# Owner(s): ["module: sparse"] # # Test to ensure sparsity information propagates properly into traced graph. # import sys import unittest import torch from torch._dynamo.config import is_fbcode from torch._subclasses.fake_tensor import FakeTensor from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, run_tests, subtest, TestCase, ) # Various data types (preserved over operations). DTYPES = [ torch.int64, torch.float16, torch.bfloat16, torch.float32, torch.float64, ] # Various index types. ITYPES = [torch.int32, torch.int64] # Constructs a subtest for every sparse layout currently supported in torch.sparse. def all_sparse_layouts(test_name="layout"): return parametrize( test_name, [ subtest(torch.sparse_coo, name="SparseCOO"), subtest(torch.sparse_csr, name="SparseCSR"), subtest(torch.sparse_csc, name="SparseCSC"), subtest(torch.sparse_bsr, name="SparseBSR"), subtest(torch.sparse_bsc, name="SparseBSC"), ], ) # # Various network examples. # class IdNet(torch.nn.Module): def forward(self, x): return x class SumNet(torch.nn.Module): def forward(self, x): return x.sum() class EltwiseNet(torch.nn.Module): def forward(self, x): return torch.nn.functional.relu(2 * torch.abs(-x)) class ToDenseNet(torch.nn.Module): def forward(self, x): return x.to_dense() class AddNet(torch.nn.Module): def forward(self, x, y): return torch.add(x, y) class SparseActivationCOO(torch.nn.Module): def forward(self, x): return [xi.to_sparse() for xi in x] class SparseActivationCSR(torch.nn.Module): def forward(self, x): return [xi.to_sparse_csr() for xi in x] # # The test driver. # @unittest.skipIf(is_fbcode(), "See torch._dynamo.config") @unittest.skipIf( sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" ) class TestSparseProp(TestCase): def setUp(self): TestCase.setUp(self) def assertEqualMeta(self, x, y): self.assertIsInstance(x, FakeTensor) self.assertIsInstance(y, torch.Tensor) # Convert expected value to meta for comparison. y = y.to("meta") self.assertEqual(x, y, exact_layout=True, exact_is_coalesced=True) # When x or y is a meta tensor (say, `x.device == "meta"`), then # assertEqual(x, y) compares only x and y attributes but skips # comparing their values. In the case of sparse tensors, this means # that comparing indices and values attributes are skipped as well, # which is why we are doing that explicitly below. if x.layout is torch.strided: pass elif x.layout is torch.sparse_coo: self.assertEqual(x._indices(), y._indices(), exact_layout=True) self.assertEqual(x._values(), y._values(), exact_layout=True) else: if x.layout in {torch.sparse_csr, torch.sparse_bsr}: x_meta1, y_meta1 = (x.crow_indices(), y.crow_indices()) x_meta2, y_meta2 = (x.col_indices(), y.col_indices()) elif x.layout in {torch.sparse_csc, torch.sparse_bsc}: x_meta1, y_meta1 = (x.ccol_indices(), y.ccol_indices()) x_meta2, y_meta2 = (x.row_indices(), y.row_indices()) else: assert 0 # unreachable self.assertEqual(x_meta1, y_meta1, exact_layout=True) self.assertEqual(x_meta2, y_meta2, exact_layout=True) self.assertEqual(x.values(), y.values(), exact_layout=True) @parametrize("dtype", DTYPES) @parametrize("itype", ITYPES) @all_sparse_layouts("layout") def test_idnet(self, dtype, itype, layout): net = IdNet() for sparse_input in self.generate_simple_inputs( layout, device="cpu", dtype=dtype, index_dtype=itype, ): # Build the traced graph. prog = torch.export.export(net, (sparse_input,)) # Test arg/output. for i, node in enumerate(prog.graph.nodes): meta = node.meta.get("val", None) if i == 0: self.assertEqualMeta(meta, sparse_input) else: self.assertEqual(meta, None) @parametrize("dtype", DTYPES) @parametrize("itype", ITYPES) @all_sparse_layouts("layout") def test_sumnet(self, dtype, itype, layout): net = SumNet() for sparse_input in self.generate_simple_inputs( layout, device="cpu", dtype=dtype, index_dtype=itype, ): result = net(sparse_input) # Build the traced graph. prog = torch.export.export(net, (sparse_input,)) # Test arg/sum/output. for i, node in enumerate(prog.graph.nodes): meta = node.meta.get("val", None) if i == 0: self.assertEqualMeta(meta, sparse_input) elif i == 1: self.assertEqualMeta(meta, result) else: self.assertEqual(meta, None) @parametrize("dtype", DTYPES) @parametrize("itype", ITYPES) @all_sparse_layouts("layout") def test_eltwisenet(self, dtype, itype, layout): net = EltwiseNet() for sparse_input in self.generate_simple_inputs( layout, device="cpu", dtype=dtype, index_dtype=itype, ): result = net(sparse_input) # Build the traced graph. prog = torch.export.export(net, (sparse_input,)) # Test arg/neg/abs/mul/relu/output. for i, node in enumerate(prog.graph.nodes): meta = node.meta.get("val", None) if i <= 4: self.assertEqualMeta(meta, result) else: self.assertEqual(meta, None) @parametrize("dtype", DTYPES) @parametrize("itype", ITYPES) @all_sparse_layouts("layout") def test_todensenet(self, dtype, itype, layout): net = ToDenseNet() for sparse_input in self.generate_simple_inputs( layout, device="cpu", dtype=dtype, index_dtype=itype, ): result = net(sparse_input) # Build the traced graph. prog = torch.export.export(net, (sparse_input,)) # Test arg/todense/output. for i, node in enumerate(prog.graph.nodes): meta = node.meta.get("val", None) if i == 0: self.assertEqualMeta(meta, sparse_input) elif i == 1: self.assertEqualMeta(meta, result) else: self.assertEqual(meta, None) def test_add(self): net = AddNet() Y = torch.arange(16, 32, dtype=torch.float32).view(4, 4) A = torch.tensor( [ [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 0.0, 2.0], [0.0, 0.0, 1.0, 1.0], [3.0, 0.0, 3.0, 0.0], ], dtype=torch.float32, ) S = A.to_sparse_csr() result = net(S, Y) # Build the traced graph. prog = torch.export.export(net, (S, Y)) # Test args/add/output. for i, node in enumerate(prog.graph.nodes): meta = node.meta.get("val", None) if i == 0: self.assertEqualMeta(meta, S) elif i == 1: self.assertEqualMeta(meta, Y) elif i == 2: self.assertEqualMeta(meta, result) else: self.assertEqual(meta, None) def test_activation_coo(self): net = SparseActivationCOO() x = [torch.randn(3, 3) for _ in range(3)] result = net(x) # Build the traced graph. prog = torch.export.export(net, args=(x,)) # Test args/to_sparse/output. for i, node in enumerate(prog.graph.nodes): meta = node.meta.get("val", None) if i <= 2: self.assertEqualMeta(meta, x[i]) elif i <= 5: self.assertEqualMeta(meta, result[i - 3]) else: self.assertEqual(meta, None) def test_activation_csr(self): net = SparseActivationCSR() x = [torch.randn(3, 3) for _ in range(3)] result = net(x) # Build the traced graph. prog = torch.export.export(net, args=(x,)) # Test args/to_sparse/output. for i, node in enumerate(prog.graph.nodes): meta = node.meta.get("val", None) if i <= 2: self.assertEqualMeta(meta, x[i]) elif i <= 5: self.assertEqualMeta(meta, result[i - 3]) else: self.assertEqual(meta, None) instantiate_parametrized_tests(TestSparseProp) if __name__ == "__main__": run_tests()