1# Owner(s): ["oncall: distributed"] 2 3import sys 4 5import torch 6import torch.distributed.checkpoint as dcp 7import torch.nn as nn 8from torch.distributed._shard.sharded_tensor import ( 9 Shard, 10 ShardedTensor, 11 ShardedTensorMetadata, 12 ShardMetadata, 13) 14from torch.distributed._shard.sharded_tensor.metadata import ( 15 TensorProperties as TensorProperties_Shard, 16) 17from torch.distributed.checkpoint._dedup_save_plans import dedup_save_plans 18from torch.distributed.checkpoint.api import CheckpointException 19from torch.distributed.checkpoint.default_planner import ( 20 _create_default_local_metadata, 21 create_default_global_save_plan, 22 create_default_local_load_plan, 23 create_default_local_save_plan, 24 DefaultLoadPlanner, 25) 26from torch.distributed.checkpoint.metadata import ( 27 BytesStorageMetadata, 28 ChunkStorageMetadata, 29 MetadataIndex, 30 TensorProperties, 31 TensorStorageMetadata, 32) 33from torch.distributed.checkpoint.planner import LoadItemType, WriteItemType 34from torch.distributed.checkpoint.planner_helpers import ( 35 create_read_items_for_chunk_list, 36) 37from torch.testing._internal.common_utils import ( 38 run_tests, 39 TEST_WITH_DEV_DBG_ASAN, 40 TestCase, 41) 42from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir 43from torch.testing._internal.distributed.distributed_utils import ( 44 with_dist, 45 with_fake_comms, 46) 47 48 49if TEST_WITH_DEV_DBG_ASAN: 50 print( 51 "Skip dev-asan as torch + multiprocessing spawn have known issues", 52 file=sys.stderr, 53 ) 54 sys.exit(0) 55 56 57def create_sharded_tensor(rank, world_size, shards_per_rank, shard_size=8): 58 shards_metadata = [] 59 local_shards = [] 60 for idx in range(0, world_size * shards_per_rank): 61 shard_rank = idx // shards_per_rank 62 shard_md = ShardMetadata( 63 shard_offsets=[idx * shard_size], 64 shard_sizes=[shard_size], 65 placement=f"rank:{shard_rank}/cpu", 66 ) 67 shards_metadata.append(shard_md) 68 if shard_rank == rank: 69 shard = Shard.from_tensor_and_offsets( 70 torch.rand(*shard_md.shard_sizes), 71 shard_offsets=shard_md.shard_offsets, 72 rank=rank, 73 ) 74 local_shards.append(shard) 75 76 sharded_tensor_md = ShardedTensorMetadata( 77 shards_metadata=shards_metadata, 78 size=torch.Size([shard_size * len(shards_metadata)]), 79 tensor_properties=TensorProperties_Shard.create_from_tensor(torch.zeros(1)), 80 ) 81 82 return ShardedTensor._init_from_local_shards_and_global_metadata( 83 local_shards=local_shards, sharded_tensor_metadata=sharded_tensor_md 84 ) 85 86 87class TestSavePlan(TestCase): 88 @with_fake_comms(rank=1, world_size=4) 89 def test_local_plan(self): 90 tensor = torch.rand(10) 91 val = [1, 2, 3] 92 st = create_sharded_tensor(rank=1, world_size=4, shards_per_rank=1) 93 state_dict = {"tensor": tensor, "value": val, "st": st} 94 plan = create_default_local_save_plan(state_dict, False) 95 self.assertEqual(3, len(plan.items)) 96 wi = plan.items[0] 97 self.assertEqual(wi.index, MetadataIndex("tensor", [0])) 98 self.assertEqual(wi.type, WriteItemType.TENSOR) 99 self.assertEqual(wi.tensor_data.size, tensor.size()) 100 self.assertEqual( 101 wi.tensor_data.properties, 102 TensorProperties.create_from_tensor(torch.zeros(1)), 103 ) 104 self.assertEqual(wi.tensor_data.chunk.offsets, torch.Size([0])) 105 self.assertEqual(wi.tensor_data.chunk.sizes, torch.Size([10])) 106 107 st_wi = plan.items[2] 108 self.assertEqual(st_wi.index, MetadataIndex("st", [8])) 109 self.assertEqual(st_wi.type, WriteItemType.SHARD) 110 self.assertEqual(st_wi.tensor_data.size, st.size()) 111 self.assertEqual( 112 st_wi.tensor_data.properties, 113 TensorProperties.create_from_tensor(torch.zeros(1)), 114 ) 115 self.assertEqual(st_wi.tensor_data.chunk.offsets, torch.Size([8])) 116 self.assertEqual(st_wi.tensor_data.chunk.sizes, torch.Size([8])) 117 118 # Coordinator rank, should include replicated items as well 119 plan = create_default_local_save_plan(state_dict, True) 120 self.assertEqual(3, len(plan.items)) 121 122 tensor_wi = next(wi for wi in plan.items if wi.type == WriteItemType.TENSOR) 123 self.assertEqual(tensor_wi.index, MetadataIndex("tensor", [0])) 124 self.assertEqual(tensor_wi.tensor_data.size, tensor.size()) 125 self.assertEqual( 126 tensor_wi.tensor_data.properties, 127 TensorProperties.create_from_tensor(tensor), 128 ) 129 self.assertEqual(tensor_wi.tensor_data.chunk.offsets, torch.Size([0])) 130 self.assertEqual(tensor_wi.tensor_data.chunk.sizes, torch.Size([10])) 131 132 bytes_wi = next(wi for wi in plan.items if wi.type == WriteItemType.BYTE_IO) 133 self.assertEqual(bytes_wi.index, MetadataIndex("value")) 134 self.assertIsNone(bytes_wi.tensor_data) 135 136 def test_global_plan(self): 137 def create_data(rank): 138 with with_dist(rank=rank, world_size=4): 139 tensor = torch.rand(10) 140 val = [1, 2, 3] 141 st = create_sharded_tensor(rank=rank, world_size=4, shards_per_rank=1) 142 state_dict = {"tensor": tensor, "value": val, "st": st} 143 return create_default_local_save_plan(state_dict, rank == 0) 144 145 all_plans = [create_data(0), create_data(1), create_data(2), create_data(3)] 146 all_plans = dedup_save_plans(all_plans) 147 final_plans, metadata = create_default_global_save_plan(all_plans=all_plans) 148 149 # The default global plan updates all indexes to include hints 150 for new_plan, old_plan in zip(final_plans, all_plans): 151 for new_item, old_item in zip(new_plan.items, old_plan.items): 152 self.assertEqual(new_item.index, old_item.index) 153 self.assertEqual(new_item.type, old_item.type) 154 self.assertEqual(new_item.tensor_data, old_item.tensor_data) 155 self.assertIn(new_item.index.fqn, metadata.state_dict_metadata) 156 157 item_md = metadata.state_dict_metadata[new_item.index.fqn] 158 if new_item.type == WriteItemType.BYTE_IO: 159 self.assertTrue(isinstance(item_md, BytesStorageMetadata)) 160 else: 161 self.assertTrue(isinstance(item_md, TensorStorageMetadata)) 162 self.assertEqual(item_md.size, old_item.tensor_data.size) 163 self.assertEqual( 164 item_md.properties, old_item.tensor_data.properties 165 ) 166 167 self.assertIsNotNone(new_item.index.index) 168 # Make sure the hint is correct 169 self.assertEqual( 170 item_md.chunks[new_item.index.index], old_item.tensor_data.chunk 171 ) 172 173 def test_local_load_plan(self): 174 def create_state_dict(rank): 175 with with_dist(rank=rank, world_size=4): 176 tensor = torch.rand(10) 177 val = [1, 2, 3] 178 st = create_sharded_tensor(rank=rank, world_size=4, shards_per_rank=1) 179 return {"tensor": tensor, "value": val, "st": st} 180 181 state_dict = create_state_dict(1) 182 metadata = _create_default_local_metadata(state_dict) 183 184 load_plan = create_default_local_load_plan(state_dict, metadata) 185 # This will create 3 entries 186 self.assertEqual(3, len(load_plan.items)) 187 st_item = next(ri for ri in load_plan.items if ri.dest_index.fqn == "st") 188 tensor_item = next( 189 ri for ri in load_plan.items if ri.dest_index.fqn == "tensor" 190 ) 191 bytes_item = next(ri for ri in load_plan.items if ri.dest_index.fqn == "value") 192 193 self.assertEqual(st_item.type, LoadItemType.TENSOR) 194 # This is an exact copy 195 self.assertEqual(st_item.dest_index, MetadataIndex("st", [8])) 196 self.assertEqual(st_item.dest_offsets, torch.Size([0])) 197 self.assertEqual(st_item.storage_index, MetadataIndex("st", [8])) 198 self.assertEqual(st_item.storage_offsets, torch.Size([0])) 199 self.assertEqual(st_item.lengths, torch.Size([8])) 200 201 self.assertEqual(tensor_item.type, LoadItemType.TENSOR) 202 self.assertEqual(tensor_item.dest_index, MetadataIndex("tensor", [0])) 203 self.assertEqual(tensor_item.dest_offsets, torch.Size([0])) 204 self.assertEqual(tensor_item.storage_index, MetadataIndex("tensor", [0])) 205 self.assertEqual(tensor_item.storage_offsets, torch.Size([0])) 206 self.assertEqual(tensor_item.lengths, torch.Size([10])) 207 208 self.assertEqual(bytes_item.type, LoadItemType.BYTE_IO) 209 self.assertEqual(bytes_item.dest_index, MetadataIndex("value")) 210 211 def test_load_with_resharding(self): 212 def create_state_dict(rank, world_size): 213 with with_dist(rank=rank, world_size=world_size): 214 return { 215 "st": create_sharded_tensor( 216 rank=rank, 217 world_size=world_size, 218 shards_per_rank=1, 219 shard_size=128 // world_size, 220 ) 221 } 222 223 # Rank 1 has a 16 bytes shard from [16, 32[ 224 world8_state_dict = create_state_dict(rank=1, world_size=8) 225 world8_metadata = _create_default_local_metadata(world8_state_dict) 226 227 # Rank 1 has a 32 bytes shard from [32, 64[ 228 world4_state_dict = create_state_dict(rank=1, world_size=4) 229 world4_metadata = _create_default_local_metadata(world4_state_dict) 230 231 # First scenario, going from world=8 to world=4, need to load 2 shards 232 # Each 4-world shard has 32 elements, so it needs to load 2 shards 233 load_plan = create_default_local_load_plan(world4_state_dict, world8_metadata) 234 self.assertEqual(2, len(load_plan.items)) 235 low_ri = next( 236 ri for ri in load_plan.items if ri.dest_offsets == torch.Size([0]) 237 ) 238 high_ri = next( 239 ri for ri in load_plan.items if ri.dest_offsets == torch.Size([16]) 240 ) 241 242 self.assertEqual(low_ri.storage_index, MetadataIndex("st", [32])) 243 self.assertEqual(low_ri.storage_offsets, torch.Size([0])) 244 self.assertEqual(low_ri.dest_index, MetadataIndex("st", [32])) 245 self.assertEqual(low_ri.dest_offsets, torch.Size([0])) 246 self.assertEqual(low_ri.lengths, torch.Size([16])) 247 248 self.assertEqual(high_ri.storage_index, MetadataIndex("st", [48])) 249 self.assertEqual(high_ri.storage_offsets, torch.Size([0])) 250 self.assertEqual(high_ri.dest_index, MetadataIndex("st", [32])) 251 self.assertEqual(high_ri.dest_offsets, torch.Size([16])) 252 self.assertEqual(high_ri.lengths, torch.Size([16])) 253 254 # Second scenario, going from world=4 to world=8, need to load half of 1 shard 255 # rank1 on 8-world needs to load the upper half of the rank0 4-world shard 256 load_plan = create_default_local_load_plan(world8_state_dict, world4_metadata) 257 self.assertEqual(1, len(load_plan.items)) 258 ri = load_plan.items[0] 259 self.assertEqual(ri.storage_index, MetadataIndex("st", [0])) 260 self.assertEqual(ri.storage_offsets, torch.Size([16])) 261 self.assertEqual(ri.dest_index, MetadataIndex("st", [16])) 262 self.assertEqual(ri.dest_offsets, torch.Size([0])) 263 self.assertEqual(ri.lengths, torch.Size([16])) 264 265 def test_load_with_world_size_diff_by_one(self): 266 def create_state_dict(rank, world_size): 267 with with_dist(rank=rank, world_size=world_size): 268 return { 269 "st": create_sharded_tensor( 270 rank=rank, 271 world_size=world_size, 272 shards_per_rank=1, 273 shard_size=120 // world_size, 274 ) 275 } 276 277 # rank 1 has a 30 bytes shard from [30, 60[ 278 world4_state_dict = create_state_dict(rank=1, world_size=4) 279 world4_metadata = _create_default_local_metadata(world4_state_dict) 280 281 # rank 1 has a 40 bytes shard from [40, 80[ 282 world3_state_dict = create_state_dict(rank=1, world_size=3) 283 284 load_plan = create_default_local_load_plan(world3_state_dict, world4_metadata) 285 self.assertEqual(2, len(load_plan.items)) 286 # this is [30, 60] to load [40, 60] 287 low_ri = next( 288 ri for ri in load_plan.items if ri.dest_offsets == torch.Size([0]) 289 ) 290 # this is [60, 90] to load [60, 80] 291 high_ri = next( 292 ri for ri in load_plan.items if ri.dest_offsets == torch.Size([20]) 293 ) 294 295 self.assertEqual(low_ri.storage_index, MetadataIndex("st", [30])) 296 self.assertEqual(low_ri.storage_offsets, torch.Size([10])) 297 self.assertEqual(low_ri.dest_index, MetadataIndex("st", [40])) 298 self.assertEqual(low_ri.dest_offsets, torch.Size([0])) 299 self.assertEqual(low_ri.lengths, torch.Size([20])) 300 301 self.assertEqual(high_ri.storage_index, MetadataIndex("st", [60])) 302 self.assertEqual(high_ri.storage_offsets, torch.Size([0])) 303 self.assertEqual(high_ri.dest_index, MetadataIndex("st", [40])) 304 self.assertEqual(high_ri.dest_offsets, torch.Size([20])) 305 self.assertEqual(high_ri.lengths, torch.Size([20])) 306 307 308class TestPlannerHelpers(TestCase): 309 def test_create_read_item_from_chunks(self): 310 tensor_md = TensorStorageMetadata( 311 properties=TensorProperties.create_from_tensor(torch.empty([16])), 312 size=torch.Size([16]), 313 chunks=[ 314 ChunkStorageMetadata(offsets=torch.Size([0]), sizes=torch.Size([8])), 315 ChunkStorageMetadata(offsets=torch.Size([8]), sizes=torch.Size([8])), 316 ], 317 ) 318 319 chunk = ChunkStorageMetadata(offsets=torch.Size([4]), sizes=torch.Size([7])) 320 read_items = create_read_items_for_chunk_list("foo", tensor_md, [chunk]) 321 322 self.assertEqual(2, len(read_items)) 323 self.assertEqual(MetadataIndex("foo", [4]), read_items[0].dest_index) 324 self.assertEqual(torch.Size([0]), read_items[0].dest_offsets) 325 326 self.assertEqual(MetadataIndex("foo", [0]), read_items[0].storage_index) 327 self.assertEqual(torch.Size([4]), read_items[0].storage_offsets) 328 329 self.assertEqual(torch.Size([4]), read_items[0].lengths) 330 331 self.assertEqual(MetadataIndex("foo", [4]), read_items[1].dest_index) 332 self.assertEqual(torch.Size([4]), read_items[1].dest_offsets) 333 334 self.assertEqual(MetadataIndex("foo", [8]), read_items[1].storage_index) 335 self.assertEqual(torch.Size([0]), read_items[1].storage_offsets) 336 337 self.assertEqual(torch.Size([3]), read_items[1].lengths) 338 339 340class TestLoadPlanner(TestCase): 341 @with_temp_dir 342 def test_strict(self): 343 original_module = nn.Linear(2, 2) 344 dcp.save(state_dict={"module": original_module}, checkpoint_id=self.temp_dir) 345 346 new_module = nn.Linear(2, 2) 347 new_module.extra_param = nn.Parameter(torch.randn(2, 2)) 348 dcp.load( 349 state_dict={"module": new_module}, 350 checkpoint_id=self.temp_dir, 351 planner=DefaultLoadPlanner(allow_partial_load=True), 352 ) 353 354 with self.assertRaisesRegex(CheckpointException, "Missing key in checkpoint"): 355 dcp.load( 356 state_dict={"module": new_module}, 357 checkpoint_id=self.temp_dir, 358 planner=DefaultLoadPlanner(allow_partial_load=False), 359 ) 360 361 362if __name__ == "__main__": 363 run_tests() 364