Spaces:
Running
on
Zero
Running
on
Zero
# coding: utf-8 | |
__author__ = 'Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/' | |
import argparse | |
import time | |
import os | |
import glob | |
import torch | |
import librosa | |
import numpy as np | |
import soundfile as sf | |
from tqdm.auto import tqdm | |
from ml_collections import ConfigDict | |
from typing import Tuple, Dict, List, Union | |
from utils import demix, get_model_from_config, prefer_target_instrument, draw_spectrogram | |
from utils import normalize_audio, denormalize_audio, apply_tta, read_audio_transposed, load_start_checkpoint | |
from metrics import get_metrics | |
import warnings | |
warnings.filterwarnings("ignore") | |
def logging(logs: List[str], text: str, verbose_logging: bool = False) -> None: | |
""" | |
Log validation information by printing the text and appending it to a log list. | |
Parameters: | |
---------- | |
store_dir : str | |
Directory to store the logs. If empty, logs are not stored. | |
logs : List[str] | |
List where the logs will be appended if the store_dir is specified. | |
text : str | |
The text to be logged, printed, and optionally added to the logs list. | |
Returns: | |
------- | |
None | |
This function modifies the logs list in place and prints the text. | |
""" | |
print(text) | |
if verbose_logging: | |
logs.append(text) | |
def write_results_in_file(store_dir: str, logs: List[str]) -> None: | |
""" | |
Write the list of results into a file in the specified directory. | |
Parameters: | |
---------- | |
store_dir : str | |
The directory where the results file will be saved. | |
results : List[str] | |
A list of result strings to be written to the file. | |
Returns: | |
------- | |
None | |
""" | |
with open(f'{store_dir}/results.txt', 'w') as out: | |
for item in logs: | |
out.write(item + "\n") | |
def get_mixture_paths( | |
args, | |
verbose: bool, | |
config: ConfigDict, | |
extension: str | |
) -> List[str]: | |
""" | |
Retrieve paths to mixture files in the specified validation directories. | |
Parameters: | |
---------- | |
valid_path : List[str] | |
A list of directories to search for validation mixtures. | |
verbose : bool | |
If True, prints detailed information about the search process. | |
config : ConfigDict | |
Configuration object containing parameters like `inference.num_overlap` and `inference.batch_size`. | |
extension : str | |
File extension of the mixture files (e.g., 'wav'). | |
Returns: | |
------- | |
List[str] | |
A list of file paths to the mixture files. | |
""" | |
try: | |
valid_path = args.valid_path | |
except Exception as e: | |
print('No valid path in args') | |
raise e | |
all_mixtures_path = [] | |
for path in valid_path: | |
part = sorted(glob.glob(f"{path}/*/mixture.{extension}")) | |
if len(part) == 0: | |
if verbose: | |
print(f'No validation data found in: {path}') | |
all_mixtures_path += part | |
if verbose: | |
print(f'Total mixtures: {len(all_mixtures_path)}') | |
print(f'Overlap: {config.inference.num_overlap} Batch size: {config.inference.batch_size}') | |
return all_mixtures_path | |
def update_metrics_and_pbar( | |
track_metrics: Dict, | |
all_metrics: Dict, | |
instr: str, | |
pbar_dict: Dict, | |
mixture_paths: Union[List[str], tqdm], | |
verbose: bool = False | |
) -> None: | |
""" | |
Update metrics dictionary and progress bar with new metric values. | |
Parameters: | |
---------- | |
track_metrics : Dict | |
Dictionary with metric names as keys and their computed values as values. | |
all_metrics : Dict | |
Dictionary to store all metrics, organized by metric name and instrument. | |
instr : str | |
Name of the instrument for which the metrics are being computed. | |
pbar_dict : Dict | |
Dictionary for progress bar updates. | |
mixture_paths : tqdm, optional | |
Progress bar object, if available. Default is None. | |
verbose : bool, optional | |
If True, prints metric values to the console. Default is False. | |
""" | |
for metric_name, metric_value in track_metrics.items(): | |
if verbose: | |
print(f"Metric {metric_name:11s} value: {metric_value:.4f}") | |
all_metrics[metric_name][instr].append(metric_value) | |
pbar_dict[f'{metric_name}_{instr}'] = metric_value | |
if mixture_paths is not None: | |
try: | |
mixture_paths.set_postfix(pbar_dict) | |
except Exception: | |
pass | |
def process_audio_files( | |
mixture_paths: List[str], | |
model: torch.nn.Module, | |
args, | |
config, | |
device: torch.device, | |
verbose: bool = False, | |
is_tqdm: bool = True | |
) -> Dict[str, Dict[str, List[float]]]: | |
""" | |
Process a list of audio files, perform source separation, and evaluate metrics. | |
Parameters: | |
---------- | |
mixture_paths : List[str] | |
List of file paths to the audio mixtures. | |
model : torch.nn.Module | |
The trained model used for source separation. | |
args : Any | |
Argument object containing user-specified options like metrics, model type, etc. | |
config : Any | |
Configuration object containing model and processing parameters. | |
device : torch.device | |
Device (CPU or CUDA) on which the model will be executed. | |
verbose : bool, optional | |
If True, prints detailed logs for each processed file. Default is False. | |
is_tqdm : bool, optional | |
If True, displays a progress bar for file processing. Default is True. | |
Returns: | |
------- | |
Dict[str, Dict[str, List[float]]] | |
A nested dictionary where the outer keys are metric names, | |
the inner keys are instrument names, and the values are lists of metric scores. | |
""" | |
instruments = prefer_target_instrument(config) | |
use_tta = getattr(args, 'use_tta', False) | |
# dir to save files, if empty no saving | |
store_dir = getattr(args, 'store_dir', '') | |
# codec to save files | |
if 'extension' in config['inference']: | |
extension = config['inference']['extension'] | |
else: | |
extension = getattr(args, 'extension', 'wav') | |
# Initialize metrics dictionary | |
all_metrics = { | |
metric: {instr: [] for instr in config.training.instruments} | |
for metric in args.metrics | |
} | |
if is_tqdm: | |
mixture_paths = tqdm(mixture_paths) | |
for path in mixture_paths: | |
start_time = time.time() | |
mix, sr = read_audio_transposed(path) | |
mix_orig = mix.copy() | |
folder = os.path.dirname(path) | |
if 'sample_rate' in config.audio: | |
if sr != config.audio['sample_rate']: | |
orig_length = mix.shape[-1] | |
if verbose: | |
print(f'Warning: sample rate is different. In config: {config.audio["sample_rate"]} in file {path}: {sr}') | |
mix = librosa.resample(mix, orig_sr=sr, target_sr=config.audio['sample_rate'], res_type='kaiser_best') | |
if verbose: | |
folder_name = os.path.abspath(folder) | |
print(f'Song: {folder_name} Shape: {mix.shape}') | |
if 'normalize' in config.inference: | |
if config.inference['normalize'] is True: | |
mix, norm_params = normalize_audio(mix) | |
waveforms_orig = demix(config, model, mix.copy(), device, model_type=args.model_type) | |
if use_tta: | |
waveforms_orig = apply_tta(config, model, mix, waveforms_orig, device, args.model_type) | |
pbar_dict = {} | |
for instr in instruments: | |
if verbose: | |
print(f"Instr: {instr}") | |
if instr != 'other' or config.training.other_fix is False: | |
track, sr1 = read_audio_transposed(f"{folder}/{instr}.{extension}", instr, skip_err=True) | |
if track is None: | |
continue | |
else: | |
# if track=vocal+other | |
track, sr1 = read_audio_transposed(f"{folder}/vocals.{extension}") | |
track = mix_orig - track | |
estimates = waveforms_orig[instr] | |
if 'sample_rate' in config.audio: | |
if sr != config.audio['sample_rate']: | |
estimates = librosa.resample(estimates, orig_sr=config.audio['sample_rate'], target_sr=sr, | |
res_type='kaiser_best') | |
estimates = librosa.util.fix_length(estimates, size=orig_length) | |
if 'normalize' in config.inference: | |
if config.inference['normalize'] is True: | |
estimates = denormalize_audio(estimates, norm_params) | |
if store_dir: | |
os.makedirs(store_dir, exist_ok=True) | |
out_wav_name = f"{store_dir}/{os.path.basename(folder)}_{instr}.wav" | |
sf.write(out_wav_name, estimates.T, sr, subtype='FLOAT') | |
if args.draw_spectro > 0: | |
out_img_name = f"{store_dir}/{os.path.basename(folder)}_{instr}.jpg" | |
draw_spectrogram(estimates.T, sr, args.draw_spectro, out_img_name) | |
out_img_name_orig = f"{store_dir}/{os.path.basename(folder)}_{instr}_orig.jpg" | |
draw_spectrogram(track.T, sr, args.draw_spectro, out_img_name_orig) | |
track_metrics = get_metrics( | |
args.metrics, | |
track, | |
estimates, | |
mix_orig, | |
device=device, | |
) | |
update_metrics_and_pbar( | |
track_metrics, | |
all_metrics, | |
instr, pbar_dict, | |
mixture_paths=mixture_paths, | |
verbose=verbose | |
) | |
if verbose: | |
print(f"Time for song: {time.time() - start_time:.2f} sec") | |
return all_metrics | |
def compute_metric_avg( | |
store_dir: str, | |
args, | |
instruments: List[str], | |
config: ConfigDict, | |
all_metrics: Dict[str, Dict[str, List[float]]], | |
start_time: float | |
) -> Dict[str, float]: | |
""" | |
Calculate and log the average metrics for each instrument, including per-instrument metrics and overall averages. | |
Parameters: | |
---------- | |
store_dir : str | |
Directory to store the logs. If empty, logs are not stored. | |
args : dict | |
Dictionary containing the arguments, used for logging. | |
instruments : List[str] | |
List of instruments to process. | |
config : ConfigDict | |
Configuration dictionary containing the inference settings. | |
all_metrics : Dict[str, Dict[str, List[float]]] | |
A dictionary containing metric values for each instrument. | |
The structure is {metric_name: {instrument_name: [metric_values]}}. | |
start_time : float | |
The starting time for calculating elapsed time. | |
Returns: | |
------- | |
Dict[str, float] | |
A dictionary with the average value for each metric across all instruments. | |
""" | |
logs = [] | |
if store_dir: | |
logs.append(str(args)) | |
verbose_logging = True | |
else: | |
verbose_logging = False | |
logging(logs, text=f"Num overlap: {config.inference.num_overlap}", verbose_logging=verbose_logging) | |
metric_avg = {} | |
for instr in instruments: | |
for metric_name in all_metrics: | |
metric_values = np.array(all_metrics[metric_name][instr]) | |
mean_val = metric_values.mean() | |
std_val = metric_values.std() | |
logging(logs, text=f"Instr {instr} {metric_name}: {mean_val:.4f} (Std: {std_val:.4f})", verbose_logging=verbose_logging) | |
if metric_name not in metric_avg: | |
metric_avg[metric_name] = 0.0 | |
metric_avg[metric_name] += mean_val | |
for metric_name in all_metrics: | |
metric_avg[metric_name] /= len(instruments) | |
if len(instruments) > 1: | |
for metric_name in metric_avg: | |
logging(logs, text=f'Metric avg {metric_name:11s}: {metric_avg[metric_name]:.4f}', verbose_logging=verbose_logging) | |
logging(logs, text=f"Elapsed time: {time.time() - start_time:.2f} sec", verbose_logging=verbose_logging) | |
if store_dir: | |
write_results_in_file(store_dir, logs) | |
return metric_avg | |
def valid( | |
model: torch.nn.Module, | |
args, | |
config: ConfigDict, | |
device: torch.device, | |
verbose: bool = False | |
) -> Tuple[dict, dict]: | |
""" | |
Validate a trained model on a set of audio mixtures and compute metrics. | |
This function performs validation by separating audio sources from mixtures, | |
computing evaluation metrics, and optionally saving results to a file. | |
Parameters: | |
---------- | |
model : torch.nn.Module | |
The trained model for source separation. | |
args : Namespace | |
Command-line arguments or equivalent object containing configurations. | |
config : dict | |
Configuration dictionary with model and processing parameters. | |
device : torch.device | |
The device (CPU or CUDA) to run the model on. | |
verbose : bool, optional | |
If True, enables verbose output during processing. Default is False. | |
Returns: | |
------- | |
dict | |
A dictionary of average metrics across all instruments. | |
""" | |
start_time = time.time() | |
model.eval().to(device) | |
# dir to save files, if empty no saving | |
store_dir = getattr(args, 'store_dir', '') | |
# codec to save files | |
if 'extension' in config['inference']: | |
extension = config['inference']['extension'] | |
else: | |
extension = getattr(args, 'extension', 'wav') | |
all_mixtures_path = get_mixture_paths(args, verbose, config, extension) | |
all_metrics = process_audio_files(all_mixtures_path, model, args, config, device, verbose, not verbose) | |
instruments = prefer_target_instrument(config) | |
return compute_metric_avg(store_dir, args, instruments, config, all_metrics, start_time), all_metrics | |
def validate_in_subprocess( | |
proc_id: int, | |
queue: torch.multiprocessing.Queue, | |
all_mixtures_path: List[str], | |
model: torch.nn.Module, | |
args, | |
config: ConfigDict, | |
device: str, | |
return_dict | |
) -> None: | |
""" | |
Perform validation on a subprocess with multi-processing support. Each process handles inference on a subset of the mixture files | |
and updates the shared metrics dictionary. | |
Parameters: | |
---------- | |
proc_id : int | |
The process ID (used to assign metrics to the correct key in `return_dict`). | |
queue : torch.multiprocessing.Queue | |
Queue to receive paths to the mixture files for processing. | |
all_mixtures_path : List[str] | |
List of paths to the mixture files to be processed. | |
model : torch.nn.Module | |
The model to be used for inference. | |
args : dict | |
Dictionary containing various argument configurations (e.g., metrics to calculate). | |
config : ConfigDict | |
Configuration object containing model settings and training parameters. | |
device : str | |
The device to use for inference (e.g., 'cpu', 'cuda:0'). | |
return_dict : torch.multiprocessing.Manager().dict | |
Shared dictionary to store the results from each process. | |
Returns: | |
------- | |
None | |
The function modifies the `return_dict` in place, but does not return any value. | |
""" | |
m1 = model.eval().to(device) | |
if proc_id == 0: | |
progress_bar = tqdm(total=len(all_mixtures_path)) | |
# Initialize metrics dictionary | |
all_metrics = { | |
metric: {instr: [] for instr in config.training.instruments} | |
for metric in args.metrics | |
} | |
while True: | |
current_step, path = queue.get() | |
if path is None: # check for sentinel value | |
break | |
single_metrics = process_audio_files([path], m1, args, config, device, False, False) | |
pbar_dict = {} | |
for instr in config.training.instruments: | |
for metric_name in all_metrics: | |
all_metrics[metric_name][instr] += single_metrics[metric_name][instr] | |
if len(single_metrics[metric_name][instr]) > 0: | |
pbar_dict[f"{metric_name}_{instr}"] = f"{single_metrics[metric_name][instr][0]:.4f}" | |
if proc_id == 0: | |
progress_bar.update(current_step - progress_bar.n) | |
progress_bar.set_postfix(pbar_dict) | |
# print(f"Inference on process {proc_id}", all_sdr) | |
return_dict[proc_id] = all_metrics | |
return | |
def run_parallel_validation( | |
verbose: bool, | |
all_mixtures_path: List[str], | |
config: ConfigDict, | |
model: torch.nn.Module, | |
device_ids: List[int], | |
args, | |
return_dict | |
) -> None: | |
""" | |
Run parallel validation using multiple processes. Each process handles a subset of the mixture files and computes the metrics. | |
The results are stored in a shared dictionary. | |
Parameters: | |
---------- | |
verbose : bool | |
Flag to print detailed information about the validation process. | |
all_mixtures_path : List[str] | |
List of paths to the mixture files to be processed. | |
config : ConfigDict | |
Configuration object containing model settings and validation parameters. | |
model : torch.nn.Module | |
The model to be used for inference. | |
device_ids : List[int] | |
List of device IDs (for multi-GPU setups) to use for validation. | |
args : dict | |
Dictionary containing various argument configurations (e.g., metrics to calculate). | |
Returns: | |
------- | |
A shared dictionary containing the validation metrics from all processes. | |
""" | |
model = model.to('cpu') | |
try: | |
# For multiGPU training extract single model | |
model = model.module | |
except: | |
pass | |
queue = torch.multiprocessing.Queue() | |
processes = [] | |
for i, device in enumerate(device_ids): | |
if torch.cuda.is_available(): | |
device = f'cuda:{device}' | |
else: | |
device = 'cpu' | |
p = torch.multiprocessing.Process( | |
target=validate_in_subprocess, | |
args=(i, queue, all_mixtures_path, model, args, config, device, return_dict) | |
) | |
p.start() | |
processes.append(p) | |
for i, path in enumerate(all_mixtures_path): | |
queue.put((i, path)) | |
for _ in range(len(device_ids)): | |
queue.put((None, None)) # sentinel value to signal subprocesses to exit | |
for p in processes: | |
p.join() # wait for all subprocesses to finish | |
return | |
def valid_multi_gpu( | |
model: torch.nn.Module, | |
args, | |
config: ConfigDict, | |
device_ids: List[int], | |
verbose: bool = False | |
) -> Tuple[Dict[str, float], dict]: | |
""" | |
Perform validation across multiple GPUs, processing mixtures and computing metrics using parallel processes. | |
The results from each GPU are aggregated and the average metrics are computed. | |
Parameters: | |
---------- | |
model : torch.nn.Module | |
The model to be used for inference. | |
args : dict | |
Dictionary containing various argument configurations, such as file saving directory and codec settings. | |
config : ConfigDict | |
Configuration object containing model settings and validation parameters. | |
device_ids : List[int] | |
List of device IDs (for multi-GPU setups) to use for validation. | |
verbose : bool, optional | |
Flag to print detailed information about the validation process. Default is False. | |
Returns: | |
------- | |
Dict[str, float] | |
A dictionary containing the average metrics for each metric name. | |
""" | |
start_time = time.time() | |
# dir to save files, if empty no saving | |
store_dir = getattr(args, 'store_dir', '') | |
# codec to save files | |
if 'extension' in config['inference']: | |
extension = config['inference']['extension'] | |
else: | |
extension = getattr(args, 'extension', 'wav') | |
all_mixtures_path = get_mixture_paths(args, verbose, config, extension) | |
return_dict = torch.multiprocessing.Manager().dict() | |
run_parallel_validation(verbose, all_mixtures_path, config, model, device_ids, args, return_dict) | |
all_metrics = dict() | |
for metric in args.metrics: | |
all_metrics[metric] = dict() | |
for instr in config.training.instruments: | |
all_metrics[metric][instr] = [] | |
for i in range(len(device_ids)): | |
all_metrics[metric][instr] += return_dict[i][metric][instr] | |
instruments = prefer_target_instrument(config) | |
return compute_metric_avg(store_dir, args, instruments, config, all_metrics, start_time), all_metrics | |
def parse_args(dict_args: Union[Dict, None]) -> argparse.Namespace: | |
""" | |
Parse command-line arguments for configuring the model, dataset, and training parameters. | |
Args: | |
dict_args: Dict of command-line arguments. If None, arguments will be parsed from sys.argv. | |
Returns: | |
Namespace object containing parsed arguments and their values. | |
""" | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--model_type", type=str, default='mdx23c', | |
help="One of mdx23c, htdemucs, segm_models, mel_band_roformer," | |
" bs_roformer, swin_upernet, bandit") | |
parser.add_argument("--config_path", type=str, help="Path to config file") | |
parser.add_argument("--start_check_point", type=str, default='', help="Initial checkpoint" | |
" to valid weights") | |
parser.add_argument("--valid_path", nargs="+", type=str, help="Validate path") | |
parser.add_argument("--store_dir", type=str, default="", help="Path to store results as wav file") | |
parser.add_argument("--draw_spectro", type=float, default=0, | |
help="If --store_dir is set then code will generate spectrograms for resulted stems as well." | |
" Value defines for how many seconds os track spectrogram will be generated.") | |
parser.add_argument("--device_ids", nargs='+', type=int, default=0, help='List of gpu ids') | |
parser.add_argument("--num_workers", type=int, default=0, help="Dataloader num_workers") | |
parser.add_argument("--pin_memory", action='store_true', help="Dataloader pin_memory") | |
parser.add_argument("--extension", type=str, default='wav', help="Choose extension for validation") | |
parser.add_argument("--use_tta", action='store_true', | |
help="Flag adds test time augmentation during inference (polarity and channel inverse)." | |
"While this triples the runtime, it reduces noise and slightly improves prediction quality.") | |
parser.add_argument("--metrics", nargs='+', type=str, default=["sdr"], | |
choices=['sdr', 'l1_freq', 'si_sdr', 'neg_log_wmse', 'aura_stft', 'aura_mrstft', 'bleedless', | |
'fullness'], help='List of metrics to use.') | |
parser.add_argument("--lora_checkpoint", type=str, default='', help="Initial checkpoint to LoRA weights") | |
if dict_args is not None: | |
args = parser.parse_args([]) | |
args_dict = vars(args) | |
args_dict.update(dict_args) | |
args = argparse.Namespace(**args_dict) | |
else: | |
args = parser.parse_args() | |
return args | |
def check_validation(dict_args): | |
args = parse_args(dict_args) | |
torch.backends.cudnn.benchmark = True | |
try: | |
torch.multiprocessing.set_start_method('spawn') | |
except Exception as e: | |
pass | |
model, config = get_model_from_config(args.model_type, args.config_path) | |
if args.start_check_point: | |
load_start_checkpoint(args, model, type_='valid') | |
print(f"Instruments: {config.training.instruments}") | |
device_ids = args.device_ids | |
if torch.cuda.is_available(): | |
device = torch.device(f'cuda:{device_ids[0]}') | |
else: | |
device = 'cpu' | |
print('CUDA is not available. Run validation on CPU. It will be very slow...') | |
if torch.cuda.is_available() and len(device_ids) > 1: | |
valid_multi_gpu(model, args, config, device_ids, verbose=False) | |
else: | |
valid(model, args, config, device, verbose=True) | |
if __name__ == "__main__": | |
check_validation(None) | |