File size: 3,559 Bytes
5d734b8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
import os
import streamlit as st
from fastapi import FastAPI
from pydantic import BaseModel
from typing import List
import uvicorn
import threading
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# ======== Cargar el modelo GPT-2 en español =========
tokenizer = AutoTokenizer.from_pretrained("PlanTL-GOB-ES/gpt2-large-bne")
model = AutoModelForCausalLM.from_pretrained("PlanTL-GOB-ES/gpt2-large-bne")
# Configurar el padding token y el modelo
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = tokenizer.eos_token_id
# ======== Definir API con FastAPI =========
app = FastAPI()
class Pareja(BaseModel):
nombre_novio: str
nombre_novia: str
historia: str
frases: List[str]
anecdotas: List[str]
# Definir el modelo de solicitud que recibirá el JSON
class ChatRequest(BaseModel):
pareja: Pareja
pregunta: str
@app.post("/chat")
def chat(msg: ChatRequest):
nombre_novio = msg.pareja.nombre_novio
nombre_novia = msg.pareja.nombre_novia
historia = msg.pareja.historia
frases = ", ".join(msg.pareja.frases)
anecdotas = " ".join(msg.pareja.anecdotas)
pregunta = msg.pregunta
# Mejorar el formato del prompt con instrucciones claras
input_text = f"""
Eres un asistente virtual especializado en responder preguntas sobre la pareja {nombre_novio} y {nombre_novia}. Usa la siguiente información para responder de manera precisa y cercana:
Historia de la pareja:
{historia}
Frases memorables de la pareja:
{frases}
Momentos especiales juntos:
{anecdotas}
Pregunta: {pregunta}
Respuesta corta y precisa:
"""
# Tokenización con parámetros ajustados
inputs = tokenizer.encode_plus(
input_text,
return_tensors="pt",
max_length=1024,
truncation=True,
padding=True,
return_attention_mask=True
)
# Parámetros ajustados para generación de texto más coherente
response_ids = model.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_new_tokens=150, # Aumentar el número de tokens generados
num_return_sequences=1,
no_repeat_ngram_size=2,
do_sample=True, # Cambiar a True para generar texto con variabilidad
top_p=0.9,
top_k=50,
temperature=0.8, # Aumento de la temperatura para una generación más flexible
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
early_stopping=True,
repetition_penalty=1.2 # Penalización de repetición
)
# Procesar la respuesta
response_text = tokenizer.decode(response_ids[0], skip_special_tokens=True).strip()
# Verificación básica
if len(response_text.strip()) < 10:
return {
"response": "Lo siento, necesito reformular la respuesta. ¿Podrías hacer la pregunta de otra manera?",
"contexto": input_text
}
return {"response": response_text, "contexto": input_text}
# ======== Función para ejecutar FastAPI en segundo plano =========
def run_api():
port = int(os.getenv("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)
threading.Thread(target=run_api, daemon=True).start()
# ======== Interfaz con Streamlit =========
st.title("Mi Amigo Virtual 🤖")
st.write("Escríbeme algo y te responderé!")
contexto_defecto = "Historia de amor"
user_input = st.text_input("Tú:")
if user_input:
response = chat(Message(contexto=contexto_defecto, text=user_input))
st.write("🤖:", response["response"]) |