Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -2,23 +2,22 @@
|
|
2 |
|
3 |
import gradio as gr
|
4 |
import torch
|
|
|
5 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
6 |
|
7 |
-
# β
|
|
|
8 |
model_id = "Ozziejoe/eemm-deberta-v3-small"
|
9 |
|
10 |
-
# β
Label names
|
11 |
label_names = [
|
12 |
"Cognition", "Affect", "Self", "Motivation", "Attention", "OB", "Context",
|
13 |
"Social", "Physical", "Psych"
|
14 |
]
|
15 |
|
16 |
-
|
17 |
-
|
18 |
-
model = AutoModelForSequenceClassification.from_pretrained(model_id)
|
19 |
model.eval()
|
20 |
|
21 |
-
# β
Classification function
|
22 |
def classify(text):
|
23 |
with torch.no_grad():
|
24 |
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
|
@@ -27,14 +26,14 @@ def classify(text):
|
|
27 |
labels = [label_names[i] for i, p in enumerate(probs) if p > 0.5]
|
28 |
return ", ".join(labels) if labels else "No domain confidently predicted."
|
29 |
|
30 |
-
# β
Gradio interface
|
31 |
demo = gr.Interface(
|
32 |
fn=classify,
|
33 |
inputs=gr.Textbox(label="Enter a question"),
|
34 |
outputs=gr.Textbox(label="Predicted domains"),
|
35 |
title="EEMM Multi-Label Classifier",
|
36 |
-
description="Classifies a question into psychological domains.",
|
37 |
allow_flagging="never"
|
38 |
)
|
39 |
|
40 |
-
|
|
|
|
2 |
|
3 |
import gradio as gr
|
4 |
import torch
|
5 |
+
import os
|
6 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
7 |
|
8 |
+
# β
Load token from secret environment variable
|
9 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
10 |
model_id = "Ozziejoe/eemm-deberta-v3-small"
|
11 |
|
|
|
12 |
label_names = [
|
13 |
"Cognition", "Affect", "Self", "Motivation", "Attention", "OB", "Context",
|
14 |
"Social", "Physical", "Psych"
|
15 |
]
|
16 |
|
17 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=HF_TOKEN)
|
18 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_id, use_auth_token=HF_TOKEN)
|
|
|
19 |
model.eval()
|
20 |
|
|
|
21 |
def classify(text):
|
22 |
with torch.no_grad():
|
23 |
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
|
|
|
26 |
labels = [label_names[i] for i, p in enumerate(probs) if p > 0.5]
|
27 |
return ", ".join(labels) if labels else "No domain confidently predicted."
|
28 |
|
|
|
29 |
demo = gr.Interface(
|
30 |
fn=classify,
|
31 |
inputs=gr.Textbox(label="Enter a question"),
|
32 |
outputs=gr.Textbox(label="Predicted domains"),
|
33 |
title="EEMM Multi-Label Classifier",
|
34 |
+
description="Classifies a question into multiple psychological domains.",
|
35 |
allow_flagging="never"
|
36 |
)
|
37 |
|
38 |
+
if __name__ == "__main__":
|
39 |
+
demo.launch(share=True, server_name="0.0.0.0")
|