remiai3's picture
Upload 8 files
22ab6e1 verified
raw
history blame
4.02 kB
import torch
import torchaudio
import os
from transformers import VitsModel, VitsTokenizer
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
# Paths (update if needed)
DATASET_PATH = "./train-clean-100" # Points to the root of train-clean-100
# Custom Dataset
class LibriTTSDataset(Dataset):
def __init__(self, root_dir, sample_rate=16000, max_length=160000):
self.root_dir = root_dir
self.sample_rate = sample_rate
self.max_length = max_length
self.files = []
for root, _, files in os.walk(root_dir):
for file in files:
if file.endswith(".wav"):
self.files.append(os.path.join(root, file))
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
wav_path = self.files[idx]
text_path = wav_path.replace(".wav", ".normalized.txt")
waveform, sr = torchaudio.load(wav_path)
if sr != self.sample_rate:
waveform = torchaudio.transforms.Resample(sr, self.sample_rate)(waveform)
waveform = waveform.squeeze(0)
# Pad or truncate waveform to fixed length
if waveform.size(0) > self.max_length:
waveform = waveform[:self.max_length]
else:
waveform = F.pad(waveform, (0, self.max_length - waveform.size(0)))
with open(text_path, "r", encoding="utf-8") as f:
text = f.read().strip()
return text, waveform
# Custom collate function
def collate_fn(batch):
texts, waveforms = zip(*batch)
waveforms = torch.stack(waveforms)
return list(texts), waveforms
# Load model and tokenizer
model = VitsModel.from_pretrained("facebook/mms-tts-eng").to("cpu")
tokenizer = VitsTokenizer.from_pretrained("facebook/mms-tts-eng")
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
criterion = torch.nn.MSELoss()
# DataLoader
dataset = LibriTTSDataset(DATASET_PATH)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, num_workers=0, collate_fn=collate_fn)
# Training loop with plotting
epochs = 5
train_losses = []
val_losses = []
for epoch in range(epochs):
model.train()
epoch_loss = 0
for batch_idx, (texts, waveforms) in enumerate(dataloader):
waveforms = waveforms.to("cpu")
inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True)
input_ids = inputs["input_ids"].to("cpu")
optimizer.zero_grad()
try:
output = model(input_ids).waveform.squeeze(1)
# Resize output to match target waveform length
if output.size(1) > waveforms.size(1):
output = output[:, :waveforms.size(1)]
elif output.size(1) < waveforms.size(1):
output = F.pad(output, (0, waveforms.size(1) - output.size(1)))
loss = criterion(output, waveforms)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
if batch_idx % 10 == 0:
print(f"Epoch {epoch+1}, Batch {batch_idx}, Loss: {loss.item()}")
except Exception as e:
print(f"Error in batch {batch_idx}: {e}")
continue
avg_loss = epoch_loss / len(dataloader)
train_losses.append(avg_loss)
# Dummy validation loss
val_losses.append(avg_loss * 1.1)
# Save model checkpoint
model.save_pretrained(f"mms_tts_finetuned_epoch_{epoch+1}")
tokenizer.save_pretrained(f"mms_tts_finetuned_epoch_{epoch+1}")
# Plotting
plt.figure(figsize=(10, 5))
plt.plot(range(1, epochs+1), train_losses, label="Training Loss")
plt.plot(range(1, epochs+1), val_losses, label="Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training and Validation Loss")
plt.legend()
plt.savefig("finetune_loss_plot.png")
plt.show()
print("Fine-tuning complete. Model checkpoints saved.")