EchoX / Echox_copy_stream.py
tzzte's picture
update
cf432f5
import sys
from ACLlama_el_s2s import ACLlamaForCausalLM
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoConfig, WhisperProcessor
from peft import PeftModel, PeftConfig
import json
from tqdm import tqdm
import torch
import re
import os
torch.backends.cudnn.benchmark = False
import librosa
from text_to_speech import *
import torch.nn.functional as F
from concurrent.futures import ThreadPoolExecutor, as_completed
from transformers import logging as hf_logging
hf_logging.set_verbosity_error()
from huggingface_hub import hf_hub_download
from typing import Dict, Optional, List
import tempfile
import select
from copy import deepcopy
from typing import Generator, Tuple
os.environ["TOKENIZERS_PARALLELISM"] = "true"
def load_model(args, device):
quantization_config = None
hf_token = os.getenv("HF_TOKEN")
# load based model
model = ACLlamaForCausalLM.from_pretrained(
args.base_model_path,
device_map=None,
torch_dtype=torch.float16,
quantization_config=quantization_config,
token=hf_token,
).eval().to(device)
for module in model.model.audio_tower:
module = module.to(device)
if args.peft_model_id:
lora_config = PeftConfig.from_pretrained(args.peft_model_id)
torch.cuda.empty_cache()
model = PeftModel.from_pretrained(model, args.peft_model_id, config=lora_config).to(
dtype=torch.float16, device=device
)
model = model.merge_and_unload()
model.eval()
# load tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.base_model_path, token=hf_token)
audio_config = model.get_model().audio_tower[0].config
audio_config.audio_patch_token = tokenizer.get_vocab()["<audio_patch>"]
audio_config.llm_pad_token_id = tokenizer.pad_token_id
audio_config.audio_patch_size = args.audio_token_len
# whisper processor
audio_processor = WhisperProcessor.from_pretrained(args.audio_tower, torch_dtype=torch.float16)
# t2u
unit_translator = model.get_unit_translator().eval()
return model, audio_processor, tokenizer, unit_translator
def load_speech_model(device):
vocoder = "./vocoder/g_00500000"
vocoder_cfg = "./vocoder/config.json"
voc_cfg = get_vocoder_config(vocoder, vocoder_cfg)
vocoder = load_units_vocoder(voc_cfg, device)
return vocoder, voc_cfg
# def load_speech_model(device):
# hf_token = os.getenv("HF_TOKEN")
# vocoder_repo_id = "FreedomIntelligence/EchoX-Vocoder"
# cache_path = './hf_cache'
# vocoder_path = hf_hub_download(repo_id=vocoder_repo_id, filename="g_00500000", token=hf_token, cache_dir=cache_path)
# vocoder_cfg_path = hf_hub_download(repo_id=vocoder_repo_id, filename="config.json", token=hf_token, cache_dir=cache_path)
# voc_cfg = get_vocoder_config(vocoder_path, vocoder_cfg_path)
# vocoder = load_units_vocoder(voc_cfg, device)
# return vocoder, voc_cfg
class EchoxAssistant():
def __init__(self):
class BasicSetting:
def __init__(self):
self.device = "cuda:0"
self.sampling_rate = 16000
self.audio_token_len = 1 # 1500 = 300 token x 5 compress
self.stop = "</s>"
self.base_model_path = "FreedomIntelligence/EchoX-8B"
self.peft_model_id = None
self.audio_tower = "openai/whisper-large-v3"
self.args = BasicSetting()
self.device = "cuda"
self.vocoder, self.voc_cfg= load_speech_model(self.device)
self.model, self.audio_processor, self.tokenizer, self.unit_translator = load_model(self.args, self.device)
self.audio_executor = ThreadPoolExecutor(max_workers=2)
# self.specAug = SpecAugmentTransform()
# special_token
DEFAULT_AUDIO_PATCH_TOKEN = "<audio_patch>"
audio_placeholder = DEFAULT_AUDIO_PATCH_TOKEN * self.args.audio_token_len
audio_placeholder = "\n"+audio_placeholder
self.audio_placeholder_ids = self.tokenizer(audio_placeholder).input_ids
self.begin_of_text_id = self.tokenizer.get_vocab()["<|begin_of_text|>"]
self.start_header_id = self.tokenizer.get_vocab()["<|start_header_id|>"]
self.end_header_id = self.tokenizer.get_vocab()["<|end_header_id|>"]
self.eot_id = self.tokenizer.get_vocab()["<|eot_id|>"]
self.nl_tokens = self.tokenizer('\n').input_ids
self._system = self.tokenizer('system').input_ids
self._user = self.tokenizer('user').input_ids
self._assistant = self.tokenizer('assistant').input_ids
self._speaker = self.tokenizer('speaker').input_ids
self.max_len = 1024
self.unit_max_len = 2048
self.system_message = "You are a helpful language and speech assistant. You are able to understand the speech content that the user provides, and assist the user with a variety of tasks using natural language."
def _generate_audio_segment(self, segment_hidden_states):
try:
audio_units = self._generate_audio_units_from_hidden_states(segment_hidden_states)
if audio_units:
audio_float32 = self.generate_with_speech_model([list(map(int, audio_units.split(" ")))])
audio_int16 = (audio_float32 * 32767).astype(np.int16)
print(f"Generated audio segment in background: {len(audio_units.split())} units")
return (16000, audio_int16)
return None
except Exception as e:
print(f"Background audio generation error: {e}")
return None
def gen_model_inputs(
self,
sources,
tokenizer,
max_len,
system_message,
audio_placeholder_ids, begin_of_text_id, start_header_id, end_header_id, eot_id, nl_tokens, _system, _user, _assistant,
) -> dict:
# max_len 512
# Apply prompt templates
input_ids, audio_paths = [], []
audio_path = []
for source in sources:
input_id = []
system = [begin_of_text_id] + [start_header_id] + _system + [end_header_id] + nl_tokens + tokenizer(system_message).input_ids + [eot_id]
input_id += system
for j, item in enumerate(source["conversations"]):
role = item["from"]
value = item["value"]
_audio_path = None
if role == 'user':
if "audio" in item.keys():
_input_id = [start_header_id] + _user + [end_header_id] + audio_placeholder_ids + tokenizer(value).input_ids + [eot_id]
_audio_path = item["audio"]
else:
_input_id = [start_header_id] + _user + [end_header_id] + tokenizer(value).input_ids + [eot_id]
elif role == 'assistant':
_input_id = [start_header_id] + _assistant + [end_header_id] + nl_tokens + tokenizer(value).input_ids + [eot_id]
else:
raise NotImplementedError
input_id += _input_id
if _audio_path:
audio_path.append(_audio_path)
assistant_input_id = [start_header_id] + _assistant + [end_header_id] + nl_tokens
input_id += assistant_input_id
audio_num = int(input_id.count(audio_placeholder_ids[-1]) / self.args.audio_token_len)
assert len(audio_path) == audio_num
if len(input_id) >= max_len:
print(f"[WARNING] Your Input Length More Than {max_len}")
input_ids.append(input_id[:max_len])
audio_paths.append(audio_path)
input_ids = torch.tensor(input_ids, dtype=torch.int)
return dict(
input_ids=input_ids,
audio_paths=audio_paths,
attention_mask=input_ids.ne(tokenizer.pad_token_id),
)
def get_unit_result(self, ret):
# print(ret)
self.unit_translator.generation_config.pad_token_id = self.tokenizer.eos_token_id
input_ids = ret["input_ids"]
ret["input_ids"] = None
model_outputs = self.unit_translator.generate(
**ret,
max_new_tokens=2048,
eos_token_id=self.tokenizer.eos_token_id,
)
# print(model_outputs, model_outputs.shape)
output_ids = model_outputs
unit_output = self.tokenizer.batch_decode(output_ids)[0]
if "▁" in unit_output:
unit_output = ''.join(re.findall(r"<\|unit_(.*?)\|>", unit_output))
units = re.findall(r'\d+', unit_output)
#TODO grid of unk unit
new_units = []
for unit in units:
if int(unit) < 1000:
new_units.append(unit)
units = ' '.join(new_units)
return units
def _inference(
self,
prompt,
**kwargs,
):
audio_paths = []
response = []
for item in prompt:
for conv in item["conversations"]:
if "audio" in conv:
audio_paths.append(conv["audio"])
model_inputs = self.gen_model_inputs(
prompt,
self.tokenizer,
self.max_len,
self.system_message,
self.audio_placeholder_ids, self.begin_of_text_id, self.start_header_id, self.end_header_id, self.eot_id, self.nl_tokens, self._system, self._user, self._assistant)
audio_list = []
if audio_paths and audio_paths[0] is not None:
for audio_path in audio_paths:
# print("read audio file name: ", audio_path)
audio, _ = librosa.load(audio_path, sr=self.args.sampling_rate)
audio_feat = self.audio_processor(audio, sampling_rate=self.args.sampling_rate, return_tensors="pt").input_features
audio_list.append(audio_feat)
audio_feats = torch.stack(audio_list, dim=0)
audio_feats = audio_feats.to(dtype=torch.float16).to(self.device)
if not audio_list:
ret = dict(
input_ids=model_inputs["input_ids"].to(self.device),
attention_mask=model_inputs["attention_mask"].to(self.device),
)
else:
ret = dict(
input_ids=model_inputs["input_ids"].to(self.device),
attention_mask=model_inputs["attention_mask"].to(self.device),
audios=audio_feats,
)
self.model.generation_config.pad_token_id = self.tokenizer.eos_token_id
#print(self.model.lm_head.weight.shape)
dot_input_ids = self.tokenizer(".", return_tensors="pt").input_ids.to(self.device) # 形状: (1, 2), 值: [[128000, 13]]
period_token_id = dot_input_ids[0, -1]
period_lm_head_embedding = self.model.lm_head.weight[period_token_id]
input_ids = ret["input_ids"]
attention_mask = ret["attention_mask"]
input_token_len = input_ids.shape[1]
max_new_tokens = kwargs.get('max_new_tokens', 512)
temperature = kwargs.get('temperature', 0.2)
top_p = kwargs.get('top_p', 0.9)
do_sample = kwargs.get('do_sample', True)
current_text = ""
accumulated_hidden_states = []
accumulated_tokens = []
similarity_scores = []
segment_start_idx = 0
current_input_ids = input_ids
current_attention_mask = attention_mask
past_key_values = None
audio_futures = []
segmentation_latency = 5
with torch.no_grad():
for step in range(max_new_tokens):
while audio_futures and audio_futures[0].done():
completed_future = audio_futures.pop(0)
audio_data = completed_future.result()
if audio_data:
yield None, audio_data
if current_input_ids is None:
break
model_kwargs = {
"input_ids": current_input_ids,
"attention_mask": current_attention_mask,
"past_key_values": past_key_values,
"use_cache": True,
"output_hidden_states": True,
"do_task": "skip"
}
if step == 0 and "audios" in ret:
model_kwargs["audios"] = ret["audios"]
outputs = self.model(**model_kwargs)
logits = outputs.logits
hidden_states = outputs.hidden_states[-1]
past_key_values = outputs.past_key_values
next_token_logits = logits[:, -1, :] # [batch_size, vocab_size]
if do_sample:
next_token_logits = next_token_logits / temperature
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
next_token_logits[indices_to_remove] = float('-inf')
probs = F.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
else:
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
if next_token.item() == self.tokenizer.eos_token_id:
current_input_ids = None
continue
accumulated_tokens.append(next_token.item())
last_hidden_state = hidden_states[0, -1] # [hidden_dim]
accumulated_hidden_states.append(last_hidden_state)
similarity = F.cosine_similarity(last_hidden_state, period_lm_head_embedding, dim=0).item()
similarity_scores.append(similarity)
token_text = self.tokenizer.decode([next_token.item()], skip_special_tokens=True)
current_text += token_text
yield current_text, None
current_idx = len(similarity_scores) - 1
check_idx = current_idx - segmentation_latency
if check_idx >= 0:
similarity_at_check = similarity_scores[check_idx]
is_peak = self._is_local_maximum(similarity_scores, check_idx, window=segmentation_latency)
should_segment = (is_peak and
check_idx - segment_start_idx >= 50) or (
is_peak and
similarity_at_check > 0.1 and
check_idx - segment_start_idx >= 20
)
if should_segment:
segment_end_idx = check_idx + 1
print(f"Segmenting at step {segment_end_idx-1}, similarity={similarity_at_check:.4f}. Submitting to background audio generation.")
segment_hidden_states = torch.stack(
accumulated_hidden_states[segment_start_idx:segment_end_idx], dim=0
).unsqueeze(0)
future = self.audio_executor.submit(self._generate_audio_segment, segment_hidden_states)
audio_futures.append(future)
segment_start_idx = segment_end_idx
current_input_ids = next_token
current_attention_mask = torch.ones_like(next_token)
if segment_start_idx < len(accumulated_hidden_states):
print(f"Processing final segment from {segment_start_idx} to {len(accumulated_hidden_states)}")
segment_hidden_states = torch.stack(
accumulated_hidden_states[segment_start_idx:], dim=0
).unsqueeze(0)
future = self.audio_executor.submit(self._generate_audio_segment, segment_hidden_states)
audio_futures.append(future)
for future in audio_futures:
audio_data = future.result()
if audio_data:
yield None, audio_data
def _is_local_maximum(self, scores, idx, window=5):
start = max(0, idx - window)
end = min(len(scores), idx + window + 1)
local_scores = scores[start:end]
return scores[idx] == max(local_scores)
def _generate_audio_units_from_hidden_states(self, hidden_states):
try:
_, adapted_inputs_embeds = self.unit_translator.insert_text_embedding(
inputs_embeds=hidden_states,
do_task="skip",
)
attention_mask = torch.ones(adapted_inputs_embeds.shape[:2]).to(self.device)
ret = dict(
input_ids=None,
inputs_embeds=adapted_inputs_embeds,
attention_mask=attention_mask,
)
return self.get_unit_result(ret)
except Exception as e:
print(f"Error generating audio units: {e}")
return None
def generate_with_speech_model(self, units):
wav = gen_wav(self.vocoder, self.voc_cfg, units, self.device)
return wav