|
import argparse |
|
import os |
|
from collections import defaultdict |
|
from io import StringIO |
|
|
|
import pandas as pd |
|
from tqdm import tqdm |
|
|
|
from perplexity import get_model_for |
|
from subsampler import PerplexitySubsampler |
|
|
|
|
|
def process_files( |
|
directory, |
|
reject_level, |
|
model_override, |
|
output_file, |
|
group_by_prefix_lang, |
|
prefix_lang_mapping=None, |
|
ratio=None, |
|
ratio_per_lang=None, |
|
pa=None, |
|
pb=None, |
|
include=None, |
|
): |
|
if ratio or ratio_per_lang: |
|
rows = ["doc_type,model,language,reject,bad,medium,good,norm,mean,std"] |
|
else: |
|
rows = ["doc_type,model,language,reject,bad,medium,good"] |
|
files = os.listdir(directory) |
|
grouped_files = defaultdict(list) |
|
if prefix_lang_mapping is None: |
|
prefix_lang_mapping = {} |
|
|
|
|
|
description = "Processing files" |
|
if group_by_prefix_lang: |
|
description = "Processing files in groups" |
|
for file in files: |
|
parts = file.split('_') |
|
prefix = parts[0] |
|
if include and prefix not in include: |
|
continue |
|
lang = parts[-1].split(".")[0][:2] |
|
group_key = prefix_lang_mapping.get(f"{prefix}_{lang}", f"{prefix}_{lang}") |
|
grouped_files[group_key].append(file) |
|
file_groups = grouped_files.values() |
|
else: |
|
file_groups = [] |
|
for file in files: |
|
if include and not any(file.startswith(prefix) for prefix in include): |
|
continue |
|
file_groups.append([file]) |
|
|
|
if output_file: |
|
progress = tqdm(file_groups, desc=description) |
|
else: |
|
progress = file_groups |
|
print(rows[0]) |
|
|
|
for group in progress: |
|
combined_perplexities = pd.DataFrame() |
|
doc_type, lang = None, None |
|
|
|
for file in group: |
|
if not doc_type or not lang: |
|
parts = file.split('_') |
|
doc_type = file.split('_')[0] |
|
lang = parts[-1].split(".")[0][:2] |
|
doc_type, lang = prefix_lang_mapping.get(f"{doc_type}_{lang}", f"{doc_type}_{lang}").rsplit("_", 1) |
|
perp = pd.read_json(os.path.join(directory, file), lines=True) |
|
perplexities = pd.read_json(StringIO(perp["perplexities"].to_json(lines=True, orient="records")), lines=True) |
|
combined_perplexities = pd.concat([combined_perplexities, perplexities], ignore_index=True) |
|
|
|
if model_override: |
|
model = model_override |
|
else: |
|
model, _ = get_model_for(doc_type) |
|
model_with_suffix = f"{model}_pp" |
|
|
|
|
|
reject = round(combined_perplexities[model_with_suffix].quantile(q=reject_level), 2) |
|
bad = round(combined_perplexities[model_with_suffix].quantile(q=0.75), 2) |
|
medium = round(combined_perplexities[model_with_suffix].quantile(q=0.50), 2) |
|
good = round(combined_perplexities[model_with_suffix].quantile(q=0.25), 2) |
|
|
|
if ratio: |
|
subsampler = PerplexitySubsampler(combined_perplexities[model_with_suffix].values) |
|
subsampler.set(ratio=ratio, pa=pa, pb=pb) |
|
norm, mean, std = subsampler.norm, subsampler.mean, subsampler.sdev |
|
sampling_stats = f",{norm},{mean},{std}" |
|
elif ratio_per_lang: |
|
subsampler = PerplexitySubsampler(combined_perplexities[model_with_suffix].values) |
|
subsampler.set(ratio=ratio_per_lang.get(lang, ratio or 1.0), pa=pa, pb=pb) |
|
norm, mean, std = subsampler.norm, subsampler.mean, subsampler.sdev |
|
sampling_stats = f",{norm},{mean},{std}" |
|
else: |
|
sampling_stats = "" |
|
|
|
row = f"{doc_type},{model},{lang},{reject},{bad},{medium},{good}{sampling_stats}" |
|
if output_file: |
|
rows.append(row) |
|
else: |
|
print(row) |
|
|
|
|
|
if output_file: |
|
with open(output_file, "w") as f: |
|
for row in rows: |
|
f.write(f"{row}\n") |
|
|
|
|
|
def main(): |
|
"""" |
|
Each doc_type prefix needs to have an "no" lang, even of there's no real data. |
|
These rows are crucial for the rest of the process. |
|
""" |
|
parser = argparse.ArgumentParser(description="Process files and compute perplexity metrics.") |
|
parser.add_argument('directory', type=str, help='Directory containing the files to process') |
|
parser.add_argument('--reject_level', type=float, default=0.95, help='Rejection quantile level (default: 0.95)') |
|
parser.add_argument('--model_override', type=str, help='Override the model used') |
|
parser.add_argument('--output_file', type=str, help='Output file in CSV format. If not given, prints to standard output.') |
|
parser.add_argument('--group_by_prefix_lang', action='store_true', help='Group and calculate quantiles for files with the same prefix and language') |
|
parser.add_argument('--overwrite_prefix_lang', type=str, help='Overwrite the assignment of languages to doc_type prefixes, e.g., "starcoder_en:starcoder_code,hplt_en:hplt_no"') |
|
parser.add_argument('--sampling_ratio', type=float, help='Ratio of documents to keep for sampling. If passed, it generate distribution statistics (norm, mean, std) needed for sampling') |
|
parser.add_argument('--sampling_ratio_per_lang', type=str, help='Ratio of documents per lang, e.g., "en:0.25,sv:0.34"') |
|
parser.add_argument('--sampling_q1_prob', type=float, default=0.20, help='Probabilty for keeping documents in the Q1 range') |
|
parser.add_argument('--sampling_q3_prob', type=float, default=0.05, help='Probabilty for keeping documents in the Q3 range') |
|
parser.add_argument('--include', type=str, help='Comma separeted list of doc type prefixes to include') |
|
|
|
args = parser.parse_args() |
|
|
|
if args.sampling_ratio_per_lang: |
|
|
|
ratio_per_lang = dict( |
|
(k.strip(), float(v.strip())) |
|
for k, v in (item.split(":") |
|
for item in args.sampling_ratio_per_lang.split(",") |
|
) |
|
) |
|
else: |
|
ratio_per_lang = None |
|
if args.overwrite_prefix_lang: |
|
|
|
prefix_lang_mapping = dict( |
|
(k.strip(), v.strip()) |
|
for k, v in (item.split(":") |
|
for item in args.overwrite_prefix_lang.split(",") |
|
) |
|
) |
|
else: |
|
prefix_lang_mapping = {} |
|
|
|
process_files( |
|
args.directory, |
|
args.reject_level, |
|
args.model_override, |
|
args.output_file, |
|
group_by_prefix_lang=args.group_by_prefix_lang, |
|
prefix_lang_mapping=prefix_lang_mapping, |
|
pa=args.sampling_q1_prob, |
|
pb=args.sampling_q3_prob, |
|
ratio=args.sampling_ratio, |
|
ratio_per_lang=ratio_per_lang, |
|
include=args.include.split(",") if args.include else None |
|
) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|