xref: /aosp_15_r20/external/pytorch/test/jit/test_op_decompositions.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import torch
4from torch.testing import FileCheck
5from torch.testing._internal.jit_utils import JitTestCase
6
7
8if __name__ == "__main__":
9    raise RuntimeError(
10        "This test file is not meant to be run directly, use:\n\n"
11        "\tpython test/test_jit.py TESTNAME\n\n"
12        "instead."
13    )
14
15
16class TestOpDecompositions(JitTestCase):
17    def test_op_decomposition(self):
18        def foo(x):
19            return torch.var(x, unbiased=True)
20
21        # TODO: more robust testing
22        foo_s = torch.jit.script(foo)
23        FileCheck().check("aten::var").run(foo_s.graph)
24        torch._C._jit_pass_run_decompositions(foo_s.graph)
25        inp = torch.rand([10, 10])
26        self.assertEqual(foo(inp), foo_s(inp))
27        FileCheck().check_not("aten::var").run(foo_s.graph)
28
29    def test_registered_decomposition(self):
30        @torch.jit.script
31        def foo(x):
32            return torch.square(x)
33
34        @torch.jit.script
35        def square_decomp(x):
36            return torch.pow(x, 2)
37
38        torch.jit._register_decomposition(
39            torch.ops.aten.square.default, square_decomp.graph
40        )
41        torch._C._jit_pass_run_decompositions(foo.graph)
42        FileCheck().check_not("aten::square").check("aten::pow").run(foo.graph)
43        x = torch.rand([4])
44        self.assertEqual(foo(x), torch.square(x))
45