Spaces:
Running
on
Zero
Running
on
Zero
# coding: utf-8 | |
__author__ = 'Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/' | |
import argparse | |
import time | |
import librosa | |
from tqdm.auto import tqdm | |
import sys | |
import os | |
import glob | |
import torch | |
import soundfile as sf | |
import torch.nn as nn | |
from datetime import datetime | |
import numpy as np | |
import librosa | |
# Using the embedded version of Python can also correctly import the utils module. | |
current_dir = os.path.dirname(os.path.abspath(__file__)) | |
sys.path.append(current_dir) | |
from utils import demix, get_model_from_config, normalize_audio, denormalize_audio | |
from utils import prefer_target_instrument, apply_tta, load_start_checkpoint, load_lora_weights | |
import warnings | |
warnings.filterwarnings("ignore") | |
def shorten_filename(filename, max_length=30): | |
""" | |
Shortens a filename to a specified maximum length | |
Args: | |
filename (str): The filename to be shortened | |
max_length (int): Maximum allowed length for the filename | |
Returns: | |
str: Shortened filename | |
""" | |
base, ext = os.path.splitext(filename) | |
if len(base) <= max_length: | |
return filename | |
# Take first 15 and last 10 characters | |
shortened = base[:15] + "..." + base[-10:] + ext | |
return shortened | |
def get_soundfile_subtype(pcm_type, is_float=False): | |
""" | |
PCM türüne göre uygun soundfile subtypei belirle | |
Args: | |
pcm_type (str): PCM türü ('PCM_16', 'PCM_24', 'FLOAT') | |
is_float (bool): Float formatı kullanılıp kullanılmayacağı | |
Returns: | |
str: Soundfile subtype | |
""" | |
if is_float: | |
return 'FLOAT' | |
subtype_map = { | |
'PCM_16': 'PCM_16', | |
'PCM_24': 'PCM_24', | |
'FLOAT': 'FLOAT' | |
} | |
return subtype_map.get(pcm_type, 'FLOAT') | |
def run_folder(model, args, config, device, verbose: bool = False): | |
start_time = time.time() | |
model.eval() | |
mixture_paths = sorted(glob.glob(os.path.join(args.input_folder, '*.*'))) | |
sample_rate = getattr(config.audio, 'sample_rate', 44100) | |
print(f"Total files found: {len(mixture_paths)}. Using sample rate: {sample_rate}") | |
instruments = prefer_target_instrument(config)[:] | |
os.makedirs(args.store_dir, exist_ok=True) | |
# Dosya sayısını ve progress için değişkenler | |
total_files = len(mixture_paths) | |
current_file = 0 | |
# Progress tracking | |
for path in mixture_paths: | |
try: | |
# Dosya işleme başlangıcı | |
current_file += 1 | |
print(f"Processing file {current_file}/{total_files}") | |
mix, sr = librosa.load(path, sr=sample_rate, mono=False) | |
except Exception as e: | |
print(f'Cannot read track: {path}') | |
print(f'Error message: {str(e)}') | |
continue | |
mix_orig = mix.copy() | |
if 'normalize' in config.inference: | |
if config.inference['normalize'] is True: | |
mix, norm_params = normalize_audio(mix) | |
waveforms_orig = demix(config, model, mix, device, model_type=args.model_type) | |
if args.use_tta: | |
waveforms_orig = apply_tta(config, model, mix, waveforms_orig, device, args.model_type) | |
if args.demud_phaseremix_inst: | |
print(f"Demudding track (phase remix - instrumental): {path}") | |
instr = 'vocals' if 'vocals' in instruments else instruments[0] | |
instruments.append('instrumental_phaseremix') | |
if 'instrumental' not in instruments and 'Instrumental' not in instruments: | |
mix_modified = mix_orig - 2*waveforms_orig[instr] | |
mix_modified_ = mix_modified.copy() | |
waveforms_modified = demix(config, model, mix_modified, device, model_type=args.model_type) | |
if args.use_tta: | |
waveforms_modified = apply_tta(config, model, mix_modified, waveforms_modified, device, args.model_type) | |
waveforms_orig['instrumental_phaseremix'] = mix_orig + waveforms_modified[instr] | |
else: | |
mix_modified = 2*waveforms_orig[instr] - mix_orig | |
mix_modified_ = mix_modified.copy() | |
waveforms_modified = demix(config, model, mix_modified, device, model_type=args.model_type) | |
if args.use_tta: | |
waveforms_modified = apply_tta(config, model, mix_modified, waveforms_orig, device, args.model_type) | |
waveforms_orig['instrumental_phaseremix'] = mix_orig + mix_modified_ - waveforms_modified[instr] | |
if args.extract_instrumental: | |
instr = 'vocals' if 'vocals' in instruments else instruments[0] | |
waveforms_orig['instrumental'] = mix_orig - waveforms_orig[instr] | |
if 'instrumental' not in instruments: | |
instruments.append('instrumental') | |
for instr in instruments: | |
estimates = waveforms_orig[instr] | |
if 'normalize' in config.inference: | |
if config.inference['normalize'] is True: | |
estimates = denormalize_audio(estimates, norm_params) | |
# Dosya formatı ve PCM türü belirleme | |
is_float = getattr(args, 'export_format', '').startswith('wav FLOAT') | |
codec = 'flac' if getattr(args, 'flac_file', False) else 'wav' | |
# Subtype belirleme | |
if codec == 'flac': | |
subtype = get_soundfile_subtype(args.pcm_type, is_float) | |
else: | |
subtype = get_soundfile_subtype('FLOAT', is_float) | |
shortened_filename = shorten_filename(os.path.basename(path)) | |
output_filename = f"{shortened_filename}_{instr}.{codec}" | |
output_path = os.path.join(args.store_dir, output_filename) | |
sf.write(output_path, estimates.T, sr, subtype=subtype) | |
# Progress yüzdesi hesaplama | |
progress_percent = int((current_file / total_files) * 100) | |
print(f"Progress: {progress_percent}%") | |
print(f"Elapsed time: {time.time() - start_time:.2f} seconds.") | |
def proc_folder(args): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--model_type", type=str, default='mdx23c', | |
help="Model type (bandit, bs_roformer, mdx23c, etc.)") | |
parser.add_argument("--config_path", type=str, help="Path to config file") | |
parser.add_argument("--demud_phaseremix_inst", action='store_true', help="demud_phaseremix_inst") | |
parser.add_argument("--start_check_point", type=str, default='', | |
help="Initial checkpoint to valid weights") | |
parser.add_argument("--input_folder", type=str, help="Folder with mixtures to process") | |
parser.add_argument("--audio_path", type=str, help="Path to a single audio file to process") # Yeni argüman | |
parser.add_argument("--store_dir", default="", type=str, help="Path to store results") | |
parser.add_argument("--device_ids", nargs='+', type=int, default=0, | |
help='List of GPU IDs') | |
parser.add_argument("--extract_instrumental", action='store_true', | |
help="Invert vocals to get instrumental if provided") | |
parser.add_argument("--force_cpu", action='store_true', | |
help="Force the use of CPU even if CUDA is available") | |
parser.add_argument("--flac_file", action='store_true', | |
help="Output flac file instead of wav") | |
parser.add_argument("--export_format", type=str, | |
choices=['wav FLOAT', 'flac PCM_16', 'flac PCM_24'], | |
default='flac PCM_24', | |
help="Export format and PCM type") | |
parser.add_argument("--pcm_type", type=str, | |
choices=['PCM_16', 'PCM_24'], | |
default='PCM_24', | |
help="PCM type for FLAC files") | |
parser.add_argument("--use_tta", action='store_true', | |
help="Enable test time augmentation") | |
parser.add_argument("--lora_checkpoint", type=str, default='', | |
help="Initial checkpoint to LoRA weights") | |
# Argümanları ayrıştır | |
parsed_args = parser.parse_args(args) | |
# Burada parsed_args.audio_path ile ses dosyası yolunu kullanabilirsiniz | |
print(f"Audio path provided: {parsed_args.audio_path}") | |
if args is None: | |
args = parser.parse_args() | |
else: | |
args = parser.parse_args(args) | |
# Cihaz seçimi | |
device = "cpu" | |
if args.force_cpu: | |
device = "cpu" | |
elif torch.cuda.is_available(): | |
print('CUDA is available, use --force_cpu to disable it.') | |
device = f'cuda:{args.device_ids[0]}' if type(args.device_ids) == list else f'cuda:{args.device_ids}' | |
elif torch.backends.mps.is_available(): | |
device = "mps" | |
print("Using device: ", device) | |
model_load_start_time = time.time() | |
torch.backends.cudnn.benchmark = True | |
model, config = get_model_from_config(args.model_type, args.config_path) | |
if args.start_check_point != '': | |
load_start_checkpoint(args, model, type_='inference') | |
print("Instruments: {}".format(config.training.instruments)) | |
# Çoklu CUDA GPU kullanımı | |
if type(args.device_ids) == list and len(args.device_ids) > 1 and not args.force_cpu: | |
model = nn.DataParallel(model, device_ids=args.device_ids) | |
model = model.to(device) | |
print("Model load time: {:.2f} sec".format(time.time() - model_load_start_time)) | |
run_folder(model, args, config, device, verbose=True) | |
if __name__ == "__main__": | |
proc_folder(None) | |