xref: /aosp_15_r20/external/executorch/exir/backend/backend_api.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import copy
8import logging
9from contextlib import contextmanager, nullcontext
10from functools import singledispatch
11from typing import Generator, List
12
13import torch
14
15from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
16from executorch.exir.backend.compile_spec_schema import CompileSpec
17
18from executorch.exir.backend.partitioner import Partitioner, PartitionResult
19from executorch.exir.backend.utils import (
20    _maybe_duplicate_constant_nodes,
21    is_identical_graph,
22)
23
24from executorch.exir.delegate import executorch_call_delegate, get_lowered_module_name
25
26from executorch.exir.graph_module import get_control_flow_submodules
27from executorch.exir.lowered_backend_module import (
28    _unsafe_adjust_original_program,
29    create_exported_program_from_submodule,
30    create_submodule_from_nodes,
31    LoweredBackendModule,
32)
33from executorch.exir.program._fake_program import (
34    get_fake_program,
35    update_to_real_program,
36)
37from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
38from torch.export import ExportedProgram
39
40
41@singledispatch
42def to_backend(args):
43    """
44    A generic function the dispatch happens on the type of the first argument. There are currently to overloaded to_backend function:
45
46    Note: Python is dynamically-typed language and therefore cannot have proper method overloading as that requires the language to
47    be able to discriminate between types at compile-time. @to_backend.register will attach the function to to_backend() base on the type of the first
48    argument (type annotation is required). However, it can't take multiple types as arguments.
49
50    ::
51
52     def to_backend(
53         backend_id: str,
54         edge_graph_module: ExportedProgram,
55         compile_specs: List[CompileSpec],
56     ) -> LoweredBackendModule:
57
58     def to_backend(
59         graph_module: torch.fx.GraphModule,
60         partitioner: Type[TPartitioner],
61     ) -> torch.fx.GraphModule
62    """
63    pass
64
65
66@to_backend.register
67def _(
68    backend_id: str,
69    edge_program: ExportedProgram,
70    compile_specs: List[CompileSpec],
71) -> LoweredBackendModule:
72    """
73    Add overloaded implementations for to_backend:
74
75    ::
76
77     def to_backend(
78         backend_id: str,
79         edge_program: ExportedProgram,
80         compile_specs: List[CompileSpec],
81     ) -> LoweredBackendModule:
82
83
84    Requires the passed in exported program in Edge dialect to be executed in
85    the backend identified by backend_id. The forward method of the given
86    edge_graph_module will be targeted for execution.
87
88    Args:
89        backend_id: The backend identifier.
90        exported_program: An exported program in Edge dialect to target for
91        lowering to the backend.
92        compile_specs: A list of backend-specific objects with static
93            metadata to configure the "compilation" process (e.g. it could be
94            another dictionary itself).
95
96    Returns:
97        LoweredBackendModule: A Module that has been lowered to the target backend.
98        Internally, the lowered Module contains these special attributes:
99        backend_id (str: backend id), __processed_module__ (str: a compiled module)
100        compile_spec, original_module (original exported program)
101
102    Raises:
103        NotImplementedError: The backend is not implemented (e.g. it was not found).
104        This exception is derived from RuntimeError and should be caught accordingly.
105        RuntimeError: The module cannot be processed by the backend.
106    """
107    assert isinstance(edge_program, ExportedProgram)
108
109    # All backend implementation are final, so we don't need to consider nested subclasses.
110    for cls in BackendDetails.__subclasses__():
111        if backend_id == cls.__name__:
112            copied_edge_program = copy.deepcopy(edge_program)
113            preprocess_result: PreprocessResult = cls.preprocess(
114                copied_edge_program,
115                compile_specs,
116            )
117            lowered_module = LoweredBackendModule(
118                edge_program=edge_program,
119                backend_id=backend_id,
120                processed_bytes=preprocess_result.processed_bytes,
121                compile_specs=compile_specs,
122            )
123            lowered_module.meta = {
124                "debug_handle_map": preprocess_result.debug_handle_map
125            }
126            return lowered_module
127    raise NotImplementedError(f"Backend {backend_id} was not found.")
128
129
130_ENABLE_VALIDATION: bool = True
131
132
133def disable_validation() -> None:
134    """Disables validation"""
135    global _ENABLE_VALIDATION
136    _ENABLE_VALIDATION = False
137
138
139@contextmanager
140def validation_disabled() -> Generator[None, None, None]:
141    """
142    Disables checking functions (ex. if the partitioned graph is identical to
143    the original graph). This context manager should only be used in certain
144    scenarios (such as when it has been profiled that checks are taking too
145    long, and are not necessarily needed)
146    """
147    global _ENABLE_VALIDATION
148    existing_setting = _ENABLE_VALIDATION
149    disable_validation()
150    try:
151        yield
152    finally:
153        _ENABLE_VALIDATION = existing_setting
154
155
156def _get_node_list_with_same_tag(
157    tagged_graph_module: torch.fx.GraphModule,
158    tag: str,
159    owning_program: ExportedProgram,
160) -> List[torch.fx.Node]:
161    """
162    Return a list of nodes with the same tag.
163    """
164    node_list = []
165
166    for node in tagged_graph_module.graph.nodes:
167        if node.meta.get("delegation_tag", "") == tag:
168            if node.op == "output":
169                raise RuntimeError(f"output node {node} should not be tagged")
170            if node.op == "placeholder":
171                if (
172                    not is_param(owning_program, node)
173                    and not is_buffer(owning_program, node)
174                    and not is_lifted_tensor_constant(owning_program, node)
175                ):
176                    raise RuntimeError(
177                        f"placeholder node for non-params, non-buffer, and non-tensor constants should not be tagged: {node} "
178                    )
179                else:
180                    # check that the users all belong to the same tag
181                    for user in node.users:
182                        users_tag = user.meta.get("delegation_tag", None)
183                        if users_tag != tag:
184                            raise RuntimeError(
185                                f"constant data node ({node}) is tagged with ({tag}) but has user ({user}) which has tag ({users_tag})"
186                            )
187            node_list.append(node)
188    return node_list
189
190
191def _partition_and_lower_one_graph_module(
192    tagged_graph_module: torch.fx.GraphModule,
193    partition_result: PartitionResult,
194    owning_program: ExportedProgram,
195    is_submodule: bool,
196) -> torch.fx.GraphModule:
197    """
198    Partitioned and lowered the graph module based on the partition tag, this is to handle one graph module.
199    """
200    for tag, delegation_spec in partition_result.partition_tags.items():
201        # Create partition with nodes containing this tag. There should only be
202        # one contained submodule per tag
203        node_list = _get_node_list_with_same_tag(
204            tagged_graph_module, tag, owning_program
205        )
206
207        if len(node_list) == 0:
208            logging.debug(f"Did not find any nodes for tag {tag}")
209            continue
210
211        logging.debug(f"For tag {tag}, found nodes {node_list}")
212        # Tag the nodes that are params as buffers, so we can order the submodule as (Parms + Buffers) (User Inputs)
213
214        replace_ctx = (
215            tagged_graph_module._set_replace_hook(
216                owning_program.graph_signature.get_replace_hook()
217            )
218            if not is_submodule
219            else nullcontext()
220        )
221        with replace_ctx:
222            submodule, call_module_node = create_submodule_from_nodes(
223                tagged_graph_module, node_list, tag
224            )
225
226        tagged_graph_module_output_node = [
227            node for node in tagged_graph_module.graph.nodes if node.op == "output"
228        ][0]
229        submodule_output_node = [
230            node for node in submodule.graph.nodes if node.op == "output"
231        ][0]
232        # Copy the output node meta from the original output node, because
233        # create_submodule_from_nodes doesn't cover the meta field
234        submodule_output_node.meta = tagged_graph_module_output_node.meta
235        logging.debug(f"Partitioned graph module: {tagged_graph_module}")
236
237        (
238            submodule_program,
239            toplevel_input_specs_to_delete,
240            toplevel_output_specs_to_delete,
241        ) = create_exported_program_from_submodule(
242            submodule,
243            owning_program,
244            tag,
245            call_module_node,
246            is_submodule,
247        )
248
249        lowered_submodule = to_backend(
250            delegation_spec.backend_id,
251            submodule_program,
252            delegation_spec.compile_specs,
253        )
254
255        # call delegate args should only use user_inputs
256        call_delegate_args = []
257        # Preserve input order as user_inputs
258        for inp_name in submodule_program.graph_signature.user_inputs:
259            for inp_node in call_module_node.all_input_nodes:
260                if inp_node.name == inp_name:
261                    call_delegate_args.append(inp_node)
262                    break
263
264        def generate_debug_handle(ep: ExportedProgram) -> int:
265            """
266            Generate a debug handle for the given ExportedProgram.
267            """
268            debug_handle = 0
269            for node in ep.graph_module.graph.nodes:
270                debug_handle = max(debug_handle, node.meta.get("debug_handle", 0))
271            return debug_handle + 1
272
273        # Replace the partitioned submodule with a lowered submodule
274        # Add call_method node with function "forward"
275        with tagged_graph_module.graph.inserting_before(call_module_node):
276            lowered_name = get_lowered_module_name(
277                tagged_graph_module, lowered_submodule
278            )
279            lowered_node = tagged_graph_module.graph.get_attr(lowered_name)
280            call_delegate_node = tagged_graph_module.graph.call_function(
281                executorch_call_delegate,
282                (lowered_node,) + tuple(call_delegate_args),
283                call_module_node.kwargs,
284            )
285            call_delegate_node.meta["debug_handle"] = generate_debug_handle(
286                owning_program
287            )
288            call_delegate_node.meta["val"] = submodule_output_node.meta["val"]
289            call_module_node.replace_all_uses_with(call_delegate_node)
290            tagged_graph_module.graph.erase_node(call_module_node)
291
292        if is_submodule:
293            assert len(toplevel_input_specs_to_delete) == 0
294            assert len(toplevel_output_specs_to_delete) == 0
295        elif (
296            len(toplevel_input_specs_to_delete) > 0
297            or len(toplevel_output_specs_to_delete) > 0
298        ):
299            _unsafe_adjust_original_program(
300                owning_program,
301                call_delegate_node,
302                toplevel_input_specs_to_delete,
303                toplevel_output_specs_to_delete,
304            )
305
306    return tagged_graph_module
307
308
309def _partition_and_lower(
310    tagged_graph_module: torch.fx.GraphModule,
311    partition_result: PartitionResult,
312    owning_program: ExportedProgram,
313    is_submodule: bool = False,
314) -> torch.fx.GraphModule:
315    """
316    Partitions the graph module into submodules based on tags, and then lowered the nodes with the same tag as one lowered module, including the submodule from control flow
317    """
318
319    partitioned_module = _partition_and_lower_one_graph_module(
320        tagged_graph_module, partition_result, owning_program, is_submodule
321    )
322
323    # Recursively partition and lower for submodules
324    for name, submod, _node in get_control_flow_submodules(partitioned_module):
325        partitioned_submodule = _partition_and_lower(
326            submod, partition_result, owning_program, is_submodule=True
327        )
328        tagged_graph_module.add_module(name, partitioned_submodule)
329
330    return tagged_graph_module
331
332
333@to_backend.register
334def _(
335    edge_program: ExportedProgram,
336    partitioner_instance: Partitioner,
337) -> ExportedProgram:
338    """
339    Add overloaded implementations for to_backend:
340
341    ::
342
343     def to_backend(
344         edge_program: ExportedProgram,
345         partitioner: Partitioner,
346     ) -> ExportedProgram:
347
348    Returns a semantically-equivalent program to the one given as input (represented
349    as a graph module in Edge dialect), but with portions of the program targeted for
350    delegation as determined by the partitioner.
351
352    Args:
353        ExportedProgram: Program in Edge dialect.
354
355        partitioner: An instance of the partitioner, in charge with tagging
356        portions of the input program for delegation. A valid partitioner must return PartitionerResult
357        including both tagged exported program and partitioner_tag: Dict[str, DelegationSpec], where each key is a tag name and
358        the nodes with same tag will be fused a one subgraph and delegated to backend specififed in delegation spec.
359
360
361    Returns:
362        ExportedProgram: The input program, with some portions targeted for delegation.
363    """
364    edge_program._validate()
365
366    # Use fake program, with FakeTensors in the state dict, to avoid copying large constant values.
367    # Fall back to deepcopy if no fake mode is found. TODO(T182910699): Remove this fallback.
368    try:
369        fake_edge_program = get_fake_program(edge_program)
370    except Exception as e:
371        logging.warning(
372            f"Error in get_fake_program for graph {edge_program.graph_module}, fallback to deepcopy: {e}"
373        )
374        fake_edge_program = copy.deepcopy(edge_program)
375    partitioner_result = partitioner_instance(fake_edge_program)
376    tagged_exported_program = partitioner_result.tagged_exported_program
377
378    # Check that the partitioner did not modify the original graph
379    if _ENABLE_VALIDATION:
380        assert is_identical_graph(
381            tagged_exported_program.graph_module,
382            edge_program.graph_module,
383        ), f"The partitioner {partitioner_instance} should not modify the graph module"
384    else:
385        logging.warning("Disabled validating the partitioner.")
386
387    assert (
388        partitioner_result.partition_tags is not None
389    ), f"Partitioner {partitioner_instance} needs a `partition_tags` field containing a mapping of tags to delegate spec"
390
391    update_to_real_program(tagged_exported_program, edge_program)
392
393    for tag, _ in partitioner_result.partition_tags.items():
394        _maybe_duplicate_constant_nodes(tagged_exported_program, tag)
395
396    tagged_graph_module = _partition_and_lower(
397        tagged_exported_program.graph_module,
398        partitioner_result,
399        tagged_exported_program,
400    )
401
402    return ExportedProgram(
403        root=tagged_graph_module,
404        graph=tagged_graph_module.graph,
405        graph_signature=tagged_exported_program.graph_signature,
406        state_dict=tagged_exported_program.state_dict,
407        range_constraints=copy.deepcopy(tagged_exported_program.range_constraints),
408        module_call_graph=copy.deepcopy(tagged_exported_program.module_call_graph),
409        example_inputs=None,
410        constants=tagged_exported_program.constants,
411        verifiers=[tagged_exported_program.verifier],
412    )
413