1# mypy: allow-untyped-defs 2from typing import Optional 3 4import torch 5from torch import Tensor 6from torch.nn import functional as F, init 7from torch.nn.parameter import Parameter 8 9from .module import Module 10 11 12__all__ = ["Embedding", "EmbeddingBag"] 13 14 15class Embedding(Module): 16 r"""A simple lookup table that stores embeddings of a fixed dictionary and size. 17 18 This module is often used to store word embeddings and retrieve them using indices. 19 The input to the module is a list of indices, and the output is the corresponding 20 word embeddings. 21 22 Args: 23 num_embeddings (int): size of the dictionary of embeddings 24 embedding_dim (int): the size of each embedding vector 25 padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient; 26 therefore, the embedding vector at :attr:`padding_idx` is not updated during training, 27 i.e. it remains as a fixed "pad". For a newly constructed Embedding, 28 the embedding vector at :attr:`padding_idx` will default to all zeros, 29 but can be updated to another value to be used as the padding vector. 30 max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm` 31 is renormalized to have norm :attr:`max_norm`. 32 norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``. 33 scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse of frequency of 34 the words in the mini-batch. Default ``False``. 35 sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. 36 See Notes for more details regarding sparse gradients. 37 38 Attributes: 39 weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim) 40 initialized from :math:`\mathcal{N}(0, 1)` 41 42 Shape: 43 - Input: :math:`(*)`, IntTensor or LongTensor of arbitrary shape containing the indices to extract 44 - Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}` 45 46 .. note:: 47 Keep in mind that only a limited number of optimizers support 48 sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`), 49 :class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`) 50 51 .. note:: 52 When :attr:`max_norm` is not ``None``, :class:`Embedding`'s forward method will modify the 53 :attr:`weight` tensor in-place. Since tensors needed for gradient computations cannot be 54 modified in-place, performing a differentiable operation on ``Embedding.weight`` before 55 calling :class:`Embedding`'s forward method requires cloning ``Embedding.weight`` when 56 :attr:`max_norm` is not ``None``. For example:: 57 58 n, d, m = 3, 5, 7 59 embedding = nn.Embedding(n, d, max_norm=1.0) 60 W = torch.randn((m, d), requires_grad=True) 61 idx = torch.tensor([1, 2]) 62 a = embedding.weight.clone() @ W.t() # weight must be cloned for this to be differentiable 63 b = embedding(idx) @ W.t() # modifies weight in-place 64 out = (a.unsqueeze(0) + b.unsqueeze(1)) 65 loss = out.sigmoid().prod() 66 loss.backward() 67 68 Examples:: 69 70 >>> # an Embedding module containing 10 tensors of size 3 71 >>> embedding = nn.Embedding(10, 3) 72 >>> # a batch of 2 samples of 4 indices each 73 >>> input = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]]) 74 >>> # xdoctest: +IGNORE_WANT("non-deterministic") 75 >>> embedding(input) 76 tensor([[[-0.0251, -1.6902, 0.7172], 77 [-0.6431, 0.0748, 0.6969], 78 [ 1.4970, 1.3448, -0.9685], 79 [-0.3677, -2.7265, -0.1685]], 80 81 [[ 1.4970, 1.3448, -0.9685], 82 [ 0.4362, -0.4004, 0.9400], 83 [-0.6431, 0.0748, 0.6969], 84 [ 0.9124, -2.3616, 1.1151]]]) 85 86 87 >>> # example with padding_idx 88 >>> embedding = nn.Embedding(10, 3, padding_idx=0) 89 >>> input = torch.LongTensor([[0, 2, 0, 5]]) 90 >>> embedding(input) 91 tensor([[[ 0.0000, 0.0000, 0.0000], 92 [ 0.1535, -2.0309, 0.9315], 93 [ 0.0000, 0.0000, 0.0000], 94 [-0.1655, 0.9897, 0.0635]]]) 95 96 >>> # example of changing `pad` vector 97 >>> padding_idx = 0 98 >>> embedding = nn.Embedding(3, 3, padding_idx=padding_idx) 99 >>> embedding.weight 100 Parameter containing: 101 tensor([[ 0.0000, 0.0000, 0.0000], 102 [-0.7895, -0.7089, -0.0364], 103 [ 0.6778, 0.5803, 0.2678]], requires_grad=True) 104 >>> with torch.no_grad(): 105 ... embedding.weight[padding_idx] = torch.ones(3) 106 >>> embedding.weight 107 Parameter containing: 108 tensor([[ 1.0000, 1.0000, 1.0000], 109 [-0.7895, -0.7089, -0.0364], 110 [ 0.6778, 0.5803, 0.2678]], requires_grad=True) 111 """ 112 113 __constants__ = [ 114 "num_embeddings", 115 "embedding_dim", 116 "padding_idx", 117 "max_norm", 118 "norm_type", 119 "scale_grad_by_freq", 120 "sparse", 121 ] 122 123 num_embeddings: int 124 embedding_dim: int 125 padding_idx: Optional[int] 126 max_norm: Optional[float] 127 norm_type: float 128 scale_grad_by_freq: bool 129 weight: Tensor 130 freeze: bool 131 sparse: bool 132 133 def __init__( 134 self, 135 num_embeddings: int, 136 embedding_dim: int, 137 padding_idx: Optional[int] = None, 138 max_norm: Optional[float] = None, 139 norm_type: float = 2.0, 140 scale_grad_by_freq: bool = False, 141 sparse: bool = False, 142 _weight: Optional[Tensor] = None, 143 _freeze: bool = False, 144 device=None, 145 dtype=None, 146 ) -> None: 147 factory_kwargs = {"device": device, "dtype": dtype} 148 super().__init__() 149 self.num_embeddings = num_embeddings 150 self.embedding_dim = embedding_dim 151 if padding_idx is not None: 152 if padding_idx > 0: 153 assert ( 154 padding_idx < self.num_embeddings 155 ), "Padding_idx must be within num_embeddings" 156 elif padding_idx < 0: 157 assert ( 158 padding_idx >= -self.num_embeddings 159 ), "Padding_idx must be within num_embeddings" 160 padding_idx = self.num_embeddings + padding_idx 161 self.padding_idx = padding_idx 162 self.max_norm = max_norm 163 self.norm_type = norm_type 164 self.scale_grad_by_freq = scale_grad_by_freq 165 if _weight is None: 166 self.weight = Parameter( 167 torch.empty((num_embeddings, embedding_dim), **factory_kwargs), 168 requires_grad=not _freeze, 169 ) 170 self.reset_parameters() 171 else: 172 assert list(_weight.shape) == [ 173 num_embeddings, 174 embedding_dim, 175 ], "Shape of weight does not match num_embeddings and embedding_dim" 176 self.weight = Parameter(_weight, requires_grad=not _freeze) 177 178 self.sparse = sparse 179 180 def reset_parameters(self) -> None: 181 init.normal_(self.weight) 182 self._fill_padding_idx_with_zero() 183 184 def _fill_padding_idx_with_zero(self) -> None: 185 if self.padding_idx is not None: 186 with torch.no_grad(): 187 self.weight[self.padding_idx].fill_(0) 188 189 def forward(self, input: Tensor) -> Tensor: 190 return F.embedding( 191 input, 192 self.weight, 193 self.padding_idx, 194 self.max_norm, 195 self.norm_type, 196 self.scale_grad_by_freq, 197 self.sparse, 198 ) 199 200 def extra_repr(self) -> str: 201 s = "{num_embeddings}, {embedding_dim}" 202 if self.padding_idx is not None: 203 s += ", padding_idx={padding_idx}" 204 if self.max_norm is not None: 205 s += ", max_norm={max_norm}" 206 if self.norm_type != 2: 207 s += ", norm_type={norm_type}" 208 if self.scale_grad_by_freq is not False: 209 s += ", scale_grad_by_freq={scale_grad_by_freq}" 210 if self.sparse is not False: 211 s += ", sparse=True" 212 return s.format(**self.__dict__) 213 214 @classmethod 215 def from_pretrained( 216 cls, 217 embeddings, 218 freeze=True, 219 padding_idx=None, 220 max_norm=None, 221 norm_type=2.0, 222 scale_grad_by_freq=False, 223 sparse=False, 224 ): 225 r"""Create Embedding instance from given 2-dimensional FloatTensor. 226 227 Args: 228 embeddings (Tensor): FloatTensor containing weights for the Embedding. 229 First dimension is being passed to Embedding as ``num_embeddings``, second as ``embedding_dim``. 230 freeze (bool, optional): If ``True``, the tensor does not get updated in the learning process. 231 Equivalent to ``embedding.weight.requires_grad = False``. Default: ``True`` 232 padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient; 233 therefore, the embedding vector at :attr:`padding_idx` is not updated during training, 234 i.e. it remains as a fixed "pad". 235 max_norm (float, optional): See module initialization documentation. 236 norm_type (float, optional): See module initialization documentation. Default ``2``. 237 scale_grad_by_freq (bool, optional): See module initialization documentation. Default ``False``. 238 sparse (bool, optional): See module initialization documentation. 239 240 Examples:: 241 242 >>> # FloatTensor containing pretrained weights 243 >>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]]) 244 >>> embedding = nn.Embedding.from_pretrained(weight) 245 >>> # Get embeddings for index 1 246 >>> input = torch.LongTensor([1]) 247 >>> # xdoctest: +IGNORE_WANT("non-deterministic") 248 >>> embedding(input) 249 tensor([[ 4.0000, 5.1000, 6.3000]]) 250 """ 251 assert ( 252 embeddings.dim() == 2 253 ), "Embeddings parameter is expected to be 2-dimensional" 254 rows, cols = embeddings.shape 255 embedding = cls( 256 num_embeddings=rows, 257 embedding_dim=cols, 258 _weight=embeddings, 259 _freeze=freeze, 260 padding_idx=padding_idx, 261 max_norm=max_norm, 262 norm_type=norm_type, 263 scale_grad_by_freq=scale_grad_by_freq, 264 sparse=sparse, 265 ) 266 return embedding 267 268 269class EmbeddingBag(Module): 270 r"""Compute sums or means of 'bags' of embeddings, without instantiating the intermediate embeddings. 271 272 For bags of constant length, no :attr:`per_sample_weights`, no indices equal to :attr:`padding_idx`, 273 and with 2D inputs, this class 274 275 * with ``mode="sum"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.sum(dim=1)``, 276 * with ``mode="mean"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.mean(dim=1)``, 277 * with ``mode="max"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.max(dim=1)``. 278 279 However, :class:`~torch.nn.EmbeddingBag` is much more time and memory efficient than using a chain of these 280 operations. 281 282 EmbeddingBag also supports per-sample weights as an argument to the forward 283 pass. This scales the output of the Embedding before performing a weighted 284 reduction as specified by ``mode``. If :attr:`per_sample_weights` is passed, the 285 only supported ``mode`` is ``"sum"``, which computes a weighted sum according to 286 :attr:`per_sample_weights`. 287 288 Args: 289 num_embeddings (int): size of the dictionary of embeddings 290 embedding_dim (int): the size of each embedding vector 291 max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm` 292 is renormalized to have norm :attr:`max_norm`. 293 norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``. 294 scale_grad_by_freq (bool, optional): if given, this will scale gradients by the inverse of frequency of 295 the words in the mini-batch. Default ``False``. 296 Note: this option is not supported when ``mode="max"``. 297 mode (str, optional): ``"sum"``, ``"mean"`` or ``"max"``. Specifies the way to reduce the bag. 298 ``"sum"`` computes the weighted sum, taking :attr:`per_sample_weights` 299 into consideration. ``"mean"`` computes the average of the values 300 in the bag, ``"max"`` computes the max value over each bag. 301 Default: ``"mean"`` 302 sparse (bool, optional): if ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. See 303 Notes for more details regarding sparse gradients. Note: this option is not 304 supported when ``mode="max"``. 305 include_last_offset (bool, optional): if ``True``, :attr:`offsets` has one additional element, where the last element 306 is equivalent to the size of `indices`. This matches the CSR format. 307 padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the 308 gradient; therefore, the embedding vector at :attr:`padding_idx` is not updated 309 during training, i.e. it remains as a fixed "pad". For a newly constructed 310 EmbeddingBag, the embedding vector at :attr:`padding_idx` will default to all 311 zeros, but can be updated to another value to be used as the padding vector. 312 Note that the embedding vector at :attr:`padding_idx` is excluded from the 313 reduction. 314 315 Attributes: 316 weight (Tensor): the learnable weights of the module of shape `(num_embeddings, embedding_dim)` 317 initialized from :math:`\mathcal{N}(0, 1)`. 318 319 Examples:: 320 321 >>> # an EmbeddingBag module containing 10 tensors of size 3 322 >>> embedding_sum = nn.EmbeddingBag(10, 3, mode='sum') 323 >>> # a batch of 2 samples of 4 indices each 324 >>> input = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.long) 325 >>> offsets = torch.tensor([0, 4], dtype=torch.long) 326 >>> # xdoctest: +IGNORE_WANT("non-deterministic") 327 >>> embedding_sum(input, offsets) 328 tensor([[-0.8861, -5.4350, -0.0523], 329 [ 1.1306, -2.5798, -1.0044]]) 330 331 >>> # Example with padding_idx 332 >>> embedding_sum = nn.EmbeddingBag(10, 3, mode='sum', padding_idx=2) 333 >>> input = torch.tensor([2, 2, 2, 2, 4, 3, 2, 9], dtype=torch.long) 334 >>> offsets = torch.tensor([0, 4], dtype=torch.long) 335 >>> embedding_sum(input, offsets) 336 tensor([[ 0.0000, 0.0000, 0.0000], 337 [-0.7082, 3.2145, -2.6251]]) 338 339 >>> # An EmbeddingBag can be loaded from an Embedding like so 340 >>> embedding = nn.Embedding(10, 3, padding_idx=2) 341 >>> embedding_sum = nn.EmbeddingBag.from_pretrained( 342 embedding.weight, 343 padding_idx=embedding.padding_idx, 344 mode='sum') 345 """ 346 347 __constants__ = [ 348 "num_embeddings", 349 "embedding_dim", 350 "max_norm", 351 "norm_type", 352 "scale_grad_by_freq", 353 "mode", 354 "sparse", 355 "include_last_offset", 356 "padding_idx", 357 ] 358 359 num_embeddings: int 360 embedding_dim: int 361 max_norm: Optional[float] 362 norm_type: float 363 scale_grad_by_freq: bool 364 weight: Tensor 365 mode: str 366 sparse: bool 367 include_last_offset: bool 368 padding_idx: Optional[int] 369 370 def __init__( 371 self, 372 num_embeddings: int, 373 embedding_dim: int, 374 max_norm: Optional[float] = None, 375 norm_type: float = 2.0, 376 scale_grad_by_freq: bool = False, 377 mode: str = "mean", 378 sparse: bool = False, 379 _weight: Optional[Tensor] = None, 380 include_last_offset: bool = False, 381 padding_idx: Optional[int] = None, 382 device=None, 383 dtype=None, 384 ) -> None: 385 factory_kwargs = {"device": device, "dtype": dtype} 386 super().__init__() 387 self.num_embeddings = num_embeddings 388 self.embedding_dim = embedding_dim 389 self.max_norm = max_norm 390 self.norm_type = norm_type 391 self.scale_grad_by_freq = scale_grad_by_freq 392 if padding_idx is not None: 393 if padding_idx > 0: 394 assert ( 395 padding_idx < self.num_embeddings 396 ), "padding_idx must be within num_embeddings" 397 elif padding_idx < 0: 398 assert ( 399 padding_idx >= -self.num_embeddings 400 ), "padding_idx must be within num_embeddings" 401 padding_idx = self.num_embeddings + padding_idx 402 self.padding_idx = padding_idx 403 if _weight is None: 404 self.weight = Parameter( 405 torch.empty((num_embeddings, embedding_dim), **factory_kwargs) 406 ) 407 self.reset_parameters() 408 else: 409 assert list(_weight.shape) == [ 410 num_embeddings, 411 embedding_dim, 412 ], "Shape of weight does not match num_embeddings and embedding_dim" 413 self.weight = Parameter(_weight) 414 self.mode = mode 415 self.sparse = sparse 416 self.include_last_offset = include_last_offset 417 418 def reset_parameters(self) -> None: 419 init.normal_(self.weight) 420 self._fill_padding_idx_with_zero() 421 422 def _fill_padding_idx_with_zero(self) -> None: 423 if self.padding_idx is not None: 424 with torch.no_grad(): 425 self.weight[self.padding_idx].fill_(0) 426 427 def forward( 428 self, 429 input: Tensor, 430 offsets: Optional[Tensor] = None, 431 per_sample_weights: Optional[Tensor] = None, 432 ) -> Tensor: 433 """Forward pass of EmbeddingBag. 434 435 Args: 436 input (Tensor): Tensor containing bags of indices into the embedding matrix. 437 offsets (Tensor, optional): Only used when :attr:`input` is 1D. :attr:`offsets` determines 438 the starting index position of each bag (sequence) in :attr:`input`. 439 per_sample_weights (Tensor, optional): a tensor of float / double weights, or None 440 to indicate all weights should be taken to be ``1``. If specified, :attr:`per_sample_weights` 441 must have exactly the same shape as input and is treated as having the same 442 :attr:`offsets`, if those are not ``None``. Only supported for ``mode='sum'``. 443 444 Returns: 445 Tensor output shape of `(B, embedding_dim)`. 446 447 .. note:: 448 449 A few notes about ``input`` and ``offsets``: 450 451 - :attr:`input` and :attr:`offsets` have to be of the same type, either int or long 452 453 - If :attr:`input` is 2D of shape `(B, N)`, it will be treated as ``B`` bags (sequences) 454 each of fixed length ``N``, and this will return ``B`` values aggregated in a way 455 depending on the :attr:`mode`. :attr:`offsets` is ignored and required to be ``None`` in this case. 456 457 - If :attr:`input` is 1D of shape `(N)`, it will be treated as a concatenation of 458 multiple bags (sequences). :attr:`offsets` is required to be a 1D tensor containing the 459 starting index positions of each bag in :attr:`input`. Therefore, for :attr:`offsets` of shape `(B)`, 460 :attr:`input` will be viewed as having ``B`` bags. Empty bags (i.e., having 0-length) will have 461 returned vectors filled by zeros. 462 """ 463 return F.embedding_bag( 464 input, 465 self.weight, 466 offsets, 467 self.max_norm, 468 self.norm_type, 469 self.scale_grad_by_freq, 470 self.mode, 471 self.sparse, 472 per_sample_weights, 473 self.include_last_offset, 474 self.padding_idx, 475 ) 476 477 def extra_repr(self) -> str: 478 s = "{num_embeddings}, {embedding_dim}" 479 if self.max_norm is not None: 480 s += ", max_norm={max_norm}" 481 if self.norm_type != 2: 482 s += ", norm_type={norm_type}" 483 if self.scale_grad_by_freq is not False: 484 s += ", scale_grad_by_freq={scale_grad_by_freq}" 485 s += ", mode={mode}" 486 if self.padding_idx is not None: 487 s += ", padding_idx={padding_idx}" 488 return s.format(**{k: repr(v) for k, v in self.__dict__.items()}) 489 490 @classmethod 491 def from_pretrained( 492 cls, 493 embeddings: Tensor, 494 freeze: bool = True, 495 max_norm: Optional[float] = None, 496 norm_type: float = 2.0, 497 scale_grad_by_freq: bool = False, 498 mode: str = "mean", 499 sparse: bool = False, 500 include_last_offset: bool = False, 501 padding_idx: Optional[int] = None, 502 ) -> "EmbeddingBag": 503 r"""Create EmbeddingBag instance from given 2-dimensional FloatTensor. 504 505 Args: 506 embeddings (Tensor): FloatTensor containing weights for the EmbeddingBag. 507 First dimension is being passed to EmbeddingBag as 'num_embeddings', second as 'embedding_dim'. 508 freeze (bool, optional): If ``True``, the tensor does not get updated in the learning process. 509 Equivalent to ``embeddingbag.weight.requires_grad = False``. Default: ``True`` 510 max_norm (float, optional): See module initialization documentation. Default: ``None`` 511 norm_type (float, optional): See module initialization documentation. Default ``2``. 512 scale_grad_by_freq (bool, optional): See module initialization documentation. Default ``False``. 513 mode (str, optional): See module initialization documentation. Default: ``"mean"`` 514 sparse (bool, optional): See module initialization documentation. Default: ``False``. 515 include_last_offset (bool, optional): See module initialization documentation. Default: ``False``. 516 padding_idx (int, optional): See module initialization documentation. Default: ``None``. 517 518 Examples:: 519 520 >>> # FloatTensor containing pretrained weights 521 >>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]]) 522 >>> embeddingbag = nn.EmbeddingBag.from_pretrained(weight) 523 >>> # Get embeddings for index 1 524 >>> input = torch.LongTensor([[1, 0]]) 525 >>> # xdoctest: +IGNORE_WANT("non-deterministic") 526 >>> embeddingbag(input) 527 tensor([[ 2.5000, 3.7000, 4.6500]]) 528 """ 529 assert ( 530 embeddings.dim() == 2 531 ), "Embeddings parameter is expected to be 2-dimensional" 532 rows, cols = embeddings.shape 533 embeddingbag = cls( 534 num_embeddings=rows, 535 embedding_dim=cols, 536 _weight=embeddings, 537 max_norm=max_norm, 538 norm_type=norm_type, 539 scale_grad_by_freq=scale_grad_by_freq, 540 mode=mode, 541 sparse=sparse, 542 include_last_offset=include_last_offset, 543 padding_idx=padding_idx, 544 ) 545 embeddingbag.weight.requires_grad = not freeze 546 return embeddingbag 547