xref: /aosp_15_r20/external/pytorch/benchmarks/dynamo/dist_util.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerimport argparse
2*da0073e9SAndroid Build Coastguard Workerimport functools
3*da0073e9SAndroid Build Coastguard Workerimport importlib
4*da0073e9SAndroid Build Coastguard Workerimport os
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Workerimport torch
7*da0073e9SAndroid Build Coastguard Workerimport torch.distributed as dist
8*da0073e9SAndroid Build Coastguard Workerimport torch.nn as nn
9*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.testing import reduce_to_scalar_loss
10*da0073e9SAndroid Build Coastguard Workerfrom torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
11*da0073e9SAndroid Build Coastguard Worker    apply_activation_checkpointing,
12*da0073e9SAndroid Build Coastguard Worker    checkpoint_wrapper,
13*da0073e9SAndroid Build Coastguard Worker    CheckpointImpl,
14*da0073e9SAndroid Build Coastguard Worker)
15*da0073e9SAndroid Build Coastguard Workerfrom torch.distributed.fsdp import FullyShardedDataParallel as FSDP
16*da0073e9SAndroid Build Coastguard Workerfrom torch.distributed.fsdp.wrap import ModuleWrapPolicy
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Workertry:
20*da0073e9SAndroid Build Coastguard Worker    from .torchbench import setup_torchbench_cwd
21*da0073e9SAndroid Build Coastguard Workerexcept ImportError:
22*da0073e9SAndroid Build Coastguard Worker    from torchbench import setup_torchbench_cwd
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Workerfrom transformers.models.bert.modeling_bert import BertLayer, BertLMPredictionHead
25*da0073e9SAndroid Build Coastguard Workerfrom transformers.models.t5.modeling_t5 import T5Block
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Workerdef setup(rank, world_size):
29*da0073e9SAndroid Build Coastguard Worker    os.environ["MASTER_ADDR"] = os.getenv("MASTER_ADDR", "localhost")
30*da0073e9SAndroid Build Coastguard Worker    os.environ["MASTER_PORT"] = os.getenv("MASTER_PORT", "12355")
31*da0073e9SAndroid Build Coastguard Worker    os.environ["RANK"] = os.getenv("RANK", "0")
32*da0073e9SAndroid Build Coastguard Worker    os.environ["WORLD_SIZE"] = os.getenv("WORLD_SIZE", "1")
33*da0073e9SAndroid Build Coastguard Worker    dist.init_process_group("nccl")
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Workerdef cleanup():
37*da0073e9SAndroid Build Coastguard Worker    dist.destroy_process_group()
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Workerclass CustomLinear(torch.nn.Module):
41*da0073e9SAndroid Build Coastguard Worker    def __init__(self, a, b):
42*da0073e9SAndroid Build Coastguard Worker        super().__init__()
43*da0073e9SAndroid Build Coastguard Worker        self.weight = nn.Parameter(torch.randn(a, b))
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
46*da0073e9SAndroid Build Coastguard Worker        return torch.mm(x, self.weight)
47*da0073e9SAndroid Build Coastguard Worker
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Workerclass MyModule(torch.nn.Module):
50*da0073e9SAndroid Build Coastguard Worker    def __init__(self, a, b):
51*da0073e9SAndroid Build Coastguard Worker        super().__init__()
52*da0073e9SAndroid Build Coastguard Worker        self.net = nn.Sequential(
53*da0073e9SAndroid Build Coastguard Worker            nn.Linear(a, b),
54*da0073e9SAndroid Build Coastguard Worker            nn.ReLU(),
55*da0073e9SAndroid Build Coastguard Worker        )
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
58*da0073e9SAndroid Build Coastguard Worker        return self.net(x)
59*da0073e9SAndroid Build Coastguard Worker
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard Workerclass ToyModel(nn.Module):
62*da0073e9SAndroid Build Coastguard Worker    def __init__(self):
63*da0073e9SAndroid Build Coastguard Worker        super().__init__()
64*da0073e9SAndroid Build Coastguard Worker        self.net = nn.Sequential(
65*da0073e9SAndroid Build Coastguard Worker            *[nn.Linear(10, 10000), nn.ReLU()]
66*da0073e9SAndroid Build Coastguard Worker            + [nn.Linear(10000, 10000), nn.ReLU()]
67*da0073e9SAndroid Build Coastguard Worker            + [MyModule(10000, 10000)]
68*da0073e9SAndroid Build Coastguard Worker            + [MyModule(10000, 1000)]
69*da0073e9SAndroid Build Coastguard Worker            + [MyModule(1000, 1000)]
70*da0073e9SAndroid Build Coastguard Worker            + [MyModule(1000, 1000)]
71*da0073e9SAndroid Build Coastguard Worker            + [MyModule(1000, 1000)]
72*da0073e9SAndroid Build Coastguard Worker            + [MyModule(1000, 1000)]
73*da0073e9SAndroid Build Coastguard Worker            + [MyModule(1000, 1000)]
74*da0073e9SAndroid Build Coastguard Worker            + [MyModule(1000, 1000)]
75*da0073e9SAndroid Build Coastguard Worker            + [MyModule(1000, 1000)]
76*da0073e9SAndroid Build Coastguard Worker            + [nn.Linear(1000, 5)]
77*da0073e9SAndroid Build Coastguard Worker        )
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
80*da0073e9SAndroid Build Coastguard Worker        return self.net(x)
81*da0073e9SAndroid Build Coastguard Worker
82*da0073e9SAndroid Build Coastguard Worker
83*da0073e9SAndroid Build Coastguard Workerdef model_iter_fn(model, example_inputs, collect_outputs=False):
84*da0073e9SAndroid Build Coastguard Worker    outputs = model(*example_inputs)
85*da0073e9SAndroid Build Coastguard Worker    loss = reduce_to_scalar_loss(outputs)
86*da0073e9SAndroid Build Coastguard Worker    loss.backward()
87*da0073e9SAndroid Build Coastguard Worker    if collect_outputs:
88*da0073e9SAndroid Build Coastguard Worker        return outputs
89*da0073e9SAndroid Build Coastguard Worker
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Workerdef get_model(args):
92*da0073e9SAndroid Build Coastguard Worker    if args.torchbench_model:
93*da0073e9SAndroid Build Coastguard Worker        old_cwd = setup_torchbench_cwd()
94*da0073e9SAndroid Build Coastguard Worker        module = importlib.import_module(
95*da0073e9SAndroid Build Coastguard Worker            f"torchbenchmark.models.{args.torchbench_model}"
96*da0073e9SAndroid Build Coastguard Worker        )
97*da0073e9SAndroid Build Coastguard Worker        benchmark_cls = getattr(module, "Model", None)
98*da0073e9SAndroid Build Coastguard Worker        bm = benchmark_cls(test="train", device=args.device, batch_size=args.batch_size)
99*da0073e9SAndroid Build Coastguard Worker        model, inputs = bm.get_module()
100*da0073e9SAndroid Build Coastguard Worker    elif args.toy_model:
101*da0073e9SAndroid Build Coastguard Worker        model = ToyModel()
102*da0073e9SAndroid Build Coastguard Worker        inputs = (torch.randn(20, 10),)
103*da0073e9SAndroid Build Coastguard Worker    else:
104*da0073e9SAndroid Build Coastguard Worker        raise argparse.ArgumentError(
105*da0073e9SAndroid Build Coastguard Worker            args.torchbench_model, message="Must specify a model"
106*da0073e9SAndroid Build Coastguard Worker        )
107*da0073e9SAndroid Build Coastguard Worker
108*da0073e9SAndroid Build Coastguard Worker    return model, inputs
109*da0073e9SAndroid Build Coastguard Worker
110*da0073e9SAndroid Build Coastguard Worker
111*da0073e9SAndroid Build Coastguard Workerdef fsdp_checkpointing_base(model, blocks):
112*da0073e9SAndroid Build Coastguard Worker    """apply activation checkpointing to model
113*da0073e9SAndroid Build Coastguard Worker    returns None as model is updated directly
114*da0073e9SAndroid Build Coastguard Worker    """
115*da0073e9SAndroid Build Coastguard Worker    non_reentrant_wrapper = functools.partial(
116*da0073e9SAndroid Build Coastguard Worker        checkpoint_wrapper,
117*da0073e9SAndroid Build Coastguard Worker        offload_to_cpu=False,
118*da0073e9SAndroid Build Coastguard Worker        checkpoint_impl=CheckpointImpl.NO_REENTRANT,
119*da0073e9SAndroid Build Coastguard Worker    )
120*da0073e9SAndroid Build Coastguard Worker
121*da0073e9SAndroid Build Coastguard Worker    def check_fn(submodule):
122*da0073e9SAndroid Build Coastguard Worker        return isinstance(submodule, blocks)
123*da0073e9SAndroid Build Coastguard Worker
124*da0073e9SAndroid Build Coastguard Worker    apply_activation_checkpointing(
125*da0073e9SAndroid Build Coastguard Worker        model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn
126*da0073e9SAndroid Build Coastguard Worker    )
127*da0073e9SAndroid Build Coastguard Worker
128*da0073e9SAndroid Build Coastguard Worker
129*da0073e9SAndroid Build Coastguard WorkerMODEL_FSDP_WRAP = {
130*da0073e9SAndroid Build Coastguard Worker    "toy_model": (MyModule,),
131*da0073e9SAndroid Build Coastguard Worker    "hf_Bert": (BertLayer, BertLMPredictionHead),
132*da0073e9SAndroid Build Coastguard Worker    "hf_T5": (T5Block,),
133*da0073e9SAndroid Build Coastguard Worker}
134*da0073e9SAndroid Build Coastguard Worker
135*da0073e9SAndroid Build Coastguard Worker
136*da0073e9SAndroid Build Coastguard Workerdef apply_fsdp(args, model, use_checkpointing=False, use_wrap_policy=True):
137*da0073e9SAndroid Build Coastguard Worker    wrap_policy = None
138*da0073e9SAndroid Build Coastguard Worker    blocks = MODEL_FSDP_WRAP[
139*da0073e9SAndroid Build Coastguard Worker        "toy_model" if model.__class__ is ToyModel else args.torchbench_model
140*da0073e9SAndroid Build Coastguard Worker    ]
141*da0073e9SAndroid Build Coastguard Worker    if use_wrap_policy:
142*da0073e9SAndroid Build Coastguard Worker        wrap_policy = ModuleWrapPolicy(blocks)
143*da0073e9SAndroid Build Coastguard Worker
144*da0073e9SAndroid Build Coastguard Worker    model = FSDP(model, auto_wrap_policy=wrap_policy, use_orig_params=True)
145*da0073e9SAndroid Build Coastguard Worker    if use_checkpointing:
146*da0073e9SAndroid Build Coastguard Worker        fsdp_checkpointing_base(model, blocks)
147*da0073e9SAndroid Build Coastguard Worker    return model
148