import base64 import io from typing import Dict, Any import torch from PIL import Image from transformers import AutoProcessor, VisionEncoderDecoderModel class EndpointHandler: def __init__(self, path=""): # Load processor and model from the provided path or model ID self.processor = AutoProcessor.from_pretrained(path or "bytedance/Dolphin") self.model = VisionEncoderDecoderModel.from_pretrained(path or "bytedance/Dolphin") self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model.to(self.device) self.model.eval() self.model = self.model.half() # Half precision for speed self.tokenizer = self.processor.tokenizer def decode_base64_image(self, image_base64: str) -> Image.Image: image_bytes = base64.b64decode(image_base64) return Image.open(io.BytesIO(image_bytes)).convert("RGB") def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: # Check for image input if "inputs" not in data: return {"error": "No inputs provided"} image_input = data["inputs"] # Support both base64 image strings and raw images (Hugging Face supports both) if isinstance(image_input, str): try: image = self.decode_base64_image(image_input) except Exception as e: return {"error": f"Invalid base64 image: {str(e)}"} else: image = image_input # Assume PIL-compatible image # Optional: Custom prompt (default: text reading) prompt = data.get("prompt", "Read text in the image.") full_prompt = f"{prompt} " # Preprocess inputs inputs = self.processor(image, return_tensors="pt") pixel_values = inputs.pixel_values.half().to(self.device) prompt_ids = self.tokenizer(full_prompt, add_special_tokens=False, return_tensors="pt").input_ids.to(self.device) decoder_attention_mask = torch.ones_like(prompt_ids).to(self.device) # Inference outputs = self.model.generate( pixel_values=pixel_values, decoder_input_ids=prompt_ids, decoder_attention_mask=decoder_attention_mask, min_length=1, max_length=4096, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id, use_cache=True, bad_words_ids=[[self.tokenizer.unk_token_id]], return_dict_in_generate=True, do_sample=False, num_beams=1, ) sequence = self.tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False)[0] # Clean up generated_text = sequence.replace(full_prompt, "").replace("", "").replace("", "").strip() return {"text": generated_text}