1# Owner(s): ["module: dynamo"] 2import copy 3import re 4import unittest 5from textwrap import dedent 6from unittest.mock import patch 7 8import torch 9import torch._dynamo 10import torch._dynamo.test_case 11import torch.fx.traceback as fx_traceback 12import torch.utils._pytree as pytree 13from torch._dynamo.testing import CompileCounter, expectedFailureDynamic, rand_strided 14from torch._functorch.aot_autograd import _aot_export_function, create_functional_call 15from torch._subclasses.fake_tensor import FakeTensorMode 16from torch.fx.experimental.proxy_tensor import make_fx 17from torch.profiler import profile 18from torch.testing import FileCheck 19from torch.testing._internal.common_utils import compare_equal_outs_and_grads 20 21 22def maybe_dupe_op(x): 23 y = x + 1 24 z = x + 2 25 if x.numel() < 5: 26 return y, y 27 else: 28 return y, z 29 30 31def is_dynamic_shape_test(test_name): 32 return test_name.endswith("_dynamic_shapes") 33 34 35aten = torch.ops.aten 36lib = torch.library.Library("custom", "DEF") # noqa: TOR901 37lib.define("maybe_dupe_op(Tensor a) -> (Tensor, Tensor)") 38lib.impl("maybe_dupe_op", maybe_dupe_op, "CPU") 39lib.impl("maybe_dupe_op", maybe_dupe_op, "Meta") 40 41 42class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase): 43 def test_LSTM(self): 44 # https://github.com/pytorch/torchdynamo/issues/1147 45 class Repro(torch.nn.Module): 46 def __init__(self) -> None: 47 super().__init__() 48 self.self_mod_model_lstm_lstm = torch.nn.LSTM( 49 64, 64, num_layers=2, bidirectional=True 50 ) 51 52 def forward(self, permute: torch.Tensor): 53 self_mod_model_lstm_lstm = self.self_mod_model_lstm_lstm(permute) 54 return (self_mod_model_lstm_lstm,) 55 56 mod = Repro() 57 58 aot_mod = torch._dynamo.optimize("aot_eager")(mod) 59 60 args = [((92, 4, 64), (1, 5888, 92), torch.float32, "cpu", False)] 61 args = [ 62 rand_strided(sh, st, dt, dev).requires_grad_(rg) 63 for (sh, st, dt, dev, rg) in args 64 ] 65 66 eager_result = mod(*args) 67 aot_result = aot_mod(*args) 68 self.assertTrue(torch._dynamo.testing.same(eager_result, aot_result)) 69 70 def test_mutation(self): 71 # https://github.com/pytorch/torchdynamo/issues/1301 72 def fn(param, y): 73 prev_grad = torch.is_grad_enabled() 74 try: 75 torch.set_grad_enabled(False) 76 param.add_(y) 77 finally: 78 torch.set_grad_enabled(prev_grad) 79 return y 80 81 y = torch.randn(4) 82 x = torch.nn.Parameter(torch.randn(4)) 83 aot_fn = torch._dynamo.optimize("aot_eager")(fn) 84 # This should not error: we mutated an autograd leaf under no_grad mode. 85 aot_fn(x, y) 86 87 def test_mutation1(self): 88 def fn(_stack0: torch.Tensor, diagonal_chunked_attention_scores: torch.Tensor): 89 getitem = diagonal_chunked_attention_scores[ 90 ( 91 slice(None, None, None), 92 slice(None, None, None), 93 slice(None, 256, None), 94 slice(None, 257, None), 95 ) 96 ] 97 _stack0[ 98 ( 99 slice(None, None, None), 100 slice(None, -1, None), 101 slice(None, None, None), 102 slice(256, None, None), 103 ) 104 ] = getitem 105 view = _stack0.view(1, 12, 1024, 513) 106 return (view,) 107 108 x = torch.randn(torch.Size([12, 4, 256, 513])) 109 y = torch.randn(torch.Size([12, 3, 512, 513])) 110 aot_fn = torch._dynamo.optimize("aot_eager")(fn) 111 aot_fn(x, y) 112 113 def test_negative_testing_mutation(self): 114 def fn(_stack0: torch.Tensor, diagonal_chunked_attention_scores: torch.Tensor): 115 getitem = diagonal_chunked_attention_scores[ 116 ( 117 slice(None, None, None), 118 slice(None, None, None), 119 slice(None, 256, None), 120 slice(None, 257, None), 121 ) 122 ] 123 _stack0 = torch.sin(_stack0) 124 _stack0[ 125 ( 126 slice(None, None, None), 127 slice(None, -1, None), 128 slice(None, None, None), 129 slice(256, None, None), 130 ) 131 ] = getitem 132 view = _stack0.view(1, 12, 1024, 513) 133 return (view,) 134 135 x = torch.randn(torch.Size([12, 4, 256, 513])) 136 y = torch.randn(torch.Size([12, 3, 512, 513])) 137 aot_fn = torch._dynamo.optimize("aot_eager")(fn) 138 aot_fn(x, y) 139 140 def test_negative_testing(self): 141 def fn(x, y): 142 return torch.sin(x).add_(y) 143 144 y = torch.randn(4) 145 x = torch.randn(4) 146 aot_fn = torch._dynamo.optimize("aot_eager")(fn) 147 aot_fn(x, y) 148 149 def test_call_fn_with_non_const_inputs_aot_safe(self): 150 class ModuleSpecialFwd(torch.nn.Module): 151 def __init__(self) -> None: 152 super().__init__() 153 self.conv = torch.nn.Conv2d( 154 in_channels=3, out_channels=20, kernel_size=(5, 5) 155 ) 156 157 def _conv_forward(self, x): 158 return self.conv._conv_forward(x, self.conv.weight, self.conv.bias) 159 160 def forward(self, x): 161 return self._conv_forward(x) 162 163 # Init mod 164 mod = ModuleSpecialFwd() 165 rx = torch.randn([3, 10, 10]) 166 167 # Run it for real 168 real = mod(rx) 169 170 # Run it in export 171 graph, _ = torch._dynamo.export(mod)(rx) 172 173 # Run exported graph with AOT 174 self.assertTrue(torch._dynamo.testing.same(real, graph(rx))) 175 176 aot_fn = torch._dynamo.optimize("aot_eager")(graph) 177 aot_fn(rx) 178 179 def test_call_fn_with_non_const_inputs_aot_unsafe(self): 180 class ModuleSpecialFwd(torch.nn.Module): 181 def _some_bad_fwd(self, param, y): 182 prev_grad = torch.is_grad_enabled() 183 try: 184 torch.set_grad_enabled(False) 185 param.add_(y) 186 finally: 187 torch.set_grad_enabled(prev_grad) 188 return y 189 190 def forward(self, x, y): 191 return self._some_bad_fwd(x, y) 192 193 # Init mod 194 mod = ModuleSpecialFwd() 195 x = torch.nn.Parameter(torch.randn(4)) 196 y = torch.randn([4]) 197 198 # Run it for real 199 real = mod(x, y) 200 201 # Run it in export 202 graph, _ = torch._dynamo.export(mod)(x, y) 203 204 # Assert equal 205 self.assertTrue(torch._dynamo.testing.same(real, graph(x, y))) 206 207 # Run exported graph with AOT 208 aot_fn = torch._dynamo.optimize("aot_eager")(graph) 209 # This should not error: we mutated an autograd leaf under no_grad mode. 210 aot_fn(x, y) 211 212 def test_call_fn_with_non_const_inputs_aot_unsafe_control_flow(self): 213 class ModuleSpecialFwd(torch.nn.Module): 214 def _some_bad_fwd(self, param, y): 215 if y[0][0] < 3: 216 return y + param 217 return param * y 218 219 def forward(self, x, y): 220 a = x * y 221 a = self._some_bad_fwd(a, a) 222 b = x + y 223 return a * b 224 225 # Init mod 226 mod = ModuleSpecialFwd() 227 x = torch.nn.Parameter(torch.randn([2, 2])) 228 y = torch.randn([2, 2]) 229 230 # Run it for real 231 real = mod(x, y) 232 233 # Run it through optimize, with our capturing fn 234 235 gms = [] 236 counter = CompileCounter() 237 238 def capturing_fn(gm, inputs): 239 nonlocal gms 240 gms.append(gm) 241 return counter(gm, inputs) 242 243 optimized_mod = torch._dynamo.optimize(capturing_fn)(mod) 244 245 # Assert equal 246 self.assertTrue(torch._dynamo.testing.same(real, optimized_mod(x, y))) 247 248 # Uncomment to reproduce commented out graphs below. 249 # for gm in gms: 250 # print("GM CODE", gm.code) 251 252 self.assertEqual(counter.frame_count, 4) 253 self.assertEqual(counter.op_count, 7) 254 # Graph 1 255 # def forward(self, x : torch.nn.parameter.Parameter, y : torch.Tensor): 256 # mul = x * y; x = y = None 257 # return (mul,) 258 # BREAK 259 # Graph 2 260 # def forward(self, y : torch.Tensor): 261 # getitem = y[0]; y = None 262 # getitem_1 = getitem[0]; getitem = None 263 # lt = getitem_1 < 3; getitem_1 = None 264 # return (lt,) 265 # BREAK 266 # Graph 3 267 # def forward(self, param : torch.Tensor, y : torch.Tensor): 268 # add = y + param; y = param = None 269 # return (add,) 270 # BREAK 271 # Graph 4 272 # def forward(self, _stack0 : torch.Tensor, x : torch.nn.parameter.Parameter, y : torch.Tensor): 273 # add = x + y; x = y = None 274 # mul = _stack0 * add; _stack0 = add = None 275 # return (mul,) 276 277 # Run fn with AOT 278 torch._dynamo.reset() 279 280 aot_fn = torch._dynamo.optimize("aot_eager")(optimized_mod) 281 aot_fn(x, y) 282 283 # Note: Dynamo recompilation guarding invalid grad 284 # 285 # This test is a spiritual equivalent to test_invalid_requires_grad_fake in test_autodispatch.py 286 # The point of this test is to invoke aot_autograd in a way that would normally trigger an assertion 287 # (This is what test_invalid_requires_grad_fake) does. However, the point of this test is to prove 288 # that we do not hit this assertion, as dynamo recompiles correctly and protects this condition. 289 # 290 # Subnote: The reason for us having test_invalid_requires_grad_fake utilizing fake tensors 291 # is because dynamo sends fake tensors down to aot_autograd. 292 @patch("torch._functorch.config.debug_assert", True) 293 def test_requires_grad_fake_via_dynamo_recompiles(self): 294 class F(torch.nn.Module): 295 def forward(self, x, y): 296 return (x + y,) 297 298 x = torch.randn(3, 3, requires_grad=True) 299 y = torch.randn(3, 3, requires_grad=True) 300 z = torch.randn(3, 3, requires_grad=False) 301 302 cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") 303 304 failure_reason = None 305 306 def guard_fail_fn(failure): 307 nonlocal failure_reason 308 failure_reason = failure[0] 309 310 fxy = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F()) 311 compare_equal_outs_and_grads(self, F(), fxy, (x, y)) 312 compare_equal_outs_and_grads(self, F(), fxy, (x, z)) 313 self.assertIn( 314 """tensor 'L['y']' requires_grad mismatch. expected requires_grad=1""", 315 failure_reason, 316 ) 317 318 # Reset failure reason 319 failure_reason = None 320 321 self.assertEqual(cc.frame_count, 2) 322 323 torch._dynamo.reset() # for new backend 324 cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") 325 326 fxz = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F()) 327 compare_equal_outs_and_grads(self, F(), fxz, (x, z)) 328 compare_equal_outs_and_grads(self, F(), fxz, (x, z)) 329 self.assertEqual(cc.frame_count, 1) 330 self.assertTrue(failure_reason is None) 331 332 def test_double_backward_errors(self): 333 # Remove this test after we get double backward to actually work 334 for grad_output in (torch.tensor(1.0, requires_grad=True), None): 335 x = torch.tensor(1.0, requires_grad=True) 336 err = "torch.compile with aot_autograd does not currently support double backward" 337 338 # The following cases should be equivalent: 339 340 # (1) double backward entirely inside compiled function 341 def f1(x): 342 y = x.sin().exp() 343 (gx,) = torch.autograd.grad( 344 y, x, create_graph=True, grad_outputs=grad_output 345 ) 346 torch.autograd.grad(gx, x) 347 return gx 348 349 compiled_f1 = torch.compile(backend="aot_eager")(f1) 350 f1(x) 351 with self.assertRaisesRegex(RuntimeError, err): 352 compiled_f1(x) 353 354 # (2) the second half of double backward outside compiled function 355 def f2(x): 356 y = x.sin().exp() 357 (gx,) = torch.autograd.grad( 358 y, x, create_graph=True, grad_outputs=grad_output 359 ) 360 return gx 361 362 compiled_f2 = torch.compile(backend="aot_eager")(f2) 363 gx = compiled_f2(x) 364 with self.assertRaisesRegex(RuntimeError, err): 365 torch.autograd.grad(gx, x) 366 367 # (3) double backward entirely outside compiled function 368 def f3(x): 369 y = x.sin().exp() 370 return y 371 372 compiled_f3 = torch.compile(backend="aot_eager")(f3) 373 y = compiled_f3(x) 374 (gx,) = torch.autograd.grad( 375 y, x, create_graph=True, grad_outputs=grad_output 376 ) 377 with self.assertRaisesRegex(RuntimeError, err): 378 torch.autograd.grad(gx, x) 379 380 # create_graph=False 381 def f4(x): 382 y = x.sin().exp() 383 return y 384 385 compiled_f4 = torch.compile(backend="aot_eager")(f4) 386 x = torch.tensor(1.0, requires_grad=True) 387 y = compiled_f4(x) 388 (gx,) = torch.autograd.grad(y, x, create_graph=False, grad_outputs=grad_output) 389 390 @patch("torch._functorch.config.debug_assert", True) 391 def test_arg_dupe_via_dynamo_recompiles(self): 392 class F(torch.nn.Module): 393 def forward(self, x, y): 394 x = x.trunc_() 395 y = y.trunc_() 396 return (x + y,) 397 398 x = torch.randn(3, 3, requires_grad=True) 399 x1, x2, x3, x4 = x.clone(), x.clone(), x.clone(), x.clone() 400 y = torch.randn(3, 3, requires_grad=True) 401 y1, y2, y4 = y.clone(), y.clone(), y.clone() 402 403 cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") 404 405 failure_reason = None 406 407 def guard_fail_fn(failure): 408 nonlocal failure_reason 409 failure_reason = failure[0] 410 411 fxy = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F()) 412 # Note: to prevent a recompilation between the two calls, 413 # we need to clone x and y on each use. 414 # fxy mutates the input's metadata, so otherwise dynamo will end up recompiling. 415 fxy(x1, y1) 416 fxy(x2, y2) 417 418 self.assertTrue(failure_reason is None) 419 420 # Reset failure reason 421 failure_reason = None 422 423 self.assertEqual(cc.frame_count, 1) 424 425 torch._dynamo.reset() # for new backend 426 cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") 427 428 fxx = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F()) 429 fxx(x3, x3) 430 fxx(x4, y4) 431 self.assertEqual(cc.frame_count, 2) 432 self.assertIn("""L['x'] is L['y']""", failure_reason) 433 434 @patch("torch._functorch.config.debug_assert", True) 435 def test_arg_dupe_via_dynamo_recompiles_many_args_param_non_tensor_arg(self): 436 class F(torch.nn.Module): 437 def __init__(self) -> None: 438 super().__init__() 439 self.mean = torch.nn.Parameter(torch.randn(3, 3)) 440 441 def forward(self, a, b, e, f): 442 a.trunc_() 443 b.trunc_() 444 return (a + b + self.mean) * e * f 445 446 a = torch.randn(3, 3, requires_grad=True) 447 b = torch.randn(3, 3, requires_grad=True) 448 a1, a2 = a.clone(), a.clone() 449 b1, b2 = b.clone(), b.clone() 450 451 failure_reason = None 452 453 def guard_fail_fn(failure): 454 nonlocal failure_reason 455 failure_reason = failure[0] 456 457 self.assertTrue(failure_reason is None) 458 459 cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") 460 461 f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F()) 462 f(a1, a1, 2, 2) 463 f(a2, b2, 2, 2) 464 self.assertEqual(cc.frame_count, 2) 465 self.assertIn( 466 """L['a'] is L['b']""", 467 failure_reason, 468 ) 469 470 torch._dynamo.reset() 471 472 cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") 473 474 c = torch.randn(3, 3, requires_grad=True) 475 d = torch.randn(3, 3, requires_grad=True) 476 c3, c4 = c.clone(), c.clone() 477 d3, d4 = d.clone(), d.clone() 478 479 f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F()) 480 f(c3, c3, 3, 3) 481 f(c4, d4, 3, 3) 482 self.assertEqual(cc.frame_count, 2) 483 self.assertIn("""L['a'] is L['b']""", failure_reason) 484 485 @patch("torch._functorch.config.debug_assert", True) 486 def test_arg_dupe_via_dynamo_recompiles_many_with_global(self): 487 z = None 488 489 class F(torch.nn.Module): 490 def __init__(self) -> None: 491 super().__init__() 492 self.mean = torch.nn.Parameter(torch.randn(3, 3)) 493 494 def forward(self, a, b, e, f): 495 a.trunc_() 496 b.trunc_() 497 return (a + b + z + self.mean) * e * f 498 499 a = torch.randn(3, 3, requires_grad=True) 500 b = torch.randn(3, 3, requires_grad=True) 501 z = a 502 a1, a2 = a.clone(), a.clone() 503 b1, b2 = b.clone(), b.clone() 504 505 failure_reason = None 506 507 def guard_fail_fn(failure): 508 nonlocal failure_reason 509 failure_reason = failure[0] 510 511 self.assertTrue(failure_reason is None) 512 513 cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") 514 515 f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F()) 516 f(a1, a1, 2, 2) 517 f(a2, b2, 2, 2) 518 self.assertEqual(cc.frame_count, 2) 519 self.assertIn( 520 """L['a'] is L['b']""", 521 failure_reason, 522 ) 523 524 @patch("torch._functorch.config.debug_assert", True) 525 def test_arg_dupe_via_dynamo_recompiles_many_args_param_non_tensor_arg_list(self): 526 class F(torch.nn.Module): 527 def __init__(self) -> None: 528 super().__init__() 529 self.mean = torch.nn.Parameter(torch.randn(3, 3)) 530 531 def forward(self, e, f, a, b): 532 a.trunc_() 533 b.trunc_() 534 return (a + b + self.mean) * e[0] * f[0] 535 536 a = torch.randn(3, 3, requires_grad=True) 537 b = torch.randn(3, 3, requires_grad=True) 538 a1, a2 = a.clone(), a.clone() 539 b1, b2 = b.clone(), b.clone() 540 541 failure_reason = None 542 543 def guard_fail_fn(failure): 544 nonlocal failure_reason 545 failure_reason = failure[0] 546 547 self.assertTrue(failure_reason is None) 548 549 cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") 550 551 f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F()) 552 f([3, 2, 1], [4, 5, 6], a1, a1) 553 f([3, 2, 1], [4, 5, 6], a2, b2) 554 self.assertEqual(cc.frame_count, 2) 555 self.assertIn( 556 """L['a'] is L['b']""", 557 failure_reason, 558 ) 559 560 torch._dynamo.reset() 561 562 cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") 563 564 c = torch.randn(3, 3, requires_grad=True) 565 d = torch.randn(3, 3, requires_grad=True) 566 c3, c4 = c.clone(), c.clone() 567 d3, d4 = d.clone(), d.clone() 568 569 f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F()) 570 f([3, 2, 1], [4, 5, 6], c3, c3) 571 f([3, 2, 1], [4, 5, 6], c4, d4) 572 self.assertEqual(cc.frame_count, 2) 573 574 @patch("torch._functorch.config.debug_assert", True) 575 def test_arg_dupe_via_dynamo_recompiles_many_args_param(self): 576 class F(torch.nn.Module): 577 def __init__(self) -> None: 578 super().__init__() 579 self.mean = torch.nn.Parameter(torch.randn(3, 3)) 580 581 def forward(self, a, b): 582 a.trunc_() 583 b.trunc_() 584 return a + b + self.mean 585 586 a = torch.randn(3, 3, requires_grad=True) 587 b = torch.randn(3, 3, requires_grad=True) 588 a1, a2 = a.clone(), a.clone() 589 b1, b2 = b.clone(), b.clone() 590 591 failure_reason = None 592 593 def guard_fail_fn(failure): 594 nonlocal failure_reason 595 failure_reason = failure[0] 596 597 self.assertTrue(failure_reason is None) 598 599 cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") 600 601 f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F()) 602 f(a1, a1) 603 f(a2, b2) 604 self.assertEqual(cc.frame_count, 2) 605 self.assertIn( 606 """L['a'] is L['b']""", 607 failure_reason, 608 ) 609 610 torch._dynamo.reset() 611 612 cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") 613 614 c = torch.randn(3, 3, requires_grad=True) 615 d = torch.randn(3, 3, requires_grad=True) 616 c3, c4 = c.clone(), c.clone() 617 d3, d4 = d.clone(), d.clone() 618 619 f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F()) 620 f(c3, c3) 621 f(c4, d4) 622 self.assertEqual(cc.frame_count, 2) 623 self.assertIn("""L['a'] is L['b']""", failure_reason) 624 625 @patch("torch._functorch.config.debug_assert", True) 626 def test_arg_dupe_via_dynamo_recompiles_many_args(self): 627 class F(torch.nn.Module): 628 def forward(self, a, b, c, d): 629 a.trunc_() 630 b.trunc_() 631 c.trunc_() 632 d.trunc_() 633 return (a + b + c + d,) 634 635 a = torch.randn(3, 3, requires_grad=True) 636 b = torch.randn(3, 3, requires_grad=True) 637 a1, a2, a3, a4 = a.clone(), a.clone(), a.clone(), a.clone() 638 b1, b2, b3, b4 = b.clone(), b.clone(), b.clone(), b.clone() 639 640 failure_reason = None 641 642 def guard_fail_fn(failure): 643 nonlocal failure_reason 644 failure_reason = failure[0] 645 646 self.assertTrue(failure_reason is None) 647 648 cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") 649 650 f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F()) 651 f(a1, a1, a1, a1) 652 f(a2, b2, b2, b2) 653 self.assertEqual(cc.frame_count, 2) 654 self.assertIn( 655 """L['a'] is L['b']""", 656 failure_reason, 657 ) 658 659 torch._dynamo.reset() 660 661 cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") 662 663 c = torch.randn(3, 3, requires_grad=True) 664 d = torch.randn(3, 3, requires_grad=True) 665 c3, c4 = c.clone(), c.clone() 666 d3, d4 = d.clone(), d.clone() 667 668 f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F()) 669 f(a3, b3, c3, c3) 670 f(a4, b4, c4, d4) 671 self.assertEqual(cc.frame_count, 2) 672 self.assertIn("""L['c'] is L['d']""", failure_reason) 673 674 def test_alias_inputs(self): 675 def fn(): 676 a = torch.tensor([1]) 677 a = a[0:1] 678 b = a.squeeze() 679 a[0] = 0 680 if a[0] < 1e5: 681 pass 682 a[0] = 2 683 return b 684 685 ref_output = fn() 686 aot_fn = torch._dynamo.optimize("aot_eager")(fn) 687 actual_output = aot_fn() 688 self.assertEqual(ref_output, actual_output) 689 690 def test_grad_inputs_alias_inputs(self): 691 class Test(torch.autograd.Function): 692 @staticmethod 693 def forward(ctx, x, y): 694 ctx.save_for_backward(x) 695 return y 696 697 @staticmethod 698 def backward(ctx, grad): 699 (x,) = ctx.saved_tensors 700 return x, grad 701 702 def fn(x, y): 703 return Test.apply(x, y) 704 705 x = torch.ones(1, requires_grad=True) 706 y = torch.ones(1, requires_grad=True) 707 compiled_fn = torch.compile(fn, backend="aot_eager") 708 out = compiled_fn(x, y) 709 out.sum().backward() 710 711 @expectedFailureDynamic # https://github.com/pytorch/pytorch/issues/103539 712 @torch._dynamo.config.patch(automatic_dynamic_shapes=False) 713 @patch("torch._functorch.config.debug_assert", True) 714 def test_multiple_aot_autograd_calls_dupe_args(self): 715 # this is just dealing with the fact that 716 # aot_module_simplified expects submods to always return tuples/lists 717 class WrapperModule(torch.nn.Module): 718 def __init__(self, mod): 719 super().__init__() 720 self.mod = mod 721 722 def forward(self, *args): 723 out = self.mod(*args) 724 if isinstance(out, (list, tuple)): 725 return out 726 return (out,) 727 728 def compile_submod(input_mod, args): 729 from functorch.compile import nop 730 from torch._functorch.aot_autograd import aot_module_simplified 731 732 class WrapperModule(torch.nn.Module): 733 def __init__(self) -> None: 734 super().__init__() 735 self.original = input_mod 736 self.submod = aot_module_simplified(input_mod, args, nop) 737 738 def forward(self, *args): 739 return self.submod(*args) 740 741 return WrapperModule() 742 743 def test_compile(fx_g, example_inps): 744 split_gm = torch.fx.passes.split_module.split_module( 745 fx_g, None, lambda node: 1 if "mul" in str(node) else 0 746 ) 747 submod_1_inps = split_gm.submod_0(*example_inps) 748 split_gm.submod_0 = compile_submod( 749 WrapperModule(split_gm.submod_0), example_inps 750 ) 751 split_gm.submod_1 = compile_submod( 752 WrapperModule(split_gm.submod_1), submod_1_inps 753 ) 754 return split_gm 755 756 @torch._dynamo.optimize(test_compile) 757 def f(a): 758 b, c = torch.ops.custom.maybe_dupe_op(a) 759 return (b.mul_(c),) 760 761 f(torch.ones(4)) 762 f(torch.ones(6)) 763 764 def test_nn_parameter_construction(self): 765 # https://github.com/pytorch/pytorch/issues/99569 766 def fn(x): 767 y = x.sin() 768 z = torch.nn.Parameter(torch.ones(1)) 769 return y + z 770 771 x = torch.rand((4, 4)) 772 773 opt_fn = torch._dynamo.optimize("aot_eager")(fn) 774 self.assertTrue(torch._dynamo.testing.same(fn(x), opt_fn(x))) 775 776 def test_aot_sequence_nr(self): 777 class Model(torch.nn.Module): 778 def __init__(self) -> None: 779 super().__init__() 780 self.conv1 = torch.nn.Conv2d( 781 in_channels=16, 782 out_channels=16, 783 kernel_size=(1, 1), 784 stride=1, 785 padding="same", 786 bias=True, 787 ) 788 self.bn1 = torch.nn.BatchNorm2d(num_features=16) 789 self.relu1 = torch.nn.ReLU() 790 self.fc1 = torch.nn.Linear(in_features=1638400, out_features=1) 791 self.loss_fn = torch.nn.L1Loss() 792 793 def forward(self, x, target): 794 y = x 795 x = self.conv1(x) 796 x = self.bn1(x) 797 x = self.relu1(x) 798 x = x + y 799 x = torch.flatten(x) 800 x = self.fc1(x) 801 output = self.loss_fn(x, target) 802 803 return (output,) 804 805 mod = Model() 806 mod.train() 807 x = torch.rand(100, 16, 32, 32, requires_grad=True) 808 target = torch.rand(1) 809 810 # Use dynamo export to get the fx graph module 811 g_mod, _ = torch._dynamo.export(mod, x, target) 812 813 def _prepare_model_args(): 814 named_parameters = dict(g_mod.named_parameters(remove_duplicate=False)) 815 named_buffers = dict(g_mod.named_buffers(remove_duplicate=False)) 816 params_and_buffers = { 817 **dict(named_parameters), 818 **dict(named_buffers), 819 } 820 params_and_buffers_flat, params_spec = pytree.tree_flatten( 821 params_and_buffers 822 ) 823 params_len = len(params_and_buffers_flat) 824 functional_call = create_functional_call(g_mod, params_spec, params_len) 825 return params_and_buffers_flat, functional_call 826 827 full_args, fn_to_trace = _prepare_model_args() 828 param_and_buf_len = len(full_args) 829 full_args.extend([x, target]) 830 831 # aot_export requires a graph mod input of fwd graph 832 # returns the full fwd/bwd graph in graph mod format 833 with torch.enable_grad(), fx_traceback.preserve_node_meta(): 834 fx_g, _, _, _ = _aot_export_function( 835 fn_to_trace, 836 full_args, 837 decompositions=None, 838 num_params_buffers=param_and_buf_len, 839 no_tangents=True, 840 ) 841 842 # Walk all the nodes in fx graph. 843 # Write the resulting ops to a table 844 min_seq_nr = -1 845 seq_table = "SeqNr|OrigAten|SrcFn|FwdSrcFn\n" 846 for node in fx_g.graph.nodes: 847 if "call_" in node.op and "getitem" not in str(node.target): 848 seq_nr = node.meta.get("seq_nr", -1) 849 if seq_nr < 0: 850 continue 851 if min_seq_nr < 0: 852 min_seq_nr = seq_nr 853 source_fn_stack = node.meta.get("source_fn_stack", []) 854 orig_aten = node.meta.get("original_aten", "") 855 mod_name = "" 856 if len(source_fn_stack) > 0: 857 mod_name = source_fn_stack[-1][0] 858 # Make all seq_nr relative so it starts at 0 859 seq_nr = seq_nr - min_seq_nr 860 # For backward nodes, also test that metadata from the corresponding 861 # forward node is copied over. 862 fwd_source_fn_stack = node.meta.get("fwd_source_fn_stack", []) 863 fwd_mod_name = "" 864 if len(fwd_source_fn_stack): 865 fwd_mod_name = fwd_source_fn_stack[-1][0] 866 seq_table = ( 867 seq_table + f"{seq_nr}|{orig_aten}|{mod_name}|{fwd_mod_name}\n" 868 ) 869 870 self.maxDiff = None 871 self.assertExpectedInline( 872 seq_table, 873 dedent( 874 """\ 875SeqNr|OrigAten|SrcFn|FwdSrcFn 8760|aten.convolution.default|l__self___conv1| 8770|aten.add.Tensor|l__self___bn1| 8781|aten._native_batch_norm_legit_functional.default|l__self___bn1| 8792|aten.relu.default|l__self___relu1| 8802|aten.detach.default|l__self___relu1| 8812|aten.detach.default|l__self___relu1| 8823|aten.add.Tensor|add| 8834|aten.view.default|flatten| 8845|aten.view.default|l__self___fc1| 8856|aten.t.default|l__self___fc1| 8867|aten.addmm.default|l__self___fc1| 8878|aten.view.default|l__self___fc1| 8889|aten.sub.Tensor|l__self___loss_fn| 88910|aten.abs.default|l__self___loss_fn| 89011|aten.mean.default|l__self___loss_fn| 89111|aten.ones_like.default||l__self___loss_fn 89211|aten.expand.default||l__self___loss_fn 89311|aten.div.Scalar||l__self___loss_fn 89410|aten.sgn.default||l__self___loss_fn 89510|aten.mul.Tensor||l__self___loss_fn 8968|aten.view.default||l__self___fc1 8977|aten.t.default||l__self___fc1 8987|aten.mm.default||l__self___fc1 8997|aten.t.default||l__self___fc1 9007|aten.mm.default||l__self___fc1 9017|aten.t.default||l__self___fc1 9027|aten.sum.dim_IntList||l__self___fc1 9037|aten.view.default||l__self___fc1 9046|aten.t.default||l__self___fc1 9055|aten.view.default||l__self___fc1 9064|aten.view.default|| 9072|aten.detach.default||l__self___relu1 9082|aten.detach.default||l__self___relu1 9092|aten.threshold_backward.default||l__self___relu1 9101|aten.native_batch_norm_backward.default||l__self___bn1 9110|aten.convolution_backward.default||l__self___conv1 91211|aten.add.Tensor||l__self___loss_fn 913""" 914 ), 915 ) 916 917 def test_split_with_sizes_aot_autograd_cleans_up_traceback_meta(self): 918 from torch._functorch.aot_autograd import setup_stacktrace_preservation_hooks 919 920 def fn(result, split_sizes): 921 rs = torch.ops.aten.split_with_sizes(result, split_sizes.tolist()) 922 return rs 923 924 example_inputs = ( 925 torch.randn(32, requires_grad=True), 926 torch.tensor((7, 16, 9)), 927 ) 928 outs = fn(*example_inputs) 929 setup_stacktrace_preservation_hooks([out.grad_fn for out in outs]) 930 with fx_traceback.preserve_node_meta(): 931 (outs[0].sum() + outs[1].sum() + outs[2].sum()).backward() 932 933 self.assertNotIn("grad_fn_seq_nr", fx_traceback.current_meta) 934 self.assertNotIn("in_grad_fn", fx_traceback.current_meta) 935 936 # https://github.com/pytorch/pytorch/issues/110121 937 def test_aot_export_joint_simple_repro(self): 938 class Mod(torch.nn.Module): 939 def __init__(self, *args, **kwargs) -> None: 940 super().__init__(*args, **kwargs) 941 self.linear = torch.nn.Linear(5, 7) 942 943 def forward(self, x): 944 return self.linear(x) 945 946 def mini_backend(gm, sample_inputs): 947 from torch._functorch.aot_autograd import aot_export_joint_simple 948 949 fake_mode = torch._dynamo.utils.detect_fake_mode(sample_inputs) 950 951 with patch.object(fake_mode, "allow_non_fake_inputs", True), fake_mode: 952 return aot_export_joint_simple(gm, sample_inputs, trace_joint=False) 953 954 sample_inputs = [torch.rand((3, 4, 5))] 955 model = Mod() 956 m_compiled = torch.compile(model, backend=mini_backend) 957 958 out_ref = model(*sample_inputs) 959 out_test = m_compiled(*sample_inputs) 960 self.assertEqual(out_ref, out_test) 961 962 def test_eager_sequence_nr(self): 963 class Model(torch.nn.Module): 964 def __init__(self) -> None: 965 super().__init__() 966 self.conv1 = torch.nn.Conv2d( 967 in_channels=16, 968 out_channels=16, 969 kernel_size=(1, 1), 970 stride=1, 971 padding="same", 972 bias=True, 973 ) 974 self.bn1 = torch.nn.BatchNorm2d(num_features=16) 975 self.relu1 = torch.nn.ReLU() 976 self.fc1 = torch.nn.Linear(in_features=1638400, out_features=1) 977 self.loss_fn = torch.nn.L1Loss() 978 979 def forward(self, x, target): 980 y = x 981 x = self.conv1(x) 982 x = self.bn1(x) 983 x = self.relu1(x) 984 x = x + y 985 x = torch.flatten(x) 986 x = self.fc1(x) 987 output = self.loss_fn(x, target) 988 989 return (output,) 990 991 def grad_with_create_graph(mod, x, target): 992 y = mod(x, target) 993 # Set create_graph=True to ensure that the sequence_nr 994 # for backward ops continues to count down. 995 (gx,) = torch.autograd.grad( 996 y[0], x, create_graph=True, grad_outputs=grad_output 997 ) 998 return gx 999 1000 x = torch.rand(100, 16, 32, 32, requires_grad=True) 1001 target = torch.rand(1) 1002 mod = Model() 1003 args = [mod, x, target] 1004 grad_output = torch.tensor(1.0, requires_grad=True) 1005 compiled_f1 = torch.compile(backend="aot_eager")(grad_with_create_graph) 1006 model_instance = compiled_f1 1007 with profile( 1008 activities=[torch.profiler.ProfilerActivity.CPU], 1009 record_shapes=True, 1010 ) as kineto_prof: 1011 res = model_instance(*args) 1012 bwd_set = set() 1013 prof_str = "SeqNr|Thread|FwdThread|Name\n" 1014 for event in kineto_prof.events(): 1015 if event.sequence_nr >= 0: 1016 prof_str = ( 1017 prof_str + f"{event.sequence_nr}|{event.thread}" 1018 f"|{event.fwd_thread}|{event.name}|\n" 1019 ) 1020 if re.search(r"Backward[01]", event.name): 1021 bwd_set.add(event.sequence_nr) 1022 self.assertTrue(len(bwd_set), 13) 1023 1024 def test_aot_grad_mode_mutation(self): 1025 for compiler in ["aot_eager", "inductor"]: 1026 1027 def f(x): 1028 y = x * x 1029 torch.set_grad_enabled(False) 1030 return y.clone(), y 1031 1032 f_compiled = torch.compile(f, backend=compiler, fullgraph=True) 1033 1034 torch.set_grad_enabled(True) 1035 x = torch.ones(3, requires_grad=True) * 3 1036 y_ref = f(x) 1037 self.assertEqual(torch.is_grad_enabled(), False) 1038 torch.set_grad_enabled(True) 1039 y = f_compiled(x) 1040 self.assertEqual(torch.is_grad_enabled(), False) 1041 torch.set_grad_enabled(True) 1042 self.assertEqual(y_ref, y) 1043 1044 self.assertIsNone(y_ref[0].grad_fn) 1045 self.assertIsNone(y[0].grad_fn) 1046 1047 self.assertIsNotNone(y_ref[1].grad_fn) 1048 self.assertIsNotNone(y[1].grad_fn) 1049 1050 # Check that the grad computed for the inputs, given the input, is the same 1051 # The tangent to `y[0]`, which has grad_required=False, is irrelevant 1052 self.assertEqual( 1053 sum(y_ref[1].grad_fn(torch.tensor([-1.0, 2.0, 0.0]))), 1054 sum( 1055 x 1056 for x in y[1].grad_fn.apply(None, torch.tensor([-1.0, 2.0, 0.0])) 1057 if x is not None 1058 ), 1059 ) 1060 1061 def test_aot_autograd_raises_invalid_leaf_set(self): 1062 @torch.compile 1063 def f(x): 1064 x.set_(torch.ones(2)) 1065 1066 # We still want to make sure that this raises 1067 x = torch.ones(2, requires_grad=True) 1068 with self.assertRaisesRegex( 1069 RuntimeError, "is being used in an in-place operation" 1070 ): 1071 f(x) 1072 1073 def test_aot_autograd_expand_mutation_functionalizes(self): 1074 def fn(x): 1075 y = x.expand(3, *x.shape) 1076 y[0, 0].add_(5) 1077 return y 1078 1079 opt_fn = torch.compile(fn, backend="aot_eager") 1080 1081 x = torch.arange(6) 1082 x_opt = x.clone().detach() 1083 self.assertEqual(fn(x), opt_fn(x_opt)) 1084 self.assertEqual(x, x_opt) 1085 1086 def test_aot_autograd_expand_mutation_backwards(self): 1087 def fn(x, z): 1088 y = x.expand(3, *x.shape) 1089 y[1, 1].mul_(5) 1090 ret = y * z 1091 return ret 1092 1093 opt_fn = torch.compile(fn, backend="aot_eager") 1094 1095 x = torch.arange(6, dtype=torch.float) 1096 z = x.clone().detach() 1097 x_opt = x.clone().detach() 1098 z_opt = x.clone().detach() 1099 1100 z.requires_grad = True 1101 z_opt.requires_grad = True 1102 1103 res = fn(x, z) 1104 opt_res = opt_fn(x_opt, z_opt) 1105 1106 self.assertEqual(res, opt_res) 1107 1108 res.sum().backward() 1109 opt_res.sum().backward() 1110 1111 self.assertEqual(x, x_opt) 1112 self.assertEqual(z.grad, z_opt.grad) 1113 1114 def test_data_ptr_access_copy(self): 1115 import torch._functorch.config as _config 1116 1117 with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False): 1118 with FakeTensorMode(): 1119 x = torch.randn(3) 1120 y = copy.copy(x) 1121 self.assertEqual(y.shape, x.shape) 1122 1123 def test_data_ptr_access_fails_in_forward(self): 1124 with torch.library._scoped_library("mylib", "FRAGMENT") as lib: 1125 torch.library.define("mylib::foo", "(Tensor x) -> Tensor", lib=lib) 1126 1127 @torch.library.impl("mylib::foo", "CompositeImplicitAutograd", lib=lib) 1128 def _(x): 1129 x.data_ptr() 1130 return x.clone() 1131 1132 x = torch.randn(3) 1133 1134 def data_ptr_graph_input(x): 1135 r0 = torch.ops.mylib.foo(x) 1136 return r0 1137 1138 def data_ptr_graph_intermediate(x): 1139 y = x.clone() 1140 r0 = torch.ops.mylib.foo(y) 1141 return r0 1142 1143 tests = [data_ptr_graph_input, data_ptr_graph_intermediate] 1144 1145 def ctx(): 1146 return self.assertRaisesRegex( 1147 RuntimeError, "Cannot access data pointer" 1148 ) 1149 1150 for f in tests: 1151 with ctx(): 1152 make_fx(f, tracing_mode="fake")(x) 1153 with ctx(): 1154 make_fx(f, tracing_mode="symbolic")(x) 1155 with ctx(): 1156 torch.compile(f, backend="eager", fullgraph=True)(x) 1157 1158 def test_data_ptr_access_fails_in_backward(self): 1159 with torch.library._scoped_library("mylib", "FRAGMENT") as lib: 1160 torch.library.define("mylib::foo", "(Tensor x) -> Tensor", lib=lib) 1161 1162 backward_called = False 1163 1164 class Foo(torch.autograd.Function): 1165 @staticmethod 1166 def forward(ctx, x): 1167 return x.clone() 1168 1169 @staticmethod 1170 def backward(ctx, grad): 1171 nonlocal backward_called 1172 backward_called = True 1173 grad.data_ptr() 1174 return grad.clone() 1175 1176 @torch.library.impl("mylib::foo", "CompositeImplicitAutograd", lib=lib) 1177 def _(x): 1178 return Foo.apply(x) 1179 1180 def f(x): 1181 return torch.ops.mylib.foo(x) 1182 1183 x = torch.randn(3, requires_grad=True) 1184 with self.assertRaisesRegex(RuntimeError, "Cannot access data pointer"): 1185 y = torch.compile(f, backend="aot_eager", fullgraph=True)(x) 1186 self.assertTrue(backward_called) 1187 1188 # We don't know how to catch multiple mutations to the same memory location 1189 @unittest.expectedFailure 1190 def test_aot_autograd_expand_mutation_error(self): 1191 def fn(x): 1192 y = x.expand(3, *x.shape) 1193 y[0:3, 0].add_(5) 1194 return y 1195 1196 opt_fn = torch.compile(fn, backend="aot_eager") 1197 1198 x = torch.arange(6) 1199 x_opt = x.clone().detach() 1200 with self.assertRaises(Exception): 1201 fn(x) 1202 with self.assertRaises(Exception): 1203 opt_fn(x_opt) 1204 1205 @torch._functorch.config.patch(donated_buffer=True) 1206 def test_donated_buffer1(self): 1207 logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers" 1208 1209 @torch.compile() 1210 def relu(x): 1211 return torch.nn.functional.relu(x) 1212 1213 with self.assertLogs(logger_name, level="INFO") as captured: 1214 relu(torch.rand([3, 3], requires_grad=True)).sum().backward() 1215 1216 if is_dynamic_shape_test(self._testMethodName): 1217 # an extra symint exists 1218 expected_msg = "bw_donated_idxs=[1]" 1219 else: 1220 expected_msg = "bw_donated_idxs=[0]" 1221 1222 # le is a donated buffer from relu 1223 FileCheck().check(expected_msg).run("\n".join(captured.output)) 1224 1225 @torch._functorch.config.patch("donated_buffer", True) 1226 def test_donated_buffer2(self): 1227 logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers" 1228 1229 # we will re-use the graph for g across f1 and f2 1230 @torch.compile() 1231 def g(activation, param2): 1232 return torch.matmul(activation, param2) 1233 1234 def f(inp, param1, param2): 1235 activation = inp + param1 1236 return g(activation, param2) 1237 1238 inp = torch.ones(4, 4) 1239 param1 = torch.ones(4, 4, requires_grad=True) 1240 param2 = torch.ones(4, 4, requires_grad=True) 1241 1242 with self.assertLogs(logger_name, level="INFO") as captured: 1243 f(inp, param1, param2).sum().backward() 1244 1245 FileCheck().check("bw_donated_idxs=[]").run("\n".join(captured.output)) 1246 1247 @torch._functorch.config.patch("donated_buffer", True) 1248 def test_donated_buffer3(self): 1249 logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers" 1250 1251 # we will re-use the graph for g across f1 and f2 1252 @torch.compile() 1253 def g(activation, param2): 1254 return torch.matmul(activation, param2) 1255 1256 def f(inp, param1, param2): 1257 # exp saves it output (the activation) for bw 1258 activation = torch.exp(inp + param1) 1259 return g(activation, param2) 1260 1261 inp = torch.ones(4, 4) 1262 param1 = torch.ones(4, 4, requires_grad=True) 1263 param2 = torch.ones(4, 4, requires_grad=True) 1264 1265 with self.assertLogs(logger_name, level="INFO") as captured: 1266 f(inp, param1, param2).sum().backward() 1267 1268 FileCheck().check("bw_donated_idxs=[]").run("\n".join(captured.output)) 1269 1270 @torch._functorch.config.patch("donated_buffer", True) 1271 def test_donated_buffer4(self): 1272 logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers" 1273 1274 class Mod(torch.nn.Module): 1275 def __init__(self) -> None: 1276 super().__init__() 1277 self.param = torch.nn.Parameter(torch.zeros([2, 2])) 1278 1279 def forward(self, x: torch.Tensor) -> torch.Tensor: 1280 return torch.nn.functional.relu(x) + self.param 1281 1282 mod = Mod() 1283 mod = torch.compile(mod) 1284 1285 inp = torch.ones([2, 2], requires_grad=True) 1286 1287 with self.assertLogs(logger_name, level="INFO") as captured: 1288 mod(inp).sum().backward() 1289 1290 # Forward graph: 1291 # %primals_1 : [num_users=1] = placeholder[target=primals_1] 1292 # %primals_2 : [num_users=1] = placeholder[target=primals_2] 1293 # %relu : [num_users=2] = call_function[target=torch.ops.aten.relu.default](args = (%primals_2,), kwargs = {}) 1294 # %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%relu, %primals_1), kwargs = {}) 1295 # %le : [num_users=1] = call_function[target=torch.ops.aten.le.Scalar](args = (%relu, 0), kwargs = {}) 1296 # return [add, le] 1297 # 1298 # `le` is a donated buffer 1299 FileCheck().check("bw_donated_idxs=[0]").run("\n".join(captured.output)) 1300 1301 @torch._functorch.config.patch("donated_buffer", True) 1302 def test_donated_buffer5(self): 1303 logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers" 1304 1305 @torch.compile() 1306 def f(x, z): 1307 y = x.view(2, 3) 1308 z = torch.nn.functional.relu(z) 1309 return torch.mm(y, x) + z 1310 1311 inp = [ 1312 torch.rand([3, 2], requires_grad=True), 1313 torch.rand([2, 2], requires_grad=True), 1314 ] 1315 1316 with self.assertLogs(logger_name, level="INFO") as captured: 1317 f(*inp).sum().backward() 1318 1319 # Forward graph: 1320 # %primals_1 : [num_users=3] = placeholder[target=primals_1] 1321 # %primals_2 : [num_users=1] = placeholder[target=primals_2] 1322 # %view : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%primals_1, [2, 3]), kwargs = {}) 1323 # %relu : [num_users=2] = call_function[target=torch.ops.aten.relu.default](args = (%primals_2,), kwargs = {}) 1324 # %mm : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%view, %primals_1), kwargs = {}) 1325 # %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mm, %relu), kwargs = {}) 1326 # %le : [num_users=1] = call_function[target=torch.ops.aten.le.Scalar](args = (%relu, 0), kwargs = {}) 1327 # return [add, primals_1, le] 1328 # 1329 # `le` is a donated buffer but primals_1 is not. 1330 FileCheck().check("bw_donated_idxs=[1]").run("\n".join(captured.output)) 1331 1332 @torch._functorch.config.patch("donated_buffer", True) 1333 def test_donated_buffer_with_retain_or_create_graph1(self): 1334 # Gives non-empty bw_donated_idxs 1335 class Mod(torch.nn.Module): 1336 def __init__(self) -> None: 1337 super().__init__() 1338 self.param = torch.nn.Parameter(torch.zeros([3, 3])) 1339 1340 def forward(self, x): 1341 return torch.nn.functional.relu(x) + self.param 1342 1343 inp = torch.randn(3, 3, requires_grad=True) 1344 1345 mod = torch.compile(Mod()) 1346 for _ in range(5): 1347 mod(inp).sum().backward() 1348 1349 @torch._functorch.config.patch("donated_buffer", True) 1350 def test_donated_buffer_with_retain_or_create_graph2(self): 1351 # Gives non-empty bw_donated_idxs 1352 class Mod(torch.nn.Module): 1353 def __init__(self) -> None: 1354 super().__init__() 1355 self.param = torch.nn.Parameter(torch.zeros([3, 3])) 1356 1357 def forward(self, x): 1358 return torch.nn.functional.relu(x) + self.param 1359 1360 inp = torch.randn(3, 3, requires_grad=True) 1361 1362 mod = torch.compile(Mod()) 1363 out = mod(inp).sum() 1364 for _ in range(5): 1365 out.backward(retain_graph=True) 1366 out.backward() 1367 1368 @torch._functorch.config.patch("donated_buffer", True) 1369 def test_donated_buffer_with_retain_or_create_graph3(self): 1370 # Gives non-empty bw_donated_idxs 1371 class Mod(torch.nn.Module): 1372 def __init__(self) -> None: 1373 super().__init__() 1374 self.param = torch.nn.Parameter(torch.zeros([3, 3])) 1375 1376 def forward(self, x): 1377 return torch.nn.functional.relu(x) + self.param 1378 1379 inp = torch.randn(3, 3, requires_grad=True) 1380 1381 mod = torch.compile(Mod()) 1382 mod(inp).sum().backward(create_graph=True) 1383 out = mod(inp).sum() 1384 for _ in range(5): 1385 out.backward(retain_graph=True) 1386 out.backward() 1387 1388 @torch._functorch.config.patch("donated_buffer", True) 1389 def test_donated_buffer_with_retain_or_create_graph4(self): 1390 # Gives non-empty bw_donated_idxs 1391 class Mod(torch.nn.Module): 1392 def __init__(self) -> None: 1393 super().__init__() 1394 self.param = torch.nn.Parameter(torch.zeros([3, 3])) 1395 1396 def forward(self, x): 1397 return torch.nn.functional.relu(x) + self.param 1398 1399 inp = torch.randn(3, 3, requires_grad=True) 1400 1401 mod = torch.compile(Mod()) 1402 mod(inp).sum().backward() 1403 out = mod(inp).sum() 1404 with self.assertRaisesRegex( 1405 RuntimeError, 1406 r"This backward function was compiled with non-empty donated " 1407 r"buffers which requires create_graph=False and retain_graph=False. " 1408 r"Please keep backward\(create_graph=False, retain_graph=False\) " 1409 r"across all backward\(\) function calls, or set " 1410 r"torch._functorch.config.donated_buffer=False to disable " 1411 r"donated buffer.", 1412 ): 1413 out.backward(retain_graph=True) 1414 1415 1416if __name__ == "__main__": 1417 from torch._dynamo.test_case import run_tests 1418 1419 run_tests() 1420