1import os 2import sys 3from tempfile import NamedTemporaryFile 4 5import torch.package.package_exporter 6from torch.testing._internal.common_utils import IS_WINDOWS, TestCase 7 8 9class PackageTestCase(TestCase): 10 def __init__(self, *args, **kwargs): 11 super().__init__(*args, **kwargs) 12 self._temporary_files = [] 13 14 def temp(self): 15 t = NamedTemporaryFile() 16 name = t.name 17 if IS_WINDOWS: 18 t.close() # can't read an open file in windows 19 else: 20 self._temporary_files.append(t) 21 return name 22 23 def setUp(self): 24 """Add test/package/ to module search path. This ensures that 25 importing our fake packages via, e.g. `import package_a` will always 26 work regardless of how we invoke the test. 27 """ 28 super().setUp() 29 self.package_test_dir = os.path.dirname(os.path.realpath(__file__)) 30 self.orig_sys_path = sys.path.copy() 31 sys.path.append(self.package_test_dir) 32 torch.package.package_exporter._gate_torchscript_serialization = False 33 34 def tearDown(self): 35 super().tearDown() 36 sys.path = self.orig_sys_path 37 38 # remove any temporary files 39 for t in self._temporary_files: 40 t.close() 41 self._temporary_files = [] 42