xref: /aosp_15_r20/external/pytorch/torch/nn/modules/adaptive.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2
3from collections import namedtuple
4from typing import List, Sequence
5
6import torch
7import torch.nn.functional as F
8from torch import Tensor
9
10from .container import ModuleList, Sequential
11from .linear import Linear
12from .module import Module
13
14
15__all__ = ["AdaptiveLogSoftmaxWithLoss"]
16
17_ASMoutput = namedtuple("_ASMoutput", ["output", "loss"])
18
19
20class AdaptiveLogSoftmaxWithLoss(Module):
21    """Efficient softmax approximation.
22
23    As described in
24    `Efficient softmax approximation for GPUs by Edouard Grave, Armand Joulin,
25    Moustapha Ciss\u00e9, David Grangier, and Herv\u00e9 J\u00e9gou
26    <https://arxiv.org/abs/1609.04309>`__.
27""" r"""
28    Adaptive softmax is an approximate strategy for training models with large
29    output spaces. It is most effective when the label distribution is highly
30    imbalanced, for example in natural language modelling, where the word
31    frequency distribution approximately follows the `Zipf's law`_.
32
33    Adaptive softmax partitions the labels into several clusters, according to
34    their frequency. These clusters may contain different number of targets
35    each.
36    Additionally, clusters containing less frequent labels assign lower
37    dimensional embeddings to those labels, which speeds up the computation.
38    For each minibatch, only clusters for which at least one target is
39    present are evaluated.
40
41    The idea is that the clusters which are accessed frequently
42    (like the first one, containing most frequent labels), should also be cheap
43    to compute -- that is, contain a small number of assigned labels.
44
45    We highly recommend taking a look at the original paper for more details.
46
47    * :attr:`cutoffs` should be an ordered Sequence of integers sorted
48      in the increasing order.
49      It controls number of clusters and the partitioning of targets into
50      clusters. For example setting ``cutoffs = [10, 100, 1000]``
51      means that first `10` targets will be assigned
52      to the 'head' of the adaptive softmax, targets `11, 12, ..., 100` will be
53      assigned to the first cluster, and targets `101, 102, ..., 1000` will be
54      assigned to the second cluster, while targets
55      `1001, 1002, ..., n_classes - 1` will be assigned
56      to the last, third cluster.
57
58    * :attr:`div_value` is used to compute the size of each additional cluster,
59      which is given as
60      :math:`\left\lfloor\frac{\texttt{in\_features}}{\texttt{div\_value}^{idx}}\right\rfloor`,
61      where :math:`idx` is the cluster index (with clusters
62      for less frequent words having larger indices,
63      and indices starting from :math:`1`).
64
65    * :attr:`head_bias` if set to True, adds a bias term to the 'head' of the
66      adaptive softmax. See paper for details. Set to False in the official
67      implementation.
68
69    .. warning::
70        Labels passed as inputs to this module should be sorted according to
71        their frequency. This means that the most frequent label should be
72        represented by the index `0`, and the least frequent
73        label should be represented by the index `n_classes - 1`.
74
75    .. note::
76        This module returns a ``NamedTuple`` with ``output``
77        and ``loss`` fields. See further documentation for details.
78
79    .. note::
80        To compute log-probabilities for all classes, the ``log_prob``
81        method can be used.
82
83    Args:
84        in_features (int): Number of features in the input tensor
85        n_classes (int): Number of classes in the dataset
86        cutoffs (Sequence): Cutoffs used to assign targets to their buckets
87        div_value (float, optional): value used as an exponent to compute sizes
88            of the clusters. Default: 4.0
89        head_bias (bool, optional): If ``True``, adds a bias term to the 'head' of the
90            adaptive softmax. Default: ``False``
91
92    Returns:
93        ``NamedTuple`` with ``output`` and ``loss`` fields:
94            * **output** is a Tensor of size ``N`` containing computed target
95              log probabilities for each example
96            * **loss** is a Scalar representing the computed negative
97              log likelihood loss
98
99    Shape:
100        - input: :math:`(N, \texttt{in\_features})` or :math:`(\texttt{in\_features})`
101        - target: :math:`(N)` or :math:`()` where each value satisfies :math:`0 <= \texttt{target[i]} <= \texttt{n\_classes}`
102        - output1: :math:`(N)` or :math:`()`
103        - output2: ``Scalar``
104
105    .. _Zipf's law: https://en.wikipedia.org/wiki/Zipf%27s_law
106    """
107
108    in_features: int
109    n_classes: int
110    cutoffs: List[int]
111    div_value: float
112    head_bias: bool
113    head: Linear
114    tail: ModuleList
115
116    def __init__(
117        self,
118        in_features: int,
119        n_classes: int,
120        cutoffs: Sequence[int],
121        div_value: float = 4.0,
122        head_bias: bool = False,
123        device=None,
124        dtype=None,
125    ) -> None:
126        factory_kwargs = {"device": device, "dtype": dtype}
127        super().__init__()
128
129        cutoffs = list(cutoffs)
130
131        if len(cutoffs) == 0:
132            raise ValueError("cutoffs should be a sequence of length larger than 0")
133
134        if (
135            (cutoffs != sorted(cutoffs))
136            or (min(cutoffs) <= 0)
137            or (max(cutoffs) > (n_classes - 1))
138            or (len(set(cutoffs)) != len(cutoffs))
139            or any(int(c) != c for c in cutoffs)
140        ):
141            raise ValueError(
142                "cutoffs should be a sequence of unique, positive "
143                "integers sorted in an increasing order, where "
144                "each value is between 1 and n_classes-1"
145            )
146
147        self.in_features = in_features
148        self.n_classes = n_classes
149        self.cutoffs = cutoffs + [n_classes]
150        self.div_value = div_value
151        self.head_bias = head_bias
152
153        self.shortlist_size = self.cutoffs[0]
154        self.n_clusters = len(self.cutoffs) - 1
155        self.head_size = self.shortlist_size + self.n_clusters
156
157        self.head = Linear(
158            self.in_features, self.head_size, bias=self.head_bias, **factory_kwargs
159        )
160        self.tail = ModuleList()
161
162        for i in range(self.n_clusters):
163            hsz = int(self.in_features // (self.div_value ** (i + 1)))
164            osz = self.cutoffs[i + 1] - self.cutoffs[i]
165
166            projection = Sequential(
167                Linear(self.in_features, hsz, bias=False, **factory_kwargs),
168                Linear(hsz, osz, bias=False, **factory_kwargs),
169            )
170
171            self.tail.append(projection)
172
173    def reset_parameters(self) -> None:
174        self.head.reset_parameters()
175        for i2h, h2o in self.tail:
176            i2h.reset_parameters()
177            h2o.reset_parameters()
178
179    def forward(self, input_: Tensor, target_: Tensor) -> _ASMoutput:
180        targ_dim = target_.dim()
181
182        if targ_dim == 1:
183            if input_.size(0) != target_.size(0):
184                raise RuntimeError(
185                    "Input and target should have the same size "
186                    "in the batch dimension."
187                )
188            if input_.dim() != 2:
189                raise RuntimeError(
190                    "1D target tensor expects 2D input tensors, "
191                    "but found inputs with size",
192                    input_.size(),
193                )
194        elif targ_dim == 0:
195            if input_.dim() != 1:
196                raise RuntimeError(
197                    "0D target tensor expects 1D input tensors, "
198                    "but found inputs with size",
199                    input_.size(),
200                )
201        else:
202            raise RuntimeError(
203                "0D or 1D target tensor expected, " "multi-target not supported"
204            )
205
206        is_batched = targ_dim > 0
207        input = input_ if is_batched else input_.unsqueeze(0)
208        target = target_ if is_batched else target_.unsqueeze(0)
209
210        used_rows = 0
211        batch_size = target.size(0)
212
213        output = input.new_zeros(batch_size)
214        gather_inds = target.new_empty(batch_size)
215
216        cutoff_values = [0] + self.cutoffs
217        for i in range(len(cutoff_values) - 1):
218            low_idx = cutoff_values[i]
219            high_idx = cutoff_values[i + 1]
220
221            target_mask = (target >= low_idx) & (target < high_idx)
222            row_indices = target_mask.nonzero().squeeze()
223
224            if row_indices.numel() == 0:
225                continue
226
227            if i == 0:
228                gather_inds.index_copy_(0, row_indices, target[target_mask])
229
230            else:
231                relative_target = target[target_mask] - low_idx
232                input_subset = input.index_select(0, row_indices)
233
234                cluster_output = self.tail[i - 1](input_subset)
235                cluster_index = self.shortlist_size + i - 1
236
237                gather_inds.index_fill_(0, row_indices, cluster_index)
238                cluster_logprob = F.log_softmax(cluster_output, dim=1)
239                local_logprob = cluster_logprob.gather(1, relative_target.unsqueeze(1))
240                output.index_copy_(0, row_indices, local_logprob.squeeze(1))
241
242            used_rows += row_indices.numel()
243
244        if used_rows != batch_size:
245            raise RuntimeError(
246                f"Target values should be in [0, {self.n_classes - 1}], "
247                f"but values in range [{target.min().item()}, {target.max().item()}] "
248                "were found. "
249            )
250
251        head_output = self.head(input)
252        head_logprob = F.log_softmax(head_output, dim=1)
253        output += head_logprob.gather(1, gather_inds.unsqueeze(1)).squeeze()
254        loss = (-output).mean()
255
256        if not is_batched:
257            output = output.squeeze(0)
258
259        return _ASMoutput(output, loss)
260
261    def _get_full_log_prob(self, input, head_output):
262        """Given input tensor, and output of ``self.head``, compute the log of the full distribution."""
263        out = input.new_empty((head_output.size(0), self.n_classes))
264        head_logprob = F.log_softmax(head_output, dim=1)
265
266        out[:, : self.shortlist_size] = head_logprob[:, : self.shortlist_size]
267
268        for i, (start_idx, stop_idx) in enumerate(zip(self.cutoffs, self.cutoffs[1:])):
269            cluster_output = self.tail[i](input)
270            cluster_logprob = F.log_softmax(cluster_output, dim=1)
271            output_logprob = cluster_logprob + head_logprob[
272                :, self.shortlist_size + i
273            ].unsqueeze(1)
274
275            out[:, start_idx:stop_idx] = output_logprob
276
277        return out
278
279    def log_prob(self, input: Tensor) -> Tensor:
280        r"""Compute log probabilities for all :math:`\texttt{n\_classes}`.
281
282        Args:
283            input (Tensor): a minibatch of examples
284
285        Returns:
286            log-probabilities of for each class :math:`c`
287            in range :math:`0 <= c <= \texttt{n\_classes}`, where :math:`\texttt{n\_classes}` is a
288            parameter passed to ``AdaptiveLogSoftmaxWithLoss`` constructor.
289
290        Shape:
291            - Input: :math:`(N, \texttt{in\_features})`
292            - Output: :math:`(N, \texttt{n\_classes})`
293
294        """
295        head_output = self.head(input)
296        return self._get_full_log_prob(input, head_output)
297
298    def predict(self, input: Tensor) -> Tensor:
299        r"""Return the class with the highest probability for each example in the input minibatch.
300
301        This is equivalent to ``self.log_prob(input).argmax(dim=1)``, but is more efficient in some cases.
302
303        Args:
304            input (Tensor): a minibatch of examples
305
306        Returns:
307            output (Tensor): a class with the highest probability for each example
308
309        Shape:
310            - Input: :math:`(N, \texttt{in\_features})`
311            - Output: :math:`(N)`
312        """
313        head_output = self.head(input)
314        output = torch.argmax(head_output, dim=1)
315        not_in_shortlist = output >= self.shortlist_size
316        all_in_shortlist = not (not_in_shortlist.any())
317
318        if all_in_shortlist:
319            return output
320
321        elif not_in_shortlist.all():
322            log_prob = self._get_full_log_prob(input, head_output)
323            return torch.argmax(log_prob, dim=1)
324
325        else:
326            log_prob = self._get_full_log_prob(
327                input[not_in_shortlist], head_output[not_in_shortlist]
328            )
329            output[not_in_shortlist] = torch.argmax(log_prob, dim=1)
330            return output
331