File size: 4,695 Bytes
afe0bf4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f15388e
 
 
 
 
afe0bf4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from typing import Dict, Any
import torch
import os
import base64
import io
from PIL import Image
import logging
import requests
import traceback  # For formatting exception tracebacks
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info

class EndpointHandler():
    """
    Handler class for the Qwen2-VL-7B-Instruct model on Hugging Face Inference Endpoints.
    This handler processes text, image, and video inputs, leveraging the Qwen2-VL model
    for multimodal understanding and generation.
    """

    def __init__(self, path=""):
        """
        Initializes the handler and loads the Qwen2-VL model.
        Args:
            path (str, optional): The path to the Qwen2-VL model directory. Defaults to "".
        """
        self.model_dir = path

        # Load the Qwen2-VL model
        self.model = Qwen2VLForConditionalGeneration.from_pretrained(
                 self.model_dir,
                 torch_dtype=torch.bfloat16,
                 attn_implementation="flash_attention_2",
                 device_map="auto",
         )
        self.processor = AutoProcessor.from_pretrained(self.model_dir)

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """
        Processes the input data and returns the Qwen2-VL model's output.
        Args:
            data (Dict[str, Any]): A dictionary containing the input data.
                - "inputs" (str): The input text, including image/video references.
                - "max_new_tokens" (int, optional): Max tokens to generate (default: 128).
        Returns:
            Dict[str, Any]: A dictionary containing the generated text.
        """
        inputs = data.get("inputs")
        max_new_tokens = data.get("max_new_tokens", 128)

        # Construct the messages list from the input string
        messages = [{"role": "user", "content": self._parse_input(inputs)}]

        # Prepare for inference (using qwen_vl_utils)
        text = self.processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        image_inputs, video_inputs = process_vision_info(messages)

        inputs = self.processor(
            text=[text],
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt",
        )
        inputs = inputs.to("cuda" if torch.cuda.is_available() else "cpu")

        # Inference
        generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
        generated_ids_trimmed = [
            out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
        ]
        output_text = self.processor.batch_decode(
            generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )[0]

        return {"generated_text": output_text}

    def _parse_input(self, input_string):
        """
        Parses the input string to identify image/video references and text.
        Args:
            input_string (str): The input string containing text, image, and video references.
        Returns:
            list: A list of dictionaries representing the parsed content.
        """
        content = []
        parts = input_string.split("<image>")
        for i, part in enumerate(parts):
            if i == 0:  # Text part
                content.append({"type": "text", "text": part.strip()})
            else:  # Image part
                image = self._load_image(part.strip())
                if image:
                    content.append({"type": "image", "image": image})
        return content

    def _load_image(self, image_data):
        """
        Loads an image from a URL or base64 encoded string.
        Args:
            image_data (str): The image data, either a URL or a base64 encoded string.
        Returns:
            PIL.Image.Image or None: The loaded image, or None if loading fails.
        """
        try:
            if image_data.startswith("http"):
                response = requests.get(image_data, stream=True)
                response.raise_for_status()  # Check for HTTP errors
                return Image.open(response.raw)
            elif image_data.startswith("data:image"):
                base64_data = image_data.split(",")[1]
                image_bytes = base64.b64decode(base64_data)
                return Image.open(io.BytesIO(image_bytes))
        except requests.exceptions.RequestException as e:
            logging.error(f"HTTP error occurred while loading image: {e}")
        except IOError as e:
            logging.error(f"Error opening image: {e}")
        return None