Cricles commited on
Commit
2efd22f
·
verified ·
1 Parent(s): f3afcef

Update model_wrapper/bert_wrapper.py

Browse files
Files changed (1) hide show
  1. model_wrapper/bert_wrapper.py +1 -1
model_wrapper/bert_wrapper.py CHANGED
@@ -34,4 +34,4 @@ class BertWrapper(object):
34
  ) # output is logits for huggingfcae transformers
35
  predicted = torch.nn.functional.softmax(outputs.logits, dim=1)
36
  predicted_id = torch.argmax(predicted, dim=1).numpy()[0]
37
- return self.id2label[predicted_id], predicted[predicted_id]
 
34
  ) # output is logits for huggingfcae transformers
35
  predicted = torch.nn.functional.softmax(outputs.logits, dim=1)
36
  predicted_id = torch.argmax(predicted, dim=1).numpy()[0]
37
+ return self.id2label[predicted_id], predicted[0][predicted_id]