Andres77872 commited on
Commit
b71dc51
·
verified ·
1 Parent(s): eaf2a17

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -10
app.py CHANGED
@@ -9,9 +9,8 @@ base_model_id = "Andres77872/SmolVLM-500M-anime-caption-v0.1"
9
  processor = AutoProcessor.from_pretrained(base_model_id)
10
  model = Idefics3ForConditionalGeneration.from_pretrained(
11
  base_model_id,
12
- device_map="auto",
13
  torch_dtype=torch.bfloat16
14
- )
15
 
16
  class StopOnTokens(StoppingCriteria):
17
  def __init__(self, tokenizer, stop_sequence):
@@ -26,7 +25,12 @@ class StopOnTokens(StoppingCriteria):
26
  new_text = new_text[-max_keep:]
27
  return self.stop_sequence in new_text
28
 
29
- def prepare_inputs(image: Image.Image):
 
 
 
 
 
30
  question = "describe the image"
31
  messages = [
32
  {
@@ -44,13 +48,7 @@ def prepare_inputs(image: Image.Image):
44
  prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
45
  inputs = processor(text=[prompt], images=[[image]], return_tensors='pt', padding=True, size=size)
46
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
47
- return inputs
48
-
49
- def caption_anime_image_stream(image):
50
- if image is None:
51
- yield "Please upload an image."
52
- return
53
- inputs = prepare_inputs(image)
54
  stop_sequence = "</QUERY>"
55
  streamer = TextIteratorStreamer(
56
  processor.tokenizer,
 
9
  processor = AutoProcessor.from_pretrained(base_model_id)
10
  model = Idefics3ForConditionalGeneration.from_pretrained(
11
  base_model_id,
 
12
  torch_dtype=torch.bfloat16
13
+ ).to("cuda:0")
14
 
15
  class StopOnTokens(StoppingCriteria):
16
  def __init__(self, tokenizer, stop_sequence):
 
25
  new_text = new_text[-max_keep:]
26
  return self.stop_sequence in new_text
27
 
28
+ @spaces.GPU
29
+ def caption_anime_image_stream(image):
30
+ if image is None:
31
+ yield "Please upload an image."
32
+ return
33
+
34
  question = "describe the image"
35
  messages = [
36
  {
 
48
  prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
49
  inputs = processor(text=[prompt], images=[[image]], return_tensors='pt', padding=True, size=size)
50
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
51
+
 
 
 
 
 
 
52
  stop_sequence = "</QUERY>"
53
  streamer = TextIteratorStreamer(
54
  processor.tokenizer,