xref: /aosp_15_r20/external/executorch/exir/lowered_backend_module.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Worker# pyre-strict
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Workerimport copy
10*523fa7a6SAndroid Build Coastguard Workerimport operator
11*523fa7a6SAndroid Build Coastguard Workerfrom collections import defaultdict
12*523fa7a6SAndroid Build Coastguard Workerfrom typing import Any, Dict, List, Optional, Set, Tuple, Union
13*523fa7a6SAndroid Build Coastguard Worker
14*523fa7a6SAndroid Build Coastguard Workerimport torch
15*523fa7a6SAndroid Build Coastguard Workerimport torch.utils._pytree as pytree
16*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir._serialize import _serialize_pte_binary
17*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.compile_spec_schema import CompileSpec
18*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.delegate import executorch_call_delegate, get_lowered_module_name
19*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.emit import emit_program
20*523fa7a6SAndroid Build Coastguard Worker
21*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.graph_module import _get_submodule
22*523fa7a6SAndroid Build Coastguard Worker
23*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes.memory_planning_pass import MemoryPlanningPass
24*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes.spec_prop_pass import make_spec, SpecPropPass
25*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.schema import Program
26*523fa7a6SAndroid Build Coastguard Worker
27*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.tracer import Value
28*523fa7a6SAndroid Build Coastguard Workerfrom torch._library.fake_class_registry import FakeScriptObject
29*523fa7a6SAndroid Build Coastguard Worker
30*523fa7a6SAndroid Build Coastguard Workerfrom torch._subclasses import FakeTensor
31*523fa7a6SAndroid Build Coastguard Workerfrom torch.export.exported_program import (
32*523fa7a6SAndroid Build Coastguard Worker    ConstantArgument,
33*523fa7a6SAndroid Build Coastguard Worker    ExportedProgram,
34*523fa7a6SAndroid Build Coastguard Worker    ExportGraphSignature,
35*523fa7a6SAndroid Build Coastguard Worker    InputKind,
36*523fa7a6SAndroid Build Coastguard Worker    InputSpec,
37*523fa7a6SAndroid Build Coastguard Worker    ModuleCallEntry,
38*523fa7a6SAndroid Build Coastguard Worker    ModuleCallSignature,
39*523fa7a6SAndroid Build Coastguard Worker    OutputKind,
40*523fa7a6SAndroid Build Coastguard Worker    OutputSpec,
41*523fa7a6SAndroid Build Coastguard Worker    TensorArgument,
42*523fa7a6SAndroid Build Coastguard Worker)
43*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx.passes.utils.fuser_utils import (
44*523fa7a6SAndroid Build Coastguard Worker    erase_nodes,
45*523fa7a6SAndroid Build Coastguard Worker    fuse_as_graphmodule,
46*523fa7a6SAndroid Build Coastguard Worker    insert_subgm,
47*523fa7a6SAndroid Build Coastguard Worker    legalize_graph,
48*523fa7a6SAndroid Build Coastguard Worker    NodeList,
49*523fa7a6SAndroid Build Coastguard Worker    topo_sort,
50*523fa7a6SAndroid Build Coastguard Worker)
51*523fa7a6SAndroid Build Coastguard Worker
52*523fa7a6SAndroid Build Coastguard Worker
53*523fa7a6SAndroid Build Coastguard Workerclass LoweredBackendModule(torch.nn.Module):
54*523fa7a6SAndroid Build Coastguard Worker    """
55*523fa7a6SAndroid Build Coastguard Worker    A subclass of nn.Module that is generated for modules containing
56*523fa7a6SAndroid Build Coastguard Worker    delegated functions. This is can be created by calling `to_backend`.
57*523fa7a6SAndroid Build Coastguard Worker    """
58*523fa7a6SAndroid Build Coastguard Worker
59*523fa7a6SAndroid Build Coastguard Worker    _backend_id: str  # The backend's name
60*523fa7a6SAndroid Build Coastguard Worker    _processed_bytes: bytes  # The delegate blobs created from backend.preprocess
61*523fa7a6SAndroid Build Coastguard Worker    _compile_specs: List[
62*523fa7a6SAndroid Build Coastguard Worker        CompileSpec
63*523fa7a6SAndroid Build Coastguard Worker    ]  # A list of backend-specific objects with static metadata to configure the "compilation" process.
64*523fa7a6SAndroid Build Coastguard Worker    _original_exported_program: ExportedProgram  # The original EXIR module
65*523fa7a6SAndroid Build Coastguard Worker
66*523fa7a6SAndroid Build Coastguard Worker    def __init__(
67*523fa7a6SAndroid Build Coastguard Worker        self,
68*523fa7a6SAndroid Build Coastguard Worker        edge_program: ExportedProgram,
69*523fa7a6SAndroid Build Coastguard Worker        backend_id: str,
70*523fa7a6SAndroid Build Coastguard Worker        processed_bytes: bytes,
71*523fa7a6SAndroid Build Coastguard Worker        compile_specs: List[CompileSpec],
72*523fa7a6SAndroid Build Coastguard Worker    ) -> None:
73*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
74*523fa7a6SAndroid Build Coastguard Worker        self._original_exported_program = edge_program
75*523fa7a6SAndroid Build Coastguard Worker        self._backend_id = backend_id
76*523fa7a6SAndroid Build Coastguard Worker        self._processed_bytes = processed_bytes
77*523fa7a6SAndroid Build Coastguard Worker        self._compile_specs = compile_specs
78*523fa7a6SAndroid Build Coastguard Worker
79*523fa7a6SAndroid Build Coastguard Worker    # pyre-ignore
80*523fa7a6SAndroid Build Coastguard Worker    def __deepcopy__(self, memo: Optional[Dict[int, Any]]) -> "LoweredBackendModule":
81*523fa7a6SAndroid Build Coastguard Worker        # Copy exported program
82*523fa7a6SAndroid Build Coastguard Worker        copied_program = ExportedProgram(
83*523fa7a6SAndroid Build Coastguard Worker            root=copy.deepcopy(self._original_exported_program.graph_module),
84*523fa7a6SAndroid Build Coastguard Worker            graph=copy.deepcopy(self._original_exported_program.graph),
85*523fa7a6SAndroid Build Coastguard Worker            graph_signature=copy.deepcopy(
86*523fa7a6SAndroid Build Coastguard Worker                self._original_exported_program.graph_signature
87*523fa7a6SAndroid Build Coastguard Worker            ),
88*523fa7a6SAndroid Build Coastguard Worker            state_dict=self._original_exported_program.state_dict,
89*523fa7a6SAndroid Build Coastguard Worker            range_constraints=copy.deepcopy(
90*523fa7a6SAndroid Build Coastguard Worker                self._original_exported_program.range_constraints
91*523fa7a6SAndroid Build Coastguard Worker            ),
92*523fa7a6SAndroid Build Coastguard Worker            module_call_graph=copy.deepcopy(
93*523fa7a6SAndroid Build Coastguard Worker                self._original_exported_program.module_call_graph
94*523fa7a6SAndroid Build Coastguard Worker            ),
95*523fa7a6SAndroid Build Coastguard Worker            constants=self._original_exported_program.constants,
96*523fa7a6SAndroid Build Coastguard Worker            verifiers=[copy.deepcopy(self._original_exported_program.verifier)],
97*523fa7a6SAndroid Build Coastguard Worker        )
98*523fa7a6SAndroid Build Coastguard Worker
99*523fa7a6SAndroid Build Coastguard Worker        res = LoweredBackendModule(
100*523fa7a6SAndroid Build Coastguard Worker            edge_program=copied_program,
101*523fa7a6SAndroid Build Coastguard Worker            backend_id=self._backend_id,
102*523fa7a6SAndroid Build Coastguard Worker            processed_bytes=self._processed_bytes,
103*523fa7a6SAndroid Build Coastguard Worker            compile_specs=copy.deepcopy(self._compile_specs, memo),
104*523fa7a6SAndroid Build Coastguard Worker        )
105*523fa7a6SAndroid Build Coastguard Worker        # pyre-fixme[16]: `LoweredBackendModule` has no attribute `meta`.
106*523fa7a6SAndroid Build Coastguard Worker        res.meta = copy.copy(getattr(self, "meta", {}))
107*523fa7a6SAndroid Build Coastguard Worker        return res
108*523fa7a6SAndroid Build Coastguard Worker
109*523fa7a6SAndroid Build Coastguard Worker    @property
110*523fa7a6SAndroid Build Coastguard Worker    def backend_id(self) -> str:
111*523fa7a6SAndroid Build Coastguard Worker        """
112*523fa7a6SAndroid Build Coastguard Worker        Returns the backends name.
113*523fa7a6SAndroid Build Coastguard Worker        """
114*523fa7a6SAndroid Build Coastguard Worker        return self._backend_id
115*523fa7a6SAndroid Build Coastguard Worker
116*523fa7a6SAndroid Build Coastguard Worker    @property
117*523fa7a6SAndroid Build Coastguard Worker    def processed_bytes(self) -> bytes:
118*523fa7a6SAndroid Build Coastguard Worker        """
119*523fa7a6SAndroid Build Coastguard Worker        Returns the delegate blob created from backend.preprocess
120*523fa7a6SAndroid Build Coastguard Worker        """
121*523fa7a6SAndroid Build Coastguard Worker        return self._processed_bytes
122*523fa7a6SAndroid Build Coastguard Worker
123*523fa7a6SAndroid Build Coastguard Worker    @property
124*523fa7a6SAndroid Build Coastguard Worker    def compile_specs(self) -> List[CompileSpec]:
125*523fa7a6SAndroid Build Coastguard Worker        """
126*523fa7a6SAndroid Build Coastguard Worker        Returns a list of backend-specific objects with static metadata to configure the "compilation" process.
127*523fa7a6SAndroid Build Coastguard Worker        """
128*523fa7a6SAndroid Build Coastguard Worker        return self._compile_specs
129*523fa7a6SAndroid Build Coastguard Worker
130*523fa7a6SAndroid Build Coastguard Worker    @property
131*523fa7a6SAndroid Build Coastguard Worker    def original_module(self) -> ExportedProgram:
132*523fa7a6SAndroid Build Coastguard Worker        """
133*523fa7a6SAndroid Build Coastguard Worker        Returns the original EXIR module
134*523fa7a6SAndroid Build Coastguard Worker        """
135*523fa7a6SAndroid Build Coastguard Worker        return self._original_exported_program
136*523fa7a6SAndroid Build Coastguard Worker
137*523fa7a6SAndroid Build Coastguard Worker    # TODO(chenlai): consolidate the seriailization config with serialize_to_flatbuffer api
138*523fa7a6SAndroid Build Coastguard Worker    def buffer(
139*523fa7a6SAndroid Build Coastguard Worker        self,
140*523fa7a6SAndroid Build Coastguard Worker        extract_delegate_segments: bool = False,
141*523fa7a6SAndroid Build Coastguard Worker        segment_alignment: int = 128,
142*523fa7a6SAndroid Build Coastguard Worker        constant_tensor_alignment: Optional[int] = None,
143*523fa7a6SAndroid Build Coastguard Worker        delegate_alignment: Optional[int] = None,
144*523fa7a6SAndroid Build Coastguard Worker        memory_planning: MemoryPlanningPass = None,  # pyre-fixme[9]
145*523fa7a6SAndroid Build Coastguard Worker    ) -> bytes:
146*523fa7a6SAndroid Build Coastguard Worker        """
147*523fa7a6SAndroid Build Coastguard Worker        Returns a buffer containing the serialized ExecuTorch binary.
148*523fa7a6SAndroid Build Coastguard Worker        """
149*523fa7a6SAndroid Build Coastguard Worker        # TODO(T181463742): avoid calling bytes(..) which incurs large copies.
150*523fa7a6SAndroid Build Coastguard Worker        out = bytes(
151*523fa7a6SAndroid Build Coastguard Worker            _serialize_pte_binary(
152*523fa7a6SAndroid Build Coastguard Worker                program=self.program(memory_planning=memory_planning),
153*523fa7a6SAndroid Build Coastguard Worker                extract_delegate_segments=extract_delegate_segments,
154*523fa7a6SAndroid Build Coastguard Worker                segment_alignment=segment_alignment,
155*523fa7a6SAndroid Build Coastguard Worker                constant_tensor_alignment=constant_tensor_alignment,
156*523fa7a6SAndroid Build Coastguard Worker                delegate_alignment=delegate_alignment,
157*523fa7a6SAndroid Build Coastguard Worker            )
158*523fa7a6SAndroid Build Coastguard Worker        )
159*523fa7a6SAndroid Build Coastguard Worker        return out
160*523fa7a6SAndroid Build Coastguard Worker
161*523fa7a6SAndroid Build Coastguard Worker    # TODO(chenlai): re-consider recapture instead of manually constructing the program because
162*523fa7a6SAndroid Build Coastguard Worker    # the meta data construction is done manually.
163*523fa7a6SAndroid Build Coastguard Worker    def program(
164*523fa7a6SAndroid Build Coastguard Worker        self,
165*523fa7a6SAndroid Build Coastguard Worker        emit_stacktrace: bool = False,
166*523fa7a6SAndroid Build Coastguard Worker        memory_planning: MemoryPlanningPass = None,  # pyre-fixme[9]
167*523fa7a6SAndroid Build Coastguard Worker    ) -> Program:
168*523fa7a6SAndroid Build Coastguard Worker        # Fix autodpes introuces cyclic dependencies:
169*523fa7a6SAndroid Build Coastguard Worker        # program -> verifier -> lowered_backend_module -> program
170*523fa7a6SAndroid Build Coastguard Worker        # @manual
171*523fa7a6SAndroid Build Coastguard Worker        from executorch.exir.program._program import (
172*523fa7a6SAndroid Build Coastguard Worker            _get_updated_graph_signature,
173*523fa7a6SAndroid Build Coastguard Worker            _transform,
174*523fa7a6SAndroid Build Coastguard Worker        )
175*523fa7a6SAndroid Build Coastguard Worker
176*523fa7a6SAndroid Build Coastguard Worker        """
177*523fa7a6SAndroid Build Coastguard Worker        Returns the object that represents the ExecuTorch binary before serialization.
178*523fa7a6SAndroid Build Coastguard Worker        """
179*523fa7a6SAndroid Build Coastguard Worker        # Creates a new module based on the original module. The original module will
180*523fa7a6SAndroid Build Coastguard Worker        # look something like following:
181*523fa7a6SAndroid Build Coastguard Worker        #
182*523fa7a6SAndroid Build Coastguard Worker        # opcode         name                 target            args                                        kwargs
183*523fa7a6SAndroid Build Coastguard Worker        # -------------  -------------------  ----------------  ------------------------------------------  --------
184*523fa7a6SAndroid Build Coastguard Worker        # placeholder    arg0_1               arg0_1            ()                                          {}
185*523fa7a6SAndroid Build Coastguard Worker        # placeholder    arg1_1               arg1_1            ()                                          {}
186*523fa7a6SAndroid Build Coastguard Worker        # call_function  aten_repeat_default  *                 (arg1_1, [4, 1])                            {}
187*523fa7a6SAndroid Build Coastguard Worker        # call_function  aten_mul_tensor      *                 (aten_repeat_default, aten_repeat_default)  {}
188*523fa7a6SAndroid Build Coastguard Worker        # call_function  aten_add_tensor      *                 (arg1_1, arg1_1)                            {}
189*523fa7a6SAndroid Build Coastguard Worker        # output         output               output            ([aten_mul_tensor, aten_add_tensor],)       {}
190*523fa7a6SAndroid Build Coastguard Worker        #
191*523fa7a6SAndroid Build Coastguard Worker        # if the whole module is lowered, the resulting lowered module look like
192*523fa7a6SAndroid Build Coastguard Worker        #
193*523fa7a6SAndroid Build Coastguard Worker        # opcode         name                      target                       args                                kwargs
194*523fa7a6SAndroid Build Coastguard Worker        # -------------  ------------------------  ---------------------------  ----------------------------------  --------
195*523fa7a6SAndroid Build Coastguard Worker        # placeholder    arg0_1                    arg0_1                       ()                                  {}
196*523fa7a6SAndroid Build Coastguard Worker        # placeholder    arg1_1                    arg1_1                       ()                                  {}
197*523fa7a6SAndroid Build Coastguard Worker        # get_attr       lowered_module_0          lowered_module_0             ()                                  {}
198*523fa7a6SAndroid Build Coastguard Worker        # call_function  executorch_call_delegate  executorch_call_delegate     (lowered_module_0, arg0_1, arg1_1)  {}
199*523fa7a6SAndroid Build Coastguard Worker        # call_function  getitem                   <built-in function getitem>  (executorch_call_delegate, 0)       {}
200*523fa7a6SAndroid Build Coastguard Worker        # call_function  getitem_1                 <built-in function getitem>  (executorch_call_delegate, 1)       {}
201*523fa7a6SAndroid Build Coastguard Worker        # output         output_1                  output                       ([getitem, getitem_1],)             {}
202*523fa7a6SAndroid Build Coastguard Worker        #
203*523fa7a6SAndroid Build Coastguard Worker        # We'll remove all call_function nodes, insert an call_delegate node, inserting getitems nodes to get the result for call_delegate node
204*523fa7a6SAndroid Build Coastguard Worker        # and return the list of getitems as the output
205*523fa7a6SAndroid Build Coastguard Worker
206*523fa7a6SAndroid Build Coastguard Worker        lowered_exported_program = copy.deepcopy(self._original_exported_program)
207*523fa7a6SAndroid Build Coastguard Worker
208*523fa7a6SAndroid Build Coastguard Worker        # The real input nodes are the ones not buffer or parameter
209*523fa7a6SAndroid Build Coastguard Worker        all_input_nodes = [
210*523fa7a6SAndroid Build Coastguard Worker            node
211*523fa7a6SAndroid Build Coastguard Worker            for node in lowered_exported_program.graph.nodes
212*523fa7a6SAndroid Build Coastguard Worker            if (
213*523fa7a6SAndroid Build Coastguard Worker                node.op == "placeholder"
214*523fa7a6SAndroid Build Coastguard Worker                and node.name
215*523fa7a6SAndroid Build Coastguard Worker                not in lowered_exported_program.graph_signature.inputs_to_buffers
216*523fa7a6SAndroid Build Coastguard Worker                and node.name
217*523fa7a6SAndroid Build Coastguard Worker                not in lowered_exported_program.graph_signature.inputs_to_parameters
218*523fa7a6SAndroid Build Coastguard Worker            )
219*523fa7a6SAndroid Build Coastguard Worker        ]
220*523fa7a6SAndroid Build Coastguard Worker
221*523fa7a6SAndroid Build Coastguard Worker        output_node = [
222*523fa7a6SAndroid Build Coastguard Worker            node for node in lowered_exported_program.graph.nodes if node.op == "output"
223*523fa7a6SAndroid Build Coastguard Worker        ]
224*523fa7a6SAndroid Build Coastguard Worker        assert len(output_node) == 1, "There should be only one output node"
225*523fa7a6SAndroid Build Coastguard Worker
226*523fa7a6SAndroid Build Coastguard Worker        # Step 1. Cleaning up the graph before inserting the call_delegate node
227*523fa7a6SAndroid Build Coastguard Worker        # Remove the original output node
228*523fa7a6SAndroid Build Coastguard Worker        lowered_exported_program.graph.erase_node(output_node[0])
229*523fa7a6SAndroid Build Coastguard Worker
230*523fa7a6SAndroid Build Coastguard Worker        # Remove all the everything else except the input
231*523fa7a6SAndroid Build Coastguard Worker        for node in reversed(lowered_exported_program.graph.nodes):
232*523fa7a6SAndroid Build Coastguard Worker            if node.op != "placeholder":
233*523fa7a6SAndroid Build Coastguard Worker                lowered_exported_program.graph.erase_node(node)
234*523fa7a6SAndroid Build Coastguard Worker
235*523fa7a6SAndroid Build Coastguard Worker        # Find placeholders that are parameters or buffers, remove them from the main graph
236*523fa7a6SAndroid Build Coastguard Worker        for node in lowered_exported_program.graph.nodes:
237*523fa7a6SAndroid Build Coastguard Worker            if node.op == "placeholder" and (
238*523fa7a6SAndroid Build Coastguard Worker                node.name in lowered_exported_program.graph_signature.inputs_to_buffers
239*523fa7a6SAndroid Build Coastguard Worker                or node.name
240*523fa7a6SAndroid Build Coastguard Worker                in lowered_exported_program.graph_signature.inputs_to_parameters
241*523fa7a6SAndroid Build Coastguard Worker            ):
242*523fa7a6SAndroid Build Coastguard Worker                lowered_exported_program.graph.erase_node(node)
243*523fa7a6SAndroid Build Coastguard Worker
244*523fa7a6SAndroid Build Coastguard Worker        # Step 2. Start constructing the graph
245*523fa7a6SAndroid Build Coastguard Worker        lowered_name = get_lowered_module_name(
246*523fa7a6SAndroid Build Coastguard Worker            lowered_exported_program.graph_module, self
247*523fa7a6SAndroid Build Coastguard Worker        )
248*523fa7a6SAndroid Build Coastguard Worker        # Insert the lowered module to the graph module as an attibute
249*523fa7a6SAndroid Build Coastguard Worker        lowered_node = lowered_exported_program.graph.get_attr(lowered_name)
250*523fa7a6SAndroid Build Coastguard Worker
251*523fa7a6SAndroid Build Coastguard Worker        # Insert a call_delegate node to the graph module, with arguments from the arg list
252*523fa7a6SAndroid Build Coastguard Worker        delegate_node = lowered_exported_program.graph.call_function(
253*523fa7a6SAndroid Build Coastguard Worker            executorch_call_delegate, (lowered_node, *all_input_nodes)
254*523fa7a6SAndroid Build Coastguard Worker        )
255*523fa7a6SAndroid Build Coastguard Worker        # Get the output list. Since the output node is a tuple of list, like ([aten_mul_tensor, aten_add_tensor],)
256*523fa7a6SAndroid Build Coastguard Worker        # We add some handling logic to get the list `[aten_mul_tensor, aten_add_tensor]` properly
257*523fa7a6SAndroid Build Coastguard Worker        original_output_nodes = [
258*523fa7a6SAndroid Build Coastguard Worker            node
259*523fa7a6SAndroid Build Coastguard Worker            for node in self._original_exported_program.graph.nodes
260*523fa7a6SAndroid Build Coastguard Worker            if node.op == "output"
261*523fa7a6SAndroid Build Coastguard Worker        ][0].args[0]
262*523fa7a6SAndroid Build Coastguard Worker
263*523fa7a6SAndroid Build Coastguard Worker        delegate_node.meta["spec"] = tuple(
264*523fa7a6SAndroid Build Coastguard Worker            [make_spec(node.meta["val"]) for node in original_output_nodes]
265*523fa7a6SAndroid Build Coastguard Worker        )
266*523fa7a6SAndroid Build Coastguard Worker        delegate_node.meta["val"] = tuple(
267*523fa7a6SAndroid Build Coastguard Worker            [node.meta["val"] for node in original_output_nodes]
268*523fa7a6SAndroid Build Coastguard Worker        )
269*523fa7a6SAndroid Build Coastguard Worker
270*523fa7a6SAndroid Build Coastguard Worker        # The getitem nodes that are going to be inserted to the lowered graph module
271*523fa7a6SAndroid Build Coastguard Worker        getitem_nodes = []
272*523fa7a6SAndroid Build Coastguard Worker        for i in range(len(original_output_nodes)):
273*523fa7a6SAndroid Build Coastguard Worker            getitem_node = lowered_exported_program.graph.call_function(
274*523fa7a6SAndroid Build Coastguard Worker                operator.getitem,
275*523fa7a6SAndroid Build Coastguard Worker                args=(delegate_node, i),
276*523fa7a6SAndroid Build Coastguard Worker            )
277*523fa7a6SAndroid Build Coastguard Worker            getitem_node.meta["val"] = delegate_node.meta["val"][i]
278*523fa7a6SAndroid Build Coastguard Worker            getitem_nodes.append(getitem_node)
279*523fa7a6SAndroid Build Coastguard Worker        lowered_exported_program.graph.output(getitem_nodes)
280*523fa7a6SAndroid Build Coastguard Worker
281*523fa7a6SAndroid Build Coastguard Worker        lowered_exported_program.graph_module.recompile()
282*523fa7a6SAndroid Build Coastguard Worker        lowered_exported_program.graph.lint()
283*523fa7a6SAndroid Build Coastguard Worker
284*523fa7a6SAndroid Build Coastguard Worker        # Users output will be the get items nodes instead
285*523fa7a6SAndroid Build Coastguard Worker        output_specs = [
286*523fa7a6SAndroid Build Coastguard Worker            OutputSpec(
287*523fa7a6SAndroid Build Coastguard Worker                kind=OutputKind.USER_OUTPUT,
288*523fa7a6SAndroid Build Coastguard Worker                arg=TensorArgument(name=getitem_node.name),
289*523fa7a6SAndroid Build Coastguard Worker                target=None,
290*523fa7a6SAndroid Build Coastguard Worker            )
291*523fa7a6SAndroid Build Coastguard Worker            for getitem_node in getitem_nodes
292*523fa7a6SAndroid Build Coastguard Worker        ]
293*523fa7a6SAndroid Build Coastguard Worker        # All data are consumed by the delegates so they should be removed from the state dict.
294*523fa7a6SAndroid Build Coastguard Worker        inputs_to_parameters = (
295*523fa7a6SAndroid Build Coastguard Worker            lowered_exported_program.graph_signature.inputs_to_parameters
296*523fa7a6SAndroid Build Coastguard Worker        )
297*523fa7a6SAndroid Build Coastguard Worker        inputs_to_buffers = lowered_exported_program.graph_signature.inputs_to_buffers
298*523fa7a6SAndroid Build Coastguard Worker        input_specs = [
299*523fa7a6SAndroid Build Coastguard Worker            InputSpec(
300*523fa7a6SAndroid Build Coastguard Worker                kind=InputKind.USER_INPUT,
301*523fa7a6SAndroid Build Coastguard Worker                arg=TensorArgument(name=node.name),
302*523fa7a6SAndroid Build Coastguard Worker                target=None,
303*523fa7a6SAndroid Build Coastguard Worker            )
304*523fa7a6SAndroid Build Coastguard Worker            for user_input in lowered_exported_program.graph_signature.user_inputs
305*523fa7a6SAndroid Build Coastguard Worker            if user_input not in inputs_to_parameters
306*523fa7a6SAndroid Build Coastguard Worker            and user_input not in inputs_to_buffers
307*523fa7a6SAndroid Build Coastguard Worker        ]
308*523fa7a6SAndroid Build Coastguard Worker
309*523fa7a6SAndroid Build Coastguard Worker        # Double check the ExportedProgram data(especially everything except graph) is good
310*523fa7a6SAndroid Build Coastguard Worker        exported_program = ExportedProgram(
311*523fa7a6SAndroid Build Coastguard Worker            root=lowered_exported_program.graph_module,
312*523fa7a6SAndroid Build Coastguard Worker            graph=lowered_exported_program.graph,
313*523fa7a6SAndroid Build Coastguard Worker            graph_signature=_get_updated_graph_signature(
314*523fa7a6SAndroid Build Coastguard Worker                ExportGraphSignature(
315*523fa7a6SAndroid Build Coastguard Worker                    input_specs=input_specs, output_specs=output_specs
316*523fa7a6SAndroid Build Coastguard Worker                ),
317*523fa7a6SAndroid Build Coastguard Worker                lowered_exported_program.graph_module,
318*523fa7a6SAndroid Build Coastguard Worker            ),
319*523fa7a6SAndroid Build Coastguard Worker            # TODO: May need to set lowered_exported_program.call_spec = CallSpec(None, None)
320*523fa7a6SAndroid Build Coastguard Worker            # somewhere as we should pass it a list of tensors to the lowered module and output a
321*523fa7a6SAndroid Build Coastguard Worker            # list of tensors. Putting call_spec=lowered_exported_program.call_spec is correct here as the
322*523fa7a6SAndroid Build Coastguard Worker            # inputs/outputs to the toplevel program will be in the format of the eager module.
323*523fa7a6SAndroid Build Coastguard Worker            state_dict={},  # None because all data are consumed by delegate
324*523fa7a6SAndroid Build Coastguard Worker            range_constraints=lowered_exported_program.range_constraints,
325*523fa7a6SAndroid Build Coastguard Worker            module_call_graph=lowered_exported_program.module_call_graph,
326*523fa7a6SAndroid Build Coastguard Worker            example_inputs=None,
327*523fa7a6SAndroid Build Coastguard Worker            verifiers=[lowered_exported_program.verifier],
328*523fa7a6SAndroid Build Coastguard Worker        )
329*523fa7a6SAndroid Build Coastguard Worker        if memory_planning is None:
330*523fa7a6SAndroid Build Coastguard Worker            memory_planning = MemoryPlanningPass()
331*523fa7a6SAndroid Build Coastguard Worker        exported_program = _transform(exported_program, SpecPropPass(), memory_planning)
332*523fa7a6SAndroid Build Coastguard Worker        emitted_program = emit_program(
333*523fa7a6SAndroid Build Coastguard Worker            exported_program, emit_stacktrace=emit_stacktrace
334*523fa7a6SAndroid Build Coastguard Worker        ).program
335*523fa7a6SAndroid Build Coastguard Worker        return emitted_program
336*523fa7a6SAndroid Build Coastguard Worker
337*523fa7a6SAndroid Build Coastguard Worker    # Used to patch each delegated function with a call_delegate call
338*523fa7a6SAndroid Build Coastguard Worker    # @staticmethod
339*523fa7a6SAndroid Build Coastguard Worker    def forward(
340*523fa7a6SAndroid Build Coastguard Worker        self,
341*523fa7a6SAndroid Build Coastguard Worker        *args: Value,
342*523fa7a6SAndroid Build Coastguard Worker        **kwargs: Tuple[Value, ...],
343*523fa7a6SAndroid Build Coastguard Worker    ) -> Value:
344*523fa7a6SAndroid Build Coastguard Worker        return executorch_call_delegate(self, *args)
345*523fa7a6SAndroid Build Coastguard Worker
346*523fa7a6SAndroid Build Coastguard Worker
347*523fa7a6SAndroid Build Coastguard Worker# TODO(zhxchen17) Try ExportPass
348*523fa7a6SAndroid Build Coastguard Workerdef _fixup_output_node(gm: torch.fx.GraphModule) -> None:
349*523fa7a6SAndroid Build Coastguard Worker    for node in reversed(gm.graph.nodes):
350*523fa7a6SAndroid Build Coastguard Worker        if node.op == "output":
351*523fa7a6SAndroid Build Coastguard Worker            with gm.graph.inserting_before(node):
352*523fa7a6SAndroid Build Coastguard Worker                assert len(node.args) == 1
353*523fa7a6SAndroid Build Coastguard Worker                outputs = node.args[0]
354*523fa7a6SAndroid Build Coastguard Worker                if isinstance(outputs, torch.fx.Node):
355*523fa7a6SAndroid Build Coastguard Worker                    val = outputs.meta.get("val")
356*523fa7a6SAndroid Build Coastguard Worker                    if isinstance(val, list):
357*523fa7a6SAndroid Build Coastguard Worker                        # If a list is returned, in some cases it is represented as a
358*523fa7a6SAndroid Build Coastguard Worker                        # singular node, like `split_copy_tensor` but EXIR will return a
359*523fa7a6SAndroid Build Coastguard Worker                        # opened-up list like `[getitem1, getitem2]`
360*523fa7a6SAndroid Build Coastguard Worker                        outputs = [
361*523fa7a6SAndroid Build Coastguard Worker                            torch.fx.Proxy(outputs)[i].node for i in range(len(val))
362*523fa7a6SAndroid Build Coastguard Worker                        ]
363*523fa7a6SAndroid Build Coastguard Worker            returns, out_spec = pytree.tree_flatten(outputs)
364*523fa7a6SAndroid Build Coastguard Worker            node.args = (returns,)
365*523fa7a6SAndroid Build Coastguard Worker            return
366*523fa7a6SAndroid Build Coastguard Worker
367*523fa7a6SAndroid Build Coastguard Worker
368*523fa7a6SAndroid Build Coastguard Workerdef arrange_graph_placeholders(
369*523fa7a6SAndroid Build Coastguard Worker    gm: torch.fx.GraphModule, owning_program: ExportedProgram
370*523fa7a6SAndroid Build Coastguard Worker) -> torch.fx.GraphModule:
371*523fa7a6SAndroid Build Coastguard Worker    """
372*523fa7a6SAndroid Build Coastguard Worker    Modifies the graph of the given graphmodule with one that contains the same nodes as the original,
373*523fa7a6SAndroid Build Coastguard Worker    but with placeholders in order of (Params + Buffers) (User Inputs)
374*523fa7a6SAndroid Build Coastguard Worker
375*523fa7a6SAndroid Build Coastguard Worker    This is used by the delegate api which disturbs the placeholder ordering when creating a submodule
376*523fa7a6SAndroid Build Coastguard Worker    from partitioned nodes
377*523fa7a6SAndroid Build Coastguard Worker
378*523fa7a6SAndroid Build Coastguard Worker    Args:
379*523fa7a6SAndroid Build Coastguard Worker        gm: The graph module that we want arranged
380*523fa7a6SAndroid Build Coastguard Worker        owning_program: ExportedProgram that the submodule (gm) belongs to
381*523fa7a6SAndroid Build Coastguard Worker
382*523fa7a6SAndroid Build Coastguard Worker    Returns:
383*523fa7a6SAndroid Build Coastguard Worker        The graph module in-placed arranged
384*523fa7a6SAndroid Build Coastguard Worker    """
385*523fa7a6SAndroid Build Coastguard Worker    new_graph = torch.fx.Graph()
386*523fa7a6SAndroid Build Coastguard Worker
387*523fa7a6SAndroid Build Coastguard Worker    node_map = {}  # mapping of nodes from old graph to new graph
388*523fa7a6SAndroid Build Coastguard Worker
389*523fa7a6SAndroid Build Coastguard Worker    graph_sign = owning_program.graph_signature
390*523fa7a6SAndroid Build Coastguard Worker
391*523fa7a6SAndroid Build Coastguard Worker    # Add all placeholders into the graph first:
392*523fa7a6SAndroid Build Coastguard Worker    param_nodes = []
393*523fa7a6SAndroid Build Coastguard Worker    buffer_nodes = []
394*523fa7a6SAndroid Build Coastguard Worker    input_nodes = []
395*523fa7a6SAndroid Build Coastguard Worker    for node in gm.graph.nodes:
396*523fa7a6SAndroid Build Coastguard Worker        if node.op != "placeholder":
397*523fa7a6SAndroid Build Coastguard Worker            continue
398*523fa7a6SAndroid Build Coastguard Worker
399*523fa7a6SAndroid Build Coastguard Worker        if node.name in graph_sign.inputs_to_parameters:
400*523fa7a6SAndroid Build Coastguard Worker            param_nodes.append(node)
401*523fa7a6SAndroid Build Coastguard Worker        elif node.name in graph_sign.inputs_to_buffers:
402*523fa7a6SAndroid Build Coastguard Worker            buffer_nodes.append(node)
403*523fa7a6SAndroid Build Coastguard Worker        else:
404*523fa7a6SAndroid Build Coastguard Worker            input_nodes.append(node)
405*523fa7a6SAndroid Build Coastguard Worker
406*523fa7a6SAndroid Build Coastguard Worker    for param_node in param_nodes:
407*523fa7a6SAndroid Build Coastguard Worker        new_node = new_graph.node_copy(param_node, lambda x: node_map[x])
408*523fa7a6SAndroid Build Coastguard Worker        node_map[param_node] = new_node
409*523fa7a6SAndroid Build Coastguard Worker    for buffer_node in buffer_nodes:
410*523fa7a6SAndroid Build Coastguard Worker        new_node = new_graph.node_copy(buffer_node, lambda x: node_map[x])
411*523fa7a6SAndroid Build Coastguard Worker        node_map[buffer_node] = new_node
412*523fa7a6SAndroid Build Coastguard Worker    for input_node in input_nodes:
413*523fa7a6SAndroid Build Coastguard Worker        new_node = new_graph.node_copy(input_node, lambda x: node_map[x])
414*523fa7a6SAndroid Build Coastguard Worker        node_map[input_node] = new_node
415*523fa7a6SAndroid Build Coastguard Worker
416*523fa7a6SAndroid Build Coastguard Worker    # Now add all the other nodes in order
417*523fa7a6SAndroid Build Coastguard Worker    for node in gm.graph.nodes:
418*523fa7a6SAndroid Build Coastguard Worker        if node.op == "placeholder":
419*523fa7a6SAndroid Build Coastguard Worker            continue
420*523fa7a6SAndroid Build Coastguard Worker
421*523fa7a6SAndroid Build Coastguard Worker        new_node = new_graph.node_copy(node, lambda x: node_map[x])
422*523fa7a6SAndroid Build Coastguard Worker        node_map[node] = new_node
423*523fa7a6SAndroid Build Coastguard Worker
424*523fa7a6SAndroid Build Coastguard Worker    # lint to ensure correctness
425*523fa7a6SAndroid Build Coastguard Worker    new_graph.lint()
426*523fa7a6SAndroid Build Coastguard Worker
427*523fa7a6SAndroid Build Coastguard Worker    new_graph._codegen = gm.graph._codegen
428*523fa7a6SAndroid Build Coastguard Worker    gm.graph = new_graph
429*523fa7a6SAndroid Build Coastguard Worker
430*523fa7a6SAndroid Build Coastguard Worker    return gm
431*523fa7a6SAndroid Build Coastguard Worker
432*523fa7a6SAndroid Build Coastguard Worker
433*523fa7a6SAndroid Build Coastguard Worker# TODO Don't regenerate new signature manually.
434*523fa7a6SAndroid Build Coastguard Workerdef _get_new_signature(  # noqa: C901
435*523fa7a6SAndroid Build Coastguard Worker    original_program: ExportedProgram,
436*523fa7a6SAndroid Build Coastguard Worker    gm: torch.fx.GraphModule,
437*523fa7a6SAndroid Build Coastguard Worker    call_module_node: torch.fx.Node,
438*523fa7a6SAndroid Build Coastguard Worker    tag: str,
439*523fa7a6SAndroid Build Coastguard Worker    is_submodule: bool = False,
440*523fa7a6SAndroid Build Coastguard Worker) -> Tuple[
441*523fa7a6SAndroid Build Coastguard Worker    ExportGraphSignature,
442*523fa7a6SAndroid Build Coastguard Worker    Dict[str, Union[torch.Tensor, torch.nn.Parameter]],
443*523fa7a6SAndroid Build Coastguard Worker    Dict[str, Union[torch.Tensor, FakeScriptObject, torch.ScriptObject]],
444*523fa7a6SAndroid Build Coastguard Worker    Dict[str, InputSpec],
445*523fa7a6SAndroid Build Coastguard Worker    Dict[str, OutputSpec],
446*523fa7a6SAndroid Build Coastguard Worker]:
447*523fa7a6SAndroid Build Coastguard Worker    """
448*523fa7a6SAndroid Build Coastguard Worker    Args:
449*523fa7a6SAndroid Build Coastguard Worker        original_program: The original program that we are paritioning
450*523fa7a6SAndroid Build Coastguard Worker        gm: The partitioned graph module.
451*523fa7a6SAndroid Build Coastguard Worker        call_module_node: The node in the original program that is calling the
452*523fa7a6SAndroid Build Coastguard Worker            partitioned graph module.
453*523fa7a6SAndroid Build Coastguard Worker        tag: The tag being used for this partitioned submodule. This is used to
454*523fa7a6SAndroid Build Coastguard Worker            tell if a particular parameter/buffer/constant node is being tagged,
455*523fa7a6SAndroid Build Coastguard Worker            aka consumed by the delegate.
456*523fa7a6SAndroid Build Coastguard Worker        is_submodule: True if we are currently partitioning inside of a
457*523fa7a6SAndroid Build Coastguard Worker            submodule (like cond's submodule). If we are inside of a submodule,
458*523fa7a6SAndroid Build Coastguard Worker            we do not care about consuming params/buffers.
459*523fa7a6SAndroid Build Coastguard Worker
460*523fa7a6SAndroid Build Coastguard Worker    Returns:
461*523fa7a6SAndroid Build Coastguard Worker
462*523fa7a6SAndroid Build Coastguard Worker        new_signature (ExportGraphSignature): The new signature for the
463*523fa7a6SAndroid Build Coastguard Worker            partitioned graph module.
464*523fa7a6SAndroid Build Coastguard Worker        new_state_dict (Dict[str, Union[torch.Tensor, torch.nn.Parameter]]): The
465*523fa7a6SAndroid Build Coastguard Worker            new state dict containing the consumed params/buffers.
466*523fa7a6SAndroid Build Coastguard Worker        new_constants (Dict[str, Union[torch.Tensor, FakeScriptObject,
467*523fa7a6SAndroid Build Coastguard Worker            torch.ScriptObject]]): The new constants table containing the
468*523fa7a6SAndroid Build Coastguard Worker            consumed constants .
469*523fa7a6SAndroid Build Coastguard Worker        input_specs_to_delete (Dict[str, InputSpec]): The input specs that have
470*523fa7a6SAndroid Build Coastguard Worker            been consumed by the delegate (param/buffer input nodes) and should
471*523fa7a6SAndroid Build Coastguard Worker            be removed from the toplevel ExportedProgram.
472*523fa7a6SAndroid Build Coastguard Worker        output_specs_to_delete (Dict[str, InputSpec]): The output specs that have
473*523fa7a6SAndroid Build Coastguard Worker            been consumed by the delegate (buffer mutation nodes) and should be
474*523fa7a6SAndroid Build Coastguard Worker            removed from the toplevel ExportedProgram.
475*523fa7a6SAndroid Build Coastguard Worker    """
476*523fa7a6SAndroid Build Coastguard Worker    old_signature = original_program.graph_signature
477*523fa7a6SAndroid Build Coastguard Worker
478*523fa7a6SAndroid Build Coastguard Worker    input_specs = []
479*523fa7a6SAndroid Build Coastguard Worker    output_specs = []
480*523fa7a6SAndroid Build Coastguard Worker    input_specs_to_delete = {}
481*523fa7a6SAndroid Build Coastguard Worker    output_specs_to_delete = {}
482*523fa7a6SAndroid Build Coastguard Worker    new_state_dict = {}
483*523fa7a6SAndroid Build Coastguard Worker    new_constants = {}
484*523fa7a6SAndroid Build Coastguard Worker
485*523fa7a6SAndroid Build Coastguard Worker    # If we are within a submodule, we do not need to care about consuming
486*523fa7a6SAndroid Build Coastguard Worker    # parameter/buffers
487*523fa7a6SAndroid Build Coastguard Worker    input_node_to_sig: Dict[str, InputSpec] = (
488*523fa7a6SAndroid Build Coastguard Worker        {input_spec.arg.name: input_spec for input_spec in old_signature.input_specs}
489*523fa7a6SAndroid Build Coastguard Worker        if not is_submodule
490*523fa7a6SAndroid Build Coastguard Worker        else {}
491*523fa7a6SAndroid Build Coastguard Worker    )
492*523fa7a6SAndroid Build Coastguard Worker
493*523fa7a6SAndroid Build Coastguard Worker    toplevel_output_node_to_sig: Dict[str, List[OutputSpec]] = defaultdict(list)
494*523fa7a6SAndroid Build Coastguard Worker    if not is_submodule:
495*523fa7a6SAndroid Build Coastguard Worker        for output_spec in old_signature.output_specs:
496*523fa7a6SAndroid Build Coastguard Worker            toplevel_output_node_to_sig[output_spec.arg.name].append(output_spec)
497*523fa7a6SAndroid Build Coastguard Worker
498*523fa7a6SAndroid Build Coastguard Worker    for node in gm.graph.nodes:
499*523fa7a6SAndroid Build Coastguard Worker        if node.op == "placeholder":
500*523fa7a6SAndroid Build Coastguard Worker
501*523fa7a6SAndroid Build Coastguard Worker            if node.name not in input_node_to_sig:
502*523fa7a6SAndroid Build Coastguard Worker                input_specs.append(
503*523fa7a6SAndroid Build Coastguard Worker                    InputSpec(
504*523fa7a6SAndroid Build Coastguard Worker                        kind=InputKind.USER_INPUT,
505*523fa7a6SAndroid Build Coastguard Worker                        arg=TensorArgument(name=node.name),
506*523fa7a6SAndroid Build Coastguard Worker                        target=None,
507*523fa7a6SAndroid Build Coastguard Worker                    )
508*523fa7a6SAndroid Build Coastguard Worker                )
509*523fa7a6SAndroid Build Coastguard Worker                continue
510*523fa7a6SAndroid Build Coastguard Worker
511*523fa7a6SAndroid Build Coastguard Worker            orig_input_spec = input_node_to_sig[node.name]
512*523fa7a6SAndroid Build Coastguard Worker
513*523fa7a6SAndroid Build Coastguard Worker            if not isinstance(orig_input_spec.arg, TensorArgument):
514*523fa7a6SAndroid Build Coastguard Worker                input_specs.append(orig_input_spec)
515*523fa7a6SAndroid Build Coastguard Worker
516*523fa7a6SAndroid Build Coastguard Worker            elif node.meta.get("delegation_tag", None) == tag:
517*523fa7a6SAndroid Build Coastguard Worker                input_specs.append(orig_input_spec)
518*523fa7a6SAndroid Build Coastguard Worker
519*523fa7a6SAndroid Build Coastguard Worker                if orig_input_spec.kind == InputKind.USER_INPUT:
520*523fa7a6SAndroid Build Coastguard Worker                    continue
521*523fa7a6SAndroid Build Coastguard Worker
522*523fa7a6SAndroid Build Coastguard Worker                # The following input specs are all attributes that should be
523*523fa7a6SAndroid Build Coastguard Worker                # consumed by the delegate, so we want to remove it from the
524*523fa7a6SAndroid Build Coastguard Worker                # toplevel module input/output
525*523fa7a6SAndroid Build Coastguard Worker                input_specs_to_delete[node.name] = orig_input_spec
526*523fa7a6SAndroid Build Coastguard Worker
527*523fa7a6SAndroid Build Coastguard Worker                input_target = orig_input_spec.target
528*523fa7a6SAndroid Build Coastguard Worker                if input_target in original_program.state_dict:
529*523fa7a6SAndroid Build Coastguard Worker                    assert orig_input_spec.kind in (
530*523fa7a6SAndroid Build Coastguard Worker                        InputKind.PARAMETER,
531*523fa7a6SAndroid Build Coastguard Worker                        InputKind.BUFFER,
532*523fa7a6SAndroid Build Coastguard Worker                    )
533*523fa7a6SAndroid Build Coastguard Worker
534*523fa7a6SAndroid Build Coastguard Worker                    new_state_dict[input_target] = original_program.state_dict[
535*523fa7a6SAndroid Build Coastguard Worker                        input_target
536*523fa7a6SAndroid Build Coastguard Worker                    ]
537*523fa7a6SAndroid Build Coastguard Worker                elif input_target in original_program.constants:
538*523fa7a6SAndroid Build Coastguard Worker                    assert orig_input_spec.kind in (
539*523fa7a6SAndroid Build Coastguard Worker                        InputKind.CONSTANT_TENSOR,
540*523fa7a6SAndroid Build Coastguard Worker                        InputKind.CUSTOM_OBJ,
541*523fa7a6SAndroid Build Coastguard Worker                        InputKind.BUFFER,
542*523fa7a6SAndroid Build Coastguard Worker                    )
543*523fa7a6SAndroid Build Coastguard Worker
544*523fa7a6SAndroid Build Coastguard Worker                    new_constants[input_target] = original_program.constants[
545*523fa7a6SAndroid Build Coastguard Worker                        input_target
546*523fa7a6SAndroid Build Coastguard Worker                    ]
547*523fa7a6SAndroid Build Coastguard Worker                else:
548*523fa7a6SAndroid Build Coastguard Worker                    raise RuntimeError(f"Invalid input spec {orig_input_spec} received")
549*523fa7a6SAndroid Build Coastguard Worker
550*523fa7a6SAndroid Build Coastguard Worker            else:
551*523fa7a6SAndroid Build Coastguard Worker                input_specs.append(
552*523fa7a6SAndroid Build Coastguard Worker                    InputSpec(
553*523fa7a6SAndroid Build Coastguard Worker                        kind=InputKind.USER_INPUT,
554*523fa7a6SAndroid Build Coastguard Worker                        arg=TensorArgument(name=node.name),
555*523fa7a6SAndroid Build Coastguard Worker                        target=None,
556*523fa7a6SAndroid Build Coastguard Worker                    )
557*523fa7a6SAndroid Build Coastguard Worker                )
558*523fa7a6SAndroid Build Coastguard Worker
559*523fa7a6SAndroid Build Coastguard Worker        if node.op == "output":
560*523fa7a6SAndroid Build Coastguard Worker            buffer_mutation_idxs: Dict[int, List[OutputSpec]] = defaultdict(list)
561*523fa7a6SAndroid Build Coastguard Worker            for user in call_module_node.users.keys():
562*523fa7a6SAndroid Build Coastguard Worker                if user.name in toplevel_output_node_to_sig:
563*523fa7a6SAndroid Build Coastguard Worker                    assert (
564*523fa7a6SAndroid Build Coastguard Worker                        user.op == "call_function" and user.target == operator.getitem
565*523fa7a6SAndroid Build Coastguard Worker                    ), f"Invalid user {user}, node.op is {user.op} and node.target is {user.target}"
566*523fa7a6SAndroid Build Coastguard Worker                    getitem_idx = user.args[1]
567*523fa7a6SAndroid Build Coastguard Worker                    assert isinstance(
568*523fa7a6SAndroid Build Coastguard Worker                        getitem_idx, int
569*523fa7a6SAndroid Build Coastguard Worker                    ), f"Invalid getitem type: {type(getitem_idx)}"
570*523fa7a6SAndroid Build Coastguard Worker                    buffer_mutation_idxs[getitem_idx].extend(
571*523fa7a6SAndroid Build Coastguard Worker                        toplevel_output_node_to_sig[user.name]
572*523fa7a6SAndroid Build Coastguard Worker                    )
573*523fa7a6SAndroid Build Coastguard Worker
574*523fa7a6SAndroid Build Coastguard Worker            for i, output_node in enumerate(node.args[0]):
575*523fa7a6SAndroid Build Coastguard Worker                if i in buffer_mutation_idxs:
576*523fa7a6SAndroid Build Coastguard Worker                    assert isinstance(output_node, torch.fx.Node)
577*523fa7a6SAndroid Build Coastguard Worker                    orig_output_specs = buffer_mutation_idxs[i]
578*523fa7a6SAndroid Build Coastguard Worker
579*523fa7a6SAndroid Build Coastguard Worker                    if any(
580*523fa7a6SAndroid Build Coastguard Worker                        orig_output_spec.kind == OutputKind.BUFFER_MUTATION
581*523fa7a6SAndroid Build Coastguard Worker                        and orig_output_spec.target in new_state_dict
582*523fa7a6SAndroid Build Coastguard Worker                        for orig_output_spec in orig_output_specs
583*523fa7a6SAndroid Build Coastguard Worker                    ):
584*523fa7a6SAndroid Build Coastguard Worker                        # If the delegate wants to consume the buffer, then the
585*523fa7a6SAndroid Build Coastguard Worker                        # delegate should also consume the buffer mutation
586*523fa7a6SAndroid Build Coastguard Worker                        # (output spec would be a BUFFER_MUTATION).  Otherwise
587*523fa7a6SAndroid Build Coastguard Worker                        # the delegate will just return the result of the
588*523fa7a6SAndroid Build Coastguard Worker                        # mutation as a USER_OUTPUT.
589*523fa7a6SAndroid Build Coastguard Worker
590*523fa7a6SAndroid Build Coastguard Worker                        orig_output_spec = [
591*523fa7a6SAndroid Build Coastguard Worker                            orig_output_spec
592*523fa7a6SAndroid Build Coastguard Worker                            for orig_output_spec in orig_output_specs
593*523fa7a6SAndroid Build Coastguard Worker                            if orig_output_spec.kind == OutputKind.BUFFER_MUTATION
594*523fa7a6SAndroid Build Coastguard Worker                            and orig_output_spec.target in new_state_dict
595*523fa7a6SAndroid Build Coastguard Worker                        ][0]
596*523fa7a6SAndroid Build Coastguard Worker
597*523fa7a6SAndroid Build Coastguard Worker                        assert len(orig_output_specs) == 1, (
598*523fa7a6SAndroid Build Coastguard Worker                            f"Constant {orig_output_spec.target} was tagged to be "
599*523fa7a6SAndroid Build Coastguard Worker                            "consumed by the buffer, and was found to also contain "
600*523fa7a6SAndroid Build Coastguard Worker                            "a buffer mutation. However this buffer mutation node "
601*523fa7a6SAndroid Build Coastguard Worker                            "was found to also be used as other types of outputs "
602*523fa7a6SAndroid Build Coastguard Worker                            "which is currently not supported. Please file an "
603*523fa7a6SAndroid Build Coastguard Worker                            "issue on Github. \n\n"
604*523fa7a6SAndroid Build Coastguard Worker                            f"The toplevel program: {original_program}\n"
605*523fa7a6SAndroid Build Coastguard Worker                        )
606*523fa7a6SAndroid Build Coastguard Worker                        output_specs.append(
607*523fa7a6SAndroid Build Coastguard Worker                            OutputSpec(
608*523fa7a6SAndroid Build Coastguard Worker                                kind=OutputKind.BUFFER_MUTATION,
609*523fa7a6SAndroid Build Coastguard Worker                                arg=TensorArgument(name=output_node.name),
610*523fa7a6SAndroid Build Coastguard Worker                                target=orig_output_spec.target,
611*523fa7a6SAndroid Build Coastguard Worker                            )
612*523fa7a6SAndroid Build Coastguard Worker                        )
613*523fa7a6SAndroid Build Coastguard Worker                        output_specs_to_delete[orig_output_spec.arg.name] = (
614*523fa7a6SAndroid Build Coastguard Worker                            orig_output_spec
615*523fa7a6SAndroid Build Coastguard Worker                        )
616*523fa7a6SAndroid Build Coastguard Worker                    else:
617*523fa7a6SAndroid Build Coastguard Worker                        output_specs.append(
618*523fa7a6SAndroid Build Coastguard Worker                            OutputSpec(
619*523fa7a6SAndroid Build Coastguard Worker                                kind=OutputKind.USER_OUTPUT,
620*523fa7a6SAndroid Build Coastguard Worker                                arg=TensorArgument(name=output_node.name),
621*523fa7a6SAndroid Build Coastguard Worker                                target=None,
622*523fa7a6SAndroid Build Coastguard Worker                            )
623*523fa7a6SAndroid Build Coastguard Worker                        )
624*523fa7a6SAndroid Build Coastguard Worker
625*523fa7a6SAndroid Build Coastguard Worker                elif not isinstance(output_node, torch.fx.Node):
626*523fa7a6SAndroid Build Coastguard Worker                    output_specs.append(
627*523fa7a6SAndroid Build Coastguard Worker                        OutputSpec(
628*523fa7a6SAndroid Build Coastguard Worker                            kind=OutputKind.USER_OUTPUT,
629*523fa7a6SAndroid Build Coastguard Worker                            arg=ConstantArgument(name="", value=output_node),
630*523fa7a6SAndroid Build Coastguard Worker                            target=None,
631*523fa7a6SAndroid Build Coastguard Worker                        )
632*523fa7a6SAndroid Build Coastguard Worker                    )
633*523fa7a6SAndroid Build Coastguard Worker
634*523fa7a6SAndroid Build Coastguard Worker                else:
635*523fa7a6SAndroid Build Coastguard Worker                    output_specs.append(
636*523fa7a6SAndroid Build Coastguard Worker                        OutputSpec(
637*523fa7a6SAndroid Build Coastguard Worker                            kind=OutputKind.USER_OUTPUT,
638*523fa7a6SAndroid Build Coastguard Worker                            arg=TensorArgument(name=output_node.name),
639*523fa7a6SAndroid Build Coastguard Worker                            target=None,
640*523fa7a6SAndroid Build Coastguard Worker                        )
641*523fa7a6SAndroid Build Coastguard Worker                    )
642*523fa7a6SAndroid Build Coastguard Worker
643*523fa7a6SAndroid Build Coastguard Worker    new_signature = ExportGraphSignature(
644*523fa7a6SAndroid Build Coastguard Worker        input_specs=input_specs, output_specs=output_specs
645*523fa7a6SAndroid Build Coastguard Worker    )
646*523fa7a6SAndroid Build Coastguard Worker
647*523fa7a6SAndroid Build Coastguard Worker    return (
648*523fa7a6SAndroid Build Coastguard Worker        new_signature,
649*523fa7a6SAndroid Build Coastguard Worker        new_state_dict,
650*523fa7a6SAndroid Build Coastguard Worker        new_constants,
651*523fa7a6SAndroid Build Coastguard Worker        input_specs_to_delete,
652*523fa7a6SAndroid Build Coastguard Worker        output_specs_to_delete,
653*523fa7a6SAndroid Build Coastguard Worker    )
654*523fa7a6SAndroid Build Coastguard Worker
655*523fa7a6SAndroid Build Coastguard Worker
656*523fa7a6SAndroid Build Coastguard Workerdef create_exported_program_from_submodule(
657*523fa7a6SAndroid Build Coastguard Worker    submodule: torch.fx.GraphModule,
658*523fa7a6SAndroid Build Coastguard Worker    owning_program: ExportedProgram,
659*523fa7a6SAndroid Build Coastguard Worker    tag: str,
660*523fa7a6SAndroid Build Coastguard Worker    call_module_node: torch.fx.Node,
661*523fa7a6SAndroid Build Coastguard Worker    is_submodule: bool,
662*523fa7a6SAndroid Build Coastguard Worker) -> Tuple[ExportedProgram, Dict[str, InputSpec], Dict[str, OutputSpec]]:
663*523fa7a6SAndroid Build Coastguard Worker    """
664*523fa7a6SAndroid Build Coastguard Worker    Creates an ExportedProgram from the given submodule using the parameters and buffers
665*523fa7a6SAndroid Build Coastguard Worker    from the top-level owning program
666*523fa7a6SAndroid Build Coastguard Worker
667*523fa7a6SAndroid Build Coastguard Worker    Args:
668*523fa7a6SAndroid Build Coastguard Worker        submodule: submodule to create and exported program from
669*523fa7a6SAndroid Build Coastguard Worker        owning_program: exported program containing the parameters and buffers used within
670*523fa7a6SAndroid Build Coastguard Worker            the submodule
671*523fa7a6SAndroid Build Coastguard Worker
672*523fa7a6SAndroid Build Coastguard Worker    Returns:
673*523fa7a6SAndroid Build Coastguard Worker        The ExportedProgram created from submodule
674*523fa7a6SAndroid Build Coastguard Worker        input_specs_to_delete (Dict[str, InputSpec]): The input specs that have
675*523fa7a6SAndroid Build Coastguard Worker            been consumed by the delegate (param/buffer input nodes) and should
676*523fa7a6SAndroid Build Coastguard Worker            be removed from the toplevel ExportedProgram.
677*523fa7a6SAndroid Build Coastguard Worker        output_specs_to_delete (Dict[str, InputSpec]): The output specs that have
678*523fa7a6SAndroid Build Coastguard Worker            been consumed by the delegate (buffer mutation nodes) and should be
679*523fa7a6SAndroid Build Coastguard Worker            removed from the toplevel ExportedProgram.
680*523fa7a6SAndroid Build Coastguard Worker    """
681*523fa7a6SAndroid Build Coastguard Worker    # Arrange the submodule's placeholders in order
682*523fa7a6SAndroid Build Coastguard Worker    submodule = arrange_graph_placeholders(submodule, owning_program)
683*523fa7a6SAndroid Build Coastguard Worker
684*523fa7a6SAndroid Build Coastguard Worker    # TODO: we probably need to arrange the outputs wrt buffer mutations.
685*523fa7a6SAndroid Build Coastguard Worker
686*523fa7a6SAndroid Build Coastguard Worker    # Get updated graph signature
687*523fa7a6SAndroid Build Coastguard Worker    (
688*523fa7a6SAndroid Build Coastguard Worker        subgraph_signature,
689*523fa7a6SAndroid Build Coastguard Worker        subgraph_state_dict,
690*523fa7a6SAndroid Build Coastguard Worker        subgraph_constants,
691*523fa7a6SAndroid Build Coastguard Worker        toplevel_input_specs_to_delete,
692*523fa7a6SAndroid Build Coastguard Worker        toplevel_output_specs_to_delete,
693*523fa7a6SAndroid Build Coastguard Worker    ) = _get_new_signature(
694*523fa7a6SAndroid Build Coastguard Worker        owning_program, submodule, call_module_node, tag, is_submodule
695*523fa7a6SAndroid Build Coastguard Worker    )
696*523fa7a6SAndroid Build Coastguard Worker
697*523fa7a6SAndroid Build Coastguard Worker    in_spec = pytree.tree_flatten((tuple(subgraph_signature.user_inputs), {}))[1]
698*523fa7a6SAndroid Build Coastguard Worker    out_spec = pytree.tree_flatten(subgraph_signature.user_outputs)[1]
699*523fa7a6SAndroid Build Coastguard Worker
700*523fa7a6SAndroid Build Coastguard Worker    return (
701*523fa7a6SAndroid Build Coastguard Worker        ExportedProgram(
702*523fa7a6SAndroid Build Coastguard Worker            root=submodule,
703*523fa7a6SAndroid Build Coastguard Worker            graph=submodule.graph,
704*523fa7a6SAndroid Build Coastguard Worker            graph_signature=subgraph_signature,
705*523fa7a6SAndroid Build Coastguard Worker            state_dict=subgraph_state_dict,
706*523fa7a6SAndroid Build Coastguard Worker            range_constraints=copy.deepcopy(owning_program.range_constraints),
707*523fa7a6SAndroid Build Coastguard Worker            module_call_graph=[
708*523fa7a6SAndroid Build Coastguard Worker                ModuleCallEntry(
709*523fa7a6SAndroid Build Coastguard Worker                    "",
710*523fa7a6SAndroid Build Coastguard Worker                    ModuleCallSignature(
711*523fa7a6SAndroid Build Coastguard Worker                        inputs=[], outputs=[], in_spec=in_spec, out_spec=out_spec
712*523fa7a6SAndroid Build Coastguard Worker                    ),
713*523fa7a6SAndroid Build Coastguard Worker                )
714*523fa7a6SAndroid Build Coastguard Worker            ],
715*523fa7a6SAndroid Build Coastguard Worker            constants=subgraph_constants,
716*523fa7a6SAndroid Build Coastguard Worker            verifiers=[owning_program.verifier],
717*523fa7a6SAndroid Build Coastguard Worker        ),
718*523fa7a6SAndroid Build Coastguard Worker        toplevel_input_specs_to_delete,
719*523fa7a6SAndroid Build Coastguard Worker        toplevel_output_specs_to_delete,
720*523fa7a6SAndroid Build Coastguard Worker    )
721*523fa7a6SAndroid Build Coastguard Worker
722*523fa7a6SAndroid Build Coastguard Worker
723*523fa7a6SAndroid Build Coastguard Workerdef create_submodule_from_nodes(
724*523fa7a6SAndroid Build Coastguard Worker    gm: torch.fx.GraphModule,
725*523fa7a6SAndroid Build Coastguard Worker    node_list: NodeList,
726*523fa7a6SAndroid Build Coastguard Worker    tag: str,
727*523fa7a6SAndroid Build Coastguard Worker    skip_legalize_graph: bool = False,
728*523fa7a6SAndroid Build Coastguard Worker) -> Tuple[torch.fx.GraphModule, torch.fx.Node]:
729*523fa7a6SAndroid Build Coastguard Worker    """
730*523fa7a6SAndroid Build Coastguard Worker    Modifies the given graph module in-place to separate out the given nodes
731*523fa7a6SAndroid Build Coastguard Worker    into a submodule. The given node_list should form a fully connected
732*523fa7a6SAndroid Build Coastguard Worker    subgraph.
733*523fa7a6SAndroid Build Coastguard Worker
734*523fa7a6SAndroid Build Coastguard Worker    Args:
735*523fa7a6SAndroid Build Coastguard Worker        gm: The graph module that we want to partition
736*523fa7a6SAndroid Build Coastguard Worker        node_list: A list of nodes that belong in the partition
737*523fa7a6SAndroid Build Coastguard Worker
738*523fa7a6SAndroid Build Coastguard Worker    Returns:
739*523fa7a6SAndroid Build Coastguard Worker        The submodule that has been partitioned, the call_module node in the
740*523fa7a6SAndroid Build Coastguard Worker        toplevel graph module calling the submodule
741*523fa7a6SAndroid Build Coastguard Worker    """
742*523fa7a6SAndroid Build Coastguard Worker    sorted_nodes = topo_sort(node_list)
743*523fa7a6SAndroid Build Coastguard Worker
744*523fa7a6SAndroid Build Coastguard Worker    submodule_name = "fused_" + tag
745*523fa7a6SAndroid Build Coastguard Worker    sub_gm, orig_inputs, orig_outputs = fuse_as_graphmodule(
746*523fa7a6SAndroid Build Coastguard Worker        gm, sorted_nodes, submodule_name
747*523fa7a6SAndroid Build Coastguard Worker    )
748*523fa7a6SAndroid Build Coastguard Worker
749*523fa7a6SAndroid Build Coastguard Worker    _fixup_output_node(sub_gm)
750*523fa7a6SAndroid Build Coastguard Worker
751*523fa7a6SAndroid Build Coastguard Worker    gm = insert_subgm(gm, sub_gm, orig_inputs, orig_outputs)
752*523fa7a6SAndroid Build Coastguard Worker    submodule_node = None
753*523fa7a6SAndroid Build Coastguard Worker    for node in gm.graph.nodes:
754*523fa7a6SAndroid Build Coastguard Worker        if node.op == "call_module":
755*523fa7a6SAndroid Build Coastguard Worker            if node.target == submodule_name:
756*523fa7a6SAndroid Build Coastguard Worker                submodule_node = node
757*523fa7a6SAndroid Build Coastguard Worker            else:
758*523fa7a6SAndroid Build Coastguard Worker                raise RuntimeError(
759*523fa7a6SAndroid Build Coastguard Worker                    f"The submodule created with nodes {node_list} did not form \
760*523fa7a6SAndroid Build Coastguard Worker                    one fully contained subgraph. Check that these nodes form a \
761*523fa7a6SAndroid Build Coastguard Worker                    fully contained graph. Partitioned graph: {gm.graph}."
762*523fa7a6SAndroid Build Coastguard Worker                )
763*523fa7a6SAndroid Build Coastguard Worker
764*523fa7a6SAndroid Build Coastguard Worker    if len(orig_outputs) == 1 and isinstance(orig_outputs[0].meta["val"], FakeTensor):
765*523fa7a6SAndroid Build Coastguard Worker        # If the original output is a single tensor, it has been
766*523fa7a6SAndroid Build Coastguard Worker        # pytree.tree_flatten-ed to be a singleton list, so we want to replace
767*523fa7a6SAndroid Build Coastguard Worker        # all uses with a getitem call to the 0th index of the result
768*523fa7a6SAndroid Build Coastguard Worker        with gm.graph.inserting_after(submodule_node):
769*523fa7a6SAndroid Build Coastguard Worker            proxy_out = torch.fx.Proxy(submodule_node)[0].node  # type: ignore[index]
770*523fa7a6SAndroid Build Coastguard Worker            submodule_node.replace_all_uses_with(proxy_out)
771*523fa7a6SAndroid Build Coastguard Worker            proxy_out.meta["val"] = submodule_node.meta["val"]
772*523fa7a6SAndroid Build Coastguard Worker            # Reset the args since it was overwritten in the previous line
773*523fa7a6SAndroid Build Coastguard Worker            proxy_out.args = (submodule_node, 0)
774*523fa7a6SAndroid Build Coastguard Worker    else:
775*523fa7a6SAndroid Build Coastguard Worker        # fuse_as_graphmodule will automatically propagate the metadata of the
776*523fa7a6SAndroid Build Coastguard Worker        # partition's last node to the getitem nodes that appear after the
777*523fa7a6SAndroid Build Coastguard Worker        # call_module node. However, in the case of delegation we do not want
778*523fa7a6SAndroid Build Coastguard Worker        # these getitem nodes to contain irrelevant previous metadata
779*523fa7a6SAndroid Build Coastguard Worker        # (ex. source_fn, # nn_module_stack)
780*523fa7a6SAndroid Build Coastguard Worker        for user_node in submodule_node.users:
781*523fa7a6SAndroid Build Coastguard Worker            user_node.meta.pop("nn_module_stack", None)
782*523fa7a6SAndroid Build Coastguard Worker            user_node.meta.pop("source_fn_stack", None)
783*523fa7a6SAndroid Build Coastguard Worker
784*523fa7a6SAndroid Build Coastguard Worker    erase_nodes(gm, sorted_nodes)
785*523fa7a6SAndroid Build Coastguard Worker
786*523fa7a6SAndroid Build Coastguard Worker    # Topological sort original gm with newly created sub_gm
787*523fa7a6SAndroid Build Coastguard Worker    # TODO : T153794167 Get rid of support for skipping legalize graph in create_submodule_from_nodes
788*523fa7a6SAndroid Build Coastguard Worker    # once we transition to using fuse_by_partitions.
789*523fa7a6SAndroid Build Coastguard Worker    if not skip_legalize_graph:
790*523fa7a6SAndroid Build Coastguard Worker        legalize_graph(gm)
791*523fa7a6SAndroid Build Coastguard Worker
792*523fa7a6SAndroid Build Coastguard Worker    # Get the call_module node
793*523fa7a6SAndroid Build Coastguard Worker    submodule_node = None
794*523fa7a6SAndroid Build Coastguard Worker    for node in gm.graph.nodes:
795*523fa7a6SAndroid Build Coastguard Worker        if node.op == "call_module" and node.target == submodule_name:
796*523fa7a6SAndroid Build Coastguard Worker            submodule_node = node
797*523fa7a6SAndroid Build Coastguard Worker        elif node.op == "call_module":
798*523fa7a6SAndroid Build Coastguard Worker            raise RuntimeError(
799*523fa7a6SAndroid Build Coastguard Worker                f"The submodule created with nodes {node_list} did not form \
800*523fa7a6SAndroid Build Coastguard Worker                one fully contained subgraph. Check that these nodes form a \
801*523fa7a6SAndroid Build Coastguard Worker                fully contained graph. Partitioned graph: {gm.graph}."
802*523fa7a6SAndroid Build Coastguard Worker            )
803*523fa7a6SAndroid Build Coastguard Worker
804*523fa7a6SAndroid Build Coastguard Worker    assert (
805*523fa7a6SAndroid Build Coastguard Worker        submodule_node is not None
806*523fa7a6SAndroid Build Coastguard Worker    ), f"No submodule was created with the nodes {node_list} in the graph {gm.graph}"
807*523fa7a6SAndroid Build Coastguard Worker
808*523fa7a6SAndroid Build Coastguard Worker    return sub_gm, submodule_node
809*523fa7a6SAndroid Build Coastguard Worker
810*523fa7a6SAndroid Build Coastguard Worker
811*523fa7a6SAndroid Build Coastguard Workerdef get_lowered_submodules(
812*523fa7a6SAndroid Build Coastguard Worker    graph_module: torch.fx.GraphModule,
813*523fa7a6SAndroid Build Coastguard Worker) -> List[Tuple[str, LoweredBackendModule, torch.fx.Node]]:
814*523fa7a6SAndroid Build Coastguard Worker    """
815*523fa7a6SAndroid Build Coastguard Worker    Returns a list of lowered modules that are in the given graph (does not look
816*523fa7a6SAndroid Build Coastguard Worker    into submodules). Specifically, the returned value is a list containing a
817*523fa7a6SAndroid Build Coastguard Worker    tuple of (name of the lowered module that's stored in the graph module, the
818*523fa7a6SAndroid Build Coastguard Worker    lowered module itself, and the fx node that called this lowered module).
819*523fa7a6SAndroid Build Coastguard Worker    """
820*523fa7a6SAndroid Build Coastguard Worker    lowered_submodules = []
821*523fa7a6SAndroid Build Coastguard Worker    for node in graph_module.graph.nodes:
822*523fa7a6SAndroid Build Coastguard Worker        if node.op == "call_function" and node.target == executorch_call_delegate:
823*523fa7a6SAndroid Build Coastguard Worker            name, module, node = _get_submodule(graph_module, node, 0)
824*523fa7a6SAndroid Build Coastguard Worker            assert isinstance(module, LoweredBackendModule)
825*523fa7a6SAndroid Build Coastguard Worker            lowered_submodules.append((name, module, node))
826*523fa7a6SAndroid Build Coastguard Worker    return lowered_submodules
827*523fa7a6SAndroid Build Coastguard Worker
828*523fa7a6SAndroid Build Coastguard Worker
829*523fa7a6SAndroid Build Coastguard Workerdef get_lowered_backend_modules(
830*523fa7a6SAndroid Build Coastguard Worker    graph_module: torch.fx.GraphModule,
831*523fa7a6SAndroid Build Coastguard Worker) -> List[LoweredBackendModule]:
832*523fa7a6SAndroid Build Coastguard Worker    """
833*523fa7a6SAndroid Build Coastguard Worker    Returns a list of exported programs which were lowered by backen delegates
834*523fa7a6SAndroid Build Coastguard Worker    """
835*523fa7a6SAndroid Build Coastguard Worker    lowered_programs = []
836*523fa7a6SAndroid Build Coastguard Worker    for node in graph_module.graph.nodes:
837*523fa7a6SAndroid Build Coastguard Worker        if node.op == "call_function" and node.target == executorch_call_delegate:
838*523fa7a6SAndroid Build Coastguard Worker            lowered_backend_module = getattr(graph_module, node.args[0].name)
839*523fa7a6SAndroid Build Coastguard Worker            lowered_programs.append(lowered_backend_module)
840*523fa7a6SAndroid Build Coastguard Worker
841*523fa7a6SAndroid Build Coastguard Worker    return lowered_programs
842*523fa7a6SAndroid Build Coastguard Worker
843*523fa7a6SAndroid Build Coastguard Worker
844*523fa7a6SAndroid Build Coastguard Workerdef _unsafe_adjust_original_program(  # noqa: C901
845*523fa7a6SAndroid Build Coastguard Worker    original_program: ExportedProgram,
846*523fa7a6SAndroid Build Coastguard Worker    call_delegate_node: torch.fx.Node,
847*523fa7a6SAndroid Build Coastguard Worker    input_specs_to_delete: Dict[str, InputSpec],
848*523fa7a6SAndroid Build Coastguard Worker    output_specs_to_delete: Dict[str, OutputSpec],
849*523fa7a6SAndroid Build Coastguard Worker) -> None:
850*523fa7a6SAndroid Build Coastguard Worker    """
851*523fa7a6SAndroid Build Coastguard Worker    Directly modify the original exported program's signature and state dict
852*523fa7a6SAndroid Build Coastguard Worker    based on the consumed params/buffers in the delegate.
853*523fa7a6SAndroid Build Coastguard Worker    """
854*523fa7a6SAndroid Build Coastguard Worker    original_program._graph_signature.input_specs = [
855*523fa7a6SAndroid Build Coastguard Worker        input_spec
856*523fa7a6SAndroid Build Coastguard Worker        for input_spec in original_program.graph_signature.input_specs
857*523fa7a6SAndroid Build Coastguard Worker        if input_spec.arg.name not in input_specs_to_delete
858*523fa7a6SAndroid Build Coastguard Worker    ]
859*523fa7a6SAndroid Build Coastguard Worker
860*523fa7a6SAndroid Build Coastguard Worker    currently_used_targets: Set[str] = {
861*523fa7a6SAndroid Build Coastguard Worker        input_spec.target
862*523fa7a6SAndroid Build Coastguard Worker        for input_spec in original_program._graph_signature.input_specs
863*523fa7a6SAndroid Build Coastguard Worker        if input_spec.target is not None
864*523fa7a6SAndroid Build Coastguard Worker    }
865*523fa7a6SAndroid Build Coastguard Worker
866*523fa7a6SAndroid Build Coastguard Worker    original_program._graph_signature.output_specs = [
867*523fa7a6SAndroid Build Coastguard Worker        output_spec
868*523fa7a6SAndroid Build Coastguard Worker        for output_spec in original_program.graph_signature.output_specs
869*523fa7a6SAndroid Build Coastguard Worker        if output_spec.arg.name not in output_specs_to_delete
870*523fa7a6SAndroid Build Coastguard Worker    ]
871*523fa7a6SAndroid Build Coastguard Worker
872*523fa7a6SAndroid Build Coastguard Worker    # Delete all parameters/buffers consumed by the created exported program
873*523fa7a6SAndroid Build Coastguard Worker    # from the graph signature, state dict, constants table
874*523fa7a6SAndroid Build Coastguard Worker    for node in original_program.graph.nodes:
875*523fa7a6SAndroid Build Coastguard Worker        if node.op == "placeholder":
876*523fa7a6SAndroid Build Coastguard Worker            if node.name in input_specs_to_delete:
877*523fa7a6SAndroid Build Coastguard Worker                assert len(node.users) == 0
878*523fa7a6SAndroid Build Coastguard Worker                original_program.graph.erase_node(node)
879*523fa7a6SAndroid Build Coastguard Worker        else:
880*523fa7a6SAndroid Build Coastguard Worker            break
881*523fa7a6SAndroid Build Coastguard Worker
882*523fa7a6SAndroid Build Coastguard Worker    for input_spec in input_specs_to_delete.values():
883*523fa7a6SAndroid Build Coastguard Worker        input_target = input_spec.target
884*523fa7a6SAndroid Build Coastguard Worker        assert input_target is not None
885*523fa7a6SAndroid Build Coastguard Worker
886*523fa7a6SAndroid Build Coastguard Worker        if input_target in currently_used_targets:
887*523fa7a6SAndroid Build Coastguard Worker            continue
888*523fa7a6SAndroid Build Coastguard Worker
889*523fa7a6SAndroid Build Coastguard Worker        if input_spec.kind == InputKind.PARAMETER:
890*523fa7a6SAndroid Build Coastguard Worker            del original_program._state_dict[input_target]
891*523fa7a6SAndroid Build Coastguard Worker        elif input_spec.kind == InputKind.BUFFER:
892*523fa7a6SAndroid Build Coastguard Worker            if input_spec.persistent:
893*523fa7a6SAndroid Build Coastguard Worker                del original_program._state_dict[input_target]
894*523fa7a6SAndroid Build Coastguard Worker            else:
895*523fa7a6SAndroid Build Coastguard Worker                del original_program._constants[input_spec.target]
896*523fa7a6SAndroid Build Coastguard Worker        elif input_spec.kind == InputKind.CONSTANT_TENSOR:
897*523fa7a6SAndroid Build Coastguard Worker            del original_program._constants[input_spec.target]
898*523fa7a6SAndroid Build Coastguard Worker        else:
899*523fa7a6SAndroid Build Coastguard Worker            raise RuntimeError(f"Invalid input spec {input_spec} received")
900*523fa7a6SAndroid Build Coastguard Worker
901*523fa7a6SAndroid Build Coastguard Worker    # Delete buffer mutations from the output which were consumed by the delegate
902*523fa7a6SAndroid Build Coastguard Worker    toplevel_output_node = None
903*523fa7a6SAndroid Build Coastguard Worker    for node in reversed(original_program.graph.nodes):
904*523fa7a6SAndroid Build Coastguard Worker        if node.op == "output":
905*523fa7a6SAndroid Build Coastguard Worker            toplevel_output_node = node
906*523fa7a6SAndroid Build Coastguard Worker            break
907*523fa7a6SAndroid Build Coastguard Worker
908*523fa7a6SAndroid Build Coastguard Worker    assert toplevel_output_node is not None
909*523fa7a6SAndroid Build Coastguard Worker    assert (
910*523fa7a6SAndroid Build Coastguard Worker        len(toplevel_output_node.args) == 1
911*523fa7a6SAndroid Build Coastguard Worker    ), f"Invalid output node: {toplevel_output_node} with args {toplevel_output_node.args}"
912*523fa7a6SAndroid Build Coastguard Worker
913*523fa7a6SAndroid Build Coastguard Worker    new_output_args = [
914*523fa7a6SAndroid Build Coastguard Worker        arg
915*523fa7a6SAndroid Build Coastguard Worker        for arg in toplevel_output_node.args[0]
916*523fa7a6SAndroid Build Coastguard Worker        if not isinstance(arg, torch.fx.Node) or arg.name not in output_specs_to_delete
917*523fa7a6SAndroid Build Coastguard Worker    ]
918*523fa7a6SAndroid Build Coastguard Worker    toplevel_output_node.args = (tuple(new_output_args),)
919*523fa7a6SAndroid Build Coastguard Worker
920*523fa7a6SAndroid Build Coastguard Worker    # Delete the buffer mutation getitem nodes
921*523fa7a6SAndroid Build Coastguard Worker    getitem_idxs: List[int] = []
922*523fa7a6SAndroid Build Coastguard Worker    user_nodes = list(call_delegate_node.users.keys())
923*523fa7a6SAndroid Build Coastguard Worker    for user in user_nodes:
924*523fa7a6SAndroid Build Coastguard Worker        if user.name in output_specs_to_delete:
925*523fa7a6SAndroid Build Coastguard Worker            assert (
926*523fa7a6SAndroid Build Coastguard Worker                user.op == "call_function" and user.target == operator.getitem
927*523fa7a6SAndroid Build Coastguard Worker            ), f"Invalid user {user}, node.op is {node.op} and node.target is {node.target}"
928*523fa7a6SAndroid Build Coastguard Worker            user_idx = user.args[1]
929*523fa7a6SAndroid Build Coastguard Worker            assert isinstance(user_idx, int), f"Invalid getitem type: {type(user_idx)}"
930*523fa7a6SAndroid Build Coastguard Worker            getitem_idxs.append(user_idx)
931*523fa7a6SAndroid Build Coastguard Worker            original_program.graph.erase_node(user)
932*523fa7a6SAndroid Build Coastguard Worker
933*523fa7a6SAndroid Build Coastguard Worker    getitem_idxs.sort(reverse=True)
934*523fa7a6SAndroid Build Coastguard Worker
935*523fa7a6SAndroid Build Coastguard Worker    # Adjust all the getitem indices after the deleted getitems
936*523fa7a6SAndroid Build Coastguard Worker    user_nodes = list(call_delegate_node.users.keys())
937*523fa7a6SAndroid Build Coastguard Worker    for user in user_nodes:
938*523fa7a6SAndroid Build Coastguard Worker        assert user.op == "call_function" and user.target == operator.getitem
939*523fa7a6SAndroid Build Coastguard Worker        user_idx = user.args[1]
940*523fa7a6SAndroid Build Coastguard Worker        assert isinstance(user_idx, int)
941*523fa7a6SAndroid Build Coastguard Worker        for i, idx in enumerate(getitem_idxs):
942*523fa7a6SAndroid Build Coastguard Worker            if user_idx > idx:
943*523fa7a6SAndroid Build Coastguard Worker                user.args = (user.args[0], user_idx - (len(getitem_idxs) - i))
944*523fa7a6SAndroid Build Coastguard Worker                break
945*523fa7a6SAndroid Build Coastguard Worker
946*523fa7a6SAndroid Build Coastguard Worker    original_program._validate()
947