1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7"""End-to-end profiler tests. 8 9This must be built and run with `buck2 -c executorch.prof_enabled=true`. 10""" 11 12import unittest 13 14import torch 15 16from executorch.exir import to_edge 17 18from executorch.extension.pybindings.portable_lib import ( 19 _create_profile_block, 20 _dump_profile_results, 21 _load_for_executorch_from_buffer, 22 _reset_profile_results, 23) 24from executorch.extension.pytree import tree_flatten 25from executorch.profiler.fb.parse_profiler_results import profile_table 26from executorch.profiler.parse_profiler_results import ( 27 deserialize_profile_results, 28 profile_aggregate_framework_tax, 29 profile_framework_tax_table, 30) 31from torch.export import export 32 33 34class Module(torch.nn.Module): 35 def __init__(self): 36 super().__init__() 37 self.register_buffer("a", 3 * torch.ones(2, 2, dtype=torch.float)) 38 self.register_buffer("b", 2 * torch.ones(2, 2, dtype=torch.float)) 39 40 def forward(self, x): 41 a = torch.mul(self.a, x) 42 b = torch.add(a, self.b) 43 return b 44 45 46class TestCustomOps(unittest.TestCase): 47 @classmethod 48 def setUpClass(cls) -> None: 49 model = Module() 50 inputs = (torch.ones(2, 2, dtype=torch.float),) 51 52 # The serialized program file. This must live longer than cls.module, 53 # because the C++ pybindings will have a pointer to it. But none of the 54 # tests should need to touch it. 55 cls.__buffer: bytes = to_edge(export(model, inputs)).to_executorch().buffer 56 57 cls.module = _load_for_executorch_from_buffer(cls.__buffer) 58 59 # pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`. 60 cls.inputs_flattened, _ = tree_flatten(inputs) 61 cls.module.run_method("forward", tuple(cls.inputs_flattened)) 62 prof_dump = _dump_profile_results() 63 assert ( 64 len(prof_dump) > 0 65 ), "prof_dump is empty; may need to build with `-c executorch.prof_enabled=true`" 66 cls.prof_results, cls.mem_results = deserialize_profile_results(prof_dump) 67 cls.expect_ops = ["native_call_add.out", "native_call_mul.out"] 68 69 def test_profiler_new_block(self) -> None: 70 block_names = ["block_1", "block_2"] 71 _reset_profile_results() 72 _create_profile_block(block_names[0]) 73 self.module.run_method("forward", tuple(self.inputs_flattened)) 74 _create_profile_block(block_names[1]) 75 self.module.run_method("forward", tuple(self.inputs_flattened)) 76 prof_dump = _dump_profile_results() 77 self.assertGreater( 78 len(prof_dump), 79 0, 80 "prof_dump is empty; may need to build with `-c executorch.prof_enabled=true`", 81 ) 82 prof_results, mem_results = deserialize_profile_results(prof_dump) 83 for i, (block_name_, _) in enumerate(prof_results.items()): 84 self.assertTrue(block_names[i] == block_name_) 85 self.assertEqual(len(prof_results), 2) 86 87 def test_profiler_expected_ops(self) -> None: 88 found_count = 0 89 for block_name, prof_data_list in self.prof_results.items(): 90 for prof_event in prof_data_list: 91 if prof_event.name in self.expect_ops: 92 found_count += 1 93 self.assertTrue(block_name == "default") 94 self.assertEqual(found_count, len(self.expect_ops)) 95 96 def test_profile_framework_tax(self) -> None: 97 prof_agg_data = profile_aggregate_framework_tax(self.prof_results) 98 for name, framework_tax in prof_agg_data.items(): 99 self.assertTrue(len(framework_tax.exec_time) == 1) 100 self.assertTrue(len(framework_tax.kernel_and_delegate_time) == 1) 101 self.assertTrue(len(framework_tax.framework_tax) == 1) 102 self.assertTrue(float(framework_tax.framework_tax[0]) < 100) 103 self.assertTrue(name == "default") 104 105 def test_gen_profile_table(self) -> None: 106 prof_table = profile_table(self.prof_results) 107 found_count = 0 108 for table in prof_table: 109 for entry in table: 110 for op in self.expect_ops: 111 found_count += 1 if op in entry.get_string() else 0 112 self.assertEqual(found_count, len(self.expect_ops)) 113 114 def test_gen_profile_framework_tax_table(self) -> None: 115 prof_agg_data = profile_aggregate_framework_tax(self.prof_results) 116 prof_framework_tax_table = profile_framework_tax_table(prof_agg_data) 117 expected_entries = [ 118 "Model execution time", 119 "Time spent in kernels", 120 "Framework tax", 121 ] 122 found_count = 0 123 for table in prof_framework_tax_table: 124 for entry in table: 125 for expected_entry in expected_entries: 126 found_count += 1 if expected_entry in entry.get_string() else 0 127 self.assertEqual(found_count, len(expected_entries)) 128 129 130def main() -> None: 131 unittest.main() 132 133 134if __name__ == "__main__": 135 main() # pragma: no cover 136