1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: unknown"] 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard Workerimport os 5*da0073e9SAndroid Build Coastguard Workerimport random 6*da0073e9SAndroid Build Coastguard Workerimport re 7*da0073e9SAndroid Build Coastguard Workerimport shutil 8*da0073e9SAndroid Build Coastguard Workerimport subprocess 9*da0073e9SAndroid Build Coastguard Workerimport sys 10*da0073e9SAndroid Build Coastguard Workerimport tempfile 11*da0073e9SAndroid Build Coastguard Workerimport textwrap 12*da0073e9SAndroid Build Coastguard Workerimport traceback 13*da0073e9SAndroid Build Coastguard Workerimport unittest 14*da0073e9SAndroid Build Coastguard Workerimport warnings 15*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, Dict, List 16*da0073e9SAndroid Build Coastguard Worker 17*da0073e9SAndroid Build Coastguard Workerimport torch 18*da0073e9SAndroid Build Coastguard Workerimport torch.cuda 19*da0073e9SAndroid Build Coastguard Workerimport torch.nn as nn 20*da0073e9SAndroid Build Coastguard Workerimport torch.utils.cpp_extension 21*da0073e9SAndroid Build Coastguard Workerimport torch.utils.data 22*da0073e9SAndroid Build Coastguard Workerfrom torch.autograd._functions.utils import check_onnx_broadcast 23*da0073e9SAndroid Build Coastguard Workerfrom torch.onnx.symbolic_opset9 import _prepare_onnx_paddings 24*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import TEST_MULTIGPU 25*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import ( 26*da0073e9SAndroid Build Coastguard Worker instantiate_device_type_tests, 27*da0073e9SAndroid Build Coastguard Worker onlyCPU, 28*da0073e9SAndroid Build Coastguard Worker ops, 29*da0073e9SAndroid Build Coastguard Worker) 30*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_methods_invocations import op_db 31*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( # type: ignore[attr-defined] 32*da0073e9SAndroid Build Coastguard Worker IS_FBCODE, 33*da0073e9SAndroid Build Coastguard Worker IS_SANDCASTLE, 34*da0073e9SAndroid Build Coastguard Worker IS_WINDOWS, 35*da0073e9SAndroid Build Coastguard Worker load_tests, 36*da0073e9SAndroid Build Coastguard Worker) 37*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._device import set_device 38*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._pytree import tree_all_only, tree_any 39*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._traceback import ( 40*da0073e9SAndroid Build Coastguard Worker CapturedTraceback, 41*da0073e9SAndroid Build Coastguard Worker format_traceback_short, 42*da0073e9SAndroid Build Coastguard Worker report_compile_source_on_error, 43*da0073e9SAndroid Build Coastguard Worker) 44*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.checkpoint import ( 45*da0073e9SAndroid Build Coastguard Worker _infer_device_type, 46*da0073e9SAndroid Build Coastguard Worker checkpoint, 47*da0073e9SAndroid Build Coastguard Worker checkpoint_sequential, 48*da0073e9SAndroid Build Coastguard Worker get_device_states, 49*da0073e9SAndroid Build Coastguard Worker) 50*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.data import DataLoader 51*da0073e9SAndroid Build Coastguard Worker 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard Worker# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for 54*da0073e9SAndroid Build Coastguard Worker# sharding on sandcastle. This line silences flake warnings 55*da0073e9SAndroid Build Coastguard Workerload_tests = load_tests 56*da0073e9SAndroid Build Coastguard Worker 57*da0073e9SAndroid Build Coastguard WorkerHAS_CUDA = torch.cuda.is_available() 58*da0073e9SAndroid Build Coastguard Worker 59*da0073e9SAndroid Build Coastguard Worker 60*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import run_tests, TestCase 61*da0073e9SAndroid Build Coastguard Worker 62*da0073e9SAndroid Build Coastguard Worker 63*da0073e9SAndroid Build Coastguard Workerclass RandomDatasetMock(torch.utils.data.Dataset): 64*da0073e9SAndroid Build Coastguard Worker def __getitem__(self, index): 65*da0073e9SAndroid Build Coastguard Worker return torch.tensor([torch.rand(1).item(), random.uniform(0, 1)]) 66*da0073e9SAndroid Build Coastguard Worker 67*da0073e9SAndroid Build Coastguard Worker def __len__(self): 68*da0073e9SAndroid Build Coastguard Worker return 1000 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Worker 71*da0073e9SAndroid Build Coastguard Workerclass TestCheckpoint(TestCase): 72*da0073e9SAndroid Build Coastguard Worker # This runs checkpoint_sequential on each of the nets in 73*da0073e9SAndroid Build Coastguard Worker # module_lists_to_compare, and compares them against the uncheckpointed model. 74*da0073e9SAndroid Build Coastguard Worker # To compare, it checks outputs as well as input gradients and parameter gradients 75*da0073e9SAndroid Build Coastguard Worker def _check_checkpoint_sequential( 76*da0073e9SAndroid Build Coastguard Worker self, 77*da0073e9SAndroid Build Coastguard Worker model, 78*da0073e9SAndroid Build Coastguard Worker module_lists_to_compare, 79*da0073e9SAndroid Build Coastguard Worker num_chunks, 80*da0073e9SAndroid Build Coastguard Worker input, 81*da0073e9SAndroid Build Coastguard Worker use_reentrant, 82*da0073e9SAndroid Build Coastguard Worker ): 83*da0073e9SAndroid Build Coastguard Worker # not checkpointed 84*da0073e9SAndroid Build Coastguard Worker out = model(input) 85*da0073e9SAndroid Build Coastguard Worker out_not_checkpointed = out.detach().clone() 86*da0073e9SAndroid Build Coastguard Worker model.zero_grad() 87*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 88*da0073e9SAndroid Build Coastguard Worker grad_not_checkpointed = { 89*da0073e9SAndroid Build Coastguard Worker name: param.grad.detach().clone() 90*da0073e9SAndroid Build Coastguard Worker for name, param in model.named_parameters() 91*da0073e9SAndroid Build Coastguard Worker } 92*da0073e9SAndroid Build Coastguard Worker input_grad_not_checkpointed = input.grad.detach().clone() 93*da0073e9SAndroid Build Coastguard Worker for model_to_compare in module_lists_to_compare: 94*da0073e9SAndroid Build Coastguard Worker # checkpointed model by passing list of modules 95*da0073e9SAndroid Build Coastguard Worker detached = input.detach() 96*da0073e9SAndroid Build Coastguard Worker detached.requires_grad = True 97*da0073e9SAndroid Build Coastguard Worker 98*da0073e9SAndroid Build Coastguard Worker # pass list of modules to checkpoint 99*da0073e9SAndroid Build Coastguard Worker out = checkpoint_sequential( 100*da0073e9SAndroid Build Coastguard Worker model_to_compare, num_chunks, detached, use_reentrant=use_reentrant 101*da0073e9SAndroid Build Coastguard Worker ) 102*da0073e9SAndroid Build Coastguard Worker out_checkpointed = out.detach().clone() 103*da0073e9SAndroid Build Coastguard Worker model.zero_grad() 104*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 105*da0073e9SAndroid Build Coastguard Worker grad_checkpointed = { 106*da0073e9SAndroid Build Coastguard Worker name: param.grad.detach().clone() 107*da0073e9SAndroid Build Coastguard Worker for name, param in model.named_parameters() 108*da0073e9SAndroid Build Coastguard Worker } 109*da0073e9SAndroid Build Coastguard Worker input_grad_checkpointed = detached.grad.detach().clone() 110*da0073e9SAndroid Build Coastguard Worker # compare outputs as well as the gradients of input and parameters 111*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_checkpointed, out_not_checkpointed) 112*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input_grad_not_checkpointed, input_grad_checkpointed) 113*da0073e9SAndroid Build Coastguard Worker for name in grad_checkpointed: 114*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad_checkpointed[name], grad_not_checkpointed[name]) 115*da0073e9SAndroid Build Coastguard Worker 116*da0073e9SAndroid Build Coastguard Worker # Test whether checkpoint is being triggered or not. For this, we check 117*da0073e9SAndroid Build Coastguard Worker # the number of times forward pass happens 118*da0073e9SAndroid Build Coastguard Worker def test_checkpoint_trigger(self): 119*da0073e9SAndroid Build Coastguard Worker class Net(nn.Module): 120*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 121*da0073e9SAndroid Build Coastguard Worker super().__init__() 122*da0073e9SAndroid Build Coastguard Worker self.counter = 0 123*da0073e9SAndroid Build Coastguard Worker 124*da0073e9SAndroid Build Coastguard Worker def forward(self, input_var): 125*da0073e9SAndroid Build Coastguard Worker self.counter += 1 126*da0073e9SAndroid Build Coastguard Worker # For reentrant, need to have autograd actually 127*da0073e9SAndroid Build Coastguard Worker # pack a tensor to trigger recomp 128*da0073e9SAndroid Build Coastguard Worker ret = input_var * torch.tensor(2.0) 129*da0073e9SAndroid Build Coastguard Worker return ret 130*da0073e9SAndroid Build Coastguard Worker 131*da0073e9SAndroid Build Coastguard Worker # checkpointed 132*da0073e9SAndroid Build Coastguard Worker for use_reentrant in [True, False]: 133*da0073e9SAndroid Build Coastguard Worker with self.subTest(use_reentrant=use_reentrant): 134*da0073e9SAndroid Build Coastguard Worker modules = [Net() for _ in range(10)] 135*da0073e9SAndroid Build Coastguard Worker for m in modules: 136*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m.counter, 0) 137*da0073e9SAndroid Build Coastguard Worker input_var = torch.randn(3, 4, requires_grad=True) 138*da0073e9SAndroid Build Coastguard Worker out = checkpoint_sequential( 139*da0073e9SAndroid Build Coastguard Worker modules, 2, input_var, use_reentrant=use_reentrant 140*da0073e9SAndroid Build Coastguard Worker ) 141*da0073e9SAndroid Build Coastguard Worker for m in modules: 142*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m.counter, 1) 143*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 144*da0073e9SAndroid Build Coastguard Worker for m in modules[: (len(modules) // 2)]: 145*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m.counter, 2) 146*da0073e9SAndroid Build Coastguard Worker for m in modules[(len(modules) // 2) :]: 147*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m.counter, 1) 148*da0073e9SAndroid Build Coastguard Worker 149*da0073e9SAndroid Build Coastguard Worker def test_checkpoint_valid(self): 150*da0073e9SAndroid Build Coastguard Worker model = nn.Sequential( 151*da0073e9SAndroid Build Coastguard Worker nn.Linear(100, 50), 152*da0073e9SAndroid Build Coastguard Worker nn.ReLU(), 153*da0073e9SAndroid Build Coastguard Worker nn.Linear(50, 20), 154*da0073e9SAndroid Build Coastguard Worker nn.ReLU(), 155*da0073e9SAndroid Build Coastguard Worker nn.Linear(20, 5), 156*da0073e9SAndroid Build Coastguard Worker nn.ReLU(), 157*da0073e9SAndroid Build Coastguard Worker ) 158*da0073e9SAndroid Build Coastguard Worker 159*da0073e9SAndroid Build Coastguard Worker input_var = torch.randn(1, 100, requires_grad=True) 160*da0073e9SAndroid Build Coastguard Worker 161*da0073e9SAndroid Build Coastguard Worker # checkpointed 162*da0073e9SAndroid Build Coastguard Worker chunks = 2 163*da0073e9SAndroid Build Coastguard Worker modules = list(model.children()) 164*da0073e9SAndroid Build Coastguard Worker out = checkpoint_sequential(modules, chunks, input_var, use_reentrant=True) 165*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 166*da0073e9SAndroid Build Coastguard Worker RuntimeError, "torch.utils.checkpoint is incompatible" 167*da0073e9SAndroid Build Coastguard Worker ): 168*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad( 169*da0073e9SAndroid Build Coastguard Worker outputs=[out], 170*da0073e9SAndroid Build Coastguard Worker grad_outputs=[torch.ones(1, 5)], 171*da0073e9SAndroid Build Coastguard Worker inputs=[input_var], 172*da0073e9SAndroid Build Coastguard Worker create_graph=True, 173*da0073e9SAndroid Build Coastguard Worker ) 174*da0073e9SAndroid Build Coastguard Worker # works with use_reentrant=False, and grads are the same 175*da0073e9SAndroid Build Coastguard Worker out = model(input_var) 176*da0073e9SAndroid Build Coastguard Worker grads_no_checkpoint = torch.autograd.grad( 177*da0073e9SAndroid Build Coastguard Worker outputs=[out], 178*da0073e9SAndroid Build Coastguard Worker grad_outputs=[torch.ones(1, 5)], 179*da0073e9SAndroid Build Coastguard Worker inputs=[input_var], 180*da0073e9SAndroid Build Coastguard Worker create_graph=True, 181*da0073e9SAndroid Build Coastguard Worker ) 182*da0073e9SAndroid Build Coastguard Worker out_checkpoint = checkpoint_sequential( 183*da0073e9SAndroid Build Coastguard Worker modules, chunks, input_var, use_reentrant=False 184*da0073e9SAndroid Build Coastguard Worker ) 185*da0073e9SAndroid Build Coastguard Worker # check outputs are the same 186*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_checkpoint, out) 187*da0073e9SAndroid Build Coastguard Worker grads_checkpoint = torch.autograd.grad( 188*da0073e9SAndroid Build Coastguard Worker outputs=[out_checkpoint], 189*da0073e9SAndroid Build Coastguard Worker grad_outputs=[torch.ones(1, 5)], 190*da0073e9SAndroid Build Coastguard Worker inputs=[input_var], 191*da0073e9SAndroid Build Coastguard Worker create_graph=True, 192*da0073e9SAndroid Build Coastguard Worker ) 193*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grads_no_checkpoint, grads_checkpoint) 194*da0073e9SAndroid Build Coastguard Worker 195*da0073e9SAndroid Build Coastguard Worker def test_checkpoint(self): 196*da0073e9SAndroid Build Coastguard Worker for use_reentrant in [True, False]: 197*da0073e9SAndroid Build Coastguard Worker with self.subTest(use_reentrant=use_reentrant): 198*da0073e9SAndroid Build Coastguard Worker model = nn.Sequential( 199*da0073e9SAndroid Build Coastguard Worker nn.Linear(100, 50), 200*da0073e9SAndroid Build Coastguard Worker nn.ReLU(), 201*da0073e9SAndroid Build Coastguard Worker nn.Linear(50, 20), 202*da0073e9SAndroid Build Coastguard Worker nn.ReLU(), 203*da0073e9SAndroid Build Coastguard Worker nn.Linear(20, 5), 204*da0073e9SAndroid Build Coastguard Worker nn.ReLU(), 205*da0073e9SAndroid Build Coastguard Worker ) 206*da0073e9SAndroid Build Coastguard Worker 207*da0073e9SAndroid Build Coastguard Worker # Compare uncheckpointed model with its checkpointed counterparts 208*da0073e9SAndroid Build Coastguard Worker # In addition to running checkpoint_sequential on the nn.Sequential 209*da0073e9SAndroid Build Coastguard Worker # instance, we also run the function on the list of functions within 210*da0073e9SAndroid Build Coastguard Worker # the module. 211*da0073e9SAndroid Build Coastguard Worker self._check_checkpoint_sequential( 212*da0073e9SAndroid Build Coastguard Worker model, 213*da0073e9SAndroid Build Coastguard Worker [list(model.children()), model], 214*da0073e9SAndroid Build Coastguard Worker 2, 215*da0073e9SAndroid Build Coastguard Worker torch.randn(1, 100, requires_grad=True), 216*da0073e9SAndroid Build Coastguard Worker use_reentrant=use_reentrant, 217*da0073e9SAndroid Build Coastguard Worker ) 218*da0073e9SAndroid Build Coastguard Worker 219*da0073e9SAndroid Build Coastguard Worker def test_checkpoint_module_list(self): 220*da0073e9SAndroid Build Coastguard Worker class ModuleListNet(nn.Module): 221*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 222*da0073e9SAndroid Build Coastguard Worker super().__init__() 223*da0073e9SAndroid Build Coastguard Worker module_list = [ 224*da0073e9SAndroid Build Coastguard Worker nn.Linear(100, 50), 225*da0073e9SAndroid Build Coastguard Worker nn.ReLU(), 226*da0073e9SAndroid Build Coastguard Worker nn.Linear(50, 20), 227*da0073e9SAndroid Build Coastguard Worker nn.ReLU(), 228*da0073e9SAndroid Build Coastguard Worker nn.Linear(20, 5), 229*da0073e9SAndroid Build Coastguard Worker nn.ReLU(), 230*da0073e9SAndroid Build Coastguard Worker ] 231*da0073e9SAndroid Build Coastguard Worker self.module_list = nn.ModuleList(module_list) 232*da0073e9SAndroid Build Coastguard Worker 233*da0073e9SAndroid Build Coastguard Worker def forward(self, input): 234*da0073e9SAndroid Build Coastguard Worker for layer in self.module_list: 235*da0073e9SAndroid Build Coastguard Worker input = layer(input) 236*da0073e9SAndroid Build Coastguard Worker return input 237*da0073e9SAndroid Build Coastguard Worker 238*da0073e9SAndroid Build Coastguard Worker for use_reentrant in [True, False]: 239*da0073e9SAndroid Build Coastguard Worker with self.subTest(use_reentrant=use_reentrant): 240*da0073e9SAndroid Build Coastguard Worker model = ModuleListNet() 241*da0073e9SAndroid Build Coastguard Worker 242*da0073e9SAndroid Build Coastguard Worker # Compare uncheckpointed model with its checkpointed counterparts. 243*da0073e9SAndroid Build Coastguard Worker self._check_checkpoint_sequential( 244*da0073e9SAndroid Build Coastguard Worker model, 245*da0073e9SAndroid Build Coastguard Worker [list(model.module_list.children()), model.module_list], 246*da0073e9SAndroid Build Coastguard Worker 2, 247*da0073e9SAndroid Build Coastguard Worker torch.randn(1, 100, requires_grad=True), 248*da0073e9SAndroid Build Coastguard Worker use_reentrant=use_reentrant, 249*da0073e9SAndroid Build Coastguard Worker ) 250*da0073e9SAndroid Build Coastguard Worker 251*da0073e9SAndroid Build Coastguard Worker def test_checkpoint_sequential_deprecated_multiple_args(self): 252*da0073e9SAndroid Build Coastguard Worker class Two(nn.Module): 253*da0073e9SAndroid Build Coastguard Worker def forward(self, a, b): 254*da0073e9SAndroid Build Coastguard Worker return a, b 255*da0073e9SAndroid Build Coastguard Worker 256*da0073e9SAndroid Build Coastguard Worker model = nn.Sequential(Two()) 257*da0073e9SAndroid Build Coastguard Worker a = torch.randn(1, 100, requires_grad=True) 258*da0073e9SAndroid Build Coastguard Worker b = torch.randn(1, 100, requires_grad=True) 259*da0073e9SAndroid Build Coastguard Worker 260*da0073e9SAndroid Build Coastguard Worker for use_reentrant in [True, False]: 261*da0073e9SAndroid Build Coastguard Worker with self.subTest(use_reentrant=use_reentrant): 262*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 263*da0073e9SAndroid Build Coastguard Worker checkpoint_sequential(model, 1, a, b) # type: ignore[call-arg] 264*da0073e9SAndroid Build Coastguard Worker 265*da0073e9SAndroid Build Coastguard Worker def test_checkpoint_sequential_deprecated_no_args(self): 266*da0073e9SAndroid Build Coastguard Worker class Noop(nn.Module): 267*da0073e9SAndroid Build Coastguard Worker def forward(self): 268*da0073e9SAndroid Build Coastguard Worker pass 269*da0073e9SAndroid Build Coastguard Worker 270*da0073e9SAndroid Build Coastguard Worker model = nn.Sequential(Noop()) 271*da0073e9SAndroid Build Coastguard Worker for use_reentrant in [True, False]: 272*da0073e9SAndroid Build Coastguard Worker with self.subTest(use_reentrant=use_reentrant): 273*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 274*da0073e9SAndroid Build Coastguard Worker checkpoint_sequential(model, 1) # type: ignore[call-arg] 275*da0073e9SAndroid Build Coastguard Worker 276*da0073e9SAndroid Build Coastguard Worker def test_checkpoint_rng_cpu(self): 277*da0073e9SAndroid Build Coastguard Worker for _ in range(5): 278*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(20000, device="cpu").requires_grad_() 279*da0073e9SAndroid Build Coastguard Worker phase1 = torch.nn.Dropout() 280*da0073e9SAndroid Build Coastguard Worker phase2 = torch.nn.Dropout() 281*da0073e9SAndroid Build Coastguard Worker 282*da0073e9SAndroid Build Coastguard Worker def run_fn(input): 283*da0073e9SAndroid Build Coastguard Worker return phase2(input) 284*da0073e9SAndroid Build Coastguard Worker 285*da0073e9SAndroid Build Coastguard Worker state = torch.get_rng_state() 286*da0073e9SAndroid Build Coastguard Worker 287*da0073e9SAndroid Build Coastguard Worker out = phase1(inp) 288*da0073e9SAndroid Build Coastguard Worker out = checkpoint(run_fn, out, use_reentrant=True) 289*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 290*da0073e9SAndroid Build Coastguard Worker grad_with_checkpointing = inp.grad 291*da0073e9SAndroid Build Coastguard Worker 292*da0073e9SAndroid Build Coastguard Worker torch.set_rng_state(state) 293*da0073e9SAndroid Build Coastguard Worker 294*da0073e9SAndroid Build Coastguard Worker inp.grad = None 295*da0073e9SAndroid Build Coastguard Worker 296*da0073e9SAndroid Build Coastguard Worker out = phase1(inp) 297*da0073e9SAndroid Build Coastguard Worker out = run_fn(out) 298*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 299*da0073e9SAndroid Build Coastguard Worker grad_no_checkpointing = inp.grad 300*da0073e9SAndroid Build Coastguard Worker 301*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad_with_checkpointing, grad_no_checkpointing) 302*da0073e9SAndroid Build Coastguard Worker 303*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not HAS_CUDA, "No CUDA") 304*da0073e9SAndroid Build Coastguard Worker def test_checkpoint_rng_cuda(self): 305*da0073e9SAndroid Build Coastguard Worker for _ in range(5): 306*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(20000, device="cuda").requires_grad_() 307*da0073e9SAndroid Build Coastguard Worker phase1 = torch.nn.Dropout() 308*da0073e9SAndroid Build Coastguard Worker phase2 = torch.nn.Dropout() 309*da0073e9SAndroid Build Coastguard Worker 310*da0073e9SAndroid Build Coastguard Worker def run_fn(input): 311*da0073e9SAndroid Build Coastguard Worker return phase2(input) 312*da0073e9SAndroid Build Coastguard Worker 313*da0073e9SAndroid Build Coastguard Worker state = torch.cuda.get_rng_state() 314*da0073e9SAndroid Build Coastguard Worker 315*da0073e9SAndroid Build Coastguard Worker out = phase1(inp) 316*da0073e9SAndroid Build Coastguard Worker out = checkpoint(run_fn, out, use_reentrant=True) 317*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 318*da0073e9SAndroid Build Coastguard Worker grad_with_checkpointing = inp.grad 319*da0073e9SAndroid Build Coastguard Worker 320*da0073e9SAndroid Build Coastguard Worker torch.cuda.set_rng_state(state) 321*da0073e9SAndroid Build Coastguard Worker 322*da0073e9SAndroid Build Coastguard Worker inp.grad = None 323*da0073e9SAndroid Build Coastguard Worker 324*da0073e9SAndroid Build Coastguard Worker out = phase1(inp) 325*da0073e9SAndroid Build Coastguard Worker out = run_fn(out) 326*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 327*da0073e9SAndroid Build Coastguard Worker grad_no_checkpointing = inp.grad 328*da0073e9SAndroid Build Coastguard Worker 329*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad_with_checkpointing, grad_no_checkpointing) 330*da0073e9SAndroid Build Coastguard Worker 331*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not HAS_CUDA, "No CUDA") 332*da0073e9SAndroid Build Coastguard Worker def test_checkpoint_not_preserve_rng_state_and_without_reentrant(self): 333*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(2, device="cuda").requires_grad_() 334*da0073e9SAndroid Build Coastguard Worker layer = torch.nn.Dropout() 335*da0073e9SAndroid Build Coastguard Worker 336*da0073e9SAndroid Build Coastguard Worker def run_fn(input): 337*da0073e9SAndroid Build Coastguard Worker return layer(input) 338*da0073e9SAndroid Build Coastguard Worker 339*da0073e9SAndroid Build Coastguard Worker out = checkpoint(run_fn, inp, use_reentrant=False, preserve_rng_state=False) 340*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 341*da0073e9SAndroid Build Coastguard Worker # This should run without error 342*da0073e9SAndroid Build Coastguard Worker 343*da0073e9SAndroid Build Coastguard Worker def test_checkpoint_non_tensor(self): 344*da0073e9SAndroid Build Coastguard Worker def run_fn(tensor1, tensor2): 345*da0073e9SAndroid Build Coastguard Worker if tensor2 is None: 346*da0073e9SAndroid Build Coastguard Worker return tensor1 347*da0073e9SAndroid Build Coastguard Worker return tensor1 + tensor2 348*da0073e9SAndroid Build Coastguard Worker 349*da0073e9SAndroid Build Coastguard Worker input_var = torch.randn(1, 100, requires_grad=True) 350*da0073e9SAndroid Build Coastguard Worker out = checkpoint(run_fn, input_var, None, use_reentrant=True) 351*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 352*da0073e9SAndroid Build Coastguard Worker 353*da0073e9SAndroid Build Coastguard Worker def test_checkpoint_non_tensor_inputs_outputs(self): 354*da0073e9SAndroid Build Coastguard Worker def foo(t1, t2, scale, t3): 355*da0073e9SAndroid Build Coastguard Worker t4 = t1 + t2 * t3 356*da0073e9SAndroid Build Coastguard Worker t5 = t1 * t2 + t3 357*da0073e9SAndroid Build Coastguard Worker t4 *= scale 358*da0073e9SAndroid Build Coastguard Worker t5 *= scale 359*da0073e9SAndroid Build Coastguard Worker return scale, t4, None, True, t5, "bar", t1 360*da0073e9SAndroid Build Coastguard Worker 361*da0073e9SAndroid Build Coastguard Worker t1 = torch.rand(10, requires_grad=True) 362*da0073e9SAndroid Build Coastguard Worker t2 = torch.rand(10, requires_grad=True) 363*da0073e9SAndroid Build Coastguard Worker t3 = torch.rand(10) 364*da0073e9SAndroid Build Coastguard Worker scale = random.randint(0, 10) 365*da0073e9SAndroid Build Coastguard Worker res = checkpoint(foo, t1, t2, scale, t3, use_reentrant=True) 366*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scale, res[0]) 367*da0073e9SAndroid Build Coastguard Worker self.assertEqual((t1 + t2 * t3) * scale, res[1]) 368*da0073e9SAndroid Build Coastguard Worker self.assertEqual(None, res[2]) 369*da0073e9SAndroid Build Coastguard Worker self.assertEqual(True, res[3]) 370*da0073e9SAndroid Build Coastguard Worker self.assertEqual((t1 * t2 + t3) * scale, res[4]) 371*da0073e9SAndroid Build Coastguard Worker self.assertEqual("bar", res[5]) 372*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t1, res[6]) 373*da0073e9SAndroid Build Coastguard Worker 374*da0073e9SAndroid Build Coastguard Worker # Validate running backward. 375*da0073e9SAndroid Build Coastguard Worker res[1].sum().backward(retain_graph=True) 376*da0073e9SAndroid Build Coastguard Worker res[4].sum().backward(retain_graph=True) 377*da0073e9SAndroid Build Coastguard Worker res[6].sum().backward() 378*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 379*da0073e9SAndroid Build Coastguard Worker RuntimeError, "Trying to backward through the graph a second time" 380*da0073e9SAndroid Build Coastguard Worker ): 381*da0073e9SAndroid Build Coastguard Worker res[6].sum().backward() 382*da0073e9SAndroid Build Coastguard Worker t1_grad = t1.grad 383*da0073e9SAndroid Build Coastguard Worker t2_grad = t2.grad 384*da0073e9SAndroid Build Coastguard Worker 385*da0073e9SAndroid Build Coastguard Worker # Reset grads, run without checkpoint and validate we receive same grads. 386*da0073e9SAndroid Build Coastguard Worker t1.grad = None 387*da0073e9SAndroid Build Coastguard Worker t2.grad = None 388*da0073e9SAndroid Build Coastguard Worker res = foo(t1, t2, scale, t3) 389*da0073e9SAndroid Build Coastguard Worker torch.autograd.backward([res[1].sum(), res[4].sum(), res[6].sum()]) 390*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t1.grad, t1_grad) 391*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t2.grad, t2_grad) 392*da0073e9SAndroid Build Coastguard Worker 393*da0073e9SAndroid Build Coastguard Worker def test_checkpoint_no_tensors(self): 394*da0073e9SAndroid Build Coastguard Worker def foo(t1, t2, scale, t3): 395*da0073e9SAndroid Build Coastguard Worker t4 = t1 + t2 * t3 396*da0073e9SAndroid Build Coastguard Worker t5 = t1 * t2 + t3 397*da0073e9SAndroid Build Coastguard Worker t4 *= scale 398*da0073e9SAndroid Build Coastguard Worker t5 *= scale 399*da0073e9SAndroid Build Coastguard Worker return scale, t4, None, True, t5, "bar", t1 400*da0073e9SAndroid Build Coastguard Worker 401*da0073e9SAndroid Build Coastguard Worker t1 = random.random() 402*da0073e9SAndroid Build Coastguard Worker t2 = random.random() 403*da0073e9SAndroid Build Coastguard Worker t3 = random.random() 404*da0073e9SAndroid Build Coastguard Worker scale = random.randint(0, 10) 405*da0073e9SAndroid Build Coastguard Worker res = checkpoint(foo, t1, t2, scale, t3, use_reentrant=True) 406*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scale, res[0]) 407*da0073e9SAndroid Build Coastguard Worker self.assertEqual((t1 + t2 * t3) * scale, res[1]) 408*da0073e9SAndroid Build Coastguard Worker self.assertEqual(None, res[2]) 409*da0073e9SAndroid Build Coastguard Worker self.assertEqual(True, res[3]) 410*da0073e9SAndroid Build Coastguard Worker self.assertEqual((t1 * t2 + t3) * scale, res[4]) 411*da0073e9SAndroid Build Coastguard Worker self.assertEqual("bar", res[5]) 412*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t1, res[6]) 413*da0073e9SAndroid Build Coastguard Worker 414*da0073e9SAndroid Build Coastguard Worker def test_checkpoint_partial_grad(self): 415*da0073e9SAndroid Build Coastguard Worker def run_fn(tensor1, tensor2): 416*da0073e9SAndroid Build Coastguard Worker # tensor 2 is used for other application logic 417*da0073e9SAndroid Build Coastguard Worker return tensor1, tensor2 418*da0073e9SAndroid Build Coastguard Worker 419*da0073e9SAndroid Build Coastguard Worker input_var = torch.randn(1, 4, requires_grad=True) 420*da0073e9SAndroid Build Coastguard Worker input_var2 = torch.randn(1, 4, requires_grad=False) 421*da0073e9SAndroid Build Coastguard Worker out = checkpoint(run_fn, input_var, input_var2, use_reentrant=True) 422*da0073e9SAndroid Build Coastguard Worker out[0].sum().backward() 423*da0073e9SAndroid Build Coastguard Worker 424*da0073e9SAndroid Build Coastguard Worker def run_fn2(tensor1, tensor2): 425*da0073e9SAndroid Build Coastguard Worker return tensor1 426*da0073e9SAndroid Build Coastguard Worker 427*da0073e9SAndroid Build Coastguard Worker input_var = torch.randn(1, 4, requires_grad=False) 428*da0073e9SAndroid Build Coastguard Worker input_var2 = torch.randn(1, 4, requires_grad=True) 429*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 430*da0073e9SAndroid Build Coastguard Worker RuntimeError, 431*da0073e9SAndroid Build Coastguard Worker r"none of output has requires_grad=True, this checkpoint\(\) is not necessary", 432*da0073e9SAndroid Build Coastguard Worker ): 433*da0073e9SAndroid Build Coastguard Worker out = checkpoint(run_fn2, input_var, input_var2, use_reentrant=True) 434*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 435*da0073e9SAndroid Build Coastguard Worker 436*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA") 437*da0073e9SAndroid Build Coastguard Worker def test_checkpointing_without_reentrant_early_free(self): 438*da0073e9SAndroid Build Coastguard Worker # I don't know how to check if the temporary saved variable buffer 439*da0073e9SAndroid Build Coastguard Worker # get de-allocated directly. So using cuda memory usage as a proxy 440*da0073e9SAndroid Build Coastguard Worker 441*da0073e9SAndroid Build Coastguard Worker def _do_test(fn, should_free): 442*da0073e9SAndroid Build Coastguard Worker stats: List[int] = [] 443*da0073e9SAndroid Build Coastguard Worker 444*da0073e9SAndroid Build Coastguard Worker def track(x, idx): 445*da0073e9SAndroid Build Coastguard Worker # Track that at each step of the backward, some Tensor were 446*da0073e9SAndroid Build Coastguard Worker # de-allocated (which correspond to the checkpoint storage being 447*da0073e9SAndroid Build Coastguard Worker # emptied at each step) 448*da0073e9SAndroid Build Coastguard Worker def hook(_unused): 449*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(stats), idx) 450*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize() 451*da0073e9SAndroid Build Coastguard Worker stats.append(torch.cuda.memory_allocated()) 452*da0073e9SAndroid Build Coastguard Worker if idx > 0: 453*da0073e9SAndroid Build Coastguard Worker if should_free: 454*da0073e9SAndroid Build Coastguard Worker self.assertLess(stats[idx], stats[idx - 1]) 455*da0073e9SAndroid Build Coastguard Worker else: 456*da0073e9SAndroid Build Coastguard Worker self.assertEqual(stats[idx], stats[idx - 1]) 457*da0073e9SAndroid Build Coastguard Worker 458*da0073e9SAndroid Build Coastguard Worker x.register_hook(hook) 459*da0073e9SAndroid Build Coastguard Worker 460*da0073e9SAndroid Build Coastguard Worker def test_fn(x): 461*da0073e9SAndroid Build Coastguard Worker # The main property of this function is that it contains multiple 462*da0073e9SAndroid Build Coastguard Worker # operations that save gradients in a chain. 463*da0073e9SAndroid Build Coastguard Worker x = x**2 464*da0073e9SAndroid Build Coastguard Worker track(x, 2) 465*da0073e9SAndroid Build Coastguard Worker x = x**2 466*da0073e9SAndroid Build Coastguard Worker track(x, 1) 467*da0073e9SAndroid Build Coastguard Worker x = x**2 468*da0073e9SAndroid Build Coastguard Worker track(x, 0) 469*da0073e9SAndroid Build Coastguard Worker x = x**2 470*da0073e9SAndroid Build Coastguard Worker return x.sum() 471*da0073e9SAndroid Build Coastguard Worker 472*da0073e9SAndroid Build Coastguard Worker fn(test_fn) 473*da0073e9SAndroid Build Coastguard Worker 474*da0073e9SAndroid Build Coastguard Worker return stats 475*da0073e9SAndroid Build Coastguard Worker 476*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(10, device="cuda", requires_grad=True) 477*da0073e9SAndroid Build Coastguard Worker x.grad = torch.zeros_like(x) 478*da0073e9SAndroid Build Coastguard Worker 479*da0073e9SAndroid Build Coastguard Worker # In a regular backward, buffers get eagerly freed 480*da0073e9SAndroid Build Coastguard Worker non_retain_stats = _do_test(lambda fn: fn(x).backward(), True) 481*da0073e9SAndroid Build Coastguard Worker 482*da0073e9SAndroid Build Coastguard Worker # In a retain_grad backward, buffers get preserved 483*da0073e9SAndroid Build Coastguard Worker _unused_retain_stats = _do_test( 484*da0073e9SAndroid Build Coastguard Worker lambda fn: fn(x).backward(retain_graph=True), False 485*da0073e9SAndroid Build Coastguard Worker ) 486*da0073e9SAndroid Build Coastguard Worker 487*da0073e9SAndroid Build Coastguard Worker # In a regular backward with checkpoint, buffers get eagerly freed 488*da0073e9SAndroid Build Coastguard Worker checkpoint_non_retain_stats = _do_test( 489*da0073e9SAndroid Build Coastguard Worker lambda fn: checkpoint(fn, x, use_reentrant=False).backward(), True 490*da0073e9SAndroid Build Coastguard Worker ) 491*da0073e9SAndroid Build Coastguard Worker 492*da0073e9SAndroid Build Coastguard Worker # In a retain_grad backward with checkpoint, buffers get eagerly freed 493*da0073e9SAndroid Build Coastguard Worker checkpoint_retain_stats = _do_test( 494*da0073e9SAndroid Build Coastguard Worker lambda fn: checkpoint(fn, x, use_reentrant=False).backward( 495*da0073e9SAndroid Build Coastguard Worker retain_graph=True 496*da0073e9SAndroid Build Coastguard Worker ), 497*da0073e9SAndroid Build Coastguard Worker True, 498*da0073e9SAndroid Build Coastguard Worker ) 499*da0073e9SAndroid Build Coastguard Worker 500*da0073e9SAndroid Build Coastguard Worker self.assertEqual(non_retain_stats, checkpoint_non_retain_stats) 501*da0073e9SAndroid Build Coastguard Worker self.assertEqual(non_retain_stats, checkpoint_retain_stats) 502*da0073e9SAndroid Build Coastguard Worker 503*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") 504*da0073e9SAndroid Build Coastguard Worker def test_get_device_states_recursive(self): 505*da0073e9SAndroid Build Coastguard Worker inp = { 506*da0073e9SAndroid Build Coastguard Worker "foo": torch.rand(10, device="cuda:0"), 507*da0073e9SAndroid Build Coastguard Worker "bar": [torch.rand(10, device="cuda:1")], 508*da0073e9SAndroid Build Coastguard Worker } 509*da0073e9SAndroid Build Coastguard Worker device_ids, device_states = get_device_states(inp) 510*da0073e9SAndroid Build Coastguard Worker self.assertEqual(2, len(device_ids)) 511*da0073e9SAndroid Build Coastguard Worker self.assertEqual(2, len(device_states)) 512*da0073e9SAndroid Build Coastguard Worker self.assertEqual(0, device_ids[0]) 513*da0073e9SAndroid Build Coastguard Worker self.assertEqual(1, device_ids[1]) 514*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(device_states[0], torch.Tensor)) 515*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(device_states[1], torch.Tensor)) 516*da0073e9SAndroid Build Coastguard Worker 517*da0073e9SAndroid Build Coastguard Worker def test_infer_device_state_recursive_meta(self): 518*da0073e9SAndroid Build Coastguard Worker inp = {"foo": torch.rand(10, device="meta")} 519*da0073e9SAndroid Build Coastguard Worker device_type = _infer_device_type(inp) 520*da0073e9SAndroid Build Coastguard Worker self.assertEqual("meta", device_type) 521*da0073e9SAndroid Build Coastguard Worker 522*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") 523*da0073e9SAndroid Build Coastguard Worker def test_infer_device_state_recursive_multi_cuda(self): 524*da0073e9SAndroid Build Coastguard Worker # Check that no warning is issued for either cuda:0, cuda:1 or 525*da0073e9SAndroid Build Coastguard Worker # cuda:0, cuda:0 cases since they are both the same device type 526*da0073e9SAndroid Build Coastguard Worker inp = { 527*da0073e9SAndroid Build Coastguard Worker "foo": torch.rand(10, device="cuda:0"), 528*da0073e9SAndroid Build Coastguard Worker "bar": [torch.rand(10, device="cuda:1")], 529*da0073e9SAndroid Build Coastguard Worker } 530*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(): 531*da0073e9SAndroid Build Coastguard Worker warnings.simplefilter("error") 532*da0073e9SAndroid Build Coastguard Worker device_type = _infer_device_type(inp) 533*da0073e9SAndroid Build Coastguard Worker self.assertEqual("cuda", device_type) 534*da0073e9SAndroid Build Coastguard Worker inp = { 535*da0073e9SAndroid Build Coastguard Worker "foo": torch.rand(10, device="cuda:0"), 536*da0073e9SAndroid Build Coastguard Worker "bar": [torch.rand(10, device="cuda:0")], 537*da0073e9SAndroid Build Coastguard Worker } 538*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(): 539*da0073e9SAndroid Build Coastguard Worker warnings.simplefilter("error") 540*da0073e9SAndroid Build Coastguard Worker device_type = _infer_device_type(inp) 541*da0073e9SAndroid Build Coastguard Worker self.assertEqual("cuda", device_type) 542*da0073e9SAndroid Build Coastguard Worker # Check that a warning is issued for cuda:0, meta and that it includes 543*da0073e9SAndroid Build Coastguard Worker # device type information 544*da0073e9SAndroid Build Coastguard Worker inp = { 545*da0073e9SAndroid Build Coastguard Worker "foo": torch.rand(10, device="cuda:0"), 546*da0073e9SAndroid Build Coastguard Worker "bar": [torch.rand(10, device="meta")], 547*da0073e9SAndroid Build Coastguard Worker } 548*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 549*da0073e9SAndroid Build Coastguard Worker device_type = _infer_device_type(inp) 550*da0073e9SAndroid Build Coastguard Worker self.assertEqual("cuda", device_type) 551*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 1) 552*da0073e9SAndroid Build Coastguard Worker warning_msg = str(w[-1].message) 553*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 554*da0073e9SAndroid Build Coastguard Worker "Tensor arguments, excluding CPU tensors, are detected on at least two types of devices" 555*da0073e9SAndroid Build Coastguard Worker in warning_msg 556*da0073e9SAndroid Build Coastguard Worker ) 557*da0073e9SAndroid Build Coastguard Worker self.assertTrue("Device types: ['cuda', 'meta']" in warning_msg) 558*da0073e9SAndroid Build Coastguard Worker self.assertTrue("first device type: cuda" in warning_msg) 559*da0073e9SAndroid Build Coastguard Worker 560*da0073e9SAndroid Build Coastguard Worker 561*da0073e9SAndroid Build Coastguard Workerclass TestDataLoaderUtils(TestCase): 562*da0073e9SAndroid Build Coastguard Worker MAX_TIMEOUT_IN_SECOND = 300 563*da0073e9SAndroid Build Coastguard Worker 564*da0073e9SAndroid Build Coastguard Worker def setUp(self): 565*da0073e9SAndroid Build Coastguard Worker super().setUp() 566*da0073e9SAndroid Build Coastguard Worker self.dataset = torch.randn(5, 3, 3, 2) 567*da0073e9SAndroid Build Coastguard Worker self.batch_size = 3 568*da0073e9SAndroid Build Coastguard Worker 569*da0073e9SAndroid Build Coastguard Worker def test_random_seed(self): 570*da0073e9SAndroid Build Coastguard Worker def run(): 571*da0073e9SAndroid Build Coastguard Worker dataloader = torch.utils.data.DataLoader( 572*da0073e9SAndroid Build Coastguard Worker RandomDatasetMock(), 573*da0073e9SAndroid Build Coastguard Worker batch_size=2, 574*da0073e9SAndroid Build Coastguard Worker num_workers=4, 575*da0073e9SAndroid Build Coastguard Worker shuffle=True, 576*da0073e9SAndroid Build Coastguard Worker timeout=self.MAX_TIMEOUT_IN_SECOND, 577*da0073e9SAndroid Build Coastguard Worker ) 578*da0073e9SAndroid Build Coastguard Worker return next(iter(dataloader)) 579*da0073e9SAndroid Build Coastguard Worker 580*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(2018) 581*da0073e9SAndroid Build Coastguard Worker x1 = run() 582*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(2018) 583*da0073e9SAndroid Build Coastguard Worker x2 = run() 584*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x1, x2) 585*da0073e9SAndroid Build Coastguard Worker 586*da0073e9SAndroid Build Coastguard Worker def test_single_keep(self): 587*da0073e9SAndroid Build Coastguard Worker # self.dataset is a Tensor here; technically not a valid input because 588*da0073e9SAndroid Build Coastguard Worker # not a Dataset subclass, but needs to stay working so add ignore's 589*da0073e9SAndroid Build Coastguard Worker # for type checking with mypy 590*da0073e9SAndroid Build Coastguard Worker dataloader: DataLoader = DataLoader( 591*da0073e9SAndroid Build Coastguard Worker self.dataset, # type: ignore[arg-type] 592*da0073e9SAndroid Build Coastguard Worker batch_size=self.batch_size, 593*da0073e9SAndroid Build Coastguard Worker num_workers=0, 594*da0073e9SAndroid Build Coastguard Worker drop_last=False, 595*da0073e9SAndroid Build Coastguard Worker ) 596*da0073e9SAndroid Build Coastguard Worker dataiter = iter(dataloader) 597*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(list(dataiter)), 2) 598*da0073e9SAndroid Build Coastguard Worker 599*da0073e9SAndroid Build Coastguard Worker def test_single_drop(self): 600*da0073e9SAndroid Build Coastguard Worker dataloader: DataLoader = DataLoader( 601*da0073e9SAndroid Build Coastguard Worker self.dataset, # type: ignore[arg-type] 602*da0073e9SAndroid Build Coastguard Worker batch_size=self.batch_size, 603*da0073e9SAndroid Build Coastguard Worker num_workers=0, 604*da0073e9SAndroid Build Coastguard Worker drop_last=True, 605*da0073e9SAndroid Build Coastguard Worker ) 606*da0073e9SAndroid Build Coastguard Worker dataiter = iter(dataloader) 607*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(list(dataiter)), 1) 608*da0073e9SAndroid Build Coastguard Worker 609*da0073e9SAndroid Build Coastguard Worker @unittest.skip( 610*da0073e9SAndroid Build Coastguard Worker "FIXME: Intermittent CUDA out-of-memory error on Windows and time-out under ASAN" 611*da0073e9SAndroid Build Coastguard Worker ) 612*da0073e9SAndroid Build Coastguard Worker def test_multi_keep(self): 613*da0073e9SAndroid Build Coastguard Worker dataloader: DataLoader = DataLoader( 614*da0073e9SAndroid Build Coastguard Worker self.dataset, # type: ignore[arg-type] 615*da0073e9SAndroid Build Coastguard Worker batch_size=self.batch_size, 616*da0073e9SAndroid Build Coastguard Worker num_workers=2, 617*da0073e9SAndroid Build Coastguard Worker drop_last=False, 618*da0073e9SAndroid Build Coastguard Worker timeout=self.MAX_TIMEOUT_IN_SECOND, 619*da0073e9SAndroid Build Coastguard Worker ) 620*da0073e9SAndroid Build Coastguard Worker dataiter = iter(dataloader) 621*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(list(dataiter)), 2) 622*da0073e9SAndroid Build Coastguard Worker 623*da0073e9SAndroid Build Coastguard Worker def test_multi_drop(self): 624*da0073e9SAndroid Build Coastguard Worker dataloader: DataLoader = DataLoader( 625*da0073e9SAndroid Build Coastguard Worker self.dataset, # type: ignore[arg-type] 626*da0073e9SAndroid Build Coastguard Worker batch_size=self.batch_size, 627*da0073e9SAndroid Build Coastguard Worker num_workers=2, 628*da0073e9SAndroid Build Coastguard Worker drop_last=True, 629*da0073e9SAndroid Build Coastguard Worker timeout=self.MAX_TIMEOUT_IN_SECOND, 630*da0073e9SAndroid Build Coastguard Worker ) 631*da0073e9SAndroid Build Coastguard Worker dataiter = iter(dataloader) 632*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(list(dataiter)), 1) 633*da0073e9SAndroid Build Coastguard Worker 634*da0073e9SAndroid Build Coastguard Worker 635*da0073e9SAndroid Build Coastguard Workertest_dir = os.path.abspath(os.path.dirname(str(__file__))) 636*da0073e9SAndroid Build Coastguard Worker 637*da0073e9SAndroid Build Coastguard Worker 638*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf( 639*da0073e9SAndroid Build Coastguard Worker "SKIP_TEST_BOTTLENECK" in os.environ.keys(), "SKIP_TEST_BOTTLENECK is set" 640*da0073e9SAndroid Build Coastguard Worker) 641*da0073e9SAndroid Build Coastguard Workerclass TestBottleneck(TestCase): 642*da0073e9SAndroid Build Coastguard Worker def _run(self, command, timeout=30): 643*da0073e9SAndroid Build Coastguard Worker """Returns (return-code, stdout, stderr)""" 644*da0073e9SAndroid Build Coastguard Worker import subprocess 645*da0073e9SAndroid Build Coastguard Worker 646*da0073e9SAndroid Build Coastguard Worker p = subprocess.Popen( 647*da0073e9SAndroid Build Coastguard Worker command, 648*da0073e9SAndroid Build Coastguard Worker stdout=subprocess.PIPE, 649*da0073e9SAndroid Build Coastguard Worker stderr=subprocess.PIPE, 650*da0073e9SAndroid Build Coastguard Worker shell=True, 651*da0073e9SAndroid Build Coastguard Worker ) 652*da0073e9SAndroid Build Coastguard Worker try: 653*da0073e9SAndroid Build Coastguard Worker output, err = p.communicate(timeout=timeout) 654*da0073e9SAndroid Build Coastguard Worker except subprocess.TimeoutExpired: 655*da0073e9SAndroid Build Coastguard Worker p.kill() 656*da0073e9SAndroid Build Coastguard Worker output, err = p.communicate() 657*da0073e9SAndroid Build Coastguard Worker rc = p.returncode 658*da0073e9SAndroid Build Coastguard Worker output_str = output.decode("ascii") 659*da0073e9SAndroid Build Coastguard Worker err_str = err.decode("ascii") 660*da0073e9SAndroid Build Coastguard Worker return (rc, output_str, err_str) 661*da0073e9SAndroid Build Coastguard Worker 662*da0073e9SAndroid Build Coastguard Worker def _run_bottleneck(self, test_file, scriptargs=""): 663*da0073e9SAndroid Build Coastguard Worker curdir = os.path.dirname(os.path.abspath(__file__)) 664*da0073e9SAndroid Build Coastguard Worker filepath = f"{curdir}/{test_file}" 665*da0073e9SAndroid Build Coastguard Worker if scriptargs != "": 666*da0073e9SAndroid Build Coastguard Worker scriptargs = f" {scriptargs}" 667*da0073e9SAndroid Build Coastguard Worker rc, out, err = self._run( 668*da0073e9SAndroid Build Coastguard Worker f"{sys.executable} -m torch.utils.bottleneck {filepath}{scriptargs}" 669*da0073e9SAndroid Build Coastguard Worker ) 670*da0073e9SAndroid Build Coastguard Worker return rc, out, err 671*da0073e9SAndroid Build Coastguard Worker 672*da0073e9SAndroid Build Coastguard Worker def _check_run_args(self): 673*da0073e9SAndroid Build Coastguard Worker # Check that this fails due to missing args 674*da0073e9SAndroid Build Coastguard Worker rc, out, err = self._run_bottleneck("bottleneck_test/test_args.py") 675*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 676*da0073e9SAndroid Build Coastguard Worker rc, 677*da0073e9SAndroid Build Coastguard Worker 2, 678*da0073e9SAndroid Build Coastguard Worker atol=0, 679*da0073e9SAndroid Build Coastguard Worker rtol=0, 680*da0073e9SAndroid Build Coastguard Worker msg=self._fail_msg("Missing args should error", out + err), 681*da0073e9SAndroid Build Coastguard Worker ) 682*da0073e9SAndroid Build Coastguard Worker 683*da0073e9SAndroid Build Coastguard Worker # This should succeed 684*da0073e9SAndroid Build Coastguard Worker rc, out, err = self._run_bottleneck( 685*da0073e9SAndroid Build Coastguard Worker "bottleneck_test/test_args.py", "--foo foo --bar bar" 686*da0073e9SAndroid Build Coastguard Worker ) 687*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 688*da0073e9SAndroid Build Coastguard Worker rc, 689*da0073e9SAndroid Build Coastguard Worker 0, 690*da0073e9SAndroid Build Coastguard Worker atol=0, 691*da0073e9SAndroid Build Coastguard Worker rtol=0, 692*da0073e9SAndroid Build Coastguard Worker msg=self._fail_msg("Should pass args to script", out + err), 693*da0073e9SAndroid Build Coastguard Worker ) 694*da0073e9SAndroid Build Coastguard Worker 695*da0073e9SAndroid Build Coastguard Worker def _fail_msg(self, msg, output): 696*da0073e9SAndroid Build Coastguard Worker return f"{msg}, output was:\n{output}" 697*da0073e9SAndroid Build Coastguard Worker 698*da0073e9SAndroid Build Coastguard Worker def _check_environment_summary(self, output): 699*da0073e9SAndroid Build Coastguard Worker results = re.search("Environment Summary", output) 700*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone( 701*da0073e9SAndroid Build Coastguard Worker results, self._fail_msg("Should have Environment Summary", output) 702*da0073e9SAndroid Build Coastguard Worker ) 703*da0073e9SAndroid Build Coastguard Worker 704*da0073e9SAndroid Build Coastguard Worker # Up to five lines away from the heading, there should be the version number 705*da0073e9SAndroid Build Coastguard Worker results = re.search( 706*da0073e9SAndroid Build Coastguard Worker r"Environment Summary.*(\n.*){,5}\nPyTorch \d+\.\d+", output 707*da0073e9SAndroid Build Coastguard Worker ) 708*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone( 709*da0073e9SAndroid Build Coastguard Worker results, self._fail_msg("Should have PyTorch version", output) 710*da0073e9SAndroid Build Coastguard Worker ) 711*da0073e9SAndroid Build Coastguard Worker 712*da0073e9SAndroid Build Coastguard Worker def _check_cprof_summary(self, output): 713*da0073e9SAndroid Build Coastguard Worker results = re.search("cProfile output", output) 714*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone( 715*da0073e9SAndroid Build Coastguard Worker results, self._fail_msg("Should have cProfile output", output) 716*da0073e9SAndroid Build Coastguard Worker ) 717*da0073e9SAndroid Build Coastguard Worker 718*da0073e9SAndroid Build Coastguard Worker # This assumes that after the cProfile output section we have 719*da0073e9SAndroid Build Coastguard Worker # the autograd profiler output 720*da0073e9SAndroid Build Coastguard Worker results = re.search( 721*da0073e9SAndroid Build Coastguard Worker r"cProfile output.*(\n.*){6,50}\n.*autograd profiler output", output 722*da0073e9SAndroid Build Coastguard Worker ) 723*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone( 724*da0073e9SAndroid Build Coastguard Worker results, 725*da0073e9SAndroid Build Coastguard Worker self._fail_msg( 726*da0073e9SAndroid Build Coastguard Worker "Distance between cProfile and autograd prof out not in [6, 50] lines", 727*da0073e9SAndroid Build Coastguard Worker output, 728*da0073e9SAndroid Build Coastguard Worker ), 729*da0073e9SAndroid Build Coastguard Worker ) 730*da0073e9SAndroid Build Coastguard Worker 731*da0073e9SAndroid Build Coastguard Worker def _check_autograd_summary(self, output): 732*da0073e9SAndroid Build Coastguard Worker results = re.search("autograd profiler output", output) 733*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone( 734*da0073e9SAndroid Build Coastguard Worker results, self._fail_msg("Should have autograd profiler output", output) 735*da0073e9SAndroid Build Coastguard Worker ) 736*da0073e9SAndroid Build Coastguard Worker 737*da0073e9SAndroid Build Coastguard Worker # This assumes that after the autograd profiler output is the end of the 738*da0073e9SAndroid Build Coastguard Worker # output. 739*da0073e9SAndroid Build Coastguard Worker results = re.search(r"autograd profiler output.*(\n.*){6,100}", output) 740*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone( 741*da0073e9SAndroid Build Coastguard Worker results, 742*da0073e9SAndroid Build Coastguard Worker self._fail_msg( 743*da0073e9SAndroid Build Coastguard Worker "Distance between autograd prof output and end of output not in [6, 100] lines", 744*da0073e9SAndroid Build Coastguard Worker output, 745*da0073e9SAndroid Build Coastguard Worker ), 746*da0073e9SAndroid Build Coastguard Worker ) 747*da0073e9SAndroid Build Coastguard Worker 748*da0073e9SAndroid Build Coastguard Worker def _check_cuda(self, output): 749*da0073e9SAndroid Build Coastguard Worker if HAS_CUDA: 750*da0073e9SAndroid Build Coastguard Worker results = re.search("CUDA mode", output) 751*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone( 752*da0073e9SAndroid Build Coastguard Worker results, self._fail_msg("Should tell users CUDA", output) 753*da0073e9SAndroid Build Coastguard Worker ) 754*da0073e9SAndroid Build Coastguard Worker else: 755*da0073e9SAndroid Build Coastguard Worker results = re.search("CUDA mode", output) 756*da0073e9SAndroid Build Coastguard Worker self.assertIsNone( 757*da0073e9SAndroid Build Coastguard Worker results, self._fail_msg("Should not tell users about CUDA", output) 758*da0073e9SAndroid Build Coastguard Worker ) 759*da0073e9SAndroid Build Coastguard Worker 760*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(HAS_CUDA, "CPU-only test") 761*da0073e9SAndroid Build Coastguard Worker def test_bottleneck_cpu_only(self): 762*da0073e9SAndroid Build Coastguard Worker rc, out, err = self._run_bottleneck("bottleneck_test/test.py") 763*da0073e9SAndroid Build Coastguard Worker self.assertEqual(rc, 0, msg=f"Run failed with\n{err}") 764*da0073e9SAndroid Build Coastguard Worker 765*da0073e9SAndroid Build Coastguard Worker self._check_run_args() 766*da0073e9SAndroid Build Coastguard Worker self._check_environment_summary(out) 767*da0073e9SAndroid Build Coastguard Worker self._check_autograd_summary(out) 768*da0073e9SAndroid Build Coastguard Worker self._check_cprof_summary(out) 769*da0073e9SAndroid Build Coastguard Worker self._check_cuda(out) 770*da0073e9SAndroid Build Coastguard Worker 771*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not HAS_CUDA, "No CUDA") 772*da0073e9SAndroid Build Coastguard Worker def test_bottleneck_cuda(self): 773*da0073e9SAndroid Build Coastguard Worker rc, out, err = self._run_bottleneck("bottleneck_test/test_cuda.py") 774*da0073e9SAndroid Build Coastguard Worker self.assertEqual(rc, 0, msg=f"Run failed with\n{err}") 775*da0073e9SAndroid Build Coastguard Worker 776*da0073e9SAndroid Build Coastguard Worker self._check_run_args() 777*da0073e9SAndroid Build Coastguard Worker self._check_environment_summary(out) 778*da0073e9SAndroid Build Coastguard Worker self._check_autograd_summary(out) 779*da0073e9SAndroid Build Coastguard Worker self._check_cprof_summary(out) 780*da0073e9SAndroid Build Coastguard Worker self._check_cuda(out) 781*da0073e9SAndroid Build Coastguard Worker 782*da0073e9SAndroid Build Coastguard Worker 783*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.collect_env import get_pretty_env_info 784*da0073e9SAndroid Build Coastguard Worker 785*da0073e9SAndroid Build Coastguard Worker 786*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(IS_FBCODE, "runs pip which is not available internally") 787*da0073e9SAndroid Build Coastguard Workerclass TestCollectEnv(TestCase): 788*da0073e9SAndroid Build Coastguard Worker def test_smoke(self): 789*da0073e9SAndroid Build Coastguard Worker info_output = get_pretty_env_info() 790*da0073e9SAndroid Build Coastguard Worker self.assertTrue(info_output.count("\n") >= 17) 791*da0073e9SAndroid Build Coastguard Worker 792*da0073e9SAndroid Build Coastguard Worker 793*da0073e9SAndroid Build Coastguard Workerclass TestONNXUtils(TestCase): 794*da0073e9SAndroid Build Coastguard Worker def test_prepare_onnx_paddings(self): 795*da0073e9SAndroid Build Coastguard Worker sizes = [2, 3, 4] 796*da0073e9SAndroid Build Coastguard Worker pad = [1, 2, 3, 4] 797*da0073e9SAndroid Build Coastguard Worker paddings = _prepare_onnx_paddings(len(sizes), pad) 798*da0073e9SAndroid Build Coastguard Worker self.assertEqual(paddings, [0, 3, 1, 0, 4, 2]) 799*da0073e9SAndroid Build Coastguard Worker 800*da0073e9SAndroid Build Coastguard Worker def test_check_onnx_broadcast(self): 801*da0073e9SAndroid Build Coastguard Worker def try_check_onnx_broadcast(dims1, dims2, expect_broadcast, expect_fail): 802*da0073e9SAndroid Build Coastguard Worker broadcast = True 803*da0073e9SAndroid Build Coastguard Worker fail = False 804*da0073e9SAndroid Build Coastguard Worker try: 805*da0073e9SAndroid Build Coastguard Worker broadcast = check_onnx_broadcast(dims1, dims2) 806*da0073e9SAndroid Build Coastguard Worker except ValueError: 807*da0073e9SAndroid Build Coastguard Worker fail = True 808*da0073e9SAndroid Build Coastguard Worker self.assertEqual(broadcast, expect_broadcast) 809*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fail, expect_fail) 810*da0073e9SAndroid Build Coastguard Worker 811*da0073e9SAndroid Build Coastguard Worker # Case 1, check the case when len(dims1) < len(dims2) and numel(dims2) > 1 812*da0073e9SAndroid Build Coastguard Worker dims1 = [3, 4] 813*da0073e9SAndroid Build Coastguard Worker dims2 = [2, 3, 4] 814*da0073e9SAndroid Build Coastguard Worker try_check_onnx_broadcast(dims1, dims2, True, True) 815*da0073e9SAndroid Build Coastguard Worker 816*da0073e9SAndroid Build Coastguard Worker # Case 2, check the case when len(dims1) < len(dims2) and numel(dims2) == 1 817*da0073e9SAndroid Build Coastguard Worker dims1 = [3, 4] 818*da0073e9SAndroid Build Coastguard Worker dims2 = [1, 1, 1] 819*da0073e9SAndroid Build Coastguard Worker try_check_onnx_broadcast(dims1, dims2, True, False) 820*da0073e9SAndroid Build Coastguard Worker 821*da0073e9SAndroid Build Coastguard Worker # Case 3, check the case when len(dims1) > len(dims2) and numel(dims2) == 1 822*da0073e9SAndroid Build Coastguard Worker dims1 = [1, 1] 823*da0073e9SAndroid Build Coastguard Worker dims2 = [1] 824*da0073e9SAndroid Build Coastguard Worker try_check_onnx_broadcast(dims1, dims2, True, False) 825*da0073e9SAndroid Build Coastguard Worker 826*da0073e9SAndroid Build Coastguard Worker # Case 4, check the case when len(dims1) > len(dims2) and dims1[x:] == dims2 827*da0073e9SAndroid Build Coastguard Worker dims1 = [2, 3, 4] 828*da0073e9SAndroid Build Coastguard Worker dims2 = [3, 4] 829*da0073e9SAndroid Build Coastguard Worker try_check_onnx_broadcast(dims1, dims2, True, False) 830*da0073e9SAndroid Build Coastguard Worker 831*da0073e9SAndroid Build Coastguard Worker # Case 5, check the case when len(dims1) > len(dims2), but dims1[x:] != dims2 832*da0073e9SAndroid Build Coastguard Worker dims1 = [2, 3, 4] 833*da0073e9SAndroid Build Coastguard Worker dims2 = [1, 4] 834*da0073e9SAndroid Build Coastguard Worker try_check_onnx_broadcast(dims1, dims2, True, True) 835*da0073e9SAndroid Build Coastguard Worker 836*da0073e9SAndroid Build Coastguard Worker # Case 6, check the equal case, no broadcast 837*da0073e9SAndroid Build Coastguard Worker dims1 = [3, 4] 838*da0073e9SAndroid Build Coastguard Worker dims2 = [3, 4] 839*da0073e9SAndroid Build Coastguard Worker try_check_onnx_broadcast(dims1, dims2, False, False) 840*da0073e9SAndroid Build Coastguard Worker 841*da0073e9SAndroid Build Coastguard Worker # Case 7, check the case when len(dims1) == len(dims2), but dims1 != dims2 842*da0073e9SAndroid Build Coastguard Worker dims1 = [3, 4] 843*da0073e9SAndroid Build Coastguard Worker dims2 = [1, 4] 844*da0073e9SAndroid Build Coastguard Worker try_check_onnx_broadcast(dims1, dims2, True, True) 845*da0073e9SAndroid Build Coastguard Worker 846*da0073e9SAndroid Build Coastguard Worker # Case 8, check the case when len(dims1) == len(dims2) and numel(s2) == 1 847*da0073e9SAndroid Build Coastguard Worker dims1 = [3, 4] 848*da0073e9SAndroid Build Coastguard Worker dims2 = [1, 1] 849*da0073e9SAndroid Build Coastguard Worker try_check_onnx_broadcast(dims1, dims2, True, False) 850*da0073e9SAndroid Build Coastguard Worker 851*da0073e9SAndroid Build Coastguard Worker 852*da0073e9SAndroid Build Coastguard Workerclass TestHipify(TestCase): 853*da0073e9SAndroid Build Coastguard Worker def test_import_hipify(self): 854*da0073e9SAndroid Build Coastguard Worker from torch.utils.hipify import hipify_python # noqa: F401 855*da0073e9SAndroid Build Coastguard Worker 856*da0073e9SAndroid Build Coastguard Worker 857*da0073e9SAndroid Build Coastguard Workerclass TestHipifyTrie(TestCase): 858*da0073e9SAndroid Build Coastguard Worker def setUp(self): 859*da0073e9SAndroid Build Coastguard Worker self.trie = torch.utils.hipify.hipify_python.Trie() 860*da0073e9SAndroid Build Coastguard Worker 861*da0073e9SAndroid Build Coastguard Worker def test_add_and_search_trie(self): 862*da0073e9SAndroid Build Coastguard Worker self.trie.add("banana") 863*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.trie.search("banana")) 864*da0073e9SAndroid Build Coastguard Worker self.assertFalse(self.trie.search("ban")) 865*da0073e9SAndroid Build Coastguard Worker self.assertFalse(self.trie.search("dog")) 866*da0073e9SAndroid Build Coastguard Worker 867*da0073e9SAndroid Build Coastguard Worker def test_add_multiple_and_search_trie(self): 868*da0073e9SAndroid Build Coastguard Worker words_to_add = ["banana", "apple", "orange"] 869*da0073e9SAndroid Build Coastguard Worker for word in words_to_add: 870*da0073e9SAndroid Build Coastguard Worker self.trie.add(word) 871*da0073e9SAndroid Build Coastguard Worker 872*da0073e9SAndroid Build Coastguard Worker for word in words_to_add: 873*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.trie.search(word)) 874*da0073e9SAndroid Build Coastguard Worker 875*da0073e9SAndroid Build Coastguard Worker for word in ["ban", "dog", "okay", "app"]: 876*da0073e9SAndroid Build Coastguard Worker self.assertFalse(self.trie.search(word)) 877*da0073e9SAndroid Build Coastguard Worker 878*da0073e9SAndroid Build Coastguard Worker def test_quote_escape(self): 879*da0073e9SAndroid Build Coastguard Worker orig_chars = ["*", "[", ".", "+", "a", "z", "-"] 880*da0073e9SAndroid Build Coastguard Worker quoted_strs = ["\\*", "\\[", "\\.", "\\+", "a", "z", "\\-"] 881*da0073e9SAndroid Build Coastguard Worker for i in range(len(orig_chars)): 882*da0073e9SAndroid Build Coastguard Worker self.assertEqual(self.trie.quote(orig_chars[i]), quoted_strs[i]) 883*da0073e9SAndroid Build Coastguard Worker 884*da0073e9SAndroid Build Coastguard Worker def test_export_trie_to_regex(self): 885*da0073e9SAndroid Build Coastguard Worker words_to_add = [ 886*da0073e9SAndroid Build Coastguard Worker "__CUDACC__", 887*da0073e9SAndroid Build Coastguard Worker "CUDA_ERROR_CONTEXT_ALREADY_CURRENT", 888*da0073e9SAndroid Build Coastguard Worker "CUDA_ERROR_ARRAY_IS_MAPPED", 889*da0073e9SAndroid Build Coastguard Worker "CUDA_ERROR_NOT_MAPPED", 890*da0073e9SAndroid Build Coastguard Worker "CUDA_ERROR_INVALID_SOURCE", 891*da0073e9SAndroid Build Coastguard Worker ] 892*da0073e9SAndroid Build Coastguard Worker for word in words_to_add: 893*da0073e9SAndroid Build Coastguard Worker self.trie.add(word) 894*da0073e9SAndroid Build Coastguard Worker regex = self.trie.export_to_regex() 895*da0073e9SAndroid Build Coastguard Worker expected_regex = r"(?:CUDA_ERROR_(?:ARRAY_IS_MAPPED|CONTEXT_ALREADY_CURRENT|INVALID_SOURCE|NOT_MAPPED)|__CUDACC__)" 896*da0073e9SAndroid Build Coastguard Worker self.assertEqual(regex, expected_regex) 897*da0073e9SAndroid Build Coastguard Worker 898*da0073e9SAndroid Build Coastguard Worker def test_prefix_words_export_trie_to_regex(self): 899*da0073e9SAndroid Build Coastguard Worker # test case where some nodes have both children and are also leaf nodes. 900*da0073e9SAndroid Build Coastguard Worker words_to_add = ["apple", "app", "ban", "banana"] 901*da0073e9SAndroid Build Coastguard Worker for word in words_to_add: 902*da0073e9SAndroid Build Coastguard Worker self.trie.add(word) 903*da0073e9SAndroid Build Coastguard Worker regex = self.trie.export_to_regex() 904*da0073e9SAndroid Build Coastguard Worker expected_regex = r"(?:app(?:le)?|ban(?:ana)?)" 905*da0073e9SAndroid Build Coastguard Worker self.assertEqual(regex, expected_regex) 906*da0073e9SAndroid Build Coastguard Worker 907*da0073e9SAndroid Build Coastguard Worker def test_single_export_trie_to_regex(self): 908*da0073e9SAndroid Build Coastguard Worker words_to_add = ["cudaErrorInvalidMemcpyDirection"] 909*da0073e9SAndroid Build Coastguard Worker for word in words_to_add: 910*da0073e9SAndroid Build Coastguard Worker self.trie.add(word) 911*da0073e9SAndroid Build Coastguard Worker regex = self.trie.export_to_regex() 912*da0073e9SAndroid Build Coastguard Worker expected_regex = "cudaErrorInvalidMemcpyDirection" 913*da0073e9SAndroid Build Coastguard Worker self.assertEqual(regex, expected_regex) 914*da0073e9SAndroid Build Coastguard Worker 915*da0073e9SAndroid Build Coastguard Worker def test_char_export_trie_to_regex(self): 916*da0073e9SAndroid Build Coastguard Worker self.trie.add("a") 917*da0073e9SAndroid Build Coastguard Worker self.assertEqual(self.trie.export_to_regex(), "a") 918*da0073e9SAndroid Build Coastguard Worker self.trie.add("b") 919*da0073e9SAndroid Build Coastguard Worker self.assertEqual(self.trie.export_to_regex(), "[ab]") 920*da0073e9SAndroid Build Coastguard Worker 921*da0073e9SAndroid Build Coastguard Worker def test_special_char_export_trie_to_regex(self): 922*da0073e9SAndroid Build Coastguard Worker self.trie.add(r"c*") 923*da0073e9SAndroid Build Coastguard Worker self.assertEqual(self.trie.export_to_regex(), r"c\*") 924*da0073e9SAndroid Build Coastguard Worker 925*da0073e9SAndroid Build Coastguard Worker 926*da0073e9SAndroid Build Coastguard Workerclass TestAssert(TestCase): 927*da0073e9SAndroid Build Coastguard Worker def test_assert_true(self): 928*da0073e9SAndroid Build Coastguard Worker # verify assertions work as expected 929*da0073e9SAndroid Build Coastguard Worker # bool argument 930*da0073e9SAndroid Build Coastguard Worker torch._assert(True, "foo") 931*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, "bar"): 932*da0073e9SAndroid Build Coastguard Worker torch._assert(False, "bar") 933*da0073e9SAndroid Build Coastguard Worker # tensor argument 934*da0073e9SAndroid Build Coastguard Worker torch._assert(torch.tensor([True], dtype=torch.bool), "foo") 935*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, "bar"): 936*da0073e9SAndroid Build Coastguard Worker torch._assert(torch.tensor([False], dtype=torch.bool), "bar") 937*da0073e9SAndroid Build Coastguard Worker 938*da0073e9SAndroid Build Coastguard Worker def test_assert_scriptable(self): 939*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 940*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 941*da0073e9SAndroid Build Coastguard Worker torch._assert(x.sum() > 0, "foo") 942*da0073e9SAndroid Build Coastguard Worker return x 943*da0073e9SAndroid Build Coastguard Worker 944*da0073e9SAndroid Build Coastguard Worker m = M() 945*da0073e9SAndroid Build Coastguard Worker # scriptable 946*da0073e9SAndroid Build Coastguard Worker ms = torch.jit.script(m) 947*da0073e9SAndroid Build Coastguard Worker # data can be passed without errors 948*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, 4).fill_(1.0) 949*da0073e9SAndroid Build Coastguard Worker ms(x) 950*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(torch.jit.Error, "foo"): 951*da0073e9SAndroid Build Coastguard Worker ms(torch.tensor([False], dtype=torch.bool)) 952*da0073e9SAndroid Build Coastguard Worker 953*da0073e9SAndroid Build Coastguard Worker 954*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(IS_SANDCASTLE, "cpp_extension is OSS only") 955*da0073e9SAndroid Build Coastguard Workerclass TestStandaloneCPPJIT(TestCase): 956*da0073e9SAndroid Build Coastguard Worker def test_load_standalone(self): 957*da0073e9SAndroid Build Coastguard Worker build_dir = tempfile.mkdtemp() 958*da0073e9SAndroid Build Coastguard Worker try: 959*da0073e9SAndroid Build Coastguard Worker src_path = os.path.join(build_dir, "main.cpp") 960*da0073e9SAndroid Build Coastguard Worker src = textwrap.dedent( 961*da0073e9SAndroid Build Coastguard Worker """\ 962*da0073e9SAndroid Build Coastguard Worker #include <iostream> 963*da0073e9SAndroid Build Coastguard Worker #include <torch/torch.h> 964*da0073e9SAndroid Build Coastguard Worker int main() { 965*da0073e9SAndroid Build Coastguard Worker auto x = torch::eye(3); 966*da0073e9SAndroid Build Coastguard Worker std::cout << x << std::endl; 967*da0073e9SAndroid Build Coastguard Worker } 968*da0073e9SAndroid Build Coastguard Worker """ 969*da0073e9SAndroid Build Coastguard Worker ) 970*da0073e9SAndroid Build Coastguard Worker with open(src_path, "w") as f: 971*da0073e9SAndroid Build Coastguard Worker f.write(src) 972*da0073e9SAndroid Build Coastguard Worker 973*da0073e9SAndroid Build Coastguard Worker exec_path = torch.utils.cpp_extension.load( 974*da0073e9SAndroid Build Coastguard Worker "standalone_load_test", 975*da0073e9SAndroid Build Coastguard Worker src_path, 976*da0073e9SAndroid Build Coastguard Worker build_directory=build_dir, 977*da0073e9SAndroid Build Coastguard Worker is_python_module=False, 978*da0073e9SAndroid Build Coastguard Worker is_standalone=True, 979*da0073e9SAndroid Build Coastguard Worker ) 980*da0073e9SAndroid Build Coastguard Worker 981*da0073e9SAndroid Build Coastguard Worker ext = ".exe" if IS_WINDOWS else "" 982*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 983*da0073e9SAndroid Build Coastguard Worker exec_path, os.path.join(build_dir, f"standalone_load_test{ext}") 984*da0073e9SAndroid Build Coastguard Worker ) 985*da0073e9SAndroid Build Coastguard Worker 986*da0073e9SAndroid Build Coastguard Worker for shell in [True, False]: 987*da0073e9SAndroid Build Coastguard Worker r = subprocess.run( 988*da0073e9SAndroid Build Coastguard Worker [exec_path], 989*da0073e9SAndroid Build Coastguard Worker shell=shell, 990*da0073e9SAndroid Build Coastguard Worker stdout=subprocess.PIPE, 991*da0073e9SAndroid Build Coastguard Worker ) 992*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r.returncode, 0) 993*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 994*da0073e9SAndroid Build Coastguard Worker # Windows prints "\r\n" for newlines. 995*da0073e9SAndroid Build Coastguard Worker textwrap.dedent(r.stdout.decode("utf-8")).replace("\r\n", "\n"), 996*da0073e9SAndroid Build Coastguard Worker textwrap.dedent( 997*da0073e9SAndroid Build Coastguard Worker """\ 998*da0073e9SAndroid Build Coastguard Worker 1 0 0 999*da0073e9SAndroid Build Coastguard Worker 0 1 0 1000*da0073e9SAndroid Build Coastguard Worker 0 0 1 1001*da0073e9SAndroid Build Coastguard Worker [ CPUFloatType{3,3} ] 1002*da0073e9SAndroid Build Coastguard Worker """ 1003*da0073e9SAndroid Build Coastguard Worker ), 1004*da0073e9SAndroid Build Coastguard Worker ) 1005*da0073e9SAndroid Build Coastguard Worker 1006*da0073e9SAndroid Build Coastguard Worker finally: 1007*da0073e9SAndroid Build Coastguard Worker shutil.rmtree(build_dir) 1008*da0073e9SAndroid Build Coastguard Worker 1009*da0073e9SAndroid Build Coastguard Worker 1010*da0073e9SAndroid Build Coastguard Workerclass DummyPrivateUse1Module: 1011*da0073e9SAndroid Build Coastguard Worker @staticmethod 1012*da0073e9SAndroid Build Coastguard Worker def is_available(): 1013*da0073e9SAndroid Build Coastguard Worker return True 1014*da0073e9SAndroid Build Coastguard Worker 1015*da0073e9SAndroid Build Coastguard Worker @staticmethod 1016*da0073e9SAndroid Build Coastguard Worker def is_autocast_enabled(): 1017*da0073e9SAndroid Build Coastguard Worker return True 1018*da0073e9SAndroid Build Coastguard Worker 1019*da0073e9SAndroid Build Coastguard Worker @staticmethod 1020*da0073e9SAndroid Build Coastguard Worker def get_autocast_dtype(): 1021*da0073e9SAndroid Build Coastguard Worker return torch.float16 1022*da0073e9SAndroid Build Coastguard Worker 1023*da0073e9SAndroid Build Coastguard Worker @staticmethod 1024*da0073e9SAndroid Build Coastguard Worker def set_autocast_enabled(enable): 1025*da0073e9SAndroid Build Coastguard Worker pass 1026*da0073e9SAndroid Build Coastguard Worker 1027*da0073e9SAndroid Build Coastguard Worker @staticmethod 1028*da0073e9SAndroid Build Coastguard Worker def set_autocast_dtype(dtype): 1029*da0073e9SAndroid Build Coastguard Worker pass 1030*da0073e9SAndroid Build Coastguard Worker 1031*da0073e9SAndroid Build Coastguard Worker @staticmethod 1032*da0073e9SAndroid Build Coastguard Worker def get_amp_supported_dtype(): 1033*da0073e9SAndroid Build Coastguard Worker return [torch.float16] 1034*da0073e9SAndroid Build Coastguard Worker 1035*da0073e9SAndroid Build Coastguard Worker 1036*da0073e9SAndroid Build Coastguard Workerclass TestExtensionUtils(TestCase): 1037*da0073e9SAndroid Build Coastguard Worker def tearDown(self): 1038*da0073e9SAndroid Build Coastguard Worker # Clean up 1039*da0073e9SAndroid Build Coastguard Worker backend_name = torch._C._get_privateuse1_backend_name() 1040*da0073e9SAndroid Build Coastguard Worker if hasattr(torch, backend_name): 1041*da0073e9SAndroid Build Coastguard Worker delattr(torch, backend_name) 1042*da0073e9SAndroid Build Coastguard Worker if f"torch.{backend_name}" in sys.modules: 1043*da0073e9SAndroid Build Coastguard Worker del sys.modules[f"torch.{backend_name}"] 1044*da0073e9SAndroid Build Coastguard Worker 1045*da0073e9SAndroid Build Coastguard Worker def test_external_module_register(self): 1046*da0073e9SAndroid Build Coastguard Worker # Built-in module 1047*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "The runtime module of"): 1048*da0073e9SAndroid Build Coastguard Worker torch._register_device_module("cuda", torch.cuda) 1049*da0073e9SAndroid Build Coastguard Worker 1050*da0073e9SAndroid Build Coastguard Worker # Wrong device type 1051*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Expected one of cpu"): 1052*da0073e9SAndroid Build Coastguard Worker torch._register_device_module("dummmy", DummyPrivateUse1Module) 1053*da0073e9SAndroid Build Coastguard Worker 1054*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(AttributeError): 1055*da0073e9SAndroid Build Coastguard Worker torch.privateuseone.is_available() # type: ignore[attr-defined] 1056*da0073e9SAndroid Build Coastguard Worker 1057*da0073e9SAndroid Build Coastguard Worker torch._register_device_module("privateuseone", DummyPrivateUse1Module) 1058*da0073e9SAndroid Build Coastguard Worker 1059*da0073e9SAndroid Build Coastguard Worker torch.privateuseone.is_available() # type: ignore[attr-defined] 1060*da0073e9SAndroid Build Coastguard Worker 1061*da0073e9SAndroid Build Coastguard Worker # No supporting for override 1062*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "The runtime module of"): 1063*da0073e9SAndroid Build Coastguard Worker torch._register_device_module("privateuseone", DummyPrivateUse1Module) 1064*da0073e9SAndroid Build Coastguard Worker 1065*da0073e9SAndroid Build Coastguard Worker def test_external_module_register_with_renamed_backend(self): 1066*da0073e9SAndroid Build Coastguard Worker torch.utils.rename_privateuse1_backend("foo") 1067*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "has already been set"): 1068*da0073e9SAndroid Build Coastguard Worker torch.utils.rename_privateuse1_backend("dummmy") 1069*da0073e9SAndroid Build Coastguard Worker 1070*da0073e9SAndroid Build Coastguard Worker custom_backend_name = torch._C._get_privateuse1_backend_name() 1071*da0073e9SAndroid Build Coastguard Worker self.assertEqual(custom_backend_name, "foo") 1072*da0073e9SAndroid Build Coastguard Worker 1073*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(AttributeError): 1074*da0073e9SAndroid Build Coastguard Worker torch.foo.is_available() # type: ignore[attr-defined] 1075*da0073e9SAndroid Build Coastguard Worker 1076*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, "Tried to use AMP with the"): 1077*da0073e9SAndroid Build Coastguard Worker with torch.autocast(device_type=custom_backend_name): 1078*da0073e9SAndroid Build Coastguard Worker pass 1079*da0073e9SAndroid Build Coastguard Worker torch._register_device_module("foo", DummyPrivateUse1Module) 1080*da0073e9SAndroid Build Coastguard Worker 1081*da0073e9SAndroid Build Coastguard Worker torch.foo.is_available() # type: ignore[attr-defined] 1082*da0073e9SAndroid Build Coastguard Worker with torch.autocast(device_type=custom_backend_name): 1083*da0073e9SAndroid Build Coastguard Worker pass 1084*da0073e9SAndroid Build Coastguard Worker 1085*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch._utils._get_device_index("foo:1"), 1) 1086*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch._utils._get_device_index(torch.device("foo:2")), 2) 1087*da0073e9SAndroid Build Coastguard Worker 1088*da0073e9SAndroid Build Coastguard Worker 1089*da0073e9SAndroid Build Coastguard Workerclass TestRenderUtils(TestCase): 1090*da0073e9SAndroid Build Coastguard Worker def test_basic(self): 1091*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 1092*da0073e9SAndroid Build Coastguard Worker torch._utils.render_call(torch.sum, [torch.randn(100)], {"dim": 0}), 1093*da0073e9SAndroid Build Coastguard Worker """torch.sum(tensor([...], size=(100,)), dim=0)""", 1094*da0073e9SAndroid Build Coastguard Worker ) 1095*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 1096*da0073e9SAndroid Build Coastguard Worker torch._utils.render_call(torch.sum, [torch.randn(100, 100)], {"dim": 0}), 1097*da0073e9SAndroid Build Coastguard Worker """torch.sum(tensor([...], size=(100, 100)), dim=0)""", 1098*da0073e9SAndroid Build Coastguard Worker ) 1099*da0073e9SAndroid Build Coastguard Worker 1100*da0073e9SAndroid Build Coastguard Worker 1101*da0073e9SAndroid Build Coastguard Workerclass TestDeviceUtils(TestCase): 1102*da0073e9SAndroid Build Coastguard Worker def test_basic(self): 1103*da0073e9SAndroid Build Coastguard Worker with torch.device("meta") as dev: 1104*da0073e9SAndroid Build Coastguard Worker x = torch.empty(3, 3) 1105*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.device.type, "meta") 1106*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dev, torch.device("meta")) 1107*da0073e9SAndroid Build Coastguard Worker 1108*da0073e9SAndroid Build Coastguard Worker def test_decorator(self): 1109*da0073e9SAndroid Build Coastguard Worker @set_device("meta") 1110*da0073e9SAndroid Build Coastguard Worker def f(): 1111*da0073e9SAndroid Build Coastguard Worker return torch.empty(3, 3) 1112*da0073e9SAndroid Build Coastguard Worker 1113*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f().device.type, "meta") 1114*da0073e9SAndroid Build Coastguard Worker 1115*da0073e9SAndroid Build Coastguard Worker def test_decorator_generator(self): 1116*da0073e9SAndroid Build Coastguard Worker @set_device("meta") 1117*da0073e9SAndroid Build Coastguard Worker def f(): 1118*da0073e9SAndroid Build Coastguard Worker yield torch.empty(3, 3) 1119*da0073e9SAndroid Build Coastguard Worker yield torch.empty(3, 3) 1120*da0073e9SAndroid Build Coastguard Worker 1121*da0073e9SAndroid Build Coastguard Worker r1, r2 = list(f()) 1122*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r1.device.type, "meta") 1123*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r2.device.type, "meta") 1124*da0073e9SAndroid Build Coastguard Worker 1125*da0073e9SAndroid Build Coastguard Worker def test_nn_module(self): 1126*da0073e9SAndroid Build Coastguard Worker with torch.device("meta"): 1127*da0073e9SAndroid Build Coastguard Worker m = nn.Linear(40, 50) 1128*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m.weight.device.type, "meta") 1129*da0073e9SAndroid Build Coastguard Worker 1130*da0073e9SAndroid Build Coastguard Worker def test_set_default_device(self): 1131*da0073e9SAndroid Build Coastguard Worker try: 1132*da0073e9SAndroid Build Coastguard Worker torch.set_default_device("meta") 1133*da0073e9SAndroid Build Coastguard Worker r = torch.empty(2, 2) 1134*da0073e9SAndroid Build Coastguard Worker finally: 1135*da0073e9SAndroid Build Coastguard Worker torch.set_default_device(None) 1136*da0073e9SAndroid Build Coastguard Worker 1137*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r.device.type, "meta") 1138*da0073e9SAndroid Build Coastguard Worker 1139*da0073e9SAndroid Build Coastguard Worker def test_get_default_device(self): 1140*da0073e9SAndroid Build Coastguard Worker torch.set_default_device("meta") 1141*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.get_default_device().type, "meta") 1142*da0073e9SAndroid Build Coastguard Worker torch.set_default_device(None) 1143*da0073e9SAndroid Build Coastguard Worker 1144*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") 1145*da0073e9SAndroid Build Coastguard Worker def test_get_default_device_more(self): 1146*da0073e9SAndroid Build Coastguard Worker torch.set_default_device("cuda") 1147*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.get_default_device(), torch.tensor([]).device) 1148*da0073e9SAndroid Build Coastguard Worker torch.set_default_device(None) 1149*da0073e9SAndroid Build Coastguard Worker 1150*da0073e9SAndroid Build Coastguard Worker torch.set_default_device("cuda") 1151*da0073e9SAndroid Build Coastguard Worker torch.cuda.set_device("cuda:1") 1152*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.get_default_device(), torch.tensor([]).device) 1153*da0073e9SAndroid Build Coastguard Worker torch.set_default_device(None) 1154*da0073e9SAndroid Build Coastguard Worker 1155*da0073e9SAndroid Build Coastguard Worker torch.set_default_device("cuda:1") 1156*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.get_default_device(), torch.tensor([]).device) 1157*da0073e9SAndroid Build Coastguard Worker torch.set_default_device(None) 1158*da0073e9SAndroid Build Coastguard Worker 1159*da0073e9SAndroid Build Coastguard Worker @onlyCPU 1160*da0073e9SAndroid Build Coastguard Worker @ops(op_db) 1161*da0073e9SAndroid Build Coastguard Worker def test_device_mode_ops(self, device, dtype, op): 1162*da0073e9SAndroid Build Coastguard Worker func = op.get_op() 1163*da0073e9SAndroid Build Coastguard Worker samples = op.sample_inputs(device, dtype, requires_grad=False) 1164*da0073e9SAndroid Build Coastguard Worker for sample in samples: 1165*da0073e9SAndroid Build Coastguard Worker # Only test samples which don't have Tensor inputs. However, 1166*da0073e9SAndroid Build Coastguard Worker # we don't test the factory property on OpInfo as it is very, 1167*da0073e9SAndroid Build Coastguard Worker # very incomplete 1168*da0073e9SAndroid Build Coastguard Worker if tree_any( 1169*da0073e9SAndroid Build Coastguard Worker lambda x: isinstance(x, torch.Tensor), 1170*da0073e9SAndroid Build Coastguard Worker (sample.input, sample.args, sample.kwargs), 1171*da0073e9SAndroid Build Coastguard Worker ): 1172*da0073e9SAndroid Build Coastguard Worker continue 1173*da0073e9SAndroid Build Coastguard Worker # Many OpInfos will explicitly pass in a device. DeviceContext 1174*da0073e9SAndroid Build Coastguard Worker # will respect device if it is explicitly specified. To test 1175*da0073e9SAndroid Build Coastguard Worker # DeviceContext, we have to remove the device kwarg in this case. 1176*da0073e9SAndroid Build Coastguard Worker # NB: Can't pass None to sample_inputs, the function can't 1177*da0073e9SAndroid Build Coastguard Worker # handle it. 1178*da0073e9SAndroid Build Coastguard Worker kwargs = sample.kwargs.copy() 1179*da0073e9SAndroid Build Coastguard Worker kwargs.pop("device", None) 1180*da0073e9SAndroid Build Coastguard Worker with torch.device("meta"): 1181*da0073e9SAndroid Build Coastguard Worker r = func(sample.input, *sample.args, **kwargs) 1182*da0073e9SAndroid Build Coastguard Worker 1183*da0073e9SAndroid Build Coastguard Worker def is_meta_device(x: torch.Tensor) -> bool: 1184*da0073e9SAndroid Build Coastguard Worker return x.device.type == "meta" 1185*da0073e9SAndroid Build Coastguard Worker 1186*da0073e9SAndroid Build Coastguard Worker self.assertTrue(tree_all_only(torch.Tensor, is_meta_device, r)) 1187*da0073e9SAndroid Build Coastguard Worker 1188*da0073e9SAndroid Build Coastguard Worker 1189*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestDeviceUtils, globals()) 1190*da0073e9SAndroid Build Coastguard Worker 1191*da0073e9SAndroid Build Coastguard Worker 1192*da0073e9SAndroid Build Coastguard Workerclass TestCppExtensionUtils(TestCase): 1193*da0073e9SAndroid Build Coastguard Worker def test_cpp_compiler_is_ok(self): 1194*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.utils.cpp_extension.check_compiler_ok_for_platform("c++")) 1195*da0073e9SAndroid Build Coastguard Worker 1196*da0073e9SAndroid Build Coastguard Worker def test_cc_compiler_is_ok(self): 1197*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.utils.cpp_extension.check_compiler_ok_for_platform("cc")) 1198*da0073e9SAndroid Build Coastguard Worker 1199*da0073e9SAndroid Build Coastguard Worker 1200*da0073e9SAndroid Build Coastguard Workerclass TestTraceback(TestCase): 1201*da0073e9SAndroid Build Coastguard Worker def test_basic(self): 1202*da0073e9SAndroid Build Coastguard Worker source = """\ 1203*da0073e9SAndroid Build Coastguard Workerdef f(x): 1204*da0073e9SAndroid Build Coastguard Worker def g(x): 1205*da0073e9SAndroid Build Coastguard Worker raise RuntimeError # HEYA 1206*da0073e9SAndroid Build Coastguard Worker 1207*da0073e9SAndroid Build Coastguard Worker x = x * 3 1208*da0073e9SAndroid Build Coastguard Worker return g(x) + 1 1209*da0073e9SAndroid Build Coastguard Worker""" 1210*da0073e9SAndroid Build Coastguard Worker 1211*da0073e9SAndroid Build Coastguard Worker out: Dict[str, Any] = {} 1212*da0073e9SAndroid Build Coastguard Worker scope = {"__compile_source__": source} 1213*da0073e9SAndroid Build Coastguard Worker exec(source, scope, out) 1214*da0073e9SAndroid Build Coastguard Worker 1215*da0073e9SAndroid Build Coastguard Worker try: 1216*da0073e9SAndroid Build Coastguard Worker with report_compile_source_on_error(): 1217*da0073e9SAndroid Build Coastguard Worker out["f"](1) 1218*da0073e9SAndroid Build Coastguard Worker except RuntimeError as e: 1219*da0073e9SAndroid Build Coastguard Worker self.assertIn("HEYA", "".join(traceback.format_tb(e.__traceback__))) 1220*da0073e9SAndroid Build Coastguard Worker 1221*da0073e9SAndroid Build Coastguard Worker def test_format_traceback_short(self): 1222*da0073e9SAndroid Build Coastguard Worker try: 1223*da0073e9SAndroid Build Coastguard Worker raise RuntimeError 1224*da0073e9SAndroid Build Coastguard Worker except RuntimeError as e: 1225*da0073e9SAndroid Build Coastguard Worker self.assertRegex( 1226*da0073e9SAndroid Build Coastguard Worker format_traceback_short(e.__traceback__), 1227*da0073e9SAndroid Build Coastguard Worker r".*test_utils.py:\d+ in test_format_traceback_short", 1228*da0073e9SAndroid Build Coastguard Worker ) 1229*da0073e9SAndroid Build Coastguard Worker 1230*da0073e9SAndroid Build Coastguard Worker def test_captured_traceback(self): 1231*da0073e9SAndroid Build Coastguard Worker self.assertIn( 1232*da0073e9SAndroid Build Coastguard Worker "test_captured_traceback", "".join(CapturedTraceback.extract().format()) 1233*da0073e9SAndroid Build Coastguard Worker ) 1234*da0073e9SAndroid Build Coastguard Worker 1235*da0073e9SAndroid Build Coastguard Worker def test_captured_traceback_format_all(self): 1236*da0073e9SAndroid Build Coastguard Worker rs = CapturedTraceback.format_all( 1237*da0073e9SAndroid Build Coastguard Worker [CapturedTraceback.extract(), CapturedTraceback.extract()] 1238*da0073e9SAndroid Build Coastguard Worker ) 1239*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(rs), 2) 1240*da0073e9SAndroid Build Coastguard Worker self.assertIn("test_captured_traceback_format_all", "".join(rs[0])) 1241*da0073e9SAndroid Build Coastguard Worker 1242*da0073e9SAndroid Build Coastguard Worker def test_captured_traceback_format_all_cached(self): 1243*da0073e9SAndroid Build Coastguard Worker tb = CapturedTraceback.extract() 1244*da0073e9SAndroid Build Coastguard Worker tb.format() # cached 1245*da0073e9SAndroid Build Coastguard Worker rs = CapturedTraceback.format_all([tb, CapturedTraceback.extract()]) 1246*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(rs), 2) 1247*da0073e9SAndroid Build Coastguard Worker self.assertIn("test_captured_traceback_format_all", "".join(rs[0])) 1248*da0073e9SAndroid Build Coastguard Worker 1249*da0073e9SAndroid Build Coastguard Worker 1250*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 1251*da0073e9SAndroid Build Coastguard Worker run_tests() 1252