Spaces:
Sleeping
Sleeping
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torchaudio | |
| from torchaudio.transforms import FrequencyMasking | |
| from tja import parse_tja, PyParsingMode | |
| from .config import N_TYPES, SAMPLE_RATE, N_MELS, HOP_LENGTH, TIME_SUB | |
| from .model import TaikoConformer7 | |
| mel_transform = torchaudio.transforms.MelSpectrogram( | |
| sample_rate=SAMPLE_RATE, | |
| n_mels=N_MELS, | |
| hop_length=HOP_LENGTH, | |
| n_fft=2048, | |
| ) | |
| freq_mask = FrequencyMasking(freq_mask_param=15) | |
| def preprocess(example, difficulty="oni"): | |
| wav_tensor = example["audio"]["array"] | |
| sr = example["audio"]["sampling_rate"] | |
| # 1) load & resample | |
| if sr != SAMPLE_RATE: | |
| wav_tensor = torchaudio.functional.resample(wav_tensor, sr, SAMPLE_RATE) | |
| # normalize audio | |
| wav_tensor = wav_tensor / (wav_tensor.abs().max() + 1e-8) | |
| # add random Gaussian noise | |
| if torch.rand(1).item() < 0.5: | |
| wav_tensor = wav_tensor + 0.005 * torch.randn_like(wav_tensor) | |
| # 2) mel: (1, N_MELS, T) | |
| mel = mel_transform(wav_tensor).unsqueeze(0) | |
| # apply SpecAugment | |
| mel = freq_mask(mel) | |
| _, _, T = mel.shape | |
| # 3) build label sequence of length ceil(T / TIME_SUB) | |
| T_sub = math.ceil(T / TIME_SUB) | |
| # Initialize energy-based labels for Don, Ka, Drumroll | |
| don_labels = torch.zeros(T_sub, dtype=torch.float32) | |
| ka_labels = torch.zeros(T_sub, dtype=torch.float32) | |
| drumroll_labels = torch.zeros(T_sub, dtype=torch.float32) | |
| sliding_nps_labels = torch.zeros( | |
| T_sub, dtype=torch.float32 | |
| ) # New label for sliding NPS | |
| # Define exponential decay tail parameters | |
| tail_length = 40 # number of frames for decay tail | |
| decay_rate = 8.0 # decay rate parameter, adjust as needed | |
| tail_kernel = torch.exp( | |
| -torch.arange(0, tail_length, dtype=torch.float32) / decay_rate | |
| ) | |
| fps = SAMPLE_RATE / HOP_LENGTH | |
| num_valid_notes = 0 | |
| for onset in example[difficulty]: | |
| typ, t_start, t_end, *_ = onset | |
| # Assuming N_TYPES in config is appropriately set (e.g., 7 or more) | |
| if typ < 1 or typ > N_TYPES: # Filter out invalid types | |
| continue | |
| num_valid_notes += 1 | |
| exact_frame_start = t_start.item() * fps | |
| # Type 1 and 3 are Don, Type 2 and 4 are Ka | |
| if typ == 1 or typ == 3 or typ == 2 or typ == 4: | |
| exact_hit_time_sub = exact_frame_start / TIME_SUB | |
| current_labels = don_labels if (typ == 1 or typ == 3) else ka_labels | |
| start_points_info = [] | |
| rounded_hit_time_sub = round(exact_hit_time_sub) | |
| if ( | |
| abs(exact_hit_time_sub - rounded_hit_time_sub) < 1e-6 | |
| ): # Tolerance for float precision | |
| idx_single = int(rounded_hit_time_sub) | |
| if 0 <= idx_single < T_sub: | |
| start_points_info.append({"idx": idx_single, "weight": 1.0}) | |
| else: | |
| idx_floor = math.floor(exact_hit_time_sub) | |
| idx_ceil = idx_floor + 1 | |
| frac = exact_hit_time_sub - idx_floor | |
| weight_ceil = frac | |
| weight_floor = 1.0 - frac | |
| if weight_floor > 1e-6 and 0 <= idx_floor < T_sub: | |
| start_points_info.append({"idx": idx_floor, "weight": weight_floor}) | |
| if weight_ceil > 1e-6 and 0 <= idx_ceil < T_sub: | |
| start_points_info.append({"idx": idx_ceil, "weight": weight_ceil}) | |
| for point_info in start_points_info: | |
| start_idx = point_info["idx"] | |
| weight = point_info["weight"] | |
| for k_idx, kernel_val in enumerate(tail_kernel): | |
| target_idx = start_idx + k_idx | |
| if 0 <= target_idx < T_sub: | |
| current_labels[target_idx] = max( | |
| current_labels[target_idx].item(), | |
| weight * kernel_val.item(), | |
| ) | |
| # Type 5, 6, 7 are Drumroll | |
| elif typ >= 5 and typ <= 7: | |
| exact_frame_end = t_end.item() * fps | |
| exact_start_time_sub = exact_frame_start / TIME_SUB | |
| exact_end_time_sub = exact_frame_end / TIME_SUB | |
| # Improved drumroll body | |
| body_loop_start_idx = math.floor(exact_start_time_sub) | |
| body_loop_end_idx = math.ceil(exact_end_time_sub) | |
| for dr_idx in range(body_loop_start_idx, body_loop_end_idx): | |
| if 0 <= dr_idx < T_sub: | |
| drumroll_labels[dr_idx] = 1.0 | |
| # Improved drumroll tail (starts from exact_end_time_sub) | |
| tail_start_points_info = [] | |
| rounded_end_time_sub = round(exact_end_time_sub) | |
| if abs(exact_end_time_sub - rounded_end_time_sub) < 1e-6: | |
| idx_single_tail = int(rounded_end_time_sub) | |
| if 0 <= idx_single_tail < T_sub: | |
| tail_start_points_info.append( | |
| {"idx": idx_single_tail, "weight": 1.0} | |
| ) | |
| else: | |
| idx_floor_tail = math.floor(exact_end_time_sub) | |
| idx_ceil_tail = idx_floor_tail + 1 | |
| frac_tail = exact_end_time_sub - idx_floor_tail | |
| weight_ceil_tail = frac_tail | |
| weight_floor_tail = 1.0 - frac_tail | |
| if weight_floor_tail > 1e-6 and 0 <= idx_floor_tail < T_sub: | |
| tail_start_points_info.append( | |
| {"idx": idx_floor_tail, "weight": weight_floor_tail} | |
| ) | |
| if weight_ceil_tail > 1e-6 and 0 <= idx_ceil_tail < T_sub: | |
| tail_start_points_info.append( | |
| {"idx": idx_ceil_tail, "weight": weight_ceil_tail} | |
| ) | |
| for point_info in tail_start_points_info: | |
| start_idx = point_info["idx"] | |
| weight = point_info["weight"] | |
| for k_idx, kernel_val in enumerate(tail_kernel): | |
| target_idx = start_idx + k_idx | |
| if 0 <= target_idx < T_sub: | |
| drumroll_labels[target_idx] = max( | |
| drumroll_labels[target_idx].item(), | |
| weight * kernel_val.item(), | |
| ) | |
| # Calculate sliding window NPS | |
| note_events = ( | |
| [] | |
| ) # Store tuples of (time_sec, type_is_drumroll_start_or_end, duration_if_drumroll) | |
| for onset in example[difficulty]: | |
| typ, t_start_tensor, t_end_tensor, *_ = onset | |
| t_start = t_start_tensor.item() | |
| t_end = t_end_tensor.item() | |
| if typ in [1, 2, 3, 4]: # Don or Ka | |
| note_events.append( | |
| (t_start, False, 0) | |
| ) # False indicates not a drumroll event, duration 0 | |
| elif typ >= 5 and typ <= 7: # Drumroll | |
| note_events.append( | |
| (t_start, True, t_end - t_start) | |
| ) # True indicates drumroll start, store duration | |
| # We don't explicitly need a drumroll end event for this calculation method | |
| note_events.sort(key=lambda x: x[0]) # Sort by time | |
| window_duration_seconds = 0.5 | |
| # drumroll_nps_rate = 10.0 # Removed: Will use adaptive rate | |
| # Step 1: Calculate base_sliding_nps_labels (Don/Ka only) | |
| base_don_ka_sliding_nps = torch.zeros(T_sub, dtype=torch.float32) | |
| time_step_duration_sec = TIME_SUB / fps # Duration of one T_sub segment | |
| for k_idx in range(T_sub): | |
| k_window_end_sec = ((k_idx + 1) * TIME_SUB) / fps | |
| k_window_start_sec = k_window_end_sec - window_duration_seconds | |
| current_don_ka_count = 0.0 | |
| for event_t, is_drumroll, _ in note_events: | |
| if not is_drumroll: # Don or Ka hit | |
| if k_window_start_sec <= event_t < k_window_end_sec: | |
| current_don_ka_count += 1 | |
| base_don_ka_sliding_nps[k_idx] = current_don_ka_count / window_duration_seconds | |
| # Step 2: Calculate adaptive_drumroll_rates_for_all_events | |
| adaptive_drumroll_rates_for_all_events = [] | |
| for event_t, is_drumroll, drumroll_dur in note_events: | |
| if is_drumroll: | |
| drumroll_start_sec = event_t | |
| drumroll_end_sec = event_t + drumroll_dur | |
| slice_start_idx = math.floor(drumroll_start_sec / time_step_duration_sec) | |
| slice_end_idx = math.ceil(drumroll_end_sec / time_step_duration_sec) | |
| slice_start_idx = max(0, slice_start_idx) | |
| slice_end_idx = min(T_sub, slice_end_idx) | |
| max_nps_in_drumroll_period = 0.0 | |
| if slice_start_idx < slice_end_idx: | |
| relevant_base_nps_values = base_don_ka_sliding_nps[ | |
| slice_start_idx:slice_end_idx | |
| ] | |
| if relevant_base_nps_values.numel() > 0: | |
| max_nps_in_drumroll_period = torch.max( | |
| relevant_base_nps_values | |
| ).item() | |
| rate = max(5.0, max_nps_in_drumroll_period) | |
| adaptive_drumroll_rates_for_all_events.append(rate) | |
| else: | |
| adaptive_drumroll_rates_for_all_events.append(0.0) # Placeholder | |
| # Step 3: Calculate final sliding_nps_labels using adaptive rates | |
| # sliding_nps_labels is already initialized with zeros earlier in the function. | |
| for k_idx in range(T_sub): | |
| k_window_end_sec = ((k_idx + 1) * TIME_SUB) / fps | |
| k_window_start_sec = k_window_end_sec - window_duration_seconds | |
| current_window_total_nps_contribution = 0.0 | |
| for event_idx, (event_t, is_drumroll, drumroll_dur) in enumerate(note_events): | |
| if is_drumroll: | |
| drumroll_start_sec = event_t | |
| drumroll_end_sec = event_t + drumroll_dur | |
| overlap_start = max(k_window_start_sec, drumroll_start_sec) | |
| overlap_end = min(k_window_end_sec, drumroll_end_sec) | |
| if overlap_end > overlap_start: | |
| overlap_duration = overlap_end - overlap_start | |
| current_adaptive_rate = adaptive_drumroll_rates_for_all_events[ | |
| event_idx | |
| ] | |
| current_window_total_nps_contribution += ( | |
| overlap_duration * current_adaptive_rate | |
| ) | |
| else: # Don or Ka hit | |
| if k_window_start_sec <= event_t < k_window_end_sec: | |
| current_window_total_nps_contribution += ( | |
| 1 # Each hit contributes 1 to the count | |
| ) | |
| sliding_nps_labels[k_idx] = ( | |
| current_window_total_nps_contribution / window_duration_seconds | |
| ) | |
| # Normalize sliding_nps_labels to 0-1 range | |
| if T_sub > 0: # Ensure there are elements to normalize | |
| min_nps_val = torch.min(sliding_nps_labels) | |
| max_nps_val = torch.max(sliding_nps_labels) | |
| denominator = max_nps_val - min_nps_val | |
| if denominator > 1e-6: # Use a small epsilon for float comparison | |
| sliding_nps_labels = (sliding_nps_labels - min_nps_val) / denominator | |
| else: | |
| # If all values are (nearly) the same | |
| if max_nps_val > 1e-6: # If the constant value is positive | |
| sliding_nps_labels = torch.ones_like(sliding_nps_labels) | |
| else: # If the constant value is zero (or very close to it) | |
| sliding_nps_labels = torch.zeros_like(sliding_nps_labels) | |
| duration_seconds = wav_tensor.shape[-1] / SAMPLE_RATE | |
| nps = num_valid_notes / duration_seconds if duration_seconds > 0 else 0.0 | |
| parsed = parse_tja(example["tja"], mode=PyParsingMode.Full) | |
| chart = next( | |
| (chart for chart in parsed.charts if chart.course.lower() == difficulty), None | |
| ) | |
| difficulty_id = ( | |
| 0 | |
| if difficulty == "easy" | |
| else ( | |
| 1 | |
| if difficulty == "normal" | |
| else 2 if difficulty == "hard" else 3 if difficulty == "oni" else 4 | |
| ) # Assuming 4 for edit/ura | |
| ) | |
| level = chart.level if chart else 0 | |
| # --- CNN shape inference and label padding/truncation --- | |
| # Simulate CNN to get output time length (T_cnn) | |
| dummy_model = TaikoConformer7() | |
| with torch.no_grad(): | |
| cnn_out = dummy_model.cnn(mel.unsqueeze(0)) # (1, C, F, T_cnn) | |
| _, _, _, T_cnn = cnn_out.shape | |
| # Pad or truncate labels to T_cnn | |
| def pad_or_truncate(label, out_len): | |
| if label.shape[0] < out_len: | |
| pad = torch.zeros(out_len - label.shape[0], dtype=label.dtype) | |
| return torch.cat([label, pad], dim=0) | |
| else: | |
| return label[:out_len] | |
| don_labels = pad_or_truncate(don_labels, T_cnn) | |
| ka_labels = pad_or_truncate(ka_labels, T_cnn) | |
| drumroll_labels = pad_or_truncate(drumroll_labels, T_cnn) | |
| sliding_nps_labels = pad_or_truncate(sliding_nps_labels, T_cnn) # Pad new label | |
| # For conformer input lengths: this should be T_cnn | |
| conformer_sequence_length = T_cnn # This is the actual sequence length after CNN | |
| print( | |
| f"Processed {num_valid_notes} notes in {duration_seconds:.2f} seconds, NPS: {nps:.2f}, Difficulty: {difficulty_id}, Level: {level}" | |
| ) | |
| return { | |
| "mel": mel, # (1, N_MELS, T) | |
| "don_labels": don_labels, # (T_cnn,) | |
| "ka_labels": ka_labels, # (T_cnn,) | |
| "drumroll_labels": drumroll_labels, # (T_cnn,) | |
| "sliding_nps_labels": sliding_nps_labels, # Add new label (T_cnn,) | |
| "nps": torch.tensor(nps, dtype=torch.float32), | |
| "difficulty": torch.tensor(difficulty_id, dtype=torch.long), | |
| "level": torch.tensor(level, dtype=torch.long), | |
| "duration_seconds": torch.tensor(duration_seconds, dtype=torch.float32), | |
| "length": torch.tensor( | |
| conformer_sequence_length, dtype=torch.long | |
| ), # Use T_cnn for conformer and loss masking | |
| } | |
| def collate_fn(batch): | |
| mels_list = [b["mel"].squeeze(0).transpose(0, 1) for b in batch] # (T, N_MELS) | |
| don_labels_list = [b["don_labels"] for b in batch] | |
| ka_labels_list = [b["ka_labels"] for b in batch] | |
| drumroll_labels_list = [b["drumroll_labels"] for b in batch] | |
| sliding_nps_labels_list = [b["sliding_nps_labels"] for b in batch] # New label list | |
| nps_list = [b["nps"] for b in batch] | |
| difficulty_list = [b["difficulty"] for b in batch] | |
| level_list = [b["level"] for b in batch] | |
| durations_list = [b["duration_seconds"] for b in batch] | |
| lengths_list = [b["length"] for b in batch] # These are T_cnn_i for each example | |
| # Pad mels | |
| padded_mels = nn.utils.rnn.pad_sequence( | |
| mels_list, batch_first=True | |
| ) # (B, T_max_mel, N_MELS) | |
| reshaped_mels = padded_mels.transpose(1, 2).unsqueeze(1) | |
| # T_max_mel_batch = padded_mels.shape[1] # Max mel length in batch, not used for label padding anymore | |
| # Determine max sequence length for labels (max T_cnn in batch) | |
| max_label_len = 0 | |
| if lengths_list: # handle empty batch case | |
| max_label_len = max(l.item() for l in lengths_list) if lengths_list else 0 | |
| # Pad labels to max_label_len (max_t_cnn_in_batch) | |
| def pad_label_to_max_len(label_tensor, target_len): | |
| current_len = label_tensor.shape[0] | |
| if current_len < target_len: | |
| padding_size = target_len - current_len | |
| # Ensure padding is created on the same device as the label_tensor | |
| padding = torch.zeros( | |
| padding_size, dtype=label_tensor.dtype, device=label_tensor.device | |
| ) | |
| return torch.cat((label_tensor, padding), dim=0) | |
| elif ( | |
| current_len > target_len | |
| ): # Should ideally not happen if lengths_list is correct | |
| return label_tensor[:target_len] | |
| return label_tensor | |
| don_labels = torch.stack( | |
| [pad_label_to_max_len(l, max_label_len) for l in don_labels_list] | |
| ) | |
| ka_labels = torch.stack( | |
| [pad_label_to_max_len(l, max_label_len) for l in ka_labels_list] | |
| ) | |
| drumroll_labels = torch.stack( | |
| [pad_label_to_max_len(l, max_label_len) for l in drumroll_labels_list] | |
| ) | |
| sliding_nps_labels = torch.stack( | |
| [pad_label_to_max_len(l, max_label_len) for l in sliding_nps_labels_list] | |
| ) # Pad new labels | |
| actual_lengths = torch.tensor([l.item() for l in lengths_list], dtype=torch.long) | |
| return { | |
| "mel": reshaped_mels, | |
| "don_labels": don_labels, | |
| "ka_labels": ka_labels, | |
| "drumroll_labels": drumroll_labels, | |
| "sliding_nps_labels": sliding_nps_labels, # Add new batched labels | |
| "lengths": actual_lengths, # for conformer and loss masking (T_cnn_i for each item) | |
| "nps": torch.stack(nps_list), | |
| "difficulty": torch.stack(difficulty_list), | |
| "level": torch.stack(level_list), | |
| "durations": torch.stack(durations_list), | |
| } | |