#include #include "pytorch_shim.h" #include "registration.h" #include "torch_binding.h" TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("fwd(" "Tensor q," "Tensor k," "Tensor v," "Tensor(k_new!)? k_new = None," "Tensor(v_new!)? v_new = None," "Tensor? q_v = None," "Tensor(out!)? out = None," "Tensor? cu_seqlens_q = None," "Tensor? cu_seqlens_k = None," "Tensor? cu_seqlens_k_new = None," "Tensor? seqused_q = None," "Tensor? seqused_k = None," "int? max_seqlen_q = None," "int? max_seqlen_k = None," "Tensor? page_table = None," "Tensor? kv_batch_idx = None," "Tensor? leftpad_k = None," "Tensor? rotary_cos = None," "Tensor? rotary_sin = None," "Tensor? seqlens_rotary = None," "Tensor? q_descale = None," "Tensor? k_descale = None," "Tensor? v_descale = None," "float? softmax_scale = None," "bool is_causal = False," "int window_size_left = -1," "int window_size_right = -1," "int attention_chunk = 0," "float softcap = 0.0," "bool is_rotary_interleaved = False," "Tensor? scheduler_metadata = None," "int num_splits = 0," "bool? pack_gqa = None," "int sm_margin = 0) -> (Tensor(out!), Tensor, Tensor, Tensor)"); ops.def("bwd(" "Tensor dout," "Tensor q," "Tensor k," "Tensor v," "Tensor out," "Tensor softmax_lse," "Tensor(dq!)? dq = None," "Tensor(dk!)? dk = None," "Tensor(dv!)? dv = None," "Tensor? cu_seqlens_q = None," "Tensor? cu_seqlens_k = None," "Tensor? seqused_q = None," "Tensor? seqused_k = None," "int? max_seqlen_q = None," "int? max_seqlen_k = None," "float? softmax_scale = None," "bool is_causal = False," "int window_size_left = -1," "int window_size_right = -1," "float softcap = 0.0," "bool deterministic = False," "int sm_margin = 0) -> (Tensor(dq!), Tensor(dk!), Tensor(dv!), Tensor, Tensor, Tensor, Tensor, Tensor)"); ops.def("fwd_combine(" "Tensor out_partial," "Tensor lse_partial," "Tensor(out!)? out = None," "ScalarType? out_dtype = None) -> (Tensor(out!), Tensor)"); ops.def("get_scheduler_metadata(" "int batch_size," "int max_seqlen_q," "int max_seqlen_k," "int num_heads," "int num_heads_k," "int headdim," "int headdim_v," "ScalarType qkv_dtype," "Tensor seqused_k," "Tensor? cu_seqlens_q = None," "Tensor? cu_seqlens_k = None," "Tensor? cu_seqlens_k_new = None," "Tensor? seqused_q = None," "Tensor? leftpad_k = None," "int? page_size = None," "int max_seqlen_k_new = 0," "bool is_causal = False," "int window_size_left = -1," "int window_size_right = -1," "int attention_chunk = 0," "bool has_softcap = False," "int num_splits = 0," "bool? pack_gqa = None," "int sm_margin = 0) -> Tensor"); ops.impl("fwd", &mha_fwd); ops.impl("bwd", &mha_bwd); ops.impl("fwd_combine", &mha_combine); ops.impl("get_scheduler_metadata", &mha_fwd_get_scheduler_metadata); } REGISTER_EXTENSION(TORCH_EXTENSION_NAME)