''' Base class for simulating games.'''

import os
import json
from typing import Dict, Any, List
from abc import ABC
import random
from utils.llm_utils import generate_prompt, llm_decide_move
from enum import Enum, unique


@unique
class PlayerId(Enum):
    CHANCE = -1
    SIMULTANEOUS = -2
    INVALID = -3
    TERMINAL = -4
    MEAN_FIELD = -5

    @classmethod
    def from_value(cls, value: int):
        """Returns the PlayerId corresponding to a given integer value.

        Args:
            value (int): The numerical value to map to a PlayerId.

        Returns:
            PlayerId: The matching enum member, or raises a ValueError if invalid.
        """
        for member in cls:
            if member.value == value:
                return member
        if value >= 0:  # Positive integers represent default players
            return None  # No enum corresponds to these values directly
        raise ValueError(f"Unknown player ID value: {value}")


class PlayerType(Enum):
    HUMAN = "human"
    RANDOM_BOT = "random_bot"
    LLM = "llm"
    SELF_PLAY = "self_play"


class GameSimulator(ABC):
    """Base class for simulating games with LLMs.

    Handles common functionality like state transitions, scoring, and logging.
    """

    def __init__(self, game: Any, game_name: str, llms: Dict[str, Any],
                 player_type: Dict[str, str], max_game_rounds: int = None):
        """
        Args:
            game (Any): The OpenSpiel game object being simulated.
            game_name (str): A human-readable name for the game (for logging and reporting).
            llms (Dict[str, Any]): A dictionary mapping player names (e.g., "Player 1")
                to their corresponding LLM instances. Can be empty if no LLMs are used.
            player_type (Dict[str, str]): A dictionary mapping player names to their types.
            max_game_rounds (int): Maximum number of rounds for iterated games. Ignored by single-shot games.
        """
        self.game = game
        self.game_name = game_name
        self.llms = llms
        self.player_type = player_type
        self.max_game_rounds = max_game_rounds  # For iterated games
        self.scores = {name: 0 for name in self.llms.keys()}  # Initialize scores

    def simulate(self, rounds: int = 1, log_fn=None) -> Dict[str, Any]:
        """Simulates a game for multiple rounds and computes metrics .

        Args:
            rounds: Number of times the game should be played.
            log_fn: Optional function to log intermediate states.

        Returns:
            Dict[str, Any]: Summary of results for all rounds.
        """
        outcomes = self._initialize_outcomes() # Reset the outcomes dictionary

        for _ in range(rounds):
            self.scores = {name: 0 for name in self.llms.keys()}  # Reset scores
            state = self.game.new_initial_state()

            while not state.is_terminal():
                if self.max_game_rounds is not None and state.move_number() >= self.max_game_rounds:
                    # If max_game_rounds is specified, terminate the game after the maximum number of rounds.
                    # The state.move_number() method tracks the number of moves (or rounds) within the game.
                    # This ensures that iterated games, such as the Iterated Prisoner's Dilemma,
                    # stop after the specified number of rounds, even if the game would naturally continue.
                    break
                if log_fn:
                    log_fn(state)

                # Collect actions
                current_player = state.current_player()
                player_id = self.normalize_player_id(current_player)

                if player_id == PlayerId.CHANCE.value:
                    # Handle chance nodes where the environment acts randomly.
                    self._handle_chance_node(state)
                elif player_id == PlayerId.SIMULTANEOUS.value:
                     # Handle simultaneous moves for all players.
                    actions = self._collect_actions(state)
                    state.apply_actions(actions)
                elif player_id == PlayerId.TERMINAL.value:
                    break
                elif current_player >= 0:  # Default players (turn-based)
                    legal_actions = state.legal_actions(current_player)
                    action = self._get_action(current_player, state, legal_actions)
                    state.apply_action(action)
                else:
                    raise ValueError(f"Unexpected player ID: {current_player}")

            # Record outcomes
            final_scores = state.returns()
            self._record_outcomes(final_scores, outcomes)

        return outcomes

    def _handle_chance_node(self, state: Any):
        """Handle chance nodes. Default behavior raises an error."""
        raise NotImplementedError("Chance node handling not implemented for this game.")


    def _collect_actions(self, state: Any) -> List[int]:
        """Collects actions for all players in a simultaneous-move game.

        Args:
            state: The current game state.

        Returns:
            List[int]: Actions chosen by all players.
        """
        return [
            self._get_action(player, state, state.legal_actions(player))
            for player in range(self.game.num_players())
        ]

    def _initialize_outcomes(self) -> Dict[str, Any]:
        """Initializes the outcomes dictionary."""
        return {"wins": {name: 0 for name in self.llms.keys()},
                "losses": {name: 0 for name in self.llms.keys()},
                "ties": 0
                }


    def _get_action(self, player: int, state: Any, legal_actions: List[int]) -> int:
        """Gets the action for the current player.

        Args:
            player: The index of the current player.
            state: The current game state.
            legal_actions: The legal actions available for the player.

        Returns:
            int: The action selected by the player.
        """
        player_name = f"Player {player + 1}"  # Map index to player name
        player_type = self.player_type.get(player_name)

        if player_type == PlayerType.HUMAN.value:
            return self._get_human_action(state, legal_actions)
        if player_type == PlayerType.RANDOM_BOT.value:
            return random.choice(legal_actions)
        if player_type == PlayerType.LLM.value:
            return self._get_llm_action(player, state, legal_actions)

        raise ValueError(f"Unknown player type for {player_name}: {player_type}")


    def _get_human_action(self, state: Any, legal_actions: List[int]) -> int:
        """Handles input for human players."""
        print(f"Current state of {self.game_name}:\n{state}")
        print(f"Your options: {legal_actions}") # Display legal moves to the user
        while True:
            try:
                action = int(input("Enter your action (number): "))
                if action in legal_actions: # Validate the move
                    return action
            except ValueError:
                pass
            print("Invalid action. Please choose from:", legal_actions)

    def _get_llm_action(self, player: int, state: Any, legal_actions: List[int]) -> int:
        """Handles LLM-based decisions."""
        player_name = f"Player {player + 1}"
        llm = self.llms[player_name]
        prompt = generate_prompt(self.game_name, str(state), legal_actions)
        return llm_decide_move(llm, prompt, tuple(legal_actions))

    def _apply_default_action(self, state):
        """
        Applies a default action when the current player is invalid.
        """
        state.apply_action(random.choice(state.legal_actions()))

    def _record_outcomes(self, final_scores: List[float], outcomes: Dict[str, Any]) -> str:
        """Records the outcome of a single game round.

        Args:
            final_scores (List[float]): Final cumulative scores of all players.
            outcomes (Dict[str, Any]): Dictionary to record wins, losses, and ties.

        Returns:
            str: Name of the winner or "tie" if there is no single winner.
        """
        # Check if all scores are equal (a tie)
        if all(score == final_scores[0] for score in final_scores):
            outcomes["ties"] += 1
            return "tie"

        # Find the maximum score and determine winners
        max_score = max(final_scores)
        winners = [name for i, name in enumerate(self.llms.keys()) if final_scores[i] == max_score]

        # Track losers as players who do not have the maximum score
        losers = [name for i, name in enumerate(self.llms.keys()) if final_scores[i] != max_score]

        # If there is one winner, record it; otherwise, record as a tie
        if len(winners) == 1:
            outcomes["wins"][winners[0]] += 1
            for loser in losers:
                outcomes["losses"][loser] += 1
            return winners[0]
        else:
            outcomes["ties"] += 1
            return "tie"


    def save_results(self, state: Any, final_scores: List[float]) -> None:
        """Save simulation results to a JSON file."""
        results = self._prepare_results(state, final_scores)
        filename = self._get_results_filename()

        with open(filename, "w") as f:
            json.dump(results, f, indent=4)
        print(f"Results saved to {filename}")

    def _prepare_results(self, state: Any, final_scores: List[float]) -> Dict[str, Any]:
        """Prepares the results dictionary for JSON serialization."""
        final_scores = final_scores.tolist() if hasattr(final_scores, "tolist") else final_scores
        return {
            "game_name": self.game_name,
            "final_state": str(state),
            "scores": self.scores,
            "returns": final_scores,
            "history": state.history_str(),
        }

    def _get_results_filename(self) -> str:
        """Generates the filename for saving results."""
        results_dir = "results"
        os.makedirs(results_dir, exist_ok=True)
        return os.path.join(results_dir, f"{self.game_name.lower().replace(' ', '_')}_results.json")

    def log_progress(self, state: Any) -> None:
        """Log the current game state."""
        print(f"Current state of {self.game_name}:\n{state}")

    def normalize_player_id(self,player_id):
        """Normalize player_id to its integer value for consistent comparisons.

           This is needed as OpenSpiel has ambiguous representation of the playerID

        Args:
            player_id (Union[int, PlayerId]): The player ID, which can be an
                integer or a PlayerId enum instance.

        Returns:
            int: The integer value of the player ID.
        """
        if isinstance(player_id, PlayerId):
            return player_id.value  # Extract the integer value from the enum
        return player_id  # If already an integer, return it as is