1# mypy: allow-untyped-defs 2from typing import Tuple, TypeVar, Union 3 4import torch 5import torch.nn as nn 6from torch.ao.nn.intrinsic import _FusedModule 7from torch.nn.common_types import _size_1_t, _size_2_t, _size_3_t 8from torch.nn.modules.utils import _pair, _single, _triple 9 10 11__all__ = ["Conv1d", "Conv2d", "Conv3d"] 12 13MOD = TypeVar("MOD", bound=nn.modules.conv._ConvNd) 14 15 16class _ConvNd(nn.modules.conv._ConvNd): 17 _FLOAT_MODULE = MOD 18 19 def __init__( 20 self, 21 in_channels: int, 22 out_channels: int, 23 kernel_size: Tuple[int, ...], 24 stride: Tuple[int, ...], 25 padding: Tuple[int, ...], 26 dilation: Tuple[int, ...], 27 transposed: bool, 28 output_padding: Tuple[int, ...], 29 groups: int, 30 bias: bool, 31 padding_mode: str, 32 qconfig=None, 33 device=None, 34 dtype=None, 35 ) -> None: 36 factory_kwargs = {"device": device, "dtype": dtype} 37 nn.modules.conv._ConvNd.__init__( 38 self, 39 in_channels, 40 out_channels, 41 kernel_size, 42 stride, 43 padding, 44 dilation, 45 transposed, 46 output_padding, 47 groups, 48 bias, 49 padding_mode, 50 **factory_kwargs, 51 ) 52 assert qconfig, "qconfig must be provided for QAT module" 53 self.qconfig = qconfig 54 self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs) 55 56 def forward(self, input): 57 return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias) 58 59 @staticmethod 60 def from_float(cls, mod, use_precomputed_fake_quant=False): 61 r"""Create a qat module from a float module 62 63 Args: 64 `mod`: a float module, either produced by torch.ao.quantization utilities 65 or directly from user 66 """ 67 assert type(mod) == cls._FLOAT_MODULE, ( 68 "qat." 69 + cls.__name__ 70 + ".from_float only works for " 71 + cls._FLOAT_MODULE.__name__ # type: ignore[attr-defined] 72 ) 73 assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" 74 assert mod.qconfig, "Input float module must have a valid qconfig" 75 if issubclass(type(mod), _FusedModule): 76 mod = mod[0] # type: ignore[index] 77 qconfig = mod.qconfig 78 qat_conv = cls( 79 mod.in_channels, 80 mod.out_channels, 81 mod.kernel_size, 82 stride=mod.stride, 83 padding=mod.padding, 84 dilation=mod.dilation, 85 groups=mod.groups, 86 bias=mod.bias is not None, 87 padding_mode=mod.padding_mode, 88 qconfig=qconfig, 89 ) 90 qat_conv.weight = mod.weight 91 qat_conv.bias = mod.bias 92 return qat_conv 93 94 def to_float(self): 95 """This works for both single qat conv, and the qat conv - relu modules 96 to convert the qat module to a floating point module 97 """ 98 cls = type(self) 99 conv = cls._FLOAT_CONV_MODULE( # type: ignore[attr-defined, operator] 100 self.in_channels, 101 self.out_channels, 102 self.kernel_size, # type: ignore[arg-type] 103 self.stride, # type: ignore[arg-type] 104 self.padding, # type: ignore[arg-type] 105 self.dilation, # type: ignore[arg-type] 106 self.groups, 107 self.bias is not None, 108 self.padding_mode, 109 ) 110 conv.weight = torch.nn.Parameter(self.weight.detach()) 111 if self.bias is not None: 112 conv.bias = torch.nn.Parameter(self.bias.detach()) 113 # conv relu 114 if issubclass(cls, _FusedModule): 115 modules = [conv] 116 assert hasattr(cls, "_FLOAT_RELU_MODULE") 117 relu = cls._FLOAT_RELU_MODULE() # type: ignore[attr-defined] 118 modules.append(relu) 119 fused = cls._FLOAT_MODULE(*modules) # type: ignore[arg-type, attr-defined, operator] 120 fused.train(self.training) 121 return fused 122 else: 123 return conv 124 125 126class Conv1d(_ConvNd, nn.Conv1d): 127 r""" 128 A Conv1d module attached with FakeQuantize modules for weight, 129 used for quantization aware training. 130 131 We adopt the same interface as :class:`~torch.nn.Conv1d` 132 133 Similar to :class:`~torch.nn.Conv2d`, with FakeQuantize modules initialized to 134 default. 135 136 Attributes: 137 weight_fake_quant: fake quant module for weight 138 """ 139 _FLOAT_MODULE = nn.Conv1d 140 _FLOAT_CONV_MODULE = nn.Conv1d 141 142 def __init__( 143 self, 144 in_channels: int, 145 out_channels: int, 146 kernel_size: _size_1_t, 147 stride: _size_1_t = 1, 148 padding: Union[str, _size_1_t] = 0, 149 dilation: _size_1_t = 1, 150 groups: int = 1, 151 bias: bool = True, 152 padding_mode: str = "zeros", 153 qconfig=None, 154 device=None, 155 dtype=None, 156 ) -> None: 157 kernel_size_ = _single(kernel_size) 158 stride_ = _single(stride) 159 padding_ = padding if isinstance(padding, str) else _single(padding) 160 dilation_ = _single(dilation) 161 super().__init__( 162 in_channels, 163 out_channels, 164 kernel_size_, 165 stride=stride_, 166 padding=padding_, 167 dilation=dilation_, 168 transposed=False, 169 output_padding=_single(0), 170 groups=groups, 171 bias=bias, 172 padding_mode=padding_mode, 173 qconfig=qconfig, 174 device=device, 175 dtype=dtype, 176 ) 177 178 @classmethod 179 def from_float(cls, mod, use_precomputed_fake_quant=False): 180 return super().from_float( 181 cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant 182 ) 183 184 185class Conv2d(_ConvNd, nn.Conv2d): 186 r""" 187 A Conv2d module attached with FakeQuantize modules for weight, 188 used for quantization aware training. 189 190 We adopt the same interface as `torch.nn.Conv2d`, please see 191 https://pytorch.org/docs/stable/nn.html?highlight=conv2d#torch.nn.Conv2d 192 for documentation. 193 194 Similar to `torch.nn.Conv2d`, with FakeQuantize modules initialized to 195 default. 196 197 Attributes: 198 weight_fake_quant: fake quant module for weight 199 """ 200 _FLOAT_MODULE = nn.Conv2d 201 _FLOAT_CONV_MODULE = nn.Conv2d 202 203 def __init__( 204 self, 205 in_channels: int, 206 out_channels: int, 207 kernel_size: _size_2_t, 208 stride: _size_2_t = 1, 209 padding: Union[str, _size_2_t] = 0, 210 dilation: _size_2_t = 1, 211 groups: int = 1, 212 bias: bool = True, 213 padding_mode: str = "zeros", 214 qconfig=None, 215 device=None, 216 dtype=None, 217 ) -> None: 218 kernel_size_ = _pair(kernel_size) 219 stride_ = _pair(stride) 220 padding_ = padding if isinstance(padding, str) else _pair(padding) 221 dilation_ = _pair(dilation) 222 super().__init__( 223 in_channels, 224 out_channels, 225 kernel_size_, 226 stride=stride_, 227 padding=padding_, 228 dilation=dilation_, 229 transposed=False, 230 output_padding=_pair(0), 231 groups=groups, 232 bias=bias, 233 padding_mode=padding_mode, 234 qconfig=qconfig, 235 device=device, 236 dtype=dtype, 237 ) 238 239 def forward(self, input): 240 return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias) 241 242 @classmethod 243 def from_float(cls, mod, use_precomputed_fake_quant=False): 244 return super().from_float( 245 cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant 246 ) 247 248 249class Conv3d(_ConvNd, nn.Conv3d): 250 r""" 251 A Conv3d module attached with FakeQuantize modules for weight, 252 used for quantization aware training. 253 254 We adopt the same interface as `torch.nn.Conv3d`, please see 255 https://pytorch.org/docs/stable/nn.html?highlight=conv3d#torch.nn.Conv3d 256 for documentation. 257 258 Similar to `torch.nn.Conv3d`, with FakeQuantize modules initialized to 259 default. 260 261 Attributes: 262 weight_fake_quant: fake quant module for weight 263 """ 264 _FLOAT_MODULE = nn.Conv3d 265 _FLOAT_CONV_MODULE = nn.Conv3d 266 267 def __init__( 268 self, 269 in_channels: int, 270 out_channels: int, 271 kernel_size: _size_3_t, 272 stride: _size_3_t = 1, 273 padding: Union[str, _size_3_t] = 0, 274 dilation: _size_3_t = 1, 275 groups: int = 1, 276 bias: bool = True, 277 padding_mode: str = "zeros", 278 qconfig=None, 279 device=None, 280 dtype=None, 281 ) -> None: 282 kernel_size_ = _triple(kernel_size) 283 stride_ = _triple(stride) 284 padding_ = padding if isinstance(padding, str) else _triple(padding) 285 dilation_ = _triple(dilation) 286 super().__init__( 287 in_channels, 288 out_channels, 289 kernel_size_, 290 stride=stride_, 291 padding=padding_, 292 dilation=dilation_, 293 transposed=False, 294 output_padding=_triple(0), 295 groups=groups, 296 bias=bias, 297 padding_mode=padding_mode, 298 qconfig=qconfig, 299 device=device, 300 dtype=dtype, 301 ) 302 303 def forward(self, input): 304 return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias) 305 306 @classmethod 307 def from_float(cls, mod, use_precomputed_fake_quant=False): 308 return super().from_float( 309 cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant 310 ) 311