kernel
flash-attn3 / torch-ext /torch_binding.cpp
danieldk's picture
danieldk HF Staff
Various small fixes
2e75662
#include <torch/library.h>
#include "pytorch_shim.h"
#include "registration.h"
#include "torch_binding.h"
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("fwd("
"Tensor q,"
"Tensor k,"
"Tensor v,"
"Tensor(k_new!)? k_new = None,"
"Tensor(v_new!)? v_new = None,"
"Tensor? q_v = None,"
"Tensor(out!)? out = None,"
"Tensor? cu_seqlens_q = None,"
"Tensor? cu_seqlens_k = None,"
"Tensor? cu_seqlens_k_new = None,"
"Tensor? seqused_q = None,"
"Tensor? seqused_k = None,"
"int? max_seqlen_q = None,"
"int? max_seqlen_k = None,"
"Tensor? page_table = None,"
"Tensor? kv_batch_idx = None,"
"Tensor? leftpad_k = None,"
"Tensor? rotary_cos = None,"
"Tensor? rotary_sin = None,"
"Tensor? seqlens_rotary = None,"
"Tensor? q_descale = None,"
"Tensor? k_descale = None,"
"Tensor? v_descale = None,"
"float? softmax_scale = None,"
"bool is_causal = False,"
"int window_size_left = -1,"
"int window_size_right = -1,"
"int attention_chunk = 0,"
"float softcap = 0.0,"
"bool is_rotary_interleaved = False,"
"Tensor? scheduler_metadata = None,"
"int num_splits = 0,"
"bool? pack_gqa = None,"
"int sm_margin = 0) -> (Tensor(out!), Tensor, Tensor, Tensor)");
ops.def("bwd("
"Tensor dout,"
"Tensor q,"
"Tensor k,"
"Tensor v,"
"Tensor out,"
"Tensor softmax_lse,"
"Tensor(dq!)? dq = None,"
"Tensor(dk!)? dk = None,"
"Tensor(dv!)? dv = None,"
"Tensor? cu_seqlens_q = None,"
"Tensor? cu_seqlens_k = None,"
"Tensor? seqused_q = None,"
"Tensor? seqused_k = None,"
"int? max_seqlen_q = None,"
"int? max_seqlen_k = None,"
"float? softmax_scale = None,"
"bool is_causal = False,"
"int window_size_left = -1,"
"int window_size_right = -1,"
"float softcap = 0.0,"
"bool deterministic = False,"
"int sm_margin = 0) -> (Tensor(dq!), Tensor(dk!), Tensor(dv!), Tensor, Tensor, Tensor, Tensor, Tensor)");
ops.def("fwd_combine("
"Tensor out_partial,"
"Tensor lse_partial,"
"Tensor(out!)? out = None,"
"ScalarType? out_dtype = None) -> (Tensor(out!), Tensor)");
ops.def("get_scheduler_metadata("
"int batch_size,"
"int max_seqlen_q,"
"int max_seqlen_k,"
"int num_heads,"
"int num_heads_k,"
"int headdim,"
"int headdim_v,"
"ScalarType qkv_dtype,"
"Tensor seqused_k,"
"Tensor? cu_seqlens_q = None,"
"Tensor? cu_seqlens_k = None,"
"Tensor? cu_seqlens_k_new = None,"
"Tensor? seqused_q = None,"
"Tensor? leftpad_k = None,"
"int? page_size = None,"
"int max_seqlen_k_new = 0,"
"bool is_causal = False,"
"int window_size_left = -1,"
"int window_size_right = -1,"
"int attention_chunk = 0,"
"bool has_softcap = False,"
"int num_splits = 0,"
"bool? pack_gqa = None,"
"int sm_margin = 0) -> Tensor");
ops.impl("fwd", &mha_fwd);
ops.impl("bwd", &mha_bwd);
ops.impl("fwd_combine", &mha_combine);
ops.impl("get_scheduler_metadata", &mha_fwd_get_scheduler_metadata);
}
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)