# Owner(s): ["module: inductor"] import contextlib import torch from torch._inductor.dependencies import MemoryDep from torch._inductor.graph import GraphLowering from torch._inductor.ir import Buffer, FixedLayout, Pointwise from torch._inductor.test_case import TestCase as InductorTestCase from torch._inductor.utils import sympy_index_symbol from torch._inductor.virtualized import ops, V from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU class TestDependencies(InductorTestCase): def _create_buffer(self, name, shape, dtype=torch.float32): return Buffer(name, FixedLayout(torch.device(GPU_TYPE), dtype, shape)) def setUp(self): super().setUp() class DummyModule(torch.nn.Module): def forward(self, x): return x * 2 self._gm = torch.fx.symbolic_trace(DummyModule()) self._graph = GraphLowering(self._gm) self._stack = contextlib.ExitStack() self._stack.enter_context(V.set_graph_handler(self._graph)) def tearDown(self): self._stack.close() super().tearDown() def test_bucketize_dependencies(self): offsets = self._create_buffer("offsets", (1025,), torch.int32) def inner_fn(index): idx = index[0] return ops.bucketize( values=idx, offsets_name=offsets.get_name(), offsets_size=offsets.get_size()[0], indexing_dtype=torch.int32, right=True, ) pointwise = Pointwise.create( device=torch.device(GPU_TYPE), dtype=torch.int32, inner_fn=inner_fn, ranges=[1024 * 4], ) self.assertEqual(len(pointwise.get_reads()), 1) def test_get_offset(self): x = sympy_index_symbol("x") y = sympy_index_symbol("y") var_ranges = { x: 1024, y: 2048, } dep1 = MemoryDep( "dep1", x * 2048 + y, list(var_ranges.keys()), list(var_ranges.values()), ) dep2 = MemoryDep( "dep2", x * 2048 + y + 1024, list(var_ranges.keys()), list(var_ranges.values()), ) self.assertEqual(dep1.get_offset(), 0) self.assertEqual(dep2.get_offset(), 1024) def test_normalize_with_stride_order_equal(self): x = sympy_index_symbol("x") y = sympy_index_symbol("y") var_ranges = { x: 1024, y: 2048, } loop_order1 = MemoryDep( "access_the_same_buffer", x * 2048 + y, [x, y], [1024, 2048], ) loop_order2 = MemoryDep( "access_the_same_buffer", x * 2048 + y, [y, x], [2048, 1024], ) self.assertTrue(loop_order1 != loop_order2) normalized_loop_order1 = loop_order1.normalize_with_stride_order() normalized_loop_order2 = loop_order2.normalize_with_stride_order() self.assertTrue(normalized_loop_order1 == normalized_loop_order2) def test_normalize_with_stride_order_unequal(self): x = sympy_index_symbol("x") y = sympy_index_symbol("y") var_ranges = { x: 1024, y: 2048, } loop_order1 = MemoryDep( "access_the_same_buffer", x * 2048 + y, [x, y], [1024, 2048], ) loop_order2 = MemoryDep( "access_the_same_buffer", x * 2048 + y + 5, [y, x], [2048, 1024], ) self.assertTrue(loop_order1 != loop_order2) normalized_loop_order1 = loop_order1.normalize_with_stride_order() normalized_loop_order2 = loop_order2.normalize_with_stride_order() # unequal due to different offset self.assertTrue(normalized_loop_order1 != normalized_loop_order2) if __name__ == "__main__": from torch._inductor.test_case import run_tests if HAS_CPU and HAS_GPU: run_tests("sympy")