xref: /aosp_15_r20/external/pytorch/torch/nn/modules/loss.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import Callable, Optional
3from typing_extensions import deprecated
4
5from torch import Tensor
6from torch.nn import _reduction as _Reduction, functional as F
7
8from .distance import PairwiseDistance
9from .module import Module
10
11
12__all__ = [
13    "L1Loss",
14    "NLLLoss",
15    "NLLLoss2d",
16    "PoissonNLLLoss",
17    "GaussianNLLLoss",
18    "KLDivLoss",
19    "MSELoss",
20    "BCELoss",
21    "BCEWithLogitsLoss",
22    "HingeEmbeddingLoss",
23    "MultiLabelMarginLoss",
24    "SmoothL1Loss",
25    "HuberLoss",
26    "SoftMarginLoss",
27    "CrossEntropyLoss",
28    "MultiLabelSoftMarginLoss",
29    "CosineEmbeddingLoss",
30    "MarginRankingLoss",
31    "MultiMarginLoss",
32    "TripletMarginLoss",
33    "TripletMarginWithDistanceLoss",
34    "CTCLoss",
35]
36
37
38class _Loss(Module):
39    reduction: str
40
41    def __init__(self, size_average=None, reduce=None, reduction: str = "mean") -> None:
42        super().__init__()
43        if size_average is not None or reduce is not None:
44            self.reduction: str = _Reduction.legacy_get_string(size_average, reduce)
45        else:
46            self.reduction = reduction
47
48
49class _WeightedLoss(_Loss):
50    def __init__(
51        self,
52        weight: Optional[Tensor] = None,
53        size_average=None,
54        reduce=None,
55        reduction: str = "mean",
56    ) -> None:
57        super().__init__(size_average, reduce, reduction)
58        self.register_buffer("weight", weight)
59        self.weight: Optional[Tensor]
60
61
62class L1Loss(_Loss):
63    r"""Creates a criterion that measures the mean absolute error (MAE) between each element in
64    the input :math:`x` and target :math:`y`.
65
66    The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as:
67
68    .. math::
69        \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
70        l_n = \left| x_n - y_n \right|,
71
72    where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'``
73    (default ``'mean'``), then:
74
75    .. math::
76        \ell(x, y) =
77        \begin{cases}
78            \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\
79            \operatorname{sum}(L),  & \text{if reduction} = \text{`sum'.}
80        \end{cases}
81
82    :math:`x` and :math:`y` are tensors of arbitrary shapes with a total
83    of :math:`N` elements each.
84
85    The sum operation still operates over all the elements, and divides by :math:`N`.
86
87    The division by :math:`N` can be avoided if one sets ``reduction = 'sum'``.
88
89    Supports real-valued and complex-valued inputs.
90
91    Args:
92        size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
93            the losses are averaged over each loss element in the batch. Note that for
94            some losses, there are multiple elements per sample. If the field :attr:`size_average`
95            is set to ``False``, the losses are instead summed for each minibatch. Ignored
96            when :attr:`reduce` is ``False``. Default: ``True``
97        reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
98            losses are averaged or summed over observations for each minibatch depending
99            on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
100            batch element instead and ignores :attr:`size_average`. Default: ``True``
101        reduction (str, optional): Specifies the reduction to apply to the output:
102            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
103            ``'mean'``: the sum of the output will be divided by the number of
104            elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
105            and :attr:`reduce` are in the process of being deprecated, and in the meantime,
106            specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
107
108    Shape:
109        - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
110        - Target: :math:`(*)`, same shape as the input.
111        - Output: scalar. If :attr:`reduction` is ``'none'``, then
112          :math:`(*)`, same shape as the input.
113
114    Examples::
115
116        >>> loss = nn.L1Loss()
117        >>> input = torch.randn(3, 5, requires_grad=True)
118        >>> target = torch.randn(3, 5)
119        >>> output = loss(input, target)
120        >>> output.backward()
121    """
122    __constants__ = ["reduction"]
123
124    def __init__(self, size_average=None, reduce=None, reduction: str = "mean") -> None:
125        super().__init__(size_average, reduce, reduction)
126
127    def forward(self, input: Tensor, target: Tensor) -> Tensor:
128        return F.l1_loss(input, target, reduction=self.reduction)
129
130
131class NLLLoss(_WeightedLoss):
132    r"""The negative log likelihood loss. It is useful to train a classification
133    problem with `C` classes.
134
135    If provided, the optional argument :attr:`weight` should be a 1D Tensor assigning
136    weight to each of the classes. This is particularly useful when you have an
137    unbalanced training set.
138
139    The `input` given through a forward call is expected to contain
140    log-probabilities of each class. `input` has to be a Tensor of size either
141    :math:`(minibatch, C)` or :math:`(minibatch, C, d_1, d_2, ..., d_K)`
142    with :math:`K \geq 1` for the `K`-dimensional case. The latter is useful for
143    higher dimension inputs, such as computing NLL loss per-pixel for 2D images.
144
145    Obtaining log-probabilities in a neural network is easily achieved by
146    adding a  `LogSoftmax`  layer in the last layer of your network.
147    You may use `CrossEntropyLoss` instead, if you prefer not to add an extra
148    layer.
149
150    The `target` that this loss expects should be a class index in the range :math:`[0, C-1]`
151    where `C = number of classes`; if `ignore_index` is specified, this loss also accepts
152    this class index (this index may not necessarily be in the class range).
153
154    The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as:
155
156    .. math::
157        \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
158        l_n = - w_{y_n} x_{n,y_n}, \quad
159        w_{c} = \text{weight}[c] \cdot \mathbb{1}\{c \not= \text{ignore\_index}\},
160
161    where :math:`x` is the input, :math:`y` is the target, :math:`w` is the weight, and
162    :math:`N` is the batch size. If :attr:`reduction` is not ``'none'``
163    (default ``'mean'``), then
164
165    .. math::
166        \ell(x, y) = \begin{cases}
167            \sum_{n=1}^N \frac{1}{\sum_{n=1}^N w_{y_n}} l_n, &
168            \text{if reduction} = \text{`mean';}\\
169            \sum_{n=1}^N l_n,  &
170            \text{if reduction} = \text{`sum'.}
171        \end{cases}
172
173    Args:
174        weight (Tensor, optional): a manual rescaling weight given to each
175            class. If given, it has to be a Tensor of size `C`. Otherwise, it is
176            treated as if having all ones.
177        size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
178            the losses are averaged over each loss element in the batch. Note that for
179            some losses, there are multiple elements per sample. If the field :attr:`size_average`
180            is set to ``False``, the losses are instead summed for each minibatch. Ignored
181            when :attr:`reduce` is ``False``. Default: ``None``
182        ignore_index (int, optional): Specifies a target value that is ignored
183            and does not contribute to the input gradient. When
184            :attr:`size_average` is ``True``, the loss is averaged over
185            non-ignored targets.
186        reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
187            losses are averaged or summed over observations for each minibatch depending
188            on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
189            batch element instead and ignores :attr:`size_average`. Default: ``None``
190        reduction (str, optional): Specifies the reduction to apply to the output:
191            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will
192            be applied, ``'mean'``: the weighted mean of the output is taken,
193            ``'sum'``: the output will be summed. Note: :attr:`size_average`
194            and :attr:`reduce` are in the process of being deprecated, and in
195            the meantime, specifying either of those two args will override
196            :attr:`reduction`. Default: ``'mean'``
197
198    Shape::
199        - Input: :math:`(N, C)` or :math:`(C)`, where `C = number of classes`, `N = batch size`, or
200          :math:`(N, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1`
201          in the case of `K`-dimensional loss.
202        - Target: :math:`(N)` or :math:`()`, where each value is
203          :math:`0 \leq \text{targets}[i] \leq C-1`, or
204          :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of
205          K-dimensional loss.
206        - Output: If :attr:`reduction` is ``'none'``, shape :math:`(N)` or
207          :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of K-dimensional loss.
208          Otherwise, scalar.
209
210    Examples::
211
212        >>> log_softmax = nn.LogSoftmax(dim=1)
213        >>> loss_fn = nn.NLLLoss()
214        >>> # input to NLLLoss is of size N x C = 3 x 5
215        >>> input = torch.randn(3, 5, requires_grad=True)
216        >>> # each element in target must have 0 <= value < C
217        >>> target = torch.tensor([1, 0, 4])
218        >>> loss = loss_fn(log_softmax(input), target)
219        >>> loss.backward()
220        >>>
221        >>>
222        >>> # 2D loss example (used, for example, with image inputs)
223        >>> N, C = 5, 4
224        >>> loss_fn = nn.NLLLoss()
225        >>> data = torch.randn(N, 16, 10, 10)
226        >>> conv = nn.Conv2d(16, C, (3, 3))
227        >>> log_softmax = nn.LogSoftmax(dim=1)
228        >>> # output of conv forward is of shape [N, C, 8, 8]
229        >>> output = log_softmax(conv(data))
230        >>> # each element in target must have 0 <= value < C
231        >>> target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
232        >>> # input to NLLLoss is of size N x C x height (8) x width (8)
233        >>> loss = loss_fn(output, target)
234        >>> loss.backward()
235    """
236    __constants__ = ["ignore_index", "reduction"]
237    ignore_index: int
238
239    def __init__(
240        self,
241        weight: Optional[Tensor] = None,
242        size_average=None,
243        ignore_index: int = -100,
244        reduce=None,
245        reduction: str = "mean",
246    ) -> None:
247        super().__init__(weight, size_average, reduce, reduction)
248        self.ignore_index = ignore_index
249
250    def forward(self, input: Tensor, target: Tensor) -> Tensor:
251        return F.nll_loss(
252            input,
253            target,
254            weight=self.weight,
255            ignore_index=self.ignore_index,
256            reduction=self.reduction,
257        )
258
259
260@deprecated(
261    "`NLLLoss2d` has been deprecated. "
262    "Please use `NLLLoss` instead as a drop-in replacement and see "
263    "https://pytorch.org/docs/main/nn.html#torch.nn.NLLLoss for more details.",
264    category=FutureWarning,
265)
266class NLLLoss2d(NLLLoss):
267    def __init__(
268        self,
269        weight: Optional[Tensor] = None,
270        size_average=None,
271        ignore_index: int = -100,
272        reduce=None,
273        reduction: str = "mean",
274    ) -> None:
275        super().__init__(weight, size_average, ignore_index, reduce, reduction)
276
277
278class PoissonNLLLoss(_Loss):
279    r"""Negative log likelihood loss with Poisson distribution of target.
280
281    The loss can be described as:
282
283    .. math::
284        \text{target} \sim \mathrm{Poisson}(\text{input})
285
286        \text{loss}(\text{input}, \text{target}) = \text{input} - \text{target} * \log(\text{input})
287                                    + \log(\text{target!})
288
289    The last term can be omitted or approximated with Stirling formula. The
290    approximation is used for target values more than 1. For targets less or
291    equal to 1 zeros are added to the loss.
292
293    Args:
294        log_input (bool, optional): if ``True`` the loss is computed as
295            :math:`\exp(\text{input}) - \text{target}*\text{input}`, if ``False`` the loss is
296            :math:`\text{input} - \text{target}*\log(\text{input}+\text{eps})`.
297        full (bool, optional): whether to compute full loss, i. e. to add the
298            Stirling approximation term
299
300            .. math::
301                \text{target}*\log(\text{target}) - \text{target} + 0.5 * \log(2\pi\text{target}).
302        size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
303            the losses are averaged over each loss element in the batch. Note that for
304            some losses, there are multiple elements per sample. If the field :attr:`size_average`
305            is set to ``False``, the losses are instead summed for each minibatch. Ignored
306            when :attr:`reduce` is ``False``. Default: ``True``
307        eps (float, optional): Small value to avoid evaluation of :math:`\log(0)` when
308            :attr:`log_input = False`. Default: 1e-8
309        reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
310            losses are averaged or summed over observations for each minibatch depending
311            on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
312            batch element instead and ignores :attr:`size_average`. Default: ``True``
313        reduction (str, optional): Specifies the reduction to apply to the output:
314            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
315            ``'mean'``: the sum of the output will be divided by the number of
316            elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
317            and :attr:`reduce` are in the process of being deprecated, and in the meantime,
318            specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
319
320    Examples::
321
322        >>> loss = nn.PoissonNLLLoss()
323        >>> log_input = torch.randn(5, 2, requires_grad=True)
324        >>> target = torch.randn(5, 2)
325        >>> output = loss(log_input, target)
326        >>> output.backward()
327
328    Shape:
329        - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
330        - Target: :math:`(*)`, same shape as the input.
331        - Output: scalar by default. If :attr:`reduction` is ``'none'``, then :math:`(*)`,
332          the same shape as the input.
333    """
334    __constants__ = ["log_input", "full", "eps", "reduction"]
335    log_input: bool
336    full: bool
337    eps: float
338
339    def __init__(
340        self,
341        log_input: bool = True,
342        full: bool = False,
343        size_average=None,
344        eps: float = 1e-8,
345        reduce=None,
346        reduction: str = "mean",
347    ) -> None:
348        super().__init__(size_average, reduce, reduction)
349        self.log_input = log_input
350        self.full = full
351        self.eps = eps
352
353    def forward(self, log_input: Tensor, target: Tensor) -> Tensor:
354        return F.poisson_nll_loss(
355            log_input,
356            target,
357            log_input=self.log_input,
358            full=self.full,
359            eps=self.eps,
360            reduction=self.reduction,
361        )
362
363
364class GaussianNLLLoss(_Loss):
365    r"""Gaussian negative log likelihood loss.
366
367    The targets are treated as samples from Gaussian distributions with
368    expectations and variances predicted by the neural network. For a
369    ``target`` tensor modelled as having Gaussian distribution with a tensor
370    of expectations ``input`` and a tensor of positive variances ``var`` the loss is:
371
372    .. math::
373        \text{loss} = \frac{1}{2}\left(\log\left(\text{max}\left(\text{var},
374        \ \text{eps}\right)\right) + \frac{\left(\text{input} - \text{target}\right)^2}
375        {\text{max}\left(\text{var}, \ \text{eps}\right)}\right) + \text{const.}
376
377    where :attr:`eps` is used for stability. By default, the constant term of
378    the loss function is omitted unless :attr:`full` is ``True``. If ``var`` is not the same
379    size as ``input`` (due to a homoscedastic assumption), it must either have a final dimension
380    of 1 or have one fewer dimension (with all other sizes being the same) for correct broadcasting.
381
382    Args:
383        full (bool, optional): include the constant term in the loss
384            calculation. Default: ``False``.
385        eps (float, optional): value used to clamp ``var`` (see note below), for
386            stability. Default: 1e-6.
387        reduction (str, optional): specifies the reduction to apply to the
388            output:``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction
389            will be applied, ``'mean'``: the output is the average of all batch
390            member losses, ``'sum'``: the output is the sum of all batch member
391            losses. Default: ``'mean'``.
392
393    Shape:
394        - Input: :math:`(N, *)` or :math:`(*)` where :math:`*` means any number of additional
395          dimensions
396        - Target: :math:`(N, *)` or :math:`(*)`, same shape as the input, or same shape as the input
397          but with one dimension equal to 1 (to allow for broadcasting)
398        - Var: :math:`(N, *)` or :math:`(*)`, same shape as the input, or same shape as the input but
399          with one dimension equal to 1, or same shape as the input but with one fewer
400          dimension (to allow for broadcasting)
401        - Output: scalar if :attr:`reduction` is ``'mean'`` (default) or
402          ``'sum'``. If :attr:`reduction` is ``'none'``, then :math:`(N, *)`, same
403          shape as the input
404
405    Examples::
406        >>> loss = nn.GaussianNLLLoss()
407        >>> input = torch.randn(5, 2, requires_grad=True)
408        >>> target = torch.randn(5, 2)
409        >>> var = torch.ones(5, 2, requires_grad=True)  # heteroscedastic
410        >>> output = loss(input, target, var)
411        >>> output.backward()
412
413        >>> loss = nn.GaussianNLLLoss()
414        >>> input = torch.randn(5, 2, requires_grad=True)
415        >>> target = torch.randn(5, 2)
416        >>> var = torch.ones(5, 1, requires_grad=True)  # homoscedastic
417        >>> output = loss(input, target, var)
418        >>> output.backward()
419
420    Note:
421        The clamping of ``var`` is ignored with respect to autograd, and so the
422        gradients are unaffected by it.
423
424    Reference:
425        Nix, D. A. and Weigend, A. S., "Estimating the mean and variance of the
426        target probability distribution", Proceedings of 1994 IEEE International
427        Conference on Neural Networks (ICNN'94), Orlando, FL, USA, 1994, pp. 55-60
428        vol.1, doi: 10.1109/ICNN.1994.374138.
429    """
430    __constants__ = ["full", "eps", "reduction"]
431    full: bool
432    eps: float
433
434    def __init__(
435        self, *, full: bool = False, eps: float = 1e-6, reduction: str = "mean"
436    ) -> None:
437        super().__init__(None, None, reduction)
438        self.full = full
439        self.eps = eps
440
441    def forward(self, input: Tensor, target: Tensor, var: Tensor) -> Tensor:
442        return F.gaussian_nll_loss(
443            input, target, var, full=self.full, eps=self.eps, reduction=self.reduction
444        )
445
446
447class KLDivLoss(_Loss):
448    r"""The Kullback-Leibler divergence loss.
449
450    For tensors of the same shape :math:`y_{\text{pred}},\ y_{\text{true}}`,
451    where :math:`y_{\text{pred}}` is the :attr:`input` and :math:`y_{\text{true}}` is the
452    :attr:`target`, we define the **pointwise KL-divergence** as
453
454    .. math::
455
456        L(y_{\text{pred}},\ y_{\text{true}})
457            = y_{\text{true}} \cdot \log \frac{y_{\text{true}}}{y_{\text{pred}}}
458            = y_{\text{true}} \cdot (\log y_{\text{true}} - \log y_{\text{pred}})
459
460    To avoid underflow issues when computing this quantity, this loss expects the argument
461    :attr:`input` in the log-space. The argument :attr:`target` may also be provided in the
462    log-space if :attr:`log_target`\ `= True`.
463
464    To summarise, this function is roughly equivalent to computing
465
466    .. code-block:: python
467
468        if not log_target: # default
469            loss_pointwise = target * (target.log() - input)
470        else:
471            loss_pointwise = target.exp() * (target - input)
472
473    and then reducing this result depending on the argument :attr:`reduction` as
474
475    .. code-block:: python
476
477        if reduction == "mean":  # default
478            loss = loss_pointwise.mean()
479        elif reduction == "batchmean":  # mathematically correct
480            loss = loss_pointwise.sum() / input.size(0)
481        elif reduction == "sum":
482            loss = loss_pointwise.sum()
483        else:  # reduction == "none"
484            loss = loss_pointwise
485
486    .. note::
487        As all the other losses in PyTorch, this function expects the first argument,
488        :attr:`input`, to be the output of the model (e.g. the neural network)
489        and the second, :attr:`target`, to be the observations in the dataset.
490        This differs from the standard mathematical notation :math:`KL(P\ ||\ Q)` where
491        :math:`P` denotes the distribution of the observations and :math:`Q` denotes the model.
492
493    .. warning::
494        :attr:`reduction`\ `= "mean"` doesn't return the true KL divergence value, please use
495        :attr:`reduction`\ `= "batchmean"` which aligns with the mathematical definition.
496
497    Args:
498        size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
499            the losses are averaged over each loss element in the batch. Note that for
500            some losses, there are multiple elements per sample. If the field :attr:`size_average`
501            is set to `False`, the losses are instead summed for each minibatch. Ignored
502            when :attr:`reduce` is `False`. Default: `True`
503        reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
504            losses are averaged or summed over observations for each minibatch depending
505            on :attr:`size_average`. When :attr:`reduce` is `False`, returns a loss per
506            batch element instead and ignores :attr:`size_average`. Default: `True`
507        reduction (str, optional): Specifies the reduction to apply to the output. Default: `"mean"`
508        log_target (bool, optional): Specifies whether `target` is the log space. Default: `False`
509
510    Shape:
511        - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
512        - Target: :math:`(*)`, same shape as the input.
513        - Output: scalar by default. If :attr:`reduction` is `'none'`, then :math:`(*)`,
514          same shape as the input.
515
516    Examples::
517        >>> kl_loss = nn.KLDivLoss(reduction="batchmean")
518        >>> # input should be a distribution in the log space
519        >>> input = F.log_softmax(torch.randn(3, 5, requires_grad=True), dim=1)
520        >>> # Sample a batch of distributions. Usually this would come from the dataset
521        >>> target = F.softmax(torch.rand(3, 5), dim=1)
522        >>> output = kl_loss(input, target)
523
524        >>> kl_loss = nn.KLDivLoss(reduction="batchmean", log_target=True)
525        >>> log_target = F.log_softmax(torch.rand(3, 5), dim=1)
526        >>> output = kl_loss(input, log_target)
527    """
528    __constants__ = ["reduction"]
529
530    def __init__(
531        self,
532        size_average=None,
533        reduce=None,
534        reduction: str = "mean",
535        log_target: bool = False,
536    ) -> None:
537        super().__init__(size_average, reduce, reduction)
538        self.log_target = log_target
539
540    def forward(self, input: Tensor, target: Tensor) -> Tensor:
541        return F.kl_div(
542            input, target, reduction=self.reduction, log_target=self.log_target
543        )
544
545
546class MSELoss(_Loss):
547    r"""Creates a criterion that measures the mean squared error (squared L2 norm) between
548    each element in the input :math:`x` and target :math:`y`.
549
550    The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as:
551
552    .. math::
553        \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
554        l_n = \left( x_n - y_n \right)^2,
555
556    where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'``
557    (default ``'mean'``), then:
558
559    .. math::
560        \ell(x, y) =
561        \begin{cases}
562            \operatorname{mean}(L), &  \text{if reduction} = \text{`mean';}\\
563            \operatorname{sum}(L),  &  \text{if reduction} = \text{`sum'.}
564        \end{cases}
565
566    :math:`x` and :math:`y` are tensors of arbitrary shapes with a total
567    of :math:`N` elements each.
568
569    The mean operation still operates over all the elements, and divides by :math:`N`.
570
571    The division by :math:`N` can be avoided if one sets ``reduction = 'sum'``.
572
573    Args:
574        size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
575            the losses are averaged over each loss element in the batch. Note that for
576            some losses, there are multiple elements per sample. If the field :attr:`size_average`
577            is set to ``False``, the losses are instead summed for each minibatch. Ignored
578            when :attr:`reduce` is ``False``. Default: ``True``
579        reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
580            losses are averaged or summed over observations for each minibatch depending
581            on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
582            batch element instead and ignores :attr:`size_average`. Default: ``True``
583        reduction (str, optional): Specifies the reduction to apply to the output:
584            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
585            ``'mean'``: the sum of the output will be divided by the number of
586            elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
587            and :attr:`reduce` are in the process of being deprecated, and in the meantime,
588            specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
589
590    Shape:
591        - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
592        - Target: :math:`(*)`, same shape as the input.
593
594    Examples::
595
596        >>> loss = nn.MSELoss()
597        >>> input = torch.randn(3, 5, requires_grad=True)
598        >>> target = torch.randn(3, 5)
599        >>> output = loss(input, target)
600        >>> output.backward()
601    """
602    __constants__ = ["reduction"]
603
604    def __init__(self, size_average=None, reduce=None, reduction: str = "mean") -> None:
605        super().__init__(size_average, reduce, reduction)
606
607    def forward(self, input: Tensor, target: Tensor) -> Tensor:
608        return F.mse_loss(input, target, reduction=self.reduction)
609
610
611class BCELoss(_WeightedLoss):
612    r"""Creates a criterion that measures the Binary Cross Entropy between the target and
613    the input probabilities:
614
615    The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as:
616
617    .. math::
618        \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
619        l_n = - w_n \left[ y_n \cdot \log x_n + (1 - y_n) \cdot \log (1 - x_n) \right],
620
621    where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'``
622    (default ``'mean'``), then
623
624    .. math::
625        \ell(x, y) = \begin{cases}
626            \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\
627            \operatorname{sum}(L),  & \text{if reduction} = \text{`sum'.}
628        \end{cases}
629
630    This is used for measuring the error of a reconstruction in for example
631    an auto-encoder. Note that the targets :math:`y` should be numbers
632    between 0 and 1.
633
634    Notice that if :math:`x_n` is either 0 or 1, one of the log terms would be
635    mathematically undefined in the above loss equation. PyTorch chooses to set
636    :math:`\log (0) = -\infty`, since :math:`\lim_{x\to 0} \log (x) = -\infty`.
637    However, an infinite term in the loss equation is not desirable for several reasons.
638
639    For one, if either :math:`y_n = 0` or :math:`(1 - y_n) = 0`, then we would be
640    multiplying 0 with infinity. Secondly, if we have an infinite loss value, then
641    we would also have an infinite term in our gradient, since
642    :math:`\lim_{x\to 0} \frac{d}{dx} \log (x) = \infty`.
643    This would make BCELoss's backward method nonlinear with respect to :math:`x_n`,
644    and using it for things like linear regression would not be straight-forward.
645
646    Our solution is that BCELoss clamps its log function outputs to be greater than
647    or equal to -100. This way, we can always have a finite loss value and a linear
648    backward method.
649
650
651    Args:
652        weight (Tensor, optional): a manual rescaling weight given to the loss
653            of each batch element. If given, has to be a Tensor of size `nbatch`.
654        size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
655            the losses are averaged over each loss element in the batch. Note that for
656            some losses, there are multiple elements per sample. If the field :attr:`size_average`
657            is set to ``False``, the losses are instead summed for each minibatch. Ignored
658            when :attr:`reduce` is ``False``. Default: ``True``
659        reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
660            losses are averaged or summed over observations for each minibatch depending
661            on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
662            batch element instead and ignores :attr:`size_average`. Default: ``True``
663        reduction (str, optional): Specifies the reduction to apply to the output:
664            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
665            ``'mean'``: the sum of the output will be divided by the number of
666            elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
667            and :attr:`reduce` are in the process of being deprecated, and in the meantime,
668            specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
669
670    Shape:
671        - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
672        - Target: :math:`(*)`, same shape as the input.
673        - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(*)`, same
674          shape as input.
675
676    Examples::
677
678        >>> m = nn.Sigmoid()
679        >>> loss = nn.BCELoss()
680        >>> input = torch.randn(3, 2, requires_grad=True)
681        >>> target = torch.rand(3, 2, requires_grad=False)
682        >>> output = loss(m(input), target)
683        >>> output.backward()
684    """
685    __constants__ = ["reduction"]
686
687    def __init__(
688        self,
689        weight: Optional[Tensor] = None,
690        size_average=None,
691        reduce=None,
692        reduction: str = "mean",
693    ) -> None:
694        super().__init__(weight, size_average, reduce, reduction)
695
696    def forward(self, input: Tensor, target: Tensor) -> Tensor:
697        return F.binary_cross_entropy(
698            input, target, weight=self.weight, reduction=self.reduction
699        )
700
701
702class BCEWithLogitsLoss(_Loss):
703    r"""This loss combines a `Sigmoid` layer and the `BCELoss` in one single
704    class. This version is more numerically stable than using a plain `Sigmoid`
705    followed by a `BCELoss` as, by combining the operations into one layer,
706    we take advantage of the log-sum-exp trick for numerical stability.
707
708    The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as:
709
710    .. math::
711        \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
712        l_n = - w_n \left[ y_n \cdot \log \sigma(x_n)
713        + (1 - y_n) \cdot \log (1 - \sigma(x_n)) \right],
714
715    where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'``
716    (default ``'mean'``), then
717
718    .. math::
719        \ell(x, y) = \begin{cases}
720            \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\
721            \operatorname{sum}(L),  & \text{if reduction} = \text{`sum'.}
722        \end{cases}
723
724    This is used for measuring the error of a reconstruction in for example
725    an auto-encoder. Note that the targets `t[i]` should be numbers
726    between 0 and 1.
727
728    It's possible to trade off recall and precision by adding weights to positive examples.
729    In the case of multi-label classification the loss can be described as:
730
731    .. math::
732        \ell_c(x, y) = L_c = \{l_{1,c},\dots,l_{N,c}\}^\top, \quad
733        l_{n,c} = - w_{n,c} \left[ p_c y_{n,c} \cdot \log \sigma(x_{n,c})
734        + (1 - y_{n,c}) \cdot \log (1 - \sigma(x_{n,c})) \right],
735
736    where :math:`c` is the class number (:math:`c > 1` for multi-label binary classification,
737    :math:`c = 1` for single-label binary classification),
738    :math:`n` is the number of the sample in the batch and
739    :math:`p_c` is the weight of the positive answer for the class :math:`c`.
740
741    :math:`p_c > 1` increases the recall, :math:`p_c < 1` increases the precision.
742
743    For example, if a dataset contains 100 positive and 300 negative examples of a single class,
744    then ``pos_weight`` for the class should be equal to :math:`\frac{300}{100}=3`.
745    The loss would act as if the dataset contains :math:`3\times 100=300` positive examples.
746
747    Examples::
748
749        >>> target = torch.ones([10, 64], dtype=torch.float32)  # 64 classes, batch size = 10
750        >>> output = torch.full([10, 64], 1.5)  # A prediction (logit)
751        >>> pos_weight = torch.ones([64])  # All weights are equal to 1
752        >>> criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
753        >>> criterion(output, target)  # -log(sigmoid(1.5))
754        tensor(0.20...)
755
756    In the above example, the ``pos_weight`` tensor's elements correspond to the 64 distinct classes
757    in a multi-label binary classification scenario. Each element in ``pos_weight`` is designed to adjust the
758    loss function based on the imbalance between negative and positive samples for the respective class.
759    This approach is useful in datasets with varying levels of class imbalance, ensuring that the loss
760    calculation accurately accounts for the distribution in each class.
761
762    Args:
763        weight (Tensor, optional): a manual rescaling weight given to the loss
764            of each batch element. If given, has to be a Tensor of size `nbatch`.
765        size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
766            the losses are averaged over each loss element in the batch. Note that for
767            some losses, there are multiple elements per sample. If the field :attr:`size_average`
768            is set to ``False``, the losses are instead summed for each minibatch. Ignored
769            when :attr:`reduce` is ``False``. Default: ``True``
770        reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
771            losses are averaged or summed over observations for each minibatch depending
772            on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
773            batch element instead and ignores :attr:`size_average`. Default: ``True``
774        reduction (str, optional): Specifies the reduction to apply to the output:
775            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
776            ``'mean'``: the sum of the output will be divided by the number of
777            elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
778            and :attr:`reduce` are in the process of being deprecated, and in the meantime,
779            specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
780        pos_weight (Tensor, optional): a weight of positive examples to be broadcasted with target.
781            Must be a tensor with equal size along the class dimension to the number of classes.
782            Pay close attention to PyTorch's broadcasting semantics in order to achieve the desired
783            operations. For a target of size [B, C, H, W] (where B is batch size) pos_weight of
784            size [B, C, H, W] will apply different pos_weights to each element of the batch or
785            [C, H, W] the same pos_weights across the batch. To apply the same positive weight
786            along all spacial dimensions for a 2D multi-class target [C, H, W] use: [C, 1, 1].
787            Default: ``None``
788
789    Shape:
790        - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
791        - Target: :math:`(*)`, same shape as the input.
792        - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(*)`, same
793          shape as input.
794
795     Examples::
796
797        >>> loss = nn.BCEWithLogitsLoss()
798        >>> input = torch.randn(3, requires_grad=True)
799        >>> target = torch.empty(3).random_(2)
800        >>> output = loss(input, target)
801        >>> output.backward()
802    """
803
804    def __init__(
805        self,
806        weight: Optional[Tensor] = None,
807        size_average=None,
808        reduce=None,
809        reduction: str = "mean",
810        pos_weight: Optional[Tensor] = None,
811    ) -> None:
812        super().__init__(size_average, reduce, reduction)
813        self.register_buffer("weight", weight)
814        self.register_buffer("pos_weight", pos_weight)
815        self.weight: Optional[Tensor]
816        self.pos_weight: Optional[Tensor]
817
818    def forward(self, input: Tensor, target: Tensor) -> Tensor:
819        return F.binary_cross_entropy_with_logits(
820            input,
821            target,
822            self.weight,
823            pos_weight=self.pos_weight,
824            reduction=self.reduction,
825        )
826
827
828class HingeEmbeddingLoss(_Loss):
829    r"""Measures the loss given an input tensor :math:`x` and a labels tensor :math:`y`
830    (containing 1 or -1).
831    This is usually used for measuring whether two inputs are similar or
832    dissimilar, e.g. using the L1 pairwise distance as :math:`x`, and is typically
833    used for learning nonlinear embeddings or semi-supervised learning.
834
835    The loss function for :math:`n`-th sample in the mini-batch is
836
837    .. math::
838        l_n = \begin{cases}
839            x_n, & \text{if}\; y_n = 1,\\
840            \max \{0, margin - x_n\}, & \text{if}\; y_n = -1,
841        \end{cases}
842
843    and the total loss functions is
844
845    .. math::
846        \ell(x, y) = \begin{cases}
847            \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\
848            \operatorname{sum}(L),  & \text{if reduction} = \text{`sum'.}
849        \end{cases}
850
851    where :math:`L = \{l_1,\dots,l_N\}^\top`.
852
853    Args:
854        margin (float, optional): Has a default value of `1`.
855        size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
856            the losses are averaged over each loss element in the batch. Note that for
857            some losses, there are multiple elements per sample. If the field :attr:`size_average`
858            is set to ``False``, the losses are instead summed for each minibatch. Ignored
859            when :attr:`reduce` is ``False``. Default: ``True``
860        reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
861            losses are averaged or summed over observations for each minibatch depending
862            on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
863            batch element instead and ignores :attr:`size_average`. Default: ``True``
864        reduction (str, optional): Specifies the reduction to apply to the output:
865            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
866            ``'mean'``: the sum of the output will be divided by the number of
867            elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
868            and :attr:`reduce` are in the process of being deprecated, and in the meantime,
869            specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
870
871    Shape:
872        - Input: :math:`(*)` where :math:`*` means, any number of dimensions. The sum operation
873          operates over all the elements.
874        - Target: :math:`(*)`, same shape as the input
875        - Output: scalar. If :attr:`reduction` is ``'none'``, then same shape as the input
876    """
877    __constants__ = ["margin", "reduction"]
878    margin: float
879
880    def __init__(
881        self,
882        margin: float = 1.0,
883        size_average=None,
884        reduce=None,
885        reduction: str = "mean",
886    ) -> None:
887        super().__init__(size_average, reduce, reduction)
888        self.margin = margin
889
890    def forward(self, input: Tensor, target: Tensor) -> Tensor:
891        return F.hinge_embedding_loss(
892            input, target, margin=self.margin, reduction=self.reduction
893        )
894
895
896class MultiLabelMarginLoss(_Loss):
897    r"""Creates a criterion that optimizes a multi-class multi-classification
898    hinge loss (margin-based loss) between input :math:`x` (a 2D mini-batch `Tensor`)
899    and output :math:`y` (which is a 2D `Tensor` of target class indices).
900    For each sample in the mini-batch:
901
902    .. math::
903        \text{loss}(x, y) = \sum_{ij}\frac{\max(0, 1 - (x[y[j]] - x[i]))}{\text{x.size}(0)}
904
905    where :math:`x \in \left\{0, \; \cdots , \; \text{x.size}(0) - 1\right\}`, \
906    :math:`y \in \left\{0, \; \cdots , \; \text{y.size}(0) - 1\right\}`, \
907    :math:`0 \leq y[j] \leq \text{x.size}(0)-1`, \
908    and :math:`i \neq y[j]` for all :math:`i` and :math:`j`.
909
910    :math:`y` and :math:`x` must have the same size.
911
912    The criterion only considers a contiguous block of non-negative targets that
913    starts at the front.
914
915    This allows for different samples to have variable amounts of target classes.
916
917    Args:
918        size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
919            the losses are averaged over each loss element in the batch. Note that for
920            some losses, there are multiple elements per sample. If the field :attr:`size_average`
921            is set to ``False``, the losses are instead summed for each minibatch. Ignored
922            when :attr:`reduce` is ``False``. Default: ``True``
923        reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
924            losses are averaged or summed over observations for each minibatch depending
925            on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
926            batch element instead and ignores :attr:`size_average`. Default: ``True``
927        reduction (str, optional): Specifies the reduction to apply to the output:
928            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
929            ``'mean'``: the sum of the output will be divided by the number of
930            elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
931            and :attr:`reduce` are in the process of being deprecated, and in the meantime,
932            specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
933
934    Shape:
935        - Input: :math:`(C)` or :math:`(N, C)` where `N` is the batch size and `C`
936          is the number of classes.
937        - Target: :math:`(C)` or :math:`(N, C)`, label targets padded by -1 ensuring same shape as the input.
938        - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(N)`.
939
940    Examples::
941
942        >>> loss = nn.MultiLabelMarginLoss()
943        >>> x = torch.FloatTensor([[0.1, 0.2, 0.4, 0.8]])
944        >>> # for target y, only consider labels 3 and 0, not after label -1
945        >>> y = torch.LongTensor([[3, 0, -1, 1]])
946        >>> # 0.25 * ((1-(0.1-0.2)) + (1-(0.1-0.4)) + (1-(0.8-0.2)) + (1-(0.8-0.4)))
947        >>> loss(x, y)
948        tensor(0.85...)
949
950    """
951    __constants__ = ["reduction"]
952
953    def __init__(self, size_average=None, reduce=None, reduction: str = "mean") -> None:
954        super().__init__(size_average, reduce, reduction)
955
956    def forward(self, input: Tensor, target: Tensor) -> Tensor:
957        return F.multilabel_margin_loss(input, target, reduction=self.reduction)
958
959
960class SmoothL1Loss(_Loss):
961    r"""Creates a criterion that uses a squared term if the absolute
962    element-wise error falls below beta and an L1 term otherwise.
963    It is less sensitive to outliers than :class:`torch.nn.MSELoss` and in some cases
964    prevents exploding gradients (e.g. see the paper `Fast R-CNN`_ by Ross Girshick).
965
966    For a batch of size :math:`N`, the unreduced loss can be described as:
967
968    .. math::
969        \ell(x, y) = L = \{l_1, ..., l_N\}^T
970
971    with
972
973    .. math::
974        l_n = \begin{cases}
975        0.5 (x_n - y_n)^2 / beta, & \text{if } |x_n - y_n| < beta \\
976        |x_n - y_n| - 0.5 * beta, & \text{otherwise }
977        \end{cases}
978
979    If `reduction` is not `none`, then:
980
981    .. math::
982        \ell(x, y) =
983        \begin{cases}
984            \operatorname{mean}(L), &  \text{if reduction} = \text{`mean';}\\
985            \operatorname{sum}(L),  &  \text{if reduction} = \text{`sum'.}
986        \end{cases}
987
988    .. note::
989        Smooth L1 loss can be seen as exactly :class:`L1Loss`, but with the :math:`|x - y| < beta`
990        portion replaced with a quadratic function such that its slope is 1 at :math:`|x - y| = beta`.
991        The quadratic segment smooths the L1 loss near :math:`|x - y| = 0`.
992
993    .. note::
994        Smooth L1 loss is closely related to :class:`HuberLoss`, being
995        equivalent to :math:`huber(x, y) / beta` (note that Smooth L1's beta hyper-parameter is
996        also known as delta for Huber). This leads to the following differences:
997
998        * As beta -> 0, Smooth L1 loss converges to :class:`L1Loss`, while :class:`HuberLoss`
999          converges to a constant 0 loss. When beta is 0, Smooth L1 loss is equivalent to L1 loss.
1000        * As beta -> :math:`+\infty`, Smooth L1 loss converges to a constant 0 loss, while
1001          :class:`HuberLoss` converges to :class:`MSELoss`.
1002        * For Smooth L1 loss, as beta varies, the L1 segment of the loss has a constant slope of 1.
1003          For :class:`HuberLoss`, the slope of the L1 segment is beta.
1004
1005    .. _`Fast R-CNN`: https://arxiv.org/abs/1504.08083
1006
1007    Args:
1008        size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
1009            the losses are averaged over each loss element in the batch. Note that for
1010            some losses, there are multiple elements per sample. If the field :attr:`size_average`
1011            is set to ``False``, the losses are instead summed for each minibatch. Ignored
1012            when :attr:`reduce` is ``False``. Default: ``True``
1013        reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
1014            losses are averaged or summed over observations for each minibatch depending
1015            on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
1016            batch element instead and ignores :attr:`size_average`. Default: ``True``
1017        reduction (str, optional): Specifies the reduction to apply to the output:
1018            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
1019            ``'mean'``: the sum of the output will be divided by the number of
1020            elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
1021            and :attr:`reduce` are in the process of being deprecated, and in the meantime,
1022            specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
1023        beta (float, optional): Specifies the threshold at which to change between L1 and L2 loss.
1024            The value must be non-negative. Default: 1.0
1025
1026    Shape:
1027        - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
1028        - Target: :math:`(*)`, same shape as the input.
1029        - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(*)`, same shape as the input.
1030    """
1031    __constants__ = ["reduction"]
1032
1033    def __init__(
1034        self, size_average=None, reduce=None, reduction: str = "mean", beta: float = 1.0
1035    ) -> None:
1036        super().__init__(size_average, reduce, reduction)
1037        self.beta = beta
1038
1039    def forward(self, input: Tensor, target: Tensor) -> Tensor:
1040        return F.smooth_l1_loss(input, target, reduction=self.reduction, beta=self.beta)
1041
1042
1043class HuberLoss(_Loss):
1044    r"""Creates a criterion that uses a squared term if the absolute
1045    element-wise error falls below delta and a delta-scaled L1 term otherwise.
1046    This loss combines advantages of both :class:`L1Loss` and :class:`MSELoss`; the
1047    delta-scaled L1 region makes the loss less sensitive to outliers than :class:`MSELoss`,
1048    while the L2 region provides smoothness over :class:`L1Loss` near 0. See
1049    `Huber loss <https://en.wikipedia.org/wiki/Huber_loss>`_ for more information.
1050
1051    For a batch of size :math:`N`, the unreduced loss can be described as:
1052
1053    .. math::
1054        \ell(x, y) = L = \{l_1, ..., l_N\}^T
1055
1056    with
1057
1058    .. math::
1059        l_n = \begin{cases}
1060        0.5 (x_n - y_n)^2, & \text{if } |x_n - y_n| < delta \\
1061        delta * (|x_n - y_n| - 0.5 * delta), & \text{otherwise }
1062        \end{cases}
1063
1064    If `reduction` is not `none`, then:
1065
1066    .. math::
1067        \ell(x, y) =
1068        \begin{cases}
1069            \operatorname{mean}(L), &  \text{if reduction} = \text{`mean';}\\
1070            \operatorname{sum}(L),  &  \text{if reduction} = \text{`sum'.}
1071        \end{cases}
1072
1073    .. note::
1074        When delta is set to 1, this loss is equivalent to :class:`SmoothL1Loss`.
1075        In general, this loss differs from :class:`SmoothL1Loss` by a factor of delta (AKA beta
1076        in Smooth L1).
1077        See :class:`SmoothL1Loss` for additional discussion on the differences in behavior
1078        between the two losses.
1079
1080    Args:
1081        reduction (str, optional): Specifies the reduction to apply to the output:
1082            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
1083            ``'mean'``: the sum of the output will be divided by the number of
1084            elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'``
1085        delta (float, optional): Specifies the threshold at which to change between delta-scaled L1 and L2 loss.
1086            The value must be positive.  Default: 1.0
1087
1088    Shape:
1089        - Input: :math:`(*)` where :math:`*` means any number of dimensions.
1090        - Target: :math:`(*)`, same shape as the input.
1091        - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(*)`, same shape as the input.
1092    """
1093    __constants__ = ["reduction", "delta"]
1094
1095    def __init__(self, reduction: str = "mean", delta: float = 1.0) -> None:
1096        super().__init__(reduction=reduction)
1097        self.delta = delta
1098
1099    def forward(self, input: Tensor, target: Tensor) -> Tensor:
1100        return F.huber_loss(input, target, reduction=self.reduction, delta=self.delta)
1101
1102
1103class SoftMarginLoss(_Loss):
1104    r"""Creates a criterion that optimizes a two-class classification
1105    logistic loss between input tensor :math:`x` and target tensor :math:`y`
1106    (containing 1 or -1).
1107
1108    .. math::
1109        \text{loss}(x, y) = \sum_i \frac{\log(1 + \exp(-y[i]*x[i]))}{\text{x.nelement}()}
1110
1111    Args:
1112        size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
1113            the losses are averaged over each loss element in the batch. Note that for
1114            some losses, there are multiple elements per sample. If the field :attr:`size_average`
1115            is set to ``False``, the losses are instead summed for each minibatch. Ignored
1116            when :attr:`reduce` is ``False``. Default: ``True``
1117        reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
1118            losses are averaged or summed over observations for each minibatch depending
1119            on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
1120            batch element instead and ignores :attr:`size_average`. Default: ``True``
1121        reduction (str, optional): Specifies the reduction to apply to the output:
1122            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
1123            ``'mean'``: the sum of the output will be divided by the number of
1124            elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
1125            and :attr:`reduce` are in the process of being deprecated, and in the meantime,
1126            specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
1127
1128    Shape:
1129        - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
1130        - Target: :math:`(*)`, same shape as the input.
1131        - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(*)`, same
1132          shape as input.
1133
1134    """
1135    __constants__ = ["reduction"]
1136
1137    def __init__(self, size_average=None, reduce=None, reduction: str = "mean") -> None:
1138        super().__init__(size_average, reduce, reduction)
1139
1140    def forward(self, input: Tensor, target: Tensor) -> Tensor:
1141        return F.soft_margin_loss(input, target, reduction=self.reduction)
1142
1143
1144class CrossEntropyLoss(_WeightedLoss):
1145    r"""This criterion computes the cross entropy loss between input logits
1146    and target.
1147
1148    It is useful when training a classification problem with `C` classes.
1149    If provided, the optional argument :attr:`weight` should be a 1D `Tensor`
1150    assigning weight to each of the classes.
1151    This is particularly useful when you have an unbalanced training set.
1152
1153    The `input` is expected to contain the unnormalized logits for each class (which do `not` need
1154    to be positive or sum to 1, in general).
1155    `input` has to be a Tensor of size :math:`(C)` for unbatched input,
1156    :math:`(minibatch, C)` or :math:`(minibatch, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1` for the
1157    `K`-dimensional case. The last being useful for higher dimension inputs, such
1158    as computing cross entropy loss per-pixel for 2D images.
1159
1160    The `target` that this criterion expects should contain either:
1161
1162    - Class indices in the range :math:`[0, C)` where :math:`C` is the number of classes; if
1163      `ignore_index` is specified, this loss also accepts this class index (this index
1164      may not necessarily be in the class range). The unreduced (i.e. with :attr:`reduction`
1165      set to ``'none'``) loss for this case can be described as:
1166
1167      .. math::
1168          \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
1169          l_n = - w_{y_n} \log \frac{\exp(x_{n,y_n})}{\sum_{c=1}^C \exp(x_{n,c})}
1170          \cdot \mathbb{1}\{y_n \not= \text{ignore\_index}\}
1171
1172      where :math:`x` is the input, :math:`y` is the target, :math:`w` is the weight,
1173      :math:`C` is the number of classes, and :math:`N` spans the minibatch dimension as well as
1174      :math:`d_1, ..., d_k` for the `K`-dimensional case. If
1175      :attr:`reduction` is not ``'none'`` (default ``'mean'``), then
1176
1177      .. math::
1178          \ell(x, y) = \begin{cases}
1179              \sum_{n=1}^N \frac{1}{\sum_{n=1}^N w_{y_n} \cdot \mathbb{1}\{y_n \not= \text{ignore\_index}\}} l_n, &
1180               \text{if reduction} = \text{`mean';}\\
1181                \sum_{n=1}^N l_n,  &
1182                \text{if reduction} = \text{`sum'.}
1183            \end{cases}
1184
1185      Note that this case is equivalent to applying :class:`~torch.nn.LogSoftmax`
1186      on an input, followed by :class:`~torch.nn.NLLLoss`.
1187
1188    - Probabilities for each class; useful when labels beyond a single class per minibatch item
1189      are required, such as for blended labels, label smoothing, etc. The unreduced (i.e. with
1190      :attr:`reduction` set to ``'none'``) loss for this case can be described as:
1191
1192      .. math::
1193          \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
1194          l_n = - \sum_{c=1}^C w_c \log \frac{\exp(x_{n,c})}{\sum_{i=1}^C \exp(x_{n,i})} y_{n,c}
1195
1196      where :math:`x` is the input, :math:`y` is the target, :math:`w` is the weight,
1197      :math:`C` is the number of classes, and :math:`N` spans the minibatch dimension as well as
1198      :math:`d_1, ..., d_k` for the `K`-dimensional case. If
1199      :attr:`reduction` is not ``'none'`` (default ``'mean'``), then
1200
1201      .. math::
1202          \ell(x, y) = \begin{cases}
1203              \frac{\sum_{n=1}^N l_n}{N}, &
1204               \text{if reduction} = \text{`mean';}\\
1205                \sum_{n=1}^N l_n,  &
1206                \text{if reduction} = \text{`sum'.}
1207            \end{cases}
1208
1209    .. note::
1210        The performance of this criterion is generally better when `target` contains class
1211        indices, as this allows for optimized computation. Consider providing `target` as
1212        class probabilities only when a single class label per minibatch item is too restrictive.
1213
1214    Args:
1215        weight (Tensor, optional): a manual rescaling weight given to each class.
1216            If given, has to be a Tensor of size `C` and floating point dtype
1217        size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
1218            the losses are averaged over each loss element in the batch. Note that for
1219            some losses, there are multiple elements per sample. If the field :attr:`size_average`
1220            is set to ``False``, the losses are instead summed for each minibatch. Ignored
1221            when :attr:`reduce` is ``False``. Default: ``True``
1222        ignore_index (int, optional): Specifies a target value that is ignored
1223            and does not contribute to the input gradient. When :attr:`size_average` is
1224            ``True``, the loss is averaged over non-ignored targets. Note that
1225            :attr:`ignore_index` is only applicable when the target contains class indices.
1226        reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
1227            losses are averaged or summed over observations for each minibatch depending
1228            on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
1229            batch element instead and ignores :attr:`size_average`. Default: ``True``
1230        reduction (str, optional): Specifies the reduction to apply to the output:
1231            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will
1232            be applied, ``'mean'``: the weighted mean of the output is taken,
1233            ``'sum'``: the output will be summed. Note: :attr:`size_average`
1234            and :attr:`reduce` are in the process of being deprecated, and in
1235            the meantime, specifying either of those two args will override
1236            :attr:`reduction`. Default: ``'mean'``
1237        label_smoothing (float, optional): A float in [0.0, 1.0]. Specifies the amount
1238            of smoothing when computing the loss, where 0.0 means no smoothing. The targets
1239            become a mixture of the original ground truth and a uniform distribution as described in
1240            `Rethinking the Inception Architecture for Computer Vision <https://arxiv.org/abs/1512.00567>`__. Default: :math:`0.0`.
1241
1242    Shape:
1243        - Input: Shape :math:`(C)`, :math:`(N, C)` or :math:`(N, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1`
1244          in the case of `K`-dimensional loss.
1245        - Target: If containing class indices, shape :math:`()`, :math:`(N)` or :math:`(N, d_1, d_2, ..., d_K)` with
1246          :math:`K \geq 1` in the case of K-dimensional loss where each value should be between :math:`[0, C)`.
1247          If containing class probabilities, same shape as the input and each value should be between :math:`[0, 1]`.
1248        - Output: If reduction is 'none', shape :math:`()`, :math:`(N)` or :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1`
1249          in the case of K-dimensional loss, depending on the shape of the input. Otherwise, scalar.
1250
1251
1252        where:
1253
1254        .. math::
1255            \begin{aligned}
1256                C ={} & \text{number of classes} \\
1257                N ={} & \text{batch size} \\
1258            \end{aligned}
1259
1260    Examples::
1261
1262        >>> # Example of target with class indices
1263        >>> loss = nn.CrossEntropyLoss()
1264        >>> input = torch.randn(3, 5, requires_grad=True)
1265        >>> target = torch.empty(3, dtype=torch.long).random_(5)
1266        >>> output = loss(input, target)
1267        >>> output.backward()
1268        >>>
1269        >>> # Example of target with class probabilities
1270        >>> input = torch.randn(3, 5, requires_grad=True)
1271        >>> target = torch.randn(3, 5).softmax(dim=1)
1272        >>> output = loss(input, target)
1273        >>> output.backward()
1274    """
1275    __constants__ = ["ignore_index", "reduction", "label_smoothing"]
1276    ignore_index: int
1277    label_smoothing: float
1278
1279    def __init__(
1280        self,
1281        weight: Optional[Tensor] = None,
1282        size_average=None,
1283        ignore_index: int = -100,
1284        reduce=None,
1285        reduction: str = "mean",
1286        label_smoothing: float = 0.0,
1287    ) -> None:
1288        super().__init__(weight, size_average, reduce, reduction)
1289        self.ignore_index = ignore_index
1290        self.label_smoothing = label_smoothing
1291
1292    def forward(self, input: Tensor, target: Tensor) -> Tensor:
1293        return F.cross_entropy(
1294            input,
1295            target,
1296            weight=self.weight,
1297            ignore_index=self.ignore_index,
1298            reduction=self.reduction,
1299            label_smoothing=self.label_smoothing,
1300        )
1301
1302
1303class MultiLabelSoftMarginLoss(_WeightedLoss):
1304    r"""Creates a criterion that optimizes a multi-label one-versus-all
1305    loss based on max-entropy, between input :math:`x` and target :math:`y` of size
1306    :math:`(N, C)`.
1307    For each sample in the minibatch:
1308
1309    .. math::
1310        loss(x, y) = - \frac{1}{C} * \sum_i y[i] * \log((1 + \exp(-x[i]))^{-1})
1311                         + (1-y[i]) * \log\left(\frac{\exp(-x[i])}{(1 + \exp(-x[i]))}\right)
1312
1313    where :math:`i \in \left\{0, \; \cdots , \; \text{x.nElement}() - 1\right\}`,
1314    :math:`y[i] \in \left\{0, \; 1\right\}`.
1315
1316    Args:
1317        weight (Tensor, optional): a manual rescaling weight given to each
1318            class. If given, it has to be a Tensor of size `C`. Otherwise, it is
1319            treated as if having all ones.
1320        size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
1321            the losses are averaged over each loss element in the batch. Note that for
1322            some losses, there are multiple elements per sample. If the field :attr:`size_average`
1323            is set to ``False``, the losses are instead summed for each minibatch. Ignored
1324            when :attr:`reduce` is ``False``. Default: ``True``
1325        reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
1326            losses are averaged or summed over observations for each minibatch depending
1327            on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
1328            batch element instead and ignores :attr:`size_average`. Default: ``True``
1329        reduction (str, optional): Specifies the reduction to apply to the output:
1330            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
1331            ``'mean'``: the sum of the output will be divided by the number of
1332            elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
1333            and :attr:`reduce` are in the process of being deprecated, and in the meantime,
1334            specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
1335
1336    Shape:
1337        - Input: :math:`(N, C)` where `N` is the batch size and `C` is the number of classes.
1338        - Target: :math:`(N, C)`, label targets must have the same shape as the input.
1339        - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(N)`.
1340    """
1341    __constants__ = ["reduction"]
1342
1343    def __init__(
1344        self,
1345        weight: Optional[Tensor] = None,
1346        size_average=None,
1347        reduce=None,
1348        reduction: str = "mean",
1349    ) -> None:
1350        super().__init__(weight, size_average, reduce, reduction)
1351
1352    def forward(self, input: Tensor, target: Tensor) -> Tensor:
1353        return F.multilabel_soft_margin_loss(
1354            input, target, weight=self.weight, reduction=self.reduction
1355        )
1356
1357
1358class CosineEmbeddingLoss(_Loss):
1359    r"""Creates a criterion that measures the loss given input tensors
1360    :math:`x_1`, :math:`x_2` and a `Tensor` label :math:`y` with values 1 or -1.
1361    Use (:math:`y=1`) to maximize the cosine similarity of two inputs, and (:math:`y=-1`) otherwise.
1362    This is typically used for learning nonlinear
1363    embeddings or semi-supervised learning.
1364
1365    The loss function for each sample is:
1366
1367    .. math::
1368        \text{loss}(x, y) =
1369        \begin{cases}
1370        1 - \cos(x_1, x_2), & \text{if } y = 1 \\
1371        \max(0, \cos(x_1, x_2) - \text{margin}), & \text{if } y = -1
1372        \end{cases}
1373
1374    Args:
1375        margin (float, optional): Should be a number from :math:`-1` to :math:`1`,
1376            :math:`0` to :math:`0.5` is suggested. If :attr:`margin` is missing, the
1377            default value is :math:`0`.
1378        size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
1379            the losses are averaged over each loss element in the batch. Note that for
1380            some losses, there are multiple elements per sample. If the field :attr:`size_average`
1381            is set to ``False``, the losses are instead summed for each minibatch. Ignored
1382            when :attr:`reduce` is ``False``. Default: ``True``
1383        reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
1384            losses are averaged or summed over observations for each minibatch depending
1385            on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
1386            batch element instead and ignores :attr:`size_average`. Default: ``True``
1387        reduction (str, optional): Specifies the reduction to apply to the output:
1388            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
1389            ``'mean'``: the sum of the output will be divided by the number of
1390            elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
1391            and :attr:`reduce` are in the process of being deprecated, and in the meantime,
1392            specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
1393
1394    Shape:
1395        - Input1: :math:`(N, D)` or :math:`(D)`, where `N` is the batch size and `D` is the embedding dimension.
1396        - Input2: :math:`(N, D)` or :math:`(D)`, same shape as Input1.
1397        - Target: :math:`(N)` or :math:`()`.
1398        - Output: If :attr:`reduction` is ``'none'``, then :math:`(N)`, otherwise scalar.
1399
1400    Examples::
1401
1402        >>> loss = nn.CosineEmbeddingLoss()
1403        >>> input1 = torch.randn(3, 5, requires_grad=True)
1404        >>> input2 = torch.randn(3, 5, requires_grad=True)
1405        >>> target = torch.ones(3)
1406        >>> output = loss(input1, input2, target)
1407        >>> output.backward()
1408    """
1409    __constants__ = ["margin", "reduction"]
1410    margin: float
1411
1412    def __init__(
1413        self,
1414        margin: float = 0.0,
1415        size_average=None,
1416        reduce=None,
1417        reduction: str = "mean",
1418    ) -> None:
1419        super().__init__(size_average, reduce, reduction)
1420        self.margin = margin
1421
1422    def forward(self, input1: Tensor, input2: Tensor, target: Tensor) -> Tensor:
1423        return F.cosine_embedding_loss(
1424            input1, input2, target, margin=self.margin, reduction=self.reduction
1425        )
1426
1427
1428class MarginRankingLoss(_Loss):
1429    r"""Creates a criterion that measures the loss given
1430    inputs :math:`x1`, :math:`x2`, two 1D mini-batch or 0D `Tensors`,
1431    and a label 1D mini-batch or 0D `Tensor` :math:`y` (containing 1 or -1).
1432
1433    If :math:`y = 1` then it assumed the first input should be ranked higher
1434    (have a larger value) than the second input, and vice-versa for :math:`y = -1`.
1435
1436    The loss function for each pair of samples in the mini-batch is:
1437
1438    .. math::
1439        \text{loss}(x1, x2, y) = \max(0, -y * (x1 - x2) + \text{margin})
1440
1441    Args:
1442        margin (float, optional): Has a default value of :math:`0`.
1443        size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
1444            the losses are averaged over each loss element in the batch. Note that for
1445            some losses, there are multiple elements per sample. If the field :attr:`size_average`
1446            is set to ``False``, the losses are instead summed for each minibatch. Ignored
1447            when :attr:`reduce` is ``False``. Default: ``True``
1448        reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
1449            losses are averaged or summed over observations for each minibatch depending
1450            on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
1451            batch element instead and ignores :attr:`size_average`. Default: ``True``
1452        reduction (str, optional): Specifies the reduction to apply to the output:
1453            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
1454            ``'mean'``: the sum of the output will be divided by the number of
1455            elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
1456            and :attr:`reduce` are in the process of being deprecated, and in the meantime,
1457            specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
1458
1459    Shape:
1460        - Input1: :math:`(N)` or :math:`()` where `N` is the batch size.
1461        - Input2: :math:`(N)` or :math:`()`, same shape as the Input1.
1462        - Target: :math:`(N)` or :math:`()`, same shape as the inputs.
1463        - Output: scalar. If :attr:`reduction` is ``'none'`` and Input size is not :math:`()`, then :math:`(N)`.
1464
1465    Examples::
1466
1467        >>> loss = nn.MarginRankingLoss()
1468        >>> input1 = torch.randn(3, requires_grad=True)
1469        >>> input2 = torch.randn(3, requires_grad=True)
1470        >>> target = torch.randn(3).sign()
1471        >>> output = loss(input1, input2, target)
1472        >>> output.backward()
1473    """
1474    __constants__ = ["margin", "reduction"]
1475    margin: float
1476
1477    def __init__(
1478        self,
1479        margin: float = 0.0,
1480        size_average=None,
1481        reduce=None,
1482        reduction: str = "mean",
1483    ) -> None:
1484        super().__init__(size_average, reduce, reduction)
1485        self.margin = margin
1486
1487    def forward(self, input1: Tensor, input2: Tensor, target: Tensor) -> Tensor:
1488        return F.margin_ranking_loss(
1489            input1, input2, target, margin=self.margin, reduction=self.reduction
1490        )
1491
1492
1493class MultiMarginLoss(_WeightedLoss):
1494    r"""Creates a criterion that optimizes a multi-class classification hinge
1495    loss (margin-based loss) between input :math:`x` (a 2D mini-batch `Tensor`) and
1496    output :math:`y` (which is a 1D tensor of target class indices,
1497    :math:`0 \leq y \leq \text{x.size}(1)-1`):
1498
1499    For each mini-batch sample, the loss in terms of the 1D input :math:`x` and scalar
1500    output :math:`y` is:
1501
1502    .. math::
1503        \text{loss}(x, y) = \frac{\sum_i \max(0, \text{margin} - x[y] + x[i])^p}{\text{x.size}(0)}
1504
1505    where :math:`i \in \left\{0, \; \cdots , \; \text{x.size}(0) - 1\right\}`
1506    and :math:`i \neq y`.
1507
1508    Optionally, you can give non-equal weighting on the classes by passing
1509    a 1D :attr:`weight` tensor into the constructor.
1510
1511    The loss function then becomes:
1512
1513    .. math::
1514        \text{loss}(x, y) = \frac{\sum_i w[y] * \max(0, \text{margin} - x[y] + x[i])^p}{\text{x.size}(0)}
1515
1516    Args:
1517        p (int, optional): Has a default value of :math:`1`. :math:`1` and :math:`2`
1518            are the only supported values.
1519        margin (float, optional): Has a default value of :math:`1`.
1520        weight (Tensor, optional): a manual rescaling weight given to each
1521            class. If given, it has to be a Tensor of size `C`. Otherwise, it is
1522            treated as if having all ones.
1523        size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
1524            the losses are averaged over each loss element in the batch. Note that for
1525            some losses, there are multiple elements per sample. If the field :attr:`size_average`
1526            is set to ``False``, the losses are instead summed for each minibatch. Ignored
1527            when :attr:`reduce` is ``False``. Default: ``True``
1528        reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
1529            losses are averaged or summed over observations for each minibatch depending
1530            on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
1531            batch element instead and ignores :attr:`size_average`. Default: ``True``
1532        reduction (str, optional): Specifies the reduction to apply to the output:
1533            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
1534            ``'mean'``: the sum of the output will be divided by the number of
1535            elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
1536            and :attr:`reduce` are in the process of being deprecated, and in the meantime,
1537            specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
1538
1539    Shape:
1540        - Input: :math:`(N, C)` or :math:`(C)`, where :math:`N` is the batch size and :math:`C` is the number of classes.
1541        - Target: :math:`(N)` or :math:`()`, where each value is :math:`0 \leq \text{targets}[i] \leq C-1`.
1542        - Output: scalar. If :attr:`reduction` is ``'none'``, then same shape as the target.
1543
1544    Examples::
1545
1546        >>> loss = nn.MultiMarginLoss()
1547        >>> x = torch.tensor([[0.1, 0.2, 0.4, 0.8]])
1548        >>> y = torch.tensor([3])
1549        >>> # 0.25 * ((1-(0.8-0.1)) + (1-(0.8-0.2)) + (1-(0.8-0.4)))
1550        >>> loss(x, y)
1551        tensor(0.32...)
1552    """
1553    __constants__ = ["p", "margin", "reduction"]
1554    margin: float
1555    p: int
1556
1557    def __init__(
1558        self,
1559        p: int = 1,
1560        margin: float = 1.0,
1561        weight: Optional[Tensor] = None,
1562        size_average=None,
1563        reduce=None,
1564        reduction: str = "mean",
1565    ) -> None:
1566        super().__init__(weight, size_average, reduce, reduction)
1567        if p != 1 and p != 2:
1568            raise ValueError("only p == 1 and p == 2 supported")
1569        if weight is not None and weight.dim() != 1:
1570            raise ValueError(
1571                f"MultiMarginLoss: expected weight to be None or 1D tensor, got {weight.dim()}D instead"
1572            )
1573        self.p = p
1574        self.margin = margin
1575
1576    def forward(self, input: Tensor, target: Tensor) -> Tensor:
1577        return F.multi_margin_loss(
1578            input,
1579            target,
1580            p=self.p,
1581            margin=self.margin,
1582            weight=self.weight,
1583            reduction=self.reduction,
1584        )
1585
1586
1587class TripletMarginLoss(_Loss):
1588    r"""Creates a criterion that measures the triplet loss given an input
1589    tensors :math:`x1`, :math:`x2`, :math:`x3` and a margin with a value greater than :math:`0`.
1590    This is used for measuring a relative similarity between samples. A triplet
1591    is composed by `a`, `p` and `n` (i.e., `anchor`, `positive examples` and `negative
1592    examples` respectively). The shapes of all input tensors should be
1593    :math:`(N, D)`.
1594
1595    The distance swap is described in detail in the paper `Learning shallow
1596    convolutional feature descriptors with triplet losses`_ by
1597    V. Balntas, E. Riba et al.
1598
1599    The loss function for each sample in the mini-batch is:
1600
1601    .. math::
1602        L(a, p, n) = \max \{d(a_i, p_i) - d(a_i, n_i) + {\rm margin}, 0\}
1603
1604
1605    where
1606
1607    .. math::
1608        d(x_i, y_i) = \left\lVert {\bf x}_i - {\bf y}_i \right\rVert_p
1609
1610    The norm is calculated using the specified p value and a small constant :math:`\varepsilon` is
1611    added for numerical stability.
1612
1613    See also :class:`~torch.nn.TripletMarginWithDistanceLoss`, which computes the
1614    triplet margin loss for input tensors using a custom distance function.
1615
1616    Args:
1617        margin (float, optional): Default: :math:`1`.
1618        p (int, optional): The norm degree for pairwise distance. Default: :math:`2`.
1619        eps (float, optional): Small constant for numerical stability. Default: :math:`1e-6`.
1620        swap (bool, optional): The distance swap is described in detail in the paper
1621            `Learning shallow convolutional feature descriptors with triplet losses` by
1622            V. Balntas, E. Riba et al. Default: ``False``.
1623        size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
1624            the losses are averaged over each loss element in the batch. Note that for
1625            some losses, there are multiple elements per sample. If the field :attr:`size_average`
1626            is set to ``False``, the losses are instead summed for each minibatch. Ignored
1627            when :attr:`reduce` is ``False``. Default: ``True``
1628        reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
1629            losses are averaged or summed over observations for each minibatch depending
1630            on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
1631            batch element instead and ignores :attr:`size_average`. Default: ``True``
1632        reduction (str, optional): Specifies the reduction to apply to the output:
1633            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
1634            ``'mean'``: the sum of the output will be divided by the number of
1635            elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
1636            and :attr:`reduce` are in the process of being deprecated, and in the meantime,
1637            specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
1638
1639    Shape:
1640        - Input: :math:`(N, D)` or :math:`(D)` where :math:`D` is the vector dimension.
1641        - Output: A Tensor of shape :math:`(N)` if :attr:`reduction` is ``'none'`` and
1642          input shape is :math:`(N, D)`; a scalar otherwise.
1643
1644    Examples::
1645
1646    >>> triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7)
1647    >>> anchor = torch.randn(100, 128, requires_grad=True)
1648    >>> positive = torch.randn(100, 128, requires_grad=True)
1649    >>> negative = torch.randn(100, 128, requires_grad=True)
1650    >>> output = triplet_loss(anchor, positive, negative)
1651    >>> output.backward()
1652
1653    .. _Learning shallow convolutional feature descriptors with triplet losses:
1654        http://www.bmva.org/bmvc/2016/papers/paper119/index.html
1655    """
1656    __constants__ = ["margin", "p", "eps", "swap", "reduction"]
1657    margin: float
1658    p: float
1659    eps: float
1660    swap: bool
1661
1662    def __init__(
1663        self,
1664        margin: float = 1.0,
1665        p: float = 2.0,
1666        eps: float = 1e-6,
1667        swap: bool = False,
1668        size_average=None,
1669        reduce=None,
1670        reduction: str = "mean",
1671    ):
1672        super().__init__(size_average, reduce, reduction)
1673        if margin <= 0:
1674            raise ValueError(
1675                f"TripletMarginLoss: expected margin to be greater than 0, got {margin} instead"
1676            )
1677        self.margin = margin
1678        self.p = p
1679        self.eps = eps
1680        self.swap = swap
1681
1682    def forward(self, anchor: Tensor, positive: Tensor, negative: Tensor) -> Tensor:
1683        return F.triplet_margin_loss(
1684            anchor,
1685            positive,
1686            negative,
1687            margin=self.margin,
1688            p=self.p,
1689            eps=self.eps,
1690            swap=self.swap,
1691            reduction=self.reduction,
1692        )
1693
1694
1695class TripletMarginWithDistanceLoss(_Loss):
1696    r"""Creates a criterion that measures the triplet loss given input
1697    tensors :math:`a`, :math:`p`, and :math:`n` (representing anchor,
1698    positive, and negative examples, respectively), and a nonnegative,
1699    real-valued function ("distance function") used to compute the relationship
1700    between the anchor and positive example ("positive distance") and the
1701    anchor and negative example ("negative distance").
1702
1703    The unreduced loss (i.e., with :attr:`reduction` set to ``'none'``)
1704    can be described as:
1705
1706    .. math::
1707        \ell(a, p, n) = L = \{l_1,\dots,l_N\}^\top, \quad
1708        l_i = \max \{d(a_i, p_i) - d(a_i, n_i) + {\rm margin}, 0\}
1709
1710    where :math:`N` is the batch size; :math:`d` is a nonnegative, real-valued function
1711    quantifying the closeness of two tensors, referred to as the :attr:`distance_function`;
1712    and :math:`margin` is a nonnegative margin representing the minimum difference
1713    between the positive and negative distances that is required for the loss to
1714    be 0.  The input tensors have :math:`N` elements each and can be of any shape
1715    that the distance function can handle.
1716
1717    If :attr:`reduction` is not ``'none'``
1718    (default ``'mean'``), then:
1719
1720    .. math::
1721        \ell(x, y) =
1722        \begin{cases}
1723            \operatorname{mean}(L), &  \text{if reduction} = \text{`mean';}\\
1724            \operatorname{sum}(L),  &  \text{if reduction} = \text{`sum'.}
1725        \end{cases}
1726
1727    See also :class:`~torch.nn.TripletMarginLoss`, which computes the triplet
1728    loss for input tensors using the :math:`l_p` distance as the distance function.
1729
1730    Args:
1731        distance_function (Callable, optional): A nonnegative, real-valued function that
1732            quantifies the closeness of two tensors. If not specified,
1733            `nn.PairwiseDistance` will be used.  Default: ``None``
1734        margin (float, optional): A nonnegative margin representing the minimum difference
1735            between the positive and negative distances required for the loss to be 0. Larger
1736            margins penalize cases where the negative examples are not distant enough from the
1737            anchors, relative to the positives. Default: :math:`1`.
1738        swap (bool, optional): Whether to use the distance swap described in the paper
1739            `Learning shallow convolutional feature descriptors with triplet losses` by
1740            V. Balntas, E. Riba et al. If True, and if the positive example is closer to the
1741            negative example than the anchor is, swaps the positive example and the anchor in
1742            the loss computation. Default: ``False``.
1743        reduction (str, optional): Specifies the (optional) reduction to apply to the output:
1744            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
1745            ``'mean'``: the sum of the output will be divided by the number of
1746            elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'``
1747
1748
1749    Shape:
1750        - Input: :math:`(N, *)` where :math:`*` represents any number of additional dimensions
1751          as supported by the distance function.
1752        - Output: A Tensor of shape :math:`(N)` if :attr:`reduction` is ``'none'``, or a scalar
1753          otherwise.
1754
1755    Examples::
1756
1757    >>> # Initialize embeddings
1758    >>> embedding = nn.Embedding(1000, 128)
1759    >>> anchor_ids = torch.randint(0, 1000, (1,))
1760    >>> positive_ids = torch.randint(0, 1000, (1,))
1761    >>> negative_ids = torch.randint(0, 1000, (1,))
1762    >>> anchor = embedding(anchor_ids)
1763    >>> positive = embedding(positive_ids)
1764    >>> negative = embedding(negative_ids)
1765    >>>
1766    >>> # Built-in Distance Function
1767    >>> triplet_loss = \
1768    >>>     nn.TripletMarginWithDistanceLoss(distance_function=nn.PairwiseDistance())
1769    >>> output = triplet_loss(anchor, positive, negative)
1770    >>> output.backward()
1771    >>>
1772    >>> # Custom Distance Function
1773    >>> def l_infinity(x1, x2):
1774    >>>     return torch.max(torch.abs(x1 - x2), dim=1).values
1775    >>>
1776    >>> # xdoctest: +SKIP("FIXME: Would call backwards a second time")
1777    >>> triplet_loss = (
1778    >>>     nn.TripletMarginWithDistanceLoss(distance_function=l_infinity, margin=1.5))
1779    >>> output = triplet_loss(anchor, positive, negative)
1780    >>> output.backward()
1781    >>>
1782    >>> # Custom Distance Function (Lambda)
1783    >>> triplet_loss = (
1784    >>>     nn.TripletMarginWithDistanceLoss(
1785    >>>         distance_function=lambda x, y: 1.0 - F.cosine_similarity(x, y)))
1786    >>> output = triplet_loss(anchor, positive, negative)
1787    >>> output.backward()
1788
1789    Reference:
1790        V. Balntas, et al.: Learning shallow convolutional feature descriptors with triplet losses:
1791        http://www.bmva.org/bmvc/2016/papers/paper119/index.html
1792    """
1793    __constants__ = ["margin", "swap", "reduction"]
1794    margin: float
1795    swap: bool
1796
1797    def __init__(
1798        self,
1799        *,
1800        distance_function: Optional[Callable[[Tensor, Tensor], Tensor]] = None,
1801        margin: float = 1.0,
1802        swap: bool = False,
1803        reduction: str = "mean",
1804    ):
1805        super().__init__(size_average=None, reduce=None, reduction=reduction)
1806        if margin <= 0:
1807            raise ValueError(
1808                f"TripletMarginWithDistanceLoss: expected margin to be greater than 0, got {margin} instead"
1809            )
1810        self.distance_function: Optional[Callable[[Tensor, Tensor], Tensor]] = (
1811            distance_function if distance_function is not None else PairwiseDistance()
1812        )
1813        self.margin = margin
1814        self.swap = swap
1815
1816    def forward(self, anchor: Tensor, positive: Tensor, negative: Tensor) -> Tensor:
1817        return F.triplet_margin_with_distance_loss(
1818            anchor,
1819            positive,
1820            negative,
1821            distance_function=self.distance_function,
1822            margin=self.margin,
1823            swap=self.swap,
1824            reduction=self.reduction,
1825        )
1826
1827
1828class CTCLoss(_Loss):
1829    r"""The Connectionist Temporal Classification loss.
1830
1831    Calculates loss between a continuous (unsegmented) time series and a target sequence. CTCLoss sums over the
1832    probability of possible alignments of input to target, producing a loss value which is differentiable
1833    with respect to each input node. The alignment of input to target is assumed to be "many-to-one", which
1834    limits the length of the target sequence such that it must be :math:`\leq` the input length.
1835
1836    Args:
1837        blank (int, optional): blank label. Default :math:`0`.
1838        reduction (str, optional): Specifies the reduction to apply to the output:
1839            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
1840            ``'mean'``: the output losses will be divided by the target lengths and
1841            then the mean over the batch is taken, ``'sum'``: the output losses will be summed.
1842            Default: ``'mean'``
1843        zero_infinity (bool, optional):
1844            Whether to zero infinite losses and the associated gradients.
1845            Default: ``False``
1846            Infinite losses mainly occur when the inputs are too short
1847            to be aligned to the targets.
1848
1849    Shape:
1850        - Log_probs: Tensor of size :math:`(T, N, C)` or :math:`(T, C)`,
1851          where :math:`T = \text{input length}`,
1852          :math:`N = \text{batch size}`, and
1853          :math:`C = \text{number of classes (including blank)}`.
1854          The logarithmized probabilities of the outputs (e.g. obtained with
1855          :func:`torch.nn.functional.log_softmax`).
1856        - Targets: Tensor of size :math:`(N, S)` or
1857          :math:`(\operatorname{sum}(\text{target\_lengths}))`,
1858          where :math:`N = \text{batch size}` and
1859          :math:`S = \text{max target length, if shape is } (N, S)`.
1860          It represents the target sequences. Each element in the target
1861          sequence is a class index. And the target index cannot be blank (default=0).
1862          In the :math:`(N, S)` form, targets are padded to the
1863          length of the longest sequence, and stacked.
1864          In the :math:`(\operatorname{sum}(\text{target\_lengths}))` form,
1865          the targets are assumed to be un-padded and
1866          concatenated within 1 dimension.
1867        - Input_lengths: Tuple or tensor of size :math:`(N)` or :math:`()`,
1868          where :math:`N = \text{batch size}`. It represents the lengths of the
1869          inputs (must each be :math:`\leq T`). And the lengths are specified
1870          for each sequence to achieve masking under the assumption that sequences
1871          are padded to equal lengths.
1872        - Target_lengths: Tuple or tensor of size :math:`(N)` or :math:`()`,
1873          where :math:`N = \text{batch size}`. It represents lengths of the targets.
1874          Lengths are specified for each sequence to achieve masking under the
1875          assumption that sequences are padded to equal lengths. If target shape is
1876          :math:`(N,S)`, target_lengths are effectively the stop index
1877          :math:`s_n` for each target sequence, such that ``target_n = targets[n,0:s_n]`` for
1878          each target in a batch. Lengths must each be :math:`\leq S`
1879          If the targets are given as a 1d tensor that is the concatenation of individual
1880          targets, the target_lengths must add up to the total length of the tensor.
1881        - Output: scalar if :attr:`reduction` is ``'mean'`` (default) or
1882          ``'sum'``. If :attr:`reduction` is ``'none'``, then :math:`(N)` if input is batched or
1883          :math:`()` if input is unbatched, where :math:`N = \text{batch size}`.
1884
1885    Examples::
1886
1887        >>> # Target are to be padded
1888        >>> T = 50      # Input sequence length
1889        >>> C = 20      # Number of classes (including blank)
1890        >>> N = 16      # Batch size
1891        >>> S = 30      # Target sequence length of longest target in batch (padding length)
1892        >>> S_min = 10  # Minimum target length, for demonstration purposes
1893        >>>
1894        >>> # Initialize random batch of input vectors, for *size = (T,N,C)
1895        >>> input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_()
1896        >>>
1897        >>> # Initialize random batch of targets (0 = blank, 1:C = classes)
1898        >>> target = torch.randint(low=1, high=C, size=(N, S), dtype=torch.long)
1899        >>>
1900        >>> input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
1901        >>> target_lengths = torch.randint(low=S_min, high=S, size=(N,), dtype=torch.long)
1902        >>> ctc_loss = nn.CTCLoss()
1903        >>> loss = ctc_loss(input, target, input_lengths, target_lengths)
1904        >>> loss.backward()
1905        >>>
1906        >>>
1907        >>> # Target are to be un-padded
1908        >>> T = 50      # Input sequence length
1909        >>> C = 20      # Number of classes (including blank)
1910        >>> N = 16      # Batch size
1911        >>>
1912        >>> # Initialize random batch of input vectors, for *size = (T,N,C)
1913        >>> input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_()
1914        >>> input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
1915        >>>
1916        >>> # Initialize random batch of targets (0 = blank, 1:C = classes)
1917        >>> target_lengths = torch.randint(low=1, high=T, size=(N,), dtype=torch.long)
1918        >>> target = torch.randint(low=1, high=C, size=(sum(target_lengths),), dtype=torch.long)
1919        >>> ctc_loss = nn.CTCLoss()
1920        >>> loss = ctc_loss(input, target, input_lengths, target_lengths)
1921        >>> loss.backward()
1922        >>>
1923        >>>
1924        >>> # Target are to be un-padded and unbatched (effectively N=1)
1925        >>> T = 50      # Input sequence length
1926        >>> C = 20      # Number of classes (including blank)
1927        >>>
1928        >>> # Initialize random batch of input vectors, for *size = (T,C)
1929        >>> # xdoctest: +SKIP("FIXME: error in doctest")
1930        >>> input = torch.randn(T, C).log_softmax(1).detach().requires_grad_()
1931        >>> input_lengths = torch.tensor(T, dtype=torch.long)
1932        >>>
1933        >>> # Initialize random batch of targets (0 = blank, 1:C = classes)
1934        >>> target_lengths = torch.randint(low=1, high=T, size=(), dtype=torch.long)
1935        >>> target = torch.randint(low=1, high=C, size=(target_lengths,), dtype=torch.long)
1936        >>> ctc_loss = nn.CTCLoss()
1937        >>> loss = ctc_loss(input, target, input_lengths, target_lengths)
1938        >>> loss.backward()
1939
1940    Reference:
1941        A. Graves et al.: Connectionist Temporal Classification:
1942        Labelling Unsegmented Sequence Data with Recurrent Neural Networks:
1943        https://www.cs.toronto.edu/~graves/icml_2006.pdf
1944
1945    Note:
1946        In order to use CuDNN, the following must be satisfied: :attr:`targets` must be
1947        in concatenated format, all :attr:`input_lengths` must be `T`.  :math:`blank=0`,
1948        :attr:`target_lengths` :math:`\leq 256`, the integer arguments must be of
1949        dtype :attr:`torch.int32`.
1950
1951        The regular implementation uses the (more common in PyTorch) `torch.long` dtype.
1952
1953
1954    Note:
1955        In some circumstances when using the CUDA backend with CuDNN, this operator
1956        may select a nondeterministic algorithm to increase performance. If this is
1957        undesirable, you can try to make the operation deterministic (potentially at
1958        a performance cost) by setting ``torch.backends.cudnn.deterministic =
1959        True``.
1960        Please see the notes on :doc:`/notes/randomness` for background.
1961    """
1962    __constants__ = ["blank", "reduction"]
1963    blank: int
1964    zero_infinity: bool
1965
1966    def __init__(
1967        self, blank: int = 0, reduction: str = "mean", zero_infinity: bool = False
1968    ):
1969        super().__init__(reduction=reduction)
1970        self.blank = blank
1971        self.zero_infinity = zero_infinity
1972
1973    def forward(
1974        self,
1975        log_probs: Tensor,
1976        targets: Tensor,
1977        input_lengths: Tensor,
1978        target_lengths: Tensor,
1979    ) -> Tensor:
1980        return F.ctc_loss(
1981            log_probs,
1982            targets,
1983            input_lengths,
1984            target_lengths,
1985            self.blank,
1986            self.reduction,
1987            self.zero_infinity,
1988        )
1989
1990
1991# TODO: L1HingeEmbeddingCriterion
1992# TODO: MSECriterion weight
1993# TODO: ClassSimplexCriterion
1994