Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from PIL import Image, ImageDraw | |
from transformers import AutoProcessor, AutoModelForImageTextToText | |
from peft import PeftModel, PeftConfig | |
import numpy as np | |
from detector import TextDetector | |
import tempfile | |
import os | |
# List of available models with their IDs and prompts | |
MODELS = { | |
"Gemma-3 10k": { | |
"id": "alakxender/dhivehi-image-text-init10k-gemma", | |
"prompt": "Extract the dhivehi text from the image" | |
} | |
} | |
class GemmaHandler: | |
def __init__(self): | |
self.model = None | |
self.processor = None | |
self.current_model_name = None | |
self.detector = TextDetector() | |
def load_model(self, model_name): | |
"""Load the model and processor""" | |
model_id = MODELS[model_name]['id'] | |
# Load the model and processor | |
self.model = AutoModelForImageTextToText.from_pretrained( | |
model_id, | |
device_map="auto", | |
torch_dtype=torch.bfloat16 | |
) | |
self.processor = AutoProcessor.from_pretrained(model_id) | |
self.current_model_name = model_name | |
def process_image(self, model_name, image, progress=None): | |
"""Process a single image""" | |
if image is None: | |
return "", [] | |
# Load model if different model selected | |
if model_name != self.current_model_name: | |
try: | |
if progress is not None: | |
progress(0, desc="Loading model...") | |
except: | |
pass | |
self.load_model(model_name) | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image) | |
width, height = image.size | |
print(f"Image dimensions: {width}x{height}") | |
# Check if image proportions are similar to a single line | |
# Typical single line has width significantly larger than height | |
# and aspect ratio (width/height) greater than 3 | |
aspect_ratio = width / height | |
if height <= 50 or aspect_ratio > 3: | |
try: | |
if progress is not None: | |
progress(0.5, desc="Processing single line...") | |
except: | |
pass | |
result = self.process_single_line(image, model_name) | |
try: | |
if progress is not None: | |
progress(1.0, desc="Done!") | |
except: | |
pass | |
return result, [image] | |
else: | |
return self.process_multi_line(image, model_name, progress) | |
def process_single_line(self, image, model_name): | |
"""Process a single line of text""" | |
# Prepare the conversation format with instruction | |
messages = [ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "text", "text": MODELS[model_name]["prompt"]}, | |
{"type": "image", "image": image.convert("RGB")} | |
], | |
} | |
] | |
# Apply the chat template | |
prompt = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
# Process into tensors | |
inputs = self.processor( | |
text=prompt, | |
images=[image], | |
return_tensors="pt" | |
).to(self.model.device) | |
# Generate text output | |
with torch.no_grad(): | |
outputs = self.model.generate(**inputs, max_new_tokens=128) | |
decoded = self.processor.tokenizer.decode(outputs[0], skip_special_tokens=True).strip() | |
# Cleanup: remove any extra prefixes or instruction leakage | |
for unwanted in ["user", "model", "Instruction:", MODELS[model_name]["prompt"]]: | |
decoded = decoded.replace(unwanted, "") | |
return decoded.strip() | |
def process_multi_line(self, image, model_name, progress=None): | |
"""Process a multi-line image by detecting text regions and OCRing each region""" | |
# Create temporary directory | |
with tempfile.TemporaryDirectory() as temp_dir: | |
# Save input image | |
input_path = os.path.join(temp_dir, "input.png") | |
image.save(input_path) | |
# Initialize detector with temp directory | |
detector = TextDetector(output_dir=temp_dir) | |
# Run text detection | |
try: | |
if progress is not None: | |
progress(0.1, desc="Detecting text regions...") | |
except: | |
pass | |
results = detector.process_input(input_path, save_images=True) | |
# Get text regions for the image | |
regions = detector.get_text_regions(results, "input") | |
if not regions: | |
return "No text regions detected", [] | |
# Process each text region | |
page_regions = regions[0] # First page | |
text_lines = page_regions.get('bboxes', []) | |
if not text_lines: | |
return "No text lines detected", [] | |
# Sort text lines by y-coordinate (top to bottom) | |
text_lines.sort(key=lambda x: x['bbox'][1]) | |
# Draw bounding boxes on the image | |
bbox_image = image.copy() | |
bbox_image = self.draw_bboxes(bbox_image, text_lines) | |
# Process each text line | |
all_text = [] | |
total_lines = len(text_lines) | |
for i, line in enumerate(text_lines): | |
try: | |
if progress is not None: | |
progress((i + 1) / total_lines, desc=f"Processing line {i+1}/{total_lines}") | |
except: | |
pass | |
# Extract text region using bbox | |
x1, y1, x2, y2 = line['bbox'] | |
line_image = image.crop((x1, y1, x2, y2)) | |
# Process the line | |
line_text = self.process_single_line(line_image, model_name) | |
all_text.append(line_text) | |
try: | |
if progress is not None: | |
progress(1.0, desc="Done!") | |
except: | |
pass | |
return "\n".join(all_text), [bbox_image] # Return as list for gallery | |
def process_pdf(self, pdf_path, model_name, progress=None): | |
"""Process a PDF file""" | |
if pdf_path is None: | |
return "", [] | |
# Load model if different model selected | |
if model_name != self.current_model_name: | |
try: | |
if progress is not None: | |
progress(0, desc="Loading model...") | |
except: | |
pass | |
self.load_model(model_name) | |
# Create temporary directory | |
with tempfile.TemporaryDirectory() as temp_dir: | |
# Initialize detector with temp directory | |
self.detector.output_dir = temp_dir | |
# Run text detection on PDF (process first 2 pages) | |
try: | |
if progress is not None: | |
progress(0.1, desc="Detecting text regions in PDF...") | |
except: | |
pass | |
results = self.detector.process_input(pdf_path, save_images=True, page_range="0") | |
# Get text regions for the PDF | |
regions = self.detector.get_text_regions(results, os.path.splitext(os.path.basename(pdf_path))[0]) | |
if not regions: | |
return "No text regions detected", [] | |
# Process each page | |
all_text = [] | |
bbox_images = [] | |
# Get the base name of the PDF without extension | |
pdf_name = os.path.splitext(os.path.basename(pdf_path))[0] | |
for page_num, page_regions in enumerate(regions): | |
try: | |
if progress is not None: | |
progress(0.2 + (page_num/len(regions))*0.3, desc=f"Processing page {page_num+1}/{len(regions)}...") | |
except: | |
pass | |
# Try different possible paths for the page image | |
possible_paths = [ | |
os.path.join(temp_dir, pdf_name, f"{pdf_name}_{page_num}_bbox.png"), # Detector's actual path | |
os.path.join(temp_dir, pdf_name, f"page_{page_num}.png"), # Original path | |
os.path.join(temp_dir, f"page_{page_num}.png"), # Direct in output dir | |
os.path.join(temp_dir, f"{pdf_name}_page_{page_num}.png") # Alternative naming | |
] | |
page_image = None | |
for page_image_path in possible_paths: | |
if os.path.exists(page_image_path): | |
page_image = Image.open(page_image_path) | |
break | |
if page_image is None: | |
all_text.append(f"\nPage {page_num+1}: Page image not found. Tried paths:\n" + | |
"\n".join(f"- {path}" for path in possible_paths)) | |
continue | |
text_lines = page_regions.get('bboxes', []) | |
if not text_lines: | |
all_text.append(f"\nPage {page_num+1}: No text lines detected") | |
continue | |
# Sort text lines by y-coordinate (top to bottom) | |
text_lines.sort(key=lambda x: x['bbox'][1]) | |
# Draw bounding boxes on the image | |
bbox_image = page_image.copy() | |
bbox_image = self.draw_bboxes(bbox_image, text_lines) | |
bbox_images.append(bbox_image) | |
# Process each text line | |
page_text = [] | |
total_lines = len(text_lines) | |
for i, line in enumerate(text_lines): | |
try: | |
if progress is not None: | |
progress(0.5 + (page_num/len(regions))*0.2 + (i/total_lines)*0.3, | |
desc=f"Processing line {i+1}/{total_lines} on page {page_num+1}/{len(regions)}...") | |
except: | |
pass | |
# Extract text region using bbox | |
x1, y1, x2, y2 = line['bbox'] | |
line_image = page_image.crop((x1, y1, x2, y2)) | |
# Process the line | |
line_text = self.process_single_line(line_image, model_name) | |
page_text.append(line_text) | |
# Add page text without page number | |
all_text.extend(page_text) | |
try: | |
if progress is not None: | |
progress(1.0, desc="Done!") | |
except: | |
pass | |
return "\n".join(all_text), bbox_images # Return list of bbox images | |
def draw_bboxes(image, text_lines): | |
"""Draw bounding boxes on the image""" | |
draw = ImageDraw.Draw(image) | |
for line in text_lines: | |
# Draw polygon - flatten nested coordinates | |
polygon = line['polygon'] | |
flat_polygon = [coord for point in polygon for coord in point] | |
draw.polygon(flat_polygon, outline="red", width=2) | |
# Draw bbox | |
x1, y1, x2, y2 = line['bbox'] | |
draw.rectangle([x1, y1, x2, y2], outline="blue", width=1) | |
# Draw confidence score | |
draw.text((x1, y1 - 10), f"{line['confidence']:.2f}", fill="red") | |
return image |