Spaces:
Sleeping
Sleeping
import onnxruntime as ort | |
from transformers import AutoTokenizer | |
import gradio as gr | |
# Define available models with their ONNX file paths and tokenizer names | |
models = { | |
"DistilBERT": { | |
"onnx_model_path": "distilbert.onnx", | |
"tokenizer_name": "distilbert-base-multilingual-cased", | |
}, | |
"BERT": { | |
"onnx_model_path": "bert.onnx", | |
"tokenizer_name": "bert-base-multilingual-cased", | |
}, | |
"MuRIL": { | |
"onnx_model_path": "muril.onnx", | |
"tokenizer_name": "google/muril-base-cased", | |
}, | |
"RoBERTa": { | |
"onnx_model_path": "roberta.onnx", | |
"tokenizer_name": "cardiffnlp/twitter-roberta-base-emotion", | |
}, | |
} | |
# Load models and tokenizers into memory | |
model_sessions = {} | |
tokenizers = {} | |
for model_name, config in models.items(): | |
print(f"Loading {model_name}...") | |
model_sessions[model_name] = ort.InferenceSession(config["onnx_model_path"]) | |
tokenizers[model_name] = AutoTokenizer.from_pretrained(config["tokenizer_name"]) | |
print("All models loaded!") | |
# Prediction function | |
def predict_with_model(text, model_name): | |
# Select the appropriate ONNX session and tokenizer | |
ort_session = model_sessions[model_name] | |
tokenizer = tokenizers[model_name] | |
# Tokenize the input text | |
inputs = tokenizer(text, return_tensors="np", padding=True, truncation=True) | |
# Run ONNX inference | |
outputs = ort_session.run(None, { | |
"input_ids": inputs["input_ids"], | |
"attention_mask": inputs["attention_mask"], | |
}) | |
# Post-process the output | |
logits = outputs[0] | |
label = "Hate Speech" if logits[0][1] > logits[0][0] else "Not Hate Speech" | |
return label | |
# Define Gradio interface | |
interface = gr.Interface( | |
fn=predict_with_model, | |
inputs=[ | |
gr.Textbox(label="Enter text to classify"), | |
gr.Dropdown( | |
choices=list(models.keys()), | |
label="Select a model", | |
), | |
], | |
outputs="text", | |
title="Multi-Model Hate Speech Detection", | |
description="Choose a model and enter text to classify whether it's hate speech.", | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
interface.launch() | |