1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/AccumulateType.h>
4 #include <ATen/Dispatch.h>
5 #include <c10/macros/Macros.h>
6 #include <ATen/cuda/CUDAContext.h>
7 #include <ATen/native/cuda/block_reduce.cuh>
8
9 #ifndef AT_PER_OPERATOR_HEADERS
10 #include <ATen/Functions.h>
11 #include <ATen/CUDAFunctions.h>
12 #include <ATen/NativeFunctions.h>
13 #else
14 #include <ATen/ops/empty.h>
15 #include <ATen/ops/zeros_like.h>
16 #include <ATen/ops/sum_cuda_dispatch.h>
17 #include <ATen/ops/multilabel_margin_loss.h>
18 #endif
19
20
21 namespace at::native {
22
23 namespace {
24 const int MULTILABELMARGIN_THREADS = 128;
25
multilabel_margin_loss_shape_check(int64_t & nframe,int64_t & dim,const int64_t & ndims,const Tensor & input,const Tensor & target)26 void multilabel_margin_loss_shape_check(
27 int64_t& nframe,
28 int64_t& dim,
29 const int64_t& ndims,
30 const Tensor& input,
31 const Tensor& target) {
32 TORCH_CHECK(
33 (ndims == 2 && input.size(1) != 0) || (ndims == 1 && input.size(0) != 0) || ndims == 0,
34 "Expected non-empty vector or matrix with optional 0-dim batch size, but got: ",
35 input.sizes());
36
37 if (ndims <= 1) {
38 nframe = 1;
39 dim = ndims == 0 ? 1 : input.size(0);
40 TORCH_CHECK(
41 target.dim() <= 1 && target.numel() == dim,
42 "inconsistent target size: ", target.sizes(), " for input of size: ",
43 input.sizes());
44 } else {
45 nframe = input.size(0);
46 dim = input.size(1);
47 TORCH_CHECK(
48 target.dim() == 2 && target.size(0) == nframe &&
49 target.size(1) == dim,
50 "inconsistent target size: ", target.sizes(), " for input of size: ",
51 input.sizes());
52 }
53 }
54
55 template <typename scalar_t, typename accscalar_t>
C10_LAUNCH_BOUNDS_1(MULTILABELMARGIN_THREADS)56 C10_LAUNCH_BOUNDS_1(MULTILABELMARGIN_THREADS)
57 __global__ void multilabel_margin_loss_forward_kernel(
58 scalar_t* output,
59 const scalar_t* input,
60 const int64_t* target,
61 scalar_t* is_target,
62 int nframe,
63 int dim,
64 bool size_average) {
65
66 // vectors:
67 int k = blockIdx.x;
68 const scalar_t* input_k = input + k * dim;
69 const int64_t* target_k = target + k * dim;
70 scalar_t* output_k = output + k;
71 scalar_t* is_target_k = is_target + k * dim;
72
73 // zero is_target
74 for (int d = threadIdx.x; d < dim; d += blockDim.x) {
75 is_target_k[d] = static_cast<scalar_t>(0);
76 }
77 __syncthreads();
78
79 // mark targets in is_target
80 if (threadIdx.x == 0) {
81 for (int dt = 0; dt < dim; dt++) {
82 int target_idx = target_k[dt];
83 if (target_idx < 0) {
84 break;
85 }
86 is_target_k[target_idx] = static_cast<scalar_t>(1);
87 }
88 }
89 __syncthreads();
90
91 // iterate over targets
92 accscalar_t sum = 0;
93 for (int dt = 0; dt < dim; dt++) {
94 // next target:
95 int target_idx = target_k[dt];
96 if (target_idx < 0) {
97 break;
98 }
99
100 // current value for target
101 scalar_t input_target_k = input_k[target_idx];
102
103 // compare to all inputs (multithreaded):
104 for (int d = threadIdx.x; d < dim; d += blockDim.x) {
105 // contribute to loss only if not a target
106 if (!static_cast<int>(is_target_k[d])) {
107 scalar_t z = 1 - input_target_k + input_k[d];
108 if (z > 0) {
109 sum += z;
110 }
111 }
112 }
113 }
114
115 // Temporary sums (for mapreduce)
116 __shared__ accscalar_t smem[MULTILABELMARGIN_THREADS];
117 accscalar_t total_sum = cuda_utils::BlockReduceSum(sum, smem);
118 if (threadIdx.x == 0) {
119 if (size_average) {
120 *output_k = static_cast<scalar_t>((total_sum / dim) / nframe);
121 } else {
122 *output_k = static_cast<scalar_t>(total_sum / dim);
123 }
124 }
125 }
126
127 template <typename scalar_t, typename accscalar_t>
C10_LAUNCH_BOUNDS_1(MULTILABELMARGIN_THREADS)128 C10_LAUNCH_BOUNDS_1(MULTILABELMARGIN_THREADS)
129 __global__ void multilabel_margin_loss_backward_kernel(
130 scalar_t* grad_input,
131 const scalar_t* grad_output,
132 const scalar_t* input,
133 const int64_t* target,
134 const scalar_t* is_target,
135 int nframe,
136 int dim,
137 bool size_average,
138 bool reduce) {
139
140 int k = blockIdx.x;
141 const scalar_t* input_k = input + k * dim;
142 scalar_t* grad_input_k = grad_input + k * dim;
143 const int64_t* target_k = target + k * dim;
144 const scalar_t* is_target_k = is_target + k * dim;
145
146 const scalar_t* grad_output_k = grad_output;
147 if (!reduce) {
148 grad_output_k += k;
149 }
150
151 // gain:
152 scalar_t g = static_cast<scalar_t>(
153 size_average && reduce ? 1. / static_cast<accscalar_t>(nframe * dim)
154 : 1. / static_cast<accscalar_t>(dim));
155
156 // zero gradients:
157 for (int d = threadIdx.x; d < dim; d += blockDim.x) {
158 grad_input_k[d] = static_cast<scalar_t>(0);
159 }
160 __syncthreads();
161
162 // iterate over targets
163 for (int dt = 0; dt < dim; dt++) {
164 // next target:
165 int target_idx = static_cast<int>(target_k[dt]);
166 if (target_idx < 0) {
167 break;
168 }
169
170 // current value for target
171 scalar_t input_target_k = input_k[target_idx];
172
173 // compare to all inputs (multithreaded):
174 accscalar_t sum = 0;
175 for (int d = threadIdx.x; d < dim; d += blockDim.x) {
176 // contribute to loss only if not a target
177 if (!static_cast<int>(is_target_k[d])) {
178 scalar_t z = 1 - input_target_k + input_k[d];
179 if (z > 0) {
180 sum -= g;
181 grad_input_k[d] += g;
182 }
183 }
184 }
185 __syncthreads();
186
187 // Temporary sums (for mapreduce)
188 __shared__ accscalar_t smem[MULTILABELMARGIN_THREADS];
189 accscalar_t total_sum = cuda_utils::BlockReduceSum(sum, smem);
190 if (threadIdx.x == 0) {
191 grad_input_k[target_idx] += static_cast<scalar_t>(total_sum);
192 }
193 }
194
195 for (int d = threadIdx.x; d < dim; d += blockDim.x) {
196 grad_input_k[d] *= *grad_output_k;
197 }
198 }
199
multilabel_margin_loss_forward_out_cuda_template(const Tensor & input,const Tensor & target,int64_t reduction,Tensor & output,Tensor & is_target)200 void multilabel_margin_loss_forward_out_cuda_template(
201 const Tensor& input,
202 const Tensor& target,
203 int64_t reduction,
204 Tensor& output,
205 Tensor& is_target) {
206 int64_t nframe, dim;
207 const int64_t ndims = input.dim();
208 multilabel_margin_loss_shape_check(nframe, dim, ndims, input, target);
209
210 if (input.numel() == 0) {
211 return;
212 }
213
214 auto input_ = input.contiguous();
215 auto target_ = target.contiguous();
216 auto is_target_ = is_target.contiguous();
217 is_target_.resize_as_(target);
218
219 if (input.dim() <= 1) {
220 output.resize_({});
221
222 dim3 blocks(1);
223 dim3 threads(MULTILABELMARGIN_THREADS);
224
225 AT_DISPATCH_FLOATING_TYPES_AND2(
226 at::ScalarType::Half,
227 at::ScalarType::BFloat16,
228 input.scalar_type(),
229 "multilabel_margin_loss_forward_kernel",
230 [&] {
231 using accscalar_t = at::acc_type<scalar_t, true>;
232 multilabel_margin_loss_forward_kernel<scalar_t, accscalar_t>
233 <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
234 output.mutable_data_ptr<scalar_t>(),
235 input_.const_data_ptr<scalar_t>(),
236 target_.const_data_ptr<int64_t>(),
237 is_target_.mutable_data_ptr<scalar_t>(),
238 1,
239 dim,
240 reduction == at::Reduction::Mean);
241 C10_CUDA_KERNEL_LAUNCH_CHECK();
242 });
243 } else if (input.dim() == 2) {
244 dim3 blocks(input.size(0));
245 dim3 threads(MULTILABELMARGIN_THREADS);
246
247 if (reduction != at::Reduction::None) {
248 auto output_tmp = at::empty({input_.size(0)}, input_.options());
249 output.resize_({});
250 AT_DISPATCH_FLOATING_TYPES_AND2(
251 at::ScalarType::Half,
252 at::ScalarType::BFloat16,
253 input.scalar_type(),
254 "multilabel_margin_loss_forward_kernel",
255 [&] {
256 using accscalar_t = at::acc_type<scalar_t, true>;
257 multilabel_margin_loss_forward_kernel<scalar_t, accscalar_t>
258 <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
259 output_tmp.mutable_data_ptr<scalar_t>(),
260 input_.const_data_ptr<scalar_t>(),
261 target_.const_data_ptr<int64_t>(),
262 is_target_.mutable_data_ptr<scalar_t>(),
263 nframe,
264 dim,
265 reduction == at::Reduction::Mean);
266 C10_CUDA_KERNEL_LAUNCH_CHECK();
267 });
268 at::cuda::sum_out(
269 output,
270 output_tmp,
271 at::IntArrayRef(std::vector<int64_t>{}),
272 false,
273 output.scalar_type());
274 } else {
275 output.resize_({input.size(0)});
276 AT_DISPATCH_FLOATING_TYPES_AND2(
277 at::ScalarType::Half,
278 at::ScalarType::BFloat16,
279 input.scalar_type(),
280 "multilabel_margin_loss_forward_kernel",
281 [&] {
282 using accscalar_t = at::acc_type<scalar_t, true>;
283 multilabel_margin_loss_forward_kernel<scalar_t, accscalar_t>
284 <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
285 output.mutable_data_ptr<scalar_t>(),
286 input_.const_data_ptr<scalar_t>(),
287 target_.const_data_ptr<int64_t>(),
288 is_target_.mutable_data_ptr<scalar_t>(),
289 nframe,
290 dim,
291 false);
292 C10_CUDA_KERNEL_LAUNCH_CHECK();
293 });
294 }
295
296 } else {
297 TORCH_CHECK(
298 false,
299 "Expected 2D input with optional zero batch dim, or 1D input with non-zero dims, but got sizes: ",
300 input.sizes());
301 }
302 }
303
multilabel_margin_loss_backward_cuda_out_template(const Tensor & grad_output,const Tensor & input,const Tensor & target,int64_t reduction,const Tensor & is_target,Tensor & grad_input)304 void multilabel_margin_loss_backward_cuda_out_template(
305 const Tensor& grad_output,
306 const Tensor& input,
307 const Tensor& target,
308 int64_t reduction,
309 const Tensor& is_target,
310 Tensor& grad_input) {
311 int64_t nframe, dim;
312 const int64_t ndims = input.dim();
313 multilabel_margin_loss_shape_check(nframe, dim, ndims, input, target);
314
315 if (input.numel() == 0) {
316 return;
317 }
318
319 auto input_ = input.contiguous();
320 auto target_ = target.contiguous();
321 auto is_target_ = is_target.contiguous();
322 auto grad_output_ = grad_output.contiguous();
323 grad_input.resize_as_(input_);
324
325 if (grad_input.dim() <= 1) {
326 int target_size = target_.dim() == 0 ? 1 : target_.size(0);
327 TORCH_CHECK(
328 (target_.numel() != 0) && (target_.dim() <= 1) && (target_size == dim),
329 "inconsistent target size");
330 TORCH_CHECK(
331 target_.sizes() == is_target_.sizes(), "inconsistent is_target size");
332 dim3 blocks(1);
333 dim3 threads(MULTILABELMARGIN_THREADS);
334
335 AT_DISPATCH_FLOATING_TYPES_AND2(
336 at::ScalarType::Half,
337 at::ScalarType::BFloat16,
338 input.scalar_type(),
339 "multilabel_margin_loss_backward_kernel",
340 [&] {
341 using accscalar_t = at::acc_type<scalar_t, true>;
342 multilabel_margin_loss_backward_kernel<scalar_t, accscalar_t>
343 <<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
344 grad_input.mutable_data_ptr<scalar_t>(),
345 grad_output_.const_data_ptr<scalar_t>(),
346 input_.const_data_ptr<scalar_t>(),
347 target_.const_data_ptr<int64_t>(),
348 is_target_.const_data_ptr<scalar_t>(),
349 1,
350 dim,
351 reduction == at::Reduction::Mean,
352 reduction != at::Reduction::None);
353 C10_CUDA_KERNEL_LAUNCH_CHECK();
354 });
355 } else if (grad_input.dim() == 2) {
356 TORCH_CHECK(
357 (input_.size(1) != 0) && (target_.dim() == 2) &&
358 (target_.size(0) == nframe) && (target_.size(1) == dim),
359 "inconsistent target size");
360 TORCH_CHECK(target_.sizes() == is_target_.sizes(), "inconsistent is_target size");
361 dim3 blocks(grad_input.size(0));
362 dim3 threads(MULTILABELMARGIN_THREADS);
363
364 AT_DISPATCH_FLOATING_TYPES_AND2(
365 at::ScalarType::Half,
366 at::ScalarType::BFloat16,
367 input.scalar_type(),
368 "multilabel_margin_loss_backward_kernel",
369 [&] {
370 using accscalar_t = at::acc_type<scalar_t, true>;
371 multilabel_margin_loss_backward_kernel<scalar_t, accscalar_t>
372 <<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
373 grad_input.mutable_data_ptr<scalar_t>(),
374 grad_output_.const_data_ptr<scalar_t>(),
375 input_.const_data_ptr<scalar_t>(),
376 target_.const_data_ptr<int64_t>(),
377 is_target_.const_data_ptr<scalar_t>(),
378 grad_input.size(0),
379 grad_input.size(1),
380 reduction == at::Reduction::Mean,
381 reduction != at::Reduction::None);
382 C10_CUDA_KERNEL_LAUNCH_CHECK();
383 });
384 } else {
385 TORCH_CHECK(
386 false,
387 "Expected 2D input with optional zero batch dim, or 1D input with non-zero dims, but got sizes: ",
388 grad_input.sizes());
389 }
390 }
391
392 } // namespace
393
multilabel_margin_loss_forward_out_cuda(const Tensor & self,const Tensor & target,int64_t reduction,Tensor & output,Tensor & is_target)394 std::tuple<Tensor&, Tensor&> multilabel_margin_loss_forward_out_cuda(
395 const Tensor& self,
396 const Tensor& target,
397 int64_t reduction,
398 Tensor& output,
399 Tensor& is_target) {
400 multilabel_margin_loss_forward_out_cuda_template(
401 self, target, reduction, output, is_target);
402 return std::tuple<Tensor&, Tensor&>(output, is_target);
403 }
404
multilabel_margin_loss_forward_cuda(const Tensor & self,const Tensor & target,int64_t reduction)405 std::tuple<Tensor, Tensor> multilabel_margin_loss_forward_cuda(
406 const Tensor& self,
407 const Tensor& target,
408 int64_t reduction) {
409 auto output = at::empty({0}, self.options());
410 auto is_target = at::empty({0}, self.options());
411 multilabel_margin_loss_forward_out_cuda_template(
412 self, target, reduction, output, is_target);
413 return std::make_tuple(output, is_target);
414 }
415
multilabel_margin_loss_backward_cuda_out(const Tensor & grad_output,const Tensor & self,const Tensor & target,int64_t reduction,const Tensor & is_target,Tensor & grad_input)416 Tensor& multilabel_margin_loss_backward_cuda_out(
417 const Tensor& grad_output,
418 const Tensor& self,
419 const Tensor& target,
420 int64_t reduction,
421 const Tensor& is_target,
422 Tensor& grad_input) {
423 multilabel_margin_loss_backward_cuda_out_template(
424 grad_output, self, target, reduction, is_target, grad_input);
425 return grad_input;
426 }
427
multilabel_margin_loss_backward_cuda(const Tensor & grad_output,const Tensor & self,const Tensor & target,int64_t reduction,const Tensor & is_target)428 Tensor multilabel_margin_loss_backward_cuda(
429 const Tensor& grad_output,
430 const Tensor& self,
431 const Tensor& target,
432 int64_t reduction,
433 const Tensor& is_target) {
434 auto grad_input = at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
435 multilabel_margin_loss_backward_cuda_out_template(
436 grad_output, self, target, reduction, is_target, grad_input);
437 return grad_input;
438 }
439
440 } // namespace at::native
441