1# Owner(s): ["oncall: jit"] 2 3 4import torch 5import torch.nn.utils.parametrize as parametrize 6from torch import nn 7from torch.testing._internal.jit_utils import JitTestCase 8 9 10if __name__ == "__main__": 11 raise RuntimeError( 12 "This test file is not meant to be run directly, use:\n\n" 13 "\tpython test/test_jit.py TESTNAME\n\n" 14 "instead." 15 ) 16 17 18class TestParametrization(JitTestCase): 19 # Define some parametrization 20 class Symmetric(nn.Module): 21 def forward(self, X): 22 return X.triu() + X.triu(1).mT 23 24 def test_traceable(self): 25 r"""Test the jit scripting and tracing of a parametrized model.""" 26 model = nn.Linear(5, 5) 27 parametrize.register_parametrization(model, "weight", self.Symmetric()) 28 29 x = torch.randn(3, 5) 30 y = model(x) 31 32 # Check the tracing works. Because traced functions cannot be called 33 # directly, we run the comparison on the activations. 34 traced_model = torch.jit.trace_module(model, {"forward": x}) 35 y_hat = traced_model(x) 36 self.assertEqual(y, y_hat) 37 38 # Check traced model works with caching 39 with parametrize.cached(): 40 y_hat = traced_model(x) 41 self.assertEqual(y, y_hat) 42 43 # Check the tracing throws an error when caching 44 with self.assertRaisesRegex(RuntimeError, "Cannot trace a model while caching"): 45 with parametrize.cached(): 46 traced_model = torch.jit.trace_module(model, {"forward": x}) 47 48 def test_scriptable(self): 49 # TODO: Need to fix the scripting in parametrizations 50 # Currently, all the tests below will throw torch.jit.Error 51 model = nn.Linear(5, 5) 52 parametrize.register_parametrization(model, "weight", self.Symmetric()) 53 54 x = torch.randn(3, 5) 55 y = model(x) 56 57 with self.assertRaises(torch.jit.Error): 58 # Check scripting works 59 scripted_model = torch.jit.script(model) 60 y_hat = scripted_model(x) 61 self.assertEqual(y, y_hat) 62 63 with parametrize.cached(): 64 # Check scripted model works when caching 65 y_hat = scripted_model(x) 66 self.assertEqual(y, y_hat) 67 68 # Check the scripting process throws an error when caching 69 with self.assertRaisesRegex(RuntimeError, "Caching is not implemented"): 70 scripted_model = torch.jit.trace_module(model) 71