1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/AccumulateType.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/cuda/detail/KernelUtils.h>
6 #include <ATen/native/TensorIterator.h>
7 #include <ATen/TensorUtils.h>
8 #include <ATen/TensorOperators.h>
9 #include <ATen/cuda/detail/KernelUtils.h>
10 #include <ATen/native/cuda/Loops.cuh>
11 #include <ATen/native/Resize.h>
12
13 #include <type_traits>
14
15 #ifndef AT_PER_OPERATOR_HEADERS
16 #include <ATen/Functions.h>
17 #include <ATen/NativeFunctions.h>
18 #else
19 #include <ATen/ops/binary_cross_entropy_backward_native.h>
20 #include <ATen/ops/binary_cross_entropy_native.h>
21 #include <ATen/ops/empty_like.h>
22 #include <ATen/ops/exp.h>
23 #include <ATen/ops/nll_loss_backward_native.h>
24 #include <ATen/ops/nll_loss_forward_native.h>
25 #include <ATen/ops/squeeze.h>
26 #endif
27
28 constexpr float EPSILON = 1e-12;
29
30 namespace {
31
32 using namespace at;
33
binary_cross_entropy_backward_out_kernel(Tensor & grad_input,const Tensor & grad,const Tensor & input,const Tensor & target)34 void binary_cross_entropy_backward_out_kernel(Tensor& grad_input, const Tensor& grad, const Tensor& input, const Tensor& target) {
35 at::TensorIterator iter = TensorIteratorConfig()
36 .add_output(grad_input)
37 .add_input(grad)
38 .add_input(input)
39 .add_input(target)
40 .build();
41 AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.common_dtype(), "binary_cross_entropy_backward_out_cuda", [&]() {
42 at::native::gpu_kernel(iter, [] GPU_LAMBDA (
43 scalar_t grad_val,
44 scalar_t input_val,
45 scalar_t target_val
46 ) -> scalar_t {
47 const scalar_t one = 1;
48 const scalar_t epsilon = EPSILON;
49
50 scalar_t grad_input_denominator = max(
51 (one - input_val) * input_val,
52 epsilon
53 );
54
55 return grad_val * (input_val - target_val) / grad_input_denominator;
56 }
57 );
58 });
59 }
60
61 } // namespace
62
63 namespace at::native {
64
binary_cross_entropy_cuda(const Tensor & input,const Tensor & target,const std::optional<Tensor> & weight_opt,int64_t reduction)65 Tensor binary_cross_entropy_cuda(const Tensor& input, const Tensor& target, const std::optional<Tensor>& weight_opt, int64_t reduction) {
66 // See [Note: hacky wrapper removal for optional tensor]
67 c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
68 const Tensor& weight = *weight_maybe_owned;
69
70 Tensor loss = at::empty_like(input);
71 return at::native::binary_cross_entropy_out_cuda(
72 input, target, weight, reduction, loss);
73 }
74
binary_cross_entropy_out_cuda(const Tensor & input,const Tensor & target,const std::optional<Tensor> & weight_opt,int64_t reduction,Tensor & loss)75 Tensor& binary_cross_entropy_out_cuda(const Tensor& input, const Tensor& target, const std::optional<Tensor>& weight_opt, int64_t reduction, Tensor& loss) {
76 // See [Note: hacky wrapper removal for optional tensor]
77 c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
78 const Tensor& weight = *weight_maybe_owned;
79
80 Tensor loss_squeezed = at::squeeze(loss);
81
82 TensorIterator iter = TensorIteratorConfig()
83 .add_output(loss_squeezed)
84 .add_owned_input(at::squeeze(input))
85 .add_owned_input(at::squeeze(target))
86 .build();
87 AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.common_dtype(), "binary_cross_entropy_out_cuda", [&]() {
88 gpu_kernel(iter,
89 [] GPU_LAMBDA (scalar_t input_val, scalar_t target_val) -> scalar_t {
90 const scalar_t zero = 0;
91 const scalar_t one = 1;
92 const scalar_t neg_100 = -100;
93
94 CUDA_KERNEL_ASSERT(input_val >= zero && input_val <= one);
95 CUDA_KERNEL_ASSERT(target_val >= zero && target_val <= one);
96
97 scalar_t log_input_val = std::log(input_val);
98 scalar_t log_1_minus_input_val = std::log1p(-input_val);
99
100 log_input_val = std::max(log_input_val, neg_100);
101 log_1_minus_input_val = std::max(log_1_minus_input_val, neg_100);
102
103 return ((target_val - one) * log_1_minus_input_val) - (target_val * log_input_val);
104 }
105 );
106 });
107 if (weight.defined()) {
108 loss.mul_(weight);
109 }
110
111 if (reduction != at::Reduction::None) {
112 Tensor loss_reduced;
113 if (reduction == at::Reduction::Mean) {
114 loss_reduced = loss.mean();
115 } else if (reduction == at::Reduction::Sum) {
116 loss_reduced = loss.sum();
117 }
118 loss.resize_as_(loss_reduced).copy_(loss_reduced);
119 }
120
121 return loss;
122 }
123
binary_cross_entropy_backward_cuda(const Tensor & grad,const Tensor & input,const Tensor & target,const std::optional<Tensor> & weight_opt,int64_t reduction)124 Tensor binary_cross_entropy_backward_cuda(const Tensor& grad, const Tensor& input, const Tensor& target, const std::optional<Tensor>& weight_opt, int64_t reduction) {
125 // See [Note: hacky wrapper removal for optional tensor]
126 c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
127 const Tensor& weight = *weight_maybe_owned;
128
129 Tensor grad_input = at::empty_like(input);
130 return at::native::binary_cross_entropy_backward_out_cuda(
131 grad, input, target, weight, reduction, grad_input);
132 }
133
binary_cross_entropy_backward_out_cuda(const Tensor & grad,const Tensor & input,const Tensor & target,const std::optional<Tensor> & weight_opt,int64_t reduction,Tensor & grad_input)134 Tensor& binary_cross_entropy_backward_out_cuda(const Tensor& grad, const Tensor& input, const Tensor& target, const std::optional<Tensor>& weight_opt, int64_t reduction, Tensor& grad_input) {
135 // See [Note: hacky wrapper removal for optional tensor]
136 c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
137 const Tensor& weight = *weight_maybe_owned;
138
139 Tensor grad_expand = grad.expand_as(input);
140 binary_cross_entropy_backward_out_kernel(grad_input, grad_expand, input, target);
141
142 if (weight.defined()) {
143 grad_input.mul_(weight);
144 }
145 if (reduction == at::Reduction::Mean) {
146 grad_input.div_(input.numel());
147 }
148 return grad_input;
149 }
150
151 // -----------------------------------
152 // nll_loss
153 // -----------------------------------
154 namespace {
155
156 constexpr int NLL_LOSS_THREADS = 32;
157
158 // NOTE(crcrpar): `Byte` support was added for https://github.com/pytorch/pytorch/issues/59765.
159 #define AT_DISPATCH_NLL_LOSS_INDEX_TYPES(TYPE, NAME, ...) \
160 AT_DISPATCH_SWITCH(TYPE, NAME, \
161 AT_PRIVATE_CASE_TYPE_USING_HINT(at::ScalarType::Byte, index_t, __VA_ARGS__) \
162 AT_PRIVATE_CASE_TYPE_USING_HINT(at::ScalarType::Long, index_t, __VA_ARGS__))
163
164 #define CHECK_INDEX_IN_CLASS(INDEX, N_CLASSES) \
165 if constexpr(std::is_unsigned<decltype(INDEX)>::value) { \
166 CUDA_KERNEL_ASSERT(INDEX < N_CLASSES); \
167 } else { \
168 CUDA_KERNEL_ASSERT(INDEX >= 0 && INDEX < N_CLASSES); \
169 }
170
171 template <typename scalar_t, typename index_t>
nll_loss_forward_no_reduce_cuda_kernel(int64_t batch_size,PackedTensorAccessor64<scalar_t,2> input,const index_t * target,scalar_t * output,const scalar_t * weights,int64_t n_classes,int64_t ignore_index)172 __global__ void nll_loss_forward_no_reduce_cuda_kernel(
173 int64_t batch_size,
174 PackedTensorAccessor64<scalar_t, 2> input,
175 const index_t* target,
176 scalar_t* output,
177 const scalar_t* weights,
178 int64_t n_classes,
179 int64_t ignore_index) {
180 CUDA_KERNEL_LOOP(index, batch_size) {
181 index_t cur_target = target[index];
182 if (cur_target == ignore_index) {
183 output[index] = static_cast<scalar_t>(0);
184 continue;
185 }
186 CHECK_INDEX_IN_CLASS(cur_target, n_classes);
187 auto cur_weight =
188 weights != nullptr ? weights[cur_target] : static_cast<scalar_t>(1);
189 output[index] = -cur_weight * input[index][cur_target];
190 }
191 }
192
193 template <typename scalar_t, typename index_t>
nll_loss_forward_reduce_cuda_kernel_1d(scalar_t * output,scalar_t * total_weight,const scalar_t * input,const index_t * target,const scalar_t * weights,bool size_average,int64_t n_classes,int64_t ignore_index)194 __global__ void nll_loss_forward_reduce_cuda_kernel_1d(
195 scalar_t* output,
196 scalar_t* total_weight,
197 const scalar_t* input,
198 const index_t* target,
199 const scalar_t* weights,
200 bool size_average,
201 int64_t n_classes,
202 int64_t ignore_index) {
203 CUDA_KERNEL_ASSERT(threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0);
204
205 const index_t t = *target;
206 if (t != ignore_index) {
207 CHECK_INDEX_IN_CLASS(t, n_classes);
208 const auto cur_weight = weights != nullptr ? weights[t] : scalar_t{1};
209 *total_weight = cur_weight;
210
211 if (size_average) {
212 // If we try to normalize a zero then we return a NaN
213 if (cur_weight == 0) {
214 *output = std::numeric_limits<scalar_t>::quiet_NaN();
215 } else {
216 *output = -input[t];
217 }
218 } else {
219 *output = -cur_weight * input[t];
220 }
221 } else {
222 // If the only element was omitted, we get 0. See the discussion in
223 // https://github.com/pytorch/pytorch/pull/64572#issuecomment-926504162
224 *output = scalar_t{0};
225 *total_weight = scalar_t{0};
226 }
227 }
228
229 template <typename scalar_t, typename accscalar_t, typename index_t>
nll_loss_forward_reduce_cuda_kernel_2d(scalar_t * output,scalar_t * total_weight,const scalar_t * input,const index_t * target,const scalar_t * weights,bool size_average,int64_t nframe,int64_t ndim,int64_t n_classes,int64_t ignore_index)230 __global__ void nll_loss_forward_reduce_cuda_kernel_2d(
231 scalar_t* output,
232 scalar_t* total_weight,
233 const scalar_t* input,
234 const index_t* target,
235 const scalar_t* weights,
236 bool size_average,
237 int64_t nframe,
238 int64_t ndim,
239 int64_t n_classes,
240 int64_t ignore_index) {
241 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
242 __shared__ accscalar_t sh_inputs[NLL_LOSS_THREADS],
243 acc_weight[NLL_LOSS_THREADS];
244
245 sh_inputs[threadIdx.x] = static_cast<accscalar_t>(0);
246 acc_weight[threadIdx.x] = static_cast<accscalar_t>(0);
247 for (int i = threadIdx.x; i < nframe; i += NLL_LOSS_THREADS) {
248 index_t t = target[i];
249 if (t != ignore_index) {
250 CHECK_INDEX_IN_CLASS(t, n_classes);
251 scalar_t cur_weight =
252 weights != nullptr ? weights[t] : static_cast<scalar_t>(1);
253 sh_inputs[threadIdx.x] -= input[i * ndim + t] * cur_weight;
254 acc_weight[threadIdx.x] += cur_weight;
255 }
256 }
257
258 __syncthreads();
259
260 if (threadIdx.x == 0) {
261 accscalar_t output_acc = 0;
262 accscalar_t total_weight_acc = 0;
263 for (int i = 0; i < NLL_LOSS_THREADS; ++i) {
264 output_acc += sh_inputs[i];
265 total_weight_acc += acc_weight[i];
266 }
267 *total_weight = static_cast<scalar_t>(total_weight_acc);
268 if (size_average) {
269 *output = static_cast<scalar_t>(output_acc / total_weight_acc);
270 } else {
271 *output = static_cast<scalar_t>(output_acc);
272 }
273 }
274 }
275
nll_loss_forward_out_cuda_template(const Tensor & output,const Tensor & total_weight,const Tensor & input_,const Tensor & target_,const Tensor & weight,int64_t reduction,int64_t ignore_index)276 void nll_loss_forward_out_cuda_template(
277 const Tensor& output,
278 const Tensor& total_weight,
279 const Tensor& input_,
280 const Tensor& target_,
281 const Tensor& weight,
282 int64_t reduction,
283 int64_t ignore_index) {
284 auto input = *input_.expect_contiguous();
285 auto target = *target_.expect_contiguous();
286
287 int64_t n_classes = input.size(-1);
288 int64_t n_dims = input.dim();
289 int64_t batch_size = n_dims == 1 ? 1 : input.size(0);
290
291 auto weight_ = weight.defined() ? weight.contiguous() : weight;
292
293 if (reduction == Reduction::None && n_dims == 2) {
294 at::native::resize_output(output, {batch_size});
295 total_weight.zero_();
296 if (batch_size == 0) {
297 // This guards from unnecessary operations and launching CUDA kernel with
298 // 0 blocks.
299 return;
300 }
301
302 AT_DISPATCH_FLOATING_TYPES_AND2(
303 at::ScalarType::Half,
304 at::ScalarType::BFloat16,
305 input.scalar_type(),
306 "nll_loss_forward_no_reduce_cuda_kernel",
307 [&] {
308 AT_DISPATCH_NLL_LOSS_INDEX_TYPES(
309 target.scalar_type(),
310 "nll_loss_forward_no_reduce_cuda_kernel_index",
311 [&] {
312 nll_loss_forward_no_reduce_cuda_kernel<scalar_t, index_t>
313 <<<at::cuda::detail::GET_BLOCKS(batch_size),
314 at::cuda::detail::CUDA_NUM_THREADS,
315 0,
316 at::cuda::getCurrentCUDAStream()>>>(
317 batch_size,
318 input.packed_accessor64<scalar_t, 2>(),
319 target.const_data_ptr<index_t>(),
320 output.mutable_data_ptr<scalar_t>(),
321 weight_.defined() ? weight_.const_data_ptr<scalar_t>()
322 : nullptr,
323 n_classes,
324 ignore_index);
325 C10_CUDA_KERNEL_LAUNCH_CHECK();
326 });
327 });
328 return;
329 }
330
331 // produce scalar outputs for the reduction case
332 at::native::resize_output(output, {});
333 total_weight.resize_({});
334
335 if (target.numel() == 0) {
336 // Here target (and input) have zero elements
337 // Mean reduction on empty tensors produces NaN. See the discussion in
338 // https://github.com/pytorch/pytorch/pull/64572#issuecomment-926504162
339 if (reduction == Reduction::Mean) {
340 output.fill_(std::numeric_limits<double>::quiet_NaN());
341 } else {
342 output.zero_();
343 }
344 total_weight.zero_();
345 return;
346 }
347
348 if (n_dims == 1) {
349 AT_DISPATCH_FLOATING_TYPES_AND2(
350 at::ScalarType::Half,
351 at::ScalarType::BFloat16,
352 input.scalar_type(),
353 "nll_loss_forward_reduce_cuda_kernel_1d",
354 [&] {
355 AT_DISPATCH_NLL_LOSS_INDEX_TYPES(
356 target.scalar_type(),
357 "nll_loss_forward_reduce_cuda_kernel_1d_index",
358 [&] {
359 nll_loss_forward_reduce_cuda_kernel_1d<scalar_t, index_t>
360 <<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
361 output.mutable_data_ptr<scalar_t>(),
362 total_weight.mutable_data_ptr<scalar_t>(),
363 input.const_data_ptr<scalar_t>(),
364 target.const_data_ptr<index_t>(),
365 weight_.defined() ? weight_.const_data_ptr<scalar_t>()
366 : nullptr,
367 reduction == at::Reduction::Mean,
368 n_classes,
369 ignore_index);
370 C10_CUDA_KERNEL_LAUNCH_CHECK();
371 });
372 });
373 } else if (n_dims == 2) {
374 AT_DISPATCH_FLOATING_TYPES_AND2(
375 at::ScalarType::Half,
376 at::ScalarType::BFloat16,
377 input.scalar_type(),
378 "nll_loss_forward_reduce_cuda_kernel_2d",
379 [&] {
380 AT_DISPATCH_NLL_LOSS_INDEX_TYPES(
381 target.scalar_type(),
382 "nll_loss_forward_reduce_cuda_kernel_2d_index",
383 [&] {
384 using accscalar_t = at::acc_type<scalar_t, /*is_cuda*/true>;
385 nll_loss_forward_reduce_cuda_kernel_2d<scalar_t, accscalar_t, index_t>
386 <<<1,
387 NLL_LOSS_THREADS,
388 0,
389 at::cuda::getCurrentCUDAStream()>>>(
390 output.mutable_data_ptr<scalar_t>(),
391 total_weight.mutable_data_ptr<scalar_t>(),
392 input.const_data_ptr<scalar_t>(),
393 target.const_data_ptr<index_t>(),
394 weight_.defined() ? weight_.const_data_ptr<scalar_t>()
395 : nullptr,
396 reduction == at::Reduction::Mean,
397 input.size(0),
398 input.size(1),
399 n_classes,
400 ignore_index);
401 C10_CUDA_KERNEL_LAUNCH_CHECK();
402 });
403 });
404 }
405 }
406
407 template <typename scalar_t, typename index_t>
nll_loss_backward_no_reduce_cuda_kernel(int batch_size,const index_t * target,PackedTensorAccessor64<const scalar_t,1> grad_output,PackedTensorAccessor64<scalar_t,2> grad_input,const scalar_t * weights,int64_t n_classes,int64_t ignore_index)408 __global__ void nll_loss_backward_no_reduce_cuda_kernel(
409 int batch_size,
410 const index_t *target,
411 PackedTensorAccessor64<const scalar_t, 1> grad_output,
412 PackedTensorAccessor64<scalar_t, 2> grad_input,
413 const scalar_t *weights,
414 int64_t n_classes,
415 int64_t ignore_index) {
416
417 CUDA_KERNEL_LOOP(index, batch_size) {
418 index_t cur_target = target[index];
419 if (cur_target == ignore_index) {
420 continue;
421 }
422 CHECK_INDEX_IN_CLASS(cur_target, n_classes);
423 scalar_t weight = weights != nullptr ? weights[cur_target] : static_cast<scalar_t>(1);
424 grad_input[index][cur_target] = -weight * grad_output[index];
425 }
426 };
427
428 template <typename scalar_t, typename index_t>
nll_loss_backward_reduce_cuda_kernel_1d(scalar_t * grad_input,const scalar_t * grad_output,const scalar_t * weights,const index_t * target,const scalar_t * total_weight,bool size_average,int64_t n_classes,int64_t ignore_index)429 __global__ void nll_loss_backward_reduce_cuda_kernel_1d(
430 scalar_t *grad_input,
431 const scalar_t *grad_output,
432 const scalar_t *weights,
433 const index_t *target,
434 const scalar_t *total_weight,
435 bool size_average,
436 int64_t n_classes,
437 int64_t ignore_index
438 ) {
439 const index_t t = *target;
440 if (t != ignore_index) {
441 CHECK_INDEX_IN_CLASS(t, n_classes);
442 const auto grad = -(size_average ? *grad_output / *total_weight : *grad_output);
443 grad_input[t] = weights != nullptr ? weights[t] * grad : grad;
444 }
445 }
446
447 template <typename T> struct bwd_index_type { using type = T; };
448 template<> struct bwd_index_type<uint8_t> { using type = int; };
449 template<> struct bwd_index_type<int64_t> { using type = uint64_t; };
450
451 template <typename scalar_t, typename index_t>
nll_loss_backward_reduce_cuda_kernel_2d(scalar_t * grad_input,const scalar_t * grad_output,const index_t * target,const scalar_t * weights,const scalar_t * total_weight,bool size_average,int nframe,int ndim,int64_t n_classes,int64_t ignore_index)452 __global__ void nll_loss_backward_reduce_cuda_kernel_2d(
453 scalar_t* grad_input,
454 const scalar_t* grad_output,
455 const index_t* target,
456 const scalar_t* weights,
457 const scalar_t* total_weight,
458 bool size_average,
459 int nframe,
460 int ndim,
461 int64_t n_classes,
462 int64_t ignore_index) {
463 using bwd_index_t = typename bwd_index_type<index_t>::type;
464 const auto grad = -(size_average ? *grad_output / *total_weight
465 : *grad_output);
466
467 for (int i = threadIdx.x; i < nframe; i += NLL_LOSS_THREADS) {
468 const index_t t = target[i];
469 if (t != ignore_index) {
470 CHECK_INDEX_IN_CLASS(t, n_classes);
471 // NOTE(crcrpar): this index could overflow in int64_t as `t` itself can be close to the max.
472 const bwd_index_t index = static_cast<bwd_index_t>(i) * ndim + t;
473 if constexpr(!std::is_unsigned<decltype(index)>::value) {
474 CUDA_KERNEL_ASSERT(index >= 0);
475 }
476 grad_input[index] = weights != nullptr ? weights[t] * grad : grad;
477 }
478 }
479 }
480
nll_loss_backward_out_cuda_template(const Tensor & grad_input_,const Tensor & grad_output_,const Tensor & input_,const Tensor & target_,const Tensor & total_weight,const Tensor & weight,int64_t reduction,int64_t ignore_index)481 void nll_loss_backward_out_cuda_template(
482 const Tensor& grad_input_,
483 const Tensor& grad_output_,
484 const Tensor& input_,
485 const Tensor& target_,
486 const Tensor& total_weight,
487 const Tensor& weight,
488 int64_t reduction,
489 int64_t ignore_index) {
490 auto target = *target_.expect_contiguous();
491 auto input = *input_.expect_contiguous();
492 auto grad_input = *grad_input_.expect_contiguous();
493 auto grad_output = *grad_output_.expect_contiguous();
494
495 int64_t n_dims = input.dim();
496 int64_t n_classes = input.size(-1);
497 int64_t batch_size = n_dims == 1 ? 1 : input.size(0);
498
499 auto weight_ = weight.defined() ? weight.contiguous() : weight;
500
501 if (reduction == at::Reduction::None && n_dims == 2) {
502 if (batch_size == 0) {
503 // This guards from unnecessary operations and launching CUDA kernel with 0 blocks.
504 return;
505 }
506 AT_DISPATCH_FLOATING_TYPES_AND2(
507 at::ScalarType::Half,
508 at::ScalarType::BFloat16,
509 input.scalar_type(),
510 "nll_loss_backward_no_reduce_cuda_kernel",
511 [&] {
512 AT_DISPATCH_NLL_LOSS_INDEX_TYPES(
513 target.scalar_type(),
514 "nll_loss_backward_no_reduce_cuda_kernel_index",
515 [&] {
516 nll_loss_backward_no_reduce_cuda_kernel<scalar_t, index_t>
517 <<<at::cuda::detail::GET_BLOCKS(batch_size),
518 at::cuda::detail::CUDA_NUM_THREADS,
519 0,
520 at::cuda::getCurrentCUDAStream()>>>(
521 batch_size,
522 target.const_data_ptr<index_t>(),
523 grad_output.packed_accessor64<const scalar_t, 1>(),
524 grad_input.packed_accessor64<scalar_t, 2>(),
525 weight.defined() ? weight_.const_data_ptr<scalar_t>() : nullptr,
526 n_classes,
527 ignore_index);
528 C10_CUDA_KERNEL_LAUNCH_CHECK();
529 });
530 });
531 return;
532 }
533
534 if (n_dims == 1) {
535 AT_DISPATCH_FLOATING_TYPES_AND2(
536 at::ScalarType::Half,
537 at::ScalarType::BFloat16,
538 input.scalar_type(),
539 "nll_loss_backward_reduce_cuda_kernel_1d",
540 [&] {
541 AT_DISPATCH_NLL_LOSS_INDEX_TYPES(
542 target.scalar_type(),
543 "nll_loss_backward_reduce_cuda_kernel_1d_index",
544 [&] {
545 nll_loss_backward_reduce_cuda_kernel_1d<scalar_t, index_t>
546 <<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
547 grad_input.mutable_data_ptr<scalar_t>(),
548 grad_output.const_data_ptr<scalar_t>(),
549 weight.defined() ? weight_.const_data_ptr<scalar_t>()
550 : nullptr,
551 target.const_data_ptr<index_t>(),
552 total_weight.const_data_ptr<scalar_t>(),
553 reduction == at::Reduction::Mean,
554 n_classes,
555 ignore_index);
556 C10_CUDA_KERNEL_LAUNCH_CHECK();
557 });
558 });
559 } else {
560 AT_DISPATCH_FLOATING_TYPES_AND2(
561 at::ScalarType::Half,
562 at::ScalarType::BFloat16,
563 input.scalar_type(),
564 "nll_loss_backward_reduce_cuda_kernel_2d",
565 [&] {
566 AT_DISPATCH_NLL_LOSS_INDEX_TYPES(
567 target.scalar_type(),
568 "nll_loss_backward_reduce_cuda_kernel_2d_index",
569 [&] {
570 nll_loss_backward_reduce_cuda_kernel_2d<scalar_t, index_t>
571 <<<1, NLL_LOSS_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
572 grad_input.mutable_data_ptr<scalar_t>(),
573 grad_output.const_data_ptr<scalar_t>(),
574 target.const_data_ptr<index_t>(),
575 weight.defined() ? weight_.const_data_ptr<scalar_t>() : nullptr,
576 total_weight.const_data_ptr<scalar_t>(),
577 reduction == at::Reduction::Mean,
578 input.size(0),
579 input.size(1),
580 n_classes,
581 ignore_index);
582 C10_CUDA_KERNEL_LAUNCH_CHECK();
583 });
584 });
585 }
586 }
587
588 #undef AT_DISPATCH_NLL_LOSS_INDEX_TYPES
589
590 } // namespace
591
TORCH_IMPL_FUNC(nll_loss_forward_out_cuda)592 TORCH_IMPL_FUNC(nll_loss_forward_out_cuda)
593 (const Tensor& self,
594 const Tensor& target,
595 const OptionalTensorRef weight_opt,
596 int64_t reduction,
597 int64_t ignore_index,
598 const Tensor& output,
599 const Tensor& total_weight) {
600 const Tensor& weight = weight_opt.getTensorRef();
601 nll_loss_forward_out_cuda_template(
602 output, total_weight, self, target, weight, reduction, ignore_index);
603 }
604
TORCH_IMPL_FUNC(nll_loss_backward_out_cuda)605 TORCH_IMPL_FUNC(nll_loss_backward_out_cuda)
606 (const Tensor& grad_output,
607 const Tensor& self,
608 const Tensor& target,
609 OptionalTensorRef weight_opt,
610 int64_t reduction,
611 int64_t ignore_index,
612 const Tensor& total_weight,
613 const Tensor& grad_input) {
614 const Tensor& weight = weight_opt.getTensorRef();
615 grad_input.zero_();
616 nll_loss_backward_out_cuda_template(
617 grad_input,
618 grad_output,
619 self,
620 target,
621 total_weight,
622 weight,
623 reduction,
624 ignore_index);
625 }
626 } // namespace at::native
627