echarlaix HF Staff commited on
Commit
0d2270a
·
1 Parent(s): 3abc86b

legacy cache support

Browse files
Files changed (1) hide show
  1. modeling_arctic.py +7 -3
modeling_arctic.py CHANGED
@@ -1763,9 +1763,13 @@ class ArcticForCausalLM(ArcticPreTrainedModel):
1763
  ):
1764
  # Omit tokens covered by past_key_values
1765
  if past_key_values is not None:
1766
- cache_length = past_key_values.get_seq_length()
1767
- past_length = past_key_values.seen_tokens
1768
- max_cache_length = past_key_values.get_max_length() if hasattr(past_key_values, "get_max_length") else past_key_values.get_max_cache_shape()
 
 
 
 
1769
 
1770
  # Keep only the unprocessed tokens:
1771
  # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
 
1763
  ):
1764
  # Omit tokens covered by past_key_values
1765
  if past_key_values is not None:
1766
+ if isinstance(past_key_values, Cache):
1767
+ cache_length = past_key_values.get_seq_length()
1768
+ past_length = past_key_values.seen_tokens
1769
+ max_cache_length = past_key_values.get_max_length() if hasattr(past_key_values, "get_max_length") else past_key_values.get_max_cache_shape()
1770
+ else:
1771
+ cache_length = past_length = past_key_values[0][0].shape[2]
1772
+ max_cache_length = None
1773
 
1774
  # Keep only the unprocessed tokens:
1775
  # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where