dhivehi-ocr / gemma.py
alakxender's picture
g
97bb8f1
import torch
from PIL import Image, ImageDraw
from transformers import AutoProcessor, AutoModelForImageTextToText
from peft import PeftModel, PeftConfig
import numpy as np
from detector import TextDetector
import tempfile
import os
# List of available models with their IDs and prompts
MODELS = {
"Gemma-3 10k": {
"id": "alakxender/dhivehi-image-text-init10k-gemma",
"prompt": "Extract the dhivehi text from the image"
}
}
class GemmaHandler:
def __init__(self):
self.model = None
self.processor = None
self.current_model_name = None
self.detector = TextDetector()
def load_model(self, model_name):
"""Load the model and processor"""
model_id = MODELS[model_name]['id']
# Load the model and processor
self.model = AutoModelForImageTextToText.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16
)
self.processor = AutoProcessor.from_pretrained(model_id)
self.current_model_name = model_name
def process_image(self, model_name, image, progress=None):
"""Process a single image"""
if image is None:
return "", []
# Load model if different model selected
if model_name != self.current_model_name:
try:
if progress is not None:
progress(0, desc="Loading model...")
except:
pass
self.load_model(model_name)
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
width, height = image.size
print(f"Image dimensions: {width}x{height}")
# Check if image proportions are similar to a single line
# Typical single line has width significantly larger than height
# and aspect ratio (width/height) greater than 3
aspect_ratio = width / height
if height <= 50 or aspect_ratio > 3:
try:
if progress is not None:
progress(0.5, desc="Processing single line...")
except:
pass
result = self.process_single_line(image, model_name)
try:
if progress is not None:
progress(1.0, desc="Done!")
except:
pass
return result, [image]
else:
return self.process_multi_line(image, model_name, progress)
def process_single_line(self, image, model_name):
"""Process a single line of text"""
# Prepare the conversation format with instruction
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": MODELS[model_name]["prompt"]},
{"type": "image", "image": image.convert("RGB")}
],
}
]
# Apply the chat template
prompt = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
# Process into tensors
inputs = self.processor(
text=prompt,
images=[image],
return_tensors="pt"
).to(self.model.device)
# Generate text output
with torch.no_grad():
outputs = self.model.generate(**inputs, max_new_tokens=128)
decoded = self.processor.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
# Cleanup: remove any extra prefixes or instruction leakage
for unwanted in ["user", "model", "Instruction:", MODELS[model_name]["prompt"]]:
decoded = decoded.replace(unwanted, "")
return decoded.strip()
def process_multi_line(self, image, model_name, progress=None):
"""Process a multi-line image by detecting text regions and OCRing each region"""
# Create temporary directory
with tempfile.TemporaryDirectory() as temp_dir:
# Save input image
input_path = os.path.join(temp_dir, "input.png")
image.save(input_path)
# Initialize detector with temp directory
detector = TextDetector(output_dir=temp_dir)
# Run text detection
try:
if progress is not None:
progress(0.1, desc="Detecting text regions...")
except:
pass
results = detector.process_input(input_path, save_images=True)
# Get text regions for the image
regions = detector.get_text_regions(results, "input")
if not regions:
return "No text regions detected", []
# Process each text region
page_regions = regions[0] # First page
text_lines = page_regions.get('bboxes', [])
if not text_lines:
return "No text lines detected", []
# Sort text lines by y-coordinate (top to bottom)
text_lines.sort(key=lambda x: x['bbox'][1])
# Draw bounding boxes on the image
bbox_image = image.copy()
bbox_image = self.draw_bboxes(bbox_image, text_lines)
# Process each text line
all_text = []
total_lines = len(text_lines)
for i, line in enumerate(text_lines):
try:
if progress is not None:
progress((i + 1) / total_lines, desc=f"Processing line {i+1}/{total_lines}")
except:
pass
# Extract text region using bbox
x1, y1, x2, y2 = line['bbox']
line_image = image.crop((x1, y1, x2, y2))
# Process the line
line_text = self.process_single_line(line_image, model_name)
all_text.append(line_text)
try:
if progress is not None:
progress(1.0, desc="Done!")
except:
pass
return "\n".join(all_text), [bbox_image] # Return as list for gallery
def process_pdf(self, pdf_path, model_name, progress=None):
"""Process a PDF file"""
if pdf_path is None:
return "", []
# Load model if different model selected
if model_name != self.current_model_name:
try:
if progress is not None:
progress(0, desc="Loading model...")
except:
pass
self.load_model(model_name)
# Create temporary directory
with tempfile.TemporaryDirectory() as temp_dir:
# Initialize detector with temp directory
self.detector.output_dir = temp_dir
# Run text detection on PDF (process first 2 pages)
try:
if progress is not None:
progress(0.1, desc="Detecting text regions in PDF...")
except:
pass
results = self.detector.process_input(pdf_path, save_images=True, page_range="0")
# Get text regions for the PDF
regions = self.detector.get_text_regions(results, os.path.splitext(os.path.basename(pdf_path))[0])
if not regions:
return "No text regions detected", []
# Process each page
all_text = []
bbox_images = []
# Get the base name of the PDF without extension
pdf_name = os.path.splitext(os.path.basename(pdf_path))[0]
for page_num, page_regions in enumerate(regions):
try:
if progress is not None:
progress(0.2 + (page_num/len(regions))*0.3, desc=f"Processing page {page_num+1}/{len(regions)}...")
except:
pass
# Try different possible paths for the page image
possible_paths = [
os.path.join(temp_dir, pdf_name, f"{pdf_name}_{page_num}_bbox.png"), # Detector's actual path
os.path.join(temp_dir, pdf_name, f"page_{page_num}.png"), # Original path
os.path.join(temp_dir, f"page_{page_num}.png"), # Direct in output dir
os.path.join(temp_dir, f"{pdf_name}_page_{page_num}.png") # Alternative naming
]
page_image = None
for page_image_path in possible_paths:
if os.path.exists(page_image_path):
page_image = Image.open(page_image_path)
break
if page_image is None:
all_text.append(f"\nPage {page_num+1}: Page image not found. Tried paths:\n" +
"\n".join(f"- {path}" for path in possible_paths))
continue
text_lines = page_regions.get('bboxes', [])
if not text_lines:
all_text.append(f"\nPage {page_num+1}: No text lines detected")
continue
# Sort text lines by y-coordinate (top to bottom)
text_lines.sort(key=lambda x: x['bbox'][1])
# Draw bounding boxes on the image
bbox_image = page_image.copy()
bbox_image = self.draw_bboxes(bbox_image, text_lines)
bbox_images.append(bbox_image)
# Process each text line
page_text = []
total_lines = len(text_lines)
for i, line in enumerate(text_lines):
try:
if progress is not None:
progress(0.5 + (page_num/len(regions))*0.2 + (i/total_lines)*0.3,
desc=f"Processing line {i+1}/{total_lines} on page {page_num+1}/{len(regions)}...")
except:
pass
# Extract text region using bbox
x1, y1, x2, y2 = line['bbox']
line_image = page_image.crop((x1, y1, x2, y2))
# Process the line
line_text = self.process_single_line(line_image, model_name)
page_text.append(line_text)
# Add page text without page number
all_text.extend(page_text)
try:
if progress is not None:
progress(1.0, desc="Done!")
except:
pass
return "\n".join(all_text), bbox_images # Return list of bbox images
@staticmethod
def draw_bboxes(image, text_lines):
"""Draw bounding boxes on the image"""
draw = ImageDraw.Draw(image)
for line in text_lines:
# Draw polygon - flatten nested coordinates
polygon = line['polygon']
flat_polygon = [coord for point in polygon for coord in point]
draw.polygon(flat_polygon, outline="red", width=2)
# Draw bbox
x1, y1, x2, y2 = line['bbox']
draw.rectangle([x1, y1, x2, y2], outline="blue", width=1)
# Draw confidence score
draw.text((x1, y1 - 10), f"{line['confidence']:.2f}", fill="red")
return image