1# Owner(s): ["oncall: jit"] 2 3import io 4import unittest 5 6import torch 7from torch.testing._internal.common_utils import IS_WINDOWS, TEST_MKL 8from torch.testing._internal.jit_utils import JitTestCase 9 10 11class TestSparse(JitTestCase): 12 def test_freeze_sparse_coo(self): 13 class SparseTensorModule(torch.nn.Module): 14 def __init__(self) -> None: 15 super().__init__() 16 self.a = torch.rand(3, 4).to_sparse() 17 self.b = torch.rand(3, 4).to_sparse() 18 19 def forward(self, x): 20 return x + self.a + self.b 21 22 x = torch.rand(3, 4).to_sparse() 23 24 m = SparseTensorModule() 25 unfrozen_result = m.forward(x) 26 27 m.eval() 28 frozen = torch.jit.freeze(torch.jit.script(m)) 29 30 frozen_result = frozen.forward(x) 31 32 self.assertEqual(unfrozen_result, frozen_result) 33 34 buffer = io.BytesIO() 35 torch.jit.save(frozen, buffer) 36 buffer.seek(0) 37 loaded_model = torch.jit.load(buffer) 38 39 loaded_result = loaded_model.forward(x) 40 41 self.assertEqual(unfrozen_result, loaded_result) 42 43 def test_serialize_sparse_coo(self): 44 class SparseTensorModule(torch.nn.Module): 45 def __init__(self) -> None: 46 super().__init__() 47 self.a = torch.rand(3, 4).to_sparse() 48 self.b = torch.rand(3, 4).to_sparse() 49 50 def forward(self, x): 51 return x + self.a + self.b 52 53 x = torch.rand(3, 4).to_sparse() 54 m = SparseTensorModule() 55 expected_result = m.forward(x) 56 57 buffer = io.BytesIO() 58 torch.jit.save(torch.jit.script(m), buffer) 59 buffer.seek(0) 60 loaded_model = torch.jit.load(buffer) 61 62 loaded_result = loaded_model.forward(x) 63 64 self.assertEqual(expected_result, loaded_result) 65 66 @unittest.skipIf(IS_WINDOWS or not TEST_MKL, "Need MKL to run CSR matmul") 67 def test_freeze_sparse_csr(self): 68 class SparseTensorModule(torch.nn.Module): 69 def __init__(self) -> None: 70 super().__init__() 71 self.a = torch.rand(4, 4).to_sparse_csr() 72 self.b = torch.rand(4, 4).to_sparse_csr() 73 74 def forward(self, x): 75 return x.matmul(self.a).matmul(self.b) 76 77 x = torch.rand(4, 4).to_sparse_csr() 78 79 m = SparseTensorModule() 80 unfrozen_result = m.forward(x) 81 82 m.eval() 83 frozen = torch.jit.freeze(torch.jit.script(m)) 84 85 frozen_result = frozen.forward(x) 86 87 self.assertEqual(unfrozen_result.to_dense(), frozen_result.to_dense()) 88 89 buffer = io.BytesIO() 90 torch.jit.save(frozen, buffer) 91 buffer.seek(0) 92 loaded_model = torch.jit.load(buffer) 93 94 loaded_result = loaded_model.forward(x) 95 96 self.assertEqual(unfrozen_result.to_dense(), loaded_result.to_dense()) 97 98 @unittest.skipIf(IS_WINDOWS or not TEST_MKL, "Need MKL to run CSR matmul") 99 def test_serialize_sparse_csr(self): 100 class SparseTensorModule(torch.nn.Module): 101 def __init__(self) -> None: 102 super().__init__() 103 self.a = torch.rand(4, 4).to_sparse_csr() 104 self.b = torch.rand(4, 4).to_sparse_csr() 105 106 def forward(self, x): 107 return x.matmul(self.a).matmul(self.b) 108 109 x = torch.rand(4, 4).to_sparse_csr() 110 m = SparseTensorModule() 111 expected_result = m.forward(x) 112 113 buffer = io.BytesIO() 114 torch.jit.save(torch.jit.script(m), buffer) 115 buffer.seek(0) 116 loaded_model = torch.jit.load(buffer) 117 118 loaded_result = loaded_model.forward(x) 119 120 self.assertEqual(expected_result.to_dense(), loaded_result.to_dense()) 121