|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#include <torch/extension.h> |
|
#include <algorithm> |
|
|
|
namespace { |
|
|
|
template <typename T> |
|
void exclusiveCumprod( |
|
const T* p_choose, |
|
T* cumprod_1mp, |
|
uint32_t bsz, |
|
uint32_t tgt_len, |
|
uint32_t src_len) { |
|
|
|
for (uint32_t b = 0; b < bsz; b++) { |
|
for (uint32_t tgt = 0; tgt < tgt_len; tgt++) { |
|
for (uint32_t src = 0; src < src_len; src++) { |
|
uint32_t idx = b * tgt_len * src_len + tgt * src_len + src; |
|
cumprod_1mp[idx] = 1 - p_choose[idx]; |
|
} |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for (uint32_t b = 0; b < bsz; b++) { |
|
for (uint32_t tgt = 0; tgt < tgt_len; tgt++) { |
|
uint32_t idx_offset = b * tgt_len * src_len + tgt * src_len; |
|
T prev = cumprod_1mp[idx_offset]; |
|
|
|
cumprod_1mp[idx_offset] = (T)1.0; |
|
T curr; |
|
for (uint32_t src = 1; src < src_len; src++) { |
|
uint32_t idx = idx_offset + src; |
|
curr = cumprod_1mp[idx]; |
|
cumprod_1mp[idx] = cumprod_1mp[idx - 1] * prev; |
|
prev = curr; |
|
} |
|
} |
|
} |
|
} |
|
|
|
template <typename T> |
|
void clamp( |
|
const T* cumprod_1mp, |
|
T* cumprod_1mp_clamp, |
|
uint32_t bsz, |
|
uint32_t tgt_len, |
|
uint32_t src_len, |
|
T min_val, |
|
T max_val) { |
|
for (uint32_t b = 0; b < bsz; b++) { |
|
for (uint32_t tgt = 0; tgt < tgt_len; tgt++) { |
|
for (uint32_t src = 0; src < src_len; src++) { |
|
uint32_t idx = b * tgt_len * src_len + tgt * src_len + src; |
|
if (cumprod_1mp[idx] < min_val) { |
|
cumprod_1mp_clamp[idx] = min_val; |
|
} else if (cumprod_1mp[idx] > max_val) { |
|
cumprod_1mp_clamp[idx] = max_val; |
|
} else { |
|
cumprod_1mp_clamp[idx] = cumprod_1mp[idx]; |
|
} |
|
} |
|
} |
|
} |
|
} |
|
|
|
template <typename T> |
|
void alignmentTrainCPUImpl( |
|
const T* p_choose, |
|
T* alpha, |
|
uint32_t bsz, |
|
uint32_t tgt_len, |
|
uint32_t src_len, |
|
float eps) { |
|
|
|
|
|
|
|
|
|
|
|
uint32_t elements = bsz * tgt_len * src_len; |
|
T* cumprod_1mp = new T[elements]; |
|
T* cumprod_1mp_clamp = new T[elements]; |
|
|
|
exclusiveCumprod<T>(p_choose, cumprod_1mp, bsz, tgt_len, src_len); |
|
clamp<T>( |
|
cumprod_1mp, cumprod_1mp_clamp, bsz, tgt_len, src_len, (T)eps, (T)1.0); |
|
|
|
|
|
|
|
|
|
for (uint32_t b = 0; b < bsz; b++) { |
|
alpha[b * tgt_len * src_len] = 1.0; |
|
} |
|
|
|
for (uint32_t tgt = 0; tgt < tgt_len; tgt++) { |
|
for (uint32_t b = 0; b < bsz; b++) { |
|
uint32_t alpha_idx, inout_idx; |
|
T prev_scan = 0, curr_scan, out; |
|
for (uint32_t src = 0; src < src_len; src++) { |
|
|
|
if (tgt == 0) { |
|
|
|
alpha_idx = b * tgt_len * src_len + src; |
|
} else { |
|
|
|
alpha_idx = b * tgt_len * src_len + (tgt - 1) * src_len + src; |
|
} |
|
|
|
inout_idx = b * tgt_len * src_len + tgt * src_len + src; |
|
curr_scan = prev_scan + alpha[alpha_idx] / cumprod_1mp_clamp[inout_idx]; |
|
|
|
out = curr_scan * p_choose[inout_idx] * cumprod_1mp[inout_idx]; |
|
alpha[inout_idx] = std::min<T>(std::max<T>(out, 0), 1.0); |
|
prev_scan = curr_scan; |
|
} |
|
} |
|
} |
|
|
|
free(cumprod_1mp); |
|
free(cumprod_1mp_clamp); |
|
} |
|
|
|
void alignmentTrainCPU( |
|
const torch::Tensor& p_choose, |
|
torch::Tensor& alpha, |
|
float eps) { |
|
uint32_t bsz = p_choose.size(0); |
|
uint32_t tgt_len = p_choose.size(1); |
|
uint32_t src_len = p_choose.size(2); |
|
|
|
AT_DISPATCH_FLOATING_TYPES_AND2( |
|
torch::ScalarType::Half, |
|
torch::ScalarType::BFloat16, |
|
p_choose.scalar_type(), |
|
"alignmentCPUImpl", |
|
[&]() { |
|
alignmentTrainCPUImpl<scalar_t>( |
|
p_choose.data_ptr<scalar_t>(), |
|
alpha.data_ptr<scalar_t>(), |
|
bsz, |
|
tgt_len, |
|
src_len, |
|
eps); |
|
}); |
|
} |
|
|
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { |
|
m.def( |
|
"alignment_train_cpu", |
|
&alignmentTrainCPU, |
|
"expected_alignment_from_p_choose (CPU)"); |
|
} |
|
|
|
} |
|
|