1# Owner(s): ["module: dynamo"] 2import abc 3import collections 4import copy 5import dataclasses 6import dis 7import enum 8import functools 9import gc 10import itertools 11import logging 12import math 13import operator 14import os 15import random 16import sys 17import tempfile 18import threading 19import traceback 20import typing 21import unittest 22import unittest.mock as mock 23import warnings 24import weakref 25from unittest.mock import patch 26 27import numpy as np 28 29import torch 30import torch._dynamo.testing 31 32import torch._inductor.test_case 33import torch.onnx.operators 34 35import torch.utils._pytree as pytree 36import torch.utils.cpp_extension 37from torch import Tensor 38from torch._C import FileCheck 39from torch._dynamo import allow_in_graph 40from torch._dynamo.eval_frame import _debug_get_cache_entry_list 41from torch._dynamo.exc import Unsupported 42from torch._dynamo.source import ConstantSource, GetItemSource, LocalSource 43from torch._dynamo.testing import ( 44 CompileCounter, 45 CompileCounterWithBackend, 46 expectedFailureDynamic, 47 same, 48 skipIfNotPy311, 49 unsupported, 50 xfailIfPy312, 51) 52from torch._dynamo.utils import CompileProfiler, counters, ifdynstaticdefault 53from torch._inductor.utils import run_and_get_code 54from torch.ao.quantization import MinMaxObserver 55from torch.ao.quantization.fake_quantize import FakeQuantize 56from torch.ao.quantization.qconfig import QConfig 57from torch.ao.quantization.quantize_fx import prepare_qat_fx 58from torch.fx.experimental.recording import NotEqualError, replay_shape_env_events 59from torch.fx.experimental.symbolic_shapes import ( 60 _constrain_range_for_size, 61 constrain_range, 62 constrain_unify, 63 ConstraintViolationError, 64 expect_true, 65 guard_size_oblivious, 66 ShapeEnv, 67) 68from torch.nn import functional as F 69from torch.testing import make_tensor 70from torch.testing._internal.common_cuda import ( 71 PLATFORM_SUPPORTS_FLASH_ATTENTION, 72 SM80OrLater, 73 TEST_CUDA, 74 TEST_MULTIGPU, 75) 76from torch.testing._internal.common_methods_invocations import ( 77 sample_inputs_take_along_dim, 78) 79from torch.testing._internal.common_utils import ( 80 freeze_rng_state, 81 IS_FBCODE, 82 set_default_dtype, 83 wrapDeterministicFlagAPITest, 84) 85from torch.testing._internal.jit_utils import JitTestCase 86from torch.testing._internal.logging_utils import logs_to_string 87 88mytuple = collections.namedtuple("mytuple", ["a", "b", "ab"]) 89T = typing.TypeVar("T") 90 91 92# Specializes a test to run only if translation validation is set. 93def onlyIfTranslationValidation(fn: typing.Callable) -> typing.Callable: 94 @functools.wraps(fn) 95 def wrapper(*args, **kwargs): 96 import torch.fx.experimental.validator 97 98 if torch.fx.experimental.validator.translation_validation_enabled(): 99 return fn(*args, **kwargs) 100 raise unittest.SkipTest(f"only works when TV is True.") 101 102 return wrapper 103 104 105def cleanup_op(opname): 106 ns, name = opname.split("::") 107 if not hasattr(torch.ops, ns): 108 return 109 actual_ns = getattr(torch.ops, ns) 110 if not hasattr(actual_ns, name): 111 return 112 delattr(actual_ns, name) 113 114 115class MyPickledModule(torch.nn.Module): 116 def __init__(self, z): 117 super().__init__() 118 self.z = z 119 120 def forward(self, x, y): 121 return x * x * x + y + self.z 122 123 124# These are used for test_{cond/map}_with_quantization 125default_symmetric_fake_quant = FakeQuantize.with_args( 126 observer=MinMaxObserver, qscheme=torch.per_tensor_symmetric, dtype=torch.quint8 127) 128default_weight_symmetric_fake_quant = FakeQuantize.with_args( 129 observer=MinMaxObserver, qscheme=torch.per_tensor_symmetric, dtype=torch.qint8 130) 131uniform_qconfig_8bit = QConfig( 132 activation=default_symmetric_fake_quant, 133 weight=default_weight_symmetric_fake_quant.with_args, 134) 135qconfig_dict = {"object_type": [(torch.nn.Linear, uniform_qconfig_8bit)]} 136 137 138def closure_adder(val): 139 def inner(x): 140 return torch.sin(x + val) 141 142 return inner 143 144 145class UserDefineSetAttr: 146 setup = False 147 148 def __setattr__(self, key, value): 149 assert torch.compiler.is_dynamo_compiling() or UserDefineSetAttr.setup 150 super().__setattr__(f"pfx_{key}", value) 151 152 def __getattr__(self, key, c=1): 153 assert torch.compiler.is_dynamo_compiling() or UserDefineSetAttr.setup 154 # c is added to force a guard on __defaults__ and checks the source for __getattr__ 155 if c: 156 return self.__dict__[f"pfx_{key}"] 157 else: 158 return None 159 160 161class MiscTests(torch._inductor.test_case.TestCase): 162 def test_get_cache_entry(self): 163 def f(x): 164 return x + 1 165 166 torch.compile(f)(torch.randn(5, 5, 5)) 167 entries = _debug_get_cache_entry_list(f) 168 self.assertTrue(len(entries) > 0) 169 170 def g(x): 171 return x + 2 172 173 entries = _debug_get_cache_entry_list(g) 174 self.assertTrue(len(entries) == 0) 175 176 try: 177 _debug_get_cache_entry_list(1) 178 except TypeError as e: 179 self.assertIn("expected a code object!", str(e)) 180 181 # test get cache entry on skipped code object 182 def h(x): 183 x = x + 1 184 torch._dynamo.graph_break() 185 return x + 1 186 187 torch.compile(h)(torch.randn(3, 3)) 188 189 entries = _debug_get_cache_entry_list(torch._dynamo.graph_break) 190 self.assertEqual(len(entries), 0) 191 192 def test_boolarg(self): 193 def boolarg(aa, bb, flag): 194 if flag: 195 return aa - bb 196 else: 197 return bb - aa 198 199 a = torch.randn(10, 10) 200 b = torch.randn(10, 10) 201 correct1 = boolarg(a, b, True) 202 correct2 = boolarg(a, b, False) 203 correct3 = boolarg(a, b, None) 204 counter = CompileCounter() 205 opt_boolarg = torch._dynamo.optimize_assert(counter)(boolarg) 206 val1 = opt_boolarg(a, b, True) 207 val2 = opt_boolarg(a, b, False) 208 val3 = opt_boolarg(a, b, None) 209 val4 = opt_boolarg(a, b, True) 210 self.assertTrue(same(val1, correct1)) 211 self.assertTrue(same(val2, correct2)) 212 self.assertTrue(same(val3, correct3)) 213 self.assertTrue(same(val4, correct1)) 214 self.assertEqual(counter.frame_count, 3) 215 216 def test_invalid_args_builtin(self): 217 @torch.compile(backend="eager") 218 def fn(x): 219 x = x.sin() 220 if isinstance(x, torch.Tensor, invalid=True): 221 x = x.sin() 222 return x 223 224 with self.assertRaises(TypeError): 225 fn(torch.randn(16)) 226 227 def test_cpp_extension_recommends_custom_ops(self): 228 cpp_source = """ 229 #include <torch/extension.h> 230 at::Tensor foobar(const at::Tensor& x) { 231 return x.clone(); 232 } 233 """ 234 module = torch.utils.cpp_extension.load_inline( 235 name="mylib", 236 cpp_sources=cpp_source, 237 functions="foobar", 238 verbose=True, 239 ) 240 241 x = torch.ones(2, 2, requires_grad=True) 242 counters.clear() 243 244 @torch.compile(backend="eager") 245 def f(x): 246 return module.foobar(x) 247 248 with self.assertWarnsOnceRegex( 249 UserWarning, 250 ".*https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html.*", 251 ): 252 f(x) 253 self.assertEqual(len(counters["graph_break"]), 1) 254 first_graph_break = list(counters["graph_break"].keys())[0] 255 self.assertExpectedInline( 256 first_graph_break, 257 """Graph break due to unsupported builtin mylib.PyCapsule.foobar. This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind). If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround. If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use torch.compiler.allow_in_graph.""", 258 ) 259 260 cpp_source = """ 261 #include <torch/extension.h> 262 at::Tensor baz(const at::Tensor& x) { 263 return x.clone(); 264 } 265 """ 266 module2 = torch.utils.cpp_extension.load_inline( 267 name="mylib2", 268 cpp_sources=cpp_source, 269 functions="baz", 270 verbose=True, 271 ) 272 273 torch._dynamo.reset() 274 275 # Test that each warning only happens once 276 @torch.compile(backend="eager") 277 def f(x): 278 module2.baz(x) 279 module.foobar(x) 280 module.foobar(x) 281 module2.baz(x) 282 module.foobar(x) 283 module2.baz(x) 284 return x.clone() 285 286 with warnings.catch_warnings(record=True) as ws: 287 warnings.simplefilter("always") 288 f(x) 289 f(x) 290 self.assertEqual(len(ws), 2) 291 292 def test_callpacked(self): 293 def call_packed(args): 294 a, b, c = args 295 return a - b * c 296 297 counter = CompileCounter() 298 a = torch.randn(10, 10) 299 b = torch.randn(10, 10) 300 c = torch.randn(10, 10) 301 correct = call_packed([a, b, c]) 302 opt_call_packed = torch._dynamo.optimize_assert(counter)(call_packed) 303 val1 = opt_call_packed([a, b, c]) 304 val2 = opt_call_packed((a, b, c)) 305 val3 = opt_call_packed([a, b, c]) 306 val4 = opt_call_packed((a, b, c)) 307 self.assertTrue(same(val1, correct)) 308 self.assertTrue(same(val2, correct)) 309 self.assertTrue(same(val3, correct)) 310 self.assertTrue(same(val4, correct)) 311 self.assertEqual(counter.frame_count, 2) 312 313 def test_raises(self): 314 def fn(a, b, c, cls): 315 x = a + b - c * 10 316 raise cls(str(x)) 317 318 counter = CompileCounter() 319 a = torch.randn(10, 10) 320 b = torch.randn(10, 10) 321 c = torch.randn(10, 10) 322 opt_fn = torch._dynamo.optimize(counter)(fn) 323 self.assertRaises(AssertionError, lambda: opt_fn(a, b, c, AssertionError)) 324 self.assertEqual(counter.frame_count, 1) 325 self.assertEqual(counter.op_count, 3) 326 327 def test_module_not_callable(self): 328 def fn(x): 329 return torch.fft(x) 330 331 counter = CompileCounter() 332 a = torch.randn(10, 10) 333 opt_fn = torch._dynamo.optimize(counter)(fn) 334 self.assertRaisesRegex( 335 TypeError, "'module' object is not callable", lambda: opt_fn(a) 336 ) 337 338 def test_inplace(self): 339 def inplace1(a, b): 340 o = torch.empty((10, 10)) 341 o.copy_(a) 342 o -= b 343 return o 344 345 torch._dynamo.testing.standard_test(self, inplace1, 2, expected_ops=3) 346 347 def test_inplace_desugaring(self): 348 def inplace_on_literals(y): 349 x0 = 1 350 x0 += y 351 x1 = 1 352 x1 -= y 353 return x0, x1 354 355 torch._dynamo.testing.standard_test( 356 self, inplace_on_literals, 1, expected_ops=2 357 ) 358 359 def test_unpack4(self): 360 def unpack4(a, b): 361 a = a[:5, :] 362 b = b[:5, :] 363 x, y = a.size() 364 o = torch.empty((x, y)) 365 o.copy_(a / b) 366 return o 367 368 torch._dynamo.testing.standard_test( 369 self, 370 unpack4, 371 2, 372 expected_ops=5, 373 expected_ops_dynamic=ifdynstaticdefault(5, 7), 374 ) 375 376 def test_unpack5(self): 377 def unpack5(a, b): 378 a = a[:5, :] 379 b = b[:5, :] 380 x, y = a.shape 381 o = torch.empty((x, y)) 382 o.copy_(a / b) 383 return o 384 385 torch._dynamo.testing.standard_test( 386 self, 387 unpack5, 388 2, 389 expected_ops=5, 390 expected_ops_dynamic=ifdynstaticdefault(5, 7), 391 ) 392 393 def test_matmul1(self): 394 def matmul_op1(a, b): 395 return a @ b 396 397 # TODO(jansel): FX doesn't support this, should add upstream support 398 torch._dynamo.testing.standard_test(self, matmul_op1, 2, expected_ops=1) 399 400 def test_int_shape_binops(self): 401 def fn(x): 402 # Test reversal by putting int arg first. 403 y = 15 - x.shape[0] 404 y = 4 + y 405 y = 5 * y 406 y = 2 % y 407 y = 3**y 408 y = 10 // y 409 y = pow(2, y) 410 y = 10 / y 411 return x + y 412 413 torch._dynamo.testing.standard_test( 414 self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(1, 11) 415 ) 416 417 @torch._dynamo.config.patch(only_allow_pt2_compliant_ops=True) 418 def test_pt2_compliant_ops_are_allowed(self): 419 lib = torch.library.Library("mylib", "FRAGMENT") 420 try: 421 torch.library.define( 422 "mylib::bar", 423 "(Tensor x) -> Tensor", 424 lib=lib, 425 tags=(torch.Tag.pt2_compliant_tag,), 426 ) 427 torch.library.impl( 428 "mylib::bar", "CompositeImplicitAutograd", torch.sin, lib=lib 429 ) 430 assert torch.Tag.pt2_compliant_tag in torch.ops.mylib.bar.default.tags 431 432 def f(x): 433 return torch.ops.mylib.bar(x) 434 435 overload = torch.ops.mylib.bar.default 436 437 def g(x): 438 return overload(x) 439 440 x = torch.randn(3) 441 442 counts = torch._dynamo.testing.CompileCounter() 443 optimized_f = torch._dynamo.optimize(counts, nopython=True)(f) 444 _ = optimized_f(x) 445 446 optimized_g = torch._dynamo.optimize(counts, nopython=True)(f) 447 _ = optimized_g(x) 448 finally: 449 cleanup_op("mylib::bar") 450 del lib 451 452 @torch._dynamo.config.patch(only_allow_pt2_compliant_ops=True) 453 def test_non_pt2_compliant_ops_graph_break(self): 454 lib = torch.library.Library("mylib", "FRAGMENT") 455 try: 456 torch.library.define("mylib::bar2", "(Tensor x) -> Tensor", lib=lib) 457 torch.library.impl( 458 "mylib::bar2", "CompositeImplicitAutograd", torch.sin, lib=lib 459 ) 460 assert torch.Tag.pt2_compliant_tag not in torch.ops.mylib.bar2.default.tags 461 462 def f(x): 463 return torch.ops.mylib.bar2(x) 464 465 overload = torch.ops.mylib.bar2.default 466 467 def g(x): 468 return overload(x) 469 470 x = torch.randn(3) 471 472 counts = torch._dynamo.testing.CompileCounter() 473 with self.assertRaisesRegex( 474 torch._dynamo.exc.Unsupported, "not PT2 compliant" 475 ): 476 optimized_f = torch._dynamo.optimize(counts, nopython=True)(f) 477 y = optimized_f(x) 478 479 with self.assertRaisesRegex( 480 torch._dynamo.exc.Unsupported, "not PT2 compliant" 481 ): 482 optimized_g = torch._dynamo.optimize(counts, nopython=True)(f) 483 y = optimized_g(x) 484 finally: 485 cleanup_op("mylib::bar2") 486 del lib 487 488 @torch._dynamo.config.patch(only_allow_pt2_compliant_ops=True) 489 def test_pt2_compliant_overload(self): 490 lib = torch.library.Library("mylib", "FRAGMENT") 491 try: 492 torch.library.define( 493 "mylib::bar3.tensor", 494 "(Tensor x) -> Tensor", 495 tags=torch.Tag.pt2_compliant_tag, 496 lib=lib, 497 ) 498 torch.library.define( 499 "mylib::bar3.int", "(Tensor x, int dim) -> Tensor", lib=lib 500 ) 501 502 torch.library.impl( 503 "mylib::bar3.tensor", 504 "CompositeImplicitAutograd", 505 torch.sin, 506 lib=lib, 507 ) 508 torch.library.impl( 509 "mylib::bar3.int", "CompositeImplicitAutograd", torch.sum, lib=lib 510 ) 511 512 def f(x): 513 return torch.ops.mylib.bar3(x) 514 515 def g(x): 516 return torch.ops.mylib.bar3(x, 1) 517 518 def h(x): 519 return torch.ops.mylib.bar3(x, x, x) 520 521 x = torch.randn(3) 522 523 counts = torch._dynamo.testing.CompileCounter() 524 optimized_f = torch._dynamo.optimize(counts, nopython=True)(f) 525 optimized_g = torch._dynamo.optimize(counts, nopython=True)(g) 526 optimized_h = torch._dynamo.optimize(counts, nopython=True)(h) 527 528 # No error: the overload is PT2 compliant 529 optimized_f(x) 530 531 with self.assertRaisesRegex( 532 torch._dynamo.exc.Unsupported, "not PT2 compliant" 533 ): 534 y = optimized_g(x) 535 536 # graph break on incorrect parsing 537 with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, "failed to"): 538 y = optimized_h(x) 539 540 finally: 541 cleanup_op("mylib::bar3") 542 del lib 543 544 def test_auto_functionalize_can_with_default(self): 545 lib = torch.library.Library("mylib", "FRAGMENT") 546 torch.library.define( 547 "mylib::foo", 548 "(Tensor a, int b, Tensor(d!)? c=None, Tensor? d=None, int e=-1) -> ()", 549 tags=torch.Tag.pt2_compliant_tag, 550 lib=lib, 551 ) 552 553 @torch.library.impl("mylib::foo", "cpu", lib=lib) 554 def foo_impl(a, b, c=None, d=None, e=-1): 555 a + b 556 return 557 558 def f(a, mode): 559 return torch.ops.mylib.foo( 560 a, 561 0, 562 ) 563 564 a = torch.tensor([10, 10, 10], dtype=torch.int64) 565 566 torch.compile(f)(a, 0) 567 568 cleanup_op("mylib::foo") 569 del lib 570 571 def test_user_defined_setattr1(self): 572 @torch.compile(backend="eager", fullgraph=True) 573 def fn(obj): 574 obj.y = obj.x + 1 575 576 obj = UserDefineSetAttr() 577 with patch.object(UserDefineSetAttr, "setup", True): 578 obj.x = torch.randn(8) 579 fn(obj) 580 with patch.object(UserDefineSetAttr, "setup", True): 581 self.assertEqual(obj.y, obj.x + 1) 582 self.assertEqual(obj.__dict__.keys(), {"pfx_x", "pfx_y"}) 583 584 def test_user_defined_setattr2(self): 585 @torch.compile(backend="eager", fullgraph=True) 586 def fn(x): 587 obj = UserDefineSetAttr() 588 obj.x = x 589 obj.y = obj.x + 1 590 return obj 591 592 x = torch.randn(8) 593 obj = fn(x) 594 with patch.object(UserDefineSetAttr, "setup", True): 595 self.assertIs(obj.x, x) 596 self.assertEqual(obj.y, x + 1) 597 self.assertEqual(obj.__dict__.keys(), {"pfx_x", "pfx_y"}) 598 599 def test_closure_recompiles(self): 600 cnt = CompileCounter() 601 602 def fn(x, other_fn): 603 return other_fn(x + 1) - 1 604 605 opt = torch.compile(fn, backend=cnt, fullgraph=True) 606 607 x = torch.randn(8) 608 for f in ( 609 closure_adder(5), 610 closure_adder(5), 611 closure_adder(torch.randn(8)), 612 closure_adder(torch.randn(8)), 613 ): 614 self.assertEqual(opt(x, f), fn(x, f)) 615 616 self.assertEqual(cnt.frame_count, 2) 617 618 def test_generate_trivial_abstract_impl(self): 619 try: 620 lib = torch.library.Library("mylib", "FRAGMENT") 621 torch.library.define( 622 "mylib::foo", 623 "(Tensor x, Tensor[] y, Tensor(a!)? z, SymInt w) -> ()", 624 tags=torch.Tag.pt2_compliant_tag, 625 lib=lib, 626 ) 627 628 @torch.library.impl("mylib::foo", "cpu", lib=lib) 629 @torch._dynamo.disable 630 def foo_impl(x, y, z, w): 631 x + y[0] + w 632 return 633 634 def f(x, y, z, w): 635 return torch.ops.mylib.foo(x, y, z, 2) 636 637 x = torch.randn(3) 638 y = (torch.randn(3), torch.randn(3)) 639 z = torch.randn(3) 640 w = torch.randn(3) 641 args = (x, y, z, w) 642 643 output = torch.compile(f, backend="eager", fullgraph=True)(*args) 644 self.assertEqual(output, None) 645 finally: 646 cleanup_op("mylib::foo") 647 del lib 648 649 def test_can_auto_functionalize(self): 650 from torch._higher_order_ops.auto_functionalize import can_auto_functionalize 651 652 expected_true = [ 653 "(Tensor(a!) x) -> ()", 654 "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> ()", 655 "(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> ()", 656 "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> Tensor", 657 "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor)", 658 ] 659 expected_false = [ 660 "(Tensor x) -> ()", 661 "(Tensor(a) x) -> Tensor(a)", 662 "(Tensor(a!) x) -> Tensor(a!)", 663 "(Tensor(a!) x, Tensor y, Tensor(b!)[] z, SymInt w) -> ()", 664 "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> Tensor(a)", 665 "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor(a))", 666 "(Tensor(a) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor(a))", 667 "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor[])", 668 ] 669 for schema in expected_true: 670 try: 671 lib = torch.library.Library("mylib", "FRAGMENT") 672 torch.library.define("mylib::a", schema, lib=lib) 673 self.assertTrue( 674 can_auto_functionalize(torch.ops.mylib.a.default), msg=schema 675 ) 676 self.assertFalse(can_auto_functionalize(torch.ops.mylib.a)) 677 finally: 678 cleanup_op("mylib::a") 679 del lib 680 for schema in expected_false: 681 try: 682 lib = torch.library.Library("mylib", "FRAGMENT") 683 torch.library.define("mylib::a", schema, lib=lib) 684 self.assertFalse( 685 can_auto_functionalize(torch.ops.mylib.a.default), msg=schema 686 ) 687 self.assertFalse(can_auto_functionalize(torch.ops.mylib.a)) 688 finally: 689 cleanup_op("mylib::a") 690 del lib 691 692 def test_auto_functionalize(self): 693 try: 694 lib = torch.library.Library("mylib", "FRAGMENT") 695 torch.library.define( 696 "mylib::foo", 697 "(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> ()", 698 tags=torch.Tag.pt2_compliant_tag, 699 lib=lib, 700 ) 701 702 @torch.library.impl("mylib::foo", "cpu", lib=lib) 703 @torch._dynamo.disable 704 def foo_impl(x, y, z, w, n): 705 x.add_(y[0] + w) 706 z.add_(y[1] + n) 707 708 def f(x, y, z, n): 709 torch.ops.mylib.foo(x, y, z, 2, n) 710 711 x = torch.randn(3) 712 y = (torch.randn(3), torch.randn(3)) 713 z = torch.randn(3) 714 n = torch.randn(3) 715 orig_args = (x, y, z, n) 716 717 compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) 718 719 log_stream, ctx = logs_to_string( 720 "torch._inductor.compile_fx", "post_grad_graphs" 721 ) 722 with ctx(): 723 torch.compile(f, backend="inductor", fullgraph=True)(*compiled_args) 724 725 post_grad_graphs = "\n".join( 726 log_stream.getvalue().strip().split("\n")[3:] 727 ).strip() 728 729 # Check the graph under static shapes 730 if torch._dynamo.config.assume_static_by_default: 731 self.assertExpectedInline( 732 post_grad_graphs, 733 """\ 734def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"): 735 # No stacktrace found for following nodes 736 foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg4_1 = arg2_1 = arg3_1 = arg1_1 = arg0_1 = None 737 return ()""", 738 ) 739 740 eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) 741 f(*eager_args) 742 self.assertEqual(compiled_args, eager_args) 743 finally: 744 cleanup_op("mylib::foo") 745 del lib 746 747 def test_auto_functionalize_with_returns(self): 748 try: 749 lib = torch.library.Library("mylib", "FRAGMENT") 750 torch.library.define( 751 "mylib::foo", 752 "(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> (Tensor, Tensor)", 753 tags=torch.Tag.pt2_compliant_tag, 754 lib=lib, 755 ) 756 757 @torch.library.impl("mylib::foo", "cpu", lib=lib) 758 @torch._dynamo.disable 759 def foo_impl(x, y, z, w, n): 760 x.add_(y[0] + w) 761 z.add_(y[1] + n) 762 return y[0] + w, y[1] + n 763 764 @torch.library.impl_abstract("mylib::foo", lib=lib) 765 def foo_abstract(x, y, z, w, n): 766 return y[0] + w, y[1] + n 767 768 def f(x, y, z, n): 769 return torch.ops.mylib.foo(x, y, z, 2, n) 770 771 x = torch.randn(3) 772 y = (torch.randn(3), torch.randn(3)) 773 z = torch.randn(3) 774 n = torch.randn(3) 775 orig_args = (x, y, z, n) 776 777 compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) 778 log_stream, ctx = logs_to_string( 779 "torch._inductor.compile_fx", "post_grad_graphs" 780 ) 781 with ctx(): 782 compiled_out = torch.compile(f, backend="inductor", fullgraph=True)( 783 *compiled_args 784 ) 785 786 if torch._dynamo.config.assume_static_by_default: 787 post_grad_graphs = "\n".join( 788 log_stream.getvalue().strip().split("\n")[3:] 789 ).strip() 790 self.assertExpectedInline( 791 post_grad_graphs, 792 """\ 793def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"): 794 # No stacktrace found for following nodes 795 foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg4_1 = arg2_1 = arg3_1 = arg1_1 = arg0_1 = None 796 getitem_4: "f32[3][1]cpu" = foo_default[0] 797 getitem_5: "f32[3][1]cpu" = foo_default[1]; foo_default = None 798 return (getitem_4, getitem_5)""", 799 ) 800 801 eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) 802 eager_out = f(*eager_args) 803 self.assertEqual(compiled_args, eager_args) 804 self.assertEqual(compiled_out, eager_out) 805 finally: 806 cleanup_op("mylib::foo") 807 del lib 808 809 def test_auto_functionalize_on_view(self): 810 try: 811 lib = torch.library.Library("mylib", "FRAGMENT") 812 torch.library.define( 813 "mylib::foo", 814 "(Tensor(a!) x) -> ()", 815 tags=torch.Tag.pt2_compliant_tag, 816 lib=lib, 817 ) 818 819 @torch.library.impl("mylib::foo", "cpu", lib=lib) 820 @torch._dynamo.disable 821 def foo_impl(x): 822 x_np = x.detach().numpy() # view 823 np.sin(x_np, out=x_np) 824 return 825 826 x = torch.randn(3) 827 expected = x.sin() 828 torch.ops.mylib.foo(x) 829 assert torch.allclose(x, expected) 830 831 @torch.compile(backend="aot_eager_decomp_partition", fullgraph=True) 832 def f(x): 833 x = x.clone() 834 y = x[:] 835 torch.ops.mylib.foo(y) 836 return x 837 838 y = f(x) 839 self.assertEqual(y, x.sin()) 840 finally: 841 cleanup_op("mylib::foo") 842 del lib 843 844 def test_auto_functionalize_optional(self): 845 try: 846 lib = torch.library.Library("mylib", "FRAGMENT") 847 torch.library.define( 848 "mylib::foo", 849 "(Tensor(a!)? x, Tensor[] y, Tensor(b!)? z, SymInt w, Tensor n) -> ()", 850 tags=torch.Tag.pt2_compliant_tag, 851 lib=lib, 852 ) 853 854 @torch.library.impl("mylib::foo", "cpu", lib=lib) 855 @torch._dynamo.disable 856 def foo_impl(x, y, z, w, n): 857 if x is not None: 858 x.add_(y[0] + w) 859 if z is not None: 860 z.add_(y[1] + n) 861 862 def f(x, y, z, n): 863 torch.ops.mylib.foo(x, y, z, 2, n) 864 865 x = None 866 y = (torch.randn(3), torch.randn(3)) 867 z = torch.randn(3) 868 n = torch.randn(3) 869 orig_args = (x, y, z, n) 870 871 compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) 872 log_stream, ctx = logs_to_string( 873 "torch._inductor.compile_fx", "post_grad_graphs" 874 ) 875 with ctx(): 876 torch.compile(f, backend="inductor", fullgraph=True)(*compiled_args) 877 878 if torch._dynamo.config.assume_static_by_default: 879 post_grad_graphs = "\n".join( 880 log_stream.getvalue().strip().split("\n")[3:] 881 ).strip() 882 self.assertExpectedInline( 883 post_grad_graphs, 884 """\ 885def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu"): 886 # No stacktrace found for following nodes 887 foo_default = torch.ops.mylib.foo.default(None, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg1_1 = arg0_1 = None 888 return ()""", 889 ) 890 891 eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) 892 f(*eager_args) 893 self.assertEqual(compiled_args, eager_args) 894 finally: 895 cleanup_op("mylib::foo") 896 del lib 897 898 def test_shape_int_inplace_binops(self): 899 def fn(x): 900 p = x.shape[0] 901 p += 2 902 p -= 2 903 p **= 2 904 p /= 2 905 p *= 2 906 p //= 2 907 p %= 2 908 return x + p 909 910 torch._dynamo.testing.standard_test( 911 self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(1, 10) 912 ) 913 914 def test_int_shape_inplace_binops(self): 915 def fn(x): 916 p = x.shape[0] 917 # Test reversal by putting constant first 918 y = 2 919 y += p 920 y = 2 921 y -= p 922 y = 2 923 y **= p 924 y = 2 925 y /= p 926 y = 2 927 y *= p 928 y = 2 929 y //= p 930 y = 2 931 y %= p 932 return x + y 933 934 torch._dynamo.testing.standard_test( 935 self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(1, 4) 936 ) 937 938 def test_int_int_comparisons(self): 939 def fn(x): 940 if 2 != 2: 941 out = 1 942 elif 2 < 1: 943 out = 1 944 elif 1 > 2: 945 out = 1 946 elif 1 >= 2: 947 out = 1 948 elif 2 <= 1: 949 out = 1 950 elif 2 == 2: 951 out = 2 952 else: 953 out = 1 954 return x + out 955 956 torch._dynamo.testing.standard_test(self, fn, 1, expected_ops=1) 957 958 def test_shape_int_comparisons(self): 959 def fn(x): 960 a = x.shape[0] 961 # Ensure support for constant on right side 962 if a != 10: 963 out = 1 964 elif a < 2: 965 out = 1 966 elif a > 12: 967 out = 1 968 elif a >= 12: 969 out = 1 970 elif a <= 2: 971 out = 1 972 elif a == 10: 973 out = 2 974 else: 975 out = 1 976 return x + out 977 978 # TODO: Test the guards maybe? 979 torch._dynamo.testing.standard_test(self, fn, 1, expected_ops=1) 980 981 def test_int_shape_comparisons(self): 982 def fn(x): 983 a = x.shape[0] 984 # Ensure support for constant on left side 985 if 10 != a: 986 out = 1 987 elif 12 < a: 988 out = 1 989 elif 2 > a: 990 out = 1 991 elif 2 >= a: 992 out = 1 993 elif 12 <= a: 994 out = 1 995 elif 10 == a: 996 out = 2 997 else: 998 out = 1 999 return x + out 1000 1001 # TODO: Test the guards maybe? 1002 torch._dynamo.testing.standard_test(self, fn, 1, expected_ops=1) 1003 1004 def test_param_shape_binops(self): 1005 class MyModule(torch.nn.Module): 1006 def __init__(self): 1007 super().__init__() 1008 self.param = torch.nn.Parameter(torch.randn(15)) 1009 1010 def forward(self, x): 1011 # Test reversal by putting param shape arg first. 1012 p = self.param.shape[0] 1013 y = p - x.shape[0] 1014 y = p + y 1015 y = p * y 1016 y = p % y 1017 y = p**y 1018 y = p // y 1019 y = pow(p, y) 1020 y = p / y 1021 return x + y 1022 1023 counts = torch._dynamo.testing.CompileCounter() 1024 mod = MyModule() 1025 optimized_mod = torch._dynamo.optimize(counts, nopython=True)(mod) 1026 1027 x = torch.randn(3) 1028 ref = mod(x) 1029 res = optimized_mod(x) 1030 1031 self.assertTrue(same(ref, res)) 1032 self.assertEqual(counts.frame_count, 1) 1033 1034 if torch._dynamo.config.assume_static_by_default: 1035 self.assertExpectedInline(counts.op_count, """1""") 1036 else: 1037 self.assertExpectedInline(counts.op_count, """11""") 1038 1039 def test_user_defined_binop(self): 1040 class MyClass: 1041 def __init__(self, value): 1042 self.value = value 1043 1044 def __radd__(self, other): 1045 return self.value + other 1046 1047 def fn(x, c): 1048 y = x.shape[0] + c 1049 return x + y 1050 1051 counts = torch._dynamo.testing.CompileCounter() 1052 opt_fn = torch._dynamo.optimize(counts)(fn) 1053 1054 x = torch.randn(3) 1055 c = MyClass(4) 1056 ref = fn(x, c) 1057 res = opt_fn(x, c) 1058 1059 self.assertTrue(same(ref, res)) 1060 self.assertEqual(counts.frame_count, 1) 1061 if torch._dynamo.config.assume_static_by_default: 1062 self.assertExpectedInline(counts.op_count, """1""") 1063 else: 1064 self.assertExpectedInline(counts.op_count, """4""") 1065 1066 def test_user_defined_iter(self): 1067 class Mod: 1068 def __init__(self): 1069 self.a = [torch.randn(2, 2), torch.randn(2, 2)] 1070 1071 def __iter__(self): 1072 return iter(self.a) 1073 1074 def f(mod): 1075 ret = [] 1076 for x in mod: 1077 ret.append(x + 1) 1078 return ret 1079 1080 mod = Mod() 1081 counts = torch._dynamo.testing.CompileCounter() 1082 opt_fn = torch._dynamo.optimize(counts, nopython=True)(f) 1083 ref = f(mod) 1084 res = opt_fn(mod) 1085 res = opt_fn(mod) 1086 res = opt_fn(mod) 1087 res = opt_fn(mod) 1088 self.assertTrue(same(ref, res)) 1089 self.assertEqual(counts.frame_count, 1) 1090 1091 mod.a.append(torch.randn(2, 2)) 1092 # `for x in mod` is inlined, where iter(m.a) creates a guard on the list length of m.a 1093 # Mutating length of mod.a causes a re-compilation. 1094 ref2 = f(mod) 1095 res2 = opt_fn(mod) 1096 res2 = opt_fn(mod) 1097 res2 = opt_fn(mod) 1098 res2 = opt_fn(mod) 1099 self.assertTrue(same(ref2, res2)) 1100 self.assertEqual(counts.frame_count, 2) 1101 1102 def test_compare_shapes_eq(self): 1103 def compare_shapes(a, b, to_list): 1104 x = list(a.unsqueeze(-1).shape) if to_list else a.shape 1105 y = list(b.unsqueeze(-1).shape) if to_list else b.shape 1106 if x == y: 1107 return a + 1 1108 else: 1109 return a + 2 1110 1111 # Test both ListVariable and ShapeVariable 1112 torch._dynamo.testing.standard_test( 1113 self, lambda a, b: compare_shapes(a, b, to_list=True), 2 1114 ) 1115 torch._dynamo.testing.standard_test( 1116 self, lambda a, b: compare_shapes(a, b, to_list=False), 2 1117 ) 1118 1119 def test_compare_shapes_tuple_eq(self): 1120 def compare_shapes(a, b): 1121 x = tuple(a.unsqueeze(-1).shape) 1122 y = tuple(b.unsqueeze(-1).shape) 1123 if x == y: 1124 return a + 1 1125 else: 1126 return a + 2 1127 1128 torch._dynamo.testing.standard_test(self, lambda a, b: compare_shapes(a, b), 2) 1129 1130 def test_compare_shapes_tuple_neq(self): 1131 def compare_shapes(a, b): 1132 x = tuple(a.unsqueeze(-1).shape) 1133 y = tuple(b.unsqueeze(-1).shape) 1134 if x != y: 1135 return a + 1 1136 else: 1137 return a + 2 1138 1139 torch._dynamo.testing.standard_test(self, lambda a, b: compare_shapes(a, b), 2) 1140 1141 def test_compare_shapes_neq(self): 1142 def compare_shapes(a, b, to_list): 1143 x = list(a.unsqueeze(-1).shape) if to_list else a.shape 1144 y = list(b.unsqueeze(-1).shape) if to_list else b.shape 1145 if x != y: 1146 return a + 1 1147 else: 1148 return a + 2 1149 1150 # Test both ListVariable and ShapeVariable 1151 torch._dynamo.testing.standard_test( 1152 self, lambda a, b: compare_shapes(a, b, to_list=True), 2 1153 ) 1154 torch._dynamo.testing.standard_test( 1155 self, lambda a, b: compare_shapes(a, b, to_list=False), 2 1156 ) 1157 1158 def test_compare_shapes_with_constant(self): 1159 def compare_shapes(a): 1160 x = a.shape 1161 if x[0] != 3: 1162 return a * 4 1163 return a * 3 1164 1165 guard_failure = None 1166 1167 def guard_failures(failure): 1168 nonlocal guard_failure 1169 guard_failure = failure 1170 1171 opt_fn = torch._dynamo.optimize( 1172 "eager", nopython=True, guard_fail_fn=guard_failures 1173 )(compare_shapes) 1174 opt_fn(torch.randn([3, 4])) 1175 opt_fn(torch.randn([4, 3])) 1176 self.assertIn( 1177 """tensor 'L['a']' size mismatch at index 0. expected 3, actual 4""", 1178 guard_failure.reason, 1179 ) 1180 1181 def test_builtin_abs(self): 1182 def fn(x, y): 1183 return abs(x) + abs(y) 1184 1185 sample = torch.randn(10, 10) 1186 opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) 1187 1188 for sample in [ 1189 (torch.randn(10, 10), torch.randn(10, 10)), 1190 (-10, make_tensor(10, dtype=torch.int64, device="cpu")), 1191 (-0.1, torch.randn(10)), 1192 ]: 1193 expect = fn(*sample) 1194 actual = opt_fn(*sample) 1195 self.assertEqual(expect, actual) 1196 1197 def test_builtin_isinstance(self): 1198 def fn(x): 1199 t = torch.arange(1, 3) 1200 a = isinstance(x, torch.Tensor) 1201 b = isinstance(t, torch.Tensor) 1202 c = isinstance(x, int) 1203 d = isinstance(3, int) 1204 e = isinstance([1, 2, 3], list) 1205 f = isinstance({"foo": 1, "bar": 2}, dict) 1206 res = [a, b, c, d, e, f] 1207 # Can't run yet due to other unimplemented instructions 1208 # res += [isinstance(torch.nn.LazyLinear(2, 3), torch.nn.Linear)] 1209 return res 1210 1211 torch._dynamo.testing.standard_test(self, fn, 1, expected_ops=1) 1212 1213 @unittest.skipIf(sys.version_info[:2] <= (3, 8), "Requires astunparse") 1214 def test_cse_dict_guards(self): 1215 def fn(x): 1216 ret = torch.zeros(3) 1217 for v in x.values(): 1218 ret = ret + v 1219 return ret 1220 1221 from torch._dynamo.guards import build_guard_function, CLOSURE_VARS 1222 1223 x = {3: torch.randn(3), 2: torch.randn(3), 4: torch.randn(3)} 1224 _, guards = torch._dynamo.export(fn, x) 1225 1226 code_lists = [c for g in guards for c in g.code_list or []] 1227 _, pycode = build_guard_function(code_lists, []) 1228 # Make sure we just call "list(dict.keys())" once 1229 self.assertEqual(pycode.count("keys"), 1) 1230 1231 def test_sys_modules(self): 1232 def fn(x, y): 1233 mod_a = sys.modules.get("aaaaaaaa") 1234 assert mod_a is None 1235 assert "bbbbbbbb" not in sys.modules 1236 1237 assert "operator" in sys.modules 1238 operator = sys.modules["operator"] 1239 builtins = sys.modules.get("builtins") 1240 operator2 = sys.modules.get("cccccccc", operator) 1241 1242 return operator.add(x, y), operator2.neg(builtins.abs(x)) 1243 1244 torch._dynamo.testing.standard_test(self, fn, 2, expected_ops=3) 1245 1246 x = torch.randn(10, 10) 1247 _, guards = torch._dynamo.export(fn, x, x) 1248 guard_code = [] 1249 for guard in guards: 1250 if guard.code_list: 1251 guard_code += guard.code_list 1252 1253 # Filter out id-matches that won't reproduce run to run 1254 guard_code = filter( 1255 lambda line: "id" not in line and "lookup_backend" not in line, 1256 sorted(guard_code), 1257 ) 1258 guard_code_str = "\n".join(guard_code) 1259 1260 for line in """\ 12612 <= L['x'].size()[0] 1262L['x'] is L['y'] 1263L['x'].ndimension() == 2 1264L['x'].requires_grad == False 1265L['x'].size()[1] == L['x'].size()[0] 1266L['x'].storage_offset() == 0 1267___dict_contains('builtins', G['sys'].modules) 1268___dict_contains('operator', G['sys'].modules) 1269___dict_contains('operator', G['sys'].modules) 1270hasattr(L['x'], '_dynamo_dynamic_indices') == False 1271not ___dict_contains('aaaaaaaa', G['sys'].modules) 1272not ___dict_contains('bbbbbbbb', G['sys'].modules) 1273not ___dict_contains('cccccccc', G['sys'].modules) 1274str(L['x'].device) == 'cpu' 1275str(L['x'].dtype) == 'torch.float32' 1276utils_device.CURRENT_DEVICE == None""".split( 1277 "\n" 1278 ): 1279 self.assertIn( 1280 line, 1281 guard_code_str, 1282 ) 1283 1284 def test_fold(self): 1285 def fn(a): 1286 return a + math.sqrt(63) 1287 1288 torch._dynamo.testing.standard_test(self, fn, 1, expected_ops=1) 1289 1290 def test_getattr_dict(self): 1291 def fn(x): 1292 from torch.masked.maskedtensor._ops_refs import _MASKEDTENSOR_FUNCTION_TABLE 1293 1294 return x * len(_MASKEDTENSOR_FUNCTION_TABLE) 1295 1296 i = torch.randn(5) 1297 r1 = fn(i) 1298 opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 1299 r2 = opt_fn(i) 1300 self.assertEqual(r1, r2) 1301 1302 def test_shape_unpack(self): 1303 def fn(x): 1304 a, b = x.size() 1305 return x * b 1306 1307 i = torch.randn(5, 10) 1308 r1 = fn(i) 1309 opt_fn = torch._dynamo.optimize("eager")(fn) 1310 r2 = opt_fn(i) 1311 self.assertTrue(same(r1, r2)) 1312 1313 def test_typing_dict(self): 1314 def fn(d): 1315 return d[T] 1316 1317 d = {T: torch.randn(3)} 1318 r1 = fn(d) 1319 opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 1320 r2 = opt_fn(d) 1321 self.assertEqual(r1, r2) 1322 1323 def test_tensor_iter(self): 1324 def fn(x): 1325 for y in x: 1326 y.add_(1.0) 1327 return y 1328 1329 torch._dynamo.testing.standard_test( 1330 self, 1331 fn, 1332 1, 1333 expected_ops=20, 1334 ) 1335 1336 def test_empty_list(self): 1337 def fn(x, ll): 1338 if len(ll) == 0 and not ll and ll is not None: 1339 return x + 1 1340 1341 i = torch.randn(5, 10) 1342 r1 = fn(i, []) 1343 opt_fn = torch._dynamo.optimize("eager")(fn) 1344 r2 = opt_fn(i, []) 1345 r3 = opt_fn(i, tuple()) 1346 self.assertTrue(same(r1, r2)) 1347 self.assertTrue(same(r1, r3)) 1348 1349 def test_min_max_over_iterable(self): 1350 def get_test_fn(func): 1351 def _fn(a, b, func=func): 1352 # try all of list, iterator, tuple, vararg. 1353 lst = [a.shape[0] + 1, 8, a.shape[0]] 1354 x = func(lst) 1355 y = func(iter(lst)) 1356 z = func(tuple(lst)) 1357 w = func(*lst) 1358 return a + (x + y + z + w) 1359 1360 return _fn 1361 1362 torch._dynamo.testing.standard_test( 1363 self, 1364 get_test_fn(func=min), 1365 2, 1366 expected_ops=1, 1367 expected_ops_dynamic=ifdynstaticdefault(1, 14), 1368 ) 1369 torch._dynamo.testing.standard_test( 1370 self, 1371 get_test_fn(func=max), 1372 2, 1373 expected_ops=1, 1374 expected_ops_dynamic=ifdynstaticdefault(1, 17), 1375 ) 1376 1377 @torch._dynamo.config.patch(capture_scalar_outputs=True) 1378 def test_torch_check(self): 1379 cnts = torch._dynamo.testing.CompileCounter() 1380 1381 @torch.compile(backend=cnts, fullgraph=True) 1382 def f(x): 1383 y = x.item() 1384 torch._check(y >= 0) 1385 return torch.arange(0, y) 1386 1387 f(torch.tensor([3])) 1388 f(torch.tensor([4])) 1389 self.assertEqual(cnts.frame_count, 1) 1390 1391 @torch._dynamo.config.patch(capture_scalar_outputs=True) 1392 def test_torch_check_symbolic_shape_rel(self): 1393 cnts = torch._dynamo.testing.CompileCounter() 1394 1395 @torch.compile(backend=cnts, fullgraph=True) 1396 def f(x): 1397 y = x.item() 1398 torch._check(x.shape[0] == 1) 1399 torch._check(x.shape[0] != 2) 1400 torch._check(x.shape[0] >= 0) 1401 torch._check(x.shape[0] > 0) 1402 torch._check(x.shape[0] < 4) 1403 torch._check(x.shape[0] <= 3) 1404 return torch.arange(0, y) 1405 1406 f(torch.tensor([3])) 1407 f(torch.tensor([4])) 1408 self.assertEqual(cnts.frame_count, 1) 1409 1410 @torch._dynamo.config.patch(capture_scalar_outputs=True) 1411 # Translation validation changes the exception type, don't run with it 1412 @torch.fx.experimental._config.patch(translation_validation=False) 1413 def test_torch_check_is_size(self): 1414 cnts = torch._dynamo.testing.CompileCounter() 1415 1416 @torch.compile(backend=cnts, fullgraph=True) 1417 def f(x): 1418 y = x.item() 1419 torch._check_is_size(y) 1420 # Cannot conditional on unbacked SymInt 1421 if y == 0: 1422 assert False 1423 else: 1424 return torch.arange(0, y) 1425 1426 self.assertRaises(torch._dynamo.exc.UserError, lambda: f(torch.tensor([3]))) 1427 1428 def test_assert(self): 1429 @torch.compile 1430 def fn1(x): 1431 assert x.shape != x.shape 1432 1433 with self.assertRaises(AssertionError): 1434 a = torch.randn(10) 1435 fn1(a) 1436 1437 def fn2(x): 1438 assert x.shape == x.shape 1439 return x.abs() 1440 1441 torch._dynamo.testing.standard_test(self, fn=fn2, nargs=1, expected_ops=1) 1442 1443 def test_config_obj(self): 1444 class Cfg: 1445 def __init__(self): 1446 self.val = 0.5 1447 self.count = 3 1448 1449 def fn(x, cfg): 1450 for i in range(cfg.count): 1451 x = x + cfg.val 1452 return x 1453 1454 cfg1 = Cfg() 1455 cfg1.val = 1.0 1456 cfg2 = Cfg() 1457 v = torch.zeros(1) 1458 cnts = torch._dynamo.testing.CompileCounter() 1459 opt_fn = torch._dynamo.optimize(cnts)(fn) 1460 v = opt_fn(v, cfg1) # 3 1461 v = opt_fn(v, cfg2) # 4.5 1462 cfg2.count = 1 1463 v = opt_fn(v, cfg2) # 5 1464 cfg2.val = 2.0 1465 v = opt_fn(v, cfg2) # 7 1466 self.assertEqual(v[0], 7) 1467 self.assertEqual(cnts.op_count, 8) 1468 1469 def test_config_getattr_default(self): 1470 class Cfg: 1471 def __init__(self): 1472 self.val = 0.5 1473 self.count = 10 1474 1475 def fn(x, cfg): 1476 if getattr(cfg, "just_add_7", False): 1477 return x + 7 1478 for i in range(cfg.count): 1479 x = x + cfg.val 1480 return x 1481 1482 cfg1 = Cfg() 1483 v = torch.zeros(1) 1484 cnts = torch._dynamo.testing.CompileCounter() 1485 opt_fn = torch._dynamo.optimize(cnts)(fn) 1486 self.assertEqual(opt_fn(v, cfg1)[0], 5) 1487 self.assertEqual(opt_fn(v, cfg1)[0], 5) 1488 cfg1.just_add_7 = True 1489 self.assertEqual(opt_fn(v, cfg1)[0], 7) 1490 self.assertEqual(opt_fn(v, cfg1)[0], 7) 1491 cfg1.just_add_7 = False 1492 self.assertEqual(opt_fn(v, cfg1)[0], 5) 1493 self.assertEqual(opt_fn(v, cfg1)[0], 5) 1494 self.assertEqual(cnts.frame_count, 3) 1495 1496 def test_size_input(self): 1497 def fn(x, s): 1498 a, b = s 1499 return x + (a - b) 1500 1501 v = torch.zeros(10, 20) 1502 cnts = torch._dynamo.testing.CompileCounter() 1503 opt_fn = torch._dynamo.optimize(cnts)(fn) 1504 self.assertEqual(opt_fn(v, v.size())[0, 0], -10) 1505 self.assertEqual(opt_fn(v, (10, 20))[0, 0], -10) 1506 self.assertEqual(opt_fn(v, [10, 20])[0, 0], -10) 1507 # One recompile per differing input type 1508 self.assertEqual(cnts.frame_count, 3) 1509 1510 def test_cell_output1(self): 1511 out = None 1512 1513 def fn(a, b): 1514 nonlocal out 1515 out = a + b * 10 1516 1517 v = torch.Tensor([100]) 1518 cnts = torch._dynamo.testing.CompileCounter() 1519 opt_fn = torch._dynamo.optimize(cnts)(fn) 1520 self.assertIsNone(opt_fn(v, v)) 1521 self.assertEqual(out[0], 1100) 1522 self.assertEqual(cnts.op_count, 2) 1523 1524 def test_cell_output2(self): 1525 out = None 1526 1527 def fn(a, b): 1528 nonlocal out 1529 c = unsupported(a, b) 1530 out = a + b * 10 + c 1531 1532 v = torch.Tensor([100]) 1533 cnts = torch._dynamo.testing.CompileCounter() 1534 opt_fn = torch._dynamo.optimize(cnts)(fn) 1535 self.assertIsNone(opt_fn(v, v)) 1536 self.assertEqual(out[0], 1200) 1537 self.assertEqual(cnts.op_count, 3) 1538 1539 def test_return_nested_function(self): 1540 out = None 1541 1542 def fn(a, b): 1543 nonlocal out 1544 c = a + b 1545 d = a + 1.0 1546 1547 def fn2(f: int = 7, g: float = 9.0): 1548 nonlocal out 1549 out = a + b * 10 1550 return c * f - d * g 1551 1552 return fn2 1553 1554 v1 = torch.Tensor([100]) 1555 v2 = torch.Tensor([200]) 1556 cnts = torch._dynamo.testing.CompileCounter() 1557 opt_fn = torch._dynamo.optimize(cnts)(fn) 1558 opt_fn_ret = torch._dynamo.optimize(cnts)(opt_fn(v1, v2)) 1559 self.assertEqual(opt_fn_ret(1.5)[0], -459) 1560 self.assertEqual(out[0], 2100) 1561 self.assertEqual(cnts.frame_count, 2) 1562 self.assertEqual(cnts.op_count, 7) 1563 1564 def test_tensor_dict1(self): 1565 def fn(inputs): 1566 return inputs["a"] - inputs["b"] * 1.5 1567 1568 v1 = torch.Tensor([100]) 1569 v2 = torch.Tensor([200]) 1570 cnts = torch._dynamo.testing.CompileCounter() 1571 opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 1572 self.assertEqual(opt_fn({"a": v1, "b": v2})[0], -200) 1573 self.assertEqual(cnts.frame_count, 1) 1574 self.assertEqual(cnts.op_count, 2) 1575 1576 def test_tensor_dict3(self): 1577 def fn(inputs_a, inputs_b): 1578 total = torch.zeros(1) 1579 input_keys = inputs_a.keys() | inputs_b.keys() 1580 for k in input_keys: 1581 if k in inputs_a: 1582 total += inputs_a[k] 1583 if k in inputs_b: 1584 total += inputs_b[k] 1585 return total 1586 1587 v1 = torch.Tensor([100]) 1588 v2 = torch.Tensor([200]) 1589 cnts = torch._dynamo.testing.CompileCounter() 1590 opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 1591 self.assertEqual( 1592 opt_fn({"a": v1, "b": v2}, {"b": v1, "c": v2}), 1593 fn({"a": v1, "b": v2}, {"b": v1, "c": v2}), 1594 ) 1595 self.assertEqual(cnts.frame_count, 1) 1596 self.assertEqual(cnts.op_count, 5) 1597 1598 def test_tensor_dict2(self): 1599 def fn1(inputs): 1600 total = torch.zeros(1) 1601 for k, v in inputs.items(): 1602 total += v 1603 return total 1604 1605 def fn2(inputs): 1606 total = torch.zeros(1) 1607 for v in inputs.values(): 1608 total += v 1609 return total 1610 1611 def fn3(inputs): 1612 total = torch.zeros(1) 1613 for k in inputs.keys(): 1614 total += inputs[k] 1615 return total 1616 1617 v1 = torch.Tensor([100]) 1618 v2 = torch.Tensor([200]) 1619 cnts = torch._dynamo.testing.CompileCounter() 1620 opt_fn1 = torch._dynamo.optimize(cnts, nopython=True)(fn1) 1621 opt_fn2 = torch._dynamo.optimize(cnts, nopython=True)(fn2) 1622 opt_fn3 = torch._dynamo.optimize(cnts, nopython=True)(fn3) 1623 self.assertEqual(opt_fn1({"a": v1, "b": v2})[0], 300) 1624 self.assertEqual(opt_fn2({"a": v1, "b": v2})[0], 300) 1625 self.assertEqual(opt_fn3({"a": v1, "b": v2})[0], 300) 1626 self.assertEqual(cnts.frame_count, 3) 1627 self.assertEqual(cnts.op_count, 9) 1628 1629 def test_dictcomp(self): 1630 def fn1(inputs): 1631 return {k: v + 1 for k, v in inputs.items()} 1632 1633 v1 = torch.Tensor([100]) 1634 v2 = torch.Tensor([200]) 1635 cnts = torch._dynamo.testing.CompileCounter() 1636 opt_fn1 = torch._dynamo.optimize(cnts)(fn1) 1637 self.assertEqual(opt_fn1({"a": v1, "b": v2})["a"], 101) 1638 self.assertEqual(opt_fn1({"a": v1, "b": v2})["b"], 201) 1639 self.assertEqual(cnts.frame_count, 1) 1640 self.assertEqual(cnts.op_count, 2) 1641 1642 def test_listcomp(self): 1643 def fn2(inputs): 1644 return torch.sum(torch.cat([v + 1 for k, v in inputs.items()], 0)) 1645 1646 v1 = torch.Tensor([100]) 1647 v2 = torch.Tensor([200]) 1648 cnts = torch._dynamo.testing.CompileCounter() 1649 opt_fn2 = torch._dynamo.optimize(cnts)(fn2) 1650 self.assertEqual(opt_fn2({"a": v1, "b": v2}), 302) 1651 self.assertEqual(cnts.frame_count, 1) 1652 self.assertEqual(cnts.op_count, 4) 1653 1654 def test_is_floating_point(self): 1655 def fn(a, b): 1656 x = a + 1.0 1657 if torch.is_floating_point(b): 1658 x = x + b 1659 return x + 2.0 1660 1661 return torch._dynamo.testing.standard_test(self, fn=fn, nargs=2, expected_ops=3) 1662 1663 def test_is_floating_point2(self): 1664 def fn(a, b): 1665 x = a + 1.0 1666 if b.is_floating_point(): 1667 x = x + b 1668 return x + 2.0 1669 1670 return torch._dynamo.testing.standard_test(self, fn=fn, nargs=2, expected_ops=3) 1671 1672 def test_is_tensor(self): 1673 def fn(a, b): 1674 x = a + 1.0 1675 if torch.is_tensor(b): 1676 x = x + b 1677 return x + 2.0 1678 1679 return torch._dynamo.testing.standard_test(self, fn=fn, nargs=2, expected_ops=3) 1680 1681 def test_is_tensor2(self): 1682 def fn(x): 1683 if torch.is_tensor(x): 1684 return x + 1 1685 else: 1686 return torch.ones([2, 3]) 1687 1688 x1 = {"input": torch.rand(2, 3)} 1689 x2 = torch.rand(2, 3) 1690 ref1 = fn(x1) 1691 ref2 = fn(x2) 1692 opt_fn = torch._dynamo.optimize("eager")(fn) 1693 res1 = opt_fn(x1) 1694 res2 = opt_fn(x2) 1695 self.assertEqual(ref1, res1) 1696 self.assertEqual(ref2, res2) 1697 1698 def test_numel(self): 1699 def fn(a): 1700 return (a + a.numel() + torch.numel(a), a + a.nelement()) 1701 1702 return torch._dynamo.testing.standard_test( 1703 self, 1704 fn=fn, 1705 nargs=1, 1706 expected_ops=3, 1707 expected_ops_dynamic=ifdynstaticdefault(3, 6), 1708 ) 1709 1710 def test_pair(self): 1711 def fn(a): 1712 return ( 1713 torch.zeros(torch.nn.modules.utils._pair(a.size())) 1714 + a 1715 + torch.ones(torch.nn.modules.utils._ntuple(3)(3)).sum() 1716 ) 1717 1718 return torch._dynamo.testing.standard_test( 1719 self, 1720 fn=fn, 1721 nargs=1, 1722 expected_ops=5, 1723 expected_ops_dynamic=ifdynstaticdefault(5, 8), 1724 ) 1725 1726 @patch.object(torch._dynamo.config, "capture_scalar_outputs", True) 1727 def test_tensor_item_capture(self): 1728 def fn(a, b): 1729 return (a + b).sum().item() 1730 1731 v1 = torch.randn((10, 10)) 1732 v2 = torch.randn((10, 10)) 1733 correct = fn(v1, v2) 1734 cnts = torch._dynamo.testing.CompileCounter() 1735 opt_fn = torch._dynamo.optimize(cnts)(fn) 1736 self.assertEqual(opt_fn(v1, v2), correct) 1737 self.assertEqual(cnts.frame_count, 1) 1738 self.assertEqual(cnts.op_count, 4) 1739 1740 @patch.object(torch._dynamo.config, "capture_scalar_outputs", False) 1741 def test_tensor_item_no_capture(self): 1742 def fn(a, b): 1743 return (a + b).sum().item() 1744 1745 v1 = torch.randn((10, 10)) 1746 v2 = torch.randn((10, 10)) 1747 correct = fn(v1, v2) 1748 cnts = torch._dynamo.testing.CompileCounter() 1749 opt_fn = torch._dynamo.optimize(cnts)(fn) 1750 self.assertEqual(opt_fn(v1, v2), correct) 1751 self.assertEqual(cnts.frame_count, 1) 1752 self.assertEqual(cnts.op_count, 2) 1753 1754 def test_namedtuple1(self): 1755 def fn(a, b): 1756 tmp = mytuple(a, b, a + b) 1757 return mytuple(tmp.a, tmp[1], tmp.ab + b) 1758 1759 v1 = torch.Tensor([10]) 1760 v2 = torch.Tensor([20]) 1761 cnts = torch._dynamo.testing.CompileCounter() 1762 opt_fn = torch._dynamo.optimize(cnts)(fn) 1763 self.assertEqual(opt_fn(v1, v2).ab, 50) 1764 self.assertEqual(cnts.frame_count, 1) 1765 self.assertEqual(cnts.op_count, 2) 1766 1767 def test_namedtuple2(self): 1768 def fn(packed): 1769 a, b, c = packed 1770 if hasattr(packed, "b"): 1771 b = packed.b + 1 1772 c = packed[2] 1773 return a + b + c 1774 1775 v1 = torch.Tensor([1]) 1776 v2 = torch.Tensor([2]) 1777 v3 = torch.Tensor([3]) 1778 cnts = torch._dynamo.testing.CompileCounter() 1779 opt_fn = torch._dynamo.optimize(cnts)(fn) 1780 self.assertEqual(opt_fn(mytuple(v1, v2, v3))[0], 7) 1781 self.assertEqual(cnts.frame_count, 1) 1782 self.assertEqual(cnts.op_count, 3) 1783 1784 def test_namedtuple3(self): 1785 def fn(x, packed): 1786 if isinstance(packed, mytuple): 1787 return x + 1 1788 else: 1789 return x - 1 1790 1791 x = torch.rand([2, 3]) 1792 packed = mytuple(1, 2, 3) 1793 ref = fn(x, packed) 1794 opt_fn = torch._dynamo.optimize("eager")(fn) 1795 res = opt_fn(x, packed) 1796 self.assertTrue(same(ref, res)) 1797 1798 def test_range_input(self): 1799 def fn(a, rng): 1800 x = a 1801 for i in rng: 1802 x = x + i 1803 return x 1804 1805 def fn1(a): 1806 return fn(a, rng=range(3)) 1807 1808 return torch._dynamo.testing.standard_test( 1809 self, fn=fn1, nargs=1, expected_ops=3 1810 ) 1811 1812 def test_range_with_shape(self): 1813 def fn(a): 1814 for i in range(1, a.shape[0]): 1815 a += 1 1816 return a 1817 1818 return torch._dynamo.testing.standard_test( 1819 self, 1820 fn=fn, 1821 nargs=1, 1822 expected_ops=9, 1823 ) 1824 1825 def test_build_tuple_unpack(self): 1826 def fn1(a, b, c): 1827 return a - b / c 1828 1829 def fn2(a, b, c): 1830 tmp1 = (a,) 1831 tmp2 = (b, c) 1832 args = (*tmp1, *tmp2) 1833 return fn1(*args) 1834 1835 def fn3(a, *args): 1836 return fn1(a, *args) 1837 1838 torch._dynamo.testing.standard_test(self, fn=fn2, nargs=3, expected_ops=2) 1839 torch._dynamo.testing.standard_test(self, fn=fn3, nargs=3, expected_ops=2) 1840 1841 def test_list_mul(self): 1842 def fn(count): 1843 head_mask = count * [None] * count 1844 return head_mask 1845 1846 cnts = torch._dynamo.testing.CompileCounter() 1847 opt_fn = torch._dynamo.optimize(cnts)(fn) 1848 self.assertEqual(opt_fn(2), [None] * 4) 1849 # TODO: the captured frame here is a bit goofy, because we don't 1850 # output anything and none of the traced operations have side 1851 # effects. Probably need better heuristic for bailing on 1852 # dynamo if there are no outputs 1853 if torch._dynamo.config.assume_static_by_default: 1854 self.assertExpectedInline(cnts.frame_count, """0""") 1855 self.assertExpectedInline(cnts.op_count, """0""") 1856 else: 1857 self.assertExpectedInline(cnts.frame_count, """1""") 1858 self.assertExpectedInline(cnts.op_count, """2""") 1859 1860 def test_list_slice_mul(self): 1861 def fn(count): 1862 a = [1, 2, 3] 1863 head_mask = count * a[1:] * count 1864 return head_mask 1865 1866 cnts = torch._dynamo.testing.CompileCounter() 1867 opt_fn = torch._dynamo.optimize(cnts)(fn) 1868 self.assertEqual(opt_fn(2), [2, 3] * 4) 1869 if torch._dynamo.config.assume_static_by_default: 1870 self.assertExpectedInline(cnts.frame_count, """0""") 1871 self.assertExpectedInline(cnts.op_count, """0""") 1872 else: 1873 self.assertExpectedInline(cnts.frame_count, """1""") 1874 self.assertExpectedInline(cnts.op_count, """2""") 1875 1876 def test_tuple_mul(self): 1877 def fn(count): 1878 head_mask = count * (2, 3) * count 1879 return head_mask 1880 1881 cnts = torch._dynamo.testing.CompileCounter() 1882 opt_fn = torch._dynamo.optimize(cnts)(fn) 1883 self.assertEqual(opt_fn(2), (2, 3) * 4) 1884 if torch._dynamo.config.assume_static_by_default: 1885 self.assertExpectedInline(cnts.frame_count, """0""") 1886 self.assertExpectedInline(cnts.op_count, """0""") 1887 else: 1888 self.assertExpectedInline(cnts.frame_count, """1""") 1889 self.assertExpectedInline(cnts.op_count, """2""") 1890 1891 def test_tuple_mul_with_shape(self): 1892 def fn(a): 1893 x = a.shape[0] 1894 y = 2 * (x, 3) * 2 1895 return a + y[4] 1896 1897 # expect 3 ops post folding for dynamic case: size, index, add 1898 torch._dynamo.testing.standard_test( 1899 self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(1, 3) 1900 ) 1901 1902 def test_tuple_iadd_with_shape(self): 1903 def fn(a): 1904 output = (a + a.shape[0], a - a.shape[0]) 1905 # tuple += tuple 1906 output += (a - a.shape[0], a + a.shape[0]) 1907 # tuple += constant tuple 1908 output += (2, 3) 1909 return output 1910 1911 # expect 4 add / subs for static, 4 * 3 (size, index, math op) for dynamic 1912 torch._dynamo.testing.standard_test( 1913 self, fn, 1, expected_ops=4, expected_ops_dynamic=ifdynstaticdefault(4, 12) 1914 ) 1915 1916 def test_list_iadd_with_shape(self): 1917 def fn(a): 1918 output = [a + a.shape[0], a - a.shape[0]] 1919 # list += list 1920 output += [a - a.shape[0], a + a.shape[0]] 1921 # list += tuple 1922 output += (a + a.shape[0], a - a.shape[0]) 1923 return output 1924 1925 # expect 6 add / subs for static, 6 * 3 (size, index, math op) for dynamic 1926 1927 torch._dynamo.testing.standard_test( 1928 self, fn, 1, expected_ops=6, expected_ops_dynamic=ifdynstaticdefault(6, 18) 1929 ) 1930 1931 def test_list_iadd_side_effect(self): 1932 def fn(a, b): 1933 a += [b] 1934 torch._dynamo.graph_break() 1935 return a 1936 1937 a = [1, 2, 3] 1938 b = torch.ones(2, 2) 1939 1940 opt_fn = torch._dynamo.optimize("eager")(fn) 1941 1942 exp = fn(a, b) 1943 1944 a = [1, 2, 3] 1945 b = torch.ones(2, 2) 1946 act = opt_fn(a, b) 1947 1948 self.assertEqual(exp, act) 1949 1950 def test_user_getattr1(self): 1951 class MyConfig(dict): 1952 def __getattr__(self, name): 1953 return self[name] 1954 1955 def fn(cfg, x, y): 1956 return x + y + cfg.offset 1957 1958 x = torch.randn(10) 1959 cfg = MyConfig(offset=5) 1960 cnts = torch._dynamo.testing.CompileCounter() 1961 opt_fn = torch._dynamo.optimize(cnts)(fn) 1962 self.assertTrue(same(opt_fn(cfg, x, x), 2 * x + 5)) 1963 self.assertEqual(cnts.frame_count, 1) 1964 self.assertEqual(cnts.op_count, 2) 1965 1966 def test_user_getattr2(self): 1967 class MyConfig: 1968 defined_on_class = 1 1969 1970 def __init__(self): 1971 self.defined_on_object = 2 1972 1973 def __getattr__(self, name): 1974 return 3 1975 1976 def fn(cfg, x): 1977 return x + cfg.defined_on_class - cfg.defined_on_object + cfg.not_defined 1978 1979 x = torch.randn(10) 1980 cfg = MyConfig() 1981 cnts = torch._dynamo.testing.CompileCounter() 1982 opt_fn = torch._dynamo.optimize(cnts)(fn) 1983 self.assertTrue(same(opt_fn(cfg, x), x + 1 - 2 + 3)) 1984 self.assertEqual(cnts.frame_count, 1) 1985 self.assertEqual(cnts.op_count, 3) 1986 1987 def test_getset_descriptor(self): 1988 def fn(g, x): 1989 return g.__get__(x) 1990 1991 cnts = torch._dynamo.testing.CompileCounter() 1992 opt_fn = torch.compile(fullgraph=True, backend="eager")(fn) 1993 g = torch.Tensor.shape 1994 1995 res = opt_fn(g, torch.ones(2, 2)) 1996 exp_res = fn(g, torch.ones(2, 2)) 1997 self.assertEqual(res, exp_res) 1998 1999 def test_get_attr_function(self): 2000 def fn(g, x): 2001 return g(x) 2002 2003 cnts = torch._dynamo.testing.CompileCounter() 2004 opt_fn = torch._dynamo.optimize(cnts)(fn) 2005 g = torch.Tensor.shape.__get__ 2006 2007 res = opt_fn(g, torch.ones(2, 2)) 2008 exp_res = fn(g, torch.ones(2, 2)) 2009 self.assertEqual(res, exp_res) 2010 2011 def test_user_getattribute(self): 2012 class MyObject: 2013 def __init__(self): 2014 self.custom_dict = {"a": torch.rand((2, 2))} 2015 self.my_number = 42 2016 2017 def __getattribute__(self, name): 2018 custom_dict = super().__getattribute__("custom_dict") 2019 if name in custom_dict: 2020 return custom_dict[name] 2021 return super().__getattribute__(name) 2022 2023 def run(self, x): 2024 return self.my_number * x + self.a * x 2025 2026 def fn(obj, x): 2027 return obj.run(x) 2028 2029 obj = MyObject() 2030 x = torch.rand((2, 2)) 2031 cnts = torch._dynamo.testing.CompileCounter() 2032 opt_fn = torch._dynamo.optimize(cnts)(fn) 2033 self.assertTrue(same(opt_fn(obj, x), fn(obj, x))) 2034 2035 def test_nn_module_getattr(self): 2036 class MyMod(torch.nn.Module): 2037 def __init__(self): 2038 super().__init__() 2039 self.custom_dict = {"queue": [torch.rand((2, 2)) for _ in range(3)]} 2040 self.other_attr = torch.rand((2, 2)) 2041 2042 def __getattr__(self, name): 2043 custom_dict = self.custom_dict 2044 if name in custom_dict: 2045 return custom_dict[name] 2046 return super().__getattr__(name) 2047 2048 def forward(self, x): 2049 return x @ self.other_attr + self.queue[-1] 2050 2051 x = torch.rand((2, 2)) 2052 mod = MyMod() 2053 cnts = torch._dynamo.testing.CompileCounter() 2054 opt_mod = torch._dynamo.optimize(cnts)(mod) 2055 self.assertTrue(same(opt_mod(x), mod(x))) 2056 self.assertTrue(cnts.frame_count, 1) 2057 self.assertTrue(cnts.op_count, 2) 2058 2059 def test_nn_module_getattribute(self): 2060 class MyMod(torch.nn.Module): 2061 def __init__(self): 2062 super().__init__() 2063 self.my_number = 42 2064 2065 def __getattribute__(self, name): 2066 if name == "special_attr": 2067 return torch.tensor([[1, 2], [3, 4]]) 2068 return super().__getattribute__(name) 2069 2070 def forward(self, x): 2071 return self.my_number * x + self.special_attr * x 2072 2073 def fn(mod, x): 2074 return mod(x) 2075 2076 mod = MyMod() 2077 x = torch.rand((2, 2)) 2078 cnts = torch._dynamo.testing.CompileCounter() 2079 opt_fn = torch._dynamo.optimize(cnts)(fn) 2080 self.assertTrue(same(opt_fn(mod, x), fn(mod, x))) 2081 2082 def test_constant_getattr(self): 2083 # https://github.com/pytorch/pytorch/issues/97480 2084 def fn(): 2085 return getattr(None, "arg", 3) 2086 2087 cnt = torch._dynamo.testing.CompileCounter() 2088 optimized_fn = torch._dynamo.optimize(cnt)(fn) 2089 res = optimized_fn() 2090 self.assertTrue(same(res, 3)) 2091 2092 def test_user_property(self): 2093 class MyConfig: 2094 @property 2095 def prop5(self): 2096 return 5 2097 2098 def fn(cfg, x, y): 2099 return x + y + cfg.prop5 2100 2101 x = torch.randn(10) 2102 cfg = MyConfig() 2103 cnts = torch._dynamo.testing.CompileCounter() 2104 opt_fn = torch._dynamo.optimize(cnts)(fn) 2105 self.assertTrue(same(opt_fn(cfg, x, x), 2 * x + 5)) 2106 self.assertEqual(cnts.frame_count, 1) 2107 self.assertEqual(cnts.op_count, 2) 2108 2109 def test_dataclass_fields(self): 2110 @dataclasses.dataclass 2111 class MyDataClass: 2112 a: torch.Tensor 2113 b: torch.Tensor = None 2114 c: torch.Tensor = None 2115 d: torch.Tensor = None 2116 e: torch.Tensor = None 2117 2118 def fn(obj): 2119 class_fields = dataclasses.fields(obj) 2120 assert len(class_fields) 2121 assert all(field.default is None for field in class_fields[1:]) 2122 other_fields_are_none = all( 2123 getattr(obj, field.name) is None for field in class_fields[1:] 2124 ) 2125 assert not other_fields_are_none 2126 2127 if not hasattr(obj, "a"): 2128 return -1 2129 if hasattr(obj, "z"): 2130 return -2 2131 2132 total = getattr(obj, class_fields[0].name) 2133 for field in class_fields[1:]: 2134 v = getattr(obj, field.name) 2135 if v is not None: 2136 total += v 2137 2138 return total 2139 2140 obj1 = MyDataClass(torch.randn(10), torch.randn(10), torch.randn(10)) 2141 obj2 = MyDataClass(torch.randn(10), e=torch.randn(10)) 2142 correct1 = fn(obj1) 2143 correct2 = fn(obj2) 2144 2145 cnts = torch._dynamo.testing.CompileCounter() 2146 opt_fn = torch._dynamo.optimize(cnts)(fn) 2147 self.assertTrue(same(opt_fn(obj1), correct1)) 2148 self.assertEqual(cnts.frame_count, 1) 2149 self.assertEqual(cnts.op_count, 2) 2150 2151 torch._dynamo.reset() 2152 cnts = torch._dynamo.testing.CompileCounter() 2153 opt_fn = torch._dynamo.optimize(cnts)(fn) 2154 self.assertTrue(same(opt_fn(obj2), correct2)) 2155 self.assertEqual(cnts.frame_count, 1) 2156 self.assertEqual(cnts.op_count, 1) 2157 2158 # guard failure 2159 obj2.z = True 2160 self.assertEqual(opt_fn(obj2), -2) 2161 2162 def test_dataclass_local_hasattr(self): 2163 cnt = CompileCounter() 2164 x = torch.randn(10) 2165 2166 @dataclasses.dataclass 2167 class MyDataClass: 2168 a: torch.Tensor 2169 b: torch.Tensor 2170 2171 @torch.compile(backend=cnt, fullgraph=True) 2172 def fn(): 2173 obj = MyDataClass(x + 1, x - 1) 2174 if not hasattr(obj, "a"): 2175 return -1 2176 if hasattr(obj, "z"): 2177 return -2 2178 return obj 2179 2180 result = fn() 2181 self.assertIsInstance(result, MyDataClass) 2182 self.assertEqual(result.a, x + 1) 2183 self.assertEqual(result.b, x - 1) 2184 self.assertEqual(cnt.frame_count, 1) 2185 self.assertEqual(cnt.op_count, 2) 2186 2187 def test_catch_watchings1(self): 2188 cnt = CompileCounter() 2189 2190 @torch.compile(backend=cnt, fullgraph=True) 2191 def fn(x): 2192 with warnings.catch_warnings(record=True): 2193 return x.sin() 2194 2195 x = torch.randn(8) 2196 self.assertEqual(fn(x), x.sin()) 2197 self.assertEqual(cnt.frame_count, 1) 2198 2199 def test_catch_watchings2(self): 2200 cnt = CompileCounter() 2201 2202 @torch.compile(backend=cnt, fullgraph=True) 2203 def fn(x): 2204 return x.sin(), warnings.catch_warnings(record=True) 2205 2206 x = torch.randn(8) 2207 _, a = fn(x) 2208 _, b = fn(x) 2209 self.assertEqual(cnt.frame_count, 1) 2210 self.assertIsInstance(a, warnings.catch_warnings) 2211 self.assertIsInstance(b, warnings.catch_warnings) 2212 self.assertIsNot(a, b) 2213 2214 def test_tensor_build_list_unpack(self): 2215 def fn(x): 2216 # seen in fastNLP_Bert 2217 return torch.cat([*x], dim=-1) 2218 2219 val = torch.randn([1, 1, 473, 768]) 2220 correct = fn(val) 2221 cnts = torch._dynamo.testing.CompileCounter() 2222 opt_fn = torch._dynamo.optimize(cnts)(fn) 2223 self.assertTrue(same(opt_fn(val), correct)) 2224 self.assertEqual(cnts.frame_count, 1) 2225 self.assertEqual(cnts.op_count, 2) 2226 2227 def test_numpy_int_constant(self): 2228 def fn(x, a, b): 2229 return x + (a % b) 2230 2231 args = [torch.randn(10), 4096, np.int64(8)] 2232 correct = fn(*args) 2233 cnts = torch._dynamo.testing.CompileCounter() 2234 opt_fn = torch._dynamo.optimize(cnts, dynamic=True, nopython=True)(fn) 2235 self.assertTrue(same(opt_fn(*args), correct)) 2236 self.assertTrue(same(opt_fn(*args), correct)) 2237 self.assertEqual(cnts.frame_count, 1) 2238 self.assertEqual(cnts.op_count, 2) 2239 2240 def test_numpy_subdtype(self): 2241 def fn(x, n): 2242 return np.issubdtype(type(n), np.integer) + x 2243 2244 args = [torch.randn(10), 4096] 2245 correct = fn(*args) 2246 cnts = torch._dynamo.testing.CompileCounter() 2247 opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 2248 self.assertEqual(opt_fn(*args), correct) 2249 self.assertEqual(cnts.frame_count, 1) 2250 2251 def test_numpy_take_along_axis(self): 2252 def fn(x, i, a): 2253 return np.take_along_axis(x, i, a) 2254 2255 def sample_to_args(s): 2256 args = (s.input, *sample.args) 2257 return tuple(a.numpy() if isinstance(a, torch.Tensor) else a for a in args) 2258 2259 samples = list( 2260 sample_inputs_take_along_dim( 2261 None, "cpu", torch.float32, requires_grad=False 2262 ) 2263 ) 2264 cnts = torch._dynamo.testing.CompileCounter() 2265 opt_fn = torch._dynamo.optimize(cnts)(fn) 2266 i = 1 2267 for sample in samples: 2268 args = sample_to_args(sample) 2269 if len(args) < 3: 2270 # if axis is None, second argument is treated as 1d array 2271 args = (args[0], np.ravel(args[1]), None) 2272 self.assertEqual(fn(*args), opt_fn(*args)) 2273 self.assertEqual(cnts.frame_count, i) 2274 i += 1 2275 2276 def test_numpy_torch_operators(self): 2277 def fn(op, t1, t2): 2278 return op(t1, t2) 2279 2280 from torch._dynamo.variables.builtin import BuiltinVariable 2281 2282 operators = BuiltinVariable._fx_graph_functions() 2283 2284 for op, t1_np, t2_np in itertools.product( 2285 operators, (True, False), (True, False) 2286 ): 2287 if op in [operator.eq, operator.ne]: 2288 # returns equivalent of torch.eq/ne 2289 continue 2290 if op is operator.getitem: 2291 # skip 2292 # Did you know that tensor[ndarray_of_floats] works? 2293 continue 2294 if op is operator.imatmul and (t1_np or t2_np): 2295 # skip 2296 # in numpy, in place matmul does not work single 2297 # dimensional arrays 2298 continue 2299 t1 = torch.rand(5) 2300 if t1_np: 2301 t1 = t1.numpy() 2302 t2 = torch.rand(5) 2303 if t2_np: 2304 t2 = t2.numpy() 2305 try: 2306 # TODO try a bit harder 2307 result = op(t1, t2) 2308 except (RuntimeError, TypeError, IndexError): 2309 continue 2310 cnts = torch._dynamo.testing.CompileCounter() 2311 opt_fn = torch._dynamo.optimize(cnts)(fn) 2312 self.assertEqual(result, opt_fn(op, t1, t2), msg=f"{op=} {t1_np=} {t2_np=}") 2313 self.assertEqual(cnts.frame_count, 1, msg=f"{op=} {t1_np=} {t2_np=}") 2314 torch._dynamo.reset() 2315 2316 def test_numpy_ndarray_graph_break(self): 2317 def fn(x): 2318 a = x.numpy() 2319 b = a.real 2320 torch._dynamo.graph_break() 2321 c = np.multiply(b, 2.0) 2322 return c 2323 2324 cnts = torch._dynamo.testing.CompileCounter() 2325 opt_fn = torch._dynamo.optimize(cnts)(fn) 2326 for _ in range(10): 2327 x = torch.randn(3) 2328 ref = fn(x) 2329 res = opt_fn(x) 2330 self.assertEqual(ref, res) 2331 self.assertEqual(cnts.frame_count, 2) 2332 2333 def test_numpy_ndarray_graph_break_with_multiple_outputs(self): 2334 def fn(x, y): 2335 a = x.numpy() 2336 b = y.numpy() 2337 torch._dynamo.graph_break() 2338 return np.add(a, 1), np.add(b, 1) 2339 2340 cnts = torch._dynamo.testing.CompileCounter() 2341 opt_fn = torch._dynamo.optimize(cnts)(fn) 2342 for _ in range(10): 2343 x = torch.randn([1, 3]) 2344 y = torch.randn([1, 3]) 2345 ref = fn(x, y) 2346 res = opt_fn(x, y) 2347 self.assertEqual(ref, res) 2348 self.assertEqual(cnts.frame_count, 2) 2349 2350 def test_numpy_force(self): 2351 def fn(x): 2352 return x.numpy(force=False) 2353 2354 cnts = torch._dynamo.testing.CompileCounter() 2355 opt_fn = torch._dynamo.optimize(cnts)(fn) 2356 x = torch.randn(3) 2357 res = opt_fn(x) 2358 self.assertEqual(type(res), np.ndarray) 2359 self.assertEqual(cnts.frame_count, 1) 2360 2361 def fn(x): 2362 return x.numpy(force=True) 2363 2364 cnts = torch._dynamo.testing.CompileCounter() 2365 opt_fn = torch._dynamo.optimize(cnts)(fn) 2366 x = torch.randn(3, requires_grad=True) 2367 res = opt_fn(x) 2368 self.assertEqual(type(res), np.ndarray) 2369 self.assertEqual(cnts.frame_count, 1) 2370 2371 def test_numpy_recompilation_scalar(self): 2372 def fn(x, a): 2373 return np.where(x < 0.5, a, x) 2374 2375 x = np.random.randn(8) 2376 cnts = torch._dynamo.testing.CompileCounter() 2377 opt_fn = torch._dynamo.optimize(cnts, dynamic=True)(fn) 2378 2379 ref = fn(x, 3) 2380 res = opt_fn(x, 3) 2381 self.assertEqual(ref, res) 2382 2383 ref = fn(x, 4) 2384 res = opt_fn(x, 4) 2385 self.assertEqual(ref, res) 2386 2387 self.assertEqual(cnts.frame_count, 1) 2388 2389 def test_tensor_interacts_with_numpy_ndarray(self): 2390 def fn(x, y): 2391 a = x.numpy() 2392 b = y.numpy() 2393 c = np.ones_like(a) 2394 d = np.ones_like(b) 2395 torch._dynamo.graph_break() 2396 return np.add(a, c), np.add(b, d) 2397 2398 cnts = torch._dynamo.testing.CompileCounter() 2399 opt_fn = torch._dynamo.optimize(cnts)(fn) 2400 for _ in range(10): 2401 x = torch.randn([1, 3]) 2402 y = torch.randn([1, 3]) 2403 ref = fn(x, y) 2404 res = opt_fn(x, y) 2405 self.assertEqual(ref, res) 2406 self.assertEqual(cnts.frame_count, 2) 2407 2408 def test_numpy_ndarray_works_with_builtin_function(self): 2409 def fn(x): 2410 v = x.sum() / len(x) 2411 return v 2412 2413 cnts = torch._dynamo.testing.CompileCounter() 2414 opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 2415 for _ in range(10): 2416 x = np.random.randn(2, 3) 2417 ref = fn(x) 2418 res = opt_fn(x) 2419 self.assertEqual(ref, res) 2420 self.assertEqual(cnts.frame_count, 1) 2421 2422 def test_numpy_array_of_arrays(self): 2423 def fn(x, y): 2424 return np.array([x, y]) 2425 2426 cnts = torch._dynamo.testing.CompileCounter() 2427 opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 2428 2429 x, y = np.float64(1), np.float64(2) 2430 res = opt_fn(x, y) 2431 self.assertEqual(res, np.array([1, 2], dtype=float)) 2432 self.assertEqual(type(res), np.ndarray) 2433 self.assertEqual(cnts.frame_count, 1) 2434 2435 x, y = np.arange(2), np.arange(2) + 2 2436 res = opt_fn(x, y) 2437 self.assertEqual(res, np.array([[0, 1], [2, 3]])) 2438 self.assertEqual(type(res), np.ndarray) 2439 self.assertEqual(cnts.frame_count, 2) 2440 2441 def test_numpy_readonly(self): 2442 @torch.compile(fullgraph=True) 2443 def fn(x): 2444 return x 2445 2446 x = np.broadcast_to(np.arange(3), (2, 3)) 2447 self.assertFalse(x.flags.writeable) 2448 2449 with warnings.catch_warnings(): 2450 warnings.simplefilter("error") 2451 y = fn(x) 2452 self.assertTrue(y.flags.writeable) # XXX: differs from numpy 2453 2454 def test_numpy_tolist(self): 2455 def fn(x): 2456 return x.tolist() 2457 2458 cnts = torch._dynamo.testing.CompileCounter() 2459 opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 2460 2461 x = np.arange(5) 2462 r = opt_fn(x) 2463 2464 self.assertEqual(r, [0, 1, 2, 3, 4]) 2465 self.assertEqual(type(r), list) 2466 self.assertEqual(cnts.frame_count, 1) 2467 2468 def test_numpy_size_attr(self): 2469 def fn(x): 2470 return x.size + x 2471 2472 cnts = torch._dynamo.testing.CompileCounter() 2473 opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 2474 2475 x = np.arange(5) 2476 r = opt_fn(x) 2477 2478 self.assertEqual(r, fn(x)) 2479 self.assertEqual(type(r), np.ndarray) 2480 self.assertEqual(cnts.frame_count, 1) 2481 2482 def test_numpy_no_raise(self): 2483 def _inf_nan_preprocess(t, t_np): 2484 t_np = np.nan_to_num(t_np) 2485 return t, t_np 2486 2487 def fn(): 2488 # shape, dims format 2489 test_cases = ( 2490 (3, 3), 2491 (4, 4), 2492 (5, 5), 2493 ) 2494 2495 for shape in test_cases: 2496 t = torch.randn(shape, dtype=torch.complex64) 2497 t_np = np.random.randn(*shape).astype(np.complex64) 2498 2499 _, t_np = _inf_nan_preprocess(t, t_np) 2500 print(t, t_np) # Just a side effect so that compilation kicks in 2501 2502 cnt = CompileCounterWithBackend("inductor") 2503 fn = torch._dynamo.optimize(cnt)(fn) 2504 fn() 2505 self.assertEqual(cnt.frame_count, ifdynstaticdefault(2, 1)) 2506 2507 def test_mandelbrot_numpy(self): 2508 def mandelbrot_numpy(max_iter): 2509 # Define the boundaries of the complex plane 2510 xn = 450 2511 yn = 375 2512 xmin = -2.25 2513 xmax = 0.75 2514 ymin = -1.25 2515 ymax = 1.25 2516 2517 # Create the grid of complex numbers 2518 x_values = np.linspace(xmin, xmax, xn, dtype=np.float64) 2519 y_values = np.linspace(ymin, ymax, yn, dtype=np.float64) 2520 rx, iy = np.meshgrid(x_values, y_values, indexing="xy") 2521 2522 x = rx.copy() 2523 y = iy.copy() 2524 mask = np.zeros_like(x) 2525 for i in range(max_iter): 2526 x_prev = x 2527 y_prev = y 2528 x = x_prev**2 - y_prev**2 + rx 2529 y = 2 * x_prev * y_prev + iy 2530 inside = np.sqrt(x**2 + y**2) <= 2 2531 mask += inside 2532 return mask 2533 2534 cnts = torch._dynamo.testing.CompileCounter() 2535 opt_fn = torch._dynamo.optimize(cnts, nopython=True)(mandelbrot_numpy) 2536 n_iter = torch._dynamo.config.cache_size_limit - 2 2537 for i in range(n_iter): 2538 x = i + 3 2539 ref = mandelbrot_numpy(x) 2540 res = opt_fn(x) 2541 self.assertEqual(ref, res) 2542 # We need to specialise the number as it's in a forloop 2543 self.assertEqual(cnts.frame_count, n_iter) 2544 2545 def test_numpy_as_global(self): 2546 global x 2547 x = np.arange(10) 2548 2549 @torch.compile(fullgraph=True) 2550 def fn(y): 2551 return y + x + x 2552 2553 r = fn(np.arange(10)) 2554 self.assertEqual(type(r), np.ndarray) 2555 self.assertEqual(r, x * 3) 2556 del x 2557 2558 def test_numpy_gt(self): 2559 x = np.arange(10) 2560 2561 @torch.compile 2562 def fn(y): 2563 return y >= 3 2564 2565 r = fn(x) 2566 self.assertEqual(type(r), np.ndarray) 2567 self.assertEqual(r, x >= 3) 2568 2569 def test_numpy_min(self): 2570 x = np.arange(10) 2571 2572 @torch.compile 2573 def fn(y): 2574 return min(y, 3), min(y, y - 1) 2575 2576 r1, r2 = fn(x) 2577 self.assertEqual(type(r1), np.ndarray) 2578 self.assertEqual(type(r2), np.ndarray) 2579 self.assertEqual(r1, np.minimum(x, 3)) 2580 self.assertEqual(r2, np.minimum(x, x - 1)) 2581 2582 def test_graph_break_correctly_when_passing_numpy_ndarray_to_torch_function(self): 2583 # from transformers/models/big_bird/modeling_big_bird.py 2584 def fn(x: int, y: torch.Tensor): 2585 ndarray_list = [np.ones([2, x])] 2586 ndarray = np.stack(ndarray_list, axis=0) 2587 tensor = torch.tensor(ndarray, dtype=torch.long) 2588 tensor.unsqueeze_(0) 2589 return tensor + y 2590 2591 cnts = torch._dynamo.testing.CompileCounter() 2592 opt_fn = torch._dynamo.optimize(cnts)(fn) 2593 for x in range(1, 10): 2594 y = torch.randn([1, 2, x]) 2595 ref = fn(x, y) 2596 res = opt_fn(x, y) 2597 self.assertEqual(ref, res) 2598 # It's all traced once with x = 1, x = 2 and then x = ks0 2599 # For dynamic it's x=1 and x=ks0 2600 self.assertEqual(cnts.frame_count, ifdynstaticdefault(3, 2)) 2601 2602 def test_numpy_with_builtin_type(self): 2603 x = np.random.rand(5) 2604 2605 def fn(x): 2606 return (x * 5).astype(bool).astype(float).astype(int) + 8 2607 2608 cnts = torch._dynamo.testing.CompileCounter() 2609 opt_fn = torch._dynamo.optimize(cnts)(fn) 2610 2611 r = opt_fn(x) 2612 self.assertEqual(r.dtype, int) 2613 self.assertEqual(cnts.frame_count, 1) 2614 2615 def test_with_builtin_type(self): 2616 x = torch.randn(5) 2617 2618 def fn(x): 2619 return (x * 5).to(bool).to(float).to(int) + 8 2620 2621 cnts = torch._dynamo.testing.CompileCounter() 2622 opt_fn = torch._dynamo.optimize(cnts)(fn) 2623 2624 r = opt_fn(x) 2625 self.assertEqual(r.dtype, torch.int64) 2626 self.assertEqual(cnts.frame_count, 1) 2627 2628 def test_numpy_unique_f16(self): 2629 def fn(): 2630 x = np.asarray([1, 1, 2, 2, 3], dtype=np.float16) 2631 return np.unique(x) 2632 2633 cnts = torch._dynamo.testing.CompileCounter() 2634 opt_fn = torch._dynamo.optimize(cnts)(fn) 2635 2636 r = opt_fn() 2637 self.assertEqual(r.dtype, np.float16) 2638 self.assertEqual(cnts.frame_count, 1) 2639 2640 def test_numpy_fallback_on_eager(self): 2641 def fn(): 2642 return np.asarray(["L", "U"]) 2643 2644 cnts = torch._dynamo.testing.CompileCounter() 2645 opt_fn = torch._dynamo.optimize(cnts)(fn) 2646 2647 r = opt_fn() 2648 self.assertEqual(cnts.frame_count, 0) # graph break 2649 self.assertEqual(r, np.asarray(["L", "U"])) 2650 2651 # repeat with a different function 2652 def fn2(): 2653 return np.random.choice(["L", "U"]) 2654 2655 cnts2 = torch._dynamo.testing.CompileCounter() 2656 opt_fn2 = torch._dynamo.optimize(cnts2)(fn2) 2657 2658 r2 = fn2() 2659 self.assertEqual(cnts.frame_count, 0) 2660 assert r2 in ("L", "U") 2661 2662 def test_trace_ndarray_frame(self): 2663 def fn(x): 2664 x = x**2 2665 print("graph break.") 2666 return 2 * x 2667 2668 counter = CompileCounter() 2669 compiled_fn = torch._dynamo.optimize(counter)(fn) 2670 2671 x = np.arange(8) 2672 self.assertEqual(fn(x), compiled_fn(x)) 2673 self.assertEqual(counter.frame_count, 2) 2674 2675 def test_trace_ndarray_frame_2(self): 2676 # no tensors/ndarray as inputs in the frame 2677 def fn(x): 2678 print("graph break.") 2679 return 2 * np.arange(x) 2680 2681 counter = CompileCounter() 2682 compiled_fn = torch._dynamo.optimize(counter)(fn) 2683 2684 x = 8 2685 self.assertEqual(fn(x), compiled_fn(x)) 2686 self.assertEqual(counter.frame_count, 1) 2687 2688 def test_numpy_non_torch_dtype(self): 2689 # test that we gracefully graph break on dtypes 2690 # that do not have pytorch equivalents. 2691 def fn(x): 2692 return isinstance(x, torch.Tensor) 2693 2694 cnts = torch._dynamo.testing.CompileCounter() 2695 opt_fn = torch._dynamo.optimize(cnts)(fn) 2696 2697 # torch does not have the `uint16` dtype 2698 for x in [np.array([42], dtype=np.uint16), np.uint16(42), np.dtype("uint16")]: 2699 r = opt_fn(x) 2700 2701 self.assertEqual(r, False) 2702 self.assertEqual(cnts.frame_count, 0) # graph break 2703 2704 def test_numpy_iter(self): 2705 # test that iteration over an ndarray produces ndarrays not bare tensors 2706 def fn(x): 2707 return [bm for bm in x] 2708 2709 cnts = torch._dynamo.testing.CompileCounter() 2710 opt_fn = torch._dynamo.optimize(cnts)(fn) 2711 2712 proba_map = np.arange(3)[:, None] 2713 res = opt_fn(proba_map) 2714 2715 self.assertEqual([type(r) for r in res], [np.ndarray, np.ndarray, np.ndarray]) 2716 self.assertEqual(res, [np.array([0]), np.array([1]), np.array([2])]) 2717 self.assertEqual(cnts.frame_count, 1) 2718 2719 # cache size limit needs to be larger than the `dtypes` list size 2720 @torch._dynamo.config.patch(cache_size_limit=12) 2721 def test_dtypes_no_graphbreaks(self): 2722 dtypes = [ 2723 # floats 2724 float, 2725 np.float64, 2726 "float64", 2727 np.float32, 2728 "float32", 2729 # np.dtype('float64') # XXX: this is not supported, yet 2730 # integers 2731 int, 2732 "int", 2733 np.intp, 2734 np.int32, 2735 np.uint8 2736 # np.dtype('int') # XXX: as above 2737 ] 2738 2739 def fn(dt): 2740 return np.arange(5, dtype=dt) 2741 2742 for dtyp in dtypes: 2743 cnts = torch._dynamo.testing.CompileCounter() 2744 opt_fn = torch._dynamo.optimize(cnts)(fn) 2745 2746 val = fn(dtyp) 2747 opt_val = opt_fn(dtyp) 2748 2749 self.assertEqual(cnts.frame_count, 1) # no graph break 2750 2751 # setting the config value makes the PRNG identical to numpy's 2752 # NB this may involve a graph break 2753 @torch._dynamo.config.patch(use_numpy_random_stream=True) 2754 def test_numpy_random_config_to_numpy(self): 2755 @torch.compile 2756 def fn(): 2757 return np.random.uniform(size=13) 2758 2759 self.assertEqual(fn().shape, (13,)) 2760 2761 def test_inplace_view_on_graph_input(self): 2762 # graph break when calling methods with inplace_view tag on graph input 2763 func_args_map = { 2764 lambda x: x.resize_(6).mul_(2): torch.ones(4), 2765 lambda x: x.t_().mul_(2): torch.rand(2, 3), 2766 lambda x: x.transpose_(0, 1).mul_(2): torch.rand(2, 3), 2767 lambda x: x.squeeze_().mul_(2): torch.rand(1, 2, 3), 2768 lambda x: x.unsqueeze_(0).mul_(2): torch.rand(2, 3), 2769 lambda x: x.resize_as_(torch.rand(200, 300)): torch.rand(2, 3), 2770 lambda x: x.swapaxes_(0, 1).mul_(2): torch.rand(2, 3), 2771 lambda x: x.swapdims_(0, 1).mul_(2): torch.rand(2, 3), 2772 lambda x: x.rename_("N", "C").mul_(2): torch.zeros(2, 3), 2773 lambda x: x.as_strided_((3, 2), (2, 1)).mul_(2): torch.zeros(2, 3), 2774 lambda x: x.detach_().mul_(2): torch.zeros(2, 3), 2775 } 2776 for func, args in func_args_map.items(): 2777 args_clone = args.clone() 2778 cnts = torch._dynamo.testing.CompileCounter() 2779 opt_f = torch._dynamo.optimize(cnts)(func) 2780 self.assertTrue(same(func(args).shape, opt_f(args_clone).shape)) 2781 self.assertEqual(cnts.frame_count, 1) 2782 self.assertEqual(cnts.op_count, 1) # mul_ 2783 2784 def test_out_variants_with_resizing_on_graph_inputs(self): 2785 def fn(x, y): 2786 return torch.cosh(x, out=y) + 1 2787 2788 x = torch.rand(2, 3) 2789 y = torch.rand(4) 2790 2791 cnts = torch._dynamo.testing.CompileCounter() 2792 opt_fn = torch.compile(fn, backend=cnts) 2793 self.assertTrue(same(fn(x, y), opt_fn(x.clone(), y.clone()))) 2794 self.assertEqual(cnts.frame_count, 1) 2795 2796 def test_out_variants_with_resizing_on_graph_inputs_with_dynamic(self): 2797 # https://github.com/pytorch/pytorch/issues/120482 2798 class CustomModel(torch.nn.Module): 2799 def __init__(self): 2800 super().__init__() 2801 2802 def forward(self, inputs): 2803 return torch.outer(**inputs) 2804 2805 compile_fn = torch.compile(CustomModel(), fullgraph=True) 2806 2807 shapes = [(2, 1), (6, 1), (4, 1)] 2808 for shape in shapes: 2809 vec1, vec2 = shape 2810 input_tensor1 = torch.randn(vec1) 2811 input_tensor2 = torch.randn(vec2) 2812 out_tensor = torch.empty(shape) 2813 args = {"input": input_tensor1, "vec2": input_tensor2, "out": out_tensor} 2814 res = compile_fn(args) 2815 opt_res = res.clone() # cuz this is out and we mutate it 2816 res = CustomModel()(args) 2817 self.assertEqual(res, opt_res) 2818 2819 def test_dict_mutation_side_effect(self): 2820 def fn(d): 2821 d["c"] = d["a"] + d.pop("b") 2822 return d 2823 2824 args1 = {"a": torch.randn(10), "b": torch.randn(10)} 2825 args2 = dict(args1) 2826 assert fn(args1) is args1 2827 cnts = torch._dynamo.testing.CompileCounter() 2828 opt_fn = torch._dynamo.optimize(cnts)(fn) 2829 self.assertIs(opt_fn(args2), args2) 2830 self.assertTrue(same(args1, args2)) 2831 self.assertEqual(cnts.frame_count, 1) 2832 self.assertEqual(cnts.op_count, 1) 2833 2834 def test_dict_order_keys(self): 2835 def fn(d): 2836 c = 0 2837 for v in d.values(): 2838 c += v 2839 return c 2840 2841 args1 = {} 2842 args1["a"] = torch.rand(10) 2843 args1["b"] = torch.rand(10) 2844 cnts = torch._dynamo.testing.CompileCounter() 2845 opt_fn = torch._dynamo.optimize(cnts)(fn) 2846 self.assertEqual(fn(args1), opt_fn(args1)) 2847 self.assertEqual(cnts.frame_count, 1) 2848 self.assertEqual(cnts.op_count, 2) 2849 2850 # A different order of keys recompiles 2851 args2 = {} 2852 args2["b"] = args1["b"] 2853 args2["a"] = args1["a"] 2854 self.assertEqual(fn(args2), opt_fn(args2)) 2855 self.assertEqual(cnts.frame_count, 2) 2856 # Extra calls don't recompile 2857 self.assertEqual(cnts.frame_count, 2) 2858 2859 def test_dict_namedtuple(self): 2860 def fn(d): 2861 return d[3] * 2 2862 2863 args1 = {collections.namedtuple: None, 3: torch.randn(3)} 2864 cnts = torch._dynamo.testing.CompileCounter() 2865 opt_fn = torch._dynamo.optimize(cnts)(fn) 2866 self.assertEqual(fn(args1), opt_fn(args1)) 2867 self.assertEqual(cnts.frame_count, 1) 2868 # Test a failing namedtuple guard 2869 args2 = {2: None, 3: torch.randn(3)} 2870 self.assertEqual(fn(args2), opt_fn(args2)) 2871 self.assertEqual(cnts.frame_count, 2) 2872 2873 def test_dict_order_keys_tensors(self): 2874 def fn(d, x): 2875 return d[x] + 3 2876 2877 args1 = {} 2878 x = torch.randn(10) 2879 y = torch.randn(10) 2880 z = torch.randn(10) 2881 args1[x] = y 2882 args1[3] = z 2883 2884 cnts = torch._dynamo.testing.CompileCounter() 2885 opt_fn = torch._dynamo.optimize(cnts)(fn) 2886 self.assertEqual(fn(args1, x), opt_fn(args1, x)) 2887 self.assertEqual(cnts.frame_count, 1) 2888 2889 # Calling again doesn't recompile (same id and key order) 2890 opt_fn(args1, x) 2891 self.assertEqual(cnts.frame_count, 1) 2892 args2 = {} 2893 args2[3] = z 2894 args2[x] = y 2895 2896 # Different order recompiles 2897 self.assertEqual(fn(args2, x), opt_fn(args2, x)) 2898 self.assertEqual(cnts.frame_count, 2) 2899 2900 def test_dict_order_keys_modules(self): 2901 def fn(d, x): 2902 return d[x](torch.ones(2, 2)) 2903 2904 args1 = {} 2905 x = torch.nn.Linear(2, 2) 2906 y = torch.nn.Linear(2, 2) 2907 z = torch.nn.Linear(2, 2) 2908 args1[x] = y 2909 args1[3] = z 2910 2911 cnts = torch._dynamo.testing.CompileCounter() 2912 opt_fn = torch._dynamo.optimize(cnts)(fn) 2913 self.assertEqual(fn(args1, x), opt_fn(args1, x)) 2914 self.assertEqual(cnts.frame_count, 1) 2915 2916 # Calling again doesn't recompile (same id and key order) 2917 opt_fn(args1, x) 2918 self.assertEqual(cnts.frame_count, 1) 2919 args2 = {} 2920 args2[3] = z 2921 args2[x] = y 2922 2923 # Different order recompiles 2924 self.assertEqual(fn(args2, x), opt_fn(args2, x)) 2925 self.assertEqual(cnts.frame_count, 2) 2926 2927 def test_dunder_new_function_inlining(self): 2928 # https://github.com/pytorch/pytorch/issues/107460 2929 2930 counters.clear() 2931 2932 class ModelA(torch.nn.Module): 2933 def __init__(self): 2934 super().__init__() 2935 2936 def forward(self, x): 2937 return torch.tanh(x + 1) 2938 2939 class ModelB(torch.nn.Module): 2940 def __new__(cls): 2941 return ModelA() 2942 2943 class Model(torch.nn.Module): 2944 def __init__(self): 2945 super().__init__() 2946 self.layer = torch.nn.Linear(2, 2) 2947 2948 def forward(self, x): 2949 other = ModelB() 2950 return self.layer(x) + other(x) 2951 2952 x = torch.rand(2, 2) 2953 m = Model() 2954 2955 opt_m = torch.compile(backend="eager")(m) 2956 ref = m(x) 2957 res = opt_m(x) 2958 self.assertTrue(same(ref, res)) 2959 self.assertEqual(len(counters["graph_break"]), 1) 2960 self.assertFalse("super() nn.Module.__init__" in counters["graph_break"]) 2961 2962 def test_class_duner_mro(self): 2963 class ModuleA(torch.nn.Module): 2964 pass 2965 2966 class ModuleB(ModuleA): 2967 pass 2968 2969 def fn(x, mod): 2970 if ModuleA in type(mod).__mro__: 2971 return x + 1 2972 else: 2973 return x - 1 2974 2975 x = torch.rand(2, 3) 2976 mod = ModuleB() 2977 opt_fn = torch.compile(backend="eager", fullgraph=True)(fn) 2978 ref = fn(x, mod) 2979 res = opt_fn(x, mod) 2980 self.assertTrue(same(ref, res)) 2981 2982 def test_nested_wraps(self): 2983 def foo(x, y): 2984 def add(x, y): 2985 return x + y 2986 2987 @functools.wraps(add) 2988 def wrapped_call(x, y): 2989 return add(x, y) 2990 2991 return wrapped_call(x, y) 2992 2993 x = torch.randn(3, 3) 2994 y = torch.randn(3, 3) 2995 2996 o = torch.compile(foo, fullgraph=True, backend="eager")(x, y) 2997 self.assertEqual(o, x + y) 2998 2999 def foo(x, y): 3000 def nested_call(x, y): 3001 def mul(x, y): 3002 return x * y 3003 3004 @functools.wraps(mul) 3005 def double_nested_call(x, y): 3006 return mul(x, y) 3007 3008 return double_nested_call(x, y) 3009 3010 return nested_call(x, y) 3011 3012 o = torch.compile(foo, fullgraph=True, backend="eager")(x, y) 3013 self.assertEqual(o, x * y) 3014 3015 def test_module_deepcopy(self): 3016 m1 = torch.nn.Sequential( 3017 torch.nn.Linear(10, 10), 3018 torch.nn.ReLU(), 3019 torch.nn.Linear(10, 10), 3020 torch.nn.ReLU(), 3021 ) 3022 m2 = torch.nn.Sequential( 3023 torch.nn.Linear(10, 10), 3024 torch.nn.ReLU(), 3025 torch.nn.Linear(10, 10), 3026 torch.nn.ReLU(), 3027 ) 3028 3029 def fn(m, x): 3030 m_copy = copy.deepcopy(m) 3031 return m_copy(x) 3032 3033 v = torch.randn(10) 3034 correct1 = fn(m1, v) 3035 correct2 = fn(m2, v) 3036 cnts = torch._dynamo.testing.CompileCounter() 3037 opt_fn = torch._dynamo.optimize(cnts)(fn) 3038 for _ in range(10): 3039 self.assertTrue(same(opt_fn(m1, v), correct1)) 3040 for _ in range(10): 3041 self.assertTrue(same(opt_fn(m2, v), correct2)) 3042 self.assertEqual(cnts.frame_count, 1) 3043 self.assertEqual(cnts.op_count, 4) 3044 3045 def test_type_copy(self): 3046 def fn(seq): 3047 a, b = seq 3048 return type(seq)([a + 1, b + 2, a + b]) 3049 3050 args1 = [torch.randn(10), torch.randn(10)] 3051 args2 = (torch.randn(10), torch.randn(10)) 3052 correct1 = fn(args1) 3053 correct2 = fn(args2) 3054 cnts = torch._dynamo.testing.CompileCounter() 3055 opt_fn = torch._dynamo.optimize(cnts)(fn) 3056 self.assertTrue(same(opt_fn(args1), correct1)) 3057 self.assertTrue(same(opt_fn(args2), correct2)) 3058 self.assertIsInstance(opt_fn(args1), list) 3059 self.assertIsInstance(opt_fn(args2), tuple) 3060 self.assertEqual(cnts.frame_count, 2) 3061 self.assertEqual(cnts.op_count, 6) 3062 3063 def test_setattr_mutation1(self): 3064 class MyObj: # noqa: B903 3065 def __init__(self, a, b): 3066 self.a = a 3067 self.b = b 3068 3069 def fn(obj): 3070 obj.c = obj.a * obj.b + 1 3071 obj.b = obj.a * obj.c + 2 3072 obj.a = obj.b * obj.c + 3 3073 obj.c = obj.a * obj.b + 4 3074 obj.b = obj.a * obj.c + 5 3075 obj.a = obj.b * obj.c + 6 3076 return obj 3077 3078 x1 = torch.randn(10) 3079 x2 = torch.randn(10) 3080 obj1 = MyObj(x1, x2) 3081 obj2 = MyObj(x1, x2) 3082 fn(obj2) 3083 cnts = torch._dynamo.testing.CompileCounter() 3084 opt_fn = torch._dynamo.optimize(cnts)(fn) 3085 self.assertIs(opt_fn(obj1), obj1) 3086 self.assertTrue(same(obj1.a, obj2.a)) 3087 self.assertTrue(same(obj1.b, obj2.b)) 3088 self.assertTrue(same(obj1.c, obj2.c)) 3089 self.assertEqual(cnts.frame_count, 1) 3090 self.assertEqual(cnts.op_count, 12) 3091 3092 def test_setattr_mutation2(self): 3093 class MyObj: 3094 def __init__(self, x): 3095 self.a = x + 1 3096 self.b = x + 2 3097 3098 def fn(x): 3099 x = x / 3.0 3100 obj = MyObj(x) 3101 obj.c = obj.a * obj.b + 1 3102 obj.b = obj.a * obj.c + 2 3103 obj.a = obj.b * obj.c + 3 3104 return obj 3105 3106 x1 = torch.randn(10) 3107 obj2 = fn(x1) 3108 3109 cnts = torch._dynamo.testing.CompileCounter() 3110 opt_fn = torch._dynamo.optimize(cnts)(fn) 3111 obj1 = opt_fn(x1) 3112 self.assertTrue(same(obj1.a, obj2.a)) 3113 self.assertTrue(same(obj1.b, obj2.b)) 3114 self.assertTrue(same(obj1.c, obj2.c)) 3115 self.assertEqual(cnts.frame_count, 1) 3116 self.assertEqual(cnts.op_count, 9) 3117 3118 def test_setattr_mutation3(self): 3119 # TODO(jansel): dead code eliminate the object creation 3120 class MyObj: 3121 def __init__(self, x): 3122 super().__init__() 3123 self.a = x + 1 3124 self.b = x + 2 3125 3126 def fn(x): 3127 x = x / 3.0 3128 obj = MyObj(x) 3129 obj.c = obj.a * obj.b + 1 3130 obj.b = obj.a * obj.c + 2 3131 obj.a = obj.b * obj.c + 3 3132 return obj.a, obj.b, obj.c 3133 3134 x1 = torch.randn(10) 3135 obj2 = fn(x1) 3136 3137 cnts = torch._dynamo.testing.CompileCounter() 3138 opt_fn = torch._dynamo.optimize(cnts)(fn) 3139 obj1 = opt_fn(x1) 3140 self.assertTrue(same(obj1, obj2)) 3141 self.assertEqual(cnts.frame_count, 1) 3142 self.assertEqual(cnts.op_count, 9) 3143 3144 def test_object_setattr(self): 3145 @dataclasses.dataclass 3146 class A: 3147 x: torch.Tensor 3148 3149 def fn1(x) -> None: 3150 a = A(x) 3151 object.__setattr__(a, "x", x + 2) 3152 return a 3153 3154 x1 = torch.randn(10) 3155 obj11 = fn1(x1.clone()) 3156 3157 cnts = torch._dynamo.testing.CompileCounter() 3158 opt_fn1 = torch._dynamo.optimize(cnts, nopython=True)(fn1) 3159 obj12 = opt_fn1(x1.clone()) 3160 self.assertTrue(same(obj11.x, x1 + 2)) 3161 self.assertTrue(same(obj12.x, x1 + 2)) 3162 self.assertTrue(same(obj11.x, obj12.x)) 3163 self.assertEqual(cnts.frame_count, 1) 3164 3165 @dataclasses.dataclass(frozen=True) 3166 class B: 3167 x: torch.Tensor 3168 3169 def fn2(x) -> None: 3170 b = B(x) 3171 return b 3172 3173 x2 = torch.randn(10) 3174 obj21 = fn2(x2.clone()) 3175 3176 cnts = torch._dynamo.testing.CompileCounter() 3177 opt_fn2 = torch._dynamo.optimize(cnts, nopython=True)(fn2) 3178 obj22 = opt_fn2(x2.clone()) 3179 self.assertTrue(same(obj21.x, x2)) 3180 self.assertTrue(same(obj22.x, x2)) 3181 self.assertTrue(same(obj21.x, obj22.x)) 3182 self.assertEqual(cnts.frame_count, 0) 3183 3184 @dataclasses.dataclass(frozen=True) 3185 class C: 3186 x: torch.Tensor 3187 3188 def fn3(x) -> None: 3189 c = C(x) 3190 object.__setattr__(c, "x", x + 2) 3191 return c 3192 3193 x3 = torch.randn(10) 3194 obj31 = fn3(x3.clone()) 3195 3196 cnts = torch._dynamo.testing.CompileCounter() 3197 opt_fn3 = torch._dynamo.optimize(cnts, nopython=True)(fn3) 3198 obj32 = opt_fn3(x3.clone()) 3199 self.assertTrue(same(obj31.x, x3 + 2)) 3200 self.assertTrue(same(obj32.x, x3 + 2)) 3201 self.assertTrue(same(obj31.x, obj32.x)) 3202 self.assertEqual(cnts.frame_count, 1) 3203 3204 @dataclasses.dataclass(frozen=True) 3205 class D: 3206 x: torch.Tensor 3207 3208 def __post_init__(self): 3209 object.__setattr__(self, "y", self.x + 2) 3210 3211 def fn4(x) -> None: 3212 d = D(x) 3213 return d 3214 3215 x4 = torch.randn(10) 3216 obj41 = fn4(x4.clone()) 3217 3218 cnts = torch._dynamo.testing.CompileCounter() 3219 opt_fn4 = torch._dynamo.optimize(cnts, nopython=True)(fn4) 3220 obj42 = opt_fn4(x4.clone()) 3221 self.assertTrue(same(obj41.x, x4)) 3222 self.assertTrue(same(obj42.x, x4)) 3223 self.assertTrue(same(obj41.x, obj42.x)) 3224 self.assertTrue(same(obj41.y, x4 + 2)) 3225 self.assertTrue(same(obj42.y, x4 + 2)) 3226 self.assertTrue(same(obj41.y, obj42.y)) 3227 self.assertEqual(cnts.frame_count, 1) 3228 3229 def test_user_defined_class_name(self): 3230 class MyClassFoo: 3231 pass 3232 3233 def fn1(a, b, c): 3234 tmp = MyClassFoo() 3235 if tmp.__class__.__name__ == "MyClassFoo": 3236 return a - b / c 3237 3238 torch._dynamo.testing.standard_test(self, fn=fn1, nargs=3) 3239 3240 def test_user_defined_class_python_type(self): 3241 class MyClass1: 3242 pass 3243 3244 class ExampleMeta(type): 3245 pass 3246 3247 class MyClass2(metaclass=ExampleMeta): 3248 pass 3249 3250 def fn(x, c): 3251 if isinstance(c, MyClass1): 3252 return x + 1 3253 elif isinstance(c, MyClass2): 3254 return x + 2 3255 else: 3256 return x + 3 3257 3258 x = torch.rand(3) 3259 opt_fn = torch._dynamo.optimize("eager")(fn) 3260 for c in [MyClass1, MyClass2]: 3261 ref = fn(x, c) 3262 res = opt_fn(x, c) 3263 self.assertTrue(same(ref, res)) 3264 3265 def test_super_calling_with_metaclass(self): 3266 class ExampleMeta(type): 3267 pass 3268 3269 class MyClass1(metaclass=ExampleMeta): 3270 coeff = 4 # Force the constant guard to test source in guards 3271 3272 @classmethod 3273 def add(cls, x): 3274 return x + 1 3275 3276 class MyClass2(MyClass1): 3277 @classmethod 3278 def add(cls, x): 3279 torch._dynamo.graph_break() 3280 return x + super().add(x) + super().coeff 3281 3282 def fn(x, obj): 3283 return x + obj.add(x) 3284 3285 x = torch.rand(3) 3286 obj = MyClass2() 3287 opt_fn = torch._dynamo.optimize("eager")(fn) 3288 ref = fn(x, obj) 3289 res = opt_fn(x, obj) 3290 self.assertTrue(same(ref, res)) 3291 3292 def test_usr_cls_staticmethod(self): 3293 class Foo: 3294 @staticmethod 3295 def bar(a, b): 3296 return a + b 3297 3298 def fn(a, b): 3299 return Foo.bar(a, b) - 1 3300 3301 torch._dynamo.testing.standard_test(self, fn=fn, nargs=2) 3302 3303 def test_usr_cls_classmethod(self): 3304 class Foo: 3305 @classmethod 3306 def bar(cls, a, b): 3307 return a + b 3308 3309 def fn(a, b): 3310 return Foo.bar(a, b) - 1 3311 3312 torch._dynamo.testing.standard_test(self, fn=fn, nargs=2) 3313 3314 def test_dunder_methods(self): 3315 class Foo: 3316 def __init__(self, val): 3317 super().__init__() 3318 self.val = val 3319 3320 def __add__(self, other): 3321 return Foo(self.val + other.val) 3322 3323 def __mul__(self, other): 3324 return Foo(self.val * other.val) 3325 3326 def __truediv__(self, other): 3327 return Foo(self.val / other.val) 3328 3329 def __sub__(self, other): 3330 return Foo(self.val - other.val) 3331 3332 def fn(a, b, c): 3333 return Foo(a) + Foo(b) * Foo(c) / Foo(a) - Foo(b) 3334 3335 torch._dynamo.testing.standard_test(self, fn=fn, nargs=3, expected_ops=4) 3336 3337 def test_function_annotation(self): 3338 class Variable: 3339 pass 3340 3341 def fn(x): 3342 x = x / 3.0 3343 3344 def inner(y: typing.List[Variable]): 3345 return x + 1 3346 3347 return inner 3348 3349 x1 = torch.randn(10) 3350 obj2 = fn(x1)([]) 3351 3352 cnts = torch._dynamo.testing.CompileCounter() 3353 opt_fn = torch._dynamo.optimize_assert(cnts)(fn) 3354 opt_fn_inner = torch._dynamo.optimize_assert(cnts)(opt_fn(x1)) 3355 obj1 = opt_fn_inner([]) 3356 self.assertTrue(same(obj1, obj2)) 3357 self.assertEqual(cnts.frame_count, 2) 3358 self.assertEqual(cnts.op_count, 2) 3359 3360 def test_nested_closure(self): 3361 v0 = torch.randn(10) 3362 3363 def fn1(): 3364 v1 = torch.randn(10) 3365 3366 def fn2(*args, **kwargs): 3367 assert len(args) == 1 3368 assert len(kwargs) == 1 3369 v2 = torch.randn(10) + args[0] + kwargs["b"] 3370 3371 def fn3(v3=torch.randn(10)): 3372 def fn4(): 3373 return v0 + v1 + v2 + v3 + 1 3374 3375 return fn4 3376 3377 return fn3 3378 3379 return fn2(1, b=2)() 3380 3381 cnts = torch._dynamo.testing.CompileCounter() 3382 opt_fn1 = torch._dynamo.optimize_assert(cnts)(fn1) 3383 tmp1 = torch._dynamo.optimize_assert(cnts)(opt_fn1()) 3384 tmp2 = torch._dynamo.optimize_assert(cnts)(opt_fn1()) 3385 self.assertTrue(tmp1().shape, (10,)) 3386 self.assertTrue(same(tmp1(), tmp1())) 3387 self.assertFalse(same(tmp1(), tmp2())) 3388 self.assertEqual(cnts.frame_count, 2) 3389 self.assertEqual(cnts.op_count, 9) 3390 3391 def test_nested_closure_mutation(self): 3392 def fn1(): 3393 v1 = torch.randn(10) 3394 3395 def fn2(): 3396 v2 = torch.randn(10) 3397 3398 def fn3(): 3399 nonlocal v1, v2 3400 v1 += 1 3401 v2 += 2 3402 return v1 + v2 3403 3404 return fn3 3405 3406 rv = fn2() 3407 rv() 3408 rv() 3409 return rv 3410 3411 torch.manual_seed(9000) 3412 counter1 = fn1() 3413 result1 = [counter1(), counter1(), counter1()] 3414 3415 torch.manual_seed(9000) 3416 cnts = torch._dynamo.testing.CompileCounter() 3417 opt_fn1 = torch._dynamo.optimize_assert(cnts)(fn1) 3418 counter2 = torch._dynamo.optimize_assert(cnts)(opt_fn1()) 3419 result2 = [counter2(), counter2(), counter2()] 3420 result1.append(counter1()) 3421 result2.append(counter2()) 3422 3423 self.assertTrue(same(result1, result2)) 3424 self.assertEqual(cnts.frame_count, 2) 3425 self.assertEqual(cnts.op_count, 11) 3426 3427 def test_write_to_closures_in_inlining(self): 3428 out = [] 3429 for use_dynamo in [False, True]: 3430 3431 def make_counter(): 3432 x = torch.randn(10) 3433 3434 def counter(): 3435 nonlocal x 3436 x = x + 1 3437 return x 3438 3439 return counter 3440 3441 torch.manual_seed(0) 3442 counter = make_counter() 3443 if not use_dynamo: 3444 out.append(counter() + counter()) 3445 else: 3446 cnts = torch._dynamo.testing.CompileCounter() 3447 3448 @torch._dynamo.optimize(cnts, nopython=True) 3449 def fn(counter): 3450 return counter() + counter() 3451 3452 out.append(fn(counter)) 3453 self.assertEqual(cnts.frame_count, 1) 3454 self.assertEqual(cnts.op_count, 3) 3455 self.assertFalse(same(counter() + counter(), out[-1])) 3456 3457 self.assertTrue(same(out[0], out[1])) 3458 3459 def test_closure_out_of_scope_cell(self): 3460 cell1 = torch.rand(1).item() 3461 cell2 = torch.rand(3, 3) 3462 3463 def indirect(): 3464 return direct() 3465 3466 def direct(): 3467 def inner(): 3468 return cell1 + 1, cell2 + 3 3469 3470 return inner() 3471 3472 cnts = torch._dynamo.testing.CompileCounter() 3473 opt_fn = torch._dynamo.optimize(cnts)(indirect) 3474 result1, result2 = opt_fn() 3475 self.assertAlmostEqual(cell1 + 1, result1) 3476 self.assertTrue(torch.allclose(cell2 + 3, result2)) 3477 self.assertEqual(cnts.frame_count, 1) 3478 self.assertEqual(cnts.op_count, 1) 3479 3480 def test_closure_out_of_scope_cell_with_mutation(self): 3481 cell1 = torch.rand(1).item() 3482 orig1 = cell1 3483 cell2 = torch.rand(3, 3) 3484 orig2 = cell2.clone() 3485 3486 def indirect(): 3487 return direct() 3488 3489 def direct(): 3490 def inner(): 3491 nonlocal cell1, cell2 3492 x = cell2 + 1 3493 cell1 += 1 3494 cell2 += 10 3495 x = x + cell2 3496 return cell1, cell2, x 3497 3498 return inner() 3499 3500 cnts = torch._dynamo.testing.CompileCounter() 3501 opt_fn = torch._dynamo.optimize(cnts, nopython=True)(indirect) 3502 for i in range(1, 4): 3503 result1, result2, _ = opt_fn() 3504 self.assertAlmostEqual(orig1 + 1 * i, result1) 3505 self.assertTrue(torch.allclose(orig2 + 10 * i, result2)) 3506 self.assertEqual(cnts.frame_count, 1) 3507 self.assertEqual(cnts.op_count, 3) 3508 cnts.clear() 3509 3510 def test_closure_with_mutation_and_graph_break(self): 3511 def fn(): 3512 x = torch.zeros(1) 3513 3514 def subfunc(): 3515 x[0] = backup 3516 3517 if x[0] >= -1e5: 3518 pass 3519 3520 backup = 1 3521 subfunc() 3522 return x 3523 3524 cnts = torch._dynamo.testing.CompileCounter() 3525 opt_fn = torch._dynamo.optimize(cnts)(fn) 3526 expected = fn() 3527 actual = opt_fn() 3528 self.assertTrue(same(expected, actual)) 3529 self.assertEqual(cnts.frame_count, 2) 3530 3531 def test_closure_out_of_scope_cell_with_cond(self): 3532 # Test closure with out-of-scope cell variable, used in a cond 3533 # where the two branches read different closure variables 3534 from functorch.experimental.control_flow import cond 3535 3536 def g(x): 3537 return x 3538 3539 class ModuleCondDeep(torch.nn.Module): 3540 def forward(self, pred, x): 3541 return self._indirection(pred, x) 3542 3543 def _indirection(self, pred, x): 3544 return self.indirection(pred, x) 3545 3546 def indirection(self, pred, x): 3547 def true_fn(y): 3548 return y + 2 3549 3550 def false_fn(y): 3551 return y - 2 3552 3553 def shallow(x): 3554 return x * 2 3555 3556 def deep(x): 3557 # y = g(x) 3558 y = x 3559 return cond( 3560 x[0][0] > 0, 3561 true_fn, 3562 false_fn, 3563 [y], 3564 ) 3565 3566 return cond(pred, shallow, deep, [x]) 3567 3568 mod = ModuleCondDeep() 3569 opt_mod = torch._dynamo.optimize("eager")(mod) 3570 inp = torch.randn(3, 3) 3571 exp1 = mod(torch.tensor(False), inp) 3572 actual1 = opt_mod(torch.tensor(False), inp) 3573 exp2 = mod(torch.tensor(True), inp) 3574 actual2 = opt_mod(torch.tensor(True), inp) 3575 self.assertTrue(torch.allclose(exp1, actual1)) 3576 self.assertTrue(torch.allclose(exp2, actual2)) 3577 3578 def test_top_package_import(self): 3579 def fn(x): 3580 import torch.fx 3581 3582 assert not isinstance(x, torch.fx.Proxy) 3583 return torch.sin(x) 3584 3585 x = torch.randn(4, 5) 3586 ref = fn(x) 3587 cnts = torch._dynamo.testing.CompileCounter() 3588 opt_fn = torch._dynamo.optimize_assert(cnts)(fn) 3589 res = opt_fn(x) 3590 self.assertTrue(same(ref, res)) 3591 3592 def test_typing_typevar(self): 3593 def fn(x): 3594 def sumt(y: torch.Tensor) -> torch.Tensor: 3595 return torch.sum(y) 3596 3597 def foo(c: typing.Callable[[T], T], y: T) -> T: 3598 return c(y) 3599 3600 return foo(sumt, x) 3601 3602 x = torch.randn(3) 3603 ref = fn(x) 3604 cnts = torch._dynamo.testing.CompileCounter() 3605 opt_fn = torch._dynamo.optimize_assert(cnts)(fn) 3606 res = opt_fn(x) 3607 self.assertTrue(same(ref, res)) 3608 self.assertEqual(cnts.frame_count, 1) 3609 3610 def test_typing_union_and_optional(self): 3611 def fn(x): 3612 a = torch.jit.annotate(typing.Dict[str, typing.Optional[torch.Tensor]], {}) 3613 b = torch.jit.annotate( 3614 typing.Dict[str, typing.Union[torch.Tensor, None]], {} 3615 ) 3616 return a, b, x + 1 3617 3618 x = torch.randn(3) 3619 ref = fn(x) 3620 opt_fn = torch._dynamo.optimize("eager", nopython=False)(fn) 3621 res = opt_fn(x) 3622 self.assertTrue(same(ref, res)) 3623 3624 def test_optimize_on_module(self): 3625 class MockModule(torch.nn.Module): 3626 def __init__(self): 3627 super().__init__() 3628 self.relu = torch.nn.ReLU() 3629 3630 def custom_member(self): 3631 # Just for checking that Dynamo returned mod object can redirect 3632 # to this method 3633 pass 3634 3635 def forward(self, x): 3636 return self.relu(x) 3637 3638 cnts1 = torch._dynamo.testing.CompileCounter() 3639 mod = MockModule() 3640 optimized_mod = torch._dynamo.optimize(cnts1, nopython=True)(mod) 3641 3642 a = torch.randn(10) 3643 ref = mod(a) 3644 res = optimized_mod(a) 3645 3646 optimized_mod.custom_member() 3647 3648 self.assertTrue(same(ref, res)) 3649 3650 def test_nested_optimize_decorator(self): 3651 cnts2 = torch._dynamo.testing.CompileCounter() 3652 cnts3 = torch._dynamo.testing.CompileCounter() 3653 3654 @torch._dynamo.run() 3655 def fn1(x): 3656 return torch.sin(x) * 10 3657 3658 @torch._dynamo.optimize(cnts2, nopython=True) 3659 def fn2(x): 3660 return fn1(x) + 1 3661 3662 @torch._dynamo.optimize(cnts3, nopython=True) 3663 def fn3(x): 3664 return torch.relu(fn2(x)) 3665 3666 fn3(torch.randn(4, 5)) 3667 self.assertEqual(cnts2.frame_count, 0) 3668 self.assertEqual(cnts3.frame_count, 1) 3669 self.assertEqual(cnts3.op_count, 4) 3670 3671 def test_nested_optimize_run(self): 3672 cnts = torch._dynamo.testing.CompileCounter() 3673 3674 @torch._dynamo.optimize(cnts, nopython=True) 3675 def fn(x): 3676 return torch.relu(torch.cos(x) + torch.sin(x)) 3677 3678 fn(torch.randn(4)) 3679 self.assertEqual(cnts.frame_count, 1) 3680 3681 fn(torch.randn(4, 4)) 3682 self.assertEqual(cnts.frame_count, 2) 3683 3684 # Test that run works on a decorated fn 3685 fn = torch._dynamo.run(fn) 3686 fn(torch.randn(4, 4, 4)) 3687 self.assertEqual(cnts.frame_count, 2) 3688 3689 def test_nested_optimize(self): 3690 cnts1 = torch._dynamo.testing.CompileCounter() 3691 cnts2 = torch._dynamo.testing.CompileCounter() 3692 3693 def fn(x): 3694 return torch.relu(torch.cos(x) + torch.sin(x)) 3695 3696 fn1 = torch._dynamo.optimize(cnts1, nopython=True)(fn) 3697 fn2 = torch._dynamo.optimize(cnts2, nopython=True)(fn1) 3698 3699 # The first optimize in the nesting should be ignored 3700 fn2(torch.randn(4)) 3701 self.assertEqual(cnts2.frame_count, 1) 3702 self.assertEqual(cnts1.frame_count, 0) 3703 3704 # Since the fn code object is already compiled, calling fn1 should 3705 # directly call the compiled_fn callable. 3706 torch._dynamo.run()(fn1)(torch.randn(4)) 3707 self.assertEqual(cnts1.frame_count, 0) 3708 3709 # Test same behavior by reversing the calls 3710 torch._dynamo.reset() 3711 cnts1 = torch._dynamo.testing.CompileCounter() 3712 cnts2 = torch._dynamo.testing.CompileCounter() 3713 fn1 = torch._dynamo.optimize(cnts1, nopython=True)(fn) 3714 fn2 = torch._dynamo.optimize(cnts2, nopython=True)(fn1) 3715 fn1(torch.randn(4)) 3716 self.assertEqual(cnts1.frame_count, 1) 3717 torch._dynamo.run()(fn2)(torch.randn(4)) 3718 self.assertEqual(cnts2.frame_count, 0) 3719 3720 def test_torch_size(self): 3721 cnts = torch._dynamo.testing.CompileCounter() 3722 3723 def fn(x): 3724 output_size = torch.Size([10, 10]) 3725 x = x.view(*output_size) 3726 return (x,) 3727 3728 x = torch.randn(100, requires_grad=True) 3729 x_clone = x.clone() 3730 ref = fn(x) 3731 3732 opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 3733 res = opt_fn(x_clone) 3734 3735 self.assertTrue(same(ref, res)) 3736 3737 def test_torch_size_numel(self): 3738 cnts = torch._dynamo.testing.CompileCounter() 3739 3740 def fn(): 3741 return torch.Size([10, 8]).numel() 3742 3743 opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 3744 num = torch.Size([10, 8]).numel() 3745 self.assertEqual(opt_fn(), num) 3746 3747 def test_torch_size_numel_dynamic(self): 3748 cnts = torch._dynamo.testing.CompileCounter() 3749 3750 def fn(x): 3751 return x.size().numel() 3752 3753 opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 3754 x = torch.rand(10, 1, 8, 1) 3755 expect = fn(x) 3756 self.assertEqual(opt_fn(x), expect) 3757 3758 def test_shape_type(self): 3759 cnts = torch._dynamo.testing.CompileCounter() 3760 3761 def fn(x): 3762 return x + (type(x.shape) == torch.Size) 3763 3764 opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 3765 x = torch.zeros(()) 3766 self.assertEqual(opt_fn(x), fn(x)) 3767 3768 def test_size_dim(self): 3769 cnts = torch._dynamo.testing.CompileCounter() 3770 3771 def fn(x, dim): 3772 return x.size(dim=dim) 3773 3774 opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 3775 x = torch.empty([4, 9, 8]) 3776 self.assertEqual(opt_fn(x, 1), 9) 3777 self.assertEqual(opt_fn(x, -2), 9) 3778 3779 def test_stride_dim(self): 3780 cnts = torch._dynamo.testing.CompileCounter() 3781 3782 def fn(x, dim): 3783 return x.stride(dim=dim) 3784 3785 opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 3786 x = torch.empty([4, 9, 8]) 3787 self.assertEqual(opt_fn(x, 0), 72) 3788 self.assertEqual(opt_fn(x, -2), 8) 3789 3790 def test_torch_seed(self): 3791 from torch._dynamo.utils import counters 3792 3793 cnts = torch._dynamo.testing.CompileCounter() 3794 counters.clear() 3795 3796 def fn(x): 3797 attention_seed = int(torch.seed() % sys.maxsize) 3798 torch.manual_seed(attention_seed) 3799 return (x,) 3800 3801 x = torch.randn(10, requires_grad=True) 3802 ref = fn(x) 3803 3804 # Python code is needed here, since torch.manual_seed graph-breaks. 3805 # Refs: https://github.com/pytorch/pytorch/issues/107187 3806 opt_fn = torch._dynamo.optimize(cnts, nopython=False)(fn) 3807 res = opt_fn(x) 3808 3809 self.assertTrue(same(ref, res)) 3810 # Only the torch.seed call is turned into an FX graph. 3811 self.assertEqual(cnts.op_count, 1) 3812 self.assertEqual(cnts.frame_count, 1) 3813 # Graph breaks at manual_seed. 3814 self.assertEqual(len(counters["graph_break"]), 1) 3815 3816 def test_is_tensor_like(self): 3817 cnts = torch._dynamo.testing.CompileCounter() 3818 3819 def f(x): 3820 if torch.overrides.is_tensor_like(x): 3821 return (x * 2,) 3822 return (torch.ones(10) + x,) 3823 3824 x = torch.randn(10) 3825 ref0 = f(x) 3826 ref1 = f(4) 3827 opt_f = torch._dynamo.optimize(cnts, nopython=True)(f) 3828 res0 = opt_f(x) 3829 res1 = opt_f(4) 3830 self.assertTrue(same(ref0, res0)) 3831 self.assertTrue(same(ref1, res1)) 3832 3833 def test_is_tensor_like2(self): 3834 class MyTensor: 3835 @classmethod 3836 def __torch_function__(cls, func, types, args=(), kwargs=None): 3837 if kwargs is None: 3838 kwargs = {} 3839 3840 if func is torch.max: 3841 return torch.tensor(123) 3842 return func(*args, **kwargs) 3843 3844 def fn(x): 3845 if torch.overrides.is_tensor_like(x): 3846 return torch.max(x) 3847 else: 3848 return torch.zeros(1) 3849 3850 x = MyTensor() 3851 ref0 = fn(x) 3852 ref1 = fn(4) 3853 opt_fn = torch._dynamo.optimize("eager")(fn) 3854 res0 = opt_fn(x) 3855 res1 = opt_fn(4) 3856 self.assertTrue(same(ref0, res0)) 3857 self.assertTrue(same(ref1, res1)) 3858 3859 def test_tensor_data(self): 3860 def fn(x, y): 3861 return x[y.data] 3862 3863 x = torch.rand(8) 3864 y = torch.ones(8).to(torch.int) 3865 ref = fn(x, y) 3866 opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) 3867 res = opt_fn(x, y) 3868 self.assertTrue(same(ref, res)) 3869 3870 def test_tensor_layout(self): 3871 def fn(x): 3872 return torch.zeros( 3873 [x.size()[0], x.size()[1]], 3874 dtype=x.dtype, 3875 layout=x.layout, 3876 device=x.device, 3877 ) 3878 3879 x = torch.rand(2, 3) 3880 ref = fn(x) 3881 opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) 3882 res = opt_fn(x) 3883 self.assertTrue(same(ref, res)) 3884 3885 def test_version_ci(self): 3886 # temporary test to check that the ci torch version is set correctly 3887 self.assertTrue(hasattr(torch, "_subclasses")) 3888 3889 @unittest.skipIf(not TEST_CUDA, "requires cuda") 3890 def test_rand(self): 3891 cnts = torch._dynamo.testing.CompileCounter() 3892 device = "cuda" 3893 3894 def fn(): 3895 return torch.randn(10, device=device) 3896 3897 torch.manual_seed(10) 3898 ref_run1 = fn() 3899 3900 torch.manual_seed(10) 3901 ref_run2 = fn() 3902 self.assertTrue(same(ref_run1, ref_run2)) 3903 3904 torch.manual_seed(10) 3905 opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 3906 res = opt_fn() 3907 3908 self.assertTrue(same(res, ref_run1)) 3909 3910 def test_slice_input(self): 3911 cnts = torch._dynamo.testing.CompileCounter() 3912 3913 def getitem(a, idx): 3914 if isinstance(idx, slice): 3915 return ( 3916 torch.zeros(1), 3917 a[idx] 3918 + [ 3919 100, 3920 ], 3921 ) 3922 else: 3923 return (torch.zeros(1), a[idx]) 3924 3925 layers = list(range(10)) 3926 ref0 = getitem(layers, slice(0, 2, 1)) 3927 ref1 = getitem(layers, 2) 3928 ref2 = getitem(layers, slice(3, 8, 2)) 3929 opt_getitem = torch._dynamo.optimize(cnts, nopython=True)(getitem) 3930 res0 = opt_getitem(layers, slice(0, 2, 1)) 3931 res1 = opt_getitem(layers, 2) 3932 res2 = opt_getitem(layers, slice(3, 8, 2)) 3933 3934 self.assertTrue(ref0 == res0) 3935 self.assertTrue(ref1 == res1) 3936 self.assertTrue(ref2 == res2) 3937 3938 def test_grad(self): 3939 cnts = torch._dynamo.testing.CompileCounter() 3940 3941 def fn(a, b): 3942 out = a * b 3943 out.sum().backward() 3944 real_out = torch.sigmoid(a.grad + b) 3945 return real_out 3946 3947 inps = [torch.randn(4, requires_grad=True) for _ in range(2)] 3948 for inp in inps: 3949 inp.grad = None 3950 ref = fn(*inps) 3951 3952 for inp in inps: 3953 inp.grad = None 3954 opt_fn = torch._dynamo.optimize(cnts)(fn) 3955 res = opt_fn(*inps) 3956 3957 self.assertTrue(same(ref, res)) 3958 3959 @torch._dynamo.config.patch(guard_nn_modules=True) 3960 def test_source_non_input_grad_access(self): 3961 # This test creates a model, and accesses the grads 3962 # from its parameter. This means that within dynamo, 3963 # the tensor we are reading the grad from HAS a source, 3964 # but is not known to graphargs. 3965 cnts = torch._dynamo.testing.CompileCounter() 3966 3967 class TrivialModel(torch.nn.Module): 3968 def __init__(self): 3969 super(TrivialModel, self).__init__() 3970 self.linear = torch.nn.Linear(2, 1) 3971 3972 def forward(self, x): 3973 return self.linear(x) 3974 3975 def fn(a, b): 3976 outs = [] 3977 for param in model.parameters(): 3978 outs.append(torch.ones(param.grad.size())) 3979 return outs, param.grad + 1 3980 3981 model = TrivialModel() 3982 # Eager 3983 a = torch.ones([2, 2], requires_grad=True) 3984 b = torch.ones([2, 2]) 3985 out = model(a) 3986 out_sum = out.sum() 3987 out_sum.backward() 3988 ref = fn(a, b) 3989 3990 # Compiled 3991 model = TrivialModel() 3992 a = torch.ones([2, 2], requires_grad=True) 3993 b = torch.ones([2, 2]) 3994 out = model(a) 3995 out_sum = out.sum() 3996 out_sum.backward() 3997 3998 opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 3999 res = opt_fn(a, b) 4000 4001 self.assertTrue(same(ref, res)) 4002 self.assertEqual(cnts.frame_count, 1) 4003 self.assertEqual(cnts.op_count, 3) 4004 4005 def test_intermediary_tensor_grad_access(self): 4006 # This test creates a model, and accesses the grads 4007 # from its parameters and an entirely intermediary tensor. 4008 cnts = torch._dynamo.testing.CompileCounter() 4009 4010 def fn(a, b): 4011 intermediary = torch.ones(2, 2) 4012 c = a + intermediary 4013 outs = [] 4014 outs.append(intermediary.grad) 4015 return outs 4016 4017 # Eager 4018 a = torch.ones([2, 2], requires_grad=True) 4019 b = torch.ones([2, 2]) 4020 ref = fn(a, b) 4021 4022 # Compiled 4023 a = torch.ones([2, 2], requires_grad=True) 4024 b = torch.ones([2, 2]) 4025 opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 4026 res = opt_fn(a, b) 4027 self.assertTrue(same(ref, res)) 4028 self.assertEqual(cnts.frame_count, 1) 4029 self.assertEqual(cnts.op_count, 2) 4030 4031 def test_clone_sparse_input(self): 4032 for layout in [ 4033 torch.sparse_coo, 4034 torch.sparse_csr, 4035 torch.sparse_csc, 4036 torch.sparse_bsr, 4037 torch.sparse_bsc, 4038 ]: 4039 for sparse_input in self.generate_simple_inputs( 4040 layout, 4041 device="cpu", 4042 dtype=torch.float64, 4043 index_dtype=torch.int64, 4044 ): 4045 # Invoke the dynamo clone input method directly. 4046 sparse_copy = torch._dynamo.utils.clone_input(sparse_input) 4047 # Make sure sparse clone is successful. 4048 self.assertEqual(sparse_input, sparse_copy) 4049 4050 def test_tensor_is_contiguous(self): 4051 def fn(x): 4052 input = torch.randn((1, 16, 1, 1)) 4053 weight = torch.randn((8, 16, 3, 3)) 4054 weight = weight.to(memory_format=x) 4055 output = torch.conv2d(input, weight, None, (2, 1), (1, 1), (1, 1), 1) 4056 return output.is_contiguous(memory_format=x) 4057 4058 opt_fn = torch._dynamo.optimize("eager")(fn) 4059 for x in [torch.contiguous_format, torch.channels_last]: 4060 self.assertEqual(fn(x), opt_fn(x)) 4061 4062 def test_python_slice(self): 4063 def f1(input): 4064 y = 0 4065 for i, x in enumerate(input[2:], 1): 4066 y = y + x 4067 return y 4068 4069 def f2(input): 4070 y = 0 4071 for i, x in enumerate(input.shape[2:], 1): 4072 y = y + x 4073 return y 4074 4075 cnts = torch._dynamo.testing.CompileCounter() 4076 opt_f1 = torch._dynamo.optimize(cnts)(f1) 4077 opt_f2 = torch._dynamo.optimize(cnts)(f2) 4078 res1 = opt_f1([1, 2, 3, 5]) 4079 res2 = opt_f2(torch.rand([2, 3, 4, 5])) 4080 4081 self.assertEqual(res1, 8) 4082 self.assertEqual(res2, 9) 4083 4084 def test_enum_as_dict_key(self): 4085 class MyEnum(enum.Enum): 4086 FOO = 10 4087 BAR = 20 4088 4089 def fn(x): 4090 y = x + 2 4091 z = { 4092 MyEnum.FOO: torch.tensor(1), 4093 MyEnum.BAR: 10, 4094 "MyEnum.BAR": torch.tensor(8), 4095 5: torch.rand(3), 4096 } 4097 torch._dynamo.graph_break() 4098 a = z[MyEnum.FOO] + z["MyEnum.BAR"] 4099 b = y * 2 4100 return a, b 4101 4102 cnts = torch._dynamo.testing.CompileCounter() 4103 opt_fn = torch._dynamo.optimize(cnts)(fn) 4104 for _ in range(10): 4105 x = torch.rand(3) 4106 ref = fn(x) 4107 res = opt_fn(x) 4108 self.assertTrue(same(ref, res)) 4109 self.assertEqual(cnts.frame_count, 2) 4110 4111 def test_enum_as_dict_key_with_overloaded_str(self): 4112 class MyEnum(enum.Enum): 4113 FOO = 10 4114 BAR = 20 4115 4116 def __str__(self): 4117 return self.value 4118 4119 def fn(x): 4120 y = x + 2 4121 z = { 4122 MyEnum.FOO: torch.tensor(1), 4123 MyEnum.BAR: 10, 4124 "MyEnum.BAR": torch.tensor(8), 4125 5: torch.rand(3), 4126 } 4127 torch._dynamo.graph_break() 4128 a = z[MyEnum.FOO] + z["MyEnum.BAR"] 4129 b = y * 2 4130 return a, b 4131 4132 cnts = torch._dynamo.testing.CompileCounter() 4133 opt_fn = torch._dynamo.optimize(cnts)(fn) 4134 for _ in range(10): 4135 x = torch.rand(3) 4136 ref = fn(x) 4137 res = opt_fn(x) 4138 self.assertTrue(same(ref, res)) 4139 self.assertEqual(cnts.frame_count, 2) 4140 4141 def test_const_dict_variable_python_type(self): 4142 from torch._dynamo.variables import ConstantVariable, ConstDictVariable 4143 4144 make_key = ConstantVariable.create 4145 4146 d1 = { 4147 make_key("a"): ConstantVariable.create(10), 4148 make_key("b"): ConstantVariable.create(20), 4149 } 4150 d2 = collections.OrderedDict( 4151 [ 4152 (make_key("x"), ConstantVariable.create(12)), 4153 (make_key("y"), ConstantVariable.create(22)), 4154 ] 4155 ) 4156 self.assertEqual(ConstDictVariable(d1).python_type(), dict) 4157 self.assertEqual( 4158 ConstDictVariable(d2, collections.OrderedDict).python_type(), 4159 collections.OrderedDict, 4160 ) 4161 4162 def test_builtin_subclasses_as_method_on_class_type(self): 4163 class Foo: 4164 def __init__(self, name): 4165 self.ame_ = name 4166 4167 def get_name(self): 4168 return "Foo " + self.name_ 4169 4170 class Bar(Foo): 4171 def __init__(self, name): 4172 self.name_ = name 4173 4174 def get_name(self): 4175 return "Bar " + self.name_ 4176 4177 class Baz(Foo): 4178 def __init__(self, name): # noqa: B903 4179 self.name_ = name 4180 4181 def get_name(self): 4182 return "Baz " + self.name_ 4183 4184 subs_of_foo_reg = Foo.__subclasses__() 4185 4186 counter = CompileCounter() 4187 4188 @torch._dynamo.optimize_assert(counter) 4189 def fn(): 4190 return Foo.__subclasses__() 4191 4192 subs_of_foo_optim = fn() 4193 4194 self.assertEqual(len(subs_of_foo_reg), 2) 4195 self.assertEqual(subs_of_foo_reg, subs_of_foo_optim) 4196 4197 def test_builtin_subclasses_as_method_on_var(self): 4198 class Foo: 4199 def __init__(self, name): 4200 self.name_ = name 4201 4202 def get_name(self): 4203 return "Foo " + self.name_ 4204 4205 class Bar(Foo): 4206 def __init__(self, name): 4207 self.name_ = name 4208 4209 def get_name(self): 4210 return "Bar " + self.name_ 4211 4212 class Baz(Bar): 4213 def __init__(self, name): 4214 self.name_ = name 4215 4216 def get_name(self): 4217 return "Baz " + self.name_ 4218 4219 subs_of_foo_reg = Foo.__subclasses__() 4220 sub_of_foo_subclass_var_reg = subs_of_foo_reg[0].__subclasses__() 4221 4222 sub_of_foo_subclass_var_optim = list() 4223 counter = CompileCounter() 4224 4225 @torch._dynamo.optimize_assert(counter) 4226 def fn(): 4227 return Foo.__subclasses__() 4228 4229 @torch._dynamo.optimize_assert(counter) 4230 def fn_single(subs_of_foo_optim): 4231 return subs_of_foo_optim[0].__subclasses__() 4232 4233 subs_of_foo_optim = fn() 4234 sub_of_foo_subclass_var_optim = fn_single(subs_of_foo_optim) 4235 4236 self.assertEqual(len(sub_of_foo_subclass_var_optim), 1) 4237 self.assertEqual(sub_of_foo_subclass_var_optim, sub_of_foo_subclass_var_reg) 4238 4239 def test_builtin_str_on_user_defined_function(self): 4240 def another_fn(): 4241 pass 4242 4243 def fn(): 4244 return "another_fn" in str(another_fn) 4245 4246 opt_fn = torch._dynamo.optimize(nopython=True)(fn) 4247 self.assertTrue(opt_fn()) 4248 4249 def test_enum_no_graphbreaks(self): 4250 class Foo(enum.Enum): 4251 FOO = 0 4252 BAR = 1 4253 4254 def fn(x, foo): 4255 if foo is Foo.FOO: 4256 x = torch.add(x, 1.0) 4257 x = torch.mul(x, 1.0) 4258 return x 4259 4260 x = torch.randn(1) 4261 cnts = torch._dynamo.testing.CompileCounter() 4262 opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 4263 opt_fn(x, Foo.FOO) 4264 self.assertEqual(cnts.op_count, 2) 4265 4266 torch._dynamo.reset() 4267 cnts = torch._dynamo.testing.CompileCounter() 4268 opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 4269 opt_fn(x, Foo.BAR) 4270 self.assertEqual(cnts.op_count, 1) 4271 4272 def test_repeat_interleave_graphbreaks(self): 4273 def fn_no_breaks(x): 4274 # no breaks on self_int 4275 x += 1 4276 x = torch.repeat_interleave(x, 2, 3) 4277 x += 1 4278 return x 4279 4280 def fn_has_breaks(x): 4281 # breaks on self_Tensor 4282 x += 1 4283 x = torch.repeat_interleave(x, torch.tensor(2), 3) 4284 x += 1 4285 return x 4286 4287 x = torch.randn([4, 16, 1, 64]) 4288 4289 cnts = torch._dynamo.testing.CompileCounter() 4290 opt_fn = torch._dynamo.optimize(cnts)(fn_no_breaks) 4291 opt_fn(x) 4292 self.assertEqual(cnts.frame_count, 1) 4293 4294 torch._dynamo.reset() 4295 cnts = torch._dynamo.testing.CompileCounter() 4296 opt_fn = torch._dynamo.optimize(cnts)(fn_has_breaks) 4297 opt_fn(x) 4298 self.assertEqual(cnts.frame_count, 2) 4299 4300 def test_id_guarded_object(self): 4301 class UDO: 4302 @torch.compile(backend="eager") 4303 def call(self, x, ref_id): 4304 self_id = id(self) 4305 if self_id == ref_id: 4306 x = torch.mul(x, 1.0) 4307 else: 4308 x = torch.mul(x, 0) 4309 return x 4310 4311 # Make sure we do recompile when id(self) is executed on 4312 # different self objects. 4313 x = torch.ones(2) 4314 obj1 = UDO() 4315 obj1_id = id(obj1) 4316 self.assertEqual(obj1.call(x, obj1_id), torch.ones(2)) 4317 4318 obj2 = UDO() 4319 # if we do not install ID_MATCH: ___check_obj_id(L['self'], xxx) this fails. 4320 self.assertEqual(obj2.call(x, obj1_id), torch.zeros(2)) 4321 4322 def test_id_guarded_module(self): 4323 class M(torch.nn.Module): 4324 def forward(self, x, ref_id): 4325 self_id = id(self) 4326 if self_id == ref_id: 4327 x = torch.mul(x, 1.0) 4328 else: 4329 x = torch.mul(x, 0) 4330 return x 4331 4332 cnts = torch._dynamo.testing.CompileCounter() 4333 4334 # Make sure we do recompile when id(self) is executed on 4335 # different self objects. 4336 x = torch.ones(2) 4337 m1 = M() 4338 m1_id = id(m1) 4339 opt_m1 = torch._dynamo.optimize(cnts, nopython=True)(m1) 4340 self.assertEqual(opt_m1(x, m1_id), torch.ones(2)) 4341 self.assertEqual(opt_m1(x, m1_id), torch.ones(2)) 4342 4343 self.assertEqual(cnts.frame_count, 1) 4344 self.assertEqual(cnts.op_count, 1) 4345 4346 m2 = M() 4347 opt_m2 = torch._dynamo.optimize(cnts, nopython=True)(m2) 4348 # if we do not install ID_MATCH: ___check_obj_id(L['self'], xxx) this fails. 4349 self.assertEqual(opt_m2(x, m1_id), torch.zeros(2)) 4350 self.assertEqual(cnts.frame_count, 2) 4351 self.assertEqual(cnts.op_count, 2) 4352 4353 def test_id_of_nn_module(self): 4354 class M(torch.nn.Module): 4355 def forward(self, x, ref_id): 4356 self_id = id(self) 4357 if self_id == ref_id: 4358 x = torch.mul(x, 1.0) 4359 x = torch.add(x, 1.0) 4360 return x 4361 4362 m = M().eval() 4363 data = torch.randn(1) 4364 cnts = torch._dynamo.testing.CompileCounter() 4365 correct_ref_id = id(m) 4366 opt_m = torch._dynamo.optimize(cnts, nopython=True)(m) 4367 opt_m(data, correct_ref_id) 4368 # Extra op is the recorded equality test (although once 4369 # the trace is flattened this is dead!) 4370 if torch._dynamo.config.assume_static_by_default: 4371 self.assertExpectedInline(cnts.op_count, """2""") 4372 else: 4373 self.assertExpectedInline(cnts.op_count, """2""") 4374 4375 torch._dynamo.reset() 4376 cnts = torch._dynamo.testing.CompileCounter() 4377 incorrect_ref_id = id(m) + 1 4378 opt_m = torch._dynamo.optimize(cnts, nopython=True)(m) 4379 opt_m(data, incorrect_ref_id) 4380 if torch._dynamo.config.assume_static_by_default: 4381 self.assertExpectedInline(cnts.op_count, """1""") 4382 else: 4383 self.assertExpectedInline(cnts.op_count, """1""") 4384 4385 def test_inline_func_jump_on_tensor_condition(self): 4386 def f1(input): 4387 if input == 0: 4388 return input + 1 4389 else: 4390 return input + 2 4391 4392 def f2(input): 4393 return f1(input) 4394 4395 cnts = torch._dynamo.testing.CompileCounter() 4396 opt_f2 = torch._dynamo.optimize(cnts)(f2) 4397 res1 = opt_f2(torch.tensor([1.0])) 4398 res2 = opt_f2(torch.tensor([0.0])) 4399 4400 self.assertEqual(res1, 3) 4401 self.assertEqual(res2, 1) 4402 4403 def test_frozenset_torch_func_contains(self): 4404 funcs = frozenset([torch.add]) 4405 4406 def fn(x, func): 4407 if func in funcs: 4408 x = torch.add(x, 1.0) 4409 x = torch.mul(x, 1.0) 4410 return x 4411 4412 x = torch.randn(1) 4413 cnts = torch._dynamo.testing.CompileCounter() 4414 opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 4415 opt_fn(x, torch.add) 4416 self.assertEqual(cnts.op_count, 2) 4417 4418 torch._dynamo.reset() 4419 cnts = torch._dynamo.testing.CompileCounter() 4420 opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 4421 opt_fn(x, torch.mul) 4422 self.assertEqual(cnts.op_count, 1) 4423 4424 def test_inline_list_mutation(self): 4425 def f1(x): 4426 x.append(torch.ones(8)) 4427 return x 4428 4429 def f2(): 4430 x = [torch.ones(6)] 4431 f1(x) 4432 return x 4433 4434 res1 = f2() 4435 cnts = torch._dynamo.testing.CompileCounter() 4436 opt_f2 = torch._dynamo.optimize(cnts)(f2) 4437 res2 = opt_f2() 4438 self.assertTrue(same(res1, res2)) 4439 4440 def test_inline_dict_mutation(self): 4441 def f1(d): 4442 d["c"] = d["a"] + d.pop("b") 4443 return d 4444 4445 def f2(): 4446 d = {"a": torch.ones(5), "b": torch.ones(5)} 4447 f1(d) 4448 return d 4449 4450 res1 = f2() 4451 cnts = torch._dynamo.testing.CompileCounter() 4452 opt_f2 = torch._dynamo.optimize(cnts)(f2) 4453 res2 = opt_f2() 4454 self.assertTrue(same(res1, res2)) 4455 4456 def test_inline_local_dict_clear(self): 4457 def f(d): 4458 d.clear() 4459 return d 4460 4461 inp = {"a": torch.randn(2, 2), "b": torch.randn(2, 2)} 4462 out = torch.compile(f, backend="eager", fullgraph=True)(inp) 4463 self.assertEqual(len(out), 0) 4464 self.assertEqual(len(inp), 0) 4465 4466 def test_inline_module_attr_dict_clear(self): 4467 class MyMod(torch.nn.Module): 4468 def __init__(self): 4469 super().__init__() 4470 self.a = {"a": torch.randn(2, 2), "b": torch.randn(2, 2)} 4471 4472 def forward(self): 4473 self.a.clear() 4474 return self.a 4475 4476 m = MyMod() 4477 out = torch.compile(m, backend="eager", fullgraph=True)() 4478 self.assertEqual(len(out), 0) 4479 self.assertEqual(len(m.a), 0) 4480 4481 def test_inline_user_defined_dict_attr_clear(self): 4482 class MyMod: 4483 def __init__(self): 4484 self.a = {"a": torch.randn(2, 2), "b": torch.randn(2, 2)} 4485 4486 def f(obj, inp): 4487 ret = len(obj.a) + inp 4488 obj.a.clear() 4489 return obj.a, ret 4490 4491 m = MyMod() 4492 before_len = len(m.a) 4493 t_inp = torch.ones(1) 4494 d, ret = torch.compile(f, backend="eager", fullgraph=True)(m, t_inp) 4495 self.assertEqual(len(m.a), 0) 4496 self.assertEqual(len(d), 0) 4497 self.assertEqual(ret, t_inp + before_len) 4498 4499 def test_recursive_inline_list_mutation(self): 4500 def f1(x, y): 4501 x.append(torch.tensor([1.1])) 4502 y.append(torch.tensor([1.2])) 4503 return x, y 4504 4505 def f2(x, y): 4506 x.append(torch.tensor([2.1])) 4507 y.append(torch.tensor([2.2])) 4508 f1(x, y) 4509 return x, y 4510 4511 def f3(x): 4512 x.append(torch.tensor([3.1])) 4513 y = [torch.tensor([3.2])] 4514 f2(x, y) 4515 return x, y 4516 4517 def f4(): 4518 x = [torch.tensor([4.1])] 4519 return f3(x) 4520 4521 res1 = f4() 4522 cnts = torch._dynamo.testing.CompileCounter() 4523 opt_f4 = torch._dynamo.optimize(cnts)(f4) 4524 res2 = opt_f4() 4525 self.assertTrue(same(res1, res2)) 4526 4527 def test_sample_input(self): 4528 from torch.testing._internal.common_methods_invocations import SampleInput 4529 4530 def fn(sample): 4531 if isinstance(sample.input, torch.Tensor): 4532 return sample.input * 2 4533 return torch.zeros(()) 4534 4535 sample = SampleInput(torch.ones(2)) 4536 ref = fn(sample) 4537 4538 opt_fn = torch._dynamo.optimize("eager")(fn) 4539 res = opt_fn(sample) 4540 4541 self.assertTrue(same(ref, res)) 4542 4543 def test_release_input_memory(self): 4544 x = torch.rand([4]) 4545 x_ref = weakref.ref(x) 4546 4547 cnts = torch._dynamo.testing.CompileCounter() 4548 4549 @torch._dynamo.optimize(cnts) 4550 def foo(x): 4551 return x + x 4552 4553 out = foo(x) 4554 self.assertTrue(same(out, x + x)) 4555 del x 4556 self.assertIs(x_ref(), None) 4557 4558 def test_release_module_memory(self): 4559 mod = torch.nn.Linear(10, 10) 4560 x = torch.rand([10, 10]) 4561 mod_weight_ref = weakref.ref(mod.weight) 4562 mod_ref = weakref.ref(mod) 4563 4564 # Modules that are passed into torch._dynamo optimized functions 4565 # will normally be held onto through the generated GraphModule, 4566 # which contains the modules. remove the reference in this backend 4567 # and test that no additional references are being held. 4568 class NoLeakBackend: 4569 def __call__(self, gm: torch.fx.GraphModule, example_inputs): 4570 gm.mod = None 4571 4572 def foo(*args, **kwargs): 4573 return (1,) 4574 4575 return foo 4576 4577 no_leak_backend = NoLeakBackend() 4578 4579 @torch._dynamo.optimize(no_leak_backend) 4580 def foo(mod, x): 4581 return mod(x) 4582 4583 foo(mod, x) 4584 del mod 4585 del x 4586 self.assertIsNone(mod_ref(), None) 4587 self.assertIsNone(mod_weight_ref(), None) 4588 4589 def test_release_scope_memory(self): 4590 def inner(y): 4591 y 4592 4593 inner = torch._dynamo.optimize("eager")(inner) 4594 4595 p_ref = None 4596 4597 x = torch.randn((10, 10)) 4598 inner(x) 4599 4600 p_ref = weakref.ref(x) 4601 self.assertTrue(p_ref() is not None) 4602 del x 4603 self.assertTrue(p_ref() is None) 4604 4605 def test_update_locals_and_stack_uses_shared_cache(self): 4606 def fn(x): 4607 perm = [0, 3, 5] 4608 perm = list(range(min(perm))) + perm 4609 perm.extend(i for i in range(x.dim()) if i not in perm) 4610 return perm 4611 4612 x = torch.rand([2, 2, 2, 2, 2, 2]) 4613 res1 = fn(x) 4614 cnts = torch._dynamo.testing.CompileCounter() 4615 opt_fn = torch._dynamo.optimize(cnts)(fn) 4616 res2 = opt_fn(x) 4617 self.assertTrue(same(res1, res2)) 4618 4619 def test_dict_reconstruct_keeps_original_order(self): 4620 def fn(): 4621 modules = collections.OrderedDict([("act", torch.nn.ReLU())]) 4622 module_dict = torch.nn.ModuleDict(modules) 4623 4624 next_modules = {"fc4": torch.nn.Linear(5, 6), "act3": torch.nn.Sigmoid()} 4625 modules.update(next_modules.items()) 4626 module_dict.update(next_modules) 4627 return modules, module_dict 4628 4629 cnts = torch._dynamo.testing.CompileCounter() 4630 opt_fn = torch._dynamo.optimize(cnts)(fn) 4631 modules, module_dict = opt_fn() 4632 4633 self.assertEqual(len(module_dict), len(modules)) 4634 for k1, m2 in zip(modules, module_dict.children()): 4635 self.assertTrue(modules[k1] is m2) 4636 4637 def test_side_effects_codegen_update_mutated(self): 4638 # codegen to update mutated variables with side effect 4639 # should after stack value's codegen 4640 def f1(x): 4641 alist = [x] 4642 alist.append(x + 1) 4643 alist[0].sum().item() # graph break 4644 res = alist.pop() 4645 res.sum().item() # graph break 4646 return res 4647 4648 def f2(a, b): 4649 d = {"a": a + 1, "b": b + 2} 4650 x = d.pop("b") 4651 x.sum().item() # graph break 4652 y = d["a"] + x 4653 y.sum().item() # graph break 4654 d["c"] = y 4655 return d 4656 4657 x = torch.rand([2, 3]) 4658 a = torch.rand([5, 6]) 4659 b = torch.rand([5, 6]) 4660 res11 = f1(x) 4661 res21 = f2(a, b) 4662 cnts = torch._dynamo.testing.CompileCounter() 4663 opt_f1 = torch._dynamo.optimize(cnts)(f1) 4664 opt_f2 = torch._dynamo.optimize(cnts)(f2) 4665 res12 = opt_f1(x) 4666 res22 = opt_f2(a, b) 4667 self.assertTrue(same(res11, res12)) 4668 self.assertTrue(same(res21, res22)) 4669 4670 def test_list_append_return_none(self): 4671 def fn(x): 4672 alist = [] 4673 blist = alist.append(x + 1) 4674 return alist, blist 4675 4676 x = torch.tensor([2.3]) 4677 res = fn(x) 4678 cnts = torch._dynamo.testing.CompileCounter() 4679 opt_fn = torch._dynamo.optimize(cnts)(fn) 4680 res2 = opt_fn(x) 4681 self.assertEqual(res, res2) 4682 4683 @patch.object(torch._dynamo.config, "capture_scalar_outputs", True) 4684 def test_tensor_ctor_list_of_tensor(self): 4685 def fn(x): 4686 return torch.tensor([x], dtype=torch.int64) 4687 4688 x = torch.tensor(20) 4689 res = fn(x) 4690 cnts = torch._dynamo.testing.CompileCounter() 4691 opt_fn = torch._dynamo.optimize(cnts)(fn) 4692 res2 = opt_fn(x) 4693 self.assertEqual(res, res2) 4694 self.assertEqual(cnts.frame_count, 1) 4695 4696 def test_tensor_types(self): 4697 def fn(dtype, tensor_type): 4698 x = torch.empty(4, dtype=dtype) 4699 assert isinstance(x, tensor_type) 4700 4701 opt_fn = torch._dynamo.optimize("eager")(fn) 4702 opt_fn(torch.float32, torch.FloatTensor) 4703 opt_fn(torch.float64, torch.DoubleTensor) 4704 opt_fn(torch.float16, torch.HalfTensor) 4705 opt_fn(torch.bfloat16, torch.BFloat16Tensor) 4706 opt_fn(torch.uint8, torch.ByteTensor) 4707 opt_fn(torch.int8, torch.CharTensor) 4708 opt_fn(torch.int64, torch.LongTensor) 4709 opt_fn(torch.int, torch.IntTensor) 4710 opt_fn(torch.int16, torch.ShortTensor) 4711 opt_fn(torch.bool, torch.BoolTensor) 4712 4713 def test_nan(self): 4714 def f(x, n): 4715 return x * 2 + n 4716 4717 x = torch.randn(4) 4718 n = float("nan") 4719 4720 cnts = torch._dynamo.testing.CompileCounter() 4721 opt_f = torch._dynamo.optimize(cnts)(f) 4722 opt_f(x, n) 4723 opt_f(x, n) 4724 self.assertEqual(cnts.frame_count, 1) 4725 4726 @patch.object(torch._dynamo.config, "capture_scalar_outputs", True) 4727 def test_item(self): 4728 class MyMod(torch.nn.Module): 4729 def forward(self, x): 4730 z = torch.max(x) 4731 return z.int().item() 4732 4733 x = torch.tensor([[10.6763, 11.7445, -2.2369]]) 4734 model = MyMod() 4735 y = torch._dynamo.optimize("eager", nopython=True)(model)(x) 4736 4737 self.assertEqual(y, 11) 4738 4739 @patch.object(torch._dynamo.config, "capture_scalar_outputs", True) 4740 def test_item_changes(self): 4741 class MyMod(torch.nn.Module): 4742 def forward(self, x): 4743 z = torch.max(x) 4744 return z.int().item() 4745 4746 x = torch.tensor([[10.6763, 11.7445, -2.2369]]) 4747 model = MyMod() 4748 opt_model = torch._dynamo.optimize("eager", nopython=True)(model) 4749 y = opt_model(x) 4750 z = opt_model(torch.tensor([[y - 5, y + 10, y + 50]])) 4751 4752 self.assertEqual(y, 11) 4753 self.assertEqual(z, 61) 4754 4755 @patch.object(torch._dynamo.config, "capture_scalar_outputs", True) 4756 def test_item_changes_new_shape(self): 4757 class MyMod(torch.nn.Module): 4758 def forward(self, x): 4759 z = torch.max(x) 4760 return z.int().item() 4761 4762 x = torch.tensor([[10.6763, 11.7445, -2.2369]]) 4763 model = MyMod() 4764 opt_model = torch._dynamo.optimize("eager", nopython=True)(model) 4765 y = opt_model(x) 4766 z = opt_model(torch.tensor([[y - 5, y + 50], [y + 5, y - 50]])) 4767 4768 self.assertEqual(y, 11) 4769 self.assertEqual(z, 61) 4770 4771 @unittest.skip("https://github.com/pytorch/pytorch/issues/99726") 4772 def test_cross_entropy_loss_fancy_ctor1(self): 4773 rand_5 = torch.randn(5) 4774 rand_3_5 = torch.randn(3, 5) 4775 target = torch.empty(3, dtype=torch.long).random_(5) 4776 4777 loss = torch.nn.CrossEntropyLoss( 4778 weight=rand_5, reduce=False, label_smoothing=0.5 4779 ) 4780 opt_loss = torch._dynamo.optimize("eager", nopython=True)(loss) 4781 input = rand_3_5 4782 dynamo_output = opt_loss(input, target) 4783 4784 loss = torch.nn.CrossEntropyLoss( 4785 weight=rand_5, reduce=False, label_smoothing=0.5 4786 ) 4787 input = rand_3_5 4788 output = loss(input, target) 4789 4790 self.assertTrue(torch.allclose(dynamo_output, output)) 4791 4792 def test_cross_entropy_loss_fancy_ctor2(self): 4793 rand_3_5 = torch.randn(3, 5) 4794 target = torch.empty(3, dtype=torch.long).random_(5) 4795 4796 loss = torch.nn.CrossEntropyLoss(reduce=False, label_smoothing=0.5) 4797 opt_loss = torch._dynamo.optimize("eager", nopython=True)(loss) 4798 input = rand_3_5 4799 dynamo_output = opt_loss(input, target) 4800 4801 loss = torch.nn.CrossEntropyLoss(reduce=False, label_smoothing=0.5) 4802 input = rand_3_5 4803 output = loss(input, target) 4804 4805 self.assertTrue(torch.allclose(dynamo_output, output)) 4806 4807 def test_cross_entropy_loss_simple_ctor(self): 4808 output = None 4809 rand_3_5 = torch.randn(3, 5) 4810 target = torch.empty(3, dtype=torch.long).random_(5) 4811 4812 loss = torch.nn.CrossEntropyLoss() 4813 opt_loss = torch._dynamo.optimize("eager", nopython=True)(loss) 4814 input = rand_3_5 4815 dynamo_output = opt_loss(input, target) 4816 4817 loss = torch.nn.CrossEntropyLoss() 4818 input = rand_3_5 4819 output = loss(input, target) 4820 4821 self.assertTrue(torch.allclose(dynamo_output, output)) 4822 4823 def test_nn_functional_reduction(self): 4824 def fn(loss, reduction): 4825 reduction_enum = F._Reduction.get_enum(reduction) 4826 if reduction_enum == 0: 4827 return loss 4828 elif reduction_enum == 1: 4829 return loss.mean() 4830 elif reduction_enum == 2: 4831 return loss.sum() 4832 4833 x = torch.rand([3, 5]) 4834 y = "mean" 4835 ref = fn(x, y) 4836 opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) 4837 res = opt_fn(x, y) 4838 self.assertTrue(torch.allclose(ref, res)) 4839 4840 def test_large_reduction_list(self): 4841 dtype = torch.float32 4842 device = "cpu" 4843 4844 def check_sum_all(tensor: torch.Tensor) -> None: 4845 pylist = tensor.reshape(-1).tolist() 4846 self.assertTrue(same(tensor.sum(), torch.tensor(sum(pylist)))) 4847 4848 check_sum_all(torch.randn(200000, dtype=dtype, device=device)) 4849 4850 def test_raise_on_backend_error(self): 4851 def my_compiler(gm, _): 4852 raise RuntimeError("duck!") 4853 4854 @torch._dynamo.optimize(my_compiler) 4855 def fn(a, b): 4856 return a + b / (a - b) 4857 4858 self.assertRaises( 4859 torch._dynamo.exc.BackendCompilerFailed, 4860 lambda: fn(torch.randn(10), torch.randn(10)), 4861 ) 4862 4863 def test_named_parameters(self): 4864 n_embd = 768 4865 block_size = 128 4866 vocab_size = 65 4867 embd_pdrop = 0.1 4868 4869 class MyModel2(torch.nn.Module): 4870 def __init__(self): 4871 super().__init__() 4872 self.tok_emb = torch.nn.Embedding(vocab_size, n_embd) 4873 self.pos_emb = torch.nn.Parameter(torch.zeros(1, block_size, n_embd)) 4874 self.drop = torch.nn.Dropout(embd_pdrop) 4875 4876 def forward(self, x): 4877 return x 4878 4879 class MyModel(torch.nn.Module): 4880 def __init__(self): 4881 super().__init__() 4882 self.tok_emb = torch.nn.Embedding(vocab_size, n_embd) 4883 self.pos_emb = torch.nn.Parameter(torch.zeros(1, block_size, n_embd)) 4884 self.drop = torch.nn.Dropout(embd_pdrop) 4885 self.submod2 = MyModel2() 4886 4887 def forward(self, x): 4888 return x 4889 4890 # Regular 4891 params = [] 4892 mod = MyModel() 4893 actual_params = list(mod.named_parameters()) 4894 4895 @torch._dynamo.optimize("eager", nopython=True) 4896 def fn(): 4897 return list(mod.named_parameters()) 4898 4899 params = fn() 4900 4901 self.assertEqual(len(actual_params), len(params)) 4902 for idx in range(len(params)): 4903 k_a, v_a = actual_params[idx] 4904 k, v = params[idx] 4905 self.assertEqual(k_a, k) 4906 self.assertTrue(torch.allclose(v_a, v)) 4907 4908 # Prefix 4909 params = [] 4910 mod = MyModel() 4911 actual_params = list(mod.named_parameters(prefix="foo")) 4912 4913 @torch._dynamo.optimize("eager", nopython=True) 4914 def fn1(): 4915 return list(mod.named_parameters(prefix="foo")) 4916 4917 params = fn1() 4918 4919 self.assertEqual(len(actual_params), len(params)) 4920 for idx in range(len(params)): 4921 k_a, v_a = actual_params[idx] 4922 k, v = params[idx] 4923 self.assertEqual(k_a, k) 4924 self.assertTrue(torch.allclose(v_a, v)) 4925 4926 @torch._dynamo.config.patch(guard_nn_modules=True) 4927 def test_module_complex_iter(self): 4928 n_embd = 768 4929 block_size = 128 4930 vocab_size = 65 4931 embd_pdrop = 0.1 4932 4933 class FakeGPT(torch.nn.Module): 4934 def __init__(self): 4935 super().__init__() 4936 self.tok_emb = torch.nn.Embedding(vocab_size, n_embd) 4937 self.pos_emb = torch.nn.Parameter(torch.zeros(1, block_size, n_embd)) 4938 self.drop = torch.nn.Dropout(embd_pdrop) 4939 self.ln_f = torch.nn.LayerNorm(n_embd) 4940 self.head = torch.nn.Linear(n_embd, vocab_size, bias=False) 4941 4942 self.block_size = block_size 4943 self.names = [] 4944 4945 def forward(self, idx, targets=None): 4946 b, t = idx.size() 4947 assert ( 4948 t <= self.block_size 4949 ), "Cannot forward, model block size is exhausted." 4950 4951 # forward the GPT model 4952 token_embeddings = self.tok_emb( 4953 idx 4954 ) # each index maps to a (learnable) vector 4955 position_embeddings = self.pos_emb[ 4956 :, :t, : 4957 ] # each position maps to a (learnable) vector 4958 x = self.drop(token_embeddings + position_embeddings) 4959 x = self.blocks(x) 4960 x = self.ln_f(x) 4961 logits = self.head(x) 4962 4963 # if we are given some desired targets also calculate the loss 4964 loss = None 4965 if targets is not None: 4966 loss = F.cross_entropy( 4967 logits.view(-1, logits.size(-1)), targets.view(-1) 4968 ) 4969 4970 return logits, loss 4971 4972 def foo(self, memo=None, prefix="", remove_duplicate=False): 4973 for mn, m in self.named_modules( 4974 memo=memo, prefix=prefix, remove_duplicate=remove_duplicate 4975 ): 4976 for pn, p in self.named_parameters(): 4977 fpn = f"{mn}.{pn}" if mn else pn 4978 self.names.append(fpn) 4979 4980 # Test plain recurse 4981 model_a = FakeGPT() 4982 model_a.foo() 4983 a_names = model_a.names 4984 4985 model_b = FakeGPT() 4986 opt_model_b = torch._dynamo.optimize("eager", nopython=True)(model_b) 4987 opt_model_b.foo() 4988 4989 self.assertEqual(a_names, model_b.names) 4990 4991 # Test with prefix 4992 model_a = FakeGPT() 4993 model_a.foo(prefix="abc") 4994 a_names = model_a.names 4995 4996 model_b = FakeGPT() 4997 opt_model_b = torch._dynamo.optimize("eager", nopython=True)(model_b) 4998 opt_model_b.foo(prefix="abc") 4999 5000 self.assertEqual(a_names, model_b.names) 5001 5002 def test_numpy_variable_isinstance(self): 5003 def fn(x, m): 5004 if isinstance(m, np.ndarray): 5005 return x + 1 5006 else: 5007 return x - 1 5008 5009 x = torch.tensor([2.3]) 5010 m = np.array([1, 2, 3]) 5011 ref = fn(x, m) 5012 cnts = torch._dynamo.testing.CompileCounter() 5013 opt_fn = torch._dynamo.optimize(cnts)(fn) 5014 res = opt_fn(x, m) 5015 self.assertEqual(ref, res) 5016 5017 # Test now the other path 5018 ref = fn(x, x) 5019 res = opt_fn(x, x) 5020 self.assertEqual(ref, res) 5021 5022 def test_tensor_dot_grad_no_graph_break(self): 5023 def fn(a, b): 5024 y = 3 * a**3 - b**2 5025 y.backward(gradient=torch.tensor([1.0, 1.0])) 5026 b.grad.zero_() 5027 return a.grad, b.grad 5028 5029 a = torch.tensor([2.0, 3.0], requires_grad=True) 5030 b = torch.tensor([6.0, 4.0], requires_grad=True) 5031 cnts = torch._dynamo.testing.CompileCounter() 5032 opt_fn = torch._dynamo.optimize(cnts)(fn) 5033 _, b_grad = opt_fn(a, b) 5034 self.assertTrue(same(b_grad, torch.tensor([0.0, 0.0]))) 5035 self.assertEqual(cnts.frame_count, 2) 5036 5037 def test_torch_nn_parameter_isinstance(self): 5038 def fn(x): 5039 a = torch.nn.Parameter(torch.rand(2, 3)) 5040 if isinstance(a, torch.Tensor): 5041 return x + 1 5042 else: 5043 return x - 1 5044 5045 x = torch.tensor([2.5]) 5046 ref = fn(x) 5047 opt_fn = torch._dynamo.optimize("eager")(fn) 5048 res = opt_fn(x) 5049 self.assertEqual(ref, res) 5050 5051 def _optimize_then_check_exp( 5052 self, foo, args, cnt, exp_out, exp_frame_count, exp_n_cached_backend 5053 ): 5054 opt_out = torch._dynamo.optimize(backend=cnt)(foo)(*args) 5055 self.assertEqual(exp_out, opt_out) 5056 self.assertEqual(cnt.frame_count, exp_frame_count) 5057 5058 def test_backend_match_guard(self): 5059 x = torch.randn([3, 4]) 5060 5061 def foo(x): 5062 return x.sin() + x.cos() 5063 5064 def foo_graph_break(x): 5065 a = x.sin() 5066 torch._dynamo.graph_break() 5067 b = x.cos() 5068 return a + b 5069 5070 eager_record_backend = torch._dynamo.testing.EagerAndRecordGraphs() 5071 backends = [eager_record_backend, "eager"] 5072 5073 # We intentionally don't reset dynamo for each backend so that we can test 5074 # 1. dynamo doesn't recompile when backend stays the same, i.e. frame_count doesn't increase 5075 # 2. dynamo recompiles when backend changes, i.e. frame_count is non-zero for next backend 5076 def test_recompile(foo, *, exp_frame_count): 5077 eager_result = foo(x) 5078 for i, backend in enumerate(backends): 5079 cnt = torch._dynamo.testing.CompileCounterWithBackend(backend) 5080 # Run opt_f multiple times to make sure dynamo doesn't recompile. 5081 # Specifically, frame_count doesn't increase 5082 # the number of cached backends is i + 2 because we have the optimizing backend + None 5083 self._optimize_then_check_exp( 5084 foo, (x,), cnt, eager_result, exp_frame_count, i + 2 5085 ) 5086 self._optimize_then_check_exp( 5087 foo, (x,), cnt, eager_result, exp_frame_count, i + 2 5088 ) 5089 self._optimize_then_check_exp( 5090 foo, (x,), cnt, eager_result, exp_frame_count, i + 2 5091 ) 5092 5093 test_recompile(foo, exp_frame_count=1) 5094 torch._dynamo.reset() 5095 test_recompile(foo_graph_break, exp_frame_count=2) 5096 5097 def test_backend_match_guard_multi_threads(self): 5098 x = torch.randn([3, 4]) 5099 5100 def foo(x): 5101 return x.sin() + x.cos() 5102 5103 def compile_then_check_exp(foo, args, cnt, eager_result, exp_frame_count): 5104 for i in range(3): 5105 opt_out = torch._dynamo.optimize(backend=cnt)(foo)(*args) 5106 self.assertEqual(opt_out, eager_result) 5107 self.assertEqual(cnt.frame_count, exp_frame_count) 5108 thread_success[threading.current_thread()] = True 5109 5110 eager_record_backend = torch._dynamo.testing.EagerAndRecordGraphs() 5111 backends = [eager_record_backend, "eager"] 5112 5113 # Test dynamo recompiles but only caches a single backend for each thread 5114 eager_result = foo(x) 5115 # cnt and None 5116 exp_frame_count = 1 5117 threads = [] 5118 thread_success = {} 5119 for i, backend in enumerate(backends): 5120 cnt = torch._dynamo.testing.CompileCounterWithBackend(backend) 5121 thread = threading.Thread( 5122 target=compile_then_check_exp, 5123 args=( 5124 foo, 5125 (x,), 5126 cnt, 5127 eager_result, 5128 exp_frame_count, 5129 ), 5130 ) 5131 threads.append(thread) 5132 thread.start() 5133 5134 # Wait for all threads to finish 5135 for thread in threads: 5136 thread.join() 5137 5138 self.assertEqual(len(thread_success), len(threads)) 5139 5140 def test_dynamo_min_operator_with_shape(self): 5141 @torch._dynamo.optimize("eager", nopython=True) 5142 def f(x, a): 5143 return min(x.shape[0], a) 5144 5145 result = f(torch.ones(6), 3) 5146 self.assertEqual(result, 3) 5147 5148 def test_onnx_shape_as_tensor(self): 5149 @torch._dynamo.optimize("eager", nopython=True) 5150 def f(x): 5151 return 1 + torch._shape_as_tensor(x)[0] 5152 5153 gm, _ = torch._dynamo.export(f)(torch.ones(6)) 5154 5155 input_one_dim = torch.ones(6) 5156 input_two_dims = torch.ones(7, 4) 5157 self.assertEqual(f(input_one_dim), 7) 5158 self.assertEqual(f(input_two_dims), 8) 5159 self.assertEqual(f(input_two_dims), 8) 5160 5161 @torch._dynamo.optimize("eager", nopython=True) 5162 def f_onnx(x): 5163 return 1 + torch.onnx.operators.shape_as_tensor(x)[0] 5164 5165 self.assertEqual(f_onnx(input_one_dim), 7) 5166 self.assertEqual(f_onnx(input_two_dims), 8) 5167 self.assertEqual(f_onnx(input_two_dims), 8) 5168 5169 def test_cond(self): 5170 from functorch.experimental.control_flow import cond 5171 5172 def true_fn(x): 5173 return x.sin() 5174 5175 def false_fn(x): 5176 return x.cos() 5177 5178 def f(pred, x): 5179 return cond(pred, true_fn, false_fn, [x]) 5180 5181 opt_fn = torch._dynamo.optimize("eager")(f) 5182 a = opt_fn(torch.tensor(False), torch.tensor([0.25, 0.25])) 5183 self.assertTrue(same(torch.cos(torch.tensor([0.25, 0.25])), a)) 5184 b = opt_fn(torch.tensor(True), torch.tensor([0.25, 0.25])) 5185 self.assertTrue(same(torch.sin(torch.tensor([0.25, 0.25])), b)) 5186 5187 def test_nonzero_static(self): 5188 # invalid size 5189 with self.assertRaisesRegex( 5190 RuntimeError, "nonzero_static: 'size' must be an non-negative integer" 5191 ): 5192 torch.nonzero_static(torch.tensor([8]), size=-2) 5193 5194 with self.assertRaisesRegex( 5195 RuntimeError, "nonzero_static: 'size' must be an non-negative integer" 5196 ): 5197 torch.nonzero_static(torch.tensor([8]), size=-2, out=torch.tensor(0)) 5198 5199 # nonzero_static.out: out dtype mismatch 5200 input_tensor = torch.tensor([8]) 5201 static_size = 1 5202 out_tensor = torch.empty((static_size, input_tensor.dim()), dtype=torch.float) 5203 with self.assertRaisesRegex( 5204 RuntimeError, "nonzero_static: Expected out tensor to have scalar type Long" 5205 ): 5206 torch.nonzero_static(input_tensor, size=static_size, out=out_tensor) 5207 5208 # nonzero_static.out: out resize (shrink) 5209 input_tensor = torch.tensor([8]) 5210 static_size = 1 5211 out_tensor = torch.empty((10, 10, 10, 10), dtype=torch.long) 5212 self.assertTrue( 5213 same( 5214 torch.nonzero_static(input_tensor, size=static_size, out=out_tensor), 5215 torch.tensor([0]), 5216 ) 5217 ) 5218 self.assertTrue( 5219 same( 5220 out_tensor, 5221 torch.tensor([0]), 5222 ) 5223 ) 5224 5225 # nonzero_static.out: out resize (enlarge) 5226 input_tensor = torch.tensor([8]) 5227 static_size = 1 5228 out_tensor = torch.empty((0), dtype=torch.long) 5229 self.assertTrue( 5230 same( 5231 torch.nonzero_static(input_tensor, size=static_size, out=out_tensor), 5232 torch.tensor([0]), 5233 ) 5234 ) 5235 self.assertTrue( 5236 same( 5237 out_tensor, 5238 torch.tensor([0]), 5239 ) 5240 ) 5241 5242 # 0 rank 5243 input_tensor = torch.tensor(6) 5244 static_size = 2 5245 self.assertTrue( 5246 same( 5247 torch.nonzero_static(input_tensor, size=static_size), 5248 torch.empty((static_size, input_tensor.dim()), dtype=torch.long), 5249 ) 5250 ) 5251 5252 # 0 size 5253 input_tensor = torch.tensor([[[1]]]) 5254 static_size = 0 5255 self.assertTrue( 5256 same( 5257 torch.nonzero_static(input_tensor, size=static_size), 5258 torch.empty((static_size, input_tensor.dim()), dtype=torch.long), 5259 ) 5260 ) 5261 5262 # 1D input 5263 input_tensor = torch.tensor([0, 8]) 5264 static_size = 1 5265 self.assertTrue( 5266 same( 5267 torch.nonzero_static(input_tensor, size=static_size), 5268 torch.tensor([1]), 5269 ) 5270 ) 5271 5272 input_tensor = torch.tensor([8, 0]) 5273 static_size = 2 5274 self.assertTrue( 5275 same( 5276 torch.nonzero_static(input_tensor, size=static_size), 5277 torch.tensor([[0], [-1]]), # padded with default fill_value "-1" 5278 ) 5279 ) 5280 5281 # 2D input 5282 input_tensor = torch.tensor([[1.2, 0], [3.4, 5.6]]) 5283 static_size = 5 5284 fill_value = -100 5285 self.assertTrue( 5286 torch._dynamo.utils.same( 5287 torch.nonzero_static( 5288 input_tensor, size=static_size, fill_value=fill_value 5289 ), 5290 torch.tensor( 5291 [ 5292 [0, 0], 5293 [1, 0], 5294 [1, 1], 5295 [fill_value, fill_value], 5296 [fill_value, fill_value], 5297 ] 5298 ), 5299 ) 5300 ) 5301 input_tensor = torch.tensor([[1.2, 0], [3.4, 5.6]]) 5302 static_size = 2 5303 fill_value = -100 5304 self.assertTrue( 5305 torch._dynamo.utils.same( 5306 torch.nonzero_static( 5307 input_tensor, size=static_size, fill_value=fill_value 5308 ), 5309 torch.tensor([[0, 0], [1, 0]]), 5310 ) 5311 ) 5312 5313 # 3D input 5314 input_tensor = torch.tensor([[[0, 0], [0, -3]], [[0, 0], [5, 0]]]) 5315 static_size = 4 5316 fill_value = -999 5317 self.assertTrue( 5318 torch._dynamo.utils.same( 5319 torch.nonzero_static( 5320 input_tensor, 5321 size=static_size, 5322 fill_value=fill_value, 5323 ), 5324 torch.tensor( 5325 [ 5326 [0, 1, 1], 5327 [1, 1, 0], 5328 [fill_value, fill_value, fill_value], 5329 [fill_value, fill_value, fill_value], 5330 ] 5331 ), 5332 ) 5333 ) 5334 5335 def test_cond_with_quantization(self): 5336 from functorch.experimental.control_flow import cond 5337 5338 class MyModule(torch.nn.Module): 5339 def __init__(self): 5340 super().__init__() 5341 example_inputs = (torch.randn(5, 5),) 5342 self.model = torch.nn.Linear(5, 5) 5343 self.quantized_model = prepare_qat_fx( 5344 self.model, qconfig_dict, example_inputs=example_inputs 5345 ) 5346 5347 def forward(self, pred, x): 5348 def true_fn(x): 5349 return x.sin() + self.quantized_model(x) 5350 5351 def false_fn(x): 5352 return x.cos() + self.model(x) 5353 5354 return cond(pred, true_fn, false_fn, [x]) 5355 5356 module = MyModule() 5357 opt_m = torch._dynamo.optimize("eager", nopython=True)(module) 5358 x = torch.rand((5, 5)) 5359 pred = torch.tensor(True) 5360 self.assertTrue(same(module(pred, x), opt_m(pred, x))) 5361 pred = torch.tensor(False) 5362 self.assertTrue(same(module(pred, x), opt_m(pred, x))) 5363 5364 def test_map_with_quantization(self): 5365 from functorch.experimental.control_flow import map 5366 5367 class MyModule(torch.nn.Module): 5368 def __init__(self): 5369 super().__init__() 5370 example_inputs = (torch.randn(5, 5),) 5371 self.model = torch.nn.Linear(5, 5) 5372 self.quantized_model = prepare_qat_fx( 5373 self.model, qconfig_dict, example_inputs=example_inputs 5374 ) 5375 5376 def forward(self, x): 5377 def body(x): 5378 return x.sin() + self.quantized_model(x) 5379 5380 return map(body, x) 5381 5382 module = MyModule() 5383 opt_m = torch._dynamo.optimize("eager", nopython=True)(module) 5384 x = torch.rand((5, 5)) 5385 self.assertTrue(same(module(x), opt_m(x))) 5386 5387 def test_cond_side_effects(self): 5388 from functorch.experimental.control_flow import cond 5389 5390 c = 0 5391 5392 def true_fn(x): 5393 return x - c 5394 5395 def false_fn(x): 5396 return x + c 5397 5398 def f(pred, x): 5399 nonlocal c 5400 c = 1 5401 return cond(pred, true_fn, false_fn, [x]) 5402 5403 opt_fn = torch._dynamo.optimize("eager")(f) 5404 c = 0 5405 a = opt_fn(torch.tensor(False), torch.tensor([0.25, 0.25])) 5406 self.assertTrue(same(torch.tensor([1.25, 1.25]), a)) 5407 5408 def test_map_side_effects(self): 5409 from functorch.experimental.control_flow import map 5410 5411 class Module(torch.nn.Module): 5412 def __init__(self): 5413 super().__init__() 5414 self.w = torch.tensor(1) 5415 5416 def forward(self, xs): 5417 def body(x): 5418 self.w += 1 5419 return x 5420 5421 return map(body, xs) 5422 5423 mod = Module() 5424 5425 error_message = "" 5426 if torch._dynamo.config.inline_inbuilt_nn_modules: 5427 error_message = r"HigherOrderOperator: Mutating a variable not in the current scope \(SideEffects\)" 5428 else: 5429 error_message = "Can't inplace modify module params/buffers" 5430 5431 with self.assertRaisesRegex(Unsupported, error_message): 5432 opt_fn = torch._dynamo.optimize("eager", nopython=True)(mod) 5433 opt_fn(torch.randn(3, 2)) 5434 5435 def test_cond_nested(self): 5436 from functorch.experimental.control_flow import cond 5437 5438 def true_fn_nested(x): 5439 return x * 10 5440 5441 def false_fn_nested(x): 5442 return x * -1 5443 5444 def true_fn(pred2, x): 5445 return x.sin() 5446 5447 def false_fn(pred2, x): 5448 return x + cond(pred2, true_fn_nested, false_fn_nested, [x]) 5449 5450 def f(pred, pred2, x): 5451 return cond(pred, true_fn, false_fn, [pred2, x]) 5452 5453 cc = torch._dynamo.testing.CompileCounter() 5454 opt_fn = torch._dynamo.optimize(cc)(f) 5455 true_true_sin = opt_fn( 5456 torch.tensor(True), torch.tensor(True), torch.tensor([0.25, 0.25]) 5457 ) 5458 self.assertTrue(same(torch.sin(torch.tensor([0.25, 0.25])), true_true_sin)) 5459 5460 true_false_sin = opt_fn( 5461 torch.tensor(True), torch.tensor(False), torch.tensor([0.25, 0.25]) 5462 ) 5463 self.assertTrue(same(torch.sin(torch.tensor([0.25, 0.25])), true_false_sin)) 5464 5465 false_true_sum_mult = opt_fn( 5466 torch.tensor(False), torch.tensor(True), torch.tensor([0.25, 0.25]) 5467 ) 5468 self.assertTrue( 5469 same(torch.tensor([2.75, 2.75]), false_true_sum_mult) 5470 ) # * 10 then add x 5471 5472 false_false_sum_neg = opt_fn( 5473 torch.tensor(False), torch.tensor(False), torch.tensor([0.25, 0.25]) 5474 ) 5475 self.assertTrue( 5476 same(torch.tensor([0.0, 0.0]), false_false_sum_neg) 5477 ) # * -1 then add x 5478 self.assertTrue(cc.frame_count, 2) 5479 5480 def test_cond_export(self): 5481 from functorch.experimental.control_flow import cond 5482 5483 def true_fn_nested(x): 5484 return x * 10 5485 5486 def false_fn_nested(x): 5487 return x * -1 5488 5489 def true_fn(pred2, x): 5490 return x.sin() 5491 5492 def false_fn(pred2, x): 5493 return x + cond(pred2, true_fn_nested, false_fn_nested, [x]) 5494 5495 def f(pred, pred2, x): 5496 return cond(pred, true_fn, false_fn, [pred2, x]) 5497 5498 graph, guard = torch._dynamo.export(f)( 5499 torch.tensor(False), torch.tensor(True), torch.tensor([0.25, 0.25]) 5500 ) 5501 true_true_sin = graph( 5502 torch.tensor(True), torch.tensor(True), torch.tensor([0.25, 0.25]) 5503 ) 5504 self.assertTrue(same(torch.sin(torch.tensor([0.25, 0.25])), true_true_sin)) 5505 5506 true_false_sin = graph( 5507 torch.tensor(True), torch.tensor(False), torch.tensor([0.25, 0.25]) 5508 ) 5509 self.assertTrue(same(torch.sin(torch.tensor([0.25, 0.25])), true_false_sin)) 5510 5511 false_true_sum_mult = graph( 5512 torch.tensor(False), torch.tensor(True), torch.tensor([0.25, 0.25]) 5513 ) 5514 self.assertTrue( 5515 same(torch.tensor([2.75, 2.75]), false_true_sum_mult) 5516 ) # * 10 then add x 5517 5518 false_false_sum_neg = graph( 5519 torch.tensor(False), torch.tensor(False), torch.tensor([0.25, 0.25]) 5520 ) 5521 self.assertTrue( 5522 same(torch.tensor([0.0, 0.0]), false_false_sum_neg) 5523 ) # * -1 then add x 5524 5525 def test_cond_export_single_arg(self): 5526 from functorch.experimental.control_flow import cond 5527 5528 def true_fn(x): 5529 return x 5530 5531 def false_fn(x): 5532 return x.sin() 5533 5534 def f(pred, x): 5535 return cond(pred, true_fn, false_fn, [x]) 5536 5537 graph, guard = torch._dynamo.export(f)( 5538 torch.tensor(False), torch.tensor([0.25, 0.25]) 5539 ) 5540 true_mirror = graph(torch.tensor(True), torch.tensor([0.25, 0.25])) 5541 self.assertTrue(same(torch.tensor([0.25, 0.25]), true_mirror)) 5542 true_mirror_2 = graph(torch.tensor(True), torch.tensor([0.33, 0.33, 0.33])) 5543 self.assertTrue(same(torch.tensor([0.33, 0.33, 0.33]), true_mirror_2)) 5544 5545 false_sin = graph(torch.tensor(False), torch.tensor([0.5, 0.5])) 5546 self.assertTrue(same(torch.sin(torch.tensor([0.5, 0.5])), false_sin)) 5547 5548 def test_enum_guards(self): 5549 class MyEnum(enum.Enum): 5550 FOO = 10 5551 BAR = 20 5552 5553 def fn(x, y): 5554 if y == MyEnum.FOO: 5555 return x + 1 5556 else: 5557 return x - 1 5558 5559 x = torch.rand(3) 5560 y = MyEnum.BAR 5561 ref = fn(x, y) 5562 opt_fn = torch.compile(backend="eager")(fn) 5563 res = opt_fn(x, y) 5564 self.assertTrue(same(ref, res)) 5565 5566 def test_duplicate_graph_break_log(self): 5567 torch._logging.set_logs(graph_breaks=True) 5568 5569 @torch._dynamo.optimize("eager") 5570 def f1(a, b): 5571 f2(a, b) 5572 5573 def f2(a, b): 5574 c = a + b 5575 print("break") 5576 return a + b + c 5577 5578 @torch._dynamo.optimize("eager") 5579 def g1(a, b): 5580 g2(a, b) 5581 5582 def g2(a, b): 5583 c = a + b 5584 print("break") 5585 return a + b + c 5586 5587 def count_graph_break_msgs(msgs): 5588 return sum(msg.find("Graph break") != -1 for msg in msgs) 5589 5590 with self.assertLogs( 5591 logger="torch._dynamo", level=logging.DEBUG 5592 ) as log, torch._dynamo.config.patch(verbose=True): 5593 f1(torch.randn(10), torch.randn(10)) 5594 self.assertGreater(count_graph_break_msgs(log.output), 1) 5595 5596 with self.assertLogs( 5597 logger="torch._dynamo", level=logging.DEBUG 5598 ) as log, torch._dynamo.config.patch(verbose=False): 5599 g1(torch.randn(10), torch.randn(10)) 5600 self.assertEqual(count_graph_break_msgs(log.output), 1) 5601 5602 # reset logging state 5603 torch._logging.set_logs() 5604 5605 def test_inplace_param_update(self): 5606 def fn(param, y): 5607 prev_grad = torch.is_grad_enabled() 5608 try: 5609 torch.set_grad_enabled(False) 5610 torch.set_grad_enabled(True) 5611 torch.set_grad_enabled(False) 5612 param.add_(y) 5613 finally: 5614 torch.set_grad_enabled(prev_grad) 5615 5616 y = torch.randn(4) 5617 x = torch.nn.Parameter(torch.randn(4)) 5618 fn(x, y) 5619 5620 cnts = torch._dynamo.testing.CompileCounter() 5621 opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 5622 opt_fn(x, y) 5623 self.assertEqual(cnts.frame_count, 1) 5624 self.assertEqual(cnts.op_count, 3) 5625 5626 @unittest.skipIf( 5627 not PLATFORM_SUPPORTS_FLASH_ATTENTION, 5628 "Can't run fused SDPA on this platform", 5629 ) 5630 def test_parsing_sdpa(self): 5631 class MyModule(torch.nn.Module): 5632 def forward(self, query, key, value): 5633 out = F.scaled_dot_product_attention(query, key, value, None, 0, True) 5634 out = F.scaled_dot_product_attention( 5635 query, key, value, None, 0, True, scale=8 5636 ) 5637 out = F.scaled_dot_product_attention( 5638 query=query, 5639 key=key, 5640 value=value, 5641 attn_mask=None, 5642 dropout_p=0, 5643 is_causal=True, 5644 ) 5645 out = F.scaled_dot_product_attention( 5646 query, 5647 key=key, 5648 value=value, 5649 attn_mask=None, 5650 dropout_p=0, 5651 is_causal=True, 5652 ) 5653 out = F.scaled_dot_product_attention( 5654 query, key, value, None, dropout_p=0, is_causal=True 5655 ) 5656 out = F.scaled_dot_product_attention(query, key, value, None, scale=8) 5657 return out 5658 5659 device = "cuda" 5660 dtype = torch.float16 5661 seq_len_q = 1 5662 seq_len_k = 1 5663 head_dim = 8 5664 query = torch.ones( 5665 1, 8, seq_len_q, head_dim, device=device, dtype=dtype, requires_grad=True 5666 ) 5667 key = torch.ones( 5668 1, 8, seq_len_k, head_dim, device=device, dtype=dtype, requires_grad=True 5669 ) 5670 value = torch.ones( 5671 1, 8, seq_len_k, head_dim, device=device, dtype=dtype, requires_grad=True 5672 ) 5673 module = MyModule() 5674 opt_mod = torch._dynamo.optimize("inductor")(module) 5675 opt_mod(query, key, value) 5676 5677 def test_generate_tensor_from_list_of_numpy_primitive_type(self): 5678 # Test sth like torch.LongTensor(list(np.int64, np.int64, ...)) 5679 def fn(): 5680 x = np.array([1, 2, 3, 4, 5, 6], dtype=np.int64) 5681 y = [x[0], x[2], x[4]] 5682 return torch.LongTensor(y) 5683 5684 ref = fn() 5685 res = torch.compile(fullgraph=True)(fn)() 5686 self.assertEqual(ref, res) 5687 5688 def test_object_classmethod(self): 5689 class C: 5690 @classmethod 5691 def fn(cls, x): 5692 return x + x 5693 5694 @torch._dynamo.optimize("eager", nopython=True) 5695 def f(): 5696 return C().fn(torch.ones(2, 3)) 5697 5698 self.assertTrue(torch.allclose(f(), torch.tensor([2.0]))) 5699 5700 def test_object_staticmethod(self): 5701 class C: 5702 @staticmethod 5703 def fn(x): 5704 return x + x 5705 5706 @torch._dynamo.optimize("eager", nopython=True) 5707 def f(): 5708 return C().fn(torch.ones(2, 3)) 5709 5710 self.assertTrue(torch.allclose(f(), torch.tensor([2.0]))) 5711 5712 def test_user_function_variable_supports_enum_argument(self): 5713 class Foo(enum.Enum): 5714 FOO = 0 5715 BAR = 1 5716 5717 def gn(x, y=Foo.FOO): 5718 if y is Foo.FOO: 5719 return x 5720 else: 5721 return x + 1 5722 5723 def fn(x): 5724 return gn(x) 5725 5726 x = torch.randn(2, 3) 5727 ref = fn(x) 5728 opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) 5729 res = opt_fn(x) 5730 self.assertTrue(torch.allclose(ref, res)) 5731 5732 def test_user_function_variable_supports_type_abcmeta_argument(self): 5733 class Foo(metaclass=abc.ABCMeta): 5734 @abc.abstractclassmethod 5735 def read(self): # noqa: B027 5736 pass 5737 5738 class Bar(Foo): 5739 def read(self): 5740 return "Hello World!" 5741 5742 class Baz: 5743 pass 5744 5745 def gn(x, tys=(Bar, Baz)): 5746 if Bar in tys: 5747 return x - 1 5748 else: 5749 return x + 1 5750 5751 def fn(x): 5752 return gn(x) 5753 5754 x = torch.randn(2, 3) 5755 ref = fn(x) 5756 opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) 5757 res = opt_fn(x) 5758 self.assertTrue(torch.allclose(ref, res)) 5759 5760 def test_user_function_variable_supports_function_argument(self): 5761 # Test user defined function default arguments can be: 5762 # 1, user defined functions (e.g, add1) 5763 # 2, torch functions (e.g, torch.sin) 5764 # 3, python builtin functions (e.g, operator.neg) 5765 def add1(x): 5766 return x + 1 5767 5768 def gn(x, f1=add1, f2=torch.sin, f3=operator.neg): 5769 return f3(f2(f1(x))) 5770 5771 def fn(x): 5772 return gn(x) 5773 5774 x = torch.randn(2, 3) 5775 ref = fn(x) 5776 opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) 5777 res = opt_fn(x) 5778 self.assertTrue(torch.allclose(ref, res)) 5779 5780 def test_typing_variable_isinstance(self): 5781 def fn(x, m): 5782 if isinstance(m, typing.Mapping): 5783 return x + 1 5784 else: 5785 return x - 1 5786 5787 x = torch.randn(2, 3) 5788 m = {"x": torch.randn(3)} 5789 ref = fn(x, m) 5790 opt_fn = torch._dynamo.optimize("eager")(fn) 5791 res = opt_fn(x, m) 5792 self.assertTrue(torch.allclose(ref, res)) 5793 5794 @torch._dynamo.config.patch(guard_nn_modules=True) 5795 def test_repro_graph_breaks_in__get_item_by_idx(self): 5796 class Mod(torch.nn.Module): 5797 def __init__(self): 5798 super().__init__() 5799 self.mod = torch.nn.Sequential( 5800 torch.nn.Linear(3, 3), torch.nn.Linear(3, 3) 5801 ) 5802 5803 def forward(self, x): 5804 return self.mod[0](x) 5805 5806 m = Mod() 5807 graph, _ = torch._dynamo.export(m)(torch.randn(3, 3)) 5808 5809 @torch._dynamo.config.patch(guard_nn_modules=True) 5810 def test_nn_sequential_invocation(self): 5811 with freeze_rng_state(): 5812 5813 class TestModel(torch.nn.Module): 5814 def __init__(self) -> None: 5815 super().__init__() 5816 self.linears = torch.nn.Sequential( 5817 torch.nn.Linear(2, 2), 5818 torch.nn.Linear(2, 2), 5819 torch.nn.Linear(2, 2), 5820 torch.nn.Linear(2, 2), 5821 ) 5822 5823 def forward(self, x): 5824 all_but_last = self.linears[:-1] 5825 return all_but_last(x) 5826 5827 m = TestModel() 5828 x = torch.rand((2, 2)) 5829 real = m(x) 5830 graph, _ = torch._dynamo.export(m)(x) 5831 dynamo_result = graph(x) 5832 self.assertTrue(same(real, dynamo_result)) 5833 5834 @torch._dynamo.config.patch(guard_nn_modules=True) 5835 def test_nn_sequential_invocation_reposition_indices(self): 5836 with freeze_rng_state(): 5837 5838 class TestModel(torch.nn.Module): 5839 def __init__(self) -> None: 5840 super().__init__() 5841 self.linears = torch.nn.Sequential( 5842 torch.nn.Linear(2, 2), 5843 torch.nn.Linear(2, 2), 5844 torch.nn.Linear(2, 2), 5845 torch.nn.Linear(2, 2), 5846 ) 5847 5848 def forward(self, x): 5849 all_but_last = self.linears[1:3] 5850 return all_but_last(x) 5851 5852 m = TestModel() 5853 x = torch.rand((2, 2)) 5854 real = m(x) 5855 graph, _ = torch._dynamo.export(m)(x) 5856 dynamo_result = graph(x) 5857 self.assertTrue(same(real, dynamo_result)) 5858 5859 def test_error_on_nested_fx_trace(self): 5860 input = torch.rand(2, 3) 5861 5862 def f(x): 5863 x + x 5864 5865 real = f(input) 5866 5867 optimized = torch._dynamo.optimize("eager")(f) 5868 self.assertTrue(same(optimized(input), real)) 5869 5870 with self.assertRaisesRegex(RuntimeError, "Detected that you are using FX"): 5871 gm = torch.fx.symbolic_trace(optimized) 5872 5873 @patch.object(torch._dynamo.config, "error_on_nested_fx_trace", False) 5874 def test_no_error_on_nested_fx_trace(self): 5875 input = torch.rand(2, 3) 5876 5877 def f(x): 5878 x + x 5879 5880 real = f(input) 5881 5882 optimized = torch._dynamo.optimize("eager")(f) 5883 self.assertTrue(same(optimized(input), real)) 5884 5885 # should not error 5886 gm = torch.fx.symbolic_trace(optimized) 5887 self.assertTrue(same(gm(input), real)) 5888 5889 def test_not_dynamic_scope(self): 5890 def f(y): 5891 x = 1 5892 5893 def g(): 5894 x = 2 5895 return lambda: x 5896 5897 return y + g()() 5898 5899 input = torch.zeros(1) 5900 real = f(input) 5901 optimized = torch._dynamo.optimize("eager")(f) 5902 opt = optimized(input) 5903 self.assertTrue(same(opt, real)) 5904 5905 def test_inference_mode(self): 5906 @torch.inference_mode() 5907 def func(x, y): 5908 return x.add(1.0) + y 5909 5910 x = torch.ones(4, requires_grad=True) 5911 y = torch.ones(4, requires_grad=True) 5912 ref = func(x, y) 5913 opt_func = torch._dynamo.optimize("eager")(func) 5914 5915 x1 = torch.ones(4, requires_grad=True) 5916 res = opt_func(x1, y) 5917 self.assertTrue(same(ref, res)) 5918 self.assertTrue(same(x, x1)) 5919 5920 def test_if_cond_nn_mod1(self): 5921 class MockModule(torch.nn.Module): 5922 def __init__(self, output_relu=True): 5923 super().__init__() 5924 self.relu = torch.nn.ReLU() if output_relu else None 5925 5926 def forward(self, x): 5927 x = torch.sin(x) 5928 if self.relu: 5929 x = self.relu(x) 5930 return x 5931 5932 model = MockModule() 5933 opt_model = torch._dynamo.optimize("eager", nopython=True)(model) 5934 5935 x = torch.rand(4) 5936 ref = model(x) 5937 res = opt_model(x) 5938 self.assertTrue(same(ref, res)) 5939 5940 model = MockModule(output_relu=False) 5941 opt_model = torch._dynamo.optimize("eager", nopython=True)(model) 5942 5943 x = torch.rand(4) 5944 ref = model(x) 5945 res = opt_model(x) 5946 self.assertTrue(same(ref, res)) 5947 5948 def test_if_cond_nn_mod2(self): 5949 class MockModule(torch.nn.Module): 5950 def __init__(self): 5951 super().__init__() 5952 self.layer = torch.nn.Sequential() 5953 5954 def forward(self, x): 5955 if self.layer: 5956 return x + 1 5957 else: 5958 return x - 1 5959 5960 model = MockModule() 5961 x = torch.rand(4) 5962 ref = model(x) 5963 opt_model = torch.compile(backend="eager")(model) 5964 res = opt_model(x) 5965 self.assertTrue(same(ref, res)) 5966 5967 def test_if_cond_nn_mod3(self): 5968 def fn(x): 5969 if torch.nn.ModuleList(): 5970 return x + 1 5971 else: 5972 return x - 1 5973 5974 x = torch.rand(4) 5975 ref = fn(x) 5976 opt_fn = torch.compile(backend="eager")(fn) 5977 res = opt_fn(x) 5978 self.assertTrue(same(ref, res)) 5979 5980 def test_if_cond_user_defined_object(self): 5981 # obj.__bool__ is not existed 5982 class A: # noqa: B903 5983 def __init__(self, x): 5984 self.x = x 5985 5986 # obj.__bool__ is function and returns bool type 5987 class B: 5988 def __init__(self, x): 5989 self.x = x 5990 5991 def __bool__(self): 5992 return self.x > 0 5993 5994 # obj.__bool__ is non-function 5995 class C: 5996 def __init__(self, x): 5997 self.x = x 5998 self.__bool__ = False 5999 6000 def fn(x, obj): 6001 if not obj: 6002 return x + 1 6003 else: 6004 return x - 1 6005 6006 x = torch.rand(4) 6007 cnts = torch._dynamo.testing.CompileCounter() 6008 opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 6009 obj1 = A(0.5) 6010 obj2 = B(0.5) 6011 obj3 = B(-0.5) 6012 obj4 = C(0.5) 6013 for obj in [obj1, obj2, obj3, obj4, obj3, obj2]: 6014 ref = fn(x, obj) 6015 res = opt_fn(x, obj) 6016 self.assertTrue(same(ref, res)) 6017 self.assertEqual(cnts.frame_count, 4) 6018 6019 def test_if_cond_user_defined_object2(self): 6020 # obj.__bool__ is function and returns non-bool type 6021 class MyObj: 6022 def __init__(self, x): 6023 self.x = x 6024 6025 def __bool__(self): 6026 self.x = 1.2 6027 return self.x 6028 6029 def fn(a, obj): 6030 if not obj: 6031 return a + obj.x 6032 else: 6033 return a - obj.x 6034 6035 x = torch.rand(4) 6036 obj = MyObj(0.5) 6037 opt_fn = torch._dynamo.optimize("eager")(fn) 6038 try: 6039 opt_fn(x, obj) 6040 self.assertFalse(True) 6041 except TypeError as e: 6042 self.assertIn("__bool__ should return bool, returned float", str(e)) 6043 6044 def test_if_cond_user_defined_object3(self): 6045 # obj.__bool__ is not existed, but obj.__len__ exists 6046 class A: # noqa: B903 6047 def __init__(self, x): 6048 self.x = x 6049 6050 def __len__(self): 6051 return len(self.x) 6052 6053 # obj.__bool__ takes precedence over obj.__len__ 6054 class B: 6055 def __init__(self, x): 6056 self.x = x 6057 6058 def __bool__(self): 6059 return False 6060 6061 def __len__(self): 6062 return len(self.x) 6063 6064 def fn(x, obj): 6065 if not obj: 6066 return x + 1 6067 else: 6068 return x - 1 6069 6070 x = torch.rand(4) 6071 opt_fn = torch.compile(backend="eager", fullgraph=True)(fn) 6072 obj1 = A([1, 2, 3]) 6073 obj2 = A([]) 6074 obj3 = B([1, 2, 3]) 6075 obj4 = B([]) 6076 for obj in [obj1, obj2, obj3, obj4]: 6077 ref = fn(x, obj) 6078 res = opt_fn(x, obj) 6079 self.assertTrue(same(ref, res)) 6080 6081 def test_class_has_instancecheck_method(self): 6082 class A: 6083 pass 6084 6085 class ExampleMeta(type): 6086 def __instancecheck__(cls, instance): 6087 return True 6088 6089 class B(metaclass=ExampleMeta): 6090 pass 6091 6092 def fn(x, obj): 6093 if isinstance(obj, B): 6094 return x + 1 6095 else: 6096 return x - 1 6097 6098 x = torch.rand(4) 6099 obj = A() 6100 ref = fn(x, obj) 6101 opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) 6102 res = opt_fn(x, obj) 6103 self.assertTrue(same(ref, res)) 6104 6105 def test_torch_cuda_is_available(self): 6106 def fn(x): 6107 if torch.cuda.is_available(): 6108 return x + 1 6109 else: 6110 return x - 1 6111 6112 x = torch.rand(4) 6113 ref = fn(x) 6114 opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) 6115 res = opt_fn(x) 6116 self.assertTrue(same(ref, res)) 6117 6118 def test_variable_tracker_recursively_contains(self): 6119 # VariableTracker.recursively_contains should be updated correctly when mutation happens 6120 def fn(x): 6121 data = [[None] * 3] * 3 6122 for i in range(3): 6123 if i == 0: 6124 data[0][i] = x 6125 else: 6126 data[0][i] = data[0][i - 1] + 1 6127 return data[0][-1] 6128 6129 x = torch.rand(4) 6130 ref = fn(x) 6131 opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) 6132 res = opt_fn(x) 6133 self.assertTrue(same(ref, res)) 6134 6135 @unittest.skipIf(not TEST_CUDA, "requires cuda") 6136 @unittest.skipIf(not torch.backends.cudnn.is_available(), "requires cudnn") 6137 def test_torch_cudnn_is_acceptable(self): 6138 def fn(x): 6139 if torch.backends.cudnn.is_acceptable(tensor=x): 6140 return x + 1 6141 return x 6142 6143 x = torch.rand(4).cuda() 6144 ref = fn(x) 6145 opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) 6146 res = opt_fn(x) 6147 self.assertTrue(same(ref, res)) 6148 6149 @unittest.skipIf(not TEST_CUDA, "requires cuda") 6150 @unittest.skipIf(not torch.backends.cudnn.is_available(), "requires cudnn") 6151 def test_torch_cudnn_is_acceptable_bad_inputs(self): 6152 def fn1(x): 6153 if torch.backends.cudnn.is_acceptable("invalid"): 6154 return x + 1 6155 return x 6156 6157 def fn2(x): 6158 if torch.backends.cudnn.is_acceptable(x, 3.14): 6159 return x + 1 6160 return x 6161 6162 with self.assertRaisesRegex( 6163 AssertionError, "Expect input to cudnn.is_acceptable to be a tensor" 6164 ): 6165 x1 = torch.rand(4).cuda() 6166 opt_fn1 = torch._dynamo.optimize("eager", nopython=True)(fn1) 6167 res1 = opt_fn1(x1) 6168 6169 with self.assertRaisesRegex( 6170 AssertionError, "Expect 1 input to cudnn.is_acceptable" 6171 ): 6172 x2 = torch.rand(4).cuda() 6173 opt_fn2 = torch._dynamo.optimize("eager", nopython=True)(fn2) 6174 res = opt_fn2(x2) 6175 6176 @unittest.skipIf(not TEST_CUDA, "requires cuda") 6177 def test_get_device(self): 6178 def fn(x, y): 6179 x = x + 1 6180 y = y + 1 6181 return x.get_device(), y.get_device() 6182 6183 x = torch.rand(4, device="cuda") 6184 y = torch.rand(4, device="cpu") 6185 ref = fn(x, y) 6186 opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) 6187 res = opt_fn(x, y) 6188 self.assertTrue(same(ref, res)) 6189 6190 def test_disable_flag(self): 6191 cnt = torch._dynamo.testing.CompileCounter() 6192 6193 with patch.dict(os.environ, {"TORCH_COMPILE_DISABLE": "1"}): 6194 6195 def fn(x, y): 6196 x = x + 1 6197 y = y + 1 6198 6199 opt_fn = torch._dynamo.optimize(cnt) 6200 6201 self.assertEqual(cnt.frame_count, 0) 6202 6203 def test_is_compiling(self): 6204 def f1(): 6205 if torch._dynamo.is_compiling(): 6206 return torch.ones(2, 2) 6207 else: 6208 return torch.zeros(2, 2) 6209 6210 def f2(): 6211 if torch._utils.is_compiling(): 6212 return torch.ones(2, 2) 6213 else: 6214 return torch.zeros(2, 2) 6215 6216 def f3(): 6217 if torch.compiler.is_compiling(): 6218 return torch.ones(2, 2) 6219 else: 6220 return torch.zeros(2, 2) 6221 6222 def f4(): 6223 if torch.compiler.is_dynamo_compiling(): 6224 return torch.ones(2, 2) 6225 else: 6226 return torch.zeros(2, 2) 6227 6228 for f in [f1, f2, f3, f4]: 6229 opt_f = torch._dynamo.optimize("eager")(f) 6230 6231 self.assertEqual(f(), torch.zeros(2, 2)) 6232 self.assertEqual(opt_f(), torch.ones(2, 2)) 6233 6234 def test_torch_generator_set_state(self): 6235 def fn(): 6236 default_state = torch.default_generator.get_state() 6237 x = torch.rand([2, 3]) 6238 if default_state.dtype != "float32": 6239 x = x * 2 6240 torch._dynamo.graph_break() 6241 torch.default_generator.set_state(default_state) 6242 y = torch.rand([2, 3]) 6243 return x, y 6244 6245 opt_fn = torch._dynamo.optimize("eager")(fn) 6246 x, y = opt_fn() 6247 self.assertEqual(x, y * 2) 6248 6249 def test_torch_distributions_lazy_property(self): 6250 def fn(x): 6251 return torch.distributions.Categorical(probs=x).entropy() 6252 6253 opt_fn = torch._dynamo.optimize("eager")(fn) 6254 x = torch.rand([4, 4]) 6255 self.assertEqual(opt_fn(x), fn(x)) 6256 6257 def test_guard_failure_fn(self): 6258 def fn(x, y, k): 6259 x = x + 1 6260 y = y + 1 6261 return x * y * k 6262 6263 x = torch.tensor([0.5, 0.5]) 6264 y = torch.tensor([1.0, 1.0]) 6265 6266 guard_failure = None 6267 6268 def guard_failures(failure): 6269 nonlocal guard_failure 6270 guard_failure = failure 6271 6272 opt_fn = torch._dynamo.optimize( 6273 "eager", nopython=True, guard_fail_fn=guard_failures 6274 )(fn) 6275 6276 x2 = torch.tensor([0.5, 0.5, 1.0]) 6277 y2 = torch.tensor([0.5, 0.5, 0.5]) 6278 6279 opt_fn(x, y, 3) 6280 opt_fn(x2, y2, 5) 6281 6282 if ( 6283 not torch._dynamo.config.specialize_int 6284 and not torch._dynamo.config.assume_static_by_default 6285 ): 6286 # we didn't actually test guard_failure_fn here but whatever, 6287 # nice to see no guard failure on the test 6288 self.assertTrue(guard_failure is None) 6289 else: 6290 self.assertTrue(guard_failure is not None) 6291 6292 def test_guard_failure_fn_shape_control(self): 6293 def fn(x, y): 6294 if x.shape[0] < 3: 6295 if y.shape[0] < 3: 6296 return x * y 6297 else: 6298 return x + y 6299 else: 6300 return -1 6301 6302 x = torch.randn([2, 2]) 6303 y = torch.randn([2, 2]) 6304 6305 guard_failure = None 6306 6307 def guard_failures(failure): 6308 nonlocal guard_failure 6309 guard_failure = failure 6310 6311 opt_fn = torch._dynamo.optimize( 6312 "eager", nopython=True, guard_fail_fn=guard_failures 6313 )(fn) 6314 6315 x2 = torch.randn([5, 5]) 6316 y2 = torch.randn([5, 5]) 6317 6318 opt_fn(x, y) 6319 opt_fn(x2, y2) 6320 6321 self.assertTrue(guard_failure is not None) 6322 first_guard_failure = guard_failure[0].partition("\n")[0] 6323 if torch._dynamo.config.assume_static_by_default: 6324 self.assertIn( 6325 """tensor 'L['x']' size mismatch at index 0. expected 2, actual 5""", 6326 first_guard_failure, 6327 ) 6328 else: 6329 self.assertIn("""2 <= L['x'].size()[0] <= 2""", first_guard_failure) 6330 6331 def test_guard_failure_fn2(self): 6332 def fn(x, y): 6333 x = x + 1 6334 y = y + 1 6335 return x * y 6336 6337 x = torch.tensor([0.5, 0.5]) 6338 y = torch.tensor([1.0, 1.0]) 6339 6340 guard_failure = None 6341 6342 def guard_failures(failure): 6343 nonlocal guard_failure 6344 guard_failure = failure 6345 6346 opt_fn = torch._dynamo.optimize( 6347 "eager", nopython=True, guard_fail_fn=guard_failures 6348 )(fn) 6349 6350 x2 = torch.tensor([0.5, 0.5, 1.0]) 6351 y2 = torch.tensor([0.5, 0.5, 0.5]) 6352 6353 opt_fn(x, y) 6354 opt_fn(x2, y2) 6355 6356 if torch._dynamo.config.assume_static_by_default: 6357 self.assertIn( 6358 """tensor 'L['x']' size mismatch at index 0. expected 2, actual 3""", 6359 guard_failure[0], 6360 ) 6361 else: 6362 self.assertTrue(guard_failure is None) 6363 6364 def test_guard_failure_fn_tensor_iter(self): 6365 def fn(x): 6366 for y in x: 6367 y.add_(1.0) 6368 return y 6369 6370 guard_failure = None 6371 6372 def guard_failures(failure): 6373 nonlocal guard_failure 6374 guard_failure = failure 6375 6376 opt_fn = torch._dynamo.optimize( 6377 "eager", nopython=True, guard_fail_fn=guard_failures 6378 )(fn) 6379 6380 args1 = torch.randn(10, 10) 6381 out = fn(args1) 6382 opt_out = opt_fn(args1) 6383 self.assertTrue(same(out, opt_out)) 6384 6385 args2 = torch.randn(9, 10) 6386 out = fn(args2) 6387 opt_out = opt_fn(args2) 6388 self.assertTrue(same(out, opt_out)) 6389 6390 # guard is expected for both static and dynamic shapes 6391 self.assertTrue(guard_failure is not None) 6392 self.assertIn( 6393 """len(L['x']) == 10""", 6394 guard_failure[0], 6395 ) 6396 6397 def test_restore_graphstate(self): 6398 # This function does some guard accumulation, 6399 # and then rolls back due to control flow. 6400 # The idea is that if one were printing guards as they appear, 6401 # they would see this insert a guard that does not show up in the final set of 6402 # guards as we rolled back from it. 6403 def nested_fn(s): 6404 if x[0] < 10: 6405 return s * s 6406 return s 6407 6408 def fn(x, y): 6409 x = x + 1 6410 y = nested_fn(y) 6411 y = y + 10 6412 return x * y 6413 6414 all_guards = [] 6415 6416 def guard_export_print(guards): 6417 nonlocal all_guards 6418 all_guards.extend(guards) 6419 6420 opt_fn = torch._dynamo.optimize("eager", guard_export_fn=guard_export_print)(fn) 6421 6422 x = torch.tensor([0.5, 0.5]) 6423 y = torch.tensor([1.0, 1.0]) 6424 opt_fn(x, y) 6425 6426 for guard in all_guards: 6427 # This guard was created 6428 self.assertTrue(guard.name != "nested_fn.__closure__[0].cell_contents") 6429 6430 def test_call_parent_non_class_methods_from_child(self): 6431 class A: 6432 a = 4 6433 6434 def add(self, x): 6435 return x + 10 6436 6437 def mul(self, x): 6438 return x * 0.1 6439 6440 class B(A): 6441 coeff = 4 6442 6443 def add(self, x): 6444 return x + 20 6445 6446 @classmethod 6447 def cube(cls, x): 6448 return cls.coeff * x * x * x 6449 6450 def mul(self, x): 6451 return super().mul(x) * x * 0.2 6452 6453 class C(B): 6454 def add(self, x): 6455 b = super().cube(x) 6456 c = A.add(self, x) 6457 d = B.mul(self, x) 6458 e = super(B, self).add(x) 6459 f = super().a * x 6460 return b + c + d + e + f 6461 6462 x = torch.rand(4) 6463 fn = C().add 6464 ref = fn(x) 6465 cnt = torch._dynamo.testing.CompileCounter() 6466 opt_fn = torch._dynamo.optimize(cnt, nopython=True)(fn) 6467 res = opt_fn(x) 6468 self.assertTrue(same(ref, res)) 6469 self.assertEqual(cnt.frame_count, 1) 6470 6471 # Check recompilation 6472 A.a = 5 6473 ref = fn(x) 6474 res = opt_fn(x) 6475 self.assertTrue(same(ref, res)) 6476 # Ensure that super guard checks are working as expected 6477 res = opt_fn(x) 6478 self.assertEqual(cnt.frame_count, 2) 6479 6480 def test_builder_for_class_with_metaclass(self): 6481 class ExampleMeta(type): 6482 pass 6483 6484 class MyClass(metaclass=ExampleMeta): 6485 pass 6486 6487 def fn(x, y): 6488 if isinstance(y, MyClass): 6489 return x + 1 6490 else: 6491 return x - 1 6492 6493 x = torch.rand([4, 4]) 6494 y = MyClass() 6495 ref = fn(x, y) 6496 opt_fn = torch._dynamo.optimize("eager")(fn) 6497 res = opt_fn(x, y) 6498 self.assertTrue(same(ref, res)) 6499 6500 def test_tuple_from_tuple_iter(self): 6501 def inner_fn(*args): 6502 acc = torch.ones(10, 10) 6503 for arg in args: 6504 acc.add_(arg) 6505 6506 return acc 6507 6508 @torch._dynamo.optimize("eager") 6509 def fn(inputs, params): 6510 y = tuple(inputs) + tuple(params) 6511 return inner_fn(*y) 6512 6513 inputs = [torch.randn(10, 10) for _ in range(3)] 6514 6515 fn(inputs, iter(tuple(inputs))) 6516 6517 def fn(params): 6518 y = tuple(params) 6519 return inner_fn(*y) 6520 6521 opt_fn = torch._dynamo.optimize("eager")(fn) 6522 inputs = [torch.randn(10, 10) for _ in range(3)] 6523 self.assertTrue(same(fn(iter(tuple(inputs))), opt_fn(iter(tuple(inputs))))) 6524 6525 # Force recompilation 6526 inputs = [torch.randn(10, 10) for _ in range(4)] 6527 self.assertTrue(same(fn(iter(tuple(inputs))), opt_fn(iter(tuple(inputs))))) 6528 6529 def test_torch_package_working_with_trace(self): 6530 # from torch._dynamo.test_case import run_tests 6531 6532 inputs = [torch.randn([2, 2]), torch.randn([2, 2])] 6533 6534 optimized_model = torch._dynamo.optimize(backend="eager")( 6535 MyPickledModule(torch.randn([2, 2])) 6536 ) 6537 from torch import package 6538 6539 path = "/tmp/MyPickledModule.pt" 6540 package_name = "MyPickledModule" 6541 resource_name = "MyPickledModule.pkl" 6542 6543 model = MyPickledModule(torch.randn([2, 2])) 6544 6545 with package.PackageExporter(path) as exp: 6546 exp.extern("**") 6547 exp.save_pickle(package_name, resource_name, model) 6548 6549 imp = package.PackageImporter(path) 6550 loaded_model = imp.load_pickle(package_name, resource_name) 6551 6552 optimized_loaded_model = torch._dynamo.optimize("eager")(loaded_model)(*inputs) 6553 6554 def test_shape_and_tuple_equality(self): 6555 def fn(x, y, t): 6556 z = x * y 6557 if x.size() == t: 6558 return z.cos() 6559 return z.sin() 6560 6561 torch._dynamo.optimize("eager", nopython=True)(fn)( 6562 torch.randn([4, 4]), torch.randn([4, 4]), (4, 4) 6563 ) 6564 6565 def test_int_list(self): 6566 # if assume_static_by_default == True: spec int list 6567 # otherwise: unspec int list 6568 def fn(x, y): 6569 return torch.sin(x + y[1] % 2) 6570 6571 x = torch.randn(6) 6572 cnt = torch._dynamo.testing.CompileCounter() 6573 opt_fn = torch._dynamo.optimize(cnt)(fn) 6574 for i in range(10, 25, 3): 6575 y = [i, i + 1, i + 2] 6576 ref = fn(x, y) 6577 res = opt_fn(x, y) 6578 self.assertTrue(same(ref, res)) 6579 if torch._dynamo.config.assume_static_by_default: 6580 if torch._dynamo.config.automatic_dynamic_shapes: 6581 self.assertExpectedInline(cnt.frame_count, """2""") 6582 else: 6583 self.assertExpectedInline(cnt.frame_count, """5""") 6584 else: 6585 self.assertExpectedInline(cnt.frame_count, """1""") 6586 6587 def test_patched_builtin_functions(self): 6588 import builtins 6589 6590 # Cache the original builtin function ids 6591 torch._dynamo.trace_rules._builtin_function_ids() 6592 6593 class MyClass: 6594 pass 6595 6596 builtin_isinstance = builtins.isinstance 6597 6598 def patched_isinstance(obj, classinfo) -> bool: 6599 if builtin_isinstance(obj, MyClass): 6600 return False 6601 else: 6602 return builtin_isinstance(obj, classinfo) 6603 6604 def fn(x, y): 6605 if isinstance(y, MyClass): 6606 return x + 1 6607 else: 6608 return x - 1 6609 6610 x = torch.ones(2, 3) 6611 y = MyClass() 6612 6613 try: 6614 ref = fn(x, y) 6615 # Monkey patch builtin function 6616 builtins.isinstance = patched_isinstance 6617 opt_fn = torch.compile(backend="eager", fullgraph=True)(fn) 6618 res = opt_fn(x, y) 6619 self.assertTrue(same(ref, x + 1)) 6620 self.assertTrue(same(res, x - 1)) 6621 finally: 6622 builtins.isinstance = builtin_isinstance 6623 6624 # check recompilation because builtins is now unpatched 6625 opt_fn = torch.compile(backend="eager", fullgraph=True)(fn) 6626 res = opt_fn(x, y) 6627 self.assertTrue(same(res, x + 1)) 6628 6629 # specifically test for tensor.attribute -> torch.something() 6630 def test_real_imag_tensor_attribute(self): 6631 def fn(x, y): 6632 a = x.real 6633 b = x.imag 6634 return torch.mul(torch.add(a, y), b) 6635 6636 x_real = torch.rand((4, 4)) 6637 x_imag = torch.rand((4, 4)) 6638 x = torch.complex(x_real, x_imag) 6639 y = torch.rand((4, 4)) 6640 6641 ref = fn(x, y) 6642 opt_fn = torch._dynamo.optimize("eager")(fn) 6643 res = opt_fn(x, y) 6644 self.assertTrue(same(ref, res)) 6645 6646 def test_cast(self): 6647 from typing import cast 6648 6649 def fn(x): 6650 return cast(torch.Tensor, torch.add(x, 1.0)) 6651 6652 opt_fn = torch.compile(backend="eager", fullgraph=True)(fn) 6653 6654 ref = fn(torch.ones(2, 2)) 6655 res = opt_fn(torch.ones(2, 2)) 6656 6657 self.assertTrue(same(ref, res)) 6658 6659 def test_T_tensor_attribute(self): 6660 def fn(x, y): 6661 a = x.T 6662 return torch.add(a, y) 6663 6664 x = torch.rand((4, 4)) 6665 y = torch.rand((4, 4)) 6666 6667 ref = fn(x, y) 6668 opt_fn = torch._dynamo.optimize("eager")(fn) 6669 res = opt_fn(x, y) 6670 self.assertTrue(same(ref, res)) 6671 6672 def test_recursive_tensor_attribute(self): 6673 def fn(x, y): 6674 a = x.real.T 6675 b = x.imag 6676 return torch.mul(torch.add(a, y), b) 6677 6678 x_real = torch.rand((4, 4)) 6679 x_imag = torch.rand((4, 4)) 6680 x = torch.complex(x_real, x_imag) 6681 y = torch.rand((4, 4)) 6682 6683 ref = fn(x, y) 6684 opt_fn = torch._dynamo.optimize("eager")(fn) 6685 res = opt_fn(x, y) 6686 self.assertTrue(same(ref, res)) 6687 6688 def test_assigning_function_to_object_attribute(self): 6689 # user-defined functions which are object's attributes are not converted to bound methods 6690 def my_add(*args): 6691 a, b = args 6692 return a + b 6693 6694 class MyClass: 6695 def __init__(self, func): 6696 self.add = func 6697 6698 obj = MyClass(my_add) 6699 6700 def fn(x): 6701 return obj.add(x, 2) 6702 6703 x = torch.rand(2, 3) 6704 ref = fn(x) 6705 opt_fn = torch.compile(backend="eager")(fn) 6706 res = opt_fn(x) 6707 self.assertTrue(same(ref, res)) 6708 6709 def test_assigning_function_to_class_attribute(self): 6710 # user-defined functions which are class's attributes are converted to bound methods 6711 def my_add(*args): 6712 obj, a, b = args 6713 return obj.x + a + b 6714 6715 class MyClass: 6716 add = my_add 6717 6718 def __init__(self, x): 6719 self.x = x 6720 6721 obj = MyClass(0.5) 6722 6723 def fn(x): 6724 return obj.add(x, 2) 6725 6726 x = torch.rand(2, 3) 6727 ref = fn(x) 6728 opt_fn = torch.compile(backend="eager")(fn) 6729 res = opt_fn(x) 6730 self.assertTrue(same(ref, res)) 6731 6732 def test_tagging_tensors_simple(self): 6733 def foo(x, y): 6734 return x * y, x, y 6735 6736 a = torch.randn([3, 3]) 6737 a.tag = "a" 6738 a.frog = "ribbity ribbit" 6739 b = torch.randn([3, 3]) 6740 b.tag = "b" 6741 b.frog = "ribbit" 6742 6743 exported = torch._dynamo.export(foo)(a, b) 6744 out_graph = exported[0] 6745 6746 nodes = list(out_graph.graph.nodes) 6747 placeholders = [node for node in nodes if node.op == "placeholder"] 6748 all_tags = [] 6749 all_frogs = [] 6750 for placeholder in placeholders: 6751 if "tensor_dict" in placeholder.meta: 6752 all_tags.append(placeholder.meta["tensor_dict"]["tag"]) 6753 all_frogs.append(placeholder.meta["tensor_dict"]["frog"]) 6754 6755 self.assertEqual(all_tags, ["a", "b"]) 6756 self.assertEqual(all_frogs, ["ribbity ribbit", "ribbit"]) 6757 6758 def test_tagging_tensors_mix_used_unused_structure(self): 6759 def pre_attention_state_ops(input, mems, state): 6760 lc_key = state[0] 6761 lc_val = state[1] 6762 bar = [] 6763 for i in range(0, 4): 6764 bar2 = [] 6765 for j in range(0, 3): 6766 bar2.append( 6767 lc_key + lc_val + torch.tensor([0.1, 0.25, 0.4, 0.5, 0.1]) 6768 ) 6769 bar.append(bar2) 6770 6771 return bar 6772 6773 mems = torch.tensor([[[1.8364, 0.2724, -1.4917, -0.4367, 0.8640]]]) 6774 state = [ 6775 torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]), 6776 torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]), 6777 ] 6778 i = torch.tensor( 6779 [ 6780 [0.0313, -0.1487, -0.3846, -0.5321], 6781 [-1.7073, 1.3331, -0.0890, -1.4935], 6782 [-0.8314, -0.1862, -0.5935, 1.5232], 6783 ] 6784 ) 6785 6786 mems.tag = "MEMS" 6787 i.tag = "FOO" 6788 state[0].tag = "STATE_0" 6789 state[1].tag = "HMMM" 6790 6791 exported = torch._dynamo.export(pre_attention_state_ops)(i, mems, state) 6792 out_graph = exported[0] 6793 6794 nodes = list(out_graph.graph.nodes) 6795 placeholders = [node for node in nodes if node.op == "placeholder"] 6796 all_tags = [] 6797 for placeholder in placeholders: 6798 if "tensor_dict" in placeholder.meta: 6799 all_tags.append(placeholder.meta["tensor_dict"]["tag"]) 6800 6801 self.assertEqual(all_tags, ["STATE_0", "HMMM"]) 6802 6803 def test_get_custom_tensor_attribute(self): 6804 def fn(x): 6805 return x.custom_attr * x 6806 6807 x = torch.rand((2, 2)) 6808 x.custom_attr = 3.14 6809 ref = fn(x) 6810 opt_fn = torch._dynamo.optimize("eager")(fn) 6811 res = opt_fn(x) 6812 self.assertTrue(same(ref, res)) 6813 6814 def test_set_custom_tensor_attribute(self): 6815 def fn(x): 6816 x.custom_attr = 3.14 6817 return x.custom_attr * x 6818 6819 x = torch.rand((2, 2)) 6820 ref = fn(x) 6821 opt_fn = torch._dynamo.optimize("eager")(fn) 6822 res = opt_fn(x) 6823 self.assertTrue(same(ref, res)) 6824 6825 def test_unhandled_exception_in_dynamo(self): 6826 # traceback.format_exc() approximates an unhandled exception 6827 def f(a): 6828 a += 1 6829 raise RuntimeError("smoge") 6830 return a 6831 6832 opt_fn = torch._dynamo.optimize("eager")(f) 6833 try: 6834 opt_fn(torch.ones(2)) 6835 except RuntimeError as e: 6836 self.assertIn("smoge", traceback.format_exc()) 6837 6838 def test_unhandled_exception_in_dynamo2(self): 6839 # segfaults in python 3.11 if shadow frame is freed improperly 6840 from torch.testing import make_tensor 6841 6842 def fn(): 6843 # test that the errors are the same for dense and sparse versions 6844 def test1(*, is_sparse): 6845 # shapes must be compatible for matrix multiplication 6846 a = make_tensor((2, 3), dtype=torch.float32, device="cpu") 6847 if is_sparse: 6848 a_sparse = a.to_sparse_csr() 6849 return torch.addmm(a, a_sparse, a) 6850 else: 6851 return torch.addmm(a, a, a) 6852 6853 try: 6854 test1(is_sparse=False) 6855 except RuntimeError as msg: 6856 try: 6857 test1(is_sparse=True) 6858 except RuntimeError as msg2: 6859 raise RuntimeError("smoge") 6860 6861 opt_fn = torch._dynamo.optimize("eager")(fn) 6862 try: 6863 opt_fn() 6864 except RuntimeError: 6865 self.assertIn("smoge", traceback.format_exc()) 6866 6867 def test_variable_access_in_exception(self): 6868 def fn(): 6869 x = torch.ones(1) 6870 try: 6871 raise RuntimeError("bad") 6872 except RuntimeError: 6873 x += 1 6874 return x 6875 6876 opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) 6877 self.assertEqual(opt_fn(), torch.tensor([2.0])) 6878 6879 def test_nested_sequential_with(self): 6880 def fn(x): 6881 with torch.set_grad_enabled(True): 6882 with torch.set_grad_enabled(False): 6883 x = x + 1 6884 with torch.set_grad_enabled(True): 6885 x = x + 1 6886 return x 6887 6888 opt_fn = torch._dynamo.optimize("eager")(fn) 6889 self.assertEqual(opt_fn(torch.ones(1)), torch.tensor([3.0])) 6890 6891 def test_nested_sequential_try(self): 6892 def fn(x): 6893 try: 6894 try: 6895 x = x + 1 6896 except: 6897 pass 6898 try: 6899 try: 6900 x = x + 1 6901 except: 6902 pass 6903 except: 6904 pass 6905 except: 6906 pass 6907 return x 6908 6909 opt_fn = torch._dynamo.optimize("eager")(fn) 6910 self.assertEqual(opt_fn(torch.ones(1)), torch.tensor([3.0])) 6911 6912 def test_nested_sequential_try_with(self): 6913 def fn(x): 6914 with torch.set_grad_enabled(True): 6915 try: 6916 x = x + 1 6917 except: 6918 pass 6919 try: 6920 with torch.set_grad_enabled(False): 6921 x = x + 1 6922 except: 6923 pass 6924 return x 6925 6926 opt_fn = torch._dynamo.optimize("eager")(fn) 6927 self.assertEqual(opt_fn(torch.ones(1)), torch.tensor([3.0])) 6928 6929 def test_nested_sequential_try_with_graph_break(self): 6930 def fn(x, n): 6931 with torch.set_grad_enabled(True): 6932 with torch.set_grad_enabled(False): 6933 x = x + 1 6934 torch._dynamo.graph_break() 6935 try: 6936 with torch.set_grad_enabled(False): 6937 x = x + 1 6938 if n == 0: 6939 torch._dynamo.graph_break() 6940 except: 6941 pass 6942 with torch.set_grad_enabled(False): 6943 x = x + 1 6944 torch._dynamo.graph_break() 6945 x = x + 1 6946 return x 6947 6948 counter = CompileCounter() 6949 opt_fn = torch._dynamo.optimize(counter)(fn) 6950 self.assertEqual(opt_fn(torch.ones(1), 0), torch.tensor([5.0])) 6951 self.assertEqual(counter.frame_count, 1) 6952 6953 torch._dynamo.reset() 6954 counter = CompileCounter() 6955 opt_fn = torch._dynamo.optimize(counter)(fn) 6956 self.assertEqual(opt_fn(torch.ones(1), 1), torch.tensor([5.0])) 6957 self.assertEqual(counter.frame_count, 3) 6958 6959 def test_ordered_dict_alias_reconstruct(self): 6960 od = collections.OrderedDict 6961 6962 def fn(): 6963 d1 = dict() 6964 d1["a"] = 1 6965 d2 = od(d1) 6966 d2["b"] = 2 6967 torch._dynamo.graph_break() 6968 if isinstance(d2, od): 6969 return d2["a"] + d2["b"] 6970 else: 6971 return 0 6972 6973 dis.dis(fn) 6974 self.assertEqual(torch._dynamo.optimize("eager")(fn)(), 3) 6975 6976 # NOTE this test can be removed once multiline errors are in Python. 6977 # See https://github.com/python/cpython/issues/106922 6978 @skipIfNotPy311 6979 def test_get_instruction_source_311(self): 6980 def f(): 6981 # flake8: noqa 6982 # fmt: off 6983 # test binary ops 6984 a = ( b ) + c 6985 a = (a + b) // (c - d) 6986 a = b \ 6987 +\ 6988 c # test 6989 a = ( 6990 (b # test + 6991 ) \ 6992 # + 6993 << ( 6994 6995 c # test 6996 \ 6997 ) # test 6998 ) 6999 7000 # test slice 7001 a = bbb [ ccc ] 7002 b = bbbbb \ 7003 [ ccc # test 7004 7005 + ddd \ 7006 7007 ] # test 7008 a = bbb[ccc][ddd][eee] 7009 7010 # test nested and multiline function calls 7011 a = g(g(g(b))) 7012 a = g(h( 7013 g(b), 7014 c 7015 )) 7016 7017 # test chained function calls 7018 a = (g(x).y)( 7019 z 7020 )(1)(2) 7021 7022 # test unicode (match traceback behavior) 7023 a = ("" + 7024 + "") + b 7025 7026 from torch._dynamo.utils import get_instruction_source_311 7027 7028 if sys.version_info >= (3, 12): 7029 # Offsets changed in 3.12, e.g. due to removal of PRECALL inst 7030 offsets = (3, 11, 15, 19, 23, 29, 35, 44, 53, 65) 7031 else: 7032 offsets = (3, 11, 15, 19, 23, 29, 35, 46, 58, 74) 7033 insts = list(dis.get_instructions(f)) 7034 # uncomment to determine offsets 7035 # print(*enumerate(insts), sep="\n") 7036 all_sources = "\n".join( 7037 get_instruction_source_311(f.__code__, insts[offset]) for offset in offsets 7038 ) 7039 self.assertExpectedInline( 7040 all_sources, 7041 """\ 7042 a = ( b ) + c 7043 ~~~~~~~~~~^~~~~ 7044 7045 a = (a + b) // (c - d) 7046 ~~~~~~~~^^~~~~~~~~ 7047 7048 a = b \\ 7049 ~~~~~~ 7050 +\\ 7051 ^~ 7052 c # test 7053 ~ 7054 7055 (b # test + 7056 ~~~~~~~~~~~~ 7057 ) \\ 7058 ~~~~ 7059 # + 7060 ~~~ 7061 << ( 7062 ^^~~ 7063 7064 7065 c # test 7066 ~~~~~~~~~ 7067 \\ 7068 ~ 7069 ) # test 7070 ~ 7071 7072 a = bbb [ ccc ] 7073 ~~~~~~^^^^^^^^^^^ 7074 7075 b = bbbbb \\ 7076 ~~~~~~~ 7077 [ ccc # test 7078 ^^^^^^^^^^^^^ 7079 7080 7081 + ddd \\ 7082 ^^^^^^^^ 7083 7084 7085 ] # test 7086 ^ 7087 7088 a = bbb[ccc][ddd][eee] 7089 ~~~~~~~~^^^^^ 7090 7091 a = g(g(g(b))) 7092 ~^^^^^^ 7093 7094 a = g(h( 7095 ~^ 7096 g(b), 7097 ^^^^^ 7098 c 7099 ^ 7100 )) 7101 ^ 7102 7103 a = (g(x).y)( 7104 ~~~~~~~~~ 7105 z 7106 ~ 7107 )(1)(2) 7108 ~^^^ 7109""", 7110 ) 7111 # test unicode (since assertExpectedInline doesn't support unicode) 7112 op_offset = 74 if sys.version_info >= (3, 12) else 84 7113 self.assertEqual( 7114 get_instruction_source_311(f.__code__, insts[op_offset]), 7115 """\ 7116 a = ("" + 7117 ~~~~~~~~ 7118 + "") + b 7119 ~~~~~~~~^~~ 7120""", 7121 ) 7122 7123 def test_raise_guard_full_constraint(self): 7124 y = torch.randn([3, 3, 3]) 7125 7126 def my_dyn_fn(x): 7127 if x.shape[0] == 3: 7128 return x.sin() 7129 return x.cos() 7130 7131 torch._dynamo.mark_dynamic(y, 0) 7132 with self.assertRaises(ConstraintViolationError): 7133 torch._dynamo.optimize("eager")(my_dyn_fn)(y) 7134 7135 # Translation validation changes the exception type, don't run with it 7136 @torch.fx.experimental._config.patch(translation_validation=False) 7137 def test_mark_dynamic_with_ranges(self): 7138 y = torch.randn([8, 3, 3]) 7139 7140 def my_dyn_fn(x): 7141 if x.shape[0] == 3: 7142 return x.sin() 7143 return x.cos() 7144 7145 torch._dynamo.mark_dynamic(y, 0, min=2, max=5) 7146 with self.assertRaises(ConstraintViolationError): 7147 torch._dynamo.optimize("eager")(my_dyn_fn)(y) 7148 7149 def test_mark_static(self): 7150 counter = CompileCounter() 7151 7152 def my_dyn_fn(x): 7153 return x.cos() 7154 7155 y = torch.randn([3]) 7156 torch._dynamo.mark_static(y, 0) 7157 torch._dynamo.optimize(counter)(my_dyn_fn)(y) 7158 7159 z = torch.randn([4]) 7160 torch._dynamo.optimize(counter)(my_dyn_fn)(z) 7161 7162 self.assertEqual(counter.frame_count, 2) 7163 7164 def test_no_raise_guard_partial_constraint(self): 7165 y = torch.randn([3, 3, 3]) 7166 7167 def my_dyn_fn(x): 7168 if x.shape[0] > 3: 7169 return x.sin() 7170 return x.cos() 7171 7172 torch._dynamo.optimize("eager")(my_dyn_fn)(y) 7173 torch._dynamo.mark_dynamic(y, 0) 7174 torch._dynamo.reset() 7175 torch._dynamo.optimize("eager")(my_dyn_fn)(y) 7176 7177 def test_no_raise_guard_partial_constraint_across_break(self): 7178 y = torch.randn([3, 3, 3]) 7179 7180 def my_dyn_fn(x, y): 7181 z = x * y 7182 7183 torch._dynamo.graph_break() 7184 if z.shape[0] > 2: 7185 return z.cos() 7186 7187 return x.cos() 7188 7189 torch._dynamo.optimize("eager")(my_dyn_fn)(y, y) 7190 torch._dynamo.mark_dynamic(y, 0) 7191 torch._dynamo.reset() 7192 torch._dynamo.optimize("eager")(my_dyn_fn)(y, y) 7193 7194 # Sadly, this does not throw - we do not prop correctly across the graph break 7195 @unittest.expectedFailure 7196 def test_raise_guard_partial_constraint_across_break(self): 7197 y = torch.randn([3, 3, 3]) 7198 7199 def my_dyn_fn(x, y): 7200 z = x * y 7201 7202 torch._dynamo.graph_break() 7203 if z.shape[0] == 3: 7204 return z.cos() 7205 7206 return x.cos() 7207 7208 torch._dynamo.optimize("eager")(my_dyn_fn)(y, y) 7209 torch._dynamo.mark_dynamic(y, 0) 7210 torch._dynamo.reset() 7211 with self.assertRaisesRegex( 7212 Exception, 7213 ): 7214 torch._dynamo.optimize("eager")(my_dyn_fn)(y, y) 7215 7216 def test_raise_guard_partial_constraint_no_graph_break(self): 7217 y = torch.randn([3, 3, 3]) 7218 7219 def my_dyn_fn(x, y): 7220 z = x * y 7221 7222 if z.shape[0] == 3: 7223 return z.cos() 7224 7225 return x.cos() 7226 7227 torch._dynamo.mark_dynamic(y, 0) 7228 with self.assertRaises(ConstraintViolationError): 7229 torch._dynamo.optimize("eager")(my_dyn_fn)(y, y) 7230 7231 def test_cannot_trace_mark_dynamic(self): 7232 y = torch.randn([3, 3, 3]) 7233 7234 def my_dyn_fn(x): 7235 torch._dynamo.mark_dynamic(x, 0) 7236 return x * x 7237 7238 with self.assertRaisesRegex( 7239 AssertionError, "Attempt to trace forbidden callable" 7240 ): 7241 torch._dynamo.optimize("eager")(my_dyn_fn)(y) 7242 7243 def test_cannot_trace_mark_dynamic_safe_unreached(self): 7244 y = torch.randn([3, 3, 3]) 7245 7246 def my_dyn_fn(x): 7247 if x.shape[0] == 3: 7248 return x 7249 print("Running", torch._dynamo.mark_dynamic(x, 0)) 7250 return x * x 7251 7252 torch._dynamo.optimize("eager")(my_dyn_fn)(y) 7253 7254 def test_anomaly_aot_autograd(self): 7255 def fail(): 7256 raise AssertionError("fail") 7257 7258 @allow_in_graph 7259 def h(a): 7260 r = a.sum() 7261 # Trigger an exception in backwards 7262 r.register_hook(lambda x: fail()) 7263 return r 7264 7265 @torch.compile(backend="aot_eager") 7266 def f(a): 7267 return h(a) 7268 7269 with warnings.catch_warnings(record=True) as w, self.assertRaises( 7270 torch._dynamo.exc.BackendCompilerFailed 7271 ): 7272 f(torch.randn(2, 2, requires_grad=True)) 7273 7274 # Suppress unrelated pkg_resources warnings 7275 self.assertIn("forward call that caused the error", str(w[-1].message)) 7276 7277 def test_py_guards_mark_dynamic(self): 7278 def my_dyn_fn(a): 7279 if a.shape[0] > 2: 7280 return a.cos() 7281 return a.sin() 7282 7283 counter = CompileCounter() 7284 7285 # Run with dynamic 7286 x0 = torch.randn([3, 3, 3]) 7287 torch._dynamo.mark_dynamic(x0, 0) 7288 torch._dynamo.optimize(counter)(my_dyn_fn)(x0) 7289 self.assertEqual(counter.frame_count, 1) 7290 7291 # Run without dynamic, no recompile 7292 x = torch.randn([3, 3, 3]) 7293 torch._dynamo.optimize(counter)(my_dyn_fn)(x) 7294 self.assertEqual(counter.frame_count, 1) 7295 7296 # Mark a new dim, 1, as dynamic 7297 x1 = torch.randn([3, 3, 3]) 7298 torch._dynamo.mark_dynamic(x1, 1) 7299 torch._dynamo.optimize(counter)(my_dyn_fn)(x1) 7300 # Recompile triggered because we marked a new dym as dynamic 7301 self.assertEqual(counter.frame_count, 2) 7302 7303 # Reset 7304 torch._dynamo.reset() 7305 # Reset counter 7306 counter = CompileCounter() 7307 7308 # Run with dynamic 1 7309 torch._dynamo.optimize(counter)(my_dyn_fn)(x1) 7310 self.assertEqual(counter.frame_count, 1) 7311 7312 # Run with dynamic 0, not subset 7313 torch._dynamo.optimize(counter)(my_dyn_fn)(x0) 7314 self.assertEqual(counter.frame_count, 2) 7315 7316 # Run with dynamic 0, 1, 2, not subset 7317 x012 = torch.randn([3, 3, 3]) 7318 torch._dynamo.mark_dynamic(x012, 0) 7319 torch._dynamo.mark_dynamic(x012, 1) 7320 torch._dynamo.mark_dynamic(x012, 2) 7321 torch._dynamo.optimize(counter)(my_dyn_fn)(x012) 7322 self.assertEqual(counter.frame_count, 3) 7323 7324 def test_recompile_on_global_state_change(self): 7325 last_state = [] 7326 cnt = 0 7327 7328 def my_compiler(gm, _): 7329 nonlocal cnt 7330 cnt += 1 7331 state = read_state() 7332 7333 def inner(*args): 7334 last_state[:] = state 7335 return gm(*args) 7336 7337 return inner 7338 7339 def read_state(): 7340 return [ 7341 torch.is_grad_enabled(), 7342 torch.are_deterministic_algorithms_enabled(), 7343 torch._C._get_cublas_allow_tf32(), 7344 ] 7345 7346 def write_state(state): 7347 torch.set_grad_enabled(state[0]), 7348 torch.use_deterministic_algorithms(state[1]) 7349 torch._C._set_cublas_allow_tf32(state[2]), 7350 7351 @torch.compile(backend=my_compiler) 7352 def fn(x): 7353 return x + 1 7354 7355 initial_state = read_state() 7356 y = torch.randn(10) 7357 try: 7358 for round in range(3): 7359 for i in range(len(initial_state)): 7360 new_state = [False] * len(initial_state) 7361 new_state[i] = True 7362 write_state(new_state) 7363 assert read_state() == new_state 7364 last_state.clear() 7365 fn(y) 7366 assert last_state == new_state 7367 if round == 0: 7368 assert cnt == i + 1 7369 else: 7370 assert cnt == len(initial_state) 7371 finally: 7372 write_state(initial_state) 7373 7374 def test_grad_state_mutated(self): 7375 prior = torch.is_grad_enabled() 7376 value = None 7377 cnt = CompileCounter() 7378 7379 @torch._dynamo.allow_in_graph 7380 def check_state(): 7381 nonlocal value 7382 value = torch.is_grad_enabled() 7383 7384 @torch.compile(backend=cnt, fullgraph=True) 7385 def fn(x): 7386 check_state() 7387 torch.set_grad_enabled(False) 7388 return x + 1 7389 7390 try: 7391 torch.set_grad_enabled(True) 7392 fn(torch.randn(10)) 7393 assert value is True 7394 assert torch.is_grad_enabled() is False 7395 7396 value = None 7397 torch.set_grad_enabled(True) 7398 fn(torch.randn(10)) 7399 assert value is True 7400 assert torch.is_grad_enabled() is False 7401 7402 assert cnt.frame_count == 1 7403 finally: 7404 torch.set_grad_enabled(prior) 7405 7406 def test_deterministic_algorithms_mutated(self): 7407 prior = torch.are_deterministic_algorithms_enabled() 7408 prior_warn_only = torch.is_deterministic_algorithms_warn_only_enabled() 7409 value = None 7410 warn_only = None 7411 cnt = CompileCounter() 7412 7413 @torch._dynamo.allow_in_graph 7414 def check_state(): 7415 nonlocal value 7416 nonlocal warn_only 7417 value = torch.are_deterministic_algorithms_enabled() 7418 warn_only = torch.is_deterministic_algorithms_warn_only_enabled() 7419 7420 @torch.compile(backend=cnt, fullgraph=True) 7421 def fn(x): 7422 check_state() 7423 torch.use_deterministic_algorithms(False, warn_only=False) 7424 return x + 1 7425 7426 def run_fn(): 7427 torch.use_deterministic_algorithms(True, warn_only=True) 7428 fn(torch.randn(10)) 7429 assert value is True 7430 assert warn_only is True 7431 assert torch.are_deterministic_algorithms_enabled() is False 7432 assert torch.is_deterministic_algorithms_warn_only_enabled() is False 7433 7434 try: 7435 run_fn() 7436 value, warn_only = None, None 7437 run_fn() 7438 assert cnt.frame_count == 1 7439 finally: 7440 torch.use_deterministic_algorithms(prior, warn_only=prior_warn_only) 7441 7442 def test_torch_compile_ctx_on_forward_and_training_step(self): 7443 class MyModel(torch.nn.Module): 7444 def forward(self): 7445 ... 7446 7447 def training_step(self): 7448 self() 7449 7450 model = MyModel() 7451 compiled_model = torch.compile(model) 7452 7453 model.forward = compiled_model.dynamo_ctx(model.forward) 7454 model.training_step = compiled_model.dynamo_ctx(model.training_step) 7455 7456 model.training_step() 7457 7458 def test_torch_guards_stack_frame_register_inlining(self): 7459 x = torch.tensor([0.5, 0.5]) 7460 y = torch.tensor([0.75, 0.75, 0.75, 0.75]) 7461 z = torch.tensor([0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25]) 7462 7463 def uwu_inline_me(x, y, z): 7464 r = torch.cat((x, x)) + y 7465 r2 = torch.cat((y, y)) + z 7466 return r, r2 7467 7468 def fn(x, y, z): 7469 r, r2 = uwu_inline_me(x, y, z) 7470 return torch.mul(r, r), torch.mul(r2, r2) 7471 7472 seen_frames = [] 7473 import contextlib 7474 7475 @contextlib.contextmanager 7476 def global_context_capture_fn(frame_summary): 7477 if frame_summary is not None: 7478 seen_frames.append(frame_summary) 7479 yield 7480 7481 with mock.patch( 7482 "torch._guards.TracingContext.current_frame", 7483 side_effect=global_context_capture_fn, 7484 ): 7485 torch._dynamo.optimize("eager")(fn)(x, y, z) 7486 7487 self.assertEqual(len(seen_frames), 1) 7488 self.assertEqual(seen_frames[0].name, "fn") 7489 self.assertEqual(seen_frames[0].line, "r, r2 = uwu_inline_me(x, y, z)") 7490 7491 def test_torch_guards_stack_frame_register_inlining_deep(self): 7492 x = torch.tensor([0.5, 0.5]) 7493 y = torch.tensor([0.75, 0.75, 0.75, 0.75]) 7494 z = torch.tensor([0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25]) 7495 7496 def uwu_inline_me_deep(x, y): 7497 return torch.cat((x, x)) + y 7498 7499 def uwu_inline_me(x, y, z): 7500 r = uwu_inline_me_deep(x, y) 7501 r2 = uwu_inline_me_deep(y, z) 7502 return r, r2 7503 7504 def fn(x, y, z): 7505 r, r2 = uwu_inline_me(x, y, z) 7506 return torch.mul(r, r), torch.mul(r2, r2) 7507 7508 seen_frames = [] 7509 import contextlib 7510 7511 @contextlib.contextmanager 7512 def global_context_capture_fn(frame_summary): 7513 if frame_summary is not None: 7514 seen_frames.append(frame_summary) 7515 yield 7516 7517 with mock.patch( 7518 "torch._guards.TracingContext.current_frame", 7519 side_effect=global_context_capture_fn, 7520 ): 7521 torch._dynamo.optimize("eager")(fn)(x, y, z) 7522 7523 self.assertEqual(len(seen_frames), 3) 7524 self.assertEqual(seen_frames[0].name, "fn") 7525 self.assertEqual(seen_frames[1].name, "uwu_inline_me") 7526 self.assertEqual(seen_frames[2].line, "r2 = uwu_inline_me_deep(y, z)") 7527 7528 def test_error_on_recompile(self): 7529 @torch._dynamo.optimize("eager") 7530 def fn(a, b): 7531 return a + b 7532 7533 with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): 7534 with self.assertRaises(torch._dynamo.exc.RecompileError): 7535 fn(torch.rand(2, 3), torch.rand(2, 3)) 7536 fn(torch.rand(2, 3), (1, 2, 3)) 7537 7538 @expectedFailureDynamic 7539 @torch._dynamo.config.patch(automatic_dynamic_shapes=False) 7540 def test_compile_profiler(self): 7541 class Model(torch.nn.Module): 7542 def forward(self, input): 7543 return input + input 7544 7545 model = Model() 7546 prof = CompileProfiler() 7547 compiled = torch.compile(model, backend=prof) 7548 base_checker = ( 7549 lambda: FileCheck() 7550 .check("Torchdynamo Profiler Report") 7551 .check("Graph Breaks") 7552 .check("No graph breaks detected.") 7553 .check("Recompilation") 7554 ) 7555 input = torch.rand((2, 3, 4)) 7556 _ = compiled(input) 7557 base_checker().check("No recompilation detected.").run(prof.report()) 7558 7559 new_shape_input = torch.rand((3, 3, 4)) 7560 _ = compiled(new_shape_input) 7561 7562 # Not an exhaustive test of dynamic shapes behavior, but some sanity 7563 if torch._dynamo.config.assume_static_by_default: 7564 base_checker().check("Recompile Reasons").check("'forward'").check( 7565 "cache_size_limit to 1" 7566 ).run(prof.report()) 7567 else: 7568 base_checker().check("No recompilation detected.").run(prof.report()) 7569 7570 new_shape_input = torch.rand((4, 3, 4)) 7571 _ = compiled(new_shape_input) 7572 7573 base_checker().check("Recompile Reasons").check("'forward'").check( 7574 "tensor 'L['input']' size mismatch at index 0. expected 2, actual 3" 7575 ).check( 7576 "tensor 'L['input']' size mismatch at index 0. expected 3, actual 4" 7577 ).run( 7578 prof.report() 7579 ) 7580 7581 def test_guards_strip_function_call(self): 7582 from torch._dynamo.guards import strip_function_call 7583 7584 test_case = [ 7585 ("___odict_getitem(a, 1)", "a"), 7586 ("a.layers[slice(2)][0]._xyz", "a"), 7587 ("getattr(a.layers[slice(2)][0]._abc, '0')", "a"), 7588 ("getattr(getattr(a.x[3], '0'), '3')", "a"), 7589 ("a.layers[slice(None, -1, None)][0]._xyz", "a"), 7590 ("a.layers[func('offset', -1, None)][0]._xyz", "a"), 7591 ] 7592 # strip_function_call should extract the object from the string. 7593 for name, expect_obj in test_case: 7594 self.assertEqual(strip_function_call(name), expect_obj) 7595 7596 def test_int_neg(self): 7597 def int_neg(a, b): 7598 x = a.shape[0] 7599 y = b.shape[0] 7600 return -x * -y * a * b 7601 7602 torch._dynamo.testing.standard_test(self, int_neg, 2) 7603 7604 def test_hash_getitem_slice(self): 7605 s = GetItemSource(LocalSource("foo"), slice(None, -1, None)) 7606 s2 = GetItemSource(LocalSource("foo"), slice(None, -1, None)) 7607 s3 = GetItemSource(LocalSource("foo"), slice(None, -1, 2)) 7608 some_set = set() 7609 7610 self.assertTrue(s not in some_set) 7611 self.assertTrue(s2 not in some_set) 7612 self.assertTrue(s3 not in some_set) 7613 7614 some_set.add(s) 7615 7616 self.assertTrue(s in some_set) 7617 # s and s2 should hash the same 7618 self.assertTrue(s2 in some_set) 7619 # s3 should be different 7620 self.assertTrue(s3 not in some_set) 7621 7622 self.assertTrue(s == s2) 7623 self.assertTrue(s != s3) 7624 7625 def test_inline_dict_function(self): 7626 def _result_type_dict(dtype): 7627 return {bool: torch.float32}[dtype] 7628 7629 @torch.compile 7630 def f(): 7631 return torch.ones(3, dtype=_result_type_dict(bool)) 7632 7633 self.assertEqual(f(), torch.ones(3, dtype=torch.float32)) 7634 7635 def test_inline_dict_function_passed_as_arg(self): 7636 @torch.compile 7637 def fn(d, x, y): 7638 if d[x] is torch.float32: 7639 return y.cos() 7640 else: 7641 return y.sin() 7642 7643 dd = {bool: torch.float32, int: torch.int64} 7644 self.assertEqual(fn(dd, bool, torch.ones(4)), torch.ones(4).cos()) 7645 self.assertEqual(fn(dd, int, torch.ones(4)), torch.ones(4).sin()) 7646 7647 def test_add_sizes(self): 7648 def func(x): 7649 y = x.size() 7650 return y + y 7651 7652 eager_out = func(torch.ones(10, 10, 3)) 7653 compile_out = torch._dynamo.optimize("eager")(func)(torch.ones(10, 10, 3)) 7654 self.assertTrue(isinstance(compile_out, torch.Size)) 7655 self.assertEqual(eager_out, compile_out) 7656 7657 @unittest.skipIf(not TEST_MULTIGPU, "need multiple GPU") 7658 def test_cuda_set_device(self): 7659 def fn(): 7660 a = torch.ones(2, device="cuda") 7661 torch.cuda.set_device(1) 7662 return a + 1 7663 7664 with torch.cuda.device(0): 7665 counter = CompileCounter() 7666 opt_fn = torch._dynamo.optimize(counter)(fn) 7667 res = opt_fn() 7668 self.assertEqual(res.device.type, "cuda") 7669 self.assertEqual(res.device.index, 0) 7670 self.assertEqual(counter.frame_count, 2) 7671 7672 def test_nested_function_resuming_with_correct_globals(self): 7673 # https://github.com/pytorch/pytorch/issues/99665 7674 try: 7675 from .utils import outer_func 7676 except ImportError: 7677 from utils import outer_func 7678 7679 def gn(x, y): 7680 return x + y 7681 7682 def fn(x, y): 7683 return outer_func(gn)(x, y) 7684 7685 x = torch.rand([3]) 7686 y = torch.rand([3]) 7687 opt_fn = torch.compile(backend="eager")(fn) 7688 ref = fn(x, y) 7689 res = opt_fn(x, y) 7690 self.assertTrue(same(ref, res)) 7691 7692 @dataclasses.dataclass 7693 class CSETestCase: 7694 expr: str 7695 preface: typing.List[str] = dataclasses.field(default_factory=list) 7696 expected: typing.Optional[str] = None 7697 expected_py38: typing.Optional[str] = None 7698 7699 def _is_py38(self) -> bool: 7700 return sys.version_info[:2] <= (3, 8) 7701 7702 def _has_ast_unparse(self) -> bool: 7703 from torch._dynamo.guards import HAS_UNPARSE_FUNCTIONS 7704 7705 return HAS_UNPARSE_FUNCTIONS 7706 7707 def test_guards_cse_pass_single(self): 7708 if not self._has_ast_unparse(): 7709 if IS_FBCODE: 7710 raise RuntimeError("Needs astunparse or Python-3.9+") 7711 raise unittest.SkipTest("Needs astunparse or Python-3.9+") 7712 from torch._dynamo.guards import PyExprCSEPass 7713 7714 testcase = self.CSETestCase 7715 testcases = [ 7716 # Nothing gets CSE-d, since the only repeated sub-expression is 'x'. 7717 # i.e. not a node type we are interested on. 7718 testcase(expr="x[0].a"), 7719 testcase(expr="x[1].a"), 7720 testcase(expr="x[2].a"), 7721 # 'a.b.c' gets CSE-d, since it's a sub-expression used more than 'PyExprCSEPass.USE_THRESHOLD'. 7722 testcase( 7723 expr="a.b.c[0].d.e", 7724 preface=["_var0 = a.b", "_var1 = _var0.c"], 7725 expected="_var1[0].d.e", 7726 ), 7727 testcase(expr="a.b.c[1].d.e", expected="_var1[1].d.e"), 7728 testcase(expr="a.b.c[2].d.e", expected="_var1[2].d.e"), 7729 # 'm.n[0]' gets CSE-d, since it is a sub-expression used more than 'PyExprCSEPass.USE_THRESHOLD'. 7730 testcase( 7731 expr="f(m.n[0], '0').x.y.z", 7732 preface=["_var2 = m.n", "_var3 = _var2[0]"], 7733 expected="f(_var3, '0').x.y.z", 7734 ), 7735 testcase(expr="f(m.n[0], '1').x.y.z", expected="f(_var3, '1').x.y.z"), 7736 testcase(expr="f(m.n[0], '2').x.y.z", expected="f(_var3, '2').x.y.z"), 7737 # The whole expressiong gets CSE-d, as well as all of its sub-expressions. 7738 testcase( 7739 expr="self.g(a, b).k", 7740 preface=["_var4 = self.g", "_var5 = _var4(a, b)", "_var6 = _var5.k"], 7741 expected="_var6", 7742 ), 7743 testcase(expr="self.g(a, b).k", expected="_var6"), 7744 testcase(expr="self.g(a, b).k", expected="_var6"), 7745 ] 7746 csepass = PyExprCSEPass() 7747 csepass.count([t.expr for t in testcases]) 7748 7749 for t in testcases: 7750 preface, expr = csepass.replace(t.expr) 7751 self.assertEqual(preface, t.preface) 7752 expected = t.expected if t.expected is not None else t.expr 7753 self.assertEqual(expr, expected) 7754 7755 def test_guards_cse_pass_multiple(self): 7756 if not self._has_ast_unparse(): 7757 raise unittest.SkipTest("Needs astunparse or Python-3.9+") 7758 from torch._dynamo.guards import PyExprCSEPass 7759 7760 testcase = self.CSETestCase 7761 testcases = [ 7762 testcase( 7763 expr="x[0].a < x[1].a * (3 - x[2].a)", 7764 expected="x[0].a < x[1].a * (3 - x[2].a)", 7765 expected_py38="(x[0].a < (x[1].a * (3 - x[2].a)))", 7766 ), 7767 testcase( 7768 expr="a.b.c[0].d.e + a.b.c[1].d.e * a.b.c[2].d.e > 0", 7769 preface=["_var0 = a.b", "_var1 = _var0.c"], 7770 expected="_var1[0].d.e + _var1[1].d.e * _var1[2].d.e > 0", 7771 expected_py38="((_var1[0].d.e + (_var1[1].d.e * _var1[2].d.e)) > 0)", 7772 ), 7773 testcase( 7774 expr="f(m.n[0], '0').x.y.z * f(m.n[0], '1').x.y.z * f(m.n[0], '2').x.y.z < 512", 7775 preface=["_var2 = m.n", "_var3 = _var2[0]"], 7776 expected="f(_var3, '0').x.y.z * f(_var3, '1').x.y.z * f(_var3, '2').x.y.z < 512", 7777 expected_py38="(((f(_var3, '0').x.y.z * f(_var3, '1').x.y.z) * f(_var3, '2').x.y.z) < 512)", 7778 ), 7779 testcase( 7780 expr="self.g(a, b).k + (1 - self.g(a, b).k) <= m[0].a + self.g(a, b).k", 7781 preface=["_var4 = self.g", "_var5 = _var4(a, b)", "_var6 = _var5.k"], 7782 expected="_var6 + (1 - _var6) <= m[0].a + _var6", 7783 expected_py38="((_var6 + (1 - _var6)) <= (m[0].a + _var6))", 7784 ), 7785 ] 7786 7787 csepass = PyExprCSEPass() 7788 csepass.count([t.expr for t in testcases]) 7789 7790 for t in testcases: 7791 preface, expr = csepass.replace(t.expr) 7792 self.assertEqual(preface, t.preface) 7793 expected = t.expected_py38 if self._is_py38() else t.expected 7794 expected = expected if expected is not None else t.expr 7795 self.assertEqual(expr, expected) 7796 7797 def test_guard_function_builder_with_cse(self): 7798 from torch._dynamo.guards import build_guard_function 7799 7800 exprs = [ 7801 "x[0].a < x[1].a * (3 - x[2].a)", 7802 "a.b.c[0].d.e + a.b.c[1].d.e * a.b.c[2].d.e > 0", 7803 "f(m.n[0], '0').x.y.z * f(m.n[0], '1').x.y.z * f(m.n[0], '2').x.y.z < 512", 7804 "self.g(a, b).k + (1 - self.g(a, b).k) <= m[0].a + self.g(a, b).k", 7805 ] 7806 7807 _, pycode = build_guard_function(exprs, "") 7808 expected = """\ 7809def ___make_guard_fn(): 7810 def guard(L): 7811 if not (x[0].a < x[1].a * (3 - x[2].a)): 7812 return False 7813 _var0 = a.b 7814 _var1 = _var0.c 7815 if not (_var1[0].d.e + _var1[1].d.e * _var1[2].d.e > 0): 7816 return False 7817 _var2 = m.n 7818 _var3 = _var2[0] 7819 if not (f(_var3, '0').x.y.z * f(_var3, '1').x.y.z * f(_var3, '2').x.y.z < 512): 7820 return False 7821 _var4 = self.g 7822 _var5 = _var4(a, b) 7823 _var6 = _var5.k 7824 if not (_var6 + (1 - _var6) <= m[0].a + _var6): 7825 return False 7826 return True 7827 return guard 7828""" 7829 expected_38 = """\ 7830def ___make_guard_fn(): 7831 def guard(L): 7832 if not ((x[0].a < (x[1].a * (3 - x[2].a)))): 7833 return False 7834 _var0 = a.b 7835 _var1 = _var0.c 7836 if not (((_var1[0].d.e + (_var1[1].d.e * _var1[2].d.e)) > 0)): 7837 return False 7838 _var2 = m.n 7839 _var3 = _var2[0] 7840 if not ((((f(_var3, '0').x.y.z * f(_var3, '1').x.y.z) * f(_var3, '2').x.y.z) < 512)): 7841 return False 7842 _var4 = self.g 7843 _var5 = _var4(a, b) 7844 _var6 = _var5.k 7845 if not (((_var6 + (1 - _var6)) <= (m[0].a + _var6))): 7846 return False 7847 return True 7848 return guard 7849""" 7850 expected_38_no_astunparse = """\ 7851def ___make_guard_fn(): 7852 def guard(L): 7853 if not (x[0].a < x[1].a * (3 - x[2].a)): 7854 return False 7855 if not (a.b.c[0].d.e + a.b.c[1].d.e * a.b.c[2].d.e > 0): 7856 return False 7857 if not (f(m.n[0], '0').x.y.z * f(m.n[0], '1').x.y.z * f(m.n[0], '2').x.y.z < 512): 7858 return False 7859 if not (self.g(a, b).k + (1 - self.g(a, b).k) <= m[0].a + self.g(a, b).k): 7860 return False 7861 return True 7862 return guard 7863""" 7864 7865 if self._is_py38(): 7866 expected = ( 7867 expected_38 if self._has_ast_unparse() else expected_38_no_astunparse 7868 ) 7869 self.assertEqual(expected, pycode) 7870 7871 def test_dynamo_compiling_fake_tensor_to_vararg_int(self): 7872 class MyModule(torch.nn.Module): 7873 def __init__(self): 7874 super().__init__() 7875 7876 def forward(self, x): 7877 # use numpy int so it's wrapped as fake tensor in dynamo 7878 shape = np.int_(16) 7879 # test shape as fake tensor, which param type is 7880 # Sequence[Union[_int, SymInt]] 7881 return x.reshape(shape) 7882 7883 x = torch.rand([4, 4]) 7884 model = MyModule() 7885 orig_out = model(x) 7886 opt_model = torch._dynamo.optimize("eager")(MyModule()) 7887 opt_out = opt_model(x) 7888 self.assertTrue(same(orig_out, opt_out)) 7889 7890 def test_scalar_tensor_is_equivalent_to_symint_argument(self): 7891 class GumbelTopKSampler(torch.nn.Module): 7892 def __init__(self, T, k): 7893 super().__init__() 7894 self.T = torch.nn.Parameter( 7895 torch.tensor(T, dtype=torch.float32), requires_grad=False 7896 ) 7897 self.k = torch.nn.Parameter( 7898 torch.tensor(k, dtype=torch.int32), requires_grad=False 7899 ) 7900 7901 def sample_discrete(self, logits): 7902 threshold = torch.topk(logits, self.k, sorted=True)[0][..., -1] 7903 samples = torch.ge(logits.squeeze(1), threshold).float() 7904 return samples 7905 7906 def forward(self, logits): 7907 dsamples = self.sample_discrete(logits) 7908 return dsamples 7909 7910 x = torch.rand([4, 4, 4, 4]) 7911 m = GumbelTopKSampler(T=4, k=4) 7912 orig_out = m(x) 7913 opt_m = torch.compile(backend="eager")(m) 7914 opt_out = opt_m(x) 7915 self.assertTrue(same(orig_out, opt_out)) 7916 7917 def test_scalar_tensor_is_equivalent_to_symint_list_argument(self): 7918 class Jitter(torch.nn.Module): 7919 def __init__(self, jitter_val): 7920 super().__init__() 7921 self.jitter_val = jitter_val 7922 7923 def roll_tensor(self, input): 7924 h_shift = self.jitter_val - 1 7925 w_shift = self.jitter_val + 1 7926 return torch.roll( 7927 torch.roll(input, shifts=h_shift, dims=2), shifts=w_shift, dims=3 7928 ) 7929 7930 def forward(self, input): 7931 return self.roll_tensor(input) 7932 7933 x = torch.rand([4, 4, 4, 4]) 7934 m = Jitter(jitter_val=4) 7935 orig_out = m(x) 7936 opt_m = torch.compile(backend="eager")(m) 7937 opt_out = opt_m(x) 7938 self.assertTrue(same(orig_out, opt_out)) 7939 7940 def test_scalar_tensor_is_equivalent_to_int_list_argument(self): 7941 class MyModel(torch.nn.Module): 7942 def forward(self, input): 7943 permute = torch.tensor([0, 2, 1]) 7944 x = input.permute(*permute) 7945 return x 7946 7947 x = torch.randn(2, 3, 4) 7948 m = MyModel() 7949 orig_out = m(x) 7950 opt_m = torch.compile(backend="eager")(m) 7951 opt_out = opt_m(x) 7952 self.assertTrue(same(orig_out, opt_out)) 7953 7954 def test_torch_variable_hasattr(self): 7955 def fn(x): 7956 if hasattr(torch.nn, "Module"): 7957 return x * x 7958 return x + 1 7959 7960 compiled_fn = torch.compile(backend="eager", fullgraph=True)(fn) 7961 7962 x = torch.rand([4, 4]) 7963 fn_out = fn(x) 7964 compiled_out = compiled_fn(x) 7965 self.assertTrue(same(fn_out, compiled_out)) 7966 7967 def test_list_hasattr1(self): 7968 def fn(x): 7969 if hasattr(x, "foo"): 7970 return x[0] + 1 7971 return x[0] - 1 7972 7973 compiled_fn = torch.compile(backend="eager", fullgraph=True)(fn) 7974 7975 x = [torch.randn(3)] 7976 fn_out = fn(x) 7977 compiled_out = compiled_fn(x) 7978 self.assertTrue(same(fn_out, compiled_out)) 7979 7980 def test_list_hasattr2(self): 7981 def fn(): 7982 x = [torch.zeros(3)] 7983 if hasattr(x, "__len__"): 7984 return x[0] + 1 7985 return x[0] - 1 7986 7987 compiled_fn = torch.compile(backend="eager", fullgraph=True)(fn) 7988 7989 fn_out = fn() 7990 compiled_out = compiled_fn() 7991 self.assertTrue(same(fn_out, compiled_out)) 7992 7993 def test_tuple_hasattr(self): 7994 def fn(x): 7995 if hasattr(x, "foo"): 7996 return x[0] + 1 7997 return x[1] - 1 7998 7999 compiled_fn = torch.compile(backend="eager", fullgraph=True)(fn) 8000 8001 x = (torch.randn(3), torch.randn(3)) 8002 fn_out = fn(x) 8003 compiled_out = compiled_fn(x) 8004 self.assertTrue(same(fn_out, compiled_out)) 8005 8006 def test_fn_hasattr__name__1(self): 8007 def fn(): 8008 foo = lambda x: x + 1 8009 return hasattr(foo, "__name__") 8010 8011 compiled_fn = torch.compile(backend="eager", fullgraph=True)(fn) 8012 8013 fn_out = fn() 8014 compiled_out = compiled_fn() 8015 self.assertEqual(fn_out, compiled_out) 8016 self.assertTrue(fn_out) 8017 8018 def test_fn_hasattr__name__2(self): 8019 def bar(x): 8020 return torch.sin(x) 8021 8022 def fn(): 8023 return hasattr(bar, "__name__") 8024 8025 compiled_fn = torch.compile(backend="eager", fullgraph=True)(fn) 8026 8027 fn_out = fn() 8028 compiled_out = compiled_fn() 8029 self.assertEqual(fn_out, compiled_out) 8030 self.assertTrue(fn_out) 8031 8032 def test_fn_hasattr__name__3(self): 8033 def bar(x, y): 8034 return torch.sin(x) + torch.cos(y) 8035 8036 baz = functools.partial(bar, y=4) 8037 8038 def fn(): 8039 return hasattr(baz, "__name__") 8040 8041 compiled_fn = torch.compile(backend="eager", fullgraph=True)(fn) 8042 8043 fn_out = fn() 8044 compiled_out = compiled_fn() 8045 self.assertEqual(fn_out, compiled_out) 8046 self.assertFalse(fn_out) 8047 8048 def test_torch_objects_as_keys(self): 8049 remap = {torch.float16: torch.float32} 8050 8051 def fn(): 8052 return torch.randn(3, dtype=remap[torch.float16]) 8053 8054 opt = torch._dynamo.optimize("eager")(fn) 8055 opt() 8056 8057 def test_tracing_py_tree(self): 8058 def fn(xs): 8059 flat_xs, spec = pytree.tree_flatten(xs) 8060 res = [x.clone() for x in flat_xs] 8061 return pytree.tree_unflatten(res, spec) 8062 8063 xs = [torch.tensor(i) for i in range(3)] 8064 8065 counter = CompileCounter() 8066 torch._dynamo.optimize(counter, nopython=True)(fn)(xs) 8067 self.assertEqual(counter.frame_count, 1) 8068 self.assertEqual(counter.op_count, 3) 8069 8070 def test_tracing_nested_py_tree(self): 8071 import torch.utils._pytree as pytree 8072 8073 def fn(xs): 8074 flat_xs, spec = pytree.tree_flatten(xs) 8075 res = [x.clone() for x in flat_xs] 8076 return pytree.tree_unflatten(res, spec) 8077 8078 xs = [torch.tensor(i) for i in range(3)] 8079 xsl = [xs, xs, xs, xs] 8080 8081 counter = CompileCounter() 8082 comp_out = torch._dynamo.optimize(counter, nopython=True)(fn)(xsl) 8083 real_out = fn(xsl) 8084 self.assertEqual(comp_out, real_out) 8085 self.assertEqual(counter.frame_count, 1) 8086 self.assertEqual(counter.op_count, 12) 8087 8088 def test_tracing_nested_py_tree_tuples(self): 8089 import torch.utils._pytree as pytree 8090 8091 def fn(xs): 8092 flat_xs, spec = pytree.tree_flatten(xs) 8093 res = [x.clone() for x in flat_xs] 8094 return pytree.tree_unflatten(res, spec) 8095 8096 xs = [torch.tensor(i) for i in range(3)] 8097 xsl = (xs, xs, xs, xs) 8098 8099 counter = CompileCounter() 8100 comp_out = torch._dynamo.optimize(counter, nopython=True)(fn)(xsl) 8101 real_out = fn(xsl) 8102 self.assertEqual(comp_out, real_out) 8103 self.assertEqual(counter.frame_count, 1) 8104 self.assertEqual(counter.op_count, 12) 8105 8106 def test_tracing_nested_py_tree_dicts(self): 8107 import torch.utils._pytree as pytree 8108 8109 def fn(xs): 8110 flat_xs, spec = pytree.tree_flatten(xs) 8111 res = [x.clone() for x in flat_xs] 8112 return pytree.tree_unflatten(res, spec) 8113 8114 xs = [torch.tensor(i) for i in range(3)] 8115 xsl = { 8116 "a": xs, 8117 "b": xs, 8118 "c": xs, 8119 } 8120 8121 counter = CompileCounter() 8122 comp_out = torch._dynamo.optimize(counter, nopython=True)(fn)(xsl) 8123 real_out = fn(xsl) 8124 self.assertEqual(comp_out, real_out) 8125 self.assertEqual(counter.frame_count, 1) 8126 self.assertEqual(counter.op_count, 9) 8127 8128 def test_dynamic_one_hot(self): 8129 def fn(x): 8130 x = x + 1 8131 # graph break from data-dependent output shape 8132 x = torch.nn.functional.one_hot(x) 8133 x = x + 1 8134 return x 8135 8136 inp = torch.arange(20) % 4 8137 counter = CompileCounter() 8138 real_out = fn(inp) 8139 comp_out = torch.compile(fn, backend=counter)(inp) 8140 self.assertEqual(comp_out, real_out) 8141 self.assertEqual(counter.frame_count, 2) 8142 self.assertEqual(counter.op_count, 2) 8143 8144 def test_tracing_nested_py_tree_mixed_all(self): 8145 import torch.utils._pytree as pytree 8146 8147 def fn(xs): 8148 flat_xs, spec = pytree.tree_flatten(xs) 8149 res = [x.clone() for x in flat_xs] 8150 return pytree.tree_unflatten(res, spec) 8151 8152 xs = [torch.tensor(i) for i in range(3)] 8153 xsa = (xs, xs) 8154 xsb = {"aa": xsa, "ab": xs} 8155 xsl = { 8156 "a": xs, 8157 "b": xsa, 8158 "c": xsb, 8159 } 8160 8161 counter = CompileCounter() 8162 comp_out = torch._dynamo.optimize(counter, nopython=True)(fn)(xsl) 8163 real_out = fn(xsl) 8164 self.assertEqual(comp_out, real_out) 8165 self.assertEqual(counter.frame_count, 1) 8166 self.assertEqual(counter.op_count, 18) 8167 8168 def test_any_all_symnode(self): 8169 cnt = CompileCounter() 8170 8171 @torch.compile(backend=cnt, fullgraph=True, dynamic=True) 8172 def fn(x): 8173 t = x.size(0) >= 10 8174 f = x.size(0) >= 100 8175 if any([]) or any([f]) or any([f, f]): 8176 return x - 1 8177 if all([f]) or all([t, f]) or all([f, t]) or all([f, f]): 8178 return x - 2 8179 if not (all([]) and all([t]) and all([t, t])): 8180 return x - 3 8181 if not (any([t]) and any([t, f]) and any([f, t])): 8182 return x - 4 8183 return x + 1 8184 8185 y1 = torch.randn(16) 8186 y2 = torch.randn(18) 8187 self.assertEqual(fn(y1), y1 + 1) 8188 self.assertEqual(fn(y2), y2 + 1) 8189 self.assertEqual(cnt.frame_count, 1) 8190 y3 = torch.randn(5) 8191 self.assertEqual(fn(y3), y3 - 3) 8192 self.assertEqual(cnt.frame_count, 2) 8193 8194 def test_tracing_py_tree_tensor_subclass(self): 8195 import torch.utils._pytree as pytree 8196 from torch.testing._internal.two_tensor import TwoTensor 8197 from torch.utils.checkpoint import checkpoint 8198 8199 def fn(xs): 8200 nested_xs = [[xs]] 8201 flat_xs, spec = pytree.tree_flatten(xs) 8202 return flat_xs[0].clone() 8203 8204 # use checkpoint to trigger a "sourceless" tensor subclass 8205 def checkpoint_fn(xs): 8206 return checkpoint(fn, xs, use_reentrant=True) 8207 8208 xs = TwoTensor(torch.ones(2, 2), torch.ones(2, 2)) 8209 8210 counter = CompileCounter() 8211 torch._dynamo.optimize(counter, nopython=True)(checkpoint_fn)(xs) 8212 self.assertEqual(counter.frame_count, 1) 8213 self.assertEqual(counter.op_count, 2) 8214 8215 def test_tracing_tree_map_only(self): 8216 import torch.utils._pytree as pytree 8217 8218 def fn(xs): 8219 def mapper(x): 8220 return x.clone() 8221 8222 y = pytree.tree_map_only(torch.Tensor, mapper, xs) 8223 return y 8224 8225 xs = [torch.tensor(i) for i in range(3)] + ["hi"] 8226 xsa = (xs, xs) 8227 xsb = {"aa": xsa, "ab": xs} 8228 8229 counter = CompileCounter() 8230 comp_out = torch._dynamo.optimize(counter, nopython=True)(fn)(xsb) 8231 real_out = fn(xsb) 8232 8233 self.assertEqual(comp_out, real_out) 8234 self.assertEqual(counter.frame_count, 1) 8235 self.assertEqual(counter.op_count, 9) 8236 8237 @torch._dynamo.config.patch( 8238 capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True 8239 ) 8240 def test_unbacked_symint(self): 8241 @torch.compile(backend="eager") 8242 def f(lengths, values): 8243 sizes = lengths.tolist() 8244 for s in sizes: 8245 torch._check_is_size(s) 8246 torch._check(s >= 2) 8247 torch._check(s <= 100) 8248 return torch.split(values, sizes) 8249 8250 f(torch.tensor([2, 3, 4]), torch.randn(9)) 8251 8252 @torch._dynamo.config.patch( 8253 capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True 8254 ) 8255 def test_unbacked_auto_functionalize_op(self): 8256 @torch.library.custom_op( 8257 "mylib::mk_image", mutates_args=("decoder",), device_types=["cpu"] 8258 ) 8259 def mk_image(decoder: Tensor) -> Tensor: 8260 return torch.randn(2, 3, 4, 5) 8261 8262 @torch.library.register_fake("mylib::mk_image") 8263 def _(decoder: Tensor) -> Tensor: 8264 image_size = [torch.library.get_ctx().new_dynamic_size() for _ in range(4)] 8265 return torch.empty(image_size) 8266 8267 @torch.compile(fullgraph=True) 8268 def f(x): 8269 return torch.ops.mylib.mk_image.default(x) 8270 8271 x = torch.zeros(100, dtype=torch.int64) 8272 f(x) 8273 8274 @torch._dynamo.config.patch(capture_scalar_outputs=True) 8275 def test_runtime_assert_replacement(self): 8276 @torch.compile(backend="aot_eager") 8277 def fn(x, y): 8278 z = y.item() 8279 torch._check(z == 3) 8280 return x + z 8281 8282 fn(torch.randn(4), torch.tensor([3])) 8283 self.assertRaises(RuntimeError, lambda: fn(torch.randn(4), torch.tensor([4]))) 8284 8285 @torch._dynamo.config.patch(capture_scalar_outputs=True) 8286 def test_cat_unbacked(self): 8287 @torch.compile(backend="eager") 8288 def fn(x, y): 8289 z = y.item() 8290 return torch.cat([x, torch.ones(z)]) 8291 8292 fn(torch.randn(2, 3), torch.tensor([0])) 8293 self.assertRaises( 8294 RuntimeError, lambda: fn(torch.randn(2, 3), torch.tensor([1])) 8295 ) 8296 8297 @torch._dynamo.config.patch( 8298 capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True 8299 ) 8300 def test_aot_autograd_propagate_unbacked_symints_shape(self): 8301 @torch.compile(backend="aot_eager") 8302 def f(x): 8303 return torch.nonzero(x) 8304 8305 f(torch.tensor([1, 0, 3, 2, 0])) 8306 8307 def test_simple_set_usage(self): 8308 def foo(x, y): 8309 setty = {x, y} 8310 return setty.pop() * setty.pop() 8311 8312 counter = CompileCounter() 8313 foo = torch._dynamo.optimize(counter, nopython=True)(foo) 8314 x = torch.randn(10, 10) 8315 y = torch.randn(10, 10) 8316 foo(x, y) 8317 self.assertEqual(counter.frame_count, 1) 8318 8319 def test_add_to_set(self): 8320 def foo(x, y): 8321 setty = set() 8322 setty.add(x[0]) 8323 setty.add(x[1]) 8324 setty.add(x[2]) 8325 setty.add(y) 8326 return y * len(setty) 8327 8328 x = torch.randn(10, 10) 8329 y = torch.randn(2, 2) 8330 eager_result = foo([x, x, x, x, y], y) 8331 8332 counter = CompileCounter() 8333 foo = torch._dynamo.optimize(counter, nopython=True)(foo) 8334 result = foo([x, x, x, x, y], y) 8335 self.assertEqual(counter.frame_count, 1) 8336 self.assertEqual(result, eager_result) 8337 8338 def test_iter_set(self): 8339 def foo(x, y): 8340 setty = set() 8341 for t in x: 8342 setty.add(t) 8343 return y * len(setty) 8344 8345 x = torch.randn(10, 10) 8346 y = torch.randn(2, 2) 8347 eager_result = foo([x, x, x, x, y], y) 8348 8349 counter = CompileCounter() 8350 foo = torch._dynamo.optimize(counter, nopython=True)(foo) 8351 result = foo([x, x, x, x, y], y) 8352 self.assertEqual(counter.frame_count, 1) 8353 self.assertEqual(result, eager_result) 8354 8355 def test_input_set_graph_break(self): 8356 def foo(x): 8357 return x.pop() * x.pop() 8358 8359 x = torch.randn(10, 10) 8360 y = torch.randn(10, 10) 8361 8362 counter = CompileCounter() 8363 8364 inp = {x, x, x, x, y, y} 8365 foo = torch._dynamo.optimize(counter, nopython=True)(foo) 8366 8367 # There's a lot of stuff about sets that cannot work without a good deal of exertion on our part. 8368 # Specifically, getting a set as input won't ever work with how GetItemSource works (Can't arbitrary access set contents) 8369 # and so the guard story for the objects passed into input just isn't there atm. 8370 with self.assertRaisesRegex( 8371 torch._dynamo.exc.Unsupported, 8372 "^call_method UserDefinedObjectVariable\\(set\\).*", 8373 ): 8374 foo(inp) 8375 8376 foo = torch._dynamo.optimize(counter, nopython=False)(foo) 8377 foo(inp) 8378 self.assertEqual(counter.frame_count, 1) 8379 8380 def test_reconstruct_set_across_graph_break(self): 8381 def foo(x, y): 8382 setty = set() 8383 for t in x: 8384 setty.add(t) 8385 print("Break!") 8386 return y * len(setty) 8387 8388 x = torch.randn(10, 10) 8389 y = torch.randn(2, 2) 8390 8391 counter = CompileCounter() 8392 foo = torch._dynamo.optimize(counter)(foo) 8393 result = foo([x, x, x, x, y], y) 8394 8395 def test_set_aliasing_recompiles(self): 8396 g1 = torch.randn(10) 8397 g2 = torch.randn(10) 8398 g3 = torch.randn(10) 8399 g4 = torch.randn(10) 8400 8401 def foo(a, b, c): 8402 myset = {g1, a, b, c} 8403 return a + len(myset) 8404 8405 counter = CompileCounter() 8406 foo = torch._dynamo.optimize(counter)(foo) 8407 # first call with no aliasing 8408 foo(g2, g3, g4) 8409 self.assertEqual(counter.frame_count, 1) 8410 8411 # no aliasing again 8412 foo(g3, g2, g4) 8413 # assert no recompile 8414 self.assertEqual(counter.frame_count, 1) 8415 8416 # aliasing changes, we should recompile 8417 foo(g2, g2, g2) 8418 self.assertEqual(counter.frame_count, 2) 8419 8420 # same aliasing, different tensor 8421 foo(g3, g3, g3) 8422 self.assertEqual(counter.frame_count, 2) 8423 8424 # aliasing between global and arg, should recompile again 8425 foo(g1, g1, g1) 8426 self.assertEqual(counter.frame_count, 3) 8427 8428 # Reset 8429 torch._dynamo.reset() 8430 8431 # aliasing between global and arg, first call 8432 foo(g1, g1, g1) 8433 self.assertEqual(counter.frame_count, 4) 8434 8435 # same aliasing, different tensor, all local, recompile 8436 foo(g3, g3, g3) 8437 self.assertEqual(counter.frame_count, 5) 8438 8439 # aliasing same tensor, we shouldn't recompile 8440 foo(g2, g2, g2) 8441 self.assertEqual(counter.frame_count, 5) 8442 8443 # No aliasing 8444 foo(g2, g3, g4) 8445 self.assertEqual(counter.frame_count, 6) 8446 8447 # No aliasing again 8448 foo(g3, g2, g4) 8449 # assert no recompile 8450 self.assertEqual(counter.frame_count, 6) 8451 8452 def test_str_format_return1(self): 8453 @torch.compile(backend="eager", fullgraph=True) 8454 def fn(img): 8455 x = torch.sin(img) 8456 y = f"shape {img.shape[-2:]} batch size {img.shape[0]}" 8457 return img + x, y 8458 8459 img1 = torch.randn(1, 1, 8, 8) 8460 res, msg = fn(img1) 8461 self.assertEqual(msg, "shape torch.Size([8, 8]) batch size 1") 8462 self.assertEqual(res, img1 + torch.sin(img1)) 8463 8464 def test_str_format_return2(self): 8465 @torch.compile(backend="eager", fullgraph=True) 8466 def fn(img): 8467 x = torch.sin(img) 8468 y = "shape {} batch size {y:.2f}".format(img.shape[-2:], y=img.shape[0]) 8469 return img + x, y 8470 8471 img1 = torch.randn(1, 1, 8, 8) 8472 res, msg = fn(img1) 8473 self.assertEqual(msg, "shape torch.Size([8, 8]) batch size 1.00") 8474 self.assertEqual(res, img1 + torch.sin(img1)) 8475 8476 @torch._dynamo.config.patch(capture_scalar_outputs=True) 8477 def test_validate_outputs_unbacked(self): 8478 class SillyCat(torch.autograd.Function): 8479 @staticmethod 8480 def forward(ctx, x0, x1, i): 8481 ctx.save_for_backward(i) 8482 return torch.cat([x0, x1]) 8483 8484 @staticmethod 8485 def backward(ctx, grad_out): 8486 (i,) = ctx.saved_tensors 8487 i0, i1 = i.tolist() 8488 g_x0, g_x1 = grad_out.split([i0, i1]) 8489 return g_x0, g_x1, None 8490 8491 @torch.compile(backend="aot_eager", fullgraph=True) 8492 def f(x, i): 8493 i0, i1 = i.tolist() 8494 x0, x1 = x.split([i0, i1]) 8495 return SillyCat.apply(x0, x1, i) 8496 8497 f(torch.randn(9, requires_grad=True), torch.tensor([3, 6])) 8498 8499 def test_str_format_assert1(self): 8500 @torch.compile(backend="eager", fullgraph=True) 8501 def fn(img): 8502 x = torch.sin(img) 8503 val = x.shape[-2:] 8504 torch._assert(len(val) == 2, f"shape {img.shape}") 8505 return img + x 8506 8507 img1 = torch.randn(1, 1, 8, 8) 8508 res = fn(img1) 8509 self.assertEqual(res, img1 + torch.sin(img1)) 8510 8511 def test_str_format_assert2(self): 8512 cnt = CompileCounter() 8513 8514 @torch.compile(backend=cnt) 8515 def fn(img): 8516 x = torch.sin(img) 8517 torch._assert( 8518 img.shape[-2] == 8 and img.shape[-1] == 16, f"shape {img.shape}" 8519 ) 8520 return img + x 8521 8522 img1 = torch.randn(1, 3, 8, 16) 8523 res = fn(img1) 8524 self.assertEqual(res, img1 + torch.sin(img1)) 8525 self.assertEqual(cnt.frame_count, 1) 8526 8527 # trigger a recompile and graph break 8528 img2 = torch.randn(1, 3, 8, 15) 8529 self.assertRaises(AssertionError, lambda: fn(img2)) 8530 8531 def test_tolist_scalar(self): 8532 def fn(x): 8533 new_list = [] 8534 for i in x.tolist(): 8535 new_list.append(i * 4) 8536 return new_list 8537 8538 x = torch.tensor([3]) 8539 eager = fn(x) 8540 counter = CompileCounter() 8541 compiled = torch._dynamo.optimize(counter, nopython=True)(fn)(x) 8542 self.assertEqual(eager, compiled) 8543 self.assertEqual(counter.frame_count, 1) 8544 8545 def test_tolist_1d(self): 8546 def fn(x): 8547 new_list = [] 8548 for i in x.tolist(): 8549 new_list.append(i * 4) 8550 return new_list 8551 8552 x = torch.tensor([2, 1]) 8553 eager = fn(x) 8554 counter = CompileCounter() 8555 compiled = torch._dynamo.optimize(counter, nopython=True)(fn)(x) 8556 self.assertEqual(eager, compiled) 8557 self.assertEqual(counter.frame_count, 1) 8558 8559 def test_tolist_kd(self): 8560 def fn(x): 8561 new_list = [] 8562 for i in x.tolist(): 8563 new_list.append(i * 4) 8564 return new_list 8565 8566 x = torch.tensor([[[2, 1], [2, 1], [2, 1]], [[2, 1], [2, 1], [2, 1]]]) 8567 eager = fn(x) 8568 counter = CompileCounter() 8569 compiled = torch._dynamo.optimize(counter, nopython=True)(fn)(x) 8570 self.assertEqual(eager, compiled) 8571 self.assertEqual(counter.frame_count, 1) 8572 8573 @patch.object(torch._dynamo.config, "specialize_int", True) 8574 def test_tolist_0d(self): 8575 def fn(x): 8576 new_list = [] 8577 i = x.tolist() 8578 new_list.append(i * 4) 8579 return new_list 8580 8581 x = torch.tensor(42) 8582 eager = fn(x) 8583 counter = CompileCounter() 8584 compiled = torch._dynamo.optimize(counter, nopython=True)(fn)(x) 8585 self.assertEqual(eager, compiled) 8586 self.assertEqual(counter.frame_count, 1) 8587 8588 @patch.object(torch._dynamo.config, "assume_static_by_default", False) 8589 @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", False) 8590 def test_tolist_kd_dynamic(self): 8591 def fn(x): 8592 new_list = [] 8593 i = x.tolist() 8594 new_list.append(i * 4) 8595 return new_list 8596 8597 x = torch.randint(3, 5, [5, 5]) 8598 eager = fn(x) 8599 counter = CompileCounter() 8600 compiled_fn = torch._dynamo.optimize(counter, nopython=True)(fn) 8601 compiled = compiled_fn(x) 8602 self.assertEqual(eager, compiled) 8603 self.assertEqual(counter.frame_count, 1) 8604 8605 # Value change, no recompiles 8606 x = torch.randint(7, 9, [5, 5]) 8607 compiled_fn(x) 8608 self.assertEqual(counter.frame_count, 1) 8609 8610 # Size change, forced recompiles 8611 x = torch.randint(3, 5, [3, 3]) 8612 compiled_fn(x) 8613 self.assertEqual(counter.frame_count, 2) 8614 8615 def test_tolist_float(self): 8616 def fn(x): 8617 new_list = [] 8618 for i in x.tolist(): 8619 new_list.append(i * 4) 8620 return new_list 8621 8622 x = torch.tensor( 8623 [[[2.0, 1.0], [2.0, 1.0], [2.0, 1.0]], [[2.0, 1.0], [2.0, 1.0], [2.0, 1.0]]] 8624 ) 8625 eager = fn(x) 8626 counter = CompileCounter() 8627 compiled = torch._dynamo.optimize(counter)(fn)(x) 8628 self.assertEqual(eager, compiled) 8629 # Nothing to compile here 8630 self.assertEqual(counter.frame_count, 0) 8631 8632 def test_inline_closure_not_loaded_by_parent(self): 8633 def outer(a): 8634 return a + 1 8635 8636 def indirect(x): 8637 return direct(x) 8638 8639 def direct(x): 8640 def deep2(c): 8641 return outer(c) 8642 8643 def deep(c): 8644 return deep2(c) 8645 8646 return deep(x) 8647 8648 x = torch.randn(3) 8649 eager = indirect(x) 8650 counter = CompileCounter() 8651 compiled = torch._dynamo.optimize(counter)(indirect)(x) 8652 self.assertEqual(eager, compiled) 8653 self.assertEqual(counter.frame_count, 1) 8654 8655 def test_deque_input(self): 8656 a = torch.randn([2, 3]) 8657 b = torch.randn([2, 3]) 8658 d1 = collections.deque([a, b]) 8659 d1.insert(0, "foo") 8660 8661 d2 = collections.deque([a, b]) 8662 d2.insert(0, "foo") 8663 8664 def fn(q): 8665 a = q.pop() 8666 b = q.pop() 8667 return a * b 8668 8669 eager = fn(d1) 8670 counter = CompileCounter() 8671 compiled = torch._dynamo.optimize(counter)(fn)(d2) 8672 self.assertEqual(eager, compiled) 8673 self.assertEqual(counter.frame_count, 1) 8674 8675 def test_deque_append_left(self): 8676 d1 = collections.deque([10, 10]) 8677 d1.insert(0, "foo") 8678 8679 d2 = collections.deque([10, 10]) 8680 d2.insert(0, "foo") 8681 8682 def fn(q, a, b): 8683 q.appendleft(a) 8684 q.appendleft(b) 8685 return q.popleft() * q.popleft() 8686 8687 a = torch.randn([3, 3]) 8688 b = torch.randn([3, 3]) 8689 eager = fn(d1, a, b) 8690 counter = CompileCounter() 8691 compiled = torch._dynamo.optimize(counter)(fn)(d2, a, b) 8692 self.assertEqual(eager, compiled) 8693 self.assertEqual(counter.frame_count, 1) 8694 self.assertTrue(isinstance(compiled, torch.Tensor)) 8695 8696 def test_yield_from(self): 8697 def yield_from_fn(t_list, k): 8698 def yield_from_gen(l): 8699 l2 = [t * k for t in l] 8700 yield from l2 8701 8702 return [t * k for t in yield_from_gen(t_list)] 8703 8704 t_list = [torch.randn([2, 3]) for _ in range(3)] 8705 eager = yield_from_fn(t_list, 2) 8706 counter = CompileCounter() 8707 compiled = torch._dynamo.optimize(counter)(yield_from_fn)(t_list, 2) 8708 self.assertEqual(eager, compiled) 8709 self.assertEqual(counter.frame_count, 1) 8710 8711 def test_yield_from_in_a_loop(self): 8712 def gen2(): 8713 yield 1 8714 8715 def gen1(): 8716 for value in range(5): 8717 yield from gen2() 8718 8719 def fn(x): 8720 c = 0 8721 for i in gen1(): 8722 c = c + i 8723 return x + c 8724 8725 opt_fn = torch.compile(fn, backend="eager") 8726 x = torch.zeros(4) 8727 self.assertEqual(fn(x), opt_fn(x)) 8728 8729 def test_yield_gen_and_from(self): 8730 def populate_and_multiply_sequence(n, multiplier): 8731 # Inline generator 8732 def tensor_generator(): 8733 for i in range(n): 8734 yield torch.tensor([i]) 8735 8736 # Use 'yield from' to iterate over tensors and multiply 8737 t_list = [tensor * multiplier for tensor in tensor_generator()] 8738 8739 def yield_from_gen(): 8740 yield from t_list 8741 8742 return [t for t in yield_from_gen()] 8743 8744 multiplier = torch.tensor([10]) 8745 eager = populate_and_multiply_sequence(5, multiplier) 8746 counter = CompileCounter() 8747 compiled = torch._dynamo.optimize(counter)(populate_and_multiply_sequence)( 8748 5, multiplier 8749 ) 8750 self.assertEqual(eager, compiled) 8751 self.assertEqual(counter.frame_count, 1) 8752 8753 def test_yield_from_user_stop_iteration(self): 8754 class MyIter: 8755 def __init__(self, seq): 8756 self.seq = seq 8757 self.index = 0 8758 8759 def __iter__(self): 8760 return self 8761 8762 def __next__(self): 8763 self.index += 1 8764 if self.index <= len(self.seq): 8765 return self.seq[self.index - 1] 8766 raise StopIteration(self.index) 8767 8768 def yield_from_iter_fn(seq): 8769 def gen(seq): 8770 yield from MyIter(seq) 8771 8772 return [i for i in gen(seq)] 8773 8774 seq = [torch.randn([2, 3]) for _ in range(3)] 8775 eager = yield_from_iter_fn(seq) 8776 counter = CompileCounter() 8777 compiled = torch._dynamo.optimize(counter)(yield_from_iter_fn)(seq) 8778 self.assertEqual(eager, compiled) 8779 self.assertEqual(counter.frame_count, 0) 8780 8781 def test_yield_send_to_subgenerator_graph_break(self): 8782 def subgenerator(tensor): 8783 multiplier = yield 8784 yield tensor * multiplier 8785 8786 def main_generator(t_list): 8787 for tensor in t_list: 8788 subgen = subgenerator(tensor) 8789 next(subgen) 8790 yield from subgen.send(torch.tensor([10])) 8791 8792 t_list = [torch.tensor([i]) for i in range(5)] 8793 eager = list(main_generator(t_list)) 8794 8795 counter = CompileCounter() 8796 compiled_fn = torch._dynamo.optimize(counter)(main_generator) 8797 compiled = list(compiled_fn(t_list)) 8798 8799 self.assertEqual(eager, compiled) 8800 self.assertEqual(counter.frame_count, 0) 8801 8802 def test_derpy_nn_module_usage(self): 8803 def ff1(x): 8804 self = mod1 8805 return torch.sigmoid(self.mod2(x) + self.param1) 8806 8807 def ff2(x): 8808 self = mod2 8809 return torch.cos(torch.sin(x) * self.param2 + 10) 8810 8811 mod1 = torch.nn.Module() 8812 mod2 = torch.nn.Module() 8813 mod1.register_module("mod2", mod2) 8814 mod1.register_parameter("param1", torch.nn.Parameter(torch.randn(10))) 8815 mod1.forward = ff1 8816 mod2.register_parameter("param2", torch.nn.Parameter(torch.randn(10))) 8817 mod2.forward = ff2 8818 mod1.eval() 8819 8820 x = torch.randn(10) 8821 expected = mod1(x) 8822 counter = CompileCounter() 8823 actual = torch.compile(mod1, backend=counter, fullgraph=True)(x) 8824 self.assertEqual(actual, expected) 8825 self.assertEqual(counter.op_count, 6) 8826 8827 def test_default_args_device_dtype(self): 8828 class Foo: 8829 def __init__( 8830 self, 8831 dtype: torch.dtype = torch.float16, 8832 device: torch.device = torch.device("cpu"), 8833 ) -> None: 8834 self.value = torch.tensor(10, dtype=dtype, device=device) 8835 8836 def fn(): 8837 return Foo().value + 1 8838 8839 opt_func = torch._dynamo.optimize("eager", nopython=True)(fn) 8840 ref = fn() 8841 res = opt_func() 8842 self.assertEqual(ref, res) 8843 8844 def test_torch_device_python_type(self): 8845 for device, device_type, index in [ 8846 ("cpu", "cpu", None), 8847 ("cuda:0", "cuda", 0), 8848 ]: 8849 if device == "cuda:0" and not TEST_CUDA: 8850 continue 8851 8852 def fn(target): 8853 target_device = target.device 8854 a = torch.zeros(2, 3, device=target_device) 8855 # Constant assert at trace time 8856 assert isinstance(target_device, torch.device) 8857 assert target_device.type == device_type 8858 assert target_device.index == index 8859 b = torch.zeros(2, 3, device=target_device) 8860 c = torch.zeros(2, 3, device=target_device) 8861 return a + b + c 8862 8863 from torch._dynamo.variables import ConstantVariable 8864 8865 device = torch.device(device) 8866 expected_variable = ConstantVariable(device) 8867 self.assertEqual(expected_variable.python_type(), type(device)) 8868 8869 opt_func = torch._dynamo.optimize("eager", nopython=True)(fn) 8870 a = torch.tensor([2, 3], device=device) 8871 res = opt_func(a) 8872 self.assertIsInstance(res, torch.Tensor) 8873 8874 def test_torch_dtype_python_type(self): 8875 def fn(target): 8876 target_dtype = target.dtype 8877 a = torch.zeros(2, 3, dtype=target_dtype) 8878 # Constant assert at trace time 8879 assert isinstance(target_dtype, torch.dtype) 8880 b = torch.zeros(2, 3, dtype=target_dtype) 8881 c = torch.zeros(2, 3, dtype=target_dtype) 8882 return a + b + c 8883 8884 from torch._dynamo.variables import ConstantVariable 8885 8886 dtype = torch.float16 8887 expected_variable = ConstantVariable(dtype) 8888 self.assertEqual(expected_variable.python_type(), type(dtype)) 8889 8890 opt_func = torch._dynamo.optimize("eager", nopython=True)(fn) 8891 a = torch.tensor([2, 3], dtype=dtype) 8892 res = opt_func(a) 8893 self.assertIsInstance(res, torch.Tensor) 8894 8895 def test_itertools_repeat(self): 8896 counters.clear() 8897 8898 def fn(x): 8899 r = itertools.repeat(100.0, 5) 8900 for i in r: 8901 x += i 8902 return x 8903 8904 x = torch.randn([2, 5]) 8905 eager = fn(x) 8906 8907 compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn) 8908 compiled = compiled_fn(x) 8909 8910 self.assertEqual(list(eager), list(compiled)) 8911 self.assertEqual(len(counters["graph_break"]), 0) 8912 8913 def test_itertools_infinite_repeat(self): 8914 counters.clear() 8915 8916 def fn(x): 8917 r = itertools.repeat(100.0) 8918 idx = 0 8919 for i in r: 8920 x += i 8921 idx += 1 8922 if idx > 10: 8923 break 8924 return x 8925 8926 x = torch.randn([2, 5]) 8927 eager = fn(x) 8928 8929 compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn) 8930 compiled = compiled_fn(x) 8931 8932 self.assertEqual(list(eager), list(compiled)) 8933 self.assertEqual(len(counters["graph_break"]), 0) 8934 8935 def test_itertools_infinite_repeat_mutation(self): 8936 counters.clear() 8937 8938 def fn(x): 8939 r = itertools.repeat(x) 8940 idx = 0 8941 for i in r: 8942 x += i 8943 i += 1 8944 idx += 1 8945 if idx > 10: 8946 break 8947 return x 8948 8949 x = torch.randn([2, 5]) 8950 eager = fn(x) 8951 8952 compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn) 8953 compiled = compiled_fn(x) 8954 8955 self.assertEqual(list(eager), list(compiled)) 8956 self.assertEqual(len(counters["graph_break"]), 0) 8957 8958 def test_itertools_infinite_count(self): 8959 for args in ([], [10], [5, -1]): 8960 counters.clear() 8961 8962 def fn(x): 8963 r = itertools.count(*args) 8964 idx = 0 8965 for i in r: 8966 x += i 8967 idx += 1 8968 if idx > 10: 8969 break 8970 return x 8971 8972 x = torch.randn([2, 5]) 8973 eager = fn(x) 8974 8975 compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn) 8976 compiled = compiled_fn(x) 8977 8978 self.assertEqual(list(eager), list(compiled)) 8979 self.assertEqual(len(counters["graph_break"]), 0) 8980 8981 def test_itertools_infinite_cycle(self): 8982 counters.clear() 8983 8984 def fn(x): 8985 for iterator in ( 8986 iter([]), 8987 iter([10, 11.0]), 8988 itertools.repeat(-1, 3), 8989 itertools.count(10), 8990 ): 8991 r = itertools.cycle(iterator) 8992 idx = 0 8993 x += 1 8994 for i in r: 8995 x += i 8996 idx += 1 8997 if idx > 10: 8998 break 8999 return x 9000 9001 x = torch.randn([2, 5]) 9002 eager = fn(x) 9003 9004 compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn) 9005 compiled = compiled_fn(x) 9006 9007 self.assertEqual(list(eager), list(compiled)) 9008 self.assertEqual(len(counters["graph_break"]), 0) 9009 9010 def test_itertools_accumulate_symint_default_sum(self): 9011 # https://github.com/pytorch/pytorch/issues/110287 9012 counters.clear() 9013 9014 def fn(x): 9015 r = itertools.accumulate([x.size(0), x.size(1)]) 9016 for i in r: 9017 x *= i 9018 return x 9019 9020 x = torch.randn(2, 3) 9021 eager = fn(x) 9022 9023 compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn) 9024 compiled = compiled_fn(x) 9025 9026 self.assertEqual(list(eager), list(compiled)) 9027 self.assertEqual(len(counters["graph_break"]), 0) 9028 9029 def test_itertools_accumulate_tensors_default_sum(self): 9030 counters.clear() 9031 9032 def fn(a, b, c, d, x): 9033 l = [a, b, c, d, x] 9034 for i, t in enumerate(l): 9035 l[i] = t * x 9036 return itertools.accumulate(l) 9037 9038 t_list = [torch.tensor([i + 1]) for i in range(4)] 9039 x = torch.tensor([[1, 2], [3, 4]]) 9040 eager = fn(*t_list, x) 9041 9042 compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn) 9043 compiled = compiled_fn(*t_list, x) 9044 9045 self.assertEqual(list(eager), list(compiled)) 9046 self.assertEqual(len(counters["graph_break"]), 0) 9047 9048 def test_itertools_accumulate_tensors_builtins(self): 9049 for builtin_op in [operator.mul, operator.sub, operator.pow]: 9050 counters.clear() 9051 9052 def fn(a, b, c, d, x): 9053 l = [a, b, c, d, x] 9054 for i, t in enumerate(l): 9055 l[i] = t * x 9056 return itertools.accumulate(l, builtin_op) 9057 9058 t_list = [torch.tensor([i + 1]) for i in range(4)] 9059 x = torch.tensor([[1, 2], [3, 4]]) 9060 eager = fn(*t_list, x) 9061 9062 compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn) 9063 compiled = compiled_fn(*t_list, x) 9064 9065 self.assertEqual(list(eager), list(compiled)) 9066 self.assertEqual(len(counters["graph_break"]), 0) 9067 9068 def test_itertools_accumulate_tensors_kwargs(self): 9069 from torch._dynamo.utils import counters 9070 9071 for kwargs in [ 9072 {"func": operator.mul}, 9073 {"initial": 100}, 9074 {"func": operator.sub, "initial": -1}, 9075 ]: 9076 counters.clear() 9077 9078 def fn(a, b, c, d, x): 9079 l = [a, b, c, d, x] 9080 for i, t in enumerate(l): 9081 l[i] = t * x 9082 return itertools.accumulate(l, **kwargs) 9083 9084 t_list = [torch.tensor([i + 1]) for i in range(4)] 9085 x = torch.tensor([[1, 2], [3, 4]]) 9086 9087 compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn) 9088 compiled = compiled_fn(*t_list, x) 9089 eager = fn(*t_list, x) 9090 9091 self.assertEqual(list(eager), list(compiled)) 9092 self.assertEqual(len(counters["graph_break"]), 0) 9093 9094 def test_packaging_version_parse(self): 9095 from packaging import version 9096 9097 @torch.compile(backend="eager", fullgraph=True) 9098 def fn(): 9099 x = torch.zeros(1) 9100 if version.parse(torch.__version__) >= version.parse("2.0.0"): 9101 return x + 1 9102 return x 9103 9104 self.assertEqual(fn().item(), 1) 9105 9106 def test_itertools_accumulate_tensors_user_defined(self): 9107 def udo_fn_0(a, b): 9108 return -1 9109 9110 rando = random.randint(0, 1) 9111 9112 def udo_fn_1(a, b): 9113 return a * rando + b * rando 9114 9115 seen = [] 9116 9117 def udo_fn_2(a, b): 9118 seen.append(a) 9119 seen.append(b) 9120 return a * len(seen) 9121 9122 for udo_fn in [udo_fn_0, udo_fn_1, udo_fn_2]: 9123 counters.clear() 9124 torch._dynamo.reset() 9125 9126 def fn(a, b, c, d, x): 9127 l = [a, b, c, d, x] 9128 for i, t in enumerate(l): 9129 l[i] = t * x 9130 return itertools.accumulate(l, udo_fn) 9131 9132 t_list = [torch.tensor([i]) for i in range(4)] 9133 x = torch.tensor([[1, 2], [3, 4]]) 9134 eager = fn(*t_list, x) 9135 9136 compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn) 9137 compiled = compiled_fn(*t_list, x) 9138 9139 self.assertEqual(list(eager), list(compiled)) 9140 self.assertEqual(len(counters["graph_break"]), 0) 9141 9142 def test_pure_python_accumulate(self): 9143 def accumulate(iterable, func=lambda x, y: x + y): 9144 it = iter(iterable) 9145 try: 9146 # Initialize the accumulator with the first value from the iterable 9147 accumulator = next(it) 9148 except StopIteration: 9149 # If the iterable is empty, return an empty generator 9150 return 9151 yield accumulator 9152 9153 for element in it: 9154 accumulator = func(accumulator, element) 9155 yield accumulator 9156 9157 def fn(it): 9158 return accumulate(it) 9159 9160 t_list = [torch.tensor([i]) for i in range(4)] 9161 eager = fn(t_list) 9162 9163 counter = CompileCounter() 9164 compiled_fn = torch._dynamo.optimize(counter)(fn) 9165 compiled = compiled_fn(t_list) 9166 9167 self.assertEqual(list(eager), list(compiled)) 9168 self.assertEqual(counter.frame_count, 1) 9169 9170 def test_itertools_groupby_pure_python_default_identify_func(self): 9171 counters.clear() 9172 9173 def fn(l): 9174 return [(k, list(g)) for k, g in itertools.groupby(l)] 9175 9176 l = [1, 2, 2, 3, 4, 4, 4, 1, 2] 9177 eager = fn(l) 9178 9179 compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn) 9180 compiled = compiled_fn(l) 9181 9182 self.assertEqual(eager, compiled) 9183 self.assertEqual(len(counters["graph_break"]), 0) 9184 9185 def test_itertools_groupby_pure_python_key_func(self): 9186 counters.clear() 9187 9188 def fn(l): 9189 return [(k, list(g)) for k, g in itertools.groupby(l, key=operator.neg)] 9190 9191 l = [1, 2, -2, 3, 4, 4, -4, 0, -2] 9192 eager = fn(l) 9193 9194 compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn) 9195 compiled = compiled_fn(l) 9196 9197 self.assertEqual(eager, compiled) 9198 self.assertEqual(len(counters["graph_break"]), 0) 9199 9200 def test_list_iterator_contains(self): 9201 def fn(x): 9202 it = iter(["my_weight", "not_my_weight"]) 9203 next(it) 9204 if "my_weight" in it: 9205 return x + 2 9206 return x + 1 9207 9208 x = torch.zeros(3) 9209 compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn) 9210 9211 self.assertEqual(fn(x), compiled_fn(x)) 9212 9213 def test_storage_return(self): 9214 @torch.compile(backend="eager", fullgraph=True) 9215 def fn(x): 9216 y = torch.sin(x + 1) 9217 storage = x.untyped_storage() 9218 storage.resize_(0) 9219 y = torch.cos(y) 9220 return y, storage 9221 9222 x = torch.randn(10) 9223 expected = torch.cos(torch.sin(x + 1)) 9224 y, s = fn(x) 9225 self.assertEqual(y, expected) 9226 self.assertEqual(x.untyped_storage().size(), 0) 9227 self.assertIs(s, x.untyped_storage()) 9228 9229 def test_flat_name_to_original_fqn(self): 9230 class FooBarModule(torch.nn.Module): 9231 def __init__(self): 9232 super().__init__() 9233 self.register_parameter("0", torch.nn.Parameter(torch.randn(3, 4))) 9234 self.register_buffer("test_buf", torch.randn(3, 4)) 9235 self.register_parameter( 9236 "test_param", torch.nn.Parameter(torch.randn(3, 4)) 9237 ) 9238 9239 def forward(self, x): 9240 return ((x + self.test_buf) * getattr(self, "0")) / self.test_param 9241 9242 class TestModule(torch.nn.Module): 9243 def __init__(self): 9244 super().__init__() 9245 self.foo_bar = FooBarModule() 9246 self.register_parameter( 9247 "test_param", torch.nn.Parameter(torch.randn(3, 4)) 9248 ) 9249 self.register_buffer("test_buf", torch.randn(3, 4)) 9250 9251 def forward(self, x): 9252 return (self.foo_bar(x) + self.test_param) * self.test_buf 9253 9254 gm, _ = torch._dynamo.export(TestModule(), torch.randn(3, 4)) 9255 self.assertIn("dynamo_flat_name_to_original_fqn", gm.meta) 9256 expected_fqn = { 9257 "L__self___test_param": "test_param", 9258 "L__self___test_buf": "test_buf", 9259 "getattr_L__self___foo_bar___0__": "foo_bar.0", 9260 "L__self___foo_bar_test_param": "foo_bar.test_param", 9261 "L__self___foo_bar_test_buf": "foo_bar.test_buf", 9262 } 9263 self.assertEqual(expected_fqn, gm.meta["dynamo_flat_name_to_original_fqn"]) 9264 9265 def test_shape_env_no_recording(self): 9266 main = ShapeEnv(should_record_events=False) 9267 9268 # The main ShapeEnv should have no event recorded. 9269 self.assertEqual(len(main.events), 0) 9270 9271 # Call create_symbolic_sizes_strides_storage_offset on both of them. 9272 r = main.create_symbolic_sizes_strides_storage_offset( 9273 torch.randn(3, 2), ConstantSource("x") 9274 ) 9275 9276 # Create a guard: size[0] == 3 (call evaluate_expr) 9277 # - +1 guard entry 9278 # - +1 replacement entry 9279 size = r[0] 9280 bool(size[0] == 3) 9281 9282 # The main ShapeEnv should remain with no event recorded. 9283 self.assertEqual(len(main.events), 0) 9284 9285 if torch.fx.experimental.validator.translation_validation_enabled(): 9286 from torch.fx.experimental.symbolic_shapes import ( 9287 CURRENT_NODE_KEY, 9288 SHAPEENV_EVENT_KEY, 9289 ) 9290 9291 # Check that we don't store any recording metadata on nodes 9292 # from the symbolic shape FX graph. 9293 for n in main.graph.nodes: 9294 self.assertFalse(SHAPEENV_EVENT_KEY in n.meta) 9295 self.assertFalse(CURRENT_NODE_KEY in n.meta) 9296 9297 def _replay_and_check(self, shape_env: ShapeEnv): 9298 if shape_env.should_record_events: 9299 replayed = replay_shape_env_events(shape_env.events) 9300 shape_env.check_equal(replayed) 9301 9302 def test_shape_env_equal_empty(self): 9303 main, other = ShapeEnv(), ShapeEnv() 9304 main.check_equal(other) 9305 self._replay_and_check(main) 9306 9307 @onlyIfTranslationValidation 9308 def test_shape_env_equal_constructor(self): 9309 main, other = ShapeEnv(allow_scalar_outputs=False), ShapeEnv() 9310 self.assertExpectedRaisesInline( 9311 NotEqualError, 9312 lambda: main.check_equal(other), 9313 """\ 9314ShapeEnv not equal: field values don't match: 9315 9316==> settings: values don't match. 9317 > Left: ShapeEnvSettings(allow_scalar_outputs=False, allow_dynamic_output_shape_ops=True, assume_static_by_default=False, specialize_zero_one=True, duck_shape=True, prefer_deferred_runtime_asserts_over_guards=False, _allow_complex_guards_as_runtime_asserts=False) 9318 > Right: ShapeEnvSettings(allow_scalar_outputs=True, allow_dynamic_output_shape_ops=True, assume_static_by_default=False, specialize_zero_one=True, duck_shape=True, prefer_deferred_runtime_asserts_over_guards=False, _allow_complex_guards_as_runtime_asserts=False) 9319""", 9320 ) 9321 self._replay_and_check(main) 9322 9323 @onlyIfTranslationValidation 9324 def test_shape_env_equal_create_symbolic_sizes_strides_storage_offset(self): 9325 main, other = ShapeEnv(), ShapeEnv() 9326 main.create_symbolic_sizes_strides_storage_offset( 9327 torch.randn(3, 2), ConstantSource("x") 9328 ) 9329 self.assertExpectedRaisesInline( 9330 NotEqualError, 9331 lambda: main.check_equal(other), 9332 """\ 9333ShapeEnv not equal: field values don't match: 9334 9335==> name_to_node: values don't match. 9336 > Left: {x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_} 9337 > Right: {} 9338==> source_to_symbol: values don't match. 9339 > Left: {x.size()[0]: x.size()[0], x.size()[1]: x.size()[1], x.storage_offset(): x.storage_offset(), x.stride()[0]: x.stride()[0], x.stride()[1]: x.stride()[1]} 9340 > Right: {} 9341==> val_to_var: values don't match. 9342 > Left: {0: 0, 1: 1, 2: s1, 3: s0} 9343 > Right: {0: 0, 1: 1} 9344==> var_to_range: values don't match. 9345 > Left: {s0: VR[2, int_oo], s1: VR[2, int_oo]} 9346 > Right: {} 9347==> var_to_sources: values don't match. 9348 > Left: {s0: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=<TensorProperty.SIZE: 0>, idx=0)], s1: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=<TensorProperty.SIZE: 0>, idx=1)]} 9349 > Right: {} 9350==> var_to_val: values don't match. 9351 > Left: {s0: 3, s1: 2} 9352 > Right: {} 9353""", 9354 ) 9355 self._replay_and_check(main) 9356 9357 @onlyIfTranslationValidation 9358 def test_shape_env_equal_unbacked(self): 9359 main, other = ShapeEnv(), ShapeEnv() 9360 main.create_unbacked_symint() 9361 main.create_unbacked_symfloat() 9362 main.create_unbacked_symbool() 9363 self.assertExpectedRaisesInline( 9364 NotEqualError, 9365 lambda: main.check_equal(other), 9366 """\ 9367ShapeEnv not equal: field values don't match: 9368 9369==> name_to_node: values don't match. 9370 > Left: {u0, u1, zuf0} 9371 > Right: {} 9372==> unbacked_symfloat_counter: values don't match. 9373 > Left: 1 9374 > Right: 0 9375==> unbacked_symint_counter: values don't match. 9376 > Left: 2 9377 > Right: 0 9378==> var_to_range: values don't match. 9379 > Left: {u0: VR[-int_oo, int_oo], u1: VR[0, 1], zuf0: VR[-oo, oo]} 9380 > Right: {} 9381""", 9382 ) 9383 self._replay_and_check(main) 9384 9385 @onlyIfTranslationValidation 9386 def test_shape_env_equal_evaluate_expr_divisible(self): 9387 main, other = ShapeEnv(), ShapeEnv() 9388 9389 # Call create_symbolic_sizes_strides_storage_offset on both of them. 9390 r = main.create_symbolic_sizes_strides_storage_offset( 9391 torch.randn(3, 2), ConstantSource("x") 9392 ) 9393 other.create_symbolic_sizes_strides_storage_offset( 9394 torch.randn(3, 2), ConstantSource("x") 9395 ) 9396 9397 # Create a guard: size[0] % 3 == 0 (only in the main ShapeEnv) 9398 # - +1 guard entry 9399 # - +1 divisible entry 9400 size = r[0] 9401 bool(size[0] % 3 == 0) 9402 9403 self.assertExpectedRaisesInline( 9404 NotEqualError, 9405 lambda: main.check_equal(other), 9406 """\ 9407ShapeEnv not equal: field values don't match: 9408 9409==> divisible: values don't match. 9410 > Left: {Mod(s0, 3)} 9411 > Right: {} 9412==> guards: values don't match. 9413 > Left: [Eq(Mod(s0, 3), 0)] 9414 > Right: [] 9415==> name_to_node: values don't match. 9416 > Left: {_assert, eq, mod, x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_} 9417 > Right: {x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_} 9418""", 9419 ) 9420 self._replay_and_check(main) 9421 9422 @onlyIfTranslationValidation 9423 def test_shape_env_equal_evaluate_expr_replacement(self): 9424 main, other = ShapeEnv(), ShapeEnv() 9425 9426 # Call create_symbolic_sizes_strides_storage_offset on both of them. 9427 r = main.create_symbolic_sizes_strides_storage_offset( 9428 torch.randn(3, 2), ConstantSource("x") 9429 ) 9430 other.create_symbolic_sizes_strides_storage_offset( 9431 torch.randn(3, 2), ConstantSource("x") 9432 ) 9433 9434 # Create a guard: size[0] == 3 (only in the main ShapeEnv) 9435 # - +1 guard entry 9436 # - +1 replacement entry 9437 size = r[0] 9438 bool(size[0] == 3) 9439 9440 self.assertExpectedRaisesInline( 9441 NotEqualError, 9442 lambda: main.check_equal(other), 9443 """\ 9444ShapeEnv not equal: field values don't match: 9445 9446==> guards: values don't match. 9447 > Left: [Eq(s0, 3)] 9448 > Right: [] 9449==> name_to_node: values don't match. 9450 > Left: {_assert, eq, x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_} 9451 > Right: {x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_} 9452==> replacements: values don't match. 9453 > Left: {s0: 3} 9454 > Right: {} 9455==> var_to_range: values don't match. 9456 > Left: {s0: VR[3, 3], s1: VR[2, int_oo]} 9457 > Right: {s0: VR[2, int_oo], s1: VR[2, int_oo]} 9458""", 9459 ) 9460 self._replay_and_check(main) 9461 9462 @onlyIfTranslationValidation 9463 def test_shape_env_equal_evaluate_expr_refinement(self): 9464 main, other = ShapeEnv(), ShapeEnv() 9465 9466 # Call create_symbolic_sizes_strides_storage_offset on both of them. 9467 r = main.create_symbolic_sizes_strides_storage_offset( 9468 torch.randn(3, 2), ConstantSource("x") 9469 ) 9470 other.create_symbolic_sizes_strides_storage_offset( 9471 torch.randn(3, 2), ConstantSource("x") 9472 ) 9473 9474 # Create a guard: size[0] >= 3 (only in the main ShapeEnv) 9475 # - +1 guard entry 9476 # - +1 var_to_guard entry 9477 # - Change: var_to_range 9478 size = r[0] 9479 bool(size[0] >= 3) 9480 9481 self.assertExpectedRaisesInline( 9482 NotEqualError, 9483 lambda: main.check_equal(other), 9484 """\ 9485ShapeEnv not equal: field values don't match: 9486 9487==> guards: values don't match. 9488 > Left: [s0 >= 3] 9489 > Right: [] 9490==> name_to_node: values don't match. 9491 > Left: {_assert, ge, x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_} 9492 > Right: {x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_} 9493==> var_to_range: values don't match. 9494 > Left: {s0: VR[3, int_oo], s1: VR[2, int_oo]} 9495 > Right: {s0: VR[2, int_oo], s1: VR[2, int_oo]} 9496""", 9497 ) 9498 self._replay_and_check(main) 9499 9500 @onlyIfTranslationValidation 9501 def test_shape_env_equal_runtime_assert(self): 9502 main, other = ShapeEnv(), ShapeEnv() 9503 9504 # Call create_unbacked_symint on both of them. 9505 r = main.create_unbacked_symint() 9506 other.create_unbacked_symint() 9507 9508 # Create a runtime assert: r % 3 == 0 (only in the main ShapeEnv) 9509 # - +1 deferred_runtime_asserts entry 9510 # - Change: num_deferred_runtime_asserts 9511 expect_true(r % 3 == 0) 9512 9513 self.assertExpectedRaisesInline( 9514 NotEqualError, 9515 lambda: main.check_equal(other), 9516 """\ 9517ShapeEnv not equal: field values don't match: 9518 9519==> deferred_runtime_asserts: values don't match. 9520 > Left: {u0: [Eq(PythonMod(u0, 3), 0)]} 9521 > Right: {} 9522==> name_to_node: values don't match. 9523 > Left: {_assert, eq, mod, u0} 9524 > Right: {u0} 9525==> num_deferred_runtime_asserts: values don't match. 9526 > Left: 1 9527 > Right: 0 9528""", 9529 ) 9530 self._replay_and_check(main) 9531 9532 def test_shape_env_recorded_function_fallback(self): 9533 # Make sure the record/replay mechanism for ShapeEnv will fallback 9534 # if no ShapeEnv instance is found. 9535 constrain_range(5, min=2, max=10) 9536 constrain_unify(5, 5) 9537 9538 self.assertExpectedRaisesInline( 9539 AssertionError, 9540 lambda: _constrain_range_for_size(5, min=2, max=10), 9541 """can only constrain range for SymInt""", 9542 ) 9543 9544 def test_default_dtype_change(self): 9545 @torch.compile 9546 def foo(): 9547 def inner(a, b, res_dtype): 9548 print(a, b, res_dtype) 9549 self.assertEqual(torch.result_type(a, b), res_dtype) 9550 9551 inner(torch.tensor(1, device="cpu"), 1.0, torch.get_default_dtype()) 9552 9553 with set_default_dtype(torch.float): 9554 foo() 9555 with set_default_dtype(torch.double): 9556 foo() 9557 9558 def test_numpy_ufunc_out(self): 9559 @torch.compile(backend="eager") 9560 def foo(): 9561 x = np.arange(5) 9562 out = np.empty((x.shape[0], x.shape[0])) 9563 res_out = np.sin(x, out=out) 9564 assert res_out is out 9565 9566 foo() 9567 9568 # Unfortunately, we don't currently preserve the ids of 9569 # res_out and out correctly across the graph break 9570 @unittest.expectedFailure 9571 def test_numpy_ufunc_out_graph_break(self): 9572 @torch.compile(backend="eager") 9573 def foo(): 9574 x = np.arange(5) 9575 out = np.empty((x.shape[0], x.shape[0])) 9576 res_out = np.sin(x, out=out) 9577 torch._dynamo.graph_break() 9578 assert res_out is out 9579 9580 foo() 9581 9582 def test_dict_subclass_cannot_be_initialized_in_graph(self): 9583 for super_class in ( 9584 collections.OrderedDict, 9585 dict, 9586 ): 9587 9588 class CustomDict(super_class): 9589 def __init__(self, *args, **kwargs): 9590 super().__init__(*args, **kwargs) 9591 9592 def fn(x): 9593 c = CustomDict() 9594 c["key"] = x 9595 assert "key" in c 9596 return c["key"] + 1 9597 9598 fn_opt = torch.compile(fn, backend="eager", fullgraph=True) 9599 with self.assertRaisesRegex( 9600 torch._dynamo.exc.Unsupported, "call_function UserDefinedClassVariable" 9601 ): 9602 print(fn_opt(torch.zeros(1))) 9603 9604 @wrapDeterministicFlagAPITest 9605 def test_backward_deterministic_mode_mismatch_warning(self): 9606 @torch.compile 9607 def func(a, b): 9608 return a + b 9609 9610 for forward_deterministic, backward_deterministic in itertools.product( 9611 [True, False], [True, False] 9612 ): 9613 torch.use_deterministic_algorithms(forward_deterministic) 9614 a = torch.randn(10, requires_grad=True) 9615 res = func(a, 1) 9616 grad = torch.ones_like(res) 9617 torch.use_deterministic_algorithms(backward_deterministic) 9618 9619 if not forward_deterministic and backward_deterministic: 9620 with self.assertRaisesRegex( 9621 RuntimeError, 9622 "^This compiled backward function is being run with torch\.use_deterministic_algorithms", 9623 ): 9624 res.backward(grad) 9625 9626 else: 9627 res.backward(grad) 9628 9629 def test_torch_dynamo_codegen_pow(self): 9630 def pow(x): 9631 return x**2 9632 9633 x = np.arange(8) 9634 pow_opt = torch.compile(pow) 9635 9636 actual, source_code = run_and_get_code(pow_opt, x) 9637 expect = pow(x) 9638 9639 self.assertEqual(expect, actual) 9640 9641 self.assertTrue( 9642 all("aten.pow" not in code for code in source_code), 9643 msg="Encountered an unexpected fallback to 'aten pow' in dynamo compiled code", 9644 ) 9645 9646 def test_graph_break_compilation_metrics(self): 9647 def fn(x): 9648 x.cos() 9649 torch._dynamo.graph_break() 9650 x.sin() 9651 torch._dynamo.graph_break() 9652 return x.cos() 9653 9654 torch._dynamo.utils.clear_compilation_metrics() 9655 x = torch.rand((4, 4)) 9656 f = torch.compile(fn, backend="eager") 9657 f(x) 9658 metrics = torch._dynamo.utils.get_compilation_metrics() 9659 # Should only be one restart per event 9660 (restart_reason,) = metrics[0].restart_reasons 9661 self.assertTrue( 9662 "skip function graph_break" in restart_reason, 9663 "Should have logged graph break reason", 9664 ) 9665 self.assertTrue( 9666 metrics[0].dynamo_time_before_restart_s 9667 <= metrics[0].entire_frame_compile_time_s 9668 ) 9669 9670 (restart_reason,) = metrics[1].restart_reasons 9671 self.assertTrue( 9672 "skip function graph_break" in restart_reason, 9673 "Should have logged graph break reason", 9674 ) 9675 self.assertTrue( 9676 metrics[1].dynamo_time_before_restart_s 9677 <= metrics[1].entire_frame_compile_time_s 9678 ) 9679 9680 # No restarts 9681 self.assertTrue( 9682 len(metrics[2].restart_reasons) == 0, "Last compile has no graph break" 9683 ) 9684 self.assertTrue(metrics[2].dynamo_time_before_restart_s == 0) 9685 9686 def test_graph_break_compilation_metrics_on_failure(self): 9687 def fn(x): 9688 return x.sin() 9689 9690 def broken_backend(gm, example_inputs): 9691 raise RuntimeError("broken backend") 9692 9693 x = torch.rand((4, 4)) 9694 f = torch.compile(fn, backend=broken_backend) 9695 with unittest.mock.patch("torch._dynamo.config.suppress_errors", True): 9696 torch._dynamo.utils.clear_compilation_metrics() 9697 f(x) 9698 metrics = torch._dynamo.utils.get_compilation_metrics() 9699 for metric in metrics: 9700 self.assertTrue(metric.dynamo_time_before_restart_s > 0) 9701 self.assertTrue( 9702 "RuntimeError: broken backend" in metric.fail_reason, 9703 "Should have logged fail reason", 9704 ) 9705 9706 def test_compilation_metrics_size_limit(self): 9707 def fn1(x): 9708 return x.relu() 9709 9710 def fn2(x): 9711 return x.cos() 9712 9713 def fn3(x): 9714 return x.sin() 9715 9716 def fn4(x): 9717 return x.exp() 9718 9719 import contextlib 9720 9721 @contextlib.contextmanager 9722 def metrics_limit_ctx(): 9723 try: 9724 torch._dynamo.utils.set_compilation_metrics_limit(3) 9725 yield 9726 finally: 9727 torch._dynamo.utils.set_compilation_metrics_limit( 9728 torch._dynamo.utils.DEFAULT_COMPILATION_METRICS_LIMIT 9729 ) 9730 9731 x = torch.rand((4, 4)) 9732 torch._dynamo.reset() 9733 torch.compile(fn1, backend="eager")(x) 9734 torch.compile(fn2, backend="eager")(x) 9735 torch.compile(fn3, backend="eager")(x) 9736 torch.compile(fn4, backend="eager")(x) 9737 9738 with metrics_limit_ctx(): 9739 torch._dynamo.utils.clear_compilation_metrics() 9740 torch._dynamo.reset() 9741 self.assertEqual(0, len(torch._dynamo.utils.get_compilation_metrics())) 9742 torch.compile(fn1, backend="eager")(x) 9743 self.assertEqual(1, len(torch._dynamo.utils.get_compilation_metrics())) 9744 torch.compile(fn2, backend="eager")(x) 9745 self.assertEqual(2, len(torch._dynamo.utils.get_compilation_metrics())) 9746 torch.compile(fn3, backend="eager")(x) 9747 self.assertEqual(3, len(torch._dynamo.utils.get_compilation_metrics())) 9748 torch.compile(fn4, backend="eager")(x) 9749 self.assertEqual(3, len(torch._dynamo.utils.get_compilation_metrics())) 9750 9751 def test_funcname_cache(self): 9752 src = """\ 9753import torch 9754if True: 9755 test = 3 9756 9757class AAA: 9758 class DUMMY: 9759 class DUMMY2: 9760 pass 9761 9762 def dummy(self): 9763 def dummy2(): 9764 pass 9765 class BBB: 9766 @staticmethod 9767 def CCC(): 9768 class DDD: 9769 if True: 9770 @staticmethod 9771 def EEE(): 9772 x = [torch.ones(3, 3) for _ in range(5)] 9773 return x 9774 return DDD 9775def fn(): 9776 return 3 9777""" 9778 with tempfile.NamedTemporaryFile(mode="w") as f: 9779 f.write(src) 9780 f.flush() 9781 from torch._dynamo.funcname_cache import get_funcname 9782 9783 names = [get_funcname(f.name, i + 1) for i in range(src.count("\n") + 1)] 9784 9785 self.assertExpectedInline( 9786 "\n".join(names), 9787 """\ 9788 9789 9790 9791 9792AAA 9793AAA.DUMMY 9794AAA.DUMMY.DUMMY2 9795AAA.DUMMY.DUMMY2 9796AAA.DUMMY.DUMMY2 9797AAA.dummy 9798AAA.dummy.dummy2 9799AAA.dummy.dummy2 9800AAA.BBB 9801AAA.BBB 9802AAA.BBB.CCC 9803AAA.BBB.CCC.DDD 9804AAA.BBB.CCC.DDD 9805AAA.BBB.CCC.DDD 9806AAA.BBB.CCC.DDD.EEE 9807AAA.BBB.CCC.DDD.EEE 9808AAA.BBB.CCC.DDD.EEE 9809AAA.BBB.CCC 9810fn 9811fn 9812""", 9813 ) 9814 9815 def test_return_dict_with_graph_break_and_update(self): 9816 def create(): 9817 torch._dynamo.graph_break() 9818 return {0: torch.tensor(3)} 9819 9820 def fn(): 9821 return {**create()} 9822 9823 opt_fn = torch.compile(backend="eager")(fn) 9824 result = opt_fn() 9825 self.assertIn(0, result) 9826 self.assertTrue(same(result[0], torch.tensor(3))) 9827 9828 def test_dynamo_reset_clears_cache(self): 9829 """Test that dynamo bytecode cache is freed 9830 when dynamo reset is called 9831 """ 9832 9833 def fn(x): 9834 return torch.sin(x) 9835 9836 opt_fn = torch.compile(backend="eager")(fn) 9837 opt_fn(torch.randn(3, 3)) 9838 9839 c1 = _debug_get_cache_entry_list(fn.__code__) 9840 self.assertEqual(len(c1), 1) 9841 9842 torch._dynamo.reset() 9843 c2 = _debug_get_cache_entry_list(fn.__code__) 9844 self.assertEqual(len(c2), 0) 9845 9846 @torch._dynamo.config.patch(capture_scalar_outputs=True) 9847 def test_guard_size_oblivious(self): 9848 # This code, in fact, does NOT work in eager 9849 @torch.compile(backend="eager", fullgraph=True) 9850 def fn(x): 9851 y = torch.zeros(x.item()) 9852 if guard_size_oblivious(y.size(0) == 0): 9853 assert False 9854 return y 9855 9856 self.assertEqual(fn(torch.tensor([0])), torch.zeros(0)) 9857 9858 def test_guard_size_oblivious_backed(self): 9859 @torch.compile(backend="eager", fullgraph=True) 9860 def f(x): 9861 y = x.size(0) 9862 # This doesn't actually do anything 9863 if guard_size_oblivious(y == 0): 9864 return torch.randn(1) 9865 else: 9866 return torch.randn(2) 9867 9868 # Should not fail in either case 9869 self.assertEqual(f(torch.randn(0)).shape, (1,)) 9870 self.assertEqual(f(torch.randn(2)).shape, (2,)) 9871 9872 def _test_compile_model_free(self, model_inp_ctr, weakref_watch): 9873 """ 9874 Args: 9875 model_inp_ctr 9876 - constructor that returns a new model and inputs to that model 9877 weakref_watch 9878 - function that returns a layer of the model for weakref to 9879 finalize on, so we can check that the layer is freed after 9880 the model goes out of scope 9881 """ 9882 cleared = False 9883 9884 def finalize(): 9885 nonlocal cleared 9886 cleared = True 9887 9888 def run(): 9889 mod, inp = model_inp_ctr() 9890 weakref.finalize(weakref_watch(mod), finalize) 9891 torch.compile(mod, backend="eager")(inp) 9892 9893 run() 9894 gc.collect() 9895 self.assertTrue(cleared) 9896 9897 def test_custom_module_free(self): 9898 """Test that a model is freed when it goes out of scope""" 9899 9900 class Mod(torch.nn.Module): 9901 def __init__(self): 9902 super(Mod, self).__init__() 9903 self.fc = torch.nn.Linear(100, 100) 9904 9905 def forward(self, out): 9906 return self.fc(out) 9907 9908 self._test_compile_model_free( 9909 lambda: (Mod(), torch.randn(100, 100)), 9910 lambda mod: mod.fc, 9911 ) 9912 9913 def test_sequential_module_free(self): 9914 self._test_compile_model_free( 9915 lambda: ( 9916 torch.nn.Sequential( 9917 torch.nn.Linear(100, 100), 9918 torch.nn.ReLU(), 9919 ), 9920 torch.randn(100, 100), 9921 ), 9922 lambda mod: mod[0], 9923 ) 9924 9925 def test_linear_module_free(self): 9926 self._test_compile_model_free( 9927 lambda: (torch.nn.Linear(100, 100), torch.randn(100, 100)), 9928 lambda mod: mod, 9929 ) 9930 9931 # The following 2 tests fail due to https://github.com/python/cpython/issues/118013. 9932 # Tracked by https://github.com/pytorch/pytorch/issues/124302. 9933 # The xfails can be removed once Python 3.12 is updated on CI. 9934 @xfailIfPy312 9935 @unittest.skipIf(True, "Skipping this test for release/2.4") 9936 def test_outside_linear_module_free(self): 9937 # Compared to test_linear_module_free, the linear 9938 # layer is not the code object that is directly compiled. 9939 9940 # This test does not use _test_compile_model_free because of difficulty 9941 # in handling variable fc. 9942 9943 cleared = False 9944 9945 def finalize(): 9946 nonlocal cleared 9947 cleared = True 9948 9949 def run(): 9950 fc = torch.nn.Linear(100, 100) 9951 9952 class Mod(torch.nn.Module): 9953 def __init__(self): 9954 super().__init__() 9955 self.fc_ref = fc 9956 9957 def forward(self, x): 9958 return self.fc_ref(x) 9959 9960 mod = Mod() 9961 inp = torch.randn(100, 100) 9962 weakref.finalize(fc, finalize) 9963 torch.compile(mod, backend="eager")(inp) 9964 9965 run() 9966 # del fc # This should delete all the references 9967 gc.collect() 9968 self.assertTrue(cleared) 9969 9970 @xfailIfPy312 9971 def test_parameter_free(self): 9972 def model_inp_ctr(): 9973 param = torch.nn.Parameter(torch.randn(100, 100)) 9974 9975 class Mod(torch.nn.Module): 9976 def __init__(self): 9977 super().__init__() 9978 self.param = param 9979 9980 def forward(self, x): 9981 return self.param * x[0] 9982 9983 # return param to keep it alive in _test_compile_model_free 9984 return Mod(), (torch.randn(100, 100), param) 9985 9986 self._test_compile_model_free(model_inp_ctr, lambda mod: mod.param) 9987 9988 def test_conditional_list_comp_in_context(self): 9989 def fn(inp): 9990 try: 9991 return [torch.sin(x) for x in inp if x is not None] 9992 except Exception: 9993 pass 9994 9995 inp = [torch.randn(3, 3) for _ in range(3)] + [None] 9996 opt_fn = torch.compile(fn, backend="eager") 9997 opt_fn(inp) 9998 9999 def test_312_binary_slice_with_graph_break1(self): 10000 l1 = torch.nn.Linear(5, 5) 10001 l2 = torch.nn.Linear(5, 5) 10002 10003 def fn(x): 10004 # causes a graph break with items in the stack 10005 n = torch.nn.Sequential(l1, l2) 10006 out = n[1:](x) 10007 return out 10008 10009 opt_fn = torch.compile(fn, backend="eager") 10010 opt_fn(torch.randn(5, 5)) 10011 10012 def test_312_binary_slice_with_graph_break2(self): 10013 class Foo: 10014 def __setitem__(self, key, val): 10015 pass 10016 10017 def __getitem__(self, key): 10018 torch._dynamo.graph_break() 10019 return 1 10020 10021 foo = Foo() 10022 10023 def fn(x): 10024 # graph break in a STORE_SLICE instruction 10025 foo[:] = x 10026 # graph break in BINARY_SLICE with has_backedge check 10027 x = x + foo[:] 10028 if x is None: 10029 x = x + 1 10030 else: 10031 x = x + 1 10032 return x 10033 10034 opt_fn = torch.compile(fn, backend="eager") 10035 opt_fn(torch.randn(5, 5)) 10036 10037 def test_super_after_graph_break(self): 10038 class Foo(torch.nn.Sequential): 10039 def __init__(self, layers): 10040 torch._dynamo.graph_break() 10041 super().__init__(*layers) 10042 10043 def fn(x): 10044 layers = [torch.nn.Linear(3, 3) for _ in range(3)] 10045 mod = Foo(layers) 10046 return mod(x) 10047 10048 opt_fn = torch.compile(fn, backend="eager") 10049 opt_fn(torch.randn(3, 3)) 10050 10051 def test_load_fast_and_clear_graph_break(self): 10052 # Can result in a segfault in 3.12+ if LOAD_FAST_AND_CLEAR 10053 # is not handled properly in a graph break 10054 def fn(): 10055 out = torch.cat([torch.randn(r, 5) for r in range(3)]) 10056 torch._dynamo.graph_break() 10057 out = torch.cat([torch.randn(r, 5) for r in range(3)]) 10058 return out 10059 10060 self.assertEqual(torch._dynamo.optimize("eager")(fn)().shape, (3, 5)) 10061 10062 def test_raises_importerror1(self): 10063 @torch.compile(backend="eager") 10064 def fn(x): 10065 try: 10066 import some_module_that_surely_does_not_exist 10067 10068 return 10069 except ImportError: 10070 pass 10071 return x.sin() 10072 10073 x = torch.randn(8) 10074 self.assertEqual(fn(x), x.sin()) 10075 10076 def test_raises_importerror2(self): 10077 @torch.compile(backend="eager") 10078 def fn(x): 10079 import some_module_that_surely_does_not_exist 10080 10081 return x + 1 10082 10083 x = torch.randn(8) 10084 with self.assertRaises(ImportError): 10085 fn(x) 10086 10087 def test_dynamo_cache_move_to_front(self): 10088 def fn(x, const): 10089 return x + const 10090 10091 # dynamic=False forces Dynamo to recompile 10092 opt_fn = torch.compile(fn, backend="eager", dynamic=False) 10093 10094 inp = torch.randn(3, 3) 10095 10096 # NOTE: assumes that each cache entry is guarded 10097 # on unique Mod instance 10098 opt_fn(inp, 1) 10099 opt_fn(inp, 2) 10100 opt_fn(inp, 3) 10101 10102 c1 = _debug_get_cache_entry_list(fn.__code__) 10103 self.assertEqual(len(c1), 3) 10104 10105 # move cache entry to front 10106 opt_fn(inp, 2) 10107 c2 = _debug_get_cache_entry_list(fn.__code__) 10108 self.assertIs(c1[1], c2[0]) 10109 10110 @torch._dynamo.config.patch(inline_inbuilt_nn_modules=False) 10111 def test_dynamo_cache_invalidate(self): 10112 class Mod(torch.nn.Module): 10113 def __init__(self): 10114 super(Mod, self).__init__() 10115 self.fc = torch.nn.Linear(3, 3) 10116 10117 def forward(self, out): 10118 return self.fc(out) 10119 10120 def fn(x, mod): 10121 return mod(x) 10122 10123 opt_fn = torch.compile(fn, backend="eager") 10124 10125 m1 = Mod() 10126 m2 = Mod() 10127 m3 = Mod() 10128 inp = torch.randn(3, 3) 10129 10130 # NOTE: assumes that each cache entry is guarded 10131 # on unique Mod instance 10132 opt_fn(inp, m1) 10133 opt_fn(inp, m2) 10134 opt_fn(inp, m3) 10135 10136 c1 = _debug_get_cache_entry_list(fn.__code__) 10137 self.assertEqual(len(c1), 3) 10138 10139 # move cache entry to front 10140 opt_fn(inp, m2) 10141 c2 = _debug_get_cache_entry_list(fn.__code__) 10142 self.assertIs(c1[1], c2[0]) 10143 10144 # delete center of cache 10145 del m3 10146 c3 = _debug_get_cache_entry_list(fn.__code__) 10147 self.assertEqual(len(c3), 2) 10148 self.assertIs(c3[0], c2[0]) 10149 self.assertIs(c3[1], c2[2]) 10150 10151 # delete end of cache 10152 del m1 10153 c4 = _debug_get_cache_entry_list(fn.__code__) 10154 self.assertEqual(len(c4), 1) 10155 self.assertIs(c4[0], c3[0]) 10156 10157 del m2 10158 c5 = _debug_get_cache_entry_list(fn.__code__) 10159 self.assertEqual(len(c5), 0) 10160 10161 def test_grad_none(self): 10162 def fn(x, y): 10163 x.grad = torch.abs(y) 10164 x.grad.add_(y) 10165 return torch.abs(y) 10166 10167 y = torch.arange(4).reshape(2, 2).to(torch.float) 10168 x = torch.randn(2, 2) 10169 x.grad = None 10170 10171 z = fn(x, y) 10172 ref_y = torch.clone(z).detach() 10173 ref_x_grad = torch.clone(x.grad).detach() 10174 10175 y = torch.arange(4).reshape(2, 2).to(torch.float) 10176 x = torch.randn(2, 2) 10177 x.grad = None 10178 10179 opt_fn = torch.compile(fn, backend="eager") 10180 z = opt_fn(x, y) 10181 self.assertEqual(z, ref_y) 10182 self.assertEqual(x.grad, ref_x_grad) 10183 10184 def test_grad_non_none(self): 10185 def fn(x, y): 10186 x.grad.add_(y) 10187 return torch.abs(y) 10188 10189 y = torch.ones(2, 2) 10190 x = torch.randn(2, 2) 10191 x.grad = torch.arange(4).reshape(2, 2).to(torch.float) 10192 10193 z = fn(x, y) 10194 ref_y = torch.clone(z).detach() 10195 ref_x_grad = torch.clone(x.grad).detach() 10196 10197 y = torch.ones(2, 2) 10198 x = torch.randn(2, 2) 10199 x.grad = torch.arange(4).reshape(2, 2).to(torch.float) 10200 10201 cnt = torch._dynamo.testing.CompileCounterWithBackend("eager") 10202 opt_fn = torch.compile(fn, backend=cnt) 10203 z = opt_fn(x, y) 10204 10205 # Ensure that the generated graph returns only one output. We want the 10206 # add_ on the grad to be part of the graph itself, so that inductor can 10207 # theoretically move the add_ and resutling copy_ nodes at the right 10208 # place to free memory. 10209 self.assertEqual(len(list(cnt.graphs[0].graph.nodes)[-1].all_input_nodes), 1) 10210 self.assertEqual(z, ref_y) 10211 self.assertEqual(x.grad, ref_x_grad) 10212 10213 def test_new_with_int_list(self): 10214 # Make sure torch.Tensor.new(int argument list) behaves the same on dynamo. 10215 def fn(x): 10216 return x.new(*x.size()) + 5 10217 10218 optfn = torch.compile(backend="eager")(fn) 10219 10220 x = torch.arange(10).view(2, 5) 10221 10222 expected = fn(x) 10223 actual = optfn(x) 10224 10225 self.assertEqual(expected.dtype, actual.dtype) 10226 self.assertEqual(expected.shape, actual.shape) 10227 self.assertEqual(expected.stride(), actual.stride()) 10228 self.assertEqual(expected.storage_offset(), actual.storage_offset()) 10229 10230 @torch._dynamo.config.patch(guard_nn_modules=True) 10231 def test_hasattr_nn_module_guard(self): 10232 class M(torch.nn.Module): 10233 def __init__(self): 10234 super().__init__() 10235 self.a = torch.nn.Linear(3, 3) 10236 10237 def forward(self, x): 10238 if hasattr(self, "a"): 10239 return self.a(x) 10240 else: 10241 return x 10242 10243 m = M() 10244 x = torch.randn(3, 3) 10245 ref = m(x) 10246 10247 opt_m = torch.compile(backend="eager")(m) 10248 res = opt_m(x) 10249 self.assertEqual(ref, res) 10250 10251 def test_ordered_dict_move_to_end(self): 10252 d = { 10253 "foo": 1, 10254 "bar": 2, 10255 } 10256 10257 d = collections.OrderedDict(d) 10258 d.move_to_end("foo") 10259 10260 @torch.compile(backend="eager") 10261 def fn(x, d): 10262 return x * d["foo"] * d["bar"] 10263 10264 fn(torch.randn(4), d) 10265 with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): 10266 fn(torch.randn(4), d) 10267 10268 def test_defaultdict(self): 10269 d = collections.defaultdict() 10270 d["foo"] = 1 10271 d["bar"] = 2 10272 10273 @torch.compile(backend="eager") 10274 def fn(x, d): 10275 return x * d["foo"] * d["bar"] 10276 10277 fn(torch.randn(4), d) 10278 with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): 10279 fn(torch.randn(4), d) 10280 10281 def test_custom_dict(self): 10282 class MyDict(dict): 10283 pass 10284 10285 d = { 10286 "foo": 1, 10287 "bar": 2, 10288 } 10289 10290 d = MyDict(d) 10291 10292 @torch.compile(backend="eager") 10293 def fn(x, d): 10294 return x * d["foo"] * d["bar"] 10295 10296 fn(torch.randn(4), d) 10297 with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): 10298 fn(torch.randn(4), d) 10299 10300 @unittest.skipIf(not TEST_CUDA, "requires cuda") 10301 @torch._dynamo.config.patch( 10302 capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True 10303 ) 10304 @torch._functorch.config.patch(fake_tensor_propagate_real_tensors=True) 10305 def test_interpolate_propagate_real_tensors(self): 10306 @torch.compile(backend="eager", fullgraph=True) 10307 def f(mask, box): 10308 # u0, u1 = mask.tolist() 10309 mask = torch.randn(1, 1, 30, 30, device="cuda") 10310 h, w = box.tolist() 10311 return torch.nn.functional.interpolate( 10312 mask, (h, w), mode="bilinear", align_corners=False 10313 ) 10314 10315 f(torch.tensor([30, 30], device="cuda"), torch.tensor([68, 32], device="cuda")) 10316 10317 def test_custom_iter_dict(self): 10318 class ReversedDict(dict): 10319 def __iter__(self): 10320 return reversed(list(self.keys())) 10321 10322 d = { 10323 "foo": 1, 10324 "bar": 2, 10325 } 10326 10327 d = ReversedDict(d) 10328 10329 @torch.compile(backend="eager") 10330 def fn(x, d): 10331 return x * d["foo"] * d["bar"] 10332 10333 fn(torch.randn(4), d) 10334 with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): 10335 fn(torch.randn(4), d) 10336 10337 def test_custom_keys_iter_dict(self): 10338 class ReversedDict(dict): 10339 def keys(self): 10340 return ["bar", "foo"] 10341 10342 d = { 10343 "foo": 1, 10344 "bar": 2, 10345 } 10346 10347 d = ReversedDict(d) 10348 10349 @torch.compile(backend="eager") 10350 def fn(x, d): 10351 return x * d["foo"] * d["bar"] 10352 10353 fn(torch.randn(4), d) 10354 with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): 10355 fn(torch.randn(4), d) 10356 10357 def test_dict_guard_on_keys_order(self): 10358 d = { 10359 2: 4, 10360 3: 5, 10361 } 10362 10363 cnts = torch._dynamo.testing.CompileCounter() 10364 10365 def fn(x, d): 10366 for key, value in d.items(): 10367 x = x * key + value 10368 return x 10369 10370 opt_fn = torch.compile(fn, backend=cnts) 10371 opt_fn(torch.randn(4), d) 10372 opt_fn(torch.randn(4), d) 10373 # No recompilation 10374 self.assertEqual(cnts.frame_count, 1) 10375 10376 # move 2 to the end 10377 d[2] = d.pop(2) 10378 10379 x = torch.randn(4) 10380 res = opt_fn(x, d) 10381 # Check recompilation 10382 self.assertEqual(cnts.frame_count, 2) 10383 self.assertEqual(res, fn(x, d)) 10384 10385 def test_dict_guard_on_keys_order2(self): 10386 d = { 10387 2: 4, 10388 3: 5, 10389 } 10390 10391 cnts = torch._dynamo.testing.CompileCounter() 10392 10393 def fn(x, d): 10394 for key in d: 10395 value = d[key] 10396 x = x * key + value 10397 return x 10398 10399 opt_fn = torch.compile(fn, backend=cnts) 10400 opt_fn(torch.randn(4), d) 10401 opt_fn(torch.randn(4), d) 10402 # No recompilation 10403 self.assertEqual(cnts.frame_count, 1) 10404 10405 # move 2 to the end 10406 d[2] = d.pop(2) 10407 10408 x = torch.randn(4) 10409 res = opt_fn(x, d) 10410 # Check recompilation 10411 self.assertEqual(cnts.frame_count, 2) 10412 self.assertEqual(res, fn(x, d)) 10413 10414 def test_contains_dunder_dict(self): 10415 class UserDefined: 10416 def __init__(self): 10417 self.a = 3 10418 self.b = 5 10419 10420 def run(self, x): 10421 if "a" in self.__dict__: 10422 x = x * self.a 10423 if "b" in self.__dict__: 10424 x = x * self.b 10425 self.c = 7 10426 if "c" in self.__dict__: 10427 x = x * self.c 10428 return x * self.__dict__.get("a") * self.__dict__.get("z", 2) 10429 10430 obj = UserDefined() 10431 10432 def fn(x): 10433 return obj.run(x) 10434 10435 x = torch.randn(4) 10436 ref = fn(x) 10437 opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 10438 res = opt_fn(x) 10439 self.assertEqual(ref, res) 10440 10441 def test_module_dunder_dict(self): 10442 class MyModule(torch.nn.Module): 10443 def __init__(self): 10444 super().__init__() 10445 self.foo = 1 10446 self.bar = 2 10447 self.baz = 3 10448 10449 def forward(self, x): 10450 if "foo" in self.__dict__: 10451 return x * self.bar 10452 return x * self.baz 10453 10454 mod = MyModule() 10455 x = torch.randn(10) 10456 opt_mod = torch.compile(mod, backend="eager", fullgraph=True) 10457 self.assertEqual(mod(x), opt_mod(x)) 10458 10459 10460class TestTracer(JitTestCase): 10461 def test_jit_save(self): 10462 def fn(): 10463 class Foo(torch.nn.Module): 10464 def __init__(self): 10465 super().__init__() 10466 self.a = 3 10467 10468 @torch.jit.export 10469 def __getstate__(self): 10470 return (3, self.training) 10471 10472 @torch.jit.export 10473 def __setstate__(self, state): 10474 self.a = state[0] 10475 self.training = state[1] 10476 10477 def forward(self, x): 10478 return x + self.a 10479 10480 f = Foo() 10481 10482 return torch.jit.trace(f, (torch.rand(3, 4),)) 10483 10484 fn() 10485 opt_fn = torch._dynamo.optimize("eager")(fn) 10486 opt_fn() 10487 10488 10489if __name__ == "__main__": 10490 from torch._dynamo.test_case import run_tests 10491 10492 run_tests() 10493