xref: /aosp_15_r20/external/pytorch/torch/_inductor/fx_passes/binary_folding.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import functools
3import itertools
4
5import torch
6
7from ..._dynamo.utils import counters
8from ..pattern_matcher import Arg, CallFunction, KeywordArg
9from .freezing_patterns import register_binary_folding_pattern
10
11
12aten = torch.ops.aten
13prims = torch.ops.prims
14
15
16def mark_mixed_dtype_conv(conv):
17    conv_dtype = conv.meta["val"].dtype
18    if conv_dtype not in (torch.float16, torch.bfloat16):
19        return
20
21    if not len(conv.users) == 1:
22        return
23
24    conv_user = next(iter(conv.users.keys()))
25    if not isinstance(conv_user.meta["val"], torch.Tensor):
26        return
27
28    if not conv_user.meta["val"].dtype == torch.float32:
29        return
30
31    while conv_user.target in _binary_ops:
32        if not len(conv_user.users) == 1:
33            return
34
35        conv_user = next(iter(conv_user.users.keys()))
36
37    if conv_user.target != prims.convert_element_type.default:
38        return
39
40    conv.meta["_allow_conv_mixed_dtype_folding"] = conv_dtype
41
42
43def mark_mixed_dtype_allowed_convs(gm):
44    """
45    Mark convolutions which we will binary fold even with mixed precision constants. We constant fold in the higher precision
46    for better accuracy and then recover the original precision after.
47    """
48    for node in gm.graph.find_nodes(
49        op="call_function", target=aten.convolution.default
50    ):
51        mark_mixed_dtype_conv(node)
52
53
54def recover_original_precision_folded_convs(gm):
55    """
56    After binary folding conv weights and biases to a higher dtype, recover the original precision they were in.
57    """
58    graph = gm.graph
59    for node in graph.find_nodes(op="call_function", target=aten.convolution.default):
60        orig_dtype = node.meta.get("_allow_conv_mixed_dtype_folding", None)
61        if orig_dtype is None:
62            continue
63
64        with graph.inserting_before(node):
65            for idx in [1, 2]:
66                old_input = node.args[idx]
67                if old_input is None:
68                    continue
69
70                new_input = graph.create_node(
71                    "call_function",
72                    prims.convert_element_type.default,
73                    (old_input, orig_dtype),
74                )
75                node.replace_input_with(old_input, new_input)
76
77
78_binary_ops = [aten.add.Tensor, aten.sub.Tensor, aten.mul.Tensor, aten.div.Tensor]
79
80
81@functools.lru_cache(None)
82def binary_folding_init():
83    _conv_args = [Arg() for _ in range(9)]
84    _computation_ops = [aten.convolution.default]
85    _computation_calls = [CallFunction(aten.convolution.default, *_conv_args, _users=1)]
86
87    """
88    In order to fuse add/sub/mul/div with conv, the dimensions of its
89    constant tensor must satisfy the following:
90    - with resizing, broadcast to w/ weight/bias tensor shape
91    - broadcast to the conv output shape
92    It needs to have a shape that can resize to weight/bias
93    tensor shape because we need to run the op with the conv
94    weights/bias without changing their sizes.
95    It needs to broadcast to the conv output shape so that we do
96    accidentally change the shape of op output by pre-fusing it
97    compared to eager.
98    The only dimension value shared by weight/bias/conv output
99    is they all contain a dim with value = channels-out. In the
100    conv output tensor, this is in the second dimension,
101    so the pointwise op tensor may have a second dimension of
102    value == channels-out, but all the other dimensions have to be 1
103    """
104
105    def _op_not_broadcasting_with_conv(weight_tensor, other_tensor):
106        # According to opDoesNotBroadCastWithConv of frozen_conv_folding.cpp
107        weight_shape = weight_tensor.shape
108        other_shape = other_tensor.shape
109        if len(weight_shape) < len(other_shape):
110            return False
111        if len(weight_shape) == len(other_shape) + 1:
112            # weight shape is [o, i, *], other_shape is [o, 1...].
113            for i in reversed(range(len(other_shape))):
114                if i == 0 and weight_shape[0] == other_shape[i]:
115                    continue
116                if other_shape[i] != 1:
117                    return False
118        else:
119            # weight shape is [o, i, *], other_shape is [1, i, *]
120            for i in reversed(range(len(other_shape))):
121                if i == 1 and weight_shape[0] == other_shape[i]:
122                    continue
123                if other_shape[i] != 1:
124                    return False
125        return True
126
127    def _check_conv_and_broadcast_op(conv_node, other):
128        # According to checkConvAndBroadcastingOpPreConditions of frozen_conv_folding.cpp.
129        # conv.weight
130        if conv_node.args[1].op != "get_attr":
131            return False
132        # conv.bias
133        if conv_node.args[1] is not None and conv_node.args[1].op != "get_attr":
134            return False
135        if (
136            not isinstance(other, int)
137            and not isinstance(other, float)
138            and other.op != "get_attr"
139        ):
140            return False
141
142        if not len(conv_node.args[1].users) == 1:
143            return False
144
145        weight_meta_value = conv_node.args[1].meta.get("val")
146        if weight_meta_value is None:
147            return False
148        # Avoid fusing op that causes type promotion
149        # restricting to float avoids int/float difficulties with scalar overload
150        if not weight_meta_value.is_floating_point():
151            return False
152        if isinstance(other, torch.fx.Node) and other.op == "get_attr":
153            other_meta_value = other.meta.get("val")
154            if not other_meta_value.is_floating_point():  # type: ignore[union-attr]
155                return False
156            if (
157                torch.promote_types(other_meta_value.dtype, weight_meta_value.dtype)  # type: ignore[union-attr]
158                != weight_meta_value.dtype
159            ):
160                if not conv_node.meta.get("_allow_conv_mixed_dtype_folding", False):
161                    return False
162
163                if (
164                    other_meta_value.dtype != torch.float  # type: ignore[union-attr]
165                    and weight_meta_value.dtype not in (torch.float16, torch.bfloat16)
166                ):
167                    return False
168
169            if not _op_not_broadcasting_with_conv(weight_meta_value, other_meta_value):
170                return False
171        else:
172            # TODO: support scalar case
173            return False
174
175        return True
176
177    def _is_foldable_pattern(match):
178        binary_node = match.output_node()
179        computation_node = binary_node.args[0]
180        other = binary_node.args[1]
181        if binary_node.args[0].target not in _computation_ops:
182            computation_node = binary_node.args[1]
183            other = binary_node.args[0]
184        if binary_node.args[0].target == aten.convolution.default:
185            return _check_conv_and_broadcast_op(computation_node, other)
186
187        return False
188
189    def resize_scalar_or_tensor_to_shape(graph, other, shape):
190        # TODO: support scalar case
191        if other.meta.get("val").numel() == 1:
192            # expand errors if the shape input has less # dims than the tensor input
193            res = graph.create_node(
194                "call_function",
195                aten.reshape.default,
196                (other, (1,)),
197            )
198            res = graph.create_node(
199                "call_function",
200                aten.expand.default,
201                (res, shape),
202            )
203        else:
204            res = graph.create_node(
205                "call_function",
206                aten.reshape.default,
207                (other, shape),
208            )
209        return res
210
211    def _create_new_conv_node(graph, conv_node, binary_node, other):
212        assert conv_node.target == aten.convolution.default
213        conv_args = list(conv_node.args)
214        weight_meta_value = conv_node.args[1].meta.get("val")
215        bias = conv_args[2]
216        if binary_node.target in [aten.add.Tensor, aten.sub.Tensor]:
217            other_reshape = resize_scalar_or_tensor_to_shape(
218                graph, other, (weight_meta_value.size(0),)
219            )
220            new_bias = graph.create_node(
221                "call_function",
222                binary_node.target,
223                (0 if bias is None else bias, other_reshape),
224            )
225            conv_args[2] = new_bias
226        else:
227            assert binary_node.target in [aten.mul.Tensor, aten.div.Tensor]
228            weight_broadcast_shape = [1 for _ in range(len(weight_meta_value.shape))]
229            weight_broadcast_shape[0] = weight_meta_value.size(0)
230            other_reshape1 = resize_scalar_or_tensor_to_shape(
231                graph, other, tuple(weight_broadcast_shape)
232            )
233            new_weight = graph.create_node(
234                "call_function", binary_node.target, (conv_args[1], other_reshape1)
235            )
236            new_weight.meta.update(conv_args[1].meta)
237            conv_args[1] = new_weight
238            if bias is not None:
239                other_reshape = resize_scalar_or_tensor_to_shape(
240                    graph, other, (weight_meta_value.size(0),)
241                )
242                new_bias = graph.create_node(
243                    "call_function", binary_node.target, (bias, other_reshape)
244                )
245                new_bias.meta.update(bias.meta)
246                conv_args[2] = new_bias
247        return graph.create_node("call_function", conv_node.target, tuple(conv_args))
248
249    for _computation_call, binary_op in itertools.product(
250        _computation_calls, _binary_ops
251    ):
252
253        @register_binary_folding_pattern(
254            CallFunction(binary_op, _computation_call, KeywordArg("other")),
255            extra_check=_is_foldable_pattern,
256        )
257        def folded_op(match, *args, **kwargs):
258            counters["inductor"]["binary_folding"] += 1
259            other = kwargs.get("other")
260            binary_node = match.output_node()
261            computation_node = (
262                binary_node.args[0]
263                if binary_node.args[0].target in _computation_ops
264                else binary_node.args[1]
265            )
266            graph = match.graph
267            with graph.inserting_before(binary_node):
268                # TODO: support linear?
269                assert computation_node.target == aten.convolution.default
270                new_computation_node = _create_new_conv_node(
271                    graph, computation_node, binary_node, other
272                )
273                binary_node.replace_all_uses_with(new_computation_node)
274                new_computation_node.meta.update(computation_node.meta)
275                graph.erase_node(binary_node)
276                graph.erase_node(computation_node)
277