marriola commited on
Commit
e58f74c
·
verified ·
1 Parent(s): d25f94c

Upload BD3LM

Browse files
Files changed (1) hide show
  1. modeling_bd3lm.py +4 -0
modeling_bd3lm.py CHANGED
@@ -525,6 +525,10 @@ class BD3LM(transformers.PreTrainedModel):
525
  'sampling_eps_max',
526
  torch.tensor(config.sampling_eps_max))
527
 
 
 
 
 
528
  def forward(
529
  self,
530
  input_ids: torch.LongTensor = None,
 
525
  'sampling_eps_max',
526
  torch.tensor(config.sampling_eps_max))
527
 
528
+ def reset_kv_cache(self):
529
+ for block in self.backbone.blocks:
530
+ block.kv_cache = None
531
+
532
  def forward(
533
  self,
534
  input_ids: torch.LongTensor = None,