Upload HfMoondream
Browse files- config.json +1 -1
- config.py +2 -2
- layers.py +52 -51
- model.safetensors +2 -2
- moondream.py +29 -24
- packing.py +52 -0
- region.py +3 -3
- text.py +87 -36
- vision.py +3 -3
config.json
CHANGED
@@ -8,6 +8,6 @@
|
|
8 |
},
|
9 |
"config": {},
|
10 |
"model_type": "moondream1",
|
11 |
-
"torch_dtype": "
|
12 |
"transformers_version": "4.44.0"
|
13 |
}
|
|
|
8 |
},
|
9 |
"config": {},
|
10 |
"model_type": "moondream1",
|
11 |
+
"torch_dtype": "bfloat16",
|
12 |
"transformers_version": "4.44.0"
|
13 |
}
|
config.py
CHANGED
@@ -12,7 +12,7 @@ class TextConfig:
|
|
12 |
n_heads: int = 32
|
13 |
n_kv_heads: int = 32
|
14 |
prefix_attn: int = 730
|
15 |
-
group_size: int = 128
|
16 |
|
17 |
|
18 |
@dataclass(frozen=True)
|
@@ -38,7 +38,7 @@ class RegionConfig:
|
|
38 |
size_feat_dim: int = 512
|
39 |
size_out_dim: int = 2048
|
40 |
inner_dim: int = 8192
|
41 |
-
|
42 |
|
43 |
@dataclass(frozen=True)
|
44 |
class TokenizerConfig:
|
|
|
12 |
n_heads: int = 32
|
13 |
n_kv_heads: int = 32
|
14 |
prefix_attn: int = 730
|
15 |
+
group_size: Optional[int] = 128
|
16 |
|
17 |
|
18 |
@dataclass(frozen=True)
|
|
|
38 |
size_feat_dim: int = 512
|
39 |
size_out_dim: int = 2048
|
40 |
inner_dim: int = 8192
|
41 |
+
group_size: Optional[int] = 128
|
42 |
|
43 |
@dataclass(frozen=True)
|
44 |
class TokenizerConfig:
|
layers.py
CHANGED
@@ -1,11 +1,13 @@
|
|
1 |
-
import bitblas
|
2 |
import torch
|
3 |
import torch.nn as nn
|
|
|
4 |
|
5 |
from dataclasses import dataclass
|
6 |
from typing import Literal
|
7 |
-
from
|
8 |
-
from
|
|
|
|
|
9 |
|
10 |
|
11 |
def gelu_approx(x):
|
@@ -18,65 +20,65 @@ class LinearWeights:
|
|
18 |
bias: torch.Tensor
|
19 |
|
20 |
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
If dtype is torch.int8, it uses bitblas for quantization.
|
25 |
-
Otherwise, it uses a standard nn.Linear layer.
|
26 |
-
"""
|
27 |
|
|
|
28 |
def __init__(
|
29 |
self,
|
30 |
in_features: int,
|
31 |
out_features: int,
|
32 |
-
|
33 |
-
dtype: torch.dtype = None,
|
34 |
-
group_size: int = 128,
|
35 |
):
|
|
|
36 |
super().__init__()
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
)
|
54 |
-
|
|
|
55 |
self.linear = nn.Linear(
|
56 |
-
in_features=
|
57 |
-
out_features=out_features,
|
58 |
-
bias=bias,
|
59 |
-
dtype=torch.float16,
|
60 |
)
|
61 |
-
|
62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
return self.linear(x)
|
64 |
|
65 |
-
@property
|
66 |
-
def weight(self) -> torch.Tensor:
|
67 |
-
try:
|
68 |
-
return self.linear.weight
|
69 |
-
except AttributeError:
|
70 |
-
return self.linear.qweight
|
71 |
-
|
72 |
-
@property
|
73 |
-
def bias(self) -> torch.Tensor:
|
74 |
-
return self.linear.bias
|
75 |
-
|
76 |
-
|
77 |
-
def linear(x: torch.Tensor, w: LinearWeights) -> torch.Tensor:
|
78 |
-
return F.linear(x, w.weight, w.bias)
|
79 |
-
|
80 |
|
81 |
@dataclass
|
82 |
class LayerNormWeights:
|
@@ -96,7 +98,6 @@ class MLPWeights:
|
|
96 |
|
97 |
|
98 |
def mlp(x: torch.Tensor, w: MLPWeights) -> torch.Tensor:
|
99 |
-
|
100 |
x = w.fc1(x)
|
101 |
x = gelu_approx(x)
|
102 |
x = w.fc2(x)
|
|
|
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
|
5 |
from dataclasses import dataclass
|
6 |
from typing import Literal
|
7 |
+
from torchao import quantize_
|
8 |
+
from torchao.quantization import int4_weight_only
|
9 |
+
|
10 |
+
from .packing import dequantize_tensor
|
11 |
|
12 |
|
13 |
def gelu_approx(x):
|
|
|
20 |
bias: torch.Tensor
|
21 |
|
22 |
|
23 |
+
def linear(x: torch.Tensor, w: LinearWeights) -> torch.Tensor:
|
24 |
+
return F.linear(x, w.weight, w.bias)
|
25 |
+
|
|
|
|
|
|
|
26 |
|
27 |
+
class QuantizedLinear(nn.Module):
|
28 |
def __init__(
|
29 |
self,
|
30 |
in_features: int,
|
31 |
out_features: int,
|
32 |
+
dtype: torch.dtype,
|
|
|
|
|
33 |
):
|
34 |
+
# TODO: Take group_size as an input instead of hardcoding it here.
|
35 |
super().__init__()
|
36 |
+
self.in_features = in_features
|
37 |
+
self.out_features = out_features
|
38 |
+
self.weight = nn.ParameterDict(
|
39 |
+
{
|
40 |
+
"packed": nn.Parameter(
|
41 |
+
torch.empty(
|
42 |
+
out_features, in_features // 128, 64, dtype=torch.uint8
|
43 |
+
),
|
44 |
+
requires_grad=False,
|
45 |
+
),
|
46 |
+
"scales": nn.Parameter(
|
47 |
+
torch.empty(out_features, in_features // 128), requires_grad=False
|
48 |
+
),
|
49 |
+
}
|
50 |
+
)
|
51 |
+
self.bias = nn.Parameter(torch.empty(out_features), requires_grad=False)
|
52 |
+
self.unpacked = False
|
53 |
+
|
54 |
+
def unpack(self):
|
55 |
+
self.weight = nn.Parameter(
|
56 |
+
dequantize_tensor(
|
57 |
+
self.weight["packed"],
|
58 |
+
self.weight["scales"],
|
59 |
+
(self.weight["packed"].shape[0], self.weight["packed"].shape[1] * 128),
|
60 |
+
128,
|
61 |
+
torch.bfloat16,
|
62 |
)
|
63 |
+
)
|
64 |
+
with torch.device("meta"):
|
65 |
self.linear = nn.Linear(
|
66 |
+
self.in_features, self.out_features, dtype=torch.bfloat16
|
|
|
|
|
|
|
67 |
)
|
68 |
+
self.linear.weight = self.weight
|
69 |
+
self.linear.bias = nn.Parameter(
|
70 |
+
self.bias.to(torch.bfloat16), requires_grad=False
|
71 |
+
)
|
72 |
+
del self.weight, self.bias
|
73 |
+
quantize_(self, int4_weight_only(group_size=128))
|
74 |
+
torch.cuda.empty_cache()
|
75 |
+
self.unpacked = True
|
76 |
+
|
77 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
78 |
+
if not self.unpacked:
|
79 |
+
self.unpack()
|
80 |
return self.linear(x)
|
81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
|
83 |
@dataclass
|
84 |
class LayerNormWeights:
|
|
|
98 |
|
99 |
|
100 |
def mlp(x: torch.Tensor, w: MLPWeights) -> torch.Tensor:
|
|
|
101 |
x = w.fc1(x)
|
102 |
x = gelu_approx(x)
|
103 |
x = w.fc2(x)
|
model.safetensors
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:97076df1a9a09ff4108a69ea59b4c9abf522b248e8425c9334bab98ddbaf4b33
|
3 |
+
size 1838828672
|
moondream.py
CHANGED
@@ -12,6 +12,7 @@ from .image_crops import reconstruct_from_crops
|
|
12 |
from .vision import vision_encoder, vision_projection, prepare_crops, build_vision_model
|
13 |
from .text import build_text_model, text_encoder, lm_head, text_decoder
|
14 |
from .region import decode_coordinate, encode_coordinate, decode_size, encode_size
|
|
|
15 |
from .utils import remove_outlier_points
|
16 |
|
17 |
|
@@ -63,47 +64,49 @@ class KVCache(nn.Module):
|
|
63 |
|
64 |
|
65 |
class MoondreamModel(nn.Module):
|
66 |
-
|
|
|
|
|
|
|
67 |
super().__init__()
|
68 |
self.config = config
|
69 |
-
self.dtype = dtype
|
70 |
-
self.setup_caches_flag = setup_caches
|
71 |
|
72 |
self.tokenizer = Tokenizer.from_pretrained(
|
73 |
"vikhyatk/moondream2", revision="2025-01-09"
|
74 |
)
|
75 |
-
|
76 |
self.vision = build_vision_model(config.vision, dtype)
|
77 |
-
|
78 |
-
self.text = build_text_model(config.text, torch.int8)
|
79 |
|
80 |
# Region Model
|
|
|
|
|
|
|
81 |
self.region = nn.ModuleDict(
|
82 |
{
|
83 |
-
"coord_encoder":
|
84 |
config.region.coord_feat_dim, config.region.dim, dtype=dtype
|
85 |
),
|
86 |
"coord_decoder": nn.ModuleDict(
|
87 |
{
|
88 |
-
"fc1":
|
89 |
config.region.dim, config.region.inner_dim, dtype=dtype
|
90 |
),
|
91 |
-
"fc2":
|
92 |
config.region.inner_dim,
|
93 |
config.region.coord_out_dim,
|
94 |
dtype=dtype,
|
95 |
),
|
96 |
}
|
97 |
),
|
98 |
-
"size_encoder":
|
99 |
config.region.size_feat_dim, config.region.dim, dtype=dtype
|
100 |
),
|
101 |
"size_decoder": nn.ModuleDict(
|
102 |
{
|
103 |
-
"fc1":
|
104 |
config.region.dim, config.region.inner_dim, dtype=dtype
|
105 |
),
|
106 |
-
"fc2":
|
107 |
config.region.inner_dim,
|
108 |
config.region.size_out_dim,
|
109 |
dtype=dtype,
|
@@ -129,11 +132,11 @@ class MoondreamModel(nn.Module):
|
|
129 |
attn_mask[..., :prefix_attn_len, :prefix_attn_len] = 1
|
130 |
self.register_buffer("attn_mask", attn_mask, persistent=False)
|
131 |
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
return # Can't set up caches without text model
|
136 |
|
|
|
137 |
c = self.config.text
|
138 |
for b in self.text.blocks:
|
139 |
b.kv_cache = KVCache(
|
@@ -166,12 +169,16 @@ class MoondreamModel(nn.Module):
|
|
166 |
return logits, hidden
|
167 |
|
168 |
def compile(self):
|
|
|
|
|
|
|
|
|
169 |
# TODO: vision_projection is not being compiled
|
170 |
-
self._vis_enc = torch.compile(
|
171 |
-
|
|
|
|
|
172 |
)
|
173 |
-
self._prefill = torch.compile(self._prefill)
|
174 |
-
self._decode_one_tok = torch.compile(self._decode_one_tok)
|
175 |
|
176 |
def _run_vision_encoder(self, image: Image.Image) -> torch.Tensor:
|
177 |
all_crops, tiling = prepare_crops(image, self.config.vision, device=self.device)
|
@@ -204,7 +211,6 @@ class MoondreamModel(nn.Module):
|
|
204 |
|
205 |
# Run through text model in addition to the vision encoder, to minimize
|
206 |
# re-computation if multiple queries are performed on this image.
|
207 |
-
|
208 |
with torch.inference_mode():
|
209 |
img_emb = self._run_vision_encoder(image)
|
210 |
bos_emb = text_encoder(
|
@@ -240,7 +246,6 @@ class MoondreamModel(nn.Module):
|
|
240 |
def _prefill_prompt(
|
241 |
self, prompt_tokens: torch.Tensor, pos: int, temperature: float, top_p: float
|
242 |
):
|
243 |
-
|
244 |
with torch.inference_mode():
|
245 |
prompt_emb = text_encoder(prompt_tokens, self.text)
|
246 |
torch._dynamo.mark_dynamic(prompt_emb, 1)
|
@@ -585,11 +590,11 @@ class MoondreamModel(nn.Module):
|
|
585 |
self.text,
|
586 |
)
|
587 |
x_emb = encode_coordinate(
|
588 |
-
torch.tensor([[[source[0]]]], device=self.device, dtype=torch.
|
589 |
self.region,
|
590 |
)
|
591 |
y_emb = encode_coordinate(
|
592 |
-
torch.tensor([[[source[1]]]], device=self.device, dtype=torch.
|
593 |
self.region,
|
594 |
)
|
595 |
|
|
|
12 |
from .vision import vision_encoder, vision_projection, prepare_crops, build_vision_model
|
13 |
from .text import build_text_model, text_encoder, lm_head, text_decoder
|
14 |
from .region import decode_coordinate, encode_coordinate, decode_size, encode_size
|
15 |
+
from .layers import QuantizedLinear
|
16 |
from .utils import remove_outlier_points
|
17 |
|
18 |
|
|
|
64 |
|
65 |
|
66 |
class MoondreamModel(nn.Module):
|
67 |
+
|
68 |
+
def __init__(
|
69 |
+
self, config: MoondreamConfig, dtype=torch.bfloat16, setup_caches=True
|
70 |
+
):
|
71 |
super().__init__()
|
72 |
self.config = config
|
|
|
|
|
73 |
|
74 |
self.tokenizer = Tokenizer.from_pretrained(
|
75 |
"vikhyatk/moondream2", revision="2025-01-09"
|
76 |
)
|
|
|
77 |
self.vision = build_vision_model(config.vision, dtype)
|
78 |
+
self.text = build_text_model(config.text, dtype)
|
|
|
79 |
|
80 |
# Region Model
|
81 |
+
linear_cls = (
|
82 |
+
QuantizedLinear if config.region.group_size is not None else nn.Linear
|
83 |
+
)
|
84 |
self.region = nn.ModuleDict(
|
85 |
{
|
86 |
+
"coord_encoder": linear_cls(
|
87 |
config.region.coord_feat_dim, config.region.dim, dtype=dtype
|
88 |
),
|
89 |
"coord_decoder": nn.ModuleDict(
|
90 |
{
|
91 |
+
"fc1": linear_cls(
|
92 |
config.region.dim, config.region.inner_dim, dtype=dtype
|
93 |
),
|
94 |
+
"fc2": linear_cls(
|
95 |
config.region.inner_dim,
|
96 |
config.region.coord_out_dim,
|
97 |
dtype=dtype,
|
98 |
),
|
99 |
}
|
100 |
),
|
101 |
+
"size_encoder": linear_cls(
|
102 |
config.region.size_feat_dim, config.region.dim, dtype=dtype
|
103 |
),
|
104 |
"size_decoder": nn.ModuleDict(
|
105 |
{
|
106 |
+
"fc1": linear_cls(
|
107 |
config.region.dim, config.region.inner_dim, dtype=dtype
|
108 |
),
|
109 |
+
"fc2": linear_cls(
|
110 |
config.region.inner_dim,
|
111 |
config.region.size_out_dim,
|
112 |
dtype=dtype,
|
|
|
132 |
attn_mask[..., :prefix_attn_len, :prefix_attn_len] = 1
|
133 |
self.register_buffer("attn_mask", attn_mask, persistent=False)
|
134 |
|
135 |
+
# Initialize KV caches.
|
136 |
+
if setup_caches:
|
137 |
+
self._setup_caches()
|
|
|
138 |
|
139 |
+
def _setup_caches(self):
|
140 |
c = self.config.text
|
141 |
for b in self.text.blocks:
|
142 |
b.kv_cache = KVCache(
|
|
|
169 |
return logits, hidden
|
170 |
|
171 |
def compile(self):
|
172 |
+
for module in self.modules():
|
173 |
+
if isinstance(module, QuantizedLinear):
|
174 |
+
module.unpack()
|
175 |
+
|
176 |
# TODO: vision_projection is not being compiled
|
177 |
+
self._vis_enc = torch.compile(self._vis_enc, fullgraph=True)
|
178 |
+
self._prefill = torch.compile(self._prefill, fullgraph=True)
|
179 |
+
self._decode_one_tok = torch.compile(
|
180 |
+
self._decode_one_tok, fullgraph=True, mode="reduce-overhead"
|
181 |
)
|
|
|
|
|
182 |
|
183 |
def _run_vision_encoder(self, image: Image.Image) -> torch.Tensor:
|
184 |
all_crops, tiling = prepare_crops(image, self.config.vision, device=self.device)
|
|
|
211 |
|
212 |
# Run through text model in addition to the vision encoder, to minimize
|
213 |
# re-computation if multiple queries are performed on this image.
|
|
|
214 |
with torch.inference_mode():
|
215 |
img_emb = self._run_vision_encoder(image)
|
216 |
bos_emb = text_encoder(
|
|
|
246 |
def _prefill_prompt(
|
247 |
self, prompt_tokens: torch.Tensor, pos: int, temperature: float, top_p: float
|
248 |
):
|
|
|
249 |
with torch.inference_mode():
|
250 |
prompt_emb = text_encoder(prompt_tokens, self.text)
|
251 |
torch._dynamo.mark_dynamic(prompt_emb, 1)
|
|
|
590 |
self.text,
|
591 |
)
|
592 |
x_emb = encode_coordinate(
|
593 |
+
torch.tensor([[[source[0]]]], device=self.device, dtype=torch.bfloat16),
|
594 |
self.region,
|
595 |
)
|
596 |
y_emb = encode_coordinate(
|
597 |
+
torch.tensor([[[source[1]]]], device=self.device, dtype=torch.bfloat16),
|
598 |
self.region,
|
599 |
)
|
600 |
|
packing.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def unpack_int4(packed: torch.Tensor, original_length: int) -> torch.Tensor:
|
5 |
+
"""
|
6 |
+
Unpack a tensor of uint8 packed bytes (two 4-bit values per byte) into a 1D tensor of int8 values,
|
7 |
+
vectorized over the entire input.
|
8 |
+
"""
|
9 |
+
lower = packed & 0xF
|
10 |
+
upper = (packed >> 4) & 0xF
|
11 |
+
# Interleave lower and upper nibbles
|
12 |
+
nibbles = torch.stack([lower, upper], dim=-1).view(-1)[:original_length]
|
13 |
+
nibbles = nibbles.to(torch.int8)
|
14 |
+
nibbles[nibbles >= 8] -= 16
|
15 |
+
return nibbles
|
16 |
+
|
17 |
+
|
18 |
+
def dequantize_tensor(
|
19 |
+
packed: torch.Tensor,
|
20 |
+
scales: torch.Tensor,
|
21 |
+
orig_shape: torch.Size,
|
22 |
+
block_size: int,
|
23 |
+
dtype: torch.dtype,
|
24 |
+
):
|
25 |
+
"""
|
26 |
+
Dequantizes a packed int4 tensor (with given per-block scales) back to bfloat16,
|
27 |
+
using vectorized operations to avoid Python loops.
|
28 |
+
"""
|
29 |
+
num_bytes_per_block = (block_size + 1) // 2 # number of packed bytes per block
|
30 |
+
num_blocks_total = packed.numel() // num_bytes_per_block
|
31 |
+
# Reshape to (num_blocks_total, num_bytes_per_block)
|
32 |
+
packed_rows = packed.view(num_blocks_total, num_bytes_per_block)
|
33 |
+
|
34 |
+
# Vectorized unpacking: compute lower and upper nibbles for all rows at once.
|
35 |
+
lower = packed_rows & 0xF
|
36 |
+
upper = (packed_rows >> 4) & 0xF
|
37 |
+
# Create a new dimension for the two nibbles and then flatten.
|
38 |
+
nibbles = torch.stack([lower, upper], dim=2).view(num_blocks_total, -1)
|
39 |
+
# Slice to get exactly block_size values per block.
|
40 |
+
quantized_flat = nibbles[:, :block_size].to(torch.int8)
|
41 |
+
quantized_flat[quantized_flat >= 8] -= 16
|
42 |
+
|
43 |
+
# Reshape to original block structure.
|
44 |
+
last_dim = orig_shape[-1]
|
45 |
+
num_blocks = last_dim // block_size
|
46 |
+
new_shape = orig_shape[:-1] + (num_blocks, block_size)
|
47 |
+
quantized = quantized_flat.view(new_shape)
|
48 |
+
|
49 |
+
# Dequantize using scales.
|
50 |
+
dequantized = quantized.to(torch.float32) * scales.unsqueeze(-1)
|
51 |
+
dequantized = dequantized.view(orig_shape)
|
52 |
+
return dequantized.to(dtype)
|
region.py
CHANGED
@@ -2,7 +2,7 @@ import torch
|
|
2 |
import torch.nn as nn
|
3 |
import math
|
4 |
|
5 |
-
from .layers import
|
6 |
|
7 |
|
8 |
def fourier_features(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
|
@@ -36,7 +36,7 @@ def encode_coordinate(coord: torch.Tensor, w: nn.Module) -> torch.Tensor:
|
|
36 |
Returns:
|
37 |
Encoded hidden states tensor for input to text model
|
38 |
"""
|
39 |
-
return
|
40 |
|
41 |
|
42 |
def decode_coordinate(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
|
@@ -64,7 +64,7 @@ def encode_size(size: torch.Tensor, w: nn.Module) -> torch.Tensor:
|
|
64 |
Returns:
|
65 |
Encoded hidden states tensor for input to text model
|
66 |
"""
|
67 |
-
return
|
68 |
|
69 |
|
70 |
def decode_size(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
|
|
|
2 |
import torch.nn as nn
|
3 |
import math
|
4 |
|
5 |
+
from .layers import mlp
|
6 |
|
7 |
|
8 |
def fourier_features(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
|
|
|
36 |
Returns:
|
37 |
Encoded hidden states tensor for input to text model
|
38 |
"""
|
39 |
+
return w.coord_encoder(fourier_features(coord, w.coord_features))
|
40 |
|
41 |
|
42 |
def decode_coordinate(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
|
|
|
64 |
Returns:
|
65 |
Encoded hidden states tensor for input to text model
|
66 |
"""
|
67 |
+
return w.size_encoder(fourier_features(size, w.size_features))
|
68 |
|
69 |
|
70 |
def decode_size(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
|
text.py
CHANGED
@@ -2,9 +2,8 @@ import torch
|
|
2 |
import torch.nn as nn
|
3 |
|
4 |
from torch.nn import functional as F
|
5 |
-
from bitblas.cache import OperatorCache
|
6 |
|
7 |
-
from .layers import layer_norm, mlp,
|
8 |
from .rope import apply_rotary_emb, precompute_freqs_cis
|
9 |
from .config import TextConfig
|
10 |
|
@@ -27,7 +26,6 @@ def attn(
|
|
27 |
head_dim = d_model // n_heads
|
28 |
|
29 |
qkv_out = w.qkv(x) # shape: (bsz, q_len, (n_heads + 2*n_kv_heads)*head_dim)
|
30 |
-
|
31 |
q_dim = n_heads * head_dim
|
32 |
kv_dim = n_kv_heads * head_dim
|
33 |
|
@@ -57,6 +55,71 @@ def attn(
|
|
57 |
return out
|
58 |
|
59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
def text_decoder(
|
61 |
x: torch.Tensor,
|
62 |
w: nn.Module,
|
@@ -76,7 +139,6 @@ def text_decoder(
|
|
76 |
n_kv_heads=config.n_kv_heads,
|
77 |
position_ids=position_ids,
|
78 |
)
|
79 |
-
|
80 |
l_mlp = mlp(l_in, block.mlp)
|
81 |
x = x + l_attn + l_mlp
|
82 |
|
@@ -90,30 +152,15 @@ def lm_head(hidden_BTC: torch.Tensor, w: nn.Module):
|
|
90 |
return logits
|
91 |
|
92 |
|
93 |
-
def
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
) -> nn.Module:
|
98 |
-
# note : layernorm dtype is used for layernorm, lm_head and wte not just layernorm
|
99 |
-
print(
|
100 |
-
"Initializing quantized backend. This only has to run once, but may take a few minutes."
|
101 |
-
)
|
102 |
-
qkv_dim = int(config.dim * (1 + 2 * config.n_kv_heads / config.n_heads))
|
103 |
-
|
104 |
-
group_size = None
|
105 |
-
if linear_dtype == torch.int8:
|
106 |
|
107 |
-
group_size = config.group_size
|
108 |
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
in_features=in_features,
|
113 |
-
out_features=out_features,
|
114 |
-
dtype=dtype,
|
115 |
-
group_size=group_size,
|
116 |
-
)
|
117 |
|
118 |
text = nn.ModuleDict(
|
119 |
{
|
@@ -121,17 +168,23 @@ def build_text_model(
|
|
121 |
[
|
122 |
nn.ModuleDict(
|
123 |
{
|
124 |
-
"ln": nn.LayerNorm(config.dim, dtype=
|
125 |
"attn": nn.ModuleDict(
|
126 |
{
|
127 |
-
"qkv":
|
128 |
-
"proj":
|
|
|
|
|
129 |
}
|
130 |
),
|
131 |
"mlp": nn.ModuleDict(
|
132 |
{
|
133 |
-
"fc1":
|
134 |
-
|
|
|
|
|
|
|
|
|
135 |
}
|
136 |
),
|
137 |
}
|
@@ -139,13 +192,11 @@ def build_text_model(
|
|
139 |
for _ in range(config.n_layers)
|
140 |
]
|
141 |
),
|
142 |
-
"post_ln": nn.LayerNorm(config.dim, dtype=
|
143 |
-
"lm_head":
|
144 |
}
|
145 |
)
|
146 |
-
text.wte = nn.Parameter(
|
147 |
-
torch.empty(config.vocab_size, config.dim, dtype=layernorm_dtype)
|
148 |
-
)
|
149 |
text.register_buffer(
|
150 |
"freqs_cis",
|
151 |
precompute_freqs_cis(config.dim // (2 * config.n_heads), config.max_context),
|
|
|
2 |
import torch.nn as nn
|
3 |
|
4 |
from torch.nn import functional as F
|
|
|
5 |
|
6 |
+
from .layers import layer_norm, mlp, QuantizedLinear
|
7 |
from .rope import apply_rotary_emb, precompute_freqs_cis
|
8 |
from .config import TextConfig
|
9 |
|
|
|
26 |
head_dim = d_model // n_heads
|
27 |
|
28 |
qkv_out = w.qkv(x) # shape: (bsz, q_len, (n_heads + 2*n_kv_heads)*head_dim)
|
|
|
29 |
q_dim = n_heads * head_dim
|
30 |
kv_dim = n_kv_heads * head_dim
|
31 |
|
|
|
55 |
return out
|
56 |
|
57 |
|
58 |
+
def _attn(
|
59 |
+
x: torch.Tensor,
|
60 |
+
w: torch.Tensor,
|
61 |
+
freqs_cis: torch.Tensor,
|
62 |
+
attn_mask: torch.Tensor,
|
63 |
+
n_heads: int,
|
64 |
+
n_kv_heads: int,
|
65 |
+
):
|
66 |
+
bsz, q_len, d_model = x.shape
|
67 |
+
head_dim = d_model // n_heads
|
68 |
+
pos = 0
|
69 |
+
|
70 |
+
qkv_out = w.qkv(x) # shape: (bsz, q_len, (n_heads + 2*n_kv_heads)*head_dim)
|
71 |
+
q_dim = n_heads * head_dim
|
72 |
+
kv_dim = n_kv_heads * head_dim
|
73 |
+
|
74 |
+
q = qkv_out[..., :q_dim].view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
|
75 |
+
k = (
|
76 |
+
qkv_out[..., q_dim : q_dim + kv_dim]
|
77 |
+
.view(bsz, q_len, n_kv_heads, head_dim)
|
78 |
+
.transpose(1, 2)
|
79 |
+
)
|
80 |
+
v = (
|
81 |
+
qkv_out[..., q_dim + kv_dim :]
|
82 |
+
.view(bsz, q_len, n_kv_heads, head_dim)
|
83 |
+
.transpose(1, 2)
|
84 |
+
)
|
85 |
+
|
86 |
+
position_ids = torch.arange(pos, pos + q_len, dtype=torch.long)
|
87 |
+
q = apply_rotary_emb(q, freqs_cis, position_ids, n_heads)
|
88 |
+
k = apply_rotary_emb(k, freqs_cis, position_ids, n_kv_heads)
|
89 |
+
out = F.scaled_dot_product_attention(
|
90 |
+
q, k, v, attn_mask=attn_mask, enable_gqa=n_heads != n_kv_heads
|
91 |
+
)
|
92 |
+
out = out.transpose(1, 2).reshape(bsz, q_len, d_model)
|
93 |
+
out = w.proj(out)
|
94 |
+
return out
|
95 |
+
|
96 |
+
|
97 |
+
def _produce_hidden(inputs_embeds: torch.Tensor, w: nn.Module, config: TextConfig):
|
98 |
+
hidden_BTC = inputs_embeds
|
99 |
+
|
100 |
+
bsz, q_len, d_model = inputs_embeds.shape
|
101 |
+
attn_mask = torch.zeros(q_len, q_len)
|
102 |
+
attn_mask[:730, :730] = 1
|
103 |
+
for i in range(730, q_len):
|
104 |
+
attn_mask[i, : i + 1] = 1
|
105 |
+
attn_mask = attn_mask.to(dtype=torch.bool)
|
106 |
+
|
107 |
+
for i, block in enumerate(w.blocks):
|
108 |
+
l_in = layer_norm(hidden_BTC, block.ln)
|
109 |
+
l_attn = _attn(
|
110 |
+
x=l_in,
|
111 |
+
w=block.attn,
|
112 |
+
freqs_cis=w.freqs_cis,
|
113 |
+
attn_mask=attn_mask,
|
114 |
+
n_heads=config.n_heads,
|
115 |
+
n_kv_heads=config.n_kv_heads,
|
116 |
+
)
|
117 |
+
l_mlp = mlp(l_in, block.mlp)
|
118 |
+
hidden_BTC = hidden_BTC + l_attn + l_mlp
|
119 |
+
|
120 |
+
return hidden_BTC
|
121 |
+
|
122 |
+
|
123 |
def text_decoder(
|
124 |
x: torch.Tensor,
|
125 |
w: nn.Module,
|
|
|
139 |
n_kv_heads=config.n_kv_heads,
|
140 |
position_ids=position_ids,
|
141 |
)
|
|
|
142 |
l_mlp = mlp(l_in, block.mlp)
|
143 |
x = x + l_attn + l_mlp
|
144 |
|
|
|
152 |
return logits
|
153 |
|
154 |
|
155 |
+
def _lm_head(hidden_BTC: torch.Tensor, w: nn.Module):
|
156 |
+
hidden_BTC = layer_norm(hidden_BTC, w.post_ln)
|
157 |
+
logits = w.lm_head(hidden_BTC)
|
158 |
+
return logits
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
|
|
|
160 |
|
161 |
+
def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module:
|
162 |
+
qkv_dim = int(config.dim * (1 + 2 * config.n_kv_heads / config.n_heads))
|
163 |
+
linear_cls = QuantizedLinear if config.group_size is not None else nn.Linear
|
|
|
|
|
|
|
|
|
|
|
164 |
|
165 |
text = nn.ModuleDict(
|
166 |
{
|
|
|
168 |
[
|
169 |
nn.ModuleDict(
|
170 |
{
|
171 |
+
"ln": nn.LayerNorm(config.dim, dtype=dtype),
|
172 |
"attn": nn.ModuleDict(
|
173 |
{
|
174 |
+
"qkv": linear_cls(config.dim, qkv_dim, dtype=dtype),
|
175 |
+
"proj": linear_cls(
|
176 |
+
config.dim, config.dim, dtype=dtype
|
177 |
+
),
|
178 |
}
|
179 |
),
|
180 |
"mlp": nn.ModuleDict(
|
181 |
{
|
182 |
+
"fc1": linear_cls(
|
183 |
+
config.dim, config.ff_dim, dtype=dtype
|
184 |
+
),
|
185 |
+
"fc2": linear_cls(
|
186 |
+
config.ff_dim, config.dim, dtype=dtype
|
187 |
+
),
|
188 |
}
|
189 |
),
|
190 |
}
|
|
|
192 |
for _ in range(config.n_layers)
|
193 |
]
|
194 |
),
|
195 |
+
"post_ln": nn.LayerNorm(config.dim, dtype=dtype),
|
196 |
+
"lm_head": linear_cls(config.dim, config.vocab_size, dtype=dtype),
|
197 |
}
|
198 |
)
|
199 |
+
text.wte = nn.Parameter(torch.empty(config.vocab_size, config.dim, dtype=dtype))
|
|
|
|
|
200 |
text.register_buffer(
|
201 |
"freqs_cis",
|
202 |
precompute_freqs_cis(config.dim // (2 * config.n_heads), config.max_context),
|
vision.py
CHANGED
@@ -6,7 +6,7 @@ import numpy as np
|
|
6 |
from typing import Union, Tuple
|
7 |
from PIL import Image
|
8 |
|
9 |
-
from .layers import attn, layer_norm,
|
10 |
from .image_crops import overlap_crop_image
|
11 |
from .config import VisionConfig
|
12 |
|
@@ -33,7 +33,7 @@ def prepare_crops(
|
|
33 |
all_crops = np.transpose(all_crops, (0, 3, 1, 2))
|
34 |
all_crops = (
|
35 |
torch.from_numpy(all_crops)
|
36 |
-
.to(device=device, dtype=torch.
|
37 |
.div_(255.0)
|
38 |
.sub_(0.5)
|
39 |
.div_(0.5)
|
@@ -64,7 +64,7 @@ def create_patches(x, patch_size):
|
|
64 |
def vision_encoder(input_BCHW: torch.Tensor, w: nn.Module, config: VisionConfig):
|
65 |
x = create_patches(input_BCHW, config.enc_patch_size)
|
66 |
|
67 |
-
x =
|
68 |
x = x + w.pos_emb
|
69 |
for block in w.blocks:
|
70 |
x = x + attn(layer_norm(x, block.ln1), block.attn, n_heads=config.enc_n_heads)
|
|
|
6 |
from typing import Union, Tuple
|
7 |
from PIL import Image
|
8 |
|
9 |
+
from .layers import attn, layer_norm, mlp
|
10 |
from .image_crops import overlap_crop_image
|
11 |
from .config import VisionConfig
|
12 |
|
|
|
33 |
all_crops = np.transpose(all_crops, (0, 3, 1, 2))
|
34 |
all_crops = (
|
35 |
torch.from_numpy(all_crops)
|
36 |
+
.to(device=device, dtype=torch.bfloat16)
|
37 |
.div_(255.0)
|
38 |
.sub_(0.5)
|
39 |
.div_(0.5)
|
|
|
64 |
def vision_encoder(input_BCHW: torch.Tensor, w: nn.Module, config: VisionConfig):
|
65 |
x = create_patches(input_BCHW, config.enc_patch_size)
|
66 |
|
67 |
+
x = w.patch_emb(x)
|
68 |
x = x + w.pos_emb
|
69 |
for block in w.blocks:
|
70 |
x = x + attn(layer_norm(x, block.ln1), block.attn, n_heads=config.enc_n_heads)
|