File size: 3,473 Bytes
eb8ddce 2e75662 eb8ddce 2e75662 eb8ddce 2e75662 eb8ddce 2e75662 eb8ddce 2e75662 eb8ddce |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
#include <torch/library.h>
#include "pytorch_shim.h"
#include "registration.h"
#include "torch_binding.h"
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("fwd("
"Tensor q,"
"Tensor k,"
"Tensor v,"
"Tensor(k_new!)? k_new = None,"
"Tensor(v_new!)? v_new = None,"
"Tensor? q_v = None,"
"Tensor(out!)? out = None,"
"Tensor? cu_seqlens_q = None,"
"Tensor? cu_seqlens_k = None,"
"Tensor? cu_seqlens_k_new = None,"
"Tensor? seqused_q = None,"
"Tensor? seqused_k = None,"
"int? max_seqlen_q = None,"
"int? max_seqlen_k = None,"
"Tensor? page_table = None,"
"Tensor? kv_batch_idx = None,"
"Tensor? leftpad_k = None,"
"Tensor? rotary_cos = None,"
"Tensor? rotary_sin = None,"
"Tensor? seqlens_rotary = None,"
"Tensor? q_descale = None,"
"Tensor? k_descale = None,"
"Tensor? v_descale = None,"
"float? softmax_scale = None,"
"bool is_causal = False,"
"int window_size_left = -1,"
"int window_size_right = -1,"
"int attention_chunk = 0,"
"float softcap = 0.0,"
"bool is_rotary_interleaved = False,"
"Tensor? scheduler_metadata = None,"
"int num_splits = 0,"
"bool? pack_gqa = None,"
"int sm_margin = 0) -> (Tensor(out!), Tensor, Tensor, Tensor)");
ops.def("bwd("
"Tensor dout,"
"Tensor q,"
"Tensor k,"
"Tensor v,"
"Tensor out,"
"Tensor softmax_lse,"
"Tensor(dq!)? dq = None,"
"Tensor(dk!)? dk = None,"
"Tensor(dv!)? dv = None,"
"Tensor? cu_seqlens_q = None,"
"Tensor? cu_seqlens_k = None,"
"Tensor? seqused_q = None,"
"Tensor? seqused_k = None,"
"int? max_seqlen_q = None,"
"int? max_seqlen_k = None,"
"float? softmax_scale = None,"
"bool is_causal = False,"
"int window_size_left = -1,"
"int window_size_right = -1,"
"float softcap = 0.0,"
"bool deterministic = False,"
"int sm_margin = 0) -> (Tensor(dq!), Tensor(dk!), Tensor(dv!), Tensor, Tensor, Tensor, Tensor, Tensor)");
ops.def("fwd_combine("
"Tensor out_partial,"
"Tensor lse_partial,"
"Tensor(out!)? out = None,"
"ScalarType? out_dtype = None) -> (Tensor(out!), Tensor)");
ops.def("get_scheduler_metadata("
"int batch_size,"
"int max_seqlen_q,"
"int max_seqlen_k,"
"int num_heads,"
"int num_heads_k,"
"int headdim,"
"int headdim_v,"
"ScalarType qkv_dtype,"
"Tensor seqused_k,"
"Tensor? cu_seqlens_q = None,"
"Tensor? cu_seqlens_k = None,"
"Tensor? cu_seqlens_k_new = None,"
"Tensor? seqused_q = None,"
"Tensor? leftpad_k = None,"
"int? page_size = None,"
"int max_seqlen_k_new = 0,"
"bool is_causal = False,"
"int window_size_left = -1,"
"int window_size_right = -1,"
"int attention_chunk = 0,"
"bool has_softcap = False,"
"int num_splits = 0,"
"bool? pack_gqa = None,"
"int sm_margin = 0) -> Tensor");
ops.impl("fwd", &mha_fwd);
ops.impl("bwd", &mha_bwd);
ops.impl("fwd_combine", &mha_combine);
ops.impl("get_scheduler_metadata", &mha_fwd_get_scheduler_metadata);
}
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|