xref: /aosp_15_r20/external/pytorch/benchmarks/distributed/rpc/parameter_server/utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import torch
2
3
4RPC_SPARSE = "rpc_sparse"
5RPC_DENSE = "rpc_dense"
6
7
8def sparse_tensor_to_rpc_format(sparse_tensor):
9    r"""
10    A helper function creates a list containing the indices, values, and size
11    of a coalesced sparse tensor.
12    Args:
13        sparse_tensor (torch.Tensor): sparse_coo_tensor represented as a list
14    """
15    sparse_tensor = sparse_tensor.coalesce()
16    return [sparse_tensor.indices(), sparse_tensor.values(), sparse_tensor.size()]
17
18
19def sparse_rpc_format_to_tensor(sparse_rpc_format):
20    r"""
21    A helper function creates a sparse_coo_tensor from indices, values, and size.
22    Args:
23        sparse_rpc_format (list): sparse_coo_tensor represented as a list
24    """
25    return torch.sparse_coo_tensor(
26        sparse_rpc_format[0], sparse_rpc_format[1], sparse_rpc_format[2]
27    ).coalesce()
28
29
30def process_bucket_with_remote_server(state, bucket):
31    r"""
32    Processes a gradient bucket passed by a DDP communication hook
33    during .backward(). The method supports processing sparse and dense
34    tensors. It records RPC future completion time metric for the trainer.
35    Args:
36        state (object): maintains state during the training process
37        bucket (GradBucket): gradient bucket
38    """
39    cref = state.cref
40    tensor = bucket.buffer()
41    if not cref.use_cuda_rpc:
42        tensor = tensor.cpu()
43    sparse = tensor.is_sparse
44    if sparse:
45        tensor = sparse_tensor_to_rpc_format(tensor)
46    b_index = bucket.get_index()
47    server_args = [cref.server_rref, state.batch_number, b_index, tensor]
48    key = state.get_key(b_index)
49    cref.record_start("hook_future_metric", key, RPC_SPARSE if sparse else RPC_DENSE)
50    fut = cref.server_rref.rpc_async().average_gradient(*server_args)
51
52    def callback(fut):
53        cref.record_end("hook_future_metric", key)
54        tensor = fut.wait()
55        if type(tensor) is list:
56            tensor = sparse_rpc_format_to_tensor(tensor)
57        tensor = tensor.cuda(cref.rank)
58        return [tensor]
59
60    return fut.then(callback)
61