xtts-castellano / finetune_xtts_hf_long.py
sob111's picture
Update finetune_xtts_hf_long.py
ea36659 verified
import os, subprocess, sys, zipfile, csv
import threading
import http.server
import socketserver
import time
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
os.environ["HF_HUB_DISABLE_HF_TRANSFER"] = "1"
os.environ["HF_HUB_ENABLE_XET"] = "0"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
os.environ["NUMBA_CACHE_DIR"] = "/tmp/numba_cache"
os.environ["HF_HOME"] = "/tmp/hf_home"
os.environ["HF_HUB_CACHE"] = "/tmp/hf_cache"
os.environ["CHECKPOINTS_OUT_PATH"] = "/tmp/xtts_checkpoints"
os.environ["OUT_PATH"] = "/tmp/output_model"
os.makedirs("/tmp/numba_cache", exist_ok=True)
os.makedirs("/tmp/hf_cache", exist_ok=True)
os.makedirs("/tmp/hf_home", exist_ok=True)
os.makedirs("/tmp/xtts_checkpoints", exist_ok=True)
os.environ["NUMBA_DISABLE_JIT"] = "1"
from huggingface_hub import HfApi, HfFolder, upload_folder, snapshot_download
# 🔒 Eliminar hf_transfer si está presente
subprocess.run([sys.executable, "-m", "pip", "uninstall", "-y", "hf_transfer"])
# === Configuración ===
HF_MODEL_ID = "sob111/xtts-v2-finetuned" # <--- cambia con tu repo en HF
HF_TOKEN = os.environ.get("HF_TOKEN") # Debe estar definido en tu Space/entorno
DATASET_PATH = "/tmp/dataset" # Ruta a tu dataset
VOXPOPULI_PATH = "/tmp/dataset/voxpopuli_es_500_vctk" # Ruta a tu dataset
OUTPUT_PATH = "/tmp/output_model"
BASE_MODEL = "coqui/XTTS-v2"
PORT = int(os.environ.get("PORT", 7860))
os.makedirs("/tmp/xtts_cache", exist_ok=True)
os.chmod("/tmp/xtts_cache", 0o777)
os.makedirs("/tmp/xtts_model", exist_ok=True)
os.chmod("/tmp/xtts_model", 0o777)
os.makedirs("/tmp/xtts_model/.huggingface", exist_ok=True)
os.chmod("/tmp/xtts_model/.huggingface", 0o777)
os.makedirs(OUTPUT_PATH, exist_ok=True)
os.chmod(OUTPUT_PATH, 0o777)
# Solución matplotlib
os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib"
os.makedirs("/tmp/matplotlib", exist_ok=True)
# Continúa con tu lógica, usando las nuevas rutas de manera consistent
# 🔧 Forzar descarga sin symlinks ni hf_transfer
model_dir = snapshot_download(
repo_id="coqui/XTTS-v2",
local_dir="/tmp/xtts_model", # descarga directa aquí
cache_dir="/tmp/hf_cache", # cache seguro en /tmp
#local_dir_use_symlinks=False, # 🔑 evita enlaces simbólicos
resume_download=True,
token=HF_TOKEN
)
print(f"✅ Modelo descargado en: {model_dir}")
CONFIG_PATH = "/tmp/xtts_model/config.json"
RESTORE_PATH = "/tmp/xtts_model/model.pth"
def start_health_check_server():
"""Inicia un servidor HTTP simple en un hilo separado."""
Handler = http.server.SimpleHTTPRequestHandler
with socketserver.TCPServer(("", PORT), Handler) as httpd:
print(f"Servidor de health check iniciado en el puerto {PORT}")
httpd.serve_forever()
# === 1.B Extraer el dataset
def extract_zip(zip_file_path, destination_path):
"""
Extracts the contents of a ZIP file to a specified directory.
Args:
zip_file_path (str): The full path to the ZIP file.
destination_path (str): The directory where the contents will be extracted.
"""
# Create the destination directory if it doesn't exist
os.makedirs(destination_path, exist_ok=True)
try:
# Open the ZIP file in read mode
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
# Extract all the contents to the specified directory
zip_ref.extractall(destination_path)
print(f"✅ Extracted '{zip_file_path}' to '{destination_path}' successfully.")
except zipfile.BadZipFile:
print(f"❌ Error: The file '{zip_file_path}' is not a valid ZIP file.")
except FileNotFoundError:
print(f"❌ Error: The file '{zip_file_path}' was not found.")
except Exception as e:
print(f"❌ An unexpected error occurred: {e}")
# Example usage:
zip_file = "/home/user/app/voxpopuli_es_500_vctk.zip"
# To protect against security vulnerabilities, it is important to sanitize the destination path.
# This prevents an attacker from using a malicious ZIP file to write outside the destination folder.
safe_destination = os.path.abspath(DATASET_PATH)
# Call the function
extract_zip(zip_file, DATASET_PATH)
print(f"safe destination {safe_destination}")
# === 2. Editar configuración para tu dataset VoxPopuli ===
import json
# === Convertir metadata.json → metadata.csv ===
#json_path = os.path.join(VOXPOPULI_PATH, "metadata.json")
#print(f"ruta de json {json_path}")
#csv_path = os.path.join(VOXPOPULI_PATH, "metadata.csv")
#if os.path.exists(json_path):
# print("🔄 Convirtiendo metadata.json → metadata.csv...")
# with open(json_path, "r", encoding="utf-8") as f:
# data = json.load(f)
# with open(csv_path, "w", encoding="utf-8", newline="") as f:
# writer = csv.writer(f, delimiter=",", quoting=csv.QUOTE_MINIMAL)
# for entry in data:
# path = entry["audio_filepath"]
# Quitar prefijo "voxpopuli_es_500/" si existe
# if path.startswith("voxpopuli_es_500/"):
# path = path.replace("voxpopuli_es_500/", "", 1)
# text = entry["text"].replace("\n", " ").strip()
# speaker = entry.get("speaker", "spk1")
# writer.writerow([path, text, speaker])
# print(f"✅ metadata.csv generado en {csv_path}")
#else:
# raise FileNotFoundError(f"❌ No se encontró {json_path}. Verifica el zip.")
from TTS.tts.datasets import load_tts_samples
from TTS.config.shared_configs import BaseDatasetConfig
# Configuración idéntica a tu dataset
config_dataset = BaseDatasetConfig(
formatter="vctk",
dataset_name="voxpopuli",
path="/tmp/dataset",
language="es",
)
if os.path.exists("/tmp/dataset"):
print("/tmp/dataset encontrado")
if os.path.exists("/tmp/dataset/wav48"):
print("/tmp/dataset/wav48 encontrado")
if os.path.exists("/tmp/dataset/voxpopuli_es_500_vctk"):
print("/tmp/dataset/voxpopuli_es_500_vctk encontrado")
if os.path.exists("/tmp/dataset/voxpopuli_es_500_vctk/wav48"):
print("/tmp/dataset/voxpopuli_es_500_vctk/wav48 encontrado")
#train_samples, eval_samples = load_tts_samples(
# "/tmp/dataset","vctk"
#)
# Construimos rutas completas
#root_path = config_dataset.path
#meta_file_train = config_dataset.meta_file_train
#meta_path = os.path.join(root_path, meta_file_train)
#print(f"Verificando archivo CSV: {meta_path}")
#print(f"Existe?: {os.path.exists(meta_path)}")
# Intentamos cargar los samples
#try:
# train_samples, eval_samples = load_tts_samples(config_dataset)
# print(f"Samples detectados: {len(train_samples)} training, {len(eval_samples)} eval")
# print("Primeros 3 samples:")
# for s in train_samples[:3]:
# print(s)
#except AssertionError as e:
# print("❌ Error cargando samples:", e)
# === 2.B Lanzar servidor web ===
# Crea e inicia el hilo para el servidor web del health check
health_check_thread = threading.Thread(target=start_health_check_server, daemon=True)
health_check_thread.start()
print("=== Editando configuración para fine-tuning con VoxPopuli ===")
with open(CONFIG_PATH, "r") as f:
config = json.load(f)
config["output_path"] = OUTPUT_PATH
config["datasets"] = [
{
"formatter": "ljspeech",
"path": VOXPOPULI_PATH,
"meta_file_train": "metadata.csv"
}
]
config["run_name"] = "xtts-finetune-voxpopuli"
config["lr"] = 1e-5 # más bajo para fine-tuning
with open(CONFIG_PATH, "w") as f:
json.dump(config, f, indent=2)
# === 3. Lanzar entrenamiento ===
print("=== Iniciando fine-tuning de XTTS-v2 ===")
import librosa
from librosa.core import spectrum
subprocess.run([
"python", "/home/user/app/train_gpt_xtts.py",
# "--config_path", CONFIG_PATH,
# "--restore_path", RESTORE_PATH
], check=True)
# subprocess.run([
# "python", "-m", "TTS.bin.train",
# "--config_path", CONFIG_PATH,
# "--restore_path", RESTORE_PATH
# ], check=True)
# === 4. Subir modelo resultante a HF ===
print("=== Subiendo modelo fine-tuneado a Hugging Face Hub ===")
api = HfApi()
HfFolder.save_token(HF_TOKEN)
api.create_repo(repo_id="sob111/xtts-v2-finetuned", repo_type="model", private=False)
upload_folder(
repo_id=HF_MODEL_ID,
repo_type="model",
folder_path=OUTPUT_PATH,
token=HF_TOKEN
)
print("✅ Fine-tuning completado y modelo subido a Hugging Face.")
sys.exit(0)