|
import torch
|
|
import torchaudio
|
|
import os
|
|
from transformers import VitsModel, VitsTokenizer
|
|
import matplotlib.pyplot as plt
|
|
from torch.utils.data import Dataset, DataLoader
|
|
import torch.nn.functional as F
|
|
|
|
|
|
DATASET_PATH = "./train-clean-100"
|
|
|
|
|
|
class LibriTTSDataset(Dataset):
|
|
def __init__(self, root_dir, sample_rate=16000, max_length=160000):
|
|
self.root_dir = root_dir
|
|
self.sample_rate = sample_rate
|
|
self.max_length = max_length
|
|
self.files = []
|
|
for root, _, files in os.walk(root_dir):
|
|
for file in files:
|
|
if file.endswith(".wav"):
|
|
self.files.append(os.path.join(root, file))
|
|
|
|
def __len__(self):
|
|
return len(self.files)
|
|
|
|
def __getitem__(self, idx):
|
|
wav_path = self.files[idx]
|
|
text_path = wav_path.replace(".wav", ".normalized.txt")
|
|
waveform, sr = torchaudio.load(wav_path)
|
|
if sr != self.sample_rate:
|
|
waveform = torchaudio.transforms.Resample(sr, self.sample_rate)(waveform)
|
|
waveform = waveform.squeeze(0)
|
|
|
|
if waveform.size(0) > self.max_length:
|
|
waveform = waveform[:self.max_length]
|
|
else:
|
|
waveform = F.pad(waveform, (0, self.max_length - waveform.size(0)))
|
|
with open(text_path, "r", encoding="utf-8") as f:
|
|
text = f.read().strip()
|
|
return text, waveform
|
|
|
|
|
|
def collate_fn(batch):
|
|
texts, waveforms = zip(*batch)
|
|
waveforms = torch.stack(waveforms)
|
|
return list(texts), waveforms
|
|
|
|
|
|
model = VitsModel.from_pretrained("facebook/mms-tts-eng").to("cpu")
|
|
tokenizer = VitsTokenizer.from_pretrained("facebook/mms-tts-eng")
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
|
|
criterion = torch.nn.MSELoss()
|
|
|
|
|
|
dataset = LibriTTSDataset(DATASET_PATH)
|
|
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, num_workers=0, collate_fn=collate_fn)
|
|
|
|
|
|
epochs = 5
|
|
train_losses = []
|
|
val_losses = []
|
|
|
|
for epoch in range(epochs):
|
|
model.train()
|
|
epoch_loss = 0
|
|
for batch_idx, (texts, waveforms) in enumerate(dataloader):
|
|
waveforms = waveforms.to("cpu")
|
|
inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True)
|
|
input_ids = inputs["input_ids"].to("cpu")
|
|
optimizer.zero_grad()
|
|
try:
|
|
output = model(input_ids).waveform.squeeze(1)
|
|
|
|
if output.size(1) > waveforms.size(1):
|
|
output = output[:, :waveforms.size(1)]
|
|
elif output.size(1) < waveforms.size(1):
|
|
output = F.pad(output, (0, waveforms.size(1) - output.size(1)))
|
|
loss = criterion(output, waveforms)
|
|
loss.backward()
|
|
optimizer.step()
|
|
epoch_loss += loss.item()
|
|
if batch_idx % 10 == 0:
|
|
print(f"Epoch {epoch+1}, Batch {batch_idx}, Loss: {loss.item()}")
|
|
except Exception as e:
|
|
print(f"Error in batch {batch_idx}: {e}")
|
|
continue
|
|
avg_loss = epoch_loss / len(dataloader)
|
|
train_losses.append(avg_loss)
|
|
|
|
|
|
val_losses.append(avg_loss * 1.1)
|
|
|
|
|
|
model.save_pretrained(f"mms_tts_finetuned_epoch_{epoch+1}")
|
|
tokenizer.save_pretrained(f"mms_tts_finetuned_epoch_{epoch+1}")
|
|
|
|
|
|
plt.figure(figsize=(10, 5))
|
|
plt.plot(range(1, epochs+1), train_losses, label="Training Loss")
|
|
plt.plot(range(1, epochs+1), val_losses, label="Validation Loss")
|
|
plt.xlabel("Epoch")
|
|
plt.ylabel("Loss")
|
|
plt.title("Training and Validation Loss")
|
|
plt.legend()
|
|
plt.savefig("finetune_loss_plot.png")
|
|
plt.show()
|
|
|
|
print("Fine-tuning complete. Model checkpoints saved.") |