1# Owner(s): ["module: dynamo"] 2# flake8: noqa: E731, C405, F811, C418, C417 3import collections 4import functools 5import inspect 6import itertools 7import math 8import operator 9import random 10import sys 11import unittest 12from dataclasses import dataclass, field 13from typing import Any, Dict, List, NamedTuple 14from unittest.mock import patch 15 16import numpy as np 17 18import torch 19import torch._dynamo.test_case 20import torch._dynamo.testing 21from torch import sub 22from torch._dynamo.testing import ( 23 CompileCounterWithBackend, 24 EagerAndRecordGraphs, 25 normalize_gm, 26) 27from torch._dynamo.utils import ifdynstaticdefault, same 28from torch._dynamo.variables import ConstantVariable 29from torch._dynamo.variables.lists import RangeVariable 30from torch.nn import functional as F 31from torch.testing._internal.common_utils import ( 32 disable_translation_validation_if_dynamic_shapes, 33 instantiate_parametrized_tests, 34 parametrize, 35) 36 37# Defines all the kernels for tests 38from torch.testing._internal.triton_utils import * # noqa: F403 39 40 41d = torch.ones(10, 10) 42e = torch.nn.Linear(10, 10) 43flag = True 44 45 46class CustomDictSubclass(collections.OrderedDict): 47 pass 48 49 50clip01 = functools.partial(torch.clip, min=0.0, max=1.0) 51 52 53def constant3(a, b): 54 return a - b + (1.0 + 2) 55 56 57_variable = 0 58 59 60def update_global(x): 61 global _variable 62 _variable += 1 63 # Check that updated global variable value is picked up 64 return x * _variable 65 66 67def func_with_default(a, b, some_default_arg=True): 68 if some_default_arg: 69 return a - b 70 71 72def make_test(fn=None, expected_frame_count=1): 73 if fn is None: 74 return lambda fn: make_test(fn, expected_frame_count=expected_frame_count) 75 76 nargs = len(inspect.signature(fn).parameters) 77 78 def test_fn(self): 79 return torch._dynamo.testing.standard_test( 80 self, 81 fn=fn, 82 nargs=nargs, 83 expected_frame_count=expected_frame_count, 84 ) 85 86 return test_fn 87 88 89class MyCls: 90 a = 1 91 92 93@torch.jit.script_if_tracing 94def inline_script_if_tracing(x): 95 return x + 1.2 96 97 98@torch.jit.ignore 99def inline_ignore(x): 100 return x + 3.4 101 102 103@torch.jit.unused 104def inline_unused(x): 105 return x + 5.6 106 107 108@functools.lru_cache 109def inline_lru_cache_fn_with_default_args(x, y, _=None): 110 return torch.sin(x * y) 111 112 113@torch.jit.script_if_tracing 114def inline_script_if_tracing_fn_with_default_args(x, y, c=1.2): 115 return torch.cos(x * y) + c 116 117 118class FunctionTests(torch._dynamo.test_case.TestCase): 119 @make_test 120 def test_inline_jit_annotations(x): 121 x = inline_script_if_tracing(x) 122 x = inline_ignore(x) 123 x = inline_unused(x) 124 return 125 126 @make_test 127 def test_inline_script_if_tracing_fn_with_default_args(a, b): 128 return inline_script_if_tracing_fn_with_default_args(a, b) 129 130 @make_test 131 def test_inline_lru_cache_fn_with_default_args(a, b): 132 return inline_lru_cache_fn_with_default_args(a, 2, b) 133 134 @make_test 135 def test_add(a, b): 136 return a + b 137 138 @make_test 139 def test_add_(a, b): 140 a_copy = torch.tensor(a) 141 return a_copy.add_(b, alpha=5.0) 142 143 @make_test 144 def test_addcdiv(a, b, c): 145 # dynamo decomposes this to avoid a graph break when 146 # the value kwarg is populated 147 return torch.addcdiv(a, b, c, value=5.0) 148 149 @make_test 150 def test_addcdiv_(a, b, c): 151 a_copy = torch.tensor(a) 152 return a_copy.addcdiv_(b, c, value=5.0) 153 154 @make_test 155 def test_is_not_null(a, b): 156 if a is not None and b is not None: 157 return a + b 158 159 def test_foreach_lerp_(self): 160 def fn(x, y, s): 161 return torch._foreach_lerp_(x, y, s) 162 163 cnt = torch._dynamo.testing.CompileCounter() 164 165 fn_opt = torch.compile(backend=cnt, fullgraph=True)(fn) 166 expected = fn( 167 [torch.ones(2, 2) * 4.26, torch.ones(2, 2) * 3.14], 168 [torch.ones(2, 2), torch.ones(2, 2)], 169 torch.tensor(0.5), 170 ) 171 172 actual = fn_opt( 173 [torch.ones(2, 2) * 4.26, torch.ones(2, 2) * 3.14], 174 [torch.ones(2, 2), torch.ones(2, 2)], 175 torch.tensor(0.5), 176 ) 177 self.assertTrue(same(expected, actual)) 178 179 def test_broadcast_foreach_pow(self): 180 from torch._dynamo.utils import same 181 182 def fn(x, y): 183 return torch._foreach_pow(x, y) 184 185 cnt = torch._dynamo.testing.CompileCounter() 186 187 fn_opt = torch.compile(backend=cnt, fullgraph=True)(fn) 188 inps = (torch.tensor(0.80), [torch.tensor(3.4), torch.tensor(7.8)]) 189 190 actual = fn_opt(*inps) 191 expected = fn(*inps) 192 self.assertTrue(same(actual, expected)) 193 self.assertTrue(cnt.frame_count, 1) 194 195 def test_addcmul_(self): 196 from copy import deepcopy 197 198 from torch._dynamo.utils import same 199 200 def fn(x, y, z, s): 201 return x.addcmul_(y, z, value=s) 202 203 cnt = torch._dynamo.testing.CompileCounter() 204 fn_opt = torch.compile(backend=cnt, fullgraph=True)(fn) 205 inps = ( 206 torch.ones(2, 2), 207 torch.ones(2, 2) + 1, 208 torch.rand(2, 2), 209 torch.tensor(0.3), 210 ) 211 inps_2 = deepcopy(inps) 212 actual = fn_opt(*inps) 213 expected = fn(*inps_2) 214 self.assertTrue(same(actual, expected)) 215 self.assertEqual(cnt.frame_count, 1) 216 217 @make_test 218 def test_functools_partial(a, b): 219 return clip01(a + b) 220 221 @make_test 222 def test_itertools_product(a, b): 223 v = a 224 for x, i in itertools.product([a, b], [1, 2]): 225 v = v + x * i 226 return v 227 228 @make_test 229 def test_itertools_chain(a, b): 230 v = a 231 for x in itertools.chain([a, b], [1, 2]): 232 v = v + x 233 return v 234 235 @make_test 236 def test_itertools_chain_from_iterable(a, b): 237 v = a 238 for x in itertools.chain.from_iterable([[a, b], [1, 2]]): 239 v = v + x 240 return v 241 242 def test_itertools_reconstruct(self): 243 def fn(a): 244 it1 = itertools.repeat(1) 245 it2 = itertools.count(2) 246 for _ in range(3): 247 a += next(it1) 248 a += next(it2) 249 return it1, it2, a 250 251 opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 252 i1, i2, a = fn(torch.ones(3, 3)) 253 it1, it2, b = opt_fn(torch.ones(3, 3)) 254 self.assertEqual(next(i1), next(it1)) 255 self.assertEqual(next(i2), next(it2)) 256 self.assertEqual(a, b) 257 258 @make_test 259 def test_obj_eq(a, b): 260 v = a + b 261 if MyCls() == None: # noqa: E711 262 return -1 263 if MyCls() != None: # noqa: E711 264 v = v.sin() 265 if MyCls() == MyCls(): 266 return -2 267 if MyCls() != MyCls(): 268 return v + 1 269 return -3 270 271 @make_test 272 def test_cls_eq(a, b): 273 v = a + b 274 if MyCls == None: # noqa: E711 275 return -1 276 if MyCls != None: # noqa: E711 277 v = v.sin() 278 if MyCls != MyCls: 279 return -2 280 if MyCls == MyCls: 281 return v + 1 282 return -3 283 284 @make_test 285 def test_obj_is(a, b): 286 v = a + b 287 if MyCls() is None: # noqa: E711 288 return -1 289 if MyCls() is not None: # noqa: E711 290 v = v.sin() 291 if MyCls() is MyCls(): 292 return -2 293 if MyCls() is not MyCls(): 294 return v + 1 295 return -3 296 297 @make_test 298 def test_cls_is(a, b): 299 v = a + b 300 if MyCls is None: # noqa: E711 301 return -1 302 if MyCls is not None: # noqa: E711 303 v = v.sin() 304 if MyCls is not MyCls: 305 return -2 306 if MyCls is MyCls: 307 return v + 1 308 return -3 309 310 @make_test 311 def test_itertools_combinations(a, b): 312 combs = [] 313 for size in itertools.combinations((1, 2, 3, 4), 2): 314 combs.append(torch.ones(size)) 315 return combs 316 317 @make_test 318 def test_np_iinfo(a): 319 max_dim = np.iinfo(np.int16).max 320 return a + max_dim 321 322 @make_test 323 def test_np_finfo(a): 324 min_dim = np.finfo(np.float32).min 325 return a + min_dim 326 327 @make_test 328 def test_constant1(a, b, c): 329 return a - b * c + 1.0 330 331 @make_test 332 def test_constant2(a, b, c): 333 return a - b * c + 1 334 335 @make_test 336 def test_constant3(a): 337 b = 1 338 c = 2 339 d = 3 340 return b + c - d + a 341 342 @make_test 343 def test_constant4(a, b): 344 c = 2 345 d = 3 346 if c > d: 347 return a - b 348 return b - a 349 350 @make_test 351 def test_cls_hasattr(self, x): 352 if hasattr(MyCls, "a"): 353 x = x + 1 354 if hasattr(MyCls, "b"): 355 x = x + 2 356 return x 357 358 @make_test 359 def test_finfo(a, b): 360 if torch.iinfo(torch.int32).bits == 32: 361 return torch.finfo(a.dtype).min * b 362 363 @make_test 364 def test_globalfn(a, b): 365 return sub(a, b) 366 367 @make_test 368 def test_viatorch(a, b): 369 return torch.sub(a, b) 370 371 @make_test 372 def test_viamethod(a, b): 373 return a.sub(b) 374 375 @make_test 376 def test_indirect1(a, b): 377 t = a.sub 378 return t(b) 379 380 @make_test 381 def test_indirect2(a, b): 382 t = a.sub 383 args = (b,) 384 return t(*args) 385 386 @make_test 387 def test_indirect3(a, b): 388 t = a.sub 389 args = (b,) 390 kwargs = {} 391 return t(*args, **kwargs) 392 393 @make_test 394 def test_methodcall1(a, b, c): 395 return constant3(a, b) * c 396 397 @make_test 398 def test_methodcall2(a, b): 399 return constant3(a=b, b=a) + 1 400 401 @make_test 402 def test_methodcall3(a, b): 403 return constant3(a, b=1.0) + b 404 405 def test_is_integer(self): 406 @torch.compile(backend="eager", fullgraph=True) 407 def forward(t, m): 408 return 2 * t if m.is_integer() else t 409 410 t = torch.tensor([1]) 411 self.assertEqual(forward(t, 1.0).item(), 2) 412 self.assertEqual(forward(t, 1.5).item(), 1) 413 414 @parametrize( 415 "method, num_type", 416 ( 417 ("as_integer_ratio", int), 418 ("bit_length", int), 419 ("conjugate", int), 420 ("as_integer_ratio", float), 421 ("conjugate", float), 422 ("hex", float), 423 ("is_integer", float), 424 ), 425 ) 426 def test_number_method(self, method, num_type): 427 def forward(t, m): 428 return 2 * t if getattr(m, method)() else t 429 430 wrapped = torch.compile(backend="eager", fullgraph=True)(forward) 431 432 for i in (0, 1, 2.5): 433 m = num_type(i) 434 t = torch.tensor([1]) 435 actual = wrapped(t, m) 436 expected = forward(t, m) 437 self.assertEqual(actual, expected) 438 439 @make_test 440 def test_device_constant(a): 441 return a + torch.ones(1, device=torch.device("cpu")) 442 443 @make_test 444 def test_tuple1(a, b): 445 args = (a, b) 446 return sub(*args) 447 448 @make_test 449 def test_tuple2(a, b): 450 args = [a, b] 451 return sub(*args) 452 453 @make_test 454 def test_is_in_onnx_export(x, y): 455 if torch.onnx.is_in_onnx_export(): 456 return x - 1 457 else: 458 return y + 1 459 460 @make_test 461 def test_is_fx_tracing(x, y): 462 if torch.fx._symbolic_trace.is_fx_tracing(): 463 return x - 1 464 else: 465 return y + 1 466 467 @make_test 468 def test_listarg1(a, b): 469 return torch.cat([a, b]) 470 471 @make_test 472 def test_listarg2(a, b): 473 return torch.cat((a, b), dim=0) 474 475 @make_test 476 def test_listarg3(a, b): 477 kwargs = {"tensors": (a, b), "dim": 0} 478 return torch.cat(**kwargs) 479 480 @make_test 481 def test_listarg4(a, b): 482 return torch.cat(tensors=[a, b], dim=0) 483 484 @make_test 485 def test_listarg5(a, b): 486 args = [(a, b)] 487 kwargs = {"dim": 0} 488 return torch.cat(*args, **kwargs) 489 490 def test_list_slice(self): 491 class Mock: 492 def __init__(self): 493 self.ets = [] 494 self.counter = 0 495 496 @torch.compile(backend="eager") 497 def run(self, x): 498 self.ets = self.ets[-3:] 499 self.ets.append(x) 500 return torch.sin(x) 501 502 mock = Mock() 503 mock.run(torch.randn(4)) 504 self.assertEqual(len(mock.ets), 1) 505 506 @make_test 507 def test_deque(a, b): 508 d = collections.deque([a, b]) 509 d.append(a + 1) 510 d.extend([a, b]) 511 d.insert(0, "foo") 512 tmp = d.pop() 513 514 another_deque = collections.deque([tmp]) 515 d.extendleft(another_deque) 516 another_deque.clear() 517 d.extend(another_deque) 518 519 d[2] = "setitem" 520 d = d.copy() 521 d.append(d.popleft()) 522 523 empty = collections.deque() 524 d.extend(empty) 525 526 return d 527 528 @make_test 529 def test_slice1(a): 530 return a[5] 531 532 @make_test 533 def test_slice2(a): 534 return a[:5] 535 536 @make_test 537 def test_slice3(a): 538 return a[5:] 539 540 @make_test 541 def test_slice4(a): 542 return a[2:5] 543 544 @make_test 545 def test_slice5(a): 546 return a[::2] 547 548 @make_test 549 def test_slice6(a): 550 return torch.unsqueeze(a, 0)[:, 2:] 551 552 @make_test 553 def test_range1(a): 554 return torch.tensor(range(a.size(0))) 555 556 @make_test 557 def test_range2(x, y): 558 r = x + y 559 for i in range(x.size(0) + 2): 560 r = r / y 561 return r 562 563 @make_test 564 def test_unpack1(a): 565 a, b = a[:5], a[5:] 566 return a - b 567 568 @make_test 569 def test_unpack2(a): 570 packed = [a[:5], a[5:]] 571 a, b = packed 572 return a - b 573 574 @make_test 575 def test_unpack3(a): 576 packed = (a[:5], a[5:]) 577 a, b = packed 578 return a - b 579 580 @make_test 581 def test_fn_with_self_set(a, b): 582 # avg_pool2d is an odd one with __self__ set 583 return F.avg_pool2d( 584 torch.unsqueeze(a, 0) * torch.unsqueeze(b, 1), kernel_size=2, padding=1 585 ) 586 587 @make_test 588 def test_return_tuple1(a, b): 589 return (a - b, b - a, a, b) 590 591 @make_test 592 def test_globalvar(a, b): 593 return a - b + d 594 595 @make_test 596 def test_globalmodule(x): 597 return e(x) 598 599 @make_test 600 def test_inline_with_default(a, b, c): 601 return func_with_default(a, b) * c 602 603 @make_test 604 def test_inner_function(x): 605 def fn(x): 606 return torch.add(x, x) 607 608 return fn(x) 609 610 @make_test 611 def test_transpose_for_scores(x): 612 new_x_shape = x.size()[:-1] + (2, 5) 613 x = x.view(*new_x_shape) 614 return x.permute(0, 2, 1) 615 616 @make_test 617 def test_return_tuple2(x): 618 return (torch.add(x, x), x) 619 620 @make_test 621 def test_load_global_bool(x): 622 if flag: 623 return torch.add(x, x) 624 else: 625 return x 626 627 @make_test 628 def test_len_tensor(x): 629 z = len(x) 630 return torch.add(x, z) 631 632 @make_test 633 def test_len_constant_list(x): 634 z = len([1, 2, 3]) 635 return torch.add(x, z) 636 637 @make_test 638 def test_len_constant_dict(x): 639 z = len({"foo": "bar"}) 640 return torch.add(x, z) 641 642 @make_test 643 def test_dict_copy(x): 644 z = dict({"foo": x + 1}) 645 return z 646 647 @make_test 648 def test_dict_keys(x): 649 d = {3: x} 650 keys = d.keys() 651 d[4] = x + 1 652 d2 = {3: 2, 4: "aa"} 653 return 3 in keys, 4 in keys, 5 in keys, d2.keys() == keys 654 655 @make_test 656 def test_dict_values(x): 657 d = {3: x} 658 values = d.values() 659 d[3] = x + 1 660 d[4] = x + 2 661 return len(values) 662 663 @make_test 664 def test_dict_setdefault1(x): 665 d = {"a": 1, "b": 2} 666 d.setdefault("a", 10) 667 if d["a"] == 1: 668 return x + 1 669 else: 670 return x - 1 671 672 @make_test 673 def test_dict_setdefault2(x): 674 d = {"a": 1, "b": 2} 675 d.setdefault("c", 10) 676 if d["c"] == 10: 677 return x + 1 678 else: 679 return x - 1 680 681 @make_test 682 def test_dict_setdefault3(x): 683 d = {"a": 1, "b": 2} 684 d.setdefault("c") 685 if d["c"] is None: 686 return x + 1 687 else: 688 return x - 1 689 690 @make_test 691 def test_defaultdict_setdefault1(x): 692 d = collections.defaultdict.fromkeys("a", "b") 693 d["a"] = 1 694 d["b"] = 2 695 d.setdefault("a", 10) 696 if d["a"] == 1: 697 return x + 1 698 else: 699 return x - 1 700 701 @make_test 702 def test_defaultdict_setdefault2(x): 703 d = collections.defaultdict.fromkeys("a", "b") 704 d["a"] = 1 705 d["b"] = 2 706 d.setdefault("c", 10) 707 if d["c"] == 10: 708 return x + 1 709 else: 710 return x - 1 711 712 @make_test 713 def test_defaultdict_setdefault3(x): 714 d = collections.defaultdict.fromkeys("a", "b") 715 d["a"] = 1 716 d["b"] = 2 717 d.setdefault("c") 718 if d["c"] is None: 719 return x + 1 720 else: 721 return x - 1 722 723 def test_dict_id_guard(self): 724 d1 = collections.OrderedDict({"a": 2}) 725 d2 = d1 726 727 def fn(x): 728 # Iteration forces DictGuardManager 729 for k in d1: 730 x = x * d1[k] * d2[k] 731 return x 732 733 opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 734 x = torch.randn(4) 735 self.assertEqual(fn(x), opt_fn(x)) 736 737 @make_test 738 def test_callable_lambda(x): 739 if callable(lambda x: True): 740 return x + 1 741 else: 742 return x - 1 743 744 @make_test 745 def test_callable_torch(x): 746 if callable(torch.abs): 747 return x + 1 748 else: 749 return x - 1 750 751 @make_test 752 def test_callable_builtin(x): 753 if callable(sum): 754 return x + 1 755 else: 756 return x - 1 757 758 def test_callable_class(self): 759 class CallableClass: 760 def __call__(): 761 pass 762 763 class NotCallableClass: 764 pass 765 766 @torch.compile(backend="eager", fullgraph=True) 767 def fn1(x, arg): 768 if callable(arg): 769 return x 770 return x + 1 771 772 @torch.compile(backend="eager", fullgraph=True) 773 def fn2(x, arg): 774 if callable(arg): 775 return x * 2 776 return x + 1 777 778 input = torch.randn(4) 779 780 for f in [fn1, fn2]: 781 self.assertEqual(f(input, NotCallableClass()), input + 1) 782 self.assertEqual( 783 f(input, CallableClass()), input if f is fn1 else input * 2 784 ) 785 786 # passing tensor and scalars 787 self.assertEqual(f(input, 1), input + 1) 788 self.assertEqual(f(input, 1.1), input + 1) 789 self.assertEqual(f(input, True), input + 1) 790 self.assertEqual(f(input, input), input + 1) 791 792 def test_callable_list(self): 793 @torch.compile(backend="eager", fullgraph=True) 794 def fn(x, arg): 795 if callable(arg): 796 return x 797 return x + 1 798 799 input = torch.randn(4) 800 self.assertEqual(fn(input, [1, 2, 3]), input + 1) 801 self.assertEqual(fn(input, (1, 2, 3)), input + 1) 802 803 @make_test 804 def test_len_constant_misc_iterables(x): 805 a = len((1, 2, 3)) 806 b = len("test str") 807 c = a + b 808 return torch.add(x, c) 809 810 @make_test 811 def test_dict_kwargs(x): 812 z = dict(text_embed=x + 1, other=x + 2) 813 return z 814 815 @make_test 816 def test_ordered_dict_kwargs(x): 817 z = collections.OrderedDict(sample=torch.ones(10)) 818 return z 819 820 @make_test 821 def test_custom_dict_kwargs(x): 822 z = CustomDictSubclass(sample=torch.ones(10)) 823 return z 824 825 @make_test 826 def test_float(x): 827 y = float(1.2) # noqa: UP018 828 y += float("1.2") 829 return torch.add(x, y) 830 831 @make_test 832 def test_is_floating_point(x): 833 y = x + 1 834 return torch.is_floating_point(y), torch.is_floating_point(input=y) 835 836 @make_test 837 def test_dtype(x): 838 if x.dtype == torch.float32: 839 return x + 1 840 841 @make_test 842 def test_get_default_dtype(x): 843 if x.dtype == torch.get_default_dtype(): 844 return x + 1 845 else: 846 return x - 1 847 848 @make_test 849 def test_get_autocast_gpu_dtype(x): 850 dtype = torch.get_autocast_gpu_dtype() 851 return x.type(dtype) 852 853 @make_test 854 def test_is_any_autocast_enabled(x): 855 if torch._C._is_any_autocast_enabled(): 856 return x + 1 857 else: 858 return x - 1 859 860 @make_test 861 def test_is_checkpoint_valid(x): 862 if torch.autograd._is_checkpoint_valid(): 863 return x + 1 864 else: 865 return x - 1 866 867 @make_test 868 def test_list_compare_polyfill(x): 869 for a, b, c in [ 870 [(1, 2, 3), (1, 2, 3), 7.77], 871 [(1, 4, 3), (1, 2, 3), 3.33], 872 [(1, 2), (1, 2, 3), 5.55], 873 [(1, 2, 3), (1, 2), 11.11], 874 [(1, -1, 3), (1, 2, 3), 13.33], 875 ]: 876 if a != b: 877 x += 1 * c 878 if a == b: 879 x += 2 * c 880 if a < b: 881 x += 4 * c 882 if a > b: 883 x += 8 * c 884 if a <= b: 885 x += 16 * c 886 if a >= b: 887 x += 32 * c 888 return x 889 890 @make_test 891 def test_promote_types(x): 892 if x.dtype == torch.promote_types(torch.int32, torch.float32): 893 return x + 1 894 else: 895 return x - 1 896 897 @make_test 898 def test_cublas_allow_tf32(x): 899 if torch.backends.cuda.matmul.allow_tf32: 900 return x.sin() + 1 901 902 return x.cos() - 1 903 904 @make_test 905 def test_get_calculate_correct_fan(x): 906 fan_in = torch.nn.init._calculate_correct_fan(x, "fan_in") 907 return x + fan_in 908 909 @make_test 910 def test_is_complex(x): 911 if torch.is_complex(x): 912 return x + 1 913 else: 914 return x - 1 915 916 @make_test 917 def test_tensor_is_complex(x): 918 if x.is_complex(): 919 return x + 1 920 else: 921 return x - 1 922 923 @make_test 924 def test_get_privateuse1_name(x): 925 if torch._C._get_privateuse1_backend_name() == "privateuseone": 926 return x + 1 927 else: 928 return x - 1 929 930 @make_test 931 def test_device(x): 932 if not x.is_cuda: 933 return x + 1 934 935 @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") 936 @make_test 937 def test_get_device_properties_tensor_device(a): 938 x = a.to("cuda") 939 prop = torch.cuda.get_device_properties(x.device) 940 if prop.major == 8: 941 return x + prop.multi_processor_count 942 return x + prop.max_threads_per_multi_processor 943 944 @make_test 945 def test_tensor_type(a, b): 946 m = a.to(torch.float16) 947 return b.type(m.type()) 948 949 @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") 950 @make_test 951 def test_tensor_type2(a, b): 952 m = a.to("cuda") 953 return m + b.type(m.type()) 954 955 @make_test 956 def test_tensor_type3(a, b): 957 m = a.type(torch.HalfTensor) 958 return b.type(m.type()) 959 960 @make_test 961 def test_tensor_type4(a, b): 962 m = a.type("torch.HalfTensor") 963 return b.type(m.type()) 964 965 @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") 966 @make_test 967 def test_tensor_type5(a, b): 968 m = a.type(torch.cuda.HalfTensor) 969 return b.type(m.type()) 970 971 @make_test 972 def test_tensor_element_size(a): 973 if a.element_size() > 1: 974 return (a + a.element_size(), a - a.element_size()) 975 return (a - a.element_size(), a + a.element_size()) 976 977 @make_test 978 def test_ndim(x): 979 if x.ndim == 2 and x.ndimension() == 2 and x.dim() == 2: 980 return x + 1 981 982 @make_test 983 def test_T(x): 984 return torch.ones_like(x.T) 985 986 @make_test 987 def test_mT(x): 988 return torch.ones_like(x.mT) 989 990 @make_test 991 def test_is_sparse(x): 992 if not x.is_sparse: 993 return x + 1 994 995 @make_test 996 def test_shape1(x): 997 if x.shape[0] == 10: 998 return x + 1 999 1000 @make_test 1001 def test_shape2(x): 1002 if x.size(1) == 10: 1003 return x + 1 1004 1005 @make_test 1006 def test_del(a, b): 1007 c = a + 1 1008 d = c + 2 1009 del c, a 1010 return b + d 1011 1012 @make_test 1013 def test_chunks1(x): 1014 chunk_size = 5 1015 assert x.shape[0] % chunk_size == 0 1016 assert x.shape[0] // chunk_size == 2 1017 return x[:chunk_size] - x[chunk_size:] 1018 1019 @make_test 1020 def test_import1(x, y): 1021 import torch 1022 from torch import sub 1023 1024 return sub(torch.add(x, y), y) 1025 1026 @make_test 1027 def test_return_dict(x, y): 1028 z = [x + y, y, False] 1029 return {"x": x, "z": z, "a": x, "b": z, "c": x} 1030 1031 @make_test 1032 def test_return_dict2(x, y): 1033 tmp = {"x": x} 1034 tmp["z"] = [x + y, y] 1035 tmp["y"] = y 1036 tmp["z"].append(False) 1037 return tmp 1038 1039 @make_test 1040 def test_funcdef_closure(x, y): 1041 x = x + y + 1.0 1042 1043 def inner(z): 1044 nonlocal x, y 1045 y = x + z + 20.0 1046 x = y + z + 10.0 1047 1048 inner(2.0) 1049 inner(3.0) 1050 1051 return x, y 1052 1053 @make_test 1054 def test_module_constant(x, y): 1055 r = x + y 1056 for i in range(torch._dynamo.testing.three): 1057 r = r / y 1058 return r 1059 1060 @make_test 1061 def test_inline_softmax(x, y): 1062 # This is common in sme huggingface models 1063 return torch.nn.Softmax(dim=-1)(x + y * 2) 1064 1065 @make_test 1066 def test_dtype_compare(a, b): 1067 if a.dtype == torch.float16: 1068 return a + 10 1069 if a.dtype == torch.float32: 1070 return a - b * 32 1071 1072 @make_test 1073 def test_build_list_unpack(a, b): 1074 it1 = (x + 1 for x in (a, b)) 1075 it2 = (x - 1 for x in (a, b)) 1076 return torch.cat([*it1, *it2], dim=-1) 1077 1078 @make_test 1079 def test_tensor_len(a, b): 1080 return a + b + len(a) + b.__len__() 1081 1082 @make_test 1083 def test_pop(a, b): 1084 ll = [a, b] 1085 ll.append(a + 1) 1086 ll.extend( 1087 [ 1088 b + 2, 1089 a + b, 1090 ] 1091 ) 1092 ll.pop(-1) 1093 ll.pop(0) 1094 ll.pop() 1095 v1, v2 = ll 1096 return v1 - v2 1097 1098 @make_test 1099 def test_list_convert(a, b): 1100 ll = [a + 2, b] 1101 ll = tuple(ll) 1102 tmp = b + 3 1103 ll = list(ll) 1104 v1, v2 = ll 1105 return v1 - v2 + tmp 1106 1107 @make_test 1108 def test_list_add(a, b): 1109 l1 = (a, b) 1110 l2 = () # being a LOAD_CONST in the bytecode 1111 l3 = l1 + l2 1112 return l3[0] + l3[1] 1113 1114 @make_test 1115 def test_list_index_with_constant_tensor(a, b): 1116 l1 = [a, b, a + 1, b + 1] 1117 return l1[torch.as_tensor(2)] 1118 1119 @make_test 1120 def test_startswith(a, b): 1121 x = a + b 1122 if "foobar".startswith("foo") and "test" in constant3.__module__: 1123 x = x + 1 1124 return x 1125 1126 @make_test 1127 def test_dict_ops(a, b): 1128 tmp = {"a": a + 1, "b": b + 2} 1129 assert tmp.get("zzz") is None 1130 v = tmp.pop("b") + tmp.get("a") + tmp.get("missing", 3) + tmp.pop("missing", 4) 1131 tmp.update({"d": 3}) 1132 tmp["c"] = v + tmp["d"] 1133 if "c" in tmp and "missing" not in tmp: 1134 return tmp["c"] - tmp["a"] + len(tmp) 1135 1136 @make_test 1137 def test_inline_jit__unwrap_optional(x): 1138 if torch.jit._unwrap_optional(x) is None: 1139 return torch.ones(2, 2) 1140 return x.sin() 1141 1142 @make_test 1143 def test_zip_longest(x): 1144 list1 = [1, 2, 3] 1145 list2 = ["a", "b"] 1146 list3 = [True, False, True, False] 1147 return torch.sin(x + 1), list( 1148 itertools.zip_longest(list1, list2, list3, fillvalue=None) 1149 ) 1150 1151 def test_torch_size_as_dict_key(self): 1152 def fn(x, cached): 1153 if x.shape not in cached: 1154 cached[x.shape] = x 1155 return x + cached[x.shape] 1156 1157 opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 1158 x1 = torch.randn(2, 3) 1159 x2 = torch.randn(2, 3) 1160 cached = {} 1161 ref1 = fn(x1, cached) 1162 ref2 = fn(x2, cached) 1163 cached = {} 1164 res1 = opt_fn(x1, cached) 1165 res2 = opt_fn(x2, cached) 1166 self.assertEqual(ref1, res1) 1167 self.assertEqual(ref2, res2) 1168 1169 def test_dict_param_keys(self): 1170 a_param = torch.nn.Parameter(torch.ones([4, 4])) 1171 1172 def fn(a): 1173 tmp = {"a": a, a_param: 3} 1174 return tmp["a"] + tmp[a_param] 1175 1176 test = make_test(fn) 1177 test(self) 1178 1179 def test_dict_mutable_map(self): 1180 from collections.abc import MutableMapping 1181 1182 class TensorDict(MutableMapping): 1183 def __init__(self) -> None: 1184 self._dict = {} 1185 1186 def add(self, key, value): 1187 self._dict[key] = value 1188 1189 def items(self): 1190 return self._dict.items() 1191 1192 def __delitem__(self, key): 1193 del self._dict[key] 1194 1195 def __getitem__(self, key): 1196 return self._dict[key] 1197 1198 def __iter__(self): 1199 return iter(self._dict) 1200 1201 def __len__(self): 1202 return len(self._dict) 1203 1204 def __setitem__(self, key, value): 1205 self._dict[key] = value 1206 1207 tensor_dict = TensorDict() 1208 tensor_dict.add("a", torch.ones(4) * 2) 1209 1210 def fn(x): 1211 copy_tensordict = dict(tensor_dict) 1212 return x * copy_tensordict["a"] 1213 1214 opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 1215 x = torch.randn(4) 1216 1217 ref = fn(x) 1218 res = opt_fn(x) 1219 self.assertEqual(ref, res) 1220 1221 def test_unpack_mutable_map(self): 1222 from collections.abc import MutableMapping 1223 1224 class TensorDict(MutableMapping): 1225 def __init__(self) -> None: 1226 self._dict = {} 1227 1228 def add(self, key, value): 1229 self._dict[key] = value 1230 1231 def items(self): 1232 return self._dict.items() 1233 1234 def __delitem__(self, key): 1235 del self._dict[key] 1236 1237 def __getitem__(self, key): 1238 return self._dict[key] 1239 1240 def __iter__(self): 1241 return iter(self._dict) 1242 1243 def __len__(self): 1244 return len(self._dict) 1245 1246 def __setitem__(self, key, value): 1247 self._dict[key] = value 1248 1249 tensor_dict = TensorDict() 1250 tensor_dict.add("a", torch.ones(4) * 2) 1251 1252 def gn(x, a=1): 1253 return x * a 1254 1255 def fn(x): 1256 return gn(x, **tensor_dict) 1257 1258 opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 1259 1260 x = torch.randn(4) 1261 1262 ref = fn(x) 1263 res = opt_fn(x) 1264 self.assertEqual(ref, res) 1265 1266 def _test_default_dict_helper(self, factory): 1267 dd = collections.defaultdict(factory) 1268 param = torch.nn.Parameter(torch.ones([2, 2])) 1269 1270 def fn(x): 1271 dd["a"] = x + 1 1272 dd[param] = 123 1273 dd["c"] = x * 2 1274 return dd["b"], dd 1275 1276 x = torch.randn(10, 10) 1277 ref = fn(x) 1278 opt_fn = torch._dynamo.optimize_assert("eager")(fn) 1279 res = opt_fn(x) 1280 1281 self.assertTrue(same(ref[0], res[0])) 1282 self.assertTrue(same(ref[1]["a"], res[1]["a"])) 1283 self.assertTrue(same(ref[1]["c"], res[1]["c"])) 1284 self.assertTrue(same(ref[1][param], res[1][param])) 1285 1286 def test_default_dict_dict(self): 1287 self._test_default_dict_helper(dict) 1288 1289 def test_default_dict_list(self): 1290 self._test_default_dict_helper(list) 1291 1292 def test_default_dict_tuple(self): 1293 self._test_default_dict_helper(tuple) 1294 1295 def test_default_dict_set(self): 1296 self._test_default_dict_helper(set) 1297 1298 def test_default_dict_lambda(self): 1299 self._test_default_dict_helper(lambda: dict()) # noqa: C408 1300 1301 def test_default_dict_closure(self): 1302 def factory(): 1303 return dict() # noqa: C408 1304 1305 self._test_default_dict_helper(factory) 1306 1307 def test_class_dict(self): 1308 class A: 1309 x = 4 1310 y = 5 1311 1312 def __init__(self) -> None: 1313 self.a = 6 1314 1315 a = A() 1316 1317 def fn(x): 1318 if "x" in type(a).__dict__: 1319 return x + 1 1320 return x + 2 1321 1322 opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 1323 x = torch.randn(4) 1324 self.assertEqual(fn(x), opt_fn(x)) 1325 1326 def test_default_dict_constr(self): 1327 param = torch.nn.Parameter(torch.ones([2, 2])) 1328 1329 def fn(x): 1330 dd = collections.defaultdict(lambda: dict()) # noqa: C408 1331 dd["a"] = x + 1 1332 dd[param] = 123 1333 dd["c"] = x * 2 1334 dd.update({"b": x * 3}) 1335 dd.update([["d", x - 2], ("e", x + 2)]) 1336 dd.update(zip("ab", [x + 3, x + 4])) 1337 return dd["b"], dd 1338 1339 x = torch.randn(10, 10) 1340 ref = fn(x) 1341 opt_fn = torch._dynamo.optimize_assert("eager")(fn) 1342 res = opt_fn(x) 1343 1344 self.assertTrue(same(ref[0], res[0])) 1345 self.assertTrue(same(ref[1]["a"], res[1]["a"])) 1346 self.assertTrue(same(ref[1]["b"], res[1]["b"])) 1347 self.assertTrue(same(ref[1]["c"], res[1]["c"])) 1348 self.assertTrue(same(ref[1]["d"], res[1]["d"])) 1349 self.assertTrue(same(ref[1]["e"], res[1]["e"])) 1350 self.assertTrue(same(ref[1][param], res[1][param])) 1351 1352 def test_dict_tuple_lazy_guard(self): 1353 @torch.compile(backend="eager") 1354 def fn(x, y): 1355 return torch.sin(x) * y[1] 1356 1357 fn(torch.randn(3), {1: 1, 2: 2}) 1358 # Changing the value of other key should not causing recompilation 1359 with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): 1360 fn(torch.randn(3), {1: 1, 2: 3}) 1361 1362 fn(torch.randn(3), (1, 2, 3)) 1363 # Changing the value of index 0, 2 (not 1) should not cause recompilation 1364 with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): 1365 fn(torch.randn(3), (11, 2, 13)) 1366 1367 @make_test 1368 def test_call_dict1(x): 1369 d1 = dict() # noqa: C408 1370 d1["x"] = x + 1 1371 d2 = collections.OrderedDict() 1372 d2["x"] = x + 2 1373 return d1["x"] + d2["x"] + 1 1374 1375 @make_test 1376 def test_call_dict2(x): 1377 d1 = dict() # noqa: C408 1378 d1["x"] = x 1379 d2 = collections.OrderedDict(d1) 1380 if isinstance(d2, collections.OrderedDict): 1381 return x + 1 1382 else: 1383 return x - 1 1384 1385 @make_test 1386 def test_call_dict3(x): 1387 my_list = [("a", x), ("b", x + 1), ("c", x + 2)] 1388 d1 = dict(my_list) 1389 d1["a"] = x + 10 1390 d2 = collections.OrderedDict(my_list) 1391 d2["c"] = x + 20 1392 return d1["a"] + d2["c"] + 1 1393 1394 @make_test 1395 def test_call_dict4(x): 1396 my_list = (("a", x), ("b", x + 1), ("c", x + 2)) 1397 d1 = dict(my_list) 1398 d1["a"] = x + 10 1399 d2 = collections.OrderedDict(my_list) 1400 d2["c"] = x + 20 1401 return d1["a"] + d2["c"] + 1 1402 1403 @make_test 1404 def test_call_dict5(x): 1405 my_list = iter([("a", x), ("b", x + 1), ("c", x + 2)]) 1406 d1 = dict(my_list) 1407 d1["a"] = x + 10 1408 d2 = collections.OrderedDict(my_list) 1409 d2["c"] = x + 20 1410 return d1["a"] + d2["c"] + 1 1411 1412 @make_test 1413 def test_dict_fromkeys(x, y): 1414 lst = ["a", "b"] 1415 d = dict.fromkeys(lst) 1416 d1 = dict.fromkeys(d, x + 1) 1417 d2 = collections.defaultdict.fromkeys(iter(d1), x - 2) 1418 d3 = collections.OrderedDict.fromkeys(tuple(lst), value=y) 1419 return d1["a"] * d2["b"] + d2["a"] + d1["b"] + d3["a"] + d3["b"] + 1 1420 1421 @make_test 1422 def test_dict_copy(x): 1423 my_list = [("a", x), ("b", x + 1), ("c", x + 2)] 1424 d1 = dict(my_list) 1425 d1["a"] = x + 10 1426 d2 = d1.copy() 1427 d2["a"] = x - 5 1428 d2["b"] = x + 3 1429 d3 = collections.OrderedDict(my_list) 1430 d3["c"] = x + 20 1431 d4 = d3.copy() 1432 d4["c"] = x - 10 1433 return d1["a"] * d2["a"] + d2["b"] + d3["c"] * d4["c"] + 1 1434 1435 @make_test 1436 def test_dict_update(x, y, z): 1437 d = {"a": x, "b": y} 1438 d.update({"a": y - 1}) 1439 d.update([("b", z + 1), ["c", z]]) 1440 d.update(zip("ab", [z + 3, y + 2])) 1441 1442 od = collections.OrderedDict(a=x * 3, b=y + 2) 1443 od.update({"a": y + 5}) 1444 od.update([["b", z + 6], ("c", z - 7)]) 1445 od.update(zip("ab", [z - 3, x + 2])) 1446 return d["a"] * od["a"] + od["c"] + d["b"] + od["b"] * d["c"] 1447 1448 @make_test 1449 def test_min_max(a, b): 1450 c = a + b 1451 a = a.sum() 1452 b = b.sum() 1453 a = min(max(a, 0), 1) 1454 b = max(0, min(1, b)) 1455 return max(a, b) - min(a, b) + c 1456 1457 @make_test 1458 def test_symbool_to_int(x): 1459 # this is roughly the pattern found in einops.unpack() 1460 if sum(s == -1 for s in x.size()) == 0: 1461 return x + 1 1462 else: 1463 return x - 1 1464 1465 @make_test 1466 def test_map_sum(a, b, c, d): 1467 return sum(map(lambda x: x + 1, [a, b, c, d])) 1468 1469 @make_test 1470 def test_sum(a, b, c, d): 1471 return sum([a, b, c, d]) 1472 1473 @make_test 1474 def test_sum_with_start_arg(a, b, c, d): 1475 return sum([b, c, d], a) 1476 1477 @make_test 1478 def test_sum_with_start_kwarg(a, b, c, d): 1479 return sum([b, c, d], start=a) 1480 1481 @make_test(expected_frame_count=0) 1482 def test_sum_shortcut(): 1483 return sum([0, 1.0, 2, 3.0]) 1484 1485 @make_test(expected_frame_count=0) 1486 def test_sum_shortcut_with_start_arg(): 1487 return sum([0, 1.0, 2, 3.0], -10) 1488 1489 @make_test(expected_frame_count=0) 1490 def test_sum_shortcut_with_start_kwarg(): 1491 return sum([0, 1.0, 2, 3.0], start=-10) 1492 1493 @make_test 1494 def test_reduce(a, b, c, d): 1495 return functools.reduce(operator.add, [a, b, c, d]) 1496 1497 @make_test 1498 def test_reduce_with_initial(a, b, c, d): 1499 return functools.reduce(operator.add, [b, c, d], a) 1500 1501 @make_test(expected_frame_count=0) 1502 def test_reduce_with_single(x): 1503 return functools.reduce(lambda a, b: (a, b), [x]) 1504 1505 @make_test(expected_frame_count=0) 1506 def test_reduce_with_single_with_initial(x, y): 1507 return functools.reduce(lambda a, b: (a, b), [y], x) 1508 1509 @make_test(expected_frame_count=0) 1510 def test_reduce_with_none_initial(x): 1511 return functools.reduce(lambda a, b: (a, b), [x], None) 1512 1513 @make_test 1514 def test_tuple_contains(a, b): 1515 v1 = "a" 1516 v2 = "b" 1517 v3 = "c" 1518 vals1 = (v1, v2, v3) 1519 vals2 = ("d", "e", "f") 1520 if "a" in vals1 and "b" not in vals2: 1521 return a + b 1522 return a - b 1523 1524 @unittest.skipIf( 1525 sys.version_info < (3, 9), 1526 "SET_UPDATE was added at Python 3.9", 1527 ) 1528 @make_test 1529 def test_set_update_bytecode(x): 1530 # This produces bytecode SET_UPDATE since python 3.9 1531 var = {"apple", "banana", "cherry"} 1532 if isinstance(var, set): 1533 return x + 1 1534 else: 1535 return x - 1 1536 1537 @unittest.skipIf( 1538 sys.version_info < (3, 9), 1539 "SET_UPDATE was added at Python 3.9", 1540 ) 1541 @make_test 1542 def test_set_update_list_with_duplicated_items(x): 1543 list1 = ["apple", "banana", "apple"] 1544 list2 = ["orange", "banana"] 1545 if len({*list1, *list2}) == 3: 1546 return x + 1 1547 else: 1548 return x - 1 1549 1550 @make_test 1551 def test_set_contains(a, b): 1552 vals = set(["a", "b", "c"]) 1553 if "a" in vals: 1554 x = a + b 1555 else: 1556 x = a - b 1557 if "d" in vals: 1558 y = a + b 1559 else: 1560 y = a - b 1561 return x, y 1562 1563 def test_set_isdisjoint(self): 1564 x = {"apple", "banana", "cherry"} 1565 y = {"google", "microsoft", "apple"} 1566 1567 def fn(a): 1568 if x.isdisjoint(y): 1569 return a + 1 1570 else: 1571 return a - 1 1572 1573 test = make_test(fn) 1574 test(self) 1575 1576 @make_test 1577 def test_set_intersection(a, b): 1578 set1 = {"apple", "banana", "cherry"} 1579 set2 = {"google", "microsoft", "apple"} 1580 intersection_set = set1.intersection(set2) 1581 if "apple" in intersection_set: 1582 x = a + b 1583 else: 1584 x = a - b 1585 if "banana" in intersection_set: 1586 y = a + b 1587 else: 1588 y = a - b 1589 return x, y 1590 1591 @make_test 1592 def test_set_union(a, b): 1593 set1 = {"apple", "banana", "cherry"} 1594 set2 = {"google", "microsoft", "apple"} 1595 union_set = set1.union(set2) 1596 if "apple" in union_set: 1597 x = a + b 1598 else: 1599 x = a - b 1600 if "banana" in union_set: 1601 y = a + b 1602 else: 1603 y = a - b 1604 return x, y 1605 1606 @make_test 1607 def test_set_difference(a, b): 1608 set1 = {"apple", "banana", "cherry"} 1609 set2 = {"google", "microsoft", "apple"} 1610 difference_set = set1.difference(set2) 1611 if "apple" in difference_set: 1612 x = a + b 1613 else: 1614 x = a - b 1615 if "banana" in difference_set: 1616 y = a + b 1617 else: 1618 y = a - b 1619 return x, y 1620 1621 def test_set_keys_view(self): 1622 from collections.abc import KeysView 1623 1624 class StringKeys(KeysView): 1625 def __init__(self, keys): 1626 self.keys = keys 1627 1628 def __getitem__(self, key): 1629 return self.keys.__getitem__(key) 1630 1631 def __iter__(self): 1632 yield from self.keys 1633 1634 def __repr__(self): 1635 return f"{type(self).__name__}({self.keys})" 1636 1637 def __len__(self): 1638 return len(self.keys) 1639 1640 def __contains__(self, item): 1641 return self.keys.__contains__(item) 1642 1643 a = StringKeys([1, 2, 3, 3]) 1644 1645 def fn(x): 1646 set_a = set(a) 1647 return len(set_a) * x 1648 1649 opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 1650 x = torch.rand(4) 1651 self.assertEqual(fn(x), opt_fn(x)) 1652 1653 def test_constant_set(self): 1654 s = set([1, 2]) 1655 1656 def fn(x): 1657 return torch.cos(x) * len(s) 1658 1659 opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 1660 1661 x = torch.rand(4) 1662 self.assertEqual(fn(x), opt_fn(x)) 1663 1664 # This should cause recompilation 1665 s.add(3) 1666 self.assertEqual(fn(x), opt_fn(x)) 1667 1668 def test_set_add(self): 1669 s = set([1, 2]) 1670 1671 def fn(x): 1672 s.add(3) 1673 return torch.cos(x) * len(x) 1674 1675 opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 1676 1677 x = torch.rand(4) 1678 self.assertEqual(fn(x), opt_fn(x)) 1679 self.assertEqual(len(s), 3) 1680 1681 @make_test 1682 def test_tuple_iadd(a, b): 1683 output = (a, b) 1684 output += (a + b, a - b) 1685 return output 1686 1687 @make_test 1688 def test_unpack_ex1(x): 1689 output = (x, x + 1, x + 2, x + 3) 1690 a, b, *cd = output 1691 return a - b / cd[0] 1692 1693 @make_test 1694 def test_unpack_ex2(x): 1695 output = (x, x + 1, x + 2, x + 3) 1696 *ab, c, d = output 1697 return c - d / ab[0] 1698 1699 @make_test 1700 def test_unpack_ex3(x): 1701 output = (x, x + 1, x + 2, x + 3) 1702 a, *bc, d = output 1703 return a - d / bc[0] 1704 1705 @make_test 1706 def test_const_tuple_add1(x): 1707 output = (x, x + 1, x + 2, x + 3) 1708 output = () + output + () 1709 return output[2] + output[3] 1710 1711 @make_test 1712 def test_const_tuple_add2(x): 1713 output = (x, x + 1, x + 2, x + 3) 1714 output = (None,) + output + (None,) 1715 return output[2] + output[3] 1716 1717 @make_test 1718 def test_list_truth(a, b): 1719 tmp = [1, 2, 3] 1720 if tmp: 1721 return a + b 1722 else: 1723 return a - b 1724 1725 @make_test 1726 def test_list_reversed(a, b): 1727 tmp = [a + 1, a + 2, a + 3] 1728 return a + b + next(iter(reversed(tmp))) 1729 1730 @make_test 1731 def test_list_sorted1(x): 1732 tmp = [1, 10, 3, 0] 1733 return x + 1, sorted(tmp), sorted(tmp, reverse=True) 1734 1735 @make_test 1736 def test_list_sorted2(x): 1737 y = [ 1738 ("john", "A", 8), 1739 ("jane", "B", 5), 1740 ("dave", "B", 10), 1741 ] 1742 return ( 1743 x + 1, 1744 sorted(y), 1745 sorted(y, key=lambda student: student[2]), 1746 sorted(y, key=lambda student: student[2], reverse=True), 1747 ) 1748 1749 @make_test 1750 def test_tuple_sorted(x): 1751 tmp = (1, 10, 3, 0) 1752 return x + 1, sorted(tmp), sorted(tmp, reverse=True) 1753 1754 @make_test 1755 def test_dict_sorted(x): 1756 tmp = {1: "D", 10: "B", 3: "E", 0: "F"} 1757 return x + 1, sorted(tmp), sorted(tmp, reverse=True) 1758 1759 def test_dict_hasattr(self): 1760 def fn(x): 1761 if hasattr(x, "to"): 1762 return x.to("cpu") 1763 if hasattr(x, "items"): 1764 return torch.cos(x["a"]) 1765 return x 1766 1767 opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 1768 1769 x = dict(a=torch.randn(3)) 1770 self.assertEqual(fn(x), opt_fn(x)) 1771 1772 x = torch.randn(4) 1773 self.assertEqual(fn(x), opt_fn(x)) 1774 1775 @make_test 1776 def test_list_clear(a, b): 1777 tmp = [a + 1, a + 2] 1778 tmp.clear() 1779 tmp.append(a + b) 1780 return tmp 1781 1782 @make_test 1783 def test_not_list(a): 1784 return not [a + 1] 1785 1786 @make_test 1787 def test_islice_chain(a, b): 1788 tmp1 = [a + 1, a + 2] 1789 tmp2 = [a + 3, a + 4] 1790 a, b = list(itertools.islice(itertools.chain(tmp1, tmp2), 1, 3)) 1791 c = next(itertools.islice(tmp1, 1, None)) 1792 return a - b / c 1793 1794 @make_test 1795 def test_namedtuple(a, b): 1796 mytuple = collections.namedtuple("mytuple", ["x", "y", "xy"]) 1797 tmp = mytuple(a, b, a + b) 1798 return mytuple(tmp.x, tmp[1], tmp.xy + b) 1799 1800 @make_test 1801 def test_namedtuple_defaults(a, b): 1802 mytuple = collections.namedtuple( 1803 "mytuple", ["x", "y", "xy"], defaults=(None, 1, None) 1804 ) 1805 tmp = mytuple(a, xy=b) 1806 return mytuple(tmp.x, tmp[1], tmp.xy + b) 1807 1808 class MyNamedTuple(NamedTuple): 1809 first: torch.Tensor 1810 second: torch.Tensor 1811 1812 def add(self) -> torch.Tensor: 1813 return self.first + self.second 1814 1815 @staticmethod 1816 def static_method() -> int: 1817 return 1 1818 1819 @classmethod 1820 def class_method(cls) -> str: 1821 return cls.__name__ 1822 1823 @make_test 1824 def test_namedtuple_user_methods(a, b): 1825 mytuple = FunctionTests.MyNamedTuple(a, b) 1826 return mytuple.add(), mytuple.static_method(), mytuple.class_method() 1827 1828 @make_test 1829 def test_namedtuple_hasattr(a, b): 1830 mytuple = FunctionTests.MyNamedTuple(a, b) 1831 1832 def isinstance_namedtuple(obj) -> bool: 1833 return ( 1834 isinstance(obj, tuple) 1835 and hasattr(obj, "_asdict") 1836 and hasattr(obj, "_fields") 1837 ) 1838 1839 if isinstance_namedtuple(mytuple): 1840 return a + b 1841 else: 1842 return a - b 1843 1844 @make_test 1845 def test_torch_size_hasattr(x): 1846 if hasattr(x.shape, "_fields"): 1847 return x + 1 1848 else: 1849 return x - 1 1850 1851 @make_test 1852 def test_is_quantized(a, b): 1853 if not a.is_quantized: 1854 return a + b 1855 1856 @make_test 1857 def test_fstrings1(a, b): 1858 x = 1.229 1859 tmp = f"{x:.2f} bar" 1860 if tmp.startswith("1.23"): 1861 return a + b 1862 1863 @make_test 1864 def test_fstrings2(x): 1865 tmp = f"{x.shape[0]} bar" 1866 if tmp.startswith("10"): 1867 return x + 1 1868 1869 @make_test 1870 def test_fstrings3(x): 1871 tmp = f"{x.__class__.__name__} foo" 1872 if tmp.startswith("Tensor"): 1873 return x + 1 1874 1875 @make_test 1876 def test_fstrings4(x): 1877 tmp = f"{x.shape[0]} bar" 1878 if "10" in tmp: 1879 return x + 1 1880 1881 @make_test 1882 def test_fstrings5(x): 1883 tmp = f"{x.shape[0]} bar" 1884 if "10" in (tmp + "haha"): 1885 return x + 1 1886 1887 @make_test 1888 def test_fstrings6(x): 1889 tmp = f"{x.shape[0] + x.shape[1]}" 1890 if "20" in tmp: 1891 return x + 1 1892 1893 @make_test 1894 def test_tensor_new_with_size(x): 1895 y = torch.rand(5, 8) 1896 z = x.new(y.size()) 1897 assert z.size() == y.size() 1898 1899 @make_test 1900 def test_tensor_new_with_shape(x): 1901 y = torch.rand(5, 8) 1902 z = x.new(y.shape) 1903 assert z.size() == y.size() 1904 1905 @make_test 1906 def test_jit_annotate(x): 1907 y = torch.jit.annotate(Any, x + 1) 1908 return y + 2 1909 1910 @make_test 1911 def test_is_contiguous_memory_format(tensor): 1912 if torch.jit.is_scripting(): 1913 return None 1914 elif tensor.is_contiguous(memory_format=torch.contiguous_format): 1915 return tensor + 1 1916 1917 def test_is_contiguous_frame_counts(self): 1918 data = [ 1919 torch.rand(10), 1920 torch.rand(2, 3, 32, 32), 1921 torch.rand(2, 3, 32, 32).contiguous(memory_format=torch.channels_last), 1922 torch.rand(10)[::2], 1923 torch.rand(12), 1924 torch.rand(2, 3, 24, 24).contiguous(memory_format=torch.channels_last), 1925 torch.rand(50)[::2], 1926 torch.rand(2, 3, 32, 32)[:, :, 2:-2, 3:-3], 1927 ] 1928 # dynamo should recompile for all inputs in static shapes mode 1929 expected_frame_counts_static = [1, 2, 3, 4, 5, 6, 7, 8] 1930 # dynamo should recompile for items 0, 1, 2, 6 in dynamic shapes mode 1931 expected_frame_counts_dynamic = [1, 2, 3, 4, 4, 4, 4, 5] 1932 expected_frame_counts = ifdynstaticdefault( 1933 expected_frame_counts_static, expected_frame_counts_dynamic 1934 ) 1935 dynamic = ifdynstaticdefault(False, True) 1936 1937 def func(x): 1938 if x.is_contiguous(): 1939 return x + 1 1940 elif x.is_contiguous(memory_format=torch.channels_last): 1941 return x + 2 1942 else: 1943 return x + 3 1944 1945 cnt = torch._dynamo.testing.CompileCounter() 1946 cfunc = torch._dynamo.optimize_assert(cnt, dynamic=dynamic)(func) 1947 1948 assert cnt.frame_count == 0 1949 for i, x in enumerate(data): 1950 expected = func(x) 1951 output = cfunc(x) 1952 self.assertTrue(same(output, expected)) 1953 assert cnt.frame_count == expected_frame_counts[i] 1954 1955 @make_test 1956 def test_list_slice_assignment(x): 1957 m = [1, 2, 3, 4] 1958 m[1:] = [6] * (len(m) - 1) 1959 return x + 1 1960 1961 @make_test 1962 def test_distributed_is_available(x): 1963 if torch.distributed.is_available(): 1964 return x + 1 1965 else: 1966 return x - 1 1967 1968 @unittest.skipIf( 1969 not torch.distributed.is_available(), "requires distributed package" 1970 ) 1971 @make_test 1972 def test_distributed_is_initialized(x): 1973 if torch.distributed.is_initialized(): 1974 return x + 1 1975 else: 1976 return x - 1 1977 1978 @disable_translation_validation_if_dynamic_shapes 1979 @make_test 1980 def test_torch_distributions_functions(x): 1981 normal = torch.distributions.Normal(x, torch.tensor(1)) 1982 independent = torch.distributions.Independent(normal, 1) 1983 return independent.log_prob(x) 1984 1985 @make_test 1986 def test_context_wrapping_nested_functions_no_closure(x): 1987 @torch.no_grad() 1988 def augment(x: torch.Tensor) -> torch.Tensor: 1989 return (x + 1) * 2 1990 1991 return augment(x) 1992 1993 # # This is to test the new syntax for pattern matching 1994 # # ("match ... case ...") added on python 3.10. 1995 # # Uncomment these test cases if you run on 3.10+ 1996 # @make_test 1997 # def test_match_sequence(a): 1998 # point = (5, 8) 1999 # match point: 2000 # case (0, 0): 2001 # return a 2002 # case (0, y): 2003 # return a - y 2004 # case (x, 0): 2005 # return a + x 2006 # case (x, y): 2007 # return a + x - y 2008 2009 # @make_test 2010 # def test_match_mapping_and_match_keys(x): 2011 # param = {"a": 0.5} 2012 # match param: 2013 # case {"a": param}: 2014 # return x * param 2015 # case {"b": param}: 2016 # return x / param 2017 2018 def test_math_radians(self): 2019 def func(x, a): 2020 return x + math.radians(a) 2021 2022 cnt = torch._dynamo.testing.CompileCounter() 2023 cfunc = torch._dynamo.optimize_assert(cnt)(func) 2024 2025 assert cnt.frame_count == 0 2026 x = torch.rand(10) 2027 expected = func(x, 12) 2028 output = cfunc(x, 12) 2029 self.assertTrue(same(output, expected)) 2030 assert cnt.frame_count == 1 2031 2032 @make_test 2033 def test_numpy_meshgrid(x, y): 2034 r1, r2 = np.meshgrid(x.numpy(), y.numpy()) 2035 return torch.from_numpy(r1), torch.from_numpy(r2) 2036 2037 @make_test 2038 def test_torch_from_numpy(x): 2039 a = x.numpy() 2040 b = torch.from_numpy(a) 2041 if b.size(0) == 1: 2042 return torch.tensor(True) 2043 else: 2044 return torch.tensor(False) 2045 2046 @make_test 2047 def test_numpy_size(x): 2048 a = x.numpy() 2049 return a.size 2050 2051 @make_test 2052 def test_numpy_attributes(x): 2053 a = x.numpy() 2054 return ( 2055 a.itemsize, 2056 a.strides, 2057 a.shape, 2058 a.ndim, 2059 a.size, 2060 torch.from_numpy(a.T), 2061 torch.from_numpy(a.real), 2062 torch.from_numpy(a.imag), 2063 ) 2064 2065 @make_test 2066 def test_mean_sum_np(x: torch.Tensor): 2067 x_mean = np.mean(x.numpy(), 1) 2068 x_sum = np.sum(x_mean) 2069 x_sum_array = np.asarray(x_sum) 2070 return torch.from_numpy(x_sum_array) 2071 2072 @make_test 2073 def test_return_numpy_ndarray(x): 2074 a = x.numpy() 2075 return a.T 2076 2077 @make_test 2078 def test_return_multiple_numpy_ndarray(x): 2079 a = x.numpy() 2080 return a.T, a.imag, a.real 2081 2082 @make_test 2083 def test_ndarray_method(x): 2084 a = x.numpy() 2085 return a.copy() 2086 2087 @make_test 2088 def test_ndarray_transpose(x): 2089 a = x.numpy() 2090 return a.transpose(0, 1) 2091 2092 @make_test 2093 def test_ndarray_reshape(x): 2094 a = x.numpy() 2095 return a.reshape([1, a.size]) 2096 2097 @make_test 2098 def test_ndarray_methods_returning_scalar(x): 2099 a = x.numpy() 2100 return a.max(axis=0), a.all(axis=0) 2101 2102 @make_test 2103 def test_ndarray_builtin_functions(x): 2104 a = x.numpy() 2105 return a + a, a - a 2106 2107 @make_test 2108 def test_numpy_dtype_argument_to_function(x): 2109 return np.ones_like(x, dtype=np.float64) 2110 2111 @make_test 2112 def test_numpy_dtype_call_in_function(x): 2113 dt = np.dtype("float") 2114 return np.full_like(x, 2.4, dtype=dt) 2115 2116 @make_test 2117 def test_numpy_linalg(x): 2118 return np.linalg.norm(x.numpy(), axis=0) 2119 2120 @make_test 2121 def test_numpy_fft(x): 2122 return np.fft.fftshift(x.numpy()) 2123 2124 @make_test 2125 def test_numpy_random(): 2126 x = np.random.randn(2, 2) 2127 return x - x 2128 2129 @make_test 2130 def test_partials_torch_op_kwarg(x): 2131 par_mul = functools.partial(torch.mul, other=torch.ones(10, 10)) 2132 return par_mul(x) 2133 2134 @make_test 2135 def test_partials_torch_op_arg(x): 2136 par_mul = functools.partial(torch.mul, torch.ones(10, 10)) 2137 return par_mul(x) 2138 2139 @make_test 2140 def test_partials_udf_arg(x): 2141 par_mul = functools.partial(udf_mul, torch.ones(10, 10)) 2142 return par_mul(x) 2143 2144 @make_test 2145 def test_list_add_then_mutate(x): 2146 my_list = [1, x] 2147 y = x / 4.0 2148 my_list = my_list + [x / 2.0, 4] 2149 my_list.append(y) 2150 return sum(my_list) 2151 2152 @make_test 2153 def test_list_expand_lhs(x): 2154 return sum(4 * [x]) 2155 2156 @make_test 2157 def test_in_not_in(x): 2158 mylist = [1, 2, 3, 4, 5, x] 2159 myotherlist = [1, 2, 3, 4, 5] 2160 assert 3 in mylist 2161 assert 6 not in myotherlist 2162 return sum(mylist) 2163 2164 @make_test 2165 def test_are_functorch_transforms_active(x): 2166 if torch._C._are_functorch_transforms_active(): 2167 return x + 1 2168 else: 2169 return x - 1 2170 2171 @make_test 2172 def test_partials_udf_kwarg(x): 2173 par_mul = functools.partial(udf_mul, y=torch.ones(10, 10)) 2174 return par_mul(x) 2175 2176 @make_test 2177 def test_partials_udf_kwarg_module(x, y): 2178 par_mod = functools.partial(udf_module, mod=SmallNN()) 2179 return par_mod(x=x, y=y) 2180 2181 @make_test 2182 def test_partials_udf_kwarg_method(x, y): 2183 par_mod = functools.partial(udf_module, mod=SmallNN().forward) 2184 return par_mod(x=x, y=y) 2185 2186 @make_test 2187 def test_partials_lambda(x): 2188 multiply = lambda x, y: x * y 2189 triple = functools.partial(multiply, y=3) 2190 return triple(x) 2191 2192 @unittest.skipUnless(torch.distributed.is_available(), "requires torch.distributed") 2193 @make_test 2194 def test_flat_param_same_storage_size(x, y): 2195 import torch.distributed.fsdp._flat_param as flat_param 2196 2197 if flat_param._same_storage_size(x, 100): 2198 x = x + 1 2199 else: 2200 x = x - 1 2201 if flat_param._same_storage_size(y, 123): 2202 y = y + 1 2203 else: 2204 y = y - 1 2205 return x, y 2206 2207 @parametrize( 2208 "attr", 2209 ( 2210 # True 2211 "__subclasshook__", 2212 "__lt__", 2213 "__hash__", 2214 "__ge__", 2215 "__le__", 2216 "__gt__", 2217 "__dict__", 2218 "__getattribute__", 2219 "__setattr__", 2220 "__doc__", 2221 "__repr__", 2222 "__dir__", 2223 "__init__", 2224 "__new__", 2225 "__class__", 2226 "__eq__", 2227 "__delattr__", 2228 "__reduce__", 2229 "__module__", 2230 "__format__", 2231 "__str__", 2232 "__sizeof__", 2233 "__ne__", 2234 "__call__", 2235 "__reduce_ex__", 2236 "__init_subclass__", 2237 "args", 2238 "keywords", 2239 "func", 2240 # False 2241 "__code__", 2242 "__kwdefaults__", 2243 "__defaults__", 2244 "__name__", 2245 "__annotations__", 2246 "__get__", 2247 "__builtins__", 2248 "__qualname__", 2249 "__globals__", 2250 "__closure__", 2251 ), 2252 ) 2253 def test_partials_hasattr(self, attr): 2254 def fn(t): 2255 f = lambda x, y: torch.sin(x) + torch.cos(y) 2256 p = functools.partial(f, y=t) 2257 if hasattr(p, attr): 2258 return p(t) 2259 else: 2260 return torch.zeros_like(t) 2261 2262 t = torch.randn(3, 4) 2263 counter = torch._dynamo.testing.CompileCounter() 2264 opt_fn = torch.compile(fullgraph=True, backend=counter)(fn) 2265 self.assertEqual(opt_fn(t), fn(t)) 2266 self.assertGreater(counter.frame_count, 0) 2267 2268 @unittest.expectedFailure 2269 def test_partials_hasattr_set_attr(self): 2270 def fn(t): 2271 f = lambda x, y: torch.sin(x) + torch.cos(y) 2272 p = functools.partial(f, y=t) 2273 p.__name__ = "test" 2274 if hasattr(p, "__name__"): 2275 return p(t) 2276 else: 2277 return torch.zeros_like(t) 2278 2279 t = torch.randn(3, 4) 2280 counter = torch._dynamo.testing.CompileCounter() 2281 opt_fn = torch.compile(fullgraph=True, backend=counter)(fn) 2282 self.assertEqual(opt_fn(t), fn(t)) 2283 2284 def test_filter(self): 2285 def fn(inputs): 2286 out = inputs[0] 2287 for inp in filter(lambda x: (x.requires_grad), inputs): 2288 out = out * inp 2289 return out 2290 2291 input1 = torch.arange(2, dtype=torch.bfloat16) 2292 input2 = torch.arange(2, dtype=torch.bfloat16).requires_grad_(True) 2293 inputs = [input1, input2] 2294 2295 opt_fn = torch.compile(fullgraph=True)(fn) 2296 self.assertEqual(opt_fn(inputs), fn(inputs)) 2297 2298 def test_filter_fallback(self): 2299 def fn(inputs): 2300 out = inputs[0] 2301 for inp in filter(lambda x: x[0] == 1, inputs): 2302 out = out * inp 2303 return out 2304 2305 input1 = torch.ones(2, dtype=torch.bfloat16) 2306 input2 = torch.arange(2, dtype=torch.bfloat16) 2307 inputs = [input1, input2] 2308 2309 opt_fn = torch.compile()(fn) 2310 self.assertEqual(opt_fn(inputs), fn(inputs)) 2311 2312 torch._dynamo.reset() 2313 2314 with self.assertRaises(torch._dynamo.exc.Unsupported): 2315 opt_fn = torch.compile(fullgraph=True)(fn) 2316 opt_fn(inputs) 2317 2318 def test_pow_int(self): 2319 def fn(a, b): 2320 return torch.pow(a, b) 2321 2322 x = torch.ones(2, 2) 2323 opt_fn = torch.compile(fullgraph=True, backend="eager", dynamic=True)(fn) 2324 self.assertEqual(opt_fn(x, 2), fn(x, 2)) 2325 2326 def test_tensor_size_indexed_by_symint(self): 2327 def fn(x, y): 2328 index = x.shape[-1] 2329 return x + y.shape[index] 2330 2331 x = torch.rand(10, 2) 2332 y = torch.rand(10, 8, 6) 2333 opt_fn = torch.compile(backend="eager", fullgraph=True)(fn) 2334 self.assertEqual(opt_fn(x, y), fn(x, y)) 2335 2336 def test_partials_as_input_partials_lambda(self): 2337 def fn(f0, f1, x): 2338 return f0(x) * f1(x) 2339 2340 multiply = lambda x, y: x * y 2341 lambda0 = functools.partial(multiply, y=3) 2342 lambda1 = functools.partial(multiply, y=2) 2343 2344 cnts = torch._dynamo.testing.CompileCounter() 2345 torch._dynamo.optimize(cnts, nopython=True)(fn)( 2346 lambda0, lambda1, torch.randn(2, 2) 2347 ) 2348 self.assertEqual(cnts.frame_count, 1) 2349 2350 def test_partials_as_input_partials_mod(self): 2351 def fn(f0, f1, x): 2352 return f0(x) * f1(x) 2353 2354 lambda0 = functools.partial(SmallNN(), y=torch.randn(2, 2)) 2355 lambda1 = functools.partial(SmallNN(), y=torch.randn(2, 2)) 2356 2357 cnts = torch._dynamo.testing.CompileCounter() 2358 x = torch.randn(2, 2) 2359 dynamo_result = torch._dynamo.optimize(cnts, nopython=True)(fn)( 2360 lambda0, lambda1, x 2361 ) 2362 self.assertEqual(cnts.frame_count, 1) 2363 2364 eager_result = fn(lambda0, lambda1, x) 2365 self.assertEqual(eager_result, dynamo_result) 2366 2367 def test_partials_as_input_UDF(self): 2368 def fn(f0, f1, x): 2369 return f0(x) * f1(x) 2370 2371 lambda0 = functools.partial(udf_mul, y=torch.randn(2, 2)) 2372 lambda1 = functools.partial(udf_mul, y=torch.randn(2, 2)) 2373 2374 cnts = torch._dynamo.testing.CompileCounter() 2375 x = torch.randn(2, 2) 2376 dynamo_result = torch._dynamo.optimize(cnts, nopython=True)(fn)( 2377 lambda0, lambda1, x 2378 ) 2379 self.assertEqual(cnts.frame_count, 1) 2380 2381 eager_result = fn(lambda0, lambda1, x) 2382 self.assertEqual(eager_result, dynamo_result) 2383 2384 def test_partials_graph_break_reconstruct(self): 2385 def fn(udf_mul_0, udf_mul_1, x): 2386 lambda0 = functools.partial(udf_mul_0, y=x) 2387 lambda1 = functools.partial(udf_mul_1, y=x) 2388 2389 print("break") 2390 return torch.mul(lambda0(x), lambda1(x)) 2391 2392 backend = EagerAndRecordGraphs() 2393 cnts = CompileCounterWithBackend(backend) 2394 x = torch.randn(2, 2) 2395 dynamo_result = torch._dynamo.optimize(cnts)(fn)(udf_mul, udf_mul, x) 2396 2397 eager_result = fn(udf_mul, udf_mul, x) 2398 gm = backend.graphs[0] 2399 self.assertEqual(eager_result, dynamo_result) 2400 if torch._dynamo.config.assume_static_by_default: 2401 self.assertExpectedInline( 2402 normalize_gm(backend.graphs[0].print_readable(print_output=False)), 2403 """\ 2404class GraphModule(torch.nn.Module): 2405 def forward(self, L_lambda0_keywords_y_: "f32[2, 2]"): 2406 l_lambda0_keywords_y_ = L_lambda0_keywords_y_ 2407 2408 mul: "f32[2, 2]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_ 2409 mul_1: "f32[2, 2]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None 2410 2411 mul_2: "f32[2, 2]" = torch.mul(mul, mul_1); mul = mul_1 = None 2412 return (mul_2,) 2413""", 2414 ) 2415 else: 2416 self.assertExpectedInline( 2417 normalize_gm(backend.graphs[0].print_readable(print_output=False)), 2418 """\ 2419class GraphModule(torch.nn.Module): 2420 def forward(self, s0: "Sym(s0)", L_lambda0_keywords_y_: "f32[s0, s0]"): 2421 l_lambda0_keywords_y_ = L_lambda0_keywords_y_ 2422 2423 mul: "f32[s0, s0]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_ 2424 mul_1: "f32[s0, s0]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None 2425 2426 mul_2: "f32[s0, s0]" = torch.mul(mul, mul_1); mul = mul_1 = None 2427 return (mul_2,) 2428""", 2429 ) 2430 2431 def test_partials_graph_break_reconstruct_mix(self): 2432 def fn(udf_mul_0, udf_add_1, x): 2433 lambda0 = functools.partial(udf_mul_0, y=x) 2434 lambda1 = functools.partial(udf_add_1, x) 2435 2436 print("break") 2437 return torch.mul(lambda0(x), lambda1(x)) 2438 2439 backend = EagerAndRecordGraphs() 2440 cnts = CompileCounterWithBackend(backend) 2441 x = torch.randn(2, 2) 2442 dynamo_result = torch._dynamo.optimize(cnts)(fn)(udf_mul, udf_add, x) 2443 2444 eager_result = fn(udf_mul, udf_add, x) 2445 gm = backend.graphs[0] 2446 self.assertEqual(eager_result, dynamo_result) 2447 if torch._dynamo.config.assume_static_by_default: 2448 self.assertExpectedInline( 2449 normalize_gm(backend.graphs[0].print_readable(print_output=False)), 2450 """\ 2451class GraphModule(torch.nn.Module): 2452 def forward(self, L_lambda0_keywords_y_: "f32[2, 2]"): 2453 l_lambda0_keywords_y_ = L_lambda0_keywords_y_ 2454 2455 mul: "f32[2, 2]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_ 2456 2457 add: "f32[2, 2]" = l_lambda0_keywords_y_ + l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None 2458 2459 mul_1: "f32[2, 2]" = torch.mul(mul, add); mul = add = None 2460 return (mul_1,) 2461""", 2462 ) 2463 else: 2464 self.assertExpectedInline( 2465 normalize_gm(backend.graphs[0].print_readable(print_output=False)), 2466 """\ 2467class GraphModule(torch.nn.Module): 2468 def forward(self, s0: "Sym(s0)", L_lambda0_keywords_y_: "f32[s0, s0]"): 2469 l_lambda0_keywords_y_ = L_lambda0_keywords_y_ 2470 2471 mul: "f32[s0, s0]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_ 2472 2473 add: "f32[s0, s0]" = l_lambda0_keywords_y_ + l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None 2474 2475 mul_1: "f32[s0, s0]" = torch.mul(mul, add); mul = add = None 2476 return (mul_1,) 2477""", 2478 ) 2479 2480 def test_partials_graph_break_reconstruct_mix_no_source(self): 2481 def fn(udf_mul_0, x): 2482 udf_add_1 = lambda x, y: x + y 2483 2484 lambda0 = functools.partial(udf_mul_0, y=x) 2485 lambda1 = functools.partial(udf_add_1, x) 2486 2487 print("break") 2488 return torch.mul(lambda0(x), lambda1(x)) 2489 2490 backend = EagerAndRecordGraphs() 2491 cnts = CompileCounterWithBackend(backend) 2492 x = torch.randn(2, 2) 2493 dynamo_result = torch._dynamo.optimize(cnts)(fn)(udf_mul, x) 2494 2495 eager_result = fn(udf_mul, x) 2496 gm = backend.graphs[0] 2497 self.assertEqual(eager_result, dynamo_result) 2498 if torch._dynamo.config.assume_static_by_default: 2499 self.assertExpectedInline( 2500 normalize_gm(backend.graphs[0].print_readable(print_output=False)), 2501 """\ 2502class GraphModule(torch.nn.Module): 2503 def forward(self, L_lambda0_keywords_y_: "f32[2, 2]"): 2504 l_lambda0_keywords_y_ = L_lambda0_keywords_y_ 2505 2506 mul: "f32[2, 2]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_ 2507 2508 add: "f32[2, 2]" = l_lambda0_keywords_y_ + l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None 2509 2510 mul_1: "f32[2, 2]" = torch.mul(mul, add); mul = add = None 2511 return (mul_1,) 2512""", 2513 ) 2514 else: 2515 self.assertExpectedInline( 2516 normalize_gm(backend.graphs[0].print_readable(print_output=False)), 2517 """\ 2518class GraphModule(torch.nn.Module): 2519 def forward(self, s0: "Sym(s0)", L_lambda0_keywords_y_: "f32[s0, s0]"): 2520 l_lambda0_keywords_y_ = L_lambda0_keywords_y_ 2521 2522 mul: "f32[s0, s0]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_ 2523 2524 add: "f32[s0, s0]" = l_lambda0_keywords_y_ + l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None 2525 2526 mul_1: "f32[s0, s0]" = torch.mul(mul, add); mul = add = None 2527 return (mul_1,) 2528""", 2529 ) 2530 2531 def test_partials_graph_break_reconstruct_args_and_kwargs(self): 2532 def fn(udf_mul_0, x): 2533 lambda0 = functools.partial(udf_mul_0, x, 4, z=x) 2534 lambda1 = functools.partial(udf_mul_0, 4, z=x) 2535 2536 return torch.mul(lambda0(), lambda1(5)) 2537 2538 backend = EagerAndRecordGraphs() 2539 cnts = CompileCounterWithBackend(backend) 2540 x = torch.randn(2, 2) 2541 dynamo_result = torch._dynamo.optimize(cnts)(fn)(udf_mul2, x) 2542 2543 eager_result = fn(udf_mul2, x) 2544 gm = backend.graphs[0] 2545 self.assertEqual(eager_result, dynamo_result) 2546 if torch._dynamo.config.assume_static_by_default: 2547 self.assertExpectedInline( 2548 normalize_gm(backend.graphs[0].print_readable(print_output=False)), 2549 """\ 2550class GraphModule(torch.nn.Module): 2551 def forward(self, L_x_: "f32[2, 2]"): 2552 l_x_ = L_x_ 2553 2554 mul: "f32[2, 2]" = l_x_ * 4 2555 mul_1: "f32[2, 2]" = mul * l_x_; mul = None 2556 mul_2: "f32[2, 2]" = 20 * l_x_; l_x_ = None 2557 2558 mul_3: "f32[2, 2]" = torch.mul(mul_1, mul_2); mul_1 = mul_2 = None 2559 return (mul_3,) 2560""", 2561 ) 2562 else: 2563 self.assertExpectedInline( 2564 normalize_gm(backend.graphs[0].print_readable(print_output=False)), 2565 """\ 2566class GraphModule(torch.nn.Module): 2567 def forward(self, s0: "Sym(s0)", L_x_: "f32[s0, s0]"): 2568 l_x_ = L_x_ 2569 2570 mul: "f32[s0, s0]" = l_x_ * 4 2571 mul_1: "f32[s0, s0]" = mul * l_x_; mul = None 2572 mul_2: "f32[s0, s0]" = 20 * l_x_; l_x_ = None 2573 2574 mul_3: "f32[s0, s0]" = torch.mul(mul_1, mul_2); mul_1 = mul_2 = None 2575 return (mul_3,) 2576""", 2577 ) 2578 2579 def test_partials_recompilation(self): 2580 def fn(f0, f1, x): 2581 return f0(x) * f1(x) 2582 2583 lambda0 = functools.partial(udf_mul, y=torch.randn(2, 2)) 2584 lambda1 = functools.partial(udf_mul, y=torch.randn(2, 2)) 2585 2586 cnts = torch._dynamo.testing.CompileCounter() 2587 2588 x = torch.randn(2, 2) 2589 fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 2590 dynamo_result = fn(lambda0, lambda1, x) 2591 self.assertEqual(cnts.frame_count, 1) 2592 2593 fn(lambda1, lambda0, x) 2594 self.assertEqual( 2595 cnts.frame_count, 1 2596 ) # No recompile! Tensor and udf_mul guarded 2597 2598 lambda2 = functools.partial(udf_mul, y=torch.randn(3, 3)) 2599 x = torch.randn(3, 3) 2600 fn(lambda2, lambda2, x) 2601 self.assertEqual(cnts.frame_count, 2) # Recompile! Tensor size changed 2602 2603 multiply = lambda x, y: x * y 2604 lambda3 = functools.partial(multiply, y=torch.randn(3, 3)) 2605 x = torch.randn(3, 3) 2606 fn(lambda3, lambda3, x) 2607 2608 self.assertEqual(cnts.frame_count, 3) # Recompile! func id changed 2609 2610 def fn2(f0, f1, args): 2611 return f0(*args) * f1(*args) 2612 2613 cnts = torch._dynamo.testing.CompileCounter() 2614 2615 x = torch.randn(2, 2) 2616 fn2 = torch._dynamo.optimize(cnts, nopython=True)(fn2) 2617 dynamo_result = fn2(lambda0, lambda1, [x]) 2618 self.assertEqual(cnts.frame_count, 1) # start over 2619 2620 lambda4 = functools.partial(multiply, y=3, x=torch.randn(3, 3)) 2621 fn2(lambda4, lambda4, []) 2622 2623 self.assertEqual(cnts.frame_count, 2) # Recompile! Different kwarg keys 2624 2625 lambda5 = functools.partial(multiply, 1) 2626 x = torch.randn(3, 3) 2627 fn2(lambda5, lambda5, [x]) 2628 2629 self.assertEqual(cnts.frame_count, 3) # Recompile! Different arg keys 2630 2631 lambda6 = lambda x: x + x 2632 fn2(lambda6, lambda6, [x]) 2633 self.assertEqual( 2634 cnts.frame_count, 4 2635 ) # Recompile! input is no longer a functools partial 2636 2637 def test_manual_seed(self): 2638 @torch.compile 2639 def foo(): 2640 torch.manual_seed(3) 2641 return torch.randint(0, 5, (5,)) 2642 2643 self.assertEqual(foo(), foo()) 2644 self.assertEqual(foo(), foo()) 2645 2646 def test_partial_across_graph_break_uninvoked(self): 2647 from functools import partial 2648 2649 def bar(x, **kwargs): 2650 return x + x 2651 2652 @torch.compile(backend="eager", dynamic=True) 2653 def foo(x, i): 2654 def inner(): 2655 print("this is a graph_break") 2656 return op(x) 2657 2658 op = partial(bar, dim=10) 2659 x = inner() 2660 op = partial(bar, other=10) 2661 return inner() + x 2662 2663 foo(torch.rand(1), 10) 2664 2665 def test_no_recompile_inner_function(self): 2666 def forward(inp): 2667 def g(y): 2668 return inp + y 2669 2670 print("graph break") 2671 return g(torch.rand([1])) 2672 2673 cnts = torch._dynamo.testing.CompileCounter() 2674 opt_fn = torch._dynamo.optimize(cnts)(forward) 2675 2676 input = torch.rand([2]) 2677 _ = opt_fn(input) 2678 _ = opt_fn(input) 2679 _ = opt_fn(input) 2680 # Should not have recompiled 2681 self.assertEqual(cnts.frame_count, 1) 2682 2683 def test_no_recompile_inner_lambda(self): 2684 def forward(inp): 2685 g = lambda y: inp + y 2686 print("graph break") 2687 return g(torch.rand([1])) 2688 2689 cnts = torch._dynamo.testing.CompileCounter() 2690 opt_fn = torch._dynamo.optimize(cnts)(forward) 2691 2692 input = torch.rand([2]) 2693 _ = opt_fn(input) 2694 _ = opt_fn(input) 2695 _ = opt_fn(input) 2696 # Should not have recompiled 2697 self.assertEqual(cnts.frame_count, 1) 2698 2699 def test_complex_closure(self): 2700 @torch.compile 2701 def forward(y): 2702 def a(): 2703 def x(z): 2704 return y + z 2705 2706 return x 2707 2708 return a() 2709 2710 input1 = torch.rand([2]) 2711 input2 = torch.rand([2]) 2712 res = forward(input1)(input2) 2713 self.assertTrue(same(res, input1 + input2)) 2714 2715 def test_non_inlined_closure(self): 2716 @torch.compile() 2717 def program(x, y): 2718 one = lambda x, y: x + y 2719 2720 def inner(): 2721 # Force no inlining 2722 torch._dynamo.graph_break() 2723 return one(x, y) 2724 2725 res = inner() 2726 one = lambda x, y: x - y 2727 res += inner() 2728 return res 2729 2730 input1 = torch.randn(1) 2731 input2 = torch.randn(1) 2732 2733 self.assertTrue(same(program(input1, input2), input1 + input1)) 2734 2735 @parametrize("int_or_float", ("int", "float")) 2736 def test_np_constant_collections_as_input(self, int_or_float): 2737 info_func = getattr(np, f"{int_or_float[0]}info") 2738 dt_string_arg = f"{int_or_float}16" 2739 np_dt_attr = getattr(np, dt_string_arg) 2740 2741 dt_args = [dt_string_arg, np_dt_attr] 2742 arg_variants_iter = itertools.chain( 2743 dt_args, map(np.dtype, dt_args), map(info_func, dt_args) 2744 ) 2745 2746 def func(a, b, info_or_dt): 2747 return a + info_func(info_or_dt).max 2748 2749 opt_fn = torch.compile(func) 2750 2751 a = torch.randn(2) 2752 b = torch.randn(2) 2753 eager_result = func(a, b, dt_args[0]) 2754 2755 for arg in arg_variants_iter: 2756 opt_result = opt_fn(a, b, arg) 2757 self.assertTrue(same(opt_result, eager_result)) 2758 2759 @parametrize( 2760 "typ, info_func", 2761 [ 2762 (int, np.iinfo), 2763 (float, np.finfo), 2764 ], 2765 name_fn=lambda t, _: t.__name__, 2766 ) 2767 def test_np_constant_collections_guards(self, typ, info_func): 2768 def func_info(a, info): 2769 return a + info.max 2770 2771 def func_dtype(a, dt): 2772 return a + info_func(dt).max 2773 2774 dt_args = [ 2775 np.dtype(typ), 2776 np.ones((1,), dtype=typ).dtype, 2777 np.dtype(np.dtype(typ).name), 2778 np.dtype(typ.__name__), 2779 ] 2780 cnts_1 = torch._dynamo.testing.CompileCounter() 2781 opt_fn_dtype = torch._dynamo.optimize(cnts_1)(func_dtype) 2782 a = torch.zeros(3, dtype=typ) 2783 for arg in dt_args: 2784 r = opt_fn_dtype(a, arg) 2785 # each should produce an identical arg 2786 self.assertEqual(cnts_1.frame_count, 1) 2787 2788 cnts_2 = torch._dynamo.testing.CompileCounter() 2789 opt_fn_info = torch._dynamo.optimize(cnts_2)(func_info) 2790 info_args = [info_func(dt) for dt in dt_args] 2791 for arg in info_args: 2792 r = opt_fn_info(a, arg) 2793 2794 # each should produce an identical arg 2795 self.assertEqual(cnts_2.frame_count, 1) 2796 2797 if typ is float: 2798 dt_extra = np.dtype(np.float16) 2799 else: 2800 dt_extra = np.dtype(np.int16) 2801 info_extra = info_func(dt_extra) 2802 2803 eager_result_dtype = func_dtype(a, dt_extra) 2804 compile_result_dtype = opt_fn_dtype(a, dt_extra) 2805 self.assertEqual(cnts_1.frame_count, 2) 2806 self.assertEqual(eager_result_dtype, compile_result_dtype) 2807 2808 eager_result_info = func_info(a, info_extra) 2809 compile_result_info = opt_fn_info(a, info_extra) 2810 self.assertEqual(cnts_2.frame_count, 2) 2811 self.assertEqual(eager_result_info, compile_result_info) 2812 2813 def test_compare_constant_and_tensor(self): 2814 for op in [ 2815 operator.lt, 2816 operator.le, 2817 operator.gt, 2818 operator.ge, 2819 operator.ne, 2820 operator.eq, 2821 operator.is_, 2822 operator.is_not, 2823 ]: 2824 with self.subTest(op=op): 2825 2826 def fn(x): 2827 return op(-10, x) 2828 2829 opt_fn = torch.compile(fullgraph=True)(fn) 2830 2831 x = torch.randn(10) 2832 self.assertEqual(opt_fn(x), fn(x)) 2833 2834 def test_pos(self): 2835 def fn(x, y): 2836 return operator.pos(x) * +y 2837 2838 opt_fn = torch.compile(fullgraph=True, dynamic=True)(fn) 2839 2840 def test(x, y): 2841 self.assertEqual(opt_fn(x, y), fn(x, y)) 2842 2843 test(torch.ones(4), 1) 2844 test(1, torch.ones(4)) 2845 test(-1, -1) 2846 test(-1.1, 1.1) 2847 test(True, False) 2848 test(torch.ones(4, dtype=torch.float32), 1.1) 2849 2850 def test_index(self): 2851 def fn(x, t): 2852 v = operator.index(x) 2853 torch.mul(t, v) 2854 2855 def test(a, b): 2856 self.assertEqual(opt_fn(a, b), fn(a, b)) 2857 2858 for dynamic in [True, False]: 2859 torch._dynamo.reset() 2860 opt_fn = torch._dynamo.optimize(dynamic=dynamic)(fn) 2861 t = torch.ones(1) 2862 test(10, t) 2863 test(-100, t) 2864 test(10, t) 2865 test(False, t) 2866 test(True, t) 2867 2868 def test_truth(self): 2869 def fn(x, y): 2870 return operator.truth(x) and bool(y) 2871 2872 opt_fn = torch.compile(fullgraph=True, dynamic=False)(fn) 2873 2874 def test(x, y): 2875 self.assertEqual(opt_fn(x, y), fn(x, y)) 2876 2877 test(1, 100) 2878 test(-1.1, True) 2879 test(-1.1, 1.1) 2880 test(True, False) 2881 test(torch.ones(1), 1) 2882 test(torch.zeros(1), 1) 2883 test(torch.ones(1), torch.ones(1)) 2884 2885 def test_unary_fold_op(self): 2886 for op in (operator.abs, abs, operator.neg, operator.pos, operator.truth): 2887 with self.subTest(op=op): 2888 2889 def fn(): 2890 a = range(-10, 10) 2891 return list(map(op, a)) 2892 2893 opt_fn = torch._dynamo.optimize(nopython=True)(fn) 2894 self.assertEqual(opt_fn(), fn()) 2895 2896 def test_unary_fold_op_seq(self): 2897 for op in (operator.length_hint,): 2898 with self.subTest(op=op): 2899 2900 def fn(): 2901 a = [tuple(range(-10, i)) for i in range(10)] 2902 return tuple(map(op, a)) 2903 2904 opt_fn = torch._dynamo.optimize(nopython=True)(fn) 2905 self.assertEqual(opt_fn(), fn()) 2906 2907 def gen_random_range_args(self): 2908 args_count = random.randint(1, 3) 2909 args = [random.randint(-10, 10) for _ in range(args_count)] 2910 if args_count == 3 and args[2] == 0: 2911 args[2] = 1 2912 return args 2913 2914 def test_range_length(self): 2915 def test(*args, expected=None): 2916 r = range(*args) 2917 range_variable = RangeVariable([ConstantVariable.create(v) for v in args]) 2918 2919 self.assertEqual(len(r), range_variable.range_length()) 2920 2921 if expected is not None: 2922 self.assertEqual(len(r), expected) 2923 2924 test(1, 1, 1, expected=0) 2925 test(1, 0, expected=0) 2926 test(-10, expected=0) 2927 2928 test(4, expected=4) 2929 test(10, expected=10) 2930 2931 # step >1 2932 test(1, 10, 2, expected=5) 2933 2934 # negative step 2935 test(10, 1, -1, expected=9) 2936 test(10, 1, -3) 2937 2938 # Fuzz testing 2939 for i in range(100): 2940 args = self.gen_random_range_args() 2941 print("testing :", args) 2942 test(*args) 2943 2944 def test_indexed_range(self): 2945 def test(range, index, expected=None): 2946 range_variable = RangeVariable( 2947 [ 2948 ConstantVariable.create(v) 2949 for v in [range.start, range.stop, range.step] 2950 ] 2951 ) 2952 2953 self.assertEqual( 2954 range[index], 2955 range_variable.apply_index(index).as_python_constant(), 2956 ) 2957 2958 if expected is not None: 2959 self.assertEqual(range[index], expected) 2960 2961 test(range(10), 1, expected=1) 2962 test(range(10, 20, 2), 1, expected=12) 2963 2964 # Fuzz testing 2965 for i in range(100): 2966 range_args = self.gen_random_range_args() 2967 r = range(*range_args) 2968 2969 if len(r) == 0: 2970 continue 2971 2972 index = random.randint(0, len(r) - 1) 2973 2974 print("testing:", r, index) 2975 test(r, index) 2976 2977 def test_sliced_range(self): 2978 def test(range, slice, expected=None): 2979 range_variable = RangeVariable( 2980 [ 2981 ConstantVariable.create(v) 2982 for v in [range.start, range.stop, range.step] 2983 ] 2984 ) 2985 2986 self.assertEqual( 2987 range[slice], 2988 range_variable.apply_slice(slice).as_python_constant(), 2989 ) 2990 2991 if expected is not None: 2992 self.assertEqual( 2993 range[slice], 2994 expected, 2995 ) 2996 2997 test(range(10), slice(1, 10, 2), expected=range(1, 10, 2)) 2998 test(range(10), slice(None, 10, None), expected=range(0, 10)) 2999 test(range(10), slice(-1, 7, None), expected=range(9, 7)) 3000 test(range(10), slice(-1, 7, 2), expected=range(9, 7, 2)) 3001 test(range(1, 10, 2), slice(3, 7, 2), expected=range(7, 11, 4)) 3002 test(range(1, 10, 2), slice(-3, 7, 2), expected=range(5, 11, 4)) 3003 test(range(-1, -5, -3), slice(5, None, -3), expected=range(-4, 2, 9)) 3004 3005 def rand_slice(): 3006 def flip_coin(): 3007 # 1 out of 10 3008 return random.randint(1, 10) == 5 3009 3010 def r_item(allow_zero=True): 3011 i = random.randint(-10, 10) 3012 if not allow_zero and i == 0: 3013 i = 1 3014 if flip_coin(): 3015 i = None 3016 return i 3017 3018 arg_count = random.randint(1, 3) 3019 3020 if arg_count == 1: 3021 return slice(r_item()) 3022 elif arg_count == 2: 3023 return slice(r_item(), r_item()) 3024 else: 3025 return slice(r_item(), r_item(), r_item(False)) 3026 3027 # Fuzz testing 3028 for i in range(100): 3029 range_args = self.gen_random_range_args() 3030 r = range(*range_args) 3031 # generate random slice 3032 s = rand_slice() 3033 3034 print("testing:", r, s) 3035 test(r, s) 3036 3037 def test_range_with_slice_index(self): 3038 def fn(x): 3039 acc = 1 3040 for k in range(2)[1::2]: 3041 acc *= acc * k 3042 return x * acc 3043 3044 opt_fn = torch.compile(fullgraph=True)(fn) 3045 x = torch.ones(1) 3046 self.assertEqual(opt_fn(x), fn(x)) 3047 3048 def test_range_with_index(self): 3049 def fn(x): 3050 acc = 1 3051 acc *= acc * range(10, 20, 2)[2] 3052 return x * acc 3053 3054 opt_fn = torch.compile(fullgraph=True)(fn) 3055 x = torch.ones(1) 3056 self.assertEqual(opt_fn(x), fn(x)) 3057 3058 def test_rand_inlined(self): 3059 @torch.compile(backend="eager", dynamic=True) 3060 def fn(): 3061 idx_size = [10] 3062 idx_size[random.randint(0, 0)] = random.randint(1, 8) 3063 t = tuple(idx_size) 3064 src_size = [random.randint(1, 5) + s for s in idx_size] 3065 idx = torch.empty(t) 3066 3067 fn() 3068 3069 def test_rand_tensor_partial(self): 3070 from collections import namedtuple 3071 from functools import partial 3072 3073 SdpaShape = namedtuple( 3074 "Sdpa_Shape", ["batch", "num_heads", "seq_len", "head_dim"] 3075 ) 3076 3077 @torch.compile(backend="eager") 3078 def func(): 3079 make_tensor = partial( 3080 torch.rand, device="cpu", dtype=torch.float16, requires_grad=True 3081 ) 3082 3083 bsz, num_heads, seq_len_q, seq_len_kv, head_dim = (16, 16, 128, 128, 16) 3084 make_q_tensor = partial( 3085 make_tensor, SdpaShape(bsz, num_heads, seq_len_q, head_dim) 3086 ) 3087 make_kv_tensor = partial( 3088 make_tensor, SdpaShape(bsz, num_heads, seq_len_kv, head_dim) 3089 ) 3090 t1 = make_q_tensor() 3091 t2 = make_kv_tensor() 3092 t3 = t1 + t2 3093 3094 func() 3095 3096 def test_to(self): 3097 @torch.compile(backend="eager") 3098 def fn(): 3099 t = torch.ones(2) 3100 y = t.to("meta") 3101 3102 fn() 3103 3104 def test_elipsis(self): 3105 @torch.compile(backend="eager", fullgraph=True) 3106 def fn(a, ind, val): 3107 a[ind] = val 3108 return a 3109 3110 arr = np.zeros(4) 3111 self.assertEqual(fn(arr, np.s_[...], np.ones(4)), np.ones(4)) 3112 3113 arr = np.array([[1, 1], [2, 2]]) 3114 self.assertEqual( 3115 fn(arr, np.s_[0, ...], np.zeros(2)), np.array([[0, 0], [2, 2]]) 3116 ) 3117 3118 arr = np.array([[1, 1], [2, 2]]) 3119 self.assertEqual( 3120 fn(arr, np.s_[1, ...], np.zeros(2)), np.array([[1, 1], [0, 0]]) 3121 ) 3122 3123 arr = np.array([[1, 1], [2, 2]]) 3124 self.assertEqual( 3125 fn(arr, np.s_[..., 0], np.array([3, 3])), np.array([[3, 1], [3, 2]]) 3126 ) 3127 3128 arr = np.array([[1, 1], [2, 2]]) 3129 self.assertEqual( 3130 fn(arr, np.s_[..., 1], np.array([3, 3])), np.array([[1, 3], [2, 3]]) 3131 ) 3132 3133 def test_map_return(self): 3134 def fn(a, b): 3135 return map(lambda x: x + 1, [a, b]) 3136 3137 opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 3138 m = opt_fn(torch.randn(3, 3), torch.randn(3, 3)) 3139 self.assertIsInstance(m, map) 3140 3141 @make_test 3142 def test_map_max(a, b): 3143 return max(map(lambda x: x.sum(), [a, b])) 3144 3145 # max(map(...)) graph breaks 3146 @unittest.expectedFailure 3147 @make_test 3148 def test_map_max_const(a): 3149 return max(map(lambda x: x, [1, 2, 3])), a + 1 3150 3151 @make_test 3152 def test_map_list(a, b): 3153 return list(map(lambda x: x + 1, [a, b])) 3154 3155 @make_test 3156 def test_map_tuple(a, b): 3157 return tuple(map(lambda x: x + 1, [a, b])) 3158 3159 @make_test 3160 def test_map_iter(a, b): 3161 it = iter(map(lambda x: x + 1, [a, b])) 3162 return next(it) 3163 3164 @make_test 3165 def test_map_zip_dict(a): 3166 d = dict( 3167 zip( 3168 map(lambda x: x + 1, [0, 1, 2]), 3169 [map(lambda x: x - 1, [y]) for y in [3, 4, 5]], 3170 ) 3171 ) 3172 return list(d[3])[0], a + 1 # noqa: RUF015 3173 3174 @make_test 3175 def test_map_dict_fromkeys(a): 3176 return dict.fromkeys(map(lambda x: x + 1, [0, 1])), a + 1 3177 3178 @make_test 3179 def test_map_set(a): 3180 return set(map(lambda x: x + 1, [0, 1])), a + 1 3181 3182 # test_map_sum defined earlier 3183 3184 @make_test 3185 def test_map_reduce(a, b): 3186 return functools.reduce(lambda x, y: x + y, map(lambda x: x + 1, [a, b])) 3187 3188 @make_test 3189 def test_map_sorted(a): 3190 return sorted(map(lambda x: x + 1, [0, 4, 3, 1, 2])), a + 1 3191 3192 @make_test 3193 def test_map_list_extend(a, b, c): 3194 l = [a] 3195 l.extend(map(lambda x: x + 1, [b, c])) 3196 return l 3197 3198 @make_test 3199 def test_map_list_slice_assign(a, b, c, d, e): 3200 l = [a, b, c] 3201 l[1:2] = map(lambda x: x + 1, [d, e]) 3202 return l 3203 3204 @make_test 3205 def test_map_deque_extendleft(a, b, c): 3206 d = collections.deque([a]) 3207 d.extendleft(map(lambda x: x + 1, [b, c])) 3208 return d 3209 3210 @make_test 3211 def test_map_str_join(a): 3212 return "".join(map(lambda x: x, ["a", "b", "c"])), a + 1 3213 3214 def test_map_with_graph_break(self): 3215 def f(a): 3216 a += 1 3217 3218 def g(x): 3219 nonlocal a 3220 a += 1 3221 return x + 1 3222 3223 m = map(g, [1, 2, 3, 4, 5]) 3224 a += next(m) # won't graph break 3225 torch._dynamo.graph_break() 3226 a += next(m) # will graph break 3227 return a 3228 3229 cnts = torch._dynamo.testing.CompileCounter() 3230 opt_f = torch.compile(f, backend=cnts) 3231 self.assertEqual(f(torch.ones(3, 3)), opt_f(torch.ones(3, 3))) 3232 self.assertEqual(cnts.frame_count, 3) 3233 3234 def test_map_reconstruct(self): 3235 def fn(a): 3236 return map(lambda x: x[0] + x[1], zip([1, 2, 3], [1, 2, 3])), a + 1 3237 3238 opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 3239 m = opt_fn(torch.ones(3, 3))[0] 3240 self.assertIsInstance(m, map) 3241 self.assertEqual(list(m), list(fn(torch.ones(3, 3))[0])) 3242 3243 def test_zip_reconstruct(self): 3244 def fn(a): 3245 return zip([1, 2, 3], map(lambda x: x + 1, [1, 2, 3])), a + 1 3246 3247 opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 3248 m = opt_fn(torch.ones(3, 3))[0] 3249 self.assertIsInstance(m, zip) 3250 self.assertEqual(list(m), list(fn(torch.ones(3, 3))[0])) 3251 3252 @make_test 3253 def test_map_partial_unpack(a, b): 3254 y = 1 3255 3256 def f(x): 3257 nonlocal y 3258 y += 1 3259 return x 3260 3261 l = list(zip([a, b], map(f, [1, 2, 3, 4]))) 3262 return a + y 3263 3264 @make_test 3265 def test_map_call_function_ex(a, b): 3266 def f(x, y): 3267 return x + y 3268 3269 return f(*map(lambda x: x + 1, [a, b])) 3270 3271 @make_test 3272 def test_map_unpack_twice(a, b): 3273 m = map(lambda x: x + 1, [a, b]) 3274 l1 = list(m) 3275 l2 = list(m) 3276 return l1, l2 3277 3278 @make_test 3279 def test_enumerate(a, b): 3280 return list(enumerate([a, b], start=1)), a + 1 3281 3282 @make_test 3283 def test_map_enumerate(a, b): 3284 return list(enumerate(map(lambda x: x + 1, [a, b]), start=1)), a + 1 3285 3286 @make_test 3287 def test_map_infinite(a, b): 3288 return list(map(lambda x, y: x + y, [a, b], itertools.count(3))) 3289 3290 @make_test 3291 def test_map_unpack_vars(a, b): 3292 x, y = map(lambda x: x + 1, [a, b]) 3293 return x + y 3294 3295 def test_enumerate_custom(self): 3296 class MyClass: 3297 def __iter__(self): 3298 self.a = 1 3299 return self 3300 3301 def __next__(self): 3302 if self.a > 3: 3303 raise StopIteration 3304 self.a += 1 3305 return self.a 3306 3307 def fn(x): 3308 for i, it in enumerate(MyClass()): 3309 x += i + it 3310 return x 3311 3312 opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 3313 self.assertEqual(fn(torch.ones(3, 3)), opt_fn(torch.ones(3, 3))) 3314 3315 def test_enumerate_reconstruct(self): 3316 def fn(a, b): 3317 return enumerate([a, b], start=1) 3318 3319 opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 3320 inps = (torch.randn(3, 3), torch.randn(3, 3)) 3321 it1 = fn(*inps) 3322 it2 = opt_fn(*inps) 3323 self.assertIsInstance(it2, enumerate) 3324 self.assertEqual(list(it1), list(it2)) 3325 3326 3327def udf_mul(x, y): 3328 return x * y 3329 3330 3331def udf_mul2(x, y, z): 3332 return x * y * z 3333 3334 3335def udf_add(x, y): 3336 return x + y 3337 3338 3339class SmallNN(torch.nn.Module): 3340 def forward(self, x, y): 3341 combined = torch.cat((x, y), dim=1) 3342 out = torch.nn.ReLU()(combined) 3343 out = torch.nn.ReLU()(out) 3344 return out 3345 3346 3347def udf_module(mod, x, y): 3348 return mod(x, y) 3349 3350 3351def global_func_with_default_tensor_args( 3352 x=torch.zeros((2, 2)), *, kw_x=torch.zeros((1, 2)) 3353): 3354 x.add_(1) 3355 kw_x.add_(1) 3356 return x, kw_x 3357 3358 3359class ModuleWithDefaultTensorArgsMethod(torch.nn.Module): 3360 def forward(self, x=torch.zeros((2, 2)), *, kw_x=torch.zeros((1, 2))): 3361 x.add_(1) 3362 kw_x.add_(1) 3363 return x, kw_x 3364 3365 3366class WrapperModule(torch.nn.Module): 3367 def __init__(self) -> None: 3368 super().__init__() 3369 self.m = ModuleWithDefaultTensorArgsMethod() 3370 3371 def forward(self): 3372 return self.m() 3373 3374 3375class DefaultsTests(torch._dynamo.test_case.TestCase): 3376 def test_func_default_tensor_args(self): 3377 """ 3378 Tests that we indeed reference (and mutate) "the one" default tensor arg 3379 stored on the globally allocated function object, both from the orig and 3380 compiled function 3381 """ 3382 3383 def func(): 3384 return global_func_with_default_tensor_args() 3385 3386 cnts = torch._dynamo.testing.CompileCounter() 3387 compiled_func = torch.compile(func, backend=cnts) 3388 for i in range(4): 3389 if i % 2 == 0: 3390 x, kw_x = func() 3391 else: 3392 x, kw_x = compiled_func() 3393 # the inner func mutates += 1 each call 3394 self.assertTrue(same(x, torch.ones_like(x) + i)) 3395 self.assertTrue(same(kw_x, torch.ones_like(kw_x) + i)) 3396 # Calling compiled_func twice does not recompile 3397 self.assertEqual(cnts.frame_count, 1) 3398 self.assertEqual(cnts.op_count, 2) 3399 3400 # But with a change to the guarded default tensor, we do recompile 3401 with patch.object( 3402 global_func_with_default_tensor_args, 3403 "__defaults__", 3404 (torch.ones((3, 4, 5)),), 3405 ): 3406 x, kw_x = compiled_func() 3407 self.assertEqual(cnts.frame_count, 2) 3408 self.assertEqual(cnts.op_count, 4) 3409 3410 with patch.object( 3411 global_func_with_default_tensor_args, 3412 "__kwdefaults__", 3413 {"kw_x": torch.ones((3, 4, 5))}, 3414 ): 3415 x, kw_x = compiled_func() 3416 self.assertEqual(cnts.frame_count, 3) 3417 self.assertEqual(cnts.op_count, 6) 3418 3419 def test_meth_default_tensor_args(self): 3420 """ 3421 Tests that we indeed reference (and mutate) "the one" default tensor arg 3422 stored on the globally allocated function object, both from the orig and 3423 compiled function 3424 """ 3425 mod = WrapperModule() 3426 cnts = torch._dynamo.testing.CompileCounter() 3427 compiled_mod = torch.compile(mod, backend=cnts) 3428 for i in range(4): 3429 if i % 2 == 0: 3430 x, kw_x = mod() 3431 else: 3432 x, kw_x = compiled_mod() 3433 # the inner func mutates += 1 each call 3434 self.assertTrue(same(x, torch.ones_like(x) + i)) 3435 self.assertTrue(same(kw_x, torch.ones_like(kw_x) + i)) 3436 # Calling compiled_func twice does not recompile 3437 self.assertEqual(cnts.frame_count, 1) 3438 self.assertEqual(cnts.op_count, 2) 3439 3440 # But with a change to the guarded default tensor, we do recompile 3441 with patch.object( 3442 ModuleWithDefaultTensorArgsMethod.forward, 3443 "__defaults__", 3444 (torch.ones((3, 4, 5)),), 3445 ): 3446 x, kw_x = compiled_mod() 3447 self.assertEqual(cnts.frame_count, 2) 3448 self.assertEqual(cnts.op_count, 4) 3449 3450 with patch.object( 3451 ModuleWithDefaultTensorArgsMethod.forward, 3452 "__kwdefaults__", 3453 {"kw_x": torch.ones((3, 4, 5))}, 3454 ): 3455 x, kw_x = compiled_mod() 3456 self.assertEqual(cnts.frame_count, 3) 3457 self.assertEqual(cnts.op_count, 6) 3458 3459 def test_func_default_torch_args(self): 3460 """ 3461 Tests other types of torch types as function default (size, dtype, device) 3462 """ 3463 3464 def func_with_default_torch_args( 3465 dt=torch.float16, ds=torch.Size((1, 2, 3)), dd=torch.device("cpu") 3466 ): 3467 return torch.ones(ds, dtype=dt, device=dd) 3468 3469 def func(): 3470 return func_with_default_torch_args() 3471 3472 cnts = torch._dynamo.testing.CompileCounter() 3473 compiled_func = torch.compile(func, backend=cnts) 3474 out = func() 3475 compiled_out = compiled_func() 3476 self.assertEqual(out.dtype, compiled_out.dtype) 3477 self.assertEqual(out.device, compiled_out.device) 3478 self.assertEqual(out.size(), compiled_out.size()) 3479 self.assertEqual(cnts.frame_count, 1) 3480 self.assertEqual(cnts.op_count, 1) 3481 3482 def test_dataclass_factory(self): 3483 @dataclass 3484 class Output: 3485 scalar: int = 2 3486 named_tensors: Dict[str, torch.Tensor] = field(default_factory=dict) 3487 lists: List[torch.Tensor] = field(default_factory=list) 3488 3489 def scale(self): 3490 return self.scalar * 2 3491 3492 def fn(x): 3493 # Check default dict assignment 3494 a = Output(1) 3495 # Check that dataclass methods can be inlined 3496 scaled_value = a.scale() 3497 3498 # Check that normal assignment works 3499 b = Output(5, named_tensors={"x": x}) 3500 3501 # Check default int assignment 3502 c = Output() 3503 3504 # Check that the default members are properly initialized 3505 if isinstance(a.named_tensors, dict): 3506 x = torch.sin(x) 3507 3508 # Change dataclass 3509 c.scalar = 6 3510 c.named_tensors["x"] = x 3511 3512 # Return dataclaass as well to check reconstruction 3513 return c, torch.cos(x) * scaled_value + b.named_tensors["x"] + c.scalar 3514 3515 cnts = torch._dynamo.testing.CompileCounter() 3516 compiled_fn = torch.compile(fn, backend=cnts, fullgraph=True) 3517 x = torch.randn(4) 3518 eager_dataclass, out = fn(x) 3519 compiled_dataclass, compiled_out = compiled_fn(x) 3520 self.assertEqual(eager_dataclass.scalar, compiled_dataclass.scalar) 3521 self.assertEqual( 3522 eager_dataclass.named_tensors["x"], compiled_dataclass.named_tensors["x"] 3523 ) 3524 self.assertTrue(same(out, compiled_out)) 3525 self.assertEqual(cnts.frame_count, 1) 3526 self.assertEqual(cnts.op_count, 5) 3527 3528 def test_dataclass_nested(self): 3529 @dataclass 3530 class Base: 3531 outer_a: int 3532 outer_b: int 3533 3534 @dataclass 3535 class Derived(Base): 3536 inner_a: Any = field(default_factory=list) 3537 3538 def fn(x): 3539 l = Derived(1, 2) 3540 return l.outer_a * x 3541 3542 opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 3543 x = torch.randn(4) 3544 res = fn(x) 3545 ref = opt_fn(x) 3546 self.assertEqual(ref, res) 3547 3548 def test_listlike_of_tensors_contains_constant(self): 3549 for listlike in [set, list]: 3550 3551 def fn(x): 3552 x.add_(1) 3553 s = listlike([x]) 3554 res = 1 in s 3555 return res 3556 3557 opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 3558 x = torch.randn(1) 3559 ref = opt_fn(x) 3560 res = fn(x) 3561 self.assertEqual(ref, res) 3562 3563 def test_cast_tensor_single_elem(self): 3564 with torch._dynamo.config.patch({"capture_scalar_outputs": True}): 3565 for t, val in [ 3566 (float, 1.0), 3567 (float, 1), 3568 (float, True), 3569 (int, 1), 3570 (int, False), 3571 # (int, 1.0), # fails due to a >= 0 comparison in sym_int 3572 ]: # , bool, complex]: no casting for sym_bool, no sym_complex 3573 3574 def fn(x): 3575 x = x + 1 3576 return t(x) 3577 3578 opt_fn = torch.compile( 3579 fn, backend="eager", fullgraph=True, dynamic=False 3580 ) 3581 x = torch.tensor([val]) 3582 res = fn(x) 3583 ref = opt_fn(x) 3584 self.assertEqual(ref, res) 3585 3586 # Cannot handle non single-elem 3587 with self.assertRaises(ValueError): 3588 fn(torch.tensor([val] * 2)) 3589 with self.assertRaises(torch._dynamo.exc.TorchRuntimeError): 3590 opt_fn(torch.tensor([val] * 2)) 3591 3592 def test_set_construction(self): 3593 def fn(x): 3594 y = x.add_(1) 3595 s = set({x}) 3596 s.add(y) 3597 return len(s) 3598 3599 opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 3600 x = torch.randn(4) 3601 res = fn(x) 3602 ref = opt_fn(x) 3603 self.assertEqual(ref, res) 3604 3605 def test_frozenset_construction(self): 3606 def fn(x): 3607 s = frozenset({x}) 3608 t = frozenset(s) 3609 return len(t) 3610 3611 opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 3612 x = torch.randn(4) 3613 res = fn(x) 3614 ref = opt_fn(x) 3615 self.assertEqual(ref, res) 3616 3617 def test_frozenset_reconstruction(self): 3618 d = {} 3619 f = frozenset() 3620 d[f] = torch.randn(4) 3621 3622 def fn(x): 3623 k = frozenset() 3624 torch._dynamo.graph_break() 3625 return d[k] * x 3626 3627 opt_fn = torch.compile(fn, backend="eager") 3628 x = torch.randn(4) 3629 res = fn(x) 3630 ref = opt_fn(x) 3631 self.assertEqual(ref, res) 3632 3633 def test_frozenset_illegal_call_method(self): 3634 def fn_add(): 3635 s = frozenset((1, 2, 3)) 3636 s.add({2}) 3637 return len(s) 3638 3639 def fn_pop(): 3640 s = frozenset((1, 2, 3)) 3641 s.pop() 3642 return len(s) 3643 3644 def fn_update(): 3645 s = frozenset((1, 2, 3)) 3646 s.update({4, 5, 6}) 3647 return len(s) 3648 3649 def fn_remove(): 3650 s = frozenset((1, 2, 3)) 3651 s.remove(2) 3652 return len(s) 3653 3654 def fn_discard(): 3655 s = frozenset((1, 2, 3)) 3656 s.discard(2) 3657 return len(s) 3658 3659 def fn_clear(): 3660 s = frozenset((1, 2, 3)) 3661 s.clear() 3662 return len(s) 3663 3664 for fn in [fn_add, fn_pop, fn_update, fn_remove, fn_discard, fn_clear]: 3665 torch._dynamo.reset() 3666 opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 3667 with self.assertRaises(torch._dynamo.exc.InternalTorchDynamoError): 3668 opt_fn() 3669 3670 def test_is_tensor_tensor(self): 3671 def fn(x, y): 3672 if x is y: 3673 return x * 2 3674 else: 3675 return x + y 3676 3677 fn_opt = torch.compile(backend="eager", fullgraph=True, dynamic=True)(fn) 3678 3679 x = torch.zeros(2) 3680 y = torch.ones(2) 3681 3682 self.assertEqual(fn(x, y), fn_opt(x, y)) 3683 self.assertEqual(fn(x, x), fn_opt(x, x)) 3684 3685 def test_is_not_tensor_tensor(self): 3686 def fn(x, y): 3687 if x is not y: 3688 return x * 2 3689 else: 3690 return x + y 3691 3692 fn_opt = torch.compile(backend="eager", fullgraph=True, dynamic=True)(fn) 3693 3694 x = torch.zeros(2) 3695 y = torch.ones(2) 3696 3697 self.assertEqual(fn(x, y), fn_opt(x, y)) 3698 self.assertEqual(fn(x, x), fn_opt(x, x)) 3699 3700 def test_is_mutated_tensor_tensor(self): 3701 def fn(x): 3702 y = x.add_(1) 3703 return x is y 3704 3705 fn_opt = torch.compile(backend="eager", fullgraph=True, dynamic=True)(fn) 3706 3707 z = torch.ones(4) 3708 3709 self.assertEqual(fn(z), fn_opt(z)) 3710 3711 def test_is_mutated_tensor_tensor_across_graph_break(self): 3712 def fn(x): 3713 y = x.add_(1) 3714 cond = x is y 3715 x.add_(1) 3716 # The real tensor values are recovered when graph breaking. 3717 # Hence we recover the invariant. 3718 torch._dynamo.graph_break() 3719 x.add_(1) 3720 return x is y, cond 3721 3722 fn_opt = torch.compile(backend="eager", dynamic=True)(fn) 3723 3724 z = torch.ones(4) 3725 3726 self.assertEqual(fn(z), fn_opt(z)) 3727 3728 def test_is_mutated_tensor_tensor(self): 3729 def fn(x): 3730 y = x.add_(1) 3731 return y is x 3732 3733 fn_opt = torch.compile(backend="eager", fullgraph=True, dynamic=True)(fn) 3734 3735 z = torch.ones(4, 1) 3736 3737 self.assertEqual(fn(z), fn_opt(z)) 3738 3739 def test_is_init_in_compile_mutated_tensor_tensor(self): 3740 def fn(x): 3741 z = x.clone() 3742 y = z.add_(1) 3743 return y is z 3744 3745 fn_opt = torch.compile(backend="eager", fullgraph=True, dynamic=True)(fn) 3746 3747 z = torch.ones(4, 1) 3748 3749 self.assertEqual(fn(z), fn_opt(z)) 3750 3751 def test_is_init_in_compile_vmapped_mutated_tensor_tensor(self): 3752 def fn(z): 3753 x = z.clone() 3754 y = torch.vmap(torch.Tensor.acos_)(x) 3755 _ = y is z 3756 return y is x 3757 3758 fn_opt = torch.compile(backend="eager", fullgraph=True, dynamic=True)(fn) 3759 3760 z = torch.ones(4, 1) 3761 3762 self.assertEqual(fn(z), fn_opt(z)) 3763 3764 def test_is_vmapped_mutated_tensor_tensor(self): 3765 def fn(x): 3766 y = torch.vmap(torch.Tensor.acos_)(x) 3767 return y is x 3768 3769 fn_opt = torch.compile(backend="eager", fullgraph=True, dynamic=True)(fn) 3770 3771 z = torch.ones(4, 1) 3772 3773 self.assertEqual(fn(z), fn_opt(z)) 3774 3775 def test_is_init_in_compile_vmapped_mutated_tensor_tensor_multi_arg(self): 3776 def fn(y, z): 3777 a = y.clone() 3778 b = z.clone() 3779 3780 def g(a, b): 3781 return a.acos_(), b.acos_() 3782 3783 c, d = torch.vmap(g)(a, b) 3784 return a is c is b is d 3785 3786 fn_opt = torch.compile(backend="eager", fullgraph=True, dynamic=True)(fn) 3787 3788 y = torch.ones(4, 2) 3789 z = torch.ones(4, 10) 3790 3791 self.assertEqual(fn(y, z), fn_opt(y, z)) 3792 self.assertEqual(fn(y, y), fn_opt(y, y)) 3793 3794 def test_in_set_would_fail_broadcast(self): 3795 param = torch.zeros(5) 3796 param2 = torch.zeros(5, 10) 3797 3798 tensor_list = set() 3799 tensor_list.add(param2) 3800 assert param not in tensor_list 3801 3802 def fn(param, param2): 3803 param.add_(1) 3804 tensor_list = set([param2]) 3805 return param in tensor_list 3806 3807 cnts = torch._dynamo.testing.CompileCounter() 3808 opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 3809 self.assertEqual(opt_fn(param, param2), fn(param, param2)) 3810 self.assertEqual(cnts.frame_count, 1) 3811 # Test aliased 3812 self.assertEqual(opt_fn(param, param), fn(param, param)) 3813 self.assertEqual(cnts.frame_count, 2) # Recompiles 3814 3815 def test_in_set_inplace(self): 3816 param = torch.zeros(5) 3817 param2 = torch.zeros(5, 10) 3818 3819 tensor_list = set() 3820 tensor_list.add(param2) 3821 assert param not in tensor_list 3822 3823 def fn(param, param2): 3824 y = param.add_(1) # Tensor method 3825 z = torch.Tensor.add_(y, 1) # torch function 3826 tensor_list = set([param2]) 3827 return y in tensor_list and z in tensor_list 3828 3829 cnts = torch._dynamo.testing.CompileCounter() 3830 opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 3831 self.assertEqual(opt_fn(param, param2), fn(param, param2)) 3832 self.assertEqual(cnts.frame_count, 1) 3833 # Test aliased 3834 self.assertEqual(opt_fn(param, param), fn(param, param)) 3835 self.assertEqual(cnts.frame_count, 2) # Recompiles 3836 3837 def test_reconstructed_name(self): 3838 lst = [] 3839 3840 @torch._dynamo.disable 3841 def disallowed(g): 3842 lst.append(g.__name__) 3843 3844 def f(): 3845 def g(): 3846 return () 3847 3848 disallowed(g) 3849 3850 f_opt = torch._dynamo 3851 opt_f = torch._dynamo.optimize(backend="eager")(f) 3852 opt_f() 3853 f() 3854 self.assertEqual(len(lst), 2) 3855 self.assertEqual(lst[0], lst[1]) 3856 3857 @unittest.skipIf( 3858 sys.version_info < (3, 10), 3859 "zip strict kwargs not implemented for Python < 3.10", 3860 ) 3861 def test_zip_strict(self): 3862 def fn(x, ys, zs): 3863 x = x.clone() 3864 for y, z in zip(ys, zs, strict=True): 3865 x += y * z 3866 return x 3867 3868 opt_fn = torch._dynamo.optimize(backend="eager")(fn) 3869 nopython_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn) 3870 3871 x = torch.ones(3) 3872 ys = [1.0, 2.0, 3.0] 3873 zs = [2.0, 5.0, 8.0] 3874 3875 self.assertEqual(opt_fn(x, ys, zs), fn(x, ys, zs)) 3876 3877 # If nopython, should raise UserError 3878 with self.assertRaisesRegex(torch._dynamo.exc.UserError, "zip()"): 3879 nopython_fn(x, ys[:1], zs) 3880 3881 with self.assertRaisesRegex(torch._dynamo.exc.UserError, "zip()"): 3882 nopython_fn(x, ys, zs[:1]) 3883 3884 # Should cause fallback if allow graph break 3885 with self.assertRaisesRegex(ValueError, "zip()"): 3886 opt_fn(x, ys[:1], zs) 3887 3888 with self.assertRaisesRegex(ValueError, "zip()"): 3889 opt_fn(x, ys, zs[:1]) 3890 3891 def test_fn_with_attr(self): 3892 def fn(x): 3893 if fn.pred: 3894 return torch.relu(x * 2) 3895 else: 3896 return torch.abs(x + 3) 3897 3898 t = torch.ones(3) 3899 counter = torch._dynamo.testing.CompileCounter() 3900 fn.pred = True 3901 opt_fn_0 = torch.compile(fullgraph=True, backend=counter)(fn) 3902 self.assertEqual(opt_fn_0(t), fn(t)) 3903 self.assertEqual(counter.frame_count, 1) 3904 fn.pred = False 3905 opt_fn_1 = torch.compile(fullgraph=True, backend=counter)(fn) 3906 self.assertEqual(opt_fn_1(t), fn(t)) 3907 self.assertEqual(counter.frame_count, 2) 3908 3909 def test_str_handler_for_user_defined_object(self): 3910 """ 3911 Confirms handler behaviour for `str` is the same between eager and dynamo. 3912 Compares a user defined object with custom `__str__` method and without. 3913 """ 3914 3915 class CustomStr: 3916 def __str__(self): 3917 return "ok" 3918 3919 def foo_custom_str(x): 3920 a = CustomStr() 3921 return x, str(a) 3922 3923 eager_custom_str = foo_custom_str(torch.ones(4)) 3924 dynamo_custom_str = torch.compile(foo_custom_str, fullgraph=True)(torch.ones(4)) 3925 3926 self.assertEqual(eager_custom_str[1], dynamo_custom_str[1]) 3927 self.assertEqual(eager_custom_str[1], "ok") 3928 3929 class DefaultStr: 3930 pass 3931 3932 def foo_default_str(x): 3933 a = DefaultStr() 3934 return x, str(a) 3935 3936 eager_default_str = foo_default_str(torch.ones(4)) 3937 dynamo_default_str = torch.compile(foo_default_str, fullgraph=True)( 3938 torch.ones(4) 3939 ) 3940 3941 # Check that the tensor output from eager and dynamo modes are the same 3942 self.assertEqual(eager_default_str[0], dynamo_default_str[0]) 3943 3944 # Check that the class name (without memory address) is the same in both modes 3945 eager_class_name = eager_default_str[1].split(" object at")[0] 3946 dynamo_class_name = dynamo_default_str[1].split(" object at")[0] 3947 self.assertEqual(eager_class_name, dynamo_class_name) 3948 3949 def test_pybind_object(self): 3950 def fn(x, pybind_obj): 3951 if pybind_obj.result: 3952 return torch.cos(x) 3953 return torch.sin(x) 3954 3955 opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 3956 3957 pybind_obj = torch._C._dynamo.guards.GuardDebugInfo(True, ["a==1"], 0) 3958 x = torch.randn(4) 3959 self.assertEqual(opt_fn(x, pybind_obj), fn(x, pybind_obj)) 3960 3961 pybind_obj = torch._C._dynamo.guards.GuardDebugInfo(False, ["a==1"], 1) 3962 x = torch.randn(4) 3963 self.assertEqual(opt_fn(x, pybind_obj), fn(x, pybind_obj)) 3964 3965 3966instantiate_parametrized_tests(FunctionTests) 3967 3968if __name__ == "__main__": 3969 from torch._dynamo.test_case import run_tests 3970 3971 run_tests() 3972