lv12's picture
Uploading model.pt
6945e61 verified
import torch
from torch import nn
import torch.nn.functional as F
from transformers import PreTrainedModel, PretrainedConfig, AutoModel
class EmbeddingMoEConfig(PretrainedConfig):
model_type = "embedding_moe"
def __init__(self, output_dim=128, num_experts=2, dropout_rate=0.1, **kwargs):
super().__init__(**kwargs)
self.output_dim = output_dim
self.num_experts = num_experts
self.dropout_rate = dropout_rate
# Expert class using pre-trained BERT
class EmbeddingExpert(nn.Module):
def __init__(self, model_name, output_dim, dropout_rate=0.1):
super().__init__()
self.base = AutoModel.from_pretrained(model_name)
self.layer_norm = nn.LayerNorm(self.base.config.hidden_size)
self.dropout = nn.Dropout(dropout_rate)
for param in self.base.parameters():
param.requires_grad = False
# Projection layer to get the final embedding
self.projection = nn.Linear(self.base.config.hidden_size, output_dim)
nn.init.xavier_uniform_(self.projection.weight)
nn.init.zeros_(self.projection.bias)
def mean_pooling(self, model_output, attention_mask):
# Mean pooling - take attention mask into account for averaging
token_embeddings = model_output.last_hidden_state
input_mask_expanded = (
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
)
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
input_mask_expanded.sum(1), min=1e-9
)
def forward(self, input_ids, attention_mask):
outputs = self.base(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = self.mean_pooling(outputs, attention_mask)
pooled_output = self.layer_norm(pooled_output)
pooled_output = self.dropout(pooled_output)
embedding = self.projection(pooled_output)
embedding = F.normalize(embedding, p=2, dim=1)
return embedding
# Gating Network
class GatingNetwork(nn.Module):
def __init__(self, input_dim, hidden_dim, num_experts, dropout_rate=0.1):
super().__init__()
self.layer_norm = nn.LayerNorm(input_dim)
self.dropout = nn.Dropout(dropout_rate)
self.linear1 = nn.Linear(input_dim, hidden_dim)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(hidden_dim, num_experts)
self.softmax = nn.Softmax(dim=-1)
nn.init.xavier_uniform_(self.linear1.weight)
nn.init.zeros_(self.linear1.bias)
nn.init.xavier_uniform_(self.linear2.weight)
nn.init.zeros_(self.linear2.bias)
def forward(self, x):
x = self.layer_norm(x)
x = self.dropout(x)
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
x = torch.clamp(x, min=-10, max=10)
x = self.softmax(x)
return x
# Mixture of Experts for sentence embeddings using BERT
class EmbeddingMoE(PreTrainedModel):
config_class = EmbeddingMoEConfig
def __init__(self, config):
super().__init__(config)
output_dim = config.output_dim if hasattr(config, "output_dim") else 128
num_experts = config.num_experts if hasattr(config, "num_experts") else 2
self.expert1 = EmbeddingExpert("bert-base-uncased", output_dim)
self.expert2 = EmbeddingExpert("bert-base-uncased", output_dim)
self.gating = GatingNetwork(output_dim, 256, num_experts)
def forward(self, input_ids, attention_mask):
# Get embeddings from both experts
expert1_output = self.expert1(input_ids, attention_mask)
expert2_output = self.expert2(input_ids, attention_mask)
# Average the output as input to gating
gating_input = (expert1_output + expert2_output) / 2
# Get gating weights
gating_output = self.gating(gating_input)
# Combine expert outputs
mixed_output = (
gating_output[:, 0].unsqueeze(1) * expert1_output
+ gating_output[:, 1].unsqueeze(1) * expert2_output
)
# Normalize the embedding to unit length
embedding = torch.nn.functional.normalize(mixed_output, p=2, dim=1)
return embedding
def encode_sentence(self, input_ids, attention_mask):
"""Helper method to get the embedding for a single sentence"""
with torch.no_grad():
return self.forward(input_ids, attention_mask)