xref: /aosp_15_r20/external/pytorch/test/jit/test_profiler.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport os
4*da0073e9SAndroid Build Coastguard Workerimport sys
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Workerimport torch
7*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import skipIfTorchDynamo
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Worker# Make the helper files in test/ importable
11*da0073e9SAndroid Build Coastguard Workerpytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
12*da0073e9SAndroid Build Coastguard Workersys.path.append(pytorch_test_dir)
13*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import FileCheck, JitTestCase, warmup_backward
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
17*da0073e9SAndroid Build Coastguard Worker    raise RuntimeError(
18*da0073e9SAndroid Build Coastguard Worker        "This test file is not meant to be run directly, use:\n\n"
19*da0073e9SAndroid Build Coastguard Worker        "\tpython test/test_jit.py TESTNAME\n\n"
20*da0073e9SAndroid Build Coastguard Worker        "instead."
21*da0073e9SAndroid Build Coastguard Worker    )
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Worker@skipIfTorchDynamo()
25*da0073e9SAndroid Build Coastguard Workerclass TestProfiler(JitTestCase):
26*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
27*da0073e9SAndroid Build Coastguard Worker        self.prev_exec = torch._C._jit_set_profiling_executor(True)
28*da0073e9SAndroid Build Coastguard Worker        self.prev_profiling = torch._C._get_graph_executor_optimize(True)
29*da0073e9SAndroid Build Coastguard Worker        self.inline_autodiff = torch._C._debug_set_autodiff_subgraph_inlining(False)
30*da0073e9SAndroid Build Coastguard Worker        self.texpr_fuser_state = torch._C._jit_texpr_fuser_enabled()
31*da0073e9SAndroid Build Coastguard Worker        self.can_fuse_on_cpu = torch._C._jit_can_fuse_on_cpu()
32*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_set_texpr_fuser_enabled(True)
33*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_override_can_fuse_on_cpu(True)
34*da0073e9SAndroid Build Coastguard Worker        self.default_dtype = torch.get_default_dtype()
35*da0073e9SAndroid Build Coastguard Worker        self.old_reduction_enabled = torch._C._jit_set_texpr_reductions_enabled(True)
36*da0073e9SAndroid Build Coastguard Worker        torch.set_default_dtype(torch.double)
37*da0073e9SAndroid Build Coastguard Worker        self.old_fusion_inlining = torch._C._debug_get_fusion_group_inlining()
38*da0073e9SAndroid Build Coastguard Worker        torch._C._debug_set_fusion_group_inlining(False)
39*da0073e9SAndroid Build Coastguard Worker        self.old_te_must_use_llvm_cpu = torch._C._jit_get_te_must_use_llvm_cpu()
40*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_set_te_must_use_llvm_cpu(False)
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Worker    def tearDown(self):
43*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_set_profiling_executor(self.prev_exec)
44*da0073e9SAndroid Build Coastguard Worker        torch._C._get_graph_executor_optimize(self.prev_profiling)
45*da0073e9SAndroid Build Coastguard Worker        torch._C._debug_set_autodiff_subgraph_inlining(self.inline_autodiff)
46*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_set_texpr_fuser_enabled(self.texpr_fuser_state)
47*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_override_can_fuse_on_cpu(self.can_fuse_on_cpu)
48*da0073e9SAndroid Build Coastguard Worker        torch.set_default_dtype(self.default_dtype)
49*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_set_texpr_reductions_enabled(self.old_reduction_enabled)
50*da0073e9SAndroid Build Coastguard Worker        torch._C._debug_set_fusion_group_inlining(self.old_fusion_inlining)
51*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_set_te_must_use_llvm_cpu(self.old_te_must_use_llvm_cpu)
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Worker    def test_tensor_type_not_determined_by_inputs(self):
54*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
55*da0073e9SAndroid Build Coastguard Worker        def scalar_type_input(x, y, z):
56*da0073e9SAndroid Build Coastguard Worker            return x + y + 4 + z.item()
57*da0073e9SAndroid Build Coastguard Worker
58*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([2, 2])
59*da0073e9SAndroid Build Coastguard Worker        scalar_type_input(x, x, torch.tensor(1))
60*da0073e9SAndroid Build Coastguard Worker        scalar_type_input(x, x, torch.tensor(1))
61*da0073e9SAndroid Build Coastguard Worker        scalar_type_input(x, x, torch.tensor(1.0))
62*da0073e9SAndroid Build Coastguard Worker        g = torch.jit.last_executed_optimized_graph()
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Worker        # item & add should not get pulled into the fusion group -
65*da0073e9SAndroid Build Coastguard Worker        # we expect to see Fusion Group (item / add) Fusion Group in ir dump
66*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("TensorExpr").check("Scalar = aten::item").check_next(
67*da0073e9SAndroid Build Coastguard Worker            "Tensor = aten::add"
68*da0073e9SAndroid Build Coastguard Worker        ).check("TensorExpr").run(g)
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
71*da0073e9SAndroid Build Coastguard Worker        def non_const_dtype(x, y, cond: bool):
72*da0073e9SAndroid Build Coastguard Worker            dtype = torch.int16 if cond else torch.int32
73*da0073e9SAndroid Build Coastguard Worker            return (x + y + 3).sum(dtype=dtype)
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Worker        non_const_dtype(x, x, True)
76*da0073e9SAndroid Build Coastguard Worker        non_const_dtype(x, x, True)
77*da0073e9SAndroid Build Coastguard Worker        g = torch.jit.last_executed_optimized_graph()
78*da0073e9SAndroid Build Coastguard Worker        # because dtype is non-const, sum should not get pulled into the Fusion Group
79*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("TensorExpr").check("TensorExpr").check_not("aten::sum").run(
80*da0073e9SAndroid Build Coastguard Worker            g
81*da0073e9SAndroid Build Coastguard Worker        )
82*da0073e9SAndroid Build Coastguard Worker
83*da0073e9SAndroid Build Coastguard Worker    def test_specialize_backward(self):
84*da0073e9SAndroid Build Coastguard Worker        def test_fuse(a, b):
85*da0073e9SAndroid Build Coastguard Worker            c = a * b
86*da0073e9SAndroid Build Coastguard Worker            d = c * b
87*da0073e9SAndroid Build Coastguard Worker            return d
88*da0073e9SAndroid Build Coastguard Worker
89*da0073e9SAndroid Build Coastguard Worker        test_fuse.__disable_jit_function_caching__ = True
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker        scripted_f = torch.jit.script(test_fuse)
92*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(1, requires_grad=True)
93*da0073e9SAndroid Build Coastguard Worker        y = torch.ones(1, requires_grad=True)
94*da0073e9SAndroid Build Coastguard Worker        scripted_f(x, y)
95*da0073e9SAndroid Build Coastguard Worker        b = scripted_f(x, y)
96*da0073e9SAndroid Build Coastguard Worker        warmup_backward(b)
97*da0073e9SAndroid Build Coastguard Worker        g = torch.jit.last_executed_optimized_graph()
98*da0073e9SAndroid Build Coastguard Worker        # Backward has an if node guarding specializations,
99*da0073e9SAndroid Build Coastguard Worker        # within the if node true block there is only one if node
100*da0073e9SAndroid Build Coastguard Worker        # that guards a tensorexpr group
101*da0073e9SAndroid Build Coastguard Worker        optimized_block = next(g.findNode("prim::If").blocks())
102*da0073e9SAndroid Build Coastguard Worker        if_nodes = list(optimized_block.findAllNodes("prim::If"))
103*da0073e9SAndroid Build Coastguard Worker
104*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(if_nodes), 1)
105*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("Group[Subgraph").run(str(if_nodes[0]))
106*da0073e9SAndroid Build Coastguard Worker        # no broadcasts occurred, sum_to_size have been specialized out
107*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(optimized_block.findNode("aten::_grad_sum_to_size"))
108*da0073e9SAndroid Build Coastguard Worker
109*da0073e9SAndroid Build Coastguard Worker        broadcast_f = torch.jit.script(test_fuse)
110*da0073e9SAndroid Build Coastguard Worker        x = torch.ones([2, 2], requires_grad=True)
111*da0073e9SAndroid Build Coastguard Worker        y = torch.ones([1], requires_grad=True)
112*da0073e9SAndroid Build Coastguard Worker        broadcast_f(x, y)
113*da0073e9SAndroid Build Coastguard Worker        b = broadcast_f(x, y)
114*da0073e9SAndroid Build Coastguard Worker        b.backward(torch.ones([2, 2], dtype=torch.float), retain_graph=True)
115*da0073e9SAndroid Build Coastguard Worker        b.backward(torch.ones([2, 2], dtype=torch.float))
116*da0073e9SAndroid Build Coastguard Worker        # warmup_backward(b, torch.ones([2, 2], dtype=torch.float))
117*da0073e9SAndroid Build Coastguard Worker        g = torch.jit.last_executed_optimized_graph()
118*da0073e9SAndroid Build Coastguard Worker        optimized_block = next(g.findNode("prim::If").blocks())
119*da0073e9SAndroid Build Coastguard Worker        # broadcasts occurred, currently expect to see aten::_grad_sum_to_size
120*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(optimized_block.findNode("aten::_grad_sum_to_size"))
121*da0073e9SAndroid Build Coastguard Worker
122*da0073e9SAndroid Build Coastguard Worker    def test_specialized_types(self):
123*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
124*da0073e9SAndroid Build Coastguard Worker        def test_fuse(a, b):
125*da0073e9SAndroid Build Coastguard Worker            c = a * b
126*da0073e9SAndroid Build Coastguard Worker            d = c * b
127*da0073e9SAndroid Build Coastguard Worker            return d
128*da0073e9SAndroid Build Coastguard Worker
129*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([0.5])
130*da0073e9SAndroid Build Coastguard Worker        for _ in range(3):
131*da0073e9SAndroid Build Coastguard Worker            test_fuse(x, x)
132*da0073e9SAndroid Build Coastguard Worker
133*da0073e9SAndroid Build Coastguard Worker        g = torch.jit.last_executed_optimized_graph()
134*da0073e9SAndroid Build Coastguard Worker        # Types should remain specialized for typecheck outputs & fusion outputs
135*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("Double(").check_same("prim::TypeCheck").check_same(
136*da0073e9SAndroid Build Coastguard Worker            "\n"
137*da0073e9SAndroid Build Coastguard Worker        ).check("Double").check_same("TensorExpr").run(g)
138*da0073e9SAndroid Build Coastguard Worker
139*da0073e9SAndroid Build Coastguard Worker        # other outputs should not be specialized
140*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("Tensor = prim::If").run(g)
141*da0073e9SAndroid Build Coastguard Worker
142*da0073e9SAndroid Build Coastguard Worker    def test_aliasing_merge(self):
143*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
144*da0073e9SAndroid Build Coastguard Worker        def foo(a, b):
145*da0073e9SAndroid Build Coastguard Worker            c = a * b
146*da0073e9SAndroid Build Coastguard Worker            d = c * b
147*da0073e9SAndroid Build Coastguard Worker            d.add_(b)
148*da0073e9SAndroid Build Coastguard Worker            e = d * b
149*da0073e9SAndroid Build Coastguard Worker            return d + e
150*da0073e9SAndroid Build Coastguard Worker
151*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(1)
152*da0073e9SAndroid Build Coastguard Worker        y = torch.ones(1)
153*da0073e9SAndroid Build Coastguard Worker        foo(x, y)
154*da0073e9SAndroid Build Coastguard Worker        b = foo(x, y)
155*da0073e9SAndroid Build Coastguard Worker        g = torch.jit.last_executed_optimized_graph()
156*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(list(g.findAllNodes("prim::TypeCheck"))), 2)
157*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("TensorExpr").check("aten::add_").check("TensorExpr").run(g)
158*da0073e9SAndroid Build Coastguard Worker
159*da0073e9SAndroid Build Coastguard Worker    def test_use_not_profiled(self):
160*da0073e9SAndroid Build Coastguard Worker        def foo(t1, t2, t3, t4, t: float):
161*da0073e9SAndroid Build Coastguard Worker            h = t1 + t2 + t3 + t4
162*da0073e9SAndroid Build Coastguard Worker            if t > 0.5:
163*da0073e9SAndroid Build Coastguard Worker                # Putting a use of t1 in a never-executed conditional prevents
164*da0073e9SAndroid Build Coastguard Worker                return t1 + 1
165*da0073e9SAndroid Build Coastguard Worker            return h
166*da0073e9SAndroid Build Coastguard Worker
167*da0073e9SAndroid Build Coastguard Worker        t = torch.rand(8, dtype=torch.float)
168*da0073e9SAndroid Build Coastguard Worker
169*da0073e9SAndroid Build Coastguard Worker        foo_script = torch.jit.script(foo)
170*da0073e9SAndroid Build Coastguard Worker        for _ in range(torch._C._jit_get_num_profiled_runs() + 1):
171*da0073e9SAndroid Build Coastguard Worker            foo_script(t, t, t, t, 0.1)
172*da0073e9SAndroid Build Coastguard Worker
173*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo(t, t, t, t, 0.1), foo_script(t, t, t, t, 0.1))
174*da0073e9SAndroid Build Coastguard Worker        g = torch.jit.last_executed_optimized_graph()
175*da0073e9SAndroid Build Coastguard Worker        # all adds fused
176*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("graph").check_not("aten::add").check("prim::If").run(g)
177*da0073e9SAndroid Build Coastguard Worker
178*da0073e9SAndroid Build Coastguard Worker    def test_not_fusing_scalar_ops(self):
179*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
180*da0073e9SAndroid Build Coastguard Worker        def foo(x: int, y: int):
181*da0073e9SAndroid Build Coastguard Worker            return x + y + 2 + 4 + 5 + 6
182*da0073e9SAndroid Build Coastguard Worker
183*da0073e9SAndroid Build Coastguard Worker        foo(1, 2)
184*da0073e9SAndroid Build Coastguard Worker        foo(2, 3)
185*da0073e9SAndroid Build Coastguard Worker        g = torch.jit.last_executed_optimized_graph()
186*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_not("TensorExpr").run(g)
187*da0073e9SAndroid Build Coastguard Worker
188*da0073e9SAndroid Build Coastguard Worker    def test_not_optimizing_property(self):
189*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
190*da0073e9SAndroid Build Coastguard Worker        def foo(x, y):
191*da0073e9SAndroid Build Coastguard Worker            return x + y + 1 + 2 + 3, x.size()
192*da0073e9SAndroid Build Coastguard Worker
193*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(1)
194*da0073e9SAndroid Build Coastguard Worker        foo(x, x)
195*da0073e9SAndroid Build Coastguard Worker        foo(x, x)
196*da0073e9SAndroid Build Coastguard Worker        g = torch.jit.last_executed_optimized_graph()
197*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("aten::size").run(g)
198*da0073e9SAndroid Build Coastguard Worker        x = torch.ones([2, 3, 5])
199*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo(x, x), (x + x + 1 + 2 + 3, x.size()))
200*da0073e9SAndroid Build Coastguard Worker
201*da0073e9SAndroid Build Coastguard Worker    def test_fallback_graph_not_specialized(self):
202*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
203*da0073e9SAndroid Build Coastguard Worker        def foo(a, b):
204*da0073e9SAndroid Build Coastguard Worker            c = a * b
205*da0073e9SAndroid Build Coastguard Worker            d = c * b
206*da0073e9SAndroid Build Coastguard Worker            e = d * b
207*da0073e9SAndroid Build Coastguard Worker            return d + e
208*da0073e9SAndroid Build Coastguard Worker
209*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(1)
210*da0073e9SAndroid Build Coastguard Worker        y = torch.ones(1)
211*da0073e9SAndroid Build Coastguard Worker        foo(x, y)
212*da0073e9SAndroid Build Coastguard Worker        foo(x, y)
213*da0073e9SAndroid Build Coastguard Worker        g = torch.jit.last_executed_optimized_graph()
214*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("CallFunction").check_next("Tensor = prim::TupleUnpack").run(
215*da0073e9SAndroid Build Coastguard Worker            g
216*da0073e9SAndroid Build Coastguard Worker        )
217*da0073e9SAndroid Build Coastguard Worker
218*da0073e9SAndroid Build Coastguard Worker    def test_autograd_fallback_graph(self):
219*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
220*da0073e9SAndroid Build Coastguard Worker        def foo(a, b):
221*da0073e9SAndroid Build Coastguard Worker            c = a * b
222*da0073e9SAndroid Build Coastguard Worker            d = c * b
223*da0073e9SAndroid Build Coastguard Worker            e = d * b
224*da0073e9SAndroid Build Coastguard Worker            return d + e
225*da0073e9SAndroid Build Coastguard Worker
226*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(1, requires_grad=True)
227*da0073e9SAndroid Build Coastguard Worker        y = torch.ones(1, requires_grad=True)
228*da0073e9SAndroid Build Coastguard Worker        foo(x, y)
229*da0073e9SAndroid Build Coastguard Worker        b = foo(x, y)
230*da0073e9SAndroid Build Coastguard Worker        b.backward(torch.ones([1], dtype=torch.float), retain_graph=True)
231*da0073e9SAndroid Build Coastguard Worker        b.backward(torch.ones([1], dtype=torch.float))
232*da0073e9SAndroid Build Coastguard Worker
233*da0073e9SAndroid Build Coastguard Worker        g = torch.jit.last_executed_optimized_graph()
234*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("fallback_function").check_next("CallFunction").run(g)
235*da0073e9SAndroid Build Coastguard Worker
236*da0073e9SAndroid Build Coastguard Worker    def test_tensor_constant(self):
237*da0073e9SAndroid Build Coastguard Worker        def foo(a, b):
238*da0073e9SAndroid Build Coastguard Worker            return a + b + torch.tensor([2])
239*da0073e9SAndroid Build Coastguard Worker
240*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(1, requires_grad=False)
241*da0073e9SAndroid Build Coastguard Worker        foo_script = torch.jit.script(foo)
242*da0073e9SAndroid Build Coastguard Worker        foo_script(x, x)
243*da0073e9SAndroid Build Coastguard Worker        foo_script(x, x)
244*da0073e9SAndroid Build Coastguard Worker
245*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo_script(x, x), foo(x, x))
246*da0073e9SAndroid Build Coastguard Worker        g = torch.jit.last_executed_optimized_graph()
247*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_count("aten::add", 2, exactly=True).run(g)
248*da0073e9SAndroid Build Coastguard Worker
249*da0073e9SAndroid Build Coastguard Worker    def test_local_fusion_strategy(self):
250*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
251*da0073e9SAndroid Build Coastguard Worker        def foo(x):
252*da0073e9SAndroid Build Coastguard Worker            return x + x + x
253*da0073e9SAndroid Build Coastguard Worker
254*da0073e9SAndroid Build Coastguard Worker        torch.jit.set_fusion_strategy([("STATIC", 1)])
255*da0073e9SAndroid Build Coastguard Worker        for _ in range(3):
256*da0073e9SAndroid Build Coastguard Worker            foo(torch.rand([10]))
257*da0073e9SAndroid Build Coastguard Worker
258*da0073e9SAndroid Build Coastguard Worker        torch.jit.set_fusion_strategy([("STATIC", 10)])
259*da0073e9SAndroid Build Coastguard Worker
260*da0073e9SAndroid Build Coastguard Worker        for i in range(10):
261*da0073e9SAndroid Build Coastguard Worker            foo(torch.rand([i]))
262*da0073e9SAndroid Build Coastguard Worker            foo(torch.rand([i]))
263*da0073e9SAndroid Build Coastguard Worker
264*da0073e9SAndroid Build Coastguard Worker        g = torch.jit.last_executed_optimized_graph()
265*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_count(":TensorExprGroup", 2, exactly=True).run(g)
266*da0073e9SAndroid Build Coastguard Worker
267*da0073e9SAndroid Build Coastguard Worker    def test_iterative_fusion(self):
268*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
269*da0073e9SAndroid Build Coastguard Worker        def foo(a, b, c, d):
270*da0073e9SAndroid Build Coastguard Worker            a = a + b
271*da0073e9SAndroid Build Coastguard Worker            b.add_(3)
272*da0073e9SAndroid Build Coastguard Worker            c = c + b + d
273*da0073e9SAndroid Build Coastguard Worker            a = a + 1
274*da0073e9SAndroid Build Coastguard Worker            return a, c
275*da0073e9SAndroid Build Coastguard Worker
276*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(1, requires_grad=False)
277*da0073e9SAndroid Build Coastguard Worker        foo(x, x, x, x)
278*da0073e9SAndroid Build Coastguard Worker        foo(x, x, x, x)
279*da0073e9SAndroid Build Coastguard Worker
280*da0073e9SAndroid Build Coastguard Worker        # when we iterate through the block, we will start
281*da0073e9SAndroid Build Coastguard Worker        # by fusing a = a + b with a = a + 1
282*da0073e9SAndroid Build Coastguard Worker        # if we were to continue iteration from that fusion point,
283*da0073e9SAndroid Build Coastguard Worker        # would miss the fusion opportunity of c = c + d + b
284*da0073e9SAndroid Build Coastguard Worker
285*da0073e9SAndroid Build Coastguard Worker        g = torch.jit.last_executed_optimized_graph()
286*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(list(g.findAllNodes("prim::TensorExprGroup"))), 2)
287