1# mypy: ignore-errors 2 3r"""This file is allowed to initialize CUDA context when imported.""" 4 5import functools 6import torch 7import torch.cuda 8from torch.testing._internal.common_utils import LazyVal, TEST_NUMBA, TEST_WITH_ROCM, TEST_CUDA, IS_WINDOWS 9import inspect 10import contextlib 11import os 12 13 14CUDA_ALREADY_INITIALIZED_ON_IMPORT = torch.cuda.is_initialized() 15 16 17TEST_MULTIGPU = TEST_CUDA and torch.cuda.device_count() >= 2 18CUDA_DEVICE = torch.device("cuda:0") if TEST_CUDA else None 19# note: if ROCm is targeted, TEST_CUDNN is code for TEST_MIOPEN 20if TEST_WITH_ROCM: 21 TEST_CUDNN = LazyVal(lambda: TEST_CUDA) 22else: 23 TEST_CUDNN = LazyVal(lambda: TEST_CUDA and torch.backends.cudnn.is_acceptable(torch.tensor(1., device=CUDA_DEVICE))) 24 25TEST_CUDNN_VERSION = LazyVal(lambda: torch.backends.cudnn.version() if TEST_CUDNN else 0) 26 27SM53OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (5, 3)) 28SM60OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (6, 0)) 29SM70OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (7, 0)) 30SM75OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (7, 5)) 31SM80OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 0)) 32SM90OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)) 33 34IS_JETSON = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() in [(7, 2), (8, 7)]) 35 36def evaluate_gfx_arch_exact(matching_arch): 37 if not torch.cuda.is_available(): 38 return False 39 gcn_arch_name = torch.cuda.get_device_properties('cuda').gcnArchName 40 arch = os.environ.get('PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE', gcn_arch_name) 41 return arch == matching_arch 42 43GFX90A_Exact = LazyVal(lambda: evaluate_gfx_arch_exact('gfx90a:sramecc+:xnack-')) 44GFX942_Exact = LazyVal(lambda: evaluate_gfx_arch_exact('gfx942:sramecc+:xnack-')) 45 46def evaluate_platform_supports_flash_attention(): 47 if TEST_WITH_ROCM: 48 return evaluate_gfx_arch_exact('gfx90a:sramecc+:xnack-') or evaluate_gfx_arch_exact('gfx942:sramecc+:xnack-') 49 if TEST_CUDA: 50 return not IS_WINDOWS and SM80OrLater 51 return False 52 53def evaluate_platform_supports_efficient_attention(): 54 if TEST_WITH_ROCM: 55 return evaluate_gfx_arch_exact('gfx90a:sramecc+:xnack-') or evaluate_gfx_arch_exact('gfx942:sramecc+:xnack-') 56 if TEST_CUDA: 57 return True 58 return False 59 60PLATFORM_SUPPORTS_FLASH_ATTENTION: bool = LazyVal(lambda: evaluate_platform_supports_flash_attention()) 61PLATFORM_SUPPORTS_MEM_EFF_ATTENTION: bool = LazyVal(lambda: evaluate_platform_supports_efficient_attention()) 62# TODO(eqy): gate this against a cuDNN version 63PLATFORM_SUPPORTS_CUDNN_ATTENTION: bool = LazyVal(lambda: TEST_CUDA and not TEST_WITH_ROCM and 64 torch.backends.cuda.cudnn_sdp_enabled()) 65# This condition always evaluates to PLATFORM_SUPPORTS_MEM_EFF_ATTENTION but for logical clarity we keep it separate 66PLATFORM_SUPPORTS_FUSED_ATTENTION: bool = LazyVal(lambda: PLATFORM_SUPPORTS_FLASH_ATTENTION or PLATFORM_SUPPORTS_MEM_EFF_ATTENTION) 67 68PLATFORM_SUPPORTS_FUSED_SDPA: bool = TEST_CUDA and not TEST_WITH_ROCM 69 70PLATFORM_SUPPORTS_BF16: bool = LazyVal(lambda: TEST_CUDA and SM80OrLater) 71 72if TEST_NUMBA: 73 try: 74 import numba.cuda 75 TEST_NUMBA_CUDA = numba.cuda.is_available() 76 except Exception as e: 77 TEST_NUMBA_CUDA = False 78 TEST_NUMBA = False 79else: 80 TEST_NUMBA_CUDA = False 81 82# Used below in `initialize_cuda_context_rng` to ensure that CUDA context and 83# RNG have been initialized. 84__cuda_ctx_rng_initialized = False 85 86 87# after this call, CUDA context and RNG must have been initialized on each GPU 88def initialize_cuda_context_rng(): 89 global __cuda_ctx_rng_initialized 90 assert TEST_CUDA, 'CUDA must be available when calling initialize_cuda_context_rng' 91 if not __cuda_ctx_rng_initialized: 92 # initialize cuda context and rng for memory tests 93 for i in range(torch.cuda.device_count()): 94 torch.randn(1, device=f"cuda:{i}") 95 __cuda_ctx_rng_initialized = True 96 97 98# Test whether hardware TF32 math mode enabled. It is enabled only on: 99# - CUDA >= 11 100# - arch >= Ampere 101def tf32_is_not_fp32(): 102 if not torch.cuda.is_available() or torch.version.cuda is None: 103 return False 104 if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8: 105 return False 106 if int(torch.version.cuda.split('.')[0]) < 11: 107 return False 108 return True 109 110 111@contextlib.contextmanager 112def tf32_off(): 113 old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32 114 try: 115 torch.backends.cuda.matmul.allow_tf32 = False 116 with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=False): 117 yield 118 finally: 119 torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul 120 121 122@contextlib.contextmanager 123def tf32_on(self, tf32_precision=1e-5): 124 old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32 125 old_precision = self.precision 126 try: 127 torch.backends.cuda.matmul.allow_tf32 = True 128 self.precision = tf32_precision 129 with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=True): 130 yield 131 finally: 132 torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul 133 self.precision = old_precision 134 135 136# This is a wrapper that wraps a test to run this test twice, one with 137# allow_tf32=True, another with allow_tf32=False. When running with 138# allow_tf32=True, it will use reduced precision as specified by the 139# argument. For example: 140# @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) 141# @tf32_on_and_off(0.005) 142# def test_matmul(self, device, dtype): 143# a = ...; b = ...; 144# c = torch.matmul(a, b) 145# self.assertEqual(c, expected) 146# In the above example, when testing torch.float32 and torch.complex64 on CUDA 147# on a CUDA >= 11 build on an >=Ampere architecture, the matmul will be running at 148# TF32 mode and TF32 mode off, and on TF32 mode, the assertEqual will use reduced 149# precision to check values. 150# 151# This decorator can be used for function with or without device/dtype, such as 152# @tf32_on_and_off(0.005) 153# def test_my_op(self) 154# @tf32_on_and_off(0.005) 155# def test_my_op(self, device) 156# @tf32_on_and_off(0.005) 157# def test_my_op(self, device, dtype) 158# @tf32_on_and_off(0.005) 159# def test_my_op(self, dtype) 160# if neither device nor dtype is specified, it will check if the system has ampere device 161# if device is specified, it will check if device is cuda 162# if dtype is specified, it will check if dtype is float32 or complex64 163# tf32 and fp32 are different only when all the three checks pass 164def tf32_on_and_off(tf32_precision=1e-5): 165 def with_tf32_disabled(self, function_call): 166 with tf32_off(): 167 function_call() 168 169 def with_tf32_enabled(self, function_call): 170 with tf32_on(self, tf32_precision): 171 function_call() 172 173 def wrapper(f): 174 params = inspect.signature(f).parameters 175 arg_names = tuple(params.keys()) 176 177 @functools.wraps(f) 178 def wrapped(*args, **kwargs): 179 for k, v in zip(arg_names, args): 180 kwargs[k] = v 181 cond = tf32_is_not_fp32() 182 if 'device' in kwargs: 183 cond = cond and (torch.device(kwargs['device']).type == 'cuda') 184 if 'dtype' in kwargs: 185 cond = cond and (kwargs['dtype'] in {torch.float32, torch.complex64}) 186 if cond: 187 with_tf32_disabled(kwargs['self'], lambda: f(**kwargs)) 188 with_tf32_enabled(kwargs['self'], lambda: f(**kwargs)) 189 else: 190 f(**kwargs) 191 192 return wrapped 193 return wrapper 194 195 196# This is a wrapper that wraps a test to run it with TF32 turned off. 197# This wrapper is designed to be used when a test uses matmul or convolutions 198# but the purpose of that test is not testing matmul or convolutions. 199# Disabling TF32 will enforce torch.float tensors to be always computed 200# at full precision. 201def with_tf32_off(f): 202 @functools.wraps(f) 203 def wrapped(*args, **kwargs): 204 with tf32_off(): 205 return f(*args, **kwargs) 206 207 return wrapped 208 209def _get_magma_version(): 210 if 'Magma' not in torch.__config__.show(): 211 return (0, 0) 212 position = torch.__config__.show().find('Magma ') 213 version_str = torch.__config__.show()[position + len('Magma '):].split('\n')[0] 214 return tuple(int(x) for x in version_str.split(".")) 215 216def _get_torch_cuda_version(): 217 if torch.version.cuda is None: 218 return (0, 0) 219 cuda_version = str(torch.version.cuda) 220 return tuple(int(x) for x in cuda_version.split(".")) 221 222def _get_torch_rocm_version(): 223 if not TEST_WITH_ROCM: 224 return (0, 0) 225 rocm_version = str(torch.version.hip) 226 rocm_version = rocm_version.split("-")[0] # ignore git sha 227 return tuple(int(x) for x in rocm_version.split(".")) 228 229def _check_cusparse_generic_available(): 230 return not TEST_WITH_ROCM 231 232def _check_hipsparse_generic_available(): 233 if not TEST_WITH_ROCM: 234 return False 235 236 rocm_version = str(torch.version.hip) 237 rocm_version = rocm_version.split("-")[0] # ignore git sha 238 rocm_version_tuple = tuple(int(x) for x in rocm_version.split(".")) 239 return not (rocm_version_tuple is None or rocm_version_tuple < (5, 1)) 240 241 242TEST_CUSPARSE_GENERIC = _check_cusparse_generic_available() 243TEST_HIPSPARSE_GENERIC = _check_hipsparse_generic_available() 244 245# Shared by test_torch.py and test_multigpu.py 246def _create_scaling_models_optimizers(device="cuda", optimizer_ctor=torch.optim.SGD, optimizer_kwargs=None): 247 # Create a module+optimizer that will use scaling, and a control module+optimizer 248 # that will not use scaling, against which the scaling-enabled module+optimizer can be compared. 249 mod_control = torch.nn.Sequential(torch.nn.Linear(8, 8), torch.nn.Linear(8, 8)).to(device=device) 250 mod_scaling = torch.nn.Sequential(torch.nn.Linear(8, 8), torch.nn.Linear(8, 8)).to(device=device) 251 with torch.no_grad(): 252 for c, s in zip(mod_control.parameters(), mod_scaling.parameters()): 253 s.copy_(c) 254 255 kwargs = {"lr": 1.0} 256 if optimizer_kwargs is not None: 257 kwargs.update(optimizer_kwargs) 258 opt_control = optimizer_ctor(mod_control.parameters(), **kwargs) 259 opt_scaling = optimizer_ctor(mod_scaling.parameters(), **kwargs) 260 261 return mod_control, mod_scaling, opt_control, opt_scaling 262 263# Shared by test_torch.py, test_cuda.py and test_multigpu.py 264def _create_scaling_case(device="cuda", dtype=torch.float, optimizer_ctor=torch.optim.SGD, optimizer_kwargs=None): 265 data = [(torch.randn((8, 8), dtype=dtype, device=device), torch.randn((8, 8), dtype=dtype, device=device)), 266 (torch.randn((8, 8), dtype=dtype, device=device), torch.randn((8, 8), dtype=dtype, device=device)), 267 (torch.randn((8, 8), dtype=dtype, device=device), torch.randn((8, 8), dtype=dtype, device=device)), 268 (torch.randn((8, 8), dtype=dtype, device=device), torch.randn((8, 8), dtype=dtype, device=device))] 269 270 loss_fn = torch.nn.MSELoss().to(device) 271 272 skip_iter = 2 273 274 return _create_scaling_models_optimizers( 275 device=device, optimizer_ctor=optimizer_ctor, optimizer_kwargs=optimizer_kwargs, 276 ) + (data, loss_fn, skip_iter) 277 278 279# Importing this module should NOT eagerly initialize CUDA 280if not CUDA_ALREADY_INITIALIZED_ON_IMPORT: 281 assert not torch.cuda.is_initialized() 282