jeongminl commited on
Commit
0b6af33
·
verified ·
1 Parent(s): bfe552a

Update modeling_mambavision.py

Browse files

There seems to be a typo in torch.nn.cross_entropy. This seems to be torch.nn.functional.cross_entropy

Files changed (1) hide show
  1. modeling_mambavision.py +1 -1
modeling_mambavision.py CHANGED
@@ -778,6 +778,6 @@ class MambaVisionModelForImageClassification(PreTrainedModel):
778
  def forward(self, tensor, labels=None):
779
  logits = self.model(tensor)
780
  if labels is not None:
781
- loss = torch.nn.cross_entropy(logits, labels)
782
  return {"loss": loss, "logits": logits}
783
  return {"logits": logits}
 
778
  def forward(self, tensor, labels=None):
779
  logits = self.model(tensor)
780
  if labels is not None:
781
+ loss = torch.nn.functional.cross_entropy(logits, labels)
782
  return {"loss": loss, "logits": logits}
783
  return {"logits": logits}