xref: /aosp_15_r20/external/executorch/exir/passes/constant_prop_pass.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9from collections import OrderedDict
10from typing import cast, Mapping, Optional
11
12import torch
13from executorch.exir.dialects._ops import ops as exir_ops
14from executorch.exir.dialects.edge._ops import EdgeOpOverload
15from torch._export.utils import (
16    get_buffer,
17    get_lifted_tensor_constant,
18    get_param,
19    is_buffer,
20    is_lifted_tensor_constant,
21    is_param,
22)
23from torch._guards import detect_fake_mode
24from torch.export import ExportedProgram
25from torch.export.exported_program import InputKind, InputSpec, TensorArgument
26from torch.utils import _pytree as pytree
27
28
29# Avoid propagating constants for `exir.ops.edge.aten.full.default`.
30# Propagating aten.full can significantly increase compiled model size.
31_DEFAULT_SKIP_TARGETS = {exir_ops.edge.aten.full.default}
32
33_PRIMITIVE_TYPES = (
34    float,
35    int,
36    bool,
37    str,
38    torch.Tensor,
39    torch.device,
40    torch.dtype,
41    torch.layout,
42)
43
44
45def is_const(
46    arg,
47    exported_program: ExportedProgram,
48    const_node_to_tensor: Mapping[torch.fx.Node, torch.Tensor],
49) -> bool:
50    if isinstance(arg, (tuple, list)):
51        return all(is_const(x, exported_program, const_node_to_tensor) for x in arg)
52    elif isinstance(arg, dict):
53        return all(
54            is_const(x, exported_program, const_node_to_tensor) for x in arg.values()
55        )
56    elif isinstance(arg, _PRIMITIVE_TYPES):
57        return True
58    elif not isinstance(arg, torch.fx.Node):
59        return False
60    elif arg in const_node_to_tensor:
61        return True
62    return False
63
64
65def get_data(
66    arg,
67    exported_program: ExportedProgram,
68    const_node_to_tensor: Mapping[torch.fx.Node, torch.Tensor],
69):
70    if isinstance(arg, (tuple, list)):
71        return type(arg)(
72            get_data(x, exported_program, const_node_to_tensor) for x in arg
73        )
74    elif isinstance(arg, _PRIMITIVE_TYPES):
75        return arg
76    elif arg in const_node_to_tensor:
77        return const_node_to_tensor[arg]
78    return None
79
80
81def get_constant_placeholder_dict(
82    exported_program: ExportedProgram,
83) -> OrderedDict[torch.fx.Node, torch.Tensor]:
84    """
85    Returns a dictionary of placeholder node -> constant tensor.
86    """
87    const_node_to_tensor: OrderedDict[torch.fx.Node, torch.Tensor] = OrderedDict()
88    for node in exported_program.graph.nodes:
89        if node.op != "placeholder":
90            continue
91
92        if is_param(exported_program, node):
93            const_node_to_tensor[node] = cast(
94                torch.Tensor, get_param(exported_program, node)
95            )
96        elif is_buffer(exported_program, node):
97            const_node_to_tensor[node] = cast(
98                torch.Tensor, get_buffer(exported_program, node)
99            )
100        elif is_lifted_tensor_constant(exported_program, node):
101            const_node_to_tensor[node] = cast(
102                torch.Tensor, get_lifted_tensor_constant(exported_program, node)
103            )
104    return const_node_to_tensor
105
106
107def get_propagated_const_tensor_dict(
108    exported_program: ExportedProgram,
109    custom_skip_targets: Optional[set[EdgeOpOverload]],
110) -> OrderedDict[torch.fx.Node, torch.Tensor]:
111    """
112    Propagates constants and returns a dictionary of node->constant tensors.
113    """
114    # Initialize dict with all constant placeholders.
115    const_node_to_tensor = get_constant_placeholder_dict(exported_program)
116
117    if custom_skip_targets is not None:
118        all_skip_targets = custom_skip_targets
119    else:
120        # Default set of targets to skip.
121        all_skip_targets = _DEFAULT_SKIP_TARGETS
122
123    for node in exported_program.graph.nodes:
124        if node.op != "call_function" or node.target in all_skip_targets:
125            continue
126
127        if not is_const(
128            node.args,
129            exported_program,
130            const_node_to_tensor,
131        ) or not is_const(
132            node.kwargs,
133            exported_program,
134            const_node_to_tensor,
135        ):
136            continue
137
138        args_data, kwargs_data = pytree.tree_map(
139            lambda x: get_data(x, exported_program, const_node_to_tensor),
140            (node.args, node.kwargs),
141        )
142        # Disable grad for constant propagation, otherwise the generated tensor can't be copied
143        # because of the grad_fn.
144        with torch.no_grad():
145            # Execute the `node.target` and create a new propagated constant tensor.
146            prop_constant_tensor = node.target(*args_data, **kwargs_data)
147        const_node_to_tensor[node] = prop_constant_tensor
148
149    return const_node_to_tensor
150
151
152def get_first_user_input(exported_program: ExportedProgram) -> torch.fx.Node:
153    """Returns the first user input node in the graph."""
154    first_user_input = None
155    for node in exported_program.graph.nodes:
156        if (
157            node.op == "placeholder"
158            and node.name in exported_program.graph_signature.user_inputs
159        ):
160            first_user_input = node
161            break
162    return first_user_input
163
164
165def replace_with_constant_node(
166    node: torch.fx.Node,
167    prop_constant_tensor: torch.Tensor,
168    first_user_input: torch.fx.Node,
169    fake_mode,
170    exported_program: ExportedProgram,
171) -> tuple[torch.fx.Node, str]:
172    # Add `prop_constant_tensor` to program.state_dict.
173    prop_constant_tensor_fqn = f"_prop_tensor_constant{len(exported_program.constants)}"
174    exported_program.constants[prop_constant_tensor_fqn] = prop_constant_tensor
175
176    # Insert a new placeholder node for the propagated constant tensor.
177    with exported_program.graph.inserting_before(first_user_input):
178        const_placeholder_node = exported_program.graph.placeholder(
179            prop_constant_tensor_fqn
180        )
181
182    # Update the meta data of the new placeholder (buffer) node.
183    for k, v in node.meta.items():
184        const_placeholder_node.meta[k] = v
185    const_placeholder_node.meta["val"] = fake_mode.from_tensor(
186        prop_constant_tensor, static_shapes=True
187    )
188    const_placeholder_node.meta["val"].constant = prop_constant_tensor
189
190    # Replace the original node with the new constant node.
191    node.replace_all_uses_with(const_placeholder_node)
192    exported_program.graph.erase_node(node)
193
194    return const_placeholder_node, prop_constant_tensor_fqn
195
196
197def get_fake_mode(exported_program: ExportedProgram):
198    fake_mode = detect_fake_mode(
199        tuple(
200            node.meta["val"]
201            for node in exported_program.graph.nodes
202            if node.op == "placeholder"
203        )
204    )
205    assert fake_mode is not None
206    return fake_mode
207
208
209def erase_constant_node(
210    exported_program: ExportedProgram,
211    node: torch.fx.Node,
212) -> None:
213    # Remove corresponding tensor from param/constants dict.
214    signature = exported_program.graph_signature
215    if name := signature.inputs_to_parameters.get(node.name, None):
216        exported_program.state_dict.pop(name, None)
217    elif name := signature.inputs_to_lifted_tensor_constants.get(node.name, None):
218        exported_program.constants.pop(name, None)
219    elif name := signature.inputs_to_buffers.get(node.name, None):
220        exported_program.constants.pop(name, None)
221        exported_program.state_dict.pop(name, None)
222
223    # Remove from graph.
224    exported_program.graph.erase_node(node)
225
226
227def create_constant_nodes_and_return_specs(
228    const_node_to_tensor: Mapping[torch.fx.Node, torch.Tensor],
229    exported_program: ExportedProgram,
230) -> dict[str, InputSpec]:
231    """
232    Creates constant nodes for all entries in `const_node_to_tensor` and returns a node.name -> InputSpec dict.
233    """
234    name_to_spec_dict: dict[str, InputSpec] = {}
235
236    fake_mode = get_fake_mode(exported_program)
237    first_user_input = get_first_user_input(exported_program)
238
239    # Iterate over nodes in reverse order.
240    for node, prop_constant_tensor in reversed(const_node_to_tensor.items()):
241        if all(x in const_node_to_tensor for x in node.users):
242            # All users of this constant node are also constant, so we don't need to create a new constant node.
243            erase_constant_node(exported_program, node)
244            continue
245
246        if node.op == "placeholder":
247            continue
248
249        const_placeholder_node, prop_constant_tensor_fqn = replace_with_constant_node(
250            node, prop_constant_tensor, first_user_input, fake_mode, exported_program
251        )
252
253        # Create input spec for lifted constant.
254        name_to_spec_dict[const_placeholder_node.name] = InputSpec(
255            kind=InputKind.CONSTANT_TENSOR,
256            arg=TensorArgument(name=const_placeholder_node.name),
257            target=prop_constant_tensor_fqn,
258            persistent=True,
259        )
260    return name_to_spec_dict
261
262
263def constant_prop_pass(
264    exported_program: ExportedProgram,
265    custom_skip_targets: Optional[set[EdgeOpOverload]] = None,
266) -> ExportedProgram:
267    """
268    This pass is for constant propagation for Exported Program with lifted parameters,
269    as the parameters will not be shown up as `get_attr` but as `placeholder` to the graph.
270
271    Args:
272        exported_program: The ExportedProgram to perform constant propagation on.
273        custom_skip_targets: Optional set of EdgeOpOverload targets to skip during constant propagation.
274
275    Returns:
276        The modified ExportedProgram with constant propagation applied.
277    """
278    if (
279        len([node for node in exported_program.graph.nodes if node.op == "placeholder"])
280        == 0
281    ):
282        return exported_program
283
284    has_control_flow = [
285        node
286        for node in exported_program.graph.nodes
287        if node.target == torch.ops.higher_order.cond
288    ]
289    if len(has_control_flow) > 0:
290        raise RuntimeError("constant_prop_pass for control flow is not supported yet.")
291
292    const_node_to_tensor = get_propagated_const_tensor_dict(
293        exported_program, custom_skip_targets
294    )
295
296    # Get old input specs.
297    name_to_spec_dict = {
298        s.arg.name: s for s in exported_program.graph_signature.input_specs
299    }
300    # Add the new constants to input specs dict.
301    name_to_spec_dict.update(
302        create_constant_nodes_and_return_specs(const_node_to_tensor, exported_program)
303    )
304
305    # Generate new input spec.
306    new_input_specs = []
307    for node in exported_program.graph.nodes:
308        if node.op != "placeholder":
309            continue
310        new_input_specs.append(name_to_spec_dict[node.name])
311    exported_program.graph_signature.input_specs = new_input_specs
312
313    # Cleanup the graph.
314    exported_program.graph.eliminate_dead_code()
315    exported_program.graph_module.recompile()
316
317    return exported_program
318