kfoughali commited on
Commit
592dfd5
·
verified ·
1 Parent(s): ad2bc8f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -10
app.py CHANGED
@@ -1193,22 +1193,44 @@ def plot_compression_tradeoff(summaries_by_ratio: Dict[float, Dict[str, Any]],
1193
  logger.info(f"Compression trade-off plots saved: {plot_path}")
1194
  return plot_path
1195
 
1196
- def plot_ablation_results(ablation_results: Dict[str, Dict[str, Any]], baseline_summary: Dict[str, Any]) -> str:
1197
  """Generate publication-grade ablation study plots."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1198
  fig, axes = plt.subplots(2, 3, figsize=(18, 10))
1199
 
 
 
 
 
 
 
1200
  # Prepare data
1201
- configs = list(ablation_results.keys())
1202
- compression_ratios = [ablation_results[c]['summary']['compression_ratio'] for c in configs]
1203
- gen_ppls = [ablation_results[c]['summary']['generation_perplexity'] for c in configs]
1204
- decode_times = [ablation_results[c]['summary']['decode_time_ms'] for c in configs]
1205
- kv_memories = [ablation_results[c]['summary']['kv_cache_memory_mb'] for c in configs]
1206
- throughputs = [ablation_results[c]['summary'].get('end_to_end_throughput', 0) for c in configs]
1207
 
1208
  baseline_gen_ppl = baseline_summary['generation_perplexity']
1209
  baseline_decode_time = baseline_summary['decode_time_ms']
1210
  baseline_kv_memory = baseline_summary['kv_cache_memory_mb']
1211
- baseline_throughput = baseline_summary.get('end_to_end_throughput', 0)
1212
 
1213
  # 1. Compression Ratio by Component
1214
  ax1 = axes[0, 0]
@@ -3448,9 +3470,26 @@ def create_research_interface():
3448
  tradeoff_path = None
3449
 
3450
  # Generate ablation plots if ablation study was done
3451
- if enable_ablation and ablation_results and "baseline" in ablation_results:
3452
  try:
3453
- ablation_plots_path = plot_ablation_results(ablation_results, ablation_results["baseline"]["summary"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3454
  except Exception as e:
3455
  logger.error(f"Failed to generate ablation plots: {e}")
3456
  ablation_plots_path = None
 
1193
  logger.info(f"Compression trade-off plots saved: {plot_path}")
1194
  return plot_path
1195
 
1196
+ def plot_ablation_results(ablation_results: Dict[str, Dict[str, Any]], baseline_summary: Dict[str, Any] = None) -> str:
1197
  """Generate publication-grade ablation study plots."""
1198
+ # Find baseline summary if not provided
1199
+ if baseline_summary is None:
1200
+ if "baseline" in ablation_results and "summary" in ablation_results["baseline"]:
1201
+ baseline_summary = ablation_results["baseline"]["summary"]
1202
+ else:
1203
+ # Find any NONE compression type as baseline
1204
+ for config_name, result in ablation_results.items():
1205
+ if "summary" in result:
1206
+ baseline_summary = result["summary"]
1207
+ logger.warning(f"Using {config_name} as baseline for ablation plots")
1208
+ break
1209
+
1210
+ if baseline_summary is None:
1211
+ logger.error("No valid baseline found for ablation plots")
1212
+ return None
1213
+
1214
  fig, axes = plt.subplots(2, 3, figsize=(18, 10))
1215
 
1216
+ # Filter out configs with errors
1217
+ valid_configs = {k: v for k, v in ablation_results.items() if "summary" in v}
1218
+ if not valid_configs:
1219
+ logger.error("No valid ablation results to plot")
1220
+ return None
1221
+
1222
  # Prepare data
1223
+ configs = list(valid_configs.keys())
1224
+ compression_ratios = [valid_configs[c]['summary']['compression_ratio'] for c in configs]
1225
+ gen_ppls = [valid_configs[c]['summary']['generation_perplexity'] for c in configs]
1226
+ decode_times = [valid_configs[c]['summary']['decode_time_ms'] for c in configs]
1227
+ kv_memories = [valid_configs[c]['summary']['kv_cache_memory_mb'] for c in configs]
1228
+ throughputs = [valid_configs[c]['summary'].get('end_to_end_throughput', 0) for c in configs]
1229
 
1230
  baseline_gen_ppl = baseline_summary['generation_perplexity']
1231
  baseline_decode_time = baseline_summary['decode_time_ms']
1232
  baseline_kv_memory = baseline_summary['kv_cache_memory_mb']
1233
+ baseline_throughput = baseline_summary.get('end_to_end_throughput', 1)
1234
 
1235
  # 1. Compression Ratio by Component
1236
  ax1 = axes[0, 0]
 
3470
  tradeoff_path = None
3471
 
3472
  # Generate ablation plots if ablation study was done
3473
+ if enable_ablation and ablation_results:
3474
  try:
3475
+ # Find a valid baseline for comparison
3476
+ baseline_for_plots = None
3477
+ if "baseline" in ablation_results and "summary" in ablation_results["baseline"]:
3478
+ baseline_for_plots = ablation_results["baseline"]["summary"]
3479
+ elif "NONE" in all_summaries:
3480
+ baseline_for_plots = all_summaries["NONE"]
3481
+ else:
3482
+ # Use the first available summary
3483
+ for result in ablation_results.values():
3484
+ if "summary" in result:
3485
+ baseline_for_plots = result["summary"]
3486
+ break
3487
+
3488
+ if baseline_for_plots:
3489
+ ablation_plots_path = plot_ablation_results(ablation_results, baseline_for_plots)
3490
+ else:
3491
+ logger.warning("No baseline available for ablation plots")
3492
+ ablation_plots_path = None
3493
  except Exception as e:
3494
  logger.error(f"Failed to generate ablation plots: {e}")
3495
  ablation_plots_path = None