xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/LossCTC.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Copyright (c) 2018 MathInf GmbH, Thomas Viehmann
2 // Licensed under the BSD-3-Clause license
3 // This is the GPU implementation of the Connectionist Temporal Loss.
4 // We mostly follow Graves.
5 // 1. Graves et al.: http://www.cs.toronto.edu/~graves/icml_2006.pdf
6 // We use the equations from above link, but note that [1] has 1-based indexing and we (of course) use 0-based.
7 // Graves et al. call the probabilities y, we use log_probs (also calling them inputs)
8 // A few optimizations (similar to those here, but also some I didn't take) are described in
9 // 2. Minmin Sun: http://on-demand.gputechconf.com/gtc/2016/presentation/s6383-minmin-sun-speech-recognition.pdf
10 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
11 #include <ATen/TensorUtils.h>
12 #include <c10/util/Exception.h>
13 #include <c10/macros/Macros.h>
14 #include <ATen/core/Tensor.h>
15 #include <ATen/Dispatch.h>
16 #include <ATen/TensorOperators.h>
17 #include <ATen/cuda/Atomic.cuh>
18 #include <ATen/cuda/CUDAContext.h>
19 
20 #ifndef AT_PER_OPERATOR_HEADERS
21 #include <ATen/Functions.h>
22 #include <ATen/NativeFunctions.h>
23 #else
24 #include <ATen/ops/_ctc_loss_backward_native.h>
25 #include <ATen/ops/_ctc_loss_native.h>
26 #include <ATen/ops/empty.h>
27 #include <ATen/ops/exp.h>
28 #include <ATen/ops/full_like.h>
29 #include <ATen/ops/imag.h>
30 #include <ATen/ops/logsumexp.h>
31 #include <ATen/ops/tensor.h>
32 #include <ATen/ops/where.h>
33 #include <ATen/ops/zeros.h>
34 #endif
35 
36 #include <type_traits>
37 #include <numeric>
38 
39 namespace at::native {
40 
41 namespace {
42 
43 // this ad-hoc converts from targets (l in [1]) to augmented targets (l' in [1])
44 // so if l is l_0 l_1 ... l_(tl-1) then this looks up idx in
45 // l' = BLANK l_0 BLANK l_1 BLANK ... BLANK l_(tl-1) BLANK
46 // - note that no bound-checking is done
47 // - it is important to only call it with idx == 0 if the target length is 0
48 // - __restrict__ impact to be measured, see
49 //   https://devblogs.nvidia.com/cuda-pro-tip-optimize-pointer-aliasing/
50 template <typename target_t>
get_target_prime(const target_t * __restrict__ target,int64_t offset,int64_t stride,int64_t idx,int64_t BLANK)51 __device__ static inline int64_t get_target_prime(
52     const target_t* __restrict__ target,
53     int64_t offset,
54     int64_t stride,
55     int64_t idx,
56     int64_t BLANK) {
57   if (idx % 2 == 0) {
58     return BLANK;
59   } else {
60     return target[offset + stride * (idx / 2)];
61   }
62 }
63 
64 // this kernel is a relatively straightforward implementation of the alpha calculation in the forward backward algorithm (section 4.1).
65 // A (minor) twist is that we are using log-calculations to enhance numerical stability (log_probs and log_alpha).
66 // In total it would be more efficient to compute the beta in the same kernel (e.g. cudnn does this). While the beta are not
67 // needed for the loss itself (just the grad), we can return log_alpha+log_beta (so same space as currently) and the overhead
68 // is small and the use-case for loss without grad is relatively limited.
69 // We parallelize by batch and target sequence. Empirically, it is faster to loop over the input (log probs) sequence  and do
70 // target in parallel, even if it means more frequent __syncthreads.
71 // In contrast to the cuDNN implementation, we allow large target lengths. For this we need that all previous `s` have been
72 // computed when we start a new block_s. This is why we have our own for loop here.
73 template<typename scalar_t, typename target_t>
74 __global__ void
75 #if defined (USE_ROCM)
76 C10_LAUNCH_BOUNDS_2((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
77 #endif
ctc_loss_log_alpha_gpu_kernel(scalar_t * __restrict__ log_alpha_data,const scalar_t * log_probs_data,const int64_t * __restrict__ input_lengths,int64_t max_input_length,const target_t * __restrict__ targets_data,const int64_t * __restrict__ target_lengths,int64_t max_target_length,scalar_t * __restrict__ neg_log_likelihood_data,int64_t lp_input_stride,int64_t lp_batch_stride,int64_t lp_char_stride,int64_t la_batch_stride,int64_t la_input_stride,int64_t la_target_stride,const int64_t * __restrict__ tg_batch_offsets,int64_t tg_target_stride,int64_t batch_size,int64_t BLANK)78 ctc_loss_log_alpha_gpu_kernel(scalar_t* __restrict__ log_alpha_data,
79                                     const scalar_t*log_probs_data, const int64_t* __restrict__ input_lengths, int64_t max_input_length,
80                                     const target_t* __restrict__ targets_data, const int64_t* __restrict__ target_lengths, int64_t max_target_length,
81                                     scalar_t* __restrict__ neg_log_likelihood_data,
82                                     int64_t lp_input_stride, int64_t lp_batch_stride, int64_t lp_char_stride,
83                                     int64_t la_batch_stride, int64_t la_input_stride, int64_t la_target_stride,
84                                     const int64_t* __restrict__ tg_batch_offsets, int64_t tg_target_stride,
85                                     int64_t batch_size, int64_t BLANK) {
86 
87   constexpr scalar_t neginf = -INFINITY;
88 
89   // bookkeeping
90   int64_t b = threadIdx.y + blockIdx.y * blockDim.y;
91   int64_t input_length = input_lengths[b];
92   int64_t target_length = target_lengths[b];
93   int64_t lp_batch_offset = b*lp_batch_stride;
94   int64_t la_batch_offset = b*la_batch_stride;
95   int64_t tg_batch_offset = tg_batch_offsets[b];
96 
97   if (b >= batch_size)
98     return;
99 
100   if (input_length == 0) {
101     if (threadIdx.x == 0) {
102       scalar_t log_likelihood = target_length == 0 ? 0 : neginf;
103       neg_log_likelihood_data[b] = -log_likelihood;
104     }
105     return;
106   }
107 
108   // first row (t=0), the three equations for alpha_1 above eq (6)
109   for (int64_t block_s = 0; block_s < 2*max_target_length+1; block_s += blockDim.x) {
110     int64_t s = threadIdx.x + block_s;
111     scalar_t la;
112     switch (s) {
113     case 0:
114       la = log_probs_data[lp_batch_offset + lp_char_stride * BLANK];
115       break;
116     case 1:
117       la = target_length == 0 ? neginf
118                               : log_probs_data
119                                     [lp_batch_offset +
120                                      lp_char_stride *
121                                          get_target_prime(
122                                              targets_data,
123                                              tg_batch_offset,
124                                              tg_target_stride,
125                                              1,
126                                              BLANK)];
127       break;
128     default:
129       la = neginf;
130     }
131     if (s < 2*max_target_length+1)
132       log_alpha_data[la_batch_offset + /* la_input_stride * 0 */ + la_target_stride * s] = la;
133   }
134 
135   for (int64_t block_s = 0; block_s < 2*max_target_length+1; block_s += blockDim.x) {
136     int64_t s = threadIdx.x + block_s;
137 
138     // These two only depend on s, so we can cache them.
139     int64_t current_char;       // l_s in eq (6)
140     bool have_three;            // flag which of the two cases in eq (6) we have
141     if (s < 2 * target_length + 1 && target_length > 0) {
142       current_char = get_target_prime(
143           targets_data,
144           tg_batch_offset,
145           tg_target_stride,
146           s,
147           BLANK);
148       have_three =
149           ((s > 1) &&
150            (get_target_prime(
151                 targets_data,
152                 tg_batch_offset,
153                 tg_target_stride,
154                 s - 2,
155                 BLANK) != current_char));
156     } else {
157       current_char = BLANK;
158       have_three = false;
159     }
160     for (int64_t t=1; t < max_input_length; t++) {
161       __syncthreads(); // on cuda 9 we might use partial synchronization of only the threads within the same batch
162       if ((t < input_length) && (s < 2 * target_length + 1)) {
163         // only for valid t, s. This is equation (6) and (7), la1, la2, la3 are the three summands,
164         // lamax is the maximum for the logsumexp trick.
165         scalar_t la1 = log_alpha_data[la_batch_offset + la_input_stride * (t-1) + la_target_stride * s];
166         scalar_t lamax = la1;
167         scalar_t la2, la3;
168         if (s > 0) {
169           la2 = log_alpha_data[la_batch_offset + la_input_stride * (t-1) + la_target_stride * (s-1)];
170           if (la2 > lamax)
171             lamax = la2;
172         } else {
173           la2 = neginf;
174         }
175         if (have_three) {
176           la3 = log_alpha_data[la_batch_offset + la_input_stride * (t-1) + la_target_stride * (s-2)];
177           if (la3 > lamax)
178             lamax = la3;
179         } else {
180           la3 = neginf;
181         }
182         if (lamax == neginf) // when all are neginf. (then the whole thing is neginf, but we can pretend)
183           lamax = 0;
184 
185         log_alpha_data[la_batch_offset + la_input_stride * t + la_target_stride * s] = std::log(std::exp(la1-lamax)+std::exp(la2-lamax)+std::exp(la3-lamax))+lamax
186           + log_probs_data[lp_batch_offset + t * lp_input_stride + lp_char_stride * current_char];
187       } else {
188         // otherwise we just set to neginf
189         if (s < 2*max_target_length+1)
190           log_alpha_data[la_batch_offset + la_input_stride * t + la_target_stride * s] = neginf;
191       }
192     }
193   }
194   __syncthreads(); // on cuda 9 we might use partial synchronization of only the threads within the same batch
195 
196   // compute the loss (eq (8))
197   if (threadIdx.x == 0) {
198     scalar_t l1 = log_alpha_data[la_batch_offset + la_input_stride * (input_length-1) + la_target_stride * (target_length*2)];
199     scalar_t l2 = target_length > 0
200         ? log_alpha_data
201               [la_batch_offset + la_input_stride * (input_length - 1) +
202                la_target_stride * (target_length * 2 - 1)]
203         : neginf;
204     scalar_t m = ((l1 > l2) ? l1 : l2);
205     m = ((m == neginf) ? 0 : m);
206     scalar_t log_likelihood = std::log(std::exp(l1-m)+std::exp(l2-m))+m;
207     neg_log_likelihood_data[b] = -log_likelihood;
208   }
209 }
210 
211 // The forward computation. Lot's of admin and a call to the alpha kernel.
212 // Note: we do not check that the labels are in the valid range. As we use
213 // them for indexing in the kernels, you'll see memory errors when you
214 // pass corrupt labels.
215 // We support both a 2-dimensional tensor as targets (one set of targets in each row) and
216 // a 1-dimensional tensor where all targets are concatenated (and we use target_lengths
217 // to figure out where they begin).
218 // We return log_alpha (currently, might change to (log_alpha+log_beta) to be passed to the
219 // backward. The dispatch function will only return the loss.
220 template<typename scalar_t, ScalarType target_scalar_type>
ctc_loss_gpu_template(const Tensor & log_probs,const Tensor & targets,IntArrayRef input_lengths,IntArrayRef target_lengths,int64_t BLANK)221 std::tuple<Tensor, Tensor> ctc_loss_gpu_template(const Tensor& log_probs, const Tensor& targets, IntArrayRef input_lengths, IntArrayRef target_lengths, int64_t BLANK) {
222   // log_probs: input_len x batch_size x num_labels
223   // targets [int64]: batch_size x target_length OR sum(target_lengths)
224   CheckedFrom c = "ctc_loss_gpu";
225   using target_t = typename std::conditional<target_scalar_type == kInt, int, int64_t>::type;
226   auto log_probs_arg = TensorArg(log_probs, "log_probs", 1);
227   auto targets_arg = TensorArg(targets, "targets", 2);
228   checkAllSameGPU(c, {log_probs_arg, targets_arg});
229 
230   checkScalarType(c, targets_arg, target_scalar_type);
231   checkDim(c, log_probs_arg, 3);
232   checkDimRange(c, targets_arg, 1, 3);
233 
234   int64_t batch_size = log_probs.size(1);
235   int64_t num_labels = log_probs.size(2);
236   TORCH_CHECK((0 <= BLANK) && (BLANK < num_labels), "blank must be in label range");
237   TORCH_CHECK(input_lengths.size() == static_cast<size_t>(batch_size), "input_lengths must be of size batch_size");
238   TORCH_CHECK(target_lengths.size() == static_cast<size_t>(batch_size), "target_lengths must be of size batch_size");
239 
240   int64_t tg_target_stride;
241 
242   int64_t max_target_length = 0;
243   auto tg_batch_offsets = at::empty({batch_size}, at::device(at::kCPU).dtype(at::kLong));
244   auto tg_batch_offsets_data = tg_batch_offsets.mutable_data_ptr<int64_t>();
245   if (targets.dim() == 1) { // concatenated targets
246     int64_t pos = 0;
247     for (int64_t i = 0; i < batch_size; i++) {
248       TORCH_CHECK(target_lengths[i] >= 0,
249                   "Expected target_lengths to have value at least ", 0, ", but got value ", target_lengths[i],
250                   " (while checking arguments for ", c, ")");
251       tg_batch_offsets_data[i] = pos;
252       pos += target_lengths[i];
253       if (max_target_length < target_lengths[i])
254         max_target_length = target_lengths[i];
255     }
256     tg_target_stride = targets.stride(0);
257     checkSize(c, targets_arg, 0, pos);
258   }
259   else { // batch x max_target_length
260     // dim is 2
261     int64_t tg_batch_stride = targets.stride(0);
262     for (int64_t i = 0; i < batch_size; i++) {
263       TORCH_CHECK(target_lengths[i] >= 0,
264                   "Expected target_lengths to have value at least ", 0, ", but got value ", target_lengths[i],
265                   " (while checking arguments for ", c, ")");
266       tg_batch_offsets_data[i] = i * tg_batch_stride;
267       if (max_target_length < target_lengths[i])
268         max_target_length = target_lengths[i];
269     }
270     tg_target_stride = targets.stride(1);
271     checkSize(c, targets_arg, 0, batch_size);
272     TORCH_CHECK(targets.size(1) >= max_target_length,
273              "Expected tensor to have size at least ", max_target_length, " at dimension 1, but got size ", targets.size(1), " for ", targets_arg,
274              " (while checking arguments for ", c, ")");
275   }
276   int64_t max_input_length = log_probs.size(0);
277   for (int64_t b = 0; b < batch_size; b++) {
278     TORCH_CHECK(input_lengths[b] >= 0,
279              "Expected input_lengths to have value at least ", 0, ", but got value ", input_lengths[b],
280              " (while checking arguments for ", c, ")");
281     TORCH_CHECK(input_lengths[b] <= max_input_length,
282              "Expected input_lengths to have value at most ", max_input_length, ", but got value ", input_lengths[b],
283              " (while checking arguments for ", c, ")");
284   }
285 
286   auto target_lengths_t = at::tensor(target_lengths, targets.options().dtype(kLong));
287   auto input_lengths_t = at::tensor(input_lengths, targets.options().dtype(kLong));
288   tg_batch_offsets = tg_batch_offsets.cuda();
289 
290   Tensor log_alpha = at::empty({batch_size, log_probs.size(0), 2*max_target_length+1}, log_probs.options());
291   Tensor neg_log_likelihood = at::empty({batch_size}, log_probs.options());
292 
293   // Very likely, we could be more clever here, e.g. learning (or generalizing and reusing) from SoftMax.cu...
294   constexpr int max_threads = std::is_same<scalar_t, float>::value ? 1024 : 768; // we need 72 or so 32 bit registers for double
295   int threads_target = max_threads;
296   while (threads_target / 2 >= 2*max_target_length+1) {
297     threads_target /= 2;
298   }
299   int threads_batch = std::min(max_threads / threads_target, (int) batch_size);
300   dim3 block(threads_target, threads_batch);
301   dim3 grid(1, (batch_size+threads_batch-1)/threads_batch);
302   cudaStream_t stream = at::cuda::getCurrentCUDAStream();
303 
304   ctc_loss_log_alpha_gpu_kernel<scalar_t, target_t><<<grid, block, 0, stream>>>(
305                       log_alpha.mutable_data_ptr<scalar_t>(),
306                       log_probs.const_data_ptr<scalar_t>(), input_lengths_t.const_data_ptr<int64_t>(), log_probs.size(0),
307                       targets.const_data_ptr<target_t>(), target_lengths_t.const_data_ptr<int64_t>(), max_target_length,
308                       neg_log_likelihood.mutable_data_ptr<scalar_t>(),
309                       log_probs.stride(0), log_probs.stride(1), log_probs.stride(2),
310                       log_alpha.stride(0), log_alpha.stride(1), log_alpha.stride(2),
311                       tg_batch_offsets.const_data_ptr<int64_t>(), tg_target_stride,
312                       batch_size, BLANK);
313   C10_CUDA_KERNEL_LAUNCH_CHECK();
314   return std::make_tuple(neg_log_likelihood, log_alpha);
315 }
316 
317 // The second (backward) half of the forward backward algorithm, (10) and (11). This is parallel to the
318 // alpha kernel above. (As mentioned above, it might make sense do the calculation in the alpha kernel.)
319 template<typename scalar_t, typename target_t>
320 __global__ void
321 C10_LAUNCH_BOUNDS_2((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
ctc_loss_backward_log_beta_gpu_kernel(scalar_t * __restrict__ log_beta_data,const scalar_t * log_probs_data,const int64_t * __restrict__ input_lengths,int64_t max_input_length,const target_t * __restrict__ targets_data,const int64_t * __restrict__ target_lengths,int64_t max_target_length,int64_t lp_input_stride,int64_t lp_batch_stride,int64_t lp_char_stride,int64_t lb_batch_stride,int64_t lb_input_stride,int64_t lb_target_stride,const int64_t * __restrict__ tg_batch_offsets,int64_t tg_target_stride,int64_t batch_size,int64_t BLANK)322 ctc_loss_backward_log_beta_gpu_kernel(scalar_t* __restrict__ log_beta_data,
323                                       const scalar_t*log_probs_data, const int64_t* __restrict__ input_lengths, int64_t max_input_length,
324                                       const target_t* __restrict__ targets_data, const int64_t* __restrict__ target_lengths, int64_t max_target_length,
325                                       int64_t lp_input_stride, int64_t lp_batch_stride, int64_t lp_char_stride,
326                                       int64_t lb_batch_stride, int64_t lb_input_stride, int64_t lb_target_stride,
327                                       const int64_t* __restrict__ tg_batch_offsets, int64_t tg_target_stride,
328                                       int64_t batch_size, int64_t BLANK) {
329   constexpr scalar_t neginf = -INFINITY;
330 
331   int64_t b = threadIdx.y + blockIdx.y * blockDim.y;
332 
333   int64_t input_length = input_lengths[b];
334   int64_t target_length = target_lengths[b];
335   int64_t lp_batch_offset = b*lp_batch_stride;
336   int64_t lb_batch_offset = b*lb_batch_stride;
337   int64_t tg_batch_offset = tg_batch_offsets[b];
338 
339   if (b >= batch_size)
340     return;
341 
342   if (input_length == 0)
343     return;
344 
345   // "first" row, the beta initialization before eq (10) (t=target_length - differes per batch)
346   for (int64_t block_s = 2*max_target_length - (2*max_target_length % blockDim.x); block_s >= 0; block_s -= blockDim.x) {
347     int64_t s = threadIdx.x + block_s;
348     scalar_t lb;
349     if (s == 2*target_length) {
350       lb = log_probs_data[lp_batch_offset + (input_length-1) * lp_input_stride + lp_char_stride * BLANK];
351     } else if (s == 2 * target_length - 1) { // false for target_length == 0
352       int64_t current_target_prime = get_target_prime(
353           targets_data,
354           tg_batch_offset,
355           tg_target_stride,
356           s,
357           BLANK);
358       lb = log_probs_data[lp_batch_offset + (input_length-1) * lp_input_stride + lp_char_stride * current_target_prime];
359     } else {
360       lb = neginf;
361     }
362     if (s < 2*max_target_length+1) {
363       log_beta_data[lb_batch_offset + (input_length-1) * lb_input_stride + lb_target_stride * s] = lb;
364     }
365   }
366 
367   // go backward in s
368   for (int64_t block_s = 2*max_target_length - (2*max_target_length % blockDim.x); block_s >= 0; block_s -= blockDim.x) {
369     int64_t s = threadIdx.x + block_s;
370     int64_t current_target_prime;
371     bool have_three;
372     if (s < 2 * target_length + 1 && target_length > 0) {
373       current_target_prime = get_target_prime(
374           targets_data,
375           tg_batch_offset,
376           tg_target_stride,
377           s,
378           BLANK);
379       have_three =
380           ((s < 2 * target_length - 1) &&
381            (get_target_prime(
382                 targets_data,
383                 tg_batch_offset,
384                 tg_target_stride,
385                 s + 2,
386                 BLANK) != current_target_prime));
387     } else {
388       current_target_prime = BLANK;
389       have_three = false;
390     }
391     // now go backward in t. Note that we need to skip the last timestep that we did above.
392     for (int64_t t=max_input_length-2; t>=0; t--) {
393       __syncthreads(); // on cuda 9 we might use partial synchronization of only the threads within the same batch item
394       if ((t < input_length - 1) && (s < 2 * target_length + 1)) {
395         scalar_t lb1 = log_beta_data[lb_batch_offset + lb_input_stride * (t+1) + lb_target_stride * s];
396         scalar_t lbmax = lb1;
397         scalar_t lb2, lb3;
398 
399         if (s < 2*target_length) {
400           lb2 = log_beta_data[lb_batch_offset + lb_input_stride * (t+1) + lb_target_stride * (s+1)];
401           if (lb2 > lbmax)
402             lbmax = lb2;
403         } else {
404           lb2 = neginf;
405         }
406         if (have_three) {
407           lb3 = log_beta_data[lb_batch_offset + lb_input_stride * (t+1) + lb_target_stride * (s+2)];
408           if (lb3 > lbmax)
409             lbmax = lb3;
410         } else {
411           lb3 = neginf;
412         }
413         if (lbmax == neginf)
414           lbmax = 0;
415 
416         scalar_t lb = std::log(std::exp(lb1-lbmax)+std::exp(lb2-lbmax)+std::exp(lb3-lbmax))+lbmax
417           + log_probs_data[lp_batch_offset + t * lp_input_stride + lp_char_stride * current_target_prime];
418 
419         log_beta_data[lb_batch_offset + lb_input_stride * t + lb_target_stride * s] = lb;
420       } else if (
421           (s < 2 * max_target_length + 1) &&
422           (((target_length == 0) && (s > 0)) || (s >= 2 * target_length + 1) ||
423            (t >= input_length))) {
424         log_beta_data
425             [lb_batch_offset + lb_input_stride * t + lb_target_stride * s] =
426                 neginf;
427       }
428     }
429   }
430 }
431 
432 // This implements the subtrahend of equation (16) for all *nonblank* characters.
433 // It assumes you have probs in gradient_data when called
434 // and it modifies gradient_data to be, the gradient.
435 // In order to facilitate this inplace update, We don't actually do this in logspace.
436 // (The other variant implemented uses log_space and the differences seem to be
437 //  not so problematic at least with unit normal distributed test activations.)
438 // Internally this uses atomicAdd because different threads may write to the same
439 // gradient position.
440 // This is parallelised over b and s again.
441 // Note that for us, the Z of eqn (16) is actually constant for all t and it is the
442 // likelihood - this is why we use the negative log likelihood below.
443 // We also multiply by the input gradient to keep with standard autograd style.
444 // I took this trick from [2], for moderate alphabet sizes a log-space
445 // calculation (with an atomic log add) is similarly in performance, but for large
446 // alphabets the inplace nature is a considerable advantage.
447 template<typename scalar_t, typename target_t>
448 __global__ void
449 #if defined (USE_ROCM)
450 C10_LAUNCH_BOUNDS_2((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
451 #endif
ctc_loss_backward_collect_nonblank_gpu_kernel(scalar_t * __restrict__ gradient_data,const scalar_t * __restrict__ grad_out_data,int64_t grad_out_batch_stride,const scalar_t * __restrict__ log_alpha_data,const scalar_t * __restrict__ log_beta_data,const scalar_t * log_probs_data,const int64_t * __restrict__ input_lengths,const target_t * __restrict__ targets_data,const int64_t * __restrict__ target_lengths,const scalar_t * __restrict__ neg_log_likelihood_data,int64_t gr_input_stride,int64_t gr_batch_stride,int64_t gr_char_stride,int64_t lp_input_stride,int64_t lp_batch_stride,int64_t lp_char_stride,int64_t la_batch_stride,int64_t la_input_stride,int64_t la_target_stride,int64_t lb_batch_stride,int64_t lb_input_stride,int64_t lb_target_stride,const int64_t * __restrict__ tg_batch_offsets,int64_t tg_target_stride,int64_t batch_size,bool zero_infinity)452 ctc_loss_backward_collect_nonblank_gpu_kernel(scalar_t* __restrict__ gradient_data,
453                                                      const scalar_t* __restrict__ grad_out_data, int64_t grad_out_batch_stride,
454                                                      const scalar_t* __restrict__ log_alpha_data, const scalar_t* __restrict__ log_beta_data,
455                                                      const scalar_t*log_probs_data, const int64_t* __restrict__ input_lengths,
456                                                      const target_t* __restrict__ targets_data, const int64_t* __restrict__ target_lengths,
457                                                      const scalar_t* __restrict__ neg_log_likelihood_data,
458                                                      int64_t gr_input_stride, int64_t gr_batch_stride, int64_t gr_char_stride,
459                                                      int64_t lp_input_stride, int64_t lp_batch_stride, int64_t lp_char_stride,
460                                                      int64_t la_batch_stride, int64_t la_input_stride, int64_t la_target_stride,
461                                                      int64_t lb_batch_stride, int64_t lb_input_stride, int64_t lb_target_stride,
462                                                      const int64_t* __restrict__ tg_batch_offsets, int64_t tg_target_stride,
463                                               int64_t batch_size, bool zero_infinity) {
464   int64_t b = threadIdx.y + blockIdx.y * blockDim.y;
465   int64_t s = threadIdx.x + blockIdx.x * blockDim.x; // note, this directly indexes into targets, not targets prime!
466 
467   if (b >= batch_size)
468     return;
469 
470   int64_t input_length = input_lengths[b];
471   int64_t target_length = target_lengths[b];
472   int64_t gr_batch_offset = b*gr_batch_stride;
473   int64_t lp_batch_offset = b*lp_batch_stride;
474   int64_t la_batch_offset = b*la_batch_stride;
475   int64_t lb_batch_offset = b*lb_batch_stride;
476   int64_t tg_batch_offset = tg_batch_offsets[b];
477 
478   if (s >= target_length)
479     return;
480 
481   int64_t target = targets_data[tg_batch_offset + s * tg_target_stride];
482   scalar_t nll = neg_log_likelihood_data[b];
483   scalar_t gr =  grad_out_data[b * grad_out_batch_stride];
484 
485   if (zero_infinity && nll == INFINITY)
486     return;
487 
488   for (int64_t t = 0; t < input_length; t++) {
489     scalar_t lp = log_probs_data[lp_batch_offset + t * lp_input_stride + lp_char_stride * target];
490     gpuAtomicAddNoReturn(&gradient_data[gr_batch_offset + t * gr_input_stride + gr_char_stride * target],
491               -std::exp(log_alpha_data[la_batch_offset + la_input_stride * t + la_target_stride * (s*2+1)]
492                         + log_beta_data[lb_batch_offset + lb_input_stride * t + lb_target_stride * (s*2+1)]
493                         + nll - lp) * gr);
494   }
495 }
496 
497 // This is the naive implementation of equation (16). It is parallelised in batch and input timestep.
498 // It appears to be faster than the above method for small batch sizes.
499 template<typename scalar_t, typename target_t>
500 __global__ void
501 #if defined (USE_ROCM)
502 C10_LAUNCH_BOUNDS_2((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
503 #endif
ctc_loss_backward_collect_gpu_kernel(scalar_t * __restrict__ gradient_data,const scalar_t * __restrict__ grad_out_data,int64_t grad_out_batch_stride,const scalar_t * __restrict__ log_alpha_data,const scalar_t * __restrict__ log_beta_data,const scalar_t * log_probs_data,const int64_t * __restrict__ input_lengths,int64_t max_input_length,const target_t * __restrict__ targets_data,const int64_t * __restrict__ target_lengths,int64_t max_target_length,const scalar_t * __restrict__ neg_log_likelihood_data,int64_t gr_input_stride,int64_t gr_batch_stride,int64_t gr_char_stride,int64_t lp_input_stride,int64_t lp_batch_stride,int64_t lp_char_stride,int64_t la_batch_stride,int64_t la_input_stride,int64_t la_target_stride,int64_t lb_batch_stride,int64_t lb_input_stride,int64_t lb_target_stride,const int64_t * __restrict__ tg_batch_offsets,int64_t tg_target_stride,int64_t batch_size,int64_t num_labels,int64_t BLANK,bool zero_infinity)504 ctc_loss_backward_collect_gpu_kernel(scalar_t* __restrict__ gradient_data,
505                                                      const scalar_t* __restrict__ grad_out_data, int64_t grad_out_batch_stride,
506                                                      const scalar_t* __restrict__ log_alpha_data, const scalar_t* __restrict__ log_beta_data,
507                                                      const scalar_t*log_probs_data, const int64_t* __restrict__ input_lengths, int64_t max_input_length,
508                                                      const target_t* __restrict__ targets_data, const int64_t* __restrict__ target_lengths, int64_t max_target_length,
509                                                      const scalar_t* __restrict__ neg_log_likelihood_data,
510                                                      int64_t gr_input_stride, int64_t gr_batch_stride, int64_t gr_char_stride,
511                                                      int64_t lp_input_stride, int64_t lp_batch_stride, int64_t lp_char_stride,
512                                                      int64_t la_batch_stride, int64_t la_input_stride, int64_t la_target_stride,
513                                                      int64_t lb_batch_stride, int64_t lb_input_stride, int64_t lb_target_stride,
514                                                      const int64_t* __restrict__ tg_batch_offsets, int64_t tg_target_stride,
515                                      int64_t batch_size, int64_t num_labels, int64_t BLANK, bool zero_infinity) {
516 
517   constexpr scalar_t neginf = -INFINITY;
518   int64_t b = threadIdx.y + blockIdx.y * blockDim.y;
519   int64_t t = threadIdx.x + blockIdx.x * blockDim.x;
520 
521   if ((t >= max_input_length) || (b >= batch_size))
522     return;
523 
524   int64_t input_length = input_lengths[b];
525   int64_t target_length = target_lengths[b];
526   int64_t gr_batch_offset = b*gr_batch_stride;
527   int64_t lp_batch_offset = b*lp_batch_stride;
528   int64_t la_batch_offset = b*la_batch_stride;
529   int64_t lb_batch_offset = b*lb_batch_stride;
530   int64_t tg_batch_offset = tg_batch_offsets[b];
531 
532   // collected[b, t, target'[s]] "log+=" log_alpha[t, s]+log_beta[t, s]
533   for (int s = 0; s < 2*max_target_length+1; s++) {
534     if (s < 2 * target_length + 1) { // if target_length == 0, s == 0
535       int64_t current_target_prime = get_target_prime(
536           targets_data,
537           tg_batch_offset,
538           tg_target_stride,
539           s,
540           BLANK);
541       scalar_t log_alpha_beta = (log_alpha_data[la_batch_offset + la_input_stride * t + la_target_stride * s]
542                                  + log_beta_data[lb_batch_offset + lb_input_stride * t + lb_target_stride * s]);
543       scalar_t& lcab = gradient_data[gr_batch_offset + t * gr_input_stride + gr_char_stride * current_target_prime];
544       if (lcab == neginf) {
545         lcab = log_alpha_beta;
546       } else {
547         scalar_t max = ((lcab > log_alpha_beta) ? lcab : log_alpha_beta);
548         lcab = std::log(std::exp(lcab-max)+std::exp(log_alpha_beta-max))+max;
549       }
550     }
551   }
552 
553   scalar_t nll = neg_log_likelihood_data[b];
554   scalar_t gr =  grad_out_data[b * grad_out_batch_stride];
555 
556   for (int64_t c = 0; c < num_labels; c++) {
557     scalar_t& res = gradient_data[gr_batch_offset + t * gr_input_stride + gr_char_stride * c];
558     if (t < input_length && (! zero_infinity || nll != INFINITY)) {
559       scalar_t lp = log_probs_data[lp_batch_offset + t * lp_input_stride + lp_char_stride * c];
560       res = (std::exp(lp)-std::exp(res + nll - lp)) * gr;
561     }
562     else {
563       res = 0.;
564     }
565   }
566 }
567 
568 // This is to zero gradients which corresponding to the out-of-sequence position
569 // Those gradients should not be used in any model update since the input
570 // elements are padded
571 template<typename scalar_t>
572 __global__ void
573 #if defined (USE_ROCM)
574 C10_LAUNCH_BOUNDS_2((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
575 #endif
ctc_loss_zero_padded_gradients(scalar_t * __restrict__ gradient_data,const int64_t * __restrict__ input_lengths,int64_t gr_timestep_stride,int64_t gr_batch_stride,int64_t gr_label_stride,int64_t max_input_length,int64_t batch_size,int64_t num_labels)576 ctc_loss_zero_padded_gradients(
577     scalar_t* __restrict__ gradient_data,   /* (T, B, D) layout */
578     const int64_t* __restrict__ input_lengths, /* (B, ) layout */
579     int64_t gr_timestep_stride,
580     int64_t gr_batch_stride,
581     int64_t gr_label_stride,
582     int64_t max_input_length, /* T */
583     int64_t batch_size, /* B */
584     int64_t num_labels  /* D */ ) {
585       int64_t b = threadIdx.y + blockIdx.y * blockDim.y;
586       int64_t t = threadIdx.x + blockIdx.x * blockDim.x;
587 
588       if (b >= batch_size || t >= max_input_length) {
589         return;
590       }
591 
592       scalar_t input_length = input_lengths[b];
593       if (t >= input_length) {
594         for (int l = 0; l < num_labels; l++)
595           gradient_data[
596             t * gr_timestep_stride + b * gr_batch_stride + l * gr_label_stride]
597           = 0.0f;
598       }
599   }
600 
601 
602 // The backward. It essentially computes eq 16 by using the above kernels.
603 // We don't do a lot of checking as we envision this to be called only when backpropagating through a (well-checked) forward.
604 template<typename scalar_t, ScalarType target_scalar_type>
ctc_loss_backward_gpu_template(const Tensor & grad_out,const Tensor & log_probs,const Tensor & targets,IntArrayRef input_lengths,IntArrayRef target_lengths,const Tensor & neg_log_likelihood,const Tensor & log_alpha,int64_t BLANK,bool zero_infinity)605 Tensor ctc_loss_backward_gpu_template(const Tensor& grad_out, const Tensor& log_probs, const Tensor& targets, IntArrayRef input_lengths, IntArrayRef target_lengths,
606                                       const Tensor& neg_log_likelihood, const Tensor& log_alpha, int64_t BLANK, bool zero_infinity) {
607   constexpr scalar_t neginf = -INFINITY;
608   using target_t = typename std::conditional<target_scalar_type == kInt, int, int64_t>::type;
609   int64_t batch_size = log_probs.size(1);
610   int64_t num_labels = log_probs.size(2);
611   int64_t tg_target_stride;
612 
613   int64_t max_target_length;
614   auto tg_batch_offsets = at::empty({batch_size}, TensorOptions(at::CPU(kLong)));
615   auto tg_batch_offsets_data = tg_batch_offsets.mutable_data_ptr<int64_t>();
616   if (targets.dim() == 1) { // concatenated targets
617     int64_t pos = 0;
618     max_target_length = 0;
619     for (int64_t i = 0; i < batch_size; i++) {
620       tg_batch_offsets_data[i] = pos;
621       pos += target_lengths[i];
622       if (max_target_length < target_lengths[i])
623         max_target_length = target_lengths[i];
624     }
625     tg_target_stride = targets.stride(0);
626   }
627   else { // batch x max_target_length
628     // dim is 2
629     int64_t tg_batch_stride = targets.stride(0);
630     for (int64_t i = 0; i < batch_size; i++) {
631       tg_batch_offsets_data[i] = i * tg_batch_stride;
632     }
633     tg_target_stride = targets.stride(1);
634     max_target_length = log_alpha.size(2)/2; // targets.size(1) might be larger
635   }
636   auto target_lengths_t = at::tensor(target_lengths, targets.options().dtype(kLong));
637   auto input_lengths_t = at::tensor(input_lengths, targets.options().dtype(kLong));
638   tg_batch_offsets = tg_batch_offsets.cuda();
639 
640   Tensor log_beta = at::empty_like(log_alpha, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
641   log_beta.fill_(neginf);
642 
643   Tensor grad = at::full_like(log_probs, neginf, LEGACY_CONTIGUOUS_MEMORY_FORMAT); // initialization for log(sum (alpha beta))
644 
645   // As above, there may be better configurations to use.
646   constexpr int max_threads = std::is_same<scalar_t, float>::value ? 1024 : 896; // we need 72 or so 32 bit registers for double
647   int threads_target = max_threads;
648   while (threads_target / 2 >= 2*max_target_length+1) {
649     threads_target /= 2;
650   }
651   int threads_batch = std::min(max_threads / threads_target, (int) batch_size);
652 
653   cudaStream_t stream = at::cuda::getCurrentCUDAStream();
654 
655   {
656     dim3 block(threads_target, threads_batch);
657     dim3 grid(1, (batch_size+threads_batch-1)/threads_batch);
658     ctc_loss_backward_log_beta_gpu_kernel<scalar_t, target_t><<<grid, block, 0, stream>>>
659       (log_beta.mutable_data_ptr<scalar_t>(),
660        log_probs.const_data_ptr<scalar_t>(), input_lengths_t.const_data_ptr<int64_t>(), log_probs.size(0),
661        targets.const_data_ptr<target_t>(), target_lengths_t.const_data_ptr<int64_t>(), max_target_length,
662        log_probs.stride(0), log_probs.stride(1), log_probs.stride(2),
663        log_beta.stride(0), log_beta.stride(1), log_beta.stride(2),
664        tg_batch_offsets.const_data_ptr<int64_t>(), tg_target_stride,
665        batch_size, BLANK);
666     C10_CUDA_KERNEL_LAUNCH_CHECK();
667   }
668 
669   // Very crude heuristic for what is a small problem., based on linearly regressing problem dimensions on
670   // the (capped) difference of timings.
671   // Note that for OK problems target length <= input length, so we
672   // only consider input length.
673   bool is_large = (2*log_probs.size(0)+(24*batch_size)/10+(2*num_labels)/10) > 450;
674   if (is_large) { // large alphabet, large batch
675     // this computes the probs, minuend in (16)
676     at::exp_out(grad, log_probs);
677     // now we compute the subtrahend for the blanks. It is a straightforward reduction because we know that
678     // blanks are in every other position.
679     // maybe we should kernelize this, too.
680     auto grad_blank = grad.narrow(2, BLANK, 1);
681     grad_blank -= (at::logsumexp(log_alpha.as_strided({batch_size, log_alpha.size(1), max_target_length+1},
682                                                       {log_alpha.stride(0), log_alpha.stride(1), log_alpha.stride(2)*2})
683                                  + log_beta.as_strided({batch_size, log_beta.size(1), max_target_length+1},
684                                                        {log_beta.stride(0), log_beta.stride(1), log_beta.stride(2)*2}),
685                                  2, true)
686                    .permute({1, 0, 2})
687                    .add_(neg_log_likelihood.view({1, batch_size, 1}))
688                    .sub_(log_probs.narrow(2, BLANK, 1))
689                    .exp_()
690                    );
691     // scale by output gradient (blanks and first summand of non-blanks)
692     grad *= grad_out.view({1, batch_size, 1});
693     if (zero_infinity) {
694       grad = at::where(neg_log_likelihood.view({1, batch_size, 1}) == Scalar(INFINITY), at::zeros({}, grad.options()), grad);
695     }
696 
697     // For the non-blank characters, we use a kernel to compute the subtrahend.
698     // Again we might configure block and grid in a better way.
699     int threads_target = max_threads;
700     while (threads_target / 2 >= max_target_length && threads_target > 1) {
701       threads_target /= 2;
702     }
703     int threads_batch = std::min(max_threads / threads_target, (int) batch_size);
704     dim3 block(threads_target, threads_batch);
705     dim3 grid(
706         std::max<int>(
707             (max_target_length + threads_target - 1) / threads_target, 1),
708         (batch_size + threads_batch - 1) / threads_batch,
709         1);
710     ctc_loss_backward_collect_nonblank_gpu_kernel<scalar_t, target_t><<<grid, block, 0, stream>>>
711       (grad.mutable_data_ptr<scalar_t>(),
712        grad_out.const_data_ptr<scalar_t>(), grad_out.stride(0),
713        log_alpha.const_data_ptr<scalar_t>(), log_beta.const_data_ptr<scalar_t>(),
714        log_probs.const_data_ptr<scalar_t>(), input_lengths_t.const_data_ptr<int64_t>(),
715        targets.const_data_ptr<target_t>(), target_lengths_t.const_data_ptr<int64_t>(),
716        neg_log_likelihood.const_data_ptr<scalar_t>(),
717        grad.stride(0), grad.stride(1), grad.stride(2),
718        log_probs.stride(0), log_probs.stride(1), log_probs.stride(2),
719        log_alpha.stride(0), log_alpha.stride(1), log_alpha.stride(2),
720        log_beta.stride(0), log_beta.stride(1), log_beta.stride(2),
721        tg_batch_offsets.const_data_ptr<int64_t>(), tg_target_stride,
722        batch_size, zero_infinity);
723     C10_CUDA_KERNEL_LAUNCH_CHECK();
724   } else { // small problem, use naive algorithm
725     // Still no block/grid configuration guru...
726     int threads_input = max_threads;
727     while (threads_input / 2 >= log_probs.size(0) && threads_input > 1) {
728       threads_input /= 2;
729     }
730     threads_batch = std::min(max_threads / threads_input, (int) batch_size);
731     dim3 block(threads_input, threads_batch);
732     dim3 grid((log_probs.size(0) + threads_input-1)/threads_input, (batch_size+threads_batch-1)/threads_batch);
733     ctc_loss_backward_collect_gpu_kernel<scalar_t, target_t><<<grid, block, 0, stream>>>
734       (grad.mutable_data_ptr<scalar_t>(),
735        grad_out.const_data_ptr<scalar_t>(), grad_out.stride(0),
736        log_alpha.const_data_ptr<scalar_t>(), log_beta.const_data_ptr<scalar_t>(),
737        log_probs.const_data_ptr<scalar_t>(), input_lengths_t.const_data_ptr<int64_t>(), log_probs.size(0),
738        targets.const_data_ptr<target_t>(), target_lengths_t.const_data_ptr<int64_t>(), max_target_length,
739        neg_log_likelihood.const_data_ptr<scalar_t>(),
740        grad.stride(0), grad.stride(1), grad.stride(2),
741        log_probs.stride(0), log_probs.stride(1), log_probs.stride(2),
742        log_alpha.stride(0), log_alpha.stride(1), log_alpha.stride(2),
743        log_beta.stride(0), log_beta.stride(1), log_beta.stride(2),
744        tg_batch_offsets.const_data_ptr<int64_t>(), tg_target_stride,
745        batch_size, num_labels, BLANK, zero_infinity);
746     C10_CUDA_KERNEL_LAUNCH_CHECK(); // catch launch errors
747   }
748 
749   // zero those invalid graident elements due to padding
750   {
751     int threads_input = max_threads;
752     while (threads_input / 2 >= log_probs.size(0)) {
753       threads_input /= 2;
754     }
755     threads_batch = std::min(max_threads / threads_input, (int) batch_size);
756     dim3 block(threads_input, threads_batch);
757     dim3 grid(
758       (log_probs.size(0) + threads_input-1)/threads_input,
759       (batch_size+threads_batch-1)/threads_batch);
760     ctc_loss_zero_padded_gradients<scalar_t><<<grid, block, 0, stream>>>(
761       grad.mutable_data_ptr<scalar_t>(),
762       input_lengths_t.const_data_ptr<int64_t>(),
763       grad.stride(0),
764       grad.stride(1),
765       grad.stride(2),
766       grad.size(0),
767       grad.size(1),
768       grad.size(2)
769     );
770     C10_CUDA_KERNEL_LAUNCH_CHECK();
771   }
772 
773   return grad;
774 }
775 
776 } // namespace
777 
ctc_loss_gpu(const Tensor & log_probs,const Tensor & targets,IntArrayRef input_lengths,IntArrayRef target_lengths,int64_t BLANK,bool zero_infinity)778 std::tuple<Tensor, Tensor> ctc_loss_gpu(const Tensor& log_probs, const Tensor& targets, IntArrayRef input_lengths, IntArrayRef target_lengths, int64_t BLANK, bool zero_infinity) {
779   (void)zero_infinity; // only used for backward
780   return AT_DISPATCH_FLOATING_TYPES(log_probs.scalar_type(), "ctc_loss_cuda", [&] {
781       if (targets.scalar_type() == kLong) {
782         return ctc_loss_gpu_template<scalar_t, kLong>(log_probs, targets, input_lengths, target_lengths, BLANK);
783       } else {
784         return ctc_loss_gpu_template<scalar_t, kInt>(log_probs, targets, input_lengths, target_lengths, BLANK);
785       }
786     });
787 }
788 
ctc_loss_backward_gpu(const Tensor & grad,const Tensor & log_probs,const Tensor & targets,IntArrayRef input_lengths,IntArrayRef target_lengths,const Tensor & neg_log_likelihood,const Tensor & log_alpha,int64_t BLANK,bool zero_infinity)789 Tensor ctc_loss_backward_gpu(const Tensor& grad, const Tensor& log_probs, const Tensor& targets, IntArrayRef input_lengths, IntArrayRef target_lengths,
790                              const Tensor& neg_log_likelihood, const Tensor& log_alpha, int64_t BLANK, bool zero_infinity) {
791   // See Note [Writing Nondeterministic Operations]
792   // Nondeterministic because of atomicAdd usage
793   globalContext().alertNotDeterministic("ctc_loss_backward_gpu");
794   return AT_DISPATCH_FLOATING_TYPES(log_probs.scalar_type(), "ctc_loss_backward_cuda", [&] {
795       if (targets.scalar_type() == kLong) {
796         return ctc_loss_backward_gpu_template<scalar_t, kLong>(grad, log_probs, targets, input_lengths, target_lengths, neg_log_likelihood, log_alpha, BLANK, zero_infinity);
797       } else {
798         return ctc_loss_backward_gpu_template<scalar_t, kInt>(grad, log_probs, targets, input_lengths, target_lengths, neg_log_likelihood, log_alpha, BLANK, zero_infinity);
799       }
800     });
801 }
802 
803 } // at::native
804