xref: /aosp_15_r20/external/pytorch/torch/nn/modules/sparse.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import Optional
3
4import torch
5from torch import Tensor
6from torch.nn import functional as F, init
7from torch.nn.parameter import Parameter
8
9from .module import Module
10
11
12__all__ = ["Embedding", "EmbeddingBag"]
13
14
15class Embedding(Module):
16    r"""A simple lookup table that stores embeddings of a fixed dictionary and size.
17
18    This module is often used to store word embeddings and retrieve them using indices.
19    The input to the module is a list of indices, and the output is the corresponding
20    word embeddings.
21
22    Args:
23        num_embeddings (int): size of the dictionary of embeddings
24        embedding_dim (int): the size of each embedding vector
25        padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient;
26                                     therefore, the embedding vector at :attr:`padding_idx` is not updated during training,
27                                     i.e. it remains as a fixed "pad". For a newly constructed Embedding,
28                                     the embedding vector at :attr:`padding_idx` will default to all zeros,
29                                     but can be updated to another value to be used as the padding vector.
30        max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`
31                                    is renormalized to have norm :attr:`max_norm`.
32        norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``.
33        scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse of frequency of
34                                                the words in the mini-batch. Default ``False``.
35        sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor.
36                                 See Notes for more details regarding sparse gradients.
37
38    Attributes:
39        weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim)
40                         initialized from :math:`\mathcal{N}(0, 1)`
41
42    Shape:
43        - Input: :math:`(*)`, IntTensor or LongTensor of arbitrary shape containing the indices to extract
44        - Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}`
45
46    .. note::
47        Keep in mind that only a limited number of optimizers support
48        sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`),
49        :class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`)
50
51    .. note::
52        When :attr:`max_norm` is not ``None``, :class:`Embedding`'s forward method will modify the
53        :attr:`weight` tensor in-place. Since tensors needed for gradient computations cannot be
54        modified in-place, performing a differentiable operation on ``Embedding.weight`` before
55        calling :class:`Embedding`'s forward method requires cloning ``Embedding.weight`` when
56        :attr:`max_norm` is not ``None``. For example::
57
58            n, d, m = 3, 5, 7
59            embedding = nn.Embedding(n, d, max_norm=1.0)
60            W = torch.randn((m, d), requires_grad=True)
61            idx = torch.tensor([1, 2])
62            a = embedding.weight.clone() @ W.t()  # weight must be cloned for this to be differentiable
63            b = embedding(idx) @ W.t()  # modifies weight in-place
64            out = (a.unsqueeze(0) + b.unsqueeze(1))
65            loss = out.sigmoid().prod()
66            loss.backward()
67
68    Examples::
69
70        >>> # an Embedding module containing 10 tensors of size 3
71        >>> embedding = nn.Embedding(10, 3)
72        >>> # a batch of 2 samples of 4 indices each
73        >>> input = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]])
74        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
75        >>> embedding(input)
76        tensor([[[-0.0251, -1.6902,  0.7172],
77                 [-0.6431,  0.0748,  0.6969],
78                 [ 1.4970,  1.3448, -0.9685],
79                 [-0.3677, -2.7265, -0.1685]],
80
81                [[ 1.4970,  1.3448, -0.9685],
82                 [ 0.4362, -0.4004,  0.9400],
83                 [-0.6431,  0.0748,  0.6969],
84                 [ 0.9124, -2.3616,  1.1151]]])
85
86
87        >>> # example with padding_idx
88        >>> embedding = nn.Embedding(10, 3, padding_idx=0)
89        >>> input = torch.LongTensor([[0, 2, 0, 5]])
90        >>> embedding(input)
91        tensor([[[ 0.0000,  0.0000,  0.0000],
92                 [ 0.1535, -2.0309,  0.9315],
93                 [ 0.0000,  0.0000,  0.0000],
94                 [-0.1655,  0.9897,  0.0635]]])
95
96        >>> # example of changing `pad` vector
97        >>> padding_idx = 0
98        >>> embedding = nn.Embedding(3, 3, padding_idx=padding_idx)
99        >>> embedding.weight
100        Parameter containing:
101        tensor([[ 0.0000,  0.0000,  0.0000],
102                [-0.7895, -0.7089, -0.0364],
103                [ 0.6778,  0.5803,  0.2678]], requires_grad=True)
104        >>> with torch.no_grad():
105        ...     embedding.weight[padding_idx] = torch.ones(3)
106        >>> embedding.weight
107        Parameter containing:
108        tensor([[ 1.0000,  1.0000,  1.0000],
109                [-0.7895, -0.7089, -0.0364],
110                [ 0.6778,  0.5803,  0.2678]], requires_grad=True)
111    """
112
113    __constants__ = [
114        "num_embeddings",
115        "embedding_dim",
116        "padding_idx",
117        "max_norm",
118        "norm_type",
119        "scale_grad_by_freq",
120        "sparse",
121    ]
122
123    num_embeddings: int
124    embedding_dim: int
125    padding_idx: Optional[int]
126    max_norm: Optional[float]
127    norm_type: float
128    scale_grad_by_freq: bool
129    weight: Tensor
130    freeze: bool
131    sparse: bool
132
133    def __init__(
134        self,
135        num_embeddings: int,
136        embedding_dim: int,
137        padding_idx: Optional[int] = None,
138        max_norm: Optional[float] = None,
139        norm_type: float = 2.0,
140        scale_grad_by_freq: bool = False,
141        sparse: bool = False,
142        _weight: Optional[Tensor] = None,
143        _freeze: bool = False,
144        device=None,
145        dtype=None,
146    ) -> None:
147        factory_kwargs = {"device": device, "dtype": dtype}
148        super().__init__()
149        self.num_embeddings = num_embeddings
150        self.embedding_dim = embedding_dim
151        if padding_idx is not None:
152            if padding_idx > 0:
153                assert (
154                    padding_idx < self.num_embeddings
155                ), "Padding_idx must be within num_embeddings"
156            elif padding_idx < 0:
157                assert (
158                    padding_idx >= -self.num_embeddings
159                ), "Padding_idx must be within num_embeddings"
160                padding_idx = self.num_embeddings + padding_idx
161        self.padding_idx = padding_idx
162        self.max_norm = max_norm
163        self.norm_type = norm_type
164        self.scale_grad_by_freq = scale_grad_by_freq
165        if _weight is None:
166            self.weight = Parameter(
167                torch.empty((num_embeddings, embedding_dim), **factory_kwargs),
168                requires_grad=not _freeze,
169            )
170            self.reset_parameters()
171        else:
172            assert list(_weight.shape) == [
173                num_embeddings,
174                embedding_dim,
175            ], "Shape of weight does not match num_embeddings and embedding_dim"
176            self.weight = Parameter(_weight, requires_grad=not _freeze)
177
178        self.sparse = sparse
179
180    def reset_parameters(self) -> None:
181        init.normal_(self.weight)
182        self._fill_padding_idx_with_zero()
183
184    def _fill_padding_idx_with_zero(self) -> None:
185        if self.padding_idx is not None:
186            with torch.no_grad():
187                self.weight[self.padding_idx].fill_(0)
188
189    def forward(self, input: Tensor) -> Tensor:
190        return F.embedding(
191            input,
192            self.weight,
193            self.padding_idx,
194            self.max_norm,
195            self.norm_type,
196            self.scale_grad_by_freq,
197            self.sparse,
198        )
199
200    def extra_repr(self) -> str:
201        s = "{num_embeddings}, {embedding_dim}"
202        if self.padding_idx is not None:
203            s += ", padding_idx={padding_idx}"
204        if self.max_norm is not None:
205            s += ", max_norm={max_norm}"
206        if self.norm_type != 2:
207            s += ", norm_type={norm_type}"
208        if self.scale_grad_by_freq is not False:
209            s += ", scale_grad_by_freq={scale_grad_by_freq}"
210        if self.sparse is not False:
211            s += ", sparse=True"
212        return s.format(**self.__dict__)
213
214    @classmethod
215    def from_pretrained(
216        cls,
217        embeddings,
218        freeze=True,
219        padding_idx=None,
220        max_norm=None,
221        norm_type=2.0,
222        scale_grad_by_freq=False,
223        sparse=False,
224    ):
225        r"""Create Embedding instance from given 2-dimensional FloatTensor.
226
227        Args:
228            embeddings (Tensor): FloatTensor containing weights for the Embedding.
229                First dimension is being passed to Embedding as ``num_embeddings``, second as ``embedding_dim``.
230            freeze (bool, optional): If ``True``, the tensor does not get updated in the learning process.
231                Equivalent to ``embedding.weight.requires_grad = False``. Default: ``True``
232            padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient;
233                                         therefore, the embedding vector at :attr:`padding_idx` is not updated during training,
234                                         i.e. it remains as a fixed "pad".
235            max_norm (float, optional): See module initialization documentation.
236            norm_type (float, optional): See module initialization documentation. Default ``2``.
237            scale_grad_by_freq (bool, optional): See module initialization documentation. Default ``False``.
238            sparse (bool, optional): See module initialization documentation.
239
240        Examples::
241
242            >>> # FloatTensor containing pretrained weights
243            >>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]])
244            >>> embedding = nn.Embedding.from_pretrained(weight)
245            >>> # Get embeddings for index 1
246            >>> input = torch.LongTensor([1])
247            >>> # xdoctest: +IGNORE_WANT("non-deterministic")
248            >>> embedding(input)
249            tensor([[ 4.0000,  5.1000,  6.3000]])
250        """
251        assert (
252            embeddings.dim() == 2
253        ), "Embeddings parameter is expected to be 2-dimensional"
254        rows, cols = embeddings.shape
255        embedding = cls(
256            num_embeddings=rows,
257            embedding_dim=cols,
258            _weight=embeddings,
259            _freeze=freeze,
260            padding_idx=padding_idx,
261            max_norm=max_norm,
262            norm_type=norm_type,
263            scale_grad_by_freq=scale_grad_by_freq,
264            sparse=sparse,
265        )
266        return embedding
267
268
269class EmbeddingBag(Module):
270    r"""Compute sums or means of 'bags' of embeddings, without instantiating the intermediate embeddings.
271
272    For bags of constant length, no :attr:`per_sample_weights`, no indices equal to :attr:`padding_idx`,
273    and with 2D inputs, this class
274
275        * with ``mode="sum"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.sum(dim=1)``,
276        * with ``mode="mean"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.mean(dim=1)``,
277        * with ``mode="max"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.max(dim=1)``.
278
279    However, :class:`~torch.nn.EmbeddingBag` is much more time and memory efficient than using a chain of these
280    operations.
281
282    EmbeddingBag also supports per-sample weights as an argument to the forward
283    pass. This scales the output of the Embedding before performing a weighted
284    reduction as specified by ``mode``. If :attr:`per_sample_weights` is passed, the
285    only supported ``mode`` is ``"sum"``, which computes a weighted sum according to
286    :attr:`per_sample_weights`.
287
288    Args:
289        num_embeddings (int): size of the dictionary of embeddings
290        embedding_dim (int): the size of each embedding vector
291        max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`
292                                    is renormalized to have norm :attr:`max_norm`.
293        norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``.
294        scale_grad_by_freq (bool, optional): if given, this will scale gradients by the inverse of frequency of
295                                                the words in the mini-batch. Default ``False``.
296                                                Note: this option is not supported when ``mode="max"``.
297        mode (str, optional): ``"sum"``, ``"mean"`` or ``"max"``. Specifies the way to reduce the bag.
298                                 ``"sum"`` computes the weighted sum, taking :attr:`per_sample_weights`
299                                 into consideration. ``"mean"`` computes the average of the values
300                                 in the bag, ``"max"`` computes the max value over each bag.
301                                 Default: ``"mean"``
302        sparse (bool, optional): if ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. See
303                                 Notes for more details regarding sparse gradients. Note: this option is not
304                                 supported when ``mode="max"``.
305        include_last_offset (bool, optional): if ``True``, :attr:`offsets` has one additional element, where the last element
306                                      is equivalent to the size of `indices`. This matches the CSR format.
307        padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the
308                                     gradient; therefore, the embedding vector at :attr:`padding_idx` is not updated
309                                     during training, i.e. it remains as a fixed "pad". For a newly constructed
310                                     EmbeddingBag, the embedding vector at :attr:`padding_idx` will default to all
311                                     zeros, but can be updated to another value to be used as the padding vector.
312                                     Note that the embedding vector at :attr:`padding_idx` is excluded from the
313                                     reduction.
314
315    Attributes:
316        weight (Tensor): the learnable weights of the module of shape `(num_embeddings, embedding_dim)`
317                         initialized from :math:`\mathcal{N}(0, 1)`.
318
319    Examples::
320
321        >>> # an EmbeddingBag module containing 10 tensors of size 3
322        >>> embedding_sum = nn.EmbeddingBag(10, 3, mode='sum')
323        >>> # a batch of 2 samples of 4 indices each
324        >>> input = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.long)
325        >>> offsets = torch.tensor([0, 4], dtype=torch.long)
326        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
327        >>> embedding_sum(input, offsets)
328        tensor([[-0.8861, -5.4350, -0.0523],
329                [ 1.1306, -2.5798, -1.0044]])
330
331        >>> # Example with padding_idx
332        >>> embedding_sum = nn.EmbeddingBag(10, 3, mode='sum', padding_idx=2)
333        >>> input = torch.tensor([2, 2, 2, 2, 4, 3, 2, 9], dtype=torch.long)
334        >>> offsets = torch.tensor([0, 4], dtype=torch.long)
335        >>> embedding_sum(input, offsets)
336        tensor([[ 0.0000,  0.0000,  0.0000],
337                [-0.7082,  3.2145, -2.6251]])
338
339        >>> # An EmbeddingBag can be loaded from an Embedding like so
340        >>> embedding = nn.Embedding(10, 3, padding_idx=2)
341        >>> embedding_sum = nn.EmbeddingBag.from_pretrained(
342                embedding.weight,
343                padding_idx=embedding.padding_idx,
344                mode='sum')
345    """
346
347    __constants__ = [
348        "num_embeddings",
349        "embedding_dim",
350        "max_norm",
351        "norm_type",
352        "scale_grad_by_freq",
353        "mode",
354        "sparse",
355        "include_last_offset",
356        "padding_idx",
357    ]
358
359    num_embeddings: int
360    embedding_dim: int
361    max_norm: Optional[float]
362    norm_type: float
363    scale_grad_by_freq: bool
364    weight: Tensor
365    mode: str
366    sparse: bool
367    include_last_offset: bool
368    padding_idx: Optional[int]
369
370    def __init__(
371        self,
372        num_embeddings: int,
373        embedding_dim: int,
374        max_norm: Optional[float] = None,
375        norm_type: float = 2.0,
376        scale_grad_by_freq: bool = False,
377        mode: str = "mean",
378        sparse: bool = False,
379        _weight: Optional[Tensor] = None,
380        include_last_offset: bool = False,
381        padding_idx: Optional[int] = None,
382        device=None,
383        dtype=None,
384    ) -> None:
385        factory_kwargs = {"device": device, "dtype": dtype}
386        super().__init__()
387        self.num_embeddings = num_embeddings
388        self.embedding_dim = embedding_dim
389        self.max_norm = max_norm
390        self.norm_type = norm_type
391        self.scale_grad_by_freq = scale_grad_by_freq
392        if padding_idx is not None:
393            if padding_idx > 0:
394                assert (
395                    padding_idx < self.num_embeddings
396                ), "padding_idx must be within num_embeddings"
397            elif padding_idx < 0:
398                assert (
399                    padding_idx >= -self.num_embeddings
400                ), "padding_idx must be within num_embeddings"
401                padding_idx = self.num_embeddings + padding_idx
402        self.padding_idx = padding_idx
403        if _weight is None:
404            self.weight = Parameter(
405                torch.empty((num_embeddings, embedding_dim), **factory_kwargs)
406            )
407            self.reset_parameters()
408        else:
409            assert list(_weight.shape) == [
410                num_embeddings,
411                embedding_dim,
412            ], "Shape of weight does not match num_embeddings and embedding_dim"
413            self.weight = Parameter(_weight)
414        self.mode = mode
415        self.sparse = sparse
416        self.include_last_offset = include_last_offset
417
418    def reset_parameters(self) -> None:
419        init.normal_(self.weight)
420        self._fill_padding_idx_with_zero()
421
422    def _fill_padding_idx_with_zero(self) -> None:
423        if self.padding_idx is not None:
424            with torch.no_grad():
425                self.weight[self.padding_idx].fill_(0)
426
427    def forward(
428        self,
429        input: Tensor,
430        offsets: Optional[Tensor] = None,
431        per_sample_weights: Optional[Tensor] = None,
432    ) -> Tensor:
433        """Forward pass of EmbeddingBag.
434
435        Args:
436            input (Tensor): Tensor containing bags of indices into the embedding matrix.
437            offsets (Tensor, optional): Only used when :attr:`input` is 1D. :attr:`offsets` determines
438                the starting index position of each bag (sequence) in :attr:`input`.
439            per_sample_weights (Tensor, optional): a tensor of float / double weights, or None
440                to indicate all weights should be taken to be ``1``. If specified, :attr:`per_sample_weights`
441                must have exactly the same shape as input and is treated as having the same
442                :attr:`offsets`, if those are not ``None``. Only supported for ``mode='sum'``.
443
444        Returns:
445            Tensor output shape of `(B, embedding_dim)`.
446
447        .. note::
448
449            A few notes about ``input`` and ``offsets``:
450
451            - :attr:`input` and :attr:`offsets` have to be of the same type, either int or long
452
453            - If :attr:`input` is 2D of shape `(B, N)`, it will be treated as ``B`` bags (sequences)
454              each of fixed length ``N``, and this will return ``B`` values aggregated in a way
455              depending on the :attr:`mode`. :attr:`offsets` is ignored and required to be ``None`` in this case.
456
457            - If :attr:`input` is 1D of shape `(N)`, it will be treated as a concatenation of
458              multiple bags (sequences).  :attr:`offsets` is required to be a 1D tensor containing the
459              starting index positions of each bag in :attr:`input`. Therefore, for :attr:`offsets` of shape `(B)`,
460              :attr:`input` will be viewed as having ``B`` bags. Empty bags (i.e., having 0-length) will have
461              returned vectors filled by zeros.
462        """
463        return F.embedding_bag(
464            input,
465            self.weight,
466            offsets,
467            self.max_norm,
468            self.norm_type,
469            self.scale_grad_by_freq,
470            self.mode,
471            self.sparse,
472            per_sample_weights,
473            self.include_last_offset,
474            self.padding_idx,
475        )
476
477    def extra_repr(self) -> str:
478        s = "{num_embeddings}, {embedding_dim}"
479        if self.max_norm is not None:
480            s += ", max_norm={max_norm}"
481        if self.norm_type != 2:
482            s += ", norm_type={norm_type}"
483        if self.scale_grad_by_freq is not False:
484            s += ", scale_grad_by_freq={scale_grad_by_freq}"
485        s += ", mode={mode}"
486        if self.padding_idx is not None:
487            s += ", padding_idx={padding_idx}"
488        return s.format(**{k: repr(v) for k, v in self.__dict__.items()})
489
490    @classmethod
491    def from_pretrained(
492        cls,
493        embeddings: Tensor,
494        freeze: bool = True,
495        max_norm: Optional[float] = None,
496        norm_type: float = 2.0,
497        scale_grad_by_freq: bool = False,
498        mode: str = "mean",
499        sparse: bool = False,
500        include_last_offset: bool = False,
501        padding_idx: Optional[int] = None,
502    ) -> "EmbeddingBag":
503        r"""Create EmbeddingBag instance from given 2-dimensional FloatTensor.
504
505        Args:
506            embeddings (Tensor): FloatTensor containing weights for the EmbeddingBag.
507                First dimension is being passed to EmbeddingBag as 'num_embeddings', second as 'embedding_dim'.
508            freeze (bool, optional): If ``True``, the tensor does not get updated in the learning process.
509                Equivalent to ``embeddingbag.weight.requires_grad = False``. Default: ``True``
510            max_norm (float, optional): See module initialization documentation. Default: ``None``
511            norm_type (float, optional): See module initialization documentation. Default ``2``.
512            scale_grad_by_freq (bool, optional): See module initialization documentation. Default ``False``.
513            mode (str, optional): See module initialization documentation. Default: ``"mean"``
514            sparse (bool, optional): See module initialization documentation. Default: ``False``.
515            include_last_offset (bool, optional): See module initialization documentation. Default: ``False``.
516            padding_idx (int, optional): See module initialization documentation. Default: ``None``.
517
518        Examples::
519
520            >>> # FloatTensor containing pretrained weights
521            >>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]])
522            >>> embeddingbag = nn.EmbeddingBag.from_pretrained(weight)
523            >>> # Get embeddings for index 1
524            >>> input = torch.LongTensor([[1, 0]])
525            >>> # xdoctest: +IGNORE_WANT("non-deterministic")
526            >>> embeddingbag(input)
527            tensor([[ 2.5000,  3.7000,  4.6500]])
528        """
529        assert (
530            embeddings.dim() == 2
531        ), "Embeddings parameter is expected to be 2-dimensional"
532        rows, cols = embeddings.shape
533        embeddingbag = cls(
534            num_embeddings=rows,
535            embedding_dim=cols,
536            _weight=embeddings,
537            max_norm=max_norm,
538            norm_type=norm_type,
539            scale_grad_by_freq=scale_grad_by_freq,
540            mode=mode,
541            sparse=sparse,
542            include_last_offset=include_last_offset,
543            padding_idx=padding_idx,
544        )
545        embeddingbag.weight.requires_grad = not freeze
546        return embeddingbag
547