darshanmakwana commited on
Commit
2675a94
·
verified ·
1 Parent(s): 34974f2

Upload eval_model.py

Browse files
Files changed (1) hide show
  1. eval_model.py +153 -0
eval_model.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, GPT2LMHeadModel
2
+ from datasets import load_dataset, Dataset, DatasetDict
3
+ import random
4
+ import string
5
+ import torch
6
+
7
+ from torchmetrics.text import WordErrorRate, CharErrorRate
8
+
9
+ wer = WordErrorRate()
10
+ cer = CharErrorRate()
11
+
12
+ def process(text):
13
+
14
+ # Lower case every letter
15
+ text = text.lower()
16
+
17
+ # Remove punctuation
18
+ punctuation_to_remove = string.punctuation.replace("'", "")
19
+ translation_table = str.maketrans('', '', punctuation_to_remove)
20
+ text = text.translate(translation_table)
21
+
22
+ # Remove whitespaces from front and behind
23
+ while text[0] == ' ' or text[-1] == ' ':
24
+ if text[0] == ' ':
25
+ text = text[1:]
26
+ if text[-1] == ' ':
27
+ text = text[:-1]
28
+
29
+ return text
30
+
31
+ import jiwer
32
+ from edit_distance import SequenceMatcher
33
+ def correct_text(text):
34
+ transforms = jiwer.Compose(
35
+ [
36
+ jiwer.ExpandCommonEnglishContractions(),
37
+ jiwer.ToLowerCase(),
38
+ jiwer.RemoveMultipleSpaces(),
39
+ jiwer.Strip(),
40
+ jiwer.RemovePunctuation(),
41
+ jiwer.ReduceToListOfListOfWords(),
42
+ ]
43
+ )
44
+ return transforms(text)
45
+
46
+ def align_gt_asr(gt, asr):
47
+ sm = SequenceMatcher(a=gt, b=asr)
48
+ best_path = []
49
+ opcodes = sm.get_opcodes()
50
+ for tag, i1, i2, j1, j2 in opcodes:
51
+ if tag == "delete":
52
+ for i in range(i1, i2):
53
+ best_path.append([gt[i], ""])
54
+ if tag == "replace" or tag == "equal":
55
+ for i, j in zip(range(i1, i2), range(j1, j2)):
56
+ best_path.append([gt[i], asr[j]])
57
+ if tag == "insert":
58
+ for j in range(j1, j2):
59
+ best_path.append(["", asr[j]])
60
+ return best_path
61
+
62
+ dtype = torch.float16
63
+
64
+ dataset_name = "./../libripseech_tokenized"
65
+ dataset = DatasetDict.load_from_disk(dataset_name)
66
+
67
+ with open("./../prompting/blist/all_rare_words.txt") as fin:
68
+ rarewords = [process(word.strip()) for word in fin]
69
+
70
+ tokenizer = AutoTokenizer.from_pretrained("./../tokenizer")
71
+ tokenizer.pad_token_id = 0
72
+ tokenizer.pad_token = "<|padding|>"
73
+ tokenizer.padding_side = "left"
74
+
75
+ # Adding new tokens for introducing prompts
76
+ tokenizer.add_tokens(["<|startofprompt|>", "<|sepofprompt|>", "<|endofprompt|>"])
77
+ sot_token = tokenizer.encode("<|startoftranscript|>")[0]
78
+ eot_token = tokenizer.encode("<|endoftranscript|>")[0]
79
+
80
+ from math import ceil
81
+ from tqdm import tqdm
82
+
83
+ val_bs = 32
84
+ n_bwords = 25
85
+ context_length = 2048
86
+
87
+ def prepare(element):
88
+
89
+ # Add audio
90
+ audio_tkns = element["audio_tokens"]
91
+ data = "".join([f"<|audio:{tkn}|>" for tkn in audio_tkns])
92
+
93
+ # sample context words and mix with the biasing list
94
+ b_words = element["b_words"]
95
+ if n_bwords > len(b_words):
96
+ context = b_words + random.sample(rarewords, n_bwords - len(b_words))
97
+ else:
98
+ context = random.sample(b_words, n_bwords)
99
+ random.shuffle(context)
100
+
101
+ # add the context words
102
+ data += "<|startofprompt|>" + "<|sepofprompt|>".join(context) + "<|endofprompt|>"
103
+
104
+ # Add text
105
+ data += "<|startoftranscript|>"
106
+
107
+ return {"data": data, "context": context}
108
+
109
+ @torch.no_grad()
110
+ def evaluate_model(model):
111
+
112
+ transcripts = []
113
+
114
+ processed_data = dataset["test.clean"].map(prepare)
115
+ data = processed_data["data"]
116
+
117
+ for idx in tqdm(range(ceil(len(data)/val_bs))):
118
+
119
+ outputs = tokenizer(data[idx * val_bs: (idx + 1) * val_bs], truncation=False, max_length=None, padding=True, return_tensors="pt").to(model.device)
120
+ input_ids = outputs["input_ids"]
121
+ par = input_ids.shape[-1]
122
+
123
+ generations = model.generate(
124
+ input_ids,
125
+ max_new_tokens=context_length - par - 1,
126
+ eos_token_id = eot_token
127
+ )
128
+ transcripts += tokenizer.batch_decode(generations[:, par:], skip_special_tokens=True)
129
+
130
+ bias_word_cnt = 0
131
+ normal_word_cnt = 0
132
+ u_wer = 0.0
133
+ b_wer = 0.0
134
+ pred_list = correct_text(transcripts)
135
+ text_list = correct_text(processed_data["text"])
136
+ prompt_list = processed_data["context"]
137
+ for a, b, c in zip(pred_list, text_list, prompt_list):
138
+ aligned_pair = align_gt_asr(b, a)
139
+ for gt_word, asr_word in aligned_pair:
140
+ if gt_word in c or asr_word in c:
141
+ if gt_word != asr_word:
142
+ b_wer += 1.0
143
+ if gt_word in c:
144
+ bias_word_cnt += 1
145
+ else:
146
+ if gt_word != asr_word:
147
+ u_wer += 1.0
148
+ if gt_word != "":
149
+ normal_word_cnt += 1
150
+ u_wer = u_wer / normal_word_cnt * 100
151
+ b_wer = b_wer / bias_word_cnt * 100
152
+
153
+ return wer(transcripts, processed_data["text"]).item() * 100, cer(transcripts, processed_data["text"]).item() * 100, b_wer, u_wer