1# mypy: allow-untyped-defs 2 3import torch 4import torch.ao.nn.intrinsic 5import torch.ao.nn.intrinsic.qat 6import torch.ao.nn.quantized as nnq 7 8 9__all__ = ["BNReLU2d", "BNReLU3d"] 10 11 12class BNReLU2d(nnq.BatchNorm2d): 13 r""" 14 A BNReLU2d module is a fused module of BatchNorm2d and ReLU 15 16 We adopt the same interface as :class:`torch.ao.nn.quantized.BatchNorm2d`. 17 18 Attributes: 19 Same as torch.ao.nn.quantized.BatchNorm2d 20 21 """ 22 _FLOAT_MODULE = torch.ao.nn.intrinsic.BNReLU2d 23 24 def __init__(self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None): 25 super().__init__( 26 num_features, eps=eps, momentum=momentum, device=device, dtype=dtype 27 ) 28 29 def forward(self, input): 30 # Temporarily using len(shape) instead of ndim due to JIT issue 31 # https://github.com/pytorch/pytorch/issues/23890 32 if len(input.shape) != 4: 33 raise ValueError("Input shape must be `(N, C, H, W)`!") 34 return torch.ops.quantized.batch_norm2d_relu( 35 input, 36 self.weight, 37 self.bias, 38 self.running_mean, 39 self.running_var, 40 self.eps, 41 self.scale, 42 self.zero_point, 43 ) 44 45 def _get_name(self): 46 return "QuantizedBNReLU2d" 47 48 @classmethod 49 def from_float(cls, mod, use_precomputed_fake_quant=False): 50 # TODO: Add qat support for BNReLU2d 51 return super().from_float( 52 mod, use_precomputed_fake_quant=use_precomputed_fake_quant 53 ) 54 55 @classmethod 56 def from_reference(cls, bn_relu, output_scale, output_zero_point): 57 return super().from_reference(bn_relu[0], output_scale, output_zero_point) 58 59 60class BNReLU3d(nnq.BatchNorm3d): 61 r""" 62 A BNReLU3d module is a fused module of BatchNorm3d and ReLU 63 64 We adopt the same interface as :class:`torch.ao.nn.quantized.BatchNorm3d`. 65 66 Attributes: 67 Same as torch.ao.nn.quantized.BatchNorm3d 68 69 """ 70 _FLOAT_MODULE = torch.ao.nn.intrinsic.BNReLU3d 71 72 def __init__(self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None): 73 super().__init__( 74 num_features, eps=eps, momentum=momentum, device=device, dtype=dtype 75 ) 76 77 def forward(self, input): 78 # Temporarily using len(shape) instead of ndim due to JIT issue 79 # https://github.com/pytorch/pytorch/issues/23890 80 if len(input.shape) != 5: 81 raise ValueError("Input shape must be `(N, C, D, H, W)`!") 82 return torch.ops.quantized.batch_norm3d_relu( 83 input, 84 self.weight, 85 self.bias, 86 self.running_mean, 87 self.running_var, 88 self.eps, 89 self.scale, 90 self.zero_point, 91 ) 92 93 def _get_name(self): 94 return "QuantizedBNReLU3d" 95 96 @classmethod 97 def from_float(cls, mod, use_precomputed_fake_quant=False): 98 # TODO: Add qat support for BNReLU3d 99 return super().from_float( 100 mod, use_precomputed_fake_quant=use_precomputed_fake_quant 101 ) 102 103 @classmethod 104 def from_reference(cls, bn_relu, output_scale, output_zero_point): 105 return super().from_reference(bn_relu[0], output_scale, output_zero_point) 106