File size: 1,700 Bytes
3d3e4d2
3ffaae9
 
3d3e4d2
3ffaae9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import gradio as gr
import pandas as pd
from transformers import pipeline

# Load model
model_id = "MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7"
classifier = pipeline("zero-shot-classification", model=model_id)

# Function to classify multiple items
def classify_items(items_text, labels_text):
    items = [line.strip() for line in items_text.strip().split("\n") if line.strip()]
    labels = [label.strip() for label in labels_text.strip().split(",") if label.strip()]
    
    results = []
    for item in items:
        output = classifier(item, labels)
        result_row = {
            "Item": item,
            "Top Label": output["labels"][0],
            "Top Score": round(output["scores"][0], 4)
        }
        for label, score in zip(output["labels"], output["scores"]):
            result_row[f"Score: {label}"] = round(score, 4)
        results.append(result_row)

    df = pd.DataFrame(results)
    return df, df.to_csv(index=False)

# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("## Zero-shot Classification for Multiple Items")

    with gr.Row():
        item_input = gr.Textbox(label="Enter your items (one per line)", lines=10, placeholder="e.g., I enjoy going to museums.\nI like spicy food.")
        label_input = gr.Textbox(label="Enter labels (comma-separated)", placeholder="e.g., art, food, travel")

    classify_button = gr.Button("Classify")
    output_table = gr.Dataframe(label="Results", interactive=False)
    download_csv = gr.File(label="Download CSV")

    classify_button.click(fn=classify_items,
                          inputs=[item_input, label_input],
                          outputs=[output_table, download_csv])

demo.launch()