removed no_weight_decay from modeling_dass.py
Browse files- modeling_dass.py +2 -2
modeling_dass.py
CHANGED
|
@@ -942,7 +942,7 @@ class SS2D(nn.Module):
|
|
| 942 |
if merge:
|
| 943 |
A_log = A_log.flatten(0, 1)
|
| 944 |
A_log = nn.Parameter(A_log)
|
| 945 |
-
A_log._no_weight_decay = True
|
| 946 |
return A_log
|
| 947 |
|
| 948 |
@staticmethod
|
|
@@ -953,7 +953,7 @@ class SS2D(nn.Module):
|
|
| 953 |
if merge:
|
| 954 |
D = D.flatten(0, 1)
|
| 955 |
D = nn.Parameter(D)
|
| 956 |
-
D._no_weight_decay = True
|
| 957 |
return D
|
| 958 |
|
| 959 |
@classmethod
|
|
|
|
| 942 |
if merge:
|
| 943 |
A_log = A_log.flatten(0, 1)
|
| 944 |
A_log = nn.Parameter(A_log)
|
| 945 |
+
#A_log._no_weight_decay = True
|
| 946 |
return A_log
|
| 947 |
|
| 948 |
@staticmethod
|
|
|
|
| 953 |
if merge:
|
| 954 |
D = D.flatten(0, 1)
|
| 955 |
D = nn.Parameter(D)
|
| 956 |
+
#D._no_weight_decay = True
|
| 957 |
return D
|
| 958 |
|
| 959 |
@classmethod
|