xref: /aosp_15_r20/external/pytorch/test/test_autoload.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: PrivateUse1"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport os
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import run_tests, TestCase
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Workerclass TestDeviceBackendAutoload(TestCase):
9*da0073e9SAndroid Build Coastguard Worker    def test_autoload(self):
10*da0073e9SAndroid Build Coastguard Worker        switch = os.getenv("TORCH_DEVICE_BACKEND_AUTOLOAD", "0")
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Worker        # After importing the extension, the value of this environment variable should be true
13*da0073e9SAndroid Build Coastguard Worker        # See: test/cpp_extensions/torch_test_cpp_extension/__init__.py
14*da0073e9SAndroid Build Coastguard Worker        is_imported = os.getenv("IS_CUSTOM_DEVICE_BACKEND_IMPORTED", "0")
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Worker        # Both values should be equal
17*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(is_imported, switch)
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
21*da0073e9SAndroid Build Coastguard Worker    run_tests()
22