Upload BD3LM
Browse files- modeling_bd3lm.py +2 -4
modeling_bd3lm.py
CHANGED
@@ -299,7 +299,6 @@ class DDiTBlock(nn.Module):
|
|
299 |
else:
|
300 |
return bias_dropout_add_scale_fused_inference
|
301 |
|
302 |
-
|
303 |
def get_qkv(self, x, rotary_cos_sin, store_kv=False):
|
304 |
# compute qkv (potentially use cache)
|
305 |
if self.kv_cache is not None:
|
@@ -307,11 +306,10 @@ class DDiTBlock(nn.Module):
|
|
307 |
qkv = torch.cat((self.kv_cache, new_qkv), dim=1)
|
308 |
else:
|
309 |
qkv = self.attn_qkv(x)
|
310 |
-
|
311 |
# store kv cache in a sliding window (can't exceed context len)
|
312 |
if store_kv:
|
313 |
-
self.kv_cache = qkv
|
314 |
-
|
315 |
qkv = einops.rearrange(
|
316 |
qkv,
|
317 |
'b s (three h d) -> b s three h d',
|
|
|
299 |
else:
|
300 |
return bias_dropout_add_scale_fused_inference
|
301 |
|
|
|
302 |
def get_qkv(self, x, rotary_cos_sin, store_kv=False):
|
303 |
# compute qkv (potentially use cache)
|
304 |
if self.kv_cache is not None:
|
|
|
306 |
qkv = torch.cat((self.kv_cache, new_qkv), dim=1)
|
307 |
else:
|
308 |
qkv = self.attn_qkv(x)
|
|
|
309 |
# store kv cache in a sliding window (can't exceed context len)
|
310 |
if store_kv:
|
311 |
+
self.kv_cache = qkv[:, -(self.n-self.block_size):]
|
312 |
+
|
313 |
qkv = einops.rearrange(
|
314 |
qkv,
|
315 |
'b s (three h d) -> b s three h d',
|