xref: /aosp_15_r20/external/executorch/exir/emit/_emit_program.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8from dataclasses import dataclass
9from typing import Any, Dict, List, Optional, Union
10
11import torch
12import torch.fx
13from executorch.exir.emit._emitter import (
14    _DelegateDebugIdentifierMap,
15    _EmitterState,
16    _ProgramState,
17    _TopLevelEmitter,
18)
19from executorch.exir.error import ExportError, ExportErrorType
20
21from executorch.exir.schema import Buffer, Program, SubsegmentOffsets
22from executorch.exir.version import EXECUTORCH_SCHEMA_VERSION
23from torch.export.exported_program import ExportedProgram, OutputKind
24from torch.utils import _pytree as pytree
25
26
27@dataclass
28class EmitterOutput:
29    """
30    The outputs of program emission. Contains the executorch program object as well as
31    a mapping of instruction ids to debug handles.
32    """
33
34    # The ExecuTorch program
35    program: Program
36
37    # This dictionary maps the instruction ids to their corresponding
38    # debug handles or list of debug handles in the case of delegate calls.
39    debug_handle_map: Dict[int, Union[int, List[int]]]
40
41    # This dictionary maps the method name to the corresponding dict which
42    # contains the mapping of the delegate instruction id to its corresponding
43    # delegate name and delegate debug identifier mapping.
44    method_to_delegate_debug_id_map: Dict[
45        str, Dict[int, Dict[str, Union[str, _DelegateDebugIdentifierMap]]]
46    ]
47
48    mutable_data: Optional[List[Buffer]]
49
50
51def _remove_non_user_outputs(exported_program: ExportedProgram) -> torch.fx.GraphModule:
52    gm = exported_program.graph_module
53    output_node = None
54    for node in gm.graph.nodes:
55        if node.op == "output":
56            output_node = node
57    assert output_node is not None
58
59    mutated_outputs: List[Optional[str]] = [
60        out_spec.target if out_spec.kind in (OutputKind.BUFFER_MUTATION,) else None
61        for out_spec in exported_program.graph_signature.output_specs
62    ]
63    outputs = pytree.tree_flatten(output_node.args)[0]
64
65    user_output_nodes = []
66    for return_node, mutated_node_name in zip(outputs, mutated_outputs):
67        if mutated_node_name is None:
68            user_output_nodes.append(return_node)
69            continue
70
71    with gm.graph.inserting_before(output_node):
72        # Only return user outputs
73        new_output = gm.graph.output(tuple(user_output_nodes))
74        new_output.meta = output_node.meta.copy()
75        output_node.replace_all_uses_with(new_output)
76        gm.graph.erase_node(output_node)
77
78    return gm
79
80
81# For each entry point in the model, determine if its a joint graph,
82# and if it is return a map of the indices in the model output that the
83# gradient outputs start at and that the parameter outputs start at.
84def _get_training_metadata(methods: Dict[str, ExportedProgram]) -> Dict[str, int]:
85    gradients_method_prefix = "__et_training_gradients_index_"
86    parameters_method_prefix = "__et_training_parameters_index_"
87    fqn_method_prefix = "__et_training_fqn_"
88    training_metadata = {}
89    for name, method in methods.items():
90        found_grad = False
91        found_param = False
92        fqns = []
93        i = 0
94        for output_spec in method.graph_signature.output_specs:
95            if output_spec.kind == OutputKind.GRADIENT_TO_PARAMETER:
96                if not found_grad:
97                    training_metadata[gradients_method_prefix + name] = i
98                    found_grad = True
99                fqns.append(output_spec.target)
100            elif output_spec.kind == OutputKind.TOKEN and not found_param:
101                assert found_grad  # Params must come after gradients
102                training_metadata[parameters_method_prefix + name] = i
103                found_param = True
104            i += 1
105            if len(fqns) > 0:
106                training_metadata[fqn_method_prefix + name] = fqns
107    return training_metadata
108
109
110def emit_program(
111    methods: Union[ExportedProgram, Dict[str, ExportedProgram]],
112    emit_stacktrace: bool = False,
113    prim_getters: Optional[Dict[str, Any]] = None,
114) -> EmitterOutput:
115    """
116    Given a exported program, it returns the program in the format
117    of the Python version of the flatbuffer Program schema.
118
119    Args:
120        methods: Either the exported program (Exported_Program) that we want to
121            emit into the flatbuffer, or a dictionary of method names to
122            ExportedPrograms.
123        emit_stacktrace: Flag to enable emission of a stacktrace for each
124           instruction for debugging purposes
125
126    Return:
127        The program in a Python class which mimics the flatbuffer schema
128    """
129
130    if isinstance(methods, ExportedProgram):
131        methods = {"forward": methods}
132
133    # validation
134    bad_methods = []
135    for name, exported_program in methods.items():
136        if not isinstance(exported_program, ExportedProgram):
137            bad_methods.append(name)
138    if len(bad_methods) != 0:
139        raise ExportError(
140            ExportErrorType.INVALID_INPUT_TYPE,
141            f"Did not receive ExportedProgram for the following methods {str(bad_methods)}",
142        )
143
144    plans = []
145    debug_handle_map = {}
146    method_to_delegate_debug_id_map = {}
147    program_state = _ProgramState()
148
149    # emit each entry point in order according to name.
150    for name, exported_program in sorted(methods.items()):
151        # create empty state
152        emitter_state = _EmitterState(
153            values=[],
154            operators=[],
155            delegates=[],
156            operator_cache={},
157            delegate_cache={},
158            emit_stacktrace=emit_stacktrace,
159        )
160
161        gm = _remove_non_user_outputs(exported_program)
162
163        emitter = _TopLevelEmitter(
164            name, exported_program, gm, program_state, emitter_state
165        )
166
167        emitter.run()
168        plans.append(emitter.plan())
169
170        debug_handle_map[name] = emitter.debug_handle_map
171        method_to_delegate_debug_id_map[name] = (
172            emitter.instr_id_to_delegate_debug_id_map
173        )
174
175    training_metadata = _get_training_metadata(methods)
176    if len(training_metadata) > 0:
177        plans.extend(emitter._emit_prim_getters(training_metadata))
178
179    # emit any primitive getters
180    if prim_getters is not None:
181        plans.extend(emitter._emit_prim_getters(prim_getters))
182
183    return EmitterOutput(
184        debug_handle_map=debug_handle_map,
185        method_to_delegate_debug_id_map=method_to_delegate_debug_id_map,
186        program=Program(
187            version=EXECUTORCH_SCHEMA_VERSION,
188            execution_plan=plans,
189            constant_buffer=program_state.constant_buffer,
190            backend_delegate_data=program_state.backend_delegate_data,
191            # Segments may be added at serialization time.
192            segments=[],
193            # Subsegment offsets may be added at serialization time.
194            constant_segment=SubsegmentOffsets(segment_index=0, offsets=[]),
195            mutable_data_segments=None,  # Will be filled in during serialization
196        ),
197        mutable_data=(
198            program_state.mutable_buffer
199            if len(program_state.mutable_buffer) > 1
200            else None
201        ),
202    )
203