Update modeling_plamo.py (#9)
Browse files- Update modeling_plamo.py (e8dcdfe415d37765d1cf4396cea60cbeb6a8f0b9)
Co-authored-by: Shogo Murai <[email protected]>
- modeling_plamo.py +12 -2
modeling_plamo.py
CHANGED
@@ -19,6 +19,7 @@ import torch
|
|
19 |
from torch import nn
|
20 |
from torch.nn import functional as F
|
21 |
from transformers import PretrainedConfig, PreTrainedModel
|
|
|
22 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
23 |
|
24 |
|
@@ -327,7 +328,8 @@ class Plamo2Cache(torch.nn.Module):
|
|
327 |
if sequence_length is not None
|
328 |
else layer_cache.key.shape[2]
|
329 |
)
|
330 |
-
|
|
|
331 |
return sequence_length
|
332 |
|
333 |
def get_max_length(self) -> int | None:
|
@@ -1387,7 +1389,7 @@ class Plamo2Model(Plamo2PreTrainedModel):
|
|
1387 |
input_ids: Optional[torch.LongTensor] = None,
|
1388 |
attention_mask: Optional[torch.Tensor] = None,
|
1389 |
position_ids: Optional[torch.Tensor] = None,
|
1390 |
-
past_key_values: Optional[Plamo2Cache] = None,
|
1391 |
inputs_embeds: Optional[torch.Tensor] = None,
|
1392 |
image_features: Optional[torch.Tensor] = None,
|
1393 |
use_cache: Optional[bool] = None,
|
@@ -1419,6 +1421,14 @@ class Plamo2Model(Plamo2PreTrainedModel):
|
|
1419 |
seq_length_with_past = seq_length
|
1420 |
past_key_values_length = 0
|
1421 |
if past_key_values is not None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1422 |
past_key_values_length = past_key_values.get_seq_length()
|
1423 |
seq_length_with_past = seq_length_with_past + past_key_values_length
|
1424 |
assert cache_position is None, "cache_position is not supported yet"
|
|
|
19 |
from torch import nn
|
20 |
from torch.nn import functional as F
|
21 |
from transformers import PretrainedConfig, PreTrainedModel
|
22 |
+
from transformers.cache_utils import DynamicCache
|
23 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
24 |
|
25 |
|
|
|
328 |
if sequence_length is not None
|
329 |
else layer_cache.key.shape[2]
|
330 |
)
|
331 |
+
if sequence_length is None:
|
332 |
+
return 0
|
333 |
return sequence_length
|
334 |
|
335 |
def get_max_length(self) -> int | None:
|
|
|
1389 |
input_ids: Optional[torch.LongTensor] = None,
|
1390 |
attention_mask: Optional[torch.Tensor] = None,
|
1391 |
position_ids: Optional[torch.Tensor] = None,
|
1392 |
+
past_key_values: Optional[Plamo2Cache | DynamicCache] = None,
|
1393 |
inputs_embeds: Optional[torch.Tensor] = None,
|
1394 |
image_features: Optional[torch.Tensor] = None,
|
1395 |
use_cache: Optional[bool] = None,
|
|
|
1421 |
seq_length_with_past = seq_length
|
1422 |
past_key_values_length = 0
|
1423 |
if past_key_values is not None:
|
1424 |
+
# In some `transformers` versions, `past_key_values` may be a `DynamicCache` object.
|
1425 |
+
if not isinstance(past_key_values, Plamo2Cache):
|
1426 |
+
past_key_values_prev = past_key_values
|
1427 |
+
past_key_values = Plamo2Cache(self.config)
|
1428 |
+
|
1429 |
+
# If `past_key_values` is a `DynamicCache` object, it must be empty
|
1430 |
+
assert len(past_key_values_prev) == 0
|
1431 |
+
assert isinstance(past_key_values, Plamo2Cache)
|
1432 |
past_key_values_length = past_key_values.get_seq_length()
|
1433 |
seq_length_with_past = seq_length_with_past + past_key_values_length
|
1434 |
assert cache_position is None, "cache_position is not supported yet"
|