1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: PrivateUse1"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport os 4*da0073e9SAndroid Build Coastguard Worker 5*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import run_tests, TestCase 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Workerclass TestDeviceBackendAutoload(TestCase): 9*da0073e9SAndroid Build Coastguard Worker def test_autoload(self): 10*da0073e9SAndroid Build Coastguard Worker switch = os.getenv("TORCH_DEVICE_BACKEND_AUTOLOAD", "0") 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Worker # After importing the extension, the value of this environment variable should be true 13*da0073e9SAndroid Build Coastguard Worker # See: test/cpp_extensions/torch_test_cpp_extension/__init__.py 14*da0073e9SAndroid Build Coastguard Worker is_imported = os.getenv("IS_CUSTOM_DEVICE_BACKEND_IMPORTED", "0") 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Worker # Both values should be equal 17*da0073e9SAndroid Build Coastguard Worker self.assertEqual(is_imported, switch) 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 21*da0073e9SAndroid Build Coastguard Worker run_tests() 22