1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport os 4*da0073e9SAndroid Build Coastguard Workerimport sys 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Workerimport torch 7*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import skipIfTorchDynamo 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Worker# Make the helper files in test/ importable 11*da0073e9SAndroid Build Coastguard Workerpytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 12*da0073e9SAndroid Build Coastguard Workersys.path.append(pytorch_test_dir) 13*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import FileCheck, JitTestCase, warmup_backward 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 17*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 18*da0073e9SAndroid Build Coastguard Worker "This test file is not meant to be run directly, use:\n\n" 19*da0073e9SAndroid Build Coastguard Worker "\tpython test/test_jit.py TESTNAME\n\n" 20*da0073e9SAndroid Build Coastguard Worker "instead." 21*da0073e9SAndroid Build Coastguard Worker ) 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Worker@skipIfTorchDynamo() 25*da0073e9SAndroid Build Coastguard Workerclass TestProfiler(JitTestCase): 26*da0073e9SAndroid Build Coastguard Worker def setUp(self): 27*da0073e9SAndroid Build Coastguard Worker self.prev_exec = torch._C._jit_set_profiling_executor(True) 28*da0073e9SAndroid Build Coastguard Worker self.prev_profiling = torch._C._get_graph_executor_optimize(True) 29*da0073e9SAndroid Build Coastguard Worker self.inline_autodiff = torch._C._debug_set_autodiff_subgraph_inlining(False) 30*da0073e9SAndroid Build Coastguard Worker self.texpr_fuser_state = torch._C._jit_texpr_fuser_enabled() 31*da0073e9SAndroid Build Coastguard Worker self.can_fuse_on_cpu = torch._C._jit_can_fuse_on_cpu() 32*da0073e9SAndroid Build Coastguard Worker torch._C._jit_set_texpr_fuser_enabled(True) 33*da0073e9SAndroid Build Coastguard Worker torch._C._jit_override_can_fuse_on_cpu(True) 34*da0073e9SAndroid Build Coastguard Worker self.default_dtype = torch.get_default_dtype() 35*da0073e9SAndroid Build Coastguard Worker self.old_reduction_enabled = torch._C._jit_set_texpr_reductions_enabled(True) 36*da0073e9SAndroid Build Coastguard Worker torch.set_default_dtype(torch.double) 37*da0073e9SAndroid Build Coastguard Worker self.old_fusion_inlining = torch._C._debug_get_fusion_group_inlining() 38*da0073e9SAndroid Build Coastguard Worker torch._C._debug_set_fusion_group_inlining(False) 39*da0073e9SAndroid Build Coastguard Worker self.old_te_must_use_llvm_cpu = torch._C._jit_get_te_must_use_llvm_cpu() 40*da0073e9SAndroid Build Coastguard Worker torch._C._jit_set_te_must_use_llvm_cpu(False) 41*da0073e9SAndroid Build Coastguard Worker 42*da0073e9SAndroid Build Coastguard Worker def tearDown(self): 43*da0073e9SAndroid Build Coastguard Worker torch._C._jit_set_profiling_executor(self.prev_exec) 44*da0073e9SAndroid Build Coastguard Worker torch._C._get_graph_executor_optimize(self.prev_profiling) 45*da0073e9SAndroid Build Coastguard Worker torch._C._debug_set_autodiff_subgraph_inlining(self.inline_autodiff) 46*da0073e9SAndroid Build Coastguard Worker torch._C._jit_set_texpr_fuser_enabled(self.texpr_fuser_state) 47*da0073e9SAndroid Build Coastguard Worker torch._C._jit_override_can_fuse_on_cpu(self.can_fuse_on_cpu) 48*da0073e9SAndroid Build Coastguard Worker torch.set_default_dtype(self.default_dtype) 49*da0073e9SAndroid Build Coastguard Worker torch._C._jit_set_texpr_reductions_enabled(self.old_reduction_enabled) 50*da0073e9SAndroid Build Coastguard Worker torch._C._debug_set_fusion_group_inlining(self.old_fusion_inlining) 51*da0073e9SAndroid Build Coastguard Worker torch._C._jit_set_te_must_use_llvm_cpu(self.old_te_must_use_llvm_cpu) 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard Worker def test_tensor_type_not_determined_by_inputs(self): 54*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 55*da0073e9SAndroid Build Coastguard Worker def scalar_type_input(x, y, z): 56*da0073e9SAndroid Build Coastguard Worker return x + y + 4 + z.item() 57*da0073e9SAndroid Build Coastguard Worker 58*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([2, 2]) 59*da0073e9SAndroid Build Coastguard Worker scalar_type_input(x, x, torch.tensor(1)) 60*da0073e9SAndroid Build Coastguard Worker scalar_type_input(x, x, torch.tensor(1)) 61*da0073e9SAndroid Build Coastguard Worker scalar_type_input(x, x, torch.tensor(1.0)) 62*da0073e9SAndroid Build Coastguard Worker g = torch.jit.last_executed_optimized_graph() 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Worker # item & add should not get pulled into the fusion group - 65*da0073e9SAndroid Build Coastguard Worker # we expect to see Fusion Group (item / add) Fusion Group in ir dump 66*da0073e9SAndroid Build Coastguard Worker FileCheck().check("TensorExpr").check("Scalar = aten::item").check_next( 67*da0073e9SAndroid Build Coastguard Worker "Tensor = aten::add" 68*da0073e9SAndroid Build Coastguard Worker ).check("TensorExpr").run(g) 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 71*da0073e9SAndroid Build Coastguard Worker def non_const_dtype(x, y, cond: bool): 72*da0073e9SAndroid Build Coastguard Worker dtype = torch.int16 if cond else torch.int32 73*da0073e9SAndroid Build Coastguard Worker return (x + y + 3).sum(dtype=dtype) 74*da0073e9SAndroid Build Coastguard Worker 75*da0073e9SAndroid Build Coastguard Worker non_const_dtype(x, x, True) 76*da0073e9SAndroid Build Coastguard Worker non_const_dtype(x, x, True) 77*da0073e9SAndroid Build Coastguard Worker g = torch.jit.last_executed_optimized_graph() 78*da0073e9SAndroid Build Coastguard Worker # because dtype is non-const, sum should not get pulled into the Fusion Group 79*da0073e9SAndroid Build Coastguard Worker FileCheck().check("TensorExpr").check("TensorExpr").check_not("aten::sum").run( 80*da0073e9SAndroid Build Coastguard Worker g 81*da0073e9SAndroid Build Coastguard Worker ) 82*da0073e9SAndroid Build Coastguard Worker 83*da0073e9SAndroid Build Coastguard Worker def test_specialize_backward(self): 84*da0073e9SAndroid Build Coastguard Worker def test_fuse(a, b): 85*da0073e9SAndroid Build Coastguard Worker c = a * b 86*da0073e9SAndroid Build Coastguard Worker d = c * b 87*da0073e9SAndroid Build Coastguard Worker return d 88*da0073e9SAndroid Build Coastguard Worker 89*da0073e9SAndroid Build Coastguard Worker test_fuse.__disable_jit_function_caching__ = True 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Worker scripted_f = torch.jit.script(test_fuse) 92*da0073e9SAndroid Build Coastguard Worker x = torch.ones(1, requires_grad=True) 93*da0073e9SAndroid Build Coastguard Worker y = torch.ones(1, requires_grad=True) 94*da0073e9SAndroid Build Coastguard Worker scripted_f(x, y) 95*da0073e9SAndroid Build Coastguard Worker b = scripted_f(x, y) 96*da0073e9SAndroid Build Coastguard Worker warmup_backward(b) 97*da0073e9SAndroid Build Coastguard Worker g = torch.jit.last_executed_optimized_graph() 98*da0073e9SAndroid Build Coastguard Worker # Backward has an if node guarding specializations, 99*da0073e9SAndroid Build Coastguard Worker # within the if node true block there is only one if node 100*da0073e9SAndroid Build Coastguard Worker # that guards a tensorexpr group 101*da0073e9SAndroid Build Coastguard Worker optimized_block = next(g.findNode("prim::If").blocks()) 102*da0073e9SAndroid Build Coastguard Worker if_nodes = list(optimized_block.findAllNodes("prim::If")) 103*da0073e9SAndroid Build Coastguard Worker 104*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(if_nodes), 1) 105*da0073e9SAndroid Build Coastguard Worker FileCheck().check("Group[Subgraph").run(str(if_nodes[0])) 106*da0073e9SAndroid Build Coastguard Worker # no broadcasts occurred, sum_to_size have been specialized out 107*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(optimized_block.findNode("aten::_grad_sum_to_size")) 108*da0073e9SAndroid Build Coastguard Worker 109*da0073e9SAndroid Build Coastguard Worker broadcast_f = torch.jit.script(test_fuse) 110*da0073e9SAndroid Build Coastguard Worker x = torch.ones([2, 2], requires_grad=True) 111*da0073e9SAndroid Build Coastguard Worker y = torch.ones([1], requires_grad=True) 112*da0073e9SAndroid Build Coastguard Worker broadcast_f(x, y) 113*da0073e9SAndroid Build Coastguard Worker b = broadcast_f(x, y) 114*da0073e9SAndroid Build Coastguard Worker b.backward(torch.ones([2, 2], dtype=torch.float), retain_graph=True) 115*da0073e9SAndroid Build Coastguard Worker b.backward(torch.ones([2, 2], dtype=torch.float)) 116*da0073e9SAndroid Build Coastguard Worker # warmup_backward(b, torch.ones([2, 2], dtype=torch.float)) 117*da0073e9SAndroid Build Coastguard Worker g = torch.jit.last_executed_optimized_graph() 118*da0073e9SAndroid Build Coastguard Worker optimized_block = next(g.findNode("prim::If").blocks()) 119*da0073e9SAndroid Build Coastguard Worker # broadcasts occurred, currently expect to see aten::_grad_sum_to_size 120*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(optimized_block.findNode("aten::_grad_sum_to_size")) 121*da0073e9SAndroid Build Coastguard Worker 122*da0073e9SAndroid Build Coastguard Worker def test_specialized_types(self): 123*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 124*da0073e9SAndroid Build Coastguard Worker def test_fuse(a, b): 125*da0073e9SAndroid Build Coastguard Worker c = a * b 126*da0073e9SAndroid Build Coastguard Worker d = c * b 127*da0073e9SAndroid Build Coastguard Worker return d 128*da0073e9SAndroid Build Coastguard Worker 129*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([0.5]) 130*da0073e9SAndroid Build Coastguard Worker for _ in range(3): 131*da0073e9SAndroid Build Coastguard Worker test_fuse(x, x) 132*da0073e9SAndroid Build Coastguard Worker 133*da0073e9SAndroid Build Coastguard Worker g = torch.jit.last_executed_optimized_graph() 134*da0073e9SAndroid Build Coastguard Worker # Types should remain specialized for typecheck outputs & fusion outputs 135*da0073e9SAndroid Build Coastguard Worker FileCheck().check("Double(").check_same("prim::TypeCheck").check_same( 136*da0073e9SAndroid Build Coastguard Worker "\n" 137*da0073e9SAndroid Build Coastguard Worker ).check("Double").check_same("TensorExpr").run(g) 138*da0073e9SAndroid Build Coastguard Worker 139*da0073e9SAndroid Build Coastguard Worker # other outputs should not be specialized 140*da0073e9SAndroid Build Coastguard Worker FileCheck().check("Tensor = prim::If").run(g) 141*da0073e9SAndroid Build Coastguard Worker 142*da0073e9SAndroid Build Coastguard Worker def test_aliasing_merge(self): 143*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 144*da0073e9SAndroid Build Coastguard Worker def foo(a, b): 145*da0073e9SAndroid Build Coastguard Worker c = a * b 146*da0073e9SAndroid Build Coastguard Worker d = c * b 147*da0073e9SAndroid Build Coastguard Worker d.add_(b) 148*da0073e9SAndroid Build Coastguard Worker e = d * b 149*da0073e9SAndroid Build Coastguard Worker return d + e 150*da0073e9SAndroid Build Coastguard Worker 151*da0073e9SAndroid Build Coastguard Worker x = torch.ones(1) 152*da0073e9SAndroid Build Coastguard Worker y = torch.ones(1) 153*da0073e9SAndroid Build Coastguard Worker foo(x, y) 154*da0073e9SAndroid Build Coastguard Worker b = foo(x, y) 155*da0073e9SAndroid Build Coastguard Worker g = torch.jit.last_executed_optimized_graph() 156*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(list(g.findAllNodes("prim::TypeCheck"))), 2) 157*da0073e9SAndroid Build Coastguard Worker FileCheck().check("TensorExpr").check("aten::add_").check("TensorExpr").run(g) 158*da0073e9SAndroid Build Coastguard Worker 159*da0073e9SAndroid Build Coastguard Worker def test_use_not_profiled(self): 160*da0073e9SAndroid Build Coastguard Worker def foo(t1, t2, t3, t4, t: float): 161*da0073e9SAndroid Build Coastguard Worker h = t1 + t2 + t3 + t4 162*da0073e9SAndroid Build Coastguard Worker if t > 0.5: 163*da0073e9SAndroid Build Coastguard Worker # Putting a use of t1 in a never-executed conditional prevents 164*da0073e9SAndroid Build Coastguard Worker return t1 + 1 165*da0073e9SAndroid Build Coastguard Worker return h 166*da0073e9SAndroid Build Coastguard Worker 167*da0073e9SAndroid Build Coastguard Worker t = torch.rand(8, dtype=torch.float) 168*da0073e9SAndroid Build Coastguard Worker 169*da0073e9SAndroid Build Coastguard Worker foo_script = torch.jit.script(foo) 170*da0073e9SAndroid Build Coastguard Worker for _ in range(torch._C._jit_get_num_profiled_runs() + 1): 171*da0073e9SAndroid Build Coastguard Worker foo_script(t, t, t, t, 0.1) 172*da0073e9SAndroid Build Coastguard Worker 173*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo(t, t, t, t, 0.1), foo_script(t, t, t, t, 0.1)) 174*da0073e9SAndroid Build Coastguard Worker g = torch.jit.last_executed_optimized_graph() 175*da0073e9SAndroid Build Coastguard Worker # all adds fused 176*da0073e9SAndroid Build Coastguard Worker FileCheck().check("graph").check_not("aten::add").check("prim::If").run(g) 177*da0073e9SAndroid Build Coastguard Worker 178*da0073e9SAndroid Build Coastguard Worker def test_not_fusing_scalar_ops(self): 179*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 180*da0073e9SAndroid Build Coastguard Worker def foo(x: int, y: int): 181*da0073e9SAndroid Build Coastguard Worker return x + y + 2 + 4 + 5 + 6 182*da0073e9SAndroid Build Coastguard Worker 183*da0073e9SAndroid Build Coastguard Worker foo(1, 2) 184*da0073e9SAndroid Build Coastguard Worker foo(2, 3) 185*da0073e9SAndroid Build Coastguard Worker g = torch.jit.last_executed_optimized_graph() 186*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("TensorExpr").run(g) 187*da0073e9SAndroid Build Coastguard Worker 188*da0073e9SAndroid Build Coastguard Worker def test_not_optimizing_property(self): 189*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 190*da0073e9SAndroid Build Coastguard Worker def foo(x, y): 191*da0073e9SAndroid Build Coastguard Worker return x + y + 1 + 2 + 3, x.size() 192*da0073e9SAndroid Build Coastguard Worker 193*da0073e9SAndroid Build Coastguard Worker x = torch.ones(1) 194*da0073e9SAndroid Build Coastguard Worker foo(x, x) 195*da0073e9SAndroid Build Coastguard Worker foo(x, x) 196*da0073e9SAndroid Build Coastguard Worker g = torch.jit.last_executed_optimized_graph() 197*da0073e9SAndroid Build Coastguard Worker FileCheck().check("aten::size").run(g) 198*da0073e9SAndroid Build Coastguard Worker x = torch.ones([2, 3, 5]) 199*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo(x, x), (x + x + 1 + 2 + 3, x.size())) 200*da0073e9SAndroid Build Coastguard Worker 201*da0073e9SAndroid Build Coastguard Worker def test_fallback_graph_not_specialized(self): 202*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 203*da0073e9SAndroid Build Coastguard Worker def foo(a, b): 204*da0073e9SAndroid Build Coastguard Worker c = a * b 205*da0073e9SAndroid Build Coastguard Worker d = c * b 206*da0073e9SAndroid Build Coastguard Worker e = d * b 207*da0073e9SAndroid Build Coastguard Worker return d + e 208*da0073e9SAndroid Build Coastguard Worker 209*da0073e9SAndroid Build Coastguard Worker x = torch.ones(1) 210*da0073e9SAndroid Build Coastguard Worker y = torch.ones(1) 211*da0073e9SAndroid Build Coastguard Worker foo(x, y) 212*da0073e9SAndroid Build Coastguard Worker foo(x, y) 213*da0073e9SAndroid Build Coastguard Worker g = torch.jit.last_executed_optimized_graph() 214*da0073e9SAndroid Build Coastguard Worker FileCheck().check("CallFunction").check_next("Tensor = prim::TupleUnpack").run( 215*da0073e9SAndroid Build Coastguard Worker g 216*da0073e9SAndroid Build Coastguard Worker ) 217*da0073e9SAndroid Build Coastguard Worker 218*da0073e9SAndroid Build Coastguard Worker def test_autograd_fallback_graph(self): 219*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 220*da0073e9SAndroid Build Coastguard Worker def foo(a, b): 221*da0073e9SAndroid Build Coastguard Worker c = a * b 222*da0073e9SAndroid Build Coastguard Worker d = c * b 223*da0073e9SAndroid Build Coastguard Worker e = d * b 224*da0073e9SAndroid Build Coastguard Worker return d + e 225*da0073e9SAndroid Build Coastguard Worker 226*da0073e9SAndroid Build Coastguard Worker x = torch.ones(1, requires_grad=True) 227*da0073e9SAndroid Build Coastguard Worker y = torch.ones(1, requires_grad=True) 228*da0073e9SAndroid Build Coastguard Worker foo(x, y) 229*da0073e9SAndroid Build Coastguard Worker b = foo(x, y) 230*da0073e9SAndroid Build Coastguard Worker b.backward(torch.ones([1], dtype=torch.float), retain_graph=True) 231*da0073e9SAndroid Build Coastguard Worker b.backward(torch.ones([1], dtype=torch.float)) 232*da0073e9SAndroid Build Coastguard Worker 233*da0073e9SAndroid Build Coastguard Worker g = torch.jit.last_executed_optimized_graph() 234*da0073e9SAndroid Build Coastguard Worker FileCheck().check("fallback_function").check_next("CallFunction").run(g) 235*da0073e9SAndroid Build Coastguard Worker 236*da0073e9SAndroid Build Coastguard Worker def test_tensor_constant(self): 237*da0073e9SAndroid Build Coastguard Worker def foo(a, b): 238*da0073e9SAndroid Build Coastguard Worker return a + b + torch.tensor([2]) 239*da0073e9SAndroid Build Coastguard Worker 240*da0073e9SAndroid Build Coastguard Worker x = torch.ones(1, requires_grad=False) 241*da0073e9SAndroid Build Coastguard Worker foo_script = torch.jit.script(foo) 242*da0073e9SAndroid Build Coastguard Worker foo_script(x, x) 243*da0073e9SAndroid Build Coastguard Worker foo_script(x, x) 244*da0073e9SAndroid Build Coastguard Worker 245*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo_script(x, x), foo(x, x)) 246*da0073e9SAndroid Build Coastguard Worker g = torch.jit.last_executed_optimized_graph() 247*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("aten::add", 2, exactly=True).run(g) 248*da0073e9SAndroid Build Coastguard Worker 249*da0073e9SAndroid Build Coastguard Worker def test_local_fusion_strategy(self): 250*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 251*da0073e9SAndroid Build Coastguard Worker def foo(x): 252*da0073e9SAndroid Build Coastguard Worker return x + x + x 253*da0073e9SAndroid Build Coastguard Worker 254*da0073e9SAndroid Build Coastguard Worker torch.jit.set_fusion_strategy([("STATIC", 1)]) 255*da0073e9SAndroid Build Coastguard Worker for _ in range(3): 256*da0073e9SAndroid Build Coastguard Worker foo(torch.rand([10])) 257*da0073e9SAndroid Build Coastguard Worker 258*da0073e9SAndroid Build Coastguard Worker torch.jit.set_fusion_strategy([("STATIC", 10)]) 259*da0073e9SAndroid Build Coastguard Worker 260*da0073e9SAndroid Build Coastguard Worker for i in range(10): 261*da0073e9SAndroid Build Coastguard Worker foo(torch.rand([i])) 262*da0073e9SAndroid Build Coastguard Worker foo(torch.rand([i])) 263*da0073e9SAndroid Build Coastguard Worker 264*da0073e9SAndroid Build Coastguard Worker g = torch.jit.last_executed_optimized_graph() 265*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count(":TensorExprGroup", 2, exactly=True).run(g) 266*da0073e9SAndroid Build Coastguard Worker 267*da0073e9SAndroid Build Coastguard Worker def test_iterative_fusion(self): 268*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 269*da0073e9SAndroid Build Coastguard Worker def foo(a, b, c, d): 270*da0073e9SAndroid Build Coastguard Worker a = a + b 271*da0073e9SAndroid Build Coastguard Worker b.add_(3) 272*da0073e9SAndroid Build Coastguard Worker c = c + b + d 273*da0073e9SAndroid Build Coastguard Worker a = a + 1 274*da0073e9SAndroid Build Coastguard Worker return a, c 275*da0073e9SAndroid Build Coastguard Worker 276*da0073e9SAndroid Build Coastguard Worker x = torch.ones(1, requires_grad=False) 277*da0073e9SAndroid Build Coastguard Worker foo(x, x, x, x) 278*da0073e9SAndroid Build Coastguard Worker foo(x, x, x, x) 279*da0073e9SAndroid Build Coastguard Worker 280*da0073e9SAndroid Build Coastguard Worker # when we iterate through the block, we will start 281*da0073e9SAndroid Build Coastguard Worker # by fusing a = a + b with a = a + 1 282*da0073e9SAndroid Build Coastguard Worker # if we were to continue iteration from that fusion point, 283*da0073e9SAndroid Build Coastguard Worker # would miss the fusion opportunity of c = c + d + b 284*da0073e9SAndroid Build Coastguard Worker 285*da0073e9SAndroid Build Coastguard Worker g = torch.jit.last_executed_optimized_graph() 286*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(list(g.findAllNodes("prim::TensorExprGroup"))), 2) 287