xref: /aosp_15_r20/external/pytorch/test/distributed/fsdp/test_fsdp_input.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3import sys
4
5import torch
6from torch import distributed as dist
7from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
8from torch.nn import Linear, Module
9from torch.optim import SGD
10from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
11from torch.testing._internal.common_fsdp import FSDPTest
12from torch.testing._internal.common_utils import (
13    instantiate_parametrized_tests,
14    parametrize,
15    run_tests,
16    subtest,
17    TEST_WITH_DEV_DBG_ASAN,
18)
19
20
21if not dist.is_available():
22    print("Distributed not available, skipping tests", file=sys.stderr)
23    sys.exit(0)
24
25if TEST_WITH_DEV_DBG_ASAN:
26    print(
27        "Skip dev-asan as torch + multiprocessing spawn have known issues",
28        file=sys.stderr,
29    )
30    sys.exit(0)
31
32
33class TestInput(FSDPTest):
34    @property
35    def world_size(self):
36        return 1
37
38    @skip_if_lt_x_gpu(1)
39    @parametrize("input_cls", [subtest(dict, name="dict"), subtest(list, name="list")])
40    def test_input_type(self, input_cls):
41        """Test FSDP with input being a list or a dict, only single GPU."""
42
43        class Model(Module):
44            def __init__(self) -> None:
45                super().__init__()
46                self.layer = Linear(4, 4)
47
48            def forward(self, input):
49                if isinstance(input, list):
50                    input = input[0]
51                else:
52                    assert isinstance(input, dict), input
53                    input = input["in"]
54                return self.layer(input)
55
56        model = FSDP(Model()).cuda()
57        optim = SGD(model.parameters(), lr=0.1)
58
59        for _ in range(5):
60            in_data = torch.rand(64, 4).cuda()
61            in_data.requires_grad = True
62            if input_cls is list:
63                in_data = [in_data]
64            else:
65                self.assertTrue(input_cls is dict)
66                in_data = {"in": in_data}
67
68            out = model(in_data)
69            out.sum().backward()
70            optim.step()
71            optim.zero_grad()
72
73
74instantiate_parametrized_tests(TestInput)
75
76if __name__ == "__main__":
77    run_tests()
78