xref: /aosp_15_r20/external/pytorch/benchmarks/dynamo/dist_util.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import argparse
2import functools
3import importlib
4import os
5
6import torch
7import torch.distributed as dist
8import torch.nn as nn
9from torch._dynamo.testing import reduce_to_scalar_loss
10from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
11    apply_activation_checkpointing,
12    checkpoint_wrapper,
13    CheckpointImpl,
14)
15from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
16from torch.distributed.fsdp.wrap import ModuleWrapPolicy
17
18
19try:
20    from .torchbench import setup_torchbench_cwd
21except ImportError:
22    from torchbench import setup_torchbench_cwd
23
24from transformers.models.bert.modeling_bert import BertLayer, BertLMPredictionHead
25from transformers.models.t5.modeling_t5 import T5Block
26
27
28def setup(rank, world_size):
29    os.environ["MASTER_ADDR"] = os.getenv("MASTER_ADDR", "localhost")
30    os.environ["MASTER_PORT"] = os.getenv("MASTER_PORT", "12355")
31    os.environ["RANK"] = os.getenv("RANK", "0")
32    os.environ["WORLD_SIZE"] = os.getenv("WORLD_SIZE", "1")
33    dist.init_process_group("nccl")
34
35
36def cleanup():
37    dist.destroy_process_group()
38
39
40class CustomLinear(torch.nn.Module):
41    def __init__(self, a, b):
42        super().__init__()
43        self.weight = nn.Parameter(torch.randn(a, b))
44
45    def forward(self, x):
46        return torch.mm(x, self.weight)
47
48
49class MyModule(torch.nn.Module):
50    def __init__(self, a, b):
51        super().__init__()
52        self.net = nn.Sequential(
53            nn.Linear(a, b),
54            nn.ReLU(),
55        )
56
57    def forward(self, x):
58        return self.net(x)
59
60
61class ToyModel(nn.Module):
62    def __init__(self):
63        super().__init__()
64        self.net = nn.Sequential(
65            *[nn.Linear(10, 10000), nn.ReLU()]
66            + [nn.Linear(10000, 10000), nn.ReLU()]
67            + [MyModule(10000, 10000)]
68            + [MyModule(10000, 1000)]
69            + [MyModule(1000, 1000)]
70            + [MyModule(1000, 1000)]
71            + [MyModule(1000, 1000)]
72            + [MyModule(1000, 1000)]
73            + [MyModule(1000, 1000)]
74            + [MyModule(1000, 1000)]
75            + [MyModule(1000, 1000)]
76            + [nn.Linear(1000, 5)]
77        )
78
79    def forward(self, x):
80        return self.net(x)
81
82
83def model_iter_fn(model, example_inputs, collect_outputs=False):
84    outputs = model(*example_inputs)
85    loss = reduce_to_scalar_loss(outputs)
86    loss.backward()
87    if collect_outputs:
88        return outputs
89
90
91def get_model(args):
92    if args.torchbench_model:
93        old_cwd = setup_torchbench_cwd()
94        module = importlib.import_module(
95            f"torchbenchmark.models.{args.torchbench_model}"
96        )
97        benchmark_cls = getattr(module, "Model", None)
98        bm = benchmark_cls(test="train", device=args.device, batch_size=args.batch_size)
99        model, inputs = bm.get_module()
100    elif args.toy_model:
101        model = ToyModel()
102        inputs = (torch.randn(20, 10),)
103    else:
104        raise argparse.ArgumentError(
105            args.torchbench_model, message="Must specify a model"
106        )
107
108    return model, inputs
109
110
111def fsdp_checkpointing_base(model, blocks):
112    """apply activation checkpointing to model
113    returns None as model is updated directly
114    """
115    non_reentrant_wrapper = functools.partial(
116        checkpoint_wrapper,
117        offload_to_cpu=False,
118        checkpoint_impl=CheckpointImpl.NO_REENTRANT,
119    )
120
121    def check_fn(submodule):
122        return isinstance(submodule, blocks)
123
124    apply_activation_checkpointing(
125        model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn
126    )
127
128
129MODEL_FSDP_WRAP = {
130    "toy_model": (MyModule,),
131    "hf_Bert": (BertLayer, BertLMPredictionHead),
132    "hf_T5": (T5Block,),
133}
134
135
136def apply_fsdp(args, model, use_checkpointing=False, use_wrap_policy=True):
137    wrap_policy = None
138    blocks = MODEL_FSDP_WRAP[
139        "toy_model" if model.__class__ is ToyModel else args.torchbench_model
140    ]
141    if use_wrap_policy:
142        wrap_policy = ModuleWrapPolicy(blocks)
143
144    model = FSDP(model, auto_wrap_policy=wrap_policy, use_orig_params=True)
145    if use_checkpointing:
146        fsdp_checkpointing_base(model, blocks)
147    return model
148