xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/MultiMarginLoss.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/AccumulateType.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/native/Resize.h>
6 #include <c10/cuda/CUDAStream.h>
7 #include <c10/cuda/CUDAException.h>
8 
9 #ifndef AT_PER_OPERATOR_HEADERS
10 #include <ATen/Functions.h>
11 #include <ATen/NativeFunctions.h>
12 #else
13 #include <ATen/ops/empty.h>
14 #include <ATen/ops/sum.h>
15 #include <ATen/ops/multi_margin_loss_native.h>
16 #include <ATen/ops/multi_margin_loss_backward_native.h>
17 #endif
18 
19 namespace at::native {
20 namespace {
21 constexpr int MULTIMARGIN_THREADS = 128;
22 
23 template <int P, typename scalar_t>
MultiMarginLoss_forward_kernel(scalar_t * output,const scalar_t * input,const int64_t * target,const scalar_t * weights,int nframe,int dim,bool sizeAverage,scalar_t margin)24 __global__ void MultiMarginLoss_forward_kernel(
25     scalar_t *output, const scalar_t *input, const int64_t *target, const scalar_t *weights,
26     int nframe, int dim, bool sizeAverage, scalar_t margin) {
27   using acc_t = at::acc_type<scalar_t, true>;
28   __shared__ acc_t buffer[MULTIMARGIN_THREADS];
29   int k = blockIdx.x;
30   const scalar_t *input_k = input + k*dim;
31   scalar_t *output_k = output + k;
32   int target_k = static_cast<int>(target[k]);
33   CUDA_KERNEL_ASSERT(target_k >= 0 && target_k < dim && "target index is out of bounds");
34   scalar_t input_target_k = input_k[target_k];
35 
36   int i_start = threadIdx.x;
37   int i_end = dim;
38   int i_step = blockDim.x;
39 
40   buffer[threadIdx.x] = 0;
41   for (int i = i_start; i < i_end; i += i_step) {
42     scalar_t z = margin - input_target_k + input_k[i];
43     if (i == target_k) {
44       continue;
45     }
46 
47     if (z > 0) {
48       scalar_t h = (P==1) ? z : z*z;
49       if (weights) {
50         h *= weights[target_k];
51       }
52       buffer[threadIdx.x] += h;
53     }
54   }
55   __syncthreads();
56 
57   // reduce
58   if (threadIdx.x == 0) {
59     acc_t sum = 0;
60     for (int i=0; i < blockDim.x; i++)
61       sum += buffer[i];
62 
63     const int denom = sizeAverage ? nframe * dim : dim;
64     *output_k = static_cast<scalar_t>(sum / denom);
65   }
66 }
67 
68 template <int P, typename scalar_t>
MultiMarginLoss_backward_kernel(scalar_t * gradInput,const scalar_t * gradOutput,const scalar_t * input,const int64_t * target,const scalar_t * weights,int nframe,int dim,bool sizeAverage,scalar_t margin,bool reduce)69 __global__ void MultiMarginLoss_backward_kernel(
70     scalar_t *gradInput, const scalar_t *gradOutput, const scalar_t *input, const int64_t *target,
71     const scalar_t *weights, int nframe, int dim, bool sizeAverage, scalar_t margin,
72     bool reduce) {
73   using acc_t = at::acc_type<scalar_t, true>;
74   __shared__ acc_t buffer[MULTIMARGIN_THREADS];
75   int k = blockIdx.x;
76   const scalar_t *input_k = input + k*dim;
77   scalar_t *gradInput_k = gradInput + k*dim;
78   int target_k = static_cast<int>(target[k]);
79   scalar_t input_target_k = input_k[target_k];
80 
81   const scalar_t *gradOutput_k = gradOutput;
82   if (!reduce) {
83     gradOutput_k += k;
84   }
85 
86   const int denom = sizeAverage && reduce ? nframe * dim : dim;
87   const acc_t g = acc_t(1) / static_cast<acc_t>(denom);
88 
89   int i_start = threadIdx.x;
90   int i_end = dim;
91   int i_step = blockDim.x;
92 
93   buffer[threadIdx.x] = 0;
94   for (int i=i_start; i<i_end; i+=i_step) {
95     scalar_t z = margin - input_target_k + input_k[i];
96     if (i == target_k) {
97       continue;
98     }
99 
100     if (z > 0) {
101       acc_t h = (P == 1) ? g : 2*g*z;
102       if (weights) {
103         h *= weights[target_k];
104       }
105 
106       buffer[threadIdx.x] -= static_cast<scalar_t>(h);
107       gradInput_k[i] = static_cast<scalar_t>(h);
108     } else {
109       gradInput_k[i] = static_cast<scalar_t>(0);
110     }
111   }
112 
113   __syncthreads();
114 
115   // reduce
116   if (threadIdx.x == 0) {
117     acc_t gradInput_target_k = 0;
118     for (int i=0; i<blockDim.x; i++) {
119       gradInput_target_k += buffer[i];
120     }
121     gradInput_k[target_k] = static_cast<scalar_t>(gradInput_target_k);
122   }
123 
124   for (int i=i_start; i<i_end; i+= i_step) {
125     gradInput_k[i] *= * gradOutput_k;
126   }
127 }
128 
multi_margin_loss_shape_check(int64_t & nframe,int64_t & dim,const int64_t & ndims,const Tensor & input,const Tensor & target,const std::optional<Tensor> & weight)129 void multi_margin_loss_shape_check(
130     int64_t& nframe,
131     int64_t& dim,
132     const int64_t& ndims,
133     const Tensor& input,
134     const Tensor& target,
135     const std::optional<Tensor>& weight) {
136     TORCH_CHECK(
137         (ndims == 2 && input.size(1) != 0) || (ndims == 1 && input.size(0) != 0) || ndims == 0,
138         "Expected non-empty vector or matrix with optional 0-dim batch size, but got: ",
139         input.sizes());
140 
141     if (ndims <= 1) {
142       nframe = 1;
143       dim = ndims == 0 ? 1 : input.size(0);
144     } else {
145       nframe = input.size(0);
146       dim = input.size(1);
147     }
148 
149     TORCH_CHECK(
150         target.dim() <= 1 && target.numel() == nframe,
151         "inconsistent target size, expected ", nframe, " but got ",
152         target.sizes());
153     if (weight && weight->defined()) {
154       TORCH_CHECK(
155           weight->dim() <= 1 && weight->numel() == dim,
156           "inconsistent weight size, expected ", dim, " but got ",
157           weight->sizes());
158     }
159 }
160 
161 }  // namespace (anonymous)
162 
multi_margin_loss_cuda_out(const Tensor & input_,const Tensor & target_,const Scalar & p_,const Scalar & margin_,const std::optional<Tensor> & weights_,int64_t reduction,Tensor & out_)163 Tensor& multi_margin_loss_cuda_out(
164     const Tensor &input_, const Tensor &target_, const Scalar &p_, const Scalar &margin_,
165     const std::optional<Tensor> &weights_, int64_t reduction, Tensor& out_) {
166   auto p = p_.toLong();
167   int64_t nframe, dim;
168   const auto ndims = input_.dim();
169 
170   TORCH_CHECK(p == 1 || p == 2, "multi_margin_loss: Invalid p, expected 1 or 2 but got ", p);
171 
172   multi_margin_loss_shape_check(nframe, dim, ndims, input_, target_, weights_);
173 
174   // produce a scalar output for 1d input
175   if (reduction == Reduction::None && target_.dim() > 0) {
176     resize_output(out_, {nframe});
177   } else {
178     resize_output(out_, {});
179   }
180   if (input_.numel() == 0) {
181     return out_;
182   }
183 
184   auto input = input_.contiguous();
185   auto target = target_.contiguous();
186   Tensor weights;
187   if (weights_ && weights_->defined()) {
188     weights = weights_->contiguous();
189   }
190   auto out = (out_.is_contiguous() ? out_ :
191               at::empty(out_.sizes(), input.options()));
192 
193   const auto stream = c10::cuda::getCurrentCUDAStream();
194 
195   AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "multi_margin_loss_cuda", [&] {
196     const scalar_t margin = margin_.to<scalar_t>();
197     if (input.dim() <= 1) {
198       TORCH_CHECK(target.dim() <= 1 && target.numel() == nframe, "inconsistent target size");
199       dim3 blocks(1);
200       dim3 threads(MULTIMARGIN_THREADS);
201       if (p == 1) {
202         MultiMarginLoss_forward_kernel<1> <<<blocks, threads, 0, stream>>>(
203             out.mutable_data_ptr<scalar_t>(),
204             input.const_data_ptr<scalar_t>(),
205             target.const_data_ptr<int64_t>(),
206             weights.defined() ? weights.const_data_ptr<scalar_t>() : nullptr,
207             1,
208             input.dim() < 1 ? input.numel() : input.sizes()[0],
209             reduction == at::Reduction::Mean,
210             margin);
211         C10_CUDA_KERNEL_LAUNCH_CHECK();
212       } else if (p == 2) {
213         MultiMarginLoss_forward_kernel<2> <<<blocks, threads, 0, stream>>>(
214             out.mutable_data_ptr<scalar_t>(),
215             input.const_data_ptr<scalar_t>(),
216             target.const_data_ptr<int64_t>(),
217             weights.defined() ? weights.const_data_ptr<scalar_t>() : nullptr,
218             1,
219             input.dim() < 1 ? input.numel() : input.sizes()[0],
220             reduction == at::Reduction::Mean,
221             margin);
222         C10_CUDA_KERNEL_LAUNCH_CHECK();
223       }
224     } else {
225       auto in_sizes = input.sizes();
226       TORCH_INTERNAL_ASSERT(in_sizes.size() == 2);
227       // allow zero-dim target for 2D input.
228       TORCH_CHECK(in_sizes[1] != 0 && target.dim() <= 1 && target.numel() == nframe,
229                 "inconsistent target size");
230       dim3 blocks(nframe);
231       dim3 threads(MULTIMARGIN_THREADS);
232 
233       if (reduction == at::Reduction::None) {
234         if (p == 1) {
235           MultiMarginLoss_forward_kernel<1> <<<blocks, threads, 0, stream>>>(
236               out.mutable_data_ptr<scalar_t>(),
237               input.const_data_ptr<scalar_t>(),
238               target.const_data_ptr<int64_t>(),
239               weights.defined() ? weights.const_data_ptr<scalar_t>() : nullptr,
240               nframe, in_sizes[1],
241               false,
242               margin);
243           C10_CUDA_KERNEL_LAUNCH_CHECK();
244         } else if (p == 2) {
245           MultiMarginLoss_forward_kernel<2> <<<blocks, threads, 0, stream>>>(
246               out.mutable_data_ptr<scalar_t>(),
247               input.const_data_ptr<scalar_t>(),
248               target.const_data_ptr<int64_t>(),
249               weights.defined() ? weights.const_data_ptr<scalar_t>() : nullptr,
250               nframe, in_sizes[1],
251               false,
252               margin);
253           C10_CUDA_KERNEL_LAUNCH_CHECK();
254         }
255       } else {
256         auto tmp_output = at::empty({nframe}, input.options());
257         if (p == 1) {
258           MultiMarginLoss_forward_kernel<1> <<<blocks, threads, 0, stream>>>(
259               tmp_output.mutable_data_ptr<scalar_t>(),
260               input.const_data_ptr<scalar_t>(),
261               target.const_data_ptr<int64_t>(),
262               weights.defined() ? weights.const_data_ptr<scalar_t>() : nullptr,
263               nframe, in_sizes[1],
264               reduction == Reduction::Mean,
265               margin);
266           C10_CUDA_KERNEL_LAUNCH_CHECK();
267         } else if (p == 2) {
268           MultiMarginLoss_forward_kernel<2> <<<blocks, threads, 0, stream>>>(
269               tmp_output.mutable_data_ptr<scalar_t>(),
270               input.const_data_ptr<scalar_t>(),
271               target.const_data_ptr<int64_t>(),
272               weights.defined() ? weights.const_data_ptr<scalar_t>() : nullptr,
273               nframe, in_sizes[1],
274               reduction == Reduction::Mean,
275               margin);
276           C10_CUDA_KERNEL_LAUNCH_CHECK();
277         }
278         at::sum_out(out, tmp_output, IntArrayRef{});
279       }
280     }
281   });
282 
283   if (!out.is_alias_of(out_)) {
284     out_.copy_(out);
285   }
286   return out_;
287 }
288 
multi_margin_loss_cuda(const Tensor & input,const Tensor & target,const Scalar & p,const Scalar & margin,const std::optional<Tensor> & weights,int64_t reduction)289 Tensor multi_margin_loss_cuda(
290     const Tensor &input, const Tensor &target, const Scalar &p, const Scalar &margin,
291     const std::optional<Tensor> &weights, int64_t reduction) {
292   auto out = at::empty({0}, input.options());
293   multi_margin_loss_cuda_out(input, target, p, margin, weights, reduction, out);
294   return out;
295 }
296 
multi_margin_loss_cuda_backward_out(const Tensor & grad_output_,const Tensor & input_,const Tensor & target_,const Scalar & p_,const Scalar & margin_,const std::optional<Tensor> & weights_,int64_t reduction,Tensor & grad_input_)297 Tensor& multi_margin_loss_cuda_backward_out(
298     const Tensor &grad_output_,const Tensor &input_, const Tensor &target_,
299     const Scalar &p_, const Scalar &margin_, const std::optional<Tensor> &weights_,
300     int64_t reduction, Tensor &grad_input_) {
301   auto p = p_.toLong();
302   int64_t nframe, dim;
303   const auto ndims = input_.dim();
304 
305   TORCH_CHECK(p == 1 || p == 2,
306               "multi_margin_loss_backward: Invalid p, expected 1 or 2 but got ", p);
307 
308   multi_margin_loss_shape_check(nframe, dim, ndims, input_, target_, weights_);
309   resize_output(grad_input_, input_.sizes());
310 
311   if (input_.numel() == 0) {
312     return grad_input_;
313   }
314 
315   auto input = input_.contiguous();
316   auto grad_input = (grad_input_.is_contiguous() ? grad_input_ :
317                      at::empty(grad_input_.sizes(), input.options()));
318   auto grad_output = grad_output_.contiguous();
319   auto target = target_.contiguous();
320   Tensor weights;
321   if (weights_ && weights_->defined()) {
322     weights = weights_->contiguous();
323   }
324 
325   const auto stream = c10::cuda::getCurrentCUDAStream();
326 
327   AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
328                                   "multi_margin_loss_backward_cuda", [&] {
329     const scalar_t margin = margin_.to<scalar_t>();
330 
331     if (input.dim() <= 1) {
332       dim3 blocks(1);
333       dim3 threads(MULTIMARGIN_THREADS);
334 
335       if (p == 1) {
336         MultiMarginLoss_backward_kernel<1> <<<blocks, threads, 0, stream>>>(
337             grad_input.mutable_data_ptr<scalar_t>(),
338             grad_output.const_data_ptr<scalar_t>(),
339             input.const_data_ptr<scalar_t>(),
340             target.const_data_ptr<int64_t>(),
341             weights.defined() ? weights.const_data_ptr<scalar_t>() : nullptr,
342             1,
343             input.dim() == 0 ? 1 : input.sizes()[0],
344             reduction == at::Reduction::Mean,
345             margin,
346             reduction != at::Reduction::None);
347         C10_CUDA_KERNEL_LAUNCH_CHECK();
348       } else if (p == 2) {
349         MultiMarginLoss_backward_kernel<2> <<<blocks, threads, 0, stream>>>(
350             grad_input.mutable_data_ptr<scalar_t>(),
351             grad_output.const_data_ptr<scalar_t>(),
352             input.const_data_ptr<scalar_t>(),
353             target.const_data_ptr<int64_t>(),
354             weights.defined() ? weights.const_data_ptr<scalar_t>() : nullptr,
355             1,
356             input.dim() == 0 ? 1 : input.sizes()[0],
357             reduction == at::Reduction::Mean,
358             margin,
359             reduction != at::Reduction::None);
360         C10_CUDA_KERNEL_LAUNCH_CHECK();
361       }
362     } else {
363       auto in_sizes = input.sizes();
364       TORCH_INTERNAL_ASSERT(in_sizes.size() == 2);
365       TORCH_CHECK((in_sizes[1] != 0) && (target.dim() <= 1) && (target.numel() == nframe),
366                   "inconsistent target size");
367       dim3 blocks(in_sizes[0]);
368       dim3 threads(MULTIMARGIN_THREADS);
369 
370       if (p == 1) {
371         MultiMarginLoss_backward_kernel<1> <<<blocks, threads, 0, stream>>>(
372             grad_input.mutable_data_ptr<scalar_t>(),
373             grad_output.const_data_ptr<scalar_t>(),
374             input.const_data_ptr<scalar_t>(),
375             target.const_data_ptr<int64_t>(),
376             weights.defined() ? weights.const_data_ptr<scalar_t>() : nullptr,
377             nframe, in_sizes[1],
378             reduction == at::Reduction::Mean,
379             margin,
380             reduction != at::Reduction::None);
381         C10_CUDA_KERNEL_LAUNCH_CHECK();
382       } else if (p == 2) {
383         MultiMarginLoss_backward_kernel<2> <<<blocks, threads, 0, stream>>>(
384             grad_input.mutable_data_ptr<scalar_t>(),
385             grad_output.const_data_ptr<scalar_t>(),
386             input.const_data_ptr<scalar_t>(),
387             target.const_data_ptr<int64_t>(),
388             weights.defined() ? weights.const_data_ptr<scalar_t>() : nullptr,
389             nframe, in_sizes[1],
390             reduction == at::Reduction::Mean,
391             margin,
392             reduction != at::Reduction::None);
393         C10_CUDA_KERNEL_LAUNCH_CHECK();
394       }
395     }
396   });
397 
398   if (!grad_input.is_alias_of(grad_input_)) {
399     grad_input_.copy_(grad_input);
400   }
401   return grad_input_;
402 }
403 
multi_margin_loss_cuda_backward(const Tensor & grad_output,const Tensor & input,const Tensor & target,const Scalar & p,const Scalar & margin,const std::optional<Tensor> & weights,int64_t reduction)404 Tensor multi_margin_loss_cuda_backward(
405     const Tensor &grad_output, const Tensor &input, const Tensor &target,
406     const Scalar &p, const Scalar &margin, const std::optional<Tensor> &weights,
407     int64_t reduction) {
408   auto grad_input = at::empty({0}, input.options());
409   multi_margin_loss_cuda_backward_out(
410       grad_output, input, target, p, margin, weights, reduction, grad_input);
411   return grad_input;
412 }
413 
414 }  // namespace at::native
415