Upload 35 files
Browse files- .gitattributes +5 -0
- punctuation/vosk-recasepunc-de-0.21.7z +3 -0
- punctuation/vosk-recasepunc-de-0.21/README +7 -0
- punctuation/vosk-recasepunc-de-0.21/checkpoint +3 -0
- punctuation/vosk-recasepunc-de-0.21/de-test.txt +6 -0
- punctuation/vosk-recasepunc-de-0.21/de-test.txt.orig +6 -0
- punctuation/vosk-recasepunc-de-0.21/example.py +23 -0
- punctuation/vosk-recasepunc-de-0.21/recasepunc.py +742 -0
- punctuation/vosk-recasepunc-en-0.22.7z +3 -0
- punctuation/vosk-recasepunc-en-0.22/README +7 -0
- punctuation/vosk-recasepunc-en-0.22/checkpoint +3 -0
- punctuation/vosk-recasepunc-en-0.22/example.py +26 -0
- punctuation/vosk-recasepunc-en-0.22/recasepunc.py +742 -0
- punctuation/vosk-recasepunc-en-0.22/vosk-adapted.txt +17 -0
- punctuation/vosk-recasepunc-en-0.22/vosk-adapted.txt.punc +1 -0
- punctuation/vosk-recasepunc-ru-0.22.7z +3 -0
- punctuation/vosk-recasepunc-ru-0.22/README +7 -0
- punctuation/vosk-recasepunc-ru-0.22/checkpoint +3 -0
- punctuation/vosk-recasepunc-ru-0.22/example.py +23 -0
- punctuation/vosk-recasepunc-ru-0.22/recasepunc.py +743 -0
- punctuation/vosk-recasepunc-ru-0.22/ru-test.txt +17 -0
- punctuation/vosk-recasepunc-ru-0.22/ru-test.txt.orig +17 -0
- speaker_indentification/vosk-model-spk-0.4.7z +3 -0
- speaker_indentification/vosk-model-spk-0.4/README.txt +119 -0
- speaker_indentification/vosk-model-spk-0.4/final.ext.raw +3 -0
- speaker_indentification/vosk-model-spk-0.4/mean.vec +1 -0
- speaker_indentification/vosk-model-spk-0.4/mfcc.conf +5 -0
- speaker_indentification/vosk-model-spk-0.4/transform.mat +0 -0
- tts/vosk-model-tts-ru-0.9-multi.7z +3 -0
- tts/vosk-model-tts-ru-0.9-multi/README.md +22 -0
- tts/vosk-model-tts-ru-0.9-multi/bert/README.md +39 -0
- tts/vosk-model-tts-ru-0.9-multi/bert/model.onnx +3 -0
- tts/vosk-model-tts-ru-0.9-multi/bert/vocab.txt +0 -0
- tts/vosk-model-tts-ru-0.9-multi/config.json +85 -0
- tts/vosk-model-tts-ru-0.9-multi/dictionary +3 -0
- tts/vosk-model-tts-ru-0.9-multi/model.onnx +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
punctuation/vosk-recasepunc-de-0.21/checkpoint filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
punctuation/vosk-recasepunc-en-0.22/checkpoint filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
punctuation/vosk-recasepunc-ru-0.22/checkpoint filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
speaker_indentification/vosk-model-spk-0.4/final.ext.raw filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
tts/vosk-model-tts-ru-0.9-multi/dictionary filter=lfs diff=lfs merge=lfs -text
|
punctuation/vosk-recasepunc-de-0.21.7z
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:42e06dab56196498cde6e89c5e6d4b97cab72942827153b06180f6a156bebc0b
|
| 3 |
+
size 1153855092
|
punctuation/vosk-recasepunc-de-0.21/README
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
1. Install pytorch and transformers:
|
| 2 |
+
|
| 3 |
+
pip3 install transformers
|
| 4 |
+
|
| 5 |
+
2. Run python3 example.py de-test.txt
|
| 6 |
+
|
| 7 |
+
3. Compare with de-test.txt.orig
|
punctuation/vosk-recasepunc-de-0.21/checkpoint
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:90c70f58d865013bf1245d3c7fae229d6029ee0b16470c055f81a63e668de685
|
| 3 |
+
size 1315574525
|
punctuation/vosk-recasepunc-de-0.21/de-test.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
nachdem sein vater schon 1707 starb als reinhart erst elf jahre alt war
|
| 2 |
+
wurde er von hauslehrern in seega erzogen hierauf kam er 1708 in die
|
| 3 |
+
stadtschule nach frankenhausen und war dort von der dritten bis zur
|
| 4 |
+
ersten klasse der bekannte schulmann magister hoffmann stand der schule
|
| 5 |
+
als rektor vor unter dem er publice prodiret hatte also öffentlich
|
| 6 |
+
aufgetreten war um eine rede zu halten
|
punctuation/vosk-recasepunc-de-0.21/de-test.txt.orig
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Nachdem sein Vater schon 1707 starb, als Reinhart erst elf Jahre alt war,
|
| 2 |
+
wurde er von Hauslehrern in Seega erzogen. Hierauf kam er 1708 in die
|
| 3 |
+
Stadtschule nach Frankenhausen und war dort von der dritten bis zur
|
| 4 |
+
ersten Klasse. Der bekannte Schulmann Magister Hoffmann stand der Schule
|
| 5 |
+
als Rektor vor, unter dem er publice prodiret hatte, also öffentlich
|
| 6 |
+
aufgetreten war, um eine Rede zu halten.
|
punctuation/vosk-recasepunc-de-0.21/example.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import time
|
| 3 |
+
from transformers import logging
|
| 4 |
+
from recasepunc import CasePuncPredictor
|
| 5 |
+
from recasepunc import WordpieceTokenizer
|
| 6 |
+
from recasepunc import Config
|
| 7 |
+
|
| 8 |
+
logging.set_verbosity_error()
|
| 9 |
+
|
| 10 |
+
predictor = CasePuncPredictor('checkpoint', lang="de")
|
| 11 |
+
|
| 12 |
+
text = " ".join(open(sys.argv[1]).readlines())
|
| 13 |
+
tokens = list(enumerate(predictor.tokenize(text)))
|
| 14 |
+
|
| 15 |
+
results = ""
|
| 16 |
+
for token, case_label, punc_label in predictor.predict(tokens, lambda x: x[1]):
|
| 17 |
+
prediction = predictor.map_punc_label(predictor.map_case_label(token[1], case_label), punc_label)
|
| 18 |
+
if token[1][0] != '#':
|
| 19 |
+
results = results + ' ' + prediction
|
| 20 |
+
else:
|
| 21 |
+
results = results + prediction
|
| 22 |
+
|
| 23 |
+
print (results.strip())
|
punctuation/vosk-recasepunc-de-0.21/recasepunc.py
ADDED
|
@@ -0,0 +1,742 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import collections
|
| 3 |
+
import os
|
| 4 |
+
import regex as re
|
| 5 |
+
#from mosestokenizer import *
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
import torch.optim as optim
|
| 11 |
+
import random
|
| 12 |
+
import unicodedata
|
| 13 |
+
import numpy as np
|
| 14 |
+
import argparse
|
| 15 |
+
from torch.utils.data import TensorDataset, DataLoader
|
| 16 |
+
|
| 17 |
+
from transformers import AutoModel, AutoTokenizer, BertTokenizer
|
| 18 |
+
|
| 19 |
+
default_config = argparse.Namespace(
|
| 20 |
+
seed=871253,
|
| 21 |
+
lang='de',
|
| 22 |
+
#flavor='flaubert/flaubert_base_uncased',
|
| 23 |
+
flavor=None,
|
| 24 |
+
max_length=256,
|
| 25 |
+
batch_size=16,
|
| 26 |
+
updates=24000,
|
| 27 |
+
period=1000,
|
| 28 |
+
lr=1e-5,
|
| 29 |
+
dab_rate=0.1,
|
| 30 |
+
device='cuda',
|
| 31 |
+
debug=False
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
default_flavors = {
|
| 35 |
+
'fr': 'flaubert/flaubert_base_uncased',
|
| 36 |
+
'en': 'bert-base-uncased',
|
| 37 |
+
'zh': 'ckiplab/bert-base-chinese',
|
| 38 |
+
'tr': 'dbmdz/bert-base-turkish-uncased',
|
| 39 |
+
'de': 'dbmdz/bert-base-german-uncased',
|
| 40 |
+
'pt': 'neuralmind/bert-base-portuguese-cased'
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
class Config(argparse.Namespace):
|
| 44 |
+
def __init__(self, **kwargs):
|
| 45 |
+
for key, value in default_config.__dict__.items():
|
| 46 |
+
setattr(self, key, value)
|
| 47 |
+
for key, value in kwargs.items():
|
| 48 |
+
setattr(self, key, value)
|
| 49 |
+
|
| 50 |
+
assert self.lang in ['fr', 'en', 'zh', 'tr', 'pt', 'de']
|
| 51 |
+
|
| 52 |
+
if 'lang' in kwargs and ('flavor' not in kwargs or kwargs['flavor'] is None):
|
| 53 |
+
self.flavor = default_flavors[self.lang]
|
| 54 |
+
|
| 55 |
+
#print(self.lang, self.flavor)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def init_random(seed):
|
| 59 |
+
# make sure everything is deterministic
|
| 60 |
+
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
|
| 61 |
+
#torch.use_deterministic_algorithms(True)
|
| 62 |
+
torch.manual_seed(seed)
|
| 63 |
+
torch.cuda.manual_seed_all(seed)
|
| 64 |
+
random.seed(seed)
|
| 65 |
+
np.random.seed(seed)
|
| 66 |
+
|
| 67 |
+
# NOTE: it is assumed in the implementation that y[:,0] is the punctuation label, and y[:,1] is the case label!
|
| 68 |
+
|
| 69 |
+
punctuation = {
|
| 70 |
+
'O': 0,
|
| 71 |
+
'COMMA': 1,
|
| 72 |
+
'PERIOD': 2,
|
| 73 |
+
'QUESTION': 3,
|
| 74 |
+
'EXCLAMATION': 4,
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
punctuation_syms = ['', ',', '.', ' ?', ' !']
|
| 78 |
+
|
| 79 |
+
case = {
|
| 80 |
+
'LOWER': 0,
|
| 81 |
+
'UPPER': 1,
|
| 82 |
+
'CAPITALIZE': 2,
|
| 83 |
+
'OTHER': 3,
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class Model(nn.Module):
|
| 88 |
+
def __init__(self, flavor, device):
|
| 89 |
+
super().__init__()
|
| 90 |
+
self.bert = AutoModel.from_pretrained(flavor)
|
| 91 |
+
# need a proper way of determining representation size
|
| 92 |
+
size = self.bert.dim if hasattr(self.bert, 'dim') else self.bert.config.pooler_fc_size if hasattr(self.bert.config, 'pooler_fc_size') else self.bert.config.emb_dim if hasattr(self.bert.config, 'emb_dim') else self.bert.config.hidden_size
|
| 93 |
+
self.punc = nn.Linear(size, 5)
|
| 94 |
+
self.case = nn.Linear(size, 4)
|
| 95 |
+
self.dropout = nn.Dropout(0.3)
|
| 96 |
+
self.to(device)
|
| 97 |
+
|
| 98 |
+
def forward(self, x):
|
| 99 |
+
output = self.bert(x)
|
| 100 |
+
representations = self.dropout(F.gelu(output['last_hidden_state']))
|
| 101 |
+
punc = self.punc(representations)
|
| 102 |
+
case = self.case(representations)
|
| 103 |
+
return punc, case
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
# randomly create sequences that align to punctuation boundaries
|
| 107 |
+
def drop_at_boundaries(rate, x, y, cls_token_id, sep_token_id, pad_token_id):
|
| 108 |
+
for i, dropped in enumerate(torch.rand((len(x),)) < rate):
|
| 109 |
+
if dropped:
|
| 110 |
+
# select all indices that are sentence endings
|
| 111 |
+
indices = (y[i,:,0] > 1).nonzero(as_tuple=True)[0]
|
| 112 |
+
if len(indices) < 2:
|
| 113 |
+
continue
|
| 114 |
+
start = indices[0] + 1
|
| 115 |
+
end = indices[random.randint(1, len(indices) - 1)] + 1
|
| 116 |
+
length = end - start
|
| 117 |
+
if length + 2 > len(x[i]):
|
| 118 |
+
continue
|
| 119 |
+
x[i, 0] = cls_token_id
|
| 120 |
+
x[i, 1: length + 1] = x[i, start: end].clone()
|
| 121 |
+
x[i, length + 1] = sep_token_id
|
| 122 |
+
x[i, length + 2:] = pad_token_id
|
| 123 |
+
y[i, 0] = 0
|
| 124 |
+
y[i, 1: length + 1] = y[i, start: end].clone()
|
| 125 |
+
y[i, length + 1:] = 0
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def compute_performance(config, model, loader):
|
| 129 |
+
device = config.device
|
| 130 |
+
criterion = nn.CrossEntropyLoss()
|
| 131 |
+
model.eval()
|
| 132 |
+
total_loss = all_correct1 = all_correct2 = num_loss = num_perf = 0
|
| 133 |
+
num_ref = collections.defaultdict(float)
|
| 134 |
+
num_hyp = collections.defaultdict(float)
|
| 135 |
+
num_correct = collections.defaultdict(float)
|
| 136 |
+
for x, y in loader:
|
| 137 |
+
x = x.long().to(device)
|
| 138 |
+
y = y.long().to(device)
|
| 139 |
+
y1 = y[:,:,0]
|
| 140 |
+
y2 = y[:,:,1]
|
| 141 |
+
with torch.no_grad():
|
| 142 |
+
y_scores1, y_scores2 = model(x.to(device))
|
| 143 |
+
loss1 = criterion(y_scores1.view(y1.size(0) * y1.size(1), -1), y1.view(y1.size(0) * y1.size(1)))
|
| 144 |
+
loss2 = criterion(y_scores2.view(y2.size(0) * y2.size(1), -1), y2.view(y2.size(0) * y2.size(1)))
|
| 145 |
+
loss = loss1 + loss2
|
| 146 |
+
y_pred1 = torch.max(y_scores1, 2)[1]
|
| 147 |
+
y_pred2 = torch.max(y_scores2, 2)[1]
|
| 148 |
+
for label in range(1, 5):
|
| 149 |
+
ref = (y1 == label)
|
| 150 |
+
hyp = (y_pred1 == label)
|
| 151 |
+
correct = (ref * hyp == 1)
|
| 152 |
+
num_ref[label] += ref.sum()
|
| 153 |
+
num_hyp[label] += hyp.sum()
|
| 154 |
+
num_correct[label] += correct.sum()
|
| 155 |
+
num_ref[0] += ref.sum()
|
| 156 |
+
num_hyp[0] += hyp.sum()
|
| 157 |
+
num_correct[0] += correct.sum()
|
| 158 |
+
all_correct1 += (y_pred1 == y1).sum()
|
| 159 |
+
all_correct2 += (y_pred2 == y2).sum()
|
| 160 |
+
total_loss += loss.item()
|
| 161 |
+
num_loss += len(y)
|
| 162 |
+
num_perf += len(y) * config.max_length
|
| 163 |
+
recall = {}
|
| 164 |
+
precision = {}
|
| 165 |
+
fscore = {}
|
| 166 |
+
for label in range(0, 5):
|
| 167 |
+
recall[label] = num_correct[label] / num_ref[label] if num_ref[label] > 0 else 0
|
| 168 |
+
precision[label] = num_correct[label] / num_hyp[label] if num_hyp[label] > 0 else 0
|
| 169 |
+
fscore[label] = (2 * recall[label] * precision[label] / (recall[label] + precision[label])).item() if recall[label] + precision[label] > 0 else 0
|
| 170 |
+
return total_loss / num_loss, all_correct2.item() / num_perf, all_correct1.item() / num_perf, fscore
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def fit(config, model, checkpoint_path, train_loader, valid_loader, iterations, valid_period=200, lr=1e-5):
|
| 174 |
+
device = config.device
|
| 175 |
+
criterion = nn.CrossEntropyLoss()
|
| 176 |
+
optimizer = optim.Adam(filter(lambda param: param.requires_grad, model.parameters()), lr=lr)
|
| 177 |
+
iteration = 0
|
| 178 |
+
while True:
|
| 179 |
+
model.train()
|
| 180 |
+
total_loss = num = 0
|
| 181 |
+
for x, y in tqdm(train_loader):
|
| 182 |
+
x = x.long().to(device)
|
| 183 |
+
y = y.long().to(device)
|
| 184 |
+
drop_at_boundaries(config.dab_rate, x, y, config.cls_token_id, config.sep_token_id, config.pad_token_id)
|
| 185 |
+
y1 = y[:,:,0]
|
| 186 |
+
y2 = y[:,:,1]
|
| 187 |
+
optimizer.zero_grad()
|
| 188 |
+
y_scores1, y_scores2 = model(x)
|
| 189 |
+
loss1 = criterion(y_scores1.view(y1.size(0) * y1.size(1), -1), y1.view(y1.size(0) * y1.size(1)))
|
| 190 |
+
loss2 = criterion(y_scores2.view(y2.size(0) * y2.size(1), -1), y2.view(y2.size(0) * y2.size(1)))
|
| 191 |
+
loss = loss1 + loss2
|
| 192 |
+
loss.backward()
|
| 193 |
+
optimizer.step()
|
| 194 |
+
total_loss += loss.item()
|
| 195 |
+
num += len(y)
|
| 196 |
+
if iteration % valid_period == valid_period - 1:
|
| 197 |
+
train_loss = total_loss / num
|
| 198 |
+
valid_loss, valid_accuracy_case, valid_accuracy_punc, valid_fscore = compute_performance(config, model, valid_loader)
|
| 199 |
+
torch.save({
|
| 200 |
+
'iteration': iteration + 1,
|
| 201 |
+
'model_state_dict': model.state_dict(),
|
| 202 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 203 |
+
'train_loss': train_loss,
|
| 204 |
+
'valid_loss': valid_loss,
|
| 205 |
+
'valid_accuracy_case': valid_accuracy_case,
|
| 206 |
+
'valid_accuracy_punc': valid_accuracy_punc,
|
| 207 |
+
'valid_fscore': valid_fscore,
|
| 208 |
+
'config': config.__dict__,
|
| 209 |
+
}, '%s.%d' % (checkpoint_path, iteration + 1))
|
| 210 |
+
print(iteration + 1, train_loss, valid_loss, valid_accuracy_case, valid_accuracy_punc, valid_fscore)
|
| 211 |
+
total_loss = num = 0
|
| 212 |
+
|
| 213 |
+
iteration += 1
|
| 214 |
+
if iteration > iterations:
|
| 215 |
+
return
|
| 216 |
+
|
| 217 |
+
sys.stderr.flush()
|
| 218 |
+
sys.stdout.flush()
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def batchify(max_length, x, y):
|
| 222 |
+
print (x.shape)
|
| 223 |
+
print (y.shape)
|
| 224 |
+
x = x[:(len(x) // max_length) * max_length].reshape(-1, max_length)
|
| 225 |
+
y = y[:(len(y) // max_length) * max_length, :].reshape(-1, max_length, 2)
|
| 226 |
+
return x, y
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def train(config, train_x_fn, train_y_fn, valid_x_fn, valid_y_fn, checkpoint_path):
|
| 230 |
+
X_train, Y_train = batchify(config.max_length, torch.load(train_x_fn), torch.load(train_y_fn))
|
| 231 |
+
X_valid, Y_valid = batchify(config.max_length, torch.load(valid_x_fn), torch.load(valid_y_fn))
|
| 232 |
+
|
| 233 |
+
train_set = TensorDataset(X_train, Y_train)
|
| 234 |
+
valid_set = TensorDataset(X_valid, Y_valid)
|
| 235 |
+
|
| 236 |
+
train_loader = DataLoader(train_set, batch_size=config.batch_size, shuffle=True)
|
| 237 |
+
valid_loader = DataLoader(valid_set, batch_size=config.batch_size)
|
| 238 |
+
|
| 239 |
+
model = Model(config.flavor, config.device)
|
| 240 |
+
|
| 241 |
+
fit(config, model, checkpoint_path, train_loader, valid_loader, config.updates, config.period, config.lr)
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def run_eval(config, test_x_fn, test_y_fn, checkpoint_path):
|
| 245 |
+
X_test, Y_test = batchify(config.max_length, torch.load(test_x_fn), torch.load(test_y_fn))
|
| 246 |
+
test_set = TensorDataset(X_test, Y_test)
|
| 247 |
+
test_loader = DataLoader(test_set, batch_size=config.batch_size)
|
| 248 |
+
|
| 249 |
+
loaded = torch.load(checkpoint_path, map_location=config.device)
|
| 250 |
+
if 'config' in loaded:
|
| 251 |
+
config = Config(**loaded['config'])
|
| 252 |
+
init(config)
|
| 253 |
+
|
| 254 |
+
model = Model(config.flavor, config.device)
|
| 255 |
+
model.load_state_dict(loaded['model_state_dict'])
|
| 256 |
+
|
| 257 |
+
print(*compute_performance(config, model, test_loader))
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def recase(token, label):
|
| 261 |
+
if label == case['LOWER']:
|
| 262 |
+
return token.lower()
|
| 263 |
+
elif label == case['CAPITALIZE']:
|
| 264 |
+
return token.lower().capitalize()
|
| 265 |
+
elif label == case['UPPER']:
|
| 266 |
+
return token.upper()
|
| 267 |
+
else:
|
| 268 |
+
return token
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
class CasePuncPredictor:
|
| 272 |
+
def __init__(self, checkpoint_path, lang=default_config.lang, flavor=default_config.flavor, device=default_config.device):
|
| 273 |
+
loaded = torch.load(checkpoint_path, map_location=device if torch.cuda.is_available() else 'cpu')
|
| 274 |
+
if 'config' in loaded:
|
| 275 |
+
self.config = Config(**loaded['config'])
|
| 276 |
+
else:
|
| 277 |
+
self.config = Config(lang=lang, flavor=flavor, device=device)
|
| 278 |
+
init(self.config)
|
| 279 |
+
|
| 280 |
+
self.model = Model(self.config.flavor, self.config.device)
|
| 281 |
+
self.model.load_state_dict(loaded['model_state_dict'])
|
| 282 |
+
self.model.eval()
|
| 283 |
+
self.model.to(self.config.device)
|
| 284 |
+
|
| 285 |
+
self.rev_case = {b: a for a, b in case.items()}
|
| 286 |
+
self.rev_punc = {b: a for a, b in punctuation.items()}
|
| 287 |
+
|
| 288 |
+
def tokenize(self, text):
|
| 289 |
+
return [self.config.cls_token] + self.config.tokenizer.tokenize(text) + [self.config.sep_token]
|
| 290 |
+
|
| 291 |
+
def predict(self, tokens, getter=lambda x: x):
|
| 292 |
+
max_length = self.config.max_length
|
| 293 |
+
device = self.config.device
|
| 294 |
+
if type(tokens) == str:
|
| 295 |
+
tokens = self.tokenize(tokens)
|
| 296 |
+
previous_label = punctuation['PERIOD']
|
| 297 |
+
for start in range(0, len(tokens), max_length):
|
| 298 |
+
instance = tokens[start: start + max_length]
|
| 299 |
+
if type(getter(instance[0])) == str:
|
| 300 |
+
ids = self.config.tokenizer.convert_tokens_to_ids(getter(token) for token in instance)
|
| 301 |
+
else:
|
| 302 |
+
ids = [getter(token) for token in instance]
|
| 303 |
+
if len(ids) < max_length:
|
| 304 |
+
ids += [0] * (max_length - len(ids))
|
| 305 |
+
x = torch.tensor([ids]).long().to(device)
|
| 306 |
+
y_scores1, y_scores2 = self.model(x)
|
| 307 |
+
y_pred1 = torch.max(y_scores1, 2)[1]
|
| 308 |
+
y_pred2 = torch.max(y_scores2, 2)[1]
|
| 309 |
+
for i, id, token, punc_label, case_label in zip(range(len(instance)), ids, instance, y_pred1[0].tolist()[:len(instance)], y_pred2[0].tolist()[:len(instance)]):
|
| 310 |
+
if id == self.config.cls_token_id or id == self.config.sep_token_id:
|
| 311 |
+
continue
|
| 312 |
+
if previous_label != None and previous_label > 1:
|
| 313 |
+
if case_label in [case['LOWER'], case['OTHER']]: # LOWER, OTHER
|
| 314 |
+
case_label = case['CAPITALIZE']
|
| 315 |
+
if i + start == len(tokens) - 2 and punc_label == punctuation['O']:
|
| 316 |
+
punc_label = punctuation['PERIOD']
|
| 317 |
+
yield (token, self.rev_case[case_label], self.rev_punc[punc_label])
|
| 318 |
+
previous_label = punc_label
|
| 319 |
+
|
| 320 |
+
def map_case_label(self, token, case_label):
|
| 321 |
+
if token.endswith('</w>'):
|
| 322 |
+
token = token[:-4]
|
| 323 |
+
if token.startswith('##'):
|
| 324 |
+
token = token[2:]
|
| 325 |
+
return recase(token, case[case_label])
|
| 326 |
+
|
| 327 |
+
def map_punc_label(self, token, punc_label):
|
| 328 |
+
if token.endswith('</w>'):
|
| 329 |
+
token = token[:-4]
|
| 330 |
+
if token.startswith('##'):
|
| 331 |
+
token = token[2:]
|
| 332 |
+
return token + punctuation_syms[punctuation[punc_label]]
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
def generate_predictions(config, checkpoint_path):
|
| 337 |
+
loaded = torch.load(checkpoint_path, map_location=config.device if torch.cuda.is_available() else 'cpu')
|
| 338 |
+
if 'config' in loaded:
|
| 339 |
+
config = Config(**loaded['config'])
|
| 340 |
+
init(config)
|
| 341 |
+
|
| 342 |
+
model = Model(config.flavor, config.device)
|
| 343 |
+
model.load_state_dict(loaded['model_state_dict'])
|
| 344 |
+
|
| 345 |
+
rev_case = {b: a for a, b in case.items()}
|
| 346 |
+
rev_punc = {b: a for a, b in punctuation.items()}
|
| 347 |
+
|
| 348 |
+
for line in sys.stdin:
|
| 349 |
+
# also drop punctuation that we may generate
|
| 350 |
+
line = ''.join([c for c in line if c not in mapped_punctuation])
|
| 351 |
+
if config.debug:
|
| 352 |
+
print(line)
|
| 353 |
+
tokens = [config.cls_token] + config.tokenizer.tokenize(line) + [config.sep_token]
|
| 354 |
+
if config.debug:
|
| 355 |
+
print(tokens)
|
| 356 |
+
previous_label = punctuation['PERIOD']
|
| 357 |
+
first_time = True
|
| 358 |
+
was_word = False
|
| 359 |
+
for start in range(0, len(tokens), config.max_length):
|
| 360 |
+
instance = tokens[start: start + config.max_length]
|
| 361 |
+
ids = config.tokenizer.convert_tokens_to_ids(instance)
|
| 362 |
+
#print(len(ids), file=sys.stderr)
|
| 363 |
+
if len(ids) < config.max_length:
|
| 364 |
+
ids += [config.pad_token_id] * (config.max_length - len(ids))
|
| 365 |
+
x = torch.tensor([ids]).long().to(config.device)
|
| 366 |
+
y_scores1, y_scores2 = model(x)
|
| 367 |
+
y_pred1 = torch.max(y_scores1, 2)[1]
|
| 368 |
+
y_pred2 = torch.max(y_scores2, 2)[1]
|
| 369 |
+
for id, token, punc_label, case_label in zip(ids, instance, y_pred1[0].tolist()[:len(instance)], y_pred2[0].tolist()[:len(instance)]):
|
| 370 |
+
if config.debug:
|
| 371 |
+
print(id, token, punc_label, case_label, file=sys.stderr)
|
| 372 |
+
if id == config.cls_token_id or id == config.sep_token_id:
|
| 373 |
+
continue
|
| 374 |
+
if previous_label != None and previous_label > 1:
|
| 375 |
+
if case_label in [case['LOWER'], case['OTHER']]:
|
| 376 |
+
case_label = case['CAPITALIZE']
|
| 377 |
+
previous_label = punc_label
|
| 378 |
+
# different strategy due to sub-lexical token encoding in Flaubert
|
| 379 |
+
if config.lang == 'fr':
|
| 380 |
+
if token.endswith('</w>'):
|
| 381 |
+
cased_token = recase(token[:-4], case_label)
|
| 382 |
+
if was_word:
|
| 383 |
+
print(' ', end='')
|
| 384 |
+
print(cased_token + punctuation_syms[punc_label], end='')
|
| 385 |
+
was_word = True
|
| 386 |
+
else:
|
| 387 |
+
cased_token = recase(token, case_label)
|
| 388 |
+
if was_word:
|
| 389 |
+
print(' ', end='')
|
| 390 |
+
print(cased_token, end='')
|
| 391 |
+
was_word = False
|
| 392 |
+
else:
|
| 393 |
+
if token.startswith('##'):
|
| 394 |
+
cased_token = recase(token[2:], case_label)
|
| 395 |
+
print(cased_token, end='')
|
| 396 |
+
else:
|
| 397 |
+
cased_token = recase(token, case_label)
|
| 398 |
+
if not first_time:
|
| 399 |
+
print(' ', end='')
|
| 400 |
+
first_time = False
|
| 401 |
+
print(cased_token + punctuation_syms[punc_label], end='')
|
| 402 |
+
if previous_label == 0:
|
| 403 |
+
print('.', end='')
|
| 404 |
+
print()
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
def label_for_case(token):
|
| 408 |
+
token = re.sub('[^\p{Han}\p{Ll}\p{Lu}]', '', token)
|
| 409 |
+
if token == token.lower():
|
| 410 |
+
return 'LOWER'
|
| 411 |
+
elif token == token.lower().capitalize():
|
| 412 |
+
return 'CAPITALIZE'
|
| 413 |
+
elif token == token.upper():
|
| 414 |
+
return 'UPPER'
|
| 415 |
+
else:
|
| 416 |
+
return 'OTHER'
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
def make_tensors(config, input_fn, output_x_fn, output_y_fn):
|
| 420 |
+
# count file lines without loading them
|
| 421 |
+
size = 0
|
| 422 |
+
with open(input_fn) as fp:
|
| 423 |
+
for line in fp:
|
| 424 |
+
size += 1
|
| 425 |
+
|
| 426 |
+
with open(input_fn) as fp:
|
| 427 |
+
X = torch.IntTensor(size)
|
| 428 |
+
Y = torch.ByteTensor(size, 2)
|
| 429 |
+
|
| 430 |
+
offset = 0
|
| 431 |
+
for n, line in enumerate(fp):
|
| 432 |
+
word, case_label, punc_label = line.strip().split('\t')
|
| 433 |
+
id = config.tokenizer.convert_tokens_to_ids(word)
|
| 434 |
+
if config.debug:
|
| 435 |
+
assert word.lower() == tokenizer.convert_ids_to_tokens(id)
|
| 436 |
+
X[offset] = id
|
| 437 |
+
Y[offset, 0] = punctuation[punc_label]
|
| 438 |
+
Y[offset, 1] = case[case_label]
|
| 439 |
+
offset += 1
|
| 440 |
+
|
| 441 |
+
torch.save(X, output_x_fn)
|
| 442 |
+
torch.save(Y, output_y_fn)
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
mapped_punctuation = {
|
| 446 |
+
'.': 'PERIOD',
|
| 447 |
+
'...': 'PERIOD',
|
| 448 |
+
',': 'COMMA',
|
| 449 |
+
';': 'COMMA',
|
| 450 |
+
':': 'COMMA',
|
| 451 |
+
'(': 'COMMA',
|
| 452 |
+
')': 'COMMA',
|
| 453 |
+
'?': 'QUESTION',
|
| 454 |
+
'!': 'EXCLAMATION',
|
| 455 |
+
',': 'COMMA',
|
| 456 |
+
'!': 'EXCLAMATION',
|
| 457 |
+
'?': 'QUESTION',
|
| 458 |
+
';': 'COMMA',
|
| 459 |
+
':': 'COMMA',
|
| 460 |
+
'(': 'COMMA',
|
| 461 |
+
'(': 'COMMA',
|
| 462 |
+
')': 'COMMA',
|
| 463 |
+
'[': 'COMMA',
|
| 464 |
+
']': 'COMMA',
|
| 465 |
+
'【': 'COMMA',
|
| 466 |
+
'】': 'COMMA',
|
| 467 |
+
'└': 'COMMA',
|
| 468 |
+
'└ ': 'COMMA',
|
| 469 |
+
'_': 'O',
|
| 470 |
+
'。': 'PERIOD',
|
| 471 |
+
'、': 'COMMA', # enumeration comma
|
| 472 |
+
'、': 'COMMA',
|
| 473 |
+
'…': 'PERIOD',
|
| 474 |
+
'—': 'COMMA',
|
| 475 |
+
'「': 'COMMA',
|
| 476 |
+
'」': 'COMMA',
|
| 477 |
+
'.': 'PERIOD',
|
| 478 |
+
'《': 'O',
|
| 479 |
+
'》': 'O',
|
| 480 |
+
',': 'COMMA',
|
| 481 |
+
'“': 'O',
|
| 482 |
+
'”': 'O',
|
| 483 |
+
'"': 'O',
|
| 484 |
+
'-': 'O',
|
| 485 |
+
'-': 'O',
|
| 486 |
+
'〉': 'COMMA',
|
| 487 |
+
'〈': 'COMMA',
|
| 488 |
+
'↑': 'O',
|
| 489 |
+
'〔': 'COMMA',
|
| 490 |
+
'〕': 'COMMA',
|
| 491 |
+
}
|
| 492 |
+
|
| 493 |
+
def preprocess_text(config, max_token_count=-1):
|
| 494 |
+
global num_tokens_output
|
| 495 |
+
max_token_count = int(max_token_count)
|
| 496 |
+
num_tokens_output = 0
|
| 497 |
+
def process_segment(text, punctuation):
|
| 498 |
+
global num_tokens_output
|
| 499 |
+
text = text.replace('\t', ' ')
|
| 500 |
+
tokens = config.tokenizer.tokenize(text)
|
| 501 |
+
for i, token in enumerate(tokens):
|
| 502 |
+
case_label = label_for_case(token)
|
| 503 |
+
if i == len(tokens) - 1:
|
| 504 |
+
print(token.lower(), case_label, punctuation, sep='\t')
|
| 505 |
+
else:
|
| 506 |
+
print(token.lower(), case_label, 'O', sep='\t')
|
| 507 |
+
num_tokens_output += 1
|
| 508 |
+
# a bit too ugly, but alternative is to throw an exception
|
| 509 |
+
if max_token_count > 0 and num_tokens_output >= max_token_count:
|
| 510 |
+
sys.exit(0)
|
| 511 |
+
|
| 512 |
+
for line in sys.stdin:
|
| 513 |
+
line = line.strip()
|
| 514 |
+
if line != '':
|
| 515 |
+
line = unicodedata.normalize("NFC", line)
|
| 516 |
+
if config.debug:
|
| 517 |
+
print(line)
|
| 518 |
+
start = 0
|
| 519 |
+
for i, char in enumerate(line):
|
| 520 |
+
if char in mapped_punctuation:
|
| 521 |
+
if i > start and line[start: i].strip() != '':
|
| 522 |
+
process_segment(line[start: i], mapped_punctuation[char])
|
| 523 |
+
start = i + 1
|
| 524 |
+
if start < len(line):
|
| 525 |
+
process_segment(line[start:], 'PERIOD')
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
def preprocess_text_old_fr(config):
|
| 529 |
+
assert config.lang == 'fr'
|
| 530 |
+
splitsents = MosesSentenceSplitter(lang)
|
| 531 |
+
tokenize = MosesTokenizer(lang, extra=['-no-escape'])
|
| 532 |
+
normalize = MosesPunctuationNormalizer(lang)
|
| 533 |
+
|
| 534 |
+
for line in sys.stdin:
|
| 535 |
+
if line.strip() != '':
|
| 536 |
+
for sentence in splitsents([normalize(line)]):
|
| 537 |
+
tokens = tokenize(sentence)
|
| 538 |
+
previous_token = None
|
| 539 |
+
for token in tokens:
|
| 540 |
+
if token in mapped_punctuation:
|
| 541 |
+
if previous_token != None:
|
| 542 |
+
print(previous_token, mapped_punctuation[token], sep='\t')
|
| 543 |
+
previous_token = None
|
| 544 |
+
elif not re.search('[\p{Han}\p{Ll}\p{Lu}\d]', token): # remove non-alphanumeric tokens
|
| 545 |
+
continue
|
| 546 |
+
else:
|
| 547 |
+
if previous_token != None:
|
| 548 |
+
print(previous_token, 'O', sep='\t')
|
| 549 |
+
previous_token = token
|
| 550 |
+
if previous_token != None:
|
| 551 |
+
print(previous_token, 'PERIOD', sep='\t')
|
| 552 |
+
|
| 553 |
+
|
| 554 |
+
# modification of the wordpiece tokenizer to keep case information even if vocab is lower cased
|
| 555 |
+
# forked from https://github.com/huggingface/transformers/blob/master/src/transformers/models/bert/tokenization_bert.py
|
| 556 |
+
|
| 557 |
+
class WordpieceTokenizer(object):
|
| 558 |
+
"""Runs WordPiece tokenization."""
|
| 559 |
+
|
| 560 |
+
def __init__(self, vocab, unk_token, max_input_chars_per_word=100, keep_case=True):
|
| 561 |
+
self.vocab = vocab
|
| 562 |
+
self.unk_token = unk_token
|
| 563 |
+
self.max_input_chars_per_word = max_input_chars_per_word
|
| 564 |
+
self.keep_case = keep_case
|
| 565 |
+
|
| 566 |
+
def tokenize(self, text):
|
| 567 |
+
"""
|
| 568 |
+
Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform
|
| 569 |
+
tokenization using the given vocabulary.
|
| 570 |
+
For example, :obj:`input = "unaffable"` wil return as output :obj:`["un", "##aff", "##able"]`.
|
| 571 |
+
Args:
|
| 572 |
+
text: A single token or whitespace separated tokens. This should have
|
| 573 |
+
already been passed through `BasicTokenizer`.
|
| 574 |
+
Returns:
|
| 575 |
+
A list of wordpiece tokens.
|
| 576 |
+
"""
|
| 577 |
+
|
| 578 |
+
output_tokens = []
|
| 579 |
+
for token in text.strip().split():
|
| 580 |
+
chars = list(token)
|
| 581 |
+
if len(chars) > self.max_input_chars_per_word:
|
| 582 |
+
output_tokens.append(self.unk_token)
|
| 583 |
+
continue
|
| 584 |
+
|
| 585 |
+
is_bad = False
|
| 586 |
+
start = 0
|
| 587 |
+
sub_tokens = []
|
| 588 |
+
while start < len(chars):
|
| 589 |
+
end = len(chars)
|
| 590 |
+
cur_substr = None
|
| 591 |
+
while start < end:
|
| 592 |
+
substr = "".join(chars[start:end])
|
| 593 |
+
if start > 0:
|
| 594 |
+
substr = "##" + substr
|
| 595 |
+
# optionaly lowercase substring before checking for inclusion in vocab
|
| 596 |
+
if (self.keep_case and substr.lower() in self.vocab) or (substr in self.vocab):
|
| 597 |
+
cur_substr = substr
|
| 598 |
+
break
|
| 599 |
+
end -= 1
|
| 600 |
+
if cur_substr is None:
|
| 601 |
+
is_bad = True
|
| 602 |
+
break
|
| 603 |
+
sub_tokens.append(cur_substr)
|
| 604 |
+
start = end
|
| 605 |
+
|
| 606 |
+
if is_bad:
|
| 607 |
+
output_tokens.append(self.unk_token)
|
| 608 |
+
else:
|
| 609 |
+
output_tokens.extend(sub_tokens)
|
| 610 |
+
return output_tokens
|
| 611 |
+
|
| 612 |
+
|
| 613 |
+
# modification of XLM bpe tokenizer for keeping case information when vocab is lowercase
|
| 614 |
+
# forked from https://github.com/huggingface/transformers/blob/cd56f3fe7eae4a53a9880e3f5e8f91877a78271c/src/transformers/models/xlm/tokenization_xlm.py
|
| 615 |
+
def bpe(self, token):
|
| 616 |
+
def to_lower(pair):
|
| 617 |
+
#print(' ',pair)
|
| 618 |
+
return (pair[0].lower(), pair[1].lower())
|
| 619 |
+
|
| 620 |
+
from transformers.models.xlm.tokenization_xlm import get_pairs
|
| 621 |
+
|
| 622 |
+
word = tuple(token[:-1]) + (token[-1] + "</w>",)
|
| 623 |
+
if token in self.cache:
|
| 624 |
+
return self.cache[token]
|
| 625 |
+
pairs = get_pairs(word)
|
| 626 |
+
|
| 627 |
+
if not pairs:
|
| 628 |
+
return token + "</w>"
|
| 629 |
+
|
| 630 |
+
while True:
|
| 631 |
+
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(to_lower(pair), float("inf")))
|
| 632 |
+
#print(bigram)
|
| 633 |
+
if to_lower(bigram) not in self.bpe_ranks:
|
| 634 |
+
break
|
| 635 |
+
first, second = bigram
|
| 636 |
+
new_word = []
|
| 637 |
+
i = 0
|
| 638 |
+
while i < len(word):
|
| 639 |
+
try:
|
| 640 |
+
j = word.index(first, i)
|
| 641 |
+
except ValueError:
|
| 642 |
+
new_word.extend(word[i:])
|
| 643 |
+
break
|
| 644 |
+
else:
|
| 645 |
+
new_word.extend(word[i:j])
|
| 646 |
+
i = j
|
| 647 |
+
|
| 648 |
+
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
|
| 649 |
+
new_word.append(first + second)
|
| 650 |
+
i += 2
|
| 651 |
+
else:
|
| 652 |
+
new_word.append(word[i])
|
| 653 |
+
i += 1
|
| 654 |
+
new_word = tuple(new_word)
|
| 655 |
+
word = new_word
|
| 656 |
+
if len(word) == 1:
|
| 657 |
+
break
|
| 658 |
+
else:
|
| 659 |
+
pairs = get_pairs(word)
|
| 660 |
+
word = " ".join(word)
|
| 661 |
+
if word == "\n </w>":
|
| 662 |
+
word = "\n</w>"
|
| 663 |
+
self.cache[token] = word
|
| 664 |
+
return word
|
| 665 |
+
|
| 666 |
+
|
| 667 |
+
|
| 668 |
+
def init(config):
|
| 669 |
+
init_random(config.seed)
|
| 670 |
+
|
| 671 |
+
if config.lang == 'fr':
|
| 672 |
+
config.tokenizer = tokenizer = AutoTokenizer.from_pretrained(config.flavor, do_lower_case=False)
|
| 673 |
+
|
| 674 |
+
from transformers.models.xlm.tokenization_xlm import XLMTokenizer
|
| 675 |
+
assert isinstance(tokenizer, XLMTokenizer)
|
| 676 |
+
|
| 677 |
+
# monkey patch XLM tokenizer
|
| 678 |
+
import types
|
| 679 |
+
tokenizer.bpe = types.MethodType(bpe, tokenizer)
|
| 680 |
+
else:
|
| 681 |
+
# warning: needs to be BertTokenizer for monkey patching to work
|
| 682 |
+
config.tokenizer = tokenizer = BertTokenizer.from_pretrained(config.flavor, do_lower_case=False)
|
| 683 |
+
|
| 684 |
+
# warning: monkey patch tokenizer to keep case information
|
| 685 |
+
#from recasing_tokenizer import WordpieceTokenizer
|
| 686 |
+
config.tokenizer.wordpiece_tokenizer = WordpieceTokenizer(vocab=tokenizer.vocab, unk_token=tokenizer.unk_token)
|
| 687 |
+
|
| 688 |
+
if config.lang == 'fr':
|
| 689 |
+
config.pad_token_id = tokenizer.pad_token_id
|
| 690 |
+
config.cls_token_id = tokenizer.bos_token_id
|
| 691 |
+
config.cls_token = tokenizer.bos_token
|
| 692 |
+
config.sep_token_id = tokenizer.sep_token_id
|
| 693 |
+
config.sep_token = tokenizer.sep_token
|
| 694 |
+
else:
|
| 695 |
+
config.pad_token_id = tokenizer.pad_token_id
|
| 696 |
+
config.cls_token_id = tokenizer.cls_token_id
|
| 697 |
+
config.cls_token = tokenizer.cls_token
|
| 698 |
+
config.sep_token_id = tokenizer.sep_token_id
|
| 699 |
+
config.sep_token = tokenizer.sep_token
|
| 700 |
+
|
| 701 |
+
if not torch.cuda.is_available() and config.device == 'cuda':
|
| 702 |
+
print('WARNING: reverting to cpu as cuda is not available', file=sys.stderr)
|
| 703 |
+
config.device = torch.device(config.device if torch.cuda.is_available() else 'cpu')
|
| 704 |
+
|
| 705 |
+
|
| 706 |
+
def main(config, action, args):
|
| 707 |
+
init(config)
|
| 708 |
+
|
| 709 |
+
if action == 'train':
|
| 710 |
+
train(config, *args)
|
| 711 |
+
elif action == 'eval':
|
| 712 |
+
run_eval(config, *args)
|
| 713 |
+
elif action == 'predict':
|
| 714 |
+
generate_predictions(config, *args)
|
| 715 |
+
elif action == 'tensorize':
|
| 716 |
+
make_tensors(config, *args)
|
| 717 |
+
elif action == 'preprocess':
|
| 718 |
+
preprocess_text(config, *args)
|
| 719 |
+
else:
|
| 720 |
+
print('invalid action "%s"' % action)
|
| 721 |
+
sys.exit(1)
|
| 722 |
+
|
| 723 |
+
if __name__ == '__main__':
|
| 724 |
+
parser = argparse.ArgumentParser()
|
| 725 |
+
parser.add_argument("action", help="train|eval|predict|tensorize|preprocess", type=str)
|
| 726 |
+
parser.add_argument("action_args", help="arguments for selected action", type=str, nargs='*')
|
| 727 |
+
parser.add_argument("--seed", help="random seed", default=default_config.seed, type=int)
|
| 728 |
+
parser.add_argument("--lang", help="language (fr, en, zh)", default=default_config.lang, type=str)
|
| 729 |
+
parser.add_argument("--flavor", help="bert flavor in transformers model zoo", default=default_config.flavor, type=str)
|
| 730 |
+
parser.add_argument("--max-length", help="maximum input length", default=default_config.max_length, type=int)
|
| 731 |
+
parser.add_argument("--batch-size", help="size of batches", default=default_config.batch_size, type=int)
|
| 732 |
+
parser.add_argument("--device", help="computation device (cuda, cpu)", default=default_config.device, type=str)
|
| 733 |
+
parser.add_argument("--debug", help="whether to output more debug info", default=default_config.debug, type=bool)
|
| 734 |
+
parser.add_argument("--updates", help="number of training updates to perform", default=default_config.updates, type=bool)
|
| 735 |
+
parser.add_argument("--period", help="validation period in updates", default=default_config.period, type=bool)
|
| 736 |
+
parser.add_argument("--lr", help="learning rate", default=default_config.lr, type=bool)
|
| 737 |
+
parser.add_argument("--dab-rate", help="drop at boundaries rate", default=default_config.dab_rate, type=bool)
|
| 738 |
+
config = Config(**parser.parse_args().__dict__)
|
| 739 |
+
|
| 740 |
+
main(config, config.action, config.action_args)
|
| 741 |
+
|
| 742 |
+
|
punctuation/vosk-recasepunc-en-0.22.7z
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d754b827d2b3f85fe56cfb7a5262dc5658a9257ff7f9404e57595328882b8777
|
| 3 |
+
size 1148483511
|
punctuation/vosk-recasepunc-en-0.22/README
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
1. Install pytorch and transformers:
|
| 2 |
+
|
| 3 |
+
pip3 install transformers
|
| 4 |
+
|
| 5 |
+
2. Run python3 example.py de-test.txt
|
| 6 |
+
|
| 7 |
+
3. Compare with de-test.txt.orig
|
punctuation/vosk-recasepunc-en-0.22/checkpoint
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9782ccd13a130feffb13609834778421ebd39e26910d25ddcf2185a0eea75935
|
| 3 |
+
size 1310193349
|
punctuation/vosk-recasepunc-en-0.22/example.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import time
|
| 3 |
+
from transformers import logging
|
| 4 |
+
from recasepunc import CasePuncPredictor
|
| 5 |
+
from recasepunc import WordpieceTokenizer
|
| 6 |
+
from recasepunc import Config
|
| 7 |
+
|
| 8 |
+
logging.set_verbosity_error()
|
| 9 |
+
|
| 10 |
+
predictor = CasePuncPredictor('checkpoint', lang="en")
|
| 11 |
+
|
| 12 |
+
text = " ".join(open(sys.argv[1]).readlines())
|
| 13 |
+
tokens = list(enumerate(predictor.tokenize(text)))
|
| 14 |
+
|
| 15 |
+
results = ""
|
| 16 |
+
for token, case_label, punc_label in predictor.predict(tokens, lambda x: x[1]):
|
| 17 |
+
prediction = predictor.map_punc_label(predictor.map_case_label(token[1], case_label), punc_label)
|
| 18 |
+
|
| 19 |
+
if token[1][0] == '\'' or (len(results) > 0 and results[-1] == '\''):
|
| 20 |
+
results = results + prediction
|
| 21 |
+
elif token[1][0] != '#':
|
| 22 |
+
results = results + ' ' + prediction
|
| 23 |
+
else:
|
| 24 |
+
results = results + prediction
|
| 25 |
+
|
| 26 |
+
print (results.strip())
|
punctuation/vosk-recasepunc-en-0.22/recasepunc.py
ADDED
|
@@ -0,0 +1,742 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import collections
|
| 3 |
+
import os
|
| 4 |
+
import regex as re
|
| 5 |
+
#from mosestokenizer import *
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
import torch.optim as optim
|
| 11 |
+
import random
|
| 12 |
+
import unicodedata
|
| 13 |
+
import numpy as np
|
| 14 |
+
import argparse
|
| 15 |
+
from torch.utils.data import TensorDataset, DataLoader
|
| 16 |
+
|
| 17 |
+
from transformers import AutoModel, AutoTokenizer, BertTokenizer
|
| 18 |
+
|
| 19 |
+
default_config = argparse.Namespace(
|
| 20 |
+
seed=871253,
|
| 21 |
+
lang='en',
|
| 22 |
+
#flavor='flaubert/flaubert_base_uncased',
|
| 23 |
+
flavor=None,
|
| 24 |
+
max_length=256,
|
| 25 |
+
batch_size=16,
|
| 26 |
+
updates=24000,
|
| 27 |
+
period=1000,
|
| 28 |
+
lr=1e-5,
|
| 29 |
+
dab_rate=0.1,
|
| 30 |
+
device='cuda',
|
| 31 |
+
debug=False
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
default_flavors = {
|
| 35 |
+
'fr': 'flaubert/flaubert_base_uncased',
|
| 36 |
+
'en': 'bert-base-uncased',
|
| 37 |
+
'zh': 'ckiplab/bert-base-chinese',
|
| 38 |
+
'tr': 'dbmdz/bert-base-turkish-uncased',
|
| 39 |
+
'de': 'dbmdz/bert-base-german-uncased',
|
| 40 |
+
'pt': 'neuralmind/bert-base-portuguese-cased'
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
class Config(argparse.Namespace):
|
| 44 |
+
def __init__(self, **kwargs):
|
| 45 |
+
for key, value in default_config.__dict__.items():
|
| 46 |
+
setattr(self, key, value)
|
| 47 |
+
for key, value in kwargs.items():
|
| 48 |
+
setattr(self, key, value)
|
| 49 |
+
|
| 50 |
+
assert self.lang in ['fr', 'en', 'zh', 'tr', 'pt', 'de']
|
| 51 |
+
|
| 52 |
+
if 'lang' in kwargs and ('flavor' not in kwargs or kwargs['flavor'] is None):
|
| 53 |
+
self.flavor = default_flavors[self.lang]
|
| 54 |
+
|
| 55 |
+
#print(self.lang, self.flavor)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def init_random(seed):
|
| 59 |
+
# make sure everything is deterministic
|
| 60 |
+
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
|
| 61 |
+
#torch.use_deterministic_algorithms(True)
|
| 62 |
+
torch.manual_seed(seed)
|
| 63 |
+
torch.cuda.manual_seed_all(seed)
|
| 64 |
+
random.seed(seed)
|
| 65 |
+
np.random.seed(seed)
|
| 66 |
+
|
| 67 |
+
# NOTE: it is assumed in the implementation that y[:,0] is the punctuation label, and y[:,1] is the case label!
|
| 68 |
+
|
| 69 |
+
punctuation = {
|
| 70 |
+
'O': 0,
|
| 71 |
+
'COMMA': 1,
|
| 72 |
+
'PERIOD': 2,
|
| 73 |
+
'QUESTION': 3,
|
| 74 |
+
'EXCLAMATION': 4,
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
punctuation_syms = ['', ',', '.', ' ?', ' !']
|
| 78 |
+
|
| 79 |
+
case = {
|
| 80 |
+
'LOWER': 0,
|
| 81 |
+
'UPPER': 1,
|
| 82 |
+
'CAPITALIZE': 2,
|
| 83 |
+
'OTHER': 3,
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class Model(nn.Module):
|
| 88 |
+
def __init__(self, flavor, device):
|
| 89 |
+
super().__init__()
|
| 90 |
+
self.bert = AutoModel.from_pretrained(flavor)
|
| 91 |
+
# need a proper way of determining representation size
|
| 92 |
+
size = self.bert.dim if hasattr(self.bert, 'dim') else self.bert.config.pooler_fc_size if hasattr(self.bert.config, 'pooler_fc_size') else self.bert.config.emb_dim if hasattr(self.bert.config, 'emb_dim') else self.bert.config.hidden_size
|
| 93 |
+
self.punc = nn.Linear(size, 5)
|
| 94 |
+
self.case = nn.Linear(size, 4)
|
| 95 |
+
self.dropout = nn.Dropout(0.3)
|
| 96 |
+
self.to(device)
|
| 97 |
+
|
| 98 |
+
def forward(self, x):
|
| 99 |
+
output = self.bert(x)
|
| 100 |
+
representations = self.dropout(F.gelu(output['last_hidden_state']))
|
| 101 |
+
punc = self.punc(representations)
|
| 102 |
+
case = self.case(representations)
|
| 103 |
+
return punc, case
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
# randomly create sequences that align to punctuation boundaries
|
| 107 |
+
def drop_at_boundaries(rate, x, y, cls_token_id, sep_token_id, pad_token_id):
|
| 108 |
+
for i, dropped in enumerate(torch.rand((len(x),)) < rate):
|
| 109 |
+
if dropped:
|
| 110 |
+
# select all indices that are sentence endings
|
| 111 |
+
indices = (y[i,:,0] > 1).nonzero(as_tuple=True)[0]
|
| 112 |
+
if len(indices) < 2:
|
| 113 |
+
continue
|
| 114 |
+
start = indices[0] + 1
|
| 115 |
+
end = indices[random.randint(1, len(indices) - 1)] + 1
|
| 116 |
+
length = end - start
|
| 117 |
+
if length + 2 > len(x[i]):
|
| 118 |
+
continue
|
| 119 |
+
x[i, 0] = cls_token_id
|
| 120 |
+
x[i, 1: length + 1] = x[i, start: end].clone()
|
| 121 |
+
x[i, length + 1] = sep_token_id
|
| 122 |
+
x[i, length + 2:] = pad_token_id
|
| 123 |
+
y[i, 0] = 0
|
| 124 |
+
y[i, 1: length + 1] = y[i, start: end].clone()
|
| 125 |
+
y[i, length + 1:] = 0
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def compute_performance(config, model, loader):
|
| 129 |
+
device = config.device
|
| 130 |
+
criterion = nn.CrossEntropyLoss()
|
| 131 |
+
model.eval()
|
| 132 |
+
total_loss = all_correct1 = all_correct2 = num_loss = num_perf = 0
|
| 133 |
+
num_ref = collections.defaultdict(float)
|
| 134 |
+
num_hyp = collections.defaultdict(float)
|
| 135 |
+
num_correct = collections.defaultdict(float)
|
| 136 |
+
for x, y in loader:
|
| 137 |
+
x = x.long().to(device)
|
| 138 |
+
y = y.long().to(device)
|
| 139 |
+
y1 = y[:,:,0]
|
| 140 |
+
y2 = y[:,:,1]
|
| 141 |
+
with torch.no_grad():
|
| 142 |
+
y_scores1, y_scores2 = model(x.to(device))
|
| 143 |
+
loss1 = criterion(y_scores1.view(y1.size(0) * y1.size(1), -1), y1.view(y1.size(0) * y1.size(1)))
|
| 144 |
+
loss2 = criterion(y_scores2.view(y2.size(0) * y2.size(1), -1), y2.view(y2.size(0) * y2.size(1)))
|
| 145 |
+
loss = loss1 + loss2
|
| 146 |
+
y_pred1 = torch.max(y_scores1, 2)[1]
|
| 147 |
+
y_pred2 = torch.max(y_scores2, 2)[1]
|
| 148 |
+
for label in range(1, 5):
|
| 149 |
+
ref = (y1 == label)
|
| 150 |
+
hyp = (y_pred1 == label)
|
| 151 |
+
correct = (ref * hyp == 1)
|
| 152 |
+
num_ref[label] += ref.sum()
|
| 153 |
+
num_hyp[label] += hyp.sum()
|
| 154 |
+
num_correct[label] += correct.sum()
|
| 155 |
+
num_ref[0] += ref.sum()
|
| 156 |
+
num_hyp[0] += hyp.sum()
|
| 157 |
+
num_correct[0] += correct.sum()
|
| 158 |
+
all_correct1 += (y_pred1 == y1).sum()
|
| 159 |
+
all_correct2 += (y_pred2 == y2).sum()
|
| 160 |
+
total_loss += loss.item()
|
| 161 |
+
num_loss += len(y)
|
| 162 |
+
num_perf += len(y) * config.max_length
|
| 163 |
+
recall = {}
|
| 164 |
+
precision = {}
|
| 165 |
+
fscore = {}
|
| 166 |
+
for label in range(0, 5):
|
| 167 |
+
recall[label] = num_correct[label] / num_ref[label] if num_ref[label] > 0 else 0
|
| 168 |
+
precision[label] = num_correct[label] / num_hyp[label] if num_hyp[label] > 0 else 0
|
| 169 |
+
fscore[label] = (2 * recall[label] * precision[label] / (recall[label] + precision[label])).item() if recall[label] + precision[label] > 0 else 0
|
| 170 |
+
return total_loss / num_loss, all_correct2.item() / num_perf, all_correct1.item() / num_perf, fscore
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def fit(config, model, checkpoint_path, train_loader, valid_loader, iterations, valid_period=200, lr=1e-5):
|
| 174 |
+
device = config.device
|
| 175 |
+
criterion = nn.CrossEntropyLoss()
|
| 176 |
+
optimizer = optim.Adam(filter(lambda param: param.requires_grad, model.parameters()), lr=lr)
|
| 177 |
+
iteration = 0
|
| 178 |
+
while True:
|
| 179 |
+
model.train()
|
| 180 |
+
total_loss = num = 0
|
| 181 |
+
for x, y in tqdm(train_loader):
|
| 182 |
+
x = x.long().to(device)
|
| 183 |
+
y = y.long().to(device)
|
| 184 |
+
drop_at_boundaries(config.dab_rate, x, y, config.cls_token_id, config.sep_token_id, config.pad_token_id)
|
| 185 |
+
y1 = y[:,:,0]
|
| 186 |
+
y2 = y[:,:,1]
|
| 187 |
+
optimizer.zero_grad()
|
| 188 |
+
y_scores1, y_scores2 = model(x)
|
| 189 |
+
loss1 = criterion(y_scores1.view(y1.size(0) * y1.size(1), -1), y1.view(y1.size(0) * y1.size(1)))
|
| 190 |
+
loss2 = criterion(y_scores2.view(y2.size(0) * y2.size(1), -1), y2.view(y2.size(0) * y2.size(1)))
|
| 191 |
+
loss = loss1 + loss2
|
| 192 |
+
loss.backward()
|
| 193 |
+
optimizer.step()
|
| 194 |
+
total_loss += loss.item()
|
| 195 |
+
num += len(y)
|
| 196 |
+
if iteration % valid_period == valid_period - 1:
|
| 197 |
+
train_loss = total_loss / num
|
| 198 |
+
valid_loss, valid_accuracy_case, valid_accuracy_punc, valid_fscore = compute_performance(config, model, valid_loader)
|
| 199 |
+
torch.save({
|
| 200 |
+
'iteration': iteration + 1,
|
| 201 |
+
'model_state_dict': model.state_dict(),
|
| 202 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 203 |
+
'train_loss': train_loss,
|
| 204 |
+
'valid_loss': valid_loss,
|
| 205 |
+
'valid_accuracy_case': valid_accuracy_case,
|
| 206 |
+
'valid_accuracy_punc': valid_accuracy_punc,
|
| 207 |
+
'valid_fscore': valid_fscore,
|
| 208 |
+
'config': config.__dict__,
|
| 209 |
+
}, '%s.%d' % (checkpoint_path, iteration + 1))
|
| 210 |
+
print(iteration + 1, train_loss, valid_loss, valid_accuracy_case, valid_accuracy_punc, valid_fscore)
|
| 211 |
+
total_loss = num = 0
|
| 212 |
+
|
| 213 |
+
iteration += 1
|
| 214 |
+
if iteration > iterations:
|
| 215 |
+
return
|
| 216 |
+
|
| 217 |
+
sys.stderr.flush()
|
| 218 |
+
sys.stdout.flush()
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def batchify(max_length, x, y):
|
| 222 |
+
print (x.shape)
|
| 223 |
+
print (y.shape)
|
| 224 |
+
x = x[:(len(x) // max_length) * max_length].reshape(-1, max_length)
|
| 225 |
+
y = y[:(len(y) // max_length) * max_length, :].reshape(-1, max_length, 2)
|
| 226 |
+
return x, y
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def train(config, train_x_fn, train_y_fn, valid_x_fn, valid_y_fn, checkpoint_path):
|
| 230 |
+
X_train, Y_train = batchify(config.max_length, torch.load(train_x_fn), torch.load(train_y_fn))
|
| 231 |
+
X_valid, Y_valid = batchify(config.max_length, torch.load(valid_x_fn), torch.load(valid_y_fn))
|
| 232 |
+
|
| 233 |
+
train_set = TensorDataset(X_train, Y_train)
|
| 234 |
+
valid_set = TensorDataset(X_valid, Y_valid)
|
| 235 |
+
|
| 236 |
+
train_loader = DataLoader(train_set, batch_size=config.batch_size, shuffle=True)
|
| 237 |
+
valid_loader = DataLoader(valid_set, batch_size=config.batch_size)
|
| 238 |
+
|
| 239 |
+
model = Model(config.flavor, config.device)
|
| 240 |
+
|
| 241 |
+
fit(config, model, checkpoint_path, train_loader, valid_loader, config.updates, config.period, config.lr)
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def run_eval(config, test_x_fn, test_y_fn, checkpoint_path):
|
| 245 |
+
X_test, Y_test = batchify(config.max_length, torch.load(test_x_fn), torch.load(test_y_fn))
|
| 246 |
+
test_set = TensorDataset(X_test, Y_test)
|
| 247 |
+
test_loader = DataLoader(test_set, batch_size=config.batch_size)
|
| 248 |
+
|
| 249 |
+
loaded = torch.load(checkpoint_path, map_location=config.device)
|
| 250 |
+
if 'config' in loaded:
|
| 251 |
+
config = Config(**loaded['config'])
|
| 252 |
+
init(config)
|
| 253 |
+
|
| 254 |
+
model = Model(config.flavor, config.device)
|
| 255 |
+
model.load_state_dict(loaded['model_state_dict'])
|
| 256 |
+
|
| 257 |
+
print(*compute_performance(config, model, test_loader))
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def recase(token, label):
|
| 261 |
+
if label == case['LOWER']:
|
| 262 |
+
return token.lower()
|
| 263 |
+
elif label == case['CAPITALIZE']:
|
| 264 |
+
return token.lower().capitalize()
|
| 265 |
+
elif label == case['UPPER']:
|
| 266 |
+
return token.upper()
|
| 267 |
+
else:
|
| 268 |
+
return token
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
class CasePuncPredictor:
|
| 272 |
+
def __init__(self, checkpoint_path, lang=default_config.lang, flavor=default_config.flavor, device=default_config.device):
|
| 273 |
+
loaded = torch.load(checkpoint_path, map_location=device if torch.cuda.is_available() else 'cpu')
|
| 274 |
+
if 'config' in loaded:
|
| 275 |
+
self.config = Config(**loaded['config'])
|
| 276 |
+
else:
|
| 277 |
+
self.config = Config(lang=lang, flavor=flavor, device=device)
|
| 278 |
+
init(self.config)
|
| 279 |
+
|
| 280 |
+
self.model = Model(self.config.flavor, self.config.device)
|
| 281 |
+
self.model.load_state_dict(loaded['model_state_dict'])
|
| 282 |
+
self.model.eval()
|
| 283 |
+
self.model.to(self.config.device)
|
| 284 |
+
|
| 285 |
+
self.rev_case = {b: a for a, b in case.items()}
|
| 286 |
+
self.rev_punc = {b: a for a, b in punctuation.items()}
|
| 287 |
+
|
| 288 |
+
def tokenize(self, text):
|
| 289 |
+
return [self.config.cls_token] + self.config.tokenizer.tokenize(text) + [self.config.sep_token]
|
| 290 |
+
|
| 291 |
+
def predict(self, tokens, getter=lambda x: x):
|
| 292 |
+
max_length = self.config.max_length
|
| 293 |
+
device = self.config.device
|
| 294 |
+
if type(tokens) == str:
|
| 295 |
+
tokens = self.tokenize(tokens)
|
| 296 |
+
previous_label = punctuation['PERIOD']
|
| 297 |
+
for start in range(0, len(tokens), max_length):
|
| 298 |
+
instance = tokens[start: start + max_length]
|
| 299 |
+
if type(getter(instance[0])) == str:
|
| 300 |
+
ids = self.config.tokenizer.convert_tokens_to_ids(getter(token) for token in instance)
|
| 301 |
+
else:
|
| 302 |
+
ids = [getter(token) for token in instance]
|
| 303 |
+
if len(ids) < max_length:
|
| 304 |
+
ids += [0] * (max_length - len(ids))
|
| 305 |
+
x = torch.tensor([ids]).long().to(device)
|
| 306 |
+
y_scores1, y_scores2 = self.model(x)
|
| 307 |
+
y_pred1 = torch.max(y_scores1, 2)[1]
|
| 308 |
+
y_pred2 = torch.max(y_scores2, 2)[1]
|
| 309 |
+
for i, id, token, punc_label, case_label in zip(range(len(instance)), ids, instance, y_pred1[0].tolist()[:len(instance)], y_pred2[0].tolist()[:len(instance)]):
|
| 310 |
+
if id == self.config.cls_token_id or id == self.config.sep_token_id:
|
| 311 |
+
continue
|
| 312 |
+
if previous_label != None and previous_label > 1:
|
| 313 |
+
if case_label in [case['LOWER'], case['OTHER']]: # LOWER, OTHER
|
| 314 |
+
case_label = case['CAPITALIZE']
|
| 315 |
+
if i + start == len(tokens) - 2 and punc_label == punctuation['O']:
|
| 316 |
+
punc_label = punctuation['PERIOD']
|
| 317 |
+
yield (token, self.rev_case[case_label], self.rev_punc[punc_label])
|
| 318 |
+
previous_label = punc_label
|
| 319 |
+
|
| 320 |
+
def map_case_label(self, token, case_label):
|
| 321 |
+
if token.endswith('</w>'):
|
| 322 |
+
token = token[:-4]
|
| 323 |
+
if token.startswith('##'):
|
| 324 |
+
token = token[2:]
|
| 325 |
+
return recase(token, case[case_label])
|
| 326 |
+
|
| 327 |
+
def map_punc_label(self, token, punc_label):
|
| 328 |
+
if token.endswith('</w>'):
|
| 329 |
+
token = token[:-4]
|
| 330 |
+
if token.startswith('##'):
|
| 331 |
+
token = token[2:]
|
| 332 |
+
return token + punctuation_syms[punctuation[punc_label]]
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
def generate_predictions(config, checkpoint_path):
|
| 337 |
+
loaded = torch.load(checkpoint_path, map_location=config.device if torch.cuda.is_available() else 'cpu')
|
| 338 |
+
if 'config' in loaded:
|
| 339 |
+
config = Config(**loaded['config'])
|
| 340 |
+
init(config)
|
| 341 |
+
|
| 342 |
+
model = Model(config.flavor, config.device)
|
| 343 |
+
model.load_state_dict(loaded['model_state_dict'])
|
| 344 |
+
|
| 345 |
+
rev_case = {b: a for a, b in case.items()}
|
| 346 |
+
rev_punc = {b: a for a, b in punctuation.items()}
|
| 347 |
+
|
| 348 |
+
for line in sys.stdin:
|
| 349 |
+
# also drop punctuation that we may generate
|
| 350 |
+
line = ''.join([c for c in line if c not in mapped_punctuation])
|
| 351 |
+
if config.debug:
|
| 352 |
+
print(line)
|
| 353 |
+
tokens = [config.cls_token] + config.tokenizer.tokenize(line) + [config.sep_token]
|
| 354 |
+
if config.debug:
|
| 355 |
+
print(tokens)
|
| 356 |
+
previous_label = punctuation['PERIOD']
|
| 357 |
+
first_time = True
|
| 358 |
+
was_word = False
|
| 359 |
+
for start in range(0, len(tokens), config.max_length):
|
| 360 |
+
instance = tokens[start: start + config.max_length]
|
| 361 |
+
ids = config.tokenizer.convert_tokens_to_ids(instance)
|
| 362 |
+
#print(len(ids), file=sys.stderr)
|
| 363 |
+
if len(ids) < config.max_length:
|
| 364 |
+
ids += [config.pad_token_id] * (config.max_length - len(ids))
|
| 365 |
+
x = torch.tensor([ids]).long().to(config.device)
|
| 366 |
+
y_scores1, y_scores2 = model(x)
|
| 367 |
+
y_pred1 = torch.max(y_scores1, 2)[1]
|
| 368 |
+
y_pred2 = torch.max(y_scores2, 2)[1]
|
| 369 |
+
for id, token, punc_label, case_label in zip(ids, instance, y_pred1[0].tolist()[:len(instance)], y_pred2[0].tolist()[:len(instance)]):
|
| 370 |
+
if config.debug:
|
| 371 |
+
print(id, token, punc_label, case_label, file=sys.stderr)
|
| 372 |
+
if id == config.cls_token_id or id == config.sep_token_id:
|
| 373 |
+
continue
|
| 374 |
+
if previous_label != None and previous_label > 1:
|
| 375 |
+
if case_label in [case['LOWER'], case['OTHER']]:
|
| 376 |
+
case_label = case['CAPITALIZE']
|
| 377 |
+
previous_label = punc_label
|
| 378 |
+
# different strategy due to sub-lexical token encoding in Flaubert
|
| 379 |
+
if config.lang == 'fr':
|
| 380 |
+
if token.endswith('</w>'):
|
| 381 |
+
cased_token = recase(token[:-4], case_label)
|
| 382 |
+
if was_word:
|
| 383 |
+
print(' ', end='')
|
| 384 |
+
print(cased_token + punctuation_syms[punc_label], end='')
|
| 385 |
+
was_word = True
|
| 386 |
+
else:
|
| 387 |
+
cased_token = recase(token, case_label)
|
| 388 |
+
if was_word:
|
| 389 |
+
print(' ', end='')
|
| 390 |
+
print(cased_token, end='')
|
| 391 |
+
was_word = False
|
| 392 |
+
else:
|
| 393 |
+
if token.startswith('##'):
|
| 394 |
+
cased_token = recase(token[2:], case_label)
|
| 395 |
+
print(cased_token, end='')
|
| 396 |
+
else:
|
| 397 |
+
cased_token = recase(token, case_label)
|
| 398 |
+
if not first_time:
|
| 399 |
+
print(' ', end='')
|
| 400 |
+
first_time = False
|
| 401 |
+
print(cased_token + punctuation_syms[punc_label], end='')
|
| 402 |
+
if previous_label == 0:
|
| 403 |
+
print('.', end='')
|
| 404 |
+
print()
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
def label_for_case(token):
|
| 408 |
+
token = re.sub('[^\p{Han}\p{Ll}\p{Lu}]', '', token)
|
| 409 |
+
if token == token.lower():
|
| 410 |
+
return 'LOWER'
|
| 411 |
+
elif token == token.lower().capitalize():
|
| 412 |
+
return 'CAPITALIZE'
|
| 413 |
+
elif token == token.upper():
|
| 414 |
+
return 'UPPER'
|
| 415 |
+
else:
|
| 416 |
+
return 'OTHER'
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
def make_tensors(config, input_fn, output_x_fn, output_y_fn):
|
| 420 |
+
# count file lines without loading them
|
| 421 |
+
size = 0
|
| 422 |
+
with open(input_fn) as fp:
|
| 423 |
+
for line in fp:
|
| 424 |
+
size += 1
|
| 425 |
+
|
| 426 |
+
with open(input_fn) as fp:
|
| 427 |
+
X = torch.IntTensor(size)
|
| 428 |
+
Y = torch.ByteTensor(size, 2)
|
| 429 |
+
|
| 430 |
+
offset = 0
|
| 431 |
+
for n, line in enumerate(fp):
|
| 432 |
+
word, case_label, punc_label = line.strip().split('\t')
|
| 433 |
+
id = config.tokenizer.convert_tokens_to_ids(word)
|
| 434 |
+
if config.debug:
|
| 435 |
+
assert word.lower() == tokenizer.convert_ids_to_tokens(id)
|
| 436 |
+
X[offset] = id
|
| 437 |
+
Y[offset, 0] = punctuation[punc_label]
|
| 438 |
+
Y[offset, 1] = case[case_label]
|
| 439 |
+
offset += 1
|
| 440 |
+
|
| 441 |
+
torch.save(X, output_x_fn)
|
| 442 |
+
torch.save(Y, output_y_fn)
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
mapped_punctuation = {
|
| 446 |
+
'.': 'PERIOD',
|
| 447 |
+
'...': 'PERIOD',
|
| 448 |
+
',': 'COMMA',
|
| 449 |
+
';': 'COMMA',
|
| 450 |
+
':': 'COMMA',
|
| 451 |
+
'(': 'COMMA',
|
| 452 |
+
')': 'COMMA',
|
| 453 |
+
'?': 'QUESTION',
|
| 454 |
+
'!': 'EXCLAMATION',
|
| 455 |
+
',': 'COMMA',
|
| 456 |
+
'!': 'EXCLAMATION',
|
| 457 |
+
'?': 'QUESTION',
|
| 458 |
+
';': 'COMMA',
|
| 459 |
+
':': 'COMMA',
|
| 460 |
+
'(': 'COMMA',
|
| 461 |
+
'(': 'COMMA',
|
| 462 |
+
')': 'COMMA',
|
| 463 |
+
'[': 'COMMA',
|
| 464 |
+
']': 'COMMA',
|
| 465 |
+
'【': 'COMMA',
|
| 466 |
+
'】': 'COMMA',
|
| 467 |
+
'└': 'COMMA',
|
| 468 |
+
'└ ': 'COMMA',
|
| 469 |
+
'_': 'O',
|
| 470 |
+
'。': 'PERIOD',
|
| 471 |
+
'、': 'COMMA', # enumeration comma
|
| 472 |
+
'、': 'COMMA',
|
| 473 |
+
'…': 'PERIOD',
|
| 474 |
+
'—': 'COMMA',
|
| 475 |
+
'「': 'COMMA',
|
| 476 |
+
'」': 'COMMA',
|
| 477 |
+
'.': 'PERIOD',
|
| 478 |
+
'《': 'O',
|
| 479 |
+
'》': 'O',
|
| 480 |
+
',': 'COMMA',
|
| 481 |
+
'“': 'O',
|
| 482 |
+
'”': 'O',
|
| 483 |
+
'"': 'O',
|
| 484 |
+
'-': 'O',
|
| 485 |
+
'-': 'O',
|
| 486 |
+
'〉': 'COMMA',
|
| 487 |
+
'〈': 'COMMA',
|
| 488 |
+
'↑': 'O',
|
| 489 |
+
'〔': 'COMMA',
|
| 490 |
+
'〕': 'COMMA',
|
| 491 |
+
}
|
| 492 |
+
|
| 493 |
+
def preprocess_text(config, max_token_count=-1):
|
| 494 |
+
global num_tokens_output
|
| 495 |
+
max_token_count = int(max_token_count)
|
| 496 |
+
num_tokens_output = 0
|
| 497 |
+
def process_segment(text, punctuation):
|
| 498 |
+
global num_tokens_output
|
| 499 |
+
text = text.replace('\t', ' ')
|
| 500 |
+
tokens = config.tokenizer.tokenize(text)
|
| 501 |
+
for i, token in enumerate(tokens):
|
| 502 |
+
case_label = label_for_case(token)
|
| 503 |
+
if i == len(tokens) - 1:
|
| 504 |
+
print(token.lower(), case_label, punctuation, sep='\t')
|
| 505 |
+
else:
|
| 506 |
+
print(token.lower(), case_label, 'O', sep='\t')
|
| 507 |
+
num_tokens_output += 1
|
| 508 |
+
# a bit too ugly, but alternative is to throw an exception
|
| 509 |
+
if max_token_count > 0 and num_tokens_output >= max_token_count:
|
| 510 |
+
sys.exit(0)
|
| 511 |
+
|
| 512 |
+
for line in sys.stdin:
|
| 513 |
+
line = line.strip()
|
| 514 |
+
if line != '':
|
| 515 |
+
line = unicodedata.normalize("NFC", line)
|
| 516 |
+
if config.debug:
|
| 517 |
+
print(line)
|
| 518 |
+
start = 0
|
| 519 |
+
for i, char in enumerate(line):
|
| 520 |
+
if char in mapped_punctuation:
|
| 521 |
+
if i > start and line[start: i].strip() != '':
|
| 522 |
+
process_segment(line[start: i], mapped_punctuation[char])
|
| 523 |
+
start = i + 1
|
| 524 |
+
if start < len(line):
|
| 525 |
+
process_segment(line[start:], 'PERIOD')
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
def preprocess_text_old_fr(config):
|
| 529 |
+
assert config.lang == 'fr'
|
| 530 |
+
splitsents = MosesSentenceSplitter(lang)
|
| 531 |
+
tokenize = MosesTokenizer(lang, extra=['-no-escape'])
|
| 532 |
+
normalize = MosesPunctuationNormalizer(lang)
|
| 533 |
+
|
| 534 |
+
for line in sys.stdin:
|
| 535 |
+
if line.strip() != '':
|
| 536 |
+
for sentence in splitsents([normalize(line)]):
|
| 537 |
+
tokens = tokenize(sentence)
|
| 538 |
+
previous_token = None
|
| 539 |
+
for token in tokens:
|
| 540 |
+
if token in mapped_punctuation:
|
| 541 |
+
if previous_token != None:
|
| 542 |
+
print(previous_token, mapped_punctuation[token], sep='\t')
|
| 543 |
+
previous_token = None
|
| 544 |
+
elif not re.search('[\p{Han}\p{Ll}\p{Lu}\d]', token): # remove non-alphanumeric tokens
|
| 545 |
+
continue
|
| 546 |
+
else:
|
| 547 |
+
if previous_token != None:
|
| 548 |
+
print(previous_token, 'O', sep='\t')
|
| 549 |
+
previous_token = token
|
| 550 |
+
if previous_token != None:
|
| 551 |
+
print(previous_token, 'PERIOD', sep='\t')
|
| 552 |
+
|
| 553 |
+
|
| 554 |
+
# modification of the wordpiece tokenizer to keep case information even if vocab is lower cased
|
| 555 |
+
# forked from https://github.com/huggingface/transformers/blob/master/src/transformers/models/bert/tokenization_bert.py
|
| 556 |
+
|
| 557 |
+
class WordpieceTokenizer(object):
|
| 558 |
+
"""Runs WordPiece tokenization."""
|
| 559 |
+
|
| 560 |
+
def __init__(self, vocab, unk_token, max_input_chars_per_word=100, keep_case=True):
|
| 561 |
+
self.vocab = vocab
|
| 562 |
+
self.unk_token = unk_token
|
| 563 |
+
self.max_input_chars_per_word = max_input_chars_per_word
|
| 564 |
+
self.keep_case = keep_case
|
| 565 |
+
|
| 566 |
+
def tokenize(self, text):
|
| 567 |
+
"""
|
| 568 |
+
Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform
|
| 569 |
+
tokenization using the given vocabulary.
|
| 570 |
+
For example, :obj:`input = "unaffable"` wil return as output :obj:`["un", "##aff", "##able"]`.
|
| 571 |
+
Args:
|
| 572 |
+
text: A single token or whitespace separated tokens. This should have
|
| 573 |
+
already been passed through `BasicTokenizer`.
|
| 574 |
+
Returns:
|
| 575 |
+
A list of wordpiece tokens.
|
| 576 |
+
"""
|
| 577 |
+
|
| 578 |
+
output_tokens = []
|
| 579 |
+
for token in text.strip().split():
|
| 580 |
+
chars = list(token)
|
| 581 |
+
if len(chars) > self.max_input_chars_per_word:
|
| 582 |
+
output_tokens.append(self.unk_token)
|
| 583 |
+
continue
|
| 584 |
+
|
| 585 |
+
is_bad = False
|
| 586 |
+
start = 0
|
| 587 |
+
sub_tokens = []
|
| 588 |
+
while start < len(chars):
|
| 589 |
+
end = len(chars)
|
| 590 |
+
cur_substr = None
|
| 591 |
+
while start < end:
|
| 592 |
+
substr = "".join(chars[start:end])
|
| 593 |
+
if start > 0:
|
| 594 |
+
substr = "##" + substr
|
| 595 |
+
# optionaly lowercase substring before checking for inclusion in vocab
|
| 596 |
+
if (self.keep_case and substr.lower() in self.vocab) or (substr in self.vocab):
|
| 597 |
+
cur_substr = substr
|
| 598 |
+
break
|
| 599 |
+
end -= 1
|
| 600 |
+
if cur_substr is None:
|
| 601 |
+
is_bad = True
|
| 602 |
+
break
|
| 603 |
+
sub_tokens.append(cur_substr)
|
| 604 |
+
start = end
|
| 605 |
+
|
| 606 |
+
if is_bad:
|
| 607 |
+
output_tokens.append(self.unk_token)
|
| 608 |
+
else:
|
| 609 |
+
output_tokens.extend(sub_tokens)
|
| 610 |
+
return output_tokens
|
| 611 |
+
|
| 612 |
+
|
| 613 |
+
# modification of XLM bpe tokenizer for keeping case information when vocab is lowercase
|
| 614 |
+
# forked from https://github.com/huggingface/transformers/blob/cd56f3fe7eae4a53a9880e3f5e8f91877a78271c/src/transformers/models/xlm/tokenization_xlm.py
|
| 615 |
+
def bpe(self, token):
|
| 616 |
+
def to_lower(pair):
|
| 617 |
+
#print(' ',pair)
|
| 618 |
+
return (pair[0].lower(), pair[1].lower())
|
| 619 |
+
|
| 620 |
+
from transformers.models.xlm.tokenization_xlm import get_pairs
|
| 621 |
+
|
| 622 |
+
word = tuple(token[:-1]) + (token[-1] + "</w>",)
|
| 623 |
+
if token in self.cache:
|
| 624 |
+
return self.cache[token]
|
| 625 |
+
pairs = get_pairs(word)
|
| 626 |
+
|
| 627 |
+
if not pairs:
|
| 628 |
+
return token + "</w>"
|
| 629 |
+
|
| 630 |
+
while True:
|
| 631 |
+
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(to_lower(pair), float("inf")))
|
| 632 |
+
#print(bigram)
|
| 633 |
+
if to_lower(bigram) not in self.bpe_ranks:
|
| 634 |
+
break
|
| 635 |
+
first, second = bigram
|
| 636 |
+
new_word = []
|
| 637 |
+
i = 0
|
| 638 |
+
while i < len(word):
|
| 639 |
+
try:
|
| 640 |
+
j = word.index(first, i)
|
| 641 |
+
except ValueError:
|
| 642 |
+
new_word.extend(word[i:])
|
| 643 |
+
break
|
| 644 |
+
else:
|
| 645 |
+
new_word.extend(word[i:j])
|
| 646 |
+
i = j
|
| 647 |
+
|
| 648 |
+
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
|
| 649 |
+
new_word.append(first + second)
|
| 650 |
+
i += 2
|
| 651 |
+
else:
|
| 652 |
+
new_word.append(word[i])
|
| 653 |
+
i += 1
|
| 654 |
+
new_word = tuple(new_word)
|
| 655 |
+
word = new_word
|
| 656 |
+
if len(word) == 1:
|
| 657 |
+
break
|
| 658 |
+
else:
|
| 659 |
+
pairs = get_pairs(word)
|
| 660 |
+
word = " ".join(word)
|
| 661 |
+
if word == "\n </w>":
|
| 662 |
+
word = "\n</w>"
|
| 663 |
+
self.cache[token] = word
|
| 664 |
+
return word
|
| 665 |
+
|
| 666 |
+
|
| 667 |
+
|
| 668 |
+
def init(config):
|
| 669 |
+
init_random(config.seed)
|
| 670 |
+
|
| 671 |
+
if config.lang == 'fr':
|
| 672 |
+
config.tokenizer = tokenizer = AutoTokenizer.from_pretrained(config.flavor, do_lower_case=False)
|
| 673 |
+
|
| 674 |
+
from transformers.models.xlm.tokenization_xlm import XLMTokenizer
|
| 675 |
+
assert isinstance(tokenizer, XLMTokenizer)
|
| 676 |
+
|
| 677 |
+
# monkey patch XLM tokenizer
|
| 678 |
+
import types
|
| 679 |
+
tokenizer.bpe = types.MethodType(bpe, tokenizer)
|
| 680 |
+
else:
|
| 681 |
+
# warning: needs to be BertTokenizer for monkey patching to work
|
| 682 |
+
config.tokenizer = tokenizer = BertTokenizer.from_pretrained(config.flavor, do_lower_case=False)
|
| 683 |
+
|
| 684 |
+
# warning: monkey patch tokenizer to keep case information
|
| 685 |
+
#from recasing_tokenizer import WordpieceTokenizer
|
| 686 |
+
config.tokenizer.wordpiece_tokenizer = WordpieceTokenizer(vocab=tokenizer.vocab, unk_token=tokenizer.unk_token)
|
| 687 |
+
|
| 688 |
+
if config.lang == 'fr':
|
| 689 |
+
config.pad_token_id = tokenizer.pad_token_id
|
| 690 |
+
config.cls_token_id = tokenizer.bos_token_id
|
| 691 |
+
config.cls_token = tokenizer.bos_token
|
| 692 |
+
config.sep_token_id = tokenizer.sep_token_id
|
| 693 |
+
config.sep_token = tokenizer.sep_token
|
| 694 |
+
else:
|
| 695 |
+
config.pad_token_id = tokenizer.pad_token_id
|
| 696 |
+
config.cls_token_id = tokenizer.cls_token_id
|
| 697 |
+
config.cls_token = tokenizer.cls_token
|
| 698 |
+
config.sep_token_id = tokenizer.sep_token_id
|
| 699 |
+
config.sep_token = tokenizer.sep_token
|
| 700 |
+
|
| 701 |
+
if not torch.cuda.is_available() and config.device == 'cuda':
|
| 702 |
+
print('WARNING: reverting to cpu as cuda is not available', file=sys.stderr)
|
| 703 |
+
config.device = torch.device(config.device if torch.cuda.is_available() else 'cpu')
|
| 704 |
+
|
| 705 |
+
|
| 706 |
+
def main(config, action, args):
|
| 707 |
+
init(config)
|
| 708 |
+
|
| 709 |
+
if action == 'train':
|
| 710 |
+
train(config, *args)
|
| 711 |
+
elif action == 'eval':
|
| 712 |
+
run_eval(config, *args)
|
| 713 |
+
elif action == 'predict':
|
| 714 |
+
generate_predictions(config, *args)
|
| 715 |
+
elif action == 'tensorize':
|
| 716 |
+
make_tensors(config, *args)
|
| 717 |
+
elif action == 'preprocess':
|
| 718 |
+
preprocess_text(config, *args)
|
| 719 |
+
else:
|
| 720 |
+
print('invalid action "%s"' % action)
|
| 721 |
+
sys.exit(1)
|
| 722 |
+
|
| 723 |
+
if __name__ == '__main__':
|
| 724 |
+
parser = argparse.ArgumentParser()
|
| 725 |
+
parser.add_argument("action", help="train|eval|predict|tensorize|preprocess", type=str)
|
| 726 |
+
parser.add_argument("action_args", help="arguments for selected action", type=str, nargs='*')
|
| 727 |
+
parser.add_argument("--seed", help="random seed", default=default_config.seed, type=int)
|
| 728 |
+
parser.add_argument("--lang", help="language (fr, en, zh)", default=default_config.lang, type=str)
|
| 729 |
+
parser.add_argument("--flavor", help="bert flavor in transformers model zoo", default=default_config.flavor, type=str)
|
| 730 |
+
parser.add_argument("--max-length", help="maximum input length", default=default_config.max_length, type=int)
|
| 731 |
+
parser.add_argument("--batch-size", help="size of batches", default=default_config.batch_size, type=int)
|
| 732 |
+
parser.add_argument("--device", help="computation device (cuda, cpu)", default=default_config.device, type=str)
|
| 733 |
+
parser.add_argument("--debug", help="whether to output more debug info", default=default_config.debug, type=bool)
|
| 734 |
+
parser.add_argument("--updates", help="number of training updates to perform", default=default_config.updates, type=bool)
|
| 735 |
+
parser.add_argument("--period", help="validation period in updates", default=default_config.period, type=bool)
|
| 736 |
+
parser.add_argument("--lr", help="learning rate", default=default_config.lr, type=bool)
|
| 737 |
+
parser.add_argument("--dab-rate", help="drop at boundaries rate", default=default_config.dab_rate, type=bool)
|
| 738 |
+
config = Config(**parser.parse_args().__dict__)
|
| 739 |
+
|
| 740 |
+
main(config, config.action, config.action_args)
|
| 741 |
+
|
| 742 |
+
|
punctuation/vosk-recasepunc-en-0.22/vosk-adapted.txt
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
the
|
| 2 |
+
the
|
| 3 |
+
the beijing and shanghai welcome to the market strata open i'm yvonne good morning and i'm david ingles counting down of course the diablo trade on the chinese
|
| 4 |
+
mainland here in hong kong let's get your top stories today taper and a timetable dominating the latest fed minutes as official debates the exit path meanwhile i got beijing heading the other way hinting at the first triple r cut in more than a year and after the didi debacle here china may move to close a loophole long used
|
| 5 |
+
by companies to take their listings abroad all to enhance that was a horrible mistake council yesterday from china as a maybe it's time to cut the triple r to help them with small businesses they are struggling from the rise of raw material costs the key question is how likely is this yeah what they say it chances are likely it's probably going to be up yet
|
| 6 |
+
the fact that they're saying it might actually already mean we're getting some sentiment coming through in terms of an improved material tracker ten year yield we'll get to that in just a moment in china we're now flirting with the three percent level equity markets futures are pointing up as you can see here in china though broadly speaking though we're down for a seven day across asia seventh day in the last excuse me
|
| 7 |
+
in the last eight sessions here have little commodity markets we're stabilising across your oil or oil prices we're still down five six per cent from highs though as far as that is concerned fx markets your story is guys can we change the police are we're looking at generally speaking the dollar that's very much in focus here so you look at that against the euro you look at that
|
| 8 |
+
against the chinese currency twenty four hours ago who would have thought we were talking about this sort of more divergence and starker labour discord between where you are in a pboc to easily in the fed and very quickly we alluded to this of course if one three percent on your chinese ten year yield and we're not one point three percent lower and lower
|
| 9 |
+
yields there is a charge for you china's top us ten year yield is at the bottom yeah the chinatown area lowest since we saw last year of september yup
|
| 10 |
+
yeah it is a really big major shift in china's central bank policy that's the key question could it be coming of course let's flash out that into what we heard from the cabinet there raising the possibility of a cut to the reserve requirement ratio to both the economy at the same time we also from a former pboc official sheng songcheng said the central bank should actually
|
| 11 |
+
cut rates he's not just talking about a triple r and either the second half is an important window when china's monetary policy can tilt towards loosening while remaining stable and the interest rates can be lowered in a reasonable and moderate manner let's get the take from also be as well whether daisy i'm david chiu here the short of it is
|
| 12 |
+
so i guess one point if we still haven't gotten that if in the event that we do their take is they it might be a little bit too aggressive to address some of the softness in the economy in other words what they're saying is it needs some help the economy maybe not this much yeah there preferring perhaps perhaps liquidity injections here and there but this might signal a bit too much
|
| 13 |
+
for when it comes to reflating the economy joining us out of the dice all this let's bring in wang tao ubi as head of asia economics and the chief china economists as well wang tao thanks much for joining us first off do you think this is actually a real possibility now
|
| 14 |
+
or well will shrink or fade contro as a frequently called using triple r cut as a tool so i think yes indeed it is a real possibility that they could do this however in the past whenever the state council called for this a few days to a couple of weeks later we were
|
| 15 |
+
would have we would see a triple r cut if they called for it and but it's worth noting that last year in june shoot at the chicago auto quote for it and by the pbc did not hold onto with any market so i i would say at this moment it's probably a relatively high likelihood but anything
|
| 16 |
+
the wording is really you know about mitigating the higher cost of commodity prices they impact on at an ease and make their effective conquered funding a bit lower so it's possible that it's going to be a targeted not a overall triple cut and i i don't think this really reflects a
|
| 17 |
+
wholesale shift in monetary policy i think very very much in the same state concrete statement also talked about
|
punctuation/vosk-recasepunc-en-0.22/vosk-adapted.txt.punc
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
The. The. The Beijing and Shanghai. Welcome to the market strata open. I'm Yvonne, good morning, and I'm David Ingles, counting down, of course, the Diablo trade on the Chinese mainland here in Hong Kong. Let's get your top stories today, taper and a timetable dominating the latest Fed minutes as official debates. The exit path. Meanwhile, I got Beijing heading the other way, hinting at the first triple R cut in more than a year. And after the Didi debacle here, China may move to close a loophole. Long used by companies to take their listings abroad, all to enhance. That was a horrible mistake. Council yesterday from China as a. Maybe it's time to cut the triple R to help them with small businesses they are struggling from the rise of raw material costs. The key question is, how likely is this ? Yeah, what they say it. Chances are likely it's probably going to be up yet. The fact that they're saying it might actually already mean we're getting some sentiment coming through in terms of an improved material tracker. Ten year yield. We'll get to that in just a moment. In China. We're now flirting with the three percent level equity markets futures are pointing up. As you can see here in China, though. Broadly speaking, though, we're down for a seven day across Asia. Seventh day in the last. Excuse me, in the last eight sessions here have little commodity markets. We're stabilising across your oil or oil prices. We're still down five, six per cent from highs, though as far as that is concerned FX markets. Your story is, guys, can we change the police are we're looking at, generally speaking, the dollar. That's very much in focus here. So you look at that against the euro. You look at that against the Chinese currency Twenty four hours ago. Who would have thought we were talking about this sort of more divergence and starker labour discord between where you are in a PBOC to easily in the Fed and very quickly. We alluded to this, Of course, if one three percent on your Chinese ten year yield and we're not one point three percent lower and lower yields, there is a charge for you. China's top US ten year yield is at the bottom. Yeah, the Chinatown area lowest since we saw last year of September. Yup. Yeah, it is a really big major shift in China's central bank policy. That's the key question. Could it be coming ? Of course. Let's flash out that into what we heard from the cabinet there, raising the possibility of a cut to the reserve requirement ratio to both the economy at the same time. We also from a former PBOC official, Sheng Songcheng said the central bank should actually cut rates. He's not just talking about a triple R. And either the second half is an important window when China's monetary policy can tilt towards loosening while remaining stable and the interest rates can be lowered in a reasonable and moderate manner. Let's get the take from also be as well, whether Daisy, I'm David Chiu here, the short of it is so I guess one point, if we still haven't gotten that if in the event that we do their take is they, it might be a little bit too aggressive to address some of the softness in the economy. In other words, what they're saying is it needs some help. The economy, maybe not this much. Yeah, there, preferring perhaps perhaps liquidity injections here and there. But this might signal a bit too much for when it comes to reflating the economy. Joining us out of the dice. All this, Let's bring in Wang Tao Ubi as head of Asia Economics, and the chief China economists as well. Wang Tao, thanks much for joining us. First off, do you think this is actually a real possibility now or well will shrink or fade ? Contro as a frequently called using triple R cut as a tool. So I think yes, indeed, it is a real possibility. That they could do this. However, in the past, whenever the State Council called for this a few days to a couple of weeks later, we were. Would have we would see a triple R cut if they called for it. And. But it's worth noting that last year in June, shoot at the Chicago auto quote for it and by the PBC did not hold onto with any market so I. I would say at this moment it's probably a relatively high likelihood, but anything. The wording is really, you know about mitigating the higher cost of commodity prices they impact on at an ease and make their effective conquered funding a bit lower. So it's possible that it's going to be a targeted, not a overall triple cut and I. I don't think this really reflects a wholesale shift in monetary policy. I think very, very much in the same state. Concrete statement also talked about.
|
punctuation/vosk-recasepunc-ru-0.22.7z
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f23cc633c06d910e056234d6b83a11a3683e582abc56c30771dbec98a91034de
|
| 3 |
+
size 1639885297
|
punctuation/vosk-recasepunc-ru-0.22/README
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
1. Install pytorch and transformers:
|
| 2 |
+
|
| 3 |
+
pip3 install transformers
|
| 4 |
+
|
| 5 |
+
2. Run python3 example.py de-test.txt
|
| 6 |
+
|
| 7 |
+
3. Compare with de-test.txt.orig
|
punctuation/vosk-recasepunc-ru-0.22/checkpoint
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:61fd424795c046963f88534071abde0813a4a6c66c07f0335b013825e536c1ae
|
| 3 |
+
size 2134070889
|
punctuation/vosk-recasepunc-ru-0.22/example.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import time
|
| 3 |
+
from transformers import logging
|
| 4 |
+
from recasepunc import CasePuncPredictor
|
| 5 |
+
from recasepunc import WordpieceTokenizer
|
| 6 |
+
from recasepunc import Config
|
| 7 |
+
|
| 8 |
+
logging.set_verbosity_error()
|
| 9 |
+
|
| 10 |
+
predictor = CasePuncPredictor('checkpoint', lang="ru")
|
| 11 |
+
|
| 12 |
+
text = " ".join(open(sys.argv[1]).readlines())
|
| 13 |
+
tokens = list(enumerate(predictor.tokenize(text)))
|
| 14 |
+
|
| 15 |
+
results = ""
|
| 16 |
+
for token, case_label, punc_label in predictor.predict(tokens, lambda x: x[1]):
|
| 17 |
+
prediction = predictor.map_punc_label(predictor.map_case_label(token[1], case_label), punc_label)
|
| 18 |
+
if token[1][0] != '#':
|
| 19 |
+
results = results + ' ' + prediction
|
| 20 |
+
else:
|
| 21 |
+
results = results + prediction
|
| 22 |
+
|
| 23 |
+
print (results.strip())
|
punctuation/vosk-recasepunc-ru-0.22/recasepunc.py
ADDED
|
@@ -0,0 +1,743 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import collections
|
| 3 |
+
import os
|
| 4 |
+
import regex as re
|
| 5 |
+
#from mosestokenizer import *
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
import torch.optim as optim
|
| 11 |
+
import random
|
| 12 |
+
import unicodedata
|
| 13 |
+
import numpy as np
|
| 14 |
+
import argparse
|
| 15 |
+
from torch.utils.data import TensorDataset, DataLoader
|
| 16 |
+
|
| 17 |
+
from transformers import AutoModel, AutoTokenizer, BertTokenizer
|
| 18 |
+
|
| 19 |
+
default_config = argparse.Namespace(
|
| 20 |
+
seed=871253,
|
| 21 |
+
lang='ru',
|
| 22 |
+
#flavor='flaubert/flaubert_base_uncased',
|
| 23 |
+
flavor=None,
|
| 24 |
+
max_length=256,
|
| 25 |
+
batch_size=16,
|
| 26 |
+
updates=50000,
|
| 27 |
+
period=1000,
|
| 28 |
+
lr=1e-5,
|
| 29 |
+
dab_rate=0.1,
|
| 30 |
+
device='cuda',
|
| 31 |
+
debug=False
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
default_flavors = {
|
| 35 |
+
'fr': 'flaubert/flaubert_base_uncased',
|
| 36 |
+
'en': 'bert-base-uncased',
|
| 37 |
+
'zh': 'ckiplab/bert-base-chinese',
|
| 38 |
+
'tr': 'dbmdz/bert-base-turkish-uncased',
|
| 39 |
+
'de': 'dbmdz/bert-base-german-uncased',
|
| 40 |
+
'pt': 'neuralmind/bert-base-portuguese-cased',
|
| 41 |
+
'ru': 'DeepPavlov/rubert-base-cased'
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
class Config(argparse.Namespace):
|
| 45 |
+
def __init__(self, **kwargs):
|
| 46 |
+
for key, value in default_config.__dict__.items():
|
| 47 |
+
setattr(self, key, value)
|
| 48 |
+
for key, value in kwargs.items():
|
| 49 |
+
setattr(self, key, value)
|
| 50 |
+
|
| 51 |
+
assert self.lang in ['fr', 'en', 'zh', 'tr', 'pt', 'de', 'ru']
|
| 52 |
+
|
| 53 |
+
if 'lang' in kwargs and ('flavor' not in kwargs or kwargs['flavor'] is None):
|
| 54 |
+
self.flavor = default_flavors[self.lang]
|
| 55 |
+
|
| 56 |
+
#print(self.lang, self.flavor)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def init_random(seed):
|
| 60 |
+
# make sure everything is deterministic
|
| 61 |
+
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
|
| 62 |
+
#torch.use_deterministic_algorithms(True)
|
| 63 |
+
torch.manual_seed(seed)
|
| 64 |
+
torch.cuda.manual_seed_all(seed)
|
| 65 |
+
random.seed(seed)
|
| 66 |
+
np.random.seed(seed)
|
| 67 |
+
|
| 68 |
+
# NOTE: it is assumed in the implementation that y[:,0] is the punctuation label, and y[:,1] is the case label!
|
| 69 |
+
|
| 70 |
+
punctuation = {
|
| 71 |
+
'O': 0,
|
| 72 |
+
'COMMA': 1,
|
| 73 |
+
'PERIOD': 2,
|
| 74 |
+
'QUESTION': 3,
|
| 75 |
+
'EXCLAMATION': 4,
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
punctuation_syms = ['', ',', '.', ' ?', ' !']
|
| 79 |
+
|
| 80 |
+
case = {
|
| 81 |
+
'LOWER': 0,
|
| 82 |
+
'UPPER': 1,
|
| 83 |
+
'CAPITALIZE': 2,
|
| 84 |
+
'OTHER': 3,
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class Model(nn.Module):
|
| 89 |
+
def __init__(self, flavor, device):
|
| 90 |
+
super().__init__()
|
| 91 |
+
self.bert = AutoModel.from_pretrained(flavor)
|
| 92 |
+
# need a proper way of determining representation size
|
| 93 |
+
size = self.bert.dim if hasattr(self.bert, 'dim') else self.bert.config.pooler_fc_size if hasattr(self.bert.config, 'pooler_fc_size') else self.bert.config.emb_dim if hasattr(self.bert.config, 'emb_dim') else self.bert.config.hidden_size
|
| 94 |
+
self.punc = nn.Linear(size, 5)
|
| 95 |
+
self.case = nn.Linear(size, 4)
|
| 96 |
+
self.dropout = nn.Dropout(0.3)
|
| 97 |
+
self.to(device)
|
| 98 |
+
|
| 99 |
+
def forward(self, x):
|
| 100 |
+
output = self.bert(x)
|
| 101 |
+
representations = self.dropout(F.gelu(output['last_hidden_state']))
|
| 102 |
+
punc = self.punc(representations)
|
| 103 |
+
case = self.case(representations)
|
| 104 |
+
return punc, case
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
# randomly create sequences that align to punctuation boundaries
|
| 108 |
+
def drop_at_boundaries(rate, x, y, cls_token_id, sep_token_id, pad_token_id):
|
| 109 |
+
for i, dropped in enumerate(torch.rand((len(x),)) < rate):
|
| 110 |
+
if dropped:
|
| 111 |
+
# select all indices that are sentence endings
|
| 112 |
+
indices = (y[i,:,0] > 1).nonzero(as_tuple=True)[0]
|
| 113 |
+
if len(indices) < 2:
|
| 114 |
+
continue
|
| 115 |
+
start = indices[0] + 1
|
| 116 |
+
end = indices[random.randint(1, len(indices) - 1)] + 1
|
| 117 |
+
length = end - start
|
| 118 |
+
if length + 2 > len(x[i]):
|
| 119 |
+
continue
|
| 120 |
+
x[i, 0] = cls_token_id
|
| 121 |
+
x[i, 1: length + 1] = x[i, start: end].clone()
|
| 122 |
+
x[i, length + 1] = sep_token_id
|
| 123 |
+
x[i, length + 2:] = pad_token_id
|
| 124 |
+
y[i, 0] = 0
|
| 125 |
+
y[i, 1: length + 1] = y[i, start: end].clone()
|
| 126 |
+
y[i, length + 1:] = 0
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def compute_performance(config, model, loader):
|
| 130 |
+
device = config.device
|
| 131 |
+
criterion = nn.CrossEntropyLoss()
|
| 132 |
+
model.eval()
|
| 133 |
+
total_loss = all_correct1 = all_correct2 = num_loss = num_perf = 0
|
| 134 |
+
num_ref = collections.defaultdict(float)
|
| 135 |
+
num_hyp = collections.defaultdict(float)
|
| 136 |
+
num_correct = collections.defaultdict(float)
|
| 137 |
+
for x, y in loader:
|
| 138 |
+
x = x.long().to(device)
|
| 139 |
+
y = y.long().to(device)
|
| 140 |
+
y1 = y[:,:,0]
|
| 141 |
+
y2 = y[:,:,1]
|
| 142 |
+
with torch.no_grad():
|
| 143 |
+
y_scores1, y_scores2 = model(x.to(device))
|
| 144 |
+
loss1 = criterion(y_scores1.view(y1.size(0) * y1.size(1), -1), y1.view(y1.size(0) * y1.size(1)))
|
| 145 |
+
loss2 = criterion(y_scores2.view(y2.size(0) * y2.size(1), -1), y2.view(y2.size(0) * y2.size(1)))
|
| 146 |
+
loss = loss1 + loss2
|
| 147 |
+
y_pred1 = torch.max(y_scores1, 2)[1]
|
| 148 |
+
y_pred2 = torch.max(y_scores2, 2)[1]
|
| 149 |
+
for label in range(1, 5):
|
| 150 |
+
ref = (y1 == label)
|
| 151 |
+
hyp = (y_pred1 == label)
|
| 152 |
+
correct = (ref * hyp == 1)
|
| 153 |
+
num_ref[label] += ref.sum()
|
| 154 |
+
num_hyp[label] += hyp.sum()
|
| 155 |
+
num_correct[label] += correct.sum()
|
| 156 |
+
num_ref[0] += ref.sum()
|
| 157 |
+
num_hyp[0] += hyp.sum()
|
| 158 |
+
num_correct[0] += correct.sum()
|
| 159 |
+
all_correct1 += (y_pred1 == y1).sum()
|
| 160 |
+
all_correct2 += (y_pred2 == y2).sum()
|
| 161 |
+
total_loss += loss.item()
|
| 162 |
+
num_loss += len(y)
|
| 163 |
+
num_perf += len(y) * config.max_length
|
| 164 |
+
recall = {}
|
| 165 |
+
precision = {}
|
| 166 |
+
fscore = {}
|
| 167 |
+
for label in range(0, 5):
|
| 168 |
+
recall[label] = num_correct[label] / num_ref[label] if num_ref[label] > 0 else 0
|
| 169 |
+
precision[label] = num_correct[label] / num_hyp[label] if num_hyp[label] > 0 else 0
|
| 170 |
+
fscore[label] = (2 * recall[label] * precision[label] / (recall[label] + precision[label])).item() if recall[label] + precision[label] > 0 else 0
|
| 171 |
+
return total_loss / num_loss, all_correct2.item() / num_perf, all_correct1.item() / num_perf, fscore
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def fit(config, model, checkpoint_path, train_loader, valid_loader, iterations, valid_period=200, lr=1e-5):
|
| 175 |
+
device = config.device
|
| 176 |
+
criterion = nn.CrossEntropyLoss()
|
| 177 |
+
optimizer = optim.Adam(filter(lambda param: param.requires_grad, model.parameters()), lr=lr)
|
| 178 |
+
iteration = 0
|
| 179 |
+
while True:
|
| 180 |
+
model.train()
|
| 181 |
+
total_loss = num = 0
|
| 182 |
+
for x, y in tqdm(train_loader):
|
| 183 |
+
x = x.long().to(device)
|
| 184 |
+
y = y.long().to(device)
|
| 185 |
+
drop_at_boundaries(config.dab_rate, x, y, config.cls_token_id, config.sep_token_id, config.pad_token_id)
|
| 186 |
+
y1 = y[:,:,0]
|
| 187 |
+
y2 = y[:,:,1]
|
| 188 |
+
optimizer.zero_grad()
|
| 189 |
+
y_scores1, y_scores2 = model(x)
|
| 190 |
+
loss1 = criterion(y_scores1.view(y1.size(0) * y1.size(1), -1), y1.view(y1.size(0) * y1.size(1)))
|
| 191 |
+
loss2 = criterion(y_scores2.view(y2.size(0) * y2.size(1), -1), y2.view(y2.size(0) * y2.size(1)))
|
| 192 |
+
loss = loss1 + loss2
|
| 193 |
+
loss.backward()
|
| 194 |
+
optimizer.step()
|
| 195 |
+
total_loss += loss.item()
|
| 196 |
+
num += len(y)
|
| 197 |
+
if iteration % valid_period == valid_period - 1:
|
| 198 |
+
train_loss = total_loss / num
|
| 199 |
+
valid_loss, valid_accuracy_case, valid_accuracy_punc, valid_fscore = compute_performance(config, model, valid_loader)
|
| 200 |
+
torch.save({
|
| 201 |
+
'iteration': iteration + 1,
|
| 202 |
+
'model_state_dict': model.state_dict(),
|
| 203 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 204 |
+
'train_loss': train_loss,
|
| 205 |
+
'valid_loss': valid_loss,
|
| 206 |
+
'valid_accuracy_case': valid_accuracy_case,
|
| 207 |
+
'valid_accuracy_punc': valid_accuracy_punc,
|
| 208 |
+
'valid_fscore': valid_fscore,
|
| 209 |
+
'config': config.__dict__,
|
| 210 |
+
}, '%s.%d' % (checkpoint_path, iteration + 1))
|
| 211 |
+
print(iteration + 1, train_loss, valid_loss, valid_accuracy_case, valid_accuracy_punc, valid_fscore)
|
| 212 |
+
total_loss = num = 0
|
| 213 |
+
|
| 214 |
+
iteration += 1
|
| 215 |
+
if iteration > iterations:
|
| 216 |
+
return
|
| 217 |
+
|
| 218 |
+
sys.stderr.flush()
|
| 219 |
+
sys.stdout.flush()
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def batchify(max_length, x, y):
|
| 223 |
+
print (x.shape)
|
| 224 |
+
print (y.shape)
|
| 225 |
+
x = x[:(len(x) // max_length) * max_length].reshape(-1, max_length)
|
| 226 |
+
y = y[:(len(y) // max_length) * max_length, :].reshape(-1, max_length, 2)
|
| 227 |
+
return x, y
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def train(config, train_x_fn, train_y_fn, valid_x_fn, valid_y_fn, checkpoint_path):
|
| 231 |
+
X_train, Y_train = batchify(config.max_length, torch.load(train_x_fn), torch.load(train_y_fn))
|
| 232 |
+
X_valid, Y_valid = batchify(config.max_length, torch.load(valid_x_fn), torch.load(valid_y_fn))
|
| 233 |
+
|
| 234 |
+
train_set = TensorDataset(X_train, Y_train)
|
| 235 |
+
valid_set = TensorDataset(X_valid, Y_valid)
|
| 236 |
+
|
| 237 |
+
train_loader = DataLoader(train_set, batch_size=config.batch_size, shuffle=True)
|
| 238 |
+
valid_loader = DataLoader(valid_set, batch_size=config.batch_size)
|
| 239 |
+
|
| 240 |
+
model = Model(config.flavor, config.device)
|
| 241 |
+
|
| 242 |
+
fit(config, model, checkpoint_path, train_loader, valid_loader, config.updates, config.period, config.lr)
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def run_eval(config, test_x_fn, test_y_fn, checkpoint_path):
|
| 246 |
+
X_test, Y_test = batchify(config.max_length, torch.load(test_x_fn), torch.load(test_y_fn))
|
| 247 |
+
test_set = TensorDataset(X_test, Y_test)
|
| 248 |
+
test_loader = DataLoader(test_set, batch_size=config.batch_size)
|
| 249 |
+
|
| 250 |
+
loaded = torch.load(checkpoint_path, map_location=config.device)
|
| 251 |
+
if 'config' in loaded:
|
| 252 |
+
config = Config(**loaded['config'])
|
| 253 |
+
init(config)
|
| 254 |
+
|
| 255 |
+
model = Model(config.flavor, config.device)
|
| 256 |
+
model.load_state_dict(loaded['model_state_dict'])
|
| 257 |
+
|
| 258 |
+
print(*compute_performance(config, model, test_loader))
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def recase(token, label):
|
| 262 |
+
if label == case['LOWER']:
|
| 263 |
+
return token.lower()
|
| 264 |
+
elif label == case['CAPITALIZE']:
|
| 265 |
+
return token.lower().capitalize()
|
| 266 |
+
elif label == case['UPPER']:
|
| 267 |
+
return token.upper()
|
| 268 |
+
else:
|
| 269 |
+
return token
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
class CasePuncPredictor:
|
| 273 |
+
def __init__(self, checkpoint_path, lang=default_config.lang, flavor=default_config.flavor, device=default_config.device):
|
| 274 |
+
loaded = torch.load(checkpoint_path, map_location=device if torch.cuda.is_available() else 'cpu')
|
| 275 |
+
if 'config' in loaded:
|
| 276 |
+
self.config = Config(**loaded['config'])
|
| 277 |
+
else:
|
| 278 |
+
self.config = Config(lang=lang, flavor=flavor, device=device)
|
| 279 |
+
init(self.config)
|
| 280 |
+
|
| 281 |
+
self.model = Model(self.config.flavor, self.config.device)
|
| 282 |
+
self.model.load_state_dict(loaded['model_state_dict'])
|
| 283 |
+
self.model.eval()
|
| 284 |
+
self.model.to(self.config.device)
|
| 285 |
+
|
| 286 |
+
self.rev_case = {b: a for a, b in case.items()}
|
| 287 |
+
self.rev_punc = {b: a for a, b in punctuation.items()}
|
| 288 |
+
|
| 289 |
+
def tokenize(self, text):
|
| 290 |
+
return [self.config.cls_token] + self.config.tokenizer.tokenize(text) + [self.config.sep_token]
|
| 291 |
+
|
| 292 |
+
def predict(self, tokens, getter=lambda x: x):
|
| 293 |
+
max_length = self.config.max_length
|
| 294 |
+
device = self.config.device
|
| 295 |
+
if type(tokens) == str:
|
| 296 |
+
tokens = self.tokenize(tokens)
|
| 297 |
+
previous_label = punctuation['PERIOD']
|
| 298 |
+
for start in range(0, len(tokens), max_length):
|
| 299 |
+
instance = tokens[start: start + max_length]
|
| 300 |
+
if type(getter(instance[0])) == str:
|
| 301 |
+
ids = self.config.tokenizer.convert_tokens_to_ids(getter(token) for token in instance)
|
| 302 |
+
else:
|
| 303 |
+
ids = [getter(token) for token in instance]
|
| 304 |
+
if len(ids) < max_length:
|
| 305 |
+
ids += [0] * (max_length - len(ids))
|
| 306 |
+
x = torch.tensor([ids]).long().to(device)
|
| 307 |
+
y_scores1, y_scores2 = self.model(x)
|
| 308 |
+
y_pred1 = torch.max(y_scores1, 2)[1]
|
| 309 |
+
y_pred2 = torch.max(y_scores2, 2)[1]
|
| 310 |
+
for i, id, token, punc_label, case_label in zip(range(len(instance)), ids, instance, y_pred1[0].tolist()[:len(instance)], y_pred2[0].tolist()[:len(instance)]):
|
| 311 |
+
if id == self.config.cls_token_id or id == self.config.sep_token_id:
|
| 312 |
+
continue
|
| 313 |
+
if previous_label != None and previous_label > 1:
|
| 314 |
+
if case_label in [case['LOWER'], case['OTHER']]: # LOWER, OTHER
|
| 315 |
+
case_label = case['CAPITALIZE']
|
| 316 |
+
if i + start == len(tokens) - 2 and punc_label == punctuation['O']:
|
| 317 |
+
punc_label = punctuation['PERIOD']
|
| 318 |
+
yield (token, self.rev_case[case_label], self.rev_punc[punc_label])
|
| 319 |
+
previous_label = punc_label
|
| 320 |
+
|
| 321 |
+
def map_case_label(self, token, case_label):
|
| 322 |
+
if token.endswith('</w>'):
|
| 323 |
+
token = token[:-4]
|
| 324 |
+
if token.startswith('##'):
|
| 325 |
+
token = token[2:]
|
| 326 |
+
return recase(token, case[case_label])
|
| 327 |
+
|
| 328 |
+
def map_punc_label(self, token, punc_label):
|
| 329 |
+
if token.endswith('</w>'):
|
| 330 |
+
token = token[:-4]
|
| 331 |
+
if token.startswith('##'):
|
| 332 |
+
token = token[2:]
|
| 333 |
+
return token + punctuation_syms[punctuation[punc_label]]
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
def generate_predictions(config, checkpoint_path):
|
| 338 |
+
loaded = torch.load(checkpoint_path, map_location=config.device if torch.cuda.is_available() else 'cpu')
|
| 339 |
+
if 'config' in loaded:
|
| 340 |
+
config = Config(**loaded['config'])
|
| 341 |
+
init(config)
|
| 342 |
+
|
| 343 |
+
model = Model(config.flavor, config.device)
|
| 344 |
+
model.load_state_dict(loaded['model_state_dict'])
|
| 345 |
+
|
| 346 |
+
rev_case = {b: a for a, b in case.items()}
|
| 347 |
+
rev_punc = {b: a for a, b in punctuation.items()}
|
| 348 |
+
|
| 349 |
+
for line in sys.stdin:
|
| 350 |
+
# also drop punctuation that we may generate
|
| 351 |
+
line = ''.join([c for c in line if c not in mapped_punctuation])
|
| 352 |
+
if config.debug:
|
| 353 |
+
print(line)
|
| 354 |
+
tokens = [config.cls_token] + config.tokenizer.tokenize(line) + [config.sep_token]
|
| 355 |
+
if config.debug:
|
| 356 |
+
print(tokens)
|
| 357 |
+
previous_label = punctuation['PERIOD']
|
| 358 |
+
first_time = True
|
| 359 |
+
was_word = False
|
| 360 |
+
for start in range(0, len(tokens), config.max_length):
|
| 361 |
+
instance = tokens[start: start + config.max_length]
|
| 362 |
+
ids = config.tokenizer.convert_tokens_to_ids(instance)
|
| 363 |
+
#print(len(ids), file=sys.stderr)
|
| 364 |
+
if len(ids) < config.max_length:
|
| 365 |
+
ids += [config.pad_token_id] * (config.max_length - len(ids))
|
| 366 |
+
x = torch.tensor([ids]).long().to(config.device)
|
| 367 |
+
y_scores1, y_scores2 = model(x)
|
| 368 |
+
y_pred1 = torch.max(y_scores1, 2)[1]
|
| 369 |
+
y_pred2 = torch.max(y_scores2, 2)[1]
|
| 370 |
+
for id, token, punc_label, case_label in zip(ids, instance, y_pred1[0].tolist()[:len(instance)], y_pred2[0].tolist()[:len(instance)]):
|
| 371 |
+
if config.debug:
|
| 372 |
+
print(id, token, punc_label, case_label, file=sys.stderr)
|
| 373 |
+
if id == config.cls_token_id or id == config.sep_token_id:
|
| 374 |
+
continue
|
| 375 |
+
if previous_label != None and previous_label > 1:
|
| 376 |
+
if case_label in [case['LOWER'], case['OTHER']]:
|
| 377 |
+
case_label = case['CAPITALIZE']
|
| 378 |
+
previous_label = punc_label
|
| 379 |
+
# different strategy due to sub-lexical token encoding in Flaubert
|
| 380 |
+
if config.lang == 'fr':
|
| 381 |
+
if token.endswith('</w>'):
|
| 382 |
+
cased_token = recase(token[:-4], case_label)
|
| 383 |
+
if was_word:
|
| 384 |
+
print(' ', end='')
|
| 385 |
+
print(cased_token + punctuation_syms[punc_label], end='')
|
| 386 |
+
was_word = True
|
| 387 |
+
else:
|
| 388 |
+
cased_token = recase(token, case_label)
|
| 389 |
+
if was_word:
|
| 390 |
+
print(' ', end='')
|
| 391 |
+
print(cased_token, end='')
|
| 392 |
+
was_word = False
|
| 393 |
+
else:
|
| 394 |
+
if token.startswith('##'):
|
| 395 |
+
cased_token = recase(token[2:], case_label)
|
| 396 |
+
print(cased_token, end='')
|
| 397 |
+
else:
|
| 398 |
+
cased_token = recase(token, case_label)
|
| 399 |
+
if not first_time:
|
| 400 |
+
print(' ', end='')
|
| 401 |
+
first_time = False
|
| 402 |
+
print(cased_token + punctuation_syms[punc_label], end='')
|
| 403 |
+
if previous_label == 0:
|
| 404 |
+
print('.', end='')
|
| 405 |
+
print()
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
def label_for_case(token):
|
| 409 |
+
token = re.sub('[^\p{Han}\p{Ll}\p{Lu}]', '', token)
|
| 410 |
+
if token == token.lower():
|
| 411 |
+
return 'LOWER'
|
| 412 |
+
elif token == token.lower().capitalize():
|
| 413 |
+
return 'CAPITALIZE'
|
| 414 |
+
elif token == token.upper():
|
| 415 |
+
return 'UPPER'
|
| 416 |
+
else:
|
| 417 |
+
return 'OTHER'
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
def make_tensors(config, input_fn, output_x_fn, output_y_fn):
|
| 421 |
+
# count file lines without loading them
|
| 422 |
+
size = 0
|
| 423 |
+
with open(input_fn) as fp:
|
| 424 |
+
for line in fp:
|
| 425 |
+
size += 1
|
| 426 |
+
|
| 427 |
+
with open(input_fn) as fp:
|
| 428 |
+
X = torch.IntTensor(size)
|
| 429 |
+
Y = torch.ByteTensor(size, 2)
|
| 430 |
+
|
| 431 |
+
offset = 0
|
| 432 |
+
for n, line in enumerate(fp):
|
| 433 |
+
word, case_label, punc_label = line.strip().split('\t')
|
| 434 |
+
id = config.tokenizer.convert_tokens_to_ids(word)
|
| 435 |
+
if config.debug:
|
| 436 |
+
assert word.lower() == tokenizer.convert_ids_to_tokens(id)
|
| 437 |
+
X[offset] = id
|
| 438 |
+
Y[offset, 0] = punctuation[punc_label]
|
| 439 |
+
Y[offset, 1] = case[case_label]
|
| 440 |
+
offset += 1
|
| 441 |
+
|
| 442 |
+
torch.save(X, output_x_fn)
|
| 443 |
+
torch.save(Y, output_y_fn)
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
mapped_punctuation = {
|
| 447 |
+
'.': 'PERIOD',
|
| 448 |
+
'...': 'PERIOD',
|
| 449 |
+
',': 'COMMA',
|
| 450 |
+
';': 'COMMA',
|
| 451 |
+
':': 'COMMA',
|
| 452 |
+
'(': 'COMMA',
|
| 453 |
+
')': 'COMMA',
|
| 454 |
+
'?': 'QUESTION',
|
| 455 |
+
'!': 'EXCLAMATION',
|
| 456 |
+
',': 'COMMA',
|
| 457 |
+
'!': 'EXCLAMATION',
|
| 458 |
+
'?': 'QUESTION',
|
| 459 |
+
';': 'COMMA',
|
| 460 |
+
':': 'COMMA',
|
| 461 |
+
'(': 'COMMA',
|
| 462 |
+
'(': 'COMMA',
|
| 463 |
+
')': 'COMMA',
|
| 464 |
+
'[': 'COMMA',
|
| 465 |
+
']': 'COMMA',
|
| 466 |
+
'【': 'COMMA',
|
| 467 |
+
'】': 'COMMA',
|
| 468 |
+
'└': 'COMMA',
|
| 469 |
+
'└ ': 'COMMA',
|
| 470 |
+
'_': 'O',
|
| 471 |
+
'。': 'PERIOD',
|
| 472 |
+
'、': 'COMMA', # enumeration comma
|
| 473 |
+
'、': 'COMMA',
|
| 474 |
+
'…': 'PERIOD',
|
| 475 |
+
'—': 'COMMA',
|
| 476 |
+
'「': 'COMMA',
|
| 477 |
+
'」': 'COMMA',
|
| 478 |
+
'.': 'PERIOD',
|
| 479 |
+
'《': 'O',
|
| 480 |
+
'》': 'O',
|
| 481 |
+
',': 'COMMA',
|
| 482 |
+
'“': 'O',
|
| 483 |
+
'”': 'O',
|
| 484 |
+
'"': 'O',
|
| 485 |
+
'-': 'O',
|
| 486 |
+
'-': 'O',
|
| 487 |
+
'〉': 'COMMA',
|
| 488 |
+
'〈': 'COMMA',
|
| 489 |
+
'↑': 'O',
|
| 490 |
+
'〔': 'COMMA',
|
| 491 |
+
'〕': 'COMMA',
|
| 492 |
+
}
|
| 493 |
+
|
| 494 |
+
def preprocess_text(config, max_token_count=-1):
|
| 495 |
+
global num_tokens_output
|
| 496 |
+
max_token_count = int(max_token_count)
|
| 497 |
+
num_tokens_output = 0
|
| 498 |
+
def process_segment(text, punctuation):
|
| 499 |
+
global num_tokens_output
|
| 500 |
+
text = text.replace('\t', ' ')
|
| 501 |
+
tokens = config.tokenizer.tokenize(text)
|
| 502 |
+
for i, token in enumerate(tokens):
|
| 503 |
+
case_label = label_for_case(token)
|
| 504 |
+
if i == len(tokens) - 1:
|
| 505 |
+
print(token.lower(), case_label, punctuation, sep='\t')
|
| 506 |
+
else:
|
| 507 |
+
print(token.lower(), case_label, 'O', sep='\t')
|
| 508 |
+
num_tokens_output += 1
|
| 509 |
+
# a bit too ugly, but alternative is to throw an exception
|
| 510 |
+
if max_token_count > 0 and num_tokens_output >= max_token_count:
|
| 511 |
+
sys.exit(0)
|
| 512 |
+
|
| 513 |
+
for line in sys.stdin:
|
| 514 |
+
line = line.strip()
|
| 515 |
+
if line != '':
|
| 516 |
+
line = unicodedata.normalize("NFC", line)
|
| 517 |
+
if config.debug:
|
| 518 |
+
print(line)
|
| 519 |
+
start = 0
|
| 520 |
+
for i, char in enumerate(line):
|
| 521 |
+
if char in mapped_punctuation:
|
| 522 |
+
if i > start and line[start: i].strip() != '':
|
| 523 |
+
process_segment(line[start: i], mapped_punctuation[char])
|
| 524 |
+
start = i + 1
|
| 525 |
+
if start < len(line):
|
| 526 |
+
process_segment(line[start:], 'PERIOD')
|
| 527 |
+
|
| 528 |
+
|
| 529 |
+
def preprocess_text_old_fr(config):
|
| 530 |
+
assert config.lang == 'fr'
|
| 531 |
+
splitsents = MosesSentenceSplitter(lang)
|
| 532 |
+
tokenize = MosesTokenizer(lang, extra=['-no-escape'])
|
| 533 |
+
normalize = MosesPunctuationNormalizer(lang)
|
| 534 |
+
|
| 535 |
+
for line in sys.stdin:
|
| 536 |
+
if line.strip() != '':
|
| 537 |
+
for sentence in splitsents([normalize(line)]):
|
| 538 |
+
tokens = tokenize(sentence)
|
| 539 |
+
previous_token = None
|
| 540 |
+
for token in tokens:
|
| 541 |
+
if token in mapped_punctuation:
|
| 542 |
+
if previous_token != None:
|
| 543 |
+
print(previous_token, mapped_punctuation[token], sep='\t')
|
| 544 |
+
previous_token = None
|
| 545 |
+
elif not re.search('[\p{Han}\p{Ll}\p{Lu}\d]', token): # remove non-alphanumeric tokens
|
| 546 |
+
continue
|
| 547 |
+
else:
|
| 548 |
+
if previous_token != None:
|
| 549 |
+
print(previous_token, 'O', sep='\t')
|
| 550 |
+
previous_token = token
|
| 551 |
+
if previous_token != None:
|
| 552 |
+
print(previous_token, 'PERIOD', sep='\t')
|
| 553 |
+
|
| 554 |
+
|
| 555 |
+
# modification of the wordpiece tokenizer to keep case information even if vocab is lower cased
|
| 556 |
+
# forked from https://github.com/huggingface/transformers/blob/master/src/transformers/models/bert/tokenization_bert.py
|
| 557 |
+
|
| 558 |
+
class WordpieceTokenizer(object):
|
| 559 |
+
"""Runs WordPiece tokenization."""
|
| 560 |
+
|
| 561 |
+
def __init__(self, vocab, unk_token, max_input_chars_per_word=100, keep_case=True):
|
| 562 |
+
self.vocab = vocab
|
| 563 |
+
self.unk_token = unk_token
|
| 564 |
+
self.max_input_chars_per_word = max_input_chars_per_word
|
| 565 |
+
self.keep_case = keep_case
|
| 566 |
+
|
| 567 |
+
def tokenize(self, text):
|
| 568 |
+
"""
|
| 569 |
+
Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform
|
| 570 |
+
tokenization using the given vocabulary.
|
| 571 |
+
For example, :obj:`input = "unaffable"` wil return as output :obj:`["un", "##aff", "##able"]`.
|
| 572 |
+
Args:
|
| 573 |
+
text: A single token or whitespace separated tokens. This should have
|
| 574 |
+
already been passed through `BasicTokenizer`.
|
| 575 |
+
Returns:
|
| 576 |
+
A list of wordpiece tokens.
|
| 577 |
+
"""
|
| 578 |
+
|
| 579 |
+
output_tokens = []
|
| 580 |
+
for token in text.strip().split():
|
| 581 |
+
chars = list(token)
|
| 582 |
+
if len(chars) > self.max_input_chars_per_word:
|
| 583 |
+
output_tokens.append(self.unk_token)
|
| 584 |
+
continue
|
| 585 |
+
|
| 586 |
+
is_bad = False
|
| 587 |
+
start = 0
|
| 588 |
+
sub_tokens = []
|
| 589 |
+
while start < len(chars):
|
| 590 |
+
end = len(chars)
|
| 591 |
+
cur_substr = None
|
| 592 |
+
while start < end:
|
| 593 |
+
substr = "".join(chars[start:end])
|
| 594 |
+
if start > 0:
|
| 595 |
+
substr = "##" + substr
|
| 596 |
+
# optionaly lowercase substring before checking for inclusion in vocab
|
| 597 |
+
if (self.keep_case and substr.lower() in self.vocab) or (substr in self.vocab):
|
| 598 |
+
cur_substr = substr
|
| 599 |
+
break
|
| 600 |
+
end -= 1
|
| 601 |
+
if cur_substr is None:
|
| 602 |
+
is_bad = True
|
| 603 |
+
break
|
| 604 |
+
sub_tokens.append(cur_substr)
|
| 605 |
+
start = end
|
| 606 |
+
|
| 607 |
+
if is_bad:
|
| 608 |
+
output_tokens.append(self.unk_token)
|
| 609 |
+
else:
|
| 610 |
+
output_tokens.extend(sub_tokens)
|
| 611 |
+
return output_tokens
|
| 612 |
+
|
| 613 |
+
|
| 614 |
+
# modification of XLM bpe tokenizer for keeping case information when vocab is lowercase
|
| 615 |
+
# forked from https://github.com/huggingface/transformers/blob/cd56f3fe7eae4a53a9880e3f5e8f91877a78271c/src/transformers/models/xlm/tokenization_xlm.py
|
| 616 |
+
def bpe(self, token):
|
| 617 |
+
def to_lower(pair):
|
| 618 |
+
#print(' ',pair)
|
| 619 |
+
return (pair[0].lower(), pair[1].lower())
|
| 620 |
+
|
| 621 |
+
from transformers.models.xlm.tokenization_xlm import get_pairs
|
| 622 |
+
|
| 623 |
+
word = tuple(token[:-1]) + (token[-1] + "</w>",)
|
| 624 |
+
if token in self.cache:
|
| 625 |
+
return self.cache[token]
|
| 626 |
+
pairs = get_pairs(word)
|
| 627 |
+
|
| 628 |
+
if not pairs:
|
| 629 |
+
return token + "</w>"
|
| 630 |
+
|
| 631 |
+
while True:
|
| 632 |
+
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(to_lower(pair), float("inf")))
|
| 633 |
+
#print(bigram)
|
| 634 |
+
if to_lower(bigram) not in self.bpe_ranks:
|
| 635 |
+
break
|
| 636 |
+
first, second = bigram
|
| 637 |
+
new_word = []
|
| 638 |
+
i = 0
|
| 639 |
+
while i < len(word):
|
| 640 |
+
try:
|
| 641 |
+
j = word.index(first, i)
|
| 642 |
+
except ValueError:
|
| 643 |
+
new_word.extend(word[i:])
|
| 644 |
+
break
|
| 645 |
+
else:
|
| 646 |
+
new_word.extend(word[i:j])
|
| 647 |
+
i = j
|
| 648 |
+
|
| 649 |
+
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
|
| 650 |
+
new_word.append(first + second)
|
| 651 |
+
i += 2
|
| 652 |
+
else:
|
| 653 |
+
new_word.append(word[i])
|
| 654 |
+
i += 1
|
| 655 |
+
new_word = tuple(new_word)
|
| 656 |
+
word = new_word
|
| 657 |
+
if len(word) == 1:
|
| 658 |
+
break
|
| 659 |
+
else:
|
| 660 |
+
pairs = get_pairs(word)
|
| 661 |
+
word = " ".join(word)
|
| 662 |
+
if word == "\n </w>":
|
| 663 |
+
word = "\n</w>"
|
| 664 |
+
self.cache[token] = word
|
| 665 |
+
return word
|
| 666 |
+
|
| 667 |
+
|
| 668 |
+
|
| 669 |
+
def init(config):
|
| 670 |
+
init_random(config.seed)
|
| 671 |
+
|
| 672 |
+
if config.lang == 'fr':
|
| 673 |
+
config.tokenizer = tokenizer = AutoTokenizer.from_pretrained(config.flavor, do_lower_case=False)
|
| 674 |
+
|
| 675 |
+
from transformers.models.xlm.tokenization_xlm import XLMTokenizer
|
| 676 |
+
assert isinstance(tokenizer, XLMTokenizer)
|
| 677 |
+
|
| 678 |
+
# monkey patch XLM tokenizer
|
| 679 |
+
import types
|
| 680 |
+
tokenizer.bpe = types.MethodType(bpe, tokenizer)
|
| 681 |
+
else:
|
| 682 |
+
# warning: needs to be BertTokenizer for monkey patching to work
|
| 683 |
+
config.tokenizer = tokenizer = BertTokenizer.from_pretrained(config.flavor, do_lower_case=False)
|
| 684 |
+
|
| 685 |
+
# warning: monkey patch tokenizer to keep case information
|
| 686 |
+
#from recasing_tokenizer import WordpieceTokenizer
|
| 687 |
+
config.tokenizer.wordpiece_tokenizer = WordpieceTokenizer(vocab=tokenizer.vocab, unk_token=tokenizer.unk_token)
|
| 688 |
+
|
| 689 |
+
if config.lang == 'fr':
|
| 690 |
+
config.pad_token_id = tokenizer.pad_token_id
|
| 691 |
+
config.cls_token_id = tokenizer.bos_token_id
|
| 692 |
+
config.cls_token = tokenizer.bos_token
|
| 693 |
+
config.sep_token_id = tokenizer.sep_token_id
|
| 694 |
+
config.sep_token = tokenizer.sep_token
|
| 695 |
+
else:
|
| 696 |
+
config.pad_token_id = tokenizer.pad_token_id
|
| 697 |
+
config.cls_token_id = tokenizer.cls_token_id
|
| 698 |
+
config.cls_token = tokenizer.cls_token
|
| 699 |
+
config.sep_token_id = tokenizer.sep_token_id
|
| 700 |
+
config.sep_token = tokenizer.sep_token
|
| 701 |
+
|
| 702 |
+
if not torch.cuda.is_available() and config.device == 'cuda':
|
| 703 |
+
print('WARNING: reverting to cpu as cuda is not available', file=sys.stderr)
|
| 704 |
+
config.device = torch.device(config.device if torch.cuda.is_available() else 'cpu')
|
| 705 |
+
|
| 706 |
+
|
| 707 |
+
def main(config, action, args):
|
| 708 |
+
init(config)
|
| 709 |
+
|
| 710 |
+
if action == 'train':
|
| 711 |
+
train(config, *args)
|
| 712 |
+
elif action == 'eval':
|
| 713 |
+
run_eval(config, *args)
|
| 714 |
+
elif action == 'predict':
|
| 715 |
+
generate_predictions(config, *args)
|
| 716 |
+
elif action == 'tensorize':
|
| 717 |
+
make_tensors(config, *args)
|
| 718 |
+
elif action == 'preprocess':
|
| 719 |
+
preprocess_text(config, *args)
|
| 720 |
+
else:
|
| 721 |
+
print('invalid action "%s"' % action)
|
| 722 |
+
sys.exit(1)
|
| 723 |
+
|
| 724 |
+
if __name__ == '__main__':
|
| 725 |
+
parser = argparse.ArgumentParser()
|
| 726 |
+
parser.add_argument("action", help="train|eval|predict|tensorize|preprocess", type=str)
|
| 727 |
+
parser.add_argument("action_args", help="arguments for selected action", type=str, nargs='*')
|
| 728 |
+
parser.add_argument("--seed", help="random seed", default=default_config.seed, type=int)
|
| 729 |
+
parser.add_argument("--lang", help="language (fr, en, zh)", default=default_config.lang, type=str)
|
| 730 |
+
parser.add_argument("--flavor", help="bert flavor in transformers model zoo", default=default_config.flavor, type=str)
|
| 731 |
+
parser.add_argument("--max-length", help="maximum input length", default=default_config.max_length, type=int)
|
| 732 |
+
parser.add_argument("--batch-size", help="size of batches", default=default_config.batch_size, type=int)
|
| 733 |
+
parser.add_argument("--device", help="computation device (cuda, cpu)", default=default_config.device, type=str)
|
| 734 |
+
parser.add_argument("--debug", help="whether to output more debug info", default=default_config.debug, type=bool)
|
| 735 |
+
parser.add_argument("--updates", help="number of training updates to perform", default=default_config.updates, type=bool)
|
| 736 |
+
parser.add_argument("--period", help="validation period in updates", default=default_config.period, type=bool)
|
| 737 |
+
parser.add_argument("--lr", help="learning rate", default=default_config.lr, type=bool)
|
| 738 |
+
parser.add_argument("--dab-rate", help="drop at boundaries rate", default=default_config.dab_rate, type=bool)
|
| 739 |
+
config = Config(**parser.parse_args().__dict__)
|
| 740 |
+
|
| 741 |
+
main(config, config.action, config.action_args)
|
| 742 |
+
|
| 743 |
+
|
punctuation/vosk-recasepunc-ru-0.22/ru-test.txt
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
все смешалось в доме облонских жена узнала что муж был в связи с бывшею
|
| 2 |
+
в их доме француженкою-гувернанткой и объявила мужу что не может жить с
|
| 3 |
+
ним в одном доме положение это продолжалось уже третий день и мучительно
|
| 4 |
+
чувствовалось и самими супругами и всеми членами семьи и домочадцами
|
| 5 |
+
все члены семьи и домочадцы чувствовали что нет смысла в их сожительстве
|
| 6 |
+
и что на каждом постоялом дворе случайно сошедшиеся люди более связаны
|
| 7 |
+
между собой чем они члены семьи и домочадцы облонских жена не выходила
|
| 8 |
+
из своих комнат мужа третий день не было дома дети бегали по всему
|
| 9 |
+
дому как потерянные англичанка поссорилась с экономкой и написала
|
| 10 |
+
записку приятельнице прося приискать ей новое место повар ушел еще
|
| 11 |
+
вчера со двора во время обеда черная кухарка и кучер просили расчета
|
| 12 |
+
На третий день после ссоры князь степан аркадьич облонский стива как
|
| 13 |
+
его звали в свете в обычный час то есть в восемь часов утра
|
| 14 |
+
проснулся не в спальне жены а в своем кабинете на сафьянном диване
|
| 15 |
+
он повернул свое полное выхоленное тело на пружинах дивана как бы желая
|
| 16 |
+
опять заснуть надолго с другой стороны крепко обнял подушку и прижался к
|
| 17 |
+
ней щекой но вдруг вскочил сел на диван и открыл глаза
|
punctuation/vosk-recasepunc-ru-0.22/ru-test.txt.orig
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Все смешалось в доме Облонских. Жена узнала, что муж был в связи с бывшею
|
| 2 |
+
в их доме француженкою-гувернанткой, и объявила мужу, что не может жить с
|
| 3 |
+
ним в одном доме. Положение это продолжалось уже третий день и мучительно
|
| 4 |
+
чувствовалось и самими супругами, и всеми членами семьи, и домочадцами.
|
| 5 |
+
Все члены семьи и домочадцы чувствовали, что нет смысла в их сожительстве
|
| 6 |
+
и что на каждом постоялом дворе случайно сошедшиеся люди более связаны
|
| 7 |
+
между собой, чем они, члены семьи и домочадцы Облонских. Жена не выходила
|
| 8 |
+
из своих комнат, мужа третий день не было дома. Дети бегали по всему
|
| 9 |
+
дому, как потерянные; англичанка поссорилась с экономкой и написала
|
| 10 |
+
записку приятельнице, прося приискать ей новое место; повар ушел еще
|
| 11 |
+
вчера со двора, во время обеда; черная кухарка и кучер просили расчета.
|
| 12 |
+
На третий день после ссоры князь Степан Аркадьич Облонский -- Стива, как
|
| 13 |
+
его звали в свете, -- в обычный час, то есть в восемь часов утра,
|
| 14 |
+
проснулся не в спальне жены, а в своем кабинете, на сафьянном диване...
|
| 15 |
+
Он повернул свое полное, выхоленное тело на пружинах дивана, как бы желая
|
| 16 |
+
опять заснуть надолго, с другой стороны крепко обнял подушку и прижался к
|
| 17 |
+
ней щекой; но вдруг вскочил, сел на диван и открыл глаза.
|
speaker_indentification/vosk-model-spk-0.4.7z
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e4c7ccb7760ffc2ccb780beac3ba40907728c075fbb1cbb66b4dacc0afda4598
|
| 3 |
+
size 13785421
|
speaker_indentification/vosk-model-spk-0.4/README.txt
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
UPLOADER David Snyder
|
| 3 |
+
DATE 2018-05-30
|
| 4 |
+
KALDI VERSION 108832d
|
| 5 |
+
|
| 6 |
+
This directory contains files generated from the recipe in
|
| 7 |
+
egs/callhome_diarization/v2/. It's contents should be placed in a similar
|
| 8 |
+
directory, with symbolic links to diarization/, sid/, steps/, etc. This was
|
| 9 |
+
created when Kaldi's master branch was at git log
|
| 10 |
+
2ad8d7821867a199e435aa36bbd13af6ed937c94.
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
I. Files list
|
| 14 |
+
------------------------------------------------------------------------------
|
| 15 |
+
|
| 16 |
+
./
|
| 17 |
+
README.txt This file
|
| 18 |
+
run.sh A copy of the egs/callhome_diarization/v2/run.sh
|
| 19 |
+
at the time of uploading this file. Use this to
|
| 20 |
+
figure out how to compute features, extract
|
| 21 |
+
embeddings, etc.
|
| 22 |
+
|
| 23 |
+
local/nnet3/xvector/tuning/
|
| 24 |
+
run_xvector_1a.sh This is the default recipe, at the time of
|
| 25 |
+
uploading this resource. The script generates
|
| 26 |
+
the configs, egs, and trains the model.
|
| 27 |
+
|
| 28 |
+
conf/
|
| 29 |
+
vad.conf The energy-based VAD configuration
|
| 30 |
+
mfcc.conf MFCC configuration
|
| 31 |
+
|
| 32 |
+
exp/xvector_nnet_1a/
|
| 33 |
+
final.raw The pretrained DNN model
|
| 34 |
+
nnet.config The nnet3 config file that was used when the
|
| 35 |
+
DNN model was first instantiated.
|
| 36 |
+
extract.config Another nnet3 config file that modifies the DNN
|
| 37 |
+
final.raw to extract x-vectors. It should be
|
| 38 |
+
automatically handled by the script
|
| 39 |
+
extract_xvectors.sh.
|
| 40 |
+
min_chunk_size Min chunk size used (see extract_xvectors.sh)
|
| 41 |
+
max_chunk_size Max chunk size used (see extract_xvectors.sh)
|
| 42 |
+
srand The RNG seed used when creating the DNN
|
| 43 |
+
|
| 44 |
+
exp/xvectors_callhome1/
|
| 45 |
+
mean.vec Vector for centering, from callhome1
|
| 46 |
+
transform.mat Whitening matrix, trained on callhome1
|
| 47 |
+
plda PLDA model for callhome1, trained on SRE data
|
| 48 |
+
|
| 49 |
+
exp/xvectors_callhome2/
|
| 50 |
+
mean.vec Vector for centering, from callhome2
|
| 51 |
+
transform.mat Whitening matrix, trained on callhome2
|
| 52 |
+
plda PLDA model for callhome1, trained on SRE data
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
II. Citation
|
| 56 |
+
------------------------------------------------------------------------------
|
| 57 |
+
|
| 58 |
+
If you wish to use this architecture in a publication, please cite one of the
|
| 59 |
+
following papers.
|
| 60 |
+
|
| 61 |
+
The x-vector architecture:
|
| 62 |
+
|
| 63 |
+
@inproceedings{snyder2018xvector,
|
| 64 |
+
title={X-vectors: Robust DNN Embeddings for Speaker Recognition},
|
| 65 |
+
author={Snyder, D. and Garcia-Romero, D. and Sell, G. and Povey, D. and Khudanpur, S.},
|
| 66 |
+
booktitle={2018 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
|
| 67 |
+
year={2018},
|
| 68 |
+
organization={IEEE},
|
| 69 |
+
url={http://www.danielpovey.com/files/2018_icassp_xvectors.pdf}
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
Diarization with x-vectors:
|
| 73 |
+
|
| 74 |
+
@article{sell2018dihard,
|
| 75 |
+
title={Diarization is Hard: Some Experiences and Lessons Learned for the JHU Team in the Inaugural DIHARD Challenge},
|
| 76 |
+
author={Sell, G. and Snyder, D. and McCree, A. and Garcia-Romero, D. and Villalba, J. and Maciejewski, M. and Manohar, V. and Dehak, N. and Povey, D. and Watanabe, S. and Khudanpur, J.},
|
| 77 |
+
journal={Interspeech},
|
| 78 |
+
year={2018}
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
III. Recipe README.txt
|
| 83 |
+
------------------------------------------------------------------------------
|
| 84 |
+
The following text is the README.txt from egs/callhome_diarization/v2 at the
|
| 85 |
+
time this archive was created.
|
| 86 |
+
|
| 87 |
+
This recipe replaces i-vectors used in the v1 recipe with embeddings extracted
|
| 88 |
+
from a deep neural network. In the scripts, we refer to these embeddings as
|
| 89 |
+
"x-vectors." The x-vector recipe in
|
| 90 |
+
local/nnet3/xvector/tuning/run_xvector_1a.sh is closesly based on the
|
| 91 |
+
following paper:
|
| 92 |
+
|
| 93 |
+
However, in this example, the x-vectors are used for diarization, rather
|
| 94 |
+
than speaker recognition. Diarization is performed by splitting speech
|
| 95 |
+
segments into very short segments (e.g., 1.5 seconds), extracting embeddings
|
| 96 |
+
from the segments, and clustering them to obtain speaker labels.
|
| 97 |
+
|
| 98 |
+
The recipe uses the following data for system development. This is in
|
| 99 |
+
addition to the NIST SRE 2000 dataset (Callhome) which is used for
|
| 100 |
+
evaluation (see ../README.txt).
|
| 101 |
+
|
| 102 |
+
Corpus LDC Catalog No.
|
| 103 |
+
SRE2004 LDC2006S44
|
| 104 |
+
SRE2005 Train LDC2011S01
|
| 105 |
+
SRE2005 Test LDC2011S04
|
| 106 |
+
SRE2006 Train LDC2011S09
|
| 107 |
+
SRE2006 Test 1 LDC2011S10
|
| 108 |
+
SRE2006 Test 2 LDC2012S01
|
| 109 |
+
SRE2008 Train LDC2011S05
|
| 110 |
+
SRE2008 Test LDC2011S08
|
| 111 |
+
SWBD2 Phase 2 LDC99S79
|
| 112 |
+
SWBD2 Phase 3 LDC2002S06
|
| 113 |
+
SWBD Cellular 1 LDC2001S13
|
| 114 |
+
SWBD Cellular 2 LDC2004S07
|
| 115 |
+
|
| 116 |
+
The following datasets are used in data augmentation.
|
| 117 |
+
|
| 118 |
+
MUSAN http://www.openslr.org/17
|
| 119 |
+
RIR_NOISES http://www.openslr.org/28
|
speaker_indentification/vosk-model-spk-0.4/final.ext.raw
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d30354474ac719b3cfce58186463302b988cf22ce3afee7237a2da862c4a91d1
|
| 3 |
+
size 14929171
|
speaker_indentification/vosk-model-spk-0.4/mean.vec
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
[ 4.450152 4.672029 4.148891 -1.711527 3.509846 2.931994 2.850384 3.178227 0.2563171 -0.9261234 3.37196 0.1472566 5.635284 -0.01870821 1.972103 -0.9502754 4.401544 2.795261 2.67637 3.917823 0.6549923 -0.02103148 4.064806 4.100016 3.700118 1.252804 5.399523 4.084152 4.106742 3.5622 4.165306 -0.2494654 -0.9603948 4.272289 -2.332889 -0.7292819 3.646834 0.3090337 4.624666 5.089351 -5.635771 1.634198 1.089098 4.363739 3.618721 0.2134228 -0.3965465 5.353687 4.034757 4.032773 3.749556 3.166129 3.868708 4.381798 -0.02561651 0.3426051 4.402168 0.1237091 0.8197291 3.809948 -2.995811 -1.648535 3.202967 3.239381 3.250949 -0.9064079 4.452719 0.2775586 0.80832 3.036884 5.163679 0.4273587 3.537773 2.539269 3.151272 4.064805 3.56104 4.244997 3.660802 4.949434 4.013721 1.418729 1.845101 4.74059 3.280786 -1.731479 1.492544 -2.88268 5.013491 5.327713 -2.668042 1.02902 -0.9622369 3.954224 3.2533 3.348548 2.906777 -0.3059559 4.595854 0.3410174 2.116138 4.830284 3.402886 3.014466 4.481457 5.14358 2.05649 3.883894 -0.9075359 4.574888 4.064843 -1.416883 3.493051 -0.06792944 4.978102 4.930044 4.138368 2.826191 4.031521 2.575887 0.7125556 4.15551 2.601444 1.190357 -1.060124 0.9739355 4.671662 -1.613742 ]
|
speaker_indentification/vosk-model-spk-0.4/mfcc.conf
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
--sample-frequency=8000
|
| 2 |
+
--high-freq=3700
|
| 3 |
+
--low-freq=20
|
| 4 |
+
--num-ceps=23
|
| 5 |
+
--allow-downsample=true
|
speaker_indentification/vosk-model-spk-0.4/transform.mat
ADDED
|
Binary file (65.6 kB). View file
|
|
|
tts/vosk-model-tts-ru-0.9-multi.7z
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8e0f1ce5406abbabf29b823a11a2d937e70a1abeeb5d96c5fb518edc0cd4b949
|
| 3 |
+
size 761220882
|
tts/vosk-model-tts-ru-0.9-multi/README.md
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Russian Vosk TTS model
|
| 2 |
+
|
| 3 |
+
Version 0.9
|
| 4 |
+
|
| 5 |
+
Metrics:
|
| 6 |
+
|
| 7 |
+
CER 0.6
|
| 8 |
+
FAD 0.810
|
| 9 |
+
UTMOS 3.290
|
| 10 |
+
Speaker Similarity 0.875
|
| 11 |
+
xRT CPU 0.35
|
| 12 |
+
xRT GPU 0.06
|
| 13 |
+
|
| 14 |
+
License: Apache 2.0
|
| 15 |
+
|
| 16 |
+
Changelog:
|
| 17 |
+
|
| 18 |
+
* ASR alignment
|
| 19 |
+
* No encoder, just duration predictor
|
| 20 |
+
* Slightly thinner predictor width (160) to fit DiT hidden vector
|
| 21 |
+
* Scale for diffusion loss (to not dominate on duration loss)
|
| 22 |
+
|
tts/vosk-model-tts-ru-0.9-multi/bert/README.md
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
language:
|
| 3 |
+
- ru
|
| 4 |
+
tags:
|
| 5 |
+
- PyTorch
|
| 6 |
+
- Transformers
|
| 7 |
+
- bert
|
| 8 |
+
- exbert
|
| 9 |
+
pipeline_tag: fill-mask
|
| 10 |
+
thumbnail: "https://github.com/sberbank-ai/model-zoo"
|
| 11 |
+
license: apache-2.0
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
# ruBert-base
|
| 15 |
+
The model architecture design, pretraining, and evaluation are documented in our preprint: [**A Family of Pretrained Transformer Language Models for Russian**](https://arxiv.org/abs/2309.10931).
|
| 16 |
+
|
| 17 |
+
The model is pretrained by the [SberDevices](https://sberdevices.ru/) team.
|
| 18 |
+
* Task: `mask filling`
|
| 19 |
+
* Type: `encoder`
|
| 20 |
+
* Tokenizer: `BPE`
|
| 21 |
+
* Dict size: `120 138`
|
| 22 |
+
* Num Parameters: `178 M`
|
| 23 |
+
* Training Data Volume `30 GB`
|
| 24 |
+
|
| 25 |
+
# Authors
|
| 26 |
+
+ NLP core team RnD [Telegram channel](https://t.me/nlpcoreteam):
|
| 27 |
+
+ Dmitry Zmitrovich
|
| 28 |
+
|
| 29 |
+
# Cite us
|
| 30 |
+
```
|
| 31 |
+
@misc{zmitrovich2023family,
|
| 32 |
+
title={A Family of Pretrained Transformer Language Models for Russian},
|
| 33 |
+
author={Dmitry Zmitrovich and Alexander Abramov and Andrey Kalmykov and Maria Tikhonova and Ekaterina Taktasheva and Danil Astafurov and Mark Baushenko and Artem Snegirev and Tatiana Shavrina and Sergey Markov and Vladislav Mikhailov and Alena Fenogenova},
|
| 34 |
+
year={2023},
|
| 35 |
+
eprint={2309.10931},
|
| 36 |
+
archivePrefix={arXiv},
|
| 37 |
+
primaryClass={cs.CL}
|
| 38 |
+
}
|
| 39 |
+
```
|
tts/vosk-model-tts-ru-0.9-multi/bert/model.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2e2f1740eaae5e29c2b4844625cbb01ff644b2b5fb0560bd34374c35d8a092c1
|
| 3 |
+
size 654361598
|
tts/vosk-model-tts-ru-0.9-multi/bert/vocab.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tts/vosk-model-tts-ru-0.9-multi/config.json
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"audio": {
|
| 3 |
+
"sample_rate": 22050
|
| 4 |
+
},
|
| 5 |
+
"inference": {
|
| 6 |
+
"noise_level": 0.8,
|
| 7 |
+
"speech_rate": 1,
|
| 8 |
+
"duration_noise_level": 0.8
|
| 9 |
+
},
|
| 10 |
+
"phoneme_id_map": {
|
| 11 |
+
"_": 0,
|
| 12 |
+
"^": 1,
|
| 13 |
+
"$": 2,
|
| 14 |
+
" ": 3,
|
| 15 |
+
"!": 4,
|
| 16 |
+
"'": 5,
|
| 17 |
+
"(": 6,
|
| 18 |
+
")": 7,
|
| 19 |
+
",": 8,
|
| 20 |
+
"-": 9,
|
| 21 |
+
".": 10,
|
| 22 |
+
"...": 11,
|
| 23 |
+
":": 12,
|
| 24 |
+
";": 13,
|
| 25 |
+
"?": 14,
|
| 26 |
+
"a0": 15,
|
| 27 |
+
"a1": 16,
|
| 28 |
+
"b": 17,
|
| 29 |
+
"bj": 18,
|
| 30 |
+
"c": 19,
|
| 31 |
+
"ch": 20,
|
| 32 |
+
"d": 21,
|
| 33 |
+
"dj": 22,
|
| 34 |
+
"e0": 23,
|
| 35 |
+
"e1": 24,
|
| 36 |
+
"f": 25,
|
| 37 |
+
"fj": 26,
|
| 38 |
+
"g": 27,
|
| 39 |
+
"gj": 28,
|
| 40 |
+
"h": 29,
|
| 41 |
+
"hj": 30,
|
| 42 |
+
"i0": 31,
|
| 43 |
+
"i1": 32,
|
| 44 |
+
"j": 33,
|
| 45 |
+
"k": 34,
|
| 46 |
+
"kj": 35,
|
| 47 |
+
"l": 36,
|
| 48 |
+
"lj": 37,
|
| 49 |
+
"m": 38,
|
| 50 |
+
"mj": 39,
|
| 51 |
+
"n": 40,
|
| 52 |
+
"nj": 41,
|
| 53 |
+
"o0": 42,
|
| 54 |
+
"o1": 43,
|
| 55 |
+
"p": 44,
|
| 56 |
+
"pj": 45,
|
| 57 |
+
"r": 46,
|
| 58 |
+
"rj": 47,
|
| 59 |
+
"s": 48,
|
| 60 |
+
"sch": 49,
|
| 61 |
+
"sh": 50,
|
| 62 |
+
"sj": 51,
|
| 63 |
+
"t": 52,
|
| 64 |
+
"tj": 53,
|
| 65 |
+
"u0": 54,
|
| 66 |
+
"u1": 55,
|
| 67 |
+
"v": 56,
|
| 68 |
+
"vj": 57,
|
| 69 |
+
"y0": 58,
|
| 70 |
+
"y1": 59,
|
| 71 |
+
"z": 60,
|
| 72 |
+
"zh": 61,
|
| 73 |
+
"zj": 62
|
| 74 |
+
},
|
| 75 |
+
"num_symbols": 62,
|
| 76 |
+
"num_speakers": 5,
|
| 77 |
+
"speaker_id_map": {
|
| 78 |
+
"female_0": 0,
|
| 79 |
+
"female_1": 1,
|
| 80 |
+
"female_2": 2,
|
| 81 |
+
"male_0": 3,
|
| 82 |
+
"male_1": 4
|
| 83 |
+
},
|
| 84 |
+
"model_type": "multistream_v1"
|
| 85 |
+
}
|
tts/vosk-model-tts-ru-0.9-multi/dictionary
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2939e72c170bb41ac8e256828cca1c5fac4db1e36717f9f53fde843b00a220ba
|
| 3 |
+
size 101431118
|
tts/vosk-model-tts-ru-0.9-multi/model.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0fa5a36b22a8bf7fe7179a3882c6371d2c01e5317019e717516f892d329c24b9
|
| 3 |
+
size 179314533
|