|
|
import os |
|
|
import streamlit as st |
|
|
from fastapi import FastAPI |
|
|
from pydantic import BaseModel |
|
|
from typing import List |
|
|
import uvicorn |
|
|
import threading |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
import torch |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("PlanTL-GOB-ES/gpt2-large-bne") |
|
|
model = AutoModelForCausalLM.from_pretrained("PlanTL-GOB-ES/gpt2-large-bne") |
|
|
|
|
|
|
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
model.config.pad_token_id = tokenizer.eos_token_id |
|
|
|
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
class Pareja(BaseModel): |
|
|
nombre_novio: str |
|
|
nombre_novia: str |
|
|
historia: str |
|
|
frases: List[str] |
|
|
anecdotas: List[str] |
|
|
|
|
|
|
|
|
class ChatRequest(BaseModel): |
|
|
pareja: Pareja |
|
|
pregunta: str |
|
|
|
|
|
@app.post("/chat") |
|
|
def chat(msg: ChatRequest): |
|
|
nombre_novio = msg.pareja.nombre_novio |
|
|
nombre_novia = msg.pareja.nombre_novia |
|
|
historia = msg.pareja.historia |
|
|
frases = ", ".join(msg.pareja.frases) |
|
|
anecdotas = " ".join(msg.pareja.anecdotas) |
|
|
pregunta = msg.pregunta |
|
|
|
|
|
|
|
|
input_text = f""" |
|
|
Eres un asistente virtual especializado en responder preguntas sobre la pareja {nombre_novio} y {nombre_novia}. Usa la siguiente información para responder de manera precisa y cercana: |
|
|
|
|
|
Historia de la pareja: |
|
|
{historia} |
|
|
|
|
|
Frases memorables de la pareja: |
|
|
{frases} |
|
|
|
|
|
Momentos especiales juntos: |
|
|
{anecdotas} |
|
|
|
|
|
Pregunta: {pregunta} |
|
|
Respuesta corta y precisa: |
|
|
""" |
|
|
|
|
|
|
|
|
inputs = tokenizer.encode_plus( |
|
|
input_text, |
|
|
return_tensors="pt", |
|
|
max_length=1024, |
|
|
truncation=True, |
|
|
padding=True, |
|
|
return_attention_mask=True |
|
|
) |
|
|
|
|
|
|
|
|
response_ids = model.generate( |
|
|
inputs["input_ids"], |
|
|
attention_mask=inputs["attention_mask"], |
|
|
max_new_tokens=150, |
|
|
num_return_sequences=1, |
|
|
no_repeat_ngram_size=2, |
|
|
do_sample=True, |
|
|
top_p=0.9, |
|
|
top_k=50, |
|
|
temperature=0.8, |
|
|
pad_token_id=tokenizer.eos_token_id, |
|
|
eos_token_id=tokenizer.eos_token_id, |
|
|
early_stopping=True, |
|
|
repetition_penalty=1.2 |
|
|
) |
|
|
|
|
|
|
|
|
response_text = tokenizer.decode(response_ids[0], skip_special_tokens=True).strip() |
|
|
|
|
|
|
|
|
if len(response_text.strip()) < 10: |
|
|
return { |
|
|
"response": "Lo siento, necesito reformular la respuesta. ¿Podrías hacer la pregunta de otra manera?", |
|
|
"contexto": input_text |
|
|
} |
|
|
|
|
|
return {"response": response_text, "contexto": input_text} |
|
|
|
|
|
|
|
|
def run_api(): |
|
|
port = int(os.getenv("PORT", 7860)) |
|
|
uvicorn.run(app, host="0.0.0.0", port=port) |
|
|
|
|
|
threading.Thread(target=run_api, daemon=True).start() |
|
|
|
|
|
|
|
|
st.title("Mi Amigo Virtual 🤖") |
|
|
st.write("Escríbeme algo y te responderé!") |
|
|
|
|
|
contexto_defecto = "Historia de amor" |
|
|
user_input = st.text_input("Tú:") |
|
|
if user_input: |
|
|
response = chat(Message(contexto=contexto_defecto, text=user_input)) |
|
|
st.write("🤖:", response["response"]) |