Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import gradio as gr | |
import os | |
import sys | |
import subprocess | |
import numpy as np | |
from paligemma2 import PaliGemma2Handler, MODELS as PALIGEMMA_MODELS | |
from gemma import GemmaHandler, MODELS as GEMMA_MODELS | |
from gemma_multiline import GemmaMultilineHandler, MODELS as GEMMA_MULTILINE_MODELS | |
# Initialize model handlers | |
paligemma_handler = PaliGemma2Handler() | |
gemma_handler = GemmaHandler() | |
gemma_multiline_handler = GemmaMultilineHandler() | |
def process_image_paligemma(model_name, image, progress=gr.Progress()): | |
"""Process a single image with PaliGemma2""" | |
return paligemma_handler.process_image(model_name, image, progress) | |
def process_image_gemma(model_name, image, progress=gr.Progress()): | |
"""Process a single image with Gemma""" | |
return gemma_handler.process_image(model_name, image, progress) | |
def process_pdf_paligemma(pdf_path, model_name, progress=gr.Progress()): | |
"""Process a PDF file with PaliGemma2""" | |
return paligemma_handler.process_pdf(pdf_path, model_name, progress) | |
def process_pdf_gemma(pdf_path, model_name, progress=gr.Progress()): | |
"""Process a PDF file with Gemma""" | |
return gemma_handler.process_pdf(pdf_path, model_name, progress) | |
def process_image_multiline(model_name, image, temp, top_p, repetition_penalty, progress=gr.Progress()): | |
return gemma_multiline_handler.generate_text_from_image(model_name, image, temp, top_p, repetition_penalty, progress) | |
def process_image_multiline_stream(model_name, image, temp, top_p, repetition_penalty, progress=gr.Progress()): | |
yield from gemma_multiline_handler.generate_text_stream(model_name, image, temp, top_p, repetition_penalty, progress) | |
def process_pdf_multiline(model_name, pdf, temp, top_p, repetition_penalty, progress=gr.Progress()): | |
return gemma_multiline_handler.process_pdf(model_name, pdf, temp, top_p, repetition_penalty, progress) | |
def process_pdf_multiline_stream(model_name, pdf, temp, top_p, repetition_penalty, progress=gr.Progress()): | |
yield from gemma_multiline_handler.process_pdf_stream(model_name, pdf, temp, top_p, repetition_penalty, progress) | |
# Example images for document-level OCR | |
document_examples = [ | |
["ml.png", "Multi-line Dhivehi text sample"], | |
["ml1.png", "Multi-line Dhivehi text sample 2"], | |
["ml2.png", "Multi-line Dhivehi text sample 3"], | |
["ml3.png", "Multi-line Dhivehi text sample 4"], | |
] | |
# Example images for sentence-level OCR | |
sentence_examples = [ | |
["type_1_sl.png", "Typed Dhivehi text sample 1"], | |
["type_2_sl.png", "Typed Dhivehi text sample 2"], | |
["hw_1_sl.png", "Handwritten Dhivehi text sample 1"], | |
["hw_2_sl.jpg", "Handwritten Dhivehi text sample 2"], | |
["hw_3_sl.png", "Handwritten Dhivehi text sample 3"], | |
["hw_4_sl.png", "Handwritten Dhivehi text sample 4"], | |
["ml.png", "Multi-line Dhivehi text sample"], | |
] | |
css = """ | |
.textbox1 textarea { | |
font-size: 18px !important; | |
font-family: 'MV_Faseyha', 'Faruma', 'A_Faruma' !important; | |
line-height: 1.8 !important; | |
} | |
.textbox2 textarea { | |
display: none; | |
} | |
""" | |
with gr.Blocks(title="Dhivehi Image to Text",css=css) as demo: | |
gr.Markdown("# Dhivehi Image to Text") | |
gr.Markdown("Dhivehi Image to Text experimental finetunes") | |
with gr.Tabs(): | |
with gr.Tab("Gemma Document"): | |
with gr.Row(): | |
model_path_dropdown = gr.Dropdown( | |
label="Model Checkpoint", | |
choices=list(GEMMA_MULTILINE_MODELS.keys()), | |
value=list(GEMMA_MULTILINE_MODELS.keys())[0], | |
interactive=True, | |
scale=2 | |
) | |
with gr.Accordion("Advanced Options", open=False): | |
with gr.Row(): | |
temperature_slider = gr.Slider( | |
minimum=0.1, maximum=1.9, value=0.2, step=0.1, | |
label="Temperature", info="Controls randomness in generation" | |
) | |
top_p_slider = gr.Slider( | |
minimum=0.1, maximum=1.0, value=1, step=0.1, | |
label="Top-p", info="Controls diversity via nucleus sampling" | |
) | |
repetition_penalty_slider = gr.Slider( | |
minimum=1.0, maximum=2.0, value=1.2, step=0.1, | |
label="Repetition Penalty", info="Penalizes repeated tokens. >1 encourages new tokens." | |
) | |
with gr.Tabs(): | |
with gr.Tab("Image Input"): | |
with gr.Row(): | |
with gr.Column(): | |
image_input = gr.Image(type="pil", label="Upload Image") | |
with gr.Row(): | |
generate_button = gr.Button("Generate Text (Non-streaming)") | |
stream_button = gr.Button("Generate Text (Streaming)", variant="primary") | |
stop_button = gr.Button("Stop", visible=False, variant="stop") | |
gr.Examples( | |
examples=[[img] for img, _ in document_examples], | |
inputs=[image_input], | |
outputs=None, | |
label="Example Images", | |
examples_per_page=7 | |
) | |
with gr.Column(): | |
text_output = gr.Textbox( | |
label="Extracted Dhivehi Text", | |
lines=20, | |
rtl=True, | |
elem_classes=["textbox1"], | |
show_copy_button=True, | |
scale=2 | |
) | |
def show_stop_button_image(): | |
return gr.update(visible=True), gr.update(interactive=False), gr.update(interactive=False) | |
def hide_stop_button_image(): | |
return gr.update(visible=False), gr.update(interactive=True), gr.update(interactive=True) | |
generate_button.click( | |
fn=process_image_multiline, | |
inputs=[model_path_dropdown, image_input, temperature_slider, top_p_slider, repetition_penalty_slider], | |
outputs=text_output, | |
show_progress="full" | |
) | |
show_event = stream_button.click(fn=show_stop_button_image, outputs=[stop_button, stream_button, generate_button]) | |
gen_event = show_event.then(fn=process_image_multiline_stream, inputs=[model_path_dropdown, image_input, temperature_slider, top_p_slider, repetition_penalty_slider], outputs=text_output, show_progress="full") | |
gen_event.then(fn=hide_stop_button_image, outputs=[stop_button, stream_button, generate_button]) | |
stop_button.click(fn=hide_stop_button_image, outputs=[stop_button, stream_button, generate_button], cancels=[gen_event]) | |
with gr.Tab("PDF Input"): | |
with gr.Row(): | |
with gr.Column(): | |
pdf_input = gr.File(label="Upload PDF", file_types=[".pdf"]) | |
with gr.Row(): | |
pdf_generate_button = gr.Button("Generate Text (Non-streaming)") | |
pdf_stream_button = gr.Button("Generate Text (Streaming)", variant="primary") | |
pdf_stop_button = gr.Button("Stop", visible=False, variant="stop") | |
gr.Examples( | |
examples=[["example.pdf", "Example PDF"]], | |
inputs=[pdf_input], | |
outputs=None, | |
label="Example PDFs", | |
examples_per_page=7 | |
) | |
with gr.Column(): | |
pdf_text_output = gr.Textbox( | |
label="Extracted Dhivehi Text", | |
lines=20, | |
rtl=True, | |
elem_classes=["textbox1"], | |
show_copy_button=True, | |
scale=2 | |
) | |
def show_stop_button_pdf(): | |
return gr.update(visible=True), gr.update(interactive=False), gr.update(interactive=False) | |
def hide_stop_button_pdf(): | |
return gr.update(visible=False), gr.update(interactive=True), gr.update(interactive=True) | |
pdf_generate_button.click( | |
fn=process_pdf_multiline, | |
inputs=[model_path_dropdown, pdf_input, temperature_slider, top_p_slider, repetition_penalty_slider], | |
outputs=pdf_text_output, | |
show_progress="full" | |
) | |
pdf_show_event = pdf_stream_button.click(fn=show_stop_button_pdf, outputs=[pdf_stop_button, pdf_stream_button, pdf_generate_button]) | |
pdf_gen_event = pdf_show_event.then(fn=process_pdf_multiline_stream, inputs=[model_path_dropdown, pdf_input, temperature_slider, top_p_slider, repetition_penalty_slider], outputs=pdf_text_output, show_progress="full") | |
pdf_gen_event.then(fn=hide_stop_button_pdf, outputs=[pdf_stop_button, pdf_stream_button, pdf_generate_button]) | |
pdf_stop_button.click(fn=hide_stop_button_pdf, outputs=[pdf_stop_button, pdf_stream_button, pdf_generate_button], cancels=[pdf_gen_event]) | |
# model_path_dropdown.change(fn=load_model_multiline, inputs=model_path_dropdown) | |
with gr.Tab("PaliGemma"): | |
model_dropdown_paligemma = gr.Dropdown( | |
choices=list(PALIGEMMA_MODELS.keys()), | |
value=list(PALIGEMMA_MODELS.keys())[0], | |
label="Select PaliGemma Model" | |
) | |
with gr.Tabs(): | |
with gr.Tab("Image Input"): | |
with gr.Row(): | |
with gr.Column(scale=2): | |
image_input_paligemma = gr.Image(type="pil", label="Input Image") | |
image_submit_btn_paligemma = gr.Button("Extract Text") | |
# Image examples | |
gr.Examples( | |
examples=[[img] for img, _ in sentence_examples], | |
inputs=[image_input_paligemma], | |
label="Example Images", | |
examples_per_page=8 | |
) | |
with gr.Column(scale=3): | |
with gr.Tabs(): | |
with gr.Tab("Extracted Text"): | |
image_text_output_paligemma = gr.Textbox( | |
lines=5, | |
label="Extracted Text", | |
show_copy_button=True, | |
rtl=True, | |
elem_classes="textbox1" | |
) | |
with gr.Tab("Detected Text Regions"): | |
image_bbox_output_paligemma = gr.Gallery( | |
label="Detected Text Regions", | |
show_label=True, | |
columns=2 | |
) | |
with gr.Tab("PDF Input"): | |
with gr.Row(): | |
with gr.Column(scale=2): | |
pdf_input_paligemma = gr.File( | |
label="Input PDF", | |
file_types=[".pdf"] | |
) | |
pdf_submit_btn_paligemma = gr.Button("Extract Text from PDF") | |
# PDF examples | |
gr.Examples( | |
examples=[ | |
["example.pdf", "Example 1"], | |
], | |
inputs=[pdf_input_paligemma], | |
label="Example PDFs", | |
examples_per_page=8 | |
) | |
with gr.Column(scale=3): | |
with gr.Tabs(): | |
with gr.Tab("Extracted Text"): | |
pdf_text_output_paligemma = gr.Textbox( | |
lines=5, | |
label="Extracted Text", | |
show_copy_button=True, | |
rtl=True, | |
elem_classes="textbox1" | |
) | |
with gr.Tab("Detected Text Regions"): | |
pdf_bbox_output_paligemma = gr.Gallery( | |
label="Detected Text Regions", | |
show_label=True, | |
columns=2 | |
) | |
with gr.Tab("Gemma Sentence"): | |
model_dropdown_gemma = gr.Dropdown( | |
choices=list(GEMMA_MODELS.keys()), | |
value=list(GEMMA_MODELS.keys())[0], | |
label="Select Gemma Model" | |
) | |
with gr.Tabs(): | |
with gr.Tab("Image Input"): | |
with gr.Row(): | |
with gr.Column(scale=2): | |
image_input_gemma = gr.Image(type="pil", label="Input Image") | |
image_submit_btn_gemma = gr.Button("Extract Text") | |
# Image examples | |
gr.Examples( | |
examples=[[img] for img, _ in sentence_examples], | |
inputs=[image_input_gemma], | |
label="Example Images", | |
examples_per_page=8 | |
) | |
with gr.Column(scale=3): | |
with gr.Tabs(): | |
with gr.Tab("Extracted Text"): | |
image_text_output_gemma = gr.Textbox( | |
lines=5, | |
label="Extracted Text", | |
show_copy_button=True, | |
rtl=True, | |
elem_classes="textbox1" | |
) | |
with gr.Tab("Detected Text Regions"): | |
image_bbox_output_gemma = gr.Gallery( | |
label="Detected Text Regions", | |
show_label=True, | |
columns=2 | |
) | |
with gr.Tab("PDF Input"): | |
with gr.Row(): | |
with gr.Column(scale=2): | |
pdf_input_gemma = gr.File( | |
label="Input PDF", | |
file_types=[".pdf"] | |
) | |
pdf_submit_btn_gemma = gr.Button("Extract Text from PDF") | |
# PDF examples | |
gr.Examples( | |
examples=[ | |
["example.pdf", "Example 1"], | |
], | |
inputs=[pdf_input_gemma], | |
label="Example PDFs", | |
examples_per_page=8 | |
) | |
with gr.Column(scale=3): | |
with gr.Tabs(): | |
with gr.Tab("Extracted Text"): | |
pdf_text_output_gemma = gr.Textbox( | |
lines=5, | |
label="Extracted Text", | |
show_copy_button=True, | |
rtl=True, | |
elem_classes="textbox1" | |
) | |
with gr.Tab("Detected Text Regions"): | |
pdf_bbox_output_gemma = gr.Gallery( | |
label="Detected Text Regions", | |
show_label=True, | |
columns=2 | |
) | |
# PaliGemma event handlers | |
image_submit_btn_paligemma.click( | |
fn=process_image_paligemma, | |
inputs=[model_dropdown_paligemma, image_input_paligemma], | |
outputs=[image_text_output_paligemma, image_bbox_output_paligemma] | |
) | |
pdf_submit_btn_paligemma.click( | |
fn=process_pdf_paligemma, | |
inputs=[pdf_input_paligemma, model_dropdown_paligemma], | |
outputs=[pdf_text_output_paligemma, pdf_bbox_output_paligemma] | |
) | |
# Gemma event handlers | |
image_submit_btn_gemma.click( | |
fn=process_image_gemma, | |
inputs=[model_dropdown_gemma, image_input_gemma], | |
outputs=[image_text_output_gemma, image_bbox_output_gemma] | |
) | |
pdf_submit_btn_gemma.click( | |
fn=process_pdf_gemma, | |
inputs=[pdf_input_gemma, model_dropdown_gemma], | |
outputs=[pdf_text_output_gemma, pdf_bbox_output_gemma] | |
) | |
# Function to install requirements | |
def install_requirements(): | |
requirements_path = 'requirements.txt' | |
# Check if requirements.txt exists | |
if not os.path.exists(requirements_path): | |
print("Error: requirements.txt not found") | |
return False | |
try: | |
print("Installing requirements...") | |
# Using --no-cache-dir to avoid memory issues | |
subprocess.check_call([ | |
sys.executable, | |
"-m", | |
"pip", | |
"install", | |
"-r", | |
requirements_path, | |
"--no-cache-dir" | |
]) | |
print("Successfully installed all requirements") | |
return True | |
except subprocess.CalledProcessError as e: | |
print(f"Error installing requirements: {e}") | |
return False | |
except Exception as e: | |
print(f"Unexpected error: {e}") | |
return False | |
# Launch the app | |
if __name__ == "__main__": | |
# First install requirements | |
success = install_requirements() | |
if success: | |
print("All requirements installed successfully") | |
# Pre-load the multiline gemma model | |
#print("Loading default Gemma Multiline model...") | |
#gemma_multiline_handler.load_model(list(GEMMA_MULTILINE_MODELS.keys())[0]) | |
#print("Default model loaded.") | |
demo.launch() | |
else: | |
print("Failed to install some requirements") | |
print("Failed to install some requirements") |