xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/Loss.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/AccumulateType.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/cuda/detail/KernelUtils.h>
6 #include <ATen/native/TensorIterator.h>
7 #include <ATen/TensorUtils.h>
8 #include <ATen/TensorOperators.h>
9 #include <ATen/cuda/detail/KernelUtils.h>
10 #include <ATen/native/cuda/Loops.cuh>
11 #include <ATen/native/Resize.h>
12 
13 #include <type_traits>
14 
15 #ifndef AT_PER_OPERATOR_HEADERS
16 #include <ATen/Functions.h>
17 #include <ATen/NativeFunctions.h>
18 #else
19 #include <ATen/ops/binary_cross_entropy_backward_native.h>
20 #include <ATen/ops/binary_cross_entropy_native.h>
21 #include <ATen/ops/empty_like.h>
22 #include <ATen/ops/exp.h>
23 #include <ATen/ops/nll_loss_backward_native.h>
24 #include <ATen/ops/nll_loss_forward_native.h>
25 #include <ATen/ops/squeeze.h>
26 #endif
27 
28 constexpr float EPSILON = 1e-12;
29 
30 namespace {
31 
32 using namespace at;
33 
binary_cross_entropy_backward_out_kernel(Tensor & grad_input,const Tensor & grad,const Tensor & input,const Tensor & target)34 void binary_cross_entropy_backward_out_kernel(Tensor& grad_input, const Tensor& grad, const Tensor& input, const Tensor& target) {
35   at::TensorIterator iter = TensorIteratorConfig()
36       .add_output(grad_input)
37       .add_input(grad)
38       .add_input(input)
39       .add_input(target)
40       .build();
41   AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.common_dtype(), "binary_cross_entropy_backward_out_cuda", [&]() {
42     at::native::gpu_kernel(iter, [] GPU_LAMBDA (
43         scalar_t grad_val,
44         scalar_t input_val,
45         scalar_t target_val
46       ) -> scalar_t {
47         const scalar_t one = 1;
48         const scalar_t epsilon = EPSILON;
49 
50         scalar_t grad_input_denominator = max(
51           (one - input_val) * input_val,
52           epsilon
53         );
54 
55         return grad_val * (input_val - target_val) / grad_input_denominator;
56       }
57     );
58   });
59 }
60 
61 } // namespace
62 
63 namespace at::native {
64 
binary_cross_entropy_cuda(const Tensor & input,const Tensor & target,const std::optional<Tensor> & weight_opt,int64_t reduction)65 Tensor binary_cross_entropy_cuda(const Tensor& input, const Tensor& target, const std::optional<Tensor>& weight_opt, int64_t reduction) {
66   // See [Note: hacky wrapper removal for optional tensor]
67   c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
68   const Tensor& weight = *weight_maybe_owned;
69 
70     Tensor loss = at::empty_like(input);
71     return at::native::binary_cross_entropy_out_cuda(
72         input, target, weight, reduction, loss);
73 }
74 
binary_cross_entropy_out_cuda(const Tensor & input,const Tensor & target,const std::optional<Tensor> & weight_opt,int64_t reduction,Tensor & loss)75 Tensor& binary_cross_entropy_out_cuda(const Tensor& input, const Tensor& target, const std::optional<Tensor>& weight_opt, int64_t reduction, Tensor& loss) {
76   // See [Note: hacky wrapper removal for optional tensor]
77   c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
78   const Tensor& weight = *weight_maybe_owned;
79 
80   Tensor loss_squeezed = at::squeeze(loss);
81 
82   TensorIterator iter = TensorIteratorConfig()
83       .add_output(loss_squeezed)
84       .add_owned_input(at::squeeze(input))
85       .add_owned_input(at::squeeze(target))
86       .build();
87   AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.common_dtype(), "binary_cross_entropy_out_cuda", [&]() {
88     gpu_kernel(iter,
89       [] GPU_LAMBDA (scalar_t input_val, scalar_t target_val) -> scalar_t {
90         const scalar_t zero = 0;
91         const scalar_t one = 1;
92         const scalar_t neg_100 = -100;
93 
94         CUDA_KERNEL_ASSERT(input_val >= zero && input_val <= one);
95         CUDA_KERNEL_ASSERT(target_val >= zero && target_val <= one);
96 
97         scalar_t log_input_val = std::log(input_val);
98         scalar_t log_1_minus_input_val = std::log1p(-input_val);
99 
100         log_input_val = std::max(log_input_val, neg_100);
101         log_1_minus_input_val = std::max(log_1_minus_input_val, neg_100);
102 
103         return ((target_val - one) * log_1_minus_input_val) - (target_val * log_input_val);
104       }
105     );
106   });
107   if (weight.defined()) {
108     loss.mul_(weight);
109   }
110 
111   if (reduction != at::Reduction::None) {
112     Tensor loss_reduced;
113     if (reduction == at::Reduction::Mean) {
114       loss_reduced = loss.mean();
115     } else if (reduction == at::Reduction::Sum) {
116       loss_reduced = loss.sum();
117     }
118     loss.resize_as_(loss_reduced).copy_(loss_reduced);
119   }
120 
121   return loss;
122 }
123 
binary_cross_entropy_backward_cuda(const Tensor & grad,const Tensor & input,const Tensor & target,const std::optional<Tensor> & weight_opt,int64_t reduction)124 Tensor binary_cross_entropy_backward_cuda(const Tensor& grad, const Tensor& input, const Tensor& target, const std::optional<Tensor>& weight_opt, int64_t reduction) {
125   // See [Note: hacky wrapper removal for optional tensor]
126   c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
127   const Tensor& weight = *weight_maybe_owned;
128 
129   Tensor grad_input = at::empty_like(input);
130   return at::native::binary_cross_entropy_backward_out_cuda(
131       grad, input, target, weight, reduction, grad_input);
132 }
133 
binary_cross_entropy_backward_out_cuda(const Tensor & grad,const Tensor & input,const Tensor & target,const std::optional<Tensor> & weight_opt,int64_t reduction,Tensor & grad_input)134 Tensor& binary_cross_entropy_backward_out_cuda(const Tensor& grad, const Tensor& input, const Tensor& target, const std::optional<Tensor>& weight_opt, int64_t reduction, Tensor& grad_input) {
135   // See [Note: hacky wrapper removal for optional tensor]
136   c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
137   const Tensor& weight = *weight_maybe_owned;
138 
139   Tensor grad_expand = grad.expand_as(input);
140   binary_cross_entropy_backward_out_kernel(grad_input, grad_expand, input, target);
141 
142   if (weight.defined()) {
143     grad_input.mul_(weight);
144   }
145   if (reduction == at::Reduction::Mean) {
146     grad_input.div_(input.numel());
147   }
148   return grad_input;
149 }
150 
151 // -----------------------------------
152 // nll_loss
153 // -----------------------------------
154 namespace {
155 
156 constexpr int NLL_LOSS_THREADS = 32;
157 
158 // NOTE(crcrpar): `Byte` support was added for https://github.com/pytorch/pytorch/issues/59765.
159 #define AT_DISPATCH_NLL_LOSS_INDEX_TYPES(TYPE, NAME, ...)                     \
160   AT_DISPATCH_SWITCH(TYPE, NAME,                                              \
161   AT_PRIVATE_CASE_TYPE_USING_HINT(at::ScalarType::Byte, index_t, __VA_ARGS__) \
162   AT_PRIVATE_CASE_TYPE_USING_HINT(at::ScalarType::Long, index_t, __VA_ARGS__))
163 
164 #define CHECK_INDEX_IN_CLASS(INDEX, N_CLASSES)                                \
165   if constexpr(std::is_unsigned<decltype(INDEX)>::value) {                    \
166     CUDA_KERNEL_ASSERT(INDEX < N_CLASSES);                                    \
167   } else {                                                                    \
168     CUDA_KERNEL_ASSERT(INDEX >= 0 && INDEX < N_CLASSES);                      \
169   }
170 
171 template <typename scalar_t, typename index_t>
nll_loss_forward_no_reduce_cuda_kernel(int64_t batch_size,PackedTensorAccessor64<scalar_t,2> input,const index_t * target,scalar_t * output,const scalar_t * weights,int64_t n_classes,int64_t ignore_index)172 __global__ void nll_loss_forward_no_reduce_cuda_kernel(
173     int64_t batch_size,
174     PackedTensorAccessor64<scalar_t, 2> input,
175     const index_t* target,
176     scalar_t* output,
177     const scalar_t* weights,
178     int64_t n_classes,
179     int64_t ignore_index) {
180   CUDA_KERNEL_LOOP(index, batch_size) {
181     index_t cur_target = target[index];
182     if (cur_target == ignore_index) {
183       output[index] = static_cast<scalar_t>(0);
184       continue;
185     }
186     CHECK_INDEX_IN_CLASS(cur_target, n_classes);
187     auto cur_weight =
188         weights != nullptr ? weights[cur_target] : static_cast<scalar_t>(1);
189     output[index] = -cur_weight * input[index][cur_target];
190   }
191 }
192 
193 template <typename scalar_t, typename index_t>
nll_loss_forward_reduce_cuda_kernel_1d(scalar_t * output,scalar_t * total_weight,const scalar_t * input,const index_t * target,const scalar_t * weights,bool size_average,int64_t n_classes,int64_t ignore_index)194 __global__ void nll_loss_forward_reduce_cuda_kernel_1d(
195     scalar_t* output,
196     scalar_t* total_weight,
197     const scalar_t* input,
198     const index_t* target,
199     const scalar_t* weights,
200     bool size_average,
201     int64_t n_classes,
202     int64_t ignore_index) {
203   CUDA_KERNEL_ASSERT(threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0);
204 
205   const index_t t = *target;
206   if (t != ignore_index) {
207     CHECK_INDEX_IN_CLASS(t, n_classes);
208     const auto cur_weight = weights != nullptr ? weights[t] : scalar_t{1};
209     *total_weight = cur_weight;
210 
211     if (size_average) {
212       // If we try to normalize a zero then we return a NaN
213       if (cur_weight == 0) {
214         *output = std::numeric_limits<scalar_t>::quiet_NaN();
215       } else {
216         *output = -input[t];
217       }
218     } else {
219       *output = -cur_weight * input[t];
220     }
221   } else {
222     // If the only element was omitted, we get 0. See the discussion in
223     // https://github.com/pytorch/pytorch/pull/64572#issuecomment-926504162
224     *output = scalar_t{0};
225     *total_weight = scalar_t{0};
226   }
227 }
228 
229 template <typename scalar_t, typename accscalar_t, typename index_t>
nll_loss_forward_reduce_cuda_kernel_2d(scalar_t * output,scalar_t * total_weight,const scalar_t * input,const index_t * target,const scalar_t * weights,bool size_average,int64_t nframe,int64_t ndim,int64_t n_classes,int64_t ignore_index)230 __global__ void nll_loss_forward_reduce_cuda_kernel_2d(
231     scalar_t* output,
232     scalar_t* total_weight,
233     const scalar_t* input,
234     const index_t* target,
235     const scalar_t* weights,
236     bool size_average,
237     int64_t nframe,
238     int64_t ndim,
239     int64_t n_classes,
240     int64_t ignore_index) {
241   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
242   __shared__ accscalar_t sh_inputs[NLL_LOSS_THREADS],
243       acc_weight[NLL_LOSS_THREADS];
244 
245   sh_inputs[threadIdx.x] = static_cast<accscalar_t>(0);
246   acc_weight[threadIdx.x] = static_cast<accscalar_t>(0);
247   for (int i = threadIdx.x; i < nframe; i += NLL_LOSS_THREADS) {
248     index_t t = target[i];
249     if (t != ignore_index) {
250       CHECK_INDEX_IN_CLASS(t, n_classes);
251       scalar_t cur_weight =
252           weights != nullptr ? weights[t] : static_cast<scalar_t>(1);
253       sh_inputs[threadIdx.x] -= input[i * ndim + t] * cur_weight;
254       acc_weight[threadIdx.x] += cur_weight;
255     }
256   }
257 
258   __syncthreads();
259 
260   if (threadIdx.x == 0) {
261     accscalar_t output_acc = 0;
262     accscalar_t total_weight_acc = 0;
263     for (int i = 0; i < NLL_LOSS_THREADS; ++i) {
264       output_acc += sh_inputs[i];
265       total_weight_acc += acc_weight[i];
266     }
267     *total_weight = static_cast<scalar_t>(total_weight_acc);
268     if (size_average) {
269       *output = static_cast<scalar_t>(output_acc / total_weight_acc);
270     } else {
271       *output = static_cast<scalar_t>(output_acc);
272     }
273   }
274 }
275 
nll_loss_forward_out_cuda_template(const Tensor & output,const Tensor & total_weight,const Tensor & input_,const Tensor & target_,const Tensor & weight,int64_t reduction,int64_t ignore_index)276 void nll_loss_forward_out_cuda_template(
277     const Tensor& output,
278     const Tensor& total_weight,
279     const Tensor& input_,
280     const Tensor& target_,
281     const Tensor& weight,
282     int64_t reduction,
283     int64_t ignore_index) {
284   auto input = *input_.expect_contiguous();
285   auto target = *target_.expect_contiguous();
286 
287   int64_t n_classes = input.size(-1);
288   int64_t n_dims = input.dim();
289   int64_t batch_size = n_dims == 1 ? 1 : input.size(0);
290 
291   auto weight_ = weight.defined() ? weight.contiguous() : weight;
292 
293   if (reduction == Reduction::None && n_dims == 2) {
294     at::native::resize_output(output, {batch_size});
295     total_weight.zero_();
296     if (batch_size == 0) {
297       // This guards from unnecessary operations and launching CUDA kernel with
298       // 0 blocks.
299       return;
300     }
301 
302     AT_DISPATCH_FLOATING_TYPES_AND2(
303         at::ScalarType::Half,
304         at::ScalarType::BFloat16,
305         input.scalar_type(),
306         "nll_loss_forward_no_reduce_cuda_kernel",
307         [&] {
308           AT_DISPATCH_NLL_LOSS_INDEX_TYPES(
309               target.scalar_type(),
310               "nll_loss_forward_no_reduce_cuda_kernel_index",
311               [&] {
312                 nll_loss_forward_no_reduce_cuda_kernel<scalar_t, index_t>
313                     <<<at::cuda::detail::GET_BLOCKS(batch_size),
314                        at::cuda::detail::CUDA_NUM_THREADS,
315                        0,
316                        at::cuda::getCurrentCUDAStream()>>>(
317                         batch_size,
318                         input.packed_accessor64<scalar_t, 2>(),
319                         target.const_data_ptr<index_t>(),
320                         output.mutable_data_ptr<scalar_t>(),
321                         weight_.defined() ? weight_.const_data_ptr<scalar_t>()
322                                           : nullptr,
323                         n_classes,
324                         ignore_index);
325                 C10_CUDA_KERNEL_LAUNCH_CHECK();
326               });
327         });
328     return;
329   }
330 
331   // produce scalar outputs for the reduction case
332   at::native::resize_output(output, {});
333   total_weight.resize_({});
334 
335   if (target.numel() == 0) {
336     // Here target (and input) have zero elements
337     // Mean reduction on empty tensors produces NaN. See the discussion in
338     // https://github.com/pytorch/pytorch/pull/64572#issuecomment-926504162
339     if (reduction == Reduction::Mean) {
340       output.fill_(std::numeric_limits<double>::quiet_NaN());
341     } else {
342       output.zero_();
343     }
344     total_weight.zero_();
345     return;
346   }
347 
348   if (n_dims == 1) {
349     AT_DISPATCH_FLOATING_TYPES_AND2(
350         at::ScalarType::Half,
351         at::ScalarType::BFloat16,
352         input.scalar_type(),
353         "nll_loss_forward_reduce_cuda_kernel_1d",
354         [&] {
355           AT_DISPATCH_NLL_LOSS_INDEX_TYPES(
356               target.scalar_type(),
357               "nll_loss_forward_reduce_cuda_kernel_1d_index",
358               [&] {
359                 nll_loss_forward_reduce_cuda_kernel_1d<scalar_t, index_t>
360                     <<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
361                         output.mutable_data_ptr<scalar_t>(),
362                         total_weight.mutable_data_ptr<scalar_t>(),
363                         input.const_data_ptr<scalar_t>(),
364                         target.const_data_ptr<index_t>(),
365                         weight_.defined() ? weight_.const_data_ptr<scalar_t>()
366                                           : nullptr,
367                         reduction == at::Reduction::Mean,
368                         n_classes,
369                         ignore_index);
370                 C10_CUDA_KERNEL_LAUNCH_CHECK();
371               });
372         });
373   } else if (n_dims == 2) {
374     AT_DISPATCH_FLOATING_TYPES_AND2(
375         at::ScalarType::Half,
376         at::ScalarType::BFloat16,
377         input.scalar_type(),
378         "nll_loss_forward_reduce_cuda_kernel_2d",
379         [&] {
380           AT_DISPATCH_NLL_LOSS_INDEX_TYPES(
381               target.scalar_type(),
382               "nll_loss_forward_reduce_cuda_kernel_2d_index",
383               [&] {
384                 using accscalar_t = at::acc_type<scalar_t, /*is_cuda*/true>;
385                 nll_loss_forward_reduce_cuda_kernel_2d<scalar_t, accscalar_t, index_t>
386                     <<<1,
387                        NLL_LOSS_THREADS,
388                        0,
389                        at::cuda::getCurrentCUDAStream()>>>(
390                         output.mutable_data_ptr<scalar_t>(),
391                         total_weight.mutable_data_ptr<scalar_t>(),
392                         input.const_data_ptr<scalar_t>(),
393                         target.const_data_ptr<index_t>(),
394                         weight_.defined() ? weight_.const_data_ptr<scalar_t>()
395                                           : nullptr,
396                         reduction == at::Reduction::Mean,
397                         input.size(0),
398                         input.size(1),
399                         n_classes,
400                         ignore_index);
401                 C10_CUDA_KERNEL_LAUNCH_CHECK();
402               });
403         });
404   }
405 }
406 
407 template <typename scalar_t, typename index_t>
nll_loss_backward_no_reduce_cuda_kernel(int batch_size,const index_t * target,PackedTensorAccessor64<const scalar_t,1> grad_output,PackedTensorAccessor64<scalar_t,2> grad_input,const scalar_t * weights,int64_t n_classes,int64_t ignore_index)408 __global__ void nll_loss_backward_no_reduce_cuda_kernel(
409   int batch_size,
410   const index_t *target,
411   PackedTensorAccessor64<const scalar_t, 1> grad_output,
412   PackedTensorAccessor64<scalar_t, 2> grad_input,
413   const scalar_t *weights,
414   int64_t n_classes,
415   int64_t ignore_index) {
416 
417   CUDA_KERNEL_LOOP(index, batch_size) {
418     index_t cur_target = target[index];
419     if (cur_target == ignore_index) {
420       continue;
421     }
422     CHECK_INDEX_IN_CLASS(cur_target, n_classes);
423     scalar_t weight = weights != nullptr ? weights[cur_target] : static_cast<scalar_t>(1);
424     grad_input[index][cur_target] = -weight * grad_output[index];
425   }
426 };
427 
428 template <typename scalar_t, typename index_t>
nll_loss_backward_reduce_cuda_kernel_1d(scalar_t * grad_input,const scalar_t * grad_output,const scalar_t * weights,const index_t * target,const scalar_t * total_weight,bool size_average,int64_t n_classes,int64_t ignore_index)429 __global__ void nll_loss_backward_reduce_cuda_kernel_1d(
430   scalar_t *grad_input,
431   const scalar_t *grad_output,
432   const scalar_t *weights,
433   const index_t *target,
434   const scalar_t *total_weight,
435   bool size_average,
436   int64_t n_classes,
437   int64_t ignore_index
438 ) {
439   const index_t t = *target;
440   if (t != ignore_index) {
441     CHECK_INDEX_IN_CLASS(t, n_classes);
442     const auto grad = -(size_average ? *grad_output / *total_weight : *grad_output);
443     grad_input[t] = weights != nullptr ? weights[t] * grad : grad;
444   }
445 }
446 
447 template <typename T> struct bwd_index_type { using type = T; };
448 template<> struct bwd_index_type<uint8_t> { using type = int; };
449 template<> struct bwd_index_type<int64_t> { using type = uint64_t; };
450 
451 template <typename scalar_t, typename index_t>
nll_loss_backward_reduce_cuda_kernel_2d(scalar_t * grad_input,const scalar_t * grad_output,const index_t * target,const scalar_t * weights,const scalar_t * total_weight,bool size_average,int nframe,int ndim,int64_t n_classes,int64_t ignore_index)452 __global__ void nll_loss_backward_reduce_cuda_kernel_2d(
453     scalar_t* grad_input,
454     const scalar_t* grad_output,
455     const index_t* target,
456     const scalar_t* weights,
457     const scalar_t* total_weight,
458     bool size_average,
459     int nframe,
460     int ndim,
461     int64_t n_classes,
462     int64_t ignore_index) {
463   using bwd_index_t = typename bwd_index_type<index_t>::type;
464   const auto grad = -(size_average ? *grad_output / *total_weight
465                                    : *grad_output);
466 
467   for (int i = threadIdx.x; i < nframe; i += NLL_LOSS_THREADS) {
468     const index_t t = target[i];
469     if (t != ignore_index) {
470       CHECK_INDEX_IN_CLASS(t, n_classes);
471       // NOTE(crcrpar): this index could overflow in int64_t as `t` itself can be close to the max.
472       const bwd_index_t index = static_cast<bwd_index_t>(i) * ndim + t;
473       if constexpr(!std::is_unsigned<decltype(index)>::value) {
474         CUDA_KERNEL_ASSERT(index >= 0);
475       }
476       grad_input[index] = weights != nullptr ? weights[t] * grad : grad;
477     }
478   }
479 }
480 
nll_loss_backward_out_cuda_template(const Tensor & grad_input_,const Tensor & grad_output_,const Tensor & input_,const Tensor & target_,const Tensor & total_weight,const Tensor & weight,int64_t reduction,int64_t ignore_index)481 void nll_loss_backward_out_cuda_template(
482     const Tensor& grad_input_,
483     const Tensor& grad_output_,
484     const Tensor& input_,
485     const Tensor& target_,
486     const Tensor& total_weight,
487     const Tensor& weight,
488     int64_t reduction,
489     int64_t ignore_index) {
490   auto target = *target_.expect_contiguous();
491   auto input = *input_.expect_contiguous();
492   auto grad_input = *grad_input_.expect_contiguous();
493   auto grad_output = *grad_output_.expect_contiguous();
494 
495   int64_t n_dims = input.dim();
496   int64_t n_classes = input.size(-1);
497   int64_t batch_size = n_dims == 1 ? 1 : input.size(0);
498 
499   auto weight_ = weight.defined() ? weight.contiguous() : weight;
500 
501   if (reduction == at::Reduction::None && n_dims == 2) {
502     if (batch_size == 0) {
503       // This guards from unnecessary operations and launching CUDA kernel with 0 blocks.
504       return;
505     }
506     AT_DISPATCH_FLOATING_TYPES_AND2(
507         at::ScalarType::Half,
508         at::ScalarType::BFloat16,
509         input.scalar_type(),
510         "nll_loss_backward_no_reduce_cuda_kernel",
511         [&] {
512           AT_DISPATCH_NLL_LOSS_INDEX_TYPES(
513               target.scalar_type(),
514               "nll_loss_backward_no_reduce_cuda_kernel_index",
515               [&] {
516                 nll_loss_backward_no_reduce_cuda_kernel<scalar_t, index_t>
517                     <<<at::cuda::detail::GET_BLOCKS(batch_size),
518                        at::cuda::detail::CUDA_NUM_THREADS,
519                        0,
520                        at::cuda::getCurrentCUDAStream()>>>(
521                         batch_size,
522                         target.const_data_ptr<index_t>(),
523                         grad_output.packed_accessor64<const scalar_t, 1>(),
524                         grad_input.packed_accessor64<scalar_t, 2>(),
525                         weight.defined() ? weight_.const_data_ptr<scalar_t>() : nullptr,
526                         n_classes,
527                         ignore_index);
528                 C10_CUDA_KERNEL_LAUNCH_CHECK();
529               });
530         });
531     return;
532   }
533 
534   if (n_dims == 1) {
535     AT_DISPATCH_FLOATING_TYPES_AND2(
536         at::ScalarType::Half,
537         at::ScalarType::BFloat16,
538         input.scalar_type(),
539         "nll_loss_backward_reduce_cuda_kernel_1d",
540         [&] {
541           AT_DISPATCH_NLL_LOSS_INDEX_TYPES(
542               target.scalar_type(),
543               "nll_loss_backward_reduce_cuda_kernel_1d_index",
544               [&] {
545                 nll_loss_backward_reduce_cuda_kernel_1d<scalar_t, index_t>
546                     <<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
547                         grad_input.mutable_data_ptr<scalar_t>(),
548                         grad_output.const_data_ptr<scalar_t>(),
549                         weight.defined() ? weight_.const_data_ptr<scalar_t>()
550                                          : nullptr,
551                         target.const_data_ptr<index_t>(),
552                         total_weight.const_data_ptr<scalar_t>(),
553                         reduction == at::Reduction::Mean,
554                         n_classes,
555                         ignore_index);
556                 C10_CUDA_KERNEL_LAUNCH_CHECK();
557               });
558         });
559   } else {
560     AT_DISPATCH_FLOATING_TYPES_AND2(
561         at::ScalarType::Half,
562         at::ScalarType::BFloat16,
563         input.scalar_type(),
564         "nll_loss_backward_reduce_cuda_kernel_2d",
565         [&] {
566           AT_DISPATCH_NLL_LOSS_INDEX_TYPES(
567               target.scalar_type(),
568               "nll_loss_backward_reduce_cuda_kernel_2d_index",
569               [&] {
570             nll_loss_backward_reduce_cuda_kernel_2d<scalar_t, index_t>
571                 <<<1, NLL_LOSS_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
572                     grad_input.mutable_data_ptr<scalar_t>(),
573                     grad_output.const_data_ptr<scalar_t>(),
574                     target.const_data_ptr<index_t>(),
575                     weight.defined() ? weight_.const_data_ptr<scalar_t>() : nullptr,
576                     total_weight.const_data_ptr<scalar_t>(),
577                     reduction == at::Reduction::Mean,
578                     input.size(0),
579                     input.size(1),
580                     n_classes,
581                     ignore_index);
582             C10_CUDA_KERNEL_LAUNCH_CHECK();
583           });
584         });
585   }
586 }
587 
588 #undef AT_DISPATCH_NLL_LOSS_INDEX_TYPES
589 
590 } // namespace
591 
TORCH_IMPL_FUNC(nll_loss_forward_out_cuda)592 TORCH_IMPL_FUNC(nll_loss_forward_out_cuda)
593 (const Tensor& self,
594  const Tensor& target,
595  const OptionalTensorRef weight_opt,
596  int64_t reduction,
597  int64_t ignore_index,
598  const Tensor& output,
599  const Tensor& total_weight) {
600   const Tensor& weight = weight_opt.getTensorRef();
601   nll_loss_forward_out_cuda_template(
602       output, total_weight, self, target, weight, reduction, ignore_index);
603 }
604 
TORCH_IMPL_FUNC(nll_loss_backward_out_cuda)605 TORCH_IMPL_FUNC(nll_loss_backward_out_cuda)
606 (const Tensor& grad_output,
607  const Tensor& self,
608  const Tensor& target,
609  OptionalTensorRef weight_opt,
610  int64_t reduction,
611  int64_t ignore_index,
612  const Tensor& total_weight,
613  const Tensor& grad_input) {
614   const Tensor& weight = weight_opt.getTensorRef();
615   grad_input.zero_();
616   nll_loss_backward_out_cuda_template(
617       grad_input,
618       grad_output,
619       self,
620       target,
621       total_weight,
622       weight,
623       reduction,
624       ignore_index);
625 }
626 }  // namespace at::native
627