|  |  | 
					
						
						|  | __author__ = 'Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/' | 
					
						
						|  |  | 
					
						
						|  | import argparse | 
					
						
						|  | import time | 
					
						
						|  | import os | 
					
						
						|  | import glob | 
					
						
						|  | import torch | 
					
						
						|  | import librosa | 
					
						
						|  | import numpy as np | 
					
						
						|  | import soundfile as sf | 
					
						
						|  | from tqdm.auto import tqdm | 
					
						
						|  | from ml_collections import ConfigDict | 
					
						
						|  | from typing import Tuple, Dict, List, Union | 
					
						
						|  | from utils import demix, get_model_from_config, prefer_target_instrument, draw_spectrogram | 
					
						
						|  | from utils import normalize_audio, denormalize_audio, apply_tta, read_audio_transposed, load_start_checkpoint | 
					
						
						|  | from metrics import get_metrics | 
					
						
						|  | import warnings | 
					
						
						|  |  | 
					
						
						|  | warnings.filterwarnings("ignore") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def logging(logs: List[str], text: str, verbose_logging: bool = False) -> None: | 
					
						
						|  | """ | 
					
						
						|  | Log validation information by printing the text and appending it to a log list. | 
					
						
						|  |  | 
					
						
						|  | Parameters: | 
					
						
						|  | ---------- | 
					
						
						|  | store_dir : str | 
					
						
						|  | Directory to store the logs. If empty, logs are not stored. | 
					
						
						|  | logs : List[str] | 
					
						
						|  | List where the logs will be appended if the store_dir is specified. | 
					
						
						|  | text : str | 
					
						
						|  | The text to be logged, printed, and optionally added to the logs list. | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | ------- | 
					
						
						|  | None | 
					
						
						|  | This function modifies the logs list in place and prints the text. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | print(text) | 
					
						
						|  | if verbose_logging: | 
					
						
						|  | logs.append(text) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def write_results_in_file(store_dir: str, logs: List[str]) -> None: | 
					
						
						|  | """ | 
					
						
						|  | Write the list of results into a file in the specified directory. | 
					
						
						|  |  | 
					
						
						|  | Parameters: | 
					
						
						|  | ---------- | 
					
						
						|  | store_dir : str | 
					
						
						|  | The directory where the results file will be saved. | 
					
						
						|  | results : List[str] | 
					
						
						|  | A list of result strings to be written to the file. | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | ------- | 
					
						
						|  | None | 
					
						
						|  | """ | 
					
						
						|  | with open(f'{store_dir}/results.txt', 'w') as out: | 
					
						
						|  | for item in logs: | 
					
						
						|  | out.write(item + "\n") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def get_mixture_paths( | 
					
						
						|  | args, | 
					
						
						|  | verbose: bool, | 
					
						
						|  | config: ConfigDict, | 
					
						
						|  | extension: str | 
					
						
						|  | ) -> List[str]: | 
					
						
						|  | """ | 
					
						
						|  | Retrieve paths to mixture files in the specified validation directories. | 
					
						
						|  |  | 
					
						
						|  | Parameters: | 
					
						
						|  | ---------- | 
					
						
						|  | valid_path : List[str] | 
					
						
						|  | A list of directories to search for validation mixtures. | 
					
						
						|  | verbose : bool | 
					
						
						|  | If True, prints detailed information about the search process. | 
					
						
						|  | config : ConfigDict | 
					
						
						|  | Configuration object containing parameters like `inference.num_overlap` and `inference.batch_size`. | 
					
						
						|  | extension : str | 
					
						
						|  | File extension of the mixture files (e.g., 'wav'). | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | ------- | 
					
						
						|  | List[str] | 
					
						
						|  | A list of file paths to the mixture files. | 
					
						
						|  | """ | 
					
						
						|  | try: | 
					
						
						|  | valid_path = args.valid_path | 
					
						
						|  | except Exception as e: | 
					
						
						|  | print('No valid path in args') | 
					
						
						|  | raise e | 
					
						
						|  |  | 
					
						
						|  | all_mixtures_path = [] | 
					
						
						|  | for path in valid_path: | 
					
						
						|  | part = sorted(glob.glob(f"{path}/*/mixture.{extension}")) | 
					
						
						|  | if len(part) == 0: | 
					
						
						|  | if verbose: | 
					
						
						|  | print(f'No validation data found in: {path}') | 
					
						
						|  | all_mixtures_path += part | 
					
						
						|  | if verbose: | 
					
						
						|  | print(f'Total mixtures: {len(all_mixtures_path)}') | 
					
						
						|  | print(f'Overlap: {config.inference.num_overlap} Batch size: {config.inference.batch_size}') | 
					
						
						|  |  | 
					
						
						|  | return all_mixtures_path | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def update_metrics_and_pbar( | 
					
						
						|  | track_metrics: Dict, | 
					
						
						|  | all_metrics: Dict, | 
					
						
						|  | instr: str, | 
					
						
						|  | pbar_dict: Dict, | 
					
						
						|  | mixture_paths: Union[List[str], tqdm], | 
					
						
						|  | verbose: bool = False | 
					
						
						|  | ) -> None: | 
					
						
						|  | """ | 
					
						
						|  | Update metrics dictionary and progress bar with new metric values. | 
					
						
						|  |  | 
					
						
						|  | Parameters: | 
					
						
						|  | ---------- | 
					
						
						|  | track_metrics : Dict | 
					
						
						|  | Dictionary with metric names as keys and their computed values as values. | 
					
						
						|  | all_metrics : Dict | 
					
						
						|  | Dictionary to store all metrics, organized by metric name and instrument. | 
					
						
						|  | instr : str | 
					
						
						|  | Name of the instrument for which the metrics are being computed. | 
					
						
						|  | pbar_dict : Dict | 
					
						
						|  | Dictionary for progress bar updates. | 
					
						
						|  | mixture_paths : tqdm, optional | 
					
						
						|  | Progress bar object, if available. Default is None. | 
					
						
						|  | verbose : bool, optional | 
					
						
						|  | If True, prints metric values to the console. Default is False. | 
					
						
						|  | """ | 
					
						
						|  | for metric_name, metric_value in track_metrics.items(): | 
					
						
						|  | if verbose: | 
					
						
						|  | print(f"Metric {metric_name:11s} value: {metric_value:.4f}") | 
					
						
						|  | all_metrics[metric_name][instr].append(metric_value) | 
					
						
						|  | pbar_dict[f'{metric_name}_{instr}'] = metric_value | 
					
						
						|  |  | 
					
						
						|  | if mixture_paths is not None: | 
					
						
						|  | try: | 
					
						
						|  | mixture_paths.set_postfix(pbar_dict) | 
					
						
						|  | except Exception: | 
					
						
						|  | pass | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def process_audio_files( | 
					
						
						|  | mixture_paths: List[str], | 
					
						
						|  | model: torch.nn.Module, | 
					
						
						|  | args, | 
					
						
						|  | config, | 
					
						
						|  | device: torch.device, | 
					
						
						|  | verbose: bool = False, | 
					
						
						|  | is_tqdm: bool = True | 
					
						
						|  | ) -> Dict[str, Dict[str, List[float]]]: | 
					
						
						|  | """ | 
					
						
						|  | Process a list of audio files, perform source separation, and evaluate metrics. | 
					
						
						|  |  | 
					
						
						|  | Parameters: | 
					
						
						|  | ---------- | 
					
						
						|  | mixture_paths : List[str] | 
					
						
						|  | List of file paths to the audio mixtures. | 
					
						
						|  | model : torch.nn.Module | 
					
						
						|  | The trained model used for source separation. | 
					
						
						|  | args : Any | 
					
						
						|  | Argument object containing user-specified options like metrics, model type, etc. | 
					
						
						|  | config : Any | 
					
						
						|  | Configuration object containing model and processing parameters. | 
					
						
						|  | device : torch.device | 
					
						
						|  | Device (CPU or CUDA) on which the model will be executed. | 
					
						
						|  | verbose : bool, optional | 
					
						
						|  | If True, prints detailed logs for each processed file. Default is False. | 
					
						
						|  | is_tqdm : bool, optional | 
					
						
						|  | If True, displays a progress bar for file processing. Default is True. | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | ------- | 
					
						
						|  | Dict[str, Dict[str, List[float]]] | 
					
						
						|  | A nested dictionary where the outer keys are metric names, | 
					
						
						|  | the inner keys are instrument names, and the values are lists of metric scores. | 
					
						
						|  | """ | 
					
						
						|  | instruments = prefer_target_instrument(config) | 
					
						
						|  |  | 
					
						
						|  | use_tta = getattr(args, 'use_tta', False) | 
					
						
						|  |  | 
					
						
						|  | store_dir = getattr(args, 'store_dir', '') | 
					
						
						|  |  | 
					
						
						|  | if 'extension' in config['inference']: | 
					
						
						|  | extension = config['inference']['extension'] | 
					
						
						|  | else: | 
					
						
						|  | extension = getattr(args, 'extension', 'wav') | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | all_metrics = { | 
					
						
						|  | metric: {instr: [] for instr in config.training.instruments} | 
					
						
						|  | for metric in args.metrics | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | if is_tqdm: | 
					
						
						|  | mixture_paths = tqdm(mixture_paths) | 
					
						
						|  |  | 
					
						
						|  | for path in mixture_paths: | 
					
						
						|  | start_time = time.time() | 
					
						
						|  | mix, sr = read_audio_transposed(path) | 
					
						
						|  | mix_orig = mix.copy() | 
					
						
						|  | folder = os.path.dirname(path) | 
					
						
						|  |  | 
					
						
						|  | if 'sample_rate' in config.audio: | 
					
						
						|  | if sr != config.audio['sample_rate']: | 
					
						
						|  | orig_length = mix.shape[-1] | 
					
						
						|  | if verbose: | 
					
						
						|  | print(f'Warning: sample rate is different. In config: {config.audio["sample_rate"]} in file {path}: {sr}') | 
					
						
						|  | mix = librosa.resample(mix, orig_sr=sr, target_sr=config.audio['sample_rate'], res_type='kaiser_best') | 
					
						
						|  |  | 
					
						
						|  | if verbose: | 
					
						
						|  | folder_name = os.path.abspath(folder) | 
					
						
						|  | print(f'Song: {folder_name} Shape: {mix.shape}') | 
					
						
						|  |  | 
					
						
						|  | if 'normalize' in config.inference: | 
					
						
						|  | if config.inference['normalize'] is True: | 
					
						
						|  | mix, norm_params = normalize_audio(mix) | 
					
						
						|  |  | 
					
						
						|  | waveforms_orig = demix(config, model, mix.copy(), device, model_type=args.model_type) | 
					
						
						|  |  | 
					
						
						|  | if use_tta: | 
					
						
						|  | waveforms_orig = apply_tta(config, model, mix, waveforms_orig, device, args.model_type) | 
					
						
						|  |  | 
					
						
						|  | pbar_dict = {} | 
					
						
						|  |  | 
					
						
						|  | for instr in instruments: | 
					
						
						|  | if verbose: | 
					
						
						|  | print(f"Instr: {instr}") | 
					
						
						|  |  | 
					
						
						|  | if instr != 'other' or config.training.other_fix is False: | 
					
						
						|  | track, sr1 = read_audio_transposed(f"{folder}/{instr}.{extension}", instr, skip_err=True) | 
					
						
						|  | if track is None: | 
					
						
						|  | continue | 
					
						
						|  | else: | 
					
						
						|  |  | 
					
						
						|  | track, sr1 = read_audio_transposed(f"{folder}/vocals.{extension}") | 
					
						
						|  | track = mix_orig - track | 
					
						
						|  |  | 
					
						
						|  | estimates = waveforms_orig[instr] | 
					
						
						|  |  | 
					
						
						|  | if 'sample_rate' in config.audio: | 
					
						
						|  | if sr != config.audio['sample_rate']: | 
					
						
						|  | estimates = librosa.resample(estimates, orig_sr=config.audio['sample_rate'], target_sr=sr, | 
					
						
						|  | res_type='kaiser_best') | 
					
						
						|  | estimates = librosa.util.fix_length(estimates, size=orig_length) | 
					
						
						|  |  | 
					
						
						|  | if 'normalize' in config.inference: | 
					
						
						|  | if config.inference['normalize'] is True: | 
					
						
						|  | estimates = denormalize_audio(estimates, norm_params) | 
					
						
						|  |  | 
					
						
						|  | if store_dir: | 
					
						
						|  | os.makedirs(store_dir, exist_ok=True) | 
					
						
						|  | out_wav_name = f"{store_dir}/{os.path.basename(folder)}_{instr}.wav" | 
					
						
						|  | sf.write(out_wav_name, estimates.T, sr, subtype='FLOAT') | 
					
						
						|  | if args.draw_spectro > 0: | 
					
						
						|  | out_img_name = f"{store_dir}/{os.path.basename(folder)}_{instr}.jpg" | 
					
						
						|  | draw_spectrogram(estimates.T, sr, args.draw_spectro, out_img_name) | 
					
						
						|  | out_img_name_orig = f"{store_dir}/{os.path.basename(folder)}_{instr}_orig.jpg" | 
					
						
						|  | draw_spectrogram(track.T, sr, args.draw_spectro, out_img_name_orig) | 
					
						
						|  |  | 
					
						
						|  | track_metrics = get_metrics( | 
					
						
						|  | args.metrics, | 
					
						
						|  | track, | 
					
						
						|  | estimates, | 
					
						
						|  | mix_orig, | 
					
						
						|  | device=device, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | update_metrics_and_pbar( | 
					
						
						|  | track_metrics, | 
					
						
						|  | all_metrics, | 
					
						
						|  | instr, pbar_dict, | 
					
						
						|  | mixture_paths=mixture_paths, | 
					
						
						|  | verbose=verbose | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if verbose: | 
					
						
						|  | print(f"Time for song: {time.time() - start_time:.2f} sec") | 
					
						
						|  |  | 
					
						
						|  | return all_metrics | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def compute_metric_avg( | 
					
						
						|  | store_dir: str, | 
					
						
						|  | args, | 
					
						
						|  | instruments: List[str], | 
					
						
						|  | config: ConfigDict, | 
					
						
						|  | all_metrics: Dict[str, Dict[str, List[float]]], | 
					
						
						|  | start_time: float | 
					
						
						|  | ) -> Dict[str, float]: | 
					
						
						|  | """ | 
					
						
						|  | Calculate and log the average metrics for each instrument, including per-instrument metrics and overall averages. | 
					
						
						|  |  | 
					
						
						|  | Parameters: | 
					
						
						|  | ---------- | 
					
						
						|  | store_dir : str | 
					
						
						|  | Directory to store the logs. If empty, logs are not stored. | 
					
						
						|  | args : dict | 
					
						
						|  | Dictionary containing the arguments, used for logging. | 
					
						
						|  | instruments : List[str] | 
					
						
						|  | List of instruments to process. | 
					
						
						|  | config : ConfigDict | 
					
						
						|  | Configuration dictionary containing the inference settings. | 
					
						
						|  | all_metrics : Dict[str, Dict[str, List[float]]] | 
					
						
						|  | A dictionary containing metric values for each instrument. | 
					
						
						|  | The structure is {metric_name: {instrument_name: [metric_values]}}. | 
					
						
						|  | start_time : float | 
					
						
						|  | The starting time for calculating elapsed time. | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | ------- | 
					
						
						|  | Dict[str, float] | 
					
						
						|  | A dictionary with the average value for each metric across all instruments. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | logs = [] | 
					
						
						|  | if store_dir: | 
					
						
						|  | logs.append(str(args)) | 
					
						
						|  | verbose_logging = True | 
					
						
						|  | else: | 
					
						
						|  | verbose_logging = False | 
					
						
						|  |  | 
					
						
						|  | logging(logs, text=f"Num overlap: {config.inference.num_overlap}", verbose_logging=verbose_logging) | 
					
						
						|  |  | 
					
						
						|  | metric_avg = {} | 
					
						
						|  | for instr in instruments: | 
					
						
						|  | for metric_name in all_metrics: | 
					
						
						|  | metric_values = np.array(all_metrics[metric_name][instr]) | 
					
						
						|  |  | 
					
						
						|  | mean_val = metric_values.mean() | 
					
						
						|  | std_val = metric_values.std() | 
					
						
						|  |  | 
					
						
						|  | logging(logs, text=f"Instr {instr} {metric_name}: {mean_val:.4f} (Std: {std_val:.4f})", verbose_logging=verbose_logging) | 
					
						
						|  | if metric_name not in metric_avg: | 
					
						
						|  | metric_avg[metric_name] = 0.0 | 
					
						
						|  | metric_avg[metric_name] += mean_val | 
					
						
						|  | for metric_name in all_metrics: | 
					
						
						|  | metric_avg[metric_name] /= len(instruments) | 
					
						
						|  |  | 
					
						
						|  | if len(instruments) > 1: | 
					
						
						|  | for metric_name in metric_avg: | 
					
						
						|  | logging(logs, text=f'Metric avg {metric_name:11s}: {metric_avg[metric_name]:.4f}', verbose_logging=verbose_logging) | 
					
						
						|  | logging(logs, text=f"Elapsed time: {time.time() - start_time:.2f} sec", verbose_logging=verbose_logging) | 
					
						
						|  |  | 
					
						
						|  | if store_dir: | 
					
						
						|  | write_results_in_file(store_dir, logs) | 
					
						
						|  |  | 
					
						
						|  | return metric_avg | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def valid( | 
					
						
						|  | model: torch.nn.Module, | 
					
						
						|  | args, | 
					
						
						|  | config: ConfigDict, | 
					
						
						|  | device: torch.device, | 
					
						
						|  | verbose: bool = False | 
					
						
						|  | ) -> Tuple[dict, dict]: | 
					
						
						|  | """ | 
					
						
						|  | Validate a trained model on a set of audio mixtures and compute metrics. | 
					
						
						|  |  | 
					
						
						|  | This function performs validation by separating audio sources from mixtures, | 
					
						
						|  | computing evaluation metrics, and optionally saving results to a file. | 
					
						
						|  |  | 
					
						
						|  | Parameters: | 
					
						
						|  | ---------- | 
					
						
						|  | model : torch.nn.Module | 
					
						
						|  | The trained model for source separation. | 
					
						
						|  | args : Namespace | 
					
						
						|  | Command-line arguments or equivalent object containing configurations. | 
					
						
						|  | config : dict | 
					
						
						|  | Configuration dictionary with model and processing parameters. | 
					
						
						|  | device : torch.device | 
					
						
						|  | The device (CPU or CUDA) to run the model on. | 
					
						
						|  | verbose : bool, optional | 
					
						
						|  | If True, enables verbose output during processing. Default is False. | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | ------- | 
					
						
						|  | dict | 
					
						
						|  | A dictionary of average metrics across all instruments. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | start_time = time.time() | 
					
						
						|  | model.eval().to(device) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | store_dir = getattr(args, 'store_dir', '') | 
					
						
						|  |  | 
					
						
						|  | if 'extension' in config['inference']: | 
					
						
						|  | extension = config['inference']['extension'] | 
					
						
						|  | else: | 
					
						
						|  | extension = getattr(args, 'extension', 'wav') | 
					
						
						|  |  | 
					
						
						|  | all_mixtures_path = get_mixture_paths(args, verbose, config, extension) | 
					
						
						|  | all_metrics = process_audio_files(all_mixtures_path, model, args, config, device, verbose, not verbose) | 
					
						
						|  | instruments = prefer_target_instrument(config) | 
					
						
						|  |  | 
					
						
						|  | return compute_metric_avg(store_dir, args, instruments, config, all_metrics, start_time), all_metrics | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def validate_in_subprocess( | 
					
						
						|  | proc_id: int, | 
					
						
						|  | queue: torch.multiprocessing.Queue, | 
					
						
						|  | all_mixtures_path: List[str], | 
					
						
						|  | model: torch.nn.Module, | 
					
						
						|  | args, | 
					
						
						|  | config: ConfigDict, | 
					
						
						|  | device: str, | 
					
						
						|  | return_dict | 
					
						
						|  | ) -> None: | 
					
						
						|  | """ | 
					
						
						|  | Perform validation on a subprocess with multi-processing support. Each process handles inference on a subset of the mixture files | 
					
						
						|  | and updates the shared metrics dictionary. | 
					
						
						|  |  | 
					
						
						|  | Parameters: | 
					
						
						|  | ---------- | 
					
						
						|  | proc_id : int | 
					
						
						|  | The process ID (used to assign metrics to the correct key in `return_dict`). | 
					
						
						|  | queue : torch.multiprocessing.Queue | 
					
						
						|  | Queue to receive paths to the mixture files for processing. | 
					
						
						|  | all_mixtures_path : List[str] | 
					
						
						|  | List of paths to the mixture files to be processed. | 
					
						
						|  | model : torch.nn.Module | 
					
						
						|  | The model to be used for inference. | 
					
						
						|  | args : dict | 
					
						
						|  | Dictionary containing various argument configurations (e.g., metrics to calculate). | 
					
						
						|  | config : ConfigDict | 
					
						
						|  | Configuration object containing model settings and training parameters. | 
					
						
						|  | device : str | 
					
						
						|  | The device to use for inference (e.g., 'cpu', 'cuda:0'). | 
					
						
						|  | return_dict : torch.multiprocessing.Manager().dict | 
					
						
						|  | Shared dictionary to store the results from each process. | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | ------- | 
					
						
						|  | None | 
					
						
						|  | The function modifies the `return_dict` in place, but does not return any value. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | m1 = model.eval().to(device) | 
					
						
						|  | if proc_id == 0: | 
					
						
						|  | progress_bar = tqdm(total=len(all_mixtures_path)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | all_metrics = { | 
					
						
						|  | metric: {instr: [] for instr in config.training.instruments} | 
					
						
						|  | for metric in args.metrics | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | while True: | 
					
						
						|  | current_step, path = queue.get() | 
					
						
						|  | if path is None: | 
					
						
						|  | break | 
					
						
						|  | single_metrics = process_audio_files([path], m1, args, config, device, False, False) | 
					
						
						|  | pbar_dict = {} | 
					
						
						|  | for instr in config.training.instruments: | 
					
						
						|  | for metric_name in all_metrics: | 
					
						
						|  | all_metrics[metric_name][instr] += single_metrics[metric_name][instr] | 
					
						
						|  | if len(single_metrics[metric_name][instr]) > 0: | 
					
						
						|  | pbar_dict[f"{metric_name}_{instr}"] = f"{single_metrics[metric_name][instr][0]:.4f}" | 
					
						
						|  | if proc_id == 0: | 
					
						
						|  | progress_bar.update(current_step - progress_bar.n) | 
					
						
						|  | progress_bar.set_postfix(pbar_dict) | 
					
						
						|  |  | 
					
						
						|  | return_dict[proc_id] = all_metrics | 
					
						
						|  | return | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def run_parallel_validation( | 
					
						
						|  | verbose: bool, | 
					
						
						|  | all_mixtures_path: List[str], | 
					
						
						|  | config: ConfigDict, | 
					
						
						|  | model: torch.nn.Module, | 
					
						
						|  | device_ids: List[int], | 
					
						
						|  | args, | 
					
						
						|  | return_dict | 
					
						
						|  | ) -> None: | 
					
						
						|  | """ | 
					
						
						|  | Run parallel validation using multiple processes. Each process handles a subset of the mixture files and computes the metrics. | 
					
						
						|  | The results are stored in a shared dictionary. | 
					
						
						|  |  | 
					
						
						|  | Parameters: | 
					
						
						|  | ---------- | 
					
						
						|  | verbose : bool | 
					
						
						|  | Flag to print detailed information about the validation process. | 
					
						
						|  | all_mixtures_path : List[str] | 
					
						
						|  | List of paths to the mixture files to be processed. | 
					
						
						|  | config : ConfigDict | 
					
						
						|  | Configuration object containing model settings and validation parameters. | 
					
						
						|  | model : torch.nn.Module | 
					
						
						|  | The model to be used for inference. | 
					
						
						|  | device_ids : List[int] | 
					
						
						|  | List of device IDs (for multi-GPU setups) to use for validation. | 
					
						
						|  | args : dict | 
					
						
						|  | Dictionary containing various argument configurations (e.g., metrics to calculate). | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | ------- | 
					
						
						|  | A shared dictionary containing the validation metrics from all processes. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | model = model.to('cpu') | 
					
						
						|  | try: | 
					
						
						|  |  | 
					
						
						|  | model = model.module | 
					
						
						|  | except: | 
					
						
						|  | pass | 
					
						
						|  |  | 
					
						
						|  | queue = torch.multiprocessing.Queue() | 
					
						
						|  | processes = [] | 
					
						
						|  |  | 
					
						
						|  | for i, device in enumerate(device_ids): | 
					
						
						|  | if torch.cuda.is_available(): | 
					
						
						|  | device = f'cuda:{device}' | 
					
						
						|  | else: | 
					
						
						|  | device = 'cpu' | 
					
						
						|  | p = torch.multiprocessing.Process( | 
					
						
						|  | target=validate_in_subprocess, | 
					
						
						|  | args=(i, queue, all_mixtures_path, model, args, config, device, return_dict) | 
					
						
						|  | ) | 
					
						
						|  | p.start() | 
					
						
						|  | processes.append(p) | 
					
						
						|  | for i, path in enumerate(all_mixtures_path): | 
					
						
						|  | queue.put((i, path)) | 
					
						
						|  | for _ in range(len(device_ids)): | 
					
						
						|  | queue.put((None, None)) | 
					
						
						|  | for p in processes: | 
					
						
						|  | p.join() | 
					
						
						|  |  | 
					
						
						|  | return | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def valid_multi_gpu( | 
					
						
						|  | model: torch.nn.Module, | 
					
						
						|  | args, | 
					
						
						|  | config: ConfigDict, | 
					
						
						|  | device_ids: List[int], | 
					
						
						|  | verbose: bool = False | 
					
						
						|  | ) -> Tuple[Dict[str, float], dict]: | 
					
						
						|  | """ | 
					
						
						|  | Perform validation across multiple GPUs, processing mixtures and computing metrics using parallel processes. | 
					
						
						|  | The results from each GPU are aggregated and the average metrics are computed. | 
					
						
						|  |  | 
					
						
						|  | Parameters: | 
					
						
						|  | ---------- | 
					
						
						|  | model : torch.nn.Module | 
					
						
						|  | The model to be used for inference. | 
					
						
						|  | args : dict | 
					
						
						|  | Dictionary containing various argument configurations, such as file saving directory and codec settings. | 
					
						
						|  | config : ConfigDict | 
					
						
						|  | Configuration object containing model settings and validation parameters. | 
					
						
						|  | device_ids : List[int] | 
					
						
						|  | List of device IDs (for multi-GPU setups) to use for validation. | 
					
						
						|  | verbose : bool, optional | 
					
						
						|  | Flag to print detailed information about the validation process. Default is False. | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | ------- | 
					
						
						|  | Dict[str, float] | 
					
						
						|  | A dictionary containing the average metrics for each metric name. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | start_time = time.time() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | store_dir = getattr(args, 'store_dir', '') | 
					
						
						|  |  | 
					
						
						|  | if 'extension' in config['inference']: | 
					
						
						|  | extension = config['inference']['extension'] | 
					
						
						|  | else: | 
					
						
						|  | extension = getattr(args, 'extension', 'wav') | 
					
						
						|  |  | 
					
						
						|  | all_mixtures_path = get_mixture_paths(args, verbose, config, extension) | 
					
						
						|  |  | 
					
						
						|  | return_dict = torch.multiprocessing.Manager().dict() | 
					
						
						|  |  | 
					
						
						|  | run_parallel_validation(verbose, all_mixtures_path, config, model, device_ids, args, return_dict) | 
					
						
						|  |  | 
					
						
						|  | all_metrics = dict() | 
					
						
						|  | for metric in args.metrics: | 
					
						
						|  | all_metrics[metric] = dict() | 
					
						
						|  | for instr in config.training.instruments: | 
					
						
						|  | all_metrics[metric][instr] = [] | 
					
						
						|  | for i in range(len(device_ids)): | 
					
						
						|  | all_metrics[metric][instr] += return_dict[i][metric][instr] | 
					
						
						|  |  | 
					
						
						|  | instruments = prefer_target_instrument(config) | 
					
						
						|  |  | 
					
						
						|  | return compute_metric_avg(store_dir, args, instruments, config, all_metrics, start_time), all_metrics | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def parse_args(dict_args: Union[Dict, None]) -> argparse.Namespace: | 
					
						
						|  | """ | 
					
						
						|  | Parse command-line arguments for configuring the model, dataset, and training parameters. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | dict_args: Dict of command-line arguments. If None, arguments will be parsed from sys.argv. | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | Namespace object containing parsed arguments and their values. | 
					
						
						|  | """ | 
					
						
						|  | parser = argparse.ArgumentParser() | 
					
						
						|  | parser.add_argument("--model_type", type=str, default='mdx23c', | 
					
						
						|  | help="One of mdx23c, htdemucs, segm_models, mel_band_roformer," | 
					
						
						|  | " bs_roformer, swin_upernet, bandit") | 
					
						
						|  | parser.add_argument("--config_path", type=str, help="Path to config file") | 
					
						
						|  | parser.add_argument("--start_check_point", type=str, default='', help="Initial checkpoint" | 
					
						
						|  | " to valid weights") | 
					
						
						|  | parser.add_argument("--valid_path", nargs="+", type=str, help="Validate path") | 
					
						
						|  | parser.add_argument("--store_dir", type=str, default="", help="Path to store results as wav file") | 
					
						
						|  | parser.add_argument("--draw_spectro", type=float, default=0, | 
					
						
						|  | help="If --store_dir is set then code will generate spectrograms for resulted stems as well." | 
					
						
						|  | " Value defines for how many seconds os track spectrogram will be generated.") | 
					
						
						|  | parser.add_argument("--device_ids", nargs='+', type=int, default=0, help='List of gpu ids') | 
					
						
						|  | parser.add_argument("--num_workers", type=int, default=0, help="Dataloader num_workers") | 
					
						
						|  | parser.add_argument("--pin_memory", action='store_true', help="Dataloader pin_memory") | 
					
						
						|  | parser.add_argument("--extension", type=str, default='wav', help="Choose extension for validation") | 
					
						
						|  | parser.add_argument("--use_tta", action='store_true', | 
					
						
						|  | help="Flag adds test time augmentation during inference (polarity and channel inverse)." | 
					
						
						|  | "While this triples the runtime, it reduces noise and slightly improves prediction quality.") | 
					
						
						|  | parser.add_argument("--metrics", nargs='+', type=str, default=["sdr"], | 
					
						
						|  | choices=['sdr', 'l1_freq', 'si_sdr', 'neg_log_wmse', 'aura_stft', 'aura_mrstft', 'bleedless', | 
					
						
						|  | 'fullness'], help='List of metrics to use.') | 
					
						
						|  | parser.add_argument("--lora_checkpoint", type=str, default='', help="Initial checkpoint to LoRA weights") | 
					
						
						|  |  | 
					
						
						|  | if dict_args is not None: | 
					
						
						|  | args = parser.parse_args([]) | 
					
						
						|  | args_dict = vars(args) | 
					
						
						|  | args_dict.update(dict_args) | 
					
						
						|  | args = argparse.Namespace(**args_dict) | 
					
						
						|  | else: | 
					
						
						|  | args = parser.parse_args() | 
					
						
						|  |  | 
					
						
						|  | return args | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def check_validation(dict_args): | 
					
						
						|  | args = parse_args(dict_args) | 
					
						
						|  | torch.backends.cudnn.benchmark = True | 
					
						
						|  | try: | 
					
						
						|  | torch.multiprocessing.set_start_method('spawn') | 
					
						
						|  | except Exception as e: | 
					
						
						|  | pass | 
					
						
						|  | model, config = get_model_from_config(args.model_type, args.config_path) | 
					
						
						|  |  | 
					
						
						|  | if args.start_check_point: | 
					
						
						|  | load_start_checkpoint(args, model, type_='valid') | 
					
						
						|  |  | 
					
						
						|  | print(f"Instruments: {config.training.instruments}") | 
					
						
						|  |  | 
					
						
						|  | device_ids = args.device_ids | 
					
						
						|  | if torch.cuda.is_available(): | 
					
						
						|  | device = torch.device(f'cuda:{device_ids[0]}') | 
					
						
						|  | else: | 
					
						
						|  | device = 'cpu' | 
					
						
						|  | print('CUDA is not available. Run validation on CPU. It will be very slow...') | 
					
						
						|  |  | 
					
						
						|  | if torch.cuda.is_available() and len(device_ids) > 1: | 
					
						
						|  | valid_multi_gpu(model, args, config, device_ids, verbose=False) | 
					
						
						|  | else: | 
					
						
						|  | valid(model, args, config, device, verbose=True) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if __name__ == "__main__": | 
					
						
						|  | check_validation(None) | 
					
						
						|  |  |