ASesYusuf1's picture
Update ensemble.py
dc08c30 verified
# coding: utf-8
__author__ = 'Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/'
import os
import librosa
import soundfile as sf
import numpy as np
import argparse
import logging
import gc
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def stft(wave, nfft, hl):
wave_left = np.ascontiguousarray(wave[0])
wave_right = np.ascontiguousarray(wave[1])
spec_left = librosa.stft(wave_left, n_fft=nfft, hop_length=hl)
spec_right = librosa.stft(wave_right, n_fft=nfft, hop_length=hl)
spec = np.stack([spec_left, spec_right])
return spec
def istft(spec, hl, length):
spec_left = np.ascontiguousarray(spec[0])
spec_right = np.ascontiguousarray(spec[1])
wave_left = librosa.istft(spec_left, hop_length=hl, length=length)
wave_right = librosa.istft(spec_right, hop_length=hl, length=length)
wave = np.stack([wave_left, wave_right])
return wave
def absmax(a, *, axis):
dims = list(a.shape)
dims.pop(axis)
indices = list(np.ogrid[tuple(slice(0, d) for d in dims)])
argmax = np.abs(a).argmax(axis=axis)
insert_pos = (len(a.shape) + axis) % len(a.shape)
indices.insert(insert_pos, argmax)
return a[tuple(indices)]
def absmin(a, *, axis):
dims = list(a.shape)
dims.pop(axis)
indices = list(np.ogrid[tuple(slice(0, d) for d in dims)])
argmax = np.abs(a).argmin(axis=axis)
insert_pos = (len(a.shape) + axis) % len(a.shape)
indices.insert(insert_pos, argmax)
return a[tuple(indices)]
def lambda_max(arr, axis=None, key=None, keepdims=False):
idxs = np.argmax(key(arr), axis)
if axis is not None:
idxs = np.expand_dims(idxs, axis)
result = np.take_along_axis(arr, idxs, axis)
if not keepdims:
result = np.squeeze(result, axis=axis)
return result
else:
return arr.flatten()[idxs]
def lambda_min(arr, axis=None, key=None, keepdims=False):
idxs = np.argmin(key(arr), axis)
if axis is not None:
idxs = np.expand_dims(idxs, axis)
result = np.take_along_axis(arr, idxs, axis)
if not keepdims:
result = np.squeeze(result, axis=axis)
return result
else:
return arr.flatten()[idxs]
def average_waveforms(pred_track, weights, algorithm):
"""
:param pred_track: shape = (num, channels, length)
:param weights: shape = (num, )
:param algorithm: One of avg_wave, median_wave, min_wave, max_wave, avg_fft, median_fft, min_fft, max_fft
:return: averaged waveform in shape (channels, length)
"""
pred_track = np.asarray(pred_track) # NumPy 2.0+ compatibility
final_length = pred_track.shape[-1]
mod_track = []
for i in range(pred_track.shape[0]):
if algorithm == 'avg_wave':
mod_track.append(pred_track[i] * weights[i])
elif algorithm in ['median_wave', 'min_wave', 'max_wave']:
mod_track.append(pred_track[i])
elif algorithm in ['avg_fft', 'min_fft', 'max_fft', 'median_fft']:
spec = stft(pred_track[i], nfft=2048, hl=1024)
if algorithm == 'avg_fft':
mod_track.append(spec * weights[i])
else:
mod_track.append(spec)
del spec
gc.collect()
mod_track = np.asarray(mod_track) # NumPy 2.0+ compatibility
if algorithm == 'avg_wave':
result = mod_track.sum(axis=0) / np.sum(weights)
elif algorithm == 'median_wave':
result = np.median(mod_track, axis=0)
elif algorithm == 'min_wave':
result = lambda_min(mod_track, axis=0, key=np.abs)
elif algorithm == 'max_wave':
result = lambda_max(mod_track, axis=0, key=np.abs)
elif algorithm == 'avg_fft':
result = mod_track.sum(axis=0) / np.sum(weights)
result = istft(result, 1024, final_length)
elif algorithm == 'min_fft':
result = lambda_min(mod_track, axis=0, key=np.abs)
result = istft(result, 1024, final_length)
elif algorithm == 'max_fft':
result = absmax(mod_track, axis=0)
result = istft(result, 1024, final_length)
elif algorithm == 'median_fft':
result = np.median(mod_track, axis=0)
result = istft(result, 1024, final_length)
gc.collect()
return result
def ensemble_files(args):
parser = argparse.ArgumentParser(description="Ensemble audio files")
parser.add_argument('--files', nargs='+', required=True, help="Input audio files")
parser.add_argument('--type', required=True, choices=['avg_wave', 'median_wave', 'max_wave', 'min_wave', 'avg_fft', 'median_fft', 'max_fft', 'min_fft'], help="Ensemble type")
parser.add_argument('--weights', nargs='+', type=float, default=None, help="Weights for each file")
parser.add_argument('--output', required=True, help="Output file path")
args = parser.parse_args(args) if isinstance(args, list) else args
logger.info(f"Ensemble type: {args.type}")
logger.info(f"Number of input files: {len(args.files)}")
weights = args.weights if args.weights else [1.0] * len(args.files)
if len(weights) != len(args.files):
logger.error("Number of weights must match number of audio files")
raise ValueError("Number of weights must match number of audio files")
logger.info(f"Weights: {weights}")
logger.info(f"Output file: {args.output}")
data = []
sr = None
for f in args.files:
if not os.path.isfile(f):
logger.error(f"Cannot find file: {f}")
raise FileNotFoundError(f"Cannot find file: {f}")
logger.info(f"Reading file: {f}")
try:
wav, curr_sr = librosa.load(f, sr=None, mono=False)
if sr is None:
sr = curr_sr
elif sr != curr_sr:
logger.error("All audio files must have the same sample rate")
raise ValueError("All audio files must have the same sample rate")
logger.info(f"Waveform shape: {wav.shape} sample rate: {sr}")
data.append(wav)
del wav
gc.collect()
except Exception as e:
logger.error(f"Error reading audio file {f}: {str(e)}")
raise RuntimeError(f"Error reading audio file {f}: {str(e)}")
try:
data = np.asarray(data) # NumPy 2.0+ compatibility
res = average_waveforms(data, weights, args.type)
logger.info(f"Result shape: {res.shape}")
os.makedirs(os.path.dirname(args.output), exist_ok=True)
sf.write(args.output, res.T, sr, 'FLOAT')
logger.info(f"Output written to: {args.output}")
return args.output
except Exception as e:
logger.error(f"Error during ensemble processing: {str(e)}")
raise RuntimeError(f"Error during ensemble processing: {str(e)}")
finally:
gc.collect()
if __name__ == "__main__":
ensemble_files(sys.argv[1:])