Srinivasan Iyer sviyer commited on
Commit
0da051f
·
unverified ·
1 Parent(s): aeb95f1

Initialize rope embeddings properly for the entropy model (#72)

Browse files
bytelatent/base_transformer.py CHANGED
@@ -617,12 +617,8 @@ class BaseTransformer(nn.Module, SequenceModelWithOutput):
617
  h = layer(h, freq_cis, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl)
618
  return h
619
 
620
- def reset_parameters(self):
621
- # Either use fixed base std or sqrt model dim
622
- self.rope_embeddings.reset_parameters()
623
-
624
  def init_weights(self):
625
- self.reset_parameters()
626
  for depth, layer in enumerate(self.layers):
627
  factor = {
628
  InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5,
 
617
  h = layer(h, freq_cis, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl)
618
  return h
619
 
 
 
 
 
620
  def init_weights(self):
621
+ self.rope_embeddings.reset_parameters()
622
  for depth, layer in enumerate(self.layers):
623
  factor = {
624
  InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5,
bytelatent/transformer.py CHANGED
@@ -116,10 +116,11 @@ class LMTransformer(BaseTransformer):
116
  return logits
117
 
118
  def reset_parameters(self, init_std=None):
119
- # Either use fixed base std or sqrt model dim
120
- super().reset_parameters()
121
- init_std = init_std or (self.dim ** (-0.5))
122
  self.norm.reset_parameters()
 
 
 
 
123
  nn.init.trunc_normal_(
124
  self.tok_embeddings.weight,
125
  mean=0.0,
@@ -127,6 +128,8 @@ class LMTransformer(BaseTransformer):
127
  a=-3 * init_std,
128
  b=3 * init_std,
129
  )
 
 
130
  if not self.weight_tying:
131
  nn.init.trunc_normal_(
132
  self.output.weight,
 
116
  return logits
117
 
118
  def reset_parameters(self, init_std=None):
 
 
 
119
  self.norm.reset_parameters()
120
+
121
+ def init_weights(self):
122
+ self.reset_parameters()
123
+ init_std = self.dim ** (-0.5)
124
  nn.init.trunc_normal_(
125
  self.tok_embeddings.weight,
126
  mean=0.0,
 
128
  a=-3 * init_std,
129
  b=3 * init_std,
130
  )
131
+ super().init_weights()
132
+
133
  if not self.weight_tying:
134
  nn.init.trunc_normal_(
135
  self.output.weight,