import streamlit as st
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel

# ✅ Setup device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ✅ Load tokenizer from repo files
tokenizer = AutoTokenizer.from_pretrained(".")

# ✅ Define ScoringModel with safe DeBERTa load
class ScoringModel(nn.Module):
    def __init__(self, base_model_name="microsoft/deberta-v3-small", dropout_rate=0.242):
        super().__init__()
        self.base = AutoModel.from_pretrained(
            base_model_name,
            torch_dtype=torch.float32,         # ensure weights are initialized
            low_cpu_mem_usage=False            # force full model load (avoid meta tensors)
        )
        self.base.gradient_checkpointing_enable()
        self.dropout1 = nn.Dropout(dropout_rate)
        self.dropout2 = nn.Dropout(dropout_rate)
        self.dropout3 = nn.Dropout(dropout_rate)
        self.classifier = nn.Linear(self.base.config.hidden_size, 1)

    def forward(self, input_ids, attention_mask):
        hidden = self.base(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0]
        logits = (self.classifier(self.dropout1(hidden)) +
                  self.classifier(self.dropout2(hidden)) +
                  self.classifier(self.dropout3(hidden))) / 3
        return logits

# ✅ Load model and weights safely
model = ScoringModel()
model.load_state_dict(torch.load("scoring_model.pt", map_location=device))
model.to(device)
model.eval()

# ✅ Streamlit UI setup
st.set_page_config(page_title="🧠 LLM Response Evaluator", page_icon="📝", layout="wide")
st.markdown("<h1 style='text-align: center;'>🧠 LLM Response Evaluator</h1>", unsafe_allow_html=True)
st.markdown("---")

# ✅ Sidebar info
with st.sidebar:
    st.header("ℹ️ About")
    st.markdown("""
    This app evaluates *which AI response is better* given a prompt.

    - Enter a **prompt** and two **responses**
    - The model predicts **which one is higher quality**

    Powered by a fine-tuned **DeBERTa-v3-small** model 🚀
    """)

# ✅ Input form
col1, col2 = st.columns(2)

with col1:
    prompt = st.text_area("📝 Enter the Prompt", height=150)

with col2:
    st.markdown("<br>", unsafe_allow_html=True)
    st.markdown("👉 Provide two possible responses below:")

response_a = st.text_area("✏️ Response A", height=100)
response_b = st.text_area("✏️ Response B", height=100)

# ✅ Prediction
if st.button("🔍 Evaluate Responses"):
    if prompt and response_a and response_b:
        text_a = f"Prompt: {prompt} [SEP] {response_a}"
        text_b = f"Prompt: {prompt} [SEP] {response_b}"

        encoded_a = tokenizer(text_a, return_tensors='pt', padding='max_length', truncation=True, max_length=186)
        encoded_b = tokenizer(text_b, return_tensors='pt', padding='max_length', truncation=True, max_length=186)

        encoded_a = {
            "input_ids": encoded_a["input_ids"].to(device),
            "attention_mask": encoded_a["attention_mask"].to(device)
        }
        encoded_b = {
            "input_ids": encoded_b["input_ids"].to(device),
            "attention_mask": encoded_b["attention_mask"].to(device)
        }

        with torch.no_grad():
            score_a = model(**encoded_a).squeeze()
            score_b = model(**encoded_b).squeeze()

        prob_a = torch.sigmoid(score_a).item()
        prob_b = torch.sigmoid(score_b).item()

        st.subheader("🔮 Prediction Result")
        if prob_b > prob_a:
            st.success(f"✅ *Response B is better!* (Confidence: {prob_b:.4f})")
        else:
            st.success(f"✅ *Response A is better!* (Confidence: {prob_a:.4f})")

        mcol1, mcol2 = st.columns(2)
        mcol1.metric(label="Confidence A", value=f"{prob_a:.4f}")
        mcol2.metric(label="Confidence B", value=f"{prob_b:.4f}")

        st.markdown("---")
        st.subheader("📊 Confidence Comparison")
        st.bar_chart({"Confidence": [prob_a, prob_b]})
    else:
        st.warning("⚠️ Please fill in *all fields* before evaluating!")