import sys from ACLlama_el_s2s import ACLlamaForCausalLM from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoConfig, WhisperProcessor from peft import PeftModel, PeftConfig import json from tqdm import tqdm import torch import re import os torch.backends.cudnn.benchmark = False import librosa from text_to_speech import * import torch.nn.functional as F from concurrent.futures import ThreadPoolExecutor, as_completed from transformers import logging as hf_logging hf_logging.set_verbosity_error() from huggingface_hub import hf_hub_download from typing import Dict, Optional, List import tempfile import select from copy import deepcopy from typing import Generator, Tuple os.environ["TOKENIZERS_PARALLELISM"] = "true" def load_model(args, device): quantization_config = None hf_token = os.getenv("HF_TOKEN") # load based model model = ACLlamaForCausalLM.from_pretrained( args.base_model_path, device_map=None, torch_dtype=torch.float16, quantization_config=quantization_config, token=hf_token, ).eval().to(device) for module in model.model.audio_tower: module = module.to(device) if args.peft_model_id: lora_config = PeftConfig.from_pretrained(args.peft_model_id) torch.cuda.empty_cache() model = PeftModel.from_pretrained(model, args.peft_model_id, config=lora_config).to( dtype=torch.float16, device=device ) model = model.merge_and_unload() model.eval() # load tokenizer tokenizer = AutoTokenizer.from_pretrained(args.base_model_path, token=hf_token) audio_config = model.get_model().audio_tower[0].config audio_config.audio_patch_token = tokenizer.get_vocab()[""] audio_config.llm_pad_token_id = tokenizer.pad_token_id audio_config.audio_patch_size = args.audio_token_len # whisper processor audio_processor = WhisperProcessor.from_pretrained(args.audio_tower, torch_dtype=torch.float16) # t2u unit_translator = model.get_unit_translator().eval() return model, audio_processor, tokenizer, unit_translator def load_speech_model(device): vocoder = "./vocoder/g_00500000" vocoder_cfg = "./vocoder/config.json" voc_cfg = get_vocoder_config(vocoder, vocoder_cfg) vocoder = load_units_vocoder(voc_cfg, device) return vocoder, voc_cfg # def load_speech_model(device): # hf_token = os.getenv("HF_TOKEN") # vocoder_repo_id = "FreedomIntelligence/EchoX-Vocoder" # cache_path = './hf_cache' # vocoder_path = hf_hub_download(repo_id=vocoder_repo_id, filename="g_00500000", token=hf_token, cache_dir=cache_path) # vocoder_cfg_path = hf_hub_download(repo_id=vocoder_repo_id, filename="config.json", token=hf_token, cache_dir=cache_path) # voc_cfg = get_vocoder_config(vocoder_path, vocoder_cfg_path) # vocoder = load_units_vocoder(voc_cfg, device) # return vocoder, voc_cfg class EchoxAssistant(): def __init__(self): class BasicSetting: def __init__(self): self.device = "cuda:0" self.sampling_rate = 16000 self.audio_token_len = 1 # 1500 = 300 token x 5 compress self.stop = "" self.base_model_path = "FreedomIntelligence/EchoX-8B" self.peft_model_id = None self.audio_tower = "openai/whisper-large-v3" self.args = BasicSetting() self.device = "cuda" self.vocoder, self.voc_cfg= load_speech_model(self.device) self.model, self.audio_processor, self.tokenizer, self.unit_translator = load_model(self.args, self.device) self.audio_executor = ThreadPoolExecutor(max_workers=2) # self.specAug = SpecAugmentTransform() # special_token DEFAULT_AUDIO_PATCH_TOKEN = "" audio_placeholder = DEFAULT_AUDIO_PATCH_TOKEN * self.args.audio_token_len audio_placeholder = "\n"+audio_placeholder self.audio_placeholder_ids = self.tokenizer(audio_placeholder).input_ids self.begin_of_text_id = self.tokenizer.get_vocab()["<|begin_of_text|>"] self.start_header_id = self.tokenizer.get_vocab()["<|start_header_id|>"] self.end_header_id = self.tokenizer.get_vocab()["<|end_header_id|>"] self.eot_id = self.tokenizer.get_vocab()["<|eot_id|>"] self.nl_tokens = self.tokenizer('\n').input_ids self._system = self.tokenizer('system').input_ids self._user = self.tokenizer('user').input_ids self._assistant = self.tokenizer('assistant').input_ids self._speaker = self.tokenizer('speaker').input_ids self.max_len = 1024 self.unit_max_len = 2048 self.system_message = "You are a helpful language and speech assistant. You are able to understand the speech content that the user provides, and assist the user with a variety of tasks using natural language." def _generate_audio_segment(self, segment_hidden_states): try: audio_units = self._generate_audio_units_from_hidden_states(segment_hidden_states) if audio_units: audio_float32 = self.generate_with_speech_model([list(map(int, audio_units.split(" ")))]) audio_int16 = (audio_float32 * 32767).astype(np.int16) print(f"Generated audio segment in background: {len(audio_units.split())} units") return (16000, audio_int16) return None except Exception as e: print(f"Background audio generation error: {e}") return None def gen_model_inputs( self, sources, tokenizer, max_len, system_message, audio_placeholder_ids, begin_of_text_id, start_header_id, end_header_id, eot_id, nl_tokens, _system, _user, _assistant, ) -> dict: # max_len 512 # Apply prompt templates input_ids, audio_paths = [], [] audio_path = [] for source in sources: input_id = [] system = [begin_of_text_id] + [start_header_id] + _system + [end_header_id] + nl_tokens + tokenizer(system_message).input_ids + [eot_id] input_id += system for j, item in enumerate(source["conversations"]): role = item["from"] value = item["value"] _audio_path = None if role == 'user': if "audio" in item.keys(): _input_id = [start_header_id] + _user + [end_header_id] + audio_placeholder_ids + tokenizer(value).input_ids + [eot_id] _audio_path = item["audio"] else: _input_id = [start_header_id] + _user + [end_header_id] + tokenizer(value).input_ids + [eot_id] elif role == 'assistant': _input_id = [start_header_id] + _assistant + [end_header_id] + nl_tokens + tokenizer(value).input_ids + [eot_id] else: raise NotImplementedError input_id += _input_id if _audio_path: audio_path.append(_audio_path) assistant_input_id = [start_header_id] + _assistant + [end_header_id] + nl_tokens input_id += assistant_input_id audio_num = int(input_id.count(audio_placeholder_ids[-1]) / self.args.audio_token_len) assert len(audio_path) == audio_num if len(input_id) >= max_len: print(f"[WARNING] Your Input Length More Than {max_len}") input_ids.append(input_id[:max_len]) audio_paths.append(audio_path) input_ids = torch.tensor(input_ids, dtype=torch.int) return dict( input_ids=input_ids, audio_paths=audio_paths, attention_mask=input_ids.ne(tokenizer.pad_token_id), ) def get_unit_result(self, ret): # print(ret) self.unit_translator.generation_config.pad_token_id = self.tokenizer.eos_token_id input_ids = ret["input_ids"] ret["input_ids"] = None model_outputs = self.unit_translator.generate( **ret, max_new_tokens=2048, eos_token_id=self.tokenizer.eos_token_id, ) # print(model_outputs, model_outputs.shape) output_ids = model_outputs unit_output = self.tokenizer.batch_decode(output_ids)[0] if "▁" in unit_output: unit_output = ''.join(re.findall(r"<\|unit_(.*?)\|>", unit_output)) units = re.findall(r'\d+', unit_output) #TODO grid of unk unit new_units = [] for unit in units: if int(unit) < 1000: new_units.append(unit) units = ' '.join(new_units) return units def _inference( self, prompt, **kwargs, ): audio_paths = [] response = [] for item in prompt: for conv in item["conversations"]: if "audio" in conv: audio_paths.append(conv["audio"]) model_inputs = self.gen_model_inputs( prompt, self.tokenizer, self.max_len, self.system_message, self.audio_placeholder_ids, self.begin_of_text_id, self.start_header_id, self.end_header_id, self.eot_id, self.nl_tokens, self._system, self._user, self._assistant) audio_list = [] if audio_paths and audio_paths[0] is not None: for audio_path in audio_paths: # print("read audio file name: ", audio_path) audio, _ = librosa.load(audio_path, sr=self.args.sampling_rate) audio_feat = self.audio_processor(audio, sampling_rate=self.args.sampling_rate, return_tensors="pt").input_features audio_list.append(audio_feat) audio_feats = torch.stack(audio_list, dim=0) audio_feats = audio_feats.to(dtype=torch.float16).to(self.device) if not audio_list: ret = dict( input_ids=model_inputs["input_ids"].to(self.device), attention_mask=model_inputs["attention_mask"].to(self.device), ) else: ret = dict( input_ids=model_inputs["input_ids"].to(self.device), attention_mask=model_inputs["attention_mask"].to(self.device), audios=audio_feats, ) self.model.generation_config.pad_token_id = self.tokenizer.eos_token_id #print(self.model.lm_head.weight.shape) dot_input_ids = self.tokenizer(".", return_tensors="pt").input_ids.to(self.device) # 形状: (1, 2), 值: [[128000, 13]] period_token_id = dot_input_ids[0, -1] period_lm_head_embedding = self.model.lm_head.weight[period_token_id] input_ids = ret["input_ids"] attention_mask = ret["attention_mask"] input_token_len = input_ids.shape[1] max_new_tokens = kwargs.get('max_new_tokens', 512) temperature = kwargs.get('temperature', 0.2) top_p = kwargs.get('top_p', 0.9) do_sample = kwargs.get('do_sample', True) current_text = "" accumulated_hidden_states = [] accumulated_tokens = [] similarity_scores = [] segment_start_idx = 0 current_input_ids = input_ids current_attention_mask = attention_mask past_key_values = None audio_futures = [] segmentation_latency = 5 with torch.no_grad(): for step in range(max_new_tokens): while audio_futures and audio_futures[0].done(): completed_future = audio_futures.pop(0) audio_data = completed_future.result() if audio_data: yield None, audio_data if current_input_ids is None: break model_kwargs = { "input_ids": current_input_ids, "attention_mask": current_attention_mask, "past_key_values": past_key_values, "use_cache": True, "output_hidden_states": True, "do_task": "skip" } if step == 0 and "audios" in ret: model_kwargs["audios"] = ret["audios"] outputs = self.model(**model_kwargs) logits = outputs.logits hidden_states = outputs.hidden_states[-1] past_key_values = outputs.past_key_values next_token_logits = logits[:, -1, :] # [batch_size, vocab_size] if do_sample: next_token_logits = next_token_logits / temperature sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) next_token_logits[indices_to_remove] = float('-inf') probs = F.softmax(next_token_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) else: next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) if next_token.item() == self.tokenizer.eos_token_id: current_input_ids = None continue accumulated_tokens.append(next_token.item()) last_hidden_state = hidden_states[0, -1] # [hidden_dim] accumulated_hidden_states.append(last_hidden_state) similarity = F.cosine_similarity(last_hidden_state, period_lm_head_embedding, dim=0).item() similarity_scores.append(similarity) token_text = self.tokenizer.decode([next_token.item()], skip_special_tokens=True) current_text += token_text yield current_text, None current_idx = len(similarity_scores) - 1 check_idx = current_idx - segmentation_latency if check_idx >= 0: similarity_at_check = similarity_scores[check_idx] is_peak = self._is_local_maximum(similarity_scores, check_idx, window=segmentation_latency) should_segment = (is_peak and check_idx - segment_start_idx >= 50) or ( is_peak and similarity_at_check > 0.1 and check_idx - segment_start_idx >= 20 ) if should_segment: segment_end_idx = check_idx + 1 print(f"Segmenting at step {segment_end_idx-1}, similarity={similarity_at_check:.4f}. Submitting to background audio generation.") segment_hidden_states = torch.stack( accumulated_hidden_states[segment_start_idx:segment_end_idx], dim=0 ).unsqueeze(0) future = self.audio_executor.submit(self._generate_audio_segment, segment_hidden_states) audio_futures.append(future) segment_start_idx = segment_end_idx current_input_ids = next_token current_attention_mask = torch.ones_like(next_token) if segment_start_idx < len(accumulated_hidden_states): print(f"Processing final segment from {segment_start_idx} to {len(accumulated_hidden_states)}") segment_hidden_states = torch.stack( accumulated_hidden_states[segment_start_idx:], dim=0 ).unsqueeze(0) future = self.audio_executor.submit(self._generate_audio_segment, segment_hidden_states) audio_futures.append(future) for future in audio_futures: audio_data = future.result() if audio_data: yield None, audio_data def _is_local_maximum(self, scores, idx, window=5): start = max(0, idx - window) end = min(len(scores), idx + window + 1) local_scores = scores[start:end] return scores[idx] == max(local_scores) def _generate_audio_units_from_hidden_states(self, hidden_states): try: _, adapted_inputs_embeds = self.unit_translator.insert_text_embedding( inputs_embeds=hidden_states, do_task="skip", ) attention_mask = torch.ones(adapted_inputs_embeds.shape[:2]).to(self.device) ret = dict( input_ids=None, inputs_embeds=adapted_inputs_embeds, attention_mask=attention_mask, ) return self.get_unit_result(ret) except Exception as e: print(f"Error generating audio units: {e}") return None def generate_with_speech_model(self, units): wav = gen_wav(self.vocoder, self.voc_cfg, units, self.device) return wav