File size: 4,015 Bytes
128f7a0
 
33ed421
128f7a0
 
 
 
2559ff3
 
128f7a0
9351df3
 
7ac0c23
128f7a0
9351df3
128f7a0
9351df3
2559ff3
128f7a0
 
2559ff3
9351df3
 
 
 
 
2559ff3
9351df3
 
2559ff3
9351df3
128f7a0
9351df3
 
2559ff3
 
 
9351df3
2559ff3
 
 
 
9351df3
2559ff3
 
9351df3
2559ff3
9351df3
2559ff3
 
128f7a0
 
 
2559ff3
9351df3
 
2559ff3
128f7a0
2559ff3
128f7a0
2559ff3
 
 
 
 
128f7a0
9351df3
 
 
 
 
 
 
2559ff3
128f7a0
 
 
2559ff3
9351df3
128f7a0
 
9351df3
128f7a0
9351df3
 
128f7a0
9351df3
2559ff3
9351df3
 
2559ff3
9351df3
 
 
 
 
2559ff3
 
9351df3
 
 
 
 
2559ff3
9351df3
2559ff3
9351df3
 
2559ff3
 
 
128f7a0
2559ff3
 
 
128f7a0
 
9351df3
 
128f7a0
9351df3
 
 
 
 
128f7a0
9351df3
128f7a0
9351df3
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import gradio as gr
import spaces
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
from PIL import Image
from datetime import datetime
import os
import torch
import gc

# Configure memory settings
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:64"
DESCRIPTION = "[Sparrow Qwen2-VL-2B Backend](https://github.com/katanaml/sparrow)"

def process_image(image_filepath, max_width=640, max_height=800):
    if image_filepath is None:
        raise ValueError("No image provided")
    
    img = Image.open(image_filepath)
    width, height = img.size
    
    # Enhanced resizing with aspect ratio preservation
    aspect_ratio = width / height
    if aspect_ratio > (max_width/max_height):
        new_width = max_width
        new_height = int(max_width / aspect_ratio)
    else:
        new_height = max_height
        new_width = int(max_height * aspect_ratio)
    
    img = img.resize((new_width, new_height), Image.LANCZOS)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    filename = f"/tmp/image_{timestamp}.jpg"
    img.save(filename, format='JPEG', quality=75, optimize=True)
    
    return os.path.abspath(filename), new_width, new_height

# Model initialization with memory optimizations
model = None
processor = None

def load_model():
    global model, processor
    model = Qwen2VLForConditionalGeneration.from_pretrained(
        "Qwen/Qwen2-VL-2B-Instruct",
        torch_dtype=torch.float16,
        device_map="auto",
        low_cpu_mem_usage=True
    )
    processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")

@spaces.GPU
def run_inference(input_imgs, text_input):
    global model, processor
    if model is None:
        load_model()
    
    results = []
    
    for image in input_imgs:
        torch.cuda.empty_cache()
        gc.collect()
        
        image_path, width, height = process_image(image)
        
        try:
            messages = [{
                "role": "user",
                "content": [
                    {"type": "image", "image": image_path},
                    {"type": "text", "text": text_input}
                ]
            }]
            
            text = processor.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
            
            # Process inputs in chunks
            inputs = processor(
                text=[text],
                images=[Image.open(image_path)],
                padding=True,
                truncation=True,
                max_length=512,
                return_tensors="pt",
            ).to("cuda")
            
            # Memory-efficient generation
            with torch.inference_mode():
                generated_ids = model.generate(
                    **inputs,
                    max_new_tokens=512,
                    do_sample=False,
                    num_beams=1,
                    early_stopping=True
                )
            
            # Clean output processing
            output = processor.batch_decode(
                generated_ids[:, inputs.input_ids.shape[1]:], 
                skip_special_tokens=True
            )[0]
            
            results.append(output)
            
            # Force memory cleanup
            del inputs, generated_ids
            torch.cuda.empty_cache()
            gc.collect()
            
        finally:
            if os.path.exists(image_path):
                os.remove(image_path)
    
    return results

# Streamlined interface
with gr.Blocks() as demo:
    gr.Markdown(DESCRIPTION)
    with gr.Row():
        input_imgs = gr.Files(file_types=["image"], label="Upload Images")
        text_input = gr.Textbox(label="Query")
        submit_btn = gr.Button("Submit", variant="primary")
    output_text = gr.Textbox(label="Response", elem_id="output")

    submit_btn.click(run_inference, [input_imgs, text_input], output_text)

demo.queue(max_size=1)
demo.launch()