import os import zipfile import requests from tqdm import tqdm from typing import List, Tuple import numpy as np from torch.utils.data import Dataset import librosa import torch SAMPLE_RATE = 22050 DURATION = 1.4 # second class EmodbDataset(Dataset): __url__ = "http://www.emodb.bilderbar.info/download/download.zip" __labels__ = ("angry", "happy", "neutral", "sad") __suffixes__ = { "angry": ["Wa", "Wb", "Wc", "Wd"], "happy": ["Fa", "Fb", "Fc", "Fd"], "neutral": ["Na", "Nb", "Nc", "Nd"], "sad": ["Ta", "Tb", "Tc", "Td"] } def __init__(self, root_path: str = './data/emodb', transform=None): super().__init__() self.root_path = root_path self.audio_root_path = os.path.join(root_path, "wav") # Ensure the dataset is downloaded self._ensure_dataset() ids = [] targets = [] for audio_file in os.listdir(self.audio_root_path): f_name, ext = os.path.splitext(audio_file) if ext != ".wav": continue suffix = f_name[-2:] for label, suffixes in self.__suffixes__.items(): if suffix in suffixes: ids.append(os.path.join(self.audio_root_path, audio_file)) targets.append(self.label2id(label)) break self.ids = ids self.targets = np.array(targets, dtype=np.int64) self.transform = transform def _ensure_dataset(self): """ Ensures the dataset is downloaded and extracted. """ if not os.path.isdir(self.audio_root_path): print(f"Dataset not found at {self.audio_root_path}. Downloading...") self._download_and_extract() def _download_and_extract(self): """ Downloads and extracts the dataset zip file. """ # Ensure the root path exists os.makedirs(self.root_path, exist_ok=True) # Download the dataset zip_path = os.path.join(self.root_path, "emodb.zip") with requests.get(self.__url__, stream=True) as r: r.raise_for_status() total_size = int(r.headers.get("content-length", 0)) with open(zip_path, "wb") as f, tqdm( desc="Downloading EMO-DB dataset", total=total_size, unit="B", unit_scale=True, unit_divisor=1024, ) as bar: for chunk in r.iter_content(chunk_size=8192): f.write(chunk) bar.update(len(chunk)) # Extract the dataset print("Extracting dataset...") with zipfile.ZipFile(zip_path, "r") as zip_ref: zip_ref.extractall(self.root_path) # Clean up the zip file os.remove(zip_path) def __len__(self): return len(self.ids) def __getitem__(self, idx: int) -> Tuple: target = self.targets[idx] audio = self.load_audio(self.ids[idx]) # Should return a numpy array if self.transform: audio = self.transform(audio) # Apply transform return audio, target @staticmethod def id2label(idx: int) -> str: return EmodbDataset.__labels__[idx] @staticmethod def label2id(label: str) -> int: if label not in EmodbDataset.__labels__: raise ValueError(f"Unknown label: {label}") return EmodbDataset.__labels__.index(label) @staticmethod def load_audio(audio_file_path: str) -> np.ndarray: audio, sr = librosa.load(audio_file_path, sr=SAMPLE_RATE, duration=DURATION) assert SAMPLE_RATE == sr, "broken audio file" # Convert numpy array to PyTorch tensor return torch.tensor(audio, dtype=torch.float32) @staticmethod def get_labels() -> List[str]: return list(EmodbDataset.__labels__)