xref: /aosp_15_r20/external/pytorch/test/distributed/fsdp/test_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3import random
4import sys
5import unittest
6from collections import OrderedDict
7from dataclasses import dataclass
8from typing import List
9
10import torch
11import torch.nn as nn
12from torch import distributed as dist
13from torch.distributed.utils import _apply_to_tensors, _replace_by_prefix
14from torch.testing._internal.common_utils import (
15    instantiate_parametrized_tests,
16    parametrize,
17    run_tests,
18    subtest,
19    TEST_WITH_DEV_DBG_ASAN,
20    TestCase,
21)
22
23
24if not dist.is_available():
25    print("Distributed not available, skipping tests", file=sys.stderr)
26    sys.exit(0)
27
28if TEST_WITH_DEV_DBG_ASAN:
29    print(
30        "Skip dev-asan as torch + multiprocessing spawn have known issues",
31        file=sys.stderr,
32    )
33    sys.exit(0)
34
35
36class TestUtils(TestCase):
37    @parametrize(
38        "devices", [["cpu"], ["cuda"], subtest(["cpu", "cuda"], name="cpu_cuda")]
39    )
40    def test_apply_to_tensors(self, devices):
41        if "cuda" in devices and (
42            not torch.cuda.is_available() or torch.cuda.device_count() < 1
43        ):
44            raise unittest.SkipTest("Skipped due to lack of GPU")
45
46        expected = 0
47
48        def get_a_tensor():
49            """Return a random tensor on random device."""
50            dev = random.choice(devices)
51            shape = random.choice(((1), (2, 3), (4, 5, 6), (7, 8, 9, 10)))
52            t = torch.rand(shape).to(dev)
53            nonlocal expected
54            expected += t.numel()
55            return t
56
57        @dataclass
58        class NonFrozenDataClass:
59            some_key: str
60            some_float: float
61            some_tensor: List[torch.Tensor]
62
63        @dataclass(frozen=True)
64        class FrozenDataClass:
65            some_key: str
66            some_float: float
67            some_tensor: List[torch.Tensor]
68
69        # create a mixed bag of data.
70        data = [1, "str"]
71        data.append({"key1": get_a_tensor(), "key2": {1: get_a_tensor()}, "key3": 3})
72        data.insert(0, {"x", get_a_tensor(), get_a_tensor()})
73        data.append(([1], get_a_tensor(), (1), [get_a_tensor()], {1, 2}))
74        data.append(
75            {"non_frozen_ds": NonFrozenDataClass("some_key", 1.0, [get_a_tensor()])}
76        )
77        data.append({"frozen_ds": FrozenDataClass("some_key", 1.0, [get_a_tensor()])})
78        od = OrderedDict()
79        od["k"] = "value"
80        data.append(od)
81
82        total = 0
83
84        def fn(t):
85            nonlocal total
86            total += t.numel()
87            return t
88
89        new_data = _apply_to_tensors(fn, data)
90        self.assertEqual(total, expected)
91        for i, v in enumerate(data):
92            self.assertEqual(type(new_data[i]), type(v))
93
94    def test_replace_by_prefix(self):
95        state_dict = {
96            "layer.a": torch.tensor(1),
97            "abc.layer.def": torch.tensor(2),
98            "layer.b": torch.tensor(3),
99        }
100        original_state_dict = state_dict.copy()
101        _replace_by_prefix(state_dict, "layer.", "module.layer.")
102        assert state_dict == {
103            "module.layer.a": torch.tensor(1),
104            "abc.layer.def": torch.tensor(2),
105            "module.layer.b": torch.tensor(3),
106        }
107        _replace_by_prefix(state_dict, "module.layer.", "layer.")
108        assert state_dict == original_state_dict
109
110    def test_packed_sequence(self):
111        """Test to ensure RNN packed sequences are modified correctly."""
112        rnn = nn.RNN(5, 5)
113
114        x = torch.rand((5, 1, 5), dtype=torch.float)
115        seq_length = torch.tensor([4], dtype=torch.int)
116
117        def fill_fn(x):
118            x.fill_(0)
119
120        x = nn.utils.rnn.pack_padded_sequence(x, seq_length)
121        x, h = rnn(x)
122        x = _apply_to_tensors(fill_fn, x)
123        x, _ = nn.utils.rnn.pad_packed_sequence(x)
124        self.assertEqual(torch.sum(x), 0)
125
126
127instantiate_parametrized_tests(TestUtils)
128
129if __name__ == "__main__":
130    run_tests()
131