xref: /aosp_15_r20/external/executorch/exir/capture/_unlift.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Workerimport copy
8*523fa7a6SAndroid Build Coastguard Workerfrom typing import Any
9*523fa7a6SAndroid Build Coastguard Worker
10*523fa7a6SAndroid Build Coastguard Workerimport torch
11*523fa7a6SAndroid Build Coastguard Workerimport torch._export
12*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
13*523fa7a6SAndroid Build Coastguard Workerfrom torch.utils import _pytree as pytree
14*523fa7a6SAndroid Build Coastguard Worker
15*523fa7a6SAndroid Build Coastguard Worker
16*523fa7a6SAndroid Build Coastguard WorkerVal = Any
17*523fa7a6SAndroid Build Coastguard Worker
18*523fa7a6SAndroid Build Coastguard Worker
19*523fa7a6SAndroid Build Coastguard Workerdef _unlift(gm, inp_pos_to_param_buffer_name, in_spec, out_spec, state_dict):
20*523fa7a6SAndroid Build Coastguard Worker    count = 0
21*523fa7a6SAndroid Build Coastguard Worker    # Step 1: make lifted params as get_attr
22*523fa7a6SAndroid Build Coastguard Worker    for node in gm.graph.nodes:
23*523fa7a6SAndroid Build Coastguard Worker        if node.op == "placeholder":
24*523fa7a6SAndroid Build Coastguard Worker            if count in inp_pos_to_param_buffer_name:
25*523fa7a6SAndroid Build Coastguard Worker                with gm.graph.inserting_after(node):
26*523fa7a6SAndroid Build Coastguard Worker                    getattr_node = gm.graph.get_attr(
27*523fa7a6SAndroid Build Coastguard Worker                        inp_pos_to_param_buffer_name[count]
28*523fa7a6SAndroid Build Coastguard Worker                    )
29*523fa7a6SAndroid Build Coastguard Worker                    node.replace_all_uses_with(getattr_node)
30*523fa7a6SAndroid Build Coastguard Worker                    metadata = node.meta
31*523fa7a6SAndroid Build Coastguard Worker                    gm.graph.erase_node(node)
32*523fa7a6SAndroid Build Coastguard Worker                    getattr_node.meta = metadata
33*523fa7a6SAndroid Build Coastguard Worker            count += 1
34*523fa7a6SAndroid Build Coastguard Worker
35*523fa7a6SAndroid Build Coastguard Worker    # Step 2: Fix the input/output of the graph now that we deleted
36*523fa7a6SAndroid Build Coastguard Worker    # some args.
37*523fa7a6SAndroid Build Coastguard Worker    gm.graph.lint()
38*523fa7a6SAndroid Build Coastguard Worker    names = [f"arg_{i}" for i in range(len(in_spec.children_specs))]
39*523fa7a6SAndroid Build Coastguard Worker    gm.graph._codegen = _PyTreeCodeGen(
40*523fa7a6SAndroid Build Coastguard Worker        _PyTreeInfo(
41*523fa7a6SAndroid Build Coastguard Worker            names,
42*523fa7a6SAndroid Build Coastguard Worker            in_spec,
43*523fa7a6SAndroid Build Coastguard Worker            out_spec,
44*523fa7a6SAndroid Build Coastguard Worker        )
45*523fa7a6SAndroid Build Coastguard Worker    )
46*523fa7a6SAndroid Build Coastguard Worker    gm.recompile()
47*523fa7a6SAndroid Build Coastguard Worker
48*523fa7a6SAndroid Build Coastguard Worker    # Step 3: Find state references in HigherOrderOps and recursively
49*523fa7a6SAndroid Build Coastguard Worker    # fix them.
50*523fa7a6SAndroid Build Coastguard Worker    for node in gm.graph.nodes:
51*523fa7a6SAndroid Build Coastguard Worker        if node.op == "call_function" and node.target == torch.ops.cond:
52*523fa7a6SAndroid Build Coastguard Worker            pred, true_graph, false_graph, operands = node.args
53*523fa7a6SAndroid Build Coastguard Worker            true_gm = getattr(gm, true_graph.name)
54*523fa7a6SAndroid Build Coastguard Worker            false_gm = getattr(gm, false_graph.name)
55*523fa7a6SAndroid Build Coastguard Worker            inp_pos_to_param_buffer_name_for_submod = {}
56*523fa7a6SAndroid Build Coastguard Worker            real_operands = []
57*523fa7a6SAndroid Build Coastguard Worker            for ix, operand in enumerate(operands):
58*523fa7a6SAndroid Build Coastguard Worker                if operand.target in inp_pos_to_param_buffer_name.values():
59*523fa7a6SAndroid Build Coastguard Worker                    inp_pos_to_param_buffer_name_for_submod[ix] = operand.target
60*523fa7a6SAndroid Build Coastguard Worker                    true_gm.register_buffer(operand.target, state_dict[operand.target])
61*523fa7a6SAndroid Build Coastguard Worker                    false_gm.register_buffer(operand.target, state_dict[operand.target])
62*523fa7a6SAndroid Build Coastguard Worker                else:
63*523fa7a6SAndroid Build Coastguard Worker                    real_operands.append(operand)
64*523fa7a6SAndroid Build Coastguard Worker            node.args = (pred, true_graph, false_graph, real_operands)
65*523fa7a6SAndroid Build Coastguard Worker
66*523fa7a6SAndroid Build Coastguard Worker            _, in_spec = pytree.tree_flatten(real_operands)
67*523fa7a6SAndroid Build Coastguard Worker
68*523fa7a6SAndroid Build Coastguard Worker            _unlift(
69*523fa7a6SAndroid Build Coastguard Worker                true_gm,
70*523fa7a6SAndroid Build Coastguard Worker                inp_pos_to_param_buffer_name_for_submod,
71*523fa7a6SAndroid Build Coastguard Worker                in_spec,
72*523fa7a6SAndroid Build Coastguard Worker                None,
73*523fa7a6SAndroid Build Coastguard Worker                state_dict,
74*523fa7a6SAndroid Build Coastguard Worker            )
75*523fa7a6SAndroid Build Coastguard Worker            _unlift(
76*523fa7a6SAndroid Build Coastguard Worker                false_gm,
77*523fa7a6SAndroid Build Coastguard Worker                inp_pos_to_param_buffer_name_for_submod,
78*523fa7a6SAndroid Build Coastguard Worker                in_spec,
79*523fa7a6SAndroid Build Coastguard Worker                None,
80*523fa7a6SAndroid Build Coastguard Worker                state_dict,
81*523fa7a6SAndroid Build Coastguard Worker            )
82*523fa7a6SAndroid Build Coastguard Worker        if node.op == "call_function" and node.target.__name__ == "map_impl":
83*523fa7a6SAndroid Build Coastguard Worker            body_graph, num_mapped, *operands = node.args
84*523fa7a6SAndroid Build Coastguard Worker            body_gm = getattr(gm, body_graph.name)
85*523fa7a6SAndroid Build Coastguard Worker            inp_pos_to_buffer_name_for_submod = {}
86*523fa7a6SAndroid Build Coastguard Worker            real_operands = []
87*523fa7a6SAndroid Build Coastguard Worker            for ix, operand in enumerate(operands):
88*523fa7a6SAndroid Build Coastguard Worker                if operand.target in inp_pos_to_param_buffer_name.values():
89*523fa7a6SAndroid Build Coastguard Worker                    inp_pos_to_buffer_name_for_submod[ix] = operand.target
90*523fa7a6SAndroid Build Coastguard Worker                    body_gm.register_buffer(operand.target, state_dict[operand.target])
91*523fa7a6SAndroid Build Coastguard Worker                else:
92*523fa7a6SAndroid Build Coastguard Worker                    real_operands.append(operand)
93*523fa7a6SAndroid Build Coastguard Worker            node.args = (body_graph, num_mapped, *real_operands)
94*523fa7a6SAndroid Build Coastguard Worker
95*523fa7a6SAndroid Build Coastguard Worker            _, in_spec = pytree.tree_flatten(real_operands)
96*523fa7a6SAndroid Build Coastguard Worker
97*523fa7a6SAndroid Build Coastguard Worker            _unlift(
98*523fa7a6SAndroid Build Coastguard Worker                body_gm, inp_pos_to_buffer_name_for_submod, in_spec, None, state_dict
99*523fa7a6SAndroid Build Coastguard Worker            )
100*523fa7a6SAndroid Build Coastguard Worker    gm.graph.lint()
101*523fa7a6SAndroid Build Coastguard Worker    gm.graph.eliminate_dead_code()
102*523fa7a6SAndroid Build Coastguard Worker    gm.recompile()
103*523fa7a6SAndroid Build Coastguard Worker    return gm
104*523fa7a6SAndroid Build Coastguard Worker
105*523fa7a6SAndroid Build Coastguard Worker
106*523fa7a6SAndroid Build Coastguard Workerdef unlift_exported_program_lifted_states(
107*523fa7a6SAndroid Build Coastguard Worker    ep: torch.export.exported_program.ExportedProgram,
108*523fa7a6SAndroid Build Coastguard Worker):
109*523fa7a6SAndroid Build Coastguard Worker    new_gm = copy.deepcopy(ep.graph_module)
110*523fa7a6SAndroid Build Coastguard Worker
111*523fa7a6SAndroid Build Coastguard Worker    # TODO Fix the period in params/buffers names later
112*523fa7a6SAndroid Build Coastguard Worker    # maybe a pass to replace graph signature with fixed names
113*523fa7a6SAndroid Build Coastguard Worker    param_buffer_name_to_corrected_name = {}
114*523fa7a6SAndroid Build Coastguard Worker
115*523fa7a6SAndroid Build Coastguard Worker    for name, stuff in ep.state_dict.items():
116*523fa7a6SAndroid Build Coastguard Worker        if name in ep.graph_signature.buffers:
117*523fa7a6SAndroid Build Coastguard Worker            if "." in name:
118*523fa7a6SAndroid Build Coastguard Worker                new_gm.register_buffer(name.replace(".", "_"), stuff)
119*523fa7a6SAndroid Build Coastguard Worker                param_buffer_name_to_corrected_name[name] = name.replace(".", "_")
120*523fa7a6SAndroid Build Coastguard Worker            else:
121*523fa7a6SAndroid Build Coastguard Worker                new_gm.register_buffer(name, stuff)
122*523fa7a6SAndroid Build Coastguard Worker        elif name in ep.graph_signature.parameters:
123*523fa7a6SAndroid Build Coastguard Worker            if "." in name:
124*523fa7a6SAndroid Build Coastguard Worker                new_gm.register_parameter(name.replace(".", "_"), stuff)
125*523fa7a6SAndroid Build Coastguard Worker                param_buffer_name_to_corrected_name[name] = name.replace(".", "_")
126*523fa7a6SAndroid Build Coastguard Worker            else:
127*523fa7a6SAndroid Build Coastguard Worker                new_gm.register_parameter(name, stuff)
128*523fa7a6SAndroid Build Coastguard Worker        else:
129*523fa7a6SAndroid Build Coastguard Worker            raise AssertionError("encountered not registered param/buffer")
130*523fa7a6SAndroid Build Coastguard Worker
131*523fa7a6SAndroid Build Coastguard Worker    count = 0
132*523fa7a6SAndroid Build Coastguard Worker    inp_pos_to_param_buffer_name = {}
133*523fa7a6SAndroid Build Coastguard Worker    for node in new_gm.graph.nodes:
134*523fa7a6SAndroid Build Coastguard Worker        if node.op == "placeholder":
135*523fa7a6SAndroid Build Coastguard Worker            if node.name in ep.graph_signature.inputs_to_buffers:
136*523fa7a6SAndroid Build Coastguard Worker                buffer_name = ep.graph_signature.inputs_to_buffers[node.name]
137*523fa7a6SAndroid Build Coastguard Worker                if buffer_name in param_buffer_name_to_corrected_name:
138*523fa7a6SAndroid Build Coastguard Worker                    inp_pos_to_param_buffer_name[count] = (
139*523fa7a6SAndroid Build Coastguard Worker                        param_buffer_name_to_corrected_name[buffer_name]
140*523fa7a6SAndroid Build Coastguard Worker                    )
141*523fa7a6SAndroid Build Coastguard Worker                else:
142*523fa7a6SAndroid Build Coastguard Worker                    inp_pos_to_param_buffer_name[count] = buffer_name
143*523fa7a6SAndroid Build Coastguard Worker            if node.name in ep.graph_signature.inputs_to_parameters:
144*523fa7a6SAndroid Build Coastguard Worker                param_name = ep.graph_signature.inputs_to_parameters[node.name]
145*523fa7a6SAndroid Build Coastguard Worker                if param_name in param_buffer_name_to_corrected_name:
146*523fa7a6SAndroid Build Coastguard Worker                    inp_pos_to_param_buffer_name[count] = (
147*523fa7a6SAndroid Build Coastguard Worker                        param_buffer_name_to_corrected_name[param_name]
148*523fa7a6SAndroid Build Coastguard Worker                    )
149*523fa7a6SAndroid Build Coastguard Worker                else:
150*523fa7a6SAndroid Build Coastguard Worker                    inp_pos_to_param_buffer_name[count] = param_name
151*523fa7a6SAndroid Build Coastguard Worker            count += 1
152*523fa7a6SAndroid Build Coastguard Worker    new_gm = _unlift(
153*523fa7a6SAndroid Build Coastguard Worker        new_gm,
154*523fa7a6SAndroid Build Coastguard Worker        inp_pos_to_param_buffer_name,
155*523fa7a6SAndroid Build Coastguard Worker        ep.call_spec.in_spec,
156*523fa7a6SAndroid Build Coastguard Worker        ep.call_spec.out_spec,
157*523fa7a6SAndroid Build Coastguard Worker        ep.state_dict,
158*523fa7a6SAndroid Build Coastguard Worker    )
159*523fa7a6SAndroid Build Coastguard Worker    return new_gm
160