Removed timm import
Browse files- modeling_vmamba.py +34 -2
modeling_vmamba.py
CHANGED
@@ -36,7 +36,6 @@ import warnings
|
|
36 |
import torch.nn as nn
|
37 |
import torch.nn.functional as F
|
38 |
import torch.utils.checkpoint as checkpoint
|
39 |
-
from timm.models.layers import DropPath, trunc_normal_
|
40 |
from functools import partial
|
41 |
from typing import Optional, Callable, Any, Union
|
42 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss
|
@@ -744,6 +743,39 @@ def selective_scan_fn(
|
|
744 |
############## HuggingFace modeling file #################
|
745 |
##########################################################
|
746 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
747 |
class VMambaLinear2d(nn.Linear):
|
748 |
def __init__(self, *args, groups=1, **kwargs):
|
749 |
nn.Linear.__init__(self, *args, **kwargs)
|
@@ -1118,7 +1150,7 @@ class VMambaPreTrainedModel(PreTrainedModel):
|
|
1118 |
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
1119 |
"""Initialize the weights"""
|
1120 |
if isinstance(module, nn.Linear):
|
1121 |
-
trunc_normal_(module.weight, std=0.02)
|
1122 |
if isinstance(module, nn.Linear) and module.bias is not None:
|
1123 |
nn.init.constant_(module.bias, 0)
|
1124 |
elif isinstance(module, nn.LayerNorm):
|
|
|
36 |
import torch.nn as nn
|
37 |
import torch.nn.functional as F
|
38 |
import torch.utils.checkpoint as checkpoint
|
|
|
39 |
from functools import partial
|
40 |
from typing import Optional, Callable, Any, Union
|
41 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss
|
|
|
743 |
############## HuggingFace modeling file #################
|
744 |
##########################################################
|
745 |
|
746 |
+
# DropPath from https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
|
747 |
+
|
748 |
+
def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
|
749 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
750 |
+
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
751 |
+
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
752 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
753 |
+
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
754 |
+
'survival rate' as the argument.
|
755 |
+
"""
|
756 |
+
if drop_prob == 0. or not training:
|
757 |
+
return x
|
758 |
+
keep_prob = 1 - drop_prob
|
759 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
760 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
761 |
+
if keep_prob > 0.0 and scale_by_keep:
|
762 |
+
random_tensor.div_(keep_prob)
|
763 |
+
return x * random_tensor
|
764 |
+
|
765 |
+
class DropPath(nn.Module):
|
766 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
767 |
+
"""
|
768 |
+
def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
|
769 |
+
super(DropPath, self).__init__()
|
770 |
+
self.drop_prob = drop_prob
|
771 |
+
self.scale_by_keep = scale_by_keep
|
772 |
+
|
773 |
+
def forward(self, x):
|
774 |
+
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
|
775 |
+
|
776 |
+
def extra_repr(self):
|
777 |
+
return f'drop_prob={round(self.drop_prob,3):0.3f}'
|
778 |
+
|
779 |
class VMambaLinear2d(nn.Linear):
|
780 |
def __init__(self, *args, groups=1, **kwargs):
|
781 |
nn.Linear.__init__(self, *args, **kwargs)
|
|
|
1150 |
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
1151 |
"""Initialize the weights"""
|
1152 |
if isinstance(module, nn.Linear):
|
1153 |
+
nn.init.trunc_normal_(module.weight, std=0.02)
|
1154 |
if isinstance(module, nn.Linear) and module.bias is not None:
|
1155 |
nn.init.constant_(module.bias, 0)
|
1156 |
elif isinstance(module, nn.LayerNorm):
|