A2C playing BreakoutNoFrameskip-v4 from https://github.com/sgoodfriend/rl-algo-impls/tree/0511de345b17175b7cf1ea706c3e05981f11761c
e81ed1e
| import os | |
| os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" | |
| import argparse | |
| import requests | |
| import shutil | |
| import subprocess | |
| import tempfile | |
| import wandb | |
| import wandb.apis.public | |
| from typing import List, Optional | |
| from huggingface_hub.hf_api import HfApi, upload_folder | |
| from huggingface_hub.repocard import metadata_save | |
| from pyvirtualdisplay.display import Display | |
| from rl_algo_impls.publish.markdown_format import EvalTableData, model_card_text | |
| from rl_algo_impls.runner.config import EnvHyperparams | |
| from rl_algo_impls.runner.evaluate import EvalArgs, evaluate_model | |
| from rl_algo_impls.shared.vec_env import make_eval_env | |
| from rl_algo_impls.shared.callbacks.eval_callback import evaluate | |
| from rl_algo_impls.wrappers.vec_episode_recorder import VecEpisodeRecorder | |
| def publish( | |
| wandb_run_paths: List[str], | |
| wandb_report_url: str, | |
| huggingface_user: Optional[str] = None, | |
| huggingface_token: Optional[str] = None, | |
| virtual_display: bool = False, | |
| ) -> None: | |
| if virtual_display: | |
| display = Display(visible=False, size=(1400, 900)) | |
| display.start() | |
| api = wandb.Api() | |
| runs = [api.run(rp) for rp in wandb_run_paths] | |
| algo = runs[0].config["algo"] | |
| hyperparam_id = runs[0].config["env"] | |
| evaluations = [ | |
| evaluate_model( | |
| EvalArgs( | |
| algo, | |
| hyperparam_id, | |
| seed=r.config.get("seed", None), | |
| render=False, | |
| best=True, | |
| n_envs=None, | |
| n_episodes=10, | |
| no_print_returns=True, | |
| wandb_run_path="/".join(r.path), | |
| ), | |
| os.getcwd(), | |
| ) | |
| for r in runs | |
| ] | |
| run_metadata = requests.get(runs[0].file("wandb-metadata.json").url).json() | |
| table_data = list(EvalTableData(r, e) for r, e in zip(runs, evaluations)) | |
| best_eval = sorted( | |
| table_data, key=lambda d: d.evaluation.stats.score, reverse=True | |
| )[0] | |
| with tempfile.TemporaryDirectory() as tmpdirname: | |
| _, (policy, stats, config) = best_eval | |
| repo_name = config.model_name(include_seed=False) | |
| repo_dir_path = os.path.join(tmpdirname, repo_name) | |
| # Locally clone this repo to a temp directory | |
| subprocess.run(["git", "clone", ".", repo_dir_path]) | |
| shutil.rmtree(os.path.join(repo_dir_path, ".git")) | |
| model_path = config.model_dir_path(best=True, downloaded=True) | |
| shutil.copytree( | |
| model_path, | |
| os.path.join( | |
| repo_dir_path, "saved_models", config.model_dir_name(best=True) | |
| ), | |
| ) | |
| github_url = "https://github.com/sgoodfriend/rl-algo-impls" | |
| commit_hash = run_metadata.get("git", {}).get("commit", None) | |
| env_id = runs[0].config.get("env_id") or runs[0].config["env"] | |
| card_text = model_card_text( | |
| algo, | |
| env_id, | |
| github_url, | |
| commit_hash, | |
| wandb_report_url, | |
| table_data, | |
| best_eval, | |
| ) | |
| readme_filepath = os.path.join(repo_dir_path, "README.md") | |
| os.remove(readme_filepath) | |
| with open(readme_filepath, "w") as f: | |
| f.write(card_text) | |
| metadata = { | |
| "library_name": "rl-algo-impls", | |
| "tags": [ | |
| env_id, | |
| algo, | |
| "deep-reinforcement-learning", | |
| "reinforcement-learning", | |
| ], | |
| "model-index": [ | |
| { | |
| "name": algo, | |
| "results": [ | |
| { | |
| "metrics": [ | |
| { | |
| "type": "mean_reward", | |
| "value": str(stats.score), | |
| "name": "mean_reward", | |
| } | |
| ], | |
| "task": { | |
| "type": "reinforcement-learning", | |
| "name": "reinforcement-learning", | |
| }, | |
| "dataset": { | |
| "name": env_id, | |
| "type": env_id, | |
| }, | |
| } | |
| ], | |
| } | |
| ], | |
| } | |
| metadata_save(readme_filepath, metadata) | |
| video_env = VecEpisodeRecorder( | |
| make_eval_env( | |
| config, | |
| EnvHyperparams(**config.env_hyperparams), | |
| override_n_envs=1, | |
| normalize_load_path=model_path, | |
| ), | |
| os.path.join(repo_dir_path, "replay"), | |
| max_video_length=3600, | |
| ) | |
| evaluate( | |
| video_env, | |
| policy, | |
| 1, | |
| deterministic=config.eval_params.get("deterministic", True), | |
| ) | |
| api = HfApi() | |
| huggingface_user = huggingface_user or api.whoami()["name"] | |
| huggingface_repo = f"{huggingface_user}/{repo_name}" | |
| api.create_repo( | |
| token=huggingface_token, | |
| repo_id=huggingface_repo, | |
| private=False, | |
| exist_ok=True, | |
| ) | |
| repo_url = upload_folder( | |
| repo_id=huggingface_repo, | |
| folder_path=repo_dir_path, | |
| path_in_repo="", | |
| commit_message=f"{algo.upper()} playing {env_id} from {github_url}/tree/{commit_hash}", | |
| token=huggingface_token, | |
| delete_patterns="*", | |
| ) | |
| print(f"Pushed model to the hub: {repo_url}") | |
| def huggingface_publish(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--wandb-run-paths", | |
| type=str, | |
| nargs="+", | |
| help="Run paths of the form entity/project/run_id", | |
| ) | |
| parser.add_argument("--wandb-report-url", type=str, help="Link to WandB report") | |
| parser.add_argument( | |
| "--huggingface-user", | |
| type=str, | |
| help="Huggingface user or team to upload model cards", | |
| default=None, | |
| ) | |
| parser.add_argument( | |
| "--virtual-display", action="store_true", help="Use headless virtual display" | |
| ) | |
| args = parser.parse_args() | |
| print(args) | |
| publish(**vars(args)) | |
| if __name__ == "__main__": | |
| huggingface_publish() | |