1*da0073e9SAndroid Build Coastguard Workerimport os 2*da0073e9SAndroid Build Coastguard Workerimport sys 3*da0073e9SAndroid Build Coastguard Workerfrom tempfile import NamedTemporaryFile 4*da0073e9SAndroid Build Coastguard Worker 5*da0073e9SAndroid Build Coastguard Workerimport torch.package.package_exporter 6*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import IS_WINDOWS, TestCase 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Workerclass PackageTestCase(TestCase): 10*da0073e9SAndroid Build Coastguard Worker def __init__(self, *args, **kwargs): 11*da0073e9SAndroid Build Coastguard Worker super().__init__(*args, **kwargs) 12*da0073e9SAndroid Build Coastguard Worker self._temporary_files = [] 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Worker def temp(self): 15*da0073e9SAndroid Build Coastguard Worker t = NamedTemporaryFile() 16*da0073e9SAndroid Build Coastguard Worker name = t.name 17*da0073e9SAndroid Build Coastguard Worker if IS_WINDOWS: 18*da0073e9SAndroid Build Coastguard Worker t.close() # can't read an open file in windows 19*da0073e9SAndroid Build Coastguard Worker else: 20*da0073e9SAndroid Build Coastguard Worker self._temporary_files.append(t) 21*da0073e9SAndroid Build Coastguard Worker return name 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Worker def setUp(self): 24*da0073e9SAndroid Build Coastguard Worker """Add test/package/ to module search path. This ensures that 25*da0073e9SAndroid Build Coastguard Worker importing our fake packages via, e.g. `import package_a` will always 26*da0073e9SAndroid Build Coastguard Worker work regardless of how we invoke the test. 27*da0073e9SAndroid Build Coastguard Worker """ 28*da0073e9SAndroid Build Coastguard Worker super().setUp() 29*da0073e9SAndroid Build Coastguard Worker self.package_test_dir = os.path.dirname(os.path.realpath(__file__)) 30*da0073e9SAndroid Build Coastguard Worker self.orig_sys_path = sys.path.copy() 31*da0073e9SAndroid Build Coastguard Worker sys.path.append(self.package_test_dir) 32*da0073e9SAndroid Build Coastguard Worker torch.package.package_exporter._gate_torchscript_serialization = False 33*da0073e9SAndroid Build Coastguard Worker 34*da0073e9SAndroid Build Coastguard Worker def tearDown(self): 35*da0073e9SAndroid Build Coastguard Worker super().tearDown() 36*da0073e9SAndroid Build Coastguard Worker sys.path = self.orig_sys_path 37*da0073e9SAndroid Build Coastguard Worker 38*da0073e9SAndroid Build Coastguard Worker # remove any temporary files 39*da0073e9SAndroid Build Coastguard Worker for t in self._temporary_files: 40*da0073e9SAndroid Build Coastguard Worker t.close() 41*da0073e9SAndroid Build Coastguard Worker self._temporary_files = [] 42