xref: /aosp_15_r20/external/executorch/backends/cadence/aot/quantizer/fusion_pass.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9from typing import Any, Dict, List, Tuple
10
11import torch
12from executorch.backends.cadence.aot.quantizer.patterns import (
13    AddmmPattern,
14    BmmPattern,
15    Conv1dPattern,
16    Conv2dPattern,
17    LayerNormPattern,
18    LinearPattern,
19    MatmulPattern,
20    ReluPattern0,
21    ReluPattern1,
22)
23from executorch.backends.cadence.aot.quantizer.utils import (
24    create_zero_bias_int32,
25    find_sequential_partitions_aten,
26    get_conv_args,
27    quantize_tensor_multiplier,
28)
29from executorch.exir.pass_base import ExportPass
30from torch import fx
31from torch.fx import GraphModule
32from torch.fx.passes.infra.pass_base import PassResult
33from torch.fx.passes.utils.fuser_utils import legalize_graph
34
35
36# Use this to avoid pyre errors
37# pyre-ignore[33]: `_ModelInputsType` cannot alias to `Any`.
38ArgsType = Any
39
40# Use this part for patterns with multiple aten ops
41ReluPatterns = (ReluPattern0, ReluPattern1)
42
43
44# Helper function to get the args and kwargs for the linear replacement op
45def get_args_and_kwargs_linear(
46    graph_module: GraphModule,
47    inputs_inputs: List[fx.Node],
48    dequants_inputs: List[fx.Node],
49    weights_inputs: List[fx.Node],
50    dequants_weights: List[fx.Node],
51    bias_inputs: List[fx.Node],
52    quant_node: fx.Node,
53) -> Tuple[Tuple[ArgsType], Dict[str, ArgsType]]:
54    """
55    Returns the args and kwargs for the linear replacement op.
56    """
57    weight_scale = dequants_weights[0].args[1]
58    # pyre-fixme[58]: Unsupported operand types
59    bias_scale = dequants_inputs[0].args[1] * weight_scale
60    requantize_scale = bias_scale / quant_node.args[1]
61    requantize_scale_t = torch.tensor([requantize_scale])
62
63    (out_multiplier, out_shift) = quantize_tensor_multiplier(requantize_scale_t)
64
65    # If bias is not available, create a bias tensor with the shape of weight[0]
66    if not bias_inputs:
67        weight_node = dequants_weights[0].args[0]
68        assert isinstance(weight_node, fx.Node)
69        bias = create_zero_bias_int32(graph_module, weight_node, bias_scale)
70    else:
71        bias = bias_inputs[0]
72
73    # Create single element tensors for weight_zero_point, out_multiplier, out_shift.
74    # Note that the function expects int32_t, when it would default to int64_t, so
75    # we explicitly require that type.
76    weight_zero_point_ = graph_module.graph.call_function(
77        torch.ops.aten.full.default,
78        ([1], dequants_weights[0].args[2]),
79        {"dtype": torch.int32},
80    )
81    out_multiplier_ = graph_module.graph.call_function(
82        torch.ops.aten.full.default,
83        ([1], out_multiplier[0].item()),
84        {"dtype": torch.int32},
85    )
86    out_shift_ = graph_module.graph.call_function(
87        torch.ops.aten.full.default,
88        ([1], out_shift[0].item()),
89        {"dtype": torch.int32},
90    )
91
92    args = tuple(inputs_inputs + weights_inputs + [bias])
93    kwargs = {
94        "src_zero_point": dequants_inputs[0].args[2],
95        "weight_zero_point": weight_zero_point_,
96        "out_multiplier": out_multiplier_,
97        "out_shift": out_shift_,
98        "out_zero_point": quant_node.args[2],
99        "offset": None,
100    }
101    return args, kwargs
102
103
104# Helper function to get the args and kwargs for the layer norm replacement op
105def get_args_and_kwargs_layer_norm(
106    graph_module: GraphModule,
107    inputs_inputs: List[fx.Node],
108    dequants_inputs: List[fx.Node],
109    other_inputs: List[fx.Node],
110    quant_node: fx.Node,
111) -> Tuple[Tuple[ArgsType], Dict[str, ArgsType]]:
112    """
113    Returns the args and kwargs for the layer norm replacement op.
114    """
115    # Check if the input is per-channel quantized
116    # TODO(matthiascremon): add proper support and testing for per-channel quantization
117    assert isinstance(dequants_inputs[0].args[1], float) and isinstance(
118        dequants_inputs[0].args[2], int
119    ), "per-channel quantization is not supported for layer norm, both scale and zero_point should be scalars"
120
121    # Make the scale and zero_point tensors
122    scale_tensor = graph_module.graph.call_function(
123        torch.ops.aten.full.default,
124        (
125            [1],
126            dequants_inputs[0].args[1],
127        ),
128        {"dtype": torch.float32},
129    )
130    zero_point_tensor = graph_module.graph.call_function(
131        torch.ops.aten.full.default,
132        (
133            [1],
134            dequants_inputs[0].args[2],
135        ),
136        {"dtype": torch.int32},
137    )
138
139    weight = other_inputs[1] if len(other_inputs) > 1 else None
140
141    if not weight:
142        weight = graph_module.graph.call_function(
143            torch.ops.aten.full.default,
144            (
145                other_inputs[0],
146                1,
147            ),
148            {"dtype": torch.float32},
149        )
150
151    bias = other_inputs[2] if len(other_inputs) > 2 else None
152
153    if not bias:
154        bias = graph_module.graph.call_function(
155            torch.ops.aten.full.default,
156            (
157                other_inputs[0],
158                0,
159            ),
160            {"dtype": torch.float32},
161        )
162
163    # Make the args and kwargs for the replacement op
164    args = tuple(inputs_inputs + [scale_tensor] + [zero_point_tensor])
165    kwargs = {
166        "normalized_shape": other_inputs[0],
167        "weight": weight,
168        "bias": bias,
169        "eps": 1e-05,
170        "output_scale": quant_node.args[1],
171        "output_zero_point": quant_node.args[2],
172    }
173    return args, kwargs
174
175
176def get_args_and_kwargs_matmul(
177    inputs_inputs: List[fx.Node],
178    dequants_inputs: List[fx.Node],
179    quant_node: fx.Node,
180) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]:
181    requantize_scale = (
182        # pyre-ignore[58]: Unsupported operand
183        dequants_inputs[0].args[1]
184        * dequants_inputs[1].args[1]
185    ) / quant_node.args[1]
186    requantize_scale_t = torch.tensor([requantize_scale])
187
188    (out_multiplier, out_shift) = quantize_tensor_multiplier(requantize_scale_t)
189
190    args = (
191        inputs_inputs[0],
192        dequants_inputs[0].args[2],
193        inputs_inputs[1],
194        dequants_inputs[1].args[2],
195        None,
196    )
197
198    kwargs = {
199        "out_multiplier": out_multiplier[0].item(),
200        "out_shift": out_shift[0].item(),
201        "out_zero_point": quant_node.args[2],
202        "transposed": False,
203    }
204    return args, kwargs
205
206
207def get_args_and_kwargs_conv(
208    graph_module: GraphModule,
209    inputs_inputs: List[fx.Node],
210    dequants_inputs: List[fx.Node],
211    weights_inputs: List[fx.Node],
212    dequants_weights: List[fx.Node],
213    bias_inputs: List[fx.Node],
214    quant_node: fx.Node,
215    op_node: fx.Node,
216) -> Tuple[Tuple[ArgsType], Dict[str, ArgsType]]:
217    weight_scale = dequants_weights[0].args[1]
218    weight_zero_point = dequants_weights[0].args[2]
219    # pyre-fixme[58]: Unsupported operand types
220    bias_scale = dequants_inputs[0].args[1] * weight_scale
221    stride = [1, 1] if len(op_node.args) < 4 else get_conv_args(op_node.args[3], 1)
222    padding = [0, 0] if len(op_node.args) < 5 else get_conv_args(op_node.args[4], 0)
223    dilation = [1, 1] if len(op_node.args) < 6 else get_conv_args(op_node.args[5], 1)
224    groups = 1 if len(op_node.args) < 7 else op_node.args[6]
225
226    # If bias is not available, create a bias tensor with the shape of weight[0]
227    if not bias_inputs:
228        weight_node = dequants_weights[0].args[0]
229        assert isinstance(weight_node, fx.Node)
230        bias = create_zero_bias_int32(graph_module, weight_node, bias_scale)
231    else:
232        bias = bias_inputs[0]
233
234    # Compute the out multiplier and out shift. They are used when the conv op is
235    # replaced by quantized linear, we compute them a priori for simplicity but
236    # may revisit the decision.
237    requantize_scale = bias_scale / quant_node.args[1]
238    requantize_scale_t = torch.tensor([requantize_scale])
239
240    (out_multiplier, out_shift) = quantize_tensor_multiplier(requantize_scale_t)
241
242    out_multiplier_ = graph_module.graph.call_function(
243        torch.ops.aten.full.default,
244        ([1], out_multiplier[0].item()),
245        {"dtype": torch.int32},
246    )
247    out_shift_ = graph_module.graph.call_function(
248        torch.ops.aten.full.default,
249        ([1], out_shift[0].item()),
250        {"dtype": torch.int32},
251    )
252
253    # Create a single element tensor for the weight zero point
254    weight_zero_point_tensor = graph_module.graph.call_function(
255        torch.ops.aten.full.default,
256        ([1], weight_zero_point),
257        {"dtype": torch.int32},
258    )
259
260    # Create a single element tensor for the bias scale
261    bias_scale_tensor = graph_module.graph.call_function(
262        torch.ops.aten.full.default,
263        ([1], bias_scale),
264        {"dtype": torch.float32},
265    )
266
267    # Make the args and kwargs for the replacement op
268    args = tuple(inputs_inputs + weights_inputs + [bias])
269    kwargs = {
270        "stride": stride,
271        "padding": padding,
272        "dilation": dilation,
273        "groups": groups,
274        "input_zero_point": dequants_inputs[0].args[2],
275        "weight_zero_point": weight_zero_point_tensor,
276        "bias_scale": bias_scale_tensor,
277        "out_scale": quant_node.args[1],
278        "out_zero_point": quant_node.args[2],
279        "out_multiplier": out_multiplier_,
280        "out_shift": out_shift_,
281        "channel_last": False,
282    }
283    return args, kwargs
284
285
286def get_args_and_kwargs_relu(
287    graph_module: GraphModule,
288    inputs_inputs: List[fx.Node],
289    dequants_inputs: List[fx.Node],
290    quant_node: fx.Node,
291) -> Tuple[Tuple[ArgsType], Dict[str, ArgsType]]:
292    input_scale = dequants_inputs[0].args[1]
293    # pyre-fixme[58]: Unsupported operand types
294    requantize_scale = input_scale / quant_node.args[1]
295    requantize_scale_t = torch.tensor([requantize_scale])
296
297    (out_multiplier, out_shift) = quantize_tensor_multiplier(requantize_scale_t)
298
299    # Make the args and kwargs for the replacement op
300    args = tuple(inputs_inputs)
301
302    X_zero_point = graph_module.graph.call_function(
303        torch.ops.aten.full.default,
304        ([1], dequants_inputs[0].args[2]),
305        {"dtype": torch.int32},
306    )
307    out_multiplier_ = graph_module.graph.call_function(
308        torch.ops.aten.full.default,
309        ([1], out_multiplier[0].item()),
310        {"dtype": torch.int32},
311    )
312    out_shift_ = graph_module.graph.call_function(
313        torch.ops.aten.full.default,
314        ([1], out_shift[0].item()),
315        {"dtype": torch.int32},
316    )
317
318    kwargs = {
319        "X_zero_point": X_zero_point,
320        "out_zero_point": quant_node.args[2],
321        "out_multiplier": out_multiplier_,
322        "out_shift": out_shift_,
323    }
324    return args, kwargs
325
326
327class QuantFusion(ExportPass):
328    # pyre-ignore[2]: Parameter `patterns` has no type specified
329    def __init__(self, patterns) -> None:
330        super().__init__()
331        # pyre-ignore[4]: Parameter `patterns` of class `QuantFusion` has no type specified
332        self.patterns = patterns
333
334    def call(self, graph_module: fx.GraphModule) -> PassResult:  # noqa: C901
335        for pattern in self.patterns:
336            fused_partitions = find_sequential_partitions_aten(
337                graph_module,
338                pattern.partition_types(),
339            )
340            for fused_partition in fused_partitions:
341                anchors = pattern.get_anchors(graph_module, fused_partition)
342                if not anchors:
343                    continue
344                if any(self.is_fused(p.nodes) for p in fused_partition):
345                    continue
346
347                for p in fused_partition:
348                    self.mark_fused(p.nodes)
349
350                dequants_inputs = []
351                for node, idx in anchors.inputs:
352                    if (
353                        node.args[idx].target
354                        == torch.ops.quantized_decomposed.dequantize_per_tensor.default
355                    ):
356                        dequants_inputs.append(node.args[idx])
357                dequants_weights = []
358                for node, idx in anchors.weights:
359                    if (
360                        node.args[idx].target
361                        == torch.ops.quantized_decomposed.dequantize_per_tensor.default
362                    ):
363                        dequants_weights.append(node.args[idx])
364                dequants_biases = []
365                for node, idx, *_spec in anchors.biases:
366                    if (
367                        node.args[idx].target
368                        == torch.ops.quantized_decomposed.dequantize_per_tensor.default
369                    ):
370                        dequants_biases.append(node.args[idx])
371
372                inputs_inputs = [node.args[0] for node in dequants_inputs]
373                weights_inputs = [node.args[0] for node in dequants_weights]
374                bias_inputs = [node.args[0] for node in dequants_biases]
375                other_inputs = [node.args[idx] for node, idx in anchors.others]
376
377                # The node is the first index of the list and first of the tuple
378                op_node = anchors.output[0][0]
379
380                assert len(op_node.users) == 1
381                quant_node = list(op_node.users.keys())[0]
382
383                with graph_module.graph.inserting_after(op_node):
384                    args = tuple(
385                        inputs_inputs + weights_inputs + other_inputs + bias_inputs
386                    )
387                    kwargs = {}
388                    if isinstance(pattern, (Conv1dPattern, Conv2dPattern)):
389                        args, kwargs = get_args_and_kwargs_conv(
390                            graph_module,
391                            inputs_inputs,
392                            dequants_inputs,
393                            weights_inputs,
394                            dequants_weights,
395                            bias_inputs,
396                            quant_node,
397                            op_node,
398                        )
399                    elif isinstance(pattern, LinearPattern):
400                        args, kwargs = get_args_and_kwargs_linear(
401                            graph_module,
402                            inputs_inputs,
403                            dequants_inputs,
404                            weights_inputs,
405                            dequants_weights,
406                            bias_inputs,
407                            quant_node,
408                        )
409                    elif isinstance(pattern, LayerNormPattern):
410                        args, kwargs = get_args_and_kwargs_layer_norm(
411                            graph_module,
412                            inputs_inputs,
413                            dequants_inputs,
414                            other_inputs,
415                            quant_node,
416                        )
417                    elif isinstance(pattern, (BmmPattern, MatmulPattern)):
418                        args, kwargs = get_args_and_kwargs_matmul(
419                            inputs_inputs,
420                            dequants_inputs,
421                            quant_node,
422                        )
423                    elif isinstance(pattern, AddmmPattern):
424                        # Transpose the weight tensor
425                        transposed_weights = graph_module.graph.call_function(
426                            torch.ops.aten.transpose.int,
427                            (weights_inputs[0], 0, 1),
428                        )
429                        # Call linear with transposed weight
430                        args, kwargs = get_args_and_kwargs_linear(
431                            graph_module,
432                            inputs_inputs,
433                            dequants_inputs,
434                            [transposed_weights],
435                            dequants_weights,
436                            bias_inputs,
437                            quant_node,
438                        )
439                    elif isinstance(pattern, ReluPatterns):
440                        args, kwargs = get_args_and_kwargs_relu(
441                            graph_module,
442                            inputs_inputs,
443                            dequants_inputs,
444                            quant_node,
445                        )
446                    fused = graph_module.graph.call_function(
447                        pattern.replacement_op(),
448                        args,
449                        kwargs,
450                    )
451                    fused.meta = quant_node.meta
452                    quant_node.replace_all_uses_with(fused)
453
454            legalize_graph(graph_module)
455            graph_module.graph.eliminate_dead_code()
456            # pyre-fixme[7]: Incompatible return type
457            graph_module.recompile()
458
459    @classmethod
460    # pyre-ignore[2]: Parameter `nodes` has no type specified
461    def is_fused(cls, nodes) -> bool:
462        return any(cls.__qualname__ in n.meta for n in nodes)
463
464    @classmethod
465    # pyre-ignore[2]: Parameter `nodes` has no type specified
466    def mark_fused(cls, nodes) -> bool:
467        for n in nodes:
468            # pyre-fixme[7]: Incompatible return type
469            n.meta["QuantFusion"] = True
470