Fix ops backward compatibility tests
Browse files- tests/test_flash_attn.py +10 -9
tests/test_flash_attn.py
CHANGED
@@ -20,7 +20,8 @@ from test_util import (
|
|
20 |
import kernels
|
21 |
|
22 |
flash_attn3 = kernels.get_kernel("kernels-community/flash-attn3")
|
23 |
-
ops = flash_attn3._ops
|
|
|
24 |
|
25 |
|
26 |
DISABLE_BACKWARD = os.getenv("FLASH_ATTENTION_DISABLE_BACKWARD", "FALSE") == "TRUE"
|
@@ -1135,7 +1136,7 @@ def test_flash3_bw_compatibility() -> None:
|
|
1135 |
# 1/ Instead of removing arguments, error out if their value is no longer supported
|
1136 |
# 2/ When adding arguments, add them at the end with a default value
|
1137 |
assert ops.fwd.default._schema.is_backward_compatible_with(parse_schema(
|
1138 |
-
"
|
1139 |
"Tensor(v_new!)? v_new=None, Tensor? q_v=None, Tensor(out!)? out=None, "
|
1140 |
"Tensor? cu_seqlens_q=None, Tensor? cu_seqlens_k=None, "
|
1141 |
"Tensor? cu_seqlens_k_new=None, Tensor? seqused_q=None, Tensor? seqused_k=None, "
|
@@ -1146,25 +1147,25 @@ def test_flash3_bw_compatibility() -> None:
|
|
1146 |
"int attention_chunk=0, float softcap=0., bool is_rotary_interleaved=False, "
|
1147 |
"Tensor? scheduler_metadata=None, int num_splits=0, bool? pack_gqa=None, int sm_margin=0) "
|
1148 |
"-> (Tensor(out!), Tensor, Tensor, Tensor)"
|
1149 |
-
))
|
1150 |
assert ops.bwd.default._schema.is_backward_compatible_with(parse_schema(
|
1151 |
-
"
|
1152 |
"Tensor(dq!)? dq=None, Tensor(dk!)? dk=None, Tensor(dv!)? dv=None, Tensor? cu_seqlens_q=None, "
|
1153 |
"Tensor? cu_seqlens_k=None, Tensor? seqused_q=None, Tensor? seqused_k=None, int? max_seqlen_q=None, "
|
1154 |
"int? max_seqlen_k=None, float? softmax_scale=None, bool is_causal=False, int window_size_left=-1, "
|
1155 |
"int window_size_right=-1, float softcap=0., bool deterministic=False, int sm_margin=0) "
|
1156 |
"-> (Tensor(dq!), Tensor(dk!), Tensor(dv!), Tensor, Tensor, Tensor, Tensor, Tensor)"
|
1157 |
-
))
|
1158 |
assert ops.fwd_combine.default._schema.is_backward_compatible_with(parse_schema(
|
1159 |
-
"
|
1160 |
"ScalarType? out_dtype=None) -> (Tensor(out!), Tensor)"
|
1161 |
-
))
|
1162 |
assert ops.get_scheduler_metadata.default._schema.is_backward_compatible_with(parse_schema(
|
1163 |
-
"
|
1164 |
"int num_heads, int num_heads_k, int headdim, int headdim_v, ScalarType qkv_dtype, Tensor seqused_k, "
|
1165 |
"Tensor? cu_seqlens_q=None, Tensor? cu_seqlens_k=None, Tensor? cu_seqlens_k_new=None, "
|
1166 |
"Tensor? seqused_q=None, Tensor? leftpad_k=None, int? page_size=None, int max_seqlen_k_new=0, "
|
1167 |
"bool is_causal=False, int window_size_left=-1, int window_size_right=-1, "
|
1168 |
"int attention_chunk=0, bool has_softcap=False, int num_splits=0, bool? pack_gqa=None, "
|
1169 |
"int sm_margin=0) -> Tensor"
|
1170 |
-
))
|
|
|
20 |
import kernels
|
21 |
|
22 |
flash_attn3 = kernels.get_kernel("kernels-community/flash-attn3")
|
23 |
+
ops = flash_attn3._ops.ops
|
24 |
+
add_op_namespace_prefix = flash_attn3._ops.add_op_namespace_prefix
|
25 |
|
26 |
|
27 |
DISABLE_BACKWARD = os.getenv("FLASH_ATTENTION_DISABLE_BACKWARD", "FALSE") == "TRUE"
|
|
|
1136 |
# 1/ Instead of removing arguments, error out if their value is no longer supported
|
1137 |
# 2/ When adding arguments, add them at the end with a default value
|
1138 |
assert ops.fwd.default._schema.is_backward_compatible_with(parse_schema(
|
1139 |
+
add_op_namespace_prefix("fwd(Tensor q, Tensor k, Tensor v, Tensor(k_new!)? k_new=None, "
|
1140 |
"Tensor(v_new!)? v_new=None, Tensor? q_v=None, Tensor(out!)? out=None, "
|
1141 |
"Tensor? cu_seqlens_q=None, Tensor? cu_seqlens_k=None, "
|
1142 |
"Tensor? cu_seqlens_k_new=None, Tensor? seqused_q=None, Tensor? seqused_k=None, "
|
|
|
1147 |
"int attention_chunk=0, float softcap=0., bool is_rotary_interleaved=False, "
|
1148 |
"Tensor? scheduler_metadata=None, int num_splits=0, bool? pack_gqa=None, int sm_margin=0) "
|
1149 |
"-> (Tensor(out!), Tensor, Tensor, Tensor)"
|
1150 |
+
)))
|
1151 |
assert ops.bwd.default._schema.is_backward_compatible_with(parse_schema(
|
1152 |
+
add_op_namespace_prefix("bwd(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, "
|
1153 |
"Tensor(dq!)? dq=None, Tensor(dk!)? dk=None, Tensor(dv!)? dv=None, Tensor? cu_seqlens_q=None, "
|
1154 |
"Tensor? cu_seqlens_k=None, Tensor? seqused_q=None, Tensor? seqused_k=None, int? max_seqlen_q=None, "
|
1155 |
"int? max_seqlen_k=None, float? softmax_scale=None, bool is_causal=False, int window_size_left=-1, "
|
1156 |
"int window_size_right=-1, float softcap=0., bool deterministic=False, int sm_margin=0) "
|
1157 |
"-> (Tensor(dq!), Tensor(dk!), Tensor(dv!), Tensor, Tensor, Tensor, Tensor, Tensor)"
|
1158 |
+
)))
|
1159 |
assert ops.fwd_combine.default._schema.is_backward_compatible_with(parse_schema(
|
1160 |
+
add_op_namespace_prefix("fwd_combine(Tensor out_partial, Tensor lse_partial, Tensor(out!)? out=None, "
|
1161 |
"ScalarType? out_dtype=None) -> (Tensor(out!), Tensor)"
|
1162 |
+
)))
|
1163 |
assert ops.get_scheduler_metadata.default._schema.is_backward_compatible_with(parse_schema(
|
1164 |
+
add_op_namespace_prefix("get_scheduler_metadata(int batch_size, int max_seqlen_q, int max_seqlen_k, "
|
1165 |
"int num_heads, int num_heads_k, int headdim, int headdim_v, ScalarType qkv_dtype, Tensor seqused_k, "
|
1166 |
"Tensor? cu_seqlens_q=None, Tensor? cu_seqlens_k=None, Tensor? cu_seqlens_k_new=None, "
|
1167 |
"Tensor? seqused_q=None, Tensor? leftpad_k=None, int? page_size=None, int max_seqlen_k_new=0, "
|
1168 |
"bool is_causal=False, int window_size_left=-1, int window_size_right=-1, "
|
1169 |
"int attention_chunk=0, bool has_softcap=False, int num_splits=0, bool? pack_gqa=None, "
|
1170 |
"int sm_margin=0) -> Tensor"
|
1171 |
+
)))
|