clipL336_TTR / transformer.py
JH-C-k's picture
Add files using upload-large-folder tool
ec54512 verified
raw
history blame
31 kB
from collections import OrderedDict
import math
from typing import Callable, Optional, Sequence, Tuple, Text
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.checkpoint import checkpoint
import numbers
import einops
import numpy as np
from misc import to_2tuple
class LayerNorm(nn.Module):
"""Subclass torch's LayerNorm (with cast back to input dtype)."""
def __init__(
self,
normalized_shape,
eps: float = 1e-5,
elementwise_affine: bool = True,
device=None,
dtype=None,
):
super().__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
self.normalized_shape = tuple(normalized_shape)
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = torch.nn.Parameter(
torch.empty(self.normalized_shape)
)
self.bias = torch.nn.Parameter(
torch.empty(self.normalized_shape)
)
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
def forward(self, x: torch.Tensor):
orig_type = x.dtype
assert self.normalized_shape == x.shape[-len(self.normalized_shape):]
dims = [-(i + 1) for i in range(len(self.normalized_shape))]
mean = x.mean(dim=dims, keepdim=True)
mean_x2 = (x**2).mean(dim=dims, keepdim=True)
var = mean_x2 - mean**2
x_norm = (x - mean) / torch.sqrt(var + self.eps)
if self.elementwise_affine:
x_norm = self.weight * x_norm + self.bias
return x_norm.to(orig_type)
class QuickGELU(nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class LayerScale(nn.Module):
def __init__(self, dim, init_values=1e-5, inplace=False):
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x):
raise ValueError("Not implemented")
return x.mul_(self.gamma) if self.inplace else x * self.gamma
class PatchDropout(nn.Module):
"""
https://arxiv.org/abs/2212.00794
"""
def __init__(self, prob, exclude_first_token=True):
super().__init__()
assert 0 <= prob < 1.0
self.prob = prob
self.exclude_first_token = exclude_first_token
def forward(self, x):
if not self.training or self.prob == 0.0:
return x
if self.exclude_first_token:
cls_tokens, x = x[:, :1], x[:, 1:]
else:
cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
batch = x.size()[0]
num_tokens = x.size()[1]
batch_indices = torch.arange(batch)
batch_indices = batch_indices[..., None]
keep_prob = 1 - self.prob
num_patches_keep = max(1, int(num_tokens * keep_prob))
rand = torch.randn(batch, num_tokens)
patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
x = x[batch_indices, patch_indices_keep]
if self.exclude_first_token:
x = torch.cat((cls_tokens, x), dim=1)
return x
class Attention(nn.Module):
def __init__(
self,
dim,
num_heads=8,
qkv_bias=True,
scaled_cosine=False,
scale_heads=False,
logit_scale_max=math.log(1.0 / 0.01),
attn_drop=0.0,
proj_drop=0.0,
):
super().__init__()
self.scaled_cosine = scaled_cosine
self.scale_heads = scale_heads
assert dim % num_heads == 0, "dim should be divisible by num_heads"
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim**-0.5
self.logit_scale_max = logit_scale_max
self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
if qkv_bias:
self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
else:
self.in_proj_bias = None
if self.scaled_cosine:
self.logit_scale = nn.Parameter(
torch.log(10 * torch.ones((num_heads, 1, 1)))
)
else:
self.logit_scale = None
self.attn_drop = nn.Dropout(attn_drop)
if self.scale_heads:
self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
else:
self.head_scale = None
self.out_proj = nn.Linear(dim, dim)
self.out_drop = nn.Dropout(proj_drop)
def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
L, N, C = x.shape
q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1)
q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
if self.logit_scale is not None:
attn = torch.bmm(
F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2)
)
logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
attn = attn.view(N, self.num_heads, L, L) * logit_scale
attn = attn.view(-1, L, L)
else:
q = q * self.scale
attn = torch.bmm(q, k.transpose(-1, -2))
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
new_attn_mask.masked_fill_(attn_mask, float("-inf"))
attn_mask = new_attn_mask
attn += attn_mask
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = torch.bmm(attn, v)
if self.head_scale is not None:
x = x.view(N, self.num_heads, L, C) * self.head_scale
x = x.view(-1, L, C)
x = x.transpose(0, 1).reshape(L, N, C)
x = self.out_proj(x)
x = self.out_drop(x)
return x
class AttentionalPooler(nn.Module):
def __init__(
self,
d_model: int,
context_dim: int,
n_head: int = 8,
n_queries: int = 256,
norm_layer: Callable = LayerNorm,
):
super().__init__()
self.query = nn.Parameter(torch.randn(n_queries, d_model))
self.attn = nn.MultiheadAttention(
d_model, n_head, kdim=context_dim, vdim=context_dim
)
self.ln_q = norm_layer(d_model)
self.ln_k = norm_layer(context_dim)
def forward(self, x: torch.Tensor):
x = self.ln_k(x).permute(1, 0, 2) # NLD -> LND
N = x.shape[1]
q = self.ln_q(self.query)
out = self.attn(self._repeat(q, N), x, x, need_weights=False)[0]
return out.permute(1, 0, 2) # LND -> NLD
def _repeat(self, query, N: int):
return query.unsqueeze(1).repeat(1, N, 1)
class MLP(nn.Module):
def __init__(
self,
d_model: int,
mlp_width: int,
act_layer: Callable = nn.GELU,
layer_id: Optional[int] = None,
):
super().__init__()
self.c_fc = nn.Linear(d_model, mlp_width)
self.gelu = act_layer()
self.c_proj = nn.Linear(mlp_width, d_model)
self.layer_id = layer_id
def forward(self, x, neuron_dict=None, num_register_tokens=0):
x = self.c_fc(x)
# If we have a dictionary of modifications and this layer is in it
if neuron_dict is not None and self.layer_id in neuron_dict and num_register_tokens>0:
neurons = neuron_dict[self.layer_id]
# Apply GELU to all activations
x_after_gelu = self.gelu(x)
original_activations = x_after_gelu.clone()
# Create new activation map for specified neurons
new_activation_map = torch.zeros(
(x_after_gelu.shape[0], x_after_gelu.shape[1], len(neurons)),
device=x_after_gelu.device,
).to(x_after_gelu.dtype)
max_values = torch.max(original_activations[:, :, neurons], dim=1, keepdim=True).values
new_activation_map[:, -num_register_tokens:, :] = max_values
new_activation_map[:,0,:] = x_after_gelu[:,0,neurons]
x_after_gelu[:,:,neurons] = new_activation_map
x = x_after_gelu
else:
x = self.gelu(x)
x = self.c_proj(x)
return x
# TODO 여기가 custom attetion이 아니라는 점에서 문제가 발생한 것으로 보인다.
class MultiheadAttention(nn.Module):
def __init__(
self,
embed_dim,
num_heads,
dropout=0.0,
bias=True,
add_bias_kv=False,
add_zero_attn=False,
kdim=None,
vdim=None,
batch_first=False,
device=None,
dtype=None,
):
super().__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
self.q_out = nn.Identity()
self.k_out = nn.Identity()
self.v_out = nn.Identity()
self.qkv_out = nn.Identity()
self.attn_map = nn.Identity()
self.num_heads = num_heads
self.dropout = dropout
self.batch_first = batch_first
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
self.in_proj_weight = nn.Parameter(torch.empty((3 * embed_dim, embed_dim)))
if bias:
self.in_proj_bias = nn.Parameter(torch.empty(3 * embed_dim))
else:
self.register_parameter("in_proj_bias", None)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
if add_bias_kv:
self.bias_k = nn.Parameter(torch.empty((1, 1, embed_dim)))
self.bias_v = nn.Parameter(torch.empty((1, 1, embed_dim)))
else:
self.bias_k = self.bias_v = None
self.add_zero_attn = add_zero_attn
def forward_direct(self, x, attn_mask=None):
B, N, C = x.shape
qkv = x @ self.in_proj_weight.T + self.in_proj_bias
qkv = self.qkv_out(qkv)
qkv = qkv.reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
# B, S, 3, H, d -> 3, B, H, S, d batch first computation
# 이 지점 때문에 연산 결과에 차이가 생기는 거 같은데?
q, k, v = qkv.unbind(0)
q = self.q_out(q)
k = self.k_out(k)
v = self.v_out(v)
dk = q.size()[-1]
q = q / math.sqrt(dk)
attn = q @ k.transpose(-2, -1)
if attn_mask is not None:
attn += attn_mask
attn = attn.softmax(dim=-1)
attn = self.attn_map(attn)
x = attn @ v
x = x.transpose(1, 2).reshape(B, N, C)
x = x @ self.out_proj.weight.T + self.out_proj.bias
return x
def _split_qkv_weight(self):
q_weight, k_weight, v_weight = (
self.in_proj_weight[: self.embed_dim].reshape(
self.num_heads, self.head_dim, -1
),
self.in_proj_weight[self.embed_dim : self.embed_dim * 2].reshape(
self.num_heads, self.head_dim, -1
),
self.in_proj_weight[self.embed_dim * 2 :].reshape(
self.num_heads, self.head_dim, -1
),
)
return q_weight, k_weight, v_weight
def _split_qkv_bias(self):
q_bias, k_bias, v_bias = (
self.in_proj_bias[: self.embed_dim].reshape(
1, self.num_heads, 1, self.head_dim
),
self.in_proj_bias[self.embed_dim : self.embed_dim * 2].reshape(
1, self.num_heads, 1, self.head_dim
),
self.in_proj_bias[self.embed_dim * 2 :].reshape(
1, self.num_heads, 1, self.head_dim
),
)
return q_bias, k_bias, v_bias
def forward_qkv(self, x, attn_mask=None):
B, N, C = x.shape
q_weight, k_weight, v_weight = (
self.in_proj_weight[: self.embed_dim],
self.in_proj_weight[self.embed_dim : self.embed_dim * 2],
self.in_proj_weight[self.embed_dim * 2 :],
)
q_bias, k_bias, v_bias = (
self.in_proj_bias[: self.embed_dim],
self.in_proj_bias[self.embed_dim : self.embed_dim * 2],
self.in_proj_bias[self.embed_dim * 2 :],
)
q = (x @ q_weight.T + q_bias).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
k = (x @ k_weight.T + k_bias).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
v = (x @ v_weight.T + v_bias).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
dk = q.size()[-1]
q = q / math.sqrt(dk)
attn = q @ k.transpose(-2, -1)
if attn_mask is not None:
attn += attn_mask
attn = attn.softmax(dim=-1)
x = torch.einsum("bhnm,bhmc->bhnmc", attn, v)
x = x.sum(axis=3).transpose(1, 2).reshape(B, N, C)
x = x @ self.out_proj.weight.T + self.out_proj.bias
return x
def forward_per_head(self, x, attn_mask=None):
B, N, C = x.shape
q_weight, k_weight, v_weight = self._split_qkv_weight()
q_bias, k_bias, v_bias = self._split_qkv_bias()
q = torch.einsum("bnc,hdc->bhnd", x, q_weight) + q_bias
k = torch.einsum("bnc,hdc->bhnd", x, k_weight) + k_bias
v = torch.einsum("bnc,hdc->bhnd", x, v_weight) + v_bias
dk = q.size()[-1]
q = q / math.sqrt(dk)
attn = q @ k.transpose(-2, -1)
if attn_mask is not None:
attn += attn_mask
attn = attn.softmax(dim=-1)
x = torch.einsum("bhnm,bhmc->bnmhc", attn, v)
x = torch.einsum(
"bnmhc,dhc->bnmhd",
x,
self.out_proj.weight.reshape(self.embed_dim, self.num_heads, self.head_dim),
)
x = x.sum(axis=[2, 3]) + self.out_proj.bias
return x
def _get_ov_circuit(self):
reshaped_o = self.out_proj.weight.reshape(
self.embed_dim, self.num_heads, self.head_dim
)
_, _, v_weight = self._split_qkv_weight()
_, _, v_bias = self._split_qkv_bias()
ov_circuit = torch.einsum("onh,nhi->oni", reshaped_o, v_weight)
ov_bias_circuit = torch.einsum("onh,bnxh->bnxo", reshaped_o, v_bias)
return ov_circuit, ov_bias_circuit
def forward_ov_circuit(self, x, attn_mask=None):
B, N, C = x.shape
q_weight, k_weight, _ = self._split_qkv_weight()
q_bias, k_bias, _ = self._split_qkv_bias()
q = torch.einsum("bnc,hdc->bhnd", x, q_weight) + q_bias
k = torch.einsum("bnc,hdc->bhnd", x, k_weight) + k_bias
ov, ov_bias = self._get_ov_circuit()
v = torch.einsum("bnc,dhc->bhnd", x, ov) + ov_bias
dk = q.size()[-1]
q = q / math.sqrt(dk)
attn = q @ k.transpose(-2, -1)
if attn_mask is not None:
attn += attn_mask
attn = attn.softmax(dim=-1)
x = torch.einsum("bhnm,bhmc->bnmhc", attn, v)
x = x.sum(axis=[2, 3]) + self.out_proj.bias
return x
def forward(self, x, attn_mask=None, method: Text = "ov_circuit"):
if method == "direct":
return self.forward_direct(x, attn_mask=attn_mask)
elif method == "qkv":
return self.forward_qkv(x, attn_mask=attn_mask)
elif method == "head":
return self.forward_per_head(x, attn_mask=attn_mask)
elif method == "ov_circuit":
return self.forward_ov_circuit(x, attn_mask=attn_mask)
class ResidualAttentionBlock(nn.Module):
def __init__(
self,
d_model: int,
n_head: int,
mlp_ratio: float = 4.0,
ls_init_value: float = None,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
layer_id: Optional[int] = None,
):
super().__init__()
self.ln_1 = norm_layer(d_model)
self.attn = MultiheadAttention(d_model, n_head)
self.layer_id = layer_id
self.ls_1 = (
LayerScale(d_model, ls_init_value)
if ls_init_value is not None
else nn.Identity()
)
self.ln_2 = norm_layer(d_model)
self.mlp_width = int(d_model * mlp_ratio)
self.mlp = MLP(
d_model,
self.mlp_width,
act_layer=act_layer,
layer_id=layer_id,
)
self.ls_2 = (
LayerScale(d_model, ls_init_value)
if ls_init_value is not None
else nn.Identity()
)
def attention(
self,
q_x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
method: Text = "direct",
):
attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None
return self.attn(q_x, attn_mask=attn_mask, method=method)
def forward(
self,
q_x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
attn_method: Text = "direct",
neuron_dict=None,
num_register_tokens=0
):
after_ln1 = self.ln_1(q_x)
after_attn = self.attention(
q_x=after_ln1, attn_mask=attn_mask, method=attn_method
)
x = q_x + self.ls_1(after_attn)
after_ln2 = self.ln_2(x)
after_mlp = self.mlp(after_ln2, neuron_dict=neuron_dict, num_register_tokens=num_register_tokens)
x = x + self.ls_2(after_mlp)
return x
class Transformer(nn.Module):
def __init__(
self,
width: int,
layers: int,
heads: int,
mlp_ratio: float = 4.0,
ls_init_value: float = None,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
):
super().__init__()
self.width = width
self.layers = layers
self.grad_checkpointing = False
self.resblocks = nn.ModuleList(
[
ResidualAttentionBlock(
width,
heads,
mlp_ratio,
ls_init_value=ls_init_value,
act_layer=act_layer,
norm_layer=norm_layer,
layer_id=i,
)
for i in range(layers)
]
)
def get_cast_dtype(self) -> torch.dtype:
if hasattr(self.resblocks[0].mlp.c_fc, "int8_original_dtype"):
return self.resblocks[0].mlp.c_fc.int8_original_dtype
return self.resblocks[0].mlp.c_fc.weight.dtype
def forward(
self,
x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
attn_method: Text = "direct",
neuron_dict=None,
num_register_tokens=0,
get_hidden_states: bool = False,
):
if "hidden_states" not in locals():
hidden_states = []
hidden_states.append(x) # input embedding 저장
for r in self.resblocks:
if self.grad_checkpointing and not torch.jit.is_scripting():
raise ValueError("grad_checkpointing not implemented")
else:
x = r(
x,
attn_mask=attn_mask,
attn_method=attn_method,
neuron_dict=neuron_dict,
num_register_tokens=num_register_tokens
)
if get_hidden_states:
hidden_states.append(x)
if get_hidden_states:
return {
"hidden_states": hidden_states,
"last_hidden_state": x,
}
else:
return x
class VisionTransformer(nn.Module):
output_tokens: torch.jit.Final[bool]
def __init__(
self,
image_size: int,
patch_size: int,
width: int,
layers: int,
heads: int,
mlp_ratio: float,
ls_init_value: float = None,
global_average_pool: bool = False,
attentional_pool: bool = False,
n_queries: int = 256,
attn_pooler_heads: int = 8,
output_dim: int = 512,
patch_dropout: float = 0.0,
input_patchnorm: bool = False,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
output_tokens: bool = False,
):
super().__init__()
self.output_tokens = output_tokens
image_height, image_width = self.image_size = to_2tuple(image_size)
patch_height, patch_width = self.patch_size = to_2tuple(patch_size)
self.grid_size = (image_height // patch_height, image_width // patch_width)
self.output_dim = output_dim
self.num_register_tokens = 0
self.neuron_dict = None
self.input_patchnorm = input_patchnorm
if input_patchnorm:
patch_input_dim = patch_height * patch_width * 3
self.patchnorm_pre_ln = LayerNorm(patch_input_dim)
self.conv1 = nn.Linear(patch_input_dim, width)
else:
self.patchnorm_pre_ln = nn.Identity()
self.conv1 = nn.Conv2d(
in_channels=3,
out_channels=width,
kernel_size=patch_size,
stride=patch_size,
bias=False,
)
scale = width**-0.5
self.class_embedding = nn.Parameter(scale * torch.randn(width))
self.positional_embedding = nn.Parameter(
scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width)
)
self.width = width
self.scale = scale
self.extra_token = self.scale * torch.randn(width)
self.patch_dropout = (
PatchDropout(patch_dropout) if patch_dropout > 0.0 else nn.Identity()
)
self.ln_pre = norm_layer(width)
self.transformer = Transformer(
width,
layers,
heads,
mlp_ratio,
ls_init_value=ls_init_value,
act_layer=act_layer,
norm_layer=norm_layer,
)
self.global_average_pool = global_average_pool
if attentional_pool:
self.attn_pool = AttentionalPooler(
output_dim, width, n_head=attn_pooler_heads, n_queries=n_queries
)
self.ln_post = norm_layer(output_dim)
self.proj = nn.Parameter(scale * torch.randn(output_dim, output_dim))
else:
self.attn_pool = None
self.ln_post = norm_layer(width)
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.transformer.grad_checkpointing = enable
def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if self.global_average_pool:
return x.mean(dim=1), x
else:
return x[:, 0], x[:, 1:]
# 여기서
def forward(self, x: torch.Tensor, attn_method: Text = "direct", num_register_tokens = None, neuron_dict=None, get_hidden_states:bool=False):
# to patches
if num_register_tokens is None and neuron_dict is None:
num_register_tokens = self.num_register_tokens
neuron_dict = self.neuron_dict
if self.input_patchnorm:
x = x.reshape(
x.shape[0],
x.shape[1],
self.grid_size[0],
self.patch_size[0],
self.grid_size[1],
self.patch_size[1],
)
x = x.permute(0, 2, 4, 1, 3, 5)
x = x.reshape(x.shape[0], self.grid_size[0] * self.grid_size[1], -1)
x = self.patchnorm_pre_ln(x)
x = self.conv1(x)
else:
x = self.conv1(x)
x = x.reshape(x.shape[0], x.shape[1], -1)
x = x.permute(0, 2, 1)
# class embeddings and positional embeddings
x = torch.cat([
self.class_embedding.to(x.dtype)
+ torch.zeros(
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
),
x,
],
dim=1,
)
x = x + self.positional_embedding.to(x.dtype)
extra_token_embeddings = []
total_patches = x.shape[1] - 1
for i in range(num_register_tokens):
extra_token_embeddings.append(
torch.zeros(
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
),
)
# Add extra tokens
if num_register_tokens > 0:
x = torch.cat([x, *extra_token_embeddings], dim=1)
x = self.patch_dropout(x)
x = self.ln_pre(x)
# 여기서 들어가는 것은 [B, 1+ 576 + num_register_tokens, C]
if get_hidden_states:
ret = self.transformer(x, attn_mask=None, attn_method=attn_method, neuron_dict=neuron_dict, num_register_tokens=num_register_tokens,get_hidden_states=get_hidden_states)
hidden_states = ret["hidden_states"]
x = ret["last_hidden_state"]
else:
x = self.transformer(x, attn_mask=None, attn_method=attn_method, neuron_dict=neuron_dict, num_register_tokens=num_register_tokens,get_hidden_states=get_hidden_states)
if self.attn_pool is not None:
x = self.attn_pool(x)
x = self.ln_post(x)
pooled, tokens = self._global_pool(x)
else:
pooled, tokens = self._global_pool(x)
pooled = self.ln_post(pooled)
if self.proj is not None:
pooled = pooled @ self.proj
if self.output_tokens:
return pooled, tokens
if get_hidden_states:
return {
"pooled": pooled, # GLOBAL CLS
"tokens": tokens, # ALL TOKENS
"hidden_states": hidden_states # layer-wise hidden states
}
return pooled
class TextTransformer(nn.Module):
output_tokens: torch.jit.Final[bool]
def __init__(
self,
context_length: int = 77,
vocab_size: int = 49408,
width: int = 512,
heads: int = 8,
layers: int = 12,
ls_init_value: float = None,
output_dim: int = 512,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
embed_cls: bool = False,
pad_id: int = 0,
output_tokens: bool = False,
):
super().__init__()
self.output_tokens = output_tokens
self.num_pos = self.context_length = context_length
self.vocab_size = vocab_size
self.width = width
self.output_dim = output_dim
self.heads = heads
self.pad_id = pad_id
self.text_projection = nn.Parameter(torch.empty(width, output_dim))
if embed_cls:
self.cls_emb = nn.Parameter(torch.empty(width))
self.num_pos += 1
else:
self.cls_emb = None
self.token_embedding = nn.Embedding(vocab_size, width)
self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width))
self.transformer = Transformer(
width=width,
layers=layers,
heads=heads,
ls_init_value=ls_init_value,
act_layer=act_layer,
norm_layer=norm_layer,
)
self.ln_final = norm_layer(width)
self.register_buffer("attn_mask", self.build_attention_mask(), persistent=False)
self.init_parameters()
def init_parameters(self):
nn.init.normal_(self.token_embedding.weight, std=0.02)
nn.init.normal_(self.positional_embedding, std=0.01)
if self.cls_emb is not None:
nn.init.normal_(self.cls_emb, std=0.01)
proj_std = (self.transformer.width**-0.5) * (
(2 * self.transformer.layers) ** -0.5
)
attn_std = self.transformer.width**-0.5
fc_std = (2 * self.transformer.width) ** -0.5
for block in self.transformer.resblocks:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
if self.text_projection is not None:
nn.init.normal_(self.text_projection, std=self.transformer.width**-0.5)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.transformer.grad_checkpointing = enable
def build_attention_mask(self):
mask = torch.empty(self.num_pos, self.num_pos)
mask.fill_(float("-inf"))
mask.triu_(1)
return mask
def build_cls_mask(self, text, cast_dtype: torch.dtype):
cls_mask = (text != self.pad_id).unsqueeze(1)
cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=1.0)
additive_mask = torch.empty(
cls_mask.shape, dtype=cast_dtype, device=cls_mask.device
)
additive_mask.fill_(0)
additive_mask.masked_fill_(~cls_mask, float("-inf"))
additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0)
return additive_mask
def _repeat(self, t, N: int):
return t.reshape(1, 1, -1).repeat(N, 1, 1)
def forward(self, text, attn_method: Text = "direct"):
cast_dtype = self.transformer.get_cast_dtype()
seq_len = text.shape[1]
x = self.token_embedding(text).to(cast_dtype)
attn_mask = self.attn_mask
if self.cls_emb is not None:
seq_len += 1
x = torch.cat([x, self._repeat(self.cls_emb, x.shape[0])], dim=1)
cls_mask = self.build_cls_mask(text, cast_dtype)
attn_mask = (
attn_mask[None, :seq_len, :seq_len] + cls_mask[:, :seq_len, :seq_len]
)
x = x + self.positional_embedding[:seq_len].to(cast_dtype)
x = self.transformer(x, attn_mask=attn_mask, attn_method=attn_method)
if self.cls_emb is not None:
pooled, tokens = x[:, -1], x[:, :-1]
pooled = self.ln_final(pooled)
else:
x = self.ln_final(x)
pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x
if self.text_projection is not None:
pooled = pooled @ self.text_projection
if self.output_tokens:
return pooled, tokens
return pooled