Rerandaka commited on
Commit
9e84c8b
Β·
verified Β·
1 Parent(s): 0f79359

danwath weda krpn

Browse files
Files changed (1) hide show
  1. app.py +8 -9
app.py CHANGED
@@ -1,28 +1,27 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoTokenizer, AutoModelForSequenceClassification # βœ… required
4
 
5
  # Load model
6
  model_id = "Rerandaka/Cild_safety_bigbird"
7
  tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
8
  model = AutoModelForSequenceClassification.from_pretrained(model_id)
9
 
10
- # Inference function
11
  def classify(text):
12
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
13
  with torch.no_grad():
14
  logits = model(**inputs).logits
15
- predicted_class = torch.argmax(logits, dim=1).item()
16
- return str(predicted_class)
17
 
18
- # API-ready Gradio Interface
19
  demo = gr.Interface(
20
  fn=classify,
21
- inputs=gr.Textbox(label="Enter text"),
22
- outputs=gr.Textbox(label="Prediction"),
23
- api_name="/classify"
24
  )
25
 
26
- # βœ… Enable API and queue
27
  demo.queue()
28
  demo.launch(show_api=True)
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
 
5
  # Load model
6
  model_id = "Rerandaka/Cild_safety_bigbird"
7
  tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
8
  model = AutoModelForSequenceClassification.from_pretrained(model_id)
9
 
10
+ # Classification function
11
  def classify(text):
12
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
13
  with torch.no_grad():
14
  logits = model(**inputs).logits
15
+ prediction = torch.argmax(logits, dim=1).item()
16
+ return prediction # 0 = safe, 1 = unsafe
17
 
18
+ # βœ… API-compatible Interface with explicit name
19
  demo = gr.Interface(
20
  fn=classify,
21
+ inputs=gr.Textbox(label="Enter paragraph..."),
22
+ outputs=gr.Number(label="Prediction (0=safe, 1=unsafe)"),
23
+ api_name="/classify" # πŸ”₯ This only works in gr.Interface (not Blocks)
24
  )
25
 
 
26
  demo.queue()
27
  demo.launch(show_api=True)