xref: /aosp_15_r20/external/pytorch/test/jit/test_sparse.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import io
4import unittest
5
6import torch
7from torch.testing._internal.common_utils import IS_WINDOWS, TEST_MKL
8from torch.testing._internal.jit_utils import JitTestCase
9
10
11class TestSparse(JitTestCase):
12    def test_freeze_sparse_coo(self):
13        class SparseTensorModule(torch.nn.Module):
14            def __init__(self) -> None:
15                super().__init__()
16                self.a = torch.rand(3, 4).to_sparse()
17                self.b = torch.rand(3, 4).to_sparse()
18
19            def forward(self, x):
20                return x + self.a + self.b
21
22        x = torch.rand(3, 4).to_sparse()
23
24        m = SparseTensorModule()
25        unfrozen_result = m.forward(x)
26
27        m.eval()
28        frozen = torch.jit.freeze(torch.jit.script(m))
29
30        frozen_result = frozen.forward(x)
31
32        self.assertEqual(unfrozen_result, frozen_result)
33
34        buffer = io.BytesIO()
35        torch.jit.save(frozen, buffer)
36        buffer.seek(0)
37        loaded_model = torch.jit.load(buffer)
38
39        loaded_result = loaded_model.forward(x)
40
41        self.assertEqual(unfrozen_result, loaded_result)
42
43    def test_serialize_sparse_coo(self):
44        class SparseTensorModule(torch.nn.Module):
45            def __init__(self) -> None:
46                super().__init__()
47                self.a = torch.rand(3, 4).to_sparse()
48                self.b = torch.rand(3, 4).to_sparse()
49
50            def forward(self, x):
51                return x + self.a + self.b
52
53        x = torch.rand(3, 4).to_sparse()
54        m = SparseTensorModule()
55        expected_result = m.forward(x)
56
57        buffer = io.BytesIO()
58        torch.jit.save(torch.jit.script(m), buffer)
59        buffer.seek(0)
60        loaded_model = torch.jit.load(buffer)
61
62        loaded_result = loaded_model.forward(x)
63
64        self.assertEqual(expected_result, loaded_result)
65
66    @unittest.skipIf(IS_WINDOWS or not TEST_MKL, "Need MKL to run CSR matmul")
67    def test_freeze_sparse_csr(self):
68        class SparseTensorModule(torch.nn.Module):
69            def __init__(self) -> None:
70                super().__init__()
71                self.a = torch.rand(4, 4).to_sparse_csr()
72                self.b = torch.rand(4, 4).to_sparse_csr()
73
74            def forward(self, x):
75                return x.matmul(self.a).matmul(self.b)
76
77        x = torch.rand(4, 4).to_sparse_csr()
78
79        m = SparseTensorModule()
80        unfrozen_result = m.forward(x)
81
82        m.eval()
83        frozen = torch.jit.freeze(torch.jit.script(m))
84
85        frozen_result = frozen.forward(x)
86
87        self.assertEqual(unfrozen_result.to_dense(), frozen_result.to_dense())
88
89        buffer = io.BytesIO()
90        torch.jit.save(frozen, buffer)
91        buffer.seek(0)
92        loaded_model = torch.jit.load(buffer)
93
94        loaded_result = loaded_model.forward(x)
95
96        self.assertEqual(unfrozen_result.to_dense(), loaded_result.to_dense())
97
98    @unittest.skipIf(IS_WINDOWS or not TEST_MKL, "Need MKL to run CSR matmul")
99    def test_serialize_sparse_csr(self):
100        class SparseTensorModule(torch.nn.Module):
101            def __init__(self) -> None:
102                super().__init__()
103                self.a = torch.rand(4, 4).to_sparse_csr()
104                self.b = torch.rand(4, 4).to_sparse_csr()
105
106            def forward(self, x):
107                return x.matmul(self.a).matmul(self.b)
108
109        x = torch.rand(4, 4).to_sparse_csr()
110        m = SparseTensorModule()
111        expected_result = m.forward(x)
112
113        buffer = io.BytesIO()
114        torch.jit.save(torch.jit.script(m), buffer)
115        buffer.seek(0)
116        loaded_model = torch.jit.load(buffer)
117
118        loaded_result = loaded_model.forward(x)
119
120        self.assertEqual(expected_result.to_dense(), loaded_result.to_dense())
121