Update modeling_qwen2.py
#7
by
xiezhe24
- opened
- modeling_qwen2.py +9 -2
modeling_qwen2.py
CHANGED
|
@@ -1450,6 +1450,9 @@ class Qwen2TSForCausalLM(Qwen2PreTrainedModel):
|
|
| 1450 |
attention_mask=attention_mask
|
| 1451 |
)
|
| 1452 |
|
|
|
|
|
|
|
|
|
|
| 1453 |
def _update_model_kwargs_for_generation(
|
| 1454 |
self,
|
| 1455 |
outputs: ModelOutput,
|
|
@@ -1505,8 +1508,12 @@ class Qwen2TSForCausalLM(Qwen2PreTrainedModel):
|
|
| 1505 |
if past_key_values is not None:
|
| 1506 |
if isinstance(past_key_values, Cache):
|
| 1507 |
cache_length = past_key_values.get_seq_length()
|
| 1508 |
-
past_length
|
| 1509 |
-
max_cache_length =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1510 |
else:
|
| 1511 |
cache_length = past_length = past_key_values[0][0].shape[2]
|
| 1512 |
max_cache_length = None
|
|
|
|
| 1450 |
attention_mask=attention_mask
|
| 1451 |
)
|
| 1452 |
|
| 1453 |
+
def _extract_past_from_model_output(self, outputs: ModelOutput):
|
| 1454 |
+
return "past_key_values", outputs.past_key_values
|
| 1455 |
+
|
| 1456 |
def _update_model_kwargs_for_generation(
|
| 1457 |
self,
|
| 1458 |
outputs: ModelOutput,
|
|
|
|
| 1508 |
if past_key_values is not None:
|
| 1509 |
if isinstance(past_key_values, Cache):
|
| 1510 |
cache_length = past_key_values.get_seq_length()
|
| 1511 |
+
past_length = past_key_values.seen_tokens
|
| 1512 |
+
max_cache_length = (
|
| 1513 |
+
past_key_values.get_max_length()
|
| 1514 |
+
if hasattr(past_key_values, "get_max_length")
|
| 1515 |
+
else past_key_values.get_max_cache_shape()
|
| 1516 |
+
)
|
| 1517 |
else:
|
| 1518 |
cache_length = past_length = past_key_values[0][0].shape[2]
|
| 1519 |
max_cache_length = None
|