Update 0_KananaEmbeddingWrapper/kanana2vec/modeling_kanana2vec.py
Browse files
0_KananaEmbeddingWrapper/kanana2vec/modeling_kanana2vec.py
CHANGED
|
@@ -97,7 +97,6 @@ class BiLlamaModel(LlamaModel):
|
|
| 97 |
sequence_length: int,
|
| 98 |
target_length: int,
|
| 99 |
dtype: torch.dtype,
|
| 100 |
-
device: torch.device,
|
| 101 |
cache_position: torch.Tensor,
|
| 102 |
batch_size: int,
|
| 103 |
**kwargs,
|
|
@@ -117,8 +116,6 @@ class BiLlamaModel(LlamaModel):
|
|
| 117 |
to account for the 0 padding, the part of the cache that is not filled yet.
|
| 118 |
dtype (`torch.dtype`):
|
| 119 |
The dtype to use for the 4D attention mask.
|
| 120 |
-
device (`torch.device`):
|
| 121 |
-
The device to plcae the 4D attention mask on.
|
| 122 |
cache_position (`torch.Tensor`):
|
| 123 |
Indices depicting the position of the input sequence tokens in the sequence.
|
| 124 |
batch_size (`torch.Tensor`):
|
|
@@ -130,7 +127,7 @@ class BiLlamaModel(LlamaModel):
|
|
| 130 |
else:
|
| 131 |
min_dtype = torch.finfo(dtype).min
|
| 132 |
causal_mask = torch.zeros(
|
| 133 |
-
(sequence_length, target_length), dtype=dtype, device=device
|
| 134 |
)
|
| 135 |
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
| 136 |
if attention_mask is not None:
|
|
|
|
| 97 |
sequence_length: int,
|
| 98 |
target_length: int,
|
| 99 |
dtype: torch.dtype,
|
|
|
|
| 100 |
cache_position: torch.Tensor,
|
| 101 |
batch_size: int,
|
| 102 |
**kwargs,
|
|
|
|
| 116 |
to account for the 0 padding, the part of the cache that is not filled yet.
|
| 117 |
dtype (`torch.dtype`):
|
| 118 |
The dtype to use for the 4D attention mask.
|
|
|
|
|
|
|
| 119 |
cache_position (`torch.Tensor`):
|
| 120 |
Indices depicting the position of the input sequence tokens in the sequence.
|
| 121 |
batch_size (`torch.Tensor`):
|
|
|
|
| 127 |
else:
|
| 128 |
min_dtype = torch.finfo(dtype).min
|
| 129 |
causal_mask = torch.zeros(
|
| 130 |
+
(sequence_length, target_length), dtype=dtype, device=cache_position.device
|
| 131 |
)
|
| 132 |
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
| 133 |
if attention_mask is not None:
|