1""" 2We will recreate all the RNN modules as we require the modules to be decomposed 3into its building blocks to be able to observe. 4""" 5 6# mypy: allow-untyped-defs 7 8import numbers 9import warnings 10from typing import Optional, Tuple 11 12import torch 13from torch import Tensor 14 15 16__all__ = ["LSTMCell", "LSTM"] 17 18 19class LSTMCell(torch.nn.Module): 20 r"""A quantizable long short-term memory (LSTM) cell. 21 22 For the description and the argument types, please, refer to :class:`~torch.nn.LSTMCell` 23 24 Examples:: 25 26 >>> import torch.ao.nn.quantizable as nnqa 27 >>> rnn = nnqa.LSTMCell(10, 20) 28 >>> input = torch.randn(6, 10) 29 >>> hx = torch.randn(3, 20) 30 >>> cx = torch.randn(3, 20) 31 >>> output = [] 32 >>> for i in range(6): 33 ... hx, cx = rnn(input[i], (hx, cx)) 34 ... output.append(hx) 35 """ 36 _FLOAT_MODULE = torch.nn.LSTMCell 37 38 def __init__( 39 self, 40 input_dim: int, 41 hidden_dim: int, 42 bias: bool = True, 43 device=None, 44 dtype=None, 45 ) -> None: 46 factory_kwargs = {"device": device, "dtype": dtype} 47 super().__init__() 48 self.input_size = input_dim 49 self.hidden_size = hidden_dim 50 self.bias = bias 51 52 self.igates = torch.nn.Linear( 53 input_dim, 4 * hidden_dim, bias=bias, **factory_kwargs 54 ) 55 self.hgates = torch.nn.Linear( 56 hidden_dim, 4 * hidden_dim, bias=bias, **factory_kwargs 57 ) 58 self.gates = torch.ao.nn.quantized.FloatFunctional() 59 60 self.input_gate = torch.nn.Sigmoid() 61 self.forget_gate = torch.nn.Sigmoid() 62 self.cell_gate = torch.nn.Tanh() 63 self.output_gate = torch.nn.Sigmoid() 64 65 self.fgate_cx = torch.ao.nn.quantized.FloatFunctional() 66 self.igate_cgate = torch.ao.nn.quantized.FloatFunctional() 67 self.fgate_cx_igate_cgate = torch.ao.nn.quantized.FloatFunctional() 68 69 self.ogate_cy = torch.ao.nn.quantized.FloatFunctional() 70 71 self.initial_hidden_state_qparams: Tuple[float, int] = (1.0, 0) 72 self.initial_cell_state_qparams: Tuple[float, int] = (1.0, 0) 73 self.hidden_state_dtype: torch.dtype = torch.quint8 74 self.cell_state_dtype: torch.dtype = torch.quint8 75 76 def forward( 77 self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None 78 ) -> Tuple[Tensor, Tensor]: 79 if hidden is None or hidden[0] is None or hidden[1] is None: 80 hidden = self.initialize_hidden(x.shape[0], x.is_quantized) 81 hx, cx = hidden 82 83 igates = self.igates(x) 84 hgates = self.hgates(hx) 85 gates = self.gates.add(igates, hgates) 86 87 input_gate, forget_gate, cell_gate, out_gate = gates.chunk(4, 1) 88 89 input_gate = self.input_gate(input_gate) 90 forget_gate = self.forget_gate(forget_gate) 91 cell_gate = self.cell_gate(cell_gate) 92 out_gate = self.output_gate(out_gate) 93 94 fgate_cx = self.fgate_cx.mul(forget_gate, cx) 95 igate_cgate = self.igate_cgate.mul(input_gate, cell_gate) 96 fgate_cx_igate_cgate = self.fgate_cx_igate_cgate.add(fgate_cx, igate_cgate) 97 cy = fgate_cx_igate_cgate 98 99 # TODO: make this tanh a member of the module so its qparams can be configured 100 tanh_cy = torch.tanh(cy) 101 hy = self.ogate_cy.mul(out_gate, tanh_cy) 102 return hy, cy 103 104 def initialize_hidden( 105 self, batch_size: int, is_quantized: bool = False 106 ) -> Tuple[Tensor, Tensor]: 107 h, c = torch.zeros((batch_size, self.hidden_size)), torch.zeros( 108 (batch_size, self.hidden_size) 109 ) 110 if is_quantized: 111 (h_scale, h_zp) = self.initial_hidden_state_qparams 112 (c_scale, c_zp) = self.initial_cell_state_qparams 113 h = torch.quantize_per_tensor( 114 h, scale=h_scale, zero_point=h_zp, dtype=self.hidden_state_dtype 115 ) 116 c = torch.quantize_per_tensor( 117 c, scale=c_scale, zero_point=c_zp, dtype=self.cell_state_dtype 118 ) 119 return h, c 120 121 def _get_name(self): 122 return "QuantizableLSTMCell" 123 124 @classmethod 125 def from_params(cls, wi, wh, bi=None, bh=None): 126 """Uses the weights and biases to create a new LSTM cell. 127 128 Args: 129 wi, wh: Weights for the input and hidden layers 130 bi, bh: Biases for the input and hidden layers 131 """ 132 assert (bi is None) == (bh is None) # Either both None or both have values 133 input_size = wi.shape[1] 134 hidden_size = wh.shape[1] 135 cell = cls(input_dim=input_size, hidden_dim=hidden_size, bias=(bi is not None)) 136 cell.igates.weight = torch.nn.Parameter(wi) 137 if bi is not None: 138 cell.igates.bias = torch.nn.Parameter(bi) 139 cell.hgates.weight = torch.nn.Parameter(wh) 140 if bh is not None: 141 cell.hgates.bias = torch.nn.Parameter(bh) 142 return cell 143 144 @classmethod 145 def from_float(cls, other, use_precomputed_fake_quant=False): 146 assert type(other) == cls._FLOAT_MODULE 147 assert hasattr(other, "qconfig"), "The float module must have 'qconfig'" 148 observed = cls.from_params( 149 other.weight_ih, other.weight_hh, other.bias_ih, other.bias_hh 150 ) 151 observed.qconfig = other.qconfig 152 observed.igates.qconfig = other.qconfig 153 observed.hgates.qconfig = other.qconfig 154 return observed 155 156 157class _LSTMSingleLayer(torch.nn.Module): 158 r"""A single one-directional LSTM layer. 159 160 The difference between a layer and a cell is that the layer can process a 161 sequence, while the cell only expects an instantaneous value. 162 """ 163 164 def __init__( 165 self, 166 input_dim: int, 167 hidden_dim: int, 168 bias: bool = True, 169 device=None, 170 dtype=None, 171 ) -> None: 172 factory_kwargs = {"device": device, "dtype": dtype} 173 super().__init__() 174 self.cell = LSTMCell(input_dim, hidden_dim, bias=bias, **factory_kwargs) 175 176 def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None): 177 result = [] 178 seq_len = x.shape[0] 179 for i in range(seq_len): 180 hidden = self.cell(x[i], hidden) 181 result.append(hidden[0]) # type: ignore[index] 182 result_tensor = torch.stack(result, 0) 183 return result_tensor, hidden 184 185 @classmethod 186 def from_params(cls, *args, **kwargs): 187 cell = LSTMCell.from_params(*args, **kwargs) 188 layer = cls(cell.input_size, cell.hidden_size, cell.bias) 189 layer.cell = cell 190 return layer 191 192 193class _LSTMLayer(torch.nn.Module): 194 r"""A single bi-directional LSTM layer.""" 195 196 def __init__( 197 self, 198 input_dim: int, 199 hidden_dim: int, 200 bias: bool = True, 201 batch_first: bool = False, 202 bidirectional: bool = False, 203 device=None, 204 dtype=None, 205 ) -> None: 206 factory_kwargs = {"device": device, "dtype": dtype} 207 super().__init__() 208 self.batch_first = batch_first 209 self.bidirectional = bidirectional 210 self.layer_fw = _LSTMSingleLayer( 211 input_dim, hidden_dim, bias=bias, **factory_kwargs 212 ) 213 if self.bidirectional: 214 self.layer_bw = _LSTMSingleLayer( 215 input_dim, hidden_dim, bias=bias, **factory_kwargs 216 ) 217 218 def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None): 219 if self.batch_first: 220 x = x.transpose(0, 1) 221 if hidden is None: 222 hx_fw, cx_fw = (None, None) 223 else: 224 hx_fw, cx_fw = hidden 225 hidden_bw: Optional[Tuple[Tensor, Tensor]] = None 226 if self.bidirectional: 227 if hx_fw is None: 228 hx_bw = None 229 else: 230 hx_bw = hx_fw[1] 231 hx_fw = hx_fw[0] 232 if cx_fw is None: 233 cx_bw = None 234 else: 235 cx_bw = cx_fw[1] 236 cx_fw = cx_fw[0] 237 if hx_bw is not None and cx_bw is not None: 238 hidden_bw = hx_bw, cx_bw 239 if hx_fw is None and cx_fw is None: 240 hidden_fw = None 241 else: 242 hidden_fw = torch.jit._unwrap_optional(hx_fw), torch.jit._unwrap_optional( 243 cx_fw 244 ) 245 result_fw, hidden_fw = self.layer_fw(x, hidden_fw) 246 247 if hasattr(self, "layer_bw") and self.bidirectional: 248 x_reversed = x.flip(0) 249 result_bw, hidden_bw = self.layer_bw(x_reversed, hidden_bw) 250 result_bw = result_bw.flip(0) 251 252 result = torch.cat([result_fw, result_bw], result_fw.dim() - 1) 253 if hidden_fw is None and hidden_bw is None: 254 h = None 255 c = None 256 elif hidden_fw is None: 257 (h, c) = torch.jit._unwrap_optional(hidden_bw) 258 elif hidden_bw is None: 259 (h, c) = torch.jit._unwrap_optional(hidden_fw) 260 else: 261 h = torch.stack([hidden_fw[0], hidden_bw[0]], 0) # type: ignore[list-item] 262 c = torch.stack([hidden_fw[1], hidden_bw[1]], 0) # type: ignore[list-item] 263 else: 264 result = result_fw 265 h, c = torch.jit._unwrap_optional(hidden_fw) # type: ignore[assignment] 266 267 if self.batch_first: 268 result.transpose_(0, 1) 269 270 return result, (h, c) 271 272 @classmethod 273 def from_float(cls, other, layer_idx=0, qconfig=None, **kwargs): 274 r""" 275 There is no FP equivalent of this class. This function is here just to 276 mimic the behavior of the `prepare` within the `torch.ao.quantization` 277 flow. 278 """ 279 assert hasattr(other, "qconfig") or (qconfig is not None) 280 281 input_size = kwargs.get("input_size", other.input_size) 282 hidden_size = kwargs.get("hidden_size", other.hidden_size) 283 bias = kwargs.get("bias", other.bias) 284 batch_first = kwargs.get("batch_first", other.batch_first) 285 bidirectional = kwargs.get("bidirectional", other.bidirectional) 286 287 layer = cls(input_size, hidden_size, bias, batch_first, bidirectional) 288 layer.qconfig = getattr(other, "qconfig", qconfig) 289 wi = getattr(other, f"weight_ih_l{layer_idx}") 290 wh = getattr(other, f"weight_hh_l{layer_idx}") 291 bi = getattr(other, f"bias_ih_l{layer_idx}", None) 292 bh = getattr(other, f"bias_hh_l{layer_idx}", None) 293 294 layer.layer_fw = _LSTMSingleLayer.from_params(wi, wh, bi, bh) 295 296 if other.bidirectional: 297 wi = getattr(other, f"weight_ih_l{layer_idx}_reverse") 298 wh = getattr(other, f"weight_hh_l{layer_idx}_reverse") 299 bi = getattr(other, f"bias_ih_l{layer_idx}_reverse", None) 300 bh = getattr(other, f"bias_hh_l{layer_idx}_reverse", None) 301 layer.layer_bw = _LSTMSingleLayer.from_params(wi, wh, bi, bh) 302 return layer 303 304 305class LSTM(torch.nn.Module): 306 r"""A quantizable long short-term memory (LSTM). 307 308 For the description and the argument types, please, refer to :class:`~torch.nn.LSTM` 309 310 Attributes: 311 layers : instances of the `_LSTMLayer` 312 313 .. note:: 314 To access the weights and biases, you need to access them per layer. 315 See examples below. 316 317 Examples:: 318 319 >>> import torch.ao.nn.quantizable as nnqa 320 >>> rnn = nnqa.LSTM(10, 20, 2) 321 >>> input = torch.randn(5, 3, 10) 322 >>> h0 = torch.randn(2, 3, 20) 323 >>> c0 = torch.randn(2, 3, 20) 324 >>> output, (hn, cn) = rnn(input, (h0, c0)) 325 >>> # To get the weights: 326 >>> # xdoctest: +SKIP 327 >>> print(rnn.layers[0].weight_ih) 328 tensor([[...]]) 329 >>> print(rnn.layers[0].weight_hh) 330 AssertionError: There is no reverse path in the non-bidirectional layer 331 """ 332 _FLOAT_MODULE = torch.nn.LSTM 333 334 def __init__( 335 self, 336 input_size: int, 337 hidden_size: int, 338 num_layers: int = 1, 339 bias: bool = True, 340 batch_first: bool = False, 341 dropout: float = 0.0, 342 bidirectional: bool = False, 343 device=None, 344 dtype=None, 345 ) -> None: 346 factory_kwargs = {"device": device, "dtype": dtype} 347 super().__init__() 348 self.input_size = input_size 349 self.hidden_size = hidden_size 350 self.num_layers = num_layers 351 self.bias = bias 352 self.batch_first = batch_first 353 self.dropout = float(dropout) 354 self.bidirectional = bidirectional 355 self.training = False # Default to eval mode. If we want to train, we will explicitly set to training. 356 num_directions = 2 if bidirectional else 1 357 358 if ( 359 not isinstance(dropout, numbers.Number) 360 or not 0 <= dropout <= 1 361 or isinstance(dropout, bool) 362 ): 363 raise ValueError( 364 "dropout should be a number in range [0, 1] " 365 "representing the probability of an element being " 366 "zeroed" 367 ) 368 if dropout > 0: 369 warnings.warn( 370 "dropout option for quantizable LSTM is ignored. " 371 "If you are training, please, use nn.LSTM version " 372 "followed by `prepare` step." 373 ) 374 if num_layers == 1: 375 warnings.warn( 376 "dropout option adds dropout after all but last " 377 "recurrent layer, so non-zero dropout expects " 378 f"num_layers greater than 1, but got dropout={dropout} " 379 f"and num_layers={num_layers}" 380 ) 381 382 layers = [ 383 _LSTMLayer( 384 self.input_size, 385 self.hidden_size, 386 self.bias, 387 batch_first=False, 388 bidirectional=self.bidirectional, 389 **factory_kwargs, 390 ) 391 ] 392 for layer in range(1, num_layers): 393 layers.append( 394 _LSTMLayer( 395 self.hidden_size, 396 self.hidden_size, 397 self.bias, 398 batch_first=False, 399 bidirectional=self.bidirectional, 400 **factory_kwargs, 401 ) 402 ) 403 self.layers = torch.nn.ModuleList(layers) 404 405 def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None): 406 if self.batch_first: 407 x = x.transpose(0, 1) 408 409 max_batch_size = x.size(1) 410 num_directions = 2 if self.bidirectional else 1 411 if hidden is None: 412 zeros = torch.zeros( 413 num_directions, 414 max_batch_size, 415 self.hidden_size, 416 dtype=torch.float, 417 device=x.device, 418 ) 419 zeros.squeeze_(0) 420 if x.is_quantized: 421 zeros = torch.quantize_per_tensor( 422 zeros, scale=1.0, zero_point=0, dtype=x.dtype 423 ) 424 hxcx = [(zeros, zeros) for _ in range(self.num_layers)] 425 else: 426 hidden_non_opt = torch.jit._unwrap_optional(hidden) 427 if isinstance(hidden_non_opt[0], Tensor): 428 hx = hidden_non_opt[0].reshape( 429 self.num_layers, num_directions, max_batch_size, self.hidden_size 430 ) 431 cx = hidden_non_opt[1].reshape( 432 self.num_layers, num_directions, max_batch_size, self.hidden_size 433 ) 434 hxcx = [ 435 (hx[idx].squeeze(0), cx[idx].squeeze(0)) 436 for idx in range(self.num_layers) 437 ] 438 else: 439 hxcx = hidden_non_opt 440 441 hx_list = [] 442 cx_list = [] 443 for idx, layer in enumerate(self.layers): 444 x, (h, c) = layer(x, hxcx[idx]) 445 hx_list.append(torch.jit._unwrap_optional(h)) 446 cx_list.append(torch.jit._unwrap_optional(c)) 447 hx_tensor = torch.stack(hx_list) 448 cx_tensor = torch.stack(cx_list) 449 450 # We are creating another dimension for bidirectional case 451 # need to collapse it 452 hx_tensor = hx_tensor.reshape(-1, hx_tensor.shape[-2], hx_tensor.shape[-1]) 453 cx_tensor = cx_tensor.reshape(-1, cx_tensor.shape[-2], cx_tensor.shape[-1]) 454 455 if self.batch_first: 456 x = x.transpose(0, 1) 457 458 return x, (hx_tensor, cx_tensor) 459 460 def _get_name(self): 461 return "QuantizableLSTM" 462 463 @classmethod 464 def from_float(cls, other, qconfig=None): 465 assert isinstance(other, cls._FLOAT_MODULE) 466 assert hasattr(other, "qconfig") or qconfig 467 observed = cls( 468 other.input_size, 469 other.hidden_size, 470 other.num_layers, 471 other.bias, 472 other.batch_first, 473 other.dropout, 474 other.bidirectional, 475 ) 476 observed.qconfig = getattr(other, "qconfig", qconfig) 477 for idx in range(other.num_layers): 478 observed.layers[idx] = _LSTMLayer.from_float( 479 other, idx, qconfig, batch_first=False 480 ) 481 482 # Prepare the model 483 if other.training: 484 observed.train() 485 observed = torch.ao.quantization.prepare_qat(observed, inplace=True) 486 else: 487 observed.eval() 488 observed = torch.ao.quantization.prepare(observed, inplace=True) 489 return observed 490 491 @classmethod 492 def from_observed(cls, other): 493 # The whole flow is float -> observed -> quantized 494 # This class does float -> observed only 495 raise NotImplementedError( 496 "It looks like you are trying to convert a " 497 "non-quantizable LSTM module. Please, see " 498 "the examples on quantizable LSTMs." 499 ) 500