1# mypy: allow-untyped-defs 2 3import builtins 4 5import torch 6from torch.distributed._shard.sharding_spec import ( 7 ChunkShardingSpec, 8 EnumerableShardingSpec, 9 ShardMetadata, 10) 11from torch.distributed._shard.sharding_spec._internals import ( 12 get_chunked_dim_size, 13 get_split_size, 14) 15 16 17def generate_chunk_sharding_specs_for_test(sharding_dim): 18 return [ 19 ChunkShardingSpec( 20 dim=sharding_dim, 21 placements=[ 22 "rank:0/cuda:0", 23 "rank:1/cuda:1", 24 "rank:2/cuda:2", 25 "rank:3/cuda:3", 26 ], 27 ), 28 # Test different ordering. (Case 1) 29 ChunkShardingSpec( 30 dim=sharding_dim, 31 placements=[ 32 "rank:2/cuda:2", 33 "rank:3/cuda:3", 34 "rank:0/cuda:0", 35 "rank:1/cuda:1", 36 ], 37 ), 38 # Test different ordering. (Case 2) 39 ChunkShardingSpec( 40 dim=sharding_dim, 41 placements=[ 42 "rank:3/cuda:3", 43 "rank:0/cuda:0", 44 "rank:1/cuda:1", 45 "rank:2/cuda:2", 46 ], 47 ), 48 ] 49 50 51def generate_enumerable_sharding_specs_for_test(): 52 return [ 53 EnumerableShardingSpec( 54 [ 55 ShardMetadata( 56 shard_offsets=[0, 0], 57 shard_sizes=[5, 5], 58 placement="rank:0/cuda:0", 59 ), 60 ShardMetadata( 61 shard_offsets=[5, 0], 62 shard_sizes=[5, 5], 63 placement="rank:1/cuda:1", 64 ), 65 ShardMetadata( 66 shard_offsets=[0, 5], 67 shard_sizes=[5, 5], 68 placement="rank:2/cuda:2", 69 ), 70 ShardMetadata( 71 shard_offsets=[5, 5], 72 shard_sizes=[5, 5], 73 placement="rank:3/cuda:3", 74 ), 75 ] 76 ) 77 ] 78 79 80def generate_local_weight_sharding_params_for_test( 81 local_weight, sharded_dim, gpu_num, spec, rank 82): 83 """ 84 Shard the local weight based the given spec, so we can compare against 85 the one from sharded tensor. 86 87 Args: 88 local_weight: weight matrix to be sharded. 89 sharded_dim: The dimension which we shard on. 90 gpu_num: number of ranks. 91 spec: sharding spec. 92 rank: # of cuda process. 93 94 Returns: 95 start_pos: start position of sharded weight on the given rank. 96 chunk_size: chunk size of sharded weight on the given rank. 97 """ 98 sharding_dim_size = local_weight.size(sharded_dim) 99 split_size = get_split_size(sharding_dim_size, gpu_num) 100 current_offsets = 0 101 start_pos = current_offsets 102 for idx, placement in enumerate(spec.placements): 103 chunk_size = get_chunked_dim_size(sharding_dim_size, split_size, idx) 104 if rank == placement.rank(): 105 start_pos = current_offsets 106 break 107 current_offsets += chunk_size 108 return start_pos, chunk_size 109 110 111def clone_module_parameter(module, param_name): 112 """ 113 Clone a parameter from a given existing module. 114 115 Args: 116 module (:class:`torch.nn.Module`): Module whose parameter needs to be cloned. 117 param_name (str): Name of the parameter of ``module`` that needs to be cloned. 118 119 Returns: cloned tensor as :class:`torch.nn.Parameter`. 120 """ 121 tensor = getattr(module, param_name) 122 return torch.nn.Parameter(tensor.detach().clone()) 123 124def gen_binary_op_func(python_op, inplace=False): 125 src_lines = ['def f(lhs, rhs):'] 126 if "torch" in python_op: 127 src_lines.append(f' return {python_op}(lhs, rhs)\n') 128 elif inplace: 129 src_lines.append(f' lhs {python_op}= rhs\n return lhs\n') 130 else: 131 src_lines.append(f' return lhs {python_op} rhs\n') 132 133 code_str = '\n'.join(src_lines) 134 g = {'torch': torch} 135 builtins.exec(code_str, g) 136 return g["f"] 137