xref: /aosp_15_r20/external/pytorch/test/jit/test_parametrization.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3
4import torch
5import torch.nn.utils.parametrize as parametrize
6from torch import nn
7from torch.testing._internal.jit_utils import JitTestCase
8
9
10if __name__ == "__main__":
11    raise RuntimeError(
12        "This test file is not meant to be run directly, use:\n\n"
13        "\tpython test/test_jit.py TESTNAME\n\n"
14        "instead."
15    )
16
17
18class TestParametrization(JitTestCase):
19    # Define some parametrization
20    class Symmetric(nn.Module):
21        def forward(self, X):
22            return X.triu() + X.triu(1).mT
23
24    def test_traceable(self):
25        r"""Test the jit scripting and tracing of a parametrized model."""
26        model = nn.Linear(5, 5)
27        parametrize.register_parametrization(model, "weight", self.Symmetric())
28
29        x = torch.randn(3, 5)
30        y = model(x)
31
32        # Check the tracing works. Because traced functions cannot be called
33        # directly, we run the comparison on the activations.
34        traced_model = torch.jit.trace_module(model, {"forward": x})
35        y_hat = traced_model(x)
36        self.assertEqual(y, y_hat)
37
38        # Check traced model works with caching
39        with parametrize.cached():
40            y_hat = traced_model(x)
41            self.assertEqual(y, y_hat)
42
43        # Check the tracing throws an error when caching
44        with self.assertRaisesRegex(RuntimeError, "Cannot trace a model while caching"):
45            with parametrize.cached():
46                traced_model = torch.jit.trace_module(model, {"forward": x})
47
48    def test_scriptable(self):
49        # TODO: Need to fix the scripting in parametrizations
50        #       Currently, all the tests below will throw torch.jit.Error
51        model = nn.Linear(5, 5)
52        parametrize.register_parametrization(model, "weight", self.Symmetric())
53
54        x = torch.randn(3, 5)
55        y = model(x)
56
57        with self.assertRaises(torch.jit.Error):
58            # Check scripting works
59            scripted_model = torch.jit.script(model)
60            y_hat = scripted_model(x)
61            self.assertEqual(y, y_hat)
62
63            with parametrize.cached():
64                # Check scripted model works when caching
65                y_hat = scripted_model(x)
66                self.assertEqual(y, y_hat)
67
68                # Check the scripting process throws an error when caching
69                with self.assertRaisesRegex(RuntimeError, "Caching is not implemented"):
70                    scripted_model = torch.jit.trace_module(model)
71