Spaces:
Sleeping
Sleeping
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torchaudio | |
| from torchaudio.transforms import FrequencyMasking | |
| from .config import N_TYPES, SAMPLE_RATE, N_MELS, HOP_LENGTH, TIME_SUB | |
| from .model import TaikoConformer5 | |
| mel_transform = torchaudio.transforms.MelSpectrogram( | |
| sample_rate=SAMPLE_RATE, | |
| n_mels=N_MELS, | |
| hop_length=HOP_LENGTH, | |
| n_fft=2048, | |
| ) | |
| freq_mask = FrequencyMasking(freq_mask_param=15) | |
| def preprocess(example, difficulty="oni"): | |
| wav_tensor = example["audio"]["array"] | |
| sr = example["audio"]["sampling_rate"] | |
| # 1) load & resample | |
| if sr != SAMPLE_RATE: | |
| wav_tensor = torchaudio.functional.resample(wav_tensor, sr, SAMPLE_RATE) | |
| # normalize audio | |
| wav_tensor = wav_tensor / (wav_tensor.abs().max() + 1e-8) | |
| # add random Gaussian noise | |
| if torch.rand(1).item() < 0.5: | |
| wav_tensor = wav_tensor + 0.005 * torch.randn_like(wav_tensor) | |
| # 2) mel: (1, N_MELS, T) | |
| mel = mel_transform(wav_tensor).unsqueeze(0) | |
| # apply SpecAugment | |
| # we don't use time masking since we don't want model to predict notes when they are masked | |
| mel = freq_mask(mel) | |
| _, _, T = mel.shape | |
| # 3) build label sequence of length ceil(T / TIME_SUB) | |
| T_sub = math.ceil(T / TIME_SUB) | |
| # Initialize energy-based labels for Don, Ka, Drumroll | |
| don_labels = torch.zeros(T_sub, dtype=torch.float32) | |
| ka_labels = torch.zeros(T_sub, dtype=torch.float32) | |
| drumroll_labels = torch.zeros(T_sub, dtype=torch.float32) | |
| # Define exponential decay tail parameters | |
| tail_length = 40 # number of frames for decay tail | |
| decay_rate = 8.0 # decay rate parameter, adjust as needed | |
| tail_kernel = torch.exp( | |
| -torch.arange(0, tail_length, dtype=torch.float32) / decay_rate | |
| ) | |
| fps = SAMPLE_RATE / HOP_LENGTH | |
| num_valid_notes = 0 | |
| for onset in example[difficulty]: | |
| typ, t_start, t_end, *_ = onset | |
| # Assuming N_TYPES in config is appropriately set (e.g., 7 or more) | |
| if typ < 1 or typ > N_TYPES: # Filter out invalid types | |
| continue | |
| num_valid_notes += 1 | |
| f = int(round(t_start.item() * fps)) | |
| idx = f // TIME_SUB | |
| if 0 <= idx < T_sub: | |
| # Apply exponential decay kernel to the corresponding energy channel | |
| # Type 1 and 3 are Don | |
| if typ == 1 or typ == 3: | |
| for i, val in enumerate(tail_kernel): | |
| target_idx = idx + i | |
| if 0 <= target_idx < T_sub: | |
| don_labels[target_idx] = max( | |
| don_labels[target_idx].item(), val.item() | |
| ) | |
| # Type 2 and 4 are Ka | |
| elif typ == 2 or typ == 4: | |
| for i, val in enumerate(tail_kernel): | |
| target_idx = idx + i | |
| if 0 <= target_idx < T_sub: | |
| ka_labels[target_idx] = max( | |
| ka_labels[target_idx].item(), val.item() | |
| ) | |
| # Type 5, 6, 7 are Drumroll | |
| elif typ >= 5 and typ <= 7: | |
| f_end = int(round(t_end.item() * fps)) | |
| idx_end = f_end // TIME_SUB | |
| for dr in range(idx, idx_end): | |
| if 0 <= dr < T_sub: | |
| drumroll_labels[dr] = 1.0 | |
| for i, val in enumerate(tail_kernel): | |
| target_idx = idx_end + i | |
| if 0 <= target_idx < T_sub: | |
| drumroll_labels[target_idx] = max( | |
| drumroll_labels[target_idx].item(), val.item() | |
| ) | |
| duration_seconds = wav_tensor.shape[-1] / SAMPLE_RATE | |
| nps = num_valid_notes / duration_seconds if duration_seconds > 0 else 0.0 | |
| print( | |
| f"Processed {num_valid_notes} notes in {duration_seconds:.2f} seconds, NPS: {nps:.2f}" | |
| ) | |
| return { | |
| "mel": mel, | |
| "don_labels": don_labels, | |
| "ka_labels": ka_labels, | |
| "drumroll_labels": drumroll_labels, | |
| "nps": torch.tensor(nps, dtype=torch.float32), | |
| "duration_seconds": torch.tensor(duration_seconds, dtype=torch.float32), | |
| } | |
| def collate_fn(batch): | |
| mels_list = [b["mel"].squeeze(0).transpose(0, 1) for b in batch] # (T, N_MELS) | |
| # Extract new energy-based labels | |
| don_labels_list = [b["don_labels"] for b in batch] | |
| ka_labels_list = [b["ka_labels"] for b in batch] | |
| drumroll_labels_list = [b["drumroll_labels"] for b in batch] | |
| nps_list = [b["nps"] for b in batch] # Extract NPS | |
| durations_list = [b["duration_seconds"] for b in batch] # Extract durations | |
| # Pad mels | |
| padded_mels = nn.utils.rnn.pad_sequence( | |
| mels_list, batch_first=True | |
| ) # (B, T_max, N_MELS) | |
| # Reshape for CNN: (B, 1, N_MELS, T_max) | |
| reshaped_mels = padded_mels.transpose(1, 2).unsqueeze(1) | |
| # Simulate CNN time downsampling to get output lengths | |
| dummy_model_for_shape_inference = TaikoConformer5() | |
| dummy_cnn = dummy_model_for_shape_inference.cnn | |
| with torch.no_grad(): | |
| cnn_out = dummy_cnn(reshaped_mels) # Use reshaped_mels that has batch dim | |
| _, _, _, T_cnn = cnn_out.shape | |
| padded_don_labels = [] | |
| padded_ka_labels = [] | |
| padded_drumroll_labels = [] | |
| # lengths = [] # This was for original presence/type labels, conformer_input_lengths is used for model | |
| for i in range(len(batch)): | |
| d_labels = don_labels_list[i] | |
| k_labels = ka_labels_list[i] | |
| dr_labels = drumroll_labels_list[i] | |
| item_original_T_sub = d_labels.shape[ | |
| 0 | |
| ] # Assuming all label types have same original length | |
| out_len = T_cnn # Target length for labels is T_cnn | |
| # Pad or truncate don_labels | |
| if item_original_T_sub < out_len: | |
| pad_d = torch.full( | |
| (out_len - item_original_T_sub,), | |
| 0, # Pad with 0 for energy labels | |
| dtype=d_labels.dtype, | |
| device=d_labels.device, | |
| ) | |
| padded_d = torch.cat([d_labels, pad_d], dim=0) | |
| else: | |
| padded_d = d_labels[:out_len] | |
| padded_don_labels.append(padded_d) | |
| # Pad or truncate ka_labels | |
| if item_original_T_sub < out_len: | |
| pad_k = torch.full( | |
| (out_len - item_original_T_sub,), | |
| 0, # Pad with 0 for energy labels | |
| dtype=k_labels.dtype, | |
| device=k_labels.device, | |
| ) | |
| padded_k = torch.cat([k_labels, pad_k], dim=0) | |
| else: | |
| padded_k = k_labels[:out_len] | |
| padded_ka_labels.append(padded_k) | |
| # Pad or truncate drumroll_labels | |
| if item_original_T_sub < out_len: | |
| pad_dr = torch.full( | |
| (out_len - item_original_T_sub,), | |
| 0, # Pad with 0 for energy labels | |
| dtype=dr_labels.dtype, | |
| device=dr_labels.device, | |
| ) | |
| padded_dr = torch.cat([dr_labels, pad_dr], dim=0) | |
| else: | |
| padded_dr = dr_labels[:out_len] | |
| padded_drumroll_labels.append(padded_dr) | |
| # For Conformer input lengths: lengths of mel sequences after CNN subsampling | |
| # (Assuming CNN does not subsample in time, T_cnn is effectively T_mel_padded) | |
| # The `lengths` for the Conformer should be based on the mel input to the conformer part. | |
| # The existing calculation for conformer_input_lengths seems to relate to TIME_SUB. | |
| # If the Conformer input itself is not subsampled by TIME_SUB, this might need review. | |
| # For now, keeping the existing conformer_input_lengths logic as it's outside the scope of label change. | |
| conformer_input_lengths = [ | |
| math.ceil(mels_list[i].shape[0] / TIME_SUB) for i in range(len(batch)) | |
| ] | |
| conformer_input_lengths = torch.tensor( | |
| [min(l, T_cnn) for l in conformer_input_lengths], dtype=torch.long | |
| ) | |
| return { | |
| "mel": reshaped_mels, | |
| "don_labels": torch.stack(padded_don_labels), | |
| "ka_labels": torch.stack(padded_ka_labels), | |
| "drumroll_labels": torch.stack(padded_drumroll_labels), | |
| "lengths": conformer_input_lengths, # These are for the Conformer model | |
| "nps": torch.stack(nps_list), | |
| "durations": torch.stack(durations_list), | |
| } | |