Spaces:
Running
on
Zero
Running
on
Zero
from typing import List, Optional, Tuple, Union | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.nn import CrossEntropyLoss, CTCLoss | |
from transformers import AutoConfig, AutoModelForCausalLM, \ | |
LlamaConfig, LlamaModel, LlamaForCausalLM | |
from transformers.trainer_pt_utils import LabelSmoother | |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast | |
from transformers import ( | |
WhisperProcessor, | |
WhisperModel, | |
) | |
from T2ULlama_CR_online import T2ULlamaForCausalLM | |
IGNORE_TOKEN_ID = LabelSmoother.ignore_index | |
class ACLlamaConfig(LlamaConfig): | |
model_type = "ACLlama" | |
def load_whisper(audio_tower_name, device="cuda"): | |
model = WhisperModel.from_pretrained( | |
"openai/whisper-large-v3",torch_dtype=torch.float16,low_cpu_mem_usage=True).to(device) | |
model.config.forced_decoder_ids = None | |
return model | |
class LookBackModule(nn.Module): | |
def __init__(self, cfg: LlamaConfig): | |
super().__init__() | |
self.encoder_attn = nn.MultiheadAttention( | |
cfg.hidden_size, | |
cfg.num_attention_heads, | |
dropout=0.1, | |
batch_first=True | |
) | |
self.atten_layer_norm = nn.LayerNorm(cfg.hidden_size) | |
def forward(self, x, wav_feature, bf_shrink_padding_mask): | |
residual = x | |
x, _ = self.encoder_attn( | |
query=x, | |
key=wav_feature, | |
value=wav_feature, | |
key_padding_mask=bf_shrink_padding_mask, | |
#attn_mask=padding_mask, | |
) | |
x += residual | |
x = self.atten_layer_norm(x) | |
return x | |
class ACLlamaModel(LlamaModel): | |
config_class = ACLlamaConfig | |
def __init__(self, config: LlamaConfig): | |
super(ACLlamaModel, self).__init__(config) | |
if hasattr(config, "audio_tower"): | |
self.audio_tower = [load_whisper(config.audio_tower)] | |
if hasattr(config, "adapter_size"): | |
self.mm_projector1 = nn.Linear(config.adapter_size*2 , config.hidden_size) | |
asr_encoder_layer = nn.TransformerEncoderLayer( | |
d_model=config.hidden_size, | |
nhead=config.num_attention_heads, | |
dim_feedforward=config.hidden_size*2, | |
dropout=0.1, | |
norm_first=True | |
) | |
self.lbm = LookBackModule(config) | |
self.out_norm = nn.LayerNorm(config.hidden_size) | |
self.audio_feature_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | |
self.asr_transformer_encoder = nn.TransformerEncoder(asr_encoder_layer, num_layers=1) | |
self.mask_tensor=(torch.ones([1, 2048])>0) | |
self.length=-1 | |
def forward( | |
self, | |
input_ids: torch.LongTensor = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
past_key_values: Optional[List[torch.FloatTensor]] = None, | |
inputs_embeds: Optional[torch.FloatTensor] = None, | |
use_cache: Optional[bool] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
audios: Optional[torch.FloatTensor] = None, | |
return_dict: Optional[bool] = None, | |
) -> Union[Tuple, BaseModelOutputWithPast]: | |
# HACK: replace back original embeddings for LLaAA pretraining | |
orig_embeds_params = getattr(self, 'orig_embeds_params', None) | |
if inputs_embeds is None: | |
inputs_embeds = self.embed_tokens(input_ids) | |
audio_tower = getattr(self, 'audio_tower', None) | |
if audio_tower is not None and (input_ids.shape[1] != 1 or self.training) and audios is not None: | |
audio_tower = audio_tower[0] # HACK: for FSDP | |
audio_list=[] | |
audio_config = audio_tower.config | |
for audio in audios: | |
with torch.no_grad(): | |
audio_feature = audio_tower.encoder(audio).last_hidden_state | |
audio_feature = audio_feature.view(audio_feature.shape[0], audio_feature.shape[1]//2, 2 * audio_feature.shape[2]) | |
audio_feature = self.mm_projector1(audio_feature) | |
audio_feature = self.asr_transformer_encoder(audio_feature) | |
audio_feature = self.out_norm(audio_feature) | |
audio_list.append(audio_feature) | |
audio_features = torch.stack(audio_list, dim=0) | |
batch = audio_features.shape[0] | |
audio_turn = audio_features.shape[1] | |
audio_features = audio_features.view((batch * audio_turn,)+audio_features.shape[2:]) | |
predict_logits = self.audio_feature_head(audio_features) | |
new_input_embeds = [] | |
label_shift = [] | |
speech_pos = [] | |
label_extend = -1 | |
new_input_ids = [] | |
tokens = predict_logits.argmax(dim=-1) | |
shrink_mask = tokens.roll(1) != tokens | |
shrink_mask[:,0] = True | |
lengths = shrink_mask.long().sum(-1) | |
shrink_2d = audio_features[shrink_mask] | |
#num_patches = audio_features.shape[1] | |
num_patches = audio_config.audio_patch_size | |
l_index=0 | |
shrink_features_raw = [] | |
for v, audio_feature, mask in zip(lengths, audio_features, ~shrink_mask): | |
shrink_feature = shrink_2d[l_index:l_index+v] | |
shrink_feature = self.lbm(shrink_feature, audio_feature, bf_shrink_padding_mask=mask) | |
shrink_features_raw.append(shrink_feature) | |
l_index += v | |
shrink_features = [] | |
for i in range(0, len(shrink_features_raw), audio_turn): | |
shrink_features.append(shrink_features_raw[i:i+audio_turn]) | |
if self.training: | |
maxn_length = lengths.view(batch,audio_turn).sum(-1).max() | |
label_extend = maxn_length - num_patches * audio_turn | |
old_seq_length = inputs_embeds.shape[1] | |
for cur_input_ids, cur_input_embeds, cur_shrink_features in zip(input_ids, inputs_embeds, shrink_features): | |
pad_ids = torch.full(size=(maxn_length,), fill_value=audio_config.llm_pad_token_id, dtype=torch.long).to(attention_mask.device) | |
pad_embeds = self.embed_tokens(pad_ids) | |
audio_start_token_pos_all = torch.where(cur_input_ids == audio_config.audio_patch_token)[0] | |
#print(cur_input_embeds.shape,cur_input_ids.shape) | |
inner_label_shift = [] | |
inner_speech_pos = [] | |
for audio_start_token_pos, shrink_feature in reversed(list(zip(audio_start_token_pos_all, cur_shrink_features))): #zip(audio_start_token_pos_all, cur_shrink_features): | |
cur_speech_length = shrink_feature.shape[0] | |
cur_input_ids = torch.cat((cur_input_ids[:audio_start_token_pos], | |
cur_input_ids[audio_start_token_pos: audio_start_token_pos+1].repeat(cur_speech_length), | |
cur_input_ids[audio_start_token_pos + num_patches:]), dim=0) | |
cur_input_embeds = torch.cat(( | |
cur_input_embeds[:audio_start_token_pos], | |
shrink_feature, | |
cur_input_embeds[audio_start_token_pos + num_patches:]), dim=0) | |
inner_label_shift.insert(0, cur_speech_length - num_patches) | |
inner_speech_pos.insert(0, audio_start_token_pos) | |
label_shift = label_shift + inner_label_shift | |
speech_pos = speech_pos + inner_speech_pos | |
cur_new_input_embeds = torch.cat((cur_input_embeds, pad_embeds[:old_seq_length + label_extend - cur_input_embeds.shape[0]]),dim=0) | |
cur_new_input_ids = torch.cat((cur_input_ids, pad_ids[:old_seq_length + label_extend - cur_input_ids.shape[0]]),dim=0) | |
new_input_embeds.append(cur_new_input_embeds) | |
new_input_ids.append(cur_new_input_ids) | |
input_ids = torch.stack(new_input_ids, dim=0) | |
attention_mask=input_ids.ne(audio_config.llm_pad_token_id) | |
inputs_embeds = torch.stack(new_input_embeds, dim=0) | |
batch_label_shift = [] | |
batch_speech_pos=[] | |
for i in range(0, len(label_shift), audio_turn): | |
batch_label_shift.append(label_shift[i:i+audio_turn]) | |
batch_speech_pos.append(speech_pos[i:i+audio_turn]) | |
else: | |
# Inference mode with batch_size=1 | |
assert input_ids.shape[0] == 1, "This implementation only supports batch_size=1 during inference" | |
# Get all audio token positions in this sample | |
audio_start_token_positions = torch.where(input_ids[0] == audio_config.audio_patch_token)[0] | |
# Initialize with original embeddings | |
current_embeds = inputs_embeds[0] # [seq_len, embed_dim] | |
current_ids = input_ids[0] # [seq_len] | |
# Process each audio token position sequentially | |
position_shift = 0 # Track position changes due to expansions | |
# Ensure shrink_features is properly formatted | |
if isinstance(shrink_features[0], list): | |
# If it's a list of lists (batch_size=1 but multiple turns), flatten it | |
shrink_features = [item for sublist in shrink_features for item in sublist] | |
for pos_idx, audio_pos in enumerate(audio_start_token_positions): | |
adjusted_pos = audio_pos + position_shift | |
# Get corresponding shrink feature (ensure it's a tensor) | |
shrink_feature = shrink_features[pos_idx] | |
if isinstance(shrink_feature, list): | |
shrink_feature = torch.stack(shrink_feature, dim=0) | |
v = shrink_feature.shape[0] # Now this should work | |
# print('len: ', v) | |
# Expand the input ids and embeddings | |
current_ids = torch.cat([ | |
current_ids[:adjusted_pos], | |
current_ids[adjusted_pos:adjusted_pos+1].repeat(v), | |
current_ids[adjusted_pos + num_patches:] | |
], dim=0) | |
current_embeds = torch.cat([ | |
current_embeds[:adjusted_pos], | |
shrink_feature, | |
current_embeds[adjusted_pos + num_patches:] | |
], dim=0) | |
# Update position shift for next iteration | |
position_shift += (v - num_patches) | |
# Update the tensors (unsqueeze to restore batch dim) | |
input_ids = current_ids.unsqueeze(0) # [1, new_seq_len] | |
inputs_embeds = current_embeds.unsqueeze(0) # [1, new_seq_len, embed_dim] | |
attention_mask = input_ids.ne(audio_config.llm_pad_token_id) | |
# Update inference state tracking | |
if not hasattr(self, 'mask_tensor'): | |
# Initialize with current attention mask | |
self.mask_tensor = attention_mask.clone() | |
self.length = attention_mask.shape[1] | |
else: | |
# Ensure mask tensor is on correct device | |
self.mask_tensor = self.mask_tensor.to(attention_mask.device) | |
# Expand mask tensor if needed | |
if self.mask_tensor.shape[1] < attention_mask.shape[1]: | |
new_mask = torch.zeros(1, attention_mask.shape[1], | |
dtype=torch.bool, | |
device=attention_mask.device) | |
new_mask[0, :self.mask_tensor.shape[1]] = self.mask_tensor | |
self.mask_tensor = new_mask | |
# Update mask tensor | |
self.mask_tensor[0, :attention_mask.shape[1]] = attention_mask[0] | |
self.length = attention_mask.shape[1] | |
attention_mask=self.mask_tensor[:,:self.length] | |
self.length+=1 | |
return_state=super(ACLlamaModel, self).forward( | |
input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values, | |
inputs_embeds=inputs_embeds, use_cache=use_cache, | |
output_attentions=output_attentions, output_hidden_states=output_hidden_states, | |
return_dict=return_dict | |
) | |
if self.training and audios is not None: | |
return_state["audio_features"] = predict_logits | |
return_state["label_shift"] = batch_label_shift | |
return_state["label_extend"] = label_extend | |
return_state["speech_pos"] = batch_speech_pos | |
#return_state = {"audio_features":predict_logits} | |
return return_state | |
class ACLlamaForCausalLM(LlamaForCausalLM): | |
config_class = ACLlamaConfig | |
def __init__(self, config): | |
super(LlamaForCausalLM, self).__init__(config) | |
self.model = ACLlamaModel(config) | |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | |
# t2u by kkq | |
if hasattr(config, "unit_output"): | |
self.unit_output = config.unit_output | |
self.unit_translator = T2ULlamaForCausalLM(config, self.lm_head.weight) | |
# Initialize weights and apply final processing | |
self.post_init() | |
def get_model(self): | |
return self.model | |
def get_unit_translator(self): | |
return self.unit_translator | |
def forward( | |
self, | |
input_ids: torch.LongTensor = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_values: Optional[List[torch.FloatTensor]] = None, | |
inputs_embeds: Optional[torch.FloatTensor] = None, | |
labels: Optional[torch.LongTensor] = None, | |
t2u_input_ids: Optional[torch.LongTensor] = None, | |
t2u_labels: Optional[torch.LongTensor] = None, | |
t2u_attention_mask: Optional[torch.Tensor] = None, | |
unit_targets: Optional[torch.Tensor] = None, | |
sub_lengths: Optional[torch.Tensor] = None, | |
asr_targets: Optional[torch.LongTensor] = None, | |
use_cache: Optional[bool] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
audios: Optional[torch.FloatTensor] = None, | |
return_dict: Optional[bool] = None, | |
cache_position: Optional[torch.LongTensor] = None, | |
do_task: str = None, | |
assistant_after_audio_shifts: Optional[torch.Tensor] = None, | |
) -> Union[Tuple, CausalLMOutputWithPast]: | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) | |
# t2u by kkq | |
# pretrain(t2u only) finetune(s2t&e2u) | |
do_task = do_task if do_task != None else getattr(self, 'unit_output', None) | |
outputs = None | |
hidden_states = None | |
new_shift_labels = None | |
if do_task != "pretrain": | |
outputs = self.model( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
past_key_values=past_key_values, | |
inputs_embeds=inputs_embeds, | |
use_cache=use_cache, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
audios=audios | |
) | |
hidden_states = outputs[0] | |
logits = self.lm_head(hidden_states) | |
loss = None | |
if labels is not None and do_task != "pretrain" and do_task != "finetune_kd": | |
if asr_targets is not None: | |
asr_logits = outputs["audio_features"] | |
asr_targets = asr_targets.view(asr_targets.shape[0] * asr_targets.shape[1], asr_targets.shape[2]) | |
mask_asr_targets = (asr_targets != IGNORE_TOKEN_ID) | |
target_lengths = mask_asr_targets.sum(1) | |
input_lengths = torch.full(size=(asr_logits.shape[0],), fill_value=asr_logits.shape[1], dtype=torch.long) | |
loss_ctc = CTCLoss() | |
log_probs = F.log_softmax(asr_logits, dim=-1).transpose(0, 1) | |
#print(asr_targets.shape) | |
#print(input_lengths, target_lengths) | |
with torch.backends.cudnn.flags(enabled=False): | |
loss_asr = F.ctc_loss( | |
log_probs, | |
asr_targets, | |
input_lengths, | |
target_lengths, | |
blank=self.model.audio_tower[0].config.audio_patch_token, | |
reduction='mean', | |
zero_infinity=True, | |
) | |
else: | |
loss_asr=0 | |
shift_labels = labels | |
if "label_shift" in outputs.keys() and len(outputs["label_shift"]) >0: | |
if outputs["label_extend"] != -1: | |
new_shift_labels = torch.full(size=(shift_labels.shape[0], outputs["label_extend"]+shift_labels.shape[1]), fill_value=IGNORE_TOKEN_ID, dtype=torch.long).to(shift_labels.device) | |
for batch in range(len(outputs["label_shift"])): | |
it_lable_shift = outputs["label_shift"][batch] | |
it_speech_pos = outputs["speech_pos"][batch] | |
prefix = 0 | |
for i in range(len(it_lable_shift)): | |
if i == len(it_lable_shift) - 1: | |
length = shift_labels.shape[1] - it_speech_pos[i] #len(shift_labels[batch]) - it_speech_pos[i] | |
else: | |
length = it_speech_pos[i + 1] - it_speech_pos[i] | |
prefix += it_lable_shift[i] | |
new_shift_labels[batch][it_speech_pos[i] + prefix: it_speech_pos[i] + length + prefix]= shift_labels[batch][it_speech_pos[i]:it_speech_pos[i]+length] | |
shift_labels = new_shift_labels | |
else: | |
raise NotImplementedError | |
# Shift so that tokens < n predict n | |
shift_logits = logits[..., :-1, :].contiguous() | |
shift_labels = shift_labels[..., 1:].contiguous() | |
#print(shift_labels[:,:50]) | |
#print(shift_labels[:,:150]) | |
loss_fct = CrossEntropyLoss() | |
# Flatten the tokens | |
shift_logits = shift_logits.view(-1, self.config.vocab_size) | |
shift_labels = shift_labels.view(-1) | |
shift_labels = shift_labels.to(shift_logits.device) | |
loss = loss_fct(shift_logits, shift_labels) | |
loss = loss + 0.3 * loss_asr | |
t2u_output = None | |
if do_task != None and do_task != "skip": | |
if do_task == "finetune_kd": | |
text_start_index = [] | |
for batch in range(len(outputs["label_shift"])): | |
text_start_index.append(outputs["speech_pos"][batch][0] + outputs["label_shift"][batch][0]+assistant_after_audio_shifts[batch]) | |
t2u_embeds_output = self.unit_translator.insert_text_embedding( | |
input_ids=t2u_input_ids, | |
attention_mask=t2u_attention_mask, | |
inputs_embeds=None, | |
labels=t2u_labels, | |
text_labels=labels, | |
shift_text_labels=new_shift_labels, | |
shift_text_hidden_states=hidden_states, | |
unit_targets=unit_targets, | |
sub_lengths=sub_lengths, | |
text_start_index=text_start_index, | |
do_task=do_task, | |
) | |
vae_loss, t2u_inputs_embeds, unit_targets, t2u_attention_mask = t2u_embeds_output | |
t2u_output = self.unit_translator( | |
input_ids=None, | |
attention_mask=t2u_attention_mask, | |
past_key_values=past_key_values, | |
inputs_embeds=t2u_inputs_embeds, | |
use_cache=use_cache, | |
labels=unit_targets, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
else: | |
t2u_embeds_output = self.unit_translator.insert_text_embedding( | |
input_ids=t2u_input_ids, | |
attention_mask=t2u_attention_mask, | |
inputs_embeds=None, | |
labels=t2u_labels, | |
text_labels=labels, | |
shift_text_labels=new_shift_labels, | |
shift_text_hidden_states=hidden_states, | |
do_task=do_task, | |
) | |
vae_loss, t2u_inputs_embeds = t2u_embeds_output | |
t2u_output = self.unit_translator( | |
input_ids=None, | |
attention_mask=t2u_attention_mask, | |
past_key_values=past_key_values, | |
inputs_embeds=t2u_inputs_embeds, | |
use_cache=use_cache, | |
labels=t2u_labels, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
t2u_loss = t2u_output[0] | |
# print(do_task, t2u_loss, vae_loss) | |
if vae_loss != None: | |
target_scale = t2u_loss.item() * 0.2 | |
vae_loss_weight = target_scale / vae_loss.item() if vae_loss > target_scale else 1.0 | |
t2u_loss = t2u_loss + vae_loss_weight * vae_loss | |
#print(vae_loss) | |
if loss != None: # S2T + T2U loss | |
# ignore LLM loss | |
# t2u_output["loss"] = t2u_loss | |
# return t2u_output | |
# original version | |
assert do_task in ["finetune"] | |
if loss.item() < 1.0: # 1.7 | |
loss = 0.2 * loss + t2u_loss * 2.0 | |
else: | |
loss = loss + t2u_loss | |
else: | |
assert do_task in ["pretrain", "finetune_kd"] | |
t2u_output["loss"] = t2u_loss | |
return t2u_output | |
#return CausalLMOutputWithPast( | |
# loss=loss, | |
# logits=outputs["audio_features"], | |
#) | |
if not return_dict: | |
output = (logits,) + outputs[1:] | |
return (loss,) + output if loss is not None else output | |
return CausalLMOutputWithPast( | |
loss=loss, | |
logits=logits, | |
past_key_values=outputs.past_key_values, | |
hidden_states=outputs.hidden_states, | |
attentions=outputs.attentions, | |
) | |
def prepare_inputs_for_generation( | |
self, | |
input_ids, | |
past_key_values=None, | |
attention_mask=None, | |
inputs_embeds=None, | |
cache_position=None, | |
position_ids=None, | |
use_cache=True, | |
**kwargs, | |
): | |
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens | |
# Exception 1: when passing input_embeds, input_ids may be missing entries | |
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here | |
if past_key_values is not None: | |
if inputs_embeds is not None: # Exception 1 | |
input_ids = input_ids[:, -cache_position.shape[0] :] | |
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) | |
input_ids = input_ids[:, cache_position] | |
if attention_mask is not None and position_ids is None: | |
# create position_ids on the fly for batch generation | |
position_ids = attention_mask.long().cumsum(-1) - 1 | |
position_ids.masked_fill_(attention_mask == 0, 1) | |
if past_key_values: | |
position_ids = position_ids[:, -input_ids.shape[1] :] | |
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step | |
if inputs_embeds is not None and cache_position[0] == 0: | |
model_inputs = {"inputs_embeds": inputs_embeds} | |
else: | |
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases | |
model_inputs.update( | |
{ | |
"position_ids": position_ids, | |
"cache_position": cache_position, | |
"past_key_values": past_key_values, | |
"use_cache": use_cache, | |
"attention_mask": attention_mask, | |
} | |
) | |
model_inputs.update({"audios": kwargs["audios"]} if "audios" in kwargs.keys() else {}) | |
model_inputs.update({"do_task": kwargs["do_task"]} if "do_task" in kwargs.keys() else {}) | |
model_inputs.update({"return_dict": kwargs["return_dict_in_generate"]} if "return_dict_in_generate" in kwargs.keys() else {}) | |
return model_inputs | |
AutoConfig.register("ACLlama", ACLlamaConfig) | |
AutoModelForCausalLM.register(ACLlamaConfig, ACLlamaForCausalLM) |