xref: /aosp_15_r20/external/executorch/backends/cadence/aot/compiler.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9import logging
10from pathlib import Path
11from typing import Callable, cast, Optional
12
13import executorch.backends.cadence.aot.ops_registrations  # noqa
14import torch
15from executorch.backends.cadence.aot.quantizer.fusion_pass import QuantFusion
16from executorch.backends.cadence.aot.quantizer.quantizer import CadenceQuantizer
17
18from executorch.backends.cadence.aot.replace_ops import ReplaceSafeSoftmaxWithSoftmax
19from executorch.backends.cadence.aot.utils import model_gm_has_SDPA, model_is_quantized
20from executorch.backends.transforms.decompose_sdpa import (
21    DecomposeScaledDotProductAttention,
22)
23from executorch.devtools import generate_etrecord
24from executorch.exir import (
25    EdgeCompileConfig,
26    EdgeProgramManager,
27    ExecutorchProgramManager,
28    to_edge,
29)
30from executorch.exir.pass_base import PassResult
31from torch.ao.quantization.pt2e.export_utils import model_is_exported
32from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
33
34from torch.export import export
35from torch.export.exported_program import ExportedProgram
36
37from .passes import get_cadence_passes
38
39from .utils import print_ops_info
40
41
42# Note: this is not meant as a primary API since it can create inconsistencies
43# if the quantizer here is different from the quantizer used to convert. It is
44# however useful for unit tests to separate the converted model from the fused
45# model, to be able to get reference numerics.
46# If this does not apply, please use quantize_and_fuse_pt2 instead.
47def convert_pt2(
48    model: torch.nn.Module,
49    inputs: tuple[object, ...],
50    quantizer: CadenceQuantizer,
51) -> torch.fx.GraphModule:
52    """
53    Prepare and convert a model using the given quantizer.
54    The quantizer must be supplied and be the same as the one used to
55    fuse the model later, if applicable. If you do not expect that behavior,
56    please use quantize_and_fuse_pt2 instead, which will instantiate a
57    default quantizer for you if needed.
58    Returns a GraphModule with the converted model.
59    """
60
61    # Export with dynamo
62    model_gm = torch.export.export_for_training(model, inputs).module()
63
64    if model_gm_has_SDPA(model_gm):  # pyre-fixme[6]
65        # Decompose SDPA
66        DecomposeScaledDotProductAttention(False)(model_gm)  # pyre-fixme[6]
67
68        # Swap _safe_softmax with _softmax (see https://github.com/pytorch/pytorch/pull/133882
69        # for details).
70        result = ReplaceSafeSoftmaxWithSoftmax()(model_gm)  # pyre-fixme[6]
71        assert result is not None
72        model_gm = result.graph_module
73
74    # Prepare
75    prepared_model = prepare_pt2e(model_gm, quantizer)
76
77    # Calibrate
78    prepared_model(*inputs)
79
80    # Convert
81    converted_model = convert_pt2e(prepared_model)
82
83    return converted_model
84
85
86# Note: this is not meant as a primary API since it can create inconsistencies
87# if the quantizer here is different from the quantizer used to convert. It is
88# however useful for unit tests to separate the converted model from the fused
89# model, to be able to get reference numerics.
90# If this does not apply, please use quantize_and_fuse_pt2 instead.
91def fuse_pt2(
92    converted_graph_module: torch.fx.GraphModule,
93    quantizer: CadenceQuantizer,
94) -> torch.fx.GraphModule:
95    """
96    Fuse a converted graph module using the given quantizer.
97    The quantizer must be the same as the one used to convert the model.
98    If you do not expect that behavior, please use quantize_and_fuse_pt2 instead,
99    which will instantiate a default quantizer for you if needed.
100    Returns a GraphModule with the fused model.
101    """
102    # Get patterns and apply fusion of dq -> op -> q to qop
103    # pyre-ignore[16]: no attribute
104    patterns = [q.pattern for q in quantizer.quantizers]
105    QuantFusion(patterns)(converted_graph_module)
106
107    return converted_graph_module
108
109
110# Note: this is the one-liner API to quantize and fuse a model.
111def quantize_pt2(
112    model: torch.nn.Module,
113    inputs: tuple[object, ...],
114    quantizer: Optional[CadenceQuantizer] = None,
115) -> torch.fx.GraphModule:
116    """
117    Prepare, convert and fuse the model using the given quantizer.
118    Returns a GraphModule with the quantized model.
119    """
120    # Quantizer
121    if not quantizer:
122        quantizer = CadenceQuantizer()
123
124    # Get converted graph module
125    converted_gm = convert_pt2(model, inputs, quantizer)
126
127    # Get fused model
128    fused_gm = fuse_pt2(converted_gm, quantizer)
129
130    return fused_gm
131
132
133# Export the model and lower it to an ExportedProgram (in aten IR)
134def export_program(
135    model: torch.nn.Module,
136    inputs: tuple[object, ...],
137    dump_graphs: bool = False,
138) -> ExportedProgram:
139    assert isinstance(model, torch.nn.Module), "model should be an nn.Module"
140
141    # We don't support training mode. Make the model inference mode by
142    # calling model.eval() or an equivalent call for quantized models.
143    # GraphModules cannot call eval(), so we skip them.
144    if not isinstance(model, torch.fx.GraphModule):
145        if hasattr(model, "eval"):
146            model.eval()
147    else:
148        # If the model is quantized, call the suggested torch.ao.quantization API
149        # which only does dropout and batchnorm.
150        if model_is_quantized(model):
151            torch.ao.quantization.move_exported_model_to_eval(model)
152        else:
153            # If we get a GraphModule which is _not_ quantized, then it should already
154            # have been exported.
155            assert model_is_exported(model), "model should be from an ExportedProgram"
156
157    # Prevent mkldnn decompositions
158    torch._C._set_mkldnn_enabled(False)
159
160    # else: capture the model and return it.
161    expo_program = export(model, inputs)
162
163    if dump_graphs:
164        logging.info("Exported graph:")
165        expo_program.graph_module.graph.print_tabular()
166
167    return expo_program
168
169
170# Export the model and lower it to an EdgeProgramManager (in edge IR).
171def export_to_edge(
172    model: torch.nn.Module,
173    inputs: tuple[object, ...],
174    dump_graphs: bool = False,
175) -> EdgeProgramManager:
176    assert isinstance(model, torch.nn.Module), "model should be an nn.Module"
177
178    # Export the model into an ExportedProgram.
179    expo_program = export_program(model, inputs, dump_graphs=dump_graphs)
180
181    # Call to_edge to convert the graph to edge IR.
182    # Note: dim_order is skipped (https://github.com/pytorch/executorch/issues/3704)
183    edge_prog_manager = to_edge(
184        expo_program,
185        compile_config=EdgeCompileConfig(
186            _check_ir_validity=False, _skip_dim_order=True
187        ),
188    )
189
190    if dump_graphs:
191        logging.info("Edge graph:")
192        edge_prog_manager.exported_program().graph_module.graph.print_tabular()
193
194    return edge_prog_manager
195
196
197def export_to_cadence(
198    model: torch.nn.Module,
199    inputs: tuple[object, ...],
200    dump_graphs: bool = False,
201    output_dir: Optional[str] = None,
202    opt_level: int = 1,
203) -> EdgeProgramManager:
204    edge_prog_manager = export_to_edge(model, inputs)
205    cadence_passes = get_cadence_passes(opt_level)
206
207    # Run a couple required passes for quant/dequant ops
208    cadence_prog_manager = edge_prog_manager.transform(
209        cast(
210            list[Callable[[torch.fx.GraphModule], Optional[PassResult]]], cadence_passes
211        )
212    )
213    return cadence_prog_manager
214
215
216def quantize_and_export_to_cadence(
217    model: torch.nn.Module,
218    inputs: tuple[object, ...],
219    dump_graphs: bool = False,
220    opt_level: int = 1,
221) -> EdgeProgramManager:
222    quantized_model = quantize_pt2(model, inputs)
223
224    return export_to_cadence(
225        quantized_model,
226        inputs,
227        opt_level=opt_level,
228        dump_graphs=dump_graphs,
229    )
230
231
232# Export the model and lower it to an EdgeProgramManager (in edge IR), and
233# apply passes specific to Cadence DSP execution. Return both to print the
234# differences.
235def export_to_executorch_gen_etrecord(
236    model: torch.nn.Module,
237    inputs: tuple[object, ...],
238    dump_graphs: bool = False,
239    output_dir: Optional[str] = None,
240    opt_level: int = 1,
241) -> ExecutorchProgramManager:
242    edge_prog_manager = export_to_edge(model, inputs)
243    cadence_passes = get_cadence_passes(opt_level)
244
245    # Run a couple required passes for quant/dequant ops
246    cadence_prog_manager = edge_prog_manager.transform(
247        cast(
248            list[Callable[[torch.fx.GraphModule], Optional[PassResult]]], cadence_passes
249        )
250    )
251
252    # Print some information to terminal
253    print_ops_info(
254        edge_prog_manager.exported_program().graph_module,
255        cadence_prog_manager.exported_program().graph_module,
256    )
257
258    # Get executorch program after Cadence specific passes
259    exec_prog: ExecutorchProgramManager = cadence_prog_manager.to_executorch()
260    if output_dir:
261        _gen_etrecord(edge_prog_manager, exec_prog, Path(output_dir))
262    else:
263        logging.warning("No output directory provided, skipping ETRecord generation")
264
265    return exec_prog
266
267
268def _gen_etrecord(
269    edge_program: EdgeProgramManager,
270    et_program: ExecutorchProgramManager,
271    output_dir: Path,
272) -> None:
273    etrec_path = output_dir / "etrecord.bin"
274    try:
275        generate_etrecord(
276            et_record=etrec_path,
277            edge_dialect_program=edge_program,
278            executorch_program=et_program,
279        )
280        logging.info(f"Generated ETRecord at {etrec_path}")
281    except Exception:
282        # Any errors here shouldn't block the rest of the flow
283        logging.exception("Encountered exception while generating ETRecord")
284