1# mypy: allow-untyped-defs 2 3from collections import namedtuple 4from typing import List, Sequence 5 6import torch 7import torch.nn.functional as F 8from torch import Tensor 9 10from .container import ModuleList, Sequential 11from .linear import Linear 12from .module import Module 13 14 15__all__ = ["AdaptiveLogSoftmaxWithLoss"] 16 17_ASMoutput = namedtuple("_ASMoutput", ["output", "loss"]) 18 19 20class AdaptiveLogSoftmaxWithLoss(Module): 21 """Efficient softmax approximation. 22 23 As described in 24 `Efficient softmax approximation for GPUs by Edouard Grave, Armand Joulin, 25 Moustapha Ciss\u00e9, David Grangier, and Herv\u00e9 J\u00e9gou 26 <https://arxiv.org/abs/1609.04309>`__. 27""" r""" 28 Adaptive softmax is an approximate strategy for training models with large 29 output spaces. It is most effective when the label distribution is highly 30 imbalanced, for example in natural language modelling, where the word 31 frequency distribution approximately follows the `Zipf's law`_. 32 33 Adaptive softmax partitions the labels into several clusters, according to 34 their frequency. These clusters may contain different number of targets 35 each. 36 Additionally, clusters containing less frequent labels assign lower 37 dimensional embeddings to those labels, which speeds up the computation. 38 For each minibatch, only clusters for which at least one target is 39 present are evaluated. 40 41 The idea is that the clusters which are accessed frequently 42 (like the first one, containing most frequent labels), should also be cheap 43 to compute -- that is, contain a small number of assigned labels. 44 45 We highly recommend taking a look at the original paper for more details. 46 47 * :attr:`cutoffs` should be an ordered Sequence of integers sorted 48 in the increasing order. 49 It controls number of clusters and the partitioning of targets into 50 clusters. For example setting ``cutoffs = [10, 100, 1000]`` 51 means that first `10` targets will be assigned 52 to the 'head' of the adaptive softmax, targets `11, 12, ..., 100` will be 53 assigned to the first cluster, and targets `101, 102, ..., 1000` will be 54 assigned to the second cluster, while targets 55 `1001, 1002, ..., n_classes - 1` will be assigned 56 to the last, third cluster. 57 58 * :attr:`div_value` is used to compute the size of each additional cluster, 59 which is given as 60 :math:`\left\lfloor\frac{\texttt{in\_features}}{\texttt{div\_value}^{idx}}\right\rfloor`, 61 where :math:`idx` is the cluster index (with clusters 62 for less frequent words having larger indices, 63 and indices starting from :math:`1`). 64 65 * :attr:`head_bias` if set to True, adds a bias term to the 'head' of the 66 adaptive softmax. See paper for details. Set to False in the official 67 implementation. 68 69 .. warning:: 70 Labels passed as inputs to this module should be sorted according to 71 their frequency. This means that the most frequent label should be 72 represented by the index `0`, and the least frequent 73 label should be represented by the index `n_classes - 1`. 74 75 .. note:: 76 This module returns a ``NamedTuple`` with ``output`` 77 and ``loss`` fields. See further documentation for details. 78 79 .. note:: 80 To compute log-probabilities for all classes, the ``log_prob`` 81 method can be used. 82 83 Args: 84 in_features (int): Number of features in the input tensor 85 n_classes (int): Number of classes in the dataset 86 cutoffs (Sequence): Cutoffs used to assign targets to their buckets 87 div_value (float, optional): value used as an exponent to compute sizes 88 of the clusters. Default: 4.0 89 head_bias (bool, optional): If ``True``, adds a bias term to the 'head' of the 90 adaptive softmax. Default: ``False`` 91 92 Returns: 93 ``NamedTuple`` with ``output`` and ``loss`` fields: 94 * **output** is a Tensor of size ``N`` containing computed target 95 log probabilities for each example 96 * **loss** is a Scalar representing the computed negative 97 log likelihood loss 98 99 Shape: 100 - input: :math:`(N, \texttt{in\_features})` or :math:`(\texttt{in\_features})` 101 - target: :math:`(N)` or :math:`()` where each value satisfies :math:`0 <= \texttt{target[i]} <= \texttt{n\_classes}` 102 - output1: :math:`(N)` or :math:`()` 103 - output2: ``Scalar`` 104 105 .. _Zipf's law: https://en.wikipedia.org/wiki/Zipf%27s_law 106 """ 107 108 in_features: int 109 n_classes: int 110 cutoffs: List[int] 111 div_value: float 112 head_bias: bool 113 head: Linear 114 tail: ModuleList 115 116 def __init__( 117 self, 118 in_features: int, 119 n_classes: int, 120 cutoffs: Sequence[int], 121 div_value: float = 4.0, 122 head_bias: bool = False, 123 device=None, 124 dtype=None, 125 ) -> None: 126 factory_kwargs = {"device": device, "dtype": dtype} 127 super().__init__() 128 129 cutoffs = list(cutoffs) 130 131 if len(cutoffs) == 0: 132 raise ValueError("cutoffs should be a sequence of length larger than 0") 133 134 if ( 135 (cutoffs != sorted(cutoffs)) 136 or (min(cutoffs) <= 0) 137 or (max(cutoffs) > (n_classes - 1)) 138 or (len(set(cutoffs)) != len(cutoffs)) 139 or any(int(c) != c for c in cutoffs) 140 ): 141 raise ValueError( 142 "cutoffs should be a sequence of unique, positive " 143 "integers sorted in an increasing order, where " 144 "each value is between 1 and n_classes-1" 145 ) 146 147 self.in_features = in_features 148 self.n_classes = n_classes 149 self.cutoffs = cutoffs + [n_classes] 150 self.div_value = div_value 151 self.head_bias = head_bias 152 153 self.shortlist_size = self.cutoffs[0] 154 self.n_clusters = len(self.cutoffs) - 1 155 self.head_size = self.shortlist_size + self.n_clusters 156 157 self.head = Linear( 158 self.in_features, self.head_size, bias=self.head_bias, **factory_kwargs 159 ) 160 self.tail = ModuleList() 161 162 for i in range(self.n_clusters): 163 hsz = int(self.in_features // (self.div_value ** (i + 1))) 164 osz = self.cutoffs[i + 1] - self.cutoffs[i] 165 166 projection = Sequential( 167 Linear(self.in_features, hsz, bias=False, **factory_kwargs), 168 Linear(hsz, osz, bias=False, **factory_kwargs), 169 ) 170 171 self.tail.append(projection) 172 173 def reset_parameters(self) -> None: 174 self.head.reset_parameters() 175 for i2h, h2o in self.tail: 176 i2h.reset_parameters() 177 h2o.reset_parameters() 178 179 def forward(self, input_: Tensor, target_: Tensor) -> _ASMoutput: 180 targ_dim = target_.dim() 181 182 if targ_dim == 1: 183 if input_.size(0) != target_.size(0): 184 raise RuntimeError( 185 "Input and target should have the same size " 186 "in the batch dimension." 187 ) 188 if input_.dim() != 2: 189 raise RuntimeError( 190 "1D target tensor expects 2D input tensors, " 191 "but found inputs with size", 192 input_.size(), 193 ) 194 elif targ_dim == 0: 195 if input_.dim() != 1: 196 raise RuntimeError( 197 "0D target tensor expects 1D input tensors, " 198 "but found inputs with size", 199 input_.size(), 200 ) 201 else: 202 raise RuntimeError( 203 "0D or 1D target tensor expected, " "multi-target not supported" 204 ) 205 206 is_batched = targ_dim > 0 207 input = input_ if is_batched else input_.unsqueeze(0) 208 target = target_ if is_batched else target_.unsqueeze(0) 209 210 used_rows = 0 211 batch_size = target.size(0) 212 213 output = input.new_zeros(batch_size) 214 gather_inds = target.new_empty(batch_size) 215 216 cutoff_values = [0] + self.cutoffs 217 for i in range(len(cutoff_values) - 1): 218 low_idx = cutoff_values[i] 219 high_idx = cutoff_values[i + 1] 220 221 target_mask = (target >= low_idx) & (target < high_idx) 222 row_indices = target_mask.nonzero().squeeze() 223 224 if row_indices.numel() == 0: 225 continue 226 227 if i == 0: 228 gather_inds.index_copy_(0, row_indices, target[target_mask]) 229 230 else: 231 relative_target = target[target_mask] - low_idx 232 input_subset = input.index_select(0, row_indices) 233 234 cluster_output = self.tail[i - 1](input_subset) 235 cluster_index = self.shortlist_size + i - 1 236 237 gather_inds.index_fill_(0, row_indices, cluster_index) 238 cluster_logprob = F.log_softmax(cluster_output, dim=1) 239 local_logprob = cluster_logprob.gather(1, relative_target.unsqueeze(1)) 240 output.index_copy_(0, row_indices, local_logprob.squeeze(1)) 241 242 used_rows += row_indices.numel() 243 244 if used_rows != batch_size: 245 raise RuntimeError( 246 f"Target values should be in [0, {self.n_classes - 1}], " 247 f"but values in range [{target.min().item()}, {target.max().item()}] " 248 "were found. " 249 ) 250 251 head_output = self.head(input) 252 head_logprob = F.log_softmax(head_output, dim=1) 253 output += head_logprob.gather(1, gather_inds.unsqueeze(1)).squeeze() 254 loss = (-output).mean() 255 256 if not is_batched: 257 output = output.squeeze(0) 258 259 return _ASMoutput(output, loss) 260 261 def _get_full_log_prob(self, input, head_output): 262 """Given input tensor, and output of ``self.head``, compute the log of the full distribution.""" 263 out = input.new_empty((head_output.size(0), self.n_classes)) 264 head_logprob = F.log_softmax(head_output, dim=1) 265 266 out[:, : self.shortlist_size] = head_logprob[:, : self.shortlist_size] 267 268 for i, (start_idx, stop_idx) in enumerate(zip(self.cutoffs, self.cutoffs[1:])): 269 cluster_output = self.tail[i](input) 270 cluster_logprob = F.log_softmax(cluster_output, dim=1) 271 output_logprob = cluster_logprob + head_logprob[ 272 :, self.shortlist_size + i 273 ].unsqueeze(1) 274 275 out[:, start_idx:stop_idx] = output_logprob 276 277 return out 278 279 def log_prob(self, input: Tensor) -> Tensor: 280 r"""Compute log probabilities for all :math:`\texttt{n\_classes}`. 281 282 Args: 283 input (Tensor): a minibatch of examples 284 285 Returns: 286 log-probabilities of for each class :math:`c` 287 in range :math:`0 <= c <= \texttt{n\_classes}`, where :math:`\texttt{n\_classes}` is a 288 parameter passed to ``AdaptiveLogSoftmaxWithLoss`` constructor. 289 290 Shape: 291 - Input: :math:`(N, \texttt{in\_features})` 292 - Output: :math:`(N, \texttt{n\_classes})` 293 294 """ 295 head_output = self.head(input) 296 return self._get_full_log_prob(input, head_output) 297 298 def predict(self, input: Tensor) -> Tensor: 299 r"""Return the class with the highest probability for each example in the input minibatch. 300 301 This is equivalent to ``self.log_prob(input).argmax(dim=1)``, but is more efficient in some cases. 302 303 Args: 304 input (Tensor): a minibatch of examples 305 306 Returns: 307 output (Tensor): a class with the highest probability for each example 308 309 Shape: 310 - Input: :math:`(N, \texttt{in\_features})` 311 - Output: :math:`(N)` 312 """ 313 head_output = self.head(input) 314 output = torch.argmax(head_output, dim=1) 315 not_in_shortlist = output >= self.shortlist_size 316 all_in_shortlist = not (not_in_shortlist.any()) 317 318 if all_in_shortlist: 319 return output 320 321 elif not_in_shortlist.all(): 322 log_prob = self._get_full_log_prob(input, head_output) 323 return torch.argmax(log_prob, dim=1) 324 325 else: 326 log_prob = self._get_full_log_prob( 327 input[not_in_shortlist], head_output[not_in_shortlist] 328 ) 329 output[not_in_shortlist] = torch.argmax(log_prob, dim=1) 330 return output 331