svjack commited on
Commit
565c0d8
·
verified ·
1 Parent(s): f54cdb7

Create caption_generator.py

Browse files
Files changed (1) hide show
  1. caption_generator.py +216 -0
caption_generator.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ python caption_generator.py /path/to/input/image.jpg /path/to/output/directory --caption_type "Descriptive" --caption_length "long" --extra_options 0 2 5 --name_input "John"
3
+ '''
4
+
5
+ import argparse
6
+ from pathlib import Path
7
+ import torch
8
+ from torch import nn
9
+ from transformers import AutoModel, AutoProcessor, AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast, AutoModelForCausalLM
10
+ from PIL import Image
11
+ import torchvision.transforms.functional as TVF
12
+
13
+ # Constants
14
+ CLIP_PATH = "google/siglip-so400m-patch14-384"
15
+ CHECKPOINT_PATH = Path("cgrkzexw-599808")
16
+
17
+ # Extra options with IDs for easy selection
18
+ EXTRA_OPTIONS = [
19
+ "If there is a person/character in the image you must refer to them as {name}.",
20
+ "Do NOT include information about people/characters that cannot be changed (like ethnicity, gender, etc), but do still include changeable attributes (like hair style).",
21
+ "Include information about lighting.",
22
+ "Include information about camera angle.",
23
+ "Include information about whether there is a watermark or not.",
24
+ "Include information about whether there are JPEG artifacts or not.",
25
+ "If it is a photo you MUST include information about what camera was likely used and details such as aperture, shutter speed, ISO, etc.",
26
+ "Do NOT include anything sexual; keep it PG.",
27
+ "Do NOT mention the image's resolution.",
28
+ "You MUST include information about the subjective aesthetic quality of the image from low to very high.",
29
+ "Include information on the image's composition style, such as leading lines, rule of thirds, or symmetry.",
30
+ "Do NOT mention any text that is in the image.",
31
+ "Specify the depth of field and whether the background is in focus or blurred.",
32
+ "If applicable, mention the likely use of artificial or natural lighting sources.",
33
+ "Do NOT use any ambiguous language.",
34
+ "Include whether the image is sfw, suggestive, or nsfw.",
35
+ "ONLY describe the most important elements of the image."
36
+ ]
37
+
38
+ # Image Adapter
39
+ class ImageAdapter(nn.Module):
40
+ def __init__(self, input_features: int, output_features: int, ln1: bool, pos_emb: bool, num_image_tokens: int, deep_extract: bool):
41
+ super().__init__()
42
+ self.deep_extract = deep_extract
43
+ if self.deep_extract:
44
+ input_features = input_features * 5
45
+ self.linear1 = nn.Linear(input_features, output_features)
46
+ self.activation = nn.GELU()
47
+ self.linear2 = nn.Linear(output_features, output_features)
48
+ self.ln1 = nn.Identity() if not ln1 else nn.LayerNorm(input_features)
49
+ self.pos_emb = None if not pos_emb else nn.Parameter(torch.zeros(num_image_tokens, input_features))
50
+ self.other_tokens = nn.Embedding(3, output_features)
51
+ self.other_tokens.weight.data.normal_(mean=0.0, std=0.02)
52
+
53
+ def forward(self, vision_outputs: torch.Tensor):
54
+ if self.deep_extract:
55
+ x = torch.concat((vision_outputs[-2], vision_outputs[3], vision_outputs[7], vision_outputs[13], vision_outputs[20]), dim=-1)
56
+ else:
57
+ x = vision_outputs[-2]
58
+ x = self.ln1(x)
59
+ if self.pos_emb is not None:
60
+ x = x + self.pos_emb
61
+ x = self.linear1(x)
62
+ x = self.activation(x)
63
+ x = self.linear2(x)
64
+ other_tokens = self.other_tokens(torch.tensor([0, 1], device=self.other_tokens.weight.device).expand(x.shape[0], -1))
65
+ x = torch.cat((other_tokens[:, 0:1], x, other_tokens[:, 1:2]), dim=1)
66
+ return x
67
+
68
+ def get_eot_embedding(self):
69
+ return self.other_tokens(torch.tensor([2], device=self.other_tokens.weight.device)).squeeze(0)
70
+
71
+ # Load models
72
+ def load_models():
73
+ print("Loading CLIP")
74
+ clip_processor = AutoProcessor.from_pretrained(CLIP_PATH)
75
+ clip_model = AutoModel.from_pretrained(CLIP_PATH)
76
+ clip_model = clip_model.vision_model
77
+ checkpoint = torch.load(CHECKPOINT_PATH / "clip_model.pt", map_location='cpu')
78
+ checkpoint = {k.replace("_orig_mod.module.", ""): v for k, v in checkpoint.items()}
79
+ clip_model.load_state_dict(checkpoint)
80
+ clip_model.eval()
81
+ clip_model.requires_grad_(False)
82
+ clip_model.to("cuda")
83
+
84
+ print("Loading tokenizer")
85
+ tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_PATH / "text_model", use_fast=True)
86
+
87
+ print("Loading LLM")
88
+ text_model = AutoModelForCausalLM.from_pretrained(CHECKPOINT_PATH / "text_model", device_map=0, torch_dtype=torch.bfloat16)
89
+ text_model.eval()
90
+
91
+ print("Loading image adapter")
92
+ image_adapter = ImageAdapter(clip_model.config.hidden_size, text_model.config.hidden_size, False, False, 38, False)
93
+ image_adapter.load_state_dict(torch.load(CHECKPOINT_PATH / "image_adapter.pt", map_location="cpu"))
94
+ image_adapter.eval()
95
+ image_adapter.to("cuda")
96
+
97
+ return clip_processor, clip_model, tokenizer, text_model, image_adapter
98
+
99
+ # Generate caption
100
+ @torch.no_grad()
101
+ def generate_caption(input_image: Image.Image, caption_type: str, caption_length: str | int, extra_options: list[str], name_input: str, custom_prompt: str, clip_processor, clip_model, tokenizer, text_model, image_adapter):
102
+ torch.cuda.empty_cache()
103
+
104
+ # Build prompt
105
+ length = None if caption_length == "any" else caption_length
106
+ if isinstance(length, str):
107
+ try:
108
+ length = int(length)
109
+ except ValueError:
110
+ pass
111
+ map_idx = 0 if length is None else 1 if isinstance(length, int) else 2
112
+ prompt_str = CAPTION_TYPE_MAP[caption_type][map_idx]
113
+
114
+ if len(extra_options) > 0:
115
+ prompt_str += " " + " ".join(extra_options)
116
+ prompt_str = prompt_str.format(name=name_input, length=caption_length, word_count=caption_length)
117
+
118
+ if custom_prompt.strip() != "":
119
+ prompt_str = custom_prompt.strip()
120
+
121
+ # Preprocess image
122
+ image = input_image.resize((384, 384), Image.LANCZOS)
123
+ pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
124
+ pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
125
+ pixel_values = pixel_values.to('cuda')
126
+
127
+ # Embed image
128
+ with torch.amp.autocast_mode.autocast('cuda', enabled=True):
129
+ vision_outputs = clip_model(pixel_values=pixel_values, output_hidden_states=True)
130
+ embedded_images = image_adapter(vision_outputs.hidden_states)
131
+ embedded_images = embedded_images.to('cuda')
132
+
133
+ # Build conversation
134
+ convo = [
135
+ {"role": "system", "content": "You are a helpful image captioner."},
136
+ {"role": "user", "content": prompt_str},
137
+ ]
138
+ convo_string = tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=True)
139
+ convo_tokens = tokenizer.encode(convo_string, return_tensors="pt", add_special_tokens=False, truncation=False)
140
+ prompt_tokens = tokenizer.encode(prompt_str, return_tensors="pt", add_special_tokens=False, truncation=False)
141
+ convo_tokens = convo_tokens.squeeze(0)
142
+ prompt_tokens = prompt_tokens.squeeze(0)
143
+
144
+ # Calculate where to inject the image
145
+ eot_id_indices = (convo_tokens == tokenizer.convert_tokens_to_ids("<|eot_id|>")).nonzero(as_tuple=True)[0].tolist()
146
+ preamble_len = eot_id_indices[1] - prompt_tokens.shape[0]
147
+
148
+ # Embed the tokens
149
+ convo_embeds = text_model.model.embed_tokens(convo_tokens.unsqueeze(0).to('cuda'))
150
+
151
+ # Construct the input
152
+ input_embeds = torch.cat([
153
+ convo_embeds[:, :preamble_len],
154
+ embedded_images.to(dtype=convo_embeds.dtype),
155
+ convo_embeds[:, preamble_len:],
156
+ ], dim=1).to('cuda')
157
+
158
+ input_ids = torch.cat([
159
+ convo_tokens[:preamble_len].unsqueeze(0),
160
+ torch.zeros((1, embedded_images.shape[1]), dtype=torch.long),
161
+ convo_tokens[preamble_len:].unsqueeze(0),
162
+ ], dim=1).to('cuda')
163
+ attention_mask = torch.ones_like(input_ids)
164
+
165
+ # Generate caption
166
+ generate_ids = text_model.generate(input_ids, inputs_embeds=input_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=True, suppress_tokens=None)
167
+ generate_ids = generate_ids[:, input_ids.shape[1]:]
168
+ if generate_ids[0][-1] == tokenizer.eos_token_id or generate_ids[0][-1] == tokenizer.convert_tokens_to_ids("<|eot_id|>"):
169
+ generate_ids = generate_ids[:, :-1]
170
+ caption = tokenizer.batch_decode(generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0]
171
+
172
+ return prompt_str, caption.strip()
173
+
174
+ # Main function
175
+ def main():
176
+ parser = argparse.ArgumentParser(description="Generate a caption for an image.")
177
+ parser.add_argument("input_image", type=str, help="Path to the input image")
178
+ parser.add_argument("output_path", type=str, help="Path to save the output caption and image")
179
+ parser.add_argument("--caption_type", type=str, default="Descriptive", choices=CAPTION_TYPE_MAP.keys(), help="Type of caption to generate")
180
+ parser.add_argument("--caption_length", type=str, default="long", help="Length of the caption")
181
+ parser.add_argument("--extra_options", nargs="*", type=int, default=[], help="Extra options for caption generation (provide IDs separated by spaces)")
182
+ parser.add_argument("--name_input", type=str, default="", help="Name of the person/character in the image (if applicable)")
183
+ parser.add_argument("--custom_prompt", type=str, default="", help="Custom prompt to override default settings")
184
+ args = parser.parse_args()
185
+
186
+ # Map extra option IDs to their corresponding strings
187
+ selected_extra_options = [EXTRA_OPTIONS[i] for i in args.extra_options]
188
+
189
+ # Load models
190
+ clip_processor, clip_model, tokenizer, text_model, image_adapter = load_models()
191
+
192
+ # Open the input image
193
+ input_image = Image.open(args.input_image)
194
+
195
+ # Generate caption
196
+ prompt_str, caption = generate_caption(input_image, args.caption_type, args.caption_length, selected_extra_options, args.name_input, args.custom_prompt, clip_processor, clip_model, tokenizer, text_model, image_adapter)
197
+
198
+ # Save caption and image
199
+ output_path = Path(args.output_path)
200
+ output_path.mkdir(parents=True, exist_ok=True)
201
+ image_name = Path(args.input_image).name.replace(" ", "_")
202
+ output_image_path = output_path / image_name
203
+ input_image.save(output_image_path)
204
+
205
+ txt_file_path = output_path / f"{output_image_path.stem}.txt"
206
+ with open(txt_file_path, "w") as f:
207
+ f.write(f"Prompt: {prompt_str}\n\nCaption: {caption}")
208
+
209
+ print(f"Caption saved to {txt_file_path}")
210
+
211
+ if __name__ == "__main__":
212
+ # Print extra options with IDs for reference
213
+ print("Extra Options:")
214
+ for i, option in enumerate(EXTRA_OPTIONS):
215
+ print(f"{i}: {option}")
216
+ main()