xref: /aosp_15_r20/external/pytorch/test/package/test_dependency_hooks.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: package/deploy"]
2
3from io import BytesIO
4
5from torch.package import PackageExporter
6from torch.testing._internal.common_utils import run_tests
7
8
9try:
10    from .common import PackageTestCase
11except ImportError:
12    # Support the case where we run this file directly.
13    from common import PackageTestCase
14
15
16class TestDependencyHooks(PackageTestCase):
17    """Dependency management hooks API tests.
18    - register_mock_hook()
19    - register_extern_hook()
20    """
21
22    def test_single_hook(self):
23        buffer = BytesIO()
24
25        my_externs = set()
26
27        def my_extern_hook(package_exporter, module_name):
28            my_externs.add(module_name)
29
30        with PackageExporter(buffer) as exporter:
31            exporter.extern(["package_a.subpackage", "module_a"])
32            exporter.register_extern_hook(my_extern_hook)
33            exporter.save_source_string("foo", "import module_a")
34
35        self.assertEqual(my_externs, {"module_a"})
36
37    def test_multiple_extern_hooks(self):
38        buffer = BytesIO()
39
40        my_externs = set()
41
42        def my_extern_hook(package_exporter, module_name):
43            my_externs.add(module_name)
44
45        # This also checks ordering, since `remove()` will fail if the value is not in the set.
46        def my_extern_hook2(package_exporter, module_name):
47            my_externs.remove(module_name)
48
49        with PackageExporter(buffer) as exporter:
50            exporter.extern(["package_a.subpackage", "module_a"])
51            exporter.register_extern_hook(my_extern_hook)
52            exporter.register_extern_hook(my_extern_hook2)
53            exporter.save_source_string("foo", "import module_a")
54
55        self.assertEqual(my_externs, set())
56
57    def test_multiple_mock_hooks(self):
58        buffer = BytesIO()
59
60        my_mocks = set()
61
62        def my_mock_hook(package_exporter, module_name):
63            my_mocks.add(module_name)
64
65        # This also checks ordering, since `remove()` will fail if the value is not in the set.
66        def my_mock_hook2(package_exporter, module_name):
67            my_mocks.remove(module_name)
68
69        with PackageExporter(buffer) as exporter:
70            exporter.mock(["package_a.subpackage", "module_a"])
71            exporter.register_mock_hook(my_mock_hook)
72            exporter.register_mock_hook(my_mock_hook2)
73            exporter.save_source_string("foo", "import module_a")
74
75        self.assertEqual(my_mocks, set())
76
77    def test_remove_hooks(self):
78        buffer = BytesIO()
79
80        my_externs = set()
81        my_externs2 = set()
82
83        def my_extern_hook(package_exporter, module_name):
84            my_externs.add(module_name)
85
86        def my_extern_hook2(package_exporter, module_name):
87            my_externs2.add(module_name)
88
89        with PackageExporter(buffer) as exporter:
90            exporter.extern(["package_a.subpackage", "module_a"])
91            handle = exporter.register_extern_hook(my_extern_hook)
92            exporter.register_extern_hook(my_extern_hook2)
93            handle.remove()
94            exporter.save_source_string("foo", "import module_a")
95
96        self.assertEqual(my_externs, set())
97        self.assertEqual(my_externs2, {"module_a"})
98
99    def test_extern_and_mock_hook(self):
100        buffer = BytesIO()
101
102        my_externs = set()
103        my_mocks = set()
104
105        def my_extern_hook(package_exporter, module_name):
106            my_externs.add(module_name)
107
108        def my_mock_hook(package_exporter, module_name):
109            my_mocks.add(module_name)
110
111        with PackageExporter(buffer) as exporter:
112            exporter.extern("module_a")
113            exporter.mock("package_a")
114            exporter.register_extern_hook(my_extern_hook)
115            exporter.register_mock_hook(my_mock_hook)
116            exporter.save_source_string("foo", "import module_a; import package_a")
117
118        self.assertEqual(my_externs, {"module_a"})
119        self.assertEqual(my_mocks, {"package_a"})
120
121
122if __name__ == "__main__":
123    run_tests()
124