1import argparse 2import functools 3import importlib 4import os 5 6import torch 7import torch.distributed as dist 8import torch.nn as nn 9from torch._dynamo.testing import reduce_to_scalar_loss 10from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( 11 apply_activation_checkpointing, 12 checkpoint_wrapper, 13 CheckpointImpl, 14) 15from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 16from torch.distributed.fsdp.wrap import ModuleWrapPolicy 17 18 19try: 20 from .torchbench import setup_torchbench_cwd 21except ImportError: 22 from torchbench import setup_torchbench_cwd 23 24from transformers.models.bert.modeling_bert import BertLayer, BertLMPredictionHead 25from transformers.models.t5.modeling_t5 import T5Block 26 27 28def setup(rank, world_size): 29 os.environ["MASTER_ADDR"] = os.getenv("MASTER_ADDR", "localhost") 30 os.environ["MASTER_PORT"] = os.getenv("MASTER_PORT", "12355") 31 os.environ["RANK"] = os.getenv("RANK", "0") 32 os.environ["WORLD_SIZE"] = os.getenv("WORLD_SIZE", "1") 33 dist.init_process_group("nccl") 34 35 36def cleanup(): 37 dist.destroy_process_group() 38 39 40class CustomLinear(torch.nn.Module): 41 def __init__(self, a, b): 42 super().__init__() 43 self.weight = nn.Parameter(torch.randn(a, b)) 44 45 def forward(self, x): 46 return torch.mm(x, self.weight) 47 48 49class MyModule(torch.nn.Module): 50 def __init__(self, a, b): 51 super().__init__() 52 self.net = nn.Sequential( 53 nn.Linear(a, b), 54 nn.ReLU(), 55 ) 56 57 def forward(self, x): 58 return self.net(x) 59 60 61class ToyModel(nn.Module): 62 def __init__(self): 63 super().__init__() 64 self.net = nn.Sequential( 65 *[nn.Linear(10, 10000), nn.ReLU()] 66 + [nn.Linear(10000, 10000), nn.ReLU()] 67 + [MyModule(10000, 10000)] 68 + [MyModule(10000, 1000)] 69 + [MyModule(1000, 1000)] 70 + [MyModule(1000, 1000)] 71 + [MyModule(1000, 1000)] 72 + [MyModule(1000, 1000)] 73 + [MyModule(1000, 1000)] 74 + [MyModule(1000, 1000)] 75 + [MyModule(1000, 1000)] 76 + [nn.Linear(1000, 5)] 77 ) 78 79 def forward(self, x): 80 return self.net(x) 81 82 83def model_iter_fn(model, example_inputs, collect_outputs=False): 84 outputs = model(*example_inputs) 85 loss = reduce_to_scalar_loss(outputs) 86 loss.backward() 87 if collect_outputs: 88 return outputs 89 90 91def get_model(args): 92 if args.torchbench_model: 93 old_cwd = setup_torchbench_cwd() 94 module = importlib.import_module( 95 f"torchbenchmark.models.{args.torchbench_model}" 96 ) 97 benchmark_cls = getattr(module, "Model", None) 98 bm = benchmark_cls(test="train", device=args.device, batch_size=args.batch_size) 99 model, inputs = bm.get_module() 100 elif args.toy_model: 101 model = ToyModel() 102 inputs = (torch.randn(20, 10),) 103 else: 104 raise argparse.ArgumentError( 105 args.torchbench_model, message="Must specify a model" 106 ) 107 108 return model, inputs 109 110 111def fsdp_checkpointing_base(model, blocks): 112 """apply activation checkpointing to model 113 returns None as model is updated directly 114 """ 115 non_reentrant_wrapper = functools.partial( 116 checkpoint_wrapper, 117 offload_to_cpu=False, 118 checkpoint_impl=CheckpointImpl.NO_REENTRANT, 119 ) 120 121 def check_fn(submodule): 122 return isinstance(submodule, blocks) 123 124 apply_activation_checkpointing( 125 model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn 126 ) 127 128 129MODEL_FSDP_WRAP = { 130 "toy_model": (MyModule,), 131 "hf_Bert": (BertLayer, BertLMPredictionHead), 132 "hf_T5": (T5Block,), 133} 134 135 136def apply_fsdp(args, model, use_checkpointing=False, use_wrap_policy=True): 137 wrap_policy = None 138 blocks = MODEL_FSDP_WRAP[ 139 "toy_model" if model.__class__ is ToyModel else args.torchbench_model 140 ] 141 if use_wrap_policy: 142 wrap_policy = ModuleWrapPolicy(blocks) 143 144 model = FSDP(model, auto_wrap_policy=wrap_policy, use_orig_params=True) 145 if use_checkpointing: 146 fsdp_checkpointing_base(model, blocks) 147 return model 148