xref: /aosp_15_r20/external/pytorch/test/jit/test_module_interface.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import os
4import sys
5from typing import Any, List
6
7import torch
8import torch.nn as nn
9from torch import Tensor
10from torch.testing._internal.jit_utils import JitTestCase, make_global
11
12
13# Make the helper files in test/ importable
14pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
15sys.path.append(pytorch_test_dir)
16
17if __name__ == "__main__":
18    raise RuntimeError(
19        "This test file is not meant to be run directly, use:\n\n"
20        "\tpython test/test_jit.py TESTNAME\n\n"
21        "instead."
22    )
23
24
25class OrigModule(nn.Module):
26    def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
27        return inp1 + inp2 + 1
28
29    def two(self, input: Tensor) -> Tensor:
30        return input + 2
31
32    def forward(self, input: Tensor) -> Tensor:
33        return input + self.one(input, input) + 1
34
35
36class NewModule(nn.Module):
37    def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
38        return inp1 * inp2 + 1
39
40    def forward(self, input: Tensor) -> Tensor:
41        return self.one(input, input + 1)
42
43
44class TestModuleInterface(JitTestCase):
45    def test_not_submodule_interface_call(self):
46        @torch.jit.interface
47        class ModuleInterface(nn.Module):
48            def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
49                pass
50
51        class TestNotModuleInterfaceCall(nn.Module):
52            proxy_mod: ModuleInterface
53
54            def __init__(self) -> None:
55                super().__init__()
56                self.proxy_mod = OrigModule()
57
58            def forward(self, input: Tensor) -> Tensor:
59                return self.proxy_mod.two(input)
60
61        with self.assertRaisesRegexWithHighlight(
62            RuntimeError, "object has no attribute or method", "self.proxy_mod.two"
63        ):
64            torch.jit.script(TestNotModuleInterfaceCall())
65
66    def test_module_interface(self):
67        @torch.jit.interface
68        class OneTwoModule(nn.Module):
69            def one(self, x: Tensor, y: Tensor) -> Tensor:
70                pass
71
72            def two(self, x: Tensor) -> Tensor:
73                pass
74
75            def forward(self, x: Tensor) -> Tensor:
76                pass
77
78        @torch.jit.interface
79        class OneTwoClass:
80            def one(self, x: Tensor, y: Tensor) -> Tensor:
81                pass
82
83            def two(self, x: Tensor) -> Tensor:
84                pass
85
86        class FooMod(nn.Module):
87            def one(self, x: Tensor, y: Tensor) -> Tensor:
88                return x + y
89
90            def two(self, x: Tensor) -> Tensor:
91                return 2 * x
92
93            def forward(self, x: Tensor) -> Tensor:
94                return self.one(self.two(x), x)
95
96        class BarMod(nn.Module):
97            def one(self, x: Tensor, y: Tensor) -> Tensor:
98                return x * y
99
100            def two(self, x: Tensor) -> Tensor:
101                return 2 / x
102
103            def forward(self, x: Tensor) -> Tensor:
104                return self.two(self.one(x, x))
105
106            @torch.jit.export
107            def forward2(self, x: Tensor) -> Tensor:
108                return self.two(self.one(x, x)) + 1
109
110        make_global(OneTwoModule, OneTwoClass)
111
112        def use_module_interface(mod_list: List[OneTwoModule], x: torch.Tensor):
113            return mod_list[0].forward(x) + mod_list[1].forward(x)
114
115        def use_class_interface(mod_list: List[OneTwoClass], x: Tensor) -> Tensor:
116            return mod_list[0].two(x) + mod_list[1].one(x, x)
117
118        scripted_foo_mod = torch.jit.script(FooMod())
119        scripted_bar_mod = torch.jit.script(BarMod())
120        self.checkScript(
121            use_module_interface,
122            (
123                [scripted_foo_mod, scripted_bar_mod],
124                torch.rand(3, 4),
125            ),
126        )
127        self.checkScript(
128            use_class_interface,
129            (
130                [scripted_foo_mod, scripted_bar_mod],
131                torch.rand(3, 4),
132            ),
133        )
134
135        def call_module_interface_on_other_method(
136            mod_interface: OneTwoModule, x: Tensor
137        ) -> Tensor:
138            return mod_interface.forward2(x)
139
140        # ensure error out when we call the module on the method other than the interface specified.
141        with self.assertRaisesRegexWithHighlight(
142            RuntimeError, "object has no attribute or method", "mod_interface.forward2"
143        ):
144            self.checkScript(
145                call_module_interface_on_other_method,
146                (
147                    scripted_bar_mod,
148                    torch.rand(3, 4),
149                ),
150            )
151
152    def test_module_doc_string(self):
153        @torch.jit.interface
154        class TestInterface(nn.Module):
155            def one(self, inp1, inp2):
156                # type: (Tensor, Tensor) -> Tensor
157                pass
158
159            def forward(self, input):
160                # type: (Tensor) -> Tensor
161                r"""stuff 1"""
162                r"""stuff 2"""
163                pass  # noqa: PIE790
164                r"""stuff 3"""
165
166        class TestModule(nn.Module):
167            proxy_mod: TestInterface
168
169            def __init__(self) -> None:
170                super().__init__()
171                self.proxy_mod = OrigModule()
172
173            def forward(self, input):
174                # type: (Tensor) -> Tensor
175                return self.proxy_mod.forward(input)
176
177        input = torch.randn(3, 4)
178        self.checkModule(TestModule(), (input,))
179
180    def test_module_interface_subtype(self):
181        @torch.jit.interface
182        class OneTwoModule(nn.Module):
183            def one(self, x: Tensor, y: Tensor) -> Tensor:
184                pass
185
186            def two(self, x: Tensor) -> Tensor:
187                pass
188
189            def forward(self, x: Tensor) -> Tensor:
190                pass
191
192        make_global(OneTwoModule)
193
194        @torch.jit.script
195        def as_module_interface(x: OneTwoModule) -> OneTwoModule:
196            return x
197
198        @torch.jit.script
199        class Foo:
200            def one(self, x: Tensor, y: Tensor) -> Tensor:
201                return x + y
202
203            def two(self, x: Tensor) -> Tensor:
204                return 2 * x
205
206            def forward(self, x: Tensor) -> Tensor:
207                return self.one(self.two(x), x)
208
209        # check class object is not a subtype of module interface
210        with self.assertRaisesRegex(
211            RuntimeError, "ScriptModule class can be subtype of module interface"
212        ):
213            as_module_interface(Foo())
214
215        class WrongMod(nn.Module):
216            def two(self, x: int) -> int:
217                return 2 * x
218
219            def forward(self, x: Tensor) -> Tensor:
220                return x + torch.randn(3, self.two(3))
221
222        scripted_wrong_mod = torch.jit.script(WrongMod())
223
224        # wrong module that is not compatible with module interface
225        with self.assertRaisesRegex(RuntimeError, "is not compatible with interface"):
226            as_module_interface(scripted_wrong_mod)
227
228        # Check that interface implementations can be contravariant in argument types and covariant in return type.
229        @torch.jit.interface
230        class TensorToAny(nn.Module):
231            def forward(self, input: torch.Tensor) -> Any:
232                pass
233
234        make_global(TensorToAny)
235
236        @torch.jit.script
237        def as_tensor_to_any(x: TensorToAny) -> TensorToAny:
238            return x
239
240        @torch.jit.interface
241        class AnyToAny(nn.Module):
242            def forward(self, input: Any) -> Any:
243                pass
244
245        make_global(AnyToAny)
246
247        @torch.jit.script
248        def as_any_to_any(x: AnyToAny) -> AnyToAny:
249            return x
250
251        class TensorToAnyImplA(nn.Module):
252            def forward(self, input: Any) -> Any:
253                return input
254
255        class TensorToAnyImplB(nn.Module):
256            def forward(self, input: Any) -> torch.Tensor:
257                return torch.tensor([1])
258
259        class AnyToAnyImpl(nn.Module):
260            def forward(self, input: Any) -> torch.Tensor:
261                return torch.tensor([1])
262
263        as_tensor_to_any(torch.jit.script(TensorToAnyImplA()))
264        as_tensor_to_any(torch.jit.script(TensorToAnyImplB()))
265        as_any_to_any(torch.jit.script(AnyToAnyImpl()))
266
267    def test_module_interface_inheritance(self):
268        with self.assertRaisesRegex(
269            RuntimeError, "does not support inheritance yet. Please directly"
270        ):
271
272            @torch.jit.interface
273            class InheritMod(nn.ReLU):
274                def three(self, x: Tensor) -> Tensor:
275                    return 3 * x
276
277    def test_module_swap(self):
278        @torch.jit.interface
279        class ModuleInterface(nn.Module):
280            def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
281                pass
282
283            def forward(self, input: Tensor) -> Tensor:
284                pass
285
286        class TestModule(nn.Module):
287            proxy_mod: ModuleInterface
288
289            def __init__(self) -> None:
290                super().__init__()
291                self.proxy_mod = OrigModule()
292
293            def forward(self, input: Tensor) -> Tensor:
294                return self.proxy_mod.forward(input)
295
296        scripted_mod = torch.jit.script(TestModule())
297        input = torch.randn(3, 4)
298        self.assertEqual(scripted_mod(input), 3 * input + 2)
299
300        # module swap with module that have the same interface
301        scripted_mod.proxy_mod = torch.jit.script(NewModule())
302        self.assertEqual(scripted_mod(input), input * (input + 1) + 1)
303
304        # module swap with non-scripted module should throw error
305        with self.assertRaisesRegex(
306            RuntimeError, "a ScriptModule with non-scripted module"
307        ):
308            scripted_mod.proxy_mod = NewModule()
309
310    def test_module_swap_wrong_module(self):
311        @torch.jit.interface
312        class ModuleInterface(nn.Module):
313            def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
314                pass
315
316            def forward(self, input: Tensor) -> Tensor:
317                pass
318
319        class NewModuleWrong(nn.Module):
320            def forward(self, input: int) -> int:
321                return input + 1
322
323        class TestModule(nn.Module):
324            proxy_mod: ModuleInterface
325
326            def __init__(self) -> None:
327                super().__init__()
328                self.proxy_mod = OrigModule()
329
330            def forward(self, input: Tensor) -> Tensor:
331                return self.proxy_mod.forward(input)
332
333        scripted_mod = torch.jit.script(TestModule())
334        # module swap with in-compatible interface
335        with self.assertRaisesRegex(RuntimeError, "is not compatible with interface"):
336            scripted_mod.proxy_mod = torch.jit.script(NewModuleWrong())
337
338    def test_module_swap_no_lazy_compile(self):
339        @torch.jit.interface
340        class ModuleInterface(nn.Module):
341            def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
342                pass
343
344            def forward(self, input: Tensor) -> Tensor:
345                pass
346
347        class TestModule(nn.Module):
348            proxy_mod: ModuleInterface
349
350            def __init__(self) -> None:
351                super().__init__()
352                self.proxy_mod = OrigModule()
353
354            def forward(self, input: Tensor) -> Tensor:
355                return self.proxy_mod.forward(input)
356
357        class NewModuleMethodNotLazyCompile(nn.Module):
358            def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
359                return inp1 * inp2 + 1
360
361            def forward(self, input: Tensor) -> Tensor:
362                return input + 1
363
364        scripted_mod = torch.jit.script(TestModule())
365        # module swap with module that have the same interface, but the method not get
366        # lazily compiled from forward, user need to export it explicitly for swap to work
367        with self.assertRaisesRegex(RuntimeError, "is not compatible with interface"):
368            scripted_mod.proxy_mod = torch.jit.script(NewModuleMethodNotLazyCompile())
369
370        class NewModuleMethodManualExport(nn.Module):
371            @torch.jit.export
372            def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
373                return inp1 * inp2 + 1
374
375            def forward(self, input: Tensor) -> Tensor:
376                return input + 1
377
378        scripted_mod.proxy_mod = torch.jit.script(NewModuleMethodManualExport())
379        input = torch.randn(3, 4)
380        self.assertEqual(scripted_mod(input), input + 1)
381
382    def test_module_swap_no_module_interface(self):
383        # test module swapping with no module interface
384        class TestNoModuleInterface(nn.Module):
385            def __init__(self) -> None:
386                super().__init__()
387                self.proxy_mod = OrigModule()
388
389            def forward(self, input: Tensor) -> Tensor:
390                return self.proxy_mod(input)
391
392        scripted_no_module_interface = torch.jit.script(TestNoModuleInterface())
393        # proxy mod is swapped with the new ScriptModule that share the same JIT type, should succeed.
394        scripted_no_module_interface.proxy_mod = torch.jit.script(OrigModule())
395        # proxy_mod is neither a module interface or have the same JIT type, should fail
396        with self.assertRaisesRegex(
397            RuntimeError,
398            r"Expected a value of type '__torch__.jit.test_module_interface.OrigModule \(.*\)' "
399            + r"for field 'proxy_mod', but found '__torch__.jit.test_module_interface.NewModule \(.*\)'",
400        ):
401            scripted_no_module_interface.proxy_mod = torch.jit.script(NewModule())
402
403    def test_script_module_as_interface_swap(self):
404        @torch.jit.interface
405        class ModuleInterface(nn.Module):
406            def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
407                pass
408
409            def forward(self, input: Tensor) -> Tensor:
410                pass
411
412        class OrigScriptModule(torch.jit.ScriptModule):
413            @torch.jit.script_method
414            def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
415                return inp1 + inp2 + 1
416
417            @torch.jit.script_method
418            def forward(self, input: Tensor) -> Tensor:
419                return input + self.one(input, input) + 1
420
421        class NewScriptModule(torch.jit.ScriptModule):
422            @torch.jit.script_method
423            def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
424                return inp1 * inp2 + 1
425
426            @torch.jit.script_method
427            def forward(self, input: Tensor) -> Tensor:
428                return self.one(input, input + 1)
429
430        class TestNNModuleWithScriptModule(nn.Module):
431            proxy_mod: ModuleInterface
432
433            def __init__(self) -> None:
434                super().__init__()
435                self.proxy_mod = OrigScriptModule()
436
437            def forward(self, input: Tensor) -> Tensor:
438                return self.proxy_mod.forward(input)
439
440        input = torch.randn(3, 4)
441        scripted_mod = torch.jit.script(TestNNModuleWithScriptModule())
442        self.assertEqual(scripted_mod(input), 3 * input + 2)
443
444        scripted_mod.proxy_mod = NewScriptModule()
445        self.assertEqual(scripted_mod(input), input * (input + 1) + 1)
446
447    # The call to forward of proxy_mod cannot be inlined. Making sure
448    # Freezing is throwing an error for now.
449    def test_freeze_module_with_interface(self):
450        class SubModule(torch.nn.Module):
451            def __init__(self) -> None:
452                super().__init__()
453                self.b = 20
454
455            def forward(self, x):
456                return self.b
457
458        class OrigMod(torch.nn.Module):
459            def __init__(self) -> None:
460                super().__init__()
461                self.a = 0
462
463            def forward(self, x):
464                return self.a
465
466        @torch.jit.interface
467        class ModInterface(torch.nn.Module):
468            def forward(self, x: Tensor) -> int:
469                pass
470
471        class TestModule(torch.nn.Module):
472            proxy_mod: ModInterface
473
474            def __init__(self) -> None:
475                super().__init__()
476                self.proxy_mod = OrigMod()
477                self.sub = SubModule()  # folded
478
479            def forward(self, x):
480                return self.proxy_mod(x) + self.sub(x)
481
482        m = torch.jit.script(TestModule())
483        m.eval()
484        mf = torch._C._freeze_module(m._c)
485        # Assume interface has no aliasing
486        mf = torch._C._freeze_module(m._c, freezeInterfaces=True)
487        input = torch.tensor([1])
488        out_s = m.forward(input)
489        out_f = mf.forward(input)
490        self.assertEqual(out_s, out_f)
491
492    def test_freeze_module_with_setattr_in_interface(self):
493        class SubModule(torch.nn.Module):
494            def __init__(self) -> None:
495                super().__init__()
496                self.b = 20
497
498            def forward(self, x):
499                self.b += 2
500                return self.b
501
502            @torch.jit.export
503            def getb(self, x):
504                return self.b
505
506        class OrigMod(torch.nn.Module):
507            def __init__(self) -> None:
508                super().__init__()
509                self.a = 0
510
511            def forward(self, x):
512                return self.a
513
514        @torch.jit.interface
515        class ModInterface(torch.nn.Module):
516            def forward(self, x: Tensor) -> int:
517                pass
518
519        class TestModule(torch.nn.Module):
520            proxy_mod: ModInterface
521
522            def __init__(self) -> None:
523                super().__init__()
524                self.proxy_mod = OrigMod()
525                self.sub = SubModule()
526
527            def forward(self, x):
528                return self.proxy_mod(x) + self.sub.getb(x)
529
530        m = torch.jit.script(TestModule())
531        m.proxy_mod = m.sub
532        m.eval()
533        mf = torch._C._freeze_module(m._c, freezeInterfaces=True)
534
535    def test_freeze_module_with_inplace_mutation_in_interface(self):
536        class SubModule(torch.nn.Module):
537            def __init__(self) -> None:
538                super().__init__()
539                self.b = torch.tensor([1.5])
540
541            def forward(self, x):
542                self.b[0] += 2
543                return self.b
544
545            @torch.jit.export
546            def getb(self, x):
547                return self.b
548
549        class OrigMod(torch.nn.Module):
550            def __init__(self) -> None:
551                super().__init__()
552                self.a = torch.tensor([0.5])
553
554            def forward(self, x):
555                return self.a
556
557        @torch.jit.interface
558        class ModInterface(torch.nn.Module):
559            def forward(self, x: Tensor) -> Tensor:
560                pass
561
562        class TestModule(torch.nn.Module):
563            proxy_mod: ModInterface
564
565            def __init__(self) -> None:
566                super().__init__()
567                self.proxy_mod = OrigMod()
568                self.sub = SubModule()
569
570            def forward(self, x):
571                y = self.proxy_mod(x)
572                z = self.sub.getb(x)
573                return y[0] + z[0]
574
575        m = torch.jit.script(TestModule())
576        m.proxy_mod = m.sub
577        m.sub.b = m.proxy_mod.b
578        m.eval()
579        mf = torch._C._freeze_module(m._c, freezeInterfaces=True)
580
581    def test_freeze_module_with_mutated_interface(self):
582        class SubModule(torch.nn.Module):
583            def __init__(self) -> None:
584                super().__init__()
585                self.b = torch.tensor([1.5])
586
587            def forward(self, x):
588                return self.b
589
590            @torch.jit.export
591            def getb(self, x):
592                return self.b
593
594        class OrigMod(torch.nn.Module):
595            def __init__(self) -> None:
596                super().__init__()
597                self.a = torch.tensor([0.5])
598
599            def forward(self, x):
600                return self.a
601
602        @torch.jit.interface
603        class ModInterface(torch.nn.Module):
604            def forward(self, x: Tensor) -> Tensor:
605                pass
606
607        class TestModule(torch.nn.Module):
608            proxy_mod: ModInterface
609
610            def __init__(self) -> None:
611                super().__init__()
612                self.proxy_mod = OrigMod()
613                self.sub = SubModule()
614
615            def forward(self, x):
616                self.proxy_mod = self.sub
617                y = self.proxy_mod(x)
618                z = self.sub.getb(x)
619                return y[0] + z[0]
620
621        m = torch.jit.script(TestModule())
622        m.eval()
623        with self.assertRaisesRegex(
624            RuntimeError, "Freezing does not support SetAttr on an interface type."
625        ):
626            mf = torch._C._freeze_module(m._c, freezeInterfaces=True)
627
628    def test_freeze_module_with_interface_and_fork(self):
629        class SubModule(torch.nn.Module):
630            def __init__(self) -> None:
631                super().__init__()
632                self.b = torch.tensor([1.5])
633
634            def forward(self, x):
635                self.b[0] += 3.2
636                return self.b
637
638        class OrigMod(torch.nn.Module):
639            def __init__(self) -> None:
640                super().__init__()
641                self.a = torch.tensor([0.5])
642
643            def forward(self, x):
644                return self.a
645
646        @torch.jit.interface
647        class ModInterface(torch.nn.Module):
648            def forward(self, x: Tensor) -> Tensor:
649                pass
650
651        class TestModule(torch.nn.Module):
652            proxy_mod: ModInterface
653
654            def __init__(self) -> None:
655                super().__init__()
656                self.proxy_mod = OrigMod()
657                self.sub = SubModule()
658
659            def forward(self, x):
660                y = self.proxy_mod(x)
661                z = self.sub(x)
662                return y + z
663
664        class MainModule(torch.nn.Module):
665            def __init__(self) -> None:
666                super().__init__()
667                self.test = TestModule()
668
669            def forward(self, x):
670                fut = torch.jit._fork(self.test.forward, x)
671                y = self.test(x)
672                z = torch.jit._wait(fut)
673                return y + z
674
675        m = torch.jit.script(MainModule())
676        m.eval()
677        mf = torch._C._freeze_module(m._c, freezeInterfaces=True)
678
679    def test_module_apis_interface(self):
680        @torch.jit.interface
681        class ModuleInterface(nn.Module):
682            def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
683                pass
684
685        class TestModule(nn.Module):
686            proxy_mod: ModuleInterface
687
688            def __init__(self) -> None:
689                super().__init__()
690                self.proxy_mod = OrigModule()
691
692            def forward(self, input):
693                return input * 2
694
695            @torch.jit.export
696            def method(self, input):
697                for module in self.modules():
698                    input = module(input)
699                return input
700
701        with self.assertRaisesRegex(Exception, "Could not compile"):
702            scripted_mod = torch.jit.script(TestModule())
703