# Owner(s): ["oncall: jit"] import os import sys from typing import Any, List import torch import torch.nn as nn from torch import Tensor from torch.testing._internal.jit_utils import JitTestCase, make_global # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) if __name__ == "__main__": raise RuntimeError( "This test file is not meant to be run directly, use:\n\n" "\tpython test/test_jit.py TESTNAME\n\n" "instead." ) class OrigModule(nn.Module): def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: return inp1 + inp2 + 1 def two(self, input: Tensor) -> Tensor: return input + 2 def forward(self, input: Tensor) -> Tensor: return input + self.one(input, input) + 1 class NewModule(nn.Module): def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: return inp1 * inp2 + 1 def forward(self, input: Tensor) -> Tensor: return self.one(input, input + 1) class TestModuleInterface(JitTestCase): def test_not_submodule_interface_call(self): @torch.jit.interface class ModuleInterface(nn.Module): def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: pass class TestNotModuleInterfaceCall(nn.Module): proxy_mod: ModuleInterface def __init__(self) -> None: super().__init__() self.proxy_mod = OrigModule() def forward(self, input: Tensor) -> Tensor: return self.proxy_mod.two(input) with self.assertRaisesRegexWithHighlight( RuntimeError, "object has no attribute or method", "self.proxy_mod.two" ): torch.jit.script(TestNotModuleInterfaceCall()) def test_module_interface(self): @torch.jit.interface class OneTwoModule(nn.Module): def one(self, x: Tensor, y: Tensor) -> Tensor: pass def two(self, x: Tensor) -> Tensor: pass def forward(self, x: Tensor) -> Tensor: pass @torch.jit.interface class OneTwoClass: def one(self, x: Tensor, y: Tensor) -> Tensor: pass def two(self, x: Tensor) -> Tensor: pass class FooMod(nn.Module): def one(self, x: Tensor, y: Tensor) -> Tensor: return x + y def two(self, x: Tensor) -> Tensor: return 2 * x def forward(self, x: Tensor) -> Tensor: return self.one(self.two(x), x) class BarMod(nn.Module): def one(self, x: Tensor, y: Tensor) -> Tensor: return x * y def two(self, x: Tensor) -> Tensor: return 2 / x def forward(self, x: Tensor) -> Tensor: return self.two(self.one(x, x)) @torch.jit.export def forward2(self, x: Tensor) -> Tensor: return self.two(self.one(x, x)) + 1 make_global(OneTwoModule, OneTwoClass) def use_module_interface(mod_list: List[OneTwoModule], x: torch.Tensor): return mod_list[0].forward(x) + mod_list[1].forward(x) def use_class_interface(mod_list: List[OneTwoClass], x: Tensor) -> Tensor: return mod_list[0].two(x) + mod_list[1].one(x, x) scripted_foo_mod = torch.jit.script(FooMod()) scripted_bar_mod = torch.jit.script(BarMod()) self.checkScript( use_module_interface, ( [scripted_foo_mod, scripted_bar_mod], torch.rand(3, 4), ), ) self.checkScript( use_class_interface, ( [scripted_foo_mod, scripted_bar_mod], torch.rand(3, 4), ), ) def call_module_interface_on_other_method( mod_interface: OneTwoModule, x: Tensor ) -> Tensor: return mod_interface.forward2(x) # ensure error out when we call the module on the method other than the interface specified. with self.assertRaisesRegexWithHighlight( RuntimeError, "object has no attribute or method", "mod_interface.forward2" ): self.checkScript( call_module_interface_on_other_method, ( scripted_bar_mod, torch.rand(3, 4), ), ) def test_module_doc_string(self): @torch.jit.interface class TestInterface(nn.Module): def one(self, inp1, inp2): # type: (Tensor, Tensor) -> Tensor pass def forward(self, input): # type: (Tensor) -> Tensor r"""stuff 1""" r"""stuff 2""" pass # noqa: PIE790 r"""stuff 3""" class TestModule(nn.Module): proxy_mod: TestInterface def __init__(self) -> None: super().__init__() self.proxy_mod = OrigModule() def forward(self, input): # type: (Tensor) -> Tensor return self.proxy_mod.forward(input) input = torch.randn(3, 4) self.checkModule(TestModule(), (input,)) def test_module_interface_subtype(self): @torch.jit.interface class OneTwoModule(nn.Module): def one(self, x: Tensor, y: Tensor) -> Tensor: pass def two(self, x: Tensor) -> Tensor: pass def forward(self, x: Tensor) -> Tensor: pass make_global(OneTwoModule) @torch.jit.script def as_module_interface(x: OneTwoModule) -> OneTwoModule: return x @torch.jit.script class Foo: def one(self, x: Tensor, y: Tensor) -> Tensor: return x + y def two(self, x: Tensor) -> Tensor: return 2 * x def forward(self, x: Tensor) -> Tensor: return self.one(self.two(x), x) # check class object is not a subtype of module interface with self.assertRaisesRegex( RuntimeError, "ScriptModule class can be subtype of module interface" ): as_module_interface(Foo()) class WrongMod(nn.Module): def two(self, x: int) -> int: return 2 * x def forward(self, x: Tensor) -> Tensor: return x + torch.randn(3, self.two(3)) scripted_wrong_mod = torch.jit.script(WrongMod()) # wrong module that is not compatible with module interface with self.assertRaisesRegex(RuntimeError, "is not compatible with interface"): as_module_interface(scripted_wrong_mod) # Check that interface implementations can be contravariant in argument types and covariant in return type. @torch.jit.interface class TensorToAny(nn.Module): def forward(self, input: torch.Tensor) -> Any: pass make_global(TensorToAny) @torch.jit.script def as_tensor_to_any(x: TensorToAny) -> TensorToAny: return x @torch.jit.interface class AnyToAny(nn.Module): def forward(self, input: Any) -> Any: pass make_global(AnyToAny) @torch.jit.script def as_any_to_any(x: AnyToAny) -> AnyToAny: return x class TensorToAnyImplA(nn.Module): def forward(self, input: Any) -> Any: return input class TensorToAnyImplB(nn.Module): def forward(self, input: Any) -> torch.Tensor: return torch.tensor([1]) class AnyToAnyImpl(nn.Module): def forward(self, input: Any) -> torch.Tensor: return torch.tensor([1]) as_tensor_to_any(torch.jit.script(TensorToAnyImplA())) as_tensor_to_any(torch.jit.script(TensorToAnyImplB())) as_any_to_any(torch.jit.script(AnyToAnyImpl())) def test_module_interface_inheritance(self): with self.assertRaisesRegex( RuntimeError, "does not support inheritance yet. Please directly" ): @torch.jit.interface class InheritMod(nn.ReLU): def three(self, x: Tensor) -> Tensor: return 3 * x def test_module_swap(self): @torch.jit.interface class ModuleInterface(nn.Module): def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: pass def forward(self, input: Tensor) -> Tensor: pass class TestModule(nn.Module): proxy_mod: ModuleInterface def __init__(self) -> None: super().__init__() self.proxy_mod = OrigModule() def forward(self, input: Tensor) -> Tensor: return self.proxy_mod.forward(input) scripted_mod = torch.jit.script(TestModule()) input = torch.randn(3, 4) self.assertEqual(scripted_mod(input), 3 * input + 2) # module swap with module that have the same interface scripted_mod.proxy_mod = torch.jit.script(NewModule()) self.assertEqual(scripted_mod(input), input * (input + 1) + 1) # module swap with non-scripted module should throw error with self.assertRaisesRegex( RuntimeError, "a ScriptModule with non-scripted module" ): scripted_mod.proxy_mod = NewModule() def test_module_swap_wrong_module(self): @torch.jit.interface class ModuleInterface(nn.Module): def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: pass def forward(self, input: Tensor) -> Tensor: pass class NewModuleWrong(nn.Module): def forward(self, input: int) -> int: return input + 1 class TestModule(nn.Module): proxy_mod: ModuleInterface def __init__(self) -> None: super().__init__() self.proxy_mod = OrigModule() def forward(self, input: Tensor) -> Tensor: return self.proxy_mod.forward(input) scripted_mod = torch.jit.script(TestModule()) # module swap with in-compatible interface with self.assertRaisesRegex(RuntimeError, "is not compatible with interface"): scripted_mod.proxy_mod = torch.jit.script(NewModuleWrong()) def test_module_swap_no_lazy_compile(self): @torch.jit.interface class ModuleInterface(nn.Module): def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: pass def forward(self, input: Tensor) -> Tensor: pass class TestModule(nn.Module): proxy_mod: ModuleInterface def __init__(self) -> None: super().__init__() self.proxy_mod = OrigModule() def forward(self, input: Tensor) -> Tensor: return self.proxy_mod.forward(input) class NewModuleMethodNotLazyCompile(nn.Module): def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: return inp1 * inp2 + 1 def forward(self, input: Tensor) -> Tensor: return input + 1 scripted_mod = torch.jit.script(TestModule()) # module swap with module that have the same interface, but the method not get # lazily compiled from forward, user need to export it explicitly for swap to work with self.assertRaisesRegex(RuntimeError, "is not compatible with interface"): scripted_mod.proxy_mod = torch.jit.script(NewModuleMethodNotLazyCompile()) class NewModuleMethodManualExport(nn.Module): @torch.jit.export def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: return inp1 * inp2 + 1 def forward(self, input: Tensor) -> Tensor: return input + 1 scripted_mod.proxy_mod = torch.jit.script(NewModuleMethodManualExport()) input = torch.randn(3, 4) self.assertEqual(scripted_mod(input), input + 1) def test_module_swap_no_module_interface(self): # test module swapping with no module interface class TestNoModuleInterface(nn.Module): def __init__(self) -> None: super().__init__() self.proxy_mod = OrigModule() def forward(self, input: Tensor) -> Tensor: return self.proxy_mod(input) scripted_no_module_interface = torch.jit.script(TestNoModuleInterface()) # proxy mod is swapped with the new ScriptModule that share the same JIT type, should succeed. scripted_no_module_interface.proxy_mod = torch.jit.script(OrigModule()) # proxy_mod is neither a module interface or have the same JIT type, should fail with self.assertRaisesRegex( RuntimeError, r"Expected a value of type '__torch__.jit.test_module_interface.OrigModule \(.*\)' " + r"for field 'proxy_mod', but found '__torch__.jit.test_module_interface.NewModule \(.*\)'", ): scripted_no_module_interface.proxy_mod = torch.jit.script(NewModule()) def test_script_module_as_interface_swap(self): @torch.jit.interface class ModuleInterface(nn.Module): def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: pass def forward(self, input: Tensor) -> Tensor: pass class OrigScriptModule(torch.jit.ScriptModule): @torch.jit.script_method def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: return inp1 + inp2 + 1 @torch.jit.script_method def forward(self, input: Tensor) -> Tensor: return input + self.one(input, input) + 1 class NewScriptModule(torch.jit.ScriptModule): @torch.jit.script_method def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: return inp1 * inp2 + 1 @torch.jit.script_method def forward(self, input: Tensor) -> Tensor: return self.one(input, input + 1) class TestNNModuleWithScriptModule(nn.Module): proxy_mod: ModuleInterface def __init__(self) -> None: super().__init__() self.proxy_mod = OrigScriptModule() def forward(self, input: Tensor) -> Tensor: return self.proxy_mod.forward(input) input = torch.randn(3, 4) scripted_mod = torch.jit.script(TestNNModuleWithScriptModule()) self.assertEqual(scripted_mod(input), 3 * input + 2) scripted_mod.proxy_mod = NewScriptModule() self.assertEqual(scripted_mod(input), input * (input + 1) + 1) # The call to forward of proxy_mod cannot be inlined. Making sure # Freezing is throwing an error for now. def test_freeze_module_with_interface(self): class SubModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.b = 20 def forward(self, x): return self.b class OrigMod(torch.nn.Module): def __init__(self) -> None: super().__init__() self.a = 0 def forward(self, x): return self.a @torch.jit.interface class ModInterface(torch.nn.Module): def forward(self, x: Tensor) -> int: pass class TestModule(torch.nn.Module): proxy_mod: ModInterface def __init__(self) -> None: super().__init__() self.proxy_mod = OrigMod() self.sub = SubModule() # folded def forward(self, x): return self.proxy_mod(x) + self.sub(x) m = torch.jit.script(TestModule()) m.eval() mf = torch._C._freeze_module(m._c) # Assume interface has no aliasing mf = torch._C._freeze_module(m._c, freezeInterfaces=True) input = torch.tensor([1]) out_s = m.forward(input) out_f = mf.forward(input) self.assertEqual(out_s, out_f) def test_freeze_module_with_setattr_in_interface(self): class SubModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.b = 20 def forward(self, x): self.b += 2 return self.b @torch.jit.export def getb(self, x): return self.b class OrigMod(torch.nn.Module): def __init__(self) -> None: super().__init__() self.a = 0 def forward(self, x): return self.a @torch.jit.interface class ModInterface(torch.nn.Module): def forward(self, x: Tensor) -> int: pass class TestModule(torch.nn.Module): proxy_mod: ModInterface def __init__(self) -> None: super().__init__() self.proxy_mod = OrigMod() self.sub = SubModule() def forward(self, x): return self.proxy_mod(x) + self.sub.getb(x) m = torch.jit.script(TestModule()) m.proxy_mod = m.sub m.eval() mf = torch._C._freeze_module(m._c, freezeInterfaces=True) def test_freeze_module_with_inplace_mutation_in_interface(self): class SubModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.b = torch.tensor([1.5]) def forward(self, x): self.b[0] += 2 return self.b @torch.jit.export def getb(self, x): return self.b class OrigMod(torch.nn.Module): def __init__(self) -> None: super().__init__() self.a = torch.tensor([0.5]) def forward(self, x): return self.a @torch.jit.interface class ModInterface(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: pass class TestModule(torch.nn.Module): proxy_mod: ModInterface def __init__(self) -> None: super().__init__() self.proxy_mod = OrigMod() self.sub = SubModule() def forward(self, x): y = self.proxy_mod(x) z = self.sub.getb(x) return y[0] + z[0] m = torch.jit.script(TestModule()) m.proxy_mod = m.sub m.sub.b = m.proxy_mod.b m.eval() mf = torch._C._freeze_module(m._c, freezeInterfaces=True) def test_freeze_module_with_mutated_interface(self): class SubModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.b = torch.tensor([1.5]) def forward(self, x): return self.b @torch.jit.export def getb(self, x): return self.b class OrigMod(torch.nn.Module): def __init__(self) -> None: super().__init__() self.a = torch.tensor([0.5]) def forward(self, x): return self.a @torch.jit.interface class ModInterface(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: pass class TestModule(torch.nn.Module): proxy_mod: ModInterface def __init__(self) -> None: super().__init__() self.proxy_mod = OrigMod() self.sub = SubModule() def forward(self, x): self.proxy_mod = self.sub y = self.proxy_mod(x) z = self.sub.getb(x) return y[0] + z[0] m = torch.jit.script(TestModule()) m.eval() with self.assertRaisesRegex( RuntimeError, "Freezing does not support SetAttr on an interface type." ): mf = torch._C._freeze_module(m._c, freezeInterfaces=True) def test_freeze_module_with_interface_and_fork(self): class SubModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.b = torch.tensor([1.5]) def forward(self, x): self.b[0] += 3.2 return self.b class OrigMod(torch.nn.Module): def __init__(self) -> None: super().__init__() self.a = torch.tensor([0.5]) def forward(self, x): return self.a @torch.jit.interface class ModInterface(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: pass class TestModule(torch.nn.Module): proxy_mod: ModInterface def __init__(self) -> None: super().__init__() self.proxy_mod = OrigMod() self.sub = SubModule() def forward(self, x): y = self.proxy_mod(x) z = self.sub(x) return y + z class MainModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.test = TestModule() def forward(self, x): fut = torch.jit._fork(self.test.forward, x) y = self.test(x) z = torch.jit._wait(fut) return y + z m = torch.jit.script(MainModule()) m.eval() mf = torch._C._freeze_module(m._c, freezeInterfaces=True) def test_module_apis_interface(self): @torch.jit.interface class ModuleInterface(nn.Module): def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: pass class TestModule(nn.Module): proxy_mod: ModuleInterface def __init__(self) -> None: super().__init__() self.proxy_mod = OrigModule() def forward(self, input): return input * 2 @torch.jit.export def method(self, input): for module in self.modules(): input = module(input) return input with self.assertRaisesRegex(Exception, "Could not compile"): scripted_mod = torch.jit.script(TestModule())