kcz358 commited on
Commit
91ff004
·
verified ·
1 Parent(s): 6f5686b

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +54 -0
README.md CHANGED
@@ -80,6 +80,8 @@ python3 -m pip install transformers@git+https://github.com/huggingface/transform
80
  ```
81
  as this is the transformers version we are using when building this model.
82
 
 
 
83
  ```python
84
  from transformers import AutoProcessor, AutoModelForCausalLM
85
 
@@ -125,6 +127,58 @@ cont = outputs[:, inputs["input_ids"].shape[-1] :]
125
  print(processor.batch_decode(cont, skip_special_tokens=True)[0])
126
  ```
127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  ## Training Details
129
 
130
  ### Training Data
 
80
  ```
81
  as this is the transformers version we are using when building this model.
82
 
83
+ ### Simple Demo
84
+
85
  ```python
86
  from transformers import AutoProcessor, AutoModelForCausalLM
87
 
 
127
  print(processor.batch_decode(cont, skip_special_tokens=True)[0])
128
  ```
129
 
130
+ ### Batch Inference
131
+ The model supports batch inference with transformers. An example demo is like this:
132
+ ```python
133
+ from transformers import AutoProcessor, AutoModelForCausalLM
134
+
135
+ import torch
136
+ import librosa
137
+
138
+ def load_audio():
139
+ return librosa.load(librosa.ex("libri1"), sr=16000)[0]
140
+
141
+ def load_audio_2():
142
+ return librosa.load(librosa.ex("libri2"), sr=16000)[0]
143
+
144
+
145
+ processor = AutoProcessor.from_pretrained("lmms-lab/Aero-1-Audio-1.5B", trust_remote_code=True)
146
+ # We encourage to use flash attention 2 for better performance
147
+ # Please install it with `pip install --no-build-isolation flash-attn`
148
+ # If you do not want flash attn, please use sdpa or eager`
149
+ model = AutoModelForCausalLM.from_pretrained("lmms-lab/Aero-1-Audio-1.5B", device_map="cuda", torch_dtype="auto", attn_implementation="flash_attention_2", trust_remote_code=True)
150
+ model.eval()
151
+
152
+ messages = [
153
+ {
154
+ "role": "user",
155
+ "content": [
156
+ {
157
+ "type": "audio_url",
158
+ "audio": "placeholder",
159
+ },
160
+ {
161
+ "type": "text",
162
+ "text": "Please transcribe the audio",
163
+ }
164
+ ]
165
+ }
166
+ ]
167
+ messages = [messages, messages]
168
+
169
+ audios = [load_audio(), load_audio_2()]
170
+
171
+ processor.tokenizer.padding_side="left"
172
+ prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
173
+ inputs = processor(text=prompt, audios=audios, sampling_rate=16000, return_tensors="pt", padding=True)
174
+ inputs = {k: v.to("cuda") for k, v in inputs.items()}
175
+ outputs = model.generate(**inputs, eos_token_id=151645, pad_token_id=151643, max_new_tokens=4096)
176
+
177
+ cont = outputs[:, inputs["input_ids"].shape[-1] :]
178
+
179
+ print(processor.batch_decode(cont, skip_special_tokens=True))
180
+ ```
181
+
182
  ## Training Details
183
 
184
  ### Training Data