1# Owner(s): ["oncall: pt2"] 2 3# Copyright (c) Facebook, Inc. and its affiliates. 4# All rights reserved. 5# 6# This source code is licensed under the BSD-style license found in the 7# LICENSE file in the root directory of this source tree. 8 9import copy 10import itertools 11import unittest 12import warnings 13from contextlib import nullcontext 14from functools import partial, wraps 15from typing import Any, Callable, Dict, List, Optional, Union 16from unittest.mock import patch 17 18from common_utils import decorate, decorateForModules, skip, skipOps, xfail 19 20import torch 21import torch._dynamo as torchdynamo 22import torch.nn as nn 23import torch.utils._pytree as pytree 24from functorch import grad, jacrev, make_fx, vjp, vmap 25from functorch.compile import ( 26 aot_function, 27 aot_module, 28 aot_module_simplified, 29 compiled_function, 30 compiled_module, 31 default_decompositions, 32 default_partition, 33 get_aot_compilation_context, 34 make_boxed_compiler, 35 make_boxed_func, 36 memory_efficient_fusion, 37 min_cut_rematerialization_partition, 38 nnc_jit, 39 nop, 40) 41from functorch.experimental import control_flow 42from torch._decomp import decomposition_table 43from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache 44from torch._functorch.aot_autograd import aot_export_joint_simple, aot_export_module 45from torch._higher_order_ops.out_dtype import out_dtype 46from torch._inductor.codecache import compiled_fx_graph_hash 47from torch._subclasses.fake_tensor import DynamicOutputShapeException, FakeTensorMode 48from torch.fx.experimental.proxy_tensor import is_sym_node 49from torch.fx.experimental.symbolic_shapes import GuardOnDataDependentSymNode, ShapeEnv 50from torch.nn.utils.rnn import PackedSequence 51from torch.testing._internal.common_device_type import ( 52 instantiate_device_type_tests, 53 ops, 54 tol, 55 toleranceOverride, 56) 57from torch.testing._internal.common_methods_invocations import op_db 58from torch.testing._internal.common_modules import module_db, modules 59from torch.testing._internal.common_utils import ( 60 compare_equal_outs_and_grads, 61 instantiate_parametrized_tests, 62 IS_ARM64, 63 IS_MACOS, 64 IS_WINDOWS, 65 IS_X86, 66 outs_and_grads, 67 parametrize, 68 run_tests, 69 skipIfRocm, 70 skipIfTorchDynamo, 71 TestCase, 72 xfail_inherited_tests, 73 xfailIfTorchDynamo, 74) 75from torch.testing._internal.custom_tensor import ConstantExtraMetadataTensor 76from torch.testing._internal.hop_db import hop_db 77from torch.testing._internal.optests import ( 78 _test_aot_autograd_forwards_backwards_helper, 79 aot_autograd_check, 80) 81from torch.testing._internal.two_tensor import TwoTensor, TwoTensorMode 82 83 84USE_TORCHVISION = False 85try: 86 import torchvision 87 88 USE_TORCHVISION = True 89except ImportError: 90 warnings.warn( 91 "Couldn't import torchvision. Some of our tests use it, try " 92 "to install it with commands from pytorch.org, post-fixed with " 93 "`--no-deps` to avoid overwriting the pytorch installation", 94 UserWarning, 95 ) 96 97USE_NETWORKX = False 98try: 99 import networkx # noqa: F401 100 101 USE_NETWORKX = True 102except ImportError: 103 warnings.warn("Some tests use networkx but it was not installed", UserWarning) 104 105# NB: numpy is a testing dependency! 106 107 108class AOTTestCase(TestCase): 109 pass 110 111 112class TestPythonKey(AOTTestCase): 113 def test_make_fx(self, device): 114 def f(x): 115 return torch.sin(x) 116 117 inp = torch.randn(3) 118 fx_f = make_fx(f)(inp) 119 120 new_inp = torch.randn(3) 121 self.assertEqual(fx_f(new_inp), f(new_inp)) 122 123 def test_make_fx_grad(self, device): 124 def f(x): 125 return torch.sin(x).sum() 126 127 inp = torch.randn(3) 128 f = grad(f) 129 fx_f = make_fx(f)(inp) 130 131 new_inp = torch.randn(3) 132 self.assertEqual(fx_f(new_inp), f(new_inp)) 133 134 def test_scalar_device(self, device): 135 def f(a, b): 136 return a + b 137 138 inps = [torch.randn(3, device=device), torch.tensor(5)] 139 fx_f = make_fx(f)(*inps) 140 self.assertEqual(fx_f(*inps), f(*inps)) 141 142 def test_make_fx_vmap(self, device): 143 def f(x): 144 return torch.sin(x) 145 146 inp = torch.randn(5, 3) 147 f = vmap(f) 148 fx_f = make_fx(f)(inp) 149 new_inp = torch.randn(5, 3) 150 self.assertEqual(fx_f(new_inp), f(new_inp)) 151 152 def test_make_fx_jacrev(self, device): 153 def f(x): 154 return x.sin().sum() 155 156 inp = torch.randn(3) 157 f = jacrev(jacrev(f)) 158 fx_f = make_fx(f)(inp) 159 new_inp = torch.randn(3) 160 self.assertEqual(fx_f(new_inp), f(new_inp)) 161 162 def test_make_fx_vjp(self, device): 163 def f(x): 164 return torch.sin(x).sum() 165 166 primals = torch.randn(3) 167 _, vjp_fn = vjp(f, primals) 168 cotangent = torch.randn(()) 169 fx_f = make_fx(vjp_fn)(cotangent, True, True) 170 new_cotangent = torch.randn(()) 171 self.assertEqual(fx_f(new_cotangent, True, True), vjp_fn(new_cotangent)) 172 173 def test_make_fx_functionalize(self, device): 174 from functorch.experimental import functionalize 175 176 def fn(a): 177 a = a * 2 178 a.relu_() 179 return a 180 181 a = torch.randn(3, device=device) 182 symbolic_gm = torch.fx.symbolic_trace(fn) 183 includes_method_relu_ = any( 184 str(n.target) == "relu_" for n in symbolic_gm.graph.nodes 185 ) 186 self.assertTrue(includes_method_relu_) 187 # Also verifies fix for https://github.com/pytorch/pytorch/issues/84570 188 gm = make_fx(functionalize(symbolic_gm))(a) 189 includes_aten_relu = any( 190 n.target == torch.ops.aten.relu.default for n in gm.graph.nodes 191 ) 192 self.assertTrue(includes_aten_relu) 193 194 def test_make_fx_no_decompose(self, device): 195 # FIXME 196 return self.skipTest("error: maximum recursion reached") 197 198 def f(x): 199 return torch.tanh(x).sum() 200 201 fx_f = make_fx(grad(f))(torch.randn(5)) 202 ops = {i.target for i in fx_f.graph.nodes} 203 204 self.assertEqual(torch.ops.aten.tanh_backward in ops, True) 205 206 fx_f = make_fx(grad(f), decomposition_table)(torch.randn(5)) 207 ops = {i.target for i in fx_f.graph.nodes} 208 self.assertEqual(torch.ops.aten.tanh_backward in ops, False) 209 210 def test_nnc_jit(self, device): 211 def f(x): 212 return torch.sin(x) 213 214 jit_f = nnc_jit(f) 215 216 inp = torch.randn(3) 217 self.assertEqual(jit_f(inp), f(inp)) 218 219 def test_nnc_scalar(self, device): 220 def f(x): 221 return torch.sin(x) 222 223 jit_f = nnc_jit(f) 224 225 inp = torch.randn(()) 226 self.assertEqual(jit_f(inp), f(inp)) 227 228 def test_nnc_pytrees(self, device): 229 def f(x): 230 return [torch.sin(x[0])] 231 232 jit_f = nnc_jit(f) 233 234 inp = [torch.randn(3)] 235 self.assertEqual(jit_f(inp), f(inp)) 236 237 def test_external_calls(self, device): 238 def f(a, b): 239 return torch.mv(a, b) 240 241 jit_f = nnc_jit(f) 242 inp = [torch.randn(3, 3), torch.randn(3)] 243 self.assertEqual(jit_f(*inp), f(*inp)) 244 245 def test_nnc_passthrough(self, device): 246 def f(x, y): 247 return x + y, y 248 249 inp = (torch.randn(3), torch.randn(3)) 250 jit_f = nnc_jit(f) 251 self.assertEqual(jit_f(*inp), f(*inp)) 252 253 def f(x): 254 x["a"] = x["a"] * 2 255 return x 256 257 inp = ({"a": torch.randn(3), "b": torch.randn(3)},) 258 jit_f = nnc_jit(f) 259 self.assertEqual(jit_f(*inp), f(*inp)) 260 261 @unittest.skipIf(not USE_TORCHVISION, "test requires torchvision") 262 def test_resnet18_backward_trace(self, device): 263 mod = torchvision.models.resnet18() 264 265 def f(x): 266 out = mod(x) 267 out.sum().backward() 268 return [a.grad for a in mod.parameters()] 269 270 inp = torch.randn(3, 3, 250, 250, requires_grad=True) 271 grads = f(inp) 272 273 mod.zero_grad() 274 mod(inp).sum().backward() 275 grads2 = [a.grad for a in mod.parameters()] 276 self.assertEqual(grads, grads2) 277 278 279def get_base(t): 280 return t._base if t._is_view() else t 281 282 283def is_in_base(t, maybe_tensors): 284 t_base = get_base(t) 285 for maybe_tensor in maybe_tensors: 286 if isinstance(maybe_tensor, torch.Tensor): 287 if t_base is get_base(maybe_tensor): 288 return True 289 return False 290 291 292def skipIfDynamoInput(reason): 293 """ 294 Skip TestAOTAutograd if running with dynamo input 295 """ 296 297 def decorator(func): 298 @wraps(func) 299 def wrapper(self, *args, **kwargs): 300 if isinstance(self, TestAOTAutogradWithDynamo): 301 self.skipTest( 302 f"Skipping {self._testMethodName} in TestAOTAutogradWithDynamo because {reason}" 303 ) 304 else: 305 func(self, *args, **kwargs) 306 307 return wrapper 308 309 return decorator 310 311 312class TestAOTAutograd(AOTTestCase): 313 def run_autograd( 314 self, 315 f: Callable, 316 fw_graph_cell: List[Optional[Callable]], 317 decompositions: Optional[Dict], 318 keep_input_mutations: bool, 319 dynamic: bool, 320 ): 321 """ 322 Runs aot_autograd with the specified settings on f. 323 """ 324 if isinstance(f, nn.Module): 325 compiled_f = aot_module( 326 f, 327 fw_compiler=make_boxed_compiler( 328 partial(extract_graph, graph_cell=fw_graph_cell) 329 ), 330 bw_compiler=nop, 331 decompositions=decompositions, 332 keep_inference_input_mutations=keep_input_mutations, 333 dynamic=dynamic, 334 ) 335 else: 336 compiled_f = aot_function( 337 f, 338 fw_compiler=make_boxed_compiler( 339 partial(extract_graph, graph_cell=fw_graph_cell) 340 ), 341 bw_compiler=nop, 342 decompositions=decompositions, 343 keep_inference_input_mutations=keep_input_mutations, 344 dynamic=dynamic, 345 ) 346 return compiled_f 347 348 # test_mutation will: 349 # - Ensure that inputs are non-leaves, so our graphs can mutate them 350 # - try to mutate outputs of the graph (to ensure that autograd meta is set properly on outputs) 351 @patch("functorch.compile.config.debug_assert", True) 352 def verify_aot_autograd( 353 self, 354 f, 355 inp_: Union[Callable, List[Any]], 356 *, 357 test_mutation: bool = False, 358 keep_inp_mutations: bool = False, 359 decompositions: Optional[Dict] = None, 360 dynamic: bool = False, 361 # Only active when inp_ is Callable. 362 # TODO: probably consolidate all tests to make inp a Callable. 363 make_inputs_subclasses: bool = False, 364 ): 365 def make_inputs(inp_): 366 # Some tests pass in a callable for inp, to generate the inputs 367 # (useful if we want to generate complicated aliasing inputs) 368 if isinstance(inp_, Callable): 369 inp_callable = inp_ 370 # The callable should return a tuple of f_inputs, f_graph_inputs 371 # (The idea is that we might want to compile a function with the graph inputs, 372 # but test autograd backprop all the way through the actual inputs) 373 with TwoTensorMode() if make_inputs_subclasses else nullcontext(): 374 inp, graph_inps = inp_callable() 375 else: 376 inp = [] 377 # Our input clones need to mimic when inputs are duplicates of one another 378 dupes_map = {} 379 for i, x in enumerate(inp_): 380 if x in dupes_map: 381 x_dupe_idx = dupes_map[x] 382 inp.append(inp[x_dupe_idx]) 383 else: 384 dupes_map[x] = i 385 if not isinstance(x, torch.Tensor): 386 x_copy = x 387 else: 388 x_copy = x.clone().detach().requires_grad_(x.requires_grad) 389 if x.requires_grad and not x.is_leaf: 390 x_copy = x_copy.clone() 391 392 inp.append(x_copy) 393 394 if test_mutation: 395 # For graphs where we mutate inputs, need our test to make sure inputs aren't leaves 396 graph_inps = [x.add(1) for x in inp] 397 else: 398 graph_inps = inp 399 400 return inp, graph_inps 401 402 def check_results( 403 ref_results, 404 test_results, 405 ref_graph_inps, 406 test_graph_inps, 407 ref_inp, 408 test_inp, 409 ): 410 ref_out, ref_grad = ref_results 411 test_out, test_grad = test_results 412 self.assertEqual(ref_grad, test_grad) 413 if isinstance(ref_out, torch.Tensor): 414 self.assertTrue(isinstance(test_out, torch.Tensor)) 415 ref_out, test_out = [ref_out], [test_out] 416 for ref_o, test_o in zip(ref_out, test_out): 417 if isinstance(ref_o, torch.Tensor): 418 self.assertEqual(ref_o.requires_grad, test_o.requires_grad) 419 self.assertEqual(ref_o.is_leaf, test_o.is_leaf) 420 ref_is_view_of_non_interm = is_in_base( 421 ref_o, ref_graph_inps 422 ) or is_in_base(ref_o, ref_out) 423 test_is_view_of_non_interm = is_in_base( 424 test_o, test_graph_inps 425 ) or is_in_base(test_o, test_out) 426 self.assertEqual( 427 ref_is_view_of_non_interm, test_is_view_of_non_interm 428 ) 429 self.assertEqual(ref_o, test_o) 430 if test_mutation: 431 # This tests that autograd meta is set properly on the output we can 432 # mutate it. 433 ref_o.add_(2) 434 test_o.add_(2) 435 self.assertEqual(ref_o, test_o) 436 # Reverse the modification 437 ref_o.sub_(2) 438 test_o.sub_(2) 439 self.assertEqual(ref_o, test_o) 440 for ref_i, test_i in zip(ref_inp, test_inp): 441 if isinstance(ref_i, torch.Tensor): 442 self.assertEqual(ref_i.requires_grad, test_i.requires_grad) 443 self.assertEqual(ref_i, test_i) 444 445 for keep_input_mutations in [True] if keep_inp_mutations else [True, False]: 446 inp, graph_inps = make_inputs(inp_) 447 test_inp, test_graph_inps = make_inputs(inp_) 448 fw_graph_cell = [None] 449 compiled_f = self.run_autograd( 450 f, fw_graph_cell, decompositions, keep_input_mutations, dynamic 451 ) 452 ref_results = outs_and_grads(f, graph_inps, inp) 453 test_results = outs_and_grads(compiled_f, test_graph_inps, test_inp) 454 455 check_results( 456 ref_results, test_results, graph_inps, test_graph_inps, inp, test_inp 457 ) 458 if isinstance(self, TestAOTAutogradWithCache): 459 # When testing with cache, run compiled_f a second time 460 cached_inp, cached_graph_inps = make_inputs(inp_) 461 cached_results = outs_and_grads( 462 compiled_f, cached_graph_inps, cached_inp 463 ) 464 check_results( 465 ref_results, 466 cached_results, 467 graph_inps, 468 cached_graph_inps, 469 inp, 470 cached_inp, 471 ) 472 473 return fw_graph_cell[0] 474 475 def test_non_tensor_and_none_inputs(self): 476 # int, None, Tensor 477 def f(a, b, c): 478 return a * c 479 480 inp = [2, None, torch.ones(3, 3, dtype=torch.float32, requires_grad=True)] 481 self.verify_aot_autograd(f, inp) 482 inp = [2, None, torch.ones(3, 3, dtype=torch.float32, requires_grad=False)] 483 self.verify_aot_autograd(f, inp) 484 485 def test_single_output(self): 486 def f(a, b): 487 return a + b 488 489 inp = [torch.randn(3, 3, requires_grad=True), torch.randn(3, 3)] 490 self.verify_aot_autograd(f, inp) 491 inp = [torch.randn(3, 3, requires_grad=False), torch.randn(3, 3)] 492 self.verify_aot_autograd(f, inp) 493 494 def test_multi_output(self): 495 def f(a, b): 496 return a + b, a - b 497 498 inp = [torch.randn(3, 3, requires_grad=True), torch.randn(3, 3)] 499 self.verify_aot_autograd(f, inp) 500 inp = [torch.randn(3, 3, requires_grad=False), torch.randn(3, 3)] 501 self.verify_aot_autograd(f, inp) 502 503 def test_multi_output_list(self): 504 def f(a, b): 505 return [a + b, a - b] 506 507 inp = [torch.randn(3, 3, requires_grad=True), torch.randn(3, 3)] 508 self.verify_aot_autograd(f, inp) 509 inp = [torch.randn(3, 3, requires_grad=False), torch.randn(3, 3)] 510 self.verify_aot_autograd(f, inp) 511 512 # Test for bug occurring at the intersection of fake tensors & functionalization. 513 def test_squeeze_mutation(self): 514 def f(a): 515 b = a.clone().squeeze(-1) 516 b.add_(1.0) 517 return a + b 518 519 inp = [torch.randn(3, 1, requires_grad=True)] 520 self.verify_aot_autograd(f, inp, dynamic=True) 521 inp = [torch.randn(3, 1, requires_grad=False)] 522 self.verify_aot_autograd(f, inp, dynamic=True) 523 524 def test_complex_linear(self): 525 # https://github.com/pytorch/pytorch/issues/93424 526 inp = [torch.randn(1, 10, 10, dtype=torch.complex64)] 527 528 class F(torch.nn.Module): 529 def __init__(self) -> None: 530 super().__init__() 531 self.linear = nn.Linear(10, 10, dtype=torch.complex64) 532 533 def forward(self, x): 534 return self.linear(x).sum().abs() 535 536 self.verify_aot_autograd(F(), inp) 537 538 def test_embedding_bag_view_dynamic(self): 539 # Backwards pass tries to wrap a sparse tensor in a FunctionalTensorWrapper; 540 # test that this works even though the sparse tensor has no storage. 541 542 class F(torch.nn.Module): 543 def __init__(self) -> None: 544 super().__init__() 545 self.emb = torch.nn.EmbeddingBag(100, 8, sparse=True) 546 547 def forward(self, x, y): 548 return self.emb(x, y).view(-1) 549 550 x = torch.arange(3) 551 y = torch.arange(3) 552 self.verify_aot_autograd(F(), [x, y], dynamic=False) 553 self.verify_aot_autograd(F(), [x, y], dynamic=True) 554 555 def test_input_mutation_simple(self): 556 def f(a): 557 a.mul_(2) 558 return a * 3 559 560 inp = [torch.ones(3, 3, requires_grad=True)] 561 fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) 562 inp = [torch.ones(3, 3, requires_grad=False)] 563 self.verify_aot_autograd(f, inp, test_mutation=True) 564 # Things to note: 565 # - the extra clone is because we need to pass the pre-mutated input to grad(), 566 # but autograd operates above functionalization so we need to manually clone. 567 # Hopefully backends can optimize this easily. 568 # - The extra return arg is because the compiled forward returns (mutated inputs + outputs) 569 self.assertExpectedInline( 570 fw_graph.code.strip(), 571 """\ 572def forward(self, primals_1): 573 clone = torch.ops.aten.clone.default(primals_1); primals_1 = None 574 mul = torch.ops.aten.mul.Tensor(clone, 2); clone = None 575 mul_1 = torch.ops.aten.mul.Tensor(mul, 3) 576 return (mul, mul_1)""", 577 ) 578 579 def test_input_mutation_set__input_mutation(self): 580 def f(a): 581 b = torch.arange(9, dtype=a.dtype).reshape(3, 3) 582 with torch.no_grad(): 583 a.set_(b) 584 return a * b 585 586 inp = [torch.ones(3, 3, requires_grad=True)] 587 self.verify_aot_autograd(f, inp, test_mutation=True, keep_inp_mutations=True) 588 inp = [torch.ones(3, 3, requires_grad=False)] 589 self.verify_aot_autograd(f, inp, test_mutation=True, keep_inp_mutations=True) 590 591 def test_set__steals_view_chain(self): 592 def f(a, b): 593 a_ = a.mul(2) 594 b_ = b.mul(2) 595 b_slice = b_[1].view(3, 3) 596 # a_clone should inherit the view chain from b_slice 597 a_.set_(b_slice) 598 # Also mutates b_, 599 a_.view(-1).mul_(2) 600 return a_ * b_slice 601 602 inp = [ 603 torch.ones(3, 3, requires_grad=False), 604 torch.zeros(3, 9, requires_grad=False), 605 ] 606 self.verify_aot_autograd(f, inp, keep_inp_mutations=True) 607 608 @skipIfDynamoInput( 609 "Test doesn't make sense with dynamo, which changes order of mutations" 610 ) 611 def test_set__and_data_mutation_good(self): 612 def f(a, b): 613 # The data mutation happens *after* the set_(). This is ok (see the graph below) 614 with torch.no_grad(): 615 a.set_(b) 616 b.mul_(2) 617 return a + b 618 619 inp = [ 620 torch.ones(3, 3, requires_grad=True), 621 torch.ones(3, 3, requires_grad=True), 622 ] 623 fw_graph = self.verify_aot_autograd( 624 f, inp, test_mutation=True, keep_inp_mutations=True 625 ) 626 inp = [ 627 torch.ones(3, 3, requires_grad=False), 628 torch.zeros(3, 3, requires_grad=False), 629 ] 630 self.verify_aot_autograd(f, inp, test_mutation=True, keep_inp_mutations=True) 631 # Important things to note: 632 # - "return a.set_(b)" desugars into "return b" 633 # - Both a and b are recorded as experiencing mutations, 634 # which is why we see "b_updated" (output of the mul) twice in the graph outputs. 635 # a is recorded as both a data mutation and a metadata mutation (due to set_ swapping its storage). 636 # - the runtime epilogue for a is "a.set_(mul)" 637 # - the runtime epilogue for b is "b.copy_(mul)" 638 self.assertExpectedInline( 639 fw_graph.code.strip(), 640 """\ 641def forward(self, primals_1, primals_2): 642 mul = torch.ops.aten.mul.Tensor(primals_2, 2) 643 add = torch.ops.aten.add.Tensor(mul, mul) 644 set_ = torch.ops.aten.set_.source_Tensor(primals_1, mul); primals_1 = set_ = None 645 copy_ = torch.ops.aten.copy_.default(primals_2, mul); primals_2 = mul = copy_ = None 646 return (add,)""", 647 ) 648 649 # This is a (hopefully) extremely rare case that is difficult to handle, 650 # so we ban it. 651 # https://github.com/pytorch/pytorch/issues/126236 652 # https://github.com/pytorch/pytorch/pull/126113 653 @xfailIfTorchDynamo 654 def test_set__and_data_mutation_bad(self): 655 def f(a): 656 a_view = a.view(-1) 657 tmp = torch.ones(3, 3, requires_grad=True) 658 # Now, any mutations on either tmp 659 # will be tracked as graph input mutations. 660 with torch.no_grad(): 661 a.set_(tmp) 662 # BAD: a_view is now detached from every graph input, 663 # so we won't recognize that this caused an input mutation! 664 a_view.mul_(2) 665 return a + tmp 666 667 inp = [torch.ones(3, 3, requires_grad=True)] 668 with self.assertRaisesRegex( 669 RuntimeError, "cannot mutate tensors with frozen storage" 670 ): 671 self.verify_aot_autograd( 672 f, inp, test_mutation=True, keep_inp_mutations=True 673 ) 674 675 @skipIfDynamoInput( 676 "Test doesn't make sense with dynamo, which changes order of mutations" 677 ) 678 def test_set__not_allowed(self): 679 def f(a, b): 680 with torch.no_grad(): 681 a.set_(b) 682 # Mutating a will change a's grad_fn, which requires us to replay the mutation outside of the graph. 683 # We currently ban this today, when the input also received a set_() input mutation. 684 a.mul_(2) 685 return a + b 686 687 inp = [ 688 torch.ones(3, 3, requires_grad=True), 689 torch.ones(3, 3, requires_grad=True), 690 ] 691 with self.assertRaisesRegex( 692 AssertionError, "but the input has other mutations that we cannot" 693 ): 694 fw_graph = self.verify_aot_autograd( 695 f, inp, test_mutation=True, keep_inp_mutations=True 696 ) 697 698 def test_input_mutation_set__nop(self): 699 def f(a): 700 b = torch.arange(9, dtype=a.dtype) 701 a_old = torch.ops.aten.alias.default(a) 702 with torch.no_grad(): 703 a.set_(b) 704 a.set_(a_old) 705 return a + b.reshape(3, 3) 706 707 inp = [torch.ones(3, 3, requires_grad=True)] 708 fw_graph = self.verify_aot_autograd( 709 f, inp, test_mutation=True, keep_inp_mutations=True 710 ) 711 inp = [torch.ones(3, 3, requires_grad=False)] 712 self.verify_aot_autograd(f, inp, test_mutation=True, keep_inp_mutations=True) 713 # Things to note: 714 # - There are no set_() calls in the graph (we functionalize a.set_(b) into "b") 715 # - There is only **1** graph output. We properly realized that the two set_() calls 716 # undo each other, and so effectively no inputs are mutated. 717 self.assertExpectedInline( 718 fw_graph.code.strip(), 719 """\ 720def forward(self, primals_1): 721 arange = torch.ops.aten.arange.default(9, dtype = torch.float32, device = device(type='cpu'), pin_memory = False) 722 alias = torch.ops.aten.alias.default(primals_1); primals_1 = None 723 view = torch.ops.aten.view.default(arange, [3, 3]); arange = None 724 add = torch.ops.aten.add.Tensor(alias, view); alias = view = None 725 return (add,)""", 726 ) 727 728 @unittest.skipIf(IS_WINDOWS, "TODO: need to fix the test case") 729 @unittest.skipIf(IS_MACOS, "TODO: need to fix the test case") 730 def test_input_mutation_fsdp_set__into_same_input(self): 731 import torch.distributed._composable.fsdp._fsdp_param 732 733 def f(a): 734 b = torch.arange(9, dtype=a.dtype).view(3, 3) 735 c = torch.arange(9, dtype=a.dtype).view(3, 3) 736 d = torch.arange(9, dtype=a.dtype).view(3, 3) 737 with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(a): 738 torch.ops.fsdp.set_.default(a, b) 739 x = a * a 740 with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(a): 741 torch.ops.fsdp.set_.default(a, c) 742 y = a * a 743 with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(a): 744 torch.ops.fsdp.set_.default(a, c) 745 z = a * a 746 return x + y + z 747 748 inp = [torch.ones(3, 3, requires_grad=True)] 749 fw_graph = self.verify_aot_autograd( 750 f, inp, test_mutation=True, keep_inp_mutations=True 751 ) 752 inp = [torch.ones(3, 3, requires_grad=False)] 753 self.verify_aot_autograd(f, inp, test_mutation=True, keep_inp_mutations=True) 754 """ 755 Expected behavior: 756 (1) When there are multiple set_() calls on the same graph input primal_X, 757 we want those set_() calls to all show up with primal_X as the first arg in the graph. 758 (2) Behavior (1) is not the case today with normal aten.set_ (blocked on #129892), 759 but using a custom fsdp.set_ op with no returns is a simple workaround to achieve that behavior. 760 """ 761 self.assertExpectedInline( 762 fw_graph.code.strip(), 763 """\ 764def forward(self, primals_1): 765 arange = torch.ops.aten.arange.default(9, dtype = torch.float32, device = device(type='cpu'), pin_memory = False) 766 view = torch.ops.aten.view.default(arange, [3, 3]); arange = None 767 arange_1 = torch.ops.aten.arange.default(9, dtype = torch.float32, device = device(type='cpu'), pin_memory = False) 768 view_1 = torch.ops.aten.view.default(arange_1, [3, 3]); arange_1 = None 769 set_ = torch.ops.fsdp.set_.default(primals_1, view); view = set_ = None 770 mul = torch.ops.aten.mul.Tensor(primals_1, primals_1) 771 set__1 = torch.ops.fsdp.set_.default(primals_1, view_1); set__1 = None 772 mul_1 = torch.ops.aten.mul.Tensor(primals_1, primals_1) 773 set__2 = torch.ops.fsdp.set_.default(primals_1, view_1); view_1 = set__2 = None 774 mul_2 = torch.ops.aten.mul.Tensor(primals_1, primals_1) 775 add = torch.ops.aten.add.Tensor(mul, mul_1); mul = mul_1 = None 776 add_1 = torch.ops.aten.add.Tensor(add, mul_2); add = mul_2 = None 777 return (add_1, primals_1)""", 778 ) 779 self.assertEqual(torch.compile(f, backend="inductor")(*inp), f(*inp)) 780 781 def test_input_mutation_simple_with_none_and_nontensor(self): 782 # Tensor, None, int 783 def f(a, b, c): 784 return a * c 785 786 f_compiled = aot_function(f, nop) 787 for req_grad in [True, False]: 788 inp = [torch.ones(3, 3, requires_grad=req_grad), None, 3] 789 out_ref = f(*inp) 790 out_test = f_compiled(*inp) 791 self.assertEqual(out_ref, out_test) 792 793 # https://github.com/pytorch/pytorch/issues/93363 794 def test_mutates_input_noncontiguous(self): 795 def f(a): 796 a.add_(1) 797 return () 798 799 f_compiled = aot_function(f, nop) 800 ref = torch.ones(4, requires_grad=True) + 0 801 ref_view = ref[0::2] 802 803 test = torch.ones(4, requires_grad=True) + 0 804 test_view = test[0::2] 805 806 out_ref = f(ref_view) 807 out_test = f_compiled(test_view) 808 self.assertEqual(ref, test) 809 810 def test_input_mutation_modifies_autograd_meta_of_aliases(self): 811 def f(a): 812 a.mul_(2) 813 out = a + 1 814 return out.detach() 815 816 x_ref = torch.ones(3, 3, requires_grad=True).clone() 817 x_ref_view = x_ref.view(3, 3) 818 819 x_test = torch.ones(3, 3, requires_grad=True).clone() 820 x_test_view = x_test.view(3, 3) 821 822 f_compiled = aot_function(f, nop, keep_inference_input_mutations=True) 823 f(x_ref) 824 f_compiled(x_test) 825 # f will mutate aliases of the input, including its autograd metadata! 826 # y.grad_fn is AsStridedBackward 827 self.assertEqual(x_ref_view, x_test_view) 828 self.assertEqual(x_ref_view._version, x_test_view._version) 829 self.assertEqual(x_ref_view.grad_fn.__class__, x_test_view.grad_fn.__class__) 830 # Test the actual gradients are correct 831 (x_ref * x_ref_view).sum().backward() 832 (x_test * x_test_view).sum().backward() 833 self.assertEqual(x_ref.grad, x_test.grad) 834 self.assertEqual(x_ref_view.grad, x_test_view.grad) 835 836 @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/127470") 837 def test_nested_subclasses(self): 838 @torch.compile(backend="aot_eager") 839 def f(x): 840 return x.sin().cos() 841 842 a = torch.ones(4, requires_grad=True) 843 a2 = a.clone().detach().requires_grad_() 844 a3 = a.clone().detach().requires_grad_() 845 a4 = a.clone().detach().requires_grad_() 846 aa = TwoTensor(a, a2) 847 aa2 = TwoTensor(a3, a4) 848 aaaa = TwoTensor(aa, aa2) 849 out = f(aaaa) 850 self.assertTrue(isinstance(out, TwoTensor)) 851 self.assertTrue(isinstance(out.a, TwoTensor)) 852 self.assertTrue(isinstance(out.b, TwoTensor)) 853 self.assertTrue(isinstance(out.a.a, torch.Tensor)) 854 self.assertTrue(isinstance(out.a.b, torch.Tensor)) 855 self.assertTrue(isinstance(out.b.a, torch.Tensor)) 856 self.assertTrue(isinstance(out.b.b, torch.Tensor)) 857 858 out.sum().backward() 859 self.assertTrue(isinstance(aaaa.grad, TwoTensor)) 860 self.assertTrue(isinstance(aaaa.grad.a, TwoTensor)) 861 self.assertTrue(isinstance(aaaa.grad.b, TwoTensor)) 862 863 @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/127470") 864 def test_nested_subclasses_non_nested_grad(self): 865 @torch.compile(backend="aot_eager") 866 def f(x): 867 return x.sin().cos() 868 869 a = torch.ones(4, requires_grad=True) 870 a2 = a.clone().detach().requires_grad_() 871 a3 = a.clone().detach().requires_grad_() 872 a4 = a.clone().detach().requires_grad_() 873 new_aa = TwoTensor(a3, a4) 874 aa = TwoTensor(a, a2) 875 876 aa2 = aa.clone().detach().requires_grad_() 877 aaaa = TwoTensor(aa, aa2) 878 out = f(new_aa) 879 new_out = out + aaaa 880 with self.assertRaisesRegex( 881 RuntimeError, 882 "The grad inputs should be same tensor subclass type as forward output", 883 ): 884 new_out.sum().backward() 885 886 @unittest.skipIf(IS_WINDOWS, "Windows isn't supported for this case") 887 @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/127470") 888 def test_custom_tensor_metadata(self): 889 def f(x): 890 x_elem = x.elem 891 x_elem_elem = x_elem.elem 892 x_elem_metadata = x_elem.constant_attribute 893 return x * x_elem * x_elem_elem * x_elem_metadata 894 895 a = torch.ones(4, requires_grad=True) 896 custom_a = ConstantExtraMetadataTensor(a) 897 custom_a.constant_attribute = 6 898 custom_aa = ConstantExtraMetadataTensor(custom_a) 899 custom_aa.constant_attribute = 4 900 901 custom_aa_compile = custom_aa.clone().detach().requires_grad_() 902 custom_aa_compile.elem.constant_attribute = 6 903 out_eager = f(custom_aa) 904 905 compiled_f = torch.compile(f, backend="aot_eager") 906 out = compiled_f(custom_aa_compile) 907 908 self.assertTrue(torch.allclose(out_eager, out)) 909 910 out.sum().backward() 911 912 self.assertTrue(isinstance(custom_aa_compile.grad, ConstantExtraMetadataTensor)) 913 self.assertTrue( 914 isinstance(custom_aa_compile.grad.elem, ConstantExtraMetadataTensor) 915 ) 916 917 @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/127470") 918 def test_nested_subclasses_complicated_inps(self): 919 def f(x, y, z): 920 temp = x + y 921 temp_plain = x.a + y.b 922 res = temp.sum() + temp_plain.sum() 923 return x.sin().cos() + res 924 925 x = torch.ones(4, requires_grad=True) 926 x2 = x.clone().detach().requires_grad_() 927 xx = TwoTensor(x, x2) 928 xx2 = xx.clone().detach().requires_grad_() 929 930 x_nested = TwoTensor(xx, xx2) 931 x_nested_compile = x_nested.clone().detach().requires_grad_() 932 933 y_nested = x_nested.clone().detach().requires_grad_() 934 y_nested_compile = y_nested.clone().detach().requires_grad_() 935 936 z = x.clone().detach().requires_grad_() 937 z_compile = z.clone().detach().requires_grad_() 938 939 out_eager = f(x_nested, y_nested, z) 940 compiled_f = torch.compile(f, backend="aot_eager") 941 out = compiled_f(x_nested_compile, y_nested_compile, z_compile) 942 self.assertTrue(torch.allclose(out_eager, out)) 943 944 self.assertTrue(isinstance(out, TwoTensor)) 945 self.assertTrue(isinstance(out.a, TwoTensor)) 946 self.assertTrue(isinstance(out.b, TwoTensor)) 947 self.assertTrue(isinstance(out.a.a, torch.Tensor)) 948 self.assertTrue(isinstance(out.a.b, torch.Tensor)) 949 self.assertTrue(isinstance(out.b.a, torch.Tensor)) 950 self.assertTrue(isinstance(out.b.b, torch.Tensor)) 951 952 out.sum().backward() 953 out_eager.sum().backward() 954 955 self.assertTrue(isinstance(x_nested_compile.grad, TwoTensor)) 956 self.assertTrue(isinstance(x_nested_compile.grad.a, TwoTensor)) 957 self.assertTrue(isinstance(x_nested_compile.grad.b, TwoTensor)) 958 959 self.assertTrue(isinstance(y_nested_compile.grad, TwoTensor)) 960 self.assertTrue(isinstance(y_nested_compile.grad.a, TwoTensor)) 961 self.assertTrue(isinstance(y_nested_compile.grad.b, TwoTensor)) 962 963 self.assertTrue(torch.allclose(x_nested_compile.grad.a.a, x_nested.grad.a.a)) 964 self.assertTrue(torch.allclose(x_nested_compile.grad.a.b, x_nested.grad.a.b)) 965 self.assertTrue(torch.allclose(y_nested_compile.grad.a.a, y_nested.grad.a.a)) 966 self.assertTrue(torch.allclose(y_nested_compile.grad.a.b, y_nested.grad.a.b)) 967 968 @unittest.skipIf(IS_WINDOWS, "Windows isn't supported for this case") 969 @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/127470") 970 def test_nested_subclasses_complicated_inps_mixed(self): 971 def f(x, y): 972 y_elem = y.elem 973 y_elem_elem = y_elem.elem 974 y_elem_metadata = y_elem.constant_attribute 975 return y * y_elem * y_elem_elem * y_elem_metadata + x 976 977 x = torch.ones(4, requires_grad=True) 978 x2 = x.clone().detach().requires_grad_() 979 xx = TwoTensor(x, x2) 980 xx2 = xx.clone().detach().requires_grad_() 981 982 x_nested = TwoTensor(xx, xx2) 983 x_nested_compile = x_nested.clone().detach().requires_grad_() 984 985 a = torch.ones(4, requires_grad=True) 986 custom_a = ConstantExtraMetadataTensor(a) 987 custom_a.constant_attribute = 6 988 custom_aa = ConstantExtraMetadataTensor(custom_a) 989 custom_aa.constant_attribute = 4 990 991 custom_aa_compile = custom_aa.clone().detach().requires_grad_() 992 custom_aa_compile.constant_attribute = 4 993 custom_aa_compile.elem.constant_attribute = 6 994 995 compiled_f = torch.compile(f, backend="aot_eager") 996 out_eager = f(x_nested, custom_aa) 997 out = compiled_f(x_nested_compile, custom_aa_compile) 998 self.assertTrue(torch.allclose(out_eager, out)) 999 1000 out.sum().backward() 1001 out_eager.sum().backward() 1002 1003 self.assertTrue(torch.allclose(x_nested_compile.grad, x_nested.grad)) 1004 self.assertTrue(torch.allclose(custom_aa_compile.grad, custom_aa.grad)) 1005 1006 @skipIfTorchDynamo("This test suite already uses dynamo") 1007 def test_composite_impl_compile(self): 1008 class Foo(torch.nn.Module): 1009 def __init__(self) -> None: 1010 super().__init__() 1011 self.linear = torch.nn.Linear(3, 3) 1012 1013 def forward(self, a): 1014 return self.linear(a) 1015 1016 inp = [torch.ones(3, 3, requires_grad=True)] 1017 fw_graph = self.verify_aot_autograd(Foo(), inp, test_mutation=True) 1018 inp = [torch.ones(3, 3, requires_grad=False)] 1019 self.assertExpectedInline( 1020 fw_graph.code.strip(), 1021 """\ 1022def forward(self, primals_1, primals_2, primals_3): 1023 t = torch.ops.aten.t.default(primals_1); primals_1 = None 1024 addmm = torch.ops.aten.addmm.default(primals_2, primals_3, t); primals_2 = None 1025 return (addmm, primals_3, t)""", 1026 ) 1027 1028 with torch.inference_mode(): 1029 fw_graph = self.verify_aot_autograd(Foo(), inp, test_mutation=True) 1030 inp = [torch.ones(3, 3, requires_grad=False)] 1031 self.assertExpectedInline( 1032 fw_graph.code.strip(), 1033 """\ 1034def forward(self, arg0_1, arg1_1, arg2_1): 1035 t = torch.ops.aten.t.default(arg0_1); arg0_1 = None 1036 addmm = torch.ops.aten.addmm.default(arg1_1, arg2_1, t); arg1_1 = arg2_1 = t = None 1037 return (addmm,)""", 1038 ) 1039 1040 def test_outputs_are_aliased(self): 1041 # Tensor, None, int 1042 def f(a): 1043 b = a.mul(2) 1044 c = b.view(-1) 1045 return b, c 1046 1047 f_compiled = aot_function(f, nop) 1048 for req_grad in [True, False]: 1049 inp = torch.ones(3, requires_grad=req_grad) 1050 out_ref = f(inp) 1051 out_test = f_compiled(inp) 1052 self.assertEqual(out_ref[0], out_test[0]) 1053 self.assertEqual(out_ref[1], out_test[1]) 1054 # Try mutating one of the outputs, which is aliased. 1055 out_ref[0].mul_(3) 1056 out_test[0].mul_(3) 1057 # Assert that the aliasing relationship was preserved 1058 self.assertEqual(out_ref[0], out_test[0]) 1059 self.assertEqual(out_ref[1], out_test[1]) 1060 1061 def test_input_mutation_is_output(self): 1062 def f(a): 1063 a.mul_(2) 1064 return a 1065 1066 inp = [torch.ones(3, 3, requires_grad=True)] 1067 fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) 1068 inp = [torch.ones(3, 3, requires_grad=False)] 1069 self.verify_aot_autograd(f, inp, test_mutation=True) 1070 self.assertExpectedInline( 1071 fw_graph.code.strip(), 1072 """\ 1073def forward(self, primals_1): 1074 clone = torch.ops.aten.clone.default(primals_1); primals_1 = None 1075 mul = torch.ops.aten.mul.Tensor(clone, 2); clone = None 1076 return (mul, mul)""", 1077 ) 1078 1079 def test_input_mutation_multiple(self): 1080 def f(a, b, c): 1081 a.mul_(2) 1082 c.mul_(2) 1083 return a + b + c 1084 1085 def create_inp(req_grad): 1086 return [ 1087 torch.ones(3, 3, requires_grad=req_grad), 1088 torch.ones(3, 3, requires_grad=req_grad), 1089 torch.ones(3, 3, requires_grad=req_grad), 1090 ] 1091 1092 self.verify_aot_autograd(f, create_inp(False), test_mutation=True) 1093 1094 fw_graph = self.verify_aot_autograd(f, create_inp(True), test_mutation=True) 1095 self.assertExpectedInline( 1096 fw_graph.code.strip(), 1097 """\ 1098def forward(self, primals_1, primals_2, primals_3): 1099 clone = torch.ops.aten.clone.default(primals_1); primals_1 = None 1100 clone_1 = torch.ops.aten.clone.default(primals_3); primals_3 = None 1101 mul = torch.ops.aten.mul.Tensor(clone, 2); clone = None 1102 mul_1 = torch.ops.aten.mul.Tensor(clone_1, 2); clone_1 = None 1103 add = torch.ops.aten.add.Tensor(mul, primals_2); primals_2 = None 1104 add_1 = torch.ops.aten.add.Tensor(add, mul_1); add = None 1105 return (mul, mul_1, add_1)""", 1106 ) 1107 1108 def test_input_mutation_return(self): 1109 def f(a, b): 1110 return torch.sin(a, out=b) 1111 1112 inp = [torch.randn(3, 3), torch.ones(3, 3)] 1113 1114 fw_graph = self.verify_aot_autograd( 1115 f, inp, test_mutation=True, keep_inp_mutations=True 1116 ) 1117 self.assertExpectedInline( 1118 fw_graph.code.strip(), 1119 """\ 1120def forward(self, arg0_1, arg1_1): 1121 sin = torch.ops.aten.sin.default(arg0_1); arg0_1 = None 1122 copy_ = torch.ops.aten.copy_.default(arg1_1, sin); arg1_1 = sin = None 1123 return (copy_,)""", 1124 ) 1125 1126 def test_input_mutation_metadata(self): 1127 def f(a, b): 1128 a.transpose_(1, 0) 1129 return a + b 1130 1131 def create_inp(req_grad): 1132 return [ 1133 torch.ones(3, 3, requires_grad=req_grad), 1134 torch.ones(3, 3, requires_grad=req_grad), 1135 ] 1136 1137 self.verify_aot_autograd(f, create_inp(True), test_mutation=True) 1138 self.verify_aot_autograd(f, create_inp(False), test_mutation=True) 1139 1140 def test_input_mutation_storage_resize_up(self): 1141 def f(a): 1142 torch.ops.inductor.resize_storage_bytes_(a, 32) 1143 # float32, 4 bytes per element, 32 bytes == 8 elements 1144 with torch.no_grad(): 1145 a.copy_(torch.ones(8)) 1146 return a + 1 1147 1148 inp = torch.zeros(8, requires_grad=True) 1149 # Input starts with zero-size-storage 1150 inp.untyped_storage().resize_(0) 1151 1152 fw_graph_cell = [None] 1153 compiled_f = aot_function( 1154 f, 1155 fw_compiler=make_boxed_compiler( 1156 partial(extract_graph, graph_cell=fw_graph_cell) 1157 ), 1158 bw_compiler=nop, 1159 decompositions={}, 1160 keep_inference_input_mutations=True, 1161 dynamic=False, 1162 ) 1163 out = compiled_f(inp) 1164 # Final functionalized graph has two mutation ops: 1165 # (1) a resize_() to resize input tensor up 1166 # (2) a copy_() to fill in the resized input with valid data 1167 self.assertExpectedInline( 1168 fw_graph_cell[0].code.strip(), 1169 """\ 1170def forward(self, primals_1): 1171 resize_storage_bytes_ = torch.ops.inductor.resize_storage_bytes_.default(primals_1, 32); resize_storage_bytes_ = None 1172 ones = torch.ops.aten.ones.default([8], device = device(type='cpu'), pin_memory = False) 1173 copy = torch.ops.aten.copy.default(primals_1, ones); ones = None 1174 add = torch.ops.aten.add.Tensor(copy, 1) 1175 copy_ = torch.ops.aten.copy_.default(primals_1, copy); primals_1 = copy = copy_ = None 1176 return (add,)""", 1177 ) 1178 1179 def test_input_mutation_storage_resize_down(self): 1180 def f(a): 1181 out = a.sin() 1182 torch.ops.inductor.resize_storage_bytes_(a, 0) 1183 return out 1184 1185 inp = torch.zeros(8, requires_grad=True) 1186 1187 fw_graph_cell = [None] 1188 compiled_f = aot_function( 1189 f, 1190 fw_compiler=make_boxed_compiler( 1191 partial(extract_graph, graph_cell=fw_graph_cell) 1192 ), 1193 bw_compiler=nop, 1194 decompositions={}, 1195 keep_inference_input_mutations=True, 1196 dynamic=False, 1197 ) 1198 out = compiled_f(inp) 1199 # Final functionalized graph has one mutation ops: 1200 # (1) a resize_() to resize input tensor down 1201 # Even though there was technically a "data mutation" on the input (from a.copy_()), 1202 # We don't include it in the graph since the final input size has zero storage 1203 self.assertExpectedInline( 1204 fw_graph_cell[0].code.strip(), 1205 """\ 1206def forward(self, primals_1): 1207 sin = torch.ops.aten.sin.default(primals_1) 1208 resize_storage_bytes_ = torch.ops.inductor.resize_storage_bytes_.default(primals_1, 0); resize_storage_bytes_ = None 1209 return (sin, primals_1)""", 1210 ) 1211 1212 # def test_input_mutation_storage_resize_up_down(self): 1213 # def f(a): 1214 # torch.ops.inductor.resize_storage_bytes_(a, 32) 1215 # # float32, 4 bytes per element, 32 bytes == 8 elements 1216 # with torch.no_grad(): 1217 # a.copy_(torch.ones(8)) 1218 # out = a.sin() 1219 # torch.ops.inductor.resize_storage_bytes_(a, 0) 1220 # return out 1221 1222 # inp = torch.zeros(8, requires_grad=True) 1223 # # Input starts with zero-size-storage 1224 # inp.untyped_storage().resize_(0) 1225 1226 # fw_graph_cell = [None] 1227 # compiled_f = aot_function( 1228 # f, 1229 # fw_compiler=make_boxed_compiler( 1230 # partial(extract_graph, graph_cell=fw_graph_cell) 1231 # ), 1232 # bw_compiler=nop, 1233 # decompositions={}, 1234 # keep_inference_input_mutations=True, 1235 # dynamic=False, 1236 # ) 1237 # out = compiled_f(inp) 1238 # # Final graph has two interesting properties: 1239 # # (1) no resizes in the functional graph, since the two resizes cancel out 1240 # # and the final size is zero 1241 # # (2) no copy_ in the functional graph, even though we copied data into the input, 1242 # # because the input has no storage at the end of graph execution (so no data to copy) 1243 # self.assertExpectedInline( 1244 # fw_graph_cell[0].code.strip(), 1245 # """\ 1246 # def forward(self, primals_1): 1247 # ones = torch.ops.aten.ones.default([8], device = device(type='cpu'), pin_memory = False) 1248 # copy = torch.ops.aten.copy.default(primals_1, ones); primals_1 = ones = None 1249 # sin = torch.ops.aten.sin.default(copy) 1250 # return [sin, copy]""", 1251 # ) 1252 1253 def test_input_mutation_storage_resize_down_and_set_(self): 1254 # Meant to mimic ppFSDP 1255 class TracableCreateParameter(torch.autograd.Function): 1256 @staticmethod 1257 def forward(ctx, tensor, placeholder): 1258 assert not tensor.requires_grad 1259 return placeholder.set_(tensor) 1260 1261 @staticmethod 1262 def backward(ctx, grad): 1263 return None, grad # grad flows to placeholder 1264 1265 def f(dummy_param, param_shard): 1266 # simulate allgather 1267 with torch.no_grad(): 1268 allgather_param = torch.cat([param_shard, param_shard]) 1269 # simulate propagating grad state through dummy param, using data of allgather param 1270 dummy_param_with_grad_state = TracableCreateParameter.apply( 1271 allgather_param, dummy_param 1272 ) 1273 out = dummy_param.sin() 1274 # Resize out dummy param, which now has the allgather data 1275 torch.ops.inductor.resize_storage_bytes_(dummy_param, 0) 1276 return out 1277 1278 # Simulates the local shard of our param 1279 param_shard = torch.zeros(8, requires_grad=True) 1280 # The dummy, zero-sized allgathered param that autograd will actually compute gradients on 1281 dummy_param = torch.zeros(16, requires_grad=True) 1282 dummy_param.untyped_storage().resize_(0) 1283 1284 fw_graph_cell = [None] 1285 compiled_f = aot_function( 1286 f, 1287 fw_compiler=make_boxed_compiler( 1288 partial(extract_graph, graph_cell=fw_graph_cell) 1289 ), 1290 bw_compiler=nop, 1291 decompositions={}, 1292 keep_inference_input_mutations=True, 1293 dynamic=False, 1294 ) 1295 out = compiled_f(dummy_param, param_shard) 1296 # Important stuff to point out: 1297 # (1) We save cat for backward (input to the sin()). 1298 # While the original code was dummy_param.sin(), 1299 # dummy_param actually contains the `cat` tensor due to the set_() call 1300 # (2) We emit a cat.resize_storage_(0) in the graph. 1301 # After the set_(), cat is the actually data of dummy_param, which is what we call resize_() on 1302 self.assertExpectedInline( 1303 fw_graph_cell[0].code.strip(), 1304 """\ 1305def forward(self, primals_1, primals_2): 1306 cat = torch.ops.aten.cat.default([primals_2, primals_2]); primals_2 = None 1307 sin = torch.ops.aten.sin.default(cat) 1308 resize_storage_bytes_ = torch.ops.inductor.resize_storage_bytes_.default(cat, 0); resize_storage_bytes_ = None 1309 set_ = torch.ops.aten.set_.source_Tensor(primals_1, cat); primals_1 = set_ = None 1310 return (sin, cat)""", 1311 ) 1312 1313 def test_input_mutation_storage_resize_before_set_(self): 1314 def f(a): 1315 with torch.no_grad(): 1316 torch.ops.inductor.resize_storage_bytes_(a, 0) 1317 a.set_(torch.ones(2)) 1318 1319 inp = torch.zeros(8, requires_grad=True) 1320 1321 compiled_f = aot_function( 1322 f, 1323 fw_compiler=nop, 1324 bw_compiler=nop, 1325 decompositions={}, 1326 keep_inference_input_mutations=True, 1327 dynamic=False, 1328 ) 1329 out = compiled_f(inp) 1330 1331 # def test_input_mutation_storage_resize_not_supported(self): 1332 # def f(a): 1333 # a.mul_(2) 1334 # torch.ops.inductor.resize_storage_bytes_(a, 0) 1335 # return a 1336 1337 # inp = torch.zeros(8, requires_grad=True) 1338 1339 # with self.assertRaisesRegex( 1340 # AssertionError, "the input has other mutations that we cannot" 1341 # ): 1342 # compiled_f = aot_function( 1343 # f, 1344 # fw_compiler=nop, 1345 # bw_compiler=nop, 1346 # decompositions={}, 1347 # keep_inference_input_mutations=True, 1348 # dynamic=False, 1349 # ) 1350 # out = compiled_f(inp) 1351 1352 def test_input_output_aliase_custom_autograd_function(self): 1353 class Foo(torch.autograd.Function): 1354 @staticmethod 1355 def forward(ctx, x): 1356 return x 1357 1358 @staticmethod 1359 def backward(ctx, gx): 1360 return gx * 0.5 1361 1362 def f(x): 1363 return Foo.apply(x) 1364 1365 inp = [torch.ones(2, 2, requires_grad=True)] 1366 self.verify_aot_autograd(f, inp, test_mutation=False) 1367 1368 def test_input_mutation_requires_grad_detach(self): 1369 # Here, "a" requires grad, and gets mutated, so we append a copy_() to the end of the graph. 1370 # Its mutation doesn't take part in autograd though, because we mutated a detach'd view. 1371 # Need to make sure that this copy_() doesn't error, and doesn't participate in autograd either. 1372 def f(a): 1373 a.detach().mul_(2) 1374 return a + 3 1375 1376 inp = [torch.ones(4, requires_grad=True)] 1377 self.verify_aot_autograd(f, inp, test_mutation=False) 1378 inp = [torch.ones(4, requires_grad=True)] 1379 # test_mutation=True will first do some compute on inp, so it is no longer an autograd leaf 1380 # by the time it becomes a graph input. Good to test both cases. 1381 self.verify_aot_autograd(f, inp, test_mutation=True) 1382 1383 def test_input_mutation_hidden_from_autograd_aliasing(self): 1384 def f(a): 1385 a_alias = a.view(-1) 1386 with torch.no_grad(): 1387 a_alias.mul_(2) 1388 return a + 1 1389 1390 inp = [torch.ones(4, requires_grad=True)] 1391 # The important bit: we detected that the input mutation is safe 1392 # to include **inside** the graph, since it was under no_grad 1393 # (so all we need to do is use mark_dirty() on the input to bump the VC) 1394 fw_graph = self.verify_aot_autograd( 1395 f, inp, test_mutation=True, keep_inp_mutations=True 1396 ) 1397 self.assertExpectedInline( 1398 fw_graph.code.strip(), 1399 """\ 1400def forward(self, primals_1): 1401 view = torch.ops.aten.view.default(primals_1, [-1]) 1402 mul = torch.ops.aten.mul.Tensor(view, 2); view = None 1403 view_1 = torch.ops.aten.view.default(mul, [4]); mul = None 1404 add = torch.ops.aten.add.Tensor(view_1, 1) 1405 copy_ = torch.ops.aten.copy_.default(primals_1, view_1); primals_1 = view_1 = copy_ = None 1406 return (add,)""", 1407 ) 1408 1409 def test_input_mutation_requires_grad_no_grad(self): 1410 def f(a): 1411 with torch.no_grad(): 1412 a.mul_(2) 1413 return a + 3 1414 1415 inp = [torch.ones(4, requires_grad=True)] 1416 fw_graph = self.verify_aot_autograd( 1417 f, inp, test_mutation=True, keep_inp_mutations=True 1418 ) 1419 # Even though the input requires_grad, we expect the keep the input mutation in the graph 1420 # (Even though this is a training graph!) 1421 self.assertExpectedInline( 1422 fw_graph.code.strip(), 1423 """\ 1424def forward(self, primals_1): 1425 mul = torch.ops.aten.mul.Tensor(primals_1, 2) 1426 add = torch.ops.aten.add.Tensor(mul, 3) 1427 copy_ = torch.ops.aten.copy_.default(primals_1, mul); primals_1 = mul = copy_ = None 1428 return (add,)""", 1429 ) 1430 1431 def test_input_mutation_requires_grad_no_grad_inference_graph(self): 1432 def f(a): 1433 with torch.no_grad(): 1434 a.mul_(2) 1435 return a + 3 1436 1437 inp = [torch.ones(4, requires_grad=True)] 1438 # Even though the input requires_grad, we expect the keep the input mutation in the graph 1439 fw_graph = self.verify_aot_autograd( 1440 f, inp, test_mutation=True, keep_inp_mutations=True 1441 ) 1442 1443 self.assertExpectedInline( 1444 fw_graph.code.strip(), 1445 """\ 1446def forward(self, arg0_1): 1447 mul = torch.ops.aten.mul.Tensor(arg0_1, 2) 1448 add = torch.ops.aten.add.Tensor(mul, 3) 1449 copy_ = torch.ops.aten.copy_.default(arg0_1, mul); arg0_1 = mul = copy_ = None 1450 return (add,)""", 1451 ) 1452 1453 def test_input_mutation_requires_grad_no_grad_detach_mixed(self): 1454 # Perform a mix of mutations on a: 1455 # 1 normal, 1 in no_grad, 1 on a detach'd tensor. 1456 # Only the first should participate in gradient computation. 1457 def f(a): 1458 a.detach().mul_(2) 1459 a.mul_(3) 1460 with torch.no_grad(): 1461 a.mul_(4) 1462 return a + 5 1463 1464 inp = [torch.ones(4, requires_grad=True)] 1465 fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) 1466 1467 def test_input_mutation_metadata2(self): 1468 def f(a): 1469 a.transpose_(1, 0) 1470 a.mul_(2) 1471 return a + 1 1472 1473 inp = [torch.ones(3, 3, requires_grad=True)] 1474 self.verify_aot_autograd(f, inp, test_mutation=True) 1475 inp = [torch.ones(3, 3, requires_grad=False)] 1476 self.verify_aot_autograd(f, inp, test_mutation=True) 1477 1478 def test_input_mutation_batchnorm(self): 1479 def f(inpt, weight, bias, running_mean, running_var): 1480 # This is additionally a good test, because the input tensors that we mutate 1481 # are *also* saved for backwards. 1482 # This tests that what we save for the backward is actually cloned inputs, 1483 # and not the original inputs that got mutated. 1484 return torch._native_batch_norm_legit( 1485 inpt, weight, bias, running_mean, running_var, True, 0.5, 1e-5 1486 ) 1487 1488 def create_inp(req_grad): 1489 return [ 1490 torch.ones(2, 5, 5, 5, requires_grad=req_grad), 1491 torch.ones(5, requires_grad=req_grad), 1492 torch.ones(5, requires_grad=req_grad), 1493 torch.ones(5), 1494 torch.ones(5), 1495 ] 1496 1497 from torch._decomp import get_decompositions 1498 1499 # This simulates what inductor does (running the fw + bw decompositions) 1500 decompositions = get_decompositions( 1501 [ 1502 torch.ops.aten._native_batch_norm_legit_functional, 1503 torch.ops.aten.native_batch_norm_backward, 1504 ] 1505 ) 1506 self.verify_aot_autograd( 1507 f, create_inp(True), test_mutation=True, decompositions=decompositions 1508 ) 1509 self.verify_aot_autograd( 1510 f, create_inp(False), test_mutation=True, decompositions=decompositions 1511 ) 1512 1513 def test_batchnorm_inference(self): 1514 inp = [ 1515 torch.ones(2, 5, 5, 5, requires_grad=True), 1516 torch.ones(5, requires_grad=True), 1517 torch.ones(5, requires_grad=True), 1518 torch.ones(5), 1519 torch.ones(5), 1520 ] 1521 1522 m = torch.nn.BatchNorm2d(4, 4) 1523 m.eval() 1524 fw_graph_cell = [None] 1525 inp = torch.ones(4, 4, 4, 4) 1526 fw_graph_cell = [None] 1527 compiled_m = aot_module( 1528 m, 1529 fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell), 1530 bw_compiler=nop, 1531 keep_inference_input_mutations=True, 1532 ) 1533 inp = torch.ones(4, 4, 4, 4) 1534 with torch.no_grad(): 1535 out = compiled_m(inp) 1536 # expectation: there are no copy_() calls in the decomposed batch norm when running under training=False (eval mode) 1537 code = fw_graph_cell[0].code.strip() 1538 self.assertTrue("copy_" not in str(code)) 1539 1540 def test_input_output_view_simple(self): 1541 def f(a): 1542 return a.view(-1) 1543 1544 inp = [torch.ones(2, 2, requires_grad=False).add(1)] 1545 self.verify_aot_autograd(f, inp, test_mutation=True) 1546 inp = [torch.ones(2, 2, requires_grad=True).add(1)] 1547 fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) 1548 # Outputs that alias inputs are pulled out of the graph entirely, so we don't compile anything here 1549 self.assertExpectedInline( 1550 fw_graph.code.strip(), 1551 """\ 1552def forward(self, arg0_1): 1553 view = torch.ops.aten.view.default(arg0_1, [-1]); arg0_1 = None 1554 return (view,)""", 1555 ) 1556 1557 def test_input_output_view_mutate_multiple(self): 1558 def f(a, b, c): 1559 a.mul_(2) 1560 c.mul_(3) 1561 return b.view(2, 2), c.view(2, 2) 1562 1563 def create_inp(req_grad): 1564 return [ 1565 torch.ones(2, 2, requires_grad=req_grad).add(1), 1566 torch.ones(2, 2, requires_grad=req_grad).add(1), 1567 torch.ones(2, 2, requires_grad=req_grad).add(1), 1568 ] 1569 1570 self.verify_aot_autograd(f, create_inp(False), test_mutation=True) 1571 fw_graph = self.verify_aot_autograd(f, create_inp(True), test_mutation=True) 1572 # The original function returned two outputs, both of which aliased inputs. 1573 # We expect two outputs in the functional graph, a_updated and c_updated. 1574 # The actual aliased outputs themselves aren't in the compiled forward graph; 1575 # Instead, they're generated outside of the graph. 1576 self.assertExpectedInline( 1577 fw_graph.code.strip(), 1578 """\ 1579def forward(self, primals_1, primals_2, primals_3): 1580 clone = torch.ops.aten.clone.default(primals_1); primals_1 = None 1581 clone_1 = torch.ops.aten.clone.default(primals_3); primals_3 = None 1582 mul = torch.ops.aten.mul.Tensor(clone, 2); clone = None 1583 mul_1 = torch.ops.aten.mul.Tensor(clone_1, 3); clone_1 = None 1584 view = torch.ops.aten.view.default(primals_2, [2, 2]); primals_2 = None 1585 view_2 = torch.ops.aten.view.default(mul_1, [2, 2]) 1586 return (mul, mul_1, view, view_2)""", 1587 ) 1588 1589 def test_input_output_view_metadata_mutate_multiple(self): 1590 def f(a, b, c): 1591 b.mul_(3) 1592 c.t_() 1593 return a.view(2, 2), b.view(2, 2), c.view(2, 2) 1594 1595 def create_inp(req_grad): 1596 return [ 1597 torch.ones(2, 2, requires_grad=req_grad).add(1), 1598 torch.ones(2, 2, requires_grad=req_grad).add(1), 1599 torch.ones(2, 2, requires_grad=req_grad).add(1), 1600 ] 1601 1602 self.verify_aot_autograd(f, create_inp(False), test_mutation=True) 1603 fw_graph = self.verify_aot_autograd(f, create_inp(True), test_mutation=True) 1604 # Important thing to check here: of the three inputs: 1605 # Only the b.mul_(3) should show up in the graph (we functionalize it and return it). 1606 # Everything else that does not show up in the graph includes: 1607 # - The metadata mutation on c (we do it outside the graph) 1608 # - All 3 original fw outputs, which are aliases of inputs (we regenerate them outside of the graph) 1609 self.assertExpectedInline( 1610 fw_graph.code.strip(), 1611 """\ 1612def forward(self, primals_1, primals_2, primals_3): 1613 clone = torch.ops.aten.clone.default(primals_2); primals_2 = None 1614 view = torch.ops.aten.view.default(primals_3, [2, 2]); primals_3 = None 1615 mul = torch.ops.aten.mul.Tensor(clone, 3); clone = None 1616 t = torch.ops.aten.t.default(view); view = None 1617 view_1 = torch.ops.aten.view.default(primals_1, [2, 2]); primals_1 = None 1618 view_3 = torch.ops.aten.view.default(t, [2, 2]) 1619 view_4 = torch.ops.aten.view.default(mul, [2, 2]) 1620 return (mul, t, view_1, view_4, view_3)""", 1621 ) 1622 1623 def test_input_mutation_and_output_view(self): 1624 def f(a): 1625 a.add_(1) 1626 return a.view(-1) 1627 1628 inp = [torch.ones(2, 2, requires_grad=False).add(1)] 1629 self.verify_aot_autograd(f, inp, test_mutation=True) 1630 inp = [torch.ones(2, 2, requires_grad=True).add(1)] 1631 fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) 1632 # Here, total # of outputs is 1 because: 1633 # - num_mutated_inps = 1 (a_updated) 1634 # - num_fw_outputs = 0 (the output is an alias of the input, so we move it outside the compiled fw) 1635 self.assertExpectedInline( 1636 fw_graph.code.strip(), 1637 """\ 1638def forward(self, primals_1): 1639 clone = torch.ops.aten.clone.default(primals_1); primals_1 = None 1640 add = torch.ops.aten.add.Tensor(clone, 1); clone = None 1641 view_1 = torch.ops.aten.view.default(add, [-1]) 1642 return (add, view_1)""", 1643 ) 1644 1645 def test_input_mutation_output_view_multiple(self): 1646 def f(a, b, c, d): 1647 b.transpose_(1, 0) 1648 c.add_(1) 1649 return d + 1, b.diagonal(), a + c 1650 1651 def create_inp(req_grad): 1652 return [ 1653 torch.arange(4, requires_grad=req_grad, dtype=torch.float32) 1654 .view(2, 2) 1655 .add(1), 1656 torch.arange(4, requires_grad=req_grad, dtype=torch.float32) 1657 .view(2, 2) 1658 .add(1), 1659 torch.ones(2, 2, requires_grad=req_grad).add(1), 1660 torch.ones(2, 2, requires_grad=req_grad).add(1), 1661 ] 1662 1663 self.verify_aot_autograd(f, create_inp(False), test_mutation=True) 1664 fw_graph = self.verify_aot_autograd(f, create_inp(True), test_mutation=True) 1665 self.assertExpectedInline( 1666 fw_graph.code.strip(), 1667 """\ 1668def forward(self, primals_1, primals_2, primals_3, primals_4): 1669 view = torch.ops.aten.view.default(primals_2, [2, 2]); primals_2 = None 1670 clone = torch.ops.aten.clone.default(primals_3); primals_3 = None 1671 transpose = torch.ops.aten.transpose.int(view, 1, 0); view = None 1672 add = torch.ops.aten.add.Tensor(clone, 1); clone = None 1673 add_1 = torch.ops.aten.add.Tensor(primals_4, 1); primals_4 = None 1674 diagonal = torch.ops.aten.diagonal.default(transpose) 1675 add_2 = torch.ops.aten.add.Tensor(primals_1, add); primals_1 = None 1676 return (transpose, add, add_1, diagonal, add_2)""", 1677 ) 1678 1679 def test_output_aliases_intermediate_single(self): 1680 def f(a): 1681 out = torch.mul(a, 3) 1682 return out.view(-1) 1683 1684 inp = [torch.ones(3, 3, requires_grad=False)] 1685 self.verify_aot_autograd(f, inp, test_mutation=True) 1686 inp = [torch.ones(3, 3, requires_grad=True)] 1687 fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) 1688 # In AOTAutograd, we are obligated to make the compiled forward directly return `out`, 1689 # and reconstruct `out.view(-1)` as a fresh output. 1690 self.assertExpectedInline( 1691 fw_graph.code.strip(), 1692 """\ 1693def forward(self, primals_1): 1694 mul = torch.ops.aten.mul.Tensor(primals_1, 3); primals_1 = None 1695 view = torch.ops.aten.view.default(mul, [-1]); mul = None 1696 return (view,)""", 1697 ) 1698 1699 def test_output_aliases_input_multi_output_view_should_raise_autograd_error(self): 1700 def f1(a): 1701 return list(a.unbind(0)) 1702 1703 f1_compiled = aot_function(f1, nop) 1704 1705 inp1 = torch.ones(3, 3, requires_grad=True).clone() 1706 inp2 = torch.ones(3, 3, requires_grad=True).clone() 1707 inp3 = torch.ones(3, 3, requires_grad=True).clone() 1708 1709 with self.assertRaisesRegex( 1710 RuntimeError, "Such functions do not allow the output views" 1711 ): 1712 out_test1 = f1_compiled(inp1) 1713 # This raises a runtime error from autograd in eager mode 1714 out_test1[0].mul_(2) 1715 1716 with self.assertRaisesRegex( 1717 RuntimeError, "Such functions do not allow the output views" 1718 ): 1719 out_test2 = f1_compiled(inp2) 1720 inp2.mul_(2) 1721 # In eager mode, if we mutate a tensor, any multi-output-view aliases 1722 # get their grad_fn replaced with error nodes, so accessing grad_fn should error 1723 grad_fn = out_test2[0].grad_fn 1724 1725 with self.assertRaisesRegex( 1726 RuntimeError, "Such functions do not allow the output views" 1727 ): 1728 out_test3 = f1_compiled(inp3) 1729 out_test1[0].detach().mul_(2) 1730 # The above case also applies to detached aliases (they turn the multi-output-view 1731 # alias's grad_fns into error nodes) 1732 grad_fn = out_test2[0].grad_fn 1733 1734 def test_output_aliases_input_multi_output_view(self): 1735 # All aliased outs are from multi-output views, so AOTAutograd will hide the aliasing from autograd. 1736 def f1(a): 1737 return list(a.unbind(0)) 1738 1739 inp = torch.ones(3, 3, requires_grad=True) 1740 inp_ref = torch.ones(3, 3, requires_grad=True) 1741 f1_compiled = aot_function(f1, nop) 1742 1743 out_ref = f1(inp_ref) 1744 out_test = f1_compiled(inp) 1745 # Assert that we get CompiledFunctionBackward in the backward graph, 1746 # and not AsStridedBackward. No view-regeneration necessary for this mult-output view case. 1747 # See Note: [AOTAutograd: differentiable outputs that alias each other from a multi-output view call] 1748 self.assertTrue( 1749 all("CompiledFunctionBackward" in str(o.grad_fn) for o in out_test) 1750 ) 1751 1752 sum(out_ref).sum().backward() 1753 sum(out_test).sum().backward() 1754 self.assertEqual(inp_ref.grad, inp.grad) 1755 1756 # Several of the outputs are from multi-output views. 1757 # However: they are part of the same alias set as "a", and "a.view(out.shape)", 1758 # which are both user-visible. 1759 # AOTAutograd will not try to be smart here and hide the aliasing relationships from autograd. 1760 # Instead, it will perform its "output aliases input" logic, and regenerate all aliases. 1761 def f3(a): 1762 return *list(a.unbind(0)), a.view(a.shape) 1763 1764 inp = torch.ones(3, 3, requires_grad=True) 1765 inp_ref = torch.ones(3, 3, requires_grad=True) 1766 f3_compiled = aot_function(f3, nop) 1767 1768 inp_ref_clone = inp_ref.clone() 1769 inp_clone = inp.clone() 1770 out_ref = f3(inp_ref_clone) 1771 out_test = f3_compiled(inp_clone) 1772 self.assertTrue(all("UnbindBackward" in str(o.grad_fn) for o in out_test[:3])) 1773 1774 # The last output is not from a multi-output view, so autograd will let us mutate it. 1775 out_ref[-1].mul_(2) 1776 out_test[-1].mul_(2) 1777 # Also mutate the input, which should affect the aliased output. 1778 inp_ref_clone.view(-1).mul_(3) 1779 inp_clone.view(-1).mul_(3) 1780 # Do backward 1781 (inp_ref + out_ref[-1]).sum().backward() 1782 (inp + out_test[-1]).sum().backward() 1783 self.assertEqual(inp_ref.grad, inp.grad) 1784 1785 def test_output_aliases_intermediate_multi_output_view(self): 1786 # All aliased outs are from multi-output views, so AOTAutograd will hide the aliasing from autograd. 1787 def f1(a): 1788 out = torch.mul(a, 3) 1789 return list(out.unbind(0)) 1790 1791 inp = torch.ones(3, 3, requires_grad=True) 1792 inp_ref = torch.ones(3, 3, requires_grad=True) 1793 f1_compiled = aot_function(f1, nop) 1794 1795 out_ref = f1(inp_ref) 1796 out_test = f1_compiled(inp) 1797 # Assert that we get CompiledFunctionBackward in the backward graph, 1798 # and not AsStridedBackward. No view-regeneration necessary for this mult-output view case. 1799 # See Note: [AOTAutograd: differentiable outputs that alias each other from a multi-output view call] 1800 self.assertTrue( 1801 all("CompiledFunctionBackward" in str(o.grad_fn) for o in out_test) 1802 ) 1803 1804 sum(out_ref).sum().backward() 1805 sum(out_test).sum().backward() 1806 self.assertEqual(inp_ref.grad, inp.grad) 1807 1808 # All aliased outs but one are from multi-output views, so AOTAutograd will hide the aliasing from autograd. 1809 def f2(a): 1810 out = torch.mul(a, 3) 1811 return *list(out.unbind(0)), out 1812 1813 inp = torch.ones(3, 3, requires_grad=True) 1814 inp_ref = torch.ones(3, 3, requires_grad=True) 1815 f2_compiled = aot_function(f2, nop) 1816 1817 out_ref = f2(inp_ref) 1818 out_test = f2_compiled(inp) 1819 # Assert that we get CompiledFunctionBackward in the backward graph, 1820 # and not AsStridedBackward. No view-regeneration necessary for this mult-output view case. 1821 # See Note: [AOTAutograd: differentiable outputs that alias each other from a multi-output view call] 1822 self.assertTrue( 1823 all("CompiledFunctionBackward" in str(o.grad_fn) for o in out_test) 1824 ) 1825 1826 # The last output is not from a multi-output view, so autograd will let us mutate it. 1827 out_ref[-1].mul_(2) 1828 out_test[-1].mul_(2) 1829 out_ref[-1].sum().backward() 1830 out_test[-1].sum().backward() 1831 self.assertEqual(inp_ref.grad, inp.grad) 1832 1833 # All aliased outs but one are from multi-output views, so AOTAutograd will hide the aliasing from autograd. 1834 def f3(a): 1835 out = torch.mul(a, 3) 1836 return *list(out.unbind(0)), out.view(out.shape) 1837 1838 inp = torch.ones(3, 3, requires_grad=True) 1839 inp_ref = torch.ones(3, 3, requires_grad=True) 1840 f3_compiled = aot_function(f3, nop) 1841 1842 out_ref = f3(inp_ref) 1843 out_test = f3_compiled(inp) 1844 # Assert that we get CompiledFunctionBackward in the backward graph, 1845 # and not AsStridedBackward. No view-regeneration necessary for this mult-output view case. 1846 # See Note: [AOTAutograd: differentiable outputs that alias each other from a multi-output view call] 1847 self.assertTrue( 1848 all("CompiledFunctionBackward" in str(o.grad_fn) for o in out_test) 1849 ) 1850 1851 # The last output is not from a multi-output view, so autograd will let us mutate it. 1852 out_ref[-1].mul_(2) 1853 out_test[-1].mul_(2) 1854 out_ref[-1].sum().backward() 1855 out_test[-1].sum().backward() 1856 self.assertEqual(inp_ref.grad, inp.grad) 1857 1858 # There are 5 outputs that all alias each other. 1859 # 3 of them come from multi-output views, but the other 3 are "ordinary" aliases. 1860 # Therefore, AOTAutograd will not attempt the multi-output-view optimization, 1861 # and apply the intermediate_base logic to all aliases. 1862 # (In theory we could probably get AOTAutograd to only apply the intermediate base 1863 # logic to the last 2 outputs and not the first 3. We should probably 1864 # just do the graph partitioning defined in this doc instead though). 1865 # https://docs.google.com/document/d/1DlfFq8TKbuAn2zyJxLfoW-X1qkkm5PLdHFtySo03QAk/edit 1866 def f4(a): 1867 out = torch.mul(a, 3) 1868 # also return the graph intermediate directly, 1869 # which will force AOTAutograd to do the "intermediate base" logic. 1870 # (Why? The user can mutate "out", which should change the autograd metadata 1871 # of the other aliased outputs) 1872 return *list(out.unbind(0)), out, out.view(out.shape) 1873 1874 inp = torch.ones(3, 3, requires_grad=True) 1875 inp_ref = torch.ones(3, 3, requires_grad=True) 1876 f4_compiled = aot_function(f4, nop) 1877 1878 out_ref = f4(inp_ref) 1879 out_test = f4_compiled(inp) 1880 # Mutate the last output of f4 (autograd will allow this, since it is not a multi-output view, 1881 # as long as *only* the non-multi-output views participate in the backward) 1882 # Note: We could probably try to hide **only** the multi-output views from autograd here 1883 # and only do the intermediate base logic for the last two aliases. 1884 # Longer term solution of graph partitioning is probably cleaner though (see the note). 1885 out_ref[-1].mul_(2) 1886 out_test[-1].mul_(2) 1887 1888 out_ref_sum = out_ref[-1] + out_ref[-2] 1889 out_test_sum = out_test[-1] + out_test[-2] 1890 out_ref_sum.sum().backward() 1891 out_test_sum.sum().backward() 1892 self.assertEqual(inp_ref.grad, inp.grad) 1893 1894 def test_output_aliases_intermediate_mutation_linear(self): 1895 def f(x): 1896 return (x + 1).view(-1) 1897 1898 inp = [torch.ones(3, 3, requires_grad=True)] 1899 # use inductor's decomps (which will e.g. turn _unsafe_view() into view()) 1900 from torch._inductor.decomposition import decompositions 1901 1902 f_compiled = aot_function(f, nop, decompositions=decompositions) 1903 1904 out_ref = f(*inp) 1905 out_test = f_compiled(*inp) 1906 1907 out_ref.mul_(2) 1908 out_test.mul_(2) 1909 self.assertEqual(out_ref, out_test) 1910 1911 def test_output_aliases_intermediate_no_grad(self): 1912 def f(a, b): 1913 out = torch.mul(a, 3) 1914 # First output is an alias of an intermediate that doesn't require grad 1915 return out.view(-1), b.add(1) 1916 1917 inp = [torch.ones(3, 3), torch.ones(3, 3, requires_grad=False)] 1918 self.verify_aot_autograd(f, inp, test_mutation=True) 1919 inp = [torch.ones(3, 3), torch.ones(3, 3, requires_grad=True)] 1920 fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) 1921 # important bit: we don't bother generating an intermediate base as an output in the graph, 1922 # because the intermediate base itself didn't require gradients. 1923 # (the only problematic case is when both the base and the aliasesed output require gradients). 1924 self.assertExpectedInline( 1925 fw_graph.code.strip(), 1926 """\ 1927def forward(self, primals_1, primals_2): 1928 mul = torch.ops.aten.mul.Tensor(primals_1, 3); primals_1 = None 1929 view = torch.ops.aten.view.default(mul, [-1]); mul = None 1930 add = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None 1931 return (view, add)""", 1932 ) 1933 1934 def test_output_aliases_intermediate_returned_multiple_times(self): 1935 def f(a): 1936 out = torch.mul(a, 3) 1937 out_view = out.view(-1) 1938 return out, out_view, out 1939 1940 inp = [torch.ones(3, 3, requires_grad=False)] 1941 self.verify_aot_autograd(f, inp, test_mutation=True) 1942 inp = [torch.ones(3, 3, requires_grad=True)] 1943 fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) 1944 1945 def test_output_aliases_intermediate_multiple(self): 1946 def f(a): 1947 out = torch.mul(a, 3) 1948 # AOTAutograd should manually generate these two output views in the epilogue. 1949 return out.view(-1), out.view(-1) 1950 1951 inp = [torch.ones(3, 3, requires_grad=False)] 1952 self.verify_aot_autograd(f, inp, test_mutation=True) 1953 inp = [torch.ones(3, 3, requires_grad=True)] 1954 fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) 1955 self.assertExpectedInline( 1956 fw_graph.code.strip(), 1957 """\ 1958def forward(self, primals_1): 1959 mul = torch.ops.aten.mul.Tensor(primals_1, 3); primals_1 = None 1960 view = torch.ops.aten.view.default(mul, [-1]) 1961 view_1 = torch.ops.aten.view.default(mul, [-1]) 1962 return (view, view_1, mul)""", 1963 ) 1964 1965 def test_output_aliases_intermediate_and_returned(self): 1966 def f(a): 1967 out = torch.mul(a, 3) 1968 # AOTAutograd should manually generate the first output (a view of an intermediate) 1969 # but not the second (which is itself the intermediate for the first) 1970 return out.view(-1), out 1971 1972 inp = [torch.ones(3, 3, requires_grad=False)] 1973 self.verify_aot_autograd(f, inp, test_mutation=True) 1974 inp = [torch.ones(3, 3, requires_grad=True)] 1975 fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) 1976 self.assertExpectedInline( 1977 fw_graph.code.strip(), 1978 """\ 1979def forward(self, primals_1): 1980 mul = torch.ops.aten.mul.Tensor(primals_1, 3); primals_1 = None 1981 view = torch.ops.aten.view.default(mul, [-1]) 1982 return (view, mul)""", 1983 ) 1984 1985 def test_output_aliases_intermediate_and_returned_flipped(self): 1986 def f(a): 1987 out = torch.mul(a, 3) 1988 # AOTAutograd should manually generate the first output (a view of an intermediate) 1989 # but not the second (which is itself the intermediate for the first) 1990 return out, out.view(-1) 1991 1992 inp = [torch.ones(3, 3, requires_grad=False)] 1993 self.verify_aot_autograd(f, inp, test_mutation=True) 1994 inp = [torch.ones(3, 3, requires_grad=True)] 1995 fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) 1996 self.assertExpectedInline( 1997 fw_graph.code.strip(), 1998 """\ 1999def forward(self, primals_1): 2000 mul = torch.ops.aten.mul.Tensor(primals_1, 3); primals_1 = None 2001 view = torch.ops.aten.view.default(mul, [-1]) 2002 return (mul, view)""", 2003 ) 2004 2005 def test_output_aliases_intermediate_and_returned_different_grad(self): 2006 def f(a): 2007 out = torch.mul(a, 3) 2008 # AOTAutograd should manually generate the first output (a view of an intermediate) 2009 # but not the second (which is itself the intermediate for the first) 2010 return out.view(-1), out, out[0].detach() 2011 2012 inp = [torch.ones(3, 3, requires_grad=False)] 2013 self.verify_aot_autograd(f, inp, test_mutation=True) 2014 inp = [torch.ones(3, 3, requires_grad=True)] 2015 fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) 2016 self.assertExpectedInline( 2017 fw_graph.code.strip(), 2018 """\ 2019def forward(self, primals_1): 2020 mul = torch.ops.aten.mul.Tensor(primals_1, 3); primals_1 = None 2021 view = torch.ops.aten.view.default(mul, [-1]) 2022 select = torch.ops.aten.select.int(mul, 0, 0) 2023 detach = torch.ops.aten.detach.default(select); select = None 2024 detach_1 = torch.ops.aten.detach.default(detach); detach = None 2025 detach_2 = torch.ops.aten.detach.default(detach_1); detach_1 = None 2026 return (view, mul, detach_2)""", 2027 ) 2028 2029 def test_output_aliases_intermediate_inplace_view(self): 2030 def f(a): 2031 out = torch.mul(a, 3) 2032 out.t_() 2033 return out 2034 2035 inp = [torch.ones(2, 4, requires_grad=True)] 2036 2037 # TODO: fix this test. 2038 # See https://github.com/pytorch/pytorch/issues/90507 2039 # self.verify_aot_autograd(f, inp, test_mutation=True) 2040 2041 def test_output_aliases_intermediate_inplace_view_with_detach(self): 2042 def f(a): 2043 out = torch.mul(a, 3) 2044 out.t_() 2045 out.detach_() 2046 # Thanks to the detach_() AOT Autograd doesn't need to do anything. 2047 # `out` will show up as having OutputType.non_alias, 2048 # and ._is_view() == False 2049 return out, a + 1 2050 2051 inp = [torch.ones(2, 4, requires_grad=False)] 2052 self.verify_aot_autograd(f, inp, test_mutation=True) 2053 inp = [torch.ones(2, 4, requires_grad=True)] 2054 fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) 2055 self.assertExpectedInline( 2056 fw_graph.code.strip(), 2057 """\ 2058def forward(self, primals_1): 2059 mul = torch.ops.aten.mul.Tensor(primals_1, 3) 2060 t = torch.ops.aten.t.default(mul); mul = None 2061 add = torch.ops.aten.add.Tensor(primals_1, 1); primals_1 = None 2062 return (t, add)""", 2063 ) 2064 2065 def test_output_aliases_intermediate_inplace_view_and_view(self): 2066 def f(a): 2067 out = torch.mul(a, 3) 2068 out_view = out.unsqueeze(0) 2069 out.t_() 2070 out_view2 = out.unsqueeze(0) 2071 return out_view, out, out_view2 2072 2073 inp = [torch.ones(2, 4, requires_grad=True)] 2074 2075 # TODO: fix this test. 2076 # See <github issue link> 2077 # self.verify_aot_autograd(f, inp, test_mutation=True) 2078 2079 def test_output_aliases_intermediate_multiple_mixed(self): 2080 def f(a): 2081 out1 = torch.mul(a, 3) 2082 out2 = torch.mul(a, 4) 2083 # AOTAutograd should manually generate these two output views in the epilogue. 2084 return out1.view(-1), out2.transpose(1, 0), out1.transpose(1, 0) 2085 2086 inp = [torch.ones(3, 3, requires_grad=False)] 2087 self.verify_aot_autograd(f, inp, test_mutation=True) 2088 inp = [torch.ones(3, 3, requires_grad=True)] 2089 fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) 2090 self.assertExpectedInline( 2091 fw_graph.code.strip(), 2092 """\ 2093def forward(self, primals_1): 2094 mul = torch.ops.aten.mul.Tensor(primals_1, 3) 2095 mul_1 = torch.ops.aten.mul.Tensor(primals_1, 4); primals_1 = None 2096 view = torch.ops.aten.view.default(mul, [-1]) 2097 transpose = torch.ops.aten.transpose.int(mul_1, 1, 0); mul_1 = None 2098 transpose_1 = torch.ops.aten.transpose.int(mul, 1, 0) 2099 return (view, transpose, transpose_1, mul)""", 2100 ) 2101 2102 def test_output_all_alias_types(self): 2103 # There are 3 types of aliasing that require us to return metadata in the compiled fw: 2104 # (1) outputs that are views of inputs 2105 # (2) outputs that are views of intermediates 2106 # (3) inputs that get metadata mutations 2107 # test all 3 of them here 2108 def f(a): 2109 a.transpose_(1, 0) 2110 tmp = a.mul(2) 2111 return tmp.squeeze(), tmp.transpose(1, 0), a.unsqueeze(0) 2112 2113 def inp_callable(req_grad): 2114 x = torch.ones(1, 2, 4, requires_grad=req_grad).clone() 2115 return [(x,), (x,)] 2116 2117 self.verify_aot_autograd( 2118 f, partial(inp_callable, req_grad=False), test_mutation=True 2119 ) 2120 fw_graph = self.verify_aot_autograd( 2121 f, partial(inp_callable, req_grad=True), test_mutation=True 2122 ) 2123 # TODO: make this test run with dynamic shapes so it is more meaningful 2124 # metadata output order: (a_updated_meta, out1_meta, out2_meta, out3_meta) 2125 self.assertExpectedInline( 2126 fw_graph.code.strip(), 2127 """\ 2128def forward(self, primals_1): 2129 view = torch.ops.aten.view.default(primals_1, [1, 2, 4]); primals_1 = None 2130 transpose = torch.ops.aten.transpose.int(view, 1, 0); view = None 2131 mul = torch.ops.aten.mul.Tensor(transpose, 2) 2132 squeeze = torch.ops.aten.squeeze.default(mul) 2133 transpose_1 = torch.ops.aten.transpose.int(mul, 1, 0) 2134 unsqueeze = torch.ops.aten.unsqueeze.default(transpose, 0) 2135 return (transpose, squeeze, transpose_1, unsqueeze, mul)""", 2136 ) 2137 2138 @parametrize("req_grad", [False, True]) 2139 def test_subclass_metadata_mutation(self, req_grad): 2140 def f(a): 2141 a.transpose_(1, 0) 2142 tmp = a.mul(2) 2143 return tmp.transpose(1, 0) 2144 2145 def inp_callable(req_grad): 2146 x = torch.ones(1, 2, 4, requires_grad=req_grad).clone() 2147 return [(x,), (x,)] 2148 2149 # See https://github.com/pytorch/pytorch/issues/114975 2150 with self.assertRaisesRegex( 2151 RuntimeError, 2152 "Metadata mutations are currently not allowed on tensor subclasses", 2153 ): 2154 self.verify_aot_autograd( 2155 f, 2156 partial(inp_callable, req_grad=req_grad), 2157 test_mutation=True, 2158 make_inputs_subclasses=True, 2159 ) 2160 2161 def test_input_data_and_metadata_mutation(self): 2162 def f(a): 2163 a.t_() 2164 a[0].mul_(2) 2165 return a.view(a.shape) 2166 2167 inp = [torch.ones(3, 3, requires_grad=False)] 2168 self.verify_aot_autograd(f, inp, test_mutation=True) 2169 inp = [torch.ones(3, 3, requires_grad=True)] 2170 fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) 2171 self.assertExpectedInline( 2172 fw_graph.code.strip(), 2173 """\ 2174def forward(self, primals_1): 2175 clone = torch.ops.aten.clone.default(primals_1); primals_1 = None 2176 t = torch.ops.aten.t.default(clone) 2177 select = torch.ops.aten.select.int(t, 0, 0); t = None 2178 mul = torch.ops.aten.mul.Tensor(select, 2); select = None 2179 t_1 = torch.ops.aten.t.default(clone); clone = None 2180 select_scatter = torch.ops.aten.select_scatter.default(t_1, mul, 0, 0); t_1 = mul = None 2181 t_2 = torch.ops.aten.t.default(select_scatter); select_scatter = None 2182 t_4 = torch.ops.aten.t.default(t_2) 2183 t_6 = torch.ops.aten.t.default(t_2); t_2 = None 2184 view_1 = torch.ops.aten.view.default(t_6, [3, 3]); t_6 = None 2185 return (t_4, view_1)""", 2186 ) 2187 2188 def test_view_and_inplace_view(self): 2189 def f(a, b): 2190 a.t_() 2191 return b.view(b.shape), a.view(a.shape) 2192 2193 def create_inp(req_grad): 2194 return [ 2195 torch.ones(3, 3, requires_grad=req_grad), 2196 torch.ones(3, 3, requires_grad=req_grad), 2197 ] 2198 2199 self.verify_aot_autograd(f, create_inp(False), test_mutation=True) 2200 fw_graph = self.verify_aot_autograd(f, create_inp(True), test_mutation=True) 2201 self.assertExpectedInline( 2202 fw_graph.code.strip(), 2203 """\ 2204def forward(self, arg0_1, arg1_1): 2205 t = torch.ops.aten.t.default(arg0_1); arg0_1 = None 2206 view = torch.ops.aten.view.default(arg1_1, [3, 3]); arg1_1 = None 2207 view_1 = torch.ops.aten.view.default(t, [3, 3]) 2208 return (t, view, view_1)""", 2209 ) 2210 2211 def test_view_detach(self): 2212 def f(a): 2213 tmp = a.detach() 2214 a.mul_(2) 2215 return a, tmp 2216 2217 inp = [torch.ones(3, 3, requires_grad=True)] 2218 self.verify_aot_autograd(f, inp, test_mutation=True) 2219 inp = [torch.ones(3, 3, requires_grad=False)] 2220 self.verify_aot_autograd(f, inp, test_mutation=True) 2221 2222 def test_input_inplace_requires_grad_true(self): 2223 def f(a, b): 2224 a.requires_grad_(True) 2225 return a.mul(3), b.mul(4) 2226 2227 inp = [ 2228 # First inp doesnt require grad, but we switch it on 2229 torch.ones(3, 3, requires_grad=False), 2230 torch.ones(3, 3, requires_grad=True), 2231 ] 2232 2233 fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) 2234 self.assertExpectedInline( 2235 fw_graph.code.strip(), 2236 """\ 2237def forward(self, primals_1, primals_2): 2238 mul = torch.ops.aten.mul.Tensor(primals_1, 3); primals_1 = None 2239 mul_1 = torch.ops.aten.mul.Tensor(primals_2, 4); primals_2 = None 2240 return (mul, mul_1)""", 2241 ) 2242 2243 # This is a torture test: 2244 # a and b get turned into a synthetic base in the compiled graph 2245 # One gets a data mutation, the other gets a metadata mutation. 2246 # We need to make sure that the metadata mutation gets propagated 2247 # back to the original input. 2248 @skipIfDynamoInput("Dynamo removes runtime error") 2249 def test_input_data_and_metadata_mutation_aliases_other_input(self): 2250 # a and b are aliased 2251 def f(a, b): 2252 a.mul_(2) 2253 b.t_() 2254 return a.mul(b) 2255 2256 def inp_callable(req_grad): 2257 base = torch.ones(2, 2, requires_grad=req_grad) 2258 # Note: in our test, the add() is important because we need the graph inputs to be non-leaves so we can mutate them. 2259 x = base.add(1) 2260 inp1 = x[0] 2261 inp2 = x[0] 2262 return [base], [inp1, inp2] 2263 2264 self.verify_aot_autograd( 2265 f, partial(inp_callable, req_grad=False), test_mutation=True 2266 ) 2267 self.verify_aot_autograd( 2268 f, partial(inp_callable, req_grad=True), test_mutation=True 2269 ) 2270 with self.assertRaisesRegex( 2271 RuntimeError, 2272 "Encountered aliased inputs that are mutated in the graph, but", 2273 ): 2274 self.verify_aot_autograd( 2275 f, 2276 partial(inp_callable, req_grad=False), 2277 test_mutation=True, 2278 make_inputs_subclasses=True, 2279 ) 2280 with self.assertRaisesRegex( 2281 RuntimeError, 2282 "Encountered aliased inputs that are mutated in the graph, but", 2283 ): 2284 self.verify_aot_autograd( 2285 f, 2286 partial(inp_callable, req_grad=True), 2287 test_mutation=True, 2288 make_inputs_subclasses=True, 2289 ) 2290 2291 # https://github.com/pytorch/pytorch/issues/106456 2292 def test_input_mutation_noncontiguous(self): 2293 def f(a): 2294 a.mul_(2) 2295 return a + 1 2296 2297 def inp_callable(req_grad): 2298 base = torch.ones(2, 2, requires_grad=req_grad) 2299 x = base.add(1) 2300 # create a non-contiguous view to pass as an input to the compiler 2301 inp = x[:, 0] 2302 return [base], [inp] 2303 2304 self.verify_aot_autograd( 2305 f, partial(inp_callable, req_grad=False), test_mutation=True 2306 ) 2307 self.verify_aot_autograd( 2308 f, partial(inp_callable, req_grad=True), test_mutation=True 2309 ) 2310 with self.assertRaisesRegex( 2311 RuntimeError, 2312 "Mutations on non-contiguous inputs are currently not allowed on tensor subclasses", 2313 ): 2314 self.verify_aot_autograd( 2315 f, 2316 partial(inp_callable, req_grad=False), 2317 test_mutation=True, 2318 make_inputs_subclasses=True, 2319 ) 2320 with self.assertRaisesRegex( 2321 RuntimeError, 2322 "Mutations on non-contiguous inputs are currently not allowed on tensor subclasses", 2323 ): 2324 self.verify_aot_autograd( 2325 f, 2326 partial(inp_callable, req_grad=True), 2327 test_mutation=True, 2328 make_inputs_subclasses=True, 2329 ) 2330 2331 def test_backward_mutation_data(self): 2332 class BwMutation(torch.autograd.Function): 2333 @staticmethod 2334 def forward(ctx, x): 2335 ctx.save_for_backward(x) 2336 return x.clone() 2337 2338 @staticmethod 2339 def backward(ctx, grad_output): 2340 (x,) = ctx.saved_tensors 2341 # bw mutation 2342 x.mul_(2) 2343 return grad_output.clone() 2344 2345 def f(a, b): 2346 out = BwMutation.apply(b) 2347 return a * out 2348 2349 inp_no_grad = [ 2350 torch.ones(3, 3, requires_grad=True), 2351 torch.ones(3, 3, requires_grad=False), 2352 ] 2353 2354 # Mutation on buffer that does not require grad during the backward is allowed 2355 self.verify_aot_autograd(f, inp_no_grad, test_mutation=True) 2356 2357 inp_grad = [ 2358 torch.ones(3, 3, requires_grad=True), 2359 torch.ones(3, 3, requires_grad=True), 2360 ] 2361 self.verify_aot_autograd(f, inp_grad, test_mutation=True) 2362 2363 def test_backward_mutation_metadata(self): 2364 class BwMutation(torch.autograd.Function): 2365 @staticmethod 2366 def forward(ctx, a, b): 2367 ctx.save_for_backward(b) 2368 return a.clone(), b.clone() 2369 2370 @staticmethod 2371 def backward(ctx, grad_a, grad_b): 2372 (b,) = ctx.saved_tensors 2373 # bw metadata mutation 2374 b.transpose_(1, 0) 2375 return grad_a.clone(), grad_b.clone() 2376 2377 def f(a, b): 2378 a_, b_ = BwMutation.apply(a, b) 2379 out = a_ * b_ 2380 return out 2381 2382 inp_no_grad = [ 2383 torch.ones(3, 3, requires_grad=True), 2384 torch.ones(3, 3, requires_grad=False), 2385 ] 2386 2387 with self.assertRaisesRegex( 2388 AssertionError, "input that had its metadata mutated in the backward" 2389 ): 2390 self.verify_aot_autograd(f, inp_no_grad, test_mutation=True) 2391 2392 def test_backward_mutation_on_grad_out(self): 2393 class BwMutation(torch.autograd.Function): 2394 @staticmethod 2395 def forward(ctx, x): 2396 return x.clone() 2397 2398 @staticmethod 2399 def backward(ctx, grad_output): 2400 grad_output.mul_(2) 2401 return grad_output.clone() 2402 2403 def f(a, b): 2404 tmp = a * b 2405 out = BwMutation.apply(tmp) 2406 return out 2407 2408 inp_grad = [ 2409 torch.ones(3, 3, requires_grad=True), 2410 torch.ones(3, 3, requires_grad=True), 2411 ] 2412 f_compiled = aot_function(f, nop) 2413 with self.assertRaisesRegex( 2414 AssertionError, "input to the backward that was mutated during the backward" 2415 ): 2416 out = f_compiled(*inp_grad) 2417 2418 def test_backward_mutation_forward_inputs(self): 2419 @torch.library.custom_op("_test::_clone", mutates_args={}) 2420 def f(x: torch.Tensor, x1: torch.Tensor) -> torch.Tensor: 2421 return x.clone() 2422 2423 def f_fake(x, x1): 2424 return torch.empty_like(x) 2425 2426 def backward(ctx, grad): 2427 with torch.no_grad(): 2428 ctx.x1.zero_() 2429 return grad * 2, None 2430 2431 def setup_context(ctx, inputs, output): 2432 (x, x1) = inputs 2433 ctx.x = x 2434 ctx.x1 = x1 2435 2436 f.register_fake(f_fake) 2437 f.register_autograd(backward, setup_context=setup_context) 2438 2439 def fn(x: torch.Tensor, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: 2440 x2.mul_(5) 2441 return torch.ops._test._clone(x, x1) + x2 2442 2443 inp_x, inp_x1, inp_x2 = ( 2444 torch.randn(3, requires_grad=True), 2445 torch.randn(3, requires_grad=False), 2446 torch.randn(3, requires_grad=False), 2447 ) 2448 2449 ref_x, ref_x1, ref_x2 = inp_x.clone(), inp_x1.clone(), inp_x2.clone() 2450 ref_y = fn(ref_x, ref_x1, ref_x2) 2451 2452 compiled_f = aot_function(fn, nop, keep_inference_input_mutations=True) 2453 2454 x, x1, x2 = inp_x.clone(), inp_x1.clone(), inp_x2.clone() 2455 y = compiled_f(x, x1, x2) 2456 2457 # Verify mutation in forward applied and mutation in backward is not in forward 2458 self.assertEqual(ref_x, x) 2459 self.assertEqual(ref_x1, x1) 2460 self.assertEqual(ref_x2, x2) 2461 self.assertEqual(ref_y, y) 2462 2463 ref_y.sum().backward() 2464 y.sum().backward() 2465 2466 # Verify mutations in backward applied 2467 self.assertEqual(ref_x, x) 2468 self.assertEqual(ref_x1, x1) 2469 self.assertEqual(ref_x2, x2) 2470 self.assertEqual(ref_y, y) 2471 2472 self.assertEqual(ref_x.grad, x.grad) 2473 self.assertEqual(ref_x1.grad, x1.grad) 2474 self.assertEqual(ref_x2.grad, x2.grad) 2475 2476 def test_backward_mutation_forward_inputs_create_graph(self): 2477 @torch.library.custom_op("_test::_clone_create_graph", mutates_args={}) 2478 def f(x: torch.Tensor, x1: torch.Tensor) -> torch.Tensor: 2479 return x.clone() 2480 2481 def f_fake(x, x1): 2482 return torch.empty_like(x) 2483 2484 def backward(ctx, grad): 2485 with torch.no_grad(): 2486 ctx.x1.zero_() 2487 return grad * 2, None 2488 2489 def setup_context(ctx, inputs, output): 2490 (x, x1) = inputs 2491 ctx.x = x 2492 ctx.x1 = x1 2493 2494 f.register_fake(f_fake) 2495 f.register_autograd(backward, setup_context=setup_context) 2496 2497 def fn(x: torch.Tensor, x1: torch.Tensor) -> torch.Tensor: 2498 return torch.ops._test._clone_create_graph(x, x1) 2499 2500 inp_x, inp_x1 = torch.randn(3, requires_grad=True), torch.randn( 2501 3, requires_grad=True 2502 ) 2503 2504 ref_x, ref_x1 = inp_x.clone(), inp_x1.clone() 2505 ref_y = f(ref_x, ref_x1) 2506 ref_y.sum().backward() 2507 x, x1 = inp_x.clone(), inp_x1.clone() 2508 compiled_f = aot_function(fn, nop) 2509 y = compiled_f(x, x1) 2510 loss = y.sum() 2511 with self.assertRaisesRegex( 2512 RuntimeError, 2513 "aot_autograd does not support input mutations with requires_grad in backward for create_graph=True", 2514 ): 2515 torch.autograd.grad(loss, inp_x, create_graph=True) 2516 # Not checking equality of ref and x as Exception is expected 2517 2518 # Partially addresses https://github.com/pytorch/pytorch/issues/106457 2519 def test_input_mutation_false_aliasing(self): 2520 def f(a, b): 2521 a.mul_(3) 2522 b.mul_(2) 2523 return a.clone().view(-1) + b.clone().view(-1) 2524 2525 # No overlap, contiguous 2526 def inp_callable1(req_grad): 2527 base = torch.ones(4, 4, requires_grad=req_grad) 2528 x = base.add(1) 2529 # create two views that share storage, but are actually non-overlapping 2530 a = x[0:2] 2531 b = x[2:4] 2532 return [base], [a, b] 2533 2534 fw_graph = self.verify_aot_autograd( 2535 f, partial(inp_callable1, req_grad=False), test_mutation=True 2536 ) 2537 self.verify_aot_autograd( 2538 f, partial(inp_callable1, req_grad=True), test_mutation=True 2539 ) 2540 self.verify_aot_autograd( 2541 f, 2542 partial(inp_callable1, req_grad=False), 2543 test_mutation=True, 2544 make_inputs_subclasses=True, 2545 ) 2546 # Input mutations on subclasses with training graphs fail backward guards today. 2547 with self.assertRaisesRegex( 2548 AssertionError, 2549 "attempted to compile the backward with incorrect subclass metadata", 2550 ): 2551 self.verify_aot_autograd( 2552 f, 2553 partial(inp_callable1, req_grad=True), 2554 test_mutation=True, 2555 make_inputs_subclasses=True, 2556 ) 2557 2558 # Important characteristic: the graph takes in 2 inputs! 2559 # That shows that we didn't try to run our complicated synthetic base logic, 2560 # because we successfully detected false aliasing across the two inputs. 2561 self.assertExpectedInline( 2562 fw_graph.code.strip(), 2563 """\ 2564def forward(self, arg0_1, arg1_1): 2565 mul = torch.ops.aten.mul.Tensor(arg0_1, 3); arg0_1 = None 2566 mul_1 = torch.ops.aten.mul.Tensor(arg1_1, 2); arg1_1 = None 2567 clone = torch.ops.aten.clone.default(mul) 2568 view = torch.ops.aten.view.default(clone, [-1]); clone = None 2569 clone_1 = torch.ops.aten.clone.default(mul_1) 2570 view_1 = torch.ops.aten.view.default(clone_1, [-1]); clone_1 = None 2571 add = torch.ops.aten.add.Tensor(view, view_1); view = view_1 = None 2572 return (mul, mul_1, add)""", 2573 ) 2574 2575 # No overlap, non-contiguous: first tensor ends before second tensor start 2576 def inp_callable2(req_grad): 2577 base = torch.ones(256, requires_grad=req_grad) 2578 x = base.add(1) 2579 a = x.as_strided((4, 4), (8, 1), storage_offset=0) 2580 b = x.as_strided((4, 4), (8, 1), storage_offset=28) 2581 return [base], [a, b] 2582 2583 # No overlap, non-contiguous: tensors are perfectly interleaved 2584 def inp_callable3(req_grad): 2585 base = torch.ones(4, 4, requires_grad=req_grad) 2586 x = base.add(1) 2587 a = x[:, 0:2] 2588 b = x[:, 2:4] 2589 return [base], [a, b] 2590 2591 # No overlap, non-contiguous 2592 def inp_callable4(req_grad): 2593 base = torch.ones(256, requires_grad=req_grad) 2594 x = base.add(1) 2595 a = x.as_strided((4, 4), (9, 1), storage_offset=0) 2596 b = x.as_strided((4, 4), (9, 1), storage_offset=22) 2597 return [base], [a, b] 2598 2599 # No overlap, non-contiguous 2600 def inp_callable5(req_grad): 2601 base = torch.ones(256, requires_grad=req_grad) 2602 x = base.add(1) 2603 a = x.as_strided((4, 4), (9, 1), storage_offset=0) 2604 b = x.as_strided((4, 4), (9, 1), storage_offset=23) 2605 return [base], [a, b] 2606 2607 # No overlap, non-contiguous 2608 def inp_callable6(req_grad): 2609 base = torch.ones(256, requires_grad=req_grad) 2610 x = base.add(1) 2611 # a's last element is at offset 195 (24 total elements) 2612 a = x.as_strided((2, 4, 3), (110, 24, 4), storage_offset=5) 2613 # b's first element is at offset 196: no overlap 2614 b = x[196 : 196 + a.numel()] 2615 return [base], [a, b] 2616 2617 # overlap! non-contiguous 2618 def inp_callable_overlap1(req_grad): 2619 base = torch.ones(256, requires_grad=req_grad) 2620 x = base.add(1) 2621 a = x.as_strided((4, 4), (9, 1), storage_offset=0) 2622 b = x.as_strided((4, 4), (9, 1), storage_offset=24) 2623 return [base], [a, b] 2624 2625 # overlap! non-contiguous 2626 def inp_callable_overlap2(req_grad): 2627 base = torch.ones(256, requires_grad=req_grad) 2628 x = base.add(1) 2629 a = x.as_strided((4, 4), (9, 1), storage_offset=0) 2630 b = x.as_strided((4, 4), (9, 1), storage_offset=25) 2631 return [base], [a, b] 2632 2633 # overlap! non-contiguous 2634 def inp_callable_overlap3(req_grad): 2635 base = torch.ones(256, requires_grad=req_grad) 2636 x = base.add(1) 2637 # a's last element is at offset 195 (24 total elements) 2638 a = x.as_strided((2, 4, 3), (110, 24, 4), storage_offset=5) 2639 # b's first element is at offset 195: overlap! 2640 b = x[195 : 195 + a.numel()] 2641 return [base], [a, b] 2642 2643 fw_graph2 = self.verify_aot_autograd( 2644 f, partial(inp_callable2, req_grad=False), test_mutation=True 2645 ) 2646 fw_graph3 = self.verify_aot_autograd( 2647 f, partial(inp_callable3, req_grad=False), test_mutation=True 2648 ) 2649 fw_graph4 = self.verify_aot_autograd( 2650 f, partial(inp_callable4, req_grad=False), test_mutation=True 2651 ) 2652 fw_graph5 = self.verify_aot_autograd( 2653 f, partial(inp_callable5, req_grad=False), test_mutation=True 2654 ) 2655 fw_graph6 = self.verify_aot_autograd( 2656 f, partial(inp_callable6, req_grad=False), test_mutation=True 2657 ) 2658 2659 fw_graph_overlap1 = self.verify_aot_autograd( 2660 f, partial(inp_callable_overlap2, req_grad=False), test_mutation=True 2661 ) 2662 fw_graph_overlap2 = self.verify_aot_autograd( 2663 f, partial(inp_callable_overlap1, req_grad=False), test_mutation=True 2664 ) 2665 2666 # All non-overlap graphs should be the same since we detected false aliasing 2667 self.assertEqual(str(fw_graph.code), str(fw_graph2.code)) 2668 self.assertEqual(str(fw_graph.code), str(fw_graph3.code)) 2669 self.assertEqual(str(fw_graph.code), str(fw_graph4.code)) 2670 self.assertEqual(str(fw_graph.code), str(fw_graph5.code)) 2671 self.assertEqual(str(fw_graph.code), str(fw_graph6.code)) 2672 2673 # All overlap graphs should be the same since we detected real aliasing 2674 self.assertNotEqual(str(fw_graph.code), str(fw_graph_overlap1.code)) 2675 self.assertNotEqual(str(fw_graph.code), str(fw_graph_overlap2.code)) 2676 self.assertTrue("as_strided_scatter" in str(fw_graph_overlap1.code)) 2677 self.assertTrue("as_strided_scatter" in str(fw_graph_overlap2.code)) 2678 2679 @unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable") 2680 def test_mem_leak_from_save_for_bw(self): 2681 # See a full diagnosis at this issue: https://github.com/pytorch/pytorch/issues/94990 2682 # Note [Detaching saved tensors in AOTAutograd] 2683 # This program creates a ref-cycle. Long term, we should fix this ref cycle 2684 # (since it can arise, naturally albeit rarely, from uses of autograd.Function). 2685 # But AOTAutograd makes it more likely to show up from tracing user programs, 2686 # so we deal with it by manually detaching the tensors that we save for backward. 2687 # This is completely wrong and would give wrong results if we were to do double backward. 2688 # Fortunately today, double backward is explicitly banned in AOTAutograd. 2689 def f(a, b): 2690 add = a + a 2691 split = torch.functional.split(add, [4, 4], dim=1) 2692 getitem_2 = split[1] 2693 unsqueeze = getitem_2.unsqueeze(-1) 2694 mul = unsqueeze * b 2695 return (getitem_2, mul) 2696 2697 f_compiled = aot_function(f, nop) 2698 inps = [ 2699 torch.ones(8, 8, device="cuda", requires_grad=True), 2700 torch.ones(1, 4, 1, device="cuda", requires_grad=True), 2701 ] 2702 mem_before = torch.cuda.memory_allocated() 2703 f_compiled(*inps) 2704 mem_after = torch.cuda.memory_allocated() 2705 self.assertTrue(mem_after == mem_before) 2706 2707 def test_output_aliases_multiple_inputs_get_correct_one(self): 2708 # a and b are aliased, but have different shapes 2709 # The first output should view off the first input, the 2nd output should view off the 2nd input 2710 def f(a, b): 2711 return a.view(a.shape), b.view(b.shape) 2712 2713 def inp_callable(req_grad): 2714 base = torch.ones(2, 2, requires_grad=req_grad) 2715 # Note: in our test, the add() is important because we need the graph inputs to be non-leaves so we can mutate them. 2716 x = base.mul(2) 2717 inp1 = x.view(-1) 2718 inp2 = x[0] 2719 return [base], [inp1, inp2] 2720 2721 self.verify_aot_autograd( 2722 f, partial(inp_callable, req_grad=False), test_mutation=True 2723 ) 2724 self.verify_aot_autograd( 2725 f, partial(inp_callable, req_grad=True), test_mutation=True 2726 ) 2727 self.verify_aot_autograd( 2728 f, 2729 partial(inp_callable, req_grad=False), 2730 test_mutation=True, 2731 make_inputs_subclasses=True, 2732 ) 2733 self.verify_aot_autograd( 2734 f, 2735 partial(inp_callable, req_grad=True), 2736 test_mutation=True, 2737 make_inputs_subclasses=True, 2738 ) 2739 2740 def test_input_mutation_aliases_other_input(self): 2741 def f(a, b): 2742 a.add_(1) 2743 return a + b 2744 2745 def inp_callable(req_grad): 2746 base = torch.ones(4, 2, requires_grad=req_grad) 2747 # Note: in our test, the add() is important because we need the graph inputs to be non-leaves so we can mutate them. 2748 x = base.add(1) 2749 inp1 = x[0] 2750 inp2 = x[0] 2751 return [base], [inp1, inp2] 2752 2753 self.verify_aot_autograd( 2754 f, partial(inp_callable, req_grad=False), test_mutation=True 2755 ) 2756 fw_graph = self.verify_aot_autograd( 2757 f, partial(inp_callable, req_grad=True), test_mutation=True 2758 ) 2759 # Important parts of the graph: 2760 # - the compiled graph takes in a base, and we generate a and b (the views) off of the base 2761 # - clone() is still in the graph, because we need to call grad() on the original (non-mutated) inputs 2762 # - We re-generate the views *after* the clone, to preserve view relationships. 2763 self.assertExpectedInline( 2764 fw_graph.code.strip(), 2765 """\ 2766def forward(self, primals_1): 2767 clone = torch.ops.aten.clone.default(primals_1); primals_1 = None 2768 as_strided = torch.ops.aten.as_strided.default(clone, [2], [1], 0) 2769 add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None 2770 as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [2], [1], 0); clone = add = None 2771 as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 0) 2772 as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 0) 2773 add_1 = torch.ops.aten.add.Tensor(as_strided_2, as_strided_5); as_strided_2 = as_strided_5 = None 2774 return (as_strided_scatter, add_1)""", 2775 ) # noqa: B950 2776 2777 def test_input_mutation_aliases_other_input2(self): 2778 def f(a, b): 2779 a.add_(1) 2780 return a + b 2781 2782 def inp_callable(req_grad): 2783 base = torch.ones(2, 2, requires_grad=req_grad) 2784 x = base.add(1) 2785 inp1 = x[0] 2786 # Here, one of the aliased inputs is the base itself 2787 inp2 = x 2788 return [base], [inp1, inp2] 2789 2790 self.verify_aot_autograd( 2791 f, partial(inp_callable, req_grad=False), test_mutation=True 2792 ) 2793 fw_graph = self.verify_aot_autograd( 2794 f, partial(inp_callable, req_grad=True), test_mutation=True 2795 ) 2796 self.assertExpectedInline( 2797 fw_graph.code.strip(), 2798 """\ 2799def forward(self, primals_1): 2800 clone = torch.ops.aten.clone.default(primals_1); primals_1 = None 2801 as_strided = torch.ops.aten.as_strided.default(clone, [2], [1], 0) 2802 add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None 2803 as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [2], [1], 0); clone = add = None 2804 as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 0) 2805 as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [2, 2], [2, 1], 0) 2806 add_1 = torch.ops.aten.add.Tensor(as_strided_2, as_strided_5); as_strided_2 = as_strided_5 = None 2807 return (as_strided_scatter, add_1)""", 2808 ) # noqa: B950 2809 2810 def test_input_mutation_aliases_and_output_alias(self): 2811 def f(a, b): 2812 # Here, we need to take care:that because and b are aliased 2813 # since a and b are aliased, we generate a view off of "updated b" 2814 a.add_(1) 2815 return b.view(b.shape) 2816 2817 def inp_callable(req_grad): 2818 base = torch.ones(2, 2, requires_grad=req_grad) 2819 x = base.add(1) 2820 return [base], [x.view(-1), x.view(-1)] 2821 2822 self.verify_aot_autograd( 2823 f, partial(inp_callable, req_grad=False), test_mutation=True 2824 ) 2825 fw_graph = self.verify_aot_autograd( 2826 f, partial(inp_callable, req_grad=True), test_mutation=True 2827 ) 2828 self.assertExpectedInline( 2829 fw_graph.code.strip(), 2830 """\ 2831def forward(self, primals_1): 2832 clone = torch.ops.aten.clone.default(primals_1); primals_1 = None 2833 as_strided = torch.ops.aten.as_strided.default(clone, [4], [1], 0) 2834 add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None 2835 as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [4], [1], 0); clone = add = None 2836 as_strided_8 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0) 2837 view_1 = torch.ops.aten.view.default(as_strided_8, [4]); as_strided_8 = None 2838 return (as_strided_scatter, view_1)""", 2839 ) # noqa: B950 2840 2841 def test_input_aliased_with_mutation_output_alias(self): 2842 def f(a, b, c): 2843 # a and c alias 2844 c.mul_(2) 2845 # The main thing we're testing here is that 2846 # (1) We need to reconstruct c.view(-1) from the 3rd input to the forward 2847 # (2) But we need to be careful to do this *before* converting aliased inputs into synthetic bases. 2848 # The original fw takes in 3 args, but the compiled fw takes in only 2 args. 2849 return b.add(1), c.view(-1) 2850 2851 def inp_callable(req_grad): 2852 base1 = torch.ones(2, 2, requires_grad=req_grad) 2853 base2 = torch.ones(2, 2, requires_grad=req_grad) 2854 x = base1.add(1) 2855 y = base2.add(1) 2856 return [base1, base2], [x.view(-1), y, x.view(-1)] 2857 2858 self.verify_aot_autograd( 2859 f, partial(inp_callable, req_grad=False), test_mutation=True 2860 ) 2861 fw_graph = self.verify_aot_autograd( 2862 f, partial(inp_callable, req_grad=True), test_mutation=True 2863 ) 2864 self.assertExpectedInline( 2865 fw_graph.code.strip(), 2866 """\ 2867def forward(self, primals_1, primals_2): 2868 clone = torch.ops.aten.clone.default(primals_1); primals_1 = None 2869 as_strided_1 = torch.ops.aten.as_strided.default(clone, [4], [1], 0) 2870 mul = torch.ops.aten.mul.Tensor(as_strided_1, 2); as_strided_1 = None 2871 as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, mul, [4], [1], 0); clone = mul = None 2872 add = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None 2873 as_strided_7 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0) 2874 view_1 = torch.ops.aten.view.default(as_strided_7, [-1]); as_strided_7 = None 2875 return (as_strided_scatter, add, view_1)""", 2876 ) # noqa: B950 2877 2878 def test_input_metadata_mutation_aliases(self): 2879 def f(a, b): 2880 # a and b alias, and we do a metadata mutation on a 2881 # Since we're not mutating data, then b isn't affected at all. 2882 # We expect aot autograd to not bother with constructing a synthetic base. 2883 a.t_() 2884 return a + b 2885 2886 def inp_callable(req_grad): 2887 base = torch.ones(2, 2, requires_grad=req_grad) 2888 x = base.add(1) 2889 return [base], [x.view(-1), x.view(-1)] 2890 2891 self.verify_aot_autograd( 2892 f, partial(inp_callable, req_grad=False), test_mutation=True 2893 ) 2894 fw_graph = self.verify_aot_autograd( 2895 f, partial(inp_callable, req_grad=True), test_mutation=True 2896 ) 2897 # Expectation: fwd() takes in 2 args, and we don't construct a synthetic base. 2898 self.assertExpectedInline( 2899 fw_graph.code.strip(), 2900 """\ 2901def forward(self, primals_1, primals_2): 2902 t = torch.ops.aten.t.default(primals_1); primals_1 = None 2903 add = torch.ops.aten.add.Tensor(t, primals_2); t = primals_2 = None 2904 return (add,)""", 2905 ) 2906 2907 def test_input_mutation_aliases_and_none_require_gradients(self): 2908 def f(a, b, c): 2909 # a and b alias, but neither require gradients (so they don't have a _base) 2910 # aot autograd should construct the synthetic base from `torch.Tensor(a.storage())` 2911 a.mul_(2) 2912 return b + 1, c + 1 2913 2914 def inp_callable(req_grad): 2915 base = torch.ones(2, 2) 2916 c_arg = torch.ones(2, 2, requires_grad=req_grad) 2917 x = base.add(1) 2918 return [base, c_arg], [x.view(-1), x.view(-1), c_arg] 2919 2920 self.verify_aot_autograd( 2921 f, partial(inp_callable, req_grad=False), test_mutation=True 2922 ) 2923 2924 with self.assertRaisesRegex( 2925 RuntimeError, "is a tensor subclass. This is not supported today" 2926 ): 2927 self.verify_aot_autograd( 2928 f, 2929 partial(inp_callable, req_grad=False), 2930 test_mutation=True, 2931 make_inputs_subclasses=True, 2932 ) 2933 2934 fw_graph = self.verify_aot_autograd( 2935 f, partial(inp_callable, req_grad=True), test_mutation=True 2936 ) 2937 self.assertExpectedInline( 2938 fw_graph.code.strip(), 2939 """\ 2940def forward(self, primals_1, primals_2): 2941 as_strided = torch.ops.aten.as_strided.default(primals_1, [4], [1], 0) 2942 mul = torch.ops.aten.mul.Tensor(as_strided, 2); as_strided = None 2943 as_strided_scatter = torch.ops.aten.as_strided_scatter.default(primals_1, mul, [4], [1], 0); primals_1 = mul = None 2944 as_strided_3 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0) 2945 add = torch.ops.aten.add.Tensor(as_strided_3, 1); as_strided_3 = None 2946 add_1 = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None 2947 return (as_strided_scatter, add, add_1)""", 2948 ) # noqa: B950 2949 2950 @skipIfDynamoInput("Fails with dynamo") 2951 def test_input_mutation_aliases_bases_out_of_order(self): 2952 # This tests our calling convention: if b and d are aliased, then the outer calling convention 2953 # that we send to the compiled forward becomes: 2954 # (b_d_base, a, c) 2955 # Importantly, even though a and c alias in our test, neither inputs are mutated, 2956 # So we don't need to do the base construction / deconstruction 2957 def f(a, b, c, d): 2958 b.add_(1) 2959 d.unsqueeze_(0) 2960 return a + c + d, b.view(-1) 2961 2962 def inp_callable(req_grad): 2963 base1 = torch.ones(2, 2, requires_grad=req_grad) 2964 base2 = torch.ones(2, 2, requires_grad=req_grad) 2965 x1 = base1.add(1) 2966 x2 = base2.add(1) 2967 # a and c alias, b and d alias 2968 return [base1, base2], [x1.view(-1), x2.view(-1), x1.view(-1), x2.view(-1)] 2969 2970 self.verify_aot_autograd( 2971 f, partial(inp_callable, req_grad=False), test_mutation=True 2972 ) 2973 2974 with self.assertRaisesRegex( 2975 RuntimeError, 2976 "Metadata mutations are currently not allowed on tensor subclasses", 2977 ): 2978 self.verify_aot_autograd( 2979 f, 2980 partial(inp_callable, req_grad=False), 2981 test_mutation=True, 2982 make_inputs_subclasses=True, 2983 ) 2984 2985 fw_graph = self.verify_aot_autograd( 2986 f, partial(inp_callable, req_grad=True), test_mutation=True 2987 ) 2988 # 3 graph inputs: (b_d_base, a, c) 2989 # 2 returns: (b_updated, a+c+d) 2990 # (there are 2 original fw outs, but one is a view of b so it's not part of the graph) 2991 # (there are also 2 input mutations, but one is a metadata-only mutation so the compiled forward doesn't return it) 2992 self.assertExpectedInline( 2993 fw_graph.code.strip(), 2994 """\ 2995def forward(self, primals_1, primals_2, primals_3): 2996 clone = torch.ops.aten.clone.default(primals_1); primals_1 = None 2997 as_strided = torch.ops.aten.as_strided.default(clone, [4], [1], 0) 2998 add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None 2999 as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [4], [1], 0); clone = add = None 3000 add_1 = torch.ops.aten.add.Tensor(primals_2, primals_3); primals_2 = primals_3 = None 3001 as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0) 3002 unsqueeze_1 = torch.ops.aten.unsqueeze.default(as_strided_5, 0); as_strided_5 = None 3003 add_2 = torch.ops.aten.add.Tensor(add_1, unsqueeze_1); add_1 = None 3004 as_strided_14 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0) 3005 view_2 = torch.ops.aten.view.default(as_strided_14, [-1]); as_strided_14 = None 3006 return (as_strided_scatter, add_2, view_2, unsqueeze_1)""", 3007 ) # noqa: B950 3008 3009 @unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable") 3010 def test_synthetic_base_base_attribute_is_none(self): 3011 def f(a, b): 3012 a.add_(1) 3013 return a + b 3014 3015 def inp_callable(): 3016 base = torch.ones(4, 4, device="cuda") 3017 # detach() so that none of the inputs have a ._base attribute. 3018 a = base[0].detach() 3019 b = base[1].detach() 3020 base2 = torch.ones(2, 2, requires_grad=True) 3021 return [base], [a, b] 3022 3023 self.verify_aot_autograd(f, inp_callable, test_mutation=True) 3024 3025 def test_input_mutation_alias_everything(self): 3026 # Mondo test that tests a combination of: 3027 # input is mutated, that aliases another input (so we make a synthetic base) 3028 # an output is an alias of another output 3029 # an output is an alias of an intermediate 3030 # a and c are aliased 3031 def f(a, b, c): 3032 c.mul_(2) # mutates c 3033 b.t_() # metadata mutate b 3034 tmp = a + c 3035 out1 = tmp.view(-1) 3036 out2 = b.t() 3037 out3 = out1.unsqueeze(0) 3038 # out1 and out3 are aliases of an intermediate, and alias each other! 3039 # out2 aliases an input, so we don't return it 3040 return out1, out2, out3 3041 3042 def inp_callable(req_grad): 3043 base1 = torch.ones(2, 2, requires_grad=req_grad) 3044 base2 = torch.ones(2, 2, requires_grad=req_grad) 3045 # Note: in our test, the add() is important because we need the graph inputs to be non-leaves so we can mutate them. 3046 base1_ = base1.add(1) 3047 base2_ = base2.add(1) 3048 a = base1_.view(-1) 3049 b = base2_ 3050 c = base1_.view(-1) 3051 return [base1, base2], [a, b, c] 3052 3053 self.verify_aot_autograd( 3054 f, partial(inp_callable, req_grad=False), test_mutation=True 3055 ) 3056 fw_graph = self.verify_aot_autograd( 3057 f, partial(inp_callable, req_grad=True), test_mutation=True 3058 ) 3059 # Expected: 3060 # - 2 inputs in the forward: synthetic_base_a_c, b 3061 # - 1 output in the forward: "tmp" 3062 # out2 is an alias of an input, and will be generated off of b outside of the compiled fn 3063 # out1 and out3 are aliases of tmp, that we generate outside of the compiled function 3064 self.assertExpectedInline( 3065 fw_graph.code.strip(), 3066 """\ 3067def forward(self, primals_1, primals_2): 3068 clone = torch.ops.aten.clone.default(primals_1); primals_1 = None 3069 view = torch.ops.aten.view.default(primals_2, [2, 2]); primals_2 = None 3070 as_strided_1 = torch.ops.aten.as_strided.default(clone, [4], [1], 0) 3071 mul = torch.ops.aten.mul.Tensor(as_strided_1, 2); as_strided_1 = None 3072 as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, mul, [4], [1], 0); clone = mul = None 3073 as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0) 3074 t = torch.ops.aten.t.default(view); view = None 3075 as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0) 3076 add = torch.ops.aten.add.Tensor(as_strided_5, as_strided_2); as_strided_5 = as_strided_2 = None 3077 view_1 = torch.ops.aten.view.default(add, [-1]) 3078 t_1 = torch.ops.aten.t.default(t) 3079 unsqueeze = torch.ops.aten.unsqueeze.default(view_1, 0) 3080 return (as_strided_scatter, t, view_1, t_1, unsqueeze, add)""", 3081 ) # noqa: B950 3082 3083 def test_dynamic_shape_output_not_in_bw_graph(self): 3084 def f(x): 3085 return [x + 1, x.shape[0]] 3086 3087 inp = torch.ones(5, requires_grad=True) 3088 bw_graph_cell = [None] 3089 compiled_f = aot_function( 3090 f, 3091 fw_compiler=nop, 3092 bw_compiler=partial(extract_graph, graph_cell=bw_graph_cell), 3093 decompositions={}, 3094 keep_inference_input_mutations=False, 3095 dynamic=True, 3096 ) 3097 out = compiled_f(inp) 3098 out[0].sum().backward() 3099 # The important bit: the forward fn returns 2 outputs, 3100 # but one of them is a symint so we should only see 3101 # 1 grad_output as an input to the backward graph. 3102 # (Otherwise, autograd will plumb a None as the value of the grad_output, 3103 # which causes inductor to complain). 3104 self.assertExpectedInline( 3105 bw_graph_cell[0].code.strip(), 3106 """\ 3107def forward(self, tangents_1): 3108 return (tangents_1,)""", 3109 ) 3110 3111 def test_no_grad_input_output(self): 3112 def f(a, b): 3113 return a.cos(), b.cos(), a * b 3114 3115 inp_thunks = [ 3116 lambda: torch.randn(5, requires_grad=True), 3117 lambda: torch.randn(5, requires_grad=False), 3118 ] 3119 for inps in itertools.product(inp_thunks, repeat=2): 3120 inps = [i() for i in inps] 3121 self.verify_aot_autograd(f, inps) 3122 3123 def test_some_output_requires_grad_input_doesnt(self): 3124 def f(a, b): 3125 a_view = a.view(-1) 3126 a_view.requires_grad_(True) 3127 return a_view 3128 3129 inp = [torch.randn(3, 3), torch.randn(3, 3, requires_grad=True)] 3130 self.verify_aot_autograd(f, inp) 3131 3132 def test_some_outputs_dont_require_grad_view(self): 3133 def f(a, b): 3134 return a.detach(), b 3135 3136 inp = [ 3137 torch.randn(3, 3, requires_grad=True), 3138 torch.randn(3, 3, requires_grad=True), 3139 ] 3140 self.verify_aot_autograd(f, inp) 3141 3142 def test_some_outputs_dont_require_grad_non_view(self): 3143 def f(a, b): 3144 return a.add(1).detach(), b 3145 3146 inp = [ 3147 torch.randn(3, 3, requires_grad=True), 3148 torch.randn(3, 3, requires_grad=True), 3149 ] 3150 self.verify_aot_autograd(f, inp) 3151 3152 def test_inner_grad(self): 3153 def foo(x): 3154 y = torch.exp(x) 3155 z = torch.autograd.grad(y, x) 3156 return z 3157 3158 inps = [torch.randn((), requires_grad=True)] 3159 self.verify_aot_autograd(foo, inps) 3160 3161 def test_grad_context(self): 3162 def foo(x): 3163 return x * 2 3164 3165 inps = [torch.randn((), requires_grad=True)] 3166 graph_size = None 3167 3168 def get_graph_size(fx_g, _): 3169 nonlocal graph_size 3170 graph_size = len(fx_g.graph.nodes) 3171 return fx_g 3172 3173 f = aot_function(foo, nop, get_graph_size) 3174 with torch.set_grad_enabled(False): 3175 f(*inps) 3176 self.assertIsNone(graph_size) 3177 3178 f = aot_function(foo, nop, get_graph_size) 3179 with torch.set_grad_enabled(True): 3180 out = f(*inps) 3181 self.assertIsNone(graph_size) 3182 out.sum().backward() 3183 self.assertTrue(graph_size > 2) 3184 3185 def test_output_dict(self): 3186 def f(x): 3187 return {"a": x, "b": x} 3188 3189 inp = [torch.randn(3, 3, requires_grad=True)] 3190 self.verify_aot_autograd(f, inp) 3191 3192 def f(x, y): 3193 return {"a": x, "b": y + x} 3194 3195 inp = [torch.randn(3, requires_grad=True), torch.randn(3)] 3196 self.verify_aot_autograd(f, inp) 3197 3198 def f(x): 3199 new_d = {} 3200 for k in x: 3201 new_d[k] = x[k] * 2 3202 return new_d 3203 3204 a = torch.randn(3, requires_grad=True) 3205 b = torch.randn(3, requires_grad=True) 3206 3207 def inp_callable(): 3208 inps = [{"a": a, "b": b}] 3209 return inps, inps 3210 3211 self.verify_aot_autograd(f, inp_callable) 3212 3213 def test_module(self): 3214 mod = nn.Sequential(nn.Linear(32, 32), nn.ReLU()) 3215 compiled_mod = compiled_module(mod, nop, nop) 3216 inp = torch.randn(32, 32) 3217 ref_out = mod(inp) 3218 ref_out.sum().backward() 3219 ref_grads = sorted([(name, p.grad) for name, p in mod.named_parameters()]) 3220 out = compiled_mod(inp) 3221 out.sum().backward() 3222 grads = sorted([(name, p.grad) for name, p in mod.named_parameters()]) 3223 self.assertEqual((out, grads), (ref_out, ref_grads)) 3224 3225 def test_batchnorm(self): 3226 mod = compiled_module(nn.BatchNorm2d(4), nop, nop) 3227 x = torch.ones(1, 4, 2, 2) 3228 mod(x).sum().backward() 3229 3230 def test_list_codegen(self): 3231 def list_nop(f, _): 3232 def g(inps): 3233 return f(*inps) 3234 3235 g._boxed_call = True 3236 return g 3237 3238 def f(a, b, c): 3239 return a.sin() * b.cos() * c.sin() 3240 3241 f = aot_function(f, list_nop) 3242 inp = [torch.randn(5, requires_grad=True) for _ in range(3)] 3243 f(*inp).sum().backward() 3244 3245 @patch("torch._functorch.aot_autograd.AOT_COUNTER", new_callable=itertools.count) 3246 def test_compilation_context(self, counter): 3247 def f(x): 3248 return x.sin().sin() 3249 3250 count = [] 3251 3252 def compiler(fx_g, _): 3253 context = get_aot_compilation_context() 3254 count.append((context[0], len(fx_g.graph.nodes))) 3255 return fx_g 3256 3257 f = aot_function(f, compiler) 3258 out = f(torch.randn(5, requires_grad=True)) 3259 f = aot_function(f, compiler) 3260 f(torch.randn(5)) 3261 out.sum().backward() 3262 self.assertExpectedInline( 3263 str(count), 3264 """[(['0_forward'], 4), (['1_inference'], 4), (['0_backward'], 8)]""", 3265 ) 3266 3267 def test_dupe_arg(self): 3268 def f(x, y): 3269 return x + y 3270 3271 x = torch.randn(3, 3, requires_grad=True) 3272 self.verify_aot_autograd(f, [x, x]) 3273 3274 def test_dupe_arg_torture(self): 3275 def f(x, y): 3276 x.t_() 3277 y.unsqueeze_(0) 3278 return x + y 3279 3280 x = torch.randn(3, 3, requires_grad=True).clone() 3281 self.verify_aot_autograd(f, [x, x]) 3282 3283 # See https://github.com/pytorch/pytorch/issues/100224 3284 def test_dupe_arg_returned_as_output(self): 3285 def f(a, b, a_): 3286 a[0].add_(1) 3287 return a_ 3288 3289 f_compiled = aot_function(f, nop) 3290 a = torch.ones(2) 3291 b = torch.ones(2) 3292 out_ref = f(a, b, a) 3293 3294 a2 = torch.ones(2) 3295 b2 = torch.ones(2) 3296 out_test = f_compiled(a2, b2, a2) 3297 3298 self.assertEqual(out_ref, out_test) 3299 self.assertEqual(a, a2) 3300 3301 @patch("torch._functorch.aot_autograd.AOT_COUNTER", new_callable=itertools.count) 3302 @patch("torch._functorch.config.debug_assert", True) 3303 def test_invalid_dupe_left_bias(self, counter): 3304 # This test checks that, just because only the first 3305 # argument did a metadata mutation, we still correctly 3306 # switch to strategy 2 (deduplicate) 3307 # See: https://github.com/pytorch/pytorch/pull/89896#discussion_r1036224447 3308 class F(torch.nn.Module): 3309 def forward(self, x, y): 3310 x.t_() 3311 return (x + y,) 3312 3313 x = torch.randn(3, 3, requires_grad=True).clone() 3314 y = torch.randn(3, 3, requires_grad=True) 3315 self.verify_aot_autograd(F(), [x, x]) 3316 3317 fxx = aot_module_simplified(F(), (x, x), nop) 3318 self.assertExpectedRaisesInline( 3319 AssertionError, 3320 lambda: fxx(x, y), 3321 """At compilation time, graph 2 was compiled under the assumption that input 1 would be a duplicate of input 0, but at runtime this was not the case. This indicates a guard bug in AOTAutograd or Dynamo, please file a bug to PyTorch.""", # noqa: B950 3322 ) 3323 3324 @patch("torch._functorch.aot_autograd.AOT_COUNTER", new_callable=itertools.count) 3325 @patch("torch._functorch.config.debug_assert", True) 3326 def test_invalid_dupe(self, counter): 3327 self._test_invalid_dupe(counter, fake=False) 3328 3329 # See Note: Dynamo recompilation guarding invalid grad for why this test exists 3330 @patch("torch._functorch.aot_autograd.AOT_COUNTER", new_callable=itertools.count) 3331 @patch("torch._functorch.config.debug_assert", True) 3332 def test_invalid_dupe_fake(self, counter): 3333 self._test_invalid_dupe(counter, fake=True) 3334 3335 def _test_invalid_dupe(self, counter, fake): 3336 class F(torch.nn.Module): 3337 def forward(self, x, y): 3338 x.unsqueeze_(0) 3339 y.unsqueeze_(0) 3340 return (x + y,) 3341 3342 x = torch.randn(3, 3, requires_grad=True).clone() 3343 y = torch.randn(3, 3, requires_grad=True).clone() 3344 3345 if fake: 3346 shape_env = ShapeEnv() 3347 fake_mode = FakeTensorMode(shape_env=shape_env) 3348 3349 fake_x = fake_mode.from_tensor(x) 3350 fake_y = fake_mode.from_tensor(y) 3351 3352 if fake: 3353 fxy = aot_module_simplified(F(), (fake_x, fake_y), nop) 3354 else: 3355 fxy = aot_module_simplified(F(), (x, y), nop) 3356 3357 fxy(x, y) 3358 x = torch.randn(3, 3, requires_grad=True).clone() 3359 y = torch.randn(3, 3, requires_grad=True).clone() 3360 fxy(x, x) # is ok! 3361 3362 if fake: 3363 fxx = aot_module_simplified(F(), (fake_x, fake_x), nop) 3364 else: 3365 fxx = aot_module_simplified(F(), (x, x), nop) 3366 3367 x = torch.randn(3, 3, requires_grad=True).clone() 3368 y = torch.randn(3, 3, requires_grad=True).clone() 3369 fxx(x, x) 3370 # Note This should not raise! Once we have guards in place here, 3371 # we will have this working correctly, as it should recompile. 3372 x = torch.randn(3, 3, requires_grad=True).clone() 3373 y = torch.randn(3, 3, requires_grad=True).clone() 3374 self.assertExpectedRaisesInline( 3375 AssertionError, 3376 lambda: fxx(x, y), 3377 """At compilation time, graph 1 was compiled under the assumption that input 1 would be a duplicate of input 0, but at runtime this was not the case. This indicates a guard bug in AOTAutograd or Dynamo, please file a bug to PyTorch.""", # noqa: B950 3378 ) 3379 3380 @patch("torch._functorch.aot_autograd.AOT_COUNTER", new_callable=itertools.count) 3381 @patch("torch._functorch.config.debug_assert", True) 3382 def test_invalid_requires_grad(self, counter): 3383 self._test_invalid_requires_grad(counter, fake=False) 3384 3385 # See Note: Dynamo recompilation guarding invalid grad for why this test exists 3386 @patch("torch._functorch.aot_autograd.AOT_COUNTER", new_callable=itertools.count) 3387 @patch("torch._functorch.config.debug_assert", True) 3388 def test_invalid_requires_grad_fake(self, counter): 3389 self._test_invalid_requires_grad(counter, fake=True) 3390 3391 def _test_invalid_requires_grad(self, counter, fake): 3392 class F(torch.nn.Module): 3393 def forward(self, x, y): 3394 return (x + y,) 3395 3396 x = torch.randn(3, 3, requires_grad=True) 3397 y = torch.randn(3, 3, requires_grad=True) 3398 z = torch.randn(3, 3, requires_grad=False) 3399 3400 if fake: 3401 shape_env = ShapeEnv() 3402 fake_mode = FakeTensorMode(shape_env=shape_env) 3403 3404 fake_x = fake_mode.from_tensor(x) 3405 fake_y = fake_mode.from_tensor(y) 3406 fake_z = fake_mode.from_tensor(z) 3407 3408 if fake: 3409 fxy = aot_module_simplified(F(), (fake_x, fake_y), nop) 3410 else: 3411 fxy = aot_module_simplified(F(), (x, y), nop) 3412 3413 compare_equal_outs_and_grads(self, F(), fxy, (x, y)) 3414 compare_equal_outs_and_grads(self, F(), fxy, (x, z)) 3415 3416 if fake: 3417 fxz = aot_module_simplified(F(), (fake_x, fake_z), nop) 3418 else: 3419 fxz = aot_module_simplified(F(), (x, z), nop) 3420 3421 compare_equal_outs_and_grads(self, F(), fxz, (x, z)) 3422 3423 self.assertExpectedRaisesInline( 3424 AssertionError, 3425 lambda: fxz(x, y), 3426 """At compilation time, graph 1 was compiled under the assumption that input 1 would not require grad, but at runtime this was not the case. This indicates a guard bug in AOTAutograd or Dynamo, please file a bug to PyTorch.""", # noqa: B950 3427 ) 3428 3429 def test_custom_autograd(self): 3430 class CustomFn(torch.autograd.Function): 3431 @staticmethod 3432 def forward(ctx, x): 3433 return x.clone() 3434 3435 @staticmethod 3436 def backward(ctx, grad_output): 3437 return grad_output + 1 3438 3439 def f(x): 3440 return CustomFn.apply(x) 3441 3442 self.verify_aot_autograd(f, [torch.randn(3)]) 3443 3444 @unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable") 3445 def test_autocast_disable_guard(self): 3446 with torch._C._DisableAutocast(): 3447 x = torch.rand([4, 4]).cuda() 3448 y = x @ x 3449 self.assertEqual(y.dtype, torch.float32) 3450 3451 @unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable") 3452 def test_nonidempotent_amp(self): 3453 def f(self_s_emb, add_3): 3454 einsum_2 = torch.functional.einsum("ah,th->t", self_s_emb, add_3) 3455 log_softmax_2 = einsum_2.log_softmax(-1) 3456 return (log_softmax_2,) 3457 3458 args = [ 3459 torch.rand((1, 256), dtype=torch.float32, device="cuda"), 3460 torch.rand((30, 256), dtype=torch.float16, device="cuda"), 3461 ] 3462 with torch.cuda.amp.autocast(enabled=True): 3463 self.verify_aot_autograd(f, args) 3464 3465 args = [e.requires_grad_(True) for e in args] 3466 with torch.cuda.amp.autocast(enabled=True): 3467 self.verify_aot_autograd(f, args) 3468 3469 @unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable") 3470 @unittest.skipIf(not torch.backends.cudnn.is_available(), "CUDNN is unavailable") 3471 @skipIfRocm # https://github.com/pytorch/pytorch/issues/96560 3472 def test_batch_norm_amp(self): 3473 device = "cuda" 3474 input_dtype = torch.float16 3475 param_dtype = torch.float32 3476 weight, bias = ( 3477 torch.ones(64, device=device, dtype=param_dtype, requires_grad=True) 3478 for _ in range(2) 3479 ) 3480 running_mean, running_var = ( 3481 torch.ones(64, device=device, dtype=param_dtype) for _ in range(2) 3482 ) 3483 3484 def bn(x): 3485 return torch.ops.aten.cudnn_batch_norm( 3486 x, 3487 weight, 3488 bias, 3489 running_mean, 3490 running_var, 3491 False, 3492 0.1, 3493 1e-05, 3494 ) 3495 3496 inp = torch.ones( 3497 torch.Size([16, 64, 112, 112]), dtype=input_dtype, device=device 3498 ) 3499 3500 ref = bn(inp) 3501 cudnn_batch_norm_decomp = torch._decomp.get_decompositions( 3502 {torch.ops.aten.cudnn_batch_norm} 3503 ) 3504 aot_fn = make_fx(bn, decomposition_table=cudnn_batch_norm_decomp)(inp) 3505 res = aot_fn(inp) 3506 for a, b in zip(ref, res): 3507 assert torch.allclose(a, b) 3508 3509 def test_output_op_depending_on_symint(self): 3510 """ 3511 It won't be obvious from reading this test what it's testing for. We should probably make it into a more 3512 focused unit test. 3513 3514 An issue with the following program was the expand op would end up depending on a symint whose proxy was 3515 incorrectly associated with one of the grad tensors rather than input tensors. It broke partitioner logic 3516 and the net result was aot_function failed to produce a function and threw an exception instead. 3517 """ 3518 inp = torch.randn(5, requires_grad=True) 3519 3520 def f(x): 3521 return x.expand(x.shape) 3522 3523 # TODO(whc) make this work (test setup is wrong somehow) 3524 # joint_forward_backward = create_joint_forward_backward(f) 3525 # out = f(inp) 3526 # joint_inputs = ([inp], [out.detach().contiguous()]) 3527 # fx_g = make_fx(joint_forward_backward)(*joint_inputs) 3528 # TODO: assert outputs of fwd graph trace to correct symint 3529 3530 # e2e test that fails without symint clone fix 3531 af = aot_function( 3532 f, 3533 nop, 3534 partition_fn=partial( 3535 min_cut_rematerialization_partition, compiler="inductor" 3536 ), 3537 dynamic=True, 3538 ) 3539 out = af(inp) 3540 self.assertEqual(out, f(inp)) 3541 3542 def test_inference_mode(self): 3543 m = torch.nn.Linear(4, 4) 3544 inp = torch.randn(4, 4) 3545 3546 aot_mod = aot_module(m, fw_compiler=nop) 3547 3548 with torch.inference_mode(): 3549 out_ref = m(inp) 3550 out_test = aot_mod(inp) 3551 self.assertEqual(out_ref, out_test) 3552 3553 def test_default_partitioner_saves_symints_not_tensors_for_bw(self): 3554 """ 3555 In this test, the important thing is that primals_1 is **only** needed in the backward 3556 in order to grab its sizes. 3557 We need to assert that what we save for the backward are the tensor's sizes, and not the tensor itself. 3558 3559 The way this test is set up, it will actually fail if we try to save the input tensor for backward. 3560 Why? 3561 b.masked_fill_(c, 0) has a backward that requires knowing a's sizes 3562 b.masked_fill_(c, 0) **also** mutates a (because b and a are aliased) 3563 The autograd engine yells at us if we save "a" for backward, and then try to mutate it. 3564 """ 3565 inp = torch.randn(2, 2, requires_grad=True) 3566 3567 def f(a): 3568 b = a[0] 3569 c = torch.ones_like(b, dtype=torch.bool) 3570 d = b.masked_fill_(c, 0) 3571 return d 3572 3573 compiled_f = aot_function(f, nop, dynamic=True) 3574 inp_ref = torch.ones(2, 2, requires_grad=True) 3575 inp_test = torch.ones(2, 2, requires_grad=True) 3576 3577 out_ref = f(inp_ref.clone()) 3578 out_test = compiled_f(inp_test.clone()) 3579 3580 self.assertEqual(out_ref, out_test) 3581 3582 out_ref.sum().backward() 3583 out_test.sum().backward() 3584 3585 self.assertEqual(inp_ref.grad, inp_test.grad) 3586 3587 def test_buffer_copied_in_graph(self): 3588 class MyModel(torch.nn.Module): 3589 def __init__(self) -> None: 3590 super().__init__() 3591 self.buf = torch.nn.Buffer(torch.zeros(1)) 3592 self.w1 = torch.nn.Parameter(torch.zeros(1)) 3593 self.w2 = torch.nn.Parameter(torch.zeros(1)) 3594 3595 def forward(self, x): 3596 self.buf.add_(1) 3597 return (self.w1 * x * self.w2).sum() + self.buf.sum() 3598 3599 model_for_eager = MyModel() 3600 model_for_compile = copy.deepcopy(model_for_eager) 3601 3602 fw_graph_cell = [None] 3603 compiled_f = aot_module( 3604 model_for_compile, 3605 fw_compiler=make_boxed_compiler( 3606 partial(extract_graph, graph_cell=fw_graph_cell) 3607 ), 3608 bw_compiler=nop, 3609 keep_inference_input_mutations=True, 3610 ) 3611 inp_ref = torch.ones(1, requires_grad=True) 3612 inp_test = torch.ones(1, requires_grad=True) 3613 3614 out_ref = model_for_eager(inp_ref.clone()) 3615 out_test = compiled_f(inp_test.clone()) 3616 3617 self.assertExpectedInline( 3618 fw_graph_cell[0].code.strip(), 3619 """\ 3620def forward(self, primals_1, primals_2, primals_3, primals_4): 3621 add = torch.ops.aten.add.Tensor(primals_3, 1) 3622 mul = torch.ops.aten.mul.Tensor(primals_1, primals_4) 3623 mul_1 = torch.ops.aten.mul.Tensor(mul, primals_2) 3624 sum_1 = torch.ops.aten.sum.default(mul_1); mul_1 = None 3625 sum_2 = torch.ops.aten.sum.default(add) 3626 add_1 = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None 3627 copy_ = torch.ops.aten.copy_.default(primals_3, add); primals_3 = add = copy_ = None 3628 return (add_1, primals_1, primals_2, primals_4, mul)""", 3629 ) 3630 3631 self.assertEqual(out_ref, out_test) 3632 3633 out_ref.sum().backward() 3634 out_test.sum().backward() 3635 3636 eager_grads = [p.grad for _, p in model_for_eager.named_parameters()] 3637 compile_grads = [p.grad for _, p in model_for_compile.named_parameters()] 3638 3639 self.assertEqual(eager_grads, compile_grads) 3640 self.assertEqual(inp_ref.grad, inp_test.grad) 3641 3642 def test_buffer_copied_in_graph_with_different_shapes(self): 3643 class MyModel(torch.nn.Module): 3644 def __init__(self) -> None: 3645 super().__init__() 3646 self.buf = torch.nn.Buffer(torch.ones(4, 4)) 3647 self.w = torch.nn.Parameter( 3648 torch.Tensor([[4, 5], [1, 2], [6, 7], [8, 9]]) 3649 ) 3650 3651 def forward(self, x): 3652 self.buf.add_(1) 3653 return (self.w @ x).sum() + self.buf.sum() 3654 3655 model_for_eager = MyModel() 3656 model_for_compile = copy.deepcopy(model_for_eager) 3657 3658 fw_graph_cell = [None] 3659 compiled_f = aot_module( 3660 model_for_compile, 3661 fw_compiler=make_boxed_compiler( 3662 partial(extract_graph, graph_cell=fw_graph_cell) 3663 ), 3664 bw_compiler=nop, 3665 keep_inference_input_mutations=True, 3666 ) 3667 inp_ref = torch.ones(2, 4, requires_grad=True) 3668 inp_test = torch.ones(2, 4, requires_grad=True) 3669 3670 out_ref = model_for_eager(inp_ref.clone()) 3671 out_test = compiled_f(inp_test.clone()) 3672 3673 self.assertExpectedInline( 3674 fw_graph_cell[0].code.strip(), 3675 """\ 3676def forward(self, primals_1, primals_2, primals_3): 3677 add = torch.ops.aten.add.Tensor(primals_2, 1) 3678 mm = torch.ops.aten.mm.default(primals_1, primals_3) 3679 sum_1 = torch.ops.aten.sum.default(mm); mm = None 3680 sum_2 = torch.ops.aten.sum.default(add) 3681 add_1 = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None 3682 copy_ = torch.ops.aten.copy_.default(primals_2, add); primals_2 = add = copy_ = None 3683 return (add_1, primals_1, primals_3)""", 3684 ) 3685 self.assertEqual(out_ref, out_test) 3686 3687 out_ref.sum().backward() 3688 out_test.sum().backward() 3689 3690 eager_grads = [p.grad for _, p in model_for_eager.named_parameters()] 3691 compile_grads = [p.grad for _, p in model_for_compile.named_parameters()] 3692 3693 self.assertEqual(eager_grads, compile_grads) 3694 3695 self.assertEqual(inp_ref.grad, inp_test.grad) 3696 3697 def test_buffer_batch_norm(self): 3698 class MyModel(torch.nn.Module): 3699 def __init__(self) -> None: 3700 super().__init__() 3701 self.m = torch.nn.BatchNorm1d(100) 3702 3703 def forward(self, x): 3704 return self.m(x) 3705 3706 model_for_eager = MyModel() 3707 model_for_compile = copy.deepcopy(model_for_eager) 3708 3709 fw_graph_cell = [None] 3710 bw_graph_cell = [None] 3711 compiled_f = aot_module( 3712 model_for_compile, 3713 fw_compiler=make_boxed_compiler( 3714 partial(extract_graph, graph_cell=fw_graph_cell) 3715 ), 3716 bw_compiler=make_boxed_compiler( 3717 partial(extract_graph, graph_cell=bw_graph_cell) 3718 ), 3719 keep_inference_input_mutations=True, 3720 ) 3721 inp_ref = torch.ones(20, 100, requires_grad=True) 3722 inp_test = torch.ones(20, 100, requires_grad=True) 3723 3724 out_ref = model_for_eager(inp_ref.clone()) 3725 out_test = compiled_f(inp_test.clone()) 3726 3727 self.assertExpectedInline( 3728 fw_graph_cell[0].code.strip(), 3729 """\ 3730def forward(self, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6): 3731 add = torch.ops.aten.add.Tensor(primals_5, 1) 3732 _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(primals_6, primals_1, primals_2, primals_3, primals_4, True, 0.1, 1e-05); primals_2 = None 3733 getitem = _native_batch_norm_legit_functional[0] 3734 getitem_1 = _native_batch_norm_legit_functional[1] 3735 getitem_2 = _native_batch_norm_legit_functional[2] 3736 getitem_3 = _native_batch_norm_legit_functional[3] 3737 getitem_4 = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None 3738 copy_ = torch.ops.aten.copy_.default(primals_3, getitem_3); primals_3 = copy_ = None 3739 copy__1 = torch.ops.aten.copy_.default(primals_4, getitem_4); primals_4 = copy__1 = None 3740 copy__2 = torch.ops.aten.copy_.default(primals_5, add); primals_5 = add = copy__2 = None 3741 return (getitem, primals_1, primals_6, getitem_1, getitem_2, getitem_3, getitem_4)""", # noqa: B950 3742 ) 3743 3744 self.assertEqual(out_ref, out_test) 3745 3746 out_ref.sum().backward() 3747 out_test.sum().backward() 3748 3749 eager_grads = [p.grad for _, p in model_for_eager.named_parameters()] 3750 compile_grads = [p.grad for _, p in model_for_compile.named_parameters()] 3751 self.assertEqual(eager_grads, compile_grads) 3752 3753 self.assertExpectedInline( 3754 bw_graph_cell[0].code.strip(), 3755 """\ 3756def forward(self, primals_1, primals_6, getitem_1, getitem_2, getitem_3, getitem_4, tangents_1): 3757 native_batch_norm_backward = torch.ops.aten.native_batch_norm_backward.default(tangents_1, primals_6, primals_1, getitem_3, getitem_4, getitem_1, getitem_2, True, 1e-05, [True, True, True]); tangents_1 = primals_6 = primals_1 = getitem_3 = getitem_4 = getitem_1 = getitem_2 = None 3758 getitem_5 = native_batch_norm_backward[0] 3759 getitem_6 = native_batch_norm_backward[1] 3760 getitem_7 = native_batch_norm_backward[2]; native_batch_norm_backward = None 3761 return (getitem_6, getitem_7, None, None, None, getitem_5)""", # noqa: B950 3762 ) 3763 3764 self.assertEqual(inp_ref.grad, inp_test.grad) 3765 3766 def test_new_inp_requires_grad_now(self): 3767 def f(x, y): 3768 return x.add_(y) 3769 3770 fw_graph_cell = [None] 3771 bw_graph_cell = [None] 3772 compiled_f = aot_function( 3773 f, 3774 fw_compiler=make_boxed_compiler( 3775 partial(extract_graph, graph_cell=fw_graph_cell) 3776 ), 3777 bw_compiler=make_boxed_compiler( 3778 partial(extract_graph, graph_cell=bw_graph_cell) 3779 ), 3780 keep_inference_input_mutations=True, 3781 ) 3782 3783 inp_ref = ( 3784 torch.ones(20, 100, requires_grad=False), 3785 torch.ones(20, 100, requires_grad=True), 3786 ) 3787 inp_test = ( 3788 torch.ones(20, 100, requires_grad=False), 3789 torch.ones(20, 100, requires_grad=True), 3790 ) 3791 3792 out_ref = f(*inp_ref) 3793 out_test = compiled_f(*inp_test) 3794 3795 # There is no copy_ method 3796 self.assertExpectedInline( 3797 fw_graph_cell[0].code.strip(), 3798 """\ 3799def forward(self, primals_1, primals_2): 3800 clone = torch.ops.aten.clone.default(primals_1); primals_1 = None 3801 add = torch.ops.aten.add.Tensor(clone, primals_2); clone = primals_2 = None 3802 return (add, add)""", 3803 ) # noqa: B950 3804 3805 self.assertEqual(out_ref, out_test) 3806 3807 out_ref.sum().backward() 3808 out_test.sum().backward() 3809 3810 self.assertExpectedInline( 3811 bw_graph_cell[0].code.strip(), 3812 """\ 3813def forward(self, tangents_1): 3814 return (None, tangents_1)""", 3815 ) # noqa: B950 3816 3817 def test_real_weights_in_symbolic_mode(self): 3818 from functorch.experimental import functionalize 3819 3820 class M(torch.nn.Module): 3821 def __init__(self) -> None: 3822 super().__init__() 3823 self.linear = torch.nn.Linear(5, 5) 3824 3825 def forward(self, x): 3826 x = self.linear(x) 3827 return x 3828 3829 m = M().eval() 3830 3831 inp = torch.randn(2, 5) 3832 3833 gm = make_fx(m, tracing_mode="symbolic", _allow_non_fake_inputs=True)(inp) 3834 self.assertEqual(gm(torch.ones(2, 5)), m(torch.ones(2, 5))) 3835 3836 gm_functionalized = make_fx( 3837 functionalize( 3838 gm, 3839 ), 3840 tracing_mode="symbolic", 3841 _allow_non_fake_inputs=True, 3842 )(inp) 3843 self.assertEqual(gm_functionalized(torch.ones(2, 5)), m(torch.ones(2, 5))) 3844 3845 inp_count = 0 3846 for node in gm.graph.nodes: 3847 if node.op == "placeholder": 3848 inp_count += 1 3849 3850 # No more param lifting 3851 self.assertEqual(inp_count, 1) 3852 3853 inp_count = 0 3854 for node in gm_functionalized.graph.nodes: 3855 if node.op == "placeholder": 3856 inp_count += 1 3857 3858 # No more param lifting 3859 self.assertEqual(inp_count, 1) 3860 3861 with self.assertRaisesRegex( 3862 Exception, "Please convert all Tensors to FakeTensors" 3863 ): 3864 make_fx(m, tracing_mode="symbolic", _allow_non_fake_inputs=False)( 3865 torch.randn(2, 5) 3866 ) 3867 3868 def test_real_weights_in_symbolic_mode_with_inplace_ops(self): 3869 class M(torch.nn.Module): 3870 def __init__(self) -> None: 3871 super().__init__() 3872 self.buffer = torch.nn.Buffer(torch.ones(4, 5)) 3873 3874 def forward(self, x): 3875 y = self.buffer.add_(3) 3876 y.resize_([20]) 3877 assert y.shape == self.buffer.shape 3878 return x.sum() + self.buffer.sum() 3879 3880 m = M().eval() 3881 inp = torch.randn(2, 5) 3882 # inplace mutation on attr is not allowed 3883 with self.assertRaisesRegex(Exception, "Can't call metadata"): 3884 make_fx(m, tracing_mode="symbolic", _allow_non_fake_inputs=True)(inp) 3885 3886 def _compile_and_erase_bases(self, *output_view_indices): 3887 # Overrides _base and _view_func tensor attributes, so as to avoid the view-replay 3888 # execution path when reconstructing views. 3889 class NoViewReplayTensor(torch.Tensor): 3890 @property 3891 def _base(self): 3892 return None 3893 3894 @property 3895 def _view_func(self): 3896 return None 3897 3898 # Wraps the outputs that are views of the FX graph 'g' with NoViewReplayTensor, 3899 # since they are the only ones that will get reconstructed. 3900 def wrapper(g, *args, **kwargs): 3901 outs = list(g(*args, **kwargs)) 3902 for i in output_view_indices: 3903 outs[i] = NoViewReplayTensor(outs[i]) 3904 return tuple(outs) 3905 3906 return lambda f: aot_function(f, fw_compiler=lambda g, _: partial(wrapper, g)) 3907 3908 def test_output_aliases_input_view_meta_replay(self): 3909 @self._compile_and_erase_bases(0) 3910 def f(a): 3911 return a.view(-1) 3912 3913 inp = torch.ones(2, 2, requires_grad=True) 3914 out = f(inp) 3915 3916 self.assertIsNotNone(out.grad_fn) 3917 self.assertExpectedInline( 3918 str(out.grad_fn.__class__), """<class 'ViewBackward0'>""" 3919 ) 3920 3921 def test_output_aliases_intermediate_view_meta_replay(self): 3922 @self._compile_and_erase_bases(0, 1) 3923 def f(a): 3924 b = a.clone() 3925 return b.view(-1), b.view(-1) 3926 3927 inp = torch.ones(2, 2, requires_grad=True) 3928 out1, out2 = f(inp) 3929 3930 self.assertIsNotNone(out1.grad_fn) 3931 self.assertExpectedInline( 3932 str(out1.grad_fn.__class__), """<class 'ViewBackward0'>""" 3933 ) 3934 3935 self.assertIsNotNone(out2.grad_fn) 3936 self.assertExpectedInline( 3937 str(out2.grad_fn.__class__), """<class 'ViewBackward0'>""" 3938 ) 3939 3940 def test_output_aliases_output_view_meta_replay(self): 3941 @self._compile_and_erase_bases(1) 3942 def f(a): 3943 b = a.add(10) 3944 return b, b.view(-1) 3945 3946 inp = torch.ones(2, 2, requires_grad=True) 3947 out1, out2 = f(inp) 3948 3949 self.assertEqual(out1.untyped_storage(), out2.untyped_storage()) 3950 self.assertIsNotNone(out2.grad_fn) 3951 self.assertExpectedInline( 3952 str(out2.grad_fn.__class__), """<class 'ViewBackward0'>""" 3953 ) 3954 3955 @skipIfTorchDynamo() 3956 @patch("torch._dynamo.config.assume_static_by_default", False) 3957 def test_dynamic_output_aliases_input_view_meta_replay(self): 3958 # - torch.compile: using it so we can have a SymInt in the FX graph. 3959 # - Compiling with inductor, so that tensor._base isn't tracked. 3960 # 3961 # This should force the use of as_strided in the view reconstruction path. 3962 # The first 2 view-replay paths won't be taken because: 3963 # - target_functional_tensor will be symbolic (_functionalize_is_symbolic call) 3964 # - tensor._base will be None 3965 @torch.compile(backend="inductor") 3966 def f(a, sz): 3967 return a.view(sz), a.view(-1) 3968 3969 inp = torch.ones(2, 2, requires_grad=True) 3970 out1, out2 = f(inp, (4,)) 3971 3972 self.assertIsNotNone(out1.grad_fn) 3973 self.assertExpectedInline( 3974 str(out1.grad_fn.__class__), """<class 'AsStridedBackward0'>""" 3975 ) 3976 3977 self.assertIsNotNone(out2.grad_fn) 3978 self.assertExpectedInline( 3979 str(out2.grad_fn.__class__), """<class 'ViewBackward0'>""" 3980 ) 3981 3982 3983def extract_graph(fx_g, _, graph_cell): 3984 graph_cell[0] = fx_g 3985 return fx_g 3986 3987 3988def get_ins_outs(fx_g): 3989 ins = [] 3990 outs = [] 3991 for n in fx_g.graph.nodes: 3992 if n.op == "placeholder": 3993 ins.append(n) 3994 elif n.op == "output": 3995 outs = tuple(n.args[0]) 3996 return ins, outs 3997 3998 3999def get_num_ins_outs(fx_g): 4000 return tuple(len(i) for i in get_ins_outs(fx_g)) 4001 4002 4003def get_fw_bw_graph( 4004 f, inps, partitioner=min_cut_rematerialization_partition, dynamic=False 4005): 4006 fw_graph_cell = [None] 4007 bw_graph_cell = [None] 4008 aot_function( 4009 f, 4010 fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell), 4011 bw_compiler=partial(extract_graph, graph_cell=bw_graph_cell), 4012 partition_fn=partitioner, 4013 decompositions=default_decompositions, 4014 dynamic=dynamic, 4015 )(*inps).sum().backward() 4016 return (fw_graph_cell[0], bw_graph_cell[0]) 4017 4018 4019class TestMod(torch.nn.Module): 4020 def __init__(self, fn): 4021 super().__init__() 4022 self.p = torch.nn.Parameter(torch.ones(2, requires_grad=True)) 4023 self.fn = fn 4024 4025 def forward(self, *args): 4026 return self.fn(self.p, *args) 4027 4028 4029class TestAOTExport(AOTTestCase): 4030 def test_aot_export_ban_dropout_mut_pre_dispatch(self): 4031 def fn(p, x): 4032 y = torch.ops.aten.dropout.default(x, 0.1, train=False) 4033 y.add_(1) 4034 return (y,) 4035 4036 mod = TestMod(fn) 4037 inp = torch.randn(2, 2) 4038 4039 with self.assertRaisesRegex( 4040 RuntimeError, "cannot mutate tensors with frozen storage" 4041 ): 4042 aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=True) 4043 4044 gm, _ = aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=False) 4045 self.assertExpectedInline( 4046 str(gm.code).strip(), 4047 """\ 4048def forward(self, arg0_1, arg1_1): 4049 clone = torch.ops.aten.clone.default(arg1_1); arg1_1 = None 4050 add = torch.ops.aten.add.Tensor(clone, 1); clone = None 4051 return (add,)""", 4052 ) 4053 4054 fw_graph_cell = [None] 4055 bw_graph_cell = [None] 4056 4057 compiled_outs = aot_function( 4058 fn, 4059 fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell), 4060 bw_compiler=partial(extract_graph, graph_cell=bw_graph_cell), 4061 partition_fn=default_partition, 4062 decompositions=default_decompositions, 4063 dynamic=True, 4064 )(*inp) 4065 fw_graph = fw_graph_cell[0] 4066 bw_graph = bw_graph_cell[0] 4067 4068 self.assertExpectedInline( 4069 str(fw_graph.code).strip(), 4070 """\ 4071def forward(self, arg0_1, arg1_1): 4072 clone = torch.ops.aten.clone.default(arg1_1); arg1_1 = None 4073 add = torch.ops.aten.add.Tensor(clone, 1); clone = None 4074 return (add,)""", 4075 ) 4076 4077 def test_aot_export_predispatch_func_simple(self): 4078 def fn(p, x): 4079 y = x + 2 4080 with torch.no_grad(): 4081 y.add_(2) 4082 return (x * 2 + y,) 4083 4084 mod = TestMod(fn) 4085 inp = torch.randn(2, 2) 4086 4087 with torch.no_grad(): 4088 gm, _ = aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=True) 4089 self.assertExpectedInline( 4090 str(gm.code).strip(), 4091 """\ 4092def forward(self, arg0_1, arg1_1): 4093 add = torch.ops.aten.add.Tensor(arg1_1, 2) 4094 _set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None 4095 add_1 = torch.ops.aten.add.Tensor(add, 2); add = None 4096 _set_grad_enabled_1 = torch._C._set_grad_enabled(False); _set_grad_enabled_1 = None 4097 mul = torch.ops.aten.mul.Tensor(arg1_1, 2); arg1_1 = None 4098 add_2 = torch.ops.aten.add.Tensor(mul, add_1); mul = add_1 = None 4099 return (add_2,)""", 4100 ) 4101 4102 def test_aot_export_predispatch_func_composite_implicit(self): 4103 def fn(p, x): 4104 with torch.enable_grad(): 4105 y = x @ x 4106 y.add_(2) 4107 return (x.sum() + y.sum(),) 4108 4109 mod = TestMod(fn) 4110 inp = torch.randn(2, 2) 4111 4112 with torch.no_grad(): 4113 gm, _ = aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=True) 4114 self.assertExpectedInline( 4115 str(gm.code).strip(), 4116 """\ 4117def forward(self, arg0_1, arg1_1): 4118 _set_grad_enabled = torch._C._set_grad_enabled(True); _set_grad_enabled = None 4119 matmul = torch.ops.aten.matmul.default(arg1_1, arg1_1) 4120 _set_grad_enabled_1 = torch._C._set_grad_enabled(False); _set_grad_enabled_1 = None 4121 add = torch.ops.aten.add.Tensor(matmul, 2); matmul = None 4122 sum_1 = torch.ops.aten.sum.default(arg1_1); arg1_1 = None 4123 sum_2 = torch.ops.aten.sum.default(add); add = None 4124 add_1 = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None 4125 return (add_1,)""", 4126 ) 4127 4128 def test_aot_export_predispatch_composite_implicit_inplace(self): 4129 def fn(x, p): 4130 return (torch.ops.aten.absolute_.default(x.clone()),) 4131 4132 mod = TestMod(fn) 4133 inp = torch.randn(2, 2) 4134 4135 gm, _ = aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=True) 4136 self.assertExpectedInline( 4137 str(gm.code).strip(), 4138 """\ 4139def forward(self, arg0_1, arg1_1): 4140 clone = torch.ops.aten.clone.default(arg0_1); arg0_1 = None 4141 abs_1 = torch.ops.aten.abs.default(clone); clone = None 4142 return (abs_1,)""", 4143 ) 4144 4145 def test_aot_export_predispatch_composite_implicit_linear(self): 4146 class MM(torch.nn.Module): 4147 def __init__(self) -> None: 4148 super().__init__() 4149 self.linear = torch.nn.Linear(2, 2) 4150 4151 def forward(self, x): 4152 return (self.linear(x),) 4153 4154 mod = MM() 4155 inp = torch.randn(2, 2) 4156 4157 gm, _ = aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=True) 4158 self.assertExpectedInline( 4159 str(gm.code).strip(), 4160 """\ 4161def forward(self, arg0_1, arg1_1, arg2_1): 4162 linear = torch.ops.aten.linear.default(arg2_1, arg0_1, arg1_1); arg2_1 = arg0_1 = arg1_1 = None 4163 return (linear,)""", 4164 ) 4165 4166 @unittest.expectedFailure 4167 def test_aot_export_predispatch_outdtype(self): 4168 class M(torch.nn.Module): 4169 def __init__(self, weight): 4170 super().__init__() 4171 self.weight = weight 4172 4173 def forward(self, x): 4174 y = x + 2 4175 y.add_(5) 4176 return ( 4177 out_dtype(torch.ops.aten.mm.default, torch.int32, y, self.weight), 4178 ) 4179 4180 weight = torch.randint(-128, 127, (5, 5), dtype=torch.int8) 4181 mod = M(weight) 4182 inp = torch.randint(-128, 127, (5, 5), dtype=torch.int8) 4183 4184 gm, _ = aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=True) 4185 self.assertExpectedInline( 4186 str(gm.code).strip(), 4187 """\ 4188def forward(self, arg0_1, arg1_1): 4189 _set_grad_enabled = torch._C._set_grad_enabled(True); _set_grad_enabled = None 4190 mm = torch.ops.aten.mm.default(arg1_1, arg1_1) 4191 _set_grad_enabled_1 = torch._C._set_grad_enabled(False); _set_grad_enabled_1 = None 4192 add = torch.ops.aten.add.Tensor(mm, 2); mm = None 4193 sum_1 = torch.ops.aten.sum.default(arg1_1); arg1_1 = None 4194 sum_2 = torch.ops.aten.sum.default(add); add = None 4195 add_1 = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None 4196 return (add_1,)""", 4197 ) 4198 4199 def test_aot_export_predispatch_func_view(self): 4200 def fn(p, x): 4201 y = x @ x 4202 y.add_(2) 4203 return (x.sum() + y.view(1, 4).sum(),) 4204 4205 mod = TestMod(fn) 4206 inp = torch.randn(2, 2) 4207 4208 gm, _ = aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=True) 4209 self.assertExpectedInline( 4210 str(gm.code).strip(), 4211 """\ 4212def forward(self, arg0_1, arg1_1): 4213 matmul = torch.ops.aten.matmul.default(arg1_1, arg1_1) 4214 add = torch.ops.aten.add.Tensor(matmul, 2); matmul = None 4215 sum_1 = torch.ops.aten.sum.default(arg1_1); arg1_1 = None 4216 view_1 = torch.ops.aten.view.default(add, [1, 4]); add = None 4217 sum_2 = torch.ops.aten.sum.default(view_1); view_1 = None 4218 add_1 = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None 4219 return (add_1,)""", 4220 ) 4221 4222 def test_aot_export_predispatch_buffer_mutation_metadata(self): 4223 class Foo(torch.nn.Module): 4224 def __init__(self) -> None: 4225 super().__init__() 4226 self.foo = torch.nn.Buffer(torch.zeros(2, 2)) 4227 4228 def forward(self, x): 4229 self.foo.add_(4) 4230 return (x.sum() + self.foo.sum(),) 4231 4232 inp = torch.randn(2, 2) 4233 4234 gm, graph_sig = aot_export_module( 4235 Foo(), [inp], trace_joint=False, pre_dispatch=True 4236 ) 4237 self.assertExpectedInline( 4238 str(gm.code).strip(), 4239 """\ 4240def forward(self, arg0_1, arg1_1): 4241 add = torch.ops.aten.add.Tensor(arg0_1, 4); arg0_1 = None 4242 sum_1 = torch.ops.aten.sum.default(arg1_1); arg1_1 = None 4243 sum_2 = torch.ops.aten.sum.default(add) 4244 add_1 = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None 4245 return (add, add_1)""", 4246 ) 4247 eager_mod = Foo() 4248 output_1, output_2 = gm(torch.zeros(2, 2), inp) 4249 eager_output = eager_mod(inp) 4250 self.assertTrue(torch.allclose(output_2, eager_output[0])) 4251 4252 _, output_2 = gm(output_1, inp) 4253 eager_output = eager_mod(inp) 4254 self.assertTrue(torch.allclose(output_2, eager_output[0])) 4255 self.assertTrue("foo" in graph_sig.buffers) 4256 self.assertTrue(graph_sig.inputs_to_buffers["arg0_1"] == "foo") 4257 4258 def test_aot_export_predispatch_with_autograd_op(self): 4259 def foo(p, x): 4260 with torch.enable_grad(): 4261 y = x + 5 4262 y.add_(5) 4263 y.add_(7) 4264 return (x.cos() + y.sin(),) 4265 4266 inp = torch.randn(2, 2) 4267 mod = TestMod(foo) 4268 4269 with torch.no_grad(): 4270 gm, _ = aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=True) 4271 self.assertExpectedInline( 4272 str(gm.code).strip(), 4273 """\ 4274def forward(self, arg0_1, arg1_1): 4275 _set_grad_enabled = torch._C._set_grad_enabled(True); _set_grad_enabled = None 4276 add = torch.ops.aten.add.Tensor(arg1_1, 5) 4277 add_1 = torch.ops.aten.add.Tensor(add, 5); add = None 4278 add_2 = torch.ops.aten.add.Tensor(add_1, 7); add_1 = None 4279 cos = torch.ops.aten.cos.default(arg1_1); arg1_1 = None 4280 sin = torch.ops.aten.sin.default(add_2); add_2 = None 4281 add_3 = torch.ops.aten.add.Tensor(cos, sin); cos = sin = None 4282 _set_grad_enabled_1 = torch._C._set_grad_enabled(False); _set_grad_enabled_1 = None 4283 return (add_3,)""", 4284 ) 4285 4286 @unittest.skipIf(IS_WINDOWS, "Windows isn't supported for this case") 4287 @unittest.skipIf( 4288 not torchdynamo.is_dynamo_supported(), "TorchDynamo is not supported" 4289 ) 4290 def test_aot_export_predispatch_with_cond_nested(self): 4291 class M(torch.nn.Module): 4292 def __init__(self) -> None: 4293 super().__init__() 4294 4295 def forward(self, x): 4296 def true_fn(x): 4297 y = x.sin() 4298 y.add_(5) 4299 4300 def true_true_fn(x): 4301 y = x.sin() 4302 y.add_(7) 4303 return y.sin() 4304 4305 def true_false_fn(x): 4306 return x.cos() 4307 4308 return torch.cond( 4309 y.cos().sum() > 5, true_true_fn, true_false_fn, [y.cos()] 4310 ) 4311 4312 def false_fn(x): 4313 z = x.cos() 4314 z.add_(6) 4315 return z.sin() 4316 4317 a = torch.cond(x.sum() > 4, true_fn, false_fn, [x]) 4318 return (a + 3, a + 4) 4319 4320 inp = torch.randn(2, 2) 4321 gm, _ = aot_export_module(M(), [inp], trace_joint=False, pre_dispatch=True) 4322 self.assertExpectedInline( 4323 str(gm.code).strip(), 4324 """\ 4325def forward(self, arg0_1): 4326 sum_1 = torch.ops.aten.sum.default(arg0_1) 4327 gt = torch.ops.aten.gt.Scalar(sum_1, 4); sum_1 = None 4328 true_graph_0 = self.true_graph_0 4329 false_graph_0 = self.false_graph_0 4330 cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = None 4331 getitem = cond[0]; cond = None 4332 add = torch.ops.aten.add.Tensor(getitem, 3) 4333 add_1 = torch.ops.aten.add.Tensor(getitem, 4); getitem = None 4334 return (add, add_1)""", # noqa: B950 4335 ) 4336 4337 self.assertExpectedInline( 4338 str(gm.true_graph_0.code).strip(), 4339 """\ 4340def forward(self, arg0_1): 4341 sin = torch.ops.aten.sin.default(arg0_1); arg0_1 = None 4342 add = torch.ops.aten.add.Tensor(sin, 5); sin = None 4343 cos = torch.ops.aten.cos.default(add) 4344 sum_1 = torch.ops.aten.sum.default(cos); cos = None 4345 gt = torch.ops.aten.gt.Scalar(sum_1, 5); sum_1 = None 4346 cos_1 = torch.ops.aten.cos.default(add); add = None 4347 true_graph_0 = self.true_graph_0 4348 false_graph_0 = self.false_graph_0 4349 cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [cos_1]); gt = true_graph_0 = false_graph_0 = cos_1 = None 4350 getitem = cond[0]; cond = None 4351 return (getitem,)""", # noqa: B950 4352 ) 4353 4354 self.assertExpectedInline( 4355 str(gm.true_graph_0.true_graph_0.code).strip(), 4356 """\ 4357def forward(self, arg0_1): 4358 sin = torch.ops.aten.sin.default(arg0_1); arg0_1 = None 4359 add = torch.ops.aten.add.Tensor(sin, 7); sin = None 4360 sin_1 = torch.ops.aten.sin.default(add); add = None 4361 return (sin_1,)""", 4362 ) 4363 4364 @unittest.skipIf(IS_WINDOWS, "Windows isn't supported for this case") 4365 @unittest.skipIf( 4366 not torchdynamo.is_dynamo_supported(), "TorchDynamo is not supported" 4367 ) 4368 def test_aot_export_predispatch_map_1(self): 4369 class M(torch.nn.Module): 4370 def __init__(self) -> None: 4371 super().__init__() 4372 4373 def forward(self, x, y): 4374 def true_fn(x, r): 4375 y = x.sin() 4376 y.add_(5) 4377 return y.cos() + r.sum() 4378 4379 def false_fn(x, r): 4380 z = x.cos() 4381 4382 def f(x, y): 4383 a = x.cos() 4384 a.add_(5) 4385 return a + y 4386 4387 return ( 4388 z 4389 + control_flow.map(f, z, r).sum() 4390 + control_flow.map(f, z, r).sum() 4391 ) 4392 4393 a = torch.cond(x.sum() > 4, true_fn, false_fn, [x, y]) 4394 return (a + 3, a + 4) 4395 4396 inps = [torch.randn(2, 2), torch.ones(2)] 4397 gm, _ = aot_export_module(M(), inps, trace_joint=False, pre_dispatch=True) 4398 self.assertExpectedInline( 4399 str(gm.code).strip(), 4400 """\ 4401def forward(self, arg0_1, arg1_1): 4402 sum_1 = torch.ops.aten.sum.default(arg0_1) 4403 gt = torch.ops.aten.gt.Scalar(sum_1, 4); sum_1 = None 4404 true_graph_0 = self.true_graph_0 4405 false_graph_0 = self.false_graph_0 4406 cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1, arg1_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = arg1_1 = None 4407 getitem = cond[0]; cond = None 4408 add = torch.ops.aten.add.Tensor(getitem, 3) 4409 add_1 = torch.ops.aten.add.Tensor(getitem, 4); getitem = None 4410 return (add, add_1)""", # noqa: B950 4411 ) 4412 self.assertExpectedInline( 4413 str(gm.true_graph_0.code).strip(), 4414 """\ 4415def forward(self, arg0_1, arg1_1): 4416 sin = torch.ops.aten.sin.default(arg0_1); arg0_1 = None 4417 add = torch.ops.aten.add.Tensor(sin, 5); sin = None 4418 cos = torch.ops.aten.cos.default(add); add = None 4419 sum_1 = torch.ops.aten.sum.default(arg1_1); arg1_1 = None 4420 add_1 = torch.ops.aten.add.Tensor(cos, sum_1); cos = sum_1 = None 4421 return (add_1,)""", 4422 ) 4423 self.assertExpectedInline( 4424 str(gm.false_graph_0.code).strip(), 4425 """\ 4426def forward(self, arg0_1, arg1_1): 4427 cos = torch.ops.aten.cos.default(arg0_1); arg0_1 = None 4428 select = torch.ops.aten.select.int(cos, 0, 0); select = None 4429 body_graph_0 = self.body_graph_0 4430 map_impl = torch.ops.higher_order.map_impl(body_graph_0, [cos], [arg1_1]); body_graph_0 = None 4431 getitem = map_impl[0]; map_impl = None 4432 sum_1 = torch.ops.aten.sum.default(getitem); getitem = None 4433 add = torch.ops.aten.add.Tensor(cos, sum_1); sum_1 = None 4434 select_1 = torch.ops.aten.select.int(cos, 0, 0); select_1 = None 4435 body_graph_1 = self.body_graph_1 4436 map_impl_1 = torch.ops.higher_order.map_impl(body_graph_1, [cos], [arg1_1]); body_graph_1 = cos = arg1_1 = None 4437 getitem_1 = map_impl_1[0]; map_impl_1 = None 4438 sum_2 = torch.ops.aten.sum.default(getitem_1); getitem_1 = None 4439 add_1 = torch.ops.aten.add.Tensor(add, sum_2); add = sum_2 = None 4440 return (add_1,)""", 4441 ) 4442 self.assertExpectedInline( 4443 str(gm.false_graph_0.body_graph_0.code).strip(), 4444 """\ 4445def forward(self, arg0_1, arg1_1): 4446 cos = torch.ops.aten.cos.default(arg0_1); arg0_1 = None 4447 add = torch.ops.aten.add.Tensor(cos, 5); cos = None 4448 add_1 = torch.ops.aten.add.Tensor(add, arg1_1); add = arg1_1 = None 4449 return (add_1,)""", 4450 ) 4451 4452 def test_aot_export_predispatch_map_2(self): 4453 class M(torch.nn.Module): 4454 def __init__(self) -> None: 4455 super().__init__() 4456 4457 def forward(self, x, y): 4458 z = x.cos() 4459 4460 def f(x, y): 4461 a = x.cos() 4462 a.add_(5) 4463 return a + y 4464 4465 return (z + control_flow.map(f, z, y).sum(),) 4466 4467 inps = [torch.randn(2, 2), torch.ones(2)] 4468 gm, _ = aot_export_module(M(), inps, trace_joint=False, pre_dispatch=True) 4469 self.assertExpectedInline( 4470 str(gm.code).strip(), 4471 """\ 4472def forward(self, arg0_1, arg1_1): 4473 cos = torch.ops.aten.cos.default(arg0_1); arg0_1 = None 4474 body_graph_0 = self.body_graph_0 4475 map_impl = torch.ops.higher_order.map_impl(body_graph_0, [cos], [arg1_1]); body_graph_0 = arg1_1 = None 4476 getitem = map_impl[0]; map_impl = None 4477 sum_1 = torch.ops.aten.sum.default(getitem); getitem = None 4478 add = torch.ops.aten.add.Tensor(cos, sum_1); cos = sum_1 = None 4479 return (add,)""", 4480 ) # noqa: B950 4481 self.assertExpectedInline( 4482 str(gm.body_graph_0.code).strip(), 4483 """\ 4484def forward(self, arg0_1, arg1_1): 4485 cos = torch.ops.aten.cos.default(arg0_1); arg0_1 = None 4486 add = torch.ops.aten.add.Tensor(cos, 5); cos = None 4487 add_1 = torch.ops.aten.add.Tensor(add, arg1_1); add = arg1_1 = None 4488 return [add_1]""", 4489 ) 4490 4491 @unittest.skipIf(IS_WINDOWS, "Windows isn't supported for this case") 4492 @unittest.skipIf( 4493 not torchdynamo.is_dynamo_supported(), "TorchDynamo is not supported" 4494 ) 4495 def test_aot_export_predispatch_with_cond(self): 4496 class M(torch.nn.Module): 4497 def __init__(self) -> None: 4498 super().__init__() 4499 4500 def forward(self, x): 4501 def true_fn(x): 4502 y = x.sin() 4503 z = torch.ops.aten.linear.default(y, torch.randn(2, 2)) 4504 z.add_(5) 4505 return z.cos() 4506 4507 def false_fn(x): 4508 z = x.cos() 4509 z.add_(6) 4510 return z.sin() 4511 4512 a = torch.cond(x.sum() > 4, true_fn, false_fn, [x]) 4513 return (a + 3, a + 4) 4514 4515 inp = torch.randn(2, 2) 4516 gm, _ = aot_export_module(M(), [inp], trace_joint=False, pre_dispatch=True) 4517 self.assertExpectedInline( 4518 str(gm.code).strip(), 4519 """\ 4520def forward(self, arg0_1): 4521 sum_1 = torch.ops.aten.sum.default(arg0_1) 4522 gt = torch.ops.aten.gt.Scalar(sum_1, 4); sum_1 = None 4523 true_graph_0 = self.true_graph_0 4524 false_graph_0 = self.false_graph_0 4525 cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = None 4526 getitem = cond[0]; cond = None 4527 add = torch.ops.aten.add.Tensor(getitem, 3) 4528 add_1 = torch.ops.aten.add.Tensor(getitem, 4); getitem = None 4529 return (add, add_1)""", # noqa: B950 4530 ) 4531 self.assertExpectedInline( 4532 str(gm.true_graph_0.code).strip(), 4533 """\ 4534def forward(self, arg0_1): 4535 sin = torch.ops.aten.sin.default(arg0_1); arg0_1 = None 4536 randn = torch.ops.aten.randn.default([2, 2], device = device(type='cpu'), pin_memory = False) 4537 linear = torch.ops.aten.linear.default(sin, randn); sin = randn = None 4538 add = torch.ops.aten.add.Tensor(linear, 5); linear = None 4539 cos = torch.ops.aten.cos.default(add); add = None 4540 return (cos,)""", 4541 ) 4542 4543 def test_aot_export_predispatch_conv_and_bn(self): 4544 class ConvBatchnorm(torch.nn.Module): 4545 def __init__(self) -> None: 4546 super().__init__() 4547 self.conv = torch.nn.Conv2d(1, 3, 1, 1) 4548 self.bn = torch.nn.BatchNorm2d(3) 4549 4550 def forward(self, x): 4551 x = self.conv(x) 4552 x = self.bn(x) 4553 return (x,) 4554 4555 mod = ConvBatchnorm() 4556 mod.train() 4557 inp = torch.randn(1, 1, 3, 3) 4558 4559 gm, _ = aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=True) 4560 self.assertExpectedInline( 4561 str(gm.code).strip(), 4562 """\ 4563def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1): 4564 conv2d = torch.ops.aten.conv2d.default(arg7_1, arg0_1, arg1_1); arg7_1 = arg0_1 = arg1_1 = None 4565 add = torch.ops.aten.add.Tensor(arg6_1, 1); arg6_1 = None 4566 _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(conv2d, arg2_1, arg3_1, arg4_1, arg5_1, True, 0.1, 1e-05); conv2d = arg2_1 = arg3_1 = arg4_1 = arg5_1 = None 4567 getitem = _native_batch_norm_legit_functional[0] 4568 getitem_3 = _native_batch_norm_legit_functional[3] 4569 getitem_4 = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None 4570 return (getitem_3, getitem_4, add, getitem)""", # noqa: B950 4571 ) 4572 4573 def test_aot_export_predispatch_reshape(self): 4574 class Reshape(torch.nn.Module): 4575 def forward(self, x): 4576 y = x.reshape(4, 4) 4577 return (y.sum(),) 4578 4579 mod = Reshape() 4580 inp = torch.randn(2, 8) 4581 4582 gm, _ = aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=True) 4583 self.assertExpectedInline( 4584 str(gm.code).strip(), 4585 """\ 4586def forward(self, arg0_1): 4587 view = torch.ops.aten.view.default(arg0_1, [4, 4]); arg0_1 = None 4588 sum_1 = torch.ops.aten.sum.default(view); view = None 4589 return (sum_1,)""", 4590 ) # noqa: B950 4591 4592 def test_aot_export_predispatch_contiguous(self): 4593 class Cont(torch.nn.Module): 4594 def forward(self, x): 4595 y = torch.ops.aten.contiguous.default(x) 4596 return (y.sum(),) 4597 4598 mod = Cont() 4599 inp = torch.randn(2, 8) 4600 4601 gm, _ = aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=True) 4602 self.assertExpectedInline( 4603 str(gm.code).strip(), 4604 """\ 4605def forward(self, arg0_1): 4606 sum_1 = torch.ops.aten.sum.default(arg0_1); arg0_1 = None 4607 return (sum_1,)""", 4608 ) # noqa: B950 4609 4610 def test_aot_export_module_joint(self): 4611 class ConvBatchnormRelu(torch.nn.Module): 4612 def __init__(self) -> None: 4613 super().__init__() 4614 self.conv = torch.nn.Conv2d(1, 3, 1, 1) 4615 self.bn = torch.nn.BatchNorm2d(3) 4616 4617 def forward(self, x): 4618 x = self.conv(x) 4619 x = self.bn(x) 4620 user_out = torch.nn.functional.relu(x) 4621 loss = user_out.sum() 4622 return loss, user_out.detach() 4623 4624 mod = ConvBatchnormRelu() 4625 mod.train() 4626 inp = torch.randn(1, 1, 3, 3) 4627 o_ref = mod(inp) 4628 fx_g, signature = aot_export_module( 4629 mod, [inp], trace_joint=True, output_loss_index=0 4630 ) 4631 # Some important characteristics of the exported graph below: 4632 # 8 arguments: 2 params from conv, 2 params from batchnorm, 2 buffers from 1 batchnorm, 1 user input 4633 # 9 outputs: 3 mutated buffers (from batchnorm), 2 user outputs and 4 gradients (since there were 4 parameters) 4634 for node in fx_g.graph.nodes: 4635 node.meta.pop("stack_trace", None) 4636 self.assertExpectedInline( 4637 fx_g.print_readable(print_output=False), 4638 """\ 4639class <lambda>(torch.nn.Module): 4640 def forward(self, arg0_1: "f32[3, 1, 1, 1]", arg1_1: "f32[3]", arg2_1: "f32[3]", arg3_1: "f32[3]", arg4_1: "f32[3]", arg5_1: "f32[3]", arg6_1: "i64[]", arg7_1: "f32[1, 1, 3, 3]"): 4641 # No stacktrace found for following nodes 4642 convolution: "f32[1, 3, 3, 3]" = torch.ops.aten.convolution.default(arg7_1, arg0_1, arg1_1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1); arg1_1 = None 4643 add: "i64[]" = torch.ops.aten.add.Tensor(arg6_1, 1); arg6_1 = None 4644 _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(convolution, arg2_1, arg3_1, arg4_1, arg5_1, True, 0.1, 1e-05); arg3_1 = arg4_1 = arg5_1 = None 4645 getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0] 4646 getitem_1: "f32[3]" = _native_batch_norm_legit_functional[1] 4647 getitem_2: "f32[3]" = _native_batch_norm_legit_functional[2] 4648 getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3] 4649 getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None 4650 relu: "f32[1, 3, 3, 3]" = torch.ops.aten.relu.default(getitem); getitem = None 4651 detach: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(relu); detach = None 4652 detach_1: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(relu) 4653 detach_2: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_1); detach_1 = None 4654 detach_3: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_2); detach_2 = None 4655 detach_4: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_3); detach_3 = None 4656 sum_1: "f32[]" = torch.ops.aten.sum.default(relu) 4657 detach_5: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(relu); relu = None 4658 detach_6: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_5); detach_5 = None 4659 detach_7: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_6); detach_6 = None 4660 detach_8: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_7); detach_7 = None 4661 detach_9: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_8); detach_8 = None 4662 detach_10: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_9); detach_9 = None 4663 ones_like: "f32[]" = torch.ops.aten.ones_like.default(sum_1, pin_memory = False, memory_format = torch.preserve_format) 4664 expand: "f32[1, 3, 3, 3]" = torch.ops.aten.expand.default(ones_like, [1, 3, 3, 3]); ones_like = None 4665 detach_11: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_4); detach_4 = None 4666 detach_12: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_11); detach_11 = None 4667 detach_13: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_12); detach_12 = None 4668 detach_14: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_13); detach_13 = None 4669 threshold_backward: "f32[1, 3, 3, 3]" = torch.ops.aten.threshold_backward.default(expand, detach_14, 0); expand = detach_14 = None 4670 native_batch_norm_backward = torch.ops.aten.native_batch_norm_backward.default(threshold_backward, convolution, arg2_1, getitem_3, getitem_4, getitem_1, getitem_2, True, 1e-05, [True, True, True]); threshold_backward = convolution = arg2_1 = getitem_1 = getitem_2 = None 4671 getitem_5: "f32[1, 3, 3, 3]" = native_batch_norm_backward[0] 4672 getitem_6: "f32[3]" = native_batch_norm_backward[1] 4673 getitem_7: "f32[3]" = native_batch_norm_backward[2]; native_batch_norm_backward = None 4674 convolution_backward = torch.ops.aten.convolution_backward.default(getitem_5, arg7_1, arg0_1, [3], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [False, True, True]); getitem_5 = arg7_1 = arg0_1 = None 4675 getitem_8 = convolution_backward[0]; getitem_8 = None 4676 getitem_9: "f32[3, 1, 1, 1]" = convolution_backward[1] 4677 getitem_10: "f32[3]" = convolution_backward[2]; convolution_backward = None 4678 return (getitem_3, getitem_4, add, sum_1, detach_10, getitem_9, getitem_10, getitem_6, getitem_7) 4679 """, # noqa: B950 4680 ) 4681 4682 self.assertExpectedInline( 4683 str(signature.parameters), 4684 """['conv.weight', 'conv.bias', 'bn.weight', 'bn.bias']""", 4685 ) 4686 self.assertExpectedInline( 4687 str(signature.buffers), 4688 """['bn.running_mean', 'bn.running_var', 'bn.num_batches_tracked']""", 4689 ) 4690 self.assertExpectedInline(str(signature.user_inputs), """['arg7_1']""") 4691 self.assertExpectedInline( 4692 str(signature.inputs_to_parameters), 4693 """{'arg0_1': 'conv.weight', 'arg1_1': 'conv.bias', 'arg2_1': 'bn.weight', 'arg3_1': 'bn.bias'}""", 4694 ) # noqa: B950 4695 self.assertExpectedInline( 4696 str(signature.inputs_to_buffers), 4697 """{'arg4_1': 'bn.running_mean', 'arg5_1': 'bn.running_var', 'arg6_1': 'bn.num_batches_tracked'}""", 4698 ) # noqa: B950 4699 self.assertExpectedInline( 4700 str(signature.buffers_to_mutate), 4701 """{'getitem_3': 'bn.running_mean', 'getitem_4': 'bn.running_var', 'add': 'bn.num_batches_tracked'}""", 4702 ) # noqa: B950 4703 self.assertExpectedInline( 4704 str(signature.backward_signature.gradients_to_parameters), 4705 """{'getitem_9': 'conv.weight', 'getitem_10': 'conv.bias', 'getitem_6': 'bn.weight', 'getitem_7': 'bn.bias'}""", 4706 ) # noqa: B950 4707 self.assertExpectedInline( 4708 str(signature.backward_signature.gradients_to_user_inputs), """{}""" 4709 ) 4710 self.assertExpectedInline( 4711 str(signature.backward_signature.loss_output), """getitem_3""" 4712 ) 4713 4714 # Also check the inference graph 4715 # Main important thing here is that there are 5 total outputs: 3 total mutated buffers (from batchnorm), 2 user outputs. 4716 fx_g_inference, signature_inference = aot_export_module( 4717 mod, [inp], trace_joint=False 4718 ) 4719 for node in fx_g_inference.graph.nodes: 4720 node.meta.pop("stack_trace", None) 4721 self.assertExpectedInline( 4722 fx_g_inference.print_readable(print_output=False), 4723 """\ 4724class <lambda>(torch.nn.Module): 4725 def forward(self, arg0_1: "f32[3, 1, 1, 1]", arg1_1: "f32[3]", arg2_1: "f32[3]", arg3_1: "f32[3]", arg4_1: "f32[3]", arg5_1: "f32[3]", arg6_1: "i64[]", arg7_1: "f32[1, 1, 3, 3]"): 4726 # No stacktrace found for following nodes 4727 convolution: "f32[1, 3, 3, 3]" = torch.ops.aten.convolution.default(arg7_1, arg0_1, arg1_1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1); arg7_1 = arg0_1 = arg1_1 = None 4728 add: "i64[]" = torch.ops.aten.add.Tensor(arg6_1, 1); arg6_1 = None 4729 _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(convolution, arg2_1, arg3_1, arg4_1, arg5_1, True, 0.1, 1e-05); convolution = arg2_1 = arg3_1 = arg4_1 = arg5_1 = None 4730 getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0] 4731 getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3] 4732 getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None 4733 relu: "f32[1, 3, 3, 3]" = torch.ops.aten.relu.default(getitem); getitem = None 4734 sum_1: "f32[]" = torch.ops.aten.sum.default(relu) 4735 detach: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(relu); relu = None 4736 detach_1: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach); detach = None 4737 detach_2: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_1); detach_1 = None 4738 return (getitem_3, getitem_4, add, sum_1, detach_2) 4739 """, # noqa: B950 4740 ) 4741 # Some important characteristics of the exported graph below: 4742 # 8 arguments: 2 params from conv, 2 params from batchnorm, 2 buffers from 1 batchnorm, 1 user input 4743 # 9 outputs: 2 mutated buffers (from batchnorm), 2 user outputs and 4 gradients (since there were 4 parameters) 4744 4745 def test_aot_export_simplified_basic(self): 4746 def f(x, y): 4747 return x * y, y * y.detach() 4748 4749 x = torch.randn(2, requires_grad=True) 4750 y = torch.randn(2, requires_grad=True) 4751 4752 f_graph_fw = aot_export_joint_simple(f, [x, y], trace_joint=False) 4753 out_ref = f(x, y) 4754 # No calling convention changes necessary to invoke the traced graph 4755 out_test = f_graph_fw(x, y) 4756 self.assertEqual(out_ref, out_test) 4757 4758 # Now test the backward 4759 x = torch.randn(2, requires_grad=True) 4760 y = torch.randn(2, requires_grad=True) 4761 x2 = x.clone().detach().requires_grad_(True) 4762 y2 = y.clone().detach().requires_grad_(True) 4763 x3 = x.clone().detach().requires_grad_(True) 4764 y3 = y.clone().detach().requires_grad_(True) 4765 f_graph_joint = aot_export_joint_simple(f, [x, y], trace_joint=True) 4766 num_fw_outputs = 2 4767 fw_g, bw_g = default_partition( 4768 f_graph_joint, [x, y], num_fwd_outputs=num_fw_outputs 4769 ) 4770 out_ref2 = f(x2, y2) 4771 fw_outs = fw_g(x3, y3) 4772 out_test2, activations = fw_outs[:num_fw_outputs], fw_outs[num_fw_outputs:] 4773 self.assertEqual(out_ref2, out_test2) 4774 4775 # Test running the traced backward graph with a mocked-up grad_output 4776 grad_outs = [torch.ones_like(x) for x in out_ref2] 4777 grads_ref = torch.autograd.grad(out_ref2, [x2, y2], grad_outputs=grad_outs) 4778 grads_test = bw_g(*activations, *grad_outs) 4779 for g_ref, g_test in zip(grads_ref, grads_test): 4780 self.assertEqual(g_ref, g_test) 4781 4782 def test_aot_export_metadata_mutation_banned(self): 4783 def fn(p, x): 4784 x.t_() 4785 return (x * 2,) 4786 4787 mod = TestMod(fn) 4788 inp = torch.randn(2, 4) 4789 with self.assertRaisesRegex( 4790 RuntimeError, "Found an input that received a metadata mutation" 4791 ): 4792 aot_export_joint_simple(fn, [mod.p, inp], trace_joint=False) 4793 aot_export_joint_simple(fn, [mod.p, inp], trace_joint=True) 4794 aot_export_module(mod, [inp], trace_joint=False) 4795 4796 def test_aot_export_forward_mutation_no_buffer_mut(self): 4797 class M(torch.nn.Module): 4798 def __init__(self) -> None: 4799 super().__init__() 4800 self.buffer1 = torch.nn.Buffer(torch.ones(6, 4)) 4801 4802 def forward(self, x): 4803 x.add_(4) 4804 return (x.cos().sum() + self.buffer1.sum(),) 4805 4806 mod = M() 4807 inp = torch.ones(6, 4) 4808 gm, sig = aot_export_module(mod, [inp], trace_joint=False) 4809 self.assertExpectedInline( 4810 str(gm.code).strip(), 4811 """\ 4812def forward(self, arg0_1, arg1_1): 4813 add = torch.ops.aten.add.Tensor(arg1_1, 4); arg1_1 = None 4814 cos = torch.ops.aten.cos.default(add) 4815 sum_1 = torch.ops.aten.sum.default(cos); cos = None 4816 sum_2 = torch.ops.aten.sum.default(arg0_1); arg0_1 = None 4817 add_1 = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None 4818 return (add, add_1)""", 4819 ) # noqa: B950 4820 self.assertEqual(sig.user_inputs_to_mutate, {"add": "arg1_1"}) 4821 4822 def test_aot_export_forward_mutation_multiple_mut(self): 4823 class M(torch.nn.Module): 4824 def __init__(self) -> None: 4825 super().__init__() 4826 self.buffer1 = torch.nn.Buffer(torch.ones(6, 4)) 4827 4828 def forward(self, x, y): 4829 y.add_(4) 4830 self.buffer1.add_(5) 4831 return ( 4832 x.cos().sum() + y.sin().sum(), 4833 self.buffer1.sum(), 4834 ) 4835 4836 mod = M() 4837 inp = [torch.ones(6, 4), torch.zeros(6, 4)] 4838 gm, sig = aot_export_module(mod, inp, trace_joint=False) 4839 self.assertExpectedInline( 4840 str(gm.code).strip(), 4841 """\ 4842def forward(self, arg0_1, arg1_1, arg2_1): 4843 add = torch.ops.aten.add.Tensor(arg2_1, 4); arg2_1 = None 4844 add_1 = torch.ops.aten.add.Tensor(arg0_1, 5); arg0_1 = None 4845 cos = torch.ops.aten.cos.default(arg1_1); arg1_1 = None 4846 sum_1 = torch.ops.aten.sum.default(cos); cos = None 4847 sin = torch.ops.aten.sin.default(add) 4848 sum_2 = torch.ops.aten.sum.default(sin); sin = None 4849 add_2 = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None 4850 sum_3 = torch.ops.aten.sum.default(add_1) 4851 return (add_1, add, add_2, sum_3)""", 4852 ) # noqa: B950 4853 self.assertEqual(sig.user_inputs_to_mutate, {"add": "arg2_1"}) 4854 self.assertEqual(sig.buffers_to_mutate, {"add_1": "buffer1"}) 4855 4856 def test_aot_export_input_mutation_on_input_requiring_grad_banned(self): 4857 class M(torch.nn.Module): 4858 def forward(self, x): 4859 x.add_(4) 4860 return (x,) 4861 4862 mod = M() 4863 inp = torch.randn(2, requires_grad=True) 4864 with self.assertRaisesRegex( 4865 RuntimeError, 4866 "Found a graph input that requires gradients, and received a mutation", 4867 ): 4868 aot_export_module(mod, [inp], trace_joint=False) 4869 4870 def test_aot_export_input_mutation_on_parameter_banned(self): 4871 def fn(p, x): 4872 p.mul_(2) 4873 return (p + x,) 4874 4875 mod = TestMod(fn) 4876 inp = torch.randn(2) 4877 with self.assertRaisesRegex( 4878 RuntimeError, 4879 "Found a graph input that requires gradients, and received a mutation", 4880 ): 4881 aot_export_joint_simple(fn, [mod.p, inp], trace_joint=False) 4882 aot_export_joint_simple(fn, [mod.p, inp], trace_joint=True) 4883 aot_export_module(mod, [inp], trace_joint=False) 4884 4885 def test_aot_export_synthetic_bases_banned(self): 4886 def fn(p, x, y): 4887 x.mul_(2) 4888 return (x + y,) 4889 4890 mod = TestMod(fn) 4891 inp = torch.randn(2) 4892 inp2 = inp.view(-1) 4893 with self.assertRaisesRegex( 4894 RuntimeError, "Encountered aliased inputs that are mutated" 4895 ): 4896 aot_export_joint_simple(fn, [mod.p, inp, inp2], trace_joint=False) 4897 aot_export_joint_simple(fn, [mod.p, inp, inp2], trace_joint=True) 4898 aot_export_module(mod, [inp, inp2], trace_joint=False) 4899 4900 def test_aot_export_input_dupes_banned(self): 4901 def fn(p, x, y): 4902 x.mul_(2) 4903 return (x + y,) 4904 4905 mod = TestMod(fn) 4906 inp = torch.randn(2) 4907 with self.assertRaisesRegex( 4908 RuntimeError, "Encountered duplicated inputs that are mutated in the graph" 4909 ): 4910 aot_export_joint_simple(fn, [mod.p, inp, inp], trace_joint=False) 4911 aot_export_joint_simple(fn, [mod.p, inp, inp], trace_joint=True) 4912 aot_export_module(mod, [inp, inp], trace_joint=False) 4913 4914 def test_aot_export_multiple_outputs_require_grad_banned(self): 4915 def fn(p, x): 4916 out = p * x 4917 return out, out.sum() 4918 4919 mod = TestMod(fn) 4920 inp = torch.randn(2) 4921 with self.assertRaisesRegex( 4922 RuntimeError, 4923 "Found an output of the forward that requires gradients, that was not", 4924 ): 4925 aot_export_module(mod, [inp], trace_joint=True, output_loss_index=1) 4926 4927 @unittest.skipIf(IS_WINDOWS, "Windows isn't supported for this case") 4928 @unittest.skipIf( 4929 not torch._dynamo.is_dynamo_supported(), "Cond needs dynamo to run" 4930 ) 4931 def test_aot_export_with_torch_cond(self): 4932 class M(torch.nn.Module): 4933 def __init__(self) -> None: 4934 super().__init__() 4935 4936 def forward(self, x): 4937 def true_fn(x): 4938 y = x + 4 4939 y.add_(5) 4940 return x.cos() 4941 4942 def false_fn(x): 4943 y = x + 5 4944 y.add_(6) 4945 return x.sin() 4946 4947 a = torch.cond(x.sum() > 4, true_fn, false_fn, [x]) 4948 return (a + 3, a + 4) 4949 4950 inp = torch.randn(3, 4) 4951 gm, _ = aot_export_module(M(), (inp,), trace_joint=False) 4952 self.assertExpectedInline( 4953 gm.code.strip(), 4954 """\ 4955def forward(self, arg0_1): 4956 sum_1 = torch.ops.aten.sum.default(arg0_1) 4957 gt = torch.ops.aten.gt.Scalar(sum_1, 4); sum_1 = None 4958 true_graph_0 = self.true_graph_0 4959 false_graph_0 = self.false_graph_0 4960 cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = None 4961 getitem = cond[0]; cond = None 4962 add = torch.ops.aten.add.Tensor(getitem, 3) 4963 add_1 = torch.ops.aten.add.Tensor(getitem, 4); getitem = None 4964 return (add, add_1)""", # noqa: B950 4965 ) 4966 4967 self.assertExpectedInline( 4968 gm.true_graph_0.code.strip(), 4969 """\ 4970def forward(self, arg0_1): 4971 add = torch.ops.aten.add.Tensor(arg0_1, 4) 4972 add_1 = torch.ops.aten.add.Tensor(add, 5); add = add_1 = None 4973 cos = torch.ops.aten.cos.default(arg0_1); arg0_1 = None 4974 return (cos,)""", 4975 ) 4976 4977 self.assertExpectedInline( 4978 gm.false_graph_0.code.strip(), 4979 """\ 4980def forward(self, arg0_1): 4981 add = torch.ops.aten.add.Tensor(arg0_1, 5) 4982 add_1 = torch.ops.aten.add.Tensor(add, 6); add = add_1 = None 4983 sin = torch.ops.aten.sin.default(arg0_1); arg0_1 = None 4984 return (sin,)""", 4985 ) 4986 4987 def test_aot_export_simplified_pytrees_banned(self): 4988 def fn(inps): 4989 return (inps[0] + inps[1],) 4990 4991 inp1 = torch.randn(2) 4992 inp2 = torch.randn(2) 4993 inps = [inp1, inp2] 4994 with self.assertRaisesRegex( 4995 RuntimeError, 4996 "aot_export_joint_simple requires individual inputs not to be pytrees", 4997 ): 4998 aot_export_joint_simple(fn, [inps], trace_joint=False) 4999 aot_export_joint_simple(fn, [inps], trace_joint=True) 5000 5001 def test_aot_export_functionalized_rng_banned(self): 5002 def fn(p, x): 5003 return (p + x,) 5004 5005 mod = TestMod(fn) 5006 inp = torch.randn(2) 5007 with patch( 5008 "functorch.compile.config.functionalize_rng_ops", True 5009 ), self.assertRaisesRegex( 5010 RuntimeError, 5011 "Functionalized RNG is not currently supported in the aot_export", 5012 ): 5013 aot_export_joint_simple(fn, [mod.p, inp], trace_joint=False) 5014 aot_export_joint_simple(fn, [mod.p, inp], trace_joint=True) 5015 aot_export_module(mod, [inp], trace_joint=False) 5016 5017 def test_aot_export_unbacked_arg(self): 5018 class M(torch.nn.Module): 5019 def forward(self): 5020 full = torch.full((), 11) 5021 i0 = full.item() 5022 return (torch.full((i0,), 0),) 5023 5024 gm, _ = aot_export_module( 5025 mod=M(), args=(), trace_joint=False, dynamic_shapes=True 5026 ) 5027 self.assertExpectedInline( 5028 gm.code.strip(), 5029 """\ 5030def forward(self): 5031 full = torch.ops.aten.full.default([], 11, device = device(type='cpu'), pin_memory = False) 5032 _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(full); full = None 5033 full_1 = torch.ops.aten.full.default([_local_scalar_dense], 0, device = device(type='cpu'), pin_memory = False); _local_scalar_dense = None 5034 return (full_1,)""", # noqa: B950 5035 ) 5036 5037 5038class TestPartitioning(AOTTestCase): 5039 @unittest.skipIf(not USE_NETWORKX, "networkx not available") 5040 def test_recompute_partitioning(self): 5041 def fn(a, b): 5042 return torch.sin(torch.sin(a)) + b 5043 5044 # Reference calculation 5045 ref_a = torch.rand(10, 10, requires_grad=True) 5046 ref_b = torch.rand(10, 10, requires_grad=True) 5047 ref = fn(ref_a, ref_b) 5048 ref.sum().backward() 5049 5050 # Compiled function calculation 5051 res_a = ref_a.clone().detach().requires_grad_(True) 5052 res_b = ref_b.clone().detach().requires_grad_(True) 5053 5054 def compile_fn(x, _): 5055 return x 5056 5057 compiled_fn = compiled_function( 5058 fn, compile_fn, compile_fn, min_cut_rematerialization_partition 5059 ) 5060 res = compiled_fn(res_a, res_b) 5061 res.sum().backward() 5062 assert torch.allclose(ref, res, atol=1e-3, rtol=1e-3) 5063 assert torch.allclose(ref_a.grad, res_a.grad, atol=1e-3, rtol=1e-3) 5064 assert torch.allclose(ref_b.grad, res_b.grad, atol=1e-3, rtol=1e-3) 5065 5066 def test_meta_tensor_inplace_op(self): 5067 # Following module results in inplace ops while tracing. The test checks 5068 # that the meta tensor information is stored for inplace ops. 5069 class MockModule(torch.nn.Module): 5070 def __init__(self) -> None: 5071 super().__init__() 5072 self.weight = torch.nn.Parameter( 5073 torch.randn(3072, 768, requires_grad=True) 5074 ) 5075 self.bias = torch.nn.Parameter(torch.randn(3072, requires_grad=True)) 5076 5077 def forward(self, add_4): 5078 linear_4 = torch.nn.functional.linear( 5079 add_4, self.weight, bias=self.bias 5080 ) 5081 gelu = torch.nn.functional.gelu(linear_4) 5082 return gelu 5083 5084 def check_meta_tensor(fx_g, _): 5085 for node in fx_g.graph.nodes: 5086 if node.op != "output": 5087 assert "tensor_meta" in node.meta 5088 return fx_g 5089 5090 inp0 = torch.randn(16, 128, 768, requires_grad=True) 5091 inputs = [ 5092 inp0, 5093 ] 5094 mod = MockModule().to(device="cpu") 5095 aot_mod = aot_module(mod, fw_compiler=check_meta_tensor) 5096 aot_mod(*inputs) 5097 5098 def test_default_partitioner_getitem(self): 5099 mod = nn.LayerNorm([10]) 5100 5101 def f(x, mod_weight, mod_bias): 5102 return torch.nn.functional.layer_norm( 5103 x, [10], mod_weight, mod_bias, eps=1e-6 5104 ) 5105 5106 fw_graph, bw_graph = get_fw_bw_graph( 5107 f, 5108 [torch.randn(3, 10, requires_grad=True), mod.weight, mod.bias], 5109 partitioner=default_partition, 5110 ) 5111 self.assertEqual(get_num_ins_outs(fw_graph), (3, 6)) 5112 self.assertEqual(get_num_ins_outs(bw_graph), (6, 3)) 5113 5114 @unittest.skipIf(not USE_NETWORKX, "networkx not available") 5115 def test_min_cut_partitioner_save_shape(self): 5116 def f(x): 5117 s = x.sum(dim=1) 5118 return s 5119 5120 inp = [torch.ones([10, 10], requires_grad=True)] 5121 fw_graph, bw_graph = get_fw_bw_graph(f, inp, dynamic=True) 5122 _, fw_output = get_ins_outs(fw_graph) 5123 self.assertEqual(get_num_ins_outs(fw_graph), (1, 3)) 5124 self.assertEqual(get_num_ins_outs(bw_graph), (3, 1)) 5125 self.assertEqual(str(fw_output[0]), "sum_1") 5126 # make sure we don't do the suboptimal thing of saving the bigger primals input to sum, 5127 # rather than saving the sizes of the primals input for use in backward expand 5128 self.assertEqual(str(fw_output[1]), "sym_size_int") 5129 self.assertEqual(str(fw_output[2]), "sym_size_int_1") 5130 5131 inp = [ 5132 torch.randn(10, requires_grad=True), 5133 torch.randn((3, 10), requires_grad=True), 5134 torch.randn((2, 10), requires_grad=True), 5135 ] 5136 5137 def f(a, b, c): 5138 # tried to test what happens if we save a size tuple in the graph; 5139 # turns out we never will due to how we trace, but this is probably 5140 # still a good test case for various size manipulations 5141 sb = torch.ops.aten.sym_size(b) 5142 sc = c.size() 5143 x = sb[0] + sc[0] 5144 a_sz = (x, a.size(0)) 5145 return torch.cat([a.expand(a_sz), b, c]) 5146 5147 fw_graph, bw_graph = get_fw_bw_graph(f, inp, dynamic=True) 5148 self.assertEqual(get_num_ins_outs(fw_graph), (3, 4)) 5149 self.assertEqual(get_num_ins_outs(bw_graph), (4, 3)) 5150 _, outs = get_ins_outs(fw_graph) 5151 self.assertTrue(all(is_sym_node(n) for n in outs[1:])) 5152 5153 def test_default_partitioner_output_tensor_shape_tensor(self): 5154 inp = [ 5155 torch.randn(10, requires_grad=True), 5156 torch.randn((3, 10), requires_grad=True), 5157 torch.randn((2, 10), requires_grad=True), 5158 torch.randn((10, 1), requires_grad=True), 5159 ] 5160 5161 def f(a, b, c, d): 5162 # Try to force symints intermixed with outputs in the function's returns 5163 sb = b.size() 5164 sc = c.size() 5165 x = sb[0] + sc[0] 5166 a_sz = (x, a.size(0)) 5167 cat = torch.cat([a.expand(a_sz), b, c]) 5168 mm = torch.mm(cat, d) 5169 mm2 = torch.mm( 5170 mm, a.view(mm.size(1), a.size(0)) 5171 ) # this saves 4 new ints for backward. why? 5172 # and what do i have to do to make it save a tensor for backward? 5173 return cat, sb, c, mm2 5174 5175 fw_graph_cell = [None] 5176 bw_graph_cell = [None] 5177 compiled_outs = aot_function( 5178 f, 5179 fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell), 5180 bw_compiler=partial(extract_graph, graph_cell=bw_graph_cell), 5181 partition_fn=default_partition, 5182 decompositions=default_decompositions, 5183 dynamic=True, 5184 )(*inp) 5185 fw_graph = fw_graph_cell[0] 5186 (compiled_outs[0].sum() + compiled_outs[2].sum()).backward() 5187 bw_graph = bw_graph_cell[0] 5188 5189 # in the fwd graph, 13 outs because: 5190 # - 5 original outputs (sb is a tuple, gets expanded to 2 symints) 5191 # - 8 saved outputs for backward: 5 tensors, 3 symints 5192 self.assertEqual(get_num_ins_outs(fw_graph), (4, 13)) 5193 # in the bwd graph, 10 inputs (grad outs) because: 5194 # - The fwd graph had 13 outputs 5195 # - 1 was a view of an input, which gets regenerated outside of the graph 5196 # and doesn't participate in the backward 5197 # - 2 user outs were symints (b.size()), which don't get tangents in the backward 5198 self.assertEqual(get_num_ins_outs(bw_graph), (10, 4)) 5199 _, fw_graph_out_nodes = get_ins_outs(fw_graph) 5200 self.assertEqual( 5201 # fw outputs include b.size() which expands to 2 symints, 5202 # 5203 # TODO(whc)- are the saved-tensors/saved-symints correct here? 5204 # i just made the test pass based on what default partition did 5205 # Of the 5 original forward outputs, the 4th (c) is an input, 5206 # which won't show up in the compiled forward graph 5207 [False, True, True, False, False] + [False] * 4 + [True] * 4, 5208 [is_sym_node(n) for n in fw_graph_out_nodes], 5209 ) 5210 5211 real_outs = f(*inp) 5212 self.assertEqual(compiled_outs, real_outs) 5213 self.assertTrue(isinstance(real_outs[1], torch.Size)) 5214 5215 # TODO(whc) we should learn to return torch.Sizes 5216 self.assertFalse(isinstance(compiled_outs[1], torch.Size)) 5217 5218 @unittest.skipIf(not USE_NETWORKX, "networkx not available") 5219 def test_min_cut_partitioner_output_tensor_shape_tensor(self): 5220 inp = [ 5221 torch.randn(10, requires_grad=True), 5222 torch.randn((3, 10), requires_grad=True), 5223 torch.randn((2, 10), requires_grad=True), 5224 torch.randn((10, 1), requires_grad=True), 5225 ] 5226 5227 def f(a, b, c, d): 5228 # Try to force symints intermixed with outputs in the function's returns 5229 sb = b.size() 5230 sc = c.size() 5231 x = sb[0] + sc[0] 5232 a_sz = (x, a.size(0)) 5233 cat = torch.cat([a.expand(a_sz), b, c]) 5234 mm = torch.mm(cat, d) 5235 mm2 = torch.mm( 5236 mm, a.view(mm.size(1), a.size(0)) 5237 ) # this saves 4 new ints for backward. why? 5238 # and what do i have to do to make it save a tensor for backward? 5239 return cat, sb, c, mm2 5240 5241 fw_graph_cell = [None] 5242 bw_graph_cell = [None] 5243 compiled_outs = aot_function( 5244 f, 5245 fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell), 5246 bw_compiler=partial(extract_graph, graph_cell=bw_graph_cell), 5247 partition_fn=min_cut_rematerialization_partition, 5248 decompositions=default_decompositions, 5249 dynamic=True, 5250 )(*inp) 5251 fw_graph = fw_graph_cell[0] 5252 (compiled_outs[0].sum() + compiled_outs[2].sum()).backward() 5253 bw_graph = bw_graph_cell[0] 5254 5255 self.assertEqual(get_num_ins_outs(fw_graph), (4, 12)) 5256 self.assertEqual(get_num_ins_outs(bw_graph), (9, 4)) 5257 _, fw_graph_out_nodes = get_ins_outs(fw_graph) 5258 self.assertEqual( 5259 # fw outputs include b.size() which expands to 2 symints, 5260 # then 4 tensors (transposes of matricies used for mm) are saved 5261 # finally 3 symints are saved 5262 [False, True, True, False, False] + [False] * 4 + [True] * 3, 5263 [is_sym_node(n) for n in fw_graph_out_nodes], 5264 ) 5265 5266 real_outs = f(*inp) 5267 self.assertEqual(compiled_outs, real_outs) 5268 self.assertTrue(isinstance(real_outs[1], torch.Size)) 5269 5270 # TODO(whc) we should learn to return torch.Sizes 5271 self.assertFalse(isinstance(compiled_outs[1], torch.Size)) 5272 5273 @unittest.skipIf(not USE_NETWORKX, "networkx not available") 5274 def test_min_cut_partitioner(self): 5275 def f(x): 5276 return x.cos().cos().cos() 5277 5278 fw_graph, bw_graph = get_fw_bw_graph(f, [torch.randn(3, requires_grad=True)]) 5279 self.assertEqual(get_num_ins_outs(fw_graph), (1, 2)) 5280 self.assertEqual(get_num_ins_outs(bw_graph), (2, 1)) 5281 5282 def f(a, b, c, d): 5283 x = a + b + c + d 5284 return x.cos().cos() 5285 5286 fw_graph, bw_graph = get_fw_bw_graph( 5287 f, [torch.randn(3, requires_grad=True) for _ in range(4)] 5288 ) 5289 self.assertEqual(get_num_ins_outs(fw_graph), (4, 2)) 5290 self.assertEqual(get_num_ins_outs(bw_graph), (2, 4)) 5291 5292 def test_contiguous(self): 5293 # The test simulates the condition where transpose followed by view 5294 # happens in the backward pass. 5295 # https://discuss.pytorch.org/t/error-on-transpose-and-view/434 5296 def f(x): 5297 return x.view(2, 3).t() 5298 5299 inp = torch.randn(6, requires_grad=True) 5300 out = aot_function(f, nop)(inp) 5301 torch.autograd.grad(out, inp, torch.randn(3, 2)) 5302 5303 def test_preserve_random(self): 5304 def fn(x): 5305 return torch.nn.functional.dropout(x, 0.5) + x 5306 5307 x = torch.randn(4) 5308 5309 torch.manual_seed(0) 5310 ref = fn(x) 5311 5312 torch.manual_seed(0) 5313 aot_fn = aot_function(fn, nop) 5314 res = aot_fn(x) 5315 5316 assert torch.allclose(ref, res) 5317 5318 # https://github.com/pytorch/pytorch/issues/110666 5319 def test_generate_gives_inference_graph(self): 5320 # We expect this to give an inference graph 5321 def generate(x): 5322 with torch.no_grad(): 5323 return torch.mul(x, x) 5324 5325 inference_graph_cell = [None] 5326 inference_compiler = make_boxed_compiler( 5327 partial(extract_graph, graph_cell=inference_graph_cell) 5328 ) 5329 aot_fn = aot_function(generate, nop, inference_compiler=inference_compiler) 5330 # Even though x requires grad, we should still get an inference graph 5331 x = torch.randn(4, requires_grad=True) 5332 res = aot_fn(x) 5333 self.assertTrue(inference_graph_cell[0] is not None) 5334 5335 @unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable") 5336 @unittest.skipIf(not USE_TORCHVISION, "test requires torchvision") 5337 def test_autocast(self): 5338 mod = torchvision.models.resnet18().cuda() 5339 mod.train() 5340 5341 x = torch.randn(16, 3, 32, 32, device="cuda") 5342 aot_mod = memory_efficient_fusion(mod) 5343 5344 # Ensure that AOT Autograd works with AMP 5345 with torch.cuda.amp.autocast(True): 5346 res = aot_mod(x) 5347 res.sum().backward() 5348 5349 5350class TestAOTDispatch(AOTTestCase): 5351 # Tests to add cases for (non-exhaustive list, mostly for my notes): 5352 # - subclass / mode introduced in the middle of the compiled fn 5353 # - various input mutation / intermediate base tests 5354 # - input mutation that changes a tensor into a subclass 5355 # - metadata mutation? (TBD) 5356 # - guard tests (fw guards *and* bw guards) 5357 # - subclass test involving _indices_of_inps_to_detach 5358 def test_aot_dispatch_simple(self): 5359 # a is a subclass, b is not 5360 def f(a, b): 5361 aa = torch.mul(a, 6) 5362 bb = torch.div(b, 2) 5363 return aa + bb 5364 5365 a1_ref = torch.ones(3, 3, requires_grad=True) 5366 a2_ref = torch.ones(3, 3, requires_grad=True) 5367 a_ref = TwoTensor(a1_ref, a2_ref) 5368 b_ref = torch.ones(3, 3, requires_grad=True) 5369 5370 a1_test = a1_ref.clone().detach().requires_grad_(True) 5371 a2_test = a2_ref.clone().detach().requires_grad_(True) 5372 a_test = TwoTensor(a1_test, a2_test) 5373 b_test = b_ref.clone().detach().requires_grad_(True) 5374 5375 fw_graph_cell = [None] 5376 bw_graph_cell = [None] 5377 5378 compiled_f = aot_function( 5379 f, 5380 fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell), 5381 bw_compiler=partial(extract_graph, graph_cell=bw_graph_cell), 5382 partition_fn=min_cut_rematerialization_partition, 5383 ) 5384 out_ref = f(a_ref, b_ref) 5385 out_test = compiled_f(a_test, b_test) 5386 5387 # Output is a TwoTensor (check both inner tensors) 5388 self.assertEqual(out_ref.a, out_test.a) 5389 self.assertEqual(out_ref.b, out_test.b) 5390 5391 out_ref.sum().backward() 5392 out_test.sum().backward() 5393 # Both grad_inputs are TwoTensor 5394 self.assertEqual(a_ref.grad.a, a_test.grad.a) 5395 self.assertEqual(a_ref.grad.b, a_test.grad.b) 5396 self.assertEqual(b_ref.grad.a, b_test.grad.a) 5397 self.assertEqual(b_ref.grad.b, b_test.grad.b) 5398 5399 # Important pieces of the graph: 5400 # - mul() and div() show up twice, because we called them on a TwoTensor 5401 # - add() shows up once, because we called it on a plain Tensor 5402 # - The user forward() fn returns 1 output (the result of add), 5403 # while the graph itself returns two outputs (add, add_1) 5404 # - add, add_1 correspond to the two inner dense tensors that will be wrapped 5405 # - into a single TwoTensor output. 5406 self.assertExpectedInline( 5407 fw_graph_cell[0].code.strip(), 5408 """\ 5409def forward(self, primals_1, primals_2, primals_3): 5410 mul = torch.ops.aten.mul.Tensor(primals_1, 6); primals_1 = None 5411 mul_1 = torch.ops.aten.mul.Tensor(primals_2, 6); primals_2 = None 5412 div = torch.ops.aten.div.Tensor(primals_3, 2); primals_3 = None 5413 add = torch.ops.aten.add.Tensor(mul, div); mul = None 5414 add_1 = torch.ops.aten.add.Tensor(mul_1, div); mul_1 = div = None 5415 return (add, add_1)""", 5416 ) 5417 5418 # Important pieces of the graph: 5419 # - 4 total dense outputs. 5420 # This corresponds to the fact that each user fwd inpt (a, b) 5421 # will get a gradient that is a TwoTensor subclass, 5422 # so (mul_2, mul_3) will be wrapped into a.grad 5423 # and (div_1, div_2) will be wrapped into b.grad 5424 # - 4 total dense outputs, 5425 self.assertExpectedInline( 5426 bw_graph_cell[0].code.strip(), 5427 """\ 5428def forward(self, tangents_1, tangents_2): 5429 div_1 = torch.ops.aten.div.Tensor(tangents_1, 2) 5430 div_2 = torch.ops.aten.div.Tensor(tangents_2, 2) 5431 mul_2 = torch.ops.aten.mul.Tensor(tangents_1, 6); tangents_1 = None 5432 mul_3 = torch.ops.aten.mul.Tensor(tangents_2, 6); tangents_2 = None 5433 return (mul_2, mul_3, div_1, div_2)""", 5434 ) 5435 5436 def test_aot_dispatch_inference(self): 5437 # a is a subclass, b is not 5438 def f(a, b): 5439 aa = torch.mul(a, 6) 5440 bb = torch.div(b, 2) 5441 return aa + bb 5442 5443 a1_ref = torch.ones(3, 3) 5444 a2_ref = torch.ones(3, 3) 5445 a_ref = TwoTensor(a1_ref, a2_ref) 5446 b_ref = torch.ones(3, 3) 5447 5448 a1_test = a1_ref.clone() 5449 a2_test = a2_ref.clone() 5450 a_test = TwoTensor(a1_test, a2_test) 5451 b_test = b_ref.clone() 5452 5453 compiled_f = aot_function( 5454 f, 5455 fw_compiler=nop, 5456 bw_compiler=nop, 5457 partition_fn=min_cut_rematerialization_partition, 5458 ) 5459 out_ref = f(a_ref, b_ref) 5460 out_test = compiled_f(a_test, b_test) 5461 5462 # Output is a TwoTensor (check both inner tensors) 5463 self.assertEqual(out_ref.a, out_test.a) 5464 self.assertEqual(out_ref.b, out_test.b) 5465 5466 def test_aot_dispatch_incorrect_backward(self): 5467 # a is a subclass, b is not 5468 def f(a, b): 5469 aa = torch.mul(a, 2) 5470 bb = torch.add(b, 3) 5471 out_subclass = torch.div(aa, bb) 5472 out_reg = torch.add(b, b) 5473 # When creating the joint, we assume that the second grad_out 5474 # is not a subclass. 5475 # In the below test case though, we end up being wrong. 5476 # This would require re-tracing and recompiling the backward. 5477 return out_subclass, out_reg 5478 5479 a1_ref = torch.ones(3, 3, requires_grad=True) 5480 a2_ref = torch.ones(3, 3, requires_grad=True) 5481 a_ref = TwoTensor(a1_ref, a2_ref) 5482 b_ref = torch.ones(3, 3, requires_grad=True) 5483 5484 a1_test = a1_ref.clone().detach().requires_grad_(True) 5485 a2_test = a2_ref.clone().detach().requires_grad_(True) 5486 a_test = TwoTensor(a1_test, a2_test) 5487 b_test = b_ref.clone().detach().requires_grad_(True) 5488 5489 compiled_f = aot_function( 5490 f, 5491 fw_compiler=nop, 5492 bw_compiler=nop, 5493 partition_fn=min_cut_rematerialization_partition, 5494 ) 5495 out_ref = f(a_ref, b_ref) 5496 out_test = compiled_f(a_test, b_test) 5497 # First out is a TwoTensor, second is an ordinary tensor 5498 self.assertEqual(out_ref[0].a, out_test[0].a) 5499 self.assertEqual(out_ref[0].b, out_test[0].b) 5500 self.assertEqual(out_ref[1], out_test[1]) 5501 5502 # We compiled our graph assuming type(grad_out[1]) == torch.Tensor, 5503 # but we were wrong: in the below tests, it is a subclass. 5504 # This will eventually require a repartition + recompile 5505 with self.assertRaisesRegex( 5506 AssertionError, 5507 "incorrectly attempted to compile the backward with incorrect subclass metadata", 5508 ): 5509 (out_test[0] + out_test[1]).sum().backward() 5510 5511 def test_aot_dispatch_output_alias(self): 5512 # a is a tensor, b is a TwoTensor 5513 def f(a, b): 5514 return b.view(b.shape), a * b 5515 5516 b1_ref = torch.ones(3, 3, requires_grad=True) 5517 b2_ref = torch.ones(3, 3, requires_grad=True) 5518 b_ref = TwoTensor(b1_ref, b2_ref) 5519 a_ref = torch.ones(3, 3, requires_grad=True) 5520 5521 b1_test = b1_ref.clone().detach().requires_grad_(True) 5522 b2_test = b2_ref.clone().detach().requires_grad_(True) 5523 b_test = TwoTensor(b1_test, b2_test) 5524 a_test = a_ref.clone().detach().requires_grad_(True) 5525 5526 compiled_f = aot_function( 5527 f, 5528 fw_compiler=nop, 5529 bw_compiler=nop, 5530 partition_fn=min_cut_rematerialization_partition, 5531 ) 5532 out_ref1, out_ref2 = f(a_ref, b_ref) 5533 out_test1, out_test2 = compiled_f(a_test, b_test) 5534 self.assertEqual(out_ref1, out_test1) 5535 self.assertEqual(out_ref2.a, out_test2.a) 5536 self.assertEqual(out_ref2.b, out_test2.b) 5537 5538 (out_ref1 + out_ref2).sum().backward() 5539 (out_test1 + out_test2).sum().backward() 5540 # Both grad_inputs are TwoTensor 5541 self.assertEqual(a_ref.grad.a, a_test.grad.a) 5542 self.assertEqual(a_ref.grad.b, a_test.grad.b) 5543 self.assertEqual(b_ref.grad.a, b_test.grad.a) 5544 self.assertEqual(b_ref.grad.b, b_test.grad.b) 5545 5546 def test_aot_dispatch_input_mutation(self): 5547 def f(a, b): 5548 a.mul_(2) 5549 b.mul_(3) 5550 return a + b 5551 5552 b1_ref = torch.ones(3, 3, requires_grad=True) 5553 b2_ref = torch.ones(3, 3, requires_grad=True) 5554 b_ref_base = TwoTensor(b1_ref, b2_ref) 5555 a_ref_base = torch.ones(3, 3, requires_grad=True) 5556 b_ref = b_ref_base + 1 5557 a_ref = a_ref_base + 1 5558 5559 b1_test = b1_ref.clone().detach().requires_grad_(True) 5560 b2_test = b2_ref.clone().detach().requires_grad_(True) 5561 b_test_base = TwoTensor(b1_test, b2_test) 5562 a_test_base = a_ref_base.clone().detach().requires_grad_(True) 5563 b_test = b_test_base + 1 5564 a_test = a_test_base + 1 5565 5566 compiled_f = aot_function( 5567 f, 5568 fw_compiler=nop, 5569 bw_compiler=nop, 5570 partition_fn=min_cut_rematerialization_partition, 5571 ) 5572 out_ref = f(a_ref, b_ref) 5573 out_test = compiled_f(a_test, b_test) 5574 self.assertEqual(out_ref.a, out_test.a) 5575 self.assertEqual(out_ref.b, out_test.b) 5576 5577 # confirm input mutations worked 5578 self.assertEqual(a_test, a_ref) 5579 self.assertEqual(b_test.a, b_ref.a) 5580 self.assertEqual(b_test.b, b_ref.b) 5581 5582 # NOTE: we need to use b in our gradient compute. Otherwise we will need to recompile teh backward. 5583 (b_ref * out_ref).sum().backward() 5584 (b_test * out_test).sum().backward() 5585 # Both grad_inputs are TwoTensor 5586 self.assertEqual(a_ref_base.grad.a, a_test_base.grad.a) 5587 self.assertEqual(a_ref_base.grad.b, a_test_base.grad.b) 5588 self.assertEqual(b_ref_base.grad.a, b_test_base.grad.a) 5589 self.assertEqual(b_ref_base.grad.b, b_test_base.grad.b) 5590 5591 # NB: Metadata mutation for subclasses is currently broken and disabled 5592 # See https://github.com/pytorch/pytorch/issues/114975 5593 @unittest.expectedFailure 5594 def test_aot_dispatch_input_metadata_mutation(self): 5595 def f(a, b): 5596 a.t_() 5597 b.unsqueeze_(0) 5598 return a + b 5599 5600 b1_ref = torch.arange(9, requires_grad=True, dtype=torch.float32).reshape(3, 3) 5601 b2_ref = torch.arange(9, requires_grad=True, dtype=torch.float32).reshape(3, 3) 5602 b_ref_base = TwoTensor(b1_ref, b2_ref) 5603 a_ref_base = ( 5604 torch.arange(9, dtype=torch.float32) 5605 .reshape(3, 3) 5606 .detach() 5607 .requires_grad_(True) 5608 ) 5609 b_ref = b_ref_base + 1 5610 a_ref = a_ref_base + 1 5611 5612 b1_test = b1_ref.clone().detach().requires_grad_(True) 5613 b2_test = b2_ref.clone().detach().requires_grad_(True) 5614 b_test_base = TwoTensor(b1_test, b2_test) 5615 a_test_base = a_ref_base.clone().detach().requires_grad_(True) 5616 b_test = b_test_base + 1 5617 a_test = a_test_base + 1 5618 5619 compiled_f = aot_function( 5620 f, 5621 fw_compiler=nop, 5622 bw_compiler=nop, 5623 partition_fn=min_cut_rematerialization_partition, 5624 ) 5625 out_ref = f(a_ref, b_ref) 5626 out_test = compiled_f(a_test, b_test) 5627 self.assertEqual(out_ref.a, out_test.a) 5628 self.assertEqual(out_ref.b, out_test.b) 5629 5630 # confirm input mutations worked 5631 self.assertEqual(a_test, a_ref) 5632 self.assertEqual(b_test.a, b_ref.a) 5633 self.assertEqual(b_test.b, b_ref.b) 5634 5635 # NOTE: we need to use b in our gradient compute. Otherwise we will need to recompile the backward. 5636 (b_ref * out_ref).sum().backward() 5637 (b_test * out_test).sum().backward() 5638 # Both grad_inputs are TwoTensor 5639 self.assertEqual(a_ref_base.grad.a, a_test_base.grad.a) 5640 self.assertEqual(a_ref_base.grad.b, a_test_base.grad.b) 5641 self.assertEqual(b_ref_base.grad.a, b_test_base.grad.a) 5642 self.assertEqual(b_ref_base.grad.b, b_test_base.grad.b) 5643 5644 # NB: Metadata mutation for subclasses is currently broken and disabled 5645 # See https://github.com/pytorch/pytorch/issues/114975 5646 @unittest.expectedFailure 5647 def test_aot_dispatch_input_data_and_metadata_mutation(self): 5648 def f(a, b): 5649 a.t_() 5650 b.unsqueeze_(0) 5651 a.mul_(2) 5652 b.mul_(3) 5653 return a + b 5654 5655 b1_ref = torch.arange(9, requires_grad=True, dtype=torch.float32).reshape(3, 3) 5656 b2_ref = torch.arange(9, requires_grad=True, dtype=torch.float32).reshape(3, 3) 5657 b_ref_base = TwoTensor(b1_ref, b2_ref) 5658 a_ref_base = ( 5659 torch.arange(9, dtype=torch.float32) 5660 .reshape(3, 3) 5661 .detach() 5662 .requires_grad_(True) 5663 ) 5664 b_ref = b_ref_base + 1 5665 a_ref = a_ref_base + 1 5666 5667 b1_test = b1_ref.clone().detach().requires_grad_(True) 5668 b2_test = b2_ref.clone().detach().requires_grad_(True) 5669 b_test_base = TwoTensor(b1_test, b2_test) 5670 a_test_base = a_ref_base.clone().detach().requires_grad_(True) 5671 b_test = b_test_base + 1 5672 a_test = a_test_base + 1 5673 5674 compiled_f = aot_function( 5675 f, 5676 fw_compiler=nop, 5677 bw_compiler=nop, 5678 partition_fn=min_cut_rematerialization_partition, 5679 ) 5680 out_ref = f(a_ref, b_ref) 5681 out_test = compiled_f(a_test, b_test) 5682 self.assertEqual(out_ref.a, out_test.a) 5683 self.assertEqual(out_ref.b, out_test.b) 5684 5685 # confirm input mutations worked 5686 self.assertEqual(a_test, a_ref) 5687 self.assertEqual(b_test.a, b_ref.a) 5688 self.assertEqual(b_test.b, b_ref.b) 5689 5690 # NOTE: we need to use b in our gradient compute. Otherwise we will need to recompile the backward. 5691 (b_ref * out_ref).sum().backward() 5692 (b_test * out_test).sum().backward() 5693 # Both grad_inputs are TwoTensor 5694 self.assertEqual(a_ref_base.grad.a, a_test_base.grad.a) 5695 self.assertEqual(a_ref_base.grad.b, a_test_base.grad.b) 5696 self.assertEqual(b_ref_base.grad.a, b_test_base.grad.a) 5697 self.assertEqual(b_ref_base.grad.b, b_test_base.grad.b) 5698 5699 def test_aot_dispatch_input_mutation_and_output_alias(self): 5700 def f(a, b): 5701 a.mul_(2) 5702 b.mul_(3) 5703 return b.view(b.shape), a + b 5704 5705 b1_ref = torch.arange(9, requires_grad=True, dtype=torch.float32).reshape(3, 3) 5706 b2_ref = torch.arange(9, requires_grad=True, dtype=torch.float32).reshape(3, 3) 5707 b_ref_base = TwoTensor(b1_ref, b2_ref) 5708 a_ref_base = ( 5709 torch.arange(9, dtype=torch.float32) 5710 .reshape(3, 3) 5711 .detach() 5712 .requires_grad_(True) 5713 ) 5714 b_ref = b_ref_base + 1 5715 a_ref = a_ref_base + 1 5716 5717 b1_test = b1_ref.clone().detach().requires_grad_(True) 5718 b2_test = b2_ref.clone().detach().requires_grad_(True) 5719 b_test_base = TwoTensor(b1_test, b2_test) 5720 a_test_base = a_ref_base.clone().detach().requires_grad_(True) 5721 b_test = b_test_base + 1 5722 a_test = a_test_base + 1 5723 5724 compiled_f = aot_function( 5725 f, 5726 fw_compiler=nop, 5727 bw_compiler=nop, 5728 partition_fn=min_cut_rematerialization_partition, 5729 ) 5730 out_ref1, out_ref2 = f(a_ref, b_ref) 5731 out_test1, out_test2 = compiled_f(a_test, b_test) 5732 self.assertEqual(out_ref1.a, out_test1.a) 5733 self.assertEqual(out_ref1.b, out_test1.b) 5734 self.assertEqual(out_ref2.a, out_test2.a) 5735 self.assertEqual(out_ref2.b, out_test2.b) 5736 5737 # confirm input mutations worked 5738 self.assertEqual(a_test, a_ref) 5739 self.assertEqual(b_test.a, b_ref.a) 5740 self.assertEqual(b_test.b, b_ref.b) 5741 5742 (out_ref1 * out_ref2).sum().backward() 5743 (out_test1 * out_test2).sum().backward() 5744 # Both grad_inputs are TwoTensors 5745 self.assertEqual(a_ref_base.grad.a, a_test_base.grad.a) 5746 self.assertEqual(a_ref_base.grad.b, a_test_base.grad.b) 5747 5748 def test_aot_dispatch_output_requires_grad_in_no_grad(self): 5749 def fn(x): 5750 out1 = x.sin() 5751 with torch.enable_grad(): 5752 out2 = x.cos() 5753 return out1, out2 5754 5755 inp_fns = [ 5756 lambda: torch.ones(10, requires_grad=True), 5757 lambda: torch.ones(10, requires_grad=False), 5758 ] 5759 5760 compiled_f = aot_function(fn, nop) 5761 for inp_fn in inp_fns: 5762 with torch.no_grad(): 5763 ref_x = inp_fn() 5764 ref_out = fn(ref_x) 5765 x = inp_fn() 5766 out = compiled_f(x) 5767 for r, o in zip(ref_out, out): 5768 self.assertEqual(r.requires_grad, o.requires_grad) 5769 if ref_x.requires_grad: 5770 with torch.enable_grad(): 5771 (ref_out[0] + ref_out[1]).sum().backward() 5772 (out[0] + out[1]).sum().backward() 5773 self.assertEqual(ref_x.grad, x.grad) 5774 assert torch.allclose(ref_x.grad, x.grad, atol=1e-3, rtol=1e-3) 5775 5776 def test_aot_dispatch_output_requires_grad_in_no_grad_views(self): 5777 # view-type ops preserve requires_grad even in no_grad. 5778 def fn(x): 5779 return x.view(-1), x.sin() 5780 5781 inference_graph_cell = [None] 5782 inference_compiler = make_boxed_compiler( 5783 partial(extract_graph, graph_cell=inference_graph_cell) 5784 ) 5785 compiled_fn = aot_function(fn, nop, inference_compiler=inference_compiler) 5786 5787 inp_x0 = torch.ones(2, 3, requires_grad=True) 5788 # Clone in no_grad will make requires_grad=False tensors, keep clone outside of no_grad 5789 ref_x0 = inp_x0.clone() 5790 x0 = inp_x0.clone() 5791 with torch.no_grad(): 5792 ref_out1, ref_out2 = fn(ref_x0) 5793 5794 out1, out2 = compiled_fn(x0) 5795 # Assert that we executed inference graph 5796 self.assertTrue(inference_graph_cell[0] is not None) 5797 5798 self.assertEqual(ref_out1.requires_grad, out1.requires_grad) 5799 self.assertEqual(ref_out2.requires_grad, out2.requires_grad) 5800 5801 5802class TestAOTModuleSimplified(AOTTestCase): 5803 def test_aot_module_simplified(self): 5804 class MockModule(torch.nn.Module): 5805 def __init__(self) -> None: 5806 super().__init__() 5807 self.linear = torch.nn.Linear(20, 30) 5808 5809 def forward(self, x, y): 5810 return (self.linear(x) + y,) 5811 5812 mod = MockModule() 5813 mod.zero_grad() 5814 5815 x = torch.randn(128, 20, requires_grad=True) 5816 y = torch.randn(128, 30, requires_grad=True) 5817 inputs = [x, y] 5818 cloned_inputs = [x.detach().clone().requires_grad_(True) for x in inputs] 5819 5820 ref = mod(*inputs) 5821 ref[0].sum().backward() 5822 5823 compiled_f = aot_module_simplified(mod, cloned_inputs, nop) 5824 mod.zero_grad() 5825 res = compiled_f(*cloned_inputs) 5826 res[0].sum().backward() 5827 5828 assert torch.allclose(ref[0], res[0]) 5829 assert torch.allclose(inputs[0].grad, cloned_inputs[0].grad) 5830 assert torch.allclose(inputs[1].grad, cloned_inputs[1].grad) 5831 5832 def test_aot_module_simplified_dynamic(self): 5833 class MockModule(torch.nn.Module): 5834 def __init__(self) -> None: 5835 super().__init__() 5836 self.linear = torch.nn.Linear(20, 30) 5837 5838 def forward(self, x, y): 5839 return (self.linear(x) + y,) 5840 5841 mod = MockModule() 5842 5843 shape_env = ShapeEnv() 5844 fake_mode = FakeTensorMode(shape_env=shape_env) 5845 5846 x = torch.randn(128, 20, requires_grad=True) 5847 y = torch.randn(128, 30, requires_grad=True) 5848 5849 inputs = [x, y] 5850 fake_inputs = [fake_mode.from_tensor(x) for x in inputs] 5851 compiled_f = aot_module_simplified(mod, fake_inputs, nop) 5852 5853 ref = mod(*inputs) 5854 ref[0].sum().backward() 5855 5856 cloned_inputs = [x.detach().clone().requires_grad_(True) for x in inputs] 5857 res = compiled_f(*cloned_inputs) 5858 res[0].sum().backward() 5859 5860 self.assertExpectedInline( 5861 shape_env.format_guards(), 5862 """\ 5863 - Eq(s1, 20) 5864 - Eq(s2, 30)""", 5865 ) 5866 5867 assert torch.allclose(ref[0], res[0]) 5868 assert torch.allclose(inputs[0].grad, cloned_inputs[0].grad) 5869 assert torch.allclose(inputs[1].grad, cloned_inputs[1].grad) 5870 5871 # https://github.com/pytorch/pytorch/issues/105327 5872 def test_lift_fresh_copy_in_graph(self): 5873 class MyMod(torch.nn.Module): 5874 def forward(self, x): 5875 _tensor_constant0 = torch.tensor([1]) 5876 lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default( 5877 _tensor_constant0 5878 ) 5879 y = x.mul(lift_fresh_copy) 5880 return (y,) 5881 5882 mod = MyMod() 5883 shape_env = ShapeEnv() 5884 fake_mode = FakeTensorMode(shape_env=shape_env) 5885 x = torch.ones(4, requires_grad=True) 5886 inputs = [x] 5887 fake_inputs = [fake_mode.from_tensor(x) for x in inputs] 5888 compiled_f = aot_module_simplified(mod, fake_inputs, nop) 5889 5890 out_ref = mod(x) 5891 out_test = compiled_f(x) 5892 self.assertEqual(out_ref[0].detach(), out_test[0].detach()) 5893 5894 def test_inference_python_dispatcher(self): 5895 # Extracted from unet 5896 class MockModule(torch.nn.Module): 5897 def __init__(self) -> None: 5898 super().__init__() 5899 self.upsample = torch.nn.Upsample( 5900 scale_factor=2, mode="bilinear", align_corners=True 5901 ) 5902 5903 def forward(self, x): 5904 return (self.upsample(x),) 5905 5906 mod = MockModule() 5907 shape_env = ShapeEnv() 5908 fake_mode = FakeTensorMode(shape_env=shape_env) 5909 x = torch.randn(2, 512, 40, 59) # NB: must not require grad 5910 inputs = [x] 5911 fake_inputs = [fake_mode.from_tensor(x) for x in inputs] 5912 compiled_f = aot_module_simplified(mod, fake_inputs, nop) 5913 5914 def test_aot_module_simplified_preserves_stack_trace(self): 5915 class MockModule(torch.nn.Module): 5916 def __init__(self) -> None: 5917 super().__init__() 5918 self.linear = torch.nn.Linear(20, 30) 5919 5920 def forward(self, x, y): 5921 z = self.linear(x) 5922 z = z + y 5923 z = z.relu() 5924 return (z,) 5925 5926 tracer = torch.fx.Tracer() 5927 tracer.record_stack_traces = True 5928 graph = tracer.trace(MockModule()) 5929 mod = torch.fx.GraphModule(tracer.root, graph) 5930 5931 for node in mod.graph.nodes: 5932 if node.op == "output": 5933 continue 5934 self.assertTrue(node.stack_trace is not None) 5935 assert "test_aotdispatch.py" in node.stack_trace 5936 5937 def assert_compiler(gm: torch.fx.GraphModule, _): 5938 for node in gm.graph.nodes: 5939 if node.op == "output" or node.op == "placeholder": 5940 continue 5941 self.assertTrue(node.stack_trace is not None) 5942 assert "test_aotdispatch.py" in node.stack_trace 5943 return gm.forward # return a python callable 5944 5945 x = torch.randn(128, 20, requires_grad=True) 5946 y = torch.randn(128, 30, requires_grad=True) 5947 inputs = [x, y] 5948 5949 compiled_f = aot_module_simplified( 5950 mod, inputs, fw_compiler=assert_compiler, bw_compiler=assert_compiler 5951 ) 5952 res = compiled_f(*inputs) 5953 res[0].sum().backward() 5954 5955 def test_aot_module_simplified_preserves_stack_trace_from_mutation(self): 5956 class MockModule(torch.nn.Module): 5957 def __init__(self) -> None: 5958 super().__init__() 5959 5960 def forward(self, x): 5961 x_view = x[0] 5962 x_view.mul_(2) 5963 return (x + x,) 5964 5965 tracer = torch.fx.Tracer() 5966 tracer.record_stack_traces = True 5967 graph = tracer.trace(MockModule()) 5968 mod = torch.fx.GraphModule(tracer.root, graph) 5969 5970 for node in mod.graph.nodes: 5971 if node.op == "output": 5972 continue 5973 self.assertTrue(node.stack_trace is not None) 5974 assert "test_aotdispatch.py" in node.stack_trace 5975 5976 def assert_compiler(gm: torch.fx.GraphModule, _): 5977 assert torch.ops.aten.copy_.default in [x.target for x in gm.graph.nodes] 5978 for node in gm.graph.nodes: 5979 if node.target == torch.ops.aten.copy_.default: 5980 assert "stack_trace" in node.meta 5981 assert "x_view.mul_(2)" in node.meta["stack_trace"] 5982 return gm.forward # return a python callable 5983 5984 x = torch.randn(128, 20) 5985 inputs = [x] 5986 5987 aot_module_simplified( 5988 mod, 5989 inputs, 5990 fw_compiler=assert_compiler, 5991 bw_compiler=assert_compiler, 5992 keep_inference_input_mutations=True, 5993 ) 5994 5995 def test_aot_module_simplified_fake_tensor_gm_raises(self): 5996 fake_mode = torch._subclasses.fake_tensor.FakeTensorMode() 5997 real_x = torch.randn(4, requires_grad=True) 5998 fake_x = fake_mode.from_tensor(real_x) 5999 real_z = torch.randn(4) 6000 fake_z = fake_mode.from_tensor(real_z) 6001 6002 class MockModule(torch.nn.Module): 6003 def forward(self, x): 6004 # Accessing a free variable fake tensor will look like a 6005 # constant to make_fx, and result in the tensor being traced 6006 # into the graph, which is an error condition. Make sure we 6007 # report adequately in this case. 6008 return (x + fake_z,) 6009 6010 with self.assertRaisesRegex(AssertionError, "Unexpected fake"): 6011 aot_module_simplified(MockModule(), (fake_x,), nop) 6012 6013 def test_aot_test_subclasses_with_tensor_factories(self): 6014 from torch.testing._internal.common_subclass import SubclassWithTensorFactory 6015 6016 inp = SubclassWithTensorFactory(torch.zeros(3, 5)) 6017 6018 def fn(x): 6019 return 2 * x 6020 6021 ref_out = fn(inp) 6022 out = torch.compile(fn, backend="aot_eager", fullgraph=True)(inp) 6023 self.assertEqual(ref_out, out) 6024 6025 6026# entries in here don't work and need to be fixed. 6027# Each one of these is a bug (or needs to be investigated) 6028aot_autograd_failures = { 6029 # data-dependent control flow 6030 xfail("cov"), 6031 xfail("nn.functional.gaussian_nll_loss"), 6032 xfail("tensor_split"), 6033 xfail("corrcoef"), 6034 xfail("quantile"), 6035 xfail("nanquantile"), 6036 xfail("narrow"), 6037 xfail("istft"), 6038 xfail("linalg.eig"), 6039 skip("as_strided_scatter"), 6040 skip("as_strided", "partial_views"), # flaky 6041 # Given input size: (s0xs1x2). Calculated output size: ... 6042 skip("max_pool2d_with_indices_backward"), 6043 skip("nn.functional.nll_loss", ""), # UBSAN failure! 6044 # Misc 6045 xfail("to_sparse"), 6046 xfail("corrcoef"), 6047 xfail("cov"), 6048 xfail("chalf"), # RuntimeError: "sum_cpu" not implemented for 'ComplexHalf' 6049 xfail("sparse.sampled_addmm"), 6050 xfail("sparse.mm", "reduce"), 6051 skip("nn.functional.binary_cross_entropy_with_logits"), # seems to fail sometimes? 6052 skip("nn.functional.margin_ranking_loss"), # seems flaky 6053 skip("linalg.lu_solve"), # flaky 6054 decorate("matmul", decorator=unittest.skipIf(IS_ARM64, "flaky")), 6055 decorate("__rmatmul__", decorator=unittest.skipIf(IS_ARM64, "flaky")), 6056 # overrides atol=1e-4, rtol=1e-5 would do as well 6057 decorate( 6058 "svd_lowrank", 6059 decorator=toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-05)}), 6060 ), 6061 decorate( 6062 "linalg.householder_product", 6063 decorator=unittest.skipIf(IS_MACOS and IS_X86, "flaky"), 6064 ), 6065 decorate( 6066 "linalg.pinv", 6067 "singular", 6068 decorator=toleranceOverride({torch.float32: tol(atol=1e-05, rtol=1e-05)}), 6069 ), 6070 decorate( 6071 "nn.functional.interpolate", 6072 "bicubic", 6073 decorator=toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-05)}), 6074 ), 6075 # conv2d sometimes nondeterministic in this config? 6076 decorate("nn.functional.conv2d", decorator=unittest.skipIf(IS_ARM64, "flaky")), 6077} 6078 6079symbolic_aot_autograd_failures = { 6080 xfail("combinations", ""), # aten.masked_select.default 6081 xfail( 6082 "index_fill", "" 6083 ), # Cannot call sizes() on tensor with symbolic sizes/strides 6084 xfail( 6085 "linalg.lstsq", "" 6086 ), # aten.linalg_lstsq.default - couldn't find symbolic meta function/decomposition 6087 xfail( 6088 "linalg.lstsq", "grad_oriented" 6089 ), # aten.linalg_lstsq.default - couldn't find symbolic meta funct... 6090 xfail( 6091 "linalg.lu_solve", "" 6092 ), # aten.linalg_lu_solve.default - couldn't find symbolic meta function/deco... 6093 skip( 6094 "nn.functional.batch_norm", "" 6095 ), # '0 is not tracked with proxy for <torch.fx.experimental.proxy_te.. 6096 xfail( 6097 "nn.functional.binary_cross_entropy", "" 6098 ), # aten.fill_.Scalar - couldn't find symbolic meta funct... 6099 xfail( 6100 "nn.functional.cross_entropy", "" 6101 ), # Cannot call sizes() on tensor with symbolic sizes/strides 6102 xfail( 6103 "nn.functional.ctc_loss", "" 6104 ), # aten._ctc_loss.Tensor - couldn't find symbolic meta function/deco... 6105 xfail( 6106 "nn.functional.fractional_max_pool3d", "" 6107 ), # rand() received an invalid combination of arguments - g... 6108 xfail( 6109 "nn.functional.group_norm", "" 6110 ), # Cannot call sizes() on tensor with symbolic sizes/strides 6111 xfail( 6112 "nn.functional.nll_loss", "" 6113 ), # Cannot call sizes() on tensor with symbolic sizes/strides 6114 xfail( 6115 "_segment_reduce", "lengths" 6116 ), # aten.segment_reduce.default - couldn't find symbolic meta functio... 6117 xfail( 6118 "_segment_reduce", "offsets" 6119 ), # aten.segment_reduce.default - couldn't find symbolic meta functio... 6120 xfail("trace", ""), # Cannot call sizes() on tensor with symbolic sizes/strides 6121 xfail( 6122 "_upsample_bilinear2d_aa" 6123 ), # RuntimeError: isIntList() INTERNAL ASSERT FAILED Expected IntList but got GenericList 6124 decorate( 6125 "linalg.householder_product", 6126 decorator=unittest.skipIf(IS_MACOS and IS_X86, "flaky"), 6127 ), 6128 # many complex operators incorrect striding, metadata 6129 xfail("fft.fft", ""), 6130 xfail("fft.hfft2", ""), 6131 xfail("fft.hfft", ""), 6132 xfail("fft.hfftn", ""), 6133 xfail("fft.ifft", ""), 6134 xfail("fft.ihfft2", ""), 6135 xfail("fft.ihfft", ""), 6136 xfail("fft.ihfftn", ""), 6137 xfail("fft.irfft2", ""), 6138 xfail("fft.irfft", ""), 6139 xfail("fft.irfftn", ""), 6140 xfail("fft.rfft2", ""), 6141 xfail("fft.rfft", ""), 6142 xfail("fft.rfftn", ""), 6143 xfail("stft", ""), # Cannot call sizes() on tensor with symbolic sizes/strides 6144} 6145 6146 6147def _test_aot_autograd_helper(self, device, dtype, op, dynamic=False): 6148 if not op.supports_autograd: 6149 self.skipTest("Op does not support autograd") 6150 6151 # aot_autograd_check is able to check data specialization by 6152 # randomizing the inputs. Here's a list of ops that really do not 6153 # like random inputs for which we want to disable that. 6154 cant_check_data_specialization = set( 6155 { 6156 "nn.functional.max_unpool1d", 6157 "nn.functional.max_unpool2d", 6158 "nn.functional.max_unpool3d", 6159 } 6160 ) 6161 try_check_data_specialization = op.name not in cant_check_data_specialization 6162 6163 sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=True) 6164 for sample_input in sample_inputs_itr: 6165 t_args = [sample_input.input] + list(sample_input.args) 6166 t_kwargs = sample_input.kwargs 6167 try: 6168 aot_autograd_check( 6169 op.op, 6170 t_args, 6171 t_kwargs, 6172 dynamic, 6173 self.assertRaisesRegex, 6174 self.assertEqual, 6175 check_gradients=True, 6176 try_check_data_specialization=try_check_data_specialization, 6177 ) 6178 except DynamicOutputShapeException: 6179 self.skipTest("Dynamic output shape operation in trace") 6180 except GuardOnDataDependentSymNode: 6181 # Carveout for getitem; I don't want to xfail the entire test 6182 # because that will reject known to be good tests see 6183 # https://github.com/pytorch/pytorch/issues/94705 6184 if op.name == "__getitem__": 6185 self.skipTest("Dynamic output shape operation in trace") 6186 else: 6187 raise 6188 6189 6190def _test_aot_autograd_module_helper( 6191 self, device, dtype, training, module_info, *, dynamic=False 6192): 6193 module_cls = module_info.module_cls 6194 module_inputs = module_info.module_inputs_func( 6195 module_info, device=device, dtype=dtype, requires_grad=True, training=training 6196 ) 6197 for module_input in module_inputs: 6198 if module_input.forward_input is None: 6199 continue 6200 6201 args, kwargs = ( 6202 module_input.constructor_input.args, 6203 module_input.constructor_input.kwargs, 6204 ) 6205 m = module_cls(*args, **kwargs) 6206 m.to(device).to(dtype) 6207 m.train(training) 6208 6209 # Lazy modules need to see an input first to initialize params. 6210 args, kwargs = ( 6211 module_input.forward_input.args, 6212 module_input.forward_input.kwargs, 6213 ) 6214 flat_args, args_spec = pytree.tree_flatten((args, kwargs)) 6215 6216 # PackedSequence is only used for RNNs. It might be possible to fake-ify if they're pytrees but 6217 # torchdynamo already doesn't support RNNs 6218 if any(tuple(isinstance(flat_arg, PackedSequence) for flat_arg in flat_args)): 6219 continue 6220 6221 if issubclass(module_info.module_cls, torch.nn.modules.lazy.LazyModuleMixin): 6222 with torch.no_grad(): 6223 m(*args, **kwargs) 6224 6225 sentinel_val = -42 6226 is_tensor_spec = [ 6227 sentinel_val if isinstance(arg, torch.Tensor) else arg for arg in flat_args 6228 ] 6229 args = [arg for arg in flat_args if isinstance(arg, torch.Tensor)] 6230 6231 def f(params_buffers_args): 6232 named_params, named_buffers, args = params_buffers_args 6233 cur_flat_args = list(is_tensor_spec) 6234 args = iter(args) 6235 for idx, v in enumerate(cur_flat_args): 6236 if v == sentinel_val: 6237 cur_flat_args[idx] = next(args) 6238 c_args, c_kwargs = pytree.tree_unflatten(cur_flat_args, args_spec) 6239 params_and_buffers = {**named_params, **named_buffers} 6240 return torch.func.functional_call(m, params_and_buffers, c_args, c_kwargs) 6241 6242 named_params = dict(m.named_parameters(remove_duplicate=False)) 6243 named_buffers = dict(m.named_buffers(remove_duplicate=False)) 6244 num_params_buffers = len(named_params) + len(named_buffers) 6245 compiled_f = aot_function( 6246 f, nop, num_params_buffers=num_params_buffers, dynamic=dynamic 6247 ) 6248 params_buffers_args = [named_params, named_buffers, args] 6249 _test_aot_autograd_forwards_backwards_helper( 6250 f, 6251 compiled_f, 6252 params_buffers_args, 6253 self.assertRaisesRegex, 6254 self.assertEqual, 6255 True, 6256 ) 6257 6258 6259class TestEagerFusionOpInfo(AOTTestCase): 6260 @ops(op_db + hop_db, allowed_dtypes=(torch.float,)) 6261 @skipOps( 6262 "TestEagerFusionOpInfo", "test_aot_autograd_exhaustive", aot_autograd_failures 6263 ) 6264 def test_aot_autograd_exhaustive(self, device, dtype, op): 6265 _test_aot_autograd_helper(self, device, dtype, op) 6266 6267 @ops(op_db + hop_db, allowed_dtypes=(torch.float,)) 6268 @patch("functorch.compile.config.debug_assert", True) 6269 @skipOps( 6270 "TestEagerFusionOpInfo", 6271 "test_aot_autograd_symbolic_exhaustive", 6272 aot_autograd_failures | symbolic_aot_autograd_failures, 6273 ) 6274 def test_aot_autograd_symbolic_exhaustive(self, device, dtype, op): 6275 _test_aot_autograd_helper(self, device, dtype, op, dynamic=True) 6276 6277 6278aot_autograd_module_failures = set( 6279 { 6280 torch.nn.CTCLoss, # torch._subclasses.fake_tensor.DynamicOutputShapeException: aten._ctc_loss.default 6281 torch.nn.GaussianNLLLoss, # RuntimeError: It appears that you're trying to get value out 6282 # of a tracing tensor with aten._local_scalar_dense.default - 6283 # erroring out! It's likely that this is caused by data-dependent 6284 # control flow or similar. 6285 torch.nn.MultiLabelMarginLoss, # AssertionError: The values for attribute 'shape' do not match: 6286 # torch.Size([1]) != torch.Size([]). Outputs of the operator are different in 6287 # eager-mode PyTorch vs AOTAutograd. This means the operator will have incorrect 6288 # output underneath torch.compile. This could be because the operator's 6289 # implementation not traceable or that there is a bug in AOTAutograd. 6290 torch.nn.TransformerEncoder, # DataDependentOutputException: aten.eq compares a mask input 6291 # to a causal mask tensor, to see if Boolean is_causal should be set 6292 # for TrnasformerEncoder layers, MHA and sdp custom kernels 6293 torch.nn.Transformer, # DataDependentOutputException: aten.equal compares a mask input 6294 # to a causal mask tensor, to see if Boolean is_causal should be set 6295 # for TransformerEncoder layers, MHA and sdp custom kernels 6296 # (this bubbles up to Transformer) 6297 } 6298) 6299 6300symbolic_aot_autograd_module_failures = { 6301 torch.nn.Transformer, # DataDependentOutputException: aten.equal compares a mask input to a mask producing a bool 6302 torch.nn.TransformerEncoder, # DataDependentOutputException: aten.equal compares a mask input to a mask producing a bool 6303 torch.nn.GaussianNLLLoss, # NotImplementedError: local_scalar_dense/item NYI for torch.bool 6304 torch.nn.GroupNorm, # in native_group_norm_backward cpg, _rem = divmod(C, group) 6305 # TypeError: unsupported operand type(s) for divmod(): 'SymInt' and 'int' 6306 torch.nn.FractionalMaxPool3d, # int() argument must be a string, a bytes-like object or a number, not 'SymFloat' 6307 torch.nn.BCELoss, # new_size = _infer_size(target.size(), weight.size()) 6308 # RuntimeError: expected int at position 0, but got: SymInt 6309} 6310 6311 6312class TestEagerFusionModuleInfo(AOTTestCase): 6313 @modules(module_db, allowed_dtypes=(torch.float,)) 6314 @decorateForModules(unittest.expectedFailure, aot_autograd_module_failures) 6315 def test_aot_autograd_module_exhaustive(self, device, dtype, training, module_info): 6316 _test_aot_autograd_module_helper(self, device, dtype, training, module_info) 6317 6318 @modules(module_db, allowed_dtypes=(torch.float,)) 6319 @decorateForModules( 6320 unittest.expectedFailure, 6321 aot_autograd_module_failures | symbolic_aot_autograd_module_failures, 6322 ) 6323 def test_aot_autograd_symbolic_module_exhaustive( 6324 self, device, dtype, training, module_info 6325 ): 6326 _test_aot_autograd_module_helper( 6327 self, device, dtype, training, module_info, dynamic=True 6328 ) 6329 6330 6331instantiate_parametrized_tests(TestAOTAutograd) 6332only_for = "cpu" 6333instantiate_device_type_tests( 6334 TestPythonKey, 6335 globals(), 6336 only_for=only_for, 6337) 6338instantiate_device_type_tests(TestEagerFusionOpInfo, globals(), only_for=only_for) 6339instantiate_device_type_tests(TestEagerFusionModuleInfo, globals(), only_for=only_for) 6340 6341 6342@xfail_inherited_tests( 6343 [ 6344 "test_set__and_data_mutation_bad", 6345 "test_subclass_metadata_mutation_req_grad_True", 6346 "test_subclass_metadata_mutation_req_grad_False", 6347 ] 6348) 6349@skipIfTorchDynamo("This test suite already uses dynamo") 6350class TestAOTAutogradWithDynamo(TestAOTAutograd): 6351 """ 6352 These are the same as TestAOTAutograd tests, but we run dynamo first to get a graph module. 6353 """ 6354 6355 def assertExpectedInline(self, *args, **kwargs): 6356 # These will have different outputs because dynamo returns a different graph module 6357 # But we don't really care about that assertion when testing with dynamo, 6358 # only that the outputs match, etc. 6359 pass 6360 6361 def make_compiler(self, graph_cell): 6362 return make_boxed_compiler(partial(extract_graph, graph_cell=graph_cell)) 6363 6364 # Compiler to passes to dynamo 6365 def run_autograd( 6366 self, 6367 f: Callable, 6368 fw_graph_cell: List[Optional[Callable]], 6369 decompositions: Optional[Dict], 6370 keep_input_mutations: bool, 6371 dynamic: bool, 6372 ): 6373 """ 6374 Runs dynamo and aot_autograd with the specified settings 6375 """ 6376 6377 def dynamo_compiler(gm, inputs, **kwargs): 6378 result = aot_module_simplified( 6379 gm, 6380 inputs, 6381 fw_compiler=self.make_compiler(fw_graph_cell), 6382 bw_compiler=self.make_compiler([None]), 6383 decompositions=decompositions, 6384 keep_inference_input_mutations=keep_input_mutations, 6385 # Dynamic is calculated from whether the inputs have fake tensors 6386 ) 6387 return result 6388 6389 def torch_compile_wrapper(*args, **kwargs): 6390 torch._dynamo.reset() 6391 fn = torch.compile(f, backend=dynamo_compiler) 6392 try: 6393 result = fn(*args, **kwargs) 6394 except torch._dynamo.exc.BackendCompilerFailed as e: 6395 # So that assertRaises works properly 6396 raise e.inner_exception from e 6397 return result 6398 6399 return torch_compile_wrapper 6400 6401 6402class MockFXGraphCache: 6403 """ 6404 In memory version of FXGraphCache so we can isolate testing for FXGraphCache 6405 """ 6406 6407 def __init__(self) -> None: 6408 self.cache = {} 6409 6410 def save(self, key, gm): 6411 self.cache[key] = gm 6412 6413 def load(self, gm, inputs): 6414 key, _ = compiled_fx_graph_hash(gm, inputs, {}, {}) 6415 if key in self.cache: 6416 gm = make_boxed_func(gm) 6417 gm._fx_graph_cache_key = key 6418 return gm 6419 else: 6420 self.save(key, gm) 6421 gm = make_boxed_func(gm) 6422 gm._fx_graph_cache_key = key 6423 return gm 6424 6425 def _lookup_graph(self, key, inputs, local, remote_cache): 6426 gm = self.cache.get(key) 6427 if gm is not None: 6428 gm = make_boxed_func(gm) 6429 return gm 6430 6431 def post_compile(self, gm, inputs, cudagraphs): 6432 pass 6433 6434 6435# The following tests fail in strict caching mode (i.e. they bypass or 6436# cache miss instead of cache hitting). They will be fixed in the PRs above this. 6437FAILING_CACHE_TESTS = ( 6438 # BypassAOTAutogradCache: unsupported nodes 6439 "test_backward_mutation_data", # Custom Autograd Function 6440 "test_backward_mutation_metadata", # Custom Autograd Function 6441 "test_custom_autograd", # Custom Autograd Function 6442 "test_input_output_aliase_custom_autograd_function", 6443) 6444 6445 6446@xfail_inherited_tests(FAILING_CACHE_TESTS) 6447class TestAOTAutogradWithCache(TestAOTAutogradWithDynamo): 6448 """ 6449 In memory version of FXGraphCache so we can isolate testing for FXGraphCache 6450 """ 6451 6452 def make_compiler(self, fw_graph_cell): 6453 mock_inductor_cache = self.inductor_cache 6454 6455 def compiler(gm, inputs): 6456 nonlocal mock_inductor_cache, fw_graph_cell 6457 result = mock_inductor_cache.load(gm, inputs) 6458 fw_graph_cell[0] = gm 6459 return result 6460 6461 return compiler 6462 6463 def run_autograd( 6464 self, 6465 f: Callable, 6466 fw_graph_cell: List[Optional[Callable]], 6467 decompositions: Optional[Dict], 6468 keep_input_mutations: bool, 6469 dynamic: bool, 6470 ): 6471 return super().run_autograd( 6472 f, 6473 fw_graph_cell, 6474 decompositions, 6475 keep_input_mutations, 6476 dynamic, 6477 ) 6478 6479 @torch._functorch.config.patch( 6480 { 6481 "enable_autograd_cache": True, 6482 "strict_autograd_cache": True, 6483 "view_replay_for_aliased_outputs": False, 6484 } 6485 ) 6486 @torch._inductor.config.patch("fx_graph_cache", True) 6487 def verify_aot_autograd( 6488 self, 6489 f, 6490 inp_: Union[Callable, List[Any]], 6491 *, 6492 test_mutation: bool = False, 6493 keep_inp_mutations: bool = False, 6494 decompositions: Optional[Dict] = None, 6495 dynamic: bool = False, 6496 # Only active when inp_ is Callable. 6497 # TODO: probably consolidate all tests to make inp a Callable. 6498 make_inputs_subclasses: bool = False, 6499 ): 6500 self.inductor_cache = MockFXGraphCache() 6501 AOTAutogradCache.clear() 6502 with patch( 6503 "torch._inductor.codecache.FxGraphCache._lookup_graph", 6504 new=self.inductor_cache._lookup_graph, 6505 ), patch( 6506 "torch._inductor.codecache.FxGraphCache.post_compile", 6507 new=self.inductor_cache.post_compile, 6508 ): 6509 return super().verify_aot_autograd( 6510 f, 6511 inp_, 6512 test_mutation=test_mutation, 6513 keep_inp_mutations=keep_inp_mutations, 6514 decompositions=decompositions, 6515 dynamic=dynamic, 6516 make_inputs_subclasses=make_inputs_subclasses, 6517 ) 6518 6519 def test_input_mutation_false_aliasing(self): 6520 # This test is disabled because it fails in strict cache mode 6521 # But also can't be xfailed because it causes undefined behavior for 6522 # ASAN 6523 self.skipTest("Skipping because it fails in strict cache mode") 6524 6525 6526if __name__ == "__main__": 6527 run_tests() 6528