File size: 2,105 Bytes
75db504 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
import torch
from torch import nn
import timm
from transformers import DistilBertModel, DistilBertConfig
import config as CFG
class ImageEncoder(nn.Module):
"""
Encode images to a fixed size vector
"""
def __init__(
self, model_name=CFG.model_name, pretrained=CFG.pretrained, trainable=CFG.trainable
):
super().__init__()
self.model = timm.create_model(
model_name, pretrained, num_classes=0, global_pool="avg"
)
for p in self.model.parameters():
p.requires_grad = trainable
def forward(self, x):
return self.model(x)
class TextEncoder(nn.Module):
def __init__(self, model_name=CFG.text_encoder_model, pretrained=CFG.pretrained, trainable=CFG.trainable):
super().__init__()
if pretrained:
self.model = DistilBertModel.from_pretrained(model_name)
else:
self.model = DistilBertModel(config=DistilBertConfig())
for p in self.model.parameters():
p.requires_grad = trainable
# we are using the CLS token hidden representation as the sentence's embedding
self.target_token_idx = 0
def forward(self, input_ids, attention_mask):
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
last_hidden_state = output.last_hidden_state
return last_hidden_state[:, self.target_token_idx, :]
class ProjectionHead(nn.Module):
def __init__(
self,
embedding_dim,
projection_dim=CFG.projection_dim,
dropout=CFG.dropout
):
super().__init__()
self.projection = nn.Linear(embedding_dim, projection_dim)
self.gelu = nn.GELU()
self.fc = nn.Linear(projection_dim, projection_dim)
self.dropout = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm(projection_dim)
def forward(self, x):
projected = self.projection(x)
x = self.gelu(projected)
x = self.fc(x)
x = self.dropout(x)
x = x + projected
x = self.layer_norm(x)
return x
|