|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
import json |
|
import argparse |
|
import os |
|
from scipy.stats import gaussian_kde |
|
import numpy as np |
|
|
|
def get_model_for(doc_type: str, override_model: str) -> str: |
|
"""Returns model type or the override model if specified""" |
|
if override_model: |
|
return override_model |
|
doc_type = doc_type.split("_", 1)[0] |
|
if doc_type in ("book", "books", "pg19"): |
|
return "books_pp" |
|
elif doc_type in ("culturax", "slimpajama", "wikipedia", "digimanus"): |
|
return "wikipedia_pp" |
|
elif doc_type in ("newspaper", "newspapers"): |
|
return "newspapers_pp" |
|
elif doc_type in ("evalueringsrapport", "lovdata", "maalfrid", "parlamint"): |
|
return "maalfrid_pp" |
|
else: |
|
return "wikipedia_pp" |
|
|
|
def load_data(files): |
|
all_data = [] |
|
for file_path in files: |
|
with open(file_path, 'r') as file: |
|
lines = file.readlines() |
|
data = [json.loads(line) for line in lines] |
|
all_data.extend(data) |
|
return pd.DataFrame(all_data) |
|
|
|
def plot_histograms(files, output_folder, xlim, override_model): |
|
df = load_data(files) |
|
doc_types = df['doctype'].unique() |
|
fig, axes = plt.subplots(len(doc_types), 1, figsize=(12, 4 * len(doc_types)), squeeze=False) |
|
|
|
|
|
palette = sns.color_palette("husl", len(doc_types)) |
|
|
|
for i, doc_type in enumerate(doc_types): |
|
ax = axes[i, 0] |
|
group = df[df['doctype'] == doc_type] |
|
languages = group['lang'].unique() |
|
|
|
|
|
colors = sns.color_palette("husl", len(languages)) |
|
|
|
for j, lang in enumerate(languages): |
|
lang_group = group[group['lang'] == lang] |
|
perplexity_model = get_model_for(doc_type, override_model) |
|
perplexity_values = lang_group['perplexities'].apply(lambda x: x[perplexity_model]).values |
|
|
|
series_color = colors[j] |
|
|
|
|
|
sns.histplot(perplexity_values, ax=ax, color=series_color, alpha=0.3, element="step", fill=True, stat="density", binwidth=30) |
|
|
|
|
|
sns.kdeplot(perplexity_values, ax=ax, bw_adjust=2, color=series_color, label=f"{lang} - {doc_type} ({perplexity_model})", linewidth=1.5) |
|
|
|
|
|
kde = gaussian_kde(perplexity_values) |
|
x_range = np.linspace(0, xlim, 1000) |
|
y_values = kde.evaluate(x_range) |
|
|
|
quartiles = np.quantile(perplexity_values, [0.25, 0.5, 0.75]) |
|
quartile_labels = ["Q1", "Q2", "Q3"] |
|
for q, quartile in enumerate(quartiles): |
|
idx = (np.abs(x_range-quartile)).argmin() |
|
y_quartile = y_values[idx] |
|
ax.plot([quartile, quartile], [0, y_quartile], color=series_color, linestyle='--', linewidth=1) |
|
ax.text(quartile, y_quartile, f'{quartile_labels[q]}: {quartile:.2f}', verticalalignment='bottom', horizontalalignment='right', color=series_color, fontsize=6) |
|
|
|
ax.set_title(f'Document Type: {doc_type} ({perplexity_model})') |
|
ax.set_xlabel('Perplexity Value') |
|
ax.set_ylabel('Density') |
|
ax.legend() |
|
ax.set_xlim(left=0, right=xlim) |
|
|
|
plt.tight_layout() |
|
output_filename = os.path.join(output_folder, "all_doc_types_plots.png") |
|
plt.savefig(output_filename, dpi=300) |
|
plt.close(fig) |
|
print(f"All document type plots saved to {output_filename}") |
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description="Plot histograms from JSON lines files.") |
|
parser.add_argument('files', nargs='+', help="Path to the JSON lines files") |
|
parser.add_argument('-o', '--output_folder', default=".", help="Output folder for the plots") |
|
parser.add_argument('--xlim', type=int, default=2500, help="Maximum x-axis limit for the plots") |
|
parser.add_argument('--model', default="", help="Override the perplexity model for all plots") |
|
|
|
args = parser.parse_args() |
|
|
|
if not os.path.exists(args.output_folder): |
|
os.makedirs(args.output_folder, exist_ok=True) |
|
|
|
plot_histograms(args.files, args.output_folder, args.xlim, args.model) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|