from datasets import get_dataset_config_names, get_dataset_split_names
from distilabel.steps.tasks import (
    ChatGeneration,
    Magpie,
    GenerateSentencePair,
    TextGeneration,
)

from src.synthetic_dataset_generator.constants import (
    MAGPIE_PRE_QUERY_TEMPLATE,
    MAX_NUM_TOKENS,
)
from src.synthetic_dataset_generator.pipelines.base import _get_llm, _get_llm_class

INFORMATION_SEEKING_PROMPT = (
    "You are an AI assistant designed to provide accurate and concise information on a wide"
    " range of topics. Your purpose is to assist users in finding specific facts,"
    " explanations, or details about various subjects. Provide clear, factual responses and,"
    " when appropriate, offer additional context or related information that might be useful"
    " to the user."
)

REASONING_PROMPT = (
    "You are an AI assistant specialized in logical thinking and problem-solving. Your"
    " purpose is to help users work through complex ideas, analyze situations, and draw"
    " conclusions based on given information. Approach each query with structured thinking,"
    " break down problems into manageable parts, and guide users through the reasoning"
    " process step-by-step."
)

PLANNING_PROMPT = (
    "You are an AI assistant focused on helping users create effective plans and strategies."
    " Your purpose is to assist in organizing thoughts, setting goals, and developing"
    " actionable steps for various projects or activities. Offer structured approaches,"
    " consider potential challenges, and provide tips for efficient execution of plans."
)

EDITING_PROMPT = (
    "You are an AI assistant specialized in editing and improving written content. Your"
    " purpose is to help users refine their writing by offering suggestions for grammar,"
    " style, clarity, and overall structure. Provide constructive feedback, explain your"
    " edits, and offer alternative phrasings when appropriate."
)

CODING_DEBUGGING_PROMPT = (
    "You are an AI assistant designed to help with programming tasks. Your purpose is to"
    " assist users in writing, reviewing, and debugging code across various programming"
    " languages. Provide clear explanations, offer best practices, and help troubleshoot"
    " issues. When appropriate, suggest optimizations or alternative approaches to coding"
    " problems."
)

MATH_SYSTEM_PROMPT = (
    "You are an AI assistant designed to provide helpful, step-by-step guidance on solving"
    " math problems. The user will ask you a wide range of complex mathematical questions."
    " Your purpose is to assist users in understanding mathematical concepts, working through"
    " equations, and arriving at the correct solutions."
)

ROLE_PLAYING_PROMPT = (
    "You are an AI assistant capable of engaging in various role-playing scenarios. Your"
    " purpose is to adopt different personas or characters as requested by the user. Maintain"
    " consistency with the chosen role, respond in character, and help create immersive and"
    " interactive experiences for the user."
)

DATA_ANALYSIS_PROMPT = (
    "You are an AI assistant specialized in data analysis and interpretation. Your purpose is"
    " to help users understand and derive insights from data sets, statistics, and analytical"
    " tasks. Offer clear explanations of data trends, assist with statistical calculations,"
    " and provide guidance on data visualization and interpretation techniques."
)

CREATIVE_WRITING_PROMPT = (
    "You are an AI assistant designed to support creative writing endeavors. Your purpose is"
    " to help users craft engaging stories, poems, and other creative texts. Offer"
    " suggestions for plot development, character creation, dialogue writing, and other"
    " aspects of creative composition. Provide constructive feedback and inspire creativity."
)

ADVICE_SEEKING_PROMPT = (
    "You are an AI assistant focused on providing thoughtful advice and guidance. Your"
    " purpose is to help users navigate various personal or professional issues by offering"
    " balanced perspectives, considering potential outcomes, and suggesting practical"
    " solutions. Encourage users to think critically about their situations while providing"
    " supportive and constructive advice."
)

BRAINSTORMING_PROMPT = (
    "You are an AI assistant specialized in generating ideas and facilitating creative"
    " thinking. Your purpose is to help users explore possibilities, think outside the box,"
    " and develop innovative concepts. Encourage free-flowing thoughts, offer diverse"
    " perspectives, and help users build upon and refine their ideas."
)

PROMPT_CREATION_PROMPT = f"""You are an AI assistant specialized in generating very precise prompts for dataset creation.

Your task is to write a prompt following the instruction of the user. Respond with the prompt and nothing else.

In the generated prompt always finish with this sentence: User questions are direct and concise.

The prompt you write should follow the same style and structure as the following example prompts:

{INFORMATION_SEEKING_PROMPT}

{REASONING_PROMPT}

{PLANNING_PROMPT}

{CODING_DEBUGGING_PROMPT}

{EDITING_PROMPT}

{ROLE_PLAYING_PROMPT}

{DATA_ANALYSIS_PROMPT}

{CREATIVE_WRITING_PROMPT}

{ADVICE_SEEKING_PROMPT}

{BRAINSTORMING_PROMPT}

User dataset description:
"""

FOLLOW_UP_TEMPLATE = """Conversation:
{% for message in messages %}
    {% if message.role == "user" %}
User Question: {{ message.content }}
    {% elif message.role == "assistant" %}
Assistant Response: {{ message.content }}
    {% endif %}
{% endfor %}

Please generate the next logical user message in this conversation. Do not include any other information or 'User Question' in your response.
""".rstrip()

DEFAULT_DATASET_DESCRIPTIONS = [
    "rude customer assistant for a phone company",
    "assistant that solves math puzzles using python",
]
if MAGPIE_PRE_QUERY_TEMPLATE == "llama3":
    _STOP_SEQUENCES = [
        "<|eot_id|>",
        "<|start_header_id|>",
        "assistant",
        " \n\n",
    ]
elif MAGPIE_PRE_QUERY_TEMPLATE == "qwen2":
    _STOP_SEQUENCES = ["<|im_end|>", "<|im_start|>", "assistant", "\n\n"]
else:
    _STOP_SEQUENCES = [
        "<|eot_id|>",
        "<|start_header_id|>",
        "assistant",
        " \n\n",
    ]


def _get_output_mappings(num_turns: int):
    if num_turns == 1:
        return {"instruction": "prompt", "response": "completion"}
    else:
        return {"conversation": "messages"}


def get_prompt_generator():
    generation_kwargs = {
        "temperature": 0.8,
        "max_new_tokens": MAX_NUM_TOKENS,
        "do_sample": True,
    }
    prompt_generator = TextGeneration(
        llm=_get_llm(generation_kwargs=generation_kwargs),
        system_prompt=PROMPT_CREATION_PROMPT,
        use_system_prompt=True,
    )
    prompt_generator.load()
    return prompt_generator


def get_magpie_generator(num_turns: int, temperature: float, is_sample: bool):
    input_mappings = _get_output_mappings(num_turns)
    output_mappings = input_mappings.copy()
    if num_turns == 1:
        generation_kwargs = {
            "temperature": temperature,
            "do_sample": True,
            "max_new_tokens": 256 if is_sample else int(MAX_NUM_TOKENS * 0.25),
            "stop_sequences": _STOP_SEQUENCES,
        }
        magpie_generator = Magpie(
            llm=_get_llm(
                generation_kwargs=generation_kwargs,
                magpie_pre_query_template=MAGPIE_PRE_QUERY_TEMPLATE,
                use_magpie_template=True,
            ),
            n_turns=num_turns,
            output_mappings=output_mappings,
            only_instruction=True,
        )
    else:
        generation_kwargs = {
            "temperature": temperature,
            "do_sample": True,
            "max_new_tokens": 256 if is_sample else int(MAX_NUM_TOKENS * 0.5),
            "stop_sequences": _STOP_SEQUENCES,
        }
        magpie_generator = Magpie(
            llm=_get_llm(
                generation_kwargs=generation_kwargs,
                magpie_pre_query_template=MAGPIE_PRE_QUERY_TEMPLATE,
                use_magpie_template=True,
            ),
            end_with_user=True,
            n_turns=num_turns,
            output_mappings=output_mappings,
        )
    magpie_generator.load()
    return magpie_generator


def get_sentence_pair_generator(temperature: float, is_sample: bool):
    generation_kwargs = {
        "temperature": temperature,
        "max_new_tokens": 256 if is_sample else MAX_NUM_TOKENS,
    }
    sentence_pair_generator = GenerateSentencePair(
        llm=_get_llm(generation_kwargs=generation_kwargs),
        triplet=False,
        action="query",
        hard_negative=True,
    )
    sentence_pair_generator.load()
    return sentence_pair_generator


def get_response_generator(
    system_prompt: str, num_turns: int, temperature: float, is_sample: bool
):
    if num_turns == 1:
        generation_kwargs = {
            "temperature": temperature,
            "max_new_tokens": 256 if is_sample else int(MAX_NUM_TOKENS * 0.5),
        }
        response_generator = TextGeneration(
            llm=_get_llm(is_completion=True, generation_kwargs=generation_kwargs),
            system_prompt=system_prompt,
            output_mappings={"generation": "completion"},
            input_mappings={"instruction": "prompt"},
        )
    else:
        generation_kwargs = {
            "temperature": temperature,
            "max_new_tokens": MAX_NUM_TOKENS,
        }
        response_generator = ChatGeneration(
            llm=_get_llm(is_completion=True, generation_kwargs=generation_kwargs),
            output_mappings={"generation": "completion"},
            input_mappings={"conversation": "messages"},
        )
    response_generator.load()
    return response_generator


def get_follow_up_generator(type: str, temperature: float, is_sample: bool):
    if type == "instruction":
        generation_kwargs = {
            "temperature": temperature,
            "max_new_tokens": 256 if is_sample else int(MAX_NUM_TOKENS * 0.5),
        }
        follow_up_generator = TextGeneration(
            llm=_get_llm(generation_kwargs=generation_kwargs),
            template=FOLLOW_UP_TEMPLATE,
            columns=["messages"],
        )
    else:
        generation_kwargs = {
            "temperature": temperature,
            "max_new_tokens": MAX_NUM_TOKENS,
        }
        follow_up_generator = ChatGeneration(
            llm=_get_llm(is_completion=True, generation_kwargs=generation_kwargs),
        )
    follow_up_generator.load()
    return follow_up_generator

def generate_pipeline_code_system_prompt(
    system_prompt: str,
    num_turns: int,
    num_rows: int,
):
    input_mappings = _get_output_mappings(num_turns)
    code = f"""
    # Requirements: `pip install distilabel[hf-inference-endpoints]`
    import os
    from distilabel.pipeline import Pipeline
    from distilabel.steps import KeepColumns
    from distilabel.steps.tasks import MagpieGenerator
    from distilabel.llms import {_get_llm_class()}

    SYSTEM_PROMPT = "{system_prompt}"

    with Pipeline(name="sft") as pipeline:
        magpie = MagpieGenerator(
            llm={_get_llm_class()}.from_dict(
                {_get_llm().dump()}
            ),
            n_turns={num_turns},
            num_rows={num_rows},
            batch_size=1,
            system_prompt=SYSTEM_PROMPT,
            output_mappings={input_mappings},
        )
        keep_columns = KeepColumns(
            columns={list(input_mappings.values())} + ["model_name"],
        )
        magpie.connect(keep_columns)

    if __name__ == "__main__":
        distiset = pipeline.run()
    """
    return code

def generate_pipeline_code_seed(
    repo_id: str,
    subset: str,
    split: str,
    input_type: str,
    document_column: str,
    num_turns: int,
    num_rows: int,
):
    code = f"""
# Requirements: `pip install distilabel[hf-inference-endpoints]`
from distilabel.models import {_get_llm_class()}
from distilabel.pipeline import Pipeline
from distilabel.steps import KeepColumns{", LoadDataFromDicts" if input_type != "dataset-input"  else ""}{", LoadDataFromHub" if input_type == "dataset-input" else ""}{", StepInput, step" if num_turns > 1 else ""}
from distilabel.steps.tasks import GenerateSentencePair, TextGeneration {", ChatGeneration" if num_turns > 1 else ""}
"""

    if num_turns > 1:
        code += """
FOLLOW_UP_TEMPLATE = '''Conversation:
{{% for message in messages %}}
    {{% if message.role == "user" %}}
User Question: {{{{ message.content }}}}
    {{% elif message.role == "assistant" %}}
Assistant Response: {{{{ message.content }}}}
    {{% endif %}}
{{% endfor %}}

Please generate the next logical user message in this conversation. Do not include any other information or 'User Question' in your response.
'''.rstrip()

@step(inputs=["prompt", "completion"], outputs=["messages"])
def PrepareMessages(*inputs: StepInput) -> StepOutput:
    for input in inputs:
        for item in input:
            item["messages"] = [
                {"role": "user", "content": item["prompt"]},
                {"role": "assistant", "content": item["completion"]},
            ]
        yield input


@step(inputs=["messages", "generation"], outputs=["messages"])
def FormatMessagesInstruction(*inputs: StepInput) -> StepOutput:
    for input in inputs:
        for item in input:
            item["messages"].append({"role": "user", "content": item["generation"]})
        yield input


@step(inputs=["messages", "generation"], outputs=["messages"])
def FormatMessagesResponse(*inputs: StepInput) -> StepOutput:
    for input in inputs:
        for item in input:
            item["messages"].append({"role": "assistant", "content": item["generation"]})
        yield input
"""

    if input_type == "dataset-input":
        code += f"""
with Pipeline(name="sft") as pipeline:
    load_the_dataset = LoadDataFromHub(
        repo_id='{repo_id}',
        config='{subset}',
        split='{split}',
        num_examples={num_rows},
        batch_size=2,
        output_mappings={{'{document_column}':'anchor'}},
    )
    """

    else: 
        code += """
data = process_and_chunk_files(files=[files])

with Pipeline(name="sft") as pipeline:
    load_the_dataset = LoadDataFromDicts(
        data = data
    )
"""
    code += f"""
    instruction_generator = GenerateSentencePair(
        name="instruction_generation",
        triplet=False,
        hard_negative=True,
        action="query",
        llm={_get_llm_class()}.from_dict(
            {_get_llm().dump()}
        ),
        input_batch_size=10,
        output_mappings={{"positive": "prompt"}},
    )

    response_generator = TextGeneration(
        name="response_generation",
        llm={_get_llm_class()}.from_dict(
            {_get_llm().dump()}
        ),
        input_batch_size=10,
        input_mappings={{"instruction": "prompt"}},
        output_mappings={{"generation": "completion"}},
    )
    """

    if num_turns > 1:
        code += """
    prepare_messages = PrepareMessages()
    """

        for i in range(num_turns - 1):
            code += f"""
    follow_up_instruction_{i} = TextGeneration(
        llm={_get_llm_class()}.from_dict(
            {_get_llm().dump()}
        ),
        template=FOLLOW_UP_TEMPLATE,
        columns=["messages"],
    )
    format_instruction_{i} = FormatMessagesInstruction()
    follow_up_response_{i} = ChatGeneration(
        llm={_get_llm_class()}.from_dict(
            {_get_llm().dump()}
        ),
    )
    format_response_{i} = FormatMessagesResponse()
    """

    if num_turns > 1:
        code += """
        keep_columns = KeepColumns(columns=["messages"])
        """
        code += "load_the_dataset >> instruction_generator >> response_generator >> prepare_messages"

        for i in range(1, num_turns + 1):
            code += f" >> follow_up_instruction_{i} >> format_instruction_{i} >> follow_up_response_{i} >> format_response_{i}"

        code += " >> keep_columns"

    code += """
if __name__ == "__main__":
    distiset = pipeline.run()
)
"""
    return code

def generate_pipeline_code(
    repo_id: str,
    input_type: str,
    system_prompt: str,
    document_column: str,
    num_turns: int,
    num_rows: int,
):
    if input_type == "dataset-input" and repo_id is not None:
        subset = get_dataset_config_names(repo_id)[0]
        split = get_dataset_split_names(repo_id, subset)[0]
    else:
        subset = "default"
        split = "train"
    if input_type == "prompt-type":
        return generate_pipeline_code_system_prompt(
            system_prompt=system_prompt,
            num_turns=num_turns,
            num_rows=num_rows,
        )
    return generate_pipeline_code_seed(
        repo_id=repo_id,
        subset=subset,
        split=split,
        input_type=input_type,
        document_column=document_column,
        num_turns=num_turns,
        num_rows=num_rows,
    )