Update README.md
Browse files
README.md
CHANGED
@@ -20,64 +20,81 @@ Corresponding paper: https://arxiv.org/abs/2505.14142
|
|
20 |
To use `AudSemThinker-QA-GRPO` for audio question answering, you can load it using the `transformers` library. Ensure you have `torch`, `torchaudio`, and `soundfile` installed.
|
21 |
|
22 |
```python
|
23 |
-
from transformers import AutoProcessor, AutoModelForCausalLM
|
24 |
-
import torch
|
25 |
-
import torchaudio
|
26 |
import soundfile as sf
|
|
|
|
|
|
|
27 |
|
28 |
-
# Load
|
29 |
-
|
30 |
-
|
31 |
-
"
|
32 |
-
torch_dtype=torch.bfloat16,
|
33 |
device_map="auto",
|
34 |
trust_remote_code=True,
|
35 |
-
low_cpu_mem_usage=True
|
36 |
)
|
37 |
|
38 |
-
#
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
|
|
|
|
41 |
audio_input, sampling_rate = torchaudio.load(audio_file)
|
42 |
if sampling_rate != processor.feature_extractor.sampling_rate:
|
43 |
-
audio_input = torchaudio.transforms.Resample(
|
44 |
-
|
|
|
|
|
|
|
45 |
|
46 |
# Example multiple-choice question
|
47 |
question = "What type of sound is present in the audio? Options: (A) Speech (B) Music (C) Environmental Sound (D) Silence"
|
48 |
user_prompt_text = f"You are given a question and an audio clip. Your task is to answer the question based on the audio clip. First, think about the question and the audio clip and put your thoughts in <think> and </think> tags. Then reason about the semantic elements involved in the audio clip and put your reasoning in <semantic_elements> and </semantic_elements> tags. Then answer the question based on the audio clip, put your answer in <answer> and </answer> tags.\nQuestion: {question}"
|
49 |
|
50 |
-
#
|
51 |
-
|
52 |
-
{
|
|
|
|
|
|
|
|
|
|
|
53 |
{
|
54 |
"role": "user",
|
55 |
"content": [
|
56 |
{"type": "audio", "audio": audio_input},
|
57 |
{"type": "text", "text": user_prompt_text}
|
58 |
-
]
|
59 |
-
}
|
60 |
]
|
61 |
|
62 |
-
#
|
63 |
-
|
64 |
-
|
65 |
-
tokenize=False,
|
66 |
-
add_generation_prompt=True
|
67 |
-
)
|
68 |
-
|
69 |
-
# Prepare inputs for the model
|
70 |
inputs = processor(
|
71 |
-
text=
|
72 |
-
audio=
|
73 |
-
|
74 |
-
|
|
|
|
|
|
|
|
|
75 |
|
76 |
-
#
|
77 |
output_ids = model.generate(**inputs, max_new_tokens=512)
|
78 |
-
response = processor.batch_decode(output_ids, skip_special_tokens=True)
|
|
|
79 |
|
80 |
-
print(response)
|
81 |
# Expected output format for QA:
|
82 |
# <think>...detailed reasoning about the audio scene and question...</think>
|
83 |
# <semantic_elements>...list of identified semantic descriptors...</semantic_elements>
|
|
|
20 |
To use `AudSemThinker-QA-GRPO` for audio question answering, you can load it using the `transformers` library. Ensure you have `torch`, `torchaudio`, and `soundfile` installed.
|
21 |
|
22 |
```python
|
|
|
|
|
|
|
23 |
import soundfile as sf
|
24 |
+
from transformers import Qwen2_5OmniForConditionalGeneration, Qwen2_5OmniProcessor
|
25 |
+
from qwen_omni_utils import process_mm_info
|
26 |
+
import torchaudio
|
27 |
|
28 |
+
# default: Load the model on the available device(s)
|
29 |
+
model = Qwen2_5OmniForConditionalGeneration.from_pretrained(
|
30 |
+
"gijs/audsemthinker-qa-grpo",
|
31 |
+
torch_dtype="auto",
|
|
|
32 |
device_map="auto",
|
33 |
trust_remote_code=True,
|
34 |
+
low_cpu_mem_usage=True
|
35 |
)
|
36 |
|
37 |
+
# We recommend enabling flash_attention_2 for better acceleration and memory saving.
|
38 |
+
# model = Qwen2_5OmniForConditionalGeneration.from_pretrained(
|
39 |
+
# "gijs/audsemthinker-qa-grpo",
|
40 |
+
# torch_dtype="auto",
|
41 |
+
# device_map="auto",
|
42 |
+
# attn_implementation="flash_attention_2",
|
43 |
+
# trust_remote_code=True,
|
44 |
+
# low_cpu_mem_usage=True
|
45 |
+
# )
|
46 |
+
|
47 |
+
processor = Qwen2_5OmniProcessor.from_pretrained("gijs/audsemthinker-qa-grpo", trust_remote_code=True)
|
48 |
|
49 |
+
# Load and preprocess audio
|
50 |
+
audio_file = "path/to/your/audio.wav"
|
51 |
audio_input, sampling_rate = torchaudio.load(audio_file)
|
52 |
if sampling_rate != processor.feature_extractor.sampling_rate:
|
53 |
+
audio_input = torchaudio.transforms.Resample(
|
54 |
+
orig_freq=sampling_rate,
|
55 |
+
new_freq=processor.feature_extractor.sampling_rate
|
56 |
+
)(audio_input)
|
57 |
+
audio_input = audio_input.squeeze().numpy()
|
58 |
|
59 |
# Example multiple-choice question
|
60 |
question = "What type of sound is present in the audio? Options: (A) Speech (B) Music (C) Environmental Sound (D) Silence"
|
61 |
user_prompt_text = f"You are given a question and an audio clip. Your task is to answer the question based on the audio clip. First, think about the question and the audio clip and put your thoughts in <think> and </think> tags. Then reason about the semantic elements involved in the audio clip and put your reasoning in <semantic_elements> and </semantic_elements> tags. Then answer the question based on the audio clip, put your answer in <answer> and </answer> tags.\nQuestion: {question}"
|
62 |
|
63 |
+
# Conversation format
|
64 |
+
conversation = [
|
65 |
+
{
|
66 |
+
"role": "system",
|
67 |
+
"content": [
|
68 |
+
{"type": "text", "text": "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech."}
|
69 |
+
],
|
70 |
+
},
|
71 |
{
|
72 |
"role": "user",
|
73 |
"content": [
|
74 |
{"type": "audio", "audio": audio_input},
|
75 |
{"type": "text", "text": user_prompt_text}
|
76 |
+
],
|
77 |
+
},
|
78 |
]
|
79 |
|
80 |
+
# Preparation for inference
|
81 |
+
text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
|
82 |
+
audios, images, videos = process_mm_info(conversation)
|
|
|
|
|
|
|
|
|
|
|
83 |
inputs = processor(
|
84 |
+
text=text,
|
85 |
+
audio=audios,
|
86 |
+
images=images,
|
87 |
+
videos=videos,
|
88 |
+
return_tensors="pt",
|
89 |
+
padding=True
|
90 |
+
)
|
91 |
+
inputs = inputs.to(model.device).to(model.dtype)
|
92 |
|
93 |
+
# Inference: Generation of the output
|
94 |
output_ids = model.generate(**inputs, max_new_tokens=512)
|
95 |
+
response = processor.batch_decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
96 |
+
print(response[0])
|
97 |
|
|
|
98 |
# Expected output format for QA:
|
99 |
# <think>...detailed reasoning about the audio scene and question...</think>
|
100 |
# <semantic_elements>...list of identified semantic descriptors...</semantic_elements>
|