1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dynamo"] 2*da0073e9SAndroid Build Coastguard Workerfrom unittest.mock import patch 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard Workerimport torch 5*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.test_case 6*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.testing 7*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.testing import unsupported 8*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.utils import ifdynstaticdefault 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Workerglobalmod = torch.nn.ReLU() 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Workerdef indirectly_unsupported(a, b): 15*da0073e9SAndroid Build Coastguard Worker c = a + b 16*da0073e9SAndroid Build Coastguard Worker return unsupported(a, c) 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Workerclass SubGraphTests(torch._dynamo.test_case.TestCase): 20*da0073e9SAndroid Build Coastguard Worker def _common(self, fn, frame_count, op_count): 21*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 22*da0073e9SAndroid Build Coastguard Worker v1 = torch.ones(10) 23*da0073e9SAndroid Build Coastguard Worker v2 = torch.ones(10) * -2.0 24*da0073e9SAndroid Build Coastguard Worker correct1 = fn(v1, v2) 25*da0073e9SAndroid Build Coastguard Worker correct2 = fn(v2, v1) 26*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 27*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnt)(fn) 28*da0073e9SAndroid Build Coastguard Worker r1 = opt_fn(v1, v2) 29*da0073e9SAndroid Build Coastguard Worker r2 = opt_fn(v2, v1) 30*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._dynamo.testing.same(r1, correct1)) 31*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._dynamo.testing.same(r2, correct2)) 32*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 33*da0073e9SAndroid Build Coastguard Worker cnt.frame_count, 34*da0073e9SAndroid Build Coastguard Worker frame_count, 35*da0073e9SAndroid Build Coastguard Worker f"actual {cnt.frame_count} != expected {frame_count}", 36*da0073e9SAndroid Build Coastguard Worker ) 37*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.op_count, op_count) 38*da0073e9SAndroid Build Coastguard Worker 39*da0073e9SAndroid Build Coastguard Worker def test_control_flow1(self): 40*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 41*da0073e9SAndroid Build Coastguard Worker c1 = a - b 42*da0073e9SAndroid Build Coastguard Worker c2 = b - a 43*da0073e9SAndroid Build Coastguard Worker if c1.sum() > c2.sum(): 44*da0073e9SAndroid Build Coastguard Worker return c1 45*da0073e9SAndroid Build Coastguard Worker else: 46*da0073e9SAndroid Build Coastguard Worker return c2 47*da0073e9SAndroid Build Coastguard Worker 48*da0073e9SAndroid Build Coastguard Worker self._common(fn, 1, 5) 49*da0073e9SAndroid Build Coastguard Worker 50*da0073e9SAndroid Build Coastguard Worker def test_control_flow2(self): 51*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 52*da0073e9SAndroid Build Coastguard Worker if a.sum() > b.sum(): 53*da0073e9SAndroid Build Coastguard Worker return 1 54*da0073e9SAndroid Build Coastguard Worker else: 55*da0073e9SAndroid Build Coastguard Worker return 2 56*da0073e9SAndroid Build Coastguard Worker 57*da0073e9SAndroid Build Coastguard Worker self._common(fn, 1, 3) 58*da0073e9SAndroid Build Coastguard Worker 59*da0073e9SAndroid Build Coastguard Worker def test_control_flow3(self): 60*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 61*da0073e9SAndroid Build Coastguard Worker c1 = a - b 62*da0073e9SAndroid Build Coastguard Worker c2 = b - a 63*da0073e9SAndroid Build Coastguard Worker m = globalmod 64*da0073e9SAndroid Build Coastguard Worker if c1.sum() > c2.sum(): 65*da0073e9SAndroid Build Coastguard Worker return m(c1) 66*da0073e9SAndroid Build Coastguard Worker else: 67*da0073e9SAndroid Build Coastguard Worker return m(c2) 68*da0073e9SAndroid Build Coastguard Worker 69*da0073e9SAndroid Build Coastguard Worker self._common(fn, 3, 7) 70*da0073e9SAndroid Build Coastguard Worker 71*da0073e9SAndroid Build Coastguard Worker def test_control_flow4(self): 72*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 73*da0073e9SAndroid Build Coastguard Worker tmp1 = a.sum() > b.sum() and a.sum() > 0 74*da0073e9SAndroid Build Coastguard Worker if tmp1: 75*da0073e9SAndroid Build Coastguard Worker return 1 76*da0073e9SAndroid Build Coastguard Worker else: 77*da0073e9SAndroid Build Coastguard Worker return 2 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard Worker self._common(fn, 3, 5) 80*da0073e9SAndroid Build Coastguard Worker 81*da0073e9SAndroid Build Coastguard Worker def test_control_flow5(self): 82*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 83*da0073e9SAndroid Build Coastguard Worker tmp1 = a.sum() > b.sum() and a.sum() > 0 84*da0073e9SAndroid Build Coastguard Worker tmp2 = a.sum() < b.sum() or b.sum() > 0 85*da0073e9SAndroid Build Coastguard Worker if tmp1 and tmp2: 86*da0073e9SAndroid Build Coastguard Worker return 1, tmp1, tmp2 87*da0073e9SAndroid Build Coastguard Worker else: 88*da0073e9SAndroid Build Coastguard Worker return 2, tmp1, tmp2 89*da0073e9SAndroid Build Coastguard Worker 90*da0073e9SAndroid Build Coastguard Worker self._common(fn, 6, 13) 91*da0073e9SAndroid Build Coastguard Worker 92*da0073e9SAndroid Build Coastguard Worker def test_capi_call1(self): 93*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 94*da0073e9SAndroid Build Coastguard Worker c1 = a - b 95*da0073e9SAndroid Build Coastguard Worker c2 = b - a 96*da0073e9SAndroid Build Coastguard Worker return unsupported(c1, c2) 97*da0073e9SAndroid Build Coastguard Worker 98*da0073e9SAndroid Build Coastguard Worker self._common(fn, 1, 2) 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Worker def test_capi_call2(self): 101*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 102*da0073e9SAndroid Build Coastguard Worker c1 = a - b 103*da0073e9SAndroid Build Coastguard Worker c2 = b - a 104*da0073e9SAndroid Build Coastguard Worker return a - (b - unsupported(c1, c2)) 105*da0073e9SAndroid Build Coastguard Worker 106*da0073e9SAndroid Build Coastguard Worker self._common(fn, 2, 4) 107*da0073e9SAndroid Build Coastguard Worker 108*da0073e9SAndroid Build Coastguard Worker def test_capi_call3(self): 109*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 110*da0073e9SAndroid Build Coastguard Worker c1 = a - b 111*da0073e9SAndroid Build Coastguard Worker c2 = b - a 112*da0073e9SAndroid Build Coastguard Worker return torch._dynamo.testing.unsupported(c1, c2) 113*da0073e9SAndroid Build Coastguard Worker 114*da0073e9SAndroid Build Coastguard Worker self._common(fn, 1, 2) 115*da0073e9SAndroid Build Coastguard Worker 116*da0073e9SAndroid Build Coastguard Worker def test_indirect_unsupported1(self): 117*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 118*da0073e9SAndroid Build Coastguard Worker c1 = a - b 119*da0073e9SAndroid Build Coastguard Worker c2 = b - a 120*da0073e9SAndroid Build Coastguard Worker return indirectly_unsupported(c1, c2) 121*da0073e9SAndroid Build Coastguard Worker 122*da0073e9SAndroid Build Coastguard Worker self._common(fn, 2, 3) 123*da0073e9SAndroid Build Coastguard Worker 124*da0073e9SAndroid Build Coastguard Worker def test_indirect_unsupported2(self): 125*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 126*da0073e9SAndroid Build Coastguard Worker local_const1 = 7 127*da0073e9SAndroid Build Coastguard Worker local_const2 = 22 128*da0073e9SAndroid Build Coastguard Worker c1 = a - b 129*da0073e9SAndroid Build Coastguard Worker c2 = b - a 130*da0073e9SAndroid Build Coastguard Worker return local_const1 / (local_const2 - indirectly_unsupported(c1, c2)) 131*da0073e9SAndroid Build Coastguard Worker 132*da0073e9SAndroid Build Coastguard Worker self._common(fn, 3, 5) 133*da0073e9SAndroid Build Coastguard Worker 134*da0073e9SAndroid Build Coastguard Worker def test_indirect_unsupported3(self): 135*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 136*da0073e9SAndroid Build Coastguard Worker args = [a - b, b - a] 137*da0073e9SAndroid Build Coastguard Worker return indirectly_unsupported(*args) 138*da0073e9SAndroid Build Coastguard Worker 139*da0073e9SAndroid Build Coastguard Worker self._common(fn, 2, 3) 140*da0073e9SAndroid Build Coastguard Worker 141*da0073e9SAndroid Build Coastguard Worker def test_stack_state1(self): 142*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 143*da0073e9SAndroid Build Coastguard Worker t1 = 1.23 * a 144*da0073e9SAndroid Build Coastguard Worker t2 = 4.56 * a 145*da0073e9SAndroid Build Coastguard Worker c1 = a - b 146*da0073e9SAndroid Build Coastguard Worker c2 = b - a 147*da0073e9SAndroid Build Coastguard Worker return t1 / (t2 - unsupported(c1, c2)) 148*da0073e9SAndroid Build Coastguard Worker 149*da0073e9SAndroid Build Coastguard Worker self._common(fn, 2, 6) 150*da0073e9SAndroid Build Coastguard Worker 151*da0073e9SAndroid Build Coastguard Worker def test_stack_state2(self): 152*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 153*da0073e9SAndroid Build Coastguard Worker t1 = 1.23 * a 154*da0073e9SAndroid Build Coastguard Worker t2 = 4.56 * a 155*da0073e9SAndroid Build Coastguard Worker c1 = a - b 156*da0073e9SAndroid Build Coastguard Worker c2 = b - a 157*da0073e9SAndroid Build Coastguard Worker return t1 / (t2 - indirectly_unsupported(c1, c2)) 158*da0073e9SAndroid Build Coastguard Worker 159*da0073e9SAndroid Build Coastguard Worker self._common(fn, 3, 7) 160*da0073e9SAndroid Build Coastguard Worker 161*da0073e9SAndroid Build Coastguard Worker def test_multigraph(self): 162*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 163*da0073e9SAndroid Build Coastguard Worker x = a + b 164*da0073e9SAndroid Build Coastguard Worker x = x / 2.0 165*da0073e9SAndroid Build Coastguard Worker if x.sum() < 0: 166*da0073e9SAndroid Build Coastguard Worker return x * -1.0 167*da0073e9SAndroid Build Coastguard Worker return x 168*da0073e9SAndroid Build Coastguard Worker 169*da0073e9SAndroid Build Coastguard Worker self._common(fn, 2, 5) 170*da0073e9SAndroid Build Coastguard Worker 171*da0073e9SAndroid Build Coastguard Worker def test_extended_args(self): 172*da0073e9SAndroid Build Coastguard Worker too_many_adds = "+".join(["a", "b"] * 256) 173*da0073e9SAndroid Build Coastguard Worker source = ( 174*da0073e9SAndroid Build Coastguard Worker f"lambda a, b: ({too_many_adds}+a if a.sum() > 0 else {too_many_adds} - b)" 175*da0073e9SAndroid Build Coastguard Worker ) 176*da0073e9SAndroid Build Coastguard Worker self._common(eval(source), 3, 1026) 177*da0073e9SAndroid Build Coastguard Worker 178*da0073e9SAndroid Build Coastguard Worker def test_resume1(self): 179*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 180*da0073e9SAndroid Build Coastguard Worker x = a + b 181*da0073e9SAndroid Build Coastguard Worker x = x / 2.0 182*da0073e9SAndroid Build Coastguard Worker x = x + 2.0 183*da0073e9SAndroid Build Coastguard Worker x = unsupported(x, a) 184*da0073e9SAndroid Build Coastguard Worker x = x + 2.0 185*da0073e9SAndroid Build Coastguard Worker x = x + 2.0 186*da0073e9SAndroid Build Coastguard Worker x = x + 2.0 187*da0073e9SAndroid Build Coastguard Worker return x 188*da0073e9SAndroid Build Coastguard Worker 189*da0073e9SAndroid Build Coastguard Worker self._common(fn, 2, 6) 190*da0073e9SAndroid Build Coastguard Worker 191*da0073e9SAndroid Build Coastguard Worker def test_resume2(self): 192*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 193*da0073e9SAndroid Build Coastguard Worker x = a + b 194*da0073e9SAndroid Build Coastguard Worker x = x / 2.0 195*da0073e9SAndroid Build Coastguard Worker x = x + 2.0 196*da0073e9SAndroid Build Coastguard Worker x = indirectly_unsupported(x, a) 197*da0073e9SAndroid Build Coastguard Worker x = x + 2.0 198*da0073e9SAndroid Build Coastguard Worker x = x + 2.0 199*da0073e9SAndroid Build Coastguard Worker x = x + 2.0 200*da0073e9SAndroid Build Coastguard Worker return x 201*da0073e9SAndroid Build Coastguard Worker 202*da0073e9SAndroid Build Coastguard Worker self._common(fn, 3, 7) 203*da0073e9SAndroid Build Coastguard Worker 204*da0073e9SAndroid Build Coastguard Worker def test_resume3(self): 205*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 206*da0073e9SAndroid Build Coastguard Worker x = a + b 207*da0073e9SAndroid Build Coastguard Worker x = x / 2.0 208*da0073e9SAndroid Build Coastguard Worker x = x + 2.0 209*da0073e9SAndroid Build Coastguard Worker x = indirectly_unsupported(x, b=a) 210*da0073e9SAndroid Build Coastguard Worker x = x + 2.0 211*da0073e9SAndroid Build Coastguard Worker x = x + 2.0 212*da0073e9SAndroid Build Coastguard Worker x = x + 2.0 213*da0073e9SAndroid Build Coastguard Worker return x 214*da0073e9SAndroid Build Coastguard Worker 215*da0073e9SAndroid Build Coastguard Worker self._common(fn, 3, 7) 216*da0073e9SAndroid Build Coastguard Worker 217*da0073e9SAndroid Build Coastguard Worker def test_resume4(self): 218*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 219*da0073e9SAndroid Build Coastguard Worker x = a + b 220*da0073e9SAndroid Build Coastguard Worker x = x / 2.0 221*da0073e9SAndroid Build Coastguard Worker x = x + 2.0 222*da0073e9SAndroid Build Coastguard Worker x = indirectly_unsupported(a=x, b=a) 223*da0073e9SAndroid Build Coastguard Worker x = x + 2.0 224*da0073e9SAndroid Build Coastguard Worker x = x + 2.0 225*da0073e9SAndroid Build Coastguard Worker x = x + 2.0 226*da0073e9SAndroid Build Coastguard Worker return x 227*da0073e9SAndroid Build Coastguard Worker 228*da0073e9SAndroid Build Coastguard Worker self._common(fn, 3, 7) 229*da0073e9SAndroid Build Coastguard Worker 230*da0073e9SAndroid Build Coastguard Worker def test_resume5(self): 231*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 232*da0073e9SAndroid Build Coastguard Worker x = a + b 233*da0073e9SAndroid Build Coastguard Worker x = x / 2.0 234*da0073e9SAndroid Build Coastguard Worker x = x + 2.0 235*da0073e9SAndroid Build Coastguard Worker print(x) 236*da0073e9SAndroid Build Coastguard Worker x = x + 2.0 237*da0073e9SAndroid Build Coastguard Worker x = x + 2.0 238*da0073e9SAndroid Build Coastguard Worker x = x + 2.0 239*da0073e9SAndroid Build Coastguard Worker return x 240*da0073e9SAndroid Build Coastguard Worker 241*da0073e9SAndroid Build Coastguard Worker self._common(fn, 2, 6) 242*da0073e9SAndroid Build Coastguard Worker 243*da0073e9SAndroid Build Coastguard Worker def test_start1(self): 244*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 245*da0073e9SAndroid Build Coastguard Worker print(a) 246*da0073e9SAndroid Build Coastguard Worker x = a + b 247*da0073e9SAndroid Build Coastguard Worker x = x + 2.0 248*da0073e9SAndroid Build Coastguard Worker x = x + 2.0 249*da0073e9SAndroid Build Coastguard Worker return x 250*da0073e9SAndroid Build Coastguard Worker 251*da0073e9SAndroid Build Coastguard Worker self._common(fn, 1, 3) 252*da0073e9SAndroid Build Coastguard Worker 253*da0073e9SAndroid Build Coastguard Worker def test_start2(self): 254*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 255*da0073e9SAndroid Build Coastguard Worker x = indirectly_unsupported(a, b) 256*da0073e9SAndroid Build Coastguard Worker x = x + 2.0 257*da0073e9SAndroid Build Coastguard Worker x = x + 2.0 258*da0073e9SAndroid Build Coastguard Worker x = x + 2.0 259*da0073e9SAndroid Build Coastguard Worker return x 260*da0073e9SAndroid Build Coastguard Worker 261*da0073e9SAndroid Build Coastguard Worker self._common(fn, 2, 4) 262*da0073e9SAndroid Build Coastguard Worker 263*da0073e9SAndroid Build Coastguard Worker def test_start3(self): 264*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 265*da0073e9SAndroid Build Coastguard Worker x = unsupported(a, b) 266*da0073e9SAndroid Build Coastguard Worker x = x + 2.0 267*da0073e9SAndroid Build Coastguard Worker x = x + 2.0 268*da0073e9SAndroid Build Coastguard Worker x = x + 2.0 269*da0073e9SAndroid Build Coastguard Worker return x 270*da0073e9SAndroid Build Coastguard Worker 271*da0073e9SAndroid Build Coastguard Worker self._common(fn, 1, 3) 272*da0073e9SAndroid Build Coastguard Worker 273*da0073e9SAndroid Build Coastguard Worker def test_start4(self): 274*da0073e9SAndroid Build Coastguard Worker def fn(a, b, check): 275*da0073e9SAndroid Build Coastguard Worker if check: 276*da0073e9SAndroid Build Coastguard Worker return a + b + 10 277*da0073e9SAndroid Build Coastguard Worker else: 278*da0073e9SAndroid Build Coastguard Worker return a + b - 10 279*da0073e9SAndroid Build Coastguard Worker 280*da0073e9SAndroid Build Coastguard Worker v1 = torch.randn(10) 281*da0073e9SAndroid Build Coastguard Worker v2 = torch.randn(10) 282*da0073e9SAndroid Build Coastguard Worker f = torch.zeros(1, dtype=torch.int32) 283*da0073e9SAndroid Build Coastguard Worker t = torch.ones(1, dtype=torch.int32) 284*da0073e9SAndroid Build Coastguard Worker correct1 = fn(v1, v2, t) 285*da0073e9SAndroid Build Coastguard Worker correct2 = fn(v1, v2, f) 286*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 287*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnt)(fn) 288*da0073e9SAndroid Build Coastguard Worker r1 = opt_fn(v1, v2, t) 289*da0073e9SAndroid Build Coastguard Worker r2 = opt_fn(v1, v2, f) 290*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._dynamo.testing.same(r1, correct1)) 291*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._dynamo.testing.same(r2, correct2)) 292*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 3) 293*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.op_count, 4) 294*da0073e9SAndroid Build Coastguard Worker 295*da0073e9SAndroid Build Coastguard Worker def test_resume_freevars(self): 296*da0073e9SAndroid Build Coastguard Worker c1 = torch.randn(10) 297*da0073e9SAndroid Build Coastguard Worker c2 = torch.randn(10) 298*da0073e9SAndroid Build Coastguard Worker 299*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 300*da0073e9SAndroid Build Coastguard Worker x = a + b + (c1 - c2) 301*da0073e9SAndroid Build Coastguard Worker x = unsupported(x, x) 302*da0073e9SAndroid Build Coastguard Worker return x + (c1 - c2) 303*da0073e9SAndroid Build Coastguard Worker 304*da0073e9SAndroid Build Coastguard Worker self._common(fn, 2, 5) 305*da0073e9SAndroid Build Coastguard Worker 306*da0073e9SAndroid Build Coastguard Worker def test_restore_state(self): 307*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 308*da0073e9SAndroid Build Coastguard Worker len_ = len 309*da0073e9SAndroid Build Coastguard Worker x = a + b 310*da0073e9SAndroid Build Coastguard Worker x = torch.add(unsupported(x, x), 1) 311*da0073e9SAndroid Build Coastguard Worker return a * x + len_(b) 312*da0073e9SAndroid Build Coastguard Worker 313*da0073e9SAndroid Build Coastguard Worker self._common(fn, 2, 4) 314*da0073e9SAndroid Build Coastguard Worker 315*da0073e9SAndroid Build Coastguard Worker def test_restore_range(self): 316*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 317*da0073e9SAndroid Build Coastguard Worker x = a + b 318*da0073e9SAndroid Build Coastguard Worker rng = range(3, 8, 2) 319*da0073e9SAndroid Build Coastguard Worker x = unsupported(x, x) 320*da0073e9SAndroid Build Coastguard Worker for i in rng: 321*da0073e9SAndroid Build Coastguard Worker x = x + i 322*da0073e9SAndroid Build Coastguard Worker return x 323*da0073e9SAndroid Build Coastguard Worker 324*da0073e9SAndroid Build Coastguard Worker # We don't specialize on range with dynamic shapes, which 325*da0073e9SAndroid Build Coastguard Worker # means we fail to unroll the loop. 326*da0073e9SAndroid Build Coastguard Worker # TODO: Consider forcing specialization when we iterate over 327*da0073e9SAndroid Build Coastguard Worker # the loop 328*da0073e9SAndroid Build Coastguard Worker self._common(fn, ifdynstaticdefault(2, 1), ifdynstaticdefault(4, 1)) 329*da0073e9SAndroid Build Coastguard Worker 330*da0073e9SAndroid Build Coastguard Worker def test_restore_range_iter(self): 331*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 332*da0073e9SAndroid Build Coastguard Worker x = a + b 333*da0073e9SAndroid Build Coastguard Worker rng = iter(range(3, 8, 2)) 334*da0073e9SAndroid Build Coastguard Worker x = unsupported(x, x) 335*da0073e9SAndroid Build Coastguard Worker x += next(rng) 336*da0073e9SAndroid Build Coastguard Worker return x, list(rng) 337*da0073e9SAndroid Build Coastguard Worker 338*da0073e9SAndroid Build Coastguard Worker self._common(fn, 2, 2) 339*da0073e9SAndroid Build Coastguard Worker 340*da0073e9SAndroid Build Coastguard Worker def test_pop_after_resume(self): 341*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 342*da0073e9SAndroid Build Coastguard Worker tmp = [a + 1, b + 2, a + b] 343*da0073e9SAndroid Build Coastguard Worker x = a 344*da0073e9SAndroid Build Coastguard Worker x = unsupported(x, x) 345*da0073e9SAndroid Build Coastguard Worker for i in range(3): 346*da0073e9SAndroid Build Coastguard Worker x += tmp.pop(-1) 347*da0073e9SAndroid Build Coastguard Worker return x 348*da0073e9SAndroid Build Coastguard Worker 349*da0073e9SAndroid Build Coastguard Worker self._common(fn, 2, 6) 350*da0073e9SAndroid Build Coastguard Worker 351*da0073e9SAndroid Build Coastguard Worker @patch("torch._dynamo.config.assume_static_by_default", False) 352*da0073e9SAndroid Build Coastguard Worker def test_dynamic_getitem(self): 353*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 354*da0073e9SAndroid Build Coastguard Worker return a[b.size(0) - 1] 355*da0073e9SAndroid Build Coastguard Worker 356*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 357*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnt)(fn) 358*da0073e9SAndroid Build Coastguard Worker for i in range(3, 12): 359*da0073e9SAndroid Build Coastguard Worker opt_fn(torch.randn(i), torch.randn(i)) 360*da0073e9SAndroid Build Coastguard Worker # just one graph 361*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 362*da0073e9SAndroid Build Coastguard Worker 363*da0073e9SAndroid Build Coastguard Worker def test_dynamic_kwarg(self): 364*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 365*da0073e9SAndroid Build Coastguard Worker return a - b * 10 366*da0073e9SAndroid Build Coastguard Worker 367*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 368*da0073e9SAndroid Build Coastguard Worker cnt_dynamic = torch._dynamo.testing.CompileCounter() 369*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnt_dynamic, dynamic=True)(fn) 370*da0073e9SAndroid Build Coastguard Worker start = 2 371*da0073e9SAndroid Build Coastguard Worker end = 12 372*da0073e9SAndroid Build Coastguard Worker steps = end - start 373*da0073e9SAndroid Build Coastguard Worker for i in range(start, end): 374*da0073e9SAndroid Build Coastguard Worker opt_fn(torch.randn(i), torch.randn(i)) 375*da0073e9SAndroid Build Coastguard Worker 376*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt_dynamic.frame_count, 1) 377*da0073e9SAndroid Build Coastguard Worker 378*da0073e9SAndroid Build Coastguard Worker def test_dynamic_duck_size(self): 379*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 380*da0073e9SAndroid Build Coastguard Worker if a.size(0) == b.size(0): 381*da0073e9SAndroid Build Coastguard Worker return a + b 382*da0073e9SAndroid Build Coastguard Worker else: 383*da0073e9SAndroid Build Coastguard Worker return a.sum() + b.sum() 384*da0073e9SAndroid Build Coastguard Worker 385*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 386*da0073e9SAndroid Build Coastguard Worker cnt_dynamic = torch._dynamo.testing.CompileCounter() 387*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnt_dynamic, dynamic=True)(fn) 388*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2) 389*da0073e9SAndroid Build Coastguard Worker y = torch.randn(3) 390*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(x, x), fn(x, x)) 391*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(x, y), fn(x, y)) 392*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt_dynamic.frame_count, 2) 393*da0073e9SAndroid Build Coastguard Worker 394*da0073e9SAndroid Build Coastguard Worker def test_dynamic_order_dependence(self): 395*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 396*da0073e9SAndroid Build Coastguard Worker return a.sum() + b.sum() 397*da0073e9SAndroid Build Coastguard Worker 398*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 399*da0073e9SAndroid Build Coastguard Worker cnt_dynamic = torch._dynamo.testing.CompileCounter() 400*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnt_dynamic)(fn) 401*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2) 402*da0073e9SAndroid Build Coastguard Worker y = torch.randn(3) 403*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(x, y), fn(x, y)) 404*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(x, x), fn(x, x)) 405*da0073e9SAndroid Build Coastguard Worker # NB: This COULD validly be 2, but we don't test disjointness in the 406*da0073e9SAndroid Build Coastguard Worker # guards for when x and y didn't duck size together, so we end up 407*da0073e9SAndroid Build Coastguard Worker # with a generic graph that also works when x and y happen to duck 408*da0073e9SAndroid Build Coastguard Worker # size together. 409*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt_dynamic.frame_count, 2) 410*da0073e9SAndroid Build Coastguard Worker 411*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 412*da0073e9SAndroid Build Coastguard Worker cnt_dynamic.frame_count = 0 413*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(x, x), fn(x, x)) # this overspecializes! 414*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(x, y), fn(x, y)) 415*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt_dynamic.frame_count, 2) 416*da0073e9SAndroid Build Coastguard Worker 417*da0073e9SAndroid Build Coastguard Worker def test_dynamic_zero_inference(self): 418*da0073e9SAndroid Build Coastguard Worker def fn(a): 419*da0073e9SAndroid Build Coastguard Worker if a.size(0) != 0: 420*da0073e9SAndroid Build Coastguard Worker return a * 2 421*da0073e9SAndroid Build Coastguard Worker else: 422*da0073e9SAndroid Build Coastguard Worker return a + 1 423*da0073e9SAndroid Build Coastguard Worker 424*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 425*da0073e9SAndroid Build Coastguard Worker cnt_dynamic = torch._dynamo.testing.CompileCounter() 426*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnt_dynamic, dynamic=True)(fn) 427*da0073e9SAndroid Build Coastguard Worker x = torch.randn(0) 428*da0073e9SAndroid Build Coastguard Worker y = torch.randn(2) 429*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(y), fn(y)) 430*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(x), fn(x)) 431*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt_dynamic.frame_count, 2) 432*da0073e9SAndroid Build Coastguard Worker 433*da0073e9SAndroid Build Coastguard Worker @patch.object(torch._dynamo.config, "capture_scalar_outputs", True) 434*da0073e9SAndroid Build Coastguard Worker def test_no_graph_break_on_item(self): 435*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 436*da0073e9SAndroid Build Coastguard Worker x = a + b - 1.5 437*da0073e9SAndroid Build Coastguard Worker x = x.sum() 438*da0073e9SAndroid Build Coastguard Worker x.item() 439*da0073e9SAndroid Build Coastguard Worker x = x / (a + b) 440*da0073e9SAndroid Build Coastguard Worker return x 441*da0073e9SAndroid Build Coastguard Worker 442*da0073e9SAndroid Build Coastguard Worker self._common(fn, 1, 5) # item gets DCE'd 443*da0073e9SAndroid Build Coastguard Worker 444*da0073e9SAndroid Build Coastguard Worker @patch.object(torch._dynamo.config, "capture_scalar_outputs", False) 445*da0073e9SAndroid Build Coastguard Worker def test_graph_break_on_item(self): 446*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 447*da0073e9SAndroid Build Coastguard Worker x = a + b - 1.5 448*da0073e9SAndroid Build Coastguard Worker x = x.sum() 449*da0073e9SAndroid Build Coastguard Worker x.item() 450*da0073e9SAndroid Build Coastguard Worker x = x / (a + b) 451*da0073e9SAndroid Build Coastguard Worker return x 452*da0073e9SAndroid Build Coastguard Worker 453*da0073e9SAndroid Build Coastguard Worker self._common(fn, 2, 5) 454*da0073e9SAndroid Build Coastguard Worker 455*da0073e9SAndroid Build Coastguard Worker def test_resume_paths_join(self): 456*da0073e9SAndroid Build Coastguard Worker def fn(x, c1, c2, c3): 457*da0073e9SAndroid Build Coastguard Worker x = x + 1 458*da0073e9SAndroid Build Coastguard Worker if c1: 459*da0073e9SAndroid Build Coastguard Worker x = x + 2 460*da0073e9SAndroid Build Coastguard Worker x = x + 3 461*da0073e9SAndroid Build Coastguard Worker if c2: 462*da0073e9SAndroid Build Coastguard Worker x = x + 4 463*da0073e9SAndroid Build Coastguard Worker x = x + 5 464*da0073e9SAndroid Build Coastguard Worker if c3: 465*da0073e9SAndroid Build Coastguard Worker x = x + 6 466*da0073e9SAndroid Build Coastguard Worker return x + 7 467*da0073e9SAndroid Build Coastguard Worker 468*da0073e9SAndroid Build Coastguard Worker v1 = torch.randn(10) 469*da0073e9SAndroid Build Coastguard Worker t = torch.Tensor([True]) 470*da0073e9SAndroid Build Coastguard Worker f = torch.Tensor([False]) 471*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 472*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnt)(fn) 473*da0073e9SAndroid Build Coastguard Worker for a in (t, f): 474*da0073e9SAndroid Build Coastguard Worker for b in (t, f): 475*da0073e9SAndroid Build Coastguard Worker for c in (t, f): 476*da0073e9SAndroid Build Coastguard Worker opt_fn(v1, a, b, c) 477*da0073e9SAndroid Build Coastguard Worker 478*da0073e9SAndroid Build Coastguard Worker # checking here we don't create 2^n graphs 479*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 7) 480*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.op_count, 10) 481*da0073e9SAndroid Build Coastguard Worker 482*da0073e9SAndroid Build Coastguard Worker def test_resume_with_no_grad1(self): 483*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 484*da0073e9SAndroid Build Coastguard Worker x = a + b 485*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 486*da0073e9SAndroid Build Coastguard Worker x = x + 1 487*da0073e9SAndroid Build Coastguard Worker x.sum().tolist() # graph break 488*da0073e9SAndroid Build Coastguard Worker x = x + 2 489*da0073e9SAndroid Build Coastguard Worker x = x + 3 490*da0073e9SAndroid Build Coastguard Worker return x 491*da0073e9SAndroid Build Coastguard Worker 492*da0073e9SAndroid Build Coastguard Worker self._common(fn, 2, 9) 493*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 494*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 495*da0073e9SAndroid Build Coastguard Worker self._common(fn, 2, 5) 496*da0073e9SAndroid Build Coastguard Worker 497*da0073e9SAndroid Build Coastguard Worker def test_resume_with_no_grad2(self): 498*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 499*da0073e9SAndroid Build Coastguard Worker x = a + b 500*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 501*da0073e9SAndroid Build Coastguard Worker x = x + 1 502*da0073e9SAndroid Build Coastguard Worker x.sum().tolist() # graph break 503*da0073e9SAndroid Build Coastguard Worker x = x + 2 504*da0073e9SAndroid Build Coastguard Worker x.sum().tolist() # graph break 505*da0073e9SAndroid Build Coastguard Worker x = x + 3 506*da0073e9SAndroid Build Coastguard Worker x = x + 4 507*da0073e9SAndroid Build Coastguard Worker return x 508*da0073e9SAndroid Build Coastguard Worker 509*da0073e9SAndroid Build Coastguard Worker self._common(fn, 3, 13) 510*da0073e9SAndroid Build Coastguard Worker 511*da0073e9SAndroid Build Coastguard Worker def test_resume_with_no_grad3(self): 512*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 513*da0073e9SAndroid Build Coastguard Worker x = a + b 514*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 515*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 516*da0073e9SAndroid Build Coastguard Worker x = x + 1 517*da0073e9SAndroid Build Coastguard Worker with torch.enable_grad(): 518*da0073e9SAndroid Build Coastguard Worker x.sum().tolist() # graph break 519*da0073e9SAndroid Build Coastguard Worker x = x[0] + 2 520*da0073e9SAndroid Build Coastguard Worker x = x + 3 521*da0073e9SAndroid Build Coastguard Worker x = x + 4 522*da0073e9SAndroid Build Coastguard Worker return x 523*da0073e9SAndroid Build Coastguard Worker 524*da0073e9SAndroid Build Coastguard Worker self._common(fn, 2, 11) 525*da0073e9SAndroid Build Coastguard Worker 526*da0073e9SAndroid Build Coastguard Worker def test_resume_tuple_iterator(self): 527*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 528*da0073e9SAndroid Build Coastguard Worker x = a + b 529*da0073e9SAndroid Build Coastguard Worker it = iter(tuple(range(10))) 530*da0073e9SAndroid Build Coastguard Worker x = x + next(it) 531*da0073e9SAndroid Build Coastguard Worker x = x + next(it) 532*da0073e9SAndroid Build Coastguard Worker x = x + next(it) 533*da0073e9SAndroid Build Coastguard Worker x = unsupported(x, x) 534*da0073e9SAndroid Build Coastguard Worker x = x + next(it) 535*da0073e9SAndroid Build Coastguard Worker x = x + next(it) 536*da0073e9SAndroid Build Coastguard Worker x = x + next(it) 537*da0073e9SAndroid Build Coastguard Worker x = x + next(it) 538*da0073e9SAndroid Build Coastguard Worker return x 539*da0073e9SAndroid Build Coastguard Worker 540*da0073e9SAndroid Build Coastguard Worker self._common(fn, 2, 8) 541*da0073e9SAndroid Build Coastguard Worker 542*da0073e9SAndroid Build Coastguard Worker def test_tuple_iterator_return(self): 543*da0073e9SAndroid Build Coastguard Worker def fn(x): 544*da0073e9SAndroid Build Coastguard Worker it = iter(tuple(range(10))) 545*da0073e9SAndroid Build Coastguard Worker x = x + next(it) 546*da0073e9SAndroid Build Coastguard Worker x = x + next(it) 547*da0073e9SAndroid Build Coastguard Worker x = unsupported(x, x) 548*da0073e9SAndroid Build Coastguard Worker x = x + next(it) 549*da0073e9SAndroid Build Coastguard Worker x = x + next(it) 550*da0073e9SAndroid Build Coastguard Worker x = unsupported(x, x) 551*da0073e9SAndroid Build Coastguard Worker x = x + next(it) 552*da0073e9SAndroid Build Coastguard Worker x = x + next(it) 553*da0073e9SAndroid Build Coastguard Worker return x, it 554*da0073e9SAndroid Build Coastguard Worker 555*da0073e9SAndroid Build Coastguard Worker v1 = torch.randn(10) 556*da0073e9SAndroid Build Coastguard Worker v2, it2 = fn(v1) 557*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 558*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnt)(fn) 559*da0073e9SAndroid Build Coastguard Worker v3, it3 = opt_fn(v1) 560*da0073e9SAndroid Build Coastguard Worker v4, it4 = opt_fn(v1) 561*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v2.tolist(), v3.tolist()) 562*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v2.tolist(), v4.tolist()) 563*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(it2), list(it3)) 564*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 3) 565*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.op_count, 6) 566*da0073e9SAndroid Build Coastguard Worker 567*da0073e9SAndroid Build Coastguard Worker def test_tuple_iterator_mutate(self): 568*da0073e9SAndroid Build Coastguard Worker def fn(x, it): 569*da0073e9SAndroid Build Coastguard Worker x = x + next(it) 570*da0073e9SAndroid Build Coastguard Worker x = x + next(it) 571*da0073e9SAndroid Build Coastguard Worker x = x + next(it) 572*da0073e9SAndroid Build Coastguard Worker x = x + next(it) 573*da0073e9SAndroid Build Coastguard Worker return x 574*da0073e9SAndroid Build Coastguard Worker 575*da0073e9SAndroid Build Coastguard Worker v1 = torch.randn(10) 576*da0073e9SAndroid Build Coastguard Worker it1 = iter(tuple(range(10))) 577*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 578*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnt)(fn) 579*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(v1, it1).tolist(), (v1 + 1 + 2 + 3).tolist()) 580*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(it1), [4, 5, 6, 7, 8, 9]) 581*da0073e9SAndroid Build Coastguard Worker 582*da0073e9SAndroid Build Coastguard Worker def test_enumerate_not_break_graph(self): 583*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 584*da0073e9SAndroid Build Coastguard Worker for i, x in enumerate(a.shape): 585*da0073e9SAndroid Build Coastguard Worker b = b + x 586*da0073e9SAndroid Build Coastguard Worker for i, x in enumerate(b.shape, 8): 587*da0073e9SAndroid Build Coastguard Worker b = b + x * i 588*da0073e9SAndroid Build Coastguard Worker return b 589*da0073e9SAndroid Build Coastguard Worker 590*da0073e9SAndroid Build Coastguard Worker self._common(fn, 1, ifdynstaticdefault(2, 3)) 591*da0073e9SAndroid Build Coastguard Worker 592*da0073e9SAndroid Build Coastguard Worker 593*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 594*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.test_case import run_tests 595*da0073e9SAndroid Build Coastguard Worker 596*da0073e9SAndroid Build Coastguard Worker run_tests() 597