xref: /aosp_15_r20/external/executorch/backends/xnnpack/xnnpack_preprocess.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import logging
8from dataclasses import dataclass
9from typing import Dict, final, List
10
11import torch
12
13from executorch.backends.xnnpack._passes import XNNPACKPassManager
14from executorch.backends.xnnpack._passes.convert_to_linear import ConvertToLinearPass
15from executorch.backends.xnnpack._passes.tag_implicit_q_dq_pass import (
16    TagImplicitQDqPass,
17)
18from executorch.backends.xnnpack.operators.node_visitor import get_node_visitors
19
20from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
21    ConstantDataOffset,
22    XNNGraph,
23)
24from executorch.backends.xnnpack.serialization.xnnpack_graph_serialize import (
25    serialize_xnnpack_binary,
26)
27from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config
28from executorch.backends.xnnpack.utils.utils import is_param_node
29
30from executorch.backends.xnnpack.utils.xnnpack_constants import (
31    XNN_VALUE_FLAG_EXTERNAL_INPUT,
32    XNN_VALUE_FLAG_EXTERNAL_OUTPUT,
33)
34
35from executorch.exir.backend.backend_details import (
36    BackendDetails,
37    CompileSpec,
38    PreprocessResult,
39)
40from executorch.exir.verification.verifier import EXIREdgeDialectVerifier
41from torch.export.exported_program import ExportedProgram
42
43DEFAULT_DEBUG_HANDLE = 65535
44
45logger = logging.getLogger(__name__)
46logger.setLevel(logging.WARNING)
47
48
49@dataclass
50class ExternalMeta:
51    external_id: int
52    io_type: int
53
54
55def generate_node_to_external_map(
56    exported_program: ExportedProgram,
57    edge_graph_module: torch.fx.GraphModule,
58) -> Dict[torch.fx.Node, ExternalMeta]:
59    node_to_external_map = {}
60    for node in edge_graph_module.graph.nodes:
61        # The order in which we visit the placeholder node is same as the *args
62        # order for the forward(*args) signature for this gm. Using the order of
63        # the nodes as external_id to extract the right arg from *args at runtime
64        #
65        # Removing parameters/buffers since they will disappear from the signature
66        # at runtime
67        if node.op == "placeholder" and not is_param_node(exported_program, node):
68            node_to_external_map[node] = ExternalMeta(
69                external_id=len(node_to_external_map),
70                io_type=XNN_VALUE_FLAG_EXTERNAL_INPUT,
71            )
72    for node in edge_graph_module.graph.nodes:
73        if node.op == "output":
74            for output_nodes in node.args:
75                for output_node in output_nodes:
76                    node_to_external_map[output_node] = ExternalMeta(
77                        external_id=len(node_to_external_map),
78                        io_type=XNN_VALUE_FLAG_EXTERNAL_OUTPUT,
79                    )
80    return node_to_external_map
81
82
83def assert_default_dim_order(edge_graph_module: torch.fx.GraphModule) -> None:
84    for node in edge_graph_module.graph.nodes:
85        if node.op != "placeholder":
86            continue
87
88        # We expect the default dim order for all tensor-like inputs i.e. inputs, buffers, and params
89        t = node.meta.get("val", None)
90        if t is not None and getattr(t, "dim_order", None) is not None:
91            default_dim_order = tuple(range(t.dim()))
92            if t.dim_order() != default_dim_order:
93                raise RuntimeError(
94                    f"XNNPACK backend only supports contiguous memory format for inputs."
95                    f"Expecting dim_order: {default_dim_order}, but got {node.meta['val'].dim_order()} for a placeholder node {node}."
96                )
97
98
99@final
100class XnnpackBackend(BackendDetails):
101    @staticmethod
102    def preprocess(
103        edge_program: ExportedProgram,
104        compile_specs: List[CompileSpec],
105    ) -> PreprocessResult:
106
107        xnnpack_edge_compile_config = get_xnnpack_edge_compile_config()
108
109        # Need to wrap EP here because xnnpack does addmm to linear
110        # transforms. This makes resulting graph not aten compliant
111        # as aten.linear is not a core aten op.
112        # Ideal fix would be to have XNNPACK verifier that bypass
113        # most checks but the base Verifier itself has some strict changes
114        # and to bypass those, we would basically copy what EdgeDialectVerifier
115        # does. So for now instead of copy pasting that, just instantiate
116        # EdgeDialectVerifier, but disable it.
117        # TODO (task link) to implement NullVerifier or something similar
118        ep = ExportedProgram(
119            root=edge_program.graph_module,
120            graph=edge_program.graph,
121            graph_signature=edge_program.graph_signature,
122            state_dict=edge_program.state_dict,
123            range_constraints=edge_program.range_constraints,
124            module_call_graph=edge_program.module_call_graph,
125            example_inputs=edge_program.example_inputs,
126            constants=edge_program.constants,
127            verifiers=[
128                EXIREdgeDialectVerifier(
129                    edge_compile_config=xnnpack_edge_compile_config, class_only=True
130                )
131            ],
132        )
133
134        passes = []
135        for spec in compile_specs:
136            if spec.key == "dqlinear_partitioner":
137                passes.append(ConvertToLinearPass)
138                passes.append(TagImplicitQDqPass)
139
140        passes = passes if len(passes) > 0 else None
141        # XNNPACK Delegate Specific Passes
142        ep = XNNPACKPassManager(ep, passes=passes).transform()
143        graph_module = ep.graph_module
144
145        node_to_external_map = generate_node_to_external_map(ep, graph_module)
146
147        # Make sure all inputs are contiguous_format or NCHW or default dim order
148        assert_default_dim_order(graph_module)
149
150        # TODO retrace the graph module to lift the new params may have
151        # been added to the graph in passes
152
153        vals_to_ids = {}
154        xnnpack_graph = XNNGraph(
155            version="0",
156            xnodes=[],
157            xvalues=[],
158            num_externs=len(node_to_external_map),
159            input_ids=[],
160            output_ids=[],
161            constant_data=[ConstantDataOffset(0, 0)],
162        )
163
164        constant_data_bytes = bytearray()
165        node_visitors = get_node_visitors(ep, node_to_external_map, constant_data_bytes)
166
167        for node in graph_module.graph.nodes:
168            if node.op == "call_function":
169                logger.info(f"Visiting: {node}, {node.target.__name__}")
170                if node.target.__name__ in node_visitors:
171                    node_visitors[node.target.__name__].define_node(
172                        node,
173                        xnnpack_graph,
174                        vals_to_ids,
175                        node.meta.get("debug_handle", DEFAULT_DEBUG_HANDLE),
176                    )
177                else:
178                    raise RuntimeError(
179                        f"For {node}, {node.op}:{node.target.__name__} is not supported in XNNPACK Delegate"
180                    )
181            elif node.op in [
182                "get_attr",
183                "placeholder",
184                "output",
185            ]:
186                continue
187            else:
188                raise RuntimeError(f"{node.op} is not supported in XNNPACK")
189        return PreprocessResult(
190            processed_bytes=serialize_xnnpack_binary(
191                xnnpack_graph, constant_data_bytes
192            ),
193            debug_handle_map={},
194        )
195