xref: /aosp_15_r20/external/pytorch/test/distributed/test_dynamo_distributed.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2import contextlib
3import copy
4import functools
5import random
6import unittest
7from contextlib import contextmanager
8from datetime import timedelta
9from io import StringIO
10from typing import List
11from unittest.mock import patch
12
13import numpy as np
14
15import torch
16import torch._dynamo
17import torch._dynamo.logging
18import torch._dynamo.test_case
19import torch.distributed as dist
20import torch.optim as optim
21from torch import nn
22from torch._C import FileCheck
23from torch._dynamo import config
24from torch._dynamo.backends.distributed import DDPOptimizer
25from torch._dynamo.comptime import comptime
26from torch._dynamo.testing import collect_results
27from torch._dynamo.utils import same
28from torch._higher_order_ops.wrap import tag_activation_checkpoint
29from torch.distributed._functional_collectives import _maybe_wrap_tensor
30from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
31from torch.distributed.fsdp.wrap import (
32    lambda_auto_wrap_policy,
33    transformer_auto_wrap_policy,
34)
35from torch.nn.parallel import DistributedDataParallel as DDP
36from torch.testing._internal.common_cuda import (
37    PLATFORM_SUPPORTS_FLASH_ATTENTION,
38    PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
39)
40from torch.testing._internal.common_distributed import (
41    _dynamo_dist_per_rank_init,
42    DynamoDistributedMultiProcTestCase,
43    DynamoDistributedSingleProcTestCase,
44    import_transformers_or_skip,
45    requires_nccl,
46    skip_if_lt_x_gpu,
47)
48from torch.testing._internal.common_utils import requires_cuda
49from torch.utils._triton import has_triton
50
51
52def reset_rng_state():
53    torch.manual_seed(1337)
54    random.seed(1337)
55    np.random.seed(1337)
56
57
58def init_weights(m):
59    if isinstance(m, nn.Linear):
60        nn.init.xavier_uniform_(m.weight)
61        m.bias.data.fill_(0.01)
62
63
64class ToyModel(nn.Module):
65    def __init__(self, in_feat=10, hidden_feat=5000, out_feat=5, ctx_manager=None):
66        super().__init__()
67        self.ctx_manager = ctx_manager
68        self.net = nn.Sequential(
69            *[nn.Linear(in_feat, hidden_feat), nn.ReLU()]
70            + [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()]
71            + [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()]
72            + [nn.Linear(hidden_feat, out_feat), nn.ReLU()]
73        )
74
75    def forward(self, inputs):
76        if self.ctx_manager is not None:
77            with self.ctx_manager():
78                return self.net(inputs)
79        else:
80            return self.net(inputs)
81
82
83def get_model(
84    device, bsz=20, in_feat=10, hidden_feat=5000, out_feat=5, ctx_manager=None
85):
86    m = ToyModel(
87        in_feat=in_feat,
88        hidden_feat=hidden_feat,
89        out_feat=out_feat,
90        ctx_manager=ctx_manager,
91    ).to(device)
92    m.apply(init_weights)
93    inputs = torch.rand(bsz, in_feat).to(device)
94    outputs = m(inputs)
95    return m, inputs, outputs
96
97
98class MutatingModel(nn.Module):
99    def __init__(self, in_feat=10, hidden_feat=5000, out_feat=5, ctx_manager=None):
100        super().__init__()
101        self.ctx_manager = ctx_manager
102        self.net = nn.Sequential(
103            *[nn.Linear(in_feat, hidden_feat), nn.ReLU()]
104            + [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()]
105            + [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()]
106            + [nn.Linear(hidden_feat, out_feat), nn.ReLU()]
107        )
108        self.state = 1
109
110    def forward(self, inputs):
111        self.state = 2
112        return self.net(inputs) * self.state
113
114
115def get_mutating_model(
116    device, bsz=20, in_feat=10, hidden_feat=5000, out_feat=5, ctx_manager=None
117):
118    m = MutatingModel(
119        in_feat=in_feat,
120        hidden_feat=hidden_feat,
121        out_feat=out_feat,
122        ctx_manager=ctx_manager,
123    ).to(device)
124    m.apply(init_weights)
125    inputs = torch.rand(bsz, in_feat).to(device)
126    outputs = m(inputs)
127    return m, inputs, outputs
128
129
130class ToyInnerModel(nn.Module):
131    def __init__(self) -> None:
132        super().__init__()
133        self.layers = [nn.Linear(100, 100), nn.Linear(100, 100)]
134        self.layers = nn.Sequential(*self.layers)
135
136    def forward(self, inputs):
137        return self.layers(inputs)
138
139
140class ToyOuterModel(nn.Module):
141    def __init__(self, device):
142        super().__init__()
143        self.layers = [ToyInnerModel().to(device) for _ in range(2)]
144        self.layers = nn.Sequential(
145            self.layers[0], nn.ReLU(), self.layers[1], nn.ReLU()
146        )
147
148    def forward(self, inputs):
149        return self.layers(inputs)
150
151
152def get_toy_model_for_activation_checkpointing(device):
153    m = ToyOuterModel(device).to(device)
154    m.apply(init_weights)
155    inputs = torch.rand(100, 100).to(device)
156    return m, inputs
157
158
159def find_first_node(gm, func):
160    for node in gm.graph.nodes:
161        if node.target is func:
162            return node
163    return None
164
165
166def apply_fsdp_with_checkpointing(
167    model, wrap_policy, checkpoint_policy, use_activation_checkpointing=True
168):
169    from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
170        apply_activation_checkpointing,
171        checkpoint_wrapper,
172        CheckpointImpl,
173    )
174
175    model = FSDP(
176        copy.deepcopy(model), auto_wrap_policy=wrap_policy, use_orig_params=True
177    )
178    if use_activation_checkpointing:
179        checkpoint_wrapper_fn = functools.partial(
180            checkpoint_wrapper,
181            checkpoint_impl=CheckpointImpl.NO_REENTRANT,
182        )
183        apply_activation_checkpointing(
184            model,
185            checkpoint_wrapper_fn=checkpoint_wrapper_fn,
186            check_fn=checkpoint_policy,
187        )
188    return model
189
190
191def get_custom_model(device):
192    class MyCustomLinear(torch.nn.Module):
193        def __init__(self) -> None:
194            super().__init__()
195            self.weight = nn.Parameter(torch.randn(512, 512))
196
197        def forward(self, x):
198            tmp = torch.mm(x, self.weight.t())
199            # test an edge case where torch.where.scalar was decomposed to aten.where.self(tensor, tensor, tensor)
200            # and the tensors T(0.4) and T(0.5) were not wrapped in FakeTensors during DDPOptimizer compilation
201            return tmp + torch.where(tmp < 0.5, 0.3, 0.6)
202
203    class MyLinear(torch.nn.Module):
204        def __init__(self) -> None:
205            super().__init__()
206            self.linear = torch.nn.Linear(512, 512)
207
208        def forward(self, x):
209            return self.linear(x)
210
211    class MyModule(torch.nn.Module):
212        def __init__(self) -> None:
213            super().__init__()
214            mods = [
215                (MyLinear(), torch.nn.ReLU()),
216                # sandwich the custom in the middle so it comes before and after
217                (MyCustomLinear(), torch.nn.ReLU()),
218                (MyLinear(), torch.nn.ReLU()),
219            ]
220            self.seq = torch.nn.Sequential(*[x for items in mods for x in items])
221
222        def forward(self, x, y):
223            # test special case where the 0th bucket (layers close to graph input) is at capacity, which would
224            # trigger a new bucket, but there are only trivial ops without parameters to put into the new bucket.
225            # optimize this case by fusing that 'empty bucket' back together with the previous full one
226            return self.seq(x + y)
227
228    m = MyModule().to(device)
229    m.apply(init_weights)
230    inputs = torch.rand((512, 512)).to(device)
231    # test duplicated inputs
232    inputs = (inputs, inputs)
233    correct_outputs = m(*inputs)
234    return m, inputs, correct_outputs
235
236
237def get_hf_bert(rank):
238    # Note: use @import_transformers_or_skip on your test case if you use this
239    # in a multiprocessing test
240    try:
241        from transformers import AutoModelForMaskedLM, BertConfig
242    except ImportError as e:
243        raise unittest.SkipTest("Unable to import transformers") from e
244
245    batch_size, max_length, config, device = 4, 512, BertConfig(), f"cuda:{rank}"
246    model = AutoModelForMaskedLM.from_config(config).to(device)
247    input_ids = torch.randint(0, config.vocab_size, (batch_size, max_length)).to(device)
248    decoder_ids = torch.randint(0, config.vocab_size, (batch_size, max_length)).to(
249        device
250    )
251    inputs = {"input_ids": input_ids, "labels": decoder_ids}
252    model.train()
253    return model, inputs
254
255
256class CheckSplitsCompiler:
257    def __init__(self) -> None:
258        self.compiler_called = 0
259
260    def compile_fn(self, gm, example_inputs):
261        self.compiler_called += 1
262        return gm
263
264
265# This simulates DDP, but it doesn't actually do any process communication;
266# it just has enough properties so that the dynamo distributed optimization is
267# able to optimize.  Feel free to simulate more properties as necessary.  The
268# other important thing is patching _active_ddp_module, which is what actually
269# triggers DDP optimization
270class FakeDDP(nn.Module):
271    def __init__(self, module, bucket_cap_mb=25):
272        super().__init__()
273        self.module = module
274        self.bucket_bytes_cap = int(bucket_cap_mb * 1024 * 1024)
275
276    @contextmanager
277    def _inside_ddp_forward(self):
278        DDP._active_ddp_module = self
279        try:
280            yield
281        finally:
282            DDP._active_ddp_module = None
283
284    def forward(self, *inputs, **kwargs):
285        with self._inside_ddp_forward():
286            return self.module.forward(*inputs, **kwargs)
287
288
289def run_hf_bert_ddp(self, model, inputs, backend):
290    reset_rng_state()
291    correct_outputs = model(**inputs)
292    correct_loss = correct_outputs.loss
293    correct_loss.backward()
294
295    reset_rng_state()
296    opt_model = torch._dynamo.optimize(backend)(model)
297    opt_outputs = opt_model(**inputs)
298    opt_loss = opt_outputs.loss
299    opt_loss.backward()
300
301    inputs_flat = [inputs[k] for k in inputs]
302    correct_results = collect_results(
303        model, correct_outputs.logits, correct_loss, inputs_flat
304    )
305    opt_results = collect_results(opt_model, opt_outputs.logits, opt_loss, inputs_flat)
306    self.assertTrue(same(correct_results, opt_results))
307
308
309class TestFakeDistributedSingleProc(torch._dynamo.test_case.TestCase):
310    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
311    @patch.object(config, "optimize_ddp", True)
312    @patch.object(torch._inductor.config, "fallback_random", True)
313    def test_hf_bert_ddp_inductor(self):
314        model, inputs = get_hf_bert(0)
315        model = FakeDDP(model)
316        run_hf_bert_ddp(self, model, inputs, "inductor")
317
318    @patch.object(config, "optimize_ddp", True)
319    def test_hf_bert_ddp_aot_eager(self):
320        model, inputs = get_hf_bert(0)
321        model = FakeDDP(model)
322        run_hf_bert_ddp(self, model, inputs, "aot_eager")
323
324    @patch.object(config, "optimize_ddp", True)
325    def test_issue90375(self):
326        class Model(nn.Module):
327            def forward(self):
328                return torch.randn(3) * torch.randn(3)
329
330        model = Model()
331        model = FakeDDP(model)
332
333        opt_model = torch._dynamo.optimize("aot_eager")(model)
334        opt_model()
335
336    @patch.object(config, "optimize_ddp", True)
337    def test_symbol_splitting(self):
338        class Model(nn.Module):
339            def __init__(self) -> None:
340                super().__init__()
341                self.weight1 = nn.Parameter(torch.randn(512, 512))
342                self.weight2 = nn.Parameter(torch.randn(512, 512))
343
344            def forward(self, x):
345                x = torch.cat([x, x])
346                y = x @ self.weight1
347                z = x + y @ self.weight2
348                return z
349
350        model = Model()
351        model = FakeDDP(model)
352
353        opt_model = torch.compile(dynamic=True)(model)
354        opt_model(torch.randn(20, 512))
355
356    @config.patch(optimize_ddp=True, capture_scalar_outputs=True)
357    def test_unbacked_symbol_splitting_direct(self):
358        class Model(nn.Module):
359            def __init__(self) -> None:
360                super().__init__()
361                self.weight1 = nn.Parameter(torch.randn(512, 512))
362                self.weight2 = nn.Parameter(torch.randn(512, 512))
363
364            def forward(self, x, y):
365                u0, u1 = y.tolist()
366                x = torch.cat([x, x])
367                y = x @ self.weight1
368                z = (x + y @ self.weight2) * u0
369                return z
370
371        model = Model()
372        model = FakeDDP(model)
373
374        opt_model = torch.compile(dynamic=True)(model)
375        opt_model(torch.randn(20, 512), torch.tensor([12, 13]))
376
377    @config.patch(optimize_ddp=True, capture_scalar_outputs=True)
378    def test_unbacked_symbol_splitting_indirect(self):
379        class Model(nn.Module):
380            def __init__(self) -> None:
381                super().__init__()
382                self.weight1 = nn.Parameter(torch.randn(512, 512))
383                self.weight2 = nn.Parameter(torch.randn(512, 512))
384
385            def forward(self, x, y):
386                u0, u1 = y.tolist()
387                a = torch.ones(u0)
388                x = torch.cat([x, x])
389                y = x @ self.weight1
390                z = (x + y @ self.weight2) * a.sum()
391                return z
392
393        model = Model()
394        model = FakeDDP(model)
395
396        opt_model = torch.compile(dynamic=True)(model)
397        opt_model(torch.randn(20, 512), torch.tensor([12, 13]))
398
399    @config.patch(optimize_ddp=True, capture_scalar_outputs=True)
400    def test_unbacked_symbol_splitting_torture_multi(self):
401        class Model(nn.Module):
402            def __init__(self) -> None:
403                super().__init__()
404                self.weight1 = nn.Parameter(torch.randn(512, 512))
405                self.weight2 = nn.Parameter(torch.randn(512, 512))
406                self.weight3 = nn.Parameter(torch.randn(512, 512))
407
408            def forward(self, x, y):
409                # partition one (contains the u0 def)
410                u0, u1 = y.tolist()
411                x = torch.cat([x, x])
412                y1 = x @ self.weight1
413                # partition two (contains the variable)
414                y2 = y1 @ self.weight2
415                a = torch.ones(u0)
416                # partition three
417                z = (x + y2 @ self.weight3) * a.sum()
418                return z
419
420        model = Model()
421        model = FakeDDP(model, bucket_cap_mb=1)
422
423        opt_model = torch.compile(dynamic=True)(model)
424        opt_model(torch.randn(20, 512), torch.tensor([12, 13]))
425
426    @config.patch(optimize_ddp=True, capture_dynamic_output_shape_ops=True)
427    def test_unbacked_symbol_splitting_no_binding(self):
428        class Model(nn.Module):
429            def __init__(self) -> None:
430                super().__init__()
431                self.weight1 = nn.Parameter(torch.randn(512, 512))
432                self.weight2 = nn.Parameter(torch.randn(512, 512))
433
434            def forward(self, x, y):
435                nz = y.nonzero()
436                x = torch.cat([x, x])
437                y = x @ self.weight1
438                z = (x + y @ self.weight2) * (nz + 1).sum()
439                return z
440
441        model = Model()
442        model = FakeDDP(model)
443
444        opt_model = torch.compile(dynamic=True)(model)
445        opt_model(torch.randn(20, 512), torch.tensor([0.0, 12.0, 0.0, 11.0]))
446
447    @patch.object(config, "optimize_ddp", True)
448    def test_call_method_forward(self):
449        class Model(nn.Module):
450            def __init__(
451                self,
452            ):
453                super().__init__()
454                layers = []
455                for l in range(2):
456                    layer = nn.ModuleList(
457                        [
458                            nn.LayerNorm(96),
459                            nn.MultiheadAttention(
460                                embed_dim=96, num_heads=4, batch_first=True
461                            ),
462                        ]
463                    )
464                    layers.append(layer)
465                self.layers = nn.ModuleList(layers)
466
467            def forward(self, x: torch.Tensor) -> torch.Tensor:
468                # x: [Batch, Freq, Time, Feature]
469                B, F, T, H = x.shape
470                for m in self.layers:
471                    x = x.reshape(B * F, T, H)
472                    x = m[0](x)
473                    x, attn = m[1].forward(x, x, x)
474                    x = x.reshape(B, F, T, H)
475                return x
476
477        model = Model()
478        model = FakeDDP(model)
479        opt_model = torch.compile(model)
480        opt_model(torch.randn(2, 129, 100, 96))
481
482
483# Are these tests failing?  Check and see if TestFakeDistributedSingleProc has a
484# single process version; if it's just a problem in the Dynamo distributed
485# optimizer, you should be able to repro it single process!
486@requires_nccl()
487class TestMultiProc(DynamoDistributedMultiProcTestCase):
488    """
489    Note: MultiProcTestCase spawns processes per test and is slow.
490    Prefer MultiThreadedTestCase for most tests. Perhaps use this one
491    sparingly for integration tests.
492    """
493
494    @skip_if_lt_x_gpu(2)
495    @config.patch(optimize_ddp=False, enable_compiler_collectives=True)
496    def test_ddp_baseline_aot_eager_multiprocess(self):
497        with _dynamo_dist_per_rank_init(self.rank, self.world_size):
498            self.assertFalse(config.optimize_ddp)
499            m, inputs, correct_outputs = get_model(f"cuda:{self.rank}")
500            m = DDP(m, device_ids=[self.rank])
501            m = torch._dynamo.optimize("aot_eager")(m)
502            outputs = m(inputs)
503            self.assertTrue(same(correct_outputs, outputs))
504
505    def _test_hf_bert_ddp_inductor(self, static_graph):
506        with _dynamo_dist_per_rank_init(self.rank, self.world_size):
507            model, inputs = get_hf_bert(self.rank)
508            model = DDP(model, static_graph=static_graph)
509            run_hf_bert_ddp(self, model, inputs, "inductor")
510
511    @skip_if_lt_x_gpu(2)
512    @import_transformers_or_skip()
513    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
514    @config.patch(optimize_ddp=True, enable_compiler_collectives=True)
515    @patch.object(torch._inductor.config, "fallback_random", True)
516    def test_hf_bert_ddp_inductor(self):
517        self._test_hf_bert_ddp_inductor(static_graph=False)
518
519    @skip_if_lt_x_gpu(2)
520    @import_transformers_or_skip()
521    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
522    @config.patch(optimize_ddp=True, enable_compiler_collectives=True)
523    @patch.object(torch._inductor.config, "fallback_random", True)
524    def test_hf_bert_ddp_inductor_static_graph(self):
525        self._test_hf_bert_ddp_inductor(static_graph=True)
526
527    def _test_hf_bert_aot_eager(self, static_graph):
528        with _dynamo_dist_per_rank_init(self.rank, self.world_size):
529            model, inputs = get_hf_bert(self.rank)
530            model = DDP(model, static_graph=static_graph)
531            run_hf_bert_ddp(self, model, inputs, "aot_eager")
532
533    @skip_if_lt_x_gpu(2)
534    @import_transformers_or_skip()
535    @config.patch(optimize_ddp=True, enable_compiler_collectives=True)
536    def test_hf_bert_ddp_aot_eager(self):
537        self._test_hf_bert_aot_eager(static_graph=False)
538
539    @skip_if_lt_x_gpu(2)
540    @import_transformers_or_skip()
541    @config.patch(optimize_ddp=True, enable_compiler_collectives=True)
542    def test_hf_bert_ddp_aot_eager_static_graph(self):
543        self._test_hf_bert_aot_eager(static_graph=True)
544
545    @skip_if_lt_x_gpu(2)
546    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
547    @config.patch(optimize_ddp=False, enable_compiler_collectives=True)
548    def test_ddp_activation_checkpointing(self):
549        from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
550            apply_activation_checkpointing,
551            checkpoint_wrapper,
552            CheckpointImpl,
553        )
554
555        class MyModel(torch.nn.Module):
556            def __init__(self) -> None:
557                super().__init__()
558                self.fc1 = torch.nn.Linear(64, 32)
559                self.fc2 = torch.nn.Linear(32, 16)
560                self.fc3 = torch.nn.Linear(16, 8)
561
562            def forward(self, inp):
563                return self.fc3(self.fc2(self.fc1(inp)))
564
565        with _dynamo_dist_per_rank_init(self.rank, self.world_size):
566            self.assertFalse(config.optimize_ddp)
567            model = MyModel().to(device="cuda")
568
569            # Activation checkpointing for Linear layers.
570            non_reentrant_wrapper = functools.partial(
571                checkpoint_wrapper,
572                checkpoint_impl=CheckpointImpl.NO_REENTRANT,
573            )
574            check_fn = lambda submodule: isinstance(  # noqa: E731
575                submodule, torch.nn.Linear
576            )
577            apply_activation_checkpointing(
578                model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn
579            )
580
581            model = DDP(model)
582            x = torch.randn(10, 64).cuda()
583            correct_outputs = model(x)
584
585            opt_model = torch.compile(model)
586            outputs = opt_model(x)
587            self.assertTrue(same(correct_outputs, outputs))
588
589    @config.patch(enable_compiler_collectives=True)
590    @skip_if_lt_x_gpu(1)
591    def test_fsdp_aot_eager(self):
592        with _dynamo_dist_per_rank_init(self.rank, self.world_size):
593            # Test with basic FSDP wrapping (outer wrap around whole model)
594            m, inputs, correct_outputs = get_model(f"cuda:{self.rank}")
595            fsdp_m = FSDP(m, use_orig_params=True)
596            fsdp_m = torch._dynamo.optimize("aot_eager")(fsdp_m)
597            outputs = fsdp_m(inputs)
598            self.assertTrue(same(correct_outputs, outputs))
599
600            # Test with recursive wrapping, nested FSDP around each Linear
601            m, inputs, correct_outputs = get_model(f"cuda:{self.rank}")
602            fsdp_m = FSDP(
603                m,
604                auto_wrap_policy=functools.partial(
605                    transformer_auto_wrap_policy, transformer_layer_cls=(nn.Linear,)
606                ),
607                use_orig_params=True,
608            )
609            fsdp_m = torch._dynamo.optimize("aot_eager")(fsdp_m)
610            outputs = fsdp_m(inputs)
611            self.assertTrue(same(correct_outputs, outputs))
612
613    @config.patch(enable_compiler_collectives=True)
614    @skip_if_lt_x_gpu(1)
615    def test_fsdp_setattr(self):
616        with _dynamo_dist_per_rank_init(self.rank, self.world_size):
617            # Test with basic FSDP wrapping (outer wrap around whole model)
618            from torch._dynamo.utils import counters
619
620            counters.clear()
621            m, inputs, correct_outputs = get_mutating_model(f"cuda:{self.rank}")
622            fsdp_m = FSDP(m, use_orig_params=True)
623            fsdp_m = torch.compile(fsdp_m, backend="eager", fullgraph=False)
624            outputs = fsdp_m(inputs)
625            self.assertTrue(same(correct_outputs, outputs))
626            self.assertEqual(len(counters["graph_break"]), 1)
627            first_graph_break = list(counters["graph_break"].keys())[0]  # noqa: RUF015
628            self.assertTrue("setattr" not in first_graph_break)
629
630    @config.patch(enable_compiler_collectives=True)
631    @skip_if_lt_x_gpu(1)
632    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
633    def test_fsdp_inductor(self):
634        with _dynamo_dist_per_rank_init(self.rank, self.world_size):
635            # Test with basic FSDP wrapping (outer wrap around whole model)
636            m, inputs, correct_outputs = get_model(f"cuda:{self.rank}")
637            fsdp_m = FSDP(m, use_orig_params=True)
638            fsdp_m = torch._dynamo.optimize("inductor")(fsdp_m)
639            outputs = fsdp_m(inputs)
640            self.assertTrue(same(correct_outputs, outputs))
641
642            # Test with recursive wrapping, nested FSDP around each Linear
643            m, inputs, correct_outputs = get_model(f"cuda:{self.rank}")
644            fsdp_m = FSDP(
645                m,
646                auto_wrap_policy=functools.partial(
647                    transformer_auto_wrap_policy, transformer_layer_cls=(nn.Linear,)
648                ),
649                use_orig_params=True,
650            )
651            fsdp_m = torch._dynamo.optimize("inductor")(fsdp_m)
652            outputs = fsdp_m(inputs)
653            self.assertTrue(same(correct_outputs, outputs))
654
655    @config.patch(enable_compiler_collectives=True)
656    @skip_if_lt_x_gpu(1)
657    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
658    def test_fsdp_activation_checkpointing(self):
659        with _dynamo_dist_per_rank_init(self.rank, self.world_size):
660            model, inputs = get_toy_model_for_activation_checkpointing(
661                f"cuda:{self.rank}"
662            )
663            is_inner = lambda module: isinstance(module, ToyInnerModel)  # noqa: E731
664            wrap_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=is_inner)
665            model = apply_fsdp_with_checkpointing(model, wrap_policy, is_inner)
666            correct_outputs = model(inputs)
667            cnt = torch._dynamo.testing.CompileCounterWithBackend("inductor")
668            opt_model = torch._dynamo.optimize(cnt)(model)
669            outputs = opt_model(inputs)
670            self.assertTrue(same(correct_outputs, outputs))
671            # Each FSDP module is a separate graph
672            self.assertEqual(cnt.frame_count, 2)
673            self.assertTrue(
674                find_first_node(cnt.graphs[0], tag_activation_checkpoint) is not None
675            )
676
677    @import_transformers_or_skip()
678    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
679    # TODO(whc) Investigate why cudagraphs breaks inductor+fsdp for hf_bert
680    @patch.object(torch._inductor.config.triton, "cudagraphs", False)
681    @patch.object(torch._inductor.config, "fallback_random", True)
682    @config.patch(enable_compiler_collectives=True)
683    @unittest.skipIf(
684        PLATFORM_SUPPORTS_FLASH_ATTENTION or PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
685        "Inaccurate results with fused SDPA kernels",
686    )
687    def test_hf_bert_fsdp(self):
688        def apply_fsdp(model, wrap_policy):
689            model = FSDP(
690                copy.deepcopy(model), auto_wrap_policy=wrap_policy, use_orig_params=True
691            )
692            return model
693
694        with _dynamo_dist_per_rank_init(self.rank, self.world_size):
695            for wrap_policy, test_instance in (
696                (None, "FSDP without recursive wrapping"),
697            ):
698                print(f"Running hf_bert test for {test_instance}")
699                model, inputs = get_hf_bert(self.rank)
700                reset_rng_state()
701                eager_model = apply_fsdp(model, wrap_policy)
702                correct_outputs = eager_model(**inputs)
703                correct_loss = correct_outputs.loss
704                correct_loss.backward()
705
706                reset_rng_state()
707                opt_model = apply_fsdp(model, wrap_policy)
708                opt_model = torch._dynamo.optimize("inductor")(opt_model)
709                opt_outputs = opt_model(**inputs)
710                opt_loss = opt_outputs.loss
711                opt_loss.backward()
712
713                inputs_flat = [inputs[k] for k in inputs]
714                correct_results = collect_results(
715                    eager_model, correct_outputs.logits, correct_loss, inputs_flat
716                )
717                opt_results = collect_results(
718                    opt_model, opt_outputs.logits, opt_loss, inputs_flat
719                )
720                self.assertTrue(same(correct_results, opt_results))
721
722    @import_transformers_or_skip()
723    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
724    # TODO(whc) Investigate why cudagraphs breaks inductor+fsdp for hf_bert
725    @patch.object(torch._inductor.config.triton, "cudagraphs", False)
726    @patch.object(torch._inductor.config, "fallback_random", True)
727    @config.patch(guard_nn_modules=True, enable_compiler_collectives=True)
728    def test_hf_bert_fsdp_activation_checkpointing(self):
729        from transformers.models.bert.modeling_bert import BertLayer
730
731        with _dynamo_dist_per_rank_init(self.rank, self.world_size):
732            for wrap_policy, test_instance in (
733                (
734                    functools.partial(
735                        transformer_auto_wrap_policy, transformer_layer_cls=(BertLayer,)
736                    ),
737                    "FSDP with recursive wrapping BertLayer instances",
738                ),
739            ):
740                print(
741                    f"Running hf_bert_activation_checkpointing test for {test_instance}"
742                )
743                model, inputs = get_hf_bert(self.rank)
744                check_fn = lambda submodule: isinstance(  # noqa: E731
745                    submodule, BertLayer
746                )
747                reset_rng_state()
748                eager_model = apply_fsdp_with_checkpointing(
749                    model, wrap_policy, check_fn
750                )
751                correct_outputs = eager_model(**inputs)
752                correct_loss = correct_outputs.loss
753                correct_loss.backward()
754
755                reset_rng_state()
756                opt_model = apply_fsdp_with_checkpointing(model, wrap_policy, check_fn)
757                opt_model = torch._dynamo.optimize("inductor")(opt_model)
758                opt_outputs = opt_model(**inputs)
759                opt_loss = opt_outputs.loss
760                opt_loss.backward()
761
762                inputs_flat = [inputs[k] for k in inputs]
763                correct_results = collect_results(
764                    eager_model, correct_outputs.logits, correct_loss, inputs_flat
765                )
766                opt_results = collect_results(
767                    opt_model, opt_outputs.logits, opt_loss, inputs_flat
768                )
769                self.assertTrue(same(correct_results, opt_results))
770
771    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
772    @config.patch(enable_compiler_collectives=True)
773    def test_compiler_collectives_automatic_dynamic_tensor(self):
774        with _dynamo_dist_per_rank_init(self.rank, self.world_size):
775
776            class SimpleModel(nn.Module):
777                def __init__(self, input_size, output_size):
778                    super().__init__()
779                    self.linear = nn.Linear(input_size, output_size)
780
781                def forward(self, x):
782                    return self.linear(x)
783
784            torch._dynamo.utils.clear_compilation_metrics()
785
786            model = SimpleModel(10, 2).to(self.rank)
787            model.forward = torch.compile(model.forward)
788            ddp_model = DDP(model, device_ids=[self.rank])
789
790            loss_fn = nn.CrossEntropyLoss()
791            optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
792
793            def B(s):
794                return [torch.randn(s, 10), torch.randint(0, 2, (s,))]
795
796            if self.rank == 0:
797                dataloader = [B(5), B(8), B(6)]
798            else:
799                dataloader = [B(6), B(6), B(3)]
800
801            for data, labels in dataloader:
802                data, labels = data.to(self.rank), labels.to(self.rank)
803                optimizer.zero_grad()
804                output = ddp_model(data)
805                loss = loss_fn(output, labels)
806                loss.backward()
807                optimizer.step()
808
809            metrics = torch._dynamo.utils.get_compilation_metrics()
810            # Number of compiles same on all nodes
811            res = [None] * self.world_size
812            torch.distributed.all_gather_object(res, len(metrics))
813            for r in res[1:]:
814                self.assertEqual(res[0], r)
815
816    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
817    @config.patch(enable_compiler_collectives=True)
818    def test_compiler_collectives_automatic_dynamic_scalar(self):
819        with _dynamo_dist_per_rank_init(self.rank, self.world_size):
820            torch._dynamo.utils.clear_compilation_metrics()
821
822            # TODO: This should be possible to do inside the function, but
823            device = f"cuda:{self.rank}"
824
825            @torch.compile()
826            def f(x, y):
827                return x + torch.ones(y, device=device).sum()
828
829            if self.rank == 0:
830                dataloader = [3, 3, 7]
831            else:
832                dataloader = [3, 4, 9]
833
834            for data in dataloader:
835                f(torch.randn(5, device=self.rank), data)
836
837            metrics = torch._dynamo.utils.get_compilation_metrics()
838            # Number of compiles same on all nodes
839            res = [None] * self.world_size
840            torch.distributed.all_gather_object(res, len(metrics))
841            for r in res[1:]:
842                self.assertEqual(res[0], r)
843
844    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
845    @config.patch(enable_compiler_collectives=True)
846    def test_compiler_collectives_automatic_dynamic_speculation_divergence(self):
847        with _dynamo_dist_per_rank_init(self.rank, self.world_size):
848            torch._dynamo.utils.clear_compilation_metrics()
849
850            # TODO: This should be possible to do inside the function, but
851            device = f"cuda:{self.rank}"
852
853            @torch.compile()
854            def f(x, y):
855                zx = x.shape
856                zy = y.shape
857                return x.sum() + y.sum()
858
859            if self.rank == 0:
860                dataloader = [4, 4]
861            else:
862                dataloader = [3, 4]
863
864            for data in dataloader:
865                f(
866                    torch.randn(data, device=self.rank),
867                    torch.randn(data, device=self.rank),
868                )
869
870            metrics = torch._dynamo.utils.get_compilation_metrics()
871            # Number of compiles same on all nodes
872            res = [None] * self.world_size
873            torch.distributed.all_gather_object(res, len(metrics))
874            for r in res[1:]:
875                self.assertEqual(res[0], r)
876
877    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
878    @config.patch(enable_compiler_collectives=True)
879    def test_compiler_collectives_graph_break_empty_graph_still_collective(self):
880        with _dynamo_dist_per_rank_init(self.rank, self.world_size):
881            torch._dynamo.utils.clear_compilation_metrics()
882
883            device = f"cuda:{self.rank}"
884
885            @torch.compile()
886            def f(x, y):
887                z = y
888                print("woof")
889                zx = x.shape
890                zy = y.shape
891                return x.sum() + y.sum()
892
893            if self.rank == 0:
894                dataloader = [5, 5, 6]
895            else:
896                dataloader = [3, 4, 5]
897
898            for data in dataloader:
899                f(
900                    torch.randn(data, device=self.rank),
901                    torch.randn(data, device=self.rank),
902                )
903
904            metrics = torch._dynamo.utils.get_compilation_metrics()
905            # Number of compiles same on all nodes
906            res = [None] * self.world_size
907            torch.distributed.all_gather_object(res, len(metrics))
908            for r in res[1:]:
909                self.assertEqual(res[0], r)
910
911    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
912    @config.patch(enable_compiler_collectives=True)
913    def test_compiler_collectives_dim_mismatch(self):
914        with _dynamo_dist_per_rank_init(self.rank, self.world_size):
915            torch._dynamo.utils.clear_compilation_metrics()
916
917            @torch.compile()
918            def f(x, y):
919                zx = x.shape
920                zy = y.shape
921                return x.sum() + y.sum()
922
923            if self.rank == 0:
924                dataloader = [[4, 2]]
925            else:
926                dataloader = [[3]]
927
928            for data in dataloader:
929                f(
930                    torch.randn(data, device=self.rank),
931                    torch.randn(data, device=self.rank),
932                )
933
934            metrics = torch._dynamo.utils.get_compilation_metrics()
935            res = [None] * self.world_size
936            torch.distributed.all_gather_object(res, len(metrics))
937            for r in res[1:]:
938                self.assertEqual(res[0], r)
939
940    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
941    @config.patch(enable_compiler_collectives=True)
942    def test_compiler_collectives_missing_source(self):
943        with _dynamo_dist_per_rank_init(self.rank, self.world_size):
944            torch._dynamo.utils.clear_compilation_metrics()
945
946            @torch.compile()
947            def f(rank, xs):
948                return xs[rank].sum()
949
950            xs = []
951            for _ in range(self.world_size):
952                xs.append(torch.randn(10, device=self.rank))
953
954            f(self.rank, xs)
955
956            metrics = torch._dynamo.utils.get_compilation_metrics()
957            res = [None] * self.world_size
958            torch.distributed.all_gather_object(res, len(metrics))
959            for r in res[1:]:
960                self.assertEqual(res[0], r)
961
962    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
963    @patch.object(torch._inductor.config, "fx_graph_cache", False)
964    @patch.object(torch._inductor.config, "fx_graph_remote_cache", False)
965    def test_asymmetric_compilation(self):
966        from torch._dynamo.comptime import comptime
967
968        with _dynamo_dist_per_rank_init(self.rank, self.world_size):
969            torch._dynamo.utils.clear_compilation_metrics()
970
971            device = f"cuda:{self.rank}"
972
973            pg = dist.distributed_c10d._get_default_group()
974
975            cnt = torch._dynamo.testing.CompileCounter()
976            sleep_time = 5
977
978            @torch._dynamo.optimize(cnt)
979            def f(x):
980                if self.rank == 0:
981                    comptime.sleep(sleep_time)
982
983                y = 2 * x
984                return y.sum()
985
986            backend = pg._get_backend(torch.device(device))
987            backend._set_default_timeout(timedelta(seconds=sleep_time - 2))
988
989            x = torch.ones(4, device=device)
990
991            # NCCL startup is lazy
992            w = pg.allreduce(x)
993            w.wait()
994
995            f(x)
996            if self.rank != 0:
997                # test fails with NCCL timeout without this line
998                dist.distributed_c10d._add_ephemeral_timeout_for_all_pgs(
999                    timedelta(seconds=sleep_time)
1000                )
1001
1002            w = pg.allreduce(x)
1003            w.wait()
1004            torch.cuda.synchronize(device)
1005
1006            metrics = torch._dynamo.utils.get_compilation_metrics()
1007            # Number of compiles same on all nodes
1008            res = [None] * self.world_size
1009            torch.distributed.all_gather_object(res, len(metrics))
1010            for r in res[1:]:
1011                self.assertEqual(res[0], r)
1012
1013    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
1014    @patch.object(torch._inductor.config, "fx_graph_cache", True)
1015    @patch.object(torch._inductor.config, "fx_graph_remote_cache", False)
1016    @patch.object(torch._inductor.config, "sleep_sec_TESTING_ONLY", 10)
1017    def test_asymmetric_compilation_with_fx_cache(self):
1018        from torch._dynamo.utils import counters
1019        from torch._inductor.utils import fresh_inductor_cache
1020
1021        with fresh_inductor_cache(), _dynamo_dist_per_rank_init(
1022            self.rank, self.world_size
1023        ):
1024            torch._dynamo.utils.clear_compilation_metrics()
1025
1026            device = f"cuda:{self.rank}"
1027
1028            pg = dist.distributed_c10d._get_default_group()
1029
1030            @torch.compile
1031            def f(x):
1032                y = 2 * x
1033                return y.sum()
1034
1035            backend = pg._get_backend(torch.device(device))
1036            backend._set_default_timeout(timedelta(seconds=5))
1037            counters.clear()
1038
1039            x = torch.ones(4, device=device)
1040
1041            f(x)
1042
1043            self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
1044            self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
1045            self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
1046
1047            w = pg.allreduce(x)
1048            w.wait()
1049            torch.cuda.synchronize(device)
1050            torch._dynamo.reset()
1051
1052            if self.rank == 0:
1053                with fresh_inductor_cache():
1054                    f(x)
1055                self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 2)
1056                self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
1057                self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
1058            else:
1059                f(x)
1060                self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
1061                self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
1062                self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
1063
1064            w = pg.allreduce(x)
1065            w.wait()
1066            torch.cuda.synchronize(device)
1067
1068
1069@requires_nccl()
1070@requires_cuda
1071class TestSingleProc(DynamoDistributedSingleProcTestCase):
1072    """
1073    Test harness initializes dist process group.
1074
1075    Test simple things here since they are simpler to debug.
1076    Use TestMultiProc for things that really need to run on multiple nodes
1077    """
1078
1079    def get_model(
1080        self, bsz=20, in_feat=10, hidden_feat=5000, out_feat=5, ctx_manager=None
1081    ):
1082        m = ToyModel(
1083            in_feat=in_feat,
1084            hidden_feat=hidden_feat,
1085            out_feat=out_feat,
1086            ctx_manager=ctx_manager,
1087        ).to(self.device)
1088        m.apply(init_weights)
1089        inputs = torch.rand(bsz, in_feat).to(self.device)
1090        outputs = m(inputs)
1091        return m, inputs, outputs
1092
1093    @patch.object(config, "optimize_ddp", False)
1094    def test_ddp_baseline_aot_eager(self):
1095        from torch.nn.parallel import DistributedDataParallel as DDP
1096
1097        m, inputs, correct_outputs = self.get_model()
1098        ddp_m = DDP(m, device_ids=self.device_ids)
1099        ddp_m = torch._dynamo.optimize("aot_eager")(ddp_m)
1100        outputs = ddp_m(inputs)
1101        self.assertTrue(same(correct_outputs, outputs))
1102
1103    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
1104    @patch.object(config, "optimize_ddp", False)
1105    def test_ddp_baseline_inductor(self):
1106        from torch.nn.parallel import DistributedDataParallel as DDP
1107
1108        m, inputs, correct_outputs = self.get_model()
1109        ddp_m = DDP(m, device_ids=self.device_ids)
1110        ddp_m = torch._dynamo.optimize("inductor")(ddp_m)
1111        outputs = ddp_m(inputs)
1112        self.assertTrue(same(correct_outputs, outputs))
1113
1114    @patch.object(config, "optimize_ddp", True)
1115    def test_graph_split(self):
1116        assert config.optimize_ddp
1117        """
1118        Just ensures that the appropriate number of splits happen (based on
1119        bucket size and model parameters) - verifies the number of times
1120        the user-provided compiler is called by the DDPOptimizer which is
1121        doing the graph splitting
1122        """
1123
1124        m, inputs, correct_outputs = self.get_model()
1125        ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25)
1126
1127        check_splits_compiler = CheckSplitsCompiler()
1128
1129        @torch._dynamo.optimize(check_splits_compiler.compile_fn)
1130        def opt_fn(inputs):
1131            return ddp_m(inputs)
1132
1133        opt_outputs = opt_fn(inputs)
1134        self.assertTrue(same(correct_outputs, opt_outputs))
1135        self.assertEqual(check_splits_compiler.compiler_called, 3)
1136
1137        # ensure compatibility with dynamo explain
1138
1139        explain_out = torch._dynamo.explain(ddp_m)(inputs)
1140        break_reasons = explain_out.break_reasons
1141        self.assertEqual(len(break_reasons), 3)
1142        self.assertTrue(all("DDPOptimizer" in r.reason for r in break_reasons))
1143
1144    @patch.object(config, "optimize_ddp", True)
1145    def test_graph_split_ctx_manager(self):
1146        """
1147        Ensures that we get the right number of splits and that the respective
1148        context managers' effects are applied to the computation.
1149        """
1150
1151        for get_compiler in [
1152            lambda: CheckSplitsCompiler(),
1153            lambda: None,
1154        ]:
1155            for ctx_manager, output_test in [
1156                (
1157                    lambda: torch.autocast(
1158                        torch.device(self.device).type, torch.float16
1159                    ),
1160                    lambda out: self.assertEqual(out.dtype, torch.float16),
1161                ),
1162                (torch.enable_grad, lambda out: self.assertTrue(out.requires_grad)),
1163                (torch.no_grad, lambda out: self.assertTrue(not out.requires_grad)),
1164            ]:
1165                m, inputs, correct_outputs = self.get_model(
1166                    out_feat=1000,
1167                    hidden_feat=1000,
1168                    in_feat=1000,
1169                    ctx_manager=ctx_manager,
1170                )
1171                # inp - 1000 * 1000 matrix of float32 (4 bytes) = 4MB
1172                # hidden - 1000 * 1000 matrix of float32 (4 bytes) = 4MB
1173                bucket_cap_mb = 3.5  # 4MB
1174                ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=bucket_cap_mb)
1175
1176                compiler = get_compiler()
1177
1178                @torch._dynamo.optimize(
1179                    compiler.compile_fn if compiler else "aot_eager"
1180                )
1181                def opt_fn(inputs):
1182                    return ddp_m(inputs)
1183
1184                opt_outputs = opt_fn(inputs)
1185                self.assertTrue(same(correct_outputs, opt_outputs))
1186                if compiler:
1187                    self.assertEqual(compiler.compiler_called, 4)
1188
1189                output_test(opt_outputs)
1190
1191                # ensure compatibility with dynamo explain
1192
1193                explain_out = torch._dynamo.explain(ddp_m)(inputs)
1194                break_reasons = explain_out.break_reasons
1195                self.assertEqual(len(break_reasons), 4)
1196                self.assertTrue(all("DDPOptimizer" in r.reason for r in break_reasons))
1197
1198    @patch.object(config, "optimize_ddp", True)
1199    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
1200    def test_graph_split_inductor(self):
1201        assert config.optimize_ddp
1202        """
1203        Same as above, but using inductor backend.
1204        We observed issues with inductor/fx interface in the past.
1205        """
1206        m, inputs, correct_outputs = self.get_model()
1207        ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25)
1208
1209        @torch._dynamo.optimize("inductor")
1210        def opt_fn(inputs):
1211            return ddp_m(inputs)
1212
1213        opt_outputs = opt_fn(inputs)
1214        self.assertTrue(same(correct_outputs, opt_outputs))
1215
1216    @torch._inductor.config.patch(
1217        {"layout_optimization": True, "keep_output_stride": False}
1218    )
1219    @patch.object(config, "optimize_ddp", True)
1220    def _test_graph_split_inductor_layout_optimizations_impl(self, context):
1221        assert config.optimize_ddp
1222        channel_dim = 512
1223        # channel dim must be > 64 for inductor to do layout optimization and use NHWC
1224
1225        class ToyModelConv(nn.Module):
1226            def __init__(self) -> None:
1227                super().__init__()
1228                self.net = nn.Sequential(
1229                    *[
1230                        nn.Conv2d(channel_dim, channel_dim, 1, stride=1, bias=False),
1231                        nn.ReLU(),
1232                    ]
1233                    + [
1234                        nn.Conv2d(channel_dim, channel_dim, 1, stride=1, bias=False),
1235                        nn.ReLU(),
1236                    ]
1237                    + [
1238                        nn.Conv2d(channel_dim, channel_dim, 1, stride=1, bias=False),
1239                        nn.ReLU(),
1240                    ]
1241                    + [
1242                        nn.Conv2d(channel_dim, channel_dim, 1, stride=1, bias=False),
1243                        nn.ReLU(),
1244                    ]
1245                )
1246
1247            def forward(self, inputs):
1248                return self.net(inputs)
1249
1250        def get_model():
1251            m = ToyModelConv().to(self.device)
1252            m.apply(init_weights)
1253            inputs = torch.rand(2, channel_dim, channel_dim, 128).to(self.device)
1254            outputs = m(inputs)
1255            return m, inputs, outputs
1256
1257        with context():
1258            m, inputs, correct_outputs = get_model()
1259            ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25)
1260
1261            @torch._dynamo.optimize("inductor")
1262            def opt_fn(inputs):
1263                return ddp_m(inputs)
1264
1265            opt_outputs = opt_fn(inputs)
1266            self.assertTrue(same(correct_outputs, opt_outputs))
1267
1268    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
1269    def test_graph_split_inductor_layout_optimizations_training(self):
1270        self._test_graph_split_inductor_layout_optimizations_impl(
1271            contextlib.nullcontext
1272        )
1273
1274    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
1275    def test_graph_split_inductor_layout_optimizations_inference(self):
1276        self._test_graph_split_inductor_layout_optimizations_impl(torch.no_grad)
1277
1278    @patch.object(config, "optimize_ddp", True)
1279    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
1280    def test_graph_split_inductor_transpose(self):
1281        assert config.optimize_ddp
1282
1283        B = 100
1284        N = 30
1285        D = 50
1286        K = 70
1287
1288        class Foo(nn.Module):
1289            def __init__(self) -> None:
1290                super().__init__()
1291                self.linear0 = nn.Linear(N, K)
1292                self.linear1 = torch.nn.Linear(D * K, 2048)
1293
1294            def forward(self, x):
1295                xt = x.transpose(2, 1)
1296                xt = self.linear0(xt).flatten(1)
1297                return self.linear1(xt)
1298
1299        mod = Foo().to(self.device)
1300
1301        compiled_mod = torch.compile(mod, backend="inductor")
1302        ddp_compiled_mod = DDP(compiled_mod, device_ids=self.device_ids)
1303
1304        x = torch.randn((B, N, D), dtype=torch.float32, device=self.device)
1305        self.assertTrue(same(mod(x), ddp_compiled_mod(x)))
1306
1307        x_1 = torch.randn((B * 2, N, D), dtype=torch.float32, device=self.device)
1308        self.assertTrue(same(mod(x_1), ddp_compiled_mod(x_1)))
1309
1310        x_2 = torch.randn((B * 3, N, D), dtype=torch.float32, device=self.device)
1311        self.assertTrue(same(mod(x_2), ddp_compiled_mod(x_2)))
1312
1313    @patch.object(config, "optimize_ddp", True)
1314    def test_no_split(self):
1315        """
1316        Ensures the DDPOptimizer returns a correct, compiled module without
1317        introducing graph splits. (Based on model parameters fitting in the bucket)
1318        """
1319        # DDP will always do a 'first bucket' with a really small size;  so only a tiny model will escape this
1320        m, inputs, correct_outputs = self.get_model(hidden_feat=5)
1321        ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=250)
1322        check_splits_compiler = CheckSplitsCompiler()
1323
1324        @torch._dynamo.optimize(check_splits_compiler.compile_fn)
1325        def opt_fn(inputs):
1326            return ddp_m(inputs)
1327
1328        opt_outputs = opt_fn(inputs)
1329        self.assertTrue(same(correct_outputs, opt_outputs))
1330        self.assertEqual(check_splits_compiler.compiler_called, 1)
1331
1332    @patch.object(config, "optimize_ddp", True)
1333    def test_aot_autograd(self):
1334        """
1335        Explicitly check AotAutograd family of compilers work,
1336        since they require example inputs propagated between graph splits.
1337        """
1338        m, inputs, correct_outputs = self.get_model()
1339        ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25)
1340
1341        @torch._dynamo.optimize("aot_eager")
1342        def opt_fn(inputs):
1343            return ddp_m(inputs)
1344
1345        opt_outputs = opt_fn(inputs)
1346        opt_outputs.sum().backward()
1347        self.assertTrue(same(correct_outputs, opt_outputs))
1348
1349    @patch.object(config, "optimize_ddp", True)
1350    def test_custom_layer(self):
1351        """
1352        Just ensures that the appropriate number of splits happen (based on
1353        bucket size and model parameters) - verifies the number of times
1354        the user-provided compiler is called by the DDPOptimizer which is
1355        doing the graph splitting
1356        """
1357        m, inputs, correct_outputs = get_custom_model(self.device)
1358        ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=1)
1359
1360        check_splits_compiler = CheckSplitsCompiler()
1361
1362        @torch._dynamo.optimize(check_splits_compiler.compile_fn)
1363        def opt_fn(inputs):
1364            return ddp_m(*inputs)
1365
1366        opt_outputs = opt_fn(inputs)
1367        self.assertTrue(same(correct_outputs, opt_outputs))
1368        self.assertEqual(check_splits_compiler.compiler_called, 3)
1369
1370    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
1371    def test_empty_graph_inductor(self):
1372        def fn():
1373            get_world_size = torch.distributed.distributed_c10d.get_world_size()
1374            return (get_world_size,)
1375
1376        opt_fn = torch._dynamo.optimize("inductor")(fn)
1377        res = None
1378        try:
1379            res = opt_fn()[0]
1380        except Exception:
1381            pass
1382        self.assertEqual(res, 1)
1383
1384    @patch.object(config, "optimize_ddp", False)
1385    def test_ignored_parameters(self):
1386        """
1387        Verifies ddp graph-split logic ignores parameters marked to ignore on DDP module.
1388        Hooks up graph-split optimizer manually so it can peek at internal state.
1389        """
1390        m, inputs, correct_outputs = get_custom_model(self.device)
1391        parameters_to_ignore = ["seq.2.weight", "seq.4.linear.bias"]
1392        DDP._set_params_and_buffers_to_ignore_for_model(m, parameters_to_ignore)
1393        ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25)
1394        parameter_ids_to_ignore = [
1395            id(ddp_m.module.get_parameter(p)) for p in ddp_m.parameters_to_ignore
1396        ]
1397
1398        check_splits_compiler = CheckSplitsCompiler()
1399        ddp_optimizer = DDPOptimizer(
1400            bucket_bytes_cap=ddp_m.bucket_bytes_cap,
1401            backend_compile_fn=check_splits_compiler.compile_fn,
1402        )
1403
1404        @torch._dynamo.optimize(ddp_optimizer.compile_fn)
1405        def opt_fn(inputs):
1406            return ddp_m(*inputs)
1407
1408        opt_outputs = opt_fn(inputs)
1409        self.assertTrue(same(correct_outputs, opt_outputs))
1410        self.assertEqual(check_splits_compiler.compiler_called, 2)
1411        for b in ddp_optimizer.buckets:
1412            for p_id in b.param_ids:
1413                self.assertFalse(p_id in parameter_ids_to_ignore)
1414
1415    @patch.object(config, "optimize_ddp", True)
1416    def test_higher_order_op(self):
1417        from torch.utils.checkpoint import checkpoint
1418
1419        N = 1000
1420
1421        class InnerModule(torch.nn.Module):
1422            def __init__(self) -> None:
1423                super().__init__()
1424                self.linear1 = torch.nn.Linear(N, N)
1425                self.linear2 = torch.nn.Linear(N, N)
1426
1427            def forward(self, x):
1428                a = self.linear1(x)
1429                a = self.linear2(a)
1430                return a
1431
1432        class MockModule(torch.nn.Module):
1433            def __init__(self) -> None:
1434                super().__init__()
1435                self.inner_mod1 = InnerModule()
1436                self.inner_mod2 = InnerModule()
1437
1438            def forward(self, x):
1439                a = checkpoint(self.inner_mod1, x, use_reentrant=False)
1440                a = torch.cos(a)
1441                a = checkpoint(self.inner_mod2, a, use_reentrant=False)
1442                a = torch.cos(a)
1443                return a
1444
1445        mod = MockModule().cuda()
1446        mod = DDP(mod, bucket_cap_mb=1)
1447        x = torch.randn(N, N, device="cuda", requires_grad=True)
1448        args = (x,)
1449
1450        backend = "aot_eager"
1451        cnt = torch._dynamo.testing.CompileCounterWithBackend(backend)
1452
1453        with self.assertRaisesRegex(
1454            torch._dynamo.exc.BackendCompilerFailed,
1455            "DDPOptimizer backend: Found a higher order op in the graph",
1456        ):
1457            torch.compile(mod, backend=cnt)(*args)
1458
1459    def test_fsdp_orig_params_assert(self):
1460        # Test with basic FSDP wrapping (outer wrap around whole model)
1461        m, inputs, correct_outputs = get_model(f"cuda:{self.rank}")
1462        fsdp_m = FSDP(m, use_orig_params=False)
1463        fsdp_m = torch._dynamo.optimize()(fsdp_m)
1464        self.assertRaisesRegex(
1465            AssertionError,
1466            "Dynamo only supports FSDP with use_orig_params=True",
1467            fsdp_m,
1468            inputs,
1469        )
1470
1471    def test_fsdp_skip_guards(self):
1472        """
1473        It's currently difficult to test dynamo guards.  Most guards tests are indirect- modify something and
1474        observe that the guard in question failed. In this case, since the FSDP guards were already deemed
1475        useless and skipping them is expected to have no practical effect, it's pretty contrived to even try to
1476        make those guards fail.  Instead, we observe the 'guard source' printed by dynamo's comptime print_guards
1477        function.
1478
1479        Note: comptime prints the guards before the time they get installed or not installed, so in both cases
1480        (skip or no skip) the same guards get printed.  The difference is that in the skip case, they show up
1481        with a special 'guard source' which will cuase them to not be installed.  So all we check for is the expected
1482        guard source 'local_fsdp_module'.
1483        """
1484        global GUARDS_FILE
1485        GUARDS_FILE = StringIO()
1486
1487        for skip_guards, expected_guard_source in (
1488            (True, "local_fsdp_module"),
1489            (False, "local_unspecialized_nn_module"),
1490        ):
1491            torch._dynamo.reset()
1492
1493            class ToyModel(nn.Module):
1494                def __init__(self, in_feat=10, hidden_feat=5000, out_feat=5):
1495                    super().__init__()
1496                    self.net = nn.Sequential(
1497                        *[nn.Linear(in_feat, hidden_feat), nn.ReLU()]
1498                        + [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()]
1499                        + [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()]
1500                        + [nn.Linear(hidden_feat, out_feat), nn.ReLU()]
1501                    )
1502
1503                def forward(self, inputs):
1504                    out = self.net(inputs)
1505
1506                    @comptime
1507                    def _(ctx):
1508                        ctx.print_guards(file=GUARDS_FILE)
1509
1510                    return out
1511
1512            device = f"cuda:{self.rank}"
1513            m = ToyModel(
1514                in_feat=10,
1515                hidden_feat=5000,
1516                out_feat=5,
1517            ).to(device)
1518            inputs = torch.rand(20, 10).to(device)
1519            m.apply(init_weights)
1520            correct_outputs = m(inputs)
1521            fsdp_m = FSDP(m, use_orig_params=True)
1522
1523            with torch._dynamo.config.patch(skip_fsdp_guards=skip_guards):
1524                opt_m = torch._dynamo.optimize("aot_eager")(fsdp_m)
1525                outputs = opt_m(inputs)
1526
1527            # far from an exhaustive check of all the expected guards, just check a couple of them.
1528            FileCheck().check("""local "L['self']" TYPE_MATCH""").check(
1529                f"""{expected_guard_source} "L['self']._modules['net']" TYPE_MATCH"""
1530            ).check(
1531                f"""{expected_guard_source} "L['self']._modules['net']._modules['0']" TYPE_MATCH"""
1532            ).run(
1533                GUARDS_FILE.getvalue()
1534            )
1535
1536            self.assertTrue(same(correct_outputs, outputs))
1537
1538    def test_fsdp_skip_register_attr_or_module(self):
1539        """
1540        ensure FSDP module is not registered as attrbutes
1541        in the fx graph
1542        see `not source.guard_source().is_fsdp_module()`
1543        before calling `register_attr_or_module`
1544        in variables/builder.py
1545        """
1546
1547        class ToyModel(nn.Module):
1548            def __init__(self, in_feat=10, hidden_feat=5000, out_feat=5):
1549                super().__init__()
1550                self.net = nn.Sequential(
1551                    *[nn.Linear(in_feat, hidden_feat), nn.ReLU()]
1552                    + [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()]
1553                )
1554
1555            def forward(self, inputs):
1556                out = self.net(inputs)
1557                return out
1558
1559        torch._dynamo.reset()
1560
1561        device = f"cuda:{self.rank}"
1562        m = ToyModel(
1563            in_feat=10,
1564            hidden_feat=5000,
1565            out_feat=5,
1566        ).to(device)
1567        inputs = torch.rand(20, 10).to(device)
1568        m.apply(init_weights)
1569        correct_outputs = m(inputs)
1570        fsdp_m = FSDP(m, use_orig_params=True)
1571
1572        def debug_compiler(gm, _):
1573            for node in gm.graph.nodes:
1574                if node.op == "get_attr":
1575                    for name in [
1576                        "l__self___net_0_weight",
1577                        "l__self___net_0_bias",
1578                        "l__self___net_2_weight",
1579                        "l__self___net_2_bias",
1580                    ]:
1581                        self.assertFalse(
1582                            name in node.name,
1583                            f"FSDP module {name} should not be registered as attributes",
1584                        )
1585            return gm
1586
1587        opt_m = torch._dynamo.optimize(backend=debug_compiler)(fsdp_m)
1588        outputs = opt_m(inputs)
1589
1590        self.assertTrue(same(correct_outputs, outputs))
1591
1592    def test_fsdp_dup_tensors_same_source(self):
1593        """
1594        Tests that FSDP-managed modules' parameters and buffers with the same
1595        source are de-duplicated, meaning that they are each only passed once
1596        as a graph input.
1597        """
1598
1599        class DuplicateModule(nn.Module):
1600            def __init__(self) -> None:
1601                super().__init__()
1602                self._param = torch.randn((3,), device="cuda")
1603                self._buf = torch.nn.Buffer(
1604                    torch.randn((3,), requires_grad=False, device="cuda")
1605                )
1606
1607            def forward(self, x: torch.Tensor) -> torch.Tensor:
1608                # Use `_param` and `_buf` each twice in this compiled forward
1609                # to exercise if they are de-duplicated by TorchDynamo
1610                z = x + self._buf + self._buf
1611                z += self._param + self._param
1612                return z
1613
1614        model = DuplicateModule()
1615        fsdp_model = FSDP(copy.deepcopy(model), use_orig_params=True)
1616        fsdp_model = torch._dynamo.optimize("aot_eager")(fsdp_model)
1617        inp = torch.randn((2, 3), device="cuda")
1618        local_out = model(inp)
1619        fsdp_out = fsdp_model(inp)
1620        self.assertEqual(local_out, fsdp_out)
1621
1622    @patch.object(config, "guard_nn_modules", True)
1623    def test_fsdp_dup_tensors_diff_source(self):
1624        """
1625        Tests that FSDP-managed modules' parameters and buffers with different
1626        source do not result in incorrect AOTAutograd de-dup guards like
1627        ``a is b``, where ``a`` and ``b`` are certainly not the same. We check
1628        this by checking for per-invocation recompiles.
1629        """
1630
1631        class BufModule(nn.Module):
1632            def __init__(self) -> None:
1633                super().__init__()
1634                self._buf = nn.Buffer(
1635                    torch.randn((3,), requires_grad=False, device="cuda")
1636                )
1637
1638            def forward(self, x: torch.Tensor) -> torch.Tensor:
1639                return x + self._buf
1640
1641        class Model(nn.Module):
1642            def __init__(self) -> None:
1643                super().__init__()
1644                self._param = nn.Parameter(torch.randn((1,), device="cuda"))
1645                self._buf_module = BufModule()
1646                # Share the buffer, meaning same tensor but different source
1647                self._buf = self._buf_module._buf
1648
1649            def forward(self, x: torch.Tensor) -> torch.Tensor:
1650                # Use the same buffer tensor twice in the compiled forward,
1651                # including a data mutation to trigger de-dup logic
1652                self._buf.mul_(2)
1653                z = x + self._buf
1654                z = self._buf_module(z)
1655                z += self._param
1656                return z
1657
1658        fsdp_model = FSDP(Model(), use_orig_params=True)
1659        cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
1660        fsdp_model = torch._dynamo.optimize(cnt)(fsdp_model)
1661        inp = torch.randn((2, 3), device="cuda")
1662        for _ in range(15):
1663            fsdp_model(inp)
1664        # Check for no recompiles (if there were incorrect de-dup guards, then
1665        # the frame count would be equal to the number of forward calls)
1666        self.assertEqual(cnt.frame_count, 1)
1667
1668    def test_fsdp_staticmethod(self):
1669        """
1670        Tests that Dynamo compiles staticmethods for FSDP-managed modules
1671        correctly both when the staticmethod is invoked from the class and from
1672        the object itself.
1673        """
1674
1675        class ModuleWithStaticMethod(nn.Module):
1676            def __init__(self, use_self: bool):
1677                super().__init__()
1678                self._use_self = use_self
1679                torch.manual_seed(42)  # force `_param` to be deterministic
1680                self._param = nn.Parameter(torch.randn((3,), device="cuda"))
1681
1682            def forward(self, x: torch.Tensor) -> torch.Tensor:
1683                if self._use_self:
1684                    z = self._add(x, self._param)
1685                else:
1686                    z = ModuleWithStaticMethod._add(x, self._param)
1687                z *= 2
1688                return z
1689
1690            @staticmethod
1691            def _add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
1692                return x + y
1693
1694        model = ModuleWithStaticMethod(False)
1695        x = torch.randn((2, 3), device="cuda")
1696        ref_out = model(x)
1697        test_outs: List[torch.Tensor] = []
1698
1699        for use_self in (False, True):
1700            model = ModuleWithStaticMethod(use_self)
1701            fsdp_model = FSDP(model, use_orig_params=True)
1702            cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
1703            fsdp_model = torch._dynamo.optimize(cnt)(fsdp_model)
1704            test_outs.append(fsdp_model(x))
1705            # Check for no recompiles, which could happen if incorrectly
1706            # passing args to the staticmethod (e.g. doubly passing `self`)
1707            # 3 is expected here for 1 forward.
1708            # Graph 1 should be add and imul
1709            self.assertEqual(cnt.frame_count, 1)
1710        for test_out in test_outs:
1711            self.assertEqual(test_out, ref_out)
1712
1713    def test_async_subclass_no_specialize(self):
1714        cnt = torch._dynamo.testing.CompileCounterWithBackend("eager")
1715
1716        @torch.compile(backend=cnt, fullgraph=True, dynamic=True)
1717        def f(x):
1718            return x + 1
1719
1720        f(_maybe_wrap_tensor(torch.randn(10)))
1721        f(_maybe_wrap_tensor(torch.randn(12)))
1722
1723        self.assertEqual(cnt.frame_count, 1)
1724
1725
1726if __name__ == "__main__":
1727    from torch._dynamo.test_case import run_tests
1728
1729    run_tests()
1730