|
#include <torch/library.h> |
|
|
|
#include "pytorch_shim.h" |
|
#include "registration.h" |
|
#include "torch_binding.h" |
|
|
|
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { |
|
ops.def("fwd(" |
|
"Tensor q," |
|
"Tensor k," |
|
"Tensor v," |
|
"Tensor(k_new!)? k_new = None," |
|
"Tensor(v_new!)? v_new = None," |
|
"Tensor? q_v = None," |
|
"Tensor(out!)? out = None," |
|
"Tensor? cu_seqlens_q = None," |
|
"Tensor? cu_seqlens_k = None," |
|
"Tensor? cu_seqlens_k_new = None," |
|
"Tensor? seqused_q = None," |
|
"Tensor? seqused_k = None," |
|
"int? max_seqlen_q = None," |
|
"int? max_seqlen_k = None," |
|
"Tensor? page_table = None," |
|
"Tensor? kv_batch_idx = None," |
|
"Tensor? leftpad_k = None," |
|
"Tensor? rotary_cos = None," |
|
"Tensor? rotary_sin = None," |
|
"Tensor? seqlens_rotary = None," |
|
"Tensor? q_descale = None," |
|
"Tensor? k_descale = None," |
|
"Tensor? v_descale = None," |
|
"float? softmax_scale = None," |
|
"bool is_causal = False," |
|
"int window_size_left = -1," |
|
"int window_size_right = -1," |
|
"int attention_chunk = 0," |
|
"float softcap = 0.0," |
|
"bool is_rotary_interleaved = False," |
|
"Tensor? scheduler_metadata = None," |
|
"int num_splits = 0," |
|
"bool? pack_gqa = None," |
|
"int sm_margin = 0) -> (Tensor(out!), Tensor, Tensor, Tensor)"); |
|
ops.def("bwd(" |
|
"Tensor dout," |
|
"Tensor q," |
|
"Tensor k," |
|
"Tensor v," |
|
"Tensor out," |
|
"Tensor softmax_lse," |
|
"Tensor(dq!)? dq = None," |
|
"Tensor(dk!)? dk = None," |
|
"Tensor(dv!)? dv = None," |
|
"Tensor? cu_seqlens_q = None," |
|
"Tensor? cu_seqlens_k = None," |
|
"Tensor? seqused_q = None," |
|
"Tensor? seqused_k = None," |
|
"int? max_seqlen_q = None," |
|
"int? max_seqlen_k = None," |
|
"float? softmax_scale = None," |
|
"bool is_causal = False," |
|
"int window_size_left = -1," |
|
"int window_size_right = -1," |
|
"float softcap = 0.0," |
|
"bool deterministic = False," |
|
"int sm_margin = 0) -> (Tensor(dq!), Tensor(dk!), Tensor(dv!), Tensor, Tensor, Tensor, Tensor, Tensor)"); |
|
ops.def("fwd_combine(" |
|
"Tensor out_partial," |
|
"Tensor lse_partial," |
|
"Tensor(out!)? out = None," |
|
"ScalarType? out_dtype = None) -> (Tensor(out!), Tensor)"); |
|
ops.def("get_scheduler_metadata(" |
|
"int batch_size," |
|
"int max_seqlen_q," |
|
"int max_seqlen_k," |
|
"int num_heads," |
|
"int num_heads_k," |
|
"int headdim," |
|
"int headdim_v," |
|
"ScalarType qkv_dtype," |
|
"Tensor seqused_k," |
|
"Tensor? cu_seqlens_q = None," |
|
"Tensor? cu_seqlens_k = None," |
|
"Tensor? cu_seqlens_k_new = None," |
|
"Tensor? seqused_q = None," |
|
"Tensor? leftpad_k = None," |
|
"int? page_size = None," |
|
"int max_seqlen_k_new = 0," |
|
"bool is_causal = False," |
|
"int window_size_left = -1," |
|
"int window_size_right = -1," |
|
"int attention_chunk = 0," |
|
"bool has_softcap = False," |
|
"int num_splits = 0," |
|
"bool? pack_gqa = None," |
|
"int sm_margin = 0) -> Tensor"); |
|
|
|
ops.impl("fwd", &mha_fwd); |
|
ops.impl("bwd", &mha_bwd); |
|
ops.impl("fwd_combine", &mha_combine); |
|
ops.impl("get_scheduler_metadata", &mha_fwd_get_scheduler_metadata); |
|
} |
|
|
|
REGISTER_EXTENSION(TORCH_EXTENSION_NAME) |
|
|