Derur commited on
Commit
23d6cc4
·
verified ·
1 Parent(s): 7a67169

Upload 35 files

Browse files
Files changed (36) hide show
  1. .gitattributes +5 -0
  2. punctuation/vosk-recasepunc-de-0.21.7z +3 -0
  3. punctuation/vosk-recasepunc-de-0.21/README +7 -0
  4. punctuation/vosk-recasepunc-de-0.21/checkpoint +3 -0
  5. punctuation/vosk-recasepunc-de-0.21/de-test.txt +6 -0
  6. punctuation/vosk-recasepunc-de-0.21/de-test.txt.orig +6 -0
  7. punctuation/vosk-recasepunc-de-0.21/example.py +23 -0
  8. punctuation/vosk-recasepunc-de-0.21/recasepunc.py +742 -0
  9. punctuation/vosk-recasepunc-en-0.22.7z +3 -0
  10. punctuation/vosk-recasepunc-en-0.22/README +7 -0
  11. punctuation/vosk-recasepunc-en-0.22/checkpoint +3 -0
  12. punctuation/vosk-recasepunc-en-0.22/example.py +26 -0
  13. punctuation/vosk-recasepunc-en-0.22/recasepunc.py +742 -0
  14. punctuation/vosk-recasepunc-en-0.22/vosk-adapted.txt +17 -0
  15. punctuation/vosk-recasepunc-en-0.22/vosk-adapted.txt.punc +1 -0
  16. punctuation/vosk-recasepunc-ru-0.22.7z +3 -0
  17. punctuation/vosk-recasepunc-ru-0.22/README +7 -0
  18. punctuation/vosk-recasepunc-ru-0.22/checkpoint +3 -0
  19. punctuation/vosk-recasepunc-ru-0.22/example.py +23 -0
  20. punctuation/vosk-recasepunc-ru-0.22/recasepunc.py +743 -0
  21. punctuation/vosk-recasepunc-ru-0.22/ru-test.txt +17 -0
  22. punctuation/vosk-recasepunc-ru-0.22/ru-test.txt.orig +17 -0
  23. speaker_indentification/vosk-model-spk-0.4.7z +3 -0
  24. speaker_indentification/vosk-model-spk-0.4/README.txt +119 -0
  25. speaker_indentification/vosk-model-spk-0.4/final.ext.raw +3 -0
  26. speaker_indentification/vosk-model-spk-0.4/mean.vec +1 -0
  27. speaker_indentification/vosk-model-spk-0.4/mfcc.conf +5 -0
  28. speaker_indentification/vosk-model-spk-0.4/transform.mat +0 -0
  29. tts/vosk-model-tts-ru-0.9-multi.7z +3 -0
  30. tts/vosk-model-tts-ru-0.9-multi/README.md +22 -0
  31. tts/vosk-model-tts-ru-0.9-multi/bert/README.md +39 -0
  32. tts/vosk-model-tts-ru-0.9-multi/bert/model.onnx +3 -0
  33. tts/vosk-model-tts-ru-0.9-multi/bert/vocab.txt +0 -0
  34. tts/vosk-model-tts-ru-0.9-multi/config.json +85 -0
  35. tts/vosk-model-tts-ru-0.9-multi/dictionary +3 -0
  36. tts/vosk-model-tts-ru-0.9-multi/model.onnx +3 -0
.gitattributes CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ punctuation/vosk-recasepunc-de-0.21/checkpoint filter=lfs diff=lfs merge=lfs -text
37
+ punctuation/vosk-recasepunc-en-0.22/checkpoint filter=lfs diff=lfs merge=lfs -text
38
+ punctuation/vosk-recasepunc-ru-0.22/checkpoint filter=lfs diff=lfs merge=lfs -text
39
+ speaker_indentification/vosk-model-spk-0.4/final.ext.raw filter=lfs diff=lfs merge=lfs -text
40
+ tts/vosk-model-tts-ru-0.9-multi/dictionary filter=lfs diff=lfs merge=lfs -text
punctuation/vosk-recasepunc-de-0.21.7z ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:42e06dab56196498cde6e89c5e6d4b97cab72942827153b06180f6a156bebc0b
3
+ size 1153855092
punctuation/vosk-recasepunc-de-0.21/README ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ 1. Install pytorch and transformers:
2
+
3
+ pip3 install transformers
4
+
5
+ 2. Run python3 example.py de-test.txt
6
+
7
+ 3. Compare with de-test.txt.orig
punctuation/vosk-recasepunc-de-0.21/checkpoint ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:90c70f58d865013bf1245d3c7fae229d6029ee0b16470c055f81a63e668de685
3
+ size 1315574525
punctuation/vosk-recasepunc-de-0.21/de-test.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ nachdem sein vater schon 1707 starb als reinhart erst elf jahre alt war
2
+ wurde er von hauslehrern in seega erzogen hierauf kam er 1708 in die
3
+ stadtschule nach frankenhausen und war dort von der dritten bis zur
4
+ ersten klasse der bekannte schulmann magister hoffmann stand der schule
5
+ als rektor vor unter dem er publice prodiret hatte also öffentlich
6
+ aufgetreten war um eine rede zu halten
punctuation/vosk-recasepunc-de-0.21/de-test.txt.orig ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ Nachdem sein Vater schon 1707 starb, als Reinhart erst elf Jahre alt war,
2
+ wurde er von Hauslehrern in Seega erzogen. Hierauf kam er 1708 in die
3
+ Stadtschule nach Frankenhausen und war dort von der dritten bis zur
4
+ ersten Klasse. Der bekannte Schulmann Magister Hoffmann stand der Schule
5
+ als Rektor vor, unter dem er publice prodiret hatte, also öffentlich
6
+ aufgetreten war, um eine Rede zu halten.
punctuation/vosk-recasepunc-de-0.21/example.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import time
3
+ from transformers import logging
4
+ from recasepunc import CasePuncPredictor
5
+ from recasepunc import WordpieceTokenizer
6
+ from recasepunc import Config
7
+
8
+ logging.set_verbosity_error()
9
+
10
+ predictor = CasePuncPredictor('checkpoint', lang="de")
11
+
12
+ text = " ".join(open(sys.argv[1]).readlines())
13
+ tokens = list(enumerate(predictor.tokenize(text)))
14
+
15
+ results = ""
16
+ for token, case_label, punc_label in predictor.predict(tokens, lambda x: x[1]):
17
+ prediction = predictor.map_punc_label(predictor.map_case_label(token[1], case_label), punc_label)
18
+ if token[1][0] != '#':
19
+ results = results + ' ' + prediction
20
+ else:
21
+ results = results + prediction
22
+
23
+ print (results.strip())
punctuation/vosk-recasepunc-de-0.21/recasepunc.py ADDED
@@ -0,0 +1,742 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import collections
3
+ import os
4
+ import regex as re
5
+ #from mosestokenizer import *
6
+ from tqdm import tqdm
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import torch.optim as optim
11
+ import random
12
+ import unicodedata
13
+ import numpy as np
14
+ import argparse
15
+ from torch.utils.data import TensorDataset, DataLoader
16
+
17
+ from transformers import AutoModel, AutoTokenizer, BertTokenizer
18
+
19
+ default_config = argparse.Namespace(
20
+ seed=871253,
21
+ lang='de',
22
+ #flavor='flaubert/flaubert_base_uncased',
23
+ flavor=None,
24
+ max_length=256,
25
+ batch_size=16,
26
+ updates=24000,
27
+ period=1000,
28
+ lr=1e-5,
29
+ dab_rate=0.1,
30
+ device='cuda',
31
+ debug=False
32
+ )
33
+
34
+ default_flavors = {
35
+ 'fr': 'flaubert/flaubert_base_uncased',
36
+ 'en': 'bert-base-uncased',
37
+ 'zh': 'ckiplab/bert-base-chinese',
38
+ 'tr': 'dbmdz/bert-base-turkish-uncased',
39
+ 'de': 'dbmdz/bert-base-german-uncased',
40
+ 'pt': 'neuralmind/bert-base-portuguese-cased'
41
+ }
42
+
43
+ class Config(argparse.Namespace):
44
+ def __init__(self, **kwargs):
45
+ for key, value in default_config.__dict__.items():
46
+ setattr(self, key, value)
47
+ for key, value in kwargs.items():
48
+ setattr(self, key, value)
49
+
50
+ assert self.lang in ['fr', 'en', 'zh', 'tr', 'pt', 'de']
51
+
52
+ if 'lang' in kwargs and ('flavor' not in kwargs or kwargs['flavor'] is None):
53
+ self.flavor = default_flavors[self.lang]
54
+
55
+ #print(self.lang, self.flavor)
56
+
57
+
58
+ def init_random(seed):
59
+ # make sure everything is deterministic
60
+ os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
61
+ #torch.use_deterministic_algorithms(True)
62
+ torch.manual_seed(seed)
63
+ torch.cuda.manual_seed_all(seed)
64
+ random.seed(seed)
65
+ np.random.seed(seed)
66
+
67
+ # NOTE: it is assumed in the implementation that y[:,0] is the punctuation label, and y[:,1] is the case label!
68
+
69
+ punctuation = {
70
+ 'O': 0,
71
+ 'COMMA': 1,
72
+ 'PERIOD': 2,
73
+ 'QUESTION': 3,
74
+ 'EXCLAMATION': 4,
75
+ }
76
+
77
+ punctuation_syms = ['', ',', '.', ' ?', ' !']
78
+
79
+ case = {
80
+ 'LOWER': 0,
81
+ 'UPPER': 1,
82
+ 'CAPITALIZE': 2,
83
+ 'OTHER': 3,
84
+ }
85
+
86
+
87
+ class Model(nn.Module):
88
+ def __init__(self, flavor, device):
89
+ super().__init__()
90
+ self.bert = AutoModel.from_pretrained(flavor)
91
+ # need a proper way of determining representation size
92
+ size = self.bert.dim if hasattr(self.bert, 'dim') else self.bert.config.pooler_fc_size if hasattr(self.bert.config, 'pooler_fc_size') else self.bert.config.emb_dim if hasattr(self.bert.config, 'emb_dim') else self.bert.config.hidden_size
93
+ self.punc = nn.Linear(size, 5)
94
+ self.case = nn.Linear(size, 4)
95
+ self.dropout = nn.Dropout(0.3)
96
+ self.to(device)
97
+
98
+ def forward(self, x):
99
+ output = self.bert(x)
100
+ representations = self.dropout(F.gelu(output['last_hidden_state']))
101
+ punc = self.punc(representations)
102
+ case = self.case(representations)
103
+ return punc, case
104
+
105
+
106
+ # randomly create sequences that align to punctuation boundaries
107
+ def drop_at_boundaries(rate, x, y, cls_token_id, sep_token_id, pad_token_id):
108
+ for i, dropped in enumerate(torch.rand((len(x),)) < rate):
109
+ if dropped:
110
+ # select all indices that are sentence endings
111
+ indices = (y[i,:,0] > 1).nonzero(as_tuple=True)[0]
112
+ if len(indices) < 2:
113
+ continue
114
+ start = indices[0] + 1
115
+ end = indices[random.randint(1, len(indices) - 1)] + 1
116
+ length = end - start
117
+ if length + 2 > len(x[i]):
118
+ continue
119
+ x[i, 0] = cls_token_id
120
+ x[i, 1: length + 1] = x[i, start: end].clone()
121
+ x[i, length + 1] = sep_token_id
122
+ x[i, length + 2:] = pad_token_id
123
+ y[i, 0] = 0
124
+ y[i, 1: length + 1] = y[i, start: end].clone()
125
+ y[i, length + 1:] = 0
126
+
127
+
128
+ def compute_performance(config, model, loader):
129
+ device = config.device
130
+ criterion = nn.CrossEntropyLoss()
131
+ model.eval()
132
+ total_loss = all_correct1 = all_correct2 = num_loss = num_perf = 0
133
+ num_ref = collections.defaultdict(float)
134
+ num_hyp = collections.defaultdict(float)
135
+ num_correct = collections.defaultdict(float)
136
+ for x, y in loader:
137
+ x = x.long().to(device)
138
+ y = y.long().to(device)
139
+ y1 = y[:,:,0]
140
+ y2 = y[:,:,1]
141
+ with torch.no_grad():
142
+ y_scores1, y_scores2 = model(x.to(device))
143
+ loss1 = criterion(y_scores1.view(y1.size(0) * y1.size(1), -1), y1.view(y1.size(0) * y1.size(1)))
144
+ loss2 = criterion(y_scores2.view(y2.size(0) * y2.size(1), -1), y2.view(y2.size(0) * y2.size(1)))
145
+ loss = loss1 + loss2
146
+ y_pred1 = torch.max(y_scores1, 2)[1]
147
+ y_pred2 = torch.max(y_scores2, 2)[1]
148
+ for label in range(1, 5):
149
+ ref = (y1 == label)
150
+ hyp = (y_pred1 == label)
151
+ correct = (ref * hyp == 1)
152
+ num_ref[label] += ref.sum()
153
+ num_hyp[label] += hyp.sum()
154
+ num_correct[label] += correct.sum()
155
+ num_ref[0] += ref.sum()
156
+ num_hyp[0] += hyp.sum()
157
+ num_correct[0] += correct.sum()
158
+ all_correct1 += (y_pred1 == y1).sum()
159
+ all_correct2 += (y_pred2 == y2).sum()
160
+ total_loss += loss.item()
161
+ num_loss += len(y)
162
+ num_perf += len(y) * config.max_length
163
+ recall = {}
164
+ precision = {}
165
+ fscore = {}
166
+ for label in range(0, 5):
167
+ recall[label] = num_correct[label] / num_ref[label] if num_ref[label] > 0 else 0
168
+ precision[label] = num_correct[label] / num_hyp[label] if num_hyp[label] > 0 else 0
169
+ fscore[label] = (2 * recall[label] * precision[label] / (recall[label] + precision[label])).item() if recall[label] + precision[label] > 0 else 0
170
+ return total_loss / num_loss, all_correct2.item() / num_perf, all_correct1.item() / num_perf, fscore
171
+
172
+
173
+ def fit(config, model, checkpoint_path, train_loader, valid_loader, iterations, valid_period=200, lr=1e-5):
174
+ device = config.device
175
+ criterion = nn.CrossEntropyLoss()
176
+ optimizer = optim.Adam(filter(lambda param: param.requires_grad, model.parameters()), lr=lr)
177
+ iteration = 0
178
+ while True:
179
+ model.train()
180
+ total_loss = num = 0
181
+ for x, y in tqdm(train_loader):
182
+ x = x.long().to(device)
183
+ y = y.long().to(device)
184
+ drop_at_boundaries(config.dab_rate, x, y, config.cls_token_id, config.sep_token_id, config.pad_token_id)
185
+ y1 = y[:,:,0]
186
+ y2 = y[:,:,1]
187
+ optimizer.zero_grad()
188
+ y_scores1, y_scores2 = model(x)
189
+ loss1 = criterion(y_scores1.view(y1.size(0) * y1.size(1), -1), y1.view(y1.size(0) * y1.size(1)))
190
+ loss2 = criterion(y_scores2.view(y2.size(0) * y2.size(1), -1), y2.view(y2.size(0) * y2.size(1)))
191
+ loss = loss1 + loss2
192
+ loss.backward()
193
+ optimizer.step()
194
+ total_loss += loss.item()
195
+ num += len(y)
196
+ if iteration % valid_period == valid_period - 1:
197
+ train_loss = total_loss / num
198
+ valid_loss, valid_accuracy_case, valid_accuracy_punc, valid_fscore = compute_performance(config, model, valid_loader)
199
+ torch.save({
200
+ 'iteration': iteration + 1,
201
+ 'model_state_dict': model.state_dict(),
202
+ 'optimizer_state_dict': optimizer.state_dict(),
203
+ 'train_loss': train_loss,
204
+ 'valid_loss': valid_loss,
205
+ 'valid_accuracy_case': valid_accuracy_case,
206
+ 'valid_accuracy_punc': valid_accuracy_punc,
207
+ 'valid_fscore': valid_fscore,
208
+ 'config': config.__dict__,
209
+ }, '%s.%d' % (checkpoint_path, iteration + 1))
210
+ print(iteration + 1, train_loss, valid_loss, valid_accuracy_case, valid_accuracy_punc, valid_fscore)
211
+ total_loss = num = 0
212
+
213
+ iteration += 1
214
+ if iteration > iterations:
215
+ return
216
+
217
+ sys.stderr.flush()
218
+ sys.stdout.flush()
219
+
220
+
221
+ def batchify(max_length, x, y):
222
+ print (x.shape)
223
+ print (y.shape)
224
+ x = x[:(len(x) // max_length) * max_length].reshape(-1, max_length)
225
+ y = y[:(len(y) // max_length) * max_length, :].reshape(-1, max_length, 2)
226
+ return x, y
227
+
228
+
229
+ def train(config, train_x_fn, train_y_fn, valid_x_fn, valid_y_fn, checkpoint_path):
230
+ X_train, Y_train = batchify(config.max_length, torch.load(train_x_fn), torch.load(train_y_fn))
231
+ X_valid, Y_valid = batchify(config.max_length, torch.load(valid_x_fn), torch.load(valid_y_fn))
232
+
233
+ train_set = TensorDataset(X_train, Y_train)
234
+ valid_set = TensorDataset(X_valid, Y_valid)
235
+
236
+ train_loader = DataLoader(train_set, batch_size=config.batch_size, shuffle=True)
237
+ valid_loader = DataLoader(valid_set, batch_size=config.batch_size)
238
+
239
+ model = Model(config.flavor, config.device)
240
+
241
+ fit(config, model, checkpoint_path, train_loader, valid_loader, config.updates, config.period, config.lr)
242
+
243
+
244
+ def run_eval(config, test_x_fn, test_y_fn, checkpoint_path):
245
+ X_test, Y_test = batchify(config.max_length, torch.load(test_x_fn), torch.load(test_y_fn))
246
+ test_set = TensorDataset(X_test, Y_test)
247
+ test_loader = DataLoader(test_set, batch_size=config.batch_size)
248
+
249
+ loaded = torch.load(checkpoint_path, map_location=config.device)
250
+ if 'config' in loaded:
251
+ config = Config(**loaded['config'])
252
+ init(config)
253
+
254
+ model = Model(config.flavor, config.device)
255
+ model.load_state_dict(loaded['model_state_dict'])
256
+
257
+ print(*compute_performance(config, model, test_loader))
258
+
259
+
260
+ def recase(token, label):
261
+ if label == case['LOWER']:
262
+ return token.lower()
263
+ elif label == case['CAPITALIZE']:
264
+ return token.lower().capitalize()
265
+ elif label == case['UPPER']:
266
+ return token.upper()
267
+ else:
268
+ return token
269
+
270
+
271
+ class CasePuncPredictor:
272
+ def __init__(self, checkpoint_path, lang=default_config.lang, flavor=default_config.flavor, device=default_config.device):
273
+ loaded = torch.load(checkpoint_path, map_location=device if torch.cuda.is_available() else 'cpu')
274
+ if 'config' in loaded:
275
+ self.config = Config(**loaded['config'])
276
+ else:
277
+ self.config = Config(lang=lang, flavor=flavor, device=device)
278
+ init(self.config)
279
+
280
+ self.model = Model(self.config.flavor, self.config.device)
281
+ self.model.load_state_dict(loaded['model_state_dict'])
282
+ self.model.eval()
283
+ self.model.to(self.config.device)
284
+
285
+ self.rev_case = {b: a for a, b in case.items()}
286
+ self.rev_punc = {b: a for a, b in punctuation.items()}
287
+
288
+ def tokenize(self, text):
289
+ return [self.config.cls_token] + self.config.tokenizer.tokenize(text) + [self.config.sep_token]
290
+
291
+ def predict(self, tokens, getter=lambda x: x):
292
+ max_length = self.config.max_length
293
+ device = self.config.device
294
+ if type(tokens) == str:
295
+ tokens = self.tokenize(tokens)
296
+ previous_label = punctuation['PERIOD']
297
+ for start in range(0, len(tokens), max_length):
298
+ instance = tokens[start: start + max_length]
299
+ if type(getter(instance[0])) == str:
300
+ ids = self.config.tokenizer.convert_tokens_to_ids(getter(token) for token in instance)
301
+ else:
302
+ ids = [getter(token) for token in instance]
303
+ if len(ids) < max_length:
304
+ ids += [0] * (max_length - len(ids))
305
+ x = torch.tensor([ids]).long().to(device)
306
+ y_scores1, y_scores2 = self.model(x)
307
+ y_pred1 = torch.max(y_scores1, 2)[1]
308
+ y_pred2 = torch.max(y_scores2, 2)[1]
309
+ for i, id, token, punc_label, case_label in zip(range(len(instance)), ids, instance, y_pred1[0].tolist()[:len(instance)], y_pred2[0].tolist()[:len(instance)]):
310
+ if id == self.config.cls_token_id or id == self.config.sep_token_id:
311
+ continue
312
+ if previous_label != None and previous_label > 1:
313
+ if case_label in [case['LOWER'], case['OTHER']]: # LOWER, OTHER
314
+ case_label = case['CAPITALIZE']
315
+ if i + start == len(tokens) - 2 and punc_label == punctuation['O']:
316
+ punc_label = punctuation['PERIOD']
317
+ yield (token, self.rev_case[case_label], self.rev_punc[punc_label])
318
+ previous_label = punc_label
319
+
320
+ def map_case_label(self, token, case_label):
321
+ if token.endswith('</w>'):
322
+ token = token[:-4]
323
+ if token.startswith('##'):
324
+ token = token[2:]
325
+ return recase(token, case[case_label])
326
+
327
+ def map_punc_label(self, token, punc_label):
328
+ if token.endswith('</w>'):
329
+ token = token[:-4]
330
+ if token.startswith('##'):
331
+ token = token[2:]
332
+ return token + punctuation_syms[punctuation[punc_label]]
333
+
334
+
335
+
336
+ def generate_predictions(config, checkpoint_path):
337
+ loaded = torch.load(checkpoint_path, map_location=config.device if torch.cuda.is_available() else 'cpu')
338
+ if 'config' in loaded:
339
+ config = Config(**loaded['config'])
340
+ init(config)
341
+
342
+ model = Model(config.flavor, config.device)
343
+ model.load_state_dict(loaded['model_state_dict'])
344
+
345
+ rev_case = {b: a for a, b in case.items()}
346
+ rev_punc = {b: a for a, b in punctuation.items()}
347
+
348
+ for line in sys.stdin:
349
+ # also drop punctuation that we may generate
350
+ line = ''.join([c for c in line if c not in mapped_punctuation])
351
+ if config.debug:
352
+ print(line)
353
+ tokens = [config.cls_token] + config.tokenizer.tokenize(line) + [config.sep_token]
354
+ if config.debug:
355
+ print(tokens)
356
+ previous_label = punctuation['PERIOD']
357
+ first_time = True
358
+ was_word = False
359
+ for start in range(0, len(tokens), config.max_length):
360
+ instance = tokens[start: start + config.max_length]
361
+ ids = config.tokenizer.convert_tokens_to_ids(instance)
362
+ #print(len(ids), file=sys.stderr)
363
+ if len(ids) < config.max_length:
364
+ ids += [config.pad_token_id] * (config.max_length - len(ids))
365
+ x = torch.tensor([ids]).long().to(config.device)
366
+ y_scores1, y_scores2 = model(x)
367
+ y_pred1 = torch.max(y_scores1, 2)[1]
368
+ y_pred2 = torch.max(y_scores2, 2)[1]
369
+ for id, token, punc_label, case_label in zip(ids, instance, y_pred1[0].tolist()[:len(instance)], y_pred2[0].tolist()[:len(instance)]):
370
+ if config.debug:
371
+ print(id, token, punc_label, case_label, file=sys.stderr)
372
+ if id == config.cls_token_id or id == config.sep_token_id:
373
+ continue
374
+ if previous_label != None and previous_label > 1:
375
+ if case_label in [case['LOWER'], case['OTHER']]:
376
+ case_label = case['CAPITALIZE']
377
+ previous_label = punc_label
378
+ # different strategy due to sub-lexical token encoding in Flaubert
379
+ if config.lang == 'fr':
380
+ if token.endswith('</w>'):
381
+ cased_token = recase(token[:-4], case_label)
382
+ if was_word:
383
+ print(' ', end='')
384
+ print(cased_token + punctuation_syms[punc_label], end='')
385
+ was_word = True
386
+ else:
387
+ cased_token = recase(token, case_label)
388
+ if was_word:
389
+ print(' ', end='')
390
+ print(cased_token, end='')
391
+ was_word = False
392
+ else:
393
+ if token.startswith('##'):
394
+ cased_token = recase(token[2:], case_label)
395
+ print(cased_token, end='')
396
+ else:
397
+ cased_token = recase(token, case_label)
398
+ if not first_time:
399
+ print(' ', end='')
400
+ first_time = False
401
+ print(cased_token + punctuation_syms[punc_label], end='')
402
+ if previous_label == 0:
403
+ print('.', end='')
404
+ print()
405
+
406
+
407
+ def label_for_case(token):
408
+ token = re.sub('[^\p{Han}\p{Ll}\p{Lu}]', '', token)
409
+ if token == token.lower():
410
+ return 'LOWER'
411
+ elif token == token.lower().capitalize():
412
+ return 'CAPITALIZE'
413
+ elif token == token.upper():
414
+ return 'UPPER'
415
+ else:
416
+ return 'OTHER'
417
+
418
+
419
+ def make_tensors(config, input_fn, output_x_fn, output_y_fn):
420
+ # count file lines without loading them
421
+ size = 0
422
+ with open(input_fn) as fp:
423
+ for line in fp:
424
+ size += 1
425
+
426
+ with open(input_fn) as fp:
427
+ X = torch.IntTensor(size)
428
+ Y = torch.ByteTensor(size, 2)
429
+
430
+ offset = 0
431
+ for n, line in enumerate(fp):
432
+ word, case_label, punc_label = line.strip().split('\t')
433
+ id = config.tokenizer.convert_tokens_to_ids(word)
434
+ if config.debug:
435
+ assert word.lower() == tokenizer.convert_ids_to_tokens(id)
436
+ X[offset] = id
437
+ Y[offset, 0] = punctuation[punc_label]
438
+ Y[offset, 1] = case[case_label]
439
+ offset += 1
440
+
441
+ torch.save(X, output_x_fn)
442
+ torch.save(Y, output_y_fn)
443
+
444
+
445
+ mapped_punctuation = {
446
+ '.': 'PERIOD',
447
+ '...': 'PERIOD',
448
+ ',': 'COMMA',
449
+ ';': 'COMMA',
450
+ ':': 'COMMA',
451
+ '(': 'COMMA',
452
+ ')': 'COMMA',
453
+ '?': 'QUESTION',
454
+ '!': 'EXCLAMATION',
455
+ ',': 'COMMA',
456
+ '!': 'EXCLAMATION',
457
+ '?': 'QUESTION',
458
+ ';': 'COMMA',
459
+ ':': 'COMMA',
460
+ '(': 'COMMA',
461
+ '(': 'COMMA',
462
+ ')': 'COMMA',
463
+ '[': 'COMMA',
464
+ ']': 'COMMA',
465
+ '【': 'COMMA',
466
+ '】': 'COMMA',
467
+ '└': 'COMMA',
468
+ '└ ': 'COMMA',
469
+ '_': 'O',
470
+ '。': 'PERIOD',
471
+ '、': 'COMMA', # enumeration comma
472
+ '、': 'COMMA',
473
+ '…': 'PERIOD',
474
+ '—': 'COMMA',
475
+ '「': 'COMMA',
476
+ '」': 'COMMA',
477
+ '.': 'PERIOD',
478
+ '《': 'O',
479
+ '》': 'O',
480
+ ',': 'COMMA',
481
+ '“': 'O',
482
+ '”': 'O',
483
+ '"': 'O',
484
+ '-': 'O',
485
+ '-': 'O',
486
+ '〉': 'COMMA',
487
+ '〈': 'COMMA',
488
+ '↑': 'O',
489
+ '〔': 'COMMA',
490
+ '〕': 'COMMA',
491
+ }
492
+
493
+ def preprocess_text(config, max_token_count=-1):
494
+ global num_tokens_output
495
+ max_token_count = int(max_token_count)
496
+ num_tokens_output = 0
497
+ def process_segment(text, punctuation):
498
+ global num_tokens_output
499
+ text = text.replace('\t', ' ')
500
+ tokens = config.tokenizer.tokenize(text)
501
+ for i, token in enumerate(tokens):
502
+ case_label = label_for_case(token)
503
+ if i == len(tokens) - 1:
504
+ print(token.lower(), case_label, punctuation, sep='\t')
505
+ else:
506
+ print(token.lower(), case_label, 'O', sep='\t')
507
+ num_tokens_output += 1
508
+ # a bit too ugly, but alternative is to throw an exception
509
+ if max_token_count > 0 and num_tokens_output >= max_token_count:
510
+ sys.exit(0)
511
+
512
+ for line in sys.stdin:
513
+ line = line.strip()
514
+ if line != '':
515
+ line = unicodedata.normalize("NFC", line)
516
+ if config.debug:
517
+ print(line)
518
+ start = 0
519
+ for i, char in enumerate(line):
520
+ if char in mapped_punctuation:
521
+ if i > start and line[start: i].strip() != '':
522
+ process_segment(line[start: i], mapped_punctuation[char])
523
+ start = i + 1
524
+ if start < len(line):
525
+ process_segment(line[start:], 'PERIOD')
526
+
527
+
528
+ def preprocess_text_old_fr(config):
529
+ assert config.lang == 'fr'
530
+ splitsents = MosesSentenceSplitter(lang)
531
+ tokenize = MosesTokenizer(lang, extra=['-no-escape'])
532
+ normalize = MosesPunctuationNormalizer(lang)
533
+
534
+ for line in sys.stdin:
535
+ if line.strip() != '':
536
+ for sentence in splitsents([normalize(line)]):
537
+ tokens = tokenize(sentence)
538
+ previous_token = None
539
+ for token in tokens:
540
+ if token in mapped_punctuation:
541
+ if previous_token != None:
542
+ print(previous_token, mapped_punctuation[token], sep='\t')
543
+ previous_token = None
544
+ elif not re.search('[\p{Han}\p{Ll}\p{Lu}\d]', token): # remove non-alphanumeric tokens
545
+ continue
546
+ else:
547
+ if previous_token != None:
548
+ print(previous_token, 'O', sep='\t')
549
+ previous_token = token
550
+ if previous_token != None:
551
+ print(previous_token, 'PERIOD', sep='\t')
552
+
553
+
554
+ # modification of the wordpiece tokenizer to keep case information even if vocab is lower cased
555
+ # forked from https://github.com/huggingface/transformers/blob/master/src/transformers/models/bert/tokenization_bert.py
556
+
557
+ class WordpieceTokenizer(object):
558
+ """Runs WordPiece tokenization."""
559
+
560
+ def __init__(self, vocab, unk_token, max_input_chars_per_word=100, keep_case=True):
561
+ self.vocab = vocab
562
+ self.unk_token = unk_token
563
+ self.max_input_chars_per_word = max_input_chars_per_word
564
+ self.keep_case = keep_case
565
+
566
+ def tokenize(self, text):
567
+ """
568
+ Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform
569
+ tokenization using the given vocabulary.
570
+ For example, :obj:`input = "unaffable"` wil return as output :obj:`["un", "##aff", "##able"]`.
571
+ Args:
572
+ text: A single token or whitespace separated tokens. This should have
573
+ already been passed through `BasicTokenizer`.
574
+ Returns:
575
+ A list of wordpiece tokens.
576
+ """
577
+
578
+ output_tokens = []
579
+ for token in text.strip().split():
580
+ chars = list(token)
581
+ if len(chars) > self.max_input_chars_per_word:
582
+ output_tokens.append(self.unk_token)
583
+ continue
584
+
585
+ is_bad = False
586
+ start = 0
587
+ sub_tokens = []
588
+ while start < len(chars):
589
+ end = len(chars)
590
+ cur_substr = None
591
+ while start < end:
592
+ substr = "".join(chars[start:end])
593
+ if start > 0:
594
+ substr = "##" + substr
595
+ # optionaly lowercase substring before checking for inclusion in vocab
596
+ if (self.keep_case and substr.lower() in self.vocab) or (substr in self.vocab):
597
+ cur_substr = substr
598
+ break
599
+ end -= 1
600
+ if cur_substr is None:
601
+ is_bad = True
602
+ break
603
+ sub_tokens.append(cur_substr)
604
+ start = end
605
+
606
+ if is_bad:
607
+ output_tokens.append(self.unk_token)
608
+ else:
609
+ output_tokens.extend(sub_tokens)
610
+ return output_tokens
611
+
612
+
613
+ # modification of XLM bpe tokenizer for keeping case information when vocab is lowercase
614
+ # forked from https://github.com/huggingface/transformers/blob/cd56f3fe7eae4a53a9880e3f5e8f91877a78271c/src/transformers/models/xlm/tokenization_xlm.py
615
+ def bpe(self, token):
616
+ def to_lower(pair):
617
+ #print(' ',pair)
618
+ return (pair[0].lower(), pair[1].lower())
619
+
620
+ from transformers.models.xlm.tokenization_xlm import get_pairs
621
+
622
+ word = tuple(token[:-1]) + (token[-1] + "</w>",)
623
+ if token in self.cache:
624
+ return self.cache[token]
625
+ pairs = get_pairs(word)
626
+
627
+ if not pairs:
628
+ return token + "</w>"
629
+
630
+ while True:
631
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(to_lower(pair), float("inf")))
632
+ #print(bigram)
633
+ if to_lower(bigram) not in self.bpe_ranks:
634
+ break
635
+ first, second = bigram
636
+ new_word = []
637
+ i = 0
638
+ while i < len(word):
639
+ try:
640
+ j = word.index(first, i)
641
+ except ValueError:
642
+ new_word.extend(word[i:])
643
+ break
644
+ else:
645
+ new_word.extend(word[i:j])
646
+ i = j
647
+
648
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
649
+ new_word.append(first + second)
650
+ i += 2
651
+ else:
652
+ new_word.append(word[i])
653
+ i += 1
654
+ new_word = tuple(new_word)
655
+ word = new_word
656
+ if len(word) == 1:
657
+ break
658
+ else:
659
+ pairs = get_pairs(word)
660
+ word = " ".join(word)
661
+ if word == "\n </w>":
662
+ word = "\n</w>"
663
+ self.cache[token] = word
664
+ return word
665
+
666
+
667
+
668
+ def init(config):
669
+ init_random(config.seed)
670
+
671
+ if config.lang == 'fr':
672
+ config.tokenizer = tokenizer = AutoTokenizer.from_pretrained(config.flavor, do_lower_case=False)
673
+
674
+ from transformers.models.xlm.tokenization_xlm import XLMTokenizer
675
+ assert isinstance(tokenizer, XLMTokenizer)
676
+
677
+ # monkey patch XLM tokenizer
678
+ import types
679
+ tokenizer.bpe = types.MethodType(bpe, tokenizer)
680
+ else:
681
+ # warning: needs to be BertTokenizer for monkey patching to work
682
+ config.tokenizer = tokenizer = BertTokenizer.from_pretrained(config.flavor, do_lower_case=False)
683
+
684
+ # warning: monkey patch tokenizer to keep case information
685
+ #from recasing_tokenizer import WordpieceTokenizer
686
+ config.tokenizer.wordpiece_tokenizer = WordpieceTokenizer(vocab=tokenizer.vocab, unk_token=tokenizer.unk_token)
687
+
688
+ if config.lang == 'fr':
689
+ config.pad_token_id = tokenizer.pad_token_id
690
+ config.cls_token_id = tokenizer.bos_token_id
691
+ config.cls_token = tokenizer.bos_token
692
+ config.sep_token_id = tokenizer.sep_token_id
693
+ config.sep_token = tokenizer.sep_token
694
+ else:
695
+ config.pad_token_id = tokenizer.pad_token_id
696
+ config.cls_token_id = tokenizer.cls_token_id
697
+ config.cls_token = tokenizer.cls_token
698
+ config.sep_token_id = tokenizer.sep_token_id
699
+ config.sep_token = tokenizer.sep_token
700
+
701
+ if not torch.cuda.is_available() and config.device == 'cuda':
702
+ print('WARNING: reverting to cpu as cuda is not available', file=sys.stderr)
703
+ config.device = torch.device(config.device if torch.cuda.is_available() else 'cpu')
704
+
705
+
706
+ def main(config, action, args):
707
+ init(config)
708
+
709
+ if action == 'train':
710
+ train(config, *args)
711
+ elif action == 'eval':
712
+ run_eval(config, *args)
713
+ elif action == 'predict':
714
+ generate_predictions(config, *args)
715
+ elif action == 'tensorize':
716
+ make_tensors(config, *args)
717
+ elif action == 'preprocess':
718
+ preprocess_text(config, *args)
719
+ else:
720
+ print('invalid action "%s"' % action)
721
+ sys.exit(1)
722
+
723
+ if __name__ == '__main__':
724
+ parser = argparse.ArgumentParser()
725
+ parser.add_argument("action", help="train|eval|predict|tensorize|preprocess", type=str)
726
+ parser.add_argument("action_args", help="arguments for selected action", type=str, nargs='*')
727
+ parser.add_argument("--seed", help="random seed", default=default_config.seed, type=int)
728
+ parser.add_argument("--lang", help="language (fr, en, zh)", default=default_config.lang, type=str)
729
+ parser.add_argument("--flavor", help="bert flavor in transformers model zoo", default=default_config.flavor, type=str)
730
+ parser.add_argument("--max-length", help="maximum input length", default=default_config.max_length, type=int)
731
+ parser.add_argument("--batch-size", help="size of batches", default=default_config.batch_size, type=int)
732
+ parser.add_argument("--device", help="computation device (cuda, cpu)", default=default_config.device, type=str)
733
+ parser.add_argument("--debug", help="whether to output more debug info", default=default_config.debug, type=bool)
734
+ parser.add_argument("--updates", help="number of training updates to perform", default=default_config.updates, type=bool)
735
+ parser.add_argument("--period", help="validation period in updates", default=default_config.period, type=bool)
736
+ parser.add_argument("--lr", help="learning rate", default=default_config.lr, type=bool)
737
+ parser.add_argument("--dab-rate", help="drop at boundaries rate", default=default_config.dab_rate, type=bool)
738
+ config = Config(**parser.parse_args().__dict__)
739
+
740
+ main(config, config.action, config.action_args)
741
+
742
+
punctuation/vosk-recasepunc-en-0.22.7z ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d754b827d2b3f85fe56cfb7a5262dc5658a9257ff7f9404e57595328882b8777
3
+ size 1148483511
punctuation/vosk-recasepunc-en-0.22/README ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ 1. Install pytorch and transformers:
2
+
3
+ pip3 install transformers
4
+
5
+ 2. Run python3 example.py de-test.txt
6
+
7
+ 3. Compare with de-test.txt.orig
punctuation/vosk-recasepunc-en-0.22/checkpoint ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9782ccd13a130feffb13609834778421ebd39e26910d25ddcf2185a0eea75935
3
+ size 1310193349
punctuation/vosk-recasepunc-en-0.22/example.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import time
3
+ from transformers import logging
4
+ from recasepunc import CasePuncPredictor
5
+ from recasepunc import WordpieceTokenizer
6
+ from recasepunc import Config
7
+
8
+ logging.set_verbosity_error()
9
+
10
+ predictor = CasePuncPredictor('checkpoint', lang="en")
11
+
12
+ text = " ".join(open(sys.argv[1]).readlines())
13
+ tokens = list(enumerate(predictor.tokenize(text)))
14
+
15
+ results = ""
16
+ for token, case_label, punc_label in predictor.predict(tokens, lambda x: x[1]):
17
+ prediction = predictor.map_punc_label(predictor.map_case_label(token[1], case_label), punc_label)
18
+
19
+ if token[1][0] == '\'' or (len(results) > 0 and results[-1] == '\''):
20
+ results = results + prediction
21
+ elif token[1][0] != '#':
22
+ results = results + ' ' + prediction
23
+ else:
24
+ results = results + prediction
25
+
26
+ print (results.strip())
punctuation/vosk-recasepunc-en-0.22/recasepunc.py ADDED
@@ -0,0 +1,742 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import collections
3
+ import os
4
+ import regex as re
5
+ #from mosestokenizer import *
6
+ from tqdm import tqdm
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import torch.optim as optim
11
+ import random
12
+ import unicodedata
13
+ import numpy as np
14
+ import argparse
15
+ from torch.utils.data import TensorDataset, DataLoader
16
+
17
+ from transformers import AutoModel, AutoTokenizer, BertTokenizer
18
+
19
+ default_config = argparse.Namespace(
20
+ seed=871253,
21
+ lang='en',
22
+ #flavor='flaubert/flaubert_base_uncased',
23
+ flavor=None,
24
+ max_length=256,
25
+ batch_size=16,
26
+ updates=24000,
27
+ period=1000,
28
+ lr=1e-5,
29
+ dab_rate=0.1,
30
+ device='cuda',
31
+ debug=False
32
+ )
33
+
34
+ default_flavors = {
35
+ 'fr': 'flaubert/flaubert_base_uncased',
36
+ 'en': 'bert-base-uncased',
37
+ 'zh': 'ckiplab/bert-base-chinese',
38
+ 'tr': 'dbmdz/bert-base-turkish-uncased',
39
+ 'de': 'dbmdz/bert-base-german-uncased',
40
+ 'pt': 'neuralmind/bert-base-portuguese-cased'
41
+ }
42
+
43
+ class Config(argparse.Namespace):
44
+ def __init__(self, **kwargs):
45
+ for key, value in default_config.__dict__.items():
46
+ setattr(self, key, value)
47
+ for key, value in kwargs.items():
48
+ setattr(self, key, value)
49
+
50
+ assert self.lang in ['fr', 'en', 'zh', 'tr', 'pt', 'de']
51
+
52
+ if 'lang' in kwargs and ('flavor' not in kwargs or kwargs['flavor'] is None):
53
+ self.flavor = default_flavors[self.lang]
54
+
55
+ #print(self.lang, self.flavor)
56
+
57
+
58
+ def init_random(seed):
59
+ # make sure everything is deterministic
60
+ os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
61
+ #torch.use_deterministic_algorithms(True)
62
+ torch.manual_seed(seed)
63
+ torch.cuda.manual_seed_all(seed)
64
+ random.seed(seed)
65
+ np.random.seed(seed)
66
+
67
+ # NOTE: it is assumed in the implementation that y[:,0] is the punctuation label, and y[:,1] is the case label!
68
+
69
+ punctuation = {
70
+ 'O': 0,
71
+ 'COMMA': 1,
72
+ 'PERIOD': 2,
73
+ 'QUESTION': 3,
74
+ 'EXCLAMATION': 4,
75
+ }
76
+
77
+ punctuation_syms = ['', ',', '.', ' ?', ' !']
78
+
79
+ case = {
80
+ 'LOWER': 0,
81
+ 'UPPER': 1,
82
+ 'CAPITALIZE': 2,
83
+ 'OTHER': 3,
84
+ }
85
+
86
+
87
+ class Model(nn.Module):
88
+ def __init__(self, flavor, device):
89
+ super().__init__()
90
+ self.bert = AutoModel.from_pretrained(flavor)
91
+ # need a proper way of determining representation size
92
+ size = self.bert.dim if hasattr(self.bert, 'dim') else self.bert.config.pooler_fc_size if hasattr(self.bert.config, 'pooler_fc_size') else self.bert.config.emb_dim if hasattr(self.bert.config, 'emb_dim') else self.bert.config.hidden_size
93
+ self.punc = nn.Linear(size, 5)
94
+ self.case = nn.Linear(size, 4)
95
+ self.dropout = nn.Dropout(0.3)
96
+ self.to(device)
97
+
98
+ def forward(self, x):
99
+ output = self.bert(x)
100
+ representations = self.dropout(F.gelu(output['last_hidden_state']))
101
+ punc = self.punc(representations)
102
+ case = self.case(representations)
103
+ return punc, case
104
+
105
+
106
+ # randomly create sequences that align to punctuation boundaries
107
+ def drop_at_boundaries(rate, x, y, cls_token_id, sep_token_id, pad_token_id):
108
+ for i, dropped in enumerate(torch.rand((len(x),)) < rate):
109
+ if dropped:
110
+ # select all indices that are sentence endings
111
+ indices = (y[i,:,0] > 1).nonzero(as_tuple=True)[0]
112
+ if len(indices) < 2:
113
+ continue
114
+ start = indices[0] + 1
115
+ end = indices[random.randint(1, len(indices) - 1)] + 1
116
+ length = end - start
117
+ if length + 2 > len(x[i]):
118
+ continue
119
+ x[i, 0] = cls_token_id
120
+ x[i, 1: length + 1] = x[i, start: end].clone()
121
+ x[i, length + 1] = sep_token_id
122
+ x[i, length + 2:] = pad_token_id
123
+ y[i, 0] = 0
124
+ y[i, 1: length + 1] = y[i, start: end].clone()
125
+ y[i, length + 1:] = 0
126
+
127
+
128
+ def compute_performance(config, model, loader):
129
+ device = config.device
130
+ criterion = nn.CrossEntropyLoss()
131
+ model.eval()
132
+ total_loss = all_correct1 = all_correct2 = num_loss = num_perf = 0
133
+ num_ref = collections.defaultdict(float)
134
+ num_hyp = collections.defaultdict(float)
135
+ num_correct = collections.defaultdict(float)
136
+ for x, y in loader:
137
+ x = x.long().to(device)
138
+ y = y.long().to(device)
139
+ y1 = y[:,:,0]
140
+ y2 = y[:,:,1]
141
+ with torch.no_grad():
142
+ y_scores1, y_scores2 = model(x.to(device))
143
+ loss1 = criterion(y_scores1.view(y1.size(0) * y1.size(1), -1), y1.view(y1.size(0) * y1.size(1)))
144
+ loss2 = criterion(y_scores2.view(y2.size(0) * y2.size(1), -1), y2.view(y2.size(0) * y2.size(1)))
145
+ loss = loss1 + loss2
146
+ y_pred1 = torch.max(y_scores1, 2)[1]
147
+ y_pred2 = torch.max(y_scores2, 2)[1]
148
+ for label in range(1, 5):
149
+ ref = (y1 == label)
150
+ hyp = (y_pred1 == label)
151
+ correct = (ref * hyp == 1)
152
+ num_ref[label] += ref.sum()
153
+ num_hyp[label] += hyp.sum()
154
+ num_correct[label] += correct.sum()
155
+ num_ref[0] += ref.sum()
156
+ num_hyp[0] += hyp.sum()
157
+ num_correct[0] += correct.sum()
158
+ all_correct1 += (y_pred1 == y1).sum()
159
+ all_correct2 += (y_pred2 == y2).sum()
160
+ total_loss += loss.item()
161
+ num_loss += len(y)
162
+ num_perf += len(y) * config.max_length
163
+ recall = {}
164
+ precision = {}
165
+ fscore = {}
166
+ for label in range(0, 5):
167
+ recall[label] = num_correct[label] / num_ref[label] if num_ref[label] > 0 else 0
168
+ precision[label] = num_correct[label] / num_hyp[label] if num_hyp[label] > 0 else 0
169
+ fscore[label] = (2 * recall[label] * precision[label] / (recall[label] + precision[label])).item() if recall[label] + precision[label] > 0 else 0
170
+ return total_loss / num_loss, all_correct2.item() / num_perf, all_correct1.item() / num_perf, fscore
171
+
172
+
173
+ def fit(config, model, checkpoint_path, train_loader, valid_loader, iterations, valid_period=200, lr=1e-5):
174
+ device = config.device
175
+ criterion = nn.CrossEntropyLoss()
176
+ optimizer = optim.Adam(filter(lambda param: param.requires_grad, model.parameters()), lr=lr)
177
+ iteration = 0
178
+ while True:
179
+ model.train()
180
+ total_loss = num = 0
181
+ for x, y in tqdm(train_loader):
182
+ x = x.long().to(device)
183
+ y = y.long().to(device)
184
+ drop_at_boundaries(config.dab_rate, x, y, config.cls_token_id, config.sep_token_id, config.pad_token_id)
185
+ y1 = y[:,:,0]
186
+ y2 = y[:,:,1]
187
+ optimizer.zero_grad()
188
+ y_scores1, y_scores2 = model(x)
189
+ loss1 = criterion(y_scores1.view(y1.size(0) * y1.size(1), -1), y1.view(y1.size(0) * y1.size(1)))
190
+ loss2 = criterion(y_scores2.view(y2.size(0) * y2.size(1), -1), y2.view(y2.size(0) * y2.size(1)))
191
+ loss = loss1 + loss2
192
+ loss.backward()
193
+ optimizer.step()
194
+ total_loss += loss.item()
195
+ num += len(y)
196
+ if iteration % valid_period == valid_period - 1:
197
+ train_loss = total_loss / num
198
+ valid_loss, valid_accuracy_case, valid_accuracy_punc, valid_fscore = compute_performance(config, model, valid_loader)
199
+ torch.save({
200
+ 'iteration': iteration + 1,
201
+ 'model_state_dict': model.state_dict(),
202
+ 'optimizer_state_dict': optimizer.state_dict(),
203
+ 'train_loss': train_loss,
204
+ 'valid_loss': valid_loss,
205
+ 'valid_accuracy_case': valid_accuracy_case,
206
+ 'valid_accuracy_punc': valid_accuracy_punc,
207
+ 'valid_fscore': valid_fscore,
208
+ 'config': config.__dict__,
209
+ }, '%s.%d' % (checkpoint_path, iteration + 1))
210
+ print(iteration + 1, train_loss, valid_loss, valid_accuracy_case, valid_accuracy_punc, valid_fscore)
211
+ total_loss = num = 0
212
+
213
+ iteration += 1
214
+ if iteration > iterations:
215
+ return
216
+
217
+ sys.stderr.flush()
218
+ sys.stdout.flush()
219
+
220
+
221
+ def batchify(max_length, x, y):
222
+ print (x.shape)
223
+ print (y.shape)
224
+ x = x[:(len(x) // max_length) * max_length].reshape(-1, max_length)
225
+ y = y[:(len(y) // max_length) * max_length, :].reshape(-1, max_length, 2)
226
+ return x, y
227
+
228
+
229
+ def train(config, train_x_fn, train_y_fn, valid_x_fn, valid_y_fn, checkpoint_path):
230
+ X_train, Y_train = batchify(config.max_length, torch.load(train_x_fn), torch.load(train_y_fn))
231
+ X_valid, Y_valid = batchify(config.max_length, torch.load(valid_x_fn), torch.load(valid_y_fn))
232
+
233
+ train_set = TensorDataset(X_train, Y_train)
234
+ valid_set = TensorDataset(X_valid, Y_valid)
235
+
236
+ train_loader = DataLoader(train_set, batch_size=config.batch_size, shuffle=True)
237
+ valid_loader = DataLoader(valid_set, batch_size=config.batch_size)
238
+
239
+ model = Model(config.flavor, config.device)
240
+
241
+ fit(config, model, checkpoint_path, train_loader, valid_loader, config.updates, config.period, config.lr)
242
+
243
+
244
+ def run_eval(config, test_x_fn, test_y_fn, checkpoint_path):
245
+ X_test, Y_test = batchify(config.max_length, torch.load(test_x_fn), torch.load(test_y_fn))
246
+ test_set = TensorDataset(X_test, Y_test)
247
+ test_loader = DataLoader(test_set, batch_size=config.batch_size)
248
+
249
+ loaded = torch.load(checkpoint_path, map_location=config.device)
250
+ if 'config' in loaded:
251
+ config = Config(**loaded['config'])
252
+ init(config)
253
+
254
+ model = Model(config.flavor, config.device)
255
+ model.load_state_dict(loaded['model_state_dict'])
256
+
257
+ print(*compute_performance(config, model, test_loader))
258
+
259
+
260
+ def recase(token, label):
261
+ if label == case['LOWER']:
262
+ return token.lower()
263
+ elif label == case['CAPITALIZE']:
264
+ return token.lower().capitalize()
265
+ elif label == case['UPPER']:
266
+ return token.upper()
267
+ else:
268
+ return token
269
+
270
+
271
+ class CasePuncPredictor:
272
+ def __init__(self, checkpoint_path, lang=default_config.lang, flavor=default_config.flavor, device=default_config.device):
273
+ loaded = torch.load(checkpoint_path, map_location=device if torch.cuda.is_available() else 'cpu')
274
+ if 'config' in loaded:
275
+ self.config = Config(**loaded['config'])
276
+ else:
277
+ self.config = Config(lang=lang, flavor=flavor, device=device)
278
+ init(self.config)
279
+
280
+ self.model = Model(self.config.flavor, self.config.device)
281
+ self.model.load_state_dict(loaded['model_state_dict'])
282
+ self.model.eval()
283
+ self.model.to(self.config.device)
284
+
285
+ self.rev_case = {b: a for a, b in case.items()}
286
+ self.rev_punc = {b: a for a, b in punctuation.items()}
287
+
288
+ def tokenize(self, text):
289
+ return [self.config.cls_token] + self.config.tokenizer.tokenize(text) + [self.config.sep_token]
290
+
291
+ def predict(self, tokens, getter=lambda x: x):
292
+ max_length = self.config.max_length
293
+ device = self.config.device
294
+ if type(tokens) == str:
295
+ tokens = self.tokenize(tokens)
296
+ previous_label = punctuation['PERIOD']
297
+ for start in range(0, len(tokens), max_length):
298
+ instance = tokens[start: start + max_length]
299
+ if type(getter(instance[0])) == str:
300
+ ids = self.config.tokenizer.convert_tokens_to_ids(getter(token) for token in instance)
301
+ else:
302
+ ids = [getter(token) for token in instance]
303
+ if len(ids) < max_length:
304
+ ids += [0] * (max_length - len(ids))
305
+ x = torch.tensor([ids]).long().to(device)
306
+ y_scores1, y_scores2 = self.model(x)
307
+ y_pred1 = torch.max(y_scores1, 2)[1]
308
+ y_pred2 = torch.max(y_scores2, 2)[1]
309
+ for i, id, token, punc_label, case_label in zip(range(len(instance)), ids, instance, y_pred1[0].tolist()[:len(instance)], y_pred2[0].tolist()[:len(instance)]):
310
+ if id == self.config.cls_token_id or id == self.config.sep_token_id:
311
+ continue
312
+ if previous_label != None and previous_label > 1:
313
+ if case_label in [case['LOWER'], case['OTHER']]: # LOWER, OTHER
314
+ case_label = case['CAPITALIZE']
315
+ if i + start == len(tokens) - 2 and punc_label == punctuation['O']:
316
+ punc_label = punctuation['PERIOD']
317
+ yield (token, self.rev_case[case_label], self.rev_punc[punc_label])
318
+ previous_label = punc_label
319
+
320
+ def map_case_label(self, token, case_label):
321
+ if token.endswith('</w>'):
322
+ token = token[:-4]
323
+ if token.startswith('##'):
324
+ token = token[2:]
325
+ return recase(token, case[case_label])
326
+
327
+ def map_punc_label(self, token, punc_label):
328
+ if token.endswith('</w>'):
329
+ token = token[:-4]
330
+ if token.startswith('##'):
331
+ token = token[2:]
332
+ return token + punctuation_syms[punctuation[punc_label]]
333
+
334
+
335
+
336
+ def generate_predictions(config, checkpoint_path):
337
+ loaded = torch.load(checkpoint_path, map_location=config.device if torch.cuda.is_available() else 'cpu')
338
+ if 'config' in loaded:
339
+ config = Config(**loaded['config'])
340
+ init(config)
341
+
342
+ model = Model(config.flavor, config.device)
343
+ model.load_state_dict(loaded['model_state_dict'])
344
+
345
+ rev_case = {b: a for a, b in case.items()}
346
+ rev_punc = {b: a for a, b in punctuation.items()}
347
+
348
+ for line in sys.stdin:
349
+ # also drop punctuation that we may generate
350
+ line = ''.join([c for c in line if c not in mapped_punctuation])
351
+ if config.debug:
352
+ print(line)
353
+ tokens = [config.cls_token] + config.tokenizer.tokenize(line) + [config.sep_token]
354
+ if config.debug:
355
+ print(tokens)
356
+ previous_label = punctuation['PERIOD']
357
+ first_time = True
358
+ was_word = False
359
+ for start in range(0, len(tokens), config.max_length):
360
+ instance = tokens[start: start + config.max_length]
361
+ ids = config.tokenizer.convert_tokens_to_ids(instance)
362
+ #print(len(ids), file=sys.stderr)
363
+ if len(ids) < config.max_length:
364
+ ids += [config.pad_token_id] * (config.max_length - len(ids))
365
+ x = torch.tensor([ids]).long().to(config.device)
366
+ y_scores1, y_scores2 = model(x)
367
+ y_pred1 = torch.max(y_scores1, 2)[1]
368
+ y_pred2 = torch.max(y_scores2, 2)[1]
369
+ for id, token, punc_label, case_label in zip(ids, instance, y_pred1[0].tolist()[:len(instance)], y_pred2[0].tolist()[:len(instance)]):
370
+ if config.debug:
371
+ print(id, token, punc_label, case_label, file=sys.stderr)
372
+ if id == config.cls_token_id or id == config.sep_token_id:
373
+ continue
374
+ if previous_label != None and previous_label > 1:
375
+ if case_label in [case['LOWER'], case['OTHER']]:
376
+ case_label = case['CAPITALIZE']
377
+ previous_label = punc_label
378
+ # different strategy due to sub-lexical token encoding in Flaubert
379
+ if config.lang == 'fr':
380
+ if token.endswith('</w>'):
381
+ cased_token = recase(token[:-4], case_label)
382
+ if was_word:
383
+ print(' ', end='')
384
+ print(cased_token + punctuation_syms[punc_label], end='')
385
+ was_word = True
386
+ else:
387
+ cased_token = recase(token, case_label)
388
+ if was_word:
389
+ print(' ', end='')
390
+ print(cased_token, end='')
391
+ was_word = False
392
+ else:
393
+ if token.startswith('##'):
394
+ cased_token = recase(token[2:], case_label)
395
+ print(cased_token, end='')
396
+ else:
397
+ cased_token = recase(token, case_label)
398
+ if not first_time:
399
+ print(' ', end='')
400
+ first_time = False
401
+ print(cased_token + punctuation_syms[punc_label], end='')
402
+ if previous_label == 0:
403
+ print('.', end='')
404
+ print()
405
+
406
+
407
+ def label_for_case(token):
408
+ token = re.sub('[^\p{Han}\p{Ll}\p{Lu}]', '', token)
409
+ if token == token.lower():
410
+ return 'LOWER'
411
+ elif token == token.lower().capitalize():
412
+ return 'CAPITALIZE'
413
+ elif token == token.upper():
414
+ return 'UPPER'
415
+ else:
416
+ return 'OTHER'
417
+
418
+
419
+ def make_tensors(config, input_fn, output_x_fn, output_y_fn):
420
+ # count file lines without loading them
421
+ size = 0
422
+ with open(input_fn) as fp:
423
+ for line in fp:
424
+ size += 1
425
+
426
+ with open(input_fn) as fp:
427
+ X = torch.IntTensor(size)
428
+ Y = torch.ByteTensor(size, 2)
429
+
430
+ offset = 0
431
+ for n, line in enumerate(fp):
432
+ word, case_label, punc_label = line.strip().split('\t')
433
+ id = config.tokenizer.convert_tokens_to_ids(word)
434
+ if config.debug:
435
+ assert word.lower() == tokenizer.convert_ids_to_tokens(id)
436
+ X[offset] = id
437
+ Y[offset, 0] = punctuation[punc_label]
438
+ Y[offset, 1] = case[case_label]
439
+ offset += 1
440
+
441
+ torch.save(X, output_x_fn)
442
+ torch.save(Y, output_y_fn)
443
+
444
+
445
+ mapped_punctuation = {
446
+ '.': 'PERIOD',
447
+ '...': 'PERIOD',
448
+ ',': 'COMMA',
449
+ ';': 'COMMA',
450
+ ':': 'COMMA',
451
+ '(': 'COMMA',
452
+ ')': 'COMMA',
453
+ '?': 'QUESTION',
454
+ '!': 'EXCLAMATION',
455
+ ',': 'COMMA',
456
+ '!': 'EXCLAMATION',
457
+ '?': 'QUESTION',
458
+ ';': 'COMMA',
459
+ ':': 'COMMA',
460
+ '(': 'COMMA',
461
+ '(': 'COMMA',
462
+ ')': 'COMMA',
463
+ '[': 'COMMA',
464
+ ']': 'COMMA',
465
+ '【': 'COMMA',
466
+ '】': 'COMMA',
467
+ '└': 'COMMA',
468
+ '└ ': 'COMMA',
469
+ '_': 'O',
470
+ '。': 'PERIOD',
471
+ '、': 'COMMA', # enumeration comma
472
+ '、': 'COMMA',
473
+ '…': 'PERIOD',
474
+ '—': 'COMMA',
475
+ '「': 'COMMA',
476
+ '」': 'COMMA',
477
+ '.': 'PERIOD',
478
+ '《': 'O',
479
+ '》': 'O',
480
+ ',': 'COMMA',
481
+ '“': 'O',
482
+ '”': 'O',
483
+ '"': 'O',
484
+ '-': 'O',
485
+ '-': 'O',
486
+ '〉': 'COMMA',
487
+ '〈': 'COMMA',
488
+ '↑': 'O',
489
+ '〔': 'COMMA',
490
+ '〕': 'COMMA',
491
+ }
492
+
493
+ def preprocess_text(config, max_token_count=-1):
494
+ global num_tokens_output
495
+ max_token_count = int(max_token_count)
496
+ num_tokens_output = 0
497
+ def process_segment(text, punctuation):
498
+ global num_tokens_output
499
+ text = text.replace('\t', ' ')
500
+ tokens = config.tokenizer.tokenize(text)
501
+ for i, token in enumerate(tokens):
502
+ case_label = label_for_case(token)
503
+ if i == len(tokens) - 1:
504
+ print(token.lower(), case_label, punctuation, sep='\t')
505
+ else:
506
+ print(token.lower(), case_label, 'O', sep='\t')
507
+ num_tokens_output += 1
508
+ # a bit too ugly, but alternative is to throw an exception
509
+ if max_token_count > 0 and num_tokens_output >= max_token_count:
510
+ sys.exit(0)
511
+
512
+ for line in sys.stdin:
513
+ line = line.strip()
514
+ if line != '':
515
+ line = unicodedata.normalize("NFC", line)
516
+ if config.debug:
517
+ print(line)
518
+ start = 0
519
+ for i, char in enumerate(line):
520
+ if char in mapped_punctuation:
521
+ if i > start and line[start: i].strip() != '':
522
+ process_segment(line[start: i], mapped_punctuation[char])
523
+ start = i + 1
524
+ if start < len(line):
525
+ process_segment(line[start:], 'PERIOD')
526
+
527
+
528
+ def preprocess_text_old_fr(config):
529
+ assert config.lang == 'fr'
530
+ splitsents = MosesSentenceSplitter(lang)
531
+ tokenize = MosesTokenizer(lang, extra=['-no-escape'])
532
+ normalize = MosesPunctuationNormalizer(lang)
533
+
534
+ for line in sys.stdin:
535
+ if line.strip() != '':
536
+ for sentence in splitsents([normalize(line)]):
537
+ tokens = tokenize(sentence)
538
+ previous_token = None
539
+ for token in tokens:
540
+ if token in mapped_punctuation:
541
+ if previous_token != None:
542
+ print(previous_token, mapped_punctuation[token], sep='\t')
543
+ previous_token = None
544
+ elif not re.search('[\p{Han}\p{Ll}\p{Lu}\d]', token): # remove non-alphanumeric tokens
545
+ continue
546
+ else:
547
+ if previous_token != None:
548
+ print(previous_token, 'O', sep='\t')
549
+ previous_token = token
550
+ if previous_token != None:
551
+ print(previous_token, 'PERIOD', sep='\t')
552
+
553
+
554
+ # modification of the wordpiece tokenizer to keep case information even if vocab is lower cased
555
+ # forked from https://github.com/huggingface/transformers/blob/master/src/transformers/models/bert/tokenization_bert.py
556
+
557
+ class WordpieceTokenizer(object):
558
+ """Runs WordPiece tokenization."""
559
+
560
+ def __init__(self, vocab, unk_token, max_input_chars_per_word=100, keep_case=True):
561
+ self.vocab = vocab
562
+ self.unk_token = unk_token
563
+ self.max_input_chars_per_word = max_input_chars_per_word
564
+ self.keep_case = keep_case
565
+
566
+ def tokenize(self, text):
567
+ """
568
+ Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform
569
+ tokenization using the given vocabulary.
570
+ For example, :obj:`input = "unaffable"` wil return as output :obj:`["un", "##aff", "##able"]`.
571
+ Args:
572
+ text: A single token or whitespace separated tokens. This should have
573
+ already been passed through `BasicTokenizer`.
574
+ Returns:
575
+ A list of wordpiece tokens.
576
+ """
577
+
578
+ output_tokens = []
579
+ for token in text.strip().split():
580
+ chars = list(token)
581
+ if len(chars) > self.max_input_chars_per_word:
582
+ output_tokens.append(self.unk_token)
583
+ continue
584
+
585
+ is_bad = False
586
+ start = 0
587
+ sub_tokens = []
588
+ while start < len(chars):
589
+ end = len(chars)
590
+ cur_substr = None
591
+ while start < end:
592
+ substr = "".join(chars[start:end])
593
+ if start > 0:
594
+ substr = "##" + substr
595
+ # optionaly lowercase substring before checking for inclusion in vocab
596
+ if (self.keep_case and substr.lower() in self.vocab) or (substr in self.vocab):
597
+ cur_substr = substr
598
+ break
599
+ end -= 1
600
+ if cur_substr is None:
601
+ is_bad = True
602
+ break
603
+ sub_tokens.append(cur_substr)
604
+ start = end
605
+
606
+ if is_bad:
607
+ output_tokens.append(self.unk_token)
608
+ else:
609
+ output_tokens.extend(sub_tokens)
610
+ return output_tokens
611
+
612
+
613
+ # modification of XLM bpe tokenizer for keeping case information when vocab is lowercase
614
+ # forked from https://github.com/huggingface/transformers/blob/cd56f3fe7eae4a53a9880e3f5e8f91877a78271c/src/transformers/models/xlm/tokenization_xlm.py
615
+ def bpe(self, token):
616
+ def to_lower(pair):
617
+ #print(' ',pair)
618
+ return (pair[0].lower(), pair[1].lower())
619
+
620
+ from transformers.models.xlm.tokenization_xlm import get_pairs
621
+
622
+ word = tuple(token[:-1]) + (token[-1] + "</w>",)
623
+ if token in self.cache:
624
+ return self.cache[token]
625
+ pairs = get_pairs(word)
626
+
627
+ if not pairs:
628
+ return token + "</w>"
629
+
630
+ while True:
631
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(to_lower(pair), float("inf")))
632
+ #print(bigram)
633
+ if to_lower(bigram) not in self.bpe_ranks:
634
+ break
635
+ first, second = bigram
636
+ new_word = []
637
+ i = 0
638
+ while i < len(word):
639
+ try:
640
+ j = word.index(first, i)
641
+ except ValueError:
642
+ new_word.extend(word[i:])
643
+ break
644
+ else:
645
+ new_word.extend(word[i:j])
646
+ i = j
647
+
648
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
649
+ new_word.append(first + second)
650
+ i += 2
651
+ else:
652
+ new_word.append(word[i])
653
+ i += 1
654
+ new_word = tuple(new_word)
655
+ word = new_word
656
+ if len(word) == 1:
657
+ break
658
+ else:
659
+ pairs = get_pairs(word)
660
+ word = " ".join(word)
661
+ if word == "\n </w>":
662
+ word = "\n</w>"
663
+ self.cache[token] = word
664
+ return word
665
+
666
+
667
+
668
+ def init(config):
669
+ init_random(config.seed)
670
+
671
+ if config.lang == 'fr':
672
+ config.tokenizer = tokenizer = AutoTokenizer.from_pretrained(config.flavor, do_lower_case=False)
673
+
674
+ from transformers.models.xlm.tokenization_xlm import XLMTokenizer
675
+ assert isinstance(tokenizer, XLMTokenizer)
676
+
677
+ # monkey patch XLM tokenizer
678
+ import types
679
+ tokenizer.bpe = types.MethodType(bpe, tokenizer)
680
+ else:
681
+ # warning: needs to be BertTokenizer for monkey patching to work
682
+ config.tokenizer = tokenizer = BertTokenizer.from_pretrained(config.flavor, do_lower_case=False)
683
+
684
+ # warning: monkey patch tokenizer to keep case information
685
+ #from recasing_tokenizer import WordpieceTokenizer
686
+ config.tokenizer.wordpiece_tokenizer = WordpieceTokenizer(vocab=tokenizer.vocab, unk_token=tokenizer.unk_token)
687
+
688
+ if config.lang == 'fr':
689
+ config.pad_token_id = tokenizer.pad_token_id
690
+ config.cls_token_id = tokenizer.bos_token_id
691
+ config.cls_token = tokenizer.bos_token
692
+ config.sep_token_id = tokenizer.sep_token_id
693
+ config.sep_token = tokenizer.sep_token
694
+ else:
695
+ config.pad_token_id = tokenizer.pad_token_id
696
+ config.cls_token_id = tokenizer.cls_token_id
697
+ config.cls_token = tokenizer.cls_token
698
+ config.sep_token_id = tokenizer.sep_token_id
699
+ config.sep_token = tokenizer.sep_token
700
+
701
+ if not torch.cuda.is_available() and config.device == 'cuda':
702
+ print('WARNING: reverting to cpu as cuda is not available', file=sys.stderr)
703
+ config.device = torch.device(config.device if torch.cuda.is_available() else 'cpu')
704
+
705
+
706
+ def main(config, action, args):
707
+ init(config)
708
+
709
+ if action == 'train':
710
+ train(config, *args)
711
+ elif action == 'eval':
712
+ run_eval(config, *args)
713
+ elif action == 'predict':
714
+ generate_predictions(config, *args)
715
+ elif action == 'tensorize':
716
+ make_tensors(config, *args)
717
+ elif action == 'preprocess':
718
+ preprocess_text(config, *args)
719
+ else:
720
+ print('invalid action "%s"' % action)
721
+ sys.exit(1)
722
+
723
+ if __name__ == '__main__':
724
+ parser = argparse.ArgumentParser()
725
+ parser.add_argument("action", help="train|eval|predict|tensorize|preprocess", type=str)
726
+ parser.add_argument("action_args", help="arguments for selected action", type=str, nargs='*')
727
+ parser.add_argument("--seed", help="random seed", default=default_config.seed, type=int)
728
+ parser.add_argument("--lang", help="language (fr, en, zh)", default=default_config.lang, type=str)
729
+ parser.add_argument("--flavor", help="bert flavor in transformers model zoo", default=default_config.flavor, type=str)
730
+ parser.add_argument("--max-length", help="maximum input length", default=default_config.max_length, type=int)
731
+ parser.add_argument("--batch-size", help="size of batches", default=default_config.batch_size, type=int)
732
+ parser.add_argument("--device", help="computation device (cuda, cpu)", default=default_config.device, type=str)
733
+ parser.add_argument("--debug", help="whether to output more debug info", default=default_config.debug, type=bool)
734
+ parser.add_argument("--updates", help="number of training updates to perform", default=default_config.updates, type=bool)
735
+ parser.add_argument("--period", help="validation period in updates", default=default_config.period, type=bool)
736
+ parser.add_argument("--lr", help="learning rate", default=default_config.lr, type=bool)
737
+ parser.add_argument("--dab-rate", help="drop at boundaries rate", default=default_config.dab_rate, type=bool)
738
+ config = Config(**parser.parse_args().__dict__)
739
+
740
+ main(config, config.action, config.action_args)
741
+
742
+
punctuation/vosk-recasepunc-en-0.22/vosk-adapted.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ the
2
+ the
3
+ the beijing and shanghai welcome to the market strata open i'm yvonne good morning and i'm david ingles counting down of course the diablo trade on the chinese
4
+ mainland here in hong kong let's get your top stories today taper and a timetable dominating the latest fed minutes as official debates the exit path meanwhile i got beijing heading the other way hinting at the first triple r cut in more than a year and after the didi debacle here china may move to close a loophole long used
5
+ by companies to take their listings abroad all to enhance that was a horrible mistake council yesterday from china as a maybe it's time to cut the triple r to help them with small businesses they are struggling from the rise of raw material costs the key question is how likely is this yeah what they say it chances are likely it's probably going to be up yet
6
+ the fact that they're saying it might actually already mean we're getting some sentiment coming through in terms of an improved material tracker ten year yield we'll get to that in just a moment in china we're now flirting with the three percent level equity markets futures are pointing up as you can see here in china though broadly speaking though we're down for a seven day across asia seventh day in the last excuse me
7
+ in the last eight sessions here have little commodity markets we're stabilising across your oil or oil prices we're still down five six per cent from highs though as far as that is concerned fx markets your story is guys can we change the police are we're looking at generally speaking the dollar that's very much in focus here so you look at that against the euro you look at that
8
+ against the chinese currency twenty four hours ago who would have thought we were talking about this sort of more divergence and starker labour discord between where you are in a pboc to easily in the fed and very quickly we alluded to this of course if one three percent on your chinese ten year yield and we're not one point three percent lower and lower
9
+ yields there is a charge for you china's top us ten year yield is at the bottom yeah the chinatown area lowest since we saw last year of september yup
10
+ yeah it is a really big major shift in china's central bank policy that's the key question could it be coming of course let's flash out that into what we heard from the cabinet there raising the possibility of a cut to the reserve requirement ratio to both the economy at the same time we also from a former pboc official sheng songcheng said the central bank should actually
11
+ cut rates he's not just talking about a triple r and either the second half is an important window when china's monetary policy can tilt towards loosening while remaining stable and the interest rates can be lowered in a reasonable and moderate manner let's get the take from also be as well whether daisy i'm david chiu here the short of it is
12
+ so i guess one point if we still haven't gotten that if in the event that we do their take is they it might be a little bit too aggressive to address some of the softness in the economy in other words what they're saying is it needs some help the economy maybe not this much yeah there preferring perhaps perhaps liquidity injections here and there but this might signal a bit too much
13
+ for when it comes to reflating the economy joining us out of the dice all this let's bring in wang tao ubi as head of asia economics and the chief china economists as well wang tao thanks much for joining us first off do you think this is actually a real possibility now
14
+ or well will shrink or fade contro as a frequently called using triple r cut as a tool so i think yes indeed it is a real possibility that they could do this however in the past whenever the state council called for this a few days to a couple of weeks later we were
15
+ would have we would see a triple r cut if they called for it and but it's worth noting that last year in june shoot at the chicago auto quote for it and by the pbc did not hold onto with any market so i i would say at this moment it's probably a relatively high likelihood but anything
16
+ the wording is really you know about mitigating the higher cost of commodity prices they impact on at an ease and make their effective conquered funding a bit lower so it's possible that it's going to be a targeted not a overall triple cut and i i don't think this really reflects a
17
+ wholesale shift in monetary policy i think very very much in the same state concrete statement also talked about
punctuation/vosk-recasepunc-en-0.22/vosk-adapted.txt.punc ADDED
@@ -0,0 +1 @@
 
 
1
+ The. The. The Beijing and Shanghai. Welcome to the market strata open. I'm Yvonne, good morning, and I'm David Ingles, counting down, of course, the Diablo trade on the Chinese mainland here in Hong Kong. Let's get your top stories today, taper and a timetable dominating the latest Fed minutes as official debates. The exit path. Meanwhile, I got Beijing heading the other way, hinting at the first triple R cut in more than a year. And after the Didi debacle here, China may move to close a loophole. Long used by companies to take their listings abroad, all to enhance. That was a horrible mistake. Council yesterday from China as a. Maybe it's time to cut the triple R to help them with small businesses they are struggling from the rise of raw material costs. The key question is, how likely is this ? Yeah, what they say it. Chances are likely it's probably going to be up yet. The fact that they're saying it might actually already mean we're getting some sentiment coming through in terms of an improved material tracker. Ten year yield. We'll get to that in just a moment. In China. We're now flirting with the three percent level equity markets futures are pointing up. As you can see here in China, though. Broadly speaking, though, we're down for a seven day across Asia. Seventh day in the last. Excuse me, in the last eight sessions here have little commodity markets. We're stabilising across your oil or oil prices. We're still down five, six per cent from highs, though as far as that is concerned FX markets. Your story is, guys, can we change the police are we're looking at, generally speaking, the dollar. That's very much in focus here. So you look at that against the euro. You look at that against the Chinese currency Twenty four hours ago. Who would have thought we were talking about this sort of more divergence and starker labour discord between where you are in a PBOC to easily in the Fed and very quickly. We alluded to this, Of course, if one three percent on your Chinese ten year yield and we're not one point three percent lower and lower yields, there is a charge for you. China's top US ten year yield is at the bottom. Yeah, the Chinatown area lowest since we saw last year of September. Yup. Yeah, it is a really big major shift in China's central bank policy. That's the key question. Could it be coming ? Of course. Let's flash out that into what we heard from the cabinet there, raising the possibility of a cut to the reserve requirement ratio to both the economy at the same time. We also from a former PBOC official, Sheng Songcheng said the central bank should actually cut rates. He's not just talking about a triple R. And either the second half is an important window when China's monetary policy can tilt towards loosening while remaining stable and the interest rates can be lowered in a reasonable and moderate manner. Let's get the take from also be as well, whether Daisy, I'm David Chiu here, the short of it is so I guess one point, if we still haven't gotten that if in the event that we do their take is they, it might be a little bit too aggressive to address some of the softness in the economy. In other words, what they're saying is it needs some help. The economy, maybe not this much. Yeah, there, preferring perhaps perhaps liquidity injections here and there. But this might signal a bit too much for when it comes to reflating the economy. Joining us out of the dice. All this, Let's bring in Wang Tao Ubi as head of Asia Economics, and the chief China economists as well. Wang Tao, thanks much for joining us. First off, do you think this is actually a real possibility now or well will shrink or fade ? Contro as a frequently called using triple R cut as a tool. So I think yes, indeed, it is a real possibility. That they could do this. However, in the past, whenever the State Council called for this a few days to a couple of weeks later, we were. Would have we would see a triple R cut if they called for it. And. But it's worth noting that last year in June, shoot at the Chicago auto quote for it and by the PBC did not hold onto with any market so I. I would say at this moment it's probably a relatively high likelihood, but anything. The wording is really, you know about mitigating the higher cost of commodity prices they impact on at an ease and make their effective conquered funding a bit lower. So it's possible that it's going to be a targeted, not a overall triple cut and I. I don't think this really reflects a wholesale shift in monetary policy. I think very, very much in the same state. Concrete statement also talked about.
punctuation/vosk-recasepunc-ru-0.22.7z ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f23cc633c06d910e056234d6b83a11a3683e582abc56c30771dbec98a91034de
3
+ size 1639885297
punctuation/vosk-recasepunc-ru-0.22/README ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ 1. Install pytorch and transformers:
2
+
3
+ pip3 install transformers
4
+
5
+ 2. Run python3 example.py de-test.txt
6
+
7
+ 3. Compare with de-test.txt.orig
punctuation/vosk-recasepunc-ru-0.22/checkpoint ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:61fd424795c046963f88534071abde0813a4a6c66c07f0335b013825e536c1ae
3
+ size 2134070889
punctuation/vosk-recasepunc-ru-0.22/example.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import time
3
+ from transformers import logging
4
+ from recasepunc import CasePuncPredictor
5
+ from recasepunc import WordpieceTokenizer
6
+ from recasepunc import Config
7
+
8
+ logging.set_verbosity_error()
9
+
10
+ predictor = CasePuncPredictor('checkpoint', lang="ru")
11
+
12
+ text = " ".join(open(sys.argv[1]).readlines())
13
+ tokens = list(enumerate(predictor.tokenize(text)))
14
+
15
+ results = ""
16
+ for token, case_label, punc_label in predictor.predict(tokens, lambda x: x[1]):
17
+ prediction = predictor.map_punc_label(predictor.map_case_label(token[1], case_label), punc_label)
18
+ if token[1][0] != '#':
19
+ results = results + ' ' + prediction
20
+ else:
21
+ results = results + prediction
22
+
23
+ print (results.strip())
punctuation/vosk-recasepunc-ru-0.22/recasepunc.py ADDED
@@ -0,0 +1,743 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import collections
3
+ import os
4
+ import regex as re
5
+ #from mosestokenizer import *
6
+ from tqdm import tqdm
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import torch.optim as optim
11
+ import random
12
+ import unicodedata
13
+ import numpy as np
14
+ import argparse
15
+ from torch.utils.data import TensorDataset, DataLoader
16
+
17
+ from transformers import AutoModel, AutoTokenizer, BertTokenizer
18
+
19
+ default_config = argparse.Namespace(
20
+ seed=871253,
21
+ lang='ru',
22
+ #flavor='flaubert/flaubert_base_uncased',
23
+ flavor=None,
24
+ max_length=256,
25
+ batch_size=16,
26
+ updates=50000,
27
+ period=1000,
28
+ lr=1e-5,
29
+ dab_rate=0.1,
30
+ device='cuda',
31
+ debug=False
32
+ )
33
+
34
+ default_flavors = {
35
+ 'fr': 'flaubert/flaubert_base_uncased',
36
+ 'en': 'bert-base-uncased',
37
+ 'zh': 'ckiplab/bert-base-chinese',
38
+ 'tr': 'dbmdz/bert-base-turkish-uncased',
39
+ 'de': 'dbmdz/bert-base-german-uncased',
40
+ 'pt': 'neuralmind/bert-base-portuguese-cased',
41
+ 'ru': 'DeepPavlov/rubert-base-cased'
42
+ }
43
+
44
+ class Config(argparse.Namespace):
45
+ def __init__(self, **kwargs):
46
+ for key, value in default_config.__dict__.items():
47
+ setattr(self, key, value)
48
+ for key, value in kwargs.items():
49
+ setattr(self, key, value)
50
+
51
+ assert self.lang in ['fr', 'en', 'zh', 'tr', 'pt', 'de', 'ru']
52
+
53
+ if 'lang' in kwargs and ('flavor' not in kwargs or kwargs['flavor'] is None):
54
+ self.flavor = default_flavors[self.lang]
55
+
56
+ #print(self.lang, self.flavor)
57
+
58
+
59
+ def init_random(seed):
60
+ # make sure everything is deterministic
61
+ os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
62
+ #torch.use_deterministic_algorithms(True)
63
+ torch.manual_seed(seed)
64
+ torch.cuda.manual_seed_all(seed)
65
+ random.seed(seed)
66
+ np.random.seed(seed)
67
+
68
+ # NOTE: it is assumed in the implementation that y[:,0] is the punctuation label, and y[:,1] is the case label!
69
+
70
+ punctuation = {
71
+ 'O': 0,
72
+ 'COMMA': 1,
73
+ 'PERIOD': 2,
74
+ 'QUESTION': 3,
75
+ 'EXCLAMATION': 4,
76
+ }
77
+
78
+ punctuation_syms = ['', ',', '.', ' ?', ' !']
79
+
80
+ case = {
81
+ 'LOWER': 0,
82
+ 'UPPER': 1,
83
+ 'CAPITALIZE': 2,
84
+ 'OTHER': 3,
85
+ }
86
+
87
+
88
+ class Model(nn.Module):
89
+ def __init__(self, flavor, device):
90
+ super().__init__()
91
+ self.bert = AutoModel.from_pretrained(flavor)
92
+ # need a proper way of determining representation size
93
+ size = self.bert.dim if hasattr(self.bert, 'dim') else self.bert.config.pooler_fc_size if hasattr(self.bert.config, 'pooler_fc_size') else self.bert.config.emb_dim if hasattr(self.bert.config, 'emb_dim') else self.bert.config.hidden_size
94
+ self.punc = nn.Linear(size, 5)
95
+ self.case = nn.Linear(size, 4)
96
+ self.dropout = nn.Dropout(0.3)
97
+ self.to(device)
98
+
99
+ def forward(self, x):
100
+ output = self.bert(x)
101
+ representations = self.dropout(F.gelu(output['last_hidden_state']))
102
+ punc = self.punc(representations)
103
+ case = self.case(representations)
104
+ return punc, case
105
+
106
+
107
+ # randomly create sequences that align to punctuation boundaries
108
+ def drop_at_boundaries(rate, x, y, cls_token_id, sep_token_id, pad_token_id):
109
+ for i, dropped in enumerate(torch.rand((len(x),)) < rate):
110
+ if dropped:
111
+ # select all indices that are sentence endings
112
+ indices = (y[i,:,0] > 1).nonzero(as_tuple=True)[0]
113
+ if len(indices) < 2:
114
+ continue
115
+ start = indices[0] + 1
116
+ end = indices[random.randint(1, len(indices) - 1)] + 1
117
+ length = end - start
118
+ if length + 2 > len(x[i]):
119
+ continue
120
+ x[i, 0] = cls_token_id
121
+ x[i, 1: length + 1] = x[i, start: end].clone()
122
+ x[i, length + 1] = sep_token_id
123
+ x[i, length + 2:] = pad_token_id
124
+ y[i, 0] = 0
125
+ y[i, 1: length + 1] = y[i, start: end].clone()
126
+ y[i, length + 1:] = 0
127
+
128
+
129
+ def compute_performance(config, model, loader):
130
+ device = config.device
131
+ criterion = nn.CrossEntropyLoss()
132
+ model.eval()
133
+ total_loss = all_correct1 = all_correct2 = num_loss = num_perf = 0
134
+ num_ref = collections.defaultdict(float)
135
+ num_hyp = collections.defaultdict(float)
136
+ num_correct = collections.defaultdict(float)
137
+ for x, y in loader:
138
+ x = x.long().to(device)
139
+ y = y.long().to(device)
140
+ y1 = y[:,:,0]
141
+ y2 = y[:,:,1]
142
+ with torch.no_grad():
143
+ y_scores1, y_scores2 = model(x.to(device))
144
+ loss1 = criterion(y_scores1.view(y1.size(0) * y1.size(1), -1), y1.view(y1.size(0) * y1.size(1)))
145
+ loss2 = criterion(y_scores2.view(y2.size(0) * y2.size(1), -1), y2.view(y2.size(0) * y2.size(1)))
146
+ loss = loss1 + loss2
147
+ y_pred1 = torch.max(y_scores1, 2)[1]
148
+ y_pred2 = torch.max(y_scores2, 2)[1]
149
+ for label in range(1, 5):
150
+ ref = (y1 == label)
151
+ hyp = (y_pred1 == label)
152
+ correct = (ref * hyp == 1)
153
+ num_ref[label] += ref.sum()
154
+ num_hyp[label] += hyp.sum()
155
+ num_correct[label] += correct.sum()
156
+ num_ref[0] += ref.sum()
157
+ num_hyp[0] += hyp.sum()
158
+ num_correct[0] += correct.sum()
159
+ all_correct1 += (y_pred1 == y1).sum()
160
+ all_correct2 += (y_pred2 == y2).sum()
161
+ total_loss += loss.item()
162
+ num_loss += len(y)
163
+ num_perf += len(y) * config.max_length
164
+ recall = {}
165
+ precision = {}
166
+ fscore = {}
167
+ for label in range(0, 5):
168
+ recall[label] = num_correct[label] / num_ref[label] if num_ref[label] > 0 else 0
169
+ precision[label] = num_correct[label] / num_hyp[label] if num_hyp[label] > 0 else 0
170
+ fscore[label] = (2 * recall[label] * precision[label] / (recall[label] + precision[label])).item() if recall[label] + precision[label] > 0 else 0
171
+ return total_loss / num_loss, all_correct2.item() / num_perf, all_correct1.item() / num_perf, fscore
172
+
173
+
174
+ def fit(config, model, checkpoint_path, train_loader, valid_loader, iterations, valid_period=200, lr=1e-5):
175
+ device = config.device
176
+ criterion = nn.CrossEntropyLoss()
177
+ optimizer = optim.Adam(filter(lambda param: param.requires_grad, model.parameters()), lr=lr)
178
+ iteration = 0
179
+ while True:
180
+ model.train()
181
+ total_loss = num = 0
182
+ for x, y in tqdm(train_loader):
183
+ x = x.long().to(device)
184
+ y = y.long().to(device)
185
+ drop_at_boundaries(config.dab_rate, x, y, config.cls_token_id, config.sep_token_id, config.pad_token_id)
186
+ y1 = y[:,:,0]
187
+ y2 = y[:,:,1]
188
+ optimizer.zero_grad()
189
+ y_scores1, y_scores2 = model(x)
190
+ loss1 = criterion(y_scores1.view(y1.size(0) * y1.size(1), -1), y1.view(y1.size(0) * y1.size(1)))
191
+ loss2 = criterion(y_scores2.view(y2.size(0) * y2.size(1), -1), y2.view(y2.size(0) * y2.size(1)))
192
+ loss = loss1 + loss2
193
+ loss.backward()
194
+ optimizer.step()
195
+ total_loss += loss.item()
196
+ num += len(y)
197
+ if iteration % valid_period == valid_period - 1:
198
+ train_loss = total_loss / num
199
+ valid_loss, valid_accuracy_case, valid_accuracy_punc, valid_fscore = compute_performance(config, model, valid_loader)
200
+ torch.save({
201
+ 'iteration': iteration + 1,
202
+ 'model_state_dict': model.state_dict(),
203
+ 'optimizer_state_dict': optimizer.state_dict(),
204
+ 'train_loss': train_loss,
205
+ 'valid_loss': valid_loss,
206
+ 'valid_accuracy_case': valid_accuracy_case,
207
+ 'valid_accuracy_punc': valid_accuracy_punc,
208
+ 'valid_fscore': valid_fscore,
209
+ 'config': config.__dict__,
210
+ }, '%s.%d' % (checkpoint_path, iteration + 1))
211
+ print(iteration + 1, train_loss, valid_loss, valid_accuracy_case, valid_accuracy_punc, valid_fscore)
212
+ total_loss = num = 0
213
+
214
+ iteration += 1
215
+ if iteration > iterations:
216
+ return
217
+
218
+ sys.stderr.flush()
219
+ sys.stdout.flush()
220
+
221
+
222
+ def batchify(max_length, x, y):
223
+ print (x.shape)
224
+ print (y.shape)
225
+ x = x[:(len(x) // max_length) * max_length].reshape(-1, max_length)
226
+ y = y[:(len(y) // max_length) * max_length, :].reshape(-1, max_length, 2)
227
+ return x, y
228
+
229
+
230
+ def train(config, train_x_fn, train_y_fn, valid_x_fn, valid_y_fn, checkpoint_path):
231
+ X_train, Y_train = batchify(config.max_length, torch.load(train_x_fn), torch.load(train_y_fn))
232
+ X_valid, Y_valid = batchify(config.max_length, torch.load(valid_x_fn), torch.load(valid_y_fn))
233
+
234
+ train_set = TensorDataset(X_train, Y_train)
235
+ valid_set = TensorDataset(X_valid, Y_valid)
236
+
237
+ train_loader = DataLoader(train_set, batch_size=config.batch_size, shuffle=True)
238
+ valid_loader = DataLoader(valid_set, batch_size=config.batch_size)
239
+
240
+ model = Model(config.flavor, config.device)
241
+
242
+ fit(config, model, checkpoint_path, train_loader, valid_loader, config.updates, config.period, config.lr)
243
+
244
+
245
+ def run_eval(config, test_x_fn, test_y_fn, checkpoint_path):
246
+ X_test, Y_test = batchify(config.max_length, torch.load(test_x_fn), torch.load(test_y_fn))
247
+ test_set = TensorDataset(X_test, Y_test)
248
+ test_loader = DataLoader(test_set, batch_size=config.batch_size)
249
+
250
+ loaded = torch.load(checkpoint_path, map_location=config.device)
251
+ if 'config' in loaded:
252
+ config = Config(**loaded['config'])
253
+ init(config)
254
+
255
+ model = Model(config.flavor, config.device)
256
+ model.load_state_dict(loaded['model_state_dict'])
257
+
258
+ print(*compute_performance(config, model, test_loader))
259
+
260
+
261
+ def recase(token, label):
262
+ if label == case['LOWER']:
263
+ return token.lower()
264
+ elif label == case['CAPITALIZE']:
265
+ return token.lower().capitalize()
266
+ elif label == case['UPPER']:
267
+ return token.upper()
268
+ else:
269
+ return token
270
+
271
+
272
+ class CasePuncPredictor:
273
+ def __init__(self, checkpoint_path, lang=default_config.lang, flavor=default_config.flavor, device=default_config.device):
274
+ loaded = torch.load(checkpoint_path, map_location=device if torch.cuda.is_available() else 'cpu')
275
+ if 'config' in loaded:
276
+ self.config = Config(**loaded['config'])
277
+ else:
278
+ self.config = Config(lang=lang, flavor=flavor, device=device)
279
+ init(self.config)
280
+
281
+ self.model = Model(self.config.flavor, self.config.device)
282
+ self.model.load_state_dict(loaded['model_state_dict'])
283
+ self.model.eval()
284
+ self.model.to(self.config.device)
285
+
286
+ self.rev_case = {b: a for a, b in case.items()}
287
+ self.rev_punc = {b: a for a, b in punctuation.items()}
288
+
289
+ def tokenize(self, text):
290
+ return [self.config.cls_token] + self.config.tokenizer.tokenize(text) + [self.config.sep_token]
291
+
292
+ def predict(self, tokens, getter=lambda x: x):
293
+ max_length = self.config.max_length
294
+ device = self.config.device
295
+ if type(tokens) == str:
296
+ tokens = self.tokenize(tokens)
297
+ previous_label = punctuation['PERIOD']
298
+ for start in range(0, len(tokens), max_length):
299
+ instance = tokens[start: start + max_length]
300
+ if type(getter(instance[0])) == str:
301
+ ids = self.config.tokenizer.convert_tokens_to_ids(getter(token) for token in instance)
302
+ else:
303
+ ids = [getter(token) for token in instance]
304
+ if len(ids) < max_length:
305
+ ids += [0] * (max_length - len(ids))
306
+ x = torch.tensor([ids]).long().to(device)
307
+ y_scores1, y_scores2 = self.model(x)
308
+ y_pred1 = torch.max(y_scores1, 2)[1]
309
+ y_pred2 = torch.max(y_scores2, 2)[1]
310
+ for i, id, token, punc_label, case_label in zip(range(len(instance)), ids, instance, y_pred1[0].tolist()[:len(instance)], y_pred2[0].tolist()[:len(instance)]):
311
+ if id == self.config.cls_token_id or id == self.config.sep_token_id:
312
+ continue
313
+ if previous_label != None and previous_label > 1:
314
+ if case_label in [case['LOWER'], case['OTHER']]: # LOWER, OTHER
315
+ case_label = case['CAPITALIZE']
316
+ if i + start == len(tokens) - 2 and punc_label == punctuation['O']:
317
+ punc_label = punctuation['PERIOD']
318
+ yield (token, self.rev_case[case_label], self.rev_punc[punc_label])
319
+ previous_label = punc_label
320
+
321
+ def map_case_label(self, token, case_label):
322
+ if token.endswith('</w>'):
323
+ token = token[:-4]
324
+ if token.startswith('##'):
325
+ token = token[2:]
326
+ return recase(token, case[case_label])
327
+
328
+ def map_punc_label(self, token, punc_label):
329
+ if token.endswith('</w>'):
330
+ token = token[:-4]
331
+ if token.startswith('##'):
332
+ token = token[2:]
333
+ return token + punctuation_syms[punctuation[punc_label]]
334
+
335
+
336
+
337
+ def generate_predictions(config, checkpoint_path):
338
+ loaded = torch.load(checkpoint_path, map_location=config.device if torch.cuda.is_available() else 'cpu')
339
+ if 'config' in loaded:
340
+ config = Config(**loaded['config'])
341
+ init(config)
342
+
343
+ model = Model(config.flavor, config.device)
344
+ model.load_state_dict(loaded['model_state_dict'])
345
+
346
+ rev_case = {b: a for a, b in case.items()}
347
+ rev_punc = {b: a for a, b in punctuation.items()}
348
+
349
+ for line in sys.stdin:
350
+ # also drop punctuation that we may generate
351
+ line = ''.join([c for c in line if c not in mapped_punctuation])
352
+ if config.debug:
353
+ print(line)
354
+ tokens = [config.cls_token] + config.tokenizer.tokenize(line) + [config.sep_token]
355
+ if config.debug:
356
+ print(tokens)
357
+ previous_label = punctuation['PERIOD']
358
+ first_time = True
359
+ was_word = False
360
+ for start in range(0, len(tokens), config.max_length):
361
+ instance = tokens[start: start + config.max_length]
362
+ ids = config.tokenizer.convert_tokens_to_ids(instance)
363
+ #print(len(ids), file=sys.stderr)
364
+ if len(ids) < config.max_length:
365
+ ids += [config.pad_token_id] * (config.max_length - len(ids))
366
+ x = torch.tensor([ids]).long().to(config.device)
367
+ y_scores1, y_scores2 = model(x)
368
+ y_pred1 = torch.max(y_scores1, 2)[1]
369
+ y_pred2 = torch.max(y_scores2, 2)[1]
370
+ for id, token, punc_label, case_label in zip(ids, instance, y_pred1[0].tolist()[:len(instance)], y_pred2[0].tolist()[:len(instance)]):
371
+ if config.debug:
372
+ print(id, token, punc_label, case_label, file=sys.stderr)
373
+ if id == config.cls_token_id or id == config.sep_token_id:
374
+ continue
375
+ if previous_label != None and previous_label > 1:
376
+ if case_label in [case['LOWER'], case['OTHER']]:
377
+ case_label = case['CAPITALIZE']
378
+ previous_label = punc_label
379
+ # different strategy due to sub-lexical token encoding in Flaubert
380
+ if config.lang == 'fr':
381
+ if token.endswith('</w>'):
382
+ cased_token = recase(token[:-4], case_label)
383
+ if was_word:
384
+ print(' ', end='')
385
+ print(cased_token + punctuation_syms[punc_label], end='')
386
+ was_word = True
387
+ else:
388
+ cased_token = recase(token, case_label)
389
+ if was_word:
390
+ print(' ', end='')
391
+ print(cased_token, end='')
392
+ was_word = False
393
+ else:
394
+ if token.startswith('##'):
395
+ cased_token = recase(token[2:], case_label)
396
+ print(cased_token, end='')
397
+ else:
398
+ cased_token = recase(token, case_label)
399
+ if not first_time:
400
+ print(' ', end='')
401
+ first_time = False
402
+ print(cased_token + punctuation_syms[punc_label], end='')
403
+ if previous_label == 0:
404
+ print('.', end='')
405
+ print()
406
+
407
+
408
+ def label_for_case(token):
409
+ token = re.sub('[^\p{Han}\p{Ll}\p{Lu}]', '', token)
410
+ if token == token.lower():
411
+ return 'LOWER'
412
+ elif token == token.lower().capitalize():
413
+ return 'CAPITALIZE'
414
+ elif token == token.upper():
415
+ return 'UPPER'
416
+ else:
417
+ return 'OTHER'
418
+
419
+
420
+ def make_tensors(config, input_fn, output_x_fn, output_y_fn):
421
+ # count file lines without loading them
422
+ size = 0
423
+ with open(input_fn) as fp:
424
+ for line in fp:
425
+ size += 1
426
+
427
+ with open(input_fn) as fp:
428
+ X = torch.IntTensor(size)
429
+ Y = torch.ByteTensor(size, 2)
430
+
431
+ offset = 0
432
+ for n, line in enumerate(fp):
433
+ word, case_label, punc_label = line.strip().split('\t')
434
+ id = config.tokenizer.convert_tokens_to_ids(word)
435
+ if config.debug:
436
+ assert word.lower() == tokenizer.convert_ids_to_tokens(id)
437
+ X[offset] = id
438
+ Y[offset, 0] = punctuation[punc_label]
439
+ Y[offset, 1] = case[case_label]
440
+ offset += 1
441
+
442
+ torch.save(X, output_x_fn)
443
+ torch.save(Y, output_y_fn)
444
+
445
+
446
+ mapped_punctuation = {
447
+ '.': 'PERIOD',
448
+ '...': 'PERIOD',
449
+ ',': 'COMMA',
450
+ ';': 'COMMA',
451
+ ':': 'COMMA',
452
+ '(': 'COMMA',
453
+ ')': 'COMMA',
454
+ '?': 'QUESTION',
455
+ '!': 'EXCLAMATION',
456
+ ',': 'COMMA',
457
+ '!': 'EXCLAMATION',
458
+ '?': 'QUESTION',
459
+ ';': 'COMMA',
460
+ ':': 'COMMA',
461
+ '(': 'COMMA',
462
+ '(': 'COMMA',
463
+ ')': 'COMMA',
464
+ '[': 'COMMA',
465
+ ']': 'COMMA',
466
+ '【': 'COMMA',
467
+ '】': 'COMMA',
468
+ '└': 'COMMA',
469
+ '└ ': 'COMMA',
470
+ '_': 'O',
471
+ '。': 'PERIOD',
472
+ '、': 'COMMA', # enumeration comma
473
+ '、': 'COMMA',
474
+ '…': 'PERIOD',
475
+ '—': 'COMMA',
476
+ '「': 'COMMA',
477
+ '」': 'COMMA',
478
+ '.': 'PERIOD',
479
+ '《': 'O',
480
+ '》': 'O',
481
+ ',': 'COMMA',
482
+ '“': 'O',
483
+ '”': 'O',
484
+ '"': 'O',
485
+ '-': 'O',
486
+ '-': 'O',
487
+ '〉': 'COMMA',
488
+ '〈': 'COMMA',
489
+ '↑': 'O',
490
+ '〔': 'COMMA',
491
+ '〕': 'COMMA',
492
+ }
493
+
494
+ def preprocess_text(config, max_token_count=-1):
495
+ global num_tokens_output
496
+ max_token_count = int(max_token_count)
497
+ num_tokens_output = 0
498
+ def process_segment(text, punctuation):
499
+ global num_tokens_output
500
+ text = text.replace('\t', ' ')
501
+ tokens = config.tokenizer.tokenize(text)
502
+ for i, token in enumerate(tokens):
503
+ case_label = label_for_case(token)
504
+ if i == len(tokens) - 1:
505
+ print(token.lower(), case_label, punctuation, sep='\t')
506
+ else:
507
+ print(token.lower(), case_label, 'O', sep='\t')
508
+ num_tokens_output += 1
509
+ # a bit too ugly, but alternative is to throw an exception
510
+ if max_token_count > 0 and num_tokens_output >= max_token_count:
511
+ sys.exit(0)
512
+
513
+ for line in sys.stdin:
514
+ line = line.strip()
515
+ if line != '':
516
+ line = unicodedata.normalize("NFC", line)
517
+ if config.debug:
518
+ print(line)
519
+ start = 0
520
+ for i, char in enumerate(line):
521
+ if char in mapped_punctuation:
522
+ if i > start and line[start: i].strip() != '':
523
+ process_segment(line[start: i], mapped_punctuation[char])
524
+ start = i + 1
525
+ if start < len(line):
526
+ process_segment(line[start:], 'PERIOD')
527
+
528
+
529
+ def preprocess_text_old_fr(config):
530
+ assert config.lang == 'fr'
531
+ splitsents = MosesSentenceSplitter(lang)
532
+ tokenize = MosesTokenizer(lang, extra=['-no-escape'])
533
+ normalize = MosesPunctuationNormalizer(lang)
534
+
535
+ for line in sys.stdin:
536
+ if line.strip() != '':
537
+ for sentence in splitsents([normalize(line)]):
538
+ tokens = tokenize(sentence)
539
+ previous_token = None
540
+ for token in tokens:
541
+ if token in mapped_punctuation:
542
+ if previous_token != None:
543
+ print(previous_token, mapped_punctuation[token], sep='\t')
544
+ previous_token = None
545
+ elif not re.search('[\p{Han}\p{Ll}\p{Lu}\d]', token): # remove non-alphanumeric tokens
546
+ continue
547
+ else:
548
+ if previous_token != None:
549
+ print(previous_token, 'O', sep='\t')
550
+ previous_token = token
551
+ if previous_token != None:
552
+ print(previous_token, 'PERIOD', sep='\t')
553
+
554
+
555
+ # modification of the wordpiece tokenizer to keep case information even if vocab is lower cased
556
+ # forked from https://github.com/huggingface/transformers/blob/master/src/transformers/models/bert/tokenization_bert.py
557
+
558
+ class WordpieceTokenizer(object):
559
+ """Runs WordPiece tokenization."""
560
+
561
+ def __init__(self, vocab, unk_token, max_input_chars_per_word=100, keep_case=True):
562
+ self.vocab = vocab
563
+ self.unk_token = unk_token
564
+ self.max_input_chars_per_word = max_input_chars_per_word
565
+ self.keep_case = keep_case
566
+
567
+ def tokenize(self, text):
568
+ """
569
+ Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform
570
+ tokenization using the given vocabulary.
571
+ For example, :obj:`input = "unaffable"` wil return as output :obj:`["un", "##aff", "##able"]`.
572
+ Args:
573
+ text: A single token or whitespace separated tokens. This should have
574
+ already been passed through `BasicTokenizer`.
575
+ Returns:
576
+ A list of wordpiece tokens.
577
+ """
578
+
579
+ output_tokens = []
580
+ for token in text.strip().split():
581
+ chars = list(token)
582
+ if len(chars) > self.max_input_chars_per_word:
583
+ output_tokens.append(self.unk_token)
584
+ continue
585
+
586
+ is_bad = False
587
+ start = 0
588
+ sub_tokens = []
589
+ while start < len(chars):
590
+ end = len(chars)
591
+ cur_substr = None
592
+ while start < end:
593
+ substr = "".join(chars[start:end])
594
+ if start > 0:
595
+ substr = "##" + substr
596
+ # optionaly lowercase substring before checking for inclusion in vocab
597
+ if (self.keep_case and substr.lower() in self.vocab) or (substr in self.vocab):
598
+ cur_substr = substr
599
+ break
600
+ end -= 1
601
+ if cur_substr is None:
602
+ is_bad = True
603
+ break
604
+ sub_tokens.append(cur_substr)
605
+ start = end
606
+
607
+ if is_bad:
608
+ output_tokens.append(self.unk_token)
609
+ else:
610
+ output_tokens.extend(sub_tokens)
611
+ return output_tokens
612
+
613
+
614
+ # modification of XLM bpe tokenizer for keeping case information when vocab is lowercase
615
+ # forked from https://github.com/huggingface/transformers/blob/cd56f3fe7eae4a53a9880e3f5e8f91877a78271c/src/transformers/models/xlm/tokenization_xlm.py
616
+ def bpe(self, token):
617
+ def to_lower(pair):
618
+ #print(' ',pair)
619
+ return (pair[0].lower(), pair[1].lower())
620
+
621
+ from transformers.models.xlm.tokenization_xlm import get_pairs
622
+
623
+ word = tuple(token[:-1]) + (token[-1] + "</w>",)
624
+ if token in self.cache:
625
+ return self.cache[token]
626
+ pairs = get_pairs(word)
627
+
628
+ if not pairs:
629
+ return token + "</w>"
630
+
631
+ while True:
632
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(to_lower(pair), float("inf")))
633
+ #print(bigram)
634
+ if to_lower(bigram) not in self.bpe_ranks:
635
+ break
636
+ first, second = bigram
637
+ new_word = []
638
+ i = 0
639
+ while i < len(word):
640
+ try:
641
+ j = word.index(first, i)
642
+ except ValueError:
643
+ new_word.extend(word[i:])
644
+ break
645
+ else:
646
+ new_word.extend(word[i:j])
647
+ i = j
648
+
649
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
650
+ new_word.append(first + second)
651
+ i += 2
652
+ else:
653
+ new_word.append(word[i])
654
+ i += 1
655
+ new_word = tuple(new_word)
656
+ word = new_word
657
+ if len(word) == 1:
658
+ break
659
+ else:
660
+ pairs = get_pairs(word)
661
+ word = " ".join(word)
662
+ if word == "\n </w>":
663
+ word = "\n</w>"
664
+ self.cache[token] = word
665
+ return word
666
+
667
+
668
+
669
+ def init(config):
670
+ init_random(config.seed)
671
+
672
+ if config.lang == 'fr':
673
+ config.tokenizer = tokenizer = AutoTokenizer.from_pretrained(config.flavor, do_lower_case=False)
674
+
675
+ from transformers.models.xlm.tokenization_xlm import XLMTokenizer
676
+ assert isinstance(tokenizer, XLMTokenizer)
677
+
678
+ # monkey patch XLM tokenizer
679
+ import types
680
+ tokenizer.bpe = types.MethodType(bpe, tokenizer)
681
+ else:
682
+ # warning: needs to be BertTokenizer for monkey patching to work
683
+ config.tokenizer = tokenizer = BertTokenizer.from_pretrained(config.flavor, do_lower_case=False)
684
+
685
+ # warning: monkey patch tokenizer to keep case information
686
+ #from recasing_tokenizer import WordpieceTokenizer
687
+ config.tokenizer.wordpiece_tokenizer = WordpieceTokenizer(vocab=tokenizer.vocab, unk_token=tokenizer.unk_token)
688
+
689
+ if config.lang == 'fr':
690
+ config.pad_token_id = tokenizer.pad_token_id
691
+ config.cls_token_id = tokenizer.bos_token_id
692
+ config.cls_token = tokenizer.bos_token
693
+ config.sep_token_id = tokenizer.sep_token_id
694
+ config.sep_token = tokenizer.sep_token
695
+ else:
696
+ config.pad_token_id = tokenizer.pad_token_id
697
+ config.cls_token_id = tokenizer.cls_token_id
698
+ config.cls_token = tokenizer.cls_token
699
+ config.sep_token_id = tokenizer.sep_token_id
700
+ config.sep_token = tokenizer.sep_token
701
+
702
+ if not torch.cuda.is_available() and config.device == 'cuda':
703
+ print('WARNING: reverting to cpu as cuda is not available', file=sys.stderr)
704
+ config.device = torch.device(config.device if torch.cuda.is_available() else 'cpu')
705
+
706
+
707
+ def main(config, action, args):
708
+ init(config)
709
+
710
+ if action == 'train':
711
+ train(config, *args)
712
+ elif action == 'eval':
713
+ run_eval(config, *args)
714
+ elif action == 'predict':
715
+ generate_predictions(config, *args)
716
+ elif action == 'tensorize':
717
+ make_tensors(config, *args)
718
+ elif action == 'preprocess':
719
+ preprocess_text(config, *args)
720
+ else:
721
+ print('invalid action "%s"' % action)
722
+ sys.exit(1)
723
+
724
+ if __name__ == '__main__':
725
+ parser = argparse.ArgumentParser()
726
+ parser.add_argument("action", help="train|eval|predict|tensorize|preprocess", type=str)
727
+ parser.add_argument("action_args", help="arguments for selected action", type=str, nargs='*')
728
+ parser.add_argument("--seed", help="random seed", default=default_config.seed, type=int)
729
+ parser.add_argument("--lang", help="language (fr, en, zh)", default=default_config.lang, type=str)
730
+ parser.add_argument("--flavor", help="bert flavor in transformers model zoo", default=default_config.flavor, type=str)
731
+ parser.add_argument("--max-length", help="maximum input length", default=default_config.max_length, type=int)
732
+ parser.add_argument("--batch-size", help="size of batches", default=default_config.batch_size, type=int)
733
+ parser.add_argument("--device", help="computation device (cuda, cpu)", default=default_config.device, type=str)
734
+ parser.add_argument("--debug", help="whether to output more debug info", default=default_config.debug, type=bool)
735
+ parser.add_argument("--updates", help="number of training updates to perform", default=default_config.updates, type=bool)
736
+ parser.add_argument("--period", help="validation period in updates", default=default_config.period, type=bool)
737
+ parser.add_argument("--lr", help="learning rate", default=default_config.lr, type=bool)
738
+ parser.add_argument("--dab-rate", help="drop at boundaries rate", default=default_config.dab_rate, type=bool)
739
+ config = Config(**parser.parse_args().__dict__)
740
+
741
+ main(config, config.action, config.action_args)
742
+
743
+
punctuation/vosk-recasepunc-ru-0.22/ru-test.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ все смешалось в доме облонских жена узнала что муж был в связи с бывшею
2
+ в их доме француженкою-гувернанткой и объявила мужу что не может жить с
3
+ ним в одном доме положение это продолжалось уже третий день и мучительно
4
+ чувствовалось и самими супругами и всеми членами семьи и домочадцами
5
+ все члены семьи и домочадцы чувствовали что нет смысла в их сожительстве
6
+ и что на каждом постоялом дворе случайно сошедшиеся люди более связаны
7
+ между собой чем они члены семьи и домочадцы облонских жена не выходила
8
+ из своих комнат мужа третий день не было дома дети бегали по всему
9
+ дому как потерянные англичанка поссорилась с экономкой и написала
10
+ записку приятельнице прося приискать ей новое место повар ушел еще
11
+ вчера со двора во время обеда черная кухарка и кучер просили расчета
12
+ На третий день после ссоры князь степан аркадьич облонский стива как
13
+ его звали в свете в обычный час то есть в восемь часов утра
14
+ проснулся не в спальне жены а в своем кабинете на сафьянном диване
15
+ он повернул свое полное выхоленное тело на пружинах дивана как бы желая
16
+ опять заснуть надолго с другой стороны крепко обнял подушку и прижался к
17
+ ней щекой но вдруг вскочил сел на диван и открыл глаза
punctuation/vosk-recasepunc-ru-0.22/ru-test.txt.orig ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Все смешалось в доме Облонских. Жена узнала, что муж был в связи с бывшею
2
+ в их доме француженкою-гувернанткой, и объявила мужу, что не может жить с
3
+ ним в одном доме. Положение это продолжалось уже третий день и мучительно
4
+ чувствовалось и самими супругами, и всеми членами семьи, и домочадцами.
5
+ Все члены семьи и домочадцы чувствовали, что нет смысла в их сожительстве
6
+ и что на каждом постоялом дворе случайно сошедшиеся люди более связаны
7
+ между собой, чем они, члены семьи и домочадцы Облонских. Жена не выходила
8
+ из своих комнат, мужа третий день не было дома. Дети бегали по всему
9
+ дому, как потерянные; англичанка поссорилась с экономкой и написала
10
+ записку приятельнице, прося приискать ей новое место; повар ушел еще
11
+ вчера со двора, во время обеда; черная кухарка и кучер просили расчета.
12
+ На третий день после ссоры князь Степан Аркадьич Облонский -- Стива, как
13
+ его звали в свете, -- в обычный час, то есть в восемь часов утра,
14
+ проснулся не в спальне жены, а в своем кабинете, на сафьянном диване...
15
+ Он повернул свое полное, выхоленное тело на пружинах дивана, как бы желая
16
+ опять заснуть надолго, с другой стороны крепко обнял подушку и прижался к
17
+ ней щекой; но вдруг вскочил, сел на диван и открыл глаза.
speaker_indentification/vosk-model-spk-0.4.7z ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e4c7ccb7760ffc2ccb780beac3ba40907728c075fbb1cbb66b4dacc0afda4598
3
+ size 13785421
speaker_indentification/vosk-model-spk-0.4/README.txt ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ UPLOADER David Snyder
3
+ DATE 2018-05-30
4
+ KALDI VERSION 108832d
5
+
6
+ This directory contains files generated from the recipe in
7
+ egs/callhome_diarization/v2/. It's contents should be placed in a similar
8
+ directory, with symbolic links to diarization/, sid/, steps/, etc. This was
9
+ created when Kaldi's master branch was at git log
10
+ 2ad8d7821867a199e435aa36bbd13af6ed937c94.
11
+
12
+
13
+ I. Files list
14
+ ------------------------------------------------------------------------------
15
+
16
+ ./
17
+ README.txt This file
18
+ run.sh A copy of the egs/callhome_diarization/v2/run.sh
19
+ at the time of uploading this file. Use this to
20
+ figure out how to compute features, extract
21
+ embeddings, etc.
22
+
23
+ local/nnet3/xvector/tuning/
24
+ run_xvector_1a.sh This is the default recipe, at the time of
25
+ uploading this resource. The script generates
26
+ the configs, egs, and trains the model.
27
+
28
+ conf/
29
+ vad.conf The energy-based VAD configuration
30
+ mfcc.conf MFCC configuration
31
+
32
+ exp/xvector_nnet_1a/
33
+ final.raw The pretrained DNN model
34
+ nnet.config The nnet3 config file that was used when the
35
+ DNN model was first instantiated.
36
+ extract.config Another nnet3 config file that modifies the DNN
37
+ final.raw to extract x-vectors. It should be
38
+ automatically handled by the script
39
+ extract_xvectors.sh.
40
+ min_chunk_size Min chunk size used (see extract_xvectors.sh)
41
+ max_chunk_size Max chunk size used (see extract_xvectors.sh)
42
+ srand The RNG seed used when creating the DNN
43
+
44
+ exp/xvectors_callhome1/
45
+ mean.vec Vector for centering, from callhome1
46
+ transform.mat Whitening matrix, trained on callhome1
47
+ plda PLDA model for callhome1, trained on SRE data
48
+
49
+ exp/xvectors_callhome2/
50
+ mean.vec Vector for centering, from callhome2
51
+ transform.mat Whitening matrix, trained on callhome2
52
+ plda PLDA model for callhome1, trained on SRE data
53
+
54
+
55
+ II. Citation
56
+ ------------------------------------------------------------------------------
57
+
58
+ If you wish to use this architecture in a publication, please cite one of the
59
+ following papers.
60
+
61
+ The x-vector architecture:
62
+
63
+ @inproceedings{snyder2018xvector,
64
+ title={X-vectors: Robust DNN Embeddings for Speaker Recognition},
65
+ author={Snyder, D. and Garcia-Romero, D. and Sell, G. and Povey, D. and Khudanpur, S.},
66
+ booktitle={2018 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
67
+ year={2018},
68
+ organization={IEEE},
69
+ url={http://www.danielpovey.com/files/2018_icassp_xvectors.pdf}
70
+ }
71
+
72
+ Diarization with x-vectors:
73
+
74
+ @article{sell2018dihard,
75
+ title={Diarization is Hard: Some Experiences and Lessons Learned for the JHU Team in the Inaugural DIHARD Challenge},
76
+ author={Sell, G. and Snyder, D. and McCree, A. and Garcia-Romero, D. and Villalba, J. and Maciejewski, M. and Manohar, V. and Dehak, N. and Povey, D. and Watanabe, S. and Khudanpur, J.},
77
+ journal={Interspeech},
78
+ year={2018}
79
+ }
80
+
81
+
82
+ III. Recipe README.txt
83
+ ------------------------------------------------------------------------------
84
+ The following text is the README.txt from egs/callhome_diarization/v2 at the
85
+ time this archive was created.
86
+
87
+ This recipe replaces i-vectors used in the v1 recipe with embeddings extracted
88
+ from a deep neural network. In the scripts, we refer to these embeddings as
89
+ "x-vectors." The x-vector recipe in
90
+ local/nnet3/xvector/tuning/run_xvector_1a.sh is closesly based on the
91
+ following paper:
92
+
93
+ However, in this example, the x-vectors are used for diarization, rather
94
+ than speaker recognition. Diarization is performed by splitting speech
95
+ segments into very short segments (e.g., 1.5 seconds), extracting embeddings
96
+ from the segments, and clustering them to obtain speaker labels.
97
+
98
+ The recipe uses the following data for system development. This is in
99
+ addition to the NIST SRE 2000 dataset (Callhome) which is used for
100
+ evaluation (see ../README.txt).
101
+
102
+ Corpus LDC Catalog No.
103
+ SRE2004 LDC2006S44
104
+ SRE2005 Train LDC2011S01
105
+ SRE2005 Test LDC2011S04
106
+ SRE2006 Train LDC2011S09
107
+ SRE2006 Test 1 LDC2011S10
108
+ SRE2006 Test 2 LDC2012S01
109
+ SRE2008 Train LDC2011S05
110
+ SRE2008 Test LDC2011S08
111
+ SWBD2 Phase 2 LDC99S79
112
+ SWBD2 Phase 3 LDC2002S06
113
+ SWBD Cellular 1 LDC2001S13
114
+ SWBD Cellular 2 LDC2004S07
115
+
116
+ The following datasets are used in data augmentation.
117
+
118
+ MUSAN http://www.openslr.org/17
119
+ RIR_NOISES http://www.openslr.org/28
speaker_indentification/vosk-model-spk-0.4/final.ext.raw ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d30354474ac719b3cfce58186463302b988cf22ce3afee7237a2da862c4a91d1
3
+ size 14929171
speaker_indentification/vosk-model-spk-0.4/mean.vec ADDED
@@ -0,0 +1 @@
 
 
1
+ [ 4.450152 4.672029 4.148891 -1.711527 3.509846 2.931994 2.850384 3.178227 0.2563171 -0.9261234 3.37196 0.1472566 5.635284 -0.01870821 1.972103 -0.9502754 4.401544 2.795261 2.67637 3.917823 0.6549923 -0.02103148 4.064806 4.100016 3.700118 1.252804 5.399523 4.084152 4.106742 3.5622 4.165306 -0.2494654 -0.9603948 4.272289 -2.332889 -0.7292819 3.646834 0.3090337 4.624666 5.089351 -5.635771 1.634198 1.089098 4.363739 3.618721 0.2134228 -0.3965465 5.353687 4.034757 4.032773 3.749556 3.166129 3.868708 4.381798 -0.02561651 0.3426051 4.402168 0.1237091 0.8197291 3.809948 -2.995811 -1.648535 3.202967 3.239381 3.250949 -0.9064079 4.452719 0.2775586 0.80832 3.036884 5.163679 0.4273587 3.537773 2.539269 3.151272 4.064805 3.56104 4.244997 3.660802 4.949434 4.013721 1.418729 1.845101 4.74059 3.280786 -1.731479 1.492544 -2.88268 5.013491 5.327713 -2.668042 1.02902 -0.9622369 3.954224 3.2533 3.348548 2.906777 -0.3059559 4.595854 0.3410174 2.116138 4.830284 3.402886 3.014466 4.481457 5.14358 2.05649 3.883894 -0.9075359 4.574888 4.064843 -1.416883 3.493051 -0.06792944 4.978102 4.930044 4.138368 2.826191 4.031521 2.575887 0.7125556 4.15551 2.601444 1.190357 -1.060124 0.9739355 4.671662 -1.613742 ]
speaker_indentification/vosk-model-spk-0.4/mfcc.conf ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ --sample-frequency=8000
2
+ --high-freq=3700
3
+ --low-freq=20
4
+ --num-ceps=23
5
+ --allow-downsample=true
speaker_indentification/vosk-model-spk-0.4/transform.mat ADDED
Binary file (65.6 kB). View file
 
tts/vosk-model-tts-ru-0.9-multi.7z ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8e0f1ce5406abbabf29b823a11a2d937e70a1abeeb5d96c5fb518edc0cd4b949
3
+ size 761220882
tts/vosk-model-tts-ru-0.9-multi/README.md ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Russian Vosk TTS model
2
+
3
+ Version 0.9
4
+
5
+ Metrics:
6
+
7
+ CER 0.6
8
+ FAD 0.810
9
+ UTMOS 3.290
10
+ Speaker Similarity 0.875
11
+ xRT CPU 0.35
12
+ xRT GPU 0.06
13
+
14
+ License: Apache 2.0
15
+
16
+ Changelog:
17
+
18
+ * ASR alignment
19
+ * No encoder, just duration predictor
20
+ * Slightly thinner predictor width (160) to fit DiT hidden vector
21
+ * Scale for diffusion loss (to not dominate on duration loss)
22
+
tts/vosk-model-tts-ru-0.9-multi/bert/README.md ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - ru
4
+ tags:
5
+ - PyTorch
6
+ - Transformers
7
+ - bert
8
+ - exbert
9
+ pipeline_tag: fill-mask
10
+ thumbnail: "https://github.com/sberbank-ai/model-zoo"
11
+ license: apache-2.0
12
+ ---
13
+
14
+ # ruBert-base
15
+ The model architecture design, pretraining, and evaluation are documented in our preprint: [**A Family of Pretrained Transformer Language Models for Russian**](https://arxiv.org/abs/2309.10931).
16
+
17
+ The model is pretrained by the [SberDevices](https://sberdevices.ru/) team.
18
+ * Task: `mask filling`
19
+ * Type: `encoder`
20
+ * Tokenizer: `BPE`
21
+ * Dict size: `120 138`
22
+ * Num Parameters: `178 M`
23
+ * Training Data Volume `30 GB`
24
+
25
+ # Authors
26
+ + NLP core team RnD [Telegram channel](https://t.me/nlpcoreteam):
27
+ + Dmitry Zmitrovich
28
+
29
+ # Cite us
30
+ ```
31
+ @misc{zmitrovich2023family,
32
+ title={A Family of Pretrained Transformer Language Models for Russian},
33
+ author={Dmitry Zmitrovich and Alexander Abramov and Andrey Kalmykov and Maria Tikhonova and Ekaterina Taktasheva and Danil Astafurov and Mark Baushenko and Artem Snegirev and Tatiana Shavrina and Sergey Markov and Vladislav Mikhailov and Alena Fenogenova},
34
+ year={2023},
35
+ eprint={2309.10931},
36
+ archivePrefix={arXiv},
37
+ primaryClass={cs.CL}
38
+ }
39
+ ```
tts/vosk-model-tts-ru-0.9-multi/bert/model.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2e2f1740eaae5e29c2b4844625cbb01ff644b2b5fb0560bd34374c35d8a092c1
3
+ size 654361598
tts/vosk-model-tts-ru-0.9-multi/bert/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
tts/vosk-model-tts-ru-0.9-multi/config.json ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "audio": {
3
+ "sample_rate": 22050
4
+ },
5
+ "inference": {
6
+ "noise_level": 0.8,
7
+ "speech_rate": 1,
8
+ "duration_noise_level": 0.8
9
+ },
10
+ "phoneme_id_map": {
11
+ "_": 0,
12
+ "^": 1,
13
+ "$": 2,
14
+ " ": 3,
15
+ "!": 4,
16
+ "'": 5,
17
+ "(": 6,
18
+ ")": 7,
19
+ ",": 8,
20
+ "-": 9,
21
+ ".": 10,
22
+ "...": 11,
23
+ ":": 12,
24
+ ";": 13,
25
+ "?": 14,
26
+ "a0": 15,
27
+ "a1": 16,
28
+ "b": 17,
29
+ "bj": 18,
30
+ "c": 19,
31
+ "ch": 20,
32
+ "d": 21,
33
+ "dj": 22,
34
+ "e0": 23,
35
+ "e1": 24,
36
+ "f": 25,
37
+ "fj": 26,
38
+ "g": 27,
39
+ "gj": 28,
40
+ "h": 29,
41
+ "hj": 30,
42
+ "i0": 31,
43
+ "i1": 32,
44
+ "j": 33,
45
+ "k": 34,
46
+ "kj": 35,
47
+ "l": 36,
48
+ "lj": 37,
49
+ "m": 38,
50
+ "mj": 39,
51
+ "n": 40,
52
+ "nj": 41,
53
+ "o0": 42,
54
+ "o1": 43,
55
+ "p": 44,
56
+ "pj": 45,
57
+ "r": 46,
58
+ "rj": 47,
59
+ "s": 48,
60
+ "sch": 49,
61
+ "sh": 50,
62
+ "sj": 51,
63
+ "t": 52,
64
+ "tj": 53,
65
+ "u0": 54,
66
+ "u1": 55,
67
+ "v": 56,
68
+ "vj": 57,
69
+ "y0": 58,
70
+ "y1": 59,
71
+ "z": 60,
72
+ "zh": 61,
73
+ "zj": 62
74
+ },
75
+ "num_symbols": 62,
76
+ "num_speakers": 5,
77
+ "speaker_id_map": {
78
+ "female_0": 0,
79
+ "female_1": 1,
80
+ "female_2": 2,
81
+ "male_0": 3,
82
+ "male_1": 4
83
+ },
84
+ "model_type": "multistream_v1"
85
+ }
tts/vosk-model-tts-ru-0.9-multi/dictionary ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2939e72c170bb41ac8e256828cca1c5fac4db1e36717f9f53fde843b00a220ba
3
+ size 101431118
tts/vosk-model-tts-ru-0.9-multi/model.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0fa5a36b22a8bf7fe7179a3882c6371d2c01e5317019e717516f892d329c24b9
3
+ size 179314533