tzzte commited on
Commit
30320c9
·
verified ·
1 Parent(s): 1d60139

Upload 13 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ EchoX-Vocoder/g_00500000 filter=lfs diff=lfs merge=lfs -text
37
+ show_case/2.wav filter=lfs diff=lfs merge=lfs -text
38
+ show_case/Translate_de_audio_prompt.wav filter=lfs diff=lfs merge=lfs -text
ACLlama_el_s2s.py ADDED
@@ -0,0 +1,565 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch.nn import CrossEntropyLoss, CTCLoss
7
+
8
+
9
+ from transformers import AutoConfig, AutoModelForCausalLM, \
10
+ LlamaConfig, LlamaModel, LlamaForCausalLM
11
+ from transformers.trainer_pt_utils import LabelSmoother
12
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
13
+ from transformers import (
14
+ WhisperProcessor,
15
+ WhisperModel,
16
+ )
17
+ from T2ULlama_CR_online import T2ULlamaForCausalLM
18
+
19
+ IGNORE_TOKEN_ID = LabelSmoother.ignore_index
20
+
21
+ class ACLlamaConfig(LlamaConfig):
22
+ model_type = "ACLlama"
23
+
24
+ def load_whisper(audio_tower_name, device="cuda"):
25
+ model = WhisperModel.from_pretrained(
26
+ audio_tower_name,torch_dtype=torch.float16,low_cpu_mem_usage=True).to(device)
27
+ model.config.forced_decoder_ids = None
28
+ return model
29
+
30
+ class LookBackModule(nn.Module):
31
+ def __init__(self, cfg: LlamaConfig):
32
+ super().__init__()
33
+ self.encoder_attn = nn.MultiheadAttention(
34
+ cfg.hidden_size,
35
+ cfg.num_attention_heads,
36
+ dropout=0.1,
37
+ batch_first=True
38
+ )
39
+ self.atten_layer_norm = nn.LayerNorm(cfg.hidden_size)
40
+
41
+
42
+ def forward(self, x, wav_feature, bf_shrink_padding_mask):
43
+
44
+ residual = x
45
+ x, _ = self.encoder_attn(
46
+ query=x,
47
+ key=wav_feature,
48
+ value=wav_feature,
49
+ key_padding_mask=bf_shrink_padding_mask,
50
+ #attn_mask=padding_mask,
51
+ )
52
+ x += residual
53
+ x = self.atten_layer_norm(x)
54
+ return x
55
+
56
+ class ACLlamaModel(LlamaModel):
57
+ config_class = ACLlamaConfig
58
+
59
+ def __init__(self, config: LlamaConfig):
60
+ super(ACLlamaModel, self).__init__(config)
61
+
62
+ if hasattr(config, "audio_tower"):
63
+ self.audio_tower = [load_whisper(config.audio_tower)]
64
+
65
+ if hasattr(config, "adapter_size"):
66
+
67
+ self.mm_projector1 = nn.Linear(config.adapter_size*2 , config.hidden_size)
68
+ asr_encoder_layer = nn.TransformerEncoderLayer(
69
+ d_model=config.hidden_size,
70
+ nhead=config.num_attention_heads,
71
+ dim_feedforward=config.hidden_size*2,
72
+ dropout=0.1,
73
+ norm_first=True
74
+ )
75
+ self.lbm = LookBackModule(config)
76
+ self.out_norm = nn.LayerNorm(config.hidden_size)
77
+ self.audio_feature_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
78
+ self.asr_transformer_encoder = nn.TransformerEncoder(asr_encoder_layer, num_layers=1)
79
+ self.mask_tensor=(torch.ones([1, 2048])>0)
80
+ self.length=-1
81
+
82
+ def forward(
83
+ self,
84
+ input_ids: torch.LongTensor = None,
85
+ attention_mask: Optional[torch.Tensor] = None,
86
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
87
+ inputs_embeds: Optional[torch.FloatTensor] = None,
88
+ use_cache: Optional[bool] = None,
89
+ output_attentions: Optional[bool] = None,
90
+ output_hidden_states: Optional[bool] = None,
91
+ audios: Optional[torch.FloatTensor] = None,
92
+ return_dict: Optional[bool] = None,
93
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
94
+
95
+ # HACK: replace back original embeddings for LLaAA pretraining
96
+ orig_embeds_params = getattr(self, 'orig_embeds_params', None)
97
+
98
+ if inputs_embeds is None:
99
+ inputs_embeds = self.embed_tokens(input_ids)
100
+
101
+ audio_tower = getattr(self, 'audio_tower', None)
102
+ if audio_tower is not None and (input_ids.shape[1] != 1 or self.training) and audios is not None:
103
+ audio_tower = audio_tower[0] # HACK: for FSDP
104
+ audio_list=[]
105
+
106
+ audio_config = audio_tower.config
107
+ for audio in audios:
108
+ with torch.no_grad():
109
+ audio_feature = audio_tower.encoder(audio).last_hidden_state
110
+
111
+ audio_feature = audio_feature.view(audio_feature.shape[0], audio_feature.shape[1]//2, 2 * audio_feature.shape[2])
112
+ audio_feature = self.mm_projector1(audio_feature)
113
+ audio_feature = self.asr_transformer_encoder(audio_feature)
114
+ audio_feature = self.out_norm(audio_feature)
115
+ audio_list.append(audio_feature)
116
+
117
+ audio_features = torch.stack(audio_list, dim=0)
118
+ batch = audio_features.shape[0]
119
+ audio_turn = audio_features.shape[1]
120
+ audio_features = audio_features.view((batch * audio_turn,)+audio_features.shape[2:])
121
+
122
+ predict_logits = self.audio_feature_head(audio_features)
123
+
124
+ new_input_embeds = []
125
+ label_shift = []
126
+ speech_pos = []
127
+ label_extend = -1
128
+ new_input_ids = []
129
+ tokens = predict_logits.argmax(dim=-1)
130
+ shrink_mask = tokens.roll(1) != tokens
131
+ shrink_mask[:,0] = True
132
+
133
+ lengths = shrink_mask.long().sum(-1)
134
+ shrink_2d = audio_features[shrink_mask]
135
+ #num_patches = audio_features.shape[1]
136
+ num_patches = audio_config.audio_patch_size
137
+ l_index=0
138
+ shrink_features_raw = []
139
+ for v, audio_feature, mask in zip(lengths, audio_features, ~shrink_mask):
140
+ shrink_feature = shrink_2d[l_index:l_index+v]
141
+ shrink_feature = self.lbm(shrink_feature, audio_feature, bf_shrink_padding_mask=mask)
142
+ shrink_features_raw.append(shrink_feature)
143
+ l_index += v
144
+
145
+ shrink_features = []
146
+ for i in range(0, len(shrink_features_raw), audio_turn):
147
+ shrink_features.append(shrink_features_raw[i:i+audio_turn])
148
+ if self.training:
149
+ maxn_length = lengths.view(batch,audio_turn).sum(-1).max()
150
+ label_extend = maxn_length - num_patches * audio_turn
151
+ old_seq_length = inputs_embeds.shape[1]
152
+ for cur_input_ids, cur_input_embeds, cur_shrink_features in zip(input_ids, inputs_embeds, shrink_features):
153
+ pad_ids = torch.full(size=(maxn_length,), fill_value=audio_config.llm_pad_token_id, dtype=torch.long).to(attention_mask.device)
154
+ pad_embeds = self.embed_tokens(pad_ids)
155
+ audio_start_token_pos_all = torch.where(cur_input_ids == audio_config.audio_patch_token)[0]
156
+ #print(cur_input_embeds.shape,cur_input_ids.shape)
157
+ inner_label_shift = []
158
+ inner_speech_pos = []
159
+ for audio_start_token_pos, shrink_feature in reversed(list(zip(audio_start_token_pos_all, cur_shrink_features))): #zip(audio_start_token_pos_all, cur_shrink_features):
160
+ cur_speech_length = shrink_feature.shape[0]
161
+
162
+ cur_input_ids = torch.cat((cur_input_ids[:audio_start_token_pos],
163
+ cur_input_ids[audio_start_token_pos: audio_start_token_pos+1].repeat(cur_speech_length),
164
+ cur_input_ids[audio_start_token_pos + num_patches:]), dim=0)
165
+ cur_input_embeds = torch.cat((
166
+ cur_input_embeds[:audio_start_token_pos],
167
+ shrink_feature,
168
+ cur_input_embeds[audio_start_token_pos + num_patches:]), dim=0)
169
+ inner_label_shift.insert(0, cur_speech_length - num_patches)
170
+ inner_speech_pos.insert(0, audio_start_token_pos)
171
+
172
+ label_shift = label_shift + inner_label_shift
173
+ speech_pos = speech_pos + inner_speech_pos
174
+
175
+ cur_new_input_embeds = torch.cat((cur_input_embeds, pad_embeds[:old_seq_length + label_extend - cur_input_embeds.shape[0]]),dim=0)
176
+ cur_new_input_ids = torch.cat((cur_input_ids, pad_ids[:old_seq_length + label_extend - cur_input_ids.shape[0]]),dim=0)
177
+ new_input_embeds.append(cur_new_input_embeds)
178
+ new_input_ids.append(cur_new_input_ids)
179
+
180
+ input_ids = torch.stack(new_input_ids, dim=0)
181
+ attention_mask=input_ids.ne(audio_config.llm_pad_token_id)
182
+ inputs_embeds = torch.stack(new_input_embeds, dim=0)
183
+
184
+ batch_label_shift = []
185
+ batch_speech_pos=[]
186
+ for i in range(0, len(label_shift), audio_turn):
187
+ batch_label_shift.append(label_shift[i:i+audio_turn])
188
+ batch_speech_pos.append(speech_pos[i:i+audio_turn])
189
+ else:
190
+ # Inference mode with batch_size=1
191
+ assert input_ids.shape[0] == 1, "This implementation only supports batch_size=1 during inference"
192
+
193
+ # Get all audio token positions in this sample
194
+ audio_start_token_positions = torch.where(input_ids[0] == audio_config.audio_patch_token)[0]
195
+
196
+ # Initialize with original embeddings
197
+ current_embeds = inputs_embeds[0] # [seq_len, embed_dim]
198
+ current_ids = input_ids[0] # [seq_len]
199
+
200
+ # Process each audio token position sequentially
201
+ position_shift = 0 # Track position changes due to expansions
202
+
203
+ # Ensure shrink_features is properly formatted
204
+ if isinstance(shrink_features[0], list):
205
+ # If it's a list of lists (batch_size=1 but multiple turns), flatten it
206
+ shrink_features = [item for sublist in shrink_features for item in sublist]
207
+
208
+ for pos_idx, audio_pos in enumerate(audio_start_token_positions):
209
+ adjusted_pos = audio_pos + position_shift
210
+
211
+ # Get corresponding shrink feature (ensure it's a tensor)
212
+ shrink_feature = shrink_features[pos_idx]
213
+ if isinstance(shrink_feature, list):
214
+ shrink_feature = torch.stack(shrink_feature, dim=0)
215
+
216
+ v = shrink_feature.shape[0] # Now this should work
217
+ # print('len: ', v)
218
+
219
+ # Expand the input ids and embeddings
220
+ current_ids = torch.cat([
221
+ current_ids[:adjusted_pos],
222
+ current_ids[adjusted_pos:adjusted_pos+1].repeat(v),
223
+ current_ids[adjusted_pos + num_patches:]
224
+ ], dim=0)
225
+
226
+ current_embeds = torch.cat([
227
+ current_embeds[:adjusted_pos],
228
+ shrink_feature,
229
+ current_embeds[adjusted_pos + num_patches:]
230
+ ], dim=0)
231
+
232
+ # Update position shift for next iteration
233
+ position_shift += (v - num_patches)
234
+
235
+ # Update the tensors (unsqueeze to restore batch dim)
236
+ input_ids = current_ids.unsqueeze(0) # [1, new_seq_len]
237
+ inputs_embeds = current_embeds.unsqueeze(0) # [1, new_seq_len, embed_dim]
238
+ attention_mask = input_ids.ne(audio_config.llm_pad_token_id)
239
+
240
+ # Update inference state tracking
241
+ if not hasattr(self, 'mask_tensor'):
242
+ # Initialize with current attention mask
243
+ self.mask_tensor = attention_mask.clone()
244
+ self.length = attention_mask.shape[1]
245
+ else:
246
+ # Ensure mask tensor is on correct device
247
+ self.mask_tensor = self.mask_tensor.to(attention_mask.device)
248
+
249
+ # Expand mask tensor if needed
250
+ if self.mask_tensor.shape[1] < attention_mask.shape[1]:
251
+ new_mask = torch.zeros(1, attention_mask.shape[1],
252
+ dtype=torch.bool,
253
+ device=attention_mask.device)
254
+ new_mask[0, :self.mask_tensor.shape[1]] = self.mask_tensor
255
+ self.mask_tensor = new_mask
256
+
257
+ # Update mask tensor
258
+ self.mask_tensor[0, :attention_mask.shape[1]] = attention_mask[0]
259
+ self.length = attention_mask.shape[1]
260
+
261
+ attention_mask=self.mask_tensor[:,:self.length]
262
+ self.length+=1
263
+
264
+ return_state=super(ACLlamaModel, self).forward(
265
+ input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values,
266
+ inputs_embeds=inputs_embeds, use_cache=use_cache,
267
+ output_attentions=output_attentions, output_hidden_states=output_hidden_states,
268
+ return_dict=return_dict
269
+ )
270
+ if self.training and audios is not None:
271
+ return_state["audio_features"] = predict_logits
272
+ return_state["label_shift"] = batch_label_shift
273
+ return_state["label_extend"] = label_extend
274
+ return_state["speech_pos"] = batch_speech_pos
275
+ #return_state = {"audio_features":predict_logits}
276
+ return return_state
277
+
278
+
279
+ class ACLlamaForCausalLM(LlamaForCausalLM):
280
+ config_class = ACLlamaConfig
281
+
282
+ def __init__(self, config):
283
+ super(LlamaForCausalLM, self).__init__(config)
284
+ self.model = ACLlamaModel(config)
285
+
286
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
287
+
288
+ # t2u by kkq
289
+ if hasattr(config, "unit_output"):
290
+ self.unit_output = config.unit_output
291
+ self.unit_translator = T2ULlamaForCausalLM(config, self.lm_head.weight)
292
+
293
+ # Initialize weights and apply final processing
294
+ self.post_init()
295
+
296
+ def get_model(self):
297
+ return self.model
298
+
299
+ def get_unit_translator(self):
300
+ return self.unit_translator
301
+
302
+ def forward(
303
+ self,
304
+ input_ids: torch.LongTensor = None,
305
+ attention_mask: Optional[torch.Tensor] = None,
306
+ position_ids: Optional[torch.LongTensor] = None,
307
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
308
+ inputs_embeds: Optional[torch.FloatTensor] = None,
309
+ labels: Optional[torch.LongTensor] = None,
310
+ t2u_input_ids: Optional[torch.LongTensor] = None,
311
+ t2u_labels: Optional[torch.LongTensor] = None,
312
+ t2u_attention_mask: Optional[torch.Tensor] = None,
313
+ unit_targets: Optional[torch.Tensor] = None,
314
+ sub_lengths: Optional[torch.Tensor] = None,
315
+ asr_targets: Optional[torch.LongTensor] = None,
316
+ use_cache: Optional[bool] = None,
317
+ output_attentions: Optional[bool] = None,
318
+ output_hidden_states: Optional[bool] = None,
319
+ audios: Optional[torch.FloatTensor] = None,
320
+ return_dict: Optional[bool] = None,
321
+ cache_position: Optional[torch.LongTensor] = None,
322
+ do_task: str = None,
323
+ assistant_after_audio_shifts: Optional[torch.Tensor] = None,
324
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
325
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
326
+ output_hidden_states = (
327
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
328
+ )
329
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
330
+
331
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
332
+
333
+ # t2u by kkq
334
+ # pretrain(t2u only) finetune(s2t&e2u)
335
+ do_task = do_task if do_task != None else getattr(self, 'unit_output', None)
336
+
337
+ outputs = None
338
+ hidden_states = None
339
+ new_shift_labels = None
340
+ if do_task != "pretrain":
341
+ outputs = self.model(
342
+ input_ids=input_ids,
343
+ attention_mask=attention_mask,
344
+ past_key_values=past_key_values,
345
+ inputs_embeds=inputs_embeds,
346
+ use_cache=use_cache,
347
+ output_attentions=output_attentions,
348
+ output_hidden_states=output_hidden_states,
349
+ return_dict=return_dict,
350
+ audios=audios
351
+ )
352
+
353
+
354
+ hidden_states = outputs[0]
355
+ logits = self.lm_head(hidden_states)
356
+
357
+ loss = None
358
+ if labels is not None and do_task != "pretrain" and do_task != "finetune_kd":
359
+ if asr_targets is not None:
360
+ asr_logits = outputs["audio_features"]
361
+ asr_targets = asr_targets.view(asr_targets.shape[0] * asr_targets.shape[1], asr_targets.shape[2])
362
+ mask_asr_targets = (asr_targets != IGNORE_TOKEN_ID)
363
+ target_lengths = mask_asr_targets.sum(1)
364
+ input_lengths = torch.full(size=(asr_logits.shape[0],), fill_value=asr_logits.shape[1], dtype=torch.long)
365
+
366
+ loss_ctc = CTCLoss()
367
+
368
+ log_probs = F.log_softmax(asr_logits, dim=-1).transpose(0, 1)
369
+ #print(asr_targets.shape)
370
+ #print(input_lengths, target_lengths)
371
+
372
+ with torch.backends.cudnn.flags(enabled=False):
373
+ loss_asr = F.ctc_loss(
374
+ log_probs,
375
+ asr_targets,
376
+ input_lengths,
377
+ target_lengths,
378
+ blank=self.model.audio_tower[0].config.audio_patch_token,
379
+ reduction='mean',
380
+ zero_infinity=True,
381
+ )
382
+ else:
383
+ loss_asr=0
384
+
385
+ shift_labels = labels
386
+ if "label_shift" in outputs.keys() and len(outputs["label_shift"]) >0:
387
+ if outputs["label_extend"] != -1:
388
+ new_shift_labels = torch.full(size=(shift_labels.shape[0], outputs["label_extend"]+shift_labels.shape[1]), fill_value=IGNORE_TOKEN_ID, dtype=torch.long).to(shift_labels.device)
389
+ for batch in range(len(outputs["label_shift"])):
390
+ it_lable_shift = outputs["label_shift"][batch]
391
+ it_speech_pos = outputs["speech_pos"][batch]
392
+ prefix = 0
393
+ for i in range(len(it_lable_shift)):
394
+ if i == len(it_lable_shift) - 1:
395
+ length = shift_labels.shape[1] - it_speech_pos[i] #len(shift_labels[batch]) - it_speech_pos[i]
396
+ else:
397
+ length = it_speech_pos[i + 1] - it_speech_pos[i]
398
+ prefix += it_lable_shift[i]
399
+ new_shift_labels[batch][it_speech_pos[i] + prefix: it_speech_pos[i] + length + prefix]= shift_labels[batch][it_speech_pos[i]:it_speech_pos[i]+length]
400
+ shift_labels = new_shift_labels
401
+ else:
402
+ raise NotImplementedError
403
+
404
+ # Shift so that tokens < n predict n
405
+ shift_logits = logits[..., :-1, :].contiguous()
406
+ shift_labels = shift_labels[..., 1:].contiguous()
407
+ #print(shift_labels[:,:50])
408
+
409
+ #print(shift_labels[:,:150])
410
+ loss_fct = CrossEntropyLoss()
411
+ # Flatten the tokens
412
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
413
+ shift_labels = shift_labels.view(-1)
414
+
415
+ shift_labels = shift_labels.to(shift_logits.device)
416
+ loss = loss_fct(shift_logits, shift_labels)
417
+ loss = loss + 0.3 * loss_asr
418
+
419
+ t2u_output = None
420
+ if do_task != None and do_task != "skip":
421
+ if do_task == "finetune_kd":
422
+ text_start_index = []
423
+ for batch in range(len(outputs["label_shift"])):
424
+ text_start_index.append(outputs["speech_pos"][batch][0] + outputs["label_shift"][batch][0]+assistant_after_audio_shifts[batch])
425
+
426
+ t2u_embeds_output = self.unit_translator.insert_text_embedding(
427
+ input_ids=t2u_input_ids,
428
+ attention_mask=t2u_attention_mask,
429
+ inputs_embeds=None,
430
+ labels=t2u_labels,
431
+ text_labels=labels,
432
+ shift_text_labels=new_shift_labels,
433
+ shift_text_hidden_states=hidden_states,
434
+ unit_targets=unit_targets,
435
+ sub_lengths=sub_lengths,
436
+ text_start_index=text_start_index,
437
+ do_task=do_task,
438
+ )
439
+
440
+ vae_loss, t2u_inputs_embeds, unit_targets, t2u_attention_mask = t2u_embeds_output
441
+
442
+ t2u_output = self.unit_translator(
443
+ input_ids=None,
444
+ attention_mask=t2u_attention_mask,
445
+ past_key_values=past_key_values,
446
+ inputs_embeds=t2u_inputs_embeds,
447
+ use_cache=use_cache,
448
+ labels=unit_targets,
449
+ output_attentions=output_attentions,
450
+ output_hidden_states=output_hidden_states,
451
+ return_dict=return_dict,
452
+ )
453
+ else:
454
+ t2u_embeds_output = self.unit_translator.insert_text_embedding(
455
+ input_ids=t2u_input_ids,
456
+ attention_mask=t2u_attention_mask,
457
+ inputs_embeds=None,
458
+ labels=t2u_labels,
459
+ text_labels=labels,
460
+ shift_text_labels=new_shift_labels,
461
+ shift_text_hidden_states=hidden_states,
462
+ do_task=do_task,
463
+ )
464
+ vae_loss, t2u_inputs_embeds = t2u_embeds_output
465
+
466
+ t2u_output = self.unit_translator(
467
+ input_ids=None,
468
+ attention_mask=t2u_attention_mask,
469
+ past_key_values=past_key_values,
470
+ inputs_embeds=t2u_inputs_embeds,
471
+ use_cache=use_cache,
472
+ labels=t2u_labels,
473
+ output_attentions=output_attentions,
474
+ output_hidden_states=output_hidden_states,
475
+ return_dict=return_dict,
476
+ )
477
+ t2u_loss = t2u_output[0]
478
+ # print(do_task, t2u_loss, vae_loss)
479
+ if vae_loss != None:
480
+ target_scale = t2u_loss.item() * 0.2
481
+ vae_loss_weight = target_scale / vae_loss.item() if vae_loss > target_scale else 1.0
482
+ t2u_loss = t2u_loss + vae_loss_weight * vae_loss
483
+ #print(vae_loss)
484
+
485
+ if loss != None: # S2T + T2U loss
486
+ # ignore LLM loss
487
+ # t2u_output["loss"] = t2u_loss
488
+ # return t2u_output
489
+ # original version
490
+ assert do_task in ["finetune"]
491
+ if loss.item() < 1.0: # 1.7
492
+ loss = 0.2 * loss + t2u_loss * 2.0
493
+ else:
494
+ loss = loss + t2u_loss
495
+ else:
496
+ assert do_task in ["pretrain", "finetune_kd"]
497
+ t2u_output["loss"] = t2u_loss
498
+ return t2u_output
499
+
500
+ #return CausalLMOutputWithPast(
501
+ # loss=loss,
502
+ # logits=outputs["audio_features"],
503
+ #)
504
+
505
+ if not return_dict:
506
+ output = (logits,) + outputs[1:]
507
+ return (loss,) + output if loss is not None else output
508
+
509
+ return CausalLMOutputWithPast(
510
+ loss=loss,
511
+ logits=logits,
512
+ past_key_values=outputs.past_key_values,
513
+ hidden_states=outputs.hidden_states,
514
+ attentions=outputs.attentions,
515
+ )
516
+
517
+ def prepare_inputs_for_generation(
518
+ self,
519
+ input_ids,
520
+ past_key_values=None,
521
+ attention_mask=None,
522
+ inputs_embeds=None,
523
+ cache_position=None,
524
+ position_ids=None,
525
+ use_cache=True,
526
+ **kwargs,
527
+ ):
528
+ # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
529
+ # Exception 1: when passing input_embeds, input_ids may be missing entries
530
+ # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
531
+ if past_key_values is not None:
532
+ if inputs_embeds is not None: # Exception 1
533
+ input_ids = input_ids[:, -cache_position.shape[0] :]
534
+ elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
535
+ input_ids = input_ids[:, cache_position]
536
+
537
+ if attention_mask is not None and position_ids is None:
538
+ # create position_ids on the fly for batch generation
539
+ position_ids = attention_mask.long().cumsum(-1) - 1
540
+ position_ids.masked_fill_(attention_mask == 0, 1)
541
+ if past_key_values:
542
+ position_ids = position_ids[:, -input_ids.shape[1] :]
543
+
544
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
545
+ if inputs_embeds is not None and cache_position[0] == 0:
546
+ model_inputs = {"inputs_embeds": inputs_embeds}
547
+ else:
548
+ model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
549
+
550
+ model_inputs.update(
551
+ {
552
+ "position_ids": position_ids,
553
+ "cache_position": cache_position,
554
+ "past_key_values": past_key_values,
555
+ "use_cache": use_cache,
556
+ "attention_mask": attention_mask,
557
+ }
558
+ )
559
+ model_inputs.update({"audios": kwargs["audios"]} if "audios" in kwargs.keys() else {})
560
+ model_inputs.update({"do_task": kwargs["do_task"]} if "do_task" in kwargs.keys() else {})
561
+ model_inputs.update({"return_dict": kwargs["return_dict_in_generate"]} if "return_dict_in_generate" in kwargs.keys() else {})
562
+ return model_inputs
563
+
564
+ AutoConfig.register("ACLlama", ACLlamaConfig)
565
+ AutoModelForCausalLM.register(ACLlamaConfig, ACLlamaForCausalLM)
EchoX-Vocoder/checkpoint_last.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:abb7b49cc59bbf058719cdae2252069dce1bb7b73362a0ec5273670ed7a6d4cc
3
+ size 389348172
EchoX-Vocoder/config.json ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "input_wavs_dir": "/private/home/adampolyak/datasets/LJ/LJSpeech-1.1/wavs_16khz_padded",
3
+ "input_training_file": "/large_experiments/ust/annl/datasets/tts/LJSpeech/filelist/mhubert_vp_en_es_fr_it3_400k/lj_train_layer11_hubert1000_filelist.txt",
4
+ "input_validation_file": "/large_experiments/ust/annl/datasets/tts/LJSpeech/filelist/mhubert_vp_en_es_fr_it3_400k/lj_dev_layer11_hubert1000_filelist.txt",
5
+
6
+ "resblock": "1",
7
+ "num_gpus": 0,
8
+ "batch_size": 16,
9
+ "learning_rate": 0.0002,
10
+ "adam_b1": 0.8,
11
+ "adam_b2": 0.99,
12
+ "lr_decay": 0.999,
13
+ "seed": 1234,
14
+
15
+ "upsample_rates": [5,4,4,2,2],
16
+ "upsample_kernel_sizes": [11,8,8,4,4],
17
+ "upsample_initial_channel": 512,
18
+ "resblock_kernel_sizes": [3,7,11],
19
+ "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
20
+ "num_embeddings": 1000,
21
+ "embedding_dim": 128,
22
+ "model_in_dim": 128,
23
+
24
+ "segment_size": 8960,
25
+ "code_hop_size": 320,
26
+ "f0": false,
27
+ "num_mels": 80,
28
+ "num_freq": 1025,
29
+ "n_fft": 1024,
30
+ "hop_size": 256,
31
+ "win_size": 1024,
32
+
33
+ "dur_prediction_weight": 1.0,
34
+ "dur_predictor_params": {
35
+ "encoder_embed_dim": 128,
36
+ "var_pred_hidden_dim": 128,
37
+ "var_pred_kernel_size": 3,
38
+ "var_pred_dropout": 0.5
39
+ },
40
+
41
+ "sampling_rate": 16000,
42
+
43
+ "fmin": 0,
44
+ "fmax": 8000,
45
+ "fmax_for_loss": null,
46
+
47
+ "num_workers": 4,
48
+
49
+ "dist_config": {
50
+ "dist_backend": "nccl",
51
+ "dist_url": "env://"
52
+ }
53
+ }
EchoX-Vocoder/g_00500000 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0d1f7188b95b06304bc05e524fddf93c7fe682fdd93acff022685663a5e26b97
3
+ size 54051213
EchoX-Vocoder/re_config.log ADDED
File without changes
EchoX-Vocoder/spm_1k.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d95d5585291329feaf35d3cb39fe5181e4987549097a9daa36f468dab9e82556
3
+ size 254653
Echox_copy_stream.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from ACLlama_el_s2s import ACLlamaForCausalLM
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoConfig, WhisperProcessor
4
+ from peft import PeftModel, PeftConfig
5
+ import json
6
+ from tqdm import tqdm
7
+ import torch
8
+ import re
9
+ import os
10
+ torch.backends.cudnn.benchmark = False
11
+ import librosa
12
+ from text_to_speech import *
13
+ import torch.nn.functional as F
14
+ from concurrent.futures import ThreadPoolExecutor, as_completed
15
+
16
+ from transformers import logging as hf_logging
17
+ hf_logging.set_verbosity_error()
18
+ from huggingface_hub import hf_hub_download
19
+ from typing import Dict, Optional, List
20
+ import tempfile
21
+ import select
22
+ from copy import deepcopy
23
+ from typing import Generator, Tuple
24
+
25
+ os.environ["TOKENIZERS_PARALLELISM"] = "true"
26
+
27
+ def load_model(args, device):
28
+ quantization_config = None
29
+ hf_token = os.getenv("HF_TOKEN")
30
+
31
+ # load based model
32
+ model = ACLlamaForCausalLM.from_pretrained(
33
+ args.base_model_path,
34
+ device_map=None,
35
+ torch_dtype=torch.float16,
36
+ quantization_config=quantization_config,
37
+ token=hf_token,
38
+ ).eval().to(device)
39
+ for module in model.model.audio_tower:
40
+ module = module.to(device)
41
+
42
+ if args.peft_model_id:
43
+ lora_config = PeftConfig.from_pretrained(args.peft_model_id)
44
+ torch.cuda.empty_cache()
45
+ model = PeftModel.from_pretrained(model, args.peft_model_id, config=lora_config).to(
46
+ dtype=torch.float16, device=device
47
+ )
48
+ model = model.merge_and_unload()
49
+
50
+ model.eval()
51
+
52
+ # load tokenizer
53
+ tokenizer = AutoTokenizer.from_pretrained(args.base_model_path, token=hf_token)
54
+
55
+ audio_config = model.get_model().audio_tower[0].config
56
+ audio_config.audio_patch_token = tokenizer.get_vocab()["<audio_patch>"]
57
+ audio_config.llm_pad_token_id = tokenizer.pad_token_id
58
+ audio_config.audio_patch_size = args.audio_token_len
59
+
60
+
61
+ # whisper processor
62
+ audio_processor = WhisperProcessor.from_pretrained(args.audio_tower, torch_dtype=torch.float16)
63
+
64
+ # t2u
65
+ unit_translator = model.get_unit_translator().eval()
66
+ return model, audio_processor, tokenizer, unit_translator
67
+
68
+ def load_speech_model(device):
69
+ vocoder = "./EchoX-Vocoder/g_00500000"
70
+ vocoder_cfg = "./EchoX-Vocoder/config.json"
71
+ voc_cfg = get_vocoder_config(vocoder, vocoder_cfg)
72
+ vocoder = load_units_vocoder(voc_cfg, device)
73
+ return vocoder, voc_cfg
74
+
75
+ # def load_speech_model(device):
76
+ # hf_token = os.getenv("HF_TOKEN")
77
+
78
+ # vocoder_repo_id = "FreedomIntelligence/EchoX-Vocoder"
79
+
80
+ # cache_path = './hf_cache'
81
+ # vocoder_path = hf_hub_download(repo_id=vocoder_repo_id, filename="g_00500000", token=hf_token, cache_dir=cache_path)
82
+ # vocoder_cfg_path = hf_hub_download(repo_id=vocoder_repo_id, filename="config.json", token=hf_token, cache_dir=cache_path)
83
+
84
+ # voc_cfg = get_vocoder_config(vocoder_path, vocoder_cfg_path)
85
+ # vocoder = load_units_vocoder(voc_cfg, device)
86
+ # return vocoder, voc_cfg
87
+
88
+ class EchoxAssistant():
89
+ def __init__(self):
90
+ class BasicSetting:
91
+ def __init__(self):
92
+ self.device = "cuda:0"
93
+ self.sampling_rate = 16000
94
+ self.audio_token_len = 1 # 1500 = 300 token x 5 compress
95
+ self.stop = "</s>"
96
+ self.base_model_path = "FreedomIntelligence/EchoX-8B"
97
+ self.peft_model_id = None
98
+ self.audio_tower = "openai/whisper-large-v3"
99
+ self.args = BasicSetting()
100
+ self.device = "cuda"
101
+ self.vocoder, self.voc_cfg= load_speech_model(self.device)
102
+ self.model, self.audio_processor, self.tokenizer, self.unit_translator = load_model(self.args, self.device)
103
+ self.audio_executor = ThreadPoolExecutor(max_workers=2)
104
+ # self.specAug = SpecAugmentTransform()
105
+ # special_token
106
+ DEFAULT_AUDIO_PATCH_TOKEN = "<audio_patch>"
107
+ audio_placeholder = DEFAULT_AUDIO_PATCH_TOKEN * self.args.audio_token_len
108
+ audio_placeholder = "\n"+audio_placeholder
109
+ self.audio_placeholder_ids = self.tokenizer(audio_placeholder).input_ids
110
+
111
+ self.begin_of_text_id = self.tokenizer.get_vocab()["<|begin_of_text|>"]
112
+ self.start_header_id = self.tokenizer.get_vocab()["<|start_header_id|>"]
113
+ self.end_header_id = self.tokenizer.get_vocab()["<|end_header_id|>"]
114
+ self.eot_id = self.tokenizer.get_vocab()["<|eot_id|>"]
115
+ self.nl_tokens = self.tokenizer('\n').input_ids
116
+ self._system = self.tokenizer('system').input_ids
117
+ self._user = self.tokenizer('user').input_ids
118
+ self._assistant = self.tokenizer('assistant').input_ids
119
+ self._speaker = self.tokenizer('speaker').input_ids
120
+
121
+ self.max_len = 1024
122
+ self.unit_max_len = 2048
123
+ self.system_message = "You are a helpful language and speech assistant. You are able to understand the speech content that the user provides, and assist the user with a variety of tasks using natural language."
124
+
125
+ def _generate_audio_segment(self, segment_hidden_states):
126
+ try:
127
+ audio_units = self._generate_audio_units_from_hidden_states(segment_hidden_states)
128
+ if audio_units:
129
+ audio_float32 = self.generate_with_speech_model([list(map(int, audio_units.split(" ")))])
130
+ audio_int16 = (audio_float32 * 32767).astype(np.int16)
131
+
132
+ print(f"Generated audio segment in background: {len(audio_units.split())} units")
133
+ return (16000, audio_int16)
134
+ return None
135
+ except Exception as e:
136
+ print(f"Background audio generation error: {e}")
137
+ return None
138
+
139
+ def gen_model_inputs(
140
+ self,
141
+ sources,
142
+ tokenizer,
143
+ max_len,
144
+ system_message,
145
+ audio_placeholder_ids, begin_of_text_id, start_header_id, end_header_id, eot_id, nl_tokens, _system, _user, _assistant,
146
+ ) -> dict:
147
+ # max_len 512
148
+
149
+ # Apply prompt templates
150
+ input_ids, audio_paths = [], []
151
+ audio_path = []
152
+
153
+ for source in sources:
154
+ input_id = []
155
+ system = [begin_of_text_id] + [start_header_id] + _system + [end_header_id] + nl_tokens + tokenizer(system_message).input_ids + [eot_id]
156
+ input_id += system
157
+
158
+ for j, item in enumerate(source["conversations"]):
159
+ role = item["from"]
160
+ value = item["value"]
161
+ _audio_path = None
162
+
163
+ if role == 'user':
164
+ if "audio" in item.keys():
165
+ _input_id = [start_header_id] + _user + [end_header_id] + audio_placeholder_ids + tokenizer(value).input_ids + [eot_id]
166
+ _audio_path = item["audio"]
167
+ else:
168
+ _input_id = [start_header_id] + _user + [end_header_id] + tokenizer(value).input_ids + [eot_id]
169
+
170
+ elif role == 'assistant':
171
+ _input_id = [start_header_id] + _assistant + [end_header_id] + nl_tokens + tokenizer(value).input_ids + [eot_id]
172
+
173
+ else:
174
+ raise NotImplementedError
175
+ input_id += _input_id
176
+
177
+ if _audio_path:
178
+ audio_path.append(_audio_path)
179
+ assistant_input_id = [start_header_id] + _assistant + [end_header_id] + nl_tokens
180
+ input_id += assistant_input_id
181
+
182
+ audio_num = int(input_id.count(audio_placeholder_ids[-1]) / self.args.audio_token_len)
183
+ assert len(audio_path) == audio_num
184
+ if len(input_id) >= max_len:
185
+ print(f"[WARNING] Your Input Length More Than {max_len}")
186
+ input_ids.append(input_id[:max_len])
187
+ audio_paths.append(audio_path)
188
+ input_ids = torch.tensor(input_ids, dtype=torch.int)
189
+ return dict(
190
+ input_ids=input_ids,
191
+ audio_paths=audio_paths,
192
+ attention_mask=input_ids.ne(tokenizer.pad_token_id),
193
+ )
194
+
195
+ def get_unit_result(self, ret):
196
+ # print(ret)
197
+ self.unit_translator.generation_config.pad_token_id = self.tokenizer.eos_token_id
198
+ input_ids = ret["input_ids"]
199
+ ret["input_ids"] = None
200
+ model_outputs = self.unit_translator.generate(
201
+ **ret,
202
+ max_new_tokens=2048,
203
+ eos_token_id=self.tokenizer.eos_token_id,
204
+ )
205
+ # print(model_outputs, model_outputs.shape)
206
+ output_ids = model_outputs
207
+ unit_output = self.tokenizer.batch_decode(output_ids)[0]
208
+ if "▁" in unit_output:
209
+ unit_output = ''.join(re.findall(r"<\|unit_(.*?)\|>", unit_output))
210
+
211
+ units = re.findall(r'\d+', unit_output)
212
+
213
+ #TODO grid of unk unit
214
+ new_units = []
215
+ for unit in units:
216
+ if int(unit) < 1000:
217
+ new_units.append(unit)
218
+
219
+ units = ' '.join(new_units)
220
+ return units
221
+
222
+
223
+ def _inference(
224
+ self,
225
+ prompt,
226
+ **kwargs,
227
+ ):
228
+ audio_paths = []
229
+ response = []
230
+ for item in prompt:
231
+ for conv in item["conversations"]:
232
+ if "audio" in conv:
233
+ audio_paths.append(conv["audio"])
234
+
235
+ model_inputs = self.gen_model_inputs(
236
+ prompt,
237
+ self.tokenizer,
238
+ self.max_len,
239
+ self.system_message,
240
+ self.audio_placeholder_ids, self.begin_of_text_id, self.start_header_id, self.end_header_id, self.eot_id, self.nl_tokens, self._system, self._user, self._assistant)
241
+
242
+ audio_list = []
243
+ if audio_paths and audio_paths[0] is not None:
244
+ for audio_path in audio_paths:
245
+ # print("read audio file name: ", audio_path)
246
+ audio, _ = librosa.load(audio_path, sr=self.args.sampling_rate)
247
+ audio_feat = self.audio_processor(audio, sampling_rate=self.args.sampling_rate, return_tensors="pt").input_features
248
+ audio_list.append(audio_feat)
249
+ audio_feats = torch.stack(audio_list, dim=0)
250
+ audio_feats = audio_feats.to(dtype=torch.float16).to(self.device)
251
+
252
+ if not audio_list:
253
+ ret = dict(
254
+ input_ids=model_inputs["input_ids"].to(self.device),
255
+ attention_mask=model_inputs["attention_mask"].to(self.device),
256
+ )
257
+ else:
258
+ ret = dict(
259
+ input_ids=model_inputs["input_ids"].to(self.device),
260
+ attention_mask=model_inputs["attention_mask"].to(self.device),
261
+ audios=audio_feats,
262
+ )
263
+
264
+ self.model.generation_config.pad_token_id = self.tokenizer.eos_token_id
265
+ #print(self.model.lm_head.weight.shape)
266
+
267
+ dot_input_ids = self.tokenizer(".", return_tensors="pt").input_ids.to(self.device) # 形状: (1, 2), 值: [[128000, 13]]
268
+ period_token_id = dot_input_ids[0, -1]
269
+ period_lm_head_embedding = self.model.lm_head.weight[period_token_id]
270
+
271
+ input_ids = ret["input_ids"]
272
+ attention_mask = ret["attention_mask"]
273
+ input_token_len = input_ids.shape[1]
274
+
275
+ max_new_tokens = kwargs.get('max_new_tokens', 512)
276
+ temperature = kwargs.get('temperature', 0.2)
277
+ top_p = kwargs.get('top_p', 0.9)
278
+ do_sample = kwargs.get('do_sample', True)
279
+
280
+ current_text = ""
281
+ accumulated_hidden_states = []
282
+ accumulated_tokens = []
283
+ similarity_scores = []
284
+ segment_start_idx = 0
285
+
286
+ current_input_ids = input_ids
287
+ current_attention_mask = attention_mask
288
+ past_key_values = None
289
+
290
+ audio_futures = []
291
+ segmentation_latency = 5
292
+
293
+ with torch.no_grad():
294
+ for step in range(max_new_tokens):
295
+ while audio_futures and audio_futures[0].done():
296
+ completed_future = audio_futures.pop(0)
297
+ audio_data = completed_future.result()
298
+ if audio_data:
299
+ yield None, audio_data
300
+
301
+ if current_input_ids is None:
302
+ break
303
+
304
+ model_kwargs = {
305
+ "input_ids": current_input_ids,
306
+ "attention_mask": current_attention_mask,
307
+ "past_key_values": past_key_values,
308
+ "use_cache": True,
309
+ "output_hidden_states": True,
310
+ "do_task": "skip"
311
+ }
312
+
313
+ if step == 0 and "audios" in ret:
314
+ model_kwargs["audios"] = ret["audios"]
315
+
316
+ outputs = self.model(**model_kwargs)
317
+
318
+ logits = outputs.logits
319
+ hidden_states = outputs.hidden_states[-1]
320
+ past_key_values = outputs.past_key_values
321
+
322
+ next_token_logits = logits[:, -1, :] # [batch_size, vocab_size]
323
+
324
+ if do_sample:
325
+ next_token_logits = next_token_logits / temperature
326
+
327
+ sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
328
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
329
+
330
+ sorted_indices_to_remove = cumulative_probs > top_p
331
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
332
+ sorted_indices_to_remove[..., 0] = 0
333
+
334
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
335
+ next_token_logits[indices_to_remove] = float('-inf')
336
+
337
+ probs = F.softmax(next_token_logits, dim=-1)
338
+ next_token = torch.multinomial(probs, num_samples=1)
339
+ else:
340
+ next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
341
+
342
+ if next_token.item() == self.tokenizer.eos_token_id:
343
+ current_input_ids = None
344
+ continue
345
+
346
+ accumulated_tokens.append(next_token.item())
347
+ last_hidden_state = hidden_states[0, -1] # [hidden_dim]
348
+ accumulated_hidden_states.append(last_hidden_state)
349
+
350
+ similarity = F.cosine_similarity(last_hidden_state, period_lm_head_embedding, dim=0).item()
351
+ similarity_scores.append(similarity)
352
+
353
+ token_text = self.tokenizer.decode([next_token.item()], skip_special_tokens=True)
354
+ current_text += token_text
355
+
356
+ yield current_text, None
357
+
358
+ current_idx = len(similarity_scores) - 1
359
+ check_idx = current_idx - segmentation_latency
360
+ if check_idx >= 0:
361
+ similarity_at_check = similarity_scores[check_idx]
362
+ is_peak = self._is_local_maximum(similarity_scores, check_idx, window=segmentation_latency)
363
+ should_segment = (is_peak and
364
+ check_idx - segment_start_idx >= 50) or (
365
+ is_peak and
366
+ similarity_at_check > 0.1 and
367
+ check_idx - segment_start_idx >= 20
368
+ )
369
+
370
+ if should_segment:
371
+ segment_end_idx = check_idx + 1
372
+ print(f"Segmenting at step {segment_end_idx-1}, similarity={similarity_at_check:.4f}. Submitting to background audio generation.")
373
+
374
+ segment_hidden_states = torch.stack(
375
+ accumulated_hidden_states[segment_start_idx:segment_end_idx], dim=0
376
+ ).unsqueeze(0)
377
+
378
+ future = self.audio_executor.submit(self._generate_audio_segment, segment_hidden_states)
379
+ audio_futures.append(future)
380
+
381
+ segment_start_idx = segment_end_idx
382
+
383
+ current_input_ids = next_token
384
+ current_attention_mask = torch.ones_like(next_token)
385
+
386
+ if segment_start_idx < len(accumulated_hidden_states):
387
+ print(f"Processing final segment from {segment_start_idx} to {len(accumulated_hidden_states)}")
388
+ segment_hidden_states = torch.stack(
389
+ accumulated_hidden_states[segment_start_idx:], dim=0
390
+ ).unsqueeze(0)
391
+ future = self.audio_executor.submit(self._generate_audio_segment, segment_hidden_states)
392
+ audio_futures.append(future)
393
+
394
+ for future in audio_futures:
395
+ audio_data = future.result()
396
+ if audio_data:
397
+ yield None, audio_data
398
+
399
+ def _is_local_maximum(self, scores, idx, window=5):
400
+ start = max(0, idx - window)
401
+ end = min(len(scores), idx + window + 1)
402
+ local_scores = scores[start:end]
403
+ return scores[idx] == max(local_scores)
404
+
405
+ def _generate_audio_units_from_hidden_states(self, hidden_states):
406
+ try:
407
+ _, adapted_inputs_embeds = self.unit_translator.insert_text_embedding(
408
+ inputs_embeds=hidden_states,
409
+ do_task="skip",
410
+ )
411
+
412
+ attention_mask = torch.ones(adapted_inputs_embeds.shape[:2]).to(self.device)
413
+ ret = dict(
414
+ input_ids=None,
415
+ inputs_embeds=adapted_inputs_embeds,
416
+ attention_mask=attention_mask,
417
+ )
418
+
419
+ return self.get_unit_result(ret)
420
+ except Exception as e:
421
+ print(f"Error generating audio units: {e}")
422
+ return None
423
+
424
+ def generate_with_speech_model(self, units):
425
+ wav = gen_wav(self.vocoder, self.voc_cfg, units, self.device)
426
+ return wav
T2ULlama_CR_online.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch.nn import CrossEntropyLoss, CTCLoss
7
+ import transformers
8
+
9
+ from transformers import AutoConfig, AutoModelForCausalLM, \
10
+ LlamaConfig, LlamaModel, LlamaForCausalLM
11
+ from transformers.trainer_pt_utils import LabelSmoother
12
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
13
+ from transformers import (
14
+ WhisperProcessor,
15
+ WhisperModel,
16
+ )
17
+
18
+ IGNORE_TOKEN_ID = LabelSmoother.ignore_index
19
+
20
+
21
+ def padding_tensor(tensor, length, dim=0, pad=False):
22
+
23
+ if length == 0:
24
+ return tensor
25
+
26
+ assert length > 0, f"Wrong padding length: {length}"
27
+
28
+ shape = list(tensor.shape)
29
+ assert dim < len(shape), f"dim {dim} out of shape {shape}"
30
+ shape[dim] = length
31
+ padding_tensor = torch.cat(
32
+ (
33
+ tensor,
34
+ torch.full(tuple(shape), pad, dtype=tensor.dtype, device=tensor.device)
35
+ ),
36
+ dim=dim
37
+ )
38
+ return padding_tensor
39
+
40
+
41
+ class T2ULlamaConfig(LlamaConfig):
42
+ model_type = "T2ULlama"
43
+
44
+ class T2ULlamaForCausalLM(LlamaForCausalLM):
45
+ config_class = T2ULlamaConfig
46
+
47
+ def __init__(self, config, embedding_weight=None):
48
+
49
+ self.current_step = 0
50
+ self.log = {}
51
+
52
+ super(LlamaForCausalLM, self).__init__(config)
53
+ self.config = config
54
+ self.training_stage = config.unit_output
55
+ self.pad_token_id = 128009
56
+
57
+ llama_config = T2ULlamaConfig(**config.to_dict(),
58
+ batch_first=True,
59
+ norm_first=True
60
+ )
61
+ llama_config.architectures = ["T2ULlamaForCausalLM"]
62
+ llama_config.pad_token_id = self.pad_token_id
63
+ llama_config.vocab_size += llama_config.unit_vocab_size
64
+ #######################################################
65
+ llama_config.unit_model = "medium"
66
+ llama_config.max_position_embeddings = 2048 # 1024 1536 2048 # origin 1024 reduced 512
67
+ #######################################################
68
+ if hasattr(llama_config, "unit_model"):
69
+ if llama_config.unit_model == "large":
70
+ llama_config.num_hidden_layers = 2
71
+ # llama_config.hidden_size = 4096
72
+ # llama_config.num_attention_heads = 32
73
+ # llama_config.intermediate_size = 14336
74
+ # llama_config.head_dim = llama_config.hidden_size // llama_config.num_attention_heads
75
+
76
+ elif llama_config.unit_model == "tiny":
77
+ llama_config.num_hidden_layers = 4
78
+ llama_config.hidden_size = 512
79
+ llama_config.num_attention_heads = 8
80
+ llama_config.intermediate_size = 2048
81
+ llama_config.head_dim = llama_config.hidden_size // llama_config.num_attention_heads
82
+ else:
83
+ llama_config.num_hidden_layers = 8
84
+ llama_config.hidden_size = 768
85
+ llama_config.num_attention_heads = 12
86
+ llama_config.num_key_value_heads = 12
87
+ llama_config.intermediate_size = 2048
88
+ llama_config.head_dim = llama_config.hidden_size // llama_config.num_attention_heads
89
+ else:
90
+ llama_config.num_hidden_layers = 6
91
+ llama_config.hidden_size = 512
92
+ llama_config.num_attention_heads = 8
93
+ llama_config.intermediate_size = 2048
94
+ llama_config.head_dim = llama_config.hidden_size // llama_config.num_attention_heads
95
+ # print(llama_config)
96
+
97
+ self.model = LlamaModel(llama_config)
98
+ # share embedding 0501 by kkq
99
+ self.model.embed_tokens = nn.Embedding(num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, padding_idx=self.pad_token_id) # redefine
100
+ self.unit_embedding = nn.Linear(config.hidden_size, llama_config.unit_vocab_size, bias=False)
101
+ self.adapter = nn.Linear(config.hidden_size, llama_config.hidden_size, bias = True)
102
+ self.lm_head = nn.Linear(llama_config.hidden_size, llama_config.vocab_size, bias=False)
103
+
104
+ if self.training_stage == "pretrain":
105
+ pass
106
+ elif self.training_stage == "finetune" or self.training_stage == "finetune_kd" or self.training_stage == "finetune_kd_online":
107
+ self.aligner_MLP = nn.Sequential(
108
+ nn.Linear(config.hidden_size, config.intermediate_size),
109
+ nn.ReLU(),
110
+ nn.Dropout(0.1),
111
+ nn.Linear(config.intermediate_size, config.hidden_size),
112
+ )
113
+ torch.nn.init.ones_(self.aligner_MLP[0].weight)
114
+ torch.nn.init.zeros_(self.aligner_MLP[0].bias)
115
+ torch.nn.init.ones_(self.aligner_MLP[3].weight)
116
+ torch.nn.init.zeros_(self.aligner_MLP[3].bias)
117
+
118
+ # Initialize weights and apply final processing
119
+ self.post_init()
120
+
121
+ def get_model(self):
122
+ return self.model
123
+
124
+ def insert_text_embedding(
125
+ self,
126
+ input_ids: torch.LongTensor = None,
127
+ attention_mask: Optional[torch.Tensor] = None,
128
+ inputs_embeds: Optional[torch.FloatTensor] = None,
129
+ labels: Optional[torch.LongTensor] = None,
130
+ text_labels: Optional[torch.LongTensor] = None,
131
+ shift_text_labels: Optional[torch.LongTensor] = None,
132
+ shift_text_hidden_states: Optional[torch.FloatTensor] = None,
133
+ unit_targets: Optional[torch.LongTensor] = None,
134
+ sub_lengths: Optional[torch.LongTensor] = None,
135
+ text_start_index: Optional[torch.LongTensor] = None,
136
+ do_task: str = None,
137
+ **kwargs: dict,
138
+ ):
139
+
140
+ if inputs_embeds == None:
141
+ # share embedding 0501 by kkq
142
+ embed_tokens_weight = torch.cat(
143
+ [
144
+ self.model.embed_tokens.weight.detach(), self.unit_embedding.weight
145
+ ],
146
+ dim = 0,
147
+ )
148
+ # print(embed_tokens_weight, embed_tokens_weight.shape)
149
+ inputs_embeds = F.embedding(input_ids, embed_tokens_weight, padding_idx=self.pad_token_id)
150
+
151
+ emb_loss = None
152
+ if do_task == "pretrain":
153
+ if self.training:
154
+ if hasattr(self, "embedding_dropout"):
155
+ emb_origin_mask = text_labels != -100
156
+ origin_padding_length = labels.shape[-1] - emb_origin_mask.shape[-1]
157
+ extend_emb_origin_mask = padding_tensor(emb_origin_mask, origin_padding_length, 1, False)
158
+ extend_emb_origin_mask = ~extend_emb_origin_mask.unsqueeze(-1).expand_as(inputs_embeds)
159
+
160
+ # Π-Model + noise
161
+ log_var = self.perturb(inputs_embeds)
162
+ perturbed_inputs_embeds_2 = inputs_embeds + torch.randn_like(inputs_embeds) * (torch.exp(0.5 * log_var) + 1e-6)
163
+ # Π-Model + dropout
164
+ perturbed_inputs_embeds_1 = self.embedding_dropout(inputs_embeds)
165
+ perturbed_inputs_embeds_2 = self.embedding_dropout(perturbed_inputs_embeds_2)
166
+ perturbed_inputs_embeds_1 = torch.where(extend_emb_origin_mask, inputs_embeds, perturbed_inputs_embeds_1)
167
+ perturbed_inputs_embeds_2 = torch.where(extend_emb_origin_mask, inputs_embeds, perturbed_inputs_embeds_2)
168
+
169
+ inputs_embeds = torch.cat(
170
+ (perturbed_inputs_embeds_1, perturbed_inputs_embeds_2),
171
+ dim=0,
172
+ )
173
+
174
+ kl_loss = -0.5 * (1 + log_var - log_var.exp()).mean(dim=-1).sum(dim=-1).mean()
175
+ contrastive_loss = (1 - F.cosine_similarity(perturbed_inputs_embeds_1, perturbed_inputs_embeds_2, dim=-1)).sum(dim=-1).mean()
176
+ emb_loss = kl_loss + contrastive_loss
177
+
178
+ if kl_loss.device == torch.device("cuda:0"):
179
+ self.log["kl_loss"] = kl_loss.item()
180
+ self.log["std"] = torch.exp(0.5 * log_var).mean().item()
181
+ self.log["contrastive_loss"] = contrastive_loss.item()
182
+
183
+ pass
184
+ elif do_task == "finetune":
185
+ inputs_embeds = inputs_embeds.detach()
186
+ inputs_embeds_refer = inputs_embeds.clone().detach()
187
+ shift_text_hidden_states = self.aligner_MLP(shift_text_hidden_states)
188
+ emb_origin_mask = text_labels != -100 # get output text pos
189
+ emb_shift_mask = shift_text_labels != -100
190
+
191
+ origin_padding_length = labels.shape[-1] - emb_origin_mask.shape[-1]
192
+ shift_padding_length = labels.shape[-1] - emb_shift_mask.shape[-1]
193
+
194
+ extend_emb_origin_mask = padding_tensor(emb_origin_mask, origin_padding_length, 1, False)
195
+ extend_emb_shift_mask = padding_tensor(emb_shift_mask, shift_padding_length, 1, False)
196
+ extend_shift_text_hidden_states = padding_tensor(shift_text_hidden_states, shift_padding_length, 1, 1e-9)
197
+ # check
198
+ extend_text_labels = padding_tensor(text_labels, origin_padding_length, 1, -100)
199
+ extend_shift_text_labels = padding_tensor(shift_text_labels, shift_padding_length, 1, -100)
200
+
201
+ assert torch.equal(
202
+ extend_text_labels[extend_emb_origin_mask],
203
+ extend_shift_text_labels[extend_emb_shift_mask]
204
+ ), "{}\n{}\n{}\n{}".format(labels, extend_emb_origin_mask, extend_shift_text_labels, extend_emb_shift_mask)
205
+
206
+ inputs_embeds[extend_emb_origin_mask.unsqueeze(-1).expand_as(inputs_embeds)] = \
207
+ extend_shift_text_hidden_states[extend_emb_shift_mask.unsqueeze(-1).expand_as(extend_shift_text_hidden_states)].to(dtype=inputs_embeds.dtype)
208
+
209
+ if self.training:
210
+ contrastive_loss = (1 - F.cosine_similarity(inputs_embeds, inputs_embeds_refer, dim=-1)).sum(-1).mean()
211
+ emb_loss = contrastive_loss
212
+ if emb_loss.device == torch.device("cuda:0"):
213
+ self.log["contrastive_loss"] = contrastive_loss.item()
214
+ pass
215
+ elif do_task == "finetune_kd" :
216
+ #inputs_embeds = inputs_embeds.detach()
217
+ #inputs_embeds_refer = inputs_embeds.clone().detach()
218
+ #print(text_labels)
219
+ #print(sub_lengths.sum())
220
+ emb_origin_mask = text_labels != -100
221
+
222
+ fetch_lables_list = []
223
+ for batch in range(emb_origin_mask.shape[0]):
224
+ fetch_lables_list.append(text_labels[batch][emb_origin_mask[batch]])
225
+ shift_text_hidden_states = self.aligner_MLP(shift_text_hidden_states)
226
+
227
+ #split the shift_text_hidden_states
228
+ #[128006, 128000, 78191, 128007, 128000, 198, 128000]
229
+ maxn_length = sub_lengths.max() + 8
230
+ pad_ids = torch.full(size=(sub_lengths.shape[0], sub_lengths.shape[1], maxn_length), fill_value=self.pad_token_id, dtype=torch.long).to(shift_text_hidden_states.device)
231
+
232
+ pad_text_ids = torch.full(size=(sub_lengths.shape[0], sub_lengths.shape[1], maxn_length), fill_value=self.pad_token_id, dtype=torch.long).to(shift_text_hidden_states.device)
233
+
234
+ atten_mask = pad_ids.ne(self.pad_token_id)
235
+ #target_mask_part1 = pad_ids.ne(self.pad_token_id)
236
+ shift_text_hidden_states_slice = F.embedding(pad_ids, embed_tokens_weight, padding_idx=self.pad_token_id)
237
+
238
+ #print(shift_text_hidden_states_slice.shape,shift_text_hidden_states.shape)
239
+ for batch in range(sub_lengths.shape[0]):
240
+ cot=0
241
+ start_index = text_start_index[batch]
242
+ for index, sub_length in enumerate(sub_lengths[batch]):
243
+ if sub_length==-1:
244
+ break
245
+ #print(shift_text_hidden_states_slice[batch][index][:sub_length].shape, shift_text_hidden_states[batch][cot:cot+sub_length].shape)
246
+ eos_id = torch.IntTensor([128009]).to(inputs_embeds.device)
247
+ eos = self.model.embed_tokens(eos_id)
248
+ if index == 0:
249
+ text_prefix_ids = torch.IntTensor([128006, 128000, 65576, 128007, 128000, 198]).to(inputs_embeds.device)
250
+ preifx_embed = self.model.embed_tokens(text_prefix_ids)
251
+ pad_text_ids[batch][index][:sub_length+7] = torch.cat([text_prefix_ids, fetch_lables_list[batch][cot:cot+sub_length], eos_id],dim=0)
252
+ atten_mask[batch][index][:sub_length+7]=True
253
+ else:
254
+ text_prefix_ids = torch.IntTensor([128006, 128000, 65576, 128007, 128000, 198, 12800]).to(inputs_embeds.device)
255
+ preifx_embed = self.model.embed_tokens(text_prefix_ids)
256
+ pad_text_ids[batch][index][:sub_length+8] = torch.cat([text_prefix_ids, fetch_lables_list[batch][cot:cot+sub_length], eos_id], dim=0)
257
+ atten_mask[batch][index][:sub_length+8]=True
258
+
259
+ new_shift_text_hidden_states = torch.cat([preifx_embed, shift_text_hidden_states[batch][start_index+cot:start_index+cot+sub_length], eos], dim = 0)
260
+ shift_text_hidden_states_slice[batch][index][:new_shift_text_hidden_states.shape[0]] = new_shift_text_hidden_states
261
+
262
+ cot+=sub_length
263
+ shift_text_hidden_states_slice = shift_text_hidden_states_slice.reshape(shift_text_hidden_states_slice.shape[0]*shift_text_hidden_states_slice.shape[1],shift_text_hidden_states_slice.shape[2],shift_text_hidden_states_slice.shape[3])
264
+
265
+
266
+ padding_unit_targets = unit_targets.clone()
267
+ padding_unit_targets = torch.where(padding_unit_targets == IGNORE_TOKEN_ID, self.pad_token_id, padding_unit_targets)
268
+ target_mask_part = padding_unit_targets.ne(self.pad_token_id)
269
+ atten_mask = torch.cat([atten_mask, target_mask_part], dim = -1)
270
+ atten_mask = atten_mask.reshape(atten_mask.shape[0]*atten_mask.shape[1],atten_mask.shape[2])
271
+
272
+ pad_text_ids = pad_text_ids.reshape(pad_text_ids.shape[0]*pad_text_ids.shape[1],pad_text_ids.shape[2])
273
+ shift_text_embeddings = F.embedding(pad_text_ids, embed_tokens_weight, padding_idx=self.pad_token_id)
274
+
275
+ unit_target_slice = F.embedding(padding_unit_targets, embed_tokens_weight, padding_idx=self.pad_token_id)
276
+ # unit_target_slice = F.embedding(unit_targets, embed_tokens_weight, padding_idx=self.pad_token_id)
277
+ unit_target_slice = unit_target_slice.reshape(unit_target_slice.shape[0]*unit_target_slice.shape[1],unit_target_slice.shape[2],unit_target_slice.shape[3])
278
+
279
+ inputs_embeds = torch.cat([shift_text_hidden_states_slice, unit_target_slice], dim = 1)
280
+
281
+ ignore_ids = torch.full(size=(sub_lengths.shape[0], sub_lengths.shape[1], maxn_length), fill_value=IGNORE_TOKEN_ID, dtype=torch.long).to(shift_text_hidden_states.device)
282
+ unit_targets = torch.cat([ignore_ids,unit_targets],dim=-1)
283
+ unit_targets = unit_targets.reshape(unit_targets.shape[0]*unit_targets.shape[1],unit_targets.shape[2])
284
+
285
+ if self.training:
286
+ #print(shift_text_hidden_states_slice.shape, shift_text_embeddings.shape)
287
+ contrastive_loss = (1 - F.cosine_similarity(shift_text_hidden_states_slice, shift_text_embeddings, dim=-1)).sum(-1).mean()
288
+ emb_loss = contrastive_loss
289
+ if emb_loss.device == torch.device("cuda:0"):
290
+ self.log["contrastive_loss"] = contrastive_loss.item()
291
+
292
+ elif do_task == "finetune_kd_online":
293
+ shift_text_hidden_states = self.aligner_MLP(shift_text_hidden_states)
294
+ gold_inputs_embeds = inputs_embeds.clone()
295
+ for batch in range (inputs_embeds.shape[0]):
296
+ start_index = text_start_index[batch]
297
+ for slice_index in range (inputs_embeds.shape[1]):
298
+ sub_length= sub_lengths[batch][slice_index]
299
+ inputs_embeds[batch][slice_index][7:7+sub_length] = shift_text_hidden_states[batch][start_index+1:start_index+1+sub_length]
300
+ start_index += sub_length
301
+ if self.training:
302
+ #print(shift_text_hidden_states_slice.shape, shift_text_embeddings.shape)
303
+ contrastive_loss = ((1 - F.cosine_similarity(inputs_embeds, gold_inputs_embeds, dim=-1)) * attention_mask).sum(-1).mean()
304
+ emb_loss = contrastive_loss
305
+ if emb_loss.device == torch.device("cuda:0"):
306
+ self.log["contrastive_loss"] = contrastive_loss.item()
307
+ unit_embeds = F.embedding(unit_targets, embed_tokens_weight, padding_idx=self.pad_token_id)
308
+
309
+ inputs_embeds = torch.cat([inputs_embeds,unit_embeds], dim=2)
310
+ else:
311
+ inputs_embeds = self.aligner_MLP(inputs_embeds)
312
+ #[start_header_id] + _speaker + [end_header_id] + nl_tokens only for batch one!
313
+ units_ids = torch.IntTensor([[128009, 128006, 128000, 65576, 128007, 128000, 198]]).to(inputs_embeds.device)
314
+ units_prefix = self.model.embed_tokens(units_ids)
315
+ text_ids = torch.IntTensor([[128006, 128000, 65576, 128007, 128000, 198, 12800]]).to(inputs_embeds.device)
316
+ text_prefix = self.model.embed_tokens(text_ids)
317
+ inputs_embeds = torch.cat([text_prefix, inputs_embeds, units_prefix], dim = 1)
318
+
319
+ inputs_embeds = self.adapter(inputs_embeds)
320
+ if do_task == "finetune_kd":
321
+ return (emb_loss, inputs_embeds, unit_targets, atten_mask,)
322
+ else:
323
+ return (emb_loss, inputs_embeds)
324
+
325
+ def forward(
326
+ self,
327
+ input_ids: torch.LongTensor = None,
328
+ attention_mask: Optional[torch.Tensor] = None,
329
+ position_ids: Optional[torch.LongTensor] = None,
330
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
331
+ inputs_embeds: Optional[torch.FloatTensor] = None,
332
+ labels: Optional[torch.LongTensor] = None,
333
+ use_cache: Optional[bool] = None,
334
+ output_attentions: Optional[bool] = None,
335
+ output_hidden_states: Optional[bool] = None,
336
+ return_dict: Optional[bool] = None,
337
+ cache_position: Optional[torch.LongTensor] = None,
338
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
339
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
340
+ output_hidden_states = (
341
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
342
+ )
343
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
344
+
345
+ if inputs_embeds == None:
346
+ # inputs_embeds = self.model.embed_tokens(input_ids)
347
+ # share embedding 0501 by kkq
348
+ embed_tokens_weight = torch.cat(
349
+ [
350
+ self.model.embed_tokens.weight.detach(), self.unit_embedding.weight
351
+ ],
352
+ dim = 0,
353
+ )
354
+ # print(embed_tokens_weight, embed_tokens_weight.shape)
355
+ inputs_embeds = F.embedding(input_ids, embed_tokens_weight, padding_idx=self.pad_token_id)
356
+ inputs_embeds = self.adapter(inputs_embeds)
357
+
358
+ outputs = self.model(
359
+ input_ids=None,
360
+ attention_mask=attention_mask,
361
+ past_key_values=past_key_values,
362
+ inputs_embeds=inputs_embeds,
363
+ use_cache=use_cache,
364
+ output_attentions=output_attentions,
365
+ output_hidden_states=output_hidden_states,
366
+ return_dict=return_dict,
367
+ )
368
+ hidden_states = outputs[0]
369
+ logits = self.lm_head(hidden_states)
370
+
371
+ loss = None
372
+ cr_loss = None
373
+ if labels != None:
374
+ shift_labels = labels
375
+
376
+ # Shift so that tokens < n predict n
377
+ shift_logits = logits[..., :-1, :].contiguous()
378
+ shift_labels = shift_labels[..., 1:].contiguous()
379
+
380
+ loss_fct = CrossEntropyLoss()
381
+
382
+ shift_logits = shift_logits.view(-1, (self.config.vocab_size + self.config.unit_vocab_size))
383
+ shift_labels = shift_labels.view(-1)
384
+ shift_labels = shift_labels.to(shift_logits.device)
385
+
386
+ loss = loss_fct(shift_logits, shift_labels)
387
+
388
+ if loss.device == torch.device("cuda:0"):
389
+ self.log["unit_loss"] = loss.item()
390
+
391
+ if cr_loss != None:
392
+ target_scale = loss.item() * 0.2
393
+ cr_loss_weight = target_scale / cr_loss.item() if cr_loss > target_scale else 1.0
394
+ loss = loss + cr_loss_weight * cr_loss
395
+
396
+ if loss.device == torch.device("cuda:0") and (self.current_step - 10) % 100 == 0:
397
+ print(self.log, loss.device)
398
+
399
+ if not return_dict:
400
+ output = (logits,) + outputs[1:]
401
+ return (loss,) + output if loss is not None else output
402
+
403
+ return CausalLMOutputWithPast(
404
+ loss=loss,
405
+ logits=logits,
406
+ past_key_values=outputs.past_key_values,
407
+ hidden_states=outputs.hidden_states,
408
+ attentions=outputs.attentions,
409
+ )
410
+
411
+ AutoConfig.register("T2ULlama", T2ULlamaConfig)
412
+ AutoModelForCausalLM.register(T2ULlamaConfig, T2ULlamaForCausalLM)
show_case/1.wav ADDED
Binary file (72.1 kB). View file
 
show_case/2.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:03198985b12bd892d05b8ae9b2e6c8303b15b0a570eea9647b2e314332340711
3
+ size 165742
show_case/Translate_de_audio_prompt.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aee2ba64475fb597ec0a68da4e1a552fc4463bca0e9393d9f170e37f563d7792
3
+ size 288164
text_to_speech.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fairseq.dataclass.configs import FairseqConfig
2
+ from fairseq import utils
3
+ from fairseq.models.text_to_speech.vocoder import CodeHiFiGANVocoder
4
+ from fairseq import checkpoint_utils, options, tasks, utils
5
+ from fairseq.distributed import utils as distributed_utils
6
+ import torch
7
+ import json
8
+ from tqdm import tqdm
9
+ import random
10
+ import soundfile as sf
11
+ import numpy as np
12
+ import ast
13
+ import time
14
+ import math
15
+ from fairseq.dataclass.utils import convert_namespace_to_omegaconf
16
+ from fairseq.token_generation_constraints import pack_constraints, unpack_constraints
17
+ from fairseq_cli.generate import get_symbols_to_strip_from_output
18
+ from collections import namedtuple
19
+ import sys
20
+ from argparse import Namespace
21
+ import argparse
22
+ import sentencepiece as spm
23
+ import re
24
+
25
+ Batch = namedtuple("Batch", "ids src_tokens src_lengths constraints")
26
+ Translation = namedtuple("Translation", "src_str hypos pos_scores alignments")
27
+
28
+ def make_batches(lines, cfg, task, max_positions, encode_fn):
29
+ def encode_fn_target(x):
30
+ return encode_fn(x)
31
+
32
+ if cfg.generation.constraints:
33
+ # Strip (tab-delimited) contraints, if present, from input lines,
34
+ # store them in batch_constraints
35
+ batch_constraints = [list() for _ in lines]
36
+ for i, line in enumerate(lines):
37
+ if "\t" in line:
38
+ lines[i], *batch_constraints[i] = line.split("\t")
39
+
40
+ # Convert each List[str] to List[Tensor]
41
+ for i, constraint_list in enumerate(batch_constraints):
42
+ batch_constraints[i] = [
43
+ task.target_dictionary.encode_line(
44
+ encode_fn_target(constraint),
45
+ append_eos=False,
46
+ add_if_not_exist=False,
47
+ )
48
+ for constraint in constraint_list
49
+ ]
50
+
51
+ if cfg.generation.constraints:
52
+ constraints_tensor = pack_constraints(batch_constraints)
53
+ else:
54
+ constraints_tensor = None
55
+
56
+ tokens, lengths = task.get_interactive_tokens_and_lengths(lines, encode_fn)
57
+
58
+ itr = task.get_batch_iterator(
59
+ dataset=task.build_dataset_for_inference(
60
+ tokens, lengths, constraints=constraints_tensor
61
+ ),
62
+ max_tokens=cfg.dataset.max_tokens,
63
+ max_sentences=cfg.dataset.batch_size,
64
+ max_positions=max_positions,
65
+ ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
66
+ ).next_epoch_itr(shuffle=False)
67
+ for batch in itr:
68
+ ids = batch["id"]
69
+ src_tokens = batch["net_input"]["src_tokens"]
70
+ src_lengths = batch["net_input"]["src_lengths"]
71
+ constraints = batch.get("constraints", None)
72
+
73
+ yield Batch(
74
+ ids=ids,
75
+ src_tokens=src_tokens,
76
+ src_lengths=src_lengths,
77
+ constraints=constraints,
78
+ )
79
+
80
+ def tokenize(inputs, sp):
81
+ text = re.sub(r'[^\w\s]', '', inputs.lower())
82
+ inputs = ' '.join(sp.EncodeAsPieces(text))
83
+ # print(inputs)
84
+ return inputs
85
+
86
+ def get_t2u_config(model, beam=5):
87
+
88
+ sys.argv = [
89
+ "fairseq-interactive",
90
+ "libri_t2u",
91
+ "--path", model,
92
+ "--gen-subset", "valid",
93
+ "--max-len-b", "1024",
94
+ "--max-source-positions", "500",
95
+ "--max-target-positions", "1024",
96
+ "--beam", str(beam),
97
+ "--results-path", "decode"
98
+ ]
99
+
100
+ parser = options.get_interactive_generation_parser()
101
+ args = options.parse_args_and_arch(parser)
102
+ # distributed_utils.call_main(convert_namespace_to_omegaconf(args), load_text2units_model)
103
+ return convert_namespace_to_omegaconf(args)
104
+
105
+ def load_text2units_model(cfg: FairseqConfig, device):
106
+
107
+ if isinstance(cfg, Namespace):
108
+ cfg = convert_namespace_to_omegaconf(cfg)
109
+
110
+ utils.import_user_module(cfg.common)
111
+ if cfg.interactive.buffer_size < 1:
112
+ cfg.interactive.buffer_size = 1
113
+ if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None:
114
+ cfg.dataset.batch_size = 1
115
+
116
+ assert (
117
+ not cfg.generation.sampling or cfg.generation.nbest == cfg.generation.beam
118
+ ), "--sampling requires --nbest to be equal to --beam"
119
+ assert (
120
+ not cfg.dataset.batch_size
121
+ or cfg.dataset.batch_size <= cfg.interactive.buffer_size
122
+ ), "--batch-size cannot be larger than --buffer-size"
123
+
124
+ # Fix seed for stochastic decoding
125
+ if cfg.common.seed is not None and not cfg.generation.no_seed_provided:
126
+ np.random.seed(cfg.common.seed)
127
+ utils.set_torch_seed(cfg.common.seed)
128
+
129
+ use_cuda = torch.cuda.is_available() and not cfg.common.cpu
130
+
131
+ # Setup task, e.g., translation
132
+ task = tasks.setup_task(cfg.task)
133
+
134
+ # Load ensemble
135
+ overrides = ast.literal_eval(cfg.common_eval.model_overrides)
136
+ models, _model_args = checkpoint_utils.load_model_ensemble(
137
+ utils.split_paths(cfg.common_eval.path),
138
+ arg_overrides=overrides,
139
+ task=task,
140
+ suffix=cfg.checkpoint.checkpoint_suffix,
141
+ strict=(cfg.checkpoint.checkpoint_shard_count == 1),
142
+ num_shards=cfg.checkpoint.checkpoint_shard_count,
143
+ )
144
+
145
+ # Set dictionaries
146
+ src_dict = task.source_dictionary
147
+ tgt_dict = task.target_dictionary
148
+
149
+ # Optimize ensemble for generation
150
+ for model in models:
151
+ if model is None:
152
+ continue
153
+ if cfg.common.fp16:
154
+ model.half()
155
+ if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
156
+ model.cuda()
157
+ model.prepare_for_inference_(cfg)
158
+
159
+ # Initialize generator
160
+ generator = task.build_generator(models, cfg.generation)
161
+
162
+ # Handle tokenization and BPE
163
+ tokenizer = task.build_tokenizer(cfg.tokenizer)
164
+ bpe = task.build_bpe(cfg.bpe)
165
+
166
+ return {
167
+ "models": models,
168
+ "generator": generator,
169
+ "tokenizer": tokenizer,
170
+ "bpe": bpe,
171
+ "task": task,
172
+ "src_dict": src_dict,
173
+ "tgt_dict": tgt_dict,
174
+ "use_cuda": use_cuda
175
+ }
176
+
177
+ def gen_units(model, cfg, inputs):
178
+ inputs = [inputs]
179
+
180
+ models = model['models']
181
+ generator = model['generator']
182
+ tokenizer = model['tokenizer']
183
+ bpe = model['bpe']
184
+ task = model['task']
185
+ src_dict = model['src_dict']
186
+ tgt_dict = model['tgt_dict']
187
+ use_cuda = model['use_cuda']
188
+
189
+ def encode_fn(x):
190
+ if tokenizer is not None:
191
+ x = tokenizer.encode(x)
192
+ if bpe is not None:
193
+ x = bpe.encode(x)
194
+ return x
195
+
196
+ def decode_fn(x):
197
+ if bpe is not None:
198
+ x = bpe.decode(x)
199
+ if tokenizer is not None:
200
+ x = tokenizer.decode(x)
201
+ return x
202
+
203
+ align_dict = utils.load_align_dict(cfg.generation.replace_unk)
204
+
205
+ max_positions = utils.resolve_max_positions(
206
+ task.max_positions(), *[model.max_positions() for model in models]
207
+ )
208
+
209
+ start_id = 0
210
+ results = []
211
+ for batch in make_batches(inputs, cfg, task, max_positions, encode_fn):
212
+ print("[INFO_DEBUG]", batch)
213
+ bsz = batch.src_tokens.size(0)
214
+ src_tokens = batch.src_tokens
215
+ src_lengths = batch.src_lengths
216
+ constraints = batch.constraints
217
+ if use_cuda:
218
+ src_tokens = src_tokens.cuda()
219
+ src_lengths = src_lengths.cuda()
220
+ if constraints is not None:
221
+ constraints = constraints.cuda()
222
+
223
+ sample = {
224
+ "net_input": {
225
+ "src_tokens": src_tokens,
226
+ "src_lengths": src_lengths,
227
+ },
228
+ }
229
+ translate_start_time = time.time()
230
+ translations = task.inference_step(
231
+ generator, models, sample, constraints=constraints
232
+ )
233
+ translate_time = time.time() - translate_start_time
234
+ list_constraints = [[] for _ in range(bsz)]
235
+ if cfg.generation.constraints:
236
+ list_constraints = [unpack_constraints(c) for c in constraints]
237
+ for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)):
238
+ src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad())
239
+ constraints = list_constraints[i]
240
+ results.append(
241
+ (
242
+ start_id + id,
243
+ src_tokens_i,
244
+ hypos,
245
+ {
246
+ "constraints": constraints,
247
+ "time": translate_time / len(translations),
248
+ },
249
+ )
250
+ )
251
+
252
+ # print(results)
253
+
254
+ units = []
255
+ for id_, _, hypos, info in sorted(results, key=lambda x: x[0]):
256
+ print("W-{}\t{:.3f}\tseconds".format(id_, info["time"]))
257
+
258
+ # Process top predictions
259
+ for hypo in hypos[: min(len(hypos), cfg.generation.nbest)]:
260
+ hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
261
+ hypo_tokens=hypo["tokens"].int().cpu(),
262
+ src_str="",
263
+ alignment=hypo["alignment"],
264
+ align_dict=align_dict,
265
+ tgt_dict=tgt_dict,
266
+ remove_bpe=cfg.common_eval.post_process,
267
+ extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator),
268
+ )
269
+
270
+ units.append(list(map(int, hypo_str.split(' '))))
271
+
272
+ return units
273
+
274
+ def get_vocoder_config(vocoder, config):
275
+
276
+ args = argparse.Namespace(
277
+ vocoder=vocoder,
278
+ vocoder_cfg=config,
279
+ dur_prediction=True,
280
+ speaker_id=1,
281
+ cpu=False
282
+ )
283
+
284
+ return args
285
+
286
+ def load_units_vocoder(args, device):
287
+ with open(args.vocoder_cfg) as f:
288
+ vocoder_cfg = json.load(f)
289
+ vocoder = CodeHiFiGANVocoder(args.vocoder, vocoder_cfg).to(device)
290
+
291
+ multispkr = vocoder.model.multispkr
292
+ if multispkr:
293
+ num_speakers = vocoder_cfg.get(
294
+ "num_speakers", 200
295
+ ) # following the default in codehifigan to set to 200
296
+ assert (
297
+ args.speaker_id < num_speakers
298
+ ), f"invalid --speaker-id ({args.speaker_id}) with total #speakers = {num_speakers}"
299
+
300
+ return vocoder, num_speakers if multispkr else 1, 'cuda' in device
301
+
302
+ def gen_wav(vocoder, args, data, device):
303
+ vocoder, num_speakers, use_cuda = vocoder
304
+ res = []
305
+ for i, d in enumerate(data): # tqdm is removed for cleaner streaming
306
+ x = {
307
+ "code": torch.LongTensor(d).view(1, -1).to(device),
308
+ }
309
+ suffix = ""
310
+
311
+ multispkr = vocoder.model.multispkr
312
+ if multispkr:
313
+ spk = (
314
+ random.randint(0, num_speakers - 1)
315
+ if args.speaker_id == -1
316
+ else args.speaker_id
317
+ )
318
+ suffix = f"_spk{spk}"
319
+ x["spkr"] = torch.LongTensor([spk]).view(1, 1)
320
+
321
+ x = utils.move_to_cuda(x) if use_cuda else x
322
+ wav = vocoder(x, args.dur_prediction).detach().cpu().numpy()
323
+
324
+ res.append(wav)
325
+
326
+ return res[0]