File size: 1,294 Bytes
3c438f7
 
 
 
2a469e6
3c438f7
6f7e263
 
 
 
 
 
 
3c438f7
 
 
 
 
6f7e263
 
3c438f7
 
6f7e263
3c438f7
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from transformers import pipeline
import gradio as gr

# Load model from Hugging Face Hub
classifier = pipeline("text-classification", model="SamanthaStorm/EscalationPrediction", return_all_scores=True)

# Map model label IDs to readable escalation effect names
label_map = {
    "LABEL_0": "de-escalate",
    "LABEL_1": "escalate",
    "LABEL_2": "neutral"
}

def predict_effect(user_msg, *replies):
    replies = [r for r in replies if r.strip()]
    results = []
    for i, reply in enumerate(replies):
        text = f"{user_msg} <sep> {reply}"
        scores_raw = classifier(text)[0]
        scores = {label_map[s['label']]: round(s['score']*100, 2) for s in scores_raw}
        top = max(scores, key=scores.get)
        res = f"**Reply Option {i+1}** - Most likely: **{top.upper()}** ({scores[top]}%)\n"
        res += "\n".join([f"- {label}: {pct}%" for label, pct in scores.items()])
        results.append(res)
    return "\n\n".join(results)

gr.Interface(
    fn=predict_effect,
    inputs=[gr.Textbox(label="User Message", lines=2)] + [gr.Textbox(label=f"Reply Option {i}", lines=2) for i in range(1, 6)],
    outputs=gr.Markdown(),
    title="Escalation Effect Predictor",
    description="Paste a message and reply options. See which ones are most likely to escalate."
).launch()