mimir-perplexity / samples_quartiles.py
versae's picture
Mdels and code
dcc5cd1
import argparse
import os
from collections import defaultdict
from io import StringIO
import pandas as pd
from tqdm import tqdm
from perplexity import get_model_for
from subsampler import PerplexitySubsampler
def process_files(
directory,
reject_level,
model_override,
output_file,
group_by_prefix_lang,
prefix_lang_mapping=None,
ratio=None,
ratio_per_lang=None,
pa=None,
pb=None,
include=None,
):
if ratio or ratio_per_lang:
rows = ["doc_type,model,language,reject,bad,medium,good,norm,mean,std"]
else:
rows = ["doc_type,model,language,reject,bad,medium,good"]
files = os.listdir(directory)
grouped_files = defaultdict(list)
if prefix_lang_mapping is None:
prefix_lang_mapping = {}
# Group files by prefix and language if the option is enabled
description = "Processing files"
if group_by_prefix_lang:
description = "Processing files in groups"
for file in files:
parts = file.split('_')
prefix = parts[0]
if include and prefix not in include:
continue
lang = parts[-1].split(".")[0][:2]
group_key = prefix_lang_mapping.get(f"{prefix}_{lang}", f"{prefix}_{lang}")
grouped_files[group_key].append(file)
file_groups = grouped_files.values()
else:
file_groups = []
for file in files: # Each file is its own group
if include and not any(file.startswith(prefix) for prefix in include):
continue
file_groups.append([file])
if output_file:
progress = tqdm(file_groups, desc=description)
else:
progress = file_groups
print(rows[0])
# Process each group of files
for group in progress:
combined_perplexities = pd.DataFrame()
doc_type, lang = None, None
for file in group:
if not doc_type or not lang: # Set doc_type and lang based on the first file
parts = file.split('_')
doc_type = file.split('_')[0]
lang = parts[-1].split(".")[0][:2]
doc_type, lang = prefix_lang_mapping.get(f"{doc_type}_{lang}", f"{doc_type}_{lang}").rsplit("_", 1)
perp = pd.read_json(os.path.join(directory, file), lines=True)
perplexities = pd.read_json(StringIO(perp["perplexities"].to_json(lines=True, orient="records")), lines=True)
combined_perplexities = pd.concat([combined_perplexities, perplexities], ignore_index=True)
if model_override:
model = model_override
else:
model, _ = get_model_for(doc_type)
model_with_suffix = f"{model}_pp"
# Calculate quantiles for the combined perplexities of the group
reject = round(combined_perplexities[model_with_suffix].quantile(q=reject_level), 2)
bad = round(combined_perplexities[model_with_suffix].quantile(q=0.75), 2)
medium = round(combined_perplexities[model_with_suffix].quantile(q=0.50), 2)
good = round(combined_perplexities[model_with_suffix].quantile(q=0.25), 2)
if ratio:
subsampler = PerplexitySubsampler(combined_perplexities[model_with_suffix].values)
subsampler.set(ratio=ratio, pa=pa, pb=pb)
norm, mean, std = subsampler.norm, subsampler.mean, subsampler.sdev
sampling_stats = f",{norm},{mean},{std}"
elif ratio_per_lang:
subsampler = PerplexitySubsampler(combined_perplexities[model_with_suffix].values)
subsampler.set(ratio=ratio_per_lang.get(lang, ratio or 1.0), pa=pa, pb=pb)
norm, mean, std = subsampler.norm, subsampler.mean, subsampler.sdev
sampling_stats = f",{norm},{mean},{std}"
else:
sampling_stats = ""
row = f"{doc_type},{model},{lang},{reject},{bad},{medium},{good}{sampling_stats}"
if output_file:
rows.append(row)
else:
print(row)
if output_file:
with open(output_file, "w") as f:
for row in rows:
f.write(f"{row}\n")
def main():
""""
Each doc_type prefix needs to have an "no" lang, even of there's no real data.
These rows are crucial for the rest of the process.
"""
parser = argparse.ArgumentParser(description="Process files and compute perplexity metrics.")
parser.add_argument('directory', type=str, help='Directory containing the files to process')
parser.add_argument('--reject_level', type=float, default=0.95, help='Rejection quantile level (default: 0.95)')
parser.add_argument('--model_override', type=str, help='Override the model used')
parser.add_argument('--output_file', type=str, help='Output file in CSV format. If not given, prints to standard output.')
parser.add_argument('--group_by_prefix_lang', action='store_true', help='Group and calculate quantiles for files with the same prefix and language')
parser.add_argument('--overwrite_prefix_lang', type=str, help='Overwrite the assignment of languages to doc_type prefixes, e.g., "starcoder_en:starcoder_code,hplt_en:hplt_no"')
parser.add_argument('--sampling_ratio', type=float, help='Ratio of documents to keep for sampling. If passed, it generate distribution statistics (norm, mean, std) needed for sampling')
parser.add_argument('--sampling_ratio_per_lang', type=str, help='Ratio of documents per lang, e.g., "en:0.25,sv:0.34"')
parser.add_argument('--sampling_q1_prob', type=float, default=0.20, help='Probabilty for keeping documents in the Q1 range')
parser.add_argument('--sampling_q3_prob', type=float, default=0.05, help='Probabilty for keeping documents in the Q3 range')
parser.add_argument('--include', type=str, help='Comma separeted list of doc type prefixes to include')
args = parser.parse_args()
if args.sampling_ratio_per_lang:
# Turns "en: 0.25, sv : 0.34" into {'en': 0.25, 'sv': 0.34}
ratio_per_lang = dict(
(k.strip(), float(v.strip()))
for k, v in (item.split(":")
for item in args.sampling_ratio_per_lang.split(",")
)
)
else:
ratio_per_lang = None
if args.overwrite_prefix_lang:
# Turns "starcoder_en:starcoder_code,hplt_en:hplt_no" into {'starcoder_en': 'starcoder_code', 'hplt_en': 'hplt_no'}
prefix_lang_mapping = dict(
(k.strip(), v.strip())
for k, v in (item.split(":")
for item in args.overwrite_prefix_lang.split(",")
)
)
else:
prefix_lang_mapping = {}
process_files(
args.directory,
args.reject_level,
args.model_override,
args.output_file,
group_by_prefix_lang=args.group_by_prefix_lang,
prefix_lang_mapping=prefix_lang_mapping,
pa=args.sampling_q1_prob,
pb=args.sampling_q3_prob,
ratio=args.sampling_ratio,
ratio_per_lang=ratio_per_lang,
include=args.include.split(",") if args.include else None
)
if __name__ == "__main__":
main()