import argparse import functools import importlib import os import torch import torch.distributed as dist import torch.nn as nn from torch._dynamo.testing import reduce_to_scalar_loss from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( apply_activation_checkpointing, checkpoint_wrapper, CheckpointImpl, ) from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.wrap import ModuleWrapPolicy try: from .torchbench import setup_torchbench_cwd except ImportError: from torchbench import setup_torchbench_cwd from transformers.models.bert.modeling_bert import BertLayer, BertLMPredictionHead from transformers.models.t5.modeling_t5 import T5Block def setup(rank, world_size): os.environ["MASTER_ADDR"] = os.getenv("MASTER_ADDR", "localhost") os.environ["MASTER_PORT"] = os.getenv("MASTER_PORT", "12355") os.environ["RANK"] = os.getenv("RANK", "0") os.environ["WORLD_SIZE"] = os.getenv("WORLD_SIZE", "1") dist.init_process_group("nccl") def cleanup(): dist.destroy_process_group() class CustomLinear(torch.nn.Module): def __init__(self, a, b): super().__init__() self.weight = nn.Parameter(torch.randn(a, b)) def forward(self, x): return torch.mm(x, self.weight) class MyModule(torch.nn.Module): def __init__(self, a, b): super().__init__() self.net = nn.Sequential( nn.Linear(a, b), nn.ReLU(), ) def forward(self, x): return self.net(x) class ToyModel(nn.Module): def __init__(self): super().__init__() self.net = nn.Sequential( *[nn.Linear(10, 10000), nn.ReLU()] + [nn.Linear(10000, 10000), nn.ReLU()] + [MyModule(10000, 10000)] + [MyModule(10000, 1000)] + [MyModule(1000, 1000)] + [MyModule(1000, 1000)] + [MyModule(1000, 1000)] + [MyModule(1000, 1000)] + [MyModule(1000, 1000)] + [MyModule(1000, 1000)] + [MyModule(1000, 1000)] + [nn.Linear(1000, 5)] ) def forward(self, x): return self.net(x) def model_iter_fn(model, example_inputs, collect_outputs=False): outputs = model(*example_inputs) loss = reduce_to_scalar_loss(outputs) loss.backward() if collect_outputs: return outputs def get_model(args): if args.torchbench_model: old_cwd = setup_torchbench_cwd() module = importlib.import_module( f"torchbenchmark.models.{args.torchbench_model}" ) benchmark_cls = getattr(module, "Model", None) bm = benchmark_cls(test="train", device=args.device, batch_size=args.batch_size) model, inputs = bm.get_module() elif args.toy_model: model = ToyModel() inputs = (torch.randn(20, 10),) else: raise argparse.ArgumentError( args.torchbench_model, message="Must specify a model" ) return model, inputs def fsdp_checkpointing_base(model, blocks): """apply activation checkpointing to model returns None as model is updated directly """ non_reentrant_wrapper = functools.partial( checkpoint_wrapper, offload_to_cpu=False, checkpoint_impl=CheckpointImpl.NO_REENTRANT, ) def check_fn(submodule): return isinstance(submodule, blocks) apply_activation_checkpointing( model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn ) MODEL_FSDP_WRAP = { "toy_model": (MyModule,), "hf_Bert": (BertLayer, BertLMPredictionHead), "hf_T5": (T5Block,), } def apply_fsdp(args, model, use_checkpointing=False, use_wrap_policy=True): wrap_policy = None blocks = MODEL_FSDP_WRAP[ "toy_model" if model.__class__ is ToyModel else args.torchbench_model ] if use_wrap_policy: wrap_policy = ModuleWrapPolicy(blocks) model = FSDP(model, auto_wrap_policy=wrap_policy, use_orig_params=True) if use_checkpointing: fsdp_checkpointing_base(model, blocks) return model