from typing import Any, Dict, List, Optional, Union import json class HajoTextTokenizer: def __init__(self, config_file: str): with open(config_file,'rt') as f: self.all_tokens = json.load(f) self.unk = 1000 + len(self.all_tokens)-1 self.all_tokens[self.unk-1000] = '?' self.valid_tokens = self.all_tokens[:-1] def encode(self, sentence): sentence = sentence.replace('ß','ss').replace('-',' ').replace(' ',' ').replace(' ',' ').lower() sentence = list(sentence) for tokid,tok in enumerate(self.valid_tokens): tlen = len(tok) ltok = list(tok) for off in range(len(sentence)-tlen+1): # print(sentence[off:off+tlen], ltok) if sentence[off:off+tlen] == ltok: prefix = sentence[:off] suffix = sentence[off+tlen:] # print('MATCH', [prefix, tok, suffix]) #print('MATCH', tok) sentence = prefix + [1000+tokid] + suffix #break out = [] last_id = 0 for t in sentence: if isinstance(t, str): t = self.unk if t == last_id: if t == self.unk: continue out.append(0) last_id = t out.append(t-1000) return out def decode(self, label_ids): out = '' last_id = 0 for i in label_ids: if i == 0 or i == -100: last_id = i continue if i == 1: break if i != last_id: out += self.all_tokens[i] last_id = i return out