xref: /aosp_15_r20/external/pytorch/test/package/test_resources.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: package/deploy"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerfrom io import BytesIO
4*da0073e9SAndroid Build Coastguard Workerfrom sys import version_info
5*da0073e9SAndroid Build Coastguard Workerfrom textwrap import dedent
6*da0073e9SAndroid Build Coastguard Workerfrom unittest import skipIf
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Workerfrom torch.package import PackageExporter, PackageImporter
9*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import run_tests
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Workertry:
13*da0073e9SAndroid Build Coastguard Worker    from .common import PackageTestCase
14*da0073e9SAndroid Build Coastguard Workerexcept ImportError:
15*da0073e9SAndroid Build Coastguard Worker    # Support the case where we run this file directly.
16*da0073e9SAndroid Build Coastguard Worker    from common import PackageTestCase
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Worker@skipIf(version_info < (3, 7), "ResourceReader API introduced in Python 3.7")
20*da0073e9SAndroid Build Coastguard Workerclass TestResources(PackageTestCase):
21*da0073e9SAndroid Build Coastguard Worker    """Tests for access APIs for packaged resources."""
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Worker    def test_resource_reader(self):
24*da0073e9SAndroid Build Coastguard Worker        """Test compliance with the get_resource_reader importlib API."""
25*da0073e9SAndroid Build Coastguard Worker        buffer = BytesIO()
26*da0073e9SAndroid Build Coastguard Worker        with PackageExporter(buffer) as pe:
27*da0073e9SAndroid Build Coastguard Worker            # Layout looks like:
28*da0073e9SAndroid Build Coastguard Worker            #    package
29*da0073e9SAndroid Build Coastguard Worker            #    |-- one/
30*da0073e9SAndroid Build Coastguard Worker            #    |   |-- a.txt
31*da0073e9SAndroid Build Coastguard Worker            #    |   |-- b.txt
32*da0073e9SAndroid Build Coastguard Worker            #    |   |-- c.txt
33*da0073e9SAndroid Build Coastguard Worker            #    |   +-- three/
34*da0073e9SAndroid Build Coastguard Worker            #    |       |-- d.txt
35*da0073e9SAndroid Build Coastguard Worker            #    |       +-- e.txt
36*da0073e9SAndroid Build Coastguard Worker            #    +-- two/
37*da0073e9SAndroid Build Coastguard Worker            #       |-- f.txt
38*da0073e9SAndroid Build Coastguard Worker            #       +-- g.txt
39*da0073e9SAndroid Build Coastguard Worker            pe.save_text("one", "a.txt", "hello, a!")
40*da0073e9SAndroid Build Coastguard Worker            pe.save_text("one", "b.txt", "hello, b!")
41*da0073e9SAndroid Build Coastguard Worker            pe.save_text("one", "c.txt", "hello, c!")
42*da0073e9SAndroid Build Coastguard Worker
43*da0073e9SAndroid Build Coastguard Worker            pe.save_text("one.three", "d.txt", "hello, d!")
44*da0073e9SAndroid Build Coastguard Worker            pe.save_text("one.three", "e.txt", "hello, e!")
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Worker            pe.save_text("two", "f.txt", "hello, f!")
47*da0073e9SAndroid Build Coastguard Worker            pe.save_text("two", "g.txt", "hello, g!")
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker        buffer.seek(0)
50*da0073e9SAndroid Build Coastguard Worker        importer = PackageImporter(buffer)
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard Worker        reader_one = importer.get_resource_reader("one")
53*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(FileNotFoundError):
54*da0073e9SAndroid Build Coastguard Worker            reader_one.resource_path("a.txt")
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(reader_one.is_resource("a.txt"))
57*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(reader_one.open_resource("a.txt").getbuffer(), b"hello, a!")
58*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(reader_one.is_resource("three"))
59*da0073e9SAndroid Build Coastguard Worker        reader_one_contents = list(reader_one.contents())
60*da0073e9SAndroid Build Coastguard Worker        self.assertSequenceEqual(
61*da0073e9SAndroid Build Coastguard Worker            reader_one_contents, ["a.txt", "b.txt", "c.txt", "three"]
62*da0073e9SAndroid Build Coastguard Worker        )
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Worker        reader_two = importer.get_resource_reader("two")
65*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(reader_two.is_resource("f.txt"))
66*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(reader_two.open_resource("f.txt").getbuffer(), b"hello, f!")
67*da0073e9SAndroid Build Coastguard Worker        reader_two_contents = list(reader_two.contents())
68*da0073e9SAndroid Build Coastguard Worker        self.assertSequenceEqual(reader_two_contents, ["f.txt", "g.txt"])
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Worker        reader_one_three = importer.get_resource_reader("one.three")
71*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(reader_one_three.is_resource("d.txt"))
72*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
73*da0073e9SAndroid Build Coastguard Worker            reader_one_three.open_resource("d.txt").getbuffer(), b"hello, d!"
74*da0073e9SAndroid Build Coastguard Worker        )
75*da0073e9SAndroid Build Coastguard Worker        reader_one_three_contenst = list(reader_one_three.contents())
76*da0073e9SAndroid Build Coastguard Worker        self.assertSequenceEqual(reader_one_three_contenst, ["d.txt", "e.txt"])
77*da0073e9SAndroid Build Coastguard Worker
78*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(importer.get_resource_reader("nonexistent_package"))
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard Worker    def test_package_resource_access(self):
81*da0073e9SAndroid Build Coastguard Worker        """Packaged modules should be able to use the importlib.resources API to access
82*da0073e9SAndroid Build Coastguard Worker        resources saved in the package.
83*da0073e9SAndroid Build Coastguard Worker        """
84*da0073e9SAndroid Build Coastguard Worker        mod_src = dedent(
85*da0073e9SAndroid Build Coastguard Worker            """\
86*da0073e9SAndroid Build Coastguard Worker            import importlib.resources
87*da0073e9SAndroid Build Coastguard Worker            import my_cool_resources
88*da0073e9SAndroid Build Coastguard Worker
89*da0073e9SAndroid Build Coastguard Worker            def secret_message():
90*da0073e9SAndroid Build Coastguard Worker                return importlib.resources.read_text(my_cool_resources, 'sekrit.txt')
91*da0073e9SAndroid Build Coastguard Worker            """
92*da0073e9SAndroid Build Coastguard Worker        )
93*da0073e9SAndroid Build Coastguard Worker        buffer = BytesIO()
94*da0073e9SAndroid Build Coastguard Worker        with PackageExporter(buffer) as pe:
95*da0073e9SAndroid Build Coastguard Worker            pe.save_source_string("foo.bar", mod_src)
96*da0073e9SAndroid Build Coastguard Worker            pe.save_text("my_cool_resources", "sekrit.txt", "my sekrit plays")
97*da0073e9SAndroid Build Coastguard Worker
98*da0073e9SAndroid Build Coastguard Worker        buffer.seek(0)
99*da0073e9SAndroid Build Coastguard Worker        importer = PackageImporter(buffer)
100*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
101*da0073e9SAndroid Build Coastguard Worker            importer.import_module("foo.bar").secret_message(), "my sekrit plays"
102*da0073e9SAndroid Build Coastguard Worker        )
103*da0073e9SAndroid Build Coastguard Worker
104*da0073e9SAndroid Build Coastguard Worker    def test_importer_access(self):
105*da0073e9SAndroid Build Coastguard Worker        buffer = BytesIO()
106*da0073e9SAndroid Build Coastguard Worker        with PackageExporter(buffer) as he:
107*da0073e9SAndroid Build Coastguard Worker            he.save_text("main", "main", "my string")
108*da0073e9SAndroid Build Coastguard Worker            he.save_binary("main", "main_binary", b"my string")
109*da0073e9SAndroid Build Coastguard Worker            src = dedent(
110*da0073e9SAndroid Build Coastguard Worker                """\
111*da0073e9SAndroid Build Coastguard Worker                import importlib
112*da0073e9SAndroid Build Coastguard Worker                import torch_package_importer as resources
113*da0073e9SAndroid Build Coastguard Worker
114*da0073e9SAndroid Build Coastguard Worker                t = resources.load_text('main', 'main')
115*da0073e9SAndroid Build Coastguard Worker                b = resources.load_binary('main', 'main_binary')
116*da0073e9SAndroid Build Coastguard Worker                """
117*da0073e9SAndroid Build Coastguard Worker            )
118*da0073e9SAndroid Build Coastguard Worker            he.save_source_string("main", src, is_package=True)
119*da0073e9SAndroid Build Coastguard Worker        buffer.seek(0)
120*da0073e9SAndroid Build Coastguard Worker        hi = PackageImporter(buffer)
121*da0073e9SAndroid Build Coastguard Worker        m = hi.import_module("main")
122*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m.t, "my string")
123*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m.b, b"my string")
124*da0073e9SAndroid Build Coastguard Worker
125*da0073e9SAndroid Build Coastguard Worker    def test_resource_access_by_path(self):
126*da0073e9SAndroid Build Coastguard Worker        """
127*da0073e9SAndroid Build Coastguard Worker        Tests that packaged code can used importlib.resources.path.
128*da0073e9SAndroid Build Coastguard Worker        """
129*da0073e9SAndroid Build Coastguard Worker        buffer = BytesIO()
130*da0073e9SAndroid Build Coastguard Worker        with PackageExporter(buffer) as he:
131*da0073e9SAndroid Build Coastguard Worker            he.save_binary("string_module", "my_string", b"my string")
132*da0073e9SAndroid Build Coastguard Worker            src = dedent(
133*da0073e9SAndroid Build Coastguard Worker                """\
134*da0073e9SAndroid Build Coastguard Worker                import importlib.resources
135*da0073e9SAndroid Build Coastguard Worker                import string_module
136*da0073e9SAndroid Build Coastguard Worker
137*da0073e9SAndroid Build Coastguard Worker                with importlib.resources.path(string_module, 'my_string') as path:
138*da0073e9SAndroid Build Coastguard Worker                    with open(path, mode='r', encoding='utf-8') as f:
139*da0073e9SAndroid Build Coastguard Worker                        s = f.read()
140*da0073e9SAndroid Build Coastguard Worker                """
141*da0073e9SAndroid Build Coastguard Worker            )
142*da0073e9SAndroid Build Coastguard Worker            he.save_source_string("main", src, is_package=True)
143*da0073e9SAndroid Build Coastguard Worker        buffer.seek(0)
144*da0073e9SAndroid Build Coastguard Worker        hi = PackageImporter(buffer)
145*da0073e9SAndroid Build Coastguard Worker        m = hi.import_module("main")
146*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m.s, "my string")
147*da0073e9SAndroid Build Coastguard Worker
148*da0073e9SAndroid Build Coastguard Worker
149*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
150*da0073e9SAndroid Build Coastguard Worker    run_tests()
151