1# mypy: allow-untyped-defs 2import torch 3 4 5__all__ = ["Dropout"] 6 7 8class Dropout(torch.nn.Dropout): 9 r"""This is the quantized equivalent of :class:`~torch.nn.Dropout`. 10 And this is a placeholder to enable models where fp32 tensors 11 had dropout to work with quantized tensors in train and eval mode. 12 13 Args: 14 p: probability of an element to be zeroed 15 inplace: can optionally do the operation in-place. Default: ``False`` 16 """ 17 18 def forward(self, input): 19 return input 20 21 def _get_name(self): 22 return "QuantizedDropout" 23 24 @classmethod 25 def from_float(cls, mod, use_precomputed_fake_quant=False): 26 return cls(mod.p, mod.inplace) 27 28 @classmethod 29 def from_reference(cls, mod, scale, zero_point): 30 return cls(mod.p, mod.inplace) 31