Spaces:
Runtime error
Runtime error
Update caption_generator.py
Browse files- caption_generator.py +30 -16
caption_generator.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
'''
|
2 |
-
python caption_generator.py /path/to/input
|
3 |
'''
|
4 |
|
5 |
import argparse
|
@@ -9,6 +9,7 @@ from torch import nn
|
|
9 |
from transformers import AutoModel, AutoProcessor, AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast, AutoModelForCausalLM
|
10 |
from PIL import Image
|
11 |
import torchvision.transforms.functional as TVF
|
|
|
12 |
|
13 |
# Constants
|
14 |
CLIP_PATH = "google/siglip-so400m-patch14-384"
|
@@ -174,8 +175,8 @@ def generate_caption(input_image: Image.Image, caption_type: str, caption_length
|
|
174 |
# Main function
|
175 |
def main():
|
176 |
parser = argparse.ArgumentParser(description="Generate a caption for an image.")
|
177 |
-
parser.add_argument("
|
178 |
-
parser.add_argument("output_path", type=str, help="Path to save the output
|
179 |
parser.add_argument("--caption_type", type=str, default="Descriptive", choices=CAPTION_TYPE_MAP.keys(), help="Type of caption to generate")
|
180 |
parser.add_argument("--caption_length", type=str, default="long", help="Length of the caption")
|
181 |
parser.add_argument("--extra_options", nargs="*", type=int, default=[], help="Extra options for caption generation (provide IDs separated by spaces)")
|
@@ -189,24 +190,37 @@ def main():
|
|
189 |
# Load models
|
190 |
clip_processor, clip_model, tokenizer, text_model, image_adapter = load_models()
|
191 |
|
192 |
-
#
|
193 |
-
|
|
|
|
|
|
|
|
|
194 |
|
195 |
-
#
|
196 |
-
prompt_str, caption = generate_caption(input_image, args.caption_type, args.caption_length, selected_extra_options, args.name_input, args.custom_prompt, clip_processor, clip_model, tokenizer, text_model, image_adapter)
|
197 |
-
|
198 |
-
# Save caption and image
|
199 |
output_path = Path(args.output_path)
|
200 |
output_path.mkdir(parents=True, exist_ok=True)
|
201 |
-
image_name = Path(args.input_image).name.replace(" ", "_")
|
202 |
-
output_image_path = output_path / image_name
|
203 |
-
input_image.save(output_image_path)
|
204 |
|
205 |
-
|
206 |
-
|
207 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
|
209 |
-
|
|
|
210 |
|
211 |
if __name__ == "__main__":
|
212 |
# Print extra options with IDs for reference
|
|
|
1 |
'''
|
2 |
+
python caption_generator.py /path/to/input /path/to/output/directory --caption_type "Descriptive" --caption_length "long" --extra_options 0 2 5 --name_input "John"
|
3 |
'''
|
4 |
|
5 |
import argparse
|
|
|
9 |
from transformers import AutoModel, AutoProcessor, AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast, AutoModelForCausalLM
|
10 |
from PIL import Image
|
11 |
import torchvision.transforms.functional as TVF
|
12 |
+
from tqdm import tqdm # 引入 tqdm 用于显示进度条
|
13 |
|
14 |
# Constants
|
15 |
CLIP_PATH = "google/siglip-so400m-patch14-384"
|
|
|
175 |
# Main function
|
176 |
def main():
|
177 |
parser = argparse.ArgumentParser(description="Generate a caption for an image.")
|
178 |
+
parser.add_argument("input_path", type=str, help="Path to the input image or directory containing images")
|
179 |
+
parser.add_argument("output_path", type=str, help="Path to save the output captions and images")
|
180 |
parser.add_argument("--caption_type", type=str, default="Descriptive", choices=CAPTION_TYPE_MAP.keys(), help="Type of caption to generate")
|
181 |
parser.add_argument("--caption_length", type=str, default="long", help="Length of the caption")
|
182 |
parser.add_argument("--extra_options", nargs="*", type=int, default=[], help="Extra options for caption generation (provide IDs separated by spaces)")
|
|
|
190 |
# Load models
|
191 |
clip_processor, clip_model, tokenizer, text_model, image_adapter = load_models()
|
192 |
|
193 |
+
# Determine if input is a directory or a single file
|
194 |
+
input_path = Path(args.input_path)
|
195 |
+
if input_path.is_dir():
|
196 |
+
image_paths = list(input_path.glob("*.[pjP][npP][gG]")) + list(input_path.glob("*.[jJ][pP][eE][gG]")) # 支持 PNG 和 JPEG 格式
|
197 |
+
else:
|
198 |
+
image_paths = [input_path]
|
199 |
|
200 |
+
# Create output directory if it doesn't exist
|
|
|
|
|
|
|
201 |
output_path = Path(args.output_path)
|
202 |
output_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
203 |
|
204 |
+
# Process each image
|
205 |
+
for image_path in tqdm(image_paths, desc="Processing images"):
|
206 |
+
try:
|
207 |
+
# Open the input image
|
208 |
+
input_image = Image.open(image_path)
|
209 |
+
|
210 |
+
# Generate caption
|
211 |
+
prompt_str, caption = generate_caption(input_image, args.caption_type, args.caption_length, selected_extra_options, args.name_input, args.custom_prompt, clip_processor, clip_model, tokenizer, text_model, image_adapter)
|
212 |
+
|
213 |
+
# Save caption and image
|
214 |
+
image_name = image_path.name.replace(" ", "_")
|
215 |
+
output_image_path = output_path / image_name
|
216 |
+
input_image.save(output_image_path)
|
217 |
+
|
218 |
+
txt_file_path = output_path / f"{output_image_path.stem}.txt"
|
219 |
+
with open(txt_file_path, "w") as f:
|
220 |
+
f.write(f"Prompt: {prompt_str}\n\nCaption: {caption}")
|
221 |
|
222 |
+
except Exception as e:
|
223 |
+
print(f"Error processing {image_path}: {e}")
|
224 |
|
225 |
if __name__ == "__main__":
|
226 |
# Print extra options with IDs for reference
|