1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Workerimport copy 8*523fa7a6SAndroid Build Coastguard Workerfrom typing import Any 9*523fa7a6SAndroid Build Coastguard Worker 10*523fa7a6SAndroid Build Coastguard Workerimport torch 11*523fa7a6SAndroid Build Coastguard Workerimport torch._export 12*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo 13*523fa7a6SAndroid Build Coastguard Workerfrom torch.utils import _pytree as pytree 14*523fa7a6SAndroid Build Coastguard Worker 15*523fa7a6SAndroid Build Coastguard Worker 16*523fa7a6SAndroid Build Coastguard WorkerVal = Any 17*523fa7a6SAndroid Build Coastguard Worker 18*523fa7a6SAndroid Build Coastguard Worker 19*523fa7a6SAndroid Build Coastguard Workerdef _unlift(gm, inp_pos_to_param_buffer_name, in_spec, out_spec, state_dict): 20*523fa7a6SAndroid Build Coastguard Worker count = 0 21*523fa7a6SAndroid Build Coastguard Worker # Step 1: make lifted params as get_attr 22*523fa7a6SAndroid Build Coastguard Worker for node in gm.graph.nodes: 23*523fa7a6SAndroid Build Coastguard Worker if node.op == "placeholder": 24*523fa7a6SAndroid Build Coastguard Worker if count in inp_pos_to_param_buffer_name: 25*523fa7a6SAndroid Build Coastguard Worker with gm.graph.inserting_after(node): 26*523fa7a6SAndroid Build Coastguard Worker getattr_node = gm.graph.get_attr( 27*523fa7a6SAndroid Build Coastguard Worker inp_pos_to_param_buffer_name[count] 28*523fa7a6SAndroid Build Coastguard Worker ) 29*523fa7a6SAndroid Build Coastguard Worker node.replace_all_uses_with(getattr_node) 30*523fa7a6SAndroid Build Coastguard Worker metadata = node.meta 31*523fa7a6SAndroid Build Coastguard Worker gm.graph.erase_node(node) 32*523fa7a6SAndroid Build Coastguard Worker getattr_node.meta = metadata 33*523fa7a6SAndroid Build Coastguard Worker count += 1 34*523fa7a6SAndroid Build Coastguard Worker 35*523fa7a6SAndroid Build Coastguard Worker # Step 2: Fix the input/output of the graph now that we deleted 36*523fa7a6SAndroid Build Coastguard Worker # some args. 37*523fa7a6SAndroid Build Coastguard Worker gm.graph.lint() 38*523fa7a6SAndroid Build Coastguard Worker names = [f"arg_{i}" for i in range(len(in_spec.children_specs))] 39*523fa7a6SAndroid Build Coastguard Worker gm.graph._codegen = _PyTreeCodeGen( 40*523fa7a6SAndroid Build Coastguard Worker _PyTreeInfo( 41*523fa7a6SAndroid Build Coastguard Worker names, 42*523fa7a6SAndroid Build Coastguard Worker in_spec, 43*523fa7a6SAndroid Build Coastguard Worker out_spec, 44*523fa7a6SAndroid Build Coastguard Worker ) 45*523fa7a6SAndroid Build Coastguard Worker ) 46*523fa7a6SAndroid Build Coastguard Worker gm.recompile() 47*523fa7a6SAndroid Build Coastguard Worker 48*523fa7a6SAndroid Build Coastguard Worker # Step 3: Find state references in HigherOrderOps and recursively 49*523fa7a6SAndroid Build Coastguard Worker # fix them. 50*523fa7a6SAndroid Build Coastguard Worker for node in gm.graph.nodes: 51*523fa7a6SAndroid Build Coastguard Worker if node.op == "call_function" and node.target == torch.ops.cond: 52*523fa7a6SAndroid Build Coastguard Worker pred, true_graph, false_graph, operands = node.args 53*523fa7a6SAndroid Build Coastguard Worker true_gm = getattr(gm, true_graph.name) 54*523fa7a6SAndroid Build Coastguard Worker false_gm = getattr(gm, false_graph.name) 55*523fa7a6SAndroid Build Coastguard Worker inp_pos_to_param_buffer_name_for_submod = {} 56*523fa7a6SAndroid Build Coastguard Worker real_operands = [] 57*523fa7a6SAndroid Build Coastguard Worker for ix, operand in enumerate(operands): 58*523fa7a6SAndroid Build Coastguard Worker if operand.target in inp_pos_to_param_buffer_name.values(): 59*523fa7a6SAndroid Build Coastguard Worker inp_pos_to_param_buffer_name_for_submod[ix] = operand.target 60*523fa7a6SAndroid Build Coastguard Worker true_gm.register_buffer(operand.target, state_dict[operand.target]) 61*523fa7a6SAndroid Build Coastguard Worker false_gm.register_buffer(operand.target, state_dict[operand.target]) 62*523fa7a6SAndroid Build Coastguard Worker else: 63*523fa7a6SAndroid Build Coastguard Worker real_operands.append(operand) 64*523fa7a6SAndroid Build Coastguard Worker node.args = (pred, true_graph, false_graph, real_operands) 65*523fa7a6SAndroid Build Coastguard Worker 66*523fa7a6SAndroid Build Coastguard Worker _, in_spec = pytree.tree_flatten(real_operands) 67*523fa7a6SAndroid Build Coastguard Worker 68*523fa7a6SAndroid Build Coastguard Worker _unlift( 69*523fa7a6SAndroid Build Coastguard Worker true_gm, 70*523fa7a6SAndroid Build Coastguard Worker inp_pos_to_param_buffer_name_for_submod, 71*523fa7a6SAndroid Build Coastguard Worker in_spec, 72*523fa7a6SAndroid Build Coastguard Worker None, 73*523fa7a6SAndroid Build Coastguard Worker state_dict, 74*523fa7a6SAndroid Build Coastguard Worker ) 75*523fa7a6SAndroid Build Coastguard Worker _unlift( 76*523fa7a6SAndroid Build Coastguard Worker false_gm, 77*523fa7a6SAndroid Build Coastguard Worker inp_pos_to_param_buffer_name_for_submod, 78*523fa7a6SAndroid Build Coastguard Worker in_spec, 79*523fa7a6SAndroid Build Coastguard Worker None, 80*523fa7a6SAndroid Build Coastguard Worker state_dict, 81*523fa7a6SAndroid Build Coastguard Worker ) 82*523fa7a6SAndroid Build Coastguard Worker if node.op == "call_function" and node.target.__name__ == "map_impl": 83*523fa7a6SAndroid Build Coastguard Worker body_graph, num_mapped, *operands = node.args 84*523fa7a6SAndroid Build Coastguard Worker body_gm = getattr(gm, body_graph.name) 85*523fa7a6SAndroid Build Coastguard Worker inp_pos_to_buffer_name_for_submod = {} 86*523fa7a6SAndroid Build Coastguard Worker real_operands = [] 87*523fa7a6SAndroid Build Coastguard Worker for ix, operand in enumerate(operands): 88*523fa7a6SAndroid Build Coastguard Worker if operand.target in inp_pos_to_param_buffer_name.values(): 89*523fa7a6SAndroid Build Coastguard Worker inp_pos_to_buffer_name_for_submod[ix] = operand.target 90*523fa7a6SAndroid Build Coastguard Worker body_gm.register_buffer(operand.target, state_dict[operand.target]) 91*523fa7a6SAndroid Build Coastguard Worker else: 92*523fa7a6SAndroid Build Coastguard Worker real_operands.append(operand) 93*523fa7a6SAndroid Build Coastguard Worker node.args = (body_graph, num_mapped, *real_operands) 94*523fa7a6SAndroid Build Coastguard Worker 95*523fa7a6SAndroid Build Coastguard Worker _, in_spec = pytree.tree_flatten(real_operands) 96*523fa7a6SAndroid Build Coastguard Worker 97*523fa7a6SAndroid Build Coastguard Worker _unlift( 98*523fa7a6SAndroid Build Coastguard Worker body_gm, inp_pos_to_buffer_name_for_submod, in_spec, None, state_dict 99*523fa7a6SAndroid Build Coastguard Worker ) 100*523fa7a6SAndroid Build Coastguard Worker gm.graph.lint() 101*523fa7a6SAndroid Build Coastguard Worker gm.graph.eliminate_dead_code() 102*523fa7a6SAndroid Build Coastguard Worker gm.recompile() 103*523fa7a6SAndroid Build Coastguard Worker return gm 104*523fa7a6SAndroid Build Coastguard Worker 105*523fa7a6SAndroid Build Coastguard Worker 106*523fa7a6SAndroid Build Coastguard Workerdef unlift_exported_program_lifted_states( 107*523fa7a6SAndroid Build Coastguard Worker ep: torch.export.exported_program.ExportedProgram, 108*523fa7a6SAndroid Build Coastguard Worker): 109*523fa7a6SAndroid Build Coastguard Worker new_gm = copy.deepcopy(ep.graph_module) 110*523fa7a6SAndroid Build Coastguard Worker 111*523fa7a6SAndroid Build Coastguard Worker # TODO Fix the period in params/buffers names later 112*523fa7a6SAndroid Build Coastguard Worker # maybe a pass to replace graph signature with fixed names 113*523fa7a6SAndroid Build Coastguard Worker param_buffer_name_to_corrected_name = {} 114*523fa7a6SAndroid Build Coastguard Worker 115*523fa7a6SAndroid Build Coastguard Worker for name, stuff in ep.state_dict.items(): 116*523fa7a6SAndroid Build Coastguard Worker if name in ep.graph_signature.buffers: 117*523fa7a6SAndroid Build Coastguard Worker if "." in name: 118*523fa7a6SAndroid Build Coastguard Worker new_gm.register_buffer(name.replace(".", "_"), stuff) 119*523fa7a6SAndroid Build Coastguard Worker param_buffer_name_to_corrected_name[name] = name.replace(".", "_") 120*523fa7a6SAndroid Build Coastguard Worker else: 121*523fa7a6SAndroid Build Coastguard Worker new_gm.register_buffer(name, stuff) 122*523fa7a6SAndroid Build Coastguard Worker elif name in ep.graph_signature.parameters: 123*523fa7a6SAndroid Build Coastguard Worker if "." in name: 124*523fa7a6SAndroid Build Coastguard Worker new_gm.register_parameter(name.replace(".", "_"), stuff) 125*523fa7a6SAndroid Build Coastguard Worker param_buffer_name_to_corrected_name[name] = name.replace(".", "_") 126*523fa7a6SAndroid Build Coastguard Worker else: 127*523fa7a6SAndroid Build Coastguard Worker new_gm.register_parameter(name, stuff) 128*523fa7a6SAndroid Build Coastguard Worker else: 129*523fa7a6SAndroid Build Coastguard Worker raise AssertionError("encountered not registered param/buffer") 130*523fa7a6SAndroid Build Coastguard Worker 131*523fa7a6SAndroid Build Coastguard Worker count = 0 132*523fa7a6SAndroid Build Coastguard Worker inp_pos_to_param_buffer_name = {} 133*523fa7a6SAndroid Build Coastguard Worker for node in new_gm.graph.nodes: 134*523fa7a6SAndroid Build Coastguard Worker if node.op == "placeholder": 135*523fa7a6SAndroid Build Coastguard Worker if node.name in ep.graph_signature.inputs_to_buffers: 136*523fa7a6SAndroid Build Coastguard Worker buffer_name = ep.graph_signature.inputs_to_buffers[node.name] 137*523fa7a6SAndroid Build Coastguard Worker if buffer_name in param_buffer_name_to_corrected_name: 138*523fa7a6SAndroid Build Coastguard Worker inp_pos_to_param_buffer_name[count] = ( 139*523fa7a6SAndroid Build Coastguard Worker param_buffer_name_to_corrected_name[buffer_name] 140*523fa7a6SAndroid Build Coastguard Worker ) 141*523fa7a6SAndroid Build Coastguard Worker else: 142*523fa7a6SAndroid Build Coastguard Worker inp_pos_to_param_buffer_name[count] = buffer_name 143*523fa7a6SAndroid Build Coastguard Worker if node.name in ep.graph_signature.inputs_to_parameters: 144*523fa7a6SAndroid Build Coastguard Worker param_name = ep.graph_signature.inputs_to_parameters[node.name] 145*523fa7a6SAndroid Build Coastguard Worker if param_name in param_buffer_name_to_corrected_name: 146*523fa7a6SAndroid Build Coastguard Worker inp_pos_to_param_buffer_name[count] = ( 147*523fa7a6SAndroid Build Coastguard Worker param_buffer_name_to_corrected_name[param_name] 148*523fa7a6SAndroid Build Coastguard Worker ) 149*523fa7a6SAndroid Build Coastguard Worker else: 150*523fa7a6SAndroid Build Coastguard Worker inp_pos_to_param_buffer_name[count] = param_name 151*523fa7a6SAndroid Build Coastguard Worker count += 1 152*523fa7a6SAndroid Build Coastguard Worker new_gm = _unlift( 153*523fa7a6SAndroid Build Coastguard Worker new_gm, 154*523fa7a6SAndroid Build Coastguard Worker inp_pos_to_param_buffer_name, 155*523fa7a6SAndroid Build Coastguard Worker ep.call_spec.in_spec, 156*523fa7a6SAndroid Build Coastguard Worker ep.call_spec.out_spec, 157*523fa7a6SAndroid Build Coastguard Worker ep.state_dict, 158*523fa7a6SAndroid Build Coastguard Worker ) 159*523fa7a6SAndroid Build Coastguard Worker return new_gm 160