Upload ProkBertForMaskedLM
Browse files
models.py
CHANGED
|
@@ -9,7 +9,7 @@ import torch.nn.functional as F
|
|
| 9 |
from transformers import MegatronBertConfig, MegatronBertModel, MegatronBertForMaskedLM, MegatronBertPreTrainedModel, PreTrainedModel
|
| 10 |
from transformers.modeling_outputs import SequenceClassifierOutput
|
| 11 |
from transformers.utils.hub import cached_file
|
| 12 |
-
from prokbert.training_utils import compute_metrics_eval_prediction
|
| 13 |
|
| 14 |
class BertForBinaryClassificationWithPooling(nn.Module):
|
| 15 |
"""
|
|
@@ -274,7 +274,7 @@ class ProkBertForSequenceClassification(ProkBertPreTrainedModel):
|
|
| 274 |
loss = None
|
| 275 |
if labels is not None:
|
| 276 |
loss = self.loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
|
| 277 |
-
|
| 278 |
classification_output = SequenceClassifierOutput(
|
| 279 |
loss=loss,
|
| 280 |
logits=logits,
|
|
|
|
| 9 |
from transformers import MegatronBertConfig, MegatronBertModel, MegatronBertForMaskedLM, MegatronBertPreTrainedModel, PreTrainedModel
|
| 10 |
from transformers.modeling_outputs import SequenceClassifierOutput
|
| 11 |
from transformers.utils.hub import cached_file
|
| 12 |
+
#from prokbert.training_utils import compute_metrics_eval_prediction
|
| 13 |
|
| 14 |
class BertForBinaryClassificationWithPooling(nn.Module):
|
| 15 |
"""
|
|
|
|
| 274 |
loss = None
|
| 275 |
if labels is not None:
|
| 276 |
loss = self.loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
|
| 277 |
+
|
| 278 |
classification_output = SequenceClassifierOutput(
|
| 279 |
loss=loss,
|
| 280 |
logits=logits,
|