1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9import logging 10from pathlib import Path 11from typing import Callable, cast, Optional 12 13import executorch.backends.cadence.aot.ops_registrations # noqa 14import torch 15from executorch.backends.cadence.aot.quantizer.fusion_pass import QuantFusion 16from executorch.backends.cadence.aot.quantizer.quantizer import CadenceQuantizer 17 18from executorch.backends.cadence.aot.replace_ops import ReplaceSafeSoftmaxWithSoftmax 19from executorch.backends.cadence.aot.utils import model_gm_has_SDPA, model_is_quantized 20from executorch.backends.transforms.decompose_sdpa import ( 21 DecomposeScaledDotProductAttention, 22) 23from executorch.devtools import generate_etrecord 24from executorch.exir import ( 25 EdgeCompileConfig, 26 EdgeProgramManager, 27 ExecutorchProgramManager, 28 to_edge, 29) 30from executorch.exir.pass_base import PassResult 31from torch.ao.quantization.pt2e.export_utils import model_is_exported 32from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e 33 34from torch.export import export 35from torch.export.exported_program import ExportedProgram 36 37from .passes import get_cadence_passes 38 39from .utils import print_ops_info 40 41 42# Note: this is not meant as a primary API since it can create inconsistencies 43# if the quantizer here is different from the quantizer used to convert. It is 44# however useful for unit tests to separate the converted model from the fused 45# model, to be able to get reference numerics. 46# If this does not apply, please use quantize_and_fuse_pt2 instead. 47def convert_pt2( 48 model: torch.nn.Module, 49 inputs: tuple[object, ...], 50 quantizer: CadenceQuantizer, 51) -> torch.fx.GraphModule: 52 """ 53 Prepare and convert a model using the given quantizer. 54 The quantizer must be supplied and be the same as the one used to 55 fuse the model later, if applicable. If you do not expect that behavior, 56 please use quantize_and_fuse_pt2 instead, which will instantiate a 57 default quantizer for you if needed. 58 Returns a GraphModule with the converted model. 59 """ 60 61 # Export with dynamo 62 model_gm = torch.export.export_for_training(model, inputs).module() 63 64 if model_gm_has_SDPA(model_gm): # pyre-fixme[6] 65 # Decompose SDPA 66 DecomposeScaledDotProductAttention(False)(model_gm) # pyre-fixme[6] 67 68 # Swap _safe_softmax with _softmax (see https://github.com/pytorch/pytorch/pull/133882 69 # for details). 70 result = ReplaceSafeSoftmaxWithSoftmax()(model_gm) # pyre-fixme[6] 71 assert result is not None 72 model_gm = result.graph_module 73 74 # Prepare 75 prepared_model = prepare_pt2e(model_gm, quantizer) 76 77 # Calibrate 78 prepared_model(*inputs) 79 80 # Convert 81 converted_model = convert_pt2e(prepared_model) 82 83 return converted_model 84 85 86# Note: this is not meant as a primary API since it can create inconsistencies 87# if the quantizer here is different from the quantizer used to convert. It is 88# however useful for unit tests to separate the converted model from the fused 89# model, to be able to get reference numerics. 90# If this does not apply, please use quantize_and_fuse_pt2 instead. 91def fuse_pt2( 92 converted_graph_module: torch.fx.GraphModule, 93 quantizer: CadenceQuantizer, 94) -> torch.fx.GraphModule: 95 """ 96 Fuse a converted graph module using the given quantizer. 97 The quantizer must be the same as the one used to convert the model. 98 If you do not expect that behavior, please use quantize_and_fuse_pt2 instead, 99 which will instantiate a default quantizer for you if needed. 100 Returns a GraphModule with the fused model. 101 """ 102 # Get patterns and apply fusion of dq -> op -> q to qop 103 # pyre-ignore[16]: no attribute 104 patterns = [q.pattern for q in quantizer.quantizers] 105 QuantFusion(patterns)(converted_graph_module) 106 107 return converted_graph_module 108 109 110# Note: this is the one-liner API to quantize and fuse a model. 111def quantize_pt2( 112 model: torch.nn.Module, 113 inputs: tuple[object, ...], 114 quantizer: Optional[CadenceQuantizer] = None, 115) -> torch.fx.GraphModule: 116 """ 117 Prepare, convert and fuse the model using the given quantizer. 118 Returns a GraphModule with the quantized model. 119 """ 120 # Quantizer 121 if not quantizer: 122 quantizer = CadenceQuantizer() 123 124 # Get converted graph module 125 converted_gm = convert_pt2(model, inputs, quantizer) 126 127 # Get fused model 128 fused_gm = fuse_pt2(converted_gm, quantizer) 129 130 return fused_gm 131 132 133# Export the model and lower it to an ExportedProgram (in aten IR) 134def export_program( 135 model: torch.nn.Module, 136 inputs: tuple[object, ...], 137 dump_graphs: bool = False, 138) -> ExportedProgram: 139 assert isinstance(model, torch.nn.Module), "model should be an nn.Module" 140 141 # We don't support training mode. Make the model inference mode by 142 # calling model.eval() or an equivalent call for quantized models. 143 # GraphModules cannot call eval(), so we skip them. 144 if not isinstance(model, torch.fx.GraphModule): 145 if hasattr(model, "eval"): 146 model.eval() 147 else: 148 # If the model is quantized, call the suggested torch.ao.quantization API 149 # which only does dropout and batchnorm. 150 if model_is_quantized(model): 151 torch.ao.quantization.move_exported_model_to_eval(model) 152 else: 153 # If we get a GraphModule which is _not_ quantized, then it should already 154 # have been exported. 155 assert model_is_exported(model), "model should be from an ExportedProgram" 156 157 # Prevent mkldnn decompositions 158 torch._C._set_mkldnn_enabled(False) 159 160 # else: capture the model and return it. 161 expo_program = export(model, inputs) 162 163 if dump_graphs: 164 logging.info("Exported graph:") 165 expo_program.graph_module.graph.print_tabular() 166 167 return expo_program 168 169 170# Export the model and lower it to an EdgeProgramManager (in edge IR). 171def export_to_edge( 172 model: torch.nn.Module, 173 inputs: tuple[object, ...], 174 dump_graphs: bool = False, 175) -> EdgeProgramManager: 176 assert isinstance(model, torch.nn.Module), "model should be an nn.Module" 177 178 # Export the model into an ExportedProgram. 179 expo_program = export_program(model, inputs, dump_graphs=dump_graphs) 180 181 # Call to_edge to convert the graph to edge IR. 182 # Note: dim_order is skipped (https://github.com/pytorch/executorch/issues/3704) 183 edge_prog_manager = to_edge( 184 expo_program, 185 compile_config=EdgeCompileConfig( 186 _check_ir_validity=False, _skip_dim_order=True 187 ), 188 ) 189 190 if dump_graphs: 191 logging.info("Edge graph:") 192 edge_prog_manager.exported_program().graph_module.graph.print_tabular() 193 194 return edge_prog_manager 195 196 197def export_to_cadence( 198 model: torch.nn.Module, 199 inputs: tuple[object, ...], 200 dump_graphs: bool = False, 201 output_dir: Optional[str] = None, 202 opt_level: int = 1, 203) -> EdgeProgramManager: 204 edge_prog_manager = export_to_edge(model, inputs) 205 cadence_passes = get_cadence_passes(opt_level) 206 207 # Run a couple required passes for quant/dequant ops 208 cadence_prog_manager = edge_prog_manager.transform( 209 cast( 210 list[Callable[[torch.fx.GraphModule], Optional[PassResult]]], cadence_passes 211 ) 212 ) 213 return cadence_prog_manager 214 215 216def quantize_and_export_to_cadence( 217 model: torch.nn.Module, 218 inputs: tuple[object, ...], 219 dump_graphs: bool = False, 220 opt_level: int = 1, 221) -> EdgeProgramManager: 222 quantized_model = quantize_pt2(model, inputs) 223 224 return export_to_cadence( 225 quantized_model, 226 inputs, 227 opt_level=opt_level, 228 dump_graphs=dump_graphs, 229 ) 230 231 232# Export the model and lower it to an EdgeProgramManager (in edge IR), and 233# apply passes specific to Cadence DSP execution. Return both to print the 234# differences. 235def export_to_executorch_gen_etrecord( 236 model: torch.nn.Module, 237 inputs: tuple[object, ...], 238 dump_graphs: bool = False, 239 output_dir: Optional[str] = None, 240 opt_level: int = 1, 241) -> ExecutorchProgramManager: 242 edge_prog_manager = export_to_edge(model, inputs) 243 cadence_passes = get_cadence_passes(opt_level) 244 245 # Run a couple required passes for quant/dequant ops 246 cadence_prog_manager = edge_prog_manager.transform( 247 cast( 248 list[Callable[[torch.fx.GraphModule], Optional[PassResult]]], cadence_passes 249 ) 250 ) 251 252 # Print some information to terminal 253 print_ops_info( 254 edge_prog_manager.exported_program().graph_module, 255 cadence_prog_manager.exported_program().graph_module, 256 ) 257 258 # Get executorch program after Cadence specific passes 259 exec_prog: ExecutorchProgramManager = cadence_prog_manager.to_executorch() 260 if output_dir: 261 _gen_etrecord(edge_prog_manager, exec_prog, Path(output_dir)) 262 else: 263 logging.warning("No output directory provided, skipping ETRecord generation") 264 265 return exec_prog 266 267 268def _gen_etrecord( 269 edge_program: EdgeProgramManager, 270 et_program: ExecutorchProgramManager, 271 output_dir: Path, 272) -> None: 273 etrec_path = output_dir / "etrecord.bin" 274 try: 275 generate_etrecord( 276 et_record=etrec_path, 277 edge_dialect_program=edge_program, 278 executorch_program=et_program, 279 ) 280 logging.info(f"Generated ETRecord at {etrec_path}") 281 except Exception: 282 # Any errors here shouldn't block the rest of the flow 283 logging.exception("Encountered exception while generating ETRecord") 284