sk23aib commited on
Commit
46c07fd
·
verified ·
1 Parent(s): bea3a35

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -46
app.py CHANGED
@@ -1,46 +1,71 @@
1
- import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
- import torch
4
-
5
- # Load models
6
- model_paths = {
7
- "BERT": "models/bert_model",
8
- "XLNet": "models/xlnet_model",
9
- "GPT-2": "models/gpt2_model"
10
- }
11
-
12
- models = {}
13
- tokenizers = {}
14
-
15
- for name, path in model_paths.items():
16
- tokenizers[name] = AutoTokenizer.from_pretrained(path)
17
- models[name] = AutoModelForSequenceClassification.from_pretrained(path)
18
- models[name].eval()
19
-
20
- # Emotion labels (adjust based on your dataset!)
21
- labels = ["anger", "joy", "sadness", "fear", "love", "surprise"]
22
-
23
- def classify(text, model_choice):
24
- tokenizer = tokenizers[model_choice]
25
- model = models[model_choice]
26
-
27
- inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
28
- with torch.no_grad():
29
- outputs = model(**inputs)
30
- probs = torch.nn.functional.softmax(outputs.logits, dim=1)
31
- top_prob, top_idx = torch.max(probs, dim=1)
32
-
33
- return f"Predicted: {labels[top_idx.item()]} ({top_prob.item():.2f})"
34
-
35
- iface = gr.Interface(
36
- fn=classify,
37
- inputs=[
38
- gr.Textbox(lines=3, placeholder="Enter text here...", label="Text"),
39
- gr.Radio(choices=["BERT", "XLNet", "GPT-2"], label="Choose Model")
40
- ],
41
- outputs="text",
42
- title="Emotion Classifier (BERT / XLNet / GPT-2)"
43
- )
44
-
45
- if __name__ == "__main__":
46
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import (
4
+ BertTokenizer,
5
+ XLNetTokenizer,
6
+ GPT2Tokenizer,
7
+ AutoModelForSequenceClassification
8
+ )
9
+
10
+ # Model repositories on Hugging Face Hub
11
+ model_repos = {
12
+ "BERT": "sk23aib/emotion-bert",
13
+ "XLNet": "sk23aib/emotion-xlnet",
14
+ "GPT-2": "sk23aib/emotion-gpt2"
15
+ }
16
+
17
+ # Emotion labels (must match model training order)
18
+ emotion_labels = [
19
+ "anger", "boredom", "empty", "enthusiasm", "fun", "happiness", "hate",
20
+ "love", "neutral", "relief", "sadness", "surprise", "worry"
21
+ ]
22
+
23
+ # Load models and tokenizers
24
+ loaded_models = {}
25
+
26
+ # BERT
27
+ bert_tokenizer = BertTokenizer.from_pretrained(model_repos["BERT"])
28
+ bert_model = AutoModelForSequenceClassification.from_pretrained(model_repos["BERT"])
29
+ bert_model.eval()
30
+ loaded_models["BERT"] = {"tokenizer": bert_tokenizer, "model": bert_model}
31
+
32
+ # XLNet
33
+ xlnet_tokenizer = XLNetTokenizer.from_pretrained(model_repos["XLNet"])
34
+ xlnet_model = AutoModelForSequenceClassification.from_pretrained(model_repos["XLNet"])
35
+ xlnet_model.eval()
36
+ loaded_models["XLNet"] = {"tokenizer": xlnet_tokenizer, "model": xlnet_model}
37
+
38
+ # GPT-2
39
+ gpt2_tokenizer = GPT2Tokenizer.from_pretrained(model_repos["GPT-2"], padding_side="left")
40
+ gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token # Required for GPT-2
41
+ gpt2_model = AutoModelForSequenceClassification.from_pretrained(model_repos["GPT-2"])
42
+ gpt2_model.config.pad_token_id = gpt2_tokenizer.pad_token_id
43
+ gpt2_model.eval()
44
+ loaded_models["GPT-2"] = {"tokenizer": gpt2_tokenizer, "model": gpt2_model}
45
+
46
+ # Inference function – return top emotion + probability
47
+ def predict_emotions(text):
48
+ output_lines = []
49
+ with torch.no_grad():
50
+ for model_name, components in loaded_models.items():
51
+ tokenizer = components["tokenizer"]
52
+ model = components["model"]
53
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
54
+ logits = model(**inputs).logits
55
+ probs = torch.nn.functional.softmax(logits, dim=1)[0]
56
+ top_idx = torch.argmax(probs).item()
57
+ top_emotion = emotion_labels[top_idx]
58
+ top_confidence = round(float(probs[top_idx]), 4)
59
+ output_lines.append(f"{model_name}: {top_emotion} ({top_confidence})")
60
+ return "\n".join(output_lines)
61
+
62
+ # Gradio Interface
63
+ interface = gr.Interface(
64
+ fn=predict_emotions,
65
+ inputs=gr.Textbox(lines=3, placeholder="Type a sentence to analyze..."),
66
+ outputs=gr.Textbox(label="Top Emotion by Model"),
67
+ title="Multi-Model Emotion Classifier",
68
+ description="See which emotion is predicted by BERT, XLNet, and GPT-2, along with their confidence."
69
+ )
70
+
71
+ interface.launch()