darshanmakwana commited on
Commit
479fcf6
·
verified ·
1 Parent(s): 2675a94

Upload 2 files

Browse files
Files changed (2) hide show
  1. ner.ipynb +398 -0
  2. train.py +179 -0
ner.ipynb ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "debeec92",
6
+ "metadata": {
7
+ "tags": []
8
+ },
9
+ "source": [
10
+ "## Gathering NER Dataset"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": null,
16
+ "id": "b83776e8",
17
+ "metadata": {
18
+ "tags": []
19
+ },
20
+ "outputs": [],
21
+ "source": [
22
+ "from datasets import DatasetDict\n",
23
+ "from transformers import AutoTokenizer\n",
24
+ "\n",
25
+ "dataset = DatasetDict.load_from_disk().remove_columns([\"token_type_ids\", \"attention_mask\"])\n",
26
+ "\n",
27
+ "tokenizer = AutoTokenizer.from_pretrained(\"./../tokenizer\")\n",
28
+ "tokenizer.pad_token_id = 0\n",
29
+ "tokenizer.pad_token = \"<|padding|>\"\n",
30
+ "tokenizer.padding_size = \"right\"\n",
31
+ "\n",
32
+ "# new tokens for prompting\n",
33
+ "num_new_tokens = tokenizer.add_tokens([\"<|startofprompt|>\", \"<|sepofprompt|>\", \"<|endofprompt|>\"])\n",
34
+ "# new tokens for entities\n",
35
+ "tokenizer.add_tokens([\"<|entity:PER|>\", \"<|entity:LOC|>\", \"<|entity:ORG|>\", \"<|entity|>\", \"<|detectentities|>\"])\n",
36
+ "# new tokens for images\n",
37
+ "tokenizer.add_tokens([\"<|startofimage|>\", \"<|endofimage|>\"])\n",
38
+ "tokenizer.add_tokens([ f\"<|image:{tkn}|>\" for tkn in range(16000)])\n",
39
+ "\n",
40
+ "tokenizer.save_pretrained(\"./tokenizer\")\n",
41
+ "\n",
42
+ "print(\"Total Vocab Size:\", len(tokenizer))"
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "code",
47
+ "execution_count": null,
48
+ "id": "f2a95871-6e2d-4b96-bc36-8febac09d795",
49
+ "metadata": {
50
+ "tags": []
51
+ },
52
+ "outputs": [],
53
+ "source": [
54
+ "tokenizer = AutoTokenizer.from_pretrained(\"./tokenizer\")"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "code",
59
+ "execution_count": null,
60
+ "id": "a706dd6d-e9b2-4e42-baf1-7d17cd93c54f",
61
+ "metadata": {
62
+ "tags": []
63
+ },
64
+ "outputs": [],
65
+ "source": [
66
+ "import numpy as np\n",
67
+ "from tqdm import tqdm\n",
68
+ "import string\n",
69
+ "import os\n",
70
+ "import re\n",
71
+ "\n",
72
+ "audio_paths = sorted(os.listdir(\"./mp3\"))\n",
73
+ "txt_paths = sorted(os.listdir(\"./txt\"))\n",
74
+ "data = np.load(\"tokens.npz\")\n",
75
+ "audio_tokens = [data[key] for key in data.keys()]"
76
+ ]
77
+ },
78
+ {
79
+ "cell_type": "code",
80
+ "execution_count": null,
81
+ "id": "ce8bb550-8149-438c-9ca5-b12681f36476",
82
+ "metadata": {
83
+ "tags": []
84
+ },
85
+ "outputs": [],
86
+ "source": [
87
+ "def tag_entities(text):\n",
88
+ " \n",
89
+ " patterns = {\n",
90
+ " \"PER\": r'\\|(.*?)\\]',\n",
91
+ " \"LOC\": r'\\$(.*?)\\]',\n",
92
+ " \"ORG\": r'\\{(.*?)\\]'\n",
93
+ " }\n",
94
+ " \n",
95
+ " entities = []\n",
96
+ "\n",
97
+ " for entity, pattern in patterns.items():\n",
98
+ " matches = re.findall(pattern, text)\n",
99
+ " text = re.sub(pattern, lambda m: f'<|entity:{entity}|>{m.group(1)}<|entity|>', text)\n",
100
+ " entities += matches\n",
101
+ "\n",
102
+ " return text, entities\n",
103
+ "\n",
104
+ "data = []\n",
105
+ "\n",
106
+ "for idx in tqdm(range(len(txt_paths))):\n",
107
+ " \n",
108
+ " with open(os.path.join(\"./txt\", txt_paths[idx])) as f:\n",
109
+ " txt = f.read()\n",
110
+ " \n",
111
+ " text, entities = tag_entities(txt.lower())\n",
112
+ " \n",
113
+ " audio_token = audio_tokens[idx]\n",
114
+ " \n",
115
+ " prompt = \"\".join([f\"<|audio:{tkn}|>\" for tkn in audio_token]) + \"<|detectentities|><|startofprompt|><|endofprompt|>\" + \"<|startoftranscript|>\" + text + \"<|endoftranscript|>\"\n",
116
+ " \n",
117
+ " try:\n",
118
+ " outputs = tokenizer(prompt, truncation=True, padding=\"max_length\", max_length=2048)\n",
119
+ " data.append({\n",
120
+ " \"audio_tokens\": audio_token,\n",
121
+ " \"raw_text\": text,\n",
122
+ " \"transcript\": txt.translate(str.maketrans('', '', string.punctuation)).lower(),\n",
123
+ " \"entities\": entities,\n",
124
+ " \"prompt\": prompt,\n",
125
+ " \"input_ids\": outputs[\"input_ids\"],\n",
126
+ " \"attention_mask\": output[\"attention_mask\"]\n",
127
+ " })\n",
128
+ " except:\n",
129
+ " print(idx)\n",
130
+ " continue\n",
131
+ " \n",
132
+ "from datasets import Dataset\n",
133
+ "import pandas as pd\n",
134
+ "\n",
135
+ "ds = Dataset.from_pandas(pd.DataFrame(data))\n",
136
+ "\n",
137
+ "ds.save_to_disk(\"entity_tokenized\")\n",
138
+ "ds.push_to_hub(\"darshanmakwana/entity_tokenized\")"
139
+ ]
140
+ },
141
+ {
142
+ "cell_type": "markdown",
143
+ "id": "38191f9a-2a11-4bb2-a885-ef303d6c43f7",
144
+ "metadata": {
145
+ "tags": []
146
+ },
147
+ "source": [
148
+ "## Validating Model"
149
+ ]
150
+ },
151
+ {
152
+ "cell_type": "code",
153
+ "execution_count": 11,
154
+ "id": "710a1144-46a1-43d4-9bf9-1c01569b26d4",
155
+ "metadata": {
156
+ "tags": []
157
+ },
158
+ "outputs": [],
159
+ "source": [
160
+ "from transformers import GPT2LMHeadModel, AutoTokenizer\n",
161
+ "from datasets import Dataset\n",
162
+ "import torch\n",
163
+ "\n",
164
+ "dataset_name = \"entity_tokenized\"\n",
165
+ "tokenizer_path = \"./../tokenizer\"\n",
166
+ "max_length = 2048\n",
167
+ "device = \"cuda:0\"\n",
168
+ "dtype = torch.float16\n",
169
+ "\n",
170
+ "dataset = Dataset.load_from_disk(dataset_name)\n",
171
+ "\n",
172
+ "tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)\n",
173
+ "tokenizer.pad_token_id = 0\n",
174
+ "tokenizer.pad_token = \"<|padding|>\"\n",
175
+ "tokenizer.padding_side = \"left\"\n",
176
+ "\n",
177
+ "# new tokens for prompting\n",
178
+ "num_new_tokens = tokenizer.add_tokens([\"<|startofprompt|>\", \"<|sepofprompt|>\", \"<|endofprompt|>\"])\n",
179
+ "# new tokens for entities\n",
180
+ "tokenizer.add_tokens([\"<|entity:PER|>\", \"<|entity:LOC|>\", \"<|entity:ORG|>\", \"<|entity|>\", \"<|detectentities|>\"])\n",
181
+ "\n",
182
+ "model = GPT2LMHeadModel.from_pretrained(\"./out/checkpoint-20000\").to(device).to(dtype).eval()"
183
+ ]
184
+ },
185
+ {
186
+ "cell_type": "code",
187
+ "execution_count": 21,
188
+ "id": "cea0d8c4-5c56-47eb-934a-86293bed6afa",
189
+ "metadata": {
190
+ "tags": []
191
+ },
192
+ "outputs": [
193
+ {
194
+ "data": {
195
+ "text/plain": [
196
+ "114.073974609375"
197
+ ]
198
+ },
199
+ "execution_count": 21,
200
+ "metadata": {},
201
+ "output_type": "execute_result"
202
+ }
203
+ ],
204
+ "source": [
205
+ "sum([param.numel() for param in model.parameters()]) / (1024 * 1024)"
206
+ ]
207
+ },
208
+ {
209
+ "cell_type": "code",
210
+ "execution_count": 12,
211
+ "id": "529ca732-569f-4b7d-8448-1f16b35a6694",
212
+ "metadata": {
213
+ "tags": []
214
+ },
215
+ "outputs": [
216
+ {
217
+ "name": "stderr",
218
+ "output_type": "stream",
219
+ "text": [
220
+ "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:27<00:00, 3.42s/it]\n"
221
+ ]
222
+ }
223
+ ],
224
+ "source": [
225
+ "from eval_model import process\n",
226
+ "from math import ceil\n",
227
+ "from tqdm import tqdm\n",
228
+ "import re\n",
229
+ "\n",
230
+ "def extract_entities(text):\n",
231
+ " \n",
232
+ " patterns = {\n",
233
+ " \"PER\": r'<\\|entity:PER\\|>(.*?)<\\|entity\\|>',\n",
234
+ " \"LOC\": r'<\\|entity:LOC\\|>(.*?)<\\|entity\\|>',\n",
235
+ " \"ORG\": r'<\\|entity:ORG\\|>(.*?)<\\|entity\\|>'\n",
236
+ " }\n",
237
+ " \n",
238
+ " entities = []\n",
239
+ "\n",
240
+ " for entity, pattern in patterns.items():\n",
241
+ " matches = re.findall(pattern, text)\n",
242
+ " text = re.sub(pattern, lambda m: f'{m.group(1)}', text)\n",
243
+ " entities += [process(match) for match in matches]\n",
244
+ "\n",
245
+ " return text, entities\n",
246
+ "\n",
247
+ "def preprocess(sample):\n",
248
+ " prompt = \"\".join([f\"<|audio:{tkn}|>\" for tkn in sample[\"audio_tokens\"]]) + \"<|detectentities|><|startofprompt|><|endofprompt|>\" + \"<|startoftranscript|>\"\n",
249
+ " return {\"prompt\": prompt}\n",
250
+ "\n",
251
+ "dataset = dataset.map(preprocess)\n",
252
+ "dataset = dataset.select(list(range(0, 1000)))\n",
253
+ "\n",
254
+ "eot_token = tokenizer.encode(\"<|endoftranscript|>\")[0]\n",
255
+ "\n",
256
+ "batch_size = 128\n",
257
+ "texts = []\n",
258
+ "tp = 0\n",
259
+ "fp = 0\n",
260
+ "tn = 0\n",
261
+ "\n",
262
+ "for idx in tqdm(range(ceil(len(dataset)/batch_size))):\n",
263
+ "\n",
264
+ " input_ids = tokenizer(dataset[idx * batch_size: (idx + 1) * batch_size][\"prompt\"], return_tensors=\"pt\", padding=True, truncation=True).input_ids.to(model.device)\n",
265
+ " par = input_ids.shape[-1]\n",
266
+ "\n",
267
+ " generations = model.generate(\n",
268
+ " input_ids,\n",
269
+ " max_new_tokens=max_length,\n",
270
+ " eos_token_id = eot_token\n",
271
+ " )\n",
272
+ " texts += tokenizer.batch_decode(generations[:, par:], skip_special_tokens=True)\n",
273
+ "\n",
274
+ "# transcript, pred_entities = extract_entities(transcripts[0])\n",
275
+ " \n",
276
+ "# entities = sample[\"entities\"]"
277
+ ]
278
+ },
279
+ {
280
+ "cell_type": "code",
281
+ "execution_count": 13,
282
+ "id": "5ce4384e-8771-487e-86e1-de5489ee4e59",
283
+ "metadata": {
284
+ "tags": []
285
+ },
286
+ "outputs": [
287
+ {
288
+ "name": "stderr",
289
+ "output_type": "stream",
290
+ "text": [
291
+ "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:04<00:00, 241.04it/s]"
292
+ ]
293
+ },
294
+ {
295
+ "name": "stdout",
296
+ "output_type": "stream",
297
+ "text": [
298
+ "Precision: 69.53846153846153\n",
299
+ "Recall: 69.32515337423312\n",
300
+ "F1 Score: 69.43164362519201\n"
301
+ ]
302
+ },
303
+ {
304
+ "name": "stderr",
305
+ "output_type": "stream",
306
+ "text": [
307
+ "\n"
308
+ ]
309
+ }
310
+ ],
311
+ "source": [
312
+ "tp = 0\n",
313
+ "fp = 0\n",
314
+ "fn = 0\n",
315
+ "\n",
316
+ "for idx in tqdm(range(len(dataset))):\n",
317
+ " \n",
318
+ " transcript, entities = extract_entities(texts[idx])\n",
319
+ "\n",
320
+ " for entity in entities:\n",
321
+ " if entity in dataset[idx][\"entities\"]:\n",
322
+ " tp += 1\n",
323
+ " else:\n",
324
+ " fp += 1\n",
325
+ " for entity in dataset[idx][\"entities\"]:\n",
326
+ " if entity not in entities:\n",
327
+ " fn += 1\n",
328
+ " \n",
329
+ "pre = tp / (tp + fp) * 100\n",
330
+ "recall = tp / (tp + fn) * 100\n",
331
+ "print(\"Precision:\", pre)\n",
332
+ "print(\"Recall:\", recall)\n",
333
+ "print(\"F1 Score:\", 2 / ((1/pre) + (1/recall)))"
334
+ ]
335
+ },
336
+ {
337
+ "cell_type": "code",
338
+ "execution_count": null,
339
+ "id": "ed0fad1a-bb30-446e-83a9-4a972fdb7766",
340
+ "metadata": {
341
+ "tags": []
342
+ },
343
+ "outputs": [],
344
+ "source": [
345
+ "## Train Iter Precision Recall F1 Score\n",
346
+ " 16000 68.80 69.27 69.03\n",
347
+ " 17000 72.92 70.78 71.83\n",
348
+ " 18000 76.78 75.34 76.05\n",
349
+ " 19000 81.78 80.92 81.34\n",
350
+ " 20000 85.05 80.74 82.84"
351
+ ]
352
+ },
353
+ {
354
+ "cell_type": "code",
355
+ "execution_count": 16,
356
+ "id": "113df077-c31c-4b57-876e-b19942100306",
357
+ "metadata": {
358
+ "tags": []
359
+ },
360
+ "outputs": [
361
+ {
362
+ "data": {
363
+ "text/plain": [
364
+ "81.34772710510141"
365
+ ]
366
+ },
367
+ "execution_count": 16,
368
+ "metadata": {},
369
+ "output_type": "execute_result"
370
+ }
371
+ ],
372
+ "source": [
373
+ "2 / ((1/81.78) + (1/80.92))"
374
+ ]
375
+ }
376
+ ],
377
+ "metadata": {
378
+ "kernelspec": {
379
+ "display_name": "Python 3 (ipykernel)",
380
+ "language": "python",
381
+ "name": "python3"
382
+ },
383
+ "language_info": {
384
+ "codemirror_mode": {
385
+ "name": "ipython",
386
+ "version": 3
387
+ },
388
+ "file_extension": ".py",
389
+ "mimetype": "text/x-python",
390
+ "name": "python",
391
+ "nbconvert_exporter": "python",
392
+ "pygments_lexer": "ipython3",
393
+ "version": "3.10.12"
394
+ }
395
+ },
396
+ "nbformat": 4,
397
+ "nbformat_minor": 5
398
+ }
train.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
3
+ os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,5,7"
4
+ # os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
5
+
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM
7
+ from transformers import AutoConfig, GPT2LMHeadModel, AutoModel, AutoModelForCausalLM
8
+ from transformers import Trainer, TrainingArguments
9
+ from datasets import Dataset, DatasetDict, concatenate_datasets, Sequence, Value
10
+ from torch.nn import functional as F
11
+ from tqdm import tqdm
12
+ import time
13
+ import torch
14
+ import wandb
15
+ import random
16
+ import string
17
+ from eval_model import evaluate_model
18
+
19
+ def process(text):
20
+
21
+ # Lower case every letter
22
+ text = text.lower()
23
+
24
+ # Remove punctuation
25
+ punctuation_to_remove = string.punctuation.replace("'", "")
26
+ translation_table = str.maketrans('', '', punctuation_to_remove)
27
+ text = text.translate(translation_table)
28
+
29
+ # Remove whitespaces from front and behind
30
+ while text[0] == ' ' or text[-1] == ' ':
31
+ if text[0] == ' ':
32
+ text = text[1:]
33
+ if text[-1] == ' ':
34
+ text = text[:-1]
35
+
36
+ return text
37
+
38
+ dataset_name = "entity_tokenized"
39
+ tokenizer_path = "./../tokenizer"
40
+ max_length = 2048
41
+ # n_layer = 16
42
+ # n_head = 16
43
+ # n_emb = 1024
44
+ n_bwords = 25
45
+
46
+ dataset = Dataset.load_from_disk(dataset_name)
47
+ dataset = dataset.remove_columns(["audio_tokens", "raw_text", "transcript", "entities", "prompt"])
48
+ feat = dataset.features.copy()
49
+ feat["input_ids"] = Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None)
50
+ feat["attention_mask"] = Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None)
51
+ dataset = dataset.cast(feat)
52
+ dataset = dataset.train_test_split(test_size=0.025)
53
+
54
+ asr_dataset = DatasetDict.load_from_disk("/root/.cache/huggingface/hub/models--darshanmakwana--storage/snapshots/b6e4caa73046e02ad19b48b39c097ba7b9980210/ASR/tokenized_librispeech/").remove_columns(["token_type_ids"])
55
+
56
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
57
+ tokenizer.pad_token_id = 0
58
+ tokenizer.pad_token = "<|padding|>"
59
+ tokenizer.padding_side = "right"
60
+
61
+ # new tokens for prompting
62
+ num_new_tokens = tokenizer.add_tokens(["<|startofprompt|>", "<|sepofprompt|>", "<|endofprompt|>"])
63
+ # new tokens for entities
64
+ tokenizer.add_tokens(["<|entity:PER|>", "<|entity:LOC|>", "<|entity:ORG|>", "<|entity|>", "<|detectentities|>"])
65
+ # new tokens for images
66
+ # tokenizer.add_tokens(["<|startofimage|>", "<|endofimage|>"])
67
+ # tokenizer.add_tokens([ f"<|image:{tkn}|>" for tkn in range(16000)])
68
+
69
+ with open("./../prompting/blist/all_rare_words.txt") as fin:
70
+ rarewords = [process(word.strip()) for word in fin]
71
+
72
+ def tokenize(element):
73
+
74
+ # Add audio
75
+ audio_tkns = element["audio_tokens"]
76
+ data = "".join([f"<|audio:{tkn}|>" for tkn in audio_tkns]) + "<|startofprompt|>"
77
+
78
+ # sample context words and mix with the biasing list
79
+ b_words = element["b_words"]
80
+ if n_bwords > len(b_words):
81
+ context = b_words + random.sample(rarewords, n_bwords - len(b_words))
82
+ else:
83
+ context = random.sample(b_words, n_bwords)
84
+ random.shuffle(context)
85
+
86
+ # add the context words
87
+ data += "<|sepofprompt|>".join(context)
88
+
89
+ # Add text
90
+ data += "<|endofprompt|><|startoftranscript|>" + element["text"] + "<|endoftranscript|>"
91
+
92
+ outputs = tokenizer(data, truncation=True, max_length=max_length, padding="max_length")
93
+ return {"input_ids": outputs["input_ids"], "attention_mask": outputs["attention_mask"]}
94
+
95
+ p_dataset = DatasetDict.load_from_disk("./../libripseech_tokenized")
96
+ prompt_dataset = p_dataset.map(
97
+ tokenize, batched=False, remove_columns = p_dataset["train.clean.100"].column_names
98
+ )
99
+
100
+ print("Total Vocab Size:", len(tokenizer))
101
+
102
+ model = GPT2LMHeadModel.from_pretrained("./../models/checkpoint-prompting")
103
+ model.resize_token_embeddings(len(tokenizer))
104
+
105
+ from transformers import DataCollatorForLanguageModeling
106
+
107
+ data_collator = DataCollatorForLanguageModeling(tokenizer, mlm = False)
108
+
109
+ config = {
110
+ "output_dir": "./out",
111
+ "max_steps": 20000,
112
+ "per_device_train_batch_size": 5,
113
+ "per_device_eval_batch_size": 5,
114
+ "gradient_accumulation_steps": 1,
115
+ "eval_strategy": "steps",
116
+ "save_strategy": "steps",
117
+ "eval_steps": 500,
118
+ "logging_steps": 1,
119
+ "logging_first_step": True,
120
+ "save_total_limit": 5,
121
+ "load_best_model_at_end": True,
122
+ "save_steps": 1000,
123
+ "lr_scheduler_type": "cosine",
124
+ "learning_rate": 1e-4,
125
+ "warmup_steps": 10,
126
+ "weight_decay": 0.01,
127
+ "report_to": "wandb",
128
+ "fp16": True
129
+ }
130
+
131
+ from argparse import Namespace
132
+
133
+ args = Namespace(**config)
134
+ train_args = TrainingArguments(**config)
135
+
136
+ wandb.init(project="multi_modal_exps", name="entity")
137
+
138
+ class GPTTrainer(Trainer):
139
+ def compute_loss(self, model, inputs, return_outputs=False):
140
+
141
+ labels = inputs.get("labels")
142
+ outputs = model(**inputs)
143
+ logits = outputs.get("logits")
144
+
145
+ labels = labels[:, 1:]
146
+ logits = logits[:, :-1, :]
147
+
148
+ print(logits.shape, labels.shape, torch.max(logits).item(), torch.max(labels).item(), torch.min(logits).item(), torch.min(labels).item())
149
+
150
+ loss = F.cross_entropy(torch.reshape(logits, (-1, logits.size(-1))), torch.reshape(labels, (-1, )), ignore_index=-100)
151
+
152
+ return (loss, outputs) if return_outputs else loss
153
+
154
+ @torch.no_grad()
155
+ def evaluation_loop(self, dataloader, description, prediction_loss_only=None, ignore_keys=None, metric_key_prefix="eval"):
156
+
157
+ eval_output = super().evaluation_loop(dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix)
158
+
159
+ wer, cer, b_wer, u_wer = evaluate_model(model)
160
+
161
+ wandb.log({
162
+ "Word Error Rate": wer,
163
+ "Char Error Rate": cer,
164
+ "Biased Word Error Rate": b_wer,
165
+ "Unbiased Word Error Rate": u_wer
166
+ })
167
+
168
+ return eval_output
169
+
170
+ trainer = GPTTrainer(
171
+ model = model,
172
+ tokenizer = tokenizer,
173
+ args = train_args,
174
+ data_collator = data_collator,
175
+ train_dataset = concatenate_datasets([dataset["train"], asr_dataset["train.clean.100"], prompt_dataset["train.clean.100"]]),
176
+ eval_dataset = concatenate_datasets([dataset["test"], asr_dataset["validation.clean"], prompt_dataset["validation.clean"]]),
177
+ )
178
+
179
+ trainer.train()