Spaces:
Runtime error
Runtime error
import torch | |
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor | |
import gradio as gr | |
from PIL import Image | |
# Use a publicly available high-capacity model. | |
# For instance, we use "google/pix2struct-docvqa-large". | |
# (If you need a different model or a private one, adjust accordingly and add authentication if necessary.) | |
model_name = "google/pix2struct-docvqa-large" | |
model = Pix2StructForConditionalGeneration.from_pretrained(model_name) | |
processor = Pix2StructProcessor.from_pretrained(model_name) | |
def solve_problem(image): | |
try: | |
# Ensure the image is in RGB. | |
image = image.convert("RGB") | |
# Preprocess image and text prompt. | |
inputs = processor( | |
images=[image], | |
text="Solve the following problem:", | |
return_tensors="pt", | |
max_patches=2048 | |
) | |
# Generate prediction. | |
predictions = model.generate( | |
**inputs, | |
max_new_tokens=200, | |
early_stopping=True, | |
num_beams=4, | |
temperature=0.2 | |
) | |
# Decode the prompt (input IDs) and the generated output. | |
problem_text = processor.decode( | |
inputs["input_ids"][0], | |
skip_special_tokens=True, | |
clean_up_tokenization_spaces=True | |
) | |
solution = processor.decode( | |
predictions[0], | |
skip_special_tokens=True, | |
clean_up_tokenization_spaces=True | |
) | |
return f"Problem: {problem_text}\nSolution: {solution}" | |
except Exception as e: | |
return f"Error processing image: {str(e)}" | |
# Set up the Gradio interface. | |
iface = gr.Interface( | |
fn=solve_problem, | |
inputs=gr.Image(type="pil", label="Upload Your Problem Image", image_mode="RGB"), | |
outputs=gr.Textbox(label="Solution", show_copy_button=True), | |
title="Problem Solver with Pix2Struct", | |
description=( | |
"Upload an image (for example, a handwritten math or logic problem) " | |
"and get a solution generated by a high-capacity Pix2Struct model.\n\n" | |
"Note: For best results on domain-specific tasks, consider fine-tuning on your own dataset." | |
), | |
examples=[ | |
["example_problem1.png"], | |
["example_problem2.jpg"] | |
], | |
theme="soft", | |
allow_flagging="never" | |
) | |
if __name__ == "__main__": | |
iface.launch() | |