1# Owner(s): ["module: dynamo"] 2import enum 3import functools 4import pprint 5import re 6import unittest 7import warnings 8 9import functorch.experimental.control_flow as control_flow 10import torch 11import torch._dynamo.config as config 12import torch._dynamo.test_case 13import torch._functorch.config 14import torch.nn as nn 15import torch.utils._pytree as pytree 16import torch.utils.checkpoint 17from torch._dynamo.backends.common import aot_autograd 18from torch._dynamo.testing import ( 19 CompileCounter, 20 CompileCounterWithBackend, 21 EagerAndRecordGraphs, 22 empty_line_normalizer, 23 normalize_gm, 24) 25from torch._dynamo.utils import counters, ifdynstaticdefault 26from torch._higher_order_ops.hints_wrap import hints_wrapper 27from torch._higher_order_ops.wrap import wrap 28from torch.testing._internal.common_utils import ( 29 munge_exc, 30 TEST_WITH_TORCHDYNAMO, 31 xfailIfTorchDynamo, 32) 33from torch.testing._internal.inductor_utils import HAS_CUDA 34from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test 35 36 37requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") 38 39 40def check_dynamic_shape_capture(): 41 # This also mirrors config from `test/dynamo/test_dynamic_shapes.py:make_dynamic_cls` 42 return not config.assume_static_by_default 43 44 45def count_ops(gm, args, freq, op): 46 actual = [node.target for node in gm.graph.nodes].count(op) 47 assert actual == freq, f"expected={freq}, actual={actual}" 48 return gm 49 50 51class Obj: 52 pass 53 54 55class MyModule(nn.Module): 56 def __init__(self) -> None: 57 super().__init__() 58 self.existing = torch.nn.Parameter(torch.ones([])) 59 60 def forward(self, x): 61 return self.existing * x 62 63 64global_obj = Obj() 65global_module = MyModule() 66global_var = torch.randn(3) 67global_num = 3.14 68global_list = [] 69 70 71def find_first_node(gm, func): 72 for node in gm.graph.nodes: 73 if node.target is func: 74 return node 75 return None 76 77 78def op_count(gm): 79 result = 0 80 for node in gm.graph.nodes: 81 if "call" in node.op: 82 result += 1 83 return result 84 85 86# Checks that a dict matches a dict with "regex keys". That is, 87# the keys are regex expressions. 88def assert_dict_matches_regex(self, dct, dct_with_regex_keys): 89 regex_keys = dct_with_regex_keys.keys() 90 regex_key_to_actual_key = {} 91 for regex_key in regex_keys: 92 for key in dct: 93 if re.match(regex_key, key): 94 if regex_key in regex_key_to_actual_key: 95 raise AssertionError( 96 f"Single key regex mapped to multiple keys. Please improve your " 97 f"regex. Got: regex='{regex_key}' " 98 f"keys='{regex_key_to_actual_key[regex_key]}'," 99 f"'{key}'" 100 ) 101 regex_key_to_actual_key[regex_key] = key 102 new_dct = {} 103 for regex_key in regex_keys: 104 if regex_key not in regex_key_to_actual_key: 105 raise AssertionError( 106 f"Got regex '{regex_key}' but could not match any key in dict with " 107 f"keys {dct.keys()}" 108 ) 109 new_dct[regex_key_to_actual_key[regex_key]] = dct_with_regex_keys[regex_key] 110 self.assertEqual(dct, new_dct) 111 112 113def default_args_generator(seed_value): 114 flat_args, args_spec = pytree.tree_flatten(seed_value) 115 for i in range(3): 116 new_flat_arg = [] 117 for val in flat_args: 118 if isinstance(val, torch.Tensor): 119 new_val = val + 0.1 * i 120 elif isinstance(val, int): 121 new_val = val + 1 * i 122 elif isinstance(val, float): 123 new_val = val + 0.1 * i 124 elif isinstance(val, enum.Enum): 125 new_val = val 126 else: 127 raise AssertionError("unexpected arg type") 128 129 new_flat_arg.append(new_val) 130 new_args = pytree.tree_unflatten(new_flat_arg, args_spec) 131 yield new_args 132 133 134class HigherOrderOpTests(torch._dynamo.test_case.TestCase): 135 def _assert_wrap_fallback(self, func, args, setup=lambda: None): 136 counters.clear() 137 backend = EagerAndRecordGraphs() 138 cnt = CompileCounterWithBackend(backend) 139 140 setup() 141 expected = func(*args) 142 setup() 143 result = torch.compile(func, backend=cnt, fullgraph=False)(*args) 144 num_graph_breaks = len(counters["graph_break"].keys()) 145 self.assertGreater(num_graph_breaks, 0) 146 147 for gm in backend.graphs: 148 for node in gm.graph.nodes: 149 self.assertFalse(node.target is wrap) 150 151 self.assertEqual(result, expected) 152 153 def _test_wrap_simple( 154 self, 155 func, 156 args_generator, 157 expected_num_wrap_args, 158 expected_opcount=2, 159 return_graph=False, 160 ): 161 # Given a `func` that has a single call to `wrap`, 162 # we check that: 163 # - there are no graph breaks 164 # - eager vs torch.compile has the same result (correctness) 165 # - other compilation metrics, e.g, # of ops in the dynamo captured graph, 166 # the wrap has the expected number of args, etc 167 # 168 # we have one or multiple runs through with each of the args from args_generator, 169 # and we will check: 170 # - correctness and no graph breaks for every run 171 # - other compilation metrics only for the first run, since automatic_dynamic_shapes 172 # may compile another dynamic version graph for the later runs 173 graph = None 174 for i, args in enumerate(args_generator): 175 backend = EagerAndRecordGraphs() 176 cnt = CompileCounterWithBackend(backend) 177 expected = func(*args) 178 result = torch.compile(func, fullgraph=True, backend=cnt)(*args) 179 # check correctness and no graph breaks 180 self.assertEqual(result, expected) 181 self.assertEqual(cnt.frame_count, 1) 182 self.assertEqual(len(backend.graphs), 1) 183 # check other compilation metrics 184 if i == 0: 185 self.assertEqual(cnt.op_count, expected_opcount) 186 graph = backend.graphs[0] 187 wrap_node = find_first_node(graph, wrap) 188 self.assertEqual(len(wrap_node.args), expected_num_wrap_args) 189 # We always return/check the graph from the first run if return_graph = True 190 if return_graph: 191 return normalize_gm(graph.print_readable(print_output=False)) 192 193 def test_error_message_sane(self): 194 foo = [] 195 196 def inner(x): 197 foo.append(x) 198 return x.clone() 199 200 @torch.compile(backend="eager", fullgraph=True) 201 def f(x): 202 return wrap(inner, x) 203 204 x = torch.randn(3) 205 with self.assertRaisesRegex( 206 torch._dynamo.exc.Unsupported, 207 r"HigherOrderOperator: Mutating a variable not in the current scope \(SideEffects\)", 208 ): 209 f(x) 210 211 def test_no_freevars(self): 212 def f(x): 213 return wrap(lambda x: torch.sin(x), x) 214 215 x = torch.randn(3) 216 self._test_wrap_simple(f, default_args_generator((x,)), 2) 217 218 def test_enum_arg(self): 219 class SomeEnum(enum.Enum): 220 A = 0 221 B = 1 222 223 def g(x, val): 224 if val == SomeEnum.A: 225 return torch.sin(x) 226 return torch.cos(x) 227 228 def f(x, val): 229 return wrap(g, x, val) 230 231 x = torch.randn(3) 232 self._test_wrap_simple(f, default_args_generator((x, SomeEnum.A)), 2) 233 234 def test_return_captured_var(self): 235 freevar = torch.randn(3) 236 237 def test(x): 238 return freevar 239 240 def fn(x): 241 return wrap(test, x) 242 243 x = torch.randn(3) 244 245 # Since, `x` is unused, we don't lift it to 246 # be the input. 247 self._test_wrap_simple(fn, default_args_generator((x,)), 2) 248 249 def test_return_captured_vars(self): 250 freevar1 = torch.randn(3) 251 freevar2 = torch.randn(3) 252 253 def test(x): 254 return freevar1, freevar2, freevar1 255 256 def fn(x): 257 return wrap(test, x) 258 259 x = torch.randn(3) 260 261 # Since, `x` is unused, we don't lift it to 262 # be the input. 263 self._test_wrap_simple(fn, default_args_generator((x,)), 3, 4) 264 265 def test_return_captured_var_used_multiple_times(self): 266 freevar = torch.randn(3) 267 268 def test(x): 269 y = x + freevar 270 return y, freevar 271 272 def fn(x): 273 return wrap(test, x) 274 275 x = torch.randn(3) 276 self._test_wrap_simple(fn, default_args_generator((x,)), 3, 3) 277 278 def test_capture_untracked_global(self): 279 def f(x): 280 return wrap(lambda x: x + global_var, x) 281 282 x = torch.randn(3) 283 self._test_wrap_simple(f, default_args_generator((x,)), 3) 284 285 def test_symint_input(self): 286 def f(x): 287 i = x.size(0) 288 return wrap(lambda x, i: x.view(i), x, i) 289 290 x = torch.randn(3, 1) 291 self._test_wrap_simple( 292 f, 293 default_args_generator((x,)), 294 ifdynstaticdefault(2, 3), 295 expected_opcount=2, 296 ) 297 298 def test_wrap_pytree_args_nested(self): 299 def f(x, y, z): 300 def fn(d): 301 return d["x"].sin() + d["y"][0].cos() - d["y"][1][2].sin() 302 303 return wrap(fn, d) 304 305 x = torch.tensor(1.5) 306 y = torch.tensor(2.0) 307 z = torch.tensor(3.0) 308 d = {"x": x, "y": (y, [x, y, z])} 309 310 def my_args_generator(t): 311 yield t 312 yield t[0] + 0.1, t[1], t[2] 313 yield t[0], t[1] + 0.1, t[2] 314 315 actual_graph = self._test_wrap_simple( 316 f, 317 my_args_generator((x, y, z)), 318 4, 319 return_graph=True, 320 ) 321 self.assertExpectedInline( 322 actual_graph, 323 """\ 324class GraphModule(torch.nn.Module): 325 def forward(self, L_d_x_: "f32[]", L_d_y_0_: "f32[]", L_d_y_1_2_: "f32[]"): 326 l_d_x_ = L_d_x_ 327 l_d_y_0_ = L_d_y_0_ 328 l_d_y_1_2_ = L_d_y_1_2_ 329 330 wrap_body_0 = self.wrap_body_0 331 wrap = torch.ops.higher_order.wrap(wrap_body_0, l_d_x_, l_d_y_0_, l_d_y_1_2_); wrap_body_0 = l_d_x_ = l_d_y_0_ = l_d_y_1_2_ = None 332 getitem: "f32[]" = wrap[0]; wrap = None 333 return (getitem,) 334 335 class wrap_body_0(torch.nn.Module): 336 def forward(self, l_d_x_: "f32[]", l_d_y_0_: "f32[]", l_d_y_1_2_: "f32[]"): 337 sin: "f32[]" = l_d_x_.sin(); l_d_x_ = None 338 cos: "f32[]" = l_d_y_0_.cos(); l_d_y_0_ = None 339 add: "f32[]" = sin + cos; sin = cos = None 340 sin_1: "f32[]" = l_d_y_1_2_.sin(); l_d_y_1_2_ = None 341 sub: "f32[]" = add - sin_1; add = sin_1 = None 342 return (sub,) 343""", # NOQA: B950 344 ) 345 346 def test_wrap_pytree_args_with_symint_constant(self): 347 def f(x, y): 348 i = x.size(0) 349 return wrap(lambda t: t[0].view(t[2]) + t[1], (x, y, i)) 350 351 x = torch.randn(3, 1) 352 y = 0.5 353 actual_graph = self._test_wrap_simple( 354 f, 355 default_args_generator((x, y)), 356 ifdynstaticdefault(2, 3), 357 expected_opcount=2, 358 return_graph=True, 359 ) 360 if torch._dynamo.config.assume_static_by_default: 361 self.assertExpectedInline( 362 actual_graph, 363 """\ 364class GraphModule(torch.nn.Module): 365 def forward(self, L_x_: "f32[3, 1]"): 366 l_x_ = L_x_ 367 368 wrap_body_0 = self.wrap_body_0 369 wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_); wrap_body_0 = l_x_ = None 370 getitem: "f32[3]" = wrap[0]; wrap = None 371 return (getitem,) 372 373 class wrap_body_0(torch.nn.Module): 374 def forward(self, l_x_: "f32[3, 1]"): 375 view: "f32[3]" = l_x_.view(3); l_x_ = None 376 add: "f32[3]" = view + 0.5; view = None 377 return (add,) 378""", 379 ) 380 else: 381 self.assertExpectedInline( 382 actual_graph, 383 """\ 384class GraphModule(torch.nn.Module): 385 def forward(self, s0: "Sym(s0)", L_x_: "f32[s0, 1]"): 386 l_x_ = L_x_ 387 388 wrap_body_0 = self.wrap_body_0 389 wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_, s0); wrap_body_0 = l_x_ = s0 = None 390 getitem: "f32[s0]" = wrap[0]; wrap = None 391 return (getitem,) 392 393 class wrap_body_0(torch.nn.Module): 394 def forward(self, l_x_: "f32[s0, 1]", size: "Sym(s0)"): 395 view: "f32[s0]" = l_x_.view(size); l_x_ = size = None 396 add: "f32[s0]" = view + 0.5; view = None 397 return (add,) 398""", 399 ) 400 401 def test_wrap_pytree_kwargs(self): 402 def f(x, y, z): 403 def fn(*, x, y, z): 404 z1, z2 = z 405 return (x * 2) + y + z1 406 407 return wrap(fn, x=x, y=y, z=z) 408 409 x = torch.randn(3) 410 y = torch.randn(3, 3) 411 412 def my_args_generator(t): 413 yield t 414 x1 = t[0] + 0.1 415 y1 = t[1] + 0.1 416 yield (x1, y1, (x1, y1)) 417 x2 = t[0] + 0.2 418 y2 = t[0] + 0.2 419 yield (x2, y2, (x2, y2)) 420 421 self._test_wrap_simple(f, my_args_generator((x, y, (x, y))), 3) 422 423 def test_wrap_pytree_args_not_const_symint_tensor(self): 424 class MyClass: 425 def __init__(self, x): 426 self.val = x 427 428 def f(x, y): 429 return wrap(lambda z: z[0].sin() * z[1].val.cos(), (x, y)) 430 431 x = torch.tensor(1.2) 432 y = MyClass(torch.tensor(3.4)) 433 self._test_wrap_simple(f, [(x, y)], 3) 434 435 def test_capture_constants(self): 436 x = torch.randn(3, 3) 437 y = 4.0 438 439 def fn(x, y, z): 440 if z: 441 return x + y 442 return x * y 443 444 def f(x, y, z): 445 return wrap(fn, x, y, z) 446 447 args = (x, 4.0, None) 448 opt_f = torch.compile(f, fullgraph=True, backend=CompileCounter()) 449 expected = f(*args) 450 result = opt_f(*args) 451 self.assertEqual(result, expected) 452 453 # Ensure that we recompile here 454 args = (x, 5.0, None) 455 expected = f(*args) 456 result = opt_f(*args) 457 self.assertEqual(result, expected) 458 459 def test_capture_untracked_global_nested(self): 460 backend = EagerAndRecordGraphs() 461 cnt = CompileCounterWithBackend(backend) 462 463 @torch.compile(backend=cnt, fullgraph=True) 464 def f(x): 465 return wrap(lambda x: wrap(lambda x: x + global_var, x), x) 466 467 x = torch.randn(3) 468 result = f(x) 469 470 self.assertEqual(result, x + global_var) 471 self.assertEqual(cnt.frame_count, 1) 472 self.assertEqual(cnt.op_count, 2) 473 474 self.assertEqual(len(backend.graphs), 1) 475 wrap_node = find_first_node(backend.graphs[0], wrap) 476 self.assertTrue(len(wrap_node.args), 3) 477 478 body_function = getattr(backend.graphs[0], wrap_node.args[0].name) 479 self.assertEqual(op_count(body_function), 2) 480 inner_wrap_node = find_first_node(body_function, wrap) 481 self.assertTrue(len(inner_wrap_node.args), 3) 482 483 def test_capture_untracked_nonlocal(self): 484 x = torch.randn(3, 3) 485 y = torch.randn(3, 3) 486 487 def f(x, y): 488 def g(x): 489 return wrap(lambda x: x + y, x) 490 491 self._test_wrap_simple(g, default_args_generator((x,)), 3) 492 return g(x) 493 494 f(x, y) 495 496 def test_capture_tracked(self): 497 x = torch.randn(3, 3) 498 y = torch.randn(3, 3) 499 500 def f(x, y): 501 return wrap(lambda x: x + y, x) 502 503 self._test_wrap_simple(f, default_args_generator((x, y)), 3) 504 505 def test_capture_tracked_nested(self): 506 x = torch.randn(3, 3) 507 y = torch.randn(3, 3) 508 509 def f(x, y): 510 return wrap(lambda x: wrap(lambda x: x + y, x), x) 511 512 self._test_wrap_simple(f, default_args_generator((x, y)), 3) 513 514 def test_inlined_functions(self): 515 def g(x, y): 516 return x + y 517 518 def f(x, y): 519 return wrap(lambda x: g(x, y), x) 520 521 x = torch.randn(3, 3) 522 y = torch.randn(3, 3) 523 self._test_wrap_simple(f, default_args_generator((x, y)), 3) 524 525 def test_same_freevar_twice(self): 526 free = torch.randn(3) 527 528 def g(x): 529 y = free.sin() 530 z = free.cos() 531 return y, z 532 533 def f(x): 534 return wrap(g, x) 535 536 x = torch.randn(3) 537 538 # Since, `x` is unused, we don't lift it to 539 # be the input. 540 self._test_wrap_simple(f, default_args_generator((x,)), 2, 3) 541 542 def test_register_subclass(self): 543 from torch._higher_order_ops.cond import cond_op 544 from torch.testing._internal.two_tensor import TwoTensor 545 546 a = torch.tensor([1.0, 0.0, 1.0]) 547 b = torch.randn(3) 548 t = TwoTensor(a, b) 549 with self.assertRaisesRegex( 550 NotImplementedError, 551 "no rule registered for HOP cond and subclass .*TwoTensor'>", 552 ): 553 res = cond_op(a.sum() > 0, torch.sin, torch.cos, (t,)) 554 555 called = 0 556 557 # Using cond.py_impl 558 @cond_op.py_impl(TwoTensor) 559 def _(pred, true_fn, false_fn, operands): 560 nonlocal called 561 called += 1 562 assert len(operands) == 1 563 a = cond_op(pred, true_fn, false_fn, (operands[0].a,)) 564 b = cond_op(pred, true_fn, false_fn, (operands[0].b,)) 565 return TwoTensor(a, b) 566 567 res = cond_op(a.sum() > 0, torch.sin, torch.cos, (t,)) 568 self.assertEqual(res.a, torch.sin(a)) 569 self.assertEqual(res.b, torch.sin(b)) 570 self.assertEqual(called, 1) 571 572 def test_register_mode(self): 573 from torch._higher_order_ops.cond import cond_op 574 575 torch_dispatch_called = 0 576 577 class MyMode(torch.utils._python_dispatch.TorchDispatchMode): 578 def __torch_dispatch__(self, func, types, args=(), kwargs=None): 579 nonlocal torch_dispatch_called 580 torch_dispatch_called += 1 581 return func(*args, **kwargs) 582 583 a = torch.tensor([1.0, 0.1, 1.0]) 584 pred = a.sum() > 0 585 with self.assertRaisesRegex( 586 NotImplementedError, 587 "no rule registered for HOP cond and mode .*MyMode", 588 ): 589 with MyMode(): 590 res = cond_op(pred, torch.sin, torch.cos, (a,)) 591 592 py_impl_called = 0 593 594 # Using cond.py_impl 595 @cond_op.py_impl(MyMode) 596 def _(mode, pred, true_fn, false_fn, operands): 597 nonlocal py_impl_called 598 py_impl_called += 1 599 return cond_op(pred, true_fn, false_fn, operands) 600 601 a = torch.tensor([1.0, 0.1, 1.0]) 602 pred = a.sum() > 0 603 with MyMode(): 604 res = cond_op(pred, torch.sin, torch.cos, (a,)) 605 self.assertEqual(res, a.sin()) 606 607 def test_capture_value_created_in_subgraph(self): 608 backend = EagerAndRecordGraphs() 609 cnt = CompileCounterWithBackend(backend) 610 611 x = torch.randn(3, 3) 612 y = torch.randn(3, 3) 613 614 def inner(x, y): 615 z = x + y 616 return wrap(lambda x: wrap(lambda x: x + z, x), x) 617 618 @torch.compile(backend=cnt, fullgraph=True) 619 def f(x, y): 620 return wrap(inner, x, y) 621 622 result = f(x, y) 623 624 self.assertEqual(result, x + y + x) 625 self.assertEqual(cnt.frame_count, 1) 626 self.assertEqual(cnt.op_count, 2) 627 self.assertEqual(len(backend.graphs), 1) 628 629 # No changes to args of outer wrap 630 gm = backend.graphs[0] 631 wrap_node = find_first_node(gm, wrap) 632 self.assertTrue(len(wrap_node.args), 3) 633 634 # z was lifted to arg of inner wrap 635 body_function = getattr(gm, wrap_node.args[0].name) 636 # addition + wrap + getitem 637 self.assertEqual(op_count(body_function), 3) 638 inner_wrap_node = find_first_node(body_function, wrap) 639 self.assertTrue(len(inner_wrap_node.args), 3) 640 641 # Innermost body function: z was also lifted to arg 642 body_function = getattr(body_function, inner_wrap_node.args[0].name) 643 self.assertEqual(op_count(body_function), 2) 644 inner_wrap_node = find_first_node(body_function, wrap) 645 self.assertTrue(len(inner_wrap_node.args), 3) 646 647 def test_side_effect_set_new_attr_global_obj(self): 648 def setup(): 649 global global_obj 650 global_obj = Obj() 651 652 def f(x): 653 def h(x): 654 def g(x): 655 global_obj.foo = x + 1 656 return x.clone() 657 658 y = wrap(g, x) 659 return y + global_obj.foo 660 661 return h(x) 662 663 x = torch.zeros([]) 664 self._assert_wrap_fallback(f, (x,), setup=setup) 665 666 def test_side_effect_set_existing_attr_global_obj(self): 667 def setup(): 668 global global_obj 669 global_obj = Obj() 670 global_obj.foo = nn.Parameter(torch.tensor(4.0)) 671 672 def f(x): 673 def h(x): 674 def g(x): 675 global_obj.foo = x + 1 676 return x.clone() 677 678 y = wrap(g, x) 679 return y + global_obj.foo 680 681 return h(x) 682 683 x = torch.zeros([]) 684 self._assert_wrap_fallback(f, (x,), setup=setup) 685 686 def test_side_effect_del_existing_attr_global_obj(self): 687 def setup(): 688 global global_obj 689 global_obj = Obj() 690 global_obj.foo = torch.tensor(4.0) 691 692 def f(x): 693 def h(x): 694 def g(x): 695 del global_obj.foo 696 return x.clone() 697 698 y = wrap(g, x) 699 return y 700 701 return h(x) 702 703 x = torch.zeros([]) 704 self._assert_wrap_fallback(f, (x,), setup=setup) 705 706 def test_side_effect_set_new_attr_global_module(self): 707 def setup(): 708 global global_module 709 global_module = MyModule() 710 711 def h(x): 712 def g(x): 713 global_module.foo = nn.Parameter(x + 1) 714 return x.clone() 715 716 y = wrap(g, x) 717 return y + global_module.foo 718 719 x = torch.zeros([]) 720 self._assert_wrap_fallback(h, (x,), setup=setup) 721 722 def test_side_effect_set_existing_attr_global_module(self): 723 def setup(): 724 global global_module 725 global_module = MyModule() 726 727 def h(x): 728 def g(x): 729 global_module.existing = nn.Parameter(torch.tensor(4.0)) 730 return global_module(x) 731 732 y = wrap(g, x) 733 return y 734 735 x = torch.zeros([]) 736 self._assert_wrap_fallback(h, (x,), setup=setup) 737 738 def test_side_effect_del_existing_attr_global_module(self): 739 def setup(): 740 global global_module 741 global_module = MyModule() 742 743 def h(x): 744 def g(x): 745 del global_module.existing 746 return x.clone() 747 748 y = wrap(g, x) 749 return y 750 751 x = torch.zeros([]) 752 self._assert_wrap_fallback(h, (x,), setup=setup) 753 754 def test_side_effect_mutate_global_num(self): 755 def setup(): 756 global global_num 757 global_num = 3.14 758 759 def f(x): 760 def g(x): 761 global global_num 762 global_num = global_num + 1 763 return x + global_num 764 765 y = wrap(g, x) 766 return y + global_num 767 768 x = torch.zeros([]) 769 self._assert_wrap_fallback(f, (x,), setup=setup) 770 771 def test_side_effect_mutate_global_num_builtin(self): 772 def setup(): 773 global global_num 774 global_num = 3.14 775 776 def f(x): 777 def g(x): 778 global global_num 779 global_num += 1 780 return x + global_num 781 782 y = wrap(g, x) 783 return y + global_num 784 785 x = torch.zeros([]) 786 self._assert_wrap_fallback(f, (x,), setup=setup) 787 788 def test_side_effect_mutate_global_tensor(self): 789 def setup(): 790 global global_var 791 global_var = torch.ones(3) 792 793 def f(x): 794 def g(x): 795 global global_var 796 global_var = global_var + 1 797 return x + global_var 798 799 y = wrap(g, x) 800 return y + global_var 801 802 x = torch.zeros([]) 803 self._assert_wrap_fallback(f, (x,), setup=setup) 804 805 def test_side_effect_mutate_global_tensor_builtin(self): 806 def setup(): 807 global global_var 808 global_var = torch.ones(3) 809 810 def f(x): 811 def g(x): 812 global global_var 813 global_var += 1 814 return x + global_var 815 816 y = wrap(g, x) 817 return y + global_var 818 819 x = torch.zeros([]) 820 self._assert_wrap_fallback(f, (x,), setup=setup) 821 822 def test_side_effect_mutate_global_list(self): 823 def setup(): 824 global global_list 825 global_list = [] 826 827 def f(x): 828 def g(x): 829 val = x + 1 830 global_list.append(val) 831 return global_list[-1] 832 833 y = wrap(g, x) 834 z = y + global_list[-1] 835 return z 836 837 x = torch.zeros([]) 838 self._assert_wrap_fallback(f, (x,), setup=setup) 839 840 def test_side_effect_mutate_nonlocal_num(self): 841 def f(x): 842 def h(x): 843 val = 1 844 845 def g(x): 846 nonlocal val 847 val = val + 1 848 return x + val 849 850 y = wrap(g, x) 851 z = y + val 852 return z 853 854 return h(x) 855 856 x = torch.zeros([]) 857 self._assert_wrap_fallback(f, (x,)) 858 859 def test_side_effect_set_new_attr_nonlocal_obj(self): 860 def f(x): 861 def h(x): 862 obj = Obj() 863 864 def g(x): 865 obj.val = x.dim() 866 return x.clone() 867 868 y = wrap(g, x) 869 z = y + obj.val 870 return z 871 872 return h(x) 873 874 x = torch.zeros([]) 875 self._assert_wrap_fallback(f, (x,)) 876 877 def test_side_effect_set_existing_attr_nonlocal_obj(self): 878 def f(x): 879 def h(x): 880 obj = Obj() 881 obj.val = 3 882 883 def g(x): 884 obj.val = x.dim() 885 return x.clone() 886 887 y = wrap(g, x) 888 z = y + obj.val 889 return z 890 891 return h(x) 892 893 x = torch.zeros([]) 894 self._assert_wrap_fallback(f, (x,)) 895 896 def test_side_effect_del_existing_attr_nonlocal_obj(self): 897 def f(x): 898 def h(x): 899 obj = Obj() 900 obj.val = 3 901 902 def g(x): 903 del obj.val 904 return x.clone() 905 906 y = wrap(g, x) 907 return y 908 909 return h(x) 910 911 x = torch.zeros([]) 912 self._assert_wrap_fallback(f, (x,)) 913 914 def test_side_effect_set_new_attr_nonlocal_module(self): 915 def h(x): 916 obj = MyModule() 917 918 def g(x): 919 obj.val = x.dim() 920 return x.clone() 921 922 y = wrap(g, x) 923 z = y + obj.val 924 return z 925 926 x = torch.zeros([]) 927 self._assert_wrap_fallback(h, (x,)) 928 929 def test_side_effect_set_existing_attr_nonlocal_module(self): 930 def h(x): 931 obj = MyModule() 932 933 def g(x): 934 obj.existing = nn.Parameter(torch.tensor(3.14)) 935 return obj(x) 936 937 y = wrap(g, x) 938 return y 939 940 x = torch.zeros([]) 941 self._assert_wrap_fallback(h, (x,)) 942 943 def test_side_effect_del_existing_attr_nonlocal_module(self): 944 def h(x): 945 obj = MyModule() 946 947 def g(x): 948 del obj.existing 949 return x.clone() 950 951 y = wrap(g, x) 952 return y 953 954 x = torch.zeros([]) 955 self._assert_wrap_fallback(h, (x,)) 956 957 def test_side_effect_mutate_nonlocal_tensor(self): 958 def f(x): 959 def h(x): 960 val = torch.tensor(1.0) 961 962 def g(x): 963 nonlocal val 964 val = val + 1 965 return x + val 966 967 y = wrap(g, x) 968 z = y + val 969 return z 970 971 return h(x) 972 973 x = torch.zeros([]) 974 self._assert_wrap_fallback(f, (x,)) 975 976 def test_side_effect_mutate_nonlocal_num_builtin(self): 977 def f(x): 978 def h(x): 979 val = 1 980 981 def g(x): 982 nonlocal val 983 val += 1 984 return x + val 985 986 y = wrap(g, x) 987 z = y + val 988 return z 989 990 return h(x) 991 992 x = torch.zeros([]) 993 self._assert_wrap_fallback(f, (x,)) 994 995 def test_side_effect_mutate_nonlocal_tensor_builtin(self): 996 def f(x): 997 def h(x): 998 val = torch.tensor(1.0) 999 1000 def g(x): 1001 nonlocal val 1002 val += 1 1003 return x + val 1004 1005 y = wrap(g, x) 1006 z = y + val 1007 return z 1008 1009 return h(x) 1010 1011 x = torch.zeros([]) 1012 self._assert_wrap_fallback(f, (x,)) 1013 1014 def test_side_effect_nonlocal_list_append_graph_break(self): 1015 def g(x): 1016 y = [] 1017 1018 def f(k): 1019 m = k + 1 1020 y.append(m) 1021 return k 1022 1023 wrap(f, x) 1024 return y[0] 1025 1026 x = torch.randn(3, 3) 1027 self._assert_wrap_fallback(g, (x,)) 1028 1029 def test_side_effect_nested_nonlocal_list_append_graph_break(self): 1030 def g(x): 1031 def h(x): 1032 y = [] 1033 1034 def f(k): 1035 m = k + 1 1036 y.append(m) 1037 return k 1038 1039 wrap(f, x) 1040 return y[0] 1041 1042 return h(x) 1043 1044 x = torch.randn(3, 3) 1045 self._assert_wrap_fallback(g, (x,)) 1046 1047 def test_side_effect_local_list_append_no_graph_break(self): 1048 def g(x): 1049 def f(k): 1050 y = [] 1051 y.append(k + 1) 1052 return y[0] 1053 1054 return wrap(f, x) 1055 1056 x = torch.randn(3, 3) 1057 self._test_wrap_simple(g, default_args_generator((x,)), 2) 1058 1059 def test_wrap_kwarg(self): 1060 def f(x, y): 1061 return wrap(lambda x, y: x + y, x, y=y) 1062 1063 x = torch.randn(3) 1064 y = torch.randn(3, 3) 1065 self._test_wrap_simple(f, default_args_generator((x, y)), 3) 1066 1067 def test_wrap_kwarg_int(self): 1068 def f(x, y): 1069 return wrap(lambda x, y: x + y, x, y=y) 1070 1071 x = torch.randn(3) 1072 y = 8 1073 1074 self._test_wrap_simple( 1075 f, default_args_generator((x, y)), ifdynstaticdefault(2, 3) 1076 ) 1077 1078 def test_wrap_all_kwarg(self): 1079 def f(y, x): 1080 return wrap(lambda x, y: (x * 2) + y, x=x, y=y) 1081 1082 x = torch.randn(3) 1083 y = torch.randn(3, 3) 1084 1085 self._test_wrap_simple(f, default_args_generator((x, y)), 3) 1086 1087 def test_wrap_kwarg_only(self): 1088 def f(x, y): 1089 def fn(*, x, y): 1090 return (x * 2) + y 1091 1092 return wrap(fn, x=x, y=y) 1093 1094 x = torch.randn(3) 1095 y = torch.randn(3, 3) 1096 1097 self._test_wrap_simple(f, default_args_generator((x, y)), 3) 1098 1099 def test_wrap_kwarg_default(self): 1100 def f(x, y): 1101 def fn(*, x, y, z=8): 1102 return (x * 2) + y + z 1103 1104 return wrap(fn, x=x, y=y) 1105 1106 x = torch.randn(3) 1107 y = torch.randn(3, 3) 1108 1109 self._test_wrap_simple(f, default_args_generator((x, y)), 3) 1110 1111 def test_wrap_kwarg_default_if_branch(self): 1112 def f(x, y): 1113 def fn(*, x, y, z=None): 1114 if z is None: 1115 return (x * 2) + y 1116 else: 1117 return 2 * x 1118 1119 return wrap(fn, x=x, y=y) 1120 1121 x = torch.randn(3) 1122 y = torch.randn(3, 3) 1123 1124 self._test_wrap_simple(f, default_args_generator((x, y)), 3) 1125 1126 def test_wrap_kwarg_recompile(self): 1127 def f(x, y, z=None): 1128 def fn(*, x, y, z=None): 1129 if z is None: 1130 return (x * 2) + y 1131 else: 1132 return 2 * x 1133 1134 return wrap(fn, x=x, y=y, z=z) 1135 1136 x = torch.randn(3) 1137 y = torch.randn(3, 3) 1138 1139 counters.clear() 1140 opt = torch.compile(f, backend="eager", fullgraph=True) 1141 opt(x, y) 1142 self.assertEqual(counters["stats"]["calls_captured"], 2) 1143 1144 # verify that we `don't` recompile 1145 opt(x, y) 1146 self.assertEqual(counters["stats"]["calls_captured"], 2) 1147 1148 output = opt(x, y, 8) 1149 self.assertEqual(counters["stats"]["calls_captured"], 4) 1150 self.assertEqual(output, 2 * x) 1151 1152 def test_wrap_kwarg_default_else_branch(self): 1153 def f(x, y, z): 1154 def fn(*, x, y, z=None): 1155 if z is None: 1156 return (x * 2) + y 1157 else: 1158 return 2 * x 1159 1160 return wrap(fn, x=x, y=y, z=z) 1161 1162 x = torch.randn(3) 1163 y = torch.randn(3, 3) 1164 1165 self._test_wrap_simple(f, default_args_generator((x, y, 8)), 2) 1166 1167 def test_map_subgraph_name_is_valid(self): 1168 backend = EagerAndRecordGraphs() 1169 cnt = CompileCounterWithBackend(backend) 1170 1171 xs = torch.randn(2, 3, 3) 1172 y = torch.randn(3) 1173 1174 def map_f(xs, y): 1175 def inner(x, y): 1176 def inner2(x, y): 1177 return x + y 1178 1179 return control_flow.map(inner2, x, y) 1180 1181 return control_flow.map(inner, xs, y) 1182 1183 graphs = self._check_map_graph_and_extract(map_f, (xs, y)) 1184 if graphs: 1185 graph, body_graph = graphs 1186 self.assertExpectedInline( 1187 graph, 1188 """\ 1189def forward(self, L_xs_ : torch.Tensor, L_y_ : torch.Tensor): 1190 l_xs_ = L_xs_ 1191 l_y_ = L_y_ 1192 map_body_1 = self.map_body_1 1193 map_impl = torch.ops.higher_order.map_impl(map_body_1, [l_xs_], [l_y_]); map_body_1 = l_xs_ = l_y_ = None 1194 getitem_1 = map_impl[0]; map_impl = None 1195 return (getitem_1,)""", 1196 ) 1197 self.assertExpectedInline( 1198 body_graph, 1199 """\ 1200def forward(self, child, l_y_): 1201 child_1 = child[0]; child_1 = None 1202 map_body_0 = self.map_body_0 1203 map_impl = torch.ops.higher_order.map_impl(map_body_0, [child], [l_y_]); map_body_0 = child = l_y_ = None 1204 getitem_1 = map_impl[0]; map_impl = None 1205 return (getitem_1,)""", 1206 ) 1207 1208 def test_map_multi_return(self): 1209 cnt = CompileCounter() 1210 1211 def f(x): 1212 return control_flow.map(lambda x: (x.sin(), x.sin()), x) 1213 1214 x = torch.randn(3) 1215 graphs = self._check_map_graph_and_extract(f, (x,)) 1216 if graphs: 1217 graph, body_graph = graphs 1218 self.assertExpectedInline( 1219 graph, 1220 """\ 1221def forward(self, L_x_ : torch.Tensor): 1222 l_x_ = L_x_ 1223 map_body_0 = self.map_body_0 1224 map_impl = torch.ops.higher_order.map_impl(map_body_0, [l_x_], []); map_body_0 = l_x_ = None 1225 getitem_1 = map_impl[0] 1226 getitem_2 = map_impl[1]; map_impl = None 1227 return (getitem_1, getitem_2)""", 1228 ) 1229 self.assertExpectedInline( 1230 body_graph, 1231 """\ 1232def forward(self, child): 1233 child_1 = child.sin() 1234 child_2 = child.sin(); child = None 1235 return (child_1, child_2)""", 1236 ) 1237 1238 def test_map_pytree_return(self): 1239 cnt = CompileCounter() 1240 1241 def _construct_pytree(a): 1242 return (a, [[[a]]], a, (a, (a,), a), {"a": a}) 1243 1244 def f(x): 1245 def inner_f(xs): 1246 return _construct_pytree(xs) 1247 1248 return control_flow.map(inner_f, x) 1249 1250 x = torch.randn(3) 1251 graphs = self._check_map_graph_and_extract(f, (x,)) 1252 if graphs: 1253 graph, body_graph = graphs 1254 self.assertExpectedInline( 1255 graph, 1256 """\ 1257def forward(self, L_x_ : torch.Tensor): 1258 l_x_ = L_x_ 1259 map_body_0 = self.map_body_0 1260 map_impl = torch.ops.higher_order.map_impl(map_body_0, [l_x_], []); map_body_0 = l_x_ = None 1261 getitem_1 = map_impl[0] 1262 getitem_2 = map_impl[1] 1263 getitem_3 = map_impl[2] 1264 getitem_4 = map_impl[3] 1265 getitem_5 = map_impl[4] 1266 getitem_6 = map_impl[5] 1267 getitem_7 = map_impl[6]; map_impl = None 1268 return (getitem_1, getitem_2, getitem_3, getitem_4, getitem_5, getitem_6, getitem_7)""", 1269 ) 1270 self.assertExpectedInline( 1271 body_graph, 1272 """\ 1273def forward(self, child): 1274 return (child, child, child, child, child, child, child)""", 1275 ) 1276 1277 def test_map_kwargs(self): 1278 cnt = CompileCounter() 1279 1280 @torch.compile(backend=cnt) 1281 def f(x): 1282 return control_flow.map(lambda x: x.sin(), x=x) 1283 1284 x = torch.randn(3) 1285 self.assertRaises(TypeError, lambda: f(x)) 1286 self.assertEqual(cnt.frame_count, 0) 1287 1288 def test_map_symint_input(self): 1289 backend = EagerAndRecordGraphs() 1290 cnt = CompileCounterWithBackend(backend) 1291 1292 def fn(x, y): 1293 def inner(x, y): 1294 return torch.sin(x + y) 1295 1296 return control_flow.map(inner, x, y.size(0)) 1297 1298 x = torch.randn(3, 1) 1299 y = torch.randn(3, 1) 1300 graphs = self._check_map_graph_and_extract(fn, (x, y)) 1301 if graphs: 1302 graph, body_graph = graphs 1303 self.assertExpectedInline( 1304 graph, 1305 """\ 1306def forward(self, L_x_ : torch.Tensor): 1307 l_x_ = L_x_ 1308 map_body_0 = self.map_body_0 1309 map_impl = torch.ops.higher_order.map_impl(map_body_0, [l_x_], [3]); map_body_0 = l_x_ = None 1310 getitem_1 = map_impl[0]; map_impl = None 1311 return (getitem_1,)""", 1312 ) 1313 self.assertExpectedInline( 1314 body_graph, 1315 """\ 1316def forward(self, child, const_unused): 1317 add = child + 3; child = None 1318 sin = torch.sin(add); add = None 1319 return (sin,)""", 1320 ) 1321 1322 def test_map_lowers_to_graph(self): 1323 backend = EagerAndRecordGraphs() 1324 cnt = CompileCounterWithBackend(backend) 1325 1326 def fn(x, y): 1327 def inner(x, y): 1328 return torch.sin(x + y) 1329 1330 return control_flow.map(inner, x, y.size(0)) 1331 1332 x = torch.randn(3, 1) 1333 y = torch.randn(3, 1) 1334 graphs = self._check_map_graph_and_extract(fn, (x, y)) 1335 if graphs: 1336 graph, body_graph = graphs 1337 self.assertExpectedInline( 1338 graph, 1339 """\ 1340def forward(self, L_x_ : torch.Tensor): 1341 l_x_ = L_x_ 1342 map_body_0 = self.map_body_0 1343 map_impl = torch.ops.higher_order.map_impl(map_body_0, [l_x_], [3]); map_body_0 = l_x_ = None 1344 getitem_1 = map_impl[0]; map_impl = None 1345 return (getitem_1,)""", 1346 ) 1347 self.assertExpectedInline( 1348 body_graph, 1349 """\ 1350def forward(self, child, const_unused): 1351 add = child + 3; child = None 1352 sin = torch.sin(add); add = None 1353 return (sin,)""", 1354 ) 1355 1356 def test_map_example_value_metadata_consistent_with_eager(self): 1357 from torch._higher_order_ops.map import map_dense 1358 1359 backend = EagerAndRecordGraphs() 1360 1361 def inner(x): 1362 return x.sin(), x.cos().T, x.sin().view(-1) 1363 1364 rand_44 = torch.randn(4, 4) 1365 inps = [ 1366 torch.randn(3), 1367 torch.randn(3, 4), 1368 torch.randn(3, 4, 5, requires_grad=True), 1369 torch.randn(3, 4, 5, requires_grad=True).permute((2, 0, 1)), 1370 torch.randn(3, 4, 5, requires_grad=True).detach(), 1371 torch.randn(3, 4, 5, requires_grad=True).narrow(1, 1, 2), 1372 rand_44.T, 1373 rand_44[::2], 1374 rand_44[::2, ::2], 1375 rand_44[1::3, 1::3], 1376 rand_44[1::3, 1::2].T, 1377 rand_44.unsqueeze(1), 1378 rand_44.squeeze(0), 1379 rand_44.reshape(2, 8), 1380 ] 1381 for x in inps: 1382 compiled_ret = torch.compile( 1383 control_flow.map, backend=backend, fullgraph=True 1384 )(inner, x) 1385 eager_sin, eager_transpose, eager_view = map_dense(inner, (x,), ()) 1386 1387 map_node = next( 1388 node 1389 for node in backend.graphs[0].graph.nodes 1390 if node.op == "call_function" and "map" in node.name 1391 ) 1392 1393 fake_sin, fake_transpose, fake_view = map_node.meta["example_value"] 1394 1395 def _check_size_stride_contiguous(x, y): 1396 self.assertEqual(y.size(), x.size()) 1397 self.assertEqual(y.stride(), x.stride()) 1398 self.assertEqual(y.requires_grad, x.requires_grad) 1399 self.assertEqual(x.is_contiguous(), True) 1400 self.assertEqual(y.is_contiguous(), True) 1401 1402 _check_size_stride_contiguous(eager_sin, fake_sin) 1403 _check_size_stride_contiguous(eager_transpose, fake_transpose) 1404 _check_size_stride_contiguous(eager_view, fake_view) 1405 1406 torch._dynamo.reset() 1407 backend.graphs.clear() 1408 1409 def test_cond_subgraph_name_is_valid(self): 1410 backend = EagerAndRecordGraphs() 1411 cnt = CompileCounterWithBackend(backend) 1412 1413 pred = torch.tensor(True) 1414 pred2 = torch.tensor(False) 1415 xs = torch.randn(2, 3, 3) 1416 y = torch.randn(3, 3) 1417 1418 @torch.compile(backend=cnt, fullgraph=True) 1419 def cond_f(pred, pred2, x, y): 1420 def true_fn(pred2, x, y): 1421 return x + y 1422 1423 def false_fn(pred2, x, y): 1424 def true_fn2(x, y): 1425 return x.sin() - y.cos() 1426 1427 def false_fn2(x, y): 1428 return x.cos() - y.sin() 1429 1430 return control_flow.cond(pred2, true_fn2, false_fn2, [x, y]) 1431 1432 return control_flow.cond(pred, true_fn, false_fn, [pred2, x, y]) 1433 1434 result = cond_f(pred, pred2, xs, y) 1435 self.assertEqual(result, xs + y) 1436 1437 cond_gm = backend.graphs[0] 1438 name_set = set() 1439 name_set.update(name for name, _ in cond_gm.named_modules()) 1440 self.assertEqual( 1441 name_set, 1442 { 1443 "", 1444 "cond_true_1", 1445 "cond_false_1", 1446 "cond_false_1.cond_false_0", 1447 "cond_false_1.cond_true_0", 1448 }, 1449 ) 1450 1451 @torch._dynamo.config.patch( 1452 assume_static_by_default=True, 1453 dynamic_shapes=True, 1454 ) 1455 def test_cond_graph_break_in_one_branch(self): 1456 backend = EagerAndRecordGraphs() 1457 cnt = CompileCounterWithBackend(backend) 1458 1459 class Foo(torch.nn.Module): 1460 def __init__(self) -> None: 1461 super().__init__() 1462 self.buffer = torch.nn.Buffer(torch.ones(6, 4)) 1463 1464 def forward(self, x): 1465 def true_fn(x): 1466 self.buffer += 1 1467 return self.buffer.sum() + x.sum() 1468 1469 def false_fn(x): 1470 return (x - 1).sum() 1471 1472 return control_flow.cond(x.sum() > 4, true_fn, false_fn, [x]) 1473 1474 mod_for_compile = torch.compile(Foo(), backend=cnt, dynamic=True) 1475 mod_for_eager = Foo() 1476 1477 with self.assertRaisesRegex( 1478 torch._dynamo.exc.UncapturedHigherOrderOpError, 1479 r"Cond doesn't work unless it is captured completely with torch.compile", 1480 ): 1481 mod_for_eager(torch.ones(6, 4)) 1482 1483 with self.assertRaisesRegex( 1484 torch._dynamo.exc.UncapturedHigherOrderOpError, 1485 r"Cond doesn't work unless it is captured completely with torch.compile", 1486 ): 1487 mod_for_compile(torch.ones(3, 4)) 1488 1489 def test_cond_free_variable_in_both_branches(self): 1490 backend = EagerAndRecordGraphs() 1491 cnt = CompileCounterWithBackend(backend) 1492 1493 z = torch.ones(4, 4) 1494 1495 class Foo(torch.nn.Module): 1496 def __init__(self) -> None: 1497 super().__init__() 1498 self.buffer = torch.nn.Buffer(torch.ones(6, 4)) 1499 1500 def forward(self, x, y): 1501 def true_fn(x): 1502 return x.sum() + self.buffer.sum() + z.sum() 1503 1504 def false_fn(x): 1505 return x.sum() - z.sum() - self.buffer.sum() 1506 1507 return control_flow.cond(y, true_fn, false_fn, [x]) 1508 1509 mod_for_compile = torch.compile( 1510 Foo(), backend=cnt, dynamic=True, fullgraph=True 1511 ) 1512 mod_for_eager = Foo() 1513 1514 self.assertEqual( 1515 mod_for_compile(torch.tensor(True), torch.tensor(5)), 1516 mod_for_eager(torch.tensor(True), torch.tensor(5)), 1517 ) 1518 1519 for node in backend.graphs[0].graph.nodes: 1520 if ( 1521 node.op == "call_function" 1522 and node.target == torch.ops.higher_order.cond 1523 ): 1524 _, _, _, operands = node.args 1525 # Each branch takes 3 inputs (buffer, x, z) 1526 self.assertEqual(len(operands), 3) 1527 if node.op == "get_attr": 1528 if str(node.target) in ("cond_true_0, cond_false_0"): 1529 num_placeholders = len( 1530 [ 1531 node 1532 for node in getattr( 1533 backend.graphs[0], str(node.target) 1534 ).graph.nodes 1535 if node.op == "placeholder" 1536 ] 1537 ) 1538 self.assertEqual(num_placeholders, 3) 1539 1540 def _check_cond_graph_and_extract(self, fn, args): 1541 backend = EagerAndRecordGraphs() 1542 cnt = CompileCounterWithBackend(backend) 1543 out = torch.compile(fn, backend=cnt, fullgraph=True)(*args) 1544 self.assertEqual(out, fn(*args)) 1545 self.assertEqual(cnt.frame_count, 1) 1546 self.assertEqual(len(backend.graphs), 1) 1547 1548 # Dynamic shapes produce a slightly different graph. 1549 if check_dynamic_shape_capture(): 1550 return 1551 1552 gm = backend.graphs[0] 1553 graph = gm.code.strip() 1554 true_graph = gm.cond_true_0.code.strip() 1555 false_graph = gm.cond_false_0.code.strip() 1556 return (graph, true_graph, false_graph) 1557 1558 def _check_map_graph_and_extract(self, fn, args): 1559 backend = EagerAndRecordGraphs() 1560 cnt = CompileCounterWithBackend(backend) 1561 out = torch.compile(fn, backend=cnt, fullgraph=True)(*args) 1562 self.assertEqual(out, fn(*args)) 1563 self.assertEqual(cnt.frame_count, 1) 1564 self.assertEqual(len(backend.graphs), 1) 1565 1566 # Dynamic shapes produce a slightly different graph. 1567 if check_dynamic_shape_capture(): 1568 return 1569 1570 gm = backend.graphs[0] 1571 graph = gm.code.strip() 1572 subgraphs = [] 1573 for module_name in gm._modules.keys(): 1574 subgraphs.append(getattr(gm, module_name).code.strip()) 1575 return (graph, *subgraphs) 1576 1577 def test_cond_branches_no_arguments(self): 1578 def fn(x): 1579 def true_fn(): 1580 return torch.sin(x) 1581 1582 def false_fn(): 1583 return torch.cos(x) 1584 1585 return control_flow.cond(x.sum() > 0, true_fn, false_fn, ()) 1586 1587 graphs = self._check_cond_graph_and_extract(fn, (torch.randn(4, 5),)) 1588 if graphs is not None: 1589 graph, true_graph, false_graph = graphs 1590 self.assertExpectedInline( 1591 graph, 1592 """\ 1593def forward(self, L_x_ : torch.Tensor): 1594 l_x_ = L_x_ 1595 sum_1 = l_x_.sum() 1596 gt = sum_1 > 0; sum_1 = None 1597 cond_true_0 = self.cond_true_0 1598 cond_false_0 = self.cond_false_0 1599 cond = torch.ops.higher_order.cond(gt, cond_true_0, cond_false_0, [l_x_]); gt = cond_true_0 = cond_false_0 = l_x_ = None 1600 getitem = cond[0]; cond = None 1601 return (getitem,)""", 1602 ) 1603 self.assertExpectedInline( 1604 true_graph, 1605 """\ 1606def forward(self, l_x_): 1607 l_x__1 = l_x_ 1608 sin = torch.sin(l_x__1); l_x__1 = None 1609 return (sin,)""", 1610 ) 1611 self.assertExpectedInline( 1612 false_graph, 1613 """\ 1614def forward(self, l_x_): 1615 l_x__1 = l_x_ 1616 cos = torch.cos(l_x__1); l_x__1 = None 1617 return (cos,)""", 1618 ) 1619 1620 def test_cond_branches_no_arguments_no_closure(self): 1621 def fn(x): 1622 def true_fn(): 1623 return torch.ones(3, 4) 1624 1625 def false_fn(): 1626 return torch.ones(3, 4).sin() 1627 1628 return control_flow.cond(x.sum() > 0, true_fn, false_fn, ()) 1629 1630 self._check_cond_graph_and_extract(fn, (torch.randn(4, 5),)) 1631 graphs = self._check_cond_graph_and_extract(fn, (torch.randn(4, 5),)) 1632 if graphs is not None: 1633 graph, true_graph, false_graph = graphs 1634 self.assertExpectedInline( 1635 graph, 1636 """\ 1637def forward(self, L_x_ : torch.Tensor): 1638 l_x_ = L_x_ 1639 sum_1 = l_x_.sum(); l_x_ = None 1640 gt = sum_1 > 0; sum_1 = None 1641 cond_true_0 = self.cond_true_0 1642 cond_false_0 = self.cond_false_0 1643 cond = torch.ops.higher_order.cond(gt, cond_true_0, cond_false_0, []); gt = cond_true_0 = cond_false_0 = None 1644 getitem = cond[0]; cond = None 1645 return (getitem,)""", 1646 ) 1647 self.assertExpectedInline( 1648 true_graph, 1649 """\ 1650def forward(self): 1651 ones = torch.ones(3, 4) 1652 return (ones,)""", 1653 ) 1654 self.assertExpectedInline( 1655 false_graph, 1656 """\ 1657def forward(self): 1658 ones = torch.ones(3, 4) 1659 sin = ones.sin(); ones = None 1660 return (sin,)""", 1661 ) 1662 1663 def test_cond_side_effect_in_one_branches(self): 1664 backend = EagerAndRecordGraphs() 1665 cnt = CompileCounterWithBackend(backend) 1666 1667 z = [torch.ones(4, 4)] 1668 1669 class Foo(torch.nn.Module): 1670 def __init__(self) -> None: 1671 super().__init__() 1672 1673 def forward(self, y, x): 1674 def true_fn(x): 1675 z.append(x) 1676 z.append(x) 1677 z.pop() 1678 return x.sum() + z[-1].sum() 1679 1680 def false_fn(x): 1681 return x.sum() - z[0].sum() 1682 1683 return control_flow.cond(y, true_fn, false_fn, [x]) 1684 1685 mod_for_eager = Foo() 1686 mod_for_compile = torch.compile( 1687 Foo(), backend=cnt, dynamic=True, fullgraph=False 1688 ) 1689 with self.assertRaisesRegex( 1690 torch._dynamo.exc.UncapturedHigherOrderOpError, 1691 r"Cond doesn't work unless it is captured completely with torch.compile", 1692 ): 1693 mod_for_eager(torch.tensor(True), torch.tensor(5)) 1694 1695 with self.assertRaisesRegex( 1696 torch._dynamo.exc.UncapturedHigherOrderOpError, 1697 r"Cond doesn't work unless it is captured completely with torch.compile", 1698 ): 1699 mod_for_compile(torch.tensor(True), torch.tensor(5)) 1700 1701 def test_cond_with_constant_pred(self): 1702 def test(pred, x): 1703 def true_fn(x): 1704 return x 1705 1706 def false_fn(x): 1707 return -x 1708 1709 return control_flow.cond(pred, true_fn, false_fn, [x]) 1710 1711 opt_test = torch.compile(test, backend="eager") 1712 inp = torch.ones(3, 3) 1713 self.assertTrue(torch.allclose(test(True, inp), opt_test(True, inp))) 1714 self.assertTrue(torch.allclose(test(False, inp), opt_test(False, inp))) 1715 1716 def test_map_graph_break(self): 1717 backend = EagerAndRecordGraphs() 1718 cnt = CompileCounterWithBackend(backend) 1719 1720 class Module(torch.nn.Module): 1721 def __init__(self) -> None: 1722 super().__init__() 1723 self.w = torch.nn.Buffer(torch.ones(6, 4)) 1724 1725 def forward(self, xs): 1726 def body(x): 1727 self.w += 1 1728 return x 1729 1730 return control_flow.map(body, xs) 1731 1732 mod = Module() 1733 1734 mod_for_compile = torch.compile(mod, backend=cnt, dynamic=True, fullgraph=False) 1735 mod_for_eager = Module() 1736 1737 res = mod_for_compile(torch.Tensor([[6, 4, 5], [3, 4, 5], [6, 6, 6]])) 1738 # There is graph break right when we enter body of map 1739 self.assertEqual(len(backend.graphs), 0) 1740 self.assertEqual( 1741 res, mod_for_eager(torch.Tensor([[6, 4, 5], [3, 4, 5], [6, 6, 6]])) 1742 ) 1743 1744 def test_map_side_effect(self): 1745 backend = EagerAndRecordGraphs() 1746 cnt = CompileCounterWithBackend(backend) 1747 1748 z = [torch.ones(6, 4)] 1749 1750 class Module(torch.nn.Module): 1751 def __init__(self) -> None: 1752 super().__init__() 1753 self.w = torch.nn.Buffer(torch.ones(6, 4)) 1754 1755 def forward(self, xs): 1756 def body(x): 1757 z.append(x) 1758 z.append(x) 1759 z.pop() 1760 return x + z[-1].sum() 1761 1762 return control_flow.map(body, xs) 1763 1764 mod = Module() 1765 1766 mod_for_compile = torch.compile(mod, backend=cnt, dynamic=True, fullgraph=False) 1767 mod_for_eager = Module() 1768 1769 res = mod_for_compile(torch.Tensor([[6, 4, 5], [3, 4, 5], [6, 6, 6]])) 1770 res = mod_for_compile(torch.Tensor([[6, 4, 5], [3, 4, 5], [6, 6, 6]])) 1771 1772 eager = mod_for_eager(torch.Tensor([[6, 4, 5], [3, 4, 5], [6, 6, 6]])) 1773 eager = mod_for_eager(torch.Tensor([[6, 4, 5], [3, 4, 5], [6, 6, 6]])) 1774 1775 self.assertEqual(len(backend.graphs), 0) 1776 self.assertEqual(res, eager) 1777 1778 def test_wrap_subgraph_name_is_valid(self): 1779 backend = EagerAndRecordGraphs() 1780 cnt = CompileCounterWithBackend(backend) 1781 1782 x = torch.randn(3, 3) 1783 y = torch.randn(3, 3) 1784 1785 def inner(x, y): 1786 z = x + y 1787 return wrap(lambda x: wrap(lambda x: x + z, x), x) 1788 1789 @torch.compile(backend=cnt, fullgraph=True) 1790 def f(x, y): 1791 return wrap(inner, x, y) 1792 1793 result = f(x, y) 1794 1795 self.assertEqual(result, x + y + x) 1796 wrap_gm = backend.graphs[0] 1797 names = set() 1798 names.update(mod_name for mod_name, _ in wrap_gm.named_modules()) 1799 self.assertEqual( 1800 names, 1801 { 1802 "", 1803 "wrap_body_2", 1804 "wrap_body_2.wrap_body_1", 1805 "wrap_body_2.wrap_body_1.wrap_body_0", 1806 }, 1807 ) 1808 1809 def test_wrap_allow_local_assign_in_body_fn(self): 1810 def f(arg1, arg2): 1811 def inner_f(arg1, arg2): 1812 a = arg1 1813 b = arg2 1814 ret = [] 1815 for x in a: 1816 ret.append(x + 1) 1817 for x in b: 1818 ret.append(x + 1) 1819 return ret 1820 1821 return wrap(inner_f, arg1, arg2) 1822 1823 x = torch.ones(3) 1824 1825 def my_args_generator(): 1826 yield [x], [x.sin()] 1827 yield (x,), (x.sin(),) 1828 1829 actual_graph = self._test_wrap_simple( 1830 f, 1831 my_args_generator(), 1832 3, 1833 3, 1834 return_graph=True, 1835 ) 1836 1837 # Dynamic shapes produce a slightly different graph. 1838 if check_dynamic_shape_capture(): 1839 return 1840 1841 self.assertExpectedInline( 1842 actual_graph, 1843 """\ 1844class GraphModule(torch.nn.Module): 1845 def forward(self, L_arg1_0_: "f32[3]", L_arg2_0_: "f32[3]"): 1846 l_arg1_0_ = L_arg1_0_ 1847 l_arg2_0_ = L_arg2_0_ 1848 1849 wrap_body_0 = self.wrap_body_0 1850 wrap = torch.ops.higher_order.wrap(wrap_body_0, l_arg1_0_, l_arg2_0_); wrap_body_0 = l_arg1_0_ = l_arg2_0_ = None 1851 getitem: "f32[3]" = wrap[0] 1852 getitem_1: "f32[3]" = wrap[1]; wrap = None 1853 return (getitem, getitem_1) 1854 1855 class wrap_body_0(torch.nn.Module): 1856 def forward(self, l_arg1_0_: "f32[3]", l_arg2_0_: "f32[3]"): 1857 child: "f32[3]" = l_arg1_0_ + 1; l_arg1_0_ = None 1858 1859 child_1: "f32[3]" = l_arg2_0_ + 1; l_arg2_0_ = None 1860 return (child, child_1) 1861""", 1862 ) 1863 1864 def test_capture_global_num(self): 1865 def f(x): 1866 return wrap(lambda x: x + global_num, x) 1867 1868 x = torch.zeros([]) 1869 # Numbers don't get lifted, so args is still 2. 1870 self._test_wrap_simple(f, default_args_generator((x,)), 2) 1871 1872 def test_capture_global_num_adds_guard(self): 1873 @torch.compile(backend="eager", fullgraph=True) 1874 def f(x): 1875 return wrap(lambda x: x + global_num, x) 1876 1877 global global_num 1878 x = torch.zeros([]) 1879 result = f(x) 1880 self.assertEqual(result, x + global_num) 1881 1882 global_num = torch.randn([]).item() 1883 result = f(x) 1884 self.assertEqual(result, x + global_num) 1885 1886 def test_capture_input_num(self): 1887 def f(x, y): 1888 return wrap(lambda x: x + y, x) 1889 1890 x = torch.zeros([]) 1891 y = 3.14 1892 # Numbers don't get lifted, so args is still 2. 1893 self._test_wrap_simple(f, default_args_generator((x, y)), 2) 1894 1895 def test_side_effect_in_body(self): 1896 counters.clear() 1897 backend = EagerAndRecordGraphs() 1898 1899 x = torch.randn([]) 1900 y = torch.randn([]) 1901 1902 def inner(x): 1903 nonlocal y 1904 y = x 1905 return x.clone() 1906 1907 @torch.compile(backend=backend) 1908 def f(x): 1909 return wrap(inner, x) 1910 1911 f(x) 1912 self.assertEqual(y, x) 1913 assert_dict_matches_regex( 1914 self, 1915 dict(counters["graph_break"]), 1916 { 1917 r".*HigherOrderOperator: Mutating a variable not in the current scope \(SideEffects\)": 1 1918 }, 1919 ) 1920 1921 def test_fallback_on_graph_break_simple(self): 1922 # In the future, there should be a per-HigherOrderOperator switch 1923 # on whether or not to fallback or raise a loud error. 1924 # For now we just fallback by default. 1925 cnt = CompileCounter() 1926 x = torch.randn([]) 1927 1928 def inner(x): 1929 y = x.sin() 1930 torch._dynamo.graph_break() 1931 z = y.sin() 1932 return z 1933 1934 @torch.compile(backend=cnt) 1935 def f(x): 1936 return wrap(inner, x) 1937 1938 result = f(x) 1939 self.assertEqual(result, inner(x)) 1940 self.assertEqual(cnt.frame_count, 0) 1941 1942 def test_fallback_on_graph_break_complicated(self): 1943 cnt = CompileCounter() 1944 x = torch.randn([]) 1945 1946 def inner(x): 1947 y = x.sin() 1948 y = y * global_var 1949 torch._dynamo.graph_break() 1950 z = y.sin() 1951 return z 1952 1953 @torch.compile(backend=cnt) 1954 def f(x): 1955 x = x.clone() 1956 result = wrap(inner, x) 1957 return result.clone() 1958 1959 result = f(x) 1960 self.assertEqual(result, inner(x)) 1961 self.assertEqual(cnt.frame_count, 2) 1962 1963 def test_modules(self): 1964 counters.clear() 1965 backend = EagerAndRecordGraphs() 1966 cnt = CompileCounterWithBackend(backend) 1967 mod = torch.nn.Linear(3, 3) 1968 x = torch.randn(3, 3) 1969 1970 @torch.compile(backend=cnt, fullgraph=True) 1971 def f(x): 1972 return wrap(lambda x: mod(x), x) 1973 1974 result = f(x) 1975 1976 self.assertEqual(result, mod(x)) 1977 self.assertEqual(cnt.frame_count, 1) 1978 1979 self.assertEqual(len(backend.graphs), 1) 1980 wrap_node = find_first_node(backend.graphs[0], wrap) 1981 # 3 args - 1 for input, and other 2 for the weight and bias 1982 self.assertTrue(len(wrap_node.args), 3) 1983 1984 # Check that the linear bias and weight are getattr in the outer graph 1985 if not torch._dynamo.config.inline_inbuilt_nn_modules: 1986 self.assertTrue(len(dict(backend.graphs[0].named_parameters())) == 2) 1987 1988 # Check that the inner function has one op and its a linear op 1989 body_function = getattr(backend.graphs[0], wrap_node.args[0].name) 1990 self.assertEqual(op_count(body_function), 1) 1991 linear_node = find_first_node(body_function, torch._C._nn.linear) 1992 self.assertTrue(linear_node is not None) 1993 1994 # Check that the innermost graph does not have any params 1995 self.assertTrue(len(dict(body_function.named_parameters())) == 0) 1996 self.assertTrue(len(dict(body_function.named_children())) == 0) 1997 1998 def test_flat_list_output(self): 1999 def f(x): 2000 return wrap(lambda x: [torch.sin(x), torch.cos(x)], x) 2001 2002 x = torch.randn(3) 2003 self._test_wrap_simple(f, default_args_generator((x,)), 2, expected_opcount=3) 2004 2005 def test_fallback_on_python_primitives_output(self): 2006 counters.clear() 2007 cnt = CompileCounter() 2008 2009 @torch.compile(backend=cnt) 2010 def f(x): 2011 return wrap(lambda x: [1, torch.sin(x), 2.0], x) 2012 2013 x = torch.randn(3) 2014 result = f(x) 2015 self.assertEqual(result, [1, torch.sin(x), 2.0]) 2016 self.assertEqual(cnt.frame_count, 0) 2017 assert_dict_matches_regex( 2018 self, 2019 dict(counters["graph_break"]), 2020 {".*HigherOrderOperator body's output must consist of tensors only": 1}, 2021 ) 2022 2023 def test_nested_tuple_output(self): 2024 def f(x): 2025 ((a, b),) = wrap(lambda x: ((x.sin(), x.cos()),), x) 2026 return a + b 2027 2028 x = torch.randn(2, 3) 2029 2030 counters.clear() 2031 graph = self._test_wrap_simple( 2032 f, default_args_generator((x,)), 2, 4, return_graph=True 2033 ) 2034 self.assertEqual(len(counters["graph_break"]), 0) 2035 2036 if check_dynamic_shape_capture(): 2037 return 2038 2039 self.assertExpectedInline( 2040 graph, 2041 """\ 2042class GraphModule(torch.nn.Module): 2043 def forward(self, L_x_: "f32[2, 3]"): 2044 l_x_ = L_x_ 2045 2046 wrap_body_0 = self.wrap_body_0 2047 wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_); wrap_body_0 = l_x_ = None 2048 a: "f32[2, 3]" = wrap[0] 2049 b: "f32[2, 3]" = wrap[1]; wrap = None 2050 2051 add: "f32[2, 3]" = a + b; a = b = None 2052 return (add,) 2053 2054 class wrap_body_0(torch.nn.Module): 2055 def forward(self, l_x_: "f32[2, 3]"): 2056 child: "f32[2, 3]" = l_x_.sin() 2057 child_1: "f32[2, 3]" = l_x_.cos(); l_x_ = None 2058 return (child, child_1) 2059""", 2060 ) 2061 2062 def test_output_with_dict(self): 2063 def f(x): 2064 return wrap(lambda x: [{"a": -x}], x) 2065 2066 x = torch.randn(3) 2067 2068 counters.clear() 2069 graph = self._test_wrap_simple( 2070 f, default_args_generator((x,)), 2, 2, return_graph=True 2071 ) 2072 self.assertEqual(len(counters["graph_break"]), 0) 2073 2074 if check_dynamic_shape_capture(): 2075 return 2076 2077 self.assertExpectedInline( 2078 graph, 2079 """\ 2080class GraphModule(torch.nn.Module): 2081 def forward(self, L_x_: "f32[3]"): 2082 l_x_ = L_x_ 2083 2084 wrap_body_0 = self.wrap_body_0 2085 wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_); wrap_body_0 = l_x_ = None 2086 getitem: "f32[3]" = wrap[0]; wrap = None 2087 return (getitem,) 2088 2089 class wrap_body_0(torch.nn.Module): 2090 def forward(self, l_x_: "f32[3]"): 2091 child: "f32[3]" = -l_x_; l_x_ = None 2092 return (child,) 2093""", 2094 ) 2095 2096 def test_access_module_attr(self): 2097 counters.clear() 2098 backend = EagerAndRecordGraphs() 2099 cnt = CompileCounterWithBackend(backend) 2100 mod = torch.nn.Linear(3, 3) 2101 x = torch.randn(3, 3) 2102 2103 @torch.compile(backend=cnt, fullgraph=True) 2104 def f(x): 2105 y = mod(x) 2106 return wrap(lambda y: y - mod.bias, y) 2107 2108 result = f(x) 2109 self.assertEqual(result, mod(x) - mod.bias) 2110 self.assertEqual(cnt.frame_count, 1) 2111 2112 self.assertEqual(len(backend.graphs), 1) 2113 wrap_node = find_first_node(backend.graphs[0], wrap) 2114 self.assertTrue(len(wrap_node.args), 3) 2115 2116 # Check that the linear bias and weight are getattr in the outer graph 2117 if not torch._dynamo.config.inline_inbuilt_nn_modules: 2118 self.assertTrue(len(dict(backend.graphs[0].named_parameters())) == 2) 2119 2120 # Check that the inner function has one op and its a linear op 2121 body_function = getattr(backend.graphs[0], wrap_node.args[0].name) 2122 self.assertEqual(op_count(body_function), 1) 2123 2124 # Check that the innermost graph does not have any params 2125 self.assertTrue(len(dict(body_function.named_parameters())) == 0) 2126 self.assertTrue(len(dict(body_function.named_children())) == 0) 2127 2128 def test_make_closure(self): 2129 def f(x, y): 2130 def g(x): 2131 return x + y 2132 2133 return g(x) 2134 2135 def h(x, y): 2136 return wrap(f, x, y) 2137 2138 x = torch.randn(3, 3) 2139 y = torch.randn(3, 3) 2140 self._test_wrap_simple(h, default_args_generator((x, y)), 3) 2141 2142 def test_internal_nonlocal(self): 2143 def f(x, y): 2144 w = 1 2145 2146 def g(x): 2147 nonlocal w 2148 w = x 2149 return x 2150 2151 def h(x): 2152 nonlocal w 2153 w = w + 1 2154 return x 2155 2156 g(x) 2157 h(x) 2158 return w + y 2159 2160 def h(x, y): 2161 return wrap(f, x, y) 2162 2163 x = torch.randn(3, 3) 2164 y = torch.randn(3, 3) 2165 self._test_wrap_simple(h, default_args_generator((x, y)), 3) 2166 2167 def test_capture_numpy_number(self): 2168 import numpy as np 2169 2170 y = np.float32(1.0) 2171 2172 def f(x): 2173 return wrap(lambda x: x + y, x) 2174 2175 x = torch.randn(3) 2176 # np.number are lifted to graph inputs 2177 self._test_wrap_simple(f, default_args_generator((x,)), 3) 2178 2179 def test_freevars_as_inputs_to_wrap(self): 2180 y = torch.randn(3) 2181 2182 def f(x): 2183 return wrap(lambda x, y: x + y, x, y) 2184 2185 x = torch.randn(3) 2186 self._test_wrap_simple(f, default_args_generator((x,)), 3) 2187 2188 def test_lift_tensor_constant(self): 2189 def f(x): 2190 y = torch.tensor(1.0) 2191 return wrap(lambda x: x + y, x) 2192 2193 x = torch.randn(3) 2194 self._test_wrap_simple(f, default_args_generator((x,)), 3, expected_opcount=3) 2195 2196 def test_nested_wrap(self): 2197 class MockModule(torch.nn.Module): 2198 def __init__(self) -> None: 2199 super().__init__() 2200 self.linear = torch.nn.Linear(10, 10) 2201 2202 def forward(self, x): 2203 return self.linear(x) 2204 2205 mod = MockModule() 2206 2207 # Two levels of wrap ops 2208 def gn(x): 2209 return torch.cos(x) + wrap(mod, x) 2210 2211 def fn(x): 2212 return wrap(gn, x) 2213 2214 self._test_wrap_simple(fn, default_args_generator((torch.randn(10, 10),)), 4) 2215 2216 def test_fn_with_kwargs_in_torch_ops(self): 2217 def fn(x): 2218 return wrap(lambda z: torch.cos(input=z), x) 2219 2220 x = torch.randn(3) 2221 self._test_wrap_simple(fn, default_args_generator((x,)), 2) 2222 2223 def test_hooks(self): 2224 class ToyModel(torch.nn.Module): 2225 def __init__(self) -> None: 2226 super().__init__() 2227 self.net = torch.nn.Linear(10, 10) 2228 2229 def forward(self, x): 2230 return self.net(x) 2231 2232 model = ToyModel() 2233 forward_handles = {} 2234 activations = {} 2235 2236 def save_activations(mod, inp, out): 2237 activations[name] = inp 2238 2239 for name, module in model.named_children(): 2240 forward_handles[name] = module.register_forward_hook(save_activations) 2241 2242 @torch.compile(backend="eager") 2243 def fn(x): 2244 return wrap(lambda x: model(x), x) 2245 2246 for i in range(2): 2247 # second iteration is key, hooks would have fired during aot trace 2248 # on first iter 2249 activations.clear() 2250 x = torch.randn((10, 10)) 2251 pred = fn(x) 2252 loss = pred.sum() 2253 loss.backward() 2254 2255 self.assertTrue(activations.keys() == forward_handles.keys()) 2256 2257 def _get_source_fn_stack(self, gm, node_names): 2258 ret = {} 2259 for mod in gm.modules(): 2260 for node in mod.graph.nodes: 2261 if node.name in node_names: 2262 actual_stack = [ 2263 name for name, _ in node.meta.get("source_fn_stack", []) 2264 ] 2265 ret[node.name] = actual_stack 2266 return ret 2267 2268 def test_wrap_source_fn_stack(self): 2269 class MockModule(torch.nn.Module): 2270 def __init__(self) -> None: 2271 super().__init__() 2272 self.linear = torch.nn.Linear(4, 4) 2273 2274 def forward(self, x): 2275 return self.linear(x) 2276 2277 mod = MockModule() 2278 2279 def gn(x): 2280 return torch.cos(x) + wrap(mod, x) 2281 2282 def fn(x): 2283 return wrap(gn, x) 2284 2285 backend = EagerAndRecordGraphs() 2286 inp = torch.randn((4, 4)) 2287 torch.compile(fn, backend=backend, fullgraph=True)(inp) 2288 2289 gm = backend.graphs[0] 2290 actual_stack = self._get_source_fn_stack(gm, {"cos", "add", "linear"}) 2291 self.assertExpectedInline( 2292 pprint.pformat(actual_stack), 2293 """\ 2294{'add': ['wrap', 'add'], 2295 'cos': ['wrap', 'cos'], 2296 'linear': ['wrap', 'wrap', 'linear']}""", 2297 ) 2298 2299 def test_cond_source_fn_stack(self): 2300 backend = EagerAndRecordGraphs() 2301 2302 @torch.compile(backend=backend, fullgraph=True) 2303 def cond_f(pred, pred2, x, y): 2304 def true_fn(pred2, x, y): 2305 return x + y 2306 2307 def false_fn(pred2, x, y): 2308 def true_fn2(x, y): 2309 return x.sin() - y.cos() 2310 2311 def false_fn2(x, y): 2312 return x.cos() - y.sin() 2313 2314 return control_flow.cond(pred2, true_fn2, false_fn2, [x, y]) 2315 2316 return control_flow.cond(pred, true_fn, false_fn, [pred2, x, y]) 2317 2318 pred = torch.tensor(True) 2319 pred2 = torch.tensor(False) 2320 xs = torch.randn(2, 3, 3) 2321 y = torch.randn(3, 3) 2322 cond_f(pred, pred2, xs, y) 2323 2324 gm = backend.graphs[0] 2325 actual_stack = self._get_source_fn_stack(gm, {"cos", "add", "sin", "sub"}) 2326 self.assertExpectedInline( 2327 pprint.pformat(actual_stack), 2328 """\ 2329{'add': ['cond', 'add'], 2330 'cos': ['cond', 'cond', 'cos'], 2331 'sin': ['cond', 'cond', 'sin'], 2332 'sub': ['cond', 'cond', 'sub']}""", 2333 ) 2334 2335 def test_map_source_fn_stack(self): 2336 backend = EagerAndRecordGraphs() 2337 2338 xs = torch.randn(2, 3, 3) 2339 y = torch.randn(3) 2340 2341 @torch.compile(backend=backend, fullgraph=True) 2342 def map_f(xs, y): 2343 def inner(x, y): 2344 def inner2(x, y): 2345 return x + y 2346 2347 return control_flow.map(inner2, x, y) * y.cos() 2348 2349 return control_flow.map(inner, xs, y).sin() 2350 2351 result = map_f(xs, y) 2352 2353 gm = backend.graphs[0] 2354 actual_stack = self._get_source_fn_stack(gm, {"cos", "add", "sin"}) 2355 self.assertExpectedInline( 2356 pprint.pformat(actual_stack), 2357 """{'add': ['map', 'map', 'add'], 'cos': ['map', 'cos'], 'sin': ['sin']}""", 2358 ) 2359 2360 def test_grad_source_fn_stack(self): 2361 backend = EagerAndRecordGraphs() 2362 2363 def fn(x): 2364 return x.sin().sum() 2365 2366 @torch.compile(backend=backend, fullgraph=False) 2367 def wrapper_fn(x): 2368 return torch.func.grad(torch.func.grad(fn))(x) 2369 2370 x = torch.randn(()) 2371 2372 wrapper_fn(x) 2373 gm = backend.graphs[0] 2374 actual_stack = self._get_source_fn_stack(gm, {"sum_1", "sin"}) 2375 self.assertExpectedInline( 2376 pprint.pformat(actual_stack), 2377 """{'sin': ['sin']}""", 2378 ) 2379 2380 def test_vmap_multiply_scalar(self): 2381 @torch.compile(backend="inductor", fullgraph=True) 2382 def g(x): 2383 return torch.vmap(torch.mul, in_dims=(0, None))(x, 3.14) 2384 2385 x = torch.randn(3) 2386 y = g(x) 2387 self.assertEqual(y, x * 3.14) 2388 2389 @torch.compile(backend="inductor", fullgraph=True) 2390 def f(x): 2391 return torch.vmap(torch.mul, in_dims=(0, None))(x, 314) 2392 2393 x = torch.randn(3) 2394 y = f(x) 2395 self.assertEqual(y, x * 314) 2396 2397 def test_vmap_source_fn_stack(self): 2398 backend = EagerAndRecordGraphs() 2399 2400 def inner_fn(x): 2401 return torch.func.vmap(lambda x: x.sum(0) + x.sum(1))(x) 2402 2403 @torch.compile(backend=backend, fullgraph=True) 2404 def fn(x): 2405 return torch.func.vmap(lambda x: inner_fn(x.cos()))(x) 2406 2407 x = torch.randn(3, 3, 3, 3) 2408 fn(x) 2409 gm = backend.graphs[0] 2410 actual_stack = self._get_source_fn_stack( 2411 gm, {"sum_1", "sum_2", "batched_output"} 2412 ) 2413 self.assertExpectedInline( 2414 pprint.pformat(actual_stack), 2415 """{'sum_1': ['sum_1'], 'sum_2': ['sum_2']}""", 2416 ) 2417 2418 def test_cond_pytree_operands(self): 2419 def _construct_pytree(): 2420 a = torch.randn(3, 3) 2421 b = torch.randn(3, 3) 2422 c = torch.randn(3, 3) 2423 d = torch.randn(3, 3) 2424 e = torch.randn(3, 3) 2425 f = torch.randn(3, 3) 2426 g = torch.randn(3, 3) 2427 return (a, [[[b]]], c, (d, (e,), f), {"g": g}) 2428 2429 pred = torch.tensor(True) 2430 inp = _construct_pytree() 2431 2432 def _reduce_sum(flattened): 2433 init = 0 2434 for val in flattened: 2435 init += val 2436 return init 2437 2438 def _reduce_max(flattened): 2439 init = flattened[0] 2440 for val in flattened: 2441 init = max(val, init) 2442 return init 2443 2444 def true_fn(pytree_in): 2445 flattened, spec = pytree.tree_flatten(pytree_in) 2446 return _reduce_sum(flattened) 2447 2448 def false_fn(pytree_in): 2449 flattened, spec = pytree.tree_flatten(pytree_in) 2450 return _reduce_max(flattened) 2451 2452 def fn(pred, pytree_in): 2453 return torch.cond(pred, true_fn, false_fn, [pytree_in]) 2454 2455 backend = EagerAndRecordGraphs() 2456 cnt = CompileCounterWithBackend(backend) 2457 compiled_res = torch.compile(fn, backend=backend)(pred, inp) 2458 eager_res = fn(pred, inp) 2459 self.assertEqual(compiled_res, eager_res) 2460 graph = backend.graphs[0] 2461 2462 # Dynamic shapes produce a slightly different graph. 2463 if check_dynamic_shape_capture(): 2464 return 2465 2466 self.assertExpectedInline( 2467 graph.code.strip(), 2468 """\ 2469def forward(self, L_pred_ : torch.Tensor, L_pytree_in_0_ : torch.Tensor, L_pytree_in_1_0_0_0_ : torch.Tensor, L_pytree_in_2_ : torch.Tensor, L_pytree_in_3_0_ : torch.Tensor, L_pytree_in_3_1_0_ : torch.Tensor, L_pytree_in_3_2_ : torch.Tensor, L_pytree_in_4_g_ : torch.Tensor): 2470 l_pred_ = L_pred_ 2471 l_pytree_in_0_ = L_pytree_in_0_ 2472 l_pytree_in_1_0_0_0_ = L_pytree_in_1_0_0_0_ 2473 l_pytree_in_2_ = L_pytree_in_2_ 2474 l_pytree_in_3_0_ = L_pytree_in_3_0_ 2475 l_pytree_in_3_1_0_ = L_pytree_in_3_1_0_ 2476 l_pytree_in_3_2_ = L_pytree_in_3_2_ 2477 l_pytree_in_4_g_ = L_pytree_in_4_g_ 2478 cond_true_0 = self.cond_true_0 2479 cond_false_0 = self.cond_false_0 2480 cond = torch.ops.higher_order.cond(l_pred_, cond_true_0, cond_false_0, [l_pytree_in_0_, l_pytree_in_1_0_0_0_, l_pytree_in_2_, l_pytree_in_3_0_, l_pytree_in_3_1_0_, l_pytree_in_3_2_, l_pytree_in_4_g_]); l_pred_ = cond_true_0 = cond_false_0 = l_pytree_in_0_ = l_pytree_in_1_0_0_0_ = l_pytree_in_2_ = l_pytree_in_3_0_ = l_pytree_in_3_1_0_ = l_pytree_in_3_2_ = l_pytree_in_4_g_ = None 2481 getitem = cond[0]; cond = None 2482 return (getitem,)""", # noqa: B950 2483 ) 2484 2485 def test_cond_pytree_operands_with_non_tensor_leaves(self): 2486 def fn(pred, pytree_in): 2487 return torch.cond( 2488 pred, lambda x: x[0] + 1, lambda x: x[0] * 2, (pytree_in,) 2489 ) 2490 2491 pred = torch.tensor(True) 2492 for pytree_in in [(1,), ("string",), (1.0,)]: 2493 with self.assertRaisesRegex( 2494 RuntimeError, 2495 r"Expect operands to be a tuple of possibly nested dict/list/tuple", 2496 ): 2497 fn(pred, pytree_in) 2498 2499 for pytree_in in [(1,), ("string",), (1.0,)]: 2500 with self.assertRaisesRegex( 2501 torch._dynamo.exc.UncapturedHigherOrderOpError, 2502 r"Cond doesn't work unless it is captured completely with torch.compile", 2503 ): 2504 torch.compile(fn, backend="eager")(pred, pytree_in) 2505 2506 def test_hints_wrapper(self): 2507 def ref_fn(x, y): 2508 x = x + y 2509 x = torch.relu(x) 2510 x = x + y 2511 return torch.abs(x) 2512 2513 def fn_with_hints(x, y): 2514 x = x + y 2515 2516 def inner_body_fn(x, y): 2517 x = torch.relu(x) 2518 x = x + y 2519 return x 2520 2521 def outer_body_fn(x, y): 2522 x = hints_wrapper(inner_body_fn, (x, y), {}, hints={"inner_body": True}) 2523 x = torch.abs(x) 2524 return x 2525 2526 res = hints_wrapper(outer_body_fn, (x, y), {}, hints={"outer_body": True}) 2527 return res 2528 2529 backend = EagerAndRecordGraphs() 2530 cnt = CompileCounterWithBackend(backend) 2531 2532 x = torch.randn(2, 4) 2533 y = torch.ones(4) 2534 2535 eager_res = fn_with_hints(x, y) 2536 compiled_res = torch.compile(fn_with_hints, backend=cnt)(x, y) 2537 ref_res = ref_fn(x, y) 2538 self.assertEqual(eager_res, ref_res) 2539 self.assertEqual(compiled_res, ref_res) 2540 self.assertEqual(len(cnt.graphs), 1) 2541 2542 # Dynamic shapes produce a slightly different graph. 2543 if check_dynamic_shape_capture(): 2544 return 2545 2546 graph = backend.graphs[0] 2547 self.assertExpectedInline( 2548 normalize_gm(graph.print_readable(print_output=False)), 2549 """\ 2550class GraphModule(torch.nn.Module): 2551 def forward(self, L_x_: "f32[2, 4]", L_y_: "f32[4]"): 2552 l_x_ = L_x_ 2553 l_y_ = L_y_ 2554 2555 x: "f32[2, 4]" = l_x_ + l_y_; l_x_ = None 2556 2557 hints_wrapper_body_1 = self.hints_wrapper_body_1 2558 hints_wrapper = torch.ops.higher_order.hints_wrapper(hints_wrapper_body_1, (x, l_y_), {}, hints = {'outer_body': True}); hints_wrapper_body_1 = x = l_y_ = None 2559 res: "f32[2, 4]" = hints_wrapper[0]; hints_wrapper = None 2560 return (res,) 2561 2562 class hints_wrapper_body_1(torch.nn.Module): 2563 def forward(self, x: "f32[2, 4]", l_y_: "f32[4]"): 2564 hints_wrapper_body_0 = self.hints_wrapper_body_0 2565 hints_wrapper = torch.ops.higher_order.hints_wrapper(hints_wrapper_body_0, (x, l_y_), {}, hints = {'inner_body': True}); hints_wrapper_body_0 = x = l_y_ = None 2566 x_1: "f32[2, 4]" = hints_wrapper[0]; hints_wrapper = None 2567 2568 x_2: "f32[2, 4]" = torch.abs(x_1); x_1 = None 2569 return (x_2,) 2570 2571 class hints_wrapper_body_0(torch.nn.Module): 2572 def forward(self, x: "f32[2, 4]", l_y_: "f32[4]"): 2573 x_1: "f32[2, 4]" = torch.relu(x); x = None 2574 2575 x_2: "f32[2, 4]" = x_1 + l_y_; x_1 = l_y_ = None 2576 return (x_2,) 2577""", 2578 ) 2579 2580 def test_hints_wrapper_no_hints(self): 2581 def fn_with_hints(x, y): 2582 def outer_body_fn(x, y): 2583 x = torch.add(x, y) 2584 return x 2585 2586 res = hints_wrapper(outer_body_fn, (x, y), {}) 2587 return res 2588 2589 backend = EagerAndRecordGraphs() 2590 cnt = CompileCounterWithBackend(backend) 2591 2592 x = torch.randn(2, 4) 2593 y = torch.ones(4) 2594 2595 msg = "hints_wrapper - key hints not provided" 2596 with self.assertRaisesRegex(RuntimeError, msg): 2597 compiled_res = torch.compile(fn_with_hints, backend=cnt)(x, y) 2598 2599 def test_hints_wrapper_incorrect_type(self): 2600 def fn_with_hints(x, y): 2601 def outer_body_fn(x, y): 2602 x = torch.add(x, y) 2603 return x 2604 2605 res = hints_wrapper(outer_body_fn, (x, y), {}, hints={"test": (True,)}) 2606 return res 2607 2608 backend = EagerAndRecordGraphs() 2609 cnt = CompileCounterWithBackend(backend) 2610 2611 x = torch.randn(2, 4) 2612 y = torch.ones(4) 2613 2614 msg = r"hints must be a dict containing int, float, bool or str value," 2615 with self.assertRaisesRegex(RuntimeError, msg): 2616 compiled_res = torch.compile(fn_with_hints, backend=cnt)(x, y) 2617 2618 def test_hints_wrapper_pytree_inputs(self): 2619 def fn_with_hints(x, y): 2620 def outer_body_fn(x): 2621 res = torch.add(x[0], x[1]["test"]) 2622 return res 2623 2624 res = hints_wrapper( 2625 outer_body_fn, ((x, {"test": y}),), {}, hints={"test": True} 2626 ) 2627 return res 2628 2629 backend = EagerAndRecordGraphs() 2630 cnt = CompileCounterWithBackend(backend) 2631 2632 x = torch.randn(2, 4) 2633 y = torch.ones(4) 2634 2635 msg = r"args must be a tuple of tensors, ints, floats, or bools," 2636 with self.assertRaisesRegex(RuntimeError, msg): 2637 fn_with_hints(x, y) 2638 2639 2640class HigherOrderOpVmapGuardTests(LoggingTestCase): 2641 @make_logging_test(recompiles=True) 2642 def test_vmap_grad_guard_ok(self, records): 2643 vmap = torch.vmap 2644 grad = torch.func.grad 2645 2646 def g(x): 2647 return vmap(grad(torch.sin))(x) 2648 2649 @torch.compile(backend="eager") 2650 def fn(x): 2651 return vmap(g)(x) 2652 2653 x = torch.randn(4, 5) 2654 y = fn(x) 2655 # sanity check 2656 self.assertEqual(len(records), 0) 2657 self.assertEqual(x.cos(), y) 2658 2659 # Calling the same function again won't have any effect on guards 2660 fn(x) 2661 self.assertEqual(len(records), 0) 2662 2663 @xfailIfTorchDynamo 2664 @make_logging_test(recompiles=True) 2665 def test_grad_guard_fail(self, records): 2666 grad = torch.func.grad 2667 2668 @torch.compile(backend="eager") 2669 def fn(x): 2670 return grad(torch.sin)(x.sum()) 2671 2672 x = torch.randn([]) 2673 fn(x) 2674 self.assertEqual(len(records), 0) 2675 2676 # calling again should not invalidate the graph 2677 fn(x) 2678 self.assertEqual(len(records), 0) 2679 2680 # call grad should retrigger compilation 2681 x = torch.randn(3) 2682 grad(fn)(x) 2683 self.assertGreater(len(records), 0) 2684 record = self.getRecord(records, "pyfunctorch") 2685 self.assertIn( 2686 """torch._functorch.pyfunctorch.compare_functorch_state([])""", 2687 munge_exc(record.getMessage()), 2688 ) 2689 2690 @make_logging_test(recompiles=True) 2691 def test_dual_level_guard(self, records): 2692 fwAD = torch.autograd.forward_ad 2693 2694 @torch.compile(backend="eager", fullgraph=True) 2695 def fn(foo, tangent): 2696 with fwAD.dual_level(): 2697 dual = fwAD.make_dual(foo, tangent[1:]) 2698 return dual 2699 2700 foo = torch.rand(2) 2701 tangent = torch.rand(3) 2702 fn(foo, tangent) 2703 self.assertEqual(len(records), 0) 2704 2705 # calling again should not invalidate the graph 2706 fn(foo, tangent) 2707 self.assertEqual(len(records), 0) 2708 2709 # assertRaises is only here because Nested forward mode AD is not supported 2710 with self.assertRaises(torch._dynamo.exc.InternalTorchDynamoError): 2711 with fwAD.dual_level(): 2712 fn(foo, tangent) 2713 self.assertGreater(len(records), 0) 2714 record = self.getRecord(records, "forward_ad") 2715 self.assertIn( 2716 """torch.autograd.forward_ad._current_level == -1""", 2717 munge_exc(record.getMessage()), 2718 ) 2719 2720 @xfailIfTorchDynamo 2721 @make_logging_test(recompiles=True) 2722 def test_jvp_guard_fail(self, records): 2723 jvp = torch.func.jvp 2724 vmap = torch.func.vmap 2725 2726 @torch.compile(backend="eager") 2727 def fn(x): 2728 return jvp(torch.sin, (x,), (x,)) 2729 2730 x = torch.randn(3, 4) 2731 fn(x) 2732 self.assertEqual(len(records), 0) 2733 2734 # calling again should not invalidate the graph 2735 fn(x) 2736 self.assertEqual(len(records), 0) 2737 2738 # call jvp should retrigger compilation 2739 x = torch.randn(3, 4, 5) 2740 jvp(vmap(fn), (x,), (x,)) 2741 2742 self.assertGreater(len(records), 0) 2743 if self.hasRecord(records, "pyfunctorch"): 2744 record = self.getRecord(records, "pyfunctorch") 2745 self.assertIn( 2746 """torch._functorch.pyfunctorch.compare_functorch_state([])""", 2747 munge_exc(record.getMessage()), 2748 ) 2749 elif self.hasRecord(records, "forward_ad"): 2750 record = self.getRecord(records, "forward_ad") 2751 self.assertIn( 2752 """torch.autograd.forward_ad._current_level == -1""", 2753 munge_exc(record.getMessage()), 2754 ) 2755 2756 @make_logging_test(recompiles=True) 2757 def test_vmap_guard_ok(self, records): 2758 @torch.compile(backend="eager") 2759 def fn(x): 2760 return torch.vmap(lambda x: x.sin())(x) 2761 2762 x = torch.randn(3, 3, 4, 5) 2763 y = fn(x) 2764 # sanity check 2765 self.assertEqual(len(records), 0) 2766 self.assertEqual(x.sin(), y) 2767 2768 # Calling the same function again won't have any effect on guards 2769 z = fn(x) 2770 self.assertEqual(len(records), 0) 2771 self.assertEqual(x.sin(), z) 2772 2773 # calling with a different object will also not affect guards 2774 w = fn(z) 2775 self.assertEqual(len(records), 0) 2776 self.assertEqual(z.sin(), w) 2777 2778 @xfailIfTorchDynamo 2779 @make_logging_test(recompiles=True) 2780 def test_vmap_guard_fail_different_state(self, records): 2781 @torch.compile(backend="eager") 2782 def fn(x): 2783 return torch.vmap(lambda x: x.sin())(x) 2784 2785 x = torch.zeros(3, 4) 2786 y = torch.vmap(fn, randomness="same")(x) 2787 self.assertEqual(x.sin(), y) 2788 self.assertEqual(len(records), 0) 2789 2790 # call vmap(vmap(fn))(x) should retrigger compilation 2791 y = torch.vmap(fn, randomness="different")(x) 2792 self.assertEqual(x.sin(), y) 2793 self.assertGreater(len(records), 0) 2794 record = self.getRecord(records, "pyfunctorch") 2795 self.assertIn( 2796 """torch._functorch.pyfunctorch.compare_functorch_state([('Vmap', 1, 'same')])""", 2797 record.getMessage(), 2798 ) 2799 2800 @xfailIfTorchDynamo 2801 @make_logging_test(recompiles=True) 2802 def test_vmap_guard_fail(self, records): 2803 @torch.compile(backend="eager") 2804 def fn(x): 2805 return torch.vmap(lambda x: x.sin())(x) 2806 2807 x = torch.zeros(3, 3, 4, 5) 2808 y = torch.vmap(fn)(x) 2809 self.assertEqual(x.sin(), y) 2810 self.assertEqual(len(records), 0) 2811 2812 # call vmap(vmap(fn))(x) should retrigger compilation as 2813 # _functorch.current_level() is not the same 2814 x = torch.zeros(3, 3, 3, 4, 5) 2815 y = torch.vmap(torch.vmap(fn))(x) 2816 self.assertEqual(x.sin(), y) 2817 self.assertGreater(len(records), 0) 2818 record = self.getRecord(records, "pyfunctorch") 2819 self.assertIn( 2820 """torch._functorch.pyfunctorch.compare_functorch_state([('Vmap', 1, 'error')])""", 2821 record.getMessage(), 2822 ) 2823 2824 @xfailIfTorchDynamo 2825 @make_logging_test(recompiles=True) 2826 def test_vmap_grad_vmap_guard_fail(self, records): 2827 vmap = torch.vmap 2828 grad = torch.func.grad 2829 2830 def g(x): 2831 y = vmap(torch.sin, randomness="same")(x) 2832 return y.sum(0) 2833 2834 @torch.compile(backend="eager") 2835 def fn(x): 2836 return grad(g)(x) 2837 2838 x = torch.randn(3, 3) 2839 y = vmap(fn, randomness="error")(x) 2840 self.assertEqual(x.cos(), y) 2841 2842 # previous FX graph should be invalidated 2843 x = torch.randn(3, 3, 4) 2844 y = vmap(vmap(fn, randomness="different"))(x) 2845 self.assertGreater(len(records), 0) 2846 record = self.getRecord(records, "pyfunctorch") 2847 self.assertIn( 2848 """torch._functorch.pyfunctorch.compare_functorch_state([('Vmap', 1, 'error')])""", 2849 munge_exc(record.getMessage()), 2850 ) 2851 2852 @xfailIfTorchDynamo 2853 @make_logging_test(recompiles=True) 2854 def test_vmap_recompile_different_states(self, records): 2855 @torch.compile(backend="eager") 2856 def fn(x): 2857 return torch.vmap(lambda x: x.sin())(x) 2858 2859 x = torch.zeros(3, 3, 4, 5) 2860 y = torch.vmap(fn, randomness="same")(x) 2861 self.assertEqual(len(records), 0) # sanity check 2862 2863 y = torch.vmap(fn, randomness="different")(x) 2864 self.assertGreater(len(records), 0) 2865 record = self.getRecord(records, "pyfunctorch") 2866 self.assertIn( 2867 """torch._functorch.pyfunctorch.compare_functorch_state([('Vmap', 1, 'same')])""", 2868 munge_exc(record.getMessage()), 2869 ) 2870 2871 @config.patch(capture_func_transforms=True) 2872 @make_logging_test(guards=True) 2873 def test_emit_functorch_guard_if_active(self, records): 2874 @torch.compile(backend="eager") 2875 def fn(x): 2876 return torch.sin(x) 2877 2878 x = torch.randn(3, 4) 2879 _ = fn(x) 2880 self.assertFalse(self.hasRecord(records, "pyfunctorch")) # sanity check 2881 2882 _ = torch.vmap(fn)(x) 2883 self.assertTrue(self.hasRecord(records, "pyfunctorch")) 2884 record = self.getRecord(records, "pyfunctorch") 2885 self.assertIn( 2886 """torch._functorch.pyfunctorch.compare_functorch_state([('Vmap', 1, 'error')])""", 2887 munge_exc(record.getMessage()), 2888 ) 2889 2890 @make_logging_test(recompiles=True) 2891 def test_linearize_recompiles(self, records): 2892 @torch.compile(backend="eager") 2893 def fn(x): 2894 out, jvp_fn = torch.func.linearize(torch.sin, x) 2895 return out, jvp_fn(x) 2896 2897 x = torch.randn(2, 3) 2898 fn(x) 2899 self.assertEqual(len(records), 0) 2900 2901 z = torch.randn(2, 3) 2902 fn(z) 2903 self.assertEqual(len(records), 0) 2904 2905 y = torch.randn(3, 4) 2906 fn(y) 2907 self.assertGreater(len(records), 0) 2908 2909 2910class FuncTorchHigherOrderOpTests(torch._dynamo.test_case.TestCase): 2911 def tearDown(self): 2912 # Ensure that in the case of a test failure, the next test won't fail 2913 # because of a previous call to _vmap_increment_nesting that wasn't undone 2914 # i.e. test_vmap_free_tensor fails when PYTORCH_TEST_WITH_DYNAMO=1 2915 # and the call to increment nesting is not undone 2916 if not TEST_WITH_TORCHDYNAMO: 2917 return 2918 2919 warn = False 2920 while ci := torch._C._functorch.peek_interpreter_stack(): 2921 if ci.key() == torch._C._functorch.TransformType.Vmap: 2922 warn = True 2923 torch._C._functorch._vmap_decrement_nesting() 2924 else: 2925 break 2926 2927 if warn: 2928 msg = ( 2929 "Interpreter stack is not empty. Test should have called " 2930 "'torch._C._functorch._vmap_decrement_nesting()'" 2931 ) 2932 warnings.warn(msg) 2933 2934 def _compile_check(self, fn, inputs, fullgraph=True, graph_idx=0): 2935 backend = EagerAndRecordGraphs() 2936 actual = fn(*inputs) 2937 expected = torch.compile(fn, backend=backend, fullgraph=fullgraph)(*inputs) 2938 2939 self.assertEqual(actual, expected) 2940 2941 wrapped_gm = backend.graphs[graph_idx] 2942 return wrapped_gm 2943 2944 def test_hessian(self): 2945 counters.clear() 2946 2947 def wrapper_fn(x): 2948 return torch.func.hessian(torch.sin)(x) 2949 2950 x = torch.randn(4, 3) 2951 wrapped_gm = self._compile_check(wrapper_fn, (x,)) 2952 # Dynamic shapes produce a slightly different graph. 2953 if check_dynamic_shape_capture(): 2954 return 2955 2956 actual = normalize_gm(wrapped_gm.print_readable(print_output=False)) 2957 self.assertExpectedInline( 2958 actual, 2959 """\ 2960class GraphModule(torch.nn.Module): 2961 def forward(self, L_x_: "f32[4, 3]"): 2962 l_x_ = L_x_ 2963 2964 tensor: "i64[1]" = torch.tensor((12,)) 2965 cumsum: "i64[1]" = tensor.cumsum(dim = 0); tensor = None 2966 getitem: "i64[0]" = cumsum[slice(None, -1, None)]; cumsum = None 2967 neg: "i64[0]" = getitem.neg(); getitem = None 2968 unbind = neg.unbind(); neg = unbind = None 2969 2970 chunk: "f32[12, 12]" = l_x_.new_zeros(12, 12) 2971 2972 diagonal: "f32[12]" = chunk.diagonal(0) 2973 fill_: "f32[12]" = diagonal.fill_(1); diagonal = fill_ = None 2974 2975 child: "f32[12, 4, 3]" = chunk.view(12, 4, 3); chunk = None 2976 2977 lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None 2978 2979 _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None 2980 2981 child_1: "f32[4, 3]" = torch._C._functorch._add_batch_dim(child, 0, 1); child = None 2982 2983 _jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_x_,), (child_1,)); _jvp_treespec_compare = None 2984 2985 _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None 2986 _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None 2987 _enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None 2988 2989 _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None 2990 2991 child_2: "f32[4, 3]" = torch._make_dual(l_x_, child_1, level = 0); child_1 = None 2992 2993 _wrap_for_grad: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 2); l_x_ = _wrap_for_grad = None 2994 2995 _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None 2996 _grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None 2997 2998 diff_primals: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(child_2, 3); child_2 = None 2999 3000 set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None 3001 3002 _set_tensor_requires_grad: "f32[4, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_primals); _set_tensor_requires_grad = None 3003 3004 set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None 3005 3006 o: "f32[4, 3]" = torch.sin(diff_primals) 3007 3008 results: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(o, 3) 3009 3010 _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None 3011 _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None 3012 3013 tensor_1: "i64[1]" = torch.tensor((12,)) 3014 cumsum_1: "i64[1]" = tensor_1.cumsum(dim = 0); tensor_1 = None 3015 getitem_1: "i64[0]" = cumsum_1[slice(None, -1, None)]; cumsum_1 = None 3016 neg_1: "i64[0]" = getitem_1.neg(); getitem_1 = None 3017 unbind_1 = neg_1.unbind(); neg_1 = unbind_1 = None 3018 3019 chunk_1: "f32[12, 12]" = results.new_zeros(12, 12); results = None 3020 3021 diagonal_1: "f32[12]" = chunk_1.diagonal(0) 3022 fill__1: "f32[12]" = diagonal_1.fill_(1); diagonal_1 = fill__1 = None 3023 3024 basis: "f32[12, 4, 3]" = chunk_1.view(12, 4, 3); chunk_1 = None 3025 3026 lazy_load_decompositions_1 = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions_1 = None 3027 3028 _vmap_increment_nesting_1 = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting_1 = None 3029 3030 _add_batch_dim_1: "f32[4, 3]" = torch._C._functorch._add_batch_dim(basis, 0, 3); basis = None 3031 3032 _vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim_1); _vjp_treespec_compare = None 3033 3034 _autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim_1], retain_graph = True, create_graph = True); o = diff_primals = _add_batch_dim_1 = None 3035 batched_outputs: "f32[4, 3]" = _autograd_grad[0]; _autograd_grad = None 3036 3037 chunked_result: "f32[12, 4, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 3, 12, 0); batched_outputs = None 3038 3039 _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None 3040 3041 split = chunked_result.split((12,), dim = 0); chunked_result = None 3042 split_1: "f32[12, 4, 3]" = split[0]; split = None 3043 3044 output_input: "f32[4, 3, 4, 3]" = split_1.view((4, 3, 4, 3)); split_1 = None 3045 3046 _unpack_dual = torch._unpack_dual(output_input, level = 0); output_input = None 3047 primal: "f32[4, 3, 4, 3]" = _unpack_dual[0] 3048 dual: "f32[4, 3, 4, 3]" = _unpack_dual[1]; _unpack_dual = None 3049 3050 primals_out_unflatten: "f32[4, 3, 4, 3]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = primals_out_unflatten = None 3051 3052 tangents_out_unflatten: "f32[4, 3, 4, 3]" = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None 3053 3054 _exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None 3055 _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None 3056 _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None 3057 3058 results_1: "f32[12, 4, 3, 4, 3]" = torch._C._functorch._remove_batch_dim(tangents_out_unflatten, 1, 12, 0); tangents_out_unflatten = None 3059 3060 _vmap_decrement_nesting_1 = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting_1 = None 3061 3062 movedim: "f32[4, 3, 4, 3, 12]" = results_1.movedim(0, -1); results_1 = None 3063 split_2 = movedim.split((12,), dim = -1); movedim = None 3064 jac_out_in: "f32[4, 3, 4, 3, 12]" = split_2[0]; split_2 = None 3065 3066 unflatten: "f32[4, 3, 4, 3, 4, 3]" = jac_out_in.unflatten(-1, (4, 3)); jac_out_in = None 3067 return (unflatten,) 3068""", 3069 ) 3070 3071 def test_hessian_argnums(self): 3072 counters.clear() 3073 3074 def fn(x, y): 3075 return x.sin() 3076 3077 def wrapper_fn(x, y): 3078 return torch.func.hessian(fn, argnums=(1,))(x, y) 3079 3080 x = torch.randn(4, 3) 3081 y = torch.randn(3, 4) 3082 wrapped_gm = self._compile_check(wrapper_fn, (x, y)) 3083 # Dynamic shapes produce a slightly different graph. 3084 if check_dynamic_shape_capture(): 3085 return 3086 3087 actual = normalize_gm(wrapped_gm.print_readable(print_output=False)) 3088 self.assertExpectedInline( 3089 "\n".join(actual.split("\n")[:-2]), 3090 """\ 3091class GraphModule(torch.nn.Module): 3092 def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): 3093 l_x_ = L_x_ 3094 l_y_ = L_y_ 3095 3096 tensor: "i64[1]" = torch.tensor((12,)) 3097 cumsum: "i64[1]" = tensor.cumsum(dim = 0); tensor = None 3098 getitem: "i64[0]" = cumsum[slice(None, -1, None)]; cumsum = None 3099 neg: "i64[0]" = getitem.neg(); getitem = None 3100 unbind = neg.unbind(); neg = unbind = None 3101 3102 chunk: "f32[12, 12]" = l_y_.new_zeros(12, 12) 3103 3104 diagonal: "f32[12]" = chunk.diagonal(0) 3105 fill_: "f32[12]" = diagonal.fill_(1); diagonal = fill_ = None 3106 3107 child: "f32[12, 3, 4]" = chunk.view(12, 3, 4); chunk = None 3108 3109 lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None 3110 3111 _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None 3112 3113 child_1: "f32[3, 4]" = torch._C._functorch._add_batch_dim(child, 0, 1); child = None 3114 3115 _jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_y_,), (child_1,)); _jvp_treespec_compare = None 3116 3117 _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None 3118 _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None 3119 _enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None 3120 3121 _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None 3122 3123 child_3: "f32[3, 4]" = torch._make_dual(l_y_, child_1, level = 0); child_1 = None 3124 3125 child_2: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 2); l_x_ = None 3126 _wrap_for_grad_1: "f32[3, 4]" = torch._C._functorch._wrap_for_grad(l_y_, 2); l_y_ = _wrap_for_grad_1 = None 3127 3128 _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None 3129 _grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None 3130 3131 _wrap_for_grad_2: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(child_2, 3); child_2 = None 3132 child_4: "f32[3, 4]" = torch._C._functorch._wrap_for_grad(child_3, 3); child_3 = None 3133 3134 set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None 3135 3136 _set_tensor_requires_grad: "f32[3, 4]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child_4); _set_tensor_requires_grad = None 3137 3138 set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None 3139 3140 o: "f32[4, 3]" = _wrap_for_grad_2.sin(); _wrap_for_grad_2 = None 3141 3142 results: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(o, 3) 3143 3144 _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None 3145 _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None 3146 3147 tensor_1: "i64[1]" = torch.tensor((12,)) 3148 cumsum_1: "i64[1]" = tensor_1.cumsum(dim = 0); tensor_1 = None 3149 getitem_1: "i64[0]" = cumsum_1[slice(None, -1, None)]; cumsum_1 = None 3150 neg_1: "i64[0]" = getitem_1.neg(); getitem_1 = None 3151 unbind_1 = neg_1.unbind(); neg_1 = unbind_1 = None 3152 3153 chunk_1: "f32[12, 12]" = results.new_zeros(12, 12); results = None 3154 3155 diagonal_1: "f32[12]" = chunk_1.diagonal(0) 3156 fill__1: "f32[12]" = diagonal_1.fill_(1); diagonal_1 = fill__1 = None 3157 3158 basis: "f32[12, 4, 3]" = chunk_1.view(12, 4, 3); chunk_1 = None 3159 3160 lazy_load_decompositions_1 = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions_1 = None 3161 3162 _vmap_increment_nesting_1 = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting_1 = None 3163 3164 _add_batch_dim_1: "f32[4, 3]" = torch._C._functorch._add_batch_dim(basis, 0, 3); basis = None 3165 3166 _vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim_1); _vjp_treespec_compare = None 3167 3168 _autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [child_4], [_add_batch_dim_1], retain_graph = True, create_graph = True); o = child_4 = _add_batch_dim_1 = None 3169 child_5: "f32[3, 4]" = _autograd_grad[0]; _autograd_grad = None 3170 3171 child_6: "f32[12, 3, 4]" = torch._C._functorch._remove_batch_dim(child_5, 3, 12, 0); child_5 = None 3172 3173 _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None 3174 3175 split = child_6.split((12,), dim = 0); child_6 = None 3176 split_1: "f32[12, 3, 4]" = split[0]; split = None 3177 3178 child_7: "f32[4, 3, 3, 4]" = split_1.view((4, 3, 3, 4)); split_1 = None 3179 3180 _unpack_dual = torch._unpack_dual(child_7, level = 0); child_7 = None 3181 primal: "f32[4, 3, 3, 4]" = _unpack_dual[0]; _unpack_dual = None 3182 3183 tangent: "f32[4, 3, 3, 4]" = torch.zeros_like(primal) 3184 3185 child_8: "f32[4, 3, 3, 4]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = child_8 = None 3186 3187 child_9: "f32[4, 3, 3, 4]" = torch._C._functorch._unwrap_for_grad(tangent, 2); tangent = None 3188 3189 _exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None 3190 _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None 3191 _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None 3192 3193 child_10: "f32[12, 4, 3, 3, 4]" = torch._C._functorch._remove_batch_dim(child_9, 1, 12, 0); child_9 = None 3194 3195 _vmap_decrement_nesting_1 = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting_1 = None 3196 3197 movedim: "f32[4, 3, 3, 4, 12]" = child_10.movedim(0, -1); child_10 = None 3198 split_2 = movedim.split((12,), dim = -1); movedim = None 3199 jac_out_in: "f32[4, 3, 3, 4, 12]" = split_2[0]; split_2 = None 3200 3201 unflatten: "f32[4, 3, 3, 4, 3, 4]" = jac_out_in.unflatten(-1, (3, 4)); jac_out_in = None""", 3202 ) 3203 3204 self.assertExpectedInline( 3205 actual.split("\n")[-2], 3206 """ return (unflatten,)""", 3207 ) 3208 3209 def test_hessian_disable_capture(self): 3210 counters.clear() 3211 3212 with config.patch(capture_func_transforms=False): 3213 # We have verified above that this 3214 # function compiles 3215 def wrapper_fn(x): 3216 return torch.func.hessian(torch.sin)(x) 3217 3218 x = torch.randn(3, 3, 3) 3219 actual = wrapper_fn(x) 3220 expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)( 3221 x 3222 ) 3223 self.assertEqual(len(counters["graph_break"]), 2) 3224 self.assertEqual( 3225 { 3226 "torch.func.vmap capture is disabled, it can be " 3227 "turned on by setting `torch._dynamo.config.capture_func_transforms=True`": 2, 3228 "torch.func.hessian capture is disabled, it can be " 3229 "turned on by setting `torch._dynamo.config.capture_func_transforms=True`": 1, 3230 }, 3231 dict(counters["graph_break"]), 3232 ) 3233 self.assertEqual(actual, expected) 3234 3235 def test_jacrev(self): 3236 counters.clear() 3237 3238 def wrapper_fn(x): 3239 return torch.func.jacrev(torch.sin)(x) 3240 3241 x = torch.randn(4, 3) 3242 wrapped_gm = self._compile_check(wrapper_fn, (x,)) 3243 # Dynamic shapes produce a slightly different graph. 3244 if check_dynamic_shape_capture(): 3245 return 3246 3247 actual = normalize_gm(wrapped_gm.print_readable(print_output=False)) 3248 self.assertExpectedInline( 3249 actual, 3250 """\ 3251class GraphModule(torch.nn.Module): 3252 def forward(self, L_x_: "f32[4, 3]"): 3253 l_x_ = L_x_ 3254 3255 _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None 3256 _grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None 3257 3258 diff_primals: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None 3259 3260 set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None 3261 3262 _set_tensor_requires_grad: "f32[4, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_primals); _set_tensor_requires_grad = None 3263 3264 set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None 3265 3266 o: "f32[4, 3]" = torch.sin(diff_primals) 3267 3268 results: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(o, 1) 3269 3270 _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None 3271 _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None 3272 3273 tensor: "i64[1]" = torch.tensor((12,)) 3274 cumsum: "i64[1]" = tensor.cumsum(dim = 0); tensor = None 3275 getitem: "i64[0]" = cumsum[slice(None, -1, None)]; cumsum = None 3276 neg: "i64[0]" = getitem.neg(); getitem = None 3277 unbind = neg.unbind(); neg = unbind = None 3278 3279 chunk: "f32[12, 12]" = results.new_zeros(12, 12); results = None 3280 3281 diagonal: "f32[12]" = chunk.diagonal(0) 3282 fill_: "f32[12]" = diagonal.fill_(1); diagonal = fill_ = None 3283 3284 basis: "f32[12, 4, 3]" = chunk.view(12, 4, 3); chunk = None 3285 3286 lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None 3287 3288 _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None 3289 3290 _add_batch_dim: "f32[4, 3]" = torch._C._functorch._add_batch_dim(basis, 0, 1); basis = None 3291 3292 _vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim); _vjp_treespec_compare = None 3293 3294 _autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True); o = diff_primals = _add_batch_dim = None 3295 batched_outputs: "f32[4, 3]" = _autograd_grad[0]; _autograd_grad = None 3296 3297 chunked_result: "f32[12, 4, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 12, 0); batched_outputs = None 3298 3299 _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None 3300 3301 split = chunked_result.split((12,), dim = 0); chunked_result = None 3302 split_1: "f32[12, 4, 3]" = split[0]; split = None 3303 3304 output_input: "f32[4, 3, 4, 3]" = split_1.view((4, 3, 4, 3)); split_1 = None 3305 return (output_input,) 3306""", 3307 ) 3308 3309 def test_jacrev_two_tensors_argnums(self): 3310 counters.clear() 3311 3312 def fn(x, y): 3313 return y.sin() 3314 3315 def wrapper_fn(x, y): 3316 return torch.func.jacrev(fn, argnums=1)(x, y) 3317 3318 x = torch.randn(4, 3) 3319 y = torch.randn(3, 4) 3320 wrapped_gm = self._compile_check(wrapper_fn, (x, y)) 3321 # Dynamic shapes produce a slightly different graph. 3322 if check_dynamic_shape_capture(): 3323 return 3324 3325 actual = normalize_gm(wrapped_gm.print_readable(print_output=False)) 3326 self.assertExpectedInline( 3327 actual, 3328 """\ 3329class GraphModule(torch.nn.Module): 3330 def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): 3331 l_x_ = L_x_ 3332 l_y_ = L_y_ 3333 3334 _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None 3335 _grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None 3336 3337 _wrap_for_grad: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = _wrap_for_grad = None 3338 diff_primals: "f32[3, 4]" = torch._C._functorch._wrap_for_grad(l_y_, 1); l_y_ = None 3339 3340 set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None 3341 3342 _set_tensor_requires_grad: "f32[3, 4]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_primals); _set_tensor_requires_grad = None 3343 3344 set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None 3345 3346 o: "f32[3, 4]" = diff_primals.sin() 3347 3348 results: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(o, 1) 3349 3350 _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None 3351 _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None 3352 3353 tensor: "i64[1]" = torch.tensor((12,)) 3354 cumsum: "i64[1]" = tensor.cumsum(dim = 0); tensor = None 3355 getitem: "i64[0]" = cumsum[slice(None, -1, None)]; cumsum = None 3356 neg: "i64[0]" = getitem.neg(); getitem = None 3357 unbind = neg.unbind(); neg = unbind = None 3358 3359 chunk: "f32[12, 12]" = results.new_zeros(12, 12); results = None 3360 3361 diagonal: "f32[12]" = chunk.diagonal(0) 3362 fill_: "f32[12]" = diagonal.fill_(1); diagonal = fill_ = None 3363 3364 basis: "f32[12, 3, 4]" = chunk.view(12, 3, 4); chunk = None 3365 3366 lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None 3367 3368 _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None 3369 3370 _add_batch_dim: "f32[3, 4]" = torch._C._functorch._add_batch_dim(basis, 0, 1); basis = None 3371 3372 _vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim); _vjp_treespec_compare = None 3373 3374 _autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True); o = diff_primals = _add_batch_dim = None 3375 batched_outputs: "f32[3, 4]" = _autograd_grad[0]; _autograd_grad = None 3376 3377 chunked_result: "f32[12, 3, 4]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 12, 0); batched_outputs = None 3378 3379 _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None 3380 3381 split = chunked_result.split((12,), dim = 0); chunked_result = None 3382 split_1: "f32[12, 3, 4]" = split[0]; split = None 3383 3384 output_input: "f32[3, 4, 3, 4]" = split_1.view((3, 4, 3, 4)); split_1 = None 3385 return (output_input,) 3386""", 3387 ) 3388 3389 def test_jacrev_has_aux(self): 3390 counters.clear() 3391 3392 def fn(x, y): 3393 return y.sin(), x 3394 3395 def wrapper_fn(x, y): 3396 return torch.func.jacrev(fn, argnums=1, has_aux=True)(x, y) 3397 3398 x = torch.randn(4, 3) 3399 y = torch.randn(3, 4) 3400 wrapped_gm = self._compile_check(wrapper_fn, (x, y)) 3401 # Dynamic shapes produce a slightly different graph. 3402 if check_dynamic_shape_capture(): 3403 return 3404 3405 actual = normalize_gm(wrapped_gm.print_readable(print_output=False)) 3406 self.assertExpectedInline( 3407 actual, 3408 """\ 3409class GraphModule(torch.nn.Module): 3410 def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): 3411 l_x_ = L_x_ 3412 l_y_ = L_y_ 3413 3414 _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None 3415 _grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None 3416 3417 aux: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None 3418 diff_primals: "f32[3, 4]" = torch._C._functorch._wrap_for_grad(l_y_, 1); l_y_ = None 3419 3420 set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None 3421 3422 _set_tensor_requires_grad: "f32[3, 4]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_primals); _set_tensor_requires_grad = None 3423 3424 set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None 3425 3426 o: "f32[3, 4]" = diff_primals.sin() 3427 3428 aux_1: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None 3429 3430 results: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(o, 1) 3431 3432 _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None 3433 _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None 3434 3435 tensor: "i64[1]" = torch.tensor((12,)) 3436 cumsum: "i64[1]" = tensor.cumsum(dim = 0); tensor = None 3437 getitem: "i64[0]" = cumsum[slice(None, -1, None)]; cumsum = None 3438 neg: "i64[0]" = getitem.neg(); getitem = None 3439 unbind = neg.unbind(); neg = unbind = None 3440 3441 chunk: "f32[12, 12]" = results.new_zeros(12, 12); results = None 3442 3443 diagonal: "f32[12]" = chunk.diagonal(0) 3444 fill_: "f32[12]" = diagonal.fill_(1); diagonal = fill_ = None 3445 3446 basis: "f32[12, 3, 4]" = chunk.view(12, 3, 4); chunk = None 3447 3448 lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None 3449 3450 _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None 3451 3452 _add_batch_dim: "f32[3, 4]" = torch._C._functorch._add_batch_dim(basis, 0, 1); basis = None 3453 3454 _vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim); _vjp_treespec_compare = None 3455 3456 _autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True); o = diff_primals = _add_batch_dim = None 3457 batched_outputs: "f32[3, 4]" = _autograd_grad[0]; _autograd_grad = None 3458 3459 chunked_result: "f32[12, 3, 4]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 12, 0); batched_outputs = None 3460 3461 _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None 3462 3463 split = chunked_result.split((12,), dim = 0); chunked_result = None 3464 split_1: "f32[12, 3, 4]" = split[0]; split = None 3465 3466 output_input: "f32[3, 4, 3, 4]" = split_1.view((3, 4, 3, 4)); split_1 = None 3467 return (output_input, aux_1) 3468""", 3469 ) 3470 3471 def test_jacrev_disable_capture(self): 3472 counters.clear() 3473 3474 with config.patch(capture_func_transforms=False): 3475 # We have verified above that this 3476 # function compiles 3477 def wrapper_fn(x): 3478 return torch.func.jacrev(torch.sin)(x) 3479 3480 x = torch.randn(3, 3, 3) 3481 actual = wrapper_fn(x) 3482 expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)( 3483 x 3484 ) 3485 self.assertEqual(len(counters["graph_break"]), 2) 3486 self.assertEqual( 3487 dict(counters["graph_break"]), 3488 { 3489 "torch.func.vmap capture is disabled, it can be " 3490 "turned on by setting `torch._dynamo.config.capture_func_transforms=True`": 2, 3491 "torch.func.jacrev capture is disabled, it can be " 3492 "turned on by setting `torch._dynamo.config.capture_func_transforms=True`": 1, 3493 }, 3494 ) 3495 self.assertEqual(actual, expected) 3496 3497 def test_vjp(self): 3498 counters.clear() 3499 3500 def fn(x): 3501 return x.sin().sum() 3502 3503 def wrapper_fn(x, v): 3504 (out, vjpfunc) = torch.func.vjp(fn, x) 3505 return out 3506 3507 x = torch.randn([5]) 3508 v = torch.randn(5) 3509 wrapped_gm = self._compile_check(wrapper_fn, (x, v)) 3510 3511 # Dynamic shapes produce a slightly different graph. 3512 if check_dynamic_shape_capture(): 3513 return 3514 3515 actual = normalize_gm(wrapped_gm.print_readable(print_output=False)) 3516 self.assertExpectedInline( 3517 actual, 3518 """\ 3519class GraphModule(torch.nn.Module): 3520 def forward(self, L_x_: "f32[5]"): 3521 l_x_ = L_x_ 3522 3523 _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None 3524 _grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None 3525 3526 child: "f32[5]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None 3527 3528 set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None 3529 3530 child_1: "f32[5]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child); child_1 = None 3531 3532 set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None 3533 3534 sin: "f32[5]" = child.sin(); child = None 3535 o: "f32[]" = sin.sum(); sin = None 3536 3537 results: "f32[]" = torch._C._functorch._unwrap_for_grad(o, 1); o = None 3538 3539 _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None 3540 _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None 3541 return (results,) 3542""", 3543 ) 3544 3545 def test_vjp_multiple_outputs(self): 3546 counters.clear() 3547 3548 def wrapper_fn(x, v): 3549 fn = lambda x: (x.sin(), x.cos()) # noqa: E731 3550 (out, vjpfunc) = torch.func.vjp(fn, x) 3551 vjps = vjpfunc((v, v)) 3552 return out, vjps 3553 3554 x = torch.randn([5]) 3555 v = torch.randn(5) 3556 wrapped_gm = self._compile_check(wrapper_fn, (x, v)) 3557 3558 # Dynamic shapes produce a slightly different graph. 3559 if check_dynamic_shape_capture(): 3560 return 3561 3562 actual = normalize_gm(wrapped_gm.print_readable(print_output=False)) 3563 self.assertExpectedInline( 3564 actual, 3565 """\ 3566class GraphModule(torch.nn.Module): 3567 def forward(self, L_x_: "f32[5]", L_v_: "f32[5]"): 3568 l_x_ = L_x_ 3569 l_v_ = L_v_ 3570 3571 _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None 3572 _grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None 3573 3574 child: "f32[5]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None 3575 3576 set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None 3577 3578 child_3: "f32[5]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child) 3579 3580 set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None 3581 3582 child_1: "f32[5]" = child.sin() 3583 child_2: "f32[5]" = child.cos(); child = None 3584 3585 _unwrap_for_grad: "f32[5]" = torch._C._functorch._unwrap_for_grad(child_1, 1) 3586 _unwrap_for_grad_1: "f32[5]" = torch._C._functorch._unwrap_for_grad(child_2, 1) 3587 3588 _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None 3589 _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None 3590 3591 _vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare((child_1, child_2), (l_v_, l_v_)); _vjp_treespec_compare = None 3592 3593 _autograd_grad = torch._functorch.eager_transforms._autograd_grad([child_1, child_2], [child_3], [l_v_, l_v_], retain_graph = True, create_graph = True); child_1 = child_2 = child_3 = l_v_ = None 3594 getitem: "f32[5]" = _autograd_grad[0]; _autograd_grad = None 3595 return (_unwrap_for_grad, _unwrap_for_grad_1, getitem) 3596""", 3597 ) 3598 3599 def test_vjp_multiple_outputs_python_struct(self): 3600 counters.clear() 3601 3602 def wrapper_fn(x, v): 3603 fn = lambda x: {"first": x.sin(), "second": x.cos()} # noqa: E731 3604 (out, vjpfunc) = torch.func.vjp(fn, x) 3605 vjps = vjpfunc({"first": v, "second": v.sin()}) 3606 return out, vjps 3607 3608 x = torch.randn([5]) 3609 v = torch.randn(5) 3610 wrapped_gm = self._compile_check(wrapper_fn, (x, v)) 3611 3612 # Dynamic shapes produce a slightly different graph. 3613 if check_dynamic_shape_capture(): 3614 return 3615 3616 actual = normalize_gm(wrapped_gm.print_readable(print_output=False)) 3617 self.assertExpectedInline( 3618 actual, 3619 """\ 3620class GraphModule(torch.nn.Module): 3621 def forward(self, L_x_: "f32[5]", L_v_: "f32[5]"): 3622 l_x_ = L_x_ 3623 l_v_ = L_v_ 3624 3625 _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None 3626 _grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None 3627 3628 child: "f32[5]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None 3629 3630 set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None 3631 3632 child_3: "f32[5]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child) 3633 3634 set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None 3635 3636 child_1: "f32[5]" = child.sin() 3637 child_2: "f32[5]" = child.cos(); child = None 3638 3639 _unwrap_for_grad: "f32[5]" = torch._C._functorch._unwrap_for_grad(child_1, 1) 3640 _unwrap_for_grad_1: "f32[5]" = torch._C._functorch._unwrap_for_grad(child_2, 1) 3641 3642 _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None 3643 _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None 3644 3645 child_4: "f32[5]" = l_v_.sin() 3646 3647 _vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare({'first': child_1, 'second': child_2}, {'first': l_v_, 'second': child_4}); _vjp_treespec_compare = None 3648 3649 _autograd_grad = torch._functorch.eager_transforms._autograd_grad([child_1, child_2], [child_3], [l_v_, child_4], retain_graph = True, create_graph = True); child_1 = child_2 = child_3 = l_v_ = child_4 = None 3650 getitem: "f32[5]" = _autograd_grad[0]; _autograd_grad = None 3651 return (_unwrap_for_grad, _unwrap_for_grad_1, getitem) 3652""", 3653 ) 3654 3655 def test_vjp_has_aux(self): 3656 counters.clear() 3657 3658 def fn(x): 3659 return x.sin().sum(), x 3660 3661 def wrapper_fn(x, v): 3662 (out, vjpfunc, _) = torch.func.vjp(fn, x, has_aux=True) 3663 return out 3664 3665 x = torch.randn([5]) 3666 v = torch.randn(5) 3667 wrapped_gm = self._compile_check(wrapper_fn, (x, v)) 3668 3669 # Dynamic shapes produce a slightly different graph. 3670 if check_dynamic_shape_capture(): 3671 return 3672 3673 actual = normalize_gm(wrapped_gm.print_readable(print_output=False)) 3674 self.assertExpectedInline( 3675 actual, 3676 """\ 3677class GraphModule(torch.nn.Module): 3678 def forward(self, L_x_: "f32[5]"): 3679 l_x_ = L_x_ 3680 3681 _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None 3682 _grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None 3683 3684 child: "f32[5]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None 3685 3686 set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None 3687 3688 child_1: "f32[5]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child); child_1 = None 3689 3690 set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None 3691 3692 sin: "f32[5]" = child.sin() 3693 o: "f32[]" = sin.sum(); sin = None 3694 3695 aux: "f32[5]" = torch._C._functorch._unwrap_for_grad(child, 1); child = aux = None 3696 3697 results: "f32[]" = torch._C._functorch._unwrap_for_grad(o, 1); o = None 3698 3699 _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None 3700 _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None 3701 return (results,) 3702""", 3703 ) 3704 3705 def test_vjp_disable_capture(self): 3706 counters.clear() 3707 3708 with config.patch(capture_func_transforms=False): 3709 # We have verified above that this 3710 # function compiles 3711 def wrapper_fn(x): 3712 (out, vjpfunc) = torch.func.vjp(torch.sin, x) 3713 return out 3714 3715 x = torch.randn(3, 3, 3) 3716 actual = wrapper_fn(x) 3717 expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)( 3718 x 3719 ) 3720 self.assertEqual(len(counters["graph_break"]), 1) 3721 self.assertEqual( 3722 dict(counters["graph_break"]), 3723 { 3724 "torch.func.vjp capture is disabled, it can be " 3725 "turned on by setting `torch._dynamo.config.capture_func_transforms=True`": 1 3726 }, 3727 ) 3728 self.assertEqual(actual, expected) 3729 3730 @config.patch(inline_inbuilt_nn_modules=True) 3731 def test_functional_call(self): 3732 def wrapper_fn(model, params, inputs, targets): 3733 prediction = torch.func.functional_call(model, params, (inputs,)) 3734 return torch.nn.functional.mse_loss(prediction, targets) 3735 3736 model = torch.nn.Linear(3, 3) 3737 params = dict(model.named_parameters()) 3738 inputs = torch.randn(64, 3) 3739 targets = torch.randn(64, 3) 3740 3741 wrapped_gm = self._compile_check(wrapper_fn, (model, params, inputs, targets)) 3742 # Dynamic shapes produce a slightly different graph. 3743 if check_dynamic_shape_capture(): 3744 return 3745 3746 actual = normalize_gm(wrapped_gm.print_readable(print_output=False)) 3747 if torch._dynamo.config.inline_inbuilt_nn_modules: 3748 self.assertExpectedInline( 3749 actual, 3750 """\ 3751class GraphModule(torch.nn.Module): 3752 def forward(self, L_model_parameters_weight_: "f32[3, 3]", L_model_parameters_bias_: "f32[3]", L_inputs_: "f32[64, 3]", L_targets_: "f32[64, 3]"): 3753 l_model_parameters_weight_ = L_model_parameters_weight_ 3754 l_model_parameters_bias_ = L_model_parameters_bias_ 3755 l_inputs_ = L_inputs_ 3756 l_targets_ = L_targets_ 3757 3758 prediction: "f32[64, 3]" = torch._C._nn.linear(l_inputs_, l_model_parameters_weight_, l_model_parameters_bias_); l_inputs_ = l_model_parameters_weight_ = l_model_parameters_bias_ = None 3759 3760 mse_loss: "f32[]" = torch.nn.functional.mse_loss(prediction, l_targets_); prediction = l_targets_ = None 3761 return (mse_loss,) 3762""", 3763 ) 3764 else: 3765 self.assertExpectedInline( 3766 actual, 3767 """\ 3768class GraphModule(torch.nn.Module): 3769 def forward(self, L_inputs_: "f32[64, 3]", L_targets_: "f32[64, 3]"): 3770 l_inputs_ = L_inputs_ 3771 l_targets_ = L_targets_ 3772 3773 prediction: "f32[64, 3]" = self.model(l_inputs_); l_inputs_ = None 3774 3775 mse_loss: "f32[]" = torch.nn.functional.mse_loss(prediction, l_targets_); prediction = l_targets_ = None 3776 return (mse_loss,) 3777""", 3778 ) 3779 3780 @config.patch(inline_inbuilt_nn_modules=True) 3781 def test_functional_call_sequential_params_and_buffers(self): 3782 # copied from test/test_stateless.py 3783 class MockModule(torch.nn.Module): 3784 def __init__(self) -> None: 3785 super().__init__() 3786 self.l1 = torch.nn.Linear(1, 1) 3787 self.register_buffer("buffer", torch.ones(1)) 3788 self.foo = 0.0 3789 3790 def forward(self, x): 3791 return self.l1(x) + self.buffer 3792 3793 def wrapper_fn(model, params, buffers, inputs): 3794 # two separate dictionaries 3795 return torch.func.functional_call(model, (params, buffers), inputs) 3796 3797 model = MockModule() 3798 params = dict(model.named_parameters()) 3799 buffers = dict(model.named_buffers()) 3800 inputs = torch.tensor([[1.5]]) 3801 3802 wrapped_gm = self._compile_check( 3803 wrapper_fn, (model, params, buffers, inputs), fullgraph=False 3804 ) 3805 # Dynamic shapes produce a slightly different graph. 3806 if check_dynamic_shape_capture(): 3807 return 3808 3809 actual = normalize_gm(wrapped_gm.print_readable(print_output=False)) 3810 if torch._dynamo.config.inline_inbuilt_nn_modules: 3811 expected = """\ 3812class GraphModule(torch.nn.Module): 3813 def forward(self, L_params_l1_weight_: "f32[1, 1]", L_params_l1_bias_: "f32[1]", L_buffers_buffer_: "f32[1]", L_inputs_: "f32[1, 1]"): 3814 l_params_l1_weight_ = L_params_l1_weight_ 3815 l_params_l1_bias_ = L_params_l1_bias_ 3816 l_buffers_buffer_ = L_buffers_buffer_ 3817 l_inputs_ = L_inputs_ 3818 3819 linear: "f32[1, 1]" = torch._C._nn.linear(l_inputs_, l_params_l1_weight_, l_params_l1_bias_); l_inputs_ = l_params_l1_weight_ = l_params_l1_bias_ = None 3820 3821 add: "f32[1, 1]" = linear + l_buffers_buffer_; linear = l_buffers_buffer_ = None 3822 return (add,) 3823""" 3824 # We found Windows/Linux have some empty line difference, empty_line_normalizer will help fix it. 3825 self.assertExpectedInline( 3826 empty_line_normalizer(actual), 3827 empty_line_normalizer(normalize_gm(expected)), 3828 ) 3829 else: 3830 self.assertExpectedInline( 3831 actual, 3832 """\ 3833class GraphModule(torch.nn.Module): 3834 def forward(self, L_x_: "f32[1, 1]"): 3835 l_x_ = L_x_ 3836 3837 l__self___l1: "f32[1, 1]" = self.L__self___l1(l_x_); l_x_ = None 3838 l__self___buffer: "f32[1]" = self.L__self___buffer 3839 add: "f32[1, 1]" = l__self___l1 + l__self___buffer; l__self___l1 = l__self___buffer = None 3840 return (add,) 3841""", 3842 ) 3843 3844 @config.patch(inline_inbuilt_nn_modules=True) 3845 def test_functional_call_disable_capture(self): 3846 counters.clear() 3847 3848 with config.patch(capture_func_transforms=False): 3849 # We have verified above that this 3850 # function compiles 3851 def wrapper_fn(model, params, inputs, targets): 3852 prediction = torch.func.functional_call(model, params, (inputs,)) 3853 return torch.nn.functional.mse_loss(prediction, targets) 3854 3855 model = torch.nn.Linear(3, 3) 3856 params = dict(model.named_parameters()) 3857 inputs = torch.randn(64, 3) 3858 targets = torch.randn(64, 3) 3859 3860 actual = wrapper_fn(model, params, inputs, targets) 3861 expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)( 3862 model, params, inputs, targets 3863 ) 3864 self.assertEqual(len(counters["graph_break"]), 1) 3865 self.assertEqual( 3866 { 3867 "torch.func.functional_call capture is disabled, it can be " 3868 "turned on by setting `torch._dynamo.config.capture_func_transforms=True`": 1, 3869 }, 3870 dict(counters["graph_break"]), 3871 ) 3872 self.assertEqual(actual, expected) 3873 3874 @config.patch(inline_inbuilt_nn_modules=False) 3875 def test_functional_call_disable_inline_nn_module(self): 3876 counters.clear() 3877 3878 def wrapper_fn(model, params, inputs, targets): 3879 prediction = torch.func.functional_call(model, params, (inputs,)) 3880 return torch.nn.functional.mse_loss(prediction, targets) 3881 3882 model = torch.nn.Linear(3, 3) 3883 params = dict(model.named_parameters()) 3884 inputs = torch.randn(64, 3) 3885 targets = torch.randn(64, 3) 3886 3887 actual = wrapper_fn(model, params, inputs, targets) 3888 expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)( 3889 model, params, inputs, targets 3890 ) 3891 self.assertEqual(len(counters["graph_break"]), 1) 3892 self.assertEqual( 3893 { 3894 "torch.func.functional_call capture is disabled, it can be " 3895 "turned on by setting `torch._dynamo.config.inline_inbuilt_nn_modules=True`": 1, 3896 }, 3897 dict(counters["graph_break"]), 3898 ) 3899 self.assertEqual(actual, expected) 3900 3901 def test_grad(self): 3902 counters.clear() 3903 3904 def fn(x): 3905 return x.sin().sum() 3906 3907 def wrapper_fn(x): 3908 return torch.func.grad(fn)(x) 3909 3910 x = torch.randn(3, 3, 3) 3911 wrapped_gm = self._compile_check(wrapper_fn, (x,)) 3912 3913 # Dynamic shapes produce a slightly different graph. 3914 if check_dynamic_shape_capture(): 3915 return 3916 3917 actual = normalize_gm(wrapped_gm.print_readable(print_output=False)) 3918 self.assertExpectedInline( 3919 actual, 3920 """\ 3921class GraphModule(torch.nn.Module): 3922 def forward(self, L_x_: "f32[3, 3, 3]"): 3923 l_x_ = L_x_ 3924 3925 _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None 3926 _grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None 3927 3928 diff_args: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None 3929 3930 set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None 3931 3932 _set_tensor_requires_grad: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args); _set_tensor_requires_grad = None 3933 3934 set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None 3935 3936 sin: "f32[3, 3, 3]" = diff_args.sin() 3937 output: "f32[]" = sin.sum(); sin = None 3938 3939 _autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True); diff_args = None 3940 grad_input: "f32[3, 3, 3]" = _autograd_grad[0]; _autograd_grad = None 3941 3942 grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None 3943 3944 output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None 3945 3946 _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None 3947 _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None 3948 return (grad_input_1,) 3949""", 3950 ) 3951 3952 def test_grad_freevar_tensor(self): 3953 counters.clear() 3954 y = torch.randn(3, 3) 3955 3956 def fn(x): 3957 return (x.sin() + y).sum() 3958 3959 def wrapper_fn(x): 3960 return torch.func.grad(fn)(x) 3961 3962 x = torch.randn(3, 3, 3) 3963 expected = wrapper_fn(x) 3964 actual = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=True)(x) 3965 self.assertEqual(actual, expected) 3966 3967 def test_grad_freevar_python_scalar(self): 3968 counters.clear() 3969 y = 3 3970 3971 def fn(x): 3972 return (x.sin() + y).sum() 3973 3974 def wrapper_fn(x): 3975 return torch.func.grad(fn)(x) 3976 3977 x = torch.randn(3, 3, 3) 3978 wrapped_gm = self._compile_check(wrapper_fn, (x,)) 3979 3980 # Dynamic shapes produce a slightly different graph. 3981 if check_dynamic_shape_capture(): 3982 return 3983 3984 actual = normalize_gm(wrapped_gm.print_readable(print_output=False)) 3985 self.assertExpectedInline( 3986 actual, 3987 """\ 3988class GraphModule(torch.nn.Module): 3989 def forward(self, L_x_: "f32[3, 3, 3]"): 3990 l_x_ = L_x_ 3991 3992 _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None 3993 _grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None 3994 3995 diff_args: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None 3996 3997 set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None 3998 3999 _set_tensor_requires_grad: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args); _set_tensor_requires_grad = None 4000 4001 set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None 4002 4003 sin: "f32[3, 3, 3]" = diff_args.sin() 4004 add: "f32[3, 3, 3]" = sin + 3; sin = None 4005 output: "f32[]" = add.sum(); add = None 4006 4007 _autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True); diff_args = None 4008 grad_input: "f32[3, 3, 3]" = _autograd_grad[0]; _autograd_grad = None 4009 4010 grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None 4011 4012 output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None 4013 4014 _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None 4015 _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None 4016 return (grad_input_1,) 4017""", 4018 ) 4019 4020 def test_grad_capture_tensor(self): 4021 counters.clear() 4022 4023 def wrapper_fn(x): 4024 y = torch.randn(3) 4025 4026 def fn(x): 4027 return (x.sin() + y).sum() 4028 4029 return torch.func.grad(fn)(x) 4030 4031 x = torch.randn(3, 3, 3) 4032 4033 wrapped_gm = self._compile_check(wrapper_fn, (x,)) 4034 4035 # Dynamic shapes produce a slightly different graph. 4036 if check_dynamic_shape_capture(): 4037 return 4038 4039 actual = normalize_gm(wrapped_gm.print_readable(print_output=False)) 4040 self.assertExpectedInline( 4041 actual, 4042 """\ 4043class GraphModule(torch.nn.Module): 4044 def forward(self, L_x_: "f32[3, 3, 3]"): 4045 l_x_ = L_x_ 4046 4047 y: "f32[3]" = torch.randn(3) 4048 4049 _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None 4050 _grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None 4051 4052 diff_args: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None 4053 4054 set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None 4055 4056 _set_tensor_requires_grad: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args); _set_tensor_requires_grad = None 4057 4058 set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None 4059 4060 sin: "f32[3, 3, 3]" = diff_args.sin() 4061 add: "f32[3, 3, 3]" = sin + y; sin = None 4062 output: "f32[]" = add.sum(); add = None 4063 4064 _autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True); diff_args = None 4065 grad_input: "f32[3, 3, 3]" = _autograd_grad[0]; _autograd_grad = None 4066 4067 grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None 4068 4069 output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None 4070 4071 _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None 4072 _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None 4073 return (y, grad_input_1) 4074""", 4075 ) 4076 4077 def test_grad_closure_scalar(self): 4078 counters.clear() 4079 4080 def wrapper_fn(x): 4081 y = 3.14 4082 4083 def fn(x): 4084 return (x.sin() + y).sum() 4085 4086 return torch.func.grad(fn)(x) 4087 4088 x = torch.randn(3, 3, 3) 4089 4090 # Graph break because dynamo is unable to get source `fn` and 4091 # functools.wraps in `grad` leads to graph-break 4092 wrapped_gm = self._compile_check(wrapper_fn, (x,), fullgraph=False) 4093 4094 # Dynamic shapes produce a slightly different graph. 4095 if check_dynamic_shape_capture(): 4096 return 4097 4098 actual = normalize_gm(wrapped_gm.print_readable(print_output=False)) 4099 self.assertExpectedInline( 4100 actual, 4101 """\ 4102class GraphModule(torch.nn.Module): 4103 def forward(self, L_x_: "f32[3, 3, 3]"): 4104 l_x_ = L_x_ 4105 4106 _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None 4107 _grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None 4108 4109 diff_args: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None 4110 4111 set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None 4112 4113 _set_tensor_requires_grad: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args); _set_tensor_requires_grad = None 4114 4115 set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None 4116 4117 sin: "f32[3, 3, 3]" = diff_args.sin() 4118 add: "f32[3, 3, 3]" = sin + 3.14; sin = None 4119 output: "f32[]" = add.sum(); add = None 4120 4121 _autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True); diff_args = None 4122 grad_input: "f32[3, 3, 3]" = _autograd_grad[0]; _autograd_grad = None 4123 4124 grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None 4125 4126 output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None 4127 4128 _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None 4129 _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None 4130 return (grad_input_1,) 4131""", 4132 ) 4133 4134 def test_grad_has_aux(self): 4135 counters.clear() 4136 4137 y = 3.14 4138 4139 def fn(x): 4140 return ((x.sin() + y).sum(), x.cos()) 4141 4142 def wrapper_fn(x): 4143 return torch.func.grad(fn, has_aux=True)(x) 4144 4145 x = torch.randn(3, 3, 3) 4146 wrapped_gm = self._compile_check(wrapper_fn, (x,)) 4147 4148 # Dynamic shapes produce a slightly different graph. 4149 if check_dynamic_shape_capture(): 4150 return 4151 4152 actual = normalize_gm(wrapped_gm.print_readable(print_output=False)) 4153 self.assertExpectedInline( 4154 actual, 4155 """\ 4156class GraphModule(torch.nn.Module): 4157 def forward(self, L_x_: "f32[3, 3, 3]"): 4158 l_x_ = L_x_ 4159 4160 _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None 4161 _grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None 4162 4163 diff_args: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None 4164 4165 set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None 4166 4167 _set_tensor_requires_grad: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args); _set_tensor_requires_grad = None 4168 4169 set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None 4170 4171 sin: "f32[3, 3, 3]" = diff_args.sin() 4172 add: "f32[3, 3, 3]" = sin + 3.14; sin = None 4173 output: "f32[]" = add.sum(); add = None 4174 aux: "f32[3, 3, 3]" = diff_args.cos() 4175 4176 _autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True); diff_args = None 4177 grad_input: "f32[3, 3, 3]" = _autograd_grad[0]; _autograd_grad = None 4178 4179 grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None 4180 4181 output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None 4182 4183 aux_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None 4184 4185 _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None 4186 _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None 4187 return (grad_input_1, aux_1) 4188""", 4189 ) 4190 4191 def test_grad_two_tensor_has_aux(self): 4192 counters.clear() 4193 4194 def fn(x, y): 4195 return ((x.sin() + y).sum(), x.cos()) 4196 4197 def wrapper_fn(x, y): 4198 return torch.func.grad(fn, has_aux=True)(x, y) 4199 4200 y = torch.randn(3, 3, 3) 4201 x = torch.randn(3, 3, 3) 4202 wrapped_gm = self._compile_check(wrapper_fn, (x, y)) 4203 4204 # Dynamic shapes produce a slightly different graph. 4205 if check_dynamic_shape_capture(): 4206 return 4207 4208 actual = normalize_gm(wrapped_gm.print_readable(print_output=False)) 4209 self.assertExpectedInline( 4210 actual, 4211 """\ 4212class GraphModule(torch.nn.Module): 4213 def forward(self, L_x_: "f32[3, 3, 3]", L_y_: "f32[3, 3, 3]"): 4214 l_x_ = L_x_ 4215 l_y_ = L_y_ 4216 4217 _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None 4218 _grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None 4219 4220 diff_args: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None 4221 _wrap_for_grad_1: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_y_, 1); l_y_ = None 4222 4223 set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None 4224 4225 _set_tensor_requires_grad: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args); _set_tensor_requires_grad = None 4226 4227 set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None 4228 4229 sin: "f32[3, 3, 3]" = diff_args.sin() 4230 add: "f32[3, 3, 3]" = sin + _wrap_for_grad_1; sin = _wrap_for_grad_1 = None 4231 output: "f32[]" = add.sum(); add = None 4232 aux: "f32[3, 3, 3]" = diff_args.cos() 4233 4234 _autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True); diff_args = None 4235 grad_input: "f32[3, 3, 3]" = _autograd_grad[0]; _autograd_grad = None 4236 4237 grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None 4238 4239 output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None 4240 4241 aux_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None 4242 4243 _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None 4244 _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None 4245 return (grad_input_1, aux_1) 4246""", 4247 ) 4248 4249 def test_grad_two_tensor_all_grad_has_aux(self): 4250 counters.clear() 4251 4252 nums = (0, 1) 4253 4254 def fn(x, y): 4255 return ((x.sin() + y).sum(), x.cos()) 4256 4257 def wrapper_fn_const_var(x, y): 4258 return torch.func.grad(fn, argnums=(0, 1), has_aux=True)(x, y) 4259 4260 def wrapper_fn_tuple_var(x, y): 4261 return torch.func.grad(fn, argnums=nums, has_aux=True)(x, y) 4262 4263 y = torch.randn(3, 3, 3) 4264 x = torch.randn(3, 3, 3) 4265 wrapped_gm_const_var = self._compile_check(wrapper_fn_const_var, (x, y)) 4266 wrapped_gm_tuple_var = self._compile_check(wrapper_fn_tuple_var, (x, y)) 4267 4268 # Dynamic shapes produce a slightly different graph. 4269 if check_dynamic_shape_capture(): 4270 return 4271 4272 actual_const_var = normalize_gm( 4273 wrapped_gm_const_var.print_readable(print_output=False) 4274 ) 4275 actual_tuple_var = normalize_gm( 4276 wrapped_gm_tuple_var.print_readable(print_output=False) 4277 ) 4278 self.assertExpectedInline( 4279 actual_const_var, 4280 """\ 4281class GraphModule(torch.nn.Module): 4282 def forward(self, L_x_: "f32[3, 3, 3]", L_y_: "f32[3, 3, 3]"): 4283 l_x_ = L_x_ 4284 l_y_ = L_y_ 4285 4286 _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None 4287 _grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None 4288 4289 child: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None 4290 child_1: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_y_, 1); l_y_ = None 4291 4292 set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None 4293 4294 _set_tensor_requires_grad: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child); _set_tensor_requires_grad = None 4295 4296 set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None 4297 set_inplace_requires_grad_allowed_2 = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed_2 = None 4298 4299 _set_tensor_requires_grad_1: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child_1); _set_tensor_requires_grad_1 = None 4300 4301 set_inplace_requires_grad_allowed_3 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_3 = None 4302 4303 sin: "f32[3, 3, 3]" = child.sin() 4304 add: "f32[3, 3, 3]" = sin + child_1; sin = None 4305 output: "f32[]" = add.sum(); add = None 4306 aux: "f32[3, 3, 3]" = child.cos() 4307 4308 _autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [child, child_1], create_graph = True); child = child_1 = None 4309 child_2: "f32[3, 3, 3]" = _autograd_grad[0] 4310 child_3: "f32[3, 3, 3]" = _autograd_grad[1]; _autograd_grad = None 4311 4312 _unwrap_for_grad: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(child_2, 1); child_2 = None 4313 _unwrap_for_grad_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(child_3, 1); child_3 = None 4314 4315 output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None 4316 4317 aux_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None 4318 4319 _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None 4320 _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None 4321 return (_unwrap_for_grad, _unwrap_for_grad_1, aux_1) 4322""", 4323 ) 4324 self.assertExpectedInline( 4325 actual_tuple_var, 4326 """\ 4327class GraphModule(torch.nn.Module): 4328 def forward(self, L_x_: "f32[3, 3, 3]", L_y_: "f32[3, 3, 3]"): 4329 l_x_ = L_x_ 4330 l_y_ = L_y_ 4331 4332 _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None 4333 _grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None 4334 4335 child: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None 4336 child_1: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_y_, 1); l_y_ = None 4337 4338 set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None 4339 4340 _set_tensor_requires_grad: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child); _set_tensor_requires_grad = None 4341 4342 set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None 4343 set_inplace_requires_grad_allowed_2 = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed_2 = None 4344 4345 _set_tensor_requires_grad_1: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child_1); _set_tensor_requires_grad_1 = None 4346 4347 set_inplace_requires_grad_allowed_3 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_3 = None 4348 4349 sin: "f32[3, 3, 3]" = child.sin() 4350 add: "f32[3, 3, 3]" = sin + child_1; sin = None 4351 output: "f32[]" = add.sum(); add = None 4352 aux: "f32[3, 3, 3]" = child.cos() 4353 4354 _autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [child, child_1], create_graph = True); child = child_1 = None 4355 child_2: "f32[3, 3, 3]" = _autograd_grad[0] 4356 child_3: "f32[3, 3, 3]" = _autograd_grad[1]; _autograd_grad = None 4357 4358 _unwrap_for_grad: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(child_2, 1); child_2 = None 4359 _unwrap_for_grad_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(child_3, 1); child_3 = None 4360 4361 output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None 4362 4363 aux_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None 4364 4365 _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None 4366 _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None 4367 return (_unwrap_for_grad, _unwrap_for_grad_1, aux_1) 4368""", 4369 ) 4370 4371 def test_grad_over_grad(self): 4372 counters.clear() 4373 4374 def fn(x): 4375 return x.sin().sum() 4376 4377 def wrapper_fn(x): 4378 return torch.func.grad(torch.func.grad(fn))(x) 4379 4380 x = torch.randn(()) 4381 wrapped_gm = self._compile_check(wrapper_fn, (x,), fullgraph=False) 4382 4383 if check_dynamic_shape_capture(): 4384 return 4385 4386 actual = normalize_gm(wrapped_gm.print_readable(print_output=False)) 4387 self.assertExpectedInline( 4388 actual, 4389 """\ 4390class GraphModule(torch.nn.Module): 4391 def forward(self, L_x_: "f32[]"): 4392 l_x_ = L_x_ 4393 4394 _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None 4395 _grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None 4396 4397 diff_args: "f32[]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None 4398 4399 set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None 4400 4401 _set_tensor_requires_grad: "f32[]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args); _set_tensor_requires_grad = None 4402 4403 set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None 4404 _saved_tensors_hooks_disable_1 = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable_1 = None 4405 _grad_increment_nesting_1 = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting_1 = None 4406 4407 diff_args_1: "f32[]" = torch._C._functorch._wrap_for_grad(diff_args, 2) 4408 4409 set_inplace_requires_grad_allowed_2 = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed_2 = None 4410 4411 _set_tensor_requires_grad_1: "f32[]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args_1); _set_tensor_requires_grad_1 = None 4412 4413 set_inplace_requires_grad_allowed_3 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_3 = None 4414 4415 sin: "f32[]" = diff_args_1.sin() 4416 output: "f32[]" = sin.sum(); sin = None 4417 4418 _autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args_1], create_graph = True); diff_args_1 = None 4419 grad_input: "f32[]" = _autograd_grad[0]; _autograd_grad = None 4420 4421 grad_input_1: "f32[]" = torch._C._functorch._unwrap_for_grad(grad_input, 2); grad_input = None 4422 4423 output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 2); output = output_1 = None 4424 4425 _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None 4426 _saved_tensors_hooks_disable_2 = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable_2 = None 4427 4428 _autograd_grad_1 = torch._functorch.eager_transforms._autograd_grad((grad_input_1,), [diff_args], create_graph = True); diff_args = None 4429 grad_input_2: "f32[]" = _autograd_grad_1[0]; _autograd_grad_1 = None 4430 4431 grad_input_3: "f32[]" = torch._C._functorch._unwrap_for_grad(grad_input_2, 1); grad_input_2 = None 4432 4433 output_2: "f32[]" = torch._C._functorch._unwrap_for_grad(grad_input_1, 1); grad_input_1 = output_2 = None 4434 4435 _grad_decrement_nesting_1 = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting_1 = None 4436 _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None 4437 return (grad_input_3,) 4438""", 4439 ) 4440 4441 def test_grad_with_graph_break(self): 4442 counters.clear() 4443 4444 def fn(x): 4445 torch._dynamo.graph_break() 4446 return x.sin().sum() 4447 4448 def wrapper_fn(x): 4449 return torch.func.grad(fn)(x) 4450 4451 x = torch.randn(3, 3, 3) 4452 actual = wrapper_fn(x) 4453 expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x) 4454 self.assertEqual(len(counters["graph_break"]), 1) 4455 self.assertEqual(actual, expected) 4456 4457 def test_grad_with_side_effect(self): 4458 counters.clear() 4459 4460 foo = [1, 2] 4461 4462 def fn(x): 4463 foo.append(3) 4464 return x.sin().sum() 4465 4466 def wrapper_fn(x): 4467 return torch.func.grad(fn)(x) 4468 4469 x = torch.randn(3, 3, 3) 4470 actual = wrapper_fn(x) 4471 expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x) 4472 self.assertEqual(len(counters["graph_break"]), 0) 4473 self.assertEqual(actual, expected) 4474 4475 def test_grad_pytree(self): 4476 counters.clear() 4477 4478 def fn(x): 4479 x1, x2 = x 4480 return x1.sin().sum() + x2 4481 4482 def wrapper_fn(x): 4483 return torch.func.grad(fn)(x) 4484 4485 x1 = torch.randn(3, 3, 3) 4486 x2 = torch.randn(()) 4487 actual = wrapper_fn((x1, x2)) 4488 expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)( 4489 (x1, x2) 4490 ) 4491 self.assertEqual(len(counters["graph_break"]), 0) 4492 self.assertEqual(actual, expected) 4493 4494 def test_grad_non_tensor_input(self): 4495 counters.clear() 4496 4497 def fn(x, y): 4498 return x.sin().sum() + y 4499 4500 def wrapper_fn(x, y): 4501 return torch.func.grad(fn)(x, y) 4502 4503 x = torch.randn(3, 3, 3) 4504 y = 3.0 4505 wrapped_gm = self._compile_check(wrapper_fn, (x, y)) 4506 4507 # Dynamic shapes produce a slightly different graph. 4508 if check_dynamic_shape_capture(): 4509 return 4510 4511 actual = normalize_gm(wrapped_gm.print_readable(print_output=False)) 4512 self.assertExpectedInline( 4513 actual, 4514 """\ 4515class GraphModule(torch.nn.Module): 4516 def forward(self, L_x_: "f32[3, 3, 3]"): 4517 l_x_ = L_x_ 4518 4519 _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None 4520 _grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None 4521 4522 diff_args: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None 4523 4524 set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None 4525 4526 _set_tensor_requires_grad: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args); _set_tensor_requires_grad = None 4527 4528 set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None 4529 4530 sin: "f32[3, 3, 3]" = diff_args.sin() 4531 sum_1: "f32[]" = sin.sum(); sin = None 4532 output: "f32[]" = sum_1 + 3.0; sum_1 = None 4533 4534 _autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True); diff_args = None 4535 grad_input: "f32[3, 3, 3]" = _autograd_grad[0]; _autograd_grad = None 4536 4537 grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None 4538 4539 output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None 4540 4541 _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None 4542 _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None 4543 return (grad_input_1,) 4544""", 4545 ) 4546 4547 def test_grad_disable_capture(self): 4548 counters.clear() 4549 4550 with config.patch(capture_func_transforms=False): 4551 # We have verified above that this 4552 # function compiles 4553 def fn(x): 4554 return x.sin().sum() 4555 4556 def wrapper_fn(x): 4557 return torch.func.grad(fn)(x) 4558 4559 x = torch.randn(3, 3) 4560 actual = wrapper_fn(x) 4561 expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)( 4562 x 4563 ) 4564 self.assertEqual(len(counters["graph_break"]), 1) 4565 self.assertEqual( 4566 dict(counters["graph_break"]), 4567 { 4568 "torch.func.grad capture is disabled, it can be turned " 4569 "on by setting `torch._dynamo.config.capture_func_transforms=True`": 2 4570 }, 4571 ) 4572 self.assertEqual(actual, expected) 4573 4574 def test_grad_fn_with_kwargs(self): 4575 def fn(x, y): 4576 return (x + y).sum() 4577 4578 def wrapper_fn(x, y): 4579 return torch.func.grad(fn)(x, y=y) 4580 4581 x = torch.randn(3, 3) 4582 y = torch.randn(3, 3) 4583 actual = wrapper_fn(x, y) 4584 expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x, y) 4585 self.assertEqual(len(counters["graph_break"]), 0) 4586 self.assertEqual(actual, expected) 4587 4588 def test_jacfwd(self): 4589 counters.clear() 4590 4591 def wrapper_fn(x): 4592 return torch.func.jacfwd(torch.sin)(x) 4593 4594 x = torch.randn(4, 3) 4595 wrapped_gm = self._compile_check(wrapper_fn, (x,)) 4596 # Dynamic shapes produce a slightly different graph. 4597 if check_dynamic_shape_capture(): 4598 return 4599 4600 actual = normalize_gm(wrapped_gm.print_readable(print_output=False)) 4601 self.assertExpectedInline( 4602 actual, 4603 """\ 4604class GraphModule(torch.nn.Module): 4605 def forward(self, L_x_: "f32[4, 3]"): 4606 l_x_ = L_x_ 4607 4608 tensor: "i64[1]" = torch.tensor((12,)) 4609 cumsum: "i64[1]" = tensor.cumsum(dim = 0); tensor = None 4610 getitem: "i64[0]" = cumsum[slice(None, -1, None)]; cumsum = None 4611 neg: "i64[0]" = getitem.neg(); getitem = None 4612 unbind = neg.unbind(); neg = unbind = None 4613 4614 chunk: "f32[12, 12]" = l_x_.new_zeros(12, 12) 4615 4616 diagonal: "f32[12]" = chunk.diagonal(0) 4617 fill_: "f32[12]" = diagonal.fill_(1); diagonal = fill_ = None 4618 4619 child: "f32[12, 4, 3]" = chunk.view(12, 4, 3); chunk = None 4620 4621 lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None 4622 4623 _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None 4624 4625 child_1: "f32[4, 3]" = torch._C._functorch._add_batch_dim(child, 0, 1); child = None 4626 4627 _jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_x_,), (child_1,)); _jvp_treespec_compare = None 4628 4629 _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None 4630 _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None 4631 _enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None 4632 4633 _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None 4634 4635 _make_dual: "f32[4, 3]" = torch._make_dual(l_x_, child_1, level = 0); child_1 = None 4636 4637 _wrap_for_grad: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 2); l_x_ = _wrap_for_grad = None 4638 4639 result_duals: "f32[4, 3]" = torch.sin(_make_dual); _make_dual = None 4640 4641 _unpack_dual = torch._unpack_dual(result_duals, level = 0); result_duals = None 4642 primal: "f32[4, 3]" = _unpack_dual[0] 4643 dual: "f32[4, 3]" = _unpack_dual[1]; _unpack_dual = None 4644 4645 primals_out_unflatten: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = primals_out_unflatten = None 4646 4647 tangents_out_unflatten: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None 4648 4649 _exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None 4650 _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None 4651 _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None 4652 4653 results: "f32[12, 4, 3]" = torch._C._functorch._remove_batch_dim(tangents_out_unflatten, 1, 12, 0); tangents_out_unflatten = None 4654 4655 _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None 4656 4657 movedim: "f32[4, 3, 12]" = results.movedim(0, -1); results = None 4658 split = movedim.split((12,), dim = -1); movedim = None 4659 jac_out_in: "f32[4, 3, 12]" = split[0]; split = None 4660 4661 unflatten: "f32[4, 3, 4, 3]" = jac_out_in.unflatten(-1, (4, 3)); jac_out_in = None 4662 return (unflatten,) 4663""", 4664 ) 4665 4666 def test_jacfwd_two_tensors_argnums(self): 4667 counters.clear() 4668 4669 def fn(x, y): 4670 return y.sin() 4671 4672 def wrapper_fn(x, y): 4673 return torch.func.jacfwd(fn, argnums=1)(x, y) 4674 4675 x = torch.randn(4, 3) 4676 y = torch.randn(3, 4) 4677 wrapped_gm = self._compile_check(wrapper_fn, (x, y)) 4678 # Dynamic shapes produce a slightly different graph. 4679 if check_dynamic_shape_capture(): 4680 return 4681 4682 actual = normalize_gm(wrapped_gm.print_readable(print_output=False)) 4683 self.assertExpectedInline( 4684 actual, 4685 """\ 4686class GraphModule(torch.nn.Module): 4687 def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): 4688 l_x_ = L_x_ 4689 l_y_ = L_y_ 4690 4691 tensor: "i64[1]" = torch.tensor((12,)) 4692 cumsum: "i64[1]" = tensor.cumsum(dim = 0); tensor = None 4693 getitem: "i64[0]" = cumsum[slice(None, -1, None)]; cumsum = None 4694 neg: "i64[0]" = getitem.neg(); getitem = None 4695 unbind = neg.unbind(); neg = unbind = None 4696 4697 chunk: "f32[12, 12]" = l_y_.new_zeros(12, 12) 4698 4699 diagonal: "f32[12]" = chunk.diagonal(0) 4700 fill_: "f32[12]" = diagonal.fill_(1); diagonal = fill_ = None 4701 4702 child: "f32[12, 3, 4]" = chunk.view(12, 3, 4); chunk = None 4703 4704 lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None 4705 4706 _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None 4707 4708 child_1: "f32[3, 4]" = torch._C._functorch._add_batch_dim(child, 0, 1); child = None 4709 4710 _jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_y_,), (child_1,)); _jvp_treespec_compare = None 4711 4712 _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None 4713 _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None 4714 _enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None 4715 4716 _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None 4717 4718 _make_dual: "f32[3, 4]" = torch._make_dual(l_y_, child_1, level = 0); child_1 = None 4719 4720 _wrap_for_grad: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 2); l_x_ = _wrap_for_grad = None 4721 _wrap_for_grad_1: "f32[3, 4]" = torch._C._functorch._wrap_for_grad(l_y_, 2); l_y_ = _wrap_for_grad_1 = None 4722 4723 result_duals: "f32[3, 4]" = _make_dual.sin(); _make_dual = None 4724 4725 _unpack_dual = torch._unpack_dual(result_duals, level = 0); result_duals = None 4726 primal: "f32[3, 4]" = _unpack_dual[0] 4727 dual: "f32[3, 4]" = _unpack_dual[1]; _unpack_dual = None 4728 4729 primals_out_unflatten: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = primals_out_unflatten = None 4730 4731 tangents_out_unflatten: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None 4732 4733 _exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None 4734 _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None 4735 _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None 4736 4737 results: "f32[12, 3, 4]" = torch._C._functorch._remove_batch_dim(tangents_out_unflatten, 1, 12, 0); tangents_out_unflatten = None 4738 4739 _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None 4740 4741 movedim: "f32[3, 4, 12]" = results.movedim(0, -1); results = None 4742 split = movedim.split((12,), dim = -1); movedim = None 4743 jac_out_in: "f32[3, 4, 12]" = split[0]; split = None 4744 4745 unflatten: "f32[3, 4, 3, 4]" = jac_out_in.unflatten(-1, (3, 4)); jac_out_in = None 4746 return (unflatten,) 4747""", 4748 ) 4749 4750 def test_jacfwd_has_aux(self): 4751 counters.clear() 4752 4753 def fn(x, y): 4754 return y.sin(), x 4755 4756 def wrapper_fn(x, y): 4757 return torch.func.jacfwd(fn, argnums=1, has_aux=True)(x, y) 4758 4759 x = torch.randn(4, 3) 4760 y = torch.randn(3, 4) 4761 wrapped_gm = self._compile_check(wrapper_fn, (x, y)) 4762 # Dynamic shapes produce a slightly different graph. 4763 if check_dynamic_shape_capture(): 4764 return 4765 4766 actual = normalize_gm(wrapped_gm.print_readable(print_output=False)) 4767 self.assertExpectedInline( 4768 actual, 4769 """\ 4770class GraphModule(torch.nn.Module): 4771 def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): 4772 l_x_ = L_x_ 4773 l_y_ = L_y_ 4774 4775 tensor: "i64[1]" = torch.tensor((12,)) 4776 cumsum: "i64[1]" = tensor.cumsum(dim = 0); tensor = None 4777 getitem: "i64[0]" = cumsum[slice(None, -1, None)]; cumsum = None 4778 neg: "i64[0]" = getitem.neg(); getitem = None 4779 unbind = neg.unbind(); neg = unbind = None 4780 4781 chunk: "f32[12, 12]" = l_y_.new_zeros(12, 12) 4782 4783 diagonal: "f32[12]" = chunk.diagonal(0) 4784 fill_: "f32[12]" = diagonal.fill_(1); diagonal = fill_ = None 4785 4786 child: "f32[12, 3, 4]" = chunk.view(12, 3, 4); chunk = None 4787 4788 lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None 4789 4790 _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None 4791 4792 child_1: "f32[3, 4]" = torch._C._functorch._add_batch_dim(child, 0, 1); child = None 4793 4794 _jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_y_,), (child_1,)); _jvp_treespec_compare = None 4795 4796 _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None 4797 _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None 4798 _enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None 4799 4800 _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None 4801 4802 _make_dual: "f32[3, 4]" = torch._make_dual(l_y_, child_1, level = 0); child_1 = None 4803 4804 aux: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 2); l_x_ = None 4805 _wrap_for_grad_1: "f32[3, 4]" = torch._C._functorch._wrap_for_grad(l_y_, 2); l_y_ = _wrap_for_grad_1 = None 4806 4807 result_duals: "f32[3, 4]" = _make_dual.sin(); _make_dual = None 4808 4809 aux_1: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(aux, 2); aux = None 4810 4811 _unpack_dual = torch._unpack_dual(result_duals, level = 0); result_duals = None 4812 primal: "f32[3, 4]" = _unpack_dual[0] 4813 dual: "f32[3, 4]" = _unpack_dual[1]; _unpack_dual = None 4814 4815 primals_out_unflatten: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = primals_out_unflatten = None 4816 4817 tangents_out_unflatten: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None 4818 4819 _exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None 4820 _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None 4821 _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None 4822 4823 results: "f32[12, 3, 4]" = torch._C._functorch._remove_batch_dim(tangents_out_unflatten, 1, 12, 0); tangents_out_unflatten = None 4824 aux_2: "f32[12, 4, 3]" = torch._C._functorch._remove_batch_dim(aux_1, 1, 12, 0); aux_1 = None 4825 4826 _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None 4827 4828 aux_3: "f32[4, 3]" = aux_2[0]; aux_2 = None 4829 4830 movedim: "f32[3, 4, 12]" = results.movedim(0, -1); results = None 4831 split = movedim.split((12,), dim = -1); movedim = None 4832 jac_out_in: "f32[3, 4, 12]" = split[0]; split = None 4833 4834 unflatten: "f32[3, 4, 3, 4]" = jac_out_in.unflatten(-1, (3, 4)); jac_out_in = None 4835 return (unflatten, aux_3) 4836""", 4837 ) 4838 4839 def test_jacfwd_randomness(self): 4840 counters.clear() 4841 4842 def fn(x, y): 4843 return y.sin(), x 4844 4845 def wrapper_fn(x, y): 4846 return torch.func.jacfwd(fn, randomness="same")(x, y) 4847 4848 x = torch.randn(4, 3) 4849 y = torch.randn(3, 4) 4850 wrapped_gm = self._compile_check(wrapper_fn, (x, y)) 4851 # Dynamic shapes produce a slightly different graph. 4852 if check_dynamic_shape_capture(): 4853 return 4854 4855 actual = normalize_gm(wrapped_gm.print_readable(print_output=False)) 4856 self.assertExpectedInline( 4857 actual, 4858 """\ 4859class GraphModule(torch.nn.Module): 4860 def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): 4861 l_x_ = L_x_ 4862 l_y_ = L_y_ 4863 4864 tensor: "i64[1]" = torch.tensor((12,)) 4865 cumsum: "i64[1]" = tensor.cumsum(dim = 0); tensor = None 4866 getitem: "i64[0]" = cumsum[slice(None, -1, None)]; cumsum = None 4867 neg: "i64[0]" = getitem.neg(); getitem = None 4868 unbind = neg.unbind(); neg = unbind = None 4869 4870 chunk: "f32[12, 12]" = l_x_.new_zeros(12, 12) 4871 4872 diagonal: "f32[12]" = chunk.diagonal(0) 4873 fill_: "f32[12]" = diagonal.fill_(1); diagonal = fill_ = None 4874 4875 child: "f32[12, 4, 3]" = chunk.view(12, 4, 3); chunk = None 4876 4877 lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None 4878 4879 _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'same'); _vmap_increment_nesting = None 4880 4881 child_1: "f32[4, 3]" = torch._C._functorch._add_batch_dim(child, 0, 1); child = None 4882 4883 _jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_x_,), (child_1,)); _jvp_treespec_compare = None 4884 4885 _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None 4886 _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None 4887 _enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None 4888 4889 _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None 4890 4891 child_3: "f32[4, 3]" = torch._make_dual(l_x_, child_1, level = 0); child_1 = None 4892 4893 _wrap_for_grad: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 2); l_x_ = _wrap_for_grad = None 4894 _wrap_for_grad_1: "f32[3, 4]" = torch._C._functorch._wrap_for_grad(l_y_, 2); l_y_ = None 4895 4896 child_2: "f32[3, 4]" = _wrap_for_grad_1.sin(); _wrap_for_grad_1 = None 4897 4898 _unpack_dual = torch._unpack_dual(child_2, level = 0); child_2 = None 4899 primal: "f32[3, 4]" = _unpack_dual[0]; _unpack_dual = None 4900 4901 tangent: "f32[3, 4]" = torch.zeros_like(primal) 4902 4903 _unpack_dual_1 = torch._unpack_dual(child_3, level = 0); child_3 = None 4904 primal_1: "f32[4, 3]" = _unpack_dual_1[0] 4905 dual: "f32[4, 3]" = _unpack_dual_1[1]; _unpack_dual_1 = None 4906 4907 child_4: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = child_4 = None 4908 child_5: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(primal_1, 2); primal_1 = child_5 = None 4909 4910 child_6: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(tangent, 2); tangent = None 4911 child_7: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None 4912 4913 _exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None 4914 _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None 4915 _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None 4916 4917 child_8: "f32[12, 3, 4]" = torch._C._functorch._remove_batch_dim(child_6, 1, 12, 0); child_6 = None 4918 child_9: "f32[12, 4, 3]" = torch._C._functorch._remove_batch_dim(child_7, 1, 12, 0); child_7 = None 4919 4920 _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None 4921 4922 movedim: "f32[3, 4, 12]" = child_8.movedim(0, -1); child_8 = None 4923 split = movedim.split((12,), dim = -1); movedim = None 4924 jac_out_in: "f32[3, 4, 12]" = split[0]; split = None 4925 4926 unflatten: "f32[3, 4, 4, 3]" = jac_out_in.unflatten(-1, (4, 3)); jac_out_in = None 4927 4928 movedim_1: "f32[4, 3, 12]" = child_9.movedim(0, -1); child_9 = None 4929 split_1 = movedim_1.split((12,), dim = -1); movedim_1 = None 4930 jac_out_in_1: "f32[4, 3, 12]" = split_1[0]; split_1 = None 4931 4932 unflatten_1: "f32[4, 3, 4, 3]" = jac_out_in_1.unflatten(-1, (4, 3)); jac_out_in_1 = None 4933 return (unflatten, unflatten_1) 4934""", 4935 ) 4936 4937 def test_jacfwd_disable_capture(self): 4938 counters.clear() 4939 4940 with config.patch(capture_func_transforms=False): 4941 # We have verified above that this 4942 # function compiles 4943 def wrapper_fn(x): 4944 return torch.func.jacfwd(torch.sin)(x) 4945 4946 x = torch.randn(3, 3, 3) 4947 actual = wrapper_fn(x) 4948 expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)( 4949 x 4950 ) 4951 self.assertEqual(len(counters["graph_break"]), 2) 4952 self.assertEqual( 4953 dict(counters["graph_break"]), 4954 { 4955 "torch.func.vmap capture is disabled, it can be " 4956 "turned on by setting `torch._dynamo.config.capture_func_transforms=True`": 2, 4957 "torch.func.jacfwd capture is disabled, it can be " 4958 "turned on by setting `torch._dynamo.config.capture_func_transforms=True`": 1, 4959 }, 4960 ) 4961 self.assertEqual(actual, expected) 4962 4963 def test_jvp_simple(self): 4964 counters.clear() 4965 4966 def fn(x): 4967 return x.sin().sum() 4968 4969 def wrapper_fn(x, v): 4970 return torch.func.jvp(fn, (x,), (v,)) 4971 4972 x = torch.randn(3, 3) 4973 v = torch.randn(3, 3) 4974 wrapped_gm = self._compile_check(wrapper_fn, (x, v)) 4975 4976 # Dynamic shapes produce a slightly different graph. 4977 if check_dynamic_shape_capture(): 4978 return 4979 4980 actual = normalize_gm(wrapped_gm.print_readable(print_output=False)) 4981 self.assertExpectedInline( 4982 actual, 4983 """\ 4984class GraphModule(torch.nn.Module): 4985 def forward(self, L_x_: "f32[3, 3]", L_v_: "f32[3, 3]"): 4986 l_x_ = L_x_ 4987 l_v_ = L_v_ 4988 4989 _jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_x_,), (l_v_,)); _jvp_treespec_compare = None 4990 4991 _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None 4992 _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None 4993 _enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None 4994 4995 _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None 4996 4997 _make_dual: "f32[3, 3]" = torch._make_dual(l_x_, l_v_, level = 0); l_x_ = l_v_ = None 4998 4999 sin: "f32[3, 3]" = _make_dual.sin(); _make_dual = None 5000 result_duals: "f32[]" = sin.sum(); sin = None 5001 5002 _unpack_dual = torch._unpack_dual(result_duals, level = 0); result_duals = None 5003 primal: "f32[]" = _unpack_dual[0] 5004 dual: "f32[]" = _unpack_dual[1]; _unpack_dual = None 5005 5006 primals_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(primal, 1); primal = None 5007 5008 tangents_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(dual, 1); dual = None 5009 5010 _exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None 5011 _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None 5012 _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None 5013 return (primals_out_unflatten, tangents_out_unflatten) 5014""", 5015 ) 5016 5017 def test_jvp_has_aux(self): 5018 counters.clear() 5019 5020 def fn(x): 5021 return x.sin().sum(), x 5022 5023 def wrapper_fn(x, v): 5024 return torch.func.jvp(fn, (x,), (v,), has_aux=True) 5025 5026 x = torch.randn(3, 3) 5027 v = torch.randn(3, 3) 5028 wrapped_gm = self._compile_check(wrapper_fn, (x, v)) 5029 5030 # Dynamic shapes produce a slightly different graph. 5031 if check_dynamic_shape_capture(): 5032 return 5033 5034 actual = normalize_gm(wrapped_gm.print_readable(print_output=False)) 5035 self.assertExpectedInline( 5036 actual, 5037 """\ 5038class GraphModule(torch.nn.Module): 5039 def forward(self, L_x_: "f32[3, 3]", L_v_: "f32[3, 3]"): 5040 l_x_ = L_x_ 5041 l_v_ = L_v_ 5042 5043 _jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_x_,), (l_v_,)); _jvp_treespec_compare = None 5044 5045 _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None 5046 _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None 5047 _enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None 5048 5049 _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None 5050 5051 aux: "f32[3, 3]" = torch._make_dual(l_x_, l_v_, level = 0); l_x_ = l_v_ = None 5052 5053 sin: "f32[3, 3]" = aux.sin() 5054 result_duals: "f32[]" = sin.sum(); sin = None 5055 5056 aux_1: "f32[3, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None 5057 5058 _unpack_dual = torch._unpack_dual(result_duals, level = 0); result_duals = None 5059 primal: "f32[]" = _unpack_dual[0] 5060 dual: "f32[]" = _unpack_dual[1]; _unpack_dual = None 5061 5062 primals_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(primal, 1); primal = None 5063 5064 tangents_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(dual, 1); dual = None 5065 5066 _exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None 5067 _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None 5068 _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None 5069 return (primals_out_unflatten, tangents_out_unflatten, aux_1) 5070""", 5071 ) 5072 5073 def test_jvp_two_tensors_has_aux(self): 5074 counters.clear() 5075 5076 def fn(x, y): 5077 return (x.sin().sum() + y.cos()), x 5078 5079 def wrapper_fn(x, y, v): 5080 return torch.func.jvp(fn, (x, y), (v, v), has_aux=True) 5081 5082 x = torch.randn(3, 3) 5083 y = torch.randn(3, 3) 5084 v = torch.randn(3, 3) 5085 wrapped_gm = self._compile_check(wrapper_fn, (x, y, v)) 5086 5087 # Dynamic shapes produce a slightly different graph. 5088 if check_dynamic_shape_capture(): 5089 return 5090 5091 actual = normalize_gm(wrapped_gm.print_readable(print_output=False)) 5092 self.assertExpectedInline( 5093 actual, 5094 """\ 5095class GraphModule(torch.nn.Module): 5096 def forward(self, L_x_: "f32[3, 3]", L_y_: "f32[3, 3]", L_v_: "f32[3, 3]"): 5097 l_x_ = L_x_ 5098 l_y_ = L_y_ 5099 l_v_ = L_v_ 5100 5101 _jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_x_, l_y_), (l_v_, l_v_)); _jvp_treespec_compare = None 5102 5103 _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None 5104 _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None 5105 _enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None 5106 5107 _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None 5108 5109 aux: "f32[3, 3]" = torch._make_dual(l_x_, l_v_, level = 0); l_x_ = None 5110 5111 _maybe_load_decompositions_1 = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions_1 = None 5112 5113 _make_dual_1: "f32[3, 3]" = torch._make_dual(l_y_, l_v_, level = 0); l_y_ = l_v_ = None 5114 5115 sin: "f32[3, 3]" = aux.sin() 5116 sum_1: "f32[]" = sin.sum(); sin = None 5117 cos: "f32[3, 3]" = _make_dual_1.cos(); _make_dual_1 = None 5118 result_duals: "f32[3, 3]" = sum_1 + cos; sum_1 = cos = None 5119 5120 aux_1: "f32[3, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None 5121 5122 _unpack_dual = torch._unpack_dual(result_duals, level = 0); result_duals = None 5123 primal: "f32[3, 3]" = _unpack_dual[0] 5124 dual: "f32[3, 3]" = _unpack_dual[1]; _unpack_dual = None 5125 5126 primals_out_unflatten: "f32[3, 3]" = torch._C._functorch._unwrap_for_grad(primal, 1); primal = None 5127 5128 tangents_out_unflatten: "f32[3, 3]" = torch._C._functorch._unwrap_for_grad(dual, 1); dual = None 5129 5130 _exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None 5131 _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None 5132 _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None 5133 return (primals_out_unflatten, tangents_out_unflatten, aux_1) 5134""", 5135 ) 5136 5137 def test_jvp_two_tensors_disable_grad(self): 5138 counters.clear() 5139 5140 def fn(x): 5141 return x.sin().sum() 5142 5143 def wrapper_fn(x, v): 5144 with torch.autograd.forward_ad._set_fwd_grad_enabled(False): 5145 return torch.func.jvp(fn, (x,), (v,)) 5146 5147 x = torch.randn(3, 3) 5148 v = torch.randn(3, 3) 5149 wrapped_gm = self._compile_check(wrapper_fn, (x, v)) 5150 5151 # Dynamic shapes produce a slightly different graph. 5152 if check_dynamic_shape_capture(): 5153 return 5154 5155 actual = normalize_gm(wrapped_gm.print_readable(print_output=False)) 5156 self.assertExpectedInline( 5157 actual, 5158 """\ 5159class GraphModule(torch.nn.Module): 5160 def forward(self, L_x_: "f32[3, 3]", L_v_: "f32[3, 3]"): 5161 l_x_ = L_x_ 5162 l_v_ = L_v_ 5163 5164 _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(False); _set_fwd_grad_enabled = None 5165 5166 _jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_x_,), (l_v_,)); _jvp_treespec_compare = None 5167 5168 _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None 5169 _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None 5170 _enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None 5171 5172 _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None 5173 5174 _make_dual: "f32[3, 3]" = torch._make_dual(l_x_, l_v_, level = 0); l_x_ = l_v_ = None 5175 5176 sin: "f32[3, 3]" = _make_dual.sin(); _make_dual = None 5177 result_duals: "f32[]" = sin.sum(); sin = None 5178 5179 _unpack_dual = torch._unpack_dual(result_duals, level = 0); result_duals = None 5180 primal: "f32[]" = _unpack_dual[0] 5181 dual: "f32[]" = _unpack_dual[1]; _unpack_dual = None 5182 5183 primals_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(primal, 1); primal = None 5184 5185 tangents_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(dual, 1); dual = None 5186 5187 _exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None 5188 _set_fwd_grad_enabled_2 = torch._C._set_fwd_grad_enabled(False); _set_fwd_grad_enabled_2 = None 5189 _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None 5190 _set_fwd_grad_enabled_3 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_3 = None 5191 return (primals_out_unflatten, tangents_out_unflatten) 5192""", 5193 ) 5194 5195 def test_jvp_two_tensors_disable_enable_disable_grad(self): 5196 counters.clear() 5197 5198 def fn(x): 5199 return x.sin().sum() 5200 5201 def wrapper_fn(x, v): 5202 with torch.autograd.forward_ad._set_fwd_grad_enabled(False): # (1) 5203 with torch.autograd.forward_ad._set_fwd_grad_enabled(True): # (2) 5204 with torch.autograd.forward_ad._set_fwd_grad_enabled(False): # (3) 5205 return torch.func.jvp(fn, (x,), (v,)) # (4) 5206 5207 # Start True 5208 # False (1) 5209 # True (2) 5210 # False (3) 5211 # True (4) 5212 # True (undo 3) 5213 # False (undo 2) 5214 # True (undo 1) 5215 5216 x = torch.randn(3, 3) 5217 v = torch.randn(3, 3) 5218 wrapped_gm = self._compile_check(wrapper_fn, (x, v)) 5219 5220 # Dynamic shapes produce a slightly different graph. 5221 if check_dynamic_shape_capture(): 5222 return 5223 5224 actual = normalize_gm(wrapped_gm.print_readable(print_output=False)) 5225 self.assertExpectedInline( 5226 actual, 5227 """\ 5228class GraphModule(torch.nn.Module): 5229 def forward(self, L_x_: "f32[3, 3]", L_v_: "f32[3, 3]"): 5230 l_x_ = L_x_ 5231 l_v_ = L_v_ 5232 5233 _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(False); _set_fwd_grad_enabled = None 5234 _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None 5235 _set_fwd_grad_enabled_2 = torch._C._set_fwd_grad_enabled(False); _set_fwd_grad_enabled_2 = None 5236 5237 _jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_x_,), (l_v_,)); _jvp_treespec_compare = None 5238 5239 _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None 5240 _set_fwd_grad_enabled_3 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_3 = None 5241 _enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None 5242 5243 _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None 5244 5245 _make_dual: "f32[3, 3]" = torch._make_dual(l_x_, l_v_, level = 0); l_x_ = l_v_ = None 5246 5247 sin: "f32[3, 3]" = _make_dual.sin(); _make_dual = None 5248 result_duals: "f32[]" = sin.sum(); sin = None 5249 5250 _unpack_dual = torch._unpack_dual(result_duals, level = 0); result_duals = None 5251 primal: "f32[]" = _unpack_dual[0] 5252 dual: "f32[]" = _unpack_dual[1]; _unpack_dual = None 5253 5254 primals_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(primal, 1); primal = None 5255 5256 tangents_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(dual, 1); dual = None 5257 5258 _exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None 5259 _set_fwd_grad_enabled_4 = torch._C._set_fwd_grad_enabled(False); _set_fwd_grad_enabled_4 = None 5260 _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None 5261 _set_fwd_grad_enabled_5 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_5 = None 5262 _set_fwd_grad_enabled_6 = torch._C._set_fwd_grad_enabled(False); _set_fwd_grad_enabled_6 = None 5263 _set_fwd_grad_enabled_7 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_7 = None 5264 return (primals_out_unflatten, tangents_out_unflatten) 5265""", 5266 ) 5267 5268 def test_jvp_freevar_tensor(self): 5269 counters.clear() 5270 y = torch.randn(3, 3) 5271 5272 def fn(x): 5273 return (x.sin() + y).sum() 5274 5275 def wrapper_fn(x): 5276 return torch.func.jvp(fn, (x,), (x,)) 5277 5278 x = torch.randn(3, 3) 5279 expected = wrapper_fn(x) 5280 actual = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=True)(x) 5281 self.assertEqual(actual, expected) 5282 5283 def test_jvp_jvp(self): 5284 counters.clear() 5285 5286 if check_dynamic_shape_capture(): 5287 self.skipTest("test fails with dynamic shapes") 5288 5289 def fn(x): 5290 return torch.func.jvp(torch.sin, (x,), (x,)) 5291 5292 def wrapper_fn(x): 5293 return torch.func.jvp(fn, (x,), (x,)) 5294 5295 x = torch.randn(3, 3, 3) 5296 wrapped_gm = self._compile_check(wrapper_fn, (x,)) 5297 5298 # Dynamic shapes produce a slightly different graph. 5299 if check_dynamic_shape_capture(): 5300 return 5301 5302 actual = normalize_gm(wrapped_gm.print_readable(print_output=False)) 5303 self.assertExpectedInline( 5304 actual, 5305 """\ 5306class GraphModule(torch.nn.Module): 5307 def forward(self, L_x_: "f32[3, 3, 3]"): 5308 l_x_ = L_x_ 5309 5310 _jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_x_,), (l_x_,)); _jvp_treespec_compare = None 5311 5312 _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None 5313 _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None 5314 _enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None 5315 5316 _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None 5317 5318 child: "f32[3, 3, 3]" = torch._make_dual(l_x_, l_x_, level = 0); l_x_ = None 5319 5320 _jvp_treespec_compare_1 = torch._functorch.eager_transforms._jvp_treespec_compare((child,), (child,)); _jvp_treespec_compare_1 = None 5321 5322 _jvp_increment_nesting_1 = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting_1 = None 5323 _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None 5324 5325 _maybe_load_decompositions_1 = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions_1 = None 5326 5327 _make_dual_1: "f32[3, 3, 3]" = torch._make_dual(child, child, level = 0); child = None 5328 5329 result_duals: "f32[3, 3, 3]" = torch.sin(_make_dual_1); _make_dual_1 = None 5330 5331 _unpack_dual = torch._unpack_dual(result_duals, level = 0); result_duals = None 5332 primal: "f32[3, 3, 3]" = _unpack_dual[0] 5333 dual: "f32[3, 3, 3]" = _unpack_dual[1]; _unpack_dual = None 5334 5335 primals_out_unflatten: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = None 5336 5337 tangents_out_unflatten: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None 5338 5339 _set_fwd_grad_enabled_2 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_2 = None 5340 _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None 5341 5342 _unpack_dual_1 = torch._unpack_dual(primals_out_unflatten, level = 0); primals_out_unflatten = None 5343 primal_1: "f32[3, 3, 3]" = _unpack_dual_1[0] 5344 dual_1: "f32[3, 3, 3]" = _unpack_dual_1[1]; _unpack_dual_1 = None 5345 _unpack_dual_2 = torch._unpack_dual(tangents_out_unflatten, level = 0); tangents_out_unflatten = None 5346 primal_2: "f32[3, 3, 3]" = _unpack_dual_2[0] 5347 dual_2: "f32[3, 3, 3]" = _unpack_dual_2[1]; _unpack_dual_2 = None 5348 5349 _unwrap_for_grad_2: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(primal_1, 1); primal_1 = None 5350 _unwrap_for_grad_3: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(primal_2, 1); primal_2 = None 5351 5352 _unwrap_for_grad_4: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(dual_1, 1); dual_1 = None 5353 _unwrap_for_grad_5: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(dual_2, 1); dual_2 = None 5354 5355 _exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None 5356 _set_fwd_grad_enabled_3 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_3 = None 5357 _jvp_decrement_nesting_1 = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting_1 = None 5358 return (_unwrap_for_grad_2, _unwrap_for_grad_3, _unwrap_for_grad_4, _unwrap_for_grad_5) 5359""", 5360 ) 5361 5362 def test_jvp_freevar_python_scalar(self): 5363 counters.clear() 5364 y = 3 5365 5366 def fn(x): 5367 return (x.sin() + y).sum() 5368 5369 def wrapper_fn(x): 5370 return torch.func.jvp(fn, (x,), (x,)) 5371 5372 x = torch.randn(3, 3, 3) 5373 expected = wrapper_fn(x) 5374 actual = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=True)(x) 5375 self.assertEqual(actual, expected) 5376 5377 def test_jvp_disable_capture(self): 5378 counters.clear() 5379 5380 with config.patch(capture_func_transforms=False): 5381 # We have verified above that this 5382 # function compiles 5383 def wrapper_fn(x): 5384 return torch.func.jvp(torch.sin, (x,), (x,)) 5385 5386 x = torch.randn(3, 3, 3) 5387 actual = wrapper_fn(x) 5388 expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)( 5389 x 5390 ) 5391 self.assertEqual(len(counters["graph_break"]), 1) 5392 self.assertEqual( 5393 dict(counters["graph_break"]), 5394 { 5395 "torch.func.jvp capture is disabled, it can be " 5396 "turned on by setting `torch._dynamo.config.capture_func_transforms=True`": 1 5397 }, 5398 ) 5399 self.assertEqual(actual, expected) 5400 5401 @config.patch(capture_func_transforms=True) 5402 def test_linearize_jvp_fn(self): 5403 counters.clear() 5404 5405 def wrapper_fn(x): 5406 output, jvp_fn = torch.func.linearize(torch.sin, x) 5407 return output, jvp_fn(x) 5408 5409 x = torch.randn(3, 3, 3) 5410 wrapped_gm = self._compile_check(wrapper_fn, (x,), fullgraph=False, graph_idx=0) 5411 5412 # Dynamic shapes produce a slightly different graph. 5413 if check_dynamic_shape_capture(): 5414 return 5415 5416 actual = normalize_gm(wrapped_gm.print_readable(print_output=False)) 5417 self.assertExpectedInline( 5418 actual, 5419 """\ 5420class GraphModule(torch.nn.Module): 5421 def forward(self, L_self_buffers_tensor_constant0_: "f32[3, 3, 3]"): 5422 l_self_buffers_tensor_constant0_ = L_self_buffers_tensor_constant0_ 5423 5424 alias_default: "f32[3, 3, 3]" = torch.ops.aten.alias.default(l_self_buffers_tensor_constant0_); l_self_buffers_tensor_constant0_ = None 5425 5426 sin_default: "f32[3, 3, 3]" = torch.ops.aten.sin.default(alias_default) 5427 5428 alias_default_1: "f32[3, 3, 3]" = torch.ops.aten.alias.default(alias_default) 5429 5430 cos_default: "f32[3, 3, 3]" = torch.ops.aten.cos.default(alias_default_1); alias_default_1 = None 5431 5432 alias_default_2: "f32[3, 3, 3]" = torch.ops.aten.alias.default(sin_default); alias_default_2 = None 5433 return (alias_default, cos_default, sin_default) 5434""", 5435 ) 5436 5437 wrapped_gm = self._compile_check(wrapper_fn, (x,), fullgraph=False, graph_idx=1) 5438 actual = normalize_gm(wrapped_gm.print_readable(print_output=False)) 5439 self.assertExpectedInline( 5440 actual, 5441 """\ 5442class GraphModule(torch.nn.Module): 5443 def forward(self, L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_0_: "f32[3, 3, 3]", L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_1_: "f32[3, 3, 3]", L_flat_tangents_1_: "f32[3, 3, 3]"): 5444 l_self_modules_fx_const_folded_attrs_parameters_0_ = L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_0_ 5445 l_self_modules_fx_const_folded_attrs_parameters_1_ = L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_1_ 5446 l_flat_tangents_1_ = L_flat_tangents_1_ 5447 5448 _new_zeros_with_same_feature_meta_default: "f32[3, 3, 3]" = torch.ops.aten._new_zeros_with_same_feature_meta.default(l_flat_tangents_1_, l_self_modules_fx_const_folded_attrs_parameters_0_); l_self_modules_fx_const_folded_attrs_parameters_0_ = None 5449 5450 copy__default: "f32[3, 3, 3]" = torch.ops.aten.copy_.default(_new_zeros_with_same_feature_meta_default, l_flat_tangents_1_); _new_zeros_with_same_feature_meta_default = l_flat_tangents_1_ = None 5451 5452 mul_tensor: "f32[3, 3, 3]" = torch.ops.aten.mul.Tensor(copy__default, l_self_modules_fx_const_folded_attrs_parameters_1_); copy__default = l_self_modules_fx_const_folded_attrs_parameters_1_ = None 5453 return (mul_tensor,) 5454""", 5455 ) 5456 5457 def test_linearize_disable_capture(self): 5458 counters.clear() 5459 with config.patch(capture_func_transforms=False): 5460 # We have verified above that this 5461 # function compiles 5462 def wrapper_fn(x): 5463 out, _ = torch.func.linearize(torch.sin, x) 5464 return out 5465 5466 x = torch.randn(2, 3) 5467 actual = wrapper_fn(x) 5468 expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)( 5469 x 5470 ) 5471 self.assertEqual(len(counters["graph_break"]), 1) 5472 self.assertEqual( 5473 { 5474 "torch.func.linearize capture is disabled, it can be " 5475 "turned on by setting `torch._dynamo.config.capture_func_transforms=True`": 1, 5476 }, 5477 dict(counters["graph_break"]), 5478 ) 5479 self.assertEqual(actual, expected) 5480 5481 @config.patch(capture_func_transforms=True) 5482 @config.patch(error_on_recompile=True) 5483 def test_vmap_recompile(self): 5484 @torch.compile(backend="eager") 5485 def fn(x): 5486 return torch.vmap(lambda x: x.sin())(x) 5487 5488 x = torch.zeros(3, 3, 4, 5) 5489 y = torch.vmap(fn)(x) 5490 # should not recompile on second call. See Pytorch issue #118493 5491 y = torch.vmap(fn)(x) 5492 5493 @xfailIfTorchDynamo 5494 @config.patch(error_on_recompile=True) 5495 def test_vmap_recompile_different_config(self): 5496 @torch.compile(backend="eager") 5497 def fn(x): 5498 return torch.vmap(lambda x: x.sin())(x) 5499 5500 x = torch.zeros(3, 3, 4, 5) 5501 y = torch.vmap(fn)(x) 5502 with self.assertRaises(torch._dynamo.exc.RecompileError): 5503 fn(x) 5504 5505 @config.patch(error_on_recompile=True) 5506 def test_vmap_recompile_same_config(self): 5507 @torch.compile(backend="eager") 5508 def fn(x): 5509 return torch.vmap(lambda x: x.sin())(x) 5510 5511 x = torch.zeros(3, 3, 4, 5) 5512 torch.vmap(torch.vmap(fn, randomness="same"), randomness="same")(x) 5513 with self.assertRaises(torch._dynamo.exc.RecompileError): 5514 torch.vmap(torch.vmap(fn, randomness="same"), randomness="error")(x) 5515 5516 @config.patch(error_on_recompile=True) 5517 def test_vmap_recompile_with_randomness(self): 5518 @torch.compile(backend="eager") 5519 def fn(x): 5520 return torch.vmap(lambda x: x.sin())(x) 5521 5522 x = torch.zeros(3, 3, 4, 5) 5523 torch.vmap(fn, randomness="same")(x) 5524 with self.assertRaises(torch._dynamo.exc.RecompileError): 5525 torch.vmap(fn, randomness="different")(x) 5526 5527 def test_vmap_call_torch_compile_fn(self): 5528 def wrapped_fn(x): 5529 return x.sin() 5530 5531 x = torch.randn(3, 4) 5532 fn = torch.compile(backend="aot_eager", fullgraph=True)(wrapped_fn) 5533 5534 with self.assertRaisesRegex( 5535 torch._dynamo.exc.Unsupported, 5536 "Calling torch.func.vmap\\(compiled_fn\\) function from eager mode is not supported", 5537 ): 5538 torch.func.vmap(fn)(x) 5539 5540 def test_grad_call_torch_compile_fn(self): 5541 def wrapped_fn(x): 5542 return x.sin().sum() 5543 5544 x = torch.randn(3, 4) 5545 fn = torch.compile(backend="aot_eager", fullgraph=True)(wrapped_fn) 5546 5547 with self.assertRaisesRegex( 5548 torch._dynamo.exc.Unsupported, 5549 "Calling torch.func.grad\\(compiled_fn\\) function from eager mode is not supported", 5550 ): 5551 torch.func.grad(fn)(x) 5552 5553 def test_jvp_call_torch_compile_fn(self): 5554 def wrapped_fn(x): 5555 return x.sin().sum() 5556 5557 x = torch.randn(3, 4) 5558 fn = torch.compile(backend="aot_eager", fullgraph=True)(wrapped_fn) 5559 5560 with self.assertRaisesRegex( 5561 torch._dynamo.exc.Unsupported, 5562 "Calling torch.func.jvp\\(compiled_fn\\) function from eager mode is not supported", 5563 ): 5564 torch.func.jvp(fn, (x,), (x,)) 5565 5566 @config.patch(error_on_recompile=True) 5567 def test_grad_recompile(self): 5568 @torch.compile(backend="eager") 5569 def fn(x): 5570 return torch.func.grad(torch.sin)(x) 5571 5572 x = torch.randn([]) 5573 torch.func.grad(fn)(x) 5574 # should not recompile on second call 5575 torch.func.grad(fn)(x) 5576 5577 def test_vmap_get_wrapped(self): 5578 counters.clear() 5579 5580 def g(x): 5581 return x.sin() 5582 5583 @torch.compile(backend="aot_eager", fullgraph=True) 5584 def fn(): 5585 return torch.vmap(g) 5586 5587 x = torch.randn(3, 4) 5588 expected = torch.vmap(g)(x) 5589 wrapper = fn() 5590 got = wrapper(x) 5591 self.assertEqual(expected, got) 5592 5593 def test_vmap_with_conditional_graph_break(self): 5594 def g(x): 5595 if len(x.shape) < 2: 5596 torch._dynamo.graph_break() 5597 return x.sin() 5598 else: 5599 return x.cos() 5600 5601 @torch.compile(backend="aot_eager") 5602 def fn(x): 5603 return torch.vmap(g)(x) 5604 5605 counters.clear() 5606 x = torch.randn(2, 3) 5607 expected = x.sin() 5608 got = fn(x) 5609 self.assertEqual(expected, got) 5610 self.assertEqual(len(counters["graph_break"]), 1) 5611 5612 counters.clear() 5613 y = torch.randn(2, 3, 4) 5614 expected = y.cos() 5615 got = fn(y) 5616 self.assertEqual(expected, got) 5617 self.assertEqual(len(counters["graph_break"]), 0) 5618 5619 def test_vmap_with_graph_break(self): 5620 counters.clear() 5621 5622 def g(x): 5623 y = x.cos() 5624 print("hi") 5625 return y.sin() 5626 5627 def fn(x): 5628 return torch.vmap(g)(x) 5629 5630 x = torch.randn(3, 4) 5631 opt = torch.compile(fn, backend="aot_eager", fullgraph=False) 5632 expected = fn(x) 5633 got = opt(x) 5634 self.assertEqual(len(counters["graph_break"]), 1) 5635 self.assertEqual(expected, got) 5636 5637 def test_vmap_with_graph_break_2(self): 5638 counters.clear() 5639 5640 def cos(x): 5641 print("cos") 5642 return x.cos() 5643 5644 def sin(x): 5645 print("sin") 5646 return x.sin() 5647 5648 def g(x): 5649 y = cos(x) 5650 return sin(y) 5651 5652 def fn(x): 5653 return torch.vmap(g, randomness="same")(x) 5654 5655 x = torch.randn(3, 4) 5656 opt = torch.compile(fn, backend="aot_eager", fullgraph=False) 5657 expected = fn(x) 5658 got = opt(x) 5659 self.assertEqual(len(counters["graph_break"]), 1) 5660 self.assertEqual(expected, got) 5661 5662 def test_vmap_with_graph_break_lambda(self): 5663 counters.clear() 5664 5665 def sin(x): 5666 print("sin") 5667 return x.sin() 5668 5669 def fn(x): 5670 return torch.vmap(lambda x: sin(x))(x) 5671 5672 x = torch.randn(3, 4) 5673 opt = torch.compile(fn, backend="aot_eager", fullgraph=False) 5674 expected = fn(x) 5675 got = opt(x) 5676 self.assertEqual(len(counters["graph_break"]), 1) 5677 self.assertEqual(expected, got) 5678 5679 def test_vmap(self): 5680 def fn(x): 5681 return torch.func.vmap(lambda x: x.sum(0) + x.sum(1))(x) 5682 5683 x = torch.randn(3, 3, 3) 5684 wrapped_gm = self._compile_check(fn, (x,)) 5685 5686 # Dynamic shapes produce a slightly different graph. 5687 if check_dynamic_shape_capture(): 5688 return 5689 5690 actual = normalize_gm(wrapped_gm.print_readable(print_output=False)) 5691 self.assertExpectedInline( 5692 actual, 5693 """\ 5694class GraphModule(torch.nn.Module): 5695 def forward(self, L_x_: "f32[3, 3, 3]"): 5696 l_x_ = L_x_ 5697 5698 lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None 5699 5700 _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting = None 5701 5702 _add_batch_dim: "f32[3, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None 5703 5704 sum_1: "f32[3]" = _add_batch_dim.sum(0) 5705 sum_2: "f32[3]" = _add_batch_dim.sum(1); _add_batch_dim = None 5706 batched_outputs: "f32[3]" = sum_1 + sum_2; sum_1 = sum_2 = None 5707 5708 _remove_batch_dim: "f32[3, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 3, 0); batched_outputs = None 5709 5710 _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None 5711 return (_remove_batch_dim,) 5712""", 5713 ) 5714 5715 def test_vmap_free_const(self): 5716 y = 3 5717 5718 def fn(x): 5719 return torch.func.vmap(lambda x: x.sum(0) + x.sum(1) + y)(x) 5720 5721 x = torch.randn(3, 3, 3) 5722 wrapped_gm = self._compile_check(fn, (x,)) 5723 5724 # Dynamic shapes produce a slightly different graph. 5725 if check_dynamic_shape_capture(): 5726 return 5727 5728 actual = normalize_gm(wrapped_gm.print_readable(print_output=False)) 5729 self.assertExpectedInline( 5730 actual, 5731 """\ 5732class GraphModule(torch.nn.Module): 5733 def forward(self, L_x_: "f32[3, 3, 3]"): 5734 l_x_ = L_x_ 5735 5736 lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None 5737 5738 _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting = None 5739 5740 _add_batch_dim: "f32[3, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None 5741 5742 sum_1: "f32[3]" = _add_batch_dim.sum(0) 5743 sum_2: "f32[3]" = _add_batch_dim.sum(1); _add_batch_dim = None 5744 add: "f32[3]" = sum_1 + sum_2; sum_1 = sum_2 = None 5745 batched_outputs: "f32[3]" = add + 3; add = None 5746 5747 _remove_batch_dim: "f32[3, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 3, 0); batched_outputs = None 5748 5749 _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None 5750 return (_remove_batch_dim,) 5751""", 5752 ) 5753 5754 def test_vmap_free_tensor(self): 5755 y = torch.randn(3, 3) 5756 5757 def fn(x): 5758 return torch.func.vmap(lambda x: x.sum(0) + x.sum(1) + y)(x) 5759 5760 x = torch.randn(3, 3, 3) 5761 wrapped_gm = self._compile_check(fn, (x,)) 5762 5763 # Dynamic shapes produce a slightly different graph. 5764 if check_dynamic_shape_capture(): 5765 return 5766 5767 actual = normalize_gm(wrapped_gm.print_readable(print_output=False)) 5768 self.assertExpectedInline( 5769 actual, 5770 """\ 5771class GraphModule(torch.nn.Module): 5772 def forward(self, L_x_: "f32[3, 3, 3]", L_y_: "f32[3, 3]"): 5773 l_x_ = L_x_ 5774 l_y_ = L_y_ 5775 5776 lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None 5777 5778 _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting = None 5779 5780 _add_batch_dim: "f32[3, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None 5781 5782 sum_1: "f32[3]" = _add_batch_dim.sum(0) 5783 sum_2: "f32[3]" = _add_batch_dim.sum(1); _add_batch_dim = None 5784 add: "f32[3]" = sum_1 + sum_2; sum_1 = sum_2 = None 5785 batched_outputs: "f32[3, 3]" = add + l_y_; add = l_y_ = None 5786 5787 _remove_batch_dim: "f32[3, 3, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 3, 0); batched_outputs = None 5788 5789 _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None 5790 return (_remove_batch_dim,) 5791""", 5792 ) 5793 5794 def test_vmap_two_inputs(self): 5795 def fn(x, y): 5796 return torch.func.vmap( 5797 lambda x, y: x.sum(0) + x.sum(1) + y, in_dims=(0, 1) 5798 )(x, y) 5799 5800 x = torch.randn(3, 3, 3) 5801 y = torch.randn(3, 3) 5802 wrapped_gm = self._compile_check(fn, (x, y)) 5803 5804 # Dynamic shapes produce a slightly different graph. 5805 if check_dynamic_shape_capture(): 5806 return 5807 5808 actual = normalize_gm(wrapped_gm.print_readable(print_output=False)) 5809 self.assertExpectedInline( 5810 actual, 5811 """\ 5812class GraphModule(torch.nn.Module): 5813 def forward(self, L_x_: "f32[3, 3, 3]", L_y_: "f32[3, 3]"): 5814 l_x_ = L_x_ 5815 l_y_ = L_y_ 5816 5817 lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None 5818 5819 _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting = None 5820 5821 _add_batch_dim: "f32[3, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None 5822 _add_batch_dim_1: "f32[3]" = torch._C._functorch._add_batch_dim(l_y_, 1, 1); l_y_ = None 5823 5824 sum_1: "f32[3]" = _add_batch_dim.sum(0) 5825 sum_2: "f32[3]" = _add_batch_dim.sum(1); _add_batch_dim = None 5826 add: "f32[3]" = sum_1 + sum_2; sum_1 = sum_2 = None 5827 batched_outputs: "f32[3]" = add + _add_batch_dim_1; add = _add_batch_dim_1 = None 5828 5829 _remove_batch_dim: "f32[3, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 3, 0); batched_outputs = None 5830 5831 _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None 5832 return (_remove_batch_dim,) 5833""", 5834 ) 5835 5836 def test_vmap_two_inputs_tuple_in_dims(self): 5837 in_dims = (0, 1) 5838 5839 def fn(x, y): 5840 return torch.func.vmap( 5841 lambda x, y: x.sum(0) + x.sum(1) + y, in_dims=in_dims 5842 )(x, y) 5843 5844 x = torch.randn(3, 3, 3) 5845 y = torch.randn(3, 3) 5846 wrapped_gm = self._compile_check(fn, (x, y)) 5847 5848 # Dynamic shapes produce a slightly different graph. 5849 if check_dynamic_shape_capture(): 5850 return 5851 5852 actual = normalize_gm(wrapped_gm.print_readable(print_output=False)) 5853 self.assertExpectedInline( 5854 actual, 5855 """\ 5856class GraphModule(torch.nn.Module): 5857 def forward(self, L_x_: "f32[3, 3, 3]", L_y_: "f32[3, 3]"): 5858 l_x_ = L_x_ 5859 l_y_ = L_y_ 5860 5861 lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None 5862 5863 _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting = None 5864 5865 _add_batch_dim: "f32[3, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None 5866 _add_batch_dim_1: "f32[3]" = torch._C._functorch._add_batch_dim(l_y_, 1, 1); l_y_ = None 5867 5868 sum_1: "f32[3]" = _add_batch_dim.sum(0) 5869 sum_2: "f32[3]" = _add_batch_dim.sum(1); _add_batch_dim = None 5870 add: "f32[3]" = sum_1 + sum_2; sum_1 = sum_2 = None 5871 batched_outputs: "f32[3]" = add + _add_batch_dim_1; add = _add_batch_dim_1 = None 5872 5873 _remove_batch_dim: "f32[3, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 3, 0); batched_outputs = None 5874 5875 _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None 5876 return (_remove_batch_dim,) 5877""", 5878 ) 5879 5880 def test_vmap_over_vmap_two_inputs(self): 5881 def fn(x, y): 5882 return torch.func.vmap(torch.func.vmap(lambda x, y: x + y, in_dims=1))(x, y) 5883 5884 x = torch.randn(3, 3, 3) 5885 y = torch.randn(3, 3, 3) 5886 wrapped_gm = self._compile_check(fn, (x, y)) 5887 5888 # Dynamic shapes produce a slightly different graph. 5889 if check_dynamic_shape_capture(): 5890 return 5891 5892 actual = normalize_gm(wrapped_gm.print_readable(print_output=False)) 5893 self.assertExpectedInline( 5894 actual, 5895 """\ 5896class GraphModule(torch.nn.Module): 5897 def forward(self, L_x_: "f32[3, 3, 3]", L_y_: "f32[3, 3, 3]"): 5898 l_x_ = L_x_ 5899 l_y_ = L_y_ 5900 5901 lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None 5902 5903 _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting = None 5904 5905 child: "f32[3, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None 5906 child_1: "f32[3, 3]" = torch._C._functorch._add_batch_dim(l_y_, 0, 1); l_y_ = None 5907 5908 lazy_load_decompositions_1 = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions_1 = None 5909 5910 _vmap_increment_nesting_1 = torch._C._functorch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting_1 = None 5911 5912 _add_batch_dim_2: "f32[3]" = torch._C._functorch._add_batch_dim(child, 1, 2); child = None 5913 _add_batch_dim_3: "f32[3]" = torch._C._functorch._add_batch_dim(child_1, 1, 2); child_1 = None 5914 5915 batched_outputs: "f32[3]" = _add_batch_dim_2 + _add_batch_dim_3; _add_batch_dim_2 = _add_batch_dim_3 = None 5916 5917 batched_outputs_1: "f32[3, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 2, 3, 0); batched_outputs = None 5918 5919 _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None 5920 5921 _remove_batch_dim_1: "f32[3, 3, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs_1, 1, 3, 0); batched_outputs_1 = None 5922 5923 _vmap_decrement_nesting_1 = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting_1 = None 5924 return (_remove_batch_dim_1,) 5925""", 5926 ) 5927 5928 def test_vmap_over_vmap_captured(self): 5929 x = torch.ones(2, 3) 5930 y = torch.ones(5, 3) 5931 5932 def fn(x): 5933 return torch.func.vmap(torch.func.vmap(lambda y: x * y))(y) 5934 5935 wrapped_gm = self._compile_check(fn, (x,)) 5936 5937 # Dynamic shapes produce a slightly different graph. 5938 if check_dynamic_shape_capture(): 5939 return 5940 5941 actual = normalize_gm(wrapped_gm.print_readable(print_output=False)) 5942 self.assertExpectedInline( 5943 actual, 5944 """\ 5945class GraphModule(torch.nn.Module): 5946 def forward(self, L_y_: "f32[5, 3]", L_x_: "f32[2, 3]"): 5947 l_y_ = L_y_ 5948 l_x_ = L_x_ 5949 5950 lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None 5951 5952 _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(5, 'error'); _vmap_increment_nesting = None 5953 5954 child: "f32[3]" = torch._C._functorch._add_batch_dim(l_y_, 0, 1); l_y_ = None 5955 5956 lazy_load_decompositions_1 = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions_1 = None 5957 5958 _vmap_increment_nesting_1 = torch._C._functorch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting_1 = None 5959 5960 _add_batch_dim_1: "f32[]" = torch._C._functorch._add_batch_dim(child, 0, 2); child = None 5961 5962 batched_outputs: "f32[2, 3]" = l_x_ * _add_batch_dim_1; l_x_ = _add_batch_dim_1 = None 5963 5964 batched_outputs_1: "f32[3, 2, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 2, 3, 0); batched_outputs = None 5965 5966 _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None 5967 5968 _remove_batch_dim_1: "f32[5, 3, 2, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs_1, 1, 5, 0); batched_outputs_1 = None 5969 5970 _vmap_decrement_nesting_1 = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting_1 = None 5971 return (_remove_batch_dim_1,) 5972""", 5973 ) 5974 5975 def test_vmap_multiple_outputs(self): 5976 x = torch.ones(2, 4, 3) 5977 5978 def fn(x): 5979 return torch.vmap(lambda x: (x.sum(0), x.sum(1)))(x) 5980 5981 wrapped_gm = self._compile_check(fn, (x,)) 5982 5983 # Dynamic shapes produce a slightly different graph. 5984 if check_dynamic_shape_capture(): 5985 return 5986 5987 actual = normalize_gm(wrapped_gm.print_readable(print_output=False)) 5988 self.assertExpectedInline( 5989 actual, 5990 """\ 5991class GraphModule(torch.nn.Module): 5992 def forward(self, L_x_: "f32[2, 4, 3]"): 5993 l_x_ = L_x_ 5994 5995 lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None 5996 5997 _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(2, 'error'); _vmap_increment_nesting = None 5998 5999 _add_batch_dim: "f32[4, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None 6000 6001 child: "f32[3]" = _add_batch_dim.sum(0) 6002 child_1: "f32[4]" = _add_batch_dim.sum(1); _add_batch_dim = None 6003 6004 _remove_batch_dim: "f32[2, 3]" = torch._C._functorch._remove_batch_dim(child, 1, 2, 0); child = None 6005 _remove_batch_dim_1: "f32[2, 4]" = torch._C._functorch._remove_batch_dim(child_1, 1, 2, 0); child_1 = None 6006 6007 _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None 6008 return (_remove_batch_dim, _remove_batch_dim_1) 6009""", 6010 ) 6011 6012 def test_vmap_multiple_outputs_diff_dims(self): 6013 x = torch.ones(2, 4, 3) 6014 6015 def fn(x): 6016 return torch.vmap(lambda x: (x.sum(0), x.sum(1)), out_dims=(1, 0))(x) 6017 6018 wrapped_gm = self._compile_check(fn, (x,)) 6019 6020 # Dynamic shapes produce a slightly different graph. 6021 if check_dynamic_shape_capture(): 6022 return 6023 6024 actual = normalize_gm(wrapped_gm.print_readable(print_output=False)) 6025 self.assertExpectedInline( 6026 actual, 6027 """\ 6028class GraphModule(torch.nn.Module): 6029 def forward(self, L_x_: "f32[2, 4, 3]"): 6030 l_x_ = L_x_ 6031 6032 lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None 6033 6034 _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(2, 'error'); _vmap_increment_nesting = None 6035 6036 _add_batch_dim: "f32[4, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None 6037 6038 child: "f32[3]" = _add_batch_dim.sum(0) 6039 child_1: "f32[4]" = _add_batch_dim.sum(1); _add_batch_dim = None 6040 6041 _remove_batch_dim: "f32[3, 2]" = torch._C._functorch._remove_batch_dim(child, 1, 2, 1); child = None 6042 _remove_batch_dim_1: "f32[2, 4]" = torch._C._functorch._remove_batch_dim(child_1, 1, 2, 0); child_1 = None 6043 6044 _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None 6045 return (_remove_batch_dim, _remove_batch_dim_1) 6046""", 6047 ) 6048 6049 def test_vmap_multiple_outputs_out_dims_tuple(self): 6050 x = torch.ones(2, 4, 3) 6051 out_dims = (1, 0) 6052 6053 def fn(x): 6054 return torch.vmap(lambda x: (x.sum(0), x.sum(1)), out_dims=out_dims)(x) 6055 6056 wrapped_gm = self._compile_check(fn, (x,)) 6057 6058 # Dynamic shapes produce a slightly different graph. 6059 if check_dynamic_shape_capture(): 6060 return 6061 6062 actual = normalize_gm(wrapped_gm.print_readable(print_output=False)) 6063 self.assertExpectedInline( 6064 actual, 6065 """\ 6066class GraphModule(torch.nn.Module): 6067 def forward(self, L_x_: "f32[2, 4, 3]"): 6068 l_x_ = L_x_ 6069 6070 lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None 6071 6072 _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(2, 'error'); _vmap_increment_nesting = None 6073 6074 _add_batch_dim: "f32[4, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None 6075 6076 child: "f32[3]" = _add_batch_dim.sum(0) 6077 child_1: "f32[4]" = _add_batch_dim.sum(1); _add_batch_dim = None 6078 6079 _remove_batch_dim: "f32[3, 2]" = torch._C._functorch._remove_batch_dim(child, 1, 2, 1); child = None 6080 _remove_batch_dim_1: "f32[2, 4]" = torch._C._functorch._remove_batch_dim(child_1, 1, 2, 0); child_1 = None 6081 6082 _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None 6083 return (_remove_batch_dim, _remove_batch_dim_1) 6084""", 6085 ) 6086 6087 def test_vmap_kwargs(self): 6088 counters.clear() 6089 x = torch.ones(2, 3) 6090 y = torch.randn(2, 3) 6091 6092 def fn(x, y): 6093 return torch.func.vmap(lambda x, y: x + y)(x, y=y) 6094 6095 actual = fn(x, y) 6096 expected = torch.compile(fn, backend="aot_eager", fullgraph=False)(x, y) 6097 self.assertEqual(len(counters["graph_break"]), 0) 6098 self.assertEqual(actual, expected) 6099 6100 def test_vmap_pytree_inputs(self): 6101 counters.clear() 6102 x = torch.ones(2, 3) 6103 y = torch.randn(2, 3) 6104 6105 def vmap_fn(inps): 6106 x = inps["x"] 6107 y = inps["y"] 6108 return x + y 6109 6110 def fn(x, y): 6111 return torch.func.vmap(vmap_fn)({"x": x, "y": y}) 6112 6113 actual = fn(x, y) 6114 expected = torch.compile(fn, backend="aot_eager", fullgraph=False)(x, y) 6115 self.assertEqual(len(counters["graph_break"]), 0) 6116 self.assertEqual(actual, expected) 6117 6118 def test_vmap_side_effects(self): 6119 counters.clear() 6120 x = torch.ones(2, 3) 6121 y = torch.randn(2, 3) 6122 6123 some_list = [] 6124 6125 def f(x, y): 6126 some_list.append(1) 6127 return x + y 6128 6129 def wrapper_fn(x, y): 6130 return torch.func.vmap(f)(x, y) 6131 6132 actual = wrapper_fn(x, y) 6133 expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x, y) 6134 self.assertEqual(len(counters["graph_break"]), 0) 6135 self.assertEqual(actual, expected) 6136 self.assertEqual(some_list, [1, 1]) 6137 6138 @unittest.expectedFailure 6139 def test_vmap_side_effects_append_input(self): 6140 counters.clear() 6141 x = torch.ones(2, 3) 6142 y = torch.randn(2, 3) 6143 6144 some_list = [] 6145 6146 def f(x, y): 6147 some_list.append(x) 6148 return x + y 6149 6150 def wrapper_fn(x, y): 6151 return torch.func.vmap(f)(x, y) 6152 6153 actual = wrapper_fn(x, y) 6154 expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x, y) 6155 self.assertEqual(len(counters["graph_break"]), 0) 6156 self.assertEqual(actual, expected) 6157 6158 def test_vmap_previous_illegal_op_no_graph_break(self): 6159 counters.clear() 6160 6161 # calling .stride() would previously graph break 6162 def bad_fn(x): 6163 y = x.view((4, 3)) 6164 y.stride() 6165 return y 6166 6167 def wrapper_fn(x): 6168 return torch.func.vmap(bad_fn)(x) 6169 6170 x = torch.randn(2, 3, 4) 6171 actual = wrapper_fn(x) 6172 expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x) 6173 self.assertEqual(len(counters["graph_break"]), 0) 6174 self.assertEqual(actual, expected) 6175 6176 def test_vmap_disable_capture(self): 6177 counters.clear() 6178 6179 with config.patch(capture_func_transforms=False): 6180 # We have verified above that this 6181 # function compiles 6182 def wrapper_fn(x): 6183 return torch.func.vmap(lambda x: x.sum(0) + x.sum(1))(x) 6184 6185 x = torch.randn(3, 3, 3) 6186 actual = wrapper_fn(x) 6187 expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)( 6188 x 6189 ) 6190 self.assertEqual(len(counters["graph_break"]), 1) 6191 self.assertEqual( 6192 dict(counters["graph_break"]), 6193 { 6194 "torch.func.vmap capture is disabled, it can be " 6195 "turned on by setting `torch._dynamo.config.capture_func_transforms=True`": 2 6196 }, 6197 ) 6198 self.assertEqual(actual, expected) 6199 6200 def test_vmap_multiple_invocation_in_dims(self): 6201 counters.clear() 6202 6203 def wrapper_fn(x, in_dims): 6204 return torch.func.vmap(torch.sum, in_dims)(x) 6205 6206 x = torch.randn(3, 3, 3, 3) 6207 cnt = CompileCounter() 6208 opt = torch.compile(wrapper_fn, backend=cnt, fullgraph=False, dynamic=True) 6209 expected = wrapper_fn(x, 0), wrapper_fn(x, 1), wrapper_fn(x, 2) 6210 # Third invocation of `opt` makes `in_dims` as SymInt. 6211 actual = opt(x, 0), opt(x, 1), opt(x, 2) 6212 self.assertEqual(expected, actual) 6213 self.assertEqual(cnt.frame_count, 3) 6214 self.assertEqual(cnt.op_count, 21) 6215 6216 def test_vmap_multiple_invocation_out_dims(self): 6217 counters.clear() 6218 6219 def wrapper_fn(x, out_dims): 6220 return torch.func.vmap(lambda x: torch.sum(x, 0), out_dims=out_dims)(x) 6221 6222 x = torch.randn(3, 3, 3, 3) 6223 cnt = CompileCounter() 6224 opt = torch.compile(wrapper_fn, backend=cnt, fullgraph=False, dynamic=True) 6225 expected = wrapper_fn(x, 0), wrapper_fn(x, 1), wrapper_fn(x, 2) 6226 # Third invocation of `opt` makes `in_dims` as SymInt. 6227 actual = opt(x, 0), opt(x, 1), opt(x, 2) 6228 self.assertEqual(expected, actual) 6229 self.assertEqual(cnt.frame_count, 3) 6230 self.assertEqual(cnt.op_count, 21) 6231 6232 def test_vmap_new_tensor_in_body(self): 6233 def fn(x): 6234 return x + torch.ones(3) 6235 6236 def wrapper_fn(x): 6237 return torch.func.vmap(fn)(x) 6238 6239 x = torch.randn( 6240 3, 6241 ) 6242 opt = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=True) 6243 expected = wrapper_fn(x) 6244 actual = opt(x) 6245 self.assertEqual(expected, actual) 6246 6247 def test_vmap_new_tensor_unused_in_body(self): 6248 def fn(x): 6249 return torch.tensor(0.5) 6250 6251 def wrapper_fn(x): 6252 return torch.func.vmap(fn)(x) 6253 6254 x = torch.randn(3) 6255 opt = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=True) 6256 expected = wrapper_fn(x) 6257 actual = opt(x) 6258 self.assertEqual(expected, actual) 6259 6260 def test_vmap_new_tensor_implicit_via_op(self): 6261 def wrapper_fn(x): 6262 return torch.func.vmap(lambda t: torch.add(t, 0.5))(x) 6263 6264 x = torch.randn(3) 6265 opt = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=True) 6266 expected = wrapper_fn(x) 6267 actual = opt(x) 6268 self.assertEqual(expected, actual) 6269 6270 6271class ActivationCheckpointingTests(torch._dynamo.test_case.TestCase): 6272 def _validate(self, fn, backend, *args, skip_check=False, fullgraph=True): 6273 cloned_args = [] 6274 for arg in args: 6275 cloned_args.append(arg.clone().detach().requires_grad_(arg.requires_grad)) 6276 6277 torch.manual_seed(0) 6278 expected = fn(*args) 6279 expected.sum().backward() 6280 6281 opt_fn = torch.compile(fn, fullgraph=fullgraph, backend=backend) 6282 torch.manual_seed(0) 6283 result = opt_fn(*cloned_args) 6284 result.sum().backward() 6285 6286 if not skip_check: 6287 self.assertEqual(result, expected) 6288 for arg, cloned_arg in zip(args, cloned_args): 6289 self.assertEqual(arg.grad, cloned_arg.grad) 6290 6291 @requires_cuda 6292 @torch._functorch.config.patch(functionalize_rng_ops=True) 6293 def test_function(self): 6294 def gn(x, y): 6295 return torch.sigmoid(torch.matmul(x, y)) 6296 6297 def fn(x, y): 6298 return torch.utils.checkpoint.checkpoint( 6299 gn, torch.sin(x), y, use_reentrant=True 6300 ) 6301 6302 x = torch.randn(4, 4, requires_grad=True) 6303 y = torch.randn(4, 4, requires_grad=True) 6304 6305 fw_compiler = functools.partial(count_ops, freq=1, op=torch.ops.aten.mm.default) 6306 bw_compiler = functools.partial(count_ops, freq=2, op=torch.ops.aten.mm.default) 6307 backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) 6308 self._validate(fn, backend, x, y) 6309 6310 @requires_cuda 6311 @torch._functorch.config.patch(functionalize_rng_ops=True) 6312 def test_function_with_kwargs(self): 6313 def gn(x, y): 6314 return torch.sigmoid(torch.matmul(x, y)) 6315 6316 def fn(x, y): 6317 return torch.utils.checkpoint.checkpoint( 6318 gn, 6319 torch.sin(x), 6320 y, 6321 use_reentrant=True, 6322 preserve_rng_state=False, 6323 ) 6324 6325 x = torch.randn(4, 4, requires_grad=True) 6326 y = torch.randn(4, 4, requires_grad=True) 6327 6328 fw_compiler = functools.partial(count_ops, freq=1, op=torch.ops.aten.mm.default) 6329 bw_compiler = functools.partial(count_ops, freq=2, op=torch.ops.aten.mm.default) 6330 backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) 6331 self._validate(fn, backend, x, y) 6332 6333 @requires_cuda 6334 @torch._functorch.config.patch(functionalize_rng_ops=True) 6335 def test_dropout(self): 6336 def gn(x, y): 6337 return torch.nn.functional.dropout(torch.matmul(x, y), p=0.2) 6338 6339 def fn(x, y): 6340 return torch.utils.checkpoint.checkpoint( 6341 gn, torch.sin(x), y, use_reentrant=True 6342 ) 6343 6344 x = torch.randn(4, 4, device="cuda", requires_grad=True) 6345 y = torch.randn(4, 4, device="cuda", requires_grad=True) 6346 6347 fw_compiler = functools.partial( 6348 count_ops, freq=1, op=torch.ops.rngprims.philox_rand.default 6349 ) 6350 # philox_rand is passed from fwd 6351 bw_compiler = functools.partial( 6352 count_ops, freq=0, op=torch.ops.rngprims.philox_rand.default 6353 ) 6354 backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) 6355 self._validate( 6356 fn, backend, x, y, skip_check=True 6357 ) # dropout decomp is known to diverge with eager 6358 6359 @requires_cuda 6360 @torch._functorch.config.patch(functionalize_rng_ops=True) 6361 def test_dropout_inductor(self): 6362 def gn(x, y): 6363 return torch.nn.functional.dropout(torch.matmul(x, y), p=0.2) 6364 6365 def fn(x, y): 6366 return torch.utils.checkpoint.checkpoint( 6367 gn, torch.sin(x), y, use_reentrant=True 6368 ) 6369 6370 x = torch.randn(4, 4, device="cuda", requires_grad=True) 6371 y = torch.randn(4, 4, device="cuda", requires_grad=True) 6372 6373 backend = "inductor" 6374 self._validate( 6375 fn, backend, x, y, skip_check=True 6376 ) # dropout decomp is known to diverge with eager 6377 6378 @requires_cuda 6379 @torch._functorch.config.patch(functionalize_rng_ops=True) 6380 def test_fallback(self): 6381 def gn(x, y): 6382 torch._dynamo.graph_break() 6383 return torch.sigmoid(torch.matmul(x, y)) 6384 6385 def fn(x, y): 6386 return torch.cos( 6387 torch.utils.checkpoint.checkpoint( 6388 gn, torch.sin(x), y, use_reentrant=True 6389 ), 6390 ) 6391 6392 x = torch.randn(4, 4, requires_grad=True) 6393 y = torch.randn(4, 4, requires_grad=True) 6394 args = (x, y) 6395 6396 backend = EagerAndRecordGraphs() 6397 cnt = CompileCounterWithBackend(backend) 6398 6399 expected = fn(*args) 6400 result = torch.compile(fn, backend=cnt)(*args) 6401 6402 self.assertEqual(result, expected) 6403 6404 # One graph for torch.sin on the input, and other for torch.cos. 6405 self.assertEqual(cnt.frame_count, 2) 6406 self.assertEqual(cnt.op_count, 2) 6407 self.assertEqual(len(backend.graphs), 2) 6408 6409 @requires_cuda 6410 @torch._functorch.config.patch(functionalize_rng_ops=True) 6411 def test_module(self): 6412 class MockModule(torch.nn.Module): 6413 def __init__(self) -> None: 6414 super().__init__() 6415 self.linear = torch.nn.Linear(10, 10) 6416 6417 def forward(self, x): 6418 return torch.sigmoid(self.linear(x)) 6419 6420 mod = MockModule() 6421 6422 def fn(x): 6423 return torch.utils.checkpoint.checkpoint( 6424 mod, torch.sin(x), use_reentrant=True 6425 ) 6426 6427 x = torch.randn(10, 10, requires_grad=True) 6428 6429 fw_compiler = functools.partial( 6430 count_ops, freq=1, op=torch.ops.aten.sigmoid.default 6431 ) 6432 # sigmoid passed from fwd 6433 bw_compiler = functools.partial( 6434 count_ops, freq=0, op=torch.ops.aten.sigmoid.default 6435 ) 6436 backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) 6437 self._validate(fn, backend, x) 6438 6439 def test_override_fallthrough_dispatch_key(self): 6440 class _FallthroughTestOnly(torch._ops.HigherOrderOperator): 6441 def __init__(self): 6442 super().__init__("_fallthrough_test_only") 6443 6444 def __call__(self, *args, **kwargs): 6445 return super().__call__(*args, **kwargs) 6446 6447 test_op = _FallthroughTestOnly() 6448 default_keys = torch._ops._HIGHER_ORDER_OP_DEFAULT_FALLTHROUGH_DISPATCH_KEYS 6449 self.assertTrue( 6450 not any(test_op.non_fallthrough_keys.has(key) for key in default_keys) 6451 ) 6452 6453 foos = [lambda x=i: x for i, k in enumerate(default_keys)] 6454 for foo, fallthrough_key in zip(foos, default_keys): 6455 test_op.py_impl(fallthrough_key)(foo) 6456 6457 self.assertTrue( 6458 all(test_op.non_fallthrough_keys.has(key) for key in default_keys) 6459 ) 6460 self.assertEqual( 6461 list(range(len(default_keys))), 6462 [test_op.py_kernels[key]() for key in default_keys], 6463 ) 6464 6465 def test_cond_with_kwargs(self): 6466 from torch._higher_order_ops.cond import cond_op 6467 6468 def test(pred, x): 6469 def true_fn(x): 6470 return x 6471 6472 def false_fn(x): 6473 return -x 6474 6475 return cond_op(pred=pred, true_fn=true_fn, false_fn=false_fn, operands=[x]) 6476 6477 cnt = CompileCounter() 6478 opt_test = torch.compile(test, backend=cnt, fullgraph=True) 6479 inp = torch.ones(3, 3) 6480 true_pred = torch.Tensor([True]) 6481 false_pred = torch.Tensor([False]) 6482 self.assertTrue(torch.allclose(test(true_pred, inp), opt_test(true_pred, inp))) 6483 self.assertEqual(cnt.frame_count, 1) 6484 self.assertTrue( 6485 torch.allclose(test(false_pred, inp), opt_test(false_pred, inp)) 6486 ) 6487 self.assertEqual(cnt.frame_count, 1) 6488 6489 def test_cond_with_invalid_kwargs(self): 6490 from torch._higher_order_ops.cond import cond_op 6491 6492 def test(pred, mode, x): 6493 def true_fn(x): 6494 return x 6495 6496 def false_fn(x): 6497 return -x 6498 6499 if mode: 6500 return cond_op( 6501 pred=pred, 6502 true_fn=true_fn, 6503 false_fn=false_fn, 6504 operands=[x], 6505 invalid=True, 6506 ) 6507 else: 6508 return cond_op( 6509 pred, 6510 pred=pred, 6511 true_fn=true_fn, 6512 false_fn=false_fn, 6513 operands=[x], 6514 ) 6515 6516 cnt = CompileCounter() 6517 opt_test = torch.compile(test, backend=cnt) 6518 inp = torch.ones(3, 3) 6519 with self.assertRaises(torch._dynamo.exc.UncapturedHigherOrderOpError): 6520 opt_test(True, True, inp) 6521 6522 with self.assertRaises(AssertionError): 6523 opt_test(True, False, inp) 6524 6525 def test_non_aliasing_util(self): 6526 from torch._dynamo.variables.higher_order_ops import _assert_tensors_nonaliasing 6527 6528 a = [torch.tensor(1), {"a": torch.tensor(1)}] 6529 b = (torch.tensor(1),) 6530 _assert_tensors_nonaliasing(a, b) 6531 6532 with self.assertRaisesRegex( 6533 AssertionError, "inputs to function body cannot alias outputs" 6534 ): 6535 _assert_tensors_nonaliasing(a, a) 6536 6537 6538if __name__ == "__main__": 6539 from torch._dynamo.test_case import run_tests 6540 6541 run_tests() 6542