Upload BD3LM
Browse files- modeling_bd3lm.py +4 -0
modeling_bd3lm.py
CHANGED
@@ -525,6 +525,10 @@ class BD3LM(transformers.PreTrainedModel):
|
|
525 |
'sampling_eps_max',
|
526 |
torch.tensor(config.sampling_eps_max))
|
527 |
|
|
|
|
|
|
|
|
|
528 |
def forward(
|
529 |
self,
|
530 |
input_ids: torch.LongTensor = None,
|
|
|
525 |
'sampling_eps_max',
|
526 |
torch.tensor(config.sampling_eps_max))
|
527 |
|
528 |
+
def reset_kv_cache(self):
|
529 |
+
for block in self.backbone.blocks:
|
530 |
+
block.kv_cache = None
|
531 |
+
|
532 |
def forward(
|
533 |
self,
|
534 |
input_ids: torch.LongTensor = None,
|