iabodas / app2.py
Garabatos's picture
1
5d734b8
raw
history blame
3.56 kB
import os
import streamlit as st
from fastapi import FastAPI
from pydantic import BaseModel
from typing import List
import uvicorn
import threading
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# ======== Cargar el modelo GPT-2 en español =========
tokenizer = AutoTokenizer.from_pretrained("PlanTL-GOB-ES/gpt2-large-bne")
model = AutoModelForCausalLM.from_pretrained("PlanTL-GOB-ES/gpt2-large-bne")
# Configurar el padding token y el modelo
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = tokenizer.eos_token_id
# ======== Definir API con FastAPI =========
app = FastAPI()
class Pareja(BaseModel):
nombre_novio: str
nombre_novia: str
historia: str
frases: List[str]
anecdotas: List[str]
# Definir el modelo de solicitud que recibirá el JSON
class ChatRequest(BaseModel):
pareja: Pareja
pregunta: str
@app.post("/chat")
def chat(msg: ChatRequest):
nombre_novio = msg.pareja.nombre_novio
nombre_novia = msg.pareja.nombre_novia
historia = msg.pareja.historia
frases = ", ".join(msg.pareja.frases)
anecdotas = " ".join(msg.pareja.anecdotas)
pregunta = msg.pregunta
# Mejorar el formato del prompt con instrucciones claras
input_text = f"""
Eres un asistente virtual especializado en responder preguntas sobre la pareja {nombre_novio} y {nombre_novia}. Usa la siguiente información para responder de manera precisa y cercana:
Historia de la pareja:
{historia}
Frases memorables de la pareja:
{frases}
Momentos especiales juntos:
{anecdotas}
Pregunta: {pregunta}
Respuesta corta y precisa:
"""
# Tokenización con parámetros ajustados
inputs = tokenizer.encode_plus(
input_text,
return_tensors="pt",
max_length=1024,
truncation=True,
padding=True,
return_attention_mask=True
)
# Parámetros ajustados para generación de texto más coherente
response_ids = model.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_new_tokens=150, # Aumentar el número de tokens generados
num_return_sequences=1,
no_repeat_ngram_size=2,
do_sample=True, # Cambiar a True para generar texto con variabilidad
top_p=0.9,
top_k=50,
temperature=0.8, # Aumento de la temperatura para una generación más flexible
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
early_stopping=True,
repetition_penalty=1.2 # Penalización de repetición
)
# Procesar la respuesta
response_text = tokenizer.decode(response_ids[0], skip_special_tokens=True).strip()
# Verificación básica
if len(response_text.strip()) < 10:
return {
"response": "Lo siento, necesito reformular la respuesta. ¿Podrías hacer la pregunta de otra manera?",
"contexto": input_text
}
return {"response": response_text, "contexto": input_text}
# ======== Función para ejecutar FastAPI en segundo plano =========
def run_api():
port = int(os.getenv("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)
threading.Thread(target=run_api, daemon=True).start()
# ======== Interfaz con Streamlit =========
st.title("Mi Amigo Virtual 🤖")
st.write("Escríbeme algo y te responderé!")
contexto_defecto = "Historia de amor"
user_input = st.text_input("Tú:")
if user_input:
response = chat(Message(contexto=contexto_defecto, text=user_input))
st.write("🤖:", response["response"])