peterroh commited on
Commit
99e9899
·
verified ·
1 Parent(s): db29485

Update tokenization.py

Browse files
Files changed (1) hide show
  1. tokenization.py +87 -18
tokenization.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import logging
2
  import re
3
  from typing import Optional
@@ -24,6 +25,44 @@ _INFINITE = int(1e12) # infinite token length for no-truncation
24
  logger = logging.getLogger("kanana-1.5-v")
25
 
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  def _pad_trunc(
28
  x: list[list[int]],
29
  padding: str,
@@ -101,20 +140,6 @@ class KananaVTokenizerMixin:
101
 
102
  return repeated_tokens
103
 
104
- def encode_text_only(self, prompt: str, add_special_tokens: bool = False) -> list:
105
- # Text-only Data
106
- # split prompt into chunks by role tokens
107
- tokens_to_split = [_AI, _HUMAN]
108
- pattern = "|".join(map(re.escape, tokens_to_split))
109
- chunk_strs = re.split(f"({pattern})", prompt)
110
- chunk_strs = [x for x in chunk_strs if len(x) > 0]
111
-
112
- enc_chunk = []
113
- for idx, chunk_str in enumerate(chunk_strs):
114
- curr_chunk = self(chunk_str, add_special_tokens=False)["input_ids"]
115
- enc_chunk += curr_chunk
116
- return enc_chunk
117
-
118
  def encode_prompt(
119
  self, prompt: str, max_length: int | None = None, image_meta: dict | None = None
120
  ) -> dict:
@@ -228,13 +253,57 @@ class KananaVTokenizer(PreTrainedTokenizer, KananaVTokenizerMixin):
228
  def __init__(self, **kwargs):
229
  super().__init__(**kwargs)
230
 
231
- def encode(self, text, add_special_tokens=False) -> list:
232
- return self.encode_text_only(prompt=text, add_special_tokens=add_special_tokens)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
 
234
 
235
  class KananaVTokenizerFast(PreTrainedTokenizerFast, KananaVTokenizerMixin):
236
  def __init__(self, **kwargs):
237
  super().__init__(**kwargs)
238
 
239
- def encode(self, text, add_special_tokens=False) -> list:
240
- return self.encode_text_only(prompt=text, add_special_tokens=add_special_tokens)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
  import logging
3
  import re
4
  from typing import Optional
 
25
  logger = logging.getLogger("kanana-1.5-v")
26
 
27
 
28
+ class AttrDict(dict):
29
+ __slots__ = ()
30
+
31
+ def __getattr__(self, name):
32
+ try:
33
+ val = self[name]
34
+ except KeyError:
35
+ raise AttributeError(name) from None
36
+
37
+ if isinstance(val, dict) and not isinstance(val, AttrDict):
38
+ val = AttrDict(val)
39
+ self[name] = val
40
+ return val
41
+
42
+ def __setattr__(self, name, value):
43
+ if name.startswith('_'):
44
+ return super().__setattr__(name, value)
45
+ if isinstance(value, dict) and not isinstance(value, AttrDict):
46
+ value = AttrDict(value)
47
+ self[name] = value
48
+
49
+ def __delattr__(self, name):
50
+ try:
51
+ del self[name]
52
+ except KeyError:
53
+ raise AttributeError(name) from None
54
+
55
+
56
+ def to_attrdict(obj):
57
+ if isinstance(obj, dict) and not isinstance(obj, AttrDict):
58
+ return AttrDict({k: to_attrdict(v) for k, v in obj.items()})
59
+ if isinstance(obj, list):
60
+ return [to_attrdict(x) for x in obj]
61
+ if isinstance(obj, tuple):
62
+ return tuple(to_attrdict(x) for x in obj)
63
+ return obj
64
+
65
+
66
  def _pad_trunc(
67
  x: list[list[int]],
68
  padding: str,
 
140
 
141
  return repeated_tokens
142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  def encode_prompt(
144
  self, prompt: str, max_length: int | None = None, image_meta: dict | None = None
145
  ) -> dict:
 
253
  def __init__(self, **kwargs):
254
  super().__init__(**kwargs)
255
 
256
+ def __call__(self, text, *args, **kwargs):
257
+ assert isinstance(text, str), "Only str is supported for tokenization."
258
+
259
+ # split prompt into chunks by role tokens: text (str) -> chunk_strs (list)
260
+ tokens_to_split = [_AI, _HUMAN]
261
+ pattern = "|".join(map(re.escape, tokens_to_split))
262
+ if re.search(pattern, text):
263
+ chunk_strs = re.split(f"({pattern})", text)
264
+ chunk_strs = [x for x in chunk_strs if len(x) > 0]
265
+
266
+ # encode chunk strs
267
+ kwargs["add_special_tokens"] = False
268
+ encodings = defaultdict(list)
269
+ for chunk_str in chunk_strs:
270
+ encoding = super().__call__(chunk_str, *args, **kwargs)
271
+ for k, v in encoding.items():
272
+ encodings[k].extend(v)
273
+ encodings = to_attrdict(encodings)
274
+ return encodings
275
+ else:
276
+ return super().__call__(text, *args, **kwargs)
277
+
278
+ def encode(self, *args, **kwargs):
279
+ return self.__call__(*args, **kwargs)["input_ids"]
280
 
281
 
282
  class KananaVTokenizerFast(PreTrainedTokenizerFast, KananaVTokenizerMixin):
283
  def __init__(self, **kwargs):
284
  super().__init__(**kwargs)
285
 
286
+ def __call__(self, text, *args, **kwargs):
287
+ assert isinstance(text, str), "Only str is supported for fast tokenization."
288
+
289
+ # split prompt into chunks by role tokens: text (str) -> chunk_strs (list)
290
+ tokens_to_split = [_AI, _HUMAN]
291
+ pattern = "|".join(map(re.escape, tokens_to_split))
292
+ if re.search(pattern, text):
293
+ chunk_strs = re.split(f"({pattern})", text)
294
+ chunk_strs = [x for x in chunk_strs if len(x) > 0]
295
+
296
+ # encode chunk strs
297
+ kwargs["add_special_tokens"] = False
298
+ encodings = defaultdict(list)
299
+ for chunk_str in chunk_strs:
300
+ encoding = super().__call__(chunk_str, *args, **kwargs)
301
+ for k, v in encoding.items():
302
+ encodings[k].extend(v)
303
+ encodings = to_attrdict(encodings)
304
+ return encodings
305
+ else:
306
+ return super().__call__(text, *args, **kwargs)
307
+
308
+ def encode(self, *args, **kwargs):
309
+ return self.__call__(*args, **kwargs)["input_ids"]