xref: /aosp_15_r20/external/pytorch/torch/ao/nn/quantized/modules/embedding_ops.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3import torch
4import torch.nn as nn
5from torch import Tensor  # noqa: F401
6from torch._jit_internal import List, Optional  # noqa: F401
7
8from .utils import _hide_packed_params_repr, _quantize_weight
9
10
11__all__ = ["EmbeddingPackedParams", "Embedding", "EmbeddingBag"]
12
13
14class EmbeddingPackedParams(torch.nn.Module):
15    _version = 1
16
17    def __init__(self, num_embeddings, embedding_dim, dtype=torch.quint8):
18        super().__init__()
19        self.dtype = dtype
20        if self.dtype in [torch.quint8, torch.quint4x2]:
21            scales = torch.ones(num_embeddings, dtype=torch.float)
22            zero_points = torch.zeros(num_embeddings, dtype=torch.float)
23            wq = torch._empty_per_channel_affine_quantized(
24                [num_embeddings, embedding_dim],
25                scales=scales,
26                zero_points=zero_points,
27                axis=0,
28                dtype=self.dtype,
29            )
30            self.set_weight(wq)
31        else:
32            raise NotImplementedError(
33                f"Unsupported dtype on quantized embedding! Supports quint8 and quint4x2. Got dtype: {dtype}"
34            )
35
36    @torch.jit.export
37    def set_weight(self, weight: torch.Tensor) -> None:
38        if self.dtype in [torch.quint8, torch.quint4x2]:
39            self._packed_weight = torch.ops.quantized.embedding_bag_prepack(weight)
40        else:
41            raise NotImplementedError(
42                "Unsupported dtype for quantized embedding prepack! Supports quint8 and quint4x2."
43            )
44
45    @torch.jit.export
46    def _weight(self):
47        if self.dtype in [torch.quint8, torch.quint4x2]:
48            return torch.ops.quantized.embedding_bag_unpack(self._packed_weight)
49        else:
50            raise NotImplementedError(
51                "Unsupported dtype for quantized embedding unpack! Supports quint8 and quint4x2."
52            )
53
54    def forward(self, x):
55        return x
56
57    # Version 1
58    #   self
59    #   |--- _packed_weight : Tensor representing weight of EmbeddingPackedParamsBase
60    #   |--- dtype : torch.dtype
61
62    def _save_to_state_dict(self, destination, prefix, keep_vars):
63        super()._save_to_state_dict(destination, prefix, keep_vars)
64        destination[prefix + "dtype"] = self.dtype
65        destination[prefix + "_packed_weight"] = self._weight()
66
67    def _load_from_state_dict(
68        self,
69        state_dict,
70        prefix,
71        local_metadata,
72        strict,
73        missing_keys,
74        unexpected_keys,
75        error_msgs,
76    ):
77        self.dtype = state_dict[prefix + "dtype"]
78        state_dict.pop(prefix + "dtype")
79
80        weight = state_dict[prefix + "_packed_weight"]
81        state_dict.pop(prefix + "_packed_weight")
82        self.set_weight(weight)
83
84        super()._load_from_state_dict(
85            state_dict,
86            prefix,
87            local_metadata,
88            False,
89            missing_keys,
90            unexpected_keys,
91            error_msgs,
92        )
93
94    def __repr__(self):
95        return self._weight().__repr__()
96
97
98class Embedding(torch.nn.Module):
99    r"""
100    A quantized Embedding module with quantized packed weights as inputs.
101    We adopt the same interface as `torch.nn.Embedding`, please see
102    https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html for documentation.
103
104    Similar to :class:`~torch.nn.Embedding`, attributes will be randomly
105    initialized at module creation time and will be overwritten later
106
107    Attributes:
108        weight (Tensor): the non-learnable quantized weights of the module of
109                         shape :math:`(\text{num\_embeddings}, \text{embedding\_dim})`.
110
111    Examples::
112        >>> m = nn.quantized.Embedding(num_embeddings=10, embedding_dim=12)
113        >>> indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8])
114        >>> output = m(indices)
115        >>> print(output.size())
116        torch.Size([9, 12])
117
118    """
119    _version = 1
120
121    def __init__(
122        self,
123        num_embeddings: int,
124        embedding_dim: int,
125        padding_idx: Optional[int] = None,
126        max_norm: Optional[float] = None,
127        norm_type: float = 2.0,
128        scale_grad_by_freq: bool = False,
129        sparse: bool = False,
130        _weight: Optional[Tensor] = None,
131        dtype=torch.quint8,
132    ) -> None:
133        super().__init__()
134        self.num_embeddings = num_embeddings
135        self.embedding_dim = embedding_dim
136        self.dtype = dtype
137
138        if _weight is None:
139            scales = torch.ones(num_embeddings, dtype=torch.float)
140            zero_points = torch.zeros(num_embeddings, dtype=torch.float)
141            qweight = torch._empty_per_channel_affine_quantized(
142                [num_embeddings, embedding_dim],
143                scales=scales,
144                zero_points=zero_points,
145                axis=0,
146                dtype=torch.quint8,
147            )
148        else:
149            assert list(_weight.shape) == [
150                num_embeddings,
151                embedding_dim,
152            ], "Shape of weight does not match num_embeddings and embedding_dim"
153            qweight = _weight
154
155        self._packed_params = EmbeddingPackedParams(
156            num_embeddings, embedding_dim, dtype
157        )
158        self._packed_params.set_weight(qweight)
159
160    def forward(self, indices: Tensor) -> Tensor:
161        if self.dtype == torch.quint4x2:
162            return torch.ops.quantized.embedding_4bit(
163                self._packed_params._packed_weight, indices
164            )
165        else:
166            return torch.ops.quantized.embedding_byte(
167                self._packed_params._packed_weight, indices
168            )
169
170    def _get_name(self):
171        return "QuantizedEmbedding"
172
173    def __repr__(self):
174        return _hide_packed_params_repr(self, EmbeddingPackedParams)
175
176    def extra_repr(self):
177        extra_repr_str = (
178            f"num_embeddings={self.num_embeddings}, embedding_dim={self.embedding_dim}, "
179            f"dtype={self._packed_params.dtype}, qscheme={self.weight().qscheme()}"
180        )
181
182        return extra_repr_str
183
184    def set_weight(self, w: torch.Tensor) -> None:
185        self._packed_params.set_weight(w)
186
187    def weight(self):
188        return self._packed_params._weight()
189
190    @classmethod
191    def from_float(cls, mod, use_precomputed_fake_quant=False):
192        r"""Create a quantized embedding module from a float module
193
194        Args:
195            mod (Module): a float module, either produced by torch.ao.quantization
196                          utilities or provided by user
197        """
198        if hasattr(mod, "weight_fake_quant"):
199            assert type(mod) == torch.ao.nn.qat.Embedding, (
200                "nnq."
201                + cls.__name__
202                + ".from_float "
203                + "with fake quant only works for "
204                + torch.ao.nn.qat.Embedding.__name__
205            )
206            weight_observer = mod.weight_fake_quant
207            activation_post_process = mod.activation_post_process
208        else:
209            assert type(mod) == nn.Embedding, (
210                "nnq."
211                + cls.__name__
212                + ".from_float only works for "
213                + nn.Embedding.__name__
214            )
215            assert hasattr(
216                mod, "qconfig"
217            ), "Embedding input float module must have qconfig defined"
218            from torch.ao.quantization import float_qparams_weight_only_qconfig
219
220            if mod.qconfig is not None and mod.qconfig.weight is not None:  # type: ignore[union-attr]
221                weight_observer = mod.qconfig.weight()  # type: ignore[union-attr, operator]
222            else:
223                weight_observer = float_qparams_weight_only_qconfig.weight()
224
225        dtype = weight_observer.dtype
226        is_float_qparams_qconfig = (
227            weight_observer.qscheme == torch.per_channel_affine_float_qparams
228        )
229        assert (
230            is_float_qparams_qconfig
231        ), "Embedding quantization is only supported with float_qparams_weight_only_qconfig."
232
233        assert (
234            dtype == torch.quint8 or dtype == torch.quint4x2
235        ), f"The only supported dtype for nnq.Embedding is torch.quint8 and torch.quint4x2, got {dtype}"
236
237        # Run the observer to calculate qparams.
238        weight_observer(mod.weight)
239        qweight = _quantize_weight(mod.weight.float(), weight_observer)
240
241        # Create quantized Embedding module and pass in the quantized weight
242        qembedding = Embedding(mod.num_embeddings, mod.embedding_dim)
243        qembedding.set_weight(qweight)
244        return qembedding
245
246    @classmethod
247    def from_reference(cls, ref_embedding):
248        qembedding = cls(
249            ref_embedding.num_embeddings,
250            ref_embedding.embedding_dim,
251            ref_embedding.padding_idx,
252            ref_embedding.max_norm,
253            ref_embedding.norm_type,
254            ref_embedding.scale_grad_by_freq,
255            ref_embedding.sparse,
256            ref_embedding.get_quantized_weight(),
257            ref_embedding.weight_dtype,
258        )
259        return qembedding
260
261
262class EmbeddingBag(Embedding):
263    r"""
264    A quantized EmbeddingBag module with quantized packed weights as inputs.
265    We adopt the same interface as `torch.nn.EmbeddingBag`, please see
266    https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html for documentation.
267
268    Similar to :class:`~torch.nn.EmbeddingBag`, attributes will be randomly
269    initialized at module creation time and will be overwritten later
270
271    Attributes:
272        weight (Tensor): the non-learnable quantized weights of the module of
273                         shape :math:`(\text{num\_embeddings}, \text{embedding\_dim})`.
274
275    Examples::
276        >>> m = nn.quantized.EmbeddingBag(num_embeddings=10, embedding_dim=12, include_last_offset=True, mode='sum')
277        >>> indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3])
278        >>> offsets = torch.tensor([0, 19, 20, 28, 28, 32])
279        >>> output = m(indices, offsets)
280        >>> print(output.size())
281        torch.Size([5, 12])
282
283    """
284    _version = 1
285
286    def __init__(
287        self,
288        num_embeddings: int,
289        embedding_dim: int,
290        max_norm: Optional[float] = None,
291        norm_type: float = 2.0,
292        scale_grad_by_freq: bool = False,
293        mode: str = "sum",
294        sparse: bool = False,
295        _weight: Optional[Tensor] = None,
296        include_last_offset: bool = False,
297        dtype=torch.quint8,
298    ) -> None:
299        super().__init__(num_embeddings, embedding_dim, _weight=_weight, dtype=dtype)
300
301        self.mode = mode
302        self.pruned_weights = False
303        self.include_last_offset = include_last_offset
304        self.dtype = dtype
305
306    def forward(
307        self,
308        indices: Tensor,
309        offsets: Optional[Tensor] = None,
310        per_sample_weights: Optional[Tensor] = None,
311        compressed_indices_mapping: Optional[Tensor] = None,
312    ) -> Tensor:
313        if self.dtype == torch.quint4x2:
314            return torch.ops.quantized.embedding_bag_4bit(
315                self._packed_params._packed_weight,
316                indices,
317                offsets,
318                False,
319                0,
320                self.pruned_weights,
321                per_sample_weights,
322                compressed_indices_mapping,
323                self.include_last_offset,
324            )
325        else:
326            return torch.ops.quantized.embedding_bag_byte(
327                self._packed_params._packed_weight,
328                indices,
329                offsets,
330                False,
331                0,
332                self.pruned_weights,
333                per_sample_weights,
334                compressed_indices_mapping,
335                self.include_last_offset,
336            )
337
338    def _get_name(self):
339        return "QuantizedEmbeddingBag"
340
341    @classmethod
342    def from_float(cls, mod, use_precomputed_fake_quant=False):
343        r"""Create a quantized embedding_bag module from a float module
344
345        Args:
346            mod (Module): a float module, either produced by torch.ao.quantization
347                          utilities or provided by user
348        """
349        if hasattr(mod, "weight_fake_quant"):
350            weight_observer = mod.weight_fake_quant
351        else:
352            assert type(mod) == nn.EmbeddingBag, (
353                "nnq."
354                + cls.__name__
355                + ".from_float only works for "
356                + nn.EmbeddingBag.__name__
357            )
358            assert hasattr(
359                mod, "qconfig"
360            ), "EmbeddingBag input float module must have qconfig defined"
361            from torch.ao.quantization.qconfig import float_qparams_weight_only_qconfig
362
363            if mod.qconfig is not None and mod.qconfig.weight is not None:  # type: ignore[union-attr]
364                weight_observer = mod.qconfig.weight()  # type: ignore[union-attr, operator]
365            else:
366                weight_observer = float_qparams_weight_only_qconfig.weight()
367
368        dtype = weight_observer.dtype
369        is_float_qparams_qconfig = (
370            weight_observer.qscheme == torch.per_channel_affine_float_qparams
371        )
372        assert (
373            is_float_qparams_qconfig
374        ), "EmbeddingBag quantization is only supported with float_qparams_weight_only_qconfig."
375
376        assert (
377            dtype == torch.quint8 or dtype == torch.quint4x2
378        ), f"The only supported dtype for nnq.EmbeddingBag is torch.quint8 and torch.quint4x2, got {dtype}"
379
380        # Run the observer to calculate qparams.
381        weight_observer(mod.weight)
382        qweight = _quantize_weight(mod.weight.float(), weight_observer)
383
384        # Create quantized EmbeddingBag module and pass in the quantized weight
385        qembedding_bag = EmbeddingBag(
386            mod.num_embeddings, mod.embedding_dim, dtype=dtype
387        )
388        qembedding_bag.set_weight(qweight)
389        return qembedding_bag
390
391    @classmethod
392    def from_reference(cls, ref_embedding_bag):
393        qembedding_bag = cls(
394            ref_embedding_bag.num_embeddings,
395            ref_embedding_bag.embedding_dim,
396            ref_embedding_bag.max_norm,
397            ref_embedding_bag.norm_type,
398            ref_embedding_bag.scale_grad_by_freq,
399            ref_embedding_bag.mode,
400            ref_embedding_bag.sparse,
401            ref_embedding_bag.get_quantized_weight(),
402            ref_embedding_bag.include_last_offset,
403            ref_embedding_bag.weight_dtype,
404        )
405        return qembedding_bag
406