DOoM-lb / src /radial /radial.py
Anonumous's picture
Add files
6ee7257
import plotly.graph_objects as go
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import random
import itertools as it
from src.leaderboard.build_leaderboard import build_leadearboard_df
def create_plot(selected_models):
"""
Создает визуализацию для сравнения выбранных моделей по метрикам DeathMath
Args:
selected_models: Список названий моделей для отображения на графике
Returns:
matplotlib.figure.Figure: График для отображения в интерфейсе
"""
# Получаем данные моделей из лидерборда
models_df = build_leadearboard_df()
# Если нет выбранных моделей или данные не загружены, возвращаем пустой график
if not selected_models or models_df.empty:
fig, ax = plt.subplots(figsize=(10, 6))
ax.text(0.5, 0.5, "Нет данных для отображения",
horizontalalignment='center', verticalalignment='center',
transform=ax.transAxes, fontsize=14)
ax.set_axis_off()
return fig
# Фильтруем DataFrame, чтобы оставить только выбранные модели
models_to_show = models_df[models_df['model'].isin(selected_models)]
if models_to_show.empty:
fig, ax = plt.subplots(figsize=(10, 6))
ax.text(0.5, 0.5, "Выбранные модели не найдены в данных",
horizontalalignment='center', verticalalignment='center',
transform=ax.transAxes, fontsize=14)
ax.set_axis_off()
return fig
# Настройка бар-графика для сравнения моделей
fig, ax = plt.subplots(figsize=(12, 8))
# Ширина столбцов
bar_width = 0.25
# Позиции на оси x
models_count = len(models_to_show)
indices = np.arange(models_count)
# Цветовая палитра
colors = ['#1f77b4', '#ff7f0e', '#2ca02c']
# Строим столбцы для разных метрик
ax.bar(indices - bar_width, models_to_show['math_score'], bar_width,
label='RussianMath Score', color=colors[0])
ax.bar(indices, models_to_show['physics_score'], bar_width,
label='RussianPhysics Score', color=colors[1])
ax.bar(indices + bar_width, models_to_show['score'], bar_width,
label='Combined Score', color=colors[2])
# Настройка осей и меток
ax.set_xlabel('Модели')
ax.set_ylabel('Баллы')
ax.set_title('Сравнение производительности моделей на DeathMath benchmark')
ax.set_xticks(indices)
ax.set_xticklabels(models_to_show['model'], rotation=45, ha='right')
ax.legend()
# Ограничение значений по оси y от 0 до 1
ax.set_ylim(0, 1.0)
# Добавляем сетку для лучшей читаемости
ax.grid(axis='y', linestyle='--', alpha=0.7)
# Обеспечиваем, чтобы все метки помещались
plt.tight_layout()
return fig
def create_radar_plot(selected_models):
"""
Создает радиальную диаграмму для сравнения выбранных моделей
Args:
selected_models: Список названий моделей для отображения на графике
Returns:
plotly.graph_objects.Figure: Интерактивный радиальный график
"""
models = build_leadearboard_df()
metrics = ["math_score", "physics_score", "score"]
metric_labels = ["RussianMath", "RussianPhysics", "Combined"]
MIN_COLOUR_DISTANCE_BETWEEN_MODELS = 100
seed = 42
def generate_colours(min_distance, seed):
colour_mapping = {}
all_models = selected_models
for i in it.count():
min_colour_distance = min_distance - i
retries_left = 10 * len(all_models)
for model_id in all_models:
random.seed(hash(model_id) + i + seed)
r, g, b = 0, 0, 0
too_bright, similar_to_other_model = True, True
while (too_bright or similar_to_other_model) and retries_left > 0:
r, g, b = tuple(random.randint(0, 255) for _ in range(3))
too_bright = np.min([r, g, b]) > 200
similar_to_other_model = any(
np.abs(np.array(colour) - np.array([r, g, b])).sum() < min_colour_distance
for colour in colour_mapping.values()
)
retries_left -= 1
colour_mapping[model_id] = (r, g, b)
if len(colour_mapping) == len(all_models):
break
return colour_mapping
colour_mapping = generate_colours(MIN_COLOUR_DISTANCE_BETWEEN_MODELS, seed)
fig = go.Figure()
for _, model_data in models.iterrows():
model_name = model_data["model"]
if model_name not in selected_models:
continue
values = [model_data[metric] for metric in metrics]
color = f'rgb{colour_mapping[model_name]}'
fig.add_trace(go.Scatterpolar(
r=values,
theta=metric_labels,
name=model_name,
fill='toself',
fillcolor=f'rgba{colour_mapping[model_name] + (0.6,)}',
line=dict(color=color)
))
fig.update_layout(
polar=dict(
radialaxis=dict(
visible=True,
range=[0, 1]
)
),
showlegend=True,
title='Сравнение моделей на DeathMath',
template="plotly_dark",
)
return fig