legacy cache support
Browse files- modeling_arctic.py +7 -3
modeling_arctic.py
CHANGED
@@ -1763,9 +1763,13 @@ class ArcticForCausalLM(ArcticPreTrainedModel):
|
|
1763 |
):
|
1764 |
# Omit tokens covered by past_key_values
|
1765 |
if past_key_values is not None:
|
1766 |
-
|
1767 |
-
|
1768 |
-
|
|
|
|
|
|
|
|
|
1769 |
|
1770 |
# Keep only the unprocessed tokens:
|
1771 |
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
|
|
1763 |
):
|
1764 |
# Omit tokens covered by past_key_values
|
1765 |
if past_key_values is not None:
|
1766 |
+
if isinstance(past_key_values, Cache):
|
1767 |
+
cache_length = past_key_values.get_seq_length()
|
1768 |
+
past_length = past_key_values.seen_tokens
|
1769 |
+
max_cache_length = past_key_values.get_max_length() if hasattr(past_key_values, "get_max_length") else past_key_values.get_max_cache_shape()
|
1770 |
+
else:
|
1771 |
+
cache_length = past_length = past_key_values[0][0].shape[2]
|
1772 |
+
max_cache_length = None
|
1773 |
|
1774 |
# Keep only the unprocessed tokens:
|
1775 |
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|