xref: /aosp_15_r20/external/pytorch/test/dynamo/test_input_attr_tracking.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2# flake8: noqa
3import torch
4import torch._dynamo
5import torch._dynamo.test_case
6import torch._dynamo.testing
7from torch._dynamo.testing import (
8    CompileCounter,
9    CompileCounterWithBackend,
10    EagerAndRecordGraphs,
11    normalize_gm,
12)
13
14
15class TestInputAttrTracking(torch._dynamo.test_case.TestCase):
16    def test_tensor_property_on_tensor(self):
17        def fn(x):
18            return x * x.y
19
20        x_ = torch.randn([2, 2])
21        y_ = torch.randn([2, 2])
22        x_.y = y_
23
24        eager_result = fn(x_)
25
26        graph = None
27
28        def grab_graph_backend(gm, inps):
29            nonlocal graph
30            graph = gm
31            return gm
32
33        fn = torch._dynamo.optimize(grab_graph_backend, nopython=True)(fn)
34        compile_result = fn(x_)
35        self.assertEqual(eager_result, compile_result)
36
37        placeholder_cnt = 0
38        for node in graph.graph.nodes:
39            if node.op == "placeholder":
40                placeholder_cnt += 1
41
42        # We want to be very sure that this lifts y to inputs!
43        self.assertEqual(placeholder_cnt, 2)
44
45    def test_tensor_property_assigned_on_tensor(self):
46        def fn(x, y):
47            x.y = y
48            return x * x.y
49
50        x_ = torch.randn([2, 2])
51        y_ = torch.randn([2, 2])
52
53        eager_result = fn(x_, y_)
54
55        graph = None
56
57        def grab_graph_backend(gm, inps):
58            nonlocal graph
59            graph = gm
60            return gm
61
62        fn = torch._dynamo.optimize(grab_graph_backend, nopython=True)(fn)
63        compile_result = fn(x_, y_)
64        self.assertEqual(eager_result, compile_result)
65
66        placeholder_cnt = 0
67        for node in graph.graph.nodes:
68            if node.op == "placeholder":
69                placeholder_cnt += 1
70
71        # y is already an input
72        self.assertEqual(placeholder_cnt, 2)
73
74    def test_const_property_on_tensor(self):
75        def fn(x):
76            return x * x.y
77
78        x_ = torch.randn([2, 2])
79        y_ = 4
80        x_.y = y_
81
82        eager_result = fn(x_)
83
84        graph = None
85
86        def grab_graph_backend(gm, inps):
87            nonlocal graph
88            graph = gm
89            return gm
90
91        fn = torch._dynamo.optimize(grab_graph_backend, nopython=True)(fn)
92        compile_result = fn(x_)
93        self.assertEqual(eager_result, compile_result)
94
95        placeholder_cnt = 0
96        for node in graph.graph.nodes:
97            if node.op == "placeholder":
98                placeholder_cnt += 1
99
100        # We want to be very sure that this does not lifts y to inputs, as its a const
101        self.assertEqual(placeholder_cnt, 1)
102
103    def test_const_property_assigned_on_tensor(self):
104        def fn(x, y):
105            x.y = y
106            return x * x.y
107
108        x_ = torch.randn([2, 2])
109        y_ = 4
110
111        eager_result = fn(x_, y_)
112
113        fn = torch._dynamo.optimize("eager", nopython=True)(fn)
114        compile_result = fn(x_, y_)
115        self.assertEqual(eager_result, compile_result)
116
117    def test_guards_correctly_property_assigned_on_tensor_type_change(self):
118        def fn(x, y):
119            x.y = y
120            return x * x.y
121
122        x_ = torch.randn([2, 2])
123
124        fn = torch._dynamo.optimize("eager", nopython=True)(fn)
125        compile_result_const = fn(x_, 4)
126        self.assertEqual(compile_result_const, x_ * 4)
127
128        y = torch.randn([2, 2])
129        compile_result_tensor = fn(x_, y)
130        self.assertEqual(compile_result_tensor, x_ * y)
131
132    def test_guards_correctly_property_assigned_on_tensor_type_change_inductor(self):
133        def fn(x, y):
134            x.y = y
135            return x * x.y
136
137        x_ = torch.randn([2, 2])
138
139        fn = torch._dynamo.optimize("inductor", nopython=True)(fn)
140        compile_result_const = fn(x_, 4)
141        self.assertEqual(compile_result_const, x_ * 4)
142
143        y = torch.randn([2, 2])
144        compile_result_tensor = fn(x_, y)
145        self.assertEqual(compile_result_tensor, x_ * y)
146
147    def test_complex_attr_access_without_graph_breaks(self):
148        def fn(x, y, z):
149            for t in x:
150                t.y = y
151                t.z = y * z
152
153            new_y = 1
154            new_z = 1
155            for t in x:
156                new_y = t.y * new_y
157                new_z = t.z * new_z
158
159            return new_y, new_z
160
161        x_0 = torch.randn([2, 2])
162        x_1 = torch.randn([2, 2])
163        x_2 = torch.randn([2, 2])
164        x = [x_0, x_1, x_2]
165
166        y = torch.randn([2, 2])
167        z = 5
168
169        eager_result = fn(x, y, z)
170
171        counter = CompileCounter()
172        fn = torch._dynamo.optimize(counter, nopython=True)(fn)
173
174        compile_result = fn(x, y, z)
175        self.assertEqual(compile_result, eager_result)
176        self.assertEqual(counter.frame_count, 1)
177        self.assertEqual(counter.op_count, 9)
178        # Graph for reference
179        #         -------------  ------  -----------------------  ------------------------------------  --------
180        # placeholder    l_y_    L_y_                     ()                                    {}
181        # call_function  mul     <built-in function mul>  (l_y_, 5)                             {}
182        # call_function  mul_1   <built-in function mul>  (l_y_, 5)                             {}
183        # call_function  mul_2   <built-in function mul>  (l_y_, 5)                             {}
184        # call_function  mul_3   <built-in function mul>  (l_y_, 1)                             {}
185        # call_function  mul_4   <built-in function mul>  (mul, 1)                              {}
186        # call_function  mul_5   <built-in function mul>  (l_y_, mul_3)                         {}
187        # call_function  mul_6   <built-in function mul>  (mul_1, mul_4)                        {}
188        # call_function  mul_7   <built-in function mul>  (l_y_, mul_5)                         {}
189        # call_function  mul_8   <built-in function mul>  (mul_2, mul_6)                        {}
190        # output         output  output                   ((mul_7, mul_8, mul, mul_1, mul_2),)  {}
191
192    def test_complex_attr_access_with_graph_breaks(self):
193        def fn(x, y, z):
194            for t in x:
195                t.y = y
196                t.z = y * z
197
198            print("Break!")
199
200            new_y = 1
201            new_z = 1
202            for t in x:
203                new_y = t.y * new_y
204                new_z = t.z * new_z
205
206            return new_y, new_z
207
208        x_0 = torch.randn([2, 2])
209        x_1 = torch.randn([2, 2])
210        x_2 = torch.randn([2, 2])
211        x = [x_0, x_1, x_2]
212
213        y = torch.randn([2, 2])
214        z = 5
215
216        eager_result = fn(x, y, z)
217
218        counter = CompileCounter()
219        fn = torch._dynamo.optimize(counter, nopython=False)(fn)
220
221        compile_result = fn(x, y, z)
222        self.assertEqual(compile_result, eager_result)
223        self.assertEqual(counter.frame_count, 2)
224        self.assertEqual(counter.op_count, 9)
225        # Graph for reference
226        # -------------  ------  -----------------------  ----------------------  --------
227        # placeholder    l_y_    L_y_                     ()                      {}
228        # call_function  mul     <built-in function mul>  (l_y_, 5)               {}
229        # call_function  mul_1   <built-in function mul>  (l_y_, 5)               {}
230        # call_function  mul_2   <built-in function mul>  (l_y_, 5)               {}
231        # output         output  output                   ((mul, mul_1, mul_2),)  {}
232        # [GRAPH BREAK!]
233        # -------------  -------  -----------------------  -----------------  --------
234        # placeholder    l_x_0_y  L_x_0_y                  ()                 {}
235        # placeholder    l_x_0_z  L_x_0_z                  ()                 {}
236        # placeholder    l_x_1_y  L_x_1_y                  ()                 {}
237        # placeholder    l_x_1_z  L_x_1_z                  ()                 {}
238        # placeholder    l_x_2_y  L_x_2_y                  ()                 {}
239        # placeholder    l_x_2_z  L_x_2_z                  ()                 {}
240        # call_function  mul      <built-in function mul>  (l_x_0_y, 1)       {}
241        # call_function  mul_1    <built-in function mul>  (l_x_0_z, 1)       {}
242        # call_function  mul_2    <built-in function mul>  (l_x_1_y, mul)     {}
243        # call_function  mul_3    <built-in function mul>  (l_x_1_z, mul_1)   {}
244        # call_function  mul_4    <built-in function mul>  (l_x_2_y, mul_2)   {}
245        # call_function  mul_5    <built-in function mul>  (l_x_2_z, mul_3)   {}
246        # output         output   output                   ((mul_4, mul_5),)  {}
247
248    def test_complex_attr_access_with_inline_reconstruct(self):
249        def inline_test_fn(x, y, z):
250            print("f")
251            return x.a + y.a + z.a
252
253        def fn(x, y, z):
254            x.a = 1
255            y.a = 2
256            z.a = 3
257
258            mult = inline_test_fn(x, y, z)
259            y = y * mult
260            x = x * mult
261            return x, y
262
263        x = torch.randn([2, 2])
264        y = torch.randn([2, 2])
265        z = torch.randn([2, 2])
266
267        eager_result = fn(x, y, z)
268
269        counter = CompileCounter()
270
271        fn = torch._dynamo.optimize(counter, nopython=False)(fn)
272
273        compile_result = fn(x, y, z)
274        self.assertEqual(compile_result, eager_result)
275        self.assertEqual(counter.frame_count, 1)
276        self.assertEqual(counter.op_count, 2)
277        # Graph for reference
278        # __compiled_fn_2 <eval_with_key>.0 opcode         name    target                   args             kwargs
279        # -------------  ------  -----------------------  ---------------  --------
280        # placeholder    l_x_    L_x_                     ()               {}
281        # placeholder    l_y_    L_y_                     ()               {}
282        # call_function  mul     <built-in function mul>  (l_y_, 6)        {}
283        # call_function  mul_1   <built-in function mul>  (l_x_, 6)        {}
284        # output         output  output                   ((mul_1, mul),)  {}
285
286    def test_set_data_on_input_tensor(self):
287        def fn(x, y):
288            x.data = y.data
289            if x.size() == y.size():
290                return x * y
291            else:
292                return y * y
293
294        x = torch.randn([5, 5])
295        y = torch.randn([2, 2])
296
297        eager_result = fn(x, y)
298
299        eager_and_record = EagerAndRecordGraphs()
300
301        counter = CompileCounterWithBackend(eager_and_record)
302
303        fn = torch._dynamo.optimize(counter, nopython=True)(fn)
304
305        compile_result = fn(x, y)
306
307        graph = eager_and_record.graphs[0]
308        actual = normalize_gm(graph.print_readable(False))
309
310        self.assertEqual(compile_result, eager_result)
311        self.assertEqual(counter.frame_count, 1)
312        self.assertEqual(counter.op_count, 6)
313        self.assertExpectedInline(
314            actual,
315            """\
316class GraphModule(torch.nn.Module):
317    def forward(self, L_y_: "f32[2, 2]", L_x_: "f32[2, 2]"):
318        l_y_ = L_y_
319        l_x_ = L_x_
320
321        _get_data_attr: "f32[2, 2]" = torch._C._autograd._get_data_attr(l_y_)
322
323        _set_grad_enabled = torch._C._set_grad_enabled(False);  _set_grad_enabled = None
324
325        set_: "f32[2, 2]" = torch_Tensor_set_(l_x_, _get_data_attr);  _get_data_attr = None
326
327        _set_grad_enabled_1 = torch._C._set_grad_enabled(True);  _set_grad_enabled_1 = None
328
329        _lower_version_count_by_1 = torch__dynamo_variables_builtin__lower_version_count_by_1(set_);  set_ = _lower_version_count_by_1 = None
330
331        mul: "f32[2, 2]" = l_x_ * l_y_;  l_x_ = l_y_ = None
332        return (mul,)
333""",
334        )
335
336    # Note - this does not actually get captured in the graph yet.
337    # The plan of record is to introduce a set_data op, entirely subsume the operation into a call_function
338    # in the fx graph, and let aot_autograd handle it.
339    def test_set_data_on_scoped_tensor(self):
340        def fn(x):
341            z = torch.zeros([4, 4])
342            z.data = x.data
343            if x.size() == z.size():
344                return z * x
345            else:
346                return x
347
348        x = torch.randn([5, 5])
349
350        eager_result = fn(x)
351
352        counter = CompileCounter()
353
354        fn = torch._dynamo.optimize(counter, nopython=False)(fn)
355
356        compile_result = fn(x)
357        self.assertEqual(compile_result, eager_result)
358        self.assertEqual(counter.frame_count, 2)
359        self.assertEqual(counter.op_count, 3)
360
361    def test_set_data_on_user_defined_class_input_tensor(self):
362        class MyUserDefinedClass:
363            def __init__(self, x, y):
364                self.x = x
365                self.y = y
366
367            def do_some_setattr_stuff(self):
368                self.z = x * y
369                self.a = x + x
370                return self.z * self.a
371
372        x = torch.randn([5, 5])
373        y = torch.randn([5, 5])
374        mudc_1 = MyUserDefinedClass(x, y)
375
376        eager_result = mudc_1.do_some_setattr_stuff()
377
378        counter = CompileCounter()
379
380        mudc_2 = MyUserDefinedClass(x, y)
381        do_some_setattr_stuff = torch._dynamo.optimize(counter, nopython=True)(
382            mudc_2.do_some_setattr_stuff
383        )
384
385        compile_result = do_some_setattr_stuff()
386        self.assertEqual(compile_result, eager_result)
387        self.assertEqual(counter.frame_count, 1)
388        self.assertEqual(counter.op_count, 3)
389        # Graph for reference
390        #  __compiled_fn_0 <eval_with_key>.0 opcode         name    target                   args                  kwargs
391        # -------------  ------  -----------------------  --------------------  --------
392        # placeholder    l_x_    L_x_                     ()                    {}
393        # placeholder    l_y_    L_y_                     ()                    {}
394        # call_function  mul     <built-in function mul>  (l_x_, l_y_)          {}
395        # call_function  add     <built-in function add>  (l_x_, l_x_)          {}
396        # call_function  mul_1   <built-in function mul>  (mul, add)            {}
397        # output         output  output                   ((mul_1, mul, add),)  {}
398
399
400if __name__ == "__main__":
401    from torch._dynamo.test_case import run_tests
402
403    run_tests()
404