import torch import torch.nn as nn import torch.nn.functional as F from typing import List, Optional from torch import Tensor from transformers import PretrainedConfig, PreTrainedModel # ---------------- CONFIG ---------------- # class BlaserConfig(PretrainedConfig): model_type = "blaser" def __init__( self, embedding_dim=1024, output_dim=1, hidden_dims=None, dropout=0.1, activation="TANH", input_form="QE", norm_emb=True, output_act=False, **kwargs, ): super().__init__(**kwargs) self.embedding_dim = embedding_dim self.output_dim = output_dim self.hidden_dims = hidden_dims if hidden_dims is not None else [3072, 1536] self.dropout = dropout self.activation = activation self.input_form = input_form self.norm_emb = norm_emb self.output_act = output_act # ---------------- CORE MODEL ---------------- # ACTIVATIONS = {"TANH": nn.Tanh, "RELU": nn.ReLU} class BlaserCore(nn.Module): def __init__( self, embedding_dim: int, output_dim: int, hidden_dims: List[int], dropout: float, activation: str, input_form: str, norm_emb: bool, output_act: bool, ): super().__init__() self.input_form = input_form self.norm_emb = norm_emb if input_form == "COMET": embedding_dim *= 6 elif input_form == "QE": embedding_dim *= 4 else: raise ValueError(f"Unrecognized input_form: {input_form}") if activation not in ACTIVATIONS: raise ValueError(f"Unrecognized activation: {activation}") modules: List[nn.Module] = [] if hidden_dims: if dropout > 0: modules.append(nn.Dropout(p=dropout)) nprev = embedding_dim for h in hidden_dims: modules.append(nn.Linear(nprev, h)) modules.append(ACTIVATIONS[activation]()) if dropout > 0: modules.append(nn.Dropout(p=dropout)) nprev = h modules.append(nn.Linear(nprev, output_dim)) if output_act: modules.append(nn.Tanh()) else: modules.append(nn.Linear(embedding_dim, output_dim)) self.mlp = nn.Sequential(*modules) def _norm(self, emb: Optional[Tensor]) -> Optional[Tensor]: return F.normalize(emb) if (emb is not None and self.norm_emb) else emb def _featurize(self, src: Tensor, mt: Tensor, ref: Optional[Tensor] = None) -> Tensor: if self.input_form == "COMET": if ref is None: raise ValueError("COMET input_form requires reference embedding") return torch.cat( [ref, mt, src * mt, ref * mt, torch.abs(mt - src), torch.abs(mt - ref)], dim=-1, ) elif self.input_form == "QE": return torch.cat([src, mt, src * mt, torch.abs(mt - src)], dim=-1) # ---------------- HF MODEL WRAPPER ---------------- # class BlaserModel(PreTrainedModel): config_class = BlaserConfig def __init__(self, config: BlaserConfig): super().__init__(config) # Directly assign the Sequential MLP to self.mlp core = BlaserCore( embedding_dim=config.embedding_dim, output_dim=config.output_dim, hidden_dims=config.hidden_dims, dropout=config.dropout, activation=config.activation, input_form=config.input_form, norm_emb=config.norm_emb, output_act=config.output_act, ) self.mlp = core.mlp self.input_form = core.input_form self.norm_emb = core.norm_emb def forward(self, src, mt, ref=None): # Use the same featurization as in BlaserCore src = F.normalize(src) if self.norm_emb else src mt = F.normalize(mt) if self.norm_emb else mt ref = F.normalize(ref) if (ref is not None and self.norm_emb) else ref if self.input_form == "COMET": if ref is None: raise ValueError("COMET input_form requires reference embedding") proc = torch.cat( [ref, mt, src * mt, ref * mt, torch.abs(mt - src), torch.abs(mt - ref)], dim=-1, ) else: # QE proc = torch.cat([src, mt, src * mt, torch.abs(mt - src)], dim=-1) return self.mlp(proc)