DSTK / evaluation /eval_detok_en.py
gooorillax's picture
first push of codes and models for g2p, t2u, tokenizer and detokenizer
cd8454d
raw
history blame
8.5 kB
# Copyright (C) 2025. Huawei Technologies Co., Ltd. All Rights Reserved. (authors: Xiao Chen)
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import soundfile as sf
import scipy
import argparse
from whisper_normalizer.english import EnglishTextNormalizer
import os
import string
import lingvo.tasks.asr.tools.simple_wer_v2 as WER
from tqdm import tqdm
import logging
import torch
keyphrases = None
english_normalizer = EnglishTextNormalizer()
device = torch.device("cuda")
en_asr_model_path = "./whisper-large-v3"
wer_obj = WER.SimpleWER(
key_phrases=keyphrases,
html_handler=WER.HighlightAlignedHtmlHandler(WER.HighlightAlignedHtml),
preprocess_handler=WER.RemoveCommentTxtPreprocess,
)
def dummy_split_text(text):
return text
def remove_punct(text):
puncts = set(string.punctuation)
output = ""
for char in text:
if char not in puncts:
output += char
output = output.replace(" ", " ")
return output
def get_gt_ref_texts_and_wav_files(
args, gt_test_lst, gt_folder, punct_remover, text_spliter
):
wav_file_list = []
reference = []
with open(gt_test_lst, "r") as fp:
for line in fp:
fields = line.strip().split("|")
wav_file = f"{gt_folder}/{fields[0]}.wav"
if not os.path.isfile(wav_file):
continue
wav_file_list.append(wav_file)
text = fields[-1].lower()
if args.norm_text:
truth_text = english_normalizer(text) # " ".join(fields[-1])
elif args.remove_punct:
truth_text = punct_remover(text)
else:
truth_text = text
truth_text = text_spliter(truth_text)
reference.append([truth_text, fields[-1]])
assert len(reference) == len(wav_file_list)
return reference, wav_file_list
def get_ref_texts_and_gen_files(
args, test_lst, test_folder, punct_remover, text_spliter
):
reference = []
gen_file_list = []
with open(test_lst, "r") as fp:
for line in fp:
fields = line.strip().split("|")
filename = fields[2].split("/")[-1]
filename = filename.split(".")[0]
gen_file = f"{filename}_gen.wav"
gen_file_list.append(f"{test_folder}/{gen_file}")
text = fields[-1].lower()
if args.norm_text:
truth_text = english_normalizer(text) # " ".join(fields[-1])
elif args.remove_punct:
truth_text = punct_remover(text)
else:
truth_text = text
truth_text = text_spliter(truth_text)
reference.append([truth_text, fields[-1]])
assert len(reference) == len(gen_file_list)
return reference, gen_file_list
def get_hypo_texts(args, results_list, punct_remover, text_spliter):
hypothesis = []
for res in results_list:
text = res["text"].lower()
if args.norm_text:
hypo_text = english_normalizer(text)
elif args.remove_punct:
hypo_text = punct_remover(text)
else:
hypo_text = text
hypo_text = text_spliter(hypo_text)
hypothesis.append([hypo_text, res["text"]])
return hypothesis
def calc_wer(reference, hypothesis, test_lst):
logging.info(f"calc WER:")
for idx in tqdm(range(len(hypothesis))):
hypo = hypothesis[idx][0].strip()
ref = reference[idx][0].strip()
wer_obj.AddHypRef(hypo, ref)
str_summary, str_details, str_keyphrases_info = wer_obj.GetSummaries()
logging.info(f"WER summary:")
logging.info(str_summary)
logging.info(str_details)
logging.info(str_keyphrases_info)
try:
fn_output = test_lst + "_diagnosis.html"
aligned_html = "<br>".join(wer_obj.aligned_htmls)
with open(fn_output, "wt") as fp:
fp.write("<body><html>")
fp.write("<div>%s</div>" % aligned_html)
fp.write("</body></html>")
fp.close()
text_output = test_lst + "_rawtext.lst"
with open(text_output, "w") as fp:
for ref, hypo in zip(reference, hypothesis):
fp.write(f"{ref[1]}|{hypo[1]}\n")
fp.close()
logging.info(f"Save {fn_output} and {text_output} for diagnosis")
except IOError:
logging.info("failed to write diagnosis html")
def load_en_model():
processor = WhisperProcessor.from_pretrained(en_asr_model_path)
model = WhisperForConditionalGeneration.from_pretrained(en_asr_model_path).to(
device
)
return processor, model
def process_wavs(wav_file_list, batch_size=300):
results = []
processor, model = load_en_model()
for wav_file_path in tqdm(wav_file_list):
wav, sr = sf.read(wav_file_path)
if sr != 16000:
wav = scipy.signal.resample(wav, int(len(wav) * 16000 / sr))
input_features = processor(
wav, sampling_rate=16000, return_tensors="pt"
).input_features
input_features = input_features.to(device)
forced_decoder_ids = processor.get_decoder_prompt_ids(
language="english", task="transcribe"
)
predicted_ids = model.generate(
input_features, forced_decoder_ids=forced_decoder_ids
)
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[
0
]
results.append({"text": transcription.strip()})
return results
def main(args):
handler = logging.FileHandler(filename=args.log_file, mode="w")
logging.root.setLevel(logging.INFO)
logging.root.addHandler(handler)
test_path = (
args.test_path
) # './40ms.AISHELL2.test_with_single_ref.base.chunk25.gen'
lst_path = args.test_lst # "40ms.AISHELL2.test_with_single_ref.base.lst"
logging.info(
f"Evaluate {args.test_path} with Text Normalization: {args.norm_text} and Remove Punct: {args.remove_punct}"
)
if args.eval_gt:
logging.info(f"run ASR for GT: {lst_path}")
reference, wav_file_list = get_gt_ref_texts_and_wav_files(
args, lst_path, test_path, remove_punct, dummy_split_text
)
results = process_wavs(wav_file_list, batch_size=12)
else:
logging.info(f"run ASR for detok: {lst_path}")
reference, gen_file_list = get_ref_texts_and_gen_files(
args, lst_path, test_path, remove_punct, dummy_split_text
)
results = process_wavs(gen_file_list, batch_size=12)
hypothesis = get_hypo_texts(args, results, remove_punct, dummy_split_text)
assert len(hypothesis) == len(reference)
logging.info(f"Finish runing ASR for {lst_path}")
logging.info(f"hypothesis: {len(hypothesis)} vs reference: {len(reference)}")
calc_wer(reference, hypothesis, test_path)
logging.info(f"Finish evaluate {lst_path}, results are in {args.log_file}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--test-path",
required=True,
type=str,
help=f"folder of wav files",
)
parser.add_argument(
"--test-lst",
required=True,
type=str,
help=f"path to test file lst",
)
parser.add_argument(
"--log-file",
required=False,
type=str,
default=None,
help=f"path to test file lst",
)
parser.add_argument(
"--norm-text",
default=False,
action="store_true",
help=f"normalized GT and hypo texts",
)
parser.add_argument(
"--remove-punct",
default=False,
action="store_true",
help=f"remove punct from GT and hypo texts",
)
parser.add_argument(
"--eval-gt",
default=False,
action="store_true",
help=f"remove punct from GT and hypo texts",
)
args = parser.parse_args()
main(args)