1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import logging 8from dataclasses import dataclass 9from typing import Dict, final, List 10 11import torch 12 13from executorch.backends.xnnpack._passes import XNNPACKPassManager 14from executorch.backends.xnnpack._passes.convert_to_linear import ConvertToLinearPass 15from executorch.backends.xnnpack._passes.tag_implicit_q_dq_pass import ( 16 TagImplicitQDqPass, 17) 18from executorch.backends.xnnpack.operators.node_visitor import get_node_visitors 19 20from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import ( 21 ConstantDataOffset, 22 XNNGraph, 23) 24from executorch.backends.xnnpack.serialization.xnnpack_graph_serialize import ( 25 serialize_xnnpack_binary, 26) 27from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config 28from executorch.backends.xnnpack.utils.utils import is_param_node 29 30from executorch.backends.xnnpack.utils.xnnpack_constants import ( 31 XNN_VALUE_FLAG_EXTERNAL_INPUT, 32 XNN_VALUE_FLAG_EXTERNAL_OUTPUT, 33) 34 35from executorch.exir.backend.backend_details import ( 36 BackendDetails, 37 CompileSpec, 38 PreprocessResult, 39) 40from executorch.exir.verification.verifier import EXIREdgeDialectVerifier 41from torch.export.exported_program import ExportedProgram 42 43DEFAULT_DEBUG_HANDLE = 65535 44 45logger = logging.getLogger(__name__) 46logger.setLevel(logging.WARNING) 47 48 49@dataclass 50class ExternalMeta: 51 external_id: int 52 io_type: int 53 54 55def generate_node_to_external_map( 56 exported_program: ExportedProgram, 57 edge_graph_module: torch.fx.GraphModule, 58) -> Dict[torch.fx.Node, ExternalMeta]: 59 node_to_external_map = {} 60 for node in edge_graph_module.graph.nodes: 61 # The order in which we visit the placeholder node is same as the *args 62 # order for the forward(*args) signature for this gm. Using the order of 63 # the nodes as external_id to extract the right arg from *args at runtime 64 # 65 # Removing parameters/buffers since they will disappear from the signature 66 # at runtime 67 if node.op == "placeholder" and not is_param_node(exported_program, node): 68 node_to_external_map[node] = ExternalMeta( 69 external_id=len(node_to_external_map), 70 io_type=XNN_VALUE_FLAG_EXTERNAL_INPUT, 71 ) 72 for node in edge_graph_module.graph.nodes: 73 if node.op == "output": 74 for output_nodes in node.args: 75 for output_node in output_nodes: 76 node_to_external_map[output_node] = ExternalMeta( 77 external_id=len(node_to_external_map), 78 io_type=XNN_VALUE_FLAG_EXTERNAL_OUTPUT, 79 ) 80 return node_to_external_map 81 82 83def assert_default_dim_order(edge_graph_module: torch.fx.GraphModule) -> None: 84 for node in edge_graph_module.graph.nodes: 85 if node.op != "placeholder": 86 continue 87 88 # We expect the default dim order for all tensor-like inputs i.e. inputs, buffers, and params 89 t = node.meta.get("val", None) 90 if t is not None and getattr(t, "dim_order", None) is not None: 91 default_dim_order = tuple(range(t.dim())) 92 if t.dim_order() != default_dim_order: 93 raise RuntimeError( 94 f"XNNPACK backend only supports contiguous memory format for inputs." 95 f"Expecting dim_order: {default_dim_order}, but got {node.meta['val'].dim_order()} for a placeholder node {node}." 96 ) 97 98 99@final 100class XnnpackBackend(BackendDetails): 101 @staticmethod 102 def preprocess( 103 edge_program: ExportedProgram, 104 compile_specs: List[CompileSpec], 105 ) -> PreprocessResult: 106 107 xnnpack_edge_compile_config = get_xnnpack_edge_compile_config() 108 109 # Need to wrap EP here because xnnpack does addmm to linear 110 # transforms. This makes resulting graph not aten compliant 111 # as aten.linear is not a core aten op. 112 # Ideal fix would be to have XNNPACK verifier that bypass 113 # most checks but the base Verifier itself has some strict changes 114 # and to bypass those, we would basically copy what EdgeDialectVerifier 115 # does. So for now instead of copy pasting that, just instantiate 116 # EdgeDialectVerifier, but disable it. 117 # TODO (task link) to implement NullVerifier or something similar 118 ep = ExportedProgram( 119 root=edge_program.graph_module, 120 graph=edge_program.graph, 121 graph_signature=edge_program.graph_signature, 122 state_dict=edge_program.state_dict, 123 range_constraints=edge_program.range_constraints, 124 module_call_graph=edge_program.module_call_graph, 125 example_inputs=edge_program.example_inputs, 126 constants=edge_program.constants, 127 verifiers=[ 128 EXIREdgeDialectVerifier( 129 edge_compile_config=xnnpack_edge_compile_config, class_only=True 130 ) 131 ], 132 ) 133 134 passes = [] 135 for spec in compile_specs: 136 if spec.key == "dqlinear_partitioner": 137 passes.append(ConvertToLinearPass) 138 passes.append(TagImplicitQDqPass) 139 140 passes = passes if len(passes) > 0 else None 141 # XNNPACK Delegate Specific Passes 142 ep = XNNPACKPassManager(ep, passes=passes).transform() 143 graph_module = ep.graph_module 144 145 node_to_external_map = generate_node_to_external_map(ep, graph_module) 146 147 # Make sure all inputs are contiguous_format or NCHW or default dim order 148 assert_default_dim_order(graph_module) 149 150 # TODO retrace the graph module to lift the new params may have 151 # been added to the graph in passes 152 153 vals_to_ids = {} 154 xnnpack_graph = XNNGraph( 155 version="0", 156 xnodes=[], 157 xvalues=[], 158 num_externs=len(node_to_external_map), 159 input_ids=[], 160 output_ids=[], 161 constant_data=[ConstantDataOffset(0, 0)], 162 ) 163 164 constant_data_bytes = bytearray() 165 node_visitors = get_node_visitors(ep, node_to_external_map, constant_data_bytes) 166 167 for node in graph_module.graph.nodes: 168 if node.op == "call_function": 169 logger.info(f"Visiting: {node}, {node.target.__name__}") 170 if node.target.__name__ in node_visitors: 171 node_visitors[node.target.__name__].define_node( 172 node, 173 xnnpack_graph, 174 vals_to_ids, 175 node.meta.get("debug_handle", DEFAULT_DEBUG_HANDLE), 176 ) 177 else: 178 raise RuntimeError( 179 f"For {node}, {node.op}:{node.target.__name__} is not supported in XNNPACK Delegate" 180 ) 181 elif node.op in [ 182 "get_attr", 183 "placeholder", 184 "output", 185 ]: 186 continue 187 else: 188 raise RuntimeError(f"{node.op} is not supported in XNNPACK") 189 return PreprocessResult( 190 processed_bytes=serialize_xnnpack_binary( 191 xnnpack_graph, constant_data_bytes 192 ), 193 debug_handle_map={}, 194 ) 195