1# mypy: allow-untyped-defs
2
3import builtins
4
5import torch
6from torch.distributed._shard.sharding_spec import (
7    ChunkShardingSpec,
8    EnumerableShardingSpec,
9    ShardMetadata,
10)
11from torch.distributed._shard.sharding_spec._internals import (
12    get_chunked_dim_size,
13    get_split_size,
14)
15
16
17def generate_chunk_sharding_specs_for_test(sharding_dim):
18    return [
19        ChunkShardingSpec(
20            dim=sharding_dim,
21            placements=[
22                "rank:0/cuda:0",
23                "rank:1/cuda:1",
24                "rank:2/cuda:2",
25                "rank:3/cuda:3",
26            ],
27        ),
28        # Test different ordering. (Case 1)
29        ChunkShardingSpec(
30            dim=sharding_dim,
31            placements=[
32                "rank:2/cuda:2",
33                "rank:3/cuda:3",
34                "rank:0/cuda:0",
35                "rank:1/cuda:1",
36            ],
37        ),
38        # Test different ordering. (Case 2)
39        ChunkShardingSpec(
40            dim=sharding_dim,
41            placements=[
42                "rank:3/cuda:3",
43                "rank:0/cuda:0",
44                "rank:1/cuda:1",
45                "rank:2/cuda:2",
46            ],
47        ),
48    ]
49
50
51def generate_enumerable_sharding_specs_for_test():
52    return [
53        EnumerableShardingSpec(
54            [
55                ShardMetadata(
56                    shard_offsets=[0, 0],
57                    shard_sizes=[5, 5],
58                    placement="rank:0/cuda:0",
59                ),
60                ShardMetadata(
61                    shard_offsets=[5, 0],
62                    shard_sizes=[5, 5],
63                    placement="rank:1/cuda:1",
64                ),
65                ShardMetadata(
66                    shard_offsets=[0, 5],
67                    shard_sizes=[5, 5],
68                    placement="rank:2/cuda:2",
69                ),
70                ShardMetadata(
71                    shard_offsets=[5, 5],
72                    shard_sizes=[5, 5],
73                    placement="rank:3/cuda:3",
74                ),
75            ]
76        )
77    ]
78
79
80def generate_local_weight_sharding_params_for_test(
81    local_weight, sharded_dim, gpu_num, spec, rank
82):
83    """
84    Shard the local weight based the given spec, so we can compare against
85    the one from sharded tensor.
86
87    Args:
88        local_weight: weight matrix to be sharded.
89        sharded_dim: The dimension which we shard on.
90        gpu_num: number of ranks.
91        spec: sharding spec.
92        rank: # of cuda process.
93
94    Returns:
95        start_pos: start position of sharded weight on the given rank.
96        chunk_size: chunk size of sharded weight on the given rank.
97    """
98    sharding_dim_size = local_weight.size(sharded_dim)
99    split_size = get_split_size(sharding_dim_size, gpu_num)
100    current_offsets = 0
101    start_pos = current_offsets
102    for idx, placement in enumerate(spec.placements):
103        chunk_size = get_chunked_dim_size(sharding_dim_size, split_size, idx)
104        if rank == placement.rank():
105            start_pos = current_offsets
106            break
107        current_offsets += chunk_size
108    return start_pos, chunk_size
109
110
111def clone_module_parameter(module, param_name):
112    """
113    Clone a parameter from a given existing module.
114
115    Args:
116        module (:class:`torch.nn.Module`): Module whose parameter needs to be cloned.
117        param_name (str): Name of the parameter of ``module`` that needs to be cloned.
118
119    Returns: cloned tensor as :class:`torch.nn.Parameter`.
120    """
121    tensor = getattr(module, param_name)
122    return torch.nn.Parameter(tensor.detach().clone())
123
124def gen_binary_op_func(python_op, inplace=False):
125    src_lines = ['def f(lhs, rhs):']
126    if "torch" in python_op:
127        src_lines.append(f'  return {python_op}(lhs, rhs)\n')
128    elif inplace:
129        src_lines.append(f'  lhs {python_op}= rhs\n  return lhs\n')
130    else:
131        src_lines.append(f'  return lhs {python_op} rhs\n')
132
133    code_str = '\n'.join(src_lines)
134    g = {'torch': torch}
135    builtins.exec(code_str, g)
136    return g["f"]
137