xref: /aosp_15_r20/external/pytorch/test/jit/test_complexity.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport contextlib
4*da0073e9SAndroid Build Coastguard Workerimport os
5*da0073e9SAndroid Build Coastguard Workerimport sys
6*da0073e9SAndroid Build Coastguard Workerimport unittest
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Workerimport torch
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Worker# Make the helper files in test/ importable
12*da0073e9SAndroid Build Coastguard Workerpytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
13*da0073e9SAndroid Build Coastguard Workersys.path.append(pytorch_test_dir)
14*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
15*da0073e9SAndroid Build Coastguard Worker    IS_FBCODE,
16*da0073e9SAndroid Build Coastguard Worker    run_tests,
17*da0073e9SAndroid Build Coastguard Worker    set_default_dtype,
18*da0073e9SAndroid Build Coastguard Worker    suppress_warnings,
19*da0073e9SAndroid Build Coastguard Worker)
20*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_metaprogramming_utils import (
21*da0073e9SAndroid Build Coastguard Worker    get_all_nn_module_tests,
22*da0073e9SAndroid Build Coastguard Worker    get_nn_functional_compiled_fn_and_inputs,
23*da0073e9SAndroid Build Coastguard Worker    get_nn_mod_test_name,
24*da0073e9SAndroid Build Coastguard Worker    nn_functional_tests,
25*da0073e9SAndroid Build Coastguard Worker    try_get_nn_module_compiled_mod_and_inputs,
26*da0073e9SAndroid Build Coastguard Worker)
27*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import enable_profiling_mode, JitTestCase
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Workerdef num_ifs_loops(graph):
31*da0073e9SAndroid Build Coastguard Worker    graph_str = str(graph)
32*da0073e9SAndroid Build Coastguard Worker    # only look at body of graph
33*da0073e9SAndroid Build Coastguard Worker    graph_body = graph_str[0 : graph_str.find("return")]
34*da0073e9SAndroid Build Coastguard Worker    return graph_body.count("prim::Loop") + graph_body.count("prim::If")
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Worker
37*da0073e9SAndroid Build Coastguard Workerdef num_non_tensor_nodes(block):
38*da0073e9SAndroid Build Coastguard Worker    num_non_tensor = 0
39*da0073e9SAndroid Build Coastguard Worker    for node in block.nodes():
40*da0073e9SAndroid Build Coastguard Worker        kind = node.kind()
41*da0073e9SAndroid Build Coastguard Worker        # GetAttr don't provide useful signal here, since they are non-optimizable except with freezing
42*da0073e9SAndroid Build Coastguard Worker        # Constant is not executed, bailouts should be a separate tests, don't provide useful signal here
43*da0073e9SAndroid Build Coastguard Worker        if kind == "prim::Constant" or "prim::Bailout" in kind or "GetAttr" in kind:
44*da0073e9SAndroid Build Coastguard Worker            continue
45*da0073e9SAndroid Build Coastguard Worker        for b in node.blocks():
46*da0073e9SAndroid Build Coastguard Worker            num_non_tensor += num_non_tensor_nodes(b)
47*da0073e9SAndroid Build Coastguard Worker        tensor_out = False
48*da0073e9SAndroid Build Coastguard Worker        for out in node.outputs():
49*da0073e9SAndroid Build Coastguard Worker            if "Tensor" in str(out.type()):
50*da0073e9SAndroid Build Coastguard Worker                tensor_out = True
51*da0073e9SAndroid Build Coastguard Worker                break
52*da0073e9SAndroid Build Coastguard Worker        num_non_tensor += int(not tensor_out)
53*da0073e9SAndroid Build Coastguard Worker    return num_non_tensor
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard Workerclass TestComplexity(JitTestCase):
57*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
58*da0073e9SAndroid Build Coastguard Worker        super().setUp()
59*da0073e9SAndroid Build Coastguard Worker        self.grad_enabled = torch.is_grad_enabled()
60*da0073e9SAndroid Build Coastguard Worker        torch.set_grad_enabled(False)
61*da0073e9SAndroid Build Coastguard Worker        self._stack = contextlib.ExitStack()
62*da0073e9SAndroid Build Coastguard Worker        self._stack.enter_context(set_default_dtype(torch.double))
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Worker    def tearDown(self):
65*da0073e9SAndroid Build Coastguard Worker        self._stack.close()
66*da0073e9SAndroid Build Coastguard Worker        torch.set_grad_enabled(self.grad_enabled)
67*da0073e9SAndroid Build Coastguard Worker        super().tearDown()
68*da0073e9SAndroid Build Coastguard Worker
69*da0073e9SAndroid Build Coastguard Worker    @suppress_warnings
70*da0073e9SAndroid Build Coastguard Worker    def test_generated_functional_tests(self):
71*da0073e9SAndroid Build Coastguard Worker        with enable_profiling_mode():
72*da0073e9SAndroid Build Coastguard Worker            stats = [("Name", "Ifs/Loops", "non-tensor ops")]
73*da0073e9SAndroid Build Coastguard Worker            for test in nn_functional_tests:
74*da0073e9SAndroid Build Coastguard Worker                test_name = test[0]
75*da0073e9SAndroid Build Coastguard Worker
76*da0073e9SAndroid Build Coastguard Worker                fn, inputs = get_nn_functional_compiled_fn_and_inputs(*test)
77*da0073e9SAndroid Build Coastguard Worker                for _ in range(6):
78*da0073e9SAndroid Build Coastguard Worker                    fn(*inputs)
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard Worker                g = torch.jit.last_executed_optimized_graph()
81*da0073e9SAndroid Build Coastguard Worker                stats.append((test_name, num_ifs_loops(g), num_non_tensor_nodes(g)))
82*da0073e9SAndroid Build Coastguard Worker        for line in stats:
83*da0073e9SAndroid Build Coastguard Worker            print(line)
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Worker    @suppress_warnings
86*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_FBCODE, "Causes a RecursionError in fbcode")
87*da0073e9SAndroid Build Coastguard Worker    def test_nn_module_tests(self):
88*da0073e9SAndroid Build Coastguard Worker        with enable_profiling_mode():
89*da0073e9SAndroid Build Coastguard Worker            stats = [("Name", "Ifs/Loops", "non-tensor ops")]
90*da0073e9SAndroid Build Coastguard Worker            for test in get_all_nn_module_tests():
91*da0073e9SAndroid Build Coastguard Worker                out = try_get_nn_module_compiled_mod_and_inputs(**test)
92*da0073e9SAndroid Build Coastguard Worker                if not out:
93*da0073e9SAndroid Build Coastguard Worker                    continue
94*da0073e9SAndroid Build Coastguard Worker
95*da0073e9SAndroid Build Coastguard Worker                mod, inputs = out
96*da0073e9SAndroid Build Coastguard Worker                test_name = get_nn_mod_test_name(**test)
97*da0073e9SAndroid Build Coastguard Worker                for _ in range(6):
98*da0073e9SAndroid Build Coastguard Worker                    mod(*inputs)
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Worker                g = torch.jit.last_executed_optimized_graph()
101*da0073e9SAndroid Build Coastguard Worker                stats.append((test_name, num_ifs_loops(g), num_non_tensor_nodes(g)))
102*da0073e9SAndroid Build Coastguard Worker
103*da0073e9SAndroid Build Coastguard Worker            for line in stats:
104*da0073e9SAndroid Build Coastguard Worker                print(line)
105*da0073e9SAndroid Build Coastguard Worker
106*da0073e9SAndroid Build Coastguard Worker
107*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
108*da0073e9SAndroid Build Coastguard Worker    run_tests()
109