|
import gradio as gr |
|
import pandas as pd |
|
from transformers import pipeline |
|
|
|
|
|
model_id = "MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7" |
|
classifier = pipeline("zero-shot-classification", model=model_id) |
|
|
|
|
|
def classify_items(items_text, labels_text): |
|
items = [line.strip() for line in items_text.strip().split("\n") if line.strip()] |
|
labels = [label.strip() for label in labels_text.strip().split(",") if label.strip()] |
|
|
|
results = [] |
|
for item in items: |
|
output = classifier(item, labels) |
|
result_row = { |
|
"Item": item, |
|
"Top Label": output["labels"][0], |
|
"Top Score": round(output["scores"][0], 4) |
|
} |
|
for label, score in zip(output["labels"], output["scores"]): |
|
result_row[f"Score: {label}"] = round(score, 4) |
|
results.append(result_row) |
|
|
|
df = pd.DataFrame(results) |
|
return df, df.to_csv(index=False) |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## Zero-shot Classification for Multiple Items") |
|
|
|
with gr.Row(): |
|
item_input = gr.Textbox(label="Enter your items (one per line)", lines=10, placeholder="e.g., I enjoy going to museums.\nI like spicy food.") |
|
label_input = gr.Textbox(label="Enter labels (comma-separated)", placeholder="e.g., art, food, travel") |
|
|
|
classify_button = gr.Button("Classify") |
|
output_table = gr.Dataframe(label="Results", interactive=False) |
|
download_csv = gr.File(label="Download CSV") |
|
|
|
classify_button.click(fn=classify_items, |
|
inputs=[item_input, label_input], |
|
outputs=[output_table, download_csv]) |
|
|
|
demo.launch() |
|
|