xref: /aosp_15_r20/external/pytorch/test/jit/test_backends.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import io
4import os
5import sys
6import unittest
7
8import torch
9import torch._C
10from torch.jit.mobile import _load_for_lite_interpreter
11from torch.testing import FileCheck
12from torch.testing._internal.common_utils import (
13    find_library_location,
14    IS_FBCODE,
15    IS_MACOS,
16    IS_SANDCASTLE,
17    IS_WINDOWS,
18    skipIfRocm,
19    TEST_WITH_ROCM,
20)
21from torch.testing._internal.jit_utils import JitTestCase
22
23
24# Make the helper files in test/ importable
25pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
26sys.path.append(pytorch_test_dir)
27
28if __name__ == "__main__":
29    raise RuntimeError(
30        "This test file is not meant to be run directly, use:\n\n"
31        "\tpython test/test_jit.py TESTNAME\n\n"
32        "instead."
33    )
34
35
36def to_test_backend(module, method_compile_spec):
37    return torch._C._jit_to_backend(
38        "test_backend", module, {"forward": method_compile_spec}
39    )
40
41
42def to_test_backend_multi(module, method_compile_spec):
43    return torch._C._jit_to_backend("test_backend", module, method_compile_spec)
44
45
46def to_test_backend_selective(module, method_compile_spec, submodules):
47    def _to_test_backend(module):
48        return to_test_backend(module, method_compile_spec)
49
50    return torch._C._jit_to_backend_selective(module, _to_test_backend, submodules)
51
52
53class BasicModule(torch.nn.Module):
54    """
55    A simple Module used to test to_backend lowering machinery.
56    """
57
58    def forward(self, x, h):
59        return self.accum(x, h), self.sub_accum(x, h)
60
61    def accum(self, x, h):
62        return x + h
63
64    def sub_accum(self, x, h):
65        return x - h
66
67
68# This is ignored in IS_WINDOWS or IS_MACOS cases. Hence we need the one in TestBackends.
69@unittest.skipIf(
70    TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE,
71    "Non-portable load_library call used in test",
72)
73class JitBackendTestCase(JitTestCase):
74    """
75    A common base class for JIT backend tests that contains common utility
76    functions for output comparison and serialization/deserialization.
77    """
78
79    def setUp(self):
80        super().setUp()
81        lib_file_path = find_library_location("libjitbackend_test.so")
82        torch.ops.load_library(str(lib_file_path))
83        # Subclasses are expected to set up three variables in their setUp methods:
84        # module - a regular, Python version of the module being tested
85        # scripted_module - a scripted version of module
86        # lowered_module - a version of module lowered to a backend
87
88    def check_function(self, function_name, input):
89        """
90        Check that the function named 'function_name' produces the same output using
91        Python, regular JIT and the backend for the given 'input'.
92        """
93        # Get handles for Python, JIT and backend methods.
94        python_method = self.module.__getattribute__(function_name)
95        jit_method = self.scripted_module.__getattr__(function_name)
96        backend_method = self.lowered_module.__getattr__(function_name)
97
98        # Run methods.
99        python_output = python_method(*input)
100        jit_output = jit_method(*input)
101        backend_output = backend_method(*input)
102
103        # The answers returned by Python, JIT and to_backend should all match.
104        self.assertEqual(python_output, backend_output)
105        self.assertEqual(jit_output, backend_output)
106
107    def save_load(self):
108        """
109        Save and load the lowered module.
110        """
111        self.lowered_module = self.getExportImportCopy(self.lowered_module)
112
113    def test_execution(self):
114        """
115        Stub for correctness tests.
116        """
117
118    def test_save_load(self):
119        """
120        Stub for serialization tests.
121        """
122
123    def test_errors(self):
124        """
125        Stub for testing error checking.
126        """
127
128
129class BasicModuleTest(JitBackendTestCase):
130    """
131    Tests for BasicModule.
132    """
133
134    def setUp(self):
135        super().setUp()
136        # Create Python, JIT and backend versions of BasicModule.
137        self.module = BasicModule()
138        self.scripted_module = torch.jit.script(BasicModule())
139        self.lowered_module = to_test_backend_multi(
140            self.scripted_module,
141            {"accum": {"": ""}, "sub_accum": {"": ""}, "forward": {"": ""}},
142        )
143
144    def test_execution(self):
145        # Test execution with backend against Python and JIT.
146        input = torch.randn(5)
147
148        # Test all three module methods.
149        self.check_function("accum", (input, input))
150        self.check_function("sub_accum", (input, input))
151        self.check_function("forward", (input, input))
152
153    @skipIfRocm
154    def test_save_load(self):
155        # Lowered module should produce the same outputs.
156        self.test_execution()
157
158        # Save the compile spec to compare against the version retrieved after loading.
159        pre_compile_spec = self.lowered_module.__getattr__(
160            "__loweredModule__"
161        ).__getattr__("__method_compile_spec")
162
163        # Save and load the lowered module.
164        self.save_load()
165
166        # Get the compile spec after loading.
167        post_compile_spec = self.lowered_module.__getattr__(
168            "__loweredModule__"
169        ).__getattr__("__method_compile_spec")
170
171        # Compile specs should match.
172        self.assertEqual(pre_compile_spec, post_compile_spec)
173
174        # Loaded module should produce the same outputs.
175        self.test_execution()
176
177
178class BasicModuleUnavailableTest(JitBackendTestCase):
179    """
180    Tests for BasicModule with a backend that is not available.
181    Fundamentally:
182      * _jit_to_backend is successful.
183      * Execution fails with an exception.
184      * Saving is successful.
185      * Loading fails with an exception.
186    """
187
188    def setUp(self):
189        super().setUp()
190        # Create Python, JIT and backend versions of BasicModule.
191        self.module = BasicModule()
192        self.scripted_module = torch.jit.script(BasicModule())
193        self.lowered_module = torch._C._jit_to_backend(
194            "test_backend_unavailable",
195            self.scripted_module,
196            {"forward": {"": ""}},
197        )
198
199    def test_execution(self):
200        # Test execution with backend fails because the backend that is not available.
201        input = torch.randn(5)
202
203        # Test exception is thrown.
204        with self.assertRaisesRegexWithHighlight(
205            Exception,
206            r"Backend is not available.",
207            'raise Exception("Backend is not available."',
208        ):
209            backend_method = self.lowered_module.__getattr__("forward")
210            backend_output = backend_method(*(input, input))
211
212    @skipIfRocm
213    def test_save_load(self):
214        # Test that saving the lowered module is OK but loading fails because the backend is not available.
215        buffer = io.BytesIO()
216        torch.jit.save(self.lowered_module, buffer)
217        buffer.seek(0)
218        with self.assertRaisesRegexWithHighlight(
219            Exception,
220            r"Backend is not available.",
221            'raise Exception("Backend is not available."',
222        ):
223            imported = torch.jit.load(buffer)
224
225
226class NestedModuleTest(JitBackendTestCase):
227    """
228    Tests for NestedModule that check that a module lowered to a backend can be used
229    as a submodule.
230    """
231
232    class NestedModule(torch.nn.Module):
233        """
234        A Module with one submodule that is used to test that lowered Modules
235        can be used as submodules.
236        """
237
238        def __init__(self, submodule):
239            super().__init__()
240            self.submodule = submodule
241
242        def forward(self, x, h):
243            return self.submodule.forward(x, h)
244
245    def setUp(self):
246        super().setUp()
247        # Create Python, JIT and backend versions of NestedModule.
248        # Both modules in self.module are regular Python modules.
249        self.module = NestedModuleTest.NestedModule(BasicModule())
250        # Both modules in self.scripted_module are ScriptModules.
251        self.scripted_module = torch.jit.script(
252            NestedModuleTest.NestedModule(BasicModule())
253        )
254
255        # First, script another instance of NestedModule with share_types=False so that it can be
256        # selectively lowered without modifying the type of self.scripted_module.
257        lowered_module = to_test_backend_multi(
258            torch.jit.script(BasicModule()),
259            {"accum": {"": ""}, "sub_accum": {"": ""}, "forward": {"": ""}},
260        )
261        # self.lowered_module is a ScriptModule, but its submodule is a lowered module.
262        self.lowered_module = torch.jit.script(
263            NestedModuleTest.NestedModule(lowered_module)
264        )
265
266    def test_execution(self):
267        # Test execution with backend against Python and JIT.
268        input = torch.randn(5)
269
270        # Test forward.
271        self.check_function("forward", (input, input))
272
273    def test_save_load(self):
274        # Lowered module should produce the same outputs.
275        self.test_execution()
276
277        # Save and load the lowered module.
278        self.save_load()
279
280        # Loaded module should produce the same outputs.
281        self.test_execution()
282
283
284class SelectiveLoweringTest(JitBackendTestCase):
285    """
286    Tests for the selective lowering API.
287    """
288
289    class OuterModule(torch.nn.Module):
290        def __init__(self, sub1, sub2, other):
291            super().__init__()
292            self.sub1 = sub1
293            self.sub2 = sub2
294            self.other = other
295
296        def forward(self, x, y):
297            # Call the module that will be lowered directly to test
298            # type remapping in modules that are not its parent.
299            a, b = self.sub1.submodule.forward(x, y)
300            c, d = self.sub2.forward(x, y)
301            e, f = self.other.forward(x, y)
302            return a + c + e, b + d + f
303
304    class MiddleModule(torch.nn.Module):
305        def __init__(self, submodule):
306            super().__init__()
307            self.submodule = submodule
308
309        def forward(self, x, y):
310            return self.submodule.forward(x, y)
311
312    def setUp(self):
313        super().setUp()
314        OuterModule = SelectiveLoweringTest.OuterModule
315        MiddleModule = SelectiveLoweringTest.MiddleModule
316
317        def script_without_type_sharing(mod):
318            return torch.jit._recursive.create_script_module(
319                mod, torch.jit._recursive.infer_methods_to_compile, share_types=False
320            )
321
322        # Create Python, JIT and backend versions of a hierarchy that looks like this:
323        #                 --------- OuterModule --------
324        #                 |              |              |
325        #           MiddleModule    MiddleModule   MiddleModule
326        #                |               |              |
327        #           BasicModule     BasicModule    BasicModule
328        #
329        # Two BasicModules will be lowered and the third will not.
330        self.module = OuterModule(
331            MiddleModule(BasicModule()),
332            MiddleModule(BasicModule()),
333            MiddleModule(BasicModule()),
334        )
335        self.scripted_module = script_without_type_sharing(
336            OuterModule(
337                MiddleModule(BasicModule()),
338                MiddleModule(BasicModule()),
339                MiddleModule(BasicModule()),
340            )
341        )
342        self.lowered_module = script_without_type_sharing(
343            OuterModule(
344                MiddleModule(BasicModule()),
345                MiddleModule(BasicModule()),
346                MiddleModule(BasicModule()),
347            )
348        )
349        self.lowered_module = to_test_backend_selective(
350            self.lowered_module, {"forward": ""}, ["sub1.submodule", "sub2.submodule"]
351        )
352
353    def test_execution(self):
354        input = torch.randn(5)
355        self.check_function("forward", (input, input))
356
357        self.test_selective_lowering_type_remap()
358
359    def test_save_load(self):
360        self.test_execution()
361        self.save_load()
362        self.test_execution()
363
364        self.test_selective_lowering_type_remap()
365
366    def test_selective_lowering_type_remap(self):
367        """
368        Check that type remapping and replacement occurred during selective lowering.
369        """
370        # Check that self.lowered_module was not lowered, but that it does contain test_backendLoweredModule due to it
371        # calling the lowered module directly.
372        FileCheck().check("OuterModule").check("BasicModule").run(
373            self.scripted_module.graph
374        )
375        FileCheck().check("OuterModule").check_not(
376            "__torch__.torch.classes.__backends__.test_backend"
377        ).check("LoweredWrapper.test_backend").run(self.lowered_module.graph)
378
379        # Check that self.lowered_module.sub1/sub2 were not lowered but that BasicModule has been replaced in their graphs.
380        FileCheck().check("MiddleModule").check("BasicModule").check_not(
381            "LoweredWrapper.test_backend"
382        ).run(self.scripted_module.sub1.graph)
383        FileCheck().check("MiddleModule").check_not(
384            "__torch__.torch.classes.__backends__.test_backend"
385        ).check("LoweredWrapper.test_backend").run(self.lowered_module.sub1.graph)
386
387        FileCheck().check("MiddleModule").check("BasicModule").check_not(
388            "LoweredWrapper.test_backend"
389        ).run(self.scripted_module.sub2.graph)
390        FileCheck().check("MiddleModule").check_not(
391            "__torch__.torch.classes.__backends__.test_backend"
392        ).check("LoweredWrapper.test_backend").run(self.lowered_module.sub2.graph)
393
394        # Check that self.lowered_module.sub1/sub2.submodule were lowered. They should have a new attribute
395        # __loweredModule__ whose graph should mention __torch__.torch.classes.__backends__.test_backend,
396        # the TorchBind class for executing functions on the test JIT backend.
397        FileCheck().check("LoweredModule.test_backend").check(
398            "__torch__.torch.classes.__backends__.test_backend"
399        ).run(self.lowered_module.sub1.submodule.__loweredModule__.graph)
400
401        FileCheck().check("LoweredModule.test_backend").check(
402            "__torch__.torch.classes.__backends__.test_backend"
403        ).run(self.lowered_module.sub2.submodule.__loweredModule__.graph)
404
405        # Check that self.other and self.other.submodule have been left untouched by the selective lowering process.
406        FileCheck().check("MiddleModule").check("BasicModule").check_not(
407            "__torch__.torch.classes.__backends__.test_backend"
408        ).check_not("LoweredWrapper.test_backend").run(self.scripted_module.other.graph)
409        FileCheck().check("BasicModule").check_not(
410            "__torch__.torch.classes.__backends__.test_backend"
411        ).check_not("LoweredModule.test_backend").run(
412            self.scripted_module.other.submodule.graph
413        )
414
415    def test_errors(self):
416        """
417        Check errors associated with selective lowering.
418        """
419        # Check error messages thrown when attempting to lower something that is not a ScriptModule.
420        with self.assertRaisesRegexWithHighlight(
421            RuntimeError, r"Object .* is not a ScriptModule", ""
422        ):
423            to_test_backend_selective(torch.nn.ReLU(), {"forward": ""}, ["submodule"])
424
425        MiddleModule = SelectiveLoweringTest.MiddleModule
426        mod = MiddleModule(BasicModule())
427        mod.new_attr = 3
428
429        with self.assertRaisesRegexWithHighlight(
430            RuntimeError, r"Attribute named new_attr is not a Module", ""
431        ):
432            to_test_backend_selective(
433                torch.jit.script(mod), {"forward": ""}, ["new_attr"]
434            )
435
436        # Check error message thrown when module hierarchy doesn't have unique types.
437        OuterModule = SelectiveLoweringTest.OuterModule
438        mod = OuterModule(
439            MiddleModule(BasicModule()),
440            MiddleModule(BasicModule()),
441            MiddleModule(BasicModule()),
442        )
443
444        with self.assertRaisesRegexWithHighlight(
445            RuntimeError,
446            r"Selective lowering is only supported for module hierarchies with unique types",
447            "",
448        ):
449            to_test_backend_selective(
450                torch.jit.script(mod), {"forward": ""}, ["sub1.submodule"]
451            )
452
453
454# This is needed for IS_WINDOWS or IS_MACOS to skip the tests.
455@unittest.skipIf(
456    TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE,
457    "Non-portable load_library call used in test",
458)
459class TestBackends(JitTestCase):
460    """
461    This class wraps and invokes all subclasses of JitBackendTestCase so that each one
462    does not have to be individually imported in test_jit.py.
463    """
464
465    def __init__(self, name):
466        super().__init__(name)
467        self.basic_module_test = BasicModuleTest(name)
468        self.basic_module_unavailable_test = BasicModuleUnavailableTest(name)
469        self.nested_module_test = NestedModuleTest(name)
470        self.selective_lowering_test = SelectiveLoweringTest(name)
471
472    def setUp(self):
473        super().setUp()
474        if not TEST_WITH_ROCM:
475            self.basic_module_test.setUp()
476            self.basic_module_unavailable_test.setUp()
477            self.nested_module_test.setUp()
478            self.selective_lowering_test.setUp()
479
480    @skipIfRocm
481    def test_execution(self):
482        self.basic_module_test.test_execution()
483        self.basic_module_unavailable_test.test_execution()
484        self.nested_module_test.test_execution()
485        self.selective_lowering_test.test_execution()
486
487    @skipIfRocm
488    def test_save_load(self):
489        self.basic_module_test.test_save_load()
490        self.basic_module_unavailable_test.test_save_load()
491        self.nested_module_test.test_save_load()
492        self.selective_lowering_test.test_save_load()
493
494    @skipIfRocm
495    def test_errors(self):
496        self.selective_lowering_test.test_errors()
497
498
499"""
500Unit Tests for backend with compiler
501This test case and the existing TestBackends are separate because they cover different aspects.
502The actual backend implementation in this test is different.
503It has a simple demo compiler to test the end-to-end flow in mobile.
504However, this test cannot cover the selective_lowering for now, which is covered in TestBackends.
505"""
506
507
508class BasicModuleAdd(torch.nn.Module):
509    """
510    A simple add Module used to test to_backend lowering machinery.
511    """
512
513    def forward(self, x, h):
514        return x + h
515
516
517# This is ignored in IS_WINDOWS or IS_MACOS cases. Hence we need the one in TestBackends.
518@unittest.skipIf(
519    TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE,
520    "Non-portable load_library call used in test",
521)
522class JitBackendTestCaseWithCompiler(JitTestCase):
523    """
524    A common base class for JIT backend tests with compilers that contains common utility
525    functions for output comparison.
526    """
527
528    def setUp(self):
529        super().setUp()
530        lib_file_path = find_library_location("libbackend_with_compiler.so")
531        torch.ops.load_library(str(lib_file_path))
532        # Subclasses are expected to set up four variables in their setUp methods:
533        # module - a regular, Python version of the module being tested
534        # scripted_module - a scripted version of module
535        # lowered_module - a version of module lowered to a backend
536        # mobile_module - a module with a format that Pytorch Mobile can execute
537
538    def check_forward(self, input):
539        """
540        Check that the forward function produces the same output using
541        Python, regular JIT, the backend, and mobile for the given 'input'.
542        """
543
544        # Get outputs from forward.
545        python_output = self.module.forward(*input)
546        jit_output = self.scripted_module.forward(*input)
547        backend_output = self.lowered_module(*input)
548        mobile_output = self.mobile_module(*input)
549
550        # The answers returned by Python, JIT, to_backend, and mobile should all match.
551        self.assertEqual(python_output, backend_output)
552        self.assertEqual(jit_output, backend_output)
553        self.assertEqual(mobile_output, backend_output)
554
555    def test_execution(self):
556        """
557        Stub for correctness tests.
558        """
559
560    def test_errors(self):
561        """
562        Stub for testing error checking.
563        """
564
565
566class BasicModuleTestWithCompiler(JitBackendTestCaseWithCompiler):
567    """
568    Tests for BasicModuleAdd.
569    """
570
571    def setUp(self):
572        super().setUp()
573        # Create Python, JIT and backend versions of BasicModuleAdd.
574        self.module = BasicModuleAdd()
575        self.scripted_module = torch.jit.script(BasicModuleAdd())
576        compile_spec = {
577            "forward": {
578                "input_shapes": "((1, 1, 320, 240), (1, 3))",
579                "some_other_option": "True",
580            },
581        }
582        self.lowered_module = torch._C._jit_to_backend(
583            "backend_with_compiler_demo", self.scripted_module, compile_spec
584        )
585        # Create mobile version of BasicModuleAdd
586        buffer = io.BytesIO(self.lowered_module._save_to_buffer_for_lite_interpreter())
587        buffer.seek(0)
588        self.mobile_module = _load_for_lite_interpreter(buffer)
589
590    def test_execution(self):
591        # Test execution with backend against Python and JIT.
592        input = torch.ones(1, dtype=torch.float)
593        self.check_forward((input, input))
594
595
596class ErrorMessagesWithCompiler(JitBackendTestCase):
597    """
598    Tests for errors that occur with compiler, specifically:
599        * an operator is not supported by the backend
600    """
601
602    class ModuleNotSupported(torch.nn.Module):
603        """
604        A module with an operator that is not supported.
605        """
606
607        def forward(self, x, h):
608            return x * h
609            self._loweredmodule.forward()
610
611    def test_errors(self):
612        scripted_module_n = torch.jit.script(
613            ErrorMessagesWithCompiler.ModuleNotSupported()
614        )
615        # Test exception is thrown when lowering a module with an unsupported operator
616        with self.assertRaisesRegexWithHighlight(
617            RuntimeError,
618            # Special escape characters are replaced with '.'
619            r"""The node of aten::mul is not supported in this compiler. .*
620        def forward.self, x, h.:
621            return x . h
622                   ~~~~~ <--- HERE
623            self._loweredmodule.forward..
624""",
625            "",
626        ):
627            lowered_module_n = torch._C._jit_to_backend(
628                "backend_with_compiler_demo", scripted_module_n, {"forward": {"": ""}}
629            )
630
631
632class CompModuleTestWithCompiler(JitBackendTestCase):
633    """
634    Tests for CompModule, which is a module with two lowered submodules
635    """
636
637    class BasicModuleSub(torch.nn.Module):
638        """
639        A simple subtraction Module to be used in CompModule.
640        """
641
642        def forward(self, x, h):
643            return x - h
644
645    class CompModule(torch.nn.Module):
646        """
647        A module with two lowered submodules.
648        """
649
650        def __init__(self, addmodule, submodule):
651            super().__init__()
652            self.lowered_add = addmodule
653            self.lowered_sub = submodule
654
655        def forward(self, a, b, s):
656            c = self.lowered_add.forward(a, b)
657            d = self.lowered_sub.forward(a, b)
658            y = s * (c * d)
659            return y
660
661    def setUp(self):
662        super().setUp()
663        # Create Python and JIT versions of CompModule with lowered submodules.
664        compile_spec = {
665            "forward": {
666                "input_shapes": "((1, 1, 320, 240), (1, 3))",
667                "some_other_option": "True",
668            },
669        }
670        lowered_add = torch._C._jit_to_backend(
671            "backend_with_compiler_demo",
672            torch.jit.script(BasicModuleAdd()),
673            compile_spec,
674        )
675        lowered_sub = torch._C._jit_to_backend(
676            "backend_with_compiler_demo",
677            torch.jit.script(CompModuleTestWithCompiler.BasicModuleSub()),
678            {"forward": {"": ""}},
679        )
680        self.module = CompModuleTestWithCompiler.CompModule(lowered_add, lowered_sub)
681        self.scripted_module = torch.jit.script(
682            CompModuleTestWithCompiler.CompModule(lowered_add, lowered_sub)
683        )
684        # No backend version of CompModule currently, so this is filler.
685        self.lowered_module = self.scripted_module
686        # Create a mobile version of CompModule from JIT version
687        buffer = io.BytesIO(self.scripted_module._save_to_buffer_for_lite_interpreter())
688        buffer.seek(0)
689        self.mobile_module = _load_for_lite_interpreter(buffer)
690
691    def test_execution(self):
692        # Test execution with backend against Python and JIT.
693        input1 = torch.ones(1, dtype=torch.float)
694        input2 = torch.ones(1, dtype=torch.float)
695
696        # Test forward.
697        self.check_function("forward", (input1, input2, input2))
698
699
700# This is needed for IS_WINDOWS or IS_MACOS to skip the tests.
701@unittest.skipIf(
702    IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE,
703    "Non-portable load_library call used in test",
704)
705class TestBackendsWithCompiler(JitTestCase):
706    """
707    This class wraps and invokes all subclasses of JitBackendTestCaseWithCompiler
708    so that each one does not have to be individually imported in test_jit.py.
709    """
710
711    def __init__(self, name):
712        super().__init__(name)
713        self.basic_module_compiler_test = BasicModuleTestWithCompiler(name)
714        self.error_module_compiler_test = ErrorMessagesWithCompiler(name)
715        self.comp_module_compiler_test = CompModuleTestWithCompiler(name)
716
717    def setUp(self):
718        super().setUp()
719        self.basic_module_compiler_test.setUp()
720        self.error_module_compiler_test.setUp()
721        self.comp_module_compiler_test.setUp()
722
723    def test_execution(self):
724        self.basic_module_compiler_test.test_execution()
725        self.comp_module_compiler_test.test_execution()
726
727    def test_errors(self):
728        self.error_module_compiler_test.test_errors()
729
730
731class CompModuleTestSameNameWithCompiler(JitBackendTestCase):
732    """
733    Tests for CompModule, which is a module with two lowered submodules with same module name
734    """
735
736    class ModuleAdd(torch.nn.Module):
737        """
738        A simple Module used to test to_backend lowering machinery.
739        """
740
741        def forward(self, x, h):
742            return x + h
743
744    class CompModule(torch.nn.Module):
745        """
746        A module with two lowered submodules.
747        """
748
749        def __init__(self) -> None:
750            super().__init__()
751            compile_spec = {
752                "forward": {
753                    "some_other_option": "True",
754                },
755            }
756            self.add = torch._C._jit_to_backend(
757                "backend_with_compiler_demo",
758                torch.jit.script(ModuleAdd()),  # noqa: F821
759                compile_spec,
760            )
761            self.sub = torch._C._jit_to_backend(
762                "backend_with_compiler_demo",
763                torch.jit.script(ModuleAdd()),  # noqa: F821
764                compile_spec,
765            )
766
767        def forward(self, a, b, s: int):
768            c = self.add.forward(a, b)
769            d = self.sub.forward(a, b)
770            y = s * (c * d)
771            return y
772
773    def setUp(self):
774        super().setUp()
775
776        self.module = CompModule()  # noqa: F821
777        self.scripted_module = torch.jit.script(self.module)
778        buffer = io.BytesIO(self.scripted_module._save_to_buffer_for_lite_interpreter())
779        buffer.seek(0)
780        self.mobile_module = _load_for_lite_interpreter(buffer)
781
782    def test_execution(self):
783        a = torch.ones(1)
784        b = 3 * torch.ones(1)
785        s = 3
786        # Test forward.
787        self.check_function("forward", (a, b, s))
788
789
790class AddedAttributesTest(JitBackendTestCase):
791    """
792    Tests for adding attributes to a model after lowering.
793    """
794
795    def setUp(self):
796        super().setUp()
797        # Create Python, JIT and backend versions of BasicModule.
798        self.module = BasicModule()
799        self.scripted_module = torch.jit.script(BasicModule())
800        self.lowered_module = to_test_backend_multi(
801            self.scripted_module,
802            {"accum": {"": ""}, "sub_accum": {"": ""}, "forward": {"": ""}},
803        )
804
805    def test_attribute(self):
806        input = [(torch.ones(5),)]
807        pre_bundled = self.lowered_module(*input[0])
808        # Attach bundled inputs which adds several attributes and functions to the model
809        self.lowered_module = (
810            torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
811                lowered_module, input  # noqa: F821
812            )
813        )
814        post_bundled = self.lowered_module(
815            *self.lowered_module.get_all_bundled_inputs()[0]
816        )
817        # Save and load the lowered module.
818        self.save_load()
819        # Use bundled after save and load to prove its preserved
820        post_load = self.lowered_module(
821            *self.lowered_module.get_all_bundled_inputs()[0]
822        )
823        self.assertEqual(pre_bundled, post_bundled)
824        self.assertEqual(post_bundled, post_load)
825