xref: /aosp_15_r20/external/pytorch/torch/ao/nn/intrinsic/modules/fused.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch
3from torch.nn import (
4    BatchNorm1d,
5    BatchNorm2d,
6    BatchNorm3d,
7    Conv1d,
8    Conv2d,
9    Conv3d,
10    Linear,
11    ReLU,
12)
13from torch.nn.utils.parametrize import type_before_parametrizations
14
15
16__all__ = [
17    "ConvReLU1d",
18    "ConvReLU2d",
19    "ConvReLU3d",
20    "LinearReLU",
21    "ConvBn1d",
22    "ConvBn2d",
23    "ConvBnReLU1d",
24    "ConvBnReLU2d",
25    "ConvBn3d",
26    "ConvBnReLU3d",
27    "BNReLU2d",
28    "BNReLU3d",
29    "LinearBn1d",
30    "LinearLeakyReLU",
31    "LinearTanh",
32    "ConvAdd2d",
33    "ConvAddReLU2d",
34]
35
36
37# Used for identifying intrinsic modules used in quantization
38class _FusedModule(torch.nn.Sequential):
39    pass
40
41
42class ConvReLU1d(_FusedModule):
43    r"""This is a sequential container which calls the Conv1d and ReLU modules.
44    During quantization this will be replaced with the corresponding fused module."""
45
46    def __init__(self, conv, relu):
47        assert (
48            type_before_parametrizations(conv) == Conv1d
49            and type_before_parametrizations(relu) == ReLU
50        ), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(relu)}"
51        super().__init__(conv, relu)
52
53
54class ConvReLU2d(_FusedModule):
55    r"""This is a sequential container which calls the Conv2d and ReLU modules.
56    During quantization this will be replaced with the corresponding fused module."""
57
58    def __init__(self, conv, relu):
59        assert (
60            type_before_parametrizations(conv) == Conv2d
61            and type_before_parametrizations(relu) == ReLU
62        ), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(relu)}"
63        super().__init__(conv, relu)
64
65
66class ConvReLU3d(_FusedModule):
67    r"""This is a sequential container which calls the Conv3d and ReLU modules.
68    During quantization this will be replaced with the corresponding fused module."""
69
70    def __init__(self, conv, relu):
71        assert (
72            type_before_parametrizations(conv) == Conv3d
73            and type_before_parametrizations(relu) == ReLU
74        ), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(relu)}"
75        super().__init__(conv, relu)
76
77
78class LinearReLU(_FusedModule):
79    r"""This is a sequential container which calls the Linear and ReLU modules.
80    During quantization this will be replaced with the corresponding fused module."""
81
82    def __init__(self, linear, relu):
83        assert (
84            type_before_parametrizations(linear) == Linear
85            and type_before_parametrizations(relu) == ReLU
86        ), f"Incorrect types for input modules{type_before_parametrizations(linear)}{type_before_parametrizations(relu)}"
87        super().__init__(linear, relu)
88
89
90class ConvBn1d(_FusedModule):
91    r"""This is a sequential container which calls the Conv 1d and Batch Norm 1d modules.
92    During quantization this will be replaced with the corresponding fused module."""
93
94    def __init__(self, conv, bn):
95        assert (
96            type_before_parametrizations(conv) == Conv1d
97            and type_before_parametrizations(bn) == BatchNorm1d
98        ), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(bn)}"
99        super().__init__(conv, bn)
100
101
102class ConvBn2d(_FusedModule):
103    r"""This is a sequential container which calls the Conv 2d and Batch Norm 2d modules.
104    During quantization this will be replaced with the corresponding fused module."""
105
106    def __init__(self, conv, bn):
107        assert (
108            type_before_parametrizations(conv) == Conv2d
109            and type_before_parametrizations(bn) == BatchNorm2d
110        ), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(bn)}"
111        super().__init__(conv, bn)
112
113
114class ConvBnReLU1d(_FusedModule):
115    r"""This is a sequential container which calls the Conv 1d, Batch Norm 1d, and ReLU modules.
116    During quantization this will be replaced with the corresponding fused module."""
117
118    def __init__(self, conv, bn, relu):
119        assert (
120            type_before_parametrizations(conv) == Conv1d
121            and type_before_parametrizations(bn) == BatchNorm1d
122            and type_before_parametrizations(relu) == ReLU
123        ), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(bn)}{type_before_parametrizations(relu)}"  # noqa: B950
124        super().__init__(conv, bn, relu)
125
126
127class ConvBnReLU2d(_FusedModule):
128    r"""This is a sequential container which calls the Conv 2d, Batch Norm 2d, and ReLU modules.
129    During quantization this will be replaced with the corresponding fused module."""
130
131    def __init__(self, conv, bn, relu):
132        assert (
133            type_before_parametrizations(conv) == Conv2d
134            and type_before_parametrizations(bn) == BatchNorm2d
135            and type_before_parametrizations(relu) == ReLU
136        ), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(bn)}{type_before_parametrizations(relu)}"  # noqa: B950
137        super().__init__(conv, bn, relu)
138
139
140class ConvBn3d(_FusedModule):
141    r"""This is a sequential container which calls the Conv 3d and Batch Norm 3d modules.
142    During quantization this will be replaced with the corresponding fused module."""
143
144    def __init__(self, conv, bn):
145        assert (
146            type_before_parametrizations(conv) == Conv3d
147            and type_before_parametrizations(bn) == BatchNorm3d
148        ), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(bn)}"
149        super().__init__(conv, bn)
150
151
152class ConvBnReLU3d(_FusedModule):
153    r"""This is a sequential container which calls the Conv 3d, Batch Norm 3d, and ReLU modules.
154    During quantization this will be replaced with the corresponding fused module."""
155
156    def __init__(self, conv, bn, relu):
157        assert (
158            type_before_parametrizations(conv) == Conv3d
159            and type_before_parametrizations(bn) == BatchNorm3d
160            and type_before_parametrizations(relu) == ReLU
161        ), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(bn)}{type_before_parametrizations(relu)}"  # noqa: B950
162        super().__init__(conv, bn, relu)
163
164
165class BNReLU2d(_FusedModule):
166    r"""This is a sequential container which calls the BatchNorm 2d and ReLU modules.
167    During quantization this will be replaced with the corresponding fused module."""
168
169    def __init__(self, batch_norm, relu):
170        assert (
171            type_before_parametrizations(batch_norm) == BatchNorm2d
172            and type_before_parametrizations(relu) == ReLU
173        ), f"Incorrect types for input modules{type_before_parametrizations(batch_norm)}{type_before_parametrizations(relu)}"
174        super().__init__(batch_norm, relu)
175
176
177class BNReLU3d(_FusedModule):
178    r"""This is a sequential container which calls the BatchNorm 3d and ReLU modules.
179    During quantization this will be replaced with the corresponding fused module."""
180
181    def __init__(self, batch_norm, relu):
182        assert (
183            type_before_parametrizations(batch_norm) == BatchNorm3d
184            and type_before_parametrizations(relu) == ReLU
185        ), f"Incorrect types for input modules{type_before_parametrizations(batch_norm)}{type_before_parametrizations(relu)}"
186        super().__init__(batch_norm, relu)
187
188
189class LinearBn1d(_FusedModule):
190    r"""This is a sequential container which calls the Linear and BatchNorm1d modules.
191    During quantization this will be replaced with the corresponding fused module."""
192
193    def __init__(self, linear, bn):
194        assert (
195            type_before_parametrizations(linear) == Linear
196            and type_before_parametrizations(bn) == BatchNorm1d
197        ), f"Incorrect types for input modules{type_before_parametrizations(linear)}{type_before_parametrizations(bn)}"
198        super().__init__(linear, bn)
199
200
201class LinearLeakyReLU(_FusedModule):
202    r"""This is a sequential container which calls the Linear and LeakyReLU modules.
203    During quantization this will be replaced with the corresponding fused module."""
204
205    def __init__(self, linear, leaky_relu):
206        assert (
207            type(linear) == Linear and type(leaky_relu) == torch.nn.LeakyReLU
208        ), f"Incorrect types for input modules{type(linear)}{type(leaky_relu)}"
209        super().__init__(linear, leaky_relu)
210
211
212class LinearTanh(_FusedModule):
213    r"""This is a sequential container which calls the Linear and Tanh modules.
214    During quantization this will be replaced with the corresponding fused module."""
215
216    def __init__(self, linear, tanh):
217        assert (
218            type(linear) == Linear and type(tanh) == torch.nn.Tanh
219        ), f"Incorrect types for input modules{type(linear)}{type(tanh)}"
220        super().__init__(linear, tanh)
221
222
223class ConvAdd2d(_FusedModule):
224    r"""This is a sequential container which calls the Conv2d modules with extra Add.
225    During quantization this will be replaced with the corresponding fused module."""
226
227    def __init__(self, conv, add):
228        super().__init__(conv)
229        self.add = add
230
231    def forward(self, x1, x2):
232        return self.add(self[0](x1), x2)
233
234
235class ConvAddReLU2d(_FusedModule):
236    r"""This is a sequential container which calls the Conv2d, add, Relu.
237    During quantization this will be replaced with the corresponding fused module."""
238
239    def __init__(self, conv, add, relu):
240        super().__init__(conv)
241        self.add = add
242        self.relu = relu
243
244    def forward(self, x1, x2):
245        return self.relu(self.add(self[0](x1), x2))
246