1# Owner(s): ["module: functionalization"] 2import torch 3from torch.testing._internal.common_utils import TestCase, run_tests 4from torch.fx.passes.reinplace import reinplace 5from torch.fx.experimental.proxy_tensor import make_fx 6from torch.fx.experimental.symbolic_shapes import ShapeEnv 7from torch._dynamo.source import ConstantSource 8from torch.fx.experimental.sym_node import SymNode 9 10try: 11 from functorch.experimental import functionalize 12 HAS_FUNCTIONALIZATION = True 13except Exception as e: 14 HAS_FUNCTIONALIZATION = False 15 16class TestReinplacePass(TestCase): 17 18 def test_reinplace_basic(self): 19 # Basic test: the out-of-place add() call should be converted 20 # into add_() 21 def f(x): 22 a = x.clone() 23 b = a.add(1) 24 return b 25 26 inpt = torch.ones(2) 27 f2 = reinplace(make_fx(f)(inpt), inpt) 28 expected_out = f(inpt) 29 actual_out = f2(inpt) 30 self.assertEqual(actual_out, expected_out) 31 self.assertExpectedInline(f2.code, """\ 32 33 34 35def forward(self, x_1): 36 clone = torch.ops.aten.clone.default(x_1); x_1 = None 37 add = torch.ops.aten.add_.Tensor(clone, 1); add = None 38 return clone 39 """) 40 41 42 def test_reinplace_with_view(self): 43 def f(x): 44 a = x.clone() 45 a_view = a.view(-1) 46 # We shouldn't re-inplace the first add(), because an alias of a is re-used later in the program 47 b = a.add(1) 48 # Second add() is fine to re-inplace 49 c = a_view.add(1) 50 return c 51 52 inpt = torch.ones(2) 53 f2 = reinplace(make_fx(f)(inpt), inpt) 54 expected_out = f(inpt) 55 actual_out = f2(inpt) 56 self.assertEqual(actual_out, expected_out) 57 self.assertExpectedInline(f2.code, """\ 58 59 60 61def forward(self, x_1): 62 clone = torch.ops.aten.clone.default(x_1); x_1 = None 63 view = torch.ops.aten.view.default(clone, [-1]) 64 add = torch.ops.aten.add.Tensor(clone, 1); clone = add = None 65 add_1 = torch.ops.aten.add_.Tensor(view, 1); add_1 = None 66 return view 67 """) 68 69 def test_reinplace_different_metadata(self): 70 def f(a_): 71 a = a_.clone() 72 b = a + 1 73 # Naively, we shouldn't try to inplace the .ge() call, 74 # because that would require resizing "b" (from a float to a bool tensor). 75 c = torch.ge(b, a) 76 return c 77 inpt = torch.ones(4) 78 f2 = reinplace(make_fx(f)(inpt), inpt) 79 expected_out = f(inpt) 80 actual_out = f2(inpt) 81 self.assertEqual(actual_out, expected_out) 82 # The .ge() should not be reinplaced. 83 self.assertExpectedInline(f2.code, """\ 84 85 86 87def forward(self, a__1): 88 clone = torch.ops.aten.clone.default(a__1); a__1 = None 89 add = torch.ops.aten.add.Tensor(clone, 1) 90 ge = torch.ops.aten.ge.Tensor(add, clone); add = clone = None 91 return ge 92 """) 93 94 def test_reinplace_overlapping_memory(self): 95 def f(a_): 96 a = a_.clone() 97 b = a.expand(4, 4) 98 # Can't reinplace because b has overlapping memory. 99 c = b.add(1) 100 return c 101 inpt = torch.ones(1) 102 f2 = reinplace(make_fx(f)(inpt), inpt) 103 expected_out = f(inpt) 104 actual_out = f2(inpt) 105 self.assertEqual(actual_out, expected_out) 106 self.assertExpectedInline(f2.code, """\ 107 108 109 110def forward(self, a__1): 111 clone = torch.ops.aten.clone.default(a__1); a__1 = None 112 expand = torch.ops.aten.expand.default(clone, [4, 4]); clone = None 113 add = torch.ops.aten.add.Tensor(expand, 1); expand = None 114 return add 115 """) 116 117 # This test won't actually run in CI, because it requires functionalize() from functorch. 118 # I'm planning on testing more comprehensively with torchbench models, 119 # but we can make this testing better once functorch moves into pytorch/pytorch. 120 def test_reinplace_scatter_op(self): 121 def f(a_): 122 # for now, don't test mutations to inputs 123 a = a_.clone() 124 e = a.view(-1) 125 b = a.view(-1) 126 c = b[0] 127 d = c.view(-1) 128 d.add_(1) 129 return a + e 130 131 if not HAS_FUNCTIONALIZATION: 132 return 133 inpt = torch.ones(4) 134 f2 = reinplace(make_fx(functionalize(f))(inpt), inpt) 135 expected_out = f(inpt) 136 actual_out = f2(inpt) 137 self.assertEqual(actual_out, expected_out) 138 # NOTE: one slight pessimization here is the fact that 139 # there are a bunch of redundant views in the graph. 140 # Technically, half of these views are duplicates that we could de-dup. 141 # This shouldn't really hurt performance though, since creating an extra view 142 # is effectively just moving some metadata around (and allocating a new TensorImpl). 143 # We can/should update the pass in the future to clean this up. 144 self.assertExpectedInline(f2.code, """\ 145 146 147 148def forward(self, a__1): 149 clone = torch.ops.aten.clone.default(a__1); a__1 = None 150 view = torch.ops.aten.view.default(clone, [-1]); view = None 151 view_1 = torch.ops.aten.view.default(clone, [-1]) 152 select = torch.ops.aten.select.int(view_1, 0, 0); view_1 = None 153 view_2 = torch.ops.aten.view.default(select, [-1]); select = None 154 add = torch.ops.aten.add_.Tensor(view_2, 1); add = None 155 view_3 = torch.ops.aten.view.default(clone, [-1]); clone = None 156 select_1 = torch.ops.aten.select.int(view_3, 0, 0); select_1 = None 157 view_4 = torch.ops.aten.view.default(view_2, []); view_2 = view_4 = None 158 view_5 = torch.ops.aten.view.default(view_3, [4]); view_3 = None 159 view_6 = torch.ops.aten.view.default(view_5, [-1]) 160 select_2 = torch.ops.aten.select.int(view_6, 0, 0); view_6 = None 161 view_7 = torch.ops.aten.view.default(select_2, [-1]); select_2 = view_7 = None 162 view_8 = torch.ops.aten.view.default(view_5, [-1]) 163 add_1 = torch.ops.aten.add_.Tensor(view_5, view_8); view_8 = add_1 = None 164 return view_5 165 """) 166 167 def test_reinplace_scatter_twice(self): 168 def f(a_): 169 # for now, don't test mutations to inputs 170 a = a_.clone() 171 b = a[:, 1] 172 c = b[1] 173 c.add_(1) 174 return a 175 176 if not HAS_FUNCTIONALIZATION: 177 return 178 179 inpt = torch.ones(4, 4) 180 f2 = reinplace(make_fx(functionalize(f))(inpt), inpt) 181 expected_out = f(inpt) 182 actual_out = f2(inpt) 183 self.assertEqual(actual_out, expected_out) 184 self.assertExpectedInline(f2.code, """\ 185 186 187 188def forward(self, a__1): 189 clone = torch.ops.aten.clone.default(a__1); a__1 = None 190 slice_1 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807) 191 select = torch.ops.aten.select.int(slice_1, 1, 1); slice_1 = None 192 select_1 = torch.ops.aten.select.int(select, 0, 1); select = None 193 add = torch.ops.aten.add_.Tensor(select_1, 1); select_1 = add = None 194 slice_2 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807) 195 select_2 = torch.ops.aten.select.int(slice_2, 1, 1); slice_2 = select_2 = None 196 slice_3 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807) 197 select_3 = torch.ops.aten.select.int(slice_3, 1, 1); slice_3 = None 198 select_4 = torch.ops.aten.select.int(select_3, 0, 1); select_3 = select_4 = None 199 return clone 200 """) 201 202 def test_reinplace_scatter_twice_with_different_view_op_valid(self): 203 def f(a_): 204 a = a_.clone() 205 b = a[:, 1] 206 c = b[1] 207 c_updated = c.add(1) 208 good_mirror_of_b = a.as_strided((4,), (4,), 1) 209 # good_mirror_of_b points to the same region of memory as b. 210 # and this scatter op below tries to scatter c_updated into the same region 211 # that c currently takes up. 212 # reinplacing logic checks this by confirming that: 213 # c_updated 214 # good_mirror_of_b.select(0, 1) 215 # have the same size/stride/storage_offset. 216 b_updated = torch.select_scatter(good_mirror_of_b, c_updated, 0, 1) 217 return b_updated 218 219 inpt = torch.ones(4, 4) 220 f2 = reinplace(make_fx(f)(inpt), inpt) 221 expected_out = f(inpt) 222 actual_out = f2(inpt) 223 self.assertEqual(actual_out, expected_out) 224 self.assertExpectedInline(f2.code, """\ 225 226 227 228def forward(self, a__1): 229 clone = torch.ops.aten.clone.default(a__1); a__1 = None 230 slice_1 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807) 231 select = torch.ops.aten.select.int(slice_1, 1, 1); slice_1 = None 232 select_1 = torch.ops.aten.select.int(select, 0, 1); select = None 233 add = torch.ops.aten.add_.Tensor(select_1, 1); select_1 = add = None 234 as_strided = torch.ops.aten.as_strided.default(clone, [4], [4], 1); clone = None 235 return as_strided 236 """) 237 238 # Test example where we have a scatter op, where the base tensor 239 # has the same size/stride/storage offset (even though it is a different view), 240 # making it valid to re-inplace 241 def test_reinplace_scatter_twice_with_different_view_op_invalid(self): 242 def f(a_): 243 a = a_.clone() 244 b = a[:, 1] 245 c = b[1] 246 c_updated = c.add(1) 247 good_mirror_of_b = a.as_strided((4,), (4,), 1) 248 # The first arg to select_scatter is an equivalent view to b. 249 # However, the select_scatter call below tries to put c_updated 250 # into a different slice of "b" than what "c" currently occupies. 251 # 252 b_updated = torch.select_scatter(good_mirror_of_b, c_updated, 0, 0) 253 return b_updated 254 255 inpt = torch.ones(4, 4) 256 f2 = reinplace(make_fx(f)(inpt), inpt) 257 expected_out = f(inpt) 258 actual_out = f2(inpt) 259 self.assertEqual(actual_out, expected_out) 260 self.assertExpectedInline(f2.code, """\ 261 262 263 264def forward(self, a__1): 265 clone = torch.ops.aten.clone.default(a__1); a__1 = None 266 slice_1 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807) 267 select = torch.ops.aten.select.int(slice_1, 1, 1); slice_1 = None 268 select_1 = torch.ops.aten.select.int(select, 0, 1); select = None 269 add = torch.ops.aten.add.Tensor(select_1, 1); select_1 = None 270 as_strided = torch.ops.aten.as_strided.default(clone, [4], [4], 1); clone = None 271 select_int = torch.ops.aten.select.int(as_strided, 0, 0) 272 copy__default = torch.ops.aten.copy_.default(select_int, add); select_int = add = copy__default = None 273 return as_strided 274 """) # noqa: B950 275 276 def test_reinplace_scatter_twice_with_different_view_op_invalid2(self): 277 def f(a_): 278 a = a_.clone() 279 b = a[:, 1] 280 c = b[1] 281 c_updated = c.add(1) 282 bad_mirror_of_b = a.as_strided((4,), (4,), 0) 283 # The first arg to select_scatter points to a different than c's base. 284 # This makes it invalid to re-inplace. 285 b_updated = torch.select_scatter(bad_mirror_of_b, c_updated, 0, 1) 286 return b_updated 287 288 inpt = torch.ones(4, 4) 289 f2 = reinplace(make_fx(f)(inpt), inpt) 290 expected_out = f(inpt) 291 actual_out = f2(inpt) 292 # self.assertEqual(actual_out, expected_out) 293 self.assertExpectedInline(f2.code, """\ 294 295 296 297def forward(self, a__1): 298 clone = torch.ops.aten.clone.default(a__1); a__1 = None 299 slice_1 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807) 300 select = torch.ops.aten.select.int(slice_1, 1, 1); slice_1 = None 301 select_1 = torch.ops.aten.select.int(select, 0, 1); select = None 302 add = torch.ops.aten.add.Tensor(select_1, 1); select_1 = None 303 as_strided = torch.ops.aten.as_strided.default(clone, [4], [4], 0); clone = None 304 select_int = torch.ops.aten.select.int(as_strided, 0, 1) 305 copy__default = torch.ops.aten.copy_.default(select_int, add); select_int = add = copy__default = None 306 return as_strided 307 """) # noqa: B950 308 309 310 def test_out_node_updated(self): 311 def f(): 312 x = torch.zeros(2, 2) 313 y = x.diagonal() 314 y_updated = y.add(1) 315 z = torch.diagonal_scatter(x, y_updated) 316 # reinplace needs to know to replace output [z] with [x] 317 return [z] 318 319 if not HAS_FUNCTIONALIZATION: 320 return 321 f2 = reinplace(make_fx(functionalize(f))()) 322 expected_out = f() 323 actual_out = f2() 324 self.assertEqual(actual_out, expected_out) 325 self.assertExpectedInline(f2.code, """\ 326 327 328 329def forward(self): 330 zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False) 331 diagonal = torch.ops.aten.diagonal.default(zeros) 332 add = torch.ops.aten.add_.Tensor(diagonal, 1); diagonal = add = None 333 return [zeros] 334 """) 335 336 def test_reinplace_index_mutation(self): 337 def f(): 338 a = torch.zeros(4, 4, 4) 339 a[:, 2:] = torch.ones(4, 2, 4) 340 return a 341 342 if not HAS_FUNCTIONALIZATION: 343 return 344 f2 = reinplace(make_fx(functionalize(f))()) 345 expected_out = f() 346 actual_out = f2() 347 self.assertEqual(actual_out, expected_out) 348 self.assertExpectedInline(f2.code, """\ 349 350 351 352def forward(self): 353 zeros = torch.ops.aten.zeros.default([4, 4, 4], device = device(type='cpu'), pin_memory = False) 354 ones = torch.ops.aten.ones.default([4, 2, 4], device = device(type='cpu'), pin_memory = False) 355 slice_1 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807) 356 slice_2 = torch.ops.aten.slice.Tensor(slice_1, 1, 2, 9223372036854775807); slice_1 = None 357 copy = torch.ops.aten.copy_.default(slice_2, ones); slice_2 = ones = copy = None 358 slice_3 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807); slice_3 = None 359 slice_4 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807) 360 slice_5 = torch.ops.aten.slice.Tensor(slice_4, 1, 2, 9223372036854775807); slice_4 = slice_5 = None 361 return zeros 362 """) 363 364 def test_reinplace_sym_input(self): 365 # Symbolic input test: the out-of-place add() call should be converted 366 # into add_(), and symbolic input won't cause any error. 367 def f(x, index): 368 a = torch.select(x, 0, index) 369 clone = a.clone() 370 b = clone.add(1) 371 return b 372 373 x = torch.randn((4, 8, 16, 16), requires_grad=False) 374 index = 2 375 shape_env = ShapeEnv() 376 symbol = shape_env.create_symbol(index, source=ConstantSource( 377 f"__testing_only{len(shape_env.var_to_val)}")) 378 sym_index = torch.SymInt(SymNode(symbol, shape_env, int, hint=index)) 379 380 inpt = [x, sym_index] 381 f2 = reinplace(make_fx(f)(*inpt), *inpt) 382 383 real_inpt = [x, index] 384 expected_out = f(*real_inpt) 385 actual_out = f2(*real_inpt) 386 self.assertEqual(actual_out, expected_out) 387 print(f2.code) 388 self.assertExpectedInline(f2.code, """\ 389 390 391 392def forward(self, x_1, index_1): 393 select = torch.ops.aten.select.int(x_1, 0, index_1); x_1 = index_1 = None 394 clone = torch.ops.aten.clone.default(select); select = None 395 add = torch.ops.aten.add_.Tensor(clone, 1); add = None 396 return clone 397 """) 398 399 400if __name__ == '__main__': 401 run_tests() 402