File size: 4,396 Bytes
0874d87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os
import zipfile
import requests
from tqdm import tqdm
from typing import List, Tuple
import numpy as np
from torch.utils.data import Dataset
import librosa
import torch
from torch.nn.utils.rnn import pad_sequence

SAMPLE_RATE = 22050
DURATION = 1.4  # seconds

class CTCEmodbDataset(Dataset):
    __url__ = "http://www.emodb.bilderbar.info/download/download.zip"
    __labels__ = ("angry", "happy", "neutral", "sad")
    __suffixes__ = {
        "angry": ["Wa", "Wb", "Wc", "Wd"],
        "happy": ["Fa", "Fb", "Fc", "Fd"],
        "neutral": ["Na", "Nb", "Nc", "Nd"],
        "sad": ["Ta", "Tb", "Tc", "Td"]
    }

    def __init__(self, root_path: str = './data/emodb', transform=None):
        super().__init__()
        self.root_path = root_path
        self.audio_root_path = os.path.join(root_path, "wav")
        
        # Ensure the dataset is downloaded
        self._ensure_dataset()

        ids = []
        targets = []
        for audio_file in os.listdir(self.audio_root_path):
            f_name, ext = os.path.splitext(audio_file)
            if ext != ".wav":
                continue

            suffix = f_name[-2:]
            for label, suffixes in self.__suffixes__.items():
                if suffix in suffixes:
                    ids.append(os.path.join(self.audio_root_path, audio_file))
                    targets.append(self.label2id(label))  # Store as integers
                    break

        self.ids = ids
        self.targets = targets  # Target sequences as a list of lists
        self.transform = transform

    def _ensure_dataset(self):
        """
        Ensures the dataset is downloaded and extracted.
        """
        if not os.path.isdir(self.audio_root_path):
            print(f"Dataset not found at {self.audio_root_path}. Downloading...")
            self._download_and_extract()

    def _download_and_extract(self):
        """
        Downloads and extracts the dataset zip file.
        """
        os.makedirs(self.root_path, exist_ok=True)
        zip_path = os.path.join(self.root_path, "emodb.zip")
        with requests.get(self.__url__, stream=True) as r:
            r.raise_for_status()
            total_size = int(r.headers.get("content-length", 0))
            with open(zip_path, "wb") as f, tqdm(
                desc="Downloading EMO-DB dataset",
                total=total_size,
                unit="B",
                unit_scale=True,
                unit_divisor=1024,
            ) as bar:
                for chunk in r.iter_content(chunk_size=8192):
                    f.write(chunk)
                    bar.update(len(chunk))

        print("Extracting dataset...")
        with zipfile.ZipFile(zip_path, "r") as zip_ref:
            zip_ref.extractall(self.root_path)

        os.remove(zip_path)

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, int, int]:
        """
        Returns:
            x (torch.Tensor): Input sequence (audio features or waveform)
            y (torch.Tensor): Target sequence (labels or tokenized transcription)
            input_length (int): Length of input sequence
            target_length (int): Length of target sequence
        """
        target = torch.tensor([self.targets[idx]], dtype=torch.long) 
        audio = self.load_audio(self.ids[idx])  # Should return a numpy array

        if self.transform:
            audio = self.transform(audio)

        # Input length (for CTC)
        input_length = audio.shape[-1]  # Last dimension is the time dimension
        target_length = len(target)  # Length of target sequence

        return audio, target, input_length, target_length

    @staticmethod
    def id2label(idx: int) -> str:
        return CTCEmodbDataset.__labels__[idx]

    @staticmethod
    def label2id(label: str) -> int:
        if label not in CTCEmodbDataset.__labels__:
            raise ValueError(f"Unknown label: {label}")
        return CTCEmodbDataset.__labels__.index(label)

    @staticmethod
    def load_audio(audio_file_path: str) -> torch.Tensor:
        audio, sr = librosa.load(audio_file_path, sr=SAMPLE_RATE, duration=DURATION)
        assert SAMPLE_RATE == sr, "broken audio file"
        return torch.tensor(audio, dtype=torch.float32)

    @staticmethod
    def get_labels() -> List[str]:
        return list(CTCEmodbDataset.__labels__)