from typing import List, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import CrossEntropyLoss, CTCLoss import transformers from transformers import AutoConfig, AutoModelForCausalLM, \ LlamaConfig, LlamaModel, LlamaForCausalLM from transformers.trainer_pt_utils import LabelSmoother from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers import ( WhisperProcessor, WhisperModel, ) IGNORE_TOKEN_ID = LabelSmoother.ignore_index def padding_tensor(tensor, length, dim=0, pad=False): if length == 0: return tensor assert length > 0, f"Wrong padding length: {length}" shape = list(tensor.shape) assert dim < len(shape), f"dim {dim} out of shape {shape}" shape[dim] = length padding_tensor = torch.cat( ( tensor, torch.full(tuple(shape), pad, dtype=tensor.dtype, device=tensor.device) ), dim=dim ) return padding_tensor class T2ULlamaConfig(LlamaConfig): model_type = "T2ULlama" class T2ULlamaForCausalLM(LlamaForCausalLM): config_class = T2ULlamaConfig def __init__(self, config, embedding_weight=None): self.current_step = 0 self.log = {} super(LlamaForCausalLM, self).__init__(config) self.config = config self.training_stage = config.unit_output self.pad_token_id = 128009 llama_config = T2ULlamaConfig(**config.to_dict(), batch_first=True, norm_first=True ) llama_config.architectures = ["T2ULlamaForCausalLM"] llama_config.pad_token_id = self.pad_token_id llama_config.vocab_size += llama_config.unit_vocab_size ####################################################### llama_config.unit_model = "medium" llama_config.max_position_embeddings = 2048 # 1024 1536 2048 # origin 1024 reduced 512 ####################################################### if hasattr(llama_config, "unit_model"): if llama_config.unit_model == "large": llama_config.num_hidden_layers = 2 # llama_config.hidden_size = 4096 # llama_config.num_attention_heads = 32 # llama_config.intermediate_size = 14336 # llama_config.head_dim = llama_config.hidden_size // llama_config.num_attention_heads elif llama_config.unit_model == "tiny": llama_config.num_hidden_layers = 4 llama_config.hidden_size = 512 llama_config.num_attention_heads = 8 llama_config.intermediate_size = 2048 llama_config.head_dim = llama_config.hidden_size // llama_config.num_attention_heads else: llama_config.num_hidden_layers = 8 llama_config.hidden_size = 768 llama_config.num_attention_heads = 12 llama_config.num_key_value_heads = 12 llama_config.intermediate_size = 2048 llama_config.head_dim = llama_config.hidden_size // llama_config.num_attention_heads else: llama_config.num_hidden_layers = 6 llama_config.hidden_size = 512 llama_config.num_attention_heads = 8 llama_config.intermediate_size = 2048 llama_config.head_dim = llama_config.hidden_size // llama_config.num_attention_heads # print(llama_config) self.model = LlamaModel(llama_config) # share embedding 0501 by kkq self.model.embed_tokens = nn.Embedding(num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, padding_idx=self.pad_token_id) # redefine self.unit_embedding = nn.Linear(config.hidden_size, llama_config.unit_vocab_size, bias=False) self.adapter = nn.Linear(config.hidden_size, llama_config.hidden_size, bias = True) self.lm_head = nn.Linear(llama_config.hidden_size, llama_config.vocab_size, bias=False) if self.training_stage == "pretrain": pass elif self.training_stage == "finetune" or self.training_stage == "finetune_kd" or self.training_stage == "finetune_kd_online": self.aligner_MLP = nn.Sequential( nn.Linear(config.hidden_size, config.intermediate_size), nn.ReLU(), nn.Dropout(0.1), nn.Linear(config.intermediate_size, config.hidden_size), ) torch.nn.init.ones_(self.aligner_MLP[0].weight) torch.nn.init.zeros_(self.aligner_MLP[0].bias) torch.nn.init.ones_(self.aligner_MLP[3].weight) torch.nn.init.zeros_(self.aligner_MLP[3].bias) # Initialize weights and apply final processing self.post_init() def get_model(self): return self.model def insert_text_embedding( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, text_labels: Optional[torch.LongTensor] = None, shift_text_labels: Optional[torch.LongTensor] = None, shift_text_hidden_states: Optional[torch.FloatTensor] = None, unit_targets: Optional[torch.LongTensor] = None, sub_lengths: Optional[torch.LongTensor] = None, text_start_index: Optional[torch.LongTensor] = None, do_task: str = None, **kwargs: dict, ): if inputs_embeds == None: # share embedding 0501 by kkq embed_tokens_weight = torch.cat( [ self.model.embed_tokens.weight.detach(), self.unit_embedding.weight ], dim = 0, ) # print(embed_tokens_weight, embed_tokens_weight.shape) inputs_embeds = F.embedding(input_ids, embed_tokens_weight, padding_idx=self.pad_token_id) emb_loss = None if do_task == "pretrain": if self.training: if hasattr(self, "embedding_dropout"): emb_origin_mask = text_labels != -100 origin_padding_length = labels.shape[-1] - emb_origin_mask.shape[-1] extend_emb_origin_mask = padding_tensor(emb_origin_mask, origin_padding_length, 1, False) extend_emb_origin_mask = ~extend_emb_origin_mask.unsqueeze(-1).expand_as(inputs_embeds) # Π-Model + noise log_var = self.perturb(inputs_embeds) perturbed_inputs_embeds_2 = inputs_embeds + torch.randn_like(inputs_embeds) * (torch.exp(0.5 * log_var) + 1e-6) # Π-Model + dropout perturbed_inputs_embeds_1 = self.embedding_dropout(inputs_embeds) perturbed_inputs_embeds_2 = self.embedding_dropout(perturbed_inputs_embeds_2) perturbed_inputs_embeds_1 = torch.where(extend_emb_origin_mask, inputs_embeds, perturbed_inputs_embeds_1) perturbed_inputs_embeds_2 = torch.where(extend_emb_origin_mask, inputs_embeds, perturbed_inputs_embeds_2) inputs_embeds = torch.cat( (perturbed_inputs_embeds_1, perturbed_inputs_embeds_2), dim=0, ) kl_loss = -0.5 * (1 + log_var - log_var.exp()).mean(dim=-1).sum(dim=-1).mean() contrastive_loss = (1 - F.cosine_similarity(perturbed_inputs_embeds_1, perturbed_inputs_embeds_2, dim=-1)).sum(dim=-1).mean() emb_loss = kl_loss + contrastive_loss if kl_loss.device == torch.device("cuda:0"): self.log["kl_loss"] = kl_loss.item() self.log["std"] = torch.exp(0.5 * log_var).mean().item() self.log["contrastive_loss"] = contrastive_loss.item() pass elif do_task == "finetune": inputs_embeds = inputs_embeds.detach() inputs_embeds_refer = inputs_embeds.clone().detach() shift_text_hidden_states = self.aligner_MLP(shift_text_hidden_states) emb_origin_mask = text_labels != -100 # get output text pos emb_shift_mask = shift_text_labels != -100 origin_padding_length = labels.shape[-1] - emb_origin_mask.shape[-1] shift_padding_length = labels.shape[-1] - emb_shift_mask.shape[-1] extend_emb_origin_mask = padding_tensor(emb_origin_mask, origin_padding_length, 1, False) extend_emb_shift_mask = padding_tensor(emb_shift_mask, shift_padding_length, 1, False) extend_shift_text_hidden_states = padding_tensor(shift_text_hidden_states, shift_padding_length, 1, 1e-9) # check extend_text_labels = padding_tensor(text_labels, origin_padding_length, 1, -100) extend_shift_text_labels = padding_tensor(shift_text_labels, shift_padding_length, 1, -100) assert torch.equal( extend_text_labels[extend_emb_origin_mask], extend_shift_text_labels[extend_emb_shift_mask] ), "{}\n{}\n{}\n{}".format(labels, extend_emb_origin_mask, extend_shift_text_labels, extend_emb_shift_mask) inputs_embeds[extend_emb_origin_mask.unsqueeze(-1).expand_as(inputs_embeds)] = \ extend_shift_text_hidden_states[extend_emb_shift_mask.unsqueeze(-1).expand_as(extend_shift_text_hidden_states)].to(dtype=inputs_embeds.dtype) if self.training: contrastive_loss = (1 - F.cosine_similarity(inputs_embeds, inputs_embeds_refer, dim=-1)).sum(-1).mean() emb_loss = contrastive_loss if emb_loss.device == torch.device("cuda:0"): self.log["contrastive_loss"] = contrastive_loss.item() pass elif do_task == "finetune_kd" : #inputs_embeds = inputs_embeds.detach() #inputs_embeds_refer = inputs_embeds.clone().detach() #print(text_labels) #print(sub_lengths.sum()) emb_origin_mask = text_labels != -100 fetch_lables_list = [] for batch in range(emb_origin_mask.shape[0]): fetch_lables_list.append(text_labels[batch][emb_origin_mask[batch]]) shift_text_hidden_states = self.aligner_MLP(shift_text_hidden_states) #split the shift_text_hidden_states #[128006, 128000, 78191, 128007, 128000, 198, 128000] maxn_length = sub_lengths.max() + 8 pad_ids = torch.full(size=(sub_lengths.shape[0], sub_lengths.shape[1], maxn_length), fill_value=self.pad_token_id, dtype=torch.long).to(shift_text_hidden_states.device) pad_text_ids = torch.full(size=(sub_lengths.shape[0], sub_lengths.shape[1], maxn_length), fill_value=self.pad_token_id, dtype=torch.long).to(shift_text_hidden_states.device) atten_mask = pad_ids.ne(self.pad_token_id) #target_mask_part1 = pad_ids.ne(self.pad_token_id) shift_text_hidden_states_slice = F.embedding(pad_ids, embed_tokens_weight, padding_idx=self.pad_token_id) #print(shift_text_hidden_states_slice.shape,shift_text_hidden_states.shape) for batch in range(sub_lengths.shape[0]): cot=0 start_index = text_start_index[batch] for index, sub_length in enumerate(sub_lengths[batch]): if sub_length==-1: break #print(shift_text_hidden_states_slice[batch][index][:sub_length].shape, shift_text_hidden_states[batch][cot:cot+sub_length].shape) eos_id = torch.IntTensor([128009]).to(inputs_embeds.device) eos = self.model.embed_tokens(eos_id) if index == 0: text_prefix_ids = torch.IntTensor([128006, 128000, 65576, 128007, 128000, 198]).to(inputs_embeds.device) preifx_embed = self.model.embed_tokens(text_prefix_ids) pad_text_ids[batch][index][:sub_length+7] = torch.cat([text_prefix_ids, fetch_lables_list[batch][cot:cot+sub_length], eos_id],dim=0) atten_mask[batch][index][:sub_length+7]=True else: text_prefix_ids = torch.IntTensor([128006, 128000, 65576, 128007, 128000, 198, 12800]).to(inputs_embeds.device) preifx_embed = self.model.embed_tokens(text_prefix_ids) pad_text_ids[batch][index][:sub_length+8] = torch.cat([text_prefix_ids, fetch_lables_list[batch][cot:cot+sub_length], eos_id], dim=0) atten_mask[batch][index][:sub_length+8]=True new_shift_text_hidden_states = torch.cat([preifx_embed, shift_text_hidden_states[batch][start_index+cot:start_index+cot+sub_length], eos], dim = 0) shift_text_hidden_states_slice[batch][index][:new_shift_text_hidden_states.shape[0]] = new_shift_text_hidden_states cot+=sub_length shift_text_hidden_states_slice = shift_text_hidden_states_slice.reshape(shift_text_hidden_states_slice.shape[0]*shift_text_hidden_states_slice.shape[1],shift_text_hidden_states_slice.shape[2],shift_text_hidden_states_slice.shape[3]) padding_unit_targets = unit_targets.clone() padding_unit_targets = torch.where(padding_unit_targets == IGNORE_TOKEN_ID, self.pad_token_id, padding_unit_targets) target_mask_part = padding_unit_targets.ne(self.pad_token_id) atten_mask = torch.cat([atten_mask, target_mask_part], dim = -1) atten_mask = atten_mask.reshape(atten_mask.shape[0]*atten_mask.shape[1],atten_mask.shape[2]) pad_text_ids = pad_text_ids.reshape(pad_text_ids.shape[0]*pad_text_ids.shape[1],pad_text_ids.shape[2]) shift_text_embeddings = F.embedding(pad_text_ids, embed_tokens_weight, padding_idx=self.pad_token_id) unit_target_slice = F.embedding(padding_unit_targets, embed_tokens_weight, padding_idx=self.pad_token_id) # unit_target_slice = F.embedding(unit_targets, embed_tokens_weight, padding_idx=self.pad_token_id) unit_target_slice = unit_target_slice.reshape(unit_target_slice.shape[0]*unit_target_slice.shape[1],unit_target_slice.shape[2],unit_target_slice.shape[3]) inputs_embeds = torch.cat([shift_text_hidden_states_slice, unit_target_slice], dim = 1) ignore_ids = torch.full(size=(sub_lengths.shape[0], sub_lengths.shape[1], maxn_length), fill_value=IGNORE_TOKEN_ID, dtype=torch.long).to(shift_text_hidden_states.device) unit_targets = torch.cat([ignore_ids,unit_targets],dim=-1) unit_targets = unit_targets.reshape(unit_targets.shape[0]*unit_targets.shape[1],unit_targets.shape[2]) if self.training: #print(shift_text_hidden_states_slice.shape, shift_text_embeddings.shape) contrastive_loss = (1 - F.cosine_similarity(shift_text_hidden_states_slice, shift_text_embeddings, dim=-1)).sum(-1).mean() emb_loss = contrastive_loss if emb_loss.device == torch.device("cuda:0"): self.log["contrastive_loss"] = contrastive_loss.item() elif do_task == "finetune_kd_online": shift_text_hidden_states = self.aligner_MLP(shift_text_hidden_states) gold_inputs_embeds = inputs_embeds.clone() for batch in range (inputs_embeds.shape[0]): start_index = text_start_index[batch] for slice_index in range (inputs_embeds.shape[1]): sub_length= sub_lengths[batch][slice_index] inputs_embeds[batch][slice_index][7:7+sub_length] = shift_text_hidden_states[batch][start_index+1:start_index+1+sub_length] start_index += sub_length if self.training: #print(shift_text_hidden_states_slice.shape, shift_text_embeddings.shape) contrastive_loss = ((1 - F.cosine_similarity(inputs_embeds, gold_inputs_embeds, dim=-1)) * attention_mask).sum(-1).mean() emb_loss = contrastive_loss if emb_loss.device == torch.device("cuda:0"): self.log["contrastive_loss"] = contrastive_loss.item() unit_embeds = F.embedding(unit_targets, embed_tokens_weight, padding_idx=self.pad_token_id) inputs_embeds = torch.cat([inputs_embeds,unit_embeds], dim=2) else: inputs_embeds = self.aligner_MLP(inputs_embeds) #[start_header_id] + _speaker + [end_header_id] + nl_tokens only for batch one! units_ids = torch.IntTensor([[128009, 128006, 128000, 65576, 128007, 128000, 198]]).to(inputs_embeds.device) units_prefix = self.model.embed_tokens(units_ids) text_ids = torch.IntTensor([[128006, 128000, 65576, 128007, 128000, 198, 12800]]).to(inputs_embeds.device) text_prefix = self.model.embed_tokens(text_ids) inputs_embeds = torch.cat([text_prefix, inputs_embeds, units_prefix], dim = 1) inputs_embeds = self.adapter(inputs_embeds) if do_task == "finetune_kd": return (emb_loss, inputs_embeds, unit_targets, atten_mask,) else: return (emb_loss, inputs_embeds) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if inputs_embeds == None: # inputs_embeds = self.model.embed_tokens(input_ids) # share embedding 0501 by kkq embed_tokens_weight = torch.cat( [ self.model.embed_tokens.weight.detach(), self.unit_embedding.weight ], dim = 0, ) # print(embed_tokens_weight, embed_tokens_weight.shape) inputs_embeds = F.embedding(input_ids, embed_tokens_weight, padding_idx=self.pad_token_id) inputs_embeds = self.adapter(inputs_embeds) outputs = self.model( input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = outputs[0] logits = self.lm_head(hidden_states) loss = None cr_loss = None if labels != None: shift_labels = labels # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = shift_labels[..., 1:].contiguous() loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, (self.config.vocab_size + self.config.unit_vocab_size)) shift_labels = shift_labels.view(-1) shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if loss.device == torch.device("cuda:0"): self.log["unit_loss"] = loss.item() if cr_loss != None: target_scale = loss.item() * 0.2 cr_loss_weight = target_scale / cr_loss.item() if cr_loss > target_scale else 1.0 loss = loss + cr_loss_weight * cr_loss if loss.device == torch.device("cuda:0") and (self.current_step - 10) % 100 == 0: print(self.log, loss.device) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) AutoConfig.register("T2ULlama", T2ULlamaConfig) AutoModelForCausalLM.register(T2ULlamaConfig, T2ULlamaForCausalLM)