1# Owner(s): ["oncall: package/deploy"] 2 3from io import BytesIO 4 5from torch.package import PackageExporter 6from torch.testing._internal.common_utils import run_tests 7 8 9try: 10 from .common import PackageTestCase 11except ImportError: 12 # Support the case where we run this file directly. 13 from common import PackageTestCase 14 15 16class TestDependencyHooks(PackageTestCase): 17 """Dependency management hooks API tests. 18 - register_mock_hook() 19 - register_extern_hook() 20 """ 21 22 def test_single_hook(self): 23 buffer = BytesIO() 24 25 my_externs = set() 26 27 def my_extern_hook(package_exporter, module_name): 28 my_externs.add(module_name) 29 30 with PackageExporter(buffer) as exporter: 31 exporter.extern(["package_a.subpackage", "module_a"]) 32 exporter.register_extern_hook(my_extern_hook) 33 exporter.save_source_string("foo", "import module_a") 34 35 self.assertEqual(my_externs, {"module_a"}) 36 37 def test_multiple_extern_hooks(self): 38 buffer = BytesIO() 39 40 my_externs = set() 41 42 def my_extern_hook(package_exporter, module_name): 43 my_externs.add(module_name) 44 45 # This also checks ordering, since `remove()` will fail if the value is not in the set. 46 def my_extern_hook2(package_exporter, module_name): 47 my_externs.remove(module_name) 48 49 with PackageExporter(buffer) as exporter: 50 exporter.extern(["package_a.subpackage", "module_a"]) 51 exporter.register_extern_hook(my_extern_hook) 52 exporter.register_extern_hook(my_extern_hook2) 53 exporter.save_source_string("foo", "import module_a") 54 55 self.assertEqual(my_externs, set()) 56 57 def test_multiple_mock_hooks(self): 58 buffer = BytesIO() 59 60 my_mocks = set() 61 62 def my_mock_hook(package_exporter, module_name): 63 my_mocks.add(module_name) 64 65 # This also checks ordering, since `remove()` will fail if the value is not in the set. 66 def my_mock_hook2(package_exporter, module_name): 67 my_mocks.remove(module_name) 68 69 with PackageExporter(buffer) as exporter: 70 exporter.mock(["package_a.subpackage", "module_a"]) 71 exporter.register_mock_hook(my_mock_hook) 72 exporter.register_mock_hook(my_mock_hook2) 73 exporter.save_source_string("foo", "import module_a") 74 75 self.assertEqual(my_mocks, set()) 76 77 def test_remove_hooks(self): 78 buffer = BytesIO() 79 80 my_externs = set() 81 my_externs2 = set() 82 83 def my_extern_hook(package_exporter, module_name): 84 my_externs.add(module_name) 85 86 def my_extern_hook2(package_exporter, module_name): 87 my_externs2.add(module_name) 88 89 with PackageExporter(buffer) as exporter: 90 exporter.extern(["package_a.subpackage", "module_a"]) 91 handle = exporter.register_extern_hook(my_extern_hook) 92 exporter.register_extern_hook(my_extern_hook2) 93 handle.remove() 94 exporter.save_source_string("foo", "import module_a") 95 96 self.assertEqual(my_externs, set()) 97 self.assertEqual(my_externs2, {"module_a"}) 98 99 def test_extern_and_mock_hook(self): 100 buffer = BytesIO() 101 102 my_externs = set() 103 my_mocks = set() 104 105 def my_extern_hook(package_exporter, module_name): 106 my_externs.add(module_name) 107 108 def my_mock_hook(package_exporter, module_name): 109 my_mocks.add(module_name) 110 111 with PackageExporter(buffer) as exporter: 112 exporter.extern("module_a") 113 exporter.mock("package_a") 114 exporter.register_extern_hook(my_extern_hook) 115 exporter.register_mock_hook(my_mock_hook) 116 exporter.save_source_string("foo", "import module_a; import package_a") 117 118 self.assertEqual(my_externs, {"module_a"}) 119 self.assertEqual(my_mocks, {"package_a"}) 120 121 122if __name__ == "__main__": 123 run_tests() 124