kernel
danieldk HF Staff commited on
Commit
557701f
·
1 Parent(s): 745fbe7

Fix ops backward compatibility tests

Browse files
Files changed (1) hide show
  1. tests/test_flash_attn.py +10 -9
tests/test_flash_attn.py CHANGED
@@ -20,7 +20,8 @@ from test_util import (
20
  import kernels
21
 
22
  flash_attn3 = kernels.get_kernel("kernels-community/flash-attn3")
23
- ops = flash_attn3._ops
 
24
 
25
 
26
  DISABLE_BACKWARD = os.getenv("FLASH_ATTENTION_DISABLE_BACKWARD", "FALSE") == "TRUE"
@@ -1135,7 +1136,7 @@ def test_flash3_bw_compatibility() -> None:
1135
  # 1/ Instead of removing arguments, error out if their value is no longer supported
1136
  # 2/ When adding arguments, add them at the end with a default value
1137
  assert ops.fwd.default._schema.is_backward_compatible_with(parse_schema(
1138
- "flash_attn_3::fwd(Tensor q, Tensor k, Tensor v, Tensor(k_new!)? k_new=None, "
1139
  "Tensor(v_new!)? v_new=None, Tensor? q_v=None, Tensor(out!)? out=None, "
1140
  "Tensor? cu_seqlens_q=None, Tensor? cu_seqlens_k=None, "
1141
  "Tensor? cu_seqlens_k_new=None, Tensor? seqused_q=None, Tensor? seqused_k=None, "
@@ -1146,25 +1147,25 @@ def test_flash3_bw_compatibility() -> None:
1146
  "int attention_chunk=0, float softcap=0., bool is_rotary_interleaved=False, "
1147
  "Tensor? scheduler_metadata=None, int num_splits=0, bool? pack_gqa=None, int sm_margin=0) "
1148
  "-> (Tensor(out!), Tensor, Tensor, Tensor)"
1149
- ))
1150
  assert ops.bwd.default._schema.is_backward_compatible_with(parse_schema(
1151
- "flash_attn_3::bwd(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, "
1152
  "Tensor(dq!)? dq=None, Tensor(dk!)? dk=None, Tensor(dv!)? dv=None, Tensor? cu_seqlens_q=None, "
1153
  "Tensor? cu_seqlens_k=None, Tensor? seqused_q=None, Tensor? seqused_k=None, int? max_seqlen_q=None, "
1154
  "int? max_seqlen_k=None, float? softmax_scale=None, bool is_causal=False, int window_size_left=-1, "
1155
  "int window_size_right=-1, float softcap=0., bool deterministic=False, int sm_margin=0) "
1156
  "-> (Tensor(dq!), Tensor(dk!), Tensor(dv!), Tensor, Tensor, Tensor, Tensor, Tensor)"
1157
- ))
1158
  assert ops.fwd_combine.default._schema.is_backward_compatible_with(parse_schema(
1159
- "flash_attn_3::fwd_combine(Tensor out_partial, Tensor lse_partial, Tensor(out!)? out=None, "
1160
  "ScalarType? out_dtype=None) -> (Tensor(out!), Tensor)"
1161
- ))
1162
  assert ops.get_scheduler_metadata.default._schema.is_backward_compatible_with(parse_schema(
1163
- "flash_attn_3::get_scheduler_metadata(int batch_size, int max_seqlen_q, int max_seqlen_k, "
1164
  "int num_heads, int num_heads_k, int headdim, int headdim_v, ScalarType qkv_dtype, Tensor seqused_k, "
1165
  "Tensor? cu_seqlens_q=None, Tensor? cu_seqlens_k=None, Tensor? cu_seqlens_k_new=None, "
1166
  "Tensor? seqused_q=None, Tensor? leftpad_k=None, int? page_size=None, int max_seqlen_k_new=0, "
1167
  "bool is_causal=False, int window_size_left=-1, int window_size_right=-1, "
1168
  "int attention_chunk=0, bool has_softcap=False, int num_splits=0, bool? pack_gqa=None, "
1169
  "int sm_margin=0) -> Tensor"
1170
- ))
 
20
  import kernels
21
 
22
  flash_attn3 = kernels.get_kernel("kernels-community/flash-attn3")
23
+ ops = flash_attn3._ops.ops
24
+ add_op_namespace_prefix = flash_attn3._ops.add_op_namespace_prefix
25
 
26
 
27
  DISABLE_BACKWARD = os.getenv("FLASH_ATTENTION_DISABLE_BACKWARD", "FALSE") == "TRUE"
 
1136
  # 1/ Instead of removing arguments, error out if their value is no longer supported
1137
  # 2/ When adding arguments, add them at the end with a default value
1138
  assert ops.fwd.default._schema.is_backward_compatible_with(parse_schema(
1139
+ add_op_namespace_prefix("fwd(Tensor q, Tensor k, Tensor v, Tensor(k_new!)? k_new=None, "
1140
  "Tensor(v_new!)? v_new=None, Tensor? q_v=None, Tensor(out!)? out=None, "
1141
  "Tensor? cu_seqlens_q=None, Tensor? cu_seqlens_k=None, "
1142
  "Tensor? cu_seqlens_k_new=None, Tensor? seqused_q=None, Tensor? seqused_k=None, "
 
1147
  "int attention_chunk=0, float softcap=0., bool is_rotary_interleaved=False, "
1148
  "Tensor? scheduler_metadata=None, int num_splits=0, bool? pack_gqa=None, int sm_margin=0) "
1149
  "-> (Tensor(out!), Tensor, Tensor, Tensor)"
1150
+ )))
1151
  assert ops.bwd.default._schema.is_backward_compatible_with(parse_schema(
1152
+ add_op_namespace_prefix("bwd(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, "
1153
  "Tensor(dq!)? dq=None, Tensor(dk!)? dk=None, Tensor(dv!)? dv=None, Tensor? cu_seqlens_q=None, "
1154
  "Tensor? cu_seqlens_k=None, Tensor? seqused_q=None, Tensor? seqused_k=None, int? max_seqlen_q=None, "
1155
  "int? max_seqlen_k=None, float? softmax_scale=None, bool is_causal=False, int window_size_left=-1, "
1156
  "int window_size_right=-1, float softcap=0., bool deterministic=False, int sm_margin=0) "
1157
  "-> (Tensor(dq!), Tensor(dk!), Tensor(dv!), Tensor, Tensor, Tensor, Tensor, Tensor)"
1158
+ )))
1159
  assert ops.fwd_combine.default._schema.is_backward_compatible_with(parse_schema(
1160
+ add_op_namespace_prefix("fwd_combine(Tensor out_partial, Tensor lse_partial, Tensor(out!)? out=None, "
1161
  "ScalarType? out_dtype=None) -> (Tensor(out!), Tensor)"
1162
+ )))
1163
  assert ops.get_scheduler_metadata.default._schema.is_backward_compatible_with(parse_schema(
1164
+ add_op_namespace_prefix("get_scheduler_metadata(int batch_size, int max_seqlen_q, int max_seqlen_k, "
1165
  "int num_heads, int num_heads_k, int headdim, int headdim_v, ScalarType qkv_dtype, Tensor seqused_k, "
1166
  "Tensor? cu_seqlens_q=None, Tensor? cu_seqlens_k=None, Tensor? cu_seqlens_k_new=None, "
1167
  "Tensor? seqused_q=None, Tensor? leftpad_k=None, int? page_size=None, int max_seqlen_k_new=0, "
1168
  "bool is_causal=False, int window_size_left=-1, int window_size_right=-1, "
1169
  "int attention_chunk=0, bool has_softcap=False, int num_splits=0, bool? pack_gqa=None, "
1170
  "int sm_margin=0) -> Tensor"
1171
+ )))