File size: 4,768 Bytes
6fc683c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
/**
 * Copyright 2017-present, Facebook, Inc.
 * All rights reserved.
 *
 * This source code is licensed under the license found in the
 * LICENSE file in the root directory of this source tree.
 */

#include <torch/extension.h> // @manual=//caffe2:torch_extension
#include <algorithm>

namespace {

template <typename T>
void exclusiveCumprod(
    const T* p_choose,
    T* cumprod_1mp,
    uint32_t bsz,
    uint32_t tgt_len,
    uint32_t src_len) {
  // cumprod_1mp = 1 - p_choose
  for (uint32_t b = 0; b < bsz; b++) {
    for (uint32_t tgt = 0; tgt < tgt_len; tgt++) {
      for (uint32_t src = 0; src < src_len; src++) {
        uint32_t idx = b * tgt_len * src_len + tgt * src_len + src;
        cumprod_1mp[idx] = 1 - p_choose[idx];
      }
    }
  }

  // Implementing exclusive cumprod in the innermost dimension
  // cumprod_1mp = cumprod(1 - p_choose)
  // There is cumprod in pytorch, however there is no exclusive mode.
  // cumprod(x) = [x1, x1x2, x2x3x4, ..., prod_{i=1}^n x_i]
  // exclusive means
  // cumprod(x) = [1, x1, x1x2, x1x2x3, ..., prod_{i=1}^{n-1} x_i]
  for (uint32_t b = 0; b < bsz; b++) {
    for (uint32_t tgt = 0; tgt < tgt_len; tgt++) {
      uint32_t idx_offset = b * tgt_len * src_len + tgt * src_len;
      T prev = cumprod_1mp[idx_offset];
      // index [b][tgt][0]
      cumprod_1mp[idx_offset] = (T)1.0;
      T curr;
      for (uint32_t src = 1; src < src_len; src++) {
        uint32_t idx = idx_offset + src;
        curr = cumprod_1mp[idx];
        cumprod_1mp[idx] = cumprod_1mp[idx - 1] * prev;
        prev = curr;
      }
    }
  }
}

template <typename T>
void clamp(
    const T* cumprod_1mp,
    T* cumprod_1mp_clamp,
    uint32_t bsz,
    uint32_t tgt_len,
    uint32_t src_len,
    T min_val,
    T max_val) {
  for (uint32_t b = 0; b < bsz; b++) {
    for (uint32_t tgt = 0; tgt < tgt_len; tgt++) {
      for (uint32_t src = 0; src < src_len; src++) {
        uint32_t idx = b * tgt_len * src_len + tgt * src_len + src;
        if (cumprod_1mp[idx] < min_val) {
          cumprod_1mp_clamp[idx] = min_val;
        } else if (cumprod_1mp[idx] > max_val) {
          cumprod_1mp_clamp[idx] = max_val;
        } else {
          cumprod_1mp_clamp[idx] = cumprod_1mp[idx];
        }
      }
    }
  }
}

template <typename T>
void alignmentTrainCPUImpl(
    const T* p_choose,
    T* alpha,
    uint32_t bsz,
    uint32_t tgt_len,
    uint32_t src_len,
    float eps) {
  // p_choose: bsz , tgt_len, src_len
  // cumprod_1mp: bsz , tgt_len, src_len
  // cumprod_1mp_clamp : bsz, tgt_len, src_len
  // alpha: bsz + 1, tgt_len, src_len

  uint32_t elements = bsz * tgt_len * src_len;
  T* cumprod_1mp = new T[elements];
  T* cumprod_1mp_clamp = new T[elements];

  exclusiveCumprod<T>(p_choose, cumprod_1mp, bsz, tgt_len, src_len);
  clamp<T>(
      cumprod_1mp, cumprod_1mp_clamp, bsz, tgt_len, src_len, (T)eps, (T)1.0);

  // ai = p_i * cumprod(1 − pi) * cumsum(a_i / cumprod(1 − pi))

  // Initialize alpha [:, 0, 0]
  for (uint32_t b = 0; b < bsz; b++) {
    alpha[b * tgt_len * src_len] = 1.0;
  }

  for (uint32_t tgt = 0; tgt < tgt_len; tgt++) {
    for (uint32_t b = 0; b < bsz; b++) {
      uint32_t alpha_idx, inout_idx;
      T prev_scan = 0, curr_scan, out;
      for (uint32_t src = 0; src < src_len; src++) {
        // Apply scan/cumsum
        if (tgt == 0) {
          // alpha index is [b][tgt][src]
          alpha_idx = b * tgt_len * src_len + src;
        } else {
          // alpha index is [b][tgt-1][src]
          alpha_idx = b * tgt_len * src_len + (tgt - 1) * src_len + src;
        }
        // input index is [b][tgt][src]
        inout_idx = b * tgt_len * src_len + tgt * src_len + src;
        curr_scan = prev_scan + alpha[alpha_idx] / cumprod_1mp_clamp[inout_idx];

        out = curr_scan * p_choose[inout_idx] * cumprod_1mp[inout_idx];
        alpha[inout_idx] = std::min<T>(std::max<T>(out, 0), 1.0);
        prev_scan = curr_scan;
      }
    }
  }

  free(cumprod_1mp);
  free(cumprod_1mp_clamp);
}

void alignmentTrainCPU(
    const torch::Tensor& p_choose,
    torch::Tensor& alpha,
    float eps) {
  uint32_t bsz = p_choose.size(0);
  uint32_t tgt_len = p_choose.size(1);
  uint32_t src_len = p_choose.size(2);

  AT_DISPATCH_FLOATING_TYPES_AND2(
      torch::ScalarType::Half,
      torch::ScalarType::BFloat16,
      p_choose.scalar_type(),
      "alignmentCPUImpl",
      [&]() {
        alignmentTrainCPUImpl<scalar_t>(
            p_choose.data_ptr<scalar_t>(),
            alpha.data_ptr<scalar_t>(),
            bsz,
            tgt_len,
            src_len,
            eps);
      });
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def(
      "alignment_train_cpu",
      &alignmentTrainCPU,
      "expected_alignment_from_p_choose (CPU)");
}

} // namespace