tristan-deep commited on
Commit
1398519
·
2 Parent(s): 57a5488 b97027a

Merge branch 'main' of github.com:tristan-deep/semantic-diffusion-echo-dehazing

Browse files
Files changed (4) hide show
  1. .gitignore +1 -0
  2. eval.py +21 -2
  3. plots.py +104 -0
  4. sweeper.py +418 -0
.gitignore CHANGED
@@ -5,3 +5,4 @@ temp/
5
  *.pdf
6
  *.hash
7
  *.npz
 
 
5
  *.pdf
6
  *.hash
7
  *.npz
8
+ sweep_results/
eval.py CHANGED
@@ -208,7 +208,7 @@ def calculate_final_score(aggregates):
208
  return 0
209
 
210
 
211
- def main(folder: str, noisy_folder: str, roi_folder: str, reference_folder: str):
212
  """Evaluate the dehazing algorithm.
213
 
214
  Args:
@@ -294,6 +294,25 @@ def main(folder: str, noisy_folder: str, roi_folder: str, reference_folder: str)
294
  fid_score = calculate_fid_score(fid_image_paths, str(reference_folder))
295
  print(f"FID between {folder} and {reference_folder}: {fid_score:.3f}")
296
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
 
298
  if __name__ == "__main__":
299
- tyro.cli(main)
 
208
  return 0
209
 
210
 
211
+ def evaluate(folder: str, noisy_folder: str, roi_folder: str, reference_folder: str):
212
  """Evaluate the dehazing algorithm.
213
 
214
  Args:
 
294
  fid_score = calculate_fid_score(fid_image_paths, str(reference_folder))
295
  print(f"FID between {folder} and {reference_folder}: {fid_score:.3f}")
296
 
297
+ # Create aggregates dictionary for final score calculation
298
+ aggregates = {
299
+ "fid": float(fid_score),
300
+ "cnr_mean": float(np.mean(metrics["CNR"])),
301
+ "cnr_std": float(np.std(metrics["CNR"])),
302
+ "gcnr_mean": float(np.mean(metrics["gCNR"])),
303
+ "gcnr_std": float(np.std(metrics["gCNR"])),
304
+ "ks_roi1_ksstatistic_mean": float(np.mean(metrics["KS_A"])),
305
+ "ks_roi1_ksstatistic_std": float(np.std(metrics["KS_A"])),
306
+ "ks_roi2_ksstatistic_mean": float(np.mean(metrics["KS_B"])),
307
+ "ks_roi2_ksstatistic_std": float(np.std(metrics["KS_B"])),
308
+ }
309
+
310
+ # Calculate final score
311
+ final_score = calculate_final_score(aggregates)
312
+ aggregates["final_score"] = float(final_score)
313
+
314
+ return aggregates
315
+
316
 
317
  if __name__ == "__main__":
318
+ tyro.cli(evaluate)
plots.py CHANGED
@@ -1,5 +1,10 @@
 
 
 
 
1
  import matplotlib.pyplot as plt
2
  import numpy as np
 
3
  from keras import ops
4
  from matplotlib.patches import PathPatch
5
  from matplotlib.path import Path as pltPath
@@ -252,3 +257,102 @@ def plot_metrics(metrics, limits, out_path):
252
  bbox_to_anchor=(0.5, 1.02),
253
  )
254
  return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+ from typing import Any, Dict, List
4
+
5
  import matplotlib.pyplot as plt
6
  import numpy as np
7
+ import tyro
8
  from keras import ops
9
  from matplotlib.patches import PathPatch
10
  from matplotlib.path import Path as pltPath
 
257
  bbox_to_anchor=(0.5, 1.02),
258
  )
259
  return fig
260
+
261
+
262
+ def plot_optimization_history_from_json(
263
+ trials_data: List[Dict[str, Any]], output_path: Path, method: str
264
+ ):
265
+ """Plot optimization history directly from JSON data."""
266
+
267
+ # Extract completed trials with values
268
+ completed_trials = [
269
+ t for t in trials_data if t["state"] == "COMPLETE" and t["value"] is not None
270
+ ]
271
+
272
+ if not completed_trials:
273
+ print("No completed trials found!")
274
+ return
275
+
276
+ # Sort by trial number
277
+ completed_trials.sort(key=lambda x: x["number"])
278
+
279
+ trial_numbers = [t["number"] for t in completed_trials]
280
+ values = [t["value"] for t in completed_trials]
281
+
282
+ # Apply seaborn styling
283
+ plt.style.use("seaborn-v0_8-darkgrid")
284
+
285
+ # Create the plot
286
+ fig, ax = plt.subplots(figsize=(5, 3), dpi=600)
287
+
288
+ # Plot all trial values with styling similar to plots.py
289
+ ax.scatter(
290
+ trial_numbers,
291
+ values,
292
+ c="#0057b7",
293
+ alpha=0.6,
294
+ s=30,
295
+ edgecolor="black",
296
+ linewidth=0.5,
297
+ )
298
+
299
+ # Plot best value so far
300
+ best_values = []
301
+ current_best = values[0]
302
+ for val in values:
303
+ if val > current_best: # Assuming maximization
304
+ current_best = val
305
+ best_values.append(current_best)
306
+
307
+ ax.plot(
308
+ trial_numbers,
309
+ best_values,
310
+ color="#d62d20",
311
+ linewidth=2.5,
312
+ label="Best Value",
313
+ marker="o",
314
+ markersize=4,
315
+ markevery=max(1, len(trial_numbers) // 20),
316
+ )
317
+
318
+ ax.set_xlabel("Trial", fontsize=11)
319
+ ax.set_ylabel("Objective Value", fontsize=11)
320
+ # ax.set_title("Optimization History", fontsize=12)
321
+ ax.legend(fontsize=10, frameon=False)
322
+
323
+ # Remove top and right spines like in plots.py
324
+ ax.spines["top"].set_visible(False)
325
+ ax.spines["right"].set_visible(False)
326
+ ax.tick_params(axis="both", which="major", labelsize=9)
327
+
328
+ # Save plot
329
+ fig.savefig(
330
+ output_path / f"optimization_history_{method}.png", dpi=600, bbox_inches="tight"
331
+ )
332
+ fig.savefig(
333
+ output_path / f"optimization_history_{method}.pdf", dpi=600, bbox_inches="tight"
334
+ )
335
+ plt.close(fig)
336
+
337
+
338
+ def main(json_file: str, output_dir: str = "plots", method: str = "semantic_dps"):
339
+ json_path = Path(json_file)
340
+ if not json_path.exists():
341
+ raise FileNotFoundError(f"JSON file not found: {json_file}")
342
+
343
+ # Load trial data
344
+ with open(json_path, "r") as f:
345
+ trials_data = json.load(f)
346
+
347
+ print(f"Loaded {len(trials_data)} trials from {json_file}")
348
+
349
+ # Create output directory
350
+ output_path = Path(output_dir)
351
+ output_path.mkdir(parents=True, exist_ok=True)
352
+
353
+ print("Generating optimization history plot...")
354
+ plot_optimization_history_from_json(trials_data, output_path, method)
355
+
356
+
357
+ if __name__ == "__main__":
358
+ tyro.cli(main)
sweeper.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+
3
+ NOTE: pip install optuna
4
+
5
+ """
6
+
7
+ import dataclasses
8
+ import json
9
+ import shutil
10
+ import tempfile
11
+ from pathlib import Path
12
+ from typing import Any, Dict, Optional
13
+
14
+ import jax
15
+ import numpy as np
16
+ import optuna
17
+ import tyro
18
+ import yaml
19
+ import zea
20
+ from keras import ops
21
+ from PIL import Image
22
+ from zea import init_device, log
23
+
24
+ from eval import evaluate
25
+ from main import init, run
26
+
27
+
28
+ def load_images_from_dir(input_folder):
29
+ """Load images from directory, similar to main.py implementation."""
30
+ paths = list(Path(input_folder).glob("*.png"))
31
+
32
+ images = []
33
+ for path in paths:
34
+ image = zea.io_lib.load_image(path)
35
+ images.append(image)
36
+
37
+ if len(images) == 0:
38
+ raise ValueError(f"No PNG images found in {input_folder}")
39
+
40
+ images = ops.stack(images, axis=0)
41
+ return images, paths
42
+
43
+
44
+ def save_images_to_temp_dir(images, image_paths, prefix=""):
45
+ """Save numpy arrays as PNG images to a temporary directory."""
46
+ temp_dir = tempfile.mkdtemp(prefix=prefix)
47
+ temp_dir_path = Path(temp_dir)
48
+
49
+ for i, (img, path) in enumerate(zip(images, image_paths)):
50
+ # Get the filename from the original path
51
+ filename = Path(path).name
52
+
53
+ # Convert image to uint8 if needed
54
+ if img.dtype != np.uint8:
55
+ # Assume image is in [0, 1] range and convert to [0, 255]
56
+ if img.max() <= 1.0:
57
+ img = (img * 255).astype(np.uint8)
58
+ else:
59
+ img = img.astype(np.uint8)
60
+
61
+ # Ensure image is 2D or 3D
62
+ if len(img.shape) == 3 and img.shape[-1] == 1:
63
+ img = img.squeeze(-1)
64
+
65
+ # Save as PNG
66
+ img_pil = Image.fromarray(img)
67
+ save_path = temp_dir_path / filename
68
+ img_pil.save(save_path)
69
+
70
+ return str(temp_dir_path)
71
+
72
+
73
+ @dataclasses.dataclass
74
+ class SweeperConfig:
75
+ """Configuration for hyperparameter sweeping with Optuna."""
76
+
77
+ # Required paths - no defaults
78
+ input_image_dir: str # Path to input hazy images
79
+ roi_folder: str # Path to ROI mask images
80
+ reference_folder: str # Path to reference/ground truth images
81
+ base_config_path: str = "configs/semantic_dps.yaml"
82
+
83
+ # Base configuration
84
+ method: str = "semantic_dps" # Which method to optimize
85
+ broad_sweep: bool = False # Choose between broad or narrow sweep
86
+
87
+ # Optuna settings
88
+ study_name: str = "dehaze_optimization"
89
+ storage: Optional[str] = None # e.g., "sqlite:///dehaze_study.db" for persistence
90
+ n_trials: int = 100
91
+
92
+ # Optimization settings
93
+ objective_metric: str = "final_score" # Which metric to optimize
94
+ direction: str = "maximize" # "maximize" or "minimize"
95
+
96
+ # Output settings
97
+ output_dir: str = "sweep_results"
98
+
99
+ # Evaluation settings
100
+ skip_fid: bool = False
101
+
102
+ # Device configuration
103
+ device: str = "auto:1"
104
+
105
+ # Pruning settings
106
+ enable_pruning: bool = True
107
+ pruner_type: str = "median" # "median", "hyperband", or "none"
108
+
109
+
110
+ class OptunaObjective:
111
+ """Optuna objective function for hyperparameter optimization."""
112
+
113
+ def __init__(self, config: SweeperConfig):
114
+ self.config = config
115
+ self.base_config = self._load_base_config()
116
+ self.hazy_images, self.image_paths = load_images_from_dir(
117
+ config.input_image_dir
118
+ )
119
+
120
+ # Initialize device
121
+ init_device(config.device)
122
+
123
+ # Initialize the diffusion model once
124
+ self.diffusion_model = init(self.base_config)
125
+
126
+ def _load_base_config(self):
127
+ """Load base configuration from YAML file."""
128
+ with open(self.config.base_config_path, "r") as f:
129
+ config_dict = yaml.safe_load(f)
130
+ return zea.Config(**config_dict)
131
+
132
+ def _create_trial_params(self, trial: optuna.Trial) -> Dict[str, Any]:
133
+ """Create trial parameters by suggesting hyperparameters."""
134
+ params = {
135
+ "guidance_kwargs": {
136
+ "omega": trial.suggest_float("omega", 0.5, 50.0, log=True),
137
+ "omega_vent": trial.suggest_float("omega_vent", 0.0001, 50.0, log=True),
138
+ "omega_sept": trial.suggest_float("omega_sept", 0.1, 50.0, log=True),
139
+ "omega_dark": trial.suggest_float("omega_dark", 0.001, 50.0, log=True),
140
+ "eta": trial.suggest_float("eta", 0.001, 1.0, log=True),
141
+ "smooth_l1_beta": trial.suggest_float(
142
+ "smooth_l1_beta", 0.1, 10.0, log=True
143
+ ),
144
+ },
145
+ "skeleton_params": {
146
+ "sigma_pre": trial.suggest_float("skeleton_sigma_pre", 0.0, 10.0),
147
+ "sigma_post": trial.suggest_float("skeleton_sigma_post", 0.0, 10.0),
148
+ "threshold": trial.suggest_float("skeleton_threshold", 0.0, 1.0),
149
+ },
150
+ "mask_params": {
151
+ "threshold": trial.suggest_float("mask_threshold", 0.0, 1.0),
152
+ "sigma": trial.suggest_float("mask_sigma", 0.0, 10.0),
153
+ },
154
+ }
155
+
156
+ # Add base config parameters that aren't being optimized
157
+ if hasattr(self.base_config, "params"):
158
+ base_params = self.base_config.params
159
+ for key, value in base_params.items():
160
+ if key not in params:
161
+ params[key] = value
162
+
163
+ return params
164
+
165
+ def __call__(self, trial: optuna.Trial) -> float:
166
+ """Optuna objective function."""
167
+ # Suggest hyperparameters for this trial
168
+ params = self._create_trial_params(trial)
169
+
170
+ # Create seed for reproducibility
171
+ seed = jax.random.PRNGKey(self.base_config.seed + trial.number)
172
+
173
+ # Run the semantic DPS method
174
+ try:
175
+ hazy_images, pred_tissue_images, pred_haze_images, masks = run(
176
+ hazy_images=self.hazy_images,
177
+ diffusion_model=self.diffusion_model,
178
+ seed=seed,
179
+ **params,
180
+ )
181
+ except Exception as e:
182
+ log.error(f"Error during model inference: {e}")
183
+ return 0.0
184
+
185
+ # Convert tensors to numpy arrays if needed
186
+ if hasattr(pred_tissue_images, "numpy"):
187
+ pred_tissue_images = pred_tissue_images.numpy()
188
+
189
+ # Initialize temp directory
190
+ pred_tissue_temp_dir = None
191
+
192
+ try:
193
+ # Save predicted tissue images to temp directory
194
+ pred_tissue_temp_dir = save_images_to_temp_dir(
195
+ pred_tissue_images, self.image_paths, prefix="pred_tissue_"
196
+ )
197
+
198
+ # Run evaluation using the updated evaluate function
199
+ results = evaluate(
200
+ folder=pred_tissue_temp_dir,
201
+ noisy_folder=self.config.input_image_dir,
202
+ roi_folder=self.config.roi_folder,
203
+ reference_folder=self.config.reference_folder,
204
+ )
205
+
206
+ objective_value = results[self.config.objective_metric]
207
+
208
+ except Exception as e:
209
+ log.error(f"Error during evaluation: {e}")
210
+ objective_value = 0.0
211
+
212
+ finally:
213
+ # Clean up temporary directory
214
+ if pred_tissue_temp_dir and Path(pred_tissue_temp_dir).exists():
215
+ try:
216
+ shutil.rmtree(pred_tissue_temp_dir)
217
+ except Exception as e:
218
+ log.warning(
219
+ f"Failed to clean up temp directory {pred_tissue_temp_dir}: {e}"
220
+ )
221
+
222
+ # Log intermediate results for potential pruning
223
+ trial.report(objective_value, step=0)
224
+
225
+ # Check if trial should be pruned
226
+ if trial.should_prune():
227
+ raise optuna.TrialPruned()
228
+
229
+ # Store hyperparameters as user attributes
230
+ for key, value in params.items():
231
+ if isinstance(value, dict):
232
+ for subkey, subvalue in value.items():
233
+ trial.set_user_attr(f"{key}_{subkey}", subvalue)
234
+ else:
235
+ trial.set_user_attr(key, value)
236
+
237
+ log.info(
238
+ f"Trial {trial.number}: {self.config.objective_metric} = {objective_value:.4f}"
239
+ )
240
+
241
+ return objective_value
242
+
243
+
244
+ def create_pruner(pruner_type: str) -> optuna.pruners.BasePruner:
245
+ """Create an Optuna pruner based on the specified type."""
246
+ if pruner_type == "median":
247
+ return optuna.pruners.MedianPruner(
248
+ n_startup_trials=5, n_warmup_steps=0, interval_steps=1
249
+ )
250
+ elif pruner_type == "hyperband":
251
+ return optuna.pruners.HyperbandPruner(
252
+ min_resource=1, max_resource=100, reduction_factor=3
253
+ )
254
+ else:
255
+ return optuna.pruners.NopPruner()
256
+
257
+
258
+ def run_optimization(config: SweeperConfig):
259
+ """Run hyperparameter optimization using Optuna."""
260
+
261
+ # Create pruner
262
+ pruner = create_pruner(config.pruner_type) if config.enable_pruning else None
263
+
264
+ # Create or load study
265
+ study = optuna.create_study(
266
+ study_name=config.study_name,
267
+ storage=config.storage,
268
+ direction=config.direction,
269
+ pruner=pruner,
270
+ load_if_exists=True,
271
+ )
272
+
273
+ log.info(f"Starting optimization for method: {config.method}")
274
+ log.info(f"Study name: {config.study_name}")
275
+ log.info(f"Number of trials: {config.n_trials}")
276
+ log.info(f"Objective metric: {config.objective_metric} ({config.direction})")
277
+
278
+ # Create objective function
279
+ objective = OptunaObjective(config)
280
+
281
+ # Run optimization
282
+ study.optimize(objective, n_trials=config.n_trials)
283
+
284
+ # Save results
285
+ output_dir = Path(config.output_dir)
286
+ output_dir.mkdir(parents=True, exist_ok=True)
287
+
288
+ # Save best trial info
289
+ best_trial = study.best_trial
290
+ best_results = {
291
+ "best_value": best_trial.value,
292
+ "best_params": best_trial.params,
293
+ "best_user_attrs": best_trial.user_attrs,
294
+ "study_stats": {
295
+ "n_trials": len(study.trials),
296
+ "n_complete_trials": len(
297
+ [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]
298
+ ),
299
+ "n_pruned_trials": len(
300
+ [t for t in study.trials if t.state == optuna.trial.TrialState.PRUNED]
301
+ ),
302
+ "n_failed_trials": len(
303
+ [t for t in study.trials if t.state == optuna.trial.TrialState.FAIL]
304
+ ),
305
+ },
306
+ }
307
+
308
+ with open(
309
+ output_dir / f"best_results_{config.method}_{config.study_name}.json", "w"
310
+ ) as f:
311
+ json.dump(best_results, f, indent=2)
312
+
313
+ # Save all trials data
314
+ trials_data = []
315
+ for trial in study.trials:
316
+ trial_data = {
317
+ "number": trial.number,
318
+ "value": trial.value,
319
+ "params": trial.params,
320
+ "user_attrs": trial.user_attrs,
321
+ "state": trial.state.name,
322
+ "datetime_start": trial.datetime_start.isoformat()
323
+ if trial.datetime_start
324
+ else None,
325
+ "datetime_complete": trial.datetime_complete.isoformat()
326
+ if trial.datetime_complete
327
+ else None,
328
+ }
329
+ trials_data.append(trial_data)
330
+
331
+ with open(
332
+ output_dir / f"all_trials_{config.method}_{config.study_name}.json", "w"
333
+ ) as f:
334
+ json.dump(trials_data, f, indent=2)
335
+
336
+ # Print summary
337
+ log.success("Optimization completed!")
338
+ log.info(f"Best {config.objective_metric}: {best_trial.value:.4f}")
339
+ log.info("Best parameters:")
340
+ for key, value in best_trial.params.items():
341
+ log.info(f" {key}: {value}")
342
+
343
+ # Print study statistics
344
+ stats = best_results["study_stats"]
345
+ log.info("Study statistics:")
346
+ log.info(f" Total trials: {stats['n_trials']}")
347
+ log.info(f" Complete trials: {stats['n_complete_trials']}")
348
+ log.info(f" Pruned trials: {stats['n_pruned_trials']}")
349
+ log.info(f" Failed trials: {stats['n_failed_trials']}")
350
+
351
+ return study
352
+
353
+
354
+ def main():
355
+ """Main function for running hyperparameter optimization."""
356
+ config = tyro.cli(SweeperConfig)
357
+
358
+ # Validate required paths exist
359
+ required_paths = [
360
+ (config.input_image_dir, "Input image directory"),
361
+ (config.roi_folder, "ROI folder"),
362
+ (config.reference_folder, "Reference folder"),
363
+ ]
364
+
365
+ for path, description in required_paths:
366
+ if not Path(path).exists():
367
+ raise FileNotFoundError(f"{description} not found: {path}")
368
+
369
+ # Set visualization style
370
+ zea.visualize.set_mpl_style()
371
+
372
+ # Run optimization
373
+ study = run_optimization(config)
374
+
375
+ # Optionally, generate optimization plots
376
+ try:
377
+ import matplotlib.pyplot as plt
378
+ import optuna.visualization as vis
379
+
380
+ output_dir = Path(config.output_dir)
381
+
382
+ # Plot optimization history
383
+ fig = vis.matplotlib.plot_optimization_history(study).figure
384
+ fig.savefig(
385
+ output_dir / f"optimization_history_{config.method}.png",
386
+ dpi=300,
387
+ bbox_inches="tight",
388
+ )
389
+ plt.close(fig)
390
+
391
+ # Plot parameter importances
392
+ fig = vis.matplotlib.plot_param_importances(study).figure
393
+ fig.savefig(
394
+ output_dir / f"param_importances_{config.method}.png",
395
+ dpi=300,
396
+ bbox_inches="tight",
397
+ )
398
+ plt.close(fig)
399
+
400
+ # Plot parallel coordinate
401
+ fig = vis.matplotlib.plot_parallel_coordinate(study).figure
402
+ fig.savefig(
403
+ output_dir / f"parallel_coordinate_{config.method}.png",
404
+ dpi=300,
405
+ bbox_inches="tight",
406
+ )
407
+ plt.close(fig)
408
+
409
+ log.success(f"Optimization plots saved to {output_dir}")
410
+
411
+ except ImportError:
412
+ log.warning(
413
+ "Optuna visualization not available. Install with: pip install optuna[visualization]"
414
+ )
415
+
416
+
417
+ if __name__ == "__main__":
418
+ main()