xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/common_cuda.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3r"""This file is allowed to initialize CUDA context when imported."""
4
5import functools
6import torch
7import torch.cuda
8from torch.testing._internal.common_utils import LazyVal, TEST_NUMBA, TEST_WITH_ROCM, TEST_CUDA, IS_WINDOWS
9import inspect
10import contextlib
11import os
12
13
14CUDA_ALREADY_INITIALIZED_ON_IMPORT = torch.cuda.is_initialized()
15
16
17TEST_MULTIGPU = TEST_CUDA and torch.cuda.device_count() >= 2
18CUDA_DEVICE = torch.device("cuda:0") if TEST_CUDA else None
19# note: if ROCm is targeted, TEST_CUDNN is code for TEST_MIOPEN
20if TEST_WITH_ROCM:
21    TEST_CUDNN = LazyVal(lambda: TEST_CUDA)
22else:
23    TEST_CUDNN = LazyVal(lambda: TEST_CUDA and torch.backends.cudnn.is_acceptable(torch.tensor(1., device=CUDA_DEVICE)))
24
25TEST_CUDNN_VERSION = LazyVal(lambda: torch.backends.cudnn.version() if TEST_CUDNN else 0)
26
27SM53OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (5, 3))
28SM60OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (6, 0))
29SM70OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (7, 0))
30SM75OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (7, 5))
31SM80OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 0))
32SM90OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0))
33
34IS_JETSON = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() in [(7, 2), (8, 7)])
35
36def evaluate_gfx_arch_exact(matching_arch):
37    if not torch.cuda.is_available():
38        return False
39    gcn_arch_name = torch.cuda.get_device_properties('cuda').gcnArchName
40    arch = os.environ.get('PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE', gcn_arch_name)
41    return arch == matching_arch
42
43GFX90A_Exact = LazyVal(lambda: evaluate_gfx_arch_exact('gfx90a:sramecc+:xnack-'))
44GFX942_Exact = LazyVal(lambda: evaluate_gfx_arch_exact('gfx942:sramecc+:xnack-'))
45
46def evaluate_platform_supports_flash_attention():
47    if TEST_WITH_ROCM:
48        return evaluate_gfx_arch_exact('gfx90a:sramecc+:xnack-') or evaluate_gfx_arch_exact('gfx942:sramecc+:xnack-')
49    if TEST_CUDA:
50        return not IS_WINDOWS and SM80OrLater
51    return False
52
53def evaluate_platform_supports_efficient_attention():
54    if TEST_WITH_ROCM:
55        return evaluate_gfx_arch_exact('gfx90a:sramecc+:xnack-') or evaluate_gfx_arch_exact('gfx942:sramecc+:xnack-')
56    if TEST_CUDA:
57        return True
58    return False
59
60PLATFORM_SUPPORTS_FLASH_ATTENTION: bool = LazyVal(lambda: evaluate_platform_supports_flash_attention())
61PLATFORM_SUPPORTS_MEM_EFF_ATTENTION: bool = LazyVal(lambda: evaluate_platform_supports_efficient_attention())
62# TODO(eqy): gate this against a cuDNN version
63PLATFORM_SUPPORTS_CUDNN_ATTENTION: bool = LazyVal(lambda: TEST_CUDA and not TEST_WITH_ROCM and
64                                                  torch.backends.cuda.cudnn_sdp_enabled())
65# This condition always evaluates to PLATFORM_SUPPORTS_MEM_EFF_ATTENTION but for logical clarity we keep it separate
66PLATFORM_SUPPORTS_FUSED_ATTENTION: bool = LazyVal(lambda: PLATFORM_SUPPORTS_FLASH_ATTENTION or PLATFORM_SUPPORTS_MEM_EFF_ATTENTION)
67
68PLATFORM_SUPPORTS_FUSED_SDPA: bool = TEST_CUDA and not TEST_WITH_ROCM
69
70PLATFORM_SUPPORTS_BF16: bool = LazyVal(lambda: TEST_CUDA and SM80OrLater)
71
72if TEST_NUMBA:
73    try:
74        import numba.cuda
75        TEST_NUMBA_CUDA = numba.cuda.is_available()
76    except Exception as e:
77        TEST_NUMBA_CUDA = False
78        TEST_NUMBA = False
79else:
80    TEST_NUMBA_CUDA = False
81
82# Used below in `initialize_cuda_context_rng` to ensure that CUDA context and
83# RNG have been initialized.
84__cuda_ctx_rng_initialized = False
85
86
87# after this call, CUDA context and RNG must have been initialized on each GPU
88def initialize_cuda_context_rng():
89    global __cuda_ctx_rng_initialized
90    assert TEST_CUDA, 'CUDA must be available when calling initialize_cuda_context_rng'
91    if not __cuda_ctx_rng_initialized:
92        # initialize cuda context and rng for memory tests
93        for i in range(torch.cuda.device_count()):
94            torch.randn(1, device=f"cuda:{i}")
95        __cuda_ctx_rng_initialized = True
96
97
98# Test whether hardware TF32 math mode enabled. It is enabled only on:
99# - CUDA >= 11
100# - arch >= Ampere
101def tf32_is_not_fp32():
102    if not torch.cuda.is_available() or torch.version.cuda is None:
103        return False
104    if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8:
105        return False
106    if int(torch.version.cuda.split('.')[0]) < 11:
107        return False
108    return True
109
110
111@contextlib.contextmanager
112def tf32_off():
113    old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32
114    try:
115        torch.backends.cuda.matmul.allow_tf32 = False
116        with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=False):
117            yield
118    finally:
119        torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul
120
121
122@contextlib.contextmanager
123def tf32_on(self, tf32_precision=1e-5):
124    old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32
125    old_precision = self.precision
126    try:
127        torch.backends.cuda.matmul.allow_tf32 = True
128        self.precision = tf32_precision
129        with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=True):
130            yield
131    finally:
132        torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul
133        self.precision = old_precision
134
135
136# This is a wrapper that wraps a test to run this test twice, one with
137# allow_tf32=True, another with allow_tf32=False. When running with
138# allow_tf32=True, it will use reduced precision as specified by the
139# argument. For example:
140#    @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
141#    @tf32_on_and_off(0.005)
142#    def test_matmul(self, device, dtype):
143#        a = ...; b = ...;
144#        c = torch.matmul(a, b)
145#        self.assertEqual(c, expected)
146# In the above example, when testing torch.float32 and torch.complex64 on CUDA
147# on a CUDA >= 11 build on an >=Ampere architecture, the matmul will be running at
148# TF32 mode and TF32 mode off, and on TF32 mode, the assertEqual will use reduced
149# precision to check values.
150#
151# This decorator can be used for function with or without device/dtype, such as
152# @tf32_on_and_off(0.005)
153# def test_my_op(self)
154# @tf32_on_and_off(0.005)
155# def test_my_op(self, device)
156# @tf32_on_and_off(0.005)
157# def test_my_op(self, device, dtype)
158# @tf32_on_and_off(0.005)
159# def test_my_op(self, dtype)
160# if neither device nor dtype is specified, it will check if the system has ampere device
161# if device is specified, it will check if device is cuda
162# if dtype is specified, it will check if dtype is float32 or complex64
163# tf32 and fp32 are different only when all the three checks pass
164def tf32_on_and_off(tf32_precision=1e-5):
165    def with_tf32_disabled(self, function_call):
166        with tf32_off():
167            function_call()
168
169    def with_tf32_enabled(self, function_call):
170        with tf32_on(self, tf32_precision):
171            function_call()
172
173    def wrapper(f):
174        params = inspect.signature(f).parameters
175        arg_names = tuple(params.keys())
176
177        @functools.wraps(f)
178        def wrapped(*args, **kwargs):
179            for k, v in zip(arg_names, args):
180                kwargs[k] = v
181            cond = tf32_is_not_fp32()
182            if 'device' in kwargs:
183                cond = cond and (torch.device(kwargs['device']).type == 'cuda')
184            if 'dtype' in kwargs:
185                cond = cond and (kwargs['dtype'] in {torch.float32, torch.complex64})
186            if cond:
187                with_tf32_disabled(kwargs['self'], lambda: f(**kwargs))
188                with_tf32_enabled(kwargs['self'], lambda: f(**kwargs))
189            else:
190                f(**kwargs)
191
192        return wrapped
193    return wrapper
194
195
196# This is a wrapper that wraps a test to run it with TF32 turned off.
197# This wrapper is designed to be used when a test uses matmul or convolutions
198# but the purpose of that test is not testing matmul or convolutions.
199# Disabling TF32 will enforce torch.float tensors to be always computed
200# at full precision.
201def with_tf32_off(f):
202    @functools.wraps(f)
203    def wrapped(*args, **kwargs):
204        with tf32_off():
205            return f(*args, **kwargs)
206
207    return wrapped
208
209def _get_magma_version():
210    if 'Magma' not in torch.__config__.show():
211        return (0, 0)
212    position = torch.__config__.show().find('Magma ')
213    version_str = torch.__config__.show()[position + len('Magma '):].split('\n')[0]
214    return tuple(int(x) for x in version_str.split("."))
215
216def _get_torch_cuda_version():
217    if torch.version.cuda is None:
218        return (0, 0)
219    cuda_version = str(torch.version.cuda)
220    return tuple(int(x) for x in cuda_version.split("."))
221
222def _get_torch_rocm_version():
223    if not TEST_WITH_ROCM:
224        return (0, 0)
225    rocm_version = str(torch.version.hip)
226    rocm_version = rocm_version.split("-")[0]    # ignore git sha
227    return tuple(int(x) for x in rocm_version.split("."))
228
229def _check_cusparse_generic_available():
230    return not TEST_WITH_ROCM
231
232def _check_hipsparse_generic_available():
233    if not TEST_WITH_ROCM:
234        return False
235
236    rocm_version = str(torch.version.hip)
237    rocm_version = rocm_version.split("-")[0]    # ignore git sha
238    rocm_version_tuple = tuple(int(x) for x in rocm_version.split("."))
239    return not (rocm_version_tuple is None or rocm_version_tuple < (5, 1))
240
241
242TEST_CUSPARSE_GENERIC = _check_cusparse_generic_available()
243TEST_HIPSPARSE_GENERIC = _check_hipsparse_generic_available()
244
245# Shared by test_torch.py and test_multigpu.py
246def _create_scaling_models_optimizers(device="cuda", optimizer_ctor=torch.optim.SGD, optimizer_kwargs=None):
247    # Create a module+optimizer that will use scaling, and a control module+optimizer
248    # that will not use scaling, against which the scaling-enabled module+optimizer can be compared.
249    mod_control = torch.nn.Sequential(torch.nn.Linear(8, 8), torch.nn.Linear(8, 8)).to(device=device)
250    mod_scaling = torch.nn.Sequential(torch.nn.Linear(8, 8), torch.nn.Linear(8, 8)).to(device=device)
251    with torch.no_grad():
252        for c, s in zip(mod_control.parameters(), mod_scaling.parameters()):
253            s.copy_(c)
254
255    kwargs = {"lr": 1.0}
256    if optimizer_kwargs is not None:
257        kwargs.update(optimizer_kwargs)
258    opt_control = optimizer_ctor(mod_control.parameters(), **kwargs)
259    opt_scaling = optimizer_ctor(mod_scaling.parameters(), **kwargs)
260
261    return mod_control, mod_scaling, opt_control, opt_scaling
262
263# Shared by test_torch.py, test_cuda.py and test_multigpu.py
264def _create_scaling_case(device="cuda", dtype=torch.float, optimizer_ctor=torch.optim.SGD, optimizer_kwargs=None):
265    data = [(torch.randn((8, 8), dtype=dtype, device=device), torch.randn((8, 8), dtype=dtype, device=device)),
266            (torch.randn((8, 8), dtype=dtype, device=device), torch.randn((8, 8), dtype=dtype, device=device)),
267            (torch.randn((8, 8), dtype=dtype, device=device), torch.randn((8, 8), dtype=dtype, device=device)),
268            (torch.randn((8, 8), dtype=dtype, device=device), torch.randn((8, 8), dtype=dtype, device=device))]
269
270    loss_fn = torch.nn.MSELoss().to(device)
271
272    skip_iter = 2
273
274    return _create_scaling_models_optimizers(
275        device=device, optimizer_ctor=optimizer_ctor, optimizer_kwargs=optimizer_kwargs,
276    ) + (data, loss_fn, skip_iter)
277
278
279# Importing this module should NOT eagerly initialize CUDA
280if not CUDA_ALREADY_INITIALIZED_ON_IMPORT:
281    assert not torch.cuda.is_initialized()
282