blaser_2_0_qe_ported / modeling_blaser.py
oist's picture
Fix model code
03493f1
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Optional
from torch import Tensor
from transformers import PretrainedConfig, PreTrainedModel
# ---------------- CONFIG ---------------- #
class BlaserConfig(PretrainedConfig):
model_type = "blaser"
def __init__(
self,
embedding_dim=1024,
output_dim=1,
hidden_dims=None,
dropout=0.1,
activation="TANH",
input_form="QE",
norm_emb=True,
output_act=False,
**kwargs,
):
super().__init__(**kwargs)
self.embedding_dim = embedding_dim
self.output_dim = output_dim
self.hidden_dims = hidden_dims if hidden_dims is not None else [3072, 1536]
self.dropout = dropout
self.activation = activation
self.input_form = input_form
self.norm_emb = norm_emb
self.output_act = output_act
# ---------------- CORE MODEL ---------------- #
ACTIVATIONS = {"TANH": nn.Tanh, "RELU": nn.ReLU}
class BlaserCore(nn.Module):
def __init__(
self,
embedding_dim: int,
output_dim: int,
hidden_dims: List[int],
dropout: float,
activation: str,
input_form: str,
norm_emb: bool,
output_act: bool,
):
super().__init__()
self.input_form = input_form
self.norm_emb = norm_emb
if input_form == "COMET":
embedding_dim *= 6
elif input_form == "QE":
embedding_dim *= 4
else:
raise ValueError(f"Unrecognized input_form: {input_form}")
if activation not in ACTIVATIONS:
raise ValueError(f"Unrecognized activation: {activation}")
modules: List[nn.Module] = []
if hidden_dims:
if dropout > 0:
modules.append(nn.Dropout(p=dropout))
nprev = embedding_dim
for h in hidden_dims:
modules.append(nn.Linear(nprev, h))
modules.append(ACTIVATIONS[activation]())
if dropout > 0:
modules.append(nn.Dropout(p=dropout))
nprev = h
modules.append(nn.Linear(nprev, output_dim))
if output_act:
modules.append(nn.Tanh())
else:
modules.append(nn.Linear(embedding_dim, output_dim))
self.mlp = nn.Sequential(*modules)
def _norm(self, emb: Optional[Tensor]) -> Optional[Tensor]:
return F.normalize(emb) if (emb is not None and self.norm_emb) else emb
def _featurize(self, src: Tensor, mt: Tensor, ref: Optional[Tensor] = None) -> Tensor:
if self.input_form == "COMET":
if ref is None:
raise ValueError("COMET input_form requires reference embedding")
return torch.cat(
[ref, mt, src * mt, ref * mt, torch.abs(mt - src), torch.abs(mt - ref)],
dim=-1,
)
elif self.input_form == "QE":
return torch.cat([src, mt, src * mt, torch.abs(mt - src)], dim=-1)
# ---------------- HF MODEL WRAPPER ---------------- #
class BlaserModel(PreTrainedModel):
config_class = BlaserConfig
def __init__(self, config: BlaserConfig):
super().__init__(config)
# Directly assign the Sequential MLP to self.mlp
core = BlaserCore(
embedding_dim=config.embedding_dim,
output_dim=config.output_dim,
hidden_dims=config.hidden_dims,
dropout=config.dropout,
activation=config.activation,
input_form=config.input_form,
norm_emb=config.norm_emb,
output_act=config.output_act,
)
self.mlp = core.mlp
self.input_form = core.input_form
self.norm_emb = core.norm_emb
def forward(self, src, mt, ref=None):
# Use the same featurization as in BlaserCore
src = F.normalize(src) if self.norm_emb else src
mt = F.normalize(mt) if self.norm_emb else mt
ref = F.normalize(ref) if (ref is not None and self.norm_emb) else ref
if self.input_form == "COMET":
if ref is None:
raise ValueError("COMET input_form requires reference embedding")
proc = torch.cat(
[ref, mt, src * mt, ref * mt, torch.abs(mt - src), torch.abs(mt - ref)],
dim=-1,
)
else: # QE
proc = torch.cat([src, mt, src * mt, torch.abs(mt - src)], dim=-1)
return self.mlp(proc)