Update modeling_mambavision.py
Browse filesThere seems to be a typo in torch.nn.cross_entropy. This seems to be torch.nn.functional.cross_entropy
- modeling_mambavision.py +1 -1
modeling_mambavision.py
CHANGED
@@ -778,6 +778,6 @@ class MambaVisionModelForImageClassification(PreTrainedModel):
|
|
778 |
def forward(self, tensor, labels=None):
|
779 |
logits = self.model(tensor)
|
780 |
if labels is not None:
|
781 |
-
loss = torch.nn.cross_entropy(logits, labels)
|
782 |
return {"loss": loss, "logits": logits}
|
783 |
return {"logits": logits}
|
|
|
778 |
def forward(self, tensor, labels=None):
|
779 |
logits = self.model(tensor)
|
780 |
if labels is not None:
|
781 |
+
loss = torch.nn.functional.cross_entropy(logits, labels)
|
782 |
return {"loss": loss, "logits": logits}
|
783 |
return {"logits": logits}
|