moondream2 / vision.py
vikhyatk's picture
Upload HfMoondream
05d640e verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Union, Tuple
from einops import rearrange
from PIL import Image
from .layers import attn, layer_norm, linear, mlp
from .image_crops import overlap_crop_image
from .config import VisionConfig
if torch.backends.mps.is_available():
# Non-divisible input sizes are not implemented on MPS device yet.
# https://github.com/pytorch/pytorch/issues/96056
def adaptive_avg_pool2d(input, output_size):
return F.adaptive_avg_pool2d(input.to("cpu"), output_size).to("mps")
else:
adaptive_avg_pool2d = F.adaptive_avg_pool2d
DeviceLike = Union[str, torch.device, int]
def prepare_crops(
image: Image.Image, config: VisionConfig, device: DeviceLike
) -> Tuple[torch.Tensor, Tuple[int, int]]:
np_image = np.array(image.convert("RGB"))
overlap_crops = overlap_crop_image(
np_image, max_crops=config.max_crops, overlap_margin=config.overlap_margin
)
all_crops = overlap_crops["crops"]
all_crops = np.transpose(all_crops, (0, 3, 1, 2))
all_crops = (
torch.from_numpy(all_crops)
.to(device=device, dtype=torch.float16)
.div_(255.0)
.sub_(0.5)
.div_(0.5)
)
return all_crops, overlap_crops["tiling"]
def vision_encoder(input_BCHW: torch.Tensor, w: nn.Module, config: VisionConfig):
x = rearrange(
input_BCHW,
"b c (h p1) (w p2) -> b (h w) (c p1 p2)",
p1=config.enc_patch_size,
p2=config.enc_patch_size,
) # B3HW -> B(HxW)(3xP1xP2), aka BTC
x = linear(x, w.patch_emb)
x = x + w.pos_emb
for block in w.blocks:
x = x + attn(layer_norm(x, block.ln1), block.attn, n_heads=config.enc_n_heads)
x = x + mlp(layer_norm(x, block.ln2), block.mlp)
x = layer_norm(x, w.post_ln)
return x
def vision_projection(
global_features: torch.Tensor,
reconstructed: torch.Tensor,
w: nn.Module,
config: VisionConfig,
):
reconstructed = reconstructed.permute(2, 0, 1)
reconstructed = adaptive_avg_pool2d(
reconstructed, output_size=(config.enc_n_layers, config.enc_n_layers)
)
reconstructed = reconstructed.permute(1, 2, 0).view(729, config.enc_dim)
final_features = torch.cat([global_features, reconstructed], dim=-1)
return mlp(final_features, w.proj_mlp)
def build_vision_model(config: VisionConfig, dtype: torch.dtype):
patch_dim = config.enc_patch_size * config.enc_patch_size * config.in_channels
grid_size = config.crop_size // config.enc_patch_size
num_patches = grid_size * grid_size
vision = nn.ModuleDict(
{
"patch_emb": nn.Linear(patch_dim, config.enc_dim, dtype=dtype),
"blocks": nn.ModuleList(
[
nn.ModuleDict(
{
"ln1": nn.LayerNorm(config.enc_dim, dtype=dtype),
"attn": nn.ModuleDict(
{
"qkv": nn.Linear(
config.enc_dim, 3 * config.enc_dim, dtype=dtype
),
"proj": nn.Linear(
config.enc_dim, config.enc_dim, dtype=dtype
),
}
),
"ln2": nn.LayerNorm(config.enc_dim, dtype=dtype),
"mlp": nn.ModuleDict(
{
"fc1": nn.Linear(
config.enc_dim, config.enc_ff_dim, dtype=dtype
),
"fc2": nn.Linear(
config.enc_ff_dim, config.enc_dim, dtype=dtype
),
}
),
}
)
for _ in range(config.enc_n_layers)
]
),
"post_ln": nn.LayerNorm(config.enc_dim, dtype=dtype),
"proj_mlp": nn.ModuleDict(
{
"fc1": nn.Linear(
config.enc_dim * 2, config.proj_inner_dim, dtype=dtype
),
"fc2": nn.Linear(
config.proj_inner_dim, config.proj_out_dim, dtype=dtype
),
}
),
}
)
vision.pos_emb = nn.Parameter(
torch.zeros(1, num_patches, config.enc_dim, dtype=dtype)
)
return vision