Spaces:
Running
on
Zero
Running
on
Zero
File size: 9,927 Bytes
190ddd0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 |
import spaces
import gradio as gr
import torch
import subprocess
import sys
from transformers import AutoProcessor, Qwen2_5OmniThinkerForConditionalGeneration, TextIteratorStreamer
import torchaudio
from threading import Thread
from qwen_omni_utils import process_mm_info
from transformers import StoppingCriteria, StoppingCriteriaList
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA version: {torch.version.cuda}")
# Check environment
result = subprocess.run([sys.executable, '-c',
'import torch; print(f"PyTorch: {torch.__version__}"); print(f"CUDA: {torch.version.cuda}")'],
capture_output=True, text=True)
print(result.stdout)
# Model paths and configuration
model_path_1 = "./model"
model_path_2 = "./model2"
base_model_id = "Qwen/Qwen2.5-Omni-7B"
# Dictionary to store loaded models and processors
loaded_models = {}
# Load the model and processor
def load_model(model_path):
# Check if model is already loaded
if model_path in loaded_models:
return loaded_models[model_path]
# Load the processor from the base model
processor = AutoProcessor.from_pretrained(
base_model_id,
trust_remote_code=True,
)
# Load the model
model = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
trust_remote_code=True,
device_map="auto",
)
model.eval()
# Store in cache
loaded_models[model_path] = (model, processor)
return model, processor
# Initialize first model and processor
model, processor = load_model(model_path_1)
def process_output(output):
if "<think>" in output:
rest = output.split("<think>")[1]
output = "<think>\n" + rest
elif "<semantic_elements>" in output:
rest = output.split("<semantic_elements>")[1]
output = "<semantic_elements>\n" + rest
elif "<answer>" in output:
rest = output.split("<answer>")[1]
output = "<answer>\n" + rest
elif "</think>" in output:
rest = output.split("</think>")[0]
output = rest + "\n</think>\n\n"
elif "</semantic_elements>" in output:
rest = output.split("</semantic_elements>")[0]
output = rest + "\n</semantic_elements>\n\n"
elif "</answer>" in output:
rest = output.split("</answer>")[0]
output = rest + "\n</answer>\n"
output = output.replace("\\n", "\n")
output = output.replace("\\", "\n")
output = output.replace("\n-", "-")
return output
# Custom Stopping Criteria
class StopOnSpecificToken(StoppingCriteria):
def __init__(self, stop_token_sequences: list[list[int]], device: str = "cuda"):
super().__init__()
self.stop_sequence_tensors = []
for seq in stop_token_sequences:
if seq: # Only process non-empty sequences
self.stop_sequence_tensors.append(torch.tensor(seq, dtype=torch.long, device=device))
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
for stop_seq_tensor in self.stop_sequence_tensors:
current_sequence_length = input_ids.shape[-1]
stop_sequence_length = stop_seq_tensor.shape[-1]
# stop_sequence_length should be > 0 due to the check in __init__
if stop_sequence_length == 0: # Should ideally not be reached if seq was non-empty
continue
if current_sequence_length >= stop_sequence_length:
# Check the last tokens of the last sequence in the batch
last_tokens = input_ids[0, -stop_sequence_length:]
if torch.equal(last_tokens, stop_seq_tensor):
return True
return False
# Keep only the process_audio_streaming function that's actually used in the Gradio interface
@spaces.GPU
def process_audio_streaming(audio_file, model_choice, question="Describe this audio in detail"):
# Load the selected model
model_path = model_path_2 if model_choice == "Think" else model_path_1
model, processor = load_model(model_path)
# Load and process the audio with torchaudio
waveform, sr = torchaudio.load(audio_file)
# Resample to 16kHz if needed
if sr != processor.feature_extractor.sampling_rate:
waveform = torchaudio.functional.resample(waveform, sr, processor.feature_extractor.sampling_rate)
sr = processor.feature_extractor.sampling_rate
# Convert to mono if stereo
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
# Get the audio data as numpy array
y = waveform.squeeze().numpy()
# Set sampling rate for the processor
sampling_rate = processor.feature_extractor.sampling_rate
# Define prompts based on model choice
prompt_think_semantics = f"You are given a question and an audio clip. Your task is to answer the question based on the audio clip. First, think about the question and the audio clip and put your thoughts in <think> and </think> tags. Then reason about the semantic elements involved in the audio clip and put your reasoning in <semantic_elements> and </semantic_elements> tags. Then answer the question based on the audio clip, put your answer in <answer> and </answer> tags. {question}"
instruction_text = ""
if model_choice == "Think + Semantics":
instruction_text = prompt_think_semantics
else: # Default to the question if no specific model processing is chosen.
instruction_text = question
# Create conversation format
conversation = [
{"role": "system", "content": [
{"type": "text", "text": "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech."}]},
{"role": "user", "content": [
{"type": "audio", "audio": y},
{"type": "text", "text": instruction_text}
]}
]
# Format the chat
chat_text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
# Process multimedia info using qwen_omni_utils
audios, images, videos = process_mm_info(conversation, use_audio_in_video=False)
# Process the inputs
inputs = processor(
text=chat_text,
audio=audios,
images=images,
videos=videos,
return_tensors="pt",
sampling_rate=sampling_rate,
).to(model.device)
# Create a standard streamer instance
streamer = TextIteratorStreamer(
processor.tokenizer,
timeout=10.,
skip_prompt=True,
skip_special_tokens=True
)
# Initialize variables for buffering
accumulated_output = ""
buffer = ""
stop_sequence = "</answer>"
stop_found = False
answer_token_ids = processor.tokenizer.encode("</answer>", add_special_tokens=False)
processed_stop_token_ids = [answer_token_ids]
# Get the device of the model to ensure tensors are on the same device
model_device = next(model.parameters()).device
custom_stopping_criteria = StopOnSpecificToken(stop_token_sequences=processed_stop_token_ids, device=model_device.type)
stopping_criteria = StoppingCriteriaList([custom_stopping_criteria])
# Generate the output with streaming
with torch.no_grad():
generate_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=768,
do_sample=False,
stopping_criteria=stopping_criteria
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
# Process the stream with buffering
for output in streamer:
if stop_found:
break
output = process_output(output)
buffer += output
# Check if stop sequence is in the buffer
if stop_sequence in buffer:
# Output everything up to and including the stop sequence
before_stop = buffer.split(stop_sequence)[0]
accumulated_output += before_stop + stop_sequence
yield accumulated_output
stop_found = True
break
else:
# Check if we can safely output part of the buffer
# Keep the last N characters where N is the length of the stop sequence
if len(buffer) > len(stop_sequence):
# Output all but the last len(stop_sequence) characters
safe_output = buffer[:-len(stop_sequence)]
buffer = buffer[-len(stop_sequence):]
accumulated_output += safe_output
yield accumulated_output
# Output any remaining buffer if no stop sequence was found
if not stop_found and buffer:
accumulated_output += buffer
yield accumulated_output
# Create Gradio interface for audio processing
audio_demo = gr.Interface(
fn=process_audio_streaming,
inputs=[
gr.Audio(type="filepath", label="Upload Audio"),
gr.Radio(["Think", "Think + Semantics"], label="Select Model", value="Think + Semantics"),
gr.Textbox(label="Question", value="Describe this audio in detail")
],
outputs=gr.Textbox(label="Generated Output", lines=30),
title="AudSemThinker",
description="Upload an audio file and the model will provide detailed analysis and description. Choose between different model versions.",
examples=[["examples/1.wav", "Think + Semantics", "Describe this audio in detail"]],
cache_examples=False,
live=True
)
# Launch the apps
if __name__ == "__main__":
audio_demo.launch() |