1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-unsafe 8 9from torch.export import ExportedProgram 10from torch.export.exported_program import OutputKind, OutputSpec, TensorArgument 11 12 13def weights_to_outputs_pass( 14 exported_program: ExportedProgram, 15) -> ExportedProgram: 16 """ 17 This pass is for training graphs with gradients returned. It flags the weights as having a gradient attached, 18 and appends them to the outputs in order to make the weights easier to handle in memory planning and the emitter. 19 20 Args: 21 exported_program: The ExportedProgram to update. 22 23 Returns: 24 The modified ExportedProgram. 25 """ 26 if ( 27 len([node for node in exported_program.graph.nodes if node.op == "placeholder"]) 28 == 0 29 ): 30 return exported_program 31 32 gs = exported_program.graph_signature 33 gm = exported_program.graph_module 34 35 # Check for/ get gradients 36 grad_targets = [ 37 spec.target 38 for spec in gs.output_specs 39 if spec.kind == OutputKind.GRADIENT_TO_PARAMETER 40 ] 41 42 # If no gradients, return 43 if len(grad_targets) == 0: 44 return exported_program 45 46 inputs_to_params = gs.inputs_to_parameters 47 48 # Get output node 49 output_node = None 50 for node in gm.graph.nodes: 51 if node.op == "output": 52 output_node = node 53 break 54 assert output_node is not None 55 56 # Get input nodes that are weights with an associated gradient 57 placeholder_nodes = [ 58 node 59 for node in gm.graph.nodes 60 if node.op == "placeholder" 61 and node.target in inputs_to_params.keys() 62 and inputs_to_params[node.target] in grad_targets 63 ] 64 65 # Flag these placeholder nodes as having a gradient attached so that memory planning will operate on them. 66 for node in placeholder_nodes: 67 node.meta["weight_has_gradient"] = True 68 69 # add to output node 70 new_output_nodes = [] 71 new_output_nodes.extend(output_node.args[0]) 72 new_output_nodes.extend(placeholder_nodes) 73 # Remove old outputs 74 new_output = gm.graph.output(tuple(new_output_nodes)) 75 output_node.replace_all_uses_with(new_output) 76 gm.graph.erase_node(output_node) 77 78 # add to output signature 79 for node in placeholder_nodes: 80 gs.output_specs.append( 81 OutputSpec( 82 OutputKind.TOKEN, # This is a hack. We are returning the raw weights here to make it easier for memory 83 # planning and the emitter. There is no outputkind.Parameter so I am using TOKEN which is currently unused in Edge. 84 TensorArgument(node.target), 85 None, 86 ) 87 ) 88 89 # Cleanup the graph. 90 exported_program.graph.eliminate_dead_code() 91 exported_program.graph_module.recompile() 92 93 return exported_program 94