|
|
import os |
|
|
import random |
|
|
import time |
|
|
import warnings |
|
|
from dataclasses import dataclass |
|
|
from datetime import datetime |
|
|
from typing import Callable |
|
|
|
|
|
import gymnasium as gym |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.optim as optim |
|
|
import tyro |
|
|
from safetensors.torch import save_model |
|
|
from torch.utils.tensorboard import SummaryWriter |
|
|
from tqdm.auto import tqdm |
|
|
|
|
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
|
|
|
|
device = torch.device( |
|
|
"cuda" |
|
|
if torch.cuda.is_available() |
|
|
else "mps" |
|
|
if torch.backends.mps.is_available() |
|
|
else "cpu" |
|
|
) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class HyperParams: |
|
|
env_id: str = "CartPole-v1" |
|
|
"""The ID of the environment to train on.""" |
|
|
exp_name: str = os.path.basename(__file__)[: -len(".py")] |
|
|
"""The name of the experiment, used for saving models and logs.""" |
|
|
n_envs: int = 32 |
|
|
"""The number of parallel environments to run.""" |
|
|
seed: int = 1 |
|
|
"""The random seed for reproducibility.""" |
|
|
video_capture_frequency: int = 50 |
|
|
"""The interval (in episodes) to record videos of the agent's performance.""" |
|
|
|
|
|
total_timesteps: int = 150_000 |
|
|
"""The total number of timesteps to train the agent.""" |
|
|
num_steps: int = 20 |
|
|
"The number of steps to run for each environment per update." |
|
|
gamma: float = 0.99 |
|
|
"""The discount factor (gamma) for future rewards.""" |
|
|
gae_lambda: float = 0.95 |
|
|
"""The lambda for the Generalized Advantage Estimator (GAE).""" |
|
|
ent_coef: float = 0.01 |
|
|
"""The coefficient for the entropy bonus.""" |
|
|
learning_rate: float = 3e-4 |
|
|
"""The learning rate for the optimizer.""" |
|
|
|
|
|
log_interval: int = 100 |
|
|
"""The interval (in timesteps) to log training statistics.""" |
|
|
|
|
|
evaluate: bool = True |
|
|
"""Whether to evaluate the agent after training.""" |
|
|
eval_episodes: int = 10 |
|
|
"""The number of episodes to run for evaluation.""" |
|
|
|
|
|
push_model: bool = True |
|
|
"""Whether to upload the saved model to huggingface""" |
|
|
hf_entity: str = "alperenunlu" |
|
|
"""The user or org name of the model repository from the Hugging Face Hub""" |
|
|
|
|
|
|
|
|
def make_env( |
|
|
env_id: str, seed: int, idx: int, video_capture_frequency: int, run_name: str |
|
|
) -> Callable[[], gym.Env]: |
|
|
"""Create a gym environment with specific configurations. |
|
|
|
|
|
Args: |
|
|
env_id (str): The ID of the environment to create. |
|
|
seed (int): The seed for random number generation. |
|
|
idx (int): The index of the environment (for vectorized environments). |
|
|
video_freq (int): Frequency of recording videos (0 to disable). |
|
|
run_name (str): The name of the run for saving videos. |
|
|
|
|
|
Returns: |
|
|
Callable[[], gym.Env]: A function that returns the configured environment. |
|
|
""" |
|
|
|
|
|
def _thunk() -> gym.Env: |
|
|
if video_capture_frequency > 0 and idx == 0: |
|
|
env = gym.make(env_id, render_mode="rgb_array") |
|
|
env = gym.wrappers.RecordVideo( |
|
|
env, |
|
|
video_folder=f"videos/{run_name}", |
|
|
episode_trigger=lambda x: x % video_capture_frequency == 0, |
|
|
name_prefix=env_id, |
|
|
) |
|
|
else: |
|
|
env = gym.make(env_id) |
|
|
env = gym.wrappers.RecordEpisodeStatistics(env) |
|
|
env.action_space.seed(seed) |
|
|
return env |
|
|
|
|
|
return _thunk |
|
|
|
|
|
|
|
|
class ActorCritic(nn.Module): |
|
|
def __init__(self, envs): |
|
|
super().__init__() |
|
|
self.critic = nn.Sequential( |
|
|
nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Linear(64, 64), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Linear(64, 1), |
|
|
) |
|
|
|
|
|
self.actor = nn.Sequential( |
|
|
nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Linear(64, 64), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Linear(64, envs.single_action_space.n), |
|
|
) |
|
|
|
|
|
def forward(self, states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
|
|
values = self.critic(states) |
|
|
logits = self.actor(states) |
|
|
return values, logits |
|
|
|
|
|
def act( |
|
|
self, states: torch.Tensor |
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
values, logits = self.forward(states) |
|
|
pd = torch.distributions.Categorical(logits=logits) |
|
|
actions = pd.sample() |
|
|
logprobs = pd.log_prob(actions) |
|
|
entropy = pd.entropy() |
|
|
return actions, logprobs, entropy, values |
|
|
|
|
|
|
|
|
def main() -> None: |
|
|
args = tyro.cli(HyperParams) |
|
|
run_name = f"{args.env_id}_{args.exp_name.replace('/', '_')}_{args.seed}_{datetime.now().strftime('%y%m%d_%H%M%S')}" |
|
|
print(run_name) |
|
|
writer = SummaryWriter(f"runs/{run_name}") |
|
|
keyval_str = "\n".join([f"|{key}|{value}|" for key, value in vars(args).items()]) |
|
|
writer.add_text( |
|
|
"hyperparameters", |
|
|
f"|param|value|\n|-|-|\n{keyval_str}", |
|
|
) |
|
|
|
|
|
random.seed(args.seed) |
|
|
np.random.seed(args.seed) |
|
|
torch.manual_seed(args.seed) |
|
|
|
|
|
envs = gym.vector.AsyncVectorEnv( |
|
|
[ |
|
|
make_env( |
|
|
args.env_id, args.seed + i, i, args.video_capture_frequency, run_name |
|
|
) |
|
|
for i in range(int(args.n_envs)) |
|
|
] |
|
|
) |
|
|
assert isinstance(envs.single_action_space, gym.spaces.Discrete), ( |
|
|
"Only discrete action space is supported" |
|
|
) |
|
|
envs.action_space.seed(args.seed) |
|
|
|
|
|
actor_critic = ActorCritic(envs).to(device) |
|
|
optimizer = optim.AdamW(actor_critic.parameters(), lr=args.learning_rate) |
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
obs, _ = envs.reset(seed=args.seed) |
|
|
step_pbar = tqdm(total=args.total_timesteps) |
|
|
postfix_dict = dict() |
|
|
for step in range(0, args.total_timesteps, envs.num_envs): |
|
|
values = torch.zeros(args.num_steps, envs.num_envs, device=device) |
|
|
rewards = torch.zeros(args.num_steps, envs.num_envs, device=device) |
|
|
logprobs = torch.zeros(args.num_steps, envs.num_envs, device=device) |
|
|
entropies = torch.zeros(args.num_steps, envs.num_envs, device=device) |
|
|
masks = torch.zeros(args.num_steps, envs.num_envs, device=device) |
|
|
|
|
|
for t in range(args.num_steps): |
|
|
action, logprob, entropy, value = actor_critic.act( |
|
|
torch.from_numpy(obs).to(device) |
|
|
) |
|
|
|
|
|
obs, reward, terminations, _, infos = envs.step(action.cpu().numpy()) |
|
|
if "episode" in infos: |
|
|
mask = infos["_episode"] |
|
|
r_mean = infos["episode"]["r"][mask].mean() |
|
|
l_mean = infos["episode"]["l"][mask].mean() |
|
|
|
|
|
writer.add_scalar("charts/episodic_return", r_mean, step) |
|
|
writer.add_scalar("charts/episodic_length", l_mean, step) |
|
|
postfix_dict.update( |
|
|
r=r_mean, |
|
|
l=l_mean, |
|
|
) |
|
|
|
|
|
values[t] = value.squeeze() |
|
|
rewards[t] = torch.from_numpy(reward) |
|
|
logprobs[t] = logprob |
|
|
entropies[t] = entropy |
|
|
masks[t] = torch.from_numpy(~terminations) |
|
|
|
|
|
advantages = torch.zeros_like(rewards).to(device) |
|
|
gae = 0.0 |
|
|
for t in reversed(range(args.num_steps - 1)): |
|
|
td_error = rewards[t] + args.gamma * values[t + 1] * masks[t] - values[t] |
|
|
gae = td_error + args.gamma * args.gae_lambda * masks[t] * gae |
|
|
advantages[t] = gae |
|
|
|
|
|
critic_loss = advantages.pow(2).mean() |
|
|
actor_loss = ( |
|
|
-(logprobs * advantages.detach()).mean() - args.ent_coef * entropies.mean() |
|
|
) |
|
|
|
|
|
loss = actor_loss + critic_loss |
|
|
|
|
|
optimizer.zero_grad() |
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
|
|
|
if step % args.log_interval < envs.num_envs: |
|
|
writer.add_scalar("losses/actor_loss", actor_loss, step) |
|
|
writer.add_scalar("losses/critic_loss", critic_loss, step) |
|
|
writer.add_scalar("losses/entropy", entropies.mean(), step) |
|
|
writer.add_scalar("charts/SPS", step // (time.time() - start_time), step) |
|
|
writer.add_scalar("losses/total_loss", loss, step) |
|
|
writer.add_scalar("losses/value_estimate", values.mean().item(), step) |
|
|
writer.add_scalar("losses/advantage", advantages.mean().item(), step) |
|
|
postfix_dict.update( |
|
|
actor_loss=actor_loss.item(), |
|
|
critic_loss=critic_loss.item(), |
|
|
|
|
|
|
|
|
|
|
|
advantage=advantages.mean().item(), |
|
|
sps=step // (time.time() - start_time), |
|
|
) |
|
|
|
|
|
step_pbar.set_postfix(postfix_dict) |
|
|
step_pbar.update(envs.num_envs) |
|
|
envs.close() |
|
|
step_pbar.close() |
|
|
|
|
|
if args.evaluate: |
|
|
run_name_eval = f"{run_name}_eval" |
|
|
final_model_path = f"runs/{run_name}/{args.exp_name}_final.safetensors" |
|
|
save_model(model=actor_critic, filename=final_model_path) |
|
|
from hellrl.evals.a2c_eval import evaluate |
|
|
|
|
|
episode_rewards = evaluate( |
|
|
final_model_path=final_model_path, |
|
|
make_env=make_env, |
|
|
env_id=args.env_id, |
|
|
ActorCritic=ActorCritic, |
|
|
run_name_eval=run_name_eval, |
|
|
device=device, |
|
|
eval_episodes=args.eval_episodes, |
|
|
) |
|
|
for i, r in enumerate(episode_rewards): |
|
|
writer.add_scalar("eval/episodic_return_eval", r, i) |
|
|
|
|
|
if args.push_model: |
|
|
from hellrl.utils.huggingface import push_model |
|
|
|
|
|
push_model( |
|
|
args=args, |
|
|
episode_rewards=episode_rewards, |
|
|
algo_name="A2C", |
|
|
run_path=f"runs/{run_name}", |
|
|
video_folder_path=f"videos/{run_name_eval}", |
|
|
) |
|
|
|
|
|
writer.close() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|