1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3import torch 4import torch.nn as nn 5from torch import Tensor # noqa: F401 6from torch._jit_internal import List, Optional # noqa: F401 7 8from .utils import _hide_packed_params_repr, _quantize_weight 9 10 11__all__ = ["EmbeddingPackedParams", "Embedding", "EmbeddingBag"] 12 13 14class EmbeddingPackedParams(torch.nn.Module): 15 _version = 1 16 17 def __init__(self, num_embeddings, embedding_dim, dtype=torch.quint8): 18 super().__init__() 19 self.dtype = dtype 20 if self.dtype in [torch.quint8, torch.quint4x2]: 21 scales = torch.ones(num_embeddings, dtype=torch.float) 22 zero_points = torch.zeros(num_embeddings, dtype=torch.float) 23 wq = torch._empty_per_channel_affine_quantized( 24 [num_embeddings, embedding_dim], 25 scales=scales, 26 zero_points=zero_points, 27 axis=0, 28 dtype=self.dtype, 29 ) 30 self.set_weight(wq) 31 else: 32 raise NotImplementedError( 33 f"Unsupported dtype on quantized embedding! Supports quint8 and quint4x2. Got dtype: {dtype}" 34 ) 35 36 @torch.jit.export 37 def set_weight(self, weight: torch.Tensor) -> None: 38 if self.dtype in [torch.quint8, torch.quint4x2]: 39 self._packed_weight = torch.ops.quantized.embedding_bag_prepack(weight) 40 else: 41 raise NotImplementedError( 42 "Unsupported dtype for quantized embedding prepack! Supports quint8 and quint4x2." 43 ) 44 45 @torch.jit.export 46 def _weight(self): 47 if self.dtype in [torch.quint8, torch.quint4x2]: 48 return torch.ops.quantized.embedding_bag_unpack(self._packed_weight) 49 else: 50 raise NotImplementedError( 51 "Unsupported dtype for quantized embedding unpack! Supports quint8 and quint4x2." 52 ) 53 54 def forward(self, x): 55 return x 56 57 # Version 1 58 # self 59 # |--- _packed_weight : Tensor representing weight of EmbeddingPackedParamsBase 60 # |--- dtype : torch.dtype 61 62 def _save_to_state_dict(self, destination, prefix, keep_vars): 63 super()._save_to_state_dict(destination, prefix, keep_vars) 64 destination[prefix + "dtype"] = self.dtype 65 destination[prefix + "_packed_weight"] = self._weight() 66 67 def _load_from_state_dict( 68 self, 69 state_dict, 70 prefix, 71 local_metadata, 72 strict, 73 missing_keys, 74 unexpected_keys, 75 error_msgs, 76 ): 77 self.dtype = state_dict[prefix + "dtype"] 78 state_dict.pop(prefix + "dtype") 79 80 weight = state_dict[prefix + "_packed_weight"] 81 state_dict.pop(prefix + "_packed_weight") 82 self.set_weight(weight) 83 84 super()._load_from_state_dict( 85 state_dict, 86 prefix, 87 local_metadata, 88 False, 89 missing_keys, 90 unexpected_keys, 91 error_msgs, 92 ) 93 94 def __repr__(self): 95 return self._weight().__repr__() 96 97 98class Embedding(torch.nn.Module): 99 r""" 100 A quantized Embedding module with quantized packed weights as inputs. 101 We adopt the same interface as `torch.nn.Embedding`, please see 102 https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html for documentation. 103 104 Similar to :class:`~torch.nn.Embedding`, attributes will be randomly 105 initialized at module creation time and will be overwritten later 106 107 Attributes: 108 weight (Tensor): the non-learnable quantized weights of the module of 109 shape :math:`(\text{num\_embeddings}, \text{embedding\_dim})`. 110 111 Examples:: 112 >>> m = nn.quantized.Embedding(num_embeddings=10, embedding_dim=12) 113 >>> indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8]) 114 >>> output = m(indices) 115 >>> print(output.size()) 116 torch.Size([9, 12]) 117 118 """ 119 _version = 1 120 121 def __init__( 122 self, 123 num_embeddings: int, 124 embedding_dim: int, 125 padding_idx: Optional[int] = None, 126 max_norm: Optional[float] = None, 127 norm_type: float = 2.0, 128 scale_grad_by_freq: bool = False, 129 sparse: bool = False, 130 _weight: Optional[Tensor] = None, 131 dtype=torch.quint8, 132 ) -> None: 133 super().__init__() 134 self.num_embeddings = num_embeddings 135 self.embedding_dim = embedding_dim 136 self.dtype = dtype 137 138 if _weight is None: 139 scales = torch.ones(num_embeddings, dtype=torch.float) 140 zero_points = torch.zeros(num_embeddings, dtype=torch.float) 141 qweight = torch._empty_per_channel_affine_quantized( 142 [num_embeddings, embedding_dim], 143 scales=scales, 144 zero_points=zero_points, 145 axis=0, 146 dtype=torch.quint8, 147 ) 148 else: 149 assert list(_weight.shape) == [ 150 num_embeddings, 151 embedding_dim, 152 ], "Shape of weight does not match num_embeddings and embedding_dim" 153 qweight = _weight 154 155 self._packed_params = EmbeddingPackedParams( 156 num_embeddings, embedding_dim, dtype 157 ) 158 self._packed_params.set_weight(qweight) 159 160 def forward(self, indices: Tensor) -> Tensor: 161 if self.dtype == torch.quint4x2: 162 return torch.ops.quantized.embedding_4bit( 163 self._packed_params._packed_weight, indices 164 ) 165 else: 166 return torch.ops.quantized.embedding_byte( 167 self._packed_params._packed_weight, indices 168 ) 169 170 def _get_name(self): 171 return "QuantizedEmbedding" 172 173 def __repr__(self): 174 return _hide_packed_params_repr(self, EmbeddingPackedParams) 175 176 def extra_repr(self): 177 extra_repr_str = ( 178 f"num_embeddings={self.num_embeddings}, embedding_dim={self.embedding_dim}, " 179 f"dtype={self._packed_params.dtype}, qscheme={self.weight().qscheme()}" 180 ) 181 182 return extra_repr_str 183 184 def set_weight(self, w: torch.Tensor) -> None: 185 self._packed_params.set_weight(w) 186 187 def weight(self): 188 return self._packed_params._weight() 189 190 @classmethod 191 def from_float(cls, mod, use_precomputed_fake_quant=False): 192 r"""Create a quantized embedding module from a float module 193 194 Args: 195 mod (Module): a float module, either produced by torch.ao.quantization 196 utilities or provided by user 197 """ 198 if hasattr(mod, "weight_fake_quant"): 199 assert type(mod) == torch.ao.nn.qat.Embedding, ( 200 "nnq." 201 + cls.__name__ 202 + ".from_float " 203 + "with fake quant only works for " 204 + torch.ao.nn.qat.Embedding.__name__ 205 ) 206 weight_observer = mod.weight_fake_quant 207 activation_post_process = mod.activation_post_process 208 else: 209 assert type(mod) == nn.Embedding, ( 210 "nnq." 211 + cls.__name__ 212 + ".from_float only works for " 213 + nn.Embedding.__name__ 214 ) 215 assert hasattr( 216 mod, "qconfig" 217 ), "Embedding input float module must have qconfig defined" 218 from torch.ao.quantization import float_qparams_weight_only_qconfig 219 220 if mod.qconfig is not None and mod.qconfig.weight is not None: # type: ignore[union-attr] 221 weight_observer = mod.qconfig.weight() # type: ignore[union-attr, operator] 222 else: 223 weight_observer = float_qparams_weight_only_qconfig.weight() 224 225 dtype = weight_observer.dtype 226 is_float_qparams_qconfig = ( 227 weight_observer.qscheme == torch.per_channel_affine_float_qparams 228 ) 229 assert ( 230 is_float_qparams_qconfig 231 ), "Embedding quantization is only supported with float_qparams_weight_only_qconfig." 232 233 assert ( 234 dtype == torch.quint8 or dtype == torch.quint4x2 235 ), f"The only supported dtype for nnq.Embedding is torch.quint8 and torch.quint4x2, got {dtype}" 236 237 # Run the observer to calculate qparams. 238 weight_observer(mod.weight) 239 qweight = _quantize_weight(mod.weight.float(), weight_observer) 240 241 # Create quantized Embedding module and pass in the quantized weight 242 qembedding = Embedding(mod.num_embeddings, mod.embedding_dim) 243 qembedding.set_weight(qweight) 244 return qembedding 245 246 @classmethod 247 def from_reference(cls, ref_embedding): 248 qembedding = cls( 249 ref_embedding.num_embeddings, 250 ref_embedding.embedding_dim, 251 ref_embedding.padding_idx, 252 ref_embedding.max_norm, 253 ref_embedding.norm_type, 254 ref_embedding.scale_grad_by_freq, 255 ref_embedding.sparse, 256 ref_embedding.get_quantized_weight(), 257 ref_embedding.weight_dtype, 258 ) 259 return qembedding 260 261 262class EmbeddingBag(Embedding): 263 r""" 264 A quantized EmbeddingBag module with quantized packed weights as inputs. 265 We adopt the same interface as `torch.nn.EmbeddingBag`, please see 266 https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html for documentation. 267 268 Similar to :class:`~torch.nn.EmbeddingBag`, attributes will be randomly 269 initialized at module creation time and will be overwritten later 270 271 Attributes: 272 weight (Tensor): the non-learnable quantized weights of the module of 273 shape :math:`(\text{num\_embeddings}, \text{embedding\_dim})`. 274 275 Examples:: 276 >>> m = nn.quantized.EmbeddingBag(num_embeddings=10, embedding_dim=12, include_last_offset=True, mode='sum') 277 >>> indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3]) 278 >>> offsets = torch.tensor([0, 19, 20, 28, 28, 32]) 279 >>> output = m(indices, offsets) 280 >>> print(output.size()) 281 torch.Size([5, 12]) 282 283 """ 284 _version = 1 285 286 def __init__( 287 self, 288 num_embeddings: int, 289 embedding_dim: int, 290 max_norm: Optional[float] = None, 291 norm_type: float = 2.0, 292 scale_grad_by_freq: bool = False, 293 mode: str = "sum", 294 sparse: bool = False, 295 _weight: Optional[Tensor] = None, 296 include_last_offset: bool = False, 297 dtype=torch.quint8, 298 ) -> None: 299 super().__init__(num_embeddings, embedding_dim, _weight=_weight, dtype=dtype) 300 301 self.mode = mode 302 self.pruned_weights = False 303 self.include_last_offset = include_last_offset 304 self.dtype = dtype 305 306 def forward( 307 self, 308 indices: Tensor, 309 offsets: Optional[Tensor] = None, 310 per_sample_weights: Optional[Tensor] = None, 311 compressed_indices_mapping: Optional[Tensor] = None, 312 ) -> Tensor: 313 if self.dtype == torch.quint4x2: 314 return torch.ops.quantized.embedding_bag_4bit( 315 self._packed_params._packed_weight, 316 indices, 317 offsets, 318 False, 319 0, 320 self.pruned_weights, 321 per_sample_weights, 322 compressed_indices_mapping, 323 self.include_last_offset, 324 ) 325 else: 326 return torch.ops.quantized.embedding_bag_byte( 327 self._packed_params._packed_weight, 328 indices, 329 offsets, 330 False, 331 0, 332 self.pruned_weights, 333 per_sample_weights, 334 compressed_indices_mapping, 335 self.include_last_offset, 336 ) 337 338 def _get_name(self): 339 return "QuantizedEmbeddingBag" 340 341 @classmethod 342 def from_float(cls, mod, use_precomputed_fake_quant=False): 343 r"""Create a quantized embedding_bag module from a float module 344 345 Args: 346 mod (Module): a float module, either produced by torch.ao.quantization 347 utilities or provided by user 348 """ 349 if hasattr(mod, "weight_fake_quant"): 350 weight_observer = mod.weight_fake_quant 351 else: 352 assert type(mod) == nn.EmbeddingBag, ( 353 "nnq." 354 + cls.__name__ 355 + ".from_float only works for " 356 + nn.EmbeddingBag.__name__ 357 ) 358 assert hasattr( 359 mod, "qconfig" 360 ), "EmbeddingBag input float module must have qconfig defined" 361 from torch.ao.quantization.qconfig import float_qparams_weight_only_qconfig 362 363 if mod.qconfig is not None and mod.qconfig.weight is not None: # type: ignore[union-attr] 364 weight_observer = mod.qconfig.weight() # type: ignore[union-attr, operator] 365 else: 366 weight_observer = float_qparams_weight_only_qconfig.weight() 367 368 dtype = weight_observer.dtype 369 is_float_qparams_qconfig = ( 370 weight_observer.qscheme == torch.per_channel_affine_float_qparams 371 ) 372 assert ( 373 is_float_qparams_qconfig 374 ), "EmbeddingBag quantization is only supported with float_qparams_weight_only_qconfig." 375 376 assert ( 377 dtype == torch.quint8 or dtype == torch.quint4x2 378 ), f"The only supported dtype for nnq.EmbeddingBag is torch.quint8 and torch.quint4x2, got {dtype}" 379 380 # Run the observer to calculate qparams. 381 weight_observer(mod.weight) 382 qweight = _quantize_weight(mod.weight.float(), weight_observer) 383 384 # Create quantized EmbeddingBag module and pass in the quantized weight 385 qembedding_bag = EmbeddingBag( 386 mod.num_embeddings, mod.embedding_dim, dtype=dtype 387 ) 388 qembedding_bag.set_weight(qweight) 389 return qembedding_bag 390 391 @classmethod 392 def from_reference(cls, ref_embedding_bag): 393 qembedding_bag = cls( 394 ref_embedding_bag.num_embeddings, 395 ref_embedding_bag.embedding_dim, 396 ref_embedding_bag.max_norm, 397 ref_embedding_bag.norm_type, 398 ref_embedding_bag.scale_grad_by_freq, 399 ref_embedding_bag.mode, 400 ref_embedding_bag.sparse, 401 ref_embedding_bag.get_quantized_weight(), 402 ref_embedding_bag.include_last_offset, 403 ref_embedding_bag.weight_dtype, 404 ) 405 return qembedding_bag 406