1# Owner(s): ["module: inductor"] 2import contextlib 3 4import sympy 5 6import torch 7import torch._inductor.config as inductor_config 8from torch._inductor.codegen import triton_utils 9from torch._inductor.codegen.common import SizeArg 10from torch._inductor.graph import GraphLowering 11from torch._inductor.test_case import TestCase as InductorTestCase 12from torch._inductor.virtualized import V 13from torch.testing._internal.inductor_utils import HAS_CPU, HAS_GPU 14 15 16class TestCodegenTriton(InductorTestCase): 17 def setUp(self): 18 super().setUp() 19 20 class DummyModule(torch.nn.Module): 21 def forward(self, x): 22 return x * 2 23 24 self._gm = torch.fx.symbolic_trace(DummyModule()) 25 self._graph = GraphLowering(self._gm) 26 27 self._stack = contextlib.ExitStack() 28 self._stack.enter_context(V.set_graph_handler(self._graph)) 29 30 def tearDown(self): 31 self._stack.close() 32 super().tearDown() 33 34 @inductor_config.patch("triton.divisible_by_16", True) 35 def test_config_of_sizearg(self): 36 two = sympy.Integer(2) 37 eight = sympy.Integer(8) 38 sixteen = sympy.Integer(16) 39 s0 = sympy.Symbol("s0", positive=True, integer=True) 40 s1 = sympy.Symbol("s1", positive=True, integer=True) 41 42 self.assertEqual( 43 (2,), 44 triton_utils.config_of( 45 [ 46 SizeArg("A", two), # no 47 SizeArg("B", eight), # no 48 SizeArg("C", sixteen), # yes 49 SizeArg("D", s0), # no 50 SizeArg("E", s1), # no 51 ] 52 ).divisible_by_16, 53 ) 54 55 self.assertEqual( 56 (0, 2, 4, 5, 6), 57 triton_utils.config_of( 58 [ 59 SizeArg("A", two * eight), # 0: yes 60 SizeArg("B", eight * s0), # 1: no 61 SizeArg("C", two * eight * s0), # 2: yes 62 SizeArg("D", s0 * s1), # 3: no 63 SizeArg("E", sixteen * s0), # 4: yes 64 SizeArg("F", sixteen * eight * s0 * s1), # 5: yes 65 SizeArg("G", two * eight * s0 * s1), # 6: yes 66 ] 67 ).divisible_by_16, 68 ) 69 70 71if __name__ == "__main__": 72 from torch._inductor.test_case import run_tests 73 74 if HAS_CPU or HAS_GPU: 75 run_tests("sympy") 76