CartPole-v1-a2c / a2c.py
alperenunlu's picture
Push model
0ce1450 verified
import os
import random
import time
import warnings
from dataclasses import dataclass
from datetime import datetime
from typing import Callable
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import tyro
from safetensors.torch import save_model
from torch.utils.tensorboard import SummaryWriter
from tqdm.auto import tqdm
warnings.filterwarnings("ignore", category=UserWarning)
device = torch.device(
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
@dataclass
class HyperParams:
env_id: str = "CartPole-v1"
"""The ID of the environment to train on."""
exp_name: str = os.path.basename(__file__)[: -len(".py")]
"""The name of the experiment, used for saving models and logs."""
n_envs: int = 32
"""The number of parallel environments to run."""
seed: int = 1
"""The random seed for reproducibility."""
video_capture_frequency: int = 50
"""The interval (in episodes) to record videos of the agent's performance."""
total_timesteps: int = 150_000
"""The total number of timesteps to train the agent."""
num_steps: int = 20
"The number of steps to run for each environment per update."
gamma: float = 0.99
"""The discount factor (gamma) for future rewards."""
gae_lambda: float = 0.95
"""The lambda for the Generalized Advantage Estimator (GAE)."""
ent_coef: float = 0.01
"""The coefficient for the entropy bonus."""
learning_rate: float = 3e-4
"""The learning rate for the optimizer."""
log_interval: int = 100
"""The interval (in timesteps) to log training statistics."""
evaluate: bool = True
"""Whether to evaluate the agent after training."""
eval_episodes: int = 10
"""The number of episodes to run for evaluation."""
push_model: bool = True
"""Whether to upload the saved model to huggingface"""
hf_entity: str = "alperenunlu"
"""The user or org name of the model repository from the Hugging Face Hub"""
def make_env(
env_id: str, seed: int, idx: int, video_capture_frequency: int, run_name: str
) -> Callable[[], gym.Env]:
"""Create a gym environment with specific configurations.
Args:
env_id (str): The ID of the environment to create.
seed (int): The seed for random number generation.
idx (int): The index of the environment (for vectorized environments).
video_freq (int): Frequency of recording videos (0 to disable).
run_name (str): The name of the run for saving videos.
Returns:
Callable[[], gym.Env]: A function that returns the configured environment.
"""
def _thunk() -> gym.Env:
if video_capture_frequency > 0 and idx == 0:
env = gym.make(env_id, render_mode="rgb_array")
env = gym.wrappers.RecordVideo(
env,
video_folder=f"videos/{run_name}",
episode_trigger=lambda x: x % video_capture_frequency == 0,
name_prefix=env_id,
)
else:
env = gym.make(env_id)
env = gym.wrappers.RecordEpisodeStatistics(env)
env.action_space.seed(seed)
return env
return _thunk
class ActorCritic(nn.Module):
def __init__(self, envs):
super().__init__()
self.critic = nn.Sequential(
nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64),
nn.ReLU(inplace=True),
nn.Linear(64, 64),
nn.ReLU(inplace=True),
nn.Linear(64, 1),
)
self.actor = nn.Sequential(
nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64),
nn.ReLU(inplace=True),
nn.Linear(64, 64),
nn.ReLU(inplace=True),
nn.Linear(64, envs.single_action_space.n),
)
def forward(self, states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
values = self.critic(states)
logits = self.actor(states)
return values, logits
def act(
self, states: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
values, logits = self.forward(states)
pd = torch.distributions.Categorical(logits=logits)
actions = pd.sample()
logprobs = pd.log_prob(actions)
entropy = pd.entropy()
return actions, logprobs, entropy, values
def main() -> None:
args = tyro.cli(HyperParams)
run_name = f"{args.env_id}_{args.exp_name.replace('/', '_')}_{args.seed}_{datetime.now().strftime('%y%m%d_%H%M%S')}"
print(run_name)
writer = SummaryWriter(f"runs/{run_name}")
keyval_str = "\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])
writer.add_text(
"hyperparameters",
f"|param|value|\n|-|-|\n{keyval_str}",
)
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
envs = gym.vector.AsyncVectorEnv(
[
make_env(
args.env_id, args.seed + i, i, args.video_capture_frequency, run_name
)
for i in range(int(args.n_envs))
]
)
assert isinstance(envs.single_action_space, gym.spaces.Discrete), (
"Only discrete action space is supported"
)
envs.action_space.seed(args.seed)
actor_critic = ActorCritic(envs).to(device)
optimizer = optim.AdamW(actor_critic.parameters(), lr=args.learning_rate)
start_time = time.time()
obs, _ = envs.reset(seed=args.seed)
step_pbar = tqdm(total=args.total_timesteps)
postfix_dict = dict()
for step in range(0, args.total_timesteps, envs.num_envs):
values = torch.zeros(args.num_steps, envs.num_envs, device=device)
rewards = torch.zeros(args.num_steps, envs.num_envs, device=device)
logprobs = torch.zeros(args.num_steps, envs.num_envs, device=device)
entropies = torch.zeros(args.num_steps, envs.num_envs, device=device)
masks = torch.zeros(args.num_steps, envs.num_envs, device=device)
for t in range(args.num_steps):
action, logprob, entropy, value = actor_critic.act(
torch.from_numpy(obs).to(device)
)
obs, reward, terminations, _, infos = envs.step(action.cpu().numpy())
if "episode" in infos:
mask = infos["_episode"]
r_mean = infos["episode"]["r"][mask].mean()
l_mean = infos["episode"]["l"][mask].mean()
writer.add_scalar("charts/episodic_return", r_mean, step)
writer.add_scalar("charts/episodic_length", l_mean, step)
postfix_dict.update(
r=r_mean,
l=l_mean,
)
values[t] = value.squeeze()
rewards[t] = torch.from_numpy(reward)
logprobs[t] = logprob
entropies[t] = entropy
masks[t] = torch.from_numpy(~terminations)
advantages = torch.zeros_like(rewards).to(device)
gae = 0.0
for t in reversed(range(args.num_steps - 1)):
td_error = rewards[t] + args.gamma * values[t + 1] * masks[t] - values[t]
gae = td_error + args.gamma * args.gae_lambda * masks[t] * gae
advantages[t] = gae
critic_loss = advantages.pow(2).mean()
actor_loss = (
-(logprobs * advantages.detach()).mean() - args.ent_coef * entropies.mean()
)
loss = actor_loss + critic_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step % args.log_interval < envs.num_envs:
writer.add_scalar("losses/actor_loss", actor_loss, step)
writer.add_scalar("losses/critic_loss", critic_loss, step)
writer.add_scalar("losses/entropy", entropies.mean(), step)
writer.add_scalar("charts/SPS", step // (time.time() - start_time), step)
writer.add_scalar("losses/total_loss", loss, step)
writer.add_scalar("losses/value_estimate", values.mean().item(), step)
writer.add_scalar("losses/advantage", advantages.mean().item(), step)
postfix_dict.update(
actor_loss=actor_loss.item(),
critic_loss=critic_loss.item(),
# total_loss=loss.item(),
# entropy=entropies.mean().item(),
# value_estimate=values.mean().item(),
advantage=advantages.mean().item(),
sps=step // (time.time() - start_time),
)
step_pbar.set_postfix(postfix_dict)
step_pbar.update(envs.num_envs)
envs.close()
step_pbar.close()
if args.evaluate:
run_name_eval = f"{run_name}_eval"
final_model_path = f"runs/{run_name}/{args.exp_name}_final.safetensors"
save_model(model=actor_critic, filename=final_model_path)
from hellrl.evals.a2c_eval import evaluate
episode_rewards = evaluate(
final_model_path=final_model_path,
make_env=make_env,
env_id=args.env_id,
ActorCritic=ActorCritic,
run_name_eval=run_name_eval,
device=device,
eval_episodes=args.eval_episodes,
)
for i, r in enumerate(episode_rewards):
writer.add_scalar("eval/episodic_return_eval", r, i)
if args.push_model:
from hellrl.utils.huggingface import push_model
push_model(
args=args,
episode_rewards=episode_rewards,
algo_name="A2C",
run_path=f"runs/{run_name}",
video_folder_path=f"videos/{run_name_eval}",
)
writer.close()
if __name__ == "__main__":
main()