svjack commited on
Commit
957334c
·
verified ·
1 Parent(s): 565c0d8

Update caption_generator.py

Browse files
Files changed (1) hide show
  1. caption_generator.py +30 -16
caption_generator.py CHANGED
@@ -1,5 +1,5 @@
1
  '''
2
- python caption_generator.py /path/to/input/image.jpg /path/to/output/directory --caption_type "Descriptive" --caption_length "long" --extra_options 0 2 5 --name_input "John"
3
  '''
4
 
5
  import argparse
@@ -9,6 +9,7 @@ from torch import nn
9
  from transformers import AutoModel, AutoProcessor, AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast, AutoModelForCausalLM
10
  from PIL import Image
11
  import torchvision.transforms.functional as TVF
 
12
 
13
  # Constants
14
  CLIP_PATH = "google/siglip-so400m-patch14-384"
@@ -174,8 +175,8 @@ def generate_caption(input_image: Image.Image, caption_type: str, caption_length
174
  # Main function
175
  def main():
176
  parser = argparse.ArgumentParser(description="Generate a caption for an image.")
177
- parser.add_argument("input_image", type=str, help="Path to the input image")
178
- parser.add_argument("output_path", type=str, help="Path to save the output caption and image")
179
  parser.add_argument("--caption_type", type=str, default="Descriptive", choices=CAPTION_TYPE_MAP.keys(), help="Type of caption to generate")
180
  parser.add_argument("--caption_length", type=str, default="long", help="Length of the caption")
181
  parser.add_argument("--extra_options", nargs="*", type=int, default=[], help="Extra options for caption generation (provide IDs separated by spaces)")
@@ -189,24 +190,37 @@ def main():
189
  # Load models
190
  clip_processor, clip_model, tokenizer, text_model, image_adapter = load_models()
191
 
192
- # Open the input image
193
- input_image = Image.open(args.input_image)
 
 
 
 
194
 
195
- # Generate caption
196
- prompt_str, caption = generate_caption(input_image, args.caption_type, args.caption_length, selected_extra_options, args.name_input, args.custom_prompt, clip_processor, clip_model, tokenizer, text_model, image_adapter)
197
-
198
- # Save caption and image
199
  output_path = Path(args.output_path)
200
  output_path.mkdir(parents=True, exist_ok=True)
201
- image_name = Path(args.input_image).name.replace(" ", "_")
202
- output_image_path = output_path / image_name
203
- input_image.save(output_image_path)
204
 
205
- txt_file_path = output_path / f"{output_image_path.stem}.txt"
206
- with open(txt_file_path, "w") as f:
207
- f.write(f"Prompt: {prompt_str}\n\nCaption: {caption}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
- print(f"Caption saved to {txt_file_path}")
 
210
 
211
  if __name__ == "__main__":
212
  # Print extra options with IDs for reference
 
1
  '''
2
+ python caption_generator.py /path/to/input /path/to/output/directory --caption_type "Descriptive" --caption_length "long" --extra_options 0 2 5 --name_input "John"
3
  '''
4
 
5
  import argparse
 
9
  from transformers import AutoModel, AutoProcessor, AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast, AutoModelForCausalLM
10
  from PIL import Image
11
  import torchvision.transforms.functional as TVF
12
+ from tqdm import tqdm # 引入 tqdm 用于显示进度条
13
 
14
  # Constants
15
  CLIP_PATH = "google/siglip-so400m-patch14-384"
 
175
  # Main function
176
  def main():
177
  parser = argparse.ArgumentParser(description="Generate a caption for an image.")
178
+ parser.add_argument("input_path", type=str, help="Path to the input image or directory containing images")
179
+ parser.add_argument("output_path", type=str, help="Path to save the output captions and images")
180
  parser.add_argument("--caption_type", type=str, default="Descriptive", choices=CAPTION_TYPE_MAP.keys(), help="Type of caption to generate")
181
  parser.add_argument("--caption_length", type=str, default="long", help="Length of the caption")
182
  parser.add_argument("--extra_options", nargs="*", type=int, default=[], help="Extra options for caption generation (provide IDs separated by spaces)")
 
190
  # Load models
191
  clip_processor, clip_model, tokenizer, text_model, image_adapter = load_models()
192
 
193
+ # Determine if input is a directory or a single file
194
+ input_path = Path(args.input_path)
195
+ if input_path.is_dir():
196
+ image_paths = list(input_path.glob("*.[pjP][npP][gG]")) + list(input_path.glob("*.[jJ][pP][eE][gG]")) # 支持 PNG 和 JPEG 格式
197
+ else:
198
+ image_paths = [input_path]
199
 
200
+ # Create output directory if it doesn't exist
 
 
 
201
  output_path = Path(args.output_path)
202
  output_path.mkdir(parents=True, exist_ok=True)
 
 
 
203
 
204
+ # Process each image
205
+ for image_path in tqdm(image_paths, desc="Processing images"):
206
+ try:
207
+ # Open the input image
208
+ input_image = Image.open(image_path)
209
+
210
+ # Generate caption
211
+ prompt_str, caption = generate_caption(input_image, args.caption_type, args.caption_length, selected_extra_options, args.name_input, args.custom_prompt, clip_processor, clip_model, tokenizer, text_model, image_adapter)
212
+
213
+ # Save caption and image
214
+ image_name = image_path.name.replace(" ", "_")
215
+ output_image_path = output_path / image_name
216
+ input_image.save(output_image_path)
217
+
218
+ txt_file_path = output_path / f"{output_image_path.stem}.txt"
219
+ with open(txt_file_path, "w") as f:
220
+ f.write(f"Prompt: {prompt_str}\n\nCaption: {caption}")
221
 
222
+ except Exception as e:
223
+ print(f"Error processing {image_path}: {e}")
224
 
225
  if __name__ == "__main__":
226
  # Print extra options with IDs for reference