# coding: utf-8 __author__ = 'Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/' import os import librosa import soundfile as sf import numpy as np import argparse import logging import gc logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def stft(wave, nfft, hl): wave_left = np.ascontiguousarray(wave[0]) wave_right = np.ascontiguousarray(wave[1]) spec_left = librosa.stft(wave_left, n_fft=nfft, hop_length=hl) spec_right = librosa.stft(wave_right, n_fft=nfft, hop_length=hl) spec = np.stack([spec_left, spec_right]) return spec def istft(spec, hl, length): spec_left = np.ascontiguousarray(spec[0]) spec_right = np.ascontiguousarray(spec[1]) wave_left = librosa.istft(spec_left, hop_length=hl, length=length) wave_right = librosa.istft(spec_right, hop_length=hl, length=length) wave = np.stack([wave_left, wave_right]) return wave def absmax(a, *, axis): dims = list(a.shape) dims.pop(axis) indices = list(np.ogrid[tuple(slice(0, d) for d in dims)]) argmax = np.abs(a).argmax(axis=axis) insert_pos = (len(a.shape) + axis) % len(a.shape) indices.insert(insert_pos, argmax) return a[tuple(indices)] def absmin(a, *, axis): dims = list(a.shape) dims.pop(axis) indices = list(np.ogrid[tuple(slice(0, d) for d in dims)]) argmax = np.abs(a).argmin(axis=axis) insert_pos = (len(a.shape) + axis) % len(a.shape) indices.insert(insert_pos, argmax) return a[tuple(indices)] def lambda_max(arr, axis=None, key=None, keepdims=False): idxs = np.argmax(key(arr), axis) if axis is not None: idxs = np.expand_dims(idxs, axis) result = np.take_along_axis(arr, idxs, axis) if not keepdims: result = np.squeeze(result, axis=axis) return result else: return arr.flatten()[idxs] def lambda_min(arr, axis=None, key=None, keepdims=False): idxs = np.argmin(key(arr), axis) if axis is not None: idxs = np.expand_dims(idxs, axis) result = np.take_along_axis(arr, idxs, axis) if not keepdims: result = np.squeeze(result, axis=axis) return result else: return arr.flatten()[idxs] def average_waveforms(pred_track, weights, algorithm): """ :param pred_track: shape = (num, channels, length) :param weights: shape = (num, ) :param algorithm: One of avg_wave, median_wave, min_wave, max_wave, avg_fft, median_fft, min_fft, max_fft :return: averaged waveform in shape (channels, length) """ pred_track = np.asarray(pred_track) # NumPy 2.0+ compatibility final_length = pred_track.shape[-1] mod_track = [] for i in range(pred_track.shape[0]): if algorithm == 'avg_wave': mod_track.append(pred_track[i] * weights[i]) elif algorithm in ['median_wave', 'min_wave', 'max_wave']: mod_track.append(pred_track[i]) elif algorithm in ['avg_fft', 'min_fft', 'max_fft', 'median_fft']: spec = stft(pred_track[i], nfft=2048, hl=1024) if algorithm == 'avg_fft': mod_track.append(spec * weights[i]) else: mod_track.append(spec) del spec gc.collect() mod_track = np.asarray(mod_track) # NumPy 2.0+ compatibility if algorithm == 'avg_wave': result = mod_track.sum(axis=0) / np.sum(weights) elif algorithm == 'median_wave': result = np.median(mod_track, axis=0) elif algorithm == 'min_wave': result = lambda_min(mod_track, axis=0, key=np.abs) elif algorithm == 'max_wave': result = lambda_max(mod_track, axis=0, key=np.abs) elif algorithm == 'avg_fft': result = mod_track.sum(axis=0) / np.sum(weights) result = istft(result, 1024, final_length) elif algorithm == 'min_fft': result = lambda_min(mod_track, axis=0, key=np.abs) result = istft(result, 1024, final_length) elif algorithm == 'max_fft': result = absmax(mod_track, axis=0) result = istft(result, 1024, final_length) elif algorithm == 'median_fft': result = np.median(mod_track, axis=0) result = istft(result, 1024, final_length) gc.collect() return result def ensemble_files(args): parser = argparse.ArgumentParser(description="Ensemble audio files") parser.add_argument('--files', nargs='+', required=True, help="Input audio files") parser.add_argument('--type', required=True, choices=['avg_wave', 'median_wave', 'max_wave', 'min_wave', 'avg_fft', 'median_fft', 'max_fft', 'min_fft'], help="Ensemble type") parser.add_argument('--weights', nargs='+', type=float, default=None, help="Weights for each file") parser.add_argument('--output', required=True, help="Output file path") args = parser.parse_args(args) if isinstance(args, list) else args logger.info(f"Ensemble type: {args.type}") logger.info(f"Number of input files: {len(args.files)}") weights = args.weights if args.weights else [1.0] * len(args.files) if len(weights) != len(args.files): logger.error("Number of weights must match number of audio files") raise ValueError("Number of weights must match number of audio files") logger.info(f"Weights: {weights}") logger.info(f"Output file: {args.output}") data = [] sr = None for f in args.files: if not os.path.isfile(f): logger.error(f"Cannot find file: {f}") raise FileNotFoundError(f"Cannot find file: {f}") logger.info(f"Reading file: {f}") try: wav, curr_sr = librosa.load(f, sr=None, mono=False) if sr is None: sr = curr_sr elif sr != curr_sr: logger.error("All audio files must have the same sample rate") raise ValueError("All audio files must have the same sample rate") logger.info(f"Waveform shape: {wav.shape} sample rate: {sr}") data.append(wav) del wav gc.collect() except Exception as e: logger.error(f"Error reading audio file {f}: {str(e)}") raise RuntimeError(f"Error reading audio file {f}: {str(e)}") try: data = np.asarray(data) # NumPy 2.0+ compatibility res = average_waveforms(data, weights, args.type) logger.info(f"Result shape: {res.shape}") os.makedirs(os.path.dirname(args.output), exist_ok=True) sf.write(args.output, res.T, sr, 'FLOAT') logger.info(f"Output written to: {args.output}") return args.output except Exception as e: logger.error(f"Error during ensemble processing: {str(e)}") raise RuntimeError(f"Error during ensemble processing: {str(e)}") finally: gc.collect() if __name__ == "__main__": ensemble_files(sys.argv[1:])