1# Owner(s): ["module: dynamo"] 2""" 3PYTEST_DONT_REWRITE (prevents pytest from rewriting assertions, which interferes 4with test_export_persist_assert) 5""" 6import copy 7import functools 8import inspect 9import io 10import operator 11import unittest 12from enum import Enum 13from typing import Dict, List, Sequence 14from unittest.mock import patch 15 16import torch 17import torch._dynamo 18import torch._dynamo.test_case 19import torch._dynamo.testing 20from functorch.experimental.control_flow import cond 21from torch._dynamo import config 22from torch._dynamo.exc import UserError 23from torch._dynamo.testing import normalize_gm 24from torch._higher_order_ops.out_dtype import out_dtype 25from torch._subclasses import fake_tensor 26from torch.fx.experimental.proxy_tensor import make_fx 27from torch.fx.experimental.symbolic_shapes import ( 28 ConstraintViolationError, 29 DimDynamic, 30 ShapeEnv, 31 StatelessSymbolicContext, 32) 33from torch.testing._internal import common_utils 34from torch.testing._internal.common_cuda import TEST_CUDA 35 36 37class ExportTests(torch._dynamo.test_case.TestCase): 38 # TODO(voz): Refactor to a shared test function. 39 # The tests in this file are a little redundant, 40 # They all take a func, run it with eager, then export it, then compare 41 def test_export(self): 42 def pre_attention_state_ops(input, mems, state): 43 lc_key = state[0] 44 lc_val = state[1] 45 bar = [] 46 for i in range(0, 4): 47 bar2 = [] 48 for j in range(0, 3): 49 bar2.append( 50 lc_key + lc_val + torch.tensor([0.1, 0.25, 0.4, 0.5, 0.1]) 51 ) 52 bar.append(bar2) 53 54 return bar 55 56 def func(): 57 mems = torch.tensor([[[1.8364, 0.2724, -1.4917, -0.4367, 0.8640]]]) 58 state = [ 59 torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]), 60 torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]), 61 ] 62 i = torch.tensor( 63 [ 64 [0.0313, -0.1487, -0.3846, -0.5321], 65 [-1.7073, 1.3331, -0.0890, -1.4935], 66 [-0.8314, -0.1862, -0.5935, 1.5232], 67 ] 68 ) 69 return pre_attention_state_ops(i, mems, state) 70 71 opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) 72 real_result = opt_func() 73 74 torch._dynamo.reset() 75 76 exported = torch._dynamo.export(func)() 77 out_graph = exported[0] 78 79 dynamo_result = out_graph() 80 self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) 81 82 def test_no_tensor_computation_fail(self): 83 with self.assertRaisesRegex( 84 AssertionError, 85 "Failed to produce a graph", 86 ): 87 inp = [torch.randn(3)] 88 inp2 = 2 89 inps = [inp, inp2] 90 91 def func(x, y): 92 return x 93 94 exported = torch._dynamo.export(func, same_signature=False)(*inps) 95 96 def test_no_tensor_computation(self): 97 inp = [torch.randn(3)] 98 inp2 = 2 99 inps = [inp, inp2] 100 101 def func(x, y): 102 return x 103 104 opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) 105 real_result = opt_func(*inps) 106 107 torch._dynamo.reset() 108 109 exported = torch._dynamo.export(func)(*inps) 110 out_graph = exported[0] 111 112 dynamo_result = out_graph(*inps) 113 114 self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) 115 self.assertExpectedInline( 116 out_graph.code.strip(), 117 """\ 118def forward(self, x, y): 119 arg0, arg1, = fx_pytree.tree_flatten_spec(([x, y], {}), self._in_spec) 120 x = arg0 121 return pytree.tree_unflatten([x], self._out_spec)""", 122 ) 123 124 def test_no_tensor_computation_2(self): 125 inp = torch.randn(3) 126 inp2 = 2 127 inps = [inp, inp2] 128 129 def func(x, y): 130 return y 131 132 opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) 133 real_result = opt_func(*inps) 134 135 torch._dynamo.reset() 136 137 exported = torch._dynamo.export(func)(*inps) 138 out_graph = exported[0] 139 140 dynamo_result = out_graph(*inps) 141 142 self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) 143 self.assertExpectedInline( 144 out_graph.code.strip(), 145 """\ 146def forward(self, x, y): 147 arg0, arg1, = fx_pytree.tree_flatten_spec(([x, y], {}), self._in_spec) 148 x = arg0 149 return pytree.tree_unflatten([2], self._out_spec)""", 150 ) 151 152 def test_export_mismatched_out(self): 153 def func(x): 154 y = x + 1 155 return ([x, x], (y, y)) 156 157 opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) 158 real_result = opt_func(torch.tensor([[[1.3737, 0.1]]])) 159 160 torch._dynamo.reset() 161 162 exported = torch._dynamo.export(func)(torch.tensor([[[1.3737, 0.1]]])) 163 out_graph = exported[0] 164 165 dynamo_result = out_graph(torch.tensor([[[1.3737, 0.1]]])) 166 167 self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) 168 169 def test_export_shape_control_flow_1(self): 170 def func(x): 171 if x.shape[0] > 10: 172 return x.cos() 173 return x.sin() 174 175 opt_func = torch._dynamo.optimize("eager")(func) 176 real_result = opt_func(torch.ones(6, 4)) 177 178 torch._dynamo.reset() 179 180 exported = torch._dynamo.export(func)(torch.ones(6, 4)) 181 out_graph, out_guards = exported 182 183 dynamo_result = out_graph(torch.ones(6, 4)) 184 185 from torch._guards import GuardSource 186 187 self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) 188 hit = False 189 for guard in out_guards: 190 if guard.source == GuardSource.SHAPE_ENV: 191 hit = True 192 self.assertExpectedInline( 193 guard.code_list, 194 """["L['x'].stride()[0] == L['x'].size()[1]", "L['x'].stride()[1] == 1", "L['x'].storage_offset() == 0", "2 <= L['x'].size()[0] <= 10", "2 <= L['x'].size()[1]"]""", # noqa: B950 195 ) 196 break 197 198 self.assertTrue(hit) 199 200 def test_export_control_flow_with_getattr(self): 201 class Animal(Enum): 202 COW = "moo" 203 204 class MyModule(torch.nn.Module): 205 def __init__(self, a): 206 super().__init__() 207 self.a = a 208 209 def forward(self, x): 210 if self.a == Animal.COW.value: 211 return x * x 212 else: 213 raise ValueError("bad") 214 215 module = MyModule("moo") 216 input = (torch.ones(4, 3),) 217 resA = module(*input) 218 graph, _ = torch._dynamo.export(module)(*input) 219 resB = graph(*input) 220 self.assertTrue(torch._dynamo.utils.same(resA, resB)) 221 222 def test_export_graph_bypass(self): 223 inp = [ 224 torch.tensor([0.1, 0.1]), 225 torch.tensor([0.2, 0.2]), 226 torch.tensor([0.3, 0.3]), 227 ] 228 229 def func(x): 230 first = x[2] 231 second = x[2] 232 return first * second 233 234 opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) 235 real_result = opt_func(inp) 236 237 torch._dynamo.reset() 238 239 exported = torch._dynamo.export(func)(inp) 240 out_graph = exported[0] 241 242 dynamo_result = out_graph(inp) 243 244 self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) 245 246 def test_list_unpack(self): 247 inp = [ 248 torch.tensor([0.1, 0.1]), 249 torch.tensor([0.2, 0.2]), 250 torch.tensor([0.3, 0.3]), 251 ] 252 253 def func(x): 254 first = x[2] 255 second = x[2] 256 return x[0], first * second, x[1], x[2] 257 258 opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) 259 real_result = opt_func(inp) 260 261 torch._dynamo.reset() 262 263 exported = torch._dynamo.export(func)(inp) 264 out_graph = exported[0] 265 266 dynamo_result = out_graph(inp) 267 268 self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) 269 270 def test_export_with_shallow_list_copy_wo_side_effects(self): 271 def f(x): 272 y = x.copy() 273 return y[0] + y[1] 274 275 inp = [torch.tensor([1.3, 3.77, 0.1]), torch.tensor([8.7, 6.23, 9.9])] 276 gm = torch._dynamo.export(f, aten_graph=True, tracing_mode="symbolic")( 277 inp 278 ).graph_module 279 self.assertTrue(torch._dynamo.utils.same(gm(inp), f(inp))) 280 281 def test_export_with_shallow_list_copy_with_side_effects(self): 282 def f(x): 283 y = x.copy() 284 x[0] = x[1] 285 y.append(torch.tensor([[100]])) 286 return x[0] + x[1], y[0] + y[1], y[2] 287 288 inp = [torch.tensor([1.3, 3.77, 0.1]), torch.tensor([8.7, 6.23, 9.9])] 289 gm = torch._dynamo.export(f, aten_graph=True, tracing_mode="symbolic")( 290 inp 291 ).graph_module 292 res = gm(inp) 293 ref = f(inp) 294 self.assertTrue(torch._dynamo.utils.same(res, ref)) 295 self.assertEqual(res[0], res[1]) 296 297 def test_export_mismatched_out_2(self): 298 def func(x): 299 y = x + 1 300 return ([x, x], (y, y)) 301 302 opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) 303 real_result = opt_func(torch.tensor([[[1.3737, 0.1]]])) 304 305 torch._dynamo.reset() 306 307 exported = torch._dynamo.export(func)(torch.tensor([[[1.3737, 0.1]]])) 308 out_graph = exported[0] 309 310 dynamo_result = out_graph(torch.tensor([[[1.3737, 0.1]]])) 311 312 self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) 313 314 def test_export_graph_with_list(self): 315 inp = [ 316 torch.tensor([0.1, 0.1]), 317 torch.tensor([0.2, 0.2]), 318 torch.tensor([0.3, 0.3]), 319 torch.tensor([0.4, 0.4]), 320 ] 321 322 def func(x): 323 first = x[2] 324 second = x[2] 325 return first * second, x 326 327 opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) 328 real_result = opt_func(inp) 329 330 torch._dynamo.reset() 331 332 exported = torch._dynamo.export(func)(inp) 333 out_graph = exported[0] 334 335 dynamo_result = out_graph(inp) 336 337 self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) 338 339 def test_export_graph_with_complex_reorder(self): 340 inp = [ 341 torch.tensor([0.1, 0.1]), 342 torch.tensor([0.2, 0.2]), 343 torch.tensor([0.3, 0.3]), 344 torch.tensor([0.4, 0.4]), 345 ] 346 347 def func(x): 348 first = x[0] 349 second = x[1] 350 third = x[2] 351 return third, first, second, first * second, first * third 352 353 opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) 354 real_result = opt_func(inp) 355 356 torch._dynamo.reset() 357 358 exported = torch._dynamo.export(func)(inp) 359 out_graph = exported[0] 360 361 dynamo_result = out_graph(inp) 362 363 self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) 364 365 def test_dupes(self): 366 inp = torch.tensor([0.1, 0.1]) 367 368 def func(x): 369 y = x + 1 370 return y, y 371 372 opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) 373 real_result = opt_func(inp) 374 375 torch._dynamo.reset() 376 377 exported = torch._dynamo.export(func)(inp) 378 out_graph = exported[0] 379 380 dynamo_result = out_graph(inp) 381 382 self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) 383 384 def test_dupes_2(self): 385 inp = torch.tensor([0.1, 0.1]) 386 387 def func(x): 388 y = x + 1 389 return y, y 390 391 opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) 392 real_result = opt_func(inp) 393 394 torch._dynamo.reset() 395 396 exported = torch._dynamo.export(func)(inp) 397 out_graph = exported[0] 398 399 dynamo_result = out_graph(inp) 400 401 self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) 402 403 def test_dupes_and_bypass(self): 404 inp = torch.tensor([0.1, 0.1]) 405 inp2 = torch.tensor([0.4, 0.4]) 406 inps = [inp, inp2] 407 408 def func(x, z): 409 y = x + 1 410 return y, y, z 411 412 opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) 413 real_result = opt_func(*inps) 414 415 torch._dynamo.reset() 416 417 exported = torch._dynamo.export(func)(*inps) 418 out_graph = exported[0] 419 420 dynamo_result = out_graph(*inps) 421 422 self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) 423 424 def test_dupes_and_bypass_with_non_tensor_arg(self): 425 inp = torch.tensor([0.1, 0.1]) 426 inp2 = torch.tensor([0.1, 0.1]) 427 inp3 = 4 428 inps = [inp, inp2, inp3] 429 430 def func(x, z, k): 431 y = x + k 432 return y, y, z 433 434 opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) 435 real_result = opt_func(*inps) 436 437 torch._dynamo.reset() 438 439 exported = torch._dynamo.export(func)(*inps) 440 out_graph = exported[0] 441 442 dynamo_result = out_graph(*inps) 443 444 self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) 445 446 def test_dupes_and_bypass_reorder_with_non_tensor_arg(self): 447 inp = torch.tensor([0.1, 0.1]) 448 inp2 = torch.tensor([0.1, 0.1]) 449 inp3 = 4 450 inps = [inp, inp2, inp3] 451 452 def func(x, z, k): 453 y = x + k 454 return z, y, y 455 456 opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) 457 real_result = opt_func(*inps) 458 459 torch._dynamo.reset() 460 461 exported = torch._dynamo.export(func)(*inps) 462 out_graph = exported[0] 463 464 dynamo_result = out_graph(*inps) 465 466 self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) 467 468 @config.patch(capture_scalar_outputs=True) 469 def test_dupes_and_bypass_with_non_tensor_output(self): 470 inp = torch.tensor([0.1, 0.1]) 471 inp2 = torch.tensor([0.1, 0.1]) 472 inp3 = 4 473 inps = [inp, inp2, inp3] 474 475 def func(x, z, k): 476 y = x + k 477 return y[0].item(), y, z 478 479 opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) 480 real_result = opt_func(*inps) 481 482 torch._dynamo.reset() 483 484 exported = torch._dynamo.export(func)(*inps) 485 out_graph = exported[0] 486 487 dynamo_result = out_graph(*inps) 488 489 self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) 490 491 def test_zeroes_in_and_out_different_shape_on_test(self): 492 inp = torch.zeros(10) 493 inp2 = torch.zeros(10) 494 inp3 = torch.zeros(10) 495 inps = [inp, inp2, inp3] 496 497 inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)] 498 499 def func(a, b, c): 500 return [[a], [b, c], [a + b], [[c + c]]] 501 502 opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) 503 real_result = opt_func(*inps_rand) 504 505 torch._dynamo.reset() 506 507 exported = torch._dynamo.export(func)(*inps) 508 out_graph = exported[0] 509 510 dynamo_result = out_graph(*inps_rand) 511 512 self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) 513 514 @config.patch(capture_scalar_outputs=True) 515 def test_zeroes_in_new_shape_scalar_out(self): 516 inp = torch.zeros(10) 517 inp2 = torch.zeros(10) 518 inp3 = torch.zeros(10) 519 inps = [inp, inp2, inp3] 520 521 inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)] 522 523 def func(a, b, c): 524 return a[0].item() + b[0].item() + c[0].item() 525 526 opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) 527 real_result = opt_func(*inps_rand) 528 529 torch._dynamo.reset() 530 531 exported = torch._dynamo.export(func)(*inps) 532 out_graph = exported[0] 533 534 dynamo_result = out_graph(*inps_rand) 535 536 self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) 537 538 @config.patch(capture_scalar_outputs=True) 539 def test_zeroes_in_new_shape_scalar_out_permute(self): 540 inp = torch.zeros(10) 541 inp2 = torch.zeros(10) 542 inp3 = torch.zeros(10) 543 inps = [inp, inp2, inp3] 544 545 inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)] 546 547 def func(a, b, c): 548 return b[0].item() + c[0].item() + a[0].item() + a[0].item() 549 550 opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) 551 real_result = opt_func(*inps_rand) 552 553 torch._dynamo.reset() 554 555 exported = torch._dynamo.export(func)(*inps) 556 out_graph = exported[0] 557 558 dynamo_result = out_graph(*inps_rand) 559 560 self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) 561 562 @config.patch(capture_scalar_outputs=True) 563 def test_zeroes_in_new_shape_scalar_out_permute_dupe_and_bypass(self): 564 inp = torch.zeros(10) 565 inp2 = torch.zeros(10) 566 inp3 = torch.zeros(10) 567 inps = [inp, inp2, inp3] 568 569 inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)] 570 571 def func(a, b, c): 572 return a, b[0].item() + c[0].item() + a[0].item() + a[0].item(), a 573 574 opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) 575 real_result = opt_func(*inps_rand) 576 577 torch._dynamo.reset() 578 579 exported = torch._dynamo.export(func)(*inps) 580 out_graph = exported[0] 581 582 dynamo_result = out_graph(*inps_rand) 583 584 self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) 585 586 def test_func_return(self): 587 inp = torch.zeros(10) 588 inp2 = torch.zeros(10) 589 inp3 = torch.zeros(10) 590 inps = [inp, inp2, inp3] 591 592 inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)] 593 594 def func(a, b, c): 595 x = a + b + c 596 597 def func2(y): 598 return x * y 599 600 return func2(x) 601 602 opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) 603 real_result = opt_func(*inps_rand) 604 605 torch._dynamo.reset() 606 607 exported = torch._dynamo.export(func)(*inps) 608 out_graph = exported[0] 609 610 dynamo_result = out_graph(*inps_rand) 611 612 self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) 613 614 def test_dict_return(self): 615 inp = torch.zeros(10) 616 inp2 = torch.zeros(10) 617 inp3 = torch.zeros(10) 618 inps = [inp, inp2, inp3] 619 620 inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)] 621 622 def func(a, b, c): 623 x = a + b + c 624 return {"a": x} 625 626 opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) 627 real_result = opt_func(*inps_rand) 628 629 torch._dynamo.reset() 630 631 exported = torch._dynamo.export(func)(*inps) 632 out_graph = exported[0] 633 634 dynamo_result = out_graph(*inps_rand) 635 636 self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) 637 638 def test_export_with_aten_graph(self): 639 def pre_attention_state_ops(input, mems, state): 640 lc_key = state[0] 641 lc_val = state[1] 642 bar = [] 643 for i in range(0, 4): 644 bar2 = [] 645 for j in range(0, 3): 646 bar2.append( 647 lc_key + lc_val + torch.tensor([0.1, 0.25, 0.4, 0.5, 0.1]) 648 ) 649 bar.append(bar2) 650 651 return bar 652 653 def func(): 654 mems = torch.tensor([[[1.8364, 0.2724, -1.4917, -0.4367, 0.8640]]]) 655 state = [ 656 torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]), 657 torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]), 658 ] 659 i = torch.tensor( 660 [ 661 [0.0313, -0.1487, -0.3846, -0.5321], 662 [-1.7073, 1.3331, -0.0890, -1.4935], 663 [-0.8314, -0.1862, -0.5935, 1.5232], 664 ] 665 ) 666 return pre_attention_state_ops(i, mems, state) 667 668 opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) 669 real_result = opt_func() 670 671 torch._dynamo.reset() 672 673 exported = torch._dynamo.export(func, aten_graph=True)() 674 out_graph = exported[0] 675 676 dynamo_result = out_graph() 677 self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) 678 679 def test_export_no_tensor_computation_with_aten_graph(self): 680 inp = [torch.randn(3)] 681 inp2 = 2 682 inps = [inp, inp2] 683 684 def func(x, y): 685 return x 686 687 opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) 688 real_result = opt_func(*inps) 689 690 torch._dynamo.reset() 691 692 exported = torch._dynamo.export(func, aten_graph=True)(*inps) 693 out_graph = exported[0] 694 695 dynamo_result = out_graph(*inps) 696 697 self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) 698 self.assertExpectedInline( 699 out_graph.code.strip(), 700 """\ 701def forward(self, x, y): 702 arg0, arg1, = fx_pytree.tree_flatten_spec(([x, y], {}), self._in_spec) 703 arg0_1 = arg0 704 return pytree.tree_unflatten([arg0_1], self._out_spec)""", 705 ) 706 707 def test_no_tensor_computation_2_with_aten_graph(self): 708 inp = torch.randn(3) 709 inp2 = 2 710 inps = [inp, inp2] 711 712 def func(x, y): 713 return y 714 715 opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) 716 real_result = opt_func(*inps) 717 718 torch._dynamo.reset() 719 720 exported = torch._dynamo.export(func, aten_graph=True)(*inps) 721 out_graph = exported[0] 722 723 dynamo_result = out_graph(*inps) 724 725 self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) 726 self.assertExpectedInline( 727 out_graph.code.strip(), 728 """\ 729def forward(self, x, y): 730 arg0, arg1, = fx_pytree.tree_flatten_spec(([x, y], {}), self._in_spec) 731 arg0_1 = arg0 732 return pytree.tree_unflatten([2], self._out_spec)""", 733 ) 734 735 def test_export_mismatched_out_with_aten_graph(self): 736 def func(x): 737 y = x + 1 738 return ([x, x], (y, y)) 739 740 opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) 741 real_result = opt_func(torch.tensor([[[1.3737, 0.1]]])) 742 743 torch._dynamo.reset() 744 745 exported = torch._dynamo.export(func, aten_graph=True)( 746 torch.tensor([[[1.3737, 0.1]]]) 747 ) 748 out_graph = exported[0] 749 750 dynamo_result = out_graph(torch.tensor([[[1.3737, 0.1]]])) 751 752 self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) 753 754 def test_export_graph_bypass_with_aten_graph(self): 755 inp = [ 756 torch.tensor([0.1, 0.1]), 757 torch.tensor([0.2, 0.2]), 758 torch.tensor([0.3, 0.3]), 759 ] 760 761 def func(x): 762 first = x[2] 763 second = x[2] 764 return first * second 765 766 opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) 767 real_result = opt_func(inp) 768 769 torch._dynamo.reset() 770 771 exported = torch._dynamo.export(func, aten_graph=True)(inp) 772 out_graph = exported[0] 773 774 dynamo_result = out_graph(inp) 775 776 self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) 777 778 def test_list_unpack_with_aten_graph(self): 779 inp = [ 780 torch.tensor([0.1, 0.1]), 781 torch.tensor([0.2, 0.2]), 782 torch.tensor([0.3, 0.3]), 783 ] 784 785 def func(x): 786 first = x[2] 787 second = x[2] 788 return x[0], first * second, x[1], x[2] 789 790 opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) 791 real_result = opt_func(inp) 792 793 torch._dynamo.reset() 794 795 exported = torch._dynamo.export(func, aten_graph=True)(inp) 796 out_graph = exported[0] 797 798 dynamo_result = out_graph(inp) 799 800 self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) 801 802 def test_export_mismatched_out_2_with_aten_graph(self): 803 def func(x): 804 y = x + 1 805 return ([x, x], (y, y)) 806 807 opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) 808 real_result = opt_func(torch.tensor([[[1.3737, 0.1]]])) 809 810 torch._dynamo.reset() 811 812 exported = torch._dynamo.export(func, aten_graph=True)( 813 torch.tensor([[[1.3737, 0.1]]]) 814 ) 815 out_graph = exported[0] 816 817 dynamo_result = out_graph(torch.tensor([[[1.3737, 0.1]]])) 818 819 self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) 820 821 def test_export_graph_with_list_with_aten_graph(self): 822 inp = [ 823 torch.tensor([0.1, 0.1]), 824 torch.tensor([0.2, 0.2]), 825 torch.tensor([0.3, 0.3]), 826 torch.tensor([0.4, 0.4]), 827 ] 828 829 def func(x): 830 first = x[2] 831 second = x[2] 832 return first * second, x 833 834 opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) 835 real_result = opt_func(inp) 836 837 torch._dynamo.reset() 838 839 exported = torch._dynamo.export(func, aten_graph=True)(inp) 840 out_graph = exported[0] 841 842 dynamo_result = out_graph(inp) 843 844 self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) 845 846 def test_export_graph_with_complex_reorder_with_aten_graph(self): 847 inp = [ 848 torch.tensor([0.1, 0.1]), 849 torch.tensor([0.2, 0.2]), 850 torch.tensor([0.3, 0.3]), 851 torch.tensor([0.4, 0.4]), 852 ] 853 854 def func(x): 855 first = x[0] 856 second = x[1] 857 third = x[2] 858 return third, first, second, first * second, first * third 859 860 opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) 861 real_result = opt_func(inp) 862 863 torch._dynamo.reset() 864 865 exported = torch._dynamo.export(func, aten_graph=True)(inp) 866 out_graph = exported[0] 867 868 dynamo_result = out_graph(inp) 869 870 self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) 871 872 def test_dupes_with_aten_graph(self): 873 inp = torch.tensor([0.1, 0.1]) 874 875 def func(x): 876 y = x + 1 877 return y, y 878 879 opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) 880 real_result = opt_func(inp) 881 882 torch._dynamo.reset() 883 884 exported = torch._dynamo.export(func, aten_graph=True)(inp) 885 out_graph = exported[0] 886 887 dynamo_result = out_graph(inp) 888 889 self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) 890 891 def test_dupes_2_with_aten_graph(self): 892 inp = torch.tensor([0.1, 0.1]) 893 894 def func(x): 895 y = x + 1 896 return y, y 897 898 opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) 899 real_result = opt_func(inp) 900 901 torch._dynamo.reset() 902 903 exported = torch._dynamo.export(func, aten_graph=True)(inp) 904 out_graph = exported[0] 905 906 dynamo_result = out_graph(inp) 907 908 self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) 909 910 def test_dupes_and_bypass_with_aten_graph(self): 911 inp = torch.tensor([0.1, 0.1]) 912 inp2 = torch.tensor([0.4, 0.4]) 913 inps = [inp, inp2] 914 915 def func(x, z): 916 y = x + 1 917 return y, y, z 918 919 opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) 920 real_result = opt_func(*inps) 921 922 torch._dynamo.reset() 923 924 exported = torch._dynamo.export(func, aten_graph=True)(*inps) 925 out_graph = exported[0] 926 927 dynamo_result = out_graph(*inps) 928 929 self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) 930 931 def test_dupes_and_bypass_with_non_tensor_arg_with_aten_graph(self): 932 inp = torch.tensor([0.1, 0.1]) 933 inp2 = torch.tensor([0.1, 0.1]) 934 inp3 = 4 935 inps = [inp, inp2, inp3] 936 937 def func(x, z, k): 938 y = x + k 939 return y, y, z 940 941 opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) 942 real_result = opt_func(*inps) 943 944 torch._dynamo.reset() 945 946 exported = torch._dynamo.export(func, aten_graph=True)(*inps) 947 out_graph = exported[0] 948 949 dynamo_result = out_graph(*inps) 950 951 self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) 952 953 def test_dupes_and_bypass_reorder_with_non_tensor_arg_with_aten_graph(self): 954 inp = torch.tensor([0.1, 0.1]) 955 inp2 = torch.tensor([0.1, 0.1]) 956 inp3 = 4 957 inps = [inp, inp2, inp3] 958 959 def func(x, z, k): 960 y = x + k 961 return z, y, y 962 963 opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) 964 real_result = opt_func(*inps) 965 966 torch._dynamo.reset() 967 968 exported = torch._dynamo.export(func, aten_graph=True)(*inps) 969 out_graph = exported[0] 970 971 dynamo_result = out_graph(*inps) 972 973 self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) 974 975 @config.patch(capture_scalar_outputs=True) 976 def test_dupes_and_bypass_with_non_tensor_output_with_aten_graph(self): 977 inp = torch.tensor([0.1, 0.1]) 978 inp2 = torch.tensor([0.1, 0.1]) 979 inp3 = 4 980 inps = [inp, inp2, inp3] 981 982 def func(x, z, k): 983 y = x + k 984 return y[0].item(), y, z 985 986 opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) 987 real_result = opt_func(*inps) 988 989 torch._dynamo.reset() 990 991 exported = torch._dynamo.export(func)(*inps) 992 out_graph = exported[0] 993 994 dynamo_result = out_graph(*inps) 995 996 self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) 997 998 def test_zeroes_in_and_out_different_shape_on_test_with_aten_graph(self): 999 inp = torch.zeros(10) 1000 inp2 = torch.zeros(10) 1001 inp3 = torch.zeros(10) 1002 inps = [inp, inp2, inp3] 1003 1004 inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)] 1005 1006 def func(a, b, c): 1007 return [[a], [b, c], [a + b], [[c + c]]] 1008 1009 opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) 1010 real_result = opt_func(*inps_rand) 1011 1012 torch._dynamo.reset() 1013 1014 exported = torch._dynamo.export(func, aten_graph=True)(*inps) 1015 out_graph = exported[0] 1016 1017 dynamo_result = out_graph(*inps_rand) 1018 1019 self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) 1020 1021 def test_func_return_with_aten_graph(self): 1022 inp = torch.zeros(10) 1023 inp2 = torch.zeros(10) 1024 inp3 = torch.zeros(10) 1025 inps = [inp, inp2, inp3] 1026 1027 inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)] 1028 1029 def func(a, b, c): 1030 x = a + b + c 1031 1032 def func2(y): 1033 return x * y 1034 1035 return func2(x) 1036 1037 opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) 1038 real_result = opt_func(*inps_rand) 1039 1040 torch._dynamo.reset() 1041 1042 exported = torch._dynamo.export(func, aten_graph=True)(*inps) 1043 out_graph = exported[0] 1044 1045 dynamo_result = out_graph(*inps_rand) 1046 1047 self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) 1048 1049 def test_dict_return_with_aten_graph(self): 1050 inp = torch.zeros(10) 1051 inp2 = torch.zeros(10) 1052 inp3 = torch.zeros(10) 1053 inps = [inp, inp2, inp3] 1054 1055 inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)] 1056 1057 def func(a, b, c): 1058 x = a + b + c 1059 return {"a": x} 1060 1061 opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) 1062 real_result = opt_func(*inps_rand) 1063 1064 torch._dynamo.reset() 1065 1066 exported = torch._dynamo.export(func, aten_graph=True)(*inps) 1067 out_graph = exported[0] 1068 1069 dynamo_result = out_graph(*inps_rand) 1070 1071 self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) 1072 1073 def test_export_with_stack_trace(self): 1074 inp = torch.randn(4, 4) 1075 1076 class MyBlock(torch.nn.Module): 1077 def forward(self, x): 1078 x = torch.nn.functional.linear(x, torch.randn(4, 4)) 1079 return torch.cos(x).relu() + 1 1080 1081 class MyModule(torch.nn.Module): 1082 def __init__(self) -> None: 1083 super().__init__() 1084 self.block = MyBlock() 1085 1086 def forward(self, x): 1087 out = self.block(x) 1088 return out 1089 1090 exported = torch._dynamo.export(MyModule(), aten_graph=False)(inp) 1091 out_graph = exported[0] 1092 1093 for node in out_graph.graph.nodes: 1094 if node.op not in {"placeholder", "output"}: 1095 self.assertTrue(node.stack_trace is not None) 1096 self.assertTrue(node.meta["nn_module_stack"] is not None) 1097 self.assertTrue(node.meta["source_fn_stack"] is not None) 1098 1099 torch._dynamo.reset() 1100 1101 exported = torch._dynamo.export(MyModule(), aten_graph=True)(inp) 1102 out_graph = exported[0] 1103 for node in out_graph.graph.nodes: 1104 if node.op == "call_function": 1105 self.assertTrue(node.stack_trace is not None) 1106 self.assertTrue(node.meta["nn_module_stack"] is not None) 1107 self.assertTrue(node.meta["source_fn_stack"] is not None) 1108 self.assertTrue(node.meta["val"] is not None) 1109 self.assertTrue(node.meta["original_aten"] is not None) 1110 1111 def test_export_preserves_nn_module_stack_for_get_attr(self): 1112 inp = torch.randn(4, 4) 1113 1114 class MyBlock(torch.nn.Module): 1115 def __init__(self) -> None: 1116 super().__init__() 1117 self.weight = torch.nn.Parameter(torch.ones(1, 1)) 1118 self.buffer = torch.nn.Buffer(torch.ones(1, 1)) 1119 1120 def forward(self, x): 1121 x = torch.nn.functional.linear(x, torch.randn(4, 4)) 1122 return torch.cos(x).relu() + self.weight + self.buffer 1123 1124 class MyModule(torch.nn.Module): 1125 def __init__(self) -> None: 1126 super().__init__() 1127 self.block = MyBlock() 1128 1129 def forward(self, x): 1130 out = self.block(x) 1131 return out 1132 1133 m = MyModule() 1134 exported = torch._dynamo.export(m, aten_graph=False)(inp) 1135 out_graph = exported[0] 1136 1137 attr_access_count = 0 1138 for node in out_graph.graph.nodes: 1139 if node.op == "get_attr": 1140 attr_access_count += 1 1141 self.assertTrue(node.meta["nn_module_stack"] is not None) 1142 self.assertEqual(attr_access_count, 2) 1143 1144 torch._dynamo.reset() 1145 1146 exported = torch._dynamo.export(m, aten_graph=True)(inp) 1147 out_graph = exported[0] 1148 1149 attr_access_count = 0 1150 for node in out_graph.graph.nodes: 1151 if node.op == "get_attr": 1152 attr_access_count += 1 1153 self.assertTrue(node.meta["nn_module_stack"] is not None) 1154 self.assertEqual(attr_access_count, 2) 1155 1156 def test_export_compare_optimize_with_make_fx(self): 1157 inp = torch.tensor([0.1, 0.1]) 1158 linear = torch.nn.Linear(2, 2) 1159 1160 def func(x): 1161 x = x + 1 1162 y = x.t() 1163 y = y.relu() 1164 y = linear(y) 1165 return y 1166 1167 exported = torch._dynamo.export(func, aten_graph=True)(inp) 1168 out_graph = exported[0] 1169 export_result = out_graph(inp) 1170 1171 torch._dynamo.reset() 1172 1173 def compiler(gm, sample_inputs): 1174 def fw(*args): 1175 aten_gm = make_fx(gm)(*args) 1176 return aten_gm(*args) 1177 1178 return fw 1179 1180 opt_func = torch._dynamo.optimize(compiler, nopython=True, dynamic=True)(func) 1181 make_fx_result_through_backend = opt_func(inp) 1182 1183 fx_g = make_fx(func)(inp) 1184 make_fx_result_through_direct = fx_g(inp) 1185 1186 self.assertTrue( 1187 torch._dynamo.utils.same(make_fx_result_through_backend, export_result) 1188 ) 1189 self.assertTrue( 1190 torch._dynamo.utils.same(make_fx_result_through_direct, export_result) 1191 ) 1192 1193 def test_export_with_constant_method_on_module(self): 1194 class MyModule(torch.nn.Module): 1195 def __init__(self) -> None: 1196 super().__init__() 1197 self.param = torch.nn.Parameter(torch.rand(4, 2)) 1198 self.linear = torch.nn.Linear(2, 2) 1199 1200 @torch._dynamo.assume_constant_result 1201 def helper_fn(self, x): 1202 return torch.nonzero(x) 1203 1204 def forward(self, x): 1205 y = torch.sin(x) 1206 x = self.linear(x) 1207 y = self.helper_fn(x) 1208 return y 1209 1210 module = MyModule() 1211 real_result = module(torch.tensor([[1.0, 0], [0, 0]])) 1212 module = MyModule() 1213 graph, _ = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]])) 1214 result = graph(torch.tensor([[1.0, 0.0], [0, 0]])) 1215 self.assertTrue(torch._dynamo.utils.same(result, real_result)) 1216 result = graph(torch.tensor([[1, 0], [0.25, 0.25]])) 1217 self.assertTrue(torch._dynamo.utils.same(result, real_result)) 1218 1219 def test_export_with_constant_method_on_module_invoke_twice(self): 1220 class MyModule(torch.nn.Module): 1221 def __init__(self) -> None: 1222 super().__init__() 1223 self.param = torch.nn.Parameter(torch.rand(4, 2)) 1224 self.linear = torch.nn.Linear(2, 2) 1225 1226 @torch._dynamo.assume_constant_result 1227 def helper_fn(self, x): 1228 return torch.nonzero(x) 1229 1230 def forward(self, x): 1231 y = torch.sin(x) 1232 x = self.linear(x) 1233 y = self.helper_fn(x) + self.helper_fn(x) 1234 return y 1235 1236 module = MyModule() 1237 real_result = module(torch.tensor([[1.0, 0], [0, 0]])) 1238 module = MyModule() 1239 graph, _ = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]])) 1240 result = graph(torch.tensor([[1.0, 0.0], [0, 0]])) 1241 self.assertTrue(torch._dynamo.utils.same(result, real_result)) 1242 result = graph(torch.tensor([[1, 0], [0.25, 0.25]])) 1243 self.assertTrue(torch._dynamo.utils.same(result, real_result)) 1244 1245 def test_export_with_constant_free_function(self): 1246 @torch._dynamo.assume_constant_result 1247 def helper_fn(x): 1248 return torch.nonzero(x) 1249 1250 class MyModule(torch.nn.Module): 1251 def __init__(self) -> None: 1252 super().__init__() 1253 self.param = torch.nn.Parameter(torch.rand(4, 2)) 1254 self.linear = torch.nn.Linear(2, 2) 1255 1256 @torch._dynamo.assume_constant_result 1257 def helper_fn(self, x): 1258 return torch.nonzero(x) 1259 1260 def forward(self, x): 1261 y = torch.sin(x) 1262 x = self.linear(x) 1263 y = helper_fn(x) + self.helper_fn(x) 1264 return y 1265 1266 module = MyModule() 1267 real_result = module(torch.tensor([[1.0, 0], [0, 0]])) 1268 module = MyModule() 1269 graph, _ = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]])) 1270 result = graph(torch.tensor([[1.0, 0.0], [0, 0]])) 1271 self.assertTrue(torch._dynamo.utils.same(result, real_result)) 1272 result = graph(torch.tensor([[1, 0], [0.25, 0.25]])) 1273 self.assertTrue(torch._dynamo.utils.same(result, real_result)) 1274 1275 def test_export_with_constant_free_function_and_class_method(self): 1276 @torch._dynamo.assume_constant_result 1277 def helper_fn(x): 1278 return torch.nonzero(x) 1279 1280 class MyModule(torch.nn.Module): 1281 def __init__(self) -> None: 1282 super().__init__() 1283 self.param = torch.nn.Parameter(torch.rand(4, 2)) 1284 self.linear = torch.nn.Linear(2, 2) 1285 1286 def forward(self, x): 1287 y = torch.sin(x) 1288 x = self.linear(x) 1289 y = helper_fn(x) 1290 return y 1291 1292 module = MyModule() 1293 real_result = module(torch.tensor([[1.0, 0], [0, 0]])) 1294 module = MyModule() 1295 graph, _ = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]])) 1296 result = graph(torch.tensor([[1.0, 0.0], [0, 0]])) 1297 self.assertTrue(torch._dynamo.utils.same(result, real_result)) 1298 result = graph(torch.tensor([[1, 0], [0.25, 0.25]])) 1299 self.assertTrue(torch._dynamo.utils.same(result, real_result)) 1300 1301 def test_export_with_constant_free_function_and_class_method_multiarg(self): 1302 @torch._dynamo.assume_constant_result 1303 def helper_fn(x): 1304 return torch.nonzero(x) 1305 1306 class MyModule(torch.nn.Module): 1307 def __init__(self) -> None: 1308 super().__init__() 1309 self.param = torch.nn.Parameter(torch.rand(4, 2)) 1310 self.linear = torch.nn.Linear(2, 2) 1311 1312 def forward(self, x, z): 1313 y = torch.sin(x) 1314 x = self.linear(x) 1315 y = helper_fn(x) + helper_fn(z) 1316 return y 1317 1318 module = MyModule() 1319 real_result = module( 1320 torch.tensor([[1.0, 0], [0, 0]]), torch.tensor([[1.0, 0], [0, 0]]) 1321 ) 1322 module = MyModule() 1323 graph, _ = torch._dynamo.export(module)( 1324 torch.tensor([[0.0, 0], [0, 0]]), torch.tensor([[1.0, 0], [0, 0]]) 1325 ) 1326 result = graph( 1327 torch.tensor([[1.0, 0.0], [0, 0]]), torch.tensor([[1.0, 0.0], [0, 0]]) 1328 ) 1329 self.assertTrue(torch._dynamo.utils.same(result, real_result)) 1330 result = graph( 1331 torch.tensor([[1, 0], [0.25, 0.25]]), torch.tensor([[1, 0], [0.25, 0.25]]) 1332 ) 1333 self.assertTrue(torch._dynamo.utils.same(result, real_result)) 1334 1335 def test_export_with_constant_free_function_and_class_method_multiarg_diff(self): 1336 @torch._dynamo.assume_constant_result 1337 def helper_fn(x): 1338 return torch.nonzero(x) 1339 1340 class MyModule(torch.nn.Module): 1341 def forward(self, x, z): 1342 y = helper_fn(x) + helper_fn(z) 1343 return y 1344 1345 module = MyModule() 1346 real_result = module( 1347 torch.tensor([[1.0, 0], [0, 0]]), torch.tensor([[1.0, 0], [0, 0]]) 1348 ) 1349 module = MyModule() 1350 graph, _ = torch._dynamo.export(module)( 1351 torch.tensor([[0.0, 0], [0, 0]]), torch.tensor([[0.0, 0], [0.5, 0]]) 1352 ) 1353 result = graph( 1354 torch.tensor([[1.0, 0.0], [0, 0]]), torch.tensor([[0.0, 1.0], [0, 0]]) 1355 ) 1356 self.assertTrue(torch._dynamo.utils.same(result, real_result)) 1357 result = graph( 1358 torch.tensor([[1, 0], [0.25, 0.25]]), 1359 torch.tensor([[0.33, 0.33], [0.25, 0.25]]), 1360 ) 1361 self.assertTrue(torch._dynamo.utils.same(result, real_result)) 1362 1363 def test_export_with_constant_tuple_nonzero(self): 1364 class MyModule(torch.nn.Module): 1365 @torch._dynamo.assume_constant_result 1366 def helper_fn(self, x): 1367 return (torch.nonzero(x), torch.nonzero(x)) 1368 1369 def forward(self, x): 1370 y = torch.tensor([0.5]) 1371 elements = self.helper_fn(x) 1372 all_y = [] 1373 for element in elements: 1374 for item in element: 1375 all_y.append(y * item) 1376 return all_y 1377 1378 module = MyModule() 1379 real_result = module(torch.tensor([1.0, 1.0])) 1380 graph, guards = torch._dynamo.export(module)(torch.tensor([1.0, 1.0])) 1381 1382 # Tensor input can be almost anything here, and the result will capture what we 1383 # made constant at compile time. 1384 result = graph(torch.tensor([[[1.0, 0], [0, 0]], [[1.0, 0], [0, 0]]])) 1385 self.assertTrue(torch._dynamo.utils.same(result, real_result)) 1386 1387 def test_export_with_constant_list_nonzero(self): 1388 class MyModule(torch.nn.Module): 1389 @torch._dynamo.assume_constant_result 1390 def helper_fn(self, x): 1391 return [torch.nonzero(x), torch.nonzero(x)] 1392 1393 def forward(self, x): 1394 y = torch.tensor([0.5]) 1395 elements = self.helper_fn(x) 1396 all_y = [] 1397 for element in elements: 1398 for item in element: 1399 all_y.append(y * item) 1400 return all_y 1401 1402 module = MyModule() 1403 real_result = module(torch.tensor([1.0, 1.0])) 1404 graph, guards = torch._dynamo.export(module)(torch.tensor([1.0, 1.0])) 1405 1406 # Tensor input can be almost anything here, and the result will capture what we 1407 # made constant at compile time. 1408 result = graph(torch.tensor([[[1.0, 0], [0, 0]], [[1.0, 0], [0, 0]]])) 1409 self.assertTrue(torch._dynamo.utils.same(result, real_result)) 1410 1411 def test_export_with_constant_list_nonzero_free_function(self): 1412 @torch._dynamo.assume_constant_result 1413 def helper_fn(x): 1414 return [torch.nonzero(x), torch.nonzero(x)] 1415 1416 class MyModule(torch.nn.Module): 1417 def forward(self, x): 1418 y = torch.tensor([0.5]) 1419 elements = helper_fn(x) 1420 all_y = [] 1421 for element in elements: 1422 for item in element: 1423 all_y.append(y * item) 1424 return all_y 1425 1426 module = MyModule() 1427 real_result = module(torch.tensor([1.0, 1.0])) 1428 graph, guards = torch._dynamo.export(module)(torch.tensor([1.0, 1.0])) 1429 1430 # Tensor input can be almost anything here, and the result will capture what we 1431 # made constant at compile time. 1432 result = graph(torch.tensor([[[1.0, 0], [0, 0]], [[1.0, 0], [0, 0]]])) 1433 self.assertTrue(torch._dynamo.utils.same(result, real_result)) 1434 1435 def test_export_with_constant_dict_values(self): 1436 class MyModule(torch.nn.Module): 1437 @torch._dynamo.assume_constant_result 1438 def helper_fn(self, x): 1439 return {"x": x, "x^2": x * x} 1440 1441 def forward(self, x): 1442 y = torch.tensor([0.5]) 1443 elements = self.helper_fn(x) 1444 y = y * elements["x"] 1445 y = y * elements["x^2"] 1446 return y 1447 1448 module = MyModule() 1449 real_result = module(torch.tensor([2.0, 2.0])) 1450 graph, guards = torch._dynamo.export(module)(torch.tensor([2.0, 2.0])) 1451 1452 # Tensor input can be almost anything here, and the result will capture what we 1453 # made constant at compile time. 1454 result = graph(torch.tensor([[[1.0, 0], [0, 0]], [[1.0, 0], [0, 0]]])) 1455 self.assertTrue(torch._dynamo.utils.same(result, real_result)) 1456 1457 def test_export_with_constant_none_control_flow(self): 1458 class MyModule(torch.nn.Module): 1459 @torch._dynamo.assume_constant_result 1460 def helper_fn(self, x): 1461 if x.item() < 0: 1462 return None 1463 else: 1464 return x 1465 1466 def forward(self, x): 1467 y = torch.tensor([0.5]) 1468 x = self.helper_fn(x) 1469 if x is None: 1470 return y 1471 return y * x 1472 1473 module = MyModule() 1474 real_result = module(torch.tensor([-1])) 1475 1476 # X is negative, so .item() < 0, which means we return y 1477 self.assertEqual(real_result, torch.tensor([0.5])) 1478 1479 graph, guards = torch._dynamo.export(module)(torch.tensor([-1])) 1480 result = graph(torch.tensor([2])) 1481 # X is positive, but we compiled helper_fn to return None, so it will still return y 1482 self.assertTrue(torch._dynamo.utils.same(result, real_result)) 1483 1484 def test_export_with_constant_not_none_control_flow(self): 1485 class MyModule(torch.nn.Module): 1486 @torch._dynamo.assume_constant_result 1487 def helper_fn(self, x): 1488 if x.item() < 0: 1489 return None 1490 else: 1491 return x 1492 1493 def forward(self, x): 1494 y = torch.tensor([0.5]) 1495 x = self.helper_fn(x) 1496 if x is None: 1497 return y 1498 return y * x 1499 1500 module = MyModule() 1501 real_result = module(torch.tensor([2])) 1502 1503 # X is positive, so .item() > 0, which means we return y * x 1504 self.assertEqual(real_result, torch.tensor([1.0])) 1505 1506 graph, guards = torch._dynamo.export(module)(torch.tensor([2])) 1507 result = graph(torch.tensor([-0.5])) 1508 # X is negative, but we compiled helper_fn to return x, so it will still return y * x 1509 self.assertTrue(torch._dynamo.utils.same(result, real_result)) 1510 1511 def test_export_with_constant_none_control_flow_free_func(self): 1512 @torch._dynamo.assume_constant_result 1513 def helper_fn(x): 1514 if x.item() < 0: 1515 return None 1516 else: 1517 return x 1518 1519 class MyModule(torch.nn.Module): 1520 def forward(self, x): 1521 y = torch.tensor([0.5]) 1522 x = helper_fn(x) 1523 if x is None: 1524 return y 1525 return y * x 1526 1527 module = MyModule() 1528 real_result = module(torch.tensor([-1])) 1529 1530 # X is negative, so .item() < 0, which means we return y 1531 self.assertEqual(real_result, torch.tensor([0.5])) 1532 1533 graph, guards = torch._dynamo.export(module)(torch.tensor([-1])) 1534 result = graph(torch.tensor([2])) 1535 # X is positive, but we compiled helper_fn to return None, so it will still return y 1536 self.assertTrue(torch._dynamo.utils.same(result, real_result)) 1537 1538 def test_export_with_constant_not_none_control_flow_pos(self): 1539 class MyModule(torch.nn.Module): 1540 @torch._dynamo.assume_constant_result 1541 def helper_fn(self, x): 1542 if x.item() < 0: 1543 return None 1544 else: 1545 return x 1546 1547 def forward(self, x): 1548 y = torch.tensor([0.5]) 1549 x = self.helper_fn(x) 1550 if x is None: 1551 return y 1552 return y * x 1553 1554 module = MyModule() 1555 real_result = module(torch.tensor([2])) 1556 1557 # X is positive, so .item() > 0, which means we return y * x 1558 self.assertEqual(real_result, torch.tensor([1.0])) 1559 1560 graph, guards = torch._dynamo.export(module)(torch.tensor([2])) 1561 result = graph(torch.tensor([-0.5])) 1562 # X is negative, but we compiled helper_fn to return x, so it will still return y * x 1563 self.assertTrue(torch._dynamo.utils.same(result, real_result)) 1564 1565 def test_export_with_constant_not_none_control_flow_free_func(self): 1566 @torch._dynamo.assume_constant_result 1567 def helper_fn(x): 1568 if x.item() < 0: 1569 return None 1570 else: 1571 return x 1572 1573 class MyModule(torch.nn.Module): 1574 def forward(self, x): 1575 y = torch.tensor([0.5]) 1576 x = helper_fn(x) 1577 if x is None: 1578 return y 1579 return y * x 1580 1581 module = MyModule() 1582 real_result = module(torch.tensor([2])) 1583 1584 # X is positive, so .item() > 0, which means we return y * x 1585 self.assertEqual(real_result, torch.tensor([1.0])) 1586 1587 graph, guards = torch._dynamo.export(module)(torch.tensor([2])) 1588 result = graph(torch.tensor([-0.5])) 1589 # X is negative, but we compiled helper_fn to return x, so it will still return y * x 1590 self.assertTrue(torch._dynamo.utils.same(result, real_result)) 1591 1592 def test_export_with_constant_not_return_const(self): 1593 class MyModule(torch.nn.Module): 1594 @torch._dynamo.assume_constant_result 1595 def helper_fn(self, x): 1596 return self.val 1597 1598 def forward(self, x): 1599 y = torch.tensor([0.5]) 1600 x = self.helper_fn(x) 1601 if x == "A": 1602 return y 1603 return -1 1604 1605 module = MyModule() 1606 module.val = "A" 1607 resA = module(torch.tensor([2])) 1608 graph, guards = torch._dynamo.export(module)(torch.tensor([2])) 1609 module.val = "B" 1610 resB = graph(torch.tensor([2])) 1611 self.assertTrue(torch._dynamo.utils.same(resA, resB)) 1612 1613 def test_export_with_builtin_op_on_assume_constant(self): 1614 @torch._dynamo.assume_constant_result 1615 def get_y(y) -> torch.Tensor: 1616 return y 1617 1618 class Bob(torch.nn.Module): 1619 def __init__(self, p, val) -> None: 1620 super().__init__() 1621 self.p = p 1622 self.y = torch.nn.Parameter(torch.tensor(val)) 1623 1624 def forward(self, x: torch.Tensor) -> torch.Tensor: 1625 # This only looks dynamic but it's actually a constant value 1626 if get_y(self.y) < self.p: 1627 return torch.cat([x, x]) 1628 else: 1629 return x 1630 1631 model = Bob(0.5, 0.3) 1632 inp = torch.ones(3, 4) 1633 graph, guards = torch._dynamo.export(model)(inp) 1634 self.assertEqual(model(inp), graph(inp)) 1635 1636 def test_export_with_constant_in_unspecialized_nn_module(self): 1637 class Module(torch.nn.Module): 1638 def __init__(self, y): 1639 super().__init__() 1640 self.y = y 1641 1642 @torch._dynamo.assume_constant_result 1643 def check(self): 1644 return self.y[0].item() == 1 1645 1646 def forward(self, x): 1647 # This line leads to module obj being tracked as UnspecializedNNModuleVariable in dynamo 1648 self.device = x.device 1649 1650 if self.check(): 1651 return x + 1 1652 else: 1653 return x + 2 1654 1655 model = Module(torch.tensor([1])) 1656 inp = torch.ones(3, 4) 1657 graph, _ = torch._dynamo.export(model)(inp) 1658 self.assertEqual(model(inp), graph(inp)) 1659 1660 def test_export_decomp(self): 1661 def f(x): 1662 return x.t() + x.t() 1663 1664 def nop(x): 1665 return x.cos() 1666 1667 graph, _ = torch._dynamo.export( 1668 f, 1669 aten_graph=True, 1670 decomposition_table={torch.ops.aten.t.default: nop}, 1671 )(torch.randn(5)) 1672 self.assertEqual( 1673 len([n for n in graph.graph.nodes if n.target == torch.ops.aten.t.default]), 1674 0, 1675 ) 1676 1677 graph, _ = torch._dynamo.export(f, aten_graph=True, decomposition_table=None)( 1678 torch.randn(5) 1679 ) 1680 self.assertEqual( 1681 len([n for n in graph.graph.nodes if n.target == torch.ops.aten.t.default]), 1682 2, 1683 ) 1684 1685 def test_export_decomp_asserts_bad_args(self): 1686 def f(x): 1687 return x.t() + x.t() 1688 1689 def nop(x): 1690 return x.cos() 1691 1692 with self.assertRaises(AssertionError): 1693 graph, _ = torch._dynamo.export( 1694 f, 1695 (torch.randn(5)), 1696 aten_graph=False, 1697 decomposition_table={torch.ops.aten.t.default: nop}, 1698 ) 1699 1700 @config.patch(capture_scalar_outputs=True) 1701 def test_export_with_module_layer(self): 1702 from functorch.experimental.control_flow import cond 1703 1704 class Module(torch.nn.Module): 1705 def __init__(self) -> None: 1706 super().__init__() 1707 self.linear = torch.nn.Linear(3, 3) 1708 1709 def forward(self, pred, x): 1710 def true_fn(val): 1711 return self.linear(val) * torch.tensor(2) 1712 1713 def false_fn(val): 1714 return self.linear(val) * torch.tensor(-1) 1715 1716 return cond(pred, true_fn, false_fn, [x]) 1717 1718 mod = Module() 1719 x = torch.randn([3, 3]) 1720 pred = torch.tensor(x[0][0].item() < 0) 1721 real_result = mod.forward(pred, x) 1722 1723 torch._dynamo.reset() 1724 1725 exported = torch._dynamo.export(mod.forward)(pred, x) 1726 out_graph = exported[0] 1727 1728 dynamo_result = out_graph(pred, x) 1729 self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) 1730 1731 # New X, just to show we did not specialize 1732 x = x * -1 1733 pred = torch.tensor(x[0][0].item() < 0) 1734 real_result_2 = mod.forward(pred, x) 1735 dynamo_result_2 = out_graph(pred, x) 1736 self.assertTrue(torch._dynamo.utils.same(real_result_2, dynamo_result_2)) 1737 1738 @config.patch(capture_scalar_outputs=True) 1739 def test_export_with_cond_branches_calling_methods(self): 1740 from functorch.experimental.control_flow import cond 1741 1742 class Module(torch.nn.Module): 1743 # ok 1744 def __init__(self) -> None: 1745 super().__init__() 1746 self.linear = torch.nn.Linear(3, 3) 1747 1748 def t(self, val): 1749 return val + 1 1750 1751 def f(self, val): 1752 return val - 1 1753 1754 def true_fn(self, val): 1755 return self.linear(val) + self.t(val) 1756 1757 def false_fn(self, val): 1758 return self.linear(val) - self.f(val) 1759 1760 def forward(self, pred, x): 1761 return cond(pred, self.true_fn, self.false_fn, [x]) 1762 1763 mod = Module() 1764 x = torch.randn([3, 3]) 1765 pred = torch.tensor(x[0][0].item() < 0) 1766 real_result = mod.forward(pred, x) 1767 out_graph, _ = torch._dynamo.export(mod.forward)(pred, x) 1768 dynamo_result = out_graph(pred, x) 1769 self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) 1770 1771 @config.patch(capture_scalar_outputs=True) 1772 def test_export_with_cond_closure(self): 1773 from functorch.experimental.control_flow import cond 1774 1775 class Foo(torch.nn.Module): 1776 def __init__(self) -> None: 1777 super().__init__() 1778 1779 def forward(self, pred, x): 1780 def true_fn(x): 1781 return x * 2 1782 1783 def false_fn(x): 1784 return x - 2 1785 1786 return cond(pred, true_fn, false_fn, [x]) 1787 1788 class Bar(torch.nn.Module): 1789 def __init__(self) -> None: 1790 super().__init__() 1791 1792 def forward(self, pred, x): 1793 def true_fn(x): 1794 return x * 2 1795 1796 def false_fn(x): 1797 return x - 2 1798 1799 return cond(pred, true_fn, false_fn, [x + 1]) 1800 1801 class FooBar(torch.nn.Module): 1802 def __init__(self) -> None: 1803 super().__init__() 1804 self.linear = torch.nn.Linear(3, 3) 1805 1806 def forward(self, pred, x): 1807 y = x + x 1808 1809 def true_fn(x, y): 1810 return self.linear(x) * (x + y) 1811 1812 def false_fn(x, y): 1813 return x * (y - x) 1814 1815 return cond(pred, true_fn, false_fn, [x, y]) 1816 1817 for Module in [Foo, Bar, FooBar]: 1818 mod = Module() 1819 x = torch.randn([3, 3], requires_grad=True) 1820 pred = torch.tensor(x[0][0].item() < 0) 1821 real_result = mod.forward(pred, x) 1822 out_graph, _ = torch._dynamo.export(mod.forward)(pred, x) 1823 dynamo_result = out_graph(pred, x) 1824 self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) 1825 1826 def test_export_with_cond_with_closed_function(self): 1827 def hello(x): 1828 return x + 1 1829 1830 def hi(x): 1831 return x + 2 1832 1833 def foo(pred, x): 1834 def true_fn(x): 1835 return hello(x) 1836 1837 def false_fn(x): 1838 return hi(x) 1839 1840 return cond(pred, true_fn, false_fn, [x]) 1841 1842 x = torch.randn(5) 1843 pred = x[0] > 0 1844 real_result = foo(pred, x) 1845 out_graph, _ = torch._dynamo.export(foo)(pred, x) 1846 dynamo_result = out_graph(pred, x) 1847 self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) 1848 1849 def test_export_with_cond_dynamic_shape_pred(self): 1850 from functorch.experimental.control_flow import cond 1851 1852 class Module(torch.nn.Module): 1853 def forward(self, x): 1854 def true_fn(x): 1855 return x + x 1856 1857 def false_fn(x): 1858 return x[:2] 1859 1860 return cond(x.shape[0] <= 2, true_fn, false_fn, [x]) 1861 1862 class Module2(torch.nn.Module): 1863 def forward(self, x): 1864 def true_fn(x): 1865 return x + x 1866 1867 def false_fn(x): 1868 return x[:2] 1869 1870 return cond(x.shape[0] <= 2, true_fn, false_fn, (x,)) 1871 1872 mods = [Module(), Module2()] 1873 for mod in mods: 1874 x = torch.randn(2, 2) 1875 out_graph, guards = torch._dynamo.export(mod)(x) 1876 self.assertExpectedInline( 1877 out_graph.code.strip(), 1878 """\ 1879def forward(self, x): 1880 arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) 1881 l_x_ = arg0 1882 size = l_x_.size() 1883 getitem = size[0]; size = None 1884 le = getitem <= 2; getitem = None 1885 cond_true_0 = self.cond_true_0 1886 cond_false_0 = self.cond_false_0 1887 cond = torch.ops.higher_order.cond(le, cond_true_0, cond_false_0, [l_x_]); le = cond_true_0 = cond_false_0 = l_x_ = None 1888 getitem_2 = cond[0]; cond = None 1889 return pytree.tree_unflatten([getitem_2], self._out_spec)""", 1890 ) 1891 self.assertExpectedInline( 1892 out_graph.cond_true_0.code.strip(), 1893 """\ 1894def forward(self, l_x_): 1895 l_x__1 = l_x_ 1896 add = l_x__1 + l_x__1; l_x__1 = None 1897 return (add,)""", 1898 ) 1899 self.assertExpectedInline( 1900 out_graph.cond_false_0.code.strip(), 1901 """\ 1902def forward(self, l_x_): 1903 l_x__1 = l_x_ 1904 getitem = l_x__1[slice(None, 2, None)]; l_x__1 = None 1905 return (getitem,)""", 1906 ) 1907 with self.assertRaisesRegex( 1908 torch._dynamo.exc.UncapturedHigherOrderOpError, 1909 "Cond doesn't work unless it is captured completely with torch.compile", 1910 ): 1911 # True branch and false branch return tensors of different shape 1912 torch._dynamo.export(mod)(torch.randn(3, 2)) 1913 1914 # We specialize into one of the branches since predicate is a python boolean. 1915 test_x = torch.randn(3, 2) 1916 mod(test_x) 1917 1918 def test_export_with_map_cond(self): 1919 from functorch.experimental.control_flow import cond, map 1920 1921 class Module(torch.nn.Module): 1922 def inner(self, x, pred): 1923 def true_fn(x): 1924 return x + x 1925 1926 def false_fn(x): 1927 return x * x 1928 1929 return cond(pred, true_fn, false_fn, [x]) 1930 1931 def forward(self, pred, xs): 1932 def body(x, pred): 1933 return self.inner(x, pred) 1934 1935 return map(body, xs, pred) 1936 1937 mod = Module() 1938 x = torch.randn(3, 2, 1) 1939 pred_x = torch.tensor(True) 1940 1941 y = torch.randn(4, 3, 2) 1942 pred_y = torch.tensor(False) 1943 real_result = mod(pred_y, y) 1944 1945 out_graph, _ = torch._dynamo.export(mod)(pred_x, x) 1946 self.assertEqual(real_result, out_graph(pred_y, y)) 1947 1948 def test_export_with_map_zero_sized_tensor(self): 1949 from functorch.experimental.control_flow import map 1950 1951 class Module(torch.nn.Module): 1952 def forward(self, xs): 1953 def body(x): 1954 return x + 1 1955 1956 return map(body, xs) 1957 1958 mod = Module() 1959 xs = torch.randn(0, 2) 1960 with self.assertRaisesRegex( 1961 torch._dynamo.exc.Unsupported, 1962 "zero-sized tensor", 1963 ): 1964 out_graph, _ = torch._dynamo.export(mod)(xs) 1965 1966 def test_export_meta_val(self): 1967 def f(x, y, z): 1968 return x * y + z 1969 1970 gm, _ = torch._dynamo.export( 1971 f, 1972 aten_graph=True, 1973 )( 1974 torch.ones(3, 2), 1975 torch.zeros(3, 2), 1976 torch.ones(3, 2), 1977 ) 1978 for node in gm.graph.nodes: 1979 if node.op == "placeholder": 1980 self.assertIn("val", node.meta) 1981 1982 def test_input_container_type(self): 1983 def f(x: torch.Tensor, y: List[torch.Tensor]) -> Dict[str, torch.Tensor]: 1984 return {"a": x.sum() + sum(y).sum()} 1985 1986 inp = (torch.randn(6, 5), [torch.randn(6, 5), torch.randn(6, 5)]) 1987 1988 gm, _ = torch._dynamo.export(f, aten_graph=True)(*inp) 1989 1990 self.assertEqual(gm(*inp), f(*inp)) 1991 1992 @config.patch(assume_static_by_default=False) 1993 def test_export_symbolic_shape(self): 1994 def f(x: torch.Tensor) -> torch.Tensor: 1995 return torch.empty(x.shape[0] * 2) 1996 1997 inp = (torch.randn(6, 5),) 1998 gm, _ = torch._dynamo.export(f, aten_graph=True)(*inp) 1999 2000 has_sym_size = False 2001 for node in gm.graph.nodes: 2002 if node.target is torch.ops.aten.sym_size.int: 2003 has_sym_size = True 2004 2005 self.assertTrue(has_sym_size) 2006 2007 @config.patch(assume_static_by_default=False) 2008 def test_dynamic_slicing(self): 2009 def f(x): 2010 return x[: x.shape[0] - 2, x.shape[1] - 1 :: 2] 2011 2012 gm_aten_mode, _ = torch._dynamo.export(f, aten_graph=True)(torch.randn(4, 5)) 2013 2014 inp = torch.randn(6, 7) 2015 self.assertEqual(gm_aten_mode(inp).shape, f(inp).shape) 2016 2017 count = 0 2018 # aten graph should flatten getitem calls to actual 2019 # slice kernel call. 2020 for node in gm_aten_mode.graph.nodes: 2021 if ( 2022 node.op == "call_function" 2023 and node.target == torch.ops.aten.slice.Tensor 2024 ): 2025 count += 1 2026 2027 self.assertEqual(count, 2) 2028 2029 gm_torch_mode, _ = torch._dynamo.export(f, aten_graph=False)(torch.randn(4, 5)) 2030 2031 # In torch mode, the graph should contain 3 getitem methods 2032 # one for x.shape[0]-2 and one for x.shape[1]-1 and one for slice 2033 # this is because Tensor class has its' own getitem method 2034 # which gets translated to aten.Slice later. 2035 count = 0 2036 for node in gm_torch_mode.graph.nodes: 2037 if node.op == "call_function" and node.target == operator.getitem: 2038 count += 1 2039 2040 self.assertEqual(count, 1) 2041 self.assertEqual(gm_torch_mode(inp).shape, f(inp).shape) 2042 2043 def test_dynamic_slicing_invalid(self): 2044 def g(x, y): 2045 return x[y : x.shape[0]] 2046 2047 with self.assertRaisesRegex( 2048 torch._dynamo.exc.Unsupported, 2049 "Dynamic slicing on data-dependent value is not supported", 2050 ): 2051 torch._dynamo.export( 2052 g, 2053 aten_graph=True, 2054 )( 2055 torch.randn(4, 5), 2056 torch.tensor(2), 2057 ) 2058 2059 @config.patch(capture_scalar_outputs=True) 2060 def test_dynamic_slicing_simple(self): 2061 def f(x): 2062 return x[slice(None, None, None)] 2063 2064 gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.randn(4, 5)) 2065 2066 inp = torch.randn(6, 7) 2067 self.assertEqual(gm(inp), f(inp)) 2068 2069 def test_pre_dispatch_simple(self): 2070 def f(x): 2071 y = torch.ones_like(x) 2072 return torch.matmul(x, y) 2073 2074 gm, _ = torch._dynamo.export( 2075 f, 2076 aten_graph=True, 2077 pre_dispatch=True, 2078 tracing_mode="fake", 2079 )( 2080 torch.randn(5, 5), 2081 ) 2082 2083 inp = torch.randn(6, 6) 2084 self.assertEqual(gm(inp), f(inp)) 2085 self.assertExpectedInline( 2086 gm.code.strip(), 2087 """\ 2088def forward(self, x): 2089 arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) 2090 arg0_1 = arg0 2091 ones_like = torch.ops.aten.ones_like.default(arg0_1, pin_memory = False) 2092 matmul = torch.ops.aten.matmul.default(arg0_1, ones_like); arg0_1 = ones_like = None 2093 return pytree.tree_unflatten([matmul], self._out_spec)""", 2094 ) 2095 2096 @patch.object(torch._dynamo.config, "capture_scalar_outputs", True) 2097 def test_export_cond_in_aten_symbolic(self): 2098 class ConditionOp(torch.nn.Module): 2099 def true_fn(self, x, y): 2100 return x * y 2101 2102 def false_fn(self, x, y): 2103 return x + y 2104 2105 def forward(self, pred, x, y): 2106 return cond(pred, self.true_fn, self.false_fn, [x, y]) 2107 2108 model = ConditionOp() 2109 inp = ( 2110 torch.tensor(False), 2111 torch.randn(4, 4), 2112 torch.randn(4, 4), 2113 ) 2114 gm, _ = torch._dynamo.export(model, aten_graph=True)(*inp) 2115 2116 gm.print_readable() 2117 2118 self.assertEqual(gm(*inp), model(*inp)) 2119 2120 def test_export_with_kwargs(self): 2121 def fn_with_kwargs(pos0, tuple0, *myargs, mykw0=None, **mykwargs): 2122 out = pos0 2123 for arg in tuple0: 2124 out *= arg 2125 for arg in myargs: 2126 out *= arg 2127 out *= mykw0 2128 out *= mykwargs["input0"] * mykwargs["input1"] 2129 return out 2130 2131 mykwargs = {"input0": torch.randn(4), "input1": torch.randn(4)} 2132 tuple0 = (torch.randn(4), torch.randn(4)) 2133 mykw0 = torch.randn(4) 2134 pos0 = torch.randn(4) 2135 myargs = [torch.randn(4), torch.randn(4)] 2136 2137 expected_argument_names = [ 2138 "pos0", 2139 "tuple0", 2140 "myargs_0", 2141 "myargs_1", 2142 "mykw0", 2143 "input0", 2144 "input1", 2145 ] 2146 self._test_export_preserving_original_signature( 2147 fn_with_kwargs, 2148 expected_argument_names, 2149 pos0, 2150 tuple0, 2151 *myargs, 2152 mykw0=mykw0, 2153 **mykwargs, 2154 ) 2155 2156 def test_export_with_kwargs_and_empty_args(self): 2157 def fn_with_kwargs(mykw0=None, **mykwargs): 2158 out = mykw0 2159 out *= mykwargs["input0"] * mykwargs["input1"] 2160 return out 2161 2162 mykwargs = {"input0": torch.randn(4), "input1": torch.randn(4)} 2163 mykw0 = torch.randn(4) 2164 2165 expected_argument_names = ["mykw0"] + list(mykwargs.keys()) 2166 self._test_export_preserving_original_signature( 2167 fn_with_kwargs, expected_argument_names, mykw0, **mykwargs 2168 ) 2169 2170 def test_export_with_args_and_empty_kwargs(self): 2171 def fn_with_kwargs(pos0, tuple0, *myargs): 2172 out = pos0 2173 for arg in tuple0: 2174 out *= arg 2175 for arg in myargs: 2176 out *= arg 2177 return out 2178 2179 tuple0 = (torch.randn(4), torch.randn(4)) 2180 pos0 = torch.randn(4) 2181 myargs = [torch.randn(4), torch.randn(4)] 2182 2183 expected_argument_names = ["pos0", "tuple0", "myargs_0", "myargs_1"] 2184 self._test_export_preserving_original_signature( 2185 fn_with_kwargs, expected_argument_names, pos0, tuple0, *myargs 2186 ) 2187 2188 @common_utils.parametrize( 2189 "default_value", 2190 [ 2191 common_utils.subtest(None, name="None"), 2192 common_utils.subtest(42.0, name="float"), 2193 common_utils.subtest( 2194 # FIXME: AssertionError: Dynamo input and output is a strict subset of traced input/output 2195 torch.randn(4), 2196 name="tensor", 2197 decorators=[unittest.expectedFailure], 2198 ), 2199 common_utils.subtest( 2200 # FIXME: AssertionError: Dynamo input and output is a strict subset of traced input/output 2201 (torch.randn(4),), 2202 name="tuple", 2203 decorators=[unittest.expectedFailure], 2204 ), 2205 ], 2206 ) 2207 def test_export_with_args_with_default(self, default_value): 2208 def fn(pos0, pos1_default=default_value): 2209 out = pos0 2210 if pos1_default is None: 2211 pos1_default = torch.randn(4) 2212 if isinstance(pos1_default, tuple): 2213 pos1_default = pos1_default[0] 2214 out *= pos1_default 2215 return out 2216 2217 pos0 = torch.randn(4) 2218 expected_argument_names = ["pos0"] 2219 self._test_export_preserving_original_signature( 2220 fn, expected_argument_names, pos0 2221 ) 2222 2223 @common_utils.parametrize( 2224 "default_value", 2225 [ 2226 common_utils.subtest(None, name="None"), 2227 common_utils.subtest(42.0, name="float"), 2228 common_utils.subtest( 2229 # FIXME: AssertionError: Dynamo input and output is a strict subset of traced input/output 2230 torch.randn(4), 2231 name="tensor", 2232 decorators=[unittest.expectedFailure], 2233 ), 2234 common_utils.subtest( 2235 # FIXME: AssertionError: Dynamo input and output is a strict subset of traced input/output 2236 (torch.randn(4),), 2237 name="tuple", 2238 decorators=[unittest.expectedFailure], 2239 ), 2240 ], 2241 ) 2242 def test_export_with_kwargs_with_default(self, default_value): 2243 def fn(pos0, *, kw0, kw1_default=default_value, **kwargs): 2244 out = pos0 2245 out += kw0 2246 if kw1_default is None: 2247 kw1_default = torch.randn(4) 2248 elif isinstance(kw1_default, tuple): 2249 kw1_default = kw1_default[0] 2250 out += kw1_default 2251 out += kwargs["kw2"] 2252 return out 2253 2254 pos0 = torch.randn(4) 2255 kw0 = torch.randn(4) 2256 kw2 = torch.randn(4) 2257 2258 args = (pos0,) 2259 kwargs = {"kw0": kw0, "kw2": kw2} 2260 expected_argument_names = ["pos0", "kw0", "kw2"] 2261 self._test_export_preserving_original_signature( 2262 fn, expected_argument_names, *args, **kwargs 2263 ) 2264 2265 def test_export_with_wrapped_fn(self): 2266 # To ensure dynamo.export is robust to wrapped functions 2267 # when it cannot use `inspect` to retrieve original signature 2268 # info. 2269 def _fn(pos0, pos1=1.0, *args, kw0, kw1=2.0, **kwargs): 2270 out = pos0 2271 out += pos1 2272 out += kw0 2273 out += kw1 2274 for arg in args: 2275 out += arg 2276 for kwarg in kwargs.values(): 2277 out += kwarg 2278 return out 2279 2280 def wrapped_fn(*args, **kwargs): 2281 return _fn(*args, **kwargs) 2282 2283 pos0 = torch.randn(4) 2284 kw0 = torch.randn(4) 2285 args = (pos0, torch.randn(4), torch.randn(4)) 2286 kwargs = {"kw0": kw0, "kw2": torch.randn(4)} 2287 expected_argument_names = [f"args_{i}" for i in range(len(args))] + list( 2288 kwargs.keys() 2289 ) 2290 2291 self._test_export_preserving_original_signature( 2292 wrapped_fn, expected_argument_names, *args, **kwargs 2293 ) 2294 2295 def test_export_with_functools_wrapped_method(self): 2296 def test_decorator(func): 2297 @functools.wraps(func) 2298 def wrapper(*args, **kwargs): 2299 return func(*args, **kwargs) 2300 2301 return wrapper 2302 2303 class MyModule(torch.nn.Module): 2304 def __init__(self) -> None: 2305 super().__init__() 2306 2307 def forward(self, x): 2308 return x 2309 2310 @test_decorator 2311 def method_to_test(self, pos0, pos1=1.0, *args, kw0, kw1=2.0, **kwargs): 2312 out = pos0 2313 out += pos1 2314 out += kw0 2315 out += kw1 2316 for arg in args: 2317 out += arg 2318 for kwarg in kwargs.values(): 2319 out += kwarg 2320 return out 2321 2322 pos0 = torch.randn(4) 2323 pos1 = torch.randn(4) 2324 unnamed_pos = torch.randn(4) 2325 kw0 = torch.randn(4) 2326 args = (pos0, pos1, unnamed_pos) 2327 kwargs = {"kw0": kw0, "kw2": torch.randn(4), "unnamed_kw": torch.randn(4)} 2328 expected_argument_names = [ 2329 "pos0", 2330 "pos1", 2331 "args_0", # 3rd unnamed positional argument 2332 ] + list(kwargs.keys()) 2333 m = MyModule() 2334 2335 self._test_export_preserving_original_signature( 2336 m.method_to_test, expected_argument_names, *args, **kwargs 2337 ) 2338 2339 def test_export_with_functools_wrapped_fn(self): 2340 def test_decorator(func): 2341 @functools.wraps(func) 2342 def wrapper(*args, **kwargs): 2343 return func(*args, **kwargs) 2344 2345 return wrapper 2346 2347 @test_decorator 2348 def _fn(pos0, pos1=1.0, *args, kw0, kw1=2.0, **kwargs): 2349 out = pos0 2350 out += pos1 2351 out += kw0 2352 out += kw1 2353 for arg in args: 2354 out += arg 2355 for kwarg in kwargs.values(): 2356 out += kwarg 2357 return out 2358 2359 def wrapped_fn(*args, **kwargs): 2360 return _fn(*args, **kwargs) 2361 2362 pos0 = torch.randn(4) 2363 kw0 = torch.randn(4) 2364 args = (pos0, torch.randn(4), torch.randn(4)) 2365 kwargs = {"kw0": kw0, "kw2": torch.randn(4)} 2366 expected_argument_names = [f"args_{i}" for i in range(len(args))] + list( 2367 kwargs.keys() 2368 ) 2369 2370 self._test_export_preserving_original_signature( 2371 wrapped_fn, expected_argument_names, *args, **kwargs 2372 ) 2373 2374 def _test_export_preserving_original_signature( 2375 self, fn, expected_argument_names: Sequence[str], *args, **kwargs 2376 ): 2377 torch._dynamo.reset() 2378 exported = torch._dynamo.export( 2379 fn, 2380 *args, 2381 **kwargs, 2382 aten_graph=False, 2383 ) 2384 2385 out_graph = exported[0] 2386 dynamo_result = out_graph(*args, **kwargs) 2387 real_result = fn(*args, **kwargs) 2388 self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) 2389 2390 # Check that the exported graph preserves same argument names. 2391 self.assertEqual( 2392 inspect.getfullargspec(out_graph.forward).args[1:], expected_argument_names 2393 ) 2394 2395 def test_dataclass_input_output(self): 2396 from dataclasses import dataclass 2397 2398 @dataclass 2399 class Tensors: 2400 x: torch.Tensor 2401 y: torch.Tensor 2402 2403 def f(t): 2404 return t.x + t.y 2405 2406 with self.assertRaisesRegex( 2407 UserError, 2408 "It looks like one of the inputs with type .*Tensors.* " 2409 "is not supported or pytree-flattenable", 2410 ): 2411 torch._dynamo.export(f, aten_graph=False)( 2412 Tensors(x=torch.randn(10), y=torch.randn(10)) 2413 ) 2414 2415 def f(x, y): 2416 return Tensors(x=x.sin(), y=y.cos()) 2417 2418 with self.assertRaisesRegex( 2419 UserError, 2420 "It looks like one of the outputs with type .*Tensors.* " 2421 "is not supported or pytree-flattenable", 2422 ): 2423 torch._dynamo.export(f, aten_graph=False)(torch.randn(10), torch.randn(10)) 2424 2425 def test_empty(self): 2426 def f(x): 2427 return x 2428 2429 exported = torch._dynamo.export(f)(torch.randn(3, 3)) 2430 out_graph = exported[0] 2431 inp = torch.randn(3, 3) 2432 self.assertTrue(torch._dynamo.utils.same(inp, out_graph(inp))) 2433 2434 class M(torch.nn.Module): 2435 def __init__(self) -> None: 2436 super().__init__() 2437 self.a = torch.ones(3, 3) 2438 2439 def forward(self): 2440 return self.a 2441 2442 exported = torch._dynamo.export(M())() 2443 out_graph = exported[0] 2444 self.assertTrue(torch._dynamo.utils.same(torch.ones(3, 3), out_graph())) 2445 2446 @unittest.skipIf(not TEST_CUDA, "No CUDA available.") 2447 def test_export_with_parameters(self): 2448 class MyModule(torch.nn.Module): 2449 def __init__(self) -> None: 2450 super().__init__() 2451 self.features = torch.nn.Sequential( 2452 torch.nn.Conv2d( 2453 3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1) 2454 ), 2455 torch.nn.ReLU(inplace=True), 2456 ) 2457 2458 def forward(self, x): 2459 return self.features(x) 2460 2461 model = MyModule().eval().cuda() 2462 random_inputs = (torch.rand([32, 3, 32, 32]).to("cuda"),) 2463 dim_x = torch.export.Dim("dim_x", min=1, max=32) 2464 exp_program = torch.export.export( 2465 model, random_inputs, dynamic_shapes={"x": {0: dim_x}} 2466 ) 2467 output_buffer = io.BytesIO() 2468 # Tests if we can restore saved nn.Parameters when we load them again 2469 torch.export.save(exp_program, output_buffer) 2470 loaded_model = torch.export.load(output_buffer) 2471 self.assertTrue( 2472 isinstance( 2473 loaded_model.module().get_parameter("features.0.weight"), 2474 torch.nn.Parameter, 2475 ) 2476 ) 2477 2478 def test_export_fast_binary_broadcast_check(self): 2479 # This test looks at the case where we erroneously create a guard 2480 # when checking the equality of the operands' shape and the output 2481 # shape during FakeTensor's binary op fast path. 2482 2483 class MyModel(torch.nn.Module): 2484 def forward(self, a, b): 2485 # final shape is (dim0, 4, 8) 2486 # order matters since a & the output have the same shape 2487 return b + a 2488 2489 a = torch.randn(100, 4, 8) 2490 b = torch.randn(4, 8) 2491 model = MyModel().eval().cuda() 2492 batchsize = torch.export.Dim("dim0", min=3, max=1024) 2493 dynamic_shape_spec = {"a": [batchsize, None, None], "b": [None, None]} 2494 2495 torch.export.export(model, (a, b), dynamic_shapes=dynamic_shape_spec) 2496 2497 def test_export_fast_binary_broadcast_check_unbacked(self): 2498 class MyModel(torch.nn.Module): 2499 def forward(self, numel, scalar): 2500 u0 = numel.item() 2501 torch._check_is_size(u0) 2502 x = torch.ones(u0 + 1) 2503 return scalar - x 2504 2505 model = MyModel().eval().cuda() 2506 numel = torch.tensor(10) 2507 scalar = torch.randn(1) 2508 torch.export.export(model, (numel, scalar)) 2509 2510 def test_export_meta(self): 2511 class MyModule(torch.nn.Module): 2512 def __init__(self) -> None: 2513 super().__init__() 2514 self.p = torch.nn.Parameter(torch.ones(2, 3)) 2515 2516 def forward(self, x): 2517 return self.p + x 2518 2519 with torch.device("meta"): 2520 m = MyModule() 2521 2522 inp = torch.ones(2, 3, device="meta") 2523 exported = torch._dynamo.export(m)(inp) 2524 out_graph = exported[0] 2525 dynamo_result = out_graph(inp) 2526 self.assertEqual(dynamo_result, m(inp)) 2527 2528 def test_constraint_violation_error_messages(self): 2529 class Foo(torch.nn.Module): 2530 def forward(self, x): 2531 if x.shape[0] == x.shape[1] * 2: 2532 return x + 1 2533 else: 2534 return x + 2 2535 2536 foo = Foo() 2537 2538 t = torch.zeros([8, 4]) 2539 dim0 = torch.export.Dim("dim0", min=3, max=10) 2540 dim1 = torch.export.Dim("dim1") 2541 dynamic_shapes = {"x": (dim0, dim1)} 2542 2543 with self.assertRaisesRegex( 2544 torch._dynamo.exc.UserError, 2545 "Constraints violated .*!(.*\n)*.*" 2546 "by dim0 = 2\\*dim1(.*\n)*.*" 2547 "Not all values of dim1 .* satisfy the generated guard 2 <= .* and .* <= 5(.*\n)*.*", 2548 ): 2549 torch.export.export(foo, (t,), dynamic_shapes=dynamic_shapes) 2550 2551 class Bar(torch.nn.Module): 2552 def forward(self, x): 2553 if x.shape[0] == 5: 2554 return x + 1 2555 else: 2556 return x + 2 2557 2558 bar = Bar() 2559 2560 t = torch.zeros([5]) 2561 dim0 = torch.export.Dim("dim0", min=3, max=8) 2562 dynamic_shapes = {"x": (dim0,)} 2563 with self.assertRaisesRegex( 2564 torch._dynamo.exc.UserError, 2565 "Not all values.*valid.*inferred to be a constant", 2566 ): 2567 torch.export.export(bar, (t,), dynamic_shapes=dynamic_shapes) 2568 2569 class Qux(torch.nn.Module): 2570 def forward(self, x): 2571 if x.shape[0] > 5 and x.shape[0] < 10: 2572 return x + 1 2573 else: 2574 return x + 2 2575 2576 qux = Qux() 2577 2578 t = torch.zeros([7]) 2579 dim0 = torch.export.Dim("dim0", min=3, max=8) 2580 dynamic_shapes = {"x": (dim0,)} 2581 with self.assertRaisesRegex( 2582 torch._dynamo.exc.UserError, 2583 "Not all values.*satisfy the generated guard", 2584 ): 2585 torch.export.export(qux, (t,), dynamic_shapes=dynamic_shapes) 2586 2587 def test_untracked_inputs_in_constraints(self): 2588 from copy import copy 2589 2590 class Foo(torch.nn.Module): 2591 def forward(self, x, y): 2592 return y + 1 2593 2594 foo = Foo() 2595 2596 x = torch.randn(2) 2597 y = torch.randn(5, 4) 2598 2599 dim0_x, dim0_y = torch.export.dims("dim0_x", "dim0_y") 2600 dynamic_shapes = {"x": {0: dim0_x}, "y": {0: dim0_y}} 2601 2602 example_inputs = (copy(x), y) 2603 ep = torch.export.export(foo, example_inputs, dynamic_shapes=dynamic_shapes) 2604 ep.module()(torch.randn(3), y) # no specialization error 2605 2606 def test_export_raise_guard_full_constraint(self): 2607 y = torch.randn([3, 3, 3]) 2608 2609 def my_dyn_fn(x): 2610 if x.shape[0] == 3: 2611 return x.sin() 2612 return x.cos() 2613 2614 torch._dynamo.export(my_dyn_fn)(y) 2615 2616 with self.assertRaises(ConstraintViolationError): 2617 torch._dynamo.export( 2618 my_dyn_fn, dynamic_shapes=({0: torch.export.Dim("dimx")},) 2619 )(y) 2620 2621 def test_export_module_specify_constraints_signature(self): 2622 y = torch.randn([3, 3, 3]) 2623 2624 class Mod(torch.nn.Module): 2625 def forward(self, x): 2626 if x.shape[0] == 3: 2627 return x.sin() 2628 return x.cos() 2629 2630 mod = Mod() 2631 torch._dynamo.export(mod)(y) 2632 2633 with self.assertRaisesRegex(ConstraintViolationError, "dimx = 3"): 2634 torch._dynamo.export(mod, dynamic_shapes=({0: torch.export.Dim("dimx")},))( 2635 y 2636 ) 2637 2638 def test_export_raise_guard_partial_constraint(self): 2639 y = torch.randn([3, 3, 3]) 2640 2641 def my_dyn_fn(x): 2642 if x.shape[0] > 3: 2643 return x.sin() 2644 return x.cos() 2645 2646 torch._dynamo.export(my_dyn_fn)(y) 2647 2648 with self.assertRaises(ConstraintViolationError): 2649 torch._dynamo.export( 2650 my_dyn_fn, dynamic_shapes=({0: torch.export.Dim("dimx")},) 2651 )(y) 2652 2653 def test_export_raise_on_relationship(self): 2654 y = torch.randn([3, 3, 3]) 2655 2656 def my_dyn_fn(a, b, c): 2657 if a.shape[0] == b.shape[1] == c.shape[2]: 2658 return a.sin() 2659 2660 return a.cos() 2661 2662 torch._dynamo.export(my_dyn_fn)(y, y, y) 2663 dim = torch.export.Dim("dim") 2664 dynamic_shapes = ({0: dim}, {0: dim}, {0: dim}) 2665 with self.assertRaises(ConstraintViolationError): 2666 torch._dynamo.export(my_dyn_fn, dynamic_shapes=dynamic_shapes)(y, y, y) 2667 dynamic_shapes = ({0: dim}, {1: dim}, {2: dim}) 2668 torch._dynamo.export(my_dyn_fn, dynamic_shapes=dynamic_shapes)(y, y, y) 2669 2670 def test_export_no_raise(self): 2671 y = torch.randn([3, 3, 3]) 2672 2673 def my_dyn_fn(a, b, c): 2674 if a.shape[1] == 3: 2675 return a.cos() 2676 return a * b * c 2677 2678 torch._dynamo.export(my_dyn_fn)(y, y, y) 2679 dim = torch.export.Dim("dim") 2680 dynamic_shapes = ({0: dim}, {0: dim}, {0: dim}) 2681 torch._dynamo.export(my_dyn_fn, dynamic_shapes=dynamic_shapes)(y, y, y) 2682 2683 def test_export_multi_dynamic_dim_unsafe_relationship(self): 2684 x = torch.randn([3, 3, 3]) 2685 y = torch.randn([2, 2, 2]) 2686 z = torch.randn([3, 3, 3]) 2687 2688 def my_dyn_fn(a, b, c): 2689 if a.shape[0] == c.shape[0]: 2690 return a.cos() 2691 return a * c, b 2692 2693 torch._dynamo.export(my_dyn_fn)(x, y, z) 2694 dimx, dimy, dimz = torch.export.dims("dimx", "dimy", "dimz") 2695 dynamic_shapes = ({0: dimx}, {0: dimy}, {0: dimz}) 2696 with self.assertRaises(ConstraintViolationError): 2697 torch._dynamo.export(my_dyn_fn, dynamic_shapes=dynamic_shapes)(x, y, z) 2698 dimz = dimx 2699 dynamic_shapes = ({0: dimx}, {0: dimy}, {0: dimz}) 2700 torch._dynamo.export(my_dyn_fn, dynamic_shapes=dynamic_shapes)(x, y, z) 2701 2702 def test_remove_redundant_dynamic_dim_in_error_message(self): 2703 class Foo(torch.nn.Module): 2704 def forward(self, x, y): 2705 if x.shape[0] == y["k"].shape[0]: 2706 return x + 1 2707 else: 2708 return x - 1 2709 2710 foo = Foo() 2711 2712 a = torch.randn(3) 2713 b = torch.randn(3) 2714 dim0_a, dim0_b = torch.export.dims("dim0_a", "dim0_b") 2715 with self.assertRaisesRegex(torch._dynamo.exc.UserError, "dim0_b = dim0_a"): 2716 torch.export.export( 2717 foo, 2718 (a, {"k": b}), 2719 dynamic_shapes={"x": {0: dim0_a}, "y": {"k": {0: dim0_b}}}, 2720 ) 2721 2722 def test_enforce_equalities(self): 2723 class Bar(torch.nn.Module): 2724 def forward(self, x, y): 2725 return torch.matmul(x, y) 2726 2727 bar = Bar() 2728 2729 batch, size = torch.export.dims("batch", "size") 2730 dynamic_shapes = {"x": (batch, size, size), "y": (batch, size, size)} 2731 2732 x = torch.randn(10, 3, 3) 2733 y = torch.randn(10, 3, 4) 2734 with self.assertRaisesRegex( 2735 torch._dynamo.exc.UserError, 2736 ".*y.*size.*2.* = 4 is not equal to .*x.*size.*1.* = 3", 2737 ): 2738 torch.export.export( 2739 bar, 2740 (x, y), 2741 dynamic_shapes=dynamic_shapes, 2742 ) 2743 y = torch.randn(10, 3, 3) 2744 ebar = torch.export.export( 2745 bar, 2746 (x, y), 2747 dynamic_shapes=dynamic_shapes, 2748 ) 2749 self.assertEqual( 2750 [ 2751 str(node.meta["val"].shape) 2752 for node in ebar.graph_module.graph.nodes 2753 if node.op == "placeholder" 2754 ], 2755 ["torch.Size([s0, s1, s1])", "torch.Size([s0, s1, s1])"], 2756 ) 2757 2758 @torch._dynamo.config.patch( 2759 capture_dynamic_output_shape_ops=True, 2760 specialize_int=True, 2761 capture_scalar_outputs=True, 2762 ) 2763 def test_export_preserve_constraints_as_metadata_tensor(self): 2764 def f(x): 2765 b = x.nonzero() 2766 torch._check(b.shape[0] >= 2) 2767 torch._check(b.shape[0] <= 5) 2768 return b 2769 2770 y = torch.tensor([8, 8, 6]) 2771 gm, _ = torch._dynamo.export( 2772 f, 2773 aten_graph=True, 2774 tracing_mode="symbolic", 2775 )(y) 2776 2777 @config.patch( 2778 capture_dynamic_output_shape_ops=True, 2779 specialize_int=True, 2780 capture_scalar_outputs=True, 2781 ) 2782 def test_exported_graph_serialization(self): 2783 def f(x, y): 2784 b = x.item() 2785 torch._check_is_size(b) 2786 return torch.empty((b, y.shape[0])) 2787 2788 x = torch.tensor([3]) 2789 y = torch.randn([8, 8, 6]) 2790 example_inputs = [x, y] 2791 dynamic_shapes = (None, {0: torch.export.Dim("dimy", min=6, max=10)}) 2792 gm, _ = torch._dynamo.export( 2793 f, 2794 dynamic_shapes=dynamic_shapes, 2795 aten_graph=True, 2796 tracing_mode="symbolic", 2797 )(*example_inputs) 2798 2799 # Ensure the exported graph module with metadata is serializable, 2800 # metadata won't be saved in the serialized module 2801 buffer = io.BytesIO() 2802 torch.save(gm, buffer) 2803 2804 def test_export_dynamic_dim_not_1(self): 2805 x = torch.randn([1, 1, 1]) 2806 2807 def my_dyn_fn(a): 2808 if a.shape[0] != 1: 2809 return a.cos() 2810 return a * a 2811 2812 torch._dynamo.export(my_dyn_fn)(x) 2813 with self.assertRaises(ConstraintViolationError): 2814 torch._dynamo.export( 2815 my_dyn_fn, dynamic_shapes=({0: torch.export.Dim("dimx")},) 2816 )(x) 2817 2818 def test_symbool(self): 2819 def f(x): 2820 a = torch.scalar_tensor(x.shape[0] > 4) 2821 return x.sin().sum() + a.sum() 2822 2823 gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.ones(6, 4)) 2824 self.assertEqual(gm(torch.ones(3, 4)), f(torch.ones(3, 4))) 2825 2826 def test_export_multi_dynamic_dim_constraint(self): 2827 x = torch.randn([3, 3, 3]) 2828 y = torch.randn([2, 2, 2]) 2829 z = torch.randn([3, 3, 3]) 2830 2831 def my_dyn_fn(a, b, c): 2832 if a.shape[0] == c.shape[0]: 2833 return a.cos() 2834 return a * c, b 2835 2836 torch._dynamo.export(my_dyn_fn)(x, y, z) 2837 dimx_0, dimx_1, dimx_2 = torch.export.dims("dimx_0", "dimx_1", "dimx_2") 2838 dynamic_shapes = ({0: dimx_0, 1: dimx_1, 2: dimx_2}, None, None) 2839 with self.assertRaises(ConstraintViolationError): 2840 torch._dynamo.export(my_dyn_fn, dynamic_shapes=dynamic_shapes)(x, y, z) 2841 dynamic_shapes = ({0: dimx_0, 1: dimx_1, 2: dimx_2}, None, {0: dimx_0}) 2842 torch._dynamo.export(my_dyn_fn, dynamic_shapes=dynamic_shapes)(x, y, z) 2843 2844 def test_export_dynamic_dim_range_constraint(self): 2845 x = torch.ones(6, 4, 4) 2846 dynamic_shapes = ({0: torch.export.Dim("dimx", min=5, max=6)},) 2847 2848 def foo(x): 2849 if x.shape[0] > 3: # ok 2850 return x.sin() 2851 return x.cos() 2852 2853 torch._dynamo.export( 2854 foo, 2855 dynamic_shapes=dynamic_shapes, 2856 aten_graph=True, 2857 )(x) 2858 2859 def bar(x): 2860 if x.shape[0] > 5: # error 2861 return x.sin() 2862 return x.cos() 2863 2864 with self.assertRaises(ConstraintViolationError): 2865 torch._dynamo.export( 2866 bar, 2867 dynamic_shapes=dynamic_shapes, 2868 aten_graph=True, 2869 )(x) 2870 2871 def test_trivial_constraint(self): 2872 class Foo(torch.nn.Module): 2873 def forward(self, x): 2874 # complex divisibility condition 2875 if (2 * x.shape[0] + 3) % (x.shape[0] - 3) == 0: 2876 return x + 1 2877 else: 2878 return x - 1 2879 2880 foo = Foo() 2881 2882 class Bar(torch.nn.Module): 2883 def forward(self, x): 2884 # trivially true 2885 if (2 * x.shape[0] + 2) % (x.shape[0] + 1) == 0: 2886 return x + 1 2887 else: 2888 return x - 1 2889 2890 bar = Bar() 2891 2892 class Qux(torch.nn.Module): 2893 def forward(self, x): 2894 # simple divisibility condition (not trivially true) 2895 if (3 * x.shape[0]) % 2 == 0: 2896 return x + 1 2897 else: 2898 return x - 1 2899 2900 qux = Qux() 2901 2902 x = torch.randn(12) 2903 dim0 = torch.export.Dim("dim0", max=100) 2904 dynamic_shapes = {"x": (dim0,)} 2905 with self.assertRaisesRegex( 2906 torch._dynamo.exc.UserError, 2907 r"Constraints violated \(dim0\)", 2908 ): 2909 torch.export.export(foo, (x,), dynamic_shapes=dynamic_shapes) 2910 2911 torch.export.export(bar, (x,), dynamic_shapes=dynamic_shapes) 2912 2913 with self.assertRaisesRegex( 2914 torch._dynamo.exc.UserError, 2915 r"Constraints violated \(dim0\)", 2916 ): 2917 torch.export.export(qux, (x,), dynamic_shapes=dynamic_shapes) 2918 2919 def test_list_contains(self): 2920 def func(x): 2921 assert x.size(-1) in [4, 5, 6], "bad" 2922 return x + x 2923 2924 inps = (torch.randn(1, 5),) 2925 opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) 2926 real_result = opt_func(*inps) 2927 2928 torch._dynamo.reset() 2929 2930 exported = torch._dynamo.export(func, aten_graph=True)(*inps) 2931 out_graph = exported[0] 2932 2933 dynamo_result = out_graph(*inps) 2934 2935 self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) 2936 2937 def test_list_not_contains(self): 2938 def func(x): 2939 assert x.size(0) not in [4, 5, 6], "bad1" 2940 assert "monkey" not in ["cow", "pig"], "bad2" 2941 return x + x 2942 2943 inps = (torch.randn(1, 5),) 2944 opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) 2945 real_result = opt_func(*inps) 2946 2947 torch._dynamo.reset() 2948 2949 exported = torch._dynamo.export(func, aten_graph=True)(*inps) 2950 out_graph = exported[0] 2951 2952 dynamo_result = out_graph(*inps) 2953 2954 self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) 2955 2956 def test_export_identity(self): 2957 inp = torch.tensor([0.1, 0.1]) 2958 2959 def func(x): 2960 return x 2961 2962 torch._dynamo.reset() 2963 exported, _ = torch._dynamo.export(func)(inp) 2964 dynamo_result = exported(inp) 2965 self.assertTrue(torch._dynamo.utils.same(inp, dynamo_result)) 2966 2967 def test_export_specialized_int(self): 2968 class Foo(torch.nn.Module): 2969 def __init__( 2970 self, 2971 input_dim, 2972 ): 2973 super().__init__() 2974 self.torch_module = torch.nn.LayerNorm( 2975 input_dim, eps=1e-5, elementwise_affine=True 2976 ) 2977 self.int_val = 100 2978 2979 def forward(self, input): 2980 return input.cos() * self.int_val * self.torch_module.eps 2981 2982 mod = Foo(128) 2983 inp = torch.randn(3, 128) 2984 2985 # In export, int & float in forward should always be specialized 2986 gm, _ = torch._dynamo.export(mod, aten_graph=True)(inp) 2987 count = 0 2988 for node in gm.graph.nodes: 2989 if node.op == "placeholder": 2990 count += 1 2991 self.assertEqual(count, 1) 2992 2993 def test_export_with_nonzero_static(self): 2994 class BasicModule(torch.nn.Module): 2995 def __init__(self, static_size): 2996 super().__init__() 2997 self.static_size = static_size 2998 2999 def forward(self, x): 3000 return torch.nonzero_static(x, size=self.static_size) 3001 3002 input_tensors = torch.tensor([6, 8]), torch.zeros(2, 3) 3003 static_sizes = 3, 4 3004 for input_tensor, static_size in zip(input_tensors, static_sizes): 3005 m = BasicModule(static_size) 3006 gm, _ = torch._dynamo.export(m, aten_graph=True)(input_tensor) 3007 res = gm(input_tensor) 3008 self.assertEqual(res.size(0), static_size) 3009 self.assertTrue( 3010 torch._dynamo.utils.same( 3011 res, torch.nonzero_static(input_tensor, size=static_size) 3012 ) 3013 ) 3014 3015 def test_export_pass_arg_by_name(self): 3016 class BasicModule(torch.nn.Module): 3017 def __init__(self) -> None: 3018 super().__init__() 3019 self.my_lin = torch.nn.Linear(3, 4, bias=True) 3020 3021 def forward(self, x): 3022 return self.my_lin(x) 3023 3024 mod, input_tensor = BasicModule(), torch.randn(2, 3) 3025 gm, guard = torch._dynamo.export(mod, aten_graph=True)(input_tensor) 3026 ref = mod(x=input_tensor) 3027 res = gm(x=input_tensor) 3028 self.assertTrue(torch._dynamo.utils.same(ref, res)) 3029 3030 def test_export_pass_arg_by_name_star_args(self): 3031 class BasicModule(torch.nn.Module): 3032 def __init__(self) -> None: 3033 super().__init__() 3034 self.my_lin = torch.nn.Linear(3, 4, bias=True) 3035 3036 def forward(self, *args): 3037 return self.my_lin(args[0]) * self.my_lin(args[1]) 3038 3039 mod, input_tensor, input_tensor2 = ( 3040 BasicModule(), 3041 torch.randn(2, 3), 3042 torch.randn(2, 3), 3043 ) 3044 gm, guard = torch._dynamo.export(mod, aten_graph=True)( 3045 input_tensor, input_tensor2 3046 ) 3047 ref = mod(input_tensor, input_tensor2) 3048 res = gm(input_tensor, input_tensor2) 3049 self.assertTrue(torch._dynamo.utils.same(ref, res)) 3050 3051 def test_export_mark_dynamic_conflict_dynamic_dim(self): 3052 y = torch.randn([3, 3, 3]) 3053 3054 def my_dyn_fn(x): 3055 if x.shape[0] > 3: 3056 return x.sin() 3057 return x.cos() 3058 3059 torch._dynamo.mark_dynamic(y, 0) 3060 with self.assertRaisesRegex( 3061 RuntimeError, 3062 "Constraints violated", 3063 ): 3064 torch._dynamo.export( 3065 my_dyn_fn, dynamic_shapes=({0: torch.export.Dim("dim")},) 3066 )(y) 3067 3068 def test_export_dynamic_dim_cleanup(self): 3069 y = torch.randn([3, 3, 3]) 3070 3071 def my_dyn_fn(x): 3072 return x.cos() 3073 3074 torch._dynamo.export(my_dyn_fn, dynamic_shapes=({0: torch.export.Dim("dim")},))( 3075 y 3076 ) 3077 3078 @config.patch(capture_dynamic_output_shape_ops=True) 3079 def test_export_dynamic_control_flow_error(self): 3080 def f(x): 3081 if x.nonzero() > 3: 3082 return x.cos() 3083 return x.sin() 3084 3085 with self.assertRaisesRegex( 3086 torch._dynamo.exc.UserError, 3087 "Dynamic control flow is not supported at the moment", 3088 ): 3089 gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.randn(5, 6)) 3090 3091 @config.patch(assume_static_by_default=False) 3092 def test_export_persist_assert(self): 3093 def f(x): 3094 assert x[0].sum() > 4, "Shape must be more than 4" 3095 return x.cos() + x.sin() 3096 3097 gm, guard = torch._dynamo.export(f, aten_graph=True, tracing_mode="symbolic")( 3098 torch.ones(5, 4, 6) 3099 ) 3100 3101 def has_aten_op(gm, op): 3102 for node in gm.graph.nodes: 3103 if node.target == op: 3104 return True 3105 return False 3106 3107 self.assertTrue(has_aten_op(gm, torch.ops.aten._assert_async.msg)) 3108 3109 gm.graph.eliminate_dead_code() 3110 gm.recompile() 3111 self.assertTrue(has_aten_op(gm, torch.ops.aten._assert_async.msg)) 3112 3113 with self.assertRaisesRegex(RuntimeError, "Shape must be more than 4"): 3114 gm(torch.zeros(3, 4, 5)) 3115 3116 @common_utils.parametrize( 3117 "type_fn", 3118 [ 3119 common_utils.subtest(type, name="builtin"), 3120 common_utils.subtest(lambda obj: obj.__class__, name="attr"), 3121 ], 3122 ) 3123 def test_access_class_method_from_user_class(self, type_fn): 3124 class A: 3125 @classmethod 3126 def func(cls): 3127 return torch.Tensor([4, 5]) 3128 3129 def f(x): 3130 a = A() 3131 return x.sum() + type_fn(a).func().sum() 3132 3133 gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.ones(6, 4)) 3134 self.assertEqual(f(torch.ones(6, 4)), gm(torch.ones(6, 4))) 3135 3136 def test_not_functionalize(self): 3137 class Foo(torch.nn.Module): 3138 def __init__(self) -> None: 3139 super().__init__() 3140 self.buffer1 = torch.nn.Buffer(torch.ones(6, 2)) 3141 3142 def forward(self, x): 3143 x.add_(2) 3144 return x.sum() + self.buffer1.sum() 3145 3146 example_inputs = (torch.ones(1, 2, 3),) 3147 gm, _ = torch._dynamo.export( 3148 Foo(), 3149 aten_graph=True, 3150 tracing_mode="symbolic", 3151 )(*example_inputs) 3152 count = 0 3153 for node in gm.graph.nodes: 3154 if node.target == torch.ops.aten.add_.Tensor: 3155 count += 1 3156 self.assertEqual(count, 1) 3157 test_inp = (torch.ones(1, 2, 3),) 3158 test_inp_v2 = (torch.ones(1, 2, 3),) 3159 self.assertEqual(gm(*test_inp), Foo()(*test_inp_v2)) 3160 3161 def test_round_dynamic_shapes(self): 3162 def f(x): 3163 return x[: round(x.shape[0] / 2)] 3164 3165 gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.ones(6, 4)) 3166 3167 self.assertEqual(f(torch.ones(6, 4)), gm(torch.ones(6, 4))) 3168 3169 def test_cond_supported_pred_types(self): 3170 def true_fn(x): 3171 return x.cos() 3172 3173 def false_fn(x): 3174 return x.sin() 3175 3176 def f_pred_traced_as_symnode_var(x): 3177 return cond(x.shape[0] > 2, true_fn, false_fn, [x]) 3178 3179 def f_pred_traced_as_tensor_var(x): 3180 return cond(x.all(), true_fn, false_fn, [x]) 3181 3182 def f_pred_complex_expression_traced_as_symnode_var(x): 3183 return cond( 3184 x.dim() > 1 and x.shape[1] > 5 and x.shape[1] <= 10, 3185 true_fn, 3186 false_fn, 3187 [x], 3188 ) 3189 3190 example_inputs = (torch.rand(5, 8),) 3191 for f in [ 3192 f_pred_traced_as_symnode_var, 3193 f_pred_traced_as_tensor_var, 3194 f_pred_complex_expression_traced_as_symnode_var, 3195 ]: 3196 gm, _ = torch._dynamo.export(f, aten_graph=True)(*example_inputs) 3197 self.assertEqual(gm(*example_inputs), f(*example_inputs)) 3198 3199 @unittest.expectedFailure # TODO: Not sure why dynamo creates a new inputs for self.a 3200 def test_sum_param(self): 3201 # Setting a new attribute inside forward() 3202 class Foo(torch.nn.Module): 3203 def __init__(self) -> None: 3204 super().__init__() 3205 self.a = torch.randn(3, 2) 3206 3207 def forward(self, x): 3208 self.b = 2 3209 return x.sum() + self.a.sum() + self.b 3210 3211 torch._dynamo.export(Foo())(torch.randn(3, 2)) 3212 3213 def test_mixed_real_and_fake_inputs(self): 3214 class _TestPattern(torch.nn.Module): 3215 def __init__(self) -> None: 3216 super().__init__() 3217 self.conv = torch.nn.Conv2d(1, 1, 1) 3218 self.bn = torch.nn.BatchNorm2d(1) 3219 3220 def forward(self, input): 3221 running_std = torch.sqrt(self.bn.running_var + self.bn.eps) 3222 scale_factor = self.bn.weight / running_std 3223 weight_shape = [1] * len(self.conv.weight.shape) 3224 weight_shape[0] = -1 3225 bias_shape = [1] * len(self.conv.weight.shape) 3226 bias_shape[1] = -1 3227 scaled_weight = self.conv.weight * scale_factor.reshape(weight_shape) 3228 zero_bias = torch.zeros_like(self.conv.bias, dtype=input.dtype) 3229 conv = self.conv._conv_forward(input, scaled_weight, zero_bias) 3230 conv_orig = conv / scale_factor.reshape(bias_shape) 3231 conv_orig = conv_orig + self.conv.bias.reshape(bias_shape) 3232 conv = self.bn(conv_orig) 3233 return conv 3234 3235 example_inputs = (torch.randn(1, 1, 3, 3),) 3236 torch._dynamo.export( 3237 _TestPattern(), 3238 aten_graph=True, 3239 )(*example_inputs) 3240 3241 @config.patch( 3242 capture_dynamic_output_shape_ops=True, 3243 capture_scalar_outputs=True, 3244 assume_static_by_default=False, 3245 ) 3246 def test_sym_contains(self): 3247 def f(x, y): 3248 return x.size(0) in y 3249 3250 gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.ones(2), torch.ones(3)) 3251 3252 true_inp = (torch.Tensor([6, 4, 5]), torch.ones(6, 4).add_(5)) 3253 false_inp = (torch.Tensor([6, 4, 5]), torch.ones(6, 4).add_(2)) 3254 self.assertEqual(gm(*true_inp), f(*true_inp)) 3255 self.assertEqual(gm(*false_inp), f(*false_inp)) 3256 3257 def test_cond_raise_user_error_on_missing_args(self): 3258 def true_fn(x): 3259 return x.cos() 3260 3261 def false_fn(x): 3262 return x.sin() 3263 3264 def f(x): 3265 return cond(x.shape[0] > 10, true_fn, false_fn) 3266 3267 example_inputs = (torch.rand(5),) 3268 with self.assertRaisesRegex( 3269 TypeError, 3270 r"cond\(\) missing 1 required positional argument: 'operands'", 3271 ): 3272 f(*example_inputs) 3273 3274 def test_cond_raise_user_error_on_unsupported_pred(self): 3275 def f_unsupported_pred(x): 3276 pred = torch.nn.Module() 3277 return cond(pred, lambda x: x.sin(), lambda x: x.cos(), [x]) 3278 3279 example_inputs = (torch.rand(5),) 3280 with self.assertRaisesRegex( 3281 RuntimeError, 3282 "Expected pred to be bool or tensor, but got Module()", 3283 ): 3284 f_unsupported_pred(*example_inputs) 3285 3286 def test_cond_raise_user_error_on_non_list_operands(self): 3287 def f_non_list_operands(x): 3288 return cond(torch.tensor(True), lambda x: x.sin(), lambda x: x.cos(), x) 3289 3290 example_inputs = (torch.rand(5),) 3291 with self.assertRaisesRegex( 3292 RuntimeError, 3293 r"Expect operands to be a tuple of possibly nested dict/list/tuple", 3294 ): 3295 f_non_list_operands(*example_inputs) 3296 3297 def test_cond_raise_user_error_on_non_tensor_operands(self): 3298 def f_non_tensor_operands(x): 3299 a: float = 3.14 3300 return cond( 3301 torch.tensor(1234), lambda x, a: x.sin(), lambda x, a: x.cos(), [x, a] 3302 ) 3303 3304 example_inputs = (torch.rand(5),) 3305 with self.assertRaisesRegex( 3306 RuntimeError, 3307 r"Expect operands to be a tuple of possibly nested dict/list/tuple", 3308 ): 3309 f_non_tensor_operands(*example_inputs) 3310 3311 def test_cond_raise_user_error_on_branch_args_mismatch(self): 3312 def true_fn(x, y): 3313 return x.sin() 3314 3315 def false_fn(x): 3316 return x.cos() 3317 3318 def f_branch_args_mismatch(x, y): 3319 return cond(torch.tensor([[[[True]]]]), true_fn, false_fn, [x, y]) 3320 3321 example_inputs = (torch.rand(5), torch.rand(2)) 3322 with self.assertRaisesRegex( 3323 torch._dynamo.exc.UncapturedHigherOrderOpError, 3324 "Cond doesn't work unless it is captured completely with torch.compil", 3325 ): 3326 torch._dynamo.export( 3327 f_branch_args_mismatch, 3328 aten_graph=True, 3329 )( 3330 *example_inputs, 3331 ) 3332 3333 @config.patch(suppress_errors=True) 3334 def test_uncaptured_higher_order_op_error_not_suppresed(self): 3335 def true_fn(x, y): 3336 return x.sin() 3337 3338 def false_fn(x): 3339 return x.cos() 3340 3341 def f_branch_args_mismatch(x, y): 3342 return cond(torch.tensor([[[[100]]]]), true_fn, false_fn, [x, y]) 3343 3344 example_inputs = (torch.rand(5), torch.rand(2)) 3345 with self.assertRaisesRegex( 3346 torch._dynamo.exc.UncapturedHigherOrderOpError, 3347 "Cond doesn't work unless it is captured completely with torch.compile", 3348 ): 3349 torch._dynamo.export( 3350 f_branch_args_mismatch, 3351 aten_graph=True, 3352 )( 3353 *example_inputs, 3354 ) 3355 3356 def test_cond_raise_user_error_on_branch_return_non_tensor(self): 3357 def f_branch_return_non_tensor(x): 3358 return cond(x.shape[0] <= 5, lambda x: 3.14, lambda x: 3.14, [x]) 3359 3360 example_inputs = (torch.rand(5),) 3361 with self.assertRaisesRegex( 3362 torch._dynamo.exc.UncapturedHigherOrderOpError, 3363 "Cond doesn't work unless it is captured completely with torch.compile", 3364 ): 3365 torch._dynamo.export( 3366 f_branch_return_non_tensor, 3367 aten_graph=True, 3368 )(*example_inputs) 3369 3370 def test_cond_raise_user_error_on_branch_return_multiple_tensors(self): 3371 def f_branch_return_multiple_tensors(pred, x, y): 3372 return cond(pred, lambda x: (x, x), lambda x: (x, x), [y]) 3373 3374 example_inputs = (torch.tensor(True), torch.randn(4), torch.randn(2)) 3375 gm, _ = torch._dynamo.export( 3376 f_branch_return_multiple_tensors, 3377 aten_graph=True, 3378 )(*example_inputs) 3379 self.assertEqual( 3380 gm(*example_inputs), f_branch_return_multiple_tensors(*example_inputs) 3381 ) 3382 3383 def test_multiple_outputs_op_with_evaluator(self): 3384 class TopKModel(torch.nn.Module): 3385 def forward(self, x): 3386 values, _ = torch.topk(x, 3) 3387 return torch.sum(values) 3388 3389 x = torch.arange(1.0, 6.0, requires_grad=True) 3390 torch._dynamo.export(TopKModel())(x) 3391 3392 def test_cond_raise_user_error_on_mismatch_return_length(self): 3393 def true_fn(x): 3394 return x 3395 3396 def false_fn(x): 3397 return (x, x) 3398 3399 def f_mismatch_return_length(x): 3400 return cond(torch.tensor(100), true_fn, false_fn, [x]) 3401 3402 example_inputs = (torch.rand(5),) 3403 with self.assertRaisesRegex( 3404 RuntimeError, "Unmatched number of outputs from cond" 3405 ): 3406 torch._dynamo.export( 3407 f_mismatch_return_length, 3408 aten_graph=True, 3409 )(*example_inputs) 3410 3411 def test_cond_raise_user_error_on_mismatch_return_tensor_meta(self): 3412 def true_fn(x): 3413 return torch.tensor([[3], [2]]) 3414 3415 def false_fn(x): 3416 return torch.tensor([3.14]) 3417 3418 def f_return_tensor_mismatch(x): 3419 return cond(x.shape[0] < 3, true_fn, false_fn, [x]) 3420 3421 example_inputs = (torch.rand(5),) 3422 with self.assertRaisesRegex( 3423 torch._dynamo.exc.UncapturedHigherOrderOpError, 3424 "Cond doesn't work unless it is captured completely with torch.compile", 3425 ): 3426 torch._dynamo.export(f_return_tensor_mismatch, aten_graph=True)( 3427 *example_inputs, 3428 ) 3429 3430 def test_byte_tensor_does_not_crash(self): 3431 # See https://github.com/pytorch/pytorch/issues/100455 3432 def func(text): 3433 tensor = torch.ByteTensor(list(bytes(text, "utf8"))) 3434 return tensor + tensor 3435 3436 text = "".join(chr(a % 90 + 40) for a in range(111)) 3437 opt_func = torch._dynamo.optimize("eager", dynamic=True)(func) 3438 for i in [99, 100]: 3439 input = text[:i] 3440 opt_func(input) 3441 3442 def test_export_defaults_ok(self): 3443 class DynamicSliceExportMod(torch.nn.Module): 3444 def forward(self, x): 3445 results = [] 3446 for i in range(4): 3447 results.append(x[: x.size(0) - i, i : x.size(2), i:3]) 3448 return tuple(results) 3449 3450 gm, _ = torch._dynamo.export(DynamicSliceExportMod(), aten_graph=True)( 3451 torch.randn(5, 5, 5), 3452 ) 3453 3454 self.assertExpectedInline( 3455 gm.code.strip(), 3456 """\ 3457def forward(self, x): 3458 arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) 3459 arg0_1 = arg0 3460 sym_size_int = torch.ops.aten.sym_size.int(arg0_1, 0) 3461 slice_1 = torch.ops.aten.slice.Tensor(arg0_1, 2, 0, 3) 3462 sub = sym_size_int - 1 3463 slice_2 = torch.ops.aten.slice.Tensor(arg0_1, 0, 0, sub); sub = None 3464 slice_3 = torch.ops.aten.slice.Tensor(slice_2, 1, 1, sym_size_int); slice_2 = None 3465 slice_4 = torch.ops.aten.slice.Tensor(slice_3, 2, 1, 3); slice_3 = None 3466 sub_1 = sym_size_int - 2 3467 slice_5 = torch.ops.aten.slice.Tensor(arg0_1, 0, 0, sub_1); sub_1 = None 3468 slice_6 = torch.ops.aten.slice.Tensor(slice_5, 1, 2, sym_size_int); slice_5 = None 3469 slice_7 = torch.ops.aten.slice.Tensor(slice_6, 2, 2, 3); slice_6 = None 3470 sub_2 = sym_size_int - 3 3471 slice_8 = torch.ops.aten.slice.Tensor(arg0_1, 0, 0, sub_2); arg0_1 = sub_2 = None 3472 slice_9 = torch.ops.aten.slice.Tensor(slice_8, 1, 3, sym_size_int); slice_8 = sym_size_int = None 3473 slice_10 = torch.ops.aten.slice.Tensor(slice_9, 2, 3, 3); slice_9 = None 3474 return pytree.tree_unflatten([slice_1, slice_4, slice_7, slice_10], self._out_spec)""", 3475 ) 3476 3477 def test_capture_symbolic_tracing_simple_within_fake_mode(self): 3478 from torch._dynamo.output_graph import config 3479 3480 def f(x): 3481 y = torch.randn(3) 3482 return x + x * y 3483 3484 with fake_tensor.FakeTensorMode( 3485 shape_env=ShapeEnv( 3486 allow_scalar_outputs=config.capture_scalar_outputs, 3487 allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops, 3488 ), 3489 ): 3490 x = torch.randn(3) 3491 3492 for aten_graph in [True, False]: 3493 gm, _ = torch._dynamo.export(f, aten_graph=aten_graph)(x) 3494 self.assertTrue( 3495 isinstance(gm, torch.fx.GraphModule), 3496 msg="test_capture_symbolic_tracing_simple_within_fake_mode_aten_graph_" 3497 + str(aten_graph), 3498 ) 3499 3500 def test_export_with_symbool_inputs(self): 3501 def f(pred: bool, x: torch.Tensor): 3502 if pred: 3503 return x.sin() 3504 else: 3505 return x.cos() 3506 3507 x = torch.randn([3, 4]) 3508 3509 def test_symbool_guards( 3510 f, size_tests, exp_graph, exp_guard_code, exp_shape_env_guards 3511 ): 3512 shape_env = ShapeEnv() 3513 with fake_tensor.FakeTensorMode( 3514 shape_env=shape_env, 3515 ) as fake_mode: 3516 fake_x = fake_mode.from_tensor( 3517 x, 3518 symbolic_context=StatelessSymbolicContext( 3519 dynamic_sizes=[DimDynamic.DYNAMIC for _ in range(x.dim())], 3520 ), 3521 ) 3522 for i, size in enumerate(size_tests): 3523 pred = fake_x.size(0) == size 3524 gm, guards = torch._dynamo.export(f)(pred, x) 3525 actual = normalize_gm(gm.print_readable(print_output=False)) 3526 # TODO: This is naughty, EXPECTTEST_ACCEPT=1 doesn't work 3527 self.assertExpectedInline(actual, exp_graph[i]) 3528 dynamo_shape_env_guards = [ 3529 guard 3530 for guard in guards 3531 if guard.guard_types is not None 3532 and "SHAPE_ENV" in guard.guard_types 3533 ] 3534 self.assertEqual(len(dynamo_shape_env_guards), 1) 3535 guard_code_on_predicate = [ 3536 code 3537 for code in dynamo_shape_env_guards[0].code_list 3538 if "L['pred']" in code 3539 ] 3540 self.assertEqual(guard_code_on_predicate, exp_guard_code[i]) 3541 outter_shape_env_guards = [ 3542 str(guard.expr) for guard in shape_env.guards 3543 ] 3544 self.assertEqual(outter_shape_env_guards, exp_shape_env_guards[i]) 3545 3546 true_graph = """\ 3547class GraphModule(torch.nn.Module): 3548 def forward(self, pred, x): 3549 arg1: "f32[s1, s2]"; 3550 3551 arg0, arg1, = fx_pytree.tree_flatten_spec(([pred, x], {}), self._in_spec) 3552 l_x_ = arg1 3553 3554 sin: "f32[s1, s2]" = l_x_.sin(); l_x_ = None 3555 return pytree.tree_unflatten([sin], self._out_spec) 3556""" 3557 false_graph = """\ 3558class GraphModule(torch.nn.Module): 3559 def forward(self, pred, x): 3560 arg1: "f32[s1, s2]"; 3561 3562 arg0, arg1, = fx_pytree.tree_flatten_spec(([pred, x], {}), self._in_spec) 3563 l_x_ = arg1 3564 3565 cos: "f32[s1, s2]" = l_x_.cos(); l_x_ = None 3566 return pytree.tree_unflatten([cos], self._out_spec) 3567""" 3568 true_guard_code = [ 3569 "cast_symbool_to_symint_guardless(L['pred']) == 1", 3570 ] 3571 false_guard_code = [ 3572 "Ne(cast_symbool_to_symint_guardless(L['pred']), 1)", 3573 ] 3574 test_symbool_guards( 3575 f, 3576 [3, 3, 4, 5], 3577 [true_graph, true_graph, false_graph, false_graph], 3578 [true_guard_code, true_guard_code, false_guard_code, false_guard_code], 3579 # Outter shape env should have no guards in it because we never specialize on the outter symbool. 3580 [[], [], [], []], 3581 ) 3582 3583 def test_invalid_input_global(self) -> None: 3584 global bulbous_bouffant 3585 bulbous_bouffant = torch.randn(3) 3586 3587 def f(y): 3588 return bulbous_bouffant + y 3589 3590 self.assertExpectedInlineMunged( 3591 UserError, 3592 lambda: torch._dynamo.export(f)(torch.randn(3)), 3593 """\ 3594G['bulbous_bouffant'], accessed at: 3595 File "test_export.py", line N, in f 3596 return bulbous_bouffant + y 3597""", 3598 ) 3599 3600 def test_invalid_input_global_multiple_access(self) -> None: 3601 global macademia 3602 macademia = torch.randn(3) 3603 3604 def g(y): 3605 global macademia 3606 y = macademia + y 3607 return y 3608 3609 def f(y): 3610 global macademia 3611 y = g(y) 3612 return macademia + y 3613 3614 # NB: This doesn't actually work (it only reports the first usage), 3615 # but I'm leaving the test here in case we fix it later 3616 self.assertExpectedInlineMunged( 3617 UserError, 3618 lambda: torch._dynamo.export(f)(torch.randn(3)), 3619 """\ 3620G['macademia'], accessed at: 3621 File "test_export.py", line N, in f 3622 y = g(y) 3623 File "test_export.py", line N, in g 3624 y = macademia + y 3625""", 3626 ) 3627 3628 def test_invalid_input_nonlocal(self) -> None: 3629 arglebargle = torch.randn(3) 3630 3631 def f(y): 3632 return arglebargle + y 3633 3634 self.assertExpectedInlineMunged( 3635 UserError, 3636 lambda: torch._dynamo.export(f)(torch.randn(3)), 3637 """L['arglebargle'], a closed over free variable""", 3638 ) 3639 3640 def test_invalid_input_unused_nonlocal_ok(self) -> None: 3641 arglebargle = torch.randn(3) 3642 3643 def f(y): 3644 x = arglebargle 3645 return y 3646 3647 torch._dynamo.export(f)(torch.randn(3)) 3648 3649 def test_symbolic_tracing_within_fake_mode_with_constraints(self): 3650 from torch._subclasses import fake_tensor 3651 3652 fake_mode = fake_tensor.FakeTensorMode() 3653 3654 class DynamicShapeSimpleModel(torch.nn.Module): 3655 def __init__(self) -> None: 3656 super().__init__() 3657 3658 def forward(self, a, b, c) -> torch.Tensor: 3659 d = (torch.matmul(a, b) + c) / 2 3660 d_s0 = d.shape[0] 3661 d_s1 = d.shape[1] 3662 d_s3 = d_s0 * d_s1 3663 e = d.view(d_s3) 3664 return torch.cat([e, e]) 3665 3666 with fake_mode: 3667 model = DynamicShapeSimpleModel() 3668 inputs = (torch.randn(2, 4), torch.randn(4, 7), torch.randn(2, 7)) 3669 dim = torch.export.Dim("dim") 3670 dynamic_shapes = ({0: dim}, None, {0: dim}) 3671 for aten_graph in [True, False]: 3672 gm = torch._dynamo.export( 3673 model, 3674 dynamic_shapes=dynamic_shapes, 3675 aten_graph=aten_graph, 3676 )(*inputs).graph_module 3677 3678 # Since there are no parameters we can do this 3679 inputs = (torch.randn(2, 4), torch.randn(4, 7), torch.randn(2, 7)) 3680 self.assertEqual(model(*inputs), gm(*inputs)) 3681 3682 def test_symbolic_tracing_within_fake_mode_with_constraints_with_parameters(self): 3683 from torch._subclasses import fake_tensor 3684 3685 fake_mode = fake_tensor.FakeTensorMode() 3686 3687 # TODO: Seems to choke if you don't make a fresh model and 3688 # just try to export Linear directly... 3689 class Model(torch.nn.Module): 3690 def __init__(self) -> None: 3691 super().__init__() 3692 self.linear = torch.nn.Linear(2, 2) 3693 3694 def forward(self, x): 3695 out = self.linear(x) 3696 return out 3697 3698 with fake_mode: 3699 model = Model() 3700 inputs = (torch.randn(10, 2, 2),) 3701 dynamic_shapes = ({0: torch.export.Dim("dim")},) 3702 for aten_graph in [True, False]: 3703 gm = torch._dynamo.export( 3704 model, 3705 dynamic_shapes=dynamic_shapes, 3706 aten_graph=aten_graph, 3707 )(*inputs).graph_module 3708 3709 def test_capture_symbolic_tracing_within_fake_mode(self): 3710 from torch._dynamo.output_graph import config 3711 from torch._subclasses import fake_tensor 3712 from torch.fx.experimental.symbolic_shapes import ShapeEnv 3713 3714 class Model(torch.nn.Module): 3715 def __init__(self) -> None: 3716 super().__init__() 3717 self.linear = torch.nn.Linear(2, 2) 3718 self.linear2 = torch.nn.Linear(2, 2) 3719 3720 def forward(self, x): 3721 out = self.linear(x) 3722 out = self.linear2(out) 3723 return out 3724 3725 # User-instantiated FakeTensorMode 3726 fake_mode = fake_tensor.FakeTensorMode( 3727 allow_non_fake_inputs=False, 3728 allow_fallback_kernels=True, 3729 shape_env=ShapeEnv( 3730 allow_scalar_outputs=config.capture_scalar_outputs, 3731 allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops, 3732 ), 3733 ) 3734 # Fakefy input+model before exporting it 3735 with fake_mode: 3736 x = torch.rand(5, 2, 2) 3737 model = Model() 3738 3739 # Export the model with fake inputs and parameters 3740 for aten_graph in [True, False]: 3741 graph_module, _ = torch._dynamo.export(model, aten_graph=aten_graph)(x) 3742 self.assertTrue( 3743 isinstance(graph_module, torch.fx.GraphModule), 3744 msg="test_capture_symbolic_tracing_within_fake_mode_aten_graph_" 3745 + str(aten_graph), 3746 ) 3747 3748 def test_cond_op_param_buffer_lifted(self): 3749 class A(torch.nn.Module): 3750 def __init__(self) -> None: 3751 super().__init__() 3752 self.buffer1 = torch.nn.Buffer(torch.zeros(6, 4)) 3753 3754 def forward(self): 3755 return self.buffer1.sum() 3756 3757 class B(torch.nn.Module): 3758 def __init__(self) -> None: 3759 super().__init__() 3760 self.buffer2 = torch.nn.Buffer(torch.ones(6, 4)) 3761 3762 def forward(self): 3763 return self.buffer2.sum() 3764 3765 class M(torch.nn.Module): 3766 def __init__(self) -> None: 3767 super().__init__() 3768 self.a = A() 3769 self.b = B() 3770 3771 def forward(self, x): 3772 def true_fn(x): 3773 return x.cos() + self.a() 3774 3775 def false_fn(x): 3776 return x.sin() + self.b() 3777 3778 return (cond(x.shape[0] > 4, true_fn, false_fn, [x]),) 3779 3780 gm, _ = torch._dynamo.export(M(), aten_graph=False)(torch.ones(6, 4)) 3781 self.assertEqual(gm(torch.ones(6, 4)), M()(torch.ones(6, 4))) 3782 self.assertEqual(gm(torch.ones(3, 4)), M()(torch.ones(3, 4))) 3783 3784 def test_nested_cond_op_param_buffer_lifted(self): 3785 class A(torch.nn.Module): 3786 def __init__(self) -> None: 3787 super().__init__() 3788 self.buffer1 = torch.nn.Buffer(torch.zeros(6, 4)) 3789 3790 def forward(self): 3791 return self.buffer1.sum() 3792 3793 class B(torch.nn.Module): 3794 def __init__(self) -> None: 3795 super().__init__() 3796 self.buffer2 = torch.nn.Buffer(torch.ones(6, 4)) 3797 3798 def forward(self): 3799 return self.buffer2.sum() 3800 3801 class M(torch.nn.Module): 3802 def __init__(self) -> None: 3803 super().__init__() 3804 self.a = A() 3805 self.b = B() 3806 3807 def forward(self, x): 3808 def true_true_fn(x): 3809 return x.cos() + self.a() 3810 3811 def true_false_fn(x): 3812 return x.cos() + self.a() + 1 3813 3814 def true_fn(x): 3815 return cond(x.shape[0] > 5, true_true_fn, true_false_fn, [x]) 3816 3817 def false_fn(x): 3818 return x.sin() + self.b() 3819 3820 return (cond(x.shape[0] > 4, true_fn, false_fn, [x]),) 3821 3822 gm, _ = torch._dynamo.export(M(), aten_graph=False)(torch.ones(6, 4)) 3823 self.assertEqual(gm(torch.ones(6, 4)), M()(torch.ones(6, 4))) 3824 self.assertEqual(gm(torch.ones(5, 4)), M()(torch.ones(5, 4))) 3825 self.assertEqual(gm(torch.ones(3, 4)), M()(torch.ones(3, 4))) 3826 3827 def test_map_cond_param_buffer_lifted(self): 3828 from functorch.experimental.control_flow import cond, map 3829 3830 class A(torch.nn.Module): 3831 def __init__(self) -> None: 3832 super().__init__() 3833 self.buffer1 = torch.nn.Buffer(torch.zeros(6, 4)) 3834 3835 def forward(self): 3836 return self.buffer1.sum() 3837 3838 class B(torch.nn.Module): 3839 def __init__(self) -> None: 3840 super().__init__() 3841 self.buffer2 = torch.nn.Buffer(torch.ones(6, 4)) 3842 3843 def forward(self): 3844 return self.buffer2.sum() 3845 3846 class Module(torch.nn.Module): 3847 def __init__(self) -> None: 3848 super().__init__() 3849 self.a = A() 3850 self.b = B() 3851 3852 def inner(self, x, pred): 3853 def true_fn(x): 3854 return x + x + self.a() 3855 3856 def false_fn(x): 3857 return x * x + self.b() 3858 3859 return cond(pred, true_fn, false_fn, [x]) 3860 3861 def forward(self, pred, xs): 3862 def body(x, pred): 3863 return self.inner(x, pred) + self.b() 3864 3865 return map(body, xs, pred) 3866 3867 mod = Module() 3868 x = torch.randn(3, 2, 1) 3869 pred_x = torch.tensor(True) 3870 3871 y = torch.randn(4, 3, 2) 3872 pred_y = torch.tensor(False) 3873 real_result = mod(pred_y, y) 3874 3875 out_graph, _ = torch._dynamo.export(mod)(pred_x, x) 3876 self.assertEqual(real_result, out_graph(pred_y, y)) 3877 3878 def test_cond_free_variables_overlapping(self): 3879 from functorch.experimental.control_flow import cond 3880 3881 class Module(torch.nn.Module): 3882 def __init__(self) -> None: 3883 super().__init__() 3884 3885 def forward(self, pred, x): 3886 a = torch.ones(6, 4) 3887 b = torch.ones(6, 4) 3888 c = torch.ones(6, 4) 3889 d = torch.ones(6, 4) 3890 3891 def true_fn(x): 3892 return x + x + a.cos() + b.cos() + d.cos() 3893 3894 def false_fn(x): 3895 return x * x + a.sin() + b.sin() + c.sin() 3896 3897 return cond(pred, true_fn, false_fn, [x]) 3898 3899 mod = Module() 3900 x = torch.ones(6, 4) 3901 pred_x = torch.tensor(True) 3902 3903 out_graph, _ = torch._dynamo.export(mod)(pred_x, x) 3904 self.assertExpectedInline( 3905 out_graph.code.strip(), 3906 """\ 3907def forward(self, pred, x): 3908 arg0, arg1, = fx_pytree.tree_flatten_spec(([pred, x], {}), self._in_spec) 3909 l_pred_ = arg0 3910 l_x_ = arg1 3911 a = torch.ones(6, 4) 3912 b = torch.ones(6, 4) 3913 c = torch.ones(6, 4) 3914 d = torch.ones(6, 4) 3915 cond_true_0 = self.cond_true_0 3916 cond_false_0 = self.cond_false_0 3917 cond = torch.ops.higher_order.cond(l_pred_, cond_true_0, cond_false_0, [a, b, l_x_, d, c]); l_pred_ = cond_true_0 = cond_false_0 = a = b = l_x_ = d = c = None 3918 getitem = cond[0]; cond = None 3919 return pytree.tree_unflatten([getitem], self._out_spec)""", # noqa: B950,E122 3920 ) 3921 3922 self.assertExpectedInline( 3923 out_graph.cond_true_0.code.strip(), 3924 """\ 3925def forward(self, a, b, l_x_, d_true_branch, c_false_branch): 3926 a_1 = a 3927 b_1 = b 3928 l_x__1 = l_x_ 3929 add = l_x__1 + l_x__1; l_x__1 = None 3930 cos = a_1.cos(); a_1 = None 3931 add_1 = add + cos; add = cos = None 3932 cos_1 = b_1.cos(); b_1 = None 3933 add_2 = add_1 + cos_1; add_1 = cos_1 = None 3934 cos_2 = d_true_branch.cos(); d_true_branch = None 3935 add_3 = add_2 + cos_2; add_2 = cos_2 = None 3936 return (add_3,)""", 3937 ) 3938 3939 self.assertExpectedInline( 3940 out_graph.cond_false_0.code.strip(), 3941 """\ 3942def forward(self, a, b, l_x_, d_true_branch, c_false_branch): 3943 a_1 = a 3944 b_1 = b 3945 l_x__1 = l_x_ 3946 mul = l_x__1 * l_x__1; l_x__1 = None 3947 sin = a_1.sin(); a_1 = None 3948 add = mul + sin; mul = sin = None 3949 sin_1 = b_1.sin(); b_1 = None 3950 add_1 = add + sin_1; add = sin_1 = None 3951 sin_2 = c_false_branch.sin(); c_false_branch = None 3952 add_2 = add_1 + sin_2; add_1 = sin_2 = None 3953 return (add_2,)""", 3954 ) 3955 3956 @unittest.skipIf( 3957 common_utils.TEST_WITH_ASAN, 3958 "Times out with ASAN, see https://github.com/pytorch/pytorch/issues/110416", 3959 ) 3960 def test_retracibility(self): 3961 class MyLinear(torch.nn.Module): 3962 def __init__(self) -> None: 3963 super().__init__() 3964 self.weight = torch.randn(20, 98) 3965 self.bias = torch.randn(20) 3966 3967 def forward(self, x): 3968 return torch.nn.functional.linear(x, self.weight, self.bias) 3969 3970 class Foo(torch.nn.Module): 3971 def __init__(self) -> None: 3972 super().__init__() 3973 self.conv = torch.nn.Conv2d(16, 33, 3) 3974 self.linear = MyLinear() 3975 3976 def forward(self, x): 3977 a, b = x 3978 a_conv = self.conv(a) 3979 a_linear = self.linear(a_conv) 3980 b_conv = self.conv(b) 3981 b_linear = self.linear(b_conv) 3982 return ( 3983 a_linear.cos() + b_linear.sin(), 3984 a_linear.sin() + b_linear.cos(), 3985 ) 3986 3987 inp_container = (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100)) 3988 3989 gm, _ = torch._dynamo.export(Foo(), inp_container, aten_graph=True) 3990 gm2, _ = torch._dynamo.export(gm, inp_container, aten_graph=True) 3991 3992 inp_test = (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100)) 3993 3994 self.assertTrue(torch.allclose(gm(inp_test)[0], gm2(inp_test)[0])) 3995 self.assertTrue(torch.allclose(gm(inp_test)[1], gm2(inp_test)[1])) 3996 3997 def test_retracibility_dict_container_inp_out(self): 3998 class MyLinear(torch.nn.Module): 3999 def __init__(self) -> None: 4000 super().__init__() 4001 self.weight = torch.randn(20, 98) 4002 self.bias = torch.randn(20) 4003 4004 def forward(self, x): 4005 return torch.nn.functional.linear(x, self.weight, self.bias) 4006 4007 class Foo(torch.nn.Module): 4008 def __init__(self) -> None: 4009 super().__init__() 4010 self.conv = torch.nn.Conv2d(16, 33, 3) 4011 self.linear = MyLinear() 4012 4013 def forward(self, x): 4014 a1, a2 = x["a"] 4015 b = x["b"] 4016 a1_conv = self.conv(a1) 4017 a1_linear = self.linear(a1_conv) 4018 a2_conv = self.conv(a2) 4019 a2_linear = self.linear(a2_conv) 4020 b_conv = self.conv(b) 4021 b_linear = self.linear(b_conv) 4022 return { 4023 "a": [ 4024 a1_linear.cos() + b_linear.sin(), 4025 a1_linear.cos() + b_linear.sin(), 4026 ], 4027 "b": a2_linear.sin() + b_linear.cos(), 4028 } 4029 4030 inp_container = { 4031 "a": (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100)), 4032 "b": torch.randn(20, 16, 50, 100), 4033 } 4034 4035 gm, _ = torch._dynamo.export(Foo(), inp_container, aten_graph=True) 4036 gm2, _ = torch._dynamo.export(gm, inp_container, aten_graph=True) 4037 4038 inp_test = { 4039 "a": (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100)), 4040 "b": torch.randn(20, 16, 50, 100), 4041 } 4042 4043 self.assertTrue(torch.allclose(gm(inp_test)["a"][0], gm2(inp_test)["a"][0])) 4044 self.assertTrue(torch.allclose(gm(inp_test)["a"][1], gm2(inp_test)["a"][1])) 4045 self.assertTrue(torch.allclose(gm(inp_test)["b"], gm2(inp_test)["b"])) 4046 4047 def test_retracibility_nested_list_out(self): 4048 class MyLinear(torch.nn.Module): 4049 def __init__(self) -> None: 4050 super().__init__() 4051 self.weight = torch.randn(20, 98) 4052 self.bias = torch.randn(20) 4053 4054 def forward(self, x): 4055 return torch.nn.functional.linear(x, self.weight, self.bias) 4056 4057 class Foo(torch.nn.Module): 4058 def __init__(self) -> None: 4059 super().__init__() 4060 self.conv = torch.nn.Conv2d(16, 33, 3) 4061 self.linear = MyLinear() 4062 4063 def forward(self, x): 4064 a1, a2 = x["a"] 4065 b = x["b"] 4066 a1_conv = self.conv(a1) 4067 a1_linear = self.linear(a1_conv) 4068 a2_conv = self.conv(a2) 4069 a2_linear = self.linear(a2_conv) 4070 b_conv = self.conv(b) 4071 b_linear = self.linear(b_conv) 4072 return [ 4073 [ 4074 a1_linear.cos() + b_linear.sin(), 4075 a1_linear.cos() + b_linear.sin(), 4076 ], 4077 [ 4078 a2_linear.sin() + b_linear.cos(), 4079 a2_linear.sin() + b_linear.cos(), 4080 ], 4081 ] 4082 4083 inp_container = { 4084 "a": (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100)), 4085 "b": torch.randn(20, 16, 50, 100), 4086 } 4087 4088 gm, _ = torch._dynamo.export(Foo(), inp_container, aten_graph=True) 4089 gm2, _ = torch._dynamo.export(gm, inp_container, aten_graph=True) 4090 4091 inp_test = { 4092 "a": (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100)), 4093 "b": torch.randn(20, 16, 50, 100), 4094 } 4095 4096 self.assertTrue(torch.allclose(gm(inp_test)[0][0], gm2(inp_test)[0][0])) 4097 self.assertTrue(torch.allclose(gm(inp_test)[0][1], gm2(inp_test)[0][1])) 4098 self.assertTrue(torch.allclose(gm(inp_test)[1][0], gm2(inp_test)[1][0])) 4099 self.assertTrue(torch.allclose(gm(inp_test)[1][1], gm2(inp_test)[1][1])) 4100 4101 def test_fx_pytree(self): 4102 def foo(args): 4103 flat_args, spec = torch.utils._pytree.tree_flatten(args) 4104 flat_args_fx = torch.fx._pytree.tree_flatten_spec(args, spec) 4105 return flat_args_fx[0] + flat_args[0] 4106 4107 inp_container = (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100)) 4108 4109 gm, _ = torch._dynamo.export(foo, inp_container, aten_graph=True) 4110 4111 self.assertTrue(torch.allclose(foo(inp_container), gm(inp_container))) 4112 4113 @config.patch(suppress_errors=True) 4114 @config.patch(verbose=True) 4115 def test_export_with_map_zero_sized_tensor_suppress_errors(self): 4116 from functorch.experimental.control_flow import map 4117 4118 class Module(torch.nn.Module): 4119 def forward(self, xs): 4120 def body(x): 4121 return x + 1 4122 4123 return map(body, xs) 4124 4125 mod = Module() 4126 xs = torch.randn(0, 2) 4127 with self.assertRaises( 4128 torch._dynamo.exc.Unsupported, 4129 ): 4130 out_graph, _ = torch._dynamo.export(mod, xs) 4131 4132 def test_param_buffer_safe_from_mutation_simple(self): 4133 class Module(torch.nn.Module): 4134 def __init__(self) -> None: 4135 super().__init__() 4136 self.buffer1 = torch.nn.Buffer(torch.zeros(5, 5)) 4137 4138 def forward(self, x): 4139 self.buffer1.add_(1) 4140 return x + self.buffer1 4141 4142 gm, _ = torch._dynamo.export(Module(), torch.ones(5, 5), aten_graph=False) 4143 buffers = list(gm.named_buffers()) 4144 self.assertEqual(len(buffers), 1) 4145 4146 name, buffer = buffers[0] 4147 self.assertEqual(name, "L__self___buffer1") 4148 4149 self.assertTrue(torch.allclose(buffer, torch.zeros(5))) 4150 4151 def test_param_buffer_safe_from_mutation_recurse(self): 4152 class Child(torch.nn.Module): 4153 def __init__(self) -> None: 4154 super().__init__() 4155 self.buffer2 = torch.nn.Buffer(torch.zeros(5)) 4156 4157 def forward(self, x): 4158 return x.sum() + self.buffer2.sum() 4159 4160 class Module(torch.nn.Module): 4161 def __init__(self) -> None: 4162 super().__init__() 4163 self.buffer1 = torch.nn.Buffer(torch.zeros(5)) 4164 self.child = Child() 4165 4166 def forward(self, x): 4167 self.buffer1.add_(1) 4168 self.child.buffer2.add_(2) 4169 return x.sum() + self.buffer1.sum() + self.child(x) 4170 4171 gm, _ = torch._dynamo.export(Module(), torch.ones(5), aten_graph=False) 4172 for name, buffer in gm.named_buffers(): 4173 self.assertTrue(torch.allclose(buffer, torch.zeros(5))) 4174 4175 def test_predispatch_with_higher_order(self): 4176 def f(x): 4177 return cond(x.shape[0] > 4, lambda x: x + 5, lambda x: x - 3, [x]) 4178 4179 gm, _ = torch._dynamo.export(f, aten_graph=True, pre_dispatch=True)( 4180 torch.randn(4, 4) 4181 ) 4182 inp1 = torch.randn(4, 4) 4183 inp2 = torch.randn(6, 4) 4184 self.assertTrue(torch.allclose(f(inp1), gm(inp1))) 4185 self.assertTrue(torch.allclose(f(inp2), gm(inp2))) 4186 4187 def test_predispatch_with_higher_order_nested(self): 4188 def f(x): 4189 def true_fn(x): 4190 return cond(x.shape[0] > 6, lambda x: x + 10, lambda x: x - 10, [x]) 4191 4192 return cond(x.shape[0] > 4, true_fn, lambda x: x - 3, [x]) 4193 4194 gm, _ = torch._dynamo.export(f, aten_graph=True, pre_dispatch=True)( 4195 torch.randn(4, 4) 4196 ) 4197 inp1 = torch.randn(4, 4) 4198 inp2 = torch.randn(6, 4) 4199 inp3 = torch.randn(8, 4) 4200 self.assertTrue(torch.allclose(f(inp1), gm(inp1))) 4201 self.assertTrue(torch.allclose(f(inp2), gm(inp2))) 4202 self.assertTrue(torch.allclose(f(inp3), gm(inp3))) 4203 4204 def test_predispatch_with_for_out_dtype(self): 4205 class M(torch.nn.Module): 4206 def __init__(self, weight): 4207 super().__init__() 4208 self.weight = weight 4209 4210 def forward(self, x): 4211 return out_dtype(torch.ops.aten.mm.default, torch.int32, x, self.weight) 4212 4213 weight = torch.randint(-128, 127, (5, 5), dtype=torch.int8) 4214 m = M(weight) 4215 x = torch.randint(-128, 127, (5, 5), dtype=torch.int8) 4216 gm, _ = torch._dynamo.export(m, x, aten_graph=True, pre_dispatch=True) 4217 4218 self.assertTrue(torch.allclose(m(x), gm(x))) 4219 4220 def test_predispatch_with_for_out_dtype_nested(self): 4221 class M(torch.nn.Module): 4222 def __init__(self, weight): 4223 super().__init__() 4224 self.weight = weight 4225 4226 def true_fn(self, x): 4227 return out_dtype( 4228 torch.ops.aten.mm.default, torch.int32, x, self.weight 4229 ).sum() 4230 4231 def false_fn(self, x): 4232 return out_dtype( 4233 torch.ops.aten.mul.Tensor, torch.int32, x, self.weight 4234 ).sum() 4235 4236 def forward(self, x): 4237 return cond(x.sum() != 0, self.true_fn, self.false_fn, [x]) 4238 4239 weight = torch.randint(-128, 127, (5, 5), dtype=torch.int8) 4240 m = M(weight) 4241 x = torch.ones((5, 5), dtype=torch.int8) 4242 gm, _ = torch._dynamo.export(m, x, aten_graph=True, pre_dispatch=True) 4243 4244 self.assertTrue(torch.allclose(m(x), gm(x))) 4245 y = torch.zeros((5, 5), dtype=torch.int8) 4246 self.assertTrue(torch.allclose(m(y), gm(y))) 4247 4248 self.assertExpectedInline( 4249 gm.true_graph_0.code.strip(), 4250 """\ 4251def forward(self, arg0_1, arg1_1): 4252 out_dtype = torch.ops.higher_order.out_dtype(torch.ops.aten.mm.default, torch.int32, arg1_1, arg0_1); arg1_1 = arg0_1 = None 4253 sum_1 = torch.ops.aten.sum.default(out_dtype); out_dtype = None 4254 return (sum_1,)""", 4255 ) 4256 4257 self.assertExpectedInline( 4258 gm.false_graph_0.code.strip(), 4259 """\ 4260def forward(self, arg0_1, arg1_1): 4261 out_dtype = torch.ops.higher_order.out_dtype(torch.ops.aten.mul.Tensor, torch.int32, arg1_1, arg0_1); arg1_1 = arg0_1 = None 4262 sum_1 = torch.ops.aten.sum.default(out_dtype); out_dtype = None 4263 return (sum_1,)""", 4264 ) 4265 4266 def test_export_nn_module_stack_patched_module(self): 4267 def forward(self, x, y): 4268 return x * y 4269 4270 class Toplevel(torch.nn.Module): 4271 def __init__(self, m): 4272 super().__init__() 4273 self.m = m 4274 4275 def forward(self, x, y): 4276 return self.m(x, y) 4277 4278 class M(torch.nn.Module): 4279 def forward(self, x, y): 4280 return x + y 4281 4282 t = Toplevel(M()) 4283 t.m.forward = forward.__get__(t.m, M) 4284 x, y = torch.rand(3), torch.rand(3) 4285 gm, _ = torch._dynamo.export(t, x, y) 4286 4287 self.assertTrue(torch.allclose(forward(None, x, y), gm(x, y))) 4288 for node in gm.graph.nodes: 4289 if node.op == "call_function": 4290 self.assertIn("nn_module_stack", node.meta) 4291 4292 def test_preserve_fx_node_metadata(self): 4293 class Module1(torch.nn.Module): 4294 def forward(self, x): 4295 return torch.sin(x) 4296 4297 class Module2(torch.nn.Module): 4298 def __init__(self) -> None: 4299 super().__init__() 4300 self.mod1 = Module1() 4301 4302 def forward(self, x): 4303 x = torch.cos(x) 4304 x = self.mod1(x) 4305 x = torch.relu(x) 4306 return x 4307 4308 def fn(x): 4309 return torch.abs(x) 4310 4311 mod = Module2() 4312 inp = torch.randn(3, 3) 4313 4314 gm, _ = torch._dynamo.export(mod)(inp) 4315 4316 # replace relu with fn 4317 gm_edit = copy.deepcopy(gm) 4318 for nd in gm_edit.graph.nodes: 4319 if nd.target == torch.relu: 4320 nd.target = fn 4321 nd.meta.clear() 4322 break 4323 gm_edit.recompile() 4324 4325 gm2, _ = torch._dynamo.export(gm_edit)(inp) 4326 4327 self.assertExpectedInline( 4328 gm.code.strip(), 4329 """\ 4330def forward(self, x): 4331 arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) 4332 l_x_ = arg0 4333 x = torch.cos(l_x_); l_x_ = None 4334 x_1 = torch.sin(x); x = None 4335 x_2 = torch.relu(x_1); x_1 = None 4336 return pytree.tree_unflatten([x_2], self._out_spec)""", 4337 ) 4338 4339 def _constais_op(gm, target): 4340 for nd in gm.graph.nodes: 4341 if nd.target == target: 4342 return True 4343 return False 4344 4345 self.assertTrue(_constais_op(gm_edit, torch.cos)) 4346 self.assertTrue(_constais_op(gm_edit, torch.sin)) 4347 self.assertTrue(not _constais_op(gm_edit, torch.relu)) 4348 4349 self.assertExpectedInline( 4350 gm2.code.strip(), 4351 """\ 4352def forward(self, x): 4353 arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) 4354 l_x_ = arg0 4355 x = torch.cos(l_x_); l_x_ = None 4356 x_1 = torch.sin(x); x = None 4357 x_2 = torch.abs(x_1); x_1 = None 4358 return pytree.tree_unflatten([x_2], self._out_spec)""", 4359 ) 4360 4361 # check for other metadata 4362 for op in (torch.sin, torch.cos): 4363 nd1 = next(filter(lambda nd: nd.target == op, gm.graph.nodes)) 4364 nd2 = next(filter(lambda nd: nd.target == op, gm2.graph.nodes)) 4365 self.assertTrue( 4366 ("nn_module_stack" in nd1.meta) == ("nn_module_stack" in nd2.meta) 4367 ) 4368 if "nn_module_stack" in nd1.meta: 4369 self.assertEqual( 4370 nd1.meta["nn_module_stack"], nd2.meta["nn_module_stack"] 4371 ) 4372 self.assertEqual(nd1.meta["stack_trace"], nd2.meta["stack_trace"]) 4373 4374 def test_preserve_fx_node_metadata_recompile(self): 4375 def fn(x): 4376 return torch.sin(x) 4377 4378 gm, _ = torch._dynamo.export(fn)(torch.randn(3, 3)) 4379 do_export = torch._dynamo.export(gm) 4380 torch._dynamo.optimize("eager")(fn)(torch.randn(3, 3)) 4381 gm1, _ = do_export(torch.randn(3, 3)) 4382 gm2, _ = do_export(torch.randn(5, 3)) 4383 4384 self.assertExpectedInline( 4385 gm1.code.strip(), 4386 """\ 4387def forward(self, x): 4388 arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) 4389 l_x_ = arg0 4390 sin = torch.sin(l_x_); l_x_ = None 4391 return pytree.tree_unflatten([sin], self._out_spec)""", 4392 ) 4393 self.assertExpectedInline( 4394 gm2.code.strip(), 4395 """\ 4396def forward(self, x): 4397 arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) 4398 l_x_ = arg0 4399 sin = torch.sin(l_x_); l_x_ = None 4400 return pytree.tree_unflatten([sin], self._out_spec)""", 4401 ) 4402 4403 def test_preserve_fx_node_metadata_inline(self): 4404 def f1(x): 4405 return torch.sin(x) 4406 4407 gm, _ = torch._dynamo.export(f1)(torch.randn(3, 3)) 4408 4409 def f2(x): 4410 x = torch.cos(x) 4411 return gm(x) 4412 4413 gm2, _ = torch._dynamo.export(f2)(torch.randn(3, 3)) 4414 4415 self.assertExpectedInline( 4416 gm2.code.strip(), 4417 """\ 4418def forward(self, x): 4419 arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) 4420 l_x_ = arg0 4421 x = torch.cos(l_x_); l_x_ = None 4422 sin = torch.sin(x); x = None 4423 return pytree.tree_unflatten([sin], self._out_spec)""", 4424 ) 4425 4426 def test_preserve_fx_node_metadata_graph_break(self): 4427 def fn(x): 4428 x = torch.sin(x) 4429 x = torch.abs(x) 4430 return torch.cos(x) 4431 4432 def bad_fn(x): 4433 torch._dynamo.graph_break() 4434 return x 4435 4436 gm, _ = torch._dynamo.export(fn)(torch.randn(3, 3)) 4437 4438 # replace abs with graph break 4439 gm_edit = copy.deepcopy(gm) 4440 for nd in gm_edit.graph.nodes: 4441 if nd.target == torch.abs: 4442 nd.target = bad_fn 4443 nd.meta.clear() 4444 break 4445 gm_edit.recompile() 4446 4447 expected = [ 4448 """x = torch.sin(l_x_)""", 4449 """cos = torch.cos(l_stack0_)""", 4450 ] 4451 4452 def test_backend(gm: torch.fx.GraphModule, example_inputs): 4453 self.assertTrue(expected) 4454 # Normalize output for dynamic and not 4455 for nd in gm.graph.nodes: 4456 if "example_value" in nd.meta: 4457 del nd.meta["example_value"] 4458 self.assertIn(expected[0], gm.print_readable(print_output=False)) 4459 expected.pop(0) 4460 return gm.forward 4461 4462 torch._dynamo.reset() 4463 opt_gm_edit = torch.compile(gm_edit, backend=test_backend) 4464 opt_gm_edit(torch.randn(3, 3)) 4465 4466 def test_torch_inference_mode_ctx(self): 4467 @torch.inference_mode() 4468 def fn(x): 4469 return x + 1 4470 4471 gm, _ = torch._dynamo.export(fn, torch.rand(2, 2)) 4472 4473 inp = torch.randn(2, 2) 4474 out = gm(inp) 4475 self.assertExpectedInline( 4476 gm.code.strip(), 4477 """\ 4478def forward(self, x): 4479 arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) 4480 l_args_0_ = arg0 4481 _enter_inference_mode = torch.autograd.grad_mode._enter_inference_mode(True) 4482 add = l_args_0_ + 1; l_args_0_ = None 4483 _exit_inference_mode = torch.autograd.grad_mode._exit_inference_mode(_enter_inference_mode); _enter_inference_mode = _exit_inference_mode = None 4484 return pytree.tree_unflatten([add], self._out_spec)""", # NOQA: B950 4485 ) 4486 self.assertEqual(out.requires_grad, False) 4487 with self.assertRaisesRegex( 4488 RuntimeError, 4489 "Setting requires_grad=True on inference tensor outside InferenceMode is not allowed.", 4490 ): 4491 out.requires_grad = True 4492 4493 @torch.inference_mode(False) 4494 def fn_no_inference(x): 4495 return x + 1 4496 4497 gm_no_inference, _ = torch._dynamo.export(fn_no_inference, torch.rand(2, 2)) 4498 self.assertExpectedInline( 4499 gm_no_inference.code.strip(), 4500 """\ 4501def forward(self, x): 4502 arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) 4503 l_args_0_ = arg0 4504 _enter_inference_mode = torch.autograd.grad_mode._enter_inference_mode(False) 4505 add = l_args_0_ + 1; l_args_0_ = None 4506 _exit_inference_mode = torch.autograd.grad_mode._exit_inference_mode(_enter_inference_mode); _enter_inference_mode = _exit_inference_mode = None 4507 return pytree.tree_unflatten([add], self._out_spec)""", # NOQA: B950 4508 ) 4509 4510 inp = torch.randn(2, 2) 4511 out = gm_no_inference(inp) 4512 self.assertEqual(out.requires_grad, False) 4513 out.requires_grad = True 4514 4515 def fn(x): 4516 with torch.inference_mode(): 4517 return x + 1 4518 4519 gm, _ = torch._dynamo.export(fn)(torch.rand(2, 2)) 4520 self.assertExpectedInline( 4521 gm.code.strip(), 4522 """\ 4523def forward(self, x): 4524 arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) 4525 l_x_ = arg0 4526 _enter_inference_mode = torch.autograd.grad_mode._enter_inference_mode(True) 4527 add = l_x_ + 1; l_x_ = None 4528 _exit_inference_mode = torch.autograd.grad_mode._exit_inference_mode(_enter_inference_mode); _enter_inference_mode = _exit_inference_mode = None 4529 return pytree.tree_unflatten([add], self._out_spec)""", # NOQA: B950 4530 ) 4531 inp = torch.randn(2, 2, requires_grad=True) 4532 out = gm(inp) 4533 self.assertEqual(out.requires_grad, False) 4534 4535 def test_export_masking_with_no_grad(self): 4536 def fn(x, b, y): 4537 x = x.clone() 4538 x[b] = y 4539 return x 4540 4541 def fn_no_grad(x, b, y): 4542 with torch.no_grad(): 4543 return fn(x, b, y) 4544 4545 def fn_inference_mode(x, b, y): 4546 with torch.inference_mode(): 4547 return fn(x, b, y) 4548 4549 x = torch.randn(4, requires_grad=True) 4550 b = torch.tensor([True, False, True, False]) 4551 y = torch.randn(2, requires_grad=True) 4552 4553 gm, _ = torch._dynamo.export(fn_no_grad)(x, b, y) 4554 self.assertExpectedInline( 4555 gm.code.strip(), 4556 """\ 4557def forward(self, x, b, y): 4558 arg0, arg1, arg2, = fx_pytree.tree_flatten_spec(([x, b, y], {}), self._in_spec) 4559 l_x_ = arg0 4560 l_b_ = arg1 4561 l_y_ = arg2 4562 _set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None 4563 x = l_x_.clone(); l_x_ = None 4564 x[l_b_] = l_y_; setitem = x; l_b_ = l_y_ = setitem = None 4565 _set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None 4566 return pytree.tree_unflatten([x], self._out_spec)""", 4567 ) 4568 4569 gm, _ = torch._dynamo.export(fn_inference_mode)(x, b, y) 4570 self.assertExpectedInline( 4571 gm.code.strip(), 4572 """\ 4573def forward(self, x, b, y): 4574 arg0, arg1, arg2, = fx_pytree.tree_flatten_spec(([x, b, y], {}), self._in_spec) 4575 l_x_ = arg0 4576 l_b_ = arg1 4577 l_y_ = arg2 4578 _enter_inference_mode = torch.autograd.grad_mode._enter_inference_mode(True) 4579 x = l_x_.clone(); l_x_ = None 4580 x[l_b_] = l_y_; setitem = x; l_b_ = l_y_ = setitem = None 4581 _exit_inference_mode = torch.autograd.grad_mode._exit_inference_mode(_enter_inference_mode); _enter_inference_mode = _exit_inference_mode = None 4582 return pytree.tree_unflatten([x], self._out_spec)""", # NOQA: B950 4583 ) 4584 4585 with self.assertRaisesRegex( 4586 torch._dynamo.exc.Unsupported, "boolean masking setitem backwards" 4587 ): 4588 gm, _ = torch._dynamo.export(fn)(x, b, y) 4589 4590 def test_dynamo_list_index(self): 4591 def fn(x, in_list): 4592 return x + in_list.index(2) 4593 4594 inputs = (torch.ones(2, 2), [1, 2]) 4595 graph, _ = torch._dynamo.export(fn)(*inputs) 4596 out = graph(*inputs) 4597 self.assertEqual(out, torch.ones(2, 2) + 1) 4598 4599 4600common_utils.instantiate_parametrized_tests(ExportTests) 4601 4602if __name__ == "__main__": 4603 from torch._dynamo.test_case import run_tests 4604 4605 run_tests() 4606