File size: 10,763 Bytes
6fc683c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 |
/**
* Copyright 2017-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> // @manual=//caffe2/aten:ATen-cu
#include <cuda_runtime.h>
#include <algorithm> // std::min/max
#include <cub/cub.cuh>
#include "alignment_train_cuda.h"
#include "utils.h"
namespace {
// The thread block length in threads along the X dimension
constexpr int BLOCK_DIM_X = 128;
// The thread block length in threads along the Y dimension
constexpr int BLOCK_DIM_Y = 8;
// The thread block length in threads for scan operation
constexpr int SCAN_BLOCK = 512;
#define gpuErrchk(ans) \
{ gpuAssert((ans), __FILE__, __LINE__); }
inline void
gpuAssert(cudaError_t code, const char* file, int line, bool abort = true) {
if (code != cudaSuccess) {
fprintf(
stderr,
"\nGPUassert: %s %s %d\n",
cudaGetErrorString(code),
file,
line);
if (abort)
exit(code);
}
}
template <typename T>
struct Prod {
/// prod operator, returns <tt>a * b</tt>
__host__ __device__ __forceinline__ T
operator()(const T& a, const T& b) const {
return a * b;
}
};
template <typename T>
struct BlockPrefixProdCallbackOp {
// Running prefix
T running_total;
// Constructor
__device__ BlockPrefixProdCallbackOp(T running_total)
: running_total(running_total) {}
// Callback operator to be entered by the first warp of threads in the block.
// Thread-0 is responsible for returning a value for seeding the block-wide
// scan.
__device__ T operator()(const T block_aggregate) {
T old_prefix = running_total;
running_total *= block_aggregate;
return old_prefix;
}
};
template <typename T>
struct BlockPrefixSumCallbackOp {
// Running prefix
T running_total;
// Constructor
__device__ BlockPrefixSumCallbackOp(T running_total)
: running_total(running_total) {}
// Callback operator to be entered by the first warp of threads in the block.
// Thread-0 is responsible for returning a value for seeding the block-wide
// scan.
__device__ T operator()(const T block_aggregate) {
T old_prefix = running_total;
running_total += block_aggregate;
return old_prefix;
}
};
template <typename T>
__global__ void oneMinusPKernel(
const T* __restrict__ p_choose,
T* __restrict__ cumprod_1mp,
uint32_t bsz,
uint32_t tgt_len,
uint32_t src_len) {
for (uint32_t b = blockIdx.x; b < bsz; b += gridDim.x) {
for (uint32_t tgt = threadIdx.y; tgt < tgt_len; tgt += blockDim.y) {
for (uint32_t src = threadIdx.x; src < src_len; src += blockDim.x) {
uint32_t idx = b * tgt_len * src_len + tgt * src_len + src;
cumprod_1mp[idx] = 1 - p_choose[idx];
}
}
}
}
template <typename T, int TPB>
__global__ void innermostScanKernel(
T* __restrict__ cumprod_1mp,
uint32_t bsz,
uint32_t tgt_len,
uint32_t src_len) {
for (uint32_t b = blockIdx.y; b < bsz; b += gridDim.y) {
for (uint32_t tgt = blockIdx.x; tgt < tgt_len; tgt += gridDim.x) {
// Specialize BlockScan for a 1D block of TPB threads on type T
typedef cub::BlockScan<T, TPB> BlockScan;
// Allocate shared memory for BlockScan
__shared__ typename BlockScan::TempStorage temp_storage;
// Initialize running total
BlockPrefixProdCallbackOp<T> prefix_op(1);
const uint32_t tid = threadIdx.x;
for (uint32_t block_src = 0; block_src < src_len;
block_src += blockDim.x) {
uint32_t src = block_src + tid;
uint32_t idx = b * tgt_len * src_len + tgt * src_len + src;
T thread_data = (src < src_len) ? cumprod_1mp[idx] : (T)0;
// Collectively compute the block-wide inclusive prefix sum
BlockScan(temp_storage)
.ExclusiveScan(thread_data, thread_data, Prod<T>(), prefix_op);
__syncthreads();
// write the scanned value to output
if (src < src_len) {
cumprod_1mp[idx] = thread_data;
}
}
}
}
}
template <typename T>
__global__ void clampKernel(
const T* __restrict__ cumprod_1mp,
T* __restrict__ cumprod_1mp_clamp,
uint32_t bsz,
uint32_t tgt_len,
uint32_t src_len,
T min_val,
T max_val) {
for (uint32_t b = blockIdx.x; b < bsz; b += gridDim.x) {
for (uint32_t tgt = threadIdx.y; tgt < tgt_len; tgt += blockDim.y) {
for (uint32_t src = threadIdx.x; src < src_len; src += blockDim.x) {
uint32_t idx = b * tgt_len * src_len + tgt * src_len + src;
if (cumprod_1mp[idx] < min_val) {
cumprod_1mp_clamp[idx] = min_val;
} else if (cumprod_1mp[idx] > max_val) {
cumprod_1mp_clamp[idx] = max_val;
} else {
cumprod_1mp_clamp[idx] = cumprod_1mp[idx];
}
}
}
}
}
template <typename T>
__global__ void initAlphaCUDAKernel(
T* alpha,
uint32_t bsz,
uint32_t tgt_len,
uint32_t src_len) {
// alpha[:, 0, 0] = 1.0
for (uint32_t b = blockIdx.x; b < bsz; b += gridDim.x) {
alpha[b * tgt_len * src_len] = (T)1.0;
}
}
template <typename T, int TPB>
__global__ void alignmentTrainCUDAKernel(
const T* __restrict__ p_choose,
const T* __restrict__ cumprod_1mp,
const T* __restrict__ cumprod_1mp_clamp,
T* __restrict__ alpha,
uint32_t bsz,
uint32_t tgt_len,
uint32_t src_len,
uint32_t tgt) {
for (uint32_t b = blockIdx.x; b < bsz; b += gridDim.x) {
// Specialize BlockScan for a 1D block of TPB threads on type T
typedef cub::BlockScan<T, TPB> BlockScan;
// Allocate shared memory for BlockScan
__shared__ typename BlockScan::TempStorage temp_storage;
// Initialize running total
BlockPrefixSumCallbackOp<T> prefix_op(0);
uint32_t b_offset = b * tgt_len * src_len;
const uint32_t tid = threadIdx.x;
for (uint32_t block_src = 0; block_src < src_len; block_src += blockDim.x) {
uint32_t src = block_src + tid;
// Obtain a segment of consecutive items that are blocked across threads
uint32_t inout_idx, alpha_idx;
if (tgt == 0) {
// both alpha and other input index is [b][0][src]
alpha_idx = b_offset + src;
} else {
// alpha index is [b][tgt-1][src]
alpha_idx = b_offset + (tgt - 1) * src_len + src;
}
inout_idx = b_offset + tgt * src_len + src;
T thread_data = (T)0;
if (src < src_len) {
thread_data = alpha[alpha_idx] / cumprod_1mp_clamp[inout_idx];
}
// Collectively compute the block-wide inclusive prefix sum
BlockScan(temp_storage).InclusiveSum(thread_data, thread_data, prefix_op);
__syncthreads();
if (src < src_len) {
T out = thread_data * p_choose[inout_idx] * cumprod_1mp[inout_idx];
// Clamps all elements into the range [ 0, 1.0 ]
alpha[inout_idx] = std::min<T>(std::max<T>(out, 0), (T)1.0);
}
}
}
}
template <typename T>
void exclusiveCumprod(
const T* p_choose,
T* cumprod_1mp,
uint32_t bsz,
uint32_t tgt_len,
uint32_t src_len,
uint32_t max_grid_x,
uint32_t max_grid_y,
cudaStream_t& stream) {
// cumprod_1mp = 1 - p_choose
dim3 grid(std::min<T>(max_grid_x, bsz), 1, 1);
dim3 block(BLOCK_DIM_X, BLOCK_DIM_Y, 1);
oneMinusPKernel<T><<<grid, block, 0, stream>>>(
p_choose, cumprod_1mp, bsz, tgt_len, src_len);
gpuErrchk(cudaGetLastError());
// scan on the innermost dimension of cumprod_1mp
// cumprod_1mp = cumprod(cumprod_1mp)
dim3 grid_scan(
std::min<T>(max_grid_x, tgt_len), std::min<T>(max_grid_y, bsz), 1);
innermostScanKernel<T, SCAN_BLOCK><<<grid_scan, SCAN_BLOCK, 0, stream>>>(
cumprod_1mp, bsz, tgt_len, src_len);
gpuErrchk(cudaGetLastError());
}
template <typename T>
void alignmentTrainCUDAImpl(
const T* p_choose,
T* alpha,
uint32_t bsz,
uint32_t tgt_len,
uint32_t src_len,
float eps) {
// p_choose: bsz , tgt_len, src_len
// cumprod_1mp: bsz , tgt_len, src_len
// cumprod_1mp_clamp : bsz, tgt_len, src_len
// alpha: bsz, tgt_len, src_len
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
uint32_t max_grid_x = at::cuda::getCurrentDeviceProperties()->maxGridSize[0];
uint32_t max_grid_y = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
// Implementing exclusive cumprod.
// cumprod_1mp = cumprod(1 - p_choose)
// There is cumprod in pytorch, however there is no exclusive mode.
// cumprod(x) = [x1, x1x2, x2x3x4, ..., prod_{i=1}^n x_i]
// exclusive means
// cumprod(x) = [1, x1, x1x2, x1x2x3, ..., prod_{i=1}^{n-1} x_i]
uint32_t elements = bsz * tgt_len * src_len;
T* cumprod_1mp;
gpuErrchk(cudaMalloc(&cumprod_1mp, elements * sizeof(T)));
exclusiveCumprod<T>(
p_choose,
cumprod_1mp,
bsz,
tgt_len,
src_len,
max_grid_x,
max_grid_y,
stream);
// clamp cumprod_1mp to the range [eps, 1.0]
T* cumprod_1mp_clamp;
gpuErrchk(cudaMalloc(&cumprod_1mp_clamp, elements * sizeof(T)));
dim3 grid_clamp(std::min<T>(max_grid_x, bsz), 1, 1);
dim3 block_clamp(BLOCK_DIM_X, BLOCK_DIM_Y, 1);
clampKernel<T><<<grid_clamp, block_clamp, 0, stream>>>(
cumprod_1mp, cumprod_1mp_clamp, bsz, tgt_len, src_len, (T)eps, (T)1.0);
gpuErrchk(cudaGetLastError());
// ai = p_i * cumprod(1 − pi) * cumsum(a_i / cumprod(1 − pi))
dim3 grid_init(std::min<int>(max_grid_x, bsz), 1, 1);
initAlphaCUDAKernel<T>
<<<grid_init, 1, 0, stream>>>(alpha, bsz, tgt_len, src_len);
gpuErrchk(cudaGetLastError());
const int grid = std::min(bsz, max_grid_x);
for (uint32_t i = 0; i < tgt_len; i++) {
alignmentTrainCUDAKernel<T, SCAN_BLOCK><<<grid, SCAN_BLOCK, 0, stream>>>(
p_choose,
cumprod_1mp,
cumprod_1mp_clamp,
alpha,
bsz,
tgt_len,
src_len,
i);
gpuErrchk(cudaGetLastError());
}
gpuErrchk(cudaFree(cumprod_1mp));
gpuErrchk(cudaFree(cumprod_1mp_clamp));
}
} // namespace
void alignmentTrainCUDAWrapper(
const torch::Tensor& p_choose,
torch::Tensor& alpha,
float eps) {
// p_choose dimension: bsz, tgt_len, src_len
uint32_t bsz = p_choose.size(0);
uint32_t tgt_len = p_choose.size(1);
uint32_t src_len = p_choose.size(2);
cudaSetDevice(p_choose.get_device());
AT_DISPATCH_FLOATING_TYPES_AND2(
torch::ScalarType::Half,
torch::ScalarType::BFloat16,
p_choose.scalar_type(),
"alignmentTrainCUDAImpl",
[&]() {
alignmentTrainCUDAImpl<scalar_t>(
p_choose.data_ptr<scalar_t>(),
alpha.data_ptr<scalar_t>(),
bsz,
tgt_len,
src_len,
eps);
});
}
|