1# Owner(s): ["module: dynamo"] 2import contextlib 3import copy 4import functools 5import random 6import unittest 7from contextlib import contextmanager 8from datetime import timedelta 9from io import StringIO 10from typing import List 11from unittest.mock import patch 12 13import numpy as np 14 15import torch 16import torch._dynamo 17import torch._dynamo.logging 18import torch._dynamo.test_case 19import torch.distributed as dist 20import torch.optim as optim 21from torch import nn 22from torch._C import FileCheck 23from torch._dynamo import config 24from torch._dynamo.backends.distributed import DDPOptimizer 25from torch._dynamo.comptime import comptime 26from torch._dynamo.testing import collect_results 27from torch._dynamo.utils import same 28from torch._higher_order_ops.wrap import tag_activation_checkpoint 29from torch.distributed._functional_collectives import _maybe_wrap_tensor 30from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 31from torch.distributed.fsdp.wrap import ( 32 lambda_auto_wrap_policy, 33 transformer_auto_wrap_policy, 34) 35from torch.nn.parallel import DistributedDataParallel as DDP 36from torch.testing._internal.common_cuda import ( 37 PLATFORM_SUPPORTS_FLASH_ATTENTION, 38 PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, 39) 40from torch.testing._internal.common_distributed import ( 41 _dynamo_dist_per_rank_init, 42 DynamoDistributedMultiProcTestCase, 43 DynamoDistributedSingleProcTestCase, 44 import_transformers_or_skip, 45 requires_nccl, 46 skip_if_lt_x_gpu, 47) 48from torch.testing._internal.common_utils import requires_cuda 49from torch.utils._triton import has_triton 50 51 52def reset_rng_state(): 53 torch.manual_seed(1337) 54 random.seed(1337) 55 np.random.seed(1337) 56 57 58def init_weights(m): 59 if isinstance(m, nn.Linear): 60 nn.init.xavier_uniform_(m.weight) 61 m.bias.data.fill_(0.01) 62 63 64class ToyModel(nn.Module): 65 def __init__(self, in_feat=10, hidden_feat=5000, out_feat=5, ctx_manager=None): 66 super().__init__() 67 self.ctx_manager = ctx_manager 68 self.net = nn.Sequential( 69 *[nn.Linear(in_feat, hidden_feat), nn.ReLU()] 70 + [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()] 71 + [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()] 72 + [nn.Linear(hidden_feat, out_feat), nn.ReLU()] 73 ) 74 75 def forward(self, inputs): 76 if self.ctx_manager is not None: 77 with self.ctx_manager(): 78 return self.net(inputs) 79 else: 80 return self.net(inputs) 81 82 83def get_model( 84 device, bsz=20, in_feat=10, hidden_feat=5000, out_feat=5, ctx_manager=None 85): 86 m = ToyModel( 87 in_feat=in_feat, 88 hidden_feat=hidden_feat, 89 out_feat=out_feat, 90 ctx_manager=ctx_manager, 91 ).to(device) 92 m.apply(init_weights) 93 inputs = torch.rand(bsz, in_feat).to(device) 94 outputs = m(inputs) 95 return m, inputs, outputs 96 97 98class MutatingModel(nn.Module): 99 def __init__(self, in_feat=10, hidden_feat=5000, out_feat=5, ctx_manager=None): 100 super().__init__() 101 self.ctx_manager = ctx_manager 102 self.net = nn.Sequential( 103 *[nn.Linear(in_feat, hidden_feat), nn.ReLU()] 104 + [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()] 105 + [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()] 106 + [nn.Linear(hidden_feat, out_feat), nn.ReLU()] 107 ) 108 self.state = 1 109 110 def forward(self, inputs): 111 self.state = 2 112 return self.net(inputs) * self.state 113 114 115def get_mutating_model( 116 device, bsz=20, in_feat=10, hidden_feat=5000, out_feat=5, ctx_manager=None 117): 118 m = MutatingModel( 119 in_feat=in_feat, 120 hidden_feat=hidden_feat, 121 out_feat=out_feat, 122 ctx_manager=ctx_manager, 123 ).to(device) 124 m.apply(init_weights) 125 inputs = torch.rand(bsz, in_feat).to(device) 126 outputs = m(inputs) 127 return m, inputs, outputs 128 129 130class ToyInnerModel(nn.Module): 131 def __init__(self) -> None: 132 super().__init__() 133 self.layers = [nn.Linear(100, 100), nn.Linear(100, 100)] 134 self.layers = nn.Sequential(*self.layers) 135 136 def forward(self, inputs): 137 return self.layers(inputs) 138 139 140class ToyOuterModel(nn.Module): 141 def __init__(self, device): 142 super().__init__() 143 self.layers = [ToyInnerModel().to(device) for _ in range(2)] 144 self.layers = nn.Sequential( 145 self.layers[0], nn.ReLU(), self.layers[1], nn.ReLU() 146 ) 147 148 def forward(self, inputs): 149 return self.layers(inputs) 150 151 152def get_toy_model_for_activation_checkpointing(device): 153 m = ToyOuterModel(device).to(device) 154 m.apply(init_weights) 155 inputs = torch.rand(100, 100).to(device) 156 return m, inputs 157 158 159def find_first_node(gm, func): 160 for node in gm.graph.nodes: 161 if node.target is func: 162 return node 163 return None 164 165 166def apply_fsdp_with_checkpointing( 167 model, wrap_policy, checkpoint_policy, use_activation_checkpointing=True 168): 169 from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( 170 apply_activation_checkpointing, 171 checkpoint_wrapper, 172 CheckpointImpl, 173 ) 174 175 model = FSDP( 176 copy.deepcopy(model), auto_wrap_policy=wrap_policy, use_orig_params=True 177 ) 178 if use_activation_checkpointing: 179 checkpoint_wrapper_fn = functools.partial( 180 checkpoint_wrapper, 181 checkpoint_impl=CheckpointImpl.NO_REENTRANT, 182 ) 183 apply_activation_checkpointing( 184 model, 185 checkpoint_wrapper_fn=checkpoint_wrapper_fn, 186 check_fn=checkpoint_policy, 187 ) 188 return model 189 190 191def get_custom_model(device): 192 class MyCustomLinear(torch.nn.Module): 193 def __init__(self) -> None: 194 super().__init__() 195 self.weight = nn.Parameter(torch.randn(512, 512)) 196 197 def forward(self, x): 198 tmp = torch.mm(x, self.weight.t()) 199 # test an edge case where torch.where.scalar was decomposed to aten.where.self(tensor, tensor, tensor) 200 # and the tensors T(0.4) and T(0.5) were not wrapped in FakeTensors during DDPOptimizer compilation 201 return tmp + torch.where(tmp < 0.5, 0.3, 0.6) 202 203 class MyLinear(torch.nn.Module): 204 def __init__(self) -> None: 205 super().__init__() 206 self.linear = torch.nn.Linear(512, 512) 207 208 def forward(self, x): 209 return self.linear(x) 210 211 class MyModule(torch.nn.Module): 212 def __init__(self) -> None: 213 super().__init__() 214 mods = [ 215 (MyLinear(), torch.nn.ReLU()), 216 # sandwich the custom in the middle so it comes before and after 217 (MyCustomLinear(), torch.nn.ReLU()), 218 (MyLinear(), torch.nn.ReLU()), 219 ] 220 self.seq = torch.nn.Sequential(*[x for items in mods for x in items]) 221 222 def forward(self, x, y): 223 # test special case where the 0th bucket (layers close to graph input) is at capacity, which would 224 # trigger a new bucket, but there are only trivial ops without parameters to put into the new bucket. 225 # optimize this case by fusing that 'empty bucket' back together with the previous full one 226 return self.seq(x + y) 227 228 m = MyModule().to(device) 229 m.apply(init_weights) 230 inputs = torch.rand((512, 512)).to(device) 231 # test duplicated inputs 232 inputs = (inputs, inputs) 233 correct_outputs = m(*inputs) 234 return m, inputs, correct_outputs 235 236 237def get_hf_bert(rank): 238 # Note: use @import_transformers_or_skip on your test case if you use this 239 # in a multiprocessing test 240 try: 241 from transformers import AutoModelForMaskedLM, BertConfig 242 except ImportError as e: 243 raise unittest.SkipTest("Unable to import transformers") from e 244 245 batch_size, max_length, config, device = 4, 512, BertConfig(), f"cuda:{rank}" 246 model = AutoModelForMaskedLM.from_config(config).to(device) 247 input_ids = torch.randint(0, config.vocab_size, (batch_size, max_length)).to(device) 248 decoder_ids = torch.randint(0, config.vocab_size, (batch_size, max_length)).to( 249 device 250 ) 251 inputs = {"input_ids": input_ids, "labels": decoder_ids} 252 model.train() 253 return model, inputs 254 255 256class CheckSplitsCompiler: 257 def __init__(self) -> None: 258 self.compiler_called = 0 259 260 def compile_fn(self, gm, example_inputs): 261 self.compiler_called += 1 262 return gm 263 264 265# This simulates DDP, but it doesn't actually do any process communication; 266# it just has enough properties so that the dynamo distributed optimization is 267# able to optimize. Feel free to simulate more properties as necessary. The 268# other important thing is patching _active_ddp_module, which is what actually 269# triggers DDP optimization 270class FakeDDP(nn.Module): 271 def __init__(self, module, bucket_cap_mb=25): 272 super().__init__() 273 self.module = module 274 self.bucket_bytes_cap = int(bucket_cap_mb * 1024 * 1024) 275 276 @contextmanager 277 def _inside_ddp_forward(self): 278 DDP._active_ddp_module = self 279 try: 280 yield 281 finally: 282 DDP._active_ddp_module = None 283 284 def forward(self, *inputs, **kwargs): 285 with self._inside_ddp_forward(): 286 return self.module.forward(*inputs, **kwargs) 287 288 289def run_hf_bert_ddp(self, model, inputs, backend): 290 reset_rng_state() 291 correct_outputs = model(**inputs) 292 correct_loss = correct_outputs.loss 293 correct_loss.backward() 294 295 reset_rng_state() 296 opt_model = torch._dynamo.optimize(backend)(model) 297 opt_outputs = opt_model(**inputs) 298 opt_loss = opt_outputs.loss 299 opt_loss.backward() 300 301 inputs_flat = [inputs[k] for k in inputs] 302 correct_results = collect_results( 303 model, correct_outputs.logits, correct_loss, inputs_flat 304 ) 305 opt_results = collect_results(opt_model, opt_outputs.logits, opt_loss, inputs_flat) 306 self.assertTrue(same(correct_results, opt_results)) 307 308 309class TestFakeDistributedSingleProc(torch._dynamo.test_case.TestCase): 310 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 311 @patch.object(config, "optimize_ddp", True) 312 @patch.object(torch._inductor.config, "fallback_random", True) 313 def test_hf_bert_ddp_inductor(self): 314 model, inputs = get_hf_bert(0) 315 model = FakeDDP(model) 316 run_hf_bert_ddp(self, model, inputs, "inductor") 317 318 @patch.object(config, "optimize_ddp", True) 319 def test_hf_bert_ddp_aot_eager(self): 320 model, inputs = get_hf_bert(0) 321 model = FakeDDP(model) 322 run_hf_bert_ddp(self, model, inputs, "aot_eager") 323 324 @patch.object(config, "optimize_ddp", True) 325 def test_issue90375(self): 326 class Model(nn.Module): 327 def forward(self): 328 return torch.randn(3) * torch.randn(3) 329 330 model = Model() 331 model = FakeDDP(model) 332 333 opt_model = torch._dynamo.optimize("aot_eager")(model) 334 opt_model() 335 336 @patch.object(config, "optimize_ddp", True) 337 def test_symbol_splitting(self): 338 class Model(nn.Module): 339 def __init__(self) -> None: 340 super().__init__() 341 self.weight1 = nn.Parameter(torch.randn(512, 512)) 342 self.weight2 = nn.Parameter(torch.randn(512, 512)) 343 344 def forward(self, x): 345 x = torch.cat([x, x]) 346 y = x @ self.weight1 347 z = x + y @ self.weight2 348 return z 349 350 model = Model() 351 model = FakeDDP(model) 352 353 opt_model = torch.compile(dynamic=True)(model) 354 opt_model(torch.randn(20, 512)) 355 356 @config.patch(optimize_ddp=True, capture_scalar_outputs=True) 357 def test_unbacked_symbol_splitting_direct(self): 358 class Model(nn.Module): 359 def __init__(self) -> None: 360 super().__init__() 361 self.weight1 = nn.Parameter(torch.randn(512, 512)) 362 self.weight2 = nn.Parameter(torch.randn(512, 512)) 363 364 def forward(self, x, y): 365 u0, u1 = y.tolist() 366 x = torch.cat([x, x]) 367 y = x @ self.weight1 368 z = (x + y @ self.weight2) * u0 369 return z 370 371 model = Model() 372 model = FakeDDP(model) 373 374 opt_model = torch.compile(dynamic=True)(model) 375 opt_model(torch.randn(20, 512), torch.tensor([12, 13])) 376 377 @config.patch(optimize_ddp=True, capture_scalar_outputs=True) 378 def test_unbacked_symbol_splitting_indirect(self): 379 class Model(nn.Module): 380 def __init__(self) -> None: 381 super().__init__() 382 self.weight1 = nn.Parameter(torch.randn(512, 512)) 383 self.weight2 = nn.Parameter(torch.randn(512, 512)) 384 385 def forward(self, x, y): 386 u0, u1 = y.tolist() 387 a = torch.ones(u0) 388 x = torch.cat([x, x]) 389 y = x @ self.weight1 390 z = (x + y @ self.weight2) * a.sum() 391 return z 392 393 model = Model() 394 model = FakeDDP(model) 395 396 opt_model = torch.compile(dynamic=True)(model) 397 opt_model(torch.randn(20, 512), torch.tensor([12, 13])) 398 399 @config.patch(optimize_ddp=True, capture_scalar_outputs=True) 400 def test_unbacked_symbol_splitting_torture_multi(self): 401 class Model(nn.Module): 402 def __init__(self) -> None: 403 super().__init__() 404 self.weight1 = nn.Parameter(torch.randn(512, 512)) 405 self.weight2 = nn.Parameter(torch.randn(512, 512)) 406 self.weight3 = nn.Parameter(torch.randn(512, 512)) 407 408 def forward(self, x, y): 409 # partition one (contains the u0 def) 410 u0, u1 = y.tolist() 411 x = torch.cat([x, x]) 412 y1 = x @ self.weight1 413 # partition two (contains the variable) 414 y2 = y1 @ self.weight2 415 a = torch.ones(u0) 416 # partition three 417 z = (x + y2 @ self.weight3) * a.sum() 418 return z 419 420 model = Model() 421 model = FakeDDP(model, bucket_cap_mb=1) 422 423 opt_model = torch.compile(dynamic=True)(model) 424 opt_model(torch.randn(20, 512), torch.tensor([12, 13])) 425 426 @config.patch(optimize_ddp=True, capture_dynamic_output_shape_ops=True) 427 def test_unbacked_symbol_splitting_no_binding(self): 428 class Model(nn.Module): 429 def __init__(self) -> None: 430 super().__init__() 431 self.weight1 = nn.Parameter(torch.randn(512, 512)) 432 self.weight2 = nn.Parameter(torch.randn(512, 512)) 433 434 def forward(self, x, y): 435 nz = y.nonzero() 436 x = torch.cat([x, x]) 437 y = x @ self.weight1 438 z = (x + y @ self.weight2) * (nz + 1).sum() 439 return z 440 441 model = Model() 442 model = FakeDDP(model) 443 444 opt_model = torch.compile(dynamic=True)(model) 445 opt_model(torch.randn(20, 512), torch.tensor([0.0, 12.0, 0.0, 11.0])) 446 447 @patch.object(config, "optimize_ddp", True) 448 def test_call_method_forward(self): 449 class Model(nn.Module): 450 def __init__( 451 self, 452 ): 453 super().__init__() 454 layers = [] 455 for l in range(2): 456 layer = nn.ModuleList( 457 [ 458 nn.LayerNorm(96), 459 nn.MultiheadAttention( 460 embed_dim=96, num_heads=4, batch_first=True 461 ), 462 ] 463 ) 464 layers.append(layer) 465 self.layers = nn.ModuleList(layers) 466 467 def forward(self, x: torch.Tensor) -> torch.Tensor: 468 # x: [Batch, Freq, Time, Feature] 469 B, F, T, H = x.shape 470 for m in self.layers: 471 x = x.reshape(B * F, T, H) 472 x = m[0](x) 473 x, attn = m[1].forward(x, x, x) 474 x = x.reshape(B, F, T, H) 475 return x 476 477 model = Model() 478 model = FakeDDP(model) 479 opt_model = torch.compile(model) 480 opt_model(torch.randn(2, 129, 100, 96)) 481 482 483# Are these tests failing? Check and see if TestFakeDistributedSingleProc has a 484# single process version; if it's just a problem in the Dynamo distributed 485# optimizer, you should be able to repro it single process! 486@requires_nccl() 487class TestMultiProc(DynamoDistributedMultiProcTestCase): 488 """ 489 Note: MultiProcTestCase spawns processes per test and is slow. 490 Prefer MultiThreadedTestCase for most tests. Perhaps use this one 491 sparingly for integration tests. 492 """ 493 494 @skip_if_lt_x_gpu(2) 495 @config.patch(optimize_ddp=False, enable_compiler_collectives=True) 496 def test_ddp_baseline_aot_eager_multiprocess(self): 497 with _dynamo_dist_per_rank_init(self.rank, self.world_size): 498 self.assertFalse(config.optimize_ddp) 499 m, inputs, correct_outputs = get_model(f"cuda:{self.rank}") 500 m = DDP(m, device_ids=[self.rank]) 501 m = torch._dynamo.optimize("aot_eager")(m) 502 outputs = m(inputs) 503 self.assertTrue(same(correct_outputs, outputs)) 504 505 def _test_hf_bert_ddp_inductor(self, static_graph): 506 with _dynamo_dist_per_rank_init(self.rank, self.world_size): 507 model, inputs = get_hf_bert(self.rank) 508 model = DDP(model, static_graph=static_graph) 509 run_hf_bert_ddp(self, model, inputs, "inductor") 510 511 @skip_if_lt_x_gpu(2) 512 @import_transformers_or_skip() 513 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 514 @config.patch(optimize_ddp=True, enable_compiler_collectives=True) 515 @patch.object(torch._inductor.config, "fallback_random", True) 516 def test_hf_bert_ddp_inductor(self): 517 self._test_hf_bert_ddp_inductor(static_graph=False) 518 519 @skip_if_lt_x_gpu(2) 520 @import_transformers_or_skip() 521 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 522 @config.patch(optimize_ddp=True, enable_compiler_collectives=True) 523 @patch.object(torch._inductor.config, "fallback_random", True) 524 def test_hf_bert_ddp_inductor_static_graph(self): 525 self._test_hf_bert_ddp_inductor(static_graph=True) 526 527 def _test_hf_bert_aot_eager(self, static_graph): 528 with _dynamo_dist_per_rank_init(self.rank, self.world_size): 529 model, inputs = get_hf_bert(self.rank) 530 model = DDP(model, static_graph=static_graph) 531 run_hf_bert_ddp(self, model, inputs, "aot_eager") 532 533 @skip_if_lt_x_gpu(2) 534 @import_transformers_or_skip() 535 @config.patch(optimize_ddp=True, enable_compiler_collectives=True) 536 def test_hf_bert_ddp_aot_eager(self): 537 self._test_hf_bert_aot_eager(static_graph=False) 538 539 @skip_if_lt_x_gpu(2) 540 @import_transformers_or_skip() 541 @config.patch(optimize_ddp=True, enable_compiler_collectives=True) 542 def test_hf_bert_ddp_aot_eager_static_graph(self): 543 self._test_hf_bert_aot_eager(static_graph=True) 544 545 @skip_if_lt_x_gpu(2) 546 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 547 @config.patch(optimize_ddp=False, enable_compiler_collectives=True) 548 def test_ddp_activation_checkpointing(self): 549 from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( 550 apply_activation_checkpointing, 551 checkpoint_wrapper, 552 CheckpointImpl, 553 ) 554 555 class MyModel(torch.nn.Module): 556 def __init__(self) -> None: 557 super().__init__() 558 self.fc1 = torch.nn.Linear(64, 32) 559 self.fc2 = torch.nn.Linear(32, 16) 560 self.fc3 = torch.nn.Linear(16, 8) 561 562 def forward(self, inp): 563 return self.fc3(self.fc2(self.fc1(inp))) 564 565 with _dynamo_dist_per_rank_init(self.rank, self.world_size): 566 self.assertFalse(config.optimize_ddp) 567 model = MyModel().to(device="cuda") 568 569 # Activation checkpointing for Linear layers. 570 non_reentrant_wrapper = functools.partial( 571 checkpoint_wrapper, 572 checkpoint_impl=CheckpointImpl.NO_REENTRANT, 573 ) 574 check_fn = lambda submodule: isinstance( # noqa: E731 575 submodule, torch.nn.Linear 576 ) 577 apply_activation_checkpointing( 578 model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn 579 ) 580 581 model = DDP(model) 582 x = torch.randn(10, 64).cuda() 583 correct_outputs = model(x) 584 585 opt_model = torch.compile(model) 586 outputs = opt_model(x) 587 self.assertTrue(same(correct_outputs, outputs)) 588 589 @config.patch(enable_compiler_collectives=True) 590 @skip_if_lt_x_gpu(1) 591 def test_fsdp_aot_eager(self): 592 with _dynamo_dist_per_rank_init(self.rank, self.world_size): 593 # Test with basic FSDP wrapping (outer wrap around whole model) 594 m, inputs, correct_outputs = get_model(f"cuda:{self.rank}") 595 fsdp_m = FSDP(m, use_orig_params=True) 596 fsdp_m = torch._dynamo.optimize("aot_eager")(fsdp_m) 597 outputs = fsdp_m(inputs) 598 self.assertTrue(same(correct_outputs, outputs)) 599 600 # Test with recursive wrapping, nested FSDP around each Linear 601 m, inputs, correct_outputs = get_model(f"cuda:{self.rank}") 602 fsdp_m = FSDP( 603 m, 604 auto_wrap_policy=functools.partial( 605 transformer_auto_wrap_policy, transformer_layer_cls=(nn.Linear,) 606 ), 607 use_orig_params=True, 608 ) 609 fsdp_m = torch._dynamo.optimize("aot_eager")(fsdp_m) 610 outputs = fsdp_m(inputs) 611 self.assertTrue(same(correct_outputs, outputs)) 612 613 @config.patch(enable_compiler_collectives=True) 614 @skip_if_lt_x_gpu(1) 615 def test_fsdp_setattr(self): 616 with _dynamo_dist_per_rank_init(self.rank, self.world_size): 617 # Test with basic FSDP wrapping (outer wrap around whole model) 618 from torch._dynamo.utils import counters 619 620 counters.clear() 621 m, inputs, correct_outputs = get_mutating_model(f"cuda:{self.rank}") 622 fsdp_m = FSDP(m, use_orig_params=True) 623 fsdp_m = torch.compile(fsdp_m, backend="eager", fullgraph=False) 624 outputs = fsdp_m(inputs) 625 self.assertTrue(same(correct_outputs, outputs)) 626 self.assertEqual(len(counters["graph_break"]), 1) 627 first_graph_break = list(counters["graph_break"].keys())[0] # noqa: RUF015 628 self.assertTrue("setattr" not in first_graph_break) 629 630 @config.patch(enable_compiler_collectives=True) 631 @skip_if_lt_x_gpu(1) 632 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 633 def test_fsdp_inductor(self): 634 with _dynamo_dist_per_rank_init(self.rank, self.world_size): 635 # Test with basic FSDP wrapping (outer wrap around whole model) 636 m, inputs, correct_outputs = get_model(f"cuda:{self.rank}") 637 fsdp_m = FSDP(m, use_orig_params=True) 638 fsdp_m = torch._dynamo.optimize("inductor")(fsdp_m) 639 outputs = fsdp_m(inputs) 640 self.assertTrue(same(correct_outputs, outputs)) 641 642 # Test with recursive wrapping, nested FSDP around each Linear 643 m, inputs, correct_outputs = get_model(f"cuda:{self.rank}") 644 fsdp_m = FSDP( 645 m, 646 auto_wrap_policy=functools.partial( 647 transformer_auto_wrap_policy, transformer_layer_cls=(nn.Linear,) 648 ), 649 use_orig_params=True, 650 ) 651 fsdp_m = torch._dynamo.optimize("inductor")(fsdp_m) 652 outputs = fsdp_m(inputs) 653 self.assertTrue(same(correct_outputs, outputs)) 654 655 @config.patch(enable_compiler_collectives=True) 656 @skip_if_lt_x_gpu(1) 657 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 658 def test_fsdp_activation_checkpointing(self): 659 with _dynamo_dist_per_rank_init(self.rank, self.world_size): 660 model, inputs = get_toy_model_for_activation_checkpointing( 661 f"cuda:{self.rank}" 662 ) 663 is_inner = lambda module: isinstance(module, ToyInnerModel) # noqa: E731 664 wrap_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=is_inner) 665 model = apply_fsdp_with_checkpointing(model, wrap_policy, is_inner) 666 correct_outputs = model(inputs) 667 cnt = torch._dynamo.testing.CompileCounterWithBackend("inductor") 668 opt_model = torch._dynamo.optimize(cnt)(model) 669 outputs = opt_model(inputs) 670 self.assertTrue(same(correct_outputs, outputs)) 671 # Each FSDP module is a separate graph 672 self.assertEqual(cnt.frame_count, 2) 673 self.assertTrue( 674 find_first_node(cnt.graphs[0], tag_activation_checkpoint) is not None 675 ) 676 677 @import_transformers_or_skip() 678 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 679 # TODO(whc) Investigate why cudagraphs breaks inductor+fsdp for hf_bert 680 @patch.object(torch._inductor.config.triton, "cudagraphs", False) 681 @patch.object(torch._inductor.config, "fallback_random", True) 682 @config.patch(enable_compiler_collectives=True) 683 @unittest.skipIf( 684 PLATFORM_SUPPORTS_FLASH_ATTENTION or PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, 685 "Inaccurate results with fused SDPA kernels", 686 ) 687 def test_hf_bert_fsdp(self): 688 def apply_fsdp(model, wrap_policy): 689 model = FSDP( 690 copy.deepcopy(model), auto_wrap_policy=wrap_policy, use_orig_params=True 691 ) 692 return model 693 694 with _dynamo_dist_per_rank_init(self.rank, self.world_size): 695 for wrap_policy, test_instance in ( 696 (None, "FSDP without recursive wrapping"), 697 ): 698 print(f"Running hf_bert test for {test_instance}") 699 model, inputs = get_hf_bert(self.rank) 700 reset_rng_state() 701 eager_model = apply_fsdp(model, wrap_policy) 702 correct_outputs = eager_model(**inputs) 703 correct_loss = correct_outputs.loss 704 correct_loss.backward() 705 706 reset_rng_state() 707 opt_model = apply_fsdp(model, wrap_policy) 708 opt_model = torch._dynamo.optimize("inductor")(opt_model) 709 opt_outputs = opt_model(**inputs) 710 opt_loss = opt_outputs.loss 711 opt_loss.backward() 712 713 inputs_flat = [inputs[k] for k in inputs] 714 correct_results = collect_results( 715 eager_model, correct_outputs.logits, correct_loss, inputs_flat 716 ) 717 opt_results = collect_results( 718 opt_model, opt_outputs.logits, opt_loss, inputs_flat 719 ) 720 self.assertTrue(same(correct_results, opt_results)) 721 722 @import_transformers_or_skip() 723 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 724 # TODO(whc) Investigate why cudagraphs breaks inductor+fsdp for hf_bert 725 @patch.object(torch._inductor.config.triton, "cudagraphs", False) 726 @patch.object(torch._inductor.config, "fallback_random", True) 727 @config.patch(guard_nn_modules=True, enable_compiler_collectives=True) 728 def test_hf_bert_fsdp_activation_checkpointing(self): 729 from transformers.models.bert.modeling_bert import BertLayer 730 731 with _dynamo_dist_per_rank_init(self.rank, self.world_size): 732 for wrap_policy, test_instance in ( 733 ( 734 functools.partial( 735 transformer_auto_wrap_policy, transformer_layer_cls=(BertLayer,) 736 ), 737 "FSDP with recursive wrapping BertLayer instances", 738 ), 739 ): 740 print( 741 f"Running hf_bert_activation_checkpointing test for {test_instance}" 742 ) 743 model, inputs = get_hf_bert(self.rank) 744 check_fn = lambda submodule: isinstance( # noqa: E731 745 submodule, BertLayer 746 ) 747 reset_rng_state() 748 eager_model = apply_fsdp_with_checkpointing( 749 model, wrap_policy, check_fn 750 ) 751 correct_outputs = eager_model(**inputs) 752 correct_loss = correct_outputs.loss 753 correct_loss.backward() 754 755 reset_rng_state() 756 opt_model = apply_fsdp_with_checkpointing(model, wrap_policy, check_fn) 757 opt_model = torch._dynamo.optimize("inductor")(opt_model) 758 opt_outputs = opt_model(**inputs) 759 opt_loss = opt_outputs.loss 760 opt_loss.backward() 761 762 inputs_flat = [inputs[k] for k in inputs] 763 correct_results = collect_results( 764 eager_model, correct_outputs.logits, correct_loss, inputs_flat 765 ) 766 opt_results = collect_results( 767 opt_model, opt_outputs.logits, opt_loss, inputs_flat 768 ) 769 self.assertTrue(same(correct_results, opt_results)) 770 771 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 772 @config.patch(enable_compiler_collectives=True) 773 def test_compiler_collectives_automatic_dynamic_tensor(self): 774 with _dynamo_dist_per_rank_init(self.rank, self.world_size): 775 776 class SimpleModel(nn.Module): 777 def __init__(self, input_size, output_size): 778 super().__init__() 779 self.linear = nn.Linear(input_size, output_size) 780 781 def forward(self, x): 782 return self.linear(x) 783 784 torch._dynamo.utils.clear_compilation_metrics() 785 786 model = SimpleModel(10, 2).to(self.rank) 787 model.forward = torch.compile(model.forward) 788 ddp_model = DDP(model, device_ids=[self.rank]) 789 790 loss_fn = nn.CrossEntropyLoss() 791 optimizer = optim.SGD(ddp_model.parameters(), lr=0.001) 792 793 def B(s): 794 return [torch.randn(s, 10), torch.randint(0, 2, (s,))] 795 796 if self.rank == 0: 797 dataloader = [B(5), B(8), B(6)] 798 else: 799 dataloader = [B(6), B(6), B(3)] 800 801 for data, labels in dataloader: 802 data, labels = data.to(self.rank), labels.to(self.rank) 803 optimizer.zero_grad() 804 output = ddp_model(data) 805 loss = loss_fn(output, labels) 806 loss.backward() 807 optimizer.step() 808 809 metrics = torch._dynamo.utils.get_compilation_metrics() 810 # Number of compiles same on all nodes 811 res = [None] * self.world_size 812 torch.distributed.all_gather_object(res, len(metrics)) 813 for r in res[1:]: 814 self.assertEqual(res[0], r) 815 816 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 817 @config.patch(enable_compiler_collectives=True) 818 def test_compiler_collectives_automatic_dynamic_scalar(self): 819 with _dynamo_dist_per_rank_init(self.rank, self.world_size): 820 torch._dynamo.utils.clear_compilation_metrics() 821 822 # TODO: This should be possible to do inside the function, but 823 device = f"cuda:{self.rank}" 824 825 @torch.compile() 826 def f(x, y): 827 return x + torch.ones(y, device=device).sum() 828 829 if self.rank == 0: 830 dataloader = [3, 3, 7] 831 else: 832 dataloader = [3, 4, 9] 833 834 for data in dataloader: 835 f(torch.randn(5, device=self.rank), data) 836 837 metrics = torch._dynamo.utils.get_compilation_metrics() 838 # Number of compiles same on all nodes 839 res = [None] * self.world_size 840 torch.distributed.all_gather_object(res, len(metrics)) 841 for r in res[1:]: 842 self.assertEqual(res[0], r) 843 844 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 845 @config.patch(enable_compiler_collectives=True) 846 def test_compiler_collectives_automatic_dynamic_speculation_divergence(self): 847 with _dynamo_dist_per_rank_init(self.rank, self.world_size): 848 torch._dynamo.utils.clear_compilation_metrics() 849 850 # TODO: This should be possible to do inside the function, but 851 device = f"cuda:{self.rank}" 852 853 @torch.compile() 854 def f(x, y): 855 zx = x.shape 856 zy = y.shape 857 return x.sum() + y.sum() 858 859 if self.rank == 0: 860 dataloader = [4, 4] 861 else: 862 dataloader = [3, 4] 863 864 for data in dataloader: 865 f( 866 torch.randn(data, device=self.rank), 867 torch.randn(data, device=self.rank), 868 ) 869 870 metrics = torch._dynamo.utils.get_compilation_metrics() 871 # Number of compiles same on all nodes 872 res = [None] * self.world_size 873 torch.distributed.all_gather_object(res, len(metrics)) 874 for r in res[1:]: 875 self.assertEqual(res[0], r) 876 877 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 878 @config.patch(enable_compiler_collectives=True) 879 def test_compiler_collectives_graph_break_empty_graph_still_collective(self): 880 with _dynamo_dist_per_rank_init(self.rank, self.world_size): 881 torch._dynamo.utils.clear_compilation_metrics() 882 883 device = f"cuda:{self.rank}" 884 885 @torch.compile() 886 def f(x, y): 887 z = y 888 print("woof") 889 zx = x.shape 890 zy = y.shape 891 return x.sum() + y.sum() 892 893 if self.rank == 0: 894 dataloader = [5, 5, 6] 895 else: 896 dataloader = [3, 4, 5] 897 898 for data in dataloader: 899 f( 900 torch.randn(data, device=self.rank), 901 torch.randn(data, device=self.rank), 902 ) 903 904 metrics = torch._dynamo.utils.get_compilation_metrics() 905 # Number of compiles same on all nodes 906 res = [None] * self.world_size 907 torch.distributed.all_gather_object(res, len(metrics)) 908 for r in res[1:]: 909 self.assertEqual(res[0], r) 910 911 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 912 @config.patch(enable_compiler_collectives=True) 913 def test_compiler_collectives_dim_mismatch(self): 914 with _dynamo_dist_per_rank_init(self.rank, self.world_size): 915 torch._dynamo.utils.clear_compilation_metrics() 916 917 @torch.compile() 918 def f(x, y): 919 zx = x.shape 920 zy = y.shape 921 return x.sum() + y.sum() 922 923 if self.rank == 0: 924 dataloader = [[4, 2]] 925 else: 926 dataloader = [[3]] 927 928 for data in dataloader: 929 f( 930 torch.randn(data, device=self.rank), 931 torch.randn(data, device=self.rank), 932 ) 933 934 metrics = torch._dynamo.utils.get_compilation_metrics() 935 res = [None] * self.world_size 936 torch.distributed.all_gather_object(res, len(metrics)) 937 for r in res[1:]: 938 self.assertEqual(res[0], r) 939 940 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 941 @config.patch(enable_compiler_collectives=True) 942 def test_compiler_collectives_missing_source(self): 943 with _dynamo_dist_per_rank_init(self.rank, self.world_size): 944 torch._dynamo.utils.clear_compilation_metrics() 945 946 @torch.compile() 947 def f(rank, xs): 948 return xs[rank].sum() 949 950 xs = [] 951 for _ in range(self.world_size): 952 xs.append(torch.randn(10, device=self.rank)) 953 954 f(self.rank, xs) 955 956 metrics = torch._dynamo.utils.get_compilation_metrics() 957 res = [None] * self.world_size 958 torch.distributed.all_gather_object(res, len(metrics)) 959 for r in res[1:]: 960 self.assertEqual(res[0], r) 961 962 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 963 @patch.object(torch._inductor.config, "fx_graph_cache", False) 964 @patch.object(torch._inductor.config, "fx_graph_remote_cache", False) 965 def test_asymmetric_compilation(self): 966 from torch._dynamo.comptime import comptime 967 968 with _dynamo_dist_per_rank_init(self.rank, self.world_size): 969 torch._dynamo.utils.clear_compilation_metrics() 970 971 device = f"cuda:{self.rank}" 972 973 pg = dist.distributed_c10d._get_default_group() 974 975 cnt = torch._dynamo.testing.CompileCounter() 976 sleep_time = 5 977 978 @torch._dynamo.optimize(cnt) 979 def f(x): 980 if self.rank == 0: 981 comptime.sleep(sleep_time) 982 983 y = 2 * x 984 return y.sum() 985 986 backend = pg._get_backend(torch.device(device)) 987 backend._set_default_timeout(timedelta(seconds=sleep_time - 2)) 988 989 x = torch.ones(4, device=device) 990 991 # NCCL startup is lazy 992 w = pg.allreduce(x) 993 w.wait() 994 995 f(x) 996 if self.rank != 0: 997 # test fails with NCCL timeout without this line 998 dist.distributed_c10d._add_ephemeral_timeout_for_all_pgs( 999 timedelta(seconds=sleep_time) 1000 ) 1001 1002 w = pg.allreduce(x) 1003 w.wait() 1004 torch.cuda.synchronize(device) 1005 1006 metrics = torch._dynamo.utils.get_compilation_metrics() 1007 # Number of compiles same on all nodes 1008 res = [None] * self.world_size 1009 torch.distributed.all_gather_object(res, len(metrics)) 1010 for r in res[1:]: 1011 self.assertEqual(res[0], r) 1012 1013 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 1014 @patch.object(torch._inductor.config, "fx_graph_cache", True) 1015 @patch.object(torch._inductor.config, "fx_graph_remote_cache", False) 1016 @patch.object(torch._inductor.config, "sleep_sec_TESTING_ONLY", 10) 1017 def test_asymmetric_compilation_with_fx_cache(self): 1018 from torch._dynamo.utils import counters 1019 from torch._inductor.utils import fresh_inductor_cache 1020 1021 with fresh_inductor_cache(), _dynamo_dist_per_rank_init( 1022 self.rank, self.world_size 1023 ): 1024 torch._dynamo.utils.clear_compilation_metrics() 1025 1026 device = f"cuda:{self.rank}" 1027 1028 pg = dist.distributed_c10d._get_default_group() 1029 1030 @torch.compile 1031 def f(x): 1032 y = 2 * x 1033 return y.sum() 1034 1035 backend = pg._get_backend(torch.device(device)) 1036 backend._set_default_timeout(timedelta(seconds=5)) 1037 counters.clear() 1038 1039 x = torch.ones(4, device=device) 1040 1041 f(x) 1042 1043 self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1) 1044 self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) 1045 self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0) 1046 1047 w = pg.allreduce(x) 1048 w.wait() 1049 torch.cuda.synchronize(device) 1050 torch._dynamo.reset() 1051 1052 if self.rank == 0: 1053 with fresh_inductor_cache(): 1054 f(x) 1055 self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 2) 1056 self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) 1057 self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0) 1058 else: 1059 f(x) 1060 self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1) 1061 self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1) 1062 self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0) 1063 1064 w = pg.allreduce(x) 1065 w.wait() 1066 torch.cuda.synchronize(device) 1067 1068 1069@requires_nccl() 1070@requires_cuda 1071class TestSingleProc(DynamoDistributedSingleProcTestCase): 1072 """ 1073 Test harness initializes dist process group. 1074 1075 Test simple things here since they are simpler to debug. 1076 Use TestMultiProc for things that really need to run on multiple nodes 1077 """ 1078 1079 def get_model( 1080 self, bsz=20, in_feat=10, hidden_feat=5000, out_feat=5, ctx_manager=None 1081 ): 1082 m = ToyModel( 1083 in_feat=in_feat, 1084 hidden_feat=hidden_feat, 1085 out_feat=out_feat, 1086 ctx_manager=ctx_manager, 1087 ).to(self.device) 1088 m.apply(init_weights) 1089 inputs = torch.rand(bsz, in_feat).to(self.device) 1090 outputs = m(inputs) 1091 return m, inputs, outputs 1092 1093 @patch.object(config, "optimize_ddp", False) 1094 def test_ddp_baseline_aot_eager(self): 1095 from torch.nn.parallel import DistributedDataParallel as DDP 1096 1097 m, inputs, correct_outputs = self.get_model() 1098 ddp_m = DDP(m, device_ids=self.device_ids) 1099 ddp_m = torch._dynamo.optimize("aot_eager")(ddp_m) 1100 outputs = ddp_m(inputs) 1101 self.assertTrue(same(correct_outputs, outputs)) 1102 1103 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 1104 @patch.object(config, "optimize_ddp", False) 1105 def test_ddp_baseline_inductor(self): 1106 from torch.nn.parallel import DistributedDataParallel as DDP 1107 1108 m, inputs, correct_outputs = self.get_model() 1109 ddp_m = DDP(m, device_ids=self.device_ids) 1110 ddp_m = torch._dynamo.optimize("inductor")(ddp_m) 1111 outputs = ddp_m(inputs) 1112 self.assertTrue(same(correct_outputs, outputs)) 1113 1114 @patch.object(config, "optimize_ddp", True) 1115 def test_graph_split(self): 1116 assert config.optimize_ddp 1117 """ 1118 Just ensures that the appropriate number of splits happen (based on 1119 bucket size and model parameters) - verifies the number of times 1120 the user-provided compiler is called by the DDPOptimizer which is 1121 doing the graph splitting 1122 """ 1123 1124 m, inputs, correct_outputs = self.get_model() 1125 ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25) 1126 1127 check_splits_compiler = CheckSplitsCompiler() 1128 1129 @torch._dynamo.optimize(check_splits_compiler.compile_fn) 1130 def opt_fn(inputs): 1131 return ddp_m(inputs) 1132 1133 opt_outputs = opt_fn(inputs) 1134 self.assertTrue(same(correct_outputs, opt_outputs)) 1135 self.assertEqual(check_splits_compiler.compiler_called, 3) 1136 1137 # ensure compatibility with dynamo explain 1138 1139 explain_out = torch._dynamo.explain(ddp_m)(inputs) 1140 break_reasons = explain_out.break_reasons 1141 self.assertEqual(len(break_reasons), 3) 1142 self.assertTrue(all("DDPOptimizer" in r.reason for r in break_reasons)) 1143 1144 @patch.object(config, "optimize_ddp", True) 1145 def test_graph_split_ctx_manager(self): 1146 """ 1147 Ensures that we get the right number of splits and that the respective 1148 context managers' effects are applied to the computation. 1149 """ 1150 1151 for get_compiler in [ 1152 lambda: CheckSplitsCompiler(), 1153 lambda: None, 1154 ]: 1155 for ctx_manager, output_test in [ 1156 ( 1157 lambda: torch.autocast( 1158 torch.device(self.device).type, torch.float16 1159 ), 1160 lambda out: self.assertEqual(out.dtype, torch.float16), 1161 ), 1162 (torch.enable_grad, lambda out: self.assertTrue(out.requires_grad)), 1163 (torch.no_grad, lambda out: self.assertTrue(not out.requires_grad)), 1164 ]: 1165 m, inputs, correct_outputs = self.get_model( 1166 out_feat=1000, 1167 hidden_feat=1000, 1168 in_feat=1000, 1169 ctx_manager=ctx_manager, 1170 ) 1171 # inp - 1000 * 1000 matrix of float32 (4 bytes) = 4MB 1172 # hidden - 1000 * 1000 matrix of float32 (4 bytes) = 4MB 1173 bucket_cap_mb = 3.5 # 4MB 1174 ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=bucket_cap_mb) 1175 1176 compiler = get_compiler() 1177 1178 @torch._dynamo.optimize( 1179 compiler.compile_fn if compiler else "aot_eager" 1180 ) 1181 def opt_fn(inputs): 1182 return ddp_m(inputs) 1183 1184 opt_outputs = opt_fn(inputs) 1185 self.assertTrue(same(correct_outputs, opt_outputs)) 1186 if compiler: 1187 self.assertEqual(compiler.compiler_called, 4) 1188 1189 output_test(opt_outputs) 1190 1191 # ensure compatibility with dynamo explain 1192 1193 explain_out = torch._dynamo.explain(ddp_m)(inputs) 1194 break_reasons = explain_out.break_reasons 1195 self.assertEqual(len(break_reasons), 4) 1196 self.assertTrue(all("DDPOptimizer" in r.reason for r in break_reasons)) 1197 1198 @patch.object(config, "optimize_ddp", True) 1199 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 1200 def test_graph_split_inductor(self): 1201 assert config.optimize_ddp 1202 """ 1203 Same as above, but using inductor backend. 1204 We observed issues with inductor/fx interface in the past. 1205 """ 1206 m, inputs, correct_outputs = self.get_model() 1207 ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25) 1208 1209 @torch._dynamo.optimize("inductor") 1210 def opt_fn(inputs): 1211 return ddp_m(inputs) 1212 1213 opt_outputs = opt_fn(inputs) 1214 self.assertTrue(same(correct_outputs, opt_outputs)) 1215 1216 @torch._inductor.config.patch( 1217 {"layout_optimization": True, "keep_output_stride": False} 1218 ) 1219 @patch.object(config, "optimize_ddp", True) 1220 def _test_graph_split_inductor_layout_optimizations_impl(self, context): 1221 assert config.optimize_ddp 1222 channel_dim = 512 1223 # channel dim must be > 64 for inductor to do layout optimization and use NHWC 1224 1225 class ToyModelConv(nn.Module): 1226 def __init__(self) -> None: 1227 super().__init__() 1228 self.net = nn.Sequential( 1229 *[ 1230 nn.Conv2d(channel_dim, channel_dim, 1, stride=1, bias=False), 1231 nn.ReLU(), 1232 ] 1233 + [ 1234 nn.Conv2d(channel_dim, channel_dim, 1, stride=1, bias=False), 1235 nn.ReLU(), 1236 ] 1237 + [ 1238 nn.Conv2d(channel_dim, channel_dim, 1, stride=1, bias=False), 1239 nn.ReLU(), 1240 ] 1241 + [ 1242 nn.Conv2d(channel_dim, channel_dim, 1, stride=1, bias=False), 1243 nn.ReLU(), 1244 ] 1245 ) 1246 1247 def forward(self, inputs): 1248 return self.net(inputs) 1249 1250 def get_model(): 1251 m = ToyModelConv().to(self.device) 1252 m.apply(init_weights) 1253 inputs = torch.rand(2, channel_dim, channel_dim, 128).to(self.device) 1254 outputs = m(inputs) 1255 return m, inputs, outputs 1256 1257 with context(): 1258 m, inputs, correct_outputs = get_model() 1259 ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25) 1260 1261 @torch._dynamo.optimize("inductor") 1262 def opt_fn(inputs): 1263 return ddp_m(inputs) 1264 1265 opt_outputs = opt_fn(inputs) 1266 self.assertTrue(same(correct_outputs, opt_outputs)) 1267 1268 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 1269 def test_graph_split_inductor_layout_optimizations_training(self): 1270 self._test_graph_split_inductor_layout_optimizations_impl( 1271 contextlib.nullcontext 1272 ) 1273 1274 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 1275 def test_graph_split_inductor_layout_optimizations_inference(self): 1276 self._test_graph_split_inductor_layout_optimizations_impl(torch.no_grad) 1277 1278 @patch.object(config, "optimize_ddp", True) 1279 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 1280 def test_graph_split_inductor_transpose(self): 1281 assert config.optimize_ddp 1282 1283 B = 100 1284 N = 30 1285 D = 50 1286 K = 70 1287 1288 class Foo(nn.Module): 1289 def __init__(self) -> None: 1290 super().__init__() 1291 self.linear0 = nn.Linear(N, K) 1292 self.linear1 = torch.nn.Linear(D * K, 2048) 1293 1294 def forward(self, x): 1295 xt = x.transpose(2, 1) 1296 xt = self.linear0(xt).flatten(1) 1297 return self.linear1(xt) 1298 1299 mod = Foo().to(self.device) 1300 1301 compiled_mod = torch.compile(mod, backend="inductor") 1302 ddp_compiled_mod = DDP(compiled_mod, device_ids=self.device_ids) 1303 1304 x = torch.randn((B, N, D), dtype=torch.float32, device=self.device) 1305 self.assertTrue(same(mod(x), ddp_compiled_mod(x))) 1306 1307 x_1 = torch.randn((B * 2, N, D), dtype=torch.float32, device=self.device) 1308 self.assertTrue(same(mod(x_1), ddp_compiled_mod(x_1))) 1309 1310 x_2 = torch.randn((B * 3, N, D), dtype=torch.float32, device=self.device) 1311 self.assertTrue(same(mod(x_2), ddp_compiled_mod(x_2))) 1312 1313 @patch.object(config, "optimize_ddp", True) 1314 def test_no_split(self): 1315 """ 1316 Ensures the DDPOptimizer returns a correct, compiled module without 1317 introducing graph splits. (Based on model parameters fitting in the bucket) 1318 """ 1319 # DDP will always do a 'first bucket' with a really small size; so only a tiny model will escape this 1320 m, inputs, correct_outputs = self.get_model(hidden_feat=5) 1321 ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=250) 1322 check_splits_compiler = CheckSplitsCompiler() 1323 1324 @torch._dynamo.optimize(check_splits_compiler.compile_fn) 1325 def opt_fn(inputs): 1326 return ddp_m(inputs) 1327 1328 opt_outputs = opt_fn(inputs) 1329 self.assertTrue(same(correct_outputs, opt_outputs)) 1330 self.assertEqual(check_splits_compiler.compiler_called, 1) 1331 1332 @patch.object(config, "optimize_ddp", True) 1333 def test_aot_autograd(self): 1334 """ 1335 Explicitly check AotAutograd family of compilers work, 1336 since they require example inputs propagated between graph splits. 1337 """ 1338 m, inputs, correct_outputs = self.get_model() 1339 ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25) 1340 1341 @torch._dynamo.optimize("aot_eager") 1342 def opt_fn(inputs): 1343 return ddp_m(inputs) 1344 1345 opt_outputs = opt_fn(inputs) 1346 opt_outputs.sum().backward() 1347 self.assertTrue(same(correct_outputs, opt_outputs)) 1348 1349 @patch.object(config, "optimize_ddp", True) 1350 def test_custom_layer(self): 1351 """ 1352 Just ensures that the appropriate number of splits happen (based on 1353 bucket size and model parameters) - verifies the number of times 1354 the user-provided compiler is called by the DDPOptimizer which is 1355 doing the graph splitting 1356 """ 1357 m, inputs, correct_outputs = get_custom_model(self.device) 1358 ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=1) 1359 1360 check_splits_compiler = CheckSplitsCompiler() 1361 1362 @torch._dynamo.optimize(check_splits_compiler.compile_fn) 1363 def opt_fn(inputs): 1364 return ddp_m(*inputs) 1365 1366 opt_outputs = opt_fn(inputs) 1367 self.assertTrue(same(correct_outputs, opt_outputs)) 1368 self.assertEqual(check_splits_compiler.compiler_called, 3) 1369 1370 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 1371 def test_empty_graph_inductor(self): 1372 def fn(): 1373 get_world_size = torch.distributed.distributed_c10d.get_world_size() 1374 return (get_world_size,) 1375 1376 opt_fn = torch._dynamo.optimize("inductor")(fn) 1377 res = None 1378 try: 1379 res = opt_fn()[0] 1380 except Exception: 1381 pass 1382 self.assertEqual(res, 1) 1383 1384 @patch.object(config, "optimize_ddp", False) 1385 def test_ignored_parameters(self): 1386 """ 1387 Verifies ddp graph-split logic ignores parameters marked to ignore on DDP module. 1388 Hooks up graph-split optimizer manually so it can peek at internal state. 1389 """ 1390 m, inputs, correct_outputs = get_custom_model(self.device) 1391 parameters_to_ignore = ["seq.2.weight", "seq.4.linear.bias"] 1392 DDP._set_params_and_buffers_to_ignore_for_model(m, parameters_to_ignore) 1393 ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25) 1394 parameter_ids_to_ignore = [ 1395 id(ddp_m.module.get_parameter(p)) for p in ddp_m.parameters_to_ignore 1396 ] 1397 1398 check_splits_compiler = CheckSplitsCompiler() 1399 ddp_optimizer = DDPOptimizer( 1400 bucket_bytes_cap=ddp_m.bucket_bytes_cap, 1401 backend_compile_fn=check_splits_compiler.compile_fn, 1402 ) 1403 1404 @torch._dynamo.optimize(ddp_optimizer.compile_fn) 1405 def opt_fn(inputs): 1406 return ddp_m(*inputs) 1407 1408 opt_outputs = opt_fn(inputs) 1409 self.assertTrue(same(correct_outputs, opt_outputs)) 1410 self.assertEqual(check_splits_compiler.compiler_called, 2) 1411 for b in ddp_optimizer.buckets: 1412 for p_id in b.param_ids: 1413 self.assertFalse(p_id in parameter_ids_to_ignore) 1414 1415 @patch.object(config, "optimize_ddp", True) 1416 def test_higher_order_op(self): 1417 from torch.utils.checkpoint import checkpoint 1418 1419 N = 1000 1420 1421 class InnerModule(torch.nn.Module): 1422 def __init__(self) -> None: 1423 super().__init__() 1424 self.linear1 = torch.nn.Linear(N, N) 1425 self.linear2 = torch.nn.Linear(N, N) 1426 1427 def forward(self, x): 1428 a = self.linear1(x) 1429 a = self.linear2(a) 1430 return a 1431 1432 class MockModule(torch.nn.Module): 1433 def __init__(self) -> None: 1434 super().__init__() 1435 self.inner_mod1 = InnerModule() 1436 self.inner_mod2 = InnerModule() 1437 1438 def forward(self, x): 1439 a = checkpoint(self.inner_mod1, x, use_reentrant=False) 1440 a = torch.cos(a) 1441 a = checkpoint(self.inner_mod2, a, use_reentrant=False) 1442 a = torch.cos(a) 1443 return a 1444 1445 mod = MockModule().cuda() 1446 mod = DDP(mod, bucket_cap_mb=1) 1447 x = torch.randn(N, N, device="cuda", requires_grad=True) 1448 args = (x,) 1449 1450 backend = "aot_eager" 1451 cnt = torch._dynamo.testing.CompileCounterWithBackend(backend) 1452 1453 with self.assertRaisesRegex( 1454 torch._dynamo.exc.BackendCompilerFailed, 1455 "DDPOptimizer backend: Found a higher order op in the graph", 1456 ): 1457 torch.compile(mod, backend=cnt)(*args) 1458 1459 def test_fsdp_orig_params_assert(self): 1460 # Test with basic FSDP wrapping (outer wrap around whole model) 1461 m, inputs, correct_outputs = get_model(f"cuda:{self.rank}") 1462 fsdp_m = FSDP(m, use_orig_params=False) 1463 fsdp_m = torch._dynamo.optimize()(fsdp_m) 1464 self.assertRaisesRegex( 1465 AssertionError, 1466 "Dynamo only supports FSDP with use_orig_params=True", 1467 fsdp_m, 1468 inputs, 1469 ) 1470 1471 def test_fsdp_skip_guards(self): 1472 """ 1473 It's currently difficult to test dynamo guards. Most guards tests are indirect- modify something and 1474 observe that the guard in question failed. In this case, since the FSDP guards were already deemed 1475 useless and skipping them is expected to have no practical effect, it's pretty contrived to even try to 1476 make those guards fail. Instead, we observe the 'guard source' printed by dynamo's comptime print_guards 1477 function. 1478 1479 Note: comptime prints the guards before the time they get installed or not installed, so in both cases 1480 (skip or no skip) the same guards get printed. The difference is that in the skip case, they show up 1481 with a special 'guard source' which will cuase them to not be installed. So all we check for is the expected 1482 guard source 'local_fsdp_module'. 1483 """ 1484 global GUARDS_FILE 1485 GUARDS_FILE = StringIO() 1486 1487 for skip_guards, expected_guard_source in ( 1488 (True, "local_fsdp_module"), 1489 (False, "local_unspecialized_nn_module"), 1490 ): 1491 torch._dynamo.reset() 1492 1493 class ToyModel(nn.Module): 1494 def __init__(self, in_feat=10, hidden_feat=5000, out_feat=5): 1495 super().__init__() 1496 self.net = nn.Sequential( 1497 *[nn.Linear(in_feat, hidden_feat), nn.ReLU()] 1498 + [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()] 1499 + [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()] 1500 + [nn.Linear(hidden_feat, out_feat), nn.ReLU()] 1501 ) 1502 1503 def forward(self, inputs): 1504 out = self.net(inputs) 1505 1506 @comptime 1507 def _(ctx): 1508 ctx.print_guards(file=GUARDS_FILE) 1509 1510 return out 1511 1512 device = f"cuda:{self.rank}" 1513 m = ToyModel( 1514 in_feat=10, 1515 hidden_feat=5000, 1516 out_feat=5, 1517 ).to(device) 1518 inputs = torch.rand(20, 10).to(device) 1519 m.apply(init_weights) 1520 correct_outputs = m(inputs) 1521 fsdp_m = FSDP(m, use_orig_params=True) 1522 1523 with torch._dynamo.config.patch(skip_fsdp_guards=skip_guards): 1524 opt_m = torch._dynamo.optimize("aot_eager")(fsdp_m) 1525 outputs = opt_m(inputs) 1526 1527 # far from an exhaustive check of all the expected guards, just check a couple of them. 1528 FileCheck().check("""local "L['self']" TYPE_MATCH""").check( 1529 f"""{expected_guard_source} "L['self']._modules['net']" TYPE_MATCH""" 1530 ).check( 1531 f"""{expected_guard_source} "L['self']._modules['net']._modules['0']" TYPE_MATCH""" 1532 ).run( 1533 GUARDS_FILE.getvalue() 1534 ) 1535 1536 self.assertTrue(same(correct_outputs, outputs)) 1537 1538 def test_fsdp_skip_register_attr_or_module(self): 1539 """ 1540 ensure FSDP module is not registered as attrbutes 1541 in the fx graph 1542 see `not source.guard_source().is_fsdp_module()` 1543 before calling `register_attr_or_module` 1544 in variables/builder.py 1545 """ 1546 1547 class ToyModel(nn.Module): 1548 def __init__(self, in_feat=10, hidden_feat=5000, out_feat=5): 1549 super().__init__() 1550 self.net = nn.Sequential( 1551 *[nn.Linear(in_feat, hidden_feat), nn.ReLU()] 1552 + [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()] 1553 ) 1554 1555 def forward(self, inputs): 1556 out = self.net(inputs) 1557 return out 1558 1559 torch._dynamo.reset() 1560 1561 device = f"cuda:{self.rank}" 1562 m = ToyModel( 1563 in_feat=10, 1564 hidden_feat=5000, 1565 out_feat=5, 1566 ).to(device) 1567 inputs = torch.rand(20, 10).to(device) 1568 m.apply(init_weights) 1569 correct_outputs = m(inputs) 1570 fsdp_m = FSDP(m, use_orig_params=True) 1571 1572 def debug_compiler(gm, _): 1573 for node in gm.graph.nodes: 1574 if node.op == "get_attr": 1575 for name in [ 1576 "l__self___net_0_weight", 1577 "l__self___net_0_bias", 1578 "l__self___net_2_weight", 1579 "l__self___net_2_bias", 1580 ]: 1581 self.assertFalse( 1582 name in node.name, 1583 f"FSDP module {name} should not be registered as attributes", 1584 ) 1585 return gm 1586 1587 opt_m = torch._dynamo.optimize(backend=debug_compiler)(fsdp_m) 1588 outputs = opt_m(inputs) 1589 1590 self.assertTrue(same(correct_outputs, outputs)) 1591 1592 def test_fsdp_dup_tensors_same_source(self): 1593 """ 1594 Tests that FSDP-managed modules' parameters and buffers with the same 1595 source are de-duplicated, meaning that they are each only passed once 1596 as a graph input. 1597 """ 1598 1599 class DuplicateModule(nn.Module): 1600 def __init__(self) -> None: 1601 super().__init__() 1602 self._param = torch.randn((3,), device="cuda") 1603 self._buf = torch.nn.Buffer( 1604 torch.randn((3,), requires_grad=False, device="cuda") 1605 ) 1606 1607 def forward(self, x: torch.Tensor) -> torch.Tensor: 1608 # Use `_param` and `_buf` each twice in this compiled forward 1609 # to exercise if they are de-duplicated by TorchDynamo 1610 z = x + self._buf + self._buf 1611 z += self._param + self._param 1612 return z 1613 1614 model = DuplicateModule() 1615 fsdp_model = FSDP(copy.deepcopy(model), use_orig_params=True) 1616 fsdp_model = torch._dynamo.optimize("aot_eager")(fsdp_model) 1617 inp = torch.randn((2, 3), device="cuda") 1618 local_out = model(inp) 1619 fsdp_out = fsdp_model(inp) 1620 self.assertEqual(local_out, fsdp_out) 1621 1622 @patch.object(config, "guard_nn_modules", True) 1623 def test_fsdp_dup_tensors_diff_source(self): 1624 """ 1625 Tests that FSDP-managed modules' parameters and buffers with different 1626 source do not result in incorrect AOTAutograd de-dup guards like 1627 ``a is b``, where ``a`` and ``b`` are certainly not the same. We check 1628 this by checking for per-invocation recompiles. 1629 """ 1630 1631 class BufModule(nn.Module): 1632 def __init__(self) -> None: 1633 super().__init__() 1634 self._buf = nn.Buffer( 1635 torch.randn((3,), requires_grad=False, device="cuda") 1636 ) 1637 1638 def forward(self, x: torch.Tensor) -> torch.Tensor: 1639 return x + self._buf 1640 1641 class Model(nn.Module): 1642 def __init__(self) -> None: 1643 super().__init__() 1644 self._param = nn.Parameter(torch.randn((1,), device="cuda")) 1645 self._buf_module = BufModule() 1646 # Share the buffer, meaning same tensor but different source 1647 self._buf = self._buf_module._buf 1648 1649 def forward(self, x: torch.Tensor) -> torch.Tensor: 1650 # Use the same buffer tensor twice in the compiled forward, 1651 # including a data mutation to trigger de-dup logic 1652 self._buf.mul_(2) 1653 z = x + self._buf 1654 z = self._buf_module(z) 1655 z += self._param 1656 return z 1657 1658 fsdp_model = FSDP(Model(), use_orig_params=True) 1659 cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") 1660 fsdp_model = torch._dynamo.optimize(cnt)(fsdp_model) 1661 inp = torch.randn((2, 3), device="cuda") 1662 for _ in range(15): 1663 fsdp_model(inp) 1664 # Check for no recompiles (if there were incorrect de-dup guards, then 1665 # the frame count would be equal to the number of forward calls) 1666 self.assertEqual(cnt.frame_count, 1) 1667 1668 def test_fsdp_staticmethod(self): 1669 """ 1670 Tests that Dynamo compiles staticmethods for FSDP-managed modules 1671 correctly both when the staticmethod is invoked from the class and from 1672 the object itself. 1673 """ 1674 1675 class ModuleWithStaticMethod(nn.Module): 1676 def __init__(self, use_self: bool): 1677 super().__init__() 1678 self._use_self = use_self 1679 torch.manual_seed(42) # force `_param` to be deterministic 1680 self._param = nn.Parameter(torch.randn((3,), device="cuda")) 1681 1682 def forward(self, x: torch.Tensor) -> torch.Tensor: 1683 if self._use_self: 1684 z = self._add(x, self._param) 1685 else: 1686 z = ModuleWithStaticMethod._add(x, self._param) 1687 z *= 2 1688 return z 1689 1690 @staticmethod 1691 def _add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 1692 return x + y 1693 1694 model = ModuleWithStaticMethod(False) 1695 x = torch.randn((2, 3), device="cuda") 1696 ref_out = model(x) 1697 test_outs: List[torch.Tensor] = [] 1698 1699 for use_self in (False, True): 1700 model = ModuleWithStaticMethod(use_self) 1701 fsdp_model = FSDP(model, use_orig_params=True) 1702 cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") 1703 fsdp_model = torch._dynamo.optimize(cnt)(fsdp_model) 1704 test_outs.append(fsdp_model(x)) 1705 # Check for no recompiles, which could happen if incorrectly 1706 # passing args to the staticmethod (e.g. doubly passing `self`) 1707 # 3 is expected here for 1 forward. 1708 # Graph 1 should be add and imul 1709 self.assertEqual(cnt.frame_count, 1) 1710 for test_out in test_outs: 1711 self.assertEqual(test_out, ref_out) 1712 1713 def test_async_subclass_no_specialize(self): 1714 cnt = torch._dynamo.testing.CompileCounterWithBackend("eager") 1715 1716 @torch.compile(backend=cnt, fullgraph=True, dynamic=True) 1717 def f(x): 1718 return x + 1 1719 1720 f(_maybe_wrap_tensor(torch.randn(10))) 1721 f(_maybe_wrap_tensor(torch.randn(12))) 1722 1723 self.assertEqual(cnt.frame_count, 1) 1724 1725 1726if __name__ == "__main__": 1727 from torch._dynamo.test_case import run_tests 1728 1729 run_tests() 1730