#pragma once #include /** * Unforunately, the type signatures of the flash_attn ops are not compatible * with the PyTorch library bindings. To get around that we use * `make_pytorch_shim` which creates a lambda that exponses the API using * PyTorch compatible types to the types, then converts them to the types * expected by the flash_attn ops. This shims allows us to make minimal changes * to `flash_api.cpp` making it easier to synchronize with upstream changes. * * The `pytorch_library_compatible_type` struct is used to map from the * flash_attn ops types to a PyTorch library compatible one. The main issues is * that the following types are not support by PyTorch libary bindings: * - `int` * - `float` * - `std::optional &` * - `std::optional &` * So we convert them to (respectively): * - `int64_t` * - `double` * - `const std::optional&` * - `const std::optional&` */ template struct pytorch_library_compatible_type { using type = T; static T convert_from_type(T arg) { return arg; } }; template using pytorch_library_compatible_type_t = \ typename pytorch_library_compatible_type::type; template T convert_from_pytorch_compatible_type(pytorch_library_compatible_type_t arg) { return pytorch_library_compatible_type::convert_from_type(arg); } // Map `std::optional &` -> `const std::optional&` // (NOTE: this is bit unsafe but non of the ops in flash_attn mutate // the optional container) template struct pytorch_library_compatible_type &> { using type = const std::optional&; static std::optional& convert_from_type(const std::optional &arg) { return const_cast&>(arg); } }; // Map `std::optional` -> // `std::optional>` // (NOTE: tested for `std::optional` -> `std::optional`) template struct pytorch_library_compatible_type> { using type = std::optional>; static std::optional> convert_from_type(std::optional arg) { return arg; } }; // Map `std::optional&` -> `const std::optional&` template<> struct pytorch_library_compatible_type &> { using type = const std::optional&; static std::optional& convert_from_type( const std::optional &arg) { return const_cast&>( reinterpret_cast&>(arg)); } }; // Map `int` -> `int64_t` template<> struct pytorch_library_compatible_type { using type = int64_t; static int convert_from_type(int64_t arg) { TORCH_CHECK(arg <= std::numeric_limits::max(), "int64_t value is too large to be converted to int"); TORCH_CHECK(arg >= std::numeric_limits::min(), "int64_t value is too small to be converted to int"); return arg; } }; // Map `float` -> `double` template<> struct pytorch_library_compatible_type { using type = double; static float convert_from_type(double arg) { TORCH_CHECK(std::abs(arg) <= std::numeric_limits::max(), "double value is too large to be converted to float"); return arg; } }; // // Shim Utils // template auto make_pytorch_shim(Ret(*fun)(Args... args)){ return [fun](pytorch_library_compatible_type_t... args) { return fun(convert_from_pytorch_compatible_type(args)...); }; }