oist commited on
Commit
03493f1
·
1 Parent(s): 1f4121b

Fix model code

Browse files
Files changed (1) hide show
  1. modeling_blaser.py +22 -19
modeling_blaser.py CHANGED
@@ -36,6 +36,7 @@ class BlaserConfig(PretrainedConfig):
36
  # ---------------- CORE MODEL ---------------- #
37
  ACTIVATIONS = {"TANH": nn.Tanh, "RELU": nn.ReLU}
38
 
 
39
  class BlaserCore(nn.Module):
40
  def __init__(
41
  self,
@@ -80,14 +81,6 @@ class BlaserCore(nn.Module):
80
 
81
  self.mlp = nn.Sequential(*modules)
82
 
83
- def forward(self, src: Tensor, mt: Tensor, ref: Optional[Tensor] = None) -> Tensor:
84
- proc = self._featurize(
85
- src=self._norm(src),
86
- mt=self._norm(mt),
87
- ref=self._norm(ref),
88
- )
89
- return self.mlp(proc)
90
-
91
  def _norm(self, emb: Optional[Tensor]) -> Optional[Tensor]:
92
  return F.normalize(emb) if (emb is not None and self.norm_emb) else emb
93
 
@@ -104,14 +97,13 @@ class BlaserCore(nn.Module):
104
 
105
 
106
  # ---------------- HF MODEL WRAPPER ---------------- #
107
-
108
  class BlaserModel(PreTrainedModel):
109
  config_class = BlaserConfig
110
 
111
  def __init__(self, config: BlaserConfig):
112
  super().__init__(config)
113
- # Instead of self.core, assign directly
114
- self.mlp = BlaserCore(
115
  embedding_dim=config.embedding_dim,
116
  output_dim=config.output_dim,
117
  hidden_dims=config.hidden_dims,
@@ -120,14 +112,25 @@ class BlaserModel(PreTrainedModel):
120
  input_form=config.input_form,
121
  norm_emb=config.norm_emb,
122
  output_act=config.output_act,
123
- ).mlp # only take the Sequential MLP
 
 
 
124
 
125
  def forward(self, src, mt, ref=None):
126
- # The old checkpoint expects the input feature processing inside BlaserCore
127
- proc = BlaserCore._featurize(
128
- self.mlp, # pass self as `self` for static call
129
- src=BlaserCore._norm(self.mlp, src),
130
- mt=BlaserCore._norm(self.mlp, mt),
131
- ref=BlaserCore._norm(self.mlp, ref)
132
- )
 
 
 
 
 
 
 
 
133
  return self.mlp(proc)
 
36
  # ---------------- CORE MODEL ---------------- #
37
  ACTIVATIONS = {"TANH": nn.Tanh, "RELU": nn.ReLU}
38
 
39
+
40
  class BlaserCore(nn.Module):
41
  def __init__(
42
  self,
 
81
 
82
  self.mlp = nn.Sequential(*modules)
83
 
 
 
 
 
 
 
 
 
84
  def _norm(self, emb: Optional[Tensor]) -> Optional[Tensor]:
85
  return F.normalize(emb) if (emb is not None and self.norm_emb) else emb
86
 
 
97
 
98
 
99
  # ---------------- HF MODEL WRAPPER ---------------- #
 
100
  class BlaserModel(PreTrainedModel):
101
  config_class = BlaserConfig
102
 
103
  def __init__(self, config: BlaserConfig):
104
  super().__init__(config)
105
+ # Directly assign the Sequential MLP to self.mlp
106
+ core = BlaserCore(
107
  embedding_dim=config.embedding_dim,
108
  output_dim=config.output_dim,
109
  hidden_dims=config.hidden_dims,
 
112
  input_form=config.input_form,
113
  norm_emb=config.norm_emb,
114
  output_act=config.output_act,
115
+ )
116
+ self.mlp = core.mlp
117
+ self.input_form = core.input_form
118
+ self.norm_emb = core.norm_emb
119
 
120
  def forward(self, src, mt, ref=None):
121
+ # Use the same featurization as in BlaserCore
122
+ src = F.normalize(src) if self.norm_emb else src
123
+ mt = F.normalize(mt) if self.norm_emb else mt
124
+ ref = F.normalize(ref) if (ref is not None and self.norm_emb) else ref
125
+
126
+ if self.input_form == "COMET":
127
+ if ref is None:
128
+ raise ValueError("COMET input_form requires reference embedding")
129
+ proc = torch.cat(
130
+ [ref, mt, src * mt, ref * mt, torch.abs(mt - src), torch.abs(mt - ref)],
131
+ dim=-1,
132
+ )
133
+ else: # QE
134
+ proc = torch.cat([src, mt, src * mt, torch.abs(mt - src)], dim=-1)
135
+
136
  return self.mlp(proc)