|  | import os | 
					
						
						|  | import re | 
					
						
						|  | import unicodedata | 
					
						
						|  | from typing import Dict | 
					
						
						|  |  | 
					
						
						|  | import kenlm | 
					
						
						|  | import sentencepiece | 
					
						
						|  | from huggingface_hub import cached_download, hf_hub_url | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class SentencePiece: | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | model: str, | 
					
						
						|  | ): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.sp = sentencepiece.SentencePieceProcessor() | 
					
						
						|  | self.sp.load(str(model)) | 
					
						
						|  |  | 
					
						
						|  | def do(self, text: dict) -> dict: | 
					
						
						|  | tokenized = self.sp.encode_as_pieces(text) | 
					
						
						|  | return " ".join(tokenized) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class KenlmModel: | 
					
						
						|  | digit_re: re.Pattern = re.compile(r"\d") | 
					
						
						|  | unicode_punct: Dict[str, str] = { | 
					
						
						|  | ",": ",", | 
					
						
						|  | "。": ".", | 
					
						
						|  | "、": ",", | 
					
						
						|  | "„": '"', | 
					
						
						|  | "”": '"', | 
					
						
						|  | "“": '"', | 
					
						
						|  | "«": '"', | 
					
						
						|  | "»": '"', | 
					
						
						|  | "1": '"', | 
					
						
						|  | "」": '"', | 
					
						
						|  | "「": '"', | 
					
						
						|  | "《": '"', | 
					
						
						|  | "》": '"', | 
					
						
						|  | "´": "'", | 
					
						
						|  | "∶": ":", | 
					
						
						|  | ":": ":", | 
					
						
						|  | "?": "?", | 
					
						
						|  | "!": "!", | 
					
						
						|  | "(": "(", | 
					
						
						|  | ")": ")", | 
					
						
						|  | ";": ";", | 
					
						
						|  | "–": "-", | 
					
						
						|  | "—": " - ", | 
					
						
						|  | ".": ". ", | 
					
						
						|  | "~": "~", | 
					
						
						|  | "’": "'", | 
					
						
						|  | "…": "...", | 
					
						
						|  | "━": "-", | 
					
						
						|  | "〈": "<", | 
					
						
						|  | "〉": ">", | 
					
						
						|  | "【": "[", | 
					
						
						|  | "】": "]", | 
					
						
						|  | "%": "%", | 
					
						
						|  | "►": "-", | 
					
						
						|  | } | 
					
						
						|  | unicode_punct_re = re.compile(f"[{''.join(unicode_punct.keys())}]") | 
					
						
						|  | non_printing_chars_re = re.compile( | 
					
						
						|  | f"[{''.join(map(chr, list(range(0,32)) + list(range(127,160))))}]" | 
					
						
						|  | ) | 
					
						
						|  | kenlm_model_dir = None | 
					
						
						|  | sentence_piece_model_dir = None | 
					
						
						|  |  | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | model_dataset: str, | 
					
						
						|  | language: str, | 
					
						
						|  | lower_case: bool = False, | 
					
						
						|  | remove_accents: bool = False, | 
					
						
						|  | normalize_numbers: bool = True, | 
					
						
						|  | punctuation: int = 1, | 
					
						
						|  | ): | 
					
						
						|  | self.model = kenlm.Model(os.path.join(model_dataset, f"{language}.arpa.bin")) | 
					
						
						|  | self.tokenizer = SentencePiece(os.path.join(model_dataset, f"{language}.sp.model")) | 
					
						
						|  | self.accent = remove_accents | 
					
						
						|  | self.case = lower_case | 
					
						
						|  | self.numbers = normalize_numbers | 
					
						
						|  | self.punct = punctuation | 
					
						
						|  |  | 
					
						
						|  | @classmethod | 
					
						
						|  | def from_pretrained( | 
					
						
						|  | cls, | 
					
						
						|  | model_dataset: str, | 
					
						
						|  | language: str, | 
					
						
						|  | ): | 
					
						
						|  | return cls( | 
					
						
						|  | model_dataset, | 
					
						
						|  | language, | 
					
						
						|  | False, | 
					
						
						|  | False, | 
					
						
						|  | True, | 
					
						
						|  | 1, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def pp(self, log_score, length): | 
					
						
						|  | return 10.0 ** (-log_score / length) | 
					
						
						|  |  | 
					
						
						|  | def get_perplexity(self, doc: str, normalize_cc_net: bool = True): | 
					
						
						|  | if normalize_cc_net: | 
					
						
						|  | doc = self.normalize( | 
					
						
						|  | doc, | 
					
						
						|  | accent=self.accent, | 
					
						
						|  | case=self.case, | 
					
						
						|  | numbers=self.numbers, | 
					
						
						|  | punct=self.punct, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | doc = self.tokenizer.do(doc) | 
					
						
						|  | doc_log_score, doc_length = 0, 0 | 
					
						
						|  | for line in doc.split("\n"): | 
					
						
						|  | log_score = self.model.score(line) | 
					
						
						|  | length = len(line.split()) + 1 | 
					
						
						|  | doc_log_score += log_score | 
					
						
						|  | doc_length += length | 
					
						
						|  | return round(self.pp(doc_log_score, doc_length), 1) | 
					
						
						|  |  | 
					
						
						|  | def normalize( | 
					
						
						|  | self, | 
					
						
						|  | line: str, | 
					
						
						|  | accent: bool = True, | 
					
						
						|  | case: bool = True, | 
					
						
						|  | numbers: bool = True, | 
					
						
						|  | punct: int = 1, | 
					
						
						|  | ) -> str: | 
					
						
						|  | line = line.strip() | 
					
						
						|  | if not line: | 
					
						
						|  | return line | 
					
						
						|  | if case: | 
					
						
						|  | line = line.lower() | 
					
						
						|  | if accent: | 
					
						
						|  | line = self.strip_accents(line) | 
					
						
						|  | if numbers: | 
					
						
						|  | line = self.digit_re.sub("0", line) | 
					
						
						|  | if punct == 1: | 
					
						
						|  | line = self.replace_unicode_punct(line) | 
					
						
						|  | elif punct == 2: | 
					
						
						|  | line = self.remove_unicode_punct(line) | 
					
						
						|  | line = self.remove_non_printing_char(line) | 
					
						
						|  | return line | 
					
						
						|  |  | 
					
						
						|  | def strip_accents(self, line: str) -> str: | 
					
						
						|  | """Strips accents from a piece of text.""" | 
					
						
						|  | nfd = unicodedata.normalize("NFD", line) | 
					
						
						|  | output = [c for c in nfd if unicodedata.category(c) != "Mn"] | 
					
						
						|  | if len(output) == line: | 
					
						
						|  | return line | 
					
						
						|  | return "".join(output) | 
					
						
						|  |  | 
					
						
						|  | def replace_unicode_punct(self, text: str) -> str: | 
					
						
						|  | return "".join(self.unicode_punct.get(c, c) for c in text) | 
					
						
						|  |  | 
					
						
						|  | def remove_unicode_punct(self, text: str) -> str: | 
					
						
						|  | """More aggressive version of replace_unicode_punct but also faster.""" | 
					
						
						|  | return self.unicode_punct_re.sub("", text) | 
					
						
						|  |  | 
					
						
						|  | def remove_non_printing_char(self, text: str) -> str: | 
					
						
						|  | return self.non_printing_chars_re.sub("", text) | 
					
						
						|  |  |