aipsych's picture
Update app.py
3ffaae9 verified
import gradio as gr
import pandas as pd
from transformers import pipeline
# Load model
model_id = "MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7"
classifier = pipeline("zero-shot-classification", model=model_id)
# Function to classify multiple items
def classify_items(items_text, labels_text):
items = [line.strip() for line in items_text.strip().split("\n") if line.strip()]
labels = [label.strip() for label in labels_text.strip().split(",") if label.strip()]
results = []
for item in items:
output = classifier(item, labels)
result_row = {
"Item": item,
"Top Label": output["labels"][0],
"Top Score": round(output["scores"][0], 4)
}
for label, score in zip(output["labels"], output["scores"]):
result_row[f"Score: {label}"] = round(score, 4)
results.append(result_row)
df = pd.DataFrame(results)
return df, df.to_csv(index=False)
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## Zero-shot Classification for Multiple Items")
with gr.Row():
item_input = gr.Textbox(label="Enter your items (one per line)", lines=10, placeholder="e.g., I enjoy going to museums.\nI like spicy food.")
label_input = gr.Textbox(label="Enter labels (comma-separated)", placeholder="e.g., art, food, travel")
classify_button = gr.Button("Classify")
output_table = gr.Dataframe(label="Results", interactive=False)
download_csv = gr.File(label="Download CSV")
classify_button.click(fn=classify_items,
inputs=[item_input, label_input],
outputs=[output_table, download_csv])
demo.launch()