Safetensors
Hebrew
bert
File size: 8,796 Bytes
3ffccd9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f9db3c2
a5aac9e
 
 
 
f9db3c2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
from typing import List
import json
from tokenizers import NormalizedString, PreTokenizedString
import re
from transformers import BertTokenizerFast
from .splinter_json import splinter_data
from tokenizers.pre_tokenizers import PreTokenizer, Sequence as PreTokenizerSequence
from tokenizers.decoders import Decoder

final_letters_map = {
    'ך': 'כ',
    'ם': 'מ',
    'ץ': 'צ',
    'ף': 'פ',
    'ן': 'נ',
    'כ': 'ך',
    'מ': 'ם',
    'צ': 'ץ',
    'פ': 'ף',
    'נ': 'ן'
}

def get_permutation(word, position, word_length):
    if position < 0:
        permutation = word[:word_length + position] + word[(word_length + position + 1):]
    else:
        permutation = word[:position] + word[(position + 1):]
    return permutation

def replace_final_letters(text):
    if text == '': return text
    if text[-1] in final_letters_map:
        return replace_last_letter(text, final_letters_map[text[-1]])
    return text
    
def replace_last_letter(text, replacement):
    return text[:-1] + replacement

def is_hebrew_letter(char):
    return '\u05D0' <= char <= '\u05EA'

def is_word_contains_non_hebrew_letters(word) -> str:
    return re.search(r'[^\u05D0-\u05EA]', word) is not None

class Splinter:
    def __init__(self, path, use_cache=True):
        if type(path) == str:
            with open(path, 'r', encoding='utf-8-sig') as r:
                parsed = json.loads(r.read())
        else:
            parsed = path

        self.reductions_map = {int(key): value for key, value in parsed['reductions_map'].items()}
        self.new_unicode_chars_map = parsed['new_unicode_chars']
        self.new_unicode_chars_inverted_map = {v:k for k,v in self.new_unicode_chars_map.items()}
        self.word_reductions_cache = dict()
        self.use_cache = use_cache

    def splinter_word(self, word: str):
        if self.use_cache:
            ret = self.word_reductions_cache.get(word, None)
            if ret: return self.word_reductions_cache[word]
        
        clean_word = replace_final_letters(word)
        # if a word contains non-Hebrew characters, convert only the Hebrew.
        if len(clean_word) > 15 or is_word_contains_non_hebrew_letters(clean_word):
            encoded_word = self.get_word_with_non_hebrew_chars_reduction(clean_word)
        else:
            word_reductions = self.get_word_reductions(clean_word)
            encoded_word = ''.join([self.new_unicode_chars_map[reduction] for reduction in word_reductions])
        if self.use_cache:
            self.word_reductions_cache[word] = encoded_word
        return encoded_word

    def unsplinter_word(self, word: str):
        decoded_word = self.decode_word(word)
        return self.rebuild_reduced_word(decoded_word)

    def decode_word(self, word: str):
        return [self.new_unicode_chars_inverted_map.get(char, char) for char in word]


    def rebuild_reduced_word(self, decoded_word):
        original_word = ""
        for reduction in decoded_word:
            if ':' in reduction and len(reduction) > 1:
                position, letter = reduction.split(':')
                position = int(position)
                if position < 0:
                    position = len(original_word) + position + 1
                if len(original_word) == position - 1:
                    original_word += reduction
                else:
                    original_word = original_word[:position] + letter + original_word[position:]
            else:
                original_word += reduction

        original_word = replace_final_letters(original_word)
        return original_word


    def get_word_reductions(self, word):
        reduced_word = word
        reductions = []
        while len(reduced_word) > 3:
            # if this word length has no known reductions - return what's left of the word as is
            if len(reduced_word) not in self.reductions_map:
                reductions.extend(self.get_single_chars_reductions(reduced_word))
                break
            reduction = self.get_reduction(reduced_word, 3, 3)
            if reduction is not None:
                position = int(reduction.split(':')[0])
                reductions.append(reduction)
                reduced_word = get_permutation(reduced_word, position, len(reduced_word))
            # if we couldn't find a reduction - return what's left of the word as is
            else:
                reductions.extend(self.get_single_chars_reductions(reduced_word))
                break

        # if we found all reductions and left only with the suspected root - keep it as is
        if len(reduced_word) < 4:
            reductions.extend(self.get_single_chars_reductions(reduced_word))

        reductions.reverse()
        return reductions

    def get_reduction(self, word, depth, width):
        curr_step_reductions = [{"word": word, "reduction": None, "root_reduction": None, "score": 1}]
        word_length = len(word)
        i = 0
        while i < depth and len(curr_step_reductions) > 0 and word_length > 3:
            next_step_reductions = list()
            for reduction in curr_step_reductions:
                possible_reductions = self.get_most_frequent_reduction_keys(
                    reduction["word"],
                    reduction["root_reduction"],
                    reduction["score"],
                    width,
                    word_length
                )
                next_step_reductions += possible_reductions
            curr_step_reductions = list(next_step_reductions)
            i += 1
            word_length -= 1

        max_score_reduction = None
        if len(curr_step_reductions) > 0:
            max_score_reduction = max(curr_step_reductions, key=lambda x: x["score"])["root_reduction"]
        return max_score_reduction

    def get_most_frequent_reduction_keys(self, word, root_reduction, parent_score, number_of_reductions, word_length):
        possible_reductions = list()
        for reduction, score in self.reductions_map[len(word)].items():
            position, letter = reduction.split(':')
            position = int(position)
            if word[position] == letter:
                permutation = get_permutation(word, position, word_length)
                possible_reductions.append({
                    "word": permutation,
                    "reduction": reduction,
                    "root_reduction": root_reduction if root_reduction is not None else reduction,
                    "score": parent_score * score
                })
                if len(possible_reductions) >= number_of_reductions:
                    break
        return possible_reductions

    def get_word_with_non_hebrew_chars_reduction(self, word):
        return ''.join(self.new_unicode_chars_map[char] if is_hebrew_letter(char) else char for char in word)

    @staticmethod
    def get_single_chars_reductions(reduced_word):
        reductions = []
        for char in reduced_word[::-1]:
            reductions.append(char)
        return reductions

class SplinterPreTokenizer:
    def __init__(self, splinter: Splinter):
        super().__init__()
        self.splinter = splinter
        
    def splinter_split(self, i: int, str: NormalizedString):
        # create the split 
        splintered_word = iter(self.splinter.splinter_word(str.normalized))
        str.map(lambda _: next(splintered_word, ' '))
        str.strip()
        return [str]
                
    def pre_tokenize(self, pretok: PreTokenizedString):
        pretok.split(self.splinter_split)
        
class SplinterDecoder:
    def __init__(self, splinter: Splinter):
        self.splinter = splinter

    def decode_chain(self, tokens: List[str]) -> List[str]:
        # combine the wordpieces
        combined_tokens = []
        for token in tokens:
            if token.startswith('##') and combined_tokens:
                combined_tokens[-1] += token[2:]
            else: combined_tokens.append(token)

        return [f' {t}' for t in map(self.splinter.unsplinter_word, combined_tokens)]
    

class SplinterBertTokenizerFast(BertTokenizerFast):
    def __init__(self, *args, use_cache=False, **kwargs):
        super().__init__(*args, **kwargs)
        self.splinter = Splinter(splinter_data, use_cache=use_cache)
        self._tokenizer.pre_tokenizer = PreTokenizerSequence([
            self._tokenizer.pre_tokenizer,
            PreTokenizer.custom(SplinterPreTokenizer(self.splinter))
        ])
        self._tokenizer.decoder = Decoder.custom(SplinterDecoder(self.splinter))

    def save_pretrained(self, *args, **kwargs):
        self._save_pretrained(*args, **kwargs)

    def _save_pretrained(self, *args, **kwargs):
        print('Cannot save SplinterBertTokenizerFast, please copy the files directly from the repository')