File size: 5,063 Bytes
dcb7d6a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import gradio as gr
import spaces
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
from PIL import Image
from datetime import datetime
import os
# subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
DESCRIPTION = "[Sparrow Qwen2-VL-7B Backend](https://github.com/katanaml/sparrow)"
def array_to_image_path(image_filepath, max_width=1250, max_height=1750):
if image_filepath is None:
raise ValueError("No image provided. Please upload an image before submitting.")
# Open the uploaded image using its filepath
img = Image.open(image_filepath)
# Extract the file extension from the uploaded file
input_image_extension = image_filepath.split('.')[-1].lower() # Extract extension from filepath
# Set file extension based on the original file, otherwise default to PNG
if input_image_extension in ['jpg', 'jpeg', 'png']:
file_extension = input_image_extension
else:
file_extension = 'png' # Default to PNG if extension is unavailable or invalid
# Get the current dimensions of the image
width, height = img.size
# Initialize new dimensions to current size
new_width, new_height = width, height
# Check if the image exceeds the maximum dimensions
if width > max_width or height > max_height:
# Calculate the new size, maintaining the aspect ratio
aspect_ratio = width / height
if width > max_width:
new_width = max_width
new_height = int(new_width / aspect_ratio)
if new_height > max_height:
new_height = max_height
new_width = int(new_height * aspect_ratio)
# Generate a unique filename using timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"image_{timestamp}.{file_extension}"
# Save the image
img.save(filename)
# Get the full path of the saved image
full_path = os.path.abspath(filename)
return full_path, new_width, new_height
# Initialize the model and processor globally to optimize performance
model = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-7B-Instruct",
torch_dtype="auto",
device_map="auto"
)
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
@spaces.GPU
def run_inference(input_imgs, text_input):
results = []
for image in input_imgs:
# Convert each image to the required format
image_path, width, height = array_to_image_path(image)
try:
# Prepare messages for each image
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": image_path,
"resized_height": height,
"resized_width": width
},
{
"type": "text",
"text": text_input
}
]
}
]
# Prepare inputs for the model
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to("cuda")
# Generate inference output
generated_ids = model.generate(**inputs, max_new_tokens=4096)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
raw_output = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
results.append(raw_output[0])
print("Processed: " + image)
finally:
# Clean up the temporary image file
os.remove(image_path)
return results
css = """
#output {
height: 500px;
overflow: auto;
border: 1px solid #ccc;
}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown(DESCRIPTION)
with gr.Tab(label="Qwen2-VL-7B Input"):
with gr.Row():
with gr.Column():
input_imgs = gr.Files(file_types=["image"], label="Upload Document Images")
text_input = gr.Textbox(label="Query")
submit_btn = gr.Button(value="Submit", variant="primary")
with gr.Column():
output_text = gr.Textbox(label="Response")
submit_btn.click(run_inference, [input_imgs, text_input], [output_text])
demo.queue(api_open=True)
demo.launch(debug=True) |