Fix model code
Browse files- modeling_blaser.py +22 -19
modeling_blaser.py
CHANGED
@@ -36,6 +36,7 @@ class BlaserConfig(PretrainedConfig):
|
|
36 |
# ---------------- CORE MODEL ---------------- #
|
37 |
ACTIVATIONS = {"TANH": nn.Tanh, "RELU": nn.ReLU}
|
38 |
|
|
|
39 |
class BlaserCore(nn.Module):
|
40 |
def __init__(
|
41 |
self,
|
@@ -80,14 +81,6 @@ class BlaserCore(nn.Module):
|
|
80 |
|
81 |
self.mlp = nn.Sequential(*modules)
|
82 |
|
83 |
-
def forward(self, src: Tensor, mt: Tensor, ref: Optional[Tensor] = None) -> Tensor:
|
84 |
-
proc = self._featurize(
|
85 |
-
src=self._norm(src),
|
86 |
-
mt=self._norm(mt),
|
87 |
-
ref=self._norm(ref),
|
88 |
-
)
|
89 |
-
return self.mlp(proc)
|
90 |
-
|
91 |
def _norm(self, emb: Optional[Tensor]) -> Optional[Tensor]:
|
92 |
return F.normalize(emb) if (emb is not None and self.norm_emb) else emb
|
93 |
|
@@ -104,14 +97,13 @@ class BlaserCore(nn.Module):
|
|
104 |
|
105 |
|
106 |
# ---------------- HF MODEL WRAPPER ---------------- #
|
107 |
-
|
108 |
class BlaserModel(PreTrainedModel):
|
109 |
config_class = BlaserConfig
|
110 |
|
111 |
def __init__(self, config: BlaserConfig):
|
112 |
super().__init__(config)
|
113 |
-
#
|
114 |
-
|
115 |
embedding_dim=config.embedding_dim,
|
116 |
output_dim=config.output_dim,
|
117 |
hidden_dims=config.hidden_dims,
|
@@ -120,14 +112,25 @@ class BlaserModel(PreTrainedModel):
|
|
120 |
input_form=config.input_form,
|
121 |
norm_emb=config.norm_emb,
|
122 |
output_act=config.output_act,
|
123 |
-
)
|
|
|
|
|
|
|
124 |
|
125 |
def forward(self, src, mt, ref=None):
|
126 |
-
#
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
return self.mlp(proc)
|
|
|
36 |
# ---------------- CORE MODEL ---------------- #
|
37 |
ACTIVATIONS = {"TANH": nn.Tanh, "RELU": nn.ReLU}
|
38 |
|
39 |
+
|
40 |
class BlaserCore(nn.Module):
|
41 |
def __init__(
|
42 |
self,
|
|
|
81 |
|
82 |
self.mlp = nn.Sequential(*modules)
|
83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
def _norm(self, emb: Optional[Tensor]) -> Optional[Tensor]:
|
85 |
return F.normalize(emb) if (emb is not None and self.norm_emb) else emb
|
86 |
|
|
|
97 |
|
98 |
|
99 |
# ---------------- HF MODEL WRAPPER ---------------- #
|
|
|
100 |
class BlaserModel(PreTrainedModel):
|
101 |
config_class = BlaserConfig
|
102 |
|
103 |
def __init__(self, config: BlaserConfig):
|
104 |
super().__init__(config)
|
105 |
+
# Directly assign the Sequential MLP to self.mlp
|
106 |
+
core = BlaserCore(
|
107 |
embedding_dim=config.embedding_dim,
|
108 |
output_dim=config.output_dim,
|
109 |
hidden_dims=config.hidden_dims,
|
|
|
112 |
input_form=config.input_form,
|
113 |
norm_emb=config.norm_emb,
|
114 |
output_act=config.output_act,
|
115 |
+
)
|
116 |
+
self.mlp = core.mlp
|
117 |
+
self.input_form = core.input_form
|
118 |
+
self.norm_emb = core.norm_emb
|
119 |
|
120 |
def forward(self, src, mt, ref=None):
|
121 |
+
# Use the same featurization as in BlaserCore
|
122 |
+
src = F.normalize(src) if self.norm_emb else src
|
123 |
+
mt = F.normalize(mt) if self.norm_emb else mt
|
124 |
+
ref = F.normalize(ref) if (ref is not None and self.norm_emb) else ref
|
125 |
+
|
126 |
+
if self.input_form == "COMET":
|
127 |
+
if ref is None:
|
128 |
+
raise ValueError("COMET input_form requires reference embedding")
|
129 |
+
proc = torch.cat(
|
130 |
+
[ref, mt, src * mt, ref * mt, torch.abs(mt - src), torch.abs(mt - ref)],
|
131 |
+
dim=-1,
|
132 |
+
)
|
133 |
+
else: # QE
|
134 |
+
proc = torch.cat([src, mt, src * mt, torch.abs(mt - src)], dim=-1)
|
135 |
+
|
136 |
return self.mlp(proc)
|