akhaliq HF Staff commited on
Commit
3622941
·
verified ·
1 Parent(s): 05acf34

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -38
app.py CHANGED
@@ -8,6 +8,7 @@ from janus.models import MultiModalityCausalLM, VLChatProcessor
8
  from dataclasses import dataclass
9
  import spaces
10
 
 
11
  @dataclass
12
  class VLChatProcessorOutput():
13
  sft_format: str
@@ -24,9 +25,8 @@ def process_image(image_paths, vl_chat_processor):
24
  images_outputs = vl_chat_processor.image_processor(images, return_tensors="pt")
25
  return images_outputs['pixel_values']
26
 
27
- # === Load Janus model ===
28
- # NOTE: This section assumes the model and processor can be loaded.
29
- # In a local environment, you might need to adjust paths or download assets.
30
  model_path = "FreedomIntelligence/Janus-4o-7B"
31
  vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
32
  tokenizer = vl_chat_processor.tokenizer
@@ -66,7 +66,7 @@ def text_and_image_to_image_generate(input_prompt, input_image_path, output_path
66
 
67
  with torch.inference_mode():
68
  input_image_pixel_values = process_image(input_images, vl_chat_processor).to(torch.bfloat16).cuda()
69
- quant_input, emb_loss_input, info_input = vl_gpt.gen_vision_model.encode(input_image_pixel_values)
70
  image_tokens_input = info_input[2].detach().reshape(input_image_pixel_values.shape[0], -1)
71
  image_embeds_input = vl_gpt.prepare_gen_img_embeds(image_tokens_input)
72
 
@@ -99,13 +99,15 @@ def text_and_image_to_image_generate(input_prompt, input_image_path, output_path
99
  inputs_embeds[ind[0], offset: offset + image_embeds_input.shape[1], :] = image_embeds_input[(ii // 2) % img_len]
100
 
101
  generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).cuda()
 
 
102
  past_key_values = None
103
 
104
  for i in range(image_token_num_per_image):
105
  outputs = vl_gpt.language_model.model(
106
  inputs_embeds=inputs_embeds,
107
  use_cache=True,
108
- past_key_values=past_key_values
109
  )
110
  hidden_states = outputs.last_hidden_state
111
 
@@ -124,7 +126,8 @@ def text_and_image_to_image_generate(input_prompt, input_image_path, output_path
124
  next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
125
  img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
126
  inputs_embeds = img_embeds.unsqueeze(dim=1)
127
-
 
128
  past_key_values = outputs.past_key_values
129
 
130
  dec = vl_gpt.gen_vision_model.decode_code(
@@ -184,13 +187,14 @@ def text_to_image_generate(input_prompt, output_path, vl_chat_processor, vl_gpt,
184
  inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
185
  generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).cuda()
186
 
 
187
  past_key_values = None
188
 
189
  for i in range(image_token_num_per_image):
190
  outputs = vl_gpt.language_model.model(
191
  inputs_embeds=inputs_embeds,
192
  use_cache=True,
193
- past_key_values=past_key_values
194
  )
195
 
196
  hidden_states = outputs.last_hidden_state
@@ -208,6 +212,7 @@ def text_to_image_generate(input_prompt, output_path, vl_chat_processor, vl_gpt,
208
  img_embeds = vl_gpt.prepare_gen_img_embeds(next_token_expanded)
209
  inputs_embeds = img_embeds.unsqueeze(dim=1)
210
 
 
211
  past_key_values = outputs.past_key_values
212
 
213
  dec = vl_gpt.gen_vision_model.decode_code(
@@ -244,53 +249,62 @@ def janus_chat_responder(message, history):
244
  prompt = message["text"]
245
  uploaded_files = message["files"]
246
 
247
- if uploaded_files:
248
- # Handle text+image to image generation
249
- # Assuming the first uploaded file is the image to process
250
- temp_image_path = uploaded_files[0]
251
-
252
- try:
253
  images = text_and_image_to_image_generate(
254
  prompt, temp_image_path, output_path, vl_chat_processor, vl_gpt
255
  )
256
- # FIX APPLIED HERE: Return a gr.Gallery component to display all generated images
257
- return gr.Gallery(value=images, label="Generated Images")
258
- except Exception as e:
259
- return f"Error during image-to-image generation: {str(e)}"
260
-
261
- else:
262
- # Handle text-to-image generation
263
- try:
264
  images = text_to_image_generate(prompt, output_path, vl_chat_processor, vl_gpt)
265
- # FIX APPLIED HERE: Return a gr.Gallery component to display all generated images
266
- return gr.Gallery(value=images, label="Generated Images")
267
- except Exception as e:
268
- return f"Error during text-to-image generation: {str(e)}"
 
 
 
 
 
269
 
270
 
271
- # === Simplified Gradio UI with a single ChatInterface ===
272
  with gr.Blocks(theme="soft", title="Janus Image Generation") as demo:
273
  gr.Markdown("# Janus Multi-Modal Image Generation")
274
  gr.Markdown("Generate images from text prompts, or upload an image and a prompt to transform it.")
275
 
 
276
  gr.ChatInterface(
277
  fn=janus_chat_responder,
278
- multimodal=True,
279
- title="Janus-4o-7B Chat",
 
 
 
 
 
 
280
  examples=[
281
- {"text": "a cat sitting on a windowsill", "files": []},
282
- {"text": "a futuristic city at sunset", "files": []},
283
- {"text": "a dragon flying over mountains", "files": []},
284
- {"text": "Turn this into a watercolor painting", "files": ["./assets/example_image.jpg"]}
285
  ]
286
  )
287
 
288
  if __name__ == "__main__":
289
- # Create a dummy image for the example if it doesn't exist
290
- if not os.path.exists("./assets"):
291
- os.makedirs("./assets")
292
- if not os.path.exists("./assets/example_image.jpg"):
293
- dummy_image = Image.new('RGB', (100, 100), color = 'red')
294
- dummy_image.save("./assets/example_image.jpg")
 
 
 
 
 
295
 
296
  demo.launch()
 
8
  from dataclasses import dataclass
9
  import spaces
10
 
11
+ # This dataclass definition is required for the processor
12
  @dataclass
13
  class VLChatProcessorOutput():
14
  sft_format: str
 
25
  images_outputs = vl_chat_processor.image_processor(images, return_tensors="pt")
26
  return images_outputs['pixel_values']
27
 
28
+ # === Load Janus model and processor ===
29
+ # This setup assumes the necessary model files are accessible.
 
30
  model_path = "FreedomIntelligence/Janus-4o-7B"
31
  vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
32
  tokenizer = vl_chat_processor.tokenizer
 
66
 
67
  with torch.inference_mode():
68
  input_image_pixel_values = process_image(input_images, vl_chat_processor).to(torch.bfloat16).cuda()
69
+ _, _, info_input = vl_gpt.gen_vision_model.encode(input_image_pixel_values)
70
  image_tokens_input = info_input[2].detach().reshape(input_image_pixel_values.shape[0], -1)
71
  image_embeds_input = vl_gpt.prepare_gen_img_embeds(image_tokens_input)
72
 
 
99
  inputs_embeds[ind[0], offset: offset + image_embeds_input.shape[1], :] = image_embeds_input[(ii // 2) % img_len]
100
 
101
  generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).cuda()
102
+
103
+ # --- FIX: Initialize past_key_values for cached generation ---
104
  past_key_values = None
105
 
106
  for i in range(image_token_num_per_image):
107
  outputs = vl_gpt.language_model.model(
108
  inputs_embeds=inputs_embeds,
109
  use_cache=True,
110
+ past_key_values=past_key_values # Pass cached values
111
  )
112
  hidden_states = outputs.last_hidden_state
113
 
 
126
  next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
127
  img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
128
  inputs_embeds = img_embeds.unsqueeze(dim=1)
129
+
130
+ # --- FIX: Update past_key_values with the output from the current step ---
131
  past_key_values = outputs.past_key_values
132
 
133
  dec = vl_gpt.gen_vision_model.decode_code(
 
187
  inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
188
  generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).cuda()
189
 
190
+ # --- FIX: Initialize past_key_values for cached generation ---
191
  past_key_values = None
192
 
193
  for i in range(image_token_num_per_image):
194
  outputs = vl_gpt.language_model.model(
195
  inputs_embeds=inputs_embeds,
196
  use_cache=True,
197
+ past_key_values=past_key_values # Pass cached values
198
  )
199
 
200
  hidden_states = outputs.last_hidden_state
 
212
  img_embeds = vl_gpt.prepare_gen_img_embeds(next_token_expanded)
213
  inputs_embeds = img_embeds.unsqueeze(dim=1)
214
 
215
+ # --- FIX: Update past_key_values with the output from the current step ---
216
  past_key_values = outputs.past_key_values
217
 
218
  dec = vl_gpt.gen_vision_model.decode_code(
 
249
  prompt = message["text"]
250
  uploaded_files = message["files"]
251
 
252
+ try:
253
+ if uploaded_files:
254
+ # Handle text+image to image generation
255
+ temp_image_path = uploaded_files[0]
 
 
256
  images = text_and_image_to_image_generate(
257
  prompt, temp_image_path, output_path, vl_chat_processor, vl_gpt
258
  )
259
+ else:
260
+ # Handle text-to-image generation
 
 
 
 
 
 
261
  images = text_to_image_generate(prompt, output_path, vl_chat_processor, vl_gpt)
262
+
263
+ # Return a gallery component to display all generated images
264
+ return gr.Gallery(value=images, label="Generated Images")
265
+
266
+ except Exception as e:
267
+ # Return a user-friendly error message
268
+ gr.Error(f"An error occurred during generation: {str(e)}")
269
+ # Return None or an empty list for the gallery to clear it
270
+ return None
271
 
272
 
273
+ # === Gradio UI with a single ChatInterface ===
274
  with gr.Blocks(theme="soft", title="Janus Image Generation") as demo:
275
  gr.Markdown("# Janus Multi-Modal Image Generation")
276
  gr.Markdown("Generate images from text prompts, or upload an image and a prompt to transform it.")
277
 
278
+ # Using gr.ChatInterface which handles the chat history and input box automatically
279
  gr.ChatInterface(
280
  fn=janus_chat_responder,
281
+ multimodal=True, # Enables file uploads
282
+ title="Janus-4o-7B",
283
+ chatbot=gr.Chatbot(height=400, label="Chat", show_label=False),
284
+ textbox=gr.MultimodalTextbox(
285
+ file_types=["image"],
286
+ placeholder="Type a prompt or upload an image...",
287
+ label="Input"
288
+ ),
289
  examples=[
290
+ {"text": "A cat made of glass, sitting on a table.", "files": []},
291
+ {"text": "A futuristic city at sunset, with flying cars.", "files": []},
292
+ {"text": "A dragon breathing fire over a medieval castle.", "files": []},
293
+ {"text": "Turn this into a watercolor painting.", "files": ["./assets/example_image.jpg"]}
294
  ]
295
  )
296
 
297
  if __name__ == "__main__":
298
+ # Create a dummy image for the example if it doesn't exist to prevent errors
299
+ assets_dir = "./assets"
300
+ example_image_path = os.path.join(assets_dir, "example_image.jpg")
301
+ if not os.path.exists(example_image_path):
302
+ os.makedirs(assets_dir, exist_ok=True)
303
+ try:
304
+ dummy_image = Image.new('RGB', (384, 384), color = 'red')
305
+ dummy_image.save(example_image_path)
306
+ print(f"Created dummy example image at: {example_image_path}")
307
+ except Exception as e:
308
+ print(f"Could not create dummy image: {e}")
309
 
310
  demo.launch()