1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dynamo"] 2*da0073e9SAndroid Build Coastguard Workerimport math 3*da0073e9SAndroid Build Coastguard Workerimport random 4*da0073e9SAndroid Build Coastguard Workerimport unittest 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Workerimport numpy as np 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Workerimport torch 9*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.test_case 10*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.testing 11*da0073e9SAndroid Build Coastguard Workerimport torch.nn.functional as F 12*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.comptime import comptime 13*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.testing import CompileCounter, same 14*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import skipIfWindows 15*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.logging_utils import logs_to_string 16*da0073e9SAndroid Build Coastguard Worker 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Worker# The intention of this test file is you should put test cases specifically 19*da0073e9SAndroid Build Coastguard Worker# for assume_static_by_default=False, aka you want to YOLO make everything as 20*da0073e9SAndroid Build Coastguard Worker# dynamic as possible. If you want to test the more normal situation where 21*da0073e9SAndroid Build Coastguard Worker# you assume static by default, put it in a regular test file and 22*da0073e9SAndroid Build Coastguard Worker# test_dynamic_shapes will cover both the YOLO and non-YOLO cases. 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Worker 25*da0073e9SAndroid Build Coastguard Worker@torch._dynamo.config.patch(assume_static_by_default=False) 26*da0073e9SAndroid Build Coastguard Workerclass UnspecTests(torch._dynamo.test_case.TestCase): 27*da0073e9SAndroid Build Coastguard Worker def test_numpy_correctness(self): 28*da0073e9SAndroid Build Coastguard Worker def fn(x, y, z): 29*da0073e9SAndroid Build Coastguard Worker xy = [x + y, y, False] 30*da0073e9SAndroid Build Coastguard Worker np_x = x.numpy() 31*da0073e9SAndroid Build Coastguard Worker np_y = y.numpy() 32*da0073e9SAndroid Build Coastguard Worker return { 33*da0073e9SAndroid Build Coastguard Worker "x": x, 34*da0073e9SAndroid Build Coastguard Worker "z": z, 35*da0073e9SAndroid Build Coastguard Worker "a": np_y.sum(), 36*da0073e9SAndroid Build Coastguard Worker "b": xy, 37*da0073e9SAndroid Build Coastguard Worker "c": np_y[0][0] / 68, 38*da0073e9SAndroid Build Coastguard Worker "d": np_x.sum(), 39*da0073e9SAndroid Build Coastguard Worker "e": np_x + np_y, 40*da0073e9SAndroid Build Coastguard Worker }, x + np_y.sum() + z 41*da0073e9SAndroid Build Coastguard Worker 42*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float64) 43*da0073e9SAndroid Build Coastguard Worker y = torch.ones([2, 2], dtype=torch.int64) 44*da0073e9SAndroid Build Coastguard Worker z = np.int64(12) 45*da0073e9SAndroid Build Coastguard Worker res1 = fn(x, y, z) 46*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 47*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 48*da0073e9SAndroid Build Coastguard Worker res2 = opt_fn(x, y, z) 49*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, res2) 50*da0073e9SAndroid Build Coastguard Worker 51*da0073e9SAndroid Build Coastguard Worker def test_no_recompilations(self): 52*da0073e9SAndroid Build Coastguard Worker # no recompilations if passing on different numpy int values 53*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 54*da0073e9SAndroid Build Coastguard Worker return {"a": x + 1, "b": y / 2} 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float64) 57*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 58*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 59*da0073e9SAndroid Build Coastguard Worker for i in range(10): 60*da0073e9SAndroid Build Coastguard Worker opt_fn(x, np.int64(i)) 61*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 62*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.op_count, 2) 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Worker @unittest.expectedFailure # array scalars decay to 0D arrays 65*da0073e9SAndroid Build Coastguard Worker def test_builtin_max_min(self): 66*da0073e9SAndroid Build Coastguard Worker # test unspecialized primitive max/min 67*da0073e9SAndroid Build Coastguard Worker def fn(x, y, z): 68*da0073e9SAndroid Build Coastguard Worker return z + 1, max(x, y), min(x - 4, y) 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Worker x = np.int64(12) 71*da0073e9SAndroid Build Coastguard Worker y = 10 72*da0073e9SAndroid Build Coastguard Worker z = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float64) 73*da0073e9SAndroid Build Coastguard Worker res1 = fn(x, y, z) 74*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 75*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 76*da0073e9SAndroid Build Coastguard Worker res2 = opt_fn(x, y, z) 77*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(res1, res2, relax_numpy_equality=True)) 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard Worker def test_feed_random_values_into_graph_only(self): 80*da0073e9SAndroid Build Coastguard Worker def fn(shape): 81*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(123) 82*da0073e9SAndroid Build Coastguard Worker x = torch.randn(shape, device="cpu") * random.randint(30, 100) 83*da0073e9SAndroid Build Coastguard Worker return x 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Worker shape = [2, 3] 86*da0073e9SAndroid Build Coastguard Worker random.seed(1) 87*da0073e9SAndroid Build Coastguard Worker res1 = fn(shape) 88*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 89*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 90*da0073e9SAndroid Build Coastguard Worker random.seed(1) 91*da0073e9SAndroid Build Coastguard Worker res2 = opt_fn(shape) 92*da0073e9SAndroid Build Coastguard Worker 93*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(res1, res2)) 94*da0073e9SAndroid Build Coastguard Worker 95*da0073e9SAndroid Build Coastguard Worker def test_random_values_with_graph_break(self): 96*da0073e9SAndroid Build Coastguard Worker def fn(x): 97*da0073e9SAndroid Build Coastguard Worker r1 = random.random() 98*da0073e9SAndroid Build Coastguard Worker y = x + random.uniform(10, 20) 99*da0073e9SAndroid Build Coastguard Worker y.sum().item() 100*da0073e9SAndroid Build Coastguard Worker r2 = random.randint(2, 18) # no graph output in this frame 101*da0073e9SAndroid Build Coastguard Worker y.sum().item() 102*da0073e9SAndroid Build Coastguard Worker return y + r1, r2 103*da0073e9SAndroid Build Coastguard Worker 104*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) 105*da0073e9SAndroid Build Coastguard Worker random.seed(1) 106*da0073e9SAndroid Build Coastguard Worker res1 = fn(x) 107*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 108*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 109*da0073e9SAndroid Build Coastguard Worker random.seed(1) 110*da0073e9SAndroid Build Coastguard Worker res2 = opt_fn(x) 111*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(res1, res2)) 112*da0073e9SAndroid Build Coastguard Worker 113*da0073e9SAndroid Build Coastguard Worker # Really annoying intersection of specialization and RandomValueSource 114*da0073e9SAndroid Build Coastguard Worker # If we get a RandomValueSource with a single element tensor, we should return a ConstantVariable like other 115*da0073e9SAndroid Build Coastguard Worker # unspects... but if we do, we break the bytecode assumptions and guards will not work as we will be referring 116*da0073e9SAndroid Build Coastguard Worker # to a name from a source that is not there. If we call .item() and take the wrapped_value out, where we do 117*da0073e9SAndroid Build Coastguard Worker # wrapped_value = wrapped_value.item() where we send unspec down to wrap_fx_proxy, this test passes and then 118*da0073e9SAndroid Build Coastguard Worker # some models fail on missing codegen.tx.output.random_values_var. If we let the tensor value go into wrap as 119*da0073e9SAndroid Build Coastguard Worker # it is, this test fails. 120*da0073e9SAndroid Build Coastguard Worker # The real solution here is to rewrite RandomValueSource and all the codegen it does from the ground up. 121*da0073e9SAndroid Build Coastguard Worker def test_multiple_consecutive_random_calls_before_graph(self): 122*da0073e9SAndroid Build Coastguard Worker def fn(x): 123*da0073e9SAndroid Build Coastguard Worker dim1 = random.randrange(start=0, stop=5) 124*da0073e9SAndroid Build Coastguard Worker dim2 = random.randrange(start=0, stop=5) 125*da0073e9SAndroid Build Coastguard Worker dim3 = random.randrange(start=0, stop=5) 126*da0073e9SAndroid Build Coastguard Worker y = torch.rand(dim1, dim2, dim3) 127*da0073e9SAndroid Build Coastguard Worker return x + 2, y 128*da0073e9SAndroid Build Coastguard Worker 129*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) 130*da0073e9SAndroid Build Coastguard Worker random.seed(1) 131*da0073e9SAndroid Build Coastguard Worker res1 = fn(x) 132*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 133*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 134*da0073e9SAndroid Build Coastguard Worker random.seed(1) 135*da0073e9SAndroid Build Coastguard Worker res2 = opt_fn(x) 136*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(res1, res2)) 137*da0073e9SAndroid Build Coastguard Worker 138*da0073e9SAndroid Build Coastguard Worker def test_compiled_random_calls_are_random(self): 139*da0073e9SAndroid Build Coastguard Worker # For compiled functions with random calls, 140*da0073e9SAndroid Build Coastguard Worker # it should return different values for every iteration. 141*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/95425 142*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager", fullgraph=True) 143*da0073e9SAndroid Build Coastguard Worker def fn(x): 144*da0073e9SAndroid Build Coastguard Worker return (x + 1) * random.uniform(0, 1) 145*da0073e9SAndroid Build Coastguard Worker 146*da0073e9SAndroid Build Coastguard Worker res = [] 147*da0073e9SAndroid Build Coastguard Worker for _ in range(5): 148*da0073e9SAndroid Build Coastguard Worker res.append(fn(torch.ones(2))) 149*da0073e9SAndroid Build Coastguard Worker for i in range(1, 5): 150*da0073e9SAndroid Build Coastguard Worker self.assertFalse(same(res[i - 1], res[i])) 151*da0073e9SAndroid Build Coastguard Worker 152*da0073e9SAndroid Build Coastguard Worker def test_random_call_with_while_loop(self): 153*da0073e9SAndroid Build Coastguard Worker def fn(x): 154*da0073e9SAndroid Build Coastguard Worker dim1 = random.randrange(start=0, stop=3) 155*da0073e9SAndroid Build Coastguard Worker dim2 = dim1 156*da0073e9SAndroid Build Coastguard Worker while dim1 == dim2: 157*da0073e9SAndroid Build Coastguard Worker dim2 = random.randrange(start=0, stop=3) 158*da0073e9SAndroid Build Coastguard Worker return x * 2 159*da0073e9SAndroid Build Coastguard Worker 160*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4) 161*da0073e9SAndroid Build Coastguard Worker random.seed(1) 162*da0073e9SAndroid Build Coastguard Worker res1 = fn(x) 163*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager")(fn) 164*da0073e9SAndroid Build Coastguard Worker random.seed(1) 165*da0073e9SAndroid Build Coastguard Worker res2 = opt_fn(x) 166*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(res1, res2)) 167*da0073e9SAndroid Build Coastguard Worker 168*da0073e9SAndroid Build Coastguard Worker random.seed(10) 169*da0073e9SAndroid Build Coastguard Worker res1 = fn(x) 170*da0073e9SAndroid Build Coastguard Worker random.seed(10) 171*da0073e9SAndroid Build Coastguard Worker res2 = opt_fn(x) 172*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(res1, res2)) 173*da0073e9SAndroid Build Coastguard Worker 174*da0073e9SAndroid Build Coastguard Worker def test_random_object(self): 175*da0073e9SAndroid Build Coastguard Worker # test argument passing, mutation, reconstruction, state correctness 176*da0073e9SAndroid Build Coastguard Worker def fn(x, rand2): 177*da0073e9SAndroid Build Coastguard Worker r1 = random.randint(1, 9) 178*da0073e9SAndroid Build Coastguard Worker r2 = rand2.randint(1, 9) 179*da0073e9SAndroid Build Coastguard Worker rand3 = random.Random(42) 180*da0073e9SAndroid Build Coastguard Worker r3 = rand3.randint(1, 9) 181*da0073e9SAndroid Build Coastguard Worker 182*da0073e9SAndroid Build Coastguard Worker y = x + r1 + r2 + r3 183*da0073e9SAndroid Build Coastguard Worker return y, rand2, rand3 184*da0073e9SAndroid Build Coastguard Worker 185*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(3, 3) 186*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 187*da0073e9SAndroid Build Coastguard Worker random.seed(0) 188*da0073e9SAndroid Build Coastguard Worker y_1, rand2_1, rand3_1 = fn(inp, random.Random(12)) 189*da0073e9SAndroid Build Coastguard Worker state_1 = random.getstate() 190*da0073e9SAndroid Build Coastguard Worker random.seed(0) 191*da0073e9SAndroid Build Coastguard Worker y_2, rand2_2, rand3_2 = opt_fn(inp, random.Random(12)) 192*da0073e9SAndroid Build Coastguard Worker state_2 = random.getstate() 193*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_1, y_2) 194*da0073e9SAndroid Build Coastguard Worker self.assertEqual(state_1, state_2) 195*da0073e9SAndroid Build Coastguard Worker self.assertEqual(rand2_1.getstate(), rand2_2.getstate()) 196*da0073e9SAndroid Build Coastguard Worker self.assertEqual(rand3_1.getstate(), rand3_2.getstate()) 197*da0073e9SAndroid Build Coastguard Worker 198*da0073e9SAndroid Build Coastguard Worker def test_random_object_methods(self): 199*da0073e9SAndroid Build Coastguard Worker def fn(x, rand1, rand2, rand3): 200*da0073e9SAndroid Build Coastguard Worker rand1.seed(42) 201*da0073e9SAndroid Build Coastguard Worker rand4 = random.Random(9002) 202*da0073e9SAndroid Build Coastguard Worker rand2.setstate(rand4.getstate()) 203*da0073e9SAndroid Build Coastguard Worker r1 = rand1.random() 204*da0073e9SAndroid Build Coastguard Worker r2 = rand2.randint(1, 10) 205*da0073e9SAndroid Build Coastguard Worker r3 = rand3.randrange(10) 206*da0073e9SAndroid Build Coastguard Worker r4 = rand4.uniform(0, 1) 207*da0073e9SAndroid Build Coastguard Worker return x + r1 + r2 + r3 + r4 208*da0073e9SAndroid Build Coastguard Worker 209*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(3, 3) 210*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 211*da0073e9SAndroid Build Coastguard Worker rand1_1 = random.Random(1) 212*da0073e9SAndroid Build Coastguard Worker rand2_1 = random.Random(2) 213*da0073e9SAndroid Build Coastguard Worker rand3_1 = random.Random(3) 214*da0073e9SAndroid Build Coastguard Worker rand1_2 = random.Random(1) 215*da0073e9SAndroid Build Coastguard Worker rand2_2 = random.Random(2) 216*da0073e9SAndroid Build Coastguard Worker rand3_2 = random.Random(3) 217*da0073e9SAndroid Build Coastguard Worker y1 = fn(inp, rand1_1, rand2_1, rand3_1) 218*da0073e9SAndroid Build Coastguard Worker y2 = opt_fn(inp, rand1_2, rand2_2, rand3_2) 219*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y1, y2) 220*da0073e9SAndroid Build Coastguard Worker self.assertEqual(rand1_1.getstate(), rand1_2.getstate()) 221*da0073e9SAndroid Build Coastguard Worker self.assertEqual(rand2_1.getstate(), rand2_2.getstate()) 222*da0073e9SAndroid Build Coastguard Worker self.assertEqual(rand3_1.getstate(), rand3_2.getstate()) 223*da0073e9SAndroid Build Coastguard Worker 224*da0073e9SAndroid Build Coastguard Worker def test_random_object_overriden_methods(self): 225*da0073e9SAndroid Build Coastguard Worker # these will result in graph breaks, but we shouldn't crash 226*da0073e9SAndroid Build Coastguard Worker def get_rng(): 227*da0073e9SAndroid Build Coastguard Worker rand1 = random.Random(1) 228*da0073e9SAndroid Build Coastguard Worker rand2 = random.Random(2) 229*da0073e9SAndroid Build Coastguard Worker 230*da0073e9SAndroid Build Coastguard Worker orig_random = rand1.random 231*da0073e9SAndroid Build Coastguard Worker 232*da0073e9SAndroid Build Coastguard Worker def custom_random(): 233*da0073e9SAndroid Build Coastguard Worker return orig_random() 234*da0073e9SAndroid Build Coastguard Worker 235*da0073e9SAndroid Build Coastguard Worker orig_getstate = rand2.getstate 236*da0073e9SAndroid Build Coastguard Worker 237*da0073e9SAndroid Build Coastguard Worker def custom_getstate(): 238*da0073e9SAndroid Build Coastguard Worker return orig_getstate() 239*da0073e9SAndroid Build Coastguard Worker 240*da0073e9SAndroid Build Coastguard Worker rand1.random = custom_random 241*da0073e9SAndroid Build Coastguard Worker rand2.getstate = custom_getstate 242*da0073e9SAndroid Build Coastguard Worker return rand1, rand2 243*da0073e9SAndroid Build Coastguard Worker 244*da0073e9SAndroid Build Coastguard Worker def fn(x, rand1, rand2): 245*da0073e9SAndroid Build Coastguard Worker r1 = rand1.random() 246*da0073e9SAndroid Build Coastguard Worker rand3 = random.Random() 247*da0073e9SAndroid Build Coastguard Worker rand3.setstate(rand2.getstate()) 248*da0073e9SAndroid Build Coastguard Worker r2 = rand3.random() 249*da0073e9SAndroid Build Coastguard Worker return x + r1 + r2 250*da0073e9SAndroid Build Coastguard Worker 251*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(3, 3) 252*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager") 253*da0073e9SAndroid Build Coastguard Worker y1 = fn(inp, *get_rng()) 254*da0073e9SAndroid Build Coastguard Worker y2 = opt_fn(inp, *get_rng()) 255*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y1, y2) 256*da0073e9SAndroid Build Coastguard Worker 257*da0073e9SAndroid Build Coastguard Worker def test_builtin_getitem(self): 258*da0073e9SAndroid Build Coastguard Worker # builtin getitem args[0] is python list and args[1] is unspec 259*da0073e9SAndroid Build Coastguard Worker def fn(x, idx): 260*da0073e9SAndroid Build Coastguard Worker return (torch.zeros(idx), x[idx], x[idx:]) 261*da0073e9SAndroid Build Coastguard Worker 262*da0073e9SAndroid Build Coastguard Worker x = list(range(50)) 263*da0073e9SAndroid Build Coastguard Worker ref = fn(x, 48) # 48 is unspecialized 264*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 265*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 266*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x, 48) 267*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 268*da0073e9SAndroid Build Coastguard Worker 269*da0073e9SAndroid Build Coastguard Worker def test_use_and_specialize(self): 270*da0073e9SAndroid Build Coastguard Worker cnt = CompileCounter() 271*da0073e9SAndroid Build Coastguard Worker 272*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend=cnt, fullgraph=True, dynamic=True) 273*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 274*da0073e9SAndroid Build Coastguard Worker x = x + y 275*da0073e9SAndroid Build Coastguard Worker if y == 2: 276*da0073e9SAndroid Build Coastguard Worker return x - 1 277*da0073e9SAndroid Build Coastguard Worker else: 278*da0073e9SAndroid Build Coastguard Worker return x + 1 279*da0073e9SAndroid Build Coastguard Worker 280*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(fn(torch.tensor([5]), 2), 6)) 281*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(fn(torch.tensor([6]), 2), 7)) 282*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(fn(torch.tensor([5]), 3), 9)) 283*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(fn(torch.tensor([4]), 3), 8)) 284*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 2) 285*da0073e9SAndroid Build Coastguard Worker 286*da0073e9SAndroid Build Coastguard Worker def test_no_recompiles(self): 287*da0073e9SAndroid Build Coastguard Worker cnt = CompileCounter() 288*da0073e9SAndroid Build Coastguard Worker 289*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend=cnt, fullgraph=True, dynamic=True) 290*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 291*da0073e9SAndroid Build Coastguard Worker return x + y 292*da0073e9SAndroid Build Coastguard Worker 293*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(fn(torch.tensor([5]), 100), 105)) 294*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(fn(torch.tensor([4]), 200), 204)) 295*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(fn(torch.tensor([3]), 300), 303)) 296*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(fn(torch.tensor([2]), 400), 402)) 297*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 298*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.op_count, 1) 299*da0073e9SAndroid Build Coastguard Worker 300*da0073e9SAndroid Build Coastguard Worker def test_no_recompiles_prod_backward(self): 301*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/120608 302*da0073e9SAndroid Build Coastguard Worker cnt = CompileCounter() 303*da0073e9SAndroid Build Coastguard Worker 304*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend=cnt, fullgraph=True, dynamic=True) 305*da0073e9SAndroid Build Coastguard Worker def fn(t): 306*da0073e9SAndroid Build Coastguard Worker return torch.prod(t, 3, keepdim=True) 307*da0073e9SAndroid Build Coastguard Worker 308*da0073e9SAndroid Build Coastguard Worker input_shapes = [(8, 10, 3, 2), (8, 3, 5, 2), (8, 4, 8, 2)] 309*da0073e9SAndroid Build Coastguard Worker for s in input_shapes: 310*da0073e9SAndroid Build Coastguard Worker t1 = torch.randn(s, requires_grad=True) 311*da0073e9SAndroid Build Coastguard Worker h_result = fn(t1) 312*da0073e9SAndroid Build Coastguard Worker grad = torch.ones_like(h_result) 313*da0073e9SAndroid Build Coastguard Worker h_result.backward(grad) 314*da0073e9SAndroid Build Coastguard Worker 315*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 316*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.op_count, 1) 317*da0073e9SAndroid Build Coastguard Worker 318*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") 319*da0073e9SAndroid Build Coastguard Worker def test_builtin_functions_on_cuda(self): 320*da0073e9SAndroid Build Coastguard Worker def fn(x, scaler): 321*da0073e9SAndroid Build Coastguard Worker m = torch.nn.ReLU() 322*da0073e9SAndroid Build Coastguard Worker y = m(x) * scaler 323*da0073e9SAndroid Build Coastguard Worker return y 324*da0073e9SAndroid Build Coastguard Worker 325*da0073e9SAndroid Build Coastguard Worker x = torch.randn([3, 6], device="cuda") 326*da0073e9SAndroid Build Coastguard Worker scaler = 0.23 # 0.23 is unspecialized 327*da0073e9SAndroid Build Coastguard Worker ref = fn(x, scaler) 328*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 329*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 330*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x, scaler) 331*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 332*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref.device, res.device) 333*da0073e9SAndroid Build Coastguard Worker 334*da0073e9SAndroid Build Coastguard Worker def test_unspec_float_precision(self): 335*da0073e9SAndroid Build Coastguard Worker def fn(image, scale_factor): 336*da0073e9SAndroid Build Coastguard Worker image = torch.nn.functional.interpolate( 337*da0073e9SAndroid Build Coastguard Worker image[None], 338*da0073e9SAndroid Build Coastguard Worker size=None, 339*da0073e9SAndroid Build Coastguard Worker scale_factor=scale_factor, 340*da0073e9SAndroid Build Coastguard Worker mode="bilinear", 341*da0073e9SAndroid Build Coastguard Worker recompute_scale_factor=True, 342*da0073e9SAndroid Build Coastguard Worker align_corners=False, 343*da0073e9SAndroid Build Coastguard Worker )[0] 344*da0073e9SAndroid Build Coastguard Worker 345*da0073e9SAndroid Build Coastguard Worker return image.shape 346*da0073e9SAndroid Build Coastguard Worker 347*da0073e9SAndroid Build Coastguard Worker x = torch.rand([3, 427, 640]) 348*da0073e9SAndroid Build Coastguard Worker scale_factor = 1.873536229133606 349*da0073e9SAndroid Build Coastguard Worker ref = fn(x, scale_factor) 350*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 351*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 352*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x, scale_factor) 353*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 354*da0073e9SAndroid Build Coastguard Worker 355*da0073e9SAndroid Build Coastguard Worker @unittest.expectedFailure # fails as long as numpy scalars are 0D arrays 356*da0073e9SAndroid Build Coastguard Worker def test_specializing_numpy_float_in_control_flow(self): 357*da0073e9SAndroid Build Coastguard Worker # np.float64 is unspecialized by default, 358*da0073e9SAndroid Build Coastguard Worker # but it should be specialized when used in control flow. 359*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 360*da0073e9SAndroid Build Coastguard Worker if y > 1.0: 361*da0073e9SAndroid Build Coastguard Worker return x + 1 362*da0073e9SAndroid Build Coastguard Worker else: 363*da0073e9SAndroid Build Coastguard Worker return x - 1 364*da0073e9SAndroid Build Coastguard Worker 365*da0073e9SAndroid Build Coastguard Worker x = torch.rand(4) 366*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) 367*da0073e9SAndroid Build Coastguard Worker for t in [np.float16, np.float32, np.float64]: 368*da0073e9SAndroid Build Coastguard Worker y = t(1.23) 369*da0073e9SAndroid Build Coastguard Worker ref = fn(x, y) 370*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x, y) 371*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 372*da0073e9SAndroid Build Coastguard Worker 373*da0073e9SAndroid Build Coastguard Worker def test_mark_static_inside(self): 374*da0073e9SAndroid Build Coastguard Worker def fn(x): 375*da0073e9SAndroid Build Coastguard Worker torch._dynamo.mark_static(x, 0) 376*da0073e9SAndroid Build Coastguard Worker comptime.assert_static(x.size(0)) 377*da0073e9SAndroid Build Coastguard Worker return x + 1 378*da0073e9SAndroid Build Coastguard Worker 379*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, dynamic=True, fullgraph=True) 380*da0073e9SAndroid Build Coastguard Worker opt_fn(torch.randn(12, 23)) 381*da0073e9SAndroid Build Coastguard Worker 382*da0073e9SAndroid Build Coastguard Worker def test_shape_graph_break(self): 383*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.comptime import comptime 384*da0073e9SAndroid Build Coastguard Worker 385*da0073e9SAndroid Build Coastguard Worker def fn(x): 386*da0073e9SAndroid Build Coastguard Worker x_shape = x.size() 387*da0073e9SAndroid Build Coastguard Worker comptime.graph_break() 388*da0073e9SAndroid Build Coastguard Worker return x + torch.randn(x_shape) 389*da0073e9SAndroid Build Coastguard Worker 390*da0073e9SAndroid Build Coastguard Worker x = torch.randn(20) 391*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager")(fn) 392*da0073e9SAndroid Build Coastguard Worker opt_fn(x) 393*da0073e9SAndroid Build Coastguard Worker 394*da0073e9SAndroid Build Coastguard Worker def test_isinstance_symint(self): 395*da0073e9SAndroid Build Coastguard Worker def fn(x): 396*da0073e9SAndroid Build Coastguard Worker assert isinstance(x.size(0), int) 397*da0073e9SAndroid Build Coastguard Worker return x * 2 398*da0073e9SAndroid Build Coastguard Worker 399*da0073e9SAndroid Build Coastguard Worker x = torch.randn(20) 400*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager")(fn) 401*da0073e9SAndroid Build Coastguard Worker opt_fn(x) 402*da0073e9SAndroid Build Coastguard Worker y = torch.randn(30) 403*da0073e9SAndroid Build Coastguard Worker torch._dynamo.mark_dynamic(y, 0) 404*da0073e9SAndroid Build Coastguard Worker opt_fn(y) 405*da0073e9SAndroid Build Coastguard Worker 406*da0073e9SAndroid Build Coastguard Worker def test_mark_01_dynamic(self): 407*da0073e9SAndroid Build Coastguard Worker def fn(x): 408*da0073e9SAndroid Build Coastguard Worker return x * 2 409*da0073e9SAndroid Build Coastguard Worker 410*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1) 411*da0073e9SAndroid Build Coastguard Worker torch._dynamo.mark_dynamic(x, 0) 412*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager")(fn) 413*da0073e9SAndroid Build Coastguard Worker # This will fail to compile a generic kernel, but we should not 414*da0073e9SAndroid Build Coastguard Worker # complain about it (mark dynamic will try its best but 0/1 415*da0073e9SAndroid Build Coastguard Worker # specialization is allowed) 416*da0073e9SAndroid Build Coastguard Worker opt_fn(x) 417*da0073e9SAndroid Build Coastguard Worker 418*da0073e9SAndroid Build Coastguard Worker def test_conv1d_symint_padding(self): 419*da0073e9SAndroid Build Coastguard Worker kernel = torch.randn(1, 1, 4) 420*da0073e9SAndroid Build Coastguard Worker 421*da0073e9SAndroid Build Coastguard Worker def func(x): 422*da0073e9SAndroid Build Coastguard Worker padding = math.ceil((kernel.shape[-1] + x.shape[-1] % 2) / 2) - 1 423*da0073e9SAndroid Build Coastguard Worker out = F.conv1d(x, kernel, padding=padding, stride=2) 424*da0073e9SAndroid Build Coastguard Worker return out 425*da0073e9SAndroid Build Coastguard Worker 426*da0073e9SAndroid Build Coastguard Worker opt_func = torch.compile(func) 427*da0073e9SAndroid Build Coastguard Worker 428*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 1, 175) 429*da0073e9SAndroid Build Coastguard Worker opt_func(x) # passes 430*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 1, 249) 431*da0073e9SAndroid Build Coastguard Worker opt_func(x) # crashes 432*da0073e9SAndroid Build Coastguard Worker 433*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch("assume_static_by_default", True) 434*da0073e9SAndroid Build Coastguard Worker def test_propagate_dynamic_dim(self): 435*da0073e9SAndroid Build Coastguard Worker x = torch.randn(20) 436*da0073e9SAndroid Build Coastguard Worker torch._dynamo.mark_dynamic(x, 0) 437*da0073e9SAndroid Build Coastguard Worker 438*da0073e9SAndroid Build Coastguard Worker @torch.compile() 439*da0073e9SAndroid Build Coastguard Worker def fn(x): 440*da0073e9SAndroid Build Coastguard Worker y = x * 2 441*da0073e9SAndroid Build Coastguard Worker comptime.graph_break() 442*da0073e9SAndroid Build Coastguard Worker z = y * 2 443*da0073e9SAndroid Build Coastguard Worker return z 444*da0073e9SAndroid Build Coastguard Worker 445*da0073e9SAndroid Build Coastguard Worker z = fn(x) 446*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z._dynamo_weak_dynamic_indices, {0}) 447*da0073e9SAndroid Build Coastguard Worker 448*da0073e9SAndroid Build Coastguard Worker def test_rshift_dynamic(self): 449*da0073e9SAndroid Build Coastguard Worker def shift_right(tensor: torch.Tensor) -> torch.Tensor: 450*da0073e9SAndroid Build Coastguard Worker return (tensor >> 2).to(torch.long) 451*da0073e9SAndroid Build Coastguard Worker 452*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(shift_right, fullgraph=True, dynamic=True) 453*da0073e9SAndroid Build Coastguard Worker sample_input = torch.tensor([4, 4, 16, 32], dtype=torch.uint8) 454*da0073e9SAndroid Build Coastguard Worker opt_fn(sample_input) 455*da0073e9SAndroid Build Coastguard Worker 456*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch(capture_scalar_outputs=True) 457*da0073e9SAndroid Build Coastguard Worker def test_symfloat_to_tensor(self): 458*da0073e9SAndroid Build Coastguard Worker def f1(v): 459*da0073e9SAndroid Build Coastguard Worker return torch.tensor([v.item()]) 460*da0073e9SAndroid Build Coastguard Worker 461*da0073e9SAndroid Build Coastguard Worker def f2(v): 462*da0073e9SAndroid Build Coastguard Worker return torch.tensor([[v.item()], [2.0]]) 463*da0073e9SAndroid Build Coastguard Worker 464*da0073e9SAndroid Build Coastguard Worker def f3(v): 465*da0073e9SAndroid Build Coastguard Worker return torch.tensor(v.item()) 466*da0073e9SAndroid Build Coastguard Worker 467*da0073e9SAndroid Build Coastguard Worker def f4(v): 468*da0073e9SAndroid Build Coastguard Worker return torch.tensor((v.item(),)) 469*da0073e9SAndroid Build Coastguard Worker 470*da0073e9SAndroid Build Coastguard Worker optimize = torch.compile(backend="aot_eager", fullgraph=True) 471*da0073e9SAndroid Build Coastguard Worker 472*da0073e9SAndroid Build Coastguard Worker r = torch.randn(1) 473*da0073e9SAndroid Build Coastguard Worker 474*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f1(r), optimize(f1)(r)) 475*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f2(r), optimize(f2)(r)) 476*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f3(r), optimize(f3)(r)) 477*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f4(r), optimize(f4)(r)) 478*da0073e9SAndroid Build Coastguard Worker 479*da0073e9SAndroid Build Coastguard Worker @skipIfWindows( 480*da0073e9SAndroid Build Coastguard Worker msg="AssertionError: The values for attribute 'dtype' do not match: torch.int32 != torch.int64." 481*da0073e9SAndroid Build Coastguard Worker ) 482*da0073e9SAndroid Build Coastguard Worker def test_to_tensor(self): 483*da0073e9SAndroid Build Coastguard Worker def f1(): 484*da0073e9SAndroid Build Coastguard Worker a = np.random.uniform(low=-1, high=1, size=(20, 1)) 485*da0073e9SAndroid Build Coastguard Worker return torch.tensor([a, a, a, a], dtype=torch.float64, device="cpu") 486*da0073e9SAndroid Build Coastguard Worker 487*da0073e9SAndroid Build Coastguard Worker def f2(): 488*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([[[123]]]) 489*da0073e9SAndroid Build Coastguard Worker return torch.tensor([a, a]) 490*da0073e9SAndroid Build Coastguard Worker 491*da0073e9SAndroid Build Coastguard Worker def f3(): 492*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(123) 493*da0073e9SAndroid Build Coastguard Worker return torch.tensor([a, a]) 494*da0073e9SAndroid Build Coastguard Worker 495*da0073e9SAndroid Build Coastguard Worker def f4(): 496*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(123) 497*da0073e9SAndroid Build Coastguard Worker b = torch.tensor([[[456]]]) 498*da0073e9SAndroid Build Coastguard Worker return torch.tensor([a, b]) 499*da0073e9SAndroid Build Coastguard Worker 500*da0073e9SAndroid Build Coastguard Worker def f5(): 501*da0073e9SAndroid Build Coastguard Worker a = np.array([1, 2]) 502*da0073e9SAndroid Build Coastguard Worker return torch.tensor([a, a]) 503*da0073e9SAndroid Build Coastguard Worker 504*da0073e9SAndroid Build Coastguard Worker optimize = torch.compile(backend="aot_eager", fullgraph=True) 505*da0073e9SAndroid Build Coastguard Worker 506*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f1().shape, optimize(f1)().shape) 507*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f2(), optimize(f2)()) 508*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f3(), optimize(f3)()) 509*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f4(), optimize(f4)()) 510*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f5(), optimize(f5)()) 511*da0073e9SAndroid Build Coastguard Worker 512*da0073e9SAndroid Build Coastguard Worker def test_sym_int_conversion(self): 513*da0073e9SAndroid Build Coastguard Worker def f(x): 514*da0073e9SAndroid Build Coastguard Worker y = x.size(0) 515*da0073e9SAndroid Build Coastguard Worker return x * int(y == 0) 516*da0073e9SAndroid Build Coastguard Worker 517*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(f, backend="eager", fullgraph=True) 518*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3) 519*da0073e9SAndroid Build Coastguard Worker opt_fn(x) 520*da0073e9SAndroid Build Coastguard Worker 521*da0073e9SAndroid Build Coastguard Worker def test_sum_dimlist_spec(self): 522*da0073e9SAndroid Build Coastguard Worker def fn(inputs, dim): 523*da0073e9SAndroid Build Coastguard Worker return torch.sum(inputs, dim) 524*da0073e9SAndroid Build Coastguard Worker 525*da0073e9SAndroid Build Coastguard Worker inputs = torch.randn(128, 5, 24, 24) 526*da0073e9SAndroid Build Coastguard Worker dim = (-1, 1, 0, 2) 527*da0073e9SAndroid Build Coastguard Worker compl_fn = torch.compile(fn, dynamic=True, backend="eager", fullgraph=True) 528*da0073e9SAndroid Build Coastguard Worker self.assertEqual(compl_fn(inputs, dim), fn(inputs, dim)) 529*da0073e9SAndroid Build Coastguard Worker 530*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch(capture_scalar_outputs=True) 531*da0073e9SAndroid Build Coastguard Worker def test_item_max(self): 532*da0073e9SAndroid Build Coastguard Worker def fn(x): 533*da0073e9SAndroid Build Coastguard Worker return torch.ones(max(x.item(), 1024)) 534*da0073e9SAndroid Build Coastguard Worker 535*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([1000]) 536*da0073e9SAndroid Build Coastguard Worker y = torch.tensor([2000]) 537*da0073e9SAndroid Build Coastguard Worker compl_fn = torch.compile(fn, backend="eager", fullgraph=True) 538*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(x), compl_fn(x)) 539*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(y), compl_fn(y)) 540*da0073e9SAndroid Build Coastguard Worker 541*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/104812 542*da0073e9SAndroid Build Coastguard Worker def test_argmin_coerces_symint_to_intlist_spec(self): 543*da0073e9SAndroid Build Coastguard Worker def fn(x, dim): 544*da0073e9SAndroid Build Coastguard Worker # the python arg parser coerces dim into a vector<int> 545*da0073e9SAndroid Build Coastguard Worker return torch.amin(x, dim=dim, keepdim=True) 546*da0073e9SAndroid Build Coastguard Worker 547*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, 4, 4) 548*da0073e9SAndroid Build Coastguard Worker dim = 2 549*da0073e9SAndroid Build Coastguard Worker compl_fn = torch.compile(fn, dynamic=True, backend="eager", fullgraph=True) 550*da0073e9SAndroid Build Coastguard Worker self.assertEqual(compl_fn(x, dim), fn(x, dim)) 551*da0073e9SAndroid Build Coastguard Worker 552*da0073e9SAndroid Build Coastguard Worker def test_exponential(self): 553*da0073e9SAndroid Build Coastguard Worker def fn(inputs, op_inputs_dict): 554*da0073e9SAndroid Build Coastguard Worker res = inputs.exponential_(**op_inputs_dict) 555*da0073e9SAndroid Build Coastguard Worker return res 556*da0073e9SAndroid Build Coastguard Worker 557*da0073e9SAndroid Build Coastguard Worker inputs = torch.randn(2, 3, 4) 558*da0073e9SAndroid Build Coastguard Worker op_inputs_dict = {"lambd": 10, "generator": None} 559*da0073e9SAndroid Build Coastguard Worker compl_fn = torch.compile(fn, dynamic=True, backend="eager", fullgraph=True) 560*da0073e9SAndroid Build Coastguard Worker self.assertEqual(compl_fn(inputs, op_inputs_dict), fn(inputs, op_inputs_dict)) 561*da0073e9SAndroid Build Coastguard Worker 562*da0073e9SAndroid Build Coastguard Worker def test_symbol_guard_limit_before_specialize(self): 563*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 564*da0073e9SAndroid Build Coastguard Worker 565*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize(cnts, dynamic=True) 566*da0073e9SAndroid Build Coastguard Worker def fn(x): 567*da0073e9SAndroid Build Coastguard Worker torch._check(x.size(0) != 3) 568*da0073e9SAndroid Build Coastguard Worker torch._check(x.size(0) != 4) 569*da0073e9SAndroid Build Coastguard Worker torch._check(x.size(0) != 5) 570*da0073e9SAndroid Build Coastguard Worker torch._check(x.size(0) != 6) 571*da0073e9SAndroid Build Coastguard Worker return x + 2 572*da0073e9SAndroid Build Coastguard Worker 573*da0073e9SAndroid Build Coastguard Worker # Control test 574*da0073e9SAndroid Build Coastguard Worker fn(torch.randn(12)) 575*da0073e9SAndroid Build Coastguard Worker fn(torch.randn(13)) 576*da0073e9SAndroid Build Coastguard Worker fn(torch.randn(14)) 577*da0073e9SAndroid Build Coastguard Worker 578*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnts.frame_count, """1""") 579*da0073e9SAndroid Build Coastguard Worker cnts.frame_count = 0 580*da0073e9SAndroid Build Coastguard Worker 581*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 582*da0073e9SAndroid Build Coastguard Worker 583*da0073e9SAndroid Build Coastguard Worker with torch.fx.experimental._config.patch( 584*da0073e9SAndroid Build Coastguard Worker symbol_guard_limit_before_specialize=3 585*da0073e9SAndroid Build Coastguard Worker ): 586*da0073e9SAndroid Build Coastguard Worker fn(torch.randn(12)) 587*da0073e9SAndroid Build Coastguard Worker fn(torch.randn(13)) 588*da0073e9SAndroid Build Coastguard Worker fn(torch.randn(14)) 589*da0073e9SAndroid Build Coastguard Worker 590*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnts.frame_count, """3""") 591*da0073e9SAndroid Build Coastguard Worker 592*da0073e9SAndroid Build Coastguard Worker def test_defaults(self): 593*da0073e9SAndroid Build Coastguard Worker def g(x, i=8): 594*da0073e9SAndroid Build Coastguard Worker comptime.assert_static(i) 595*da0073e9SAndroid Build Coastguard Worker return x * i 596*da0073e9SAndroid Build Coastguard Worker 597*da0073e9SAndroid Build Coastguard Worker def fn(x): 598*da0073e9SAndroid Build Coastguard Worker return g(x) 599*da0073e9SAndroid Build Coastguard Worker 600*da0073e9SAndroid Build Coastguard Worker inputs = torch.randn(2, 3, 4) 601*da0073e9SAndroid Build Coastguard Worker compl_fn = torch.compile(fn, dynamic=True, backend="eager") 602*da0073e9SAndroid Build Coastguard Worker self.assertEqual(compl_fn(inputs), fn(inputs)) 603*da0073e9SAndroid Build Coastguard Worker 604*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch(specialize_float=False, assume_static_by_default=True) 605*da0073e9SAndroid Build Coastguard Worker def test_unspec_float_input(self): 606*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 607*da0073e9SAndroid Build Coastguard Worker 608*da0073e9SAndroid Build Coastguard Worker def f(x, y): 609*da0073e9SAndroid Build Coastguard Worker if y == 5.0: 610*da0073e9SAndroid Build Coastguard Worker return x + 2 611*da0073e9SAndroid Build Coastguard Worker else: 612*da0073e9SAndroid Build Coastguard Worker return x + y 613*da0073e9SAndroid Build Coastguard Worker 614*da0073e9SAndroid Build Coastguard Worker cf = torch.compile(backend=cnts, fullgraph=True)(f) 615*da0073e9SAndroid Build Coastguard Worker 616*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 617*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f(x, 3.0), cf(x, 3.0)) 618*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f(x, 4.0), cf(x, 4.0)) 619*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnts.frame_count, """1""") # no recompile 620*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f(x, 5.0), cf(x, 5.0)) 621*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnts.frame_count, """2""") # guard worked 622*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f(x, math.nan), cf(x, math.nan)) 623*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnts.frame_count, """3""") # nan always recompiles 624*da0073e9SAndroid Build Coastguard Worker 625*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch(specialize_float=False, assume_static_by_default=True) 626*da0073e9SAndroid Build Coastguard Worker def test_unspec_float_output(self): 627*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 628*da0073e9SAndroid Build Coastguard Worker 629*da0073e9SAndroid Build Coastguard Worker def f(x, y): 630*da0073e9SAndroid Build Coastguard Worker return x + 1, y * 2 631*da0073e9SAndroid Build Coastguard Worker 632*da0073e9SAndroid Build Coastguard Worker cf = torch.compile(backend=cnts, fullgraph=True)(f) 633*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 634*da0073e9SAndroid Build Coastguard Worker 635*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f(x, 3.0), cf(x, 3.0)) 636*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f(x, 4.0), cf(x, 4.0)) 637*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f(x, 5.0), cf(x, 5.0)) 638*da0073e9SAndroid Build Coastguard Worker 639*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch(capture_scalar_outputs=True) 640*da0073e9SAndroid Build Coastguard Worker def test_data_dependent_evaluate_expr_graph_break(self): 641*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 642*da0073e9SAndroid Build Coastguard Worker 643*da0073e9SAndroid Build Coastguard Worker # To ensure that the continuation frame is compiled, 644*da0073e9SAndroid Build Coastguard Worker # have to write the test function in this funny way. 645*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/issues/111918 646*da0073e9SAndroid Build Coastguard Worker def test(y): 647*da0073e9SAndroid Build Coastguard Worker if y > 2: 648*da0073e9SAndroid Build Coastguard Worker return True 649*da0073e9SAndroid Build Coastguard Worker else: 650*da0073e9SAndroid Build Coastguard Worker return False 651*da0073e9SAndroid Build Coastguard Worker 652*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize(cnts) 653*da0073e9SAndroid Build Coastguard Worker def fn(x): 654*da0073e9SAndroid Build Coastguard Worker x = x + 1 655*da0073e9SAndroid Build Coastguard Worker y = x.item() 656*da0073e9SAndroid Build Coastguard Worker if test(y): 657*da0073e9SAndroid Build Coastguard Worker return x * 2 658*da0073e9SAndroid Build Coastguard Worker else: 659*da0073e9SAndroid Build Coastguard Worker return x * 3 660*da0073e9SAndroid Build Coastguard Worker 661*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([3.0]) 662*da0073e9SAndroid Build Coastguard Worker fn(x) 663*da0073e9SAndroid Build Coastguard Worker 664*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnts.frame_count, """2""") 665*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnts.op_count, """4""") 666*da0073e9SAndroid Build Coastguard Worker 667*da0073e9SAndroid Build Coastguard Worker def test_prune_torch_check(self): 668*da0073e9SAndroid Build Coastguard Worker log_stream, ctx = logs_to_string("torch._dynamo.output_graph", "graph_code") 669*da0073e9SAndroid Build Coastguard Worker 670*da0073e9SAndroid Build Coastguard Worker @torch.compile(fullgraph=True, dynamic=True, backend="eager") 671*da0073e9SAndroid Build Coastguard Worker def f(x, y): 672*da0073e9SAndroid Build Coastguard Worker torch._check(y + 5 == 85) 673*da0073e9SAndroid Build Coastguard Worker torch._check(x.size(0) == 80) 674*da0073e9SAndroid Build Coastguard Worker 675*da0073e9SAndroid Build Coastguard Worker with ctx(): 676*da0073e9SAndroid Build Coastguard Worker f(torch.randn(80, 100), 80) 677*da0073e9SAndroid Build Coastguard Worker 678*da0073e9SAndroid Build Coastguard Worker out = "\n".join(log_stream.getvalue().strip().split("\n")[3:]).strip() 679*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 680*da0073e9SAndroid Build Coastguard Worker out, 681*da0073e9SAndroid Build Coastguard Worker """\ 682*da0073e9SAndroid Build Coastguard Workerdef forward(self): 683*da0073e9SAndroid Build Coastguard Worker return ()""", 684*da0073e9SAndroid Build Coastguard Worker ) 685*da0073e9SAndroid Build Coastguard Worker 686*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch(capture_scalar_outputs=True) 687*da0073e9SAndroid Build Coastguard Worker def test_split_aot_autograd(self): 688*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="aot_eager", fullgraph=True) 689*da0073e9SAndroid Build Coastguard Worker def f(x, i): 690*da0073e9SAndroid Build Coastguard Worker y, z = i.tolist() 691*da0073e9SAndroid Build Coastguard Worker return torch.split(x, [y, z]) 692*da0073e9SAndroid Build Coastguard Worker 693*da0073e9SAndroid Build Coastguard Worker print(f(torch.randn(10, requires_grad=True), torch.tensor([7, 3]))) 694*da0073e9SAndroid Build Coastguard Worker 695*da0073e9SAndroid Build Coastguard Worker def test_bool_tensor_ctor(self): 696*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 697*da0073e9SAndroid Build Coastguard Worker 698*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend=cnts, dynamic=True, fullgraph=True) 699*da0073e9SAndroid Build Coastguard Worker def f(x): 700*da0073e9SAndroid Build Coastguard Worker y = torch.empty((x.size(0) // 13) * 13) 701*da0073e9SAndroid Build Coastguard Worker return torch.tensor(y.numel() == 0) 702*da0073e9SAndroid Build Coastguard Worker 703*da0073e9SAndroid Build Coastguard Worker self.assertTrue(f(torch.empty(8)).item()) 704*da0073e9SAndroid Build Coastguard Worker self.assertFalse(f(torch.empty(13)).item()) 705*da0073e9SAndroid Build Coastguard Worker 706*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch(error_on_recompile=True) 707*da0073e9SAndroid Build Coastguard Worker def test_mark_unbacked(self): 708*da0073e9SAndroid Build Coastguard Worker class TestModel(torch.nn.Module): 709*da0073e9SAndroid Build Coastguard Worker def __init__( 710*da0073e9SAndroid Build Coastguard Worker self, 711*da0073e9SAndroid Build Coastguard Worker ): 712*da0073e9SAndroid Build Coastguard Worker super().__init__() 713*da0073e9SAndroid Build Coastguard Worker 714*da0073e9SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor, val: int) -> torch.Tensor: 715*da0073e9SAndroid Build Coastguard Worker return x * 2 716*da0073e9SAndroid Build Coastguard Worker 717*da0073e9SAndroid Build Coastguard Worker main_model = TestModel() 718*da0073e9SAndroid Build Coastguard Worker opt_model = torch.compile(main_model, mode="max-autotune", dynamic=True) 719*da0073e9SAndroid Build Coastguard Worker 720*da0073e9SAndroid Build Coastguard Worker x1 = torch.rand(3, 5, 4, 8) 721*da0073e9SAndroid Build Coastguard Worker x2 = torch.rand(1, 5, 4, 8) 722*da0073e9SAndroid Build Coastguard Worker 723*da0073e9SAndroid Build Coastguard Worker torch._dynamo.decorators.mark_unbacked(x1, 0) 724*da0073e9SAndroid Build Coastguard Worker 725*da0073e9SAndroid Build Coastguard Worker o1_ref = main_model(x1, 2) 726*da0073e9SAndroid Build Coastguard Worker o1 = opt_model(x1, 2) 727*da0073e9SAndroid Build Coastguard Worker self.assertEqual(o1_ref, o1) 728*da0073e9SAndroid Build Coastguard Worker 729*da0073e9SAndroid Build Coastguard Worker o1_2_ref = main_model(x2, 2) 730*da0073e9SAndroid Build Coastguard Worker o1_2 = opt_model(x2, 2) 731*da0073e9SAndroid Build Coastguard Worker self.assertEqual(o1_2_ref, o1_2) 732*da0073e9SAndroid Build Coastguard Worker 733*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch(error_on_recompile=True) 734*da0073e9SAndroid Build Coastguard Worker def test_mark_unbacked_hint_consistency(self): 735*da0073e9SAndroid Build Coastguard Worker from torch.fx.experimental.symbolic_shapes import guard_size_oblivious 736*da0073e9SAndroid Build Coastguard Worker 737*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1) 738*da0073e9SAndroid Build Coastguard Worker torch._dynamo.decorators.mark_unbacked(x, 0) 739*da0073e9SAndroid Build Coastguard Worker 740*da0073e9SAndroid Build Coastguard Worker @torch.compile() 741*da0073e9SAndroid Build Coastguard Worker def f(x): 742*da0073e9SAndroid Build Coastguard Worker if guard_size_oblivious(x.size(0) != 1): 743*da0073e9SAndroid Build Coastguard Worker return x + 3 744*da0073e9SAndroid Build Coastguard Worker else: 745*da0073e9SAndroid Build Coastguard Worker return x + 4 746*da0073e9SAndroid Build Coastguard Worker 747*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f(x), x + 3) 748*da0073e9SAndroid Build Coastguard Worker 749*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch(error_on_recompile=True) 750*da0073e9SAndroid Build Coastguard Worker def test_mark_unbacked_channels_last(self): 751*da0073e9SAndroid Build Coastguard Worker class TestModel(torch.nn.Module): 752*da0073e9SAndroid Build Coastguard Worker def __init__( 753*da0073e9SAndroid Build Coastguard Worker self, 754*da0073e9SAndroid Build Coastguard Worker ): 755*da0073e9SAndroid Build Coastguard Worker super().__init__() 756*da0073e9SAndroid Build Coastguard Worker 757*da0073e9SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor, val: int) -> torch.Tensor: 758*da0073e9SAndroid Build Coastguard Worker return x * 2 759*da0073e9SAndroid Build Coastguard Worker 760*da0073e9SAndroid Build Coastguard Worker main_model = TestModel() 761*da0073e9SAndroid Build Coastguard Worker opt_model = torch.compile(main_model, mode="max-autotune", dynamic=True) 762*da0073e9SAndroid Build Coastguard Worker 763*da0073e9SAndroid Build Coastguard Worker x1 = torch.rand(3, 5, 4, 8).to(memory_format=torch.channels_last) 764*da0073e9SAndroid Build Coastguard Worker x2 = torch.rand(1, 5, 4, 8).to(memory_format=torch.channels_last) 765*da0073e9SAndroid Build Coastguard Worker 766*da0073e9SAndroid Build Coastguard Worker torch._dynamo.decorators.mark_unbacked(x1, 0) 767*da0073e9SAndroid Build Coastguard Worker 768*da0073e9SAndroid Build Coastguard Worker o1_ref = main_model(x1, 2) 769*da0073e9SAndroid Build Coastguard Worker o1 = opt_model(x1, 2) 770*da0073e9SAndroid Build Coastguard Worker self.assertEqual(o1_ref, o1) 771*da0073e9SAndroid Build Coastguard Worker 772*da0073e9SAndroid Build Coastguard Worker o1_2_ref = main_model(x2, 2) 773*da0073e9SAndroid Build Coastguard Worker o1_2 = opt_model(x2, 2) 774*da0073e9SAndroid Build Coastguard Worker self.assertEqual(o1_2_ref, o1_2) 775*da0073e9SAndroid Build Coastguard Worker 776*da0073e9SAndroid Build Coastguard Worker 777*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 778*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.test_case import run_tests 779*da0073e9SAndroid Build Coastguard Worker 780*da0073e9SAndroid Build Coastguard Worker run_tests() 781