1# mypy: allow-untyped-defs 2import torch 3import torch.ao.nn.intrinsic 4import torch.ao.nn.intrinsic.qat 5import torch.ao.nn.quantized as nnq 6import torch.nn.functional as F 7 8 9_reverse_repeat_padding = nnq.modules.conv._reverse_repeat_padding 10 11 12class ConvAdd2d(nnq.Conv2d): 13 r""" 14 A ConvAdd2d module is a fused module of Conv2d and Add 15 16 We adopt the same interface as :class:`torch.ao.nn.quantized.Conv2d`. 17 18 Attributes: 19 Same as torch.ao.nn.quantized.Conv2d 20 21 """ 22 _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvAdd2d # type: ignore[assignment] 23 24 def __init__( 25 self, 26 in_channels, 27 out_channels, 28 kernel_size, 29 stride=1, 30 padding=0, 31 dilation=1, 32 groups=1, 33 bias=True, 34 padding_mode="zeros", 35 device=None, 36 dtype=None, 37 ): 38 super().__init__( 39 in_channels, 40 out_channels, 41 kernel_size, 42 stride=stride, 43 padding=padding, 44 dilation=dilation, 45 groups=groups, 46 bias=bias, 47 padding_mode=padding_mode, 48 device=device, 49 dtype=dtype, 50 ) 51 52 def forward(self, input, extra_input): 53 # Temporarily using len(shape) instead of ndim due to JIT issue 54 # https://github.com/pytorch/pytorch/issues/23890 55 if len(input.shape) != 4: 56 raise ValueError("Input shape must be `(N, C, H, W)`!") 57 if self.padding_mode != "zeros": 58 _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding) 59 input = F.pad( 60 input, _reversed_padding_repeated_twice, mode=self.padding_mode 61 ) 62 return torch.ops.quantized.conv2d_add( 63 input, extra_input, self._packed_params, self.scale, self.zero_point 64 ) 65 66 def _get_name(self): 67 return "QuantizedConvAdd2d" 68 69 @classmethod 70 def from_float(cls, mod, use_precomputed_fake_quant=False): 71 return super().from_float( 72 mod, use_precomputed_fake_quant=use_precomputed_fake_quant 73 ) 74 75 @classmethod 76 def from_reference(cls, ref_qconv, output_scale, output_zero_point): 77 return super().from_reference(ref_qconv[0], output_scale, output_zero_point) 78 79 80class ConvAddReLU2d(nnq.Conv2d): 81 r""" 82 A ConvAddReLU2d module is a fused module of Conv2d, Add and Relu 83 84 We adopt the same interface as :class:`torch.ao.nn.quantized.Conv2d`. 85 86 Attributes: 87 Same as torch.ao.nn.quantized.Conv2d 88 89 """ 90 _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvAddReLU2d # type: ignore[assignment] 91 92 def __init__( 93 self, 94 in_channels, 95 out_channels, 96 kernel_size, 97 stride=1, 98 padding=0, 99 dilation=1, 100 groups=1, 101 bias=True, 102 padding_mode="zeros", 103 device=None, 104 dtype=None, 105 ): 106 super().__init__( 107 in_channels, 108 out_channels, 109 kernel_size, 110 stride=stride, 111 padding=padding, 112 dilation=dilation, 113 groups=groups, 114 bias=bias, 115 padding_mode=padding_mode, 116 device=device, 117 dtype=dtype, 118 ) 119 120 def forward(self, input, extra_input): 121 # Temporarily using len(shape) instead of ndim due to JIT issue 122 # https://github.com/pytorch/pytorch/issues/23890 123 if len(input.shape) != 4: 124 raise ValueError("Input shape must be `(N, C, H, W)`!") 125 if self.padding_mode != "zeros": 126 _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding) 127 input = F.pad( 128 input, _reversed_padding_repeated_twice, mode=self.padding_mode 129 ) 130 return torch.ops.quantized.conv2d_add_relu( 131 input, extra_input, self._packed_params, self.scale, self.zero_point 132 ) 133 134 def _get_name(self): 135 return "QuantizedConvAddReLU2d" 136 137 @classmethod 138 def from_float(cls, mod, use_precomputed_fake_quant=False): 139 return super().from_float( 140 mod, use_precomputed_fake_quant=use_precomputed_fake_quant 141 ) 142 143 @classmethod 144 def from_reference(cls, ref_qconv, output_scale, output_zero_point): 145 return super().from_reference(ref_qconv[0], output_scale, output_zero_point) 146