1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/AccumulateType.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/native/Resize.h>
6 #include <c10/cuda/CUDAStream.h>
7 #include <c10/cuda/CUDAException.h>
8
9 #ifndef AT_PER_OPERATOR_HEADERS
10 #include <ATen/Functions.h>
11 #include <ATen/NativeFunctions.h>
12 #else
13 #include <ATen/ops/empty.h>
14 #include <ATen/ops/sum.h>
15 #include <ATen/ops/multi_margin_loss_native.h>
16 #include <ATen/ops/multi_margin_loss_backward_native.h>
17 #endif
18
19 namespace at::native {
20 namespace {
21 constexpr int MULTIMARGIN_THREADS = 128;
22
23 template <int P, typename scalar_t>
MultiMarginLoss_forward_kernel(scalar_t * output,const scalar_t * input,const int64_t * target,const scalar_t * weights,int nframe,int dim,bool sizeAverage,scalar_t margin)24 __global__ void MultiMarginLoss_forward_kernel(
25 scalar_t *output, const scalar_t *input, const int64_t *target, const scalar_t *weights,
26 int nframe, int dim, bool sizeAverage, scalar_t margin) {
27 using acc_t = at::acc_type<scalar_t, true>;
28 __shared__ acc_t buffer[MULTIMARGIN_THREADS];
29 int k = blockIdx.x;
30 const scalar_t *input_k = input + k*dim;
31 scalar_t *output_k = output + k;
32 int target_k = static_cast<int>(target[k]);
33 CUDA_KERNEL_ASSERT(target_k >= 0 && target_k < dim && "target index is out of bounds");
34 scalar_t input_target_k = input_k[target_k];
35
36 int i_start = threadIdx.x;
37 int i_end = dim;
38 int i_step = blockDim.x;
39
40 buffer[threadIdx.x] = 0;
41 for (int i = i_start; i < i_end; i += i_step) {
42 scalar_t z = margin - input_target_k + input_k[i];
43 if (i == target_k) {
44 continue;
45 }
46
47 if (z > 0) {
48 scalar_t h = (P==1) ? z : z*z;
49 if (weights) {
50 h *= weights[target_k];
51 }
52 buffer[threadIdx.x] += h;
53 }
54 }
55 __syncthreads();
56
57 // reduce
58 if (threadIdx.x == 0) {
59 acc_t sum = 0;
60 for (int i=0; i < blockDim.x; i++)
61 sum += buffer[i];
62
63 const int denom = sizeAverage ? nframe * dim : dim;
64 *output_k = static_cast<scalar_t>(sum / denom);
65 }
66 }
67
68 template <int P, typename scalar_t>
MultiMarginLoss_backward_kernel(scalar_t * gradInput,const scalar_t * gradOutput,const scalar_t * input,const int64_t * target,const scalar_t * weights,int nframe,int dim,bool sizeAverage,scalar_t margin,bool reduce)69 __global__ void MultiMarginLoss_backward_kernel(
70 scalar_t *gradInput, const scalar_t *gradOutput, const scalar_t *input, const int64_t *target,
71 const scalar_t *weights, int nframe, int dim, bool sizeAverage, scalar_t margin,
72 bool reduce) {
73 using acc_t = at::acc_type<scalar_t, true>;
74 __shared__ acc_t buffer[MULTIMARGIN_THREADS];
75 int k = blockIdx.x;
76 const scalar_t *input_k = input + k*dim;
77 scalar_t *gradInput_k = gradInput + k*dim;
78 int target_k = static_cast<int>(target[k]);
79 scalar_t input_target_k = input_k[target_k];
80
81 const scalar_t *gradOutput_k = gradOutput;
82 if (!reduce) {
83 gradOutput_k += k;
84 }
85
86 const int denom = sizeAverage && reduce ? nframe * dim : dim;
87 const acc_t g = acc_t(1) / static_cast<acc_t>(denom);
88
89 int i_start = threadIdx.x;
90 int i_end = dim;
91 int i_step = blockDim.x;
92
93 buffer[threadIdx.x] = 0;
94 for (int i=i_start; i<i_end; i+=i_step) {
95 scalar_t z = margin - input_target_k + input_k[i];
96 if (i == target_k) {
97 continue;
98 }
99
100 if (z > 0) {
101 acc_t h = (P == 1) ? g : 2*g*z;
102 if (weights) {
103 h *= weights[target_k];
104 }
105
106 buffer[threadIdx.x] -= static_cast<scalar_t>(h);
107 gradInput_k[i] = static_cast<scalar_t>(h);
108 } else {
109 gradInput_k[i] = static_cast<scalar_t>(0);
110 }
111 }
112
113 __syncthreads();
114
115 // reduce
116 if (threadIdx.x == 0) {
117 acc_t gradInput_target_k = 0;
118 for (int i=0; i<blockDim.x; i++) {
119 gradInput_target_k += buffer[i];
120 }
121 gradInput_k[target_k] = static_cast<scalar_t>(gradInput_target_k);
122 }
123
124 for (int i=i_start; i<i_end; i+= i_step) {
125 gradInput_k[i] *= * gradOutput_k;
126 }
127 }
128
multi_margin_loss_shape_check(int64_t & nframe,int64_t & dim,const int64_t & ndims,const Tensor & input,const Tensor & target,const std::optional<Tensor> & weight)129 void multi_margin_loss_shape_check(
130 int64_t& nframe,
131 int64_t& dim,
132 const int64_t& ndims,
133 const Tensor& input,
134 const Tensor& target,
135 const std::optional<Tensor>& weight) {
136 TORCH_CHECK(
137 (ndims == 2 && input.size(1) != 0) || (ndims == 1 && input.size(0) != 0) || ndims == 0,
138 "Expected non-empty vector or matrix with optional 0-dim batch size, but got: ",
139 input.sizes());
140
141 if (ndims <= 1) {
142 nframe = 1;
143 dim = ndims == 0 ? 1 : input.size(0);
144 } else {
145 nframe = input.size(0);
146 dim = input.size(1);
147 }
148
149 TORCH_CHECK(
150 target.dim() <= 1 && target.numel() == nframe,
151 "inconsistent target size, expected ", nframe, " but got ",
152 target.sizes());
153 if (weight && weight->defined()) {
154 TORCH_CHECK(
155 weight->dim() <= 1 && weight->numel() == dim,
156 "inconsistent weight size, expected ", dim, " but got ",
157 weight->sizes());
158 }
159 }
160
161 } // namespace (anonymous)
162
multi_margin_loss_cuda_out(const Tensor & input_,const Tensor & target_,const Scalar & p_,const Scalar & margin_,const std::optional<Tensor> & weights_,int64_t reduction,Tensor & out_)163 Tensor& multi_margin_loss_cuda_out(
164 const Tensor &input_, const Tensor &target_, const Scalar &p_, const Scalar &margin_,
165 const std::optional<Tensor> &weights_, int64_t reduction, Tensor& out_) {
166 auto p = p_.toLong();
167 int64_t nframe, dim;
168 const auto ndims = input_.dim();
169
170 TORCH_CHECK(p == 1 || p == 2, "multi_margin_loss: Invalid p, expected 1 or 2 but got ", p);
171
172 multi_margin_loss_shape_check(nframe, dim, ndims, input_, target_, weights_);
173
174 // produce a scalar output for 1d input
175 if (reduction == Reduction::None && target_.dim() > 0) {
176 resize_output(out_, {nframe});
177 } else {
178 resize_output(out_, {});
179 }
180 if (input_.numel() == 0) {
181 return out_;
182 }
183
184 auto input = input_.contiguous();
185 auto target = target_.contiguous();
186 Tensor weights;
187 if (weights_ && weights_->defined()) {
188 weights = weights_->contiguous();
189 }
190 auto out = (out_.is_contiguous() ? out_ :
191 at::empty(out_.sizes(), input.options()));
192
193 const auto stream = c10::cuda::getCurrentCUDAStream();
194
195 AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "multi_margin_loss_cuda", [&] {
196 const scalar_t margin = margin_.to<scalar_t>();
197 if (input.dim() <= 1) {
198 TORCH_CHECK(target.dim() <= 1 && target.numel() == nframe, "inconsistent target size");
199 dim3 blocks(1);
200 dim3 threads(MULTIMARGIN_THREADS);
201 if (p == 1) {
202 MultiMarginLoss_forward_kernel<1> <<<blocks, threads, 0, stream>>>(
203 out.mutable_data_ptr<scalar_t>(),
204 input.const_data_ptr<scalar_t>(),
205 target.const_data_ptr<int64_t>(),
206 weights.defined() ? weights.const_data_ptr<scalar_t>() : nullptr,
207 1,
208 input.dim() < 1 ? input.numel() : input.sizes()[0],
209 reduction == at::Reduction::Mean,
210 margin);
211 C10_CUDA_KERNEL_LAUNCH_CHECK();
212 } else if (p == 2) {
213 MultiMarginLoss_forward_kernel<2> <<<blocks, threads, 0, stream>>>(
214 out.mutable_data_ptr<scalar_t>(),
215 input.const_data_ptr<scalar_t>(),
216 target.const_data_ptr<int64_t>(),
217 weights.defined() ? weights.const_data_ptr<scalar_t>() : nullptr,
218 1,
219 input.dim() < 1 ? input.numel() : input.sizes()[0],
220 reduction == at::Reduction::Mean,
221 margin);
222 C10_CUDA_KERNEL_LAUNCH_CHECK();
223 }
224 } else {
225 auto in_sizes = input.sizes();
226 TORCH_INTERNAL_ASSERT(in_sizes.size() == 2);
227 // allow zero-dim target for 2D input.
228 TORCH_CHECK(in_sizes[1] != 0 && target.dim() <= 1 && target.numel() == nframe,
229 "inconsistent target size");
230 dim3 blocks(nframe);
231 dim3 threads(MULTIMARGIN_THREADS);
232
233 if (reduction == at::Reduction::None) {
234 if (p == 1) {
235 MultiMarginLoss_forward_kernel<1> <<<blocks, threads, 0, stream>>>(
236 out.mutable_data_ptr<scalar_t>(),
237 input.const_data_ptr<scalar_t>(),
238 target.const_data_ptr<int64_t>(),
239 weights.defined() ? weights.const_data_ptr<scalar_t>() : nullptr,
240 nframe, in_sizes[1],
241 false,
242 margin);
243 C10_CUDA_KERNEL_LAUNCH_CHECK();
244 } else if (p == 2) {
245 MultiMarginLoss_forward_kernel<2> <<<blocks, threads, 0, stream>>>(
246 out.mutable_data_ptr<scalar_t>(),
247 input.const_data_ptr<scalar_t>(),
248 target.const_data_ptr<int64_t>(),
249 weights.defined() ? weights.const_data_ptr<scalar_t>() : nullptr,
250 nframe, in_sizes[1],
251 false,
252 margin);
253 C10_CUDA_KERNEL_LAUNCH_CHECK();
254 }
255 } else {
256 auto tmp_output = at::empty({nframe}, input.options());
257 if (p == 1) {
258 MultiMarginLoss_forward_kernel<1> <<<blocks, threads, 0, stream>>>(
259 tmp_output.mutable_data_ptr<scalar_t>(),
260 input.const_data_ptr<scalar_t>(),
261 target.const_data_ptr<int64_t>(),
262 weights.defined() ? weights.const_data_ptr<scalar_t>() : nullptr,
263 nframe, in_sizes[1],
264 reduction == Reduction::Mean,
265 margin);
266 C10_CUDA_KERNEL_LAUNCH_CHECK();
267 } else if (p == 2) {
268 MultiMarginLoss_forward_kernel<2> <<<blocks, threads, 0, stream>>>(
269 tmp_output.mutable_data_ptr<scalar_t>(),
270 input.const_data_ptr<scalar_t>(),
271 target.const_data_ptr<int64_t>(),
272 weights.defined() ? weights.const_data_ptr<scalar_t>() : nullptr,
273 nframe, in_sizes[1],
274 reduction == Reduction::Mean,
275 margin);
276 C10_CUDA_KERNEL_LAUNCH_CHECK();
277 }
278 at::sum_out(out, tmp_output, IntArrayRef{});
279 }
280 }
281 });
282
283 if (!out.is_alias_of(out_)) {
284 out_.copy_(out);
285 }
286 return out_;
287 }
288
multi_margin_loss_cuda(const Tensor & input,const Tensor & target,const Scalar & p,const Scalar & margin,const std::optional<Tensor> & weights,int64_t reduction)289 Tensor multi_margin_loss_cuda(
290 const Tensor &input, const Tensor &target, const Scalar &p, const Scalar &margin,
291 const std::optional<Tensor> &weights, int64_t reduction) {
292 auto out = at::empty({0}, input.options());
293 multi_margin_loss_cuda_out(input, target, p, margin, weights, reduction, out);
294 return out;
295 }
296
multi_margin_loss_cuda_backward_out(const Tensor & grad_output_,const Tensor & input_,const Tensor & target_,const Scalar & p_,const Scalar & margin_,const std::optional<Tensor> & weights_,int64_t reduction,Tensor & grad_input_)297 Tensor& multi_margin_loss_cuda_backward_out(
298 const Tensor &grad_output_,const Tensor &input_, const Tensor &target_,
299 const Scalar &p_, const Scalar &margin_, const std::optional<Tensor> &weights_,
300 int64_t reduction, Tensor &grad_input_) {
301 auto p = p_.toLong();
302 int64_t nframe, dim;
303 const auto ndims = input_.dim();
304
305 TORCH_CHECK(p == 1 || p == 2,
306 "multi_margin_loss_backward: Invalid p, expected 1 or 2 but got ", p);
307
308 multi_margin_loss_shape_check(nframe, dim, ndims, input_, target_, weights_);
309 resize_output(grad_input_, input_.sizes());
310
311 if (input_.numel() == 0) {
312 return grad_input_;
313 }
314
315 auto input = input_.contiguous();
316 auto grad_input = (grad_input_.is_contiguous() ? grad_input_ :
317 at::empty(grad_input_.sizes(), input.options()));
318 auto grad_output = grad_output_.contiguous();
319 auto target = target_.contiguous();
320 Tensor weights;
321 if (weights_ && weights_->defined()) {
322 weights = weights_->contiguous();
323 }
324
325 const auto stream = c10::cuda::getCurrentCUDAStream();
326
327 AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
328 "multi_margin_loss_backward_cuda", [&] {
329 const scalar_t margin = margin_.to<scalar_t>();
330
331 if (input.dim() <= 1) {
332 dim3 blocks(1);
333 dim3 threads(MULTIMARGIN_THREADS);
334
335 if (p == 1) {
336 MultiMarginLoss_backward_kernel<1> <<<blocks, threads, 0, stream>>>(
337 grad_input.mutable_data_ptr<scalar_t>(),
338 grad_output.const_data_ptr<scalar_t>(),
339 input.const_data_ptr<scalar_t>(),
340 target.const_data_ptr<int64_t>(),
341 weights.defined() ? weights.const_data_ptr<scalar_t>() : nullptr,
342 1,
343 input.dim() == 0 ? 1 : input.sizes()[0],
344 reduction == at::Reduction::Mean,
345 margin,
346 reduction != at::Reduction::None);
347 C10_CUDA_KERNEL_LAUNCH_CHECK();
348 } else if (p == 2) {
349 MultiMarginLoss_backward_kernel<2> <<<blocks, threads, 0, stream>>>(
350 grad_input.mutable_data_ptr<scalar_t>(),
351 grad_output.const_data_ptr<scalar_t>(),
352 input.const_data_ptr<scalar_t>(),
353 target.const_data_ptr<int64_t>(),
354 weights.defined() ? weights.const_data_ptr<scalar_t>() : nullptr,
355 1,
356 input.dim() == 0 ? 1 : input.sizes()[0],
357 reduction == at::Reduction::Mean,
358 margin,
359 reduction != at::Reduction::None);
360 C10_CUDA_KERNEL_LAUNCH_CHECK();
361 }
362 } else {
363 auto in_sizes = input.sizes();
364 TORCH_INTERNAL_ASSERT(in_sizes.size() == 2);
365 TORCH_CHECK((in_sizes[1] != 0) && (target.dim() <= 1) && (target.numel() == nframe),
366 "inconsistent target size");
367 dim3 blocks(in_sizes[0]);
368 dim3 threads(MULTIMARGIN_THREADS);
369
370 if (p == 1) {
371 MultiMarginLoss_backward_kernel<1> <<<blocks, threads, 0, stream>>>(
372 grad_input.mutable_data_ptr<scalar_t>(),
373 grad_output.const_data_ptr<scalar_t>(),
374 input.const_data_ptr<scalar_t>(),
375 target.const_data_ptr<int64_t>(),
376 weights.defined() ? weights.const_data_ptr<scalar_t>() : nullptr,
377 nframe, in_sizes[1],
378 reduction == at::Reduction::Mean,
379 margin,
380 reduction != at::Reduction::None);
381 C10_CUDA_KERNEL_LAUNCH_CHECK();
382 } else if (p == 2) {
383 MultiMarginLoss_backward_kernel<2> <<<blocks, threads, 0, stream>>>(
384 grad_input.mutable_data_ptr<scalar_t>(),
385 grad_output.const_data_ptr<scalar_t>(),
386 input.const_data_ptr<scalar_t>(),
387 target.const_data_ptr<int64_t>(),
388 weights.defined() ? weights.const_data_ptr<scalar_t>() : nullptr,
389 nframe, in_sizes[1],
390 reduction == at::Reduction::Mean,
391 margin,
392 reduction != at::Reduction::None);
393 C10_CUDA_KERNEL_LAUNCH_CHECK();
394 }
395 }
396 });
397
398 if (!grad_input.is_alias_of(grad_input_)) {
399 grad_input_.copy_(grad_input);
400 }
401 return grad_input_;
402 }
403
multi_margin_loss_cuda_backward(const Tensor & grad_output,const Tensor & input,const Tensor & target,const Scalar & p,const Scalar & margin,const std::optional<Tensor> & weights,int64_t reduction)404 Tensor multi_margin_loss_cuda_backward(
405 const Tensor &grad_output, const Tensor &input, const Tensor &target,
406 const Scalar &p, const Scalar &margin, const std::optional<Tensor> &weights,
407 int64_t reduction) {
408 auto grad_input = at::empty({0}, input.options());
409 multi_margin_loss_cuda_backward_out(
410 grad_output, input, target, p, margin, weights, reduction, grad_input);
411 return grad_input;
412 }
413
414 } // namespace at::native
415