Spaces:
Running
Running
import os | |
import time | |
import tempfile | |
import subprocess | |
import threading | |
import json | |
import base64 | |
import io | |
import random | |
import logging | |
from queue import Queue | |
from threading import Thread | |
import gradio as gr | |
import torch | |
import librosa | |
import soundfile as sf | |
import requests | |
import numpy as np | |
from scipy import signal | |
from transformers import pipeline, AutoTokenizer, AutoModel | |
# Thiết lập logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
# Tạo các thư mục cần thiết | |
os.makedirs("data", exist_ok=True) | |
os.makedirs("data/audio", exist_ok=True) | |
os.makedirs("data/reports", exist_ok=True) | |
os.makedirs("data/models", exist_ok=True) | |
class AsyncProcessor: | |
"""Xử lý các tác vụ nặng trong thread riêng để không làm 'đơ' giao diện.""" | |
def __init__(self): | |
self.task_queue = Queue() | |
self.result_queue = Queue() | |
self.running = True | |
self.worker_thread = Thread(target=self._worker) | |
self.worker_thread.daemon = True | |
self.worker_thread.start() | |
def _worker(self): | |
while self.running: | |
if not self.task_queue.empty(): | |
task_id, func, args, kwargs = self.task_queue.get() | |
try: | |
result = func(*args, **kwargs) | |
self.result_queue.put((task_id, result, None)) | |
except Exception as e: | |
logger.error(f"Lỗi trong xử lý tác vụ {task_id}: {str(e)}") | |
self.result_queue.put((task_id, None, str(e))) | |
self.task_queue.task_done() | |
time.sleep(0.1) | |
def add_task(self, task_id, func, *args, **kwargs): | |
self.task_queue.put((task_id, func, args, kwargs)) | |
def get_result(self): | |
if not self.result_queue.empty(): | |
return self.result_queue.get() | |
return None | |
def stop(self): | |
self.running = False | |
if self.worker_thread.is_alive(): | |
self.worker_thread.join(timeout=1) | |
class VietSpeechTrainer: | |
def __init__(self): | |
# Đọc cấu hình từ file config.json và từ biến môi trường | |
self.config = self._load_config() | |
# Khởi tạo bộ xử lý bất đồng bộ | |
self.async_processor = AsyncProcessor() | |
# Lưu trữ lịch sử phiên làm việc | |
self.session_history = [] | |
self.current_session_id = int(time.time()) | |
# Các biến trạng thái hội thoại | |
self.current_scenario = None | |
self.current_prompt_index = 0 | |
# Khởi tạo các mô hình (STT, TTS và phân tích LLM) | |
logger.info("Đang tải các mô hình...") | |
self._initialize_models() | |
def _load_config(self): | |
"""Đọc file config.json và cập nhật từ biến môi trường (Secrets khi deploy)""" | |
config = { | |
"stt_model": "nguyenvulebinh/wav2vec2-base-vietnamese-250h", | |
"use_phowhisper": False, | |
"use_phobert": False, | |
"use_vncorenlp": False, | |
"llm_provider": "none", # openai, gemini, local hoặc none | |
"openai_api_key": "", | |
"gemini_api_key": "", | |
"local_llm_endpoint": "", | |
"use_viettts": False, | |
"default_dialect": "Bắc", | |
"enable_pronunciation_eval": False, | |
"preprocess_audio": True, | |
"save_history": True, | |
"enable_english_tts": False | |
} | |
if os.path.exists("config.json"): | |
try: | |
with open("config.json", "r", encoding="utf-8") as f: | |
file_config = json.load(f) | |
config.update(file_config) | |
except Exception as e: | |
logger.error(f"Lỗi đọc config.json: {e}") | |
# Cập nhật từ biến môi trường | |
if os.environ.get("LLM_PROVIDER"): | |
config["llm_provider"] = os.environ.get("LLM_PROVIDER").lower() | |
if os.environ.get("OPENAI_API_KEY"): | |
config["openai_api_key"] = os.environ.get("OPENAI_API_KEY") | |
if os.environ.get("GEMINI_API_KEY"): | |
config["gemini_api_key"] = os.environ.get("GEMINI_API_KEY") | |
if os.environ.get("LOCAL_LLM_ENDPOINT"): | |
config["local_llm_endpoint"] = os.environ.get("LOCAL_LLM_ENDPOINT") | |
if os.environ.get("ENABLE_ENGLISH_TTS") and os.environ.get("ENABLE_ENGLISH_TTS").lower() == "true": | |
config["enable_english_tts"] = True | |
return config | |
def _initialize_models(self): | |
"""Khởi tạo mô hình STT và thiết lập CSM cho TTS tiếng Anh nếu được bật.""" | |
try: | |
# Khởi tạo STT | |
if self.config["use_phowhisper"]: | |
logger.info("Loading PhoWhisper...") | |
self.stt_model = pipeline("automatic-speech-recognition", | |
model="vinai/PhoWhisper-small", | |
device=0 if torch.cuda.is_available() else -1) | |
else: | |
logger.info(f"Loading STT model: {self.config['stt_model']}") | |
self.stt_model = pipeline("automatic-speech-recognition", | |
model=self.config["stt_model"], | |
device=0 if torch.cuda.is_available() else -1) | |
except Exception as e: | |
logger.error(f"Lỗi khởi tạo STT: {e}") | |
self.stt_model = None | |
# Các mô hình NLP (PhoBERT, VnCoreNLP) nếu cần. | |
# ... | |
# Nếu bật TTS tiếng Anh thì thiết lập CSM | |
if self.config.get("enable_english_tts", False): | |
self._setup_csm() | |
else: | |
self.csm_ready = False | |
def _setup_csm(self): | |
"""Cài đặt mô hình CSM (Conversational Speech Generation Model) cho TTS tiếng Anh.""" | |
try: | |
csm_dir = os.path.join(os.getcwd(), "csm") | |
if not os.path.exists(csm_dir): | |
logger.info("Cloning CSM repo...") | |
subprocess.run(["git", "clone", "https://github.com/SesameAILabs/csm", csm_dir], check=True) | |
logger.info("Installing CSM requirements...") | |
subprocess.run(["pip", "install", "-r", os.path.join(csm_dir, "requirements.txt")], check=True) | |
self.csm_ready = True | |
logger.info("CSM đã được thiết lập thành công!") | |
except Exception as e: | |
logger.error(f"Failed to set up CSM: {e}") | |
self.csm_ready = False | |
def text_to_speech(self, text, language="vi", dialect="Bắc"): | |
""" | |
Chuyển văn bản thành giọng nói: | |
- Nếu language == "en": sử dụng CSM để tạo TTS tiếng Anh. | |
- Nếu language == "vi": sử dụng API hoặc logic TTS tiếng Việt. | |
""" | |
if language == "en": | |
if not self.csm_ready: | |
logger.error("CSM chưa được thiết lập hoặc không được bật.") | |
return None | |
output_file = f"data/audio/csm_{int(time.time())}.wav" | |
csm_script_path = os.path.join(os.getcwd(), "csm", "run_csm.py") | |
cmd = [ | |
"python", | |
csm_script_path, | |
"--text", text, | |
"--speaker_id", "0", # Mặc định, có thể cho phép người dùng chọn | |
"--output", output_file | |
] | |
try: | |
subprocess.run(cmd, check=True) | |
return output_file | |
except subprocess.CalledProcessError as e: | |
logger.error(f"CSM generation failed: {e}") | |
return None | |
else: | |
# Ví dụ: Nếu có API TTS tiếng Việt, gọi API đó. | |
tts_api_url = self.config.get("tts_api_url", "") | |
if tts_api_url: | |
try: | |
resp = requests.post(tts_api_url, json={"text": text, "dialect": dialect.lower()}) | |
if resp.status_code == 200: | |
output_file = f"data/audio/tts_{int(time.time())}.wav" | |
with open(output_file, "wb") as f: | |
f.write(resp.content) | |
return output_file | |
else: | |
logger.error(f"Error calling TTS API: {resp.text}") | |
return None | |
except Exception as e: | |
logger.error(f"Lỗi gọi TTS API: {e}") | |
return None | |
else: | |
# Nếu không có API TTS, bạn có thể tích hợp VietTTS hoặc khác. | |
return None | |
def transcribe_audio(self, audio_path): | |
"""Chuyển đổi giọng nói thành văn bản (STT).""" | |
if not self.stt_model: | |
return "STT model not available." | |
try: | |
result = self.stt_model(audio_path) | |
if isinstance(result, dict) and "text" in result: | |
return result["text"] | |
elif isinstance(result, list): | |
return " ".join([chunk.get("text", "") for chunk in result]) | |
else: | |
return str(result) | |
except Exception as e: | |
logger.error(f"Lỗi chuyển giọng nói: {e}") | |
return f"Lỗi: {str(e)}" | |
def analyze_text(self, transcript, dialect="Bắc"): | |
""" | |
Phân tích văn bản sử dụng LLM: | |
- Nếu LLM_PROVIDER là "openai", "gemini" hay "local" thì gọi API tương ứng. | |
- Nếu LLM_PROVIDER là "none", sử dụng phân tích rule-based. | |
""" | |
llm_provider = self.config["llm_provider"] | |
if llm_provider == "openai" and self.config["openai_api_key"]: | |
return self._analyze_with_openai(transcript) | |
elif llm_provider == "gemini" and self.config["gemini_api_key"]: | |
return self._analyze_with_gemini(transcript) | |
elif llm_provider == "local" and self.config["local_llm_endpoint"]: | |
return self._analyze_with_local_llm(transcript) | |
else: | |
return self._rule_based_analysis(transcript, dialect) | |
def _analyze_with_openai(self, transcript): | |
headers = { | |
"Authorization": f"Bearer {self.config['openai_api_key']}", | |
"Content-Type": "application/json" | |
} | |
data = { | |
"model": "gpt-3.5-turbo", | |
"messages": [ | |
{"role": "system", "content": "Bạn là trợ lý dạy tiếng Việt."}, | |
{"role": "user", "content": transcript} | |
], | |
"temperature": 0.5, | |
"max_tokens": 150 | |
} | |
try: | |
response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=data) | |
if response.status_code == 200: | |
result = response.json() | |
return result["choices"][0]["message"]["content"] | |
else: | |
return "Lỗi khi gọi OpenAI API." | |
except Exception as e: | |
logger.error(f"Lỗi OpenAI: {e}") | |
return "Lỗi phân tích với OpenAI." | |
def _analyze_with_gemini(self, transcript): | |
# Ví dụ minh họa: Gọi Gemini API (chi tiết phụ thuộc vào tài liệu của Gemini) | |
return "Gemini analysis..." | |
def _analyze_with_local_llm(self, transcript): | |
# Giả sử gọi một endpoint local (nếu có) cho LLM cục bộ. | |
headers = {"Content-Type": "application/json"} | |
data = { | |
"model": "local-model", | |
"messages": [ | |
{"role": "system", "content": "Bạn là trợ lý dạy tiếng Việt."}, | |
{"role": "user", "content": transcript} | |
], | |
"temperature": 0.5, | |
"max_tokens": 150 | |
} | |
try: | |
response = requests.post(self.config["local_llm_endpoint"] + "/chat/completions", headers=headers, json=data) | |
if response.status_code == 200: | |
result = response.json() | |
return result["choices"][0]["message"]["content"] | |
else: | |
return "Lỗi khi gọi Local LLM." | |
except Exception as e: | |
logger.error(f"Lỗi local LLM: {e}") | |
return "Lỗi phân tích với LLM local." | |
def _rule_based_analysis(self, transcript, dialect): | |
# Phân tích đơn giản không dùng LLM | |
return "Phân tích rule-based: " + transcript | |
def clean_up(self): | |
self.async_processor.stop() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
logger.info("Clean up done.") | |
def create_demo(): | |
trainer = VietSpeechTrainer() | |
with gr.Blocks(title="Ứng dụng Luyện Nói & TTS", theme=gr.themes.Soft(primary_hue="blue")) as demo: | |
gr.Markdown("## Ứng dụng Luyện Nói & TTS (Tiếng Việt & Tiếng Anh)") | |
with gr.Tabs(): | |
# Tab 1: TTS Tiếng Việt | |
with gr.Tab("TTS Tiếng Việt"): | |
vi_text_input = gr.Textbox(label="Nhập văn bản tiếng Việt") | |
vi_audio_output = gr.Audio(label="Kết quả âm thanh") | |
gen_vi_btn = gr.Button("Chuyển thành giọng nói") | |
def gen_vi_tts(txt): | |
return trainer.text_to_speech(txt, language="vi", dialect=trainer.config["default_dialect"]) | |
gen_vi_btn.click(fn=gen_vi_tts, inputs=vi_text_input, outputs=vi_audio_output) | |
# Tab 2: TTS Tiếng Anh (sử dụng CSM) | |
with gr.Tab("TTS Tiếng Anh"): | |
en_text_input = gr.Textbox(label="Enter English text") | |
en_audio_output = gr.Audio(label="Generated English Audio (CSM)") | |
gen_en_btn = gr.Button("Generate English Speech") | |
def gen_en_tts(txt): | |
return trainer.text_to_speech(txt, language="en") | |
gen_en_btn.click(fn=gen_en_tts, inputs=en_text_input, outputs=en_audio_output) | |
# Tab 3: Luyện phát âm (Tiếng Việt) | |
with gr.Tab("Luyện phát âm"): | |
audio_input = gr.Audio(source="microphone", type="filepath", label="Giọng nói của bạn") | |
transcript_output = gr.Textbox(label="Transcript") | |
analysis_output = gr.Markdown(label="Phân tích") | |
analyze_btn = gr.Button("Phân tích") | |
def process_audio(audio_path): | |
transcript = trainer.transcribe_audio(audio_path) | |
analysis = trainer.analyze_text(transcript, dialect=trainer.config["default_dialect"]) | |
return transcript, analysis | |
analyze_btn.click(fn=process_audio, inputs=audio_input, outputs=[transcript_output, analysis_output]) | |
# Tab 4: Thông tin & Hướng dẫn | |
with gr.Tab("Thông tin"): | |
gr.Markdown(""" | |
### Hướng dẫn sử dụng: | |
- **TTS Tiếng Việt:** Nhập văn bản tiếng Việt và nhấn "Chuyển thành giọng nói". | |
- **TTS Tiếng Anh (CSM):** Nhập English text và nhấn "Generate English Speech". | |
- **Luyện phát âm:** Thu âm giọng nói, sau đó nhấn "Phân tích" để xem transcript và phân tích. | |
### Cấu hình LLM: | |
- **OpenAI:** Đặt biến môi trường `LLM_PROVIDER=openai` và `OPENAI_API_KEY` với key của bạn. | |
- **Gemini:** Đặt `LLM_PROVIDER=gemini` và `GEMINI_API_KEY`. | |
- **Local LLM:** Đặt `LLM_PROVIDER=local` và `LOCAL_LLM_ENDPOINT` với URL của server LLM nếu bạn có. | |
- **None:** Đặt `LLM_PROVIDER=none` để sử dụng phân tích rule-based. | |
### Lưu ý: | |
- Để sử dụng TTS tiếng Anh (CSM), hãy bật biến `ENABLE_ENGLISH_TTS` (hoặc đặt `"enable_english_tts": true` trong config.json). | |
""") | |
return demo | |
def main(): | |
demo = create_demo() | |
# Sử dụng hàng đợi Gradio để xử lý tác vụ dài (ví dụ TTS CSM) | |
demo.queue() | |
demo.launch(server_name="0.0.0.0", server_port=7860) | |
if __name__ == "__main__": | |
main() | |