mimir-perplexity / histograms.py
versae's picture
Mdels and code
dcc5cd1
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import json
import argparse
import os
from scipy.stats import gaussian_kde
import numpy as np
def get_model_for(doc_type: str, override_model: str) -> str:
"""Returns model type or the override model if specified"""
if override_model:
return override_model
doc_type = doc_type.split("_", 1)[0]
if doc_type in ("book", "books", "pg19"):
return "books_pp"
elif doc_type in ("culturax", "slimpajama", "wikipedia", "digimanus"):
return "wikipedia_pp"
elif doc_type in ("newspaper", "newspapers"):
return "newspapers_pp"
elif doc_type in ("evalueringsrapport", "lovdata", "maalfrid", "parlamint"):
return "maalfrid_pp"
else:
return "wikipedia_pp"
def load_data(files):
all_data = []
for file_path in files:
with open(file_path, 'r') as file:
lines = file.readlines()
data = [json.loads(line) for line in lines]
all_data.extend(data)
return pd.DataFrame(all_data)
def plot_histograms(files, output_folder, xlim, override_model):
df = load_data(files)
doc_types = df['doctype'].unique()
fig, axes = plt.subplots(len(doc_types), 1, figsize=(12, 4 * len(doc_types)), squeeze=False)
# Set up a color palette
palette = sns.color_palette("husl", len(doc_types))
for i, doc_type in enumerate(doc_types):
ax = axes[i, 0]
group = df[df['doctype'] == doc_type]
languages = group['lang'].unique()
# Prepare a unique color for each language within the document type
colors = sns.color_palette("husl", len(languages))
for j, lang in enumerate(languages):
lang_group = group[group['lang'] == lang]
perplexity_model = get_model_for(doc_type, override_model)
perplexity_values = lang_group['perplexities'].apply(lambda x: x[perplexity_model]).values
series_color = colors[j]
# Plot histogram with lighter color
sns.histplot(perplexity_values, ax=ax, color=series_color, alpha=0.3, element="step", fill=True, stat="density", binwidth=30)
# Plot KDE without filling
sns.kdeplot(perplexity_values, ax=ax, bw_adjust=2, color=series_color, label=f"{lang} - {doc_type} ({perplexity_model})", linewidth=1.5)
kde = gaussian_kde(perplexity_values)
x_range = np.linspace(0, xlim, 1000)
y_values = kde.evaluate(x_range)
quartiles = np.quantile(perplexity_values, [0.25, 0.5, 0.75])
quartile_labels = ["Q1", "Q2", "Q3"]
for q, quartile in enumerate(quartiles):
idx = (np.abs(x_range-quartile)).argmin()
y_quartile = y_values[idx]
ax.plot([quartile, quartile], [0, y_quartile], color=series_color, linestyle='--', linewidth=1)
ax.text(quartile, y_quartile, f'{quartile_labels[q]}: {quartile:.2f}', verticalalignment='bottom', horizontalalignment='right', color=series_color, fontsize=6)
ax.set_title(f'Document Type: {doc_type} ({perplexity_model})')
ax.set_xlabel('Perplexity Value')
ax.set_ylabel('Density')
ax.legend()
ax.set_xlim(left=0, right=xlim)
plt.tight_layout()
output_filename = os.path.join(output_folder, "all_doc_types_plots.png")
plt.savefig(output_filename, dpi=300)
plt.close(fig)
print(f"All document type plots saved to {output_filename}")
def main():
parser = argparse.ArgumentParser(description="Plot histograms from JSON lines files.")
parser.add_argument('files', nargs='+', help="Path to the JSON lines files")
parser.add_argument('-o', '--output_folder', default=".", help="Output folder for the plots")
parser.add_argument('--xlim', type=int, default=2500, help="Maximum x-axis limit for the plots")
parser.add_argument('--model', default="", help="Override the perplexity model for all plots")
args = parser.parse_args()
if not os.path.exists(args.output_folder):
os.makedirs(args.output_folder, exist_ok=True)
plot_histograms(args.files, args.output_folder, args.xlim, args.model)
if __name__ == "__main__":
main()