import os import random import time import warnings from dataclasses import dataclass from datetime import datetime from typing import Callable import gymnasium as gym import numpy as np import torch import torch.nn as nn import torch.optim as optim import tyro from safetensors.torch import save_model from torch.utils.tensorboard import SummaryWriter from tqdm.auto import tqdm warnings.filterwarnings("ignore", category=UserWarning) device = torch.device( "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" ) @dataclass class HyperParams: env_id: str = "CartPole-v1" """The ID of the environment to train on.""" exp_name: str = os.path.basename(__file__)[: -len(".py")] """The name of the experiment, used for saving models and logs.""" n_envs: int = 32 """The number of parallel environments to run.""" seed: int = 1 """The random seed for reproducibility.""" video_capture_frequency: int = 50 """The interval (in episodes) to record videos of the agent's performance.""" total_timesteps: int = 150_000 """The total number of timesteps to train the agent.""" num_steps: int = 20 "The number of steps to run for each environment per update." gamma: float = 0.99 """The discount factor (gamma) for future rewards.""" gae_lambda: float = 0.95 """The lambda for the Generalized Advantage Estimator (GAE).""" ent_coef: float = 0.01 """The coefficient for the entropy bonus.""" learning_rate: float = 3e-4 """The learning rate for the optimizer.""" log_interval: int = 100 """The interval (in timesteps) to log training statistics.""" evaluate: bool = True """Whether to evaluate the agent after training.""" eval_episodes: int = 10 """The number of episodes to run for evaluation.""" push_model: bool = True """Whether to upload the saved model to huggingface""" hf_entity: str = "alperenunlu" """The user or org name of the model repository from the Hugging Face Hub""" def make_env( env_id: str, seed: int, idx: int, video_capture_frequency: int, run_name: str ) -> Callable[[], gym.Env]: """Create a gym environment with specific configurations. Args: env_id (str): The ID of the environment to create. seed (int): The seed for random number generation. idx (int): The index of the environment (for vectorized environments). video_freq (int): Frequency of recording videos (0 to disable). run_name (str): The name of the run for saving videos. Returns: Callable[[], gym.Env]: A function that returns the configured environment. """ def _thunk() -> gym.Env: if video_capture_frequency > 0 and idx == 0: env = gym.make(env_id, render_mode="rgb_array") env = gym.wrappers.RecordVideo( env, video_folder=f"videos/{run_name}", episode_trigger=lambda x: x % video_capture_frequency == 0, name_prefix=env_id, ) else: env = gym.make(env_id) env = gym.wrappers.RecordEpisodeStatistics(env) env.action_space.seed(seed) return env return _thunk class ActorCritic(nn.Module): def __init__(self, envs): super().__init__() self.critic = nn.Sequential( nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64), nn.ReLU(inplace=True), nn.Linear(64, 64), nn.ReLU(inplace=True), nn.Linear(64, 1), ) self.actor = nn.Sequential( nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64), nn.ReLU(inplace=True), nn.Linear(64, 64), nn.ReLU(inplace=True), nn.Linear(64, envs.single_action_space.n), ) def forward(self, states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: values = self.critic(states) logits = self.actor(states) return values, logits def act( self, states: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: values, logits = self.forward(states) pd = torch.distributions.Categorical(logits=logits) actions = pd.sample() logprobs = pd.log_prob(actions) entropy = pd.entropy() return actions, logprobs, entropy, values def main() -> None: args = tyro.cli(HyperParams) run_name = f"{args.env_id}_{args.exp_name.replace('/', '_')}_{args.seed}_{datetime.now().strftime('%y%m%d_%H%M%S')}" print(run_name) writer = SummaryWriter(f"runs/{run_name}") keyval_str = "\n".join([f"|{key}|{value}|" for key, value in vars(args).items()]) writer.add_text( "hyperparameters", f"|param|value|\n|-|-|\n{keyval_str}", ) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) envs = gym.vector.AsyncVectorEnv( [ make_env( args.env_id, args.seed + i, i, args.video_capture_frequency, run_name ) for i in range(int(args.n_envs)) ] ) assert isinstance(envs.single_action_space, gym.spaces.Discrete), ( "Only discrete action space is supported" ) envs.action_space.seed(args.seed) actor_critic = ActorCritic(envs).to(device) optimizer = optim.AdamW(actor_critic.parameters(), lr=args.learning_rate) start_time = time.time() obs, _ = envs.reset(seed=args.seed) step_pbar = tqdm(total=args.total_timesteps) postfix_dict = dict() for step in range(0, args.total_timesteps, envs.num_envs): values = torch.zeros(args.num_steps, envs.num_envs, device=device) rewards = torch.zeros(args.num_steps, envs.num_envs, device=device) logprobs = torch.zeros(args.num_steps, envs.num_envs, device=device) entropies = torch.zeros(args.num_steps, envs.num_envs, device=device) masks = torch.zeros(args.num_steps, envs.num_envs, device=device) for t in range(args.num_steps): action, logprob, entropy, value = actor_critic.act( torch.from_numpy(obs).to(device) ) obs, reward, terminations, _, infos = envs.step(action.cpu().numpy()) if "episode" in infos: mask = infos["_episode"] r_mean = infos["episode"]["r"][mask].mean() l_mean = infos["episode"]["l"][mask].mean() writer.add_scalar("charts/episodic_return", r_mean, step) writer.add_scalar("charts/episodic_length", l_mean, step) postfix_dict.update( r=r_mean, l=l_mean, ) values[t] = value.squeeze() rewards[t] = torch.from_numpy(reward) logprobs[t] = logprob entropies[t] = entropy masks[t] = torch.from_numpy(~terminations) advantages = torch.zeros_like(rewards).to(device) gae = 0.0 for t in reversed(range(args.num_steps - 1)): td_error = rewards[t] + args.gamma * values[t + 1] * masks[t] - values[t] gae = td_error + args.gamma * args.gae_lambda * masks[t] * gae advantages[t] = gae critic_loss = advantages.pow(2).mean() actor_loss = ( -(logprobs * advantages.detach()).mean() - args.ent_coef * entropies.mean() ) loss = actor_loss + critic_loss optimizer.zero_grad() loss.backward() optimizer.step() if step % args.log_interval < envs.num_envs: writer.add_scalar("losses/actor_loss", actor_loss, step) writer.add_scalar("losses/critic_loss", critic_loss, step) writer.add_scalar("losses/entropy", entropies.mean(), step) writer.add_scalar("charts/SPS", step // (time.time() - start_time), step) writer.add_scalar("losses/total_loss", loss, step) writer.add_scalar("losses/value_estimate", values.mean().item(), step) writer.add_scalar("losses/advantage", advantages.mean().item(), step) postfix_dict.update( actor_loss=actor_loss.item(), critic_loss=critic_loss.item(), # total_loss=loss.item(), # entropy=entropies.mean().item(), # value_estimate=values.mean().item(), advantage=advantages.mean().item(), sps=step // (time.time() - start_time), ) step_pbar.set_postfix(postfix_dict) step_pbar.update(envs.num_envs) envs.close() step_pbar.close() if args.evaluate: run_name_eval = f"{run_name}_eval" final_model_path = f"runs/{run_name}/{args.exp_name}_final.safetensors" save_model(model=actor_critic, filename=final_model_path) from hellrl.evals.a2c_eval import evaluate episode_rewards = evaluate( final_model_path=final_model_path, make_env=make_env, env_id=args.env_id, ActorCritic=ActorCritic, run_name_eval=run_name_eval, device=device, eval_episodes=args.eval_episodes, ) for i, r in enumerate(episode_rewards): writer.add_scalar("eval/episodic_return_eval", r, i) if args.push_model: from hellrl.utils.huggingface import push_model push_model( args=args, episode_rewards=episode_rewards, algo_name="A2C", run_path=f"runs/{run_name}", video_folder_path=f"videos/{run_name_eval}", ) writer.close() if __name__ == "__main__": main()