Fix model code
Browse files- modeling_blaser.py +12 -3
modeling_blaser.py
CHANGED
@@ -104,12 +104,14 @@ class BlaserCore(nn.Module):
|
|
104 |
|
105 |
|
106 |
# ---------------- HF MODEL WRAPPER ---------------- #
|
|
|
107 |
class BlaserModel(PreTrainedModel):
|
108 |
config_class = BlaserConfig
|
109 |
|
110 |
def __init__(self, config: BlaserConfig):
|
111 |
super().__init__(config)
|
112 |
-
self.core
|
|
|
113 |
embedding_dim=config.embedding_dim,
|
114 |
output_dim=config.output_dim,
|
115 |
hidden_dims=config.hidden_dims,
|
@@ -118,7 +120,14 @@ class BlaserModel(PreTrainedModel):
|
|
118 |
input_form=config.input_form,
|
119 |
norm_emb=config.norm_emb,
|
120 |
output_act=config.output_act,
|
121 |
-
)
|
122 |
|
123 |
def forward(self, src, mt, ref=None):
|
124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
|
105 |
|
106 |
# ---------------- HF MODEL WRAPPER ---------------- #
|
107 |
+
|
108 |
class BlaserModel(PreTrainedModel):
|
109 |
config_class = BlaserConfig
|
110 |
|
111 |
def __init__(self, config: BlaserConfig):
|
112 |
super().__init__(config)
|
113 |
+
# Instead of self.core, assign directly
|
114 |
+
self.mlp = BlaserCore(
|
115 |
embedding_dim=config.embedding_dim,
|
116 |
output_dim=config.output_dim,
|
117 |
hidden_dims=config.hidden_dims,
|
|
|
120 |
input_form=config.input_form,
|
121 |
norm_emb=config.norm_emb,
|
122 |
output_act=config.output_act,
|
123 |
+
).mlp # only take the Sequential MLP
|
124 |
|
125 |
def forward(self, src, mt, ref=None):
|
126 |
+
# The old checkpoint expects the input feature processing inside BlaserCore
|
127 |
+
proc = BlaserCore._featurize(
|
128 |
+
self.mlp, # pass self as `self` for static call
|
129 |
+
src=BlaserCore._norm(self.mlp, src),
|
130 |
+
mt=BlaserCore._norm(self.mlp, mt),
|
131 |
+
ref=BlaserCore._norm(self.mlp, ref)
|
132 |
+
)
|
133 |
+
return self.mlp(proc)
|