|
import argparse |
|
import json |
|
import re |
|
import os |
|
from functools import cache |
|
from pathlib import Path |
|
from typing import Iterator, List, NoReturn, Optional, Tuple, Union |
|
|
|
import kenlm |
|
import msgspec |
|
import sentencepiece |
|
from numpy.random import default_rng |
|
from scipy.stats import norm |
|
from tqdm import tqdm |
|
|
|
from normalization import normalize_text |
|
|
|
|
|
RNG = default_rng() |
|
LANGS = ("no", "nn", "nob", "nno", "da", "sv", "is", "en") |
|
DEFAULT_LANG = "no" |
|
BASEPATH = Path(os.environ.get("PERPLEXITY_BASEPATH", "/nfsmounts/datastore/mimir/perplexity")) |
|
CONFIG = { |
|
"harmful": { |
|
"no": {"model": BASEPATH / "kenlm" / "harmful" / "no.bin", "normalize": True}, |
|
"nn": {"model": BASEPATH / "kenlm" / "harmful" / "no.bin", "normalize": True}, |
|
"nob": {"model": BASEPATH / "kenlm" / "harmful" / "no.bin", "normalize": True}, |
|
"nno": {"model": BASEPATH / "kenlm" / "harmful" / "no.bin", "normalize": True}, |
|
"da": {"model": BASEPATH / "kenlm" / "harmful" / "da.bin", "normalize": True}, |
|
"sv": {"model": BASEPATH / "kenlm" / "harmful" / "sv.bin", "normalize": True}, |
|
"is": {"model": BASEPATH / "kenlm" / "harmful" / "is.bin", "normalize": True}, |
|
"en": {"model": BASEPATH / "kenlm" / "harmful" / "en.bin", "normalize": True}, |
|
}, |
|
"wikipedia": { |
|
"no": { |
|
"model": BASEPATH / "kenlm" / "wikipedia" / "no.arpa.bin", |
|
"tokenizer": BASEPATH / "spm" / "wikipedia" / "no.sp.model", |
|
"normalize": True |
|
}, |
|
"nn": { |
|
"model": BASEPATH / "kenlm" / "wikipedia" / "nn.arpa.bin", |
|
"tokenizer": BASEPATH / "spm" / "wikipedia" / "nn.sp.model", |
|
"normalize": True |
|
}, |
|
"nob": { |
|
"model": BASEPATH / "kenlm" / "wikipedia" / "no.arpa.bin", |
|
"tokenizer": BASEPATH / "spm" / "wikipedia" / "no.sp.model", |
|
"normalize": True |
|
}, |
|
"nno": { |
|
"model": BASEPATH / "kenlm" / "wikipedia" / "nn.arpa.bin", |
|
"tokenizer": BASEPATH / "spm" / "wikipedia" / "nn.sp.model", |
|
"normalize": True |
|
}, |
|
"da": { |
|
"model": BASEPATH / "kenlm" / "wikipedia" / "da.arpa.bin", |
|
"tokenizer": BASEPATH / "spm" / "wikipedia" / "da.sp.model", |
|
"normalize": True |
|
}, |
|
"en": { |
|
"model": BASEPATH / "kenlm" / "wikipedia" / "en.arpa.bin", |
|
"tokenizer": BASEPATH / "spm" / "wikipedia" / "en.sp.model", |
|
"normalize": True |
|
}, |
|
"is": { |
|
"model": BASEPATH / "kenlm" / "wikipedia" / "is.arpa.bin", |
|
"tokenizer": BASEPATH / "spm" / "wikipedia" / "is.sp.model", |
|
"normalize": True |
|
}, |
|
"sv": { |
|
"model": BASEPATH / "kenlm" / "wikipedia" / "sv.arpa.bin", |
|
"tokenizer": BASEPATH / "spm" / "wikipedia" / "sv.sp.model", |
|
"normalize": True |
|
}, |
|
}, |
|
"books": { |
|
"model": BASEPATH / "kenlm" / "books.norm.sp.arpa.bin", |
|
"tokenizer": BASEPATH / "spm" / "books.norm.sp.model", |
|
"normalize": True |
|
}, |
|
"newspapers": { |
|
"model": BASEPATH / "kenlm" / "newspapers.norm.sp.arpa.bin", |
|
"tokenizer": BASEPATH / "spm" / "newspapers.norm.sp.model", |
|
"normalize": True |
|
}, |
|
"maalfrid": { |
|
"model": BASEPATH / "kenlm" / "maalfrid.norm.sp.arpa.bin", |
|
"tokenizer": BASEPATH / "spm" / "maalfrid.norm.sp.model", |
|
"normalize": True |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def should_keep( |
|
perp: float, dist_norm: float, dist_mean: float, dist_std: float |
|
) -> bool: |
|
""" |
|
Decide if a doc is to be retained based on its perplexity value |
|
Note: set() must have been called previously |
|
""" |
|
p = norm.pdf(perp, loc=dist_mean, scale=dist_std) / dist_norm |
|
return RNG.uniform() < p |
|
|
|
|
|
def fix_language(language: str) -> str: |
|
if language not in LANGS: |
|
return DEFAULT_LANG |
|
else: |
|
return language |
|
|
|
|
|
def pp(log_score, length): |
|
return 10.0 ** (-log_score / length) |
|
|
|
|
|
@cache |
|
def load_kenlm(model: str) -> kenlm.Model: |
|
lm_config = kenlm.Config() |
|
lm_config.load_method = 2 |
|
return kenlm.Model(str(model), lm_config) |
|
|
|
|
|
@cache |
|
def load_sentencepiece(model: str) -> sentencepiece.SentencePieceProcessor: |
|
sp = sentencepiece.SentencePieceProcessor() |
|
sp.load(str(model)) |
|
return sp |
|
|
|
|
|
def get_perplexity( |
|
document: str, |
|
model: str, |
|
tokenizer: str=None, |
|
normalize: bool=False |
|
) -> float: |
|
lines = document.split("\n") |
|
model = load_kenlm(model) |
|
if not lines or not model: |
|
return 0.0 |
|
if tokenizer: |
|
sp = load_sentencepiece(tokenizer) |
|
doc_log_score, doc_length = 0, 0 |
|
for line in lines: |
|
if not line: |
|
continue |
|
if normalize: |
|
line = normalize_text(line) |
|
if tokenizer: |
|
line = " ".join(sp.encode_as_pieces(line)) |
|
log_score = model.score(line) |
|
length = len(line.split()) + 1 |
|
doc_log_score += log_score |
|
doc_length += length |
|
|
|
return round(pp(doc_log_score, doc_length), 1) |
|
|
|
|
|
def get_perplexity_local( |
|
document: str, |
|
model: kenlm.Model, |
|
tokenizer: sentencepiece.SentencePieceProcessor=None, |
|
normalize: bool=False |
|
) -> float: |
|
lines = document.split("\n") |
|
if not lines or not model: |
|
return 0.0 |
|
doc_log_score, doc_length = 0, 0 |
|
for line in lines: |
|
if normalize: |
|
line = normalize_text(line) |
|
if tokenizer is not None: |
|
line = " ".join(tokenizer.encode_as_pieces(line)) |
|
log_score = model.score(line) |
|
length = len(line.split()) + 1 |
|
doc_log_score += log_score |
|
doc_length += length |
|
|
|
return round(pp(doc_log_score, doc_length), 1) |
|
|
|
|
|
def harmful_perplexity(document: str, language: str) -> float: |
|
params = CONFIG["harmful"][fix_lang(language)] |
|
return get_perplexity(document=document, **params) |
|
|
|
|
|
def wikipedia_perplexity(document: str, language: str) -> float: |
|
params = CONFIG["wikipedia"][fix_lang(language)] |
|
return get_perplexity(document=document, **params) |
|
|
|
|
|
def books_perplexity(document: str) -> float: |
|
params = CONFIG["books"] |
|
return get_perplexity(document=document, **params) |
|
|
|
|
|
def newspapers_perplexity(document: str) -> float: |
|
params = CONFIG["newspapers"] |
|
return get_perplexity(document=document, **params) |
|
|
|
|
|
def maalfrid_perplexity(document: str) -> float: |
|
params = CONFIG["maalfrid"] |
|
return get_perplexity(document=document, **params) |
|
|
|
|
|
def source_perplexities( |
|
document: str, |
|
language: str, |
|
model: str | None = None, |
|
include_harmful: bool=True) -> float: |
|
"""Calculates all models perplexities at once""" |
|
|
|
normalized_document = "\n".join(normalize_text(line) for line in document.split("\n")) |
|
language = fix_language(language) |
|
|
|
if model is not None: |
|
params = CONFIG[model] |
|
if model == "wikipedia": |
|
params = params[language] |
|
params.update({"normalize": False}) |
|
perplexity = get_perplexity(document=normalized_document, **params) |
|
perplexities = { |
|
f"{model}_pp": perplexity, |
|
} |
|
else: |
|
params = CONFIG["wikipedia"][language] |
|
params.update({"normalize": False}) |
|
wikipedia_perplexity = get_perplexity(document=normalized_document, **params) |
|
|
|
params = CONFIG["books"] |
|
params.update({"normalize": False}) |
|
books_perplexity = get_perplexity(document=normalized_document, **params) |
|
|
|
params = CONFIG["newspapers"] |
|
params.update({"normalize": False}) |
|
newspapers_perplexity = get_perplexity(document=normalized_document, **params) |
|
|
|
params = CONFIG["maalfrid"] |
|
params.update({"normalize": False}) |
|
maalfrid_perplexity = get_perplexity(document=normalized_document, **params) |
|
perplexities = { |
|
"wikipedia_pp": wikipedia_perplexity, |
|
"books_pp": books_perplexity, |
|
"newspapers_pp": newspapers_perplexity, |
|
"maalfrid_pp": maalfrid_perplexity, |
|
} |
|
if include_harmful: |
|
params = CONFIG["harmful"][language] |
|
params.update({"normalize": False}) |
|
harmful_perplexity = get_perplexity(document=normalized_document, **params) |
|
perplexities.update({ |
|
"harmful_pp": harmful_perplexity, |
|
}) |
|
return perplexities |
|
|
|
|
|
def get_model_for(doc_type: str) -> (str, bool): |
|
"""Returns model type and if it needs a language variant""" |
|
doc_type = doc_type.split("_", 1)[0] |
|
if "-" in doc_type: |
|
doc_type = doc_type.split("-", 1)[-1] |
|
if doc_type in ("book", "books"): |
|
return "books", False |
|
elif doc_type in ("culturax", "slimpajama", "wikipedia", "digimanus", "pg19", "hplt", "starcoder"): |
|
return "wikipedia", True |
|
elif doc_type in ("newspaper", "newspapers"): |
|
return "newspapers", False |
|
elif doc_type in ("evalueringsrapport", "lovdata", "maalfrid", "parlamint"): |
|
return "maalfrid", False |
|
else: |
|
return "wikipedia", True |
|
|
|
|
|
def preload_models_tokenizers() -> List: |
|
print("Preloading models...", end=" ") |
|
models = { |
|
"books": ( |
|
load_kenlm(BASEPATH / "kenlm" / "books.norm.arpa.bin"), |
|
load_sentencepiece(BASEPATH / "spm" / "books.norm.sp.model") |
|
), |
|
"newspapers": ( |
|
load_kenlm(BASEPATH / "kenlm" / "newspapers.norm.arpa.bin"), |
|
load_sentencepiece(BASEPATH / "spm" / "newspapers.norm.sp.model") |
|
), |
|
"maalfrid": ( |
|
load_kenlm(BASEPATH / "kenlm" / "maalfrid.norm.arpa.bin"), |
|
load_sentencepiece(BASEPATH / "spm" / "maalfrid.norm.sp.model") |
|
), |
|
} |
|
for lang, params in CONFIG["harmful"].items(): |
|
model = load_kenlm(params["model"]) |
|
models[f"harmful-{lang}"] = model, None |
|
|
|
for lang, params in CONFIG["wikipedia"].items(): |
|
model = load_kenlm(params["model"]) |
|
tokenizer = load_sentencepiece(params["tokenizer"]) |
|
models[f"wikipedia-{lang}"] = model, tokenizer |
|
print("Done") |
|
return models |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def process_file(input_file, output_path, cutoff=None, model=None, overwrite_output=True): |
|
""" |
|
Processes a file by reading its contents, analyzing each line for language and document type, |
|
computing perplexities using specified models, and writing the modified content to a new file. |
|
|
|
This function performs several steps: |
|
1. Determines the output file path and checks for its existence if overwrite is not desired. |
|
2. Reads the input file line by line, processing each line as a separate JSON document. |
|
3. For each document, identifies its language using a fastText model. If the document type is "starcoder", |
|
it defaults the language to English. |
|
4. Depending on the model parameter, computes perplexities for the document text either using a |
|
single document type model or a specified general model. |
|
5. Updates the document with computed perplexities and writes it to the output file in JSON format. |
|
6. Optionally stops processing after a specified number of lines determined by the cutoff parameter. |
|
|
|
Parameters: |
|
- input_file (str or Path): Path to the input file to be processed. |
|
- output_path (str or Path): Directory path where the output file will be saved. The output file |
|
will have the same name as the input file. |
|
- cutoff (int, optional): If provided, processing will stop after this number of lines. Defaults to None. |
|
- model (str, optional): Specifies the model to use for computing perplexities. If 'single', uses a |
|
model specific to the document's type. Otherwise, uses the model specified. |
|
Defaults to None. |
|
- overwrite_output (bool): If True, will overwrite the output file if it already exists. If False, |
|
will skip processing if the output file exists. Defaults to True. |
|
|
|
Returns: |
|
None. Writes processed documents to an output file in the specified output path. |
|
""" |
|
input_file = Path(input_file) |
|
output_file = Path(output_path) / input_file.name |
|
if not overwrite_output and output_file.exists(): |
|
print(f"Skipping {output_file} as it already exists") |
|
return |
|
with (open(output_file, 'w', encoding='utf-8') as f, |
|
open(input_file, 'r', encoding='utf-8') as lines): |
|
for line_count, line in tqdm(enumerate(lines), desc=f"Processing {input_file.name}"): |
|
doc = json.loads(line) |
|
language = doc["lang_fasttext"] |
|
if doc["doc_type"] == "starcoder": |
|
language = "en" |
|
if model == "single": |
|
doc_type_model, _ = get_model_for(doc["doc_type"]) |
|
perplexities = source_perplexities(doc["text"], language, model=doc_type_model) |
|
perplexities["perplexity"] = perplexities.pop(f"{doc_type_model}_pp") |
|
perplexities["perplexity_model"] = doc_type_model |
|
else: |
|
perplexities = source_perplexities(doc["text"], language, model=model) |
|
doc.update(perplexities) |
|
f.write(json.dumps(doc) + "\n") |
|
if cutoff is not None and line_count >= cutoff: |
|
break |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description='Calculate perplexity values for a given JSON Lines file and output the result to a new file.') |
|
parser.add_argument('-i', '--input_file', type=str, |
|
help='Input file path') |
|
parser.add_argument('-o', '--output_path', type=str, |
|
help='Output path to write enriched file') |
|
parser.add_argument('-c', '--cutoff', required=False, type=int, |
|
help='Max number of lines to process') |
|
parser.add_argument('-m', '--model', required=False, type=str, |
|
help='Run "single" model per doc type, "all" the models, ' |
|
'or a specific model to choose from ' |
|
'"books", "wikipedia", "newspapers" or "maalfrid". ' |
|
'Defaults to "single"') |
|
parser.add_argument('--overwrite_output', |
|
action=argparse.BooleanOptionalAction, default=True, |
|
help="Whether to overwrite the output file if exists.") |
|
|
|
args = parser.parse_args() |
|
|
|
if args.model == "single": |
|
process_file( |
|
args.input_file, args.output_path, args.cutoff, |
|
model="single", overwrite_output=args.overwrite_output, |
|
) |
|
elif args.model in ("books", "wikipedia", "newspapers", "maalfrid"): |
|
process_file( |
|
args.input_file, args.output_path, args.cutoff, |
|
model=args.model, overwrite_output=args.overwrite_output, |
|
) |
|
else: |
|
process_file( |
|
args.input_file, args.output_path, args.cutoff, |
|
overwrite_output=args.overwrite_output, |
|
) |
|
|