A2C playing BreakoutNoFrameskip-v4 from https://github.com/sgoodfriend/rl-algo-impls/tree/2067e21d62fff5db60168687e7d9e89019a8bfc0
32c4441
| import numpy as np | |
| import optuna | |
| from time import perf_counter | |
| from torch.utils.tensorboard.writer import SummaryWriter | |
| from typing import NamedTuple, Union | |
| from rl_algo_impls.shared.callbacks.callback import Callback | |
| from rl_algo_impls.shared.callbacks.eval_callback import evaluate | |
| from rl_algo_impls.shared.policy.policy import Policy | |
| from rl_algo_impls.shared.stats import EpisodesStats | |
| from rl_algo_impls.wrappers.episode_stats_writer import EpisodeStatsWriter | |
| from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv, find_wrapper | |
| class Evaluation(NamedTuple): | |
| eval_stat: EpisodesStats | |
| train_stat: EpisodesStats | |
| score: float | |
| class OptimizeCallback(Callback): | |
| def __init__( | |
| self, | |
| policy: Policy, | |
| env: VecEnv, | |
| trial: optuna.Trial, | |
| tb_writer: SummaryWriter, | |
| step_freq: Union[int, float] = 50_000, | |
| n_episodes: int = 10, | |
| deterministic: bool = True, | |
| ) -> None: | |
| super().__init__() | |
| self.policy = policy | |
| self.env = env | |
| self.trial = trial | |
| self.tb_writer = tb_writer | |
| self.step_freq = step_freq | |
| self.n_episodes = n_episodes | |
| self.deterministic = deterministic | |
| stats_writer = find_wrapper(policy.env, EpisodeStatsWriter) | |
| assert stats_writer | |
| self.stats_writer = stats_writer | |
| self.eval_step = 1 | |
| self.is_pruned = False | |
| self.last_eval_stat = None | |
| self.last_train_stat = None | |
| self.last_score = -np.inf | |
| def on_step(self, timesteps_elapsed: int = 1) -> bool: | |
| super().on_step(timesteps_elapsed) | |
| if self.timesteps_elapsed >= self.eval_step * self.step_freq: | |
| self.evaluate() | |
| return not self.is_pruned | |
| return True | |
| def evaluate(self) -> None: | |
| self.last_eval_stat, self.last_train_stat, score = evaluation( | |
| self.policy, | |
| self.env, | |
| self.tb_writer, | |
| self.n_episodes, | |
| self.deterministic, | |
| self.timesteps_elapsed, | |
| ) | |
| self.last_score = score | |
| self.trial.report(score, self.eval_step) | |
| if self.trial.should_prune(): | |
| self.is_pruned = True | |
| self.eval_step += 1 | |
| def evaluation( | |
| policy: Policy, | |
| env: VecEnv, | |
| tb_writer: SummaryWriter, | |
| n_episodes: int, | |
| deterministic: bool, | |
| timesteps_elapsed: int, | |
| ) -> Evaluation: | |
| start_time = perf_counter() | |
| eval_stat = evaluate( | |
| env, | |
| policy, | |
| n_episodes, | |
| deterministic=deterministic, | |
| print_returns=False, | |
| ) | |
| end_time = perf_counter() | |
| tb_writer.add_scalar( | |
| "eval/steps_per_second", | |
| eval_stat.length.sum() / (end_time - start_time), | |
| timesteps_elapsed, | |
| ) | |
| policy.train() | |
| print(f"Eval Timesteps: {timesteps_elapsed} | {eval_stat}") | |
| eval_stat.write_to_tensorboard(tb_writer, "eval", timesteps_elapsed) | |
| stats_writer = find_wrapper(policy.env, EpisodeStatsWriter) | |
| assert stats_writer | |
| train_stat = EpisodesStats(stats_writer.episodes) | |
| print(f" Train Stat: {train_stat}") | |
| score = (eval_stat.score.mean + train_stat.score.mean) / 2 | |
| print(f" Score: {round(score, 2)}") | |
| tb_writer.add_scalar( | |
| "eval/score", | |
| score, | |
| timesteps_elapsed, | |
| ) | |
| return Evaluation(eval_stat, train_stat, score) | |