1 // Copyright (c) 2018 MathInf GmbH, Thomas Viehmann
2 // Licensed under the BSD-3-Clause license
3 // This is the GPU implementation of the Connectionist Temporal Loss.
4 // We mostly follow Graves.
5 // 1. Graves et al.: http://www.cs.toronto.edu/~graves/icml_2006.pdf
6 // We use the equations from above link, but note that [1] has 1-based indexing and we (of course) use 0-based.
7 // Graves et al. call the probabilities y, we use log_probs (also calling them inputs)
8 // A few optimizations (similar to those here, but also some I didn't take) are described in
9 // 2. Minmin Sun: http://on-demand.gputechconf.com/gtc/2016/presentation/s6383-minmin-sun-speech-recognition.pdf
10 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
11 #include <ATen/TensorUtils.h>
12 #include <c10/util/Exception.h>
13 #include <c10/macros/Macros.h>
14 #include <ATen/core/Tensor.h>
15 #include <ATen/Dispatch.h>
16 #include <ATen/TensorOperators.h>
17 #include <ATen/cuda/Atomic.cuh>
18 #include <ATen/cuda/CUDAContext.h>
19
20 #ifndef AT_PER_OPERATOR_HEADERS
21 #include <ATen/Functions.h>
22 #include <ATen/NativeFunctions.h>
23 #else
24 #include <ATen/ops/_ctc_loss_backward_native.h>
25 #include <ATen/ops/_ctc_loss_native.h>
26 #include <ATen/ops/empty.h>
27 #include <ATen/ops/exp.h>
28 #include <ATen/ops/full_like.h>
29 #include <ATen/ops/imag.h>
30 #include <ATen/ops/logsumexp.h>
31 #include <ATen/ops/tensor.h>
32 #include <ATen/ops/where.h>
33 #include <ATen/ops/zeros.h>
34 #endif
35
36 #include <type_traits>
37 #include <numeric>
38
39 namespace at::native {
40
41 namespace {
42
43 // this ad-hoc converts from targets (l in [1]) to augmented targets (l' in [1])
44 // so if l is l_0 l_1 ... l_(tl-1) then this looks up idx in
45 // l' = BLANK l_0 BLANK l_1 BLANK ... BLANK l_(tl-1) BLANK
46 // - note that no bound-checking is done
47 // - it is important to only call it with idx == 0 if the target length is 0
48 // - __restrict__ impact to be measured, see
49 // https://devblogs.nvidia.com/cuda-pro-tip-optimize-pointer-aliasing/
50 template <typename target_t>
get_target_prime(const target_t * __restrict__ target,int64_t offset,int64_t stride,int64_t idx,int64_t BLANK)51 __device__ static inline int64_t get_target_prime(
52 const target_t* __restrict__ target,
53 int64_t offset,
54 int64_t stride,
55 int64_t idx,
56 int64_t BLANK) {
57 if (idx % 2 == 0) {
58 return BLANK;
59 } else {
60 return target[offset + stride * (idx / 2)];
61 }
62 }
63
64 // this kernel is a relatively straightforward implementation of the alpha calculation in the forward backward algorithm (section 4.1).
65 // A (minor) twist is that we are using log-calculations to enhance numerical stability (log_probs and log_alpha).
66 // In total it would be more efficient to compute the beta in the same kernel (e.g. cudnn does this). While the beta are not
67 // needed for the loss itself (just the grad), we can return log_alpha+log_beta (so same space as currently) and the overhead
68 // is small and the use-case for loss without grad is relatively limited.
69 // We parallelize by batch and target sequence. Empirically, it is faster to loop over the input (log probs) sequence and do
70 // target in parallel, even if it means more frequent __syncthreads.
71 // In contrast to the cuDNN implementation, we allow large target lengths. For this we need that all previous `s` have been
72 // computed when we start a new block_s. This is why we have our own for loop here.
73 template<typename scalar_t, typename target_t>
74 __global__ void
75 #if defined (USE_ROCM)
76 C10_LAUNCH_BOUNDS_2((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
77 #endif
ctc_loss_log_alpha_gpu_kernel(scalar_t * __restrict__ log_alpha_data,const scalar_t * log_probs_data,const int64_t * __restrict__ input_lengths,int64_t max_input_length,const target_t * __restrict__ targets_data,const int64_t * __restrict__ target_lengths,int64_t max_target_length,scalar_t * __restrict__ neg_log_likelihood_data,int64_t lp_input_stride,int64_t lp_batch_stride,int64_t lp_char_stride,int64_t la_batch_stride,int64_t la_input_stride,int64_t la_target_stride,const int64_t * __restrict__ tg_batch_offsets,int64_t tg_target_stride,int64_t batch_size,int64_t BLANK)78 ctc_loss_log_alpha_gpu_kernel(scalar_t* __restrict__ log_alpha_data,
79 const scalar_t*log_probs_data, const int64_t* __restrict__ input_lengths, int64_t max_input_length,
80 const target_t* __restrict__ targets_data, const int64_t* __restrict__ target_lengths, int64_t max_target_length,
81 scalar_t* __restrict__ neg_log_likelihood_data,
82 int64_t lp_input_stride, int64_t lp_batch_stride, int64_t lp_char_stride,
83 int64_t la_batch_stride, int64_t la_input_stride, int64_t la_target_stride,
84 const int64_t* __restrict__ tg_batch_offsets, int64_t tg_target_stride,
85 int64_t batch_size, int64_t BLANK) {
86
87 constexpr scalar_t neginf = -INFINITY;
88
89 // bookkeeping
90 int64_t b = threadIdx.y + blockIdx.y * blockDim.y;
91 int64_t input_length = input_lengths[b];
92 int64_t target_length = target_lengths[b];
93 int64_t lp_batch_offset = b*lp_batch_stride;
94 int64_t la_batch_offset = b*la_batch_stride;
95 int64_t tg_batch_offset = tg_batch_offsets[b];
96
97 if (b >= batch_size)
98 return;
99
100 if (input_length == 0) {
101 if (threadIdx.x == 0) {
102 scalar_t log_likelihood = target_length == 0 ? 0 : neginf;
103 neg_log_likelihood_data[b] = -log_likelihood;
104 }
105 return;
106 }
107
108 // first row (t=0), the three equations for alpha_1 above eq (6)
109 for (int64_t block_s = 0; block_s < 2*max_target_length+1; block_s += blockDim.x) {
110 int64_t s = threadIdx.x + block_s;
111 scalar_t la;
112 switch (s) {
113 case 0:
114 la = log_probs_data[lp_batch_offset + lp_char_stride * BLANK];
115 break;
116 case 1:
117 la = target_length == 0 ? neginf
118 : log_probs_data
119 [lp_batch_offset +
120 lp_char_stride *
121 get_target_prime(
122 targets_data,
123 tg_batch_offset,
124 tg_target_stride,
125 1,
126 BLANK)];
127 break;
128 default:
129 la = neginf;
130 }
131 if (s < 2*max_target_length+1)
132 log_alpha_data[la_batch_offset + /* la_input_stride * 0 */ + la_target_stride * s] = la;
133 }
134
135 for (int64_t block_s = 0; block_s < 2*max_target_length+1; block_s += blockDim.x) {
136 int64_t s = threadIdx.x + block_s;
137
138 // These two only depend on s, so we can cache them.
139 int64_t current_char; // l_s in eq (6)
140 bool have_three; // flag which of the two cases in eq (6) we have
141 if (s < 2 * target_length + 1 && target_length > 0) {
142 current_char = get_target_prime(
143 targets_data,
144 tg_batch_offset,
145 tg_target_stride,
146 s,
147 BLANK);
148 have_three =
149 ((s > 1) &&
150 (get_target_prime(
151 targets_data,
152 tg_batch_offset,
153 tg_target_stride,
154 s - 2,
155 BLANK) != current_char));
156 } else {
157 current_char = BLANK;
158 have_three = false;
159 }
160 for (int64_t t=1; t < max_input_length; t++) {
161 __syncthreads(); // on cuda 9 we might use partial synchronization of only the threads within the same batch
162 if ((t < input_length) && (s < 2 * target_length + 1)) {
163 // only for valid t, s. This is equation (6) and (7), la1, la2, la3 are the three summands,
164 // lamax is the maximum for the logsumexp trick.
165 scalar_t la1 = log_alpha_data[la_batch_offset + la_input_stride * (t-1) + la_target_stride * s];
166 scalar_t lamax = la1;
167 scalar_t la2, la3;
168 if (s > 0) {
169 la2 = log_alpha_data[la_batch_offset + la_input_stride * (t-1) + la_target_stride * (s-1)];
170 if (la2 > lamax)
171 lamax = la2;
172 } else {
173 la2 = neginf;
174 }
175 if (have_three) {
176 la3 = log_alpha_data[la_batch_offset + la_input_stride * (t-1) + la_target_stride * (s-2)];
177 if (la3 > lamax)
178 lamax = la3;
179 } else {
180 la3 = neginf;
181 }
182 if (lamax == neginf) // when all are neginf. (then the whole thing is neginf, but we can pretend)
183 lamax = 0;
184
185 log_alpha_data[la_batch_offset + la_input_stride * t + la_target_stride * s] = std::log(std::exp(la1-lamax)+std::exp(la2-lamax)+std::exp(la3-lamax))+lamax
186 + log_probs_data[lp_batch_offset + t * lp_input_stride + lp_char_stride * current_char];
187 } else {
188 // otherwise we just set to neginf
189 if (s < 2*max_target_length+1)
190 log_alpha_data[la_batch_offset + la_input_stride * t + la_target_stride * s] = neginf;
191 }
192 }
193 }
194 __syncthreads(); // on cuda 9 we might use partial synchronization of only the threads within the same batch
195
196 // compute the loss (eq (8))
197 if (threadIdx.x == 0) {
198 scalar_t l1 = log_alpha_data[la_batch_offset + la_input_stride * (input_length-1) + la_target_stride * (target_length*2)];
199 scalar_t l2 = target_length > 0
200 ? log_alpha_data
201 [la_batch_offset + la_input_stride * (input_length - 1) +
202 la_target_stride * (target_length * 2 - 1)]
203 : neginf;
204 scalar_t m = ((l1 > l2) ? l1 : l2);
205 m = ((m == neginf) ? 0 : m);
206 scalar_t log_likelihood = std::log(std::exp(l1-m)+std::exp(l2-m))+m;
207 neg_log_likelihood_data[b] = -log_likelihood;
208 }
209 }
210
211 // The forward computation. Lot's of admin and a call to the alpha kernel.
212 // Note: we do not check that the labels are in the valid range. As we use
213 // them for indexing in the kernels, you'll see memory errors when you
214 // pass corrupt labels.
215 // We support both a 2-dimensional tensor as targets (one set of targets in each row) and
216 // a 1-dimensional tensor where all targets are concatenated (and we use target_lengths
217 // to figure out where they begin).
218 // We return log_alpha (currently, might change to (log_alpha+log_beta) to be passed to the
219 // backward. The dispatch function will only return the loss.
220 template<typename scalar_t, ScalarType target_scalar_type>
ctc_loss_gpu_template(const Tensor & log_probs,const Tensor & targets,IntArrayRef input_lengths,IntArrayRef target_lengths,int64_t BLANK)221 std::tuple<Tensor, Tensor> ctc_loss_gpu_template(const Tensor& log_probs, const Tensor& targets, IntArrayRef input_lengths, IntArrayRef target_lengths, int64_t BLANK) {
222 // log_probs: input_len x batch_size x num_labels
223 // targets [int64]: batch_size x target_length OR sum(target_lengths)
224 CheckedFrom c = "ctc_loss_gpu";
225 using target_t = typename std::conditional<target_scalar_type == kInt, int, int64_t>::type;
226 auto log_probs_arg = TensorArg(log_probs, "log_probs", 1);
227 auto targets_arg = TensorArg(targets, "targets", 2);
228 checkAllSameGPU(c, {log_probs_arg, targets_arg});
229
230 checkScalarType(c, targets_arg, target_scalar_type);
231 checkDim(c, log_probs_arg, 3);
232 checkDimRange(c, targets_arg, 1, 3);
233
234 int64_t batch_size = log_probs.size(1);
235 int64_t num_labels = log_probs.size(2);
236 TORCH_CHECK((0 <= BLANK) && (BLANK < num_labels), "blank must be in label range");
237 TORCH_CHECK(input_lengths.size() == static_cast<size_t>(batch_size), "input_lengths must be of size batch_size");
238 TORCH_CHECK(target_lengths.size() == static_cast<size_t>(batch_size), "target_lengths must be of size batch_size");
239
240 int64_t tg_target_stride;
241
242 int64_t max_target_length = 0;
243 auto tg_batch_offsets = at::empty({batch_size}, at::device(at::kCPU).dtype(at::kLong));
244 auto tg_batch_offsets_data = tg_batch_offsets.mutable_data_ptr<int64_t>();
245 if (targets.dim() == 1) { // concatenated targets
246 int64_t pos = 0;
247 for (int64_t i = 0; i < batch_size; i++) {
248 TORCH_CHECK(target_lengths[i] >= 0,
249 "Expected target_lengths to have value at least ", 0, ", but got value ", target_lengths[i],
250 " (while checking arguments for ", c, ")");
251 tg_batch_offsets_data[i] = pos;
252 pos += target_lengths[i];
253 if (max_target_length < target_lengths[i])
254 max_target_length = target_lengths[i];
255 }
256 tg_target_stride = targets.stride(0);
257 checkSize(c, targets_arg, 0, pos);
258 }
259 else { // batch x max_target_length
260 // dim is 2
261 int64_t tg_batch_stride = targets.stride(0);
262 for (int64_t i = 0; i < batch_size; i++) {
263 TORCH_CHECK(target_lengths[i] >= 0,
264 "Expected target_lengths to have value at least ", 0, ", but got value ", target_lengths[i],
265 " (while checking arguments for ", c, ")");
266 tg_batch_offsets_data[i] = i * tg_batch_stride;
267 if (max_target_length < target_lengths[i])
268 max_target_length = target_lengths[i];
269 }
270 tg_target_stride = targets.stride(1);
271 checkSize(c, targets_arg, 0, batch_size);
272 TORCH_CHECK(targets.size(1) >= max_target_length,
273 "Expected tensor to have size at least ", max_target_length, " at dimension 1, but got size ", targets.size(1), " for ", targets_arg,
274 " (while checking arguments for ", c, ")");
275 }
276 int64_t max_input_length = log_probs.size(0);
277 for (int64_t b = 0; b < batch_size; b++) {
278 TORCH_CHECK(input_lengths[b] >= 0,
279 "Expected input_lengths to have value at least ", 0, ", but got value ", input_lengths[b],
280 " (while checking arguments for ", c, ")");
281 TORCH_CHECK(input_lengths[b] <= max_input_length,
282 "Expected input_lengths to have value at most ", max_input_length, ", but got value ", input_lengths[b],
283 " (while checking arguments for ", c, ")");
284 }
285
286 auto target_lengths_t = at::tensor(target_lengths, targets.options().dtype(kLong));
287 auto input_lengths_t = at::tensor(input_lengths, targets.options().dtype(kLong));
288 tg_batch_offsets = tg_batch_offsets.cuda();
289
290 Tensor log_alpha = at::empty({batch_size, log_probs.size(0), 2*max_target_length+1}, log_probs.options());
291 Tensor neg_log_likelihood = at::empty({batch_size}, log_probs.options());
292
293 // Very likely, we could be more clever here, e.g. learning (or generalizing and reusing) from SoftMax.cu...
294 constexpr int max_threads = std::is_same<scalar_t, float>::value ? 1024 : 768; // we need 72 or so 32 bit registers for double
295 int threads_target = max_threads;
296 while (threads_target / 2 >= 2*max_target_length+1) {
297 threads_target /= 2;
298 }
299 int threads_batch = std::min(max_threads / threads_target, (int) batch_size);
300 dim3 block(threads_target, threads_batch);
301 dim3 grid(1, (batch_size+threads_batch-1)/threads_batch);
302 cudaStream_t stream = at::cuda::getCurrentCUDAStream();
303
304 ctc_loss_log_alpha_gpu_kernel<scalar_t, target_t><<<grid, block, 0, stream>>>(
305 log_alpha.mutable_data_ptr<scalar_t>(),
306 log_probs.const_data_ptr<scalar_t>(), input_lengths_t.const_data_ptr<int64_t>(), log_probs.size(0),
307 targets.const_data_ptr<target_t>(), target_lengths_t.const_data_ptr<int64_t>(), max_target_length,
308 neg_log_likelihood.mutable_data_ptr<scalar_t>(),
309 log_probs.stride(0), log_probs.stride(1), log_probs.stride(2),
310 log_alpha.stride(0), log_alpha.stride(1), log_alpha.stride(2),
311 tg_batch_offsets.const_data_ptr<int64_t>(), tg_target_stride,
312 batch_size, BLANK);
313 C10_CUDA_KERNEL_LAUNCH_CHECK();
314 return std::make_tuple(neg_log_likelihood, log_alpha);
315 }
316
317 // The second (backward) half of the forward backward algorithm, (10) and (11). This is parallel to the
318 // alpha kernel above. (As mentioned above, it might make sense do the calculation in the alpha kernel.)
319 template<typename scalar_t, typename target_t>
320 __global__ void
321 C10_LAUNCH_BOUNDS_2((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
ctc_loss_backward_log_beta_gpu_kernel(scalar_t * __restrict__ log_beta_data,const scalar_t * log_probs_data,const int64_t * __restrict__ input_lengths,int64_t max_input_length,const target_t * __restrict__ targets_data,const int64_t * __restrict__ target_lengths,int64_t max_target_length,int64_t lp_input_stride,int64_t lp_batch_stride,int64_t lp_char_stride,int64_t lb_batch_stride,int64_t lb_input_stride,int64_t lb_target_stride,const int64_t * __restrict__ tg_batch_offsets,int64_t tg_target_stride,int64_t batch_size,int64_t BLANK)322 ctc_loss_backward_log_beta_gpu_kernel(scalar_t* __restrict__ log_beta_data,
323 const scalar_t*log_probs_data, const int64_t* __restrict__ input_lengths, int64_t max_input_length,
324 const target_t* __restrict__ targets_data, const int64_t* __restrict__ target_lengths, int64_t max_target_length,
325 int64_t lp_input_stride, int64_t lp_batch_stride, int64_t lp_char_stride,
326 int64_t lb_batch_stride, int64_t lb_input_stride, int64_t lb_target_stride,
327 const int64_t* __restrict__ tg_batch_offsets, int64_t tg_target_stride,
328 int64_t batch_size, int64_t BLANK) {
329 constexpr scalar_t neginf = -INFINITY;
330
331 int64_t b = threadIdx.y + blockIdx.y * blockDim.y;
332
333 int64_t input_length = input_lengths[b];
334 int64_t target_length = target_lengths[b];
335 int64_t lp_batch_offset = b*lp_batch_stride;
336 int64_t lb_batch_offset = b*lb_batch_stride;
337 int64_t tg_batch_offset = tg_batch_offsets[b];
338
339 if (b >= batch_size)
340 return;
341
342 if (input_length == 0)
343 return;
344
345 // "first" row, the beta initialization before eq (10) (t=target_length - differes per batch)
346 for (int64_t block_s = 2*max_target_length - (2*max_target_length % blockDim.x); block_s >= 0; block_s -= blockDim.x) {
347 int64_t s = threadIdx.x + block_s;
348 scalar_t lb;
349 if (s == 2*target_length) {
350 lb = log_probs_data[lp_batch_offset + (input_length-1) * lp_input_stride + lp_char_stride * BLANK];
351 } else if (s == 2 * target_length - 1) { // false for target_length == 0
352 int64_t current_target_prime = get_target_prime(
353 targets_data,
354 tg_batch_offset,
355 tg_target_stride,
356 s,
357 BLANK);
358 lb = log_probs_data[lp_batch_offset + (input_length-1) * lp_input_stride + lp_char_stride * current_target_prime];
359 } else {
360 lb = neginf;
361 }
362 if (s < 2*max_target_length+1) {
363 log_beta_data[lb_batch_offset + (input_length-1) * lb_input_stride + lb_target_stride * s] = lb;
364 }
365 }
366
367 // go backward in s
368 for (int64_t block_s = 2*max_target_length - (2*max_target_length % blockDim.x); block_s >= 0; block_s -= blockDim.x) {
369 int64_t s = threadIdx.x + block_s;
370 int64_t current_target_prime;
371 bool have_three;
372 if (s < 2 * target_length + 1 && target_length > 0) {
373 current_target_prime = get_target_prime(
374 targets_data,
375 tg_batch_offset,
376 tg_target_stride,
377 s,
378 BLANK);
379 have_three =
380 ((s < 2 * target_length - 1) &&
381 (get_target_prime(
382 targets_data,
383 tg_batch_offset,
384 tg_target_stride,
385 s + 2,
386 BLANK) != current_target_prime));
387 } else {
388 current_target_prime = BLANK;
389 have_three = false;
390 }
391 // now go backward in t. Note that we need to skip the last timestep that we did above.
392 for (int64_t t=max_input_length-2; t>=0; t--) {
393 __syncthreads(); // on cuda 9 we might use partial synchronization of only the threads within the same batch item
394 if ((t < input_length - 1) && (s < 2 * target_length + 1)) {
395 scalar_t lb1 = log_beta_data[lb_batch_offset + lb_input_stride * (t+1) + lb_target_stride * s];
396 scalar_t lbmax = lb1;
397 scalar_t lb2, lb3;
398
399 if (s < 2*target_length) {
400 lb2 = log_beta_data[lb_batch_offset + lb_input_stride * (t+1) + lb_target_stride * (s+1)];
401 if (lb2 > lbmax)
402 lbmax = lb2;
403 } else {
404 lb2 = neginf;
405 }
406 if (have_three) {
407 lb3 = log_beta_data[lb_batch_offset + lb_input_stride * (t+1) + lb_target_stride * (s+2)];
408 if (lb3 > lbmax)
409 lbmax = lb3;
410 } else {
411 lb3 = neginf;
412 }
413 if (lbmax == neginf)
414 lbmax = 0;
415
416 scalar_t lb = std::log(std::exp(lb1-lbmax)+std::exp(lb2-lbmax)+std::exp(lb3-lbmax))+lbmax
417 + log_probs_data[lp_batch_offset + t * lp_input_stride + lp_char_stride * current_target_prime];
418
419 log_beta_data[lb_batch_offset + lb_input_stride * t + lb_target_stride * s] = lb;
420 } else if (
421 (s < 2 * max_target_length + 1) &&
422 (((target_length == 0) && (s > 0)) || (s >= 2 * target_length + 1) ||
423 (t >= input_length))) {
424 log_beta_data
425 [lb_batch_offset + lb_input_stride * t + lb_target_stride * s] =
426 neginf;
427 }
428 }
429 }
430 }
431
432 // This implements the subtrahend of equation (16) for all *nonblank* characters.
433 // It assumes you have probs in gradient_data when called
434 // and it modifies gradient_data to be, the gradient.
435 // In order to facilitate this inplace update, We don't actually do this in logspace.
436 // (The other variant implemented uses log_space and the differences seem to be
437 // not so problematic at least with unit normal distributed test activations.)
438 // Internally this uses atomicAdd because different threads may write to the same
439 // gradient position.
440 // This is parallelised over b and s again.
441 // Note that for us, the Z of eqn (16) is actually constant for all t and it is the
442 // likelihood - this is why we use the negative log likelihood below.
443 // We also multiply by the input gradient to keep with standard autograd style.
444 // I took this trick from [2], for moderate alphabet sizes a log-space
445 // calculation (with an atomic log add) is similarly in performance, but for large
446 // alphabets the inplace nature is a considerable advantage.
447 template<typename scalar_t, typename target_t>
448 __global__ void
449 #if defined (USE_ROCM)
450 C10_LAUNCH_BOUNDS_2((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
451 #endif
ctc_loss_backward_collect_nonblank_gpu_kernel(scalar_t * __restrict__ gradient_data,const scalar_t * __restrict__ grad_out_data,int64_t grad_out_batch_stride,const scalar_t * __restrict__ log_alpha_data,const scalar_t * __restrict__ log_beta_data,const scalar_t * log_probs_data,const int64_t * __restrict__ input_lengths,const target_t * __restrict__ targets_data,const int64_t * __restrict__ target_lengths,const scalar_t * __restrict__ neg_log_likelihood_data,int64_t gr_input_stride,int64_t gr_batch_stride,int64_t gr_char_stride,int64_t lp_input_stride,int64_t lp_batch_stride,int64_t lp_char_stride,int64_t la_batch_stride,int64_t la_input_stride,int64_t la_target_stride,int64_t lb_batch_stride,int64_t lb_input_stride,int64_t lb_target_stride,const int64_t * __restrict__ tg_batch_offsets,int64_t tg_target_stride,int64_t batch_size,bool zero_infinity)452 ctc_loss_backward_collect_nonblank_gpu_kernel(scalar_t* __restrict__ gradient_data,
453 const scalar_t* __restrict__ grad_out_data, int64_t grad_out_batch_stride,
454 const scalar_t* __restrict__ log_alpha_data, const scalar_t* __restrict__ log_beta_data,
455 const scalar_t*log_probs_data, const int64_t* __restrict__ input_lengths,
456 const target_t* __restrict__ targets_data, const int64_t* __restrict__ target_lengths,
457 const scalar_t* __restrict__ neg_log_likelihood_data,
458 int64_t gr_input_stride, int64_t gr_batch_stride, int64_t gr_char_stride,
459 int64_t lp_input_stride, int64_t lp_batch_stride, int64_t lp_char_stride,
460 int64_t la_batch_stride, int64_t la_input_stride, int64_t la_target_stride,
461 int64_t lb_batch_stride, int64_t lb_input_stride, int64_t lb_target_stride,
462 const int64_t* __restrict__ tg_batch_offsets, int64_t tg_target_stride,
463 int64_t batch_size, bool zero_infinity) {
464 int64_t b = threadIdx.y + blockIdx.y * blockDim.y;
465 int64_t s = threadIdx.x + blockIdx.x * blockDim.x; // note, this directly indexes into targets, not targets prime!
466
467 if (b >= batch_size)
468 return;
469
470 int64_t input_length = input_lengths[b];
471 int64_t target_length = target_lengths[b];
472 int64_t gr_batch_offset = b*gr_batch_stride;
473 int64_t lp_batch_offset = b*lp_batch_stride;
474 int64_t la_batch_offset = b*la_batch_stride;
475 int64_t lb_batch_offset = b*lb_batch_stride;
476 int64_t tg_batch_offset = tg_batch_offsets[b];
477
478 if (s >= target_length)
479 return;
480
481 int64_t target = targets_data[tg_batch_offset + s * tg_target_stride];
482 scalar_t nll = neg_log_likelihood_data[b];
483 scalar_t gr = grad_out_data[b * grad_out_batch_stride];
484
485 if (zero_infinity && nll == INFINITY)
486 return;
487
488 for (int64_t t = 0; t < input_length; t++) {
489 scalar_t lp = log_probs_data[lp_batch_offset + t * lp_input_stride + lp_char_stride * target];
490 gpuAtomicAddNoReturn(&gradient_data[gr_batch_offset + t * gr_input_stride + gr_char_stride * target],
491 -std::exp(log_alpha_data[la_batch_offset + la_input_stride * t + la_target_stride * (s*2+1)]
492 + log_beta_data[lb_batch_offset + lb_input_stride * t + lb_target_stride * (s*2+1)]
493 + nll - lp) * gr);
494 }
495 }
496
497 // This is the naive implementation of equation (16). It is parallelised in batch and input timestep.
498 // It appears to be faster than the above method for small batch sizes.
499 template<typename scalar_t, typename target_t>
500 __global__ void
501 #if defined (USE_ROCM)
502 C10_LAUNCH_BOUNDS_2((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
503 #endif
ctc_loss_backward_collect_gpu_kernel(scalar_t * __restrict__ gradient_data,const scalar_t * __restrict__ grad_out_data,int64_t grad_out_batch_stride,const scalar_t * __restrict__ log_alpha_data,const scalar_t * __restrict__ log_beta_data,const scalar_t * log_probs_data,const int64_t * __restrict__ input_lengths,int64_t max_input_length,const target_t * __restrict__ targets_data,const int64_t * __restrict__ target_lengths,int64_t max_target_length,const scalar_t * __restrict__ neg_log_likelihood_data,int64_t gr_input_stride,int64_t gr_batch_stride,int64_t gr_char_stride,int64_t lp_input_stride,int64_t lp_batch_stride,int64_t lp_char_stride,int64_t la_batch_stride,int64_t la_input_stride,int64_t la_target_stride,int64_t lb_batch_stride,int64_t lb_input_stride,int64_t lb_target_stride,const int64_t * __restrict__ tg_batch_offsets,int64_t tg_target_stride,int64_t batch_size,int64_t num_labels,int64_t BLANK,bool zero_infinity)504 ctc_loss_backward_collect_gpu_kernel(scalar_t* __restrict__ gradient_data,
505 const scalar_t* __restrict__ grad_out_data, int64_t grad_out_batch_stride,
506 const scalar_t* __restrict__ log_alpha_data, const scalar_t* __restrict__ log_beta_data,
507 const scalar_t*log_probs_data, const int64_t* __restrict__ input_lengths, int64_t max_input_length,
508 const target_t* __restrict__ targets_data, const int64_t* __restrict__ target_lengths, int64_t max_target_length,
509 const scalar_t* __restrict__ neg_log_likelihood_data,
510 int64_t gr_input_stride, int64_t gr_batch_stride, int64_t gr_char_stride,
511 int64_t lp_input_stride, int64_t lp_batch_stride, int64_t lp_char_stride,
512 int64_t la_batch_stride, int64_t la_input_stride, int64_t la_target_stride,
513 int64_t lb_batch_stride, int64_t lb_input_stride, int64_t lb_target_stride,
514 const int64_t* __restrict__ tg_batch_offsets, int64_t tg_target_stride,
515 int64_t batch_size, int64_t num_labels, int64_t BLANK, bool zero_infinity) {
516
517 constexpr scalar_t neginf = -INFINITY;
518 int64_t b = threadIdx.y + blockIdx.y * blockDim.y;
519 int64_t t = threadIdx.x + blockIdx.x * blockDim.x;
520
521 if ((t >= max_input_length) || (b >= batch_size))
522 return;
523
524 int64_t input_length = input_lengths[b];
525 int64_t target_length = target_lengths[b];
526 int64_t gr_batch_offset = b*gr_batch_stride;
527 int64_t lp_batch_offset = b*lp_batch_stride;
528 int64_t la_batch_offset = b*la_batch_stride;
529 int64_t lb_batch_offset = b*lb_batch_stride;
530 int64_t tg_batch_offset = tg_batch_offsets[b];
531
532 // collected[b, t, target'[s]] "log+=" log_alpha[t, s]+log_beta[t, s]
533 for (int s = 0; s < 2*max_target_length+1; s++) {
534 if (s < 2 * target_length + 1) { // if target_length == 0, s == 0
535 int64_t current_target_prime = get_target_prime(
536 targets_data,
537 tg_batch_offset,
538 tg_target_stride,
539 s,
540 BLANK);
541 scalar_t log_alpha_beta = (log_alpha_data[la_batch_offset + la_input_stride * t + la_target_stride * s]
542 + log_beta_data[lb_batch_offset + lb_input_stride * t + lb_target_stride * s]);
543 scalar_t& lcab = gradient_data[gr_batch_offset + t * gr_input_stride + gr_char_stride * current_target_prime];
544 if (lcab == neginf) {
545 lcab = log_alpha_beta;
546 } else {
547 scalar_t max = ((lcab > log_alpha_beta) ? lcab : log_alpha_beta);
548 lcab = std::log(std::exp(lcab-max)+std::exp(log_alpha_beta-max))+max;
549 }
550 }
551 }
552
553 scalar_t nll = neg_log_likelihood_data[b];
554 scalar_t gr = grad_out_data[b * grad_out_batch_stride];
555
556 for (int64_t c = 0; c < num_labels; c++) {
557 scalar_t& res = gradient_data[gr_batch_offset + t * gr_input_stride + gr_char_stride * c];
558 if (t < input_length && (! zero_infinity || nll != INFINITY)) {
559 scalar_t lp = log_probs_data[lp_batch_offset + t * lp_input_stride + lp_char_stride * c];
560 res = (std::exp(lp)-std::exp(res + nll - lp)) * gr;
561 }
562 else {
563 res = 0.;
564 }
565 }
566 }
567
568 // This is to zero gradients which corresponding to the out-of-sequence position
569 // Those gradients should not be used in any model update since the input
570 // elements are padded
571 template<typename scalar_t>
572 __global__ void
573 #if defined (USE_ROCM)
574 C10_LAUNCH_BOUNDS_2((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
575 #endif
ctc_loss_zero_padded_gradients(scalar_t * __restrict__ gradient_data,const int64_t * __restrict__ input_lengths,int64_t gr_timestep_stride,int64_t gr_batch_stride,int64_t gr_label_stride,int64_t max_input_length,int64_t batch_size,int64_t num_labels)576 ctc_loss_zero_padded_gradients(
577 scalar_t* __restrict__ gradient_data, /* (T, B, D) layout */
578 const int64_t* __restrict__ input_lengths, /* (B, ) layout */
579 int64_t gr_timestep_stride,
580 int64_t gr_batch_stride,
581 int64_t gr_label_stride,
582 int64_t max_input_length, /* T */
583 int64_t batch_size, /* B */
584 int64_t num_labels /* D */ ) {
585 int64_t b = threadIdx.y + blockIdx.y * blockDim.y;
586 int64_t t = threadIdx.x + blockIdx.x * blockDim.x;
587
588 if (b >= batch_size || t >= max_input_length) {
589 return;
590 }
591
592 scalar_t input_length = input_lengths[b];
593 if (t >= input_length) {
594 for (int l = 0; l < num_labels; l++)
595 gradient_data[
596 t * gr_timestep_stride + b * gr_batch_stride + l * gr_label_stride]
597 = 0.0f;
598 }
599 }
600
601
602 // The backward. It essentially computes eq 16 by using the above kernels.
603 // We don't do a lot of checking as we envision this to be called only when backpropagating through a (well-checked) forward.
604 template<typename scalar_t, ScalarType target_scalar_type>
ctc_loss_backward_gpu_template(const Tensor & grad_out,const Tensor & log_probs,const Tensor & targets,IntArrayRef input_lengths,IntArrayRef target_lengths,const Tensor & neg_log_likelihood,const Tensor & log_alpha,int64_t BLANK,bool zero_infinity)605 Tensor ctc_loss_backward_gpu_template(const Tensor& grad_out, const Tensor& log_probs, const Tensor& targets, IntArrayRef input_lengths, IntArrayRef target_lengths,
606 const Tensor& neg_log_likelihood, const Tensor& log_alpha, int64_t BLANK, bool zero_infinity) {
607 constexpr scalar_t neginf = -INFINITY;
608 using target_t = typename std::conditional<target_scalar_type == kInt, int, int64_t>::type;
609 int64_t batch_size = log_probs.size(1);
610 int64_t num_labels = log_probs.size(2);
611 int64_t tg_target_stride;
612
613 int64_t max_target_length;
614 auto tg_batch_offsets = at::empty({batch_size}, TensorOptions(at::CPU(kLong)));
615 auto tg_batch_offsets_data = tg_batch_offsets.mutable_data_ptr<int64_t>();
616 if (targets.dim() == 1) { // concatenated targets
617 int64_t pos = 0;
618 max_target_length = 0;
619 for (int64_t i = 0; i < batch_size; i++) {
620 tg_batch_offsets_data[i] = pos;
621 pos += target_lengths[i];
622 if (max_target_length < target_lengths[i])
623 max_target_length = target_lengths[i];
624 }
625 tg_target_stride = targets.stride(0);
626 }
627 else { // batch x max_target_length
628 // dim is 2
629 int64_t tg_batch_stride = targets.stride(0);
630 for (int64_t i = 0; i < batch_size; i++) {
631 tg_batch_offsets_data[i] = i * tg_batch_stride;
632 }
633 tg_target_stride = targets.stride(1);
634 max_target_length = log_alpha.size(2)/2; // targets.size(1) might be larger
635 }
636 auto target_lengths_t = at::tensor(target_lengths, targets.options().dtype(kLong));
637 auto input_lengths_t = at::tensor(input_lengths, targets.options().dtype(kLong));
638 tg_batch_offsets = tg_batch_offsets.cuda();
639
640 Tensor log_beta = at::empty_like(log_alpha, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
641 log_beta.fill_(neginf);
642
643 Tensor grad = at::full_like(log_probs, neginf, LEGACY_CONTIGUOUS_MEMORY_FORMAT); // initialization for log(sum (alpha beta))
644
645 // As above, there may be better configurations to use.
646 constexpr int max_threads = std::is_same<scalar_t, float>::value ? 1024 : 896; // we need 72 or so 32 bit registers for double
647 int threads_target = max_threads;
648 while (threads_target / 2 >= 2*max_target_length+1) {
649 threads_target /= 2;
650 }
651 int threads_batch = std::min(max_threads / threads_target, (int) batch_size);
652
653 cudaStream_t stream = at::cuda::getCurrentCUDAStream();
654
655 {
656 dim3 block(threads_target, threads_batch);
657 dim3 grid(1, (batch_size+threads_batch-1)/threads_batch);
658 ctc_loss_backward_log_beta_gpu_kernel<scalar_t, target_t><<<grid, block, 0, stream>>>
659 (log_beta.mutable_data_ptr<scalar_t>(),
660 log_probs.const_data_ptr<scalar_t>(), input_lengths_t.const_data_ptr<int64_t>(), log_probs.size(0),
661 targets.const_data_ptr<target_t>(), target_lengths_t.const_data_ptr<int64_t>(), max_target_length,
662 log_probs.stride(0), log_probs.stride(1), log_probs.stride(2),
663 log_beta.stride(0), log_beta.stride(1), log_beta.stride(2),
664 tg_batch_offsets.const_data_ptr<int64_t>(), tg_target_stride,
665 batch_size, BLANK);
666 C10_CUDA_KERNEL_LAUNCH_CHECK();
667 }
668
669 // Very crude heuristic for what is a small problem., based on linearly regressing problem dimensions on
670 // the (capped) difference of timings.
671 // Note that for OK problems target length <= input length, so we
672 // only consider input length.
673 bool is_large = (2*log_probs.size(0)+(24*batch_size)/10+(2*num_labels)/10) > 450;
674 if (is_large) { // large alphabet, large batch
675 // this computes the probs, minuend in (16)
676 at::exp_out(grad, log_probs);
677 // now we compute the subtrahend for the blanks. It is a straightforward reduction because we know that
678 // blanks are in every other position.
679 // maybe we should kernelize this, too.
680 auto grad_blank = grad.narrow(2, BLANK, 1);
681 grad_blank -= (at::logsumexp(log_alpha.as_strided({batch_size, log_alpha.size(1), max_target_length+1},
682 {log_alpha.stride(0), log_alpha.stride(1), log_alpha.stride(2)*2})
683 + log_beta.as_strided({batch_size, log_beta.size(1), max_target_length+1},
684 {log_beta.stride(0), log_beta.stride(1), log_beta.stride(2)*2}),
685 2, true)
686 .permute({1, 0, 2})
687 .add_(neg_log_likelihood.view({1, batch_size, 1}))
688 .sub_(log_probs.narrow(2, BLANK, 1))
689 .exp_()
690 );
691 // scale by output gradient (blanks and first summand of non-blanks)
692 grad *= grad_out.view({1, batch_size, 1});
693 if (zero_infinity) {
694 grad = at::where(neg_log_likelihood.view({1, batch_size, 1}) == Scalar(INFINITY), at::zeros({}, grad.options()), grad);
695 }
696
697 // For the non-blank characters, we use a kernel to compute the subtrahend.
698 // Again we might configure block and grid in a better way.
699 int threads_target = max_threads;
700 while (threads_target / 2 >= max_target_length && threads_target > 1) {
701 threads_target /= 2;
702 }
703 int threads_batch = std::min(max_threads / threads_target, (int) batch_size);
704 dim3 block(threads_target, threads_batch);
705 dim3 grid(
706 std::max<int>(
707 (max_target_length + threads_target - 1) / threads_target, 1),
708 (batch_size + threads_batch - 1) / threads_batch,
709 1);
710 ctc_loss_backward_collect_nonblank_gpu_kernel<scalar_t, target_t><<<grid, block, 0, stream>>>
711 (grad.mutable_data_ptr<scalar_t>(),
712 grad_out.const_data_ptr<scalar_t>(), grad_out.stride(0),
713 log_alpha.const_data_ptr<scalar_t>(), log_beta.const_data_ptr<scalar_t>(),
714 log_probs.const_data_ptr<scalar_t>(), input_lengths_t.const_data_ptr<int64_t>(),
715 targets.const_data_ptr<target_t>(), target_lengths_t.const_data_ptr<int64_t>(),
716 neg_log_likelihood.const_data_ptr<scalar_t>(),
717 grad.stride(0), grad.stride(1), grad.stride(2),
718 log_probs.stride(0), log_probs.stride(1), log_probs.stride(2),
719 log_alpha.stride(0), log_alpha.stride(1), log_alpha.stride(2),
720 log_beta.stride(0), log_beta.stride(1), log_beta.stride(2),
721 tg_batch_offsets.const_data_ptr<int64_t>(), tg_target_stride,
722 batch_size, zero_infinity);
723 C10_CUDA_KERNEL_LAUNCH_CHECK();
724 } else { // small problem, use naive algorithm
725 // Still no block/grid configuration guru...
726 int threads_input = max_threads;
727 while (threads_input / 2 >= log_probs.size(0) && threads_input > 1) {
728 threads_input /= 2;
729 }
730 threads_batch = std::min(max_threads / threads_input, (int) batch_size);
731 dim3 block(threads_input, threads_batch);
732 dim3 grid((log_probs.size(0) + threads_input-1)/threads_input, (batch_size+threads_batch-1)/threads_batch);
733 ctc_loss_backward_collect_gpu_kernel<scalar_t, target_t><<<grid, block, 0, stream>>>
734 (grad.mutable_data_ptr<scalar_t>(),
735 grad_out.const_data_ptr<scalar_t>(), grad_out.stride(0),
736 log_alpha.const_data_ptr<scalar_t>(), log_beta.const_data_ptr<scalar_t>(),
737 log_probs.const_data_ptr<scalar_t>(), input_lengths_t.const_data_ptr<int64_t>(), log_probs.size(0),
738 targets.const_data_ptr<target_t>(), target_lengths_t.const_data_ptr<int64_t>(), max_target_length,
739 neg_log_likelihood.const_data_ptr<scalar_t>(),
740 grad.stride(0), grad.stride(1), grad.stride(2),
741 log_probs.stride(0), log_probs.stride(1), log_probs.stride(2),
742 log_alpha.stride(0), log_alpha.stride(1), log_alpha.stride(2),
743 log_beta.stride(0), log_beta.stride(1), log_beta.stride(2),
744 tg_batch_offsets.const_data_ptr<int64_t>(), tg_target_stride,
745 batch_size, num_labels, BLANK, zero_infinity);
746 C10_CUDA_KERNEL_LAUNCH_CHECK(); // catch launch errors
747 }
748
749 // zero those invalid graident elements due to padding
750 {
751 int threads_input = max_threads;
752 while (threads_input / 2 >= log_probs.size(0)) {
753 threads_input /= 2;
754 }
755 threads_batch = std::min(max_threads / threads_input, (int) batch_size);
756 dim3 block(threads_input, threads_batch);
757 dim3 grid(
758 (log_probs.size(0) + threads_input-1)/threads_input,
759 (batch_size+threads_batch-1)/threads_batch);
760 ctc_loss_zero_padded_gradients<scalar_t><<<grid, block, 0, stream>>>(
761 grad.mutable_data_ptr<scalar_t>(),
762 input_lengths_t.const_data_ptr<int64_t>(),
763 grad.stride(0),
764 grad.stride(1),
765 grad.stride(2),
766 grad.size(0),
767 grad.size(1),
768 grad.size(2)
769 );
770 C10_CUDA_KERNEL_LAUNCH_CHECK();
771 }
772
773 return grad;
774 }
775
776 } // namespace
777
ctc_loss_gpu(const Tensor & log_probs,const Tensor & targets,IntArrayRef input_lengths,IntArrayRef target_lengths,int64_t BLANK,bool zero_infinity)778 std::tuple<Tensor, Tensor> ctc_loss_gpu(const Tensor& log_probs, const Tensor& targets, IntArrayRef input_lengths, IntArrayRef target_lengths, int64_t BLANK, bool zero_infinity) {
779 (void)zero_infinity; // only used for backward
780 return AT_DISPATCH_FLOATING_TYPES(log_probs.scalar_type(), "ctc_loss_cuda", [&] {
781 if (targets.scalar_type() == kLong) {
782 return ctc_loss_gpu_template<scalar_t, kLong>(log_probs, targets, input_lengths, target_lengths, BLANK);
783 } else {
784 return ctc_loss_gpu_template<scalar_t, kInt>(log_probs, targets, input_lengths, target_lengths, BLANK);
785 }
786 });
787 }
788
ctc_loss_backward_gpu(const Tensor & grad,const Tensor & log_probs,const Tensor & targets,IntArrayRef input_lengths,IntArrayRef target_lengths,const Tensor & neg_log_likelihood,const Tensor & log_alpha,int64_t BLANK,bool zero_infinity)789 Tensor ctc_loss_backward_gpu(const Tensor& grad, const Tensor& log_probs, const Tensor& targets, IntArrayRef input_lengths, IntArrayRef target_lengths,
790 const Tensor& neg_log_likelihood, const Tensor& log_alpha, int64_t BLANK, bool zero_infinity) {
791 // See Note [Writing Nondeterministic Operations]
792 // Nondeterministic because of atomicAdd usage
793 globalContext().alertNotDeterministic("ctc_loss_backward_gpu");
794 return AT_DISPATCH_FLOATING_TYPES(log_probs.scalar_type(), "ctc_loss_backward_cuda", [&] {
795 if (targets.scalar_type() == kLong) {
796 return ctc_loss_backward_gpu_template<scalar_t, kLong>(grad, log_probs, targets, input_lengths, target_lengths, neg_log_likelihood, log_alpha, BLANK, zero_infinity);
797 } else {
798 return ctc_loss_backward_gpu_template<scalar_t, kInt>(grad, log_probs, targets, input_lengths, target_lengths, neg_log_likelihood, log_alpha, BLANK, zero_infinity);
799 }
800 });
801 }
802
803 } // at::native
804