import logging
import os

import gradio as gr
import numpy as np
import pandas as pd
import scipy.stats
from apscheduler.schedulers.background import BackgroundScheduler
from datasets import load_dataset
from huggingface_hub import HfApi

# Set up logging
logger = logging.getLogger("app")
logger.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
ch = logging.StreamHandler()
ch.setFormatter(formatter)
logger.addHandler(ch)

# Disable the absl logger (annoying)
logging.getLogger("absl").setLevel(logging.WARNING)

API = HfApi(token=os.environ.get("TOKEN"))
RESULTS_REPO = "open-rl-leaderboard/results_v2"
REFRESH_RATE = 5 * 60  # 5 minutes
ALL_ENV_IDS = {
    "Atari": [
        "AdventureNoFrameskip-v4",
        "AirRaidNoFrameskip-v4",
        "AlienNoFrameskip-v4",
        "AmidarNoFrameskip-v4",
        "AssaultNoFrameskip-v4",
        "AsterixNoFrameskip-v4",
        "AsteroidsNoFrameskip-v4",
        "AtlantisNoFrameskip-v4",
        "BankHeistNoFrameskip-v4",
        "BattleZoneNoFrameskip-v4",
        "BeamRiderNoFrameskip-v4",
        "BerzerkNoFrameskip-v4",
        "BowlingNoFrameskip-v4",
        "BoxingNoFrameskip-v4",
        "BreakoutNoFrameskip-v4",
        "CarnivalNoFrameskip-v4",
        "CentipedeNoFrameskip-v4",
        "ChopperCommandNoFrameskip-v4",
        "CrazyClimberNoFrameskip-v4",
        "DefenderNoFrameskip-v4",
        "DemonAttackNoFrameskip-v4",
        "DoubleDunkNoFrameskip-v4",
        "ElevatorActionNoFrameskip-v4",
        "EnduroNoFrameskip-v4",
        "FishingDerbyNoFrameskip-v4",
        "FreewayNoFrameskip-v4",
        "FrostbiteNoFrameskip-v4",
        "GopherNoFrameskip-v4",
        "GravitarNoFrameskip-v4",
        "HeroNoFrameskip-v4",
        "IceHockeyNoFrameskip-v4",
        "JamesbondNoFrameskip-v4",
        "JourneyEscapeNoFrameskip-v4",
        "KangarooNoFrameskip-v4",
        "KrullNoFrameskip-v4",
        "KungFuMasterNoFrameskip-v4",
        "MontezumaRevengeNoFrameskip-v4",
        "MsPacmanNoFrameskip-v4",
        "NameThisGameNoFrameskip-v4",
        "PhoenixNoFrameskip-v4",
        "PitfallNoFrameskip-v4",
        "PongNoFrameskip-v4",
        "PooyanNoFrameskip-v4",
        "PrivateEyeNoFrameskip-v4",
        "QbertNoFrameskip-v4",
        "RiverraidNoFrameskip-v4",
        "RoadRunnerNoFrameskip-v4",
        "RobotankNoFrameskip-v4",
        "SeaquestNoFrameskip-v4",
        "SkiingNoFrameskip-v4",
        "SolarisNoFrameskip-v4",
        "SpaceInvadersNoFrameskip-v4",
        "StarGunnerNoFrameskip-v4",
        "TennisNoFrameskip-v4",
        "TimePilotNoFrameskip-v4",
        "TutankhamNoFrameskip-v4",
        "UpNDownNoFrameskip-v4",
        "VentureNoFrameskip-v4",
        "VideoPinballNoFrameskip-v4",
        "WizardOfWorNoFrameskip-v4",
        "YarsRevengeNoFrameskip-v4",
        "ZaxxonNoFrameskip-v4",
    ],
    "Box2D": [
        "BipedalWalker-v3",
        "BipedalWalkerHardcore-v3",
        "CarRacing-v2",
        "LunarLander-v2",
        "LunarLanderContinuous-v2",
    ],
    "Toy text": [
        "Blackjack-v1",
        "CliffWalking-v0",
        "FrozenLake-v1",
        "FrozenLake8x8-v1",
    ],
    "Classic control": [
        "Acrobot-v1",
        "CartPole-v1",
        "MountainCar-v0",
        "MountainCarContinuous-v0",
        "Pendulum-v1",
    ],
    "MuJoCo": [
        "Ant-v4",
        "HalfCheetah-v4",
        "Hopper-v4",
        "Humanoid-v4",
        "HumanoidStandup-v4",
        "InvertedDoublePendulum-v4",
        "InvertedPendulum-v4",
        "Pusher-v4",
        "Reacher-v4",
        "Swimmer-v4",
        "Walker2d-v4",
    ],
    "PyBullet": [
        "AntBulletEnv-v0",
        "HalfCheetahBulletEnv-v0",
        "HopperBulletEnv-v0",
        "HumanoidBulletEnv-v0",
        "InvertedDoublePendulumBulletEnv-v0",
        "InvertedPendulumSwingupBulletEnv-v0",
        "MinitaurBulletEnv-v0",
        "ReacherBulletEnv-v0",
        "Walker2DBulletEnv-v0",
    ],
}


def iqm(x):
    return scipy.stats.trim_mean(x, proportiontocut=0.25, axis=None)


def get_leaderboard_df():
    logger.info("Downloading results")
    dataset = load_dataset(RESULTS_REPO, split="train")  # split is not important, but we need to use "train")
    df = dataset.to_pandas()  # convert to pandas dataframe
    df = df[df["status"] == "DONE"]  # keep only the models that are done
    df["iqm_episodic_return"] = df["episodic_returns"].apply(iqm)
    logger.debug("Results downloaded")
    return df


def select_env(df: pd.DataFrame, env_id: str):
    df = df[df["env_id"] == env_id]
    df = df.sort_values("iqm_episodic_return", ascending=False)
    df["ranking"] = np.arange(1, len(df) + 1)
    return df


def format_df(df: pd.DataFrame):
    # Add hyperlinks
    df = df.copy()
    for index, row in df.iterrows():
        user_id = row["user_id"]
        model_id = row["model_id"]
        df.loc[index, "user_id"] = f"[{user_id}](https://huggingface.co/{user_id})"
        df.loc[index, "model_id"] = f"[{model_id}](https://huggingface.co/{user_id}/{model_id})"

    # Keep only the relevant columns
    df = df[["ranking", "user_id", "model_id", "iqm_episodic_return"]]
    return df.values.tolist()


def refresh_video(df, env_id):
    env_df = select_env(df, env_id)
    if not env_df.empty:
        user_id = env_df.iloc[0]["user_id"]
        model_id = env_df.iloc[0]["model_id"]
        sha = env_df.iloc[0]["sha"]
        repo_id = f"{user_id}/{model_id}"
        try:
            video_path = API.hf_hub_download(repo_id=repo_id, filename="replay.mp4", revision=sha, repo_type="model")
            return video_path
        except Exception as e:
            logger.error(f"Error while downloading video for {env_id}: {e}")
            return None
    else:
        return None


def refresh_one_video(df, env_id):
    def inner():
        return refresh_video(df, env_id)

    return inner


def refresh_winner(df, env_id):
    # print("Refreshing winners")
    env_df = select_env(df, env_id)
    if not env_df.empty:
        user_id = env_df.iloc[0]["user_id"]
        model_id = env_df.iloc[0]["model_id"]
        url = f"https://huggingface.co/{user_id}/{model_id}"
        return f"""## {env_id}

### 🏆 [Best model]({url}) 🏆"""
    else:
        return f"""## {env_id}

This leaderboard is quite empty... 😢

Be the first to submit your model!
Check the tab "🚀 Getting my agent evaluated"
"""


def refresh_num_models(df):
    return f"The leaderboard currently contains {len(df):,} models."


css = """
.generating {
    border: none;
}
h2 {
    text-align: center;
}
h3 {
    text-align: center;
}

"""


def update_globals():
    global dataframes, winner_texts, video_pathes, num_models_str, df
    df = get_leaderboard_df()
    all_env_ids = [env_id for env_ids in ALL_ENV_IDS.values() for env_id in env_ids]
    dataframes = {env_id: format_df(select_env(df, env_id)) for env_id in all_env_ids}
    winner_texts = {env_id: refresh_winner(df, env_id) for env_id in all_env_ids}
    video_pathes = {env_id: refresh_video(df, env_id) for env_id in all_env_ids}
    num_models_str = refresh_num_models(df)


update_globals()


def refresh():
    global dataframes, winner_texts, num_models_str
    return list(dataframes.values()) + list(winner_texts.values()) + [num_models_str]


with gr.Blocks(css=css) as demo:
    with open("texts/heading.md") as fp:
        gr.Markdown(fp.read())
        num_models_md = gr.Markdown()
    with gr.Tabs(elem_classes="tab-buttons") as tabs:
        with gr.TabItem("🏅 Leaderboard"):
            all_gr_dfs = {}
            all_gr_winners = {}
            all_gr_videos = {}
            for env_domain, env_ids in ALL_ENV_IDS.items():
                with gr.TabItem(env_domain):
                    for env_id in env_ids:
                        # If the env_id envs with "NoFrameskip-v4", we remove it to improve readability
                        tab_env_id = env_id[: -len("NoFrameskip-v4")] if env_id.endswith("NoFrameskip-v4") else env_id
                        with gr.TabItem(tab_env_id) as tab:
                            logger.debug(f"Creating tab for {env_id}")
                            with gr.Row(equal_height=False):
                                with gr.Column(scale=3):
                                    gr_df = gr.components.Dataframe(
                                        headers=["🏆", "🧑 User", "🤖 Model id", "📊 IQM episodic return"],
                                        datatype=["number", "markdown", "markdown", "number"],
                                    )
                                with gr.Column(scale=1):
                                    with gr.Row():  # Display the env_id and the winner
                                        gr_winner = gr.Markdown()
                                    with gr.Row():  # Play the video of the best model
                                        gr_video = gr.PlayableVideo(  # Doesn't loop for the moment, see https://github.com/gradio-app/gradio/issues/7689,
                                            min_width=50,
                                            show_download_button=False,
                                            show_share_button=False,
                                            show_label=False,
                                            interactive=False,
                                        )

                            all_gr_dfs[env_id] = gr_df
                            all_gr_winners[env_id] = gr_winner
                            all_gr_videos[env_id] = gr_video

                            tab.select(refresh_one_video(df, env_id), outputs=[gr_video])

                # Load the first video of the first environment
                demo.load(refresh_one_video(df, env_ids[0]), outputs=[all_gr_videos[env_ids[0]]])

        with gr.TabItem("🚀 Getting my agent evaluated"):
            with open("texts/getting_my_agent_evaluated.md") as fp:
                gr.Markdown(fp.read())

        with gr.TabItem("📝 About"):
            with open("texts/about.md") as fp:
                gr.Markdown(fp.read())

    demo.load(refresh, outputs=list(all_gr_dfs.values()) + list(all_gr_winners.values()) + [num_models_md])

scheduler = BackgroundScheduler()
scheduler.add_job(func=update_globals, trigger="interval", seconds=REFRESH_RATE, max_instances=1)
scheduler.start()


if __name__ == "__main__":
    demo.queue().launch()