import torch from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor import gradio as gr from PIL import Image # Use a publicly available high-capacity model. # For instance, we use "google/pix2struct-docvqa-large". # (If you need a different model or a private one, adjust accordingly and add authentication if necessary.) model_name = "google/pix2struct-docvqa-large" model = Pix2StructForConditionalGeneration.from_pretrained(model_name) processor = Pix2StructProcessor.from_pretrained(model_name) def solve_problem(image): try: # Ensure the image is in RGB. image = image.convert("RGB") # Preprocess image and text prompt. inputs = processor( images=[image], text="Solve the following problem:", return_tensors="pt", max_patches=2048 ) # Generate prediction. predictions = model.generate( **inputs, max_new_tokens=200, early_stopping=True, num_beams=4, temperature=0.2 ) # Decode the prompt (input IDs) and the generated output. problem_text = processor.decode( inputs["input_ids"][0], skip_special_tokens=True, clean_up_tokenization_spaces=True ) solution = processor.decode( predictions[0], skip_special_tokens=True, clean_up_tokenization_spaces=True ) return f"Problem: {problem_text}\nSolution: {solution}" except Exception as e: return f"Error processing image: {str(e)}" # Set up the Gradio interface. iface = gr.Interface( fn=solve_problem, inputs=gr.Image(type="pil", label="Upload Your Problem Image", image_mode="RGB"), outputs=gr.Textbox(label="Solution", show_copy_button=True), title="Problem Solver with Pix2Struct", description=( "Upload an image (for example, a handwritten math or logic problem) " "and get a solution generated by a high-capacity Pix2Struct model.\n\n" "Note: For best results on domain-specific tasks, consider fine-tuning on your own dataset." ), examples=[ ["example_problem1.png"], ["example_problem2.jpg"] ], theme="soft", allow_flagging="never" ) if __name__ == "__main__": iface.launch()