ahatamiz commited on
Commit
7ad1881
·
verified ·
1 Parent(s): fd2550f

Upload model

Browse files
config.json CHANGED
@@ -1,10 +1,10 @@
1
  {
2
  "architectures": [
3
- "MambaVisionModelForImageClassification"
4
  ],
5
  "auto_map": {
6
  "AutoConfig": "configuration_mambavision.MambaVisionConfig",
7
- "AutoModelForImageClassification": "modeling_mambavision.MambaVisionModelForImageClassification"
8
  },
9
  "depths": [
10
  1,
 
1
  {
2
  "architectures": [
3
+ "MambaVisionModel"
4
  ],
5
  "auto_map": {
6
  "AutoConfig": "configuration_mambavision.MambaVisionConfig",
7
+ "AutoModel": "modeling_mambavision.MambaVisionModel"
8
  },
9
  "depths": [
10
  1,
configuration_mambavision.py CHANGED
@@ -1,6 +1,4 @@
1
  from transformers import PretrainedConfig
2
- from typing import List
3
-
4
 
5
  class MambaVisionConfig(PretrainedConfig):
6
  model_type = "mambavision"
 
1
  from transformers import PretrainedConfig
 
 
2
 
3
  class MambaVisionConfig(PretrainedConfig):
4
  model_type = "mambavision"
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:9970e390c1e0014ceaee412e485de6016e8db68aef6823dcc2801e838b2be114
3
  size 127219000
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4f6987be0a2ca2222f386eb750d028a05203b047d3c8dfb664c27e2295d02fc0
3
  size 127219000
modeling_mambavision.py CHANGED
@@ -28,7 +28,7 @@ from einops import rearrange, repeat
28
 
29
  from transformers import PreTrainedModel
30
 
31
- from .configuration_mambavision import MambaVisionConfig
32
 
33
 
34
  def _cfg(url='', **kwargs):
@@ -602,8 +602,8 @@ class MambaVisionLayer(nn.Module):
602
  if pad_r > 0 or pad_b > 0:
603
  x = x[:, :, :H, :W].contiguous()
604
  if self.downsample is None:
605
- return x
606
- return self.downsample(x)
607
 
608
 
609
  class MambaVision(nn.Module):
@@ -697,15 +697,17 @@ class MambaVision(nn.Module):
697
 
698
  def forward_features(self, x):
699
  x = self.patch_embed(x)
 
700
  for level in self.levels:
701
- x = level(x)
 
702
  x = self.norm(x)
703
  x = self.avgpool(x)
704
  x = torch.flatten(x, 1)
705
- return x
706
 
707
  def forward(self, x):
708
- x = self.forward_features(x)
709
  x = self.head(x)
710
  return x
711
 
 
28
 
29
  from transformers import PreTrainedModel
30
 
31
+ from configuration_mambavision import MambaVisionConfig
32
 
33
 
34
  def _cfg(url='', **kwargs):
 
602
  if pad_r > 0 or pad_b > 0:
603
  x = x[:, :, :H, :W].contiguous()
604
  if self.downsample is None:
605
+ return x, x
606
+ return self.downsample(x), x
607
 
608
 
609
  class MambaVision(nn.Module):
 
697
 
698
  def forward_features(self, x):
699
  x = self.patch_embed(x)
700
+ outs = []
701
  for level in self.levels:
702
+ x, xo = level(x)
703
+ outs.append(xo)
704
  x = self.norm(x)
705
  x = self.avgpool(x)
706
  x = torch.flatten(x, 1)
707
+ return x, outs
708
 
709
  def forward(self, x):
710
+ x, outs = self.forward_features(x)
711
  x = self.head(x)
712
  return x
713