1# mypy: allow-untyped-defs 2# Owner(s): ["module: unknown"] 3 4import os 5import random 6import re 7import shutil 8import subprocess 9import sys 10import tempfile 11import textwrap 12import traceback 13import unittest 14import warnings 15from typing import Any, Dict, List 16 17import torch 18import torch.cuda 19import torch.nn as nn 20import torch.utils.cpp_extension 21import torch.utils.data 22from torch.autograd._functions.utils import check_onnx_broadcast 23from torch.onnx.symbolic_opset9 import _prepare_onnx_paddings 24from torch.testing._internal.common_cuda import TEST_MULTIGPU 25from torch.testing._internal.common_device_type import ( 26 instantiate_device_type_tests, 27 onlyCPU, 28 ops, 29) 30from torch.testing._internal.common_methods_invocations import op_db 31from torch.testing._internal.common_utils import ( # type: ignore[attr-defined] 32 IS_FBCODE, 33 IS_SANDCASTLE, 34 IS_WINDOWS, 35 load_tests, 36) 37from torch.utils._device import set_device 38from torch.utils._pytree import tree_all_only, tree_any 39from torch.utils._traceback import ( 40 CapturedTraceback, 41 format_traceback_short, 42 report_compile_source_on_error, 43) 44from torch.utils.checkpoint import ( 45 _infer_device_type, 46 checkpoint, 47 checkpoint_sequential, 48 get_device_states, 49) 50from torch.utils.data import DataLoader 51 52 53# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for 54# sharding on sandcastle. This line silences flake warnings 55load_tests = load_tests 56 57HAS_CUDA = torch.cuda.is_available() 58 59 60from torch.testing._internal.common_utils import run_tests, TestCase 61 62 63class RandomDatasetMock(torch.utils.data.Dataset): 64 def __getitem__(self, index): 65 return torch.tensor([torch.rand(1).item(), random.uniform(0, 1)]) 66 67 def __len__(self): 68 return 1000 69 70 71class TestCheckpoint(TestCase): 72 # This runs checkpoint_sequential on each of the nets in 73 # module_lists_to_compare, and compares them against the uncheckpointed model. 74 # To compare, it checks outputs as well as input gradients and parameter gradients 75 def _check_checkpoint_sequential( 76 self, 77 model, 78 module_lists_to_compare, 79 num_chunks, 80 input, 81 use_reentrant, 82 ): 83 # not checkpointed 84 out = model(input) 85 out_not_checkpointed = out.detach().clone() 86 model.zero_grad() 87 out.sum().backward() 88 grad_not_checkpointed = { 89 name: param.grad.detach().clone() 90 for name, param in model.named_parameters() 91 } 92 input_grad_not_checkpointed = input.grad.detach().clone() 93 for model_to_compare in module_lists_to_compare: 94 # checkpointed model by passing list of modules 95 detached = input.detach() 96 detached.requires_grad = True 97 98 # pass list of modules to checkpoint 99 out = checkpoint_sequential( 100 model_to_compare, num_chunks, detached, use_reentrant=use_reentrant 101 ) 102 out_checkpointed = out.detach().clone() 103 model.zero_grad() 104 out.sum().backward() 105 grad_checkpointed = { 106 name: param.grad.detach().clone() 107 for name, param in model.named_parameters() 108 } 109 input_grad_checkpointed = detached.grad.detach().clone() 110 # compare outputs as well as the gradients of input and parameters 111 self.assertEqual(out_checkpointed, out_not_checkpointed) 112 self.assertEqual(input_grad_not_checkpointed, input_grad_checkpointed) 113 for name in grad_checkpointed: 114 self.assertEqual(grad_checkpointed[name], grad_not_checkpointed[name]) 115 116 # Test whether checkpoint is being triggered or not. For this, we check 117 # the number of times forward pass happens 118 def test_checkpoint_trigger(self): 119 class Net(nn.Module): 120 def __init__(self) -> None: 121 super().__init__() 122 self.counter = 0 123 124 def forward(self, input_var): 125 self.counter += 1 126 # For reentrant, need to have autograd actually 127 # pack a tensor to trigger recomp 128 ret = input_var * torch.tensor(2.0) 129 return ret 130 131 # checkpointed 132 for use_reentrant in [True, False]: 133 with self.subTest(use_reentrant=use_reentrant): 134 modules = [Net() for _ in range(10)] 135 for m in modules: 136 self.assertEqual(m.counter, 0) 137 input_var = torch.randn(3, 4, requires_grad=True) 138 out = checkpoint_sequential( 139 modules, 2, input_var, use_reentrant=use_reentrant 140 ) 141 for m in modules: 142 self.assertEqual(m.counter, 1) 143 out.sum().backward() 144 for m in modules[: (len(modules) // 2)]: 145 self.assertEqual(m.counter, 2) 146 for m in modules[(len(modules) // 2) :]: 147 self.assertEqual(m.counter, 1) 148 149 def test_checkpoint_valid(self): 150 model = nn.Sequential( 151 nn.Linear(100, 50), 152 nn.ReLU(), 153 nn.Linear(50, 20), 154 nn.ReLU(), 155 nn.Linear(20, 5), 156 nn.ReLU(), 157 ) 158 159 input_var = torch.randn(1, 100, requires_grad=True) 160 161 # checkpointed 162 chunks = 2 163 modules = list(model.children()) 164 out = checkpoint_sequential(modules, chunks, input_var, use_reentrant=True) 165 with self.assertRaisesRegex( 166 RuntimeError, "torch.utils.checkpoint is incompatible" 167 ): 168 torch.autograd.grad( 169 outputs=[out], 170 grad_outputs=[torch.ones(1, 5)], 171 inputs=[input_var], 172 create_graph=True, 173 ) 174 # works with use_reentrant=False, and grads are the same 175 out = model(input_var) 176 grads_no_checkpoint = torch.autograd.grad( 177 outputs=[out], 178 grad_outputs=[torch.ones(1, 5)], 179 inputs=[input_var], 180 create_graph=True, 181 ) 182 out_checkpoint = checkpoint_sequential( 183 modules, chunks, input_var, use_reentrant=False 184 ) 185 # check outputs are the same 186 self.assertEqual(out_checkpoint, out) 187 grads_checkpoint = torch.autograd.grad( 188 outputs=[out_checkpoint], 189 grad_outputs=[torch.ones(1, 5)], 190 inputs=[input_var], 191 create_graph=True, 192 ) 193 self.assertEqual(grads_no_checkpoint, grads_checkpoint) 194 195 def test_checkpoint(self): 196 for use_reentrant in [True, False]: 197 with self.subTest(use_reentrant=use_reentrant): 198 model = nn.Sequential( 199 nn.Linear(100, 50), 200 nn.ReLU(), 201 nn.Linear(50, 20), 202 nn.ReLU(), 203 nn.Linear(20, 5), 204 nn.ReLU(), 205 ) 206 207 # Compare uncheckpointed model with its checkpointed counterparts 208 # In addition to running checkpoint_sequential on the nn.Sequential 209 # instance, we also run the function on the list of functions within 210 # the module. 211 self._check_checkpoint_sequential( 212 model, 213 [list(model.children()), model], 214 2, 215 torch.randn(1, 100, requires_grad=True), 216 use_reentrant=use_reentrant, 217 ) 218 219 def test_checkpoint_module_list(self): 220 class ModuleListNet(nn.Module): 221 def __init__(self) -> None: 222 super().__init__() 223 module_list = [ 224 nn.Linear(100, 50), 225 nn.ReLU(), 226 nn.Linear(50, 20), 227 nn.ReLU(), 228 nn.Linear(20, 5), 229 nn.ReLU(), 230 ] 231 self.module_list = nn.ModuleList(module_list) 232 233 def forward(self, input): 234 for layer in self.module_list: 235 input = layer(input) 236 return input 237 238 for use_reentrant in [True, False]: 239 with self.subTest(use_reentrant=use_reentrant): 240 model = ModuleListNet() 241 242 # Compare uncheckpointed model with its checkpointed counterparts. 243 self._check_checkpoint_sequential( 244 model, 245 [list(model.module_list.children()), model.module_list], 246 2, 247 torch.randn(1, 100, requires_grad=True), 248 use_reentrant=use_reentrant, 249 ) 250 251 def test_checkpoint_sequential_deprecated_multiple_args(self): 252 class Two(nn.Module): 253 def forward(self, a, b): 254 return a, b 255 256 model = nn.Sequential(Two()) 257 a = torch.randn(1, 100, requires_grad=True) 258 b = torch.randn(1, 100, requires_grad=True) 259 260 for use_reentrant in [True, False]: 261 with self.subTest(use_reentrant=use_reentrant): 262 with self.assertRaises(TypeError): 263 checkpoint_sequential(model, 1, a, b) # type: ignore[call-arg] 264 265 def test_checkpoint_sequential_deprecated_no_args(self): 266 class Noop(nn.Module): 267 def forward(self): 268 pass 269 270 model = nn.Sequential(Noop()) 271 for use_reentrant in [True, False]: 272 with self.subTest(use_reentrant=use_reentrant): 273 with self.assertRaises(TypeError): 274 checkpoint_sequential(model, 1) # type: ignore[call-arg] 275 276 def test_checkpoint_rng_cpu(self): 277 for _ in range(5): 278 inp = torch.randn(20000, device="cpu").requires_grad_() 279 phase1 = torch.nn.Dropout() 280 phase2 = torch.nn.Dropout() 281 282 def run_fn(input): 283 return phase2(input) 284 285 state = torch.get_rng_state() 286 287 out = phase1(inp) 288 out = checkpoint(run_fn, out, use_reentrant=True) 289 out.sum().backward() 290 grad_with_checkpointing = inp.grad 291 292 torch.set_rng_state(state) 293 294 inp.grad = None 295 296 out = phase1(inp) 297 out = run_fn(out) 298 out.sum().backward() 299 grad_no_checkpointing = inp.grad 300 301 self.assertEqual(grad_with_checkpointing, grad_no_checkpointing) 302 303 @unittest.skipIf(not HAS_CUDA, "No CUDA") 304 def test_checkpoint_rng_cuda(self): 305 for _ in range(5): 306 inp = torch.randn(20000, device="cuda").requires_grad_() 307 phase1 = torch.nn.Dropout() 308 phase2 = torch.nn.Dropout() 309 310 def run_fn(input): 311 return phase2(input) 312 313 state = torch.cuda.get_rng_state() 314 315 out = phase1(inp) 316 out = checkpoint(run_fn, out, use_reentrant=True) 317 out.sum().backward() 318 grad_with_checkpointing = inp.grad 319 320 torch.cuda.set_rng_state(state) 321 322 inp.grad = None 323 324 out = phase1(inp) 325 out = run_fn(out) 326 out.sum().backward() 327 grad_no_checkpointing = inp.grad 328 329 self.assertEqual(grad_with_checkpointing, grad_no_checkpointing) 330 331 @unittest.skipIf(not HAS_CUDA, "No CUDA") 332 def test_checkpoint_not_preserve_rng_state_and_without_reentrant(self): 333 inp = torch.randn(2, device="cuda").requires_grad_() 334 layer = torch.nn.Dropout() 335 336 def run_fn(input): 337 return layer(input) 338 339 out = checkpoint(run_fn, inp, use_reentrant=False, preserve_rng_state=False) 340 out.sum().backward() 341 # This should run without error 342 343 def test_checkpoint_non_tensor(self): 344 def run_fn(tensor1, tensor2): 345 if tensor2 is None: 346 return tensor1 347 return tensor1 + tensor2 348 349 input_var = torch.randn(1, 100, requires_grad=True) 350 out = checkpoint(run_fn, input_var, None, use_reentrant=True) 351 out.sum().backward() 352 353 def test_checkpoint_non_tensor_inputs_outputs(self): 354 def foo(t1, t2, scale, t3): 355 t4 = t1 + t2 * t3 356 t5 = t1 * t2 + t3 357 t4 *= scale 358 t5 *= scale 359 return scale, t4, None, True, t5, "bar", t1 360 361 t1 = torch.rand(10, requires_grad=True) 362 t2 = torch.rand(10, requires_grad=True) 363 t3 = torch.rand(10) 364 scale = random.randint(0, 10) 365 res = checkpoint(foo, t1, t2, scale, t3, use_reentrant=True) 366 self.assertEqual(scale, res[0]) 367 self.assertEqual((t1 + t2 * t3) * scale, res[1]) 368 self.assertEqual(None, res[2]) 369 self.assertEqual(True, res[3]) 370 self.assertEqual((t1 * t2 + t3) * scale, res[4]) 371 self.assertEqual("bar", res[5]) 372 self.assertEqual(t1, res[6]) 373 374 # Validate running backward. 375 res[1].sum().backward(retain_graph=True) 376 res[4].sum().backward(retain_graph=True) 377 res[6].sum().backward() 378 with self.assertRaisesRegex( 379 RuntimeError, "Trying to backward through the graph a second time" 380 ): 381 res[6].sum().backward() 382 t1_grad = t1.grad 383 t2_grad = t2.grad 384 385 # Reset grads, run without checkpoint and validate we receive same grads. 386 t1.grad = None 387 t2.grad = None 388 res = foo(t1, t2, scale, t3) 389 torch.autograd.backward([res[1].sum(), res[4].sum(), res[6].sum()]) 390 self.assertEqual(t1.grad, t1_grad) 391 self.assertEqual(t2.grad, t2_grad) 392 393 def test_checkpoint_no_tensors(self): 394 def foo(t1, t2, scale, t3): 395 t4 = t1 + t2 * t3 396 t5 = t1 * t2 + t3 397 t4 *= scale 398 t5 *= scale 399 return scale, t4, None, True, t5, "bar", t1 400 401 t1 = random.random() 402 t2 = random.random() 403 t3 = random.random() 404 scale = random.randint(0, 10) 405 res = checkpoint(foo, t1, t2, scale, t3, use_reentrant=True) 406 self.assertEqual(scale, res[0]) 407 self.assertEqual((t1 + t2 * t3) * scale, res[1]) 408 self.assertEqual(None, res[2]) 409 self.assertEqual(True, res[3]) 410 self.assertEqual((t1 * t2 + t3) * scale, res[4]) 411 self.assertEqual("bar", res[5]) 412 self.assertEqual(t1, res[6]) 413 414 def test_checkpoint_partial_grad(self): 415 def run_fn(tensor1, tensor2): 416 # tensor 2 is used for other application logic 417 return tensor1, tensor2 418 419 input_var = torch.randn(1, 4, requires_grad=True) 420 input_var2 = torch.randn(1, 4, requires_grad=False) 421 out = checkpoint(run_fn, input_var, input_var2, use_reentrant=True) 422 out[0].sum().backward() 423 424 def run_fn2(tensor1, tensor2): 425 return tensor1 426 427 input_var = torch.randn(1, 4, requires_grad=False) 428 input_var2 = torch.randn(1, 4, requires_grad=True) 429 with self.assertRaisesRegex( 430 RuntimeError, 431 r"none of output has requires_grad=True, this checkpoint\(\) is not necessary", 432 ): 433 out = checkpoint(run_fn2, input_var, input_var2, use_reentrant=True) 434 out.sum().backward() 435 436 @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA") 437 def test_checkpointing_without_reentrant_early_free(self): 438 # I don't know how to check if the temporary saved variable buffer 439 # get de-allocated directly. So using cuda memory usage as a proxy 440 441 def _do_test(fn, should_free): 442 stats: List[int] = [] 443 444 def track(x, idx): 445 # Track that at each step of the backward, some Tensor were 446 # de-allocated (which correspond to the checkpoint storage being 447 # emptied at each step) 448 def hook(_unused): 449 self.assertEqual(len(stats), idx) 450 torch.cuda.synchronize() 451 stats.append(torch.cuda.memory_allocated()) 452 if idx > 0: 453 if should_free: 454 self.assertLess(stats[idx], stats[idx - 1]) 455 else: 456 self.assertEqual(stats[idx], stats[idx - 1]) 457 458 x.register_hook(hook) 459 460 def test_fn(x): 461 # The main property of this function is that it contains multiple 462 # operations that save gradients in a chain. 463 x = x**2 464 track(x, 2) 465 x = x**2 466 track(x, 1) 467 x = x**2 468 track(x, 0) 469 x = x**2 470 return x.sum() 471 472 fn(test_fn) 473 474 return stats 475 476 x = torch.zeros(10, device="cuda", requires_grad=True) 477 x.grad = torch.zeros_like(x) 478 479 # In a regular backward, buffers get eagerly freed 480 non_retain_stats = _do_test(lambda fn: fn(x).backward(), True) 481 482 # In a retain_grad backward, buffers get preserved 483 _unused_retain_stats = _do_test( 484 lambda fn: fn(x).backward(retain_graph=True), False 485 ) 486 487 # In a regular backward with checkpoint, buffers get eagerly freed 488 checkpoint_non_retain_stats = _do_test( 489 lambda fn: checkpoint(fn, x, use_reentrant=False).backward(), True 490 ) 491 492 # In a retain_grad backward with checkpoint, buffers get eagerly freed 493 checkpoint_retain_stats = _do_test( 494 lambda fn: checkpoint(fn, x, use_reentrant=False).backward( 495 retain_graph=True 496 ), 497 True, 498 ) 499 500 self.assertEqual(non_retain_stats, checkpoint_non_retain_stats) 501 self.assertEqual(non_retain_stats, checkpoint_retain_stats) 502 503 @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") 504 def test_get_device_states_recursive(self): 505 inp = { 506 "foo": torch.rand(10, device="cuda:0"), 507 "bar": [torch.rand(10, device="cuda:1")], 508 } 509 device_ids, device_states = get_device_states(inp) 510 self.assertEqual(2, len(device_ids)) 511 self.assertEqual(2, len(device_states)) 512 self.assertEqual(0, device_ids[0]) 513 self.assertEqual(1, device_ids[1]) 514 self.assertTrue(isinstance(device_states[0], torch.Tensor)) 515 self.assertTrue(isinstance(device_states[1], torch.Tensor)) 516 517 def test_infer_device_state_recursive_meta(self): 518 inp = {"foo": torch.rand(10, device="meta")} 519 device_type = _infer_device_type(inp) 520 self.assertEqual("meta", device_type) 521 522 @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") 523 def test_infer_device_state_recursive_multi_cuda(self): 524 # Check that no warning is issued for either cuda:0, cuda:1 or 525 # cuda:0, cuda:0 cases since they are both the same device type 526 inp = { 527 "foo": torch.rand(10, device="cuda:0"), 528 "bar": [torch.rand(10, device="cuda:1")], 529 } 530 with warnings.catch_warnings(): 531 warnings.simplefilter("error") 532 device_type = _infer_device_type(inp) 533 self.assertEqual("cuda", device_type) 534 inp = { 535 "foo": torch.rand(10, device="cuda:0"), 536 "bar": [torch.rand(10, device="cuda:0")], 537 } 538 with warnings.catch_warnings(): 539 warnings.simplefilter("error") 540 device_type = _infer_device_type(inp) 541 self.assertEqual("cuda", device_type) 542 # Check that a warning is issued for cuda:0, meta and that it includes 543 # device type information 544 inp = { 545 "foo": torch.rand(10, device="cuda:0"), 546 "bar": [torch.rand(10, device="meta")], 547 } 548 with warnings.catch_warnings(record=True) as w: 549 device_type = _infer_device_type(inp) 550 self.assertEqual("cuda", device_type) 551 self.assertEqual(len(w), 1) 552 warning_msg = str(w[-1].message) 553 self.assertTrue( 554 "Tensor arguments, excluding CPU tensors, are detected on at least two types of devices" 555 in warning_msg 556 ) 557 self.assertTrue("Device types: ['cuda', 'meta']" in warning_msg) 558 self.assertTrue("first device type: cuda" in warning_msg) 559 560 561class TestDataLoaderUtils(TestCase): 562 MAX_TIMEOUT_IN_SECOND = 300 563 564 def setUp(self): 565 super().setUp() 566 self.dataset = torch.randn(5, 3, 3, 2) 567 self.batch_size = 3 568 569 def test_random_seed(self): 570 def run(): 571 dataloader = torch.utils.data.DataLoader( 572 RandomDatasetMock(), 573 batch_size=2, 574 num_workers=4, 575 shuffle=True, 576 timeout=self.MAX_TIMEOUT_IN_SECOND, 577 ) 578 return next(iter(dataloader)) 579 580 torch.manual_seed(2018) 581 x1 = run() 582 torch.manual_seed(2018) 583 x2 = run() 584 self.assertEqual(x1, x2) 585 586 def test_single_keep(self): 587 # self.dataset is a Tensor here; technically not a valid input because 588 # not a Dataset subclass, but needs to stay working so add ignore's 589 # for type checking with mypy 590 dataloader: DataLoader = DataLoader( 591 self.dataset, # type: ignore[arg-type] 592 batch_size=self.batch_size, 593 num_workers=0, 594 drop_last=False, 595 ) 596 dataiter = iter(dataloader) 597 self.assertEqual(len(list(dataiter)), 2) 598 599 def test_single_drop(self): 600 dataloader: DataLoader = DataLoader( 601 self.dataset, # type: ignore[arg-type] 602 batch_size=self.batch_size, 603 num_workers=0, 604 drop_last=True, 605 ) 606 dataiter = iter(dataloader) 607 self.assertEqual(len(list(dataiter)), 1) 608 609 @unittest.skip( 610 "FIXME: Intermittent CUDA out-of-memory error on Windows and time-out under ASAN" 611 ) 612 def test_multi_keep(self): 613 dataloader: DataLoader = DataLoader( 614 self.dataset, # type: ignore[arg-type] 615 batch_size=self.batch_size, 616 num_workers=2, 617 drop_last=False, 618 timeout=self.MAX_TIMEOUT_IN_SECOND, 619 ) 620 dataiter = iter(dataloader) 621 self.assertEqual(len(list(dataiter)), 2) 622 623 def test_multi_drop(self): 624 dataloader: DataLoader = DataLoader( 625 self.dataset, # type: ignore[arg-type] 626 batch_size=self.batch_size, 627 num_workers=2, 628 drop_last=True, 629 timeout=self.MAX_TIMEOUT_IN_SECOND, 630 ) 631 dataiter = iter(dataloader) 632 self.assertEqual(len(list(dataiter)), 1) 633 634 635test_dir = os.path.abspath(os.path.dirname(str(__file__))) 636 637 638@unittest.skipIf( 639 "SKIP_TEST_BOTTLENECK" in os.environ.keys(), "SKIP_TEST_BOTTLENECK is set" 640) 641class TestBottleneck(TestCase): 642 def _run(self, command, timeout=30): 643 """Returns (return-code, stdout, stderr)""" 644 import subprocess 645 646 p = subprocess.Popen( 647 command, 648 stdout=subprocess.PIPE, 649 stderr=subprocess.PIPE, 650 shell=True, 651 ) 652 try: 653 output, err = p.communicate(timeout=timeout) 654 except subprocess.TimeoutExpired: 655 p.kill() 656 output, err = p.communicate() 657 rc = p.returncode 658 output_str = output.decode("ascii") 659 err_str = err.decode("ascii") 660 return (rc, output_str, err_str) 661 662 def _run_bottleneck(self, test_file, scriptargs=""): 663 curdir = os.path.dirname(os.path.abspath(__file__)) 664 filepath = f"{curdir}/{test_file}" 665 if scriptargs != "": 666 scriptargs = f" {scriptargs}" 667 rc, out, err = self._run( 668 f"{sys.executable} -m torch.utils.bottleneck {filepath}{scriptargs}" 669 ) 670 return rc, out, err 671 672 def _check_run_args(self): 673 # Check that this fails due to missing args 674 rc, out, err = self._run_bottleneck("bottleneck_test/test_args.py") 675 self.assertEqual( 676 rc, 677 2, 678 atol=0, 679 rtol=0, 680 msg=self._fail_msg("Missing args should error", out + err), 681 ) 682 683 # This should succeed 684 rc, out, err = self._run_bottleneck( 685 "bottleneck_test/test_args.py", "--foo foo --bar bar" 686 ) 687 self.assertEqual( 688 rc, 689 0, 690 atol=0, 691 rtol=0, 692 msg=self._fail_msg("Should pass args to script", out + err), 693 ) 694 695 def _fail_msg(self, msg, output): 696 return f"{msg}, output was:\n{output}" 697 698 def _check_environment_summary(self, output): 699 results = re.search("Environment Summary", output) 700 self.assertIsNotNone( 701 results, self._fail_msg("Should have Environment Summary", output) 702 ) 703 704 # Up to five lines away from the heading, there should be the version number 705 results = re.search( 706 r"Environment Summary.*(\n.*){,5}\nPyTorch \d+\.\d+", output 707 ) 708 self.assertIsNotNone( 709 results, self._fail_msg("Should have PyTorch version", output) 710 ) 711 712 def _check_cprof_summary(self, output): 713 results = re.search("cProfile output", output) 714 self.assertIsNotNone( 715 results, self._fail_msg("Should have cProfile output", output) 716 ) 717 718 # This assumes that after the cProfile output section we have 719 # the autograd profiler output 720 results = re.search( 721 r"cProfile output.*(\n.*){6,50}\n.*autograd profiler output", output 722 ) 723 self.assertIsNotNone( 724 results, 725 self._fail_msg( 726 "Distance between cProfile and autograd prof out not in [6, 50] lines", 727 output, 728 ), 729 ) 730 731 def _check_autograd_summary(self, output): 732 results = re.search("autograd profiler output", output) 733 self.assertIsNotNone( 734 results, self._fail_msg("Should have autograd profiler output", output) 735 ) 736 737 # This assumes that after the autograd profiler output is the end of the 738 # output. 739 results = re.search(r"autograd profiler output.*(\n.*){6,100}", output) 740 self.assertIsNotNone( 741 results, 742 self._fail_msg( 743 "Distance between autograd prof output and end of output not in [6, 100] lines", 744 output, 745 ), 746 ) 747 748 def _check_cuda(self, output): 749 if HAS_CUDA: 750 results = re.search("CUDA mode", output) 751 self.assertIsNotNone( 752 results, self._fail_msg("Should tell users CUDA", output) 753 ) 754 else: 755 results = re.search("CUDA mode", output) 756 self.assertIsNone( 757 results, self._fail_msg("Should not tell users about CUDA", output) 758 ) 759 760 @unittest.skipIf(HAS_CUDA, "CPU-only test") 761 def test_bottleneck_cpu_only(self): 762 rc, out, err = self._run_bottleneck("bottleneck_test/test.py") 763 self.assertEqual(rc, 0, msg=f"Run failed with\n{err}") 764 765 self._check_run_args() 766 self._check_environment_summary(out) 767 self._check_autograd_summary(out) 768 self._check_cprof_summary(out) 769 self._check_cuda(out) 770 771 @unittest.skipIf(not HAS_CUDA, "No CUDA") 772 def test_bottleneck_cuda(self): 773 rc, out, err = self._run_bottleneck("bottleneck_test/test_cuda.py") 774 self.assertEqual(rc, 0, msg=f"Run failed with\n{err}") 775 776 self._check_run_args() 777 self._check_environment_summary(out) 778 self._check_autograd_summary(out) 779 self._check_cprof_summary(out) 780 self._check_cuda(out) 781 782 783from torch.utils.collect_env import get_pretty_env_info 784 785 786@unittest.skipIf(IS_FBCODE, "runs pip which is not available internally") 787class TestCollectEnv(TestCase): 788 def test_smoke(self): 789 info_output = get_pretty_env_info() 790 self.assertTrue(info_output.count("\n") >= 17) 791 792 793class TestONNXUtils(TestCase): 794 def test_prepare_onnx_paddings(self): 795 sizes = [2, 3, 4] 796 pad = [1, 2, 3, 4] 797 paddings = _prepare_onnx_paddings(len(sizes), pad) 798 self.assertEqual(paddings, [0, 3, 1, 0, 4, 2]) 799 800 def test_check_onnx_broadcast(self): 801 def try_check_onnx_broadcast(dims1, dims2, expect_broadcast, expect_fail): 802 broadcast = True 803 fail = False 804 try: 805 broadcast = check_onnx_broadcast(dims1, dims2) 806 except ValueError: 807 fail = True 808 self.assertEqual(broadcast, expect_broadcast) 809 self.assertEqual(fail, expect_fail) 810 811 # Case 1, check the case when len(dims1) < len(dims2) and numel(dims2) > 1 812 dims1 = [3, 4] 813 dims2 = [2, 3, 4] 814 try_check_onnx_broadcast(dims1, dims2, True, True) 815 816 # Case 2, check the case when len(dims1) < len(dims2) and numel(dims2) == 1 817 dims1 = [3, 4] 818 dims2 = [1, 1, 1] 819 try_check_onnx_broadcast(dims1, dims2, True, False) 820 821 # Case 3, check the case when len(dims1) > len(dims2) and numel(dims2) == 1 822 dims1 = [1, 1] 823 dims2 = [1] 824 try_check_onnx_broadcast(dims1, dims2, True, False) 825 826 # Case 4, check the case when len(dims1) > len(dims2) and dims1[x:] == dims2 827 dims1 = [2, 3, 4] 828 dims2 = [3, 4] 829 try_check_onnx_broadcast(dims1, dims2, True, False) 830 831 # Case 5, check the case when len(dims1) > len(dims2), but dims1[x:] != dims2 832 dims1 = [2, 3, 4] 833 dims2 = [1, 4] 834 try_check_onnx_broadcast(dims1, dims2, True, True) 835 836 # Case 6, check the equal case, no broadcast 837 dims1 = [3, 4] 838 dims2 = [3, 4] 839 try_check_onnx_broadcast(dims1, dims2, False, False) 840 841 # Case 7, check the case when len(dims1) == len(dims2), but dims1 != dims2 842 dims1 = [3, 4] 843 dims2 = [1, 4] 844 try_check_onnx_broadcast(dims1, dims2, True, True) 845 846 # Case 8, check the case when len(dims1) == len(dims2) and numel(s2) == 1 847 dims1 = [3, 4] 848 dims2 = [1, 1] 849 try_check_onnx_broadcast(dims1, dims2, True, False) 850 851 852class TestHipify(TestCase): 853 def test_import_hipify(self): 854 from torch.utils.hipify import hipify_python # noqa: F401 855 856 857class TestHipifyTrie(TestCase): 858 def setUp(self): 859 self.trie = torch.utils.hipify.hipify_python.Trie() 860 861 def test_add_and_search_trie(self): 862 self.trie.add("banana") 863 self.assertTrue(self.trie.search("banana")) 864 self.assertFalse(self.trie.search("ban")) 865 self.assertFalse(self.trie.search("dog")) 866 867 def test_add_multiple_and_search_trie(self): 868 words_to_add = ["banana", "apple", "orange"] 869 for word in words_to_add: 870 self.trie.add(word) 871 872 for word in words_to_add: 873 self.assertTrue(self.trie.search(word)) 874 875 for word in ["ban", "dog", "okay", "app"]: 876 self.assertFalse(self.trie.search(word)) 877 878 def test_quote_escape(self): 879 orig_chars = ["*", "[", ".", "+", "a", "z", "-"] 880 quoted_strs = ["\\*", "\\[", "\\.", "\\+", "a", "z", "\\-"] 881 for i in range(len(orig_chars)): 882 self.assertEqual(self.trie.quote(orig_chars[i]), quoted_strs[i]) 883 884 def test_export_trie_to_regex(self): 885 words_to_add = [ 886 "__CUDACC__", 887 "CUDA_ERROR_CONTEXT_ALREADY_CURRENT", 888 "CUDA_ERROR_ARRAY_IS_MAPPED", 889 "CUDA_ERROR_NOT_MAPPED", 890 "CUDA_ERROR_INVALID_SOURCE", 891 ] 892 for word in words_to_add: 893 self.trie.add(word) 894 regex = self.trie.export_to_regex() 895 expected_regex = r"(?:CUDA_ERROR_(?:ARRAY_IS_MAPPED|CONTEXT_ALREADY_CURRENT|INVALID_SOURCE|NOT_MAPPED)|__CUDACC__)" 896 self.assertEqual(regex, expected_regex) 897 898 def test_prefix_words_export_trie_to_regex(self): 899 # test case where some nodes have both children and are also leaf nodes. 900 words_to_add = ["apple", "app", "ban", "banana"] 901 for word in words_to_add: 902 self.trie.add(word) 903 regex = self.trie.export_to_regex() 904 expected_regex = r"(?:app(?:le)?|ban(?:ana)?)" 905 self.assertEqual(regex, expected_regex) 906 907 def test_single_export_trie_to_regex(self): 908 words_to_add = ["cudaErrorInvalidMemcpyDirection"] 909 for word in words_to_add: 910 self.trie.add(word) 911 regex = self.trie.export_to_regex() 912 expected_regex = "cudaErrorInvalidMemcpyDirection" 913 self.assertEqual(regex, expected_regex) 914 915 def test_char_export_trie_to_regex(self): 916 self.trie.add("a") 917 self.assertEqual(self.trie.export_to_regex(), "a") 918 self.trie.add("b") 919 self.assertEqual(self.trie.export_to_regex(), "[ab]") 920 921 def test_special_char_export_trie_to_regex(self): 922 self.trie.add(r"c*") 923 self.assertEqual(self.trie.export_to_regex(), r"c\*") 924 925 926class TestAssert(TestCase): 927 def test_assert_true(self): 928 # verify assertions work as expected 929 # bool argument 930 torch._assert(True, "foo") 931 with self.assertRaisesRegex(AssertionError, "bar"): 932 torch._assert(False, "bar") 933 # tensor argument 934 torch._assert(torch.tensor([True], dtype=torch.bool), "foo") 935 with self.assertRaisesRegex(AssertionError, "bar"): 936 torch._assert(torch.tensor([False], dtype=torch.bool), "bar") 937 938 def test_assert_scriptable(self): 939 class M(torch.nn.Module): 940 def forward(self, x): 941 torch._assert(x.sum() > 0, "foo") 942 return x 943 944 m = M() 945 # scriptable 946 ms = torch.jit.script(m) 947 # data can be passed without errors 948 x = torch.randn(4, 4).fill_(1.0) 949 ms(x) 950 with self.assertRaisesRegex(torch.jit.Error, "foo"): 951 ms(torch.tensor([False], dtype=torch.bool)) 952 953 954@unittest.skipIf(IS_SANDCASTLE, "cpp_extension is OSS only") 955class TestStandaloneCPPJIT(TestCase): 956 def test_load_standalone(self): 957 build_dir = tempfile.mkdtemp() 958 try: 959 src_path = os.path.join(build_dir, "main.cpp") 960 src = textwrap.dedent( 961 """\ 962 #include <iostream> 963 #include <torch/torch.h> 964 int main() { 965 auto x = torch::eye(3); 966 std::cout << x << std::endl; 967 } 968 """ 969 ) 970 with open(src_path, "w") as f: 971 f.write(src) 972 973 exec_path = torch.utils.cpp_extension.load( 974 "standalone_load_test", 975 src_path, 976 build_directory=build_dir, 977 is_python_module=False, 978 is_standalone=True, 979 ) 980 981 ext = ".exe" if IS_WINDOWS else "" 982 self.assertEqual( 983 exec_path, os.path.join(build_dir, f"standalone_load_test{ext}") 984 ) 985 986 for shell in [True, False]: 987 r = subprocess.run( 988 [exec_path], 989 shell=shell, 990 stdout=subprocess.PIPE, 991 ) 992 self.assertEqual(r.returncode, 0) 993 self.assertEqual( 994 # Windows prints "\r\n" for newlines. 995 textwrap.dedent(r.stdout.decode("utf-8")).replace("\r\n", "\n"), 996 textwrap.dedent( 997 """\ 998 1 0 0 999 0 1 0 1000 0 0 1 1001 [ CPUFloatType{3,3} ] 1002 """ 1003 ), 1004 ) 1005 1006 finally: 1007 shutil.rmtree(build_dir) 1008 1009 1010class DummyPrivateUse1Module: 1011 @staticmethod 1012 def is_available(): 1013 return True 1014 1015 @staticmethod 1016 def is_autocast_enabled(): 1017 return True 1018 1019 @staticmethod 1020 def get_autocast_dtype(): 1021 return torch.float16 1022 1023 @staticmethod 1024 def set_autocast_enabled(enable): 1025 pass 1026 1027 @staticmethod 1028 def set_autocast_dtype(dtype): 1029 pass 1030 1031 @staticmethod 1032 def get_amp_supported_dtype(): 1033 return [torch.float16] 1034 1035 1036class TestExtensionUtils(TestCase): 1037 def tearDown(self): 1038 # Clean up 1039 backend_name = torch._C._get_privateuse1_backend_name() 1040 if hasattr(torch, backend_name): 1041 delattr(torch, backend_name) 1042 if f"torch.{backend_name}" in sys.modules: 1043 del sys.modules[f"torch.{backend_name}"] 1044 1045 def test_external_module_register(self): 1046 # Built-in module 1047 with self.assertRaisesRegex(RuntimeError, "The runtime module of"): 1048 torch._register_device_module("cuda", torch.cuda) 1049 1050 # Wrong device type 1051 with self.assertRaisesRegex(RuntimeError, "Expected one of cpu"): 1052 torch._register_device_module("dummmy", DummyPrivateUse1Module) 1053 1054 with self.assertRaises(AttributeError): 1055 torch.privateuseone.is_available() # type: ignore[attr-defined] 1056 1057 torch._register_device_module("privateuseone", DummyPrivateUse1Module) 1058 1059 torch.privateuseone.is_available() # type: ignore[attr-defined] 1060 1061 # No supporting for override 1062 with self.assertRaisesRegex(RuntimeError, "The runtime module of"): 1063 torch._register_device_module("privateuseone", DummyPrivateUse1Module) 1064 1065 def test_external_module_register_with_renamed_backend(self): 1066 torch.utils.rename_privateuse1_backend("foo") 1067 with self.assertRaisesRegex(RuntimeError, "has already been set"): 1068 torch.utils.rename_privateuse1_backend("dummmy") 1069 1070 custom_backend_name = torch._C._get_privateuse1_backend_name() 1071 self.assertEqual(custom_backend_name, "foo") 1072 1073 with self.assertRaises(AttributeError): 1074 torch.foo.is_available() # type: ignore[attr-defined] 1075 1076 with self.assertRaisesRegex(AssertionError, "Tried to use AMP with the"): 1077 with torch.autocast(device_type=custom_backend_name): 1078 pass 1079 torch._register_device_module("foo", DummyPrivateUse1Module) 1080 1081 torch.foo.is_available() # type: ignore[attr-defined] 1082 with torch.autocast(device_type=custom_backend_name): 1083 pass 1084 1085 self.assertEqual(torch._utils._get_device_index("foo:1"), 1) 1086 self.assertEqual(torch._utils._get_device_index(torch.device("foo:2")), 2) 1087 1088 1089class TestRenderUtils(TestCase): 1090 def test_basic(self): 1091 self.assertExpectedInline( 1092 torch._utils.render_call(torch.sum, [torch.randn(100)], {"dim": 0}), 1093 """torch.sum(tensor([...], size=(100,)), dim=0)""", 1094 ) 1095 self.assertExpectedInline( 1096 torch._utils.render_call(torch.sum, [torch.randn(100, 100)], {"dim": 0}), 1097 """torch.sum(tensor([...], size=(100, 100)), dim=0)""", 1098 ) 1099 1100 1101class TestDeviceUtils(TestCase): 1102 def test_basic(self): 1103 with torch.device("meta") as dev: 1104 x = torch.empty(3, 3) 1105 self.assertEqual(x.device.type, "meta") 1106 self.assertEqual(dev, torch.device("meta")) 1107 1108 def test_decorator(self): 1109 @set_device("meta") 1110 def f(): 1111 return torch.empty(3, 3) 1112 1113 self.assertEqual(f().device.type, "meta") 1114 1115 def test_decorator_generator(self): 1116 @set_device("meta") 1117 def f(): 1118 yield torch.empty(3, 3) 1119 yield torch.empty(3, 3) 1120 1121 r1, r2 = list(f()) 1122 self.assertEqual(r1.device.type, "meta") 1123 self.assertEqual(r2.device.type, "meta") 1124 1125 def test_nn_module(self): 1126 with torch.device("meta"): 1127 m = nn.Linear(40, 50) 1128 self.assertEqual(m.weight.device.type, "meta") 1129 1130 def test_set_default_device(self): 1131 try: 1132 torch.set_default_device("meta") 1133 r = torch.empty(2, 2) 1134 finally: 1135 torch.set_default_device(None) 1136 1137 self.assertEqual(r.device.type, "meta") 1138 1139 def test_get_default_device(self): 1140 torch.set_default_device("meta") 1141 self.assertEqual(torch.get_default_device().type, "meta") 1142 torch.set_default_device(None) 1143 1144 @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") 1145 def test_get_default_device_more(self): 1146 torch.set_default_device("cuda") 1147 self.assertEqual(torch.get_default_device(), torch.tensor([]).device) 1148 torch.set_default_device(None) 1149 1150 torch.set_default_device("cuda") 1151 torch.cuda.set_device("cuda:1") 1152 self.assertEqual(torch.get_default_device(), torch.tensor([]).device) 1153 torch.set_default_device(None) 1154 1155 torch.set_default_device("cuda:1") 1156 self.assertEqual(torch.get_default_device(), torch.tensor([]).device) 1157 torch.set_default_device(None) 1158 1159 @onlyCPU 1160 @ops(op_db) 1161 def test_device_mode_ops(self, device, dtype, op): 1162 func = op.get_op() 1163 samples = op.sample_inputs(device, dtype, requires_grad=False) 1164 for sample in samples: 1165 # Only test samples which don't have Tensor inputs. However, 1166 # we don't test the factory property on OpInfo as it is very, 1167 # very incomplete 1168 if tree_any( 1169 lambda x: isinstance(x, torch.Tensor), 1170 (sample.input, sample.args, sample.kwargs), 1171 ): 1172 continue 1173 # Many OpInfos will explicitly pass in a device. DeviceContext 1174 # will respect device if it is explicitly specified. To test 1175 # DeviceContext, we have to remove the device kwarg in this case. 1176 # NB: Can't pass None to sample_inputs, the function can't 1177 # handle it. 1178 kwargs = sample.kwargs.copy() 1179 kwargs.pop("device", None) 1180 with torch.device("meta"): 1181 r = func(sample.input, *sample.args, **kwargs) 1182 1183 def is_meta_device(x: torch.Tensor) -> bool: 1184 return x.device.type == "meta" 1185 1186 self.assertTrue(tree_all_only(torch.Tensor, is_meta_device, r)) 1187 1188 1189instantiate_device_type_tests(TestDeviceUtils, globals()) 1190 1191 1192class TestCppExtensionUtils(TestCase): 1193 def test_cpp_compiler_is_ok(self): 1194 self.assertTrue(torch.utils.cpp_extension.check_compiler_ok_for_platform("c++")) 1195 1196 def test_cc_compiler_is_ok(self): 1197 self.assertTrue(torch.utils.cpp_extension.check_compiler_ok_for_platform("cc")) 1198 1199 1200class TestTraceback(TestCase): 1201 def test_basic(self): 1202 source = """\ 1203def f(x): 1204 def g(x): 1205 raise RuntimeError # HEYA 1206 1207 x = x * 3 1208 return g(x) + 1 1209""" 1210 1211 out: Dict[str, Any] = {} 1212 scope = {"__compile_source__": source} 1213 exec(source, scope, out) 1214 1215 try: 1216 with report_compile_source_on_error(): 1217 out["f"](1) 1218 except RuntimeError as e: 1219 self.assertIn("HEYA", "".join(traceback.format_tb(e.__traceback__))) 1220 1221 def test_format_traceback_short(self): 1222 try: 1223 raise RuntimeError 1224 except RuntimeError as e: 1225 self.assertRegex( 1226 format_traceback_short(e.__traceback__), 1227 r".*test_utils.py:\d+ in test_format_traceback_short", 1228 ) 1229 1230 def test_captured_traceback(self): 1231 self.assertIn( 1232 "test_captured_traceback", "".join(CapturedTraceback.extract().format()) 1233 ) 1234 1235 def test_captured_traceback_format_all(self): 1236 rs = CapturedTraceback.format_all( 1237 [CapturedTraceback.extract(), CapturedTraceback.extract()] 1238 ) 1239 self.assertEqual(len(rs), 2) 1240 self.assertIn("test_captured_traceback_format_all", "".join(rs[0])) 1241 1242 def test_captured_traceback_format_all_cached(self): 1243 tb = CapturedTraceback.extract() 1244 tb.format() # cached 1245 rs = CapturedTraceback.format_all([tb, CapturedTraceback.extract()]) 1246 self.assertEqual(len(rs), 2) 1247 self.assertIn("test_captured_traceback_format_all", "".join(rs[0])) 1248 1249 1250if __name__ == "__main__": 1251 run_tests() 1252