Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
from PIL import Image | |
import requests | |
import spaces | |
from transformers import AutoProcessor, Idefics3ForConditionalGeneration, TextIteratorStreamer, StoppingCriteria, StoppingCriteriaList | |
base_model_id = "Andres77872/SmolVLM-500M-anime-caption-v0.2" | |
processor = AutoProcessor.from_pretrained(base_model_id) | |
model = Idefics3ForConditionalGeneration.from_pretrained( | |
base_model_id, | |
torch_dtype=torch.bfloat16 | |
).to("cuda:0") | |
class StopOnTokens(StoppingCriteria): | |
def __init__(self, tokenizer, stop_sequence): | |
super().__init__() | |
self.tokenizer = tokenizer | |
self.stop_sequence = stop_sequence | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
new_text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True) | |
max_keep = len(self.stop_sequence) + 10 | |
if len(new_text) > max_keep: | |
new_text = new_text[-max_keep:] | |
return self.stop_sequence in new_text | |
def caption_anime_image_stream(image): | |
if image is None: | |
yield "Please upload an image." | |
return | |
question = "describe the image" | |
messages = [ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "image"}, | |
{"type": "text", "text": question} | |
] | |
} | |
] | |
max_image_size = processor.image_processor.max_image_size["longest_edge"] | |
size = processor.image_processor.size.copy() | |
if "longest_edge" in size and size["longest_edge"] > max_image_size: | |
size["longest_edge"] = max_image_size | |
prompt = processor.apply_chat_template(messages, add_generation_prompt=True) | |
inputs = processor(text=[prompt], images=[[image]], return_tensors='pt', padding=True, size=size) | |
inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
stop_sequence = "</RATING>" | |
streamer = TextIteratorStreamer( | |
processor.tokenizer, | |
skip_prompt=True, | |
skip_special_tokens=True, | |
) | |
custom_stopping_criteria = StoppingCriteriaList([ | |
StopOnTokens(processor.tokenizer, stop_sequence) | |
]) | |
with torch.no_grad(): | |
generation_kwargs = dict( | |
**inputs, | |
streamer=streamer, | |
do_sample=False, | |
max_new_tokens=1024, | |
pad_token_id=processor.tokenizer.pad_token_id, | |
stopping_criteria=custom_stopping_criteria, | |
) | |
import threading | |
generation_thread = threading.Thread(target=model.generate, kwargs=generation_kwargs) | |
generation_thread.start() | |
caption = "" | |
for new_text in streamer: | |
caption += new_text | |
yield caption.strip() | |
generation_thread.join() | |
demo = gr.Interface( | |
caption_anime_image_stream, | |
inputs=gr.Image(type="pil", label="Anime Image"), | |
outputs=gr.Textbox(lines=8, label="Caption"), | |
title="SmolVLM-500M-anime-caption-v0.2 Demo", | |
description="Upload an anime-style image to generate a caption.", | |
# Enable live streaming: | |
allow_flagging="auto", | |
examples=None, | |
) | |
demo.queue() | |
demo.launch() |