1# Owner(s): ["module: dynamo"] 2# flake8: noqa: B950 3import copy 4import math 5from dataclasses import dataclass 6 7import torch 8import torch._dynamo.test_case 9import torch._dynamo.testing 10import torch._dynamo.utils 11from torch.testing._internal.triton_utils import HAS_CUDA, requires_cuda 12 13 14if HAS_CUDA: 15 import triton 16 17 from torch.testing._internal.triton_utils import add_kernel 18 19 20class CustomFunc1(torch.autograd.Function): 21 @staticmethod 22 def forward(ctx, foo): 23 return foo + foo 24 25 @staticmethod 26 def backward(ctx, grad_output): 27 return grad_output 28 29 30class CustomFunc3(torch.autograd.Function): 31 # Test there is graph break in forward function 32 @staticmethod 33 def forward(ctx, foo): 34 result = foo + foo 35 torch._dynamo.graph_break() 36 result = result + foo 37 ctx.save_for_backward(result) 38 return result 39 40 @staticmethod 41 def backward(ctx, grad_output): 42 (result,) = ctx.saved_tensors 43 return grad_output * math.sqrt(result.numel()) 44 45 46class Module1(torch.nn.Module): 47 def forward(self, foo): 48 return CustomFunc1().apply(foo) 49 50 51class Module2(torch.nn.Module): 52 def __init__(self) -> None: 53 super().__init__() 54 self.fn = CustomFunc1.apply 55 56 def forward(self, foo): 57 return self.fn(foo) 58 59 60class Module3(torch.nn.Module): 61 def forward(self, foo): 62 return CustomFunc1().apply(foo) 63 64 65class Module4(torch.nn.Module): 66 def __init__(self) -> None: 67 super().__init__() 68 self.fn = CustomFunc1.apply 69 70 def forward(self, foo): 71 return self.fn(foo) 72 73 74class Module5(torch.nn.Module): 75 def forward(self, foo): 76 return CustomFunc3().apply(foo) 77 78 79class Module6(torch.nn.Module): 80 def __init__(self) -> None: 81 super().__init__() 82 self.fn = CustomFunc3.apply 83 84 def forward(self, foo): 85 return self.fn(foo) 86 87 88class LinearFunction(torch.autograd.Function): 89 # Note that forward, setup_context, and backward are @staticmethods 90 @staticmethod 91 def forward(input, weight, bias): 92 output = input.mm(weight.t()) 93 if bias is not None: 94 output += bias.unsqueeze(0).expand_as(output) 95 return output 96 97 @staticmethod 98 # inputs is a Tuple of all of the inputs passed to forward. 99 # output is the output of the forward(). 100 def setup_context(ctx, inputs, output): 101 input, weight, bias = inputs 102 ctx.save_for_backward(input, weight, bias) 103 104 # This function has only a single output, so it gets only one gradient 105 @staticmethod 106 def backward(ctx, grad_output): 107 input, weight, bias = ctx.saved_tensors 108 grad_input = grad_weight = grad_bias = None 109 if ctx.needs_input_grad[0]: 110 grad_input = grad_output.mm(weight) 111 if ctx.needs_input_grad[1]: 112 grad_weight = grad_output.t().mm(input) 113 if bias is not None and ctx.needs_input_grad[2]: 114 grad_bias = grad_output.sum(0) 115 116 return grad_input, grad_weight, grad_bias 117 118 119class ModuleLinear(torch.nn.Module): 120 def forward(self, input, weight, bias=None): 121 return LinearFunction.apply(input, weight, bias) 122 123 124class MaterializingGradFunction(torch.autograd.Function): 125 @staticmethod 126 def forward(ctx, x): 127 ctx.set_materialize_grads(False) 128 return x.clone(), x.clone() 129 130 @staticmethod 131 def backward(ctx, grad_out1, grad_out2): 132 return grad_out1, grad_out2 133 134 135class MaterializingGradModule(torch.nn.Module): 136 def forward(self, x): 137 return MaterializingGradFunction.apply(x) 138 139 140class CustomFuncBwdPrintGraphBreak(torch.autograd.Function): 141 @staticmethod 142 def forward(ctx, foo): 143 return torch.add(foo, foo) 144 145 @staticmethod 146 def backward(ctx, grad_output): 147 print("graph break!") 148 return grad_output 149 150 151class CustomFuncBwdPrintModule(torch.nn.Module): 152 def forward(self, x): 153 return CustomFuncBwdPrintGraphBreak.apply(x) 154 155 156class CustomFuncStrideBwd(torch.autograd.Function): 157 @staticmethod 158 def forward(ctx, foo): 159 return torch.add(foo, foo) 160 161 @staticmethod 162 def backward(ctx, grad_output): 163 return grad_output.stride() 164 165 166class CustomFuncStrideModule(torch.nn.Module): 167 def forward(self, x): 168 return CustomFuncStrideBwd.apply(x) 169 170 171class CustomFuncSaveForBwd(torch.autograd.Function): 172 @staticmethod 173 def forward(ctx, foo): 174 result = foo + foo 175 result = result + foo 176 ctx.save_for_backward(result) 177 return result 178 179 @staticmethod 180 def backward(ctx, grad_output): 181 (result,) = ctx.saved_tensors 182 return grad_output * math.sqrt(result.numel()) 183 184 185class SaveForBwdModule(torch.nn.Module): 186 def forward(self, foo): 187 return CustomFuncSaveForBwd().apply(foo) 188 189 190class ContextSaveAndMark(torch.autograd.Function): 191 @staticmethod 192 def forward(ctx, x): 193 with torch.no_grad(): 194 ctx.save_for_backward(x) 195 ctx.mark_non_differentiable(x) 196 return x 197 198 @staticmethod 199 def backward(ctx, grad_output): 200 return grad_output 201 202 203class ContextMarkAndSave(torch.autograd.Function): 204 @staticmethod 205 def forward(ctx, x): 206 with torch.no_grad(): 207 ctx.mark_non_differentiable(x) 208 ctx.save_for_backward(x) 209 return x 210 211 @staticmethod 212 def backward(ctx, grad_output): 213 return grad_output 214 215 216class ModuleWithGradFunc(torch.nn.Module): 217 def __init__(self, func): 218 super().__init__() 219 self.f = func.apply 220 221 def forward(self, x): 222 return self.f(x) 223 224 225class AutogradFunctionTests(torch._dynamo.test_case.TestCase): 226 # Sound behaviors, tested for working capture 227 def test_autograd_function_equivalence(self): 228 for grad in [True, False]: 229 for i in range(1, 5): 230 torch._dynamo.reset() 231 model = globals()[f"Module{i}"]() 232 opt_model = torch._dynamo.optimize("eager")(model) 233 self.assertTrue( 234 torch.allclose( 235 opt_model(torch.ones(2, 3, requires_grad=grad)), 236 torch.tensor([2.0], requires_grad=grad), 237 ) 238 ) 239 240 def test_autograd_function_has_graph_break(self): 241 for grad in [True, False]: 242 x = torch.randn(10, requires_grad=grad) 243 for model in [Module5(), Module6()]: 244 torch._dynamo.reset() 245 cnts = torch._dynamo.testing.CompileCounter() 246 opt_model = torch._dynamo.optimize(cnts)(model) 247 for _ in range(3): 248 ref = model(x) 249 res = opt_model(x) 250 self.assertTrue(torch.allclose(ref, res)) 251 self.assertEqual(cnts.frame_count, 2) 252 253 def test_linear_setup_context(self): 254 model = ModuleLinear() 255 opt_model = torch._dynamo.optimize("eager", nopython=True)(model) 256 input = torch.randn(2, 2, dtype=torch.double, requires_grad=True) 257 weight = torch.randn(3, 2, dtype=torch.double, requires_grad=True) 258 eager_result = model(input, weight) 259 optim_result = opt_model(input, weight) 260 self.assertEqual(optim_result, eager_result) 261 262 def test_materialize_grad(self): 263 model = MaterializingGradModule() 264 opt_model = torch._dynamo.optimize("eager")(model) 265 x = torch.randn(2, 2, dtype=torch.double, requires_grad=True) 266 optim_result = opt_model(x) 267 eager_result = model(x) 268 self.assertEqual(optim_result, eager_result) 269 270 def test_print_in_bwd(self): 271 model = CustomFuncBwdPrintModule() 272 opt_model = torch._dynamo.optimize("eager", nopython=True)(model) 273 x = torch.randn(2, 2, dtype=torch.double, requires_grad=True) 274 with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, "builtin: print"): 275 opt_model(x) 276 277 def test_stride_in_bwd(self): 278 torch._dynamo.utils.counters.clear() 279 cnt = torch._dynamo.testing.CompileCounter() 280 model = CustomFuncStrideModule() 281 opt_model = torch.compile(backend=cnt)(model) 282 x = torch.randn(2, 2, dtype=torch.double, requires_grad=True) 283 ref = model(x) 284 res = opt_model(x) 285 286 self.assertEqual(ref, res) 287 self.assertEqual(cnt.frame_count, 1) 288 # graph break: Illegal getattr invocation stride in strict mod. 289 self.assertEqual( 290 list(torch._dynamo.utils.counters["graph_break"].values()), [1] 291 ) 292 293 def test_enum_arg(self): 294 from enum import Enum 295 296 class SomeEnum(Enum): 297 A = 0 298 B = 1 299 300 class Foo(torch.autograd.Function): 301 @staticmethod 302 def forward(ctx, x, e): 303 if e is SomeEnum.A: 304 return x.sin() 305 else: 306 return x.cos() 307 308 @staticmethod 309 def backward(ctx, g): 310 return g 311 312 @torch.compile(backend="eager", fullgraph=True) 313 def f(x, enum): 314 output = Foo.apply( 315 x, 316 enum, 317 ) 318 return output 319 320 x = torch.tensor([[1.0, 2, 3], [4, 5, 6]], requires_grad=True) 321 y = f(x, SomeEnum.A) 322 self.assertEqual(y, x.sin()) 323 324 def test_save_for_bwd(self): 325 model = SaveForBwdModule() 326 opt_model = torch._dynamo.optimize("eager", nopython=True)(model) 327 x = torch.randn(2, 2, dtype=torch.double, requires_grad=True) 328 opt_model(x) 329 330 def test_allow_in_graph(self): 331 torch._dynamo.utils.counters.clear() 332 cnt = torch._dynamo.testing.CompileCounter() 333 334 @torch._dynamo.allow_in_graph 335 class AllowInGraphFunc(torch.autograd.Function): 336 @staticmethod 337 def forward(ctx, x): 338 torch._dynamo.graph_break() 339 ctx.x0 = x.size(0) 340 return x * 2 341 342 @staticmethod 343 def backward(ctx, grad_out): 344 return grad_out * ctx.x0 345 346 @torch.compile(backend=cnt, fullgraph=True) 347 def fn(x): 348 return AllowInGraphFunc.apply(x) 349 350 x = torch.rand(2, 3, requires_grad=True) 351 result = fn(x) 352 353 self.assertEqual(result, AllowInGraphFunc.apply(x)) 354 self.assertEqual(cnt.frame_count, 1) 355 356 def test_once_differentiable(self): 357 from torch.autograd.function import once_differentiable 358 359 torch._dynamo.utils.counters.clear() 360 cnt = torch._dynamo.testing.CompileCounter() 361 362 class ScaleGradient(torch.autograd.Function): 363 @staticmethod 364 def forward(ctx, x): 365 return x 366 367 @staticmethod 368 @once_differentiable 369 def backward(ctx, grad): 370 return grad * 0.5 371 372 @torch.compile(backend=cnt, fullgraph=True) 373 def fn(x): 374 return ScaleGradient.apply(x) 375 376 x = torch.randn(3, requires_grad=True) 377 result = fn(x) 378 379 self.assertEqual(result, ScaleGradient.apply(x)) 380 self.assertEqual(cnt.frame_count, 1) 381 382 def test_classmethod(self): 383 class Shake(torch.autograd.Function): 384 @classmethod 385 def forward(cls, ctx, foo): 386 return foo + foo 387 388 @classmethod 389 def backward(cls, ctx, grad_output): 390 return grad_output 391 392 def f(x): 393 return Shake.apply(x) 394 395 x = torch.randn(4, 4, 4, 4, requires_grad=True) 396 opt_m = torch.compile(backend="eager")(f) 397 opt_m(x) 398 399 def test_function_context_save_and_mark(self): 400 mod = ModuleWithGradFunc(ContextSaveAndMark) 401 args, kwargs = ([torch.rand([1])], {}) 402 before = mod(*args, **kwargs) 403 404 torch._dynamo.reset() 405 compiled_model = torch._dynamo.optimize("eager")(mod) 406 after = compiled_model(*args, **kwargs) 407 self.assertEqual(before, after) 408 409 def test_function_context_mark_and_save(self): 410 mod = ModuleWithGradFunc(ContextMarkAndSave) 411 args, kwargs = ([torch.rand([1])], {}) 412 before = mod(*args, **kwargs) 413 414 torch._dynamo.reset() 415 compiled_model = torch._dynamo.optimize("eager")(mod) 416 after = compiled_model(*args, **kwargs) 417 self.assertEqual(before, after) 418 419 def test_multi_output(self): 420 torch._dynamo.utils.counters.clear() 421 cnt = torch._dynamo.testing.CompileCounter() 422 423 class Foo(torch.autograd.Function): 424 @staticmethod 425 def forward(ctx, x): 426 return x.clone(), x.clone() 427 428 @staticmethod 429 def backward(ctx, grad1, grad2): 430 return grad1 + grad2 431 432 @torch.compile(backend=cnt, fullgraph=True) 433 def f(x): 434 return Foo.apply(x) 435 436 x = torch.randn(3, requires_grad=True) 437 result = f(x) 438 439 self.assertEqual(result, Foo.apply(x)) 440 self.assertEqual(cnt.frame_count, 1) 441 442 def test_amp_custom_fwd_bwd(self): 443 torch._dynamo.utils.counters.clear() 444 cnt = torch._dynamo.testing.CompileCounter() 445 446 class MyMM(torch.autograd.Function): 447 @staticmethod 448 @torch.amp.custom_fwd(device_type="cuda") 449 def forward(ctx, a, b): 450 ctx.save_for_backward(a, b) 451 return a.mm(b) 452 453 @staticmethod 454 @torch.amp.custom_bwd(device_type="cuda") 455 def backward(ctx, grad): 456 a, b = ctx.saved_tensors 457 return grad.mm(b.t()), a.t().mm(grad) 458 459 @torch.compile(backend=cnt, fullgraph=True) 460 def fn(a, b): 461 return MyMM.apply(a, b) 462 463 a = torch.randn([64, 64], dtype=torch.float32, requires_grad=True) 464 grad = a.clone() 465 res = fn(a, a) 466 res.backward(grad) 467 468 self.assertEqual(res, MyMM.apply(a, a)) 469 self.assertEqual(cnt.frame_count, 1) 470 471 def test_set_materialize_grads_no_graph_break(self): 472 class MulY(torch.autograd.Function): 473 @staticmethod 474 def forward(ctx, x): 475 ctx.set_materialize_grads(True) 476 return x * 3 477 478 @staticmethod 479 def backward(ctx, grad_out): 480 return grad_out * 3 481 482 @torch.compile(backend="eager", fullgraph=True) 483 def f(x): 484 return MulY.apply(x) 485 486 x = torch.tensor(2.0, requires_grad=True) 487 result = f(x) 488 result.sum().backward() 489 self.assertEqual(result, MulY.apply(x)) 490 self.assertEqual(x.grad, 3.0) 491 492 def test_user_defined_object_as_input(self): 493 cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") 494 495 @dataclass 496 class Weird: 497 x: int 498 b: torch.Tensor 499 c: torch.Tensor 500 501 class Foo(torch.autograd.Function): 502 @staticmethod 503 def forward(ctx, x: torch.Tensor, weird: Weird, z: torch.Tensor): 504 ctx.save_for_backward(weird.b, weird.c) 505 return weird.b * weird.c * x.clone() 506 507 @staticmethod 508 def backward(ctx, grad): 509 b, c = ctx.saved_tensors 510 return grad * b * c, None, grad * 2 511 512 @torch.compile(backend=cnt, fullgraph=True) 513 def f(x, weird, z): 514 return Foo.apply(x, weird, z) 515 516 x = torch.tensor(2.0, requires_grad=True) 517 weird = Weird(1.2, torch.tensor(2.5, requires_grad=True), torch.tensor(3.5)) 518 z = torch.tensor(3.0, requires_grad=True) 519 520 result = f(x, weird, z) 521 result.sum().backward() 522 523 self.assertEqual(result, Foo.apply(x, weird, z)) 524 self.assertEqual(x.grad, 2.5 * 3.5) 525 self.assertEqual(z.grad, 2.0) 526 self.assertEqual(weird.b.grad, None) 527 528 # check Dynamo captured graph is correct! 529 actual_graph = torch._dynamo.testing.normalize_gm( 530 cnt.graphs[0].print_readable(print_output=False) 531 ) 532 self.assertExpectedInline( 533 actual_graph, 534 """\ 535class GraphModule(torch.nn.Module): 536 def forward(self, L_x_: "f32[]", L_z_: "f32[]", L_weird_b: "f32[]", L_weird_c: "f32[]"): 537 l_x_ = L_x_ 538 l_z_ = L_z_ 539 l_weird_b = L_weird_b 540 l_weird_c = L_weird_c 541 542 function_ctx = torch.autograd.function.FunctionCtx(); function_ctx = None 543 fwd_body_0 = self.fwd_body_0 544 bwd_body_0 = self.bwd_body_0 545 autograd_function_apply: "f32[]" = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_x_, l_z_, l_weird_b, l_weird_c, args_tensor_mask = [True, False, True], non_differentiable_idx = []); fwd_body_0 = bwd_body_0 = l_x_ = l_z_ = l_weird_b = l_weird_c = None 546 return (autograd_function_apply,) 547 548 class fwd_body_0(torch.nn.Module): 549 def forward(self, ctx, x: "f32[]", z: "f32[]", l_weird_b: "f32[]", l_weird_c: "f32[]"): 550 mul: "f32[]" = l_weird_b * l_weird_c 551 clone: "f32[]" = x.clone(); x = None 552 mul_1: "f32[]" = mul * clone; mul = clone = None 553 return (mul_1, [l_weird_b, l_weird_c]) 554 555 class bwd_body_0(torch.nn.Module): 556 def forward(self, ctx, grad: "f32[]", l_weird_b: "f32[]", l_weird_c: "f32[]"): 557 _set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None 558 559 mul: "f32[]" = grad * l_weird_b; l_weird_b = None 560 mul_1: "f32[]" = mul * l_weird_c; mul = l_weird_c = None 561 mul_2: "f32[]" = grad * 2; grad = None 562 563 _set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None 564 return (mul_1, mul_2) 565""", 566 ) 567 568 def test_tensor_list_as_input(self): 569 class Foo(torch.autograd.Function): 570 @staticmethod 571 def forward(ctx, x, tl): 572 ctx.save_for_backward(tl[0], tl[1]) 573 return x.clone() * (tl[0] + tl[1]) 574 575 @staticmethod 576 def backward(ctx, grad): 577 tl0, tl1 = ctx.saved_tensors 578 return grad * (tl0 + tl1), None 579 580 @torch.compile(backend="aot_eager", fullgraph=True) 581 def f(x, tl): 582 return Foo.apply(x, tl) 583 584 x = torch.tensor(2.0, requires_grad=True) 585 tl = [ 586 torch.tensor(3.0, requires_grad=True), 587 torch.tensor(4.0, requires_grad=True), 588 ] 589 590 result = f(x, tl) 591 result.sum().backward() 592 593 self.assertEqual(result, Foo.apply(x, tl)) 594 self.assertEqual(x.grad, 7.0) 595 self.assertEqual(tl[0].grad, None) 596 self.assertEqual(tl[1].grad, None) 597 598 def test_multiple_different_non_tensor_inputs(self): 599 @dataclass 600 class Weird: 601 x: int 602 b: torch.Tensor 603 c: torch.Tensor 604 605 class Foo(torch.autograd.Function): 606 @staticmethod 607 def forward(ctx, x, weird, z, tl): 608 ctx.save_for_backward(weird.b, weird.c, tl[0], tl[1]) 609 return x.clone() * weird.b * weird.c * tl[0] 610 611 @staticmethod 612 def backward(ctx, grad): 613 b, c, tl0, _ = ctx.saved_tensors 614 return grad * b * c * tl0, None, grad * 2, None 615 616 @torch.compile(backend="aot_eager", fullgraph=True) 617 def f(x, weird, z, tl): 618 return Foo.apply(x, weird, z, tl) 619 620 x = torch.tensor(2.0, requires_grad=True) 621 weird = Weird( 622 1.2, 623 torch.tensor(2.5, requires_grad=True), 624 torch.tensor(3.5, requires_grad=True), 625 ) 626 z = torch.tensor(3.0, requires_grad=True) 627 tl = [ 628 torch.tensor(0.5, requires_grad=True), 629 torch.tensor(0.6, requires_grad=True), 630 ] 631 632 result = f(x, weird, z, tl) 633 result.sum().backward() 634 635 self.assertEqual(result, Foo.apply(x, weird, z, tl)) 636 self.assertEqual(x.grad, 2.5 * 3.5 * 0.5) 637 self.assertEqual(z.grad, 2.0) 638 self.assertEqual(weird.b.grad, None) 639 self.assertEqual(weird.c.grad, None) 640 self.assertEqual(tl[0].grad, None) 641 self.assertEqual(tl[1].grad, None) 642 643 def test_backward_returns_none_for_tensor_input(self): 644 class Foo(torch.autograd.Function): 645 @staticmethod 646 def forward(ctx, x, y): 647 ctx.save_for_backward(y) 648 return x.clone() * y 649 650 @staticmethod 651 def backward(ctx, grad): 652 (y,) = ctx.saved_tensors 653 return grad * y, None 654 655 @torch.compile(backend="aot_eager", fullgraph=True) 656 def f(x, y): 657 return Foo.apply(x, y) 658 659 x = torch.tensor(2.0, requires_grad=True) 660 y = torch.tensor(3.0, requires_grad=True) 661 662 result = f(x, y) 663 result.sum().backward() 664 665 self.assertEqual(result, Foo.apply(x, y)) 666 self.assertEqual(x.grad, 3.0) 667 self.assertEqual(y.grad, None) 668 669 def test_function_with_bound_free_variable(self): 670 class LowerBound(torch.autograd.Function): 671 @staticmethod 672 def forward(ctx, inputs, bound): 673 ctx.save_for_backward(inputs, inputs.new_ones(1) * bound) 674 return inputs.clamp(min=bound) 675 676 @staticmethod 677 def backward(ctx, grad_output): 678 inputs, bound = ctx.saved_tensors 679 return (inputs >= bound) * grad_output, None 680 681 class MyMod(torch.nn.Module): 682 def __init__(self) -> None: 683 super().__init__() 684 self.gamma = torch.nn.Parameter(torch.rand([4, 128, 32, 32])) 685 686 def forward(self, x): 687 gamma = LowerBound.apply(self.gamma, 1) 688 return x + gamma 689 690 mod = MyMod() 691 args, kwargs = ([torch.rand([4, 128, 32, 32])], {}) 692 before = mod(*args, **kwargs) 693 694 compiled_model = torch._dynamo.optimize("eager")(mod) 695 after = compiled_model(*args, **kwargs) 696 self.assertEqual(before, after) 697 698 # I pulled all of these test cases from test_autograd.py 699 # In the future, we should make the Dynamo test suite actually 700 # run on test_autograd.py (it's disabled right now) and delete these. 701 def test_smoke_from_test_autograd(self): 702 def mult1(x): 703 return x.prod(dim=-1).prod(dim=-1) 704 705 class Mult(torch.autograd.Function): 706 @staticmethod 707 def forward(ctx, x): 708 y = mult1(x) 709 ctx.save_for_backward(x, y) 710 return y 711 712 @staticmethod 713 def backward(ctx, grad_output): 714 x, y = ctx.saved_tensors 715 return (grad_output * y)[:, None, None] / x 716 717 mult2 = Mult.apply 718 719 class Double(torch.autograd.Function): 720 @staticmethod 721 def forward(ctx, x): 722 y = x**2 723 ctx.save_for_backward(x, y) 724 return y 725 726 @staticmethod 727 def backward(ctx, grad_output): 728 x, _ = ctx.saved_tensors 729 return grad_output * 2 * x 730 731 # this is equivalent, but uses the output of .forward() in .backward() 732 class Double2(Double): 733 @staticmethod 734 def backward(ctx, grad_output): 735 x, y = ctx.saved_tensors 736 return grad_output * 2 * y / x 737 738 double = Double.apply 739 double2 = Double2.apply 740 741 class Identity(torch.autograd.Function): 742 @staticmethod 743 def forward(ctx, a, b): 744 return a, a + b 745 746 @staticmethod 747 def backward(ctx, grad_a, grad_b): 748 return grad_a + grad_b, grad_b 749 750 class MyFunc2(torch.autograd.Function): 751 @staticmethod 752 def forward(ctx, inp): 753 return inp.clone() 754 755 @staticmethod 756 def backward(ctx, gO): 757 return torch.tensor(float("nan")).expand(10, 10) 758 759 def run_fn(a): 760 out = MyFunc2.apply(a) 761 return out.sum() 762 763 class MyFn(torch.autograd.Function): 764 @staticmethod 765 def forward(ctx, inp): 766 return inp.view_as(inp) 767 768 @staticmethod 769 def backward(ctx, grad): 770 return grad 771 772 class MyAdder(torch.autograd.Function): 773 @staticmethod 774 def forward(ctx, a, b): 775 a.add_(b) 776 ctx.mark_dirty(a) 777 return a 778 779 @staticmethod 780 def backward(ctx, grad): 781 return grad, grad 782 783 class InplaceMul(torch.autograd.Function): 784 @staticmethod 785 def forward(ctx, x): 786 result = x.mul_(2) 787 ctx.mark_dirty(result) 788 return result 789 790 @staticmethod 791 def backward(ctx, grad_output): 792 pass 793 794 @staticmethod 795 def jvp(ctx, x_t): 796 if jvp_err: # noqa: F821 797 return x_t 798 else: 799 return x_t.mul_(2) 800 801 class MyFn2(torch.autograd.Function): 802 @staticmethod 803 def forward(ctx, x, y): 804 return x + y, x 805 806 @staticmethod 807 def vjp(ctx, gO1, gO2): 808 return gO1 + gO2, gO1 809 810 @staticmethod 811 def jvp(ctx, x_t, y_t): 812 return x_t + y_t, fn(x_t) # noqa: F821 813 814 class MyFn3(torch.autograd.Function): 815 @staticmethod 816 def forward(ctx, inp, inplace): 817 view = inp.clone()[:3] 818 if inplace: 819 view += 2 820 return view 821 822 @staticmethod 823 def backward(ctx, grad): 824 return grad, None 825 826 def test(): 827 x = torch.ones(2, 4, 4).requires_grad_() 828 mult2(x) 829 830 x = torch.tensor(2).double().requires_grad_() 831 double(x) 832 double2(x) 833 834 x = torch.randn(5, 5, requires_grad=True) 835 y = torch.randn(5, 5, requires_grad=True) 836 q, p = Identity.apply(x, y) 837 838 a = torch.rand(1, 2) 839 b = torch.rand(1, requires_grad=True) 840 view_a = MyFn.apply(a) 841 842 a = torch.ones(2, requires_grad=True) 843 b = torch.ones(2, requires_grad=True) 844 c = MyAdder.apply(a.clone(), b) 845 c.sum().backward() 846 847 z = torch.tensor(1.0, requires_grad=True) 848 x = z.clone() 849 y = InplaceMul.apply(x) 850 851 a = torch.tensor(1.0, dtype=torch.double, requires_grad=True) 852 b = torch.tensor(1.0, dtype=torch.double, requires_grad=True) 853 c = torch.tensor(1.0, dtype=torch.double) 854 d = torch.tensor(1.0, dtype=torch.double) 855 MyFn2.apply(a, b) 856 MyFn2.apply(c, d) 857 858 base = torch.rand(10, requires_grad=True) 859 foo = MyFn3.apply(base, False) 860 861 test() 862 opt_test = torch._dynamo.optimize("eager")(test) 863 opt_test() 864 865 def test_tensor_subclass_intermediary_input(self): 866 class FooTensor(torch.Tensor): 867 @staticmethod 868 def __new__(cls, data, config, scale): 869 self = torch.Tensor._make_wrapper_subclass( 870 cls, 871 config[0], 872 strides=config[1], 873 storage_offset=config[2], 874 dtype=config[3], 875 layout=config[4], 876 requires_grad=config[5], 877 device=data.device, 878 ) 879 self._data = data 880 self._config = config 881 self._scale = scale 882 return self 883 884 def __repr__(self): 885 return "FooTensor" 886 887 def __tensor_flatten__(self): 888 return ("_data",), ( 889 self._config, 890 self._scale, 891 ) 892 893 @staticmethod 894 def __tensor_unflatten__(tensors, metadatas, outer_size, outer_stride): 895 return FooTensor(tensors["_data"], metadatas[0], metadatas[1]) 896 897 @classmethod 898 def __torch_dispatch__(cls, func, types, args, kwargs=None): 899 # handling clone and view is so dynamo fakefication passes, it's not 900 # intended to be handling user code 901 if func == torch.ops.aten.clone.default: 902 return FooTensor( 903 args[0]._data.clone(), args[0]._config, args[0]._scale 904 ) 905 elif func == torch.ops.aten.view.default: 906 new_data = args[0]._data.view(*args[1:]) 907 return FooTensor(new_data, args[0]._config, args[0]._scale) 908 909 raise NotImplementedError 910 911 class foo_autograd_fn(torch.autograd.Function): 912 @staticmethod 913 def forward(ctx, x): 914 # access some data from `x`, where `x` is a tensor subclass 915 x2 = x._data + 1.0 916 # create and return a tensor subclass from within a torch.autograd.Function 917 x3 = FooTensor(x2, x._config, x._scale) 918 return x3._data 919 920 @staticmethod 921 def backward(ctx, g): 922 return g 923 924 x_ref = torch.randn(4, 4).requires_grad_(True) 925 x = copy.deepcopy(x_ref) 926 scale = torch.tensor(1.0) 927 # Weird that this is needed, but not having this breaks a lot of things 928 torch._dynamo.allow_in_graph(FooTensor) 929 930 def foo(x, scale): 931 config = ( 932 x.size(), 933 x.stride(), 934 x.storage_offset(), 935 x.dtype, 936 x.layout, 937 x.requires_grad, 938 ) 939 x = FooTensor(x, config, scale) 940 x = foo_autograd_fn.apply(x) 941 return x 942 943 y_ref = foo(x_ref, scale) 944 y_ref.sum().backward() 945 946 foo_opt = torch.compile(foo, backend="eager") 947 y = foo_opt(x, scale) 948 y.sum().backward() 949 950 self.assertEqual(y, y_ref) 951 self.assertEqual(x.grad, x_ref.grad) 952 953 def test_smuggle_symint_issue_111031(self): 954 from torch.autograd import Function 955 956 class Foo(Function): 957 @staticmethod 958 def forward(ctx, x): 959 ctx.x0 = x.size(0) 960 return x * 2 961 962 @staticmethod 963 def backward(ctx, grad_out): 964 return grad_out * ctx.x0 965 966 cnts = torch._dynamo.testing.CompileCounter() 967 968 @torch.compile(backend=cnts, fullgraph=True, dynamic=True) 969 def foo(x): 970 return Foo.apply(x) 971 972 foo(torch.randn(2, requires_grad=True)) 973 self.assertEqual(cnts.frame_count, 1) 974 975 def test_needs_input_grad(self): 976 cnt = torch._dynamo.testing.CompileCounter() 977 978 class NeedsInputGradFunc(torch.autograd.Function): 979 @staticmethod 980 def forward(ctx, foo): 981 result = foo + foo 982 ctx.save_for_backward(result) 983 return result 984 985 @staticmethod 986 @torch.compile(backend=cnt, fullgraph=True) 987 def backward(ctx, grad_output): 988 (result,) = ctx.saved_tensors 989 if ctx.needs_input_grad[0]: 990 return grad_output * result.sin() 991 return None 992 993 x = torch.randn(10, requires_grad=True) 994 NeedsInputGradFunc.apply(x).sum().backward() 995 self.assertEqual(x.grad.shape, x.shape) 996 self.assertEqual(cnt.frame_count, 1) 997 self.assertEqual(cnt.op_count, 2) 998 999 def test_repeated_save_for_backward_calls(self): 1000 from torch.autograd import Function 1001 1002 class Foo(Function): 1003 @staticmethod 1004 def forward(ctx, x, y): 1005 ctx.save_for_backward(x) 1006 ctx.save_for_backward(x, y) 1007 return x * y 1008 1009 @staticmethod 1010 def backward(ctx, grad_out): 1011 x, y = ctx.saved_tensors 1012 return grad_out * x, grad_out * y 1013 1014 cnts = torch._dynamo.testing.CompileCounter() 1015 1016 def foo(x, y): 1017 return Foo.apply(x, y) 1018 1019 x_ref = torch.randn(2, requires_grad=True) 1020 y_ref = torch.randn(2, requires_grad=True) 1021 x_test = x_ref.clone().detach().requires_grad_() 1022 y_test = y_ref.clone().detach().requires_grad_() 1023 1024 out_ref = foo(x_ref, y_ref) 1025 out_ref.sum().backward() 1026 1027 out_test = torch.compile(foo, backend=cnts)(x_test, y_test) 1028 out_test.sum().backward() 1029 1030 self.assertEqual(cnts.frame_count, 1) 1031 self.assertEqual(out_ref, out_test) 1032 self.assertEqual(x_ref.grad, x_test.grad) 1033 self.assertEqual(y_ref.grad, y_test.grad) 1034 1035 def test_smuggle_tensor_and_complex_structures(self): 1036 from torch.autograd import Function 1037 1038 class Foo(Function): 1039 @staticmethod 1040 def forward(ctx, x): 1041 ctx.x0 = x 1042 ctx.x1 = [1, 2, 3] 1043 return x * 2 1044 1045 @staticmethod 1046 def backward(ctx, grad_out): 1047 x0mul = grad_out * ctx.x0 1048 for i in ctx.x1: 1049 x0mul = (x0mul * i) + x0mul 1050 return x0mul 1051 1052 cnts = torch._dynamo.testing.CompileCounter() 1053 1054 @torch.compile(backend=cnts, fullgraph=True, dynamic=True) 1055 def foo(x): 1056 return Foo.apply(x) 1057 1058 foo(torch.randn(2, requires_grad=True)) 1059 self.assertEqual(cnts.frame_count, 1) 1060 1061 def test_mark_non_differentiable(self): 1062 cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") 1063 from torch.autograd import Function 1064 1065 class MyFunction(Function): 1066 @staticmethod 1067 def forward(ctx, x, y): 1068 out1 = x.sin() 1069 out2 = y * 2 1070 ctx.mark_non_differentiable(out2) 1071 return out1, out2 1072 1073 @staticmethod 1074 def backward(ctx, grad1, grad2): 1075 return grad1.cos(), grad2 * 0.0 1076 1077 @torch.compile(backend=cnt, fullgraph=True) 1078 def fn(x, y): 1079 return MyFunction.apply(x, y) 1080 1081 x = torch.tensor(10.0, requires_grad=True) 1082 y = torch.tensor(20.0, requires_grad=True) 1083 ref1, ref2 = MyFunction.apply(x, y) 1084 res1, res2 = fn(x, y) 1085 self.assertEqual(ref1, res1) 1086 self.assertEqual(ref2, res2) 1087 # Ensure out1 requires gradients, out2 does not. 1088 self.assertTrue(ref1.requires_grad) 1089 self.assertTrue(res1.requires_grad) 1090 self.assertFalse(ref2.requires_grad) 1091 self.assertFalse(res2.requires_grad) 1092 res1.sum().backward() 1093 1094 # check Dynamo captured graph is correct! 1095 actual_graph = torch._dynamo.testing.normalize_gm( 1096 cnt.graphs[0].print_readable(print_output=False) 1097 ) 1098 self.assertExpectedInline( 1099 actual_graph, 1100 """\ 1101class GraphModule(torch.nn.Module): 1102 def forward(self, L_x_: "f32[]", L_y_: "f32[]"): 1103 l_x_ = L_x_ 1104 l_y_ = L_y_ 1105 1106 function_ctx = torch.autograd.function.FunctionCtx(); function_ctx = None 1107 fwd_body_0 = self.fwd_body_0 1108 bwd_body_0 = self.bwd_body_0 1109 autograd_function_apply = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_x_, l_y_, args_tensor_mask = [True, True], non_differentiable_idx = [1]); fwd_body_0 = bwd_body_0 = l_x_ = l_y_ = None 1110 getitem: "f32[]" = autograd_function_apply[0] 1111 getitem_1: "f32[]" = autograd_function_apply[1]; autograd_function_apply = None 1112 return (getitem, getitem_1) 1113 1114 class fwd_body_0(torch.nn.Module): 1115 def forward(self, ctx, x: "f32[]", y: "f32[]"): 1116 out1: "f32[]" = x.sin(); x = None 1117 1118 out2: "f32[]" = y * 2; y = None 1119 return ((out1, out2), []) 1120 1121 class bwd_body_0(torch.nn.Module): 1122 def forward(self, ctx, grad1: "f32[]", grad2: "f32[]"): 1123 _set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None 1124 1125 cos: "f32[]" = grad1.cos(); grad1 = None 1126 mul: "f32[]" = grad2 * 0.0; grad2 = None 1127 1128 _set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None 1129 return (cos, mul) 1130""", 1131 ) 1132 1133 def test_mark_multi_output_non_differentiable(self): 1134 from torch.autograd import Function 1135 1136 class MyFunction(Function): 1137 @staticmethod 1138 def forward(ctx, x, y, z): 1139 out1 = x.sin() 1140 out2 = y * 2 1141 out3 = z + 3 1142 ctx.mark_non_differentiable(out2, out3) 1143 return out1, out2, out3 1144 1145 @staticmethod 1146 def backward(ctx, grad1, grad2, grad3): 1147 return grad1.cos(), grad2, grad3 1148 1149 @torch.compile(backend="aot_eager", fullgraph=True) 1150 def fn(x, y, z): 1151 return MyFunction.apply(x, y, z) 1152 1153 x = torch.tensor(10.0, requires_grad=True) 1154 y = torch.tensor(20.0, requires_grad=True) 1155 z = torch.tensor(30.0, requires_grad=True) 1156 ref1, ref2, ref3 = MyFunction.apply(x, y, z) 1157 res1, res2, res3 = fn(x, y, z) 1158 self.assertEqual(ref1, res1) 1159 self.assertEqual(ref2, res2) 1160 self.assertEqual(ref3, res3) 1161 # Ensure out1 requires gradients, out2 does not. 1162 self.assertTrue(ref1.requires_grad) 1163 self.assertTrue(res1.requires_grad) 1164 self.assertFalse(ref2.requires_grad) 1165 self.assertFalse(res2.requires_grad) 1166 self.assertFalse(ref3.requires_grad) 1167 self.assertFalse(res3.requires_grad) 1168 res1.sum().backward() 1169 1170 def test_default_values(self): 1171 from torch.autograd import Function 1172 1173 class Foo(Function): 1174 @staticmethod 1175 def forward(ctx, x, alpha=0.99): 1176 return x 1177 1178 @staticmethod 1179 def backward(ctx, grad_out): 1180 return grad_out 1181 1182 @torch.compile 1183 def foo(x): 1184 return Foo.apply(x) 1185 1186 # Make sure guards for default values do not crash 1187 foo(torch.randn(2)) 1188 foo(torch.randn(2, requires_grad=True)) 1189 1190 def test_tuple_arg(self): 1191 cnt = torch._dynamo.testing.CompileCounter() 1192 1193 class TupleArgFunc(torch.autograd.Function): 1194 @staticmethod 1195 def forward(ctx, x, shape): 1196 ctx.save_for_backward(torch.randn(shape)) 1197 return x + 1 1198 1199 @staticmethod 1200 def backward(ctx, grad_output): 1201 (result,) = ctx.saved_tensors 1202 return result, None 1203 1204 @torch.compile(backend=cnt, fullgraph=True) 1205 def fn(): 1206 return TupleArgFunc.apply(x, shape) 1207 1208 shape = (10, 10) 1209 x = torch.randn(shape, requires_grad=True) 1210 out = fn() 1211 out.sum().backward() 1212 self.assertEqual(out, x + 1) 1213 self.assertEqual(x.grad.shape, shape) 1214 self.assertEqual(cnt.frame_count, 1) 1215 self.assertEqual(cnt.op_count, 2) 1216 1217 @requires_cuda 1218 def test_triton_kernel_basic(self): 1219 class Add(torch.autograd.Function): 1220 @staticmethod 1221 def forward(ctx, x, y): 1222 ctx.save_for_backward(x, y) 1223 output = torch.zeros_like(x) 1224 n_elements = output.numel() 1225 grid = lambda meta: ( # noqa: E731 1226 triton.cdiv(n_elements, meta["BLOCK_SIZE"]), 1227 ) 1228 add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16) 1229 return output 1230 1231 @staticmethod 1232 def backward(ctx, grad_output): 1233 x, y = ctx.saved_tensors 1234 return x * grad_output, y * grad_output 1235 1236 @torch.compile(fullgraph=True, backend="inductor") 1237 def f(x, y): 1238 z = Add.apply(x, y) 1239 return z 1240 1241 x = torch.randn(10, device="cuda", requires_grad=True) 1242 y = torch.randn(10, device="cuda", requires_grad=True) 1243 z = f(x, y) 1244 loss = z.sum() 1245 loss.backward() 1246 self.assertEqual(x + y, z) 1247 1248 @requires_cuda 1249 def test_triton_kernel_multiple_out(self): 1250 class Add(torch.autograd.Function): 1251 @staticmethod 1252 def forward(ctx, x, y): 1253 ctx.save_for_backward(x, y) 1254 ctx.t1 = x 1255 ctx.t2 = y 1256 output = torch.zeros_like(x) 1257 n_elements = output.numel() 1258 grid = lambda meta: ( # noqa: E731 1259 triton.cdiv(n_elements, meta["BLOCK_SIZE"]), 1260 ) 1261 add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16) 1262 return output, x 1263 1264 @staticmethod 1265 def backward(ctx, grad_output, old_x): 1266 x, y = ctx.saved_tensors 1267 x1 = ctx.t1 1268 y1 = ctx.t2 1269 return old_x * x * x1 * grad_output, y * y1 * grad_output 1270 1271 @torch.compile(fullgraph=True, backend="inductor") 1272 def f(x, y): 1273 z = Add.apply(x, y) 1274 return z 1275 1276 x = torch.randn(10, device="cuda", requires_grad=True) 1277 y = torch.randn(10, device="cuda", requires_grad=True) 1278 z, _ = f(x, y) 1279 loss = z.sum() 1280 loss.backward() 1281 self.assertEqual(x + y, z) 1282 1283 1284if __name__ == "__main__": 1285 from torch._dynamo.test_case import run_tests 1286 1287 run_tests() 1288