marriola commited on
Commit
00c5b41
·
verified ·
1 Parent(s): e58f74c

Upload BD3LM

Browse files
Files changed (1) hide show
  1. modeling_bd3lm.py +6 -8
modeling_bd3lm.py CHANGED
@@ -267,10 +267,11 @@ def regular_attention_multi_headed(qkv):
267
 
268
 
269
  class DDiTBlock(nn.Module):
270
- def __init__(self, n, dim, n_heads, cond_dim, mlp_ratio=4,
271
  dropout=0.1, attn_backend='flash_attn'):
272
  super().__init__()
273
  self.n = n
 
274
  self.n_heads = n_heads
275
  self.attn_backend = attn_backend
276
  self.kv_cache = None
@@ -302,19 +303,15 @@ class DDiTBlock(nn.Module):
302
  def get_qkv(self, x, rotary_cos_sin, store_kv=False):
303
  # compute qkv (potentially use cache)
304
  if self.kv_cache is not None:
305
- block_len = x.shape[1] - self.kv_cache.shape[1]
306
- new_qkv = self.attn_qkv(x[:, -block_len:])
307
  qkv = torch.cat((self.kv_cache, new_qkv), dim=1)
308
  else:
309
  qkv = self.attn_qkv(x)
310
 
311
  # store kv cache in a sliding window (can't exceed context len)
312
  if store_kv:
313
- if self.kv_cache is not None:
314
- cache_len = min(x.shape[1], self.n - block_len)
315
- self.kv_cache = qkv[:, -cache_len:]
316
- else:
317
- self.kv_cache = qkv
318
  qkv = einops.rearrange(
319
  qkv,
320
  'b s (three h d) -> b s three h d',
@@ -440,6 +437,7 @@ class DITBackbone(nn.Module):
440
  blocks = []
441
  for _ in range(config.n_blocks):
442
  blocks.append(DDiTBlock(self.n,
 
443
  config.hidden_dim,
444
  config.n_heads,
445
  config.cond_dim,
 
267
 
268
 
269
  class DDiTBlock(nn.Module):
270
+ def __init__(self, n, block_size, dim, n_heads, cond_dim, mlp_ratio=4,
271
  dropout=0.1, attn_backend='flash_attn'):
272
  super().__init__()
273
  self.n = n
274
+ self.block_size = block_size
275
  self.n_heads = n_heads
276
  self.attn_backend = attn_backend
277
  self.kv_cache = None
 
303
  def get_qkv(self, x, rotary_cos_sin, store_kv=False):
304
  # compute qkv (potentially use cache)
305
  if self.kv_cache is not None:
306
+ new_qkv = self.attn_qkv(x[:, -self.block_size:])
 
307
  qkv = torch.cat((self.kv_cache, new_qkv), dim=1)
308
  else:
309
  qkv = self.attn_qkv(x)
310
 
311
  # store kv cache in a sliding window (can't exceed context len)
312
  if store_kv:
313
+ self.kv_cache = qkv
314
+ self.kv_cache = self.kv_cache[:, -self.n:]
 
 
 
315
  qkv = einops.rearrange(
316
  qkv,
317
  'b s (three h d) -> b s three h d',
 
437
  blocks = []
438
  for _ in range(config.n_blocks):
439
  blocks.append(DDiTBlock(self.n,
440
+ self.block_size,
441
  config.hidden_dim,
442
  config.n_heads,
443
  config.cond_dim,