1# Owner(s): ["oncall: distributed"] 2 3import sys 4 5import torch 6from torch import distributed as dist 7from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 8from torch.nn import Linear, Module 9from torch.optim import SGD 10from torch.testing._internal.common_distributed import skip_if_lt_x_gpu 11from torch.testing._internal.common_fsdp import FSDPTest 12from torch.testing._internal.common_utils import ( 13 instantiate_parametrized_tests, 14 parametrize, 15 run_tests, 16 subtest, 17 TEST_WITH_DEV_DBG_ASAN, 18) 19 20 21if not dist.is_available(): 22 print("Distributed not available, skipping tests", file=sys.stderr) 23 sys.exit(0) 24 25if TEST_WITH_DEV_DBG_ASAN: 26 print( 27 "Skip dev-asan as torch + multiprocessing spawn have known issues", 28 file=sys.stderr, 29 ) 30 sys.exit(0) 31 32 33class TestInput(FSDPTest): 34 @property 35 def world_size(self): 36 return 1 37 38 @skip_if_lt_x_gpu(1) 39 @parametrize("input_cls", [subtest(dict, name="dict"), subtest(list, name="list")]) 40 def test_input_type(self, input_cls): 41 """Test FSDP with input being a list or a dict, only single GPU.""" 42 43 class Model(Module): 44 def __init__(self) -> None: 45 super().__init__() 46 self.layer = Linear(4, 4) 47 48 def forward(self, input): 49 if isinstance(input, list): 50 input = input[0] 51 else: 52 assert isinstance(input, dict), input 53 input = input["in"] 54 return self.layer(input) 55 56 model = FSDP(Model()).cuda() 57 optim = SGD(model.parameters(), lr=0.1) 58 59 for _ in range(5): 60 in_data = torch.rand(64, 4).cuda() 61 in_data.requires_grad = True 62 if input_cls is list: 63 in_data = [in_data] 64 else: 65 self.assertTrue(input_cls is dict) 66 in_data = {"in": in_data} 67 68 out = model(in_data) 69 out.sum().backward() 70 optim.step() 71 optim.zero_grad() 72 73 74instantiate_parametrized_tests(TestInput) 75 76if __name__ == "__main__": 77 run_tests() 78