datalama commited on
Commit
06575a7
·
verified ·
1 Parent(s): 7a6b851

Update 0_KananaEmbeddingWrapper/kanana2vec/modeling_kanana2vec.py

Browse files
0_KananaEmbeddingWrapper/kanana2vec/modeling_kanana2vec.py CHANGED
@@ -97,7 +97,6 @@ class BiLlamaModel(LlamaModel):
97
  sequence_length: int,
98
  target_length: int,
99
  dtype: torch.dtype,
100
- device: torch.device,
101
  cache_position: torch.Tensor,
102
  batch_size: int,
103
  **kwargs,
@@ -117,8 +116,6 @@ class BiLlamaModel(LlamaModel):
117
  to account for the 0 padding, the part of the cache that is not filled yet.
118
  dtype (`torch.dtype`):
119
  The dtype to use for the 4D attention mask.
120
- device (`torch.device`):
121
- The device to plcae the 4D attention mask on.
122
  cache_position (`torch.Tensor`):
123
  Indices depicting the position of the input sequence tokens in the sequence.
124
  batch_size (`torch.Tensor`):
@@ -130,7 +127,7 @@ class BiLlamaModel(LlamaModel):
130
  else:
131
  min_dtype = torch.finfo(dtype).min
132
  causal_mask = torch.zeros(
133
- (sequence_length, target_length), dtype=dtype, device=device
134
  )
135
  causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
136
  if attention_mask is not None:
 
97
  sequence_length: int,
98
  target_length: int,
99
  dtype: torch.dtype,
 
100
  cache_position: torch.Tensor,
101
  batch_size: int,
102
  **kwargs,
 
116
  to account for the 0 padding, the part of the cache that is not filled yet.
117
  dtype (`torch.dtype`):
118
  The dtype to use for the 4D attention mask.
 
 
119
  cache_position (`torch.Tensor`):
120
  Indices depicting the position of the input sequence tokens in the sequence.
121
  batch_size (`torch.Tensor`):
 
127
  else:
128
  min_dtype = torch.finfo(dtype).min
129
  causal_mask = torch.zeros(
130
+ (sequence_length, target_length), dtype=dtype, device=cache_position.device
131
  )
132
  causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
133
  if attention_mask is not None: