xref: /aosp_15_r20/external/executorch/exir/passes/weights_to_outputs_pass.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9from torch.export import ExportedProgram
10from torch.export.exported_program import OutputKind, OutputSpec, TensorArgument
11
12
13def weights_to_outputs_pass(
14    exported_program: ExportedProgram,
15) -> ExportedProgram:
16    """
17    This pass is for training graphs with gradients returned. It flags the weights as having a gradient attached,
18    and appends them to the outputs in order to make the weights easier to handle in memory planning and the emitter.
19
20    Args:
21        exported_program: The ExportedProgram to update.
22
23    Returns:
24        The modified ExportedProgram.
25    """
26    if (
27        len([node for node in exported_program.graph.nodes if node.op == "placeholder"])
28        == 0
29    ):
30        return exported_program
31
32    gs = exported_program.graph_signature
33    gm = exported_program.graph_module
34
35    # Check for/ get gradients
36    grad_targets = [
37        spec.target
38        for spec in gs.output_specs
39        if spec.kind == OutputKind.GRADIENT_TO_PARAMETER
40    ]
41
42    # If no gradients, return
43    if len(grad_targets) == 0:
44        return exported_program
45
46    inputs_to_params = gs.inputs_to_parameters
47
48    # Get output node
49    output_node = None
50    for node in gm.graph.nodes:
51        if node.op == "output":
52            output_node = node
53            break
54    assert output_node is not None
55
56    # Get input nodes that are weights with an associated gradient
57    placeholder_nodes = [
58        node
59        for node in gm.graph.nodes
60        if node.op == "placeholder"
61        and node.target in inputs_to_params.keys()
62        and inputs_to_params[node.target] in grad_targets
63    ]
64
65    # Flag these placeholder nodes as having a gradient attached so that memory planning will operate on them.
66    for node in placeholder_nodes:
67        node.meta["weight_has_gradient"] = True
68
69    # add to output node
70    new_output_nodes = []
71    new_output_nodes.extend(output_node.args[0])
72    new_output_nodes.extend(placeholder_nodes)
73    # Remove old outputs
74    new_output = gm.graph.output(tuple(new_output_nodes))
75    output_node.replace_all_uses_with(new_output)
76    gm.graph.erase_node(output_node)
77
78    # add to output signature
79    for node in placeholder_nodes:
80        gs.output_specs.append(
81            OutputSpec(
82                OutputKind.TOKEN,  # This is a hack. We are returning the raw weights here to make it easier for memory
83                # planning and the emitter. There is no outputkind.Parameter so I am using TOKEN which is currently unused in Edge.
84                TensorArgument(node.target),
85                None,
86            )
87        )
88
89    # Cleanup the graph.
90    exported_program.graph.eliminate_dead_code()
91    exported_program.graph_module.recompile()
92
93    return exported_program
94