xref: /aosp_15_r20/external/pytorch/test/package/test_importer.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: package/deploy"]
2
3from io import BytesIO
4
5import torch
6from torch.package import (
7    Importer,
8    OrderedImporter,
9    PackageExporter,
10    PackageImporter,
11    sys_importer,
12)
13from torch.testing._internal.common_utils import run_tests
14
15
16try:
17    from .common import PackageTestCase
18except ImportError:
19    # Support the case where we run this file directly.
20    from common import PackageTestCase
21
22
23class TestImporter(PackageTestCase):
24    """Tests for Importer and derived classes."""
25
26    def test_sys_importer(self):
27        import package_a
28        import package_a.subpackage
29
30        self.assertIs(sys_importer.import_module("package_a"), package_a)
31        self.assertIs(
32            sys_importer.import_module("package_a.subpackage"), package_a.subpackage
33        )
34
35    def test_sys_importer_roundtrip(self):
36        import package_a
37        import package_a.subpackage
38
39        importer = sys_importer
40        type_ = package_a.subpackage.PackageASubpackageObject
41        module_name, type_name = importer.get_name(type_)
42
43        module = importer.import_module(module_name)
44        self.assertIs(getattr(module, type_name), type_)
45
46    def test_single_ordered_importer(self):
47        import module_a  # noqa: F401
48        import package_a
49
50        buffer = BytesIO()
51        with PackageExporter(buffer) as pe:
52            pe.save_module(package_a.__name__)
53
54        buffer.seek(0)
55        importer = PackageImporter(buffer)
56
57        # Construct an importer-only environment.
58        ordered_importer = OrderedImporter(importer)
59
60        # The module returned by this environment should be the same one that's
61        # in the importer.
62        self.assertIs(
63            ordered_importer.import_module("package_a"),
64            importer.import_module("package_a"),
65        )
66        # It should not be the one available in the outer Python environment.
67        self.assertIsNot(ordered_importer.import_module("package_a"), package_a)
68
69        # We didn't package this module, so it should not be available.
70        with self.assertRaises(ModuleNotFoundError):
71            ordered_importer.import_module("module_a")
72
73    def test_ordered_importer_basic(self):
74        import package_a
75
76        buffer = BytesIO()
77        with PackageExporter(buffer) as pe:
78            pe.save_module(package_a.__name__)
79
80        buffer.seek(0)
81        importer = PackageImporter(buffer)
82
83        ordered_importer_sys_first = OrderedImporter(sys_importer, importer)
84        self.assertIs(ordered_importer_sys_first.import_module("package_a"), package_a)
85
86        ordered_importer_package_first = OrderedImporter(importer, sys_importer)
87        self.assertIs(
88            ordered_importer_package_first.import_module("package_a"),
89            importer.import_module("package_a"),
90        )
91
92    def test_ordered_importer_whichmodule(self):
93        """OrderedImporter's implementation of whichmodule should try each
94        underlying importer's whichmodule in order.
95        """
96
97        class DummyImporter(Importer):
98            def __init__(self, whichmodule_return):
99                self._whichmodule_return = whichmodule_return
100
101            def import_module(self, module_name):
102                raise NotImplementedError
103
104            def whichmodule(self, obj, name):
105                return self._whichmodule_return
106
107        class DummyClass:
108            pass
109
110        dummy_importer_foo = DummyImporter("foo")
111        dummy_importer_bar = DummyImporter("bar")
112        dummy_importer_not_found = DummyImporter(
113            "__main__"
114        )  # __main__ is used as a proxy for "not found" by CPython
115
116        foo_then_bar = OrderedImporter(dummy_importer_foo, dummy_importer_bar)
117        self.assertEqual(foo_then_bar.whichmodule(DummyClass(), ""), "foo")
118
119        bar_then_foo = OrderedImporter(dummy_importer_bar, dummy_importer_foo)
120        self.assertEqual(bar_then_foo.whichmodule(DummyClass(), ""), "bar")
121
122        notfound_then_foo = OrderedImporter(
123            dummy_importer_not_found, dummy_importer_foo
124        )
125        self.assertEqual(notfound_then_foo.whichmodule(DummyClass(), ""), "foo")
126
127    def test_package_importer_whichmodule_no_dunder_module(self):
128        """Exercise corner case where we try to pickle an object whose
129        __module__ doesn't exist because it's from a C extension.
130        """
131        # torch.float16 is an example of such an object: it is a C extension
132        # type for which there is no __module__ defined. The default pickler
133        # finds it using special logic to traverse sys.modules and look up
134        # `float16` on each module (see pickle.py:whichmodule).
135        #
136        # We must ensure that we emulate the same behavior from PackageImporter.
137        my_dtype = torch.float16
138
139        # Set up a PackageImporter which has a torch.float16 object pickled:
140        buffer = BytesIO()
141        with PackageExporter(buffer) as exporter:
142            exporter.save_pickle("foo", "foo.pkl", my_dtype)
143        buffer.seek(0)
144
145        importer = PackageImporter(buffer)
146        my_loaded_dtype = importer.load_pickle("foo", "foo.pkl")
147
148        # Re-save a package with only our PackageImporter as the importer
149        buffer2 = BytesIO()
150        with PackageExporter(buffer2, importer=importer) as exporter:
151            exporter.save_pickle("foo", "foo.pkl", my_loaded_dtype)
152
153        buffer2.seek(0)
154
155        importer2 = PackageImporter(buffer2)
156        my_loaded_dtype2 = importer2.load_pickle("foo", "foo.pkl")
157        self.assertIs(my_dtype, my_loaded_dtype)
158        self.assertIs(my_dtype, my_loaded_dtype2)
159
160
161if __name__ == "__main__":
162    run_tests()
163