emotion-detection / datasets /audio_dataset.py
saeedbenadeeb's picture
First commit
0874d87
raw
history blame
3.91 kB
import os
import zipfile
import requests
from tqdm import tqdm
from typing import List, Tuple
import numpy as np
from torch.utils.data import Dataset
import librosa
import torch
SAMPLE_RATE = 22050
DURATION = 1.4 # second
class EmodbDataset(Dataset):
__url__ = "http://www.emodb.bilderbar.info/download/download.zip"
__labels__ = ("angry", "happy", "neutral", "sad")
__suffixes__ = {
"angry": ["Wa", "Wb", "Wc", "Wd"],
"happy": ["Fa", "Fb", "Fc", "Fd"],
"neutral": ["Na", "Nb", "Nc", "Nd"],
"sad": ["Ta", "Tb", "Tc", "Td"]
}
def __init__(self, root_path: str = './data/emodb', transform=None):
super().__init__()
self.root_path = root_path
self.audio_root_path = os.path.join(root_path, "wav")
# Ensure the dataset is downloaded
self._ensure_dataset()
ids = []
targets = []
for audio_file in os.listdir(self.audio_root_path):
f_name, ext = os.path.splitext(audio_file)
if ext != ".wav":
continue
suffix = f_name[-2:]
for label, suffixes in self.__suffixes__.items():
if suffix in suffixes:
ids.append(os.path.join(self.audio_root_path, audio_file))
targets.append(self.label2id(label))
break
self.ids = ids
self.targets = np.array(targets, dtype=np.int64)
self.transform = transform
def _ensure_dataset(self):
"""
Ensures the dataset is downloaded and extracted.
"""
if not os.path.isdir(self.audio_root_path):
print(f"Dataset not found at {self.audio_root_path}. Downloading...")
self._download_and_extract()
def _download_and_extract(self):
"""
Downloads and extracts the dataset zip file.
"""
# Ensure the root path exists
os.makedirs(self.root_path, exist_ok=True)
# Download the dataset
zip_path = os.path.join(self.root_path, "emodb.zip")
with requests.get(self.__url__, stream=True) as r:
r.raise_for_status()
total_size = int(r.headers.get("content-length", 0))
with open(zip_path, "wb") as f, tqdm(
desc="Downloading EMO-DB dataset",
total=total_size,
unit="B",
unit_scale=True,
unit_divisor=1024,
) as bar:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
bar.update(len(chunk))
# Extract the dataset
print("Extracting dataset...")
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(self.root_path)
# Clean up the zip file
os.remove(zip_path)
def __len__(self):
return len(self.ids)
def __getitem__(self, idx: int) -> Tuple:
target = self.targets[idx]
audio = self.load_audio(self.ids[idx]) # Should return a numpy array
if self.transform:
audio = self.transform(audio) # Apply transform
return audio, target
@staticmethod
def id2label(idx: int) -> str:
return EmodbDataset.__labels__[idx]
@staticmethod
def label2id(label: str) -> int:
if label not in EmodbDataset.__labels__:
raise ValueError(f"Unknown label: {label}")
return EmodbDataset.__labels__.index(label)
@staticmethod
def load_audio(audio_file_path: str) -> np.ndarray:
audio, sr = librosa.load(audio_file_path, sr=SAMPLE_RATE, duration=DURATION)
assert SAMPLE_RATE == sr, "broken audio file"
# Convert numpy array to PyTorch tensor
return torch.tensor(audio, dtype=torch.float32)
@staticmethod
def get_labels() -> List[str]:
return list(EmodbDataset.__labels__)