DSTK / evaluation /eval_sim.py
gooorillax's picture
first push of codes and models for g2p, t2u, tokenizer and detokenizer
cd8454d
raw
history blame
5.91 kB
# Copyright (C) 2025. Huawei Technologies Co., Ltd. All Rights Reserved. (authors: Xiao Chen)
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from tqdm import tqdm
import logging
import os
from verification import init_model, MODEL_LIST
import soundfile as sf
import torch
import numpy as np
import torch.nn.functional as F
from torchaudio.transforms import Resample
import torch.multiprocessing as mp
console_format = logging.Formatter(
"[%(asctime)s][%(filename)s:%(levelname)s][%(process)d:%(threadName)s]%(message)s"
)
console_handler = logging.StreamHandler()
console_handler.setFormatter(console_format)
console_handler.setLevel(logging.INFO)
if len(logging.root.handlers) > 0:
for handler in logging.root.handlers:
logging.root.removeHandler(handler)
logging.root.addHandler(console_handler)
logging.root.setLevel(logging.INFO)
MODEL_NAME = "wavlm_large"
S3PRL_PATH = os.environ.get("S3PRL_PATH")
if S3PRL_PATH is not None:
import patch_unispeech
logging.info("Applying Patches for unispeech!!!")
patch_unispeech.patch_for_npu()
def get_ref_and_gen_files(
test_lst, test_folder, task_queue
):
with open(test_lst, "r") as fp:
for line in fp:
fields = line.strip().split("|")
gen_name = fields[2].split("/")[-1]
gen_name = gen_name.split(".")[0]
gen_file = f"{test_folder}/{gen_name}_gen.wav"
ref_name = fields[0].split("/")[-1]
ref_name = ref_name.split(".")[0]
ref_file = f"{test_folder}/{ref_name}_ref.wav"
task_queue.put((ref_file, gen_file))
return
def eval_speaker_similarity(model, wav1, wav2, rank):
wav1, sr1 = sf.read(wav1)
wav2, sr2 = sf.read(wav2)
wav1 = torch.from_numpy(wav1).unsqueeze(0).float()
wav2 = torch.from_numpy(wav2).unsqueeze(0).float()
resample1 = Resample(orig_freq=sr1, new_freq=16000)
resample2 = Resample(orig_freq=sr2, new_freq=16000)
wav1 = resample1(wav1)
wav2 = resample2(wav2)
wav1 = wav1.cuda(f"cuda:{rank}")
wav2 = wav2.cuda(f"cuda:{rank}")
model.eval()
with torch.no_grad():
emb1 = model(wav1)
emb2 = model(wav2)
sim = F.cosine_similarity(emb1, emb2)
logging.info("The similarity score between two audios is %.4f (-1.0, 1.0)." % (sim[0].item()))
return sim[0].item()
def eval_proc(model_path, task_queue, rank, sim_list):
model = None
assert MODEL_NAME in MODEL_LIST, 'The model_name should be in {}'.format(MODEL_LIST)
model = init_model(MODEL_NAME, model_path) if model is None else model
model.to(f"cuda:{rank}")
# sim_list = []
# for ref, gen in tqdm(ref_gen_list):
while True:
try:
new_record = task_queue.get()
if new_record is None:
logging.info("FINISH processing all inputs")
break
ref = new_record[0]
gen = new_record[1]
logging.info(f"eval SIM: {ref} v.s. {gen}")
if not os.path.exists(ref) or not os.path.exists(gen):
logging.info(f"MISSING: {ref} v.s. {gen}")
continue
sim = eval_speaker_similarity(model, ref, gen, rank)
sim_list.append((sim, ref, gen))
except:
logging.info(f"FAIL to eval SIM: {ref} v.s. {gen}")
def main(args):
handler = logging.FileHandler(filename=args.log_file, mode="w")
logging.root.addHandler(handler)
device_list = [0]
if "CUDA_VISIBLE_DEVICES" in os.environ:
device_list = [int(x.strip()) for x in os.environ["CUDA_VISIBLE_DEVICES"].split(",")]
logging.info(f"Using devices: {device_list}")
n_procs = len(device_list)
ctx = mp.get_context('spawn')
with ctx.Manager() as manager:
sim_list = manager.list()
task_queue = manager.Queue()
get_ref_and_gen_files(args.test_lst, args.test_path, task_queue)
processes = []
for idx in range(n_procs):
task_queue.put(None)
rank = idx # device_list[idx]
p = ctx.Process(target=eval_proc, args=(args.model_path, task_queue, rank, sim_list))
processes.append(p)
for proc in processes:
proc.start()
for proc in processes:
proc.join()
sim_scores = []
for sim, ref, gen in sim_list:
logging.info(f"{ref} vs {gen} : {sim}")
sim_scores.append(sim)
avg_sim = round(np.mean(np.array(list(sim_scores))), 3)
logging.info("total evaluated wav pairs: %d" % (len(sim_list)))
logging.info("The average similarity score of %s is %.4f (-1.0, 1.0)." % (args.test_path, avg_sim))
return avg_sim
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--test-path",
required=True,
type=str,
help=f"folder of wav files",
)
parser.add_argument(
"--test-lst",
required=True,
type=str,
help=f"path to test file lst",
)
parser.add_argument(
"--log-file",
required=False,
type=str,
default=None,
help=f"path to test file lst",
)
parser.add_argument(
"--model-path",
type=str,
default="./wavlm-sv",
help=f"path to sv model",
)
args = parser.parse_args()
main(args)