Update app.py
Browse files
app.py
CHANGED
|
@@ -1193,22 +1193,44 @@ def plot_compression_tradeoff(summaries_by_ratio: Dict[float, Dict[str, Any]],
|
|
| 1193 |
logger.info(f"Compression trade-off plots saved: {plot_path}")
|
| 1194 |
return plot_path
|
| 1195 |
|
| 1196 |
-
def plot_ablation_results(ablation_results: Dict[str, Dict[str, Any]], baseline_summary: Dict[str, Any]) -> str:
|
| 1197 |
"""Generate publication-grade ablation study plots."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1198 |
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
|
| 1199 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1200 |
# Prepare data
|
| 1201 |
-
configs = list(
|
| 1202 |
-
compression_ratios = [
|
| 1203 |
-
gen_ppls = [
|
| 1204 |
-
decode_times = [
|
| 1205 |
-
kv_memories = [
|
| 1206 |
-
throughputs = [
|
| 1207 |
|
| 1208 |
baseline_gen_ppl = baseline_summary['generation_perplexity']
|
| 1209 |
baseline_decode_time = baseline_summary['decode_time_ms']
|
| 1210 |
baseline_kv_memory = baseline_summary['kv_cache_memory_mb']
|
| 1211 |
-
baseline_throughput = baseline_summary.get('end_to_end_throughput',
|
| 1212 |
|
| 1213 |
# 1. Compression Ratio by Component
|
| 1214 |
ax1 = axes[0, 0]
|
|
@@ -3448,9 +3470,26 @@ def create_research_interface():
|
|
| 3448 |
tradeoff_path = None
|
| 3449 |
|
| 3450 |
# Generate ablation plots if ablation study was done
|
| 3451 |
-
if enable_ablation and ablation_results
|
| 3452 |
try:
|
| 3453 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3454 |
except Exception as e:
|
| 3455 |
logger.error(f"Failed to generate ablation plots: {e}")
|
| 3456 |
ablation_plots_path = None
|
|
|
|
| 1193 |
logger.info(f"Compression trade-off plots saved: {plot_path}")
|
| 1194 |
return plot_path
|
| 1195 |
|
| 1196 |
+
def plot_ablation_results(ablation_results: Dict[str, Dict[str, Any]], baseline_summary: Dict[str, Any] = None) -> str:
|
| 1197 |
"""Generate publication-grade ablation study plots."""
|
| 1198 |
+
# Find baseline summary if not provided
|
| 1199 |
+
if baseline_summary is None:
|
| 1200 |
+
if "baseline" in ablation_results and "summary" in ablation_results["baseline"]:
|
| 1201 |
+
baseline_summary = ablation_results["baseline"]["summary"]
|
| 1202 |
+
else:
|
| 1203 |
+
# Find any NONE compression type as baseline
|
| 1204 |
+
for config_name, result in ablation_results.items():
|
| 1205 |
+
if "summary" in result:
|
| 1206 |
+
baseline_summary = result["summary"]
|
| 1207 |
+
logger.warning(f"Using {config_name} as baseline for ablation plots")
|
| 1208 |
+
break
|
| 1209 |
+
|
| 1210 |
+
if baseline_summary is None:
|
| 1211 |
+
logger.error("No valid baseline found for ablation plots")
|
| 1212 |
+
return None
|
| 1213 |
+
|
| 1214 |
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
|
| 1215 |
|
| 1216 |
+
# Filter out configs with errors
|
| 1217 |
+
valid_configs = {k: v for k, v in ablation_results.items() if "summary" in v}
|
| 1218 |
+
if not valid_configs:
|
| 1219 |
+
logger.error("No valid ablation results to plot")
|
| 1220 |
+
return None
|
| 1221 |
+
|
| 1222 |
# Prepare data
|
| 1223 |
+
configs = list(valid_configs.keys())
|
| 1224 |
+
compression_ratios = [valid_configs[c]['summary']['compression_ratio'] for c in configs]
|
| 1225 |
+
gen_ppls = [valid_configs[c]['summary']['generation_perplexity'] for c in configs]
|
| 1226 |
+
decode_times = [valid_configs[c]['summary']['decode_time_ms'] for c in configs]
|
| 1227 |
+
kv_memories = [valid_configs[c]['summary']['kv_cache_memory_mb'] for c in configs]
|
| 1228 |
+
throughputs = [valid_configs[c]['summary'].get('end_to_end_throughput', 0) for c in configs]
|
| 1229 |
|
| 1230 |
baseline_gen_ppl = baseline_summary['generation_perplexity']
|
| 1231 |
baseline_decode_time = baseline_summary['decode_time_ms']
|
| 1232 |
baseline_kv_memory = baseline_summary['kv_cache_memory_mb']
|
| 1233 |
+
baseline_throughput = baseline_summary.get('end_to_end_throughput', 1)
|
| 1234 |
|
| 1235 |
# 1. Compression Ratio by Component
|
| 1236 |
ax1 = axes[0, 0]
|
|
|
|
| 3470 |
tradeoff_path = None
|
| 3471 |
|
| 3472 |
# Generate ablation plots if ablation study was done
|
| 3473 |
+
if enable_ablation and ablation_results:
|
| 3474 |
try:
|
| 3475 |
+
# Find a valid baseline for comparison
|
| 3476 |
+
baseline_for_plots = None
|
| 3477 |
+
if "baseline" in ablation_results and "summary" in ablation_results["baseline"]:
|
| 3478 |
+
baseline_for_plots = ablation_results["baseline"]["summary"]
|
| 3479 |
+
elif "NONE" in all_summaries:
|
| 3480 |
+
baseline_for_plots = all_summaries["NONE"]
|
| 3481 |
+
else:
|
| 3482 |
+
# Use the first available summary
|
| 3483 |
+
for result in ablation_results.values():
|
| 3484 |
+
if "summary" in result:
|
| 3485 |
+
baseline_for_plots = result["summary"]
|
| 3486 |
+
break
|
| 3487 |
+
|
| 3488 |
+
if baseline_for_plots:
|
| 3489 |
+
ablation_plots_path = plot_ablation_results(ablation_results, baseline_for_plots)
|
| 3490 |
+
else:
|
| 3491 |
+
logger.warning("No baseline available for ablation plots")
|
| 3492 |
+
ablation_plots_path = None
|
| 3493 |
except Exception as e:
|
| 3494 |
logger.error(f"Failed to generate ablation plots: {e}")
|
| 3495 |
ablation_plots_path = None
|