breadlicker45 commited on
Commit
1de48dc
·
verified ·
1 Parent(s): eef3b6b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -5
app.py CHANGED
@@ -23,18 +23,18 @@ def load_model():
23
  )
24
 
25
  # Load the processor and model using the correct identifier
26
- model_id = "google/paligemma2-10b-pt-224"
27
- processor = PaliGemmaProcessor.from_pretrained(model_id, token=token)
28
  device = "cuda" if torch.cuda.is_available() else "cpu"
29
  model = PaliGemmaForConditionalGeneration.from_pretrained(
30
- model_id, torch_dtype=torch.bfloat16, token=token
31
  ).to(device).eval()
32
 
33
  return processor, model
34
 
35
 
36
  @spaces.GPU(duration=120) # Increased timeout to 120 seconds
37
- def process_image_and_text(image_pil, text_input):
38
  """Extract text from image using PaliGemma2."""
39
  try:
40
  processor, model = load_model()
@@ -43,6 +43,9 @@ def process_image_and_text(image_pil, text_input):
43
  # Load the image using load_image
44
  image = load_image(image_pil)
45
 
 
 
 
46
  # Use the provided text input
47
  model_inputs = processor(text=text_input, images=image, return_tensors="pt").to(
48
  device, dtype=torch.bfloat16
@@ -50,7 +53,7 @@ def process_image_and_text(image_pil, text_input):
50
  input_len = model_inputs["input_ids"].shape[-1]
51
 
52
  with torch.inference_mode():
53
- generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
54
  generation = generation[0][input_len:]
55
  decoded = processor.decode(generation, skip_special_tokens=True)
56
 
@@ -66,6 +69,7 @@ if __name__ == "__main__":
66
  inputs=[
67
  gr.Image(type="pil", label="Upload an image"),
68
  gr.Textbox(label="Enter Text Prompt"),
 
69
  ],
70
  outputs=gr.Textbox(label="Generated Text"),
71
  title="PaliGemma2 Image and Text to Text",
 
23
  )
24
 
25
  # Load the processor and model using the correct identifier
26
+ model_id = "google/paligemma2-28b-pt-896"
27
+ processor = PaliGemmaProcessor.from_pretrained(model_id, use_auth_token=token)
28
  device = "cuda" if torch.cuda.is_available() else "cpu"
29
  model = PaliGemmaForConditionalGeneration.from_pretrained(
30
+ model_id, torch_dtype=torch.bfloat16, use_auth_token=token
31
  ).to(device).eval()
32
 
33
  return processor, model
34
 
35
 
36
  @spaces.GPU(duration=120) # Increased timeout to 120 seconds
37
+ def process_image_and_text(image_pil, text_input, num_beams):
38
  """Extract text from image using PaliGemma2."""
39
  try:
40
  processor, model = load_model()
 
43
  # Load the image using load_image
44
  image = load_image(image_pil)
45
 
46
+ # Add <image> token to the beginning of the text prompt
47
+ text_input = "<image> " + text_input
48
+
49
  # Use the provided text input
50
  model_inputs = processor(text=text_input, images=image, return_tensors="pt").to(
51
  device, dtype=torch.bfloat16
 
53
  input_len = model_inputs["input_ids"].shape[-1]
54
 
55
  with torch.inference_mode():
56
+ generation = model.generate(**model_inputs, max_new_tokens=200, do_sample=False, num_beams=num_beams)
57
  generation = generation[0][input_len:]
58
  decoded = processor.decode(generation, skip_special_tokens=True)
59
 
 
69
  inputs=[
70
  gr.Image(type="pil", label="Upload an image"),
71
  gr.Textbox(label="Enter Text Prompt"),
72
+ gr.Slider(minimum=1, maximum=10, step=1, value=1, label="Number of Beams"),
73
  ],
74
  outputs=gr.Textbox(label="Generated Text"),
75
  title="PaliGemma2 Image and Text to Text",