1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dynamo"] 2*da0073e9SAndroid Build Coastguard Workerimport abc 3*da0073e9SAndroid Build Coastguard Workerimport collections 4*da0073e9SAndroid Build Coastguard Workerimport copy 5*da0073e9SAndroid Build Coastguard Workerimport dataclasses 6*da0073e9SAndroid Build Coastguard Workerimport dis 7*da0073e9SAndroid Build Coastguard Workerimport enum 8*da0073e9SAndroid Build Coastguard Workerimport functools 9*da0073e9SAndroid Build Coastguard Workerimport gc 10*da0073e9SAndroid Build Coastguard Workerimport itertools 11*da0073e9SAndroid Build Coastguard Workerimport logging 12*da0073e9SAndroid Build Coastguard Workerimport math 13*da0073e9SAndroid Build Coastguard Workerimport operator 14*da0073e9SAndroid Build Coastguard Workerimport os 15*da0073e9SAndroid Build Coastguard Workerimport random 16*da0073e9SAndroid Build Coastguard Workerimport sys 17*da0073e9SAndroid Build Coastguard Workerimport tempfile 18*da0073e9SAndroid Build Coastguard Workerimport threading 19*da0073e9SAndroid Build Coastguard Workerimport traceback 20*da0073e9SAndroid Build Coastguard Workerimport typing 21*da0073e9SAndroid Build Coastguard Workerimport unittest 22*da0073e9SAndroid Build Coastguard Workerimport unittest.mock as mock 23*da0073e9SAndroid Build Coastguard Workerimport warnings 24*da0073e9SAndroid Build Coastguard Workerimport weakref 25*da0073e9SAndroid Build Coastguard Workerfrom unittest.mock import patch 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Workerimport numpy as np 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Workerimport torch 30*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.testing 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Workerimport torch._inductor.test_case 33*da0073e9SAndroid Build Coastguard Workerimport torch.onnx.operators 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Workerimport torch.utils._pytree as pytree 36*da0073e9SAndroid Build Coastguard Workerimport torch.utils.cpp_extension 37*da0073e9SAndroid Build Coastguard Workerfrom torch import Tensor 38*da0073e9SAndroid Build Coastguard Workerfrom torch._C import FileCheck 39*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo import allow_in_graph 40*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.eval_frame import _debug_get_cache_entry_list 41*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.exc import Unsupported 42*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.source import ConstantSource, GetItemSource, LocalSource 43*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.testing import ( 44*da0073e9SAndroid Build Coastguard Worker CompileCounter, 45*da0073e9SAndroid Build Coastguard Worker CompileCounterWithBackend, 46*da0073e9SAndroid Build Coastguard Worker expectedFailureDynamic, 47*da0073e9SAndroid Build Coastguard Worker same, 48*da0073e9SAndroid Build Coastguard Worker skipIfNotPy311, 49*da0073e9SAndroid Build Coastguard Worker unsupported, 50*da0073e9SAndroid Build Coastguard Worker xfailIfPy312, 51*da0073e9SAndroid Build Coastguard Worker) 52*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.utils import CompileProfiler, counters, ifdynstaticdefault 53*da0073e9SAndroid Build Coastguard Workerfrom torch._inductor.utils import run_and_get_code 54*da0073e9SAndroid Build Coastguard Workerfrom torch.ao.quantization import MinMaxObserver 55*da0073e9SAndroid Build Coastguard Workerfrom torch.ao.quantization.fake_quantize import FakeQuantize 56*da0073e9SAndroid Build Coastguard Workerfrom torch.ao.quantization.qconfig import QConfig 57*da0073e9SAndroid Build Coastguard Workerfrom torch.ao.quantization.quantize_fx import prepare_qat_fx 58*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.experimental.recording import NotEqualError, replay_shape_env_events 59*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.experimental.symbolic_shapes import ( 60*da0073e9SAndroid Build Coastguard Worker _constrain_range_for_size, 61*da0073e9SAndroid Build Coastguard Worker constrain_range, 62*da0073e9SAndroid Build Coastguard Worker constrain_unify, 63*da0073e9SAndroid Build Coastguard Worker ConstraintViolationError, 64*da0073e9SAndroid Build Coastguard Worker expect_true, 65*da0073e9SAndroid Build Coastguard Worker guard_size_oblivious, 66*da0073e9SAndroid Build Coastguard Worker ShapeEnv, 67*da0073e9SAndroid Build Coastguard Worker) 68*da0073e9SAndroid Build Coastguard Workerfrom torch.nn import functional as F 69*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import make_tensor 70*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import ( 71*da0073e9SAndroid Build Coastguard Worker PLATFORM_SUPPORTS_FLASH_ATTENTION, 72*da0073e9SAndroid Build Coastguard Worker SM80OrLater, 73*da0073e9SAndroid Build Coastguard Worker TEST_CUDA, 74*da0073e9SAndroid Build Coastguard Worker TEST_MULTIGPU, 75*da0073e9SAndroid Build Coastguard Worker) 76*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_methods_invocations import ( 77*da0073e9SAndroid Build Coastguard Worker sample_inputs_take_along_dim, 78*da0073e9SAndroid Build Coastguard Worker) 79*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 80*da0073e9SAndroid Build Coastguard Worker freeze_rng_state, 81*da0073e9SAndroid Build Coastguard Worker IS_FBCODE, 82*da0073e9SAndroid Build Coastguard Worker set_default_dtype, 83*da0073e9SAndroid Build Coastguard Worker wrapDeterministicFlagAPITest, 84*da0073e9SAndroid Build Coastguard Worker) 85*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import JitTestCase 86*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.logging_utils import logs_to_string 87*da0073e9SAndroid Build Coastguard Worker 88*da0073e9SAndroid Build Coastguard Workermytuple = collections.namedtuple("mytuple", ["a", "b", "ab"]) 89*da0073e9SAndroid Build Coastguard WorkerT = typing.TypeVar("T") 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Worker 92*da0073e9SAndroid Build Coastguard Worker# Specializes a test to run only if translation validation is set. 93*da0073e9SAndroid Build Coastguard Workerdef onlyIfTranslationValidation(fn: typing.Callable) -> typing.Callable: 94*da0073e9SAndroid Build Coastguard Worker @functools.wraps(fn) 95*da0073e9SAndroid Build Coastguard Worker def wrapper(*args, **kwargs): 96*da0073e9SAndroid Build Coastguard Worker import torch.fx.experimental.validator 97*da0073e9SAndroid Build Coastguard Worker 98*da0073e9SAndroid Build Coastguard Worker if torch.fx.experimental.validator.translation_validation_enabled(): 99*da0073e9SAndroid Build Coastguard Worker return fn(*args, **kwargs) 100*da0073e9SAndroid Build Coastguard Worker raise unittest.SkipTest(f"only works when TV is True.") 101*da0073e9SAndroid Build Coastguard Worker 102*da0073e9SAndroid Build Coastguard Worker return wrapper 103*da0073e9SAndroid Build Coastguard Worker 104*da0073e9SAndroid Build Coastguard Worker 105*da0073e9SAndroid Build Coastguard Workerdef cleanup_op(opname): 106*da0073e9SAndroid Build Coastguard Worker ns, name = opname.split("::") 107*da0073e9SAndroid Build Coastguard Worker if not hasattr(torch.ops, ns): 108*da0073e9SAndroid Build Coastguard Worker return 109*da0073e9SAndroid Build Coastguard Worker actual_ns = getattr(torch.ops, ns) 110*da0073e9SAndroid Build Coastguard Worker if not hasattr(actual_ns, name): 111*da0073e9SAndroid Build Coastguard Worker return 112*da0073e9SAndroid Build Coastguard Worker delattr(actual_ns, name) 113*da0073e9SAndroid Build Coastguard Worker 114*da0073e9SAndroid Build Coastguard Worker 115*da0073e9SAndroid Build Coastguard Workerclass MyPickledModule(torch.nn.Module): 116*da0073e9SAndroid Build Coastguard Worker def __init__(self, z): 117*da0073e9SAndroid Build Coastguard Worker super().__init__() 118*da0073e9SAndroid Build Coastguard Worker self.z = z 119*da0073e9SAndroid Build Coastguard Worker 120*da0073e9SAndroid Build Coastguard Worker def forward(self, x, y): 121*da0073e9SAndroid Build Coastguard Worker return x * x * x + y + self.z 122*da0073e9SAndroid Build Coastguard Worker 123*da0073e9SAndroid Build Coastguard Worker 124*da0073e9SAndroid Build Coastguard Worker# These are used for test_{cond/map}_with_quantization 125*da0073e9SAndroid Build Coastguard Workerdefault_symmetric_fake_quant = FakeQuantize.with_args( 126*da0073e9SAndroid Build Coastguard Worker observer=MinMaxObserver, qscheme=torch.per_tensor_symmetric, dtype=torch.quint8 127*da0073e9SAndroid Build Coastguard Worker) 128*da0073e9SAndroid Build Coastguard Workerdefault_weight_symmetric_fake_quant = FakeQuantize.with_args( 129*da0073e9SAndroid Build Coastguard Worker observer=MinMaxObserver, qscheme=torch.per_tensor_symmetric, dtype=torch.qint8 130*da0073e9SAndroid Build Coastguard Worker) 131*da0073e9SAndroid Build Coastguard Workeruniform_qconfig_8bit = QConfig( 132*da0073e9SAndroid Build Coastguard Worker activation=default_symmetric_fake_quant, 133*da0073e9SAndroid Build Coastguard Worker weight=default_weight_symmetric_fake_quant.with_args, 134*da0073e9SAndroid Build Coastguard Worker) 135*da0073e9SAndroid Build Coastguard Workerqconfig_dict = {"object_type": [(torch.nn.Linear, uniform_qconfig_8bit)]} 136*da0073e9SAndroid Build Coastguard Worker 137*da0073e9SAndroid Build Coastguard Worker 138*da0073e9SAndroid Build Coastguard Workerdef closure_adder(val): 139*da0073e9SAndroid Build Coastguard Worker def inner(x): 140*da0073e9SAndroid Build Coastguard Worker return torch.sin(x + val) 141*da0073e9SAndroid Build Coastguard Worker 142*da0073e9SAndroid Build Coastguard Worker return inner 143*da0073e9SAndroid Build Coastguard Worker 144*da0073e9SAndroid Build Coastguard Worker 145*da0073e9SAndroid Build Coastguard Workerclass UserDefineSetAttr: 146*da0073e9SAndroid Build Coastguard Worker setup = False 147*da0073e9SAndroid Build Coastguard Worker 148*da0073e9SAndroid Build Coastguard Worker def __setattr__(self, key, value): 149*da0073e9SAndroid Build Coastguard Worker assert torch.compiler.is_dynamo_compiling() or UserDefineSetAttr.setup 150*da0073e9SAndroid Build Coastguard Worker super().__setattr__(f"pfx_{key}", value) 151*da0073e9SAndroid Build Coastguard Worker 152*da0073e9SAndroid Build Coastguard Worker def __getattr__(self, key, c=1): 153*da0073e9SAndroid Build Coastguard Worker assert torch.compiler.is_dynamo_compiling() or UserDefineSetAttr.setup 154*da0073e9SAndroid Build Coastguard Worker # c is added to force a guard on __defaults__ and checks the source for __getattr__ 155*da0073e9SAndroid Build Coastguard Worker if c: 156*da0073e9SAndroid Build Coastguard Worker return self.__dict__[f"pfx_{key}"] 157*da0073e9SAndroid Build Coastguard Worker else: 158*da0073e9SAndroid Build Coastguard Worker return None 159*da0073e9SAndroid Build Coastguard Worker 160*da0073e9SAndroid Build Coastguard Worker 161*da0073e9SAndroid Build Coastguard Workerclass MiscTests(torch._inductor.test_case.TestCase): 162*da0073e9SAndroid Build Coastguard Worker def test_get_cache_entry(self): 163*da0073e9SAndroid Build Coastguard Worker def f(x): 164*da0073e9SAndroid Build Coastguard Worker return x + 1 165*da0073e9SAndroid Build Coastguard Worker 166*da0073e9SAndroid Build Coastguard Worker torch.compile(f)(torch.randn(5, 5, 5)) 167*da0073e9SAndroid Build Coastguard Worker entries = _debug_get_cache_entry_list(f) 168*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(entries) > 0) 169*da0073e9SAndroid Build Coastguard Worker 170*da0073e9SAndroid Build Coastguard Worker def g(x): 171*da0073e9SAndroid Build Coastguard Worker return x + 2 172*da0073e9SAndroid Build Coastguard Worker 173*da0073e9SAndroid Build Coastguard Worker entries = _debug_get_cache_entry_list(g) 174*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(entries) == 0) 175*da0073e9SAndroid Build Coastguard Worker 176*da0073e9SAndroid Build Coastguard Worker try: 177*da0073e9SAndroid Build Coastguard Worker _debug_get_cache_entry_list(1) 178*da0073e9SAndroid Build Coastguard Worker except TypeError as e: 179*da0073e9SAndroid Build Coastguard Worker self.assertIn("expected a code object!", str(e)) 180*da0073e9SAndroid Build Coastguard Worker 181*da0073e9SAndroid Build Coastguard Worker # test get cache entry on skipped code object 182*da0073e9SAndroid Build Coastguard Worker def h(x): 183*da0073e9SAndroid Build Coastguard Worker x = x + 1 184*da0073e9SAndroid Build Coastguard Worker torch._dynamo.graph_break() 185*da0073e9SAndroid Build Coastguard Worker return x + 1 186*da0073e9SAndroid Build Coastguard Worker 187*da0073e9SAndroid Build Coastguard Worker torch.compile(h)(torch.randn(3, 3)) 188*da0073e9SAndroid Build Coastguard Worker 189*da0073e9SAndroid Build Coastguard Worker entries = _debug_get_cache_entry_list(torch._dynamo.graph_break) 190*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(entries), 0) 191*da0073e9SAndroid Build Coastguard Worker 192*da0073e9SAndroid Build Coastguard Worker def test_boolarg(self): 193*da0073e9SAndroid Build Coastguard Worker def boolarg(aa, bb, flag): 194*da0073e9SAndroid Build Coastguard Worker if flag: 195*da0073e9SAndroid Build Coastguard Worker return aa - bb 196*da0073e9SAndroid Build Coastguard Worker else: 197*da0073e9SAndroid Build Coastguard Worker return bb - aa 198*da0073e9SAndroid Build Coastguard Worker 199*da0073e9SAndroid Build Coastguard Worker a = torch.randn(10, 10) 200*da0073e9SAndroid Build Coastguard Worker b = torch.randn(10, 10) 201*da0073e9SAndroid Build Coastguard Worker correct1 = boolarg(a, b, True) 202*da0073e9SAndroid Build Coastguard Worker correct2 = boolarg(a, b, False) 203*da0073e9SAndroid Build Coastguard Worker correct3 = boolarg(a, b, None) 204*da0073e9SAndroid Build Coastguard Worker counter = CompileCounter() 205*da0073e9SAndroid Build Coastguard Worker opt_boolarg = torch._dynamo.optimize_assert(counter)(boolarg) 206*da0073e9SAndroid Build Coastguard Worker val1 = opt_boolarg(a, b, True) 207*da0073e9SAndroid Build Coastguard Worker val2 = opt_boolarg(a, b, False) 208*da0073e9SAndroid Build Coastguard Worker val3 = opt_boolarg(a, b, None) 209*da0073e9SAndroid Build Coastguard Worker val4 = opt_boolarg(a, b, True) 210*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(val1, correct1)) 211*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(val2, correct2)) 212*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(val3, correct3)) 213*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(val4, correct1)) 214*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 3) 215*da0073e9SAndroid Build Coastguard Worker 216*da0073e9SAndroid Build Coastguard Worker def test_invalid_args_builtin(self): 217*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager") 218*da0073e9SAndroid Build Coastguard Worker def fn(x): 219*da0073e9SAndroid Build Coastguard Worker x = x.sin() 220*da0073e9SAndroid Build Coastguard Worker if isinstance(x, torch.Tensor, invalid=True): 221*da0073e9SAndroid Build Coastguard Worker x = x.sin() 222*da0073e9SAndroid Build Coastguard Worker return x 223*da0073e9SAndroid Build Coastguard Worker 224*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 225*da0073e9SAndroid Build Coastguard Worker fn(torch.randn(16)) 226*da0073e9SAndroid Build Coastguard Worker 227*da0073e9SAndroid Build Coastguard Worker def test_cpp_extension_recommends_custom_ops(self): 228*da0073e9SAndroid Build Coastguard Worker cpp_source = """ 229*da0073e9SAndroid Build Coastguard Worker #include <torch/extension.h> 230*da0073e9SAndroid Build Coastguard Worker at::Tensor foobar(const at::Tensor& x) { 231*da0073e9SAndroid Build Coastguard Worker return x.clone(); 232*da0073e9SAndroid Build Coastguard Worker } 233*da0073e9SAndroid Build Coastguard Worker """ 234*da0073e9SAndroid Build Coastguard Worker module = torch.utils.cpp_extension.load_inline( 235*da0073e9SAndroid Build Coastguard Worker name="mylib", 236*da0073e9SAndroid Build Coastguard Worker cpp_sources=cpp_source, 237*da0073e9SAndroid Build Coastguard Worker functions="foobar", 238*da0073e9SAndroid Build Coastguard Worker verbose=True, 239*da0073e9SAndroid Build Coastguard Worker ) 240*da0073e9SAndroid Build Coastguard Worker 241*da0073e9SAndroid Build Coastguard Worker x = torch.ones(2, 2, requires_grad=True) 242*da0073e9SAndroid Build Coastguard Worker counters.clear() 243*da0073e9SAndroid Build Coastguard Worker 244*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager") 245*da0073e9SAndroid Build Coastguard Worker def f(x): 246*da0073e9SAndroid Build Coastguard Worker return module.foobar(x) 247*da0073e9SAndroid Build Coastguard Worker 248*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsOnceRegex( 249*da0073e9SAndroid Build Coastguard Worker UserWarning, 250*da0073e9SAndroid Build Coastguard Worker ".*https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html.*", 251*da0073e9SAndroid Build Coastguard Worker ): 252*da0073e9SAndroid Build Coastguard Worker f(x) 253*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(counters["graph_break"]), 1) 254*da0073e9SAndroid Build Coastguard Worker first_graph_break = list(counters["graph_break"].keys())[0] 255*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 256*da0073e9SAndroid Build Coastguard Worker first_graph_break, 257*da0073e9SAndroid Build Coastguard Worker """Graph break due to unsupported builtin mylib.PyCapsule.foobar. This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind). If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround. If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use torch.compiler.allow_in_graph.""", 258*da0073e9SAndroid Build Coastguard Worker ) 259*da0073e9SAndroid Build Coastguard Worker 260*da0073e9SAndroid Build Coastguard Worker cpp_source = """ 261*da0073e9SAndroid Build Coastguard Worker #include <torch/extension.h> 262*da0073e9SAndroid Build Coastguard Worker at::Tensor baz(const at::Tensor& x) { 263*da0073e9SAndroid Build Coastguard Worker return x.clone(); 264*da0073e9SAndroid Build Coastguard Worker } 265*da0073e9SAndroid Build Coastguard Worker """ 266*da0073e9SAndroid Build Coastguard Worker module2 = torch.utils.cpp_extension.load_inline( 267*da0073e9SAndroid Build Coastguard Worker name="mylib2", 268*da0073e9SAndroid Build Coastguard Worker cpp_sources=cpp_source, 269*da0073e9SAndroid Build Coastguard Worker functions="baz", 270*da0073e9SAndroid Build Coastguard Worker verbose=True, 271*da0073e9SAndroid Build Coastguard Worker ) 272*da0073e9SAndroid Build Coastguard Worker 273*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 274*da0073e9SAndroid Build Coastguard Worker 275*da0073e9SAndroid Build Coastguard Worker # Test that each warning only happens once 276*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager") 277*da0073e9SAndroid Build Coastguard Worker def f(x): 278*da0073e9SAndroid Build Coastguard Worker module2.baz(x) 279*da0073e9SAndroid Build Coastguard Worker module.foobar(x) 280*da0073e9SAndroid Build Coastguard Worker module.foobar(x) 281*da0073e9SAndroid Build Coastguard Worker module2.baz(x) 282*da0073e9SAndroid Build Coastguard Worker module.foobar(x) 283*da0073e9SAndroid Build Coastguard Worker module2.baz(x) 284*da0073e9SAndroid Build Coastguard Worker return x.clone() 285*da0073e9SAndroid Build Coastguard Worker 286*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as ws: 287*da0073e9SAndroid Build Coastguard Worker warnings.simplefilter("always") 288*da0073e9SAndroid Build Coastguard Worker f(x) 289*da0073e9SAndroid Build Coastguard Worker f(x) 290*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(ws), 2) 291*da0073e9SAndroid Build Coastguard Worker 292*da0073e9SAndroid Build Coastguard Worker def test_callpacked(self): 293*da0073e9SAndroid Build Coastguard Worker def call_packed(args): 294*da0073e9SAndroid Build Coastguard Worker a, b, c = args 295*da0073e9SAndroid Build Coastguard Worker return a - b * c 296*da0073e9SAndroid Build Coastguard Worker 297*da0073e9SAndroid Build Coastguard Worker counter = CompileCounter() 298*da0073e9SAndroid Build Coastguard Worker a = torch.randn(10, 10) 299*da0073e9SAndroid Build Coastguard Worker b = torch.randn(10, 10) 300*da0073e9SAndroid Build Coastguard Worker c = torch.randn(10, 10) 301*da0073e9SAndroid Build Coastguard Worker correct = call_packed([a, b, c]) 302*da0073e9SAndroid Build Coastguard Worker opt_call_packed = torch._dynamo.optimize_assert(counter)(call_packed) 303*da0073e9SAndroid Build Coastguard Worker val1 = opt_call_packed([a, b, c]) 304*da0073e9SAndroid Build Coastguard Worker val2 = opt_call_packed((a, b, c)) 305*da0073e9SAndroid Build Coastguard Worker val3 = opt_call_packed([a, b, c]) 306*da0073e9SAndroid Build Coastguard Worker val4 = opt_call_packed((a, b, c)) 307*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(val1, correct)) 308*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(val2, correct)) 309*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(val3, correct)) 310*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(val4, correct)) 311*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 2) 312*da0073e9SAndroid Build Coastguard Worker 313*da0073e9SAndroid Build Coastguard Worker def test_raises(self): 314*da0073e9SAndroid Build Coastguard Worker def fn(a, b, c, cls): 315*da0073e9SAndroid Build Coastguard Worker x = a + b - c * 10 316*da0073e9SAndroid Build Coastguard Worker raise cls(str(x)) 317*da0073e9SAndroid Build Coastguard Worker 318*da0073e9SAndroid Build Coastguard Worker counter = CompileCounter() 319*da0073e9SAndroid Build Coastguard Worker a = torch.randn(10, 10) 320*da0073e9SAndroid Build Coastguard Worker b = torch.randn(10, 10) 321*da0073e9SAndroid Build Coastguard Worker c = torch.randn(10, 10) 322*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(counter)(fn) 323*da0073e9SAndroid Build Coastguard Worker self.assertRaises(AssertionError, lambda: opt_fn(a, b, c, AssertionError)) 324*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 1) 325*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.op_count, 3) 326*da0073e9SAndroid Build Coastguard Worker 327*da0073e9SAndroid Build Coastguard Worker def test_module_not_callable(self): 328*da0073e9SAndroid Build Coastguard Worker def fn(x): 329*da0073e9SAndroid Build Coastguard Worker return torch.fft(x) 330*da0073e9SAndroid Build Coastguard Worker 331*da0073e9SAndroid Build Coastguard Worker counter = CompileCounter() 332*da0073e9SAndroid Build Coastguard Worker a = torch.randn(10, 10) 333*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(counter)(fn) 334*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 335*da0073e9SAndroid Build Coastguard Worker TypeError, "'module' object is not callable", lambda: opt_fn(a) 336*da0073e9SAndroid Build Coastguard Worker ) 337*da0073e9SAndroid Build Coastguard Worker 338*da0073e9SAndroid Build Coastguard Worker def test_inplace(self): 339*da0073e9SAndroid Build Coastguard Worker def inplace1(a, b): 340*da0073e9SAndroid Build Coastguard Worker o = torch.empty((10, 10)) 341*da0073e9SAndroid Build Coastguard Worker o.copy_(a) 342*da0073e9SAndroid Build Coastguard Worker o -= b 343*da0073e9SAndroid Build Coastguard Worker return o 344*da0073e9SAndroid Build Coastguard Worker 345*da0073e9SAndroid Build Coastguard Worker torch._dynamo.testing.standard_test(self, inplace1, 2, expected_ops=3) 346*da0073e9SAndroid Build Coastguard Worker 347*da0073e9SAndroid Build Coastguard Worker def test_inplace_desugaring(self): 348*da0073e9SAndroid Build Coastguard Worker def inplace_on_literals(y): 349*da0073e9SAndroid Build Coastguard Worker x0 = 1 350*da0073e9SAndroid Build Coastguard Worker x0 += y 351*da0073e9SAndroid Build Coastguard Worker x1 = 1 352*da0073e9SAndroid Build Coastguard Worker x1 -= y 353*da0073e9SAndroid Build Coastguard Worker return x0, x1 354*da0073e9SAndroid Build Coastguard Worker 355*da0073e9SAndroid Build Coastguard Worker torch._dynamo.testing.standard_test( 356*da0073e9SAndroid Build Coastguard Worker self, inplace_on_literals, 1, expected_ops=2 357*da0073e9SAndroid Build Coastguard Worker ) 358*da0073e9SAndroid Build Coastguard Worker 359*da0073e9SAndroid Build Coastguard Worker def test_unpack4(self): 360*da0073e9SAndroid Build Coastguard Worker def unpack4(a, b): 361*da0073e9SAndroid Build Coastguard Worker a = a[:5, :] 362*da0073e9SAndroid Build Coastguard Worker b = b[:5, :] 363*da0073e9SAndroid Build Coastguard Worker x, y = a.size() 364*da0073e9SAndroid Build Coastguard Worker o = torch.empty((x, y)) 365*da0073e9SAndroid Build Coastguard Worker o.copy_(a / b) 366*da0073e9SAndroid Build Coastguard Worker return o 367*da0073e9SAndroid Build Coastguard Worker 368*da0073e9SAndroid Build Coastguard Worker torch._dynamo.testing.standard_test( 369*da0073e9SAndroid Build Coastguard Worker self, 370*da0073e9SAndroid Build Coastguard Worker unpack4, 371*da0073e9SAndroid Build Coastguard Worker 2, 372*da0073e9SAndroid Build Coastguard Worker expected_ops=5, 373*da0073e9SAndroid Build Coastguard Worker expected_ops_dynamic=ifdynstaticdefault(5, 7), 374*da0073e9SAndroid Build Coastguard Worker ) 375*da0073e9SAndroid Build Coastguard Worker 376*da0073e9SAndroid Build Coastguard Worker def test_unpack5(self): 377*da0073e9SAndroid Build Coastguard Worker def unpack5(a, b): 378*da0073e9SAndroid Build Coastguard Worker a = a[:5, :] 379*da0073e9SAndroid Build Coastguard Worker b = b[:5, :] 380*da0073e9SAndroid Build Coastguard Worker x, y = a.shape 381*da0073e9SAndroid Build Coastguard Worker o = torch.empty((x, y)) 382*da0073e9SAndroid Build Coastguard Worker o.copy_(a / b) 383*da0073e9SAndroid Build Coastguard Worker return o 384*da0073e9SAndroid Build Coastguard Worker 385*da0073e9SAndroid Build Coastguard Worker torch._dynamo.testing.standard_test( 386*da0073e9SAndroid Build Coastguard Worker self, 387*da0073e9SAndroid Build Coastguard Worker unpack5, 388*da0073e9SAndroid Build Coastguard Worker 2, 389*da0073e9SAndroid Build Coastguard Worker expected_ops=5, 390*da0073e9SAndroid Build Coastguard Worker expected_ops_dynamic=ifdynstaticdefault(5, 7), 391*da0073e9SAndroid Build Coastguard Worker ) 392*da0073e9SAndroid Build Coastguard Worker 393*da0073e9SAndroid Build Coastguard Worker def test_matmul1(self): 394*da0073e9SAndroid Build Coastguard Worker def matmul_op1(a, b): 395*da0073e9SAndroid Build Coastguard Worker return a @ b 396*da0073e9SAndroid Build Coastguard Worker 397*da0073e9SAndroid Build Coastguard Worker # TODO(jansel): FX doesn't support this, should add upstream support 398*da0073e9SAndroid Build Coastguard Worker torch._dynamo.testing.standard_test(self, matmul_op1, 2, expected_ops=1) 399*da0073e9SAndroid Build Coastguard Worker 400*da0073e9SAndroid Build Coastguard Worker def test_int_shape_binops(self): 401*da0073e9SAndroid Build Coastguard Worker def fn(x): 402*da0073e9SAndroid Build Coastguard Worker # Test reversal by putting int arg first. 403*da0073e9SAndroid Build Coastguard Worker y = 15 - x.shape[0] 404*da0073e9SAndroid Build Coastguard Worker y = 4 + y 405*da0073e9SAndroid Build Coastguard Worker y = 5 * y 406*da0073e9SAndroid Build Coastguard Worker y = 2 % y 407*da0073e9SAndroid Build Coastguard Worker y = 3**y 408*da0073e9SAndroid Build Coastguard Worker y = 10 // y 409*da0073e9SAndroid Build Coastguard Worker y = pow(2, y) 410*da0073e9SAndroid Build Coastguard Worker y = 10 / y 411*da0073e9SAndroid Build Coastguard Worker return x + y 412*da0073e9SAndroid Build Coastguard Worker 413*da0073e9SAndroid Build Coastguard Worker torch._dynamo.testing.standard_test( 414*da0073e9SAndroid Build Coastguard Worker self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(1, 11) 415*da0073e9SAndroid Build Coastguard Worker ) 416*da0073e9SAndroid Build Coastguard Worker 417*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch(only_allow_pt2_compliant_ops=True) 418*da0073e9SAndroid Build Coastguard Worker def test_pt2_compliant_ops_are_allowed(self): 419*da0073e9SAndroid Build Coastguard Worker lib = torch.library.Library("mylib", "FRAGMENT") 420*da0073e9SAndroid Build Coastguard Worker try: 421*da0073e9SAndroid Build Coastguard Worker torch.library.define( 422*da0073e9SAndroid Build Coastguard Worker "mylib::bar", 423*da0073e9SAndroid Build Coastguard Worker "(Tensor x) -> Tensor", 424*da0073e9SAndroid Build Coastguard Worker lib=lib, 425*da0073e9SAndroid Build Coastguard Worker tags=(torch.Tag.pt2_compliant_tag,), 426*da0073e9SAndroid Build Coastguard Worker ) 427*da0073e9SAndroid Build Coastguard Worker torch.library.impl( 428*da0073e9SAndroid Build Coastguard Worker "mylib::bar", "CompositeImplicitAutograd", torch.sin, lib=lib 429*da0073e9SAndroid Build Coastguard Worker ) 430*da0073e9SAndroid Build Coastguard Worker assert torch.Tag.pt2_compliant_tag in torch.ops.mylib.bar.default.tags 431*da0073e9SAndroid Build Coastguard Worker 432*da0073e9SAndroid Build Coastguard Worker def f(x): 433*da0073e9SAndroid Build Coastguard Worker return torch.ops.mylib.bar(x) 434*da0073e9SAndroid Build Coastguard Worker 435*da0073e9SAndroid Build Coastguard Worker overload = torch.ops.mylib.bar.default 436*da0073e9SAndroid Build Coastguard Worker 437*da0073e9SAndroid Build Coastguard Worker def g(x): 438*da0073e9SAndroid Build Coastguard Worker return overload(x) 439*da0073e9SAndroid Build Coastguard Worker 440*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 441*da0073e9SAndroid Build Coastguard Worker 442*da0073e9SAndroid Build Coastguard Worker counts = torch._dynamo.testing.CompileCounter() 443*da0073e9SAndroid Build Coastguard Worker optimized_f = torch._dynamo.optimize(counts, nopython=True)(f) 444*da0073e9SAndroid Build Coastguard Worker _ = optimized_f(x) 445*da0073e9SAndroid Build Coastguard Worker 446*da0073e9SAndroid Build Coastguard Worker optimized_g = torch._dynamo.optimize(counts, nopython=True)(f) 447*da0073e9SAndroid Build Coastguard Worker _ = optimized_g(x) 448*da0073e9SAndroid Build Coastguard Worker finally: 449*da0073e9SAndroid Build Coastguard Worker cleanup_op("mylib::bar") 450*da0073e9SAndroid Build Coastguard Worker del lib 451*da0073e9SAndroid Build Coastguard Worker 452*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch(only_allow_pt2_compliant_ops=True) 453*da0073e9SAndroid Build Coastguard Worker def test_non_pt2_compliant_ops_graph_break(self): 454*da0073e9SAndroid Build Coastguard Worker lib = torch.library.Library("mylib", "FRAGMENT") 455*da0073e9SAndroid Build Coastguard Worker try: 456*da0073e9SAndroid Build Coastguard Worker torch.library.define("mylib::bar2", "(Tensor x) -> Tensor", lib=lib) 457*da0073e9SAndroid Build Coastguard Worker torch.library.impl( 458*da0073e9SAndroid Build Coastguard Worker "mylib::bar2", "CompositeImplicitAutograd", torch.sin, lib=lib 459*da0073e9SAndroid Build Coastguard Worker ) 460*da0073e9SAndroid Build Coastguard Worker assert torch.Tag.pt2_compliant_tag not in torch.ops.mylib.bar2.default.tags 461*da0073e9SAndroid Build Coastguard Worker 462*da0073e9SAndroid Build Coastguard Worker def f(x): 463*da0073e9SAndroid Build Coastguard Worker return torch.ops.mylib.bar2(x) 464*da0073e9SAndroid Build Coastguard Worker 465*da0073e9SAndroid Build Coastguard Worker overload = torch.ops.mylib.bar2.default 466*da0073e9SAndroid Build Coastguard Worker 467*da0073e9SAndroid Build Coastguard Worker def g(x): 468*da0073e9SAndroid Build Coastguard Worker return overload(x) 469*da0073e9SAndroid Build Coastguard Worker 470*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 471*da0073e9SAndroid Build Coastguard Worker 472*da0073e9SAndroid Build Coastguard Worker counts = torch._dynamo.testing.CompileCounter() 473*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 474*da0073e9SAndroid Build Coastguard Worker torch._dynamo.exc.Unsupported, "not PT2 compliant" 475*da0073e9SAndroid Build Coastguard Worker ): 476*da0073e9SAndroid Build Coastguard Worker optimized_f = torch._dynamo.optimize(counts, nopython=True)(f) 477*da0073e9SAndroid Build Coastguard Worker y = optimized_f(x) 478*da0073e9SAndroid Build Coastguard Worker 479*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 480*da0073e9SAndroid Build Coastguard Worker torch._dynamo.exc.Unsupported, "not PT2 compliant" 481*da0073e9SAndroid Build Coastguard Worker ): 482*da0073e9SAndroid Build Coastguard Worker optimized_g = torch._dynamo.optimize(counts, nopython=True)(f) 483*da0073e9SAndroid Build Coastguard Worker y = optimized_g(x) 484*da0073e9SAndroid Build Coastguard Worker finally: 485*da0073e9SAndroid Build Coastguard Worker cleanup_op("mylib::bar2") 486*da0073e9SAndroid Build Coastguard Worker del lib 487*da0073e9SAndroid Build Coastguard Worker 488*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch(only_allow_pt2_compliant_ops=True) 489*da0073e9SAndroid Build Coastguard Worker def test_pt2_compliant_overload(self): 490*da0073e9SAndroid Build Coastguard Worker lib = torch.library.Library("mylib", "FRAGMENT") 491*da0073e9SAndroid Build Coastguard Worker try: 492*da0073e9SAndroid Build Coastguard Worker torch.library.define( 493*da0073e9SAndroid Build Coastguard Worker "mylib::bar3.tensor", 494*da0073e9SAndroid Build Coastguard Worker "(Tensor x) -> Tensor", 495*da0073e9SAndroid Build Coastguard Worker tags=torch.Tag.pt2_compliant_tag, 496*da0073e9SAndroid Build Coastguard Worker lib=lib, 497*da0073e9SAndroid Build Coastguard Worker ) 498*da0073e9SAndroid Build Coastguard Worker torch.library.define( 499*da0073e9SAndroid Build Coastguard Worker "mylib::bar3.int", "(Tensor x, int dim) -> Tensor", lib=lib 500*da0073e9SAndroid Build Coastguard Worker ) 501*da0073e9SAndroid Build Coastguard Worker 502*da0073e9SAndroid Build Coastguard Worker torch.library.impl( 503*da0073e9SAndroid Build Coastguard Worker "mylib::bar3.tensor", 504*da0073e9SAndroid Build Coastguard Worker "CompositeImplicitAutograd", 505*da0073e9SAndroid Build Coastguard Worker torch.sin, 506*da0073e9SAndroid Build Coastguard Worker lib=lib, 507*da0073e9SAndroid Build Coastguard Worker ) 508*da0073e9SAndroid Build Coastguard Worker torch.library.impl( 509*da0073e9SAndroid Build Coastguard Worker "mylib::bar3.int", "CompositeImplicitAutograd", torch.sum, lib=lib 510*da0073e9SAndroid Build Coastguard Worker ) 511*da0073e9SAndroid Build Coastguard Worker 512*da0073e9SAndroid Build Coastguard Worker def f(x): 513*da0073e9SAndroid Build Coastguard Worker return torch.ops.mylib.bar3(x) 514*da0073e9SAndroid Build Coastguard Worker 515*da0073e9SAndroid Build Coastguard Worker def g(x): 516*da0073e9SAndroid Build Coastguard Worker return torch.ops.mylib.bar3(x, 1) 517*da0073e9SAndroid Build Coastguard Worker 518*da0073e9SAndroid Build Coastguard Worker def h(x): 519*da0073e9SAndroid Build Coastguard Worker return torch.ops.mylib.bar3(x, x, x) 520*da0073e9SAndroid Build Coastguard Worker 521*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 522*da0073e9SAndroid Build Coastguard Worker 523*da0073e9SAndroid Build Coastguard Worker counts = torch._dynamo.testing.CompileCounter() 524*da0073e9SAndroid Build Coastguard Worker optimized_f = torch._dynamo.optimize(counts, nopython=True)(f) 525*da0073e9SAndroid Build Coastguard Worker optimized_g = torch._dynamo.optimize(counts, nopython=True)(g) 526*da0073e9SAndroid Build Coastguard Worker optimized_h = torch._dynamo.optimize(counts, nopython=True)(h) 527*da0073e9SAndroid Build Coastguard Worker 528*da0073e9SAndroid Build Coastguard Worker # No error: the overload is PT2 compliant 529*da0073e9SAndroid Build Coastguard Worker optimized_f(x) 530*da0073e9SAndroid Build Coastguard Worker 531*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 532*da0073e9SAndroid Build Coastguard Worker torch._dynamo.exc.Unsupported, "not PT2 compliant" 533*da0073e9SAndroid Build Coastguard Worker ): 534*da0073e9SAndroid Build Coastguard Worker y = optimized_g(x) 535*da0073e9SAndroid Build Coastguard Worker 536*da0073e9SAndroid Build Coastguard Worker # graph break on incorrect parsing 537*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, "failed to"): 538*da0073e9SAndroid Build Coastguard Worker y = optimized_h(x) 539*da0073e9SAndroid Build Coastguard Worker 540*da0073e9SAndroid Build Coastguard Worker finally: 541*da0073e9SAndroid Build Coastguard Worker cleanup_op("mylib::bar3") 542*da0073e9SAndroid Build Coastguard Worker del lib 543*da0073e9SAndroid Build Coastguard Worker 544*da0073e9SAndroid Build Coastguard Worker def test_auto_functionalize_can_with_default(self): 545*da0073e9SAndroid Build Coastguard Worker lib = torch.library.Library("mylib", "FRAGMENT") 546*da0073e9SAndroid Build Coastguard Worker torch.library.define( 547*da0073e9SAndroid Build Coastguard Worker "mylib::foo", 548*da0073e9SAndroid Build Coastguard Worker "(Tensor a, int b, Tensor(d!)? c=None, Tensor? d=None, int e=-1) -> ()", 549*da0073e9SAndroid Build Coastguard Worker tags=torch.Tag.pt2_compliant_tag, 550*da0073e9SAndroid Build Coastguard Worker lib=lib, 551*da0073e9SAndroid Build Coastguard Worker ) 552*da0073e9SAndroid Build Coastguard Worker 553*da0073e9SAndroid Build Coastguard Worker @torch.library.impl("mylib::foo", "cpu", lib=lib) 554*da0073e9SAndroid Build Coastguard Worker def foo_impl(a, b, c=None, d=None, e=-1): 555*da0073e9SAndroid Build Coastguard Worker a + b 556*da0073e9SAndroid Build Coastguard Worker return 557*da0073e9SAndroid Build Coastguard Worker 558*da0073e9SAndroid Build Coastguard Worker def f(a, mode): 559*da0073e9SAndroid Build Coastguard Worker return torch.ops.mylib.foo( 560*da0073e9SAndroid Build Coastguard Worker a, 561*da0073e9SAndroid Build Coastguard Worker 0, 562*da0073e9SAndroid Build Coastguard Worker ) 563*da0073e9SAndroid Build Coastguard Worker 564*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([10, 10, 10], dtype=torch.int64) 565*da0073e9SAndroid Build Coastguard Worker 566*da0073e9SAndroid Build Coastguard Worker torch.compile(f)(a, 0) 567*da0073e9SAndroid Build Coastguard Worker 568*da0073e9SAndroid Build Coastguard Worker cleanup_op("mylib::foo") 569*da0073e9SAndroid Build Coastguard Worker del lib 570*da0073e9SAndroid Build Coastguard Worker 571*da0073e9SAndroid Build Coastguard Worker def test_user_defined_setattr1(self): 572*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager", fullgraph=True) 573*da0073e9SAndroid Build Coastguard Worker def fn(obj): 574*da0073e9SAndroid Build Coastguard Worker obj.y = obj.x + 1 575*da0073e9SAndroid Build Coastguard Worker 576*da0073e9SAndroid Build Coastguard Worker obj = UserDefineSetAttr() 577*da0073e9SAndroid Build Coastguard Worker with patch.object(UserDefineSetAttr, "setup", True): 578*da0073e9SAndroid Build Coastguard Worker obj.x = torch.randn(8) 579*da0073e9SAndroid Build Coastguard Worker fn(obj) 580*da0073e9SAndroid Build Coastguard Worker with patch.object(UserDefineSetAttr, "setup", True): 581*da0073e9SAndroid Build Coastguard Worker self.assertEqual(obj.y, obj.x + 1) 582*da0073e9SAndroid Build Coastguard Worker self.assertEqual(obj.__dict__.keys(), {"pfx_x", "pfx_y"}) 583*da0073e9SAndroid Build Coastguard Worker 584*da0073e9SAndroid Build Coastguard Worker def test_user_defined_setattr2(self): 585*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager", fullgraph=True) 586*da0073e9SAndroid Build Coastguard Worker def fn(x): 587*da0073e9SAndroid Build Coastguard Worker obj = UserDefineSetAttr() 588*da0073e9SAndroid Build Coastguard Worker obj.x = x 589*da0073e9SAndroid Build Coastguard Worker obj.y = obj.x + 1 590*da0073e9SAndroid Build Coastguard Worker return obj 591*da0073e9SAndroid Build Coastguard Worker 592*da0073e9SAndroid Build Coastguard Worker x = torch.randn(8) 593*da0073e9SAndroid Build Coastguard Worker obj = fn(x) 594*da0073e9SAndroid Build Coastguard Worker with patch.object(UserDefineSetAttr, "setup", True): 595*da0073e9SAndroid Build Coastguard Worker self.assertIs(obj.x, x) 596*da0073e9SAndroid Build Coastguard Worker self.assertEqual(obj.y, x + 1) 597*da0073e9SAndroid Build Coastguard Worker self.assertEqual(obj.__dict__.keys(), {"pfx_x", "pfx_y"}) 598*da0073e9SAndroid Build Coastguard Worker 599*da0073e9SAndroid Build Coastguard Worker def test_closure_recompiles(self): 600*da0073e9SAndroid Build Coastguard Worker cnt = CompileCounter() 601*da0073e9SAndroid Build Coastguard Worker 602*da0073e9SAndroid Build Coastguard Worker def fn(x, other_fn): 603*da0073e9SAndroid Build Coastguard Worker return other_fn(x + 1) - 1 604*da0073e9SAndroid Build Coastguard Worker 605*da0073e9SAndroid Build Coastguard Worker opt = torch.compile(fn, backend=cnt, fullgraph=True) 606*da0073e9SAndroid Build Coastguard Worker 607*da0073e9SAndroid Build Coastguard Worker x = torch.randn(8) 608*da0073e9SAndroid Build Coastguard Worker for f in ( 609*da0073e9SAndroid Build Coastguard Worker closure_adder(5), 610*da0073e9SAndroid Build Coastguard Worker closure_adder(5), 611*da0073e9SAndroid Build Coastguard Worker closure_adder(torch.randn(8)), 612*da0073e9SAndroid Build Coastguard Worker closure_adder(torch.randn(8)), 613*da0073e9SAndroid Build Coastguard Worker ): 614*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt(x, f), fn(x, f)) 615*da0073e9SAndroid Build Coastguard Worker 616*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 2) 617*da0073e9SAndroid Build Coastguard Worker 618*da0073e9SAndroid Build Coastguard Worker def test_generate_trivial_abstract_impl(self): 619*da0073e9SAndroid Build Coastguard Worker try: 620*da0073e9SAndroid Build Coastguard Worker lib = torch.library.Library("mylib", "FRAGMENT") 621*da0073e9SAndroid Build Coastguard Worker torch.library.define( 622*da0073e9SAndroid Build Coastguard Worker "mylib::foo", 623*da0073e9SAndroid Build Coastguard Worker "(Tensor x, Tensor[] y, Tensor(a!)? z, SymInt w) -> ()", 624*da0073e9SAndroid Build Coastguard Worker tags=torch.Tag.pt2_compliant_tag, 625*da0073e9SAndroid Build Coastguard Worker lib=lib, 626*da0073e9SAndroid Build Coastguard Worker ) 627*da0073e9SAndroid Build Coastguard Worker 628*da0073e9SAndroid Build Coastguard Worker @torch.library.impl("mylib::foo", "cpu", lib=lib) 629*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.disable 630*da0073e9SAndroid Build Coastguard Worker def foo_impl(x, y, z, w): 631*da0073e9SAndroid Build Coastguard Worker x + y[0] + w 632*da0073e9SAndroid Build Coastguard Worker return 633*da0073e9SAndroid Build Coastguard Worker 634*da0073e9SAndroid Build Coastguard Worker def f(x, y, z, w): 635*da0073e9SAndroid Build Coastguard Worker return torch.ops.mylib.foo(x, y, z, 2) 636*da0073e9SAndroid Build Coastguard Worker 637*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 638*da0073e9SAndroid Build Coastguard Worker y = (torch.randn(3), torch.randn(3)) 639*da0073e9SAndroid Build Coastguard Worker z = torch.randn(3) 640*da0073e9SAndroid Build Coastguard Worker w = torch.randn(3) 641*da0073e9SAndroid Build Coastguard Worker args = (x, y, z, w) 642*da0073e9SAndroid Build Coastguard Worker 643*da0073e9SAndroid Build Coastguard Worker output = torch.compile(f, backend="eager", fullgraph=True)(*args) 644*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output, None) 645*da0073e9SAndroid Build Coastguard Worker finally: 646*da0073e9SAndroid Build Coastguard Worker cleanup_op("mylib::foo") 647*da0073e9SAndroid Build Coastguard Worker del lib 648*da0073e9SAndroid Build Coastguard Worker 649*da0073e9SAndroid Build Coastguard Worker def test_can_auto_functionalize(self): 650*da0073e9SAndroid Build Coastguard Worker from torch._higher_order_ops.auto_functionalize import can_auto_functionalize 651*da0073e9SAndroid Build Coastguard Worker 652*da0073e9SAndroid Build Coastguard Worker expected_true = [ 653*da0073e9SAndroid Build Coastguard Worker "(Tensor(a!) x) -> ()", 654*da0073e9SAndroid Build Coastguard Worker "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> ()", 655*da0073e9SAndroid Build Coastguard Worker "(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> ()", 656*da0073e9SAndroid Build Coastguard Worker "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> Tensor", 657*da0073e9SAndroid Build Coastguard Worker "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor)", 658*da0073e9SAndroid Build Coastguard Worker ] 659*da0073e9SAndroid Build Coastguard Worker expected_false = [ 660*da0073e9SAndroid Build Coastguard Worker "(Tensor x) -> ()", 661*da0073e9SAndroid Build Coastguard Worker "(Tensor(a) x) -> Tensor(a)", 662*da0073e9SAndroid Build Coastguard Worker "(Tensor(a!) x) -> Tensor(a!)", 663*da0073e9SAndroid Build Coastguard Worker "(Tensor(a!) x, Tensor y, Tensor(b!)[] z, SymInt w) -> ()", 664*da0073e9SAndroid Build Coastguard Worker "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> Tensor(a)", 665*da0073e9SAndroid Build Coastguard Worker "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor(a))", 666*da0073e9SAndroid Build Coastguard Worker "(Tensor(a) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor(a))", 667*da0073e9SAndroid Build Coastguard Worker "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor[])", 668*da0073e9SAndroid Build Coastguard Worker ] 669*da0073e9SAndroid Build Coastguard Worker for schema in expected_true: 670*da0073e9SAndroid Build Coastguard Worker try: 671*da0073e9SAndroid Build Coastguard Worker lib = torch.library.Library("mylib", "FRAGMENT") 672*da0073e9SAndroid Build Coastguard Worker torch.library.define("mylib::a", schema, lib=lib) 673*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 674*da0073e9SAndroid Build Coastguard Worker can_auto_functionalize(torch.ops.mylib.a.default), msg=schema 675*da0073e9SAndroid Build Coastguard Worker ) 676*da0073e9SAndroid Build Coastguard Worker self.assertFalse(can_auto_functionalize(torch.ops.mylib.a)) 677*da0073e9SAndroid Build Coastguard Worker finally: 678*da0073e9SAndroid Build Coastguard Worker cleanup_op("mylib::a") 679*da0073e9SAndroid Build Coastguard Worker del lib 680*da0073e9SAndroid Build Coastguard Worker for schema in expected_false: 681*da0073e9SAndroid Build Coastguard Worker try: 682*da0073e9SAndroid Build Coastguard Worker lib = torch.library.Library("mylib", "FRAGMENT") 683*da0073e9SAndroid Build Coastguard Worker torch.library.define("mylib::a", schema, lib=lib) 684*da0073e9SAndroid Build Coastguard Worker self.assertFalse( 685*da0073e9SAndroid Build Coastguard Worker can_auto_functionalize(torch.ops.mylib.a.default), msg=schema 686*da0073e9SAndroid Build Coastguard Worker ) 687*da0073e9SAndroid Build Coastguard Worker self.assertFalse(can_auto_functionalize(torch.ops.mylib.a)) 688*da0073e9SAndroid Build Coastguard Worker finally: 689*da0073e9SAndroid Build Coastguard Worker cleanup_op("mylib::a") 690*da0073e9SAndroid Build Coastguard Worker del lib 691*da0073e9SAndroid Build Coastguard Worker 692*da0073e9SAndroid Build Coastguard Worker def test_auto_functionalize(self): 693*da0073e9SAndroid Build Coastguard Worker try: 694*da0073e9SAndroid Build Coastguard Worker lib = torch.library.Library("mylib", "FRAGMENT") 695*da0073e9SAndroid Build Coastguard Worker torch.library.define( 696*da0073e9SAndroid Build Coastguard Worker "mylib::foo", 697*da0073e9SAndroid Build Coastguard Worker "(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> ()", 698*da0073e9SAndroid Build Coastguard Worker tags=torch.Tag.pt2_compliant_tag, 699*da0073e9SAndroid Build Coastguard Worker lib=lib, 700*da0073e9SAndroid Build Coastguard Worker ) 701*da0073e9SAndroid Build Coastguard Worker 702*da0073e9SAndroid Build Coastguard Worker @torch.library.impl("mylib::foo", "cpu", lib=lib) 703*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.disable 704*da0073e9SAndroid Build Coastguard Worker def foo_impl(x, y, z, w, n): 705*da0073e9SAndroid Build Coastguard Worker x.add_(y[0] + w) 706*da0073e9SAndroid Build Coastguard Worker z.add_(y[1] + n) 707*da0073e9SAndroid Build Coastguard Worker 708*da0073e9SAndroid Build Coastguard Worker def f(x, y, z, n): 709*da0073e9SAndroid Build Coastguard Worker torch.ops.mylib.foo(x, y, z, 2, n) 710*da0073e9SAndroid Build Coastguard Worker 711*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 712*da0073e9SAndroid Build Coastguard Worker y = (torch.randn(3), torch.randn(3)) 713*da0073e9SAndroid Build Coastguard Worker z = torch.randn(3) 714*da0073e9SAndroid Build Coastguard Worker n = torch.randn(3) 715*da0073e9SAndroid Build Coastguard Worker orig_args = (x, y, z, n) 716*da0073e9SAndroid Build Coastguard Worker 717*da0073e9SAndroid Build Coastguard Worker compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) 718*da0073e9SAndroid Build Coastguard Worker 719*da0073e9SAndroid Build Coastguard Worker log_stream, ctx = logs_to_string( 720*da0073e9SAndroid Build Coastguard Worker "torch._inductor.compile_fx", "post_grad_graphs" 721*da0073e9SAndroid Build Coastguard Worker ) 722*da0073e9SAndroid Build Coastguard Worker with ctx(): 723*da0073e9SAndroid Build Coastguard Worker torch.compile(f, backend="inductor", fullgraph=True)(*compiled_args) 724*da0073e9SAndroid Build Coastguard Worker 725*da0073e9SAndroid Build Coastguard Worker post_grad_graphs = "\n".join( 726*da0073e9SAndroid Build Coastguard Worker log_stream.getvalue().strip().split("\n")[3:] 727*da0073e9SAndroid Build Coastguard Worker ).strip() 728*da0073e9SAndroid Build Coastguard Worker 729*da0073e9SAndroid Build Coastguard Worker # Check the graph under static shapes 730*da0073e9SAndroid Build Coastguard Worker if torch._dynamo.config.assume_static_by_default: 731*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 732*da0073e9SAndroid Build Coastguard Worker post_grad_graphs, 733*da0073e9SAndroid Build Coastguard Worker """\ 734*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"): 735*da0073e9SAndroid Build Coastguard Worker # No stacktrace found for following nodes 736*da0073e9SAndroid Build Coastguard Worker foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg4_1 = arg2_1 = arg3_1 = arg1_1 = arg0_1 = None 737*da0073e9SAndroid Build Coastguard Worker return ()""", 738*da0073e9SAndroid Build Coastguard Worker ) 739*da0073e9SAndroid Build Coastguard Worker 740*da0073e9SAndroid Build Coastguard Worker eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) 741*da0073e9SAndroid Build Coastguard Worker f(*eager_args) 742*da0073e9SAndroid Build Coastguard Worker self.assertEqual(compiled_args, eager_args) 743*da0073e9SAndroid Build Coastguard Worker finally: 744*da0073e9SAndroid Build Coastguard Worker cleanup_op("mylib::foo") 745*da0073e9SAndroid Build Coastguard Worker del lib 746*da0073e9SAndroid Build Coastguard Worker 747*da0073e9SAndroid Build Coastguard Worker def test_auto_functionalize_with_returns(self): 748*da0073e9SAndroid Build Coastguard Worker try: 749*da0073e9SAndroid Build Coastguard Worker lib = torch.library.Library("mylib", "FRAGMENT") 750*da0073e9SAndroid Build Coastguard Worker torch.library.define( 751*da0073e9SAndroid Build Coastguard Worker "mylib::foo", 752*da0073e9SAndroid Build Coastguard Worker "(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> (Tensor, Tensor)", 753*da0073e9SAndroid Build Coastguard Worker tags=torch.Tag.pt2_compliant_tag, 754*da0073e9SAndroid Build Coastguard Worker lib=lib, 755*da0073e9SAndroid Build Coastguard Worker ) 756*da0073e9SAndroid Build Coastguard Worker 757*da0073e9SAndroid Build Coastguard Worker @torch.library.impl("mylib::foo", "cpu", lib=lib) 758*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.disable 759*da0073e9SAndroid Build Coastguard Worker def foo_impl(x, y, z, w, n): 760*da0073e9SAndroid Build Coastguard Worker x.add_(y[0] + w) 761*da0073e9SAndroid Build Coastguard Worker z.add_(y[1] + n) 762*da0073e9SAndroid Build Coastguard Worker return y[0] + w, y[1] + n 763*da0073e9SAndroid Build Coastguard Worker 764*da0073e9SAndroid Build Coastguard Worker @torch.library.impl_abstract("mylib::foo", lib=lib) 765*da0073e9SAndroid Build Coastguard Worker def foo_abstract(x, y, z, w, n): 766*da0073e9SAndroid Build Coastguard Worker return y[0] + w, y[1] + n 767*da0073e9SAndroid Build Coastguard Worker 768*da0073e9SAndroid Build Coastguard Worker def f(x, y, z, n): 769*da0073e9SAndroid Build Coastguard Worker return torch.ops.mylib.foo(x, y, z, 2, n) 770*da0073e9SAndroid Build Coastguard Worker 771*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 772*da0073e9SAndroid Build Coastguard Worker y = (torch.randn(3), torch.randn(3)) 773*da0073e9SAndroid Build Coastguard Worker z = torch.randn(3) 774*da0073e9SAndroid Build Coastguard Worker n = torch.randn(3) 775*da0073e9SAndroid Build Coastguard Worker orig_args = (x, y, z, n) 776*da0073e9SAndroid Build Coastguard Worker 777*da0073e9SAndroid Build Coastguard Worker compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) 778*da0073e9SAndroid Build Coastguard Worker log_stream, ctx = logs_to_string( 779*da0073e9SAndroid Build Coastguard Worker "torch._inductor.compile_fx", "post_grad_graphs" 780*da0073e9SAndroid Build Coastguard Worker ) 781*da0073e9SAndroid Build Coastguard Worker with ctx(): 782*da0073e9SAndroid Build Coastguard Worker compiled_out = torch.compile(f, backend="inductor", fullgraph=True)( 783*da0073e9SAndroid Build Coastguard Worker *compiled_args 784*da0073e9SAndroid Build Coastguard Worker ) 785*da0073e9SAndroid Build Coastguard Worker 786*da0073e9SAndroid Build Coastguard Worker if torch._dynamo.config.assume_static_by_default: 787*da0073e9SAndroid Build Coastguard Worker post_grad_graphs = "\n".join( 788*da0073e9SAndroid Build Coastguard Worker log_stream.getvalue().strip().split("\n")[3:] 789*da0073e9SAndroid Build Coastguard Worker ).strip() 790*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 791*da0073e9SAndroid Build Coastguard Worker post_grad_graphs, 792*da0073e9SAndroid Build Coastguard Worker """\ 793*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"): 794*da0073e9SAndroid Build Coastguard Worker # No stacktrace found for following nodes 795*da0073e9SAndroid Build Coastguard Worker foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg4_1 = arg2_1 = arg3_1 = arg1_1 = arg0_1 = None 796*da0073e9SAndroid Build Coastguard Worker getitem_4: "f32[3][1]cpu" = foo_default[0] 797*da0073e9SAndroid Build Coastguard Worker getitem_5: "f32[3][1]cpu" = foo_default[1]; foo_default = None 798*da0073e9SAndroid Build Coastguard Worker return (getitem_4, getitem_5)""", 799*da0073e9SAndroid Build Coastguard Worker ) 800*da0073e9SAndroid Build Coastguard Worker 801*da0073e9SAndroid Build Coastguard Worker eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) 802*da0073e9SAndroid Build Coastguard Worker eager_out = f(*eager_args) 803*da0073e9SAndroid Build Coastguard Worker self.assertEqual(compiled_args, eager_args) 804*da0073e9SAndroid Build Coastguard Worker self.assertEqual(compiled_out, eager_out) 805*da0073e9SAndroid Build Coastguard Worker finally: 806*da0073e9SAndroid Build Coastguard Worker cleanup_op("mylib::foo") 807*da0073e9SAndroid Build Coastguard Worker del lib 808*da0073e9SAndroid Build Coastguard Worker 809*da0073e9SAndroid Build Coastguard Worker def test_auto_functionalize_on_view(self): 810*da0073e9SAndroid Build Coastguard Worker try: 811*da0073e9SAndroid Build Coastguard Worker lib = torch.library.Library("mylib", "FRAGMENT") 812*da0073e9SAndroid Build Coastguard Worker torch.library.define( 813*da0073e9SAndroid Build Coastguard Worker "mylib::foo", 814*da0073e9SAndroid Build Coastguard Worker "(Tensor(a!) x) -> ()", 815*da0073e9SAndroid Build Coastguard Worker tags=torch.Tag.pt2_compliant_tag, 816*da0073e9SAndroid Build Coastguard Worker lib=lib, 817*da0073e9SAndroid Build Coastguard Worker ) 818*da0073e9SAndroid Build Coastguard Worker 819*da0073e9SAndroid Build Coastguard Worker @torch.library.impl("mylib::foo", "cpu", lib=lib) 820*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.disable 821*da0073e9SAndroid Build Coastguard Worker def foo_impl(x): 822*da0073e9SAndroid Build Coastguard Worker x_np = x.detach().numpy() # view 823*da0073e9SAndroid Build Coastguard Worker np.sin(x_np, out=x_np) 824*da0073e9SAndroid Build Coastguard Worker return 825*da0073e9SAndroid Build Coastguard Worker 826*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 827*da0073e9SAndroid Build Coastguard Worker expected = x.sin() 828*da0073e9SAndroid Build Coastguard Worker torch.ops.mylib.foo(x) 829*da0073e9SAndroid Build Coastguard Worker assert torch.allclose(x, expected) 830*da0073e9SAndroid Build Coastguard Worker 831*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="aot_eager_decomp_partition", fullgraph=True) 832*da0073e9SAndroid Build Coastguard Worker def f(x): 833*da0073e9SAndroid Build Coastguard Worker x = x.clone() 834*da0073e9SAndroid Build Coastguard Worker y = x[:] 835*da0073e9SAndroid Build Coastguard Worker torch.ops.mylib.foo(y) 836*da0073e9SAndroid Build Coastguard Worker return x 837*da0073e9SAndroid Build Coastguard Worker 838*da0073e9SAndroid Build Coastguard Worker y = f(x) 839*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, x.sin()) 840*da0073e9SAndroid Build Coastguard Worker finally: 841*da0073e9SAndroid Build Coastguard Worker cleanup_op("mylib::foo") 842*da0073e9SAndroid Build Coastguard Worker del lib 843*da0073e9SAndroid Build Coastguard Worker 844*da0073e9SAndroid Build Coastguard Worker def test_auto_functionalize_optional(self): 845*da0073e9SAndroid Build Coastguard Worker try: 846*da0073e9SAndroid Build Coastguard Worker lib = torch.library.Library("mylib", "FRAGMENT") 847*da0073e9SAndroid Build Coastguard Worker torch.library.define( 848*da0073e9SAndroid Build Coastguard Worker "mylib::foo", 849*da0073e9SAndroid Build Coastguard Worker "(Tensor(a!)? x, Tensor[] y, Tensor(b!)? z, SymInt w, Tensor n) -> ()", 850*da0073e9SAndroid Build Coastguard Worker tags=torch.Tag.pt2_compliant_tag, 851*da0073e9SAndroid Build Coastguard Worker lib=lib, 852*da0073e9SAndroid Build Coastguard Worker ) 853*da0073e9SAndroid Build Coastguard Worker 854*da0073e9SAndroid Build Coastguard Worker @torch.library.impl("mylib::foo", "cpu", lib=lib) 855*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.disable 856*da0073e9SAndroid Build Coastguard Worker def foo_impl(x, y, z, w, n): 857*da0073e9SAndroid Build Coastguard Worker if x is not None: 858*da0073e9SAndroid Build Coastguard Worker x.add_(y[0] + w) 859*da0073e9SAndroid Build Coastguard Worker if z is not None: 860*da0073e9SAndroid Build Coastguard Worker z.add_(y[1] + n) 861*da0073e9SAndroid Build Coastguard Worker 862*da0073e9SAndroid Build Coastguard Worker def f(x, y, z, n): 863*da0073e9SAndroid Build Coastguard Worker torch.ops.mylib.foo(x, y, z, 2, n) 864*da0073e9SAndroid Build Coastguard Worker 865*da0073e9SAndroid Build Coastguard Worker x = None 866*da0073e9SAndroid Build Coastguard Worker y = (torch.randn(3), torch.randn(3)) 867*da0073e9SAndroid Build Coastguard Worker z = torch.randn(3) 868*da0073e9SAndroid Build Coastguard Worker n = torch.randn(3) 869*da0073e9SAndroid Build Coastguard Worker orig_args = (x, y, z, n) 870*da0073e9SAndroid Build Coastguard Worker 871*da0073e9SAndroid Build Coastguard Worker compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) 872*da0073e9SAndroid Build Coastguard Worker log_stream, ctx = logs_to_string( 873*da0073e9SAndroid Build Coastguard Worker "torch._inductor.compile_fx", "post_grad_graphs" 874*da0073e9SAndroid Build Coastguard Worker ) 875*da0073e9SAndroid Build Coastguard Worker with ctx(): 876*da0073e9SAndroid Build Coastguard Worker torch.compile(f, backend="inductor", fullgraph=True)(*compiled_args) 877*da0073e9SAndroid Build Coastguard Worker 878*da0073e9SAndroid Build Coastguard Worker if torch._dynamo.config.assume_static_by_default: 879*da0073e9SAndroid Build Coastguard Worker post_grad_graphs = "\n".join( 880*da0073e9SAndroid Build Coastguard Worker log_stream.getvalue().strip().split("\n")[3:] 881*da0073e9SAndroid Build Coastguard Worker ).strip() 882*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 883*da0073e9SAndroid Build Coastguard Worker post_grad_graphs, 884*da0073e9SAndroid Build Coastguard Worker """\ 885*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu"): 886*da0073e9SAndroid Build Coastguard Worker # No stacktrace found for following nodes 887*da0073e9SAndroid Build Coastguard Worker foo_default = torch.ops.mylib.foo.default(None, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg1_1 = arg0_1 = None 888*da0073e9SAndroid Build Coastguard Worker return ()""", 889*da0073e9SAndroid Build Coastguard Worker ) 890*da0073e9SAndroid Build Coastguard Worker 891*da0073e9SAndroid Build Coastguard Worker eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) 892*da0073e9SAndroid Build Coastguard Worker f(*eager_args) 893*da0073e9SAndroid Build Coastguard Worker self.assertEqual(compiled_args, eager_args) 894*da0073e9SAndroid Build Coastguard Worker finally: 895*da0073e9SAndroid Build Coastguard Worker cleanup_op("mylib::foo") 896*da0073e9SAndroid Build Coastguard Worker del lib 897*da0073e9SAndroid Build Coastguard Worker 898*da0073e9SAndroid Build Coastguard Worker def test_shape_int_inplace_binops(self): 899*da0073e9SAndroid Build Coastguard Worker def fn(x): 900*da0073e9SAndroid Build Coastguard Worker p = x.shape[0] 901*da0073e9SAndroid Build Coastguard Worker p += 2 902*da0073e9SAndroid Build Coastguard Worker p -= 2 903*da0073e9SAndroid Build Coastguard Worker p **= 2 904*da0073e9SAndroid Build Coastguard Worker p /= 2 905*da0073e9SAndroid Build Coastguard Worker p *= 2 906*da0073e9SAndroid Build Coastguard Worker p //= 2 907*da0073e9SAndroid Build Coastguard Worker p %= 2 908*da0073e9SAndroid Build Coastguard Worker return x + p 909*da0073e9SAndroid Build Coastguard Worker 910*da0073e9SAndroid Build Coastguard Worker torch._dynamo.testing.standard_test( 911*da0073e9SAndroid Build Coastguard Worker self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(1, 10) 912*da0073e9SAndroid Build Coastguard Worker ) 913*da0073e9SAndroid Build Coastguard Worker 914*da0073e9SAndroid Build Coastguard Worker def test_int_shape_inplace_binops(self): 915*da0073e9SAndroid Build Coastguard Worker def fn(x): 916*da0073e9SAndroid Build Coastguard Worker p = x.shape[0] 917*da0073e9SAndroid Build Coastguard Worker # Test reversal by putting constant first 918*da0073e9SAndroid Build Coastguard Worker y = 2 919*da0073e9SAndroid Build Coastguard Worker y += p 920*da0073e9SAndroid Build Coastguard Worker y = 2 921*da0073e9SAndroid Build Coastguard Worker y -= p 922*da0073e9SAndroid Build Coastguard Worker y = 2 923*da0073e9SAndroid Build Coastguard Worker y **= p 924*da0073e9SAndroid Build Coastguard Worker y = 2 925*da0073e9SAndroid Build Coastguard Worker y /= p 926*da0073e9SAndroid Build Coastguard Worker y = 2 927*da0073e9SAndroid Build Coastguard Worker y *= p 928*da0073e9SAndroid Build Coastguard Worker y = 2 929*da0073e9SAndroid Build Coastguard Worker y //= p 930*da0073e9SAndroid Build Coastguard Worker y = 2 931*da0073e9SAndroid Build Coastguard Worker y %= p 932*da0073e9SAndroid Build Coastguard Worker return x + y 933*da0073e9SAndroid Build Coastguard Worker 934*da0073e9SAndroid Build Coastguard Worker torch._dynamo.testing.standard_test( 935*da0073e9SAndroid Build Coastguard Worker self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(1, 4) 936*da0073e9SAndroid Build Coastguard Worker ) 937*da0073e9SAndroid Build Coastguard Worker 938*da0073e9SAndroid Build Coastguard Worker def test_int_int_comparisons(self): 939*da0073e9SAndroid Build Coastguard Worker def fn(x): 940*da0073e9SAndroid Build Coastguard Worker if 2 != 2: 941*da0073e9SAndroid Build Coastguard Worker out = 1 942*da0073e9SAndroid Build Coastguard Worker elif 2 < 1: 943*da0073e9SAndroid Build Coastguard Worker out = 1 944*da0073e9SAndroid Build Coastguard Worker elif 1 > 2: 945*da0073e9SAndroid Build Coastguard Worker out = 1 946*da0073e9SAndroid Build Coastguard Worker elif 1 >= 2: 947*da0073e9SAndroid Build Coastguard Worker out = 1 948*da0073e9SAndroid Build Coastguard Worker elif 2 <= 1: 949*da0073e9SAndroid Build Coastguard Worker out = 1 950*da0073e9SAndroid Build Coastguard Worker elif 2 == 2: 951*da0073e9SAndroid Build Coastguard Worker out = 2 952*da0073e9SAndroid Build Coastguard Worker else: 953*da0073e9SAndroid Build Coastguard Worker out = 1 954*da0073e9SAndroid Build Coastguard Worker return x + out 955*da0073e9SAndroid Build Coastguard Worker 956*da0073e9SAndroid Build Coastguard Worker torch._dynamo.testing.standard_test(self, fn, 1, expected_ops=1) 957*da0073e9SAndroid Build Coastguard Worker 958*da0073e9SAndroid Build Coastguard Worker def test_shape_int_comparisons(self): 959*da0073e9SAndroid Build Coastguard Worker def fn(x): 960*da0073e9SAndroid Build Coastguard Worker a = x.shape[0] 961*da0073e9SAndroid Build Coastguard Worker # Ensure support for constant on right side 962*da0073e9SAndroid Build Coastguard Worker if a != 10: 963*da0073e9SAndroid Build Coastguard Worker out = 1 964*da0073e9SAndroid Build Coastguard Worker elif a < 2: 965*da0073e9SAndroid Build Coastguard Worker out = 1 966*da0073e9SAndroid Build Coastguard Worker elif a > 12: 967*da0073e9SAndroid Build Coastguard Worker out = 1 968*da0073e9SAndroid Build Coastguard Worker elif a >= 12: 969*da0073e9SAndroid Build Coastguard Worker out = 1 970*da0073e9SAndroid Build Coastguard Worker elif a <= 2: 971*da0073e9SAndroid Build Coastguard Worker out = 1 972*da0073e9SAndroid Build Coastguard Worker elif a == 10: 973*da0073e9SAndroid Build Coastguard Worker out = 2 974*da0073e9SAndroid Build Coastguard Worker else: 975*da0073e9SAndroid Build Coastguard Worker out = 1 976*da0073e9SAndroid Build Coastguard Worker return x + out 977*da0073e9SAndroid Build Coastguard Worker 978*da0073e9SAndroid Build Coastguard Worker # TODO: Test the guards maybe? 979*da0073e9SAndroid Build Coastguard Worker torch._dynamo.testing.standard_test(self, fn, 1, expected_ops=1) 980*da0073e9SAndroid Build Coastguard Worker 981*da0073e9SAndroid Build Coastguard Worker def test_int_shape_comparisons(self): 982*da0073e9SAndroid Build Coastguard Worker def fn(x): 983*da0073e9SAndroid Build Coastguard Worker a = x.shape[0] 984*da0073e9SAndroid Build Coastguard Worker # Ensure support for constant on left side 985*da0073e9SAndroid Build Coastguard Worker if 10 != a: 986*da0073e9SAndroid Build Coastguard Worker out = 1 987*da0073e9SAndroid Build Coastguard Worker elif 12 < a: 988*da0073e9SAndroid Build Coastguard Worker out = 1 989*da0073e9SAndroid Build Coastguard Worker elif 2 > a: 990*da0073e9SAndroid Build Coastguard Worker out = 1 991*da0073e9SAndroid Build Coastguard Worker elif 2 >= a: 992*da0073e9SAndroid Build Coastguard Worker out = 1 993*da0073e9SAndroid Build Coastguard Worker elif 12 <= a: 994*da0073e9SAndroid Build Coastguard Worker out = 1 995*da0073e9SAndroid Build Coastguard Worker elif 10 == a: 996*da0073e9SAndroid Build Coastguard Worker out = 2 997*da0073e9SAndroid Build Coastguard Worker else: 998*da0073e9SAndroid Build Coastguard Worker out = 1 999*da0073e9SAndroid Build Coastguard Worker return x + out 1000*da0073e9SAndroid Build Coastguard Worker 1001*da0073e9SAndroid Build Coastguard Worker # TODO: Test the guards maybe? 1002*da0073e9SAndroid Build Coastguard Worker torch._dynamo.testing.standard_test(self, fn, 1, expected_ops=1) 1003*da0073e9SAndroid Build Coastguard Worker 1004*da0073e9SAndroid Build Coastguard Worker def test_param_shape_binops(self): 1005*da0073e9SAndroid Build Coastguard Worker class MyModule(torch.nn.Module): 1006*da0073e9SAndroid Build Coastguard Worker def __init__(self): 1007*da0073e9SAndroid Build Coastguard Worker super().__init__() 1008*da0073e9SAndroid Build Coastguard Worker self.param = torch.nn.Parameter(torch.randn(15)) 1009*da0073e9SAndroid Build Coastguard Worker 1010*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1011*da0073e9SAndroid Build Coastguard Worker # Test reversal by putting param shape arg first. 1012*da0073e9SAndroid Build Coastguard Worker p = self.param.shape[0] 1013*da0073e9SAndroid Build Coastguard Worker y = p - x.shape[0] 1014*da0073e9SAndroid Build Coastguard Worker y = p + y 1015*da0073e9SAndroid Build Coastguard Worker y = p * y 1016*da0073e9SAndroid Build Coastguard Worker y = p % y 1017*da0073e9SAndroid Build Coastguard Worker y = p**y 1018*da0073e9SAndroid Build Coastguard Worker y = p // y 1019*da0073e9SAndroid Build Coastguard Worker y = pow(p, y) 1020*da0073e9SAndroid Build Coastguard Worker y = p / y 1021*da0073e9SAndroid Build Coastguard Worker return x + y 1022*da0073e9SAndroid Build Coastguard Worker 1023*da0073e9SAndroid Build Coastguard Worker counts = torch._dynamo.testing.CompileCounter() 1024*da0073e9SAndroid Build Coastguard Worker mod = MyModule() 1025*da0073e9SAndroid Build Coastguard Worker optimized_mod = torch._dynamo.optimize(counts, nopython=True)(mod) 1026*da0073e9SAndroid Build Coastguard Worker 1027*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 1028*da0073e9SAndroid Build Coastguard Worker ref = mod(x) 1029*da0073e9SAndroid Build Coastguard Worker res = optimized_mod(x) 1030*da0073e9SAndroid Build Coastguard Worker 1031*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 1032*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counts.frame_count, 1) 1033*da0073e9SAndroid Build Coastguard Worker 1034*da0073e9SAndroid Build Coastguard Worker if torch._dynamo.config.assume_static_by_default: 1035*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(counts.op_count, """1""") 1036*da0073e9SAndroid Build Coastguard Worker else: 1037*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(counts.op_count, """11""") 1038*da0073e9SAndroid Build Coastguard Worker 1039*da0073e9SAndroid Build Coastguard Worker def test_user_defined_binop(self): 1040*da0073e9SAndroid Build Coastguard Worker class MyClass: 1041*da0073e9SAndroid Build Coastguard Worker def __init__(self, value): 1042*da0073e9SAndroid Build Coastguard Worker self.value = value 1043*da0073e9SAndroid Build Coastguard Worker 1044*da0073e9SAndroid Build Coastguard Worker def __radd__(self, other): 1045*da0073e9SAndroid Build Coastguard Worker return self.value + other 1046*da0073e9SAndroid Build Coastguard Worker 1047*da0073e9SAndroid Build Coastguard Worker def fn(x, c): 1048*da0073e9SAndroid Build Coastguard Worker y = x.shape[0] + c 1049*da0073e9SAndroid Build Coastguard Worker return x + y 1050*da0073e9SAndroid Build Coastguard Worker 1051*da0073e9SAndroid Build Coastguard Worker counts = torch._dynamo.testing.CompileCounter() 1052*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(counts)(fn) 1053*da0073e9SAndroid Build Coastguard Worker 1054*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 1055*da0073e9SAndroid Build Coastguard Worker c = MyClass(4) 1056*da0073e9SAndroid Build Coastguard Worker ref = fn(x, c) 1057*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x, c) 1058*da0073e9SAndroid Build Coastguard Worker 1059*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 1060*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counts.frame_count, 1) 1061*da0073e9SAndroid Build Coastguard Worker if torch._dynamo.config.assume_static_by_default: 1062*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(counts.op_count, """1""") 1063*da0073e9SAndroid Build Coastguard Worker else: 1064*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(counts.op_count, """4""") 1065*da0073e9SAndroid Build Coastguard Worker 1066*da0073e9SAndroid Build Coastguard Worker def test_user_defined_iter(self): 1067*da0073e9SAndroid Build Coastguard Worker class Mod: 1068*da0073e9SAndroid Build Coastguard Worker def __init__(self): 1069*da0073e9SAndroid Build Coastguard Worker self.a = [torch.randn(2, 2), torch.randn(2, 2)] 1070*da0073e9SAndroid Build Coastguard Worker 1071*da0073e9SAndroid Build Coastguard Worker def __iter__(self): 1072*da0073e9SAndroid Build Coastguard Worker return iter(self.a) 1073*da0073e9SAndroid Build Coastguard Worker 1074*da0073e9SAndroid Build Coastguard Worker def f(mod): 1075*da0073e9SAndroid Build Coastguard Worker ret = [] 1076*da0073e9SAndroid Build Coastguard Worker for x in mod: 1077*da0073e9SAndroid Build Coastguard Worker ret.append(x + 1) 1078*da0073e9SAndroid Build Coastguard Worker return ret 1079*da0073e9SAndroid Build Coastguard Worker 1080*da0073e9SAndroid Build Coastguard Worker mod = Mod() 1081*da0073e9SAndroid Build Coastguard Worker counts = torch._dynamo.testing.CompileCounter() 1082*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(counts, nopython=True)(f) 1083*da0073e9SAndroid Build Coastguard Worker ref = f(mod) 1084*da0073e9SAndroid Build Coastguard Worker res = opt_fn(mod) 1085*da0073e9SAndroid Build Coastguard Worker res = opt_fn(mod) 1086*da0073e9SAndroid Build Coastguard Worker res = opt_fn(mod) 1087*da0073e9SAndroid Build Coastguard Worker res = opt_fn(mod) 1088*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 1089*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counts.frame_count, 1) 1090*da0073e9SAndroid Build Coastguard Worker 1091*da0073e9SAndroid Build Coastguard Worker mod.a.append(torch.randn(2, 2)) 1092*da0073e9SAndroid Build Coastguard Worker # `for x in mod` is inlined, where iter(m.a) creates a guard on the list length of m.a 1093*da0073e9SAndroid Build Coastguard Worker # Mutating length of mod.a causes a re-compilation. 1094*da0073e9SAndroid Build Coastguard Worker ref2 = f(mod) 1095*da0073e9SAndroid Build Coastguard Worker res2 = opt_fn(mod) 1096*da0073e9SAndroid Build Coastguard Worker res2 = opt_fn(mod) 1097*da0073e9SAndroid Build Coastguard Worker res2 = opt_fn(mod) 1098*da0073e9SAndroid Build Coastguard Worker res2 = opt_fn(mod) 1099*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref2, res2)) 1100*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counts.frame_count, 2) 1101*da0073e9SAndroid Build Coastguard Worker 1102*da0073e9SAndroid Build Coastguard Worker def test_compare_shapes_eq(self): 1103*da0073e9SAndroid Build Coastguard Worker def compare_shapes(a, b, to_list): 1104*da0073e9SAndroid Build Coastguard Worker x = list(a.unsqueeze(-1).shape) if to_list else a.shape 1105*da0073e9SAndroid Build Coastguard Worker y = list(b.unsqueeze(-1).shape) if to_list else b.shape 1106*da0073e9SAndroid Build Coastguard Worker if x == y: 1107*da0073e9SAndroid Build Coastguard Worker return a + 1 1108*da0073e9SAndroid Build Coastguard Worker else: 1109*da0073e9SAndroid Build Coastguard Worker return a + 2 1110*da0073e9SAndroid Build Coastguard Worker 1111*da0073e9SAndroid Build Coastguard Worker # Test both ListVariable and ShapeVariable 1112*da0073e9SAndroid Build Coastguard Worker torch._dynamo.testing.standard_test( 1113*da0073e9SAndroid Build Coastguard Worker self, lambda a, b: compare_shapes(a, b, to_list=True), 2 1114*da0073e9SAndroid Build Coastguard Worker ) 1115*da0073e9SAndroid Build Coastguard Worker torch._dynamo.testing.standard_test( 1116*da0073e9SAndroid Build Coastguard Worker self, lambda a, b: compare_shapes(a, b, to_list=False), 2 1117*da0073e9SAndroid Build Coastguard Worker ) 1118*da0073e9SAndroid Build Coastguard Worker 1119*da0073e9SAndroid Build Coastguard Worker def test_compare_shapes_tuple_eq(self): 1120*da0073e9SAndroid Build Coastguard Worker def compare_shapes(a, b): 1121*da0073e9SAndroid Build Coastguard Worker x = tuple(a.unsqueeze(-1).shape) 1122*da0073e9SAndroid Build Coastguard Worker y = tuple(b.unsqueeze(-1).shape) 1123*da0073e9SAndroid Build Coastguard Worker if x == y: 1124*da0073e9SAndroid Build Coastguard Worker return a + 1 1125*da0073e9SAndroid Build Coastguard Worker else: 1126*da0073e9SAndroid Build Coastguard Worker return a + 2 1127*da0073e9SAndroid Build Coastguard Worker 1128*da0073e9SAndroid Build Coastguard Worker torch._dynamo.testing.standard_test(self, lambda a, b: compare_shapes(a, b), 2) 1129*da0073e9SAndroid Build Coastguard Worker 1130*da0073e9SAndroid Build Coastguard Worker def test_compare_shapes_tuple_neq(self): 1131*da0073e9SAndroid Build Coastguard Worker def compare_shapes(a, b): 1132*da0073e9SAndroid Build Coastguard Worker x = tuple(a.unsqueeze(-1).shape) 1133*da0073e9SAndroid Build Coastguard Worker y = tuple(b.unsqueeze(-1).shape) 1134*da0073e9SAndroid Build Coastguard Worker if x != y: 1135*da0073e9SAndroid Build Coastguard Worker return a + 1 1136*da0073e9SAndroid Build Coastguard Worker else: 1137*da0073e9SAndroid Build Coastguard Worker return a + 2 1138*da0073e9SAndroid Build Coastguard Worker 1139*da0073e9SAndroid Build Coastguard Worker torch._dynamo.testing.standard_test(self, lambda a, b: compare_shapes(a, b), 2) 1140*da0073e9SAndroid Build Coastguard Worker 1141*da0073e9SAndroid Build Coastguard Worker def test_compare_shapes_neq(self): 1142*da0073e9SAndroid Build Coastguard Worker def compare_shapes(a, b, to_list): 1143*da0073e9SAndroid Build Coastguard Worker x = list(a.unsqueeze(-1).shape) if to_list else a.shape 1144*da0073e9SAndroid Build Coastguard Worker y = list(b.unsqueeze(-1).shape) if to_list else b.shape 1145*da0073e9SAndroid Build Coastguard Worker if x != y: 1146*da0073e9SAndroid Build Coastguard Worker return a + 1 1147*da0073e9SAndroid Build Coastguard Worker else: 1148*da0073e9SAndroid Build Coastguard Worker return a + 2 1149*da0073e9SAndroid Build Coastguard Worker 1150*da0073e9SAndroid Build Coastguard Worker # Test both ListVariable and ShapeVariable 1151*da0073e9SAndroid Build Coastguard Worker torch._dynamo.testing.standard_test( 1152*da0073e9SAndroid Build Coastguard Worker self, lambda a, b: compare_shapes(a, b, to_list=True), 2 1153*da0073e9SAndroid Build Coastguard Worker ) 1154*da0073e9SAndroid Build Coastguard Worker torch._dynamo.testing.standard_test( 1155*da0073e9SAndroid Build Coastguard Worker self, lambda a, b: compare_shapes(a, b, to_list=False), 2 1156*da0073e9SAndroid Build Coastguard Worker ) 1157*da0073e9SAndroid Build Coastguard Worker 1158*da0073e9SAndroid Build Coastguard Worker def test_compare_shapes_with_constant(self): 1159*da0073e9SAndroid Build Coastguard Worker def compare_shapes(a): 1160*da0073e9SAndroid Build Coastguard Worker x = a.shape 1161*da0073e9SAndroid Build Coastguard Worker if x[0] != 3: 1162*da0073e9SAndroid Build Coastguard Worker return a * 4 1163*da0073e9SAndroid Build Coastguard Worker return a * 3 1164*da0073e9SAndroid Build Coastguard Worker 1165*da0073e9SAndroid Build Coastguard Worker guard_failure = None 1166*da0073e9SAndroid Build Coastguard Worker 1167*da0073e9SAndroid Build Coastguard Worker def guard_failures(failure): 1168*da0073e9SAndroid Build Coastguard Worker nonlocal guard_failure 1169*da0073e9SAndroid Build Coastguard Worker guard_failure = failure 1170*da0073e9SAndroid Build Coastguard Worker 1171*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize( 1172*da0073e9SAndroid Build Coastguard Worker "eager", nopython=True, guard_fail_fn=guard_failures 1173*da0073e9SAndroid Build Coastguard Worker )(compare_shapes) 1174*da0073e9SAndroid Build Coastguard Worker opt_fn(torch.randn([3, 4])) 1175*da0073e9SAndroid Build Coastguard Worker opt_fn(torch.randn([4, 3])) 1176*da0073e9SAndroid Build Coastguard Worker self.assertIn( 1177*da0073e9SAndroid Build Coastguard Worker """tensor 'L['a']' size mismatch at index 0. expected 3, actual 4""", 1178*da0073e9SAndroid Build Coastguard Worker guard_failure.reason, 1179*da0073e9SAndroid Build Coastguard Worker ) 1180*da0073e9SAndroid Build Coastguard Worker 1181*da0073e9SAndroid Build Coastguard Worker def test_builtin_abs(self): 1182*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 1183*da0073e9SAndroid Build Coastguard Worker return abs(x) + abs(y) 1184*da0073e9SAndroid Build Coastguard Worker 1185*da0073e9SAndroid Build Coastguard Worker sample = torch.randn(10, 10) 1186*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) 1187*da0073e9SAndroid Build Coastguard Worker 1188*da0073e9SAndroid Build Coastguard Worker for sample in [ 1189*da0073e9SAndroid Build Coastguard Worker (torch.randn(10, 10), torch.randn(10, 10)), 1190*da0073e9SAndroid Build Coastguard Worker (-10, make_tensor(10, dtype=torch.int64, device="cpu")), 1191*da0073e9SAndroid Build Coastguard Worker (-0.1, torch.randn(10)), 1192*da0073e9SAndroid Build Coastguard Worker ]: 1193*da0073e9SAndroid Build Coastguard Worker expect = fn(*sample) 1194*da0073e9SAndroid Build Coastguard Worker actual = opt_fn(*sample) 1195*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expect, actual) 1196*da0073e9SAndroid Build Coastguard Worker 1197*da0073e9SAndroid Build Coastguard Worker def test_builtin_isinstance(self): 1198*da0073e9SAndroid Build Coastguard Worker def fn(x): 1199*da0073e9SAndroid Build Coastguard Worker t = torch.arange(1, 3) 1200*da0073e9SAndroid Build Coastguard Worker a = isinstance(x, torch.Tensor) 1201*da0073e9SAndroid Build Coastguard Worker b = isinstance(t, torch.Tensor) 1202*da0073e9SAndroid Build Coastguard Worker c = isinstance(x, int) 1203*da0073e9SAndroid Build Coastguard Worker d = isinstance(3, int) 1204*da0073e9SAndroid Build Coastguard Worker e = isinstance([1, 2, 3], list) 1205*da0073e9SAndroid Build Coastguard Worker f = isinstance({"foo": 1, "bar": 2}, dict) 1206*da0073e9SAndroid Build Coastguard Worker res = [a, b, c, d, e, f] 1207*da0073e9SAndroid Build Coastguard Worker # Can't run yet due to other unimplemented instructions 1208*da0073e9SAndroid Build Coastguard Worker # res += [isinstance(torch.nn.LazyLinear(2, 3), torch.nn.Linear)] 1209*da0073e9SAndroid Build Coastguard Worker return res 1210*da0073e9SAndroid Build Coastguard Worker 1211*da0073e9SAndroid Build Coastguard Worker torch._dynamo.testing.standard_test(self, fn, 1, expected_ops=1) 1212*da0073e9SAndroid Build Coastguard Worker 1213*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(sys.version_info[:2] <= (3, 8), "Requires astunparse") 1214*da0073e9SAndroid Build Coastguard Worker def test_cse_dict_guards(self): 1215*da0073e9SAndroid Build Coastguard Worker def fn(x): 1216*da0073e9SAndroid Build Coastguard Worker ret = torch.zeros(3) 1217*da0073e9SAndroid Build Coastguard Worker for v in x.values(): 1218*da0073e9SAndroid Build Coastguard Worker ret = ret + v 1219*da0073e9SAndroid Build Coastguard Worker return ret 1220*da0073e9SAndroid Build Coastguard Worker 1221*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.guards import build_guard_function, CLOSURE_VARS 1222*da0073e9SAndroid Build Coastguard Worker 1223*da0073e9SAndroid Build Coastguard Worker x = {3: torch.randn(3), 2: torch.randn(3), 4: torch.randn(3)} 1224*da0073e9SAndroid Build Coastguard Worker _, guards = torch._dynamo.export(fn, x) 1225*da0073e9SAndroid Build Coastguard Worker 1226*da0073e9SAndroid Build Coastguard Worker code_lists = [c for g in guards for c in g.code_list or []] 1227*da0073e9SAndroid Build Coastguard Worker _, pycode = build_guard_function(code_lists, []) 1228*da0073e9SAndroid Build Coastguard Worker # Make sure we just call "list(dict.keys())" once 1229*da0073e9SAndroid Build Coastguard Worker self.assertEqual(pycode.count("keys"), 1) 1230*da0073e9SAndroid Build Coastguard Worker 1231*da0073e9SAndroid Build Coastguard Worker def test_sys_modules(self): 1232*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 1233*da0073e9SAndroid Build Coastguard Worker mod_a = sys.modules.get("aaaaaaaa") 1234*da0073e9SAndroid Build Coastguard Worker assert mod_a is None 1235*da0073e9SAndroid Build Coastguard Worker assert "bbbbbbbb" not in sys.modules 1236*da0073e9SAndroid Build Coastguard Worker 1237*da0073e9SAndroid Build Coastguard Worker assert "operator" in sys.modules 1238*da0073e9SAndroid Build Coastguard Worker operator = sys.modules["operator"] 1239*da0073e9SAndroid Build Coastguard Worker builtins = sys.modules.get("builtins") 1240*da0073e9SAndroid Build Coastguard Worker operator2 = sys.modules.get("cccccccc", operator) 1241*da0073e9SAndroid Build Coastguard Worker 1242*da0073e9SAndroid Build Coastguard Worker return operator.add(x, y), operator2.neg(builtins.abs(x)) 1243*da0073e9SAndroid Build Coastguard Worker 1244*da0073e9SAndroid Build Coastguard Worker torch._dynamo.testing.standard_test(self, fn, 2, expected_ops=3) 1245*da0073e9SAndroid Build Coastguard Worker 1246*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10, 10) 1247*da0073e9SAndroid Build Coastguard Worker _, guards = torch._dynamo.export(fn, x, x) 1248*da0073e9SAndroid Build Coastguard Worker guard_code = [] 1249*da0073e9SAndroid Build Coastguard Worker for guard in guards: 1250*da0073e9SAndroid Build Coastguard Worker if guard.code_list: 1251*da0073e9SAndroid Build Coastguard Worker guard_code += guard.code_list 1252*da0073e9SAndroid Build Coastguard Worker 1253*da0073e9SAndroid Build Coastguard Worker # Filter out id-matches that won't reproduce run to run 1254*da0073e9SAndroid Build Coastguard Worker guard_code = filter( 1255*da0073e9SAndroid Build Coastguard Worker lambda line: "id" not in line and "lookup_backend" not in line, 1256*da0073e9SAndroid Build Coastguard Worker sorted(guard_code), 1257*da0073e9SAndroid Build Coastguard Worker ) 1258*da0073e9SAndroid Build Coastguard Worker guard_code_str = "\n".join(guard_code) 1259*da0073e9SAndroid Build Coastguard Worker 1260*da0073e9SAndroid Build Coastguard Worker for line in """\ 1261*da0073e9SAndroid Build Coastguard Worker2 <= L['x'].size()[0] 1262*da0073e9SAndroid Build Coastguard WorkerL['x'] is L['y'] 1263*da0073e9SAndroid Build Coastguard WorkerL['x'].ndimension() == 2 1264*da0073e9SAndroid Build Coastguard WorkerL['x'].requires_grad == False 1265*da0073e9SAndroid Build Coastguard WorkerL['x'].size()[1] == L['x'].size()[0] 1266*da0073e9SAndroid Build Coastguard WorkerL['x'].storage_offset() == 0 1267*da0073e9SAndroid Build Coastguard Worker___dict_contains('builtins', G['sys'].modules) 1268*da0073e9SAndroid Build Coastguard Worker___dict_contains('operator', G['sys'].modules) 1269*da0073e9SAndroid Build Coastguard Worker___dict_contains('operator', G['sys'].modules) 1270*da0073e9SAndroid Build Coastguard Workerhasattr(L['x'], '_dynamo_dynamic_indices') == False 1271*da0073e9SAndroid Build Coastguard Workernot ___dict_contains('aaaaaaaa', G['sys'].modules) 1272*da0073e9SAndroid Build Coastguard Workernot ___dict_contains('bbbbbbbb', G['sys'].modules) 1273*da0073e9SAndroid Build Coastguard Workernot ___dict_contains('cccccccc', G['sys'].modules) 1274*da0073e9SAndroid Build Coastguard Workerstr(L['x'].device) == 'cpu' 1275*da0073e9SAndroid Build Coastguard Workerstr(L['x'].dtype) == 'torch.float32' 1276*da0073e9SAndroid Build Coastguard Workerutils_device.CURRENT_DEVICE == None""".split( 1277*da0073e9SAndroid Build Coastguard Worker "\n" 1278*da0073e9SAndroid Build Coastguard Worker ): 1279*da0073e9SAndroid Build Coastguard Worker self.assertIn( 1280*da0073e9SAndroid Build Coastguard Worker line, 1281*da0073e9SAndroid Build Coastguard Worker guard_code_str, 1282*da0073e9SAndroid Build Coastguard Worker ) 1283*da0073e9SAndroid Build Coastguard Worker 1284*da0073e9SAndroid Build Coastguard Worker def test_fold(self): 1285*da0073e9SAndroid Build Coastguard Worker def fn(a): 1286*da0073e9SAndroid Build Coastguard Worker return a + math.sqrt(63) 1287*da0073e9SAndroid Build Coastguard Worker 1288*da0073e9SAndroid Build Coastguard Worker torch._dynamo.testing.standard_test(self, fn, 1, expected_ops=1) 1289*da0073e9SAndroid Build Coastguard Worker 1290*da0073e9SAndroid Build Coastguard Worker def test_getattr_dict(self): 1291*da0073e9SAndroid Build Coastguard Worker def fn(x): 1292*da0073e9SAndroid Build Coastguard Worker from torch.masked.maskedtensor._ops_refs import _MASKEDTENSOR_FUNCTION_TABLE 1293*da0073e9SAndroid Build Coastguard Worker 1294*da0073e9SAndroid Build Coastguard Worker return x * len(_MASKEDTENSOR_FUNCTION_TABLE) 1295*da0073e9SAndroid Build Coastguard Worker 1296*da0073e9SAndroid Build Coastguard Worker i = torch.randn(5) 1297*da0073e9SAndroid Build Coastguard Worker r1 = fn(i) 1298*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 1299*da0073e9SAndroid Build Coastguard Worker r2 = opt_fn(i) 1300*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r1, r2) 1301*da0073e9SAndroid Build Coastguard Worker 1302*da0073e9SAndroid Build Coastguard Worker def test_shape_unpack(self): 1303*da0073e9SAndroid Build Coastguard Worker def fn(x): 1304*da0073e9SAndroid Build Coastguard Worker a, b = x.size() 1305*da0073e9SAndroid Build Coastguard Worker return x * b 1306*da0073e9SAndroid Build Coastguard Worker 1307*da0073e9SAndroid Build Coastguard Worker i = torch.randn(5, 10) 1308*da0073e9SAndroid Build Coastguard Worker r1 = fn(i) 1309*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager")(fn) 1310*da0073e9SAndroid Build Coastguard Worker r2 = opt_fn(i) 1311*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(r1, r2)) 1312*da0073e9SAndroid Build Coastguard Worker 1313*da0073e9SAndroid Build Coastguard Worker def test_typing_dict(self): 1314*da0073e9SAndroid Build Coastguard Worker def fn(d): 1315*da0073e9SAndroid Build Coastguard Worker return d[T] 1316*da0073e9SAndroid Build Coastguard Worker 1317*da0073e9SAndroid Build Coastguard Worker d = {T: torch.randn(3)} 1318*da0073e9SAndroid Build Coastguard Worker r1 = fn(d) 1319*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 1320*da0073e9SAndroid Build Coastguard Worker r2 = opt_fn(d) 1321*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r1, r2) 1322*da0073e9SAndroid Build Coastguard Worker 1323*da0073e9SAndroid Build Coastguard Worker def test_tensor_iter(self): 1324*da0073e9SAndroid Build Coastguard Worker def fn(x): 1325*da0073e9SAndroid Build Coastguard Worker for y in x: 1326*da0073e9SAndroid Build Coastguard Worker y.add_(1.0) 1327*da0073e9SAndroid Build Coastguard Worker return y 1328*da0073e9SAndroid Build Coastguard Worker 1329*da0073e9SAndroid Build Coastguard Worker torch._dynamo.testing.standard_test( 1330*da0073e9SAndroid Build Coastguard Worker self, 1331*da0073e9SAndroid Build Coastguard Worker fn, 1332*da0073e9SAndroid Build Coastguard Worker 1, 1333*da0073e9SAndroid Build Coastguard Worker expected_ops=20, 1334*da0073e9SAndroid Build Coastguard Worker ) 1335*da0073e9SAndroid Build Coastguard Worker 1336*da0073e9SAndroid Build Coastguard Worker def test_empty_list(self): 1337*da0073e9SAndroid Build Coastguard Worker def fn(x, ll): 1338*da0073e9SAndroid Build Coastguard Worker if len(ll) == 0 and not ll and ll is not None: 1339*da0073e9SAndroid Build Coastguard Worker return x + 1 1340*da0073e9SAndroid Build Coastguard Worker 1341*da0073e9SAndroid Build Coastguard Worker i = torch.randn(5, 10) 1342*da0073e9SAndroid Build Coastguard Worker r1 = fn(i, []) 1343*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager")(fn) 1344*da0073e9SAndroid Build Coastguard Worker r2 = opt_fn(i, []) 1345*da0073e9SAndroid Build Coastguard Worker r3 = opt_fn(i, tuple()) 1346*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(r1, r2)) 1347*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(r1, r3)) 1348*da0073e9SAndroid Build Coastguard Worker 1349*da0073e9SAndroid Build Coastguard Worker def test_min_max_over_iterable(self): 1350*da0073e9SAndroid Build Coastguard Worker def get_test_fn(func): 1351*da0073e9SAndroid Build Coastguard Worker def _fn(a, b, func=func): 1352*da0073e9SAndroid Build Coastguard Worker # try all of list, iterator, tuple, vararg. 1353*da0073e9SAndroid Build Coastguard Worker lst = [a.shape[0] + 1, 8, a.shape[0]] 1354*da0073e9SAndroid Build Coastguard Worker x = func(lst) 1355*da0073e9SAndroid Build Coastguard Worker y = func(iter(lst)) 1356*da0073e9SAndroid Build Coastguard Worker z = func(tuple(lst)) 1357*da0073e9SAndroid Build Coastguard Worker w = func(*lst) 1358*da0073e9SAndroid Build Coastguard Worker return a + (x + y + z + w) 1359*da0073e9SAndroid Build Coastguard Worker 1360*da0073e9SAndroid Build Coastguard Worker return _fn 1361*da0073e9SAndroid Build Coastguard Worker 1362*da0073e9SAndroid Build Coastguard Worker torch._dynamo.testing.standard_test( 1363*da0073e9SAndroid Build Coastguard Worker self, 1364*da0073e9SAndroid Build Coastguard Worker get_test_fn(func=min), 1365*da0073e9SAndroid Build Coastguard Worker 2, 1366*da0073e9SAndroid Build Coastguard Worker expected_ops=1, 1367*da0073e9SAndroid Build Coastguard Worker expected_ops_dynamic=ifdynstaticdefault(1, 14), 1368*da0073e9SAndroid Build Coastguard Worker ) 1369*da0073e9SAndroid Build Coastguard Worker torch._dynamo.testing.standard_test( 1370*da0073e9SAndroid Build Coastguard Worker self, 1371*da0073e9SAndroid Build Coastguard Worker get_test_fn(func=max), 1372*da0073e9SAndroid Build Coastguard Worker 2, 1373*da0073e9SAndroid Build Coastguard Worker expected_ops=1, 1374*da0073e9SAndroid Build Coastguard Worker expected_ops_dynamic=ifdynstaticdefault(1, 17), 1375*da0073e9SAndroid Build Coastguard Worker ) 1376*da0073e9SAndroid Build Coastguard Worker 1377*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch(capture_scalar_outputs=True) 1378*da0073e9SAndroid Build Coastguard Worker def test_torch_check(self): 1379*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 1380*da0073e9SAndroid Build Coastguard Worker 1381*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend=cnts, fullgraph=True) 1382*da0073e9SAndroid Build Coastguard Worker def f(x): 1383*da0073e9SAndroid Build Coastguard Worker y = x.item() 1384*da0073e9SAndroid Build Coastguard Worker torch._check(y >= 0) 1385*da0073e9SAndroid Build Coastguard Worker return torch.arange(0, y) 1386*da0073e9SAndroid Build Coastguard Worker 1387*da0073e9SAndroid Build Coastguard Worker f(torch.tensor([3])) 1388*da0073e9SAndroid Build Coastguard Worker f(torch.tensor([4])) 1389*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 1390*da0073e9SAndroid Build Coastguard Worker 1391*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch(capture_scalar_outputs=True) 1392*da0073e9SAndroid Build Coastguard Worker def test_torch_check_symbolic_shape_rel(self): 1393*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 1394*da0073e9SAndroid Build Coastguard Worker 1395*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend=cnts, fullgraph=True) 1396*da0073e9SAndroid Build Coastguard Worker def f(x): 1397*da0073e9SAndroid Build Coastguard Worker y = x.item() 1398*da0073e9SAndroid Build Coastguard Worker torch._check(x.shape[0] == 1) 1399*da0073e9SAndroid Build Coastguard Worker torch._check(x.shape[0] != 2) 1400*da0073e9SAndroid Build Coastguard Worker torch._check(x.shape[0] >= 0) 1401*da0073e9SAndroid Build Coastguard Worker torch._check(x.shape[0] > 0) 1402*da0073e9SAndroid Build Coastguard Worker torch._check(x.shape[0] < 4) 1403*da0073e9SAndroid Build Coastguard Worker torch._check(x.shape[0] <= 3) 1404*da0073e9SAndroid Build Coastguard Worker return torch.arange(0, y) 1405*da0073e9SAndroid Build Coastguard Worker 1406*da0073e9SAndroid Build Coastguard Worker f(torch.tensor([3])) 1407*da0073e9SAndroid Build Coastguard Worker f(torch.tensor([4])) 1408*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 1409*da0073e9SAndroid Build Coastguard Worker 1410*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch(capture_scalar_outputs=True) 1411*da0073e9SAndroid Build Coastguard Worker # Translation validation changes the exception type, don't run with it 1412*da0073e9SAndroid Build Coastguard Worker @torch.fx.experimental._config.patch(translation_validation=False) 1413*da0073e9SAndroid Build Coastguard Worker def test_torch_check_is_size(self): 1414*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 1415*da0073e9SAndroid Build Coastguard Worker 1416*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend=cnts, fullgraph=True) 1417*da0073e9SAndroid Build Coastguard Worker def f(x): 1418*da0073e9SAndroid Build Coastguard Worker y = x.item() 1419*da0073e9SAndroid Build Coastguard Worker torch._check_is_size(y) 1420*da0073e9SAndroid Build Coastguard Worker # Cannot conditional on unbacked SymInt 1421*da0073e9SAndroid Build Coastguard Worker if y == 0: 1422*da0073e9SAndroid Build Coastguard Worker assert False 1423*da0073e9SAndroid Build Coastguard Worker else: 1424*da0073e9SAndroid Build Coastguard Worker return torch.arange(0, y) 1425*da0073e9SAndroid Build Coastguard Worker 1426*da0073e9SAndroid Build Coastguard Worker self.assertRaises(torch._dynamo.exc.UserError, lambda: f(torch.tensor([3]))) 1427*da0073e9SAndroid Build Coastguard Worker 1428*da0073e9SAndroid Build Coastguard Worker def test_assert(self): 1429*da0073e9SAndroid Build Coastguard Worker @torch.compile 1430*da0073e9SAndroid Build Coastguard Worker def fn1(x): 1431*da0073e9SAndroid Build Coastguard Worker assert x.shape != x.shape 1432*da0073e9SAndroid Build Coastguard Worker 1433*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(AssertionError): 1434*da0073e9SAndroid Build Coastguard Worker a = torch.randn(10) 1435*da0073e9SAndroid Build Coastguard Worker fn1(a) 1436*da0073e9SAndroid Build Coastguard Worker 1437*da0073e9SAndroid Build Coastguard Worker def fn2(x): 1438*da0073e9SAndroid Build Coastguard Worker assert x.shape == x.shape 1439*da0073e9SAndroid Build Coastguard Worker return x.abs() 1440*da0073e9SAndroid Build Coastguard Worker 1441*da0073e9SAndroid Build Coastguard Worker torch._dynamo.testing.standard_test(self, fn=fn2, nargs=1, expected_ops=1) 1442*da0073e9SAndroid Build Coastguard Worker 1443*da0073e9SAndroid Build Coastguard Worker def test_config_obj(self): 1444*da0073e9SAndroid Build Coastguard Worker class Cfg: 1445*da0073e9SAndroid Build Coastguard Worker def __init__(self): 1446*da0073e9SAndroid Build Coastguard Worker self.val = 0.5 1447*da0073e9SAndroid Build Coastguard Worker self.count = 3 1448*da0073e9SAndroid Build Coastguard Worker 1449*da0073e9SAndroid Build Coastguard Worker def fn(x, cfg): 1450*da0073e9SAndroid Build Coastguard Worker for i in range(cfg.count): 1451*da0073e9SAndroid Build Coastguard Worker x = x + cfg.val 1452*da0073e9SAndroid Build Coastguard Worker return x 1453*da0073e9SAndroid Build Coastguard Worker 1454*da0073e9SAndroid Build Coastguard Worker cfg1 = Cfg() 1455*da0073e9SAndroid Build Coastguard Worker cfg1.val = 1.0 1456*da0073e9SAndroid Build Coastguard Worker cfg2 = Cfg() 1457*da0073e9SAndroid Build Coastguard Worker v = torch.zeros(1) 1458*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 1459*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 1460*da0073e9SAndroid Build Coastguard Worker v = opt_fn(v, cfg1) # 3 1461*da0073e9SAndroid Build Coastguard Worker v = opt_fn(v, cfg2) # 4.5 1462*da0073e9SAndroid Build Coastguard Worker cfg2.count = 1 1463*da0073e9SAndroid Build Coastguard Worker v = opt_fn(v, cfg2) # 5 1464*da0073e9SAndroid Build Coastguard Worker cfg2.val = 2.0 1465*da0073e9SAndroid Build Coastguard Worker v = opt_fn(v, cfg2) # 7 1466*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v[0], 7) 1467*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.op_count, 8) 1468*da0073e9SAndroid Build Coastguard Worker 1469*da0073e9SAndroid Build Coastguard Worker def test_config_getattr_default(self): 1470*da0073e9SAndroid Build Coastguard Worker class Cfg: 1471*da0073e9SAndroid Build Coastguard Worker def __init__(self): 1472*da0073e9SAndroid Build Coastguard Worker self.val = 0.5 1473*da0073e9SAndroid Build Coastguard Worker self.count = 10 1474*da0073e9SAndroid Build Coastguard Worker 1475*da0073e9SAndroid Build Coastguard Worker def fn(x, cfg): 1476*da0073e9SAndroid Build Coastguard Worker if getattr(cfg, "just_add_7", False): 1477*da0073e9SAndroid Build Coastguard Worker return x + 7 1478*da0073e9SAndroid Build Coastguard Worker for i in range(cfg.count): 1479*da0073e9SAndroid Build Coastguard Worker x = x + cfg.val 1480*da0073e9SAndroid Build Coastguard Worker return x 1481*da0073e9SAndroid Build Coastguard Worker 1482*da0073e9SAndroid Build Coastguard Worker cfg1 = Cfg() 1483*da0073e9SAndroid Build Coastguard Worker v = torch.zeros(1) 1484*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 1485*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 1486*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(v, cfg1)[0], 5) 1487*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(v, cfg1)[0], 5) 1488*da0073e9SAndroid Build Coastguard Worker cfg1.just_add_7 = True 1489*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(v, cfg1)[0], 7) 1490*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(v, cfg1)[0], 7) 1491*da0073e9SAndroid Build Coastguard Worker cfg1.just_add_7 = False 1492*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(v, cfg1)[0], 5) 1493*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(v, cfg1)[0], 5) 1494*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 3) 1495*da0073e9SAndroid Build Coastguard Worker 1496*da0073e9SAndroid Build Coastguard Worker def test_size_input(self): 1497*da0073e9SAndroid Build Coastguard Worker def fn(x, s): 1498*da0073e9SAndroid Build Coastguard Worker a, b = s 1499*da0073e9SAndroid Build Coastguard Worker return x + (a - b) 1500*da0073e9SAndroid Build Coastguard Worker 1501*da0073e9SAndroid Build Coastguard Worker v = torch.zeros(10, 20) 1502*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 1503*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 1504*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(v, v.size())[0, 0], -10) 1505*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(v, (10, 20))[0, 0], -10) 1506*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(v, [10, 20])[0, 0], -10) 1507*da0073e9SAndroid Build Coastguard Worker # One recompile per differing input type 1508*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 3) 1509*da0073e9SAndroid Build Coastguard Worker 1510*da0073e9SAndroid Build Coastguard Worker def test_cell_output1(self): 1511*da0073e9SAndroid Build Coastguard Worker out = None 1512*da0073e9SAndroid Build Coastguard Worker 1513*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 1514*da0073e9SAndroid Build Coastguard Worker nonlocal out 1515*da0073e9SAndroid Build Coastguard Worker out = a + b * 10 1516*da0073e9SAndroid Build Coastguard Worker 1517*da0073e9SAndroid Build Coastguard Worker v = torch.Tensor([100]) 1518*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 1519*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 1520*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(opt_fn(v, v)) 1521*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out[0], 1100) 1522*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.op_count, 2) 1523*da0073e9SAndroid Build Coastguard Worker 1524*da0073e9SAndroid Build Coastguard Worker def test_cell_output2(self): 1525*da0073e9SAndroid Build Coastguard Worker out = None 1526*da0073e9SAndroid Build Coastguard Worker 1527*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 1528*da0073e9SAndroid Build Coastguard Worker nonlocal out 1529*da0073e9SAndroid Build Coastguard Worker c = unsupported(a, b) 1530*da0073e9SAndroid Build Coastguard Worker out = a + b * 10 + c 1531*da0073e9SAndroid Build Coastguard Worker 1532*da0073e9SAndroid Build Coastguard Worker v = torch.Tensor([100]) 1533*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 1534*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 1535*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(opt_fn(v, v)) 1536*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out[0], 1200) 1537*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.op_count, 3) 1538*da0073e9SAndroid Build Coastguard Worker 1539*da0073e9SAndroid Build Coastguard Worker def test_return_nested_function(self): 1540*da0073e9SAndroid Build Coastguard Worker out = None 1541*da0073e9SAndroid Build Coastguard Worker 1542*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 1543*da0073e9SAndroid Build Coastguard Worker nonlocal out 1544*da0073e9SAndroid Build Coastguard Worker c = a + b 1545*da0073e9SAndroid Build Coastguard Worker d = a + 1.0 1546*da0073e9SAndroid Build Coastguard Worker 1547*da0073e9SAndroid Build Coastguard Worker def fn2(f: int = 7, g: float = 9.0): 1548*da0073e9SAndroid Build Coastguard Worker nonlocal out 1549*da0073e9SAndroid Build Coastguard Worker out = a + b * 10 1550*da0073e9SAndroid Build Coastguard Worker return c * f - d * g 1551*da0073e9SAndroid Build Coastguard Worker 1552*da0073e9SAndroid Build Coastguard Worker return fn2 1553*da0073e9SAndroid Build Coastguard Worker 1554*da0073e9SAndroid Build Coastguard Worker v1 = torch.Tensor([100]) 1555*da0073e9SAndroid Build Coastguard Worker v2 = torch.Tensor([200]) 1556*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 1557*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 1558*da0073e9SAndroid Build Coastguard Worker opt_fn_ret = torch._dynamo.optimize(cnts)(opt_fn(v1, v2)) 1559*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn_ret(1.5)[0], -459) 1560*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out[0], 2100) 1561*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 2) 1562*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.op_count, 7) 1563*da0073e9SAndroid Build Coastguard Worker 1564*da0073e9SAndroid Build Coastguard Worker def test_tensor_dict1(self): 1565*da0073e9SAndroid Build Coastguard Worker def fn(inputs): 1566*da0073e9SAndroid Build Coastguard Worker return inputs["a"] - inputs["b"] * 1.5 1567*da0073e9SAndroid Build Coastguard Worker 1568*da0073e9SAndroid Build Coastguard Worker v1 = torch.Tensor([100]) 1569*da0073e9SAndroid Build Coastguard Worker v2 = torch.Tensor([200]) 1570*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 1571*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 1572*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn({"a": v1, "b": v2})[0], -200) 1573*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 1574*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.op_count, 2) 1575*da0073e9SAndroid Build Coastguard Worker 1576*da0073e9SAndroid Build Coastguard Worker def test_tensor_dict3(self): 1577*da0073e9SAndroid Build Coastguard Worker def fn(inputs_a, inputs_b): 1578*da0073e9SAndroid Build Coastguard Worker total = torch.zeros(1) 1579*da0073e9SAndroid Build Coastguard Worker input_keys = inputs_a.keys() | inputs_b.keys() 1580*da0073e9SAndroid Build Coastguard Worker for k in input_keys: 1581*da0073e9SAndroid Build Coastguard Worker if k in inputs_a: 1582*da0073e9SAndroid Build Coastguard Worker total += inputs_a[k] 1583*da0073e9SAndroid Build Coastguard Worker if k in inputs_b: 1584*da0073e9SAndroid Build Coastguard Worker total += inputs_b[k] 1585*da0073e9SAndroid Build Coastguard Worker return total 1586*da0073e9SAndroid Build Coastguard Worker 1587*da0073e9SAndroid Build Coastguard Worker v1 = torch.Tensor([100]) 1588*da0073e9SAndroid Build Coastguard Worker v2 = torch.Tensor([200]) 1589*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 1590*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 1591*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1592*da0073e9SAndroid Build Coastguard Worker opt_fn({"a": v1, "b": v2}, {"b": v1, "c": v2}), 1593*da0073e9SAndroid Build Coastguard Worker fn({"a": v1, "b": v2}, {"b": v1, "c": v2}), 1594*da0073e9SAndroid Build Coastguard Worker ) 1595*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 1596*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.op_count, 5) 1597*da0073e9SAndroid Build Coastguard Worker 1598*da0073e9SAndroid Build Coastguard Worker def test_tensor_dict2(self): 1599*da0073e9SAndroid Build Coastguard Worker def fn1(inputs): 1600*da0073e9SAndroid Build Coastguard Worker total = torch.zeros(1) 1601*da0073e9SAndroid Build Coastguard Worker for k, v in inputs.items(): 1602*da0073e9SAndroid Build Coastguard Worker total += v 1603*da0073e9SAndroid Build Coastguard Worker return total 1604*da0073e9SAndroid Build Coastguard Worker 1605*da0073e9SAndroid Build Coastguard Worker def fn2(inputs): 1606*da0073e9SAndroid Build Coastguard Worker total = torch.zeros(1) 1607*da0073e9SAndroid Build Coastguard Worker for v in inputs.values(): 1608*da0073e9SAndroid Build Coastguard Worker total += v 1609*da0073e9SAndroid Build Coastguard Worker return total 1610*da0073e9SAndroid Build Coastguard Worker 1611*da0073e9SAndroid Build Coastguard Worker def fn3(inputs): 1612*da0073e9SAndroid Build Coastguard Worker total = torch.zeros(1) 1613*da0073e9SAndroid Build Coastguard Worker for k in inputs.keys(): 1614*da0073e9SAndroid Build Coastguard Worker total += inputs[k] 1615*da0073e9SAndroid Build Coastguard Worker return total 1616*da0073e9SAndroid Build Coastguard Worker 1617*da0073e9SAndroid Build Coastguard Worker v1 = torch.Tensor([100]) 1618*da0073e9SAndroid Build Coastguard Worker v2 = torch.Tensor([200]) 1619*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 1620*da0073e9SAndroid Build Coastguard Worker opt_fn1 = torch._dynamo.optimize(cnts, nopython=True)(fn1) 1621*da0073e9SAndroid Build Coastguard Worker opt_fn2 = torch._dynamo.optimize(cnts, nopython=True)(fn2) 1622*da0073e9SAndroid Build Coastguard Worker opt_fn3 = torch._dynamo.optimize(cnts, nopython=True)(fn3) 1623*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn1({"a": v1, "b": v2})[0], 300) 1624*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn2({"a": v1, "b": v2})[0], 300) 1625*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn3({"a": v1, "b": v2})[0], 300) 1626*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 3) 1627*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.op_count, 9) 1628*da0073e9SAndroid Build Coastguard Worker 1629*da0073e9SAndroid Build Coastguard Worker def test_dictcomp(self): 1630*da0073e9SAndroid Build Coastguard Worker def fn1(inputs): 1631*da0073e9SAndroid Build Coastguard Worker return {k: v + 1 for k, v in inputs.items()} 1632*da0073e9SAndroid Build Coastguard Worker 1633*da0073e9SAndroid Build Coastguard Worker v1 = torch.Tensor([100]) 1634*da0073e9SAndroid Build Coastguard Worker v2 = torch.Tensor([200]) 1635*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 1636*da0073e9SAndroid Build Coastguard Worker opt_fn1 = torch._dynamo.optimize(cnts)(fn1) 1637*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn1({"a": v1, "b": v2})["a"], 101) 1638*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn1({"a": v1, "b": v2})["b"], 201) 1639*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 1640*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.op_count, 2) 1641*da0073e9SAndroid Build Coastguard Worker 1642*da0073e9SAndroid Build Coastguard Worker def test_listcomp(self): 1643*da0073e9SAndroid Build Coastguard Worker def fn2(inputs): 1644*da0073e9SAndroid Build Coastguard Worker return torch.sum(torch.cat([v + 1 for k, v in inputs.items()], 0)) 1645*da0073e9SAndroid Build Coastguard Worker 1646*da0073e9SAndroid Build Coastguard Worker v1 = torch.Tensor([100]) 1647*da0073e9SAndroid Build Coastguard Worker v2 = torch.Tensor([200]) 1648*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 1649*da0073e9SAndroid Build Coastguard Worker opt_fn2 = torch._dynamo.optimize(cnts)(fn2) 1650*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn2({"a": v1, "b": v2}), 302) 1651*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 1652*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.op_count, 4) 1653*da0073e9SAndroid Build Coastguard Worker 1654*da0073e9SAndroid Build Coastguard Worker def test_is_floating_point(self): 1655*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 1656*da0073e9SAndroid Build Coastguard Worker x = a + 1.0 1657*da0073e9SAndroid Build Coastguard Worker if torch.is_floating_point(b): 1658*da0073e9SAndroid Build Coastguard Worker x = x + b 1659*da0073e9SAndroid Build Coastguard Worker return x + 2.0 1660*da0073e9SAndroid Build Coastguard Worker 1661*da0073e9SAndroid Build Coastguard Worker return torch._dynamo.testing.standard_test(self, fn=fn, nargs=2, expected_ops=3) 1662*da0073e9SAndroid Build Coastguard Worker 1663*da0073e9SAndroid Build Coastguard Worker def test_is_floating_point2(self): 1664*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 1665*da0073e9SAndroid Build Coastguard Worker x = a + 1.0 1666*da0073e9SAndroid Build Coastguard Worker if b.is_floating_point(): 1667*da0073e9SAndroid Build Coastguard Worker x = x + b 1668*da0073e9SAndroid Build Coastguard Worker return x + 2.0 1669*da0073e9SAndroid Build Coastguard Worker 1670*da0073e9SAndroid Build Coastguard Worker return torch._dynamo.testing.standard_test(self, fn=fn, nargs=2, expected_ops=3) 1671*da0073e9SAndroid Build Coastguard Worker 1672*da0073e9SAndroid Build Coastguard Worker def test_is_tensor(self): 1673*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 1674*da0073e9SAndroid Build Coastguard Worker x = a + 1.0 1675*da0073e9SAndroid Build Coastguard Worker if torch.is_tensor(b): 1676*da0073e9SAndroid Build Coastguard Worker x = x + b 1677*da0073e9SAndroid Build Coastguard Worker return x + 2.0 1678*da0073e9SAndroid Build Coastguard Worker 1679*da0073e9SAndroid Build Coastguard Worker return torch._dynamo.testing.standard_test(self, fn=fn, nargs=2, expected_ops=3) 1680*da0073e9SAndroid Build Coastguard Worker 1681*da0073e9SAndroid Build Coastguard Worker def test_is_tensor2(self): 1682*da0073e9SAndroid Build Coastguard Worker def fn(x): 1683*da0073e9SAndroid Build Coastguard Worker if torch.is_tensor(x): 1684*da0073e9SAndroid Build Coastguard Worker return x + 1 1685*da0073e9SAndroid Build Coastguard Worker else: 1686*da0073e9SAndroid Build Coastguard Worker return torch.ones([2, 3]) 1687*da0073e9SAndroid Build Coastguard Worker 1688*da0073e9SAndroid Build Coastguard Worker x1 = {"input": torch.rand(2, 3)} 1689*da0073e9SAndroid Build Coastguard Worker x2 = torch.rand(2, 3) 1690*da0073e9SAndroid Build Coastguard Worker ref1 = fn(x1) 1691*da0073e9SAndroid Build Coastguard Worker ref2 = fn(x2) 1692*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager")(fn) 1693*da0073e9SAndroid Build Coastguard Worker res1 = opt_fn(x1) 1694*da0073e9SAndroid Build Coastguard Worker res2 = opt_fn(x2) 1695*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref1, res1) 1696*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref2, res2) 1697*da0073e9SAndroid Build Coastguard Worker 1698*da0073e9SAndroid Build Coastguard Worker def test_numel(self): 1699*da0073e9SAndroid Build Coastguard Worker def fn(a): 1700*da0073e9SAndroid Build Coastguard Worker return (a + a.numel() + torch.numel(a), a + a.nelement()) 1701*da0073e9SAndroid Build Coastguard Worker 1702*da0073e9SAndroid Build Coastguard Worker return torch._dynamo.testing.standard_test( 1703*da0073e9SAndroid Build Coastguard Worker self, 1704*da0073e9SAndroid Build Coastguard Worker fn=fn, 1705*da0073e9SAndroid Build Coastguard Worker nargs=1, 1706*da0073e9SAndroid Build Coastguard Worker expected_ops=3, 1707*da0073e9SAndroid Build Coastguard Worker expected_ops_dynamic=ifdynstaticdefault(3, 6), 1708*da0073e9SAndroid Build Coastguard Worker ) 1709*da0073e9SAndroid Build Coastguard Worker 1710*da0073e9SAndroid Build Coastguard Worker def test_pair(self): 1711*da0073e9SAndroid Build Coastguard Worker def fn(a): 1712*da0073e9SAndroid Build Coastguard Worker return ( 1713*da0073e9SAndroid Build Coastguard Worker torch.zeros(torch.nn.modules.utils._pair(a.size())) 1714*da0073e9SAndroid Build Coastguard Worker + a 1715*da0073e9SAndroid Build Coastguard Worker + torch.ones(torch.nn.modules.utils._ntuple(3)(3)).sum() 1716*da0073e9SAndroid Build Coastguard Worker ) 1717*da0073e9SAndroid Build Coastguard Worker 1718*da0073e9SAndroid Build Coastguard Worker return torch._dynamo.testing.standard_test( 1719*da0073e9SAndroid Build Coastguard Worker self, 1720*da0073e9SAndroid Build Coastguard Worker fn=fn, 1721*da0073e9SAndroid Build Coastguard Worker nargs=1, 1722*da0073e9SAndroid Build Coastguard Worker expected_ops=5, 1723*da0073e9SAndroid Build Coastguard Worker expected_ops_dynamic=ifdynstaticdefault(5, 8), 1724*da0073e9SAndroid Build Coastguard Worker ) 1725*da0073e9SAndroid Build Coastguard Worker 1726*da0073e9SAndroid Build Coastguard Worker @patch.object(torch._dynamo.config, "capture_scalar_outputs", True) 1727*da0073e9SAndroid Build Coastguard Worker def test_tensor_item_capture(self): 1728*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 1729*da0073e9SAndroid Build Coastguard Worker return (a + b).sum().item() 1730*da0073e9SAndroid Build Coastguard Worker 1731*da0073e9SAndroid Build Coastguard Worker v1 = torch.randn((10, 10)) 1732*da0073e9SAndroid Build Coastguard Worker v2 = torch.randn((10, 10)) 1733*da0073e9SAndroid Build Coastguard Worker correct = fn(v1, v2) 1734*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 1735*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 1736*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(v1, v2), correct) 1737*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 1738*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.op_count, 4) 1739*da0073e9SAndroid Build Coastguard Worker 1740*da0073e9SAndroid Build Coastguard Worker @patch.object(torch._dynamo.config, "capture_scalar_outputs", False) 1741*da0073e9SAndroid Build Coastguard Worker def test_tensor_item_no_capture(self): 1742*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 1743*da0073e9SAndroid Build Coastguard Worker return (a + b).sum().item() 1744*da0073e9SAndroid Build Coastguard Worker 1745*da0073e9SAndroid Build Coastguard Worker v1 = torch.randn((10, 10)) 1746*da0073e9SAndroid Build Coastguard Worker v2 = torch.randn((10, 10)) 1747*da0073e9SAndroid Build Coastguard Worker correct = fn(v1, v2) 1748*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 1749*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 1750*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(v1, v2), correct) 1751*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 1752*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.op_count, 2) 1753*da0073e9SAndroid Build Coastguard Worker 1754*da0073e9SAndroid Build Coastguard Worker def test_namedtuple1(self): 1755*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 1756*da0073e9SAndroid Build Coastguard Worker tmp = mytuple(a, b, a + b) 1757*da0073e9SAndroid Build Coastguard Worker return mytuple(tmp.a, tmp[1], tmp.ab + b) 1758*da0073e9SAndroid Build Coastguard Worker 1759*da0073e9SAndroid Build Coastguard Worker v1 = torch.Tensor([10]) 1760*da0073e9SAndroid Build Coastguard Worker v2 = torch.Tensor([20]) 1761*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 1762*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 1763*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(v1, v2).ab, 50) 1764*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 1765*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.op_count, 2) 1766*da0073e9SAndroid Build Coastguard Worker 1767*da0073e9SAndroid Build Coastguard Worker def test_namedtuple2(self): 1768*da0073e9SAndroid Build Coastguard Worker def fn(packed): 1769*da0073e9SAndroid Build Coastguard Worker a, b, c = packed 1770*da0073e9SAndroid Build Coastguard Worker if hasattr(packed, "b"): 1771*da0073e9SAndroid Build Coastguard Worker b = packed.b + 1 1772*da0073e9SAndroid Build Coastguard Worker c = packed[2] 1773*da0073e9SAndroid Build Coastguard Worker return a + b + c 1774*da0073e9SAndroid Build Coastguard Worker 1775*da0073e9SAndroid Build Coastguard Worker v1 = torch.Tensor([1]) 1776*da0073e9SAndroid Build Coastguard Worker v2 = torch.Tensor([2]) 1777*da0073e9SAndroid Build Coastguard Worker v3 = torch.Tensor([3]) 1778*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 1779*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 1780*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(mytuple(v1, v2, v3))[0], 7) 1781*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 1782*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.op_count, 3) 1783*da0073e9SAndroid Build Coastguard Worker 1784*da0073e9SAndroid Build Coastguard Worker def test_namedtuple3(self): 1785*da0073e9SAndroid Build Coastguard Worker def fn(x, packed): 1786*da0073e9SAndroid Build Coastguard Worker if isinstance(packed, mytuple): 1787*da0073e9SAndroid Build Coastguard Worker return x + 1 1788*da0073e9SAndroid Build Coastguard Worker else: 1789*da0073e9SAndroid Build Coastguard Worker return x - 1 1790*da0073e9SAndroid Build Coastguard Worker 1791*da0073e9SAndroid Build Coastguard Worker x = torch.rand([2, 3]) 1792*da0073e9SAndroid Build Coastguard Worker packed = mytuple(1, 2, 3) 1793*da0073e9SAndroid Build Coastguard Worker ref = fn(x, packed) 1794*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager")(fn) 1795*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x, packed) 1796*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 1797*da0073e9SAndroid Build Coastguard Worker 1798*da0073e9SAndroid Build Coastguard Worker def test_range_input(self): 1799*da0073e9SAndroid Build Coastguard Worker def fn(a, rng): 1800*da0073e9SAndroid Build Coastguard Worker x = a 1801*da0073e9SAndroid Build Coastguard Worker for i in rng: 1802*da0073e9SAndroid Build Coastguard Worker x = x + i 1803*da0073e9SAndroid Build Coastguard Worker return x 1804*da0073e9SAndroid Build Coastguard Worker 1805*da0073e9SAndroid Build Coastguard Worker def fn1(a): 1806*da0073e9SAndroid Build Coastguard Worker return fn(a, rng=range(3)) 1807*da0073e9SAndroid Build Coastguard Worker 1808*da0073e9SAndroid Build Coastguard Worker return torch._dynamo.testing.standard_test( 1809*da0073e9SAndroid Build Coastguard Worker self, fn=fn1, nargs=1, expected_ops=3 1810*da0073e9SAndroid Build Coastguard Worker ) 1811*da0073e9SAndroid Build Coastguard Worker 1812*da0073e9SAndroid Build Coastguard Worker def test_range_with_shape(self): 1813*da0073e9SAndroid Build Coastguard Worker def fn(a): 1814*da0073e9SAndroid Build Coastguard Worker for i in range(1, a.shape[0]): 1815*da0073e9SAndroid Build Coastguard Worker a += 1 1816*da0073e9SAndroid Build Coastguard Worker return a 1817*da0073e9SAndroid Build Coastguard Worker 1818*da0073e9SAndroid Build Coastguard Worker return torch._dynamo.testing.standard_test( 1819*da0073e9SAndroid Build Coastguard Worker self, 1820*da0073e9SAndroid Build Coastguard Worker fn=fn, 1821*da0073e9SAndroid Build Coastguard Worker nargs=1, 1822*da0073e9SAndroid Build Coastguard Worker expected_ops=9, 1823*da0073e9SAndroid Build Coastguard Worker ) 1824*da0073e9SAndroid Build Coastguard Worker 1825*da0073e9SAndroid Build Coastguard Worker def test_build_tuple_unpack(self): 1826*da0073e9SAndroid Build Coastguard Worker def fn1(a, b, c): 1827*da0073e9SAndroid Build Coastguard Worker return a - b / c 1828*da0073e9SAndroid Build Coastguard Worker 1829*da0073e9SAndroid Build Coastguard Worker def fn2(a, b, c): 1830*da0073e9SAndroid Build Coastguard Worker tmp1 = (a,) 1831*da0073e9SAndroid Build Coastguard Worker tmp2 = (b, c) 1832*da0073e9SAndroid Build Coastguard Worker args = (*tmp1, *tmp2) 1833*da0073e9SAndroid Build Coastguard Worker return fn1(*args) 1834*da0073e9SAndroid Build Coastguard Worker 1835*da0073e9SAndroid Build Coastguard Worker def fn3(a, *args): 1836*da0073e9SAndroid Build Coastguard Worker return fn1(a, *args) 1837*da0073e9SAndroid Build Coastguard Worker 1838*da0073e9SAndroid Build Coastguard Worker torch._dynamo.testing.standard_test(self, fn=fn2, nargs=3, expected_ops=2) 1839*da0073e9SAndroid Build Coastguard Worker torch._dynamo.testing.standard_test(self, fn=fn3, nargs=3, expected_ops=2) 1840*da0073e9SAndroid Build Coastguard Worker 1841*da0073e9SAndroid Build Coastguard Worker def test_list_mul(self): 1842*da0073e9SAndroid Build Coastguard Worker def fn(count): 1843*da0073e9SAndroid Build Coastguard Worker head_mask = count * [None] * count 1844*da0073e9SAndroid Build Coastguard Worker return head_mask 1845*da0073e9SAndroid Build Coastguard Worker 1846*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 1847*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 1848*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(2), [None] * 4) 1849*da0073e9SAndroid Build Coastguard Worker # TODO: the captured frame here is a bit goofy, because we don't 1850*da0073e9SAndroid Build Coastguard Worker # output anything and none of the traced operations have side 1851*da0073e9SAndroid Build Coastguard Worker # effects. Probably need better heuristic for bailing on 1852*da0073e9SAndroid Build Coastguard Worker # dynamo if there are no outputs 1853*da0073e9SAndroid Build Coastguard Worker if torch._dynamo.config.assume_static_by_default: 1854*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnts.frame_count, """0""") 1855*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnts.op_count, """0""") 1856*da0073e9SAndroid Build Coastguard Worker else: 1857*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnts.frame_count, """1""") 1858*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnts.op_count, """2""") 1859*da0073e9SAndroid Build Coastguard Worker 1860*da0073e9SAndroid Build Coastguard Worker def test_list_slice_mul(self): 1861*da0073e9SAndroid Build Coastguard Worker def fn(count): 1862*da0073e9SAndroid Build Coastguard Worker a = [1, 2, 3] 1863*da0073e9SAndroid Build Coastguard Worker head_mask = count * a[1:] * count 1864*da0073e9SAndroid Build Coastguard Worker return head_mask 1865*da0073e9SAndroid Build Coastguard Worker 1866*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 1867*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 1868*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(2), [2, 3] * 4) 1869*da0073e9SAndroid Build Coastguard Worker if torch._dynamo.config.assume_static_by_default: 1870*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnts.frame_count, """0""") 1871*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnts.op_count, """0""") 1872*da0073e9SAndroid Build Coastguard Worker else: 1873*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnts.frame_count, """1""") 1874*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnts.op_count, """2""") 1875*da0073e9SAndroid Build Coastguard Worker 1876*da0073e9SAndroid Build Coastguard Worker def test_tuple_mul(self): 1877*da0073e9SAndroid Build Coastguard Worker def fn(count): 1878*da0073e9SAndroid Build Coastguard Worker head_mask = count * (2, 3) * count 1879*da0073e9SAndroid Build Coastguard Worker return head_mask 1880*da0073e9SAndroid Build Coastguard Worker 1881*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 1882*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 1883*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(2), (2, 3) * 4) 1884*da0073e9SAndroid Build Coastguard Worker if torch._dynamo.config.assume_static_by_default: 1885*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnts.frame_count, """0""") 1886*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnts.op_count, """0""") 1887*da0073e9SAndroid Build Coastguard Worker else: 1888*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnts.frame_count, """1""") 1889*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnts.op_count, """2""") 1890*da0073e9SAndroid Build Coastguard Worker 1891*da0073e9SAndroid Build Coastguard Worker def test_tuple_mul_with_shape(self): 1892*da0073e9SAndroid Build Coastguard Worker def fn(a): 1893*da0073e9SAndroid Build Coastguard Worker x = a.shape[0] 1894*da0073e9SAndroid Build Coastguard Worker y = 2 * (x, 3) * 2 1895*da0073e9SAndroid Build Coastguard Worker return a + y[4] 1896*da0073e9SAndroid Build Coastguard Worker 1897*da0073e9SAndroid Build Coastguard Worker # expect 3 ops post folding for dynamic case: size, index, add 1898*da0073e9SAndroid Build Coastguard Worker torch._dynamo.testing.standard_test( 1899*da0073e9SAndroid Build Coastguard Worker self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(1, 3) 1900*da0073e9SAndroid Build Coastguard Worker ) 1901*da0073e9SAndroid Build Coastguard Worker 1902*da0073e9SAndroid Build Coastguard Worker def test_tuple_iadd_with_shape(self): 1903*da0073e9SAndroid Build Coastguard Worker def fn(a): 1904*da0073e9SAndroid Build Coastguard Worker output = (a + a.shape[0], a - a.shape[0]) 1905*da0073e9SAndroid Build Coastguard Worker # tuple += tuple 1906*da0073e9SAndroid Build Coastguard Worker output += (a - a.shape[0], a + a.shape[0]) 1907*da0073e9SAndroid Build Coastguard Worker # tuple += constant tuple 1908*da0073e9SAndroid Build Coastguard Worker output += (2, 3) 1909*da0073e9SAndroid Build Coastguard Worker return output 1910*da0073e9SAndroid Build Coastguard Worker 1911*da0073e9SAndroid Build Coastguard Worker # expect 4 add / subs for static, 4 * 3 (size, index, math op) for dynamic 1912*da0073e9SAndroid Build Coastguard Worker torch._dynamo.testing.standard_test( 1913*da0073e9SAndroid Build Coastguard Worker self, fn, 1, expected_ops=4, expected_ops_dynamic=ifdynstaticdefault(4, 12) 1914*da0073e9SAndroid Build Coastguard Worker ) 1915*da0073e9SAndroid Build Coastguard Worker 1916*da0073e9SAndroid Build Coastguard Worker def test_list_iadd_with_shape(self): 1917*da0073e9SAndroid Build Coastguard Worker def fn(a): 1918*da0073e9SAndroid Build Coastguard Worker output = [a + a.shape[0], a - a.shape[0]] 1919*da0073e9SAndroid Build Coastguard Worker # list += list 1920*da0073e9SAndroid Build Coastguard Worker output += [a - a.shape[0], a + a.shape[0]] 1921*da0073e9SAndroid Build Coastguard Worker # list += tuple 1922*da0073e9SAndroid Build Coastguard Worker output += (a + a.shape[0], a - a.shape[0]) 1923*da0073e9SAndroid Build Coastguard Worker return output 1924*da0073e9SAndroid Build Coastguard Worker 1925*da0073e9SAndroid Build Coastguard Worker # expect 6 add / subs for static, 6 * 3 (size, index, math op) for dynamic 1926*da0073e9SAndroid Build Coastguard Worker 1927*da0073e9SAndroid Build Coastguard Worker torch._dynamo.testing.standard_test( 1928*da0073e9SAndroid Build Coastguard Worker self, fn, 1, expected_ops=6, expected_ops_dynamic=ifdynstaticdefault(6, 18) 1929*da0073e9SAndroid Build Coastguard Worker ) 1930*da0073e9SAndroid Build Coastguard Worker 1931*da0073e9SAndroid Build Coastguard Worker def test_list_iadd_side_effect(self): 1932*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 1933*da0073e9SAndroid Build Coastguard Worker a += [b] 1934*da0073e9SAndroid Build Coastguard Worker torch._dynamo.graph_break() 1935*da0073e9SAndroid Build Coastguard Worker return a 1936*da0073e9SAndroid Build Coastguard Worker 1937*da0073e9SAndroid Build Coastguard Worker a = [1, 2, 3] 1938*da0073e9SAndroid Build Coastguard Worker b = torch.ones(2, 2) 1939*da0073e9SAndroid Build Coastguard Worker 1940*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager")(fn) 1941*da0073e9SAndroid Build Coastguard Worker 1942*da0073e9SAndroid Build Coastguard Worker exp = fn(a, b) 1943*da0073e9SAndroid Build Coastguard Worker 1944*da0073e9SAndroid Build Coastguard Worker a = [1, 2, 3] 1945*da0073e9SAndroid Build Coastguard Worker b = torch.ones(2, 2) 1946*da0073e9SAndroid Build Coastguard Worker act = opt_fn(a, b) 1947*da0073e9SAndroid Build Coastguard Worker 1948*da0073e9SAndroid Build Coastguard Worker self.assertEqual(exp, act) 1949*da0073e9SAndroid Build Coastguard Worker 1950*da0073e9SAndroid Build Coastguard Worker def test_user_getattr1(self): 1951*da0073e9SAndroid Build Coastguard Worker class MyConfig(dict): 1952*da0073e9SAndroid Build Coastguard Worker def __getattr__(self, name): 1953*da0073e9SAndroid Build Coastguard Worker return self[name] 1954*da0073e9SAndroid Build Coastguard Worker 1955*da0073e9SAndroid Build Coastguard Worker def fn(cfg, x, y): 1956*da0073e9SAndroid Build Coastguard Worker return x + y + cfg.offset 1957*da0073e9SAndroid Build Coastguard Worker 1958*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10) 1959*da0073e9SAndroid Build Coastguard Worker cfg = MyConfig(offset=5) 1960*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 1961*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 1962*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(opt_fn(cfg, x, x), 2 * x + 5)) 1963*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 1964*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.op_count, 2) 1965*da0073e9SAndroid Build Coastguard Worker 1966*da0073e9SAndroid Build Coastguard Worker def test_user_getattr2(self): 1967*da0073e9SAndroid Build Coastguard Worker class MyConfig: 1968*da0073e9SAndroid Build Coastguard Worker defined_on_class = 1 1969*da0073e9SAndroid Build Coastguard Worker 1970*da0073e9SAndroid Build Coastguard Worker def __init__(self): 1971*da0073e9SAndroid Build Coastguard Worker self.defined_on_object = 2 1972*da0073e9SAndroid Build Coastguard Worker 1973*da0073e9SAndroid Build Coastguard Worker def __getattr__(self, name): 1974*da0073e9SAndroid Build Coastguard Worker return 3 1975*da0073e9SAndroid Build Coastguard Worker 1976*da0073e9SAndroid Build Coastguard Worker def fn(cfg, x): 1977*da0073e9SAndroid Build Coastguard Worker return x + cfg.defined_on_class - cfg.defined_on_object + cfg.not_defined 1978*da0073e9SAndroid Build Coastguard Worker 1979*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10) 1980*da0073e9SAndroid Build Coastguard Worker cfg = MyConfig() 1981*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 1982*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 1983*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(opt_fn(cfg, x), x + 1 - 2 + 3)) 1984*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 1985*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.op_count, 3) 1986*da0073e9SAndroid Build Coastguard Worker 1987*da0073e9SAndroid Build Coastguard Worker def test_getset_descriptor(self): 1988*da0073e9SAndroid Build Coastguard Worker def fn(g, x): 1989*da0073e9SAndroid Build Coastguard Worker return g.__get__(x) 1990*da0073e9SAndroid Build Coastguard Worker 1991*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 1992*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fullgraph=True, backend="eager")(fn) 1993*da0073e9SAndroid Build Coastguard Worker g = torch.Tensor.shape 1994*da0073e9SAndroid Build Coastguard Worker 1995*da0073e9SAndroid Build Coastguard Worker res = opt_fn(g, torch.ones(2, 2)) 1996*da0073e9SAndroid Build Coastguard Worker exp_res = fn(g, torch.ones(2, 2)) 1997*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, exp_res) 1998*da0073e9SAndroid Build Coastguard Worker 1999*da0073e9SAndroid Build Coastguard Worker def test_get_attr_function(self): 2000*da0073e9SAndroid Build Coastguard Worker def fn(g, x): 2001*da0073e9SAndroid Build Coastguard Worker return g(x) 2002*da0073e9SAndroid Build Coastguard Worker 2003*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 2004*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 2005*da0073e9SAndroid Build Coastguard Worker g = torch.Tensor.shape.__get__ 2006*da0073e9SAndroid Build Coastguard Worker 2007*da0073e9SAndroid Build Coastguard Worker res = opt_fn(g, torch.ones(2, 2)) 2008*da0073e9SAndroid Build Coastguard Worker exp_res = fn(g, torch.ones(2, 2)) 2009*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, exp_res) 2010*da0073e9SAndroid Build Coastguard Worker 2011*da0073e9SAndroid Build Coastguard Worker def test_user_getattribute(self): 2012*da0073e9SAndroid Build Coastguard Worker class MyObject: 2013*da0073e9SAndroid Build Coastguard Worker def __init__(self): 2014*da0073e9SAndroid Build Coastguard Worker self.custom_dict = {"a": torch.rand((2, 2))} 2015*da0073e9SAndroid Build Coastguard Worker self.my_number = 42 2016*da0073e9SAndroid Build Coastguard Worker 2017*da0073e9SAndroid Build Coastguard Worker def __getattribute__(self, name): 2018*da0073e9SAndroid Build Coastguard Worker custom_dict = super().__getattribute__("custom_dict") 2019*da0073e9SAndroid Build Coastguard Worker if name in custom_dict: 2020*da0073e9SAndroid Build Coastguard Worker return custom_dict[name] 2021*da0073e9SAndroid Build Coastguard Worker return super().__getattribute__(name) 2022*da0073e9SAndroid Build Coastguard Worker 2023*da0073e9SAndroid Build Coastguard Worker def run(self, x): 2024*da0073e9SAndroid Build Coastguard Worker return self.my_number * x + self.a * x 2025*da0073e9SAndroid Build Coastguard Worker 2026*da0073e9SAndroid Build Coastguard Worker def fn(obj, x): 2027*da0073e9SAndroid Build Coastguard Worker return obj.run(x) 2028*da0073e9SAndroid Build Coastguard Worker 2029*da0073e9SAndroid Build Coastguard Worker obj = MyObject() 2030*da0073e9SAndroid Build Coastguard Worker x = torch.rand((2, 2)) 2031*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 2032*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 2033*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(opt_fn(obj, x), fn(obj, x))) 2034*da0073e9SAndroid Build Coastguard Worker 2035*da0073e9SAndroid Build Coastguard Worker def test_nn_module_getattr(self): 2036*da0073e9SAndroid Build Coastguard Worker class MyMod(torch.nn.Module): 2037*da0073e9SAndroid Build Coastguard Worker def __init__(self): 2038*da0073e9SAndroid Build Coastguard Worker super().__init__() 2039*da0073e9SAndroid Build Coastguard Worker self.custom_dict = {"queue": [torch.rand((2, 2)) for _ in range(3)]} 2040*da0073e9SAndroid Build Coastguard Worker self.other_attr = torch.rand((2, 2)) 2041*da0073e9SAndroid Build Coastguard Worker 2042*da0073e9SAndroid Build Coastguard Worker def __getattr__(self, name): 2043*da0073e9SAndroid Build Coastguard Worker custom_dict = self.custom_dict 2044*da0073e9SAndroid Build Coastguard Worker if name in custom_dict: 2045*da0073e9SAndroid Build Coastguard Worker return custom_dict[name] 2046*da0073e9SAndroid Build Coastguard Worker return super().__getattr__(name) 2047*da0073e9SAndroid Build Coastguard Worker 2048*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2049*da0073e9SAndroid Build Coastguard Worker return x @ self.other_attr + self.queue[-1] 2050*da0073e9SAndroid Build Coastguard Worker 2051*da0073e9SAndroid Build Coastguard Worker x = torch.rand((2, 2)) 2052*da0073e9SAndroid Build Coastguard Worker mod = MyMod() 2053*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 2054*da0073e9SAndroid Build Coastguard Worker opt_mod = torch._dynamo.optimize(cnts)(mod) 2055*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(opt_mod(x), mod(x))) 2056*da0073e9SAndroid Build Coastguard Worker self.assertTrue(cnts.frame_count, 1) 2057*da0073e9SAndroid Build Coastguard Worker self.assertTrue(cnts.op_count, 2) 2058*da0073e9SAndroid Build Coastguard Worker 2059*da0073e9SAndroid Build Coastguard Worker def test_nn_module_getattribute(self): 2060*da0073e9SAndroid Build Coastguard Worker class MyMod(torch.nn.Module): 2061*da0073e9SAndroid Build Coastguard Worker def __init__(self): 2062*da0073e9SAndroid Build Coastguard Worker super().__init__() 2063*da0073e9SAndroid Build Coastguard Worker self.my_number = 42 2064*da0073e9SAndroid Build Coastguard Worker 2065*da0073e9SAndroid Build Coastguard Worker def __getattribute__(self, name): 2066*da0073e9SAndroid Build Coastguard Worker if name == "special_attr": 2067*da0073e9SAndroid Build Coastguard Worker return torch.tensor([[1, 2], [3, 4]]) 2068*da0073e9SAndroid Build Coastguard Worker return super().__getattribute__(name) 2069*da0073e9SAndroid Build Coastguard Worker 2070*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2071*da0073e9SAndroid Build Coastguard Worker return self.my_number * x + self.special_attr * x 2072*da0073e9SAndroid Build Coastguard Worker 2073*da0073e9SAndroid Build Coastguard Worker def fn(mod, x): 2074*da0073e9SAndroid Build Coastguard Worker return mod(x) 2075*da0073e9SAndroid Build Coastguard Worker 2076*da0073e9SAndroid Build Coastguard Worker mod = MyMod() 2077*da0073e9SAndroid Build Coastguard Worker x = torch.rand((2, 2)) 2078*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 2079*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 2080*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(opt_fn(mod, x), fn(mod, x))) 2081*da0073e9SAndroid Build Coastguard Worker 2082*da0073e9SAndroid Build Coastguard Worker def test_constant_getattr(self): 2083*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/97480 2084*da0073e9SAndroid Build Coastguard Worker def fn(): 2085*da0073e9SAndroid Build Coastguard Worker return getattr(None, "arg", 3) 2086*da0073e9SAndroid Build Coastguard Worker 2087*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 2088*da0073e9SAndroid Build Coastguard Worker optimized_fn = torch._dynamo.optimize(cnt)(fn) 2089*da0073e9SAndroid Build Coastguard Worker res = optimized_fn() 2090*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(res, 3)) 2091*da0073e9SAndroid Build Coastguard Worker 2092*da0073e9SAndroid Build Coastguard Worker def test_user_property(self): 2093*da0073e9SAndroid Build Coastguard Worker class MyConfig: 2094*da0073e9SAndroid Build Coastguard Worker @property 2095*da0073e9SAndroid Build Coastguard Worker def prop5(self): 2096*da0073e9SAndroid Build Coastguard Worker return 5 2097*da0073e9SAndroid Build Coastguard Worker 2098*da0073e9SAndroid Build Coastguard Worker def fn(cfg, x, y): 2099*da0073e9SAndroid Build Coastguard Worker return x + y + cfg.prop5 2100*da0073e9SAndroid Build Coastguard Worker 2101*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10) 2102*da0073e9SAndroid Build Coastguard Worker cfg = MyConfig() 2103*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 2104*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 2105*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(opt_fn(cfg, x, x), 2 * x + 5)) 2106*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 2107*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.op_count, 2) 2108*da0073e9SAndroid Build Coastguard Worker 2109*da0073e9SAndroid Build Coastguard Worker def test_dataclass_fields(self): 2110*da0073e9SAndroid Build Coastguard Worker @dataclasses.dataclass 2111*da0073e9SAndroid Build Coastguard Worker class MyDataClass: 2112*da0073e9SAndroid Build Coastguard Worker a: torch.Tensor 2113*da0073e9SAndroid Build Coastguard Worker b: torch.Tensor = None 2114*da0073e9SAndroid Build Coastguard Worker c: torch.Tensor = None 2115*da0073e9SAndroid Build Coastguard Worker d: torch.Tensor = None 2116*da0073e9SAndroid Build Coastguard Worker e: torch.Tensor = None 2117*da0073e9SAndroid Build Coastguard Worker 2118*da0073e9SAndroid Build Coastguard Worker def fn(obj): 2119*da0073e9SAndroid Build Coastguard Worker class_fields = dataclasses.fields(obj) 2120*da0073e9SAndroid Build Coastguard Worker assert len(class_fields) 2121*da0073e9SAndroid Build Coastguard Worker assert all(field.default is None for field in class_fields[1:]) 2122*da0073e9SAndroid Build Coastguard Worker other_fields_are_none = all( 2123*da0073e9SAndroid Build Coastguard Worker getattr(obj, field.name) is None for field in class_fields[1:] 2124*da0073e9SAndroid Build Coastguard Worker ) 2125*da0073e9SAndroid Build Coastguard Worker assert not other_fields_are_none 2126*da0073e9SAndroid Build Coastguard Worker 2127*da0073e9SAndroid Build Coastguard Worker if not hasattr(obj, "a"): 2128*da0073e9SAndroid Build Coastguard Worker return -1 2129*da0073e9SAndroid Build Coastguard Worker if hasattr(obj, "z"): 2130*da0073e9SAndroid Build Coastguard Worker return -2 2131*da0073e9SAndroid Build Coastguard Worker 2132*da0073e9SAndroid Build Coastguard Worker total = getattr(obj, class_fields[0].name) 2133*da0073e9SAndroid Build Coastguard Worker for field in class_fields[1:]: 2134*da0073e9SAndroid Build Coastguard Worker v = getattr(obj, field.name) 2135*da0073e9SAndroid Build Coastguard Worker if v is not None: 2136*da0073e9SAndroid Build Coastguard Worker total += v 2137*da0073e9SAndroid Build Coastguard Worker 2138*da0073e9SAndroid Build Coastguard Worker return total 2139*da0073e9SAndroid Build Coastguard Worker 2140*da0073e9SAndroid Build Coastguard Worker obj1 = MyDataClass(torch.randn(10), torch.randn(10), torch.randn(10)) 2141*da0073e9SAndroid Build Coastguard Worker obj2 = MyDataClass(torch.randn(10), e=torch.randn(10)) 2142*da0073e9SAndroid Build Coastguard Worker correct1 = fn(obj1) 2143*da0073e9SAndroid Build Coastguard Worker correct2 = fn(obj2) 2144*da0073e9SAndroid Build Coastguard Worker 2145*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 2146*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 2147*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(opt_fn(obj1), correct1)) 2148*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 2149*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.op_count, 2) 2150*da0073e9SAndroid Build Coastguard Worker 2151*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 2152*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 2153*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 2154*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(opt_fn(obj2), correct2)) 2155*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 2156*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.op_count, 1) 2157*da0073e9SAndroid Build Coastguard Worker 2158*da0073e9SAndroid Build Coastguard Worker # guard failure 2159*da0073e9SAndroid Build Coastguard Worker obj2.z = True 2160*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(obj2), -2) 2161*da0073e9SAndroid Build Coastguard Worker 2162*da0073e9SAndroid Build Coastguard Worker def test_dataclass_local_hasattr(self): 2163*da0073e9SAndroid Build Coastguard Worker cnt = CompileCounter() 2164*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10) 2165*da0073e9SAndroid Build Coastguard Worker 2166*da0073e9SAndroid Build Coastguard Worker @dataclasses.dataclass 2167*da0073e9SAndroid Build Coastguard Worker class MyDataClass: 2168*da0073e9SAndroid Build Coastguard Worker a: torch.Tensor 2169*da0073e9SAndroid Build Coastguard Worker b: torch.Tensor 2170*da0073e9SAndroid Build Coastguard Worker 2171*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend=cnt, fullgraph=True) 2172*da0073e9SAndroid Build Coastguard Worker def fn(): 2173*da0073e9SAndroid Build Coastguard Worker obj = MyDataClass(x + 1, x - 1) 2174*da0073e9SAndroid Build Coastguard Worker if not hasattr(obj, "a"): 2175*da0073e9SAndroid Build Coastguard Worker return -1 2176*da0073e9SAndroid Build Coastguard Worker if hasattr(obj, "z"): 2177*da0073e9SAndroid Build Coastguard Worker return -2 2178*da0073e9SAndroid Build Coastguard Worker return obj 2179*da0073e9SAndroid Build Coastguard Worker 2180*da0073e9SAndroid Build Coastguard Worker result = fn() 2181*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(result, MyDataClass) 2182*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.a, x + 1) 2183*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.b, x - 1) 2184*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 2185*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.op_count, 2) 2186*da0073e9SAndroid Build Coastguard Worker 2187*da0073e9SAndroid Build Coastguard Worker def test_catch_watchings1(self): 2188*da0073e9SAndroid Build Coastguard Worker cnt = CompileCounter() 2189*da0073e9SAndroid Build Coastguard Worker 2190*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend=cnt, fullgraph=True) 2191*da0073e9SAndroid Build Coastguard Worker def fn(x): 2192*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True): 2193*da0073e9SAndroid Build Coastguard Worker return x.sin() 2194*da0073e9SAndroid Build Coastguard Worker 2195*da0073e9SAndroid Build Coastguard Worker x = torch.randn(8) 2196*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(x), x.sin()) 2197*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 2198*da0073e9SAndroid Build Coastguard Worker 2199*da0073e9SAndroid Build Coastguard Worker def test_catch_watchings2(self): 2200*da0073e9SAndroid Build Coastguard Worker cnt = CompileCounter() 2201*da0073e9SAndroid Build Coastguard Worker 2202*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend=cnt, fullgraph=True) 2203*da0073e9SAndroid Build Coastguard Worker def fn(x): 2204*da0073e9SAndroid Build Coastguard Worker return x.sin(), warnings.catch_warnings(record=True) 2205*da0073e9SAndroid Build Coastguard Worker 2206*da0073e9SAndroid Build Coastguard Worker x = torch.randn(8) 2207*da0073e9SAndroid Build Coastguard Worker _, a = fn(x) 2208*da0073e9SAndroid Build Coastguard Worker _, b = fn(x) 2209*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 2210*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(a, warnings.catch_warnings) 2211*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(b, warnings.catch_warnings) 2212*da0073e9SAndroid Build Coastguard Worker self.assertIsNot(a, b) 2213*da0073e9SAndroid Build Coastguard Worker 2214*da0073e9SAndroid Build Coastguard Worker def test_tensor_build_list_unpack(self): 2215*da0073e9SAndroid Build Coastguard Worker def fn(x): 2216*da0073e9SAndroid Build Coastguard Worker # seen in fastNLP_Bert 2217*da0073e9SAndroid Build Coastguard Worker return torch.cat([*x], dim=-1) 2218*da0073e9SAndroid Build Coastguard Worker 2219*da0073e9SAndroid Build Coastguard Worker val = torch.randn([1, 1, 473, 768]) 2220*da0073e9SAndroid Build Coastguard Worker correct = fn(val) 2221*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 2222*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 2223*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(opt_fn(val), correct)) 2224*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 2225*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.op_count, 2) 2226*da0073e9SAndroid Build Coastguard Worker 2227*da0073e9SAndroid Build Coastguard Worker def test_numpy_int_constant(self): 2228*da0073e9SAndroid Build Coastguard Worker def fn(x, a, b): 2229*da0073e9SAndroid Build Coastguard Worker return x + (a % b) 2230*da0073e9SAndroid Build Coastguard Worker 2231*da0073e9SAndroid Build Coastguard Worker args = [torch.randn(10), 4096, np.int64(8)] 2232*da0073e9SAndroid Build Coastguard Worker correct = fn(*args) 2233*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 2234*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts, dynamic=True, nopython=True)(fn) 2235*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(opt_fn(*args), correct)) 2236*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(opt_fn(*args), correct)) 2237*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 2238*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.op_count, 2) 2239*da0073e9SAndroid Build Coastguard Worker 2240*da0073e9SAndroid Build Coastguard Worker def test_numpy_subdtype(self): 2241*da0073e9SAndroid Build Coastguard Worker def fn(x, n): 2242*da0073e9SAndroid Build Coastguard Worker return np.issubdtype(type(n), np.integer) + x 2243*da0073e9SAndroid Build Coastguard Worker 2244*da0073e9SAndroid Build Coastguard Worker args = [torch.randn(10), 4096] 2245*da0073e9SAndroid Build Coastguard Worker correct = fn(*args) 2246*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 2247*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 2248*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(*args), correct) 2249*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 2250*da0073e9SAndroid Build Coastguard Worker 2251*da0073e9SAndroid Build Coastguard Worker def test_numpy_take_along_axis(self): 2252*da0073e9SAndroid Build Coastguard Worker def fn(x, i, a): 2253*da0073e9SAndroid Build Coastguard Worker return np.take_along_axis(x, i, a) 2254*da0073e9SAndroid Build Coastguard Worker 2255*da0073e9SAndroid Build Coastguard Worker def sample_to_args(s): 2256*da0073e9SAndroid Build Coastguard Worker args = (s.input, *sample.args) 2257*da0073e9SAndroid Build Coastguard Worker return tuple(a.numpy() if isinstance(a, torch.Tensor) else a for a in args) 2258*da0073e9SAndroid Build Coastguard Worker 2259*da0073e9SAndroid Build Coastguard Worker samples = list( 2260*da0073e9SAndroid Build Coastguard Worker sample_inputs_take_along_dim( 2261*da0073e9SAndroid Build Coastguard Worker None, "cpu", torch.float32, requires_grad=False 2262*da0073e9SAndroid Build Coastguard Worker ) 2263*da0073e9SAndroid Build Coastguard Worker ) 2264*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 2265*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 2266*da0073e9SAndroid Build Coastguard Worker i = 1 2267*da0073e9SAndroid Build Coastguard Worker for sample in samples: 2268*da0073e9SAndroid Build Coastguard Worker args = sample_to_args(sample) 2269*da0073e9SAndroid Build Coastguard Worker if len(args) < 3: 2270*da0073e9SAndroid Build Coastguard Worker # if axis is None, second argument is treated as 1d array 2271*da0073e9SAndroid Build Coastguard Worker args = (args[0], np.ravel(args[1]), None) 2272*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(*args), opt_fn(*args)) 2273*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, i) 2274*da0073e9SAndroid Build Coastguard Worker i += 1 2275*da0073e9SAndroid Build Coastguard Worker 2276*da0073e9SAndroid Build Coastguard Worker def test_numpy_torch_operators(self): 2277*da0073e9SAndroid Build Coastguard Worker def fn(op, t1, t2): 2278*da0073e9SAndroid Build Coastguard Worker return op(t1, t2) 2279*da0073e9SAndroid Build Coastguard Worker 2280*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.variables.builtin import BuiltinVariable 2281*da0073e9SAndroid Build Coastguard Worker 2282*da0073e9SAndroid Build Coastguard Worker operators = BuiltinVariable._fx_graph_functions() 2283*da0073e9SAndroid Build Coastguard Worker 2284*da0073e9SAndroid Build Coastguard Worker for op, t1_np, t2_np in itertools.product( 2285*da0073e9SAndroid Build Coastguard Worker operators, (True, False), (True, False) 2286*da0073e9SAndroid Build Coastguard Worker ): 2287*da0073e9SAndroid Build Coastguard Worker if op in [operator.eq, operator.ne]: 2288*da0073e9SAndroid Build Coastguard Worker # returns equivalent of torch.eq/ne 2289*da0073e9SAndroid Build Coastguard Worker continue 2290*da0073e9SAndroid Build Coastguard Worker if op is operator.getitem: 2291*da0073e9SAndroid Build Coastguard Worker # skip 2292*da0073e9SAndroid Build Coastguard Worker # Did you know that tensor[ndarray_of_floats] works? 2293*da0073e9SAndroid Build Coastguard Worker continue 2294*da0073e9SAndroid Build Coastguard Worker if op is operator.imatmul and (t1_np or t2_np): 2295*da0073e9SAndroid Build Coastguard Worker # skip 2296*da0073e9SAndroid Build Coastguard Worker # in numpy, in place matmul does not work single 2297*da0073e9SAndroid Build Coastguard Worker # dimensional arrays 2298*da0073e9SAndroid Build Coastguard Worker continue 2299*da0073e9SAndroid Build Coastguard Worker t1 = torch.rand(5) 2300*da0073e9SAndroid Build Coastguard Worker if t1_np: 2301*da0073e9SAndroid Build Coastguard Worker t1 = t1.numpy() 2302*da0073e9SAndroid Build Coastguard Worker t2 = torch.rand(5) 2303*da0073e9SAndroid Build Coastguard Worker if t2_np: 2304*da0073e9SAndroid Build Coastguard Worker t2 = t2.numpy() 2305*da0073e9SAndroid Build Coastguard Worker try: 2306*da0073e9SAndroid Build Coastguard Worker # TODO try a bit harder 2307*da0073e9SAndroid Build Coastguard Worker result = op(t1, t2) 2308*da0073e9SAndroid Build Coastguard Worker except (RuntimeError, TypeError, IndexError): 2309*da0073e9SAndroid Build Coastguard Worker continue 2310*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 2311*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 2312*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, opt_fn(op, t1, t2), msg=f"{op=} {t1_np=} {t2_np=}") 2313*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1, msg=f"{op=} {t1_np=} {t2_np=}") 2314*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 2315*da0073e9SAndroid Build Coastguard Worker 2316*da0073e9SAndroid Build Coastguard Worker def test_numpy_ndarray_graph_break(self): 2317*da0073e9SAndroid Build Coastguard Worker def fn(x): 2318*da0073e9SAndroid Build Coastguard Worker a = x.numpy() 2319*da0073e9SAndroid Build Coastguard Worker b = a.real 2320*da0073e9SAndroid Build Coastguard Worker torch._dynamo.graph_break() 2321*da0073e9SAndroid Build Coastguard Worker c = np.multiply(b, 2.0) 2322*da0073e9SAndroid Build Coastguard Worker return c 2323*da0073e9SAndroid Build Coastguard Worker 2324*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 2325*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 2326*da0073e9SAndroid Build Coastguard Worker for _ in range(10): 2327*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 2328*da0073e9SAndroid Build Coastguard Worker ref = fn(x) 2329*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x) 2330*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 2331*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 2) 2332*da0073e9SAndroid Build Coastguard Worker 2333*da0073e9SAndroid Build Coastguard Worker def test_numpy_ndarray_graph_break_with_multiple_outputs(self): 2334*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 2335*da0073e9SAndroid Build Coastguard Worker a = x.numpy() 2336*da0073e9SAndroid Build Coastguard Worker b = y.numpy() 2337*da0073e9SAndroid Build Coastguard Worker torch._dynamo.graph_break() 2338*da0073e9SAndroid Build Coastguard Worker return np.add(a, 1), np.add(b, 1) 2339*da0073e9SAndroid Build Coastguard Worker 2340*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 2341*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 2342*da0073e9SAndroid Build Coastguard Worker for _ in range(10): 2343*da0073e9SAndroid Build Coastguard Worker x = torch.randn([1, 3]) 2344*da0073e9SAndroid Build Coastguard Worker y = torch.randn([1, 3]) 2345*da0073e9SAndroid Build Coastguard Worker ref = fn(x, y) 2346*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x, y) 2347*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 2348*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 2) 2349*da0073e9SAndroid Build Coastguard Worker 2350*da0073e9SAndroid Build Coastguard Worker def test_numpy_force(self): 2351*da0073e9SAndroid Build Coastguard Worker def fn(x): 2352*da0073e9SAndroid Build Coastguard Worker return x.numpy(force=False) 2353*da0073e9SAndroid Build Coastguard Worker 2354*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 2355*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 2356*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 2357*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x) 2358*da0073e9SAndroid Build Coastguard Worker self.assertEqual(type(res), np.ndarray) 2359*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 2360*da0073e9SAndroid Build Coastguard Worker 2361*da0073e9SAndroid Build Coastguard Worker def fn(x): 2362*da0073e9SAndroid Build Coastguard Worker return x.numpy(force=True) 2363*da0073e9SAndroid Build Coastguard Worker 2364*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 2365*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 2366*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, requires_grad=True) 2367*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x) 2368*da0073e9SAndroid Build Coastguard Worker self.assertEqual(type(res), np.ndarray) 2369*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 2370*da0073e9SAndroid Build Coastguard Worker 2371*da0073e9SAndroid Build Coastguard Worker def test_numpy_recompilation_scalar(self): 2372*da0073e9SAndroid Build Coastguard Worker def fn(x, a): 2373*da0073e9SAndroid Build Coastguard Worker return np.where(x < 0.5, a, x) 2374*da0073e9SAndroid Build Coastguard Worker 2375*da0073e9SAndroid Build Coastguard Worker x = np.random.randn(8) 2376*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 2377*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts, dynamic=True)(fn) 2378*da0073e9SAndroid Build Coastguard Worker 2379*da0073e9SAndroid Build Coastguard Worker ref = fn(x, 3) 2380*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x, 3) 2381*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 2382*da0073e9SAndroid Build Coastguard Worker 2383*da0073e9SAndroid Build Coastguard Worker ref = fn(x, 4) 2384*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x, 4) 2385*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 2386*da0073e9SAndroid Build Coastguard Worker 2387*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 2388*da0073e9SAndroid Build Coastguard Worker 2389*da0073e9SAndroid Build Coastguard Worker def test_tensor_interacts_with_numpy_ndarray(self): 2390*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 2391*da0073e9SAndroid Build Coastguard Worker a = x.numpy() 2392*da0073e9SAndroid Build Coastguard Worker b = y.numpy() 2393*da0073e9SAndroid Build Coastguard Worker c = np.ones_like(a) 2394*da0073e9SAndroid Build Coastguard Worker d = np.ones_like(b) 2395*da0073e9SAndroid Build Coastguard Worker torch._dynamo.graph_break() 2396*da0073e9SAndroid Build Coastguard Worker return np.add(a, c), np.add(b, d) 2397*da0073e9SAndroid Build Coastguard Worker 2398*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 2399*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 2400*da0073e9SAndroid Build Coastguard Worker for _ in range(10): 2401*da0073e9SAndroid Build Coastguard Worker x = torch.randn([1, 3]) 2402*da0073e9SAndroid Build Coastguard Worker y = torch.randn([1, 3]) 2403*da0073e9SAndroid Build Coastguard Worker ref = fn(x, y) 2404*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x, y) 2405*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 2406*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 2) 2407*da0073e9SAndroid Build Coastguard Worker 2408*da0073e9SAndroid Build Coastguard Worker def test_numpy_ndarray_works_with_builtin_function(self): 2409*da0073e9SAndroid Build Coastguard Worker def fn(x): 2410*da0073e9SAndroid Build Coastguard Worker v = x.sum() / len(x) 2411*da0073e9SAndroid Build Coastguard Worker return v 2412*da0073e9SAndroid Build Coastguard Worker 2413*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 2414*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 2415*da0073e9SAndroid Build Coastguard Worker for _ in range(10): 2416*da0073e9SAndroid Build Coastguard Worker x = np.random.randn(2, 3) 2417*da0073e9SAndroid Build Coastguard Worker ref = fn(x) 2418*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x) 2419*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 2420*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 2421*da0073e9SAndroid Build Coastguard Worker 2422*da0073e9SAndroid Build Coastguard Worker def test_numpy_array_of_arrays(self): 2423*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 2424*da0073e9SAndroid Build Coastguard Worker return np.array([x, y]) 2425*da0073e9SAndroid Build Coastguard Worker 2426*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 2427*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 2428*da0073e9SAndroid Build Coastguard Worker 2429*da0073e9SAndroid Build Coastguard Worker x, y = np.float64(1), np.float64(2) 2430*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x, y) 2431*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, np.array([1, 2], dtype=float)) 2432*da0073e9SAndroid Build Coastguard Worker self.assertEqual(type(res), np.ndarray) 2433*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 2434*da0073e9SAndroid Build Coastguard Worker 2435*da0073e9SAndroid Build Coastguard Worker x, y = np.arange(2), np.arange(2) + 2 2436*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x, y) 2437*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, np.array([[0, 1], [2, 3]])) 2438*da0073e9SAndroid Build Coastguard Worker self.assertEqual(type(res), np.ndarray) 2439*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 2) 2440*da0073e9SAndroid Build Coastguard Worker 2441*da0073e9SAndroid Build Coastguard Worker def test_numpy_readonly(self): 2442*da0073e9SAndroid Build Coastguard Worker @torch.compile(fullgraph=True) 2443*da0073e9SAndroid Build Coastguard Worker def fn(x): 2444*da0073e9SAndroid Build Coastguard Worker return x 2445*da0073e9SAndroid Build Coastguard Worker 2446*da0073e9SAndroid Build Coastguard Worker x = np.broadcast_to(np.arange(3), (2, 3)) 2447*da0073e9SAndroid Build Coastguard Worker self.assertFalse(x.flags.writeable) 2448*da0073e9SAndroid Build Coastguard Worker 2449*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(): 2450*da0073e9SAndroid Build Coastguard Worker warnings.simplefilter("error") 2451*da0073e9SAndroid Build Coastguard Worker y = fn(x) 2452*da0073e9SAndroid Build Coastguard Worker self.assertTrue(y.flags.writeable) # XXX: differs from numpy 2453*da0073e9SAndroid Build Coastguard Worker 2454*da0073e9SAndroid Build Coastguard Worker def test_numpy_tolist(self): 2455*da0073e9SAndroid Build Coastguard Worker def fn(x): 2456*da0073e9SAndroid Build Coastguard Worker return x.tolist() 2457*da0073e9SAndroid Build Coastguard Worker 2458*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 2459*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 2460*da0073e9SAndroid Build Coastguard Worker 2461*da0073e9SAndroid Build Coastguard Worker x = np.arange(5) 2462*da0073e9SAndroid Build Coastguard Worker r = opt_fn(x) 2463*da0073e9SAndroid Build Coastguard Worker 2464*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r, [0, 1, 2, 3, 4]) 2465*da0073e9SAndroid Build Coastguard Worker self.assertEqual(type(r), list) 2466*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 2467*da0073e9SAndroid Build Coastguard Worker 2468*da0073e9SAndroid Build Coastguard Worker def test_numpy_size_attr(self): 2469*da0073e9SAndroid Build Coastguard Worker def fn(x): 2470*da0073e9SAndroid Build Coastguard Worker return x.size + x 2471*da0073e9SAndroid Build Coastguard Worker 2472*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 2473*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 2474*da0073e9SAndroid Build Coastguard Worker 2475*da0073e9SAndroid Build Coastguard Worker x = np.arange(5) 2476*da0073e9SAndroid Build Coastguard Worker r = opt_fn(x) 2477*da0073e9SAndroid Build Coastguard Worker 2478*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r, fn(x)) 2479*da0073e9SAndroid Build Coastguard Worker self.assertEqual(type(r), np.ndarray) 2480*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 2481*da0073e9SAndroid Build Coastguard Worker 2482*da0073e9SAndroid Build Coastguard Worker def test_numpy_no_raise(self): 2483*da0073e9SAndroid Build Coastguard Worker def _inf_nan_preprocess(t, t_np): 2484*da0073e9SAndroid Build Coastguard Worker t_np = np.nan_to_num(t_np) 2485*da0073e9SAndroid Build Coastguard Worker return t, t_np 2486*da0073e9SAndroid Build Coastguard Worker 2487*da0073e9SAndroid Build Coastguard Worker def fn(): 2488*da0073e9SAndroid Build Coastguard Worker # shape, dims format 2489*da0073e9SAndroid Build Coastguard Worker test_cases = ( 2490*da0073e9SAndroid Build Coastguard Worker (3, 3), 2491*da0073e9SAndroid Build Coastguard Worker (4, 4), 2492*da0073e9SAndroid Build Coastguard Worker (5, 5), 2493*da0073e9SAndroid Build Coastguard Worker ) 2494*da0073e9SAndroid Build Coastguard Worker 2495*da0073e9SAndroid Build Coastguard Worker for shape in test_cases: 2496*da0073e9SAndroid Build Coastguard Worker t = torch.randn(shape, dtype=torch.complex64) 2497*da0073e9SAndroid Build Coastguard Worker t_np = np.random.randn(*shape).astype(np.complex64) 2498*da0073e9SAndroid Build Coastguard Worker 2499*da0073e9SAndroid Build Coastguard Worker _, t_np = _inf_nan_preprocess(t, t_np) 2500*da0073e9SAndroid Build Coastguard Worker print(t, t_np) # Just a side effect so that compilation kicks in 2501*da0073e9SAndroid Build Coastguard Worker 2502*da0073e9SAndroid Build Coastguard Worker cnt = CompileCounterWithBackend("inductor") 2503*da0073e9SAndroid Build Coastguard Worker fn = torch._dynamo.optimize(cnt)(fn) 2504*da0073e9SAndroid Build Coastguard Worker fn() 2505*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, ifdynstaticdefault(2, 1)) 2506*da0073e9SAndroid Build Coastguard Worker 2507*da0073e9SAndroid Build Coastguard Worker def test_mandelbrot_numpy(self): 2508*da0073e9SAndroid Build Coastguard Worker def mandelbrot_numpy(max_iter): 2509*da0073e9SAndroid Build Coastguard Worker # Define the boundaries of the complex plane 2510*da0073e9SAndroid Build Coastguard Worker xn = 450 2511*da0073e9SAndroid Build Coastguard Worker yn = 375 2512*da0073e9SAndroid Build Coastguard Worker xmin = -2.25 2513*da0073e9SAndroid Build Coastguard Worker xmax = 0.75 2514*da0073e9SAndroid Build Coastguard Worker ymin = -1.25 2515*da0073e9SAndroid Build Coastguard Worker ymax = 1.25 2516*da0073e9SAndroid Build Coastguard Worker 2517*da0073e9SAndroid Build Coastguard Worker # Create the grid of complex numbers 2518*da0073e9SAndroid Build Coastguard Worker x_values = np.linspace(xmin, xmax, xn, dtype=np.float64) 2519*da0073e9SAndroid Build Coastguard Worker y_values = np.linspace(ymin, ymax, yn, dtype=np.float64) 2520*da0073e9SAndroid Build Coastguard Worker rx, iy = np.meshgrid(x_values, y_values, indexing="xy") 2521*da0073e9SAndroid Build Coastguard Worker 2522*da0073e9SAndroid Build Coastguard Worker x = rx.copy() 2523*da0073e9SAndroid Build Coastguard Worker y = iy.copy() 2524*da0073e9SAndroid Build Coastguard Worker mask = np.zeros_like(x) 2525*da0073e9SAndroid Build Coastguard Worker for i in range(max_iter): 2526*da0073e9SAndroid Build Coastguard Worker x_prev = x 2527*da0073e9SAndroid Build Coastguard Worker y_prev = y 2528*da0073e9SAndroid Build Coastguard Worker x = x_prev**2 - y_prev**2 + rx 2529*da0073e9SAndroid Build Coastguard Worker y = 2 * x_prev * y_prev + iy 2530*da0073e9SAndroid Build Coastguard Worker inside = np.sqrt(x**2 + y**2) <= 2 2531*da0073e9SAndroid Build Coastguard Worker mask += inside 2532*da0073e9SAndroid Build Coastguard Worker return mask 2533*da0073e9SAndroid Build Coastguard Worker 2534*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 2535*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts, nopython=True)(mandelbrot_numpy) 2536*da0073e9SAndroid Build Coastguard Worker n_iter = torch._dynamo.config.cache_size_limit - 2 2537*da0073e9SAndroid Build Coastguard Worker for i in range(n_iter): 2538*da0073e9SAndroid Build Coastguard Worker x = i + 3 2539*da0073e9SAndroid Build Coastguard Worker ref = mandelbrot_numpy(x) 2540*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x) 2541*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 2542*da0073e9SAndroid Build Coastguard Worker # We need to specialise the number as it's in a forloop 2543*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, n_iter) 2544*da0073e9SAndroid Build Coastguard Worker 2545*da0073e9SAndroid Build Coastguard Worker def test_numpy_as_global(self): 2546*da0073e9SAndroid Build Coastguard Worker global x 2547*da0073e9SAndroid Build Coastguard Worker x = np.arange(10) 2548*da0073e9SAndroid Build Coastguard Worker 2549*da0073e9SAndroid Build Coastguard Worker @torch.compile(fullgraph=True) 2550*da0073e9SAndroid Build Coastguard Worker def fn(y): 2551*da0073e9SAndroid Build Coastguard Worker return y + x + x 2552*da0073e9SAndroid Build Coastguard Worker 2553*da0073e9SAndroid Build Coastguard Worker r = fn(np.arange(10)) 2554*da0073e9SAndroid Build Coastguard Worker self.assertEqual(type(r), np.ndarray) 2555*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r, x * 3) 2556*da0073e9SAndroid Build Coastguard Worker del x 2557*da0073e9SAndroid Build Coastguard Worker 2558*da0073e9SAndroid Build Coastguard Worker def test_numpy_gt(self): 2559*da0073e9SAndroid Build Coastguard Worker x = np.arange(10) 2560*da0073e9SAndroid Build Coastguard Worker 2561*da0073e9SAndroid Build Coastguard Worker @torch.compile 2562*da0073e9SAndroid Build Coastguard Worker def fn(y): 2563*da0073e9SAndroid Build Coastguard Worker return y >= 3 2564*da0073e9SAndroid Build Coastguard Worker 2565*da0073e9SAndroid Build Coastguard Worker r = fn(x) 2566*da0073e9SAndroid Build Coastguard Worker self.assertEqual(type(r), np.ndarray) 2567*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r, x >= 3) 2568*da0073e9SAndroid Build Coastguard Worker 2569*da0073e9SAndroid Build Coastguard Worker def test_numpy_min(self): 2570*da0073e9SAndroid Build Coastguard Worker x = np.arange(10) 2571*da0073e9SAndroid Build Coastguard Worker 2572*da0073e9SAndroid Build Coastguard Worker @torch.compile 2573*da0073e9SAndroid Build Coastguard Worker def fn(y): 2574*da0073e9SAndroid Build Coastguard Worker return min(y, 3), min(y, y - 1) 2575*da0073e9SAndroid Build Coastguard Worker 2576*da0073e9SAndroid Build Coastguard Worker r1, r2 = fn(x) 2577*da0073e9SAndroid Build Coastguard Worker self.assertEqual(type(r1), np.ndarray) 2578*da0073e9SAndroid Build Coastguard Worker self.assertEqual(type(r2), np.ndarray) 2579*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r1, np.minimum(x, 3)) 2580*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r2, np.minimum(x, x - 1)) 2581*da0073e9SAndroid Build Coastguard Worker 2582*da0073e9SAndroid Build Coastguard Worker def test_graph_break_correctly_when_passing_numpy_ndarray_to_torch_function(self): 2583*da0073e9SAndroid Build Coastguard Worker # from transformers/models/big_bird/modeling_big_bird.py 2584*da0073e9SAndroid Build Coastguard Worker def fn(x: int, y: torch.Tensor): 2585*da0073e9SAndroid Build Coastguard Worker ndarray_list = [np.ones([2, x])] 2586*da0073e9SAndroid Build Coastguard Worker ndarray = np.stack(ndarray_list, axis=0) 2587*da0073e9SAndroid Build Coastguard Worker tensor = torch.tensor(ndarray, dtype=torch.long) 2588*da0073e9SAndroid Build Coastguard Worker tensor.unsqueeze_(0) 2589*da0073e9SAndroid Build Coastguard Worker return tensor + y 2590*da0073e9SAndroid Build Coastguard Worker 2591*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 2592*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 2593*da0073e9SAndroid Build Coastguard Worker for x in range(1, 10): 2594*da0073e9SAndroid Build Coastguard Worker y = torch.randn([1, 2, x]) 2595*da0073e9SAndroid Build Coastguard Worker ref = fn(x, y) 2596*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x, y) 2597*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 2598*da0073e9SAndroid Build Coastguard Worker # It's all traced once with x = 1, x = 2 and then x = ks0 2599*da0073e9SAndroid Build Coastguard Worker # For dynamic it's x=1 and x=ks0 2600*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, ifdynstaticdefault(3, 2)) 2601*da0073e9SAndroid Build Coastguard Worker 2602*da0073e9SAndroid Build Coastguard Worker def test_numpy_with_builtin_type(self): 2603*da0073e9SAndroid Build Coastguard Worker x = np.random.rand(5) 2604*da0073e9SAndroid Build Coastguard Worker 2605*da0073e9SAndroid Build Coastguard Worker def fn(x): 2606*da0073e9SAndroid Build Coastguard Worker return (x * 5).astype(bool).astype(float).astype(int) + 8 2607*da0073e9SAndroid Build Coastguard Worker 2608*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 2609*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 2610*da0073e9SAndroid Build Coastguard Worker 2611*da0073e9SAndroid Build Coastguard Worker r = opt_fn(x) 2612*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r.dtype, int) 2613*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 2614*da0073e9SAndroid Build Coastguard Worker 2615*da0073e9SAndroid Build Coastguard Worker def test_with_builtin_type(self): 2616*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5) 2617*da0073e9SAndroid Build Coastguard Worker 2618*da0073e9SAndroid Build Coastguard Worker def fn(x): 2619*da0073e9SAndroid Build Coastguard Worker return (x * 5).to(bool).to(float).to(int) + 8 2620*da0073e9SAndroid Build Coastguard Worker 2621*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 2622*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 2623*da0073e9SAndroid Build Coastguard Worker 2624*da0073e9SAndroid Build Coastguard Worker r = opt_fn(x) 2625*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r.dtype, torch.int64) 2626*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 2627*da0073e9SAndroid Build Coastguard Worker 2628*da0073e9SAndroid Build Coastguard Worker def test_numpy_unique_f16(self): 2629*da0073e9SAndroid Build Coastguard Worker def fn(): 2630*da0073e9SAndroid Build Coastguard Worker x = np.asarray([1, 1, 2, 2, 3], dtype=np.float16) 2631*da0073e9SAndroid Build Coastguard Worker return np.unique(x) 2632*da0073e9SAndroid Build Coastguard Worker 2633*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 2634*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 2635*da0073e9SAndroid Build Coastguard Worker 2636*da0073e9SAndroid Build Coastguard Worker r = opt_fn() 2637*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r.dtype, np.float16) 2638*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 2639*da0073e9SAndroid Build Coastguard Worker 2640*da0073e9SAndroid Build Coastguard Worker def test_numpy_fallback_on_eager(self): 2641*da0073e9SAndroid Build Coastguard Worker def fn(): 2642*da0073e9SAndroid Build Coastguard Worker return np.asarray(["L", "U"]) 2643*da0073e9SAndroid Build Coastguard Worker 2644*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 2645*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 2646*da0073e9SAndroid Build Coastguard Worker 2647*da0073e9SAndroid Build Coastguard Worker r = opt_fn() 2648*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 0) # graph break 2649*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r, np.asarray(["L", "U"])) 2650*da0073e9SAndroid Build Coastguard Worker 2651*da0073e9SAndroid Build Coastguard Worker # repeat with a different function 2652*da0073e9SAndroid Build Coastguard Worker def fn2(): 2653*da0073e9SAndroid Build Coastguard Worker return np.random.choice(["L", "U"]) 2654*da0073e9SAndroid Build Coastguard Worker 2655*da0073e9SAndroid Build Coastguard Worker cnts2 = torch._dynamo.testing.CompileCounter() 2656*da0073e9SAndroid Build Coastguard Worker opt_fn2 = torch._dynamo.optimize(cnts2)(fn2) 2657*da0073e9SAndroid Build Coastguard Worker 2658*da0073e9SAndroid Build Coastguard Worker r2 = fn2() 2659*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 0) 2660*da0073e9SAndroid Build Coastguard Worker assert r2 in ("L", "U") 2661*da0073e9SAndroid Build Coastguard Worker 2662*da0073e9SAndroid Build Coastguard Worker def test_trace_ndarray_frame(self): 2663*da0073e9SAndroid Build Coastguard Worker def fn(x): 2664*da0073e9SAndroid Build Coastguard Worker x = x**2 2665*da0073e9SAndroid Build Coastguard Worker print("graph break.") 2666*da0073e9SAndroid Build Coastguard Worker return 2 * x 2667*da0073e9SAndroid Build Coastguard Worker 2668*da0073e9SAndroid Build Coastguard Worker counter = CompileCounter() 2669*da0073e9SAndroid Build Coastguard Worker compiled_fn = torch._dynamo.optimize(counter)(fn) 2670*da0073e9SAndroid Build Coastguard Worker 2671*da0073e9SAndroid Build Coastguard Worker x = np.arange(8) 2672*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(x), compiled_fn(x)) 2673*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 2) 2674*da0073e9SAndroid Build Coastguard Worker 2675*da0073e9SAndroid Build Coastguard Worker def test_trace_ndarray_frame_2(self): 2676*da0073e9SAndroid Build Coastguard Worker # no tensors/ndarray as inputs in the frame 2677*da0073e9SAndroid Build Coastguard Worker def fn(x): 2678*da0073e9SAndroid Build Coastguard Worker print("graph break.") 2679*da0073e9SAndroid Build Coastguard Worker return 2 * np.arange(x) 2680*da0073e9SAndroid Build Coastguard Worker 2681*da0073e9SAndroid Build Coastguard Worker counter = CompileCounter() 2682*da0073e9SAndroid Build Coastguard Worker compiled_fn = torch._dynamo.optimize(counter)(fn) 2683*da0073e9SAndroid Build Coastguard Worker 2684*da0073e9SAndroid Build Coastguard Worker x = 8 2685*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(x), compiled_fn(x)) 2686*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 1) 2687*da0073e9SAndroid Build Coastguard Worker 2688*da0073e9SAndroid Build Coastguard Worker def test_numpy_non_torch_dtype(self): 2689*da0073e9SAndroid Build Coastguard Worker # test that we gracefully graph break on dtypes 2690*da0073e9SAndroid Build Coastguard Worker # that do not have pytorch equivalents. 2691*da0073e9SAndroid Build Coastguard Worker def fn(x): 2692*da0073e9SAndroid Build Coastguard Worker return isinstance(x, torch.Tensor) 2693*da0073e9SAndroid Build Coastguard Worker 2694*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 2695*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 2696*da0073e9SAndroid Build Coastguard Worker 2697*da0073e9SAndroid Build Coastguard Worker # torch does not have the `uint16` dtype 2698*da0073e9SAndroid Build Coastguard Worker for x in [np.array([42], dtype=np.uint16), np.uint16(42), np.dtype("uint16")]: 2699*da0073e9SAndroid Build Coastguard Worker r = opt_fn(x) 2700*da0073e9SAndroid Build Coastguard Worker 2701*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r, False) 2702*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 0) # graph break 2703*da0073e9SAndroid Build Coastguard Worker 2704*da0073e9SAndroid Build Coastguard Worker def test_numpy_iter(self): 2705*da0073e9SAndroid Build Coastguard Worker # test that iteration over an ndarray produces ndarrays not bare tensors 2706*da0073e9SAndroid Build Coastguard Worker def fn(x): 2707*da0073e9SAndroid Build Coastguard Worker return [bm for bm in x] 2708*da0073e9SAndroid Build Coastguard Worker 2709*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 2710*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 2711*da0073e9SAndroid Build Coastguard Worker 2712*da0073e9SAndroid Build Coastguard Worker proba_map = np.arange(3)[:, None] 2713*da0073e9SAndroid Build Coastguard Worker res = opt_fn(proba_map) 2714*da0073e9SAndroid Build Coastguard Worker 2715*da0073e9SAndroid Build Coastguard Worker self.assertEqual([type(r) for r in res], [np.ndarray, np.ndarray, np.ndarray]) 2716*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, [np.array([0]), np.array([1]), np.array([2])]) 2717*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 2718*da0073e9SAndroid Build Coastguard Worker 2719*da0073e9SAndroid Build Coastguard Worker # cache size limit needs to be larger than the `dtypes` list size 2720*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch(cache_size_limit=12) 2721*da0073e9SAndroid Build Coastguard Worker def test_dtypes_no_graphbreaks(self): 2722*da0073e9SAndroid Build Coastguard Worker dtypes = [ 2723*da0073e9SAndroid Build Coastguard Worker # floats 2724*da0073e9SAndroid Build Coastguard Worker float, 2725*da0073e9SAndroid Build Coastguard Worker np.float64, 2726*da0073e9SAndroid Build Coastguard Worker "float64", 2727*da0073e9SAndroid Build Coastguard Worker np.float32, 2728*da0073e9SAndroid Build Coastguard Worker "float32", 2729*da0073e9SAndroid Build Coastguard Worker # np.dtype('float64') # XXX: this is not supported, yet 2730*da0073e9SAndroid Build Coastguard Worker # integers 2731*da0073e9SAndroid Build Coastguard Worker int, 2732*da0073e9SAndroid Build Coastguard Worker "int", 2733*da0073e9SAndroid Build Coastguard Worker np.intp, 2734*da0073e9SAndroid Build Coastguard Worker np.int32, 2735*da0073e9SAndroid Build Coastguard Worker np.uint8 2736*da0073e9SAndroid Build Coastguard Worker # np.dtype('int') # XXX: as above 2737*da0073e9SAndroid Build Coastguard Worker ] 2738*da0073e9SAndroid Build Coastguard Worker 2739*da0073e9SAndroid Build Coastguard Worker def fn(dt): 2740*da0073e9SAndroid Build Coastguard Worker return np.arange(5, dtype=dt) 2741*da0073e9SAndroid Build Coastguard Worker 2742*da0073e9SAndroid Build Coastguard Worker for dtyp in dtypes: 2743*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 2744*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 2745*da0073e9SAndroid Build Coastguard Worker 2746*da0073e9SAndroid Build Coastguard Worker val = fn(dtyp) 2747*da0073e9SAndroid Build Coastguard Worker opt_val = opt_fn(dtyp) 2748*da0073e9SAndroid Build Coastguard Worker 2749*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) # no graph break 2750*da0073e9SAndroid Build Coastguard Worker 2751*da0073e9SAndroid Build Coastguard Worker # setting the config value makes the PRNG identical to numpy's 2752*da0073e9SAndroid Build Coastguard Worker # NB this may involve a graph break 2753*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch(use_numpy_random_stream=True) 2754*da0073e9SAndroid Build Coastguard Worker def test_numpy_random_config_to_numpy(self): 2755*da0073e9SAndroid Build Coastguard Worker @torch.compile 2756*da0073e9SAndroid Build Coastguard Worker def fn(): 2757*da0073e9SAndroid Build Coastguard Worker return np.random.uniform(size=13) 2758*da0073e9SAndroid Build Coastguard Worker 2759*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn().shape, (13,)) 2760*da0073e9SAndroid Build Coastguard Worker 2761*da0073e9SAndroid Build Coastguard Worker def test_inplace_view_on_graph_input(self): 2762*da0073e9SAndroid Build Coastguard Worker # graph break when calling methods with inplace_view tag on graph input 2763*da0073e9SAndroid Build Coastguard Worker func_args_map = { 2764*da0073e9SAndroid Build Coastguard Worker lambda x: x.resize_(6).mul_(2): torch.ones(4), 2765*da0073e9SAndroid Build Coastguard Worker lambda x: x.t_().mul_(2): torch.rand(2, 3), 2766*da0073e9SAndroid Build Coastguard Worker lambda x: x.transpose_(0, 1).mul_(2): torch.rand(2, 3), 2767*da0073e9SAndroid Build Coastguard Worker lambda x: x.squeeze_().mul_(2): torch.rand(1, 2, 3), 2768*da0073e9SAndroid Build Coastguard Worker lambda x: x.unsqueeze_(0).mul_(2): torch.rand(2, 3), 2769*da0073e9SAndroid Build Coastguard Worker lambda x: x.resize_as_(torch.rand(200, 300)): torch.rand(2, 3), 2770*da0073e9SAndroid Build Coastguard Worker lambda x: x.swapaxes_(0, 1).mul_(2): torch.rand(2, 3), 2771*da0073e9SAndroid Build Coastguard Worker lambda x: x.swapdims_(0, 1).mul_(2): torch.rand(2, 3), 2772*da0073e9SAndroid Build Coastguard Worker lambda x: x.rename_("N", "C").mul_(2): torch.zeros(2, 3), 2773*da0073e9SAndroid Build Coastguard Worker lambda x: x.as_strided_((3, 2), (2, 1)).mul_(2): torch.zeros(2, 3), 2774*da0073e9SAndroid Build Coastguard Worker lambda x: x.detach_().mul_(2): torch.zeros(2, 3), 2775*da0073e9SAndroid Build Coastguard Worker } 2776*da0073e9SAndroid Build Coastguard Worker for func, args in func_args_map.items(): 2777*da0073e9SAndroid Build Coastguard Worker args_clone = args.clone() 2778*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 2779*da0073e9SAndroid Build Coastguard Worker opt_f = torch._dynamo.optimize(cnts)(func) 2780*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(func(args).shape, opt_f(args_clone).shape)) 2781*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 2782*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.op_count, 1) # mul_ 2783*da0073e9SAndroid Build Coastguard Worker 2784*da0073e9SAndroid Build Coastguard Worker def test_out_variants_with_resizing_on_graph_inputs(self): 2785*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 2786*da0073e9SAndroid Build Coastguard Worker return torch.cosh(x, out=y) + 1 2787*da0073e9SAndroid Build Coastguard Worker 2788*da0073e9SAndroid Build Coastguard Worker x = torch.rand(2, 3) 2789*da0073e9SAndroid Build Coastguard Worker y = torch.rand(4) 2790*da0073e9SAndroid Build Coastguard Worker 2791*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 2792*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend=cnts) 2793*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(fn(x, y), opt_fn(x.clone(), y.clone()))) 2794*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 2795*da0073e9SAndroid Build Coastguard Worker 2796*da0073e9SAndroid Build Coastguard Worker def test_out_variants_with_resizing_on_graph_inputs_with_dynamic(self): 2797*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/120482 2798*da0073e9SAndroid Build Coastguard Worker class CustomModel(torch.nn.Module): 2799*da0073e9SAndroid Build Coastguard Worker def __init__(self): 2800*da0073e9SAndroid Build Coastguard Worker super().__init__() 2801*da0073e9SAndroid Build Coastguard Worker 2802*da0073e9SAndroid Build Coastguard Worker def forward(self, inputs): 2803*da0073e9SAndroid Build Coastguard Worker return torch.outer(**inputs) 2804*da0073e9SAndroid Build Coastguard Worker 2805*da0073e9SAndroid Build Coastguard Worker compile_fn = torch.compile(CustomModel(), fullgraph=True) 2806*da0073e9SAndroid Build Coastguard Worker 2807*da0073e9SAndroid Build Coastguard Worker shapes = [(2, 1), (6, 1), (4, 1)] 2808*da0073e9SAndroid Build Coastguard Worker for shape in shapes: 2809*da0073e9SAndroid Build Coastguard Worker vec1, vec2 = shape 2810*da0073e9SAndroid Build Coastguard Worker input_tensor1 = torch.randn(vec1) 2811*da0073e9SAndroid Build Coastguard Worker input_tensor2 = torch.randn(vec2) 2812*da0073e9SAndroid Build Coastguard Worker out_tensor = torch.empty(shape) 2813*da0073e9SAndroid Build Coastguard Worker args = {"input": input_tensor1, "vec2": input_tensor2, "out": out_tensor} 2814*da0073e9SAndroid Build Coastguard Worker res = compile_fn(args) 2815*da0073e9SAndroid Build Coastguard Worker opt_res = res.clone() # cuz this is out and we mutate it 2816*da0073e9SAndroid Build Coastguard Worker res = CustomModel()(args) 2817*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, opt_res) 2818*da0073e9SAndroid Build Coastguard Worker 2819*da0073e9SAndroid Build Coastguard Worker def test_dict_mutation_side_effect(self): 2820*da0073e9SAndroid Build Coastguard Worker def fn(d): 2821*da0073e9SAndroid Build Coastguard Worker d["c"] = d["a"] + d.pop("b") 2822*da0073e9SAndroid Build Coastguard Worker return d 2823*da0073e9SAndroid Build Coastguard Worker 2824*da0073e9SAndroid Build Coastguard Worker args1 = {"a": torch.randn(10), "b": torch.randn(10)} 2825*da0073e9SAndroid Build Coastguard Worker args2 = dict(args1) 2826*da0073e9SAndroid Build Coastguard Worker assert fn(args1) is args1 2827*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 2828*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 2829*da0073e9SAndroid Build Coastguard Worker self.assertIs(opt_fn(args2), args2) 2830*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(args1, args2)) 2831*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 2832*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.op_count, 1) 2833*da0073e9SAndroid Build Coastguard Worker 2834*da0073e9SAndroid Build Coastguard Worker def test_dict_order_keys(self): 2835*da0073e9SAndroid Build Coastguard Worker def fn(d): 2836*da0073e9SAndroid Build Coastguard Worker c = 0 2837*da0073e9SAndroid Build Coastguard Worker for v in d.values(): 2838*da0073e9SAndroid Build Coastguard Worker c += v 2839*da0073e9SAndroid Build Coastguard Worker return c 2840*da0073e9SAndroid Build Coastguard Worker 2841*da0073e9SAndroid Build Coastguard Worker args1 = {} 2842*da0073e9SAndroid Build Coastguard Worker args1["a"] = torch.rand(10) 2843*da0073e9SAndroid Build Coastguard Worker args1["b"] = torch.rand(10) 2844*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 2845*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 2846*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(args1), opt_fn(args1)) 2847*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 2848*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.op_count, 2) 2849*da0073e9SAndroid Build Coastguard Worker 2850*da0073e9SAndroid Build Coastguard Worker # A different order of keys recompiles 2851*da0073e9SAndroid Build Coastguard Worker args2 = {} 2852*da0073e9SAndroid Build Coastguard Worker args2["b"] = args1["b"] 2853*da0073e9SAndroid Build Coastguard Worker args2["a"] = args1["a"] 2854*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(args2), opt_fn(args2)) 2855*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 2) 2856*da0073e9SAndroid Build Coastguard Worker # Extra calls don't recompile 2857*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 2) 2858*da0073e9SAndroid Build Coastguard Worker 2859*da0073e9SAndroid Build Coastguard Worker def test_dict_namedtuple(self): 2860*da0073e9SAndroid Build Coastguard Worker def fn(d): 2861*da0073e9SAndroid Build Coastguard Worker return d[3] * 2 2862*da0073e9SAndroid Build Coastguard Worker 2863*da0073e9SAndroid Build Coastguard Worker args1 = {collections.namedtuple: None, 3: torch.randn(3)} 2864*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 2865*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 2866*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(args1), opt_fn(args1)) 2867*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 2868*da0073e9SAndroid Build Coastguard Worker # Test a failing namedtuple guard 2869*da0073e9SAndroid Build Coastguard Worker args2 = {2: None, 3: torch.randn(3)} 2870*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(args2), opt_fn(args2)) 2871*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 2) 2872*da0073e9SAndroid Build Coastguard Worker 2873*da0073e9SAndroid Build Coastguard Worker def test_dict_order_keys_tensors(self): 2874*da0073e9SAndroid Build Coastguard Worker def fn(d, x): 2875*da0073e9SAndroid Build Coastguard Worker return d[x] + 3 2876*da0073e9SAndroid Build Coastguard Worker 2877*da0073e9SAndroid Build Coastguard Worker args1 = {} 2878*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10) 2879*da0073e9SAndroid Build Coastguard Worker y = torch.randn(10) 2880*da0073e9SAndroid Build Coastguard Worker z = torch.randn(10) 2881*da0073e9SAndroid Build Coastguard Worker args1[x] = y 2882*da0073e9SAndroid Build Coastguard Worker args1[3] = z 2883*da0073e9SAndroid Build Coastguard Worker 2884*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 2885*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 2886*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(args1, x), opt_fn(args1, x)) 2887*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 2888*da0073e9SAndroid Build Coastguard Worker 2889*da0073e9SAndroid Build Coastguard Worker # Calling again doesn't recompile (same id and key order) 2890*da0073e9SAndroid Build Coastguard Worker opt_fn(args1, x) 2891*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 2892*da0073e9SAndroid Build Coastguard Worker args2 = {} 2893*da0073e9SAndroid Build Coastguard Worker args2[3] = z 2894*da0073e9SAndroid Build Coastguard Worker args2[x] = y 2895*da0073e9SAndroid Build Coastguard Worker 2896*da0073e9SAndroid Build Coastguard Worker # Different order recompiles 2897*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(args2, x), opt_fn(args2, x)) 2898*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 2) 2899*da0073e9SAndroid Build Coastguard Worker 2900*da0073e9SAndroid Build Coastguard Worker def test_dict_order_keys_modules(self): 2901*da0073e9SAndroid Build Coastguard Worker def fn(d, x): 2902*da0073e9SAndroid Build Coastguard Worker return d[x](torch.ones(2, 2)) 2903*da0073e9SAndroid Build Coastguard Worker 2904*da0073e9SAndroid Build Coastguard Worker args1 = {} 2905*da0073e9SAndroid Build Coastguard Worker x = torch.nn.Linear(2, 2) 2906*da0073e9SAndroid Build Coastguard Worker y = torch.nn.Linear(2, 2) 2907*da0073e9SAndroid Build Coastguard Worker z = torch.nn.Linear(2, 2) 2908*da0073e9SAndroid Build Coastguard Worker args1[x] = y 2909*da0073e9SAndroid Build Coastguard Worker args1[3] = z 2910*da0073e9SAndroid Build Coastguard Worker 2911*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 2912*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 2913*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(args1, x), opt_fn(args1, x)) 2914*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 2915*da0073e9SAndroid Build Coastguard Worker 2916*da0073e9SAndroid Build Coastguard Worker # Calling again doesn't recompile (same id and key order) 2917*da0073e9SAndroid Build Coastguard Worker opt_fn(args1, x) 2918*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 2919*da0073e9SAndroid Build Coastguard Worker args2 = {} 2920*da0073e9SAndroid Build Coastguard Worker args2[3] = z 2921*da0073e9SAndroid Build Coastguard Worker args2[x] = y 2922*da0073e9SAndroid Build Coastguard Worker 2923*da0073e9SAndroid Build Coastguard Worker # Different order recompiles 2924*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(args2, x), opt_fn(args2, x)) 2925*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 2) 2926*da0073e9SAndroid Build Coastguard Worker 2927*da0073e9SAndroid Build Coastguard Worker def test_dunder_new_function_inlining(self): 2928*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/107460 2929*da0073e9SAndroid Build Coastguard Worker 2930*da0073e9SAndroid Build Coastguard Worker counters.clear() 2931*da0073e9SAndroid Build Coastguard Worker 2932*da0073e9SAndroid Build Coastguard Worker class ModelA(torch.nn.Module): 2933*da0073e9SAndroid Build Coastguard Worker def __init__(self): 2934*da0073e9SAndroid Build Coastguard Worker super().__init__() 2935*da0073e9SAndroid Build Coastguard Worker 2936*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2937*da0073e9SAndroid Build Coastguard Worker return torch.tanh(x + 1) 2938*da0073e9SAndroid Build Coastguard Worker 2939*da0073e9SAndroid Build Coastguard Worker class ModelB(torch.nn.Module): 2940*da0073e9SAndroid Build Coastguard Worker def __new__(cls): 2941*da0073e9SAndroid Build Coastguard Worker return ModelA() 2942*da0073e9SAndroid Build Coastguard Worker 2943*da0073e9SAndroid Build Coastguard Worker class Model(torch.nn.Module): 2944*da0073e9SAndroid Build Coastguard Worker def __init__(self): 2945*da0073e9SAndroid Build Coastguard Worker super().__init__() 2946*da0073e9SAndroid Build Coastguard Worker self.layer = torch.nn.Linear(2, 2) 2947*da0073e9SAndroid Build Coastguard Worker 2948*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2949*da0073e9SAndroid Build Coastguard Worker other = ModelB() 2950*da0073e9SAndroid Build Coastguard Worker return self.layer(x) + other(x) 2951*da0073e9SAndroid Build Coastguard Worker 2952*da0073e9SAndroid Build Coastguard Worker x = torch.rand(2, 2) 2953*da0073e9SAndroid Build Coastguard Worker m = Model() 2954*da0073e9SAndroid Build Coastguard Worker 2955*da0073e9SAndroid Build Coastguard Worker opt_m = torch.compile(backend="eager")(m) 2956*da0073e9SAndroid Build Coastguard Worker ref = m(x) 2957*da0073e9SAndroid Build Coastguard Worker res = opt_m(x) 2958*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 2959*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(counters["graph_break"]), 1) 2960*da0073e9SAndroid Build Coastguard Worker self.assertFalse("super() nn.Module.__init__" in counters["graph_break"]) 2961*da0073e9SAndroid Build Coastguard Worker 2962*da0073e9SAndroid Build Coastguard Worker def test_class_duner_mro(self): 2963*da0073e9SAndroid Build Coastguard Worker class ModuleA(torch.nn.Module): 2964*da0073e9SAndroid Build Coastguard Worker pass 2965*da0073e9SAndroid Build Coastguard Worker 2966*da0073e9SAndroid Build Coastguard Worker class ModuleB(ModuleA): 2967*da0073e9SAndroid Build Coastguard Worker pass 2968*da0073e9SAndroid Build Coastguard Worker 2969*da0073e9SAndroid Build Coastguard Worker def fn(x, mod): 2970*da0073e9SAndroid Build Coastguard Worker if ModuleA in type(mod).__mro__: 2971*da0073e9SAndroid Build Coastguard Worker return x + 1 2972*da0073e9SAndroid Build Coastguard Worker else: 2973*da0073e9SAndroid Build Coastguard Worker return x - 1 2974*da0073e9SAndroid Build Coastguard Worker 2975*da0073e9SAndroid Build Coastguard Worker x = torch.rand(2, 3) 2976*da0073e9SAndroid Build Coastguard Worker mod = ModuleB() 2977*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(backend="eager", fullgraph=True)(fn) 2978*da0073e9SAndroid Build Coastguard Worker ref = fn(x, mod) 2979*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x, mod) 2980*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 2981*da0073e9SAndroid Build Coastguard Worker 2982*da0073e9SAndroid Build Coastguard Worker def test_nested_wraps(self): 2983*da0073e9SAndroid Build Coastguard Worker def foo(x, y): 2984*da0073e9SAndroid Build Coastguard Worker def add(x, y): 2985*da0073e9SAndroid Build Coastguard Worker return x + y 2986*da0073e9SAndroid Build Coastguard Worker 2987*da0073e9SAndroid Build Coastguard Worker @functools.wraps(add) 2988*da0073e9SAndroid Build Coastguard Worker def wrapped_call(x, y): 2989*da0073e9SAndroid Build Coastguard Worker return add(x, y) 2990*da0073e9SAndroid Build Coastguard Worker 2991*da0073e9SAndroid Build Coastguard Worker return wrapped_call(x, y) 2992*da0073e9SAndroid Build Coastguard Worker 2993*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 3) 2994*da0073e9SAndroid Build Coastguard Worker y = torch.randn(3, 3) 2995*da0073e9SAndroid Build Coastguard Worker 2996*da0073e9SAndroid Build Coastguard Worker o = torch.compile(foo, fullgraph=True, backend="eager")(x, y) 2997*da0073e9SAndroid Build Coastguard Worker self.assertEqual(o, x + y) 2998*da0073e9SAndroid Build Coastguard Worker 2999*da0073e9SAndroid Build Coastguard Worker def foo(x, y): 3000*da0073e9SAndroid Build Coastguard Worker def nested_call(x, y): 3001*da0073e9SAndroid Build Coastguard Worker def mul(x, y): 3002*da0073e9SAndroid Build Coastguard Worker return x * y 3003*da0073e9SAndroid Build Coastguard Worker 3004*da0073e9SAndroid Build Coastguard Worker @functools.wraps(mul) 3005*da0073e9SAndroid Build Coastguard Worker def double_nested_call(x, y): 3006*da0073e9SAndroid Build Coastguard Worker return mul(x, y) 3007*da0073e9SAndroid Build Coastguard Worker 3008*da0073e9SAndroid Build Coastguard Worker return double_nested_call(x, y) 3009*da0073e9SAndroid Build Coastguard Worker 3010*da0073e9SAndroid Build Coastguard Worker return nested_call(x, y) 3011*da0073e9SAndroid Build Coastguard Worker 3012*da0073e9SAndroid Build Coastguard Worker o = torch.compile(foo, fullgraph=True, backend="eager")(x, y) 3013*da0073e9SAndroid Build Coastguard Worker self.assertEqual(o, x * y) 3014*da0073e9SAndroid Build Coastguard Worker 3015*da0073e9SAndroid Build Coastguard Worker def test_module_deepcopy(self): 3016*da0073e9SAndroid Build Coastguard Worker m1 = torch.nn.Sequential( 3017*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(10, 10), 3018*da0073e9SAndroid Build Coastguard Worker torch.nn.ReLU(), 3019*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(10, 10), 3020*da0073e9SAndroid Build Coastguard Worker torch.nn.ReLU(), 3021*da0073e9SAndroid Build Coastguard Worker ) 3022*da0073e9SAndroid Build Coastguard Worker m2 = torch.nn.Sequential( 3023*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(10, 10), 3024*da0073e9SAndroid Build Coastguard Worker torch.nn.ReLU(), 3025*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(10, 10), 3026*da0073e9SAndroid Build Coastguard Worker torch.nn.ReLU(), 3027*da0073e9SAndroid Build Coastguard Worker ) 3028*da0073e9SAndroid Build Coastguard Worker 3029*da0073e9SAndroid Build Coastguard Worker def fn(m, x): 3030*da0073e9SAndroid Build Coastguard Worker m_copy = copy.deepcopy(m) 3031*da0073e9SAndroid Build Coastguard Worker return m_copy(x) 3032*da0073e9SAndroid Build Coastguard Worker 3033*da0073e9SAndroid Build Coastguard Worker v = torch.randn(10) 3034*da0073e9SAndroid Build Coastguard Worker correct1 = fn(m1, v) 3035*da0073e9SAndroid Build Coastguard Worker correct2 = fn(m2, v) 3036*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 3037*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 3038*da0073e9SAndroid Build Coastguard Worker for _ in range(10): 3039*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(opt_fn(m1, v), correct1)) 3040*da0073e9SAndroid Build Coastguard Worker for _ in range(10): 3041*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(opt_fn(m2, v), correct2)) 3042*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 3043*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.op_count, 4) 3044*da0073e9SAndroid Build Coastguard Worker 3045*da0073e9SAndroid Build Coastguard Worker def test_type_copy(self): 3046*da0073e9SAndroid Build Coastguard Worker def fn(seq): 3047*da0073e9SAndroid Build Coastguard Worker a, b = seq 3048*da0073e9SAndroid Build Coastguard Worker return type(seq)([a + 1, b + 2, a + b]) 3049*da0073e9SAndroid Build Coastguard Worker 3050*da0073e9SAndroid Build Coastguard Worker args1 = [torch.randn(10), torch.randn(10)] 3051*da0073e9SAndroid Build Coastguard Worker args2 = (torch.randn(10), torch.randn(10)) 3052*da0073e9SAndroid Build Coastguard Worker correct1 = fn(args1) 3053*da0073e9SAndroid Build Coastguard Worker correct2 = fn(args2) 3054*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 3055*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 3056*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(opt_fn(args1), correct1)) 3057*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(opt_fn(args2), correct2)) 3058*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(opt_fn(args1), list) 3059*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(opt_fn(args2), tuple) 3060*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 2) 3061*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.op_count, 6) 3062*da0073e9SAndroid Build Coastguard Worker 3063*da0073e9SAndroid Build Coastguard Worker def test_setattr_mutation1(self): 3064*da0073e9SAndroid Build Coastguard Worker class MyObj: # noqa: B903 3065*da0073e9SAndroid Build Coastguard Worker def __init__(self, a, b): 3066*da0073e9SAndroid Build Coastguard Worker self.a = a 3067*da0073e9SAndroid Build Coastguard Worker self.b = b 3068*da0073e9SAndroid Build Coastguard Worker 3069*da0073e9SAndroid Build Coastguard Worker def fn(obj): 3070*da0073e9SAndroid Build Coastguard Worker obj.c = obj.a * obj.b + 1 3071*da0073e9SAndroid Build Coastguard Worker obj.b = obj.a * obj.c + 2 3072*da0073e9SAndroid Build Coastguard Worker obj.a = obj.b * obj.c + 3 3073*da0073e9SAndroid Build Coastguard Worker obj.c = obj.a * obj.b + 4 3074*da0073e9SAndroid Build Coastguard Worker obj.b = obj.a * obj.c + 5 3075*da0073e9SAndroid Build Coastguard Worker obj.a = obj.b * obj.c + 6 3076*da0073e9SAndroid Build Coastguard Worker return obj 3077*da0073e9SAndroid Build Coastguard Worker 3078*da0073e9SAndroid Build Coastguard Worker x1 = torch.randn(10) 3079*da0073e9SAndroid Build Coastguard Worker x2 = torch.randn(10) 3080*da0073e9SAndroid Build Coastguard Worker obj1 = MyObj(x1, x2) 3081*da0073e9SAndroid Build Coastguard Worker obj2 = MyObj(x1, x2) 3082*da0073e9SAndroid Build Coastguard Worker fn(obj2) 3083*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 3084*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 3085*da0073e9SAndroid Build Coastguard Worker self.assertIs(opt_fn(obj1), obj1) 3086*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(obj1.a, obj2.a)) 3087*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(obj1.b, obj2.b)) 3088*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(obj1.c, obj2.c)) 3089*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 3090*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.op_count, 12) 3091*da0073e9SAndroid Build Coastguard Worker 3092*da0073e9SAndroid Build Coastguard Worker def test_setattr_mutation2(self): 3093*da0073e9SAndroid Build Coastguard Worker class MyObj: 3094*da0073e9SAndroid Build Coastguard Worker def __init__(self, x): 3095*da0073e9SAndroid Build Coastguard Worker self.a = x + 1 3096*da0073e9SAndroid Build Coastguard Worker self.b = x + 2 3097*da0073e9SAndroid Build Coastguard Worker 3098*da0073e9SAndroid Build Coastguard Worker def fn(x): 3099*da0073e9SAndroid Build Coastguard Worker x = x / 3.0 3100*da0073e9SAndroid Build Coastguard Worker obj = MyObj(x) 3101*da0073e9SAndroid Build Coastguard Worker obj.c = obj.a * obj.b + 1 3102*da0073e9SAndroid Build Coastguard Worker obj.b = obj.a * obj.c + 2 3103*da0073e9SAndroid Build Coastguard Worker obj.a = obj.b * obj.c + 3 3104*da0073e9SAndroid Build Coastguard Worker return obj 3105*da0073e9SAndroid Build Coastguard Worker 3106*da0073e9SAndroid Build Coastguard Worker x1 = torch.randn(10) 3107*da0073e9SAndroid Build Coastguard Worker obj2 = fn(x1) 3108*da0073e9SAndroid Build Coastguard Worker 3109*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 3110*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 3111*da0073e9SAndroid Build Coastguard Worker obj1 = opt_fn(x1) 3112*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(obj1.a, obj2.a)) 3113*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(obj1.b, obj2.b)) 3114*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(obj1.c, obj2.c)) 3115*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 3116*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.op_count, 9) 3117*da0073e9SAndroid Build Coastguard Worker 3118*da0073e9SAndroid Build Coastguard Worker def test_setattr_mutation3(self): 3119*da0073e9SAndroid Build Coastguard Worker # TODO(jansel): dead code eliminate the object creation 3120*da0073e9SAndroid Build Coastguard Worker class MyObj: 3121*da0073e9SAndroid Build Coastguard Worker def __init__(self, x): 3122*da0073e9SAndroid Build Coastguard Worker super().__init__() 3123*da0073e9SAndroid Build Coastguard Worker self.a = x + 1 3124*da0073e9SAndroid Build Coastguard Worker self.b = x + 2 3125*da0073e9SAndroid Build Coastguard Worker 3126*da0073e9SAndroid Build Coastguard Worker def fn(x): 3127*da0073e9SAndroid Build Coastguard Worker x = x / 3.0 3128*da0073e9SAndroid Build Coastguard Worker obj = MyObj(x) 3129*da0073e9SAndroid Build Coastguard Worker obj.c = obj.a * obj.b + 1 3130*da0073e9SAndroid Build Coastguard Worker obj.b = obj.a * obj.c + 2 3131*da0073e9SAndroid Build Coastguard Worker obj.a = obj.b * obj.c + 3 3132*da0073e9SAndroid Build Coastguard Worker return obj.a, obj.b, obj.c 3133*da0073e9SAndroid Build Coastguard Worker 3134*da0073e9SAndroid Build Coastguard Worker x1 = torch.randn(10) 3135*da0073e9SAndroid Build Coastguard Worker obj2 = fn(x1) 3136*da0073e9SAndroid Build Coastguard Worker 3137*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 3138*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 3139*da0073e9SAndroid Build Coastguard Worker obj1 = opt_fn(x1) 3140*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(obj1, obj2)) 3141*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 3142*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.op_count, 9) 3143*da0073e9SAndroid Build Coastguard Worker 3144*da0073e9SAndroid Build Coastguard Worker def test_object_setattr(self): 3145*da0073e9SAndroid Build Coastguard Worker @dataclasses.dataclass 3146*da0073e9SAndroid Build Coastguard Worker class A: 3147*da0073e9SAndroid Build Coastguard Worker x: torch.Tensor 3148*da0073e9SAndroid Build Coastguard Worker 3149*da0073e9SAndroid Build Coastguard Worker def fn1(x) -> None: 3150*da0073e9SAndroid Build Coastguard Worker a = A(x) 3151*da0073e9SAndroid Build Coastguard Worker object.__setattr__(a, "x", x + 2) 3152*da0073e9SAndroid Build Coastguard Worker return a 3153*da0073e9SAndroid Build Coastguard Worker 3154*da0073e9SAndroid Build Coastguard Worker x1 = torch.randn(10) 3155*da0073e9SAndroid Build Coastguard Worker obj11 = fn1(x1.clone()) 3156*da0073e9SAndroid Build Coastguard Worker 3157*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 3158*da0073e9SAndroid Build Coastguard Worker opt_fn1 = torch._dynamo.optimize(cnts, nopython=True)(fn1) 3159*da0073e9SAndroid Build Coastguard Worker obj12 = opt_fn1(x1.clone()) 3160*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(obj11.x, x1 + 2)) 3161*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(obj12.x, x1 + 2)) 3162*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(obj11.x, obj12.x)) 3163*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 3164*da0073e9SAndroid Build Coastguard Worker 3165*da0073e9SAndroid Build Coastguard Worker @dataclasses.dataclass(frozen=True) 3166*da0073e9SAndroid Build Coastguard Worker class B: 3167*da0073e9SAndroid Build Coastguard Worker x: torch.Tensor 3168*da0073e9SAndroid Build Coastguard Worker 3169*da0073e9SAndroid Build Coastguard Worker def fn2(x) -> None: 3170*da0073e9SAndroid Build Coastguard Worker b = B(x) 3171*da0073e9SAndroid Build Coastguard Worker return b 3172*da0073e9SAndroid Build Coastguard Worker 3173*da0073e9SAndroid Build Coastguard Worker x2 = torch.randn(10) 3174*da0073e9SAndroid Build Coastguard Worker obj21 = fn2(x2.clone()) 3175*da0073e9SAndroid Build Coastguard Worker 3176*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 3177*da0073e9SAndroid Build Coastguard Worker opt_fn2 = torch._dynamo.optimize(cnts, nopython=True)(fn2) 3178*da0073e9SAndroid Build Coastguard Worker obj22 = opt_fn2(x2.clone()) 3179*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(obj21.x, x2)) 3180*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(obj22.x, x2)) 3181*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(obj21.x, obj22.x)) 3182*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 0) 3183*da0073e9SAndroid Build Coastguard Worker 3184*da0073e9SAndroid Build Coastguard Worker @dataclasses.dataclass(frozen=True) 3185*da0073e9SAndroid Build Coastguard Worker class C: 3186*da0073e9SAndroid Build Coastguard Worker x: torch.Tensor 3187*da0073e9SAndroid Build Coastguard Worker 3188*da0073e9SAndroid Build Coastguard Worker def fn3(x) -> None: 3189*da0073e9SAndroid Build Coastguard Worker c = C(x) 3190*da0073e9SAndroid Build Coastguard Worker object.__setattr__(c, "x", x + 2) 3191*da0073e9SAndroid Build Coastguard Worker return c 3192*da0073e9SAndroid Build Coastguard Worker 3193*da0073e9SAndroid Build Coastguard Worker x3 = torch.randn(10) 3194*da0073e9SAndroid Build Coastguard Worker obj31 = fn3(x3.clone()) 3195*da0073e9SAndroid Build Coastguard Worker 3196*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 3197*da0073e9SAndroid Build Coastguard Worker opt_fn3 = torch._dynamo.optimize(cnts, nopython=True)(fn3) 3198*da0073e9SAndroid Build Coastguard Worker obj32 = opt_fn3(x3.clone()) 3199*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(obj31.x, x3 + 2)) 3200*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(obj32.x, x3 + 2)) 3201*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(obj31.x, obj32.x)) 3202*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 3203*da0073e9SAndroid Build Coastguard Worker 3204*da0073e9SAndroid Build Coastguard Worker @dataclasses.dataclass(frozen=True) 3205*da0073e9SAndroid Build Coastguard Worker class D: 3206*da0073e9SAndroid Build Coastguard Worker x: torch.Tensor 3207*da0073e9SAndroid Build Coastguard Worker 3208*da0073e9SAndroid Build Coastguard Worker def __post_init__(self): 3209*da0073e9SAndroid Build Coastguard Worker object.__setattr__(self, "y", self.x + 2) 3210*da0073e9SAndroid Build Coastguard Worker 3211*da0073e9SAndroid Build Coastguard Worker def fn4(x) -> None: 3212*da0073e9SAndroid Build Coastguard Worker d = D(x) 3213*da0073e9SAndroid Build Coastguard Worker return d 3214*da0073e9SAndroid Build Coastguard Worker 3215*da0073e9SAndroid Build Coastguard Worker x4 = torch.randn(10) 3216*da0073e9SAndroid Build Coastguard Worker obj41 = fn4(x4.clone()) 3217*da0073e9SAndroid Build Coastguard Worker 3218*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 3219*da0073e9SAndroid Build Coastguard Worker opt_fn4 = torch._dynamo.optimize(cnts, nopython=True)(fn4) 3220*da0073e9SAndroid Build Coastguard Worker obj42 = opt_fn4(x4.clone()) 3221*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(obj41.x, x4)) 3222*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(obj42.x, x4)) 3223*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(obj41.x, obj42.x)) 3224*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(obj41.y, x4 + 2)) 3225*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(obj42.y, x4 + 2)) 3226*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(obj41.y, obj42.y)) 3227*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 3228*da0073e9SAndroid Build Coastguard Worker 3229*da0073e9SAndroid Build Coastguard Worker def test_user_defined_class_name(self): 3230*da0073e9SAndroid Build Coastguard Worker class MyClassFoo: 3231*da0073e9SAndroid Build Coastguard Worker pass 3232*da0073e9SAndroid Build Coastguard Worker 3233*da0073e9SAndroid Build Coastguard Worker def fn1(a, b, c): 3234*da0073e9SAndroid Build Coastguard Worker tmp = MyClassFoo() 3235*da0073e9SAndroid Build Coastguard Worker if tmp.__class__.__name__ == "MyClassFoo": 3236*da0073e9SAndroid Build Coastguard Worker return a - b / c 3237*da0073e9SAndroid Build Coastguard Worker 3238*da0073e9SAndroid Build Coastguard Worker torch._dynamo.testing.standard_test(self, fn=fn1, nargs=3) 3239*da0073e9SAndroid Build Coastguard Worker 3240*da0073e9SAndroid Build Coastguard Worker def test_user_defined_class_python_type(self): 3241*da0073e9SAndroid Build Coastguard Worker class MyClass1: 3242*da0073e9SAndroid Build Coastguard Worker pass 3243*da0073e9SAndroid Build Coastguard Worker 3244*da0073e9SAndroid Build Coastguard Worker class ExampleMeta(type): 3245*da0073e9SAndroid Build Coastguard Worker pass 3246*da0073e9SAndroid Build Coastguard Worker 3247*da0073e9SAndroid Build Coastguard Worker class MyClass2(metaclass=ExampleMeta): 3248*da0073e9SAndroid Build Coastguard Worker pass 3249*da0073e9SAndroid Build Coastguard Worker 3250*da0073e9SAndroid Build Coastguard Worker def fn(x, c): 3251*da0073e9SAndroid Build Coastguard Worker if isinstance(c, MyClass1): 3252*da0073e9SAndroid Build Coastguard Worker return x + 1 3253*da0073e9SAndroid Build Coastguard Worker elif isinstance(c, MyClass2): 3254*da0073e9SAndroid Build Coastguard Worker return x + 2 3255*da0073e9SAndroid Build Coastguard Worker else: 3256*da0073e9SAndroid Build Coastguard Worker return x + 3 3257*da0073e9SAndroid Build Coastguard Worker 3258*da0073e9SAndroid Build Coastguard Worker x = torch.rand(3) 3259*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager")(fn) 3260*da0073e9SAndroid Build Coastguard Worker for c in [MyClass1, MyClass2]: 3261*da0073e9SAndroid Build Coastguard Worker ref = fn(x, c) 3262*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x, c) 3263*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 3264*da0073e9SAndroid Build Coastguard Worker 3265*da0073e9SAndroid Build Coastguard Worker def test_super_calling_with_metaclass(self): 3266*da0073e9SAndroid Build Coastguard Worker class ExampleMeta(type): 3267*da0073e9SAndroid Build Coastguard Worker pass 3268*da0073e9SAndroid Build Coastguard Worker 3269*da0073e9SAndroid Build Coastguard Worker class MyClass1(metaclass=ExampleMeta): 3270*da0073e9SAndroid Build Coastguard Worker coeff = 4 # Force the constant guard to test source in guards 3271*da0073e9SAndroid Build Coastguard Worker 3272*da0073e9SAndroid Build Coastguard Worker @classmethod 3273*da0073e9SAndroid Build Coastguard Worker def add(cls, x): 3274*da0073e9SAndroid Build Coastguard Worker return x + 1 3275*da0073e9SAndroid Build Coastguard Worker 3276*da0073e9SAndroid Build Coastguard Worker class MyClass2(MyClass1): 3277*da0073e9SAndroid Build Coastguard Worker @classmethod 3278*da0073e9SAndroid Build Coastguard Worker def add(cls, x): 3279*da0073e9SAndroid Build Coastguard Worker torch._dynamo.graph_break() 3280*da0073e9SAndroid Build Coastguard Worker return x + super().add(x) + super().coeff 3281*da0073e9SAndroid Build Coastguard Worker 3282*da0073e9SAndroid Build Coastguard Worker def fn(x, obj): 3283*da0073e9SAndroid Build Coastguard Worker return x + obj.add(x) 3284*da0073e9SAndroid Build Coastguard Worker 3285*da0073e9SAndroid Build Coastguard Worker x = torch.rand(3) 3286*da0073e9SAndroid Build Coastguard Worker obj = MyClass2() 3287*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager")(fn) 3288*da0073e9SAndroid Build Coastguard Worker ref = fn(x, obj) 3289*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x, obj) 3290*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 3291*da0073e9SAndroid Build Coastguard Worker 3292*da0073e9SAndroid Build Coastguard Worker def test_usr_cls_staticmethod(self): 3293*da0073e9SAndroid Build Coastguard Worker class Foo: 3294*da0073e9SAndroid Build Coastguard Worker @staticmethod 3295*da0073e9SAndroid Build Coastguard Worker def bar(a, b): 3296*da0073e9SAndroid Build Coastguard Worker return a + b 3297*da0073e9SAndroid Build Coastguard Worker 3298*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 3299*da0073e9SAndroid Build Coastguard Worker return Foo.bar(a, b) - 1 3300*da0073e9SAndroid Build Coastguard Worker 3301*da0073e9SAndroid Build Coastguard Worker torch._dynamo.testing.standard_test(self, fn=fn, nargs=2) 3302*da0073e9SAndroid Build Coastguard Worker 3303*da0073e9SAndroid Build Coastguard Worker def test_usr_cls_classmethod(self): 3304*da0073e9SAndroid Build Coastguard Worker class Foo: 3305*da0073e9SAndroid Build Coastguard Worker @classmethod 3306*da0073e9SAndroid Build Coastguard Worker def bar(cls, a, b): 3307*da0073e9SAndroid Build Coastguard Worker return a + b 3308*da0073e9SAndroid Build Coastguard Worker 3309*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 3310*da0073e9SAndroid Build Coastguard Worker return Foo.bar(a, b) - 1 3311*da0073e9SAndroid Build Coastguard Worker 3312*da0073e9SAndroid Build Coastguard Worker torch._dynamo.testing.standard_test(self, fn=fn, nargs=2) 3313*da0073e9SAndroid Build Coastguard Worker 3314*da0073e9SAndroid Build Coastguard Worker def test_dunder_methods(self): 3315*da0073e9SAndroid Build Coastguard Worker class Foo: 3316*da0073e9SAndroid Build Coastguard Worker def __init__(self, val): 3317*da0073e9SAndroid Build Coastguard Worker super().__init__() 3318*da0073e9SAndroid Build Coastguard Worker self.val = val 3319*da0073e9SAndroid Build Coastguard Worker 3320*da0073e9SAndroid Build Coastguard Worker def __add__(self, other): 3321*da0073e9SAndroid Build Coastguard Worker return Foo(self.val + other.val) 3322*da0073e9SAndroid Build Coastguard Worker 3323*da0073e9SAndroid Build Coastguard Worker def __mul__(self, other): 3324*da0073e9SAndroid Build Coastguard Worker return Foo(self.val * other.val) 3325*da0073e9SAndroid Build Coastguard Worker 3326*da0073e9SAndroid Build Coastguard Worker def __truediv__(self, other): 3327*da0073e9SAndroid Build Coastguard Worker return Foo(self.val / other.val) 3328*da0073e9SAndroid Build Coastguard Worker 3329*da0073e9SAndroid Build Coastguard Worker def __sub__(self, other): 3330*da0073e9SAndroid Build Coastguard Worker return Foo(self.val - other.val) 3331*da0073e9SAndroid Build Coastguard Worker 3332*da0073e9SAndroid Build Coastguard Worker def fn(a, b, c): 3333*da0073e9SAndroid Build Coastguard Worker return Foo(a) + Foo(b) * Foo(c) / Foo(a) - Foo(b) 3334*da0073e9SAndroid Build Coastguard Worker 3335*da0073e9SAndroid Build Coastguard Worker torch._dynamo.testing.standard_test(self, fn=fn, nargs=3, expected_ops=4) 3336*da0073e9SAndroid Build Coastguard Worker 3337*da0073e9SAndroid Build Coastguard Worker def test_function_annotation(self): 3338*da0073e9SAndroid Build Coastguard Worker class Variable: 3339*da0073e9SAndroid Build Coastguard Worker pass 3340*da0073e9SAndroid Build Coastguard Worker 3341*da0073e9SAndroid Build Coastguard Worker def fn(x): 3342*da0073e9SAndroid Build Coastguard Worker x = x / 3.0 3343*da0073e9SAndroid Build Coastguard Worker 3344*da0073e9SAndroid Build Coastguard Worker def inner(y: typing.List[Variable]): 3345*da0073e9SAndroid Build Coastguard Worker return x + 1 3346*da0073e9SAndroid Build Coastguard Worker 3347*da0073e9SAndroid Build Coastguard Worker return inner 3348*da0073e9SAndroid Build Coastguard Worker 3349*da0073e9SAndroid Build Coastguard Worker x1 = torch.randn(10) 3350*da0073e9SAndroid Build Coastguard Worker obj2 = fn(x1)([]) 3351*da0073e9SAndroid Build Coastguard Worker 3352*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 3353*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize_assert(cnts)(fn) 3354*da0073e9SAndroid Build Coastguard Worker opt_fn_inner = torch._dynamo.optimize_assert(cnts)(opt_fn(x1)) 3355*da0073e9SAndroid Build Coastguard Worker obj1 = opt_fn_inner([]) 3356*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(obj1, obj2)) 3357*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 2) 3358*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.op_count, 2) 3359*da0073e9SAndroid Build Coastguard Worker 3360*da0073e9SAndroid Build Coastguard Worker def test_nested_closure(self): 3361*da0073e9SAndroid Build Coastguard Worker v0 = torch.randn(10) 3362*da0073e9SAndroid Build Coastguard Worker 3363*da0073e9SAndroid Build Coastguard Worker def fn1(): 3364*da0073e9SAndroid Build Coastguard Worker v1 = torch.randn(10) 3365*da0073e9SAndroid Build Coastguard Worker 3366*da0073e9SAndroid Build Coastguard Worker def fn2(*args, **kwargs): 3367*da0073e9SAndroid Build Coastguard Worker assert len(args) == 1 3368*da0073e9SAndroid Build Coastguard Worker assert len(kwargs) == 1 3369*da0073e9SAndroid Build Coastguard Worker v2 = torch.randn(10) + args[0] + kwargs["b"] 3370*da0073e9SAndroid Build Coastguard Worker 3371*da0073e9SAndroid Build Coastguard Worker def fn3(v3=torch.randn(10)): 3372*da0073e9SAndroid Build Coastguard Worker def fn4(): 3373*da0073e9SAndroid Build Coastguard Worker return v0 + v1 + v2 + v3 + 1 3374*da0073e9SAndroid Build Coastguard Worker 3375*da0073e9SAndroid Build Coastguard Worker return fn4 3376*da0073e9SAndroid Build Coastguard Worker 3377*da0073e9SAndroid Build Coastguard Worker return fn3 3378*da0073e9SAndroid Build Coastguard Worker 3379*da0073e9SAndroid Build Coastguard Worker return fn2(1, b=2)() 3380*da0073e9SAndroid Build Coastguard Worker 3381*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 3382*da0073e9SAndroid Build Coastguard Worker opt_fn1 = torch._dynamo.optimize_assert(cnts)(fn1) 3383*da0073e9SAndroid Build Coastguard Worker tmp1 = torch._dynamo.optimize_assert(cnts)(opt_fn1()) 3384*da0073e9SAndroid Build Coastguard Worker tmp2 = torch._dynamo.optimize_assert(cnts)(opt_fn1()) 3385*da0073e9SAndroid Build Coastguard Worker self.assertTrue(tmp1().shape, (10,)) 3386*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(tmp1(), tmp1())) 3387*da0073e9SAndroid Build Coastguard Worker self.assertFalse(same(tmp1(), tmp2())) 3388*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 2) 3389*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.op_count, 9) 3390*da0073e9SAndroid Build Coastguard Worker 3391*da0073e9SAndroid Build Coastguard Worker def test_nested_closure_mutation(self): 3392*da0073e9SAndroid Build Coastguard Worker def fn1(): 3393*da0073e9SAndroid Build Coastguard Worker v1 = torch.randn(10) 3394*da0073e9SAndroid Build Coastguard Worker 3395*da0073e9SAndroid Build Coastguard Worker def fn2(): 3396*da0073e9SAndroid Build Coastguard Worker v2 = torch.randn(10) 3397*da0073e9SAndroid Build Coastguard Worker 3398*da0073e9SAndroid Build Coastguard Worker def fn3(): 3399*da0073e9SAndroid Build Coastguard Worker nonlocal v1, v2 3400*da0073e9SAndroid Build Coastguard Worker v1 += 1 3401*da0073e9SAndroid Build Coastguard Worker v2 += 2 3402*da0073e9SAndroid Build Coastguard Worker return v1 + v2 3403*da0073e9SAndroid Build Coastguard Worker 3404*da0073e9SAndroid Build Coastguard Worker return fn3 3405*da0073e9SAndroid Build Coastguard Worker 3406*da0073e9SAndroid Build Coastguard Worker rv = fn2() 3407*da0073e9SAndroid Build Coastguard Worker rv() 3408*da0073e9SAndroid Build Coastguard Worker rv() 3409*da0073e9SAndroid Build Coastguard Worker return rv 3410*da0073e9SAndroid Build Coastguard Worker 3411*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(9000) 3412*da0073e9SAndroid Build Coastguard Worker counter1 = fn1() 3413*da0073e9SAndroid Build Coastguard Worker result1 = [counter1(), counter1(), counter1()] 3414*da0073e9SAndroid Build Coastguard Worker 3415*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(9000) 3416*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 3417*da0073e9SAndroid Build Coastguard Worker opt_fn1 = torch._dynamo.optimize_assert(cnts)(fn1) 3418*da0073e9SAndroid Build Coastguard Worker counter2 = torch._dynamo.optimize_assert(cnts)(opt_fn1()) 3419*da0073e9SAndroid Build Coastguard Worker result2 = [counter2(), counter2(), counter2()] 3420*da0073e9SAndroid Build Coastguard Worker result1.append(counter1()) 3421*da0073e9SAndroid Build Coastguard Worker result2.append(counter2()) 3422*da0073e9SAndroid Build Coastguard Worker 3423*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(result1, result2)) 3424*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 2) 3425*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.op_count, 11) 3426*da0073e9SAndroid Build Coastguard Worker 3427*da0073e9SAndroid Build Coastguard Worker def test_write_to_closures_in_inlining(self): 3428*da0073e9SAndroid Build Coastguard Worker out = [] 3429*da0073e9SAndroid Build Coastguard Worker for use_dynamo in [False, True]: 3430*da0073e9SAndroid Build Coastguard Worker 3431*da0073e9SAndroid Build Coastguard Worker def make_counter(): 3432*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10) 3433*da0073e9SAndroid Build Coastguard Worker 3434*da0073e9SAndroid Build Coastguard Worker def counter(): 3435*da0073e9SAndroid Build Coastguard Worker nonlocal x 3436*da0073e9SAndroid Build Coastguard Worker x = x + 1 3437*da0073e9SAndroid Build Coastguard Worker return x 3438*da0073e9SAndroid Build Coastguard Worker 3439*da0073e9SAndroid Build Coastguard Worker return counter 3440*da0073e9SAndroid Build Coastguard Worker 3441*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(0) 3442*da0073e9SAndroid Build Coastguard Worker counter = make_counter() 3443*da0073e9SAndroid Build Coastguard Worker if not use_dynamo: 3444*da0073e9SAndroid Build Coastguard Worker out.append(counter() + counter()) 3445*da0073e9SAndroid Build Coastguard Worker else: 3446*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 3447*da0073e9SAndroid Build Coastguard Worker 3448*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize(cnts, nopython=True) 3449*da0073e9SAndroid Build Coastguard Worker def fn(counter): 3450*da0073e9SAndroid Build Coastguard Worker return counter() + counter() 3451*da0073e9SAndroid Build Coastguard Worker 3452*da0073e9SAndroid Build Coastguard Worker out.append(fn(counter)) 3453*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 3454*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.op_count, 3) 3455*da0073e9SAndroid Build Coastguard Worker self.assertFalse(same(counter() + counter(), out[-1])) 3456*da0073e9SAndroid Build Coastguard Worker 3457*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(out[0], out[1])) 3458*da0073e9SAndroid Build Coastguard Worker 3459*da0073e9SAndroid Build Coastguard Worker def test_closure_out_of_scope_cell(self): 3460*da0073e9SAndroid Build Coastguard Worker cell1 = torch.rand(1).item() 3461*da0073e9SAndroid Build Coastguard Worker cell2 = torch.rand(3, 3) 3462*da0073e9SAndroid Build Coastguard Worker 3463*da0073e9SAndroid Build Coastguard Worker def indirect(): 3464*da0073e9SAndroid Build Coastguard Worker return direct() 3465*da0073e9SAndroid Build Coastguard Worker 3466*da0073e9SAndroid Build Coastguard Worker def direct(): 3467*da0073e9SAndroid Build Coastguard Worker def inner(): 3468*da0073e9SAndroid Build Coastguard Worker return cell1 + 1, cell2 + 3 3469*da0073e9SAndroid Build Coastguard Worker 3470*da0073e9SAndroid Build Coastguard Worker return inner() 3471*da0073e9SAndroid Build Coastguard Worker 3472*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 3473*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(indirect) 3474*da0073e9SAndroid Build Coastguard Worker result1, result2 = opt_fn() 3475*da0073e9SAndroid Build Coastguard Worker self.assertAlmostEqual(cell1 + 1, result1) 3476*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(cell2 + 3, result2)) 3477*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 3478*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.op_count, 1) 3479*da0073e9SAndroid Build Coastguard Worker 3480*da0073e9SAndroid Build Coastguard Worker def test_closure_out_of_scope_cell_with_mutation(self): 3481*da0073e9SAndroid Build Coastguard Worker cell1 = torch.rand(1).item() 3482*da0073e9SAndroid Build Coastguard Worker orig1 = cell1 3483*da0073e9SAndroid Build Coastguard Worker cell2 = torch.rand(3, 3) 3484*da0073e9SAndroid Build Coastguard Worker orig2 = cell2.clone() 3485*da0073e9SAndroid Build Coastguard Worker 3486*da0073e9SAndroid Build Coastguard Worker def indirect(): 3487*da0073e9SAndroid Build Coastguard Worker return direct() 3488*da0073e9SAndroid Build Coastguard Worker 3489*da0073e9SAndroid Build Coastguard Worker def direct(): 3490*da0073e9SAndroid Build Coastguard Worker def inner(): 3491*da0073e9SAndroid Build Coastguard Worker nonlocal cell1, cell2 3492*da0073e9SAndroid Build Coastguard Worker x = cell2 + 1 3493*da0073e9SAndroid Build Coastguard Worker cell1 += 1 3494*da0073e9SAndroid Build Coastguard Worker cell2 += 10 3495*da0073e9SAndroid Build Coastguard Worker x = x + cell2 3496*da0073e9SAndroid Build Coastguard Worker return cell1, cell2, x 3497*da0073e9SAndroid Build Coastguard Worker 3498*da0073e9SAndroid Build Coastguard Worker return inner() 3499*da0073e9SAndroid Build Coastguard Worker 3500*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 3501*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts, nopython=True)(indirect) 3502*da0073e9SAndroid Build Coastguard Worker for i in range(1, 4): 3503*da0073e9SAndroid Build Coastguard Worker result1, result2, _ = opt_fn() 3504*da0073e9SAndroid Build Coastguard Worker self.assertAlmostEqual(orig1 + 1 * i, result1) 3505*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(orig2 + 10 * i, result2)) 3506*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 3507*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.op_count, 3) 3508*da0073e9SAndroid Build Coastguard Worker cnts.clear() 3509*da0073e9SAndroid Build Coastguard Worker 3510*da0073e9SAndroid Build Coastguard Worker def test_closure_with_mutation_and_graph_break(self): 3511*da0073e9SAndroid Build Coastguard Worker def fn(): 3512*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(1) 3513*da0073e9SAndroid Build Coastguard Worker 3514*da0073e9SAndroid Build Coastguard Worker def subfunc(): 3515*da0073e9SAndroid Build Coastguard Worker x[0] = backup 3516*da0073e9SAndroid Build Coastguard Worker 3517*da0073e9SAndroid Build Coastguard Worker if x[0] >= -1e5: 3518*da0073e9SAndroid Build Coastguard Worker pass 3519*da0073e9SAndroid Build Coastguard Worker 3520*da0073e9SAndroid Build Coastguard Worker backup = 1 3521*da0073e9SAndroid Build Coastguard Worker subfunc() 3522*da0073e9SAndroid Build Coastguard Worker return x 3523*da0073e9SAndroid Build Coastguard Worker 3524*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 3525*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 3526*da0073e9SAndroid Build Coastguard Worker expected = fn() 3527*da0073e9SAndroid Build Coastguard Worker actual = opt_fn() 3528*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(expected, actual)) 3529*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 2) 3530*da0073e9SAndroid Build Coastguard Worker 3531*da0073e9SAndroid Build Coastguard Worker def test_closure_out_of_scope_cell_with_cond(self): 3532*da0073e9SAndroid Build Coastguard Worker # Test closure with out-of-scope cell variable, used in a cond 3533*da0073e9SAndroid Build Coastguard Worker # where the two branches read different closure variables 3534*da0073e9SAndroid Build Coastguard Worker from functorch.experimental.control_flow import cond 3535*da0073e9SAndroid Build Coastguard Worker 3536*da0073e9SAndroid Build Coastguard Worker def g(x): 3537*da0073e9SAndroid Build Coastguard Worker return x 3538*da0073e9SAndroid Build Coastguard Worker 3539*da0073e9SAndroid Build Coastguard Worker class ModuleCondDeep(torch.nn.Module): 3540*da0073e9SAndroid Build Coastguard Worker def forward(self, pred, x): 3541*da0073e9SAndroid Build Coastguard Worker return self._indirection(pred, x) 3542*da0073e9SAndroid Build Coastguard Worker 3543*da0073e9SAndroid Build Coastguard Worker def _indirection(self, pred, x): 3544*da0073e9SAndroid Build Coastguard Worker return self.indirection(pred, x) 3545*da0073e9SAndroid Build Coastguard Worker 3546*da0073e9SAndroid Build Coastguard Worker def indirection(self, pred, x): 3547*da0073e9SAndroid Build Coastguard Worker def true_fn(y): 3548*da0073e9SAndroid Build Coastguard Worker return y + 2 3549*da0073e9SAndroid Build Coastguard Worker 3550*da0073e9SAndroid Build Coastguard Worker def false_fn(y): 3551*da0073e9SAndroid Build Coastguard Worker return y - 2 3552*da0073e9SAndroid Build Coastguard Worker 3553*da0073e9SAndroid Build Coastguard Worker def shallow(x): 3554*da0073e9SAndroid Build Coastguard Worker return x * 2 3555*da0073e9SAndroid Build Coastguard Worker 3556*da0073e9SAndroid Build Coastguard Worker def deep(x): 3557*da0073e9SAndroid Build Coastguard Worker # y = g(x) 3558*da0073e9SAndroid Build Coastguard Worker y = x 3559*da0073e9SAndroid Build Coastguard Worker return cond( 3560*da0073e9SAndroid Build Coastguard Worker x[0][0] > 0, 3561*da0073e9SAndroid Build Coastguard Worker true_fn, 3562*da0073e9SAndroid Build Coastguard Worker false_fn, 3563*da0073e9SAndroid Build Coastguard Worker [y], 3564*da0073e9SAndroid Build Coastguard Worker ) 3565*da0073e9SAndroid Build Coastguard Worker 3566*da0073e9SAndroid Build Coastguard Worker return cond(pred, shallow, deep, [x]) 3567*da0073e9SAndroid Build Coastguard Worker 3568*da0073e9SAndroid Build Coastguard Worker mod = ModuleCondDeep() 3569*da0073e9SAndroid Build Coastguard Worker opt_mod = torch._dynamo.optimize("eager")(mod) 3570*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(3, 3) 3571*da0073e9SAndroid Build Coastguard Worker exp1 = mod(torch.tensor(False), inp) 3572*da0073e9SAndroid Build Coastguard Worker actual1 = opt_mod(torch.tensor(False), inp) 3573*da0073e9SAndroid Build Coastguard Worker exp2 = mod(torch.tensor(True), inp) 3574*da0073e9SAndroid Build Coastguard Worker actual2 = opt_mod(torch.tensor(True), inp) 3575*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(exp1, actual1)) 3576*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(exp2, actual2)) 3577*da0073e9SAndroid Build Coastguard Worker 3578*da0073e9SAndroid Build Coastguard Worker def test_top_package_import(self): 3579*da0073e9SAndroid Build Coastguard Worker def fn(x): 3580*da0073e9SAndroid Build Coastguard Worker import torch.fx 3581*da0073e9SAndroid Build Coastguard Worker 3582*da0073e9SAndroid Build Coastguard Worker assert not isinstance(x, torch.fx.Proxy) 3583*da0073e9SAndroid Build Coastguard Worker return torch.sin(x) 3584*da0073e9SAndroid Build Coastguard Worker 3585*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, 5) 3586*da0073e9SAndroid Build Coastguard Worker ref = fn(x) 3587*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 3588*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize_assert(cnts)(fn) 3589*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x) 3590*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 3591*da0073e9SAndroid Build Coastguard Worker 3592*da0073e9SAndroid Build Coastguard Worker def test_typing_typevar(self): 3593*da0073e9SAndroid Build Coastguard Worker def fn(x): 3594*da0073e9SAndroid Build Coastguard Worker def sumt(y: torch.Tensor) -> torch.Tensor: 3595*da0073e9SAndroid Build Coastguard Worker return torch.sum(y) 3596*da0073e9SAndroid Build Coastguard Worker 3597*da0073e9SAndroid Build Coastguard Worker def foo(c: typing.Callable[[T], T], y: T) -> T: 3598*da0073e9SAndroid Build Coastguard Worker return c(y) 3599*da0073e9SAndroid Build Coastguard Worker 3600*da0073e9SAndroid Build Coastguard Worker return foo(sumt, x) 3601*da0073e9SAndroid Build Coastguard Worker 3602*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 3603*da0073e9SAndroid Build Coastguard Worker ref = fn(x) 3604*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 3605*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize_assert(cnts)(fn) 3606*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x) 3607*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 3608*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 3609*da0073e9SAndroid Build Coastguard Worker 3610*da0073e9SAndroid Build Coastguard Worker def test_typing_union_and_optional(self): 3611*da0073e9SAndroid Build Coastguard Worker def fn(x): 3612*da0073e9SAndroid Build Coastguard Worker a = torch.jit.annotate(typing.Dict[str, typing.Optional[torch.Tensor]], {}) 3613*da0073e9SAndroid Build Coastguard Worker b = torch.jit.annotate( 3614*da0073e9SAndroid Build Coastguard Worker typing.Dict[str, typing.Union[torch.Tensor, None]], {} 3615*da0073e9SAndroid Build Coastguard Worker ) 3616*da0073e9SAndroid Build Coastguard Worker return a, b, x + 1 3617*da0073e9SAndroid Build Coastguard Worker 3618*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 3619*da0073e9SAndroid Build Coastguard Worker ref = fn(x) 3620*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager", nopython=False)(fn) 3621*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x) 3622*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 3623*da0073e9SAndroid Build Coastguard Worker 3624*da0073e9SAndroid Build Coastguard Worker def test_optimize_on_module(self): 3625*da0073e9SAndroid Build Coastguard Worker class MockModule(torch.nn.Module): 3626*da0073e9SAndroid Build Coastguard Worker def __init__(self): 3627*da0073e9SAndroid Build Coastguard Worker super().__init__() 3628*da0073e9SAndroid Build Coastguard Worker self.relu = torch.nn.ReLU() 3629*da0073e9SAndroid Build Coastguard Worker 3630*da0073e9SAndroid Build Coastguard Worker def custom_member(self): 3631*da0073e9SAndroid Build Coastguard Worker # Just for checking that Dynamo returned mod object can redirect 3632*da0073e9SAndroid Build Coastguard Worker # to this method 3633*da0073e9SAndroid Build Coastguard Worker pass 3634*da0073e9SAndroid Build Coastguard Worker 3635*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 3636*da0073e9SAndroid Build Coastguard Worker return self.relu(x) 3637*da0073e9SAndroid Build Coastguard Worker 3638*da0073e9SAndroid Build Coastguard Worker cnts1 = torch._dynamo.testing.CompileCounter() 3639*da0073e9SAndroid Build Coastguard Worker mod = MockModule() 3640*da0073e9SAndroid Build Coastguard Worker optimized_mod = torch._dynamo.optimize(cnts1, nopython=True)(mod) 3641*da0073e9SAndroid Build Coastguard Worker 3642*da0073e9SAndroid Build Coastguard Worker a = torch.randn(10) 3643*da0073e9SAndroid Build Coastguard Worker ref = mod(a) 3644*da0073e9SAndroid Build Coastguard Worker res = optimized_mod(a) 3645*da0073e9SAndroid Build Coastguard Worker 3646*da0073e9SAndroid Build Coastguard Worker optimized_mod.custom_member() 3647*da0073e9SAndroid Build Coastguard Worker 3648*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 3649*da0073e9SAndroid Build Coastguard Worker 3650*da0073e9SAndroid Build Coastguard Worker def test_nested_optimize_decorator(self): 3651*da0073e9SAndroid Build Coastguard Worker cnts2 = torch._dynamo.testing.CompileCounter() 3652*da0073e9SAndroid Build Coastguard Worker cnts3 = torch._dynamo.testing.CompileCounter() 3653*da0073e9SAndroid Build Coastguard Worker 3654*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.run() 3655*da0073e9SAndroid Build Coastguard Worker def fn1(x): 3656*da0073e9SAndroid Build Coastguard Worker return torch.sin(x) * 10 3657*da0073e9SAndroid Build Coastguard Worker 3658*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize(cnts2, nopython=True) 3659*da0073e9SAndroid Build Coastguard Worker def fn2(x): 3660*da0073e9SAndroid Build Coastguard Worker return fn1(x) + 1 3661*da0073e9SAndroid Build Coastguard Worker 3662*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize(cnts3, nopython=True) 3663*da0073e9SAndroid Build Coastguard Worker def fn3(x): 3664*da0073e9SAndroid Build Coastguard Worker return torch.relu(fn2(x)) 3665*da0073e9SAndroid Build Coastguard Worker 3666*da0073e9SAndroid Build Coastguard Worker fn3(torch.randn(4, 5)) 3667*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts2.frame_count, 0) 3668*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts3.frame_count, 1) 3669*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts3.op_count, 4) 3670*da0073e9SAndroid Build Coastguard Worker 3671*da0073e9SAndroid Build Coastguard Worker def test_nested_optimize_run(self): 3672*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 3673*da0073e9SAndroid Build Coastguard Worker 3674*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize(cnts, nopython=True) 3675*da0073e9SAndroid Build Coastguard Worker def fn(x): 3676*da0073e9SAndroid Build Coastguard Worker return torch.relu(torch.cos(x) + torch.sin(x)) 3677*da0073e9SAndroid Build Coastguard Worker 3678*da0073e9SAndroid Build Coastguard Worker fn(torch.randn(4)) 3679*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 3680*da0073e9SAndroid Build Coastguard Worker 3681*da0073e9SAndroid Build Coastguard Worker fn(torch.randn(4, 4)) 3682*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 2) 3683*da0073e9SAndroid Build Coastguard Worker 3684*da0073e9SAndroid Build Coastguard Worker # Test that run works on a decorated fn 3685*da0073e9SAndroid Build Coastguard Worker fn = torch._dynamo.run(fn) 3686*da0073e9SAndroid Build Coastguard Worker fn(torch.randn(4, 4, 4)) 3687*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 2) 3688*da0073e9SAndroid Build Coastguard Worker 3689*da0073e9SAndroid Build Coastguard Worker def test_nested_optimize(self): 3690*da0073e9SAndroid Build Coastguard Worker cnts1 = torch._dynamo.testing.CompileCounter() 3691*da0073e9SAndroid Build Coastguard Worker cnts2 = torch._dynamo.testing.CompileCounter() 3692*da0073e9SAndroid Build Coastguard Worker 3693*da0073e9SAndroid Build Coastguard Worker def fn(x): 3694*da0073e9SAndroid Build Coastguard Worker return torch.relu(torch.cos(x) + torch.sin(x)) 3695*da0073e9SAndroid Build Coastguard Worker 3696*da0073e9SAndroid Build Coastguard Worker fn1 = torch._dynamo.optimize(cnts1, nopython=True)(fn) 3697*da0073e9SAndroid Build Coastguard Worker fn2 = torch._dynamo.optimize(cnts2, nopython=True)(fn1) 3698*da0073e9SAndroid Build Coastguard Worker 3699*da0073e9SAndroid Build Coastguard Worker # The first optimize in the nesting should be ignored 3700*da0073e9SAndroid Build Coastguard Worker fn2(torch.randn(4)) 3701*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts2.frame_count, 1) 3702*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts1.frame_count, 0) 3703*da0073e9SAndroid Build Coastguard Worker 3704*da0073e9SAndroid Build Coastguard Worker # Since the fn code object is already compiled, calling fn1 should 3705*da0073e9SAndroid Build Coastguard Worker # directly call the compiled_fn callable. 3706*da0073e9SAndroid Build Coastguard Worker torch._dynamo.run()(fn1)(torch.randn(4)) 3707*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts1.frame_count, 0) 3708*da0073e9SAndroid Build Coastguard Worker 3709*da0073e9SAndroid Build Coastguard Worker # Test same behavior by reversing the calls 3710*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 3711*da0073e9SAndroid Build Coastguard Worker cnts1 = torch._dynamo.testing.CompileCounter() 3712*da0073e9SAndroid Build Coastguard Worker cnts2 = torch._dynamo.testing.CompileCounter() 3713*da0073e9SAndroid Build Coastguard Worker fn1 = torch._dynamo.optimize(cnts1, nopython=True)(fn) 3714*da0073e9SAndroid Build Coastguard Worker fn2 = torch._dynamo.optimize(cnts2, nopython=True)(fn1) 3715*da0073e9SAndroid Build Coastguard Worker fn1(torch.randn(4)) 3716*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts1.frame_count, 1) 3717*da0073e9SAndroid Build Coastguard Worker torch._dynamo.run()(fn2)(torch.randn(4)) 3718*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts2.frame_count, 0) 3719*da0073e9SAndroid Build Coastguard Worker 3720*da0073e9SAndroid Build Coastguard Worker def test_torch_size(self): 3721*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 3722*da0073e9SAndroid Build Coastguard Worker 3723*da0073e9SAndroid Build Coastguard Worker def fn(x): 3724*da0073e9SAndroid Build Coastguard Worker output_size = torch.Size([10, 10]) 3725*da0073e9SAndroid Build Coastguard Worker x = x.view(*output_size) 3726*da0073e9SAndroid Build Coastguard Worker return (x,) 3727*da0073e9SAndroid Build Coastguard Worker 3728*da0073e9SAndroid Build Coastguard Worker x = torch.randn(100, requires_grad=True) 3729*da0073e9SAndroid Build Coastguard Worker x_clone = x.clone() 3730*da0073e9SAndroid Build Coastguard Worker ref = fn(x) 3731*da0073e9SAndroid Build Coastguard Worker 3732*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 3733*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x_clone) 3734*da0073e9SAndroid Build Coastguard Worker 3735*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 3736*da0073e9SAndroid Build Coastguard Worker 3737*da0073e9SAndroid Build Coastguard Worker def test_torch_size_numel(self): 3738*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 3739*da0073e9SAndroid Build Coastguard Worker 3740*da0073e9SAndroid Build Coastguard Worker def fn(): 3741*da0073e9SAndroid Build Coastguard Worker return torch.Size([10, 8]).numel() 3742*da0073e9SAndroid Build Coastguard Worker 3743*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 3744*da0073e9SAndroid Build Coastguard Worker num = torch.Size([10, 8]).numel() 3745*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(), num) 3746*da0073e9SAndroid Build Coastguard Worker 3747*da0073e9SAndroid Build Coastguard Worker def test_torch_size_numel_dynamic(self): 3748*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 3749*da0073e9SAndroid Build Coastguard Worker 3750*da0073e9SAndroid Build Coastguard Worker def fn(x): 3751*da0073e9SAndroid Build Coastguard Worker return x.size().numel() 3752*da0073e9SAndroid Build Coastguard Worker 3753*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 3754*da0073e9SAndroid Build Coastguard Worker x = torch.rand(10, 1, 8, 1) 3755*da0073e9SAndroid Build Coastguard Worker expect = fn(x) 3756*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(x), expect) 3757*da0073e9SAndroid Build Coastguard Worker 3758*da0073e9SAndroid Build Coastguard Worker def test_shape_type(self): 3759*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 3760*da0073e9SAndroid Build Coastguard Worker 3761*da0073e9SAndroid Build Coastguard Worker def fn(x): 3762*da0073e9SAndroid Build Coastguard Worker return x + (type(x.shape) == torch.Size) 3763*da0073e9SAndroid Build Coastguard Worker 3764*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 3765*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(()) 3766*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(x), fn(x)) 3767*da0073e9SAndroid Build Coastguard Worker 3768*da0073e9SAndroid Build Coastguard Worker def test_size_dim(self): 3769*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 3770*da0073e9SAndroid Build Coastguard Worker 3771*da0073e9SAndroid Build Coastguard Worker def fn(x, dim): 3772*da0073e9SAndroid Build Coastguard Worker return x.size(dim=dim) 3773*da0073e9SAndroid Build Coastguard Worker 3774*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 3775*da0073e9SAndroid Build Coastguard Worker x = torch.empty([4, 9, 8]) 3776*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(x, 1), 9) 3777*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(x, -2), 9) 3778*da0073e9SAndroid Build Coastguard Worker 3779*da0073e9SAndroid Build Coastguard Worker def test_stride_dim(self): 3780*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 3781*da0073e9SAndroid Build Coastguard Worker 3782*da0073e9SAndroid Build Coastguard Worker def fn(x, dim): 3783*da0073e9SAndroid Build Coastguard Worker return x.stride(dim=dim) 3784*da0073e9SAndroid Build Coastguard Worker 3785*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 3786*da0073e9SAndroid Build Coastguard Worker x = torch.empty([4, 9, 8]) 3787*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(x, 0), 72) 3788*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(x, -2), 8) 3789*da0073e9SAndroid Build Coastguard Worker 3790*da0073e9SAndroid Build Coastguard Worker def test_torch_seed(self): 3791*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.utils import counters 3792*da0073e9SAndroid Build Coastguard Worker 3793*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 3794*da0073e9SAndroid Build Coastguard Worker counters.clear() 3795*da0073e9SAndroid Build Coastguard Worker 3796*da0073e9SAndroid Build Coastguard Worker def fn(x): 3797*da0073e9SAndroid Build Coastguard Worker attention_seed = int(torch.seed() % sys.maxsize) 3798*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(attention_seed) 3799*da0073e9SAndroid Build Coastguard Worker return (x,) 3800*da0073e9SAndroid Build Coastguard Worker 3801*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10, requires_grad=True) 3802*da0073e9SAndroid Build Coastguard Worker ref = fn(x) 3803*da0073e9SAndroid Build Coastguard Worker 3804*da0073e9SAndroid Build Coastguard Worker # Python code is needed here, since torch.manual_seed graph-breaks. 3805*da0073e9SAndroid Build Coastguard Worker # Refs: https://github.com/pytorch/pytorch/issues/107187 3806*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts, nopython=False)(fn) 3807*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x) 3808*da0073e9SAndroid Build Coastguard Worker 3809*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 3810*da0073e9SAndroid Build Coastguard Worker # Only the torch.seed call is turned into an FX graph. 3811*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.op_count, 1) 3812*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 3813*da0073e9SAndroid Build Coastguard Worker # Graph breaks at manual_seed. 3814*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(counters["graph_break"]), 1) 3815*da0073e9SAndroid Build Coastguard Worker 3816*da0073e9SAndroid Build Coastguard Worker def test_is_tensor_like(self): 3817*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 3818*da0073e9SAndroid Build Coastguard Worker 3819*da0073e9SAndroid Build Coastguard Worker def f(x): 3820*da0073e9SAndroid Build Coastguard Worker if torch.overrides.is_tensor_like(x): 3821*da0073e9SAndroid Build Coastguard Worker return (x * 2,) 3822*da0073e9SAndroid Build Coastguard Worker return (torch.ones(10) + x,) 3823*da0073e9SAndroid Build Coastguard Worker 3824*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10) 3825*da0073e9SAndroid Build Coastguard Worker ref0 = f(x) 3826*da0073e9SAndroid Build Coastguard Worker ref1 = f(4) 3827*da0073e9SAndroid Build Coastguard Worker opt_f = torch._dynamo.optimize(cnts, nopython=True)(f) 3828*da0073e9SAndroid Build Coastguard Worker res0 = opt_f(x) 3829*da0073e9SAndroid Build Coastguard Worker res1 = opt_f(4) 3830*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref0, res0)) 3831*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref1, res1)) 3832*da0073e9SAndroid Build Coastguard Worker 3833*da0073e9SAndroid Build Coastguard Worker def test_is_tensor_like2(self): 3834*da0073e9SAndroid Build Coastguard Worker class MyTensor: 3835*da0073e9SAndroid Build Coastguard Worker @classmethod 3836*da0073e9SAndroid Build Coastguard Worker def __torch_function__(cls, func, types, args=(), kwargs=None): 3837*da0073e9SAndroid Build Coastguard Worker if kwargs is None: 3838*da0073e9SAndroid Build Coastguard Worker kwargs = {} 3839*da0073e9SAndroid Build Coastguard Worker 3840*da0073e9SAndroid Build Coastguard Worker if func is torch.max: 3841*da0073e9SAndroid Build Coastguard Worker return torch.tensor(123) 3842*da0073e9SAndroid Build Coastguard Worker return func(*args, **kwargs) 3843*da0073e9SAndroid Build Coastguard Worker 3844*da0073e9SAndroid Build Coastguard Worker def fn(x): 3845*da0073e9SAndroid Build Coastguard Worker if torch.overrides.is_tensor_like(x): 3846*da0073e9SAndroid Build Coastguard Worker return torch.max(x) 3847*da0073e9SAndroid Build Coastguard Worker else: 3848*da0073e9SAndroid Build Coastguard Worker return torch.zeros(1) 3849*da0073e9SAndroid Build Coastguard Worker 3850*da0073e9SAndroid Build Coastguard Worker x = MyTensor() 3851*da0073e9SAndroid Build Coastguard Worker ref0 = fn(x) 3852*da0073e9SAndroid Build Coastguard Worker ref1 = fn(4) 3853*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager")(fn) 3854*da0073e9SAndroid Build Coastguard Worker res0 = opt_fn(x) 3855*da0073e9SAndroid Build Coastguard Worker res1 = opt_fn(4) 3856*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref0, res0)) 3857*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref1, res1)) 3858*da0073e9SAndroid Build Coastguard Worker 3859*da0073e9SAndroid Build Coastguard Worker def test_tensor_data(self): 3860*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 3861*da0073e9SAndroid Build Coastguard Worker return x[y.data] 3862*da0073e9SAndroid Build Coastguard Worker 3863*da0073e9SAndroid Build Coastguard Worker x = torch.rand(8) 3864*da0073e9SAndroid Build Coastguard Worker y = torch.ones(8).to(torch.int) 3865*da0073e9SAndroid Build Coastguard Worker ref = fn(x, y) 3866*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) 3867*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x, y) 3868*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 3869*da0073e9SAndroid Build Coastguard Worker 3870*da0073e9SAndroid Build Coastguard Worker def test_tensor_layout(self): 3871*da0073e9SAndroid Build Coastguard Worker def fn(x): 3872*da0073e9SAndroid Build Coastguard Worker return torch.zeros( 3873*da0073e9SAndroid Build Coastguard Worker [x.size()[0], x.size()[1]], 3874*da0073e9SAndroid Build Coastguard Worker dtype=x.dtype, 3875*da0073e9SAndroid Build Coastguard Worker layout=x.layout, 3876*da0073e9SAndroid Build Coastguard Worker device=x.device, 3877*da0073e9SAndroid Build Coastguard Worker ) 3878*da0073e9SAndroid Build Coastguard Worker 3879*da0073e9SAndroid Build Coastguard Worker x = torch.rand(2, 3) 3880*da0073e9SAndroid Build Coastguard Worker ref = fn(x) 3881*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) 3882*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x) 3883*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 3884*da0073e9SAndroid Build Coastguard Worker 3885*da0073e9SAndroid Build Coastguard Worker def test_version_ci(self): 3886*da0073e9SAndroid Build Coastguard Worker # temporary test to check that the ci torch version is set correctly 3887*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hasattr(torch, "_subclasses")) 3888*da0073e9SAndroid Build Coastguard Worker 3889*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "requires cuda") 3890*da0073e9SAndroid Build Coastguard Worker def test_rand(self): 3891*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 3892*da0073e9SAndroid Build Coastguard Worker device = "cuda" 3893*da0073e9SAndroid Build Coastguard Worker 3894*da0073e9SAndroid Build Coastguard Worker def fn(): 3895*da0073e9SAndroid Build Coastguard Worker return torch.randn(10, device=device) 3896*da0073e9SAndroid Build Coastguard Worker 3897*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(10) 3898*da0073e9SAndroid Build Coastguard Worker ref_run1 = fn() 3899*da0073e9SAndroid Build Coastguard Worker 3900*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(10) 3901*da0073e9SAndroid Build Coastguard Worker ref_run2 = fn() 3902*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref_run1, ref_run2)) 3903*da0073e9SAndroid Build Coastguard Worker 3904*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(10) 3905*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 3906*da0073e9SAndroid Build Coastguard Worker res = opt_fn() 3907*da0073e9SAndroid Build Coastguard Worker 3908*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(res, ref_run1)) 3909*da0073e9SAndroid Build Coastguard Worker 3910*da0073e9SAndroid Build Coastguard Worker def test_slice_input(self): 3911*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 3912*da0073e9SAndroid Build Coastguard Worker 3913*da0073e9SAndroid Build Coastguard Worker def getitem(a, idx): 3914*da0073e9SAndroid Build Coastguard Worker if isinstance(idx, slice): 3915*da0073e9SAndroid Build Coastguard Worker return ( 3916*da0073e9SAndroid Build Coastguard Worker torch.zeros(1), 3917*da0073e9SAndroid Build Coastguard Worker a[idx] 3918*da0073e9SAndroid Build Coastguard Worker + [ 3919*da0073e9SAndroid Build Coastguard Worker 100, 3920*da0073e9SAndroid Build Coastguard Worker ], 3921*da0073e9SAndroid Build Coastguard Worker ) 3922*da0073e9SAndroid Build Coastguard Worker else: 3923*da0073e9SAndroid Build Coastguard Worker return (torch.zeros(1), a[idx]) 3924*da0073e9SAndroid Build Coastguard Worker 3925*da0073e9SAndroid Build Coastguard Worker layers = list(range(10)) 3926*da0073e9SAndroid Build Coastguard Worker ref0 = getitem(layers, slice(0, 2, 1)) 3927*da0073e9SAndroid Build Coastguard Worker ref1 = getitem(layers, 2) 3928*da0073e9SAndroid Build Coastguard Worker ref2 = getitem(layers, slice(3, 8, 2)) 3929*da0073e9SAndroid Build Coastguard Worker opt_getitem = torch._dynamo.optimize(cnts, nopython=True)(getitem) 3930*da0073e9SAndroid Build Coastguard Worker res0 = opt_getitem(layers, slice(0, 2, 1)) 3931*da0073e9SAndroid Build Coastguard Worker res1 = opt_getitem(layers, 2) 3932*da0073e9SAndroid Build Coastguard Worker res2 = opt_getitem(layers, slice(3, 8, 2)) 3933*da0073e9SAndroid Build Coastguard Worker 3934*da0073e9SAndroid Build Coastguard Worker self.assertTrue(ref0 == res0) 3935*da0073e9SAndroid Build Coastguard Worker self.assertTrue(ref1 == res1) 3936*da0073e9SAndroid Build Coastguard Worker self.assertTrue(ref2 == res2) 3937*da0073e9SAndroid Build Coastguard Worker 3938*da0073e9SAndroid Build Coastguard Worker def test_grad(self): 3939*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 3940*da0073e9SAndroid Build Coastguard Worker 3941*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 3942*da0073e9SAndroid Build Coastguard Worker out = a * b 3943*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 3944*da0073e9SAndroid Build Coastguard Worker real_out = torch.sigmoid(a.grad + b) 3945*da0073e9SAndroid Build Coastguard Worker return real_out 3946*da0073e9SAndroid Build Coastguard Worker 3947*da0073e9SAndroid Build Coastguard Worker inps = [torch.randn(4, requires_grad=True) for _ in range(2)] 3948*da0073e9SAndroid Build Coastguard Worker for inp in inps: 3949*da0073e9SAndroid Build Coastguard Worker inp.grad = None 3950*da0073e9SAndroid Build Coastguard Worker ref = fn(*inps) 3951*da0073e9SAndroid Build Coastguard Worker 3952*da0073e9SAndroid Build Coastguard Worker for inp in inps: 3953*da0073e9SAndroid Build Coastguard Worker inp.grad = None 3954*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 3955*da0073e9SAndroid Build Coastguard Worker res = opt_fn(*inps) 3956*da0073e9SAndroid Build Coastguard Worker 3957*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 3958*da0073e9SAndroid Build Coastguard Worker 3959*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch(guard_nn_modules=True) 3960*da0073e9SAndroid Build Coastguard Worker def test_source_non_input_grad_access(self): 3961*da0073e9SAndroid Build Coastguard Worker # This test creates a model, and accesses the grads 3962*da0073e9SAndroid Build Coastguard Worker # from its parameter. This means that within dynamo, 3963*da0073e9SAndroid Build Coastguard Worker # the tensor we are reading the grad from HAS a source, 3964*da0073e9SAndroid Build Coastguard Worker # but is not known to graphargs. 3965*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 3966*da0073e9SAndroid Build Coastguard Worker 3967*da0073e9SAndroid Build Coastguard Worker class TrivialModel(torch.nn.Module): 3968*da0073e9SAndroid Build Coastguard Worker def __init__(self): 3969*da0073e9SAndroid Build Coastguard Worker super(TrivialModel, self).__init__() 3970*da0073e9SAndroid Build Coastguard Worker self.linear = torch.nn.Linear(2, 1) 3971*da0073e9SAndroid Build Coastguard Worker 3972*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 3973*da0073e9SAndroid Build Coastguard Worker return self.linear(x) 3974*da0073e9SAndroid Build Coastguard Worker 3975*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 3976*da0073e9SAndroid Build Coastguard Worker outs = [] 3977*da0073e9SAndroid Build Coastguard Worker for param in model.parameters(): 3978*da0073e9SAndroid Build Coastguard Worker outs.append(torch.ones(param.grad.size())) 3979*da0073e9SAndroid Build Coastguard Worker return outs, param.grad + 1 3980*da0073e9SAndroid Build Coastguard Worker 3981*da0073e9SAndroid Build Coastguard Worker model = TrivialModel() 3982*da0073e9SAndroid Build Coastguard Worker # Eager 3983*da0073e9SAndroid Build Coastguard Worker a = torch.ones([2, 2], requires_grad=True) 3984*da0073e9SAndroid Build Coastguard Worker b = torch.ones([2, 2]) 3985*da0073e9SAndroid Build Coastguard Worker out = model(a) 3986*da0073e9SAndroid Build Coastguard Worker out_sum = out.sum() 3987*da0073e9SAndroid Build Coastguard Worker out_sum.backward() 3988*da0073e9SAndroid Build Coastguard Worker ref = fn(a, b) 3989*da0073e9SAndroid Build Coastguard Worker 3990*da0073e9SAndroid Build Coastguard Worker # Compiled 3991*da0073e9SAndroid Build Coastguard Worker model = TrivialModel() 3992*da0073e9SAndroid Build Coastguard Worker a = torch.ones([2, 2], requires_grad=True) 3993*da0073e9SAndroid Build Coastguard Worker b = torch.ones([2, 2]) 3994*da0073e9SAndroid Build Coastguard Worker out = model(a) 3995*da0073e9SAndroid Build Coastguard Worker out_sum = out.sum() 3996*da0073e9SAndroid Build Coastguard Worker out_sum.backward() 3997*da0073e9SAndroid Build Coastguard Worker 3998*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 3999*da0073e9SAndroid Build Coastguard Worker res = opt_fn(a, b) 4000*da0073e9SAndroid Build Coastguard Worker 4001*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 4002*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 4003*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.op_count, 3) 4004*da0073e9SAndroid Build Coastguard Worker 4005*da0073e9SAndroid Build Coastguard Worker def test_intermediary_tensor_grad_access(self): 4006*da0073e9SAndroid Build Coastguard Worker # This test creates a model, and accesses the grads 4007*da0073e9SAndroid Build Coastguard Worker # from its parameters and an entirely intermediary tensor. 4008*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 4009*da0073e9SAndroid Build Coastguard Worker 4010*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 4011*da0073e9SAndroid Build Coastguard Worker intermediary = torch.ones(2, 2) 4012*da0073e9SAndroid Build Coastguard Worker c = a + intermediary 4013*da0073e9SAndroid Build Coastguard Worker outs = [] 4014*da0073e9SAndroid Build Coastguard Worker outs.append(intermediary.grad) 4015*da0073e9SAndroid Build Coastguard Worker return outs 4016*da0073e9SAndroid Build Coastguard Worker 4017*da0073e9SAndroid Build Coastguard Worker # Eager 4018*da0073e9SAndroid Build Coastguard Worker a = torch.ones([2, 2], requires_grad=True) 4019*da0073e9SAndroid Build Coastguard Worker b = torch.ones([2, 2]) 4020*da0073e9SAndroid Build Coastguard Worker ref = fn(a, b) 4021*da0073e9SAndroid Build Coastguard Worker 4022*da0073e9SAndroid Build Coastguard Worker # Compiled 4023*da0073e9SAndroid Build Coastguard Worker a = torch.ones([2, 2], requires_grad=True) 4024*da0073e9SAndroid Build Coastguard Worker b = torch.ones([2, 2]) 4025*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 4026*da0073e9SAndroid Build Coastguard Worker res = opt_fn(a, b) 4027*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 4028*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 4029*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.op_count, 2) 4030*da0073e9SAndroid Build Coastguard Worker 4031*da0073e9SAndroid Build Coastguard Worker def test_clone_sparse_input(self): 4032*da0073e9SAndroid Build Coastguard Worker for layout in [ 4033*da0073e9SAndroid Build Coastguard Worker torch.sparse_coo, 4034*da0073e9SAndroid Build Coastguard Worker torch.sparse_csr, 4035*da0073e9SAndroid Build Coastguard Worker torch.sparse_csc, 4036*da0073e9SAndroid Build Coastguard Worker torch.sparse_bsr, 4037*da0073e9SAndroid Build Coastguard Worker torch.sparse_bsc, 4038*da0073e9SAndroid Build Coastguard Worker ]: 4039*da0073e9SAndroid Build Coastguard Worker for sparse_input in self.generate_simple_inputs( 4040*da0073e9SAndroid Build Coastguard Worker layout, 4041*da0073e9SAndroid Build Coastguard Worker device="cpu", 4042*da0073e9SAndroid Build Coastguard Worker dtype=torch.float64, 4043*da0073e9SAndroid Build Coastguard Worker index_dtype=torch.int64, 4044*da0073e9SAndroid Build Coastguard Worker ): 4045*da0073e9SAndroid Build Coastguard Worker # Invoke the dynamo clone input method directly. 4046*da0073e9SAndroid Build Coastguard Worker sparse_copy = torch._dynamo.utils.clone_input(sparse_input) 4047*da0073e9SAndroid Build Coastguard Worker # Make sure sparse clone is successful. 4048*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sparse_input, sparse_copy) 4049*da0073e9SAndroid Build Coastguard Worker 4050*da0073e9SAndroid Build Coastguard Worker def test_tensor_is_contiguous(self): 4051*da0073e9SAndroid Build Coastguard Worker def fn(x): 4052*da0073e9SAndroid Build Coastguard Worker input = torch.randn((1, 16, 1, 1)) 4053*da0073e9SAndroid Build Coastguard Worker weight = torch.randn((8, 16, 3, 3)) 4054*da0073e9SAndroid Build Coastguard Worker weight = weight.to(memory_format=x) 4055*da0073e9SAndroid Build Coastguard Worker output = torch.conv2d(input, weight, None, (2, 1), (1, 1), (1, 1), 1) 4056*da0073e9SAndroid Build Coastguard Worker return output.is_contiguous(memory_format=x) 4057*da0073e9SAndroid Build Coastguard Worker 4058*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager")(fn) 4059*da0073e9SAndroid Build Coastguard Worker for x in [torch.contiguous_format, torch.channels_last]: 4060*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(x), opt_fn(x)) 4061*da0073e9SAndroid Build Coastguard Worker 4062*da0073e9SAndroid Build Coastguard Worker def test_python_slice(self): 4063*da0073e9SAndroid Build Coastguard Worker def f1(input): 4064*da0073e9SAndroid Build Coastguard Worker y = 0 4065*da0073e9SAndroid Build Coastguard Worker for i, x in enumerate(input[2:], 1): 4066*da0073e9SAndroid Build Coastguard Worker y = y + x 4067*da0073e9SAndroid Build Coastguard Worker return y 4068*da0073e9SAndroid Build Coastguard Worker 4069*da0073e9SAndroid Build Coastguard Worker def f2(input): 4070*da0073e9SAndroid Build Coastguard Worker y = 0 4071*da0073e9SAndroid Build Coastguard Worker for i, x in enumerate(input.shape[2:], 1): 4072*da0073e9SAndroid Build Coastguard Worker y = y + x 4073*da0073e9SAndroid Build Coastguard Worker return y 4074*da0073e9SAndroid Build Coastguard Worker 4075*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 4076*da0073e9SAndroid Build Coastguard Worker opt_f1 = torch._dynamo.optimize(cnts)(f1) 4077*da0073e9SAndroid Build Coastguard Worker opt_f2 = torch._dynamo.optimize(cnts)(f2) 4078*da0073e9SAndroid Build Coastguard Worker res1 = opt_f1([1, 2, 3, 5]) 4079*da0073e9SAndroid Build Coastguard Worker res2 = opt_f2(torch.rand([2, 3, 4, 5])) 4080*da0073e9SAndroid Build Coastguard Worker 4081*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, 8) 4082*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res2, 9) 4083*da0073e9SAndroid Build Coastguard Worker 4084*da0073e9SAndroid Build Coastguard Worker def test_enum_as_dict_key(self): 4085*da0073e9SAndroid Build Coastguard Worker class MyEnum(enum.Enum): 4086*da0073e9SAndroid Build Coastguard Worker FOO = 10 4087*da0073e9SAndroid Build Coastguard Worker BAR = 20 4088*da0073e9SAndroid Build Coastguard Worker 4089*da0073e9SAndroid Build Coastguard Worker def fn(x): 4090*da0073e9SAndroid Build Coastguard Worker y = x + 2 4091*da0073e9SAndroid Build Coastguard Worker z = { 4092*da0073e9SAndroid Build Coastguard Worker MyEnum.FOO: torch.tensor(1), 4093*da0073e9SAndroid Build Coastguard Worker MyEnum.BAR: 10, 4094*da0073e9SAndroid Build Coastguard Worker "MyEnum.BAR": torch.tensor(8), 4095*da0073e9SAndroid Build Coastguard Worker 5: torch.rand(3), 4096*da0073e9SAndroid Build Coastguard Worker } 4097*da0073e9SAndroid Build Coastguard Worker torch._dynamo.graph_break() 4098*da0073e9SAndroid Build Coastguard Worker a = z[MyEnum.FOO] + z["MyEnum.BAR"] 4099*da0073e9SAndroid Build Coastguard Worker b = y * 2 4100*da0073e9SAndroid Build Coastguard Worker return a, b 4101*da0073e9SAndroid Build Coastguard Worker 4102*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 4103*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 4104*da0073e9SAndroid Build Coastguard Worker for _ in range(10): 4105*da0073e9SAndroid Build Coastguard Worker x = torch.rand(3) 4106*da0073e9SAndroid Build Coastguard Worker ref = fn(x) 4107*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x) 4108*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 4109*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 2) 4110*da0073e9SAndroid Build Coastguard Worker 4111*da0073e9SAndroid Build Coastguard Worker def test_enum_as_dict_key_with_overloaded_str(self): 4112*da0073e9SAndroid Build Coastguard Worker class MyEnum(enum.Enum): 4113*da0073e9SAndroid Build Coastguard Worker FOO = 10 4114*da0073e9SAndroid Build Coastguard Worker BAR = 20 4115*da0073e9SAndroid Build Coastguard Worker 4116*da0073e9SAndroid Build Coastguard Worker def __str__(self): 4117*da0073e9SAndroid Build Coastguard Worker return self.value 4118*da0073e9SAndroid Build Coastguard Worker 4119*da0073e9SAndroid Build Coastguard Worker def fn(x): 4120*da0073e9SAndroid Build Coastguard Worker y = x + 2 4121*da0073e9SAndroid Build Coastguard Worker z = { 4122*da0073e9SAndroid Build Coastguard Worker MyEnum.FOO: torch.tensor(1), 4123*da0073e9SAndroid Build Coastguard Worker MyEnum.BAR: 10, 4124*da0073e9SAndroid Build Coastguard Worker "MyEnum.BAR": torch.tensor(8), 4125*da0073e9SAndroid Build Coastguard Worker 5: torch.rand(3), 4126*da0073e9SAndroid Build Coastguard Worker } 4127*da0073e9SAndroid Build Coastguard Worker torch._dynamo.graph_break() 4128*da0073e9SAndroid Build Coastguard Worker a = z[MyEnum.FOO] + z["MyEnum.BAR"] 4129*da0073e9SAndroid Build Coastguard Worker b = y * 2 4130*da0073e9SAndroid Build Coastguard Worker return a, b 4131*da0073e9SAndroid Build Coastguard Worker 4132*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 4133*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 4134*da0073e9SAndroid Build Coastguard Worker for _ in range(10): 4135*da0073e9SAndroid Build Coastguard Worker x = torch.rand(3) 4136*da0073e9SAndroid Build Coastguard Worker ref = fn(x) 4137*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x) 4138*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 4139*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 2) 4140*da0073e9SAndroid Build Coastguard Worker 4141*da0073e9SAndroid Build Coastguard Worker def test_const_dict_variable_python_type(self): 4142*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.variables import ConstantVariable, ConstDictVariable 4143*da0073e9SAndroid Build Coastguard Worker 4144*da0073e9SAndroid Build Coastguard Worker make_key = ConstantVariable.create 4145*da0073e9SAndroid Build Coastguard Worker 4146*da0073e9SAndroid Build Coastguard Worker d1 = { 4147*da0073e9SAndroid Build Coastguard Worker make_key("a"): ConstantVariable.create(10), 4148*da0073e9SAndroid Build Coastguard Worker make_key("b"): ConstantVariable.create(20), 4149*da0073e9SAndroid Build Coastguard Worker } 4150*da0073e9SAndroid Build Coastguard Worker d2 = collections.OrderedDict( 4151*da0073e9SAndroid Build Coastguard Worker [ 4152*da0073e9SAndroid Build Coastguard Worker (make_key("x"), ConstantVariable.create(12)), 4153*da0073e9SAndroid Build Coastguard Worker (make_key("y"), ConstantVariable.create(22)), 4154*da0073e9SAndroid Build Coastguard Worker ] 4155*da0073e9SAndroid Build Coastguard Worker ) 4156*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ConstDictVariable(d1).python_type(), dict) 4157*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 4158*da0073e9SAndroid Build Coastguard Worker ConstDictVariable(d2, collections.OrderedDict).python_type(), 4159*da0073e9SAndroid Build Coastguard Worker collections.OrderedDict, 4160*da0073e9SAndroid Build Coastguard Worker ) 4161*da0073e9SAndroid Build Coastguard Worker 4162*da0073e9SAndroid Build Coastguard Worker def test_builtin_subclasses_as_method_on_class_type(self): 4163*da0073e9SAndroid Build Coastguard Worker class Foo: 4164*da0073e9SAndroid Build Coastguard Worker def __init__(self, name): 4165*da0073e9SAndroid Build Coastguard Worker self.ame_ = name 4166*da0073e9SAndroid Build Coastguard Worker 4167*da0073e9SAndroid Build Coastguard Worker def get_name(self): 4168*da0073e9SAndroid Build Coastguard Worker return "Foo " + self.name_ 4169*da0073e9SAndroid Build Coastguard Worker 4170*da0073e9SAndroid Build Coastguard Worker class Bar(Foo): 4171*da0073e9SAndroid Build Coastguard Worker def __init__(self, name): 4172*da0073e9SAndroid Build Coastguard Worker self.name_ = name 4173*da0073e9SAndroid Build Coastguard Worker 4174*da0073e9SAndroid Build Coastguard Worker def get_name(self): 4175*da0073e9SAndroid Build Coastguard Worker return "Bar " + self.name_ 4176*da0073e9SAndroid Build Coastguard Worker 4177*da0073e9SAndroid Build Coastguard Worker class Baz(Foo): 4178*da0073e9SAndroid Build Coastguard Worker def __init__(self, name): # noqa: B903 4179*da0073e9SAndroid Build Coastguard Worker self.name_ = name 4180*da0073e9SAndroid Build Coastguard Worker 4181*da0073e9SAndroid Build Coastguard Worker def get_name(self): 4182*da0073e9SAndroid Build Coastguard Worker return "Baz " + self.name_ 4183*da0073e9SAndroid Build Coastguard Worker 4184*da0073e9SAndroid Build Coastguard Worker subs_of_foo_reg = Foo.__subclasses__() 4185*da0073e9SAndroid Build Coastguard Worker 4186*da0073e9SAndroid Build Coastguard Worker counter = CompileCounter() 4187*da0073e9SAndroid Build Coastguard Worker 4188*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize_assert(counter) 4189*da0073e9SAndroid Build Coastguard Worker def fn(): 4190*da0073e9SAndroid Build Coastguard Worker return Foo.__subclasses__() 4191*da0073e9SAndroid Build Coastguard Worker 4192*da0073e9SAndroid Build Coastguard Worker subs_of_foo_optim = fn() 4193*da0073e9SAndroid Build Coastguard Worker 4194*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(subs_of_foo_reg), 2) 4195*da0073e9SAndroid Build Coastguard Worker self.assertEqual(subs_of_foo_reg, subs_of_foo_optim) 4196*da0073e9SAndroid Build Coastguard Worker 4197*da0073e9SAndroid Build Coastguard Worker def test_builtin_subclasses_as_method_on_var(self): 4198*da0073e9SAndroid Build Coastguard Worker class Foo: 4199*da0073e9SAndroid Build Coastguard Worker def __init__(self, name): 4200*da0073e9SAndroid Build Coastguard Worker self.name_ = name 4201*da0073e9SAndroid Build Coastguard Worker 4202*da0073e9SAndroid Build Coastguard Worker def get_name(self): 4203*da0073e9SAndroid Build Coastguard Worker return "Foo " + self.name_ 4204*da0073e9SAndroid Build Coastguard Worker 4205*da0073e9SAndroid Build Coastguard Worker class Bar(Foo): 4206*da0073e9SAndroid Build Coastguard Worker def __init__(self, name): 4207*da0073e9SAndroid Build Coastguard Worker self.name_ = name 4208*da0073e9SAndroid Build Coastguard Worker 4209*da0073e9SAndroid Build Coastguard Worker def get_name(self): 4210*da0073e9SAndroid Build Coastguard Worker return "Bar " + self.name_ 4211*da0073e9SAndroid Build Coastguard Worker 4212*da0073e9SAndroid Build Coastguard Worker class Baz(Bar): 4213*da0073e9SAndroid Build Coastguard Worker def __init__(self, name): 4214*da0073e9SAndroid Build Coastguard Worker self.name_ = name 4215*da0073e9SAndroid Build Coastguard Worker 4216*da0073e9SAndroid Build Coastguard Worker def get_name(self): 4217*da0073e9SAndroid Build Coastguard Worker return "Baz " + self.name_ 4218*da0073e9SAndroid Build Coastguard Worker 4219*da0073e9SAndroid Build Coastguard Worker subs_of_foo_reg = Foo.__subclasses__() 4220*da0073e9SAndroid Build Coastguard Worker sub_of_foo_subclass_var_reg = subs_of_foo_reg[0].__subclasses__() 4221*da0073e9SAndroid Build Coastguard Worker 4222*da0073e9SAndroid Build Coastguard Worker sub_of_foo_subclass_var_optim = list() 4223*da0073e9SAndroid Build Coastguard Worker counter = CompileCounter() 4224*da0073e9SAndroid Build Coastguard Worker 4225*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize_assert(counter) 4226*da0073e9SAndroid Build Coastguard Worker def fn(): 4227*da0073e9SAndroid Build Coastguard Worker return Foo.__subclasses__() 4228*da0073e9SAndroid Build Coastguard Worker 4229*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize_assert(counter) 4230*da0073e9SAndroid Build Coastguard Worker def fn_single(subs_of_foo_optim): 4231*da0073e9SAndroid Build Coastguard Worker return subs_of_foo_optim[0].__subclasses__() 4232*da0073e9SAndroid Build Coastguard Worker 4233*da0073e9SAndroid Build Coastguard Worker subs_of_foo_optim = fn() 4234*da0073e9SAndroid Build Coastguard Worker sub_of_foo_subclass_var_optim = fn_single(subs_of_foo_optim) 4235*da0073e9SAndroid Build Coastguard Worker 4236*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(sub_of_foo_subclass_var_optim), 1) 4237*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sub_of_foo_subclass_var_optim, sub_of_foo_subclass_var_reg) 4238*da0073e9SAndroid Build Coastguard Worker 4239*da0073e9SAndroid Build Coastguard Worker def test_builtin_str_on_user_defined_function(self): 4240*da0073e9SAndroid Build Coastguard Worker def another_fn(): 4241*da0073e9SAndroid Build Coastguard Worker pass 4242*da0073e9SAndroid Build Coastguard Worker 4243*da0073e9SAndroid Build Coastguard Worker def fn(): 4244*da0073e9SAndroid Build Coastguard Worker return "another_fn" in str(another_fn) 4245*da0073e9SAndroid Build Coastguard Worker 4246*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(nopython=True)(fn) 4247*da0073e9SAndroid Build Coastguard Worker self.assertTrue(opt_fn()) 4248*da0073e9SAndroid Build Coastguard Worker 4249*da0073e9SAndroid Build Coastguard Worker def test_enum_no_graphbreaks(self): 4250*da0073e9SAndroid Build Coastguard Worker class Foo(enum.Enum): 4251*da0073e9SAndroid Build Coastguard Worker FOO = 0 4252*da0073e9SAndroid Build Coastguard Worker BAR = 1 4253*da0073e9SAndroid Build Coastguard Worker 4254*da0073e9SAndroid Build Coastguard Worker def fn(x, foo): 4255*da0073e9SAndroid Build Coastguard Worker if foo is Foo.FOO: 4256*da0073e9SAndroid Build Coastguard Worker x = torch.add(x, 1.0) 4257*da0073e9SAndroid Build Coastguard Worker x = torch.mul(x, 1.0) 4258*da0073e9SAndroid Build Coastguard Worker return x 4259*da0073e9SAndroid Build Coastguard Worker 4260*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1) 4261*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 4262*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 4263*da0073e9SAndroid Build Coastguard Worker opt_fn(x, Foo.FOO) 4264*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.op_count, 2) 4265*da0073e9SAndroid Build Coastguard Worker 4266*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 4267*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 4268*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 4269*da0073e9SAndroid Build Coastguard Worker opt_fn(x, Foo.BAR) 4270*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.op_count, 1) 4271*da0073e9SAndroid Build Coastguard Worker 4272*da0073e9SAndroid Build Coastguard Worker def test_repeat_interleave_graphbreaks(self): 4273*da0073e9SAndroid Build Coastguard Worker def fn_no_breaks(x): 4274*da0073e9SAndroid Build Coastguard Worker # no breaks on self_int 4275*da0073e9SAndroid Build Coastguard Worker x += 1 4276*da0073e9SAndroid Build Coastguard Worker x = torch.repeat_interleave(x, 2, 3) 4277*da0073e9SAndroid Build Coastguard Worker x += 1 4278*da0073e9SAndroid Build Coastguard Worker return x 4279*da0073e9SAndroid Build Coastguard Worker 4280*da0073e9SAndroid Build Coastguard Worker def fn_has_breaks(x): 4281*da0073e9SAndroid Build Coastguard Worker # breaks on self_Tensor 4282*da0073e9SAndroid Build Coastguard Worker x += 1 4283*da0073e9SAndroid Build Coastguard Worker x = torch.repeat_interleave(x, torch.tensor(2), 3) 4284*da0073e9SAndroid Build Coastguard Worker x += 1 4285*da0073e9SAndroid Build Coastguard Worker return x 4286*da0073e9SAndroid Build Coastguard Worker 4287*da0073e9SAndroid Build Coastguard Worker x = torch.randn([4, 16, 1, 64]) 4288*da0073e9SAndroid Build Coastguard Worker 4289*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 4290*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn_no_breaks) 4291*da0073e9SAndroid Build Coastguard Worker opt_fn(x) 4292*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 4293*da0073e9SAndroid Build Coastguard Worker 4294*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 4295*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 4296*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn_has_breaks) 4297*da0073e9SAndroid Build Coastguard Worker opt_fn(x) 4298*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 2) 4299*da0073e9SAndroid Build Coastguard Worker 4300*da0073e9SAndroid Build Coastguard Worker def test_id_guarded_object(self): 4301*da0073e9SAndroid Build Coastguard Worker class UDO: 4302*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager") 4303*da0073e9SAndroid Build Coastguard Worker def call(self, x, ref_id): 4304*da0073e9SAndroid Build Coastguard Worker self_id = id(self) 4305*da0073e9SAndroid Build Coastguard Worker if self_id == ref_id: 4306*da0073e9SAndroid Build Coastguard Worker x = torch.mul(x, 1.0) 4307*da0073e9SAndroid Build Coastguard Worker else: 4308*da0073e9SAndroid Build Coastguard Worker x = torch.mul(x, 0) 4309*da0073e9SAndroid Build Coastguard Worker return x 4310*da0073e9SAndroid Build Coastguard Worker 4311*da0073e9SAndroid Build Coastguard Worker # Make sure we do recompile when id(self) is executed on 4312*da0073e9SAndroid Build Coastguard Worker # different self objects. 4313*da0073e9SAndroid Build Coastguard Worker x = torch.ones(2) 4314*da0073e9SAndroid Build Coastguard Worker obj1 = UDO() 4315*da0073e9SAndroid Build Coastguard Worker obj1_id = id(obj1) 4316*da0073e9SAndroid Build Coastguard Worker self.assertEqual(obj1.call(x, obj1_id), torch.ones(2)) 4317*da0073e9SAndroid Build Coastguard Worker 4318*da0073e9SAndroid Build Coastguard Worker obj2 = UDO() 4319*da0073e9SAndroid Build Coastguard Worker # if we do not install ID_MATCH: ___check_obj_id(L['self'], xxx) this fails. 4320*da0073e9SAndroid Build Coastguard Worker self.assertEqual(obj2.call(x, obj1_id), torch.zeros(2)) 4321*da0073e9SAndroid Build Coastguard Worker 4322*da0073e9SAndroid Build Coastguard Worker def test_id_guarded_module(self): 4323*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 4324*da0073e9SAndroid Build Coastguard Worker def forward(self, x, ref_id): 4325*da0073e9SAndroid Build Coastguard Worker self_id = id(self) 4326*da0073e9SAndroid Build Coastguard Worker if self_id == ref_id: 4327*da0073e9SAndroid Build Coastguard Worker x = torch.mul(x, 1.0) 4328*da0073e9SAndroid Build Coastguard Worker else: 4329*da0073e9SAndroid Build Coastguard Worker x = torch.mul(x, 0) 4330*da0073e9SAndroid Build Coastguard Worker return x 4331*da0073e9SAndroid Build Coastguard Worker 4332*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 4333*da0073e9SAndroid Build Coastguard Worker 4334*da0073e9SAndroid Build Coastguard Worker # Make sure we do recompile when id(self) is executed on 4335*da0073e9SAndroid Build Coastguard Worker # different self objects. 4336*da0073e9SAndroid Build Coastguard Worker x = torch.ones(2) 4337*da0073e9SAndroid Build Coastguard Worker m1 = M() 4338*da0073e9SAndroid Build Coastguard Worker m1_id = id(m1) 4339*da0073e9SAndroid Build Coastguard Worker opt_m1 = torch._dynamo.optimize(cnts, nopython=True)(m1) 4340*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_m1(x, m1_id), torch.ones(2)) 4341*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_m1(x, m1_id), torch.ones(2)) 4342*da0073e9SAndroid Build Coastguard Worker 4343*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 4344*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.op_count, 1) 4345*da0073e9SAndroid Build Coastguard Worker 4346*da0073e9SAndroid Build Coastguard Worker m2 = M() 4347*da0073e9SAndroid Build Coastguard Worker opt_m2 = torch._dynamo.optimize(cnts, nopython=True)(m2) 4348*da0073e9SAndroid Build Coastguard Worker # if we do not install ID_MATCH: ___check_obj_id(L['self'], xxx) this fails. 4349*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_m2(x, m1_id), torch.zeros(2)) 4350*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 2) 4351*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.op_count, 2) 4352*da0073e9SAndroid Build Coastguard Worker 4353*da0073e9SAndroid Build Coastguard Worker def test_id_of_nn_module(self): 4354*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 4355*da0073e9SAndroid Build Coastguard Worker def forward(self, x, ref_id): 4356*da0073e9SAndroid Build Coastguard Worker self_id = id(self) 4357*da0073e9SAndroid Build Coastguard Worker if self_id == ref_id: 4358*da0073e9SAndroid Build Coastguard Worker x = torch.mul(x, 1.0) 4359*da0073e9SAndroid Build Coastguard Worker x = torch.add(x, 1.0) 4360*da0073e9SAndroid Build Coastguard Worker return x 4361*da0073e9SAndroid Build Coastguard Worker 4362*da0073e9SAndroid Build Coastguard Worker m = M().eval() 4363*da0073e9SAndroid Build Coastguard Worker data = torch.randn(1) 4364*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 4365*da0073e9SAndroid Build Coastguard Worker correct_ref_id = id(m) 4366*da0073e9SAndroid Build Coastguard Worker opt_m = torch._dynamo.optimize(cnts, nopython=True)(m) 4367*da0073e9SAndroid Build Coastguard Worker opt_m(data, correct_ref_id) 4368*da0073e9SAndroid Build Coastguard Worker # Extra op is the recorded equality test (although once 4369*da0073e9SAndroid Build Coastguard Worker # the trace is flattened this is dead!) 4370*da0073e9SAndroid Build Coastguard Worker if torch._dynamo.config.assume_static_by_default: 4371*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnts.op_count, """2""") 4372*da0073e9SAndroid Build Coastguard Worker else: 4373*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnts.op_count, """2""") 4374*da0073e9SAndroid Build Coastguard Worker 4375*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 4376*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 4377*da0073e9SAndroid Build Coastguard Worker incorrect_ref_id = id(m) + 1 4378*da0073e9SAndroid Build Coastguard Worker opt_m = torch._dynamo.optimize(cnts, nopython=True)(m) 4379*da0073e9SAndroid Build Coastguard Worker opt_m(data, incorrect_ref_id) 4380*da0073e9SAndroid Build Coastguard Worker if torch._dynamo.config.assume_static_by_default: 4381*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnts.op_count, """1""") 4382*da0073e9SAndroid Build Coastguard Worker else: 4383*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnts.op_count, """1""") 4384*da0073e9SAndroid Build Coastguard Worker 4385*da0073e9SAndroid Build Coastguard Worker def test_inline_func_jump_on_tensor_condition(self): 4386*da0073e9SAndroid Build Coastguard Worker def f1(input): 4387*da0073e9SAndroid Build Coastguard Worker if input == 0: 4388*da0073e9SAndroid Build Coastguard Worker return input + 1 4389*da0073e9SAndroid Build Coastguard Worker else: 4390*da0073e9SAndroid Build Coastguard Worker return input + 2 4391*da0073e9SAndroid Build Coastguard Worker 4392*da0073e9SAndroid Build Coastguard Worker def f2(input): 4393*da0073e9SAndroid Build Coastguard Worker return f1(input) 4394*da0073e9SAndroid Build Coastguard Worker 4395*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 4396*da0073e9SAndroid Build Coastguard Worker opt_f2 = torch._dynamo.optimize(cnts)(f2) 4397*da0073e9SAndroid Build Coastguard Worker res1 = opt_f2(torch.tensor([1.0])) 4398*da0073e9SAndroid Build Coastguard Worker res2 = opt_f2(torch.tensor([0.0])) 4399*da0073e9SAndroid Build Coastguard Worker 4400*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, 3) 4401*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res2, 1) 4402*da0073e9SAndroid Build Coastguard Worker 4403*da0073e9SAndroid Build Coastguard Worker def test_frozenset_torch_func_contains(self): 4404*da0073e9SAndroid Build Coastguard Worker funcs = frozenset([torch.add]) 4405*da0073e9SAndroid Build Coastguard Worker 4406*da0073e9SAndroid Build Coastguard Worker def fn(x, func): 4407*da0073e9SAndroid Build Coastguard Worker if func in funcs: 4408*da0073e9SAndroid Build Coastguard Worker x = torch.add(x, 1.0) 4409*da0073e9SAndroid Build Coastguard Worker x = torch.mul(x, 1.0) 4410*da0073e9SAndroid Build Coastguard Worker return x 4411*da0073e9SAndroid Build Coastguard Worker 4412*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1) 4413*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 4414*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 4415*da0073e9SAndroid Build Coastguard Worker opt_fn(x, torch.add) 4416*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.op_count, 2) 4417*da0073e9SAndroid Build Coastguard Worker 4418*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 4419*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 4420*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 4421*da0073e9SAndroid Build Coastguard Worker opt_fn(x, torch.mul) 4422*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.op_count, 1) 4423*da0073e9SAndroid Build Coastguard Worker 4424*da0073e9SAndroid Build Coastguard Worker def test_inline_list_mutation(self): 4425*da0073e9SAndroid Build Coastguard Worker def f1(x): 4426*da0073e9SAndroid Build Coastguard Worker x.append(torch.ones(8)) 4427*da0073e9SAndroid Build Coastguard Worker return x 4428*da0073e9SAndroid Build Coastguard Worker 4429*da0073e9SAndroid Build Coastguard Worker def f2(): 4430*da0073e9SAndroid Build Coastguard Worker x = [torch.ones(6)] 4431*da0073e9SAndroid Build Coastguard Worker f1(x) 4432*da0073e9SAndroid Build Coastguard Worker return x 4433*da0073e9SAndroid Build Coastguard Worker 4434*da0073e9SAndroid Build Coastguard Worker res1 = f2() 4435*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 4436*da0073e9SAndroid Build Coastguard Worker opt_f2 = torch._dynamo.optimize(cnts)(f2) 4437*da0073e9SAndroid Build Coastguard Worker res2 = opt_f2() 4438*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(res1, res2)) 4439*da0073e9SAndroid Build Coastguard Worker 4440*da0073e9SAndroid Build Coastguard Worker def test_inline_dict_mutation(self): 4441*da0073e9SAndroid Build Coastguard Worker def f1(d): 4442*da0073e9SAndroid Build Coastguard Worker d["c"] = d["a"] + d.pop("b") 4443*da0073e9SAndroid Build Coastguard Worker return d 4444*da0073e9SAndroid Build Coastguard Worker 4445*da0073e9SAndroid Build Coastguard Worker def f2(): 4446*da0073e9SAndroid Build Coastguard Worker d = {"a": torch.ones(5), "b": torch.ones(5)} 4447*da0073e9SAndroid Build Coastguard Worker f1(d) 4448*da0073e9SAndroid Build Coastguard Worker return d 4449*da0073e9SAndroid Build Coastguard Worker 4450*da0073e9SAndroid Build Coastguard Worker res1 = f2() 4451*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 4452*da0073e9SAndroid Build Coastguard Worker opt_f2 = torch._dynamo.optimize(cnts)(f2) 4453*da0073e9SAndroid Build Coastguard Worker res2 = opt_f2() 4454*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(res1, res2)) 4455*da0073e9SAndroid Build Coastguard Worker 4456*da0073e9SAndroid Build Coastguard Worker def test_inline_local_dict_clear(self): 4457*da0073e9SAndroid Build Coastguard Worker def f(d): 4458*da0073e9SAndroid Build Coastguard Worker d.clear() 4459*da0073e9SAndroid Build Coastguard Worker return d 4460*da0073e9SAndroid Build Coastguard Worker 4461*da0073e9SAndroid Build Coastguard Worker inp = {"a": torch.randn(2, 2), "b": torch.randn(2, 2)} 4462*da0073e9SAndroid Build Coastguard Worker out = torch.compile(f, backend="eager", fullgraph=True)(inp) 4463*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(out), 0) 4464*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(inp), 0) 4465*da0073e9SAndroid Build Coastguard Worker 4466*da0073e9SAndroid Build Coastguard Worker def test_inline_module_attr_dict_clear(self): 4467*da0073e9SAndroid Build Coastguard Worker class MyMod(torch.nn.Module): 4468*da0073e9SAndroid Build Coastguard Worker def __init__(self): 4469*da0073e9SAndroid Build Coastguard Worker super().__init__() 4470*da0073e9SAndroid Build Coastguard Worker self.a = {"a": torch.randn(2, 2), "b": torch.randn(2, 2)} 4471*da0073e9SAndroid Build Coastguard Worker 4472*da0073e9SAndroid Build Coastguard Worker def forward(self): 4473*da0073e9SAndroid Build Coastguard Worker self.a.clear() 4474*da0073e9SAndroid Build Coastguard Worker return self.a 4475*da0073e9SAndroid Build Coastguard Worker 4476*da0073e9SAndroid Build Coastguard Worker m = MyMod() 4477*da0073e9SAndroid Build Coastguard Worker out = torch.compile(m, backend="eager", fullgraph=True)() 4478*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(out), 0) 4479*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(m.a), 0) 4480*da0073e9SAndroid Build Coastguard Worker 4481*da0073e9SAndroid Build Coastguard Worker def test_inline_user_defined_dict_attr_clear(self): 4482*da0073e9SAndroid Build Coastguard Worker class MyMod: 4483*da0073e9SAndroid Build Coastguard Worker def __init__(self): 4484*da0073e9SAndroid Build Coastguard Worker self.a = {"a": torch.randn(2, 2), "b": torch.randn(2, 2)} 4485*da0073e9SAndroid Build Coastguard Worker 4486*da0073e9SAndroid Build Coastguard Worker def f(obj, inp): 4487*da0073e9SAndroid Build Coastguard Worker ret = len(obj.a) + inp 4488*da0073e9SAndroid Build Coastguard Worker obj.a.clear() 4489*da0073e9SAndroid Build Coastguard Worker return obj.a, ret 4490*da0073e9SAndroid Build Coastguard Worker 4491*da0073e9SAndroid Build Coastguard Worker m = MyMod() 4492*da0073e9SAndroid Build Coastguard Worker before_len = len(m.a) 4493*da0073e9SAndroid Build Coastguard Worker t_inp = torch.ones(1) 4494*da0073e9SAndroid Build Coastguard Worker d, ret = torch.compile(f, backend="eager", fullgraph=True)(m, t_inp) 4495*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(m.a), 0) 4496*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(d), 0) 4497*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ret, t_inp + before_len) 4498*da0073e9SAndroid Build Coastguard Worker 4499*da0073e9SAndroid Build Coastguard Worker def test_recursive_inline_list_mutation(self): 4500*da0073e9SAndroid Build Coastguard Worker def f1(x, y): 4501*da0073e9SAndroid Build Coastguard Worker x.append(torch.tensor([1.1])) 4502*da0073e9SAndroid Build Coastguard Worker y.append(torch.tensor([1.2])) 4503*da0073e9SAndroid Build Coastguard Worker return x, y 4504*da0073e9SAndroid Build Coastguard Worker 4505*da0073e9SAndroid Build Coastguard Worker def f2(x, y): 4506*da0073e9SAndroid Build Coastguard Worker x.append(torch.tensor([2.1])) 4507*da0073e9SAndroid Build Coastguard Worker y.append(torch.tensor([2.2])) 4508*da0073e9SAndroid Build Coastguard Worker f1(x, y) 4509*da0073e9SAndroid Build Coastguard Worker return x, y 4510*da0073e9SAndroid Build Coastguard Worker 4511*da0073e9SAndroid Build Coastguard Worker def f3(x): 4512*da0073e9SAndroid Build Coastguard Worker x.append(torch.tensor([3.1])) 4513*da0073e9SAndroid Build Coastguard Worker y = [torch.tensor([3.2])] 4514*da0073e9SAndroid Build Coastguard Worker f2(x, y) 4515*da0073e9SAndroid Build Coastguard Worker return x, y 4516*da0073e9SAndroid Build Coastguard Worker 4517*da0073e9SAndroid Build Coastguard Worker def f4(): 4518*da0073e9SAndroid Build Coastguard Worker x = [torch.tensor([4.1])] 4519*da0073e9SAndroid Build Coastguard Worker return f3(x) 4520*da0073e9SAndroid Build Coastguard Worker 4521*da0073e9SAndroid Build Coastguard Worker res1 = f4() 4522*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 4523*da0073e9SAndroid Build Coastguard Worker opt_f4 = torch._dynamo.optimize(cnts)(f4) 4524*da0073e9SAndroid Build Coastguard Worker res2 = opt_f4() 4525*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(res1, res2)) 4526*da0073e9SAndroid Build Coastguard Worker 4527*da0073e9SAndroid Build Coastguard Worker def test_sample_input(self): 4528*da0073e9SAndroid Build Coastguard Worker from torch.testing._internal.common_methods_invocations import SampleInput 4529*da0073e9SAndroid Build Coastguard Worker 4530*da0073e9SAndroid Build Coastguard Worker def fn(sample): 4531*da0073e9SAndroid Build Coastguard Worker if isinstance(sample.input, torch.Tensor): 4532*da0073e9SAndroid Build Coastguard Worker return sample.input * 2 4533*da0073e9SAndroid Build Coastguard Worker return torch.zeros(()) 4534*da0073e9SAndroid Build Coastguard Worker 4535*da0073e9SAndroid Build Coastguard Worker sample = SampleInput(torch.ones(2)) 4536*da0073e9SAndroid Build Coastguard Worker ref = fn(sample) 4537*da0073e9SAndroid Build Coastguard Worker 4538*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager")(fn) 4539*da0073e9SAndroid Build Coastguard Worker res = opt_fn(sample) 4540*da0073e9SAndroid Build Coastguard Worker 4541*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 4542*da0073e9SAndroid Build Coastguard Worker 4543*da0073e9SAndroid Build Coastguard Worker def test_release_input_memory(self): 4544*da0073e9SAndroid Build Coastguard Worker x = torch.rand([4]) 4545*da0073e9SAndroid Build Coastguard Worker x_ref = weakref.ref(x) 4546*da0073e9SAndroid Build Coastguard Worker 4547*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 4548*da0073e9SAndroid Build Coastguard Worker 4549*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize(cnts) 4550*da0073e9SAndroid Build Coastguard Worker def foo(x): 4551*da0073e9SAndroid Build Coastguard Worker return x + x 4552*da0073e9SAndroid Build Coastguard Worker 4553*da0073e9SAndroid Build Coastguard Worker out = foo(x) 4554*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(out, x + x)) 4555*da0073e9SAndroid Build Coastguard Worker del x 4556*da0073e9SAndroid Build Coastguard Worker self.assertIs(x_ref(), None) 4557*da0073e9SAndroid Build Coastguard Worker 4558*da0073e9SAndroid Build Coastguard Worker def test_release_module_memory(self): 4559*da0073e9SAndroid Build Coastguard Worker mod = torch.nn.Linear(10, 10) 4560*da0073e9SAndroid Build Coastguard Worker x = torch.rand([10, 10]) 4561*da0073e9SAndroid Build Coastguard Worker mod_weight_ref = weakref.ref(mod.weight) 4562*da0073e9SAndroid Build Coastguard Worker mod_ref = weakref.ref(mod) 4563*da0073e9SAndroid Build Coastguard Worker 4564*da0073e9SAndroid Build Coastguard Worker # Modules that are passed into torch._dynamo optimized functions 4565*da0073e9SAndroid Build Coastguard Worker # will normally be held onto through the generated GraphModule, 4566*da0073e9SAndroid Build Coastguard Worker # which contains the modules. remove the reference in this backend 4567*da0073e9SAndroid Build Coastguard Worker # and test that no additional references are being held. 4568*da0073e9SAndroid Build Coastguard Worker class NoLeakBackend: 4569*da0073e9SAndroid Build Coastguard Worker def __call__(self, gm: torch.fx.GraphModule, example_inputs): 4570*da0073e9SAndroid Build Coastguard Worker gm.mod = None 4571*da0073e9SAndroid Build Coastguard Worker 4572*da0073e9SAndroid Build Coastguard Worker def foo(*args, **kwargs): 4573*da0073e9SAndroid Build Coastguard Worker return (1,) 4574*da0073e9SAndroid Build Coastguard Worker 4575*da0073e9SAndroid Build Coastguard Worker return foo 4576*da0073e9SAndroid Build Coastguard Worker 4577*da0073e9SAndroid Build Coastguard Worker no_leak_backend = NoLeakBackend() 4578*da0073e9SAndroid Build Coastguard Worker 4579*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize(no_leak_backend) 4580*da0073e9SAndroid Build Coastguard Worker def foo(mod, x): 4581*da0073e9SAndroid Build Coastguard Worker return mod(x) 4582*da0073e9SAndroid Build Coastguard Worker 4583*da0073e9SAndroid Build Coastguard Worker foo(mod, x) 4584*da0073e9SAndroid Build Coastguard Worker del mod 4585*da0073e9SAndroid Build Coastguard Worker del x 4586*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(mod_ref(), None) 4587*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(mod_weight_ref(), None) 4588*da0073e9SAndroid Build Coastguard Worker 4589*da0073e9SAndroid Build Coastguard Worker def test_release_scope_memory(self): 4590*da0073e9SAndroid Build Coastguard Worker def inner(y): 4591*da0073e9SAndroid Build Coastguard Worker y 4592*da0073e9SAndroid Build Coastguard Worker 4593*da0073e9SAndroid Build Coastguard Worker inner = torch._dynamo.optimize("eager")(inner) 4594*da0073e9SAndroid Build Coastguard Worker 4595*da0073e9SAndroid Build Coastguard Worker p_ref = None 4596*da0073e9SAndroid Build Coastguard Worker 4597*da0073e9SAndroid Build Coastguard Worker x = torch.randn((10, 10)) 4598*da0073e9SAndroid Build Coastguard Worker inner(x) 4599*da0073e9SAndroid Build Coastguard Worker 4600*da0073e9SAndroid Build Coastguard Worker p_ref = weakref.ref(x) 4601*da0073e9SAndroid Build Coastguard Worker self.assertTrue(p_ref() is not None) 4602*da0073e9SAndroid Build Coastguard Worker del x 4603*da0073e9SAndroid Build Coastguard Worker self.assertTrue(p_ref() is None) 4604*da0073e9SAndroid Build Coastguard Worker 4605*da0073e9SAndroid Build Coastguard Worker def test_update_locals_and_stack_uses_shared_cache(self): 4606*da0073e9SAndroid Build Coastguard Worker def fn(x): 4607*da0073e9SAndroid Build Coastguard Worker perm = [0, 3, 5] 4608*da0073e9SAndroid Build Coastguard Worker perm = list(range(min(perm))) + perm 4609*da0073e9SAndroid Build Coastguard Worker perm.extend(i for i in range(x.dim()) if i not in perm) 4610*da0073e9SAndroid Build Coastguard Worker return perm 4611*da0073e9SAndroid Build Coastguard Worker 4612*da0073e9SAndroid Build Coastguard Worker x = torch.rand([2, 2, 2, 2, 2, 2]) 4613*da0073e9SAndroid Build Coastguard Worker res1 = fn(x) 4614*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 4615*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 4616*da0073e9SAndroid Build Coastguard Worker res2 = opt_fn(x) 4617*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(res1, res2)) 4618*da0073e9SAndroid Build Coastguard Worker 4619*da0073e9SAndroid Build Coastguard Worker def test_dict_reconstruct_keeps_original_order(self): 4620*da0073e9SAndroid Build Coastguard Worker def fn(): 4621*da0073e9SAndroid Build Coastguard Worker modules = collections.OrderedDict([("act", torch.nn.ReLU())]) 4622*da0073e9SAndroid Build Coastguard Worker module_dict = torch.nn.ModuleDict(modules) 4623*da0073e9SAndroid Build Coastguard Worker 4624*da0073e9SAndroid Build Coastguard Worker next_modules = {"fc4": torch.nn.Linear(5, 6), "act3": torch.nn.Sigmoid()} 4625*da0073e9SAndroid Build Coastguard Worker modules.update(next_modules.items()) 4626*da0073e9SAndroid Build Coastguard Worker module_dict.update(next_modules) 4627*da0073e9SAndroid Build Coastguard Worker return modules, module_dict 4628*da0073e9SAndroid Build Coastguard Worker 4629*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 4630*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 4631*da0073e9SAndroid Build Coastguard Worker modules, module_dict = opt_fn() 4632*da0073e9SAndroid Build Coastguard Worker 4633*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(module_dict), len(modules)) 4634*da0073e9SAndroid Build Coastguard Worker for k1, m2 in zip(modules, module_dict.children()): 4635*da0073e9SAndroid Build Coastguard Worker self.assertTrue(modules[k1] is m2) 4636*da0073e9SAndroid Build Coastguard Worker 4637*da0073e9SAndroid Build Coastguard Worker def test_side_effects_codegen_update_mutated(self): 4638*da0073e9SAndroid Build Coastguard Worker # codegen to update mutated variables with side effect 4639*da0073e9SAndroid Build Coastguard Worker # should after stack value's codegen 4640*da0073e9SAndroid Build Coastguard Worker def f1(x): 4641*da0073e9SAndroid Build Coastguard Worker alist = [x] 4642*da0073e9SAndroid Build Coastguard Worker alist.append(x + 1) 4643*da0073e9SAndroid Build Coastguard Worker alist[0].sum().item() # graph break 4644*da0073e9SAndroid Build Coastguard Worker res = alist.pop() 4645*da0073e9SAndroid Build Coastguard Worker res.sum().item() # graph break 4646*da0073e9SAndroid Build Coastguard Worker return res 4647*da0073e9SAndroid Build Coastguard Worker 4648*da0073e9SAndroid Build Coastguard Worker def f2(a, b): 4649*da0073e9SAndroid Build Coastguard Worker d = {"a": a + 1, "b": b + 2} 4650*da0073e9SAndroid Build Coastguard Worker x = d.pop("b") 4651*da0073e9SAndroid Build Coastguard Worker x.sum().item() # graph break 4652*da0073e9SAndroid Build Coastguard Worker y = d["a"] + x 4653*da0073e9SAndroid Build Coastguard Worker y.sum().item() # graph break 4654*da0073e9SAndroid Build Coastguard Worker d["c"] = y 4655*da0073e9SAndroid Build Coastguard Worker return d 4656*da0073e9SAndroid Build Coastguard Worker 4657*da0073e9SAndroid Build Coastguard Worker x = torch.rand([2, 3]) 4658*da0073e9SAndroid Build Coastguard Worker a = torch.rand([5, 6]) 4659*da0073e9SAndroid Build Coastguard Worker b = torch.rand([5, 6]) 4660*da0073e9SAndroid Build Coastguard Worker res11 = f1(x) 4661*da0073e9SAndroid Build Coastguard Worker res21 = f2(a, b) 4662*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 4663*da0073e9SAndroid Build Coastguard Worker opt_f1 = torch._dynamo.optimize(cnts)(f1) 4664*da0073e9SAndroid Build Coastguard Worker opt_f2 = torch._dynamo.optimize(cnts)(f2) 4665*da0073e9SAndroid Build Coastguard Worker res12 = opt_f1(x) 4666*da0073e9SAndroid Build Coastguard Worker res22 = opt_f2(a, b) 4667*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(res11, res12)) 4668*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(res21, res22)) 4669*da0073e9SAndroid Build Coastguard Worker 4670*da0073e9SAndroid Build Coastguard Worker def test_list_append_return_none(self): 4671*da0073e9SAndroid Build Coastguard Worker def fn(x): 4672*da0073e9SAndroid Build Coastguard Worker alist = [] 4673*da0073e9SAndroid Build Coastguard Worker blist = alist.append(x + 1) 4674*da0073e9SAndroid Build Coastguard Worker return alist, blist 4675*da0073e9SAndroid Build Coastguard Worker 4676*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([2.3]) 4677*da0073e9SAndroid Build Coastguard Worker res = fn(x) 4678*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 4679*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 4680*da0073e9SAndroid Build Coastguard Worker res2 = opt_fn(x) 4681*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, res2) 4682*da0073e9SAndroid Build Coastguard Worker 4683*da0073e9SAndroid Build Coastguard Worker @patch.object(torch._dynamo.config, "capture_scalar_outputs", True) 4684*da0073e9SAndroid Build Coastguard Worker def test_tensor_ctor_list_of_tensor(self): 4685*da0073e9SAndroid Build Coastguard Worker def fn(x): 4686*da0073e9SAndroid Build Coastguard Worker return torch.tensor([x], dtype=torch.int64) 4687*da0073e9SAndroid Build Coastguard Worker 4688*da0073e9SAndroid Build Coastguard Worker x = torch.tensor(20) 4689*da0073e9SAndroid Build Coastguard Worker res = fn(x) 4690*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 4691*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 4692*da0073e9SAndroid Build Coastguard Worker res2 = opt_fn(x) 4693*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, res2) 4694*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 4695*da0073e9SAndroid Build Coastguard Worker 4696*da0073e9SAndroid Build Coastguard Worker def test_tensor_types(self): 4697*da0073e9SAndroid Build Coastguard Worker def fn(dtype, tensor_type): 4698*da0073e9SAndroid Build Coastguard Worker x = torch.empty(4, dtype=dtype) 4699*da0073e9SAndroid Build Coastguard Worker assert isinstance(x, tensor_type) 4700*da0073e9SAndroid Build Coastguard Worker 4701*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager")(fn) 4702*da0073e9SAndroid Build Coastguard Worker opt_fn(torch.float32, torch.FloatTensor) 4703*da0073e9SAndroid Build Coastguard Worker opt_fn(torch.float64, torch.DoubleTensor) 4704*da0073e9SAndroid Build Coastguard Worker opt_fn(torch.float16, torch.HalfTensor) 4705*da0073e9SAndroid Build Coastguard Worker opt_fn(torch.bfloat16, torch.BFloat16Tensor) 4706*da0073e9SAndroid Build Coastguard Worker opt_fn(torch.uint8, torch.ByteTensor) 4707*da0073e9SAndroid Build Coastguard Worker opt_fn(torch.int8, torch.CharTensor) 4708*da0073e9SAndroid Build Coastguard Worker opt_fn(torch.int64, torch.LongTensor) 4709*da0073e9SAndroid Build Coastguard Worker opt_fn(torch.int, torch.IntTensor) 4710*da0073e9SAndroid Build Coastguard Worker opt_fn(torch.int16, torch.ShortTensor) 4711*da0073e9SAndroid Build Coastguard Worker opt_fn(torch.bool, torch.BoolTensor) 4712*da0073e9SAndroid Build Coastguard Worker 4713*da0073e9SAndroid Build Coastguard Worker def test_nan(self): 4714*da0073e9SAndroid Build Coastguard Worker def f(x, n): 4715*da0073e9SAndroid Build Coastguard Worker return x * 2 + n 4716*da0073e9SAndroid Build Coastguard Worker 4717*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4) 4718*da0073e9SAndroid Build Coastguard Worker n = float("nan") 4719*da0073e9SAndroid Build Coastguard Worker 4720*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 4721*da0073e9SAndroid Build Coastguard Worker opt_f = torch._dynamo.optimize(cnts)(f) 4722*da0073e9SAndroid Build Coastguard Worker opt_f(x, n) 4723*da0073e9SAndroid Build Coastguard Worker opt_f(x, n) 4724*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 4725*da0073e9SAndroid Build Coastguard Worker 4726*da0073e9SAndroid Build Coastguard Worker @patch.object(torch._dynamo.config, "capture_scalar_outputs", True) 4727*da0073e9SAndroid Build Coastguard Worker def test_item(self): 4728*da0073e9SAndroid Build Coastguard Worker class MyMod(torch.nn.Module): 4729*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 4730*da0073e9SAndroid Build Coastguard Worker z = torch.max(x) 4731*da0073e9SAndroid Build Coastguard Worker return z.int().item() 4732*da0073e9SAndroid Build Coastguard Worker 4733*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([[10.6763, 11.7445, -2.2369]]) 4734*da0073e9SAndroid Build Coastguard Worker model = MyMod() 4735*da0073e9SAndroid Build Coastguard Worker y = torch._dynamo.optimize("eager", nopython=True)(model)(x) 4736*da0073e9SAndroid Build Coastguard Worker 4737*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, 11) 4738*da0073e9SAndroid Build Coastguard Worker 4739*da0073e9SAndroid Build Coastguard Worker @patch.object(torch._dynamo.config, "capture_scalar_outputs", True) 4740*da0073e9SAndroid Build Coastguard Worker def test_item_changes(self): 4741*da0073e9SAndroid Build Coastguard Worker class MyMod(torch.nn.Module): 4742*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 4743*da0073e9SAndroid Build Coastguard Worker z = torch.max(x) 4744*da0073e9SAndroid Build Coastguard Worker return z.int().item() 4745*da0073e9SAndroid Build Coastguard Worker 4746*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([[10.6763, 11.7445, -2.2369]]) 4747*da0073e9SAndroid Build Coastguard Worker model = MyMod() 4748*da0073e9SAndroid Build Coastguard Worker opt_model = torch._dynamo.optimize("eager", nopython=True)(model) 4749*da0073e9SAndroid Build Coastguard Worker y = opt_model(x) 4750*da0073e9SAndroid Build Coastguard Worker z = opt_model(torch.tensor([[y - 5, y + 10, y + 50]])) 4751*da0073e9SAndroid Build Coastguard Worker 4752*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, 11) 4753*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z, 61) 4754*da0073e9SAndroid Build Coastguard Worker 4755*da0073e9SAndroid Build Coastguard Worker @patch.object(torch._dynamo.config, "capture_scalar_outputs", True) 4756*da0073e9SAndroid Build Coastguard Worker def test_item_changes_new_shape(self): 4757*da0073e9SAndroid Build Coastguard Worker class MyMod(torch.nn.Module): 4758*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 4759*da0073e9SAndroid Build Coastguard Worker z = torch.max(x) 4760*da0073e9SAndroid Build Coastguard Worker return z.int().item() 4761*da0073e9SAndroid Build Coastguard Worker 4762*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([[10.6763, 11.7445, -2.2369]]) 4763*da0073e9SAndroid Build Coastguard Worker model = MyMod() 4764*da0073e9SAndroid Build Coastguard Worker opt_model = torch._dynamo.optimize("eager", nopython=True)(model) 4765*da0073e9SAndroid Build Coastguard Worker y = opt_model(x) 4766*da0073e9SAndroid Build Coastguard Worker z = opt_model(torch.tensor([[y - 5, y + 50], [y + 5, y - 50]])) 4767*da0073e9SAndroid Build Coastguard Worker 4768*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, 11) 4769*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z, 61) 4770*da0073e9SAndroid Build Coastguard Worker 4771*da0073e9SAndroid Build Coastguard Worker @unittest.skip("https://github.com/pytorch/pytorch/issues/99726") 4772*da0073e9SAndroid Build Coastguard Worker def test_cross_entropy_loss_fancy_ctor1(self): 4773*da0073e9SAndroid Build Coastguard Worker rand_5 = torch.randn(5) 4774*da0073e9SAndroid Build Coastguard Worker rand_3_5 = torch.randn(3, 5) 4775*da0073e9SAndroid Build Coastguard Worker target = torch.empty(3, dtype=torch.long).random_(5) 4776*da0073e9SAndroid Build Coastguard Worker 4777*da0073e9SAndroid Build Coastguard Worker loss = torch.nn.CrossEntropyLoss( 4778*da0073e9SAndroid Build Coastguard Worker weight=rand_5, reduce=False, label_smoothing=0.5 4779*da0073e9SAndroid Build Coastguard Worker ) 4780*da0073e9SAndroid Build Coastguard Worker opt_loss = torch._dynamo.optimize("eager", nopython=True)(loss) 4781*da0073e9SAndroid Build Coastguard Worker input = rand_3_5 4782*da0073e9SAndroid Build Coastguard Worker dynamo_output = opt_loss(input, target) 4783*da0073e9SAndroid Build Coastguard Worker 4784*da0073e9SAndroid Build Coastguard Worker loss = torch.nn.CrossEntropyLoss( 4785*da0073e9SAndroid Build Coastguard Worker weight=rand_5, reduce=False, label_smoothing=0.5 4786*da0073e9SAndroid Build Coastguard Worker ) 4787*da0073e9SAndroid Build Coastguard Worker input = rand_3_5 4788*da0073e9SAndroid Build Coastguard Worker output = loss(input, target) 4789*da0073e9SAndroid Build Coastguard Worker 4790*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(dynamo_output, output)) 4791*da0073e9SAndroid Build Coastguard Worker 4792*da0073e9SAndroid Build Coastguard Worker def test_cross_entropy_loss_fancy_ctor2(self): 4793*da0073e9SAndroid Build Coastguard Worker rand_3_5 = torch.randn(3, 5) 4794*da0073e9SAndroid Build Coastguard Worker target = torch.empty(3, dtype=torch.long).random_(5) 4795*da0073e9SAndroid Build Coastguard Worker 4796*da0073e9SAndroid Build Coastguard Worker loss = torch.nn.CrossEntropyLoss(reduce=False, label_smoothing=0.5) 4797*da0073e9SAndroid Build Coastguard Worker opt_loss = torch._dynamo.optimize("eager", nopython=True)(loss) 4798*da0073e9SAndroid Build Coastguard Worker input = rand_3_5 4799*da0073e9SAndroid Build Coastguard Worker dynamo_output = opt_loss(input, target) 4800*da0073e9SAndroid Build Coastguard Worker 4801*da0073e9SAndroid Build Coastguard Worker loss = torch.nn.CrossEntropyLoss(reduce=False, label_smoothing=0.5) 4802*da0073e9SAndroid Build Coastguard Worker input = rand_3_5 4803*da0073e9SAndroid Build Coastguard Worker output = loss(input, target) 4804*da0073e9SAndroid Build Coastguard Worker 4805*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(dynamo_output, output)) 4806*da0073e9SAndroid Build Coastguard Worker 4807*da0073e9SAndroid Build Coastguard Worker def test_cross_entropy_loss_simple_ctor(self): 4808*da0073e9SAndroid Build Coastguard Worker output = None 4809*da0073e9SAndroid Build Coastguard Worker rand_3_5 = torch.randn(3, 5) 4810*da0073e9SAndroid Build Coastguard Worker target = torch.empty(3, dtype=torch.long).random_(5) 4811*da0073e9SAndroid Build Coastguard Worker 4812*da0073e9SAndroid Build Coastguard Worker loss = torch.nn.CrossEntropyLoss() 4813*da0073e9SAndroid Build Coastguard Worker opt_loss = torch._dynamo.optimize("eager", nopython=True)(loss) 4814*da0073e9SAndroid Build Coastguard Worker input = rand_3_5 4815*da0073e9SAndroid Build Coastguard Worker dynamo_output = opt_loss(input, target) 4816*da0073e9SAndroid Build Coastguard Worker 4817*da0073e9SAndroid Build Coastguard Worker loss = torch.nn.CrossEntropyLoss() 4818*da0073e9SAndroid Build Coastguard Worker input = rand_3_5 4819*da0073e9SAndroid Build Coastguard Worker output = loss(input, target) 4820*da0073e9SAndroid Build Coastguard Worker 4821*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(dynamo_output, output)) 4822*da0073e9SAndroid Build Coastguard Worker 4823*da0073e9SAndroid Build Coastguard Worker def test_nn_functional_reduction(self): 4824*da0073e9SAndroid Build Coastguard Worker def fn(loss, reduction): 4825*da0073e9SAndroid Build Coastguard Worker reduction_enum = F._Reduction.get_enum(reduction) 4826*da0073e9SAndroid Build Coastguard Worker if reduction_enum == 0: 4827*da0073e9SAndroid Build Coastguard Worker return loss 4828*da0073e9SAndroid Build Coastguard Worker elif reduction_enum == 1: 4829*da0073e9SAndroid Build Coastguard Worker return loss.mean() 4830*da0073e9SAndroid Build Coastguard Worker elif reduction_enum == 2: 4831*da0073e9SAndroid Build Coastguard Worker return loss.sum() 4832*da0073e9SAndroid Build Coastguard Worker 4833*da0073e9SAndroid Build Coastguard Worker x = torch.rand([3, 5]) 4834*da0073e9SAndroid Build Coastguard Worker y = "mean" 4835*da0073e9SAndroid Build Coastguard Worker ref = fn(x, y) 4836*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) 4837*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x, y) 4838*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(ref, res)) 4839*da0073e9SAndroid Build Coastguard Worker 4840*da0073e9SAndroid Build Coastguard Worker def test_large_reduction_list(self): 4841*da0073e9SAndroid Build Coastguard Worker dtype = torch.float32 4842*da0073e9SAndroid Build Coastguard Worker device = "cpu" 4843*da0073e9SAndroid Build Coastguard Worker 4844*da0073e9SAndroid Build Coastguard Worker def check_sum_all(tensor: torch.Tensor) -> None: 4845*da0073e9SAndroid Build Coastguard Worker pylist = tensor.reshape(-1).tolist() 4846*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(tensor.sum(), torch.tensor(sum(pylist)))) 4847*da0073e9SAndroid Build Coastguard Worker 4848*da0073e9SAndroid Build Coastguard Worker check_sum_all(torch.randn(200000, dtype=dtype, device=device)) 4849*da0073e9SAndroid Build Coastguard Worker 4850*da0073e9SAndroid Build Coastguard Worker def test_raise_on_backend_error(self): 4851*da0073e9SAndroid Build Coastguard Worker def my_compiler(gm, _): 4852*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("duck!") 4853*da0073e9SAndroid Build Coastguard Worker 4854*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize(my_compiler) 4855*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 4856*da0073e9SAndroid Build Coastguard Worker return a + b / (a - b) 4857*da0073e9SAndroid Build Coastguard Worker 4858*da0073e9SAndroid Build Coastguard Worker self.assertRaises( 4859*da0073e9SAndroid Build Coastguard Worker torch._dynamo.exc.BackendCompilerFailed, 4860*da0073e9SAndroid Build Coastguard Worker lambda: fn(torch.randn(10), torch.randn(10)), 4861*da0073e9SAndroid Build Coastguard Worker ) 4862*da0073e9SAndroid Build Coastguard Worker 4863*da0073e9SAndroid Build Coastguard Worker def test_named_parameters(self): 4864*da0073e9SAndroid Build Coastguard Worker n_embd = 768 4865*da0073e9SAndroid Build Coastguard Worker block_size = 128 4866*da0073e9SAndroid Build Coastguard Worker vocab_size = 65 4867*da0073e9SAndroid Build Coastguard Worker embd_pdrop = 0.1 4868*da0073e9SAndroid Build Coastguard Worker 4869*da0073e9SAndroid Build Coastguard Worker class MyModel2(torch.nn.Module): 4870*da0073e9SAndroid Build Coastguard Worker def __init__(self): 4871*da0073e9SAndroid Build Coastguard Worker super().__init__() 4872*da0073e9SAndroid Build Coastguard Worker self.tok_emb = torch.nn.Embedding(vocab_size, n_embd) 4873*da0073e9SAndroid Build Coastguard Worker self.pos_emb = torch.nn.Parameter(torch.zeros(1, block_size, n_embd)) 4874*da0073e9SAndroid Build Coastguard Worker self.drop = torch.nn.Dropout(embd_pdrop) 4875*da0073e9SAndroid Build Coastguard Worker 4876*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 4877*da0073e9SAndroid Build Coastguard Worker return x 4878*da0073e9SAndroid Build Coastguard Worker 4879*da0073e9SAndroid Build Coastguard Worker class MyModel(torch.nn.Module): 4880*da0073e9SAndroid Build Coastguard Worker def __init__(self): 4881*da0073e9SAndroid Build Coastguard Worker super().__init__() 4882*da0073e9SAndroid Build Coastguard Worker self.tok_emb = torch.nn.Embedding(vocab_size, n_embd) 4883*da0073e9SAndroid Build Coastguard Worker self.pos_emb = torch.nn.Parameter(torch.zeros(1, block_size, n_embd)) 4884*da0073e9SAndroid Build Coastguard Worker self.drop = torch.nn.Dropout(embd_pdrop) 4885*da0073e9SAndroid Build Coastguard Worker self.submod2 = MyModel2() 4886*da0073e9SAndroid Build Coastguard Worker 4887*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 4888*da0073e9SAndroid Build Coastguard Worker return x 4889*da0073e9SAndroid Build Coastguard Worker 4890*da0073e9SAndroid Build Coastguard Worker # Regular 4891*da0073e9SAndroid Build Coastguard Worker params = [] 4892*da0073e9SAndroid Build Coastguard Worker mod = MyModel() 4893*da0073e9SAndroid Build Coastguard Worker actual_params = list(mod.named_parameters()) 4894*da0073e9SAndroid Build Coastguard Worker 4895*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize("eager", nopython=True) 4896*da0073e9SAndroid Build Coastguard Worker def fn(): 4897*da0073e9SAndroid Build Coastguard Worker return list(mod.named_parameters()) 4898*da0073e9SAndroid Build Coastguard Worker 4899*da0073e9SAndroid Build Coastguard Worker params = fn() 4900*da0073e9SAndroid Build Coastguard Worker 4901*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(actual_params), len(params)) 4902*da0073e9SAndroid Build Coastguard Worker for idx in range(len(params)): 4903*da0073e9SAndroid Build Coastguard Worker k_a, v_a = actual_params[idx] 4904*da0073e9SAndroid Build Coastguard Worker k, v = params[idx] 4905*da0073e9SAndroid Build Coastguard Worker self.assertEqual(k_a, k) 4906*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(v_a, v)) 4907*da0073e9SAndroid Build Coastguard Worker 4908*da0073e9SAndroid Build Coastguard Worker # Prefix 4909*da0073e9SAndroid Build Coastguard Worker params = [] 4910*da0073e9SAndroid Build Coastguard Worker mod = MyModel() 4911*da0073e9SAndroid Build Coastguard Worker actual_params = list(mod.named_parameters(prefix="foo")) 4912*da0073e9SAndroid Build Coastguard Worker 4913*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize("eager", nopython=True) 4914*da0073e9SAndroid Build Coastguard Worker def fn1(): 4915*da0073e9SAndroid Build Coastguard Worker return list(mod.named_parameters(prefix="foo")) 4916*da0073e9SAndroid Build Coastguard Worker 4917*da0073e9SAndroid Build Coastguard Worker params = fn1() 4918*da0073e9SAndroid Build Coastguard Worker 4919*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(actual_params), len(params)) 4920*da0073e9SAndroid Build Coastguard Worker for idx in range(len(params)): 4921*da0073e9SAndroid Build Coastguard Worker k_a, v_a = actual_params[idx] 4922*da0073e9SAndroid Build Coastguard Worker k, v = params[idx] 4923*da0073e9SAndroid Build Coastguard Worker self.assertEqual(k_a, k) 4924*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(v_a, v)) 4925*da0073e9SAndroid Build Coastguard Worker 4926*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch(guard_nn_modules=True) 4927*da0073e9SAndroid Build Coastguard Worker def test_module_complex_iter(self): 4928*da0073e9SAndroid Build Coastguard Worker n_embd = 768 4929*da0073e9SAndroid Build Coastguard Worker block_size = 128 4930*da0073e9SAndroid Build Coastguard Worker vocab_size = 65 4931*da0073e9SAndroid Build Coastguard Worker embd_pdrop = 0.1 4932*da0073e9SAndroid Build Coastguard Worker 4933*da0073e9SAndroid Build Coastguard Worker class FakeGPT(torch.nn.Module): 4934*da0073e9SAndroid Build Coastguard Worker def __init__(self): 4935*da0073e9SAndroid Build Coastguard Worker super().__init__() 4936*da0073e9SAndroid Build Coastguard Worker self.tok_emb = torch.nn.Embedding(vocab_size, n_embd) 4937*da0073e9SAndroid Build Coastguard Worker self.pos_emb = torch.nn.Parameter(torch.zeros(1, block_size, n_embd)) 4938*da0073e9SAndroid Build Coastguard Worker self.drop = torch.nn.Dropout(embd_pdrop) 4939*da0073e9SAndroid Build Coastguard Worker self.ln_f = torch.nn.LayerNorm(n_embd) 4940*da0073e9SAndroid Build Coastguard Worker self.head = torch.nn.Linear(n_embd, vocab_size, bias=False) 4941*da0073e9SAndroid Build Coastguard Worker 4942*da0073e9SAndroid Build Coastguard Worker self.block_size = block_size 4943*da0073e9SAndroid Build Coastguard Worker self.names = [] 4944*da0073e9SAndroid Build Coastguard Worker 4945*da0073e9SAndroid Build Coastguard Worker def forward(self, idx, targets=None): 4946*da0073e9SAndroid Build Coastguard Worker b, t = idx.size() 4947*da0073e9SAndroid Build Coastguard Worker assert ( 4948*da0073e9SAndroid Build Coastguard Worker t <= self.block_size 4949*da0073e9SAndroid Build Coastguard Worker ), "Cannot forward, model block size is exhausted." 4950*da0073e9SAndroid Build Coastguard Worker 4951*da0073e9SAndroid Build Coastguard Worker # forward the GPT model 4952*da0073e9SAndroid Build Coastguard Worker token_embeddings = self.tok_emb( 4953*da0073e9SAndroid Build Coastguard Worker idx 4954*da0073e9SAndroid Build Coastguard Worker ) # each index maps to a (learnable) vector 4955*da0073e9SAndroid Build Coastguard Worker position_embeddings = self.pos_emb[ 4956*da0073e9SAndroid Build Coastguard Worker :, :t, : 4957*da0073e9SAndroid Build Coastguard Worker ] # each position maps to a (learnable) vector 4958*da0073e9SAndroid Build Coastguard Worker x = self.drop(token_embeddings + position_embeddings) 4959*da0073e9SAndroid Build Coastguard Worker x = self.blocks(x) 4960*da0073e9SAndroid Build Coastguard Worker x = self.ln_f(x) 4961*da0073e9SAndroid Build Coastguard Worker logits = self.head(x) 4962*da0073e9SAndroid Build Coastguard Worker 4963*da0073e9SAndroid Build Coastguard Worker # if we are given some desired targets also calculate the loss 4964*da0073e9SAndroid Build Coastguard Worker loss = None 4965*da0073e9SAndroid Build Coastguard Worker if targets is not None: 4966*da0073e9SAndroid Build Coastguard Worker loss = F.cross_entropy( 4967*da0073e9SAndroid Build Coastguard Worker logits.view(-1, logits.size(-1)), targets.view(-1) 4968*da0073e9SAndroid Build Coastguard Worker ) 4969*da0073e9SAndroid Build Coastguard Worker 4970*da0073e9SAndroid Build Coastguard Worker return logits, loss 4971*da0073e9SAndroid Build Coastguard Worker 4972*da0073e9SAndroid Build Coastguard Worker def foo(self, memo=None, prefix="", remove_duplicate=False): 4973*da0073e9SAndroid Build Coastguard Worker for mn, m in self.named_modules( 4974*da0073e9SAndroid Build Coastguard Worker memo=memo, prefix=prefix, remove_duplicate=remove_duplicate 4975*da0073e9SAndroid Build Coastguard Worker ): 4976*da0073e9SAndroid Build Coastguard Worker for pn, p in self.named_parameters(): 4977*da0073e9SAndroid Build Coastguard Worker fpn = f"{mn}.{pn}" if mn else pn 4978*da0073e9SAndroid Build Coastguard Worker self.names.append(fpn) 4979*da0073e9SAndroid Build Coastguard Worker 4980*da0073e9SAndroid Build Coastguard Worker # Test plain recurse 4981*da0073e9SAndroid Build Coastguard Worker model_a = FakeGPT() 4982*da0073e9SAndroid Build Coastguard Worker model_a.foo() 4983*da0073e9SAndroid Build Coastguard Worker a_names = model_a.names 4984*da0073e9SAndroid Build Coastguard Worker 4985*da0073e9SAndroid Build Coastguard Worker model_b = FakeGPT() 4986*da0073e9SAndroid Build Coastguard Worker opt_model_b = torch._dynamo.optimize("eager", nopython=True)(model_b) 4987*da0073e9SAndroid Build Coastguard Worker opt_model_b.foo() 4988*da0073e9SAndroid Build Coastguard Worker 4989*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a_names, model_b.names) 4990*da0073e9SAndroid Build Coastguard Worker 4991*da0073e9SAndroid Build Coastguard Worker # Test with prefix 4992*da0073e9SAndroid Build Coastguard Worker model_a = FakeGPT() 4993*da0073e9SAndroid Build Coastguard Worker model_a.foo(prefix="abc") 4994*da0073e9SAndroid Build Coastguard Worker a_names = model_a.names 4995*da0073e9SAndroid Build Coastguard Worker 4996*da0073e9SAndroid Build Coastguard Worker model_b = FakeGPT() 4997*da0073e9SAndroid Build Coastguard Worker opt_model_b = torch._dynamo.optimize("eager", nopython=True)(model_b) 4998*da0073e9SAndroid Build Coastguard Worker opt_model_b.foo(prefix="abc") 4999*da0073e9SAndroid Build Coastguard Worker 5000*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a_names, model_b.names) 5001*da0073e9SAndroid Build Coastguard Worker 5002*da0073e9SAndroid Build Coastguard Worker def test_numpy_variable_isinstance(self): 5003*da0073e9SAndroid Build Coastguard Worker def fn(x, m): 5004*da0073e9SAndroid Build Coastguard Worker if isinstance(m, np.ndarray): 5005*da0073e9SAndroid Build Coastguard Worker return x + 1 5006*da0073e9SAndroid Build Coastguard Worker else: 5007*da0073e9SAndroid Build Coastguard Worker return x - 1 5008*da0073e9SAndroid Build Coastguard Worker 5009*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([2.3]) 5010*da0073e9SAndroid Build Coastguard Worker m = np.array([1, 2, 3]) 5011*da0073e9SAndroid Build Coastguard Worker ref = fn(x, m) 5012*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 5013*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 5014*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x, m) 5015*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 5016*da0073e9SAndroid Build Coastguard Worker 5017*da0073e9SAndroid Build Coastguard Worker # Test now the other path 5018*da0073e9SAndroid Build Coastguard Worker ref = fn(x, x) 5019*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x, x) 5020*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 5021*da0073e9SAndroid Build Coastguard Worker 5022*da0073e9SAndroid Build Coastguard Worker def test_tensor_dot_grad_no_graph_break(self): 5023*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 5024*da0073e9SAndroid Build Coastguard Worker y = 3 * a**3 - b**2 5025*da0073e9SAndroid Build Coastguard Worker y.backward(gradient=torch.tensor([1.0, 1.0])) 5026*da0073e9SAndroid Build Coastguard Worker b.grad.zero_() 5027*da0073e9SAndroid Build Coastguard Worker return a.grad, b.grad 5028*da0073e9SAndroid Build Coastguard Worker 5029*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([2.0, 3.0], requires_grad=True) 5030*da0073e9SAndroid Build Coastguard Worker b = torch.tensor([6.0, 4.0], requires_grad=True) 5031*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 5032*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 5033*da0073e9SAndroid Build Coastguard Worker _, b_grad = opt_fn(a, b) 5034*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(b_grad, torch.tensor([0.0, 0.0]))) 5035*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 2) 5036*da0073e9SAndroid Build Coastguard Worker 5037*da0073e9SAndroid Build Coastguard Worker def test_torch_nn_parameter_isinstance(self): 5038*da0073e9SAndroid Build Coastguard Worker def fn(x): 5039*da0073e9SAndroid Build Coastguard Worker a = torch.nn.Parameter(torch.rand(2, 3)) 5040*da0073e9SAndroid Build Coastguard Worker if isinstance(a, torch.Tensor): 5041*da0073e9SAndroid Build Coastguard Worker return x + 1 5042*da0073e9SAndroid Build Coastguard Worker else: 5043*da0073e9SAndroid Build Coastguard Worker return x - 1 5044*da0073e9SAndroid Build Coastguard Worker 5045*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([2.5]) 5046*da0073e9SAndroid Build Coastguard Worker ref = fn(x) 5047*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager")(fn) 5048*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x) 5049*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 5050*da0073e9SAndroid Build Coastguard Worker 5051*da0073e9SAndroid Build Coastguard Worker def _optimize_then_check_exp( 5052*da0073e9SAndroid Build Coastguard Worker self, foo, args, cnt, exp_out, exp_frame_count, exp_n_cached_backend 5053*da0073e9SAndroid Build Coastguard Worker ): 5054*da0073e9SAndroid Build Coastguard Worker opt_out = torch._dynamo.optimize(backend=cnt)(foo)(*args) 5055*da0073e9SAndroid Build Coastguard Worker self.assertEqual(exp_out, opt_out) 5056*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, exp_frame_count) 5057*da0073e9SAndroid Build Coastguard Worker 5058*da0073e9SAndroid Build Coastguard Worker def test_backend_match_guard(self): 5059*da0073e9SAndroid Build Coastguard Worker x = torch.randn([3, 4]) 5060*da0073e9SAndroid Build Coastguard Worker 5061*da0073e9SAndroid Build Coastguard Worker def foo(x): 5062*da0073e9SAndroid Build Coastguard Worker return x.sin() + x.cos() 5063*da0073e9SAndroid Build Coastguard Worker 5064*da0073e9SAndroid Build Coastguard Worker def foo_graph_break(x): 5065*da0073e9SAndroid Build Coastguard Worker a = x.sin() 5066*da0073e9SAndroid Build Coastguard Worker torch._dynamo.graph_break() 5067*da0073e9SAndroid Build Coastguard Worker b = x.cos() 5068*da0073e9SAndroid Build Coastguard Worker return a + b 5069*da0073e9SAndroid Build Coastguard Worker 5070*da0073e9SAndroid Build Coastguard Worker eager_record_backend = torch._dynamo.testing.EagerAndRecordGraphs() 5071*da0073e9SAndroid Build Coastguard Worker backends = [eager_record_backend, "eager"] 5072*da0073e9SAndroid Build Coastguard Worker 5073*da0073e9SAndroid Build Coastguard Worker # We intentionally don't reset dynamo for each backend so that we can test 5074*da0073e9SAndroid Build Coastguard Worker # 1. dynamo doesn't recompile when backend stays the same, i.e. frame_count doesn't increase 5075*da0073e9SAndroid Build Coastguard Worker # 2. dynamo recompiles when backend changes, i.e. frame_count is non-zero for next backend 5076*da0073e9SAndroid Build Coastguard Worker def test_recompile(foo, *, exp_frame_count): 5077*da0073e9SAndroid Build Coastguard Worker eager_result = foo(x) 5078*da0073e9SAndroid Build Coastguard Worker for i, backend in enumerate(backends): 5079*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounterWithBackend(backend) 5080*da0073e9SAndroid Build Coastguard Worker # Run opt_f multiple times to make sure dynamo doesn't recompile. 5081*da0073e9SAndroid Build Coastguard Worker # Specifically, frame_count doesn't increase 5082*da0073e9SAndroid Build Coastguard Worker # the number of cached backends is i + 2 because we have the optimizing backend + None 5083*da0073e9SAndroid Build Coastguard Worker self._optimize_then_check_exp( 5084*da0073e9SAndroid Build Coastguard Worker foo, (x,), cnt, eager_result, exp_frame_count, i + 2 5085*da0073e9SAndroid Build Coastguard Worker ) 5086*da0073e9SAndroid Build Coastguard Worker self._optimize_then_check_exp( 5087*da0073e9SAndroid Build Coastguard Worker foo, (x,), cnt, eager_result, exp_frame_count, i + 2 5088*da0073e9SAndroid Build Coastguard Worker ) 5089*da0073e9SAndroid Build Coastguard Worker self._optimize_then_check_exp( 5090*da0073e9SAndroid Build Coastguard Worker foo, (x,), cnt, eager_result, exp_frame_count, i + 2 5091*da0073e9SAndroid Build Coastguard Worker ) 5092*da0073e9SAndroid Build Coastguard Worker 5093*da0073e9SAndroid Build Coastguard Worker test_recompile(foo, exp_frame_count=1) 5094*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 5095*da0073e9SAndroid Build Coastguard Worker test_recompile(foo_graph_break, exp_frame_count=2) 5096*da0073e9SAndroid Build Coastguard Worker 5097*da0073e9SAndroid Build Coastguard Worker def test_backend_match_guard_multi_threads(self): 5098*da0073e9SAndroid Build Coastguard Worker x = torch.randn([3, 4]) 5099*da0073e9SAndroid Build Coastguard Worker 5100*da0073e9SAndroid Build Coastguard Worker def foo(x): 5101*da0073e9SAndroid Build Coastguard Worker return x.sin() + x.cos() 5102*da0073e9SAndroid Build Coastguard Worker 5103*da0073e9SAndroid Build Coastguard Worker def compile_then_check_exp(foo, args, cnt, eager_result, exp_frame_count): 5104*da0073e9SAndroid Build Coastguard Worker for i in range(3): 5105*da0073e9SAndroid Build Coastguard Worker opt_out = torch._dynamo.optimize(backend=cnt)(foo)(*args) 5106*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_out, eager_result) 5107*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, exp_frame_count) 5108*da0073e9SAndroid Build Coastguard Worker thread_success[threading.current_thread()] = True 5109*da0073e9SAndroid Build Coastguard Worker 5110*da0073e9SAndroid Build Coastguard Worker eager_record_backend = torch._dynamo.testing.EagerAndRecordGraphs() 5111*da0073e9SAndroid Build Coastguard Worker backends = [eager_record_backend, "eager"] 5112*da0073e9SAndroid Build Coastguard Worker 5113*da0073e9SAndroid Build Coastguard Worker # Test dynamo recompiles but only caches a single backend for each thread 5114*da0073e9SAndroid Build Coastguard Worker eager_result = foo(x) 5115*da0073e9SAndroid Build Coastguard Worker # cnt and None 5116*da0073e9SAndroid Build Coastguard Worker exp_frame_count = 1 5117*da0073e9SAndroid Build Coastguard Worker threads = [] 5118*da0073e9SAndroid Build Coastguard Worker thread_success = {} 5119*da0073e9SAndroid Build Coastguard Worker for i, backend in enumerate(backends): 5120*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounterWithBackend(backend) 5121*da0073e9SAndroid Build Coastguard Worker thread = threading.Thread( 5122*da0073e9SAndroid Build Coastguard Worker target=compile_then_check_exp, 5123*da0073e9SAndroid Build Coastguard Worker args=( 5124*da0073e9SAndroid Build Coastguard Worker foo, 5125*da0073e9SAndroid Build Coastguard Worker (x,), 5126*da0073e9SAndroid Build Coastguard Worker cnt, 5127*da0073e9SAndroid Build Coastguard Worker eager_result, 5128*da0073e9SAndroid Build Coastguard Worker exp_frame_count, 5129*da0073e9SAndroid Build Coastguard Worker ), 5130*da0073e9SAndroid Build Coastguard Worker ) 5131*da0073e9SAndroid Build Coastguard Worker threads.append(thread) 5132*da0073e9SAndroid Build Coastguard Worker thread.start() 5133*da0073e9SAndroid Build Coastguard Worker 5134*da0073e9SAndroid Build Coastguard Worker # Wait for all threads to finish 5135*da0073e9SAndroid Build Coastguard Worker for thread in threads: 5136*da0073e9SAndroid Build Coastguard Worker thread.join() 5137*da0073e9SAndroid Build Coastguard Worker 5138*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(thread_success), len(threads)) 5139*da0073e9SAndroid Build Coastguard Worker 5140*da0073e9SAndroid Build Coastguard Worker def test_dynamo_min_operator_with_shape(self): 5141*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize("eager", nopython=True) 5142*da0073e9SAndroid Build Coastguard Worker def f(x, a): 5143*da0073e9SAndroid Build Coastguard Worker return min(x.shape[0], a) 5144*da0073e9SAndroid Build Coastguard Worker 5145*da0073e9SAndroid Build Coastguard Worker result = f(torch.ones(6), 3) 5146*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, 3) 5147*da0073e9SAndroid Build Coastguard Worker 5148*da0073e9SAndroid Build Coastguard Worker def test_onnx_shape_as_tensor(self): 5149*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize("eager", nopython=True) 5150*da0073e9SAndroid Build Coastguard Worker def f(x): 5151*da0073e9SAndroid Build Coastguard Worker return 1 + torch._shape_as_tensor(x)[0] 5152*da0073e9SAndroid Build Coastguard Worker 5153*da0073e9SAndroid Build Coastguard Worker gm, _ = torch._dynamo.export(f)(torch.ones(6)) 5154*da0073e9SAndroid Build Coastguard Worker 5155*da0073e9SAndroid Build Coastguard Worker input_one_dim = torch.ones(6) 5156*da0073e9SAndroid Build Coastguard Worker input_two_dims = torch.ones(7, 4) 5157*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f(input_one_dim), 7) 5158*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f(input_two_dims), 8) 5159*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f(input_two_dims), 8) 5160*da0073e9SAndroid Build Coastguard Worker 5161*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize("eager", nopython=True) 5162*da0073e9SAndroid Build Coastguard Worker def f_onnx(x): 5163*da0073e9SAndroid Build Coastguard Worker return 1 + torch.onnx.operators.shape_as_tensor(x)[0] 5164*da0073e9SAndroid Build Coastguard Worker 5165*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f_onnx(input_one_dim), 7) 5166*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f_onnx(input_two_dims), 8) 5167*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f_onnx(input_two_dims), 8) 5168*da0073e9SAndroid Build Coastguard Worker 5169*da0073e9SAndroid Build Coastguard Worker def test_cond(self): 5170*da0073e9SAndroid Build Coastguard Worker from functorch.experimental.control_flow import cond 5171*da0073e9SAndroid Build Coastguard Worker 5172*da0073e9SAndroid Build Coastguard Worker def true_fn(x): 5173*da0073e9SAndroid Build Coastguard Worker return x.sin() 5174*da0073e9SAndroid Build Coastguard Worker 5175*da0073e9SAndroid Build Coastguard Worker def false_fn(x): 5176*da0073e9SAndroid Build Coastguard Worker return x.cos() 5177*da0073e9SAndroid Build Coastguard Worker 5178*da0073e9SAndroid Build Coastguard Worker def f(pred, x): 5179*da0073e9SAndroid Build Coastguard Worker return cond(pred, true_fn, false_fn, [x]) 5180*da0073e9SAndroid Build Coastguard Worker 5181*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager")(f) 5182*da0073e9SAndroid Build Coastguard Worker a = opt_fn(torch.tensor(False), torch.tensor([0.25, 0.25])) 5183*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(torch.cos(torch.tensor([0.25, 0.25])), a)) 5184*da0073e9SAndroid Build Coastguard Worker b = opt_fn(torch.tensor(True), torch.tensor([0.25, 0.25])) 5185*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(torch.sin(torch.tensor([0.25, 0.25])), b)) 5186*da0073e9SAndroid Build Coastguard Worker 5187*da0073e9SAndroid Build Coastguard Worker def test_nonzero_static(self): 5188*da0073e9SAndroid Build Coastguard Worker # invalid size 5189*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 5190*da0073e9SAndroid Build Coastguard Worker RuntimeError, "nonzero_static: 'size' must be an non-negative integer" 5191*da0073e9SAndroid Build Coastguard Worker ): 5192*da0073e9SAndroid Build Coastguard Worker torch.nonzero_static(torch.tensor([8]), size=-2) 5193*da0073e9SAndroid Build Coastguard Worker 5194*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 5195*da0073e9SAndroid Build Coastguard Worker RuntimeError, "nonzero_static: 'size' must be an non-negative integer" 5196*da0073e9SAndroid Build Coastguard Worker ): 5197*da0073e9SAndroid Build Coastguard Worker torch.nonzero_static(torch.tensor([8]), size=-2, out=torch.tensor(0)) 5198*da0073e9SAndroid Build Coastguard Worker 5199*da0073e9SAndroid Build Coastguard Worker # nonzero_static.out: out dtype mismatch 5200*da0073e9SAndroid Build Coastguard Worker input_tensor = torch.tensor([8]) 5201*da0073e9SAndroid Build Coastguard Worker static_size = 1 5202*da0073e9SAndroid Build Coastguard Worker out_tensor = torch.empty((static_size, input_tensor.dim()), dtype=torch.float) 5203*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 5204*da0073e9SAndroid Build Coastguard Worker RuntimeError, "nonzero_static: Expected out tensor to have scalar type Long" 5205*da0073e9SAndroid Build Coastguard Worker ): 5206*da0073e9SAndroid Build Coastguard Worker torch.nonzero_static(input_tensor, size=static_size, out=out_tensor) 5207*da0073e9SAndroid Build Coastguard Worker 5208*da0073e9SAndroid Build Coastguard Worker # nonzero_static.out: out resize (shrink) 5209*da0073e9SAndroid Build Coastguard Worker input_tensor = torch.tensor([8]) 5210*da0073e9SAndroid Build Coastguard Worker static_size = 1 5211*da0073e9SAndroid Build Coastguard Worker out_tensor = torch.empty((10, 10, 10, 10), dtype=torch.long) 5212*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 5213*da0073e9SAndroid Build Coastguard Worker same( 5214*da0073e9SAndroid Build Coastguard Worker torch.nonzero_static(input_tensor, size=static_size, out=out_tensor), 5215*da0073e9SAndroid Build Coastguard Worker torch.tensor([0]), 5216*da0073e9SAndroid Build Coastguard Worker ) 5217*da0073e9SAndroid Build Coastguard Worker ) 5218*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 5219*da0073e9SAndroid Build Coastguard Worker same( 5220*da0073e9SAndroid Build Coastguard Worker out_tensor, 5221*da0073e9SAndroid Build Coastguard Worker torch.tensor([0]), 5222*da0073e9SAndroid Build Coastguard Worker ) 5223*da0073e9SAndroid Build Coastguard Worker ) 5224*da0073e9SAndroid Build Coastguard Worker 5225*da0073e9SAndroid Build Coastguard Worker # nonzero_static.out: out resize (enlarge) 5226*da0073e9SAndroid Build Coastguard Worker input_tensor = torch.tensor([8]) 5227*da0073e9SAndroid Build Coastguard Worker static_size = 1 5228*da0073e9SAndroid Build Coastguard Worker out_tensor = torch.empty((0), dtype=torch.long) 5229*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 5230*da0073e9SAndroid Build Coastguard Worker same( 5231*da0073e9SAndroid Build Coastguard Worker torch.nonzero_static(input_tensor, size=static_size, out=out_tensor), 5232*da0073e9SAndroid Build Coastguard Worker torch.tensor([0]), 5233*da0073e9SAndroid Build Coastguard Worker ) 5234*da0073e9SAndroid Build Coastguard Worker ) 5235*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 5236*da0073e9SAndroid Build Coastguard Worker same( 5237*da0073e9SAndroid Build Coastguard Worker out_tensor, 5238*da0073e9SAndroid Build Coastguard Worker torch.tensor([0]), 5239*da0073e9SAndroid Build Coastguard Worker ) 5240*da0073e9SAndroid Build Coastguard Worker ) 5241*da0073e9SAndroid Build Coastguard Worker 5242*da0073e9SAndroid Build Coastguard Worker # 0 rank 5243*da0073e9SAndroid Build Coastguard Worker input_tensor = torch.tensor(6) 5244*da0073e9SAndroid Build Coastguard Worker static_size = 2 5245*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 5246*da0073e9SAndroid Build Coastguard Worker same( 5247*da0073e9SAndroid Build Coastguard Worker torch.nonzero_static(input_tensor, size=static_size), 5248*da0073e9SAndroid Build Coastguard Worker torch.empty((static_size, input_tensor.dim()), dtype=torch.long), 5249*da0073e9SAndroid Build Coastguard Worker ) 5250*da0073e9SAndroid Build Coastguard Worker ) 5251*da0073e9SAndroid Build Coastguard Worker 5252*da0073e9SAndroid Build Coastguard Worker # 0 size 5253*da0073e9SAndroid Build Coastguard Worker input_tensor = torch.tensor([[[1]]]) 5254*da0073e9SAndroid Build Coastguard Worker static_size = 0 5255*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 5256*da0073e9SAndroid Build Coastguard Worker same( 5257*da0073e9SAndroid Build Coastguard Worker torch.nonzero_static(input_tensor, size=static_size), 5258*da0073e9SAndroid Build Coastguard Worker torch.empty((static_size, input_tensor.dim()), dtype=torch.long), 5259*da0073e9SAndroid Build Coastguard Worker ) 5260*da0073e9SAndroid Build Coastguard Worker ) 5261*da0073e9SAndroid Build Coastguard Worker 5262*da0073e9SAndroid Build Coastguard Worker # 1D input 5263*da0073e9SAndroid Build Coastguard Worker input_tensor = torch.tensor([0, 8]) 5264*da0073e9SAndroid Build Coastguard Worker static_size = 1 5265*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 5266*da0073e9SAndroid Build Coastguard Worker same( 5267*da0073e9SAndroid Build Coastguard Worker torch.nonzero_static(input_tensor, size=static_size), 5268*da0073e9SAndroid Build Coastguard Worker torch.tensor([1]), 5269*da0073e9SAndroid Build Coastguard Worker ) 5270*da0073e9SAndroid Build Coastguard Worker ) 5271*da0073e9SAndroid Build Coastguard Worker 5272*da0073e9SAndroid Build Coastguard Worker input_tensor = torch.tensor([8, 0]) 5273*da0073e9SAndroid Build Coastguard Worker static_size = 2 5274*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 5275*da0073e9SAndroid Build Coastguard Worker same( 5276*da0073e9SAndroid Build Coastguard Worker torch.nonzero_static(input_tensor, size=static_size), 5277*da0073e9SAndroid Build Coastguard Worker torch.tensor([[0], [-1]]), # padded with default fill_value "-1" 5278*da0073e9SAndroid Build Coastguard Worker ) 5279*da0073e9SAndroid Build Coastguard Worker ) 5280*da0073e9SAndroid Build Coastguard Worker 5281*da0073e9SAndroid Build Coastguard Worker # 2D input 5282*da0073e9SAndroid Build Coastguard Worker input_tensor = torch.tensor([[1.2, 0], [3.4, 5.6]]) 5283*da0073e9SAndroid Build Coastguard Worker static_size = 5 5284*da0073e9SAndroid Build Coastguard Worker fill_value = -100 5285*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 5286*da0073e9SAndroid Build Coastguard Worker torch._dynamo.utils.same( 5287*da0073e9SAndroid Build Coastguard Worker torch.nonzero_static( 5288*da0073e9SAndroid Build Coastguard Worker input_tensor, size=static_size, fill_value=fill_value 5289*da0073e9SAndroid Build Coastguard Worker ), 5290*da0073e9SAndroid Build Coastguard Worker torch.tensor( 5291*da0073e9SAndroid Build Coastguard Worker [ 5292*da0073e9SAndroid Build Coastguard Worker [0, 0], 5293*da0073e9SAndroid Build Coastguard Worker [1, 0], 5294*da0073e9SAndroid Build Coastguard Worker [1, 1], 5295*da0073e9SAndroid Build Coastguard Worker [fill_value, fill_value], 5296*da0073e9SAndroid Build Coastguard Worker [fill_value, fill_value], 5297*da0073e9SAndroid Build Coastguard Worker ] 5298*da0073e9SAndroid Build Coastguard Worker ), 5299*da0073e9SAndroid Build Coastguard Worker ) 5300*da0073e9SAndroid Build Coastguard Worker ) 5301*da0073e9SAndroid Build Coastguard Worker input_tensor = torch.tensor([[1.2, 0], [3.4, 5.6]]) 5302*da0073e9SAndroid Build Coastguard Worker static_size = 2 5303*da0073e9SAndroid Build Coastguard Worker fill_value = -100 5304*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 5305*da0073e9SAndroid Build Coastguard Worker torch._dynamo.utils.same( 5306*da0073e9SAndroid Build Coastguard Worker torch.nonzero_static( 5307*da0073e9SAndroid Build Coastguard Worker input_tensor, size=static_size, fill_value=fill_value 5308*da0073e9SAndroid Build Coastguard Worker ), 5309*da0073e9SAndroid Build Coastguard Worker torch.tensor([[0, 0], [1, 0]]), 5310*da0073e9SAndroid Build Coastguard Worker ) 5311*da0073e9SAndroid Build Coastguard Worker ) 5312*da0073e9SAndroid Build Coastguard Worker 5313*da0073e9SAndroid Build Coastguard Worker # 3D input 5314*da0073e9SAndroid Build Coastguard Worker input_tensor = torch.tensor([[[0, 0], [0, -3]], [[0, 0], [5, 0]]]) 5315*da0073e9SAndroid Build Coastguard Worker static_size = 4 5316*da0073e9SAndroid Build Coastguard Worker fill_value = -999 5317*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 5318*da0073e9SAndroid Build Coastguard Worker torch._dynamo.utils.same( 5319*da0073e9SAndroid Build Coastguard Worker torch.nonzero_static( 5320*da0073e9SAndroid Build Coastguard Worker input_tensor, 5321*da0073e9SAndroid Build Coastguard Worker size=static_size, 5322*da0073e9SAndroid Build Coastguard Worker fill_value=fill_value, 5323*da0073e9SAndroid Build Coastguard Worker ), 5324*da0073e9SAndroid Build Coastguard Worker torch.tensor( 5325*da0073e9SAndroid Build Coastguard Worker [ 5326*da0073e9SAndroid Build Coastguard Worker [0, 1, 1], 5327*da0073e9SAndroid Build Coastguard Worker [1, 1, 0], 5328*da0073e9SAndroid Build Coastguard Worker [fill_value, fill_value, fill_value], 5329*da0073e9SAndroid Build Coastguard Worker [fill_value, fill_value, fill_value], 5330*da0073e9SAndroid Build Coastguard Worker ] 5331*da0073e9SAndroid Build Coastguard Worker ), 5332*da0073e9SAndroid Build Coastguard Worker ) 5333*da0073e9SAndroid Build Coastguard Worker ) 5334*da0073e9SAndroid Build Coastguard Worker 5335*da0073e9SAndroid Build Coastguard Worker def test_cond_with_quantization(self): 5336*da0073e9SAndroid Build Coastguard Worker from functorch.experimental.control_flow import cond 5337*da0073e9SAndroid Build Coastguard Worker 5338*da0073e9SAndroid Build Coastguard Worker class MyModule(torch.nn.Module): 5339*da0073e9SAndroid Build Coastguard Worker def __init__(self): 5340*da0073e9SAndroid Build Coastguard Worker super().__init__() 5341*da0073e9SAndroid Build Coastguard Worker example_inputs = (torch.randn(5, 5),) 5342*da0073e9SAndroid Build Coastguard Worker self.model = torch.nn.Linear(5, 5) 5343*da0073e9SAndroid Build Coastguard Worker self.quantized_model = prepare_qat_fx( 5344*da0073e9SAndroid Build Coastguard Worker self.model, qconfig_dict, example_inputs=example_inputs 5345*da0073e9SAndroid Build Coastguard Worker ) 5346*da0073e9SAndroid Build Coastguard Worker 5347*da0073e9SAndroid Build Coastguard Worker def forward(self, pred, x): 5348*da0073e9SAndroid Build Coastguard Worker def true_fn(x): 5349*da0073e9SAndroid Build Coastguard Worker return x.sin() + self.quantized_model(x) 5350*da0073e9SAndroid Build Coastguard Worker 5351*da0073e9SAndroid Build Coastguard Worker def false_fn(x): 5352*da0073e9SAndroid Build Coastguard Worker return x.cos() + self.model(x) 5353*da0073e9SAndroid Build Coastguard Worker 5354*da0073e9SAndroid Build Coastguard Worker return cond(pred, true_fn, false_fn, [x]) 5355*da0073e9SAndroid Build Coastguard Worker 5356*da0073e9SAndroid Build Coastguard Worker module = MyModule() 5357*da0073e9SAndroid Build Coastguard Worker opt_m = torch._dynamo.optimize("eager", nopython=True)(module) 5358*da0073e9SAndroid Build Coastguard Worker x = torch.rand((5, 5)) 5359*da0073e9SAndroid Build Coastguard Worker pred = torch.tensor(True) 5360*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(module(pred, x), opt_m(pred, x))) 5361*da0073e9SAndroid Build Coastguard Worker pred = torch.tensor(False) 5362*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(module(pred, x), opt_m(pred, x))) 5363*da0073e9SAndroid Build Coastguard Worker 5364*da0073e9SAndroid Build Coastguard Worker def test_map_with_quantization(self): 5365*da0073e9SAndroid Build Coastguard Worker from functorch.experimental.control_flow import map 5366*da0073e9SAndroid Build Coastguard Worker 5367*da0073e9SAndroid Build Coastguard Worker class MyModule(torch.nn.Module): 5368*da0073e9SAndroid Build Coastguard Worker def __init__(self): 5369*da0073e9SAndroid Build Coastguard Worker super().__init__() 5370*da0073e9SAndroid Build Coastguard Worker example_inputs = (torch.randn(5, 5),) 5371*da0073e9SAndroid Build Coastguard Worker self.model = torch.nn.Linear(5, 5) 5372*da0073e9SAndroid Build Coastguard Worker self.quantized_model = prepare_qat_fx( 5373*da0073e9SAndroid Build Coastguard Worker self.model, qconfig_dict, example_inputs=example_inputs 5374*da0073e9SAndroid Build Coastguard Worker ) 5375*da0073e9SAndroid Build Coastguard Worker 5376*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 5377*da0073e9SAndroid Build Coastguard Worker def body(x): 5378*da0073e9SAndroid Build Coastguard Worker return x.sin() + self.quantized_model(x) 5379*da0073e9SAndroid Build Coastguard Worker 5380*da0073e9SAndroid Build Coastguard Worker return map(body, x) 5381*da0073e9SAndroid Build Coastguard Worker 5382*da0073e9SAndroid Build Coastguard Worker module = MyModule() 5383*da0073e9SAndroid Build Coastguard Worker opt_m = torch._dynamo.optimize("eager", nopython=True)(module) 5384*da0073e9SAndroid Build Coastguard Worker x = torch.rand((5, 5)) 5385*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(module(x), opt_m(x))) 5386*da0073e9SAndroid Build Coastguard Worker 5387*da0073e9SAndroid Build Coastguard Worker def test_cond_side_effects(self): 5388*da0073e9SAndroid Build Coastguard Worker from functorch.experimental.control_flow import cond 5389*da0073e9SAndroid Build Coastguard Worker 5390*da0073e9SAndroid Build Coastguard Worker c = 0 5391*da0073e9SAndroid Build Coastguard Worker 5392*da0073e9SAndroid Build Coastguard Worker def true_fn(x): 5393*da0073e9SAndroid Build Coastguard Worker return x - c 5394*da0073e9SAndroid Build Coastguard Worker 5395*da0073e9SAndroid Build Coastguard Worker def false_fn(x): 5396*da0073e9SAndroid Build Coastguard Worker return x + c 5397*da0073e9SAndroid Build Coastguard Worker 5398*da0073e9SAndroid Build Coastguard Worker def f(pred, x): 5399*da0073e9SAndroid Build Coastguard Worker nonlocal c 5400*da0073e9SAndroid Build Coastguard Worker c = 1 5401*da0073e9SAndroid Build Coastguard Worker return cond(pred, true_fn, false_fn, [x]) 5402*da0073e9SAndroid Build Coastguard Worker 5403*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager")(f) 5404*da0073e9SAndroid Build Coastguard Worker c = 0 5405*da0073e9SAndroid Build Coastguard Worker a = opt_fn(torch.tensor(False), torch.tensor([0.25, 0.25])) 5406*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(torch.tensor([1.25, 1.25]), a)) 5407*da0073e9SAndroid Build Coastguard Worker 5408*da0073e9SAndroid Build Coastguard Worker def test_map_side_effects(self): 5409*da0073e9SAndroid Build Coastguard Worker from functorch.experimental.control_flow import map 5410*da0073e9SAndroid Build Coastguard Worker 5411*da0073e9SAndroid Build Coastguard Worker class Module(torch.nn.Module): 5412*da0073e9SAndroid Build Coastguard Worker def __init__(self): 5413*da0073e9SAndroid Build Coastguard Worker super().__init__() 5414*da0073e9SAndroid Build Coastguard Worker self.w = torch.tensor(1) 5415*da0073e9SAndroid Build Coastguard Worker 5416*da0073e9SAndroid Build Coastguard Worker def forward(self, xs): 5417*da0073e9SAndroid Build Coastguard Worker def body(x): 5418*da0073e9SAndroid Build Coastguard Worker self.w += 1 5419*da0073e9SAndroid Build Coastguard Worker return x 5420*da0073e9SAndroid Build Coastguard Worker 5421*da0073e9SAndroid Build Coastguard Worker return map(body, xs) 5422*da0073e9SAndroid Build Coastguard Worker 5423*da0073e9SAndroid Build Coastguard Worker mod = Module() 5424*da0073e9SAndroid Build Coastguard Worker 5425*da0073e9SAndroid Build Coastguard Worker error_message = "" 5426*da0073e9SAndroid Build Coastguard Worker if torch._dynamo.config.inline_inbuilt_nn_modules: 5427*da0073e9SAndroid Build Coastguard Worker error_message = r"HigherOrderOperator: Mutating a variable not in the current scope \(SideEffects\)" 5428*da0073e9SAndroid Build Coastguard Worker else: 5429*da0073e9SAndroid Build Coastguard Worker error_message = "Can't inplace modify module params/buffers" 5430*da0073e9SAndroid Build Coastguard Worker 5431*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(Unsupported, error_message): 5432*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager", nopython=True)(mod) 5433*da0073e9SAndroid Build Coastguard Worker opt_fn(torch.randn(3, 2)) 5434*da0073e9SAndroid Build Coastguard Worker 5435*da0073e9SAndroid Build Coastguard Worker def test_cond_nested(self): 5436*da0073e9SAndroid Build Coastguard Worker from functorch.experimental.control_flow import cond 5437*da0073e9SAndroid Build Coastguard Worker 5438*da0073e9SAndroid Build Coastguard Worker def true_fn_nested(x): 5439*da0073e9SAndroid Build Coastguard Worker return x * 10 5440*da0073e9SAndroid Build Coastguard Worker 5441*da0073e9SAndroid Build Coastguard Worker def false_fn_nested(x): 5442*da0073e9SAndroid Build Coastguard Worker return x * -1 5443*da0073e9SAndroid Build Coastguard Worker 5444*da0073e9SAndroid Build Coastguard Worker def true_fn(pred2, x): 5445*da0073e9SAndroid Build Coastguard Worker return x.sin() 5446*da0073e9SAndroid Build Coastguard Worker 5447*da0073e9SAndroid Build Coastguard Worker def false_fn(pred2, x): 5448*da0073e9SAndroid Build Coastguard Worker return x + cond(pred2, true_fn_nested, false_fn_nested, [x]) 5449*da0073e9SAndroid Build Coastguard Worker 5450*da0073e9SAndroid Build Coastguard Worker def f(pred, pred2, x): 5451*da0073e9SAndroid Build Coastguard Worker return cond(pred, true_fn, false_fn, [pred2, x]) 5452*da0073e9SAndroid Build Coastguard Worker 5453*da0073e9SAndroid Build Coastguard Worker cc = torch._dynamo.testing.CompileCounter() 5454*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cc)(f) 5455*da0073e9SAndroid Build Coastguard Worker true_true_sin = opt_fn( 5456*da0073e9SAndroid Build Coastguard Worker torch.tensor(True), torch.tensor(True), torch.tensor([0.25, 0.25]) 5457*da0073e9SAndroid Build Coastguard Worker ) 5458*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(torch.sin(torch.tensor([0.25, 0.25])), true_true_sin)) 5459*da0073e9SAndroid Build Coastguard Worker 5460*da0073e9SAndroid Build Coastguard Worker true_false_sin = opt_fn( 5461*da0073e9SAndroid Build Coastguard Worker torch.tensor(True), torch.tensor(False), torch.tensor([0.25, 0.25]) 5462*da0073e9SAndroid Build Coastguard Worker ) 5463*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(torch.sin(torch.tensor([0.25, 0.25])), true_false_sin)) 5464*da0073e9SAndroid Build Coastguard Worker 5465*da0073e9SAndroid Build Coastguard Worker false_true_sum_mult = opt_fn( 5466*da0073e9SAndroid Build Coastguard Worker torch.tensor(False), torch.tensor(True), torch.tensor([0.25, 0.25]) 5467*da0073e9SAndroid Build Coastguard Worker ) 5468*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 5469*da0073e9SAndroid Build Coastguard Worker same(torch.tensor([2.75, 2.75]), false_true_sum_mult) 5470*da0073e9SAndroid Build Coastguard Worker ) # * 10 then add x 5471*da0073e9SAndroid Build Coastguard Worker 5472*da0073e9SAndroid Build Coastguard Worker false_false_sum_neg = opt_fn( 5473*da0073e9SAndroid Build Coastguard Worker torch.tensor(False), torch.tensor(False), torch.tensor([0.25, 0.25]) 5474*da0073e9SAndroid Build Coastguard Worker ) 5475*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 5476*da0073e9SAndroid Build Coastguard Worker same(torch.tensor([0.0, 0.0]), false_false_sum_neg) 5477*da0073e9SAndroid Build Coastguard Worker ) # * -1 then add x 5478*da0073e9SAndroid Build Coastguard Worker self.assertTrue(cc.frame_count, 2) 5479*da0073e9SAndroid Build Coastguard Worker 5480*da0073e9SAndroid Build Coastguard Worker def test_cond_export(self): 5481*da0073e9SAndroid Build Coastguard Worker from functorch.experimental.control_flow import cond 5482*da0073e9SAndroid Build Coastguard Worker 5483*da0073e9SAndroid Build Coastguard Worker def true_fn_nested(x): 5484*da0073e9SAndroid Build Coastguard Worker return x * 10 5485*da0073e9SAndroid Build Coastguard Worker 5486*da0073e9SAndroid Build Coastguard Worker def false_fn_nested(x): 5487*da0073e9SAndroid Build Coastguard Worker return x * -1 5488*da0073e9SAndroid Build Coastguard Worker 5489*da0073e9SAndroid Build Coastguard Worker def true_fn(pred2, x): 5490*da0073e9SAndroid Build Coastguard Worker return x.sin() 5491*da0073e9SAndroid Build Coastguard Worker 5492*da0073e9SAndroid Build Coastguard Worker def false_fn(pred2, x): 5493*da0073e9SAndroid Build Coastguard Worker return x + cond(pred2, true_fn_nested, false_fn_nested, [x]) 5494*da0073e9SAndroid Build Coastguard Worker 5495*da0073e9SAndroid Build Coastguard Worker def f(pred, pred2, x): 5496*da0073e9SAndroid Build Coastguard Worker return cond(pred, true_fn, false_fn, [pred2, x]) 5497*da0073e9SAndroid Build Coastguard Worker 5498*da0073e9SAndroid Build Coastguard Worker graph, guard = torch._dynamo.export(f)( 5499*da0073e9SAndroid Build Coastguard Worker torch.tensor(False), torch.tensor(True), torch.tensor([0.25, 0.25]) 5500*da0073e9SAndroid Build Coastguard Worker ) 5501*da0073e9SAndroid Build Coastguard Worker true_true_sin = graph( 5502*da0073e9SAndroid Build Coastguard Worker torch.tensor(True), torch.tensor(True), torch.tensor([0.25, 0.25]) 5503*da0073e9SAndroid Build Coastguard Worker ) 5504*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(torch.sin(torch.tensor([0.25, 0.25])), true_true_sin)) 5505*da0073e9SAndroid Build Coastguard Worker 5506*da0073e9SAndroid Build Coastguard Worker true_false_sin = graph( 5507*da0073e9SAndroid Build Coastguard Worker torch.tensor(True), torch.tensor(False), torch.tensor([0.25, 0.25]) 5508*da0073e9SAndroid Build Coastguard Worker ) 5509*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(torch.sin(torch.tensor([0.25, 0.25])), true_false_sin)) 5510*da0073e9SAndroid Build Coastguard Worker 5511*da0073e9SAndroid Build Coastguard Worker false_true_sum_mult = graph( 5512*da0073e9SAndroid Build Coastguard Worker torch.tensor(False), torch.tensor(True), torch.tensor([0.25, 0.25]) 5513*da0073e9SAndroid Build Coastguard Worker ) 5514*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 5515*da0073e9SAndroid Build Coastguard Worker same(torch.tensor([2.75, 2.75]), false_true_sum_mult) 5516*da0073e9SAndroid Build Coastguard Worker ) # * 10 then add x 5517*da0073e9SAndroid Build Coastguard Worker 5518*da0073e9SAndroid Build Coastguard Worker false_false_sum_neg = graph( 5519*da0073e9SAndroid Build Coastguard Worker torch.tensor(False), torch.tensor(False), torch.tensor([0.25, 0.25]) 5520*da0073e9SAndroid Build Coastguard Worker ) 5521*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 5522*da0073e9SAndroid Build Coastguard Worker same(torch.tensor([0.0, 0.0]), false_false_sum_neg) 5523*da0073e9SAndroid Build Coastguard Worker ) # * -1 then add x 5524*da0073e9SAndroid Build Coastguard Worker 5525*da0073e9SAndroid Build Coastguard Worker def test_cond_export_single_arg(self): 5526*da0073e9SAndroid Build Coastguard Worker from functorch.experimental.control_flow import cond 5527*da0073e9SAndroid Build Coastguard Worker 5528*da0073e9SAndroid Build Coastguard Worker def true_fn(x): 5529*da0073e9SAndroid Build Coastguard Worker return x 5530*da0073e9SAndroid Build Coastguard Worker 5531*da0073e9SAndroid Build Coastguard Worker def false_fn(x): 5532*da0073e9SAndroid Build Coastguard Worker return x.sin() 5533*da0073e9SAndroid Build Coastguard Worker 5534*da0073e9SAndroid Build Coastguard Worker def f(pred, x): 5535*da0073e9SAndroid Build Coastguard Worker return cond(pred, true_fn, false_fn, [x]) 5536*da0073e9SAndroid Build Coastguard Worker 5537*da0073e9SAndroid Build Coastguard Worker graph, guard = torch._dynamo.export(f)( 5538*da0073e9SAndroid Build Coastguard Worker torch.tensor(False), torch.tensor([0.25, 0.25]) 5539*da0073e9SAndroid Build Coastguard Worker ) 5540*da0073e9SAndroid Build Coastguard Worker true_mirror = graph(torch.tensor(True), torch.tensor([0.25, 0.25])) 5541*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(torch.tensor([0.25, 0.25]), true_mirror)) 5542*da0073e9SAndroid Build Coastguard Worker true_mirror_2 = graph(torch.tensor(True), torch.tensor([0.33, 0.33, 0.33])) 5543*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(torch.tensor([0.33, 0.33, 0.33]), true_mirror_2)) 5544*da0073e9SAndroid Build Coastguard Worker 5545*da0073e9SAndroid Build Coastguard Worker false_sin = graph(torch.tensor(False), torch.tensor([0.5, 0.5])) 5546*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(torch.sin(torch.tensor([0.5, 0.5])), false_sin)) 5547*da0073e9SAndroid Build Coastguard Worker 5548*da0073e9SAndroid Build Coastguard Worker def test_enum_guards(self): 5549*da0073e9SAndroid Build Coastguard Worker class MyEnum(enum.Enum): 5550*da0073e9SAndroid Build Coastguard Worker FOO = 10 5551*da0073e9SAndroid Build Coastguard Worker BAR = 20 5552*da0073e9SAndroid Build Coastguard Worker 5553*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 5554*da0073e9SAndroid Build Coastguard Worker if y == MyEnum.FOO: 5555*da0073e9SAndroid Build Coastguard Worker return x + 1 5556*da0073e9SAndroid Build Coastguard Worker else: 5557*da0073e9SAndroid Build Coastguard Worker return x - 1 5558*da0073e9SAndroid Build Coastguard Worker 5559*da0073e9SAndroid Build Coastguard Worker x = torch.rand(3) 5560*da0073e9SAndroid Build Coastguard Worker y = MyEnum.BAR 5561*da0073e9SAndroid Build Coastguard Worker ref = fn(x, y) 5562*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(backend="eager")(fn) 5563*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x, y) 5564*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 5565*da0073e9SAndroid Build Coastguard Worker 5566*da0073e9SAndroid Build Coastguard Worker def test_duplicate_graph_break_log(self): 5567*da0073e9SAndroid Build Coastguard Worker torch._logging.set_logs(graph_breaks=True) 5568*da0073e9SAndroid Build Coastguard Worker 5569*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize("eager") 5570*da0073e9SAndroid Build Coastguard Worker def f1(a, b): 5571*da0073e9SAndroid Build Coastguard Worker f2(a, b) 5572*da0073e9SAndroid Build Coastguard Worker 5573*da0073e9SAndroid Build Coastguard Worker def f2(a, b): 5574*da0073e9SAndroid Build Coastguard Worker c = a + b 5575*da0073e9SAndroid Build Coastguard Worker print("break") 5576*da0073e9SAndroid Build Coastguard Worker return a + b + c 5577*da0073e9SAndroid Build Coastguard Worker 5578*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize("eager") 5579*da0073e9SAndroid Build Coastguard Worker def g1(a, b): 5580*da0073e9SAndroid Build Coastguard Worker g2(a, b) 5581*da0073e9SAndroid Build Coastguard Worker 5582*da0073e9SAndroid Build Coastguard Worker def g2(a, b): 5583*da0073e9SAndroid Build Coastguard Worker c = a + b 5584*da0073e9SAndroid Build Coastguard Worker print("break") 5585*da0073e9SAndroid Build Coastguard Worker return a + b + c 5586*da0073e9SAndroid Build Coastguard Worker 5587*da0073e9SAndroid Build Coastguard Worker def count_graph_break_msgs(msgs): 5588*da0073e9SAndroid Build Coastguard Worker return sum(msg.find("Graph break") != -1 for msg in msgs) 5589*da0073e9SAndroid Build Coastguard Worker 5590*da0073e9SAndroid Build Coastguard Worker with self.assertLogs( 5591*da0073e9SAndroid Build Coastguard Worker logger="torch._dynamo", level=logging.DEBUG 5592*da0073e9SAndroid Build Coastguard Worker ) as log, torch._dynamo.config.patch(verbose=True): 5593*da0073e9SAndroid Build Coastguard Worker f1(torch.randn(10), torch.randn(10)) 5594*da0073e9SAndroid Build Coastguard Worker self.assertGreater(count_graph_break_msgs(log.output), 1) 5595*da0073e9SAndroid Build Coastguard Worker 5596*da0073e9SAndroid Build Coastguard Worker with self.assertLogs( 5597*da0073e9SAndroid Build Coastguard Worker logger="torch._dynamo", level=logging.DEBUG 5598*da0073e9SAndroid Build Coastguard Worker ) as log, torch._dynamo.config.patch(verbose=False): 5599*da0073e9SAndroid Build Coastguard Worker g1(torch.randn(10), torch.randn(10)) 5600*da0073e9SAndroid Build Coastguard Worker self.assertEqual(count_graph_break_msgs(log.output), 1) 5601*da0073e9SAndroid Build Coastguard Worker 5602*da0073e9SAndroid Build Coastguard Worker # reset logging state 5603*da0073e9SAndroid Build Coastguard Worker torch._logging.set_logs() 5604*da0073e9SAndroid Build Coastguard Worker 5605*da0073e9SAndroid Build Coastguard Worker def test_inplace_param_update(self): 5606*da0073e9SAndroid Build Coastguard Worker def fn(param, y): 5607*da0073e9SAndroid Build Coastguard Worker prev_grad = torch.is_grad_enabled() 5608*da0073e9SAndroid Build Coastguard Worker try: 5609*da0073e9SAndroid Build Coastguard Worker torch.set_grad_enabled(False) 5610*da0073e9SAndroid Build Coastguard Worker torch.set_grad_enabled(True) 5611*da0073e9SAndroid Build Coastguard Worker torch.set_grad_enabled(False) 5612*da0073e9SAndroid Build Coastguard Worker param.add_(y) 5613*da0073e9SAndroid Build Coastguard Worker finally: 5614*da0073e9SAndroid Build Coastguard Worker torch.set_grad_enabled(prev_grad) 5615*da0073e9SAndroid Build Coastguard Worker 5616*da0073e9SAndroid Build Coastguard Worker y = torch.randn(4) 5617*da0073e9SAndroid Build Coastguard Worker x = torch.nn.Parameter(torch.randn(4)) 5618*da0073e9SAndroid Build Coastguard Worker fn(x, y) 5619*da0073e9SAndroid Build Coastguard Worker 5620*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 5621*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 5622*da0073e9SAndroid Build Coastguard Worker opt_fn(x, y) 5623*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 5624*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.op_count, 3) 5625*da0073e9SAndroid Build Coastguard Worker 5626*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 5627*da0073e9SAndroid Build Coastguard Worker not PLATFORM_SUPPORTS_FLASH_ATTENTION, 5628*da0073e9SAndroid Build Coastguard Worker "Can't run fused SDPA on this platform", 5629*da0073e9SAndroid Build Coastguard Worker ) 5630*da0073e9SAndroid Build Coastguard Worker def test_parsing_sdpa(self): 5631*da0073e9SAndroid Build Coastguard Worker class MyModule(torch.nn.Module): 5632*da0073e9SAndroid Build Coastguard Worker def forward(self, query, key, value): 5633*da0073e9SAndroid Build Coastguard Worker out = F.scaled_dot_product_attention(query, key, value, None, 0, True) 5634*da0073e9SAndroid Build Coastguard Worker out = F.scaled_dot_product_attention( 5635*da0073e9SAndroid Build Coastguard Worker query, key, value, None, 0, True, scale=8 5636*da0073e9SAndroid Build Coastguard Worker ) 5637*da0073e9SAndroid Build Coastguard Worker out = F.scaled_dot_product_attention( 5638*da0073e9SAndroid Build Coastguard Worker query=query, 5639*da0073e9SAndroid Build Coastguard Worker key=key, 5640*da0073e9SAndroid Build Coastguard Worker value=value, 5641*da0073e9SAndroid Build Coastguard Worker attn_mask=None, 5642*da0073e9SAndroid Build Coastguard Worker dropout_p=0, 5643*da0073e9SAndroid Build Coastguard Worker is_causal=True, 5644*da0073e9SAndroid Build Coastguard Worker ) 5645*da0073e9SAndroid Build Coastguard Worker out = F.scaled_dot_product_attention( 5646*da0073e9SAndroid Build Coastguard Worker query, 5647*da0073e9SAndroid Build Coastguard Worker key=key, 5648*da0073e9SAndroid Build Coastguard Worker value=value, 5649*da0073e9SAndroid Build Coastguard Worker attn_mask=None, 5650*da0073e9SAndroid Build Coastguard Worker dropout_p=0, 5651*da0073e9SAndroid Build Coastguard Worker is_causal=True, 5652*da0073e9SAndroid Build Coastguard Worker ) 5653*da0073e9SAndroid Build Coastguard Worker out = F.scaled_dot_product_attention( 5654*da0073e9SAndroid Build Coastguard Worker query, key, value, None, dropout_p=0, is_causal=True 5655*da0073e9SAndroid Build Coastguard Worker ) 5656*da0073e9SAndroid Build Coastguard Worker out = F.scaled_dot_product_attention(query, key, value, None, scale=8) 5657*da0073e9SAndroid Build Coastguard Worker return out 5658*da0073e9SAndroid Build Coastguard Worker 5659*da0073e9SAndroid Build Coastguard Worker device = "cuda" 5660*da0073e9SAndroid Build Coastguard Worker dtype = torch.float16 5661*da0073e9SAndroid Build Coastguard Worker seq_len_q = 1 5662*da0073e9SAndroid Build Coastguard Worker seq_len_k = 1 5663*da0073e9SAndroid Build Coastguard Worker head_dim = 8 5664*da0073e9SAndroid Build Coastguard Worker query = torch.ones( 5665*da0073e9SAndroid Build Coastguard Worker 1, 8, seq_len_q, head_dim, device=device, dtype=dtype, requires_grad=True 5666*da0073e9SAndroid Build Coastguard Worker ) 5667*da0073e9SAndroid Build Coastguard Worker key = torch.ones( 5668*da0073e9SAndroid Build Coastguard Worker 1, 8, seq_len_k, head_dim, device=device, dtype=dtype, requires_grad=True 5669*da0073e9SAndroid Build Coastguard Worker ) 5670*da0073e9SAndroid Build Coastguard Worker value = torch.ones( 5671*da0073e9SAndroid Build Coastguard Worker 1, 8, seq_len_k, head_dim, device=device, dtype=dtype, requires_grad=True 5672*da0073e9SAndroid Build Coastguard Worker ) 5673*da0073e9SAndroid Build Coastguard Worker module = MyModule() 5674*da0073e9SAndroid Build Coastguard Worker opt_mod = torch._dynamo.optimize("inductor")(module) 5675*da0073e9SAndroid Build Coastguard Worker opt_mod(query, key, value) 5676*da0073e9SAndroid Build Coastguard Worker 5677*da0073e9SAndroid Build Coastguard Worker def test_generate_tensor_from_list_of_numpy_primitive_type(self): 5678*da0073e9SAndroid Build Coastguard Worker # Test sth like torch.LongTensor(list(np.int64, np.int64, ...)) 5679*da0073e9SAndroid Build Coastguard Worker def fn(): 5680*da0073e9SAndroid Build Coastguard Worker x = np.array([1, 2, 3, 4, 5, 6], dtype=np.int64) 5681*da0073e9SAndroid Build Coastguard Worker y = [x[0], x[2], x[4]] 5682*da0073e9SAndroid Build Coastguard Worker return torch.LongTensor(y) 5683*da0073e9SAndroid Build Coastguard Worker 5684*da0073e9SAndroid Build Coastguard Worker ref = fn() 5685*da0073e9SAndroid Build Coastguard Worker res = torch.compile(fullgraph=True)(fn)() 5686*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 5687*da0073e9SAndroid Build Coastguard Worker 5688*da0073e9SAndroid Build Coastguard Worker def test_object_classmethod(self): 5689*da0073e9SAndroid Build Coastguard Worker class C: 5690*da0073e9SAndroid Build Coastguard Worker @classmethod 5691*da0073e9SAndroid Build Coastguard Worker def fn(cls, x): 5692*da0073e9SAndroid Build Coastguard Worker return x + x 5693*da0073e9SAndroid Build Coastguard Worker 5694*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize("eager", nopython=True) 5695*da0073e9SAndroid Build Coastguard Worker def f(): 5696*da0073e9SAndroid Build Coastguard Worker return C().fn(torch.ones(2, 3)) 5697*da0073e9SAndroid Build Coastguard Worker 5698*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(f(), torch.tensor([2.0]))) 5699*da0073e9SAndroid Build Coastguard Worker 5700*da0073e9SAndroid Build Coastguard Worker def test_object_staticmethod(self): 5701*da0073e9SAndroid Build Coastguard Worker class C: 5702*da0073e9SAndroid Build Coastguard Worker @staticmethod 5703*da0073e9SAndroid Build Coastguard Worker def fn(x): 5704*da0073e9SAndroid Build Coastguard Worker return x + x 5705*da0073e9SAndroid Build Coastguard Worker 5706*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize("eager", nopython=True) 5707*da0073e9SAndroid Build Coastguard Worker def f(): 5708*da0073e9SAndroid Build Coastguard Worker return C().fn(torch.ones(2, 3)) 5709*da0073e9SAndroid Build Coastguard Worker 5710*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(f(), torch.tensor([2.0]))) 5711*da0073e9SAndroid Build Coastguard Worker 5712*da0073e9SAndroid Build Coastguard Worker def test_user_function_variable_supports_enum_argument(self): 5713*da0073e9SAndroid Build Coastguard Worker class Foo(enum.Enum): 5714*da0073e9SAndroid Build Coastguard Worker FOO = 0 5715*da0073e9SAndroid Build Coastguard Worker BAR = 1 5716*da0073e9SAndroid Build Coastguard Worker 5717*da0073e9SAndroid Build Coastguard Worker def gn(x, y=Foo.FOO): 5718*da0073e9SAndroid Build Coastguard Worker if y is Foo.FOO: 5719*da0073e9SAndroid Build Coastguard Worker return x 5720*da0073e9SAndroid Build Coastguard Worker else: 5721*da0073e9SAndroid Build Coastguard Worker return x + 1 5722*da0073e9SAndroid Build Coastguard Worker 5723*da0073e9SAndroid Build Coastguard Worker def fn(x): 5724*da0073e9SAndroid Build Coastguard Worker return gn(x) 5725*da0073e9SAndroid Build Coastguard Worker 5726*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3) 5727*da0073e9SAndroid Build Coastguard Worker ref = fn(x) 5728*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) 5729*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x) 5730*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(ref, res)) 5731*da0073e9SAndroid Build Coastguard Worker 5732*da0073e9SAndroid Build Coastguard Worker def test_user_function_variable_supports_type_abcmeta_argument(self): 5733*da0073e9SAndroid Build Coastguard Worker class Foo(metaclass=abc.ABCMeta): 5734*da0073e9SAndroid Build Coastguard Worker @abc.abstractclassmethod 5735*da0073e9SAndroid Build Coastguard Worker def read(self): # noqa: B027 5736*da0073e9SAndroid Build Coastguard Worker pass 5737*da0073e9SAndroid Build Coastguard Worker 5738*da0073e9SAndroid Build Coastguard Worker class Bar(Foo): 5739*da0073e9SAndroid Build Coastguard Worker def read(self): 5740*da0073e9SAndroid Build Coastguard Worker return "Hello World!" 5741*da0073e9SAndroid Build Coastguard Worker 5742*da0073e9SAndroid Build Coastguard Worker class Baz: 5743*da0073e9SAndroid Build Coastguard Worker pass 5744*da0073e9SAndroid Build Coastguard Worker 5745*da0073e9SAndroid Build Coastguard Worker def gn(x, tys=(Bar, Baz)): 5746*da0073e9SAndroid Build Coastguard Worker if Bar in tys: 5747*da0073e9SAndroid Build Coastguard Worker return x - 1 5748*da0073e9SAndroid Build Coastguard Worker else: 5749*da0073e9SAndroid Build Coastguard Worker return x + 1 5750*da0073e9SAndroid Build Coastguard Worker 5751*da0073e9SAndroid Build Coastguard Worker def fn(x): 5752*da0073e9SAndroid Build Coastguard Worker return gn(x) 5753*da0073e9SAndroid Build Coastguard Worker 5754*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3) 5755*da0073e9SAndroid Build Coastguard Worker ref = fn(x) 5756*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) 5757*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x) 5758*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(ref, res)) 5759*da0073e9SAndroid Build Coastguard Worker 5760*da0073e9SAndroid Build Coastguard Worker def test_user_function_variable_supports_function_argument(self): 5761*da0073e9SAndroid Build Coastguard Worker # Test user defined function default arguments can be: 5762*da0073e9SAndroid Build Coastguard Worker # 1, user defined functions (e.g, add1) 5763*da0073e9SAndroid Build Coastguard Worker # 2, torch functions (e.g, torch.sin) 5764*da0073e9SAndroid Build Coastguard Worker # 3, python builtin functions (e.g, operator.neg) 5765*da0073e9SAndroid Build Coastguard Worker def add1(x): 5766*da0073e9SAndroid Build Coastguard Worker return x + 1 5767*da0073e9SAndroid Build Coastguard Worker 5768*da0073e9SAndroid Build Coastguard Worker def gn(x, f1=add1, f2=torch.sin, f3=operator.neg): 5769*da0073e9SAndroid Build Coastguard Worker return f3(f2(f1(x))) 5770*da0073e9SAndroid Build Coastguard Worker 5771*da0073e9SAndroid Build Coastguard Worker def fn(x): 5772*da0073e9SAndroid Build Coastguard Worker return gn(x) 5773*da0073e9SAndroid Build Coastguard Worker 5774*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3) 5775*da0073e9SAndroid Build Coastguard Worker ref = fn(x) 5776*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) 5777*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x) 5778*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(ref, res)) 5779*da0073e9SAndroid Build Coastguard Worker 5780*da0073e9SAndroid Build Coastguard Worker def test_typing_variable_isinstance(self): 5781*da0073e9SAndroid Build Coastguard Worker def fn(x, m): 5782*da0073e9SAndroid Build Coastguard Worker if isinstance(m, typing.Mapping): 5783*da0073e9SAndroid Build Coastguard Worker return x + 1 5784*da0073e9SAndroid Build Coastguard Worker else: 5785*da0073e9SAndroid Build Coastguard Worker return x - 1 5786*da0073e9SAndroid Build Coastguard Worker 5787*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3) 5788*da0073e9SAndroid Build Coastguard Worker m = {"x": torch.randn(3)} 5789*da0073e9SAndroid Build Coastguard Worker ref = fn(x, m) 5790*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager")(fn) 5791*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x, m) 5792*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(ref, res)) 5793*da0073e9SAndroid Build Coastguard Worker 5794*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch(guard_nn_modules=True) 5795*da0073e9SAndroid Build Coastguard Worker def test_repro_graph_breaks_in__get_item_by_idx(self): 5796*da0073e9SAndroid Build Coastguard Worker class Mod(torch.nn.Module): 5797*da0073e9SAndroid Build Coastguard Worker def __init__(self): 5798*da0073e9SAndroid Build Coastguard Worker super().__init__() 5799*da0073e9SAndroid Build Coastguard Worker self.mod = torch.nn.Sequential( 5800*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(3, 3), torch.nn.Linear(3, 3) 5801*da0073e9SAndroid Build Coastguard Worker ) 5802*da0073e9SAndroid Build Coastguard Worker 5803*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 5804*da0073e9SAndroid Build Coastguard Worker return self.mod[0](x) 5805*da0073e9SAndroid Build Coastguard Worker 5806*da0073e9SAndroid Build Coastguard Worker m = Mod() 5807*da0073e9SAndroid Build Coastguard Worker graph, _ = torch._dynamo.export(m)(torch.randn(3, 3)) 5808*da0073e9SAndroid Build Coastguard Worker 5809*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch(guard_nn_modules=True) 5810*da0073e9SAndroid Build Coastguard Worker def test_nn_sequential_invocation(self): 5811*da0073e9SAndroid Build Coastguard Worker with freeze_rng_state(): 5812*da0073e9SAndroid Build Coastguard Worker 5813*da0073e9SAndroid Build Coastguard Worker class TestModel(torch.nn.Module): 5814*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 5815*da0073e9SAndroid Build Coastguard Worker super().__init__() 5816*da0073e9SAndroid Build Coastguard Worker self.linears = torch.nn.Sequential( 5817*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(2, 2), 5818*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(2, 2), 5819*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(2, 2), 5820*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(2, 2), 5821*da0073e9SAndroid Build Coastguard Worker ) 5822*da0073e9SAndroid Build Coastguard Worker 5823*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 5824*da0073e9SAndroid Build Coastguard Worker all_but_last = self.linears[:-1] 5825*da0073e9SAndroid Build Coastguard Worker return all_but_last(x) 5826*da0073e9SAndroid Build Coastguard Worker 5827*da0073e9SAndroid Build Coastguard Worker m = TestModel() 5828*da0073e9SAndroid Build Coastguard Worker x = torch.rand((2, 2)) 5829*da0073e9SAndroid Build Coastguard Worker real = m(x) 5830*da0073e9SAndroid Build Coastguard Worker graph, _ = torch._dynamo.export(m)(x) 5831*da0073e9SAndroid Build Coastguard Worker dynamo_result = graph(x) 5832*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(real, dynamo_result)) 5833*da0073e9SAndroid Build Coastguard Worker 5834*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch(guard_nn_modules=True) 5835*da0073e9SAndroid Build Coastguard Worker def test_nn_sequential_invocation_reposition_indices(self): 5836*da0073e9SAndroid Build Coastguard Worker with freeze_rng_state(): 5837*da0073e9SAndroid Build Coastguard Worker 5838*da0073e9SAndroid Build Coastguard Worker class TestModel(torch.nn.Module): 5839*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 5840*da0073e9SAndroid Build Coastguard Worker super().__init__() 5841*da0073e9SAndroid Build Coastguard Worker self.linears = torch.nn.Sequential( 5842*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(2, 2), 5843*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(2, 2), 5844*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(2, 2), 5845*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(2, 2), 5846*da0073e9SAndroid Build Coastguard Worker ) 5847*da0073e9SAndroid Build Coastguard Worker 5848*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 5849*da0073e9SAndroid Build Coastguard Worker all_but_last = self.linears[1:3] 5850*da0073e9SAndroid Build Coastguard Worker return all_but_last(x) 5851*da0073e9SAndroid Build Coastguard Worker 5852*da0073e9SAndroid Build Coastguard Worker m = TestModel() 5853*da0073e9SAndroid Build Coastguard Worker x = torch.rand((2, 2)) 5854*da0073e9SAndroid Build Coastguard Worker real = m(x) 5855*da0073e9SAndroid Build Coastguard Worker graph, _ = torch._dynamo.export(m)(x) 5856*da0073e9SAndroid Build Coastguard Worker dynamo_result = graph(x) 5857*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(real, dynamo_result)) 5858*da0073e9SAndroid Build Coastguard Worker 5859*da0073e9SAndroid Build Coastguard Worker def test_error_on_nested_fx_trace(self): 5860*da0073e9SAndroid Build Coastguard Worker input = torch.rand(2, 3) 5861*da0073e9SAndroid Build Coastguard Worker 5862*da0073e9SAndroid Build Coastguard Worker def f(x): 5863*da0073e9SAndroid Build Coastguard Worker x + x 5864*da0073e9SAndroid Build Coastguard Worker 5865*da0073e9SAndroid Build Coastguard Worker real = f(input) 5866*da0073e9SAndroid Build Coastguard Worker 5867*da0073e9SAndroid Build Coastguard Worker optimized = torch._dynamo.optimize("eager")(f) 5868*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(optimized(input), real)) 5869*da0073e9SAndroid Build Coastguard Worker 5870*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Detected that you are using FX"): 5871*da0073e9SAndroid Build Coastguard Worker gm = torch.fx.symbolic_trace(optimized) 5872*da0073e9SAndroid Build Coastguard Worker 5873*da0073e9SAndroid Build Coastguard Worker @patch.object(torch._dynamo.config, "error_on_nested_fx_trace", False) 5874*da0073e9SAndroid Build Coastguard Worker def test_no_error_on_nested_fx_trace(self): 5875*da0073e9SAndroid Build Coastguard Worker input = torch.rand(2, 3) 5876*da0073e9SAndroid Build Coastguard Worker 5877*da0073e9SAndroid Build Coastguard Worker def f(x): 5878*da0073e9SAndroid Build Coastguard Worker x + x 5879*da0073e9SAndroid Build Coastguard Worker 5880*da0073e9SAndroid Build Coastguard Worker real = f(input) 5881*da0073e9SAndroid Build Coastguard Worker 5882*da0073e9SAndroid Build Coastguard Worker optimized = torch._dynamo.optimize("eager")(f) 5883*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(optimized(input), real)) 5884*da0073e9SAndroid Build Coastguard Worker 5885*da0073e9SAndroid Build Coastguard Worker # should not error 5886*da0073e9SAndroid Build Coastguard Worker gm = torch.fx.symbolic_trace(optimized) 5887*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(gm(input), real)) 5888*da0073e9SAndroid Build Coastguard Worker 5889*da0073e9SAndroid Build Coastguard Worker def test_not_dynamic_scope(self): 5890*da0073e9SAndroid Build Coastguard Worker def f(y): 5891*da0073e9SAndroid Build Coastguard Worker x = 1 5892*da0073e9SAndroid Build Coastguard Worker 5893*da0073e9SAndroid Build Coastguard Worker def g(): 5894*da0073e9SAndroid Build Coastguard Worker x = 2 5895*da0073e9SAndroid Build Coastguard Worker return lambda: x 5896*da0073e9SAndroid Build Coastguard Worker 5897*da0073e9SAndroid Build Coastguard Worker return y + g()() 5898*da0073e9SAndroid Build Coastguard Worker 5899*da0073e9SAndroid Build Coastguard Worker input = torch.zeros(1) 5900*da0073e9SAndroid Build Coastguard Worker real = f(input) 5901*da0073e9SAndroid Build Coastguard Worker optimized = torch._dynamo.optimize("eager")(f) 5902*da0073e9SAndroid Build Coastguard Worker opt = optimized(input) 5903*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(opt, real)) 5904*da0073e9SAndroid Build Coastguard Worker 5905*da0073e9SAndroid Build Coastguard Worker def test_inference_mode(self): 5906*da0073e9SAndroid Build Coastguard Worker @torch.inference_mode() 5907*da0073e9SAndroid Build Coastguard Worker def func(x, y): 5908*da0073e9SAndroid Build Coastguard Worker return x.add(1.0) + y 5909*da0073e9SAndroid Build Coastguard Worker 5910*da0073e9SAndroid Build Coastguard Worker x = torch.ones(4, requires_grad=True) 5911*da0073e9SAndroid Build Coastguard Worker y = torch.ones(4, requires_grad=True) 5912*da0073e9SAndroid Build Coastguard Worker ref = func(x, y) 5913*da0073e9SAndroid Build Coastguard Worker opt_func = torch._dynamo.optimize("eager")(func) 5914*da0073e9SAndroid Build Coastguard Worker 5915*da0073e9SAndroid Build Coastguard Worker x1 = torch.ones(4, requires_grad=True) 5916*da0073e9SAndroid Build Coastguard Worker res = opt_func(x1, y) 5917*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 5918*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(x, x1)) 5919*da0073e9SAndroid Build Coastguard Worker 5920*da0073e9SAndroid Build Coastguard Worker def test_if_cond_nn_mod1(self): 5921*da0073e9SAndroid Build Coastguard Worker class MockModule(torch.nn.Module): 5922*da0073e9SAndroid Build Coastguard Worker def __init__(self, output_relu=True): 5923*da0073e9SAndroid Build Coastguard Worker super().__init__() 5924*da0073e9SAndroid Build Coastguard Worker self.relu = torch.nn.ReLU() if output_relu else None 5925*da0073e9SAndroid Build Coastguard Worker 5926*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 5927*da0073e9SAndroid Build Coastguard Worker x = torch.sin(x) 5928*da0073e9SAndroid Build Coastguard Worker if self.relu: 5929*da0073e9SAndroid Build Coastguard Worker x = self.relu(x) 5930*da0073e9SAndroid Build Coastguard Worker return x 5931*da0073e9SAndroid Build Coastguard Worker 5932*da0073e9SAndroid Build Coastguard Worker model = MockModule() 5933*da0073e9SAndroid Build Coastguard Worker opt_model = torch._dynamo.optimize("eager", nopython=True)(model) 5934*da0073e9SAndroid Build Coastguard Worker 5935*da0073e9SAndroid Build Coastguard Worker x = torch.rand(4) 5936*da0073e9SAndroid Build Coastguard Worker ref = model(x) 5937*da0073e9SAndroid Build Coastguard Worker res = opt_model(x) 5938*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 5939*da0073e9SAndroid Build Coastguard Worker 5940*da0073e9SAndroid Build Coastguard Worker model = MockModule(output_relu=False) 5941*da0073e9SAndroid Build Coastguard Worker opt_model = torch._dynamo.optimize("eager", nopython=True)(model) 5942*da0073e9SAndroid Build Coastguard Worker 5943*da0073e9SAndroid Build Coastguard Worker x = torch.rand(4) 5944*da0073e9SAndroid Build Coastguard Worker ref = model(x) 5945*da0073e9SAndroid Build Coastguard Worker res = opt_model(x) 5946*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 5947*da0073e9SAndroid Build Coastguard Worker 5948*da0073e9SAndroid Build Coastguard Worker def test_if_cond_nn_mod2(self): 5949*da0073e9SAndroid Build Coastguard Worker class MockModule(torch.nn.Module): 5950*da0073e9SAndroid Build Coastguard Worker def __init__(self): 5951*da0073e9SAndroid Build Coastguard Worker super().__init__() 5952*da0073e9SAndroid Build Coastguard Worker self.layer = torch.nn.Sequential() 5953*da0073e9SAndroid Build Coastguard Worker 5954*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 5955*da0073e9SAndroid Build Coastguard Worker if self.layer: 5956*da0073e9SAndroid Build Coastguard Worker return x + 1 5957*da0073e9SAndroid Build Coastguard Worker else: 5958*da0073e9SAndroid Build Coastguard Worker return x - 1 5959*da0073e9SAndroid Build Coastguard Worker 5960*da0073e9SAndroid Build Coastguard Worker model = MockModule() 5961*da0073e9SAndroid Build Coastguard Worker x = torch.rand(4) 5962*da0073e9SAndroid Build Coastguard Worker ref = model(x) 5963*da0073e9SAndroid Build Coastguard Worker opt_model = torch.compile(backend="eager")(model) 5964*da0073e9SAndroid Build Coastguard Worker res = opt_model(x) 5965*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 5966*da0073e9SAndroid Build Coastguard Worker 5967*da0073e9SAndroid Build Coastguard Worker def test_if_cond_nn_mod3(self): 5968*da0073e9SAndroid Build Coastguard Worker def fn(x): 5969*da0073e9SAndroid Build Coastguard Worker if torch.nn.ModuleList(): 5970*da0073e9SAndroid Build Coastguard Worker return x + 1 5971*da0073e9SAndroid Build Coastguard Worker else: 5972*da0073e9SAndroid Build Coastguard Worker return x - 1 5973*da0073e9SAndroid Build Coastguard Worker 5974*da0073e9SAndroid Build Coastguard Worker x = torch.rand(4) 5975*da0073e9SAndroid Build Coastguard Worker ref = fn(x) 5976*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(backend="eager")(fn) 5977*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x) 5978*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 5979*da0073e9SAndroid Build Coastguard Worker 5980*da0073e9SAndroid Build Coastguard Worker def test_if_cond_user_defined_object(self): 5981*da0073e9SAndroid Build Coastguard Worker # obj.__bool__ is not existed 5982*da0073e9SAndroid Build Coastguard Worker class A: # noqa: B903 5983*da0073e9SAndroid Build Coastguard Worker def __init__(self, x): 5984*da0073e9SAndroid Build Coastguard Worker self.x = x 5985*da0073e9SAndroid Build Coastguard Worker 5986*da0073e9SAndroid Build Coastguard Worker # obj.__bool__ is function and returns bool type 5987*da0073e9SAndroid Build Coastguard Worker class B: 5988*da0073e9SAndroid Build Coastguard Worker def __init__(self, x): 5989*da0073e9SAndroid Build Coastguard Worker self.x = x 5990*da0073e9SAndroid Build Coastguard Worker 5991*da0073e9SAndroid Build Coastguard Worker def __bool__(self): 5992*da0073e9SAndroid Build Coastguard Worker return self.x > 0 5993*da0073e9SAndroid Build Coastguard Worker 5994*da0073e9SAndroid Build Coastguard Worker # obj.__bool__ is non-function 5995*da0073e9SAndroid Build Coastguard Worker class C: 5996*da0073e9SAndroid Build Coastguard Worker def __init__(self, x): 5997*da0073e9SAndroid Build Coastguard Worker self.x = x 5998*da0073e9SAndroid Build Coastguard Worker self.__bool__ = False 5999*da0073e9SAndroid Build Coastguard Worker 6000*da0073e9SAndroid Build Coastguard Worker def fn(x, obj): 6001*da0073e9SAndroid Build Coastguard Worker if not obj: 6002*da0073e9SAndroid Build Coastguard Worker return x + 1 6003*da0073e9SAndroid Build Coastguard Worker else: 6004*da0073e9SAndroid Build Coastguard Worker return x - 1 6005*da0073e9SAndroid Build Coastguard Worker 6006*da0073e9SAndroid Build Coastguard Worker x = torch.rand(4) 6007*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 6008*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 6009*da0073e9SAndroid Build Coastguard Worker obj1 = A(0.5) 6010*da0073e9SAndroid Build Coastguard Worker obj2 = B(0.5) 6011*da0073e9SAndroid Build Coastguard Worker obj3 = B(-0.5) 6012*da0073e9SAndroid Build Coastguard Worker obj4 = C(0.5) 6013*da0073e9SAndroid Build Coastguard Worker for obj in [obj1, obj2, obj3, obj4, obj3, obj2]: 6014*da0073e9SAndroid Build Coastguard Worker ref = fn(x, obj) 6015*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x, obj) 6016*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 6017*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 4) 6018*da0073e9SAndroid Build Coastguard Worker 6019*da0073e9SAndroid Build Coastguard Worker def test_if_cond_user_defined_object2(self): 6020*da0073e9SAndroid Build Coastguard Worker # obj.__bool__ is function and returns non-bool type 6021*da0073e9SAndroid Build Coastguard Worker class MyObj: 6022*da0073e9SAndroid Build Coastguard Worker def __init__(self, x): 6023*da0073e9SAndroid Build Coastguard Worker self.x = x 6024*da0073e9SAndroid Build Coastguard Worker 6025*da0073e9SAndroid Build Coastguard Worker def __bool__(self): 6026*da0073e9SAndroid Build Coastguard Worker self.x = 1.2 6027*da0073e9SAndroid Build Coastguard Worker return self.x 6028*da0073e9SAndroid Build Coastguard Worker 6029*da0073e9SAndroid Build Coastguard Worker def fn(a, obj): 6030*da0073e9SAndroid Build Coastguard Worker if not obj: 6031*da0073e9SAndroid Build Coastguard Worker return a + obj.x 6032*da0073e9SAndroid Build Coastguard Worker else: 6033*da0073e9SAndroid Build Coastguard Worker return a - obj.x 6034*da0073e9SAndroid Build Coastguard Worker 6035*da0073e9SAndroid Build Coastguard Worker x = torch.rand(4) 6036*da0073e9SAndroid Build Coastguard Worker obj = MyObj(0.5) 6037*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager")(fn) 6038*da0073e9SAndroid Build Coastguard Worker try: 6039*da0073e9SAndroid Build Coastguard Worker opt_fn(x, obj) 6040*da0073e9SAndroid Build Coastguard Worker self.assertFalse(True) 6041*da0073e9SAndroid Build Coastguard Worker except TypeError as e: 6042*da0073e9SAndroid Build Coastguard Worker self.assertIn("__bool__ should return bool, returned float", str(e)) 6043*da0073e9SAndroid Build Coastguard Worker 6044*da0073e9SAndroid Build Coastguard Worker def test_if_cond_user_defined_object3(self): 6045*da0073e9SAndroid Build Coastguard Worker # obj.__bool__ is not existed, but obj.__len__ exists 6046*da0073e9SAndroid Build Coastguard Worker class A: # noqa: B903 6047*da0073e9SAndroid Build Coastguard Worker def __init__(self, x): 6048*da0073e9SAndroid Build Coastguard Worker self.x = x 6049*da0073e9SAndroid Build Coastguard Worker 6050*da0073e9SAndroid Build Coastguard Worker def __len__(self): 6051*da0073e9SAndroid Build Coastguard Worker return len(self.x) 6052*da0073e9SAndroid Build Coastguard Worker 6053*da0073e9SAndroid Build Coastguard Worker # obj.__bool__ takes precedence over obj.__len__ 6054*da0073e9SAndroid Build Coastguard Worker class B: 6055*da0073e9SAndroid Build Coastguard Worker def __init__(self, x): 6056*da0073e9SAndroid Build Coastguard Worker self.x = x 6057*da0073e9SAndroid Build Coastguard Worker 6058*da0073e9SAndroid Build Coastguard Worker def __bool__(self): 6059*da0073e9SAndroid Build Coastguard Worker return False 6060*da0073e9SAndroid Build Coastguard Worker 6061*da0073e9SAndroid Build Coastguard Worker def __len__(self): 6062*da0073e9SAndroid Build Coastguard Worker return len(self.x) 6063*da0073e9SAndroid Build Coastguard Worker 6064*da0073e9SAndroid Build Coastguard Worker def fn(x, obj): 6065*da0073e9SAndroid Build Coastguard Worker if not obj: 6066*da0073e9SAndroid Build Coastguard Worker return x + 1 6067*da0073e9SAndroid Build Coastguard Worker else: 6068*da0073e9SAndroid Build Coastguard Worker return x - 1 6069*da0073e9SAndroid Build Coastguard Worker 6070*da0073e9SAndroid Build Coastguard Worker x = torch.rand(4) 6071*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(backend="eager", fullgraph=True)(fn) 6072*da0073e9SAndroid Build Coastguard Worker obj1 = A([1, 2, 3]) 6073*da0073e9SAndroid Build Coastguard Worker obj2 = A([]) 6074*da0073e9SAndroid Build Coastguard Worker obj3 = B([1, 2, 3]) 6075*da0073e9SAndroid Build Coastguard Worker obj4 = B([]) 6076*da0073e9SAndroid Build Coastguard Worker for obj in [obj1, obj2, obj3, obj4]: 6077*da0073e9SAndroid Build Coastguard Worker ref = fn(x, obj) 6078*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x, obj) 6079*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 6080*da0073e9SAndroid Build Coastguard Worker 6081*da0073e9SAndroid Build Coastguard Worker def test_class_has_instancecheck_method(self): 6082*da0073e9SAndroid Build Coastguard Worker class A: 6083*da0073e9SAndroid Build Coastguard Worker pass 6084*da0073e9SAndroid Build Coastguard Worker 6085*da0073e9SAndroid Build Coastguard Worker class ExampleMeta(type): 6086*da0073e9SAndroid Build Coastguard Worker def __instancecheck__(cls, instance): 6087*da0073e9SAndroid Build Coastguard Worker return True 6088*da0073e9SAndroid Build Coastguard Worker 6089*da0073e9SAndroid Build Coastguard Worker class B(metaclass=ExampleMeta): 6090*da0073e9SAndroid Build Coastguard Worker pass 6091*da0073e9SAndroid Build Coastguard Worker 6092*da0073e9SAndroid Build Coastguard Worker def fn(x, obj): 6093*da0073e9SAndroid Build Coastguard Worker if isinstance(obj, B): 6094*da0073e9SAndroid Build Coastguard Worker return x + 1 6095*da0073e9SAndroid Build Coastguard Worker else: 6096*da0073e9SAndroid Build Coastguard Worker return x - 1 6097*da0073e9SAndroid Build Coastguard Worker 6098*da0073e9SAndroid Build Coastguard Worker x = torch.rand(4) 6099*da0073e9SAndroid Build Coastguard Worker obj = A() 6100*da0073e9SAndroid Build Coastguard Worker ref = fn(x, obj) 6101*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) 6102*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x, obj) 6103*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 6104*da0073e9SAndroid Build Coastguard Worker 6105*da0073e9SAndroid Build Coastguard Worker def test_torch_cuda_is_available(self): 6106*da0073e9SAndroid Build Coastguard Worker def fn(x): 6107*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 6108*da0073e9SAndroid Build Coastguard Worker return x + 1 6109*da0073e9SAndroid Build Coastguard Worker else: 6110*da0073e9SAndroid Build Coastguard Worker return x - 1 6111*da0073e9SAndroid Build Coastguard Worker 6112*da0073e9SAndroid Build Coastguard Worker x = torch.rand(4) 6113*da0073e9SAndroid Build Coastguard Worker ref = fn(x) 6114*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) 6115*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x) 6116*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 6117*da0073e9SAndroid Build Coastguard Worker 6118*da0073e9SAndroid Build Coastguard Worker def test_variable_tracker_recursively_contains(self): 6119*da0073e9SAndroid Build Coastguard Worker # VariableTracker.recursively_contains should be updated correctly when mutation happens 6120*da0073e9SAndroid Build Coastguard Worker def fn(x): 6121*da0073e9SAndroid Build Coastguard Worker data = [[None] * 3] * 3 6122*da0073e9SAndroid Build Coastguard Worker for i in range(3): 6123*da0073e9SAndroid Build Coastguard Worker if i == 0: 6124*da0073e9SAndroid Build Coastguard Worker data[0][i] = x 6125*da0073e9SAndroid Build Coastguard Worker else: 6126*da0073e9SAndroid Build Coastguard Worker data[0][i] = data[0][i - 1] + 1 6127*da0073e9SAndroid Build Coastguard Worker return data[0][-1] 6128*da0073e9SAndroid Build Coastguard Worker 6129*da0073e9SAndroid Build Coastguard Worker x = torch.rand(4) 6130*da0073e9SAndroid Build Coastguard Worker ref = fn(x) 6131*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) 6132*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x) 6133*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 6134*da0073e9SAndroid Build Coastguard Worker 6135*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "requires cuda") 6136*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not torch.backends.cudnn.is_available(), "requires cudnn") 6137*da0073e9SAndroid Build Coastguard Worker def test_torch_cudnn_is_acceptable(self): 6138*da0073e9SAndroid Build Coastguard Worker def fn(x): 6139*da0073e9SAndroid Build Coastguard Worker if torch.backends.cudnn.is_acceptable(tensor=x): 6140*da0073e9SAndroid Build Coastguard Worker return x + 1 6141*da0073e9SAndroid Build Coastguard Worker return x 6142*da0073e9SAndroid Build Coastguard Worker 6143*da0073e9SAndroid Build Coastguard Worker x = torch.rand(4).cuda() 6144*da0073e9SAndroid Build Coastguard Worker ref = fn(x) 6145*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) 6146*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x) 6147*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 6148*da0073e9SAndroid Build Coastguard Worker 6149*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "requires cuda") 6150*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not torch.backends.cudnn.is_available(), "requires cudnn") 6151*da0073e9SAndroid Build Coastguard Worker def test_torch_cudnn_is_acceptable_bad_inputs(self): 6152*da0073e9SAndroid Build Coastguard Worker def fn1(x): 6153*da0073e9SAndroid Build Coastguard Worker if torch.backends.cudnn.is_acceptable("invalid"): 6154*da0073e9SAndroid Build Coastguard Worker return x + 1 6155*da0073e9SAndroid Build Coastguard Worker return x 6156*da0073e9SAndroid Build Coastguard Worker 6157*da0073e9SAndroid Build Coastguard Worker def fn2(x): 6158*da0073e9SAndroid Build Coastguard Worker if torch.backends.cudnn.is_acceptable(x, 3.14): 6159*da0073e9SAndroid Build Coastguard Worker return x + 1 6160*da0073e9SAndroid Build Coastguard Worker return x 6161*da0073e9SAndroid Build Coastguard Worker 6162*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 6163*da0073e9SAndroid Build Coastguard Worker AssertionError, "Expect input to cudnn.is_acceptable to be a tensor" 6164*da0073e9SAndroid Build Coastguard Worker ): 6165*da0073e9SAndroid Build Coastguard Worker x1 = torch.rand(4).cuda() 6166*da0073e9SAndroid Build Coastguard Worker opt_fn1 = torch._dynamo.optimize("eager", nopython=True)(fn1) 6167*da0073e9SAndroid Build Coastguard Worker res1 = opt_fn1(x1) 6168*da0073e9SAndroid Build Coastguard Worker 6169*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 6170*da0073e9SAndroid Build Coastguard Worker AssertionError, "Expect 1 input to cudnn.is_acceptable" 6171*da0073e9SAndroid Build Coastguard Worker ): 6172*da0073e9SAndroid Build Coastguard Worker x2 = torch.rand(4).cuda() 6173*da0073e9SAndroid Build Coastguard Worker opt_fn2 = torch._dynamo.optimize("eager", nopython=True)(fn2) 6174*da0073e9SAndroid Build Coastguard Worker res = opt_fn2(x2) 6175*da0073e9SAndroid Build Coastguard Worker 6176*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "requires cuda") 6177*da0073e9SAndroid Build Coastguard Worker def test_get_device(self): 6178*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 6179*da0073e9SAndroid Build Coastguard Worker x = x + 1 6180*da0073e9SAndroid Build Coastguard Worker y = y + 1 6181*da0073e9SAndroid Build Coastguard Worker return x.get_device(), y.get_device() 6182*da0073e9SAndroid Build Coastguard Worker 6183*da0073e9SAndroid Build Coastguard Worker x = torch.rand(4, device="cuda") 6184*da0073e9SAndroid Build Coastguard Worker y = torch.rand(4, device="cpu") 6185*da0073e9SAndroid Build Coastguard Worker ref = fn(x, y) 6186*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) 6187*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x, y) 6188*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 6189*da0073e9SAndroid Build Coastguard Worker 6190*da0073e9SAndroid Build Coastguard Worker def test_disable_flag(self): 6191*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 6192*da0073e9SAndroid Build Coastguard Worker 6193*da0073e9SAndroid Build Coastguard Worker with patch.dict(os.environ, {"TORCH_COMPILE_DISABLE": "1"}): 6194*da0073e9SAndroid Build Coastguard Worker 6195*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 6196*da0073e9SAndroid Build Coastguard Worker x = x + 1 6197*da0073e9SAndroid Build Coastguard Worker y = y + 1 6198*da0073e9SAndroid Build Coastguard Worker 6199*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnt) 6200*da0073e9SAndroid Build Coastguard Worker 6201*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 0) 6202*da0073e9SAndroid Build Coastguard Worker 6203*da0073e9SAndroid Build Coastguard Worker def test_is_compiling(self): 6204*da0073e9SAndroid Build Coastguard Worker def f1(): 6205*da0073e9SAndroid Build Coastguard Worker if torch._dynamo.is_compiling(): 6206*da0073e9SAndroid Build Coastguard Worker return torch.ones(2, 2) 6207*da0073e9SAndroid Build Coastguard Worker else: 6208*da0073e9SAndroid Build Coastguard Worker return torch.zeros(2, 2) 6209*da0073e9SAndroid Build Coastguard Worker 6210*da0073e9SAndroid Build Coastguard Worker def f2(): 6211*da0073e9SAndroid Build Coastguard Worker if torch._utils.is_compiling(): 6212*da0073e9SAndroid Build Coastguard Worker return torch.ones(2, 2) 6213*da0073e9SAndroid Build Coastguard Worker else: 6214*da0073e9SAndroid Build Coastguard Worker return torch.zeros(2, 2) 6215*da0073e9SAndroid Build Coastguard Worker 6216*da0073e9SAndroid Build Coastguard Worker def f3(): 6217*da0073e9SAndroid Build Coastguard Worker if torch.compiler.is_compiling(): 6218*da0073e9SAndroid Build Coastguard Worker return torch.ones(2, 2) 6219*da0073e9SAndroid Build Coastguard Worker else: 6220*da0073e9SAndroid Build Coastguard Worker return torch.zeros(2, 2) 6221*da0073e9SAndroid Build Coastguard Worker 6222*da0073e9SAndroid Build Coastguard Worker def f4(): 6223*da0073e9SAndroid Build Coastguard Worker if torch.compiler.is_dynamo_compiling(): 6224*da0073e9SAndroid Build Coastguard Worker return torch.ones(2, 2) 6225*da0073e9SAndroid Build Coastguard Worker else: 6226*da0073e9SAndroid Build Coastguard Worker return torch.zeros(2, 2) 6227*da0073e9SAndroid Build Coastguard Worker 6228*da0073e9SAndroid Build Coastguard Worker for f in [f1, f2, f3, f4]: 6229*da0073e9SAndroid Build Coastguard Worker opt_f = torch._dynamo.optimize("eager")(f) 6230*da0073e9SAndroid Build Coastguard Worker 6231*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f(), torch.zeros(2, 2)) 6232*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_f(), torch.ones(2, 2)) 6233*da0073e9SAndroid Build Coastguard Worker 6234*da0073e9SAndroid Build Coastguard Worker def test_torch_generator_set_state(self): 6235*da0073e9SAndroid Build Coastguard Worker def fn(): 6236*da0073e9SAndroid Build Coastguard Worker default_state = torch.default_generator.get_state() 6237*da0073e9SAndroid Build Coastguard Worker x = torch.rand([2, 3]) 6238*da0073e9SAndroid Build Coastguard Worker if default_state.dtype != "float32": 6239*da0073e9SAndroid Build Coastguard Worker x = x * 2 6240*da0073e9SAndroid Build Coastguard Worker torch._dynamo.graph_break() 6241*da0073e9SAndroid Build Coastguard Worker torch.default_generator.set_state(default_state) 6242*da0073e9SAndroid Build Coastguard Worker y = torch.rand([2, 3]) 6243*da0073e9SAndroid Build Coastguard Worker return x, y 6244*da0073e9SAndroid Build Coastguard Worker 6245*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager")(fn) 6246*da0073e9SAndroid Build Coastguard Worker x, y = opt_fn() 6247*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, y * 2) 6248*da0073e9SAndroid Build Coastguard Worker 6249*da0073e9SAndroid Build Coastguard Worker def test_torch_distributions_lazy_property(self): 6250*da0073e9SAndroid Build Coastguard Worker def fn(x): 6251*da0073e9SAndroid Build Coastguard Worker return torch.distributions.Categorical(probs=x).entropy() 6252*da0073e9SAndroid Build Coastguard Worker 6253*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager")(fn) 6254*da0073e9SAndroid Build Coastguard Worker x = torch.rand([4, 4]) 6255*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(x), fn(x)) 6256*da0073e9SAndroid Build Coastguard Worker 6257*da0073e9SAndroid Build Coastguard Worker def test_guard_failure_fn(self): 6258*da0073e9SAndroid Build Coastguard Worker def fn(x, y, k): 6259*da0073e9SAndroid Build Coastguard Worker x = x + 1 6260*da0073e9SAndroid Build Coastguard Worker y = y + 1 6261*da0073e9SAndroid Build Coastguard Worker return x * y * k 6262*da0073e9SAndroid Build Coastguard Worker 6263*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([0.5, 0.5]) 6264*da0073e9SAndroid Build Coastguard Worker y = torch.tensor([1.0, 1.0]) 6265*da0073e9SAndroid Build Coastguard Worker 6266*da0073e9SAndroid Build Coastguard Worker guard_failure = None 6267*da0073e9SAndroid Build Coastguard Worker 6268*da0073e9SAndroid Build Coastguard Worker def guard_failures(failure): 6269*da0073e9SAndroid Build Coastguard Worker nonlocal guard_failure 6270*da0073e9SAndroid Build Coastguard Worker guard_failure = failure 6271*da0073e9SAndroid Build Coastguard Worker 6272*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize( 6273*da0073e9SAndroid Build Coastguard Worker "eager", nopython=True, guard_fail_fn=guard_failures 6274*da0073e9SAndroid Build Coastguard Worker )(fn) 6275*da0073e9SAndroid Build Coastguard Worker 6276*da0073e9SAndroid Build Coastguard Worker x2 = torch.tensor([0.5, 0.5, 1.0]) 6277*da0073e9SAndroid Build Coastguard Worker y2 = torch.tensor([0.5, 0.5, 0.5]) 6278*da0073e9SAndroid Build Coastguard Worker 6279*da0073e9SAndroid Build Coastguard Worker opt_fn(x, y, 3) 6280*da0073e9SAndroid Build Coastguard Worker opt_fn(x2, y2, 5) 6281*da0073e9SAndroid Build Coastguard Worker 6282*da0073e9SAndroid Build Coastguard Worker if ( 6283*da0073e9SAndroid Build Coastguard Worker not torch._dynamo.config.specialize_int 6284*da0073e9SAndroid Build Coastguard Worker and not torch._dynamo.config.assume_static_by_default 6285*da0073e9SAndroid Build Coastguard Worker ): 6286*da0073e9SAndroid Build Coastguard Worker # we didn't actually test guard_failure_fn here but whatever, 6287*da0073e9SAndroid Build Coastguard Worker # nice to see no guard failure on the test 6288*da0073e9SAndroid Build Coastguard Worker self.assertTrue(guard_failure is None) 6289*da0073e9SAndroid Build Coastguard Worker else: 6290*da0073e9SAndroid Build Coastguard Worker self.assertTrue(guard_failure is not None) 6291*da0073e9SAndroid Build Coastguard Worker 6292*da0073e9SAndroid Build Coastguard Worker def test_guard_failure_fn_shape_control(self): 6293*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 6294*da0073e9SAndroid Build Coastguard Worker if x.shape[0] < 3: 6295*da0073e9SAndroid Build Coastguard Worker if y.shape[0] < 3: 6296*da0073e9SAndroid Build Coastguard Worker return x * y 6297*da0073e9SAndroid Build Coastguard Worker else: 6298*da0073e9SAndroid Build Coastguard Worker return x + y 6299*da0073e9SAndroid Build Coastguard Worker else: 6300*da0073e9SAndroid Build Coastguard Worker return -1 6301*da0073e9SAndroid Build Coastguard Worker 6302*da0073e9SAndroid Build Coastguard Worker x = torch.randn([2, 2]) 6303*da0073e9SAndroid Build Coastguard Worker y = torch.randn([2, 2]) 6304*da0073e9SAndroid Build Coastguard Worker 6305*da0073e9SAndroid Build Coastguard Worker guard_failure = None 6306*da0073e9SAndroid Build Coastguard Worker 6307*da0073e9SAndroid Build Coastguard Worker def guard_failures(failure): 6308*da0073e9SAndroid Build Coastguard Worker nonlocal guard_failure 6309*da0073e9SAndroid Build Coastguard Worker guard_failure = failure 6310*da0073e9SAndroid Build Coastguard Worker 6311*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize( 6312*da0073e9SAndroid Build Coastguard Worker "eager", nopython=True, guard_fail_fn=guard_failures 6313*da0073e9SAndroid Build Coastguard Worker )(fn) 6314*da0073e9SAndroid Build Coastguard Worker 6315*da0073e9SAndroid Build Coastguard Worker x2 = torch.randn([5, 5]) 6316*da0073e9SAndroid Build Coastguard Worker y2 = torch.randn([5, 5]) 6317*da0073e9SAndroid Build Coastguard Worker 6318*da0073e9SAndroid Build Coastguard Worker opt_fn(x, y) 6319*da0073e9SAndroid Build Coastguard Worker opt_fn(x2, y2) 6320*da0073e9SAndroid Build Coastguard Worker 6321*da0073e9SAndroid Build Coastguard Worker self.assertTrue(guard_failure is not None) 6322*da0073e9SAndroid Build Coastguard Worker first_guard_failure = guard_failure[0].partition("\n")[0] 6323*da0073e9SAndroid Build Coastguard Worker if torch._dynamo.config.assume_static_by_default: 6324*da0073e9SAndroid Build Coastguard Worker self.assertIn( 6325*da0073e9SAndroid Build Coastguard Worker """tensor 'L['x']' size mismatch at index 0. expected 2, actual 5""", 6326*da0073e9SAndroid Build Coastguard Worker first_guard_failure, 6327*da0073e9SAndroid Build Coastguard Worker ) 6328*da0073e9SAndroid Build Coastguard Worker else: 6329*da0073e9SAndroid Build Coastguard Worker self.assertIn("""2 <= L['x'].size()[0] <= 2""", first_guard_failure) 6330*da0073e9SAndroid Build Coastguard Worker 6331*da0073e9SAndroid Build Coastguard Worker def test_guard_failure_fn2(self): 6332*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 6333*da0073e9SAndroid Build Coastguard Worker x = x + 1 6334*da0073e9SAndroid Build Coastguard Worker y = y + 1 6335*da0073e9SAndroid Build Coastguard Worker return x * y 6336*da0073e9SAndroid Build Coastguard Worker 6337*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([0.5, 0.5]) 6338*da0073e9SAndroid Build Coastguard Worker y = torch.tensor([1.0, 1.0]) 6339*da0073e9SAndroid Build Coastguard Worker 6340*da0073e9SAndroid Build Coastguard Worker guard_failure = None 6341*da0073e9SAndroid Build Coastguard Worker 6342*da0073e9SAndroid Build Coastguard Worker def guard_failures(failure): 6343*da0073e9SAndroid Build Coastguard Worker nonlocal guard_failure 6344*da0073e9SAndroid Build Coastguard Worker guard_failure = failure 6345*da0073e9SAndroid Build Coastguard Worker 6346*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize( 6347*da0073e9SAndroid Build Coastguard Worker "eager", nopython=True, guard_fail_fn=guard_failures 6348*da0073e9SAndroid Build Coastguard Worker )(fn) 6349*da0073e9SAndroid Build Coastguard Worker 6350*da0073e9SAndroid Build Coastguard Worker x2 = torch.tensor([0.5, 0.5, 1.0]) 6351*da0073e9SAndroid Build Coastguard Worker y2 = torch.tensor([0.5, 0.5, 0.5]) 6352*da0073e9SAndroid Build Coastguard Worker 6353*da0073e9SAndroid Build Coastguard Worker opt_fn(x, y) 6354*da0073e9SAndroid Build Coastguard Worker opt_fn(x2, y2) 6355*da0073e9SAndroid Build Coastguard Worker 6356*da0073e9SAndroid Build Coastguard Worker if torch._dynamo.config.assume_static_by_default: 6357*da0073e9SAndroid Build Coastguard Worker self.assertIn( 6358*da0073e9SAndroid Build Coastguard Worker """tensor 'L['x']' size mismatch at index 0. expected 2, actual 3""", 6359*da0073e9SAndroid Build Coastguard Worker guard_failure[0], 6360*da0073e9SAndroid Build Coastguard Worker ) 6361*da0073e9SAndroid Build Coastguard Worker else: 6362*da0073e9SAndroid Build Coastguard Worker self.assertTrue(guard_failure is None) 6363*da0073e9SAndroid Build Coastguard Worker 6364*da0073e9SAndroid Build Coastguard Worker def test_guard_failure_fn_tensor_iter(self): 6365*da0073e9SAndroid Build Coastguard Worker def fn(x): 6366*da0073e9SAndroid Build Coastguard Worker for y in x: 6367*da0073e9SAndroid Build Coastguard Worker y.add_(1.0) 6368*da0073e9SAndroid Build Coastguard Worker return y 6369*da0073e9SAndroid Build Coastguard Worker 6370*da0073e9SAndroid Build Coastguard Worker guard_failure = None 6371*da0073e9SAndroid Build Coastguard Worker 6372*da0073e9SAndroid Build Coastguard Worker def guard_failures(failure): 6373*da0073e9SAndroid Build Coastguard Worker nonlocal guard_failure 6374*da0073e9SAndroid Build Coastguard Worker guard_failure = failure 6375*da0073e9SAndroid Build Coastguard Worker 6376*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize( 6377*da0073e9SAndroid Build Coastguard Worker "eager", nopython=True, guard_fail_fn=guard_failures 6378*da0073e9SAndroid Build Coastguard Worker )(fn) 6379*da0073e9SAndroid Build Coastguard Worker 6380*da0073e9SAndroid Build Coastguard Worker args1 = torch.randn(10, 10) 6381*da0073e9SAndroid Build Coastguard Worker out = fn(args1) 6382*da0073e9SAndroid Build Coastguard Worker opt_out = opt_fn(args1) 6383*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(out, opt_out)) 6384*da0073e9SAndroid Build Coastguard Worker 6385*da0073e9SAndroid Build Coastguard Worker args2 = torch.randn(9, 10) 6386*da0073e9SAndroid Build Coastguard Worker out = fn(args2) 6387*da0073e9SAndroid Build Coastguard Worker opt_out = opt_fn(args2) 6388*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(out, opt_out)) 6389*da0073e9SAndroid Build Coastguard Worker 6390*da0073e9SAndroid Build Coastguard Worker # guard is expected for both static and dynamic shapes 6391*da0073e9SAndroid Build Coastguard Worker self.assertTrue(guard_failure is not None) 6392*da0073e9SAndroid Build Coastguard Worker self.assertIn( 6393*da0073e9SAndroid Build Coastguard Worker """len(L['x']) == 10""", 6394*da0073e9SAndroid Build Coastguard Worker guard_failure[0], 6395*da0073e9SAndroid Build Coastguard Worker ) 6396*da0073e9SAndroid Build Coastguard Worker 6397*da0073e9SAndroid Build Coastguard Worker def test_restore_graphstate(self): 6398*da0073e9SAndroid Build Coastguard Worker # This function does some guard accumulation, 6399*da0073e9SAndroid Build Coastguard Worker # and then rolls back due to control flow. 6400*da0073e9SAndroid Build Coastguard Worker # The idea is that if one were printing guards as they appear, 6401*da0073e9SAndroid Build Coastguard Worker # they would see this insert a guard that does not show up in the final set of 6402*da0073e9SAndroid Build Coastguard Worker # guards as we rolled back from it. 6403*da0073e9SAndroid Build Coastguard Worker def nested_fn(s): 6404*da0073e9SAndroid Build Coastguard Worker if x[0] < 10: 6405*da0073e9SAndroid Build Coastguard Worker return s * s 6406*da0073e9SAndroid Build Coastguard Worker return s 6407*da0073e9SAndroid Build Coastguard Worker 6408*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 6409*da0073e9SAndroid Build Coastguard Worker x = x + 1 6410*da0073e9SAndroid Build Coastguard Worker y = nested_fn(y) 6411*da0073e9SAndroid Build Coastguard Worker y = y + 10 6412*da0073e9SAndroid Build Coastguard Worker return x * y 6413*da0073e9SAndroid Build Coastguard Worker 6414*da0073e9SAndroid Build Coastguard Worker all_guards = [] 6415*da0073e9SAndroid Build Coastguard Worker 6416*da0073e9SAndroid Build Coastguard Worker def guard_export_print(guards): 6417*da0073e9SAndroid Build Coastguard Worker nonlocal all_guards 6418*da0073e9SAndroid Build Coastguard Worker all_guards.extend(guards) 6419*da0073e9SAndroid Build Coastguard Worker 6420*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager", guard_export_fn=guard_export_print)(fn) 6421*da0073e9SAndroid Build Coastguard Worker 6422*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([0.5, 0.5]) 6423*da0073e9SAndroid Build Coastguard Worker y = torch.tensor([1.0, 1.0]) 6424*da0073e9SAndroid Build Coastguard Worker opt_fn(x, y) 6425*da0073e9SAndroid Build Coastguard Worker 6426*da0073e9SAndroid Build Coastguard Worker for guard in all_guards: 6427*da0073e9SAndroid Build Coastguard Worker # This guard was created 6428*da0073e9SAndroid Build Coastguard Worker self.assertTrue(guard.name != "nested_fn.__closure__[0].cell_contents") 6429*da0073e9SAndroid Build Coastguard Worker 6430*da0073e9SAndroid Build Coastguard Worker def test_call_parent_non_class_methods_from_child(self): 6431*da0073e9SAndroid Build Coastguard Worker class A: 6432*da0073e9SAndroid Build Coastguard Worker a = 4 6433*da0073e9SAndroid Build Coastguard Worker 6434*da0073e9SAndroid Build Coastguard Worker def add(self, x): 6435*da0073e9SAndroid Build Coastguard Worker return x + 10 6436*da0073e9SAndroid Build Coastguard Worker 6437*da0073e9SAndroid Build Coastguard Worker def mul(self, x): 6438*da0073e9SAndroid Build Coastguard Worker return x * 0.1 6439*da0073e9SAndroid Build Coastguard Worker 6440*da0073e9SAndroid Build Coastguard Worker class B(A): 6441*da0073e9SAndroid Build Coastguard Worker coeff = 4 6442*da0073e9SAndroid Build Coastguard Worker 6443*da0073e9SAndroid Build Coastguard Worker def add(self, x): 6444*da0073e9SAndroid Build Coastguard Worker return x + 20 6445*da0073e9SAndroid Build Coastguard Worker 6446*da0073e9SAndroid Build Coastguard Worker @classmethod 6447*da0073e9SAndroid Build Coastguard Worker def cube(cls, x): 6448*da0073e9SAndroid Build Coastguard Worker return cls.coeff * x * x * x 6449*da0073e9SAndroid Build Coastguard Worker 6450*da0073e9SAndroid Build Coastguard Worker def mul(self, x): 6451*da0073e9SAndroid Build Coastguard Worker return super().mul(x) * x * 0.2 6452*da0073e9SAndroid Build Coastguard Worker 6453*da0073e9SAndroid Build Coastguard Worker class C(B): 6454*da0073e9SAndroid Build Coastguard Worker def add(self, x): 6455*da0073e9SAndroid Build Coastguard Worker b = super().cube(x) 6456*da0073e9SAndroid Build Coastguard Worker c = A.add(self, x) 6457*da0073e9SAndroid Build Coastguard Worker d = B.mul(self, x) 6458*da0073e9SAndroid Build Coastguard Worker e = super(B, self).add(x) 6459*da0073e9SAndroid Build Coastguard Worker f = super().a * x 6460*da0073e9SAndroid Build Coastguard Worker return b + c + d + e + f 6461*da0073e9SAndroid Build Coastguard Worker 6462*da0073e9SAndroid Build Coastguard Worker x = torch.rand(4) 6463*da0073e9SAndroid Build Coastguard Worker fn = C().add 6464*da0073e9SAndroid Build Coastguard Worker ref = fn(x) 6465*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 6466*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnt, nopython=True)(fn) 6467*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x) 6468*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 6469*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 6470*da0073e9SAndroid Build Coastguard Worker 6471*da0073e9SAndroid Build Coastguard Worker # Check recompilation 6472*da0073e9SAndroid Build Coastguard Worker A.a = 5 6473*da0073e9SAndroid Build Coastguard Worker ref = fn(x) 6474*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x) 6475*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 6476*da0073e9SAndroid Build Coastguard Worker # Ensure that super guard checks are working as expected 6477*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x) 6478*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 2) 6479*da0073e9SAndroid Build Coastguard Worker 6480*da0073e9SAndroid Build Coastguard Worker def test_builder_for_class_with_metaclass(self): 6481*da0073e9SAndroid Build Coastguard Worker class ExampleMeta(type): 6482*da0073e9SAndroid Build Coastguard Worker pass 6483*da0073e9SAndroid Build Coastguard Worker 6484*da0073e9SAndroid Build Coastguard Worker class MyClass(metaclass=ExampleMeta): 6485*da0073e9SAndroid Build Coastguard Worker pass 6486*da0073e9SAndroid Build Coastguard Worker 6487*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 6488*da0073e9SAndroid Build Coastguard Worker if isinstance(y, MyClass): 6489*da0073e9SAndroid Build Coastguard Worker return x + 1 6490*da0073e9SAndroid Build Coastguard Worker else: 6491*da0073e9SAndroid Build Coastguard Worker return x - 1 6492*da0073e9SAndroid Build Coastguard Worker 6493*da0073e9SAndroid Build Coastguard Worker x = torch.rand([4, 4]) 6494*da0073e9SAndroid Build Coastguard Worker y = MyClass() 6495*da0073e9SAndroid Build Coastguard Worker ref = fn(x, y) 6496*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager")(fn) 6497*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x, y) 6498*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 6499*da0073e9SAndroid Build Coastguard Worker 6500*da0073e9SAndroid Build Coastguard Worker def test_tuple_from_tuple_iter(self): 6501*da0073e9SAndroid Build Coastguard Worker def inner_fn(*args): 6502*da0073e9SAndroid Build Coastguard Worker acc = torch.ones(10, 10) 6503*da0073e9SAndroid Build Coastguard Worker for arg in args: 6504*da0073e9SAndroid Build Coastguard Worker acc.add_(arg) 6505*da0073e9SAndroid Build Coastguard Worker 6506*da0073e9SAndroid Build Coastguard Worker return acc 6507*da0073e9SAndroid Build Coastguard Worker 6508*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize("eager") 6509*da0073e9SAndroid Build Coastguard Worker def fn(inputs, params): 6510*da0073e9SAndroid Build Coastguard Worker y = tuple(inputs) + tuple(params) 6511*da0073e9SAndroid Build Coastguard Worker return inner_fn(*y) 6512*da0073e9SAndroid Build Coastguard Worker 6513*da0073e9SAndroid Build Coastguard Worker inputs = [torch.randn(10, 10) for _ in range(3)] 6514*da0073e9SAndroid Build Coastguard Worker 6515*da0073e9SAndroid Build Coastguard Worker fn(inputs, iter(tuple(inputs))) 6516*da0073e9SAndroid Build Coastguard Worker 6517*da0073e9SAndroid Build Coastguard Worker def fn(params): 6518*da0073e9SAndroid Build Coastguard Worker y = tuple(params) 6519*da0073e9SAndroid Build Coastguard Worker return inner_fn(*y) 6520*da0073e9SAndroid Build Coastguard Worker 6521*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager")(fn) 6522*da0073e9SAndroid Build Coastguard Worker inputs = [torch.randn(10, 10) for _ in range(3)] 6523*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(fn(iter(tuple(inputs))), opt_fn(iter(tuple(inputs))))) 6524*da0073e9SAndroid Build Coastguard Worker 6525*da0073e9SAndroid Build Coastguard Worker # Force recompilation 6526*da0073e9SAndroid Build Coastguard Worker inputs = [torch.randn(10, 10) for _ in range(4)] 6527*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(fn(iter(tuple(inputs))), opt_fn(iter(tuple(inputs))))) 6528*da0073e9SAndroid Build Coastguard Worker 6529*da0073e9SAndroid Build Coastguard Worker def test_torch_package_working_with_trace(self): 6530*da0073e9SAndroid Build Coastguard Worker # from torch._dynamo.test_case import run_tests 6531*da0073e9SAndroid Build Coastguard Worker 6532*da0073e9SAndroid Build Coastguard Worker inputs = [torch.randn([2, 2]), torch.randn([2, 2])] 6533*da0073e9SAndroid Build Coastguard Worker 6534*da0073e9SAndroid Build Coastguard Worker optimized_model = torch._dynamo.optimize(backend="eager")( 6535*da0073e9SAndroid Build Coastguard Worker MyPickledModule(torch.randn([2, 2])) 6536*da0073e9SAndroid Build Coastguard Worker ) 6537*da0073e9SAndroid Build Coastguard Worker from torch import package 6538*da0073e9SAndroid Build Coastguard Worker 6539*da0073e9SAndroid Build Coastguard Worker path = "/tmp/MyPickledModule.pt" 6540*da0073e9SAndroid Build Coastguard Worker package_name = "MyPickledModule" 6541*da0073e9SAndroid Build Coastguard Worker resource_name = "MyPickledModule.pkl" 6542*da0073e9SAndroid Build Coastguard Worker 6543*da0073e9SAndroid Build Coastguard Worker model = MyPickledModule(torch.randn([2, 2])) 6544*da0073e9SAndroid Build Coastguard Worker 6545*da0073e9SAndroid Build Coastguard Worker with package.PackageExporter(path) as exp: 6546*da0073e9SAndroid Build Coastguard Worker exp.extern("**") 6547*da0073e9SAndroid Build Coastguard Worker exp.save_pickle(package_name, resource_name, model) 6548*da0073e9SAndroid Build Coastguard Worker 6549*da0073e9SAndroid Build Coastguard Worker imp = package.PackageImporter(path) 6550*da0073e9SAndroid Build Coastguard Worker loaded_model = imp.load_pickle(package_name, resource_name) 6551*da0073e9SAndroid Build Coastguard Worker 6552*da0073e9SAndroid Build Coastguard Worker optimized_loaded_model = torch._dynamo.optimize("eager")(loaded_model)(*inputs) 6553*da0073e9SAndroid Build Coastguard Worker 6554*da0073e9SAndroid Build Coastguard Worker def test_shape_and_tuple_equality(self): 6555*da0073e9SAndroid Build Coastguard Worker def fn(x, y, t): 6556*da0073e9SAndroid Build Coastguard Worker z = x * y 6557*da0073e9SAndroid Build Coastguard Worker if x.size() == t: 6558*da0073e9SAndroid Build Coastguard Worker return z.cos() 6559*da0073e9SAndroid Build Coastguard Worker return z.sin() 6560*da0073e9SAndroid Build Coastguard Worker 6561*da0073e9SAndroid Build Coastguard Worker torch._dynamo.optimize("eager", nopython=True)(fn)( 6562*da0073e9SAndroid Build Coastguard Worker torch.randn([4, 4]), torch.randn([4, 4]), (4, 4) 6563*da0073e9SAndroid Build Coastguard Worker ) 6564*da0073e9SAndroid Build Coastguard Worker 6565*da0073e9SAndroid Build Coastguard Worker def test_int_list(self): 6566*da0073e9SAndroid Build Coastguard Worker # if assume_static_by_default == True: spec int list 6567*da0073e9SAndroid Build Coastguard Worker # otherwise: unspec int list 6568*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 6569*da0073e9SAndroid Build Coastguard Worker return torch.sin(x + y[1] % 2) 6570*da0073e9SAndroid Build Coastguard Worker 6571*da0073e9SAndroid Build Coastguard Worker x = torch.randn(6) 6572*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 6573*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnt)(fn) 6574*da0073e9SAndroid Build Coastguard Worker for i in range(10, 25, 3): 6575*da0073e9SAndroid Build Coastguard Worker y = [i, i + 1, i + 2] 6576*da0073e9SAndroid Build Coastguard Worker ref = fn(x, y) 6577*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x, y) 6578*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 6579*da0073e9SAndroid Build Coastguard Worker if torch._dynamo.config.assume_static_by_default: 6580*da0073e9SAndroid Build Coastguard Worker if torch._dynamo.config.automatic_dynamic_shapes: 6581*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnt.frame_count, """2""") 6582*da0073e9SAndroid Build Coastguard Worker else: 6583*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnt.frame_count, """5""") 6584*da0073e9SAndroid Build Coastguard Worker else: 6585*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnt.frame_count, """1""") 6586*da0073e9SAndroid Build Coastguard Worker 6587*da0073e9SAndroid Build Coastguard Worker def test_patched_builtin_functions(self): 6588*da0073e9SAndroid Build Coastguard Worker import builtins 6589*da0073e9SAndroid Build Coastguard Worker 6590*da0073e9SAndroid Build Coastguard Worker # Cache the original builtin function ids 6591*da0073e9SAndroid Build Coastguard Worker torch._dynamo.trace_rules._builtin_function_ids() 6592*da0073e9SAndroid Build Coastguard Worker 6593*da0073e9SAndroid Build Coastguard Worker class MyClass: 6594*da0073e9SAndroid Build Coastguard Worker pass 6595*da0073e9SAndroid Build Coastguard Worker 6596*da0073e9SAndroid Build Coastguard Worker builtin_isinstance = builtins.isinstance 6597*da0073e9SAndroid Build Coastguard Worker 6598*da0073e9SAndroid Build Coastguard Worker def patched_isinstance(obj, classinfo) -> bool: 6599*da0073e9SAndroid Build Coastguard Worker if builtin_isinstance(obj, MyClass): 6600*da0073e9SAndroid Build Coastguard Worker return False 6601*da0073e9SAndroid Build Coastguard Worker else: 6602*da0073e9SAndroid Build Coastguard Worker return builtin_isinstance(obj, classinfo) 6603*da0073e9SAndroid Build Coastguard Worker 6604*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 6605*da0073e9SAndroid Build Coastguard Worker if isinstance(y, MyClass): 6606*da0073e9SAndroid Build Coastguard Worker return x + 1 6607*da0073e9SAndroid Build Coastguard Worker else: 6608*da0073e9SAndroid Build Coastguard Worker return x - 1 6609*da0073e9SAndroid Build Coastguard Worker 6610*da0073e9SAndroid Build Coastguard Worker x = torch.ones(2, 3) 6611*da0073e9SAndroid Build Coastguard Worker y = MyClass() 6612*da0073e9SAndroid Build Coastguard Worker 6613*da0073e9SAndroid Build Coastguard Worker try: 6614*da0073e9SAndroid Build Coastguard Worker ref = fn(x, y) 6615*da0073e9SAndroid Build Coastguard Worker # Monkey patch builtin function 6616*da0073e9SAndroid Build Coastguard Worker builtins.isinstance = patched_isinstance 6617*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(backend="eager", fullgraph=True)(fn) 6618*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x, y) 6619*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, x + 1)) 6620*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(res, x - 1)) 6621*da0073e9SAndroid Build Coastguard Worker finally: 6622*da0073e9SAndroid Build Coastguard Worker builtins.isinstance = builtin_isinstance 6623*da0073e9SAndroid Build Coastguard Worker 6624*da0073e9SAndroid Build Coastguard Worker # check recompilation because builtins is now unpatched 6625*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(backend="eager", fullgraph=True)(fn) 6626*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x, y) 6627*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(res, x + 1)) 6628*da0073e9SAndroid Build Coastguard Worker 6629*da0073e9SAndroid Build Coastguard Worker # specifically test for tensor.attribute -> torch.something() 6630*da0073e9SAndroid Build Coastguard Worker def test_real_imag_tensor_attribute(self): 6631*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 6632*da0073e9SAndroid Build Coastguard Worker a = x.real 6633*da0073e9SAndroid Build Coastguard Worker b = x.imag 6634*da0073e9SAndroid Build Coastguard Worker return torch.mul(torch.add(a, y), b) 6635*da0073e9SAndroid Build Coastguard Worker 6636*da0073e9SAndroid Build Coastguard Worker x_real = torch.rand((4, 4)) 6637*da0073e9SAndroid Build Coastguard Worker x_imag = torch.rand((4, 4)) 6638*da0073e9SAndroid Build Coastguard Worker x = torch.complex(x_real, x_imag) 6639*da0073e9SAndroid Build Coastguard Worker y = torch.rand((4, 4)) 6640*da0073e9SAndroid Build Coastguard Worker 6641*da0073e9SAndroid Build Coastguard Worker ref = fn(x, y) 6642*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager")(fn) 6643*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x, y) 6644*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 6645*da0073e9SAndroid Build Coastguard Worker 6646*da0073e9SAndroid Build Coastguard Worker def test_cast(self): 6647*da0073e9SAndroid Build Coastguard Worker from typing import cast 6648*da0073e9SAndroid Build Coastguard Worker 6649*da0073e9SAndroid Build Coastguard Worker def fn(x): 6650*da0073e9SAndroid Build Coastguard Worker return cast(torch.Tensor, torch.add(x, 1.0)) 6651*da0073e9SAndroid Build Coastguard Worker 6652*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(backend="eager", fullgraph=True)(fn) 6653*da0073e9SAndroid Build Coastguard Worker 6654*da0073e9SAndroid Build Coastguard Worker ref = fn(torch.ones(2, 2)) 6655*da0073e9SAndroid Build Coastguard Worker res = opt_fn(torch.ones(2, 2)) 6656*da0073e9SAndroid Build Coastguard Worker 6657*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 6658*da0073e9SAndroid Build Coastguard Worker 6659*da0073e9SAndroid Build Coastguard Worker def test_T_tensor_attribute(self): 6660*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 6661*da0073e9SAndroid Build Coastguard Worker a = x.T 6662*da0073e9SAndroid Build Coastguard Worker return torch.add(a, y) 6663*da0073e9SAndroid Build Coastguard Worker 6664*da0073e9SAndroid Build Coastguard Worker x = torch.rand((4, 4)) 6665*da0073e9SAndroid Build Coastguard Worker y = torch.rand((4, 4)) 6666*da0073e9SAndroid Build Coastguard Worker 6667*da0073e9SAndroid Build Coastguard Worker ref = fn(x, y) 6668*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager")(fn) 6669*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x, y) 6670*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 6671*da0073e9SAndroid Build Coastguard Worker 6672*da0073e9SAndroid Build Coastguard Worker def test_recursive_tensor_attribute(self): 6673*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 6674*da0073e9SAndroid Build Coastguard Worker a = x.real.T 6675*da0073e9SAndroid Build Coastguard Worker b = x.imag 6676*da0073e9SAndroid Build Coastguard Worker return torch.mul(torch.add(a, y), b) 6677*da0073e9SAndroid Build Coastguard Worker 6678*da0073e9SAndroid Build Coastguard Worker x_real = torch.rand((4, 4)) 6679*da0073e9SAndroid Build Coastguard Worker x_imag = torch.rand((4, 4)) 6680*da0073e9SAndroid Build Coastguard Worker x = torch.complex(x_real, x_imag) 6681*da0073e9SAndroid Build Coastguard Worker y = torch.rand((4, 4)) 6682*da0073e9SAndroid Build Coastguard Worker 6683*da0073e9SAndroid Build Coastguard Worker ref = fn(x, y) 6684*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager")(fn) 6685*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x, y) 6686*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 6687*da0073e9SAndroid Build Coastguard Worker 6688*da0073e9SAndroid Build Coastguard Worker def test_assigning_function_to_object_attribute(self): 6689*da0073e9SAndroid Build Coastguard Worker # user-defined functions which are object's attributes are not converted to bound methods 6690*da0073e9SAndroid Build Coastguard Worker def my_add(*args): 6691*da0073e9SAndroid Build Coastguard Worker a, b = args 6692*da0073e9SAndroid Build Coastguard Worker return a + b 6693*da0073e9SAndroid Build Coastguard Worker 6694*da0073e9SAndroid Build Coastguard Worker class MyClass: 6695*da0073e9SAndroid Build Coastguard Worker def __init__(self, func): 6696*da0073e9SAndroid Build Coastguard Worker self.add = func 6697*da0073e9SAndroid Build Coastguard Worker 6698*da0073e9SAndroid Build Coastguard Worker obj = MyClass(my_add) 6699*da0073e9SAndroid Build Coastguard Worker 6700*da0073e9SAndroid Build Coastguard Worker def fn(x): 6701*da0073e9SAndroid Build Coastguard Worker return obj.add(x, 2) 6702*da0073e9SAndroid Build Coastguard Worker 6703*da0073e9SAndroid Build Coastguard Worker x = torch.rand(2, 3) 6704*da0073e9SAndroid Build Coastguard Worker ref = fn(x) 6705*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(backend="eager")(fn) 6706*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x) 6707*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 6708*da0073e9SAndroid Build Coastguard Worker 6709*da0073e9SAndroid Build Coastguard Worker def test_assigning_function_to_class_attribute(self): 6710*da0073e9SAndroid Build Coastguard Worker # user-defined functions which are class's attributes are converted to bound methods 6711*da0073e9SAndroid Build Coastguard Worker def my_add(*args): 6712*da0073e9SAndroid Build Coastguard Worker obj, a, b = args 6713*da0073e9SAndroid Build Coastguard Worker return obj.x + a + b 6714*da0073e9SAndroid Build Coastguard Worker 6715*da0073e9SAndroid Build Coastguard Worker class MyClass: 6716*da0073e9SAndroid Build Coastguard Worker add = my_add 6717*da0073e9SAndroid Build Coastguard Worker 6718*da0073e9SAndroid Build Coastguard Worker def __init__(self, x): 6719*da0073e9SAndroid Build Coastguard Worker self.x = x 6720*da0073e9SAndroid Build Coastguard Worker 6721*da0073e9SAndroid Build Coastguard Worker obj = MyClass(0.5) 6722*da0073e9SAndroid Build Coastguard Worker 6723*da0073e9SAndroid Build Coastguard Worker def fn(x): 6724*da0073e9SAndroid Build Coastguard Worker return obj.add(x, 2) 6725*da0073e9SAndroid Build Coastguard Worker 6726*da0073e9SAndroid Build Coastguard Worker x = torch.rand(2, 3) 6727*da0073e9SAndroid Build Coastguard Worker ref = fn(x) 6728*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(backend="eager")(fn) 6729*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x) 6730*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 6731*da0073e9SAndroid Build Coastguard Worker 6732*da0073e9SAndroid Build Coastguard Worker def test_tagging_tensors_simple(self): 6733*da0073e9SAndroid Build Coastguard Worker def foo(x, y): 6734*da0073e9SAndroid Build Coastguard Worker return x * y, x, y 6735*da0073e9SAndroid Build Coastguard Worker 6736*da0073e9SAndroid Build Coastguard Worker a = torch.randn([3, 3]) 6737*da0073e9SAndroid Build Coastguard Worker a.tag = "a" 6738*da0073e9SAndroid Build Coastguard Worker a.frog = "ribbity ribbit" 6739*da0073e9SAndroid Build Coastguard Worker b = torch.randn([3, 3]) 6740*da0073e9SAndroid Build Coastguard Worker b.tag = "b" 6741*da0073e9SAndroid Build Coastguard Worker b.frog = "ribbit" 6742*da0073e9SAndroid Build Coastguard Worker 6743*da0073e9SAndroid Build Coastguard Worker exported = torch._dynamo.export(foo)(a, b) 6744*da0073e9SAndroid Build Coastguard Worker out_graph = exported[0] 6745*da0073e9SAndroid Build Coastguard Worker 6746*da0073e9SAndroid Build Coastguard Worker nodes = list(out_graph.graph.nodes) 6747*da0073e9SAndroid Build Coastguard Worker placeholders = [node for node in nodes if node.op == "placeholder"] 6748*da0073e9SAndroid Build Coastguard Worker all_tags = [] 6749*da0073e9SAndroid Build Coastguard Worker all_frogs = [] 6750*da0073e9SAndroid Build Coastguard Worker for placeholder in placeholders: 6751*da0073e9SAndroid Build Coastguard Worker if "tensor_dict" in placeholder.meta: 6752*da0073e9SAndroid Build Coastguard Worker all_tags.append(placeholder.meta["tensor_dict"]["tag"]) 6753*da0073e9SAndroid Build Coastguard Worker all_frogs.append(placeholder.meta["tensor_dict"]["frog"]) 6754*da0073e9SAndroid Build Coastguard Worker 6755*da0073e9SAndroid Build Coastguard Worker self.assertEqual(all_tags, ["a", "b"]) 6756*da0073e9SAndroid Build Coastguard Worker self.assertEqual(all_frogs, ["ribbity ribbit", "ribbit"]) 6757*da0073e9SAndroid Build Coastguard Worker 6758*da0073e9SAndroid Build Coastguard Worker def test_tagging_tensors_mix_used_unused_structure(self): 6759*da0073e9SAndroid Build Coastguard Worker def pre_attention_state_ops(input, mems, state): 6760*da0073e9SAndroid Build Coastguard Worker lc_key = state[0] 6761*da0073e9SAndroid Build Coastguard Worker lc_val = state[1] 6762*da0073e9SAndroid Build Coastguard Worker bar = [] 6763*da0073e9SAndroid Build Coastguard Worker for i in range(0, 4): 6764*da0073e9SAndroid Build Coastguard Worker bar2 = [] 6765*da0073e9SAndroid Build Coastguard Worker for j in range(0, 3): 6766*da0073e9SAndroid Build Coastguard Worker bar2.append( 6767*da0073e9SAndroid Build Coastguard Worker lc_key + lc_val + torch.tensor([0.1, 0.25, 0.4, 0.5, 0.1]) 6768*da0073e9SAndroid Build Coastguard Worker ) 6769*da0073e9SAndroid Build Coastguard Worker bar.append(bar2) 6770*da0073e9SAndroid Build Coastguard Worker 6771*da0073e9SAndroid Build Coastguard Worker return bar 6772*da0073e9SAndroid Build Coastguard Worker 6773*da0073e9SAndroid Build Coastguard Worker mems = torch.tensor([[[1.8364, 0.2724, -1.4917, -0.4367, 0.8640]]]) 6774*da0073e9SAndroid Build Coastguard Worker state = [ 6775*da0073e9SAndroid Build Coastguard Worker torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]), 6776*da0073e9SAndroid Build Coastguard Worker torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]), 6777*da0073e9SAndroid Build Coastguard Worker ] 6778*da0073e9SAndroid Build Coastguard Worker i = torch.tensor( 6779*da0073e9SAndroid Build Coastguard Worker [ 6780*da0073e9SAndroid Build Coastguard Worker [0.0313, -0.1487, -0.3846, -0.5321], 6781*da0073e9SAndroid Build Coastguard Worker [-1.7073, 1.3331, -0.0890, -1.4935], 6782*da0073e9SAndroid Build Coastguard Worker [-0.8314, -0.1862, -0.5935, 1.5232], 6783*da0073e9SAndroid Build Coastguard Worker ] 6784*da0073e9SAndroid Build Coastguard Worker ) 6785*da0073e9SAndroid Build Coastguard Worker 6786*da0073e9SAndroid Build Coastguard Worker mems.tag = "MEMS" 6787*da0073e9SAndroid Build Coastguard Worker i.tag = "FOO" 6788*da0073e9SAndroid Build Coastguard Worker state[0].tag = "STATE_0" 6789*da0073e9SAndroid Build Coastguard Worker state[1].tag = "HMMM" 6790*da0073e9SAndroid Build Coastguard Worker 6791*da0073e9SAndroid Build Coastguard Worker exported = torch._dynamo.export(pre_attention_state_ops)(i, mems, state) 6792*da0073e9SAndroid Build Coastguard Worker out_graph = exported[0] 6793*da0073e9SAndroid Build Coastguard Worker 6794*da0073e9SAndroid Build Coastguard Worker nodes = list(out_graph.graph.nodes) 6795*da0073e9SAndroid Build Coastguard Worker placeholders = [node for node in nodes if node.op == "placeholder"] 6796*da0073e9SAndroid Build Coastguard Worker all_tags = [] 6797*da0073e9SAndroid Build Coastguard Worker for placeholder in placeholders: 6798*da0073e9SAndroid Build Coastguard Worker if "tensor_dict" in placeholder.meta: 6799*da0073e9SAndroid Build Coastguard Worker all_tags.append(placeholder.meta["tensor_dict"]["tag"]) 6800*da0073e9SAndroid Build Coastguard Worker 6801*da0073e9SAndroid Build Coastguard Worker self.assertEqual(all_tags, ["STATE_0", "HMMM"]) 6802*da0073e9SAndroid Build Coastguard Worker 6803*da0073e9SAndroid Build Coastguard Worker def test_get_custom_tensor_attribute(self): 6804*da0073e9SAndroid Build Coastguard Worker def fn(x): 6805*da0073e9SAndroid Build Coastguard Worker return x.custom_attr * x 6806*da0073e9SAndroid Build Coastguard Worker 6807*da0073e9SAndroid Build Coastguard Worker x = torch.rand((2, 2)) 6808*da0073e9SAndroid Build Coastguard Worker x.custom_attr = 3.14 6809*da0073e9SAndroid Build Coastguard Worker ref = fn(x) 6810*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager")(fn) 6811*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x) 6812*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 6813*da0073e9SAndroid Build Coastguard Worker 6814*da0073e9SAndroid Build Coastguard Worker def test_set_custom_tensor_attribute(self): 6815*da0073e9SAndroid Build Coastguard Worker def fn(x): 6816*da0073e9SAndroid Build Coastguard Worker x.custom_attr = 3.14 6817*da0073e9SAndroid Build Coastguard Worker return x.custom_attr * x 6818*da0073e9SAndroid Build Coastguard Worker 6819*da0073e9SAndroid Build Coastguard Worker x = torch.rand((2, 2)) 6820*da0073e9SAndroid Build Coastguard Worker ref = fn(x) 6821*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager")(fn) 6822*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x) 6823*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 6824*da0073e9SAndroid Build Coastguard Worker 6825*da0073e9SAndroid Build Coastguard Worker def test_unhandled_exception_in_dynamo(self): 6826*da0073e9SAndroid Build Coastguard Worker # traceback.format_exc() approximates an unhandled exception 6827*da0073e9SAndroid Build Coastguard Worker def f(a): 6828*da0073e9SAndroid Build Coastguard Worker a += 1 6829*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("smoge") 6830*da0073e9SAndroid Build Coastguard Worker return a 6831*da0073e9SAndroid Build Coastguard Worker 6832*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager")(f) 6833*da0073e9SAndroid Build Coastguard Worker try: 6834*da0073e9SAndroid Build Coastguard Worker opt_fn(torch.ones(2)) 6835*da0073e9SAndroid Build Coastguard Worker except RuntimeError as e: 6836*da0073e9SAndroid Build Coastguard Worker self.assertIn("smoge", traceback.format_exc()) 6837*da0073e9SAndroid Build Coastguard Worker 6838*da0073e9SAndroid Build Coastguard Worker def test_unhandled_exception_in_dynamo2(self): 6839*da0073e9SAndroid Build Coastguard Worker # segfaults in python 3.11 if shadow frame is freed improperly 6840*da0073e9SAndroid Build Coastguard Worker from torch.testing import make_tensor 6841*da0073e9SAndroid Build Coastguard Worker 6842*da0073e9SAndroid Build Coastguard Worker def fn(): 6843*da0073e9SAndroid Build Coastguard Worker # test that the errors are the same for dense and sparse versions 6844*da0073e9SAndroid Build Coastguard Worker def test1(*, is_sparse): 6845*da0073e9SAndroid Build Coastguard Worker # shapes must be compatible for matrix multiplication 6846*da0073e9SAndroid Build Coastguard Worker a = make_tensor((2, 3), dtype=torch.float32, device="cpu") 6847*da0073e9SAndroid Build Coastguard Worker if is_sparse: 6848*da0073e9SAndroid Build Coastguard Worker a_sparse = a.to_sparse_csr() 6849*da0073e9SAndroid Build Coastguard Worker return torch.addmm(a, a_sparse, a) 6850*da0073e9SAndroid Build Coastguard Worker else: 6851*da0073e9SAndroid Build Coastguard Worker return torch.addmm(a, a, a) 6852*da0073e9SAndroid Build Coastguard Worker 6853*da0073e9SAndroid Build Coastguard Worker try: 6854*da0073e9SAndroid Build Coastguard Worker test1(is_sparse=False) 6855*da0073e9SAndroid Build Coastguard Worker except RuntimeError as msg: 6856*da0073e9SAndroid Build Coastguard Worker try: 6857*da0073e9SAndroid Build Coastguard Worker test1(is_sparse=True) 6858*da0073e9SAndroid Build Coastguard Worker except RuntimeError as msg2: 6859*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("smoge") 6860*da0073e9SAndroid Build Coastguard Worker 6861*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager")(fn) 6862*da0073e9SAndroid Build Coastguard Worker try: 6863*da0073e9SAndroid Build Coastguard Worker opt_fn() 6864*da0073e9SAndroid Build Coastguard Worker except RuntimeError: 6865*da0073e9SAndroid Build Coastguard Worker self.assertIn("smoge", traceback.format_exc()) 6866*da0073e9SAndroid Build Coastguard Worker 6867*da0073e9SAndroid Build Coastguard Worker def test_variable_access_in_exception(self): 6868*da0073e9SAndroid Build Coastguard Worker def fn(): 6869*da0073e9SAndroid Build Coastguard Worker x = torch.ones(1) 6870*da0073e9SAndroid Build Coastguard Worker try: 6871*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("bad") 6872*da0073e9SAndroid Build Coastguard Worker except RuntimeError: 6873*da0073e9SAndroid Build Coastguard Worker x += 1 6874*da0073e9SAndroid Build Coastguard Worker return x 6875*da0073e9SAndroid Build Coastguard Worker 6876*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) 6877*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(), torch.tensor([2.0])) 6878*da0073e9SAndroid Build Coastguard Worker 6879*da0073e9SAndroid Build Coastguard Worker def test_nested_sequential_with(self): 6880*da0073e9SAndroid Build Coastguard Worker def fn(x): 6881*da0073e9SAndroid Build Coastguard Worker with torch.set_grad_enabled(True): 6882*da0073e9SAndroid Build Coastguard Worker with torch.set_grad_enabled(False): 6883*da0073e9SAndroid Build Coastguard Worker x = x + 1 6884*da0073e9SAndroid Build Coastguard Worker with torch.set_grad_enabled(True): 6885*da0073e9SAndroid Build Coastguard Worker x = x + 1 6886*da0073e9SAndroid Build Coastguard Worker return x 6887*da0073e9SAndroid Build Coastguard Worker 6888*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager")(fn) 6889*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(torch.ones(1)), torch.tensor([3.0])) 6890*da0073e9SAndroid Build Coastguard Worker 6891*da0073e9SAndroid Build Coastguard Worker def test_nested_sequential_try(self): 6892*da0073e9SAndroid Build Coastguard Worker def fn(x): 6893*da0073e9SAndroid Build Coastguard Worker try: 6894*da0073e9SAndroid Build Coastguard Worker try: 6895*da0073e9SAndroid Build Coastguard Worker x = x + 1 6896*da0073e9SAndroid Build Coastguard Worker except: 6897*da0073e9SAndroid Build Coastguard Worker pass 6898*da0073e9SAndroid Build Coastguard Worker try: 6899*da0073e9SAndroid Build Coastguard Worker try: 6900*da0073e9SAndroid Build Coastguard Worker x = x + 1 6901*da0073e9SAndroid Build Coastguard Worker except: 6902*da0073e9SAndroid Build Coastguard Worker pass 6903*da0073e9SAndroid Build Coastguard Worker except: 6904*da0073e9SAndroid Build Coastguard Worker pass 6905*da0073e9SAndroid Build Coastguard Worker except: 6906*da0073e9SAndroid Build Coastguard Worker pass 6907*da0073e9SAndroid Build Coastguard Worker return x 6908*da0073e9SAndroid Build Coastguard Worker 6909*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager")(fn) 6910*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(torch.ones(1)), torch.tensor([3.0])) 6911*da0073e9SAndroid Build Coastguard Worker 6912*da0073e9SAndroid Build Coastguard Worker def test_nested_sequential_try_with(self): 6913*da0073e9SAndroid Build Coastguard Worker def fn(x): 6914*da0073e9SAndroid Build Coastguard Worker with torch.set_grad_enabled(True): 6915*da0073e9SAndroid Build Coastguard Worker try: 6916*da0073e9SAndroid Build Coastguard Worker x = x + 1 6917*da0073e9SAndroid Build Coastguard Worker except: 6918*da0073e9SAndroid Build Coastguard Worker pass 6919*da0073e9SAndroid Build Coastguard Worker try: 6920*da0073e9SAndroid Build Coastguard Worker with torch.set_grad_enabled(False): 6921*da0073e9SAndroid Build Coastguard Worker x = x + 1 6922*da0073e9SAndroid Build Coastguard Worker except: 6923*da0073e9SAndroid Build Coastguard Worker pass 6924*da0073e9SAndroid Build Coastguard Worker return x 6925*da0073e9SAndroid Build Coastguard Worker 6926*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager")(fn) 6927*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(torch.ones(1)), torch.tensor([3.0])) 6928*da0073e9SAndroid Build Coastguard Worker 6929*da0073e9SAndroid Build Coastguard Worker def test_nested_sequential_try_with_graph_break(self): 6930*da0073e9SAndroid Build Coastguard Worker def fn(x, n): 6931*da0073e9SAndroid Build Coastguard Worker with torch.set_grad_enabled(True): 6932*da0073e9SAndroid Build Coastguard Worker with torch.set_grad_enabled(False): 6933*da0073e9SAndroid Build Coastguard Worker x = x + 1 6934*da0073e9SAndroid Build Coastguard Worker torch._dynamo.graph_break() 6935*da0073e9SAndroid Build Coastguard Worker try: 6936*da0073e9SAndroid Build Coastguard Worker with torch.set_grad_enabled(False): 6937*da0073e9SAndroid Build Coastguard Worker x = x + 1 6938*da0073e9SAndroid Build Coastguard Worker if n == 0: 6939*da0073e9SAndroid Build Coastguard Worker torch._dynamo.graph_break() 6940*da0073e9SAndroid Build Coastguard Worker except: 6941*da0073e9SAndroid Build Coastguard Worker pass 6942*da0073e9SAndroid Build Coastguard Worker with torch.set_grad_enabled(False): 6943*da0073e9SAndroid Build Coastguard Worker x = x + 1 6944*da0073e9SAndroid Build Coastguard Worker torch._dynamo.graph_break() 6945*da0073e9SAndroid Build Coastguard Worker x = x + 1 6946*da0073e9SAndroid Build Coastguard Worker return x 6947*da0073e9SAndroid Build Coastguard Worker 6948*da0073e9SAndroid Build Coastguard Worker counter = CompileCounter() 6949*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(counter)(fn) 6950*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(torch.ones(1), 0), torch.tensor([5.0])) 6951*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 1) 6952*da0073e9SAndroid Build Coastguard Worker 6953*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 6954*da0073e9SAndroid Build Coastguard Worker counter = CompileCounter() 6955*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(counter)(fn) 6956*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(torch.ones(1), 1), torch.tensor([5.0])) 6957*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 3) 6958*da0073e9SAndroid Build Coastguard Worker 6959*da0073e9SAndroid Build Coastguard Worker def test_ordered_dict_alias_reconstruct(self): 6960*da0073e9SAndroid Build Coastguard Worker od = collections.OrderedDict 6961*da0073e9SAndroid Build Coastguard Worker 6962*da0073e9SAndroid Build Coastguard Worker def fn(): 6963*da0073e9SAndroid Build Coastguard Worker d1 = dict() 6964*da0073e9SAndroid Build Coastguard Worker d1["a"] = 1 6965*da0073e9SAndroid Build Coastguard Worker d2 = od(d1) 6966*da0073e9SAndroid Build Coastguard Worker d2["b"] = 2 6967*da0073e9SAndroid Build Coastguard Worker torch._dynamo.graph_break() 6968*da0073e9SAndroid Build Coastguard Worker if isinstance(d2, od): 6969*da0073e9SAndroid Build Coastguard Worker return d2["a"] + d2["b"] 6970*da0073e9SAndroid Build Coastguard Worker else: 6971*da0073e9SAndroid Build Coastguard Worker return 0 6972*da0073e9SAndroid Build Coastguard Worker 6973*da0073e9SAndroid Build Coastguard Worker dis.dis(fn) 6974*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch._dynamo.optimize("eager")(fn)(), 3) 6975*da0073e9SAndroid Build Coastguard Worker 6976*da0073e9SAndroid Build Coastguard Worker # NOTE this test can be removed once multiline errors are in Python. 6977*da0073e9SAndroid Build Coastguard Worker # See https://github.com/python/cpython/issues/106922 6978*da0073e9SAndroid Build Coastguard Worker @skipIfNotPy311 6979*da0073e9SAndroid Build Coastguard Worker def test_get_instruction_source_311(self): 6980*da0073e9SAndroid Build Coastguard Worker def f(): 6981*da0073e9SAndroid Build Coastguard Worker # flake8: noqa 6982*da0073e9SAndroid Build Coastguard Worker # fmt: off 6983*da0073e9SAndroid Build Coastguard Worker # test binary ops 6984*da0073e9SAndroid Build Coastguard Worker a = ( b ) + c 6985*da0073e9SAndroid Build Coastguard Worker a = (a + b) // (c - d) 6986*da0073e9SAndroid Build Coastguard Worker a = b \ 6987*da0073e9SAndroid Build Coastguard Worker +\ 6988*da0073e9SAndroid Build Coastguard Worker c # test 6989*da0073e9SAndroid Build Coastguard Worker a = ( 6990*da0073e9SAndroid Build Coastguard Worker (b # test + 6991*da0073e9SAndroid Build Coastguard Worker ) \ 6992*da0073e9SAndroid Build Coastguard Worker # + 6993*da0073e9SAndroid Build Coastguard Worker << ( 6994*da0073e9SAndroid Build Coastguard Worker 6995*da0073e9SAndroid Build Coastguard Worker c # test 6996*da0073e9SAndroid Build Coastguard Worker \ 6997*da0073e9SAndroid Build Coastguard Worker ) # test 6998*da0073e9SAndroid Build Coastguard Worker ) 6999*da0073e9SAndroid Build Coastguard Worker 7000*da0073e9SAndroid Build Coastguard Worker # test slice 7001*da0073e9SAndroid Build Coastguard Worker a = bbb [ ccc ] 7002*da0073e9SAndroid Build Coastguard Worker b = bbbbb \ 7003*da0073e9SAndroid Build Coastguard Worker [ ccc # test 7004*da0073e9SAndroid Build Coastguard Worker 7005*da0073e9SAndroid Build Coastguard Worker + ddd \ 7006*da0073e9SAndroid Build Coastguard Worker 7007*da0073e9SAndroid Build Coastguard Worker ] # test 7008*da0073e9SAndroid Build Coastguard Worker a = bbb[ccc][ddd][eee] 7009*da0073e9SAndroid Build Coastguard Worker 7010*da0073e9SAndroid Build Coastguard Worker # test nested and multiline function calls 7011*da0073e9SAndroid Build Coastguard Worker a = g(g(g(b))) 7012*da0073e9SAndroid Build Coastguard Worker a = g(h( 7013*da0073e9SAndroid Build Coastguard Worker g(b), 7014*da0073e9SAndroid Build Coastguard Worker c 7015*da0073e9SAndroid Build Coastguard Worker )) 7016*da0073e9SAndroid Build Coastguard Worker 7017*da0073e9SAndroid Build Coastguard Worker # test chained function calls 7018*da0073e9SAndroid Build Coastguard Worker a = (g(x).y)( 7019*da0073e9SAndroid Build Coastguard Worker z 7020*da0073e9SAndroid Build Coastguard Worker )(1)(2) 7021*da0073e9SAndroid Build Coastguard Worker 7022*da0073e9SAndroid Build Coastguard Worker # test unicode (match traceback behavior) 7023*da0073e9SAndroid Build Coastguard Worker a = ("" + 7024*da0073e9SAndroid Build Coastguard Worker + "") + b 7025*da0073e9SAndroid Build Coastguard Worker 7026*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.utils import get_instruction_source_311 7027*da0073e9SAndroid Build Coastguard Worker 7028*da0073e9SAndroid Build Coastguard Worker if sys.version_info >= (3, 12): 7029*da0073e9SAndroid Build Coastguard Worker # Offsets changed in 3.12, e.g. due to removal of PRECALL inst 7030*da0073e9SAndroid Build Coastguard Worker offsets = (3, 11, 15, 19, 23, 29, 35, 44, 53, 65) 7031*da0073e9SAndroid Build Coastguard Worker else: 7032*da0073e9SAndroid Build Coastguard Worker offsets = (3, 11, 15, 19, 23, 29, 35, 46, 58, 74) 7033*da0073e9SAndroid Build Coastguard Worker insts = list(dis.get_instructions(f)) 7034*da0073e9SAndroid Build Coastguard Worker # uncomment to determine offsets 7035*da0073e9SAndroid Build Coastguard Worker # print(*enumerate(insts), sep="\n") 7036*da0073e9SAndroid Build Coastguard Worker all_sources = "\n".join( 7037*da0073e9SAndroid Build Coastguard Worker get_instruction_source_311(f.__code__, insts[offset]) for offset in offsets 7038*da0073e9SAndroid Build Coastguard Worker ) 7039*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 7040*da0073e9SAndroid Build Coastguard Worker all_sources, 7041*da0073e9SAndroid Build Coastguard Worker """\ 7042*da0073e9SAndroid Build Coastguard Worker a = ( b ) + c 7043*da0073e9SAndroid Build Coastguard Worker ~~~~~~~~~~^~~~~ 7044*da0073e9SAndroid Build Coastguard Worker 7045*da0073e9SAndroid Build Coastguard Worker a = (a + b) // (c - d) 7046*da0073e9SAndroid Build Coastguard Worker ~~~~~~~~^^~~~~~~~~ 7047*da0073e9SAndroid Build Coastguard Worker 7048*da0073e9SAndroid Build Coastguard Worker a = b \\ 7049*da0073e9SAndroid Build Coastguard Worker ~~~~~~ 7050*da0073e9SAndroid Build Coastguard Worker +\\ 7051*da0073e9SAndroid Build Coastguard Worker ^~ 7052*da0073e9SAndroid Build Coastguard Worker c # test 7053*da0073e9SAndroid Build Coastguard Worker ~ 7054*da0073e9SAndroid Build Coastguard Worker 7055*da0073e9SAndroid Build Coastguard Worker (b # test + 7056*da0073e9SAndroid Build Coastguard Worker ~~~~~~~~~~~~ 7057*da0073e9SAndroid Build Coastguard Worker ) \\ 7058*da0073e9SAndroid Build Coastguard Worker ~~~~ 7059*da0073e9SAndroid Build Coastguard Worker # + 7060*da0073e9SAndroid Build Coastguard Worker ~~~ 7061*da0073e9SAndroid Build Coastguard Worker << ( 7062*da0073e9SAndroid Build Coastguard Worker ^^~~ 7063*da0073e9SAndroid Build Coastguard Worker 7064*da0073e9SAndroid Build Coastguard Worker 7065*da0073e9SAndroid Build Coastguard Worker c # test 7066*da0073e9SAndroid Build Coastguard Worker ~~~~~~~~~ 7067*da0073e9SAndroid Build Coastguard Worker \\ 7068*da0073e9SAndroid Build Coastguard Worker ~ 7069*da0073e9SAndroid Build Coastguard Worker ) # test 7070*da0073e9SAndroid Build Coastguard Worker ~ 7071*da0073e9SAndroid Build Coastguard Worker 7072*da0073e9SAndroid Build Coastguard Worker a = bbb [ ccc ] 7073*da0073e9SAndroid Build Coastguard Worker ~~~~~~^^^^^^^^^^^ 7074*da0073e9SAndroid Build Coastguard Worker 7075*da0073e9SAndroid Build Coastguard Worker b = bbbbb \\ 7076*da0073e9SAndroid Build Coastguard Worker ~~~~~~~ 7077*da0073e9SAndroid Build Coastguard Worker [ ccc # test 7078*da0073e9SAndroid Build Coastguard Worker ^^^^^^^^^^^^^ 7079*da0073e9SAndroid Build Coastguard Worker 7080*da0073e9SAndroid Build Coastguard Worker 7081*da0073e9SAndroid Build Coastguard Worker + ddd \\ 7082*da0073e9SAndroid Build Coastguard Worker ^^^^^^^^ 7083*da0073e9SAndroid Build Coastguard Worker 7084*da0073e9SAndroid Build Coastguard Worker 7085*da0073e9SAndroid Build Coastguard Worker ] # test 7086*da0073e9SAndroid Build Coastguard Worker ^ 7087*da0073e9SAndroid Build Coastguard Worker 7088*da0073e9SAndroid Build Coastguard Worker a = bbb[ccc][ddd][eee] 7089*da0073e9SAndroid Build Coastguard Worker ~~~~~~~~^^^^^ 7090*da0073e9SAndroid Build Coastguard Worker 7091*da0073e9SAndroid Build Coastguard Worker a = g(g(g(b))) 7092*da0073e9SAndroid Build Coastguard Worker ~^^^^^^ 7093*da0073e9SAndroid Build Coastguard Worker 7094*da0073e9SAndroid Build Coastguard Worker a = g(h( 7095*da0073e9SAndroid Build Coastguard Worker ~^ 7096*da0073e9SAndroid Build Coastguard Worker g(b), 7097*da0073e9SAndroid Build Coastguard Worker ^^^^^ 7098*da0073e9SAndroid Build Coastguard Worker c 7099*da0073e9SAndroid Build Coastguard Worker ^ 7100*da0073e9SAndroid Build Coastguard Worker )) 7101*da0073e9SAndroid Build Coastguard Worker ^ 7102*da0073e9SAndroid Build Coastguard Worker 7103*da0073e9SAndroid Build Coastguard Worker a = (g(x).y)( 7104*da0073e9SAndroid Build Coastguard Worker ~~~~~~~~~ 7105*da0073e9SAndroid Build Coastguard Worker z 7106*da0073e9SAndroid Build Coastguard Worker ~ 7107*da0073e9SAndroid Build Coastguard Worker )(1)(2) 7108*da0073e9SAndroid Build Coastguard Worker ~^^^ 7109*da0073e9SAndroid Build Coastguard Worker""", 7110*da0073e9SAndroid Build Coastguard Worker ) 7111*da0073e9SAndroid Build Coastguard Worker # test unicode (since assertExpectedInline doesn't support unicode) 7112*da0073e9SAndroid Build Coastguard Worker op_offset = 74 if sys.version_info >= (3, 12) else 84 7113*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 7114*da0073e9SAndroid Build Coastguard Worker get_instruction_source_311(f.__code__, insts[op_offset]), 7115*da0073e9SAndroid Build Coastguard Worker """\ 7116*da0073e9SAndroid Build Coastguard Worker a = ("" + 7117*da0073e9SAndroid Build Coastguard Worker ~~~~~~~~ 7118*da0073e9SAndroid Build Coastguard Worker + "") + b 7119*da0073e9SAndroid Build Coastguard Worker ~~~~~~~~^~~ 7120*da0073e9SAndroid Build Coastguard Worker""", 7121*da0073e9SAndroid Build Coastguard Worker ) 7122*da0073e9SAndroid Build Coastguard Worker 7123*da0073e9SAndroid Build Coastguard Worker def test_raise_guard_full_constraint(self): 7124*da0073e9SAndroid Build Coastguard Worker y = torch.randn([3, 3, 3]) 7125*da0073e9SAndroid Build Coastguard Worker 7126*da0073e9SAndroid Build Coastguard Worker def my_dyn_fn(x): 7127*da0073e9SAndroid Build Coastguard Worker if x.shape[0] == 3: 7128*da0073e9SAndroid Build Coastguard Worker return x.sin() 7129*da0073e9SAndroid Build Coastguard Worker return x.cos() 7130*da0073e9SAndroid Build Coastguard Worker 7131*da0073e9SAndroid Build Coastguard Worker torch._dynamo.mark_dynamic(y, 0) 7132*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ConstraintViolationError): 7133*da0073e9SAndroid Build Coastguard Worker torch._dynamo.optimize("eager")(my_dyn_fn)(y) 7134*da0073e9SAndroid Build Coastguard Worker 7135*da0073e9SAndroid Build Coastguard Worker # Translation validation changes the exception type, don't run with it 7136*da0073e9SAndroid Build Coastguard Worker @torch.fx.experimental._config.patch(translation_validation=False) 7137*da0073e9SAndroid Build Coastguard Worker def test_mark_dynamic_with_ranges(self): 7138*da0073e9SAndroid Build Coastguard Worker y = torch.randn([8, 3, 3]) 7139*da0073e9SAndroid Build Coastguard Worker 7140*da0073e9SAndroid Build Coastguard Worker def my_dyn_fn(x): 7141*da0073e9SAndroid Build Coastguard Worker if x.shape[0] == 3: 7142*da0073e9SAndroid Build Coastguard Worker return x.sin() 7143*da0073e9SAndroid Build Coastguard Worker return x.cos() 7144*da0073e9SAndroid Build Coastguard Worker 7145*da0073e9SAndroid Build Coastguard Worker torch._dynamo.mark_dynamic(y, 0, min=2, max=5) 7146*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ConstraintViolationError): 7147*da0073e9SAndroid Build Coastguard Worker torch._dynamo.optimize("eager")(my_dyn_fn)(y) 7148*da0073e9SAndroid Build Coastguard Worker 7149*da0073e9SAndroid Build Coastguard Worker def test_mark_static(self): 7150*da0073e9SAndroid Build Coastguard Worker counter = CompileCounter() 7151*da0073e9SAndroid Build Coastguard Worker 7152*da0073e9SAndroid Build Coastguard Worker def my_dyn_fn(x): 7153*da0073e9SAndroid Build Coastguard Worker return x.cos() 7154*da0073e9SAndroid Build Coastguard Worker 7155*da0073e9SAndroid Build Coastguard Worker y = torch.randn([3]) 7156*da0073e9SAndroid Build Coastguard Worker torch._dynamo.mark_static(y, 0) 7157*da0073e9SAndroid Build Coastguard Worker torch._dynamo.optimize(counter)(my_dyn_fn)(y) 7158*da0073e9SAndroid Build Coastguard Worker 7159*da0073e9SAndroid Build Coastguard Worker z = torch.randn([4]) 7160*da0073e9SAndroid Build Coastguard Worker torch._dynamo.optimize(counter)(my_dyn_fn)(z) 7161*da0073e9SAndroid Build Coastguard Worker 7162*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 2) 7163*da0073e9SAndroid Build Coastguard Worker 7164*da0073e9SAndroid Build Coastguard Worker def test_no_raise_guard_partial_constraint(self): 7165*da0073e9SAndroid Build Coastguard Worker y = torch.randn([3, 3, 3]) 7166*da0073e9SAndroid Build Coastguard Worker 7167*da0073e9SAndroid Build Coastguard Worker def my_dyn_fn(x): 7168*da0073e9SAndroid Build Coastguard Worker if x.shape[0] > 3: 7169*da0073e9SAndroid Build Coastguard Worker return x.sin() 7170*da0073e9SAndroid Build Coastguard Worker return x.cos() 7171*da0073e9SAndroid Build Coastguard Worker 7172*da0073e9SAndroid Build Coastguard Worker torch._dynamo.optimize("eager")(my_dyn_fn)(y) 7173*da0073e9SAndroid Build Coastguard Worker torch._dynamo.mark_dynamic(y, 0) 7174*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 7175*da0073e9SAndroid Build Coastguard Worker torch._dynamo.optimize("eager")(my_dyn_fn)(y) 7176*da0073e9SAndroid Build Coastguard Worker 7177*da0073e9SAndroid Build Coastguard Worker def test_no_raise_guard_partial_constraint_across_break(self): 7178*da0073e9SAndroid Build Coastguard Worker y = torch.randn([3, 3, 3]) 7179*da0073e9SAndroid Build Coastguard Worker 7180*da0073e9SAndroid Build Coastguard Worker def my_dyn_fn(x, y): 7181*da0073e9SAndroid Build Coastguard Worker z = x * y 7182*da0073e9SAndroid Build Coastguard Worker 7183*da0073e9SAndroid Build Coastguard Worker torch._dynamo.graph_break() 7184*da0073e9SAndroid Build Coastguard Worker if z.shape[0] > 2: 7185*da0073e9SAndroid Build Coastguard Worker return z.cos() 7186*da0073e9SAndroid Build Coastguard Worker 7187*da0073e9SAndroid Build Coastguard Worker return x.cos() 7188*da0073e9SAndroid Build Coastguard Worker 7189*da0073e9SAndroid Build Coastguard Worker torch._dynamo.optimize("eager")(my_dyn_fn)(y, y) 7190*da0073e9SAndroid Build Coastguard Worker torch._dynamo.mark_dynamic(y, 0) 7191*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 7192*da0073e9SAndroid Build Coastguard Worker torch._dynamo.optimize("eager")(my_dyn_fn)(y, y) 7193*da0073e9SAndroid Build Coastguard Worker 7194*da0073e9SAndroid Build Coastguard Worker # Sadly, this does not throw - we do not prop correctly across the graph break 7195*da0073e9SAndroid Build Coastguard Worker @unittest.expectedFailure 7196*da0073e9SAndroid Build Coastguard Worker def test_raise_guard_partial_constraint_across_break(self): 7197*da0073e9SAndroid Build Coastguard Worker y = torch.randn([3, 3, 3]) 7198*da0073e9SAndroid Build Coastguard Worker 7199*da0073e9SAndroid Build Coastguard Worker def my_dyn_fn(x, y): 7200*da0073e9SAndroid Build Coastguard Worker z = x * y 7201*da0073e9SAndroid Build Coastguard Worker 7202*da0073e9SAndroid Build Coastguard Worker torch._dynamo.graph_break() 7203*da0073e9SAndroid Build Coastguard Worker if z.shape[0] == 3: 7204*da0073e9SAndroid Build Coastguard Worker return z.cos() 7205*da0073e9SAndroid Build Coastguard Worker 7206*da0073e9SAndroid Build Coastguard Worker return x.cos() 7207*da0073e9SAndroid Build Coastguard Worker 7208*da0073e9SAndroid Build Coastguard Worker torch._dynamo.optimize("eager")(my_dyn_fn)(y, y) 7209*da0073e9SAndroid Build Coastguard Worker torch._dynamo.mark_dynamic(y, 0) 7210*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 7211*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 7212*da0073e9SAndroid Build Coastguard Worker Exception, 7213*da0073e9SAndroid Build Coastguard Worker ): 7214*da0073e9SAndroid Build Coastguard Worker torch._dynamo.optimize("eager")(my_dyn_fn)(y, y) 7215*da0073e9SAndroid Build Coastguard Worker 7216*da0073e9SAndroid Build Coastguard Worker def test_raise_guard_partial_constraint_no_graph_break(self): 7217*da0073e9SAndroid Build Coastguard Worker y = torch.randn([3, 3, 3]) 7218*da0073e9SAndroid Build Coastguard Worker 7219*da0073e9SAndroid Build Coastguard Worker def my_dyn_fn(x, y): 7220*da0073e9SAndroid Build Coastguard Worker z = x * y 7221*da0073e9SAndroid Build Coastguard Worker 7222*da0073e9SAndroid Build Coastguard Worker if z.shape[0] == 3: 7223*da0073e9SAndroid Build Coastguard Worker return z.cos() 7224*da0073e9SAndroid Build Coastguard Worker 7225*da0073e9SAndroid Build Coastguard Worker return x.cos() 7226*da0073e9SAndroid Build Coastguard Worker 7227*da0073e9SAndroid Build Coastguard Worker torch._dynamo.mark_dynamic(y, 0) 7228*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ConstraintViolationError): 7229*da0073e9SAndroid Build Coastguard Worker torch._dynamo.optimize("eager")(my_dyn_fn)(y, y) 7230*da0073e9SAndroid Build Coastguard Worker 7231*da0073e9SAndroid Build Coastguard Worker def test_cannot_trace_mark_dynamic(self): 7232*da0073e9SAndroid Build Coastguard Worker y = torch.randn([3, 3, 3]) 7233*da0073e9SAndroid Build Coastguard Worker 7234*da0073e9SAndroid Build Coastguard Worker def my_dyn_fn(x): 7235*da0073e9SAndroid Build Coastguard Worker torch._dynamo.mark_dynamic(x, 0) 7236*da0073e9SAndroid Build Coastguard Worker return x * x 7237*da0073e9SAndroid Build Coastguard Worker 7238*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 7239*da0073e9SAndroid Build Coastguard Worker AssertionError, "Attempt to trace forbidden callable" 7240*da0073e9SAndroid Build Coastguard Worker ): 7241*da0073e9SAndroid Build Coastguard Worker torch._dynamo.optimize("eager")(my_dyn_fn)(y) 7242*da0073e9SAndroid Build Coastguard Worker 7243*da0073e9SAndroid Build Coastguard Worker def test_cannot_trace_mark_dynamic_safe_unreached(self): 7244*da0073e9SAndroid Build Coastguard Worker y = torch.randn([3, 3, 3]) 7245*da0073e9SAndroid Build Coastguard Worker 7246*da0073e9SAndroid Build Coastguard Worker def my_dyn_fn(x): 7247*da0073e9SAndroid Build Coastguard Worker if x.shape[0] == 3: 7248*da0073e9SAndroid Build Coastguard Worker return x 7249*da0073e9SAndroid Build Coastguard Worker print("Running", torch._dynamo.mark_dynamic(x, 0)) 7250*da0073e9SAndroid Build Coastguard Worker return x * x 7251*da0073e9SAndroid Build Coastguard Worker 7252*da0073e9SAndroid Build Coastguard Worker torch._dynamo.optimize("eager")(my_dyn_fn)(y) 7253*da0073e9SAndroid Build Coastguard Worker 7254*da0073e9SAndroid Build Coastguard Worker def test_anomaly_aot_autograd(self): 7255*da0073e9SAndroid Build Coastguard Worker def fail(): 7256*da0073e9SAndroid Build Coastguard Worker raise AssertionError("fail") 7257*da0073e9SAndroid Build Coastguard Worker 7258*da0073e9SAndroid Build Coastguard Worker @allow_in_graph 7259*da0073e9SAndroid Build Coastguard Worker def h(a): 7260*da0073e9SAndroid Build Coastguard Worker r = a.sum() 7261*da0073e9SAndroid Build Coastguard Worker # Trigger an exception in backwards 7262*da0073e9SAndroid Build Coastguard Worker r.register_hook(lambda x: fail()) 7263*da0073e9SAndroid Build Coastguard Worker return r 7264*da0073e9SAndroid Build Coastguard Worker 7265*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="aot_eager") 7266*da0073e9SAndroid Build Coastguard Worker def f(a): 7267*da0073e9SAndroid Build Coastguard Worker return h(a) 7268*da0073e9SAndroid Build Coastguard Worker 7269*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w, self.assertRaises( 7270*da0073e9SAndroid Build Coastguard Worker torch._dynamo.exc.BackendCompilerFailed 7271*da0073e9SAndroid Build Coastguard Worker ): 7272*da0073e9SAndroid Build Coastguard Worker f(torch.randn(2, 2, requires_grad=True)) 7273*da0073e9SAndroid Build Coastguard Worker 7274*da0073e9SAndroid Build Coastguard Worker # Suppress unrelated pkg_resources warnings 7275*da0073e9SAndroid Build Coastguard Worker self.assertIn("forward call that caused the error", str(w[-1].message)) 7276*da0073e9SAndroid Build Coastguard Worker 7277*da0073e9SAndroid Build Coastguard Worker def test_py_guards_mark_dynamic(self): 7278*da0073e9SAndroid Build Coastguard Worker def my_dyn_fn(a): 7279*da0073e9SAndroid Build Coastguard Worker if a.shape[0] > 2: 7280*da0073e9SAndroid Build Coastguard Worker return a.cos() 7281*da0073e9SAndroid Build Coastguard Worker return a.sin() 7282*da0073e9SAndroid Build Coastguard Worker 7283*da0073e9SAndroid Build Coastguard Worker counter = CompileCounter() 7284*da0073e9SAndroid Build Coastguard Worker 7285*da0073e9SAndroid Build Coastguard Worker # Run with dynamic 7286*da0073e9SAndroid Build Coastguard Worker x0 = torch.randn([3, 3, 3]) 7287*da0073e9SAndroid Build Coastguard Worker torch._dynamo.mark_dynamic(x0, 0) 7288*da0073e9SAndroid Build Coastguard Worker torch._dynamo.optimize(counter)(my_dyn_fn)(x0) 7289*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 1) 7290*da0073e9SAndroid Build Coastguard Worker 7291*da0073e9SAndroid Build Coastguard Worker # Run without dynamic, no recompile 7292*da0073e9SAndroid Build Coastguard Worker x = torch.randn([3, 3, 3]) 7293*da0073e9SAndroid Build Coastguard Worker torch._dynamo.optimize(counter)(my_dyn_fn)(x) 7294*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 1) 7295*da0073e9SAndroid Build Coastguard Worker 7296*da0073e9SAndroid Build Coastguard Worker # Mark a new dim, 1, as dynamic 7297*da0073e9SAndroid Build Coastguard Worker x1 = torch.randn([3, 3, 3]) 7298*da0073e9SAndroid Build Coastguard Worker torch._dynamo.mark_dynamic(x1, 1) 7299*da0073e9SAndroid Build Coastguard Worker torch._dynamo.optimize(counter)(my_dyn_fn)(x1) 7300*da0073e9SAndroid Build Coastguard Worker # Recompile triggered because we marked a new dym as dynamic 7301*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 2) 7302*da0073e9SAndroid Build Coastguard Worker 7303*da0073e9SAndroid Build Coastguard Worker # Reset 7304*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 7305*da0073e9SAndroid Build Coastguard Worker # Reset counter 7306*da0073e9SAndroid Build Coastguard Worker counter = CompileCounter() 7307*da0073e9SAndroid Build Coastguard Worker 7308*da0073e9SAndroid Build Coastguard Worker # Run with dynamic 1 7309*da0073e9SAndroid Build Coastguard Worker torch._dynamo.optimize(counter)(my_dyn_fn)(x1) 7310*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 1) 7311*da0073e9SAndroid Build Coastguard Worker 7312*da0073e9SAndroid Build Coastguard Worker # Run with dynamic 0, not subset 7313*da0073e9SAndroid Build Coastguard Worker torch._dynamo.optimize(counter)(my_dyn_fn)(x0) 7314*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 2) 7315*da0073e9SAndroid Build Coastguard Worker 7316*da0073e9SAndroid Build Coastguard Worker # Run with dynamic 0, 1, 2, not subset 7317*da0073e9SAndroid Build Coastguard Worker x012 = torch.randn([3, 3, 3]) 7318*da0073e9SAndroid Build Coastguard Worker torch._dynamo.mark_dynamic(x012, 0) 7319*da0073e9SAndroid Build Coastguard Worker torch._dynamo.mark_dynamic(x012, 1) 7320*da0073e9SAndroid Build Coastguard Worker torch._dynamo.mark_dynamic(x012, 2) 7321*da0073e9SAndroid Build Coastguard Worker torch._dynamo.optimize(counter)(my_dyn_fn)(x012) 7322*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 3) 7323*da0073e9SAndroid Build Coastguard Worker 7324*da0073e9SAndroid Build Coastguard Worker def test_recompile_on_global_state_change(self): 7325*da0073e9SAndroid Build Coastguard Worker last_state = [] 7326*da0073e9SAndroid Build Coastguard Worker cnt = 0 7327*da0073e9SAndroid Build Coastguard Worker 7328*da0073e9SAndroid Build Coastguard Worker def my_compiler(gm, _): 7329*da0073e9SAndroid Build Coastguard Worker nonlocal cnt 7330*da0073e9SAndroid Build Coastguard Worker cnt += 1 7331*da0073e9SAndroid Build Coastguard Worker state = read_state() 7332*da0073e9SAndroid Build Coastguard Worker 7333*da0073e9SAndroid Build Coastguard Worker def inner(*args): 7334*da0073e9SAndroid Build Coastguard Worker last_state[:] = state 7335*da0073e9SAndroid Build Coastguard Worker return gm(*args) 7336*da0073e9SAndroid Build Coastguard Worker 7337*da0073e9SAndroid Build Coastguard Worker return inner 7338*da0073e9SAndroid Build Coastguard Worker 7339*da0073e9SAndroid Build Coastguard Worker def read_state(): 7340*da0073e9SAndroid Build Coastguard Worker return [ 7341*da0073e9SAndroid Build Coastguard Worker torch.is_grad_enabled(), 7342*da0073e9SAndroid Build Coastguard Worker torch.are_deterministic_algorithms_enabled(), 7343*da0073e9SAndroid Build Coastguard Worker torch._C._get_cublas_allow_tf32(), 7344*da0073e9SAndroid Build Coastguard Worker ] 7345*da0073e9SAndroid Build Coastguard Worker 7346*da0073e9SAndroid Build Coastguard Worker def write_state(state): 7347*da0073e9SAndroid Build Coastguard Worker torch.set_grad_enabled(state[0]), 7348*da0073e9SAndroid Build Coastguard Worker torch.use_deterministic_algorithms(state[1]) 7349*da0073e9SAndroid Build Coastguard Worker torch._C._set_cublas_allow_tf32(state[2]), 7350*da0073e9SAndroid Build Coastguard Worker 7351*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend=my_compiler) 7352*da0073e9SAndroid Build Coastguard Worker def fn(x): 7353*da0073e9SAndroid Build Coastguard Worker return x + 1 7354*da0073e9SAndroid Build Coastguard Worker 7355*da0073e9SAndroid Build Coastguard Worker initial_state = read_state() 7356*da0073e9SAndroid Build Coastguard Worker y = torch.randn(10) 7357*da0073e9SAndroid Build Coastguard Worker try: 7358*da0073e9SAndroid Build Coastguard Worker for round in range(3): 7359*da0073e9SAndroid Build Coastguard Worker for i in range(len(initial_state)): 7360*da0073e9SAndroid Build Coastguard Worker new_state = [False] * len(initial_state) 7361*da0073e9SAndroid Build Coastguard Worker new_state[i] = True 7362*da0073e9SAndroid Build Coastguard Worker write_state(new_state) 7363*da0073e9SAndroid Build Coastguard Worker assert read_state() == new_state 7364*da0073e9SAndroid Build Coastguard Worker last_state.clear() 7365*da0073e9SAndroid Build Coastguard Worker fn(y) 7366*da0073e9SAndroid Build Coastguard Worker assert last_state == new_state 7367*da0073e9SAndroid Build Coastguard Worker if round == 0: 7368*da0073e9SAndroid Build Coastguard Worker assert cnt == i + 1 7369*da0073e9SAndroid Build Coastguard Worker else: 7370*da0073e9SAndroid Build Coastguard Worker assert cnt == len(initial_state) 7371*da0073e9SAndroid Build Coastguard Worker finally: 7372*da0073e9SAndroid Build Coastguard Worker write_state(initial_state) 7373*da0073e9SAndroid Build Coastguard Worker 7374*da0073e9SAndroid Build Coastguard Worker def test_grad_state_mutated(self): 7375*da0073e9SAndroid Build Coastguard Worker prior = torch.is_grad_enabled() 7376*da0073e9SAndroid Build Coastguard Worker value = None 7377*da0073e9SAndroid Build Coastguard Worker cnt = CompileCounter() 7378*da0073e9SAndroid Build Coastguard Worker 7379*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.allow_in_graph 7380*da0073e9SAndroid Build Coastguard Worker def check_state(): 7381*da0073e9SAndroid Build Coastguard Worker nonlocal value 7382*da0073e9SAndroid Build Coastguard Worker value = torch.is_grad_enabled() 7383*da0073e9SAndroid Build Coastguard Worker 7384*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend=cnt, fullgraph=True) 7385*da0073e9SAndroid Build Coastguard Worker def fn(x): 7386*da0073e9SAndroid Build Coastguard Worker check_state() 7387*da0073e9SAndroid Build Coastguard Worker torch.set_grad_enabled(False) 7388*da0073e9SAndroid Build Coastguard Worker return x + 1 7389*da0073e9SAndroid Build Coastguard Worker 7390*da0073e9SAndroid Build Coastguard Worker try: 7391*da0073e9SAndroid Build Coastguard Worker torch.set_grad_enabled(True) 7392*da0073e9SAndroid Build Coastguard Worker fn(torch.randn(10)) 7393*da0073e9SAndroid Build Coastguard Worker assert value is True 7394*da0073e9SAndroid Build Coastguard Worker assert torch.is_grad_enabled() is False 7395*da0073e9SAndroid Build Coastguard Worker 7396*da0073e9SAndroid Build Coastguard Worker value = None 7397*da0073e9SAndroid Build Coastguard Worker torch.set_grad_enabled(True) 7398*da0073e9SAndroid Build Coastguard Worker fn(torch.randn(10)) 7399*da0073e9SAndroid Build Coastguard Worker assert value is True 7400*da0073e9SAndroid Build Coastguard Worker assert torch.is_grad_enabled() is False 7401*da0073e9SAndroid Build Coastguard Worker 7402*da0073e9SAndroid Build Coastguard Worker assert cnt.frame_count == 1 7403*da0073e9SAndroid Build Coastguard Worker finally: 7404*da0073e9SAndroid Build Coastguard Worker torch.set_grad_enabled(prior) 7405*da0073e9SAndroid Build Coastguard Worker 7406*da0073e9SAndroid Build Coastguard Worker def test_deterministic_algorithms_mutated(self): 7407*da0073e9SAndroid Build Coastguard Worker prior = torch.are_deterministic_algorithms_enabled() 7408*da0073e9SAndroid Build Coastguard Worker prior_warn_only = torch.is_deterministic_algorithms_warn_only_enabled() 7409*da0073e9SAndroid Build Coastguard Worker value = None 7410*da0073e9SAndroid Build Coastguard Worker warn_only = None 7411*da0073e9SAndroid Build Coastguard Worker cnt = CompileCounter() 7412*da0073e9SAndroid Build Coastguard Worker 7413*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.allow_in_graph 7414*da0073e9SAndroid Build Coastguard Worker def check_state(): 7415*da0073e9SAndroid Build Coastguard Worker nonlocal value 7416*da0073e9SAndroid Build Coastguard Worker nonlocal warn_only 7417*da0073e9SAndroid Build Coastguard Worker value = torch.are_deterministic_algorithms_enabled() 7418*da0073e9SAndroid Build Coastguard Worker warn_only = torch.is_deterministic_algorithms_warn_only_enabled() 7419*da0073e9SAndroid Build Coastguard Worker 7420*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend=cnt, fullgraph=True) 7421*da0073e9SAndroid Build Coastguard Worker def fn(x): 7422*da0073e9SAndroid Build Coastguard Worker check_state() 7423*da0073e9SAndroid Build Coastguard Worker torch.use_deterministic_algorithms(False, warn_only=False) 7424*da0073e9SAndroid Build Coastguard Worker return x + 1 7425*da0073e9SAndroid Build Coastguard Worker 7426*da0073e9SAndroid Build Coastguard Worker def run_fn(): 7427*da0073e9SAndroid Build Coastguard Worker torch.use_deterministic_algorithms(True, warn_only=True) 7428*da0073e9SAndroid Build Coastguard Worker fn(torch.randn(10)) 7429*da0073e9SAndroid Build Coastguard Worker assert value is True 7430*da0073e9SAndroid Build Coastguard Worker assert warn_only is True 7431*da0073e9SAndroid Build Coastguard Worker assert torch.are_deterministic_algorithms_enabled() is False 7432*da0073e9SAndroid Build Coastguard Worker assert torch.is_deterministic_algorithms_warn_only_enabled() is False 7433*da0073e9SAndroid Build Coastguard Worker 7434*da0073e9SAndroid Build Coastguard Worker try: 7435*da0073e9SAndroid Build Coastguard Worker run_fn() 7436*da0073e9SAndroid Build Coastguard Worker value, warn_only = None, None 7437*da0073e9SAndroid Build Coastguard Worker run_fn() 7438*da0073e9SAndroid Build Coastguard Worker assert cnt.frame_count == 1 7439*da0073e9SAndroid Build Coastguard Worker finally: 7440*da0073e9SAndroid Build Coastguard Worker torch.use_deterministic_algorithms(prior, warn_only=prior_warn_only) 7441*da0073e9SAndroid Build Coastguard Worker 7442*da0073e9SAndroid Build Coastguard Worker def test_torch_compile_ctx_on_forward_and_training_step(self): 7443*da0073e9SAndroid Build Coastguard Worker class MyModel(torch.nn.Module): 7444*da0073e9SAndroid Build Coastguard Worker def forward(self): 7445*da0073e9SAndroid Build Coastguard Worker ... 7446*da0073e9SAndroid Build Coastguard Worker 7447*da0073e9SAndroid Build Coastguard Worker def training_step(self): 7448*da0073e9SAndroid Build Coastguard Worker self() 7449*da0073e9SAndroid Build Coastguard Worker 7450*da0073e9SAndroid Build Coastguard Worker model = MyModel() 7451*da0073e9SAndroid Build Coastguard Worker compiled_model = torch.compile(model) 7452*da0073e9SAndroid Build Coastguard Worker 7453*da0073e9SAndroid Build Coastguard Worker model.forward = compiled_model.dynamo_ctx(model.forward) 7454*da0073e9SAndroid Build Coastguard Worker model.training_step = compiled_model.dynamo_ctx(model.training_step) 7455*da0073e9SAndroid Build Coastguard Worker 7456*da0073e9SAndroid Build Coastguard Worker model.training_step() 7457*da0073e9SAndroid Build Coastguard Worker 7458*da0073e9SAndroid Build Coastguard Worker def test_torch_guards_stack_frame_register_inlining(self): 7459*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([0.5, 0.5]) 7460*da0073e9SAndroid Build Coastguard Worker y = torch.tensor([0.75, 0.75, 0.75, 0.75]) 7461*da0073e9SAndroid Build Coastguard Worker z = torch.tensor([0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25]) 7462*da0073e9SAndroid Build Coastguard Worker 7463*da0073e9SAndroid Build Coastguard Worker def uwu_inline_me(x, y, z): 7464*da0073e9SAndroid Build Coastguard Worker r = torch.cat((x, x)) + y 7465*da0073e9SAndroid Build Coastguard Worker r2 = torch.cat((y, y)) + z 7466*da0073e9SAndroid Build Coastguard Worker return r, r2 7467*da0073e9SAndroid Build Coastguard Worker 7468*da0073e9SAndroid Build Coastguard Worker def fn(x, y, z): 7469*da0073e9SAndroid Build Coastguard Worker r, r2 = uwu_inline_me(x, y, z) 7470*da0073e9SAndroid Build Coastguard Worker return torch.mul(r, r), torch.mul(r2, r2) 7471*da0073e9SAndroid Build Coastguard Worker 7472*da0073e9SAndroid Build Coastguard Worker seen_frames = [] 7473*da0073e9SAndroid Build Coastguard Worker import contextlib 7474*da0073e9SAndroid Build Coastguard Worker 7475*da0073e9SAndroid Build Coastguard Worker @contextlib.contextmanager 7476*da0073e9SAndroid Build Coastguard Worker def global_context_capture_fn(frame_summary): 7477*da0073e9SAndroid Build Coastguard Worker if frame_summary is not None: 7478*da0073e9SAndroid Build Coastguard Worker seen_frames.append(frame_summary) 7479*da0073e9SAndroid Build Coastguard Worker yield 7480*da0073e9SAndroid Build Coastguard Worker 7481*da0073e9SAndroid Build Coastguard Worker with mock.patch( 7482*da0073e9SAndroid Build Coastguard Worker "torch._guards.TracingContext.current_frame", 7483*da0073e9SAndroid Build Coastguard Worker side_effect=global_context_capture_fn, 7484*da0073e9SAndroid Build Coastguard Worker ): 7485*da0073e9SAndroid Build Coastguard Worker torch._dynamo.optimize("eager")(fn)(x, y, z) 7486*da0073e9SAndroid Build Coastguard Worker 7487*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(seen_frames), 1) 7488*da0073e9SAndroid Build Coastguard Worker self.assertEqual(seen_frames[0].name, "fn") 7489*da0073e9SAndroid Build Coastguard Worker self.assertEqual(seen_frames[0].line, "r, r2 = uwu_inline_me(x, y, z)") 7490*da0073e9SAndroid Build Coastguard Worker 7491*da0073e9SAndroid Build Coastguard Worker def test_torch_guards_stack_frame_register_inlining_deep(self): 7492*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([0.5, 0.5]) 7493*da0073e9SAndroid Build Coastguard Worker y = torch.tensor([0.75, 0.75, 0.75, 0.75]) 7494*da0073e9SAndroid Build Coastguard Worker z = torch.tensor([0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25]) 7495*da0073e9SAndroid Build Coastguard Worker 7496*da0073e9SAndroid Build Coastguard Worker def uwu_inline_me_deep(x, y): 7497*da0073e9SAndroid Build Coastguard Worker return torch.cat((x, x)) + y 7498*da0073e9SAndroid Build Coastguard Worker 7499*da0073e9SAndroid Build Coastguard Worker def uwu_inline_me(x, y, z): 7500*da0073e9SAndroid Build Coastguard Worker r = uwu_inline_me_deep(x, y) 7501*da0073e9SAndroid Build Coastguard Worker r2 = uwu_inline_me_deep(y, z) 7502*da0073e9SAndroid Build Coastguard Worker return r, r2 7503*da0073e9SAndroid Build Coastguard Worker 7504*da0073e9SAndroid Build Coastguard Worker def fn(x, y, z): 7505*da0073e9SAndroid Build Coastguard Worker r, r2 = uwu_inline_me(x, y, z) 7506*da0073e9SAndroid Build Coastguard Worker return torch.mul(r, r), torch.mul(r2, r2) 7507*da0073e9SAndroid Build Coastguard Worker 7508*da0073e9SAndroid Build Coastguard Worker seen_frames = [] 7509*da0073e9SAndroid Build Coastguard Worker import contextlib 7510*da0073e9SAndroid Build Coastguard Worker 7511*da0073e9SAndroid Build Coastguard Worker @contextlib.contextmanager 7512*da0073e9SAndroid Build Coastguard Worker def global_context_capture_fn(frame_summary): 7513*da0073e9SAndroid Build Coastguard Worker if frame_summary is not None: 7514*da0073e9SAndroid Build Coastguard Worker seen_frames.append(frame_summary) 7515*da0073e9SAndroid Build Coastguard Worker yield 7516*da0073e9SAndroid Build Coastguard Worker 7517*da0073e9SAndroid Build Coastguard Worker with mock.patch( 7518*da0073e9SAndroid Build Coastguard Worker "torch._guards.TracingContext.current_frame", 7519*da0073e9SAndroid Build Coastguard Worker side_effect=global_context_capture_fn, 7520*da0073e9SAndroid Build Coastguard Worker ): 7521*da0073e9SAndroid Build Coastguard Worker torch._dynamo.optimize("eager")(fn)(x, y, z) 7522*da0073e9SAndroid Build Coastguard Worker 7523*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(seen_frames), 3) 7524*da0073e9SAndroid Build Coastguard Worker self.assertEqual(seen_frames[0].name, "fn") 7525*da0073e9SAndroid Build Coastguard Worker self.assertEqual(seen_frames[1].name, "uwu_inline_me") 7526*da0073e9SAndroid Build Coastguard Worker self.assertEqual(seen_frames[2].line, "r2 = uwu_inline_me_deep(y, z)") 7527*da0073e9SAndroid Build Coastguard Worker 7528*da0073e9SAndroid Build Coastguard Worker def test_error_on_recompile(self): 7529*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize("eager") 7530*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 7531*da0073e9SAndroid Build Coastguard Worker return a + b 7532*da0073e9SAndroid Build Coastguard Worker 7533*da0073e9SAndroid Build Coastguard Worker with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): 7534*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(torch._dynamo.exc.RecompileError): 7535*da0073e9SAndroid Build Coastguard Worker fn(torch.rand(2, 3), torch.rand(2, 3)) 7536*da0073e9SAndroid Build Coastguard Worker fn(torch.rand(2, 3), (1, 2, 3)) 7537*da0073e9SAndroid Build Coastguard Worker 7538*da0073e9SAndroid Build Coastguard Worker @expectedFailureDynamic 7539*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch(automatic_dynamic_shapes=False) 7540*da0073e9SAndroid Build Coastguard Worker def test_compile_profiler(self): 7541*da0073e9SAndroid Build Coastguard Worker class Model(torch.nn.Module): 7542*da0073e9SAndroid Build Coastguard Worker def forward(self, input): 7543*da0073e9SAndroid Build Coastguard Worker return input + input 7544*da0073e9SAndroid Build Coastguard Worker 7545*da0073e9SAndroid Build Coastguard Worker model = Model() 7546*da0073e9SAndroid Build Coastguard Worker prof = CompileProfiler() 7547*da0073e9SAndroid Build Coastguard Worker compiled = torch.compile(model, backend=prof) 7548*da0073e9SAndroid Build Coastguard Worker base_checker = ( 7549*da0073e9SAndroid Build Coastguard Worker lambda: FileCheck() 7550*da0073e9SAndroid Build Coastguard Worker .check("Torchdynamo Profiler Report") 7551*da0073e9SAndroid Build Coastguard Worker .check("Graph Breaks") 7552*da0073e9SAndroid Build Coastguard Worker .check("No graph breaks detected.") 7553*da0073e9SAndroid Build Coastguard Worker .check("Recompilation") 7554*da0073e9SAndroid Build Coastguard Worker ) 7555*da0073e9SAndroid Build Coastguard Worker input = torch.rand((2, 3, 4)) 7556*da0073e9SAndroid Build Coastguard Worker _ = compiled(input) 7557*da0073e9SAndroid Build Coastguard Worker base_checker().check("No recompilation detected.").run(prof.report()) 7558*da0073e9SAndroid Build Coastguard Worker 7559*da0073e9SAndroid Build Coastguard Worker new_shape_input = torch.rand((3, 3, 4)) 7560*da0073e9SAndroid Build Coastguard Worker _ = compiled(new_shape_input) 7561*da0073e9SAndroid Build Coastguard Worker 7562*da0073e9SAndroid Build Coastguard Worker # Not an exhaustive test of dynamic shapes behavior, but some sanity 7563*da0073e9SAndroid Build Coastguard Worker if torch._dynamo.config.assume_static_by_default: 7564*da0073e9SAndroid Build Coastguard Worker base_checker().check("Recompile Reasons").check("'forward'").check( 7565*da0073e9SAndroid Build Coastguard Worker "cache_size_limit to 1" 7566*da0073e9SAndroid Build Coastguard Worker ).run(prof.report()) 7567*da0073e9SAndroid Build Coastguard Worker else: 7568*da0073e9SAndroid Build Coastguard Worker base_checker().check("No recompilation detected.").run(prof.report()) 7569*da0073e9SAndroid Build Coastguard Worker 7570*da0073e9SAndroid Build Coastguard Worker new_shape_input = torch.rand((4, 3, 4)) 7571*da0073e9SAndroid Build Coastguard Worker _ = compiled(new_shape_input) 7572*da0073e9SAndroid Build Coastguard Worker 7573*da0073e9SAndroid Build Coastguard Worker base_checker().check("Recompile Reasons").check("'forward'").check( 7574*da0073e9SAndroid Build Coastguard Worker "tensor 'L['input']' size mismatch at index 0. expected 2, actual 3" 7575*da0073e9SAndroid Build Coastguard Worker ).check( 7576*da0073e9SAndroid Build Coastguard Worker "tensor 'L['input']' size mismatch at index 0. expected 3, actual 4" 7577*da0073e9SAndroid Build Coastguard Worker ).run( 7578*da0073e9SAndroid Build Coastguard Worker prof.report() 7579*da0073e9SAndroid Build Coastguard Worker ) 7580*da0073e9SAndroid Build Coastguard Worker 7581*da0073e9SAndroid Build Coastguard Worker def test_guards_strip_function_call(self): 7582*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.guards import strip_function_call 7583*da0073e9SAndroid Build Coastguard Worker 7584*da0073e9SAndroid Build Coastguard Worker test_case = [ 7585*da0073e9SAndroid Build Coastguard Worker ("___odict_getitem(a, 1)", "a"), 7586*da0073e9SAndroid Build Coastguard Worker ("a.layers[slice(2)][0]._xyz", "a"), 7587*da0073e9SAndroid Build Coastguard Worker ("getattr(a.layers[slice(2)][0]._abc, '0')", "a"), 7588*da0073e9SAndroid Build Coastguard Worker ("getattr(getattr(a.x[3], '0'), '3')", "a"), 7589*da0073e9SAndroid Build Coastguard Worker ("a.layers[slice(None, -1, None)][0]._xyz", "a"), 7590*da0073e9SAndroid Build Coastguard Worker ("a.layers[func('offset', -1, None)][0]._xyz", "a"), 7591*da0073e9SAndroid Build Coastguard Worker ] 7592*da0073e9SAndroid Build Coastguard Worker # strip_function_call should extract the object from the string. 7593*da0073e9SAndroid Build Coastguard Worker for name, expect_obj in test_case: 7594*da0073e9SAndroid Build Coastguard Worker self.assertEqual(strip_function_call(name), expect_obj) 7595*da0073e9SAndroid Build Coastguard Worker 7596*da0073e9SAndroid Build Coastguard Worker def test_int_neg(self): 7597*da0073e9SAndroid Build Coastguard Worker def int_neg(a, b): 7598*da0073e9SAndroid Build Coastguard Worker x = a.shape[0] 7599*da0073e9SAndroid Build Coastguard Worker y = b.shape[0] 7600*da0073e9SAndroid Build Coastguard Worker return -x * -y * a * b 7601*da0073e9SAndroid Build Coastguard Worker 7602*da0073e9SAndroid Build Coastguard Worker torch._dynamo.testing.standard_test(self, int_neg, 2) 7603*da0073e9SAndroid Build Coastguard Worker 7604*da0073e9SAndroid Build Coastguard Worker def test_hash_getitem_slice(self): 7605*da0073e9SAndroid Build Coastguard Worker s = GetItemSource(LocalSource("foo"), slice(None, -1, None)) 7606*da0073e9SAndroid Build Coastguard Worker s2 = GetItemSource(LocalSource("foo"), slice(None, -1, None)) 7607*da0073e9SAndroid Build Coastguard Worker s3 = GetItemSource(LocalSource("foo"), slice(None, -1, 2)) 7608*da0073e9SAndroid Build Coastguard Worker some_set = set() 7609*da0073e9SAndroid Build Coastguard Worker 7610*da0073e9SAndroid Build Coastguard Worker self.assertTrue(s not in some_set) 7611*da0073e9SAndroid Build Coastguard Worker self.assertTrue(s2 not in some_set) 7612*da0073e9SAndroid Build Coastguard Worker self.assertTrue(s3 not in some_set) 7613*da0073e9SAndroid Build Coastguard Worker 7614*da0073e9SAndroid Build Coastguard Worker some_set.add(s) 7615*da0073e9SAndroid Build Coastguard Worker 7616*da0073e9SAndroid Build Coastguard Worker self.assertTrue(s in some_set) 7617*da0073e9SAndroid Build Coastguard Worker # s and s2 should hash the same 7618*da0073e9SAndroid Build Coastguard Worker self.assertTrue(s2 in some_set) 7619*da0073e9SAndroid Build Coastguard Worker # s3 should be different 7620*da0073e9SAndroid Build Coastguard Worker self.assertTrue(s3 not in some_set) 7621*da0073e9SAndroid Build Coastguard Worker 7622*da0073e9SAndroid Build Coastguard Worker self.assertTrue(s == s2) 7623*da0073e9SAndroid Build Coastguard Worker self.assertTrue(s != s3) 7624*da0073e9SAndroid Build Coastguard Worker 7625*da0073e9SAndroid Build Coastguard Worker def test_inline_dict_function(self): 7626*da0073e9SAndroid Build Coastguard Worker def _result_type_dict(dtype): 7627*da0073e9SAndroid Build Coastguard Worker return {bool: torch.float32}[dtype] 7628*da0073e9SAndroid Build Coastguard Worker 7629*da0073e9SAndroid Build Coastguard Worker @torch.compile 7630*da0073e9SAndroid Build Coastguard Worker def f(): 7631*da0073e9SAndroid Build Coastguard Worker return torch.ones(3, dtype=_result_type_dict(bool)) 7632*da0073e9SAndroid Build Coastguard Worker 7633*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f(), torch.ones(3, dtype=torch.float32)) 7634*da0073e9SAndroid Build Coastguard Worker 7635*da0073e9SAndroid Build Coastguard Worker def test_inline_dict_function_passed_as_arg(self): 7636*da0073e9SAndroid Build Coastguard Worker @torch.compile 7637*da0073e9SAndroid Build Coastguard Worker def fn(d, x, y): 7638*da0073e9SAndroid Build Coastguard Worker if d[x] is torch.float32: 7639*da0073e9SAndroid Build Coastguard Worker return y.cos() 7640*da0073e9SAndroid Build Coastguard Worker else: 7641*da0073e9SAndroid Build Coastguard Worker return y.sin() 7642*da0073e9SAndroid Build Coastguard Worker 7643*da0073e9SAndroid Build Coastguard Worker dd = {bool: torch.float32, int: torch.int64} 7644*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(dd, bool, torch.ones(4)), torch.ones(4).cos()) 7645*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(dd, int, torch.ones(4)), torch.ones(4).sin()) 7646*da0073e9SAndroid Build Coastguard Worker 7647*da0073e9SAndroid Build Coastguard Worker def test_add_sizes(self): 7648*da0073e9SAndroid Build Coastguard Worker def func(x): 7649*da0073e9SAndroid Build Coastguard Worker y = x.size() 7650*da0073e9SAndroid Build Coastguard Worker return y + y 7651*da0073e9SAndroid Build Coastguard Worker 7652*da0073e9SAndroid Build Coastguard Worker eager_out = func(torch.ones(10, 10, 3)) 7653*da0073e9SAndroid Build Coastguard Worker compile_out = torch._dynamo.optimize("eager")(func)(torch.ones(10, 10, 3)) 7654*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(compile_out, torch.Size)) 7655*da0073e9SAndroid Build Coastguard Worker self.assertEqual(eager_out, compile_out) 7656*da0073e9SAndroid Build Coastguard Worker 7657*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "need multiple GPU") 7658*da0073e9SAndroid Build Coastguard Worker def test_cuda_set_device(self): 7659*da0073e9SAndroid Build Coastguard Worker def fn(): 7660*da0073e9SAndroid Build Coastguard Worker a = torch.ones(2, device="cuda") 7661*da0073e9SAndroid Build Coastguard Worker torch.cuda.set_device(1) 7662*da0073e9SAndroid Build Coastguard Worker return a + 1 7663*da0073e9SAndroid Build Coastguard Worker 7664*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(0): 7665*da0073e9SAndroid Build Coastguard Worker counter = CompileCounter() 7666*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(counter)(fn) 7667*da0073e9SAndroid Build Coastguard Worker res = opt_fn() 7668*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.device.type, "cuda") 7669*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.device.index, 0) 7670*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 2) 7671*da0073e9SAndroid Build Coastguard Worker 7672*da0073e9SAndroid Build Coastguard Worker def test_nested_function_resuming_with_correct_globals(self): 7673*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/99665 7674*da0073e9SAndroid Build Coastguard Worker try: 7675*da0073e9SAndroid Build Coastguard Worker from .utils import outer_func 7676*da0073e9SAndroid Build Coastguard Worker except ImportError: 7677*da0073e9SAndroid Build Coastguard Worker from utils import outer_func 7678*da0073e9SAndroid Build Coastguard Worker 7679*da0073e9SAndroid Build Coastguard Worker def gn(x, y): 7680*da0073e9SAndroid Build Coastguard Worker return x + y 7681*da0073e9SAndroid Build Coastguard Worker 7682*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 7683*da0073e9SAndroid Build Coastguard Worker return outer_func(gn)(x, y) 7684*da0073e9SAndroid Build Coastguard Worker 7685*da0073e9SAndroid Build Coastguard Worker x = torch.rand([3]) 7686*da0073e9SAndroid Build Coastguard Worker y = torch.rand([3]) 7687*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(backend="eager")(fn) 7688*da0073e9SAndroid Build Coastguard Worker ref = fn(x, y) 7689*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x, y) 7690*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 7691*da0073e9SAndroid Build Coastguard Worker 7692*da0073e9SAndroid Build Coastguard Worker @dataclasses.dataclass 7693*da0073e9SAndroid Build Coastguard Worker class CSETestCase: 7694*da0073e9SAndroid Build Coastguard Worker expr: str 7695*da0073e9SAndroid Build Coastguard Worker preface: typing.List[str] = dataclasses.field(default_factory=list) 7696*da0073e9SAndroid Build Coastguard Worker expected: typing.Optional[str] = None 7697*da0073e9SAndroid Build Coastguard Worker expected_py38: typing.Optional[str] = None 7698*da0073e9SAndroid Build Coastguard Worker 7699*da0073e9SAndroid Build Coastguard Worker def _is_py38(self) -> bool: 7700*da0073e9SAndroid Build Coastguard Worker return sys.version_info[:2] <= (3, 8) 7701*da0073e9SAndroid Build Coastguard Worker 7702*da0073e9SAndroid Build Coastguard Worker def _has_ast_unparse(self) -> bool: 7703*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.guards import HAS_UNPARSE_FUNCTIONS 7704*da0073e9SAndroid Build Coastguard Worker 7705*da0073e9SAndroid Build Coastguard Worker return HAS_UNPARSE_FUNCTIONS 7706*da0073e9SAndroid Build Coastguard Worker 7707*da0073e9SAndroid Build Coastguard Worker def test_guards_cse_pass_single(self): 7708*da0073e9SAndroid Build Coastguard Worker if not self._has_ast_unparse(): 7709*da0073e9SAndroid Build Coastguard Worker if IS_FBCODE: 7710*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("Needs astunparse or Python-3.9+") 7711*da0073e9SAndroid Build Coastguard Worker raise unittest.SkipTest("Needs astunparse or Python-3.9+") 7712*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.guards import PyExprCSEPass 7713*da0073e9SAndroid Build Coastguard Worker 7714*da0073e9SAndroid Build Coastguard Worker testcase = self.CSETestCase 7715*da0073e9SAndroid Build Coastguard Worker testcases = [ 7716*da0073e9SAndroid Build Coastguard Worker # Nothing gets CSE-d, since the only repeated sub-expression is 'x'. 7717*da0073e9SAndroid Build Coastguard Worker # i.e. not a node type we are interested on. 7718*da0073e9SAndroid Build Coastguard Worker testcase(expr="x[0].a"), 7719*da0073e9SAndroid Build Coastguard Worker testcase(expr="x[1].a"), 7720*da0073e9SAndroid Build Coastguard Worker testcase(expr="x[2].a"), 7721*da0073e9SAndroid Build Coastguard Worker # 'a.b.c' gets CSE-d, since it's a sub-expression used more than 'PyExprCSEPass.USE_THRESHOLD'. 7722*da0073e9SAndroid Build Coastguard Worker testcase( 7723*da0073e9SAndroid Build Coastguard Worker expr="a.b.c[0].d.e", 7724*da0073e9SAndroid Build Coastguard Worker preface=["_var0 = a.b", "_var1 = _var0.c"], 7725*da0073e9SAndroid Build Coastguard Worker expected="_var1[0].d.e", 7726*da0073e9SAndroid Build Coastguard Worker ), 7727*da0073e9SAndroid Build Coastguard Worker testcase(expr="a.b.c[1].d.e", expected="_var1[1].d.e"), 7728*da0073e9SAndroid Build Coastguard Worker testcase(expr="a.b.c[2].d.e", expected="_var1[2].d.e"), 7729*da0073e9SAndroid Build Coastguard Worker # 'm.n[0]' gets CSE-d, since it is a sub-expression used more than 'PyExprCSEPass.USE_THRESHOLD'. 7730*da0073e9SAndroid Build Coastguard Worker testcase( 7731*da0073e9SAndroid Build Coastguard Worker expr="f(m.n[0], '0').x.y.z", 7732*da0073e9SAndroid Build Coastguard Worker preface=["_var2 = m.n", "_var3 = _var2[0]"], 7733*da0073e9SAndroid Build Coastguard Worker expected="f(_var3, '0').x.y.z", 7734*da0073e9SAndroid Build Coastguard Worker ), 7735*da0073e9SAndroid Build Coastguard Worker testcase(expr="f(m.n[0], '1').x.y.z", expected="f(_var3, '1').x.y.z"), 7736*da0073e9SAndroid Build Coastguard Worker testcase(expr="f(m.n[0], '2').x.y.z", expected="f(_var3, '2').x.y.z"), 7737*da0073e9SAndroid Build Coastguard Worker # The whole expressiong gets CSE-d, as well as all of its sub-expressions. 7738*da0073e9SAndroid Build Coastguard Worker testcase( 7739*da0073e9SAndroid Build Coastguard Worker expr="self.g(a, b).k", 7740*da0073e9SAndroid Build Coastguard Worker preface=["_var4 = self.g", "_var5 = _var4(a, b)", "_var6 = _var5.k"], 7741*da0073e9SAndroid Build Coastguard Worker expected="_var6", 7742*da0073e9SAndroid Build Coastguard Worker ), 7743*da0073e9SAndroid Build Coastguard Worker testcase(expr="self.g(a, b).k", expected="_var6"), 7744*da0073e9SAndroid Build Coastguard Worker testcase(expr="self.g(a, b).k", expected="_var6"), 7745*da0073e9SAndroid Build Coastguard Worker ] 7746*da0073e9SAndroid Build Coastguard Worker csepass = PyExprCSEPass() 7747*da0073e9SAndroid Build Coastguard Worker csepass.count([t.expr for t in testcases]) 7748*da0073e9SAndroid Build Coastguard Worker 7749*da0073e9SAndroid Build Coastguard Worker for t in testcases: 7750*da0073e9SAndroid Build Coastguard Worker preface, expr = csepass.replace(t.expr) 7751*da0073e9SAndroid Build Coastguard Worker self.assertEqual(preface, t.preface) 7752*da0073e9SAndroid Build Coastguard Worker expected = t.expected if t.expected is not None else t.expr 7753*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expr, expected) 7754*da0073e9SAndroid Build Coastguard Worker 7755*da0073e9SAndroid Build Coastguard Worker def test_guards_cse_pass_multiple(self): 7756*da0073e9SAndroid Build Coastguard Worker if not self._has_ast_unparse(): 7757*da0073e9SAndroid Build Coastguard Worker raise unittest.SkipTest("Needs astunparse or Python-3.9+") 7758*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.guards import PyExprCSEPass 7759*da0073e9SAndroid Build Coastguard Worker 7760*da0073e9SAndroid Build Coastguard Worker testcase = self.CSETestCase 7761*da0073e9SAndroid Build Coastguard Worker testcases = [ 7762*da0073e9SAndroid Build Coastguard Worker testcase( 7763*da0073e9SAndroid Build Coastguard Worker expr="x[0].a < x[1].a * (3 - x[2].a)", 7764*da0073e9SAndroid Build Coastguard Worker expected="x[0].a < x[1].a * (3 - x[2].a)", 7765*da0073e9SAndroid Build Coastguard Worker expected_py38="(x[0].a < (x[1].a * (3 - x[2].a)))", 7766*da0073e9SAndroid Build Coastguard Worker ), 7767*da0073e9SAndroid Build Coastguard Worker testcase( 7768*da0073e9SAndroid Build Coastguard Worker expr="a.b.c[0].d.e + a.b.c[1].d.e * a.b.c[2].d.e > 0", 7769*da0073e9SAndroid Build Coastguard Worker preface=["_var0 = a.b", "_var1 = _var0.c"], 7770*da0073e9SAndroid Build Coastguard Worker expected="_var1[0].d.e + _var1[1].d.e * _var1[2].d.e > 0", 7771*da0073e9SAndroid Build Coastguard Worker expected_py38="((_var1[0].d.e + (_var1[1].d.e * _var1[2].d.e)) > 0)", 7772*da0073e9SAndroid Build Coastguard Worker ), 7773*da0073e9SAndroid Build Coastguard Worker testcase( 7774*da0073e9SAndroid Build Coastguard Worker expr="f(m.n[0], '0').x.y.z * f(m.n[0], '1').x.y.z * f(m.n[0], '2').x.y.z < 512", 7775*da0073e9SAndroid Build Coastguard Worker preface=["_var2 = m.n", "_var3 = _var2[0]"], 7776*da0073e9SAndroid Build Coastguard Worker expected="f(_var3, '0').x.y.z * f(_var3, '1').x.y.z * f(_var3, '2').x.y.z < 512", 7777*da0073e9SAndroid Build Coastguard Worker expected_py38="(((f(_var3, '0').x.y.z * f(_var3, '1').x.y.z) * f(_var3, '2').x.y.z) < 512)", 7778*da0073e9SAndroid Build Coastguard Worker ), 7779*da0073e9SAndroid Build Coastguard Worker testcase( 7780*da0073e9SAndroid Build Coastguard Worker expr="self.g(a, b).k + (1 - self.g(a, b).k) <= m[0].a + self.g(a, b).k", 7781*da0073e9SAndroid Build Coastguard Worker preface=["_var4 = self.g", "_var5 = _var4(a, b)", "_var6 = _var5.k"], 7782*da0073e9SAndroid Build Coastguard Worker expected="_var6 + (1 - _var6) <= m[0].a + _var6", 7783*da0073e9SAndroid Build Coastguard Worker expected_py38="((_var6 + (1 - _var6)) <= (m[0].a + _var6))", 7784*da0073e9SAndroid Build Coastguard Worker ), 7785*da0073e9SAndroid Build Coastguard Worker ] 7786*da0073e9SAndroid Build Coastguard Worker 7787*da0073e9SAndroid Build Coastguard Worker csepass = PyExprCSEPass() 7788*da0073e9SAndroid Build Coastguard Worker csepass.count([t.expr for t in testcases]) 7789*da0073e9SAndroid Build Coastguard Worker 7790*da0073e9SAndroid Build Coastguard Worker for t in testcases: 7791*da0073e9SAndroid Build Coastguard Worker preface, expr = csepass.replace(t.expr) 7792*da0073e9SAndroid Build Coastguard Worker self.assertEqual(preface, t.preface) 7793*da0073e9SAndroid Build Coastguard Worker expected = t.expected_py38 if self._is_py38() else t.expected 7794*da0073e9SAndroid Build Coastguard Worker expected = expected if expected is not None else t.expr 7795*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expr, expected) 7796*da0073e9SAndroid Build Coastguard Worker 7797*da0073e9SAndroid Build Coastguard Worker def test_guard_function_builder_with_cse(self): 7798*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.guards import build_guard_function 7799*da0073e9SAndroid Build Coastguard Worker 7800*da0073e9SAndroid Build Coastguard Worker exprs = [ 7801*da0073e9SAndroid Build Coastguard Worker "x[0].a < x[1].a * (3 - x[2].a)", 7802*da0073e9SAndroid Build Coastguard Worker "a.b.c[0].d.e + a.b.c[1].d.e * a.b.c[2].d.e > 0", 7803*da0073e9SAndroid Build Coastguard Worker "f(m.n[0], '0').x.y.z * f(m.n[0], '1').x.y.z * f(m.n[0], '2').x.y.z < 512", 7804*da0073e9SAndroid Build Coastguard Worker "self.g(a, b).k + (1 - self.g(a, b).k) <= m[0].a + self.g(a, b).k", 7805*da0073e9SAndroid Build Coastguard Worker ] 7806*da0073e9SAndroid Build Coastguard Worker 7807*da0073e9SAndroid Build Coastguard Worker _, pycode = build_guard_function(exprs, "") 7808*da0073e9SAndroid Build Coastguard Worker expected = """\ 7809*da0073e9SAndroid Build Coastguard Workerdef ___make_guard_fn(): 7810*da0073e9SAndroid Build Coastguard Worker def guard(L): 7811*da0073e9SAndroid Build Coastguard Worker if not (x[0].a < x[1].a * (3 - x[2].a)): 7812*da0073e9SAndroid Build Coastguard Worker return False 7813*da0073e9SAndroid Build Coastguard Worker _var0 = a.b 7814*da0073e9SAndroid Build Coastguard Worker _var1 = _var0.c 7815*da0073e9SAndroid Build Coastguard Worker if not (_var1[0].d.e + _var1[1].d.e * _var1[2].d.e > 0): 7816*da0073e9SAndroid Build Coastguard Worker return False 7817*da0073e9SAndroid Build Coastguard Worker _var2 = m.n 7818*da0073e9SAndroid Build Coastguard Worker _var3 = _var2[0] 7819*da0073e9SAndroid Build Coastguard Worker if not (f(_var3, '0').x.y.z * f(_var3, '1').x.y.z * f(_var3, '2').x.y.z < 512): 7820*da0073e9SAndroid Build Coastguard Worker return False 7821*da0073e9SAndroid Build Coastguard Worker _var4 = self.g 7822*da0073e9SAndroid Build Coastguard Worker _var5 = _var4(a, b) 7823*da0073e9SAndroid Build Coastguard Worker _var6 = _var5.k 7824*da0073e9SAndroid Build Coastguard Worker if not (_var6 + (1 - _var6) <= m[0].a + _var6): 7825*da0073e9SAndroid Build Coastguard Worker return False 7826*da0073e9SAndroid Build Coastguard Worker return True 7827*da0073e9SAndroid Build Coastguard Worker return guard 7828*da0073e9SAndroid Build Coastguard Worker""" 7829*da0073e9SAndroid Build Coastguard Worker expected_38 = """\ 7830*da0073e9SAndroid Build Coastguard Workerdef ___make_guard_fn(): 7831*da0073e9SAndroid Build Coastguard Worker def guard(L): 7832*da0073e9SAndroid Build Coastguard Worker if not ((x[0].a < (x[1].a * (3 - x[2].a)))): 7833*da0073e9SAndroid Build Coastguard Worker return False 7834*da0073e9SAndroid Build Coastguard Worker _var0 = a.b 7835*da0073e9SAndroid Build Coastguard Worker _var1 = _var0.c 7836*da0073e9SAndroid Build Coastguard Worker if not (((_var1[0].d.e + (_var1[1].d.e * _var1[2].d.e)) > 0)): 7837*da0073e9SAndroid Build Coastguard Worker return False 7838*da0073e9SAndroid Build Coastguard Worker _var2 = m.n 7839*da0073e9SAndroid Build Coastguard Worker _var3 = _var2[0] 7840*da0073e9SAndroid Build Coastguard Worker if not ((((f(_var3, '0').x.y.z * f(_var3, '1').x.y.z) * f(_var3, '2').x.y.z) < 512)): 7841*da0073e9SAndroid Build Coastguard Worker return False 7842*da0073e9SAndroid Build Coastguard Worker _var4 = self.g 7843*da0073e9SAndroid Build Coastguard Worker _var5 = _var4(a, b) 7844*da0073e9SAndroid Build Coastguard Worker _var6 = _var5.k 7845*da0073e9SAndroid Build Coastguard Worker if not (((_var6 + (1 - _var6)) <= (m[0].a + _var6))): 7846*da0073e9SAndroid Build Coastguard Worker return False 7847*da0073e9SAndroid Build Coastguard Worker return True 7848*da0073e9SAndroid Build Coastguard Worker return guard 7849*da0073e9SAndroid Build Coastguard Worker""" 7850*da0073e9SAndroid Build Coastguard Worker expected_38_no_astunparse = """\ 7851*da0073e9SAndroid Build Coastguard Workerdef ___make_guard_fn(): 7852*da0073e9SAndroid Build Coastguard Worker def guard(L): 7853*da0073e9SAndroid Build Coastguard Worker if not (x[0].a < x[1].a * (3 - x[2].a)): 7854*da0073e9SAndroid Build Coastguard Worker return False 7855*da0073e9SAndroid Build Coastguard Worker if not (a.b.c[0].d.e + a.b.c[1].d.e * a.b.c[2].d.e > 0): 7856*da0073e9SAndroid Build Coastguard Worker return False 7857*da0073e9SAndroid Build Coastguard Worker if not (f(m.n[0], '0').x.y.z * f(m.n[0], '1').x.y.z * f(m.n[0], '2').x.y.z < 512): 7858*da0073e9SAndroid Build Coastguard Worker return False 7859*da0073e9SAndroid Build Coastguard Worker if not (self.g(a, b).k + (1 - self.g(a, b).k) <= m[0].a + self.g(a, b).k): 7860*da0073e9SAndroid Build Coastguard Worker return False 7861*da0073e9SAndroid Build Coastguard Worker return True 7862*da0073e9SAndroid Build Coastguard Worker return guard 7863*da0073e9SAndroid Build Coastguard Worker""" 7864*da0073e9SAndroid Build Coastguard Worker 7865*da0073e9SAndroid Build Coastguard Worker if self._is_py38(): 7866*da0073e9SAndroid Build Coastguard Worker expected = ( 7867*da0073e9SAndroid Build Coastguard Worker expected_38 if self._has_ast_unparse() else expected_38_no_astunparse 7868*da0073e9SAndroid Build Coastguard Worker ) 7869*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, pycode) 7870*da0073e9SAndroid Build Coastguard Worker 7871*da0073e9SAndroid Build Coastguard Worker def test_dynamo_compiling_fake_tensor_to_vararg_int(self): 7872*da0073e9SAndroid Build Coastguard Worker class MyModule(torch.nn.Module): 7873*da0073e9SAndroid Build Coastguard Worker def __init__(self): 7874*da0073e9SAndroid Build Coastguard Worker super().__init__() 7875*da0073e9SAndroid Build Coastguard Worker 7876*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 7877*da0073e9SAndroid Build Coastguard Worker # use numpy int so it's wrapped as fake tensor in dynamo 7878*da0073e9SAndroid Build Coastguard Worker shape = np.int_(16) 7879*da0073e9SAndroid Build Coastguard Worker # test shape as fake tensor, which param type is 7880*da0073e9SAndroid Build Coastguard Worker # Sequence[Union[_int, SymInt]] 7881*da0073e9SAndroid Build Coastguard Worker return x.reshape(shape) 7882*da0073e9SAndroid Build Coastguard Worker 7883*da0073e9SAndroid Build Coastguard Worker x = torch.rand([4, 4]) 7884*da0073e9SAndroid Build Coastguard Worker model = MyModule() 7885*da0073e9SAndroid Build Coastguard Worker orig_out = model(x) 7886*da0073e9SAndroid Build Coastguard Worker opt_model = torch._dynamo.optimize("eager")(MyModule()) 7887*da0073e9SAndroid Build Coastguard Worker opt_out = opt_model(x) 7888*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(orig_out, opt_out)) 7889*da0073e9SAndroid Build Coastguard Worker 7890*da0073e9SAndroid Build Coastguard Worker def test_scalar_tensor_is_equivalent_to_symint_argument(self): 7891*da0073e9SAndroid Build Coastguard Worker class GumbelTopKSampler(torch.nn.Module): 7892*da0073e9SAndroid Build Coastguard Worker def __init__(self, T, k): 7893*da0073e9SAndroid Build Coastguard Worker super().__init__() 7894*da0073e9SAndroid Build Coastguard Worker self.T = torch.nn.Parameter( 7895*da0073e9SAndroid Build Coastguard Worker torch.tensor(T, dtype=torch.float32), requires_grad=False 7896*da0073e9SAndroid Build Coastguard Worker ) 7897*da0073e9SAndroid Build Coastguard Worker self.k = torch.nn.Parameter( 7898*da0073e9SAndroid Build Coastguard Worker torch.tensor(k, dtype=torch.int32), requires_grad=False 7899*da0073e9SAndroid Build Coastguard Worker ) 7900*da0073e9SAndroid Build Coastguard Worker 7901*da0073e9SAndroid Build Coastguard Worker def sample_discrete(self, logits): 7902*da0073e9SAndroid Build Coastguard Worker threshold = torch.topk(logits, self.k, sorted=True)[0][..., -1] 7903*da0073e9SAndroid Build Coastguard Worker samples = torch.ge(logits.squeeze(1), threshold).float() 7904*da0073e9SAndroid Build Coastguard Worker return samples 7905*da0073e9SAndroid Build Coastguard Worker 7906*da0073e9SAndroid Build Coastguard Worker def forward(self, logits): 7907*da0073e9SAndroid Build Coastguard Worker dsamples = self.sample_discrete(logits) 7908*da0073e9SAndroid Build Coastguard Worker return dsamples 7909*da0073e9SAndroid Build Coastguard Worker 7910*da0073e9SAndroid Build Coastguard Worker x = torch.rand([4, 4, 4, 4]) 7911*da0073e9SAndroid Build Coastguard Worker m = GumbelTopKSampler(T=4, k=4) 7912*da0073e9SAndroid Build Coastguard Worker orig_out = m(x) 7913*da0073e9SAndroid Build Coastguard Worker opt_m = torch.compile(backend="eager")(m) 7914*da0073e9SAndroid Build Coastguard Worker opt_out = opt_m(x) 7915*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(orig_out, opt_out)) 7916*da0073e9SAndroid Build Coastguard Worker 7917*da0073e9SAndroid Build Coastguard Worker def test_scalar_tensor_is_equivalent_to_symint_list_argument(self): 7918*da0073e9SAndroid Build Coastguard Worker class Jitter(torch.nn.Module): 7919*da0073e9SAndroid Build Coastguard Worker def __init__(self, jitter_val): 7920*da0073e9SAndroid Build Coastguard Worker super().__init__() 7921*da0073e9SAndroid Build Coastguard Worker self.jitter_val = jitter_val 7922*da0073e9SAndroid Build Coastguard Worker 7923*da0073e9SAndroid Build Coastguard Worker def roll_tensor(self, input): 7924*da0073e9SAndroid Build Coastguard Worker h_shift = self.jitter_val - 1 7925*da0073e9SAndroid Build Coastguard Worker w_shift = self.jitter_val + 1 7926*da0073e9SAndroid Build Coastguard Worker return torch.roll( 7927*da0073e9SAndroid Build Coastguard Worker torch.roll(input, shifts=h_shift, dims=2), shifts=w_shift, dims=3 7928*da0073e9SAndroid Build Coastguard Worker ) 7929*da0073e9SAndroid Build Coastguard Worker 7930*da0073e9SAndroid Build Coastguard Worker def forward(self, input): 7931*da0073e9SAndroid Build Coastguard Worker return self.roll_tensor(input) 7932*da0073e9SAndroid Build Coastguard Worker 7933*da0073e9SAndroid Build Coastguard Worker x = torch.rand([4, 4, 4, 4]) 7934*da0073e9SAndroid Build Coastguard Worker m = Jitter(jitter_val=4) 7935*da0073e9SAndroid Build Coastguard Worker orig_out = m(x) 7936*da0073e9SAndroid Build Coastguard Worker opt_m = torch.compile(backend="eager")(m) 7937*da0073e9SAndroid Build Coastguard Worker opt_out = opt_m(x) 7938*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(orig_out, opt_out)) 7939*da0073e9SAndroid Build Coastguard Worker 7940*da0073e9SAndroid Build Coastguard Worker def test_scalar_tensor_is_equivalent_to_int_list_argument(self): 7941*da0073e9SAndroid Build Coastguard Worker class MyModel(torch.nn.Module): 7942*da0073e9SAndroid Build Coastguard Worker def forward(self, input): 7943*da0073e9SAndroid Build Coastguard Worker permute = torch.tensor([0, 2, 1]) 7944*da0073e9SAndroid Build Coastguard Worker x = input.permute(*permute) 7945*da0073e9SAndroid Build Coastguard Worker return x 7946*da0073e9SAndroid Build Coastguard Worker 7947*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3, 4) 7948*da0073e9SAndroid Build Coastguard Worker m = MyModel() 7949*da0073e9SAndroid Build Coastguard Worker orig_out = m(x) 7950*da0073e9SAndroid Build Coastguard Worker opt_m = torch.compile(backend="eager")(m) 7951*da0073e9SAndroid Build Coastguard Worker opt_out = opt_m(x) 7952*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(orig_out, opt_out)) 7953*da0073e9SAndroid Build Coastguard Worker 7954*da0073e9SAndroid Build Coastguard Worker def test_torch_variable_hasattr(self): 7955*da0073e9SAndroid Build Coastguard Worker def fn(x): 7956*da0073e9SAndroid Build Coastguard Worker if hasattr(torch.nn, "Module"): 7957*da0073e9SAndroid Build Coastguard Worker return x * x 7958*da0073e9SAndroid Build Coastguard Worker return x + 1 7959*da0073e9SAndroid Build Coastguard Worker 7960*da0073e9SAndroid Build Coastguard Worker compiled_fn = torch.compile(backend="eager", fullgraph=True)(fn) 7961*da0073e9SAndroid Build Coastguard Worker 7962*da0073e9SAndroid Build Coastguard Worker x = torch.rand([4, 4]) 7963*da0073e9SAndroid Build Coastguard Worker fn_out = fn(x) 7964*da0073e9SAndroid Build Coastguard Worker compiled_out = compiled_fn(x) 7965*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(fn_out, compiled_out)) 7966*da0073e9SAndroid Build Coastguard Worker 7967*da0073e9SAndroid Build Coastguard Worker def test_list_hasattr1(self): 7968*da0073e9SAndroid Build Coastguard Worker def fn(x): 7969*da0073e9SAndroid Build Coastguard Worker if hasattr(x, "foo"): 7970*da0073e9SAndroid Build Coastguard Worker return x[0] + 1 7971*da0073e9SAndroid Build Coastguard Worker return x[0] - 1 7972*da0073e9SAndroid Build Coastguard Worker 7973*da0073e9SAndroid Build Coastguard Worker compiled_fn = torch.compile(backend="eager", fullgraph=True)(fn) 7974*da0073e9SAndroid Build Coastguard Worker 7975*da0073e9SAndroid Build Coastguard Worker x = [torch.randn(3)] 7976*da0073e9SAndroid Build Coastguard Worker fn_out = fn(x) 7977*da0073e9SAndroid Build Coastguard Worker compiled_out = compiled_fn(x) 7978*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(fn_out, compiled_out)) 7979*da0073e9SAndroid Build Coastguard Worker 7980*da0073e9SAndroid Build Coastguard Worker def test_list_hasattr2(self): 7981*da0073e9SAndroid Build Coastguard Worker def fn(): 7982*da0073e9SAndroid Build Coastguard Worker x = [torch.zeros(3)] 7983*da0073e9SAndroid Build Coastguard Worker if hasattr(x, "__len__"): 7984*da0073e9SAndroid Build Coastguard Worker return x[0] + 1 7985*da0073e9SAndroid Build Coastguard Worker return x[0] - 1 7986*da0073e9SAndroid Build Coastguard Worker 7987*da0073e9SAndroid Build Coastguard Worker compiled_fn = torch.compile(backend="eager", fullgraph=True)(fn) 7988*da0073e9SAndroid Build Coastguard Worker 7989*da0073e9SAndroid Build Coastguard Worker fn_out = fn() 7990*da0073e9SAndroid Build Coastguard Worker compiled_out = compiled_fn() 7991*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(fn_out, compiled_out)) 7992*da0073e9SAndroid Build Coastguard Worker 7993*da0073e9SAndroid Build Coastguard Worker def test_tuple_hasattr(self): 7994*da0073e9SAndroid Build Coastguard Worker def fn(x): 7995*da0073e9SAndroid Build Coastguard Worker if hasattr(x, "foo"): 7996*da0073e9SAndroid Build Coastguard Worker return x[0] + 1 7997*da0073e9SAndroid Build Coastguard Worker return x[1] - 1 7998*da0073e9SAndroid Build Coastguard Worker 7999*da0073e9SAndroid Build Coastguard Worker compiled_fn = torch.compile(backend="eager", fullgraph=True)(fn) 8000*da0073e9SAndroid Build Coastguard Worker 8001*da0073e9SAndroid Build Coastguard Worker x = (torch.randn(3), torch.randn(3)) 8002*da0073e9SAndroid Build Coastguard Worker fn_out = fn(x) 8003*da0073e9SAndroid Build Coastguard Worker compiled_out = compiled_fn(x) 8004*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(fn_out, compiled_out)) 8005*da0073e9SAndroid Build Coastguard Worker 8006*da0073e9SAndroid Build Coastguard Worker def test_fn_hasattr__name__1(self): 8007*da0073e9SAndroid Build Coastguard Worker def fn(): 8008*da0073e9SAndroid Build Coastguard Worker foo = lambda x: x + 1 8009*da0073e9SAndroid Build Coastguard Worker return hasattr(foo, "__name__") 8010*da0073e9SAndroid Build Coastguard Worker 8011*da0073e9SAndroid Build Coastguard Worker compiled_fn = torch.compile(backend="eager", fullgraph=True)(fn) 8012*da0073e9SAndroid Build Coastguard Worker 8013*da0073e9SAndroid Build Coastguard Worker fn_out = fn() 8014*da0073e9SAndroid Build Coastguard Worker compiled_out = compiled_fn() 8015*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn_out, compiled_out) 8016*da0073e9SAndroid Build Coastguard Worker self.assertTrue(fn_out) 8017*da0073e9SAndroid Build Coastguard Worker 8018*da0073e9SAndroid Build Coastguard Worker def test_fn_hasattr__name__2(self): 8019*da0073e9SAndroid Build Coastguard Worker def bar(x): 8020*da0073e9SAndroid Build Coastguard Worker return torch.sin(x) 8021*da0073e9SAndroid Build Coastguard Worker 8022*da0073e9SAndroid Build Coastguard Worker def fn(): 8023*da0073e9SAndroid Build Coastguard Worker return hasattr(bar, "__name__") 8024*da0073e9SAndroid Build Coastguard Worker 8025*da0073e9SAndroid Build Coastguard Worker compiled_fn = torch.compile(backend="eager", fullgraph=True)(fn) 8026*da0073e9SAndroid Build Coastguard Worker 8027*da0073e9SAndroid Build Coastguard Worker fn_out = fn() 8028*da0073e9SAndroid Build Coastguard Worker compiled_out = compiled_fn() 8029*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn_out, compiled_out) 8030*da0073e9SAndroid Build Coastguard Worker self.assertTrue(fn_out) 8031*da0073e9SAndroid Build Coastguard Worker 8032*da0073e9SAndroid Build Coastguard Worker def test_fn_hasattr__name__3(self): 8033*da0073e9SAndroid Build Coastguard Worker def bar(x, y): 8034*da0073e9SAndroid Build Coastguard Worker return torch.sin(x) + torch.cos(y) 8035*da0073e9SAndroid Build Coastguard Worker 8036*da0073e9SAndroid Build Coastguard Worker baz = functools.partial(bar, y=4) 8037*da0073e9SAndroid Build Coastguard Worker 8038*da0073e9SAndroid Build Coastguard Worker def fn(): 8039*da0073e9SAndroid Build Coastguard Worker return hasattr(baz, "__name__") 8040*da0073e9SAndroid Build Coastguard Worker 8041*da0073e9SAndroid Build Coastguard Worker compiled_fn = torch.compile(backend="eager", fullgraph=True)(fn) 8042*da0073e9SAndroid Build Coastguard Worker 8043*da0073e9SAndroid Build Coastguard Worker fn_out = fn() 8044*da0073e9SAndroid Build Coastguard Worker compiled_out = compiled_fn() 8045*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn_out, compiled_out) 8046*da0073e9SAndroid Build Coastguard Worker self.assertFalse(fn_out) 8047*da0073e9SAndroid Build Coastguard Worker 8048*da0073e9SAndroid Build Coastguard Worker def test_torch_objects_as_keys(self): 8049*da0073e9SAndroid Build Coastguard Worker remap = {torch.float16: torch.float32} 8050*da0073e9SAndroid Build Coastguard Worker 8051*da0073e9SAndroid Build Coastguard Worker def fn(): 8052*da0073e9SAndroid Build Coastguard Worker return torch.randn(3, dtype=remap[torch.float16]) 8053*da0073e9SAndroid Build Coastguard Worker 8054*da0073e9SAndroid Build Coastguard Worker opt = torch._dynamo.optimize("eager")(fn) 8055*da0073e9SAndroid Build Coastguard Worker opt() 8056*da0073e9SAndroid Build Coastguard Worker 8057*da0073e9SAndroid Build Coastguard Worker def test_tracing_py_tree(self): 8058*da0073e9SAndroid Build Coastguard Worker def fn(xs): 8059*da0073e9SAndroid Build Coastguard Worker flat_xs, spec = pytree.tree_flatten(xs) 8060*da0073e9SAndroid Build Coastguard Worker res = [x.clone() for x in flat_xs] 8061*da0073e9SAndroid Build Coastguard Worker return pytree.tree_unflatten(res, spec) 8062*da0073e9SAndroid Build Coastguard Worker 8063*da0073e9SAndroid Build Coastguard Worker xs = [torch.tensor(i) for i in range(3)] 8064*da0073e9SAndroid Build Coastguard Worker 8065*da0073e9SAndroid Build Coastguard Worker counter = CompileCounter() 8066*da0073e9SAndroid Build Coastguard Worker torch._dynamo.optimize(counter, nopython=True)(fn)(xs) 8067*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 1) 8068*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.op_count, 3) 8069*da0073e9SAndroid Build Coastguard Worker 8070*da0073e9SAndroid Build Coastguard Worker def test_tracing_nested_py_tree(self): 8071*da0073e9SAndroid Build Coastguard Worker import torch.utils._pytree as pytree 8072*da0073e9SAndroid Build Coastguard Worker 8073*da0073e9SAndroid Build Coastguard Worker def fn(xs): 8074*da0073e9SAndroid Build Coastguard Worker flat_xs, spec = pytree.tree_flatten(xs) 8075*da0073e9SAndroid Build Coastguard Worker res = [x.clone() for x in flat_xs] 8076*da0073e9SAndroid Build Coastguard Worker return pytree.tree_unflatten(res, spec) 8077*da0073e9SAndroid Build Coastguard Worker 8078*da0073e9SAndroid Build Coastguard Worker xs = [torch.tensor(i) for i in range(3)] 8079*da0073e9SAndroid Build Coastguard Worker xsl = [xs, xs, xs, xs] 8080*da0073e9SAndroid Build Coastguard Worker 8081*da0073e9SAndroid Build Coastguard Worker counter = CompileCounter() 8082*da0073e9SAndroid Build Coastguard Worker comp_out = torch._dynamo.optimize(counter, nopython=True)(fn)(xsl) 8083*da0073e9SAndroid Build Coastguard Worker real_out = fn(xsl) 8084*da0073e9SAndroid Build Coastguard Worker self.assertEqual(comp_out, real_out) 8085*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 1) 8086*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.op_count, 12) 8087*da0073e9SAndroid Build Coastguard Worker 8088*da0073e9SAndroid Build Coastguard Worker def test_tracing_nested_py_tree_tuples(self): 8089*da0073e9SAndroid Build Coastguard Worker import torch.utils._pytree as pytree 8090*da0073e9SAndroid Build Coastguard Worker 8091*da0073e9SAndroid Build Coastguard Worker def fn(xs): 8092*da0073e9SAndroid Build Coastguard Worker flat_xs, spec = pytree.tree_flatten(xs) 8093*da0073e9SAndroid Build Coastguard Worker res = [x.clone() for x in flat_xs] 8094*da0073e9SAndroid Build Coastguard Worker return pytree.tree_unflatten(res, spec) 8095*da0073e9SAndroid Build Coastguard Worker 8096*da0073e9SAndroid Build Coastguard Worker xs = [torch.tensor(i) for i in range(3)] 8097*da0073e9SAndroid Build Coastguard Worker xsl = (xs, xs, xs, xs) 8098*da0073e9SAndroid Build Coastguard Worker 8099*da0073e9SAndroid Build Coastguard Worker counter = CompileCounter() 8100*da0073e9SAndroid Build Coastguard Worker comp_out = torch._dynamo.optimize(counter, nopython=True)(fn)(xsl) 8101*da0073e9SAndroid Build Coastguard Worker real_out = fn(xsl) 8102*da0073e9SAndroid Build Coastguard Worker self.assertEqual(comp_out, real_out) 8103*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 1) 8104*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.op_count, 12) 8105*da0073e9SAndroid Build Coastguard Worker 8106*da0073e9SAndroid Build Coastguard Worker def test_tracing_nested_py_tree_dicts(self): 8107*da0073e9SAndroid Build Coastguard Worker import torch.utils._pytree as pytree 8108*da0073e9SAndroid Build Coastguard Worker 8109*da0073e9SAndroid Build Coastguard Worker def fn(xs): 8110*da0073e9SAndroid Build Coastguard Worker flat_xs, spec = pytree.tree_flatten(xs) 8111*da0073e9SAndroid Build Coastguard Worker res = [x.clone() for x in flat_xs] 8112*da0073e9SAndroid Build Coastguard Worker return pytree.tree_unflatten(res, spec) 8113*da0073e9SAndroid Build Coastguard Worker 8114*da0073e9SAndroid Build Coastguard Worker xs = [torch.tensor(i) for i in range(3)] 8115*da0073e9SAndroid Build Coastguard Worker xsl = { 8116*da0073e9SAndroid Build Coastguard Worker "a": xs, 8117*da0073e9SAndroid Build Coastguard Worker "b": xs, 8118*da0073e9SAndroid Build Coastguard Worker "c": xs, 8119*da0073e9SAndroid Build Coastguard Worker } 8120*da0073e9SAndroid Build Coastguard Worker 8121*da0073e9SAndroid Build Coastguard Worker counter = CompileCounter() 8122*da0073e9SAndroid Build Coastguard Worker comp_out = torch._dynamo.optimize(counter, nopython=True)(fn)(xsl) 8123*da0073e9SAndroid Build Coastguard Worker real_out = fn(xsl) 8124*da0073e9SAndroid Build Coastguard Worker self.assertEqual(comp_out, real_out) 8125*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 1) 8126*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.op_count, 9) 8127*da0073e9SAndroid Build Coastguard Worker 8128*da0073e9SAndroid Build Coastguard Worker def test_dynamic_one_hot(self): 8129*da0073e9SAndroid Build Coastguard Worker def fn(x): 8130*da0073e9SAndroid Build Coastguard Worker x = x + 1 8131*da0073e9SAndroid Build Coastguard Worker # graph break from data-dependent output shape 8132*da0073e9SAndroid Build Coastguard Worker x = torch.nn.functional.one_hot(x) 8133*da0073e9SAndroid Build Coastguard Worker x = x + 1 8134*da0073e9SAndroid Build Coastguard Worker return x 8135*da0073e9SAndroid Build Coastguard Worker 8136*da0073e9SAndroid Build Coastguard Worker inp = torch.arange(20) % 4 8137*da0073e9SAndroid Build Coastguard Worker counter = CompileCounter() 8138*da0073e9SAndroid Build Coastguard Worker real_out = fn(inp) 8139*da0073e9SAndroid Build Coastguard Worker comp_out = torch.compile(fn, backend=counter)(inp) 8140*da0073e9SAndroid Build Coastguard Worker self.assertEqual(comp_out, real_out) 8141*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 2) 8142*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.op_count, 2) 8143*da0073e9SAndroid Build Coastguard Worker 8144*da0073e9SAndroid Build Coastguard Worker def test_tracing_nested_py_tree_mixed_all(self): 8145*da0073e9SAndroid Build Coastguard Worker import torch.utils._pytree as pytree 8146*da0073e9SAndroid Build Coastguard Worker 8147*da0073e9SAndroid Build Coastguard Worker def fn(xs): 8148*da0073e9SAndroid Build Coastguard Worker flat_xs, spec = pytree.tree_flatten(xs) 8149*da0073e9SAndroid Build Coastguard Worker res = [x.clone() for x in flat_xs] 8150*da0073e9SAndroid Build Coastguard Worker return pytree.tree_unflatten(res, spec) 8151*da0073e9SAndroid Build Coastguard Worker 8152*da0073e9SAndroid Build Coastguard Worker xs = [torch.tensor(i) for i in range(3)] 8153*da0073e9SAndroid Build Coastguard Worker xsa = (xs, xs) 8154*da0073e9SAndroid Build Coastguard Worker xsb = {"aa": xsa, "ab": xs} 8155*da0073e9SAndroid Build Coastguard Worker xsl = { 8156*da0073e9SAndroid Build Coastguard Worker "a": xs, 8157*da0073e9SAndroid Build Coastguard Worker "b": xsa, 8158*da0073e9SAndroid Build Coastguard Worker "c": xsb, 8159*da0073e9SAndroid Build Coastguard Worker } 8160*da0073e9SAndroid Build Coastguard Worker 8161*da0073e9SAndroid Build Coastguard Worker counter = CompileCounter() 8162*da0073e9SAndroid Build Coastguard Worker comp_out = torch._dynamo.optimize(counter, nopython=True)(fn)(xsl) 8163*da0073e9SAndroid Build Coastguard Worker real_out = fn(xsl) 8164*da0073e9SAndroid Build Coastguard Worker self.assertEqual(comp_out, real_out) 8165*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 1) 8166*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.op_count, 18) 8167*da0073e9SAndroid Build Coastguard Worker 8168*da0073e9SAndroid Build Coastguard Worker def test_any_all_symnode(self): 8169*da0073e9SAndroid Build Coastguard Worker cnt = CompileCounter() 8170*da0073e9SAndroid Build Coastguard Worker 8171*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend=cnt, fullgraph=True, dynamic=True) 8172*da0073e9SAndroid Build Coastguard Worker def fn(x): 8173*da0073e9SAndroid Build Coastguard Worker t = x.size(0) >= 10 8174*da0073e9SAndroid Build Coastguard Worker f = x.size(0) >= 100 8175*da0073e9SAndroid Build Coastguard Worker if any([]) or any([f]) or any([f, f]): 8176*da0073e9SAndroid Build Coastguard Worker return x - 1 8177*da0073e9SAndroid Build Coastguard Worker if all([f]) or all([t, f]) or all([f, t]) or all([f, f]): 8178*da0073e9SAndroid Build Coastguard Worker return x - 2 8179*da0073e9SAndroid Build Coastguard Worker if not (all([]) and all([t]) and all([t, t])): 8180*da0073e9SAndroid Build Coastguard Worker return x - 3 8181*da0073e9SAndroid Build Coastguard Worker if not (any([t]) and any([t, f]) and any([f, t])): 8182*da0073e9SAndroid Build Coastguard Worker return x - 4 8183*da0073e9SAndroid Build Coastguard Worker return x + 1 8184*da0073e9SAndroid Build Coastguard Worker 8185*da0073e9SAndroid Build Coastguard Worker y1 = torch.randn(16) 8186*da0073e9SAndroid Build Coastguard Worker y2 = torch.randn(18) 8187*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(y1), y1 + 1) 8188*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(y2), y2 + 1) 8189*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 8190*da0073e9SAndroid Build Coastguard Worker y3 = torch.randn(5) 8191*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(y3), y3 - 3) 8192*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 2) 8193*da0073e9SAndroid Build Coastguard Worker 8194*da0073e9SAndroid Build Coastguard Worker def test_tracing_py_tree_tensor_subclass(self): 8195*da0073e9SAndroid Build Coastguard Worker import torch.utils._pytree as pytree 8196*da0073e9SAndroid Build Coastguard Worker from torch.testing._internal.two_tensor import TwoTensor 8197*da0073e9SAndroid Build Coastguard Worker from torch.utils.checkpoint import checkpoint 8198*da0073e9SAndroid Build Coastguard Worker 8199*da0073e9SAndroid Build Coastguard Worker def fn(xs): 8200*da0073e9SAndroid Build Coastguard Worker nested_xs = [[xs]] 8201*da0073e9SAndroid Build Coastguard Worker flat_xs, spec = pytree.tree_flatten(xs) 8202*da0073e9SAndroid Build Coastguard Worker return flat_xs[0].clone() 8203*da0073e9SAndroid Build Coastguard Worker 8204*da0073e9SAndroid Build Coastguard Worker # use checkpoint to trigger a "sourceless" tensor subclass 8205*da0073e9SAndroid Build Coastguard Worker def checkpoint_fn(xs): 8206*da0073e9SAndroid Build Coastguard Worker return checkpoint(fn, xs, use_reentrant=True) 8207*da0073e9SAndroid Build Coastguard Worker 8208*da0073e9SAndroid Build Coastguard Worker xs = TwoTensor(torch.ones(2, 2), torch.ones(2, 2)) 8209*da0073e9SAndroid Build Coastguard Worker 8210*da0073e9SAndroid Build Coastguard Worker counter = CompileCounter() 8211*da0073e9SAndroid Build Coastguard Worker torch._dynamo.optimize(counter, nopython=True)(checkpoint_fn)(xs) 8212*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 1) 8213*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.op_count, 2) 8214*da0073e9SAndroid Build Coastguard Worker 8215*da0073e9SAndroid Build Coastguard Worker def test_tracing_tree_map_only(self): 8216*da0073e9SAndroid Build Coastguard Worker import torch.utils._pytree as pytree 8217*da0073e9SAndroid Build Coastguard Worker 8218*da0073e9SAndroid Build Coastguard Worker def fn(xs): 8219*da0073e9SAndroid Build Coastguard Worker def mapper(x): 8220*da0073e9SAndroid Build Coastguard Worker return x.clone() 8221*da0073e9SAndroid Build Coastguard Worker 8222*da0073e9SAndroid Build Coastguard Worker y = pytree.tree_map_only(torch.Tensor, mapper, xs) 8223*da0073e9SAndroid Build Coastguard Worker return y 8224*da0073e9SAndroid Build Coastguard Worker 8225*da0073e9SAndroid Build Coastguard Worker xs = [torch.tensor(i) for i in range(3)] + ["hi"] 8226*da0073e9SAndroid Build Coastguard Worker xsa = (xs, xs) 8227*da0073e9SAndroid Build Coastguard Worker xsb = {"aa": xsa, "ab": xs} 8228*da0073e9SAndroid Build Coastguard Worker 8229*da0073e9SAndroid Build Coastguard Worker counter = CompileCounter() 8230*da0073e9SAndroid Build Coastguard Worker comp_out = torch._dynamo.optimize(counter, nopython=True)(fn)(xsb) 8231*da0073e9SAndroid Build Coastguard Worker real_out = fn(xsb) 8232*da0073e9SAndroid Build Coastguard Worker 8233*da0073e9SAndroid Build Coastguard Worker self.assertEqual(comp_out, real_out) 8234*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 1) 8235*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.op_count, 9) 8236*da0073e9SAndroid Build Coastguard Worker 8237*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch( 8238*da0073e9SAndroid Build Coastguard Worker capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True 8239*da0073e9SAndroid Build Coastguard Worker ) 8240*da0073e9SAndroid Build Coastguard Worker def test_unbacked_symint(self): 8241*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager") 8242*da0073e9SAndroid Build Coastguard Worker def f(lengths, values): 8243*da0073e9SAndroid Build Coastguard Worker sizes = lengths.tolist() 8244*da0073e9SAndroid Build Coastguard Worker for s in sizes: 8245*da0073e9SAndroid Build Coastguard Worker torch._check_is_size(s) 8246*da0073e9SAndroid Build Coastguard Worker torch._check(s >= 2) 8247*da0073e9SAndroid Build Coastguard Worker torch._check(s <= 100) 8248*da0073e9SAndroid Build Coastguard Worker return torch.split(values, sizes) 8249*da0073e9SAndroid Build Coastguard Worker 8250*da0073e9SAndroid Build Coastguard Worker f(torch.tensor([2, 3, 4]), torch.randn(9)) 8251*da0073e9SAndroid Build Coastguard Worker 8252*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch( 8253*da0073e9SAndroid Build Coastguard Worker capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True 8254*da0073e9SAndroid Build Coastguard Worker ) 8255*da0073e9SAndroid Build Coastguard Worker def test_unbacked_auto_functionalize_op(self): 8256*da0073e9SAndroid Build Coastguard Worker @torch.library.custom_op( 8257*da0073e9SAndroid Build Coastguard Worker "mylib::mk_image", mutates_args=("decoder",), device_types=["cpu"] 8258*da0073e9SAndroid Build Coastguard Worker ) 8259*da0073e9SAndroid Build Coastguard Worker def mk_image(decoder: Tensor) -> Tensor: 8260*da0073e9SAndroid Build Coastguard Worker return torch.randn(2, 3, 4, 5) 8261*da0073e9SAndroid Build Coastguard Worker 8262*da0073e9SAndroid Build Coastguard Worker @torch.library.register_fake("mylib::mk_image") 8263*da0073e9SAndroid Build Coastguard Worker def _(decoder: Tensor) -> Tensor: 8264*da0073e9SAndroid Build Coastguard Worker image_size = [torch.library.get_ctx().new_dynamic_size() for _ in range(4)] 8265*da0073e9SAndroid Build Coastguard Worker return torch.empty(image_size) 8266*da0073e9SAndroid Build Coastguard Worker 8267*da0073e9SAndroid Build Coastguard Worker @torch.compile(fullgraph=True) 8268*da0073e9SAndroid Build Coastguard Worker def f(x): 8269*da0073e9SAndroid Build Coastguard Worker return torch.ops.mylib.mk_image.default(x) 8270*da0073e9SAndroid Build Coastguard Worker 8271*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(100, dtype=torch.int64) 8272*da0073e9SAndroid Build Coastguard Worker f(x) 8273*da0073e9SAndroid Build Coastguard Worker 8274*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch(capture_scalar_outputs=True) 8275*da0073e9SAndroid Build Coastguard Worker def test_runtime_assert_replacement(self): 8276*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="aot_eager") 8277*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 8278*da0073e9SAndroid Build Coastguard Worker z = y.item() 8279*da0073e9SAndroid Build Coastguard Worker torch._check(z == 3) 8280*da0073e9SAndroid Build Coastguard Worker return x + z 8281*da0073e9SAndroid Build Coastguard Worker 8282*da0073e9SAndroid Build Coastguard Worker fn(torch.randn(4), torch.tensor([3])) 8283*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: fn(torch.randn(4), torch.tensor([4]))) 8284*da0073e9SAndroid Build Coastguard Worker 8285*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch(capture_scalar_outputs=True) 8286*da0073e9SAndroid Build Coastguard Worker def test_cat_unbacked(self): 8287*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager") 8288*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 8289*da0073e9SAndroid Build Coastguard Worker z = y.item() 8290*da0073e9SAndroid Build Coastguard Worker return torch.cat([x, torch.ones(z)]) 8291*da0073e9SAndroid Build Coastguard Worker 8292*da0073e9SAndroid Build Coastguard Worker fn(torch.randn(2, 3), torch.tensor([0])) 8293*da0073e9SAndroid Build Coastguard Worker self.assertRaises( 8294*da0073e9SAndroid Build Coastguard Worker RuntimeError, lambda: fn(torch.randn(2, 3), torch.tensor([1])) 8295*da0073e9SAndroid Build Coastguard Worker ) 8296*da0073e9SAndroid Build Coastguard Worker 8297*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch( 8298*da0073e9SAndroid Build Coastguard Worker capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True 8299*da0073e9SAndroid Build Coastguard Worker ) 8300*da0073e9SAndroid Build Coastguard Worker def test_aot_autograd_propagate_unbacked_symints_shape(self): 8301*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="aot_eager") 8302*da0073e9SAndroid Build Coastguard Worker def f(x): 8303*da0073e9SAndroid Build Coastguard Worker return torch.nonzero(x) 8304*da0073e9SAndroid Build Coastguard Worker 8305*da0073e9SAndroid Build Coastguard Worker f(torch.tensor([1, 0, 3, 2, 0])) 8306*da0073e9SAndroid Build Coastguard Worker 8307*da0073e9SAndroid Build Coastguard Worker def test_simple_set_usage(self): 8308*da0073e9SAndroid Build Coastguard Worker def foo(x, y): 8309*da0073e9SAndroid Build Coastguard Worker setty = {x, y} 8310*da0073e9SAndroid Build Coastguard Worker return setty.pop() * setty.pop() 8311*da0073e9SAndroid Build Coastguard Worker 8312*da0073e9SAndroid Build Coastguard Worker counter = CompileCounter() 8313*da0073e9SAndroid Build Coastguard Worker foo = torch._dynamo.optimize(counter, nopython=True)(foo) 8314*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10, 10) 8315*da0073e9SAndroid Build Coastguard Worker y = torch.randn(10, 10) 8316*da0073e9SAndroid Build Coastguard Worker foo(x, y) 8317*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 1) 8318*da0073e9SAndroid Build Coastguard Worker 8319*da0073e9SAndroid Build Coastguard Worker def test_add_to_set(self): 8320*da0073e9SAndroid Build Coastguard Worker def foo(x, y): 8321*da0073e9SAndroid Build Coastguard Worker setty = set() 8322*da0073e9SAndroid Build Coastguard Worker setty.add(x[0]) 8323*da0073e9SAndroid Build Coastguard Worker setty.add(x[1]) 8324*da0073e9SAndroid Build Coastguard Worker setty.add(x[2]) 8325*da0073e9SAndroid Build Coastguard Worker setty.add(y) 8326*da0073e9SAndroid Build Coastguard Worker return y * len(setty) 8327*da0073e9SAndroid Build Coastguard Worker 8328*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10, 10) 8329*da0073e9SAndroid Build Coastguard Worker y = torch.randn(2, 2) 8330*da0073e9SAndroid Build Coastguard Worker eager_result = foo([x, x, x, x, y], y) 8331*da0073e9SAndroid Build Coastguard Worker 8332*da0073e9SAndroid Build Coastguard Worker counter = CompileCounter() 8333*da0073e9SAndroid Build Coastguard Worker foo = torch._dynamo.optimize(counter, nopython=True)(foo) 8334*da0073e9SAndroid Build Coastguard Worker result = foo([x, x, x, x, y], y) 8335*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 1) 8336*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, eager_result) 8337*da0073e9SAndroid Build Coastguard Worker 8338*da0073e9SAndroid Build Coastguard Worker def test_iter_set(self): 8339*da0073e9SAndroid Build Coastguard Worker def foo(x, y): 8340*da0073e9SAndroid Build Coastguard Worker setty = set() 8341*da0073e9SAndroid Build Coastguard Worker for t in x: 8342*da0073e9SAndroid Build Coastguard Worker setty.add(t) 8343*da0073e9SAndroid Build Coastguard Worker return y * len(setty) 8344*da0073e9SAndroid Build Coastguard Worker 8345*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10, 10) 8346*da0073e9SAndroid Build Coastguard Worker y = torch.randn(2, 2) 8347*da0073e9SAndroid Build Coastguard Worker eager_result = foo([x, x, x, x, y], y) 8348*da0073e9SAndroid Build Coastguard Worker 8349*da0073e9SAndroid Build Coastguard Worker counter = CompileCounter() 8350*da0073e9SAndroid Build Coastguard Worker foo = torch._dynamo.optimize(counter, nopython=True)(foo) 8351*da0073e9SAndroid Build Coastguard Worker result = foo([x, x, x, x, y], y) 8352*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 1) 8353*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, eager_result) 8354*da0073e9SAndroid Build Coastguard Worker 8355*da0073e9SAndroid Build Coastguard Worker def test_input_set_graph_break(self): 8356*da0073e9SAndroid Build Coastguard Worker def foo(x): 8357*da0073e9SAndroid Build Coastguard Worker return x.pop() * x.pop() 8358*da0073e9SAndroid Build Coastguard Worker 8359*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10, 10) 8360*da0073e9SAndroid Build Coastguard Worker y = torch.randn(10, 10) 8361*da0073e9SAndroid Build Coastguard Worker 8362*da0073e9SAndroid Build Coastguard Worker counter = CompileCounter() 8363*da0073e9SAndroid Build Coastguard Worker 8364*da0073e9SAndroid Build Coastguard Worker inp = {x, x, x, x, y, y} 8365*da0073e9SAndroid Build Coastguard Worker foo = torch._dynamo.optimize(counter, nopython=True)(foo) 8366*da0073e9SAndroid Build Coastguard Worker 8367*da0073e9SAndroid Build Coastguard Worker # There's a lot of stuff about sets that cannot work without a good deal of exertion on our part. 8368*da0073e9SAndroid Build Coastguard Worker # Specifically, getting a set as input won't ever work with how GetItemSource works (Can't arbitrary access set contents) 8369*da0073e9SAndroid Build Coastguard Worker # and so the guard story for the objects passed into input just isn't there atm. 8370*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 8371*da0073e9SAndroid Build Coastguard Worker torch._dynamo.exc.Unsupported, 8372*da0073e9SAndroid Build Coastguard Worker "^call_method UserDefinedObjectVariable\\(set\\).*", 8373*da0073e9SAndroid Build Coastguard Worker ): 8374*da0073e9SAndroid Build Coastguard Worker foo(inp) 8375*da0073e9SAndroid Build Coastguard Worker 8376*da0073e9SAndroid Build Coastguard Worker foo = torch._dynamo.optimize(counter, nopython=False)(foo) 8377*da0073e9SAndroid Build Coastguard Worker foo(inp) 8378*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 1) 8379*da0073e9SAndroid Build Coastguard Worker 8380*da0073e9SAndroid Build Coastguard Worker def test_reconstruct_set_across_graph_break(self): 8381*da0073e9SAndroid Build Coastguard Worker def foo(x, y): 8382*da0073e9SAndroid Build Coastguard Worker setty = set() 8383*da0073e9SAndroid Build Coastguard Worker for t in x: 8384*da0073e9SAndroid Build Coastguard Worker setty.add(t) 8385*da0073e9SAndroid Build Coastguard Worker print("Break!") 8386*da0073e9SAndroid Build Coastguard Worker return y * len(setty) 8387*da0073e9SAndroid Build Coastguard Worker 8388*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10, 10) 8389*da0073e9SAndroid Build Coastguard Worker y = torch.randn(2, 2) 8390*da0073e9SAndroid Build Coastguard Worker 8391*da0073e9SAndroid Build Coastguard Worker counter = CompileCounter() 8392*da0073e9SAndroid Build Coastguard Worker foo = torch._dynamo.optimize(counter)(foo) 8393*da0073e9SAndroid Build Coastguard Worker result = foo([x, x, x, x, y], y) 8394*da0073e9SAndroid Build Coastguard Worker 8395*da0073e9SAndroid Build Coastguard Worker def test_set_aliasing_recompiles(self): 8396*da0073e9SAndroid Build Coastguard Worker g1 = torch.randn(10) 8397*da0073e9SAndroid Build Coastguard Worker g2 = torch.randn(10) 8398*da0073e9SAndroid Build Coastguard Worker g3 = torch.randn(10) 8399*da0073e9SAndroid Build Coastguard Worker g4 = torch.randn(10) 8400*da0073e9SAndroid Build Coastguard Worker 8401*da0073e9SAndroid Build Coastguard Worker def foo(a, b, c): 8402*da0073e9SAndroid Build Coastguard Worker myset = {g1, a, b, c} 8403*da0073e9SAndroid Build Coastguard Worker return a + len(myset) 8404*da0073e9SAndroid Build Coastguard Worker 8405*da0073e9SAndroid Build Coastguard Worker counter = CompileCounter() 8406*da0073e9SAndroid Build Coastguard Worker foo = torch._dynamo.optimize(counter)(foo) 8407*da0073e9SAndroid Build Coastguard Worker # first call with no aliasing 8408*da0073e9SAndroid Build Coastguard Worker foo(g2, g3, g4) 8409*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 1) 8410*da0073e9SAndroid Build Coastguard Worker 8411*da0073e9SAndroid Build Coastguard Worker # no aliasing again 8412*da0073e9SAndroid Build Coastguard Worker foo(g3, g2, g4) 8413*da0073e9SAndroid Build Coastguard Worker # assert no recompile 8414*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 1) 8415*da0073e9SAndroid Build Coastguard Worker 8416*da0073e9SAndroid Build Coastguard Worker # aliasing changes, we should recompile 8417*da0073e9SAndroid Build Coastguard Worker foo(g2, g2, g2) 8418*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 2) 8419*da0073e9SAndroid Build Coastguard Worker 8420*da0073e9SAndroid Build Coastguard Worker # same aliasing, different tensor 8421*da0073e9SAndroid Build Coastguard Worker foo(g3, g3, g3) 8422*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 2) 8423*da0073e9SAndroid Build Coastguard Worker 8424*da0073e9SAndroid Build Coastguard Worker # aliasing between global and arg, should recompile again 8425*da0073e9SAndroid Build Coastguard Worker foo(g1, g1, g1) 8426*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 3) 8427*da0073e9SAndroid Build Coastguard Worker 8428*da0073e9SAndroid Build Coastguard Worker # Reset 8429*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 8430*da0073e9SAndroid Build Coastguard Worker 8431*da0073e9SAndroid Build Coastguard Worker # aliasing between global and arg, first call 8432*da0073e9SAndroid Build Coastguard Worker foo(g1, g1, g1) 8433*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 4) 8434*da0073e9SAndroid Build Coastguard Worker 8435*da0073e9SAndroid Build Coastguard Worker # same aliasing, different tensor, all local, recompile 8436*da0073e9SAndroid Build Coastguard Worker foo(g3, g3, g3) 8437*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 5) 8438*da0073e9SAndroid Build Coastguard Worker 8439*da0073e9SAndroid Build Coastguard Worker # aliasing same tensor, we shouldn't recompile 8440*da0073e9SAndroid Build Coastguard Worker foo(g2, g2, g2) 8441*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 5) 8442*da0073e9SAndroid Build Coastguard Worker 8443*da0073e9SAndroid Build Coastguard Worker # No aliasing 8444*da0073e9SAndroid Build Coastguard Worker foo(g2, g3, g4) 8445*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 6) 8446*da0073e9SAndroid Build Coastguard Worker 8447*da0073e9SAndroid Build Coastguard Worker # No aliasing again 8448*da0073e9SAndroid Build Coastguard Worker foo(g3, g2, g4) 8449*da0073e9SAndroid Build Coastguard Worker # assert no recompile 8450*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 6) 8451*da0073e9SAndroid Build Coastguard Worker 8452*da0073e9SAndroid Build Coastguard Worker def test_str_format_return1(self): 8453*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager", fullgraph=True) 8454*da0073e9SAndroid Build Coastguard Worker def fn(img): 8455*da0073e9SAndroid Build Coastguard Worker x = torch.sin(img) 8456*da0073e9SAndroid Build Coastguard Worker y = f"shape {img.shape[-2:]} batch size {img.shape[0]}" 8457*da0073e9SAndroid Build Coastguard Worker return img + x, y 8458*da0073e9SAndroid Build Coastguard Worker 8459*da0073e9SAndroid Build Coastguard Worker img1 = torch.randn(1, 1, 8, 8) 8460*da0073e9SAndroid Build Coastguard Worker res, msg = fn(img1) 8461*da0073e9SAndroid Build Coastguard Worker self.assertEqual(msg, "shape torch.Size([8, 8]) batch size 1") 8462*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, img1 + torch.sin(img1)) 8463*da0073e9SAndroid Build Coastguard Worker 8464*da0073e9SAndroid Build Coastguard Worker def test_str_format_return2(self): 8465*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager", fullgraph=True) 8466*da0073e9SAndroid Build Coastguard Worker def fn(img): 8467*da0073e9SAndroid Build Coastguard Worker x = torch.sin(img) 8468*da0073e9SAndroid Build Coastguard Worker y = "shape {} batch size {y:.2f}".format(img.shape[-2:], y=img.shape[0]) 8469*da0073e9SAndroid Build Coastguard Worker return img + x, y 8470*da0073e9SAndroid Build Coastguard Worker 8471*da0073e9SAndroid Build Coastguard Worker img1 = torch.randn(1, 1, 8, 8) 8472*da0073e9SAndroid Build Coastguard Worker res, msg = fn(img1) 8473*da0073e9SAndroid Build Coastguard Worker self.assertEqual(msg, "shape torch.Size([8, 8]) batch size 1.00") 8474*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, img1 + torch.sin(img1)) 8475*da0073e9SAndroid Build Coastguard Worker 8476*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch(capture_scalar_outputs=True) 8477*da0073e9SAndroid Build Coastguard Worker def test_validate_outputs_unbacked(self): 8478*da0073e9SAndroid Build Coastguard Worker class SillyCat(torch.autograd.Function): 8479*da0073e9SAndroid Build Coastguard Worker @staticmethod 8480*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x0, x1, i): 8481*da0073e9SAndroid Build Coastguard Worker ctx.save_for_backward(i) 8482*da0073e9SAndroid Build Coastguard Worker return torch.cat([x0, x1]) 8483*da0073e9SAndroid Build Coastguard Worker 8484*da0073e9SAndroid Build Coastguard Worker @staticmethod 8485*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad_out): 8486*da0073e9SAndroid Build Coastguard Worker (i,) = ctx.saved_tensors 8487*da0073e9SAndroid Build Coastguard Worker i0, i1 = i.tolist() 8488*da0073e9SAndroid Build Coastguard Worker g_x0, g_x1 = grad_out.split([i0, i1]) 8489*da0073e9SAndroid Build Coastguard Worker return g_x0, g_x1, None 8490*da0073e9SAndroid Build Coastguard Worker 8491*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="aot_eager", fullgraph=True) 8492*da0073e9SAndroid Build Coastguard Worker def f(x, i): 8493*da0073e9SAndroid Build Coastguard Worker i0, i1 = i.tolist() 8494*da0073e9SAndroid Build Coastguard Worker x0, x1 = x.split([i0, i1]) 8495*da0073e9SAndroid Build Coastguard Worker return SillyCat.apply(x0, x1, i) 8496*da0073e9SAndroid Build Coastguard Worker 8497*da0073e9SAndroid Build Coastguard Worker f(torch.randn(9, requires_grad=True), torch.tensor([3, 6])) 8498*da0073e9SAndroid Build Coastguard Worker 8499*da0073e9SAndroid Build Coastguard Worker def test_str_format_assert1(self): 8500*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager", fullgraph=True) 8501*da0073e9SAndroid Build Coastguard Worker def fn(img): 8502*da0073e9SAndroid Build Coastguard Worker x = torch.sin(img) 8503*da0073e9SAndroid Build Coastguard Worker val = x.shape[-2:] 8504*da0073e9SAndroid Build Coastguard Worker torch._assert(len(val) == 2, f"shape {img.shape}") 8505*da0073e9SAndroid Build Coastguard Worker return img + x 8506*da0073e9SAndroid Build Coastguard Worker 8507*da0073e9SAndroid Build Coastguard Worker img1 = torch.randn(1, 1, 8, 8) 8508*da0073e9SAndroid Build Coastguard Worker res = fn(img1) 8509*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, img1 + torch.sin(img1)) 8510*da0073e9SAndroid Build Coastguard Worker 8511*da0073e9SAndroid Build Coastguard Worker def test_str_format_assert2(self): 8512*da0073e9SAndroid Build Coastguard Worker cnt = CompileCounter() 8513*da0073e9SAndroid Build Coastguard Worker 8514*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend=cnt) 8515*da0073e9SAndroid Build Coastguard Worker def fn(img): 8516*da0073e9SAndroid Build Coastguard Worker x = torch.sin(img) 8517*da0073e9SAndroid Build Coastguard Worker torch._assert( 8518*da0073e9SAndroid Build Coastguard Worker img.shape[-2] == 8 and img.shape[-1] == 16, f"shape {img.shape}" 8519*da0073e9SAndroid Build Coastguard Worker ) 8520*da0073e9SAndroid Build Coastguard Worker return img + x 8521*da0073e9SAndroid Build Coastguard Worker 8522*da0073e9SAndroid Build Coastguard Worker img1 = torch.randn(1, 3, 8, 16) 8523*da0073e9SAndroid Build Coastguard Worker res = fn(img1) 8524*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, img1 + torch.sin(img1)) 8525*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 8526*da0073e9SAndroid Build Coastguard Worker 8527*da0073e9SAndroid Build Coastguard Worker # trigger a recompile and graph break 8528*da0073e9SAndroid Build Coastguard Worker img2 = torch.randn(1, 3, 8, 15) 8529*da0073e9SAndroid Build Coastguard Worker self.assertRaises(AssertionError, lambda: fn(img2)) 8530*da0073e9SAndroid Build Coastguard Worker 8531*da0073e9SAndroid Build Coastguard Worker def test_tolist_scalar(self): 8532*da0073e9SAndroid Build Coastguard Worker def fn(x): 8533*da0073e9SAndroid Build Coastguard Worker new_list = [] 8534*da0073e9SAndroid Build Coastguard Worker for i in x.tolist(): 8535*da0073e9SAndroid Build Coastguard Worker new_list.append(i * 4) 8536*da0073e9SAndroid Build Coastguard Worker return new_list 8537*da0073e9SAndroid Build Coastguard Worker 8538*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([3]) 8539*da0073e9SAndroid Build Coastguard Worker eager = fn(x) 8540*da0073e9SAndroid Build Coastguard Worker counter = CompileCounter() 8541*da0073e9SAndroid Build Coastguard Worker compiled = torch._dynamo.optimize(counter, nopython=True)(fn)(x) 8542*da0073e9SAndroid Build Coastguard Worker self.assertEqual(eager, compiled) 8543*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 1) 8544*da0073e9SAndroid Build Coastguard Worker 8545*da0073e9SAndroid Build Coastguard Worker def test_tolist_1d(self): 8546*da0073e9SAndroid Build Coastguard Worker def fn(x): 8547*da0073e9SAndroid Build Coastguard Worker new_list = [] 8548*da0073e9SAndroid Build Coastguard Worker for i in x.tolist(): 8549*da0073e9SAndroid Build Coastguard Worker new_list.append(i * 4) 8550*da0073e9SAndroid Build Coastguard Worker return new_list 8551*da0073e9SAndroid Build Coastguard Worker 8552*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([2, 1]) 8553*da0073e9SAndroid Build Coastguard Worker eager = fn(x) 8554*da0073e9SAndroid Build Coastguard Worker counter = CompileCounter() 8555*da0073e9SAndroid Build Coastguard Worker compiled = torch._dynamo.optimize(counter, nopython=True)(fn)(x) 8556*da0073e9SAndroid Build Coastguard Worker self.assertEqual(eager, compiled) 8557*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 1) 8558*da0073e9SAndroid Build Coastguard Worker 8559*da0073e9SAndroid Build Coastguard Worker def test_tolist_kd(self): 8560*da0073e9SAndroid Build Coastguard Worker def fn(x): 8561*da0073e9SAndroid Build Coastguard Worker new_list = [] 8562*da0073e9SAndroid Build Coastguard Worker for i in x.tolist(): 8563*da0073e9SAndroid Build Coastguard Worker new_list.append(i * 4) 8564*da0073e9SAndroid Build Coastguard Worker return new_list 8565*da0073e9SAndroid Build Coastguard Worker 8566*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([[[2, 1], [2, 1], [2, 1]], [[2, 1], [2, 1], [2, 1]]]) 8567*da0073e9SAndroid Build Coastguard Worker eager = fn(x) 8568*da0073e9SAndroid Build Coastguard Worker counter = CompileCounter() 8569*da0073e9SAndroid Build Coastguard Worker compiled = torch._dynamo.optimize(counter, nopython=True)(fn)(x) 8570*da0073e9SAndroid Build Coastguard Worker self.assertEqual(eager, compiled) 8571*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 1) 8572*da0073e9SAndroid Build Coastguard Worker 8573*da0073e9SAndroid Build Coastguard Worker @patch.object(torch._dynamo.config, "specialize_int", True) 8574*da0073e9SAndroid Build Coastguard Worker def test_tolist_0d(self): 8575*da0073e9SAndroid Build Coastguard Worker def fn(x): 8576*da0073e9SAndroid Build Coastguard Worker new_list = [] 8577*da0073e9SAndroid Build Coastguard Worker i = x.tolist() 8578*da0073e9SAndroid Build Coastguard Worker new_list.append(i * 4) 8579*da0073e9SAndroid Build Coastguard Worker return new_list 8580*da0073e9SAndroid Build Coastguard Worker 8581*da0073e9SAndroid Build Coastguard Worker x = torch.tensor(42) 8582*da0073e9SAndroid Build Coastguard Worker eager = fn(x) 8583*da0073e9SAndroid Build Coastguard Worker counter = CompileCounter() 8584*da0073e9SAndroid Build Coastguard Worker compiled = torch._dynamo.optimize(counter, nopython=True)(fn)(x) 8585*da0073e9SAndroid Build Coastguard Worker self.assertEqual(eager, compiled) 8586*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 1) 8587*da0073e9SAndroid Build Coastguard Worker 8588*da0073e9SAndroid Build Coastguard Worker @patch.object(torch._dynamo.config, "assume_static_by_default", False) 8589*da0073e9SAndroid Build Coastguard Worker @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", False) 8590*da0073e9SAndroid Build Coastguard Worker def test_tolist_kd_dynamic(self): 8591*da0073e9SAndroid Build Coastguard Worker def fn(x): 8592*da0073e9SAndroid Build Coastguard Worker new_list = [] 8593*da0073e9SAndroid Build Coastguard Worker i = x.tolist() 8594*da0073e9SAndroid Build Coastguard Worker new_list.append(i * 4) 8595*da0073e9SAndroid Build Coastguard Worker return new_list 8596*da0073e9SAndroid Build Coastguard Worker 8597*da0073e9SAndroid Build Coastguard Worker x = torch.randint(3, 5, [5, 5]) 8598*da0073e9SAndroid Build Coastguard Worker eager = fn(x) 8599*da0073e9SAndroid Build Coastguard Worker counter = CompileCounter() 8600*da0073e9SAndroid Build Coastguard Worker compiled_fn = torch._dynamo.optimize(counter, nopython=True)(fn) 8601*da0073e9SAndroid Build Coastguard Worker compiled = compiled_fn(x) 8602*da0073e9SAndroid Build Coastguard Worker self.assertEqual(eager, compiled) 8603*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 1) 8604*da0073e9SAndroid Build Coastguard Worker 8605*da0073e9SAndroid Build Coastguard Worker # Value change, no recompiles 8606*da0073e9SAndroid Build Coastguard Worker x = torch.randint(7, 9, [5, 5]) 8607*da0073e9SAndroid Build Coastguard Worker compiled_fn(x) 8608*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 1) 8609*da0073e9SAndroid Build Coastguard Worker 8610*da0073e9SAndroid Build Coastguard Worker # Size change, forced recompiles 8611*da0073e9SAndroid Build Coastguard Worker x = torch.randint(3, 5, [3, 3]) 8612*da0073e9SAndroid Build Coastguard Worker compiled_fn(x) 8613*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 2) 8614*da0073e9SAndroid Build Coastguard Worker 8615*da0073e9SAndroid Build Coastguard Worker def test_tolist_float(self): 8616*da0073e9SAndroid Build Coastguard Worker def fn(x): 8617*da0073e9SAndroid Build Coastguard Worker new_list = [] 8618*da0073e9SAndroid Build Coastguard Worker for i in x.tolist(): 8619*da0073e9SAndroid Build Coastguard Worker new_list.append(i * 4) 8620*da0073e9SAndroid Build Coastguard Worker return new_list 8621*da0073e9SAndroid Build Coastguard Worker 8622*da0073e9SAndroid Build Coastguard Worker x = torch.tensor( 8623*da0073e9SAndroid Build Coastguard Worker [[[2.0, 1.0], [2.0, 1.0], [2.0, 1.0]], [[2.0, 1.0], [2.0, 1.0], [2.0, 1.0]]] 8624*da0073e9SAndroid Build Coastguard Worker ) 8625*da0073e9SAndroid Build Coastguard Worker eager = fn(x) 8626*da0073e9SAndroid Build Coastguard Worker counter = CompileCounter() 8627*da0073e9SAndroid Build Coastguard Worker compiled = torch._dynamo.optimize(counter)(fn)(x) 8628*da0073e9SAndroid Build Coastguard Worker self.assertEqual(eager, compiled) 8629*da0073e9SAndroid Build Coastguard Worker # Nothing to compile here 8630*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 0) 8631*da0073e9SAndroid Build Coastguard Worker 8632*da0073e9SAndroid Build Coastguard Worker def test_inline_closure_not_loaded_by_parent(self): 8633*da0073e9SAndroid Build Coastguard Worker def outer(a): 8634*da0073e9SAndroid Build Coastguard Worker return a + 1 8635*da0073e9SAndroid Build Coastguard Worker 8636*da0073e9SAndroid Build Coastguard Worker def indirect(x): 8637*da0073e9SAndroid Build Coastguard Worker return direct(x) 8638*da0073e9SAndroid Build Coastguard Worker 8639*da0073e9SAndroid Build Coastguard Worker def direct(x): 8640*da0073e9SAndroid Build Coastguard Worker def deep2(c): 8641*da0073e9SAndroid Build Coastguard Worker return outer(c) 8642*da0073e9SAndroid Build Coastguard Worker 8643*da0073e9SAndroid Build Coastguard Worker def deep(c): 8644*da0073e9SAndroid Build Coastguard Worker return deep2(c) 8645*da0073e9SAndroid Build Coastguard Worker 8646*da0073e9SAndroid Build Coastguard Worker return deep(x) 8647*da0073e9SAndroid Build Coastguard Worker 8648*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 8649*da0073e9SAndroid Build Coastguard Worker eager = indirect(x) 8650*da0073e9SAndroid Build Coastguard Worker counter = CompileCounter() 8651*da0073e9SAndroid Build Coastguard Worker compiled = torch._dynamo.optimize(counter)(indirect)(x) 8652*da0073e9SAndroid Build Coastguard Worker self.assertEqual(eager, compiled) 8653*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 1) 8654*da0073e9SAndroid Build Coastguard Worker 8655*da0073e9SAndroid Build Coastguard Worker def test_deque_input(self): 8656*da0073e9SAndroid Build Coastguard Worker a = torch.randn([2, 3]) 8657*da0073e9SAndroid Build Coastguard Worker b = torch.randn([2, 3]) 8658*da0073e9SAndroid Build Coastguard Worker d1 = collections.deque([a, b]) 8659*da0073e9SAndroid Build Coastguard Worker d1.insert(0, "foo") 8660*da0073e9SAndroid Build Coastguard Worker 8661*da0073e9SAndroid Build Coastguard Worker d2 = collections.deque([a, b]) 8662*da0073e9SAndroid Build Coastguard Worker d2.insert(0, "foo") 8663*da0073e9SAndroid Build Coastguard Worker 8664*da0073e9SAndroid Build Coastguard Worker def fn(q): 8665*da0073e9SAndroid Build Coastguard Worker a = q.pop() 8666*da0073e9SAndroid Build Coastguard Worker b = q.pop() 8667*da0073e9SAndroid Build Coastguard Worker return a * b 8668*da0073e9SAndroid Build Coastguard Worker 8669*da0073e9SAndroid Build Coastguard Worker eager = fn(d1) 8670*da0073e9SAndroid Build Coastguard Worker counter = CompileCounter() 8671*da0073e9SAndroid Build Coastguard Worker compiled = torch._dynamo.optimize(counter)(fn)(d2) 8672*da0073e9SAndroid Build Coastguard Worker self.assertEqual(eager, compiled) 8673*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 1) 8674*da0073e9SAndroid Build Coastguard Worker 8675*da0073e9SAndroid Build Coastguard Worker def test_deque_append_left(self): 8676*da0073e9SAndroid Build Coastguard Worker d1 = collections.deque([10, 10]) 8677*da0073e9SAndroid Build Coastguard Worker d1.insert(0, "foo") 8678*da0073e9SAndroid Build Coastguard Worker 8679*da0073e9SAndroid Build Coastguard Worker d2 = collections.deque([10, 10]) 8680*da0073e9SAndroid Build Coastguard Worker d2.insert(0, "foo") 8681*da0073e9SAndroid Build Coastguard Worker 8682*da0073e9SAndroid Build Coastguard Worker def fn(q, a, b): 8683*da0073e9SAndroid Build Coastguard Worker q.appendleft(a) 8684*da0073e9SAndroid Build Coastguard Worker q.appendleft(b) 8685*da0073e9SAndroid Build Coastguard Worker return q.popleft() * q.popleft() 8686*da0073e9SAndroid Build Coastguard Worker 8687*da0073e9SAndroid Build Coastguard Worker a = torch.randn([3, 3]) 8688*da0073e9SAndroid Build Coastguard Worker b = torch.randn([3, 3]) 8689*da0073e9SAndroid Build Coastguard Worker eager = fn(d1, a, b) 8690*da0073e9SAndroid Build Coastguard Worker counter = CompileCounter() 8691*da0073e9SAndroid Build Coastguard Worker compiled = torch._dynamo.optimize(counter)(fn)(d2, a, b) 8692*da0073e9SAndroid Build Coastguard Worker self.assertEqual(eager, compiled) 8693*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 1) 8694*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(compiled, torch.Tensor)) 8695*da0073e9SAndroid Build Coastguard Worker 8696*da0073e9SAndroid Build Coastguard Worker def test_yield_from(self): 8697*da0073e9SAndroid Build Coastguard Worker def yield_from_fn(t_list, k): 8698*da0073e9SAndroid Build Coastguard Worker def yield_from_gen(l): 8699*da0073e9SAndroid Build Coastguard Worker l2 = [t * k for t in l] 8700*da0073e9SAndroid Build Coastguard Worker yield from l2 8701*da0073e9SAndroid Build Coastguard Worker 8702*da0073e9SAndroid Build Coastguard Worker return [t * k for t in yield_from_gen(t_list)] 8703*da0073e9SAndroid Build Coastguard Worker 8704*da0073e9SAndroid Build Coastguard Worker t_list = [torch.randn([2, 3]) for _ in range(3)] 8705*da0073e9SAndroid Build Coastguard Worker eager = yield_from_fn(t_list, 2) 8706*da0073e9SAndroid Build Coastguard Worker counter = CompileCounter() 8707*da0073e9SAndroid Build Coastguard Worker compiled = torch._dynamo.optimize(counter)(yield_from_fn)(t_list, 2) 8708*da0073e9SAndroid Build Coastguard Worker self.assertEqual(eager, compiled) 8709*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 1) 8710*da0073e9SAndroid Build Coastguard Worker 8711*da0073e9SAndroid Build Coastguard Worker def test_yield_from_in_a_loop(self): 8712*da0073e9SAndroid Build Coastguard Worker def gen2(): 8713*da0073e9SAndroid Build Coastguard Worker yield 1 8714*da0073e9SAndroid Build Coastguard Worker 8715*da0073e9SAndroid Build Coastguard Worker def gen1(): 8716*da0073e9SAndroid Build Coastguard Worker for value in range(5): 8717*da0073e9SAndroid Build Coastguard Worker yield from gen2() 8718*da0073e9SAndroid Build Coastguard Worker 8719*da0073e9SAndroid Build Coastguard Worker def fn(x): 8720*da0073e9SAndroid Build Coastguard Worker c = 0 8721*da0073e9SAndroid Build Coastguard Worker for i in gen1(): 8722*da0073e9SAndroid Build Coastguard Worker c = c + i 8723*da0073e9SAndroid Build Coastguard Worker return x + c 8724*da0073e9SAndroid Build Coastguard Worker 8725*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager") 8726*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(4) 8727*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(x), opt_fn(x)) 8728*da0073e9SAndroid Build Coastguard Worker 8729*da0073e9SAndroid Build Coastguard Worker def test_yield_gen_and_from(self): 8730*da0073e9SAndroid Build Coastguard Worker def populate_and_multiply_sequence(n, multiplier): 8731*da0073e9SAndroid Build Coastguard Worker # Inline generator 8732*da0073e9SAndroid Build Coastguard Worker def tensor_generator(): 8733*da0073e9SAndroid Build Coastguard Worker for i in range(n): 8734*da0073e9SAndroid Build Coastguard Worker yield torch.tensor([i]) 8735*da0073e9SAndroid Build Coastguard Worker 8736*da0073e9SAndroid Build Coastguard Worker # Use 'yield from' to iterate over tensors and multiply 8737*da0073e9SAndroid Build Coastguard Worker t_list = [tensor * multiplier for tensor in tensor_generator()] 8738*da0073e9SAndroid Build Coastguard Worker 8739*da0073e9SAndroid Build Coastguard Worker def yield_from_gen(): 8740*da0073e9SAndroid Build Coastguard Worker yield from t_list 8741*da0073e9SAndroid Build Coastguard Worker 8742*da0073e9SAndroid Build Coastguard Worker return [t for t in yield_from_gen()] 8743*da0073e9SAndroid Build Coastguard Worker 8744*da0073e9SAndroid Build Coastguard Worker multiplier = torch.tensor([10]) 8745*da0073e9SAndroid Build Coastguard Worker eager = populate_and_multiply_sequence(5, multiplier) 8746*da0073e9SAndroid Build Coastguard Worker counter = CompileCounter() 8747*da0073e9SAndroid Build Coastguard Worker compiled = torch._dynamo.optimize(counter)(populate_and_multiply_sequence)( 8748*da0073e9SAndroid Build Coastguard Worker 5, multiplier 8749*da0073e9SAndroid Build Coastguard Worker ) 8750*da0073e9SAndroid Build Coastguard Worker self.assertEqual(eager, compiled) 8751*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 1) 8752*da0073e9SAndroid Build Coastguard Worker 8753*da0073e9SAndroid Build Coastguard Worker def test_yield_from_user_stop_iteration(self): 8754*da0073e9SAndroid Build Coastguard Worker class MyIter: 8755*da0073e9SAndroid Build Coastguard Worker def __init__(self, seq): 8756*da0073e9SAndroid Build Coastguard Worker self.seq = seq 8757*da0073e9SAndroid Build Coastguard Worker self.index = 0 8758*da0073e9SAndroid Build Coastguard Worker 8759*da0073e9SAndroid Build Coastguard Worker def __iter__(self): 8760*da0073e9SAndroid Build Coastguard Worker return self 8761*da0073e9SAndroid Build Coastguard Worker 8762*da0073e9SAndroid Build Coastguard Worker def __next__(self): 8763*da0073e9SAndroid Build Coastguard Worker self.index += 1 8764*da0073e9SAndroid Build Coastguard Worker if self.index <= len(self.seq): 8765*da0073e9SAndroid Build Coastguard Worker return self.seq[self.index - 1] 8766*da0073e9SAndroid Build Coastguard Worker raise StopIteration(self.index) 8767*da0073e9SAndroid Build Coastguard Worker 8768*da0073e9SAndroid Build Coastguard Worker def yield_from_iter_fn(seq): 8769*da0073e9SAndroid Build Coastguard Worker def gen(seq): 8770*da0073e9SAndroid Build Coastguard Worker yield from MyIter(seq) 8771*da0073e9SAndroid Build Coastguard Worker 8772*da0073e9SAndroid Build Coastguard Worker return [i for i in gen(seq)] 8773*da0073e9SAndroid Build Coastguard Worker 8774*da0073e9SAndroid Build Coastguard Worker seq = [torch.randn([2, 3]) for _ in range(3)] 8775*da0073e9SAndroid Build Coastguard Worker eager = yield_from_iter_fn(seq) 8776*da0073e9SAndroid Build Coastguard Worker counter = CompileCounter() 8777*da0073e9SAndroid Build Coastguard Worker compiled = torch._dynamo.optimize(counter)(yield_from_iter_fn)(seq) 8778*da0073e9SAndroid Build Coastguard Worker self.assertEqual(eager, compiled) 8779*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 0) 8780*da0073e9SAndroid Build Coastguard Worker 8781*da0073e9SAndroid Build Coastguard Worker def test_yield_send_to_subgenerator_graph_break(self): 8782*da0073e9SAndroid Build Coastguard Worker def subgenerator(tensor): 8783*da0073e9SAndroid Build Coastguard Worker multiplier = yield 8784*da0073e9SAndroid Build Coastguard Worker yield tensor * multiplier 8785*da0073e9SAndroid Build Coastguard Worker 8786*da0073e9SAndroid Build Coastguard Worker def main_generator(t_list): 8787*da0073e9SAndroid Build Coastguard Worker for tensor in t_list: 8788*da0073e9SAndroid Build Coastguard Worker subgen = subgenerator(tensor) 8789*da0073e9SAndroid Build Coastguard Worker next(subgen) 8790*da0073e9SAndroid Build Coastguard Worker yield from subgen.send(torch.tensor([10])) 8791*da0073e9SAndroid Build Coastguard Worker 8792*da0073e9SAndroid Build Coastguard Worker t_list = [torch.tensor([i]) for i in range(5)] 8793*da0073e9SAndroid Build Coastguard Worker eager = list(main_generator(t_list)) 8794*da0073e9SAndroid Build Coastguard Worker 8795*da0073e9SAndroid Build Coastguard Worker counter = CompileCounter() 8796*da0073e9SAndroid Build Coastguard Worker compiled_fn = torch._dynamo.optimize(counter)(main_generator) 8797*da0073e9SAndroid Build Coastguard Worker compiled = list(compiled_fn(t_list)) 8798*da0073e9SAndroid Build Coastguard Worker 8799*da0073e9SAndroid Build Coastguard Worker self.assertEqual(eager, compiled) 8800*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 0) 8801*da0073e9SAndroid Build Coastguard Worker 8802*da0073e9SAndroid Build Coastguard Worker def test_derpy_nn_module_usage(self): 8803*da0073e9SAndroid Build Coastguard Worker def ff1(x): 8804*da0073e9SAndroid Build Coastguard Worker self = mod1 8805*da0073e9SAndroid Build Coastguard Worker return torch.sigmoid(self.mod2(x) + self.param1) 8806*da0073e9SAndroid Build Coastguard Worker 8807*da0073e9SAndroid Build Coastguard Worker def ff2(x): 8808*da0073e9SAndroid Build Coastguard Worker self = mod2 8809*da0073e9SAndroid Build Coastguard Worker return torch.cos(torch.sin(x) * self.param2 + 10) 8810*da0073e9SAndroid Build Coastguard Worker 8811*da0073e9SAndroid Build Coastguard Worker mod1 = torch.nn.Module() 8812*da0073e9SAndroid Build Coastguard Worker mod2 = torch.nn.Module() 8813*da0073e9SAndroid Build Coastguard Worker mod1.register_module("mod2", mod2) 8814*da0073e9SAndroid Build Coastguard Worker mod1.register_parameter("param1", torch.nn.Parameter(torch.randn(10))) 8815*da0073e9SAndroid Build Coastguard Worker mod1.forward = ff1 8816*da0073e9SAndroid Build Coastguard Worker mod2.register_parameter("param2", torch.nn.Parameter(torch.randn(10))) 8817*da0073e9SAndroid Build Coastguard Worker mod2.forward = ff2 8818*da0073e9SAndroid Build Coastguard Worker mod1.eval() 8819*da0073e9SAndroid Build Coastguard Worker 8820*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10) 8821*da0073e9SAndroid Build Coastguard Worker expected = mod1(x) 8822*da0073e9SAndroid Build Coastguard Worker counter = CompileCounter() 8823*da0073e9SAndroid Build Coastguard Worker actual = torch.compile(mod1, backend=counter, fullgraph=True)(x) 8824*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected) 8825*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.op_count, 6) 8826*da0073e9SAndroid Build Coastguard Worker 8827*da0073e9SAndroid Build Coastguard Worker def test_default_args_device_dtype(self): 8828*da0073e9SAndroid Build Coastguard Worker class Foo: 8829*da0073e9SAndroid Build Coastguard Worker def __init__( 8830*da0073e9SAndroid Build Coastguard Worker self, 8831*da0073e9SAndroid Build Coastguard Worker dtype: torch.dtype = torch.float16, 8832*da0073e9SAndroid Build Coastguard Worker device: torch.device = torch.device("cpu"), 8833*da0073e9SAndroid Build Coastguard Worker ) -> None: 8834*da0073e9SAndroid Build Coastguard Worker self.value = torch.tensor(10, dtype=dtype, device=device) 8835*da0073e9SAndroid Build Coastguard Worker 8836*da0073e9SAndroid Build Coastguard Worker def fn(): 8837*da0073e9SAndroid Build Coastguard Worker return Foo().value + 1 8838*da0073e9SAndroid Build Coastguard Worker 8839*da0073e9SAndroid Build Coastguard Worker opt_func = torch._dynamo.optimize("eager", nopython=True)(fn) 8840*da0073e9SAndroid Build Coastguard Worker ref = fn() 8841*da0073e9SAndroid Build Coastguard Worker res = opt_func() 8842*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 8843*da0073e9SAndroid Build Coastguard Worker 8844*da0073e9SAndroid Build Coastguard Worker def test_torch_device_python_type(self): 8845*da0073e9SAndroid Build Coastguard Worker for device, device_type, index in [ 8846*da0073e9SAndroid Build Coastguard Worker ("cpu", "cpu", None), 8847*da0073e9SAndroid Build Coastguard Worker ("cuda:0", "cuda", 0), 8848*da0073e9SAndroid Build Coastguard Worker ]: 8849*da0073e9SAndroid Build Coastguard Worker if device == "cuda:0" and not TEST_CUDA: 8850*da0073e9SAndroid Build Coastguard Worker continue 8851*da0073e9SAndroid Build Coastguard Worker 8852*da0073e9SAndroid Build Coastguard Worker def fn(target): 8853*da0073e9SAndroid Build Coastguard Worker target_device = target.device 8854*da0073e9SAndroid Build Coastguard Worker a = torch.zeros(2, 3, device=target_device) 8855*da0073e9SAndroid Build Coastguard Worker # Constant assert at trace time 8856*da0073e9SAndroid Build Coastguard Worker assert isinstance(target_device, torch.device) 8857*da0073e9SAndroid Build Coastguard Worker assert target_device.type == device_type 8858*da0073e9SAndroid Build Coastguard Worker assert target_device.index == index 8859*da0073e9SAndroid Build Coastguard Worker b = torch.zeros(2, 3, device=target_device) 8860*da0073e9SAndroid Build Coastguard Worker c = torch.zeros(2, 3, device=target_device) 8861*da0073e9SAndroid Build Coastguard Worker return a + b + c 8862*da0073e9SAndroid Build Coastguard Worker 8863*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.variables import ConstantVariable 8864*da0073e9SAndroid Build Coastguard Worker 8865*da0073e9SAndroid Build Coastguard Worker device = torch.device(device) 8866*da0073e9SAndroid Build Coastguard Worker expected_variable = ConstantVariable(device) 8867*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_variable.python_type(), type(device)) 8868*da0073e9SAndroid Build Coastguard Worker 8869*da0073e9SAndroid Build Coastguard Worker opt_func = torch._dynamo.optimize("eager", nopython=True)(fn) 8870*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([2, 3], device=device) 8871*da0073e9SAndroid Build Coastguard Worker res = opt_func(a) 8872*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(res, torch.Tensor) 8873*da0073e9SAndroid Build Coastguard Worker 8874*da0073e9SAndroid Build Coastguard Worker def test_torch_dtype_python_type(self): 8875*da0073e9SAndroid Build Coastguard Worker def fn(target): 8876*da0073e9SAndroid Build Coastguard Worker target_dtype = target.dtype 8877*da0073e9SAndroid Build Coastguard Worker a = torch.zeros(2, 3, dtype=target_dtype) 8878*da0073e9SAndroid Build Coastguard Worker # Constant assert at trace time 8879*da0073e9SAndroid Build Coastguard Worker assert isinstance(target_dtype, torch.dtype) 8880*da0073e9SAndroid Build Coastguard Worker b = torch.zeros(2, 3, dtype=target_dtype) 8881*da0073e9SAndroid Build Coastguard Worker c = torch.zeros(2, 3, dtype=target_dtype) 8882*da0073e9SAndroid Build Coastguard Worker return a + b + c 8883*da0073e9SAndroid Build Coastguard Worker 8884*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.variables import ConstantVariable 8885*da0073e9SAndroid Build Coastguard Worker 8886*da0073e9SAndroid Build Coastguard Worker dtype = torch.float16 8887*da0073e9SAndroid Build Coastguard Worker expected_variable = ConstantVariable(dtype) 8888*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_variable.python_type(), type(dtype)) 8889*da0073e9SAndroid Build Coastguard Worker 8890*da0073e9SAndroid Build Coastguard Worker opt_func = torch._dynamo.optimize("eager", nopython=True)(fn) 8891*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([2, 3], dtype=dtype) 8892*da0073e9SAndroid Build Coastguard Worker res = opt_func(a) 8893*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(res, torch.Tensor) 8894*da0073e9SAndroid Build Coastguard Worker 8895*da0073e9SAndroid Build Coastguard Worker def test_itertools_repeat(self): 8896*da0073e9SAndroid Build Coastguard Worker counters.clear() 8897*da0073e9SAndroid Build Coastguard Worker 8898*da0073e9SAndroid Build Coastguard Worker def fn(x): 8899*da0073e9SAndroid Build Coastguard Worker r = itertools.repeat(100.0, 5) 8900*da0073e9SAndroid Build Coastguard Worker for i in r: 8901*da0073e9SAndroid Build Coastguard Worker x += i 8902*da0073e9SAndroid Build Coastguard Worker return x 8903*da0073e9SAndroid Build Coastguard Worker 8904*da0073e9SAndroid Build Coastguard Worker x = torch.randn([2, 5]) 8905*da0073e9SAndroid Build Coastguard Worker eager = fn(x) 8906*da0073e9SAndroid Build Coastguard Worker 8907*da0073e9SAndroid Build Coastguard Worker compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn) 8908*da0073e9SAndroid Build Coastguard Worker compiled = compiled_fn(x) 8909*da0073e9SAndroid Build Coastguard Worker 8910*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(eager), list(compiled)) 8911*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(counters["graph_break"]), 0) 8912*da0073e9SAndroid Build Coastguard Worker 8913*da0073e9SAndroid Build Coastguard Worker def test_itertools_infinite_repeat(self): 8914*da0073e9SAndroid Build Coastguard Worker counters.clear() 8915*da0073e9SAndroid Build Coastguard Worker 8916*da0073e9SAndroid Build Coastguard Worker def fn(x): 8917*da0073e9SAndroid Build Coastguard Worker r = itertools.repeat(100.0) 8918*da0073e9SAndroid Build Coastguard Worker idx = 0 8919*da0073e9SAndroid Build Coastguard Worker for i in r: 8920*da0073e9SAndroid Build Coastguard Worker x += i 8921*da0073e9SAndroid Build Coastguard Worker idx += 1 8922*da0073e9SAndroid Build Coastguard Worker if idx > 10: 8923*da0073e9SAndroid Build Coastguard Worker break 8924*da0073e9SAndroid Build Coastguard Worker return x 8925*da0073e9SAndroid Build Coastguard Worker 8926*da0073e9SAndroid Build Coastguard Worker x = torch.randn([2, 5]) 8927*da0073e9SAndroid Build Coastguard Worker eager = fn(x) 8928*da0073e9SAndroid Build Coastguard Worker 8929*da0073e9SAndroid Build Coastguard Worker compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn) 8930*da0073e9SAndroid Build Coastguard Worker compiled = compiled_fn(x) 8931*da0073e9SAndroid Build Coastguard Worker 8932*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(eager), list(compiled)) 8933*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(counters["graph_break"]), 0) 8934*da0073e9SAndroid Build Coastguard Worker 8935*da0073e9SAndroid Build Coastguard Worker def test_itertools_infinite_repeat_mutation(self): 8936*da0073e9SAndroid Build Coastguard Worker counters.clear() 8937*da0073e9SAndroid Build Coastguard Worker 8938*da0073e9SAndroid Build Coastguard Worker def fn(x): 8939*da0073e9SAndroid Build Coastguard Worker r = itertools.repeat(x) 8940*da0073e9SAndroid Build Coastguard Worker idx = 0 8941*da0073e9SAndroid Build Coastguard Worker for i in r: 8942*da0073e9SAndroid Build Coastguard Worker x += i 8943*da0073e9SAndroid Build Coastguard Worker i += 1 8944*da0073e9SAndroid Build Coastguard Worker idx += 1 8945*da0073e9SAndroid Build Coastguard Worker if idx > 10: 8946*da0073e9SAndroid Build Coastguard Worker break 8947*da0073e9SAndroid Build Coastguard Worker return x 8948*da0073e9SAndroid Build Coastguard Worker 8949*da0073e9SAndroid Build Coastguard Worker x = torch.randn([2, 5]) 8950*da0073e9SAndroid Build Coastguard Worker eager = fn(x) 8951*da0073e9SAndroid Build Coastguard Worker 8952*da0073e9SAndroid Build Coastguard Worker compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn) 8953*da0073e9SAndroid Build Coastguard Worker compiled = compiled_fn(x) 8954*da0073e9SAndroid Build Coastguard Worker 8955*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(eager), list(compiled)) 8956*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(counters["graph_break"]), 0) 8957*da0073e9SAndroid Build Coastguard Worker 8958*da0073e9SAndroid Build Coastguard Worker def test_itertools_infinite_count(self): 8959*da0073e9SAndroid Build Coastguard Worker for args in ([], [10], [5, -1]): 8960*da0073e9SAndroid Build Coastguard Worker counters.clear() 8961*da0073e9SAndroid Build Coastguard Worker 8962*da0073e9SAndroid Build Coastguard Worker def fn(x): 8963*da0073e9SAndroid Build Coastguard Worker r = itertools.count(*args) 8964*da0073e9SAndroid Build Coastguard Worker idx = 0 8965*da0073e9SAndroid Build Coastguard Worker for i in r: 8966*da0073e9SAndroid Build Coastguard Worker x += i 8967*da0073e9SAndroid Build Coastguard Worker idx += 1 8968*da0073e9SAndroid Build Coastguard Worker if idx > 10: 8969*da0073e9SAndroid Build Coastguard Worker break 8970*da0073e9SAndroid Build Coastguard Worker return x 8971*da0073e9SAndroid Build Coastguard Worker 8972*da0073e9SAndroid Build Coastguard Worker x = torch.randn([2, 5]) 8973*da0073e9SAndroid Build Coastguard Worker eager = fn(x) 8974*da0073e9SAndroid Build Coastguard Worker 8975*da0073e9SAndroid Build Coastguard Worker compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn) 8976*da0073e9SAndroid Build Coastguard Worker compiled = compiled_fn(x) 8977*da0073e9SAndroid Build Coastguard Worker 8978*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(eager), list(compiled)) 8979*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(counters["graph_break"]), 0) 8980*da0073e9SAndroid Build Coastguard Worker 8981*da0073e9SAndroid Build Coastguard Worker def test_itertools_infinite_cycle(self): 8982*da0073e9SAndroid Build Coastguard Worker counters.clear() 8983*da0073e9SAndroid Build Coastguard Worker 8984*da0073e9SAndroid Build Coastguard Worker def fn(x): 8985*da0073e9SAndroid Build Coastguard Worker for iterator in ( 8986*da0073e9SAndroid Build Coastguard Worker iter([]), 8987*da0073e9SAndroid Build Coastguard Worker iter([10, 11.0]), 8988*da0073e9SAndroid Build Coastguard Worker itertools.repeat(-1, 3), 8989*da0073e9SAndroid Build Coastguard Worker itertools.count(10), 8990*da0073e9SAndroid Build Coastguard Worker ): 8991*da0073e9SAndroid Build Coastguard Worker r = itertools.cycle(iterator) 8992*da0073e9SAndroid Build Coastguard Worker idx = 0 8993*da0073e9SAndroid Build Coastguard Worker x += 1 8994*da0073e9SAndroid Build Coastguard Worker for i in r: 8995*da0073e9SAndroid Build Coastguard Worker x += i 8996*da0073e9SAndroid Build Coastguard Worker idx += 1 8997*da0073e9SAndroid Build Coastguard Worker if idx > 10: 8998*da0073e9SAndroid Build Coastguard Worker break 8999*da0073e9SAndroid Build Coastguard Worker return x 9000*da0073e9SAndroid Build Coastguard Worker 9001*da0073e9SAndroid Build Coastguard Worker x = torch.randn([2, 5]) 9002*da0073e9SAndroid Build Coastguard Worker eager = fn(x) 9003*da0073e9SAndroid Build Coastguard Worker 9004*da0073e9SAndroid Build Coastguard Worker compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn) 9005*da0073e9SAndroid Build Coastguard Worker compiled = compiled_fn(x) 9006*da0073e9SAndroid Build Coastguard Worker 9007*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(eager), list(compiled)) 9008*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(counters["graph_break"]), 0) 9009*da0073e9SAndroid Build Coastguard Worker 9010*da0073e9SAndroid Build Coastguard Worker def test_itertools_accumulate_symint_default_sum(self): 9011*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/110287 9012*da0073e9SAndroid Build Coastguard Worker counters.clear() 9013*da0073e9SAndroid Build Coastguard Worker 9014*da0073e9SAndroid Build Coastguard Worker def fn(x): 9015*da0073e9SAndroid Build Coastguard Worker r = itertools.accumulate([x.size(0), x.size(1)]) 9016*da0073e9SAndroid Build Coastguard Worker for i in r: 9017*da0073e9SAndroid Build Coastguard Worker x *= i 9018*da0073e9SAndroid Build Coastguard Worker return x 9019*da0073e9SAndroid Build Coastguard Worker 9020*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3) 9021*da0073e9SAndroid Build Coastguard Worker eager = fn(x) 9022*da0073e9SAndroid Build Coastguard Worker 9023*da0073e9SAndroid Build Coastguard Worker compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn) 9024*da0073e9SAndroid Build Coastguard Worker compiled = compiled_fn(x) 9025*da0073e9SAndroid Build Coastguard Worker 9026*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(eager), list(compiled)) 9027*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(counters["graph_break"]), 0) 9028*da0073e9SAndroid Build Coastguard Worker 9029*da0073e9SAndroid Build Coastguard Worker def test_itertools_accumulate_tensors_default_sum(self): 9030*da0073e9SAndroid Build Coastguard Worker counters.clear() 9031*da0073e9SAndroid Build Coastguard Worker 9032*da0073e9SAndroid Build Coastguard Worker def fn(a, b, c, d, x): 9033*da0073e9SAndroid Build Coastguard Worker l = [a, b, c, d, x] 9034*da0073e9SAndroid Build Coastguard Worker for i, t in enumerate(l): 9035*da0073e9SAndroid Build Coastguard Worker l[i] = t * x 9036*da0073e9SAndroid Build Coastguard Worker return itertools.accumulate(l) 9037*da0073e9SAndroid Build Coastguard Worker 9038*da0073e9SAndroid Build Coastguard Worker t_list = [torch.tensor([i + 1]) for i in range(4)] 9039*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([[1, 2], [3, 4]]) 9040*da0073e9SAndroid Build Coastguard Worker eager = fn(*t_list, x) 9041*da0073e9SAndroid Build Coastguard Worker 9042*da0073e9SAndroid Build Coastguard Worker compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn) 9043*da0073e9SAndroid Build Coastguard Worker compiled = compiled_fn(*t_list, x) 9044*da0073e9SAndroid Build Coastguard Worker 9045*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(eager), list(compiled)) 9046*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(counters["graph_break"]), 0) 9047*da0073e9SAndroid Build Coastguard Worker 9048*da0073e9SAndroid Build Coastguard Worker def test_itertools_accumulate_tensors_builtins(self): 9049*da0073e9SAndroid Build Coastguard Worker for builtin_op in [operator.mul, operator.sub, operator.pow]: 9050*da0073e9SAndroid Build Coastguard Worker counters.clear() 9051*da0073e9SAndroid Build Coastguard Worker 9052*da0073e9SAndroid Build Coastguard Worker def fn(a, b, c, d, x): 9053*da0073e9SAndroid Build Coastguard Worker l = [a, b, c, d, x] 9054*da0073e9SAndroid Build Coastguard Worker for i, t in enumerate(l): 9055*da0073e9SAndroid Build Coastguard Worker l[i] = t * x 9056*da0073e9SAndroid Build Coastguard Worker return itertools.accumulate(l, builtin_op) 9057*da0073e9SAndroid Build Coastguard Worker 9058*da0073e9SAndroid Build Coastguard Worker t_list = [torch.tensor([i + 1]) for i in range(4)] 9059*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([[1, 2], [3, 4]]) 9060*da0073e9SAndroid Build Coastguard Worker eager = fn(*t_list, x) 9061*da0073e9SAndroid Build Coastguard Worker 9062*da0073e9SAndroid Build Coastguard Worker compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn) 9063*da0073e9SAndroid Build Coastguard Worker compiled = compiled_fn(*t_list, x) 9064*da0073e9SAndroid Build Coastguard Worker 9065*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(eager), list(compiled)) 9066*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(counters["graph_break"]), 0) 9067*da0073e9SAndroid Build Coastguard Worker 9068*da0073e9SAndroid Build Coastguard Worker def test_itertools_accumulate_tensors_kwargs(self): 9069*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.utils import counters 9070*da0073e9SAndroid Build Coastguard Worker 9071*da0073e9SAndroid Build Coastguard Worker for kwargs in [ 9072*da0073e9SAndroid Build Coastguard Worker {"func": operator.mul}, 9073*da0073e9SAndroid Build Coastguard Worker {"initial": 100}, 9074*da0073e9SAndroid Build Coastguard Worker {"func": operator.sub, "initial": -1}, 9075*da0073e9SAndroid Build Coastguard Worker ]: 9076*da0073e9SAndroid Build Coastguard Worker counters.clear() 9077*da0073e9SAndroid Build Coastguard Worker 9078*da0073e9SAndroid Build Coastguard Worker def fn(a, b, c, d, x): 9079*da0073e9SAndroid Build Coastguard Worker l = [a, b, c, d, x] 9080*da0073e9SAndroid Build Coastguard Worker for i, t in enumerate(l): 9081*da0073e9SAndroid Build Coastguard Worker l[i] = t * x 9082*da0073e9SAndroid Build Coastguard Worker return itertools.accumulate(l, **kwargs) 9083*da0073e9SAndroid Build Coastguard Worker 9084*da0073e9SAndroid Build Coastguard Worker t_list = [torch.tensor([i + 1]) for i in range(4)] 9085*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([[1, 2], [3, 4]]) 9086*da0073e9SAndroid Build Coastguard Worker 9087*da0073e9SAndroid Build Coastguard Worker compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn) 9088*da0073e9SAndroid Build Coastguard Worker compiled = compiled_fn(*t_list, x) 9089*da0073e9SAndroid Build Coastguard Worker eager = fn(*t_list, x) 9090*da0073e9SAndroid Build Coastguard Worker 9091*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(eager), list(compiled)) 9092*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(counters["graph_break"]), 0) 9093*da0073e9SAndroid Build Coastguard Worker 9094*da0073e9SAndroid Build Coastguard Worker def test_packaging_version_parse(self): 9095*da0073e9SAndroid Build Coastguard Worker from packaging import version 9096*da0073e9SAndroid Build Coastguard Worker 9097*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager", fullgraph=True) 9098*da0073e9SAndroid Build Coastguard Worker def fn(): 9099*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(1) 9100*da0073e9SAndroid Build Coastguard Worker if version.parse(torch.__version__) >= version.parse("2.0.0"): 9101*da0073e9SAndroid Build Coastguard Worker return x + 1 9102*da0073e9SAndroid Build Coastguard Worker return x 9103*da0073e9SAndroid Build Coastguard Worker 9104*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn().item(), 1) 9105*da0073e9SAndroid Build Coastguard Worker 9106*da0073e9SAndroid Build Coastguard Worker def test_itertools_accumulate_tensors_user_defined(self): 9107*da0073e9SAndroid Build Coastguard Worker def udo_fn_0(a, b): 9108*da0073e9SAndroid Build Coastguard Worker return -1 9109*da0073e9SAndroid Build Coastguard Worker 9110*da0073e9SAndroid Build Coastguard Worker rando = random.randint(0, 1) 9111*da0073e9SAndroid Build Coastguard Worker 9112*da0073e9SAndroid Build Coastguard Worker def udo_fn_1(a, b): 9113*da0073e9SAndroid Build Coastguard Worker return a * rando + b * rando 9114*da0073e9SAndroid Build Coastguard Worker 9115*da0073e9SAndroid Build Coastguard Worker seen = [] 9116*da0073e9SAndroid Build Coastguard Worker 9117*da0073e9SAndroid Build Coastguard Worker def udo_fn_2(a, b): 9118*da0073e9SAndroid Build Coastguard Worker seen.append(a) 9119*da0073e9SAndroid Build Coastguard Worker seen.append(b) 9120*da0073e9SAndroid Build Coastguard Worker return a * len(seen) 9121*da0073e9SAndroid Build Coastguard Worker 9122*da0073e9SAndroid Build Coastguard Worker for udo_fn in [udo_fn_0, udo_fn_1, udo_fn_2]: 9123*da0073e9SAndroid Build Coastguard Worker counters.clear() 9124*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 9125*da0073e9SAndroid Build Coastguard Worker 9126*da0073e9SAndroid Build Coastguard Worker def fn(a, b, c, d, x): 9127*da0073e9SAndroid Build Coastguard Worker l = [a, b, c, d, x] 9128*da0073e9SAndroid Build Coastguard Worker for i, t in enumerate(l): 9129*da0073e9SAndroid Build Coastguard Worker l[i] = t * x 9130*da0073e9SAndroid Build Coastguard Worker return itertools.accumulate(l, udo_fn) 9131*da0073e9SAndroid Build Coastguard Worker 9132*da0073e9SAndroid Build Coastguard Worker t_list = [torch.tensor([i]) for i in range(4)] 9133*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([[1, 2], [3, 4]]) 9134*da0073e9SAndroid Build Coastguard Worker eager = fn(*t_list, x) 9135*da0073e9SAndroid Build Coastguard Worker 9136*da0073e9SAndroid Build Coastguard Worker compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn) 9137*da0073e9SAndroid Build Coastguard Worker compiled = compiled_fn(*t_list, x) 9138*da0073e9SAndroid Build Coastguard Worker 9139*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(eager), list(compiled)) 9140*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(counters["graph_break"]), 0) 9141*da0073e9SAndroid Build Coastguard Worker 9142*da0073e9SAndroid Build Coastguard Worker def test_pure_python_accumulate(self): 9143*da0073e9SAndroid Build Coastguard Worker def accumulate(iterable, func=lambda x, y: x + y): 9144*da0073e9SAndroid Build Coastguard Worker it = iter(iterable) 9145*da0073e9SAndroid Build Coastguard Worker try: 9146*da0073e9SAndroid Build Coastguard Worker # Initialize the accumulator with the first value from the iterable 9147*da0073e9SAndroid Build Coastguard Worker accumulator = next(it) 9148*da0073e9SAndroid Build Coastguard Worker except StopIteration: 9149*da0073e9SAndroid Build Coastguard Worker # If the iterable is empty, return an empty generator 9150*da0073e9SAndroid Build Coastguard Worker return 9151*da0073e9SAndroid Build Coastguard Worker yield accumulator 9152*da0073e9SAndroid Build Coastguard Worker 9153*da0073e9SAndroid Build Coastguard Worker for element in it: 9154*da0073e9SAndroid Build Coastguard Worker accumulator = func(accumulator, element) 9155*da0073e9SAndroid Build Coastguard Worker yield accumulator 9156*da0073e9SAndroid Build Coastguard Worker 9157*da0073e9SAndroid Build Coastguard Worker def fn(it): 9158*da0073e9SAndroid Build Coastguard Worker return accumulate(it) 9159*da0073e9SAndroid Build Coastguard Worker 9160*da0073e9SAndroid Build Coastguard Worker t_list = [torch.tensor([i]) for i in range(4)] 9161*da0073e9SAndroid Build Coastguard Worker eager = fn(t_list) 9162*da0073e9SAndroid Build Coastguard Worker 9163*da0073e9SAndroid Build Coastguard Worker counter = CompileCounter() 9164*da0073e9SAndroid Build Coastguard Worker compiled_fn = torch._dynamo.optimize(counter)(fn) 9165*da0073e9SAndroid Build Coastguard Worker compiled = compiled_fn(t_list) 9166*da0073e9SAndroid Build Coastguard Worker 9167*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(eager), list(compiled)) 9168*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 1) 9169*da0073e9SAndroid Build Coastguard Worker 9170*da0073e9SAndroid Build Coastguard Worker def test_itertools_groupby_pure_python_default_identify_func(self): 9171*da0073e9SAndroid Build Coastguard Worker counters.clear() 9172*da0073e9SAndroid Build Coastguard Worker 9173*da0073e9SAndroid Build Coastguard Worker def fn(l): 9174*da0073e9SAndroid Build Coastguard Worker return [(k, list(g)) for k, g in itertools.groupby(l)] 9175*da0073e9SAndroid Build Coastguard Worker 9176*da0073e9SAndroid Build Coastguard Worker l = [1, 2, 2, 3, 4, 4, 4, 1, 2] 9177*da0073e9SAndroid Build Coastguard Worker eager = fn(l) 9178*da0073e9SAndroid Build Coastguard Worker 9179*da0073e9SAndroid Build Coastguard Worker compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn) 9180*da0073e9SAndroid Build Coastguard Worker compiled = compiled_fn(l) 9181*da0073e9SAndroid Build Coastguard Worker 9182*da0073e9SAndroid Build Coastguard Worker self.assertEqual(eager, compiled) 9183*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(counters["graph_break"]), 0) 9184*da0073e9SAndroid Build Coastguard Worker 9185*da0073e9SAndroid Build Coastguard Worker def test_itertools_groupby_pure_python_key_func(self): 9186*da0073e9SAndroid Build Coastguard Worker counters.clear() 9187*da0073e9SAndroid Build Coastguard Worker 9188*da0073e9SAndroid Build Coastguard Worker def fn(l): 9189*da0073e9SAndroid Build Coastguard Worker return [(k, list(g)) for k, g in itertools.groupby(l, key=operator.neg)] 9190*da0073e9SAndroid Build Coastguard Worker 9191*da0073e9SAndroid Build Coastguard Worker l = [1, 2, -2, 3, 4, 4, -4, 0, -2] 9192*da0073e9SAndroid Build Coastguard Worker eager = fn(l) 9193*da0073e9SAndroid Build Coastguard Worker 9194*da0073e9SAndroid Build Coastguard Worker compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn) 9195*da0073e9SAndroid Build Coastguard Worker compiled = compiled_fn(l) 9196*da0073e9SAndroid Build Coastguard Worker 9197*da0073e9SAndroid Build Coastguard Worker self.assertEqual(eager, compiled) 9198*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(counters["graph_break"]), 0) 9199*da0073e9SAndroid Build Coastguard Worker 9200*da0073e9SAndroid Build Coastguard Worker def test_list_iterator_contains(self): 9201*da0073e9SAndroid Build Coastguard Worker def fn(x): 9202*da0073e9SAndroid Build Coastguard Worker it = iter(["my_weight", "not_my_weight"]) 9203*da0073e9SAndroid Build Coastguard Worker next(it) 9204*da0073e9SAndroid Build Coastguard Worker if "my_weight" in it: 9205*da0073e9SAndroid Build Coastguard Worker return x + 2 9206*da0073e9SAndroid Build Coastguard Worker return x + 1 9207*da0073e9SAndroid Build Coastguard Worker 9208*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(3) 9209*da0073e9SAndroid Build Coastguard Worker compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn) 9210*da0073e9SAndroid Build Coastguard Worker 9211*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(x), compiled_fn(x)) 9212*da0073e9SAndroid Build Coastguard Worker 9213*da0073e9SAndroid Build Coastguard Worker def test_storage_return(self): 9214*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager", fullgraph=True) 9215*da0073e9SAndroid Build Coastguard Worker def fn(x): 9216*da0073e9SAndroid Build Coastguard Worker y = torch.sin(x + 1) 9217*da0073e9SAndroid Build Coastguard Worker storage = x.untyped_storage() 9218*da0073e9SAndroid Build Coastguard Worker storage.resize_(0) 9219*da0073e9SAndroid Build Coastguard Worker y = torch.cos(y) 9220*da0073e9SAndroid Build Coastguard Worker return y, storage 9221*da0073e9SAndroid Build Coastguard Worker 9222*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10) 9223*da0073e9SAndroid Build Coastguard Worker expected = torch.cos(torch.sin(x + 1)) 9224*da0073e9SAndroid Build Coastguard Worker y, s = fn(x) 9225*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, expected) 9226*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.untyped_storage().size(), 0) 9227*da0073e9SAndroid Build Coastguard Worker self.assertIs(s, x.untyped_storage()) 9228*da0073e9SAndroid Build Coastguard Worker 9229*da0073e9SAndroid Build Coastguard Worker def test_flat_name_to_original_fqn(self): 9230*da0073e9SAndroid Build Coastguard Worker class FooBarModule(torch.nn.Module): 9231*da0073e9SAndroid Build Coastguard Worker def __init__(self): 9232*da0073e9SAndroid Build Coastguard Worker super().__init__() 9233*da0073e9SAndroid Build Coastguard Worker self.register_parameter("0", torch.nn.Parameter(torch.randn(3, 4))) 9234*da0073e9SAndroid Build Coastguard Worker self.register_buffer("test_buf", torch.randn(3, 4)) 9235*da0073e9SAndroid Build Coastguard Worker self.register_parameter( 9236*da0073e9SAndroid Build Coastguard Worker "test_param", torch.nn.Parameter(torch.randn(3, 4)) 9237*da0073e9SAndroid Build Coastguard Worker ) 9238*da0073e9SAndroid Build Coastguard Worker 9239*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 9240*da0073e9SAndroid Build Coastguard Worker return ((x + self.test_buf) * getattr(self, "0")) / self.test_param 9241*da0073e9SAndroid Build Coastguard Worker 9242*da0073e9SAndroid Build Coastguard Worker class TestModule(torch.nn.Module): 9243*da0073e9SAndroid Build Coastguard Worker def __init__(self): 9244*da0073e9SAndroid Build Coastguard Worker super().__init__() 9245*da0073e9SAndroid Build Coastguard Worker self.foo_bar = FooBarModule() 9246*da0073e9SAndroid Build Coastguard Worker self.register_parameter( 9247*da0073e9SAndroid Build Coastguard Worker "test_param", torch.nn.Parameter(torch.randn(3, 4)) 9248*da0073e9SAndroid Build Coastguard Worker ) 9249*da0073e9SAndroid Build Coastguard Worker self.register_buffer("test_buf", torch.randn(3, 4)) 9250*da0073e9SAndroid Build Coastguard Worker 9251*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 9252*da0073e9SAndroid Build Coastguard Worker return (self.foo_bar(x) + self.test_param) * self.test_buf 9253*da0073e9SAndroid Build Coastguard Worker 9254*da0073e9SAndroid Build Coastguard Worker gm, _ = torch._dynamo.export(TestModule(), torch.randn(3, 4)) 9255*da0073e9SAndroid Build Coastguard Worker self.assertIn("dynamo_flat_name_to_original_fqn", gm.meta) 9256*da0073e9SAndroid Build Coastguard Worker expected_fqn = { 9257*da0073e9SAndroid Build Coastguard Worker "L__self___test_param": "test_param", 9258*da0073e9SAndroid Build Coastguard Worker "L__self___test_buf": "test_buf", 9259*da0073e9SAndroid Build Coastguard Worker "getattr_L__self___foo_bar___0__": "foo_bar.0", 9260*da0073e9SAndroid Build Coastguard Worker "L__self___foo_bar_test_param": "foo_bar.test_param", 9261*da0073e9SAndroid Build Coastguard Worker "L__self___foo_bar_test_buf": "foo_bar.test_buf", 9262*da0073e9SAndroid Build Coastguard Worker } 9263*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_fqn, gm.meta["dynamo_flat_name_to_original_fqn"]) 9264*da0073e9SAndroid Build Coastguard Worker 9265*da0073e9SAndroid Build Coastguard Worker def test_shape_env_no_recording(self): 9266*da0073e9SAndroid Build Coastguard Worker main = ShapeEnv(should_record_events=False) 9267*da0073e9SAndroid Build Coastguard Worker 9268*da0073e9SAndroid Build Coastguard Worker # The main ShapeEnv should have no event recorded. 9269*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(main.events), 0) 9270*da0073e9SAndroid Build Coastguard Worker 9271*da0073e9SAndroid Build Coastguard Worker # Call create_symbolic_sizes_strides_storage_offset on both of them. 9272*da0073e9SAndroid Build Coastguard Worker r = main.create_symbolic_sizes_strides_storage_offset( 9273*da0073e9SAndroid Build Coastguard Worker torch.randn(3, 2), ConstantSource("x") 9274*da0073e9SAndroid Build Coastguard Worker ) 9275*da0073e9SAndroid Build Coastguard Worker 9276*da0073e9SAndroid Build Coastguard Worker # Create a guard: size[0] == 3 (call evaluate_expr) 9277*da0073e9SAndroid Build Coastguard Worker # - +1 guard entry 9278*da0073e9SAndroid Build Coastguard Worker # - +1 replacement entry 9279*da0073e9SAndroid Build Coastguard Worker size = r[0] 9280*da0073e9SAndroid Build Coastguard Worker bool(size[0] == 3) 9281*da0073e9SAndroid Build Coastguard Worker 9282*da0073e9SAndroid Build Coastguard Worker # The main ShapeEnv should remain with no event recorded. 9283*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(main.events), 0) 9284*da0073e9SAndroid Build Coastguard Worker 9285*da0073e9SAndroid Build Coastguard Worker if torch.fx.experimental.validator.translation_validation_enabled(): 9286*da0073e9SAndroid Build Coastguard Worker from torch.fx.experimental.symbolic_shapes import ( 9287*da0073e9SAndroid Build Coastguard Worker CURRENT_NODE_KEY, 9288*da0073e9SAndroid Build Coastguard Worker SHAPEENV_EVENT_KEY, 9289*da0073e9SAndroid Build Coastguard Worker ) 9290*da0073e9SAndroid Build Coastguard Worker 9291*da0073e9SAndroid Build Coastguard Worker # Check that we don't store any recording metadata on nodes 9292*da0073e9SAndroid Build Coastguard Worker # from the symbolic shape FX graph. 9293*da0073e9SAndroid Build Coastguard Worker for n in main.graph.nodes: 9294*da0073e9SAndroid Build Coastguard Worker self.assertFalse(SHAPEENV_EVENT_KEY in n.meta) 9295*da0073e9SAndroid Build Coastguard Worker self.assertFalse(CURRENT_NODE_KEY in n.meta) 9296*da0073e9SAndroid Build Coastguard Worker 9297*da0073e9SAndroid Build Coastguard Worker def _replay_and_check(self, shape_env: ShapeEnv): 9298*da0073e9SAndroid Build Coastguard Worker if shape_env.should_record_events: 9299*da0073e9SAndroid Build Coastguard Worker replayed = replay_shape_env_events(shape_env.events) 9300*da0073e9SAndroid Build Coastguard Worker shape_env.check_equal(replayed) 9301*da0073e9SAndroid Build Coastguard Worker 9302*da0073e9SAndroid Build Coastguard Worker def test_shape_env_equal_empty(self): 9303*da0073e9SAndroid Build Coastguard Worker main, other = ShapeEnv(), ShapeEnv() 9304*da0073e9SAndroid Build Coastguard Worker main.check_equal(other) 9305*da0073e9SAndroid Build Coastguard Worker self._replay_and_check(main) 9306*da0073e9SAndroid Build Coastguard Worker 9307*da0073e9SAndroid Build Coastguard Worker @onlyIfTranslationValidation 9308*da0073e9SAndroid Build Coastguard Worker def test_shape_env_equal_constructor(self): 9309*da0073e9SAndroid Build Coastguard Worker main, other = ShapeEnv(allow_scalar_outputs=False), ShapeEnv() 9310*da0073e9SAndroid Build Coastguard Worker self.assertExpectedRaisesInline( 9311*da0073e9SAndroid Build Coastguard Worker NotEqualError, 9312*da0073e9SAndroid Build Coastguard Worker lambda: main.check_equal(other), 9313*da0073e9SAndroid Build Coastguard Worker """\ 9314*da0073e9SAndroid Build Coastguard WorkerShapeEnv not equal: field values don't match: 9315*da0073e9SAndroid Build Coastguard Worker 9316*da0073e9SAndroid Build Coastguard Worker==> settings: values don't match. 9317*da0073e9SAndroid Build Coastguard Worker > Left: ShapeEnvSettings(allow_scalar_outputs=False, allow_dynamic_output_shape_ops=True, assume_static_by_default=False, specialize_zero_one=True, duck_shape=True, prefer_deferred_runtime_asserts_over_guards=False, _allow_complex_guards_as_runtime_asserts=False) 9318*da0073e9SAndroid Build Coastguard Worker > Right: ShapeEnvSettings(allow_scalar_outputs=True, allow_dynamic_output_shape_ops=True, assume_static_by_default=False, specialize_zero_one=True, duck_shape=True, prefer_deferred_runtime_asserts_over_guards=False, _allow_complex_guards_as_runtime_asserts=False) 9319*da0073e9SAndroid Build Coastguard Worker""", 9320*da0073e9SAndroid Build Coastguard Worker ) 9321*da0073e9SAndroid Build Coastguard Worker self._replay_and_check(main) 9322*da0073e9SAndroid Build Coastguard Worker 9323*da0073e9SAndroid Build Coastguard Worker @onlyIfTranslationValidation 9324*da0073e9SAndroid Build Coastguard Worker def test_shape_env_equal_create_symbolic_sizes_strides_storage_offset(self): 9325*da0073e9SAndroid Build Coastguard Worker main, other = ShapeEnv(), ShapeEnv() 9326*da0073e9SAndroid Build Coastguard Worker main.create_symbolic_sizes_strides_storage_offset( 9327*da0073e9SAndroid Build Coastguard Worker torch.randn(3, 2), ConstantSource("x") 9328*da0073e9SAndroid Build Coastguard Worker ) 9329*da0073e9SAndroid Build Coastguard Worker self.assertExpectedRaisesInline( 9330*da0073e9SAndroid Build Coastguard Worker NotEqualError, 9331*da0073e9SAndroid Build Coastguard Worker lambda: main.check_equal(other), 9332*da0073e9SAndroid Build Coastguard Worker """\ 9333*da0073e9SAndroid Build Coastguard WorkerShapeEnv not equal: field values don't match: 9334*da0073e9SAndroid Build Coastguard Worker 9335*da0073e9SAndroid Build Coastguard Worker==> name_to_node: values don't match. 9336*da0073e9SAndroid Build Coastguard Worker > Left: {x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_} 9337*da0073e9SAndroid Build Coastguard Worker > Right: {} 9338*da0073e9SAndroid Build Coastguard Worker==> source_to_symbol: values don't match. 9339*da0073e9SAndroid Build Coastguard Worker > Left: {x.size()[0]: x.size()[0], x.size()[1]: x.size()[1], x.storage_offset(): x.storage_offset(), x.stride()[0]: x.stride()[0], x.stride()[1]: x.stride()[1]} 9340*da0073e9SAndroid Build Coastguard Worker > Right: {} 9341*da0073e9SAndroid Build Coastguard Worker==> val_to_var: values don't match. 9342*da0073e9SAndroid Build Coastguard Worker > Left: {0: 0, 1: 1, 2: s1, 3: s0} 9343*da0073e9SAndroid Build Coastguard Worker > Right: {0: 0, 1: 1} 9344*da0073e9SAndroid Build Coastguard Worker==> var_to_range: values don't match. 9345*da0073e9SAndroid Build Coastguard Worker > Left: {s0: VR[2, int_oo], s1: VR[2, int_oo]} 9346*da0073e9SAndroid Build Coastguard Worker > Right: {} 9347*da0073e9SAndroid Build Coastguard Worker==> var_to_sources: values don't match. 9348*da0073e9SAndroid Build Coastguard Worker > Left: {s0: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=<TensorProperty.SIZE: 0>, idx=0)], s1: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=<TensorProperty.SIZE: 0>, idx=1)]} 9349*da0073e9SAndroid Build Coastguard Worker > Right: {} 9350*da0073e9SAndroid Build Coastguard Worker==> var_to_val: values don't match. 9351*da0073e9SAndroid Build Coastguard Worker > Left: {s0: 3, s1: 2} 9352*da0073e9SAndroid Build Coastguard Worker > Right: {} 9353*da0073e9SAndroid Build Coastguard Worker""", 9354*da0073e9SAndroid Build Coastguard Worker ) 9355*da0073e9SAndroid Build Coastguard Worker self._replay_and_check(main) 9356*da0073e9SAndroid Build Coastguard Worker 9357*da0073e9SAndroid Build Coastguard Worker @onlyIfTranslationValidation 9358*da0073e9SAndroid Build Coastguard Worker def test_shape_env_equal_unbacked(self): 9359*da0073e9SAndroid Build Coastguard Worker main, other = ShapeEnv(), ShapeEnv() 9360*da0073e9SAndroid Build Coastguard Worker main.create_unbacked_symint() 9361*da0073e9SAndroid Build Coastguard Worker main.create_unbacked_symfloat() 9362*da0073e9SAndroid Build Coastguard Worker main.create_unbacked_symbool() 9363*da0073e9SAndroid Build Coastguard Worker self.assertExpectedRaisesInline( 9364*da0073e9SAndroid Build Coastguard Worker NotEqualError, 9365*da0073e9SAndroid Build Coastguard Worker lambda: main.check_equal(other), 9366*da0073e9SAndroid Build Coastguard Worker """\ 9367*da0073e9SAndroid Build Coastguard WorkerShapeEnv not equal: field values don't match: 9368*da0073e9SAndroid Build Coastguard Worker 9369*da0073e9SAndroid Build Coastguard Worker==> name_to_node: values don't match. 9370*da0073e9SAndroid Build Coastguard Worker > Left: {u0, u1, zuf0} 9371*da0073e9SAndroid Build Coastguard Worker > Right: {} 9372*da0073e9SAndroid Build Coastguard Worker==> unbacked_symfloat_counter: values don't match. 9373*da0073e9SAndroid Build Coastguard Worker > Left: 1 9374*da0073e9SAndroid Build Coastguard Worker > Right: 0 9375*da0073e9SAndroid Build Coastguard Worker==> unbacked_symint_counter: values don't match. 9376*da0073e9SAndroid Build Coastguard Worker > Left: 2 9377*da0073e9SAndroid Build Coastguard Worker > Right: 0 9378*da0073e9SAndroid Build Coastguard Worker==> var_to_range: values don't match. 9379*da0073e9SAndroid Build Coastguard Worker > Left: {u0: VR[-int_oo, int_oo], u1: VR[0, 1], zuf0: VR[-oo, oo]} 9380*da0073e9SAndroid Build Coastguard Worker > Right: {} 9381*da0073e9SAndroid Build Coastguard Worker""", 9382*da0073e9SAndroid Build Coastguard Worker ) 9383*da0073e9SAndroid Build Coastguard Worker self._replay_and_check(main) 9384*da0073e9SAndroid Build Coastguard Worker 9385*da0073e9SAndroid Build Coastguard Worker @onlyIfTranslationValidation 9386*da0073e9SAndroid Build Coastguard Worker def test_shape_env_equal_evaluate_expr_divisible(self): 9387*da0073e9SAndroid Build Coastguard Worker main, other = ShapeEnv(), ShapeEnv() 9388*da0073e9SAndroid Build Coastguard Worker 9389*da0073e9SAndroid Build Coastguard Worker # Call create_symbolic_sizes_strides_storage_offset on both of them. 9390*da0073e9SAndroid Build Coastguard Worker r = main.create_symbolic_sizes_strides_storage_offset( 9391*da0073e9SAndroid Build Coastguard Worker torch.randn(3, 2), ConstantSource("x") 9392*da0073e9SAndroid Build Coastguard Worker ) 9393*da0073e9SAndroid Build Coastguard Worker other.create_symbolic_sizes_strides_storage_offset( 9394*da0073e9SAndroid Build Coastguard Worker torch.randn(3, 2), ConstantSource("x") 9395*da0073e9SAndroid Build Coastguard Worker ) 9396*da0073e9SAndroid Build Coastguard Worker 9397*da0073e9SAndroid Build Coastguard Worker # Create a guard: size[0] % 3 == 0 (only in the main ShapeEnv) 9398*da0073e9SAndroid Build Coastguard Worker # - +1 guard entry 9399*da0073e9SAndroid Build Coastguard Worker # - +1 divisible entry 9400*da0073e9SAndroid Build Coastguard Worker size = r[0] 9401*da0073e9SAndroid Build Coastguard Worker bool(size[0] % 3 == 0) 9402*da0073e9SAndroid Build Coastguard Worker 9403*da0073e9SAndroid Build Coastguard Worker self.assertExpectedRaisesInline( 9404*da0073e9SAndroid Build Coastguard Worker NotEqualError, 9405*da0073e9SAndroid Build Coastguard Worker lambda: main.check_equal(other), 9406*da0073e9SAndroid Build Coastguard Worker """\ 9407*da0073e9SAndroid Build Coastguard WorkerShapeEnv not equal: field values don't match: 9408*da0073e9SAndroid Build Coastguard Worker 9409*da0073e9SAndroid Build Coastguard Worker==> divisible: values don't match. 9410*da0073e9SAndroid Build Coastguard Worker > Left: {Mod(s0, 3)} 9411*da0073e9SAndroid Build Coastguard Worker > Right: {} 9412*da0073e9SAndroid Build Coastguard Worker==> guards: values don't match. 9413*da0073e9SAndroid Build Coastguard Worker > Left: [Eq(Mod(s0, 3), 0)] 9414*da0073e9SAndroid Build Coastguard Worker > Right: [] 9415*da0073e9SAndroid Build Coastguard Worker==> name_to_node: values don't match. 9416*da0073e9SAndroid Build Coastguard Worker > Left: {_assert, eq, mod, x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_} 9417*da0073e9SAndroid Build Coastguard Worker > Right: {x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_} 9418*da0073e9SAndroid Build Coastguard Worker""", 9419*da0073e9SAndroid Build Coastguard Worker ) 9420*da0073e9SAndroid Build Coastguard Worker self._replay_and_check(main) 9421*da0073e9SAndroid Build Coastguard Worker 9422*da0073e9SAndroid Build Coastguard Worker @onlyIfTranslationValidation 9423*da0073e9SAndroid Build Coastguard Worker def test_shape_env_equal_evaluate_expr_replacement(self): 9424*da0073e9SAndroid Build Coastguard Worker main, other = ShapeEnv(), ShapeEnv() 9425*da0073e9SAndroid Build Coastguard Worker 9426*da0073e9SAndroid Build Coastguard Worker # Call create_symbolic_sizes_strides_storage_offset on both of them. 9427*da0073e9SAndroid Build Coastguard Worker r = main.create_symbolic_sizes_strides_storage_offset( 9428*da0073e9SAndroid Build Coastguard Worker torch.randn(3, 2), ConstantSource("x") 9429*da0073e9SAndroid Build Coastguard Worker ) 9430*da0073e9SAndroid Build Coastguard Worker other.create_symbolic_sizes_strides_storage_offset( 9431*da0073e9SAndroid Build Coastguard Worker torch.randn(3, 2), ConstantSource("x") 9432*da0073e9SAndroid Build Coastguard Worker ) 9433*da0073e9SAndroid Build Coastguard Worker 9434*da0073e9SAndroid Build Coastguard Worker # Create a guard: size[0] == 3 (only in the main ShapeEnv) 9435*da0073e9SAndroid Build Coastguard Worker # - +1 guard entry 9436*da0073e9SAndroid Build Coastguard Worker # - +1 replacement entry 9437*da0073e9SAndroid Build Coastguard Worker size = r[0] 9438*da0073e9SAndroid Build Coastguard Worker bool(size[0] == 3) 9439*da0073e9SAndroid Build Coastguard Worker 9440*da0073e9SAndroid Build Coastguard Worker self.assertExpectedRaisesInline( 9441*da0073e9SAndroid Build Coastguard Worker NotEqualError, 9442*da0073e9SAndroid Build Coastguard Worker lambda: main.check_equal(other), 9443*da0073e9SAndroid Build Coastguard Worker """\ 9444*da0073e9SAndroid Build Coastguard WorkerShapeEnv not equal: field values don't match: 9445*da0073e9SAndroid Build Coastguard Worker 9446*da0073e9SAndroid Build Coastguard Worker==> guards: values don't match. 9447*da0073e9SAndroid Build Coastguard Worker > Left: [Eq(s0, 3)] 9448*da0073e9SAndroid Build Coastguard Worker > Right: [] 9449*da0073e9SAndroid Build Coastguard Worker==> name_to_node: values don't match. 9450*da0073e9SAndroid Build Coastguard Worker > Left: {_assert, eq, x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_} 9451*da0073e9SAndroid Build Coastguard Worker > Right: {x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_} 9452*da0073e9SAndroid Build Coastguard Worker==> replacements: values don't match. 9453*da0073e9SAndroid Build Coastguard Worker > Left: {s0: 3} 9454*da0073e9SAndroid Build Coastguard Worker > Right: {} 9455*da0073e9SAndroid Build Coastguard Worker==> var_to_range: values don't match. 9456*da0073e9SAndroid Build Coastguard Worker > Left: {s0: VR[3, 3], s1: VR[2, int_oo]} 9457*da0073e9SAndroid Build Coastguard Worker > Right: {s0: VR[2, int_oo], s1: VR[2, int_oo]} 9458*da0073e9SAndroid Build Coastguard Worker""", 9459*da0073e9SAndroid Build Coastguard Worker ) 9460*da0073e9SAndroid Build Coastguard Worker self._replay_and_check(main) 9461*da0073e9SAndroid Build Coastguard Worker 9462*da0073e9SAndroid Build Coastguard Worker @onlyIfTranslationValidation 9463*da0073e9SAndroid Build Coastguard Worker def test_shape_env_equal_evaluate_expr_refinement(self): 9464*da0073e9SAndroid Build Coastguard Worker main, other = ShapeEnv(), ShapeEnv() 9465*da0073e9SAndroid Build Coastguard Worker 9466*da0073e9SAndroid Build Coastguard Worker # Call create_symbolic_sizes_strides_storage_offset on both of them. 9467*da0073e9SAndroid Build Coastguard Worker r = main.create_symbolic_sizes_strides_storage_offset( 9468*da0073e9SAndroid Build Coastguard Worker torch.randn(3, 2), ConstantSource("x") 9469*da0073e9SAndroid Build Coastguard Worker ) 9470*da0073e9SAndroid Build Coastguard Worker other.create_symbolic_sizes_strides_storage_offset( 9471*da0073e9SAndroid Build Coastguard Worker torch.randn(3, 2), ConstantSource("x") 9472*da0073e9SAndroid Build Coastguard Worker ) 9473*da0073e9SAndroid Build Coastguard Worker 9474*da0073e9SAndroid Build Coastguard Worker # Create a guard: size[0] >= 3 (only in the main ShapeEnv) 9475*da0073e9SAndroid Build Coastguard Worker # - +1 guard entry 9476*da0073e9SAndroid Build Coastguard Worker # - +1 var_to_guard entry 9477*da0073e9SAndroid Build Coastguard Worker # - Change: var_to_range 9478*da0073e9SAndroid Build Coastguard Worker size = r[0] 9479*da0073e9SAndroid Build Coastguard Worker bool(size[0] >= 3) 9480*da0073e9SAndroid Build Coastguard Worker 9481*da0073e9SAndroid Build Coastguard Worker self.assertExpectedRaisesInline( 9482*da0073e9SAndroid Build Coastguard Worker NotEqualError, 9483*da0073e9SAndroid Build Coastguard Worker lambda: main.check_equal(other), 9484*da0073e9SAndroid Build Coastguard Worker """\ 9485*da0073e9SAndroid Build Coastguard WorkerShapeEnv not equal: field values don't match: 9486*da0073e9SAndroid Build Coastguard Worker 9487*da0073e9SAndroid Build Coastguard Worker==> guards: values don't match. 9488*da0073e9SAndroid Build Coastguard Worker > Left: [s0 >= 3] 9489*da0073e9SAndroid Build Coastguard Worker > Right: [] 9490*da0073e9SAndroid Build Coastguard Worker==> name_to_node: values don't match. 9491*da0073e9SAndroid Build Coastguard Worker > Left: {_assert, ge, x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_} 9492*da0073e9SAndroid Build Coastguard Worker > Right: {x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_} 9493*da0073e9SAndroid Build Coastguard Worker==> var_to_range: values don't match. 9494*da0073e9SAndroid Build Coastguard Worker > Left: {s0: VR[3, int_oo], s1: VR[2, int_oo]} 9495*da0073e9SAndroid Build Coastguard Worker > Right: {s0: VR[2, int_oo], s1: VR[2, int_oo]} 9496*da0073e9SAndroid Build Coastguard Worker""", 9497*da0073e9SAndroid Build Coastguard Worker ) 9498*da0073e9SAndroid Build Coastguard Worker self._replay_and_check(main) 9499*da0073e9SAndroid Build Coastguard Worker 9500*da0073e9SAndroid Build Coastguard Worker @onlyIfTranslationValidation 9501*da0073e9SAndroid Build Coastguard Worker def test_shape_env_equal_runtime_assert(self): 9502*da0073e9SAndroid Build Coastguard Worker main, other = ShapeEnv(), ShapeEnv() 9503*da0073e9SAndroid Build Coastguard Worker 9504*da0073e9SAndroid Build Coastguard Worker # Call create_unbacked_symint on both of them. 9505*da0073e9SAndroid Build Coastguard Worker r = main.create_unbacked_symint() 9506*da0073e9SAndroid Build Coastguard Worker other.create_unbacked_symint() 9507*da0073e9SAndroid Build Coastguard Worker 9508*da0073e9SAndroid Build Coastguard Worker # Create a runtime assert: r % 3 == 0 (only in the main ShapeEnv) 9509*da0073e9SAndroid Build Coastguard Worker # - +1 deferred_runtime_asserts entry 9510*da0073e9SAndroid Build Coastguard Worker # - Change: num_deferred_runtime_asserts 9511*da0073e9SAndroid Build Coastguard Worker expect_true(r % 3 == 0) 9512*da0073e9SAndroid Build Coastguard Worker 9513*da0073e9SAndroid Build Coastguard Worker self.assertExpectedRaisesInline( 9514*da0073e9SAndroid Build Coastguard Worker NotEqualError, 9515*da0073e9SAndroid Build Coastguard Worker lambda: main.check_equal(other), 9516*da0073e9SAndroid Build Coastguard Worker """\ 9517*da0073e9SAndroid Build Coastguard WorkerShapeEnv not equal: field values don't match: 9518*da0073e9SAndroid Build Coastguard Worker 9519*da0073e9SAndroid Build Coastguard Worker==> deferred_runtime_asserts: values don't match. 9520*da0073e9SAndroid Build Coastguard Worker > Left: {u0: [Eq(PythonMod(u0, 3), 0)]} 9521*da0073e9SAndroid Build Coastguard Worker > Right: {} 9522*da0073e9SAndroid Build Coastguard Worker==> name_to_node: values don't match. 9523*da0073e9SAndroid Build Coastguard Worker > Left: {_assert, eq, mod, u0} 9524*da0073e9SAndroid Build Coastguard Worker > Right: {u0} 9525*da0073e9SAndroid Build Coastguard Worker==> num_deferred_runtime_asserts: values don't match. 9526*da0073e9SAndroid Build Coastguard Worker > Left: 1 9527*da0073e9SAndroid Build Coastguard Worker > Right: 0 9528*da0073e9SAndroid Build Coastguard Worker""", 9529*da0073e9SAndroid Build Coastguard Worker ) 9530*da0073e9SAndroid Build Coastguard Worker self._replay_and_check(main) 9531*da0073e9SAndroid Build Coastguard Worker 9532*da0073e9SAndroid Build Coastguard Worker def test_shape_env_recorded_function_fallback(self): 9533*da0073e9SAndroid Build Coastguard Worker # Make sure the record/replay mechanism for ShapeEnv will fallback 9534*da0073e9SAndroid Build Coastguard Worker # if no ShapeEnv instance is found. 9535*da0073e9SAndroid Build Coastguard Worker constrain_range(5, min=2, max=10) 9536*da0073e9SAndroid Build Coastguard Worker constrain_unify(5, 5) 9537*da0073e9SAndroid Build Coastguard Worker 9538*da0073e9SAndroid Build Coastguard Worker self.assertExpectedRaisesInline( 9539*da0073e9SAndroid Build Coastguard Worker AssertionError, 9540*da0073e9SAndroid Build Coastguard Worker lambda: _constrain_range_for_size(5, min=2, max=10), 9541*da0073e9SAndroid Build Coastguard Worker """can only constrain range for SymInt""", 9542*da0073e9SAndroid Build Coastguard Worker ) 9543*da0073e9SAndroid Build Coastguard Worker 9544*da0073e9SAndroid Build Coastguard Worker def test_default_dtype_change(self): 9545*da0073e9SAndroid Build Coastguard Worker @torch.compile 9546*da0073e9SAndroid Build Coastguard Worker def foo(): 9547*da0073e9SAndroid Build Coastguard Worker def inner(a, b, res_dtype): 9548*da0073e9SAndroid Build Coastguard Worker print(a, b, res_dtype) 9549*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.result_type(a, b), res_dtype) 9550*da0073e9SAndroid Build Coastguard Worker 9551*da0073e9SAndroid Build Coastguard Worker inner(torch.tensor(1, device="cpu"), 1.0, torch.get_default_dtype()) 9552*da0073e9SAndroid Build Coastguard Worker 9553*da0073e9SAndroid Build Coastguard Worker with set_default_dtype(torch.float): 9554*da0073e9SAndroid Build Coastguard Worker foo() 9555*da0073e9SAndroid Build Coastguard Worker with set_default_dtype(torch.double): 9556*da0073e9SAndroid Build Coastguard Worker foo() 9557*da0073e9SAndroid Build Coastguard Worker 9558*da0073e9SAndroid Build Coastguard Worker def test_numpy_ufunc_out(self): 9559*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager") 9560*da0073e9SAndroid Build Coastguard Worker def foo(): 9561*da0073e9SAndroid Build Coastguard Worker x = np.arange(5) 9562*da0073e9SAndroid Build Coastguard Worker out = np.empty((x.shape[0], x.shape[0])) 9563*da0073e9SAndroid Build Coastguard Worker res_out = np.sin(x, out=out) 9564*da0073e9SAndroid Build Coastguard Worker assert res_out is out 9565*da0073e9SAndroid Build Coastguard Worker 9566*da0073e9SAndroid Build Coastguard Worker foo() 9567*da0073e9SAndroid Build Coastguard Worker 9568*da0073e9SAndroid Build Coastguard Worker # Unfortunately, we don't currently preserve the ids of 9569*da0073e9SAndroid Build Coastguard Worker # res_out and out correctly across the graph break 9570*da0073e9SAndroid Build Coastguard Worker @unittest.expectedFailure 9571*da0073e9SAndroid Build Coastguard Worker def test_numpy_ufunc_out_graph_break(self): 9572*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager") 9573*da0073e9SAndroid Build Coastguard Worker def foo(): 9574*da0073e9SAndroid Build Coastguard Worker x = np.arange(5) 9575*da0073e9SAndroid Build Coastguard Worker out = np.empty((x.shape[0], x.shape[0])) 9576*da0073e9SAndroid Build Coastguard Worker res_out = np.sin(x, out=out) 9577*da0073e9SAndroid Build Coastguard Worker torch._dynamo.graph_break() 9578*da0073e9SAndroid Build Coastguard Worker assert res_out is out 9579*da0073e9SAndroid Build Coastguard Worker 9580*da0073e9SAndroid Build Coastguard Worker foo() 9581*da0073e9SAndroid Build Coastguard Worker 9582*da0073e9SAndroid Build Coastguard Worker def test_dict_subclass_cannot_be_initialized_in_graph(self): 9583*da0073e9SAndroid Build Coastguard Worker for super_class in ( 9584*da0073e9SAndroid Build Coastguard Worker collections.OrderedDict, 9585*da0073e9SAndroid Build Coastguard Worker dict, 9586*da0073e9SAndroid Build Coastguard Worker ): 9587*da0073e9SAndroid Build Coastguard Worker 9588*da0073e9SAndroid Build Coastguard Worker class CustomDict(super_class): 9589*da0073e9SAndroid Build Coastguard Worker def __init__(self, *args, **kwargs): 9590*da0073e9SAndroid Build Coastguard Worker super().__init__(*args, **kwargs) 9591*da0073e9SAndroid Build Coastguard Worker 9592*da0073e9SAndroid Build Coastguard Worker def fn(x): 9593*da0073e9SAndroid Build Coastguard Worker c = CustomDict() 9594*da0073e9SAndroid Build Coastguard Worker c["key"] = x 9595*da0073e9SAndroid Build Coastguard Worker assert "key" in c 9596*da0073e9SAndroid Build Coastguard Worker return c["key"] + 1 9597*da0073e9SAndroid Build Coastguard Worker 9598*da0073e9SAndroid Build Coastguard Worker fn_opt = torch.compile(fn, backend="eager", fullgraph=True) 9599*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 9600*da0073e9SAndroid Build Coastguard Worker torch._dynamo.exc.Unsupported, "call_function UserDefinedClassVariable" 9601*da0073e9SAndroid Build Coastguard Worker ): 9602*da0073e9SAndroid Build Coastguard Worker print(fn_opt(torch.zeros(1))) 9603*da0073e9SAndroid Build Coastguard Worker 9604*da0073e9SAndroid Build Coastguard Worker @wrapDeterministicFlagAPITest 9605*da0073e9SAndroid Build Coastguard Worker def test_backward_deterministic_mode_mismatch_warning(self): 9606*da0073e9SAndroid Build Coastguard Worker @torch.compile 9607*da0073e9SAndroid Build Coastguard Worker def func(a, b): 9608*da0073e9SAndroid Build Coastguard Worker return a + b 9609*da0073e9SAndroid Build Coastguard Worker 9610*da0073e9SAndroid Build Coastguard Worker for forward_deterministic, backward_deterministic in itertools.product( 9611*da0073e9SAndroid Build Coastguard Worker [True, False], [True, False] 9612*da0073e9SAndroid Build Coastguard Worker ): 9613*da0073e9SAndroid Build Coastguard Worker torch.use_deterministic_algorithms(forward_deterministic) 9614*da0073e9SAndroid Build Coastguard Worker a = torch.randn(10, requires_grad=True) 9615*da0073e9SAndroid Build Coastguard Worker res = func(a, 1) 9616*da0073e9SAndroid Build Coastguard Worker grad = torch.ones_like(res) 9617*da0073e9SAndroid Build Coastguard Worker torch.use_deterministic_algorithms(backward_deterministic) 9618*da0073e9SAndroid Build Coastguard Worker 9619*da0073e9SAndroid Build Coastguard Worker if not forward_deterministic and backward_deterministic: 9620*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 9621*da0073e9SAndroid Build Coastguard Worker RuntimeError, 9622*da0073e9SAndroid Build Coastguard Worker "^This compiled backward function is being run with torch\.use_deterministic_algorithms", 9623*da0073e9SAndroid Build Coastguard Worker ): 9624*da0073e9SAndroid Build Coastguard Worker res.backward(grad) 9625*da0073e9SAndroid Build Coastguard Worker 9626*da0073e9SAndroid Build Coastguard Worker else: 9627*da0073e9SAndroid Build Coastguard Worker res.backward(grad) 9628*da0073e9SAndroid Build Coastguard Worker 9629*da0073e9SAndroid Build Coastguard Worker def test_torch_dynamo_codegen_pow(self): 9630*da0073e9SAndroid Build Coastguard Worker def pow(x): 9631*da0073e9SAndroid Build Coastguard Worker return x**2 9632*da0073e9SAndroid Build Coastguard Worker 9633*da0073e9SAndroid Build Coastguard Worker x = np.arange(8) 9634*da0073e9SAndroid Build Coastguard Worker pow_opt = torch.compile(pow) 9635*da0073e9SAndroid Build Coastguard Worker 9636*da0073e9SAndroid Build Coastguard Worker actual, source_code = run_and_get_code(pow_opt, x) 9637*da0073e9SAndroid Build Coastguard Worker expect = pow(x) 9638*da0073e9SAndroid Build Coastguard Worker 9639*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expect, actual) 9640*da0073e9SAndroid Build Coastguard Worker 9641*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 9642*da0073e9SAndroid Build Coastguard Worker all("aten.pow" not in code for code in source_code), 9643*da0073e9SAndroid Build Coastguard Worker msg="Encountered an unexpected fallback to 'aten pow' in dynamo compiled code", 9644*da0073e9SAndroid Build Coastguard Worker ) 9645*da0073e9SAndroid Build Coastguard Worker 9646*da0073e9SAndroid Build Coastguard Worker def test_graph_break_compilation_metrics(self): 9647*da0073e9SAndroid Build Coastguard Worker def fn(x): 9648*da0073e9SAndroid Build Coastguard Worker x.cos() 9649*da0073e9SAndroid Build Coastguard Worker torch._dynamo.graph_break() 9650*da0073e9SAndroid Build Coastguard Worker x.sin() 9651*da0073e9SAndroid Build Coastguard Worker torch._dynamo.graph_break() 9652*da0073e9SAndroid Build Coastguard Worker return x.cos() 9653*da0073e9SAndroid Build Coastguard Worker 9654*da0073e9SAndroid Build Coastguard Worker torch._dynamo.utils.clear_compilation_metrics() 9655*da0073e9SAndroid Build Coastguard Worker x = torch.rand((4, 4)) 9656*da0073e9SAndroid Build Coastguard Worker f = torch.compile(fn, backend="eager") 9657*da0073e9SAndroid Build Coastguard Worker f(x) 9658*da0073e9SAndroid Build Coastguard Worker metrics = torch._dynamo.utils.get_compilation_metrics() 9659*da0073e9SAndroid Build Coastguard Worker # Should only be one restart per event 9660*da0073e9SAndroid Build Coastguard Worker (restart_reason,) = metrics[0].restart_reasons 9661*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 9662*da0073e9SAndroid Build Coastguard Worker "skip function graph_break" in restart_reason, 9663*da0073e9SAndroid Build Coastguard Worker "Should have logged graph break reason", 9664*da0073e9SAndroid Build Coastguard Worker ) 9665*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 9666*da0073e9SAndroid Build Coastguard Worker metrics[0].dynamo_time_before_restart_s 9667*da0073e9SAndroid Build Coastguard Worker <= metrics[0].entire_frame_compile_time_s 9668*da0073e9SAndroid Build Coastguard Worker ) 9669*da0073e9SAndroid Build Coastguard Worker 9670*da0073e9SAndroid Build Coastguard Worker (restart_reason,) = metrics[1].restart_reasons 9671*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 9672*da0073e9SAndroid Build Coastguard Worker "skip function graph_break" in restart_reason, 9673*da0073e9SAndroid Build Coastguard Worker "Should have logged graph break reason", 9674*da0073e9SAndroid Build Coastguard Worker ) 9675*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 9676*da0073e9SAndroid Build Coastguard Worker metrics[1].dynamo_time_before_restart_s 9677*da0073e9SAndroid Build Coastguard Worker <= metrics[1].entire_frame_compile_time_s 9678*da0073e9SAndroid Build Coastguard Worker ) 9679*da0073e9SAndroid Build Coastguard Worker 9680*da0073e9SAndroid Build Coastguard Worker # No restarts 9681*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 9682*da0073e9SAndroid Build Coastguard Worker len(metrics[2].restart_reasons) == 0, "Last compile has no graph break" 9683*da0073e9SAndroid Build Coastguard Worker ) 9684*da0073e9SAndroid Build Coastguard Worker self.assertTrue(metrics[2].dynamo_time_before_restart_s == 0) 9685*da0073e9SAndroid Build Coastguard Worker 9686*da0073e9SAndroid Build Coastguard Worker def test_graph_break_compilation_metrics_on_failure(self): 9687*da0073e9SAndroid Build Coastguard Worker def fn(x): 9688*da0073e9SAndroid Build Coastguard Worker return x.sin() 9689*da0073e9SAndroid Build Coastguard Worker 9690*da0073e9SAndroid Build Coastguard Worker def broken_backend(gm, example_inputs): 9691*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("broken backend") 9692*da0073e9SAndroid Build Coastguard Worker 9693*da0073e9SAndroid Build Coastguard Worker x = torch.rand((4, 4)) 9694*da0073e9SAndroid Build Coastguard Worker f = torch.compile(fn, backend=broken_backend) 9695*da0073e9SAndroid Build Coastguard Worker with unittest.mock.patch("torch._dynamo.config.suppress_errors", True): 9696*da0073e9SAndroid Build Coastguard Worker torch._dynamo.utils.clear_compilation_metrics() 9697*da0073e9SAndroid Build Coastguard Worker f(x) 9698*da0073e9SAndroid Build Coastguard Worker metrics = torch._dynamo.utils.get_compilation_metrics() 9699*da0073e9SAndroid Build Coastguard Worker for metric in metrics: 9700*da0073e9SAndroid Build Coastguard Worker self.assertTrue(metric.dynamo_time_before_restart_s > 0) 9701*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 9702*da0073e9SAndroid Build Coastguard Worker "RuntimeError: broken backend" in metric.fail_reason, 9703*da0073e9SAndroid Build Coastguard Worker "Should have logged fail reason", 9704*da0073e9SAndroid Build Coastguard Worker ) 9705*da0073e9SAndroid Build Coastguard Worker 9706*da0073e9SAndroid Build Coastguard Worker def test_compilation_metrics_size_limit(self): 9707*da0073e9SAndroid Build Coastguard Worker def fn1(x): 9708*da0073e9SAndroid Build Coastguard Worker return x.relu() 9709*da0073e9SAndroid Build Coastguard Worker 9710*da0073e9SAndroid Build Coastguard Worker def fn2(x): 9711*da0073e9SAndroid Build Coastguard Worker return x.cos() 9712*da0073e9SAndroid Build Coastguard Worker 9713*da0073e9SAndroid Build Coastguard Worker def fn3(x): 9714*da0073e9SAndroid Build Coastguard Worker return x.sin() 9715*da0073e9SAndroid Build Coastguard Worker 9716*da0073e9SAndroid Build Coastguard Worker def fn4(x): 9717*da0073e9SAndroid Build Coastguard Worker return x.exp() 9718*da0073e9SAndroid Build Coastguard Worker 9719*da0073e9SAndroid Build Coastguard Worker import contextlib 9720*da0073e9SAndroid Build Coastguard Worker 9721*da0073e9SAndroid Build Coastguard Worker @contextlib.contextmanager 9722*da0073e9SAndroid Build Coastguard Worker def metrics_limit_ctx(): 9723*da0073e9SAndroid Build Coastguard Worker try: 9724*da0073e9SAndroid Build Coastguard Worker torch._dynamo.utils.set_compilation_metrics_limit(3) 9725*da0073e9SAndroid Build Coastguard Worker yield 9726*da0073e9SAndroid Build Coastguard Worker finally: 9727*da0073e9SAndroid Build Coastguard Worker torch._dynamo.utils.set_compilation_metrics_limit( 9728*da0073e9SAndroid Build Coastguard Worker torch._dynamo.utils.DEFAULT_COMPILATION_METRICS_LIMIT 9729*da0073e9SAndroid Build Coastguard Worker ) 9730*da0073e9SAndroid Build Coastguard Worker 9731*da0073e9SAndroid Build Coastguard Worker x = torch.rand((4, 4)) 9732*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 9733*da0073e9SAndroid Build Coastguard Worker torch.compile(fn1, backend="eager")(x) 9734*da0073e9SAndroid Build Coastguard Worker torch.compile(fn2, backend="eager")(x) 9735*da0073e9SAndroid Build Coastguard Worker torch.compile(fn3, backend="eager")(x) 9736*da0073e9SAndroid Build Coastguard Worker torch.compile(fn4, backend="eager")(x) 9737*da0073e9SAndroid Build Coastguard Worker 9738*da0073e9SAndroid Build Coastguard Worker with metrics_limit_ctx(): 9739*da0073e9SAndroid Build Coastguard Worker torch._dynamo.utils.clear_compilation_metrics() 9740*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 9741*da0073e9SAndroid Build Coastguard Worker self.assertEqual(0, len(torch._dynamo.utils.get_compilation_metrics())) 9742*da0073e9SAndroid Build Coastguard Worker torch.compile(fn1, backend="eager")(x) 9743*da0073e9SAndroid Build Coastguard Worker self.assertEqual(1, len(torch._dynamo.utils.get_compilation_metrics())) 9744*da0073e9SAndroid Build Coastguard Worker torch.compile(fn2, backend="eager")(x) 9745*da0073e9SAndroid Build Coastguard Worker self.assertEqual(2, len(torch._dynamo.utils.get_compilation_metrics())) 9746*da0073e9SAndroid Build Coastguard Worker torch.compile(fn3, backend="eager")(x) 9747*da0073e9SAndroid Build Coastguard Worker self.assertEqual(3, len(torch._dynamo.utils.get_compilation_metrics())) 9748*da0073e9SAndroid Build Coastguard Worker torch.compile(fn4, backend="eager")(x) 9749*da0073e9SAndroid Build Coastguard Worker self.assertEqual(3, len(torch._dynamo.utils.get_compilation_metrics())) 9750*da0073e9SAndroid Build Coastguard Worker 9751*da0073e9SAndroid Build Coastguard Worker def test_funcname_cache(self): 9752*da0073e9SAndroid Build Coastguard Worker src = """\ 9753*da0073e9SAndroid Build Coastguard Workerimport torch 9754*da0073e9SAndroid Build Coastguard Workerif True: 9755*da0073e9SAndroid Build Coastguard Worker test = 3 9756*da0073e9SAndroid Build Coastguard Worker 9757*da0073e9SAndroid Build Coastguard Workerclass AAA: 9758*da0073e9SAndroid Build Coastguard Worker class DUMMY: 9759*da0073e9SAndroid Build Coastguard Worker class DUMMY2: 9760*da0073e9SAndroid Build Coastguard Worker pass 9761*da0073e9SAndroid Build Coastguard Worker 9762*da0073e9SAndroid Build Coastguard Worker def dummy(self): 9763*da0073e9SAndroid Build Coastguard Worker def dummy2(): 9764*da0073e9SAndroid Build Coastguard Worker pass 9765*da0073e9SAndroid Build Coastguard Worker class BBB: 9766*da0073e9SAndroid Build Coastguard Worker @staticmethod 9767*da0073e9SAndroid Build Coastguard Worker def CCC(): 9768*da0073e9SAndroid Build Coastguard Worker class DDD: 9769*da0073e9SAndroid Build Coastguard Worker if True: 9770*da0073e9SAndroid Build Coastguard Worker @staticmethod 9771*da0073e9SAndroid Build Coastguard Worker def EEE(): 9772*da0073e9SAndroid Build Coastguard Worker x = [torch.ones(3, 3) for _ in range(5)] 9773*da0073e9SAndroid Build Coastguard Worker return x 9774*da0073e9SAndroid Build Coastguard Worker return DDD 9775*da0073e9SAndroid Build Coastguard Workerdef fn(): 9776*da0073e9SAndroid Build Coastguard Worker return 3 9777*da0073e9SAndroid Build Coastguard Worker""" 9778*da0073e9SAndroid Build Coastguard Worker with tempfile.NamedTemporaryFile(mode="w") as f: 9779*da0073e9SAndroid Build Coastguard Worker f.write(src) 9780*da0073e9SAndroid Build Coastguard Worker f.flush() 9781*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.funcname_cache import get_funcname 9782*da0073e9SAndroid Build Coastguard Worker 9783*da0073e9SAndroid Build Coastguard Worker names = [get_funcname(f.name, i + 1) for i in range(src.count("\n") + 1)] 9784*da0073e9SAndroid Build Coastguard Worker 9785*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 9786*da0073e9SAndroid Build Coastguard Worker "\n".join(names), 9787*da0073e9SAndroid Build Coastguard Worker """\ 9788*da0073e9SAndroid Build Coastguard Worker 9789*da0073e9SAndroid Build Coastguard Worker 9790*da0073e9SAndroid Build Coastguard Worker 9791*da0073e9SAndroid Build Coastguard Worker 9792*da0073e9SAndroid Build Coastguard WorkerAAA 9793*da0073e9SAndroid Build Coastguard WorkerAAA.DUMMY 9794*da0073e9SAndroid Build Coastguard WorkerAAA.DUMMY.DUMMY2 9795*da0073e9SAndroid Build Coastguard WorkerAAA.DUMMY.DUMMY2 9796*da0073e9SAndroid Build Coastguard WorkerAAA.DUMMY.DUMMY2 9797*da0073e9SAndroid Build Coastguard WorkerAAA.dummy 9798*da0073e9SAndroid Build Coastguard WorkerAAA.dummy.dummy2 9799*da0073e9SAndroid Build Coastguard WorkerAAA.dummy.dummy2 9800*da0073e9SAndroid Build Coastguard WorkerAAA.BBB 9801*da0073e9SAndroid Build Coastguard WorkerAAA.BBB 9802*da0073e9SAndroid Build Coastguard WorkerAAA.BBB.CCC 9803*da0073e9SAndroid Build Coastguard WorkerAAA.BBB.CCC.DDD 9804*da0073e9SAndroid Build Coastguard WorkerAAA.BBB.CCC.DDD 9805*da0073e9SAndroid Build Coastguard WorkerAAA.BBB.CCC.DDD 9806*da0073e9SAndroid Build Coastguard WorkerAAA.BBB.CCC.DDD.EEE 9807*da0073e9SAndroid Build Coastguard WorkerAAA.BBB.CCC.DDD.EEE 9808*da0073e9SAndroid Build Coastguard WorkerAAA.BBB.CCC.DDD.EEE 9809*da0073e9SAndroid Build Coastguard WorkerAAA.BBB.CCC 9810*da0073e9SAndroid Build Coastguard Workerfn 9811*da0073e9SAndroid Build Coastguard Workerfn 9812*da0073e9SAndroid Build Coastguard Worker""", 9813*da0073e9SAndroid Build Coastguard Worker ) 9814*da0073e9SAndroid Build Coastguard Worker 9815*da0073e9SAndroid Build Coastguard Worker def test_return_dict_with_graph_break_and_update(self): 9816*da0073e9SAndroid Build Coastguard Worker def create(): 9817*da0073e9SAndroid Build Coastguard Worker torch._dynamo.graph_break() 9818*da0073e9SAndroid Build Coastguard Worker return {0: torch.tensor(3)} 9819*da0073e9SAndroid Build Coastguard Worker 9820*da0073e9SAndroid Build Coastguard Worker def fn(): 9821*da0073e9SAndroid Build Coastguard Worker return {**create()} 9822*da0073e9SAndroid Build Coastguard Worker 9823*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(backend="eager")(fn) 9824*da0073e9SAndroid Build Coastguard Worker result = opt_fn() 9825*da0073e9SAndroid Build Coastguard Worker self.assertIn(0, result) 9826*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(result[0], torch.tensor(3))) 9827*da0073e9SAndroid Build Coastguard Worker 9828*da0073e9SAndroid Build Coastguard Worker def test_dynamo_reset_clears_cache(self): 9829*da0073e9SAndroid Build Coastguard Worker """Test that dynamo bytecode cache is freed 9830*da0073e9SAndroid Build Coastguard Worker when dynamo reset is called 9831*da0073e9SAndroid Build Coastguard Worker """ 9832*da0073e9SAndroid Build Coastguard Worker 9833*da0073e9SAndroid Build Coastguard Worker def fn(x): 9834*da0073e9SAndroid Build Coastguard Worker return torch.sin(x) 9835*da0073e9SAndroid Build Coastguard Worker 9836*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(backend="eager")(fn) 9837*da0073e9SAndroid Build Coastguard Worker opt_fn(torch.randn(3, 3)) 9838*da0073e9SAndroid Build Coastguard Worker 9839*da0073e9SAndroid Build Coastguard Worker c1 = _debug_get_cache_entry_list(fn.__code__) 9840*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(c1), 1) 9841*da0073e9SAndroid Build Coastguard Worker 9842*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 9843*da0073e9SAndroid Build Coastguard Worker c2 = _debug_get_cache_entry_list(fn.__code__) 9844*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(c2), 0) 9845*da0073e9SAndroid Build Coastguard Worker 9846*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch(capture_scalar_outputs=True) 9847*da0073e9SAndroid Build Coastguard Worker def test_guard_size_oblivious(self): 9848*da0073e9SAndroid Build Coastguard Worker # This code, in fact, does NOT work in eager 9849*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager", fullgraph=True) 9850*da0073e9SAndroid Build Coastguard Worker def fn(x): 9851*da0073e9SAndroid Build Coastguard Worker y = torch.zeros(x.item()) 9852*da0073e9SAndroid Build Coastguard Worker if guard_size_oblivious(y.size(0) == 0): 9853*da0073e9SAndroid Build Coastguard Worker assert False 9854*da0073e9SAndroid Build Coastguard Worker return y 9855*da0073e9SAndroid Build Coastguard Worker 9856*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(torch.tensor([0])), torch.zeros(0)) 9857*da0073e9SAndroid Build Coastguard Worker 9858*da0073e9SAndroid Build Coastguard Worker def test_guard_size_oblivious_backed(self): 9859*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager", fullgraph=True) 9860*da0073e9SAndroid Build Coastguard Worker def f(x): 9861*da0073e9SAndroid Build Coastguard Worker y = x.size(0) 9862*da0073e9SAndroid Build Coastguard Worker # This doesn't actually do anything 9863*da0073e9SAndroid Build Coastguard Worker if guard_size_oblivious(y == 0): 9864*da0073e9SAndroid Build Coastguard Worker return torch.randn(1) 9865*da0073e9SAndroid Build Coastguard Worker else: 9866*da0073e9SAndroid Build Coastguard Worker return torch.randn(2) 9867*da0073e9SAndroid Build Coastguard Worker 9868*da0073e9SAndroid Build Coastguard Worker # Should not fail in either case 9869*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f(torch.randn(0)).shape, (1,)) 9870*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f(torch.randn(2)).shape, (2,)) 9871*da0073e9SAndroid Build Coastguard Worker 9872*da0073e9SAndroid Build Coastguard Worker def _test_compile_model_free(self, model_inp_ctr, weakref_watch): 9873*da0073e9SAndroid Build Coastguard Worker """ 9874*da0073e9SAndroid Build Coastguard Worker Args: 9875*da0073e9SAndroid Build Coastguard Worker model_inp_ctr 9876*da0073e9SAndroid Build Coastguard Worker - constructor that returns a new model and inputs to that model 9877*da0073e9SAndroid Build Coastguard Worker weakref_watch 9878*da0073e9SAndroid Build Coastguard Worker - function that returns a layer of the model for weakref to 9879*da0073e9SAndroid Build Coastguard Worker finalize on, so we can check that the layer is freed after 9880*da0073e9SAndroid Build Coastguard Worker the model goes out of scope 9881*da0073e9SAndroid Build Coastguard Worker """ 9882*da0073e9SAndroid Build Coastguard Worker cleared = False 9883*da0073e9SAndroid Build Coastguard Worker 9884*da0073e9SAndroid Build Coastguard Worker def finalize(): 9885*da0073e9SAndroid Build Coastguard Worker nonlocal cleared 9886*da0073e9SAndroid Build Coastguard Worker cleared = True 9887*da0073e9SAndroid Build Coastguard Worker 9888*da0073e9SAndroid Build Coastguard Worker def run(): 9889*da0073e9SAndroid Build Coastguard Worker mod, inp = model_inp_ctr() 9890*da0073e9SAndroid Build Coastguard Worker weakref.finalize(weakref_watch(mod), finalize) 9891*da0073e9SAndroid Build Coastguard Worker torch.compile(mod, backend="eager")(inp) 9892*da0073e9SAndroid Build Coastguard Worker 9893*da0073e9SAndroid Build Coastguard Worker run() 9894*da0073e9SAndroid Build Coastguard Worker gc.collect() 9895*da0073e9SAndroid Build Coastguard Worker self.assertTrue(cleared) 9896*da0073e9SAndroid Build Coastguard Worker 9897*da0073e9SAndroid Build Coastguard Worker def test_custom_module_free(self): 9898*da0073e9SAndroid Build Coastguard Worker """Test that a model is freed when it goes out of scope""" 9899*da0073e9SAndroid Build Coastguard Worker 9900*da0073e9SAndroid Build Coastguard Worker class Mod(torch.nn.Module): 9901*da0073e9SAndroid Build Coastguard Worker def __init__(self): 9902*da0073e9SAndroid Build Coastguard Worker super(Mod, self).__init__() 9903*da0073e9SAndroid Build Coastguard Worker self.fc = torch.nn.Linear(100, 100) 9904*da0073e9SAndroid Build Coastguard Worker 9905*da0073e9SAndroid Build Coastguard Worker def forward(self, out): 9906*da0073e9SAndroid Build Coastguard Worker return self.fc(out) 9907*da0073e9SAndroid Build Coastguard Worker 9908*da0073e9SAndroid Build Coastguard Worker self._test_compile_model_free( 9909*da0073e9SAndroid Build Coastguard Worker lambda: (Mod(), torch.randn(100, 100)), 9910*da0073e9SAndroid Build Coastguard Worker lambda mod: mod.fc, 9911*da0073e9SAndroid Build Coastguard Worker ) 9912*da0073e9SAndroid Build Coastguard Worker 9913*da0073e9SAndroid Build Coastguard Worker def test_sequential_module_free(self): 9914*da0073e9SAndroid Build Coastguard Worker self._test_compile_model_free( 9915*da0073e9SAndroid Build Coastguard Worker lambda: ( 9916*da0073e9SAndroid Build Coastguard Worker torch.nn.Sequential( 9917*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(100, 100), 9918*da0073e9SAndroid Build Coastguard Worker torch.nn.ReLU(), 9919*da0073e9SAndroid Build Coastguard Worker ), 9920*da0073e9SAndroid Build Coastguard Worker torch.randn(100, 100), 9921*da0073e9SAndroid Build Coastguard Worker ), 9922*da0073e9SAndroid Build Coastguard Worker lambda mod: mod[0], 9923*da0073e9SAndroid Build Coastguard Worker ) 9924*da0073e9SAndroid Build Coastguard Worker 9925*da0073e9SAndroid Build Coastguard Worker def test_linear_module_free(self): 9926*da0073e9SAndroid Build Coastguard Worker self._test_compile_model_free( 9927*da0073e9SAndroid Build Coastguard Worker lambda: (torch.nn.Linear(100, 100), torch.randn(100, 100)), 9928*da0073e9SAndroid Build Coastguard Worker lambda mod: mod, 9929*da0073e9SAndroid Build Coastguard Worker ) 9930*da0073e9SAndroid Build Coastguard Worker 9931*da0073e9SAndroid Build Coastguard Worker # The following 2 tests fail due to https://github.com/python/cpython/issues/118013. 9932*da0073e9SAndroid Build Coastguard Worker # Tracked by https://github.com/pytorch/pytorch/issues/124302. 9933*da0073e9SAndroid Build Coastguard Worker # The xfails can be removed once Python 3.12 is updated on CI. 9934*da0073e9SAndroid Build Coastguard Worker @xfailIfPy312 9935*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(True, "Skipping this test for release/2.4") 9936*da0073e9SAndroid Build Coastguard Worker def test_outside_linear_module_free(self): 9937*da0073e9SAndroid Build Coastguard Worker # Compared to test_linear_module_free, the linear 9938*da0073e9SAndroid Build Coastguard Worker # layer is not the code object that is directly compiled. 9939*da0073e9SAndroid Build Coastguard Worker 9940*da0073e9SAndroid Build Coastguard Worker # This test does not use _test_compile_model_free because of difficulty 9941*da0073e9SAndroid Build Coastguard Worker # in handling variable fc. 9942*da0073e9SAndroid Build Coastguard Worker 9943*da0073e9SAndroid Build Coastguard Worker cleared = False 9944*da0073e9SAndroid Build Coastguard Worker 9945*da0073e9SAndroid Build Coastguard Worker def finalize(): 9946*da0073e9SAndroid Build Coastguard Worker nonlocal cleared 9947*da0073e9SAndroid Build Coastguard Worker cleared = True 9948*da0073e9SAndroid Build Coastguard Worker 9949*da0073e9SAndroid Build Coastguard Worker def run(): 9950*da0073e9SAndroid Build Coastguard Worker fc = torch.nn.Linear(100, 100) 9951*da0073e9SAndroid Build Coastguard Worker 9952*da0073e9SAndroid Build Coastguard Worker class Mod(torch.nn.Module): 9953*da0073e9SAndroid Build Coastguard Worker def __init__(self): 9954*da0073e9SAndroid Build Coastguard Worker super().__init__() 9955*da0073e9SAndroid Build Coastguard Worker self.fc_ref = fc 9956*da0073e9SAndroid Build Coastguard Worker 9957*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 9958*da0073e9SAndroid Build Coastguard Worker return self.fc_ref(x) 9959*da0073e9SAndroid Build Coastguard Worker 9960*da0073e9SAndroid Build Coastguard Worker mod = Mod() 9961*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(100, 100) 9962*da0073e9SAndroid Build Coastguard Worker weakref.finalize(fc, finalize) 9963*da0073e9SAndroid Build Coastguard Worker torch.compile(mod, backend="eager")(inp) 9964*da0073e9SAndroid Build Coastguard Worker 9965*da0073e9SAndroid Build Coastguard Worker run() 9966*da0073e9SAndroid Build Coastguard Worker # del fc # This should delete all the references 9967*da0073e9SAndroid Build Coastguard Worker gc.collect() 9968*da0073e9SAndroid Build Coastguard Worker self.assertTrue(cleared) 9969*da0073e9SAndroid Build Coastguard Worker 9970*da0073e9SAndroid Build Coastguard Worker @xfailIfPy312 9971*da0073e9SAndroid Build Coastguard Worker def test_parameter_free(self): 9972*da0073e9SAndroid Build Coastguard Worker def model_inp_ctr(): 9973*da0073e9SAndroid Build Coastguard Worker param = torch.nn.Parameter(torch.randn(100, 100)) 9974*da0073e9SAndroid Build Coastguard Worker 9975*da0073e9SAndroid Build Coastguard Worker class Mod(torch.nn.Module): 9976*da0073e9SAndroid Build Coastguard Worker def __init__(self): 9977*da0073e9SAndroid Build Coastguard Worker super().__init__() 9978*da0073e9SAndroid Build Coastguard Worker self.param = param 9979*da0073e9SAndroid Build Coastguard Worker 9980*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 9981*da0073e9SAndroid Build Coastguard Worker return self.param * x[0] 9982*da0073e9SAndroid Build Coastguard Worker 9983*da0073e9SAndroid Build Coastguard Worker # return param to keep it alive in _test_compile_model_free 9984*da0073e9SAndroid Build Coastguard Worker return Mod(), (torch.randn(100, 100), param) 9985*da0073e9SAndroid Build Coastguard Worker 9986*da0073e9SAndroid Build Coastguard Worker self._test_compile_model_free(model_inp_ctr, lambda mod: mod.param) 9987*da0073e9SAndroid Build Coastguard Worker 9988*da0073e9SAndroid Build Coastguard Worker def test_conditional_list_comp_in_context(self): 9989*da0073e9SAndroid Build Coastguard Worker def fn(inp): 9990*da0073e9SAndroid Build Coastguard Worker try: 9991*da0073e9SAndroid Build Coastguard Worker return [torch.sin(x) for x in inp if x is not None] 9992*da0073e9SAndroid Build Coastguard Worker except Exception: 9993*da0073e9SAndroid Build Coastguard Worker pass 9994*da0073e9SAndroid Build Coastguard Worker 9995*da0073e9SAndroid Build Coastguard Worker inp = [torch.randn(3, 3) for _ in range(3)] + [None] 9996*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager") 9997*da0073e9SAndroid Build Coastguard Worker opt_fn(inp) 9998*da0073e9SAndroid Build Coastguard Worker 9999*da0073e9SAndroid Build Coastguard Worker def test_312_binary_slice_with_graph_break1(self): 10000*da0073e9SAndroid Build Coastguard Worker l1 = torch.nn.Linear(5, 5) 10001*da0073e9SAndroid Build Coastguard Worker l2 = torch.nn.Linear(5, 5) 10002*da0073e9SAndroid Build Coastguard Worker 10003*da0073e9SAndroid Build Coastguard Worker def fn(x): 10004*da0073e9SAndroid Build Coastguard Worker # causes a graph break with items in the stack 10005*da0073e9SAndroid Build Coastguard Worker n = torch.nn.Sequential(l1, l2) 10006*da0073e9SAndroid Build Coastguard Worker out = n[1:](x) 10007*da0073e9SAndroid Build Coastguard Worker return out 10008*da0073e9SAndroid Build Coastguard Worker 10009*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager") 10010*da0073e9SAndroid Build Coastguard Worker opt_fn(torch.randn(5, 5)) 10011*da0073e9SAndroid Build Coastguard Worker 10012*da0073e9SAndroid Build Coastguard Worker def test_312_binary_slice_with_graph_break2(self): 10013*da0073e9SAndroid Build Coastguard Worker class Foo: 10014*da0073e9SAndroid Build Coastguard Worker def __setitem__(self, key, val): 10015*da0073e9SAndroid Build Coastguard Worker pass 10016*da0073e9SAndroid Build Coastguard Worker 10017*da0073e9SAndroid Build Coastguard Worker def __getitem__(self, key): 10018*da0073e9SAndroid Build Coastguard Worker torch._dynamo.graph_break() 10019*da0073e9SAndroid Build Coastguard Worker return 1 10020*da0073e9SAndroid Build Coastguard Worker 10021*da0073e9SAndroid Build Coastguard Worker foo = Foo() 10022*da0073e9SAndroid Build Coastguard Worker 10023*da0073e9SAndroid Build Coastguard Worker def fn(x): 10024*da0073e9SAndroid Build Coastguard Worker # graph break in a STORE_SLICE instruction 10025*da0073e9SAndroid Build Coastguard Worker foo[:] = x 10026*da0073e9SAndroid Build Coastguard Worker # graph break in BINARY_SLICE with has_backedge check 10027*da0073e9SAndroid Build Coastguard Worker x = x + foo[:] 10028*da0073e9SAndroid Build Coastguard Worker if x is None: 10029*da0073e9SAndroid Build Coastguard Worker x = x + 1 10030*da0073e9SAndroid Build Coastguard Worker else: 10031*da0073e9SAndroid Build Coastguard Worker x = x + 1 10032*da0073e9SAndroid Build Coastguard Worker return x 10033*da0073e9SAndroid Build Coastguard Worker 10034*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager") 10035*da0073e9SAndroid Build Coastguard Worker opt_fn(torch.randn(5, 5)) 10036*da0073e9SAndroid Build Coastguard Worker 10037*da0073e9SAndroid Build Coastguard Worker def test_super_after_graph_break(self): 10038*da0073e9SAndroid Build Coastguard Worker class Foo(torch.nn.Sequential): 10039*da0073e9SAndroid Build Coastguard Worker def __init__(self, layers): 10040*da0073e9SAndroid Build Coastguard Worker torch._dynamo.graph_break() 10041*da0073e9SAndroid Build Coastguard Worker super().__init__(*layers) 10042*da0073e9SAndroid Build Coastguard Worker 10043*da0073e9SAndroid Build Coastguard Worker def fn(x): 10044*da0073e9SAndroid Build Coastguard Worker layers = [torch.nn.Linear(3, 3) for _ in range(3)] 10045*da0073e9SAndroid Build Coastguard Worker mod = Foo(layers) 10046*da0073e9SAndroid Build Coastguard Worker return mod(x) 10047*da0073e9SAndroid Build Coastguard Worker 10048*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager") 10049*da0073e9SAndroid Build Coastguard Worker opt_fn(torch.randn(3, 3)) 10050*da0073e9SAndroid Build Coastguard Worker 10051*da0073e9SAndroid Build Coastguard Worker def test_load_fast_and_clear_graph_break(self): 10052*da0073e9SAndroid Build Coastguard Worker # Can result in a segfault in 3.12+ if LOAD_FAST_AND_CLEAR 10053*da0073e9SAndroid Build Coastguard Worker # is not handled properly in a graph break 10054*da0073e9SAndroid Build Coastguard Worker def fn(): 10055*da0073e9SAndroid Build Coastguard Worker out = torch.cat([torch.randn(r, 5) for r in range(3)]) 10056*da0073e9SAndroid Build Coastguard Worker torch._dynamo.graph_break() 10057*da0073e9SAndroid Build Coastguard Worker out = torch.cat([torch.randn(r, 5) for r in range(3)]) 10058*da0073e9SAndroid Build Coastguard Worker return out 10059*da0073e9SAndroid Build Coastguard Worker 10060*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch._dynamo.optimize("eager")(fn)().shape, (3, 5)) 10061*da0073e9SAndroid Build Coastguard Worker 10062*da0073e9SAndroid Build Coastguard Worker def test_raises_importerror1(self): 10063*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager") 10064*da0073e9SAndroid Build Coastguard Worker def fn(x): 10065*da0073e9SAndroid Build Coastguard Worker try: 10066*da0073e9SAndroid Build Coastguard Worker import some_module_that_surely_does_not_exist 10067*da0073e9SAndroid Build Coastguard Worker 10068*da0073e9SAndroid Build Coastguard Worker return 10069*da0073e9SAndroid Build Coastguard Worker except ImportError: 10070*da0073e9SAndroid Build Coastguard Worker pass 10071*da0073e9SAndroid Build Coastguard Worker return x.sin() 10072*da0073e9SAndroid Build Coastguard Worker 10073*da0073e9SAndroid Build Coastguard Worker x = torch.randn(8) 10074*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(x), x.sin()) 10075*da0073e9SAndroid Build Coastguard Worker 10076*da0073e9SAndroid Build Coastguard Worker def test_raises_importerror2(self): 10077*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager") 10078*da0073e9SAndroid Build Coastguard Worker def fn(x): 10079*da0073e9SAndroid Build Coastguard Worker import some_module_that_surely_does_not_exist 10080*da0073e9SAndroid Build Coastguard Worker 10081*da0073e9SAndroid Build Coastguard Worker return x + 1 10082*da0073e9SAndroid Build Coastguard Worker 10083*da0073e9SAndroid Build Coastguard Worker x = torch.randn(8) 10084*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ImportError): 10085*da0073e9SAndroid Build Coastguard Worker fn(x) 10086*da0073e9SAndroid Build Coastguard Worker 10087*da0073e9SAndroid Build Coastguard Worker def test_dynamo_cache_move_to_front(self): 10088*da0073e9SAndroid Build Coastguard Worker def fn(x, const): 10089*da0073e9SAndroid Build Coastguard Worker return x + const 10090*da0073e9SAndroid Build Coastguard Worker 10091*da0073e9SAndroid Build Coastguard Worker # dynamic=False forces Dynamo to recompile 10092*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager", dynamic=False) 10093*da0073e9SAndroid Build Coastguard Worker 10094*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(3, 3) 10095*da0073e9SAndroid Build Coastguard Worker 10096*da0073e9SAndroid Build Coastguard Worker # NOTE: assumes that each cache entry is guarded 10097*da0073e9SAndroid Build Coastguard Worker # on unique Mod instance 10098*da0073e9SAndroid Build Coastguard Worker opt_fn(inp, 1) 10099*da0073e9SAndroid Build Coastguard Worker opt_fn(inp, 2) 10100*da0073e9SAndroid Build Coastguard Worker opt_fn(inp, 3) 10101*da0073e9SAndroid Build Coastguard Worker 10102*da0073e9SAndroid Build Coastguard Worker c1 = _debug_get_cache_entry_list(fn.__code__) 10103*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(c1), 3) 10104*da0073e9SAndroid Build Coastguard Worker 10105*da0073e9SAndroid Build Coastguard Worker # move cache entry to front 10106*da0073e9SAndroid Build Coastguard Worker opt_fn(inp, 2) 10107*da0073e9SAndroid Build Coastguard Worker c2 = _debug_get_cache_entry_list(fn.__code__) 10108*da0073e9SAndroid Build Coastguard Worker self.assertIs(c1[1], c2[0]) 10109*da0073e9SAndroid Build Coastguard Worker 10110*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch(inline_inbuilt_nn_modules=False) 10111*da0073e9SAndroid Build Coastguard Worker def test_dynamo_cache_invalidate(self): 10112*da0073e9SAndroid Build Coastguard Worker class Mod(torch.nn.Module): 10113*da0073e9SAndroid Build Coastguard Worker def __init__(self): 10114*da0073e9SAndroid Build Coastguard Worker super(Mod, self).__init__() 10115*da0073e9SAndroid Build Coastguard Worker self.fc = torch.nn.Linear(3, 3) 10116*da0073e9SAndroid Build Coastguard Worker 10117*da0073e9SAndroid Build Coastguard Worker def forward(self, out): 10118*da0073e9SAndroid Build Coastguard Worker return self.fc(out) 10119*da0073e9SAndroid Build Coastguard Worker 10120*da0073e9SAndroid Build Coastguard Worker def fn(x, mod): 10121*da0073e9SAndroid Build Coastguard Worker return mod(x) 10122*da0073e9SAndroid Build Coastguard Worker 10123*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager") 10124*da0073e9SAndroid Build Coastguard Worker 10125*da0073e9SAndroid Build Coastguard Worker m1 = Mod() 10126*da0073e9SAndroid Build Coastguard Worker m2 = Mod() 10127*da0073e9SAndroid Build Coastguard Worker m3 = Mod() 10128*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(3, 3) 10129*da0073e9SAndroid Build Coastguard Worker 10130*da0073e9SAndroid Build Coastguard Worker # NOTE: assumes that each cache entry is guarded 10131*da0073e9SAndroid Build Coastguard Worker # on unique Mod instance 10132*da0073e9SAndroid Build Coastguard Worker opt_fn(inp, m1) 10133*da0073e9SAndroid Build Coastguard Worker opt_fn(inp, m2) 10134*da0073e9SAndroid Build Coastguard Worker opt_fn(inp, m3) 10135*da0073e9SAndroid Build Coastguard Worker 10136*da0073e9SAndroid Build Coastguard Worker c1 = _debug_get_cache_entry_list(fn.__code__) 10137*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(c1), 3) 10138*da0073e9SAndroid Build Coastguard Worker 10139*da0073e9SAndroid Build Coastguard Worker # move cache entry to front 10140*da0073e9SAndroid Build Coastguard Worker opt_fn(inp, m2) 10141*da0073e9SAndroid Build Coastguard Worker c2 = _debug_get_cache_entry_list(fn.__code__) 10142*da0073e9SAndroid Build Coastguard Worker self.assertIs(c1[1], c2[0]) 10143*da0073e9SAndroid Build Coastguard Worker 10144*da0073e9SAndroid Build Coastguard Worker # delete center of cache 10145*da0073e9SAndroid Build Coastguard Worker del m3 10146*da0073e9SAndroid Build Coastguard Worker c3 = _debug_get_cache_entry_list(fn.__code__) 10147*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(c3), 2) 10148*da0073e9SAndroid Build Coastguard Worker self.assertIs(c3[0], c2[0]) 10149*da0073e9SAndroid Build Coastguard Worker self.assertIs(c3[1], c2[2]) 10150*da0073e9SAndroid Build Coastguard Worker 10151*da0073e9SAndroid Build Coastguard Worker # delete end of cache 10152*da0073e9SAndroid Build Coastguard Worker del m1 10153*da0073e9SAndroid Build Coastguard Worker c4 = _debug_get_cache_entry_list(fn.__code__) 10154*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(c4), 1) 10155*da0073e9SAndroid Build Coastguard Worker self.assertIs(c4[0], c3[0]) 10156*da0073e9SAndroid Build Coastguard Worker 10157*da0073e9SAndroid Build Coastguard Worker del m2 10158*da0073e9SAndroid Build Coastguard Worker c5 = _debug_get_cache_entry_list(fn.__code__) 10159*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(c5), 0) 10160*da0073e9SAndroid Build Coastguard Worker 10161*da0073e9SAndroid Build Coastguard Worker def test_grad_none(self): 10162*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 10163*da0073e9SAndroid Build Coastguard Worker x.grad = torch.abs(y) 10164*da0073e9SAndroid Build Coastguard Worker x.grad.add_(y) 10165*da0073e9SAndroid Build Coastguard Worker return torch.abs(y) 10166*da0073e9SAndroid Build Coastguard Worker 10167*da0073e9SAndroid Build Coastguard Worker y = torch.arange(4).reshape(2, 2).to(torch.float) 10168*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 2) 10169*da0073e9SAndroid Build Coastguard Worker x.grad = None 10170*da0073e9SAndroid Build Coastguard Worker 10171*da0073e9SAndroid Build Coastguard Worker z = fn(x, y) 10172*da0073e9SAndroid Build Coastguard Worker ref_y = torch.clone(z).detach() 10173*da0073e9SAndroid Build Coastguard Worker ref_x_grad = torch.clone(x.grad).detach() 10174*da0073e9SAndroid Build Coastguard Worker 10175*da0073e9SAndroid Build Coastguard Worker y = torch.arange(4).reshape(2, 2).to(torch.float) 10176*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 2) 10177*da0073e9SAndroid Build Coastguard Worker x.grad = None 10178*da0073e9SAndroid Build Coastguard Worker 10179*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager") 10180*da0073e9SAndroid Build Coastguard Worker z = opt_fn(x, y) 10181*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z, ref_y) 10182*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, ref_x_grad) 10183*da0073e9SAndroid Build Coastguard Worker 10184*da0073e9SAndroid Build Coastguard Worker def test_grad_non_none(self): 10185*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 10186*da0073e9SAndroid Build Coastguard Worker x.grad.add_(y) 10187*da0073e9SAndroid Build Coastguard Worker return torch.abs(y) 10188*da0073e9SAndroid Build Coastguard Worker 10189*da0073e9SAndroid Build Coastguard Worker y = torch.ones(2, 2) 10190*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 2) 10191*da0073e9SAndroid Build Coastguard Worker x.grad = torch.arange(4).reshape(2, 2).to(torch.float) 10192*da0073e9SAndroid Build Coastguard Worker 10193*da0073e9SAndroid Build Coastguard Worker z = fn(x, y) 10194*da0073e9SAndroid Build Coastguard Worker ref_y = torch.clone(z).detach() 10195*da0073e9SAndroid Build Coastguard Worker ref_x_grad = torch.clone(x.grad).detach() 10196*da0073e9SAndroid Build Coastguard Worker 10197*da0073e9SAndroid Build Coastguard Worker y = torch.ones(2, 2) 10198*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 2) 10199*da0073e9SAndroid Build Coastguard Worker x.grad = torch.arange(4).reshape(2, 2).to(torch.float) 10200*da0073e9SAndroid Build Coastguard Worker 10201*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounterWithBackend("eager") 10202*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend=cnt) 10203*da0073e9SAndroid Build Coastguard Worker z = opt_fn(x, y) 10204*da0073e9SAndroid Build Coastguard Worker 10205*da0073e9SAndroid Build Coastguard Worker # Ensure that the generated graph returns only one output. We want the 10206*da0073e9SAndroid Build Coastguard Worker # add_ on the grad to be part of the graph itself, so that inductor can 10207*da0073e9SAndroid Build Coastguard Worker # theoretically move the add_ and resutling copy_ nodes at the right 10208*da0073e9SAndroid Build Coastguard Worker # place to free memory. 10209*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(list(cnt.graphs[0].graph.nodes)[-1].all_input_nodes), 1) 10210*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z, ref_y) 10211*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, ref_x_grad) 10212*da0073e9SAndroid Build Coastguard Worker 10213*da0073e9SAndroid Build Coastguard Worker def test_new_with_int_list(self): 10214*da0073e9SAndroid Build Coastguard Worker # Make sure torch.Tensor.new(int argument list) behaves the same on dynamo. 10215*da0073e9SAndroid Build Coastguard Worker def fn(x): 10216*da0073e9SAndroid Build Coastguard Worker return x.new(*x.size()) + 5 10217*da0073e9SAndroid Build Coastguard Worker 10218*da0073e9SAndroid Build Coastguard Worker optfn = torch.compile(backend="eager")(fn) 10219*da0073e9SAndroid Build Coastguard Worker 10220*da0073e9SAndroid Build Coastguard Worker x = torch.arange(10).view(2, 5) 10221*da0073e9SAndroid Build Coastguard Worker 10222*da0073e9SAndroid Build Coastguard Worker expected = fn(x) 10223*da0073e9SAndroid Build Coastguard Worker actual = optfn(x) 10224*da0073e9SAndroid Build Coastguard Worker 10225*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected.dtype, actual.dtype) 10226*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected.shape, actual.shape) 10227*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected.stride(), actual.stride()) 10228*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected.storage_offset(), actual.storage_offset()) 10229*da0073e9SAndroid Build Coastguard Worker 10230*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch(guard_nn_modules=True) 10231*da0073e9SAndroid Build Coastguard Worker def test_hasattr_nn_module_guard(self): 10232*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 10233*da0073e9SAndroid Build Coastguard Worker def __init__(self): 10234*da0073e9SAndroid Build Coastguard Worker super().__init__() 10235*da0073e9SAndroid Build Coastguard Worker self.a = torch.nn.Linear(3, 3) 10236*da0073e9SAndroid Build Coastguard Worker 10237*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 10238*da0073e9SAndroid Build Coastguard Worker if hasattr(self, "a"): 10239*da0073e9SAndroid Build Coastguard Worker return self.a(x) 10240*da0073e9SAndroid Build Coastguard Worker else: 10241*da0073e9SAndroid Build Coastguard Worker return x 10242*da0073e9SAndroid Build Coastguard Worker 10243*da0073e9SAndroid Build Coastguard Worker m = M() 10244*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 3) 10245*da0073e9SAndroid Build Coastguard Worker ref = m(x) 10246*da0073e9SAndroid Build Coastguard Worker 10247*da0073e9SAndroid Build Coastguard Worker opt_m = torch.compile(backend="eager")(m) 10248*da0073e9SAndroid Build Coastguard Worker res = opt_m(x) 10249*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 10250*da0073e9SAndroid Build Coastguard Worker 10251*da0073e9SAndroid Build Coastguard Worker def test_ordered_dict_move_to_end(self): 10252*da0073e9SAndroid Build Coastguard Worker d = { 10253*da0073e9SAndroid Build Coastguard Worker "foo": 1, 10254*da0073e9SAndroid Build Coastguard Worker "bar": 2, 10255*da0073e9SAndroid Build Coastguard Worker } 10256*da0073e9SAndroid Build Coastguard Worker 10257*da0073e9SAndroid Build Coastguard Worker d = collections.OrderedDict(d) 10258*da0073e9SAndroid Build Coastguard Worker d.move_to_end("foo") 10259*da0073e9SAndroid Build Coastguard Worker 10260*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager") 10261*da0073e9SAndroid Build Coastguard Worker def fn(x, d): 10262*da0073e9SAndroid Build Coastguard Worker return x * d["foo"] * d["bar"] 10263*da0073e9SAndroid Build Coastguard Worker 10264*da0073e9SAndroid Build Coastguard Worker fn(torch.randn(4), d) 10265*da0073e9SAndroid Build Coastguard Worker with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): 10266*da0073e9SAndroid Build Coastguard Worker fn(torch.randn(4), d) 10267*da0073e9SAndroid Build Coastguard Worker 10268*da0073e9SAndroid Build Coastguard Worker def test_defaultdict(self): 10269*da0073e9SAndroid Build Coastguard Worker d = collections.defaultdict() 10270*da0073e9SAndroid Build Coastguard Worker d["foo"] = 1 10271*da0073e9SAndroid Build Coastguard Worker d["bar"] = 2 10272*da0073e9SAndroid Build Coastguard Worker 10273*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager") 10274*da0073e9SAndroid Build Coastguard Worker def fn(x, d): 10275*da0073e9SAndroid Build Coastguard Worker return x * d["foo"] * d["bar"] 10276*da0073e9SAndroid Build Coastguard Worker 10277*da0073e9SAndroid Build Coastguard Worker fn(torch.randn(4), d) 10278*da0073e9SAndroid Build Coastguard Worker with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): 10279*da0073e9SAndroid Build Coastguard Worker fn(torch.randn(4), d) 10280*da0073e9SAndroid Build Coastguard Worker 10281*da0073e9SAndroid Build Coastguard Worker def test_custom_dict(self): 10282*da0073e9SAndroid Build Coastguard Worker class MyDict(dict): 10283*da0073e9SAndroid Build Coastguard Worker pass 10284*da0073e9SAndroid Build Coastguard Worker 10285*da0073e9SAndroid Build Coastguard Worker d = { 10286*da0073e9SAndroid Build Coastguard Worker "foo": 1, 10287*da0073e9SAndroid Build Coastguard Worker "bar": 2, 10288*da0073e9SAndroid Build Coastguard Worker } 10289*da0073e9SAndroid Build Coastguard Worker 10290*da0073e9SAndroid Build Coastguard Worker d = MyDict(d) 10291*da0073e9SAndroid Build Coastguard Worker 10292*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager") 10293*da0073e9SAndroid Build Coastguard Worker def fn(x, d): 10294*da0073e9SAndroid Build Coastguard Worker return x * d["foo"] * d["bar"] 10295*da0073e9SAndroid Build Coastguard Worker 10296*da0073e9SAndroid Build Coastguard Worker fn(torch.randn(4), d) 10297*da0073e9SAndroid Build Coastguard Worker with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): 10298*da0073e9SAndroid Build Coastguard Worker fn(torch.randn(4), d) 10299*da0073e9SAndroid Build Coastguard Worker 10300*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "requires cuda") 10301*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch( 10302*da0073e9SAndroid Build Coastguard Worker capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True 10303*da0073e9SAndroid Build Coastguard Worker ) 10304*da0073e9SAndroid Build Coastguard Worker @torch._functorch.config.patch(fake_tensor_propagate_real_tensors=True) 10305*da0073e9SAndroid Build Coastguard Worker def test_interpolate_propagate_real_tensors(self): 10306*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager", fullgraph=True) 10307*da0073e9SAndroid Build Coastguard Worker def f(mask, box): 10308*da0073e9SAndroid Build Coastguard Worker # u0, u1 = mask.tolist() 10309*da0073e9SAndroid Build Coastguard Worker mask = torch.randn(1, 1, 30, 30, device="cuda") 10310*da0073e9SAndroid Build Coastguard Worker h, w = box.tolist() 10311*da0073e9SAndroid Build Coastguard Worker return torch.nn.functional.interpolate( 10312*da0073e9SAndroid Build Coastguard Worker mask, (h, w), mode="bilinear", align_corners=False 10313*da0073e9SAndroid Build Coastguard Worker ) 10314*da0073e9SAndroid Build Coastguard Worker 10315*da0073e9SAndroid Build Coastguard Worker f(torch.tensor([30, 30], device="cuda"), torch.tensor([68, 32], device="cuda")) 10316*da0073e9SAndroid Build Coastguard Worker 10317*da0073e9SAndroid Build Coastguard Worker def test_custom_iter_dict(self): 10318*da0073e9SAndroid Build Coastguard Worker class ReversedDict(dict): 10319*da0073e9SAndroid Build Coastguard Worker def __iter__(self): 10320*da0073e9SAndroid Build Coastguard Worker return reversed(list(self.keys())) 10321*da0073e9SAndroid Build Coastguard Worker 10322*da0073e9SAndroid Build Coastguard Worker d = { 10323*da0073e9SAndroid Build Coastguard Worker "foo": 1, 10324*da0073e9SAndroid Build Coastguard Worker "bar": 2, 10325*da0073e9SAndroid Build Coastguard Worker } 10326*da0073e9SAndroid Build Coastguard Worker 10327*da0073e9SAndroid Build Coastguard Worker d = ReversedDict(d) 10328*da0073e9SAndroid Build Coastguard Worker 10329*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager") 10330*da0073e9SAndroid Build Coastguard Worker def fn(x, d): 10331*da0073e9SAndroid Build Coastguard Worker return x * d["foo"] * d["bar"] 10332*da0073e9SAndroid Build Coastguard Worker 10333*da0073e9SAndroid Build Coastguard Worker fn(torch.randn(4), d) 10334*da0073e9SAndroid Build Coastguard Worker with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): 10335*da0073e9SAndroid Build Coastguard Worker fn(torch.randn(4), d) 10336*da0073e9SAndroid Build Coastguard Worker 10337*da0073e9SAndroid Build Coastguard Worker def test_custom_keys_iter_dict(self): 10338*da0073e9SAndroid Build Coastguard Worker class ReversedDict(dict): 10339*da0073e9SAndroid Build Coastguard Worker def keys(self): 10340*da0073e9SAndroid Build Coastguard Worker return ["bar", "foo"] 10341*da0073e9SAndroid Build Coastguard Worker 10342*da0073e9SAndroid Build Coastguard Worker d = { 10343*da0073e9SAndroid Build Coastguard Worker "foo": 1, 10344*da0073e9SAndroid Build Coastguard Worker "bar": 2, 10345*da0073e9SAndroid Build Coastguard Worker } 10346*da0073e9SAndroid Build Coastguard Worker 10347*da0073e9SAndroid Build Coastguard Worker d = ReversedDict(d) 10348*da0073e9SAndroid Build Coastguard Worker 10349*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager") 10350*da0073e9SAndroid Build Coastguard Worker def fn(x, d): 10351*da0073e9SAndroid Build Coastguard Worker return x * d["foo"] * d["bar"] 10352*da0073e9SAndroid Build Coastguard Worker 10353*da0073e9SAndroid Build Coastguard Worker fn(torch.randn(4), d) 10354*da0073e9SAndroid Build Coastguard Worker with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): 10355*da0073e9SAndroid Build Coastguard Worker fn(torch.randn(4), d) 10356*da0073e9SAndroid Build Coastguard Worker 10357*da0073e9SAndroid Build Coastguard Worker def test_dict_guard_on_keys_order(self): 10358*da0073e9SAndroid Build Coastguard Worker d = { 10359*da0073e9SAndroid Build Coastguard Worker 2: 4, 10360*da0073e9SAndroid Build Coastguard Worker 3: 5, 10361*da0073e9SAndroid Build Coastguard Worker } 10362*da0073e9SAndroid Build Coastguard Worker 10363*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 10364*da0073e9SAndroid Build Coastguard Worker 10365*da0073e9SAndroid Build Coastguard Worker def fn(x, d): 10366*da0073e9SAndroid Build Coastguard Worker for key, value in d.items(): 10367*da0073e9SAndroid Build Coastguard Worker x = x * key + value 10368*da0073e9SAndroid Build Coastguard Worker return x 10369*da0073e9SAndroid Build Coastguard Worker 10370*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend=cnts) 10371*da0073e9SAndroid Build Coastguard Worker opt_fn(torch.randn(4), d) 10372*da0073e9SAndroid Build Coastguard Worker opt_fn(torch.randn(4), d) 10373*da0073e9SAndroid Build Coastguard Worker # No recompilation 10374*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 10375*da0073e9SAndroid Build Coastguard Worker 10376*da0073e9SAndroid Build Coastguard Worker # move 2 to the end 10377*da0073e9SAndroid Build Coastguard Worker d[2] = d.pop(2) 10378*da0073e9SAndroid Build Coastguard Worker 10379*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4) 10380*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x, d) 10381*da0073e9SAndroid Build Coastguard Worker # Check recompilation 10382*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 2) 10383*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, fn(x, d)) 10384*da0073e9SAndroid Build Coastguard Worker 10385*da0073e9SAndroid Build Coastguard Worker def test_dict_guard_on_keys_order2(self): 10386*da0073e9SAndroid Build Coastguard Worker d = { 10387*da0073e9SAndroid Build Coastguard Worker 2: 4, 10388*da0073e9SAndroid Build Coastguard Worker 3: 5, 10389*da0073e9SAndroid Build Coastguard Worker } 10390*da0073e9SAndroid Build Coastguard Worker 10391*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 10392*da0073e9SAndroid Build Coastguard Worker 10393*da0073e9SAndroid Build Coastguard Worker def fn(x, d): 10394*da0073e9SAndroid Build Coastguard Worker for key in d: 10395*da0073e9SAndroid Build Coastguard Worker value = d[key] 10396*da0073e9SAndroid Build Coastguard Worker x = x * key + value 10397*da0073e9SAndroid Build Coastguard Worker return x 10398*da0073e9SAndroid Build Coastguard Worker 10399*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend=cnts) 10400*da0073e9SAndroid Build Coastguard Worker opt_fn(torch.randn(4), d) 10401*da0073e9SAndroid Build Coastguard Worker opt_fn(torch.randn(4), d) 10402*da0073e9SAndroid Build Coastguard Worker # No recompilation 10403*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 10404*da0073e9SAndroid Build Coastguard Worker 10405*da0073e9SAndroid Build Coastguard Worker # move 2 to the end 10406*da0073e9SAndroid Build Coastguard Worker d[2] = d.pop(2) 10407*da0073e9SAndroid Build Coastguard Worker 10408*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4) 10409*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x, d) 10410*da0073e9SAndroid Build Coastguard Worker # Check recompilation 10411*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 2) 10412*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, fn(x, d)) 10413*da0073e9SAndroid Build Coastguard Worker 10414*da0073e9SAndroid Build Coastguard Worker def test_contains_dunder_dict(self): 10415*da0073e9SAndroid Build Coastguard Worker class UserDefined: 10416*da0073e9SAndroid Build Coastguard Worker def __init__(self): 10417*da0073e9SAndroid Build Coastguard Worker self.a = 3 10418*da0073e9SAndroid Build Coastguard Worker self.b = 5 10419*da0073e9SAndroid Build Coastguard Worker 10420*da0073e9SAndroid Build Coastguard Worker def run(self, x): 10421*da0073e9SAndroid Build Coastguard Worker if "a" in self.__dict__: 10422*da0073e9SAndroid Build Coastguard Worker x = x * self.a 10423*da0073e9SAndroid Build Coastguard Worker if "b" in self.__dict__: 10424*da0073e9SAndroid Build Coastguard Worker x = x * self.b 10425*da0073e9SAndroid Build Coastguard Worker self.c = 7 10426*da0073e9SAndroid Build Coastguard Worker if "c" in self.__dict__: 10427*da0073e9SAndroid Build Coastguard Worker x = x * self.c 10428*da0073e9SAndroid Build Coastguard Worker return x * self.__dict__.get("a") * self.__dict__.get("z", 2) 10429*da0073e9SAndroid Build Coastguard Worker 10430*da0073e9SAndroid Build Coastguard Worker obj = UserDefined() 10431*da0073e9SAndroid Build Coastguard Worker 10432*da0073e9SAndroid Build Coastguard Worker def fn(x): 10433*da0073e9SAndroid Build Coastguard Worker return obj.run(x) 10434*da0073e9SAndroid Build Coastguard Worker 10435*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4) 10436*da0073e9SAndroid Build Coastguard Worker ref = fn(x) 10437*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 10438*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x) 10439*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 10440*da0073e9SAndroid Build Coastguard Worker 10441*da0073e9SAndroid Build Coastguard Worker def test_module_dunder_dict(self): 10442*da0073e9SAndroid Build Coastguard Worker class MyModule(torch.nn.Module): 10443*da0073e9SAndroid Build Coastguard Worker def __init__(self): 10444*da0073e9SAndroid Build Coastguard Worker super().__init__() 10445*da0073e9SAndroid Build Coastguard Worker self.foo = 1 10446*da0073e9SAndroid Build Coastguard Worker self.bar = 2 10447*da0073e9SAndroid Build Coastguard Worker self.baz = 3 10448*da0073e9SAndroid Build Coastguard Worker 10449*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 10450*da0073e9SAndroid Build Coastguard Worker if "foo" in self.__dict__: 10451*da0073e9SAndroid Build Coastguard Worker return x * self.bar 10452*da0073e9SAndroid Build Coastguard Worker return x * self.baz 10453*da0073e9SAndroid Build Coastguard Worker 10454*da0073e9SAndroid Build Coastguard Worker mod = MyModule() 10455*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10) 10456*da0073e9SAndroid Build Coastguard Worker opt_mod = torch.compile(mod, backend="eager", fullgraph=True) 10457*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mod(x), opt_mod(x)) 10458*da0073e9SAndroid Build Coastguard Worker 10459*da0073e9SAndroid Build Coastguard Worker 10460*da0073e9SAndroid Build Coastguard Workerclass TestTracer(JitTestCase): 10461*da0073e9SAndroid Build Coastguard Worker def test_jit_save(self): 10462*da0073e9SAndroid Build Coastguard Worker def fn(): 10463*da0073e9SAndroid Build Coastguard Worker class Foo(torch.nn.Module): 10464*da0073e9SAndroid Build Coastguard Worker def __init__(self): 10465*da0073e9SAndroid Build Coastguard Worker super().__init__() 10466*da0073e9SAndroid Build Coastguard Worker self.a = 3 10467*da0073e9SAndroid Build Coastguard Worker 10468*da0073e9SAndroid Build Coastguard Worker @torch.jit.export 10469*da0073e9SAndroid Build Coastguard Worker def __getstate__(self): 10470*da0073e9SAndroid Build Coastguard Worker return (3, self.training) 10471*da0073e9SAndroid Build Coastguard Worker 10472*da0073e9SAndroid Build Coastguard Worker @torch.jit.export 10473*da0073e9SAndroid Build Coastguard Worker def __setstate__(self, state): 10474*da0073e9SAndroid Build Coastguard Worker self.a = state[0] 10475*da0073e9SAndroid Build Coastguard Worker self.training = state[1] 10476*da0073e9SAndroid Build Coastguard Worker 10477*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 10478*da0073e9SAndroid Build Coastguard Worker return x + self.a 10479*da0073e9SAndroid Build Coastguard Worker 10480*da0073e9SAndroid Build Coastguard Worker f = Foo() 10481*da0073e9SAndroid Build Coastguard Worker 10482*da0073e9SAndroid Build Coastguard Worker return torch.jit.trace(f, (torch.rand(3, 4),)) 10483*da0073e9SAndroid Build Coastguard Worker 10484*da0073e9SAndroid Build Coastguard Worker fn() 10485*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager")(fn) 10486*da0073e9SAndroid Build Coastguard Worker opt_fn() 10487*da0073e9SAndroid Build Coastguard Worker 10488*da0073e9SAndroid Build Coastguard Worker 10489*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 10490*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.test_case import run_tests 10491*da0073e9SAndroid Build Coastguard Worker 10492*da0073e9SAndroid Build Coastguard Worker run_tests() 10493