import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import json
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, GemmaTokenizer,  StoppingCriteria, StoppingCriteriaList, GenerationConfig
import os

#sft_model = "somosnlp/RecetasDeLaAbuela_mistral-7b-instruct-v0.2-bnb-4bit"
#base_model_name = "unsloth/Mistral-7B-Instruct-v0.2"
#sft_model = "somosnlp/RecetasDeLaAbuela_gemma-2b-it-bnb-4bit"
#base_model_name = "unsloth/gemma-2b-it-bnb-4bit"

sft_model = "somosnlp/RecetasDeLaAbuela5k_gemma-2b-bnb-4bit"
base_model_name = "unsloth/gemma-2b-bnb-4bit"
#base_model_name = "unsloth/gemma-2b-it-bnb-4bit"

max_seq_length=300
base_model = AutoModelForCausalLM.from_pretrained(base_model_name,return_dict=True,device_map="auto", torch_dtype=torch.float16,)
tokenizer = AutoTokenizer.from_pretrained(base_model_name, max_length = max_seq_length)
ft_model = PeftModel.from_pretrained(base_model, sft_model)
model = ft_model.merge_and_unload()
model.save_pretrained(".")
tokenizer.save_pretrained(".")

class ListOfTokensStoppingCriteria(StoppingCriteria):
    """
    Clase para definir un criterio de parada basado en una lista de tokens específicos.
    """
    def __init__(self, tokenizer, stop_tokens):
        self.tokenizer = tokenizer
        # Codifica cada token de parada y guarda sus IDs en una lista
        self.stop_token_ids_list = [tokenizer.encode(stop_token, add_special_tokens=False) for stop_token in stop_tokens]

    def __call__(self, input_ids, scores, **kwargs):
        # Verifica si los últimos tokens generados coinciden con alguno de los conjuntos de tokens de parada
        for stop_token_ids in self.stop_token_ids_list:
            len_stop_tokens = len(stop_token_ids)
            if len(input_ids[0]) >= len_stop_tokens:
                if input_ids[0, -len_stop_tokens:].tolist() == stop_token_ids:
                    return True
        return False

# Uso del criterio de parada personalizado
stop_tokens = ["<end_of_turn>"]  # Lista de tokens de parada

# Inicializa tu criterio de parada con el tokenizer y la lista de tokens de parada
stopping_criteria = ListOfTokensStoppingCriteria(tokenizer, stop_tokens)

# Añade tu criterio de parada a una StoppingCriteriaList
stopping_criteria_list = StoppingCriteriaList([stopping_criteria])

def generate_text(prompt, context, max_length=max_seq_length):
  prompt=prompt.replace("\n", "").replace("¿","").replace("?","")
  input_text = f'''<bos><start_of_turn>system\n{context}<end_of_turn><start_of_turn>user\n{prompt}<end_of_turn><start_of_turn>model\n'''
  inputs = tokenizer.encode(input_text, return_tensors="pt", add_special_tokens=False).to("cuda:0")
  max_new_tokens=max_length
  generation_config = GenerationConfig(
                max_new_tokens=max_new_tokens,
                temperature=0.1, #top_p=0.9,top_k=50, # 45
                repetition_penalty=1.3,  #1.1
                do_sample=True,
            )
  outputs = model.generate(generation_config=generation_config, input_ids=inputs, stopping_criteria=stopping_criteria_list,)
  return tokenizer.decode(outputs[0], skip_special_tokens=False) #True

def mostrar_respuesta(pregunta, contexto):
    try:
      res= generate_text(pregunta, contexto, max_length=max_seq_length)
      return str(res)
    except Exception as e:
      return str(e)

# Ejemplos de preguntas
mis_ejemplos = [
    ["Ingredientes y pasos de la receta asado de cordero", "Eres un agente experto en nutrición y cocina."],
    ["Ingredientes y pasos de la receta lomo a la pimienta con papas", "Eres un agente experto en nutrición y cocina."],
    ["Ingredientes de la receta coles de bruselas", "Eres un agente experto en nutrición y cocina."],
]

iface = gr.Interface(
    fn=mostrar_respuesta,
    inputs=[gr.Textbox(label="Pregunta"), gr.Textbox(label="Contexto", value="Eres un agente experto en nutrición y cocina."),],
    outputs=[gr.Textbox(label="Respuesta", lines=4),],
    title="Recetas de la Abuel@",
    description=f'Esta aplicación RecetasDeLaAbuel@ es una demostración de un asistente inteligente de cocina especializado en el idioma español. '
    f'Está basada en el corpus https://huggingface.co/datasets/somosnlp/RecetasDeLaAbuela y utiliza el modelo '
    f'https://huggingface.co/somosnlp/RecetasDeLaAbuela_gemma-2b-it-bnb-4bit . En las siguientes entradas indica tu pregunta sobre una comida o receta de cocina '
    f'(ingredientes o pasos de preparación) y en el contexto añade el escenario de uso (por defecto, agente experto en nutrición y cocina). '
    f'Introduce tu pregunta sobre recetas de cocina.',
    examples=mis_ejemplos,
)

iface.queue(max_size=14).launch() # share=True,debug=True