import argparse
import json
import os
import random
import re
from collections import defaultdict


def list_directories(path):
    # 지정된 경로에 있는 항목들을 리스트로 받아옴
    items = os.listdir(path)

    # 항목들 중에서 디렉토리(폴더)만을 필터링
    directories = [item for item in items if os.path.isdir(os.path.join(path, item))]

    return directories


def parse_by_regex(string):
    varco_template_w_src = r"아래는 작업을 설명하는 명령어와 추가적 맥락을 제공하는 입력이 짝을 이루는 예제입니다.\n주어진 입력에 대해 명령어를 적절히 수행하는 응답을 작성하세요.\n\n### 입력:\n(?P<source>.*?)\n\n### 명령어:\n(?P<instruction>.*?)\n\n### 응답:\n"
    varco_template_wo_src = r"아래는 작업을 설명하는 명령어입니다.\n명령어에 따른 요청을 적절히 완료하는 응답을 작성하세요.\n\n### 명령어:\n(?P<instruction>.*?)\n\n### 응답:\n"

    if re.compile(varco_template_w_src, flags=re.MULTILINE | re.DOTALL).match(string):
        match = re.compile(varco_template_w_src, flags=re.MULTILINE | re.DOTALL).match(
            string
        )
        source = match.group("source")
        instruction = match.group("instruction")
    elif re.compile(varco_template_wo_src, flags=re.MULTILINE | re.DOTALL).match(
        string
    ):
        match = re.compile(varco_template_wo_src, flags=re.MULTILINE | re.DOTALL).match(
            string
        )
        source = ""
        instruction = match.group("instruction")
    else:
        source = None
        instruction = None

    return source, instruction


# path 에 있는 result.json 파일 읽어서 전처리된 instance들로 만든다.
def result_file_process(model, task, path):
    with open(path, encoding="utf8") as f:
        instances = json.loads(f.read())
    processed_instances = []
    for instance in instances:
        raw = instance.get("input", False)
        if raw:
            source = instance["source"]
            instruction = instance["instruction"]
        else:
            raw = instance.get("source", False)
            source, instruction = parse_by_regex(instance.get("source", False))

        if source is None or instruction is None:
            print(f"PARSING ERROR IN MODEL {model} TASK {task} PATH {path} SRC {raw}")
        else:
            processed_instances.append(
                {
                    "model_id": model,
                    "task": task,
                    "instruction": instruction.strip(),
                    "source": source.strip(),
                    "generated": instance["generated_result"],
                }
            )
    return processed_instances


# model results 디렉토리에서 결과값 변환 작업
def transform_results_folder(input_path, output_path, model_name_pattern, num_instance):
    regex_pattern = re.compile(model_name_pattern)

    tasks = list_directories(input_path)
    models = list_directories(os.path.join(input_path, tasks[0]))
    models = [model for model in models if regex_pattern.match(model)]

    model_results = {}
    print(f"TASKS: {tasks}")
    print(f"MODELS: {models}")
    for task in tasks:
        models = [
            model
            for model in list_directories(os.path.join(input_path, task))
            if regex_pattern.match(model)
        ]
        for model in models:
            result_path = os.path.join(input_path, task, model, "result.json")
            model_name = model
            if task in model:
                model_name = model.split(f"-{task}-")[0]
            instances = result_file_process(model_name, task, result_path)

            if model_name in model_results.keys():
                model_results[model_name] += instances
            else:
                model_results[model_name] = instances

        print(f"{task} results processing is over..")
    for k, v in model_results.items():
        print(f"# of instances in {k} is {len(v)}")

    dataset_by_task = defaultdict(lambda: defaultdict(list))
    for data in (
        all_datasets := [obj for obj_list in model_results.values() for obj in obj_list]
    ):
        dataset_by_task[data["task"]][
            f"{data['instruction']}\n\n{data['source']}"
        ].append(data)
    new_results = {model: [] for model in {data["model_id"] for data in all_datasets}}
    num_model = len(list(new_results.keys()))
    for task in dataset_by_task.keys():
        candidates = []
        for data in dataset_by_task[task].values():
            if len(data) != num_model:
                continue
            candidates.append(data)
        random.shuffle(candidates)
        selected = candidates[:num_instance]
        for data_list in selected:
            for data in data_list:
                new_results[data["model_id"]].append(data)

    for model in new_results.keys():
        path = os.path.join(output_path, f"{model}.jsonl")
        os.makedirs(os.path.dirname(path), exist_ok=True)
        with open(path, "w", encoding="utf8") as f_out:
            for instance in new_results[model]:
                json.dump(instance, f_out, ensure_ascii=False)
                f_out.write("\n")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-i", "--input_path", type=str, help="path of generated result directory"
    )
    parser.add_argument(
        "-o", "--output_path", type=str, help="path of processed result directory"
    )
    parser.add_argument(
        "-m",
        "--model_name_pattern",
        type=str,
        help="model name's pattern for regex",
        default="",
    )
    parser.add_argument(
        "-n", "--num_instance", type=int, help="number of instance to choice"
    )
    args = parser.parse_args()
    transform_results_folder(
        args.input_path, args.output_path, args.model_name_pattern, args.num_instance
    )