|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
from transformers import PreTrainedModel, PretrainedConfig, AutoModel |
|
|
|
class EmbeddingMoEConfig(PretrainedConfig): |
|
model_type = "embedding_moe" |
|
def __init__(self, output_dim=128, num_experts=2, dropout_rate=0.1, **kwargs): |
|
super().__init__(**kwargs) |
|
self.output_dim = output_dim |
|
self.num_experts = num_experts |
|
self.dropout_rate = dropout_rate |
|
|
|
|
|
|
|
class EmbeddingExpert(nn.Module): |
|
def __init__(self, model_name, output_dim, dropout_rate=0.1): |
|
super().__init__() |
|
self.base = AutoModel.from_pretrained(model_name) |
|
self.layer_norm = nn.LayerNorm(self.base.config.hidden_size) |
|
self.dropout = nn.Dropout(dropout_rate) |
|
for param in self.base.parameters(): |
|
param.requires_grad = False |
|
|
|
|
|
self.projection = nn.Linear(self.base.config.hidden_size, output_dim) |
|
nn.init.xavier_uniform_(self.projection.weight) |
|
nn.init.zeros_(self.projection.bias) |
|
|
|
def mean_pooling(self, model_output, attention_mask): |
|
|
|
token_embeddings = model_output.last_hidden_state |
|
input_mask_expanded = ( |
|
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
|
) |
|
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( |
|
input_mask_expanded.sum(1), min=1e-9 |
|
) |
|
|
|
def forward(self, input_ids, attention_mask): |
|
outputs = self.base(input_ids=input_ids, attention_mask=attention_mask) |
|
pooled_output = self.mean_pooling(outputs, attention_mask) |
|
pooled_output = self.layer_norm(pooled_output) |
|
pooled_output = self.dropout(pooled_output) |
|
embedding = self.projection(pooled_output) |
|
embedding = F.normalize(embedding, p=2, dim=1) |
|
|
|
return embedding |
|
|
|
|
|
|
|
class GatingNetwork(nn.Module): |
|
def __init__(self, input_dim, hidden_dim, num_experts, dropout_rate=0.1): |
|
super().__init__() |
|
self.layer_norm = nn.LayerNorm(input_dim) |
|
self.dropout = nn.Dropout(dropout_rate) |
|
self.linear1 = nn.Linear(input_dim, hidden_dim) |
|
self.relu = nn.ReLU() |
|
self.linear2 = nn.Linear(hidden_dim, num_experts) |
|
self.softmax = nn.Softmax(dim=-1) |
|
|
|
nn.init.xavier_uniform_(self.linear1.weight) |
|
nn.init.zeros_(self.linear1.bias) |
|
nn.init.xavier_uniform_(self.linear2.weight) |
|
nn.init.zeros_(self.linear2.bias) |
|
|
|
def forward(self, x): |
|
x = self.layer_norm(x) |
|
x = self.dropout(x) |
|
|
|
x = self.linear1(x) |
|
x = self.relu(x) |
|
x = self.linear2(x) |
|
x = torch.clamp(x, min=-10, max=10) |
|
x = self.softmax(x) |
|
return x |
|
|
|
|
|
|
|
class EmbeddingMoE(PreTrainedModel): |
|
config_class = EmbeddingMoEConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
output_dim = config.output_dim if hasattr(config, "output_dim") else 128 |
|
num_experts = config.num_experts if hasattr(config, "num_experts") else 2 |
|
|
|
self.expert1 = EmbeddingExpert("bert-base-uncased", output_dim) |
|
self.expert2 = EmbeddingExpert("bert-base-uncased", output_dim) |
|
self.gating = GatingNetwork(output_dim, 256, num_experts) |
|
|
|
def forward(self, input_ids, attention_mask): |
|
|
|
expert1_output = self.expert1(input_ids, attention_mask) |
|
expert2_output = self.expert2(input_ids, attention_mask) |
|
|
|
|
|
gating_input = (expert1_output + expert2_output) / 2 |
|
|
|
|
|
gating_output = self.gating(gating_input) |
|
|
|
|
|
mixed_output = ( |
|
gating_output[:, 0].unsqueeze(1) * expert1_output |
|
+ gating_output[:, 1].unsqueeze(1) * expert2_output |
|
) |
|
|
|
|
|
embedding = torch.nn.functional.normalize(mixed_output, p=2, dim=1) |
|
|
|
return embedding |
|
|
|
def encode_sentence(self, input_ids, attention_mask): |
|
"""Helper method to get the embedding for a single sentence""" |
|
with torch.no_grad(): |
|
return self.forward(input_ids, attention_mask) |
|
|