Spaces:
Running
on
Zero
Running
on
Zero
Initialize rope embeddings properly for the entropy model (#72)
Browse files
bytelatent/base_transformer.py
CHANGED
@@ -617,12 +617,8 @@ class BaseTransformer(nn.Module, SequenceModelWithOutput):
|
|
617 |
h = layer(h, freq_cis, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl)
|
618 |
return h
|
619 |
|
620 |
-
def reset_parameters(self):
|
621 |
-
# Either use fixed base std or sqrt model dim
|
622 |
-
self.rope_embeddings.reset_parameters()
|
623 |
-
|
624 |
def init_weights(self):
|
625 |
-
self.reset_parameters()
|
626 |
for depth, layer in enumerate(self.layers):
|
627 |
factor = {
|
628 |
InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5,
|
|
|
617 |
h = layer(h, freq_cis, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl)
|
618 |
return h
|
619 |
|
|
|
|
|
|
|
|
|
620 |
def init_weights(self):
|
621 |
+
self.rope_embeddings.reset_parameters()
|
622 |
for depth, layer in enumerate(self.layers):
|
623 |
factor = {
|
624 |
InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5,
|
bytelatent/transformer.py
CHANGED
@@ -116,10 +116,11 @@ class LMTransformer(BaseTransformer):
|
|
116 |
return logits
|
117 |
|
118 |
def reset_parameters(self, init_std=None):
|
119 |
-
# Either use fixed base std or sqrt model dim
|
120 |
-
super().reset_parameters()
|
121 |
-
init_std = init_std or (self.dim ** (-0.5))
|
122 |
self.norm.reset_parameters()
|
|
|
|
|
|
|
|
|
123 |
nn.init.trunc_normal_(
|
124 |
self.tok_embeddings.weight,
|
125 |
mean=0.0,
|
@@ -127,6 +128,8 @@ class LMTransformer(BaseTransformer):
|
|
127 |
a=-3 * init_std,
|
128 |
b=3 * init_std,
|
129 |
)
|
|
|
|
|
130 |
if not self.weight_tying:
|
131 |
nn.init.trunc_normal_(
|
132 |
self.output.weight,
|
|
|
116 |
return logits
|
117 |
|
118 |
def reset_parameters(self, init_std=None):
|
|
|
|
|
|
|
119 |
self.norm.reset_parameters()
|
120 |
+
|
121 |
+
def init_weights(self):
|
122 |
+
self.reset_parameters()
|
123 |
+
init_std = self.dim ** (-0.5)
|
124 |
nn.init.trunc_normal_(
|
125 |
self.tok_embeddings.weight,
|
126 |
mean=0.0,
|
|
|
128 |
a=-3 * init_std,
|
129 |
b=3 * init_std,
|
130 |
)
|
131 |
+
super().init_weights()
|
132 |
+
|
133 |
if not self.weight_tying:
|
134 |
nn.init.trunc_normal_(
|
135 |
self.output.weight,
|