Upload BD3LM
Browse files- modeling_bd3lm.py +6 -8
modeling_bd3lm.py
CHANGED
@@ -267,10 +267,11 @@ def regular_attention_multi_headed(qkv):
|
|
267 |
|
268 |
|
269 |
class DDiTBlock(nn.Module):
|
270 |
-
def __init__(self, n, dim, n_heads, cond_dim, mlp_ratio=4,
|
271 |
dropout=0.1, attn_backend='flash_attn'):
|
272 |
super().__init__()
|
273 |
self.n = n
|
|
|
274 |
self.n_heads = n_heads
|
275 |
self.attn_backend = attn_backend
|
276 |
self.kv_cache = None
|
@@ -302,19 +303,15 @@ class DDiTBlock(nn.Module):
|
|
302 |
def get_qkv(self, x, rotary_cos_sin, store_kv=False):
|
303 |
# compute qkv (potentially use cache)
|
304 |
if self.kv_cache is not None:
|
305 |
-
|
306 |
-
new_qkv = self.attn_qkv(x[:, -block_len:])
|
307 |
qkv = torch.cat((self.kv_cache, new_qkv), dim=1)
|
308 |
else:
|
309 |
qkv = self.attn_qkv(x)
|
310 |
|
311 |
# store kv cache in a sliding window (can't exceed context len)
|
312 |
if store_kv:
|
313 |
-
|
314 |
-
|
315 |
-
self.kv_cache = qkv[:, -cache_len:]
|
316 |
-
else:
|
317 |
-
self.kv_cache = qkv
|
318 |
qkv = einops.rearrange(
|
319 |
qkv,
|
320 |
'b s (three h d) -> b s three h d',
|
@@ -440,6 +437,7 @@ class DITBackbone(nn.Module):
|
|
440 |
blocks = []
|
441 |
for _ in range(config.n_blocks):
|
442 |
blocks.append(DDiTBlock(self.n,
|
|
|
443 |
config.hidden_dim,
|
444 |
config.n_heads,
|
445 |
config.cond_dim,
|
|
|
267 |
|
268 |
|
269 |
class DDiTBlock(nn.Module):
|
270 |
+
def __init__(self, n, block_size, dim, n_heads, cond_dim, mlp_ratio=4,
|
271 |
dropout=0.1, attn_backend='flash_attn'):
|
272 |
super().__init__()
|
273 |
self.n = n
|
274 |
+
self.block_size = block_size
|
275 |
self.n_heads = n_heads
|
276 |
self.attn_backend = attn_backend
|
277 |
self.kv_cache = None
|
|
|
303 |
def get_qkv(self, x, rotary_cos_sin, store_kv=False):
|
304 |
# compute qkv (potentially use cache)
|
305 |
if self.kv_cache is not None:
|
306 |
+
new_qkv = self.attn_qkv(x[:, -self.block_size:])
|
|
|
307 |
qkv = torch.cat((self.kv_cache, new_qkv), dim=1)
|
308 |
else:
|
309 |
qkv = self.attn_qkv(x)
|
310 |
|
311 |
# store kv cache in a sliding window (can't exceed context len)
|
312 |
if store_kv:
|
313 |
+
self.kv_cache = qkv
|
314 |
+
self.kv_cache = self.kv_cache[:, -self.n:]
|
|
|
|
|
|
|
315 |
qkv = einops.rearrange(
|
316 |
qkv,
|
317 |
'b s (three h d) -> b s three h d',
|
|
|
437 |
blocks = []
|
438 |
for _ in range(config.n_blocks):
|
439 |
blocks.append(DDiTBlock(self.n,
|
440 |
+
self.block_size,
|
441 |
config.hidden_dim,
|
442 |
config.n_heads,
|
443 |
config.cond_dim,
|