marriola commited on
Commit
35ce154
·
verified ·
1 Parent(s): 00c5b41

Upload BD3LM

Browse files
Files changed (1) hide show
  1. modeling_bd3lm.py +2 -4
modeling_bd3lm.py CHANGED
@@ -299,7 +299,6 @@ class DDiTBlock(nn.Module):
299
  else:
300
  return bias_dropout_add_scale_fused_inference
301
 
302
-
303
  def get_qkv(self, x, rotary_cos_sin, store_kv=False):
304
  # compute qkv (potentially use cache)
305
  if self.kv_cache is not None:
@@ -307,11 +306,10 @@ class DDiTBlock(nn.Module):
307
  qkv = torch.cat((self.kv_cache, new_qkv), dim=1)
308
  else:
309
  qkv = self.attn_qkv(x)
310
-
311
  # store kv cache in a sliding window (can't exceed context len)
312
  if store_kv:
313
- self.kv_cache = qkv
314
- self.kv_cache = self.kv_cache[:, -self.n:]
315
  qkv = einops.rearrange(
316
  qkv,
317
  'b s (three h d) -> b s three h d',
 
299
  else:
300
  return bias_dropout_add_scale_fused_inference
301
 
 
302
  def get_qkv(self, x, rotary_cos_sin, store_kv=False):
303
  # compute qkv (potentially use cache)
304
  if self.kv_cache is not None:
 
306
  qkv = torch.cat((self.kv_cache, new_qkv), dim=1)
307
  else:
308
  qkv = self.attn_qkv(x)
 
309
  # store kv cache in a sliding window (can't exceed context len)
310
  if store_kv:
311
+ self.kv_cache = qkv[:, -(self.n-self.block_size):]
312
+
313
  qkv = einops.rearrange(
314
  qkv,
315
  'b s (three h d) -> b s three h d',