Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -8,6 +8,7 @@ from janus.models import MultiModalityCausalLM, VLChatProcessor
|
|
8 |
from dataclasses import dataclass
|
9 |
import spaces
|
10 |
|
|
|
11 |
@dataclass
|
12 |
class VLChatProcessorOutput():
|
13 |
sft_format: str
|
@@ -24,9 +25,8 @@ def process_image(image_paths, vl_chat_processor):
|
|
24 |
images_outputs = vl_chat_processor.image_processor(images, return_tensors="pt")
|
25 |
return images_outputs['pixel_values']
|
26 |
|
27 |
-
# === Load Janus model ===
|
28 |
-
#
|
29 |
-
# In a local environment, you might need to adjust paths or download assets.
|
30 |
model_path = "FreedomIntelligence/Janus-4o-7B"
|
31 |
vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
|
32 |
tokenizer = vl_chat_processor.tokenizer
|
@@ -66,7 +66,7 @@ def text_and_image_to_image_generate(input_prompt, input_image_path, output_path
|
|
66 |
|
67 |
with torch.inference_mode():
|
68 |
input_image_pixel_values = process_image(input_images, vl_chat_processor).to(torch.bfloat16).cuda()
|
69 |
-
|
70 |
image_tokens_input = info_input[2].detach().reshape(input_image_pixel_values.shape[0], -1)
|
71 |
image_embeds_input = vl_gpt.prepare_gen_img_embeds(image_tokens_input)
|
72 |
|
@@ -99,13 +99,15 @@ def text_and_image_to_image_generate(input_prompt, input_image_path, output_path
|
|
99 |
inputs_embeds[ind[0], offset: offset + image_embeds_input.shape[1], :] = image_embeds_input[(ii // 2) % img_len]
|
100 |
|
101 |
generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).cuda()
|
|
|
|
|
102 |
past_key_values = None
|
103 |
|
104 |
for i in range(image_token_num_per_image):
|
105 |
outputs = vl_gpt.language_model.model(
|
106 |
inputs_embeds=inputs_embeds,
|
107 |
use_cache=True,
|
108 |
-
past_key_values=past_key_values
|
109 |
)
|
110 |
hidden_states = outputs.last_hidden_state
|
111 |
|
@@ -124,7 +126,8 @@ def text_and_image_to_image_generate(input_prompt, input_image_path, output_path
|
|
124 |
next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
|
125 |
img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
|
126 |
inputs_embeds = img_embeds.unsqueeze(dim=1)
|
127 |
-
|
|
|
128 |
past_key_values = outputs.past_key_values
|
129 |
|
130 |
dec = vl_gpt.gen_vision_model.decode_code(
|
@@ -184,13 +187,14 @@ def text_to_image_generate(input_prompt, output_path, vl_chat_processor, vl_gpt,
|
|
184 |
inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
|
185 |
generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).cuda()
|
186 |
|
|
|
187 |
past_key_values = None
|
188 |
|
189 |
for i in range(image_token_num_per_image):
|
190 |
outputs = vl_gpt.language_model.model(
|
191 |
inputs_embeds=inputs_embeds,
|
192 |
use_cache=True,
|
193 |
-
past_key_values=past_key_values
|
194 |
)
|
195 |
|
196 |
hidden_states = outputs.last_hidden_state
|
@@ -208,6 +212,7 @@ def text_to_image_generate(input_prompt, output_path, vl_chat_processor, vl_gpt,
|
|
208 |
img_embeds = vl_gpt.prepare_gen_img_embeds(next_token_expanded)
|
209 |
inputs_embeds = img_embeds.unsqueeze(dim=1)
|
210 |
|
|
|
211 |
past_key_values = outputs.past_key_values
|
212 |
|
213 |
dec = vl_gpt.gen_vision_model.decode_code(
|
@@ -244,53 +249,62 @@ def janus_chat_responder(message, history):
|
|
244 |
prompt = message["text"]
|
245 |
uploaded_files = message["files"]
|
246 |
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
try:
|
253 |
images = text_and_image_to_image_generate(
|
254 |
prompt, temp_image_path, output_path, vl_chat_processor, vl_gpt
|
255 |
)
|
256 |
-
|
257 |
-
|
258 |
-
except Exception as e:
|
259 |
-
return f"Error during image-to-image generation: {str(e)}"
|
260 |
-
|
261 |
-
else:
|
262 |
-
# Handle text-to-image generation
|
263 |
-
try:
|
264 |
images = text_to_image_generate(prompt, output_path, vl_chat_processor, vl_gpt)
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
|
|
|
|
|
|
|
|
|
|
269 |
|
270 |
|
271 |
-
# ===
|
272 |
with gr.Blocks(theme="soft", title="Janus Image Generation") as demo:
|
273 |
gr.Markdown("# Janus Multi-Modal Image Generation")
|
274 |
gr.Markdown("Generate images from text prompts, or upload an image and a prompt to transform it.")
|
275 |
|
|
|
276 |
gr.ChatInterface(
|
277 |
fn=janus_chat_responder,
|
278 |
-
multimodal=True,
|
279 |
-
title="Janus-4o-7B
|
|
|
|
|
|
|
|
|
|
|
|
|
280 |
examples=[
|
281 |
-
{"text": "
|
282 |
-
{"text": "
|
283 |
-
{"text": "
|
284 |
-
{"text": "Turn this into a watercolor painting", "files": ["./assets/example_image.jpg"]}
|
285 |
]
|
286 |
)
|
287 |
|
288 |
if __name__ == "__main__":
|
289 |
-
# Create a dummy image for the example if it doesn't exist
|
290 |
-
|
291 |
-
|
292 |
-
if not os.path.exists(
|
293 |
-
|
294 |
-
|
|
|
|
|
|
|
|
|
|
|
295 |
|
296 |
demo.launch()
|
|
|
8 |
from dataclasses import dataclass
|
9 |
import spaces
|
10 |
|
11 |
+
# This dataclass definition is required for the processor
|
12 |
@dataclass
|
13 |
class VLChatProcessorOutput():
|
14 |
sft_format: str
|
|
|
25 |
images_outputs = vl_chat_processor.image_processor(images, return_tensors="pt")
|
26 |
return images_outputs['pixel_values']
|
27 |
|
28 |
+
# === Load Janus model and processor ===
|
29 |
+
# This setup assumes the necessary model files are accessible.
|
|
|
30 |
model_path = "FreedomIntelligence/Janus-4o-7B"
|
31 |
vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
|
32 |
tokenizer = vl_chat_processor.tokenizer
|
|
|
66 |
|
67 |
with torch.inference_mode():
|
68 |
input_image_pixel_values = process_image(input_images, vl_chat_processor).to(torch.bfloat16).cuda()
|
69 |
+
_, _, info_input = vl_gpt.gen_vision_model.encode(input_image_pixel_values)
|
70 |
image_tokens_input = info_input[2].detach().reshape(input_image_pixel_values.shape[0], -1)
|
71 |
image_embeds_input = vl_gpt.prepare_gen_img_embeds(image_tokens_input)
|
72 |
|
|
|
99 |
inputs_embeds[ind[0], offset: offset + image_embeds_input.shape[1], :] = image_embeds_input[(ii // 2) % img_len]
|
100 |
|
101 |
generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).cuda()
|
102 |
+
|
103 |
+
# --- FIX: Initialize past_key_values for cached generation ---
|
104 |
past_key_values = None
|
105 |
|
106 |
for i in range(image_token_num_per_image):
|
107 |
outputs = vl_gpt.language_model.model(
|
108 |
inputs_embeds=inputs_embeds,
|
109 |
use_cache=True,
|
110 |
+
past_key_values=past_key_values # Pass cached values
|
111 |
)
|
112 |
hidden_states = outputs.last_hidden_state
|
113 |
|
|
|
126 |
next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
|
127 |
img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
|
128 |
inputs_embeds = img_embeds.unsqueeze(dim=1)
|
129 |
+
|
130 |
+
# --- FIX: Update past_key_values with the output from the current step ---
|
131 |
past_key_values = outputs.past_key_values
|
132 |
|
133 |
dec = vl_gpt.gen_vision_model.decode_code(
|
|
|
187 |
inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
|
188 |
generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).cuda()
|
189 |
|
190 |
+
# --- FIX: Initialize past_key_values for cached generation ---
|
191 |
past_key_values = None
|
192 |
|
193 |
for i in range(image_token_num_per_image):
|
194 |
outputs = vl_gpt.language_model.model(
|
195 |
inputs_embeds=inputs_embeds,
|
196 |
use_cache=True,
|
197 |
+
past_key_values=past_key_values # Pass cached values
|
198 |
)
|
199 |
|
200 |
hidden_states = outputs.last_hidden_state
|
|
|
212 |
img_embeds = vl_gpt.prepare_gen_img_embeds(next_token_expanded)
|
213 |
inputs_embeds = img_embeds.unsqueeze(dim=1)
|
214 |
|
215 |
+
# --- FIX: Update past_key_values with the output from the current step ---
|
216 |
past_key_values = outputs.past_key_values
|
217 |
|
218 |
dec = vl_gpt.gen_vision_model.decode_code(
|
|
|
249 |
prompt = message["text"]
|
250 |
uploaded_files = message["files"]
|
251 |
|
252 |
+
try:
|
253 |
+
if uploaded_files:
|
254 |
+
# Handle text+image to image generation
|
255 |
+
temp_image_path = uploaded_files[0]
|
|
|
|
|
256 |
images = text_and_image_to_image_generate(
|
257 |
prompt, temp_image_path, output_path, vl_chat_processor, vl_gpt
|
258 |
)
|
259 |
+
else:
|
260 |
+
# Handle text-to-image generation
|
|
|
|
|
|
|
|
|
|
|
|
|
261 |
images = text_to_image_generate(prompt, output_path, vl_chat_processor, vl_gpt)
|
262 |
+
|
263 |
+
# Return a gallery component to display all generated images
|
264 |
+
return gr.Gallery(value=images, label="Generated Images")
|
265 |
+
|
266 |
+
except Exception as e:
|
267 |
+
# Return a user-friendly error message
|
268 |
+
gr.Error(f"An error occurred during generation: {str(e)}")
|
269 |
+
# Return None or an empty list for the gallery to clear it
|
270 |
+
return None
|
271 |
|
272 |
|
273 |
+
# === Gradio UI with a single ChatInterface ===
|
274 |
with gr.Blocks(theme="soft", title="Janus Image Generation") as demo:
|
275 |
gr.Markdown("# Janus Multi-Modal Image Generation")
|
276 |
gr.Markdown("Generate images from text prompts, or upload an image and a prompt to transform it.")
|
277 |
|
278 |
+
# Using gr.ChatInterface which handles the chat history and input box automatically
|
279 |
gr.ChatInterface(
|
280 |
fn=janus_chat_responder,
|
281 |
+
multimodal=True, # Enables file uploads
|
282 |
+
title="Janus-4o-7B",
|
283 |
+
chatbot=gr.Chatbot(height=400, label="Chat", show_label=False),
|
284 |
+
textbox=gr.MultimodalTextbox(
|
285 |
+
file_types=["image"],
|
286 |
+
placeholder="Type a prompt or upload an image...",
|
287 |
+
label="Input"
|
288 |
+
),
|
289 |
examples=[
|
290 |
+
{"text": "A cat made of glass, sitting on a table.", "files": []},
|
291 |
+
{"text": "A futuristic city at sunset, with flying cars.", "files": []},
|
292 |
+
{"text": "A dragon breathing fire over a medieval castle.", "files": []},
|
293 |
+
{"text": "Turn this into a watercolor painting.", "files": ["./assets/example_image.jpg"]}
|
294 |
]
|
295 |
)
|
296 |
|
297 |
if __name__ == "__main__":
|
298 |
+
# Create a dummy image for the example if it doesn't exist to prevent errors
|
299 |
+
assets_dir = "./assets"
|
300 |
+
example_image_path = os.path.join(assets_dir, "example_image.jpg")
|
301 |
+
if not os.path.exists(example_image_path):
|
302 |
+
os.makedirs(assets_dir, exist_ok=True)
|
303 |
+
try:
|
304 |
+
dummy_image = Image.new('RGB', (384, 384), color = 'red')
|
305 |
+
dummy_image.save(example_image_path)
|
306 |
+
print(f"Created dummy example image at: {example_image_path}")
|
307 |
+
except Exception as e:
|
308 |
+
print(f"Could not create dummy image: {e}")
|
309 |
|
310 |
demo.launch()
|