LogicBombaklot commited on
Commit
f3abe78
·
verified ·
1 Parent(s): b34d2f8

Upload transformers_4_44_2__modeling_flash_attention_utils_backward_compat.py with huggingface_hub

Browse files
transformers_4_44_2__modeling_flash_attention_utils_backward_compat.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import inspect
17
+ import os
18
+ from typing import Optional, Tuple, Union
19
+
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+
24
+ from functools import lru_cache
25
+ import importlib.metadata
26
+ import importlib.util
27
+ from packaging import version
28
+
29
+ from transformers.utils import is_flash_attn_2_available
30
+
31
+
32
+ if is_flash_attn_2_available():
33
+ try:
34
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
35
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
36
+ _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
37
+ except ImportError:
38
+ raise "Unable to import flash_attn"
39
+
40
+
41
+ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[Tuple[bool, str], bool]:
42
+ # Check if the package spec exists and grab its version to avoid importing a local directory
43
+ package_exists = importlib.util.find_spec(pkg_name) is not None
44
+ package_version = "N/A"
45
+ if package_exists:
46
+ try:
47
+ # Primary method to get the package version
48
+ package_version = importlib.metadata.version(pkg_name)
49
+ except importlib.metadata.PackageNotFoundError:
50
+ # Fallback method: Only for "torch" and versions containing "dev"
51
+ if pkg_name == "torch":
52
+ try:
53
+ package = importlib.import_module(pkg_name)
54
+ temp_version = getattr(package, "__version__", "N/A")
55
+ # Check if the version contains "dev"
56
+ if "dev" in temp_version:
57
+ package_version = temp_version
58
+ package_exists = True
59
+ else:
60
+ package_exists = False
61
+ except ImportError:
62
+ # If the package can't be imported, it's not available
63
+ package_exists = False
64
+ else:
65
+ # For packages other than "torch", don't attempt the fallback and set as not available
66
+ package_exists = False
67
+ if return_version:
68
+ return package_exists, package_version
69
+ else:
70
+ return package_exists
71
+
72
+
73
+ @lru_cache()
74
+ def is_flash_attn_greater_or_equal(library_version: str):
75
+ if not _is_package_available("flash_attn"):
76
+ return False
77
+
78
+ return version.parse(importlib.metadata.version("flash_attn")) >= version.parse(library_version)
79
+
80
+
81
+ def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]:
82
+ """
83
+ Retrieves indexing data required to repad unpadded (ragged) tensors.
84
+
85
+ Arguments:
86
+ attention_mask (`torch.Tensor`):
87
+ Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
88
+
89
+ Return:
90
+ indices (`torch.Tensor`):
91
+ The indices of non-masked tokens from the flattened input sequence.
92
+ cu_seqlens (`torch.Tensor`):
93
+ The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
94
+ max_seqlen_in_batch (`int`):
95
+ Maximum sequence length in batch.
96
+ """
97
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
98
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
99
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
100
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
101
+ return (
102
+ indices,
103
+ cu_seqlens,
104
+ max_seqlen_in_batch,
105
+ )
106
+
107
+
108
+ def _upad_input(
109
+ query_layer: torch.Tensor,
110
+ key_layer: torch.Tensor,
111
+ value_layer: torch.Tensor,
112
+ attention_mask: torch.Tensor,
113
+ query_length: int,
114
+ ):
115
+ """
116
+ Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches.
117
+
118
+ This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary
119
+ tensors for query, key, value tensors.
120
+
121
+ Arguments:
122
+ query_layer (`torch.Tensor`):
123
+ Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
124
+ key_layer (`torch.Tensor`):
125
+ Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
126
+ value_layer (`torch.Tensor`):
127
+ Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
128
+ attention_mask (`torch.Tensor`):
129
+ Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
130
+ query_length (`int`):
131
+ Target length.
132
+
133
+ Return:
134
+ query_layer (`torch.Tensor`):
135
+ Query state without padding. Shape: (total_target_length, num_heads, head_dim).
136
+ key_layer (`torch.Tensor`):
137
+ Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
138
+ value_layer (`torch.Tensor`):
139
+ Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
140
+ indices_q (`torch.Tensor`):
141
+ The indices of non-masked tokens from the flattened input target sequence.
142
+ (cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`):
143
+ The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
144
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`):
145
+ Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
146
+ """
147
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
148
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
149
+
150
+ key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k)
151
+ value_layer = index_first_axis(
152
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
153
+ )
154
+ if query_length == kv_seq_len:
155
+ query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, -1, head_dim), indices_k)
156
+ cu_seqlens_q = cu_seqlens_k
157
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
158
+ indices_q = indices_k
159
+ elif query_length == 1:
160
+ max_seqlen_in_batch_q = 1
161
+ cu_seqlens_q = torch.arange(
162
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
163
+ ) # There is a memcpy here, that is very bad.
164
+ indices_q = cu_seqlens_q[:-1]
165
+ query_layer = query_layer.squeeze(1)
166
+ else:
167
+ # The -q_len: slice assumes left padding.
168
+ attention_mask = attention_mask[:, -query_length:]
169
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
170
+
171
+ return (
172
+ query_layer,
173
+ key_layer,
174
+ value_layer,
175
+ indices_q,
176
+ (cu_seqlens_q, cu_seqlens_k),
177
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
178
+ )
179
+
180
+
181
+ def prepare_fa2_from_position_ids(query, key, value, position_ids):
182
+ """
183
+ This function returns necessary arguments to call `flash_attn_varlen_func`.
184
+ All three query, key, value states will be flattened.
185
+ Cummulative lengths of each examples in the batch will be extracted from position_ids.
186
+
187
+ NOTE: ideally cummulative lengths should be prepared at the data collator stage
188
+
189
+ Arguments:
190
+ query (`torch.Tensor`):
191
+ Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
192
+ key (`torch.Tensor`):
193
+ Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
194
+ value (`torch.Tensor`):
195
+ Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
196
+ position_ids (`torch.Tensor`):
197
+ Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
198
+
199
+ Return:
200
+ query (`torch.Tensor`):
201
+ Query state without padding. Shape: (total_target_length, num_heads, head_dim).
202
+ key (`torch.Tensor`):
203
+ Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
204
+ value (`torch.Tensor`):
205
+ Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
206
+ indices_q (`torch.Tensor`):
207
+ The indices of non-masked tokens from the flattened input target sequence.
208
+ (cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`):
209
+ The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
210
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`):
211
+ Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
212
+ """
213
+ query = query.view(-1, query.size(-2), query.size(-1))
214
+ key = key.view(-1, key.size(-2), key.size(-1))
215
+ value = value.view(-1, value.size(-2), value.size(-1))
216
+ position_ids = position_ids.flatten()
217
+ indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)
218
+
219
+ cu_seq_lens = torch.cat(
220
+ (
221
+ indices_q[position_ids == 0],
222
+ torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),
223
+ )
224
+ )
225
+
226
+ max_length = position_ids.max() + 1
227
+
228
+ return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length))
229
+
230
+
231
+ def _flash_attention_forward(
232
+ query_states: torch.Tensor,
233
+ key_states: torch.Tensor,
234
+ value_states: torch.Tensor,
235
+ attention_mask: torch.Tensor,
236
+ query_length: int,
237
+ is_causal: bool,
238
+ dropout: float = 0.0,
239
+ position_ids: Optional[torch.Tensor] = None,
240
+ softmax_scale: Optional[float] = None,
241
+ sliding_window: Optional[int] = None,
242
+ use_top_left_mask: bool = False,
243
+ softcap: Optional[float] = None,
244
+ deterministic: bool = None,
245
+ ):
246
+ """
247
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
248
+ first unpad the input, then computes the attention scores and pad the final attention scores.
249
+
250
+ Args:
251
+ query_states (`torch.Tensor`):
252
+ Input query states to be passed to Flash Attention API
253
+ key_states (`torch.Tensor`):
254
+ Input key states to be passed to Flash Attention API
255
+ value_states (`torch.Tensor`):
256
+ Input value states to be passed to Flash Attention API
257
+ attention_mask (`torch.Tensor`):
258
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
259
+ position of padding tokens and 1 for the position of non-padding tokens.
260
+ dropout (`float`):
261
+ Attention dropout
262
+ softmax_scale (`float`, *optional*):
263
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
264
+ use_top_left_mask (`bool`, defaults to `False`):
265
+ flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference.
266
+ softcap (`float`, *optional*):
267
+ Softcap for the attention logits, used e.g. in gemma2.
268
+ deterministic (`bool`, *optional*):
269
+ Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled.
270
+ """
271
+ if not use_top_left_mask:
272
+ causal = is_causal
273
+ else:
274
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__.
275
+ causal = is_causal and query_length != 1
276
+
277
+ # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
278
+ use_sliding_windows = (
279
+ _flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window
280
+ )
281
+ flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {}
282
+
283
+ if is_flash_attn_greater_or_equal("2.4.1"):
284
+ if deterministic is None:
285
+ deterministic = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
286
+ flash_kwargs["deterministic"] = deterministic
287
+
288
+ if softcap is not None:
289
+ flash_kwargs["softcap"] = softcap
290
+
291
+ # Contains at least one padding token in the sequence
292
+ if attention_mask is not None:
293
+ batch_size = query_states.shape[0]
294
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = _upad_input(
295
+ query_states, key_states, value_states, attention_mask, query_length
296
+ )
297
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
298
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
299
+
300
+ attn_output_unpad = flash_attn_varlen_func(
301
+ query_states,
302
+ key_states,
303
+ value_states,
304
+ cu_seqlens_q=cu_seqlens_q,
305
+ cu_seqlens_k=cu_seqlens_k,
306
+ max_seqlen_q=max_seqlen_in_batch_q,
307
+ max_seqlen_k=max_seqlen_in_batch_k,
308
+ dropout_p=dropout,
309
+ softmax_scale=softmax_scale,
310
+ causal=causal,
311
+ **flash_kwargs,
312
+ )
313
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
314
+
315
+ # If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
316
+ # then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
317
+ # Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
318
+ elif position_ids is not None and query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all():
319
+ batch_size = query_states.size(0)
320
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
321
+ query_states, key_states, value_states, position_ids
322
+ )
323
+
324
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
325
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
326
+
327
+ attn_output = flash_attn_varlen_func(
328
+ query_states,
329
+ key_states,
330
+ value_states,
331
+ cu_seqlens_q=cu_seqlens_q,
332
+ cu_seqlens_k=cu_seqlens_k,
333
+ max_seqlen_q=max_seqlen_in_batch_q,
334
+ max_seqlen_k=max_seqlen_in_batch_k,
335
+ dropout_p=dropout,
336
+ softmax_scale=softmax_scale,
337
+ causal=causal,
338
+ **flash_kwargs,
339
+ )
340
+
341
+ attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1))
342
+
343
+ else:
344
+ attn_output = flash_attn_func(
345
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal, **flash_kwargs
346
+ )
347
+
348
+ return attn_output