from datetime import datetime, timedelta
import json
import requests
import streamlit as st
from any_agent import AgentFramework
from any_agent.tracing.trace import _is_tracing_supported
from any_agent.evaluation import EvaluationCase
from any_agent.evaluation.schemas import CheckpointCriteria
import pandas as pd
from constants import DEFAULT_EVALUATION_CASE, MODEL_OPTIONS
import copy

from pydantic import BaseModel, ConfigDict


class UserInputs(BaseModel):
    model_config = ConfigDict(extra="forbid")
    model_id: str
    location: str
    max_driving_hours: int
    date: datetime
    framework: str
    evaluation_case: EvaluationCase
    run_evaluation: bool


@st.cache_resource
def get_area(area_name: str) -> dict:
    """Get the area from Nominatim.

    Uses the [Nominatim API](https://nominatim.org/release-docs/develop/api/Search/).

    Args:
        area_name (str): The name of the area.

    Returns:
        dict: The area found.
    """
    response = requests.get(
        f"https://nominatim.openstreetmap.org/search?q={area_name}&format=json",
        headers={"User-Agent": "Mozilla/5.0"},
        timeout=5,
    )
    response.raise_for_status()
    response_json = json.loads(response.content.decode())
    return response_json


def get_user_inputs() -> UserInputs:
    default_val = "Los Angeles California, US"

    location = st.text_input("Enter a location", value=default_val)
    if location:
        location_check = get_area(location)
        if not location_check:
            st.error("❌ Invalid location")

    max_driving_hours = st.number_input(
        "Enter the maximum driving hours", min_value=1, value=2
    )

    col_date, col_time = st.columns([2, 1])
    with col_date:
        date = st.date_input(
            "Select a date in the future", value=datetime.now() + timedelta(days=1)
        )
    with col_time:
        # default to 9am
        time = st.selectbox(
            "Select a time",
            [datetime.strptime(f"{i:02d}:00", "%H:%M").time() for i in range(24)],
            index=9,
        )
    date = datetime.combine(date, time)

    supported_frameworks = [
        framework for framework in AgentFramework if _is_tracing_supported(framework)
    ]

    framework = st.selectbox(
        "Select the agent framework to use",
        supported_frameworks,
        index=2,
        format_func=lambda x: x.name,
    )

    model_id = st.selectbox(
        "Select the model to use",
        MODEL_OPTIONS,
        index=1,
        format_func=lambda x: "/".join(x.split("/")[-3:]),
    )

    # Add evaluation case section
    with st.expander("Custom Evaluation"):
        evaluation_model_id = st.selectbox(
            "Select the model to use for LLM-as-a-Judge evaluation",
            MODEL_OPTIONS,
            index=2,
            format_func=lambda x: "/".join(x.split("/")[-3:]),
        )
        evaluation_case = copy.deepcopy(DEFAULT_EVALUATION_CASE)
        evaluation_case.llm_judge = evaluation_model_id
        # make this an editable json section
        # convert the checkpoints to a df series so that it can be edited
        checkpoints = evaluation_case.checkpoints
        checkpoints_df = pd.DataFrame(
            [checkpoint.model_dump() for checkpoint in checkpoints]
        )
        checkpoints_df = st.data_editor(
            checkpoints_df,
            column_config={
                "points": st.column_config.NumberColumn(label="Points"),
                "criteria": st.column_config.TextColumn(label="Criteria"),
            },
            hide_index=True,
            num_rows="dynamic",
        )
        # for each checkpoint, convert it back to a CheckpointCriteria object
        new_ckpts = []

        # don't let a user add more than 20 checkpoints
        if len(checkpoints_df) > 20:
            st.error(
                "You can only add up to 20 checkpoints for the purpose of this demo."
            )
            checkpoints_df = checkpoints_df[:20]

        for _, row in checkpoints_df.iterrows():
            if row["criteria"] == "":
                continue
            try:
                # Don't let people write essays for criteria in this demo
                if len(row["criteria"].split(" ")) > 100:
                    raise ValueError("Criteria is too long")
                new_crit = CheckpointCriteria(
                    criteria=row["criteria"], points=row["points"]
                )
                new_ckpts.append(new_crit)
            except Exception as e:
                st.error(f"Error creating checkpoint: {e}")
        evaluation_case.checkpoints = new_ckpts

    return UserInputs(
        model_id=model_id,
        location=location,
        max_driving_hours=max_driving_hours,
        date=date,
        framework=framework,
        evaluation_case=evaluation_case,
        run_evaluation=st.checkbox("Run Evaluation", value=True),
    )