Upload BD3LM
Browse files- modeling_bd3lm.py +2 -4
modeling_bd3lm.py
CHANGED
|
@@ -299,7 +299,6 @@ class DDiTBlock(nn.Module):
|
|
| 299 |
else:
|
| 300 |
return bias_dropout_add_scale_fused_inference
|
| 301 |
|
| 302 |
-
|
| 303 |
def get_qkv(self, x, rotary_cos_sin, store_kv=False):
|
| 304 |
# compute qkv (potentially use cache)
|
| 305 |
if self.kv_cache is not None:
|
|
@@ -307,11 +306,10 @@ class DDiTBlock(nn.Module):
|
|
| 307 |
qkv = torch.cat((self.kv_cache, new_qkv), dim=1)
|
| 308 |
else:
|
| 309 |
qkv = self.attn_qkv(x)
|
| 310 |
-
|
| 311 |
# store kv cache in a sliding window (can't exceed context len)
|
| 312 |
if store_kv:
|
| 313 |
-
self.kv_cache = qkv
|
| 314 |
-
|
| 315 |
qkv = einops.rearrange(
|
| 316 |
qkv,
|
| 317 |
'b s (three h d) -> b s three h d',
|
|
|
|
| 299 |
else:
|
| 300 |
return bias_dropout_add_scale_fused_inference
|
| 301 |
|
|
|
|
| 302 |
def get_qkv(self, x, rotary_cos_sin, store_kv=False):
|
| 303 |
# compute qkv (potentially use cache)
|
| 304 |
if self.kv_cache is not None:
|
|
|
|
| 306 |
qkv = torch.cat((self.kv_cache, new_qkv), dim=1)
|
| 307 |
else:
|
| 308 |
qkv = self.attn_qkv(x)
|
|
|
|
| 309 |
# store kv cache in a sliding window (can't exceed context len)
|
| 310 |
if store_kv:
|
| 311 |
+
self.kv_cache = qkv[:, -(self.n-self.block_size):]
|
| 312 |
+
|
| 313 |
qkv = einops.rearrange(
|
| 314 |
qkv,
|
| 315 |
'b s (three h d) -> b s three h d',
|