|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from transformers import WhisperProcessor, WhisperForConditionalGeneration |
|
|
import soundfile as sf |
|
|
import scipy |
|
|
import argparse |
|
|
from whisper_normalizer.english import EnglishTextNormalizer |
|
|
import os |
|
|
import string |
|
|
import lingvo.tasks.asr.tools.simple_wer_v2 as WER |
|
|
from tqdm import tqdm |
|
|
import logging |
|
|
import torch |
|
|
|
|
|
keyphrases = None |
|
|
english_normalizer = EnglishTextNormalizer() |
|
|
device = torch.device("cuda") |
|
|
en_asr_model_path = "./whisper-large-v3" |
|
|
|
|
|
|
|
|
wer_obj = WER.SimpleWER( |
|
|
key_phrases=keyphrases, |
|
|
html_handler=WER.HighlightAlignedHtmlHandler(WER.HighlightAlignedHtml), |
|
|
preprocess_handler=WER.RemoveCommentTxtPreprocess, |
|
|
) |
|
|
|
|
|
|
|
|
def dummy_split_text(text): |
|
|
return text |
|
|
|
|
|
|
|
|
def remove_punct(text): |
|
|
puncts = set(string.punctuation) |
|
|
output = "" |
|
|
for char in text: |
|
|
if char not in puncts: |
|
|
output += char |
|
|
output = output.replace(" ", " ") |
|
|
return output |
|
|
|
|
|
|
|
|
def get_gt_ref_texts_and_wav_files( |
|
|
args, gt_test_lst, gt_folder, punct_remover, text_spliter |
|
|
): |
|
|
wav_file_list = [] |
|
|
reference = [] |
|
|
with open(gt_test_lst, "r") as fp: |
|
|
for line in fp: |
|
|
fields = line.strip().split("|") |
|
|
wav_file = f"{gt_folder}/{fields[0]}.wav" |
|
|
|
|
|
if not os.path.isfile(wav_file): |
|
|
continue |
|
|
|
|
|
wav_file_list.append(wav_file) |
|
|
text = fields[-1].lower() |
|
|
if args.norm_text: |
|
|
truth_text = english_normalizer(text) |
|
|
elif args.remove_punct: |
|
|
truth_text = punct_remover(text) |
|
|
else: |
|
|
truth_text = text |
|
|
truth_text = text_spliter(truth_text) |
|
|
reference.append([truth_text, fields[-1]]) |
|
|
|
|
|
assert len(reference) == len(wav_file_list) |
|
|
return reference, wav_file_list |
|
|
|
|
|
|
|
|
def get_ref_texts_and_gen_files( |
|
|
args, test_lst, test_folder, punct_remover, text_spliter |
|
|
): |
|
|
reference = [] |
|
|
gen_file_list = [] |
|
|
with open(test_lst, "r") as fp: |
|
|
for line in fp: |
|
|
fields = line.strip().split("|") |
|
|
filename = fields[2].split("/")[-1] |
|
|
filename = filename.split(".")[0] |
|
|
gen_file = f"{filename}_gen.wav" |
|
|
gen_file_list.append(f"{test_folder}/{gen_file}") |
|
|
|
|
|
text = fields[-1].lower() |
|
|
if args.norm_text: |
|
|
truth_text = english_normalizer(text) |
|
|
elif args.remove_punct: |
|
|
truth_text = punct_remover(text) |
|
|
else: |
|
|
truth_text = text |
|
|
|
|
|
truth_text = text_spliter(truth_text) |
|
|
reference.append([truth_text, fields[-1]]) |
|
|
|
|
|
assert len(reference) == len(gen_file_list) |
|
|
return reference, gen_file_list |
|
|
|
|
|
|
|
|
def get_hypo_texts(args, results_list, punct_remover, text_spliter): |
|
|
hypothesis = [] |
|
|
for res in results_list: |
|
|
text = res["text"].lower() |
|
|
if args.norm_text: |
|
|
hypo_text = english_normalizer(text) |
|
|
elif args.remove_punct: |
|
|
hypo_text = punct_remover(text) |
|
|
else: |
|
|
hypo_text = text |
|
|
hypo_text = text_spliter(hypo_text) |
|
|
hypothesis.append([hypo_text, res["text"]]) |
|
|
|
|
|
return hypothesis |
|
|
|
|
|
|
|
|
def calc_wer(reference, hypothesis, test_lst): |
|
|
logging.info(f"calc WER:") |
|
|
for idx in tqdm(range(len(hypothesis))): |
|
|
hypo = hypothesis[idx][0].strip() |
|
|
ref = reference[idx][0].strip() |
|
|
wer_obj.AddHypRef(hypo, ref) |
|
|
|
|
|
str_summary, str_details, str_keyphrases_info = wer_obj.GetSummaries() |
|
|
logging.info(f"WER summary:") |
|
|
logging.info(str_summary) |
|
|
logging.info(str_details) |
|
|
logging.info(str_keyphrases_info) |
|
|
|
|
|
try: |
|
|
fn_output = test_lst + "_diagnosis.html" |
|
|
aligned_html = "<br>".join(wer_obj.aligned_htmls) |
|
|
with open(fn_output, "wt") as fp: |
|
|
fp.write("<body><html>") |
|
|
fp.write("<div>%s</div>" % aligned_html) |
|
|
fp.write("</body></html>") |
|
|
fp.close() |
|
|
|
|
|
text_output = test_lst + "_rawtext.lst" |
|
|
with open(text_output, "w") as fp: |
|
|
for ref, hypo in zip(reference, hypothesis): |
|
|
fp.write(f"{ref[1]}|{hypo[1]}\n") |
|
|
fp.close() |
|
|
logging.info(f"Save {fn_output} and {text_output} for diagnosis") |
|
|
except IOError: |
|
|
logging.info("failed to write diagnosis html") |
|
|
|
|
|
|
|
|
def load_en_model(): |
|
|
processor = WhisperProcessor.from_pretrained(en_asr_model_path) |
|
|
model = WhisperForConditionalGeneration.from_pretrained(en_asr_model_path).to( |
|
|
device |
|
|
) |
|
|
return processor, model |
|
|
|
|
|
|
|
|
def process_wavs(wav_file_list, batch_size=300): |
|
|
results = [] |
|
|
processor, model = load_en_model() |
|
|
for wav_file_path in tqdm(wav_file_list): |
|
|
wav, sr = sf.read(wav_file_path) |
|
|
if sr != 16000: |
|
|
wav = scipy.signal.resample(wav, int(len(wav) * 16000 / sr)) |
|
|
input_features = processor( |
|
|
wav, sampling_rate=16000, return_tensors="pt" |
|
|
).input_features |
|
|
input_features = input_features.to(device) |
|
|
forced_decoder_ids = processor.get_decoder_prompt_ids( |
|
|
language="english", task="transcribe" |
|
|
) |
|
|
predicted_ids = model.generate( |
|
|
input_features, forced_decoder_ids=forced_decoder_ids |
|
|
) |
|
|
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[ |
|
|
0 |
|
|
] |
|
|
results.append({"text": transcription.strip()}) |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
def main(args): |
|
|
handler = logging.FileHandler(filename=args.log_file, mode="w") |
|
|
logging.root.setLevel(logging.INFO) |
|
|
logging.root.addHandler(handler) |
|
|
|
|
|
test_path = ( |
|
|
args.test_path |
|
|
) |
|
|
lst_path = args.test_lst |
|
|
logging.info( |
|
|
f"Evaluate {args.test_path} with Text Normalization: {args.norm_text} and Remove Punct: {args.remove_punct}" |
|
|
) |
|
|
|
|
|
if args.eval_gt: |
|
|
logging.info(f"run ASR for GT: {lst_path}") |
|
|
reference, wav_file_list = get_gt_ref_texts_and_wav_files( |
|
|
args, lst_path, test_path, remove_punct, dummy_split_text |
|
|
) |
|
|
results = process_wavs(wav_file_list, batch_size=12) |
|
|
else: |
|
|
logging.info(f"run ASR for detok: {lst_path}") |
|
|
reference, gen_file_list = get_ref_texts_and_gen_files( |
|
|
args, lst_path, test_path, remove_punct, dummy_split_text |
|
|
) |
|
|
results = process_wavs(gen_file_list, batch_size=12) |
|
|
|
|
|
hypothesis = get_hypo_texts(args, results, remove_punct, dummy_split_text) |
|
|
|
|
|
assert len(hypothesis) == len(reference) |
|
|
logging.info(f"Finish runing ASR for {lst_path}") |
|
|
logging.info(f"hypothesis: {len(hypothesis)} vs reference: {len(reference)}") |
|
|
|
|
|
calc_wer(reference, hypothesis, test_path) |
|
|
logging.info(f"Finish evaluate {lst_path}, results are in {args.log_file}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument( |
|
|
"--test-path", |
|
|
required=True, |
|
|
type=str, |
|
|
help=f"folder of wav files", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--test-lst", |
|
|
required=True, |
|
|
type=str, |
|
|
help=f"path to test file lst", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--log-file", |
|
|
required=False, |
|
|
type=str, |
|
|
default=None, |
|
|
help=f"path to test file lst", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--norm-text", |
|
|
default=False, |
|
|
action="store_true", |
|
|
help=f"normalized GT and hypo texts", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--remove-punct", |
|
|
default=False, |
|
|
action="store_true", |
|
|
help=f"remove punct from GT and hypo texts", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--eval-gt", |
|
|
default=False, |
|
|
action="store_true", |
|
|
help=f"remove punct from GT and hypo texts", |
|
|
) |
|
|
args = parser.parse_args() |
|
|
|
|
|
main(args) |
|
|
|