# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: MIT
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
import os
import time
import json
import logging
import gc
import torch
from pathlib import Path
from trt_llama_api import TrtLlmAPI
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from collections import defaultdict
from llama_index import ServiceContext
from llama_index.llms.llama_utils import messages_to_prompt, completion_to_prompt
from llama_index import set_global_service_context
from faiss_vector_storage import FaissEmbeddingStorage
from ui.user_interface import MainInterface

app_config_file = 'config\\app_config.json'
model_config_file = 'config\\config.json'
preference_config_file = 'config\\preferences.json'
data_source = 'directory'

def read_config(file_name):
    try:
        with open(file_name, 'r') as file:
            return json.load(file)
    except FileNotFoundError:
        print(f"The file {file_name} was not found.")
    except json.JSONDecodeError:
        print(f"There was an error decoding the JSON from the file {file_name}.")
    except Exception as e:
        print(f"An unexpected error occurred: {e}")
    return None

def get_model_config(config, model_name=None):
    models = config["models"]["supported"]
    selected_model = next((model for model in models if model["name"] == model_name), models[0])
    return {
        "model_path": os.path.join(os.getcwd(), selected_model["metadata"]["model_path"]),
        "engine": selected_model["metadata"]["engine"],
        "tokenizer_path": os.path.join(os.getcwd(), selected_model["metadata"]["tokenizer_path"]),
        "max_new_tokens": selected_model["metadata"]["max_new_tokens"],
        "max_input_token": selected_model["metadata"]["max_input_token"],
        "temperature": selected_model["metadata"]["temperature"]
    }

def get_data_path(config):
    return os.path.join(os.getcwd(), config["dataset"]["path"])

# read the app specific config
app_config = read_config(app_config_file)
streaming = app_config["streaming"]
similarity_top_k = app_config["similarity_top_k"]
is_chat_engine = app_config["is_chat_engine"]
embedded_model_name = app_config["embedded_model"]
embedded_model = os.path.join(os.getcwd(), "model", embedded_model_name)
embedded_dimension = app_config["embedded_dimension"]

# read model specific config
selected_model_name = None
selected_data_directory = None
config = read_config(model_config_file)
if os.path.exists(preference_config_file):
    perf_config = read_config(preference_config_file)
    selected_model_name = perf_config.get('models', {}).get('selected')
    selected_data_directory = perf_config.get('dataset', {}).get('path')

if selected_model_name == None:
    selected_model_name = config["models"].get("selected")

model_config = get_model_config(config, selected_model_name)
trt_engine_path = model_config["model_path"]
trt_engine_name = model_config["engine"]
tokenizer_dir_path = model_config["tokenizer_path"]
data_dir = config["dataset"]["path"] if selected_data_directory == None else selected_data_directory

# create trt_llm engine object
llm = TrtLlmAPI(
    model_path=model_config["model_path"],
    engine_name=model_config["engine"],
    tokenizer_dir=model_config["tokenizer_path"],
    temperature=model_config["temperature"],
    max_new_tokens=model_config["max_new_tokens"],
    context_window=model_config["max_input_token"],
    messages_to_prompt=messages_to_prompt,
    completion_to_prompt=completion_to_prompt,
    verbose=False
)

# create embeddings model object
embed_model = HuggingFaceEmbeddings(model_name=embedded_model)
service_context = ServiceContext.from_defaults(llm=llm, embed_model=embed_model,
                                               context_window=model_config["max_input_token"], chunk_size=512,
                                               chunk_overlap=200)
set_global_service_context(service_context)


def generate_inferance_engine(data, force_rewrite=False):
    """
       Initialize and return a FAISS-based inference engine.

       Args:
           data: The directory where the data for the inference engine is located.
           force_rewrite (bool): If True, force rewriting the index.

       Returns:
           The initialized inference engine.

       Raises:
           RuntimeError: If unable to generate the inference engine.
       """
    try:
        global engine
        faiss_storage = FaissEmbeddingStorage(data_dir=data,
                                              dimension=embedded_dimension)
        faiss_storage.initialize_index(force_rewrite=force_rewrite)
        engine = faiss_storage.get_engine(is_chat_engine=is_chat_engine, streaming=streaming,
                                          similarity_top_k=similarity_top_k)
    except Exception as e:
        raise RuntimeError(f"Unable to generate the inference engine: {e}")


# load the vectorstore index
generate_inferance_engine(data_dir)

def call_llm_streamed(query):
    partial_response = ""
    response = llm.stream_complete(query)
    for token in response:
        partial_response += token.delta
        yield partial_response

def chatbot(query, chat_history, session_id):
    if data_source == "nodataset":
        yield llm.complete(query).text
        return

    if is_chat_engine:
        response = engine.chat(query)
    else:
        response = engine.query(query)

    # Aggregate scores by file
    file_scores = defaultdict(float)
    for node in response.source_nodes:
        metadata = node.metadata
        if 'filename' in metadata:
            file_name = metadata['filename']
            file_scores[file_name] += node.score

    # Find the file with the highest aggregated score
    highest_aggregated_score_file = None
    if file_scores:
        highest_aggregated_score_file = max(file_scores, key=file_scores.get)

    file_links = []
    seen_files = set()  # Set to track unique file names

    # Generate links for the file with the highest aggregated score
    if highest_aggregated_score_file:
        abs_path = Path(os.path.join(os.getcwd(), highest_aggregated_score_file.replace('\\', '/')))
        file_name = os.path.basename(abs_path)
        file_name_without_ext = abs_path.stem
        if file_name not in seen_files:  # Ensure the file hasn't already been processed
            if data_source == 'directory':
                file_link = file_name
            else:
                exit("Wrong data_source type")
            file_links.append(file_link)
            seen_files.add(file_name)  # Mark file as processed

    response_txt = str(response)
    if file_links:
        response_txt += "<br>Reference files:<br>" + "<br>".join(file_links)
    if not highest_aggregated_score_file:  # If no file with a high score was found
        response_txt = llm.complete(query).text
    yield response_txt

def stream_chatbot(query, chat_history, session_id):
    if data_source == "nodataset":
        for response in call_llm_streamed(query):
            yield response
        return

    if is_chat_engine:
        response = engine.stream_chat(query)
    else:
        response = engine.query(query)

    partial_response = ""
    if len(response.source_nodes) == 0:
        response = llm.stream_complete(query)
        for token in response:
            partial_response += token.delta
            yield partial_response
    else:
        # Aggregate scores by file
        file_scores = defaultdict(float)
        for node in response.source_nodes:
            if 'filename' in node.metadata:
                file_name = node.metadata['filename']
                file_scores[file_name] += node.score

        # Find the file with the highest aggregated score
        highest_score_file = max(file_scores, key=file_scores.get, default=None)

        file_links = []
        seen_files = set()
        for token in response.response_gen:
            partial_response += token
            yield partial_response
            time.sleep(0.05)

        time.sleep(0.2)

        if highest_score_file:
            abs_path = Path(os.path.join(os.getcwd(), highest_score_file.replace('\\', '/')))
            file_name = os.path.basename(abs_path)
            file_name_without_ext = abs_path.stem
            if file_name not in seen_files:  # Check if file_name is already seen
                if data_source == 'directory':
                    file_link = file_name
                else:
                    exit("Wrong data_source type")
                file_links.append(file_link)
                seen_files.add(file_name)  # Add file_name to the set

        if file_links:
            partial_response += "<br>Reference files:<br>" + "<br>".join(file_links)
        yield partial_response

    # call garbage collector after inference
    torch.cuda.empty_cache()
    gc.collect()

interface = MainInterface(chatbot=stream_chatbot if streaming else chatbot, streaming=streaming)

def on_shutdown_handler(session_id):
    global llm, service_context, embed_model, faiss_storage, engine
    import gc
    if llm is not None:
        llm.unload_model()
        del llm
    # Force a garbage collection cycle
    gc.collect()


interface.on_shutdown(on_shutdown_handler)


def reset_chat_handler(session_id):
    global faiss_storage
    global engine
    print('reset chat called', session_id)
    if is_chat_engine == True:
        faiss_storage.reset_engine(engine)


interface.on_reset_chat(reset_chat_handler)


def on_dataset_path_updated_handler(source, new_directory, video_count, session_id):
    print('data set path updated to ', source, new_directory, video_count, session_id)
    global engine
    global data_dir
    if source == 'directory':
        if data_dir != new_directory:
            data_dir = new_directory
            generate_inferance_engine(data_dir)

interface.on_dataset_path_updated(on_dataset_path_updated_handler)

def on_model_change_handler(model, metadata, session_id):
    model_path = os.path.join(os.getcwd(), metadata.get('model_path', None))
    engine_name = metadata.get('engine', None)
    tokenizer_path = os.path.join(os.getcwd(), metadata.get('tokenizer_path', None))

    if not model_path or not engine_name:
        print("Model path or engine not provided in metadata")
        return

    global llm, embedded_model, engine, data_dir, service_context

    if llm is not None:
        llm.unload_model()
        del llm

    llm = TrtLlmAPI(
        model_path=model_path,
        engine_name=engine_name,
        tokenizer_dir=tokenizer_path,
        temperature=metadata.get('temperature', 0.1),
        max_new_tokens=metadata.get('max_new_tokens', 512),
        context_window=metadata.get('max_input_token', 512),
        messages_to_prompt=messages_to_prompt,
        completion_to_prompt=completion_to_prompt,
        verbose=False
    )
    service_context = ServiceContext.from_service_context(service_context=service_context, llm=llm)
    set_global_service_context(service_context)
    generate_inferance_engine(data_dir)


interface.on_model_change(on_model_change_handler)


def on_dataset_source_change_handler(source, path, session_id):

    global data_source, data_dir, engine
    data_source = source

    if data_source == "nodataset":
        print(' No dataset source selected', session_id)
        return
    
    print('dataset source updated ', source, path, session_id)
    
    if data_source == "directory":
        data_dir = path
    else:
        print("Wrong data type selected")
    generate_inferance_engine(data_dir)

interface.on_dataset_source_updated(on_dataset_source_change_handler)

def handle_regenerate_index(source, path, session_id):
    generate_inferance_engine(path, force_rewrite=True)
    print("on regenerate index", source, path, session_id)

interface.on_regenerate_index(handle_regenerate_index)
# render the interface
interface.render()