oist commited on
Commit
5c4f06c
·
1 Parent(s): 3bd5472

Fix model code

Browse files
Files changed (1) hide show
  1. modeling_blaser.py +12 -3
modeling_blaser.py CHANGED
@@ -104,12 +104,14 @@ class BlaserCore(nn.Module):
104
 
105
 
106
  # ---------------- HF MODEL WRAPPER ---------------- #
 
107
  class BlaserModel(PreTrainedModel):
108
  config_class = BlaserConfig
109
 
110
  def __init__(self, config: BlaserConfig):
111
  super().__init__(config)
112
- self.core = BlaserCore(
 
113
  embedding_dim=config.embedding_dim,
114
  output_dim=config.output_dim,
115
  hidden_dims=config.hidden_dims,
@@ -118,7 +120,14 @@ class BlaserModel(PreTrainedModel):
118
  input_form=config.input_form,
119
  norm_emb=config.norm_emb,
120
  output_act=config.output_act,
121
- )
122
 
123
  def forward(self, src, mt, ref=None):
124
- return self.core(src, mt, ref)
 
 
 
 
 
 
 
 
104
 
105
 
106
  # ---------------- HF MODEL WRAPPER ---------------- #
107
+
108
  class BlaserModel(PreTrainedModel):
109
  config_class = BlaserConfig
110
 
111
  def __init__(self, config: BlaserConfig):
112
  super().__init__(config)
113
+ # Instead of self.core, assign directly
114
+ self.mlp = BlaserCore(
115
  embedding_dim=config.embedding_dim,
116
  output_dim=config.output_dim,
117
  hidden_dims=config.hidden_dims,
 
120
  input_form=config.input_form,
121
  norm_emb=config.norm_emb,
122
  output_act=config.output_act,
123
+ ).mlp # only take the Sequential MLP
124
 
125
  def forward(self, src, mt, ref=None):
126
+ # The old checkpoint expects the input feature processing inside BlaserCore
127
+ proc = BlaserCore._featurize(
128
+ self.mlp, # pass self as `self` for static call
129
+ src=BlaserCore._norm(self.mlp, src),
130
+ mt=BlaserCore._norm(self.mlp, mt),
131
+ ref=BlaserCore._norm(self.mlp, ref)
132
+ )
133
+ return self.mlp(proc)