xref: /aosp_15_r20/external/pytorch/torch/ao/nn/intrinsic/quantized/modules/conv_add.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch
3import torch.ao.nn.intrinsic
4import torch.ao.nn.intrinsic.qat
5import torch.ao.nn.quantized as nnq
6import torch.nn.functional as F
7
8
9_reverse_repeat_padding = nnq.modules.conv._reverse_repeat_padding
10
11
12class ConvAdd2d(nnq.Conv2d):
13    r"""
14    A ConvAdd2d module is a fused module of Conv2d and Add
15
16    We adopt the same interface as :class:`torch.ao.nn.quantized.Conv2d`.
17
18    Attributes:
19        Same as torch.ao.nn.quantized.Conv2d
20
21    """
22    _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvAdd2d  # type: ignore[assignment]
23
24    def __init__(
25        self,
26        in_channels,
27        out_channels,
28        kernel_size,
29        stride=1,
30        padding=0,
31        dilation=1,
32        groups=1,
33        bias=True,
34        padding_mode="zeros",
35        device=None,
36        dtype=None,
37    ):
38        super().__init__(
39            in_channels,
40            out_channels,
41            kernel_size,
42            stride=stride,
43            padding=padding,
44            dilation=dilation,
45            groups=groups,
46            bias=bias,
47            padding_mode=padding_mode,
48            device=device,
49            dtype=dtype,
50        )
51
52    def forward(self, input, extra_input):
53        # Temporarily using len(shape) instead of ndim due to JIT issue
54        # https://github.com/pytorch/pytorch/issues/23890
55        if len(input.shape) != 4:
56            raise ValueError("Input shape must be `(N, C, H, W)`!")
57        if self.padding_mode != "zeros":
58            _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
59            input = F.pad(
60                input, _reversed_padding_repeated_twice, mode=self.padding_mode
61            )
62        return torch.ops.quantized.conv2d_add(
63            input, extra_input, self._packed_params, self.scale, self.zero_point
64        )
65
66    def _get_name(self):
67        return "QuantizedConvAdd2d"
68
69    @classmethod
70    def from_float(cls, mod, use_precomputed_fake_quant=False):
71        return super().from_float(
72            mod, use_precomputed_fake_quant=use_precomputed_fake_quant
73        )
74
75    @classmethod
76    def from_reference(cls, ref_qconv, output_scale, output_zero_point):
77        return super().from_reference(ref_qconv[0], output_scale, output_zero_point)
78
79
80class ConvAddReLU2d(nnq.Conv2d):
81    r"""
82    A ConvAddReLU2d module is a fused module of Conv2d, Add and Relu
83
84    We adopt the same interface as :class:`torch.ao.nn.quantized.Conv2d`.
85
86    Attributes:
87        Same as torch.ao.nn.quantized.Conv2d
88
89    """
90    _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvAddReLU2d  # type: ignore[assignment]
91
92    def __init__(
93        self,
94        in_channels,
95        out_channels,
96        kernel_size,
97        stride=1,
98        padding=0,
99        dilation=1,
100        groups=1,
101        bias=True,
102        padding_mode="zeros",
103        device=None,
104        dtype=None,
105    ):
106        super().__init__(
107            in_channels,
108            out_channels,
109            kernel_size,
110            stride=stride,
111            padding=padding,
112            dilation=dilation,
113            groups=groups,
114            bias=bias,
115            padding_mode=padding_mode,
116            device=device,
117            dtype=dtype,
118        )
119
120    def forward(self, input, extra_input):
121        # Temporarily using len(shape) instead of ndim due to JIT issue
122        # https://github.com/pytorch/pytorch/issues/23890
123        if len(input.shape) != 4:
124            raise ValueError("Input shape must be `(N, C, H, W)`!")
125        if self.padding_mode != "zeros":
126            _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
127            input = F.pad(
128                input, _reversed_padding_repeated_twice, mode=self.padding_mode
129            )
130        return torch.ops.quantized.conv2d_add_relu(
131            input, extra_input, self._packed_params, self.scale, self.zero_point
132        )
133
134    def _get_name(self):
135        return "QuantizedConvAddReLU2d"
136
137    @classmethod
138    def from_float(cls, mod, use_precomputed_fake_quant=False):
139        return super().from_float(
140            mod, use_precomputed_fake_quant=use_precomputed_fake_quant
141        )
142
143    @classmethod
144    def from_reference(cls, ref_qconv, output_scale, output_zero_point):
145        return super().from_reference(ref_qconv[0], output_scale, output_zero_point)
146