Spaces:
Running
Running
import plotly.graph_objects as go | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import pandas as pd | |
import random | |
import itertools as it | |
from src.leaderboard.build_leaderboard import build_leadearboard_df | |
def create_plot(selected_models): | |
""" | |
Создает визуализацию для сравнения выбранных моделей по метрикам DeathMath | |
Args: | |
selected_models: Список названий моделей для отображения на графике | |
Returns: | |
matplotlib.figure.Figure: График для отображения в интерфейсе | |
""" | |
# Получаем данные моделей из лидерборда | |
models_df = build_leadearboard_df() | |
# Если нет выбранных моделей или данные не загружены, возвращаем пустой график | |
if not selected_models or models_df.empty: | |
fig, ax = plt.subplots(figsize=(10, 6)) | |
ax.text(0.5, 0.5, "Нет данных для отображения", | |
horizontalalignment='center', verticalalignment='center', | |
transform=ax.transAxes, fontsize=14) | |
ax.set_axis_off() | |
return fig | |
# Фильтруем DataFrame, чтобы оставить только выбранные модели | |
models_to_show = models_df[models_df['model'].isin(selected_models)] | |
if models_to_show.empty: | |
fig, ax = plt.subplots(figsize=(10, 6)) | |
ax.text(0.5, 0.5, "Выбранные модели не найдены в данных", | |
horizontalalignment='center', verticalalignment='center', | |
transform=ax.transAxes, fontsize=14) | |
ax.set_axis_off() | |
return fig | |
# Настройка бар-графика для сравнения моделей | |
fig, ax = plt.subplots(figsize=(12, 8)) | |
# Ширина столбцов | |
bar_width = 0.25 | |
# Позиции на оси x | |
models_count = len(models_to_show) | |
indices = np.arange(models_count) | |
# Цветовая палитра | |
colors = ['#1f77b4', '#ff7f0e', '#2ca02c'] | |
# Строим столбцы для разных метрик | |
ax.bar(indices - bar_width, models_to_show['math_score'], bar_width, | |
label='RussianMath Score', color=colors[0]) | |
ax.bar(indices, models_to_show['physics_score'], bar_width, | |
label='RussianPhysics Score', color=colors[1]) | |
ax.bar(indices + bar_width, models_to_show['score'], bar_width, | |
label='Combined Score', color=colors[2]) | |
# Настройка осей и меток | |
ax.set_xlabel('Модели') | |
ax.set_ylabel('Баллы') | |
ax.set_title('Сравнение производительности моделей на DeathMath benchmark') | |
ax.set_xticks(indices) | |
ax.set_xticklabels(models_to_show['model'], rotation=45, ha='right') | |
ax.legend() | |
# Ограничение значений по оси y от 0 до 1 | |
ax.set_ylim(0, 1.0) | |
# Добавляем сетку для лучшей читаемости | |
ax.grid(axis='y', linestyle='--', alpha=0.7) | |
# Обеспечиваем, чтобы все метки помещались | |
plt.tight_layout() | |
return fig | |
def create_radar_plot(selected_models): | |
""" | |
Создает радиальную диаграмму для сравнения выбранных моделей | |
Args: | |
selected_models: Список названий моделей для отображения на графике | |
Returns: | |
plotly.graph_objects.Figure: Интерактивный радиальный график | |
""" | |
models = build_leadearboard_df() | |
metrics = ["math_score", "physics_score", "score"] | |
metric_labels = ["RussianMath", "RussianPhysics", "Combined"] | |
MIN_COLOUR_DISTANCE_BETWEEN_MODELS = 100 | |
seed = 42 | |
def generate_colours(min_distance, seed): | |
colour_mapping = {} | |
all_models = selected_models | |
for i in it.count(): | |
min_colour_distance = min_distance - i | |
retries_left = 10 * len(all_models) | |
for model_id in all_models: | |
random.seed(hash(model_id) + i + seed) | |
r, g, b = 0, 0, 0 | |
too_bright, similar_to_other_model = True, True | |
while (too_bright or similar_to_other_model) and retries_left > 0: | |
r, g, b = tuple(random.randint(0, 255) for _ in range(3)) | |
too_bright = np.min([r, g, b]) > 200 | |
similar_to_other_model = any( | |
np.abs(np.array(colour) - np.array([r, g, b])).sum() < min_colour_distance | |
for colour in colour_mapping.values() | |
) | |
retries_left -= 1 | |
colour_mapping[model_id] = (r, g, b) | |
if len(colour_mapping) == len(all_models): | |
break | |
return colour_mapping | |
colour_mapping = generate_colours(MIN_COLOUR_DISTANCE_BETWEEN_MODELS, seed) | |
fig = go.Figure() | |
for _, model_data in models.iterrows(): | |
model_name = model_data["model"] | |
if model_name not in selected_models: | |
continue | |
values = [model_data[metric] for metric in metrics] | |
color = f'rgb{colour_mapping[model_name]}' | |
fig.add_trace(go.Scatterpolar( | |
r=values, | |
theta=metric_labels, | |
name=model_name, | |
fill='toself', | |
fillcolor=f'rgba{colour_mapping[model_name] + (0.6,)}', | |
line=dict(color=color) | |
)) | |
fig.update_layout( | |
polar=dict( | |
radialaxis=dict( | |
visible=True, | |
range=[0, 1] | |
) | |
), | |
showlegend=True, | |
title='Сравнение моделей на DeathMath', | |
template="plotly_dark", | |
) | |
return fig | |