1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8import unittest 9 10import torch 11import torch._dynamo 12 13from executorch.exir import to_edge 14 15from executorch.extension.pybindings.portable_lib import ( 16 _load_for_executorch_from_buffer, 17) 18from torch.export._trace import _export 19from torch.export.experimental import _export_forward_backward 20from torch.export.exported_program import OutputKind 21 22 23class TestJointGraph(unittest.TestCase): 24 def test_joint_graph(self) -> None: 25 class Module(torch.nn.Module): 26 def __init__(self): 27 super().__init__() 28 self.linear = torch.nn.Linear(3, 3) 29 self.linear_no_train = torch.nn.Linear(3, 3) 30 for param in self.linear_no_train.parameters(): 31 param.requires_grad = False 32 self.loss = torch.nn.CrossEntropyLoss() 33 34 def forward(self, x, y): 35 return self.loss(self.linear_no_train(self.linear(x)).softmax(dim=0), y) 36 37 m = Module() 38 example_inputs = (torch.ones(3), torch.tensor([1.0, 0.0, 0.0])) 39 m(*example_inputs) 40 ep = _export(m, example_inputs, pre_dispatch=True) 41 joint_ep = _export_forward_backward(ep) 42 edge = to_edge(joint_ep) 43 44 output_node = None 45 for node in edge.exported_program().graph.nodes: 46 if node.op == "output": 47 output_node = node 48 break 49 50 orig_outputs = len(output_node.args[0]) 51 52 et = edge.to_executorch() 53 54 weight_output_specs = [ 55 spec 56 for spec in et.exported_program().graph_signature.output_specs 57 if spec.kind == OutputKind.TOKEN 58 ] 59 60 output_node = None 61 for node in et.exported_program().graph.nodes: 62 if node.op == "output": 63 output_node = node 64 break 65 66 weight_outputs = len(output_node.args[0]) 67 68 # make sure 2 new outputs are added to both the node and the spec 69 self.assertEqual(len(weight_output_specs), 2) # linear layer weight and bias 70 self.assertEqual( 71 weight_outputs - orig_outputs, 2 72 ) # linear layer weight and bias 73 74 # assert that the weight and bias have proper data_buffer_idx and allocation_info 75 self.assertEqual( 76 et.executorch_program.execution_plan[0] # pyre-ignore 77 .values[0] 78 .val.data_buffer_idx, 79 1, 80 ) 81 self.assertEqual( 82 et.executorch_program.execution_plan[0] # pyre-ignore 83 .values[1] 84 .val.data_buffer_idx, 85 2, 86 ) 87 self.assertEqual( 88 et.executorch_program.execution_plan[0] # pyre-ignore 89 .values[0] 90 .val.allocation_info.memory_offset_low, 91 0, 92 ) 93 self.assertEqual( 94 et.executorch_program.execution_plan[0] # pyre-ignore 95 .values[1] 96 .val.allocation_info.memory_offset_low, 97 48, 98 ) 99 100 loss = m(*example_inputs) 101 loss.backward() 102 et_mod = _load_for_executorch_from_buffer(et.buffer) 103 et_outputs = et_mod.forward( 104 example_inputs 105 ) # ET outputs are [loss, grads, weights] 106 107 self.assertTrue(torch.allclose(loss, et_outputs[0])) 108 self.assertTrue( 109 torch.allclose(m.linear.weight.grad, et_outputs[1]) # pyre-ignore[6] 110 ) 111 self.assertTrue(torch.allclose(m.linear.bias.grad, et_outputs[2])) 112 self.assertTrue(torch.allclose(m.linear.weight, et_outputs[3])) 113 self.assertTrue(torch.allclose(m.linear.bias, et_outputs[4])) 114 115 self.assertEqual( 116 len(et.executorch_program.execution_plan), 4 117 ) # forward + 2 training metadata functions 118 119 # gradient outputs start at index 1 120 self.assertEqual( 121 et.executorch_program.execution_plan[1] # pyre-ignore 122 .values[0] 123 .val.int_val, 124 1, 125 ) 126 127 self.assertEqual( 128 et.executorch_program.execution_plan[2] # pyre-ignore 129 .values[0] 130 .val.string_val, 131 "linear.weight", 132 ) 133 134 # parameter outputs start at index 3 135 self.assertEqual( 136 et.executorch_program.execution_plan[3] # pyre-ignore 137 .values[0] 138 .val.int_val, 139 3, 140 ) 141