fxtentacle commited on
Commit
d09f022
·
1 Parent(s): c2cd532

Upload TEVR Explanation.ipynb

Browse files
Files changed (1) hide show
  1. TEVR Explanation.ipynb +186 -0
TEVR Explanation.ipynb ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "89c94977",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "from huggingface_hub import snapshot_download\n",
11
+ "data_folder = snapshot_download(\"fxtentacle/tevr-token-entropy-predictor-de\")"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": 2,
17
+ "id": "a48a49d6",
18
+ "metadata": {},
19
+ "outputs": [],
20
+ "source": [
21
+ "from transformers import T5ForConditionalGeneration\n",
22
+ "model = T5ForConditionalGeneration.from_pretrained(data_folder)\n",
23
+ "model.to('cuda')\n",
24
+ "model.eval()\n",
25
+ "None"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "code",
30
+ "execution_count": 3,
31
+ "id": "eed8bfc3",
32
+ "metadata": {},
33
+ "outputs": [],
34
+ "source": [
35
+ "import torch\n",
36
+ "\n",
37
+ "def text_to_cross_entropy(text):\n",
38
+ " ttext = torch.tensor([[0]+list(text.encode('UTF-8'))],dtype=torch.int64).to('cuda')\n",
39
+ " tone = torch.tensor([[1]],dtype=torch.int32).to('cuda')\n",
40
+ " logits = model.forward(input_ids=tone, attention_mask=tone, decoder_input_ids=ttext, return_dict=False)[0].detach()\n",
41
+ " cross_entropy = torch.nn.functional.cross_entropy(input=logits[0][:-1], target=ttext[0][1:], reduction='none').detach().cpu().numpy()\n",
42
+ " return cross_entropy"
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "code",
47
+ "execution_count": 4,
48
+ "id": "8ec8cf8d",
49
+ "metadata": {},
50
+ "outputs": [],
51
+ "source": [
52
+ "import sys\n",
53
+ "import os\n",
54
+ "sys.path.append(data_folder)\n",
55
+ "from text_tokenizer import HajoTextTokenizer"
56
+ ]
57
+ },
58
+ {
59
+ "cell_type": "code",
60
+ "execution_count": 5,
61
+ "id": "37165805",
62
+ "metadata": {},
63
+ "outputs": [],
64
+ "source": [
65
+ "tokenizer_file = 'text-tokenizer-de-4m.txt'\n",
66
+ "text_tokenizer = HajoTextTokenizer(data_folder+'/'+tokenizer_file)"
67
+ ]
68
+ },
69
+ {
70
+ "cell_type": "code",
71
+ "execution_count": 6,
72
+ "id": "73e55343",
73
+ "metadata": {},
74
+ "outputs": [
75
+ {
76
+ "name": "stdout",
77
+ "output_type": "stream",
78
+ "text": [
79
+ "['die', ' ', 'k', 'at', 'ze', ' ', 'ist', ' ', 'n', 'ied', 'lich']\n",
80
+ "[3.3762913048267365, 3.3762913048267365, 3.3762913048267365, 0.29695791006088257, 4.193424224853516, 2.3430762887001038, 2.3430762887001038, 2.8417416363954544, 2.8417416363954544, 1.1227068901062012, 2.017452405144771, 2.017452405144771, 2.017452405144771, 0.0016304069431498647, 2.580254554748535, 2.3091587026913962, 2.3091587026913962, 2.3091587026913962, 1.0126478232632508, 1.0126478232632508, 1.0126478232632508, 1.0126478232632508]\n"
81
+ ]
82
+ }
83
+ ],
84
+ "source": [
85
+ "text = \"die katze ist niedlich\"\n",
86
+ "cross_entropy = text_to_cross_entropy(text)\n",
87
+ "\n",
88
+ "tokens = text_tokenizer.encode(text)\n",
89
+ "tokens = [text_tokenizer.all_tokens[t] for t in tokens]\n",
90
+ "print(tokens)\n",
91
+ "token_sums = []\n",
92
+ "token_sums2 = []\n",
93
+ "for t in tokens:\n",
94
+ " ce = sum(cross_entropy[len(token_sums):len(token_sums)+len(t)])\n",
95
+ " for r in range(len(t)): token_sums.append(ce / len(t))\n",
96
+ " token_sums2.append(ce)\n",
97
+ "print(token_sums)"
98
+ ]
99
+ },
100
+ {
101
+ "cell_type": "code",
102
+ "execution_count": 7,
103
+ "id": "e61e00aa",
104
+ "metadata": {},
105
+ "outputs": [
106
+ {
107
+ "data": {
108
+ "text/html": [
109
+ "<table style=\"font-size: 20px; font-family: Roboto\"><tr><td><b>(1)</b></td><td style=\"text-align:left\">d</td><td style=\"text-align:left\">i</td><td style=\"text-align:left\">e</td><td style=\"text-align:left\"> </td><td style=\"text-align:left\">k</td><td style=\"text-align:left\">a</td><td style=\"text-align:left\">t</td><td style=\"text-align:left\">z</td><td style=\"text-align:left\">e</td><td style=\"text-align:left\"> </td><td style=\"text-align:left\">i</td><td style=\"text-align:left\">s</td><td style=\"text-align:left\">t</td><td style=\"text-align:left\"> </td><td style=\"text-align:left\">n</td><td style=\"text-align:left\">i</td><td style=\"text-align:left\">e</td><td style=\"text-align:left\">d</td><td style=\"text-align:left\">l</td><td style=\"text-align:left\">i</td><td style=\"text-align:left\">c</td><td style=\"text-align:left\">h</td></tr><tr><td><b>(2)</b></td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>σ²=0.0</td></tr><tr><td><b>(3)</b></td><td>8.9</td><td>1.0</td><td>0.2</td><td>0.3</td><td>4.2</td><td>1.6</td><td>3.1</td><td>5.4</td><td>0.3</td><td>1.1</td><td>3.0</td><td>3.0</td><td>0.0</td><td>0.0</td><td>2.6</td><td>0.6</td><td>4.4</td><td>1.9</td><td>4.0</td><td>0.0</td><td>0.0</td><td>0.0</td><td>σ²=5.0</td></tr><tr><td><b>(4)</b></td><td style=\"text-align:center\" colspan=3>die</td><td style=\"text-align:center\" colspan=1> </td><td style=\"text-align:center\" colspan=1>k</td><td style=\"text-align:center\" colspan=2>at</td><td style=\"text-align:center\" colspan=2>ze</td><td style=\"text-align:center\" colspan=1> </td><td style=\"text-align:center\" colspan=3>ist</td><td style=\"text-align:center\" colspan=1> </td><td style=\"text-align:center\" colspan=1>n</td><td style=\"text-align:center\" colspan=3>ied</td><td style=\"text-align:center\" colspan=4>lich</td></tr><tr><td><b>(5)</b></td><td style=\"text-align:center\" colspan=3>10.1</td><td style=\"text-align:center\" colspan=1>0.3</td><td style=\"text-align:center\" colspan=1>4.2</td><td style=\"text-align:center\" colspan=2>4.7</td><td style=\"text-align:center\" colspan=2>5.7</td><td style=\"text-align:center\" colspan=1>1.1</td><td style=\"text-align:center\" colspan=3>6.1</td><td style=\"text-align:center\" colspan=1>0.0</td><td style=\"text-align:center\" colspan=1>2.6</td><td style=\"text-align:center\" colspan=3>6.9</td><td style=\"text-align:center\" colspan=4>4.1</td></tr><tr><td><b>(6)</b></td><td>3.4</td><td>3.4</td><td>3.4</td><td>0.3</td><td>4.2</td><td>2.3</td><td>2.3</td><td>2.8</td><td>2.8</td><td>1.1</td><td>2.0</td><td>2.0</td><td>2.0</td><td>0.0</td><td>2.6</td><td>2.3</td><td>2.3</td><td>2.3</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>σ²=1.1</td></tr></table>"
110
+ ],
111
+ "text/plain": [
112
+ "<IPython.core.display.HTML object>"
113
+ ]
114
+ },
115
+ "execution_count": 7,
116
+ "metadata": {},
117
+ "output_type": "execute_result"
118
+ }
119
+ ],
120
+ "source": [
121
+ "import numpy as np\n",
122
+ "html = '<table style=\"font-size: 20px; font-family: Roboto\">'\n",
123
+ "html += '<tr><td><b>(1)</b></td>'+''.join([f'<td style=\"text-align:left\">{c}</td>' for c in list(text)])+'</tr>'\n",
124
+ "html += '<tr><td><b>(2)</b></td>'+''.join(['<td>1.0</td>'.format(v) for v in cross_entropy])+'<td>σ²={:3.1f}</td>'.format(np.var([1.0 for v in cross_entropy]))+'</tr>'\n",
125
+ "html += '<tr><td><b>(3)</b></td>'+''.join(['<td>{:3.1f}</td>'.format(v) for v in cross_entropy])+'<td>σ²={:3.1f}</td>'.format(np.var(cross_entropy))+'</tr>'\n",
126
+ "html += '<tr><td><b>(4)</b></td>'+''.join([f'<td style=\"text-align:center\" colspan={len(t)}>{t}</td>' for t in tokens])+'</tr>'\n",
127
+ "html += '<tr><td><b>(5)</b></td>'+''.join([f'<td style=\"text-align:center\" colspan={len(t)}>{\"{:3.1f}\".format(token_sums2[i])}</td>' for i,t in enumerate(tokens)])+'</tr>'\n",
128
+ "html += '<tr><td><b>(6)</b></td>'+''.join(['<td>{:3.1f}</td>'.format(v) for v in token_sums])+'<td>σ²={:3.1f}</td>'.format(np.var(token_sums))+'</tr>'\n",
129
+ "html += '</table>'\n",
130
+ "\n",
131
+ "import IPython\n",
132
+ "IPython.display.HTML(html)"
133
+ ]
134
+ },
135
+ {
136
+ "cell_type": "code",
137
+ "execution_count": 8,
138
+ "id": "dcafdcab",
139
+ "metadata": {},
140
+ "outputs": [
141
+ {
142
+ "name": "stdout",
143
+ "output_type": "stream",
144
+ "text": [
145
+ "<pad>, <eos>, , chen, sche, lich, isch, icht, iche, eine, rden, tion, urde, haft, eich, rung, chte, ssen, chaf, nder, tlic, tung, eite, iert, sich, ngen, erde, scha, nden, unge, lung, mmen, eren, ende, inde, erun, sten, iese, igen, erte, iner, tsch, keit, der, die, ter, und, ein, ist, den, ten, ber, ver, sch, ung, ste, ent, ach, nte, auf, ben, eit, des, ers, aus, das, von, ren, gen, nen, lle, hre, mit, iel, uch, lte, ann, lie, men, dem, and, ind, als, sta, elt, ges, tte, ern, wir, ell, war, ere, rch, abe, len, ige, ied, ger, nnt, wei, ele, och, sse, end, all, ahr, bei, sie, ede, ion, ieg, ege, auc, che, rie, eis, vor, her, ang, für, ass, uss, tel, er, in, ge, en, st, ie, an, te, be, re, zu, ar, es, ra, al, or, ch, et, ei, un, le, rt, se, is, ha, we, at, me, ne, ur, he, au, ro, ti, li, ri, eh, im, ma, tr, ig, el, um, la, am, de, so, ol, tz, il, on, it, sc, sp, ko, na, pr, ni, si, fe, wi, ns, ke, ut, da, gr, eu, mi, hr, ze, hi, ta, ss, ng, sa, us, ba, ck, em, kt, ka, ve, fr, bi, wa, ah, gt, di, ab, fo, to, rk, as, ag, gi, hn, s, t, n, m, r, l, f, e, a, b, d, h, k, g, o, i, u, w, p, z, ä, ü, v, ö, j, c, y, x, q, á, í, ō, ó, š, é, č, ?\n"
146
+ ]
147
+ }
148
+ ],
149
+ "source": [
150
+ "from text_tokenizer import HajoTextTokenizer\n",
151
+ "text_tokenizer = HajoTextTokenizer(data_folder+'/'+tokenizer_file)\n",
152
+ "tt = text_tokenizer.all_tokens\n",
153
+ "print(', '.join(tt))"
154
+ ]
155
+ },
156
+ {
157
+ "cell_type": "code",
158
+ "execution_count": null,
159
+ "id": "b87b7fd0",
160
+ "metadata": {},
161
+ "outputs": [],
162
+ "source": []
163
+ }
164
+ ],
165
+ "metadata": {
166
+ "kernelspec": {
167
+ "display_name": "Python 3 (ipykernel)",
168
+ "language": "python",
169
+ "name": "python3"
170
+ },
171
+ "language_info": {
172
+ "codemirror_mode": {
173
+ "name": "ipython",
174
+ "version": 3
175
+ },
176
+ "file_extension": ".py",
177
+ "mimetype": "text/x-python",
178
+ "name": "python",
179
+ "nbconvert_exporter": "python",
180
+ "pygments_lexer": "ipython3",
181
+ "version": "3.7.5"
182
+ }
183
+ },
184
+ "nbformat": 4,
185
+ "nbformat_minor": 5
186
+ }