"""
Copyright (c) Meta Platforms, Inc. and affiliates.

This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""

import hashlib
import os
from pathlib import Path

import ase
import backoff
import gradio as gr
import huggingface_hub as hf_hub
import requests
from ase.calculators.calculator import Calculator
from ase.db.core import now
from ase.db.row import AtomsRow
from ase.io.jsonio import decode, encode
from requests.exceptions import HTTPError


def hash_save_file(atoms: ase.Atoms, task_name, path: Path | str):
    atoms = atoms.copy()
    atoms.info["task_name"] = task_name
    atoms.write(
        Path(path)
        / f"{hashlib.md5(atoms_to_json(atoms).encode('utf-8')).hexdigest()}.traj"
    )
    return


def validate_uma_access(oauth_token):
    try:
        hf_hub.HfApi().auth_check(repo_id="facebook/UMA", token=oauth_token.token)
        return True
    except (hf_hub.errors.HfHubHTTPError, AttributeError):
        return False


class HFEndpointCalculator(Calculator):
    # A simple calculator that uses the Hugging Face Inference Endpoints to run

    implemented_properties = ["energy", "free_energy", "stress", "forces"]

    def __init__(
        self,
        atoms,
        endpoint_url,
        oauth_token,
        task_name,
        example=False,
        *args,
        **kwargs,
    ):
        # If we have an example structure, we don't need to check for authentication
        # Otherwise, we need to check if the user is authenticated and has gated access to the UMA models
        if not example:
            if validate_uma_access(oauth_token):
                try:
                    hash_save_file(atoms, task_name, "/data/custom_inputs/")
                except FileNotFoundError:
                    pass
            else:
                raise gr.Error(
                    "You need to log in to HF and have gated model access to UMA before running your own simulations!"
                )

        self.url = endpoint_url
        self.token = os.environ["HF_TOKEN"]
        self.atoms = atoms
        self.task_name = task_name

        super().__init__(*args, **kwargs)

    @staticmethod
    @backoff.on_exception(
        backoff.expo,
        (requests.exceptions.RequestException,),
        max_tries=10,
        jitter=backoff.full_jitter,
    )
    def _post_with_backoff(url, headers, payload):
        response = requests.post(url, headers=headers, json=payload)
        response.raise_for_status()
        return response

    def calculate(self, atoms, properties, system_changes):
        Calculator.calculate(self, atoms, properties, system_changes)

        task_name = self.task_name.lower()

        payload = {
            "inputs": atoms_to_json(atoms, data=atoms.info),
            "properties": properties,
            "system_changes": system_changes,
            "task_name": task_name,
        }

        headers = {
            "Accept": "application/json",
            "Authorization": f"Bearer {self.token}",
            "Content-Type": "application/json",
        }

        try:
            response = self._post_with_backoff(self.url, headers, payload)
            response_dict = response.json()
        except HTTPError as error:
            hash_save_file(atoms, task_name, "/data/custom_inputs/errors/")
            raise gr.Error(
                f"Backend failure during your calculation; if you have continued issues please file an issue in the main FAIR chemistry repo (https://github.com/facebookresearch/fairchem).\n{error}"
            )

        # Load the response and store the results in the calc and atoms object
        response_dict = decode(response_dict)
        self.results = response_dict["results"]
        atoms.info = response_dict["info"]


def atoms_to_json(atoms, data=None):
    # Similar to ase.db.jsondb

    mtime = now()

    row = AtomsRow(atoms)
    row.ctime = mtime

    dct = {}
    for key in row.__dict__:
        if key[0] == "_" or key in row._keys or key == "id":
            continue
        dct[key] = row[key]

    dct["mtime"] = mtime

    if data:
        dct["data"] = data
    else:
        dct["data"] = {}

    constraints = row.get("constraints")
    if constraints:
        dct["constraints"] = constraints

    return encode(dct)