xref: /aosp_15_r20/external/pytorch/test/test_cuda_nvml_based_avail.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: cuda"]
2
3import multiprocessing
4import os
5import sys
6import unittest
7from unittest.mock import patch
8
9import torch
10
11
12# NOTE: Each of the tests in this module need to be run in a brand new process to ensure CUDA is uninitialized
13# prior to test initiation.
14with patch.dict(os.environ, {"PYTORCH_NVML_BASED_CUDA_CHECK": "1"}):
15    # Before executing the desired tests, we need to disable CUDA initialization and fork_handler additions that would
16    # otherwise be triggered by the `torch.testing._internal.common_utils` module import
17    from torch.testing._internal.common_utils import (
18        instantiate_parametrized_tests,
19        IS_JETSON,
20        IS_WINDOWS,
21        NoTest,
22        parametrize,
23        run_tests,
24        TestCase,
25    )
26
27    # NOTE: Because `remove_device_and_dtype_suffixes` initializes CUDA context (triggered via the import of
28    # `torch.testing._internal.common_device_type` which imports `torch.testing._internal.common_cuda`) we need
29    # to bypass that method here which should be irrelevant to the parameterized tests in this module.
30    torch.testing._internal.common_utils.remove_device_and_dtype_suffixes = lambda x: x
31
32    TEST_CUDA = torch.cuda.is_available()
33    if not TEST_CUDA:
34        print("CUDA not available, skipping tests", file=sys.stderr)
35        TestCase = NoTest  # type: ignore[misc, assignment] # noqa: F811
36
37
38@torch.testing._internal.common_utils.markDynamoStrictTest
39class TestExtendedCUDAIsAvail(TestCase):
40    SUBPROCESS_REMINDER_MSG = (
41        "\n REMINDER: Tests defined in test_cuda_nvml_based_avail.py must be run in a process "
42        "where there CUDA Driver API has not been initialized. Before further debugging, ensure you are either using "
43        "run_test.py or have added --subprocess to run each test in a different subprocess."
44    )
45
46    def setUp(self):
47        super().setUp()
48        torch.cuda._cached_device_count = (
49            None  # clear the lru_cache on this method before our test
50        )
51
52    @staticmethod
53    def in_bad_fork_test() -> bool:
54        _ = torch.cuda.is_available()
55        return torch.cuda._is_in_bad_fork()
56
57    # These tests validate the behavior and activation of the weaker, NVML-based, user-requested
58    # `torch.cuda.is_available()` assessment. The NVML-based assessment should be attempted when
59    # `PYTORCH_NVML_BASED_CUDA_CHECK` is set to 1, reverting to the default CUDA Runtime API check otherwise.
60    # If the NVML-based assessment is attempted but fails, the CUDA Runtime API check should be executed
61    @unittest.skipIf(IS_WINDOWS, "Needs fork")
62    @parametrize("nvml_avail", [True, False])
63    @parametrize("avoid_init", ["1", "0", None])
64    def test_cuda_is_available(self, avoid_init, nvml_avail):
65        if IS_JETSON and nvml_avail and avoid_init == "1":
66            self.skipTest("Not working for Jetson")
67        patch_env = {"PYTORCH_NVML_BASED_CUDA_CHECK": avoid_init} if avoid_init else {}
68        with patch.dict(os.environ, **patch_env):
69            if nvml_avail:
70                _ = torch.cuda.is_available()
71            else:
72                with patch.object(torch.cuda, "_device_count_nvml", return_value=-1):
73                    _ = torch.cuda.is_available()
74            with multiprocessing.get_context("fork").Pool(1) as pool:
75                in_bad_fork = pool.apply(TestExtendedCUDAIsAvail.in_bad_fork_test)
76            if os.getenv("PYTORCH_NVML_BASED_CUDA_CHECK") == "1" and nvml_avail:
77                self.assertFalse(
78                    in_bad_fork, TestExtendedCUDAIsAvail.SUBPROCESS_REMINDER_MSG
79                )
80            else:
81                assert in_bad_fork
82
83
84@torch.testing._internal.common_utils.markDynamoStrictTest
85class TestVisibleDeviceParses(TestCase):
86    def test_env_var_parsing(self):
87        def _parse_visible_devices(val):
88            from torch.cuda import _parse_visible_devices as _pvd
89
90            with patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": val}, clear=True):
91                return _pvd()
92
93        # rest of the string is ignored
94        self.assertEqual(_parse_visible_devices("1gpu2,2ampere"), [1, 2])
95        # Negatives abort parsing
96        self.assertEqual(_parse_visible_devices("0, 1, 2, -1, 3"), [0, 1, 2])
97        # Double mention of ordinal returns empty set
98        self.assertEqual(_parse_visible_devices("0, 1, 2, 1"), [])
99        # Unary pluses and minuses
100        self.assertEqual(_parse_visible_devices("2, +3, -0, 5"), [2, 3, 0, 5])
101        # Random string is used as empty set
102        self.assertEqual(_parse_visible_devices("one,two,3,4"), [])
103        # Random string is used as separator
104        self.assertEqual(_parse_visible_devices("4,3,two,one"), [4, 3])
105        # GPU ids are parsed
106        self.assertEqual(_parse_visible_devices("GPU-9e8d35e3"), ["GPU-9e8d35e3"])
107        # Ordinals are not included in GPUid set
108        self.assertEqual(_parse_visible_devices("GPU-123, 2"), ["GPU-123"])
109        # MIG ids are parsed
110        self.assertEqual(_parse_visible_devices("MIG-89c850dc"), ["MIG-89c850dc"])
111
112    def test_partial_uuid_resolver(self):
113        from torch.cuda import _transform_uuid_to_ordinals
114
115        uuids = [
116            "GPU-9942190a-aa31-4ff1-4aa9-c388d80f85f1",
117            "GPU-9e8d35e3-a134-0fdd-0e01-23811fdbd293",
118            "GPU-e429a63e-c61c-4795-b757-5132caeb8e70",
119            "GPU-eee1dfbc-0a0f-6ad8-5ff6-dc942a8b9d98",
120            "GPU-bbcd6503-5150-4e92-c266-97cc4390d04e",
121            "GPU-472ea263-58d7-410d-cc82-f7fdece5bd28",
122            "GPU-e56257c4-947f-6a5b-7ec9-0f45567ccf4e",
123            "GPU-1c20e77d-1c1a-d9ed-fe37-18b8466a78ad",
124        ]
125        self.assertEqual(_transform_uuid_to_ordinals(["GPU-9e8d35e3"], uuids), [1])
126        self.assertEqual(
127            _transform_uuid_to_ordinals(["GPU-e4", "GPU-9e8d35e3"], uuids), [2, 1]
128        )
129        self.assertEqual(
130            _transform_uuid_to_ordinals("GPU-9e8d35e3,GPU-1,GPU-47".split(","), uuids),
131            [1, 7, 5],
132        )
133        # First invalid UUID aborts parsing
134        self.assertEqual(
135            _transform_uuid_to_ordinals(["GPU-123", "GPU-9e8d35e3"], uuids), []
136        )
137        self.assertEqual(
138            _transform_uuid_to_ordinals(["GPU-9e8d35e3", "GPU-123", "GPU-47"], uuids),
139            [1],
140        )
141        # First ambigous UUID aborts parsing
142        self.assertEqual(
143            _transform_uuid_to_ordinals(["GPU-9e8d35e3", "GPU-e", "GPU-47"], uuids), [1]
144        )
145        # Duplicate UUIDs result in empty set
146        self.assertEqual(
147            _transform_uuid_to_ordinals(["GPU-9e8d35e3", "GPU-47", "GPU-9e8"], uuids),
148            [],
149        )
150
151    def test_ordinal_parse_visible_devices(self):
152        def _device_count_nvml(val):
153            from torch.cuda import _device_count_nvml as _dc
154
155            with patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": val}, clear=True):
156                return _dc()
157
158        with patch.object(torch.cuda, "_raw_device_count_nvml", return_value=2):
159            self.assertEqual(_device_count_nvml("1, 0"), 2)
160            # Ordinal out of bounds aborts parsing
161            self.assertEqual(_device_count_nvml("1, 5, 0"), 1)
162
163
164instantiate_parametrized_tests(TestExtendedCUDAIsAvail)
165
166if __name__ == "__main__":
167    run_tests()
168