kernel
danieldk HF Staff commited on
Commit
77427db
·
1 Parent(s): f935a74

Make tests work

Browse files
tests/__init__.py ADDED
File without changes
tests/padding.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from einops import rearrange
6
+
7
+
8
+ def unpad_input(hidden_states, attention_mask, unused_mask=None):
9
+ """
10
+ Arguments:
11
+ hidden_states: (batch, seqlen, ...)
12
+ attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
13
+ unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused.
14
+ Return:
15
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask.
16
+ indices: (total_nnz), the indices of masked tokens from the flattened input sequence.
17
+ cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
18
+ max_seqlen_in_batch: int
19
+ seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask.
20
+ """
21
+ all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask
22
+ seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32)
23
+ used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
24
+ indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten()
25
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
26
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
27
+ # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
28
+ # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
29
+ # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
30
+ # index with integer indices.
31
+ return (
32
+ rearrange(hidden_states, "b s ... -> (b s) ...")[indices],
33
+ indices,
34
+ cu_seqlens,
35
+ max_seqlen_in_batch,
36
+ used_seqlens_in_batch,
37
+ )
38
+
39
+
40
+ def pad_input(hidden_states, indices, batch, seqlen):
41
+ """
42
+ Arguments:
43
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
44
+ indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
45
+ batch: int, batch size for the padded sequence.
46
+ seqlen: int, maximum sequence length for the padded sequence.
47
+ Return:
48
+ hidden_states: (batch, seqlen, ...)
49
+ """
50
+ dim = hidden_states.shape[1:]
51
+ output = torch.zeros((batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype)
52
+ output[indices] = hidden_states
53
+ return rearrange(output, "(b s) ... -> b s ...", b=batch)
tests/test_flash_attn.py CHANGED
@@ -8,10 +8,7 @@ import torch.nn.functional as F
8
  from torch._C import parse_schema
9
 
10
  from einops import rearrange, repeat
11
- try:
12
- from flash_attn.layers.rotary import apply_rotary_emb
13
- except ImportError:
14
- apply_rotary_emb = None
15
 
16
  from padding import pad_input, unpad_input
17
  from test_util import (
@@ -20,10 +17,10 @@ from test_util import (
20
  generate_random_padding_mask,
21
  )
22
 
23
- from flash_attn3 import flash_attn_func, flash_attn_varlen_func, flash_attn_combine
24
- from flash_attn3 import flash_attn_with_kvcache, get_scheduler_metadata
25
 
26
- from flash_attn3._ops import ops
 
27
 
28
 
29
  DISABLE_BACKWARD = os.getenv("FLASH_ATTENTION_DISABLE_BACKWARD", "FALSE") == "TRUE"
@@ -195,7 +192,7 @@ def test_flash_attn_output(
195
  pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False]
196
  num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1]
197
  for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals):
198
- out, lse = flash_attn_func(
199
  q,
200
  k,
201
  v,
@@ -462,7 +459,7 @@ def test_flash_attn_varlen_output(
462
  pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False]
463
  num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1]
464
  for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals):
465
- out_unpad, lse = flash_attn_varlen_func(
466
  q_unpad,
467
  k_unpad,
468
  v_unpad,
@@ -856,7 +853,7 @@ def test_flash_attn_kvcache(
856
  precompute_metadata_vals = [False, True]
857
  for num_splits, precompute_metadata in itertools.product(num_splits_vals, precompute_metadata_vals):
858
  if precompute_metadata:
859
- scheduler_metadata = get_scheduler_metadata(
860
  batch_size, max_seqlen_q if varlen_q else seqlen_q, seqlen_k, nheads, nheads_k, d,
861
  cache_seqlens, q.dtype, headdim_v=dv, cu_seqlens_q=cu_seqlens_q,
862
  cu_seqlens_k_new=cu_seqlens_k_new, cache_leftpad=cache_leftpad,
@@ -874,7 +871,7 @@ def test_flash_attn_kvcache(
874
  else:
875
  k_cache_paged.copy_(k_cache_saved)
876
  v_cache_paged.copy_(v_cache_saved)
877
- out, lse, *rest = flash_attn_with_kvcache(
878
  q if not varlen_q else q_unpad,
879
  k_cache if page_size is None else k_cache_paged,
880
  v_cache if page_size is None else v_cache_paged,
@@ -1008,7 +1005,7 @@ def test_flash_attn_cluster(seqlen_q, seqlen_k, d, causal, dtype):
1008
  k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype)
1009
  v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype)
1010
  for _ in range(100):
1011
- flash_attn_func(q, k, v, causal=causal)
1012
 
1013
 
1014
  # @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
@@ -1052,7 +1049,7 @@ def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, causal, dtype):
1052
  k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
1053
  v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
1054
  torch.random.manual_seed(42)
1055
- out0, lse0 = flash_attn_func(q, k, v, causal=causal)
1056
  g = torch.randn_like(out0)
1057
  dq0, dk0, dv0 = torch.autograd.grad(out0, (q, k, v), g)
1058
  # Numerical error if we just do any arithmetic on dq
@@ -1060,7 +1057,7 @@ def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, causal, dtype):
1060
 
1061
  for i in range(1000):
1062
  torch.random.manual_seed(42)
1063
- out, lse = flash_attn_func(q, k, v, causal=causal)
1064
  assert torch.equal(out, out0)
1065
  assert torch.equal(lse, lse0)
1066
 
@@ -1111,7 +1108,7 @@ def test_flash_attn_combine(num_splits, seqlen, d, dtype):
1111
  lse_partial = torch.randn(num_splits, batch_size, nheads * 2, seqlen, device=device, dtype=torch.float32).transpose(-1, -2)[:, :, :, :nheads] # To test non-contiguous tensor
1112
  # To test short-circuiting based on num_splits
1113
  lse_partial[num_splits // 2:, :batch_size // 3] = -float("inf")
1114
- out, lse = flash_attn_combine(out_partial, lse_partial, out_dtype=dtype)
1115
  out_ref, lse_ref = attention_combine_ref(out_partial, lse_partial)
1116
  out_pt = out_ref.to(dtype)
1117
 
 
8
  from torch._C import parse_schema
9
 
10
  from einops import rearrange, repeat
11
+ apply_rotary_emb = None
 
 
 
12
 
13
  from padding import pad_input, unpad_input
14
  from test_util import (
 
17
  generate_random_padding_mask,
18
  )
19
 
20
+ import kernels
 
21
 
22
+ flash_attn3 = kernels.get_kernel("kernels-community/flash-attn3")
23
+ ops = flash_attn3._ops
24
 
25
 
26
  DISABLE_BACKWARD = os.getenv("FLASH_ATTENTION_DISABLE_BACKWARD", "FALSE") == "TRUE"
 
192
  pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False]
193
  num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1]
194
  for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals):
195
+ out, lse = flash_attn3.flash_attn_func(
196
  q,
197
  k,
198
  v,
 
459
  pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False]
460
  num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1]
461
  for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals):
462
+ out_unpad, lse = flash_attn3.flash_attn_varlen_func(
463
  q_unpad,
464
  k_unpad,
465
  v_unpad,
 
853
  precompute_metadata_vals = [False, True]
854
  for num_splits, precompute_metadata in itertools.product(num_splits_vals, precompute_metadata_vals):
855
  if precompute_metadata:
856
+ scheduler_metadata = flash_attn3.get_scheduler_metadata(
857
  batch_size, max_seqlen_q if varlen_q else seqlen_q, seqlen_k, nheads, nheads_k, d,
858
  cache_seqlens, q.dtype, headdim_v=dv, cu_seqlens_q=cu_seqlens_q,
859
  cu_seqlens_k_new=cu_seqlens_k_new, cache_leftpad=cache_leftpad,
 
871
  else:
872
  k_cache_paged.copy_(k_cache_saved)
873
  v_cache_paged.copy_(v_cache_saved)
874
+ out, lse, *rest = flash_attn3.flash_attn_with_kvcache(
875
  q if not varlen_q else q_unpad,
876
  k_cache if page_size is None else k_cache_paged,
877
  v_cache if page_size is None else v_cache_paged,
 
1005
  k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype)
1006
  v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype)
1007
  for _ in range(100):
1008
+ flash_attn3.flash_attn_func(q, k, v, causal=causal)
1009
 
1010
 
1011
  # @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
 
1049
  k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
1050
  v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
1051
  torch.random.manual_seed(42)
1052
+ out0, lse0 = flash_attn3.flash_attn_func(q, k, v, causal=causal)
1053
  g = torch.randn_like(out0)
1054
  dq0, dk0, dv0 = torch.autograd.grad(out0, (q, k, v), g)
1055
  # Numerical error if we just do any arithmetic on dq
 
1057
 
1058
  for i in range(1000):
1059
  torch.random.manual_seed(42)
1060
+ out, lse = flash_attn3.flash_attn_func(q, k, v, causal=causal)
1061
  assert torch.equal(out, out0)
1062
  assert torch.equal(lse, lse0)
1063
 
 
1108
  lse_partial = torch.randn(num_splits, batch_size, nheads * 2, seqlen, device=device, dtype=torch.float32).transpose(-1, -2)[:, :, :, :nheads] # To test non-contiguous tensor
1109
  # To test short-circuiting based on num_splits
1110
  lse_partial[num_splits // 2:, :batch_size // 3] = -float("inf")
1111
+ out, lse = flash_attn3.flash_attn_combine(out_partial, lse_partial, out_dtype=dtype)
1112
  out_ref, lse_ref = attention_combine_ref(out_partial, lse_partial)
1113
  out_pt = out_ref.to(dtype)
1114
 
tests/test_util.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ from einops import rearrange, repeat
5
+
6
+ from padding import pad_input, unpad_input
7
+
8
+
9
+ def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random", zero_lengths=False):
10
+ assert mode in ["full", "random", "third"]
11
+ if mode == "full":
12
+ lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32)
13
+ elif mode == "random":
14
+ lengths = torch.randint(
15
+ max(0 if zero_lengths else 1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device
16
+ )
17
+ elif mode == "third":
18
+ lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device)
19
+
20
+ if zero_lengths:
21
+ # Generate zero-lengths every 5 batches and the last batch.
22
+ for i in range(batch_size):
23
+ if i % 5 == 0:
24
+ lengths[i] = 0
25
+ lengths[-1] = 0
26
+ padding_mask = (
27
+ repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths
28
+ )
29
+ return padding_mask
30
+
31
+
32
+ def generate_qkv(
33
+ q, k, v, query_padding_mask=None, key_padding_mask=None, qv=None, kvpacked=False, qkvpacked=False,
34
+ query_unused_mask=None, key_unused_mask=None,
35
+ ):
36
+ """
37
+ Arguments:
38
+ q: (batch_size, seqlen_q, nheads, d)
39
+ k: (batch_size, seqlen_k, nheads_k, d)
40
+ v: (batch_size, seqlen_k, nheads_k, d_v)
41
+ query_padding_mask: (batch_size, seqlen), bool
42
+ key_padding_mask: (batch_size, seqlen), bool
43
+ """
44
+ assert not (kvpacked and qkvpacked)
45
+ batch_size, seqlen_q, nheads, d = q.shape
46
+ d_v = v.shape[-1]
47
+ _, seqlen_k, nheads_k, _ = k.shape
48
+ assert k.shape == (batch_size, seqlen_k, nheads_k, d)
49
+ assert v.shape == (batch_size, seqlen_k, nheads_k, d_v)
50
+ if query_unused_mask is not None or key_unused_mask is not None:
51
+ assert not kvpacked
52
+ assert not qkvpacked
53
+
54
+ if query_padding_mask is not None:
55
+ q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, seqused_q = unpad_input(
56
+ q, query_padding_mask, query_unused_mask
57
+ )
58
+ output_pad_fn = lambda output_unpad: pad_input(
59
+ output_unpad, indices_q, batch_size, seqlen_q
60
+ )
61
+ qv_unpad = rearrange(qv, "b s ... -> (b s) ...")[indices_q] if qv is not None else None
62
+ else:
63
+ q_unpad = rearrange(q, "b s h d -> (b s) h d")
64
+ cu_seqlens_q = torch.arange(
65
+ 0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device
66
+ )
67
+ seqused_q = None
68
+ max_seqlen_q = seqlen_q
69
+ output_pad_fn = lambda output_unpad: rearrange(
70
+ output_unpad, "(b s) h d -> b s h d", b=batch_size
71
+ )
72
+ qv_unpad = rearrange(qv, "b s ... -> (b s) ...") if qv is not None else None
73
+
74
+ if key_padding_mask is not None:
75
+ k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, seqused_k = unpad_input(
76
+ k, key_padding_mask, key_unused_mask
77
+ )
78
+ v_unpad, *rest = unpad_input(v, key_padding_mask, key_unused_mask)
79
+ else:
80
+ k_unpad = rearrange(k, "b s h d -> (b s) h d")
81
+ v_unpad = rearrange(v, "b s h d -> (b s) h d")
82
+ cu_seqlens_k = torch.arange(
83
+ 0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device
84
+ )
85
+ seqused_k = None
86
+ max_seqlen_k = seqlen_k
87
+
88
+ if qkvpacked:
89
+ assert (query_padding_mask == key_padding_mask).all()
90
+ assert nheads == nheads_k
91
+ qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
92
+ qkv = torch.stack([q, k, v], dim=2)
93
+ if query_padding_mask is not None:
94
+ dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q)
95
+ else:
96
+ dqkv_pad_fn = lambda dqkv_unpad: rearrange(
97
+ dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
98
+ )
99
+ return (
100
+ qkv_unpad.detach().requires_grad_(),
101
+ cu_seqlens_q,
102
+ max_seqlen_q,
103
+ qkv.detach().requires_grad_(),
104
+ output_pad_fn,
105
+ dqkv_pad_fn,
106
+ )
107
+ elif kvpacked:
108
+ kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
109
+ kv = torch.stack([k, v], dim=2)
110
+ dq_pad_fn = output_pad_fn
111
+ if key_padding_mask is not None:
112
+ dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k)
113
+ else:
114
+ dkv_pad_fn = lambda dkv_unpad: rearrange(
115
+ dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
116
+ )
117
+ return (
118
+ q_unpad.detach().requires_grad_(),
119
+ kv_unpad.detach().requires_grad_(),
120
+ cu_seqlens_q,
121
+ cu_seqlens_k,
122
+ max_seqlen_q,
123
+ max_seqlen_k,
124
+ q.detach().requires_grad_(),
125
+ kv.detach().requires_grad_(),
126
+ output_pad_fn,
127
+ dq_pad_fn,
128
+ dkv_pad_fn,
129
+ )
130
+ else:
131
+ dq_pad_fn = output_pad_fn
132
+ if key_padding_mask is not None:
133
+ dk_pad_fn = lambda dk_unpad: pad_input(dk_unpad, indices_k, batch_size, seqlen_k)
134
+ else:
135
+ dk_pad_fn = lambda dk_unpad: rearrange(dk_unpad, "(b s) h d -> b s h d", b=batch_size)
136
+ return (
137
+ q_unpad.detach().requires_grad_(),
138
+ k_unpad.detach().requires_grad_(),
139
+ v_unpad.detach().requires_grad_(),
140
+ qv_unpad.detach() if qv is not None else None,
141
+ cu_seqlens_q,
142
+ cu_seqlens_k,
143
+ seqused_q,
144
+ seqused_k,
145
+ max_seqlen_q,
146
+ max_seqlen_k,
147
+ q.detach().requires_grad_(),
148
+ k.detach().requires_grad_(),
149
+ v.detach().requires_grad_(),
150
+ qv.detach() if qv is not None else None,
151
+ output_pad_fn,
152
+ dq_pad_fn,
153
+ dk_pad_fn,
154
+ )
155
+
156
+
157
+ def construct_local_mask(
158
+ seqlen_q,
159
+ seqlen_k,
160
+ window_size=(-1, -1), # -1 means infinite window size
161
+ sink_token_length=0,
162
+ query_padding_mask=None,
163
+ key_padding_mask=None,
164
+ key_leftpad=None,
165
+ device=None,
166
+ ):
167
+ row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
168
+ col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
169
+ if key_leftpad is not None:
170
+ key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1")
171
+ col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0])
172
+ col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32)
173
+ sk = (
174
+ seqlen_k
175
+ if key_padding_mask is None
176
+ else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
177
+ )
178
+ sq = (
179
+ seqlen_q
180
+ if query_padding_mask is None
181
+ else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
182
+ )
183
+ if window_size[0] < 0:
184
+ return col_idx > row_idx + sk - sq + window_size[1]
185
+ else:
186
+ sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk
187
+ return torch.logical_or(
188
+ col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk),
189
+ torch.logical_and(col_idx < row_idx + sk - sq - window_size[0], col_idx >= sink_token_length),
190
+ )
191
+
192
+
193
+ def construct_chunk_mask(
194
+ seqlen_q,
195
+ seqlen_k,
196
+ attention_chunk,
197
+ query_padding_mask=None,
198
+ key_padding_mask=None,
199
+ key_leftpad=None,
200
+ device=None,
201
+ ):
202
+ row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
203
+ col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
204
+ if key_leftpad is not None:
205
+ key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1")
206
+ col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0])
207
+ col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32)
208
+ sk = (
209
+ seqlen_k
210
+ if key_padding_mask is None
211
+ else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
212
+ )
213
+ sq = (
214
+ seqlen_q
215
+ if query_padding_mask is None
216
+ else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
217
+ )
218
+ sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk
219
+ # Subtract remainder instead of divide and then multiply to take care of negative values
220
+ col_limit_left_chunk = row_idx + sk - sq - (row_idx + sk - sq) % attention_chunk
221
+ return torch.logical_or(
222
+ col_idx < col_limit_left_chunk, col_idx >= col_limit_left_chunk + attention_chunk
223
+ )
224
+
225
+
226
+ def attention_ref(
227
+ q,
228
+ k,
229
+ v,
230
+ query_padding_mask=None,
231
+ key_padding_mask=None,
232
+ key_leftpad=None,
233
+ attn_bias=None,
234
+ dropout_p=0.0,
235
+ dropout_mask=None,
236
+ causal=False,
237
+ qv=None,
238
+ q_descale=None, k_descale=None, v_descale=None,
239
+ window_size=(-1, -1), # -1 means infinite window size
240
+ attention_chunk=0,
241
+ sink_token_length=0,
242
+ softcap=0.0,
243
+ upcast=True,
244
+ reorder_ops=False,
245
+ intermediate_dtype=None,
246
+ ):
247
+ """
248
+ Arguments:
249
+ q: (batch_size, seqlen_q, nheads, head_dim)
250
+ k: (batch_size, seqlen_k, nheads, head_dim)
251
+ v: (batch_size, seqlen_k, nheads, head_dim_v)
252
+ qv: (batch_size, seqlen_q, nheads, head_dim_v)
253
+ query_padding_mask: (batch_size, seqlen_q)
254
+ key_padding_mask: (batch_size, seqlen_k)
255
+ attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)
256
+ dropout_p: float
257
+ dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
258
+ causal: whether to apply causal masking
259
+ upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
260
+ output back to fp16/bf16.
261
+ reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.)
262
+ without changing the math. This is to estimate the numerical error from operation
263
+ reordering.
264
+ Output:
265
+ output: (batch_size, seqlen_q, nheads, head_dim_v)
266
+ attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout
267
+ """
268
+ if causal:
269
+ window_size = (window_size[0], 0)
270
+ dtype_og = q.dtype
271
+ if upcast:
272
+ q, k, v = q.float(), k.float(), v.float()
273
+ qv = qv.float() if qv is not None else None
274
+ if q_descale is not None:
275
+ q_descale = repeat(q_descale, "b h -> b 1 (h g) 1", g=q.shape[2] // k.shape[2])
276
+ q = (q.float() * q_descale).to(q.dtype)
277
+ qv = (qv.float() * q_descale).to(qv.dtype) if qv is not None else None
278
+ if k_descale is not None:
279
+ k = (k.float() * rearrange(k_descale, "b h -> b 1 h 1")).to(dtype=k.dtype)
280
+ if v_descale is not None:
281
+ v = (v.float() * rearrange(v_descale, "b h -> b 1 h 1")).to(dtype=v.dtype)
282
+ seqlen_q, seqlen_k = q.shape[1], k.shape[1]
283
+ k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2])
284
+ v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
285
+ d = q.shape[-1]
286
+ dv = v.shape[-1]
287
+ softmax_scale = 1.0 / math.sqrt(d if qv is None else d + dv)
288
+ if not reorder_ops:
289
+ scores = torch.einsum("bthd,bshd->bhts", q * softmax_scale, k)
290
+ else:
291
+ scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
292
+ if qv is not None:
293
+ scores = scores + torch.einsum("bthd,bshd->bhts", qv * softmax_scale, v)
294
+ if softcap > 0:
295
+ scores = torch.tanh(scores / softcap) * softcap
296
+ if key_padding_mask is not None:
297
+ scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
298
+ local_mask = None
299
+ if window_size[0] >= 0 or window_size[1] >= 0:
300
+ local_mask = construct_local_mask(
301
+ seqlen_q,
302
+ seqlen_k,
303
+ window_size,
304
+ sink_token_length,
305
+ query_padding_mask,
306
+ key_padding_mask,
307
+ key_leftpad=key_leftpad,
308
+ device=q.device,
309
+ )
310
+ if attention_chunk > 0:
311
+ chunk_mask = construct_chunk_mask(
312
+ seqlen_q,
313
+ seqlen_k,
314
+ attention_chunk,
315
+ query_padding_mask,
316
+ key_padding_mask,
317
+ key_leftpad=key_leftpad,
318
+ device=q.device,
319
+ )
320
+ local_mask = torch.logical_or(local_mask, chunk_mask) if local_mask is not None else chunk_mask
321
+ if local_mask is not None:
322
+ scores.masked_fill_(local_mask, float("-inf"))
323
+ if attn_bias is not None:
324
+ scores = scores + attn_bias
325
+ attention = torch.softmax(scores, dim=-1).to(v.dtype)
326
+ # We want to mask here so that the attention matrix doesn't have any NaNs
327
+ # Otherwise we'll get NaN in dV
328
+ if query_padding_mask is not None:
329
+ attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
330
+ # Without this we might get NaN in dv
331
+ if key_padding_mask is not None:
332
+ attention = attention.masked_fill(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0)
333
+ # Some rows might be completely masked out so we fill them with zero instead of NaN
334
+ if local_mask is not None:
335
+ attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0)
336
+ dropout_scaling = 1.0 / (1 - dropout_p)
337
+ # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling
338
+ # output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
339
+ if dropout_mask is not None:
340
+ attention_drop = attention.masked_fill(~dropout_mask, 0.0)
341
+ else:
342
+ attention_drop = attention
343
+ if intermediate_dtype is not None:
344
+ attention_drop = attention_drop.to(intermediate_dtype).to(attention_drop.dtype)
345
+ output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling)
346
+ if query_padding_mask is not None:
347
+ output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
348
+ return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)