xref: /aosp_15_r20/external/executorch/exir/tests/test_joint_graph.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8import unittest
9
10import torch
11import torch._dynamo
12
13from executorch.exir import to_edge
14
15from executorch.extension.pybindings.portable_lib import (
16    _load_for_executorch_from_buffer,
17)
18from torch.export._trace import _export
19from torch.export.experimental import _export_forward_backward
20from torch.export.exported_program import OutputKind
21
22
23class TestJointGraph(unittest.TestCase):
24    def test_joint_graph(self) -> None:
25        class Module(torch.nn.Module):
26            def __init__(self):
27                super().__init__()
28                self.linear = torch.nn.Linear(3, 3)
29                self.linear_no_train = torch.nn.Linear(3, 3)
30                for param in self.linear_no_train.parameters():
31                    param.requires_grad = False
32                self.loss = torch.nn.CrossEntropyLoss()
33
34            def forward(self, x, y):
35                return self.loss(self.linear_no_train(self.linear(x)).softmax(dim=0), y)
36
37        m = Module()
38        example_inputs = (torch.ones(3), torch.tensor([1.0, 0.0, 0.0]))
39        m(*example_inputs)
40        ep = _export(m, example_inputs, pre_dispatch=True)
41        joint_ep = _export_forward_backward(ep)
42        edge = to_edge(joint_ep)
43
44        output_node = None
45        for node in edge.exported_program().graph.nodes:
46            if node.op == "output":
47                output_node = node
48                break
49
50        orig_outputs = len(output_node.args[0])
51
52        et = edge.to_executorch()
53
54        weight_output_specs = [
55            spec
56            for spec in et.exported_program().graph_signature.output_specs
57            if spec.kind == OutputKind.TOKEN
58        ]
59
60        output_node = None
61        for node in et.exported_program().graph.nodes:
62            if node.op == "output":
63                output_node = node
64                break
65
66        weight_outputs = len(output_node.args[0])
67
68        # make sure 2 new outputs are added to both the node and the spec
69        self.assertEqual(len(weight_output_specs), 2)  # linear layer weight and bias
70        self.assertEqual(
71            weight_outputs - orig_outputs, 2
72        )  # linear layer weight and bias
73
74        # assert that the weight and bias have proper data_buffer_idx and allocation_info
75        self.assertEqual(
76            et.executorch_program.execution_plan[0]  # pyre-ignore
77            .values[0]
78            .val.data_buffer_idx,
79            1,
80        )
81        self.assertEqual(
82            et.executorch_program.execution_plan[0]  # pyre-ignore
83            .values[1]
84            .val.data_buffer_idx,
85            2,
86        )
87        self.assertEqual(
88            et.executorch_program.execution_plan[0]  # pyre-ignore
89            .values[0]
90            .val.allocation_info.memory_offset_low,
91            0,
92        )
93        self.assertEqual(
94            et.executorch_program.execution_plan[0]  # pyre-ignore
95            .values[1]
96            .val.allocation_info.memory_offset_low,
97            48,
98        )
99
100        loss = m(*example_inputs)
101        loss.backward()
102        et_mod = _load_for_executorch_from_buffer(et.buffer)
103        et_outputs = et_mod.forward(
104            example_inputs
105        )  # ET outputs are [loss, grads, weights]
106
107        self.assertTrue(torch.allclose(loss, et_outputs[0]))
108        self.assertTrue(
109            torch.allclose(m.linear.weight.grad, et_outputs[1])  # pyre-ignore[6]
110        )
111        self.assertTrue(torch.allclose(m.linear.bias.grad, et_outputs[2]))
112        self.assertTrue(torch.allclose(m.linear.weight, et_outputs[3]))
113        self.assertTrue(torch.allclose(m.linear.bias, et_outputs[4]))
114
115        self.assertEqual(
116            len(et.executorch_program.execution_plan), 4
117        )  # forward + 2 training metadata functions
118
119        # gradient outputs start at index 1
120        self.assertEqual(
121            et.executorch_program.execution_plan[1]  # pyre-ignore
122            .values[0]
123            .val.int_val,
124            1,
125        )
126
127        self.assertEqual(
128            et.executorch_program.execution_plan[2]  # pyre-ignore
129            .values[0]
130            .val.string_val,
131            "linear.weight",
132        )
133
134        # parameter outputs start at index 3
135        self.assertEqual(
136            et.executorch_program.execution_plan[3]  # pyre-ignore
137            .values[0]
138            .val.int_val,
139            3,
140        )
141