dhivehi-ocr / app.py
alakxender's picture
t
ba9dade
import spaces
import gradio as gr
import os
import sys
import subprocess
import numpy as np
from paligemma2 import PaliGemma2Handler, MODELS as PALIGEMMA_MODELS
from gemma import GemmaHandler, MODELS as GEMMA_MODELS
from gemma_multiline import GemmaMultilineHandler, MODELS as GEMMA_MULTILINE_MODELS
# Initialize model handlers
paligemma_handler = PaliGemma2Handler()
gemma_handler = GemmaHandler()
gemma_multiline_handler = GemmaMultilineHandler()
@spaces.GPU
def process_image_paligemma(model_name, image, progress=gr.Progress()):
"""Process a single image with PaliGemma2"""
return paligemma_handler.process_image(model_name, image, progress)
@spaces.GPU
def process_image_gemma(model_name, image, progress=gr.Progress()):
"""Process a single image with Gemma"""
return gemma_handler.process_image(model_name, image, progress)
@spaces.GPU
def process_pdf_paligemma(pdf_path, model_name, progress=gr.Progress()):
"""Process a PDF file with PaliGemma2"""
return paligemma_handler.process_pdf(pdf_path, model_name, progress)
@spaces.GPU
def process_pdf_gemma(pdf_path, model_name, progress=gr.Progress()):
"""Process a PDF file with Gemma"""
return gemma_handler.process_pdf(pdf_path, model_name, progress)
@spaces.GPU
def process_image_multiline(model_name, image, temp, top_p, repetition_penalty, progress=gr.Progress()):
return gemma_multiline_handler.generate_text_from_image(model_name, image, temp, top_p, repetition_penalty, progress)
@spaces.GPU
def process_image_multiline_stream(model_name, image, temp, top_p, repetition_penalty, progress=gr.Progress()):
yield from gemma_multiline_handler.generate_text_stream(model_name, image, temp, top_p, repetition_penalty, progress)
@spaces.GPU
def process_pdf_multiline(model_name, pdf, temp, top_p, repetition_penalty, progress=gr.Progress()):
return gemma_multiline_handler.process_pdf(model_name, pdf, temp, top_p, repetition_penalty, progress)
@spaces.GPU
def process_pdf_multiline_stream(model_name, pdf, temp, top_p, repetition_penalty, progress=gr.Progress()):
yield from gemma_multiline_handler.process_pdf_stream(model_name, pdf, temp, top_p, repetition_penalty, progress)
# Example images for document-level OCR
document_examples = [
["ml.png", "Multi-line Dhivehi text sample"],
["ml1.png", "Multi-line Dhivehi text sample 2"],
["ml2.png", "Multi-line Dhivehi text sample 3"],
["ml3.png", "Multi-line Dhivehi text sample 4"],
]
# Example images for sentence-level OCR
sentence_examples = [
["type_1_sl.png", "Typed Dhivehi text sample 1"],
["type_2_sl.png", "Typed Dhivehi text sample 2"],
["hw_1_sl.png", "Handwritten Dhivehi text sample 1"],
["hw_2_sl.jpg", "Handwritten Dhivehi text sample 2"],
["hw_3_sl.png", "Handwritten Dhivehi text sample 3"],
["hw_4_sl.png", "Handwritten Dhivehi text sample 4"],
["ml.png", "Multi-line Dhivehi text sample"],
]
css = """
.textbox1 textarea {
font-size: 18px !important;
font-family: 'MV_Faseyha', 'Faruma', 'A_Faruma' !important;
line-height: 1.8 !important;
}
.textbox2 textarea {
display: none;
}
"""
with gr.Blocks(title="Dhivehi Image to Text",css=css) as demo:
gr.Markdown("# Dhivehi Image to Text")
gr.Markdown("Dhivehi Image to Text experimental finetunes")
with gr.Tabs():
with gr.Tab("Gemma Document"):
with gr.Row():
model_path_dropdown = gr.Dropdown(
label="Model Checkpoint",
choices=list(GEMMA_MULTILINE_MODELS.keys()),
value=list(GEMMA_MULTILINE_MODELS.keys())[0],
interactive=True,
scale=2
)
with gr.Accordion("Advanced Options", open=False):
with gr.Row():
temperature_slider = gr.Slider(
minimum=0.1, maximum=1.9, value=0.2, step=0.1,
label="Temperature", info="Controls randomness in generation"
)
top_p_slider = gr.Slider(
minimum=0.1, maximum=1.0, value=1, step=0.1,
label="Top-p", info="Controls diversity via nucleus sampling"
)
repetition_penalty_slider = gr.Slider(
minimum=1.0, maximum=2.0, value=1.2, step=0.1,
label="Repetition Penalty", info="Penalizes repeated tokens. >1 encourages new tokens."
)
with gr.Tabs():
with gr.Tab("Image Input"):
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload Image")
with gr.Row():
generate_button = gr.Button("Generate Text (Non-streaming)")
stream_button = gr.Button("Generate Text (Streaming)", variant="primary")
stop_button = gr.Button("Stop", visible=False, variant="stop")
gr.Examples(
examples=[[img] for img, _ in document_examples],
inputs=[image_input],
outputs=None,
label="Example Images",
examples_per_page=7
)
with gr.Column():
text_output = gr.Textbox(
label="Extracted Dhivehi Text",
lines=20,
rtl=True,
elem_classes=["textbox1"],
show_copy_button=True,
scale=2
)
def show_stop_button_image():
return gr.update(visible=True), gr.update(interactive=False), gr.update(interactive=False)
def hide_stop_button_image():
return gr.update(visible=False), gr.update(interactive=True), gr.update(interactive=True)
generate_button.click(
fn=process_image_multiline,
inputs=[model_path_dropdown, image_input, temperature_slider, top_p_slider, repetition_penalty_slider],
outputs=text_output,
show_progress="full"
)
show_event = stream_button.click(fn=show_stop_button_image, outputs=[stop_button, stream_button, generate_button])
gen_event = show_event.then(fn=process_image_multiline_stream, inputs=[model_path_dropdown, image_input, temperature_slider, top_p_slider, repetition_penalty_slider], outputs=text_output, show_progress="full")
gen_event.then(fn=hide_stop_button_image, outputs=[stop_button, stream_button, generate_button])
stop_button.click(fn=hide_stop_button_image, outputs=[stop_button, stream_button, generate_button], cancels=[gen_event])
with gr.Tab("PDF Input"):
with gr.Row():
with gr.Column():
pdf_input = gr.File(label="Upload PDF", file_types=[".pdf"])
with gr.Row():
pdf_generate_button = gr.Button("Generate Text (Non-streaming)")
pdf_stream_button = gr.Button("Generate Text (Streaming)", variant="primary")
pdf_stop_button = gr.Button("Stop", visible=False, variant="stop")
gr.Examples(
examples=[["example.pdf", "Example PDF"]],
inputs=[pdf_input],
outputs=None,
label="Example PDFs",
examples_per_page=7
)
with gr.Column():
pdf_text_output = gr.Textbox(
label="Extracted Dhivehi Text",
lines=20,
rtl=True,
elem_classes=["textbox1"],
show_copy_button=True,
scale=2
)
def show_stop_button_pdf():
return gr.update(visible=True), gr.update(interactive=False), gr.update(interactive=False)
def hide_stop_button_pdf():
return gr.update(visible=False), gr.update(interactive=True), gr.update(interactive=True)
pdf_generate_button.click(
fn=process_pdf_multiline,
inputs=[model_path_dropdown, pdf_input, temperature_slider, top_p_slider, repetition_penalty_slider],
outputs=pdf_text_output,
show_progress="full"
)
pdf_show_event = pdf_stream_button.click(fn=show_stop_button_pdf, outputs=[pdf_stop_button, pdf_stream_button, pdf_generate_button])
pdf_gen_event = pdf_show_event.then(fn=process_pdf_multiline_stream, inputs=[model_path_dropdown, pdf_input, temperature_slider, top_p_slider, repetition_penalty_slider], outputs=pdf_text_output, show_progress="full")
pdf_gen_event.then(fn=hide_stop_button_pdf, outputs=[pdf_stop_button, pdf_stream_button, pdf_generate_button])
pdf_stop_button.click(fn=hide_stop_button_pdf, outputs=[pdf_stop_button, pdf_stream_button, pdf_generate_button], cancels=[pdf_gen_event])
# model_path_dropdown.change(fn=load_model_multiline, inputs=model_path_dropdown)
with gr.Tab("PaliGemma"):
model_dropdown_paligemma = gr.Dropdown(
choices=list(PALIGEMMA_MODELS.keys()),
value=list(PALIGEMMA_MODELS.keys())[0],
label="Select PaliGemma Model"
)
with gr.Tabs():
with gr.Tab("Image Input"):
with gr.Row():
with gr.Column(scale=2):
image_input_paligemma = gr.Image(type="pil", label="Input Image")
image_submit_btn_paligemma = gr.Button("Extract Text")
# Image examples
gr.Examples(
examples=[[img] for img, _ in sentence_examples],
inputs=[image_input_paligemma],
label="Example Images",
examples_per_page=8
)
with gr.Column(scale=3):
with gr.Tabs():
with gr.Tab("Extracted Text"):
image_text_output_paligemma = gr.Textbox(
lines=5,
label="Extracted Text",
show_copy_button=True,
rtl=True,
elem_classes="textbox1"
)
with gr.Tab("Detected Text Regions"):
image_bbox_output_paligemma = gr.Gallery(
label="Detected Text Regions",
show_label=True,
columns=2
)
with gr.Tab("PDF Input"):
with gr.Row():
with gr.Column(scale=2):
pdf_input_paligemma = gr.File(
label="Input PDF",
file_types=[".pdf"]
)
pdf_submit_btn_paligemma = gr.Button("Extract Text from PDF")
# PDF examples
gr.Examples(
examples=[
["example.pdf", "Example 1"],
],
inputs=[pdf_input_paligemma],
label="Example PDFs",
examples_per_page=8
)
with gr.Column(scale=3):
with gr.Tabs():
with gr.Tab("Extracted Text"):
pdf_text_output_paligemma = gr.Textbox(
lines=5,
label="Extracted Text",
show_copy_button=True,
rtl=True,
elem_classes="textbox1"
)
with gr.Tab("Detected Text Regions"):
pdf_bbox_output_paligemma = gr.Gallery(
label="Detected Text Regions",
show_label=True,
columns=2
)
with gr.Tab("Gemma Sentence"):
model_dropdown_gemma = gr.Dropdown(
choices=list(GEMMA_MODELS.keys()),
value=list(GEMMA_MODELS.keys())[0],
label="Select Gemma Model"
)
with gr.Tabs():
with gr.Tab("Image Input"):
with gr.Row():
with gr.Column(scale=2):
image_input_gemma = gr.Image(type="pil", label="Input Image")
image_submit_btn_gemma = gr.Button("Extract Text")
# Image examples
gr.Examples(
examples=[[img] for img, _ in sentence_examples],
inputs=[image_input_gemma],
label="Example Images",
examples_per_page=8
)
with gr.Column(scale=3):
with gr.Tabs():
with gr.Tab("Extracted Text"):
image_text_output_gemma = gr.Textbox(
lines=5,
label="Extracted Text",
show_copy_button=True,
rtl=True,
elem_classes="textbox1"
)
with gr.Tab("Detected Text Regions"):
image_bbox_output_gemma = gr.Gallery(
label="Detected Text Regions",
show_label=True,
columns=2
)
with gr.Tab("PDF Input"):
with gr.Row():
with gr.Column(scale=2):
pdf_input_gemma = gr.File(
label="Input PDF",
file_types=[".pdf"]
)
pdf_submit_btn_gemma = gr.Button("Extract Text from PDF")
# PDF examples
gr.Examples(
examples=[
["example.pdf", "Example 1"],
],
inputs=[pdf_input_gemma],
label="Example PDFs",
examples_per_page=8
)
with gr.Column(scale=3):
with gr.Tabs():
with gr.Tab("Extracted Text"):
pdf_text_output_gemma = gr.Textbox(
lines=5,
label="Extracted Text",
show_copy_button=True,
rtl=True,
elem_classes="textbox1"
)
with gr.Tab("Detected Text Regions"):
pdf_bbox_output_gemma = gr.Gallery(
label="Detected Text Regions",
show_label=True,
columns=2
)
# PaliGemma event handlers
image_submit_btn_paligemma.click(
fn=process_image_paligemma,
inputs=[model_dropdown_paligemma, image_input_paligemma],
outputs=[image_text_output_paligemma, image_bbox_output_paligemma]
)
pdf_submit_btn_paligemma.click(
fn=process_pdf_paligemma,
inputs=[pdf_input_paligemma, model_dropdown_paligemma],
outputs=[pdf_text_output_paligemma, pdf_bbox_output_paligemma]
)
# Gemma event handlers
image_submit_btn_gemma.click(
fn=process_image_gemma,
inputs=[model_dropdown_gemma, image_input_gemma],
outputs=[image_text_output_gemma, image_bbox_output_gemma]
)
pdf_submit_btn_gemma.click(
fn=process_pdf_gemma,
inputs=[pdf_input_gemma, model_dropdown_gemma],
outputs=[pdf_text_output_gemma, pdf_bbox_output_gemma]
)
# Function to install requirements
def install_requirements():
requirements_path = 'requirements.txt'
# Check if requirements.txt exists
if not os.path.exists(requirements_path):
print("Error: requirements.txt not found")
return False
try:
print("Installing requirements...")
# Using --no-cache-dir to avoid memory issues
subprocess.check_call([
sys.executable,
"-m",
"pip",
"install",
"-r",
requirements_path,
"--no-cache-dir"
])
print("Successfully installed all requirements")
return True
except subprocess.CalledProcessError as e:
print(f"Error installing requirements: {e}")
return False
except Exception as e:
print(f"Unexpected error: {e}")
return False
# Launch the app
if __name__ == "__main__":
# First install requirements
success = install_requirements()
if success:
print("All requirements installed successfully")
# Pre-load the multiline gemma model
#print("Loading default Gemma Multiline model...")
#gemma_multiline_handler.load_model(list(GEMMA_MULTILINE_MODELS.keys())[0])
#print("Default model loaded.")
demo.launch()
else:
print("Failed to install some requirements")
print("Failed to install some requirements")