#include #include // CUDA forward declarations std::vector projective_transform_cuda( torch::Tensor poses, torch::Tensor disps, torch::Tensor intrinsics, torch::Tensor ii, torch::Tensor jj); torch::Tensor depth_filter_cuda( torch::Tensor poses, torch::Tensor disps, torch::Tensor intrinsics, torch::Tensor ix, torch::Tensor thresh); torch::Tensor frame_distance_cuda( torch::Tensor poses, torch::Tensor disps, torch::Tensor intrinsics, torch::Tensor ii, torch::Tensor jj, const float beta); std::vector projmap_cuda( torch::Tensor poses, torch::Tensor disps, torch::Tensor intrinsics, torch::Tensor ii, torch::Tensor jj); torch::Tensor iproj_cuda( torch::Tensor poses, torch::Tensor disps, torch::Tensor intrinsics); std::vector ba_cuda( torch::Tensor poses, torch::Tensor disps, torch::Tensor intrinsics, torch::Tensor disps_sens, torch::Tensor targets, torch::Tensor weights, torch::Tensor eta, torch::Tensor ii, torch::Tensor jj, const int t0, const int t1, const int iterations, const float lm, const float ep, const bool motion_only); std::vector corr_index_cuda_forward( torch::Tensor volume, torch::Tensor coords, int radius); std::vector corr_index_cuda_backward( torch::Tensor volume, torch::Tensor coords, torch::Tensor corr_grad, int radius); std::vector altcorr_cuda_forward( torch::Tensor fmap1, torch::Tensor fmap2, torch::Tensor coords, int radius); std::vector altcorr_cuda_backward( torch::Tensor fmap1, torch::Tensor fmap2, torch::Tensor coords, torch::Tensor corr_grad, int radius); #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_INPUT(x) CHECK_CONTIGUOUS(x) std::vector ba( torch::Tensor poses, torch::Tensor disps, torch::Tensor intrinsics, torch::Tensor disps_sens, torch::Tensor targets, torch::Tensor weights, torch::Tensor eta, torch::Tensor ii, torch::Tensor jj, const int t0, const int t1, const int iterations, const float lm, const float ep, const bool motion_only) { CHECK_INPUT(targets); CHECK_INPUT(weights); CHECK_INPUT(poses); CHECK_INPUT(disps); CHECK_INPUT(intrinsics); CHECK_INPUT(disps_sens); CHECK_INPUT(ii); CHECK_INPUT(jj); return ba_cuda(poses, disps, intrinsics, disps_sens, targets, weights, eta, ii, jj, t0, t1, iterations, lm, ep, motion_only); } torch::Tensor frame_distance( torch::Tensor poses, torch::Tensor disps, torch::Tensor intrinsics, torch::Tensor ii, torch::Tensor jj, const float beta) { CHECK_INPUT(poses); CHECK_INPUT(disps); CHECK_INPUT(intrinsics); CHECK_INPUT(ii); CHECK_INPUT(jj); return frame_distance_cuda(poses, disps, intrinsics, ii, jj, beta); } std::vector projmap( torch::Tensor poses, torch::Tensor disps, torch::Tensor intrinsics, torch::Tensor ii, torch::Tensor jj) { CHECK_INPUT(poses); CHECK_INPUT(disps); CHECK_INPUT(intrinsics); CHECK_INPUT(ii); CHECK_INPUT(jj); return projmap_cuda(poses, disps, intrinsics, ii, jj); } torch::Tensor iproj( torch::Tensor poses, torch::Tensor disps, torch::Tensor intrinsics) { CHECK_INPUT(poses); CHECK_INPUT(disps); CHECK_INPUT(intrinsics); return iproj_cuda(poses, disps, intrinsics); } // c++ python binding std::vector corr_index_forward( torch::Tensor volume, torch::Tensor coords, int radius) { CHECK_INPUT(volume); CHECK_INPUT(coords); return corr_index_cuda_forward(volume, coords, radius); } std::vector corr_index_backward( torch::Tensor volume, torch::Tensor coords, torch::Tensor corr_grad, int radius) { CHECK_INPUT(volume); CHECK_INPUT(coords); CHECK_INPUT(corr_grad); auto volume_grad = corr_index_cuda_backward(volume, coords, corr_grad, radius); return {volume_grad}; } std::vector altcorr_forward( torch::Tensor fmap1, torch::Tensor fmap2, torch::Tensor coords, int radius) { CHECK_INPUT(fmap1); CHECK_INPUT(fmap2); CHECK_INPUT(coords); return altcorr_cuda_forward(fmap1, fmap2, coords, radius); } std::vector altcorr_backward( torch::Tensor fmap1, torch::Tensor fmap2, torch::Tensor coords, torch::Tensor corr_grad, int radius) { CHECK_INPUT(fmap1); CHECK_INPUT(fmap2); CHECK_INPUT(coords); CHECK_INPUT(corr_grad); return altcorr_cuda_backward(fmap1, fmap2, coords, corr_grad, radius); } torch::Tensor depth_filter( torch::Tensor poses, torch::Tensor disps, torch::Tensor intrinsics, torch::Tensor ix, torch::Tensor thresh) { CHECK_INPUT(poses); CHECK_INPUT(disps); CHECK_INPUT(intrinsics); CHECK_INPUT(ix); CHECK_INPUT(thresh); return depth_filter_cuda(poses, disps, intrinsics, ix, thresh); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // bundle adjustment kernels m.def("ba", &ba, "bundle adjustment"); m.def("frame_distance", &frame_distance, "frame_distance"); m.def("projmap", &projmap, "projmap"); m.def("depth_filter", &depth_filter, "depth_filter"); m.def("iproj", &iproj, "back projection"); // correlation volume kernels m.def("altcorr_forward", &altcorr_forward, "ALTCORR forward"); m.def("altcorr_backward", &altcorr_backward, "ALTCORR backward"); m.def("corr_index_forward", &corr_index_forward, "INDEX forward"); m.def("corr_index_backward", &corr_index_backward, "INDEX backward"); }