ruben3010 commited on
Commit
afe0bf4
·
verified ·
1 Parent(s): f90c5ff

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +116 -0
handler.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any
2
+ import torch
3
+ import os
4
+ import base64
5
+ import io
6
+ from PIL import Image
7
+ import logging
8
+ import requests
9
+ import traceback # For formatting exception tracebacks
10
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
11
+ from qwen_vl_utils import process_vision_info
12
+
13
+ class EndpointHandler():
14
+ """
15
+ Handler class for the Qwen2-VL-7B-Instruct model on Hugging Face Inference Endpoints.
16
+ This handler processes text, image, and video inputs, leveraging the Qwen2-VL model
17
+ for multimodal understanding and generation.
18
+ """
19
+
20
+ def __init__(self, path=""):
21
+ """
22
+ Initializes the handler and loads the Qwen2-VL model.
23
+ Args:
24
+ path (str, optional): The path to the Qwen2-VL model directory. Defaults to "".
25
+ """
26
+ self.model_dir = path
27
+
28
+ # Load the Qwen2-VL model
29
+ self.model = Qwen2VLForConditionalGeneration.from_pretrained(
30
+ self.model_dir, torch_dtype="auto", device_map="auto"
31
+ )
32
+ self.processor = AutoProcessor.from_pretrained(self.model_dir)
33
+
34
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
35
+ """
36
+ Processes the input data and returns the Qwen2-VL model's output.
37
+ Args:
38
+ data (Dict[str, Any]): A dictionary containing the input data.
39
+ - "inputs" (str): The input text, including image/video references.
40
+ - "max_new_tokens" (int, optional): Max tokens to generate (default: 128).
41
+ Returns:
42
+ Dict[str, Any]: A dictionary containing the generated text.
43
+ """
44
+ inputs = data.get("inputs")
45
+ max_new_tokens = data.get("max_new_tokens", 128)
46
+
47
+ # Construct the messages list from the input string
48
+ messages = [{"role": "user", "content": self._parse_input(inputs)}]
49
+
50
+ # Prepare for inference (using qwen_vl_utils)
51
+ text = self.processor.apply_chat_template(
52
+ messages, tokenize=False, add_generation_prompt=True
53
+ )
54
+ image_inputs, video_inputs = process_vision_info(messages)
55
+
56
+ inputs = self.processor(
57
+ text=[text],
58
+ images=image_inputs,
59
+ videos=video_inputs,
60
+ padding=True,
61
+ return_tensors="pt",
62
+ )
63
+ inputs = inputs.to("cuda" if torch.cuda.is_available() else "cpu")
64
+
65
+ # Inference
66
+ generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
67
+ generated_ids_trimmed = [
68
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
69
+ ]
70
+ output_text = self.processor.batch_decode(
71
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
72
+ )[0]
73
+
74
+ return {"generated_text": output_text}
75
+
76
+ def _parse_input(self, input_string):
77
+ """
78
+ Parses the input string to identify image/video references and text.
79
+ Args:
80
+ input_string (str): The input string containing text, image, and video references.
81
+ Returns:
82
+ list: A list of dictionaries representing the parsed content.
83
+ """
84
+ content = []
85
+ parts = input_string.split("<image>")
86
+ for i, part in enumerate(parts):
87
+ if i == 0: # Text part
88
+ content.append({"type": "text", "text": part.strip()})
89
+ else: # Image part
90
+ image = self._load_image(part.strip())
91
+ if image:
92
+ content.append({"type": "image", "image": image})
93
+ return content
94
+
95
+ def _load_image(self, image_data):
96
+ """
97
+ Loads an image from a URL or base64 encoded string.
98
+ Args:
99
+ image_data (str): The image data, either a URL or a base64 encoded string.
100
+ Returns:
101
+ PIL.Image.Image or None: The loaded image, or None if loading fails.
102
+ """
103
+ try:
104
+ if image_data.startswith("http"):
105
+ response = requests.get(image_data, stream=True)
106
+ response.raise_for_status() # Check for HTTP errors
107
+ return Image.open(response.raw)
108
+ elif image_data.startswith("data:image"):
109
+ base64_data = image_data.split(",")[1]
110
+ image_bytes = base64.b64decode(base64_data)
111
+ return Image.open(io.BytesIO(image_bytes))
112
+ except requests.exceptions.RequestException as e:
113
+ logging.error(f"HTTP error occurred while loading image: {e}")
114
+ except IOError as e:
115
+ logging.error(f"Error opening image: {e}")
116
+ return None