xref: /aosp_15_r20/external/pytorch/test/distributed/checkpoint/test_planner.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3import sys
4
5import torch
6import torch.distributed.checkpoint as dcp
7import torch.nn as nn
8from torch.distributed._shard.sharded_tensor import (
9    Shard,
10    ShardedTensor,
11    ShardedTensorMetadata,
12    ShardMetadata,
13)
14from torch.distributed._shard.sharded_tensor.metadata import (
15    TensorProperties as TensorProperties_Shard,
16)
17from torch.distributed.checkpoint._dedup_save_plans import dedup_save_plans
18from torch.distributed.checkpoint.api import CheckpointException
19from torch.distributed.checkpoint.default_planner import (
20    _create_default_local_metadata,
21    create_default_global_save_plan,
22    create_default_local_load_plan,
23    create_default_local_save_plan,
24    DefaultLoadPlanner,
25)
26from torch.distributed.checkpoint.metadata import (
27    BytesStorageMetadata,
28    ChunkStorageMetadata,
29    MetadataIndex,
30    TensorProperties,
31    TensorStorageMetadata,
32)
33from torch.distributed.checkpoint.planner import LoadItemType, WriteItemType
34from torch.distributed.checkpoint.planner_helpers import (
35    create_read_items_for_chunk_list,
36)
37from torch.testing._internal.common_utils import (
38    run_tests,
39    TEST_WITH_DEV_DBG_ASAN,
40    TestCase,
41)
42from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
43from torch.testing._internal.distributed.distributed_utils import (
44    with_dist,
45    with_fake_comms,
46)
47
48
49if TEST_WITH_DEV_DBG_ASAN:
50    print(
51        "Skip dev-asan as torch + multiprocessing spawn have known issues",
52        file=sys.stderr,
53    )
54    sys.exit(0)
55
56
57def create_sharded_tensor(rank, world_size, shards_per_rank, shard_size=8):
58    shards_metadata = []
59    local_shards = []
60    for idx in range(0, world_size * shards_per_rank):
61        shard_rank = idx // shards_per_rank
62        shard_md = ShardMetadata(
63            shard_offsets=[idx * shard_size],
64            shard_sizes=[shard_size],
65            placement=f"rank:{shard_rank}/cpu",
66        )
67        shards_metadata.append(shard_md)
68        if shard_rank == rank:
69            shard = Shard.from_tensor_and_offsets(
70                torch.rand(*shard_md.shard_sizes),
71                shard_offsets=shard_md.shard_offsets,
72                rank=rank,
73            )
74            local_shards.append(shard)
75
76    sharded_tensor_md = ShardedTensorMetadata(
77        shards_metadata=shards_metadata,
78        size=torch.Size([shard_size * len(shards_metadata)]),
79        tensor_properties=TensorProperties_Shard.create_from_tensor(torch.zeros(1)),
80    )
81
82    return ShardedTensor._init_from_local_shards_and_global_metadata(
83        local_shards=local_shards, sharded_tensor_metadata=sharded_tensor_md
84    )
85
86
87class TestSavePlan(TestCase):
88    @with_fake_comms(rank=1, world_size=4)
89    def test_local_plan(self):
90        tensor = torch.rand(10)
91        val = [1, 2, 3]
92        st = create_sharded_tensor(rank=1, world_size=4, shards_per_rank=1)
93        state_dict = {"tensor": tensor, "value": val, "st": st}
94        plan = create_default_local_save_plan(state_dict, False)
95        self.assertEqual(3, len(plan.items))
96        wi = plan.items[0]
97        self.assertEqual(wi.index, MetadataIndex("tensor", [0]))
98        self.assertEqual(wi.type, WriteItemType.TENSOR)
99        self.assertEqual(wi.tensor_data.size, tensor.size())
100        self.assertEqual(
101            wi.tensor_data.properties,
102            TensorProperties.create_from_tensor(torch.zeros(1)),
103        )
104        self.assertEqual(wi.tensor_data.chunk.offsets, torch.Size([0]))
105        self.assertEqual(wi.tensor_data.chunk.sizes, torch.Size([10]))
106
107        st_wi = plan.items[2]
108        self.assertEqual(st_wi.index, MetadataIndex("st", [8]))
109        self.assertEqual(st_wi.type, WriteItemType.SHARD)
110        self.assertEqual(st_wi.tensor_data.size, st.size())
111        self.assertEqual(
112            st_wi.tensor_data.properties,
113            TensorProperties.create_from_tensor(torch.zeros(1)),
114        )
115        self.assertEqual(st_wi.tensor_data.chunk.offsets, torch.Size([8]))
116        self.assertEqual(st_wi.tensor_data.chunk.sizes, torch.Size([8]))
117
118        # Coordinator rank, should include replicated items as well
119        plan = create_default_local_save_plan(state_dict, True)
120        self.assertEqual(3, len(plan.items))
121
122        tensor_wi = next(wi for wi in plan.items if wi.type == WriteItemType.TENSOR)
123        self.assertEqual(tensor_wi.index, MetadataIndex("tensor", [0]))
124        self.assertEqual(tensor_wi.tensor_data.size, tensor.size())
125        self.assertEqual(
126            tensor_wi.tensor_data.properties,
127            TensorProperties.create_from_tensor(tensor),
128        )
129        self.assertEqual(tensor_wi.tensor_data.chunk.offsets, torch.Size([0]))
130        self.assertEqual(tensor_wi.tensor_data.chunk.sizes, torch.Size([10]))
131
132        bytes_wi = next(wi for wi in plan.items if wi.type == WriteItemType.BYTE_IO)
133        self.assertEqual(bytes_wi.index, MetadataIndex("value"))
134        self.assertIsNone(bytes_wi.tensor_data)
135
136    def test_global_plan(self):
137        def create_data(rank):
138            with with_dist(rank=rank, world_size=4):
139                tensor = torch.rand(10)
140                val = [1, 2, 3]
141                st = create_sharded_tensor(rank=rank, world_size=4, shards_per_rank=1)
142                state_dict = {"tensor": tensor, "value": val, "st": st}
143                return create_default_local_save_plan(state_dict, rank == 0)
144
145        all_plans = [create_data(0), create_data(1), create_data(2), create_data(3)]
146        all_plans = dedup_save_plans(all_plans)
147        final_plans, metadata = create_default_global_save_plan(all_plans=all_plans)
148
149        # The default global plan updates all indexes to include hints
150        for new_plan, old_plan in zip(final_plans, all_plans):
151            for new_item, old_item in zip(new_plan.items, old_plan.items):
152                self.assertEqual(new_item.index, old_item.index)
153                self.assertEqual(new_item.type, old_item.type)
154                self.assertEqual(new_item.tensor_data, old_item.tensor_data)
155                self.assertIn(new_item.index.fqn, metadata.state_dict_metadata)
156
157                item_md = metadata.state_dict_metadata[new_item.index.fqn]
158                if new_item.type == WriteItemType.BYTE_IO:
159                    self.assertTrue(isinstance(item_md, BytesStorageMetadata))
160                else:
161                    self.assertTrue(isinstance(item_md, TensorStorageMetadata))
162                    self.assertEqual(item_md.size, old_item.tensor_data.size)
163                    self.assertEqual(
164                        item_md.properties, old_item.tensor_data.properties
165                    )
166
167                    self.assertIsNotNone(new_item.index.index)
168                    # Make sure the hint is correct
169                    self.assertEqual(
170                        item_md.chunks[new_item.index.index], old_item.tensor_data.chunk
171                    )
172
173    def test_local_load_plan(self):
174        def create_state_dict(rank):
175            with with_dist(rank=rank, world_size=4):
176                tensor = torch.rand(10)
177                val = [1, 2, 3]
178                st = create_sharded_tensor(rank=rank, world_size=4, shards_per_rank=1)
179                return {"tensor": tensor, "value": val, "st": st}
180
181        state_dict = create_state_dict(1)
182        metadata = _create_default_local_metadata(state_dict)
183
184        load_plan = create_default_local_load_plan(state_dict, metadata)
185        # This will create 3 entries
186        self.assertEqual(3, len(load_plan.items))
187        st_item = next(ri for ri in load_plan.items if ri.dest_index.fqn == "st")
188        tensor_item = next(
189            ri for ri in load_plan.items if ri.dest_index.fqn == "tensor"
190        )
191        bytes_item = next(ri for ri in load_plan.items if ri.dest_index.fqn == "value")
192
193        self.assertEqual(st_item.type, LoadItemType.TENSOR)
194        # This is an exact copy
195        self.assertEqual(st_item.dest_index, MetadataIndex("st", [8]))
196        self.assertEqual(st_item.dest_offsets, torch.Size([0]))
197        self.assertEqual(st_item.storage_index, MetadataIndex("st", [8]))
198        self.assertEqual(st_item.storage_offsets, torch.Size([0]))
199        self.assertEqual(st_item.lengths, torch.Size([8]))
200
201        self.assertEqual(tensor_item.type, LoadItemType.TENSOR)
202        self.assertEqual(tensor_item.dest_index, MetadataIndex("tensor", [0]))
203        self.assertEqual(tensor_item.dest_offsets, torch.Size([0]))
204        self.assertEqual(tensor_item.storage_index, MetadataIndex("tensor", [0]))
205        self.assertEqual(tensor_item.storage_offsets, torch.Size([0]))
206        self.assertEqual(tensor_item.lengths, torch.Size([10]))
207
208        self.assertEqual(bytes_item.type, LoadItemType.BYTE_IO)
209        self.assertEqual(bytes_item.dest_index, MetadataIndex("value"))
210
211    def test_load_with_resharding(self):
212        def create_state_dict(rank, world_size):
213            with with_dist(rank=rank, world_size=world_size):
214                return {
215                    "st": create_sharded_tensor(
216                        rank=rank,
217                        world_size=world_size,
218                        shards_per_rank=1,
219                        shard_size=128 // world_size,
220                    )
221                }
222
223        # Rank 1 has a 16 bytes shard from [16, 32[
224        world8_state_dict = create_state_dict(rank=1, world_size=8)
225        world8_metadata = _create_default_local_metadata(world8_state_dict)
226
227        # Rank 1 has a 32 bytes shard from [32, 64[
228        world4_state_dict = create_state_dict(rank=1, world_size=4)
229        world4_metadata = _create_default_local_metadata(world4_state_dict)
230
231        # First scenario, going from world=8 to world=4, need to load 2 shards
232        # Each 4-world shard has 32 elements, so it needs to load 2 shards
233        load_plan = create_default_local_load_plan(world4_state_dict, world8_metadata)
234        self.assertEqual(2, len(load_plan.items))
235        low_ri = next(
236            ri for ri in load_plan.items if ri.dest_offsets == torch.Size([0])
237        )
238        high_ri = next(
239            ri for ri in load_plan.items if ri.dest_offsets == torch.Size([16])
240        )
241
242        self.assertEqual(low_ri.storage_index, MetadataIndex("st", [32]))
243        self.assertEqual(low_ri.storage_offsets, torch.Size([0]))
244        self.assertEqual(low_ri.dest_index, MetadataIndex("st", [32]))
245        self.assertEqual(low_ri.dest_offsets, torch.Size([0]))
246        self.assertEqual(low_ri.lengths, torch.Size([16]))
247
248        self.assertEqual(high_ri.storage_index, MetadataIndex("st", [48]))
249        self.assertEqual(high_ri.storage_offsets, torch.Size([0]))
250        self.assertEqual(high_ri.dest_index, MetadataIndex("st", [32]))
251        self.assertEqual(high_ri.dest_offsets, torch.Size([16]))
252        self.assertEqual(high_ri.lengths, torch.Size([16]))
253
254        # Second scenario, going from world=4 to world=8, need to load half of 1 shard
255        # rank1 on 8-world needs to load the upper half of the rank0 4-world shard
256        load_plan = create_default_local_load_plan(world8_state_dict, world4_metadata)
257        self.assertEqual(1, len(load_plan.items))
258        ri = load_plan.items[0]
259        self.assertEqual(ri.storage_index, MetadataIndex("st", [0]))
260        self.assertEqual(ri.storage_offsets, torch.Size([16]))
261        self.assertEqual(ri.dest_index, MetadataIndex("st", [16]))
262        self.assertEqual(ri.dest_offsets, torch.Size([0]))
263        self.assertEqual(ri.lengths, torch.Size([16]))
264
265    def test_load_with_world_size_diff_by_one(self):
266        def create_state_dict(rank, world_size):
267            with with_dist(rank=rank, world_size=world_size):
268                return {
269                    "st": create_sharded_tensor(
270                        rank=rank,
271                        world_size=world_size,
272                        shards_per_rank=1,
273                        shard_size=120 // world_size,
274                    )
275                }
276
277        # rank 1 has a 30 bytes shard from [30, 60[
278        world4_state_dict = create_state_dict(rank=1, world_size=4)
279        world4_metadata = _create_default_local_metadata(world4_state_dict)
280
281        # rank 1 has a 40 bytes shard from [40, 80[
282        world3_state_dict = create_state_dict(rank=1, world_size=3)
283
284        load_plan = create_default_local_load_plan(world3_state_dict, world4_metadata)
285        self.assertEqual(2, len(load_plan.items))
286        # this is [30, 60] to load [40, 60]
287        low_ri = next(
288            ri for ri in load_plan.items if ri.dest_offsets == torch.Size([0])
289        )
290        # this is [60, 90] to load [60, 80]
291        high_ri = next(
292            ri for ri in load_plan.items if ri.dest_offsets == torch.Size([20])
293        )
294
295        self.assertEqual(low_ri.storage_index, MetadataIndex("st", [30]))
296        self.assertEqual(low_ri.storage_offsets, torch.Size([10]))
297        self.assertEqual(low_ri.dest_index, MetadataIndex("st", [40]))
298        self.assertEqual(low_ri.dest_offsets, torch.Size([0]))
299        self.assertEqual(low_ri.lengths, torch.Size([20]))
300
301        self.assertEqual(high_ri.storage_index, MetadataIndex("st", [60]))
302        self.assertEqual(high_ri.storage_offsets, torch.Size([0]))
303        self.assertEqual(high_ri.dest_index, MetadataIndex("st", [40]))
304        self.assertEqual(high_ri.dest_offsets, torch.Size([20]))
305        self.assertEqual(high_ri.lengths, torch.Size([20]))
306
307
308class TestPlannerHelpers(TestCase):
309    def test_create_read_item_from_chunks(self):
310        tensor_md = TensorStorageMetadata(
311            properties=TensorProperties.create_from_tensor(torch.empty([16])),
312            size=torch.Size([16]),
313            chunks=[
314                ChunkStorageMetadata(offsets=torch.Size([0]), sizes=torch.Size([8])),
315                ChunkStorageMetadata(offsets=torch.Size([8]), sizes=torch.Size([8])),
316            ],
317        )
318
319        chunk = ChunkStorageMetadata(offsets=torch.Size([4]), sizes=torch.Size([7]))
320        read_items = create_read_items_for_chunk_list("foo", tensor_md, [chunk])
321
322        self.assertEqual(2, len(read_items))
323        self.assertEqual(MetadataIndex("foo", [4]), read_items[0].dest_index)
324        self.assertEqual(torch.Size([0]), read_items[0].dest_offsets)
325
326        self.assertEqual(MetadataIndex("foo", [0]), read_items[0].storage_index)
327        self.assertEqual(torch.Size([4]), read_items[0].storage_offsets)
328
329        self.assertEqual(torch.Size([4]), read_items[0].lengths)
330
331        self.assertEqual(MetadataIndex("foo", [4]), read_items[1].dest_index)
332        self.assertEqual(torch.Size([4]), read_items[1].dest_offsets)
333
334        self.assertEqual(MetadataIndex("foo", [8]), read_items[1].storage_index)
335        self.assertEqual(torch.Size([0]), read_items[1].storage_offsets)
336
337        self.assertEqual(torch.Size([3]), read_items[1].lengths)
338
339
340class TestLoadPlanner(TestCase):
341    @with_temp_dir
342    def test_strict(self):
343        original_module = nn.Linear(2, 2)
344        dcp.save(state_dict={"module": original_module}, checkpoint_id=self.temp_dir)
345
346        new_module = nn.Linear(2, 2)
347        new_module.extra_param = nn.Parameter(torch.randn(2, 2))
348        dcp.load(
349            state_dict={"module": new_module},
350            checkpoint_id=self.temp_dir,
351            planner=DefaultLoadPlanner(allow_partial_load=True),
352        )
353
354        with self.assertRaisesRegex(CheckpointException, "Missing key in checkpoint"):
355            dcp.load(
356                state_dict={"module": new_module},
357                checkpoint_id=self.temp_dir,
358                planner=DefaultLoadPlanner(allow_partial_load=False),
359            )
360
361
362if __name__ == "__main__":
363    run_tests()
364