# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import copy import logging import random import sys from abc import ABC, abstractmethod from collections import Counter, OrderedDict from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union import torch from executorch.backends.xnnpack._passes import XNNPACKPassManager from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config from executorch.exir import ( EdgeCompileConfig, EdgeProgramManager, ExecutorchBackendConfig, ExecutorchProgramManager, to_edge, to_edge_transform_and_lower, ) from executorch.exir.backend.backend_api import validation_disabled from executorch.exir.backend.partitioner import Partitioner from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass from executorch.exir.print_program import pretty_print, print_program from torch.export import export_for_training logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) try: from executorch.extension.pybindings.portable_lib import ( # @manual _load_for_executorch_from_buffer, ) except ImportError as e: logger.warning(f"{e=}") pass from executorch.exir.program._program import _transform from torch._export.pass_base import PassType from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.ao.quantization.quantizer.quantizer import Quantizer from torch.ao.quantization.quantizer.xnnpack_quantizer import ( get_symmetric_quantization_config, XNNPACKQuantizer, ) from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import QuantizationConfig from torch.export import export, ExportedProgram from torch.testing import FileCheck from torch.utils._pytree import tree_flatten class Stage(ABC): """ Interface for a Stage in the PT2.0 lowering pipeline """ @abstractmethod def run(self, artifact, inputs): """ Executes this stage, generates the 'artifact', for later stages. """ pass @property @abstractmethod def artifact(self): """ Returns the artifact generated by this stage. To be used by the next stage in the pipeline. """ pass @property @abstractmethod def graph_module(self): """ Return the artifact's graph module for this stage """ pass def run_artifact(self, inputs): """ Returns the output of calling the artifact generated by this stage with inputs """ if isinstance(self.artifact, ExportedProgram): return self.artifact(*inputs) else: return self.artifact.exported_program().module()(*inputs) # Debug Tools for stages def artifact_str(self): """ Return string printable artifact for this stage """ if isinstance(self.artifact, EdgeProgramManager): return self.artifact.exported_program() return self.artifact def stage_banner(self): """ Returns banner string for this stage """ return "#" * 36 + " " + str(self.__class__.__name__) + " " + "#" * 36 + "\n" def dump_artifact(self, path_to_dump: Optional[str]): """ Dumps string printable artifact to path. If path_to_dump, then it is printed to terminal """ if path_to_dump: with open(path_to_dump, "a") as fp: fp.write(str(self.stage_banner() + "\n")) fp.write(str(self.artifact_str())) else: print(self.stage_banner() + "\n") print(self.artifact_str()) _stages_: Dict[str, Stage] = {} def register_stage(stage: Stage): """ Register a Stage to be used in the Tester. """ assert isinstance(stage, type) name = stage.__qualname__ if name in _stages_: raise RuntimeError(f"Duplicate stage in Tester, {name}") _stages_[name] = stage return stage @register_stage class Quantize(Stage): def __init__( self, quantizer: Optional[Quantizer] = None, quantization_config: Optional[QuantizationConfig] = None, calibrate: bool = True, ): self.quantizer = quantizer or XNNPACKQuantizer() self.quantization_config = ( quantization_config or get_symmetric_quantization_config() ) self.calibrate = calibrate self.quantizer.set_global(self.quantization_config) self.converted_graph = None def run( self, artifact: torch.nn.Module, inputs: Optional[Tuple[torch.Tensor]] ) -> None: assert inputs is not None captured_graph = export_for_training(artifact, inputs).module() assert isinstance(captured_graph, torch.fx.GraphModule) prepared = prepare_pt2e(captured_graph, self.quantizer) if self.calibrate: # Calibrate prepared model to provide data to quantization observers. prepared(*inputs) converted = convert_pt2e(prepared) self.converted_graph = converted @property def artifact(self) -> torch.fx.GraphModule: return self.converted_graph @property def graph_module(self) -> str: return self.converted_graph def run_artifact(self, inputs): return self.converted_graph.forward(*inputs) @register_stage class Export(Stage): def __init__(self, dynamic_shapes: Optional[Tuple[Any]] = None): self.exported_program = None self.dynamic_shapes = dynamic_shapes def run( self, artifact: torch.nn.Module, inputs: Tuple[torch.Tensor], ) -> None: self.exported_program = export( artifact, inputs, dynamic_shapes=self.dynamic_shapes ) @property def artifact(self) -> ExportedProgram: return self.exported_program @property def graph_module(self) -> str: return self.exported_program.graph_module @register_stage class ToEdge(Stage): def __init__(self, edge_compile_config: Optional[EdgeCompileConfig] = None): self.edge_compile_conf = ( edge_compile_config or get_xnnpack_edge_compile_config() ) self.edge_dialect_program = None def run(self, artifact: ExportedProgram, inputs=None) -> None: self.edge_dialect_program = to_edge( artifact, compile_config=self.edge_compile_conf ) @property def artifact(self) -> EdgeProgramManager: return self.edge_dialect_program @property def graph_module(self) -> str: return self.edge_dialect_program.exported_program().graph_module @register_stage class RunPasses(Stage): def __init__( self, pass_list: Optional[List[Type[PassType]]] = None, pass_functions: Optional[List[Callable]] = None, ): self.pass_list = pass_list self.pass_functions = pass_functions self.edge_or_aten_program = None def run( self, artifact: Union[EdgeProgramManager, ExportedProgram], inputs=None ) -> None: if isinstance(artifact, EdgeProgramManager): self.edge_or_aten_program = artifact if self.pass_list: pass_manager = XNNPACKPassManager( artifact.exported_program(), self.pass_list ) self.edge_or_aten_program._edge_programs["forward"] = ( pass_manager.transform() ) if self.pass_functions: assert isinstance(self.pass_functions, list) for pass_function in self.pass_functions: self.edge_or_aten_program._edge_programs["forward"] = pass_function( self.edge_or_aten_program.exported_program() ) else: transformed_ep = artifact if self.pass_list: assert isinstance(self.pass_list, list) for pass_ in self.pass_list: transformed_ep = _transform(transformed_ep, pass_()) if self.pass_functions: assert isinstance(self.pass_functions, list) for pass_function in self.pass_functions: transformed_ep = pass_function(transformed_ep) self.edge_or_aten_program = transformed_ep @property def artifact(self) -> Union[EdgeProgramManager, ExportedProgram]: return self.edge_or_aten_program @property def graph_module(self) -> str: if isinstance(self.edge_or_aten_program, EdgeProgramManager): return self.edge_or_aten_program.exported_program().graph_module else: return self.edge_or_aten_program.graph_module @register_stage class ToEdgeTransformAndLower(Stage): def __init__( self, partitioners: Optional[List[Partitioner]] = None, edge_compile_config: Optional[EdgeCompileConfig] = None, ): self.partitioners = partitioners or [XnnpackPartitioner()] self.edge_compile_conf = ( edge_compile_config or get_xnnpack_edge_compile_config() ) self.edge_dialect_program = None def run(self, artifact: ExportedProgram, inputs=None) -> None: artifact_to_run = copy.deepcopy(artifact) self.edge_dialect_program = to_edge_transform_and_lower( artifact_to_run, compile_config=self.edge_compile_conf, partitioner=self.partitioners, ) @property def artifact(self) -> EdgeProgramManager: return self.edge_dialect_program @property def graph_module(self) -> str: return self.edge_dialect_program.exported_program().graph_module @register_stage class Partition(Stage): def __init__(self, partitioner: Optional[Partitioner] = None): self.partitioner = partitioner or XnnpackPartitioner() self.delegate_module = None def run(self, artifact: EdgeProgramManager, inputs=None): with validation_disabled(): self.delegate_module = artifact self.delegate_module = self.delegate_module.to_backend(self.partitioner) @property def artifact(self) -> EdgeProgramManager: return self.delegate_module @property def graph_module(self) -> str: return self.delegate_module.exported_program().graph_module @register_stage class ToExecutorch(Stage): def __init__( self, config: Optional[ExecutorchBackendConfig] = None, ): self.config = config or ExecutorchBackendConfig( extract_delegate_segments=True, sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(), ) self.executorch_program = None def run(self, artifact: EdgeProgramManager, inputs=None): self.executorch_program = artifact.to_executorch(self.config) @property def artifact(self) -> ExecutorchProgramManager: return self.executorch_program @property def graph_module(self) -> str: return self.executorch_program().graph_module def dump_artifact(self, path_to_dump: Optional[str]): """ dump_artifact is overridden to dump the serialized program """ original_stdout = sys.stdout sys.stdout = open(path_to_dump, "a") if path_to_dump else sys.stdout print(self.stage_banner() + "\n") pretty_print(self.artifact._emitter_output.program) print_program( self.artifact._emitter_output.program, show_meminfo=True, mark_dynamic_shape_tensor=True, ) sys.stdout = original_stdout @register_stage class Serialize(Stage): def __init__(self): self.buffer = None def run(self, artifact: ExecutorchProgramManager, inputs=None) -> None: self.buffer = artifact.buffer @property def artifact(self) -> bytes: return self.buffer @property def graph_module(self) -> None: return None def run_artifact(self, inputs): inputs_flattened, _ = tree_flatten(inputs) executorch_module = _load_for_executorch_from_buffer(self.buffer) executorch_output = copy.deepcopy( executorch_module.run_method("forward", tuple(inputs_flattened)) ) return executorch_output def dump_artifact(self, path_to_dump: Optional[str]): """ dump_artifact is overridden to dump the serialized bytes into pte file """ if not path_to_dump: raise RuntimeError("path_to_dump file not provided") else: with open(path_to_dump, "wb") as f: f.write(self.artifact) class Tester: def __init__( self, module: torch.nn.Module, example_inputs: Tuple[torch.Tensor], dynamic_shapes: Optional[Tuple[Any]] = None, ): module.eval() self.original_module = module self.example_inputs = example_inputs self.dynamic_shapes = dynamic_shapes self.stages: Dict[str, Stage] = OrderedDict.fromkeys(list(_stages_.keys())) self.pipeline = { self.stage_name(Quantize): [self.stage_name(Export)], self.stage_name(Export): [ self.stage_name(RunPasses), self.stage_name(ToEdge), self.stage_name(ToEdgeTransformAndLower), ], self.stage_name(ToEdgeTransformAndLower): [ self.stage_name(RunPasses), self.stage_name(ToExecutorch), ], self.stage_name(ToEdge): [ self.stage_name(Partition), self.stage_name(RunPasses), ], self.stage_name(RunPasses): [ self.stage_name(Partition), self.stage_name(ToEdgeTransformAndLower), ], # TODO Make this Stage optional self.stage_name(Partition): [self.stage_name(ToExecutorch)], self.stage_name(ToExecutorch): [self.stage_name(Serialize)], self.stage_name(Serialize): [], } assert all( stage in self.pipeline for stage in self.stages ), "Invalid Tester internal state!" # Current stage name self.cur: str = "" # Reference output from eager mode self.reference_output = None # Quantization scale from eager mode self.quantization_scale: Optional[float] = None # Artifact output from stage self.stage_output = None def generate_random_inputs(self): # Get shapes of inputs input_shapes = [] if self.dynamic_shapes is None: for tensor_arg in self.example_inputs: assert isinstance(tensor_arg, torch.Tensor) input_shapes.append(tensor_arg.shape) else: # Random shapes depending on dynamic shape constraint dim_name_to_size = {} for arg_idx in range(len(self.example_inputs)): assert isinstance(self.example_inputs[arg_idx], torch.Tensor) ex_shape = list(self.example_inputs[arg_idx].shape) dynamic_dim_spec = self.dynamic_shapes[arg_idx] for dim_idx, dim_spec in dynamic_dim_spec.items(): assert dim_idx < len(ex_shape) if isinstance(dim_spec, torch.export.dynamic_shapes._DerivedDim): # derived dims are of the form {0: 2 * torch.export.Dim() // 2} # The root contains the min/max of the export dim and fn contains # the function to compute the derived dim. dim_spec = dim_spec.root fn = dim_spec.fn elif isinstance(dim_spec, torch.export.dynamic_shapes._Dim): # Not derived dim so fn is just itself def fn(x): return x else: raise RuntimeError( f"Expected Dynamic Dims to be of type _DerivedDim or _Dim but got {type(dim_spec)}" ) dim_name = dim_spec.__name__ if dim_name not in dim_name_to_size: upper_bound = min( dim_spec.max, 1000 ) # unbounded int max is too large lower_bound = ( dim_spec.min if dim_spec.min >= 2 else 1 ) # 0/1 specialization means dim_spec.min can never be 1 dim_name_to_size[dim_name] = fn( random.randint(lower_bound, upper_bound) ) ex_shape[dim_idx] = dim_name_to_size[dim_spec.__name__] input_shapes.append(torch.Size(ex_shape)) # create random tensor inputs with the shapes given above: random_inputs = [] for arg_idx in range(len(self.example_inputs)): random_inputs.append( torch.randn(input_shapes[arg_idx]).to( dtype=self.example_inputs[arg_idx].dtype ) ) yield tuple(random_inputs) @staticmethod def stage_name(stage) -> str: t = stage if isinstance(stage, type) else type(stage) return t.__qualname__ def _pre(self, stage): name: str = self.stage_name(stage) assert isinstance(name, str) and name in self.stages and not self.stages[name] last_artifact = self.original_module if self.cur: assert self.cur in self.pipeline, f"Invalid state: {self.cur}" allowed_next_stages = self.pipeline[self.cur] assert name in allowed_next_stages, f"Invalid next stage: {name}" last_artifact = self.get_artifact() self.cur = name return last_artifact def _post(self, stage): name = self.stage_name(stage) assert name in self.stages self.stages[name] = stage def _run_stage(self, stage_instance, inputs=None): assert isinstance(stage_instance, Stage) prev_stage_artifact = self._pre(stage_instance) stage_instance.run(prev_stage_artifact, inputs=inputs) self._post(stage_instance) return self # Stages def quantize(self, quantize_stage: Optional[Quantize] = None): return self._run_stage(quantize_stage or Quantize(), self.example_inputs) def export(self, export_stage: Optional[Export] = None): return self._run_stage( export_stage or Export(dynamic_shapes=self.dynamic_shapes), self.example_inputs, ) def to_edge(self, to_edge_stage: Optional[ToEdge] = None): # TODO(T182187531): Skip dim order for now. Support dim order and its op after alpha release. if not to_edge_stage: to_edge_stage = ToEdge() to_edge_stage.edge_compile_conf._skip_dim_order = True res = self._run_stage(to_edge_stage) return res def to_edge_transform_and_lower( self, to_edge_and_transform_stage: Optional[ToEdgeTransformAndLower] = None ): return self._run_stage(to_edge_and_transform_stage or ToEdgeTransformAndLower()) def run_passes(self, run_passes_stage: Optional[RunPasses] = None): return self._run_stage(run_passes_stage or RunPasses()) def partition(self, partition_stage: Optional[Partition] = None): return self._run_stage(partition_stage or Partition()) def to_executorch(self, to_executorch_stage: Optional[ToExecutorch] = None): return self._run_stage(to_executorch_stage or ToExecutorch()) def serialize(self, serialize_stage: Optional[Serialize] = None): return self._run_stage(serialize_stage or Serialize()) # Util functions def dump_artifact(self, path: Optional[str] = None, stage: Optional[str] = None): stage = stage or self.cur self.stages[stage].dump_artifact(path) return self def get_artifact(self, stage: Optional[str] = None): stage = stage or self.cur return self.stages[stage].artifact def check(self, input: List[str]): for key in input: FileCheck().check(key).run(self.stages[self.cur].graph_module.code) return self def check_not(self, input: List[str]): for key in input: FileCheck().check_not(key).run(self.stages[self.cur].graph_module.code) return self def check_count(self, input: Dict[Any, int]): # TODO target checks similar to checkGraphModuleNodes() for key, count in input.items(): FileCheck().check_count(key, count, exactly=True).run( self.stages[self.cur].graph_module.code ) return self def check_node_count(self, input: Dict[Any, int]): # Count the occurances of each target in the graph. target_ops = [ node.target for node in self.stages[self.cur].graph_module.graph.nodes if node.op == "call_function" ] op_counts = Counter(target_ops) for key, count in input.items(): if count != op_counts[key]: print(f"Nodes: {op_counts}") raise AssertionError( f"Expected {count} {key} nodes but found {op_counts[key]}." ) return self def run_method_and_compare_outputs( self, stage: Optional[str] = None, inputs: Optional[Tuple[torch.Tensor]] = None, num_runs=1, atol=1e-03, rtol=1e-03, qtol=0, ): number_of_runs = 1 if inputs is not None else num_runs reference_stage = self.stages[self.stage_name(Export)] stage = stage or self.cur print(f"Comparing Stage {stage} with Stage {reference_stage}") for run_iteration in range(number_of_runs): inputs_to_run = inputs if inputs else next(self.generate_random_inputs()) input_shapes = [generated_input.shape for generated_input in inputs_to_run] print(f"Run {run_iteration} with input shapes: {input_shapes}") # Reference output (and quantization scale) ( reference_output, quantization_scale, ) = self._calculate_reference_output( reference_stage.artifact, inputs_to_run ) # Output from running artifact at stage stage_output = self.stages[stage].run_artifact(inputs_to_run) self._compare_outputs( reference_output, stage_output, quantization_scale, atol, rtol, qtol ) return self @staticmethod def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03): """ Helper testing function that asserts that the model output and the reference output are equal with some tolerance. Due to numerical differences between eager mode and the XNNPACK's backend, we relax the detal such that absolute tolerance is 1e-3. and relative tolerance is 1e-3. In the event that the computation was quantized, we further relax the tolerance to one quantized step (equal to the quantization scale). This allows the quantized value to differ by 1 between the reference and model output. """ assert len(model_output) == len(ref_output) for i in range(len(model_output)): model = model_output[i] ref = ref_output[i] assert torch.allclose( model, ref, atol=atol, rtol=rtol, ), ( f"Output {i} does not match reference output.\n" f"\tGiven atol: {atol}, rtol: {rtol}.\n" f"\tOutput tensor shape: {model.shape}, dtype: {model.dtype}\n" f"\tDifference: max: {torch.max(model-ref)}, abs: {torch.max(torch.abs(model-ref))}, mean abs error: {torch.mean(torch.abs(model-ref))}.\n" f"\t-- Model vs. Reference --\n" f"\t Numel: {model.numel()}, {ref.numel()}\n" f"\tMedian: {model.median()}, {ref.median()}\n" f"\t Mean: {model.mean()}, {ref.mean()}\n" f"\t Max: {model.max()}, {ref.max()}\n" f"\t Min: {model.min()}, {ref.min()}\n" ) @staticmethod def _compare_outputs( reference_output, stage_output, quantization_scale=None, atol=1e-03, rtol=1e-03, qtol=0, ): """ Compares the original of the original nn module with the output of the generated artifact. This requres calling run_method before calling compare_outputs. As that runs the generated artifact on the sample inputs and sets the stage output to be compared against the reference. """ # Wrap both outputs as tuple, since executor output is always a tuple even if single tensor if isinstance(reference_output, torch.Tensor): reference_output = (reference_output,) if isinstance(stage_output, torch.Tensor): stage_output = (stage_output,) # If a qtol is provided and we found an dequantization node prior to the output, relax the # atol by qtol quant units. if quantization_scale is not None: atol += quantization_scale * qtol Tester._assert_outputs_equal( stage_output, reference_output, atol=atol, rtol=rtol, ) @staticmethod def _calculate_reference_output( program: ExportedProgram, inputs ) -> Tuple[torch.Tensor, Optional[float]]: """ Execute the reference program and return the output. If the output comes from a dequantize node, return the quantization scale as well. """ # Locate the output node. output_node = None for node in program.graph.nodes: if node.op == "output": output_node = node break assert output_node is not None # Look for a dequantization node in the output node args. Returned values are found in the first # argument of the output node. dequant_node = None for arg_node in output_node.args[0]: if ( arg_node.op == "call_function" and arg_node.target == torch.ops.quantized_decomposed.dequantize_per_tensor.default ): dequant_node = arg_node break scale = None if dequant_node is not None: original_target = dequant_node.target # Replace the dequant node with shim to intercept the quantization parameters. # It will be invoked when we evaluate the program to find the reference outputs. def dequant_shim(*args): nonlocal scale scale = args[1] result = original_target(*args) return result dequant_node.target = dequant_shim output = program.module()(*inputs) return output, scale