Commit
·
d09f022
1
Parent(s):
c2cd532
Upload TEVR Explanation.ipynb
Browse files- TEVR Explanation.ipynb +186 -0
TEVR Explanation.ipynb
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"id": "89c94977",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [],
|
9 |
+
"source": [
|
10 |
+
"from huggingface_hub import snapshot_download\n",
|
11 |
+
"data_folder = snapshot_download(\"fxtentacle/tevr-token-entropy-predictor-de\")"
|
12 |
+
]
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"cell_type": "code",
|
16 |
+
"execution_count": 2,
|
17 |
+
"id": "a48a49d6",
|
18 |
+
"metadata": {},
|
19 |
+
"outputs": [],
|
20 |
+
"source": [
|
21 |
+
"from transformers import T5ForConditionalGeneration\n",
|
22 |
+
"model = T5ForConditionalGeneration.from_pretrained(data_folder)\n",
|
23 |
+
"model.to('cuda')\n",
|
24 |
+
"model.eval()\n",
|
25 |
+
"None"
|
26 |
+
]
|
27 |
+
},
|
28 |
+
{
|
29 |
+
"cell_type": "code",
|
30 |
+
"execution_count": 3,
|
31 |
+
"id": "eed8bfc3",
|
32 |
+
"metadata": {},
|
33 |
+
"outputs": [],
|
34 |
+
"source": [
|
35 |
+
"import torch\n",
|
36 |
+
"\n",
|
37 |
+
"def text_to_cross_entropy(text):\n",
|
38 |
+
" ttext = torch.tensor([[0]+list(text.encode('UTF-8'))],dtype=torch.int64).to('cuda')\n",
|
39 |
+
" tone = torch.tensor([[1]],dtype=torch.int32).to('cuda')\n",
|
40 |
+
" logits = model.forward(input_ids=tone, attention_mask=tone, decoder_input_ids=ttext, return_dict=False)[0].detach()\n",
|
41 |
+
" cross_entropy = torch.nn.functional.cross_entropy(input=logits[0][:-1], target=ttext[0][1:], reduction='none').detach().cpu().numpy()\n",
|
42 |
+
" return cross_entropy"
|
43 |
+
]
|
44 |
+
},
|
45 |
+
{
|
46 |
+
"cell_type": "code",
|
47 |
+
"execution_count": 4,
|
48 |
+
"id": "8ec8cf8d",
|
49 |
+
"metadata": {},
|
50 |
+
"outputs": [],
|
51 |
+
"source": [
|
52 |
+
"import sys\n",
|
53 |
+
"import os\n",
|
54 |
+
"sys.path.append(data_folder)\n",
|
55 |
+
"from text_tokenizer import HajoTextTokenizer"
|
56 |
+
]
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"cell_type": "code",
|
60 |
+
"execution_count": 5,
|
61 |
+
"id": "37165805",
|
62 |
+
"metadata": {},
|
63 |
+
"outputs": [],
|
64 |
+
"source": [
|
65 |
+
"tokenizer_file = 'text-tokenizer-de-4m.txt'\n",
|
66 |
+
"text_tokenizer = HajoTextTokenizer(data_folder+'/'+tokenizer_file)"
|
67 |
+
]
|
68 |
+
},
|
69 |
+
{
|
70 |
+
"cell_type": "code",
|
71 |
+
"execution_count": 6,
|
72 |
+
"id": "73e55343",
|
73 |
+
"metadata": {},
|
74 |
+
"outputs": [
|
75 |
+
{
|
76 |
+
"name": "stdout",
|
77 |
+
"output_type": "stream",
|
78 |
+
"text": [
|
79 |
+
"['die', ' ', 'k', 'at', 'ze', ' ', 'ist', ' ', 'n', 'ied', 'lich']\n",
|
80 |
+
"[3.3762913048267365, 3.3762913048267365, 3.3762913048267365, 0.29695791006088257, 4.193424224853516, 2.3430762887001038, 2.3430762887001038, 2.8417416363954544, 2.8417416363954544, 1.1227068901062012, 2.017452405144771, 2.017452405144771, 2.017452405144771, 0.0016304069431498647, 2.580254554748535, 2.3091587026913962, 2.3091587026913962, 2.3091587026913962, 1.0126478232632508, 1.0126478232632508, 1.0126478232632508, 1.0126478232632508]\n"
|
81 |
+
]
|
82 |
+
}
|
83 |
+
],
|
84 |
+
"source": [
|
85 |
+
"text = \"die katze ist niedlich\"\n",
|
86 |
+
"cross_entropy = text_to_cross_entropy(text)\n",
|
87 |
+
"\n",
|
88 |
+
"tokens = text_tokenizer.encode(text)\n",
|
89 |
+
"tokens = [text_tokenizer.all_tokens[t] for t in tokens]\n",
|
90 |
+
"print(tokens)\n",
|
91 |
+
"token_sums = []\n",
|
92 |
+
"token_sums2 = []\n",
|
93 |
+
"for t in tokens:\n",
|
94 |
+
" ce = sum(cross_entropy[len(token_sums):len(token_sums)+len(t)])\n",
|
95 |
+
" for r in range(len(t)): token_sums.append(ce / len(t))\n",
|
96 |
+
" token_sums2.append(ce)\n",
|
97 |
+
"print(token_sums)"
|
98 |
+
]
|
99 |
+
},
|
100 |
+
{
|
101 |
+
"cell_type": "code",
|
102 |
+
"execution_count": 7,
|
103 |
+
"id": "e61e00aa",
|
104 |
+
"metadata": {},
|
105 |
+
"outputs": [
|
106 |
+
{
|
107 |
+
"data": {
|
108 |
+
"text/html": [
|
109 |
+
"<table style=\"font-size: 20px; font-family: Roboto\"><tr><td><b>(1)</b></td><td style=\"text-align:left\">d</td><td style=\"text-align:left\">i</td><td style=\"text-align:left\">e</td><td style=\"text-align:left\"> </td><td style=\"text-align:left\">k</td><td style=\"text-align:left\">a</td><td style=\"text-align:left\">t</td><td style=\"text-align:left\">z</td><td style=\"text-align:left\">e</td><td style=\"text-align:left\"> </td><td style=\"text-align:left\">i</td><td style=\"text-align:left\">s</td><td style=\"text-align:left\">t</td><td style=\"text-align:left\"> </td><td style=\"text-align:left\">n</td><td style=\"text-align:left\">i</td><td style=\"text-align:left\">e</td><td style=\"text-align:left\">d</td><td style=\"text-align:left\">l</td><td style=\"text-align:left\">i</td><td style=\"text-align:left\">c</td><td style=\"text-align:left\">h</td></tr><tr><td><b>(2)</b></td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>σ²=0.0</td></tr><tr><td><b>(3)</b></td><td>8.9</td><td>1.0</td><td>0.2</td><td>0.3</td><td>4.2</td><td>1.6</td><td>3.1</td><td>5.4</td><td>0.3</td><td>1.1</td><td>3.0</td><td>3.0</td><td>0.0</td><td>0.0</td><td>2.6</td><td>0.6</td><td>4.4</td><td>1.9</td><td>4.0</td><td>0.0</td><td>0.0</td><td>0.0</td><td>σ²=5.0</td></tr><tr><td><b>(4)</b></td><td style=\"text-align:center\" colspan=3>die</td><td style=\"text-align:center\" colspan=1> </td><td style=\"text-align:center\" colspan=1>k</td><td style=\"text-align:center\" colspan=2>at</td><td style=\"text-align:center\" colspan=2>ze</td><td style=\"text-align:center\" colspan=1> </td><td style=\"text-align:center\" colspan=3>ist</td><td style=\"text-align:center\" colspan=1> </td><td style=\"text-align:center\" colspan=1>n</td><td style=\"text-align:center\" colspan=3>ied</td><td style=\"text-align:center\" colspan=4>lich</td></tr><tr><td><b>(5)</b></td><td style=\"text-align:center\" colspan=3>10.1</td><td style=\"text-align:center\" colspan=1>0.3</td><td style=\"text-align:center\" colspan=1>4.2</td><td style=\"text-align:center\" colspan=2>4.7</td><td style=\"text-align:center\" colspan=2>5.7</td><td style=\"text-align:center\" colspan=1>1.1</td><td style=\"text-align:center\" colspan=3>6.1</td><td style=\"text-align:center\" colspan=1>0.0</td><td style=\"text-align:center\" colspan=1>2.6</td><td style=\"text-align:center\" colspan=3>6.9</td><td style=\"text-align:center\" colspan=4>4.1</td></tr><tr><td><b>(6)</b></td><td>3.4</td><td>3.4</td><td>3.4</td><td>0.3</td><td>4.2</td><td>2.3</td><td>2.3</td><td>2.8</td><td>2.8</td><td>1.1</td><td>2.0</td><td>2.0</td><td>2.0</td><td>0.0</td><td>2.6</td><td>2.3</td><td>2.3</td><td>2.3</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>σ²=1.1</td></tr></table>"
|
110 |
+
],
|
111 |
+
"text/plain": [
|
112 |
+
"<IPython.core.display.HTML object>"
|
113 |
+
]
|
114 |
+
},
|
115 |
+
"execution_count": 7,
|
116 |
+
"metadata": {},
|
117 |
+
"output_type": "execute_result"
|
118 |
+
}
|
119 |
+
],
|
120 |
+
"source": [
|
121 |
+
"import numpy as np\n",
|
122 |
+
"html = '<table style=\"font-size: 20px; font-family: Roboto\">'\n",
|
123 |
+
"html += '<tr><td><b>(1)</b></td>'+''.join([f'<td style=\"text-align:left\">{c}</td>' for c in list(text)])+'</tr>'\n",
|
124 |
+
"html += '<tr><td><b>(2)</b></td>'+''.join(['<td>1.0</td>'.format(v) for v in cross_entropy])+'<td>σ²={:3.1f}</td>'.format(np.var([1.0 for v in cross_entropy]))+'</tr>'\n",
|
125 |
+
"html += '<tr><td><b>(3)</b></td>'+''.join(['<td>{:3.1f}</td>'.format(v) for v in cross_entropy])+'<td>σ²={:3.1f}</td>'.format(np.var(cross_entropy))+'</tr>'\n",
|
126 |
+
"html += '<tr><td><b>(4)</b></td>'+''.join([f'<td style=\"text-align:center\" colspan={len(t)}>{t}</td>' for t in tokens])+'</tr>'\n",
|
127 |
+
"html += '<tr><td><b>(5)</b></td>'+''.join([f'<td style=\"text-align:center\" colspan={len(t)}>{\"{:3.1f}\".format(token_sums2[i])}</td>' for i,t in enumerate(tokens)])+'</tr>'\n",
|
128 |
+
"html += '<tr><td><b>(6)</b></td>'+''.join(['<td>{:3.1f}</td>'.format(v) for v in token_sums])+'<td>σ²={:3.1f}</td>'.format(np.var(token_sums))+'</tr>'\n",
|
129 |
+
"html += '</table>'\n",
|
130 |
+
"\n",
|
131 |
+
"import IPython\n",
|
132 |
+
"IPython.display.HTML(html)"
|
133 |
+
]
|
134 |
+
},
|
135 |
+
{
|
136 |
+
"cell_type": "code",
|
137 |
+
"execution_count": 8,
|
138 |
+
"id": "dcafdcab",
|
139 |
+
"metadata": {},
|
140 |
+
"outputs": [
|
141 |
+
{
|
142 |
+
"name": "stdout",
|
143 |
+
"output_type": "stream",
|
144 |
+
"text": [
|
145 |
+
"<pad>, <eos>, , chen, sche, lich, isch, icht, iche, eine, rden, tion, urde, haft, eich, rung, chte, ssen, chaf, nder, tlic, tung, eite, iert, sich, ngen, erde, scha, nden, unge, lung, mmen, eren, ende, inde, erun, sten, iese, igen, erte, iner, tsch, keit, der, die, ter, und, ein, ist, den, ten, ber, ver, sch, ung, ste, ent, ach, nte, auf, ben, eit, des, ers, aus, das, von, ren, gen, nen, lle, hre, mit, iel, uch, lte, ann, lie, men, dem, and, ind, als, sta, elt, ges, tte, ern, wir, ell, war, ere, rch, abe, len, ige, ied, ger, nnt, wei, ele, och, sse, end, all, ahr, bei, sie, ede, ion, ieg, ege, auc, che, rie, eis, vor, her, ang, für, ass, uss, tel, er, in, ge, en, st, ie, an, te, be, re, zu, ar, es, ra, al, or, ch, et, ei, un, le, rt, se, is, ha, we, at, me, ne, ur, he, au, ro, ti, li, ri, eh, im, ma, tr, ig, el, um, la, am, de, so, ol, tz, il, on, it, sc, sp, ko, na, pr, ni, si, fe, wi, ns, ke, ut, da, gr, eu, mi, hr, ze, hi, ta, ss, ng, sa, us, ba, ck, em, kt, ka, ve, fr, bi, wa, ah, gt, di, ab, fo, to, rk, as, ag, gi, hn, s, t, n, m, r, l, f, e, a, b, d, h, k, g, o, i, u, w, p, z, ä, ü, v, ö, j, c, y, x, q, á, í, ō, ó, š, é, č, ?\n"
|
146 |
+
]
|
147 |
+
}
|
148 |
+
],
|
149 |
+
"source": [
|
150 |
+
"from text_tokenizer import HajoTextTokenizer\n",
|
151 |
+
"text_tokenizer = HajoTextTokenizer(data_folder+'/'+tokenizer_file)\n",
|
152 |
+
"tt = text_tokenizer.all_tokens\n",
|
153 |
+
"print(', '.join(tt))"
|
154 |
+
]
|
155 |
+
},
|
156 |
+
{
|
157 |
+
"cell_type": "code",
|
158 |
+
"execution_count": null,
|
159 |
+
"id": "b87b7fd0",
|
160 |
+
"metadata": {},
|
161 |
+
"outputs": [],
|
162 |
+
"source": []
|
163 |
+
}
|
164 |
+
],
|
165 |
+
"metadata": {
|
166 |
+
"kernelspec": {
|
167 |
+
"display_name": "Python 3 (ipykernel)",
|
168 |
+
"language": "python",
|
169 |
+
"name": "python3"
|
170 |
+
},
|
171 |
+
"language_info": {
|
172 |
+
"codemirror_mode": {
|
173 |
+
"name": "ipython",
|
174 |
+
"version": 3
|
175 |
+
},
|
176 |
+
"file_extension": ".py",
|
177 |
+
"mimetype": "text/x-python",
|
178 |
+
"name": "python",
|
179 |
+
"nbconvert_exporter": "python",
|
180 |
+
"pygments_lexer": "ipython3",
|
181 |
+
"version": "3.7.5"
|
182 |
+
}
|
183 |
+
},
|
184 |
+
"nbformat": 4,
|
185 |
+
"nbformat_minor": 5
|
186 |
+
}
|