import spaces
import gradio as gr
import torch
import subprocess
import sys
from transformers import AutoProcessor, Qwen2_5OmniThinkerForConditionalGeneration, TextIteratorStreamer
import torchaudio
from threading import Thread
from qwen_omni_utils import process_mm_info
from transformers import StoppingCriteria, StoppingCriteriaList
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA version: {torch.version.cuda}")
# Check environment
result = subprocess.run([sys.executable, '-c',
'import torch; print(f"PyTorch: {torch.__version__}"); print(f"CUDA: {torch.version.cuda}")'],
capture_output=True, text=True)
print(result.stdout)
# Model paths and configuration
model_path_1 = "./model"
model_path_2 = "./model2"
base_model_id = "Qwen/Qwen2.5-Omni-7B"
# Dictionary to store loaded models and processors
loaded_models = {}
# Load the model and processor
def load_model(model_path):
# Check if model is already loaded
if model_path in loaded_models:
return loaded_models[model_path]
# Load the processor from the base model
processor = AutoProcessor.from_pretrained(
base_model_id,
trust_remote_code=True,
)
# Load the model
model = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
trust_remote_code=True,
device_map="auto",
)
model.eval()
# Store in cache
loaded_models[model_path] = (model, processor)
return model, processor
# Initialize first model and processor
model, processor = load_model(model_path_1)
def process_output(output):
if "" in output:
rest = output.split("")[1]
output = "\n" + rest
elif "" in output:
rest = output.split("")[1]
output = "\n" + rest
elif "" in output:
rest = output.split("")[1]
output = "\n" + rest
elif "" in output:
rest = output.split("")[0]
output = rest + "\n\n\n"
elif "" in output:
rest = output.split("")[0]
output = rest + "\n\n\n"
elif "" in output:
rest = output.split("")[0]
output = rest + "\n\n"
output = output.replace("\\n", "\n")
output = output.replace("\\", "\n")
output = output.replace("\n-", "-")
return output
# Custom Stopping Criteria
class StopOnSpecificToken(StoppingCriteria):
def __init__(self, stop_token_sequences: list[list[int]], device: str = "cuda"):
super().__init__()
self.stop_sequence_tensors = []
for seq in stop_token_sequences:
if seq: # Only process non-empty sequences
self.stop_sequence_tensors.append(torch.tensor(seq, dtype=torch.long, device=device))
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
for stop_seq_tensor in self.stop_sequence_tensors:
current_sequence_length = input_ids.shape[-1]
stop_sequence_length = stop_seq_tensor.shape[-1]
# stop_sequence_length should be > 0 due to the check in __init__
if stop_sequence_length == 0: # Should ideally not be reached if seq was non-empty
continue
if current_sequence_length >= stop_sequence_length:
# Check the last tokens of the last sequence in the batch
last_tokens = input_ids[0, -stop_sequence_length:]
if torch.equal(last_tokens, stop_seq_tensor):
return True
return False
# Keep only the process_audio_streaming function that's actually used in the Gradio interface
@spaces.GPU
def process_audio_streaming(audio_file, model_choice, question="Describe this audio in detail"):
# Load the selected model
model_path = model_path_2 if model_choice == "Think" else model_path_1
model, processor = load_model(model_path)
# Load and process the audio with torchaudio
waveform, sr = torchaudio.load(audio_file)
# Resample to 16kHz if needed
if sr != processor.feature_extractor.sampling_rate:
waveform = torchaudio.functional.resample(waveform, sr, processor.feature_extractor.sampling_rate)
sr = processor.feature_extractor.sampling_rate
# Convert to mono if stereo
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
# Get the audio data as numpy array
y = waveform.squeeze().numpy()
# Set sampling rate for the processor
sampling_rate = processor.feature_extractor.sampling_rate
# Define prompts based on model choice
prompt_think_semantics = f"You are given a question and an audio clip. Your task is to answer the question based on the audio clip. First, think about the question and the audio clip and put your thoughts in and tags. Then reason about the semantic elements involved in the audio clip and put your reasoning in and tags. Then answer the question based on the audio clip, put your answer in and tags. {question}"
instruction_text = ""
if model_choice == "Think + Semantics":
instruction_text = prompt_think_semantics
else: # Default to the question if no specific model processing is chosen.
instruction_text = question
# Create conversation format
conversation = [
{"role": "system", "content": [
{"type": "text", "text": "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech."}]},
{"role": "user", "content": [
{"type": "audio", "audio": y},
{"type": "text", "text": instruction_text}
]}
]
# Format the chat
chat_text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
# Process multimedia info using qwen_omni_utils
audios, images, videos = process_mm_info(conversation, use_audio_in_video=False)
# Process the inputs
inputs = processor(
text=chat_text,
audio=audios,
images=images,
videos=videos,
return_tensors="pt",
sampling_rate=sampling_rate,
).to(model.device)
# Create a standard streamer instance
streamer = TextIteratorStreamer(
processor.tokenizer,
timeout=10.,
skip_prompt=True,
skip_special_tokens=True
)
# Initialize variables for buffering
accumulated_output = ""
buffer = ""
stop_sequence = ""
stop_found = False
answer_token_ids = processor.tokenizer.encode("", add_special_tokens=False)
processed_stop_token_ids = [answer_token_ids]
# Get the device of the model to ensure tensors are on the same device
model_device = next(model.parameters()).device
custom_stopping_criteria = StopOnSpecificToken(stop_token_sequences=processed_stop_token_ids, device=model_device.type)
stopping_criteria = StoppingCriteriaList([custom_stopping_criteria])
# Generate the output with streaming
with torch.no_grad():
generate_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=768,
do_sample=False,
stopping_criteria=stopping_criteria
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
# Process the stream with buffering
for output in streamer:
if stop_found:
break
output = process_output(output)
buffer += output
# Check if stop sequence is in the buffer
if stop_sequence in buffer:
# Output everything up to and including the stop sequence
before_stop = buffer.split(stop_sequence)[0]
accumulated_output += before_stop + stop_sequence
yield accumulated_output
stop_found = True
break
else:
# Check if we can safely output part of the buffer
# Keep the last N characters where N is the length of the stop sequence
if len(buffer) > len(stop_sequence):
# Output all but the last len(stop_sequence) characters
safe_output = buffer[:-len(stop_sequence)]
buffer = buffer[-len(stop_sequence):]
accumulated_output += safe_output
yield accumulated_output
# Output any remaining buffer if no stop sequence was found
if not stop_found and buffer:
accumulated_output += buffer
yield accumulated_output
# Create Gradio interface for audio processing
audio_demo = gr.Interface(
fn=process_audio_streaming,
inputs=[
gr.Audio(type="filepath", label="Upload Audio"),
gr.Radio(["Think", "Think + Semantics"], label="Select Model", value="Think + Semantics"),
gr.Textbox(label="Question", value="Describe this audio in detail")
],
outputs=gr.Textbox(label="Generated Output", lines=30),
title="AudSemThinker",
description="Upload an audio file and the model will provide detailed analysis and description. Choose between different model versions.",
examples=[["examples/1.wav", "Think + Semantics", "Describe this audio in detail"]],
cache_examples=False,
live=True
)
# Launch the apps
if __name__ == "__main__":
audio_demo.launch()