1# mypy: allow-untyped-defs 2 3import torch 4import torch.ao.nn.intrinsic 5import torch.ao.nn.intrinsic.qat 6import torch.ao.nn.quantized as nnq 7import torch.nn.functional as F 8from torch.nn.utils import fuse_conv_bn_weights 9 10 11__all__ = [ 12 "ConvReLU1d", 13 "ConvReLU2d", 14 "ConvReLU3d", 15] 16 17_reverse_repeat_padding = nnq.modules.conv._reverse_repeat_padding 18 19 20# TODO: factor out the common parts to ConvNd 21class ConvReLU1d(nnq.Conv1d): 22 r""" 23 A ConvReLU1d module is a fused module of Conv1d and ReLU 24 25 We adopt the same interface as :class:`torch.ao.nn.quantized.Conv1d`. 26 27 Attributes: 28 Same as torch.ao.nn.quantized.Conv1d 29 30 """ 31 _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvReLU1d # type: ignore[assignment] 32 33 def __init__( 34 self, 35 in_channels, 36 out_channels, 37 kernel_size, 38 stride=1, 39 padding=0, 40 dilation=1, 41 groups=1, 42 bias=True, 43 padding_mode="zeros", 44 device=None, 45 dtype=None, 46 ): 47 super().__init__( 48 in_channels, 49 out_channels, 50 kernel_size, 51 stride=stride, 52 padding=padding, 53 dilation=dilation, 54 groups=groups, 55 bias=bias, 56 padding_mode=padding_mode, 57 device=device, 58 dtype=dtype, 59 ) 60 61 def forward(self, input): 62 # Temporarily using len(shape) instead of ndim due to JIT issue 63 # https://github.com/pytorch/pytorch/issues/23890 64 if len(input.shape) != 3: 65 raise ValueError("Input shape must be `(N, C, L)`!") 66 if self.padding_mode != "zeros": 67 # Padding in Conv1d is stored as (p, p), need to get (p,) 68 _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding[:1]) 69 input = F.pad( 70 input, _reversed_padding_repeated_twice, mode=self.padding_mode 71 ) 72 return torch.ops.quantized.conv1d_relu( 73 input, self._packed_params, self.scale, self.zero_point 74 ) 75 76 def _get_name(self): 77 return "QuantizedConvReLU1d" 78 79 @classmethod 80 def from_float(cls, mod, use_precomputed_fake_quant=False): 81 if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU1d: 82 assert mod.bn.running_var is not None and mod.bn.running_mean is not None 83 mod.weight, mod.bias = fuse_conv_bn_weights( 84 mod.weight, 85 mod.bias, 86 mod.bn.running_mean, 87 mod.bn.running_var, 88 mod.bn.eps, 89 mod.bn.weight, 90 mod.bn.bias, 91 ) 92 return super().from_float(mod, use_precomputed_fake_quant) 93 94 @classmethod 95 def from_reference(cls, ref_qconv, output_scale, output_zero_point): 96 assert ( 97 type(ref_qconv) != torch.ao.nn.intrinsic.ConvBnReLU1d 98 ), "BatchNorm1d should be fused into Conv1d before converting to reference module" 99 return super().from_reference(ref_qconv[0], output_scale, output_zero_point) 100 101 102class ConvReLU2d(nnq.Conv2d): 103 r""" 104 A ConvReLU2d module is a fused module of Conv2d and ReLU 105 106 We adopt the same interface as :class:`torch.ao.nn.quantized.Conv2d`. 107 108 Attributes: 109 Same as torch.ao.nn.quantized.Conv2d 110 111 """ 112 _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvReLU2d # type: ignore[assignment] 113 114 def __init__( 115 self, 116 in_channels, 117 out_channels, 118 kernel_size, 119 stride=1, 120 padding=0, 121 dilation=1, 122 groups=1, 123 bias=True, 124 padding_mode="zeros", 125 device=None, 126 dtype=None, 127 ): 128 super().__init__( 129 in_channels, 130 out_channels, 131 kernel_size, 132 stride=stride, 133 padding=padding, 134 dilation=dilation, 135 groups=groups, 136 bias=bias, 137 padding_mode=padding_mode, 138 device=device, 139 dtype=dtype, 140 ) 141 142 def forward(self, input): 143 # Temporarily using len(shape) instead of ndim due to JIT issue 144 # https://github.com/pytorch/pytorch/issues/23890 145 if len(input.shape) != 4: 146 raise ValueError("Input shape must be `(N, C, H, W)`!") 147 if self.padding_mode != "zeros": 148 _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding) 149 input = F.pad( 150 input, _reversed_padding_repeated_twice, mode=self.padding_mode 151 ) 152 return torch.ops.quantized.conv2d_relu( 153 input, self._packed_params, self.scale, self.zero_point 154 ) 155 156 def _get_name(self): 157 return "QuantizedConvReLU2d" 158 159 @classmethod 160 def from_float(cls, mod, use_precomputed_fake_quant=False): 161 if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU2d: 162 assert mod.bn.running_var is not None and mod.bn.running_mean is not None 163 mod.weight, mod.bias = fuse_conv_bn_weights( 164 mod.weight, 165 mod.bias, 166 mod.bn.running_mean, 167 mod.bn.running_var, 168 mod.bn.eps, 169 mod.bn.weight, 170 mod.bn.bias, 171 ) 172 return super().from_float( 173 mod, use_precomputed_fake_quant=use_precomputed_fake_quant 174 ) 175 176 @classmethod 177 def from_reference(cls, ref_qconv, output_scale, output_zero_point): 178 assert ( 179 type(ref_qconv) != torch.ao.nn.intrinsic.ConvBnReLU2d 180 ), "BatchNorm2d should be fused into Conv2d before converting to reference module" 181 return super().from_reference(ref_qconv[0], output_scale, output_zero_point) 182 183 184class ConvReLU3d(nnq.Conv3d): 185 r""" 186 A ConvReLU3d module is a fused module of Conv3d and ReLU 187 188 We adopt the same interface as :class:`torch.ao.nn.quantized.Conv3d`. 189 190 Attributes: Same as torch.ao.nn.quantized.Conv3d 191 192 """ 193 _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvReLU3d # type: ignore[assignment] 194 195 def __init__( 196 self, 197 in_channels, 198 out_channels, 199 kernel_size, 200 stride=1, 201 padding=0, 202 dilation=1, 203 groups=1, 204 bias=True, 205 padding_mode="zeros", 206 device=None, 207 dtype=None, 208 ): 209 assert padding_mode != "reflect", "Conv3d does not support reflection padding" 210 super().__init__( 211 in_channels, 212 out_channels, 213 kernel_size, 214 stride=stride, 215 padding=padding, 216 dilation=dilation, 217 groups=groups, 218 bias=bias, 219 padding_mode=padding_mode, 220 device=device, 221 dtype=dtype, 222 ) 223 224 def forward(self, input): 225 # Temporarily using len(shape) instead of ndim due to JIT issue 226 # https://github.com/pytorch/pytorch/issues/23890 227 if len(input.shape) != 5: 228 raise ValueError("Input shape must be `(N, C, D, H, W)`!") 229 if self.padding_mode != "zeros": 230 _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding) 231 input = F.pad( 232 input, _reversed_padding_repeated_twice, mode=self.padding_mode 233 ) 234 return torch.ops.quantized.conv3d_relu( 235 input, self._packed_params, self.scale, self.zero_point 236 ) 237 238 def _get_name(self): 239 return "QuantizedConvReLU3d" 240 241 @classmethod 242 def from_float(cls, mod, use_precomputed_fake_quant=False): 243 if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU3d: 244 assert mod.bn.running_var is not None and mod.bn.running_mean is not None 245 mod.weight, mod.bias = fuse_conv_bn_weights( 246 mod.weight, 247 mod.bias, 248 mod.bn.running_mean, 249 mod.bn.running_var, 250 mod.bn.eps, 251 mod.bn.weight, 252 mod.bn.bias, 253 ) 254 return super().from_float( 255 mod, use_precomputed_fake_quant=use_precomputed_fake_quant 256 ) 257 258 @classmethod 259 def from_reference(cls, ref_qconv, output_scale, output_zero_point): 260 assert ( 261 type(ref_qconv) != torch.ao.nn.intrinsic.ConvBnReLU3d 262 ), "BatchNorm3d should be fused into Conv3d before converting to reference module" 263 return super().from_reference(ref_qconv[0], output_scale, output_zero_point) 264