import gradio as gr import torch from PIL import Image import requests import spaces from transformers import AutoProcessor, Idefics3ForConditionalGeneration, TextIteratorStreamer, StoppingCriteria, StoppingCriteriaList base_model_id = "Andres77872/SmolVLM-500M-anime-caption-v0.2" processor = AutoProcessor.from_pretrained(base_model_id) model = Idefics3ForConditionalGeneration.from_pretrained( base_model_id, torch_dtype=torch.bfloat16 ).to("cuda:0") class StopOnTokens(StoppingCriteria): def __init__(self, tokenizer, stop_sequence): super().__init__() self.tokenizer = tokenizer self.stop_sequence = stop_sequence def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: new_text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True) max_keep = len(self.stop_sequence) + 10 if len(new_text) > max_keep: new_text = new_text[-max_keep:] return self.stop_sequence in new_text @spaces.GPU def caption_anime_image_stream(image): if image is None: yield "Please upload an image." return question = "describe the image" messages = [ { "role": "user", "content": [ {"type": "image"}, {"type": "text", "text": question} ] } ] max_image_size = processor.image_processor.max_image_size["longest_edge"] size = processor.image_processor.size.copy() if "longest_edge" in size and size["longest_edge"] > max_image_size: size["longest_edge"] = max_image_size prompt = processor.apply_chat_template(messages, add_generation_prompt=True) inputs = processor(text=[prompt], images=[[image]], return_tensors='pt', padding=True, size=size) inputs = {k: v.to(model.device) for k, v in inputs.items()} stop_sequence = "" streamer = TextIteratorStreamer( processor.tokenizer, skip_prompt=True, skip_special_tokens=True, ) custom_stopping_criteria = StoppingCriteriaList([ StopOnTokens(processor.tokenizer, stop_sequence) ]) with torch.no_grad(): generation_kwargs = dict( **inputs, streamer=streamer, do_sample=False, max_new_tokens=1024, pad_token_id=processor.tokenizer.pad_token_id, stopping_criteria=custom_stopping_criteria, ) import threading generation_thread = threading.Thread(target=model.generate, kwargs=generation_kwargs) generation_thread.start() caption = "" for new_text in streamer: caption += new_text yield caption.strip() generation_thread.join() demo = gr.Interface( caption_anime_image_stream, inputs=gr.Image(type="pil", label="Anime Image"), outputs=gr.Textbox(lines=8, label="Caption"), title="SmolVLM-500M-anime-caption-v0.2 Demo", description="Upload an anime-style image to generate a caption.", # Enable live streaming: allow_flagging="auto", examples=None, ) demo.queue() demo.launch()