Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from torch import nn | |
def get_last_assistant_masks(input_ids): | |
i=len(input_ids)-4 | |
while i >= 0: | |
if input_ids[i:i+4] == [128006, 78191, 128007, 271]: | |
pos = i + 4 | |
break | |
i -= 1 | |
assistant_masks = [] | |
for i in range(len(input_ids)): | |
if i < pos: | |
assistant_masks.append(0) | |
else: | |
assistant_masks.append(1) | |
assert input_ids[-1]==128009 | |
return assistant_masks | |
def Normalized_MSE_loss(x: torch.Tensor, x_hat: torch.Tensor) -> torch.Tensor: | |
return (((x_hat - x) ** 2).mean(dim=-1) / (x**2).mean(dim=-1)).mean() | |
def Masked_Normalized_MSE_loss(x: torch.Tensor, x_hat: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: | |
mask = mask.to(torch.bfloat16) | |
loss = ((x_hat - x) ** 2).mean(dim=-1) / (x**2).mean(dim=-1) | |
assert loss.shape==mask.shape | |
seq_loss = (mask * loss).sum(-1) / (mask.sum(-1)) | |
return seq_loss.mean() | |
def pre_process(hidden_stats: torch.Tensor, eps: float = 1e-6) -> tuple: | |
''' | |
:param hidden_stats: Hidden states (shape: [batch, max_length, hidden_size]). | |
:param eps: Epsilon value for numerical stability. | |
''' | |
mean = hidden_stats.mean(dim=-1, keepdim=True) | |
std = hidden_stats.std(dim=-1, keepdim=True) | |
x = (hidden_stats - mean) / (std + eps) | |
return x, mean, std | |
class TopkSAE(nn.Module): | |
''' | |
TopK Sparse Autoencoder Implements: | |
z = TopK(encoder(x - pre_bias) + latent_bias) | |
x_hat = decoder(z) + pre_bias | |
''' | |
def __init__( | |
self, hidden_size: int, latent_size: int, k: int | |
) -> None: | |
''' | |
:param hidden_size: Dimensionality of the input residual stream activation. | |
:param latent_size: Number of latent units. | |
:param k: Number of activated latents. | |
''' | |
# 'sae_pre_bias', 'sae_latent_bias', 'sae_encoder.weight', 'sae_decoder.weight' | |
assert k <= latent_size, f'k should be less than or equal to {latent_size}' | |
super(TopkSAE, self).__init__() | |
self.pre_bias = nn.Parameter(torch.zeros(hidden_size)) | |
self.latent_bias = nn.Parameter(torch.zeros(latent_size)) | |
self.encoder = nn.Linear(hidden_size, latent_size, bias=False) | |
self.decoder = nn.Linear(latent_size, hidden_size, bias=False) | |
self.k = k | |
self.latent_size = latent_size | |
self.hidden_size = hidden_size | |
# "tied" init | |
# self.decoder.weight.data = self.encoder.weight.data.T.clone() | |
def pre_acts(self, x: torch.Tensor) -> torch.Tensor: | |
x = x - self.pre_bias | |
return self.encoder(x) + self.latent_bias | |
def get_latents(self, pre_acts: torch.Tensor) -> torch.Tensor: | |
topk = torch.topk(pre_acts, self.k, dim=-1) | |
latents = torch.zeros_like(pre_acts) | |
latents.scatter_(-1, topk.indices, topk.values) | |
return latents | |
def encode(self, x: torch.Tensor) -> torch.Tensor: | |
pre_acts = self.pre_acts(x) | |
latents = self.get_latents(pre_acts) | |
return latents | |
def decode(self, latents: torch.Tensor) -> torch.Tensor: | |
return self.decoder(latents) + self.pre_bias | |
def forward(self, x: torch.Tensor) -> tuple: | |
''' | |
:param x: Input residual stream activation (shape: [batch_size, max_length, hidden_size]). | |
:return: latents (shape: [batch_size, max_length, latent_size]). | |
x_hat (shape: [batch_size, max_length, hidden_size]). | |
''' | |
latents = self.encode(x) | |
x_hat = self.decode(latents) | |
return latents, x_hat | |