xref: /aosp_15_r20/external/pytorch/test/jit/test_complexity.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import contextlib
4import os
5import sys
6import unittest
7
8import torch
9
10
11# Make the helper files in test/ importable
12pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
13sys.path.append(pytorch_test_dir)
14from torch.testing._internal.common_utils import (
15    IS_FBCODE,
16    run_tests,
17    set_default_dtype,
18    suppress_warnings,
19)
20from torch.testing._internal.jit_metaprogramming_utils import (
21    get_all_nn_module_tests,
22    get_nn_functional_compiled_fn_and_inputs,
23    get_nn_mod_test_name,
24    nn_functional_tests,
25    try_get_nn_module_compiled_mod_and_inputs,
26)
27from torch.testing._internal.jit_utils import enable_profiling_mode, JitTestCase
28
29
30def num_ifs_loops(graph):
31    graph_str = str(graph)
32    # only look at body of graph
33    graph_body = graph_str[0 : graph_str.find("return")]
34    return graph_body.count("prim::Loop") + graph_body.count("prim::If")
35
36
37def num_non_tensor_nodes(block):
38    num_non_tensor = 0
39    for node in block.nodes():
40        kind = node.kind()
41        # GetAttr don't provide useful signal here, since they are non-optimizable except with freezing
42        # Constant is not executed, bailouts should be a separate tests, don't provide useful signal here
43        if kind == "prim::Constant" or "prim::Bailout" in kind or "GetAttr" in kind:
44            continue
45        for b in node.blocks():
46            num_non_tensor += num_non_tensor_nodes(b)
47        tensor_out = False
48        for out in node.outputs():
49            if "Tensor" in str(out.type()):
50                tensor_out = True
51                break
52        num_non_tensor += int(not tensor_out)
53    return num_non_tensor
54
55
56class TestComplexity(JitTestCase):
57    def setUp(self):
58        super().setUp()
59        self.grad_enabled = torch.is_grad_enabled()
60        torch.set_grad_enabled(False)
61        self._stack = contextlib.ExitStack()
62        self._stack.enter_context(set_default_dtype(torch.double))
63
64    def tearDown(self):
65        self._stack.close()
66        torch.set_grad_enabled(self.grad_enabled)
67        super().tearDown()
68
69    @suppress_warnings
70    def test_generated_functional_tests(self):
71        with enable_profiling_mode():
72            stats = [("Name", "Ifs/Loops", "non-tensor ops")]
73            for test in nn_functional_tests:
74                test_name = test[0]
75
76                fn, inputs = get_nn_functional_compiled_fn_and_inputs(*test)
77                for _ in range(6):
78                    fn(*inputs)
79
80                g = torch.jit.last_executed_optimized_graph()
81                stats.append((test_name, num_ifs_loops(g), num_non_tensor_nodes(g)))
82        for line in stats:
83            print(line)
84
85    @suppress_warnings
86    @unittest.skipIf(IS_FBCODE, "Causes a RecursionError in fbcode")
87    def test_nn_module_tests(self):
88        with enable_profiling_mode():
89            stats = [("Name", "Ifs/Loops", "non-tensor ops")]
90            for test in get_all_nn_module_tests():
91                out = try_get_nn_module_compiled_mod_and_inputs(**test)
92                if not out:
93                    continue
94
95                mod, inputs = out
96                test_name = get_nn_mod_test_name(**test)
97                for _ in range(6):
98                    mod(*inputs)
99
100                g = torch.jit.last_executed_optimized_graph()
101                stats.append((test_name, num_ifs_loops(g), num_non_tensor_nodes(g)))
102
103            for line in stats:
104                print(line)
105
106
107if __name__ == "__main__":
108    run_tests()
109