xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/fuser_method_mappings.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import itertools
3from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
4
5import torch.ao.nn.intrinsic as nni
6import torch.nn as nn
7from torch.ao.quantization.utils import get_combined_dict, MatchAllNode, Pattern
8
9
10__all__ = [
11    "fuse_conv_bn",
12    "fuse_conv_bn_relu",
13    "fuse_linear_bn",
14    "fuse_convtranspose_bn",
15    "get_fuser_method",
16    "get_fuser_method_new",
17]
18
19
20def fuse_conv_bn(is_qat, conv, bn):
21    r"""Return the fused the conv and bn modules.
22    Given the conv and bn modules, fuses them and returns the fused module
23
24    Args:
25        is_qat: a flag for whether we are using quantization aware training fusion
26        or post training quantization fusion
27        conv: Module instance of type conv2d/conv3d
28        bn: Spatial BN instance that needs to be fused with the conv
29
30    Examples::
31
32        >>> m1 = nn.Conv2d(10, 20, 3)
33        >>> b1 = nn.BatchNorm2d(20)
34        >>> # xdoctest: +SKIP
35        >>> m2 = fuse_conv_bn(m1, b1)
36    """
37    assert (
38        conv.training == bn.training
39    ), "Conv and BN both must be in the same mode (train or eval)."
40
41    fused_module_class_map = {
42        nn.Conv1d: nni.ConvBn1d,
43        nn.Conv2d: nni.ConvBn2d,
44        nn.Conv3d: nni.ConvBn3d,
45    }
46
47    if is_qat:
48        assert (
49            bn.num_features == conv.out_channels
50        ), "Output channel of Conv2d must match num_features of BatchNorm2d"
51        assert bn.affine, "Only support fusing BatchNorm2d with affine set to True"
52        assert (
53            bn.track_running_stats
54        ), "Only support fusing BatchNorm2d with tracking_running_stats set to True"
55        fused_module_class = fused_module_class_map.get((type(conv)), None)
56        if fused_module_class is not None:
57            return fused_module_class(conv, bn)
58        else:
59            raise NotImplementedError(f"Cannot fuse train modules: {(conv, bn)}")
60    else:
61        return nn.utils.fuse_conv_bn_eval(conv, bn)
62
63
64def fuse_conv_bn_relu(is_qat, conv, bn, relu):
65    r"""Return the fused conv and bv modules.
66
67    Given the conv and bn modules, fuses them and returns the fused module
68
69    Args:
70        is_qat: a flag for whether we are using quantization aware training fusion
71        or post training quantization fusion
72        conv: Module instance of type conv2d/conv3d
73        bn: Spatial BN instance that needs to be fused with the conv
74
75    Examples::
76
77        >>> m1 = nn.Conv2d(10, 20, 3)
78        >>> b1 = nn.BatchNorm2d(20)
79        >>> r1 = nn.ReLU(inplace=False)
80        >>> # xdoctest: +SKIP
81        >>> m2 = fuse_conv_bn_relu(m1, b1, r1)
82    """
83    assert (
84        conv.training == bn.training == relu.training
85    ), "Conv and BN both must be in the same mode (train or eval)."
86    fused_module: Optional[Type[nn.Sequential]] = None
87    if is_qat:
88        map_to_fused_module_train = {
89            nn.Conv1d: nni.ConvBnReLU1d,
90            nn.Conv2d: nni.ConvBnReLU2d,
91            nn.Conv3d: nni.ConvBnReLU3d,
92        }
93        assert (
94            bn.num_features == conv.out_channels
95        ), "Output channel of Conv must match num_features of BatchNorm"
96        assert bn.affine, "Only support fusing BatchNorm with affine set to True"
97        assert (
98            bn.track_running_stats
99        ), "Only support fusing BatchNorm with tracking_running_stats set to True"
100        fused_module = map_to_fused_module_train.get(type(conv), None)
101        if fused_module is not None:
102            return fused_module(conv, bn, relu)
103        else:
104            raise NotImplementedError(f"Cannot fuse train modules: {(conv, bn, relu)}")
105    else:
106        map_to_fused_module_eval = {
107            nn.Conv1d: nni.ConvReLU1d,
108            nn.Conv2d: nni.ConvReLU2d,
109            nn.Conv3d: nni.ConvReLU3d,
110        }
111        fused_module = map_to_fused_module_eval.get(type(conv), None)
112        if fused_module is not None:
113            fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn)
114            return fused_module(fused_conv, relu)
115        else:
116            raise NotImplementedError(f"Cannot fuse eval modules: {(conv, bn, relu)}")
117
118
119def fuse_linear_bn(is_qat, linear, bn):
120    r"""Return the fused linear and bn modules.
121    Given the linear and bn modules, fuses them and returns the fused module
122
123    Args:
124        is_qat: a flag for whether we are using quantization aware training fusion
125        or post training quantization fusion
126        linear: Module instance of type Linear
127        bn: BatchNorm1d instance that needs to be fused with the linear layer
128
129    Examples::
130
131        >>> m1 = nn.Linear(20, 10)
132        >>> b1 = nn.BatchNorm1d(10)
133        >>> # xdoctest: +SKIP
134        >>> m2 = fuse_linear_bn(m1, b1)
135    """
136    assert (
137        linear.training == bn.training
138    ), "Linear and BN both must be in the same mode (train or eval)."
139
140    if is_qat:
141        assert (
142            bn.num_features == linear.out_features
143        ), "Output features of Linear must match num_features of BatchNorm1d"
144        assert bn.affine, "Only support fusing BatchNorm1d with affine set to True"
145        assert (
146            bn.track_running_stats
147        ), "Only support fusing BatchNorm1d with tracking_running_stats set to True"
148        return nni.LinearBn1d(linear, bn)
149    else:
150        return nn.utils.fusion.fuse_linear_bn_eval(linear, bn)
151
152
153def fuse_convtranspose_bn(is_qat, convt, bn):
154    r"""Return the fused ConvTranspose and bn modules.
155    Given ConvTranspose and bn modules, fuses them and returns the fused module
156
157    Args:
158        convt: Module instance of type ConvTransposeNd
159        bn: BatchNormNd instance that needs to be fused with the linear layer.
160            batch norm N should match the ConvTranspose N
161
162    Examples::
163
164        >>> m1 = nn.ConvTranspose2d(10, 20, 3)
165        >>> b1 = nn.BatchNorm2d(20)
166        >>> # xdoctest: +SKIP
167        >>> m2 = fuse_convtranspose_bn(m1, b1)
168    """
169    assert (
170        convt.training == bn.training
171    ), "ConvTranspose and BN both must be in the same mode (train or eval)."
172
173    if is_qat:
174        raise Exception(  # noqa: TRY002
175            "Fusing ConvTranspose+BatchNorm not yet supported in QAT."
176        )
177    else:
178        return nn.utils.fusion.fuse_conv_bn_eval(convt, bn, transpose=True)
179
180
181def _sequential_wrapper2(sequential):
182    """Return a sequential wrapped that for is_qat and two modules.
183    Given a sequential class for two modules, return a function that takes
184    is_qat, and then two modules as argument, that ignores the is_qat flag
185    and always returns the sequential that combines the two input modules
186    """
187
188    def fuser_method(is_qat, m1, m2):
189        return sequential(m1, m2)
190
191    return fuser_method
192
193
194_DEFAULT_OP_LIST_TO_FUSER_METHOD: Dict[Tuple, Union[nn.Sequential, Callable]] = {
195    (nn.Conv1d, nn.BatchNorm1d): fuse_conv_bn,
196    (nn.Conv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
197    (nn.Conv2d, nn.BatchNorm2d): fuse_conv_bn,
198    (nn.Conv2d, nn.BatchNorm2d, nn.ReLU): fuse_conv_bn_relu,
199    (nn.Conv3d, nn.BatchNorm3d): fuse_conv_bn,
200    (nn.Conv3d, nn.BatchNorm3d, nn.ReLU): fuse_conv_bn_relu,
201    (nn.Conv1d, nn.ReLU): _sequential_wrapper2(nni.ConvReLU1d),
202    (nn.Conv2d, nn.ReLU): _sequential_wrapper2(nni.ConvReLU2d),
203    (nn.Conv3d, nn.ReLU): _sequential_wrapper2(nni.ConvReLU3d),
204    (nn.Linear, nn.BatchNorm1d): fuse_linear_bn,
205    (nn.Linear, nn.ReLU): _sequential_wrapper2(nni.LinearReLU),
206    (nn.BatchNorm2d, nn.ReLU): _sequential_wrapper2(nni.BNReLU2d),
207    (nn.BatchNorm3d, nn.ReLU): _sequential_wrapper2(nni.BNReLU3d),
208    (nn.ConvTranspose1d, nn.BatchNorm1d): fuse_convtranspose_bn,
209    (nn.ConvTranspose2d, nn.BatchNorm2d): fuse_convtranspose_bn,
210    (nn.ConvTranspose3d, nn.BatchNorm3d): fuse_convtranspose_bn,
211}
212
213
214def get_fuser_method(op_list, additional_fuser_method_mapping=None):
215    """Get fuser method for the given list of module types.
216
217    Get fuser method for the given list of module types,
218    return None if fuser method does not exist
219    """
220    if additional_fuser_method_mapping is None:
221        additional_fuser_method_mapping = {}
222    all_mappings = get_combined_dict(
223        _DEFAULT_OP_LIST_TO_FUSER_METHOD, additional_fuser_method_mapping
224    )
225    fuser_method = all_mappings.get(op_list, None)
226    assert fuser_method is not None, f"did not find fuser method for: {op_list} "
227    return fuser_method
228
229
230def _reverse2(f):
231    def reversed(is_qat, x, y):
232        return f(is_qat, y, x)
233
234    return reversed
235
236
237def _reverse3(f):
238    def reversed(is_qat, x, w):
239        y, z = w
240        return f(is_qat, z, y, x)
241
242    return reversed
243
244
245def _get_valid_patterns(op_pattern):
246    """Return a list of valid patterns generated from the op_pattern.
247
248    Returns a list of valid patterns generated from the op_pattern,
249    since MatchAllNode can match all types of nodes,
250    e.g. pattern (torch.nn.Conv2d, torch.add) should also be able to match keys like
251    (MatchAllNode, torch.add) and (torch.nn.Conv2d, MatchAllNode)
252
253    Example Input:
254    (torch.add, (torch.nn.ReLU, torch.nn.Conv2d))
255
256    Example Output:
257    [(torch.add, (torch.nn.ReLU, torch.nn.Conv2d)),
258     (torch.add, (torch.nn.ReLU, MatchAllNode)),
259     (torch.add, (MatchAllNode, torch.nn.Conv2d)),
260     (torch.add, (MatchAllNode, MatchAllNode)),
261     (MatchAllNode, (torch.nn.ReLU, torch.nn.Conv2d)),
262     (MatchAllNode, (torch.nn.ReLU, MatchAllNode)),
263     (MatchAllNode, (MatchAllNode, torch.nn.Conv2d)),
264     (MatchAllNode, (MatchAllNode, MatchAllNode)),
265    ]
266    """
267    result: List[Any]
268    if isinstance(op_pattern, (tuple, list)):
269        sub_combs = []
270        for sub_pattern in op_pattern:
271            sub_combs.append(_get_valid_patterns(sub_pattern))
272        result = list(itertools.product(*sub_combs))
273    else:
274        result = [op_pattern, MatchAllNode]
275    return result
276
277
278def get_fuser_method_new(
279    op_pattern: Pattern,
280    fuser_method_mapping: Dict[Pattern, Union[nn.Sequential, Callable]],
281):
282    """Get fuser method.
283
284    This will be made default after we deprecate the get_fuser_method
285    Would like to implement this first and have a separate PR for deprecation
286    """
287    op_patterns = _get_valid_patterns(op_pattern)
288    fuser_method = None
289    for op_pattern in op_patterns:
290        fuser_method = fuser_method_mapping.get(op_pattern, None)
291        if fuser_method is not None:
292            break
293    assert fuser_method is not None, f"did not find fuser method for: {op_pattern} "
294    return fuser_method
295