import pandas as pd import matplotlib.pyplot as plt import seaborn as sns import json import argparse import os from scipy.stats import gaussian_kde import numpy as np def get_model_for(doc_type: str, override_model: str) -> str: """Returns model type or the override model if specified""" if override_model: return override_model doc_type = doc_type.split("_", 1)[0] if doc_type in ("book", "books", "pg19"): return "books_pp" elif doc_type in ("culturax", "slimpajama", "wikipedia", "digimanus"): return "wikipedia_pp" elif doc_type in ("newspaper", "newspapers"): return "newspapers_pp" elif doc_type in ("evalueringsrapport", "lovdata", "maalfrid", "parlamint"): return "maalfrid_pp" else: return "wikipedia_pp" def load_data(files): all_data = [] for file_path in files: with open(file_path, 'r') as file: lines = file.readlines() data = [json.loads(line) for line in lines] all_data.extend(data) return pd.DataFrame(all_data) def plot_histograms(files, output_folder, xlim, override_model): df = load_data(files) doc_types = df['doctype'].unique() fig, axes = plt.subplots(len(doc_types), 1, figsize=(12, 4 * len(doc_types)), squeeze=False) # Set up a color palette palette = sns.color_palette("husl", len(doc_types)) for i, doc_type in enumerate(doc_types): ax = axes[i, 0] group = df[df['doctype'] == doc_type] languages = group['lang'].unique() # Prepare a unique color for each language within the document type colors = sns.color_palette("husl", len(languages)) for j, lang in enumerate(languages): lang_group = group[group['lang'] == lang] perplexity_model = get_model_for(doc_type, override_model) perplexity_values = lang_group['perplexities'].apply(lambda x: x[perplexity_model]).values series_color = colors[j] # Plot histogram with lighter color sns.histplot(perplexity_values, ax=ax, color=series_color, alpha=0.3, element="step", fill=True, stat="density", binwidth=30) # Plot KDE without filling sns.kdeplot(perplexity_values, ax=ax, bw_adjust=2, color=series_color, label=f"{lang} - {doc_type} ({perplexity_model})", linewidth=1.5) kde = gaussian_kde(perplexity_values) x_range = np.linspace(0, xlim, 1000) y_values = kde.evaluate(x_range) quartiles = np.quantile(perplexity_values, [0.25, 0.5, 0.75]) quartile_labels = ["Q1", "Q2", "Q3"] for q, quartile in enumerate(quartiles): idx = (np.abs(x_range-quartile)).argmin() y_quartile = y_values[idx] ax.plot([quartile, quartile], [0, y_quartile], color=series_color, linestyle='--', linewidth=1) ax.text(quartile, y_quartile, f'{quartile_labels[q]}: {quartile:.2f}', verticalalignment='bottom', horizontalalignment='right', color=series_color, fontsize=6) ax.set_title(f'Document Type: {doc_type} ({perplexity_model})') ax.set_xlabel('Perplexity Value') ax.set_ylabel('Density') ax.legend() ax.set_xlim(left=0, right=xlim) plt.tight_layout() output_filename = os.path.join(output_folder, "all_doc_types_plots.png") plt.savefig(output_filename, dpi=300) plt.close(fig) print(f"All document type plots saved to {output_filename}") def main(): parser = argparse.ArgumentParser(description="Plot histograms from JSON lines files.") parser.add_argument('files', nargs='+', help="Path to the JSON lines files") parser.add_argument('-o', '--output_folder', default=".", help="Output folder for the plots") parser.add_argument('--xlim', type=int, default=2500, help="Maximum x-axis limit for the plots") parser.add_argument('--model', default="", help="Override the perplexity model for all plots") args = parser.parse_args() if not os.path.exists(args.output_folder): os.makedirs(args.output_folder, exist_ok=True) plot_histograms(args.files, args.output_folder, args.xlim, args.model) if __name__ == "__main__": main()