Andres77872 commited on
Commit
fb1e20b
·
verified ·
1 Parent(s): 0842d98

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -4
app.py CHANGED
@@ -1,7 +1,93 @@
1
  import gradio as gr
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
+ import requests
5
+ from transformers import AutoProcessor, Idefics3ForConditionalGeneration, TextIteratorStreamer, StoppingCriteria, StoppingCriteriaList
6
 
7
+ base_model_id = "Andres77872/SmolVLM-500M-anime-caption-v0.1"
 
8
 
9
+ processor = AutoProcessor.from_pretrained(base_model_id)
10
+ model = Idefics3ForConditionalGeneration.from_pretrained(
11
+ base_model_id,
12
+ device_map="auto",
13
+ torch_dtype=torch.bfloat16
14
+ )
15
+
16
+ class StopOnTokens(StoppingCriteria):
17
+ def __init__(self, tokenizer, stop_sequence):
18
+ super().__init__()
19
+ self.tokenizer = tokenizer
20
+ self.stop_sequence = stop_sequence
21
+
22
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
23
+ new_text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
24
+ max_keep = len(self.stop_sequence) + 10
25
+ if len(new_text) > max_keep:
26
+ new_text = new_text[-max_keep:]
27
+ return self.stop_sequence in new_text
28
+
29
+ def prepare_inputs(image: Image.Image):
30
+ question = "describe the image"
31
+ messages = [
32
+ {
33
+ "role": "user",
34
+ "content": [
35
+ {"type": "image"},
36
+ {"type": "text", "text": question}
37
+ ]
38
+ }
39
+ ]
40
+ max_image_size = processor.image_processor.max_image_size["longest_edge"]
41
+ size = processor.image_processor.size.copy()
42
+ if "longest_edge" in size and size["longest_edge"] > max_image_size:
43
+ size["longest_edge"] = max_image_size
44
+ prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
45
+ inputs = processor(text=[prompt], images=[[image]], return_tensors='pt', padding=True, size=size)
46
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
47
+ return inputs
48
+
49
+ def caption_anime_image_stream(image):
50
+ if image is None:
51
+ yield "Please upload an image."
52
+ return
53
+ inputs = prepare_inputs(image)
54
+ stop_sequence = "</QUERY>"
55
+ streamer = TextIteratorStreamer(
56
+ processor.tokenizer,
57
+ skip_prompt=True,
58
+ skip_special_tokens=True,
59
+ )
60
+ custom_stopping_criteria = StoppingCriteriaList([
61
+ StopOnTokens(processor.tokenizer, stop_sequence)
62
+ ])
63
+ with torch.no_grad():
64
+ generation_kwargs = dict(
65
+ **inputs,
66
+ streamer=streamer,
67
+ do_sample=False,
68
+ max_new_tokens=512,
69
+ pad_token_id=processor.tokenizer.pad_token_id,
70
+ stopping_criteria=custom_stopping_criteria,
71
+ )
72
+ import threading
73
+ generation_thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
74
+ generation_thread.start()
75
+ caption = ""
76
+ for new_text in streamer:
77
+ caption += new_text
78
+ yield caption.strip()
79
+ generation_thread.join()
80
+
81
+ demo = gr.Interface(
82
+ caption_anime_image_stream,
83
+ inputs=gr.Image(type="pil", label="Anime Image"),
84
+ outputs=gr.Textbox(lines=8, label="Caption"),
85
+ title="SmolVLM-500M-Anime-Caption Demo",
86
+ description="Upload an anime-style image to generate a caption.",
87
+ # Enable live streaming:
88
+ allow_flagging="auto",
89
+ examples=None,
90
+ )
91
+
92
+ demo.queue()
93
+ demo.launch()