import argparse |
import os |
from collections import defaultdict |
from io import StringIO |
import pandas as pd |
from tqdm import tqdm |
from perplexity import get_model_for |
from subsampler import PerplexitySubsampler |
def process_files( |
directory, |
reject_level, |
model_override, |
output_file, |
group_by_prefix_lang, |
prefix_lang_mapping=None, |
ratio=None, |
ratio_per_lang=None, |
pa=None, |
pb=None, |
include=None, |
): |
if ratio or ratio_per_lang: |
rows = ["doc_type,model,language,reject,bad,medium,good,norm,mean,std"] |
else: |
rows = ["doc_type,model,language,reject,bad,medium,good"] |
files = os.listdir(directory) |
grouped_files = defaultdict(list) |
if prefix_lang_mapping is None: |
prefix_lang_mapping = {} |
description = "Processing files" |
if group_by_prefix_lang: |
description = "Processing files in groups" |
for file in files: |
parts = file.split('_') |
prefix = parts[0] |
if include and prefix not in include: |
continue |
lang = parts[-1].split(".")[0][:2] |
group_key = prefix_lang_mapping.get(f"{prefix}_{lang}", f"{prefix}_{lang}") |
grouped_files[group_key].append(file) |
file_groups = grouped_files.values() |
else: |
file_groups = [] |
for file in files: |
if include and not any(file.startswith(prefix) for prefix in include): |
continue |
file_groups.append([file]) |
if output_file: |
progress = tqdm(file_groups, desc=description) |
else: |
progress = file_groups |
print(rows[0]) |
for group in progress: |
combined_perplexities = pd.DataFrame() |
doc_type, lang = None, None |
for file in group: |
if not doc_type or not lang: |
parts = file.split('_') |
doc_type = file.split('_')[0] |
lang = parts[-1].split(".")[0][:2] |
doc_type, lang = prefix_lang_mapping.get(f"{doc_type}_{lang}", f"{doc_type}_{lang}").rsplit("_", 1) |
perp = pd.read_json(os.path.join(directory, file), lines=True) |
perplexities = pd.read_json(StringIO(perp["perplexities"].to_json(lines=True, orient="records")), lines=True) |
combined_perplexities = pd.concat([combined_perplexities, perplexities], ignore_index=True) |
if model_override: |
model = model_override |
else: |
model, _ = get_model_for(doc_type) |
model_with_suffix = f"{model}_pp" |
reject = round(combined_perplexities[model_with_suffix].quantile(q=reject_level), 2) |
bad = round(combined_perplexities[model_with_suffix].quantile(q=0.75), 2) |
medium = round(combined_perplexities[model_with_suffix].quantile(q=0.50), 2) |
good = round(combined_perplexities[model_with_suffix].quantile(q=0.25), 2) |
if ratio: |
subsampler = PerplexitySubsampler(combined_perplexities[model_with_suffix].values) |
subsampler.set(ratio=ratio, pa=pa, pb=pb) |
norm, mean, std = subsampler.norm, subsampler.mean, subsampler.sdev |
sampling_stats = f",{norm},{mean},{std}" |
elif ratio_per_lang: |
subsampler = PerplexitySubsampler(combined_perplexities[model_with_suffix].values) |
subsampler.set(ratio=ratio_per_lang.get(lang, ratio or 1.0), pa=pa, pb=pb) |
norm, mean, std = subsampler.norm, subsampler.mean, subsampler.sdev |
sampling_stats = f",{norm},{mean},{std}" |
else: |
sampling_stats = "" |
row = f"{doc_type},{model},{lang},{reject},{bad},{medium},{good}{sampling_stats}" |
if output_file: |
rows.append(row) |
else: |
print(row) |
if output_file: |
with open(output_file, "w") as f: |
for row in rows: |
f.write(f"{row}\n") |
def main(): |
"""" |
Each doc_type prefix needs to have an "no" lang, even of there's no real data. |
These rows are crucial for the rest of the process. |
""" |
parser = argparse.ArgumentParser(description="Process files and compute perplexity metrics.") |
parser.add_argument('directory', type=str, help='Directory containing the files to process') |
parser.add_argument('--reject_level', type=float, default=0.95, help='Rejection quantile level (default: 0.95)') |
parser.add_argument('--model_override', type=str, help='Override the model used') |
parser.add_argument('--output_file', type=str, help='Output file in CSV format. If not given, prints to standard output.') |
parser.add_argument('--group_by_prefix_lang', action='store_true', help='Group and calculate quantiles for files with the same prefix and language') |
parser.add_argument('--overwrite_prefix_lang', type=str, help='Overwrite the assignment of languages to doc_type prefixes, e.g., "starcoder_en:starcoder_code,hplt_en:hplt_no"') |
parser.add_argument('--sampling_ratio', type=float, help='Ratio of documents to keep for sampling. If passed, it generate distribution statistics (norm, mean, std) needed for sampling') |
parser.add_argument('--sampling_ratio_per_lang', type=str, help='Ratio of documents per lang, e.g., "en:0.25,sv:0.34"') |
parser.add_argument('--sampling_q1_prob', type=float, default=0.20, help='Probabilty for keeping documents in the Q1 range') |
parser.add_argument('--sampling_q3_prob', type=float, default=0.05, help='Probabilty for keeping documents in the Q3 range') |
parser.add_argument('--include', type=str, help='Comma separeted list of doc type prefixes to include') |
args = parser.parse_args() |
if args.sampling_ratio_per_lang: |
ratio_per_lang = dict( |
(k.strip(), float(v.strip())) |
for k, v in (item.split(":") |
for item in args.sampling_ratio_per_lang.split(",") |
) |
) |
else: |
ratio_per_lang = None |
if args.overwrite_prefix_lang: |
prefix_lang_mapping = dict( |
(k.strip(), v.strip()) |
for k, v in (item.split(":") |
for item in args.overwrite_prefix_lang.split(",") |
) |
) |
else: |
prefix_lang_mapping = {} |
process_files( |
args.directory, |
args.reject_level, |
args.model_override, |
args.output_file, |
group_by_prefix_lang=args.group_by_prefix_lang, |
prefix_lang_mapping=prefix_lang_mapping, |
pa=args.sampling_q1_prob, |
pb=args.sampling_q3_prob, |
ratio=args.sampling_ratio, |
ratio_per_lang=ratio_per_lang, |
include=args.include.split(",") if args.include else None |
) |
if __name__ == "__main__": |
main() |