File size: 9,927 Bytes
190ddd0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
import spaces
import gradio as gr
import torch
import subprocess
import sys
from transformers import AutoProcessor, Qwen2_5OmniThinkerForConditionalGeneration, TextIteratorStreamer
import torchaudio
from threading import Thread
from qwen_omni_utils import process_mm_info
from transformers import StoppingCriteria, StoppingCriteriaList

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA version: {torch.version.cuda}")

# Check environment
result = subprocess.run([sys.executable, '-c', 
    'import torch; print(f"PyTorch: {torch.__version__}"); print(f"CUDA: {torch.version.cuda}")'], 
    capture_output=True, text=True)
print(result.stdout)

# Model paths and configuration
model_path_1 = "./model"
model_path_2 = "./model2"
base_model_id = "Qwen/Qwen2.5-Omni-7B"

# Dictionary to store loaded models and processors
loaded_models = {}

# Load the model and processor
def load_model(model_path):
    # Check if model is already loaded
    if model_path in loaded_models:
        return loaded_models[model_path]
    
    # Load the processor from the base model
    processor = AutoProcessor.from_pretrained(
        base_model_id,
        trust_remote_code=True,
    )
    
    # Load the model
    model = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained(
        model_path,
        torch_dtype=torch.bfloat16,
        attn_implementation="flash_attention_2",
        trust_remote_code=True,
        device_map="auto",
    )
    
    model.eval()
    
    # Store in cache
    loaded_models[model_path] = (model, processor)
    
    return model, processor

# Initialize first model and processor
model, processor = load_model(model_path_1)


def process_output(output):
    if "<think>" in output:
        rest = output.split("<think>")[1]
        output = "<think>\n" + rest
    elif "<semantic_elements>" in output:
        rest = output.split("<semantic_elements>")[1]
        output = "<semantic_elements>\n" + rest
    elif "<answer>" in output:
        rest = output.split("<answer>")[1]
        output = "<answer>\n" + rest
    elif "</think>" in output:
        rest = output.split("</think>")[0]
        output = rest + "\n</think>\n\n"
    elif "</semantic_elements>" in output:
        rest = output.split("</semantic_elements>")[0]
        output = rest + "\n</semantic_elements>\n\n"
    elif "</answer>" in output:
        rest = output.split("</answer>")[0]
        output = rest + "\n</answer>\n"
    output = output.replace("\\n", "\n")
    output = output.replace("\\", "\n")
    output = output.replace("\n-", "-")
    return output

# Custom Stopping Criteria
class StopOnSpecificToken(StoppingCriteria):
    def __init__(self, stop_token_sequences: list[list[int]], device: str = "cuda"):
        super().__init__()
        self.stop_sequence_tensors = []
        for seq in stop_token_sequences:
            if seq:  # Only process non-empty sequences
                self.stop_sequence_tensors.append(torch.tensor(seq, dtype=torch.long, device=device))

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        for stop_seq_tensor in self.stop_sequence_tensors:
            current_sequence_length = input_ids.shape[-1]
            stop_sequence_length = stop_seq_tensor.shape[-1]

            # stop_sequence_length should be > 0 due to the check in __init__
            if stop_sequence_length == 0: # Should ideally not be reached if seq was non-empty
                continue

            if current_sequence_length >= stop_sequence_length:
                # Check the last tokens of the last sequence in the batch
                last_tokens = input_ids[0, -stop_sequence_length:]
                if torch.equal(last_tokens, stop_seq_tensor):
                    return True
        return False

# Keep only the process_audio_streaming function that's actually used in the Gradio interface
@spaces.GPU
def process_audio_streaming(audio_file, model_choice, question="Describe this audio in detail"):
    # Load the selected model
    model_path = model_path_2 if model_choice == "Think" else model_path_1
    model, processor = load_model(model_path)
    
    # Load and process the audio with torchaudio
    waveform, sr = torchaudio.load(audio_file)
    
    # Resample to 16kHz if needed
    if sr != processor.feature_extractor.sampling_rate:
        waveform = torchaudio.functional.resample(waveform, sr, processor.feature_extractor.sampling_rate)
        sr = processor.feature_extractor.sampling_rate
    
    # Convert to mono if stereo
    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)
    
    # Get the audio data as numpy array
    y = waveform.squeeze().numpy()
            
    # Set sampling rate for the processor
    sampling_rate = processor.feature_extractor.sampling_rate
    
    # Define prompts based on model choice
    prompt_think_semantics = f"You are given a question and an audio clip. Your task is to answer the question based on the audio clip. First, think about the question and the audio clip and put your thoughts in <think> and </think> tags. Then reason about the semantic elements involved in the audio clip and put your reasoning in <semantic_elements> and </semantic_elements> tags. Then answer the question based on the audio clip, put your answer in <answer> and </answer> tags. {question}"

    instruction_text = ""
    if model_choice == "Think + Semantics":
        instruction_text = prompt_think_semantics
    else: # Default to the question if no specific model processing is chosen.
        instruction_text = question
    
    # Create conversation format
    conversation = [
        {"role": "system", "content": [
            {"type": "text", "text": "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech."}]},
        {"role": "user", "content": [
            {"type": "audio", "audio": y},
            {"type": "text", "text": instruction_text}
        ]}
    ]
    
    # Format the chat
    chat_text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
    
    # Process multimedia info using qwen_omni_utils
    audios, images, videos = process_mm_info(conversation, use_audio_in_video=False)
    
    # Process the inputs
    inputs = processor(
        text=chat_text,
        audio=audios,
        images=images, 
        videos=videos,
        return_tensors="pt",
        sampling_rate=sampling_rate,
    ).to(model.device)
    
    # Create a standard streamer instance
    streamer = TextIteratorStreamer(
        processor.tokenizer, 
        timeout=10., 
        skip_prompt=True, 
        skip_special_tokens=True
    )
    
    # Initialize variables for buffering
    accumulated_output = ""
    buffer = ""
    stop_sequence = "</answer>"
    stop_found = False

    answer_token_ids = processor.tokenizer.encode("</answer>", add_special_tokens=False)
    processed_stop_token_ids = [answer_token_ids]

    # Get the device of the model to ensure tensors are on the same device
    model_device = next(model.parameters()).device
    custom_stopping_criteria = StopOnSpecificToken(stop_token_sequences=processed_stop_token_ids, device=model_device.type)
    stopping_criteria = StoppingCriteriaList([custom_stopping_criteria])

    # Generate the output with streaming
    with torch.no_grad():
        generate_kwargs = dict(
            **inputs,
            streamer=streamer,
            max_new_tokens=768,
            do_sample=False,
            stopping_criteria=stopping_criteria
        )
        t = Thread(target=model.generate, kwargs=generate_kwargs)
        t.start()

        # Process the stream with buffering
        for output in streamer:
            if stop_found:
                break
                
            output = process_output(output)
            buffer += output
            
            # Check if stop sequence is in the buffer
            if stop_sequence in buffer:
                # Output everything up to and including the stop sequence
                before_stop = buffer.split(stop_sequence)[0]
                accumulated_output += before_stop + stop_sequence
                yield accumulated_output
                stop_found = True
                break
            else:
                # Check if we can safely output part of the buffer
                # Keep the last N characters where N is the length of the stop sequence
                if len(buffer) > len(stop_sequence):
                    # Output all but the last len(stop_sequence) characters
                    safe_output = buffer[:-len(stop_sequence)]
                    buffer = buffer[-len(stop_sequence):]
                    accumulated_output += safe_output
                    yield accumulated_output
        
        # Output any remaining buffer if no stop sequence was found
        if not stop_found and buffer:
            accumulated_output += buffer
            yield accumulated_output

# Create Gradio interface for audio processing
audio_demo = gr.Interface(
    fn=process_audio_streaming,
    inputs=[
        gr.Audio(type="filepath", label="Upload Audio"),
        gr.Radio(["Think", "Think + Semantics"], label="Select Model", value="Think + Semantics"),
        gr.Textbox(label="Question", value="Describe this audio in detail")
    ],
    outputs=gr.Textbox(label="Generated Output", lines=30),
    title="AudSemThinker",
    description="Upload an audio file and the model will provide detailed analysis and description. Choose between different model versions.",
    examples=[["examples/1.wav", "Think + Semantics", "Describe this audio in detail"]],
    cache_examples=False,
    live=True
)

# Launch the apps
if __name__ == "__main__":
    audio_demo.launch()