yhirokawa shmurai commited on
Commit
74b112b
·
verified ·
1 Parent(s): a99ff56

Update modeling_plamo.py (#9)

Browse files

- Update modeling_plamo.py (e8dcdfe415d37765d1cf4396cea60cbeb6a8f0b9)


Co-authored-by: Shogo Murai <[email protected]>

Files changed (1) hide show
  1. modeling_plamo.py +12 -2
modeling_plamo.py CHANGED
@@ -19,6 +19,7 @@ import torch
19
  from torch import nn
20
  from torch.nn import functional as F
21
  from transformers import PretrainedConfig, PreTrainedModel
 
22
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
23
 
24
 
@@ -327,7 +328,8 @@ class Plamo2Cache(torch.nn.Module):
327
  if sequence_length is not None
328
  else layer_cache.key.shape[2]
329
  )
330
- assert sequence_length is not None
 
331
  return sequence_length
332
 
333
  def get_max_length(self) -> int | None:
@@ -1387,7 +1389,7 @@ class Plamo2Model(Plamo2PreTrainedModel):
1387
  input_ids: Optional[torch.LongTensor] = None,
1388
  attention_mask: Optional[torch.Tensor] = None,
1389
  position_ids: Optional[torch.Tensor] = None,
1390
- past_key_values: Optional[Plamo2Cache] = None,
1391
  inputs_embeds: Optional[torch.Tensor] = None,
1392
  image_features: Optional[torch.Tensor] = None,
1393
  use_cache: Optional[bool] = None,
@@ -1419,6 +1421,14 @@ class Plamo2Model(Plamo2PreTrainedModel):
1419
  seq_length_with_past = seq_length
1420
  past_key_values_length = 0
1421
  if past_key_values is not None:
 
 
 
 
 
 
 
 
1422
  past_key_values_length = past_key_values.get_seq_length()
1423
  seq_length_with_past = seq_length_with_past + past_key_values_length
1424
  assert cache_position is None, "cache_position is not supported yet"
 
19
  from torch import nn
20
  from torch.nn import functional as F
21
  from transformers import PretrainedConfig, PreTrainedModel
22
+ from transformers.cache_utils import DynamicCache
23
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
24
 
25
 
 
328
  if sequence_length is not None
329
  else layer_cache.key.shape[2]
330
  )
331
+ if sequence_length is None:
332
+ return 0
333
  return sequence_length
334
 
335
  def get_max_length(self) -> int | None:
 
1389
  input_ids: Optional[torch.LongTensor] = None,
1390
  attention_mask: Optional[torch.Tensor] = None,
1391
  position_ids: Optional[torch.Tensor] = None,
1392
+ past_key_values: Optional[Plamo2Cache | DynamicCache] = None,
1393
  inputs_embeds: Optional[torch.Tensor] = None,
1394
  image_features: Optional[torch.Tensor] = None,
1395
  use_cache: Optional[bool] = None,
 
1421
  seq_length_with_past = seq_length
1422
  past_key_values_length = 0
1423
  if past_key_values is not None:
1424
+ # In some `transformers` versions, `past_key_values` may be a `DynamicCache` object.
1425
+ if not isinstance(past_key_values, Plamo2Cache):
1426
+ past_key_values_prev = past_key_values
1427
+ past_key_values = Plamo2Cache(self.config)
1428
+
1429
+ # If `past_key_values` is a `DynamicCache` object, it must be empty
1430
+ assert len(past_key_values_prev) == 0
1431
+ assert isinstance(past_key_values, Plamo2Cache)
1432
  past_key_values_length = past_key_values.get_seq_length()
1433
  seq_length_with_past = seq_length_with_past + past_key_values_length
1434
  assert cache_position is None, "cache_position is not supported yet"