1import torch 2 3 4RPC_SPARSE = "rpc_sparse" 5RPC_DENSE = "rpc_dense" 6 7 8def sparse_tensor_to_rpc_format(sparse_tensor): 9 r""" 10 A helper function creates a list containing the indices, values, and size 11 of a coalesced sparse tensor. 12 Args: 13 sparse_tensor (torch.Tensor): sparse_coo_tensor represented as a list 14 """ 15 sparse_tensor = sparse_tensor.coalesce() 16 return [sparse_tensor.indices(), sparse_tensor.values(), sparse_tensor.size()] 17 18 19def sparse_rpc_format_to_tensor(sparse_rpc_format): 20 r""" 21 A helper function creates a sparse_coo_tensor from indices, values, and size. 22 Args: 23 sparse_rpc_format (list): sparse_coo_tensor represented as a list 24 """ 25 return torch.sparse_coo_tensor( 26 sparse_rpc_format[0], sparse_rpc_format[1], sparse_rpc_format[2] 27 ).coalesce() 28 29 30def process_bucket_with_remote_server(state, bucket): 31 r""" 32 Processes a gradient bucket passed by a DDP communication hook 33 during .backward(). The method supports processing sparse and dense 34 tensors. It records RPC future completion time metric for the trainer. 35 Args: 36 state (object): maintains state during the training process 37 bucket (GradBucket): gradient bucket 38 """ 39 cref = state.cref 40 tensor = bucket.buffer() 41 if not cref.use_cuda_rpc: 42 tensor = tensor.cpu() 43 sparse = tensor.is_sparse 44 if sparse: 45 tensor = sparse_tensor_to_rpc_format(tensor) 46 b_index = bucket.get_index() 47 server_args = [cref.server_rref, state.batch_number, b_index, tensor] 48 key = state.get_key(b_index) 49 cref.record_start("hook_future_metric", key, RPC_SPARSE if sparse else RPC_DENSE) 50 fut = cref.server_rref.rpc_async().average_gradient(*server_args) 51 52 def callback(fut): 53 cref.record_end("hook_future_metric", key) 54 tensor = fut.wait() 55 if type(tensor) is list: 56 tensor = sparse_rpc_format_to_tensor(tensor) 57 tensor = tensor.cuda(cref.rank) 58 return [tensor] 59 60 return fut.then(callback) 61