1# Owner(s): ["module: dynamo"] 2# flake8: noqa 3import torch 4import torch._dynamo 5import torch._dynamo.test_case 6import torch._dynamo.testing 7from torch._dynamo.testing import ( 8 CompileCounter, 9 CompileCounterWithBackend, 10 EagerAndRecordGraphs, 11 normalize_gm, 12) 13 14 15class TestInputAttrTracking(torch._dynamo.test_case.TestCase): 16 def test_tensor_property_on_tensor(self): 17 def fn(x): 18 return x * x.y 19 20 x_ = torch.randn([2, 2]) 21 y_ = torch.randn([2, 2]) 22 x_.y = y_ 23 24 eager_result = fn(x_) 25 26 graph = None 27 28 def grab_graph_backend(gm, inps): 29 nonlocal graph 30 graph = gm 31 return gm 32 33 fn = torch._dynamo.optimize(grab_graph_backend, nopython=True)(fn) 34 compile_result = fn(x_) 35 self.assertEqual(eager_result, compile_result) 36 37 placeholder_cnt = 0 38 for node in graph.graph.nodes: 39 if node.op == "placeholder": 40 placeholder_cnt += 1 41 42 # We want to be very sure that this lifts y to inputs! 43 self.assertEqual(placeholder_cnt, 2) 44 45 def test_tensor_property_assigned_on_tensor(self): 46 def fn(x, y): 47 x.y = y 48 return x * x.y 49 50 x_ = torch.randn([2, 2]) 51 y_ = torch.randn([2, 2]) 52 53 eager_result = fn(x_, y_) 54 55 graph = None 56 57 def grab_graph_backend(gm, inps): 58 nonlocal graph 59 graph = gm 60 return gm 61 62 fn = torch._dynamo.optimize(grab_graph_backend, nopython=True)(fn) 63 compile_result = fn(x_, y_) 64 self.assertEqual(eager_result, compile_result) 65 66 placeholder_cnt = 0 67 for node in graph.graph.nodes: 68 if node.op == "placeholder": 69 placeholder_cnt += 1 70 71 # y is already an input 72 self.assertEqual(placeholder_cnt, 2) 73 74 def test_const_property_on_tensor(self): 75 def fn(x): 76 return x * x.y 77 78 x_ = torch.randn([2, 2]) 79 y_ = 4 80 x_.y = y_ 81 82 eager_result = fn(x_) 83 84 graph = None 85 86 def grab_graph_backend(gm, inps): 87 nonlocal graph 88 graph = gm 89 return gm 90 91 fn = torch._dynamo.optimize(grab_graph_backend, nopython=True)(fn) 92 compile_result = fn(x_) 93 self.assertEqual(eager_result, compile_result) 94 95 placeholder_cnt = 0 96 for node in graph.graph.nodes: 97 if node.op == "placeholder": 98 placeholder_cnt += 1 99 100 # We want to be very sure that this does not lifts y to inputs, as its a const 101 self.assertEqual(placeholder_cnt, 1) 102 103 def test_const_property_assigned_on_tensor(self): 104 def fn(x, y): 105 x.y = y 106 return x * x.y 107 108 x_ = torch.randn([2, 2]) 109 y_ = 4 110 111 eager_result = fn(x_, y_) 112 113 fn = torch._dynamo.optimize("eager", nopython=True)(fn) 114 compile_result = fn(x_, y_) 115 self.assertEqual(eager_result, compile_result) 116 117 def test_guards_correctly_property_assigned_on_tensor_type_change(self): 118 def fn(x, y): 119 x.y = y 120 return x * x.y 121 122 x_ = torch.randn([2, 2]) 123 124 fn = torch._dynamo.optimize("eager", nopython=True)(fn) 125 compile_result_const = fn(x_, 4) 126 self.assertEqual(compile_result_const, x_ * 4) 127 128 y = torch.randn([2, 2]) 129 compile_result_tensor = fn(x_, y) 130 self.assertEqual(compile_result_tensor, x_ * y) 131 132 def test_guards_correctly_property_assigned_on_tensor_type_change_inductor(self): 133 def fn(x, y): 134 x.y = y 135 return x * x.y 136 137 x_ = torch.randn([2, 2]) 138 139 fn = torch._dynamo.optimize("inductor", nopython=True)(fn) 140 compile_result_const = fn(x_, 4) 141 self.assertEqual(compile_result_const, x_ * 4) 142 143 y = torch.randn([2, 2]) 144 compile_result_tensor = fn(x_, y) 145 self.assertEqual(compile_result_tensor, x_ * y) 146 147 def test_complex_attr_access_without_graph_breaks(self): 148 def fn(x, y, z): 149 for t in x: 150 t.y = y 151 t.z = y * z 152 153 new_y = 1 154 new_z = 1 155 for t in x: 156 new_y = t.y * new_y 157 new_z = t.z * new_z 158 159 return new_y, new_z 160 161 x_0 = torch.randn([2, 2]) 162 x_1 = torch.randn([2, 2]) 163 x_2 = torch.randn([2, 2]) 164 x = [x_0, x_1, x_2] 165 166 y = torch.randn([2, 2]) 167 z = 5 168 169 eager_result = fn(x, y, z) 170 171 counter = CompileCounter() 172 fn = torch._dynamo.optimize(counter, nopython=True)(fn) 173 174 compile_result = fn(x, y, z) 175 self.assertEqual(compile_result, eager_result) 176 self.assertEqual(counter.frame_count, 1) 177 self.assertEqual(counter.op_count, 9) 178 # Graph for reference 179 # ------------- ------ ----------------------- ------------------------------------ -------- 180 # placeholder l_y_ L_y_ () {} 181 # call_function mul <built-in function mul> (l_y_, 5) {} 182 # call_function mul_1 <built-in function mul> (l_y_, 5) {} 183 # call_function mul_2 <built-in function mul> (l_y_, 5) {} 184 # call_function mul_3 <built-in function mul> (l_y_, 1) {} 185 # call_function mul_4 <built-in function mul> (mul, 1) {} 186 # call_function mul_5 <built-in function mul> (l_y_, mul_3) {} 187 # call_function mul_6 <built-in function mul> (mul_1, mul_4) {} 188 # call_function mul_7 <built-in function mul> (l_y_, mul_5) {} 189 # call_function mul_8 <built-in function mul> (mul_2, mul_6) {} 190 # output output output ((mul_7, mul_8, mul, mul_1, mul_2),) {} 191 192 def test_complex_attr_access_with_graph_breaks(self): 193 def fn(x, y, z): 194 for t in x: 195 t.y = y 196 t.z = y * z 197 198 print("Break!") 199 200 new_y = 1 201 new_z = 1 202 for t in x: 203 new_y = t.y * new_y 204 new_z = t.z * new_z 205 206 return new_y, new_z 207 208 x_0 = torch.randn([2, 2]) 209 x_1 = torch.randn([2, 2]) 210 x_2 = torch.randn([2, 2]) 211 x = [x_0, x_1, x_2] 212 213 y = torch.randn([2, 2]) 214 z = 5 215 216 eager_result = fn(x, y, z) 217 218 counter = CompileCounter() 219 fn = torch._dynamo.optimize(counter, nopython=False)(fn) 220 221 compile_result = fn(x, y, z) 222 self.assertEqual(compile_result, eager_result) 223 self.assertEqual(counter.frame_count, 2) 224 self.assertEqual(counter.op_count, 9) 225 # Graph for reference 226 # ------------- ------ ----------------------- ---------------------- -------- 227 # placeholder l_y_ L_y_ () {} 228 # call_function mul <built-in function mul> (l_y_, 5) {} 229 # call_function mul_1 <built-in function mul> (l_y_, 5) {} 230 # call_function mul_2 <built-in function mul> (l_y_, 5) {} 231 # output output output ((mul, mul_1, mul_2),) {} 232 # [GRAPH BREAK!] 233 # ------------- ------- ----------------------- ----------------- -------- 234 # placeholder l_x_0_y L_x_0_y () {} 235 # placeholder l_x_0_z L_x_0_z () {} 236 # placeholder l_x_1_y L_x_1_y () {} 237 # placeholder l_x_1_z L_x_1_z () {} 238 # placeholder l_x_2_y L_x_2_y () {} 239 # placeholder l_x_2_z L_x_2_z () {} 240 # call_function mul <built-in function mul> (l_x_0_y, 1) {} 241 # call_function mul_1 <built-in function mul> (l_x_0_z, 1) {} 242 # call_function mul_2 <built-in function mul> (l_x_1_y, mul) {} 243 # call_function mul_3 <built-in function mul> (l_x_1_z, mul_1) {} 244 # call_function mul_4 <built-in function mul> (l_x_2_y, mul_2) {} 245 # call_function mul_5 <built-in function mul> (l_x_2_z, mul_3) {} 246 # output output output ((mul_4, mul_5),) {} 247 248 def test_complex_attr_access_with_inline_reconstruct(self): 249 def inline_test_fn(x, y, z): 250 print("f") 251 return x.a + y.a + z.a 252 253 def fn(x, y, z): 254 x.a = 1 255 y.a = 2 256 z.a = 3 257 258 mult = inline_test_fn(x, y, z) 259 y = y * mult 260 x = x * mult 261 return x, y 262 263 x = torch.randn([2, 2]) 264 y = torch.randn([2, 2]) 265 z = torch.randn([2, 2]) 266 267 eager_result = fn(x, y, z) 268 269 counter = CompileCounter() 270 271 fn = torch._dynamo.optimize(counter, nopython=False)(fn) 272 273 compile_result = fn(x, y, z) 274 self.assertEqual(compile_result, eager_result) 275 self.assertEqual(counter.frame_count, 1) 276 self.assertEqual(counter.op_count, 2) 277 # Graph for reference 278 # __compiled_fn_2 <eval_with_key>.0 opcode name target args kwargs 279 # ------------- ------ ----------------------- --------------- -------- 280 # placeholder l_x_ L_x_ () {} 281 # placeholder l_y_ L_y_ () {} 282 # call_function mul <built-in function mul> (l_y_, 6) {} 283 # call_function mul_1 <built-in function mul> (l_x_, 6) {} 284 # output output output ((mul_1, mul),) {} 285 286 def test_set_data_on_input_tensor(self): 287 def fn(x, y): 288 x.data = y.data 289 if x.size() == y.size(): 290 return x * y 291 else: 292 return y * y 293 294 x = torch.randn([5, 5]) 295 y = torch.randn([2, 2]) 296 297 eager_result = fn(x, y) 298 299 eager_and_record = EagerAndRecordGraphs() 300 301 counter = CompileCounterWithBackend(eager_and_record) 302 303 fn = torch._dynamo.optimize(counter, nopython=True)(fn) 304 305 compile_result = fn(x, y) 306 307 graph = eager_and_record.graphs[0] 308 actual = normalize_gm(graph.print_readable(False)) 309 310 self.assertEqual(compile_result, eager_result) 311 self.assertEqual(counter.frame_count, 1) 312 self.assertEqual(counter.op_count, 6) 313 self.assertExpectedInline( 314 actual, 315 """\ 316class GraphModule(torch.nn.Module): 317 def forward(self, L_y_: "f32[2, 2]", L_x_: "f32[2, 2]"): 318 l_y_ = L_y_ 319 l_x_ = L_x_ 320 321 _get_data_attr: "f32[2, 2]" = torch._C._autograd._get_data_attr(l_y_) 322 323 _set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None 324 325 set_: "f32[2, 2]" = torch_Tensor_set_(l_x_, _get_data_attr); _get_data_attr = None 326 327 _set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None 328 329 _lower_version_count_by_1 = torch__dynamo_variables_builtin__lower_version_count_by_1(set_); set_ = _lower_version_count_by_1 = None 330 331 mul: "f32[2, 2]" = l_x_ * l_y_; l_x_ = l_y_ = None 332 return (mul,) 333""", 334 ) 335 336 # Note - this does not actually get captured in the graph yet. 337 # The plan of record is to introduce a set_data op, entirely subsume the operation into a call_function 338 # in the fx graph, and let aot_autograd handle it. 339 def test_set_data_on_scoped_tensor(self): 340 def fn(x): 341 z = torch.zeros([4, 4]) 342 z.data = x.data 343 if x.size() == z.size(): 344 return z * x 345 else: 346 return x 347 348 x = torch.randn([5, 5]) 349 350 eager_result = fn(x) 351 352 counter = CompileCounter() 353 354 fn = torch._dynamo.optimize(counter, nopython=False)(fn) 355 356 compile_result = fn(x) 357 self.assertEqual(compile_result, eager_result) 358 self.assertEqual(counter.frame_count, 2) 359 self.assertEqual(counter.op_count, 3) 360 361 def test_set_data_on_user_defined_class_input_tensor(self): 362 class MyUserDefinedClass: 363 def __init__(self, x, y): 364 self.x = x 365 self.y = y 366 367 def do_some_setattr_stuff(self): 368 self.z = x * y 369 self.a = x + x 370 return self.z * self.a 371 372 x = torch.randn([5, 5]) 373 y = torch.randn([5, 5]) 374 mudc_1 = MyUserDefinedClass(x, y) 375 376 eager_result = mudc_1.do_some_setattr_stuff() 377 378 counter = CompileCounter() 379 380 mudc_2 = MyUserDefinedClass(x, y) 381 do_some_setattr_stuff = torch._dynamo.optimize(counter, nopython=True)( 382 mudc_2.do_some_setattr_stuff 383 ) 384 385 compile_result = do_some_setattr_stuff() 386 self.assertEqual(compile_result, eager_result) 387 self.assertEqual(counter.frame_count, 1) 388 self.assertEqual(counter.op_count, 3) 389 # Graph for reference 390 # __compiled_fn_0 <eval_with_key>.0 opcode name target args kwargs 391 # ------------- ------ ----------------------- -------------------- -------- 392 # placeholder l_x_ L_x_ () {} 393 # placeholder l_y_ L_y_ () {} 394 # call_function mul <built-in function mul> (l_x_, l_y_) {} 395 # call_function add <built-in function add> (l_x_, l_x_) {} 396 # call_function mul_1 <built-in function mul> (mul, add) {} 397 # output output output ((mul_1, mul, add),) {} 398 399 400if __name__ == "__main__": 401 from torch._dynamo.test_case import run_tests 402 403 run_tests() 404