1 #pragma once
2
3 #include <ATen/ExpandUtils.h>
4 #include <torch/nn/functional/activation.h>
5 #include <torch/nn/options/loss.h>
6
7 namespace torch {
8 namespace nn {
9 namespace functional {
10
11 #ifndef DOXYGEN_SHOULD_SKIP_THIS
12 namespace detail {
l1_loss(const Tensor & input,const Tensor & target,L1LossFuncOptions::reduction_t reduction)13 inline Tensor l1_loss(
14 const Tensor& input,
15 const Tensor& target,
16 L1LossFuncOptions::reduction_t reduction) {
17 return torch::l1_loss(input, target, enumtype::reduction_get_enum(reduction));
18 }
19 } // namespace detail
20 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
21
22 /// See
23 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.l1_loss
24 /// about the exact behavior of this functional.
25 ///
26 /// See the documentation for `torch::nn::functional::L1LossFuncOptions` class
27 /// to learn what optional arguments are supported for this functional.
28 ///
29 /// Example:
30 /// ```
31 /// namespace F = torch::nn::functional;
32 /// F::l1_loss(input, target, F::L1LossFuncOptions(torch::kNone));
33 /// ```
34 inline Tensor l1_loss(
35 const Tensor& input,
36 const Tensor& target,
37 const L1LossFuncOptions& options = {}) {
38 return detail::l1_loss(input, target, options.reduction());
39 }
40
41 // ============================================================================
42
43 #ifndef DOXYGEN_SHOULD_SKIP_THIS
44 namespace detail {
45 inline Tensor kl_div(
46 const Tensor& input,
47 const Tensor& target,
48 KLDivFuncOptions::reduction_t reduction,
49 bool log_target = false) {
50 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
51 torch::Reduction::Reduction reduction_enum;
52
53 if (std::holds_alternative<enumtype::kMean>(reduction)) {
54 TORCH_WARN(
55 "reduction: 'mean' divides the total loss by both the batch size and the support size."
56 "'batchmean' divides only by the batch size, and aligns with the KL div math definition."
57 "'mean' will be changed to behave the same as 'batchmean' in the next major release.");
58 }
59
60 // special case for batchmean
61 if (std::holds_alternative<enumtype::kBatchMean>(reduction)) {
62 reduction_enum = torch::Reduction::Sum;
63 } else {
64 reduction_enum = enumtype::reduction_get_enum(reduction);
65 }
66
67 auto reduced = torch::kl_div(input, target, reduction_enum, log_target);
68
69 if (std::holds_alternative<enumtype::kBatchMean>(reduction) &&
70 input.dim() != 0) {
71 reduced = reduced / input.sizes()[0];
72 }
73
74 return reduced;
75 }
76 } // namespace detail
77 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
78
79 /// See
80 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.kl_div
81 /// about the exact behavior of this functional.
82 ///
83 /// See the documentation for `torch::nn::functional::KLDivFuncOptions` class to
84 /// learn what optional arguments are supported for this functional.
85 ///
86 /// Example:
87 /// ```
88 /// namespace F = torch::nn::functional;
89 /// F::kl_div(input, target,
90 /// F::KLDivFuncOptions.reduction(torch::kNone).log_target(false));
91 /// ```
92 inline Tensor kl_div(
93 const Tensor& input,
94 const Tensor& target,
95 const KLDivFuncOptions& options = {}) {
96 return detail::kl_div(
97 input, target, options.reduction(), options.log_target());
98 }
99
100 // ============================================================================
101
102 #ifndef DOXYGEN_SHOULD_SKIP_THIS
103 namespace detail {
mse_loss(const Tensor & input,const Tensor & target,MSELossFuncOptions::reduction_t reduction)104 inline Tensor mse_loss(
105 const Tensor& input,
106 const Tensor& target,
107 MSELossFuncOptions::reduction_t reduction) {
108 if (!(target.sizes() == input.sizes())) {
109 TORCH_WARN(
110 "Using a target size (",
111 target.sizes(),
112 ") that is different to the input size (",
113 input.sizes(),
114 "). ",
115 "This will likely lead to incorrect results due to broadcasting. ",
116 "Please ensure they have the same size.");
117 }
118 std::vector<torch::Tensor> broadcast_tensors =
119 torch::broadcast_tensors({input, target});
120 auto expanded_input = broadcast_tensors[0];
121 auto expanded_target = broadcast_tensors[1];
122 return torch::mse_loss(
123 expanded_input, expanded_target, enumtype::reduction_get_enum(reduction));
124 }
125 } // namespace detail
126 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
127
128 /// See
129 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.mse_loss
130 /// about the exact behavior of this functional.
131 ///
132 /// See the documentation for `torch::nn::functional::MSELossFuncOptions` class
133 /// to learn what optional arguments are supported for this functional.
134 ///
135 /// Example:
136 /// ```
137 /// namespace F = torch::nn::functional;
138 /// F::mse_loss(input, target, F::MSELossFuncOptions(torch::kNone));
139 /// ```
140 inline Tensor mse_loss(
141 const Tensor& input,
142 const Tensor& target,
143 const MSELossFuncOptions& options = {}) {
144 return detail::mse_loss(input, target, options.reduction());
145 }
146
147 // ============================================================================
148
149 #ifndef DOXYGEN_SHOULD_SKIP_THIS
150 namespace detail {
binary_cross_entropy(const Tensor & input,const Tensor & target,const Tensor & weight,BinaryCrossEntropyFuncOptions::reduction_t reduction)151 inline Tensor binary_cross_entropy(
152 const Tensor& input,
153 const Tensor& target,
154 const Tensor& weight,
155 BinaryCrossEntropyFuncOptions::reduction_t reduction) {
156 auto reduction_enum = enumtype::reduction_get_enum(reduction);
157
158 if (target.sizes() != input.sizes()) {
159 TORCH_CHECK(
160 false,
161 "Using a target size (",
162 target.sizes(),
163 ") ",
164 "that is different to the input size (",
165 input.sizes(),
166 ") is deprecated. ",
167 "Please ensure they have the same size.");
168 }
169
170 auto weight_ = weight;
171 if (weight_.defined()) {
172 auto new_size = at::infer_size(target.sizes(), weight_.sizes());
173 weight_ = weight_.expand(new_size);
174 }
175
176 return torch::binary_cross_entropy(input, target, weight_, reduction_enum);
177 }
178 } // namespace detail
179 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
180
181 /// See
182 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.binary_cross_entropy
183 /// about the exact behavior of this functional.
184 ///
185 /// See the documentation for
186 /// `torch::nn::functional::BinaryCrossEntropyFuncOptions` class to learn what
187 /// optional arguments are supported for this functional.
188 ///
189 /// Example:
190 /// ```
191 /// namespace F = torch::nn::functional;
192 /// F::binary_cross_entropy(input, target,
193 /// F::BinaryCrossEntropyFuncOptions().weight(weight));
194 /// ```
195 inline Tensor binary_cross_entropy(
196 const Tensor& input,
197 const Tensor& target,
198 const BinaryCrossEntropyFuncOptions& options = {}) {
199 return detail::binary_cross_entropy(
200 input, target, options.weight(), options.reduction());
201 }
202
203 // ============================================================================
204
205 #ifndef DOXYGEN_SHOULD_SKIP_THIS
206 namespace detail {
hinge_embedding_loss(const Tensor & input,const Tensor & target,double margin,HingeEmbeddingLossFuncOptions::reduction_t reduction)207 inline Tensor hinge_embedding_loss(
208 const Tensor& input,
209 const Tensor& target,
210 double margin,
211 HingeEmbeddingLossFuncOptions::reduction_t reduction) {
212 return torch::hinge_embedding_loss(
213 input, target, margin, enumtype::reduction_get_enum(reduction));
214 }
215 } // namespace detail
216 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
217
218 /// See
219 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.hinge_embedding_loss
220 /// about the exact behavior of this functional.
221 ///
222 /// See the documentation for
223 /// `torch::nn::functional::HingeEmbeddingLossFuncOptions` class to learn what
224 /// optional arguments are supported for this functional.
225 ///
226 /// Example:
227 /// ```
228 /// namespace F = torch::nn::functional;
229 /// F::hinge_embedding_loss(input, target,
230 /// F::HingeEmbeddingLossFuncOptions().margin(2));
231 /// ```
232 inline Tensor hinge_embedding_loss(
233 const Tensor& input,
234 const Tensor& target,
235 const HingeEmbeddingLossFuncOptions& options = {}) {
236 return detail::hinge_embedding_loss(
237 input, target, options.margin(), options.reduction());
238 }
239
240 // ============================================================================
241
242 #ifndef DOXYGEN_SHOULD_SKIP_THIS
243 namespace detail {
multi_margin_loss(const Tensor & input,const Tensor & target,int64_t p,double margin,const Tensor & weight,MultiMarginLossFuncOptions::reduction_t reduction)244 inline Tensor multi_margin_loss(
245 const Tensor& input,
246 const Tensor& target,
247 int64_t p,
248 double margin,
249 const Tensor& weight,
250 MultiMarginLossFuncOptions::reduction_t reduction) {
251 TORCH_CHECK(p == 1 || p == 2, "only p == 1 and p == 2 supported");
252 if (weight.defined()) {
253 TORCH_CHECK(weight.dim() == 1, "weight must be one-dimensional");
254 }
255
256 return torch::multi_margin_loss(
257 input,
258 target,
259 p,
260 margin,
261 weight,
262 enumtype::reduction_get_enum(reduction));
263 }
264 } // namespace detail
265 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
266
267 /// See
268 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.multi_margin_loss
269 /// about the exact behavior of this functional.
270 ///
271 /// See the documentation for
272 /// `torch::nn::functional::MultiMarginLossFuncOptions` class to learn what
273 /// optional arguments are supported for this functional.
274 ///
275 /// Example:
276 /// ```
277 /// namespace F = torch::nn::functional;
278 /// F::multi_margin_loss(input, target,
279 /// F::MultiMarginLossFuncOptions().margin(2).weight(weight));
280 /// ```
281 inline Tensor multi_margin_loss(
282 const Tensor& input,
283 const Tensor& target,
284 const MultiMarginLossFuncOptions& options = {}) {
285 return detail::multi_margin_loss(
286 input,
287 target,
288 options.p(),
289 options.margin(),
290 options.weight(),
291 options.reduction());
292 }
293
294 // ============================================================================
295
296 #ifndef DOXYGEN_SHOULD_SKIP_THIS
297 namespace detail {
cosine_embedding_loss(const Tensor & input1,const Tensor & input2,const Tensor & target,double margin,CosineEmbeddingLossFuncOptions::reduction_t reduction)298 inline Tensor cosine_embedding_loss(
299 const Tensor& input1,
300 const Tensor& input2,
301 const Tensor& target,
302 double margin,
303 CosineEmbeddingLossFuncOptions::reduction_t reduction) {
304 return torch::cosine_embedding_loss(
305 input1, input2, target, margin, enumtype::reduction_get_enum(reduction));
306 }
307 } // namespace detail
308 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
309
310 /// See
311 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.cosine_embedding_loss
312 /// about the exact behavior of this functional.
313 ///
314 /// See the documentation for
315 /// `torch::nn::functional::CosineEmbeddingLossFuncOptions` class to learn what
316 /// optional arguments are supported for this functional.
317 ///
318 /// Example:
319 /// ```
320 /// namespace F = torch::nn::functional;
321 /// F::cosine_embedding_loss(input1, input2, target,
322 /// F::CosineEmbeddingLossFuncOptions().margin(0.5));
323 /// ```
324 inline Tensor cosine_embedding_loss(
325 const Tensor& input1,
326 const Tensor& input2,
327 const Tensor& target,
328 const CosineEmbeddingLossFuncOptions& options = {}) {
329 return detail::cosine_embedding_loss(
330 input1, input2, target, options.margin(), options.reduction());
331 }
332
333 // ============================================================================
334
335 inline Tensor _smooth_l1_loss(
336 const Tensor& input,
337 const Tensor& target,
338 double beta = 1.) {
339 auto t = torch::abs(input - target);
340 return torch::where(t < beta, 0.5 * torch::pow(t, 2) / beta, t - 0.5 * beta);
341 }
342
343 #ifndef DOXYGEN_SHOULD_SKIP_THIS
344 namespace detail {
345 inline Tensor smooth_l1_loss(
346 const Tensor& input,
347 const Tensor& target,
348 SmoothL1LossFuncOptions::reduction_t reduction,
349 std::optional<double> beta_opt = std::nullopt) {
350 if (target.sizes() != input.sizes()) {
351 TORCH_WARN(
352 "Using a target size (",
353 target.sizes(),
354 ") that is different to the input size (",
355 input.sizes(),
356 "). ",
357 "This will likely lead to incorrect results due to broadcasting. ",
358 "Please ensure they have the same size.");
359 }
360 double beta = beta_opt.value_or(1.0);
361
362 std::vector<Tensor> expanded_tensors =
363 torch::broadcast_tensors({input, target});
364 return torch::smooth_l1_loss(
365 expanded_tensors[0],
366 expanded_tensors[1],
367 enumtype::reduction_get_enum(reduction),
368 beta);
369 }
370 } // namespace detail
371 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
372
373 /// See
374 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.smooth_l1_loss
375 /// about the exact behavior of this functional.
376 ///
377 /// See the documentation for `torch::nn::functional::SmoothL1LossFuncOptions`
378 /// class to learn what optional arguments are supported for this functional.
379 ///
380 /// Example:
381 /// ```
382 /// namespace F = torch::nn::functional;
383 /// F::smooth_l1_loss(input, target, F::SmoothL1LossFuncOptions(torch::kNone));
384 /// ```
385 inline Tensor smooth_l1_loss(
386 const Tensor& input,
387 const Tensor& target,
388 const SmoothL1LossFuncOptions& options = {}) {
389 return detail::smooth_l1_loss(
390 input, target, options.reduction(), options.beta());
391 }
392
393 /// See
394 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.smooth_l1_loss
395 /// about the exact behavior of this functional.
396 ///
397 /// Example:
398 /// ```
399 /// namespace F = torch::nn::functional;
400 /// F::smooth_l1_loss(input, target, /*options=*/torch::kNone, /*beta=*/0.5);
401 /// ```
smooth_l1_loss(const Tensor & input,const Tensor & target,const SmoothL1LossFuncOptions & options,double beta)402 inline Tensor smooth_l1_loss(
403 const Tensor& input,
404 const Tensor& target,
405 const SmoothL1LossFuncOptions& options,
406 double beta) {
407 TORCH_CHECK(
408 options.beta() == std::nullopt,
409 "expected beta not to be provided in 'options', but got ",
410 options.beta().value());
411 return detail::smooth_l1_loss(input, target, options.reduction(), beta);
412 }
413
414 // ============================================================================
415
416 #ifndef DOXYGEN_SHOULD_SKIP_THIS
417 namespace detail {
418 inline Tensor huber_loss(
419 const Tensor& input,
420 const Tensor& target,
421 HuberLossFuncOptions::reduction_t reduction,
422 double delta = 1.) {
423 if (target.sizes() != input.sizes()) {
424 TORCH_WARN(
425 "Using a target size (",
426 target.sizes(),
427 ") that is different to the input size (",
428 input.sizes(),
429 "). ",
430 "This will likely lead to incorrect results due to broadcasting. ",
431 "Please ensure they have the same size.");
432 }
433
434 std::vector<Tensor> expanded_tensors =
435 torch::broadcast_tensors({input, target});
436 return torch::huber_loss(
437 expanded_tensors[0],
438 expanded_tensors[1],
439 enumtype::reduction_get_enum(reduction),
440 delta);
441 }
442 } // namespace detail
443 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
444
445 /// See
446 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.huber_loss
447 /// about the exact behavior of this functional.
448 ///
449 /// See the documentation for `torch::nn::functional::HuberLossFuncOptions`
450 /// class to learn what optional arguments are supported for this functional.
451 ///
452 /// Example:
453 /// ```
454 /// namespace F = torch::nn::functional;
455 /// F::huber_loss(input, target,
456 /// F::HuberLossFuncOptions().reduction(torch::kNone).delta(0.5));
457 /// ```
458 inline Tensor huber_loss(
459 const Tensor& input,
460 const Tensor& target,
461 const HuberLossFuncOptions& options = {}) {
462 return detail::huber_loss(
463 input, target, options.reduction(), options.delta());
464 }
465
466 // ============================================================================
467
468 #ifndef DOXYGEN_SHOULD_SKIP_THIS
469 namespace detail {
multilabel_margin_loss(const Tensor & input,const Tensor & target,MultilabelMarginLossFuncOptions::reduction_t reduction)470 inline Tensor multilabel_margin_loss(
471 const Tensor& input,
472 const Tensor& target,
473 MultilabelMarginLossFuncOptions::reduction_t reduction) {
474 return torch::multilabel_margin_loss(
475 input, target, enumtype::reduction_get_enum(reduction));
476 }
477 } // namespace detail
478 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
479
480 /// See
481 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.multilabel_margin_loss
482 /// about the exact behavior of this functional.
483 ///
484 /// See the documentation for
485 /// `torch::nn::functional::MultilabelMarginLossFuncOptions` class to learn what
486 /// optional arguments are supported for this functional.
487 ///
488 /// Example:
489 /// ```
490 /// namespace F = torch::nn::functional;
491 /// F::multilabel_margin_loss(input, target,
492 /// F::MultilabelMarginLossFuncOptions(torch::kNone));
493 /// ```
494 inline Tensor multilabel_margin_loss(
495 const Tensor& input,
496 const Tensor& target,
497 const MultilabelMarginLossFuncOptions& options = {}) {
498 return detail::multilabel_margin_loss(input, target, options.reduction());
499 }
500
501 // ============================================================================
502
503 #ifndef DOXYGEN_SHOULD_SKIP_THIS
504 namespace detail {
soft_margin_loss(const Tensor & input,const Tensor & target,SoftMarginLossFuncOptions::reduction_t reduction)505 inline Tensor soft_margin_loss(
506 const Tensor& input,
507 const Tensor& target,
508 SoftMarginLossFuncOptions::reduction_t reduction) {
509 return torch::soft_margin_loss(
510 input, target, enumtype::reduction_get_enum(reduction));
511 }
512 } // namespace detail
513 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
514
515 /// See
516 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.soft_margin_loss
517 /// about the exact behavior of this functional.
518 ///
519 /// See the documentation for `torch::nn::functional::SoftMarginLossFuncOptions`
520 /// class to learn what optional arguments are supported for this functional.
521 ///
522 /// Example:
523 /// ```
524 /// namespace F = torch::nn::functional;
525 /// F::soft_margin_loss(input, target,
526 /// F::SoftMarginLossFuncOptions(torch::kNone));
527 /// ```
528 inline Tensor soft_margin_loss(
529 const Tensor& input,
530 const Tensor& target,
531 const SoftMarginLossFuncOptions& options = {}) {
532 return detail::soft_margin_loss(input, target, options.reduction());
533 }
534
535 // ============================================================================
536
537 #ifndef DOXYGEN_SHOULD_SKIP_THIS
538 namespace detail {
multilabel_soft_margin_loss(const Tensor & input,const Tensor & target,const Tensor & weight,MultilabelSoftMarginLossFuncOptions::reduction_t reduction)539 inline Tensor multilabel_soft_margin_loss(
540 const Tensor& input,
541 const Tensor& target,
542 const Tensor& weight,
543 MultilabelSoftMarginLossFuncOptions::reduction_t reduction) {
544 auto loss =
545 -(target * torch::log_sigmoid(input) +
546 (1 - target) * torch::log_sigmoid(-input));
547 if (weight.defined()) {
548 loss = loss * weight;
549 }
550
551 auto class_dim = input.dim() - 1;
552 auto C = input.size(class_dim);
553 loss = loss.sum(class_dim) / C; // only return N loss values
554
555 Tensor ret;
556
557 if (std::holds_alternative<enumtype::kNone>(reduction)) {
558 ret = loss;
559 } else if (std::holds_alternative<enumtype::kMean>(reduction)) {
560 ret = loss.mean();
561 } else if (std::holds_alternative<enumtype::kSum>(reduction)) {
562 ret = loss.sum();
563 } else {
564 ret = input;
565 TORCH_INTERNAL_ASSERT(
566 false, enumtype::get_enum_name(reduction), " is not valid");
567 }
568 return ret;
569 }
570 } // namespace detail
571 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
572
573 /// See
574 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.multilabel_soft_margin_loss
575 /// about the exact behavior of this functional.
576 ///
577 /// See the documentation for
578 /// `torch::nn::functional::MultilabelSoftMarginLossFuncOptions` class to learn
579 /// what optional arguments are supported for this functional.
580 ///
581 /// Example:
582 /// ```
583 /// namespace F = torch::nn::functional;
584 /// F::multilabel_soft_margin_loss(input, target,
585 /// F::MultilabelSoftMarginLossFuncOptions().reduction(torch::kNone).weight(weight));
586 /// ```
587 inline Tensor multilabel_soft_margin_loss(
588 const Tensor& input,
589 const Tensor& target,
590 const MultilabelSoftMarginLossFuncOptions& options = {}) {
591 return detail::multilabel_soft_margin_loss(
592 input, target, options.weight(), options.reduction());
593 }
594
595 // ============================================================================
596
597 #ifndef DOXYGEN_SHOULD_SKIP_THIS
598 namespace detail {
triplet_margin_loss(const Tensor & anchor,const Tensor & positive,const Tensor & negative,double margin,double p,double eps,bool swap,TripletMarginLossFuncOptions::reduction_t reduction)599 inline Tensor triplet_margin_loss(
600 const Tensor& anchor,
601 const Tensor& positive,
602 const Tensor& negative,
603 double margin,
604 double p,
605 double eps,
606 bool swap,
607 TripletMarginLossFuncOptions::reduction_t reduction) {
608 return torch::triplet_margin_loss(
609 anchor,
610 positive,
611 negative,
612 margin,
613 p,
614 eps,
615 swap,
616 enumtype::reduction_get_enum(reduction));
617 }
618 } // namespace detail
619 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
620
621 /// See
622 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.triplet_margin_loss
623 /// about the exact behavior of this functional.
624 ///
625 /// See the documentation for
626 /// `torch::nn::functional::TripletMarginLossFuncOptions` class to learn what
627 /// optional arguments are supported for this functional.
628 ///
629 /// Example:
630 /// ```
631 /// namespace F = torch::nn::functional;
632 /// F::triplet_margin_loss(anchor, positive, negative,
633 /// F::TripletMarginLossFuncOptions().margin(1.0));
634 /// ```
635 inline Tensor triplet_margin_loss(
636 const Tensor& anchor,
637 const Tensor& positive,
638 const Tensor& negative,
639 const TripletMarginLossFuncOptions& options = {}) {
640 return detail::triplet_margin_loss(
641 anchor,
642 positive,
643 negative,
644 options.margin(),
645 options.p(),
646 options.eps(),
647 options.swap(),
648 options.reduction());
649 }
650
651 // ============================================================================
652
653 #ifndef DOXYGEN_SHOULD_SKIP_THIS
654 namespace detail {
triplet_margin_with_distance_loss(const Tensor & anchor,const Tensor & positive,const Tensor & negative,std::optional<TripletMarginWithDistanceLossFuncOptions::distance_function_t> distance_function,double margin,bool swap,TripletMarginWithDistanceLossFuncOptions::reduction_t reduction)655 inline Tensor triplet_margin_with_distance_loss(
656 const Tensor& anchor,
657 const Tensor& positive,
658 const Tensor& negative,
659 std::optional<TripletMarginWithDistanceLossFuncOptions::distance_function_t>
660 distance_function,
661 double margin,
662 bool swap,
663 TripletMarginWithDistanceLossFuncOptions::reduction_t reduction) {
664 Tensor dist_pos, dist_neg;
665 if (distance_function.has_value()) {
666 auto distance_function_impl = distance_function.value();
667 dist_pos = distance_function_impl(anchor, positive);
668 dist_neg = distance_function_impl(anchor, negative);
669 } else {
670 dist_pos = pairwise_distance(anchor, positive);
671 dist_neg = pairwise_distance(anchor, negative);
672 }
673
674 if (swap) {
675 Tensor dist_swap;
676 if (distance_function.has_value()) {
677 dist_swap = distance_function.value()(positive, negative);
678 } else {
679 dist_swap = pairwise_distance(positive, negative);
680 }
681 dist_neg = torch::min(dist_neg, dist_swap);
682 }
683
684 auto loss = torch::clamp_min(dist_pos - dist_neg + margin, 0);
685
686 Tensor ret;
687 if (std::holds_alternative<enumtype::kNone>(reduction)) {
688 ret = loss;
689 } else if (std::holds_alternative<enumtype::kMean>(reduction)) {
690 ret = loss.mean();
691 } else if (std::holds_alternative<enumtype::kSum>(reduction)) {
692 ret = loss.sum();
693 } else {
694 ret = anchor;
695 TORCH_INTERNAL_ASSERT(
696 false, enumtype::get_enum_name(reduction), " is not valid");
697 }
698 return ret;
699 }
700 } // namespace detail
701 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
702
703 /// See
704 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.triplet_margin_with_distance_loss
705 /// about the exact behavior of this functional.
706 ///
707 /// See the documentation for
708 /// `torch::nn::functional::TripletMarginWithDistanceLossFuncOptions` class to
709 /// learn what optional arguments are supported for this functional.
710 ///
711 /// Example:
712 /// ```
713 /// namespace F = torch::nn::functional;
714 /// F::triplet_margin_with_distance_loss(anchor, positive, negative,
715 /// F::TripletMarginWithDistanceLossFuncOptions().margin(1.0));
716 /// ```
717 inline Tensor triplet_margin_with_distance_loss(
718 const Tensor& anchor,
719 const Tensor& positive,
720 const Tensor& negative,
721 const TripletMarginWithDistanceLossFuncOptions& options = {}) {
722 return detail::triplet_margin_with_distance_loss(
723 anchor,
724 positive,
725 negative,
726 options.distance_function(),
727 options.margin(),
728 options.swap(),
729 options.reduction());
730 }
731
732 // ============================================================================
733
734 #ifndef DOXYGEN_SHOULD_SKIP_THIS
735 namespace detail {
ctc_loss(const Tensor & log_probs,const Tensor & targets,const Tensor & input_lengths,const Tensor & target_lengths,int64_t blank,CTCLossFuncOptions::reduction_t reduction,bool zero_infinity)736 inline Tensor ctc_loss(
737 const Tensor& log_probs,
738 const Tensor& targets,
739 const Tensor& input_lengths,
740 const Tensor& target_lengths,
741 int64_t blank,
742 CTCLossFuncOptions::reduction_t reduction,
743 bool zero_infinity) {
744 return torch::ctc_loss(
745 log_probs,
746 targets,
747 input_lengths,
748 target_lengths,
749 blank,
750 enumtype::reduction_get_enum(reduction),
751 zero_infinity);
752 }
753 } // namespace detail
754 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
755
756 /// See
757 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.ctc_loss
758 /// about the exact behavior of this functional.
759 ///
760 /// See the documentation for `torch::nn::functional::CTCLossFuncOptions` class
761 /// to learn what optional arguments are supported for this functional.
762 ///
763 /// Example:
764 /// ```
765 /// namespace F = torch::nn::functional;
766 /// F::ctc_loss(log_probs, targets, input_lengths, target_lengths,
767 /// F::CTCLossFuncOptions().reduction(torch::kNone));
768 /// ```
769 inline Tensor ctc_loss(
770 const Tensor& log_probs,
771 const Tensor& targets,
772 const Tensor& input_lengths,
773 const Tensor& target_lengths,
774 const CTCLossFuncOptions& options = {}) {
775 return detail::ctc_loss(
776 log_probs,
777 targets,
778 input_lengths,
779 target_lengths,
780 options.blank(),
781 options.reduction(),
782 options.zero_infinity());
783 }
784
785 // ============================================================================
786
787 #ifndef DOXYGEN_SHOULD_SKIP_THIS
788 namespace detail {
poisson_nll_loss(const Tensor & input,const Tensor & target,bool log_input,bool full,double eps,PoissonNLLLossFuncOptions::reduction_t reduction)789 inline Tensor poisson_nll_loss(
790 const Tensor& input,
791 const Tensor& target,
792 bool log_input,
793 bool full,
794 double eps,
795 PoissonNLLLossFuncOptions::reduction_t reduction) {
796 return torch::poisson_nll_loss(
797 input,
798 target,
799 log_input,
800 full,
801 eps,
802 enumtype::reduction_get_enum(reduction));
803 }
804 } // namespace detail
805 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
806
807 /// See
808 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.poisson_nll_loss
809 /// about the exact behavior of this functional.
810 ///
811 /// See the documentation for `torch::nn::functional::PoissonNLLLossFuncOptions`
812 /// class to learn what optional arguments are supported for this functional.
813 ///
814 /// Example:
815 /// ```
816 /// namespace F = torch::nn::functional;
817 /// F::poisson_nll_loss(input, target,
818 /// F::PoissonNLLLossFuncOptions().reduction(torch::kNone));
819 /// ```
820 inline Tensor poisson_nll_loss(
821 const Tensor& input,
822 const Tensor& target,
823 const PoissonNLLLossFuncOptions& options = {}) {
824 return detail::poisson_nll_loss(
825 input,
826 target,
827 options.log_input(),
828 options.full(),
829 options.eps(),
830 options.reduction());
831 }
832
833 // ============================================================================
834
835 #ifndef DOXYGEN_SHOULD_SKIP_THIS
836 namespace detail {
margin_ranking_loss(const Tensor & input1,const Tensor & input2,const Tensor & target,double margin,MarginRankingLossFuncOptions::reduction_t reduction)837 inline Tensor margin_ranking_loss(
838 const Tensor& input1,
839 const Tensor& input2,
840 const Tensor& target,
841 double margin,
842 MarginRankingLossFuncOptions::reduction_t reduction) {
843 TORCH_CHECK(
844 input1.dim() == input2.dim() && input1.dim() == target.dim(),
845 "margin_ranking_loss : All input tensors should have same dimension but got sizes: "
846 "input1: ",
847 input1.sizes(),
848 ", input2: ",
849 input2.sizes(),
850 ", target: ",
851 target.sizes());
852 return torch::margin_ranking_loss(
853 input1, input2, target, margin, enumtype::reduction_get_enum(reduction));
854 }
855 } // namespace detail
856 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
857
858 /// See
859 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.margin_ranking_loss
860 /// about the exact behavior of this functional.
861 ///
862 /// See the documentation for
863 /// `torch::nn::functional::MarginRankingLossFuncOptions` class to learn what
864 /// optional arguments are supported for this functional.
865 ///
866 /// Example:
867 /// ```
868 /// namespace F = torch::nn::functional;
869 /// F::margin_ranking_loss(input1, input2, target,
870 /// F::MarginRankingLossFuncOptions().margin(0.5).reduction(torch::kSum));
871 /// ```
872 inline Tensor margin_ranking_loss(
873 const Tensor& input1,
874 const Tensor& input2,
875 const Tensor& target,
876 const MarginRankingLossFuncOptions& options = {}) {
877 return detail::margin_ranking_loss(
878 input1, input2, target, options.margin(), options.reduction());
879 }
880
881 // ============================================================================
882
883 #ifndef DOXYGEN_SHOULD_SKIP_THIS
884 namespace detail {
nll_loss(const Tensor & input,const Tensor & target,const Tensor & weight,int64_t ignore_index,const NLLLossFuncOptions::reduction_t & reduction)885 inline Tensor nll_loss(
886 const Tensor& input,
887 const Tensor& target,
888 const Tensor& weight,
889 int64_t ignore_index,
890 const NLLLossFuncOptions::reduction_t& reduction) {
891 if (input.dim() < 2) {
892 TORCH_CHECK(false, "Expected 2 or more dimensions (got ", input.dim(), ")");
893 }
894
895 if (input.sizes()[0] != target.sizes()[0]) {
896 TORCH_CHECK(
897 false,
898 "Expected input batch_size (",
899 input.sizes()[0],
900 ") to match target batch_size (",
901 target.sizes()[0],
902 ").");
903 }
904
905 return torch::nll_loss_nd(
906 input,
907 target,
908 weight,
909 enumtype::reduction_get_enum(reduction),
910 ignore_index);
911 }
912 } // namespace detail
913 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
914
915 /// See
916 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.nll_loss
917 /// about the exact behavior of this functional.
918 ///
919 /// See the documentation for `torch::nn::functional::NLLLossFuncOptions` class
920 /// to learn what optional arguments are supported for this functional.
921 ///
922 /// Example:
923 /// ```
924 /// namespace F = torch::nn::functional;
925 /// F::nll_loss(input, target,
926 /// F::NLLLossFuncOptions().ignore_index(-100).reduction(torch::kMean));
927 /// ```
928 inline Tensor nll_loss(
929 const Tensor& input,
930 const Tensor& target,
931 const NLLLossFuncOptions& options = {}) {
932 return detail::nll_loss(
933 input,
934 target,
935 options.weight(),
936 options.ignore_index(),
937 options.reduction());
938 }
939
940 // ============================================================================
941
942 #ifndef DOXYGEN_SHOULD_SKIP_THIS
943 namespace detail {
cross_entropy(const Tensor & input,const Tensor & target,const Tensor & weight,int64_t ignore_index,CrossEntropyFuncOptions::reduction_t reduction,double label_smoothing)944 inline Tensor cross_entropy(
945 const Tensor& input,
946 const Tensor& target,
947 const Tensor& weight,
948 int64_t ignore_index,
949 CrossEntropyFuncOptions::reduction_t reduction,
950 double label_smoothing) {
951 return torch::cross_entropy_loss(
952 input,
953 target,
954 weight,
955 enumtype::reduction_get_enum(reduction),
956 ignore_index,
957 label_smoothing);
958 }
959 } // namespace detail
960 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
961
962 /// See
963 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.cross_entropy
964 /// about the exact behavior of this functional.
965 ///
966 /// See the documentation for `torch::nn::functional::CrossEntropyFuncOptions`
967 /// class to learn what optional arguments are supported for this functional.
968 ///
969 /// Example:
970 /// ```
971 /// namespace F = torch::nn::functional;
972 /// F::cross_entropy(input, target,
973 /// F::CrossEntropyFuncOptions().ignore_index(-100).reduction(torch::kMean));
974 /// ```
975 inline Tensor cross_entropy(
976 const Tensor& input,
977 const Tensor& target,
978 const CrossEntropyFuncOptions& options = {}) {
979 return detail::cross_entropy(
980 input,
981 target,
982 options.weight(),
983 options.ignore_index(),
984 options.reduction(),
985 options.label_smoothing());
986 }
987
988 // ============================================================================
989
990 #ifndef DOXYGEN_SHOULD_SKIP_THIS
991 namespace detail {
binary_cross_entropy_with_logits(const Tensor & input,const Tensor & target,const Tensor & weight,BinaryCrossEntropyWithLogitsFuncOptions::reduction_t reduction,const Tensor & pos_weight)992 inline Tensor binary_cross_entropy_with_logits(
993 const Tensor& input,
994 const Tensor& target,
995 const Tensor& weight,
996 BinaryCrossEntropyWithLogitsFuncOptions::reduction_t reduction,
997 const Tensor& pos_weight) {
998 TORCH_CHECK(
999 target.sizes() == input.sizes(),
1000 "Target size (",
1001 target.sizes(),
1002 ") must be the same as input size (",
1003 input.sizes(),
1004 ")");
1005
1006 return torch::binary_cross_entropy_with_logits(
1007 input,
1008 target,
1009 weight,
1010 pos_weight,
1011 enumtype::reduction_get_enum(reduction));
1012 }
1013 } // namespace detail
1014 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
1015
1016 /// See
1017 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.binary_cross_entropy_with_logits
1018 /// about the exact behavior of this functional.
1019 ///
1020 /// See the documentation for
1021 /// `torch::nn::functional::BinaryCrossEntropyWithLogitsFuncOptions` class to
1022 /// learn what optional arguments are supported for this functional.
1023 ///
1024 /// Example:
1025 /// ```
1026 /// namespace F = torch::nn::functional;
1027 /// F::binary_cross_entropy_with_logits(input, target,
1028 /// F::BinaryCrossEntropyWithLogitsFuncOptions().pos_weight(pos_weight).reduction(torch::kSum));
1029 /// ```
1030 inline Tensor binary_cross_entropy_with_logits(
1031 const Tensor& input,
1032 const Tensor& target,
1033 const BinaryCrossEntropyWithLogitsFuncOptions& options = {}) {
1034 return detail::binary_cross_entropy_with_logits(
1035 input,
1036 target,
1037 options.weight(),
1038 options.reduction(),
1039 options.pos_weight());
1040 }
1041
1042 } // namespace functional
1043 } // namespace nn
1044 } // namespace torch
1045