yakilee commited on
Commit
dcb7d6a
·
verified ·
1 Parent(s): fd2bd94

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +155 -0
app.py CHANGED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
4
+ from qwen_vl_utils import process_vision_info
5
+ from PIL import Image
6
+ from datetime import datetime
7
+ import os
8
+
9
+ # subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
10
+
11
+ DESCRIPTION = "[Sparrow Qwen2-VL-7B Backend](https://github.com/katanaml/sparrow)"
12
+
13
+
14
+ def array_to_image_path(image_filepath, max_width=1250, max_height=1750):
15
+ if image_filepath is None:
16
+ raise ValueError("No image provided. Please upload an image before submitting.")
17
+
18
+ # Open the uploaded image using its filepath
19
+ img = Image.open(image_filepath)
20
+
21
+ # Extract the file extension from the uploaded file
22
+ input_image_extension = image_filepath.split('.')[-1].lower() # Extract extension from filepath
23
+
24
+ # Set file extension based on the original file, otherwise default to PNG
25
+ if input_image_extension in ['jpg', 'jpeg', 'png']:
26
+ file_extension = input_image_extension
27
+ else:
28
+ file_extension = 'png' # Default to PNG if extension is unavailable or invalid
29
+
30
+ # Get the current dimensions of the image
31
+ width, height = img.size
32
+
33
+ # Initialize new dimensions to current size
34
+ new_width, new_height = width, height
35
+
36
+ # Check if the image exceeds the maximum dimensions
37
+ if width > max_width or height > max_height:
38
+ # Calculate the new size, maintaining the aspect ratio
39
+ aspect_ratio = width / height
40
+
41
+ if width > max_width:
42
+ new_width = max_width
43
+ new_height = int(new_width / aspect_ratio)
44
+
45
+ if new_height > max_height:
46
+ new_height = max_height
47
+ new_width = int(new_height * aspect_ratio)
48
+
49
+ # Generate a unique filename using timestamp
50
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
51
+ filename = f"image_{timestamp}.{file_extension}"
52
+
53
+ # Save the image
54
+ img.save(filename)
55
+
56
+ # Get the full path of the saved image
57
+ full_path = os.path.abspath(filename)
58
+
59
+ return full_path, new_width, new_height
60
+
61
+
62
+ # Initialize the model and processor globally to optimize performance
63
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
64
+ "Qwen/Qwen2-VL-7B-Instruct",
65
+ torch_dtype="auto",
66
+ device_map="auto"
67
+ )
68
+
69
+ processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
70
+
71
+
72
+ @spaces.GPU
73
+ def run_inference(input_imgs, text_input):
74
+ results = []
75
+
76
+ for image in input_imgs:
77
+ # Convert each image to the required format
78
+ image_path, width, height = array_to_image_path(image)
79
+
80
+ try:
81
+ # Prepare messages for each image
82
+ messages = [
83
+ {
84
+ "role": "user",
85
+ "content": [
86
+ {
87
+ "type": "image",
88
+ "image": image_path,
89
+ "resized_height": height,
90
+ "resized_width": width
91
+ },
92
+ {
93
+ "type": "text",
94
+ "text": text_input
95
+ }
96
+ ]
97
+ }
98
+ ]
99
+
100
+ # Prepare inputs for the model
101
+ text = processor.apply_chat_template(
102
+ messages, tokenize=False, add_generation_prompt=True
103
+ )
104
+
105
+ image_inputs, video_inputs = process_vision_info(messages)
106
+ inputs = processor(
107
+ text=[text],
108
+ images=image_inputs,
109
+ videos=video_inputs,
110
+ padding=True,
111
+ return_tensors="pt",
112
+ )
113
+ inputs = inputs.to("cuda")
114
+
115
+ # Generate inference output
116
+ generated_ids = model.generate(**inputs, max_new_tokens=4096)
117
+ generated_ids_trimmed = [
118
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
119
+ ]
120
+ raw_output = processor.batch_decode(
121
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=True
122
+ )
123
+
124
+ results.append(raw_output[0])
125
+ print("Processed: " + image)
126
+ finally:
127
+ # Clean up the temporary image file
128
+ os.remove(image_path)
129
+
130
+ return results
131
+
132
+
133
+ css = """
134
+ #output {
135
+ height: 500px;
136
+ overflow: auto;
137
+ border: 1px solid #ccc;
138
+ }
139
+ """
140
+
141
+ with gr.Blocks(css=css) as demo:
142
+ gr.Markdown(DESCRIPTION)
143
+ with gr.Tab(label="Qwen2-VL-7B Input"):
144
+ with gr.Row():
145
+ with gr.Column():
146
+ input_imgs = gr.Files(file_types=["image"], label="Upload Document Images")
147
+ text_input = gr.Textbox(label="Query")
148
+ submit_btn = gr.Button(value="Submit", variant="primary")
149
+ with gr.Column():
150
+ output_text = gr.Textbox(label="Response")
151
+
152
+ submit_btn.click(run_inference, [input_imgs, text_input], [output_text])
153
+
154
+ demo.queue(api_open=True)
155
+ demo.launch(debug=True)