xref: /aosp_15_r20/external/pytorch/test/test_fx_reinplace_pass.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: functionalization"]
2import torch
3from torch.testing._internal.common_utils import TestCase, run_tests
4from torch.fx.passes.reinplace import reinplace
5from torch.fx.experimental.proxy_tensor import make_fx
6from torch.fx.experimental.symbolic_shapes import ShapeEnv
7from torch._dynamo.source import ConstantSource
8from torch.fx.experimental.sym_node import SymNode
9
10try:
11    from functorch.experimental import functionalize
12    HAS_FUNCTIONALIZATION = True
13except Exception as e:
14    HAS_FUNCTIONALIZATION = False
15
16class TestReinplacePass(TestCase):
17
18    def test_reinplace_basic(self):
19        # Basic test: the out-of-place add() call should be converted
20        # into add_()
21        def f(x):
22            a = x.clone()
23            b = a.add(1)
24            return b
25
26        inpt = torch.ones(2)
27        f2 = reinplace(make_fx(f)(inpt), inpt)
28        expected_out = f(inpt)
29        actual_out = f2(inpt)
30        self.assertEqual(actual_out, expected_out)
31        self.assertExpectedInline(f2.code, """\
32
33
34
35def forward(self, x_1):
36    clone = torch.ops.aten.clone.default(x_1);  x_1 = None
37    add = torch.ops.aten.add_.Tensor(clone, 1);  add = None
38    return clone
39    """)
40
41
42    def test_reinplace_with_view(self):
43        def f(x):
44            a = x.clone()
45            a_view = a.view(-1)
46            # We shouldn't re-inplace the first add(), because an alias of a is re-used later in the program
47            b = a.add(1)
48            # Second add() is fine to re-inplace
49            c = a_view.add(1)
50            return c
51
52        inpt = torch.ones(2)
53        f2 = reinplace(make_fx(f)(inpt), inpt)
54        expected_out = f(inpt)
55        actual_out = f2(inpt)
56        self.assertEqual(actual_out, expected_out)
57        self.assertExpectedInline(f2.code, """\
58
59
60
61def forward(self, x_1):
62    clone = torch.ops.aten.clone.default(x_1);  x_1 = None
63    view = torch.ops.aten.view.default(clone, [-1])
64    add = torch.ops.aten.add.Tensor(clone, 1);  clone = add = None
65    add_1 = torch.ops.aten.add_.Tensor(view, 1);  add_1 = None
66    return view
67    """)
68
69    def test_reinplace_different_metadata(self):
70        def f(a_):
71            a = a_.clone()
72            b = a + 1
73            # Naively, we shouldn't try to inplace the .ge() call,
74            # because that would require resizing "b" (from a float to a bool tensor).
75            c = torch.ge(b, a)
76            return c
77        inpt = torch.ones(4)
78        f2 = reinplace(make_fx(f)(inpt), inpt)
79        expected_out = f(inpt)
80        actual_out = f2(inpt)
81        self.assertEqual(actual_out, expected_out)
82        # The .ge() should not be reinplaced.
83        self.assertExpectedInline(f2.code, """\
84
85
86
87def forward(self, a__1):
88    clone = torch.ops.aten.clone.default(a__1);  a__1 = None
89    add = torch.ops.aten.add.Tensor(clone, 1)
90    ge = torch.ops.aten.ge.Tensor(add, clone);  add = clone = None
91    return ge
92    """)
93
94    def test_reinplace_overlapping_memory(self):
95        def f(a_):
96            a = a_.clone()
97            b = a.expand(4, 4)
98            # Can't reinplace because b has overlapping memory.
99            c = b.add(1)
100            return c
101        inpt = torch.ones(1)
102        f2 = reinplace(make_fx(f)(inpt), inpt)
103        expected_out = f(inpt)
104        actual_out = f2(inpt)
105        self.assertEqual(actual_out, expected_out)
106        self.assertExpectedInline(f2.code, """\
107
108
109
110def forward(self, a__1):
111    clone = torch.ops.aten.clone.default(a__1);  a__1 = None
112    expand = torch.ops.aten.expand.default(clone, [4, 4]);  clone = None
113    add = torch.ops.aten.add.Tensor(expand, 1);  expand = None
114    return add
115    """)
116
117    # This test won't actually run in CI, because it requires functionalize() from functorch.
118    # I'm planning on testing more comprehensively with torchbench models,
119    # but we can make this testing better once functorch moves into pytorch/pytorch.
120    def test_reinplace_scatter_op(self):
121        def f(a_):
122            # for now, don't test mutations to inputs
123            a = a_.clone()
124            e = a.view(-1)
125            b = a.view(-1)
126            c = b[0]
127            d = c.view(-1)
128            d.add_(1)
129            return a + e
130
131        if not HAS_FUNCTIONALIZATION:
132            return
133        inpt = torch.ones(4)
134        f2 = reinplace(make_fx(functionalize(f))(inpt), inpt)
135        expected_out = f(inpt)
136        actual_out = f2(inpt)
137        self.assertEqual(actual_out, expected_out)
138        # NOTE: one slight pessimization here is the fact that
139        # there are a bunch of redundant views in the graph.
140        # Technically, half of these views are duplicates that we could de-dup.
141        # This shouldn't really hurt performance though, since creating an extra view
142        # is effectively just moving some metadata around (and allocating a new TensorImpl).
143        # We can/should update the pass in the future to clean this up.
144        self.assertExpectedInline(f2.code, """\
145
146
147
148def forward(self, a__1):
149    clone = torch.ops.aten.clone.default(a__1);  a__1 = None
150    view = torch.ops.aten.view.default(clone, [-1]);  view = None
151    view_1 = torch.ops.aten.view.default(clone, [-1])
152    select = torch.ops.aten.select.int(view_1, 0, 0);  view_1 = None
153    view_2 = torch.ops.aten.view.default(select, [-1]);  select = None
154    add = torch.ops.aten.add_.Tensor(view_2, 1);  add = None
155    view_3 = torch.ops.aten.view.default(clone, [-1]);  clone = None
156    select_1 = torch.ops.aten.select.int(view_3, 0, 0);  select_1 = None
157    view_4 = torch.ops.aten.view.default(view_2, []);  view_2 = view_4 = None
158    view_5 = torch.ops.aten.view.default(view_3, [4]);  view_3 = None
159    view_6 = torch.ops.aten.view.default(view_5, [-1])
160    select_2 = torch.ops.aten.select.int(view_6, 0, 0);  view_6 = None
161    view_7 = torch.ops.aten.view.default(select_2, [-1]);  select_2 = view_7 = None
162    view_8 = torch.ops.aten.view.default(view_5, [-1])
163    add_1 = torch.ops.aten.add_.Tensor(view_5, view_8);  view_8 = add_1 = None
164    return view_5
165    """)
166
167    def test_reinplace_scatter_twice(self):
168        def f(a_):
169            # for now, don't test mutations to inputs
170            a = a_.clone()
171            b = a[:, 1]
172            c = b[1]
173            c.add_(1)
174            return a
175
176        if not HAS_FUNCTIONALIZATION:
177            return
178
179        inpt = torch.ones(4, 4)
180        f2 = reinplace(make_fx(functionalize(f))(inpt), inpt)
181        expected_out = f(inpt)
182        actual_out = f2(inpt)
183        self.assertEqual(actual_out, expected_out)
184        self.assertExpectedInline(f2.code, """\
185
186
187
188def forward(self, a__1):
189    clone = torch.ops.aten.clone.default(a__1);  a__1 = None
190    slice_1 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
191    select = torch.ops.aten.select.int(slice_1, 1, 1);  slice_1 = None
192    select_1 = torch.ops.aten.select.int(select, 0, 1);  select = None
193    add = torch.ops.aten.add_.Tensor(select_1, 1);  select_1 = add = None
194    slice_2 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
195    select_2 = torch.ops.aten.select.int(slice_2, 1, 1);  slice_2 = select_2 = None
196    slice_3 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
197    select_3 = torch.ops.aten.select.int(slice_3, 1, 1);  slice_3 = None
198    select_4 = torch.ops.aten.select.int(select_3, 0, 1);  select_3 = select_4 = None
199    return clone
200    """)
201
202    def test_reinplace_scatter_twice_with_different_view_op_valid(self):
203        def f(a_):
204            a = a_.clone()
205            b = a[:, 1]
206            c = b[1]
207            c_updated = c.add(1)
208            good_mirror_of_b = a.as_strided((4,), (4,), 1)
209            # good_mirror_of_b points to the same region of memory as b.
210            # and this scatter op below tries to scatter c_updated into the same region
211            # that c currently takes up.
212            # reinplacing logic checks this by confirming that:
213            #   c_updated
214            #   good_mirror_of_b.select(0, 1)
215            # have the same size/stride/storage_offset.
216            b_updated = torch.select_scatter(good_mirror_of_b, c_updated, 0, 1)
217            return b_updated
218
219        inpt = torch.ones(4, 4)
220        f2 = reinplace(make_fx(f)(inpt), inpt)
221        expected_out = f(inpt)
222        actual_out = f2(inpt)
223        self.assertEqual(actual_out, expected_out)
224        self.assertExpectedInline(f2.code, """\
225
226
227
228def forward(self, a__1):
229    clone = torch.ops.aten.clone.default(a__1);  a__1 = None
230    slice_1 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
231    select = torch.ops.aten.select.int(slice_1, 1, 1);  slice_1 = None
232    select_1 = torch.ops.aten.select.int(select, 0, 1);  select = None
233    add = torch.ops.aten.add_.Tensor(select_1, 1);  select_1 = add = None
234    as_strided = torch.ops.aten.as_strided.default(clone, [4], [4], 1);  clone = None
235    return as_strided
236    """)
237
238    # Test example where we have a scatter op, where the base tensor
239    # has the same size/stride/storage offset (even though it is a different view),
240    # making it valid to re-inplace
241    def test_reinplace_scatter_twice_with_different_view_op_invalid(self):
242        def f(a_):
243            a = a_.clone()
244            b = a[:, 1]
245            c = b[1]
246            c_updated = c.add(1)
247            good_mirror_of_b = a.as_strided((4,), (4,), 1)
248            # The first arg to select_scatter is an equivalent view to b.
249            # However, the select_scatter call below tries to put c_updated
250            # into a different slice of "b" than what "c" currently occupies.
251            #
252            b_updated = torch.select_scatter(good_mirror_of_b, c_updated, 0, 0)
253            return b_updated
254
255        inpt = torch.ones(4, 4)
256        f2 = reinplace(make_fx(f)(inpt), inpt)
257        expected_out = f(inpt)
258        actual_out = f2(inpt)
259        self.assertEqual(actual_out, expected_out)
260        self.assertExpectedInline(f2.code, """\
261
262
263
264def forward(self, a__1):
265    clone = torch.ops.aten.clone.default(a__1);  a__1 = None
266    slice_1 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
267    select = torch.ops.aten.select.int(slice_1, 1, 1);  slice_1 = None
268    select_1 = torch.ops.aten.select.int(select, 0, 1);  select = None
269    add = torch.ops.aten.add.Tensor(select_1, 1);  select_1 = None
270    as_strided = torch.ops.aten.as_strided.default(clone, [4], [4], 1);  clone = None
271    select_int = torch.ops.aten.select.int(as_strided, 0, 0)
272    copy__default = torch.ops.aten.copy_.default(select_int, add);  select_int = add = copy__default = None
273    return as_strided
274    """)  # noqa: B950
275
276    def test_reinplace_scatter_twice_with_different_view_op_invalid2(self):
277        def f(a_):
278            a = a_.clone()
279            b = a[:, 1]
280            c = b[1]
281            c_updated = c.add(1)
282            bad_mirror_of_b = a.as_strided((4,), (4,), 0)
283            # The first arg to select_scatter points to a different than c's base.
284            # This makes it invalid to re-inplace.
285            b_updated = torch.select_scatter(bad_mirror_of_b, c_updated, 0, 1)
286            return b_updated
287
288        inpt = torch.ones(4, 4)
289        f2 = reinplace(make_fx(f)(inpt), inpt)
290        expected_out = f(inpt)
291        actual_out = f2(inpt)
292        # self.assertEqual(actual_out, expected_out)
293        self.assertExpectedInline(f2.code, """\
294
295
296
297def forward(self, a__1):
298    clone = torch.ops.aten.clone.default(a__1);  a__1 = None
299    slice_1 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
300    select = torch.ops.aten.select.int(slice_1, 1, 1);  slice_1 = None
301    select_1 = torch.ops.aten.select.int(select, 0, 1);  select = None
302    add = torch.ops.aten.add.Tensor(select_1, 1);  select_1 = None
303    as_strided = torch.ops.aten.as_strided.default(clone, [4], [4], 0);  clone = None
304    select_int = torch.ops.aten.select.int(as_strided, 0, 1)
305    copy__default = torch.ops.aten.copy_.default(select_int, add);  select_int = add = copy__default = None
306    return as_strided
307    """)  # noqa: B950
308
309
310    def test_out_node_updated(self):
311        def f():
312            x = torch.zeros(2, 2)
313            y = x.diagonal()
314            y_updated = y.add(1)
315            z = torch.diagonal_scatter(x, y_updated)
316            # reinplace needs to know to replace output [z] with [x]
317            return [z]
318
319        if not HAS_FUNCTIONALIZATION:
320            return
321        f2 = reinplace(make_fx(functionalize(f))())
322        expected_out = f()
323        actual_out = f2()
324        self.assertEqual(actual_out, expected_out)
325        self.assertExpectedInline(f2.code, """\
326
327
328
329def forward(self):
330    zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
331    diagonal = torch.ops.aten.diagonal.default(zeros)
332    add = torch.ops.aten.add_.Tensor(diagonal, 1);  diagonal = add = None
333    return [zeros]
334    """)
335
336    def test_reinplace_index_mutation(self):
337        def f():
338            a = torch.zeros(4, 4, 4)
339            a[:, 2:] = torch.ones(4, 2, 4)
340            return a
341
342        if not HAS_FUNCTIONALIZATION:
343            return
344        f2 = reinplace(make_fx(functionalize(f))())
345        expected_out = f()
346        actual_out = f2()
347        self.assertEqual(actual_out, expected_out)
348        self.assertExpectedInline(f2.code, """\
349
350
351
352def forward(self):
353    zeros = torch.ops.aten.zeros.default([4, 4, 4], device = device(type='cpu'), pin_memory = False)
354    ones = torch.ops.aten.ones.default([4, 2, 4], device = device(type='cpu'), pin_memory = False)
355    slice_1 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807)
356    slice_2 = torch.ops.aten.slice.Tensor(slice_1, 1, 2, 9223372036854775807);  slice_1 = None
357    copy = torch.ops.aten.copy_.default(slice_2, ones);  slice_2 = ones = copy = None
358    slice_3 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807);  slice_3 = None
359    slice_4 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807)
360    slice_5 = torch.ops.aten.slice.Tensor(slice_4, 1, 2, 9223372036854775807);  slice_4 = slice_5 = None
361    return zeros
362    """)
363
364    def test_reinplace_sym_input(self):
365        # Symbolic input test: the out-of-place add() call should be converted
366        # into add_(), and symbolic input won't cause any error.
367        def f(x, index):
368            a = torch.select(x, 0, index)
369            clone = a.clone()
370            b = clone.add(1)
371            return b
372
373        x = torch.randn((4, 8, 16, 16), requires_grad=False)
374        index = 2
375        shape_env = ShapeEnv()
376        symbol = shape_env.create_symbol(index, source=ConstantSource(
377            f"__testing_only{len(shape_env.var_to_val)}"))
378        sym_index = torch.SymInt(SymNode(symbol, shape_env, int, hint=index))
379
380        inpt = [x, sym_index]
381        f2 = reinplace(make_fx(f)(*inpt), *inpt)
382
383        real_inpt = [x, index]
384        expected_out = f(*real_inpt)
385        actual_out = f2(*real_inpt)
386        self.assertEqual(actual_out, expected_out)
387        print(f2.code)
388        self.assertExpectedInline(f2.code, """\
389
390
391
392def forward(self, x_1, index_1):
393    select = torch.ops.aten.select.int(x_1, 0, index_1);  x_1 = index_1 = None
394    clone = torch.ops.aten.clone.default(select);  select = None
395    add = torch.ops.aten.add_.Tensor(clone, 1);  add = None
396    return clone
397    """)
398
399
400if __name__ == '__main__':
401    run_tests()
402