FikriRiyadi commited on
Commit
09ca6a5
·
verified ·
1 Parent(s): c58d472

Update predict_utils.py

Browse files
Files changed (1) hide show
  1. predict_utils.py +4 -14
predict_utils.py CHANGED
@@ -3,11 +3,8 @@ import numpy as np
3
  from transformers import BertTokenizer
4
  from model import HybridModel
5
 
6
- LABELS = [
7
- 'HS', 'Abusive', 'HS_Individual', 'HS_Group', 'HS_Religion', 'HS_Race',
8
- 'HS_Physical', 'HS_Gender', 'HS_Other', 'HS_Weak', 'HS_Moderate', 'HS_Strong'
9
- ]
10
-
11
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
 
13
  def load_model_and_thresholds():
@@ -21,18 +18,11 @@ def load_model_and_thresholds():
21
  return model, tokenizer, thresholds
22
 
23
  def predict(text, model, tokenizer, thresholds):
24
- encoding = tokenizer(
25
- text,
26
- return_tensors='pt',
27
- padding='max_length',
28
- truncation=True,
29
- max_length=128
30
- )
31
  input_ids = encoding["input_ids"].to(DEVICE)
32
  attention_mask = encoding["attention_mask"].to(DEVICE)
33
 
34
  with torch.no_grad():
35
  probs = model(input_ids, attention_mask).squeeze(0).cpu().numpy()
36
 
37
- result = {label: float(prob) for label, prob in zip(LABELS, probs)}
38
- return result
 
3
  from transformers import BertTokenizer
4
  from model import HybridModel
5
 
6
+ LABELS = ['HS', 'Abusive', 'HS_Individual', 'HS_Group', 'HS_Religion', 'HS_Race',
7
+ 'HS_Physical', 'HS_Gender', 'HS_Other', 'HS_Weak', 'HS_Moderate', 'HS_Strong']
 
 
 
8
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
 
10
  def load_model_and_thresholds():
 
18
  return model, tokenizer, thresholds
19
 
20
  def predict(text, model, tokenizer, thresholds):
21
+ encoding = tokenizer(text, return_tensors='pt', padding='max_length', truncation=True, max_length=128)
 
 
 
 
 
 
22
  input_ids = encoding["input_ids"].to(DEVICE)
23
  attention_mask = encoding["attention_mask"].to(DEVICE)
24
 
25
  with torch.no_grad():
26
  probs = model(input_ids, attention_mask).squeeze(0).cpu().numpy()
27
 
28
+ return {label: float(prob) for label, prob in zip(LABELS, probs)}