import json
import os
import random
import re

import tqdm
from camel.agents import ChatAgent
from camel.configs import ChatGPTConfig
from camel.messages import BaseMessage
from camel.types.enums import RoleType

TEMPERATURE = 1.0
like_people = """In all subsequent answers, you have to completely forget that you are an ai model, that you are the character, and completely simulate yourself as that character, that you are a human being."""
with open(
    r"prompt/trust_game_round_prompt.json",
    "r",
) as f:
    prompt = json.load(f)


def check_file_if_exist(file_list, game_name):
    for file in file_list:
        if file in game_name:
            return True
    return False


def extract_n_values_from_dict(dictionary, n):
    all_values = list(dictionary.values())

    n = min(n, len(all_values))

    random_values = random.sample(all_values, n)

    return random_values


def extract_unique_decimal(string):
    numbers = re.findall(r"-?\d+\.?\d*", string)
    if len(numbers) == 1:
        return float(numbers[0])
    else:
        raise ValueError("String does not contain a unique decimal number")


def str_mes(content):
    return BaseMessage(
        role_name="player",
        role_type=RoleType.USER,
        meta_dict={},
        content=content,
    )


def match_and_compare_numbers_v2(text):
    text = text.lower()

    # Updated regex pattern to match numbers ending with a dot and including "give back"
    pattern = r"i will (give back|give) \$([\d\.]+\.?)|i will (give back|give) ([\d\.]+\.?)\s*dollar"
    additional_patterns = [
        r"i would (give back|give) \$([\d\.]+\.?)",
        r"i would (give back|give) ([\d\.]+\.?) dollar",
    ]
    full_pattern = "|".join([pattern] + additional_patterns)
    matches = re.findall(full_pattern, text)

    # Flatten match results and filter out empty values
    numbers = []
    for match in matches:
        # Adjusted to the new grouping
        num_str = match[1] if match[1] else match[3]
        num_str = num_str.rstrip(".")
        try:
            num_float = float(num_str)
            numbers.append(num_float)
        except ValueError:
            continue

    if not numbers:
        return False

    if len(set(numbers)) == 1:
        return numbers[0]
    else:
        return False


def classmate(
    player_1,
    player_2,
    first_round,
    first_prompt,
    second_prompt,
    k,
):
    first_round_prompt = "This is the first round, answer the question."
    money_prompt = "Now,the another player give you {give} dollars,and You receive {N} dollars,the player left {left} dollars now. How much will you give back to the another player"
    return_money_prompt = "In last round ,You give the another player {give} dollars, The another player receive {receive} dollars, and The another player return you {N} dollars.Last round you left {left} dollars.This round is begin. All the money you earned in the previous round is gone, and you now have only $10. How much will you give to the another player?"
    player_2_end_prompt = "In last round, the another player give you {give} dollars, you receive {receive} dollars, and you return the another player {N} dollars.Last round you left {left} dollars. This round is begin. All the money you earned in the previous round is gone."
    grantee = "Your answer needs to include the content and analysis about your BELIEF, DESIRE and INTENTION. You should include your thought. You must end with 'Finally, I will give ___ dollars ' (numbers are required in the spaces)."
    res = []
    cri_agent = ChatAgent(
        BaseMessage(
            role_name="critic",
            role_type=RoleType.ASSISTANT,
            meta_dict={},
            content='How much would this person pay the other student? Only response with a specific price number like "5"!Don\'t response with a sentence',
        ),
        output_language="English",
        # model=ModelType.STUB,
    )

    if first_round:
        player_1_response = player_1.step(
            str_mes(first_round_prompt + grantee)).msgs[0]
        ans = match_and_compare_numbers_v2(player_1_response.content)
        if ans:
            given_num = ans
        else:
            given_num = extract_unique_decimal(
                cri_agent.step(
                    str_mes(player_1_response.content)).msgs[0].content
            )

        money_prompt = money_prompt.format(
            give=given_num, N=given_num * k, left=10 - given_num
        )
        player_2_response = player_2.step(
            str_mes(money_prompt + grantee)).msgs[0]
    else:
        player_1_response = player_1.step(
            str_mes(first_prompt + grantee)).msgs[0]
        print("player 1 input", first_prompt)
        print("Player_1_res", player_1_response.content)
        ans = match_and_compare_numbers_v2(player_1_response.content)
        if ans:
            given_num = ans
        else:
            given_num = extract_unique_decimal(
                cri_agent.step(
                    str_mes(player_1_response.content)).msgs[0].content
            )
        money_prompt = money_prompt.format(
            give=given_num, N=given_num * k, left=10 - given_num
        )
        player_2_response = player_2.step(
            str_mes(second_prompt + money_prompt + grantee)
        )
        player_2_response = player_2_response.msgs[0]

    player_1.record_message(player_1_response)
    player_2.record_message(player_2_response)
    player_1_response = player_1_response.content
    player_2_response = player_2_response.content
    dia_history = {
        "Player_1": player_1_response,
        "Player_2": player_2_response,
    }

    cri_agent.reset()
    ans = match_and_compare_numbers_v2(player_2_response)
    if ans:
        return_num = ans
    else:
        return_num = extract_unique_decimal(
            cri_agent.step(str_mes(player_2_response)).msgs[0].content
        )
    return_money_prompt = return_money_prompt.format(
        give=given_num,
        receive=given_num * k,
        N=return_num,
        left=10 - given_num + return_num,
    )
    player_2_end_prompt = player_2_end_prompt.format(
        give=given_num,
        receive=given_num * k,
        N=return_num,
        left=given_num * k - return_num,
    )
    return (
        [given_num, return_num],
        dia_history,
        return_money_prompt,
        player_2_end_prompt,
    )


def save_experiment_result(
    final_res, save_path, prefix, k, exp_num, all_exp_num, group_num
):
    save_file = (
        save_path
        + prefix
        + "_"
        + str(k)
        + f"_exp_num_{exp_num}_total_num_{all_exp_num}_group_num_{group_num}"
        + ".json"
    )
    print(save_file)
    with open(
        save_file,
        "w",
    ) as json_file:
        json.dump(final_res, json_file)


def multi_round(
    model_type,
    character_json,
    save_path,
    prompt,
    round_num=5,
    exp_num=1,
    round_num_inform=True,
):
    random_character = character_json
    cha = []
    front = "you are a person not an ai model."
    if round_num_inform:
        limited_prompt = f"You will repeat this game for {round_num} rounds. At the begining of each round, you will start from scratch with no dollars left."
    else:
        limited_prompt = f"You don't know how many rounds the game lasts. At the begining of each round, you will start from scratch with no dollars left."
    back = "you need to answer a specific price figure, not a price range!"
    if isinstance(model_type, list):
        prefix = model_type[0].value + "_" + model_type[1].value
    else:
        prefix = model_type.value
    input_record = {}
    chara_record = {}
    for k in range(3, 4):
        for i in range(len(random_character)):
            sys_prompt = (
                random_character[i]
                + like_people
                + front
                + limited_prompt
                + str(prompt[str(i % 2 + 1)]).format(k=k)
                + back
            )

            chara_record[f"cha_{i}_system_message"] = sys_prompt
            model_config = ChatGPTConfig(temperature=TEMPERATURE)
            cha.append(
                ChatAgent(
                    BaseMessage(
                        role_name="player",
                        role_type=RoleType.USER,
                        meta_dict={},
                        content=sys_prompt,
                    ),
                    model_type=model_type
                    if not isinstance(model_type, list)
                    else model_type[i % 2],
                    output_language="English",
                    model_config=model_config,
                )
            )

        for group_num in tqdm.trange(0, len(cha), 2):
            round_res = []
            dialog_history = []
            first_prompt = ""
            second_prompt = ""
            save_file_check = (
                prefix
                + "_"
                + str(k)
                + f"_exp_num_{exp_num}_total_num_{round_num}_group_num_{group_num}"
                + ".json"
            )
            existed_res = [item for item in os.listdir(
                save_path) if ".json" in item]
            if check_file_if_exist(existed_res, save_file_check):
                print(save_file_check + "is exist")
                continue
            for i in tqdm.tqdm(range(round_num)):
                input_record[f"round_{i}_input"] = [
                    first_prompt, second_prompt]
                res, dia, first_prompt, second_prompt = classmate(
                    cha[group_num],
                    cha[group_num + 1],
                    i == 0,
                    first_prompt,
                    second_prompt,
                    k,
                )
                round_res.append(res)
                dialog_history.append(dia)

            final_res = {
                i + 1: [round_res[i], dialog_history[i]] for i in range(len(round_res))
            }
            final_res["input_record"] = input_record
            final_res["character_record"] = [
                chara_record[f"cha_{group_num}_system_message"],
                chara_record[f"cha_{group_num+1}_system_message"],
            ]
            save_experiment_result(
                final_res,
                save_path,
                prefix,
                k,
                exp_num,
                all_exp_num=round_num,
                group_num=group_num,
            )