moondream2 / layers.py
vikhyatk's picture
Upload HfMoondream
05d640e verified
from dataclasses import dataclass
from typing import Literal
import torch
from torch.nn import functional as F
def gelu_approx(x):
return F.gelu(x, approximate="tanh")
@dataclass
class LinearWeights:
weight: torch.Tensor
bias: torch.Tensor
def linear(x: torch.Tensor, w: LinearWeights) -> torch.Tensor:
return F.linear(x, w.weight, w.bias)
@dataclass
class LayerNormWeights:
weight: torch.Tensor
bias: torch.Tensor
def layer_norm(x: torch.Tensor, w: LayerNormWeights) -> torch.Tensor:
return F.layer_norm(x, w.bias.shape, w.weight, w.bias)
@dataclass
class MLPWeights:
fc1: LinearWeights
fc2: LinearWeights
act: Literal["gelu_approx"] = "gelu_approx"
def mlp(x: torch.Tensor, w: MLPWeights) -> torch.Tensor:
x = linear(x, w.fc1)
x = gelu_approx(x)
x = linear(x, w.fc2)
return x
@dataclass
class AttentionWeights:
qkv: LinearWeights
proj: LinearWeights
def attn(x: torch.Tensor, w: AttentionWeights, n_heads: int) -> torch.Tensor:
bsz, q_len, d_model = x.shape
head_dim = d_model // n_heads
q, k, v = [
t.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
for t in linear(x, w.qkv).chunk(3, dim=-1)
]
out = F.scaled_dot_product_attention(q, k, v)
out = out.transpose(1, 2).reshape(bsz, q_len, d_model)
out = linear(out, w.proj)
return out