xref: /aosp_15_r20/external/executorch/devtools/backend_debug/tests/test_delegation_info.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8
9import pandas as pd
10
11import torch
12from executorch.devtools.backend_debug import DelegationBreakdown, get_delegation_info
13from executorch.exir import to_edge
14from executorch.exir.backend.test.op_partitioner_demo import AddMulPartitionerDemo
15from pandas.testing import assert_frame_equal
16
17
18class TestUtils(unittest.TestCase):
19    def test_get_delegation_info(self):
20        class Model(torch.nn.Module):
21            def __init__(self):
22                super().__init__()
23
24            def forward(self, a, x, b):
25                y = torch.mm(a, x)
26                z = y + b
27                a = z - a
28                y = torch.mm(a, x)
29                z = y + b
30                return z
31
32        m = Model()
33        inputs = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
34        edge = to_edge(torch.export.export(m, inputs)).to_backend(
35            AddMulPartitionerDemo()
36        )
37        delegation_info = get_delegation_info(edge.exported_program().graph_module)
38
39        self.assertEqual(delegation_info.num_delegated_subgraphs, 2)
40        self.assertEqual(delegation_info.num_delegated_nodes, 4)
41        self.assertEqual(delegation_info.num_non_delegated_nodes, 3)
42        expected_delegation_by_op_dict = {
43            "aten_add_tensor": DelegationBreakdown(
44                op_type="aten_add_tensor", delegated=2, non_delegated=0
45            ),
46            "aten_mm_default": DelegationBreakdown(
47                op_type="aten_mm_default", delegated=2, non_delegated=0
48            ),
49            "aten_sub_tensor": DelegationBreakdown(
50                op_type="aten_sub_tensor", delegated=0, non_delegated=1
51            ),
52            "getitem": DelegationBreakdown(
53                op_type="getitem", delegated=0, non_delegated=2
54            ),
55        }
56        self.assertEqual(
57            delegation_info.delegation_by_operator, expected_delegation_by_op_dict
58        )
59
60        self.assertIn(
61            "Total delegated subgraphs",
62            delegation_info.get_summary(),
63        )
64
65        df = delegation_info.get_operator_delegation_dataframe()
66        expected_df = pd.DataFrame(
67            {
68                "op_type": [
69                    "aten_add_tensor",
70                    "aten_mm_default",
71                    "aten_sub_tensor",
72                    "getitem",
73                    "Total",
74                ],
75                "occurrences_in_delegated_graphs": [2, 2, 0, 0, 4],
76                "occurrences_in_non_delegated_graphs": [0, 0, 1, 2, 3],
77            }
78        )
79        assert_frame_equal(expected_df, df)
80