1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Worker# pyre-strict 8*523fa7a6SAndroid Build Coastguard Worker 9*523fa7a6SAndroid Build Coastguard Workerimport copy 10*523fa7a6SAndroid Build Coastguard Workerimport operator 11*523fa7a6SAndroid Build Coastguard Workerfrom collections import defaultdict 12*523fa7a6SAndroid Build Coastguard Workerfrom typing import Any, Dict, List, Optional, Set, Tuple, Union 13*523fa7a6SAndroid Build Coastguard Worker 14*523fa7a6SAndroid Build Coastguard Workerimport torch 15*523fa7a6SAndroid Build Coastguard Workerimport torch.utils._pytree as pytree 16*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir._serialize import _serialize_pte_binary 17*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.compile_spec_schema import CompileSpec 18*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.delegate import executorch_call_delegate, get_lowered_module_name 19*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.emit import emit_program 20*523fa7a6SAndroid Build Coastguard Worker 21*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.graph_module import _get_submodule 22*523fa7a6SAndroid Build Coastguard Worker 23*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes.memory_planning_pass import MemoryPlanningPass 24*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes.spec_prop_pass import make_spec, SpecPropPass 25*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.schema import Program 26*523fa7a6SAndroid Build Coastguard Worker 27*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.tracer import Value 28*523fa7a6SAndroid Build Coastguard Workerfrom torch._library.fake_class_registry import FakeScriptObject 29*523fa7a6SAndroid Build Coastguard Worker 30*523fa7a6SAndroid Build Coastguard Workerfrom torch._subclasses import FakeTensor 31*523fa7a6SAndroid Build Coastguard Workerfrom torch.export.exported_program import ( 32*523fa7a6SAndroid Build Coastguard Worker ConstantArgument, 33*523fa7a6SAndroid Build Coastguard Worker ExportedProgram, 34*523fa7a6SAndroid Build Coastguard Worker ExportGraphSignature, 35*523fa7a6SAndroid Build Coastguard Worker InputKind, 36*523fa7a6SAndroid Build Coastguard Worker InputSpec, 37*523fa7a6SAndroid Build Coastguard Worker ModuleCallEntry, 38*523fa7a6SAndroid Build Coastguard Worker ModuleCallSignature, 39*523fa7a6SAndroid Build Coastguard Worker OutputKind, 40*523fa7a6SAndroid Build Coastguard Worker OutputSpec, 41*523fa7a6SAndroid Build Coastguard Worker TensorArgument, 42*523fa7a6SAndroid Build Coastguard Worker) 43*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx.passes.utils.fuser_utils import ( 44*523fa7a6SAndroid Build Coastguard Worker erase_nodes, 45*523fa7a6SAndroid Build Coastguard Worker fuse_as_graphmodule, 46*523fa7a6SAndroid Build Coastguard Worker insert_subgm, 47*523fa7a6SAndroid Build Coastguard Worker legalize_graph, 48*523fa7a6SAndroid Build Coastguard Worker NodeList, 49*523fa7a6SAndroid Build Coastguard Worker topo_sort, 50*523fa7a6SAndroid Build Coastguard Worker) 51*523fa7a6SAndroid Build Coastguard Worker 52*523fa7a6SAndroid Build Coastguard Worker 53*523fa7a6SAndroid Build Coastguard Workerclass LoweredBackendModule(torch.nn.Module): 54*523fa7a6SAndroid Build Coastguard Worker """ 55*523fa7a6SAndroid Build Coastguard Worker A subclass of nn.Module that is generated for modules containing 56*523fa7a6SAndroid Build Coastguard Worker delegated functions. This is can be created by calling `to_backend`. 57*523fa7a6SAndroid Build Coastguard Worker """ 58*523fa7a6SAndroid Build Coastguard Worker 59*523fa7a6SAndroid Build Coastguard Worker _backend_id: str # The backend's name 60*523fa7a6SAndroid Build Coastguard Worker _processed_bytes: bytes # The delegate blobs created from backend.preprocess 61*523fa7a6SAndroid Build Coastguard Worker _compile_specs: List[ 62*523fa7a6SAndroid Build Coastguard Worker CompileSpec 63*523fa7a6SAndroid Build Coastguard Worker ] # A list of backend-specific objects with static metadata to configure the "compilation" process. 64*523fa7a6SAndroid Build Coastguard Worker _original_exported_program: ExportedProgram # The original EXIR module 65*523fa7a6SAndroid Build Coastguard Worker 66*523fa7a6SAndroid Build Coastguard Worker def __init__( 67*523fa7a6SAndroid Build Coastguard Worker self, 68*523fa7a6SAndroid Build Coastguard Worker edge_program: ExportedProgram, 69*523fa7a6SAndroid Build Coastguard Worker backend_id: str, 70*523fa7a6SAndroid Build Coastguard Worker processed_bytes: bytes, 71*523fa7a6SAndroid Build Coastguard Worker compile_specs: List[CompileSpec], 72*523fa7a6SAndroid Build Coastguard Worker ) -> None: 73*523fa7a6SAndroid Build Coastguard Worker super().__init__() 74*523fa7a6SAndroid Build Coastguard Worker self._original_exported_program = edge_program 75*523fa7a6SAndroid Build Coastguard Worker self._backend_id = backend_id 76*523fa7a6SAndroid Build Coastguard Worker self._processed_bytes = processed_bytes 77*523fa7a6SAndroid Build Coastguard Worker self._compile_specs = compile_specs 78*523fa7a6SAndroid Build Coastguard Worker 79*523fa7a6SAndroid Build Coastguard Worker # pyre-ignore 80*523fa7a6SAndroid Build Coastguard Worker def __deepcopy__(self, memo: Optional[Dict[int, Any]]) -> "LoweredBackendModule": 81*523fa7a6SAndroid Build Coastguard Worker # Copy exported program 82*523fa7a6SAndroid Build Coastguard Worker copied_program = ExportedProgram( 83*523fa7a6SAndroid Build Coastguard Worker root=copy.deepcopy(self._original_exported_program.graph_module), 84*523fa7a6SAndroid Build Coastguard Worker graph=copy.deepcopy(self._original_exported_program.graph), 85*523fa7a6SAndroid Build Coastguard Worker graph_signature=copy.deepcopy( 86*523fa7a6SAndroid Build Coastguard Worker self._original_exported_program.graph_signature 87*523fa7a6SAndroid Build Coastguard Worker ), 88*523fa7a6SAndroid Build Coastguard Worker state_dict=self._original_exported_program.state_dict, 89*523fa7a6SAndroid Build Coastguard Worker range_constraints=copy.deepcopy( 90*523fa7a6SAndroid Build Coastguard Worker self._original_exported_program.range_constraints 91*523fa7a6SAndroid Build Coastguard Worker ), 92*523fa7a6SAndroid Build Coastguard Worker module_call_graph=copy.deepcopy( 93*523fa7a6SAndroid Build Coastguard Worker self._original_exported_program.module_call_graph 94*523fa7a6SAndroid Build Coastguard Worker ), 95*523fa7a6SAndroid Build Coastguard Worker constants=self._original_exported_program.constants, 96*523fa7a6SAndroid Build Coastguard Worker verifiers=[copy.deepcopy(self._original_exported_program.verifier)], 97*523fa7a6SAndroid Build Coastguard Worker ) 98*523fa7a6SAndroid Build Coastguard Worker 99*523fa7a6SAndroid Build Coastguard Worker res = LoweredBackendModule( 100*523fa7a6SAndroid Build Coastguard Worker edge_program=copied_program, 101*523fa7a6SAndroid Build Coastguard Worker backend_id=self._backend_id, 102*523fa7a6SAndroid Build Coastguard Worker processed_bytes=self._processed_bytes, 103*523fa7a6SAndroid Build Coastguard Worker compile_specs=copy.deepcopy(self._compile_specs, memo), 104*523fa7a6SAndroid Build Coastguard Worker ) 105*523fa7a6SAndroid Build Coastguard Worker # pyre-fixme[16]: `LoweredBackendModule` has no attribute `meta`. 106*523fa7a6SAndroid Build Coastguard Worker res.meta = copy.copy(getattr(self, "meta", {})) 107*523fa7a6SAndroid Build Coastguard Worker return res 108*523fa7a6SAndroid Build Coastguard Worker 109*523fa7a6SAndroid Build Coastguard Worker @property 110*523fa7a6SAndroid Build Coastguard Worker def backend_id(self) -> str: 111*523fa7a6SAndroid Build Coastguard Worker """ 112*523fa7a6SAndroid Build Coastguard Worker Returns the backends name. 113*523fa7a6SAndroid Build Coastguard Worker """ 114*523fa7a6SAndroid Build Coastguard Worker return self._backend_id 115*523fa7a6SAndroid Build Coastguard Worker 116*523fa7a6SAndroid Build Coastguard Worker @property 117*523fa7a6SAndroid Build Coastguard Worker def processed_bytes(self) -> bytes: 118*523fa7a6SAndroid Build Coastguard Worker """ 119*523fa7a6SAndroid Build Coastguard Worker Returns the delegate blob created from backend.preprocess 120*523fa7a6SAndroid Build Coastguard Worker """ 121*523fa7a6SAndroid Build Coastguard Worker return self._processed_bytes 122*523fa7a6SAndroid Build Coastguard Worker 123*523fa7a6SAndroid Build Coastguard Worker @property 124*523fa7a6SAndroid Build Coastguard Worker def compile_specs(self) -> List[CompileSpec]: 125*523fa7a6SAndroid Build Coastguard Worker """ 126*523fa7a6SAndroid Build Coastguard Worker Returns a list of backend-specific objects with static metadata to configure the "compilation" process. 127*523fa7a6SAndroid Build Coastguard Worker """ 128*523fa7a6SAndroid Build Coastguard Worker return self._compile_specs 129*523fa7a6SAndroid Build Coastguard Worker 130*523fa7a6SAndroid Build Coastguard Worker @property 131*523fa7a6SAndroid Build Coastguard Worker def original_module(self) -> ExportedProgram: 132*523fa7a6SAndroid Build Coastguard Worker """ 133*523fa7a6SAndroid Build Coastguard Worker Returns the original EXIR module 134*523fa7a6SAndroid Build Coastguard Worker """ 135*523fa7a6SAndroid Build Coastguard Worker return self._original_exported_program 136*523fa7a6SAndroid Build Coastguard Worker 137*523fa7a6SAndroid Build Coastguard Worker # TODO(chenlai): consolidate the seriailization config with serialize_to_flatbuffer api 138*523fa7a6SAndroid Build Coastguard Worker def buffer( 139*523fa7a6SAndroid Build Coastguard Worker self, 140*523fa7a6SAndroid Build Coastguard Worker extract_delegate_segments: bool = False, 141*523fa7a6SAndroid Build Coastguard Worker segment_alignment: int = 128, 142*523fa7a6SAndroid Build Coastguard Worker constant_tensor_alignment: Optional[int] = None, 143*523fa7a6SAndroid Build Coastguard Worker delegate_alignment: Optional[int] = None, 144*523fa7a6SAndroid Build Coastguard Worker memory_planning: MemoryPlanningPass = None, # pyre-fixme[9] 145*523fa7a6SAndroid Build Coastguard Worker ) -> bytes: 146*523fa7a6SAndroid Build Coastguard Worker """ 147*523fa7a6SAndroid Build Coastguard Worker Returns a buffer containing the serialized ExecuTorch binary. 148*523fa7a6SAndroid Build Coastguard Worker """ 149*523fa7a6SAndroid Build Coastguard Worker # TODO(T181463742): avoid calling bytes(..) which incurs large copies. 150*523fa7a6SAndroid Build Coastguard Worker out = bytes( 151*523fa7a6SAndroid Build Coastguard Worker _serialize_pte_binary( 152*523fa7a6SAndroid Build Coastguard Worker program=self.program(memory_planning=memory_planning), 153*523fa7a6SAndroid Build Coastguard Worker extract_delegate_segments=extract_delegate_segments, 154*523fa7a6SAndroid Build Coastguard Worker segment_alignment=segment_alignment, 155*523fa7a6SAndroid Build Coastguard Worker constant_tensor_alignment=constant_tensor_alignment, 156*523fa7a6SAndroid Build Coastguard Worker delegate_alignment=delegate_alignment, 157*523fa7a6SAndroid Build Coastguard Worker ) 158*523fa7a6SAndroid Build Coastguard Worker ) 159*523fa7a6SAndroid Build Coastguard Worker return out 160*523fa7a6SAndroid Build Coastguard Worker 161*523fa7a6SAndroid Build Coastguard Worker # TODO(chenlai): re-consider recapture instead of manually constructing the program because 162*523fa7a6SAndroid Build Coastguard Worker # the meta data construction is done manually. 163*523fa7a6SAndroid Build Coastguard Worker def program( 164*523fa7a6SAndroid Build Coastguard Worker self, 165*523fa7a6SAndroid Build Coastguard Worker emit_stacktrace: bool = False, 166*523fa7a6SAndroid Build Coastguard Worker memory_planning: MemoryPlanningPass = None, # pyre-fixme[9] 167*523fa7a6SAndroid Build Coastguard Worker ) -> Program: 168*523fa7a6SAndroid Build Coastguard Worker # Fix autodpes introuces cyclic dependencies: 169*523fa7a6SAndroid Build Coastguard Worker # program -> verifier -> lowered_backend_module -> program 170*523fa7a6SAndroid Build Coastguard Worker # @manual 171*523fa7a6SAndroid Build Coastguard Worker from executorch.exir.program._program import ( 172*523fa7a6SAndroid Build Coastguard Worker _get_updated_graph_signature, 173*523fa7a6SAndroid Build Coastguard Worker _transform, 174*523fa7a6SAndroid Build Coastguard Worker ) 175*523fa7a6SAndroid Build Coastguard Worker 176*523fa7a6SAndroid Build Coastguard Worker """ 177*523fa7a6SAndroid Build Coastguard Worker Returns the object that represents the ExecuTorch binary before serialization. 178*523fa7a6SAndroid Build Coastguard Worker """ 179*523fa7a6SAndroid Build Coastguard Worker # Creates a new module based on the original module. The original module will 180*523fa7a6SAndroid Build Coastguard Worker # look something like following: 181*523fa7a6SAndroid Build Coastguard Worker # 182*523fa7a6SAndroid Build Coastguard Worker # opcode name target args kwargs 183*523fa7a6SAndroid Build Coastguard Worker # ------------- ------------------- ---------------- ------------------------------------------ -------- 184*523fa7a6SAndroid Build Coastguard Worker # placeholder arg0_1 arg0_1 () {} 185*523fa7a6SAndroid Build Coastguard Worker # placeholder arg1_1 arg1_1 () {} 186*523fa7a6SAndroid Build Coastguard Worker # call_function aten_repeat_default * (arg1_1, [4, 1]) {} 187*523fa7a6SAndroid Build Coastguard Worker # call_function aten_mul_tensor * (aten_repeat_default, aten_repeat_default) {} 188*523fa7a6SAndroid Build Coastguard Worker # call_function aten_add_tensor * (arg1_1, arg1_1) {} 189*523fa7a6SAndroid Build Coastguard Worker # output output output ([aten_mul_tensor, aten_add_tensor],) {} 190*523fa7a6SAndroid Build Coastguard Worker # 191*523fa7a6SAndroid Build Coastguard Worker # if the whole module is lowered, the resulting lowered module look like 192*523fa7a6SAndroid Build Coastguard Worker # 193*523fa7a6SAndroid Build Coastguard Worker # opcode name target args kwargs 194*523fa7a6SAndroid Build Coastguard Worker # ------------- ------------------------ --------------------------- ---------------------------------- -------- 195*523fa7a6SAndroid Build Coastguard Worker # placeholder arg0_1 arg0_1 () {} 196*523fa7a6SAndroid Build Coastguard Worker # placeholder arg1_1 arg1_1 () {} 197*523fa7a6SAndroid Build Coastguard Worker # get_attr lowered_module_0 lowered_module_0 () {} 198*523fa7a6SAndroid Build Coastguard Worker # call_function executorch_call_delegate executorch_call_delegate (lowered_module_0, arg0_1, arg1_1) {} 199*523fa7a6SAndroid Build Coastguard Worker # call_function getitem <built-in function getitem> (executorch_call_delegate, 0) {} 200*523fa7a6SAndroid Build Coastguard Worker # call_function getitem_1 <built-in function getitem> (executorch_call_delegate, 1) {} 201*523fa7a6SAndroid Build Coastguard Worker # output output_1 output ([getitem, getitem_1],) {} 202*523fa7a6SAndroid Build Coastguard Worker # 203*523fa7a6SAndroid Build Coastguard Worker # We'll remove all call_function nodes, insert an call_delegate node, inserting getitems nodes to get the result for call_delegate node 204*523fa7a6SAndroid Build Coastguard Worker # and return the list of getitems as the output 205*523fa7a6SAndroid Build Coastguard Worker 206*523fa7a6SAndroid Build Coastguard Worker lowered_exported_program = copy.deepcopy(self._original_exported_program) 207*523fa7a6SAndroid Build Coastguard Worker 208*523fa7a6SAndroid Build Coastguard Worker # The real input nodes are the ones not buffer or parameter 209*523fa7a6SAndroid Build Coastguard Worker all_input_nodes = [ 210*523fa7a6SAndroid Build Coastguard Worker node 211*523fa7a6SAndroid Build Coastguard Worker for node in lowered_exported_program.graph.nodes 212*523fa7a6SAndroid Build Coastguard Worker if ( 213*523fa7a6SAndroid Build Coastguard Worker node.op == "placeholder" 214*523fa7a6SAndroid Build Coastguard Worker and node.name 215*523fa7a6SAndroid Build Coastguard Worker not in lowered_exported_program.graph_signature.inputs_to_buffers 216*523fa7a6SAndroid Build Coastguard Worker and node.name 217*523fa7a6SAndroid Build Coastguard Worker not in lowered_exported_program.graph_signature.inputs_to_parameters 218*523fa7a6SAndroid Build Coastguard Worker ) 219*523fa7a6SAndroid Build Coastguard Worker ] 220*523fa7a6SAndroid Build Coastguard Worker 221*523fa7a6SAndroid Build Coastguard Worker output_node = [ 222*523fa7a6SAndroid Build Coastguard Worker node for node in lowered_exported_program.graph.nodes if node.op == "output" 223*523fa7a6SAndroid Build Coastguard Worker ] 224*523fa7a6SAndroid Build Coastguard Worker assert len(output_node) == 1, "There should be only one output node" 225*523fa7a6SAndroid Build Coastguard Worker 226*523fa7a6SAndroid Build Coastguard Worker # Step 1. Cleaning up the graph before inserting the call_delegate node 227*523fa7a6SAndroid Build Coastguard Worker # Remove the original output node 228*523fa7a6SAndroid Build Coastguard Worker lowered_exported_program.graph.erase_node(output_node[0]) 229*523fa7a6SAndroid Build Coastguard Worker 230*523fa7a6SAndroid Build Coastguard Worker # Remove all the everything else except the input 231*523fa7a6SAndroid Build Coastguard Worker for node in reversed(lowered_exported_program.graph.nodes): 232*523fa7a6SAndroid Build Coastguard Worker if node.op != "placeholder": 233*523fa7a6SAndroid Build Coastguard Worker lowered_exported_program.graph.erase_node(node) 234*523fa7a6SAndroid Build Coastguard Worker 235*523fa7a6SAndroid Build Coastguard Worker # Find placeholders that are parameters or buffers, remove them from the main graph 236*523fa7a6SAndroid Build Coastguard Worker for node in lowered_exported_program.graph.nodes: 237*523fa7a6SAndroid Build Coastguard Worker if node.op == "placeholder" and ( 238*523fa7a6SAndroid Build Coastguard Worker node.name in lowered_exported_program.graph_signature.inputs_to_buffers 239*523fa7a6SAndroid Build Coastguard Worker or node.name 240*523fa7a6SAndroid Build Coastguard Worker in lowered_exported_program.graph_signature.inputs_to_parameters 241*523fa7a6SAndroid Build Coastguard Worker ): 242*523fa7a6SAndroid Build Coastguard Worker lowered_exported_program.graph.erase_node(node) 243*523fa7a6SAndroid Build Coastguard Worker 244*523fa7a6SAndroid Build Coastguard Worker # Step 2. Start constructing the graph 245*523fa7a6SAndroid Build Coastguard Worker lowered_name = get_lowered_module_name( 246*523fa7a6SAndroid Build Coastguard Worker lowered_exported_program.graph_module, self 247*523fa7a6SAndroid Build Coastguard Worker ) 248*523fa7a6SAndroid Build Coastguard Worker # Insert the lowered module to the graph module as an attibute 249*523fa7a6SAndroid Build Coastguard Worker lowered_node = lowered_exported_program.graph.get_attr(lowered_name) 250*523fa7a6SAndroid Build Coastguard Worker 251*523fa7a6SAndroid Build Coastguard Worker # Insert a call_delegate node to the graph module, with arguments from the arg list 252*523fa7a6SAndroid Build Coastguard Worker delegate_node = lowered_exported_program.graph.call_function( 253*523fa7a6SAndroid Build Coastguard Worker executorch_call_delegate, (lowered_node, *all_input_nodes) 254*523fa7a6SAndroid Build Coastguard Worker ) 255*523fa7a6SAndroid Build Coastguard Worker # Get the output list. Since the output node is a tuple of list, like ([aten_mul_tensor, aten_add_tensor],) 256*523fa7a6SAndroid Build Coastguard Worker # We add some handling logic to get the list `[aten_mul_tensor, aten_add_tensor]` properly 257*523fa7a6SAndroid Build Coastguard Worker original_output_nodes = [ 258*523fa7a6SAndroid Build Coastguard Worker node 259*523fa7a6SAndroid Build Coastguard Worker for node in self._original_exported_program.graph.nodes 260*523fa7a6SAndroid Build Coastguard Worker if node.op == "output" 261*523fa7a6SAndroid Build Coastguard Worker ][0].args[0] 262*523fa7a6SAndroid Build Coastguard Worker 263*523fa7a6SAndroid Build Coastguard Worker delegate_node.meta["spec"] = tuple( 264*523fa7a6SAndroid Build Coastguard Worker [make_spec(node.meta["val"]) for node in original_output_nodes] 265*523fa7a6SAndroid Build Coastguard Worker ) 266*523fa7a6SAndroid Build Coastguard Worker delegate_node.meta["val"] = tuple( 267*523fa7a6SAndroid Build Coastguard Worker [node.meta["val"] for node in original_output_nodes] 268*523fa7a6SAndroid Build Coastguard Worker ) 269*523fa7a6SAndroid Build Coastguard Worker 270*523fa7a6SAndroid Build Coastguard Worker # The getitem nodes that are going to be inserted to the lowered graph module 271*523fa7a6SAndroid Build Coastguard Worker getitem_nodes = [] 272*523fa7a6SAndroid Build Coastguard Worker for i in range(len(original_output_nodes)): 273*523fa7a6SAndroid Build Coastguard Worker getitem_node = lowered_exported_program.graph.call_function( 274*523fa7a6SAndroid Build Coastguard Worker operator.getitem, 275*523fa7a6SAndroid Build Coastguard Worker args=(delegate_node, i), 276*523fa7a6SAndroid Build Coastguard Worker ) 277*523fa7a6SAndroid Build Coastguard Worker getitem_node.meta["val"] = delegate_node.meta["val"][i] 278*523fa7a6SAndroid Build Coastguard Worker getitem_nodes.append(getitem_node) 279*523fa7a6SAndroid Build Coastguard Worker lowered_exported_program.graph.output(getitem_nodes) 280*523fa7a6SAndroid Build Coastguard Worker 281*523fa7a6SAndroid Build Coastguard Worker lowered_exported_program.graph_module.recompile() 282*523fa7a6SAndroid Build Coastguard Worker lowered_exported_program.graph.lint() 283*523fa7a6SAndroid Build Coastguard Worker 284*523fa7a6SAndroid Build Coastguard Worker # Users output will be the get items nodes instead 285*523fa7a6SAndroid Build Coastguard Worker output_specs = [ 286*523fa7a6SAndroid Build Coastguard Worker OutputSpec( 287*523fa7a6SAndroid Build Coastguard Worker kind=OutputKind.USER_OUTPUT, 288*523fa7a6SAndroid Build Coastguard Worker arg=TensorArgument(name=getitem_node.name), 289*523fa7a6SAndroid Build Coastguard Worker target=None, 290*523fa7a6SAndroid Build Coastguard Worker ) 291*523fa7a6SAndroid Build Coastguard Worker for getitem_node in getitem_nodes 292*523fa7a6SAndroid Build Coastguard Worker ] 293*523fa7a6SAndroid Build Coastguard Worker # All data are consumed by the delegates so they should be removed from the state dict. 294*523fa7a6SAndroid Build Coastguard Worker inputs_to_parameters = ( 295*523fa7a6SAndroid Build Coastguard Worker lowered_exported_program.graph_signature.inputs_to_parameters 296*523fa7a6SAndroid Build Coastguard Worker ) 297*523fa7a6SAndroid Build Coastguard Worker inputs_to_buffers = lowered_exported_program.graph_signature.inputs_to_buffers 298*523fa7a6SAndroid Build Coastguard Worker input_specs = [ 299*523fa7a6SAndroid Build Coastguard Worker InputSpec( 300*523fa7a6SAndroid Build Coastguard Worker kind=InputKind.USER_INPUT, 301*523fa7a6SAndroid Build Coastguard Worker arg=TensorArgument(name=node.name), 302*523fa7a6SAndroid Build Coastguard Worker target=None, 303*523fa7a6SAndroid Build Coastguard Worker ) 304*523fa7a6SAndroid Build Coastguard Worker for user_input in lowered_exported_program.graph_signature.user_inputs 305*523fa7a6SAndroid Build Coastguard Worker if user_input not in inputs_to_parameters 306*523fa7a6SAndroid Build Coastguard Worker and user_input not in inputs_to_buffers 307*523fa7a6SAndroid Build Coastguard Worker ] 308*523fa7a6SAndroid Build Coastguard Worker 309*523fa7a6SAndroid Build Coastguard Worker # Double check the ExportedProgram data(especially everything except graph) is good 310*523fa7a6SAndroid Build Coastguard Worker exported_program = ExportedProgram( 311*523fa7a6SAndroid Build Coastguard Worker root=lowered_exported_program.graph_module, 312*523fa7a6SAndroid Build Coastguard Worker graph=lowered_exported_program.graph, 313*523fa7a6SAndroid Build Coastguard Worker graph_signature=_get_updated_graph_signature( 314*523fa7a6SAndroid Build Coastguard Worker ExportGraphSignature( 315*523fa7a6SAndroid Build Coastguard Worker input_specs=input_specs, output_specs=output_specs 316*523fa7a6SAndroid Build Coastguard Worker ), 317*523fa7a6SAndroid Build Coastguard Worker lowered_exported_program.graph_module, 318*523fa7a6SAndroid Build Coastguard Worker ), 319*523fa7a6SAndroid Build Coastguard Worker # TODO: May need to set lowered_exported_program.call_spec = CallSpec(None, None) 320*523fa7a6SAndroid Build Coastguard Worker # somewhere as we should pass it a list of tensors to the lowered module and output a 321*523fa7a6SAndroid Build Coastguard Worker # list of tensors. Putting call_spec=lowered_exported_program.call_spec is correct here as the 322*523fa7a6SAndroid Build Coastguard Worker # inputs/outputs to the toplevel program will be in the format of the eager module. 323*523fa7a6SAndroid Build Coastguard Worker state_dict={}, # None because all data are consumed by delegate 324*523fa7a6SAndroid Build Coastguard Worker range_constraints=lowered_exported_program.range_constraints, 325*523fa7a6SAndroid Build Coastguard Worker module_call_graph=lowered_exported_program.module_call_graph, 326*523fa7a6SAndroid Build Coastguard Worker example_inputs=None, 327*523fa7a6SAndroid Build Coastguard Worker verifiers=[lowered_exported_program.verifier], 328*523fa7a6SAndroid Build Coastguard Worker ) 329*523fa7a6SAndroid Build Coastguard Worker if memory_planning is None: 330*523fa7a6SAndroid Build Coastguard Worker memory_planning = MemoryPlanningPass() 331*523fa7a6SAndroid Build Coastguard Worker exported_program = _transform(exported_program, SpecPropPass(), memory_planning) 332*523fa7a6SAndroid Build Coastguard Worker emitted_program = emit_program( 333*523fa7a6SAndroid Build Coastguard Worker exported_program, emit_stacktrace=emit_stacktrace 334*523fa7a6SAndroid Build Coastguard Worker ).program 335*523fa7a6SAndroid Build Coastguard Worker return emitted_program 336*523fa7a6SAndroid Build Coastguard Worker 337*523fa7a6SAndroid Build Coastguard Worker # Used to patch each delegated function with a call_delegate call 338*523fa7a6SAndroid Build Coastguard Worker # @staticmethod 339*523fa7a6SAndroid Build Coastguard Worker def forward( 340*523fa7a6SAndroid Build Coastguard Worker self, 341*523fa7a6SAndroid Build Coastguard Worker *args: Value, 342*523fa7a6SAndroid Build Coastguard Worker **kwargs: Tuple[Value, ...], 343*523fa7a6SAndroid Build Coastguard Worker ) -> Value: 344*523fa7a6SAndroid Build Coastguard Worker return executorch_call_delegate(self, *args) 345*523fa7a6SAndroid Build Coastguard Worker 346*523fa7a6SAndroid Build Coastguard Worker 347*523fa7a6SAndroid Build Coastguard Worker# TODO(zhxchen17) Try ExportPass 348*523fa7a6SAndroid Build Coastguard Workerdef _fixup_output_node(gm: torch.fx.GraphModule) -> None: 349*523fa7a6SAndroid Build Coastguard Worker for node in reversed(gm.graph.nodes): 350*523fa7a6SAndroid Build Coastguard Worker if node.op == "output": 351*523fa7a6SAndroid Build Coastguard Worker with gm.graph.inserting_before(node): 352*523fa7a6SAndroid Build Coastguard Worker assert len(node.args) == 1 353*523fa7a6SAndroid Build Coastguard Worker outputs = node.args[0] 354*523fa7a6SAndroid Build Coastguard Worker if isinstance(outputs, torch.fx.Node): 355*523fa7a6SAndroid Build Coastguard Worker val = outputs.meta.get("val") 356*523fa7a6SAndroid Build Coastguard Worker if isinstance(val, list): 357*523fa7a6SAndroid Build Coastguard Worker # If a list is returned, in some cases it is represented as a 358*523fa7a6SAndroid Build Coastguard Worker # singular node, like `split_copy_tensor` but EXIR will return a 359*523fa7a6SAndroid Build Coastguard Worker # opened-up list like `[getitem1, getitem2]` 360*523fa7a6SAndroid Build Coastguard Worker outputs = [ 361*523fa7a6SAndroid Build Coastguard Worker torch.fx.Proxy(outputs)[i].node for i in range(len(val)) 362*523fa7a6SAndroid Build Coastguard Worker ] 363*523fa7a6SAndroid Build Coastguard Worker returns, out_spec = pytree.tree_flatten(outputs) 364*523fa7a6SAndroid Build Coastguard Worker node.args = (returns,) 365*523fa7a6SAndroid Build Coastguard Worker return 366*523fa7a6SAndroid Build Coastguard Worker 367*523fa7a6SAndroid Build Coastguard Worker 368*523fa7a6SAndroid Build Coastguard Workerdef arrange_graph_placeholders( 369*523fa7a6SAndroid Build Coastguard Worker gm: torch.fx.GraphModule, owning_program: ExportedProgram 370*523fa7a6SAndroid Build Coastguard Worker) -> torch.fx.GraphModule: 371*523fa7a6SAndroid Build Coastguard Worker """ 372*523fa7a6SAndroid Build Coastguard Worker Modifies the graph of the given graphmodule with one that contains the same nodes as the original, 373*523fa7a6SAndroid Build Coastguard Worker but with placeholders in order of (Params + Buffers) (User Inputs) 374*523fa7a6SAndroid Build Coastguard Worker 375*523fa7a6SAndroid Build Coastguard Worker This is used by the delegate api which disturbs the placeholder ordering when creating a submodule 376*523fa7a6SAndroid Build Coastguard Worker from partitioned nodes 377*523fa7a6SAndroid Build Coastguard Worker 378*523fa7a6SAndroid Build Coastguard Worker Args: 379*523fa7a6SAndroid Build Coastguard Worker gm: The graph module that we want arranged 380*523fa7a6SAndroid Build Coastguard Worker owning_program: ExportedProgram that the submodule (gm) belongs to 381*523fa7a6SAndroid Build Coastguard Worker 382*523fa7a6SAndroid Build Coastguard Worker Returns: 383*523fa7a6SAndroid Build Coastguard Worker The graph module in-placed arranged 384*523fa7a6SAndroid Build Coastguard Worker """ 385*523fa7a6SAndroid Build Coastguard Worker new_graph = torch.fx.Graph() 386*523fa7a6SAndroid Build Coastguard Worker 387*523fa7a6SAndroid Build Coastguard Worker node_map = {} # mapping of nodes from old graph to new graph 388*523fa7a6SAndroid Build Coastguard Worker 389*523fa7a6SAndroid Build Coastguard Worker graph_sign = owning_program.graph_signature 390*523fa7a6SAndroid Build Coastguard Worker 391*523fa7a6SAndroid Build Coastguard Worker # Add all placeholders into the graph first: 392*523fa7a6SAndroid Build Coastguard Worker param_nodes = [] 393*523fa7a6SAndroid Build Coastguard Worker buffer_nodes = [] 394*523fa7a6SAndroid Build Coastguard Worker input_nodes = [] 395*523fa7a6SAndroid Build Coastguard Worker for node in gm.graph.nodes: 396*523fa7a6SAndroid Build Coastguard Worker if node.op != "placeholder": 397*523fa7a6SAndroid Build Coastguard Worker continue 398*523fa7a6SAndroid Build Coastguard Worker 399*523fa7a6SAndroid Build Coastguard Worker if node.name in graph_sign.inputs_to_parameters: 400*523fa7a6SAndroid Build Coastguard Worker param_nodes.append(node) 401*523fa7a6SAndroid Build Coastguard Worker elif node.name in graph_sign.inputs_to_buffers: 402*523fa7a6SAndroid Build Coastguard Worker buffer_nodes.append(node) 403*523fa7a6SAndroid Build Coastguard Worker else: 404*523fa7a6SAndroid Build Coastguard Worker input_nodes.append(node) 405*523fa7a6SAndroid Build Coastguard Worker 406*523fa7a6SAndroid Build Coastguard Worker for param_node in param_nodes: 407*523fa7a6SAndroid Build Coastguard Worker new_node = new_graph.node_copy(param_node, lambda x: node_map[x]) 408*523fa7a6SAndroid Build Coastguard Worker node_map[param_node] = new_node 409*523fa7a6SAndroid Build Coastguard Worker for buffer_node in buffer_nodes: 410*523fa7a6SAndroid Build Coastguard Worker new_node = new_graph.node_copy(buffer_node, lambda x: node_map[x]) 411*523fa7a6SAndroid Build Coastguard Worker node_map[buffer_node] = new_node 412*523fa7a6SAndroid Build Coastguard Worker for input_node in input_nodes: 413*523fa7a6SAndroid Build Coastguard Worker new_node = new_graph.node_copy(input_node, lambda x: node_map[x]) 414*523fa7a6SAndroid Build Coastguard Worker node_map[input_node] = new_node 415*523fa7a6SAndroid Build Coastguard Worker 416*523fa7a6SAndroid Build Coastguard Worker # Now add all the other nodes in order 417*523fa7a6SAndroid Build Coastguard Worker for node in gm.graph.nodes: 418*523fa7a6SAndroid Build Coastguard Worker if node.op == "placeholder": 419*523fa7a6SAndroid Build Coastguard Worker continue 420*523fa7a6SAndroid Build Coastguard Worker 421*523fa7a6SAndroid Build Coastguard Worker new_node = new_graph.node_copy(node, lambda x: node_map[x]) 422*523fa7a6SAndroid Build Coastguard Worker node_map[node] = new_node 423*523fa7a6SAndroid Build Coastguard Worker 424*523fa7a6SAndroid Build Coastguard Worker # lint to ensure correctness 425*523fa7a6SAndroid Build Coastguard Worker new_graph.lint() 426*523fa7a6SAndroid Build Coastguard Worker 427*523fa7a6SAndroid Build Coastguard Worker new_graph._codegen = gm.graph._codegen 428*523fa7a6SAndroid Build Coastguard Worker gm.graph = new_graph 429*523fa7a6SAndroid Build Coastguard Worker 430*523fa7a6SAndroid Build Coastguard Worker return gm 431*523fa7a6SAndroid Build Coastguard Worker 432*523fa7a6SAndroid Build Coastguard Worker 433*523fa7a6SAndroid Build Coastguard Worker# TODO Don't regenerate new signature manually. 434*523fa7a6SAndroid Build Coastguard Workerdef _get_new_signature( # noqa: C901 435*523fa7a6SAndroid Build Coastguard Worker original_program: ExportedProgram, 436*523fa7a6SAndroid Build Coastguard Worker gm: torch.fx.GraphModule, 437*523fa7a6SAndroid Build Coastguard Worker call_module_node: torch.fx.Node, 438*523fa7a6SAndroid Build Coastguard Worker tag: str, 439*523fa7a6SAndroid Build Coastguard Worker is_submodule: bool = False, 440*523fa7a6SAndroid Build Coastguard Worker) -> Tuple[ 441*523fa7a6SAndroid Build Coastguard Worker ExportGraphSignature, 442*523fa7a6SAndroid Build Coastguard Worker Dict[str, Union[torch.Tensor, torch.nn.Parameter]], 443*523fa7a6SAndroid Build Coastguard Worker Dict[str, Union[torch.Tensor, FakeScriptObject, torch.ScriptObject]], 444*523fa7a6SAndroid Build Coastguard Worker Dict[str, InputSpec], 445*523fa7a6SAndroid Build Coastguard Worker Dict[str, OutputSpec], 446*523fa7a6SAndroid Build Coastguard Worker]: 447*523fa7a6SAndroid Build Coastguard Worker """ 448*523fa7a6SAndroid Build Coastguard Worker Args: 449*523fa7a6SAndroid Build Coastguard Worker original_program: The original program that we are paritioning 450*523fa7a6SAndroid Build Coastguard Worker gm: The partitioned graph module. 451*523fa7a6SAndroid Build Coastguard Worker call_module_node: The node in the original program that is calling the 452*523fa7a6SAndroid Build Coastguard Worker partitioned graph module. 453*523fa7a6SAndroid Build Coastguard Worker tag: The tag being used for this partitioned submodule. This is used to 454*523fa7a6SAndroid Build Coastguard Worker tell if a particular parameter/buffer/constant node is being tagged, 455*523fa7a6SAndroid Build Coastguard Worker aka consumed by the delegate. 456*523fa7a6SAndroid Build Coastguard Worker is_submodule: True if we are currently partitioning inside of a 457*523fa7a6SAndroid Build Coastguard Worker submodule (like cond's submodule). If we are inside of a submodule, 458*523fa7a6SAndroid Build Coastguard Worker we do not care about consuming params/buffers. 459*523fa7a6SAndroid Build Coastguard Worker 460*523fa7a6SAndroid Build Coastguard Worker Returns: 461*523fa7a6SAndroid Build Coastguard Worker 462*523fa7a6SAndroid Build Coastguard Worker new_signature (ExportGraphSignature): The new signature for the 463*523fa7a6SAndroid Build Coastguard Worker partitioned graph module. 464*523fa7a6SAndroid Build Coastguard Worker new_state_dict (Dict[str, Union[torch.Tensor, torch.nn.Parameter]]): The 465*523fa7a6SAndroid Build Coastguard Worker new state dict containing the consumed params/buffers. 466*523fa7a6SAndroid Build Coastguard Worker new_constants (Dict[str, Union[torch.Tensor, FakeScriptObject, 467*523fa7a6SAndroid Build Coastguard Worker torch.ScriptObject]]): The new constants table containing the 468*523fa7a6SAndroid Build Coastguard Worker consumed constants . 469*523fa7a6SAndroid Build Coastguard Worker input_specs_to_delete (Dict[str, InputSpec]): The input specs that have 470*523fa7a6SAndroid Build Coastguard Worker been consumed by the delegate (param/buffer input nodes) and should 471*523fa7a6SAndroid Build Coastguard Worker be removed from the toplevel ExportedProgram. 472*523fa7a6SAndroid Build Coastguard Worker output_specs_to_delete (Dict[str, InputSpec]): The output specs that have 473*523fa7a6SAndroid Build Coastguard Worker been consumed by the delegate (buffer mutation nodes) and should be 474*523fa7a6SAndroid Build Coastguard Worker removed from the toplevel ExportedProgram. 475*523fa7a6SAndroid Build Coastguard Worker """ 476*523fa7a6SAndroid Build Coastguard Worker old_signature = original_program.graph_signature 477*523fa7a6SAndroid Build Coastguard Worker 478*523fa7a6SAndroid Build Coastguard Worker input_specs = [] 479*523fa7a6SAndroid Build Coastguard Worker output_specs = [] 480*523fa7a6SAndroid Build Coastguard Worker input_specs_to_delete = {} 481*523fa7a6SAndroid Build Coastguard Worker output_specs_to_delete = {} 482*523fa7a6SAndroid Build Coastguard Worker new_state_dict = {} 483*523fa7a6SAndroid Build Coastguard Worker new_constants = {} 484*523fa7a6SAndroid Build Coastguard Worker 485*523fa7a6SAndroid Build Coastguard Worker # If we are within a submodule, we do not need to care about consuming 486*523fa7a6SAndroid Build Coastguard Worker # parameter/buffers 487*523fa7a6SAndroid Build Coastguard Worker input_node_to_sig: Dict[str, InputSpec] = ( 488*523fa7a6SAndroid Build Coastguard Worker {input_spec.arg.name: input_spec for input_spec in old_signature.input_specs} 489*523fa7a6SAndroid Build Coastguard Worker if not is_submodule 490*523fa7a6SAndroid Build Coastguard Worker else {} 491*523fa7a6SAndroid Build Coastguard Worker ) 492*523fa7a6SAndroid Build Coastguard Worker 493*523fa7a6SAndroid Build Coastguard Worker toplevel_output_node_to_sig: Dict[str, List[OutputSpec]] = defaultdict(list) 494*523fa7a6SAndroid Build Coastguard Worker if not is_submodule: 495*523fa7a6SAndroid Build Coastguard Worker for output_spec in old_signature.output_specs: 496*523fa7a6SAndroid Build Coastguard Worker toplevel_output_node_to_sig[output_spec.arg.name].append(output_spec) 497*523fa7a6SAndroid Build Coastguard Worker 498*523fa7a6SAndroid Build Coastguard Worker for node in gm.graph.nodes: 499*523fa7a6SAndroid Build Coastguard Worker if node.op == "placeholder": 500*523fa7a6SAndroid Build Coastguard Worker 501*523fa7a6SAndroid Build Coastguard Worker if node.name not in input_node_to_sig: 502*523fa7a6SAndroid Build Coastguard Worker input_specs.append( 503*523fa7a6SAndroid Build Coastguard Worker InputSpec( 504*523fa7a6SAndroid Build Coastguard Worker kind=InputKind.USER_INPUT, 505*523fa7a6SAndroid Build Coastguard Worker arg=TensorArgument(name=node.name), 506*523fa7a6SAndroid Build Coastguard Worker target=None, 507*523fa7a6SAndroid Build Coastguard Worker ) 508*523fa7a6SAndroid Build Coastguard Worker ) 509*523fa7a6SAndroid Build Coastguard Worker continue 510*523fa7a6SAndroid Build Coastguard Worker 511*523fa7a6SAndroid Build Coastguard Worker orig_input_spec = input_node_to_sig[node.name] 512*523fa7a6SAndroid Build Coastguard Worker 513*523fa7a6SAndroid Build Coastguard Worker if not isinstance(orig_input_spec.arg, TensorArgument): 514*523fa7a6SAndroid Build Coastguard Worker input_specs.append(orig_input_spec) 515*523fa7a6SAndroid Build Coastguard Worker 516*523fa7a6SAndroid Build Coastguard Worker elif node.meta.get("delegation_tag", None) == tag: 517*523fa7a6SAndroid Build Coastguard Worker input_specs.append(orig_input_spec) 518*523fa7a6SAndroid Build Coastguard Worker 519*523fa7a6SAndroid Build Coastguard Worker if orig_input_spec.kind == InputKind.USER_INPUT: 520*523fa7a6SAndroid Build Coastguard Worker continue 521*523fa7a6SAndroid Build Coastguard Worker 522*523fa7a6SAndroid Build Coastguard Worker # The following input specs are all attributes that should be 523*523fa7a6SAndroid Build Coastguard Worker # consumed by the delegate, so we want to remove it from the 524*523fa7a6SAndroid Build Coastguard Worker # toplevel module input/output 525*523fa7a6SAndroid Build Coastguard Worker input_specs_to_delete[node.name] = orig_input_spec 526*523fa7a6SAndroid Build Coastguard Worker 527*523fa7a6SAndroid Build Coastguard Worker input_target = orig_input_spec.target 528*523fa7a6SAndroid Build Coastguard Worker if input_target in original_program.state_dict: 529*523fa7a6SAndroid Build Coastguard Worker assert orig_input_spec.kind in ( 530*523fa7a6SAndroid Build Coastguard Worker InputKind.PARAMETER, 531*523fa7a6SAndroid Build Coastguard Worker InputKind.BUFFER, 532*523fa7a6SAndroid Build Coastguard Worker ) 533*523fa7a6SAndroid Build Coastguard Worker 534*523fa7a6SAndroid Build Coastguard Worker new_state_dict[input_target] = original_program.state_dict[ 535*523fa7a6SAndroid Build Coastguard Worker input_target 536*523fa7a6SAndroid Build Coastguard Worker ] 537*523fa7a6SAndroid Build Coastguard Worker elif input_target in original_program.constants: 538*523fa7a6SAndroid Build Coastguard Worker assert orig_input_spec.kind in ( 539*523fa7a6SAndroid Build Coastguard Worker InputKind.CONSTANT_TENSOR, 540*523fa7a6SAndroid Build Coastguard Worker InputKind.CUSTOM_OBJ, 541*523fa7a6SAndroid Build Coastguard Worker InputKind.BUFFER, 542*523fa7a6SAndroid Build Coastguard Worker ) 543*523fa7a6SAndroid Build Coastguard Worker 544*523fa7a6SAndroid Build Coastguard Worker new_constants[input_target] = original_program.constants[ 545*523fa7a6SAndroid Build Coastguard Worker input_target 546*523fa7a6SAndroid Build Coastguard Worker ] 547*523fa7a6SAndroid Build Coastguard Worker else: 548*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError(f"Invalid input spec {orig_input_spec} received") 549*523fa7a6SAndroid Build Coastguard Worker 550*523fa7a6SAndroid Build Coastguard Worker else: 551*523fa7a6SAndroid Build Coastguard Worker input_specs.append( 552*523fa7a6SAndroid Build Coastguard Worker InputSpec( 553*523fa7a6SAndroid Build Coastguard Worker kind=InputKind.USER_INPUT, 554*523fa7a6SAndroid Build Coastguard Worker arg=TensorArgument(name=node.name), 555*523fa7a6SAndroid Build Coastguard Worker target=None, 556*523fa7a6SAndroid Build Coastguard Worker ) 557*523fa7a6SAndroid Build Coastguard Worker ) 558*523fa7a6SAndroid Build Coastguard Worker 559*523fa7a6SAndroid Build Coastguard Worker if node.op == "output": 560*523fa7a6SAndroid Build Coastguard Worker buffer_mutation_idxs: Dict[int, List[OutputSpec]] = defaultdict(list) 561*523fa7a6SAndroid Build Coastguard Worker for user in call_module_node.users.keys(): 562*523fa7a6SAndroid Build Coastguard Worker if user.name in toplevel_output_node_to_sig: 563*523fa7a6SAndroid Build Coastguard Worker assert ( 564*523fa7a6SAndroid Build Coastguard Worker user.op == "call_function" and user.target == operator.getitem 565*523fa7a6SAndroid Build Coastguard Worker ), f"Invalid user {user}, node.op is {user.op} and node.target is {user.target}" 566*523fa7a6SAndroid Build Coastguard Worker getitem_idx = user.args[1] 567*523fa7a6SAndroid Build Coastguard Worker assert isinstance( 568*523fa7a6SAndroid Build Coastguard Worker getitem_idx, int 569*523fa7a6SAndroid Build Coastguard Worker ), f"Invalid getitem type: {type(getitem_idx)}" 570*523fa7a6SAndroid Build Coastguard Worker buffer_mutation_idxs[getitem_idx].extend( 571*523fa7a6SAndroid Build Coastguard Worker toplevel_output_node_to_sig[user.name] 572*523fa7a6SAndroid Build Coastguard Worker ) 573*523fa7a6SAndroid Build Coastguard Worker 574*523fa7a6SAndroid Build Coastguard Worker for i, output_node in enumerate(node.args[0]): 575*523fa7a6SAndroid Build Coastguard Worker if i in buffer_mutation_idxs: 576*523fa7a6SAndroid Build Coastguard Worker assert isinstance(output_node, torch.fx.Node) 577*523fa7a6SAndroid Build Coastguard Worker orig_output_specs = buffer_mutation_idxs[i] 578*523fa7a6SAndroid Build Coastguard Worker 579*523fa7a6SAndroid Build Coastguard Worker if any( 580*523fa7a6SAndroid Build Coastguard Worker orig_output_spec.kind == OutputKind.BUFFER_MUTATION 581*523fa7a6SAndroid Build Coastguard Worker and orig_output_spec.target in new_state_dict 582*523fa7a6SAndroid Build Coastguard Worker for orig_output_spec in orig_output_specs 583*523fa7a6SAndroid Build Coastguard Worker ): 584*523fa7a6SAndroid Build Coastguard Worker # If the delegate wants to consume the buffer, then the 585*523fa7a6SAndroid Build Coastguard Worker # delegate should also consume the buffer mutation 586*523fa7a6SAndroid Build Coastguard Worker # (output spec would be a BUFFER_MUTATION). Otherwise 587*523fa7a6SAndroid Build Coastguard Worker # the delegate will just return the result of the 588*523fa7a6SAndroid Build Coastguard Worker # mutation as a USER_OUTPUT. 589*523fa7a6SAndroid Build Coastguard Worker 590*523fa7a6SAndroid Build Coastguard Worker orig_output_spec = [ 591*523fa7a6SAndroid Build Coastguard Worker orig_output_spec 592*523fa7a6SAndroid Build Coastguard Worker for orig_output_spec in orig_output_specs 593*523fa7a6SAndroid Build Coastguard Worker if orig_output_spec.kind == OutputKind.BUFFER_MUTATION 594*523fa7a6SAndroid Build Coastguard Worker and orig_output_spec.target in new_state_dict 595*523fa7a6SAndroid Build Coastguard Worker ][0] 596*523fa7a6SAndroid Build Coastguard Worker 597*523fa7a6SAndroid Build Coastguard Worker assert len(orig_output_specs) == 1, ( 598*523fa7a6SAndroid Build Coastguard Worker f"Constant {orig_output_spec.target} was tagged to be " 599*523fa7a6SAndroid Build Coastguard Worker "consumed by the buffer, and was found to also contain " 600*523fa7a6SAndroid Build Coastguard Worker "a buffer mutation. However this buffer mutation node " 601*523fa7a6SAndroid Build Coastguard Worker "was found to also be used as other types of outputs " 602*523fa7a6SAndroid Build Coastguard Worker "which is currently not supported. Please file an " 603*523fa7a6SAndroid Build Coastguard Worker "issue on Github. \n\n" 604*523fa7a6SAndroid Build Coastguard Worker f"The toplevel program: {original_program}\n" 605*523fa7a6SAndroid Build Coastguard Worker ) 606*523fa7a6SAndroid Build Coastguard Worker output_specs.append( 607*523fa7a6SAndroid Build Coastguard Worker OutputSpec( 608*523fa7a6SAndroid Build Coastguard Worker kind=OutputKind.BUFFER_MUTATION, 609*523fa7a6SAndroid Build Coastguard Worker arg=TensorArgument(name=output_node.name), 610*523fa7a6SAndroid Build Coastguard Worker target=orig_output_spec.target, 611*523fa7a6SAndroid Build Coastguard Worker ) 612*523fa7a6SAndroid Build Coastguard Worker ) 613*523fa7a6SAndroid Build Coastguard Worker output_specs_to_delete[orig_output_spec.arg.name] = ( 614*523fa7a6SAndroid Build Coastguard Worker orig_output_spec 615*523fa7a6SAndroid Build Coastguard Worker ) 616*523fa7a6SAndroid Build Coastguard Worker else: 617*523fa7a6SAndroid Build Coastguard Worker output_specs.append( 618*523fa7a6SAndroid Build Coastguard Worker OutputSpec( 619*523fa7a6SAndroid Build Coastguard Worker kind=OutputKind.USER_OUTPUT, 620*523fa7a6SAndroid Build Coastguard Worker arg=TensorArgument(name=output_node.name), 621*523fa7a6SAndroid Build Coastguard Worker target=None, 622*523fa7a6SAndroid Build Coastguard Worker ) 623*523fa7a6SAndroid Build Coastguard Worker ) 624*523fa7a6SAndroid Build Coastguard Worker 625*523fa7a6SAndroid Build Coastguard Worker elif not isinstance(output_node, torch.fx.Node): 626*523fa7a6SAndroid Build Coastguard Worker output_specs.append( 627*523fa7a6SAndroid Build Coastguard Worker OutputSpec( 628*523fa7a6SAndroid Build Coastguard Worker kind=OutputKind.USER_OUTPUT, 629*523fa7a6SAndroid Build Coastguard Worker arg=ConstantArgument(name="", value=output_node), 630*523fa7a6SAndroid Build Coastguard Worker target=None, 631*523fa7a6SAndroid Build Coastguard Worker ) 632*523fa7a6SAndroid Build Coastguard Worker ) 633*523fa7a6SAndroid Build Coastguard Worker 634*523fa7a6SAndroid Build Coastguard Worker else: 635*523fa7a6SAndroid Build Coastguard Worker output_specs.append( 636*523fa7a6SAndroid Build Coastguard Worker OutputSpec( 637*523fa7a6SAndroid Build Coastguard Worker kind=OutputKind.USER_OUTPUT, 638*523fa7a6SAndroid Build Coastguard Worker arg=TensorArgument(name=output_node.name), 639*523fa7a6SAndroid Build Coastguard Worker target=None, 640*523fa7a6SAndroid Build Coastguard Worker ) 641*523fa7a6SAndroid Build Coastguard Worker ) 642*523fa7a6SAndroid Build Coastguard Worker 643*523fa7a6SAndroid Build Coastguard Worker new_signature = ExportGraphSignature( 644*523fa7a6SAndroid Build Coastguard Worker input_specs=input_specs, output_specs=output_specs 645*523fa7a6SAndroid Build Coastguard Worker ) 646*523fa7a6SAndroid Build Coastguard Worker 647*523fa7a6SAndroid Build Coastguard Worker return ( 648*523fa7a6SAndroid Build Coastguard Worker new_signature, 649*523fa7a6SAndroid Build Coastguard Worker new_state_dict, 650*523fa7a6SAndroid Build Coastguard Worker new_constants, 651*523fa7a6SAndroid Build Coastguard Worker input_specs_to_delete, 652*523fa7a6SAndroid Build Coastguard Worker output_specs_to_delete, 653*523fa7a6SAndroid Build Coastguard Worker ) 654*523fa7a6SAndroid Build Coastguard Worker 655*523fa7a6SAndroid Build Coastguard Worker 656*523fa7a6SAndroid Build Coastguard Workerdef create_exported_program_from_submodule( 657*523fa7a6SAndroid Build Coastguard Worker submodule: torch.fx.GraphModule, 658*523fa7a6SAndroid Build Coastguard Worker owning_program: ExportedProgram, 659*523fa7a6SAndroid Build Coastguard Worker tag: str, 660*523fa7a6SAndroid Build Coastguard Worker call_module_node: torch.fx.Node, 661*523fa7a6SAndroid Build Coastguard Worker is_submodule: bool, 662*523fa7a6SAndroid Build Coastguard Worker) -> Tuple[ExportedProgram, Dict[str, InputSpec], Dict[str, OutputSpec]]: 663*523fa7a6SAndroid Build Coastguard Worker """ 664*523fa7a6SAndroid Build Coastguard Worker Creates an ExportedProgram from the given submodule using the parameters and buffers 665*523fa7a6SAndroid Build Coastguard Worker from the top-level owning program 666*523fa7a6SAndroid Build Coastguard Worker 667*523fa7a6SAndroid Build Coastguard Worker Args: 668*523fa7a6SAndroid Build Coastguard Worker submodule: submodule to create and exported program from 669*523fa7a6SAndroid Build Coastguard Worker owning_program: exported program containing the parameters and buffers used within 670*523fa7a6SAndroid Build Coastguard Worker the submodule 671*523fa7a6SAndroid Build Coastguard Worker 672*523fa7a6SAndroid Build Coastguard Worker Returns: 673*523fa7a6SAndroid Build Coastguard Worker The ExportedProgram created from submodule 674*523fa7a6SAndroid Build Coastguard Worker input_specs_to_delete (Dict[str, InputSpec]): The input specs that have 675*523fa7a6SAndroid Build Coastguard Worker been consumed by the delegate (param/buffer input nodes) and should 676*523fa7a6SAndroid Build Coastguard Worker be removed from the toplevel ExportedProgram. 677*523fa7a6SAndroid Build Coastguard Worker output_specs_to_delete (Dict[str, InputSpec]): The output specs that have 678*523fa7a6SAndroid Build Coastguard Worker been consumed by the delegate (buffer mutation nodes) and should be 679*523fa7a6SAndroid Build Coastguard Worker removed from the toplevel ExportedProgram. 680*523fa7a6SAndroid Build Coastguard Worker """ 681*523fa7a6SAndroid Build Coastguard Worker # Arrange the submodule's placeholders in order 682*523fa7a6SAndroid Build Coastguard Worker submodule = arrange_graph_placeholders(submodule, owning_program) 683*523fa7a6SAndroid Build Coastguard Worker 684*523fa7a6SAndroid Build Coastguard Worker # TODO: we probably need to arrange the outputs wrt buffer mutations. 685*523fa7a6SAndroid Build Coastguard Worker 686*523fa7a6SAndroid Build Coastguard Worker # Get updated graph signature 687*523fa7a6SAndroid Build Coastguard Worker ( 688*523fa7a6SAndroid Build Coastguard Worker subgraph_signature, 689*523fa7a6SAndroid Build Coastguard Worker subgraph_state_dict, 690*523fa7a6SAndroid Build Coastguard Worker subgraph_constants, 691*523fa7a6SAndroid Build Coastguard Worker toplevel_input_specs_to_delete, 692*523fa7a6SAndroid Build Coastguard Worker toplevel_output_specs_to_delete, 693*523fa7a6SAndroid Build Coastguard Worker ) = _get_new_signature( 694*523fa7a6SAndroid Build Coastguard Worker owning_program, submodule, call_module_node, tag, is_submodule 695*523fa7a6SAndroid Build Coastguard Worker ) 696*523fa7a6SAndroid Build Coastguard Worker 697*523fa7a6SAndroid Build Coastguard Worker in_spec = pytree.tree_flatten((tuple(subgraph_signature.user_inputs), {}))[1] 698*523fa7a6SAndroid Build Coastguard Worker out_spec = pytree.tree_flatten(subgraph_signature.user_outputs)[1] 699*523fa7a6SAndroid Build Coastguard Worker 700*523fa7a6SAndroid Build Coastguard Worker return ( 701*523fa7a6SAndroid Build Coastguard Worker ExportedProgram( 702*523fa7a6SAndroid Build Coastguard Worker root=submodule, 703*523fa7a6SAndroid Build Coastguard Worker graph=submodule.graph, 704*523fa7a6SAndroid Build Coastguard Worker graph_signature=subgraph_signature, 705*523fa7a6SAndroid Build Coastguard Worker state_dict=subgraph_state_dict, 706*523fa7a6SAndroid Build Coastguard Worker range_constraints=copy.deepcopy(owning_program.range_constraints), 707*523fa7a6SAndroid Build Coastguard Worker module_call_graph=[ 708*523fa7a6SAndroid Build Coastguard Worker ModuleCallEntry( 709*523fa7a6SAndroid Build Coastguard Worker "", 710*523fa7a6SAndroid Build Coastguard Worker ModuleCallSignature( 711*523fa7a6SAndroid Build Coastguard Worker inputs=[], outputs=[], in_spec=in_spec, out_spec=out_spec 712*523fa7a6SAndroid Build Coastguard Worker ), 713*523fa7a6SAndroid Build Coastguard Worker ) 714*523fa7a6SAndroid Build Coastguard Worker ], 715*523fa7a6SAndroid Build Coastguard Worker constants=subgraph_constants, 716*523fa7a6SAndroid Build Coastguard Worker verifiers=[owning_program.verifier], 717*523fa7a6SAndroid Build Coastguard Worker ), 718*523fa7a6SAndroid Build Coastguard Worker toplevel_input_specs_to_delete, 719*523fa7a6SAndroid Build Coastguard Worker toplevel_output_specs_to_delete, 720*523fa7a6SAndroid Build Coastguard Worker ) 721*523fa7a6SAndroid Build Coastguard Worker 722*523fa7a6SAndroid Build Coastguard Worker 723*523fa7a6SAndroid Build Coastguard Workerdef create_submodule_from_nodes( 724*523fa7a6SAndroid Build Coastguard Worker gm: torch.fx.GraphModule, 725*523fa7a6SAndroid Build Coastguard Worker node_list: NodeList, 726*523fa7a6SAndroid Build Coastguard Worker tag: str, 727*523fa7a6SAndroid Build Coastguard Worker skip_legalize_graph: bool = False, 728*523fa7a6SAndroid Build Coastguard Worker) -> Tuple[torch.fx.GraphModule, torch.fx.Node]: 729*523fa7a6SAndroid Build Coastguard Worker """ 730*523fa7a6SAndroid Build Coastguard Worker Modifies the given graph module in-place to separate out the given nodes 731*523fa7a6SAndroid Build Coastguard Worker into a submodule. The given node_list should form a fully connected 732*523fa7a6SAndroid Build Coastguard Worker subgraph. 733*523fa7a6SAndroid Build Coastguard Worker 734*523fa7a6SAndroid Build Coastguard Worker Args: 735*523fa7a6SAndroid Build Coastguard Worker gm: The graph module that we want to partition 736*523fa7a6SAndroid Build Coastguard Worker node_list: A list of nodes that belong in the partition 737*523fa7a6SAndroid Build Coastguard Worker 738*523fa7a6SAndroid Build Coastguard Worker Returns: 739*523fa7a6SAndroid Build Coastguard Worker The submodule that has been partitioned, the call_module node in the 740*523fa7a6SAndroid Build Coastguard Worker toplevel graph module calling the submodule 741*523fa7a6SAndroid Build Coastguard Worker """ 742*523fa7a6SAndroid Build Coastguard Worker sorted_nodes = topo_sort(node_list) 743*523fa7a6SAndroid Build Coastguard Worker 744*523fa7a6SAndroid Build Coastguard Worker submodule_name = "fused_" + tag 745*523fa7a6SAndroid Build Coastguard Worker sub_gm, orig_inputs, orig_outputs = fuse_as_graphmodule( 746*523fa7a6SAndroid Build Coastguard Worker gm, sorted_nodes, submodule_name 747*523fa7a6SAndroid Build Coastguard Worker ) 748*523fa7a6SAndroid Build Coastguard Worker 749*523fa7a6SAndroid Build Coastguard Worker _fixup_output_node(sub_gm) 750*523fa7a6SAndroid Build Coastguard Worker 751*523fa7a6SAndroid Build Coastguard Worker gm = insert_subgm(gm, sub_gm, orig_inputs, orig_outputs) 752*523fa7a6SAndroid Build Coastguard Worker submodule_node = None 753*523fa7a6SAndroid Build Coastguard Worker for node in gm.graph.nodes: 754*523fa7a6SAndroid Build Coastguard Worker if node.op == "call_module": 755*523fa7a6SAndroid Build Coastguard Worker if node.target == submodule_name: 756*523fa7a6SAndroid Build Coastguard Worker submodule_node = node 757*523fa7a6SAndroid Build Coastguard Worker else: 758*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError( 759*523fa7a6SAndroid Build Coastguard Worker f"The submodule created with nodes {node_list} did not form \ 760*523fa7a6SAndroid Build Coastguard Worker one fully contained subgraph. Check that these nodes form a \ 761*523fa7a6SAndroid Build Coastguard Worker fully contained graph. Partitioned graph: {gm.graph}." 762*523fa7a6SAndroid Build Coastguard Worker ) 763*523fa7a6SAndroid Build Coastguard Worker 764*523fa7a6SAndroid Build Coastguard Worker if len(orig_outputs) == 1 and isinstance(orig_outputs[0].meta["val"], FakeTensor): 765*523fa7a6SAndroid Build Coastguard Worker # If the original output is a single tensor, it has been 766*523fa7a6SAndroid Build Coastguard Worker # pytree.tree_flatten-ed to be a singleton list, so we want to replace 767*523fa7a6SAndroid Build Coastguard Worker # all uses with a getitem call to the 0th index of the result 768*523fa7a6SAndroid Build Coastguard Worker with gm.graph.inserting_after(submodule_node): 769*523fa7a6SAndroid Build Coastguard Worker proxy_out = torch.fx.Proxy(submodule_node)[0].node # type: ignore[index] 770*523fa7a6SAndroid Build Coastguard Worker submodule_node.replace_all_uses_with(proxy_out) 771*523fa7a6SAndroid Build Coastguard Worker proxy_out.meta["val"] = submodule_node.meta["val"] 772*523fa7a6SAndroid Build Coastguard Worker # Reset the args since it was overwritten in the previous line 773*523fa7a6SAndroid Build Coastguard Worker proxy_out.args = (submodule_node, 0) 774*523fa7a6SAndroid Build Coastguard Worker else: 775*523fa7a6SAndroid Build Coastguard Worker # fuse_as_graphmodule will automatically propagate the metadata of the 776*523fa7a6SAndroid Build Coastguard Worker # partition's last node to the getitem nodes that appear after the 777*523fa7a6SAndroid Build Coastguard Worker # call_module node. However, in the case of delegation we do not want 778*523fa7a6SAndroid Build Coastguard Worker # these getitem nodes to contain irrelevant previous metadata 779*523fa7a6SAndroid Build Coastguard Worker # (ex. source_fn, # nn_module_stack) 780*523fa7a6SAndroid Build Coastguard Worker for user_node in submodule_node.users: 781*523fa7a6SAndroid Build Coastguard Worker user_node.meta.pop("nn_module_stack", None) 782*523fa7a6SAndroid Build Coastguard Worker user_node.meta.pop("source_fn_stack", None) 783*523fa7a6SAndroid Build Coastguard Worker 784*523fa7a6SAndroid Build Coastguard Worker erase_nodes(gm, sorted_nodes) 785*523fa7a6SAndroid Build Coastguard Worker 786*523fa7a6SAndroid Build Coastguard Worker # Topological sort original gm with newly created sub_gm 787*523fa7a6SAndroid Build Coastguard Worker # TODO : T153794167 Get rid of support for skipping legalize graph in create_submodule_from_nodes 788*523fa7a6SAndroid Build Coastguard Worker # once we transition to using fuse_by_partitions. 789*523fa7a6SAndroid Build Coastguard Worker if not skip_legalize_graph: 790*523fa7a6SAndroid Build Coastguard Worker legalize_graph(gm) 791*523fa7a6SAndroid Build Coastguard Worker 792*523fa7a6SAndroid Build Coastguard Worker # Get the call_module node 793*523fa7a6SAndroid Build Coastguard Worker submodule_node = None 794*523fa7a6SAndroid Build Coastguard Worker for node in gm.graph.nodes: 795*523fa7a6SAndroid Build Coastguard Worker if node.op == "call_module" and node.target == submodule_name: 796*523fa7a6SAndroid Build Coastguard Worker submodule_node = node 797*523fa7a6SAndroid Build Coastguard Worker elif node.op == "call_module": 798*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError( 799*523fa7a6SAndroid Build Coastguard Worker f"The submodule created with nodes {node_list} did not form \ 800*523fa7a6SAndroid Build Coastguard Worker one fully contained subgraph. Check that these nodes form a \ 801*523fa7a6SAndroid Build Coastguard Worker fully contained graph. Partitioned graph: {gm.graph}." 802*523fa7a6SAndroid Build Coastguard Worker ) 803*523fa7a6SAndroid Build Coastguard Worker 804*523fa7a6SAndroid Build Coastguard Worker assert ( 805*523fa7a6SAndroid Build Coastguard Worker submodule_node is not None 806*523fa7a6SAndroid Build Coastguard Worker ), f"No submodule was created with the nodes {node_list} in the graph {gm.graph}" 807*523fa7a6SAndroid Build Coastguard Worker 808*523fa7a6SAndroid Build Coastguard Worker return sub_gm, submodule_node 809*523fa7a6SAndroid Build Coastguard Worker 810*523fa7a6SAndroid Build Coastguard Worker 811*523fa7a6SAndroid Build Coastguard Workerdef get_lowered_submodules( 812*523fa7a6SAndroid Build Coastguard Worker graph_module: torch.fx.GraphModule, 813*523fa7a6SAndroid Build Coastguard Worker) -> List[Tuple[str, LoweredBackendModule, torch.fx.Node]]: 814*523fa7a6SAndroid Build Coastguard Worker """ 815*523fa7a6SAndroid Build Coastguard Worker Returns a list of lowered modules that are in the given graph (does not look 816*523fa7a6SAndroid Build Coastguard Worker into submodules). Specifically, the returned value is a list containing a 817*523fa7a6SAndroid Build Coastguard Worker tuple of (name of the lowered module that's stored in the graph module, the 818*523fa7a6SAndroid Build Coastguard Worker lowered module itself, and the fx node that called this lowered module). 819*523fa7a6SAndroid Build Coastguard Worker """ 820*523fa7a6SAndroid Build Coastguard Worker lowered_submodules = [] 821*523fa7a6SAndroid Build Coastguard Worker for node in graph_module.graph.nodes: 822*523fa7a6SAndroid Build Coastguard Worker if node.op == "call_function" and node.target == executorch_call_delegate: 823*523fa7a6SAndroid Build Coastguard Worker name, module, node = _get_submodule(graph_module, node, 0) 824*523fa7a6SAndroid Build Coastguard Worker assert isinstance(module, LoweredBackendModule) 825*523fa7a6SAndroid Build Coastguard Worker lowered_submodules.append((name, module, node)) 826*523fa7a6SAndroid Build Coastguard Worker return lowered_submodules 827*523fa7a6SAndroid Build Coastguard Worker 828*523fa7a6SAndroid Build Coastguard Worker 829*523fa7a6SAndroid Build Coastguard Workerdef get_lowered_backend_modules( 830*523fa7a6SAndroid Build Coastguard Worker graph_module: torch.fx.GraphModule, 831*523fa7a6SAndroid Build Coastguard Worker) -> List[LoweredBackendModule]: 832*523fa7a6SAndroid Build Coastguard Worker """ 833*523fa7a6SAndroid Build Coastguard Worker Returns a list of exported programs which were lowered by backen delegates 834*523fa7a6SAndroid Build Coastguard Worker """ 835*523fa7a6SAndroid Build Coastguard Worker lowered_programs = [] 836*523fa7a6SAndroid Build Coastguard Worker for node in graph_module.graph.nodes: 837*523fa7a6SAndroid Build Coastguard Worker if node.op == "call_function" and node.target == executorch_call_delegate: 838*523fa7a6SAndroid Build Coastguard Worker lowered_backend_module = getattr(graph_module, node.args[0].name) 839*523fa7a6SAndroid Build Coastguard Worker lowered_programs.append(lowered_backend_module) 840*523fa7a6SAndroid Build Coastguard Worker 841*523fa7a6SAndroid Build Coastguard Worker return lowered_programs 842*523fa7a6SAndroid Build Coastguard Worker 843*523fa7a6SAndroid Build Coastguard Worker 844*523fa7a6SAndroid Build Coastguard Workerdef _unsafe_adjust_original_program( # noqa: C901 845*523fa7a6SAndroid Build Coastguard Worker original_program: ExportedProgram, 846*523fa7a6SAndroid Build Coastguard Worker call_delegate_node: torch.fx.Node, 847*523fa7a6SAndroid Build Coastguard Worker input_specs_to_delete: Dict[str, InputSpec], 848*523fa7a6SAndroid Build Coastguard Worker output_specs_to_delete: Dict[str, OutputSpec], 849*523fa7a6SAndroid Build Coastguard Worker) -> None: 850*523fa7a6SAndroid Build Coastguard Worker """ 851*523fa7a6SAndroid Build Coastguard Worker Directly modify the original exported program's signature and state dict 852*523fa7a6SAndroid Build Coastguard Worker based on the consumed params/buffers in the delegate. 853*523fa7a6SAndroid Build Coastguard Worker """ 854*523fa7a6SAndroid Build Coastguard Worker original_program._graph_signature.input_specs = [ 855*523fa7a6SAndroid Build Coastguard Worker input_spec 856*523fa7a6SAndroid Build Coastguard Worker for input_spec in original_program.graph_signature.input_specs 857*523fa7a6SAndroid Build Coastguard Worker if input_spec.arg.name not in input_specs_to_delete 858*523fa7a6SAndroid Build Coastguard Worker ] 859*523fa7a6SAndroid Build Coastguard Worker 860*523fa7a6SAndroid Build Coastguard Worker currently_used_targets: Set[str] = { 861*523fa7a6SAndroid Build Coastguard Worker input_spec.target 862*523fa7a6SAndroid Build Coastguard Worker for input_spec in original_program._graph_signature.input_specs 863*523fa7a6SAndroid Build Coastguard Worker if input_spec.target is not None 864*523fa7a6SAndroid Build Coastguard Worker } 865*523fa7a6SAndroid Build Coastguard Worker 866*523fa7a6SAndroid Build Coastguard Worker original_program._graph_signature.output_specs = [ 867*523fa7a6SAndroid Build Coastguard Worker output_spec 868*523fa7a6SAndroid Build Coastguard Worker for output_spec in original_program.graph_signature.output_specs 869*523fa7a6SAndroid Build Coastguard Worker if output_spec.arg.name not in output_specs_to_delete 870*523fa7a6SAndroid Build Coastguard Worker ] 871*523fa7a6SAndroid Build Coastguard Worker 872*523fa7a6SAndroid Build Coastguard Worker # Delete all parameters/buffers consumed by the created exported program 873*523fa7a6SAndroid Build Coastguard Worker # from the graph signature, state dict, constants table 874*523fa7a6SAndroid Build Coastguard Worker for node in original_program.graph.nodes: 875*523fa7a6SAndroid Build Coastguard Worker if node.op == "placeholder": 876*523fa7a6SAndroid Build Coastguard Worker if node.name in input_specs_to_delete: 877*523fa7a6SAndroid Build Coastguard Worker assert len(node.users) == 0 878*523fa7a6SAndroid Build Coastguard Worker original_program.graph.erase_node(node) 879*523fa7a6SAndroid Build Coastguard Worker else: 880*523fa7a6SAndroid Build Coastguard Worker break 881*523fa7a6SAndroid Build Coastguard Worker 882*523fa7a6SAndroid Build Coastguard Worker for input_spec in input_specs_to_delete.values(): 883*523fa7a6SAndroid Build Coastguard Worker input_target = input_spec.target 884*523fa7a6SAndroid Build Coastguard Worker assert input_target is not None 885*523fa7a6SAndroid Build Coastguard Worker 886*523fa7a6SAndroid Build Coastguard Worker if input_target in currently_used_targets: 887*523fa7a6SAndroid Build Coastguard Worker continue 888*523fa7a6SAndroid Build Coastguard Worker 889*523fa7a6SAndroid Build Coastguard Worker if input_spec.kind == InputKind.PARAMETER: 890*523fa7a6SAndroid Build Coastguard Worker del original_program._state_dict[input_target] 891*523fa7a6SAndroid Build Coastguard Worker elif input_spec.kind == InputKind.BUFFER: 892*523fa7a6SAndroid Build Coastguard Worker if input_spec.persistent: 893*523fa7a6SAndroid Build Coastguard Worker del original_program._state_dict[input_target] 894*523fa7a6SAndroid Build Coastguard Worker else: 895*523fa7a6SAndroid Build Coastguard Worker del original_program._constants[input_spec.target] 896*523fa7a6SAndroid Build Coastguard Worker elif input_spec.kind == InputKind.CONSTANT_TENSOR: 897*523fa7a6SAndroid Build Coastguard Worker del original_program._constants[input_spec.target] 898*523fa7a6SAndroid Build Coastguard Worker else: 899*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError(f"Invalid input spec {input_spec} received") 900*523fa7a6SAndroid Build Coastguard Worker 901*523fa7a6SAndroid Build Coastguard Worker # Delete buffer mutations from the output which were consumed by the delegate 902*523fa7a6SAndroid Build Coastguard Worker toplevel_output_node = None 903*523fa7a6SAndroid Build Coastguard Worker for node in reversed(original_program.graph.nodes): 904*523fa7a6SAndroid Build Coastguard Worker if node.op == "output": 905*523fa7a6SAndroid Build Coastguard Worker toplevel_output_node = node 906*523fa7a6SAndroid Build Coastguard Worker break 907*523fa7a6SAndroid Build Coastguard Worker 908*523fa7a6SAndroid Build Coastguard Worker assert toplevel_output_node is not None 909*523fa7a6SAndroid Build Coastguard Worker assert ( 910*523fa7a6SAndroid Build Coastguard Worker len(toplevel_output_node.args) == 1 911*523fa7a6SAndroid Build Coastguard Worker ), f"Invalid output node: {toplevel_output_node} with args {toplevel_output_node.args}" 912*523fa7a6SAndroid Build Coastguard Worker 913*523fa7a6SAndroid Build Coastguard Worker new_output_args = [ 914*523fa7a6SAndroid Build Coastguard Worker arg 915*523fa7a6SAndroid Build Coastguard Worker for arg in toplevel_output_node.args[0] 916*523fa7a6SAndroid Build Coastguard Worker if not isinstance(arg, torch.fx.Node) or arg.name not in output_specs_to_delete 917*523fa7a6SAndroid Build Coastguard Worker ] 918*523fa7a6SAndroid Build Coastguard Worker toplevel_output_node.args = (tuple(new_output_args),) 919*523fa7a6SAndroid Build Coastguard Worker 920*523fa7a6SAndroid Build Coastguard Worker # Delete the buffer mutation getitem nodes 921*523fa7a6SAndroid Build Coastguard Worker getitem_idxs: List[int] = [] 922*523fa7a6SAndroid Build Coastguard Worker user_nodes = list(call_delegate_node.users.keys()) 923*523fa7a6SAndroid Build Coastguard Worker for user in user_nodes: 924*523fa7a6SAndroid Build Coastguard Worker if user.name in output_specs_to_delete: 925*523fa7a6SAndroid Build Coastguard Worker assert ( 926*523fa7a6SAndroid Build Coastguard Worker user.op == "call_function" and user.target == operator.getitem 927*523fa7a6SAndroid Build Coastguard Worker ), f"Invalid user {user}, node.op is {node.op} and node.target is {node.target}" 928*523fa7a6SAndroid Build Coastguard Worker user_idx = user.args[1] 929*523fa7a6SAndroid Build Coastguard Worker assert isinstance(user_idx, int), f"Invalid getitem type: {type(user_idx)}" 930*523fa7a6SAndroid Build Coastguard Worker getitem_idxs.append(user_idx) 931*523fa7a6SAndroid Build Coastguard Worker original_program.graph.erase_node(user) 932*523fa7a6SAndroid Build Coastguard Worker 933*523fa7a6SAndroid Build Coastguard Worker getitem_idxs.sort(reverse=True) 934*523fa7a6SAndroid Build Coastguard Worker 935*523fa7a6SAndroid Build Coastguard Worker # Adjust all the getitem indices after the deleted getitems 936*523fa7a6SAndroid Build Coastguard Worker user_nodes = list(call_delegate_node.users.keys()) 937*523fa7a6SAndroid Build Coastguard Worker for user in user_nodes: 938*523fa7a6SAndroid Build Coastguard Worker assert user.op == "call_function" and user.target == operator.getitem 939*523fa7a6SAndroid Build Coastguard Worker user_idx = user.args[1] 940*523fa7a6SAndroid Build Coastguard Worker assert isinstance(user_idx, int) 941*523fa7a6SAndroid Build Coastguard Worker for i, idx in enumerate(getitem_idxs): 942*523fa7a6SAndroid Build Coastguard Worker if user_idx > idx: 943*523fa7a6SAndroid Build Coastguard Worker user.args = (user.args[0], user_idx - (len(getitem_idxs) - i)) 944*523fa7a6SAndroid Build Coastguard Worker break 945*523fa7a6SAndroid Build Coastguard Worker 946*523fa7a6SAndroid Build Coastguard Worker original_program._validate() 947