Make tests work
Browse files- tests/__init__.py +0 -0
- tests/padding.py +53 -0
- tests/test_flash_attn.py +12 -15
- tests/test_util.py +348 -0
tests/__init__.py
ADDED
File without changes
|
tests/padding.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from einops import rearrange
|
6 |
+
|
7 |
+
|
8 |
+
def unpad_input(hidden_states, attention_mask, unused_mask=None):
|
9 |
+
"""
|
10 |
+
Arguments:
|
11 |
+
hidden_states: (batch, seqlen, ...)
|
12 |
+
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
|
13 |
+
unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused.
|
14 |
+
Return:
|
15 |
+
hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask.
|
16 |
+
indices: (total_nnz), the indices of masked tokens from the flattened input sequence.
|
17 |
+
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
|
18 |
+
max_seqlen_in_batch: int
|
19 |
+
seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask.
|
20 |
+
"""
|
21 |
+
all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask
|
22 |
+
seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32)
|
23 |
+
used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
24 |
+
indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten()
|
25 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
26 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
27 |
+
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
|
28 |
+
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
|
29 |
+
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
|
30 |
+
# index with integer indices.
|
31 |
+
return (
|
32 |
+
rearrange(hidden_states, "b s ... -> (b s) ...")[indices],
|
33 |
+
indices,
|
34 |
+
cu_seqlens,
|
35 |
+
max_seqlen_in_batch,
|
36 |
+
used_seqlens_in_batch,
|
37 |
+
)
|
38 |
+
|
39 |
+
|
40 |
+
def pad_input(hidden_states, indices, batch, seqlen):
|
41 |
+
"""
|
42 |
+
Arguments:
|
43 |
+
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
|
44 |
+
indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
|
45 |
+
batch: int, batch size for the padded sequence.
|
46 |
+
seqlen: int, maximum sequence length for the padded sequence.
|
47 |
+
Return:
|
48 |
+
hidden_states: (batch, seqlen, ...)
|
49 |
+
"""
|
50 |
+
dim = hidden_states.shape[1:]
|
51 |
+
output = torch.zeros((batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype)
|
52 |
+
output[indices] = hidden_states
|
53 |
+
return rearrange(output, "(b s) ... -> b s ...", b=batch)
|
tests/test_flash_attn.py
CHANGED
@@ -8,10 +8,7 @@ import torch.nn.functional as F
|
|
8 |
from torch._C import parse_schema
|
9 |
|
10 |
from einops import rearrange, repeat
|
11 |
-
|
12 |
-
from flash_attn.layers.rotary import apply_rotary_emb
|
13 |
-
except ImportError:
|
14 |
-
apply_rotary_emb = None
|
15 |
|
16 |
from padding import pad_input, unpad_input
|
17 |
from test_util import (
|
@@ -20,10 +17,10 @@ from test_util import (
|
|
20 |
generate_random_padding_mask,
|
21 |
)
|
22 |
|
23 |
-
|
24 |
-
from flash_attn3 import flash_attn_with_kvcache, get_scheduler_metadata
|
25 |
|
26 |
-
|
|
|
27 |
|
28 |
|
29 |
DISABLE_BACKWARD = os.getenv("FLASH_ATTENTION_DISABLE_BACKWARD", "FALSE") == "TRUE"
|
@@ -195,7 +192,7 @@ def test_flash_attn_output(
|
|
195 |
pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False]
|
196 |
num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1]
|
197 |
for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals):
|
198 |
-
out, lse = flash_attn_func(
|
199 |
q,
|
200 |
k,
|
201 |
v,
|
@@ -462,7 +459,7 @@ def test_flash_attn_varlen_output(
|
|
462 |
pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False]
|
463 |
num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1]
|
464 |
for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals):
|
465 |
-
out_unpad, lse = flash_attn_varlen_func(
|
466 |
q_unpad,
|
467 |
k_unpad,
|
468 |
v_unpad,
|
@@ -856,7 +853,7 @@ def test_flash_attn_kvcache(
|
|
856 |
precompute_metadata_vals = [False, True]
|
857 |
for num_splits, precompute_metadata in itertools.product(num_splits_vals, precompute_metadata_vals):
|
858 |
if precompute_metadata:
|
859 |
-
scheduler_metadata = get_scheduler_metadata(
|
860 |
batch_size, max_seqlen_q if varlen_q else seqlen_q, seqlen_k, nheads, nheads_k, d,
|
861 |
cache_seqlens, q.dtype, headdim_v=dv, cu_seqlens_q=cu_seqlens_q,
|
862 |
cu_seqlens_k_new=cu_seqlens_k_new, cache_leftpad=cache_leftpad,
|
@@ -874,7 +871,7 @@ def test_flash_attn_kvcache(
|
|
874 |
else:
|
875 |
k_cache_paged.copy_(k_cache_saved)
|
876 |
v_cache_paged.copy_(v_cache_saved)
|
877 |
-
out, lse, *rest = flash_attn_with_kvcache(
|
878 |
q if not varlen_q else q_unpad,
|
879 |
k_cache if page_size is None else k_cache_paged,
|
880 |
v_cache if page_size is None else v_cache_paged,
|
@@ -1008,7 +1005,7 @@ def test_flash_attn_cluster(seqlen_q, seqlen_k, d, causal, dtype):
|
|
1008 |
k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype)
|
1009 |
v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype)
|
1010 |
for _ in range(100):
|
1011 |
-
flash_attn_func(q, k, v, causal=causal)
|
1012 |
|
1013 |
|
1014 |
# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
|
@@ -1052,7 +1049,7 @@ def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, causal, dtype):
|
|
1052 |
k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
|
1053 |
v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
|
1054 |
torch.random.manual_seed(42)
|
1055 |
-
out0, lse0 = flash_attn_func(q, k, v, causal=causal)
|
1056 |
g = torch.randn_like(out0)
|
1057 |
dq0, dk0, dv0 = torch.autograd.grad(out0, (q, k, v), g)
|
1058 |
# Numerical error if we just do any arithmetic on dq
|
@@ -1060,7 +1057,7 @@ def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, causal, dtype):
|
|
1060 |
|
1061 |
for i in range(1000):
|
1062 |
torch.random.manual_seed(42)
|
1063 |
-
out, lse = flash_attn_func(q, k, v, causal=causal)
|
1064 |
assert torch.equal(out, out0)
|
1065 |
assert torch.equal(lse, lse0)
|
1066 |
|
@@ -1111,7 +1108,7 @@ def test_flash_attn_combine(num_splits, seqlen, d, dtype):
|
|
1111 |
lse_partial = torch.randn(num_splits, batch_size, nheads * 2, seqlen, device=device, dtype=torch.float32).transpose(-1, -2)[:, :, :, :nheads] # To test non-contiguous tensor
|
1112 |
# To test short-circuiting based on num_splits
|
1113 |
lse_partial[num_splits // 2:, :batch_size // 3] = -float("inf")
|
1114 |
-
out, lse = flash_attn_combine(out_partial, lse_partial, out_dtype=dtype)
|
1115 |
out_ref, lse_ref = attention_combine_ref(out_partial, lse_partial)
|
1116 |
out_pt = out_ref.to(dtype)
|
1117 |
|
|
|
8 |
from torch._C import parse_schema
|
9 |
|
10 |
from einops import rearrange, repeat
|
11 |
+
apply_rotary_emb = None
|
|
|
|
|
|
|
12 |
|
13 |
from padding import pad_input, unpad_input
|
14 |
from test_util import (
|
|
|
17 |
generate_random_padding_mask,
|
18 |
)
|
19 |
|
20 |
+
import kernels
|
|
|
21 |
|
22 |
+
flash_attn3 = kernels.get_kernel("kernels-community/flash-attn3")
|
23 |
+
ops = flash_attn3._ops
|
24 |
|
25 |
|
26 |
DISABLE_BACKWARD = os.getenv("FLASH_ATTENTION_DISABLE_BACKWARD", "FALSE") == "TRUE"
|
|
|
192 |
pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False]
|
193 |
num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1]
|
194 |
for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals):
|
195 |
+
out, lse = flash_attn3.flash_attn_func(
|
196 |
q,
|
197 |
k,
|
198 |
v,
|
|
|
459 |
pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False]
|
460 |
num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1]
|
461 |
for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals):
|
462 |
+
out_unpad, lse = flash_attn3.flash_attn_varlen_func(
|
463 |
q_unpad,
|
464 |
k_unpad,
|
465 |
v_unpad,
|
|
|
853 |
precompute_metadata_vals = [False, True]
|
854 |
for num_splits, precompute_metadata in itertools.product(num_splits_vals, precompute_metadata_vals):
|
855 |
if precompute_metadata:
|
856 |
+
scheduler_metadata = flash_attn3.get_scheduler_metadata(
|
857 |
batch_size, max_seqlen_q if varlen_q else seqlen_q, seqlen_k, nheads, nheads_k, d,
|
858 |
cache_seqlens, q.dtype, headdim_v=dv, cu_seqlens_q=cu_seqlens_q,
|
859 |
cu_seqlens_k_new=cu_seqlens_k_new, cache_leftpad=cache_leftpad,
|
|
|
871 |
else:
|
872 |
k_cache_paged.copy_(k_cache_saved)
|
873 |
v_cache_paged.copy_(v_cache_saved)
|
874 |
+
out, lse, *rest = flash_attn3.flash_attn_with_kvcache(
|
875 |
q if not varlen_q else q_unpad,
|
876 |
k_cache if page_size is None else k_cache_paged,
|
877 |
v_cache if page_size is None else v_cache_paged,
|
|
|
1005 |
k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype)
|
1006 |
v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype)
|
1007 |
for _ in range(100):
|
1008 |
+
flash_attn3.flash_attn_func(q, k, v, causal=causal)
|
1009 |
|
1010 |
|
1011 |
# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
|
|
|
1049 |
k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
|
1050 |
v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
|
1051 |
torch.random.manual_seed(42)
|
1052 |
+
out0, lse0 = flash_attn3.flash_attn_func(q, k, v, causal=causal)
|
1053 |
g = torch.randn_like(out0)
|
1054 |
dq0, dk0, dv0 = torch.autograd.grad(out0, (q, k, v), g)
|
1055 |
# Numerical error if we just do any arithmetic on dq
|
|
|
1057 |
|
1058 |
for i in range(1000):
|
1059 |
torch.random.manual_seed(42)
|
1060 |
+
out, lse = flash_attn3.flash_attn_func(q, k, v, causal=causal)
|
1061 |
assert torch.equal(out, out0)
|
1062 |
assert torch.equal(lse, lse0)
|
1063 |
|
|
|
1108 |
lse_partial = torch.randn(num_splits, batch_size, nheads * 2, seqlen, device=device, dtype=torch.float32).transpose(-1, -2)[:, :, :, :nheads] # To test non-contiguous tensor
|
1109 |
# To test short-circuiting based on num_splits
|
1110 |
lse_partial[num_splits // 2:, :batch_size // 3] = -float("inf")
|
1111 |
+
out, lse = flash_attn3.flash_attn_combine(out_partial, lse_partial, out_dtype=dtype)
|
1112 |
out_ref, lse_ref = attention_combine_ref(out_partial, lse_partial)
|
1113 |
out_pt = out_ref.to(dtype)
|
1114 |
|
tests/test_util.py
ADDED
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from einops import rearrange, repeat
|
5 |
+
|
6 |
+
from padding import pad_input, unpad_input
|
7 |
+
|
8 |
+
|
9 |
+
def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random", zero_lengths=False):
|
10 |
+
assert mode in ["full", "random", "third"]
|
11 |
+
if mode == "full":
|
12 |
+
lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32)
|
13 |
+
elif mode == "random":
|
14 |
+
lengths = torch.randint(
|
15 |
+
max(0 if zero_lengths else 1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device
|
16 |
+
)
|
17 |
+
elif mode == "third":
|
18 |
+
lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device)
|
19 |
+
|
20 |
+
if zero_lengths:
|
21 |
+
# Generate zero-lengths every 5 batches and the last batch.
|
22 |
+
for i in range(batch_size):
|
23 |
+
if i % 5 == 0:
|
24 |
+
lengths[i] = 0
|
25 |
+
lengths[-1] = 0
|
26 |
+
padding_mask = (
|
27 |
+
repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths
|
28 |
+
)
|
29 |
+
return padding_mask
|
30 |
+
|
31 |
+
|
32 |
+
def generate_qkv(
|
33 |
+
q, k, v, query_padding_mask=None, key_padding_mask=None, qv=None, kvpacked=False, qkvpacked=False,
|
34 |
+
query_unused_mask=None, key_unused_mask=None,
|
35 |
+
):
|
36 |
+
"""
|
37 |
+
Arguments:
|
38 |
+
q: (batch_size, seqlen_q, nheads, d)
|
39 |
+
k: (batch_size, seqlen_k, nheads_k, d)
|
40 |
+
v: (batch_size, seqlen_k, nheads_k, d_v)
|
41 |
+
query_padding_mask: (batch_size, seqlen), bool
|
42 |
+
key_padding_mask: (batch_size, seqlen), bool
|
43 |
+
"""
|
44 |
+
assert not (kvpacked and qkvpacked)
|
45 |
+
batch_size, seqlen_q, nheads, d = q.shape
|
46 |
+
d_v = v.shape[-1]
|
47 |
+
_, seqlen_k, nheads_k, _ = k.shape
|
48 |
+
assert k.shape == (batch_size, seqlen_k, nheads_k, d)
|
49 |
+
assert v.shape == (batch_size, seqlen_k, nheads_k, d_v)
|
50 |
+
if query_unused_mask is not None or key_unused_mask is not None:
|
51 |
+
assert not kvpacked
|
52 |
+
assert not qkvpacked
|
53 |
+
|
54 |
+
if query_padding_mask is not None:
|
55 |
+
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, seqused_q = unpad_input(
|
56 |
+
q, query_padding_mask, query_unused_mask
|
57 |
+
)
|
58 |
+
output_pad_fn = lambda output_unpad: pad_input(
|
59 |
+
output_unpad, indices_q, batch_size, seqlen_q
|
60 |
+
)
|
61 |
+
qv_unpad = rearrange(qv, "b s ... -> (b s) ...")[indices_q] if qv is not None else None
|
62 |
+
else:
|
63 |
+
q_unpad = rearrange(q, "b s h d -> (b s) h d")
|
64 |
+
cu_seqlens_q = torch.arange(
|
65 |
+
0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device
|
66 |
+
)
|
67 |
+
seqused_q = None
|
68 |
+
max_seqlen_q = seqlen_q
|
69 |
+
output_pad_fn = lambda output_unpad: rearrange(
|
70 |
+
output_unpad, "(b s) h d -> b s h d", b=batch_size
|
71 |
+
)
|
72 |
+
qv_unpad = rearrange(qv, "b s ... -> (b s) ...") if qv is not None else None
|
73 |
+
|
74 |
+
if key_padding_mask is not None:
|
75 |
+
k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, seqused_k = unpad_input(
|
76 |
+
k, key_padding_mask, key_unused_mask
|
77 |
+
)
|
78 |
+
v_unpad, *rest = unpad_input(v, key_padding_mask, key_unused_mask)
|
79 |
+
else:
|
80 |
+
k_unpad = rearrange(k, "b s h d -> (b s) h d")
|
81 |
+
v_unpad = rearrange(v, "b s h d -> (b s) h d")
|
82 |
+
cu_seqlens_k = torch.arange(
|
83 |
+
0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device
|
84 |
+
)
|
85 |
+
seqused_k = None
|
86 |
+
max_seqlen_k = seqlen_k
|
87 |
+
|
88 |
+
if qkvpacked:
|
89 |
+
assert (query_padding_mask == key_padding_mask).all()
|
90 |
+
assert nheads == nheads_k
|
91 |
+
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
|
92 |
+
qkv = torch.stack([q, k, v], dim=2)
|
93 |
+
if query_padding_mask is not None:
|
94 |
+
dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q)
|
95 |
+
else:
|
96 |
+
dqkv_pad_fn = lambda dqkv_unpad: rearrange(
|
97 |
+
dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
|
98 |
+
)
|
99 |
+
return (
|
100 |
+
qkv_unpad.detach().requires_grad_(),
|
101 |
+
cu_seqlens_q,
|
102 |
+
max_seqlen_q,
|
103 |
+
qkv.detach().requires_grad_(),
|
104 |
+
output_pad_fn,
|
105 |
+
dqkv_pad_fn,
|
106 |
+
)
|
107 |
+
elif kvpacked:
|
108 |
+
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
|
109 |
+
kv = torch.stack([k, v], dim=2)
|
110 |
+
dq_pad_fn = output_pad_fn
|
111 |
+
if key_padding_mask is not None:
|
112 |
+
dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k)
|
113 |
+
else:
|
114 |
+
dkv_pad_fn = lambda dkv_unpad: rearrange(
|
115 |
+
dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
|
116 |
+
)
|
117 |
+
return (
|
118 |
+
q_unpad.detach().requires_grad_(),
|
119 |
+
kv_unpad.detach().requires_grad_(),
|
120 |
+
cu_seqlens_q,
|
121 |
+
cu_seqlens_k,
|
122 |
+
max_seqlen_q,
|
123 |
+
max_seqlen_k,
|
124 |
+
q.detach().requires_grad_(),
|
125 |
+
kv.detach().requires_grad_(),
|
126 |
+
output_pad_fn,
|
127 |
+
dq_pad_fn,
|
128 |
+
dkv_pad_fn,
|
129 |
+
)
|
130 |
+
else:
|
131 |
+
dq_pad_fn = output_pad_fn
|
132 |
+
if key_padding_mask is not None:
|
133 |
+
dk_pad_fn = lambda dk_unpad: pad_input(dk_unpad, indices_k, batch_size, seqlen_k)
|
134 |
+
else:
|
135 |
+
dk_pad_fn = lambda dk_unpad: rearrange(dk_unpad, "(b s) h d -> b s h d", b=batch_size)
|
136 |
+
return (
|
137 |
+
q_unpad.detach().requires_grad_(),
|
138 |
+
k_unpad.detach().requires_grad_(),
|
139 |
+
v_unpad.detach().requires_grad_(),
|
140 |
+
qv_unpad.detach() if qv is not None else None,
|
141 |
+
cu_seqlens_q,
|
142 |
+
cu_seqlens_k,
|
143 |
+
seqused_q,
|
144 |
+
seqused_k,
|
145 |
+
max_seqlen_q,
|
146 |
+
max_seqlen_k,
|
147 |
+
q.detach().requires_grad_(),
|
148 |
+
k.detach().requires_grad_(),
|
149 |
+
v.detach().requires_grad_(),
|
150 |
+
qv.detach() if qv is not None else None,
|
151 |
+
output_pad_fn,
|
152 |
+
dq_pad_fn,
|
153 |
+
dk_pad_fn,
|
154 |
+
)
|
155 |
+
|
156 |
+
|
157 |
+
def construct_local_mask(
|
158 |
+
seqlen_q,
|
159 |
+
seqlen_k,
|
160 |
+
window_size=(-1, -1), # -1 means infinite window size
|
161 |
+
sink_token_length=0,
|
162 |
+
query_padding_mask=None,
|
163 |
+
key_padding_mask=None,
|
164 |
+
key_leftpad=None,
|
165 |
+
device=None,
|
166 |
+
):
|
167 |
+
row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
|
168 |
+
col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
|
169 |
+
if key_leftpad is not None:
|
170 |
+
key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1")
|
171 |
+
col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0])
|
172 |
+
col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32)
|
173 |
+
sk = (
|
174 |
+
seqlen_k
|
175 |
+
if key_padding_mask is None
|
176 |
+
else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
|
177 |
+
)
|
178 |
+
sq = (
|
179 |
+
seqlen_q
|
180 |
+
if query_padding_mask is None
|
181 |
+
else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
|
182 |
+
)
|
183 |
+
if window_size[0] < 0:
|
184 |
+
return col_idx > row_idx + sk - sq + window_size[1]
|
185 |
+
else:
|
186 |
+
sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk
|
187 |
+
return torch.logical_or(
|
188 |
+
col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk),
|
189 |
+
torch.logical_and(col_idx < row_idx + sk - sq - window_size[0], col_idx >= sink_token_length),
|
190 |
+
)
|
191 |
+
|
192 |
+
|
193 |
+
def construct_chunk_mask(
|
194 |
+
seqlen_q,
|
195 |
+
seqlen_k,
|
196 |
+
attention_chunk,
|
197 |
+
query_padding_mask=None,
|
198 |
+
key_padding_mask=None,
|
199 |
+
key_leftpad=None,
|
200 |
+
device=None,
|
201 |
+
):
|
202 |
+
row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
|
203 |
+
col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
|
204 |
+
if key_leftpad is not None:
|
205 |
+
key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1")
|
206 |
+
col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0])
|
207 |
+
col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32)
|
208 |
+
sk = (
|
209 |
+
seqlen_k
|
210 |
+
if key_padding_mask is None
|
211 |
+
else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
|
212 |
+
)
|
213 |
+
sq = (
|
214 |
+
seqlen_q
|
215 |
+
if query_padding_mask is None
|
216 |
+
else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
|
217 |
+
)
|
218 |
+
sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk
|
219 |
+
# Subtract remainder instead of divide and then multiply to take care of negative values
|
220 |
+
col_limit_left_chunk = row_idx + sk - sq - (row_idx + sk - sq) % attention_chunk
|
221 |
+
return torch.logical_or(
|
222 |
+
col_idx < col_limit_left_chunk, col_idx >= col_limit_left_chunk + attention_chunk
|
223 |
+
)
|
224 |
+
|
225 |
+
|
226 |
+
def attention_ref(
|
227 |
+
q,
|
228 |
+
k,
|
229 |
+
v,
|
230 |
+
query_padding_mask=None,
|
231 |
+
key_padding_mask=None,
|
232 |
+
key_leftpad=None,
|
233 |
+
attn_bias=None,
|
234 |
+
dropout_p=0.0,
|
235 |
+
dropout_mask=None,
|
236 |
+
causal=False,
|
237 |
+
qv=None,
|
238 |
+
q_descale=None, k_descale=None, v_descale=None,
|
239 |
+
window_size=(-1, -1), # -1 means infinite window size
|
240 |
+
attention_chunk=0,
|
241 |
+
sink_token_length=0,
|
242 |
+
softcap=0.0,
|
243 |
+
upcast=True,
|
244 |
+
reorder_ops=False,
|
245 |
+
intermediate_dtype=None,
|
246 |
+
):
|
247 |
+
"""
|
248 |
+
Arguments:
|
249 |
+
q: (batch_size, seqlen_q, nheads, head_dim)
|
250 |
+
k: (batch_size, seqlen_k, nheads, head_dim)
|
251 |
+
v: (batch_size, seqlen_k, nheads, head_dim_v)
|
252 |
+
qv: (batch_size, seqlen_q, nheads, head_dim_v)
|
253 |
+
query_padding_mask: (batch_size, seqlen_q)
|
254 |
+
key_padding_mask: (batch_size, seqlen_k)
|
255 |
+
attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)
|
256 |
+
dropout_p: float
|
257 |
+
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
|
258 |
+
causal: whether to apply causal masking
|
259 |
+
upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
|
260 |
+
output back to fp16/bf16.
|
261 |
+
reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.)
|
262 |
+
without changing the math. This is to estimate the numerical error from operation
|
263 |
+
reordering.
|
264 |
+
Output:
|
265 |
+
output: (batch_size, seqlen_q, nheads, head_dim_v)
|
266 |
+
attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout
|
267 |
+
"""
|
268 |
+
if causal:
|
269 |
+
window_size = (window_size[0], 0)
|
270 |
+
dtype_og = q.dtype
|
271 |
+
if upcast:
|
272 |
+
q, k, v = q.float(), k.float(), v.float()
|
273 |
+
qv = qv.float() if qv is not None else None
|
274 |
+
if q_descale is not None:
|
275 |
+
q_descale = repeat(q_descale, "b h -> b 1 (h g) 1", g=q.shape[2] // k.shape[2])
|
276 |
+
q = (q.float() * q_descale).to(q.dtype)
|
277 |
+
qv = (qv.float() * q_descale).to(qv.dtype) if qv is not None else None
|
278 |
+
if k_descale is not None:
|
279 |
+
k = (k.float() * rearrange(k_descale, "b h -> b 1 h 1")).to(dtype=k.dtype)
|
280 |
+
if v_descale is not None:
|
281 |
+
v = (v.float() * rearrange(v_descale, "b h -> b 1 h 1")).to(dtype=v.dtype)
|
282 |
+
seqlen_q, seqlen_k = q.shape[1], k.shape[1]
|
283 |
+
k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2])
|
284 |
+
v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
|
285 |
+
d = q.shape[-1]
|
286 |
+
dv = v.shape[-1]
|
287 |
+
softmax_scale = 1.0 / math.sqrt(d if qv is None else d + dv)
|
288 |
+
if not reorder_ops:
|
289 |
+
scores = torch.einsum("bthd,bshd->bhts", q * softmax_scale, k)
|
290 |
+
else:
|
291 |
+
scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
|
292 |
+
if qv is not None:
|
293 |
+
scores = scores + torch.einsum("bthd,bshd->bhts", qv * softmax_scale, v)
|
294 |
+
if softcap > 0:
|
295 |
+
scores = torch.tanh(scores / softcap) * softcap
|
296 |
+
if key_padding_mask is not None:
|
297 |
+
scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
|
298 |
+
local_mask = None
|
299 |
+
if window_size[0] >= 0 or window_size[1] >= 0:
|
300 |
+
local_mask = construct_local_mask(
|
301 |
+
seqlen_q,
|
302 |
+
seqlen_k,
|
303 |
+
window_size,
|
304 |
+
sink_token_length,
|
305 |
+
query_padding_mask,
|
306 |
+
key_padding_mask,
|
307 |
+
key_leftpad=key_leftpad,
|
308 |
+
device=q.device,
|
309 |
+
)
|
310 |
+
if attention_chunk > 0:
|
311 |
+
chunk_mask = construct_chunk_mask(
|
312 |
+
seqlen_q,
|
313 |
+
seqlen_k,
|
314 |
+
attention_chunk,
|
315 |
+
query_padding_mask,
|
316 |
+
key_padding_mask,
|
317 |
+
key_leftpad=key_leftpad,
|
318 |
+
device=q.device,
|
319 |
+
)
|
320 |
+
local_mask = torch.logical_or(local_mask, chunk_mask) if local_mask is not None else chunk_mask
|
321 |
+
if local_mask is not None:
|
322 |
+
scores.masked_fill_(local_mask, float("-inf"))
|
323 |
+
if attn_bias is not None:
|
324 |
+
scores = scores + attn_bias
|
325 |
+
attention = torch.softmax(scores, dim=-1).to(v.dtype)
|
326 |
+
# We want to mask here so that the attention matrix doesn't have any NaNs
|
327 |
+
# Otherwise we'll get NaN in dV
|
328 |
+
if query_padding_mask is not None:
|
329 |
+
attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
|
330 |
+
# Without this we might get NaN in dv
|
331 |
+
if key_padding_mask is not None:
|
332 |
+
attention = attention.masked_fill(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0)
|
333 |
+
# Some rows might be completely masked out so we fill them with zero instead of NaN
|
334 |
+
if local_mask is not None:
|
335 |
+
attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0)
|
336 |
+
dropout_scaling = 1.0 / (1 - dropout_p)
|
337 |
+
# attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling
|
338 |
+
# output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
|
339 |
+
if dropout_mask is not None:
|
340 |
+
attention_drop = attention.masked_fill(~dropout_mask, 0.0)
|
341 |
+
else:
|
342 |
+
attention_drop = attention
|
343 |
+
if intermediate_dtype is not None:
|
344 |
+
attention_drop = attention_drop.to(intermediate_dtype).to(attention_drop.dtype)
|
345 |
+
output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling)
|
346 |
+
if query_padding_mask is not None:
|
347 |
+
output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
|
348 |
+
return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)
|