File size: 1,929 Bytes
fbe2da8
86f9780
 
 
f081761
 
 
 
 
86f9780
f081761
 
 
 
86f9780
f081761
 
 
 
 
 
23cf7a9
f081761
 
23cf7a9
86f9780
f081761
4036329
 
f081761
 
 
 
 
 
 
 
86f9780
f081761
 
 
86f9780
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import gradio as gr
from transformers import pipeline
import pandas as pd

MODEL_MAP = {
    "MoritzLaurer/deberta-v3-large-zeroshot-v2.0": "MoritzLaurer/deberta-v3-large-zeroshot-v2.0",
    "MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7": "MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7",
    "joeddav/xlm-roberta-large-xnli": "joeddav/xlm-roberta-large-xnli"
}

def classify_items(model_name, items_text, labels_text):
    classifier = pipeline("zero-shot-classification", model=MODEL_MAP[model_name])
    items = [item.strip() for item in items_text.split("\n") if item.strip()]
    labels = [label.strip() for label in labels_text.split(",") if label.strip()]

    results = []
    for item in items:
        out = classifier(item, labels, multi_label=True)
        scores = {label: prob for label, prob in zip(out["labels"], out["scores"])}
        scores["item"] = item
        results.append(scores)

    df = pd.DataFrame(results).fillna(0)
    return df, gr.File.update(value=df.to_csv(index=False), visible=True)

with gr.Blocks() as demo:
    gr.Markdown("## 🧠 Zero-Shot Questionnaire Classifier")

    with gr.Row():
        model_choice = gr.Dropdown(choices=list(MODEL_MAP.keys()), label="Choose a zero-shot model")

    item_input = gr.Textbox(label="Enter questionnaire items (one per line)", lines=6, placeholder="I enjoy social gatherings.\nI prefer planning over spontaneity.")
    label_input = gr.Textbox(label="Enter response options (comma-separated)", placeholder="Strongly disagree, Disagree, Neutral, Agree, Strongly agree")

    run_button = gr.Button("Classify")
    output_table = gr.Dataframe(label="Classification Results")
    download_csv = gr.File(label="Download CSV", visible=False)

    run_button.click(fn=classify_items,
                     inputs=[model_choice, item_input, label_input],
                     outputs=[output_table, download_csv])

demo.launch()