|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
|
from tqdm import tqdm |
|
|
import logging |
|
|
import os |
|
|
from verification import init_model, MODEL_LIST |
|
|
import soundfile as sf |
|
|
import torch |
|
|
import numpy as np |
|
|
import torch.nn.functional as F |
|
|
from torchaudio.transforms import Resample |
|
|
import torch.multiprocessing as mp |
|
|
|
|
|
console_format = logging.Formatter( |
|
|
"[%(asctime)s][%(filename)s:%(levelname)s][%(process)d:%(threadName)s]%(message)s" |
|
|
) |
|
|
console_handler = logging.StreamHandler() |
|
|
console_handler.setFormatter(console_format) |
|
|
console_handler.setLevel(logging.INFO) |
|
|
if len(logging.root.handlers) > 0: |
|
|
for handler in logging.root.handlers: |
|
|
logging.root.removeHandler(handler) |
|
|
logging.root.addHandler(console_handler) |
|
|
logging.root.setLevel(logging.INFO) |
|
|
|
|
|
|
|
|
MODEL_NAME = "wavlm_large" |
|
|
S3PRL_PATH = os.environ.get("S3PRL_PATH") |
|
|
if S3PRL_PATH is not None: |
|
|
import patch_unispeech |
|
|
logging.info("Applying Patches for unispeech!!!") |
|
|
patch_unispeech.patch_for_npu() |
|
|
|
|
|
|
|
|
def get_ref_and_gen_files( |
|
|
test_lst, test_folder, task_queue |
|
|
): |
|
|
with open(test_lst, "r") as fp: |
|
|
for line in fp: |
|
|
fields = line.strip().split("|") |
|
|
gen_name = fields[2].split("/")[-1] |
|
|
gen_name = gen_name.split(".")[0] |
|
|
gen_file = f"{test_folder}/{gen_name}_gen.wav" |
|
|
|
|
|
ref_name = fields[0].split("/")[-1] |
|
|
ref_name = ref_name.split(".")[0] |
|
|
ref_file = f"{test_folder}/{ref_name}_ref.wav" |
|
|
|
|
|
task_queue.put((ref_file, gen_file)) |
|
|
|
|
|
return |
|
|
|
|
|
|
|
|
def eval_speaker_similarity(model, wav1, wav2, rank): |
|
|
wav1, sr1 = sf.read(wav1) |
|
|
wav2, sr2 = sf.read(wav2) |
|
|
|
|
|
wav1 = torch.from_numpy(wav1).unsqueeze(0).float() |
|
|
wav2 = torch.from_numpy(wav2).unsqueeze(0).float() |
|
|
resample1 = Resample(orig_freq=sr1, new_freq=16000) |
|
|
resample2 = Resample(orig_freq=sr2, new_freq=16000) |
|
|
wav1 = resample1(wav1) |
|
|
wav2 = resample2(wav2) |
|
|
|
|
|
wav1 = wav1.cuda(f"cuda:{rank}") |
|
|
wav2 = wav2.cuda(f"cuda:{rank}") |
|
|
|
|
|
model.eval() |
|
|
with torch.no_grad(): |
|
|
emb1 = model(wav1) |
|
|
emb2 = model(wav2) |
|
|
|
|
|
sim = F.cosine_similarity(emb1, emb2) |
|
|
logging.info("The similarity score between two audios is %.4f (-1.0, 1.0)." % (sim[0].item())) |
|
|
return sim[0].item() |
|
|
|
|
|
|
|
|
def eval_proc(model_path, task_queue, rank, sim_list): |
|
|
model = None |
|
|
assert MODEL_NAME in MODEL_LIST, 'The model_name should be in {}'.format(MODEL_LIST) |
|
|
model = init_model(MODEL_NAME, model_path) if model is None else model |
|
|
model.to(f"cuda:{rank}") |
|
|
|
|
|
|
|
|
while True: |
|
|
try: |
|
|
new_record = task_queue.get() |
|
|
if new_record is None: |
|
|
logging.info("FINISH processing all inputs") |
|
|
break |
|
|
|
|
|
ref = new_record[0] |
|
|
gen = new_record[1] |
|
|
logging.info(f"eval SIM: {ref} v.s. {gen}") |
|
|
|
|
|
if not os.path.exists(ref) or not os.path.exists(gen): |
|
|
logging.info(f"MISSING: {ref} v.s. {gen}") |
|
|
continue |
|
|
|
|
|
sim = eval_speaker_similarity(model, ref, gen, rank) |
|
|
sim_list.append((sim, ref, gen)) |
|
|
except: |
|
|
logging.info(f"FAIL to eval SIM: {ref} v.s. {gen}") |
|
|
|
|
|
|
|
|
def main(args): |
|
|
handler = logging.FileHandler(filename=args.log_file, mode="w") |
|
|
logging.root.addHandler(handler) |
|
|
|
|
|
device_list = [0] |
|
|
if "CUDA_VISIBLE_DEVICES" in os.environ: |
|
|
device_list = [int(x.strip()) for x in os.environ["CUDA_VISIBLE_DEVICES"].split(",")] |
|
|
|
|
|
logging.info(f"Using devices: {device_list}") |
|
|
n_procs = len(device_list) |
|
|
ctx = mp.get_context('spawn') |
|
|
with ctx.Manager() as manager: |
|
|
sim_list = manager.list() |
|
|
task_queue = manager.Queue() |
|
|
get_ref_and_gen_files(args.test_lst, args.test_path, task_queue) |
|
|
|
|
|
processes = [] |
|
|
for idx in range(n_procs): |
|
|
task_queue.put(None) |
|
|
rank = idx |
|
|
p = ctx.Process(target=eval_proc, args=(args.model_path, task_queue, rank, sim_list)) |
|
|
processes.append(p) |
|
|
|
|
|
for proc in processes: |
|
|
proc.start() |
|
|
|
|
|
for proc in processes: |
|
|
proc.join() |
|
|
|
|
|
sim_scores = [] |
|
|
for sim, ref, gen in sim_list: |
|
|
logging.info(f"{ref} vs {gen} : {sim}") |
|
|
sim_scores.append(sim) |
|
|
avg_sim = round(np.mean(np.array(list(sim_scores))), 3) |
|
|
logging.info("total evaluated wav pairs: %d" % (len(sim_list))) |
|
|
logging.info("The average similarity score of %s is %.4f (-1.0, 1.0)." % (args.test_path, avg_sim)) |
|
|
return avg_sim |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument( |
|
|
"--test-path", |
|
|
required=True, |
|
|
type=str, |
|
|
help=f"folder of wav files", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--test-lst", |
|
|
required=True, |
|
|
type=str, |
|
|
help=f"path to test file lst", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--log-file", |
|
|
required=False, |
|
|
type=str, |
|
|
default=None, |
|
|
help=f"path to test file lst", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--model-path", |
|
|
type=str, |
|
|
default="./wavlm-sv", |
|
|
help=f"path to sv model", |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
main(args) |
|
|
|