dhivehi-ocr / gemma_multiline.py
alakxender's picture
t
0e946e4
import torch
from PIL import Image
from transformers import AutoModelForImageTextToText, AutoProcessor, TextIteratorStreamer
from threading import Thread
import os
import fitz # PyMuPDF
MODELS = {
"Gemma Multiline - no-format": "alakxender/dhivehi-image-text-10k-multi-gemma",
"Gemma Multiline - line" : "alakxender/dhivehi-image-text-line-gemma"
#"Gemma Multiline - markdown": "alakxender/dhivehi-image-text-10k-multi-gemma-md",
#"Gemma Multiline - html": "alakxender/dhivehi-image-text-10k-multi-gemma-md",
#"Gemma Multiline - json": "alakxender/dhivehi-image-text-10k-multi-gemma-json"
#"Gemma Multiline - json+bbox": "alakxender/dhivehi-image-text-10k-multi-gemma-json-bbox",
}
class GemmaMultilineHandler:
def __init__(self):
self.model = None
self.processor = None
self.current_model_name = None
self.instruction = 'Extract the dhivehi text from the image'
def load_model(self, model_name: str):
if not model_name:
self.model = None
self.processor = None
self.current_model_name = None
print("Model name is empty. No model loaded.")
return
model_path = MODELS.get(model_name)
if not model_path:
print(f"Model '{model_name}' not found.")
return
if model_name == self.current_model_name and self.model is not None:
print(f"Model '{model_name}' is already loaded.")
return
try:
self.model = AutoModelForImageTextToText.from_pretrained(
model_path,
device_map="auto",
torch_dtype=torch.bfloat16,
)
self.processor = AutoProcessor.from_pretrained(model_path)
self.current_model_name = model_name
print(f"Model loaded from {model_path}")
except Exception as e:
self.model = None
self.processor = None
self.current_model_name = None
print(f"Failed to load model: {e}")
def process_vision_info(self, messages: list[dict]) -> list[Image.Image]:
image_inputs = []
for msg in messages:
content = msg.get("content", [])
if not isinstance(content, list):
content = [content]
for element in content:
if isinstance(element, dict) and ("image" in element or element.get("type") == "image"):
image = element["image"]
image_inputs.append(image.convert("RGB"))
return image_inputs
def generate_text_from_image(self, model_name: str, image: Image.Image, temperature: float = 0.8, top_p: float = 1.0, repetition_penalty: float = 1.2, progress=None) -> str:
if model_name != self.current_model_name or self.model is None:
try:
if progress: progress(0, desc=f"Loading {model_name}...")
except: pass
self.load_model(model_name)
if self.model is None or self.processor is None:
return "Model not loaded. Please select a model."
messages = [
{
'role': 'user',
'content': [
{'type': 'text', 'text': self.instruction},
{'type': 'image', 'image': image}
]
},
]
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs = self.process_vision_info(messages)
inputs = self.processor(
text=[text],
images=image_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to(self.model.device)
stop_token_ids = [
self.processor.tokenizer.eos_token_id,
self.processor.tokenizer.convert_tokens_to_ids("<end_of_turn>")
]
generated_ids = self.model.generate(
**inputs,
max_new_tokens=1024,
top_p=top_p,
do_sample=True,
temperature=temperature,
eos_token_id=stop_token_ids,
disable_compile=True,
pad_token_id=self.processor.tokenizer.eos_token_id,
repetition_penalty=repetition_penalty
)
generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
output_text = self.processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
return output_text[0]
def generate_text_stream(self, model_name: str, image: Image.Image, temperature: float = 0.8, top_p: float = 1.0, repetition_penalty: float = 1.2, progress=None):
if model_name != self.current_model_name or self.model is None:
try:
if progress: progress(0, desc=f"Loading {model_name}...")
except: pass
self.load_model(model_name)
if self.model is None or self.processor is None:
yield "Model not loaded. Please provide a model path."
return
messages = [
{
'role': 'user',
'content': [
{'type': 'text', 'text': self.instruction},
{'type': 'image', 'image': image}
]
},
]
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs = self.process_vision_info(messages)
inputs = self.processor(
text=[text],
images=image_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to(self.model.device)
stop_token_ids = [
self.processor.tokenizer.eos_token_id,
self.processor.tokenizer.convert_tokens_to_ids("<end_of_turn>")
]
streamer = TextIteratorStreamer(self.processor.tokenizer, skip_special_tokens=True, clean_up_tokenization_spaces=False)
generation_kwargs = dict(
**inputs,
max_new_tokens=1024,
top_p=top_p,
do_sample=True,
temperature=temperature,
eos_token_id=stop_token_ids,
disable_compile=True,
pad_token_id=self.processor.tokenizer.eos_token_id,
streamer=streamer,
repetition_penalty=repetition_penalty
)
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
thread.start()
generated_text = ""
for new_text in streamer:
generated_text += new_text
model_turn_start = generated_text.rfind("model\n")
if model_turn_start != -1:
response_start_index = model_turn_start + len("model\n")
clean_text = generated_text[response_start_index:]
yield clean_text.strip()
def process_pdf(self, model_name: str, pdf_path, temperature, top_p, repetition_penalty, progress=None):
if model_name != self.current_model_name or self.model is None:
try:
if progress: progress(0, desc=f"Loading {model_name}...")
except: pass
self.load_model(model_name)
if self.model is None or self.processor is None:
return "Model not loaded. Please load a model first."
if pdf_path is None:
return "No PDF file provided."
try:
doc = fitz.open(pdf_path.name)
if doc.page_count > 0:
page = doc.load_page(0) # Load the first page
pix = page.get_pixmap()
image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
doc.close()
return self.generate_text_from_image(model_name, image, temperature, top_p, repetition_penalty, progress)
else:
doc.close()
return "PDF has no pages."
except Exception as e:
return f"Failed to process PDF: {e}"
def process_pdf_stream(self, model_name: str, pdf_path, temperature, top_p, repetition_penalty, progress=None):
if model_name != self.current_model_name or self.model is None:
try:
if progress: progress(0, desc=f"Loading {model_name}...")
except: pass
self.load_model(model_name)
if self.model is None or self.processor is None:
yield "Model not loaded. Please load a model first."
return
if pdf_path is None:
yield "No PDF file provided."
return
try:
doc = fitz.open(pdf_path.name)
if doc.page_count > 0:
page = doc.load_page(0) # Load the first page
pix = page.get_pixmap()
image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
doc.close()
yield from self.generate_text_stream(model_name, image, temperature, top_p, repetition_penalty, progress)
else:
doc.close()
yield "PDF has no pages."
except Exception as e:
yield f"Failed to process PDF: {e}"