import argparse import os from collections import defaultdict from io import StringIO import pandas as pd from tqdm import tqdm from perplexity import get_model_for from subsampler import PerplexitySubsampler def process_files( directory, reject_level, model_override, output_file, group_by_prefix_lang, prefix_lang_mapping=None, ratio=None, ratio_per_lang=None, pa=None, pb=None, include=None, ): if ratio or ratio_per_lang: rows = ["doc_type,model,language,reject,bad,medium,good,norm,mean,std"] else: rows = ["doc_type,model,language,reject,bad,medium,good"] files = os.listdir(directory) grouped_files = defaultdict(list) if prefix_lang_mapping is None: prefix_lang_mapping = {} # Group files by prefix and language if the option is enabled description = "Processing files" if group_by_prefix_lang: description = "Processing files in groups" for file in files: parts = file.split('_') prefix = parts[0] if include and prefix not in include: continue lang = parts[-1].split(".")[0][:2] group_key = prefix_lang_mapping.get(f"{prefix}_{lang}", f"{prefix}_{lang}") grouped_files[group_key].append(file) file_groups = grouped_files.values() else: file_groups = [] for file in files: # Each file is its own group if include and not any(file.startswith(prefix) for prefix in include): continue file_groups.append([file]) if output_file: progress = tqdm(file_groups, desc=description) else: progress = file_groups print(rows[0]) # Process each group of files for group in progress: combined_perplexities = pd.DataFrame() doc_type, lang = None, None for file in group: if not doc_type or not lang: # Set doc_type and lang based on the first file parts = file.split('_') doc_type = file.split('_')[0] lang = parts[-1].split(".")[0][:2] doc_type, lang = prefix_lang_mapping.get(f"{doc_type}_{lang}", f"{doc_type}_{lang}").rsplit("_", 1) perp = pd.read_json(os.path.join(directory, file), lines=True) perplexities = pd.read_json(StringIO(perp["perplexities"].to_json(lines=True, orient="records")), lines=True) combined_perplexities = pd.concat([combined_perplexities, perplexities], ignore_index=True) if model_override: model = model_override else: model, _ = get_model_for(doc_type) model_with_suffix = f"{model}_pp" # Calculate quantiles for the combined perplexities of the group reject = round(combined_perplexities[model_with_suffix].quantile(q=reject_level), 2) bad = round(combined_perplexities[model_with_suffix].quantile(q=0.75), 2) medium = round(combined_perplexities[model_with_suffix].quantile(q=0.50), 2) good = round(combined_perplexities[model_with_suffix].quantile(q=0.25), 2) if ratio: subsampler = PerplexitySubsampler(combined_perplexities[model_with_suffix].values) subsampler.set(ratio=ratio, pa=pa, pb=pb) norm, mean, std = subsampler.norm, subsampler.mean, subsampler.sdev sampling_stats = f",{norm},{mean},{std}" elif ratio_per_lang: subsampler = PerplexitySubsampler(combined_perplexities[model_with_suffix].values) subsampler.set(ratio=ratio_per_lang.get(lang, ratio or 1.0), pa=pa, pb=pb) norm, mean, std = subsampler.norm, subsampler.mean, subsampler.sdev sampling_stats = f",{norm},{mean},{std}" else: sampling_stats = "" row = f"{doc_type},{model},{lang},{reject},{bad},{medium},{good}{sampling_stats}" if output_file: rows.append(row) else: print(row) if output_file: with open(output_file, "w") as f: for row in rows: f.write(f"{row}\n") def main(): """" Each doc_type prefix needs to have an "no" lang, even of there's no real data. These rows are crucial for the rest of the process. """ parser = argparse.ArgumentParser(description="Process files and compute perplexity metrics.") parser.add_argument('directory', type=str, help='Directory containing the files to process') parser.add_argument('--reject_level', type=float, default=0.95, help='Rejection quantile level (default: 0.95)') parser.add_argument('--model_override', type=str, help='Override the model used') parser.add_argument('--output_file', type=str, help='Output file in CSV format. If not given, prints to standard output.') parser.add_argument('--group_by_prefix_lang', action='store_true', help='Group and calculate quantiles for files with the same prefix and language') parser.add_argument('--overwrite_prefix_lang', type=str, help='Overwrite the assignment of languages to doc_type prefixes, e.g., "starcoder_en:starcoder_code,hplt_en:hplt_no"') parser.add_argument('--sampling_ratio', type=float, help='Ratio of documents to keep for sampling. If passed, it generate distribution statistics (norm, mean, std) needed for sampling') parser.add_argument('--sampling_ratio_per_lang', type=str, help='Ratio of documents per lang, e.g., "en:0.25,sv:0.34"') parser.add_argument('--sampling_q1_prob', type=float, default=0.20, help='Probabilty for keeping documents in the Q1 range') parser.add_argument('--sampling_q3_prob', type=float, default=0.05, help='Probabilty for keeping documents in the Q3 range') parser.add_argument('--include', type=str, help='Comma separeted list of doc type prefixes to include') args = parser.parse_args() if args.sampling_ratio_per_lang: # Turns "en: 0.25, sv : 0.34" into {'en': 0.25, 'sv': 0.34} ratio_per_lang = dict( (k.strip(), float(v.strip())) for k, v in (item.split(":") for item in args.sampling_ratio_per_lang.split(",") ) ) else: ratio_per_lang = None if args.overwrite_prefix_lang: # Turns "starcoder_en:starcoder_code,hplt_en:hplt_no" into {'starcoder_en': 'starcoder_code', 'hplt_en': 'hplt_no'} prefix_lang_mapping = dict( (k.strip(), v.strip()) for k, v in (item.split(":") for item in args.overwrite_prefix_lang.split(",") ) ) else: prefix_lang_mapping = {} process_files( args.directory, args.reject_level, args.model_override, args.output_file, group_by_prefix_lang=args.group_by_prefix_lang, prefix_lang_mapping=prefix_lang_mapping, pa=args.sampling_q1_prob, pb=args.sampling_q3_prob, ratio=args.sampling_ratio, ratio_per_lang=ratio_per_lang, include=args.include.split(",") if args.include else None ) if __name__ == "__main__": main()