1*da0073e9SAndroid Build Coastguard Workerimport argparse 2*da0073e9SAndroid Build Coastguard Workerimport functools 3*da0073e9SAndroid Build Coastguard Workerimport importlib 4*da0073e9SAndroid Build Coastguard Workerimport os 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Workerimport torch 7*da0073e9SAndroid Build Coastguard Workerimport torch.distributed as dist 8*da0073e9SAndroid Build Coastguard Workerimport torch.nn as nn 9*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.testing import reduce_to_scalar_loss 10*da0073e9SAndroid Build Coastguard Workerfrom torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( 11*da0073e9SAndroid Build Coastguard Worker apply_activation_checkpointing, 12*da0073e9SAndroid Build Coastguard Worker checkpoint_wrapper, 13*da0073e9SAndroid Build Coastguard Worker CheckpointImpl, 14*da0073e9SAndroid Build Coastguard Worker) 15*da0073e9SAndroid Build Coastguard Workerfrom torch.distributed.fsdp import FullyShardedDataParallel as FSDP 16*da0073e9SAndroid Build Coastguard Workerfrom torch.distributed.fsdp.wrap import ModuleWrapPolicy 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Workertry: 20*da0073e9SAndroid Build Coastguard Worker from .torchbench import setup_torchbench_cwd 21*da0073e9SAndroid Build Coastguard Workerexcept ImportError: 22*da0073e9SAndroid Build Coastguard Worker from torchbench import setup_torchbench_cwd 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Workerfrom transformers.models.bert.modeling_bert import BertLayer, BertLMPredictionHead 25*da0073e9SAndroid Build Coastguard Workerfrom transformers.models.t5.modeling_t5 import T5Block 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Workerdef setup(rank, world_size): 29*da0073e9SAndroid Build Coastguard Worker os.environ["MASTER_ADDR"] = os.getenv("MASTER_ADDR", "localhost") 30*da0073e9SAndroid Build Coastguard Worker os.environ["MASTER_PORT"] = os.getenv("MASTER_PORT", "12355") 31*da0073e9SAndroid Build Coastguard Worker os.environ["RANK"] = os.getenv("RANK", "0") 32*da0073e9SAndroid Build Coastguard Worker os.environ["WORLD_SIZE"] = os.getenv("WORLD_SIZE", "1") 33*da0073e9SAndroid Build Coastguard Worker dist.init_process_group("nccl") 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Worker 36*da0073e9SAndroid Build Coastguard Workerdef cleanup(): 37*da0073e9SAndroid Build Coastguard Worker dist.destroy_process_group() 38*da0073e9SAndroid Build Coastguard Worker 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Workerclass CustomLinear(torch.nn.Module): 41*da0073e9SAndroid Build Coastguard Worker def __init__(self, a, b): 42*da0073e9SAndroid Build Coastguard Worker super().__init__() 43*da0073e9SAndroid Build Coastguard Worker self.weight = nn.Parameter(torch.randn(a, b)) 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 46*da0073e9SAndroid Build Coastguard Worker return torch.mm(x, self.weight) 47*da0073e9SAndroid Build Coastguard Worker 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Workerclass MyModule(torch.nn.Module): 50*da0073e9SAndroid Build Coastguard Worker def __init__(self, a, b): 51*da0073e9SAndroid Build Coastguard Worker super().__init__() 52*da0073e9SAndroid Build Coastguard Worker self.net = nn.Sequential( 53*da0073e9SAndroid Build Coastguard Worker nn.Linear(a, b), 54*da0073e9SAndroid Build Coastguard Worker nn.ReLU(), 55*da0073e9SAndroid Build Coastguard Worker ) 56*da0073e9SAndroid Build Coastguard Worker 57*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 58*da0073e9SAndroid Build Coastguard Worker return self.net(x) 59*da0073e9SAndroid Build Coastguard Worker 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard Workerclass ToyModel(nn.Module): 62*da0073e9SAndroid Build Coastguard Worker def __init__(self): 63*da0073e9SAndroid Build Coastguard Worker super().__init__() 64*da0073e9SAndroid Build Coastguard Worker self.net = nn.Sequential( 65*da0073e9SAndroid Build Coastguard Worker *[nn.Linear(10, 10000), nn.ReLU()] 66*da0073e9SAndroid Build Coastguard Worker + [nn.Linear(10000, 10000), nn.ReLU()] 67*da0073e9SAndroid Build Coastguard Worker + [MyModule(10000, 10000)] 68*da0073e9SAndroid Build Coastguard Worker + [MyModule(10000, 1000)] 69*da0073e9SAndroid Build Coastguard Worker + [MyModule(1000, 1000)] 70*da0073e9SAndroid Build Coastguard Worker + [MyModule(1000, 1000)] 71*da0073e9SAndroid Build Coastguard Worker + [MyModule(1000, 1000)] 72*da0073e9SAndroid Build Coastguard Worker + [MyModule(1000, 1000)] 73*da0073e9SAndroid Build Coastguard Worker + [MyModule(1000, 1000)] 74*da0073e9SAndroid Build Coastguard Worker + [MyModule(1000, 1000)] 75*da0073e9SAndroid Build Coastguard Worker + [MyModule(1000, 1000)] 76*da0073e9SAndroid Build Coastguard Worker + [nn.Linear(1000, 5)] 77*da0073e9SAndroid Build Coastguard Worker ) 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 80*da0073e9SAndroid Build Coastguard Worker return self.net(x) 81*da0073e9SAndroid Build Coastguard Worker 82*da0073e9SAndroid Build Coastguard Worker 83*da0073e9SAndroid Build Coastguard Workerdef model_iter_fn(model, example_inputs, collect_outputs=False): 84*da0073e9SAndroid Build Coastguard Worker outputs = model(*example_inputs) 85*da0073e9SAndroid Build Coastguard Worker loss = reduce_to_scalar_loss(outputs) 86*da0073e9SAndroid Build Coastguard Worker loss.backward() 87*da0073e9SAndroid Build Coastguard Worker if collect_outputs: 88*da0073e9SAndroid Build Coastguard Worker return outputs 89*da0073e9SAndroid Build Coastguard Worker 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Workerdef get_model(args): 92*da0073e9SAndroid Build Coastguard Worker if args.torchbench_model: 93*da0073e9SAndroid Build Coastguard Worker old_cwd = setup_torchbench_cwd() 94*da0073e9SAndroid Build Coastguard Worker module = importlib.import_module( 95*da0073e9SAndroid Build Coastguard Worker f"torchbenchmark.models.{args.torchbench_model}" 96*da0073e9SAndroid Build Coastguard Worker ) 97*da0073e9SAndroid Build Coastguard Worker benchmark_cls = getattr(module, "Model", None) 98*da0073e9SAndroid Build Coastguard Worker bm = benchmark_cls(test="train", device=args.device, batch_size=args.batch_size) 99*da0073e9SAndroid Build Coastguard Worker model, inputs = bm.get_module() 100*da0073e9SAndroid Build Coastguard Worker elif args.toy_model: 101*da0073e9SAndroid Build Coastguard Worker model = ToyModel() 102*da0073e9SAndroid Build Coastguard Worker inputs = (torch.randn(20, 10),) 103*da0073e9SAndroid Build Coastguard Worker else: 104*da0073e9SAndroid Build Coastguard Worker raise argparse.ArgumentError( 105*da0073e9SAndroid Build Coastguard Worker args.torchbench_model, message="Must specify a model" 106*da0073e9SAndroid Build Coastguard Worker ) 107*da0073e9SAndroid Build Coastguard Worker 108*da0073e9SAndroid Build Coastguard Worker return model, inputs 109*da0073e9SAndroid Build Coastguard Worker 110*da0073e9SAndroid Build Coastguard Worker 111*da0073e9SAndroid Build Coastguard Workerdef fsdp_checkpointing_base(model, blocks): 112*da0073e9SAndroid Build Coastguard Worker """apply activation checkpointing to model 113*da0073e9SAndroid Build Coastguard Worker returns None as model is updated directly 114*da0073e9SAndroid Build Coastguard Worker """ 115*da0073e9SAndroid Build Coastguard Worker non_reentrant_wrapper = functools.partial( 116*da0073e9SAndroid Build Coastguard Worker checkpoint_wrapper, 117*da0073e9SAndroid Build Coastguard Worker offload_to_cpu=False, 118*da0073e9SAndroid Build Coastguard Worker checkpoint_impl=CheckpointImpl.NO_REENTRANT, 119*da0073e9SAndroid Build Coastguard Worker ) 120*da0073e9SAndroid Build Coastguard Worker 121*da0073e9SAndroid Build Coastguard Worker def check_fn(submodule): 122*da0073e9SAndroid Build Coastguard Worker return isinstance(submodule, blocks) 123*da0073e9SAndroid Build Coastguard Worker 124*da0073e9SAndroid Build Coastguard Worker apply_activation_checkpointing( 125*da0073e9SAndroid Build Coastguard Worker model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn 126*da0073e9SAndroid Build Coastguard Worker ) 127*da0073e9SAndroid Build Coastguard Worker 128*da0073e9SAndroid Build Coastguard Worker 129*da0073e9SAndroid Build Coastguard WorkerMODEL_FSDP_WRAP = { 130*da0073e9SAndroid Build Coastguard Worker "toy_model": (MyModule,), 131*da0073e9SAndroid Build Coastguard Worker "hf_Bert": (BertLayer, BertLMPredictionHead), 132*da0073e9SAndroid Build Coastguard Worker "hf_T5": (T5Block,), 133*da0073e9SAndroid Build Coastguard Worker} 134*da0073e9SAndroid Build Coastguard Worker 135*da0073e9SAndroid Build Coastguard Worker 136*da0073e9SAndroid Build Coastguard Workerdef apply_fsdp(args, model, use_checkpointing=False, use_wrap_policy=True): 137*da0073e9SAndroid Build Coastguard Worker wrap_policy = None 138*da0073e9SAndroid Build Coastguard Worker blocks = MODEL_FSDP_WRAP[ 139*da0073e9SAndroid Build Coastguard Worker "toy_model" if model.__class__ is ToyModel else args.torchbench_model 140*da0073e9SAndroid Build Coastguard Worker ] 141*da0073e9SAndroid Build Coastguard Worker if use_wrap_policy: 142*da0073e9SAndroid Build Coastguard Worker wrap_policy = ModuleWrapPolicy(blocks) 143*da0073e9SAndroid Build Coastguard Worker 144*da0073e9SAndroid Build Coastguard Worker model = FSDP(model, auto_wrap_policy=wrap_policy, use_orig_params=True) 145*da0073e9SAndroid Build Coastguard Worker if use_checkpointing: 146*da0073e9SAndroid Build Coastguard Worker fsdp_checkpointing_base(model, blocks) 147*da0073e9SAndroid Build Coastguard Worker return model 148