1# Owner(s): ["oncall: pt2"] 2import dataclasses 3import functools 4 5import torch 6from torch import nn 7from torch._dynamo import compiled_autograd 8from torch._dynamo.test_case import run_tests, TestCase 9from torch._dynamo.testing import CompileCounter 10from torch.testing._internal.common_utils import IS_MACOS, skipIfRocm, skipIfXpu 11from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, requires_gpu 12 13 14# Fake distributed 15WORLD_SIZE = 2 16 17 18def init_fake_distributed(device="cpu"): 19 @torch.no_grad 20 def all_gather(t): 21 return torch.cat([t] * WORLD_SIZE, 0) 22 23 @torch.no_grad 24 def reduce_scatter(t): 25 # clone since reduce_scatter input and output should not be aliases. 26 return t.narrow(0, 0, t.size(0) // WORLD_SIZE).clone() 27 28 def fw_pre_hook(mod, inp): 29 if not compiled_autograd.compiled_autograd_enabled: 30 # torch.ops.fsdp.set_ doesn't work well in eager mode, so use the slow copy_ path instead. 31 mod.unsharded_weight.untyped_storage().resize_( 32 mod.unsharded_weight.nelement() * mod.unsharded_weight.element_size() 33 ) 34 with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter( 35 mod.unsharded_weight 36 ): 37 mod.unsharded_weight.copy_(all_gather(mod.sharded_weight)) 38 else: 39 with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter( 40 mod.unsharded_weight 41 ): 42 torch.ops.fsdp.set_( 43 mod.unsharded_weight, all_gather(mod.sharded_weight) 44 ) 45 mod._parameters["weight"] = mod.unsharded_weight 46 47 # Forward: 48 # mod.sharded_weight = local_shard (always) 49 # Before: 50 # mod.weight = local_shard 51 # mod.unsharded_weight = zero-sized allgather 52 # After: 53 # mod.weight = local_shard 54 # mod.unsharded_weight = zero-sized allgather 55 56 def fw_post_hook(mod, inp, out): 57 mod._parameters["weight"] = mod.sharded_weight 58 mod.unsharded_weight.untyped_storage().resize_(0) 59 60 def bw_pre_hook(mod, gO): 61 if not compiled_autograd.compiled_autograd_enabled: 62 # torch.ops.fsdp.set_ doesn't work well in eager mode, so use the slow copy_ path instead. 63 mod.unsharded_weight.untyped_storage().resize_( 64 mod.unsharded_weight.nelement() * mod.unsharded_weight.element_size() 65 ) 66 with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter( 67 mod.unsharded_weight 68 ): 69 mod.unsharded_weight.copy_(all_gather(mod.sharded_weight)) 70 else: 71 with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter( 72 mod.unsharded_weight 73 ): 74 torch.ops.fsdp.set_( 75 mod.unsharded_weight, all_gather(mod.sharded_weight) 76 ) 77 mod._parameters["weight"] = mod.unsharded_weight 78 79 # Backward: 80 # mod.sharded_weight = local_shard (always) 81 # Before: 82 # mod.weight = local_shard 83 # mod.unsharded_weight = zero-sized allgather 84 # After: 85 # mod.weight = local_shard 86 # mod.unsharded_weight = zero-sized allgather 87 88 def bw_post_hook(mod, gI, gO): 89 grad = mod.weight.grad 90 new_grad = reduce_scatter(grad) 91 mod._parameters["weight"] = mod.sharded_weight 92 mod.weight.grad = new_grad 93 mod.unsharded_weight.untyped_storage().resize_(0) 94 95 torch.manual_seed(1234) 96 m = nn.Linear(20, 10, bias=False, device=device) 97 98 # Mimics eager 1st iteration 99 m.sharded_weight = nn.Parameter(reduce_scatter(m.weight)) 100 m.unsharded_weight = nn.Parameter(all_gather(m.sharded_weight)) 101 m.unsharded_weight.untyped_storage().resize_(0) 102 103 m.register_full_backward_pre_hook(bw_pre_hook) 104 m.register_full_backward_hook(bw_post_hook) 105 m.register_forward_pre_hook(fw_pre_hook) 106 m.register_forward_hook(fw_post_hook) 107 return m, torch.rand(2, 20, requires_grad=True, device=device) 108 109 110def init_module_bw_hooks(allow_eager): 111 def bw_pre_hook(mod, gO): 112 assert allow_eager or torch._dynamo.is_compiling() 113 assert mod.weight.size() == (10, 10) 114 mod.hook_count_pre.add_(1) 115 return (torch.sin(gO[0] + 1.2),) 116 117 def bw_post_hook(mod, gI, gO): 118 assert allow_eager or torch._dynamo.is_compiling() 119 assert mod.weight.size() == (10, 10) 120 mod.hook_count_post.add_(1) 121 return (torch.sin(gI[0] + 3.4),) 122 123 torch.manual_seed(1234) 124 m = nn.Linear(10, 10) 125 m.hook_count_pre = torch.tensor(0) 126 m.hook_count_post = torch.tensor(0) 127 m.register_full_backward_pre_hook(bw_pre_hook) 128 m.register_full_backward_hook(bw_post_hook) 129 return m, torch.rand(2, 10, requires_grad=True) 130 131 132def steps(m, inp): 133 for _ in range(4): 134 out = m(inp) 135 out.sum().backward() 136 return out 137 138 139class DistributedPatternTests(TestCase): 140 def test_intermediate_hook_with_closure(self): 141 @dataclasses.dataclass 142 class CustomObj: 143 val: torch.Tensor 144 145 def fn(x, obj): 146 y = x.sin() 147 closure_var = y + 1 148 y.register_hook(lambda grad: grad + obj.val + closure_var) 149 z = y.sin() 150 return z 151 152 opt = torch.compile(fn, fullgraph=True) 153 154 obj1 = CustomObj(torch.tensor(88)) 155 obj2 = CustomObj(torch.tensor(99)) 156 x0 = torch.ones(4, requires_grad=True) 157 x1 = torch.ones(4, requires_grad=True) 158 x2 = torch.ones(4, requires_grad=True) 159 x3 = torch.ones(4, requires_grad=True) 160 fn(x0, obj1).sum().backward() 161 fn(x1, obj2).sum().backward() 162 163 with compiled_autograd.enable(functools.partial(torch.compile, fullgraph=True)): 164 opt(x2, obj1).sum().backward() 165 opt(x3, obj2).sum().backward() 166 167 self.assertEqual(x0.grad, x2.grad) 168 self.assertEqual(x1.grad, x3.grad) 169 170 @torch.no_grad() 171 def _test_storage_resize_zero(self, device): 172 @torch.compile(fullgraph=True) 173 def fn(x): 174 y = torch.sin(x) 175 x.untyped_storage().resize_(0) 176 return torch.cos(y) 177 178 x = torch.randn(10, device=device) 179 expected = torch.cos(torch.sin(x)) 180 y = fn(x) 181 self.assertEqual(y, expected) 182 self.assertEqual(x.untyped_storage().size(), 0) 183 184 def test_storage_resize_zero_cpu(self): 185 self._test_storage_resize_zero("cpu") 186 187 @skipIfRocm 188 @requires_gpu() 189 def test_storage_resize_zero_gpu(self): 190 self._test_storage_resize_zero(GPU_TYPE) 191 192 @torch.no_grad() 193 def _test_storage_resize_nonzero(self, device): 194 @torch.compile(fullgraph=True) 195 def fn(x, out): 196 y = torch.sin(x) 197 assert out.untyped_storage().size() == 0 198 out.untyped_storage().resize_(x.untyped_storage().size()) 199 out.copy_(y.cos()) 200 201 x = torch.randn(10, device=device) 202 out = torch.randn(10, device=device) 203 expected = torch.cos(torch.sin(x)) 204 out.untyped_storage().resize_(0) 205 fn(x, out) 206 self.assertEqual(out.untyped_storage().size(), x.untyped_storage().size()) 207 self.assertEqual(out, expected) 208 209 def test_storage_resize_nonzero_cpu(self): 210 self._test_storage_resize_nonzero("cpu") 211 212 @skipIfRocm 213 @requires_gpu() 214 def test_storage_resize_nonzero_gpu(self): 215 self._test_storage_resize_nonzero(GPU_TYPE) 216 217 @torch.no_grad() 218 def test_unsafe_set_version_counter1(self): 219 cnt = CompileCounter() 220 221 @torch.compile(backend=cnt, fullgraph=True) 222 def fn(w, x): 223 x = x.sin() 224 v = w._version 225 w.copy_(x + 1) 226 torch._C._autograd._unsafe_set_version_counter(w, v) 227 return w, v 228 229 for v in (3, 0, 1): 230 w1 = torch.randn(16) 231 for i in range(v): 232 w1.fill_(i) # bump w1._version 233 self.assertEqual(w1._version, v) 234 x1 = torch.randn(16) 235 w2, v2 = fn(w1, x1) 236 237 self.assertIs(w1, w2) 238 self.assertEqual(w1, x1.sin() + 1) 239 self.assertEqual(v2, v) 240 self.assertEqual(w1._version, v) 241 self.assertEqual(cnt.frame_count, 1) 242 243 def test_unsafe_set_version_counter2(self): 244 @torch.compile(backend="inductor", fullgraph=True) 245 def fn(w, x): 246 r = w.sin() 247 with torch.no_grad(): 248 v = w._version 249 w.copy_(x) 250 torch._C._autograd._unsafe_set_version_counter(w, v) 251 return r 252 253 w1 = torch.randn(1, requires_grad=True) 254 x1 = torch.randn(1) 255 expected_r1 = w1.detach().sin() 256 257 r1 = fn(w1, x1) 258 r1.backward() 259 self.assertEqual(r1, expected_r1) 260 self.assertEqual(w1, x1) 261 self.assertEqual(w1.grad, x1.cos()) 262 263 @torch.no_grad() 264 def test_unsafe_preserve_version_counter1(self): 265 @torch.compile(backend="eager", fullgraph=True) 266 def fn(w, x): 267 x = x.sin() 268 with torch.autograd._unsafe_preserve_version_counter(w): 269 w.copy_(x + 1) 270 return w 271 272 w1 = torch.randn(16).fill_(0).fill_(1) 273 x1 = torch.randn(16) 274 v1 = w1._version 275 w2 = fn(w1, x1) 276 v2 = w1._version 277 278 self.assertIs(w1, w2) 279 self.assertEqual(w1, x1.sin() + 1) 280 self.assertEqual(v1, v2) 281 282 def test_unsafe_preserve_version_counter2(self): 283 @torch.compile(backend="inductor", fullgraph=True) 284 def fn(w, x): 285 r = w.sin() 286 with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(w): 287 w.copy_(x) 288 return r 289 290 w1 = torch.randn(1, requires_grad=True) 291 x1 = torch.randn(1) 292 expected_r1 = w1.detach().sin() 293 294 r1 = fn(w1, x1) 295 r1.backward() 296 self.assertEqual(r1, expected_r1) 297 self.assertEqual(w1, x1) 298 self.assertEqual(w1.grad, x1.cos()) 299 300 def test_module_backward_hooks_eager(self): 301 m1, inp1 = init_module_bw_hooks(True) 302 out1 = steps(m1, inp1) 303 304 m2, inp2 = init_module_bw_hooks(False) 305 fw_cnt = CompileCounter() 306 bw_cnt = CompileCounter() 307 with compiled_autograd.enable(torch.compile(backend=bw_cnt, fullgraph=True)): 308 m2 = torch.compile(m2, backend=fw_cnt, fullgraph=True) 309 out2 = steps(m2, inp2) 310 311 self.assertEqual(m1.hook_count_pre, m2.hook_count_pre) 312 self.assertEqual(m1.hook_count_post, m2.hook_count_post) 313 self.assertEqual(out1, out2) 314 self.assertEqual(inp1.grad, inp2.grad) 315 self.assertEqual(m1.weight.grad, m2.weight.grad) 316 self.assertEqual(m1.bias.grad, m2.bias.grad) 317 318 self.assertEqual(fw_cnt.frame_count, 1) 319 self.assertEqual(fw_cnt.op_count, 5) 320 self.assertEqual(bw_cnt.frame_count, 2) # grad=None and grad!=None 321 self.assertEqual(bw_cnt.op_count, 48) 322 323 def test_module_backward_hooks_aot(self): 324 m1, inp1 = init_module_bw_hooks(True) 325 out1 = steps(m1, inp1) 326 327 m2, inp2 = init_module_bw_hooks(True) 328 m2 = torch.compile(m2, backend="aot_eager", fullgraph=True) 329 with compiled_autograd.enable(lambda gm: gm): 330 out2 = steps(m2, inp2) 331 332 self.assertEqual(m1.hook_count_pre, m2.hook_count_pre) 333 self.assertEqual(m1.hook_count_post, m2.hook_count_post) 334 self.assertEqual(out1, out2) 335 self.assertEqual(inp1.grad, inp2.grad) 336 self.assertEqual(m1.weight.grad, m2.weight.grad) 337 self.assertEqual(m1.bias.grad, m2.bias.grad) 338 339 def test_module_backward_hooks_inductor(self): 340 m1, inp1 = init_module_bw_hooks(True) 341 out1 = steps(m1, inp1) 342 343 m2, inp2 = init_module_bw_hooks(False) 344 m2 = torch.compile(m2, fullgraph=True) 345 with compiled_autograd.enable(torch.compile(fullgraph=True)): 346 out2 = steps(m2, inp2) 347 348 self.assertEqual(m1.hook_count_pre, m2.hook_count_pre) 349 self.assertEqual(m1.hook_count_post, m2.hook_count_post) 350 self.assertEqual(out1, out2) 351 self.assertEqual(inp1.grad, inp2.grad) 352 self.assertEqual(m1.weight.grad, m2.weight.grad) 353 self.assertEqual(m1.bias.grad, m2.bias.grad) 354 355 def test_module_backward_hooks_multi_layers(self): 356 a1, inp1 = init_module_bw_hooks(True) 357 b1, _ = init_module_bw_hooks(True) 358 out1 = steps(torch.nn.Sequential(a1, b1), inp1) 359 360 a2, inp2 = init_module_bw_hooks(False) 361 b2, _ = init_module_bw_hooks(False) 362 with compiled_autograd.enable(torch.compile(fullgraph=True)): 363 out2 = steps( 364 torch.compile(torch.nn.Sequential(a2, b2), fullgraph=True), inp2 365 ) 366 367 self.assertEqual(a1.hook_count_pre, a2.hook_count_pre) 368 self.assertEqual(a1.hook_count_post, a2.hook_count_post) 369 self.assertEqual(b1.hook_count_pre, b2.hook_count_pre) 370 self.assertEqual(b1.hook_count_post, b2.hook_count_post) 371 self.assertEqual(out1, out2) 372 self.assertEqual(inp1.grad, inp2.grad) 373 self.assertEqual(a1.weight.grad, a2.weight.grad) 374 self.assertEqual(a1.bias.grad, a2.bias.grad) 375 self.assertEqual(b1.weight.grad, b2.weight.grad) 376 self.assertEqual(b1.bias.grad, b2.bias.grad) 377 378 # TODO(jansel): support bw hooks with graph break 379 380 def _assert_same_grad(self, a, b): 381 self.assertEqual(type(a), type(b)) 382 self.assertEqual(a, b) 383 self.assertEqual(a.grad, b.grad) 384 self.assertEqual(a.requires_grad, b.requires_grad) 385 386 def test_nn_param_return1(self): 387 def fn(x): 388 p = torch.nn.Parameter(x) 389 return p, p.sin() 390 391 opt = torch.compile(fn, fullgraph=True) 392 x1 = torch.randn(16) 393 x2 = x1.clone() 394 395 p1, r1 = fn(x1) 396 r1.sum().backward() 397 p2, r2 = opt(x2) 398 r2.sum().backward() 399 self._assert_same_grad(r1, r2) 400 self._assert_same_grad(p1, p2) 401 402 def test_nn_param_return2(self): 403 def fn(x): 404 p = torch.nn.Parameter(x, requires_grad=False) 405 return p, x + 1 406 407 opt = torch.compile(fn, fullgraph=True) 408 x1 = torch.randn(16) 409 x2 = x1.clone() 410 411 p1, r1 = fn(x1) 412 p2, r2 = opt(x2) 413 self._assert_same_grad(r1, r2) 414 self._assert_same_grad(p1, p2) 415 416 def test_nn_param_return3(self): 417 def fn(x): 418 p = torch.nn.Parameter(x + 123) 419 return p, p.sin() 420 421 opt = torch.compile(fn, fullgraph=True) 422 x1 = torch.randn(16) 423 x2 = x1.clone() 424 425 p1, r1 = fn(x1) 426 r1.sum().backward() 427 p2, r2 = opt(x2) 428 r2.sum().backward() 429 self._assert_same_grad(r1, r2) 430 self._assert_same_grad(p1, p2) 431 432 def test_nn_param_return4(self): 433 def fn(x): 434 p = torch.nn.Parameter(x + 123, requires_grad=False) 435 return p, x + 1 436 437 opt = torch.compile(fn, fullgraph=True) 438 x1 = torch.randn(16) 439 x2 = x1.clone() 440 441 p1, r1 = fn(x1) 442 p2, r2 = opt(x2) 443 self._assert_same_grad(r1, r2) 444 self._assert_same_grad(p1, p2) 445 446 @torch._functorch.config.patch(recompute_views=True) 447 def test_fake_distributed_aot_eager(self): 448 m1, inp1 = init_fake_distributed() 449 out1 = steps(m1, inp1) 450 451 m2, inp2 = init_fake_distributed() 452 m2 = torch.compile(m2, backend="aot_eager", fullgraph=True) 453 bw_cnt = CompileCounter() 454 with compiled_autograd.enable(torch.compile(backend=bw_cnt, fullgraph=True)): 455 out2 = steps(m2, inp2) 456 457 self._assert_same_grad(m1.weight, m2.weight) 458 self._assert_same_grad(inp1, inp2) 459 self._assert_same_grad(out1, out2) 460 # Recompile on grad==None/grad!=None 461 self.assertEqual(bw_cnt.frame_count, 2) 462 463 @skipIfRocm 464 @skipIfXpu 465 @requires_gpu() 466 @torch._functorch.config.patch(recompute_views=True) 467 def test_fake_distributed_inductor(self): 468 # TODO: fix .set_ lowering in CPU inductor, and enable the CPU test. 469 m1, inp1 = init_fake_distributed(GPU_TYPE) 470 out1 = steps(m1, inp1) 471 472 m2, inp2 = init_fake_distributed(GPU_TYPE) 473 m2 = torch.compile(m2, fullgraph=True) 474 with compiled_autograd.enable(torch.compile(fullgraph=True)): 475 out2 = steps(m2, inp2) 476 477 self._assert_same_grad(m1.weight, m2.weight) 478 self._assert_same_grad(inp1, inp2) 479 self._assert_same_grad(out1, out2) 480 481 482if __name__ == "__main__": 483 if HAS_CPU and not IS_MACOS: 484 run_tests(needs="filelock") 485