xref: /aosp_15_r20/external/executorch/profiler/test/test_profiler_e2e.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7"""End-to-end profiler tests.
8
9This must be built and run with `buck2 -c executorch.prof_enabled=true`.
10"""
11
12import unittest
13
14import torch
15
16from executorch.exir import to_edge
17
18from executorch.extension.pybindings.portable_lib import (
19    _create_profile_block,
20    _dump_profile_results,
21    _load_for_executorch_from_buffer,
22    _reset_profile_results,
23)
24from executorch.extension.pytree import tree_flatten
25from executorch.profiler.fb.parse_profiler_results import profile_table
26from executorch.profiler.parse_profiler_results import (
27    deserialize_profile_results,
28    profile_aggregate_framework_tax,
29    profile_framework_tax_table,
30)
31from torch.export import export
32
33
34class Module(torch.nn.Module):
35    def __init__(self):
36        super().__init__()
37        self.register_buffer("a", 3 * torch.ones(2, 2, dtype=torch.float))
38        self.register_buffer("b", 2 * torch.ones(2, 2, dtype=torch.float))
39
40    def forward(self, x):
41        a = torch.mul(self.a, x)
42        b = torch.add(a, self.b)
43        return b
44
45
46class TestCustomOps(unittest.TestCase):
47    @classmethod
48    def setUpClass(cls) -> None:
49        model = Module()
50        inputs = (torch.ones(2, 2, dtype=torch.float),)
51
52        # The serialized program file. This must live longer than cls.module,
53        # because the C++ pybindings will have a pointer to it. But none of the
54        # tests should need to touch it.
55        cls.__buffer: bytes = to_edge(export(model, inputs)).to_executorch().buffer
56
57        cls.module = _load_for_executorch_from_buffer(cls.__buffer)
58
59        # pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`.
60        cls.inputs_flattened, _ = tree_flatten(inputs)
61        cls.module.run_method("forward", tuple(cls.inputs_flattened))
62        prof_dump = _dump_profile_results()
63        assert (
64            len(prof_dump) > 0
65        ), "prof_dump is empty; may need to build with `-c executorch.prof_enabled=true`"
66        cls.prof_results, cls.mem_results = deserialize_profile_results(prof_dump)
67        cls.expect_ops = ["native_call_add.out", "native_call_mul.out"]
68
69    def test_profiler_new_block(self) -> None:
70        block_names = ["block_1", "block_2"]
71        _reset_profile_results()
72        _create_profile_block(block_names[0])
73        self.module.run_method("forward", tuple(self.inputs_flattened))
74        _create_profile_block(block_names[1])
75        self.module.run_method("forward", tuple(self.inputs_flattened))
76        prof_dump = _dump_profile_results()
77        self.assertGreater(
78            len(prof_dump),
79            0,
80            "prof_dump is empty; may need to build with `-c executorch.prof_enabled=true`",
81        )
82        prof_results, mem_results = deserialize_profile_results(prof_dump)
83        for i, (block_name_, _) in enumerate(prof_results.items()):
84            self.assertTrue(block_names[i] == block_name_)
85        self.assertEqual(len(prof_results), 2)
86
87    def test_profiler_expected_ops(self) -> None:
88        found_count = 0
89        for block_name, prof_data_list in self.prof_results.items():
90            for prof_event in prof_data_list:
91                if prof_event.name in self.expect_ops:
92                    found_count += 1
93            self.assertTrue(block_name == "default")
94        self.assertEqual(found_count, len(self.expect_ops))
95
96    def test_profile_framework_tax(self) -> None:
97        prof_agg_data = profile_aggregate_framework_tax(self.prof_results)
98        for name, framework_tax in prof_agg_data.items():
99            self.assertTrue(len(framework_tax.exec_time) == 1)
100            self.assertTrue(len(framework_tax.kernel_and_delegate_time) == 1)
101            self.assertTrue(len(framework_tax.framework_tax) == 1)
102            self.assertTrue(float(framework_tax.framework_tax[0]) < 100)
103            self.assertTrue(name == "default")
104
105    def test_gen_profile_table(self) -> None:
106        prof_table = profile_table(self.prof_results)
107        found_count = 0
108        for table in prof_table:
109            for entry in table:
110                for op in self.expect_ops:
111                    found_count += 1 if op in entry.get_string() else 0
112        self.assertEqual(found_count, len(self.expect_ops))
113
114    def test_gen_profile_framework_tax_table(self) -> None:
115        prof_agg_data = profile_aggregate_framework_tax(self.prof_results)
116        prof_framework_tax_table = profile_framework_tax_table(prof_agg_data)
117        expected_entries = [
118            "Model execution time",
119            "Time spent in kernels",
120            "Framework tax",
121        ]
122        found_count = 0
123        for table in prof_framework_tax_table:
124            for entry in table:
125                for expected_entry in expected_entries:
126                    found_count += 1 if expected_entry in entry.get_string() else 0
127        self.assertEqual(found_count, len(expected_entries))
128
129
130def main() -> None:
131    unittest.main()
132
133
134if __name__ == "__main__":
135    main()  # pragma: no cover
136