import gradio as gr
from transformers import pipeline
import torch
import logging

# Настройка логирования
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Загружаем модель
model_name = "sberbank-ai/rugpt3large_based_on_gpt2"
try:
    logger.info(f"Попытка загрузки модели {model_name}...")
    generator = pipeline(
        "text-generation",
        model=model_name,
        device=-1,  # Используем CPU
        framework="pt",
        max_length=80,  # Уменьшен для стабильности на CPU
        truncation=True,
        model_kwargs={"torch_dtype": torch.float32}
    )
    logger.info("Модель успешно загружена.")
except Exception as e:
    logger.error(f"Ошибка загрузки модели: {e}")
    exit(1)

def respond(message, max_tokens=80, temperature=0.5, top_p=0.7):
    # Промпт с акцентом на медицинский ответ
    prompt = f"Вы медицинский чат-бот. Пользователь говорит: '{message}'. Дайте краткий ответ только с диагнозом и лечением на русском языке в формате: Диагноз: [диагноз]. Лечение: [лечение]."
    try:
        logger.info(f"Генерация ответа для: {message}")
        outputs = generator(
            prompt,
            max_length=max_tokens,
            temperature=temperature,
            top_p=top_p,
            do_sample=True,
            num_return_sequences=1,
            no_repeat_ngram_size=2  # Предотвращаем повторы
        )
        response = outputs[0]["generated_text"].replace(prompt, "").strip()
        logger.info(f"Ответ сгенерирован: {response}")

        # Проверка и форматирование ответа
        if "Диагноз:" in response and "Лечение:" in response:
            return response
        else:
            # Если формат не соблюден, извлекаем диагноз и добавляем базовое лечение
            diagnosis = response.split(".")[0].strip() if response else "Неизвестно"
            return f"Диагноз: {diagnosis}. Лечение: Обратитесь к врачу для точной помощи."
    except Exception as e:
        logger.error(f"Ошибка генерации ответа: {e}")
        return "Ошибка генерации. Проконсультируйтесь с врачом."

demo = gr.Interface(
    fn=respond,
    inputs=[
        gr.Textbox(label="Ваше сообщение", placeholder="Опишите симптомы (например, 'Болит горло')..."),
        gr.Slider(minimum=50, maximum=150, value=80, step=10, label="Макс. токенов"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.5, label="Температура"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.7, label="Top-p")
    ],
    outputs="text",
    title="Медицинский чат-бот на базе RuGPT-3 Large",
    theme=gr.themes.Soft(),
    description="Введите симптомы, и чат-бот предложит диагноз и лечение. Для точной помощи обратитесь к врачу."
)

if __name__ == "__main__":
    demo.launch()