Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from PIL import Image | |
from transformers import AutoModelForImageTextToText, AutoProcessor, TextIteratorStreamer | |
from threading import Thread | |
import os | |
import fitz # PyMuPDF | |
MODELS = { | |
"Gemma Multiline - no-format": "alakxender/dhivehi-image-text-10k-multi-gemma", | |
"Gemma Multiline - line" : "alakxender/dhivehi-image-text-line-gemma" | |
#"Gemma Multiline - markdown": "alakxender/dhivehi-image-text-10k-multi-gemma-md", | |
#"Gemma Multiline - html": "alakxender/dhivehi-image-text-10k-multi-gemma-md", | |
#"Gemma Multiline - json": "alakxender/dhivehi-image-text-10k-multi-gemma-json" | |
#"Gemma Multiline - json+bbox": "alakxender/dhivehi-image-text-10k-multi-gemma-json-bbox", | |
} | |
class GemmaMultilineHandler: | |
def __init__(self): | |
self.model = None | |
self.processor = None | |
self.current_model_name = None | |
self.instruction = 'Extract the dhivehi text from the image' | |
def load_model(self, model_name: str): | |
if not model_name: | |
self.model = None | |
self.processor = None | |
self.current_model_name = None | |
print("Model name is empty. No model loaded.") | |
return | |
model_path = MODELS.get(model_name) | |
if not model_path: | |
print(f"Model '{model_name}' not found.") | |
return | |
if model_name == self.current_model_name and self.model is not None: | |
print(f"Model '{model_name}' is already loaded.") | |
return | |
try: | |
self.model = AutoModelForImageTextToText.from_pretrained( | |
model_path, | |
device_map="auto", | |
torch_dtype=torch.bfloat16, | |
) | |
self.processor = AutoProcessor.from_pretrained(model_path) | |
self.current_model_name = model_name | |
print(f"Model loaded from {model_path}") | |
except Exception as e: | |
self.model = None | |
self.processor = None | |
self.current_model_name = None | |
print(f"Failed to load model: {e}") | |
def process_vision_info(self, messages: list[dict]) -> list[Image.Image]: | |
image_inputs = [] | |
for msg in messages: | |
content = msg.get("content", []) | |
if not isinstance(content, list): | |
content = [content] | |
for element in content: | |
if isinstance(element, dict) and ("image" in element or element.get("type") == "image"): | |
image = element["image"] | |
image_inputs.append(image.convert("RGB")) | |
return image_inputs | |
def generate_text_from_image(self, model_name: str, image: Image.Image, temperature: float = 0.8, top_p: float = 1.0, repetition_penalty: float = 1.2, progress=None) -> str: | |
if model_name != self.current_model_name or self.model is None: | |
try: | |
if progress: progress(0, desc=f"Loading {model_name}...") | |
except: pass | |
self.load_model(model_name) | |
if self.model is None or self.processor is None: | |
return "Model not loaded. Please select a model." | |
messages = [ | |
{ | |
'role': 'user', | |
'content': [ | |
{'type': 'text', 'text': self.instruction}, | |
{'type': 'image', 'image': image} | |
] | |
}, | |
] | |
text = self.processor.apply_chat_template( | |
messages, tokenize=False, add_generation_prompt=True | |
) | |
image_inputs = self.process_vision_info(messages) | |
inputs = self.processor( | |
text=[text], | |
images=image_inputs, | |
padding=True, | |
return_tensors="pt", | |
) | |
inputs = inputs.to(self.model.device) | |
stop_token_ids = [ | |
self.processor.tokenizer.eos_token_id, | |
self.processor.tokenizer.convert_tokens_to_ids("<end_of_turn>") | |
] | |
generated_ids = self.model.generate( | |
**inputs, | |
max_new_tokens=1024, | |
top_p=top_p, | |
do_sample=True, | |
temperature=temperature, | |
eos_token_id=stop_token_ids, | |
disable_compile=True, | |
pad_token_id=self.processor.tokenizer.eos_token_id, | |
repetition_penalty=repetition_penalty | |
) | |
generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)] | |
output_text = self.processor.batch_decode( | |
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False | |
) | |
return output_text[0] | |
def generate_text_stream(self, model_name: str, image: Image.Image, temperature: float = 0.8, top_p: float = 1.0, repetition_penalty: float = 1.2, progress=None): | |
if model_name != self.current_model_name or self.model is None: | |
try: | |
if progress: progress(0, desc=f"Loading {model_name}...") | |
except: pass | |
self.load_model(model_name) | |
if self.model is None or self.processor is None: | |
yield "Model not loaded. Please provide a model path." | |
return | |
messages = [ | |
{ | |
'role': 'user', | |
'content': [ | |
{'type': 'text', 'text': self.instruction}, | |
{'type': 'image', 'image': image} | |
] | |
}, | |
] | |
text = self.processor.apply_chat_template( | |
messages, tokenize=False, add_generation_prompt=True | |
) | |
image_inputs = self.process_vision_info(messages) | |
inputs = self.processor( | |
text=[text], | |
images=image_inputs, | |
padding=True, | |
return_tensors="pt", | |
) | |
inputs = inputs.to(self.model.device) | |
stop_token_ids = [ | |
self.processor.tokenizer.eos_token_id, | |
self.processor.tokenizer.convert_tokens_to_ids("<end_of_turn>") | |
] | |
streamer = TextIteratorStreamer(self.processor.tokenizer, skip_special_tokens=True, clean_up_tokenization_spaces=False) | |
generation_kwargs = dict( | |
**inputs, | |
max_new_tokens=1024, | |
top_p=top_p, | |
do_sample=True, | |
temperature=temperature, | |
eos_token_id=stop_token_ids, | |
disable_compile=True, | |
pad_token_id=self.processor.tokenizer.eos_token_id, | |
streamer=streamer, | |
repetition_penalty=repetition_penalty | |
) | |
thread = Thread(target=self.model.generate, kwargs=generation_kwargs) | |
thread.start() | |
generated_text = "" | |
for new_text in streamer: | |
generated_text += new_text | |
model_turn_start = generated_text.rfind("model\n") | |
if model_turn_start != -1: | |
response_start_index = model_turn_start + len("model\n") | |
clean_text = generated_text[response_start_index:] | |
yield clean_text.strip() | |
def process_pdf(self, model_name: str, pdf_path, temperature, top_p, repetition_penalty, progress=None): | |
if model_name != self.current_model_name or self.model is None: | |
try: | |
if progress: progress(0, desc=f"Loading {model_name}...") | |
except: pass | |
self.load_model(model_name) | |
if self.model is None or self.processor is None: | |
return "Model not loaded. Please load a model first." | |
if pdf_path is None: | |
return "No PDF file provided." | |
try: | |
doc = fitz.open(pdf_path.name) | |
if doc.page_count > 0: | |
page = doc.load_page(0) # Load the first page | |
pix = page.get_pixmap() | |
image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) | |
doc.close() | |
return self.generate_text_from_image(model_name, image, temperature, top_p, repetition_penalty, progress) | |
else: | |
doc.close() | |
return "PDF has no pages." | |
except Exception as e: | |
return f"Failed to process PDF: {e}" | |
def process_pdf_stream(self, model_name: str, pdf_path, temperature, top_p, repetition_penalty, progress=None): | |
if model_name != self.current_model_name or self.model is None: | |
try: | |
if progress: progress(0, desc=f"Loading {model_name}...") | |
except: pass | |
self.load_model(model_name) | |
if self.model is None or self.processor is None: | |
yield "Model not loaded. Please load a model first." | |
return | |
if pdf_path is None: | |
yield "No PDF file provided." | |
return | |
try: | |
doc = fitz.open(pdf_path.name) | |
if doc.page_count > 0: | |
page = doc.load_page(0) # Load the first page | |
pix = page.get_pixmap() | |
image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) | |
doc.close() | |
yield from self.generate_text_stream(model_name, image, temperature, top_p, repetition_penalty, progress) | |
else: | |
doc.close() | |
yield "PDF has no pages." | |
except Exception as e: | |
yield f"Failed to process PDF: {e}" |