xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/MultiLabelMarginCriterion.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/AccumulateType.h>
4 #include <ATen/Dispatch.h>
5 #include <c10/macros/Macros.h>
6 #include <ATen/cuda/CUDAContext.h>
7 #include <ATen/native/cuda/block_reduce.cuh>
8 
9 #ifndef AT_PER_OPERATOR_HEADERS
10 #include <ATen/Functions.h>
11 #include <ATen/CUDAFunctions.h>
12 #include <ATen/NativeFunctions.h>
13 #else
14 #include <ATen/ops/empty.h>
15 #include <ATen/ops/zeros_like.h>
16 #include <ATen/ops/sum_cuda_dispatch.h>
17 #include <ATen/ops/multilabel_margin_loss.h>
18 #endif
19 
20 
21 namespace at::native {
22 
23 namespace {
24 const int MULTILABELMARGIN_THREADS = 128;
25 
multilabel_margin_loss_shape_check(int64_t & nframe,int64_t & dim,const int64_t & ndims,const Tensor & input,const Tensor & target)26 void multilabel_margin_loss_shape_check(
27     int64_t& nframe,
28     int64_t& dim,
29     const int64_t& ndims,
30     const Tensor& input,
31     const Tensor& target) {
32     TORCH_CHECK(
33         (ndims == 2 && input.size(1) != 0) || (ndims == 1 && input.size(0) != 0) || ndims == 0,
34         "Expected non-empty vector or matrix with optional 0-dim batch size, but got: ",
35         input.sizes());
36 
37     if (ndims <= 1) {
38       nframe = 1;
39       dim = ndims == 0 ? 1 : input.size(0);
40       TORCH_CHECK(
41           target.dim() <= 1 && target.numel() == dim,
42           "inconsistent target size: ", target.sizes(), " for input of size: ",
43           input.sizes());
44     } else {
45       nframe = input.size(0);
46       dim = input.size(1);
47       TORCH_CHECK(
48           target.dim() == 2 && target.size(0) == nframe &&
49           target.size(1) == dim,
50           "inconsistent target size: ", target.sizes(), " for input of size: ",
51           input.sizes());
52     }
53 }
54 
55 template <typename scalar_t, typename accscalar_t>
C10_LAUNCH_BOUNDS_1(MULTILABELMARGIN_THREADS)56 C10_LAUNCH_BOUNDS_1(MULTILABELMARGIN_THREADS)
57 __global__ void multilabel_margin_loss_forward_kernel(
58     scalar_t* output,
59     const scalar_t* input,
60     const int64_t* target,
61     scalar_t* is_target,
62     int nframe,
63     int dim,
64     bool size_average) {
65 
66   // vectors:
67   int k = blockIdx.x;
68   const scalar_t* input_k = input + k * dim;
69   const int64_t* target_k = target + k * dim;
70   scalar_t* output_k = output + k;
71   scalar_t* is_target_k = is_target + k * dim;
72 
73   // zero is_target
74   for (int d = threadIdx.x; d < dim; d += blockDim.x) {
75     is_target_k[d] = static_cast<scalar_t>(0);
76   }
77   __syncthreads();
78 
79   // mark targets in is_target
80   if (threadIdx.x == 0) {
81     for (int dt = 0; dt < dim; dt++) {
82       int target_idx = target_k[dt];
83       if (target_idx < 0) {
84         break;
85       }
86       is_target_k[target_idx] = static_cast<scalar_t>(1);
87     }
88   }
89   __syncthreads();
90 
91   // iterate over targets
92   accscalar_t sum = 0;
93   for (int dt = 0; dt < dim; dt++) {
94     // next target:
95     int target_idx = target_k[dt];
96     if (target_idx < 0) {
97       break;
98     }
99 
100     // current value for target
101     scalar_t input_target_k = input_k[target_idx];
102 
103     // compare to all inputs (multithreaded):
104     for (int d = threadIdx.x; d < dim; d += blockDim.x) {
105       // contribute to loss only if not a target
106       if (!static_cast<int>(is_target_k[d])) {
107         scalar_t z = 1 - input_target_k + input_k[d];
108         if (z > 0) {
109           sum += z;
110         }
111       }
112     }
113   }
114 
115   // Temporary sums (for mapreduce)
116   __shared__ accscalar_t smem[MULTILABELMARGIN_THREADS];
117   accscalar_t total_sum = cuda_utils::BlockReduceSum(sum, smem);
118   if (threadIdx.x == 0) {
119     if (size_average) {
120       *output_k = static_cast<scalar_t>((total_sum / dim) / nframe);
121     } else {
122       *output_k = static_cast<scalar_t>(total_sum / dim);
123     }
124   }
125 }
126 
127 template <typename scalar_t, typename accscalar_t>
C10_LAUNCH_BOUNDS_1(MULTILABELMARGIN_THREADS)128 C10_LAUNCH_BOUNDS_1(MULTILABELMARGIN_THREADS)
129 __global__ void multilabel_margin_loss_backward_kernel(
130     scalar_t* grad_input,
131     const scalar_t* grad_output,
132     const scalar_t* input,
133     const int64_t* target,
134     const scalar_t* is_target,
135     int nframe,
136     int dim,
137     bool size_average,
138     bool reduce) {
139 
140   int k = blockIdx.x;
141   const scalar_t* input_k = input + k * dim;
142   scalar_t* grad_input_k = grad_input + k * dim;
143   const int64_t* target_k = target + k * dim;
144   const scalar_t* is_target_k = is_target + k * dim;
145 
146   const scalar_t* grad_output_k = grad_output;
147   if (!reduce) {
148     grad_output_k += k;
149   }
150 
151   // gain:
152   scalar_t g = static_cast<scalar_t>(
153       size_average && reduce ? 1. / static_cast<accscalar_t>(nframe * dim)
154                              : 1. / static_cast<accscalar_t>(dim));
155 
156   // zero gradients:
157   for (int d = threadIdx.x; d < dim; d += blockDim.x) {
158     grad_input_k[d] = static_cast<scalar_t>(0);
159   }
160   __syncthreads();
161 
162   // iterate over targets
163   for (int dt = 0; dt < dim; dt++) {
164     // next target:
165     int target_idx = static_cast<int>(target_k[dt]);
166     if (target_idx < 0) {
167       break;
168     }
169 
170     // current value for target
171     scalar_t input_target_k = input_k[target_idx];
172 
173     // compare to all inputs (multithreaded):
174     accscalar_t sum = 0;
175     for (int d = threadIdx.x; d < dim; d += blockDim.x) {
176       // contribute to loss only if not a target
177       if (!static_cast<int>(is_target_k[d])) {
178         scalar_t z = 1 - input_target_k + input_k[d];
179         if (z > 0) {
180           sum -= g;
181           grad_input_k[d] += g;
182         }
183       }
184     }
185     __syncthreads();
186 
187     // Temporary sums (for mapreduce)
188     __shared__ accscalar_t smem[MULTILABELMARGIN_THREADS];
189     accscalar_t total_sum = cuda_utils::BlockReduceSum(sum, smem);
190     if (threadIdx.x == 0) {
191       grad_input_k[target_idx] += static_cast<scalar_t>(total_sum);
192     }
193   }
194 
195   for (int d = threadIdx.x; d < dim; d += blockDim.x) {
196     grad_input_k[d] *= *grad_output_k;
197   }
198 }
199 
multilabel_margin_loss_forward_out_cuda_template(const Tensor & input,const Tensor & target,int64_t reduction,Tensor & output,Tensor & is_target)200 void multilabel_margin_loss_forward_out_cuda_template(
201     const Tensor& input,
202     const Tensor& target,
203     int64_t reduction,
204     Tensor& output,
205     Tensor& is_target) {
206   int64_t nframe, dim;
207   const int64_t ndims = input.dim();
208   multilabel_margin_loss_shape_check(nframe, dim, ndims, input, target);
209 
210   if (input.numel() == 0) {
211     return;
212   }
213 
214   auto input_ = input.contiguous();
215   auto target_ = target.contiguous();
216   auto is_target_ = is_target.contiguous();
217   is_target_.resize_as_(target);
218 
219   if (input.dim() <= 1) {
220     output.resize_({});
221 
222     dim3 blocks(1);
223     dim3 threads(MULTILABELMARGIN_THREADS);
224 
225     AT_DISPATCH_FLOATING_TYPES_AND2(
226         at::ScalarType::Half,
227         at::ScalarType::BFloat16,
228         input.scalar_type(),
229         "multilabel_margin_loss_forward_kernel",
230         [&] {
231           using accscalar_t = at::acc_type<scalar_t, true>;
232           multilabel_margin_loss_forward_kernel<scalar_t, accscalar_t>
233               <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
234                   output.mutable_data_ptr<scalar_t>(),
235                   input_.const_data_ptr<scalar_t>(),
236                   target_.const_data_ptr<int64_t>(),
237                   is_target_.mutable_data_ptr<scalar_t>(),
238                   1,
239                   dim,
240                   reduction == at::Reduction::Mean);
241           C10_CUDA_KERNEL_LAUNCH_CHECK();
242         });
243   } else if (input.dim() == 2) {
244     dim3 blocks(input.size(0));
245     dim3 threads(MULTILABELMARGIN_THREADS);
246 
247     if (reduction != at::Reduction::None) {
248       auto output_tmp = at::empty({input_.size(0)}, input_.options());
249       output.resize_({});
250       AT_DISPATCH_FLOATING_TYPES_AND2(
251           at::ScalarType::Half,
252           at::ScalarType::BFloat16,
253           input.scalar_type(),
254           "multilabel_margin_loss_forward_kernel",
255           [&] {
256             using accscalar_t = at::acc_type<scalar_t, true>;
257             multilabel_margin_loss_forward_kernel<scalar_t, accscalar_t>
258                 <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
259                     output_tmp.mutable_data_ptr<scalar_t>(),
260                     input_.const_data_ptr<scalar_t>(),
261                     target_.const_data_ptr<int64_t>(),
262                     is_target_.mutable_data_ptr<scalar_t>(),
263                     nframe,
264                     dim,
265                     reduction == at::Reduction::Mean);
266             C10_CUDA_KERNEL_LAUNCH_CHECK();
267           });
268       at::cuda::sum_out(
269           output,
270           output_tmp,
271           at::IntArrayRef(std::vector<int64_t>{}),
272           false,
273           output.scalar_type());
274     } else {
275       output.resize_({input.size(0)});
276       AT_DISPATCH_FLOATING_TYPES_AND2(
277           at::ScalarType::Half,
278           at::ScalarType::BFloat16,
279           input.scalar_type(),
280           "multilabel_margin_loss_forward_kernel",
281           [&] {
282             using accscalar_t = at::acc_type<scalar_t, true>;
283             multilabel_margin_loss_forward_kernel<scalar_t, accscalar_t>
284                 <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
285                     output.mutable_data_ptr<scalar_t>(),
286                     input_.const_data_ptr<scalar_t>(),
287                     target_.const_data_ptr<int64_t>(),
288                     is_target_.mutable_data_ptr<scalar_t>(),
289                     nframe,
290                     dim,
291                     false);
292             C10_CUDA_KERNEL_LAUNCH_CHECK();
293           });
294     }
295 
296   } else {
297     TORCH_CHECK(
298         false,
299         "Expected 2D input with optional zero batch dim, or 1D input with non-zero dims, but got sizes: ",
300         input.sizes());
301   }
302 }
303 
multilabel_margin_loss_backward_cuda_out_template(const Tensor & grad_output,const Tensor & input,const Tensor & target,int64_t reduction,const Tensor & is_target,Tensor & grad_input)304 void multilabel_margin_loss_backward_cuda_out_template(
305     const Tensor& grad_output,
306     const Tensor& input,
307     const Tensor& target,
308     int64_t reduction,
309     const Tensor& is_target,
310     Tensor& grad_input) {
311   int64_t nframe, dim;
312   const int64_t ndims = input.dim();
313   multilabel_margin_loss_shape_check(nframe, dim, ndims, input, target);
314 
315   if (input.numel() == 0) {
316     return;
317   }
318 
319   auto input_ = input.contiguous();
320   auto target_ = target.contiguous();
321   auto is_target_ = is_target.contiguous();
322   auto grad_output_ = grad_output.contiguous();
323   grad_input.resize_as_(input_);
324 
325   if (grad_input.dim() <= 1) {
326     int target_size = target_.dim() == 0 ? 1 : target_.size(0);
327     TORCH_CHECK(
328         (target_.numel() != 0) && (target_.dim() <= 1) && (target_size == dim),
329         "inconsistent target size");
330     TORCH_CHECK(
331         target_.sizes() == is_target_.sizes(), "inconsistent is_target size");
332     dim3 blocks(1);
333     dim3 threads(MULTILABELMARGIN_THREADS);
334 
335     AT_DISPATCH_FLOATING_TYPES_AND2(
336         at::ScalarType::Half,
337         at::ScalarType::BFloat16,
338         input.scalar_type(),
339         "multilabel_margin_loss_backward_kernel",
340         [&] {
341           using accscalar_t = at::acc_type<scalar_t, true>;
342           multilabel_margin_loss_backward_kernel<scalar_t, accscalar_t>
343               <<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
344                   grad_input.mutable_data_ptr<scalar_t>(),
345                   grad_output_.const_data_ptr<scalar_t>(),
346                   input_.const_data_ptr<scalar_t>(),
347                   target_.const_data_ptr<int64_t>(),
348                   is_target_.const_data_ptr<scalar_t>(),
349                   1,
350                   dim,
351                   reduction == at::Reduction::Mean,
352                   reduction != at::Reduction::None);
353           C10_CUDA_KERNEL_LAUNCH_CHECK();
354         });
355   } else if (grad_input.dim() == 2) {
356     TORCH_CHECK(
357         (input_.size(1) != 0) && (target_.dim() == 2) &&
358             (target_.size(0) == nframe) && (target_.size(1) == dim),
359         "inconsistent target size");
360     TORCH_CHECK(target_.sizes() == is_target_.sizes(), "inconsistent is_target size");
361     dim3 blocks(grad_input.size(0));
362     dim3 threads(MULTILABELMARGIN_THREADS);
363 
364     AT_DISPATCH_FLOATING_TYPES_AND2(
365         at::ScalarType::Half,
366         at::ScalarType::BFloat16,
367         input.scalar_type(),
368         "multilabel_margin_loss_backward_kernel",
369         [&] {
370           using accscalar_t = at::acc_type<scalar_t, true>;
371           multilabel_margin_loss_backward_kernel<scalar_t, accscalar_t>
372               <<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
373                   grad_input.mutable_data_ptr<scalar_t>(),
374                   grad_output_.const_data_ptr<scalar_t>(),
375                   input_.const_data_ptr<scalar_t>(),
376                   target_.const_data_ptr<int64_t>(),
377                   is_target_.const_data_ptr<scalar_t>(),
378                   grad_input.size(0),
379                   grad_input.size(1),
380                   reduction == at::Reduction::Mean,
381                   reduction != at::Reduction::None);
382           C10_CUDA_KERNEL_LAUNCH_CHECK();
383         });
384   } else {
385     TORCH_CHECK(
386         false,
387         "Expected 2D input with optional zero batch dim, or 1D input with non-zero dims, but got sizes: ",
388         grad_input.sizes());
389   }
390 }
391 
392 } // namespace
393 
multilabel_margin_loss_forward_out_cuda(const Tensor & self,const Tensor & target,int64_t reduction,Tensor & output,Tensor & is_target)394 std::tuple<Tensor&, Tensor&> multilabel_margin_loss_forward_out_cuda(
395     const Tensor& self,
396     const Tensor& target,
397     int64_t reduction,
398     Tensor& output,
399     Tensor& is_target) {
400   multilabel_margin_loss_forward_out_cuda_template(
401       self, target, reduction, output, is_target);
402   return std::tuple<Tensor&, Tensor&>(output, is_target);
403 }
404 
multilabel_margin_loss_forward_cuda(const Tensor & self,const Tensor & target,int64_t reduction)405 std::tuple<Tensor, Tensor> multilabel_margin_loss_forward_cuda(
406     const Tensor& self,
407     const Tensor& target,
408     int64_t reduction) {
409   auto output = at::empty({0}, self.options());
410   auto is_target = at::empty({0}, self.options());
411   multilabel_margin_loss_forward_out_cuda_template(
412       self, target, reduction, output, is_target);
413   return std::make_tuple(output, is_target);
414 }
415 
multilabel_margin_loss_backward_cuda_out(const Tensor & grad_output,const Tensor & self,const Tensor & target,int64_t reduction,const Tensor & is_target,Tensor & grad_input)416 Tensor& multilabel_margin_loss_backward_cuda_out(
417     const Tensor& grad_output,
418     const Tensor& self,
419     const Tensor& target,
420     int64_t reduction,
421     const Tensor& is_target,
422     Tensor& grad_input) {
423   multilabel_margin_loss_backward_cuda_out_template(
424       grad_output, self, target, reduction, is_target, grad_input);
425   return grad_input;
426 }
427 
multilabel_margin_loss_backward_cuda(const Tensor & grad_output,const Tensor & self,const Tensor & target,int64_t reduction,const Tensor & is_target)428 Tensor multilabel_margin_loss_backward_cuda(
429     const Tensor& grad_output,
430     const Tensor& self,
431     const Tensor& target,
432     int64_t reduction,
433     const Tensor& is_target) {
434   auto grad_input = at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
435   multilabel_margin_loss_backward_cuda_out_template(
436       grad_output, self, target, reduction, is_target, grad_input);
437   return grad_input;
438 }
439 
440 } // namespace at::native
441