from typing import List, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import CrossEntropyLoss, CTCLoss from transformers import AutoConfig, AutoModelForCausalLM, \ LlamaConfig, LlamaModel, LlamaForCausalLM from transformers.trainer_pt_utils import LabelSmoother from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers import ( WhisperProcessor, WhisperModel, ) from T2ULlama_CR_online import T2ULlamaForCausalLM IGNORE_TOKEN_ID = LabelSmoother.ignore_index class ACLlamaConfig(LlamaConfig): model_type = "ACLlama" def load_whisper(audio_tower_name, device="cuda"): model = WhisperModel.from_pretrained( "openai/whisper-large-v3",torch_dtype=torch.float16,low_cpu_mem_usage=True).to(device) model.config.forced_decoder_ids = None return model class LookBackModule(nn.Module): def __init__(self, cfg: LlamaConfig): super().__init__() self.encoder_attn = nn.MultiheadAttention( cfg.hidden_size, cfg.num_attention_heads, dropout=0.1, batch_first=True ) self.atten_layer_norm = nn.LayerNorm(cfg.hidden_size) def forward(self, x, wav_feature, bf_shrink_padding_mask): residual = x x, _ = self.encoder_attn( query=x, key=wav_feature, value=wav_feature, key_padding_mask=bf_shrink_padding_mask, #attn_mask=padding_mask, ) x += residual x = self.atten_layer_norm(x) return x class ACLlamaModel(LlamaModel): config_class = ACLlamaConfig def __init__(self, config: LlamaConfig): super(ACLlamaModel, self).__init__(config) if hasattr(config, "audio_tower"): self.audio_tower = [load_whisper(config.audio_tower)] if hasattr(config, "adapter_size"): self.mm_projector1 = nn.Linear(config.adapter_size*2 , config.hidden_size) asr_encoder_layer = nn.TransformerEncoderLayer( d_model=config.hidden_size, nhead=config.num_attention_heads, dim_feedforward=config.hidden_size*2, dropout=0.1, norm_first=True ) self.lbm = LookBackModule(config) self.out_norm = nn.LayerNorm(config.hidden_size) self.audio_feature_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.asr_transformer_encoder = nn.TransformerEncoder(asr_encoder_layer, num_layers=1) self.mask_tensor=(torch.ones([1, 2048])>0) self.length=-1 def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, audios: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: # HACK: replace back original embeddings for LLaAA pretraining orig_embeds_params = getattr(self, 'orig_embeds_params', None) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) audio_tower = getattr(self, 'audio_tower', None) if audio_tower is not None and (input_ids.shape[1] != 1 or self.training) and audios is not None: audio_tower = audio_tower[0] # HACK: for FSDP audio_list=[] audio_config = audio_tower.config for audio in audios: with torch.no_grad(): audio_feature = audio_tower.encoder(audio).last_hidden_state audio_feature = audio_feature.view(audio_feature.shape[0], audio_feature.shape[1]//2, 2 * audio_feature.shape[2]) audio_feature = self.mm_projector1(audio_feature) audio_feature = self.asr_transformer_encoder(audio_feature) audio_feature = self.out_norm(audio_feature) audio_list.append(audio_feature) audio_features = torch.stack(audio_list, dim=0) batch = audio_features.shape[0] audio_turn = audio_features.shape[1] audio_features = audio_features.view((batch * audio_turn,)+audio_features.shape[2:]) predict_logits = self.audio_feature_head(audio_features) new_input_embeds = [] label_shift = [] speech_pos = [] label_extend = -1 new_input_ids = [] tokens = predict_logits.argmax(dim=-1) shrink_mask = tokens.roll(1) != tokens shrink_mask[:,0] = True lengths = shrink_mask.long().sum(-1) shrink_2d = audio_features[shrink_mask] #num_patches = audio_features.shape[1] num_patches = audio_config.audio_patch_size l_index=0 shrink_features_raw = [] for v, audio_feature, mask in zip(lengths, audio_features, ~shrink_mask): shrink_feature = shrink_2d[l_index:l_index+v] shrink_feature = self.lbm(shrink_feature, audio_feature, bf_shrink_padding_mask=mask) shrink_features_raw.append(shrink_feature) l_index += v shrink_features = [] for i in range(0, len(shrink_features_raw), audio_turn): shrink_features.append(shrink_features_raw[i:i+audio_turn]) if self.training: maxn_length = lengths.view(batch,audio_turn).sum(-1).max() label_extend = maxn_length - num_patches * audio_turn old_seq_length = inputs_embeds.shape[1] for cur_input_ids, cur_input_embeds, cur_shrink_features in zip(input_ids, inputs_embeds, shrink_features): pad_ids = torch.full(size=(maxn_length,), fill_value=audio_config.llm_pad_token_id, dtype=torch.long).to(attention_mask.device) pad_embeds = self.embed_tokens(pad_ids) audio_start_token_pos_all = torch.where(cur_input_ids == audio_config.audio_patch_token)[0] #print(cur_input_embeds.shape,cur_input_ids.shape) inner_label_shift = [] inner_speech_pos = [] for audio_start_token_pos, shrink_feature in reversed(list(zip(audio_start_token_pos_all, cur_shrink_features))): #zip(audio_start_token_pos_all, cur_shrink_features): cur_speech_length = shrink_feature.shape[0] cur_input_ids = torch.cat((cur_input_ids[:audio_start_token_pos], cur_input_ids[audio_start_token_pos: audio_start_token_pos+1].repeat(cur_speech_length), cur_input_ids[audio_start_token_pos + num_patches:]), dim=0) cur_input_embeds = torch.cat(( cur_input_embeds[:audio_start_token_pos], shrink_feature, cur_input_embeds[audio_start_token_pos + num_patches:]), dim=0) inner_label_shift.insert(0, cur_speech_length - num_patches) inner_speech_pos.insert(0, audio_start_token_pos) label_shift = label_shift + inner_label_shift speech_pos = speech_pos + inner_speech_pos cur_new_input_embeds = torch.cat((cur_input_embeds, pad_embeds[:old_seq_length + label_extend - cur_input_embeds.shape[0]]),dim=0) cur_new_input_ids = torch.cat((cur_input_ids, pad_ids[:old_seq_length + label_extend - cur_input_ids.shape[0]]),dim=0) new_input_embeds.append(cur_new_input_embeds) new_input_ids.append(cur_new_input_ids) input_ids = torch.stack(new_input_ids, dim=0) attention_mask=input_ids.ne(audio_config.llm_pad_token_id) inputs_embeds = torch.stack(new_input_embeds, dim=0) batch_label_shift = [] batch_speech_pos=[] for i in range(0, len(label_shift), audio_turn): batch_label_shift.append(label_shift[i:i+audio_turn]) batch_speech_pos.append(speech_pos[i:i+audio_turn]) else: # Inference mode with batch_size=1 assert input_ids.shape[0] == 1, "This implementation only supports batch_size=1 during inference" # Get all audio token positions in this sample audio_start_token_positions = torch.where(input_ids[0] == audio_config.audio_patch_token)[0] # Initialize with original embeddings current_embeds = inputs_embeds[0] # [seq_len, embed_dim] current_ids = input_ids[0] # [seq_len] # Process each audio token position sequentially position_shift = 0 # Track position changes due to expansions # Ensure shrink_features is properly formatted if isinstance(shrink_features[0], list): # If it's a list of lists (batch_size=1 but multiple turns), flatten it shrink_features = [item for sublist in shrink_features for item in sublist] for pos_idx, audio_pos in enumerate(audio_start_token_positions): adjusted_pos = audio_pos + position_shift # Get corresponding shrink feature (ensure it's a tensor) shrink_feature = shrink_features[pos_idx] if isinstance(shrink_feature, list): shrink_feature = torch.stack(shrink_feature, dim=0) v = shrink_feature.shape[0] # Now this should work # print('len: ', v) # Expand the input ids and embeddings current_ids = torch.cat([ current_ids[:adjusted_pos], current_ids[adjusted_pos:adjusted_pos+1].repeat(v), current_ids[adjusted_pos + num_patches:] ], dim=0) current_embeds = torch.cat([ current_embeds[:adjusted_pos], shrink_feature, current_embeds[adjusted_pos + num_patches:] ], dim=0) # Update position shift for next iteration position_shift += (v - num_patches) # Update the tensors (unsqueeze to restore batch dim) input_ids = current_ids.unsqueeze(0) # [1, new_seq_len] inputs_embeds = current_embeds.unsqueeze(0) # [1, new_seq_len, embed_dim] attention_mask = input_ids.ne(audio_config.llm_pad_token_id) # Update inference state tracking if not hasattr(self, 'mask_tensor'): # Initialize with current attention mask self.mask_tensor = attention_mask.clone() self.length = attention_mask.shape[1] else: # Ensure mask tensor is on correct device self.mask_tensor = self.mask_tensor.to(attention_mask.device) # Expand mask tensor if needed if self.mask_tensor.shape[1] < attention_mask.shape[1]: new_mask = torch.zeros(1, attention_mask.shape[1], dtype=torch.bool, device=attention_mask.device) new_mask[0, :self.mask_tensor.shape[1]] = self.mask_tensor self.mask_tensor = new_mask # Update mask tensor self.mask_tensor[0, :attention_mask.shape[1]] = attention_mask[0] self.length = attention_mask.shape[1] attention_mask=self.mask_tensor[:,:self.length] self.length+=1 return_state=super(ACLlamaModel, self).forward( input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict ) if self.training and audios is not None: return_state["audio_features"] = predict_logits return_state["label_shift"] = batch_label_shift return_state["label_extend"] = label_extend return_state["speech_pos"] = batch_speech_pos #return_state = {"audio_features":predict_logits} return return_state class ACLlamaForCausalLM(LlamaForCausalLM): config_class = ACLlamaConfig def __init__(self, config): super(LlamaForCausalLM, self).__init__(config) self.model = ACLlamaModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # t2u by kkq if hasattr(config, "unit_output"): self.unit_output = config.unit_output self.unit_translator = T2ULlamaForCausalLM(config, self.lm_head.weight) # Initialize weights and apply final processing self.post_init() def get_model(self): return self.model def get_unit_translator(self): return self.unit_translator def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, t2u_input_ids: Optional[torch.LongTensor] = None, t2u_labels: Optional[torch.LongTensor] = None, t2u_attention_mask: Optional[torch.Tensor] = None, unit_targets: Optional[torch.Tensor] = None, sub_lengths: Optional[torch.Tensor] = None, asr_targets: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, audios: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, do_task: str = None, assistant_after_audio_shifts: Optional[torch.Tensor] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) # t2u by kkq # pretrain(t2u only) finetune(s2t&e2u) do_task = do_task if do_task != None else getattr(self, 'unit_output', None) outputs = None hidden_states = None new_shift_labels = None if do_task != "pretrain": outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, audios=audios ) hidden_states = outputs[0] logits = self.lm_head(hidden_states) loss = None if labels is not None and do_task != "pretrain" and do_task != "finetune_kd": if asr_targets is not None: asr_logits = outputs["audio_features"] asr_targets = asr_targets.view(asr_targets.shape[0] * asr_targets.shape[1], asr_targets.shape[2]) mask_asr_targets = (asr_targets != IGNORE_TOKEN_ID) target_lengths = mask_asr_targets.sum(1) input_lengths = torch.full(size=(asr_logits.shape[0],), fill_value=asr_logits.shape[1], dtype=torch.long) loss_ctc = CTCLoss() log_probs = F.log_softmax(asr_logits, dim=-1).transpose(0, 1) #print(asr_targets.shape) #print(input_lengths, target_lengths) with torch.backends.cudnn.flags(enabled=False): loss_asr = F.ctc_loss( log_probs, asr_targets, input_lengths, target_lengths, blank=self.model.audio_tower[0].config.audio_patch_token, reduction='mean', zero_infinity=True, ) else: loss_asr=0 shift_labels = labels if "label_shift" in outputs.keys() and len(outputs["label_shift"]) >0: if outputs["label_extend"] != -1: new_shift_labels = torch.full(size=(shift_labels.shape[0], outputs["label_extend"]+shift_labels.shape[1]), fill_value=IGNORE_TOKEN_ID, dtype=torch.long).to(shift_labels.device) for batch in range(len(outputs["label_shift"])): it_lable_shift = outputs["label_shift"][batch] it_speech_pos = outputs["speech_pos"][batch] prefix = 0 for i in range(len(it_lable_shift)): if i == len(it_lable_shift) - 1: length = shift_labels.shape[1] - it_speech_pos[i] #len(shift_labels[batch]) - it_speech_pos[i] else: length = it_speech_pos[i + 1] - it_speech_pos[i] prefix += it_lable_shift[i] new_shift_labels[batch][it_speech_pos[i] + prefix: it_speech_pos[i] + length + prefix]= shift_labels[batch][it_speech_pos[i]:it_speech_pos[i]+length] shift_labels = new_shift_labels else: raise NotImplementedError # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = shift_labels[..., 1:].contiguous() #print(shift_labels[:,:50]) #print(shift_labels[:,:150]) loss_fct = CrossEntropyLoss() # Flatten the tokens shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) loss = loss + 0.3 * loss_asr t2u_output = None if do_task != None and do_task != "skip": if do_task == "finetune_kd": text_start_index = [] for batch in range(len(outputs["label_shift"])): text_start_index.append(outputs["speech_pos"][batch][0] + outputs["label_shift"][batch][0]+assistant_after_audio_shifts[batch]) t2u_embeds_output = self.unit_translator.insert_text_embedding( input_ids=t2u_input_ids, attention_mask=t2u_attention_mask, inputs_embeds=None, labels=t2u_labels, text_labels=labels, shift_text_labels=new_shift_labels, shift_text_hidden_states=hidden_states, unit_targets=unit_targets, sub_lengths=sub_lengths, text_start_index=text_start_index, do_task=do_task, ) vae_loss, t2u_inputs_embeds, unit_targets, t2u_attention_mask = t2u_embeds_output t2u_output = self.unit_translator( input_ids=None, attention_mask=t2u_attention_mask, past_key_values=past_key_values, inputs_embeds=t2u_inputs_embeds, use_cache=use_cache, labels=unit_targets, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) else: t2u_embeds_output = self.unit_translator.insert_text_embedding( input_ids=t2u_input_ids, attention_mask=t2u_attention_mask, inputs_embeds=None, labels=t2u_labels, text_labels=labels, shift_text_labels=new_shift_labels, shift_text_hidden_states=hidden_states, do_task=do_task, ) vae_loss, t2u_inputs_embeds = t2u_embeds_output t2u_output = self.unit_translator( input_ids=None, attention_mask=t2u_attention_mask, past_key_values=past_key_values, inputs_embeds=t2u_inputs_embeds, use_cache=use_cache, labels=t2u_labels, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) t2u_loss = t2u_output[0] # print(do_task, t2u_loss, vae_loss) if vae_loss != None: target_scale = t2u_loss.item() * 0.2 vae_loss_weight = target_scale / vae_loss.item() if vae_loss > target_scale else 1.0 t2u_loss = t2u_loss + vae_loss_weight * vae_loss #print(vae_loss) if loss != None: # S2T + T2U loss # ignore LLM loss # t2u_output["loss"] = t2u_loss # return t2u_output # original version assert do_task in ["finetune"] if loss.item() < 1.0: # 1.7 loss = 0.2 * loss + t2u_loss * 2.0 else: loss = loss + t2u_loss else: assert do_task in ["pretrain", "finetune_kd"] t2u_output["loss"] = t2u_loss return t2u_output #return CausalLMOutputWithPast( # loss=loss, # logits=outputs["audio_features"], #) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, position_ids=None, use_cache=True, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens # Exception 1: when passing input_embeds, input_ids may be missing entries # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here if past_key_values is not None: if inputs_embeds is not None: # Exception 1 input_ids = input_ids[:, -cache_position.shape[0] :] elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) input_ids = input_ids[:, cache_position] if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and cache_position[0] == 0: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases model_inputs.update( { "position_ids": position_ids, "cache_position": cache_position, "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, } ) model_inputs.update({"audios": kwargs["audios"]} if "audios" in kwargs.keys() else {}) model_inputs.update({"do_task": kwargs["do_task"]} if "do_task" in kwargs.keys() else {}) model_inputs.update({"return_dict": kwargs["return_dict_in_generate"]} if "return_dict_in_generate" in kwargs.keys() else {}) return model_inputs AutoConfig.register("ACLlama", ACLlamaConfig) AutoModelForCausalLM.register(ACLlamaConfig, ACLlamaForCausalLM)