|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from typing import List, Optional |
|
from torch import Tensor |
|
from transformers import PretrainedConfig, PreTrainedModel |
|
|
|
|
|
|
|
class BlaserConfig(PretrainedConfig): |
|
model_type = "blaser" |
|
|
|
def __init__( |
|
self, |
|
embedding_dim=1024, |
|
output_dim=1, |
|
hidden_dims=None, |
|
dropout=0.1, |
|
activation="TANH", |
|
input_form="QE", |
|
norm_emb=True, |
|
output_act=False, |
|
**kwargs, |
|
): |
|
super().__init__(**kwargs) |
|
self.embedding_dim = embedding_dim |
|
self.output_dim = output_dim |
|
self.hidden_dims = hidden_dims if hidden_dims is not None else [3072, 1536] |
|
self.dropout = dropout |
|
self.activation = activation |
|
self.input_form = input_form |
|
self.norm_emb = norm_emb |
|
self.output_act = output_act |
|
|
|
|
|
|
|
ACTIVATIONS = {"TANH": nn.Tanh, "RELU": nn.ReLU} |
|
|
|
|
|
class BlaserCore(nn.Module): |
|
def __init__( |
|
self, |
|
embedding_dim: int, |
|
output_dim: int, |
|
hidden_dims: List[int], |
|
dropout: float, |
|
activation: str, |
|
input_form: str, |
|
norm_emb: bool, |
|
output_act: bool, |
|
): |
|
super().__init__() |
|
self.input_form = input_form |
|
self.norm_emb = norm_emb |
|
|
|
if input_form == "COMET": |
|
embedding_dim *= 6 |
|
elif input_form == "QE": |
|
embedding_dim *= 4 |
|
else: |
|
raise ValueError(f"Unrecognized input_form: {input_form}") |
|
if activation not in ACTIVATIONS: |
|
raise ValueError(f"Unrecognized activation: {activation}") |
|
|
|
modules: List[nn.Module] = [] |
|
if hidden_dims: |
|
if dropout > 0: |
|
modules.append(nn.Dropout(p=dropout)) |
|
nprev = embedding_dim |
|
for h in hidden_dims: |
|
modules.append(nn.Linear(nprev, h)) |
|
modules.append(ACTIVATIONS[activation]()) |
|
if dropout > 0: |
|
modules.append(nn.Dropout(p=dropout)) |
|
nprev = h |
|
modules.append(nn.Linear(nprev, output_dim)) |
|
if output_act: |
|
modules.append(nn.Tanh()) |
|
else: |
|
modules.append(nn.Linear(embedding_dim, output_dim)) |
|
|
|
self.mlp = nn.Sequential(*modules) |
|
|
|
def _norm(self, emb: Optional[Tensor]) -> Optional[Tensor]: |
|
return F.normalize(emb) if (emb is not None and self.norm_emb) else emb |
|
|
|
def _featurize(self, src: Tensor, mt: Tensor, ref: Optional[Tensor] = None) -> Tensor: |
|
if self.input_form == "COMET": |
|
if ref is None: |
|
raise ValueError("COMET input_form requires reference embedding") |
|
return torch.cat( |
|
[ref, mt, src * mt, ref * mt, torch.abs(mt - src), torch.abs(mt - ref)], |
|
dim=-1, |
|
) |
|
elif self.input_form == "QE": |
|
return torch.cat([src, mt, src * mt, torch.abs(mt - src)], dim=-1) |
|
|
|
|
|
|
|
class BlaserModel(PreTrainedModel): |
|
config_class = BlaserConfig |
|
|
|
def __init__(self, config: BlaserConfig): |
|
super().__init__(config) |
|
|
|
core = BlaserCore( |
|
embedding_dim=config.embedding_dim, |
|
output_dim=config.output_dim, |
|
hidden_dims=config.hidden_dims, |
|
dropout=config.dropout, |
|
activation=config.activation, |
|
input_form=config.input_form, |
|
norm_emb=config.norm_emb, |
|
output_act=config.output_act, |
|
) |
|
self.mlp = core.mlp |
|
self.input_form = core.input_form |
|
self.norm_emb = core.norm_emb |
|
|
|
def forward(self, src, mt, ref=None): |
|
|
|
src = F.normalize(src) if self.norm_emb else src |
|
mt = F.normalize(mt) if self.norm_emb else mt |
|
ref = F.normalize(ref) if (ref is not None and self.norm_emb) else ref |
|
|
|
if self.input_form == "COMET": |
|
if ref is None: |
|
raise ValueError("COMET input_form requires reference embedding") |
|
proc = torch.cat( |
|
[ref, mt, src * mt, ref * mt, torch.abs(mt - src), torch.abs(mt - ref)], |
|
dim=-1, |
|
) |
|
else: |
|
proc = torch.cat([src, mt, src * mt, torch.abs(mt - src)], dim=-1) |
|
|
|
return self.mlp(proc) |