|
import base64 |
|
import os |
|
from functools import lru_cache |
|
from typing import Optional |
|
import torch |
|
from transformers import AutoTokenizer |
|
from whisper.tokenizer import Tokenizer |
|
|
|
import tiktoken |
|
|
|
LANGUAGES = { |
|
"en": "english", |
|
"zh": "chinese", |
|
"de": "german", |
|
"es": "spanish", |
|
"ru": "russian", |
|
"ko": "korean", |
|
"fr": "french", |
|
"ja": "japanese", |
|
"pt": "portuguese", |
|
"tr": "turkish", |
|
"pl": "polish", |
|
"ca": "catalan", |
|
"nl": "dutch", |
|
"ar": "arabic", |
|
"sv": "swedish", |
|
"it": "italian", |
|
"id": "indonesian", |
|
"hi": "hindi", |
|
"fi": "finnish", |
|
"vi": "vietnamese", |
|
"he": "hebrew", |
|
"uk": "ukrainian", |
|
"el": "greek", |
|
"ms": "malay", |
|
"cs": "czech", |
|
"ro": "romanian", |
|
"da": "danish", |
|
"hu": "hungarian", |
|
"ta": "tamil", |
|
"no": "norwegian", |
|
"th": "thai", |
|
"ur": "urdu", |
|
"hr": "croatian", |
|
"bg": "bulgarian", |
|
"lt": "lithuanian", |
|
"la": "latin", |
|
"mi": "maori", |
|
"ml": "malayalam", |
|
"cy": "welsh", |
|
"sk": "slovak", |
|
"te": "telugu", |
|
"fa": "persian", |
|
"lv": "latvian", |
|
"bn": "bengali", |
|
"sr": "serbian", |
|
"az": "azerbaijani", |
|
"sl": "slovenian", |
|
"kn": "kannada", |
|
"et": "estonian", |
|
"mk": "macedonian", |
|
"br": "breton", |
|
"eu": "basque", |
|
"is": "icelandic", |
|
"hy": "armenian", |
|
"ne": "nepali", |
|
"mn": "mongolian", |
|
"bs": "bosnian", |
|
"kk": "kazakh", |
|
"sq": "albanian", |
|
"sw": "swahili", |
|
"gl": "galician", |
|
"mr": "marathi", |
|
"pa": "punjabi", |
|
"si": "sinhala", |
|
"km": "khmer", |
|
"sn": "shona", |
|
"yo": "yoruba", |
|
"so": "somali", |
|
"af": "afrikaans", |
|
"oc": "occitan", |
|
"ka": "georgian", |
|
"be": "belarusian", |
|
"tg": "tajik", |
|
"sd": "sindhi", |
|
"gu": "gujarati", |
|
"am": "amharic", |
|
"yi": "yiddish", |
|
"lo": "lao", |
|
"uz": "uzbek", |
|
"fo": "faroese", |
|
"ht": "haitian creole", |
|
"ps": "pashto", |
|
"tk": "turkmen", |
|
"nn": "nynorsk", |
|
"mt": "maltese", |
|
"sa": "sanskrit", |
|
"lb": "luxembourgish", |
|
"my": "myanmar", |
|
"bo": "tibetan", |
|
"tl": "tagalog", |
|
"mg": "malagasy", |
|
"as": "assamese", |
|
"tt": "tatar", |
|
"haw": "hawaiian", |
|
"ln": "lingala", |
|
"ha": "hausa", |
|
"ba": "bashkir", |
|
"jw": "javanese", |
|
"su": "sundanese", |
|
"yue": "cantonese", |
|
"minnan": "minnan", |
|
"wuyu": "wuyu", |
|
"dialect": "dialect", |
|
"zh/en": "zh/en", |
|
"en/zh": "en/zh", |
|
} |
|
|
|
|
|
TO_LANGUAGE_CODE = { |
|
**{language: code for code, language in LANGUAGES.items()}, |
|
"burmese": "my", |
|
"valencian": "ca", |
|
"flemish": "nl", |
|
"haitian": "ht", |
|
"letzeburgesch": "lb", |
|
"pushto": "ps", |
|
"panjabi": "pa", |
|
"moldavian": "ro", |
|
"moldovan": "ro", |
|
"sinhalese": "si", |
|
"castilian": "es", |
|
"mandarin": "zh", |
|
} |
|
|
|
AUDIO_EVENT = { |
|
"ASR": "ASR", |
|
"AED": "AED", |
|
"SER": "SER", |
|
"Speech": "Speech", |
|
"/Speech": "/Speech", |
|
"BGM": "BGM", |
|
"/BGM": "/BGM", |
|
"Laughter": "Laughter", |
|
"/Laughter": "/Laughter", |
|
"Applause": "Applause", |
|
"/Applause": "/Applause", |
|
} |
|
|
|
EMOTION = { |
|
"HAPPY": "HAPPY", |
|
"SAD": "SAD", |
|
"ANGRY": "ANGRY", |
|
"NEUTRAL": "NEUTRAL", |
|
} |
|
|
|
TTS_Vocal_Token = { |
|
"TTS/B": "TTS/B", |
|
"TTS/O": "TTS/O", |
|
"TTS/Q": "TTS/Q", |
|
"TTS/A": "TTS/A", |
|
"TTS/CO": "TTS/CO", |
|
"TTS/CL": "TTS/CL", |
|
"TTS/H": "TTS/H", |
|
**{f"TTS/SP{i:02d}": f"TTS/SP{i:02d}" for i in range(1, 14)} |
|
} |
|
|
|
|
|
@lru_cache(maxsize=None) |
|
def get_encoding(name: str = "gpt2", num_languages: int = 99): |
|
vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken") |
|
ranks = { |
|
base64.b64decode(token): int(rank) |
|
for token, rank in (line.split() for line in open(vocab_path) if line) |
|
} |
|
n_vocab = len(ranks) |
|
special_tokens = {} |
|
|
|
specials = [ |
|
"<|endoftext|>", |
|
"<|startoftranscript|>", |
|
*[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]], |
|
*[f"<|{audio_event}|>" for audio_event in list(AUDIO_EVENT.keys())], |
|
*[f"<|{emotion}|>" for emotion in list(EMOTION.keys())], |
|
"<|translate|>", |
|
"<|transcribe|>", |
|
"<|startoflm|>", |
|
"<|startofprev|>", |
|
"<|nospeech|>", |
|
"<|notimestamps|>", |
|
*[f"<|SPECIAL_TOKEN_{i}|>" for i in range(1, 31)], |
|
*[f"<|{tts}|>" for tts in list(TTS_Vocal_Token.keys())], |
|
*[f"<|{i * 0.02:.2f}|>" for i in range(1501)], |
|
] |
|
|
|
for token in specials: |
|
special_tokens[token] = n_vocab |
|
n_vocab += 1 |
|
|
|
return tiktoken.Encoding( |
|
name=os.path.basename(vocab_path), |
|
explicit_n_vocab=n_vocab, |
|
pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""", |
|
mergeable_ranks=ranks, |
|
special_tokens=special_tokens, |
|
) |
|
|
|
|
|
@lru_cache(maxsize=None) |
|
def get_tokenizer( |
|
multilingual: bool, |
|
*, |
|
num_languages: int = 99, |
|
language: Optional[str] = None, |
|
task: Optional[str] = None, |
|
) -> Tokenizer: |
|
if language is not None: |
|
language = language.lower() |
|
if language not in LANGUAGES: |
|
if language in TO_LANGUAGE_CODE: |
|
language = TO_LANGUAGE_CODE[language] |
|
else: |
|
raise ValueError(f"Unsupported language: {language}") |
|
|
|
if multilingual: |
|
encoding_name = "multilingual_zh_ja_yue_char_del" |
|
language = language or "en" |
|
task = task or "transcribe" |
|
else: |
|
encoding_name = "gpt2" |
|
language = None |
|
task = None |
|
|
|
encoding = get_encoding(name=encoding_name, num_languages=num_languages) |
|
|
|
return Tokenizer( |
|
encoding=encoding, num_languages=num_languages, language=language, task=task |
|
) |
|
|
|
|
|
class QwenTokenizer(): |
|
def __init__(self, token_path, skip_special_tokens=True): |
|
super().__init__() |
|
|
|
special_tokens = { |
|
'eos_token': '<|endoftext|>', |
|
'pad_token': '<|endoftext|>', |
|
'additional_special_tokens': [ |
|
'<|im_start|>', '<|im_end|>', '<|endofprompt|>', |
|
'[breath]', '<strong>', '</strong>', '[noise]', |
|
'[laughter]', '[cough]', '[clucking]', '[accent]', |
|
'[quick_breath]', |
|
"<laughter>", "</laughter>", |
|
"[hissing]", "[sigh]", "[vocalized-noise]", |
|
"[lipsmack]", "[mn]" |
|
] |
|
} |
|
self.special_tokens = special_tokens |
|
self.tokenizer = AutoTokenizer.from_pretrained(token_path) |
|
self.tokenizer.add_special_tokens(special_tokens) |
|
self.skip_special_tokens = skip_special_tokens |
|
|
|
def encode(self, text, **kwargs): |
|
tokens = self.tokenizer([text], return_tensors="pt") |
|
tokens = tokens["input_ids"][0].cpu().tolist() |
|
return tokens |
|
|
|
def decode(self, tokens): |
|
tokens = torch.tensor(tokens, dtype=torch.int64) |
|
text = self.tokenizer.batch_decode([tokens], skip_special_tokens=self.skip_special_tokens)[0] |
|
return text |
|
|
|
|
|
@lru_cache(maxsize=None) |
|
def get_qwen_tokenizer( |
|
token_path: str, |
|
skip_special_tokens: bool |
|
) -> QwenTokenizer: |
|
return QwenTokenizer(token_path=token_path, skip_special_tokens=skip_special_tokens) |
|
|