xref: /aosp_15_r20/external/pytorch/test/inductor/test_codegen_triton.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: inductor"]
2import contextlib
3
4import sympy
5
6import torch
7import torch._inductor.config as inductor_config
8from torch._inductor.codegen import triton_utils
9from torch._inductor.codegen.common import SizeArg
10from torch._inductor.graph import GraphLowering
11from torch._inductor.test_case import TestCase as InductorTestCase
12from torch._inductor.virtualized import V
13from torch.testing._internal.inductor_utils import HAS_CPU, HAS_GPU
14
15
16class TestCodegenTriton(InductorTestCase):
17    def setUp(self):
18        super().setUp()
19
20        class DummyModule(torch.nn.Module):
21            def forward(self, x):
22                return x * 2
23
24        self._gm = torch.fx.symbolic_trace(DummyModule())
25        self._graph = GraphLowering(self._gm)
26
27        self._stack = contextlib.ExitStack()
28        self._stack.enter_context(V.set_graph_handler(self._graph))
29
30    def tearDown(self):
31        self._stack.close()
32        super().tearDown()
33
34    @inductor_config.patch("triton.divisible_by_16", True)
35    def test_config_of_sizearg(self):
36        two = sympy.Integer(2)
37        eight = sympy.Integer(8)
38        sixteen = sympy.Integer(16)
39        s0 = sympy.Symbol("s0", positive=True, integer=True)
40        s1 = sympy.Symbol("s1", positive=True, integer=True)
41
42        self.assertEqual(
43            (2,),
44            triton_utils.config_of(
45                [
46                    SizeArg("A", two),  # no
47                    SizeArg("B", eight),  # no
48                    SizeArg("C", sixteen),  # yes
49                    SizeArg("D", s0),  # no
50                    SizeArg("E", s1),  # no
51                ]
52            ).divisible_by_16,
53        )
54
55        self.assertEqual(
56            (0, 2, 4, 5, 6),
57            triton_utils.config_of(
58                [
59                    SizeArg("A", two * eight),  # 0: yes
60                    SizeArg("B", eight * s0),  # 1: no
61                    SizeArg("C", two * eight * s0),  # 2: yes
62                    SizeArg("D", s0 * s1),  # 3: no
63                    SizeArg("E", sixteen * s0),  # 4: yes
64                    SizeArg("F", sixteen * eight * s0 * s1),  # 5: yes
65                    SizeArg("G", two * eight * s0 * s1),  # 6: yes
66                ]
67            ).divisible_by_16,
68        )
69
70
71if __name__ == "__main__":
72    from torch._inductor.test_case import run_tests
73
74    if HAS_CPU or HAS_GPU:
75        run_tests("sympy")
76