xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/autograd/functions/recvrpc_backward.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/functional.h>
2 #include <c10/util/irange.h>
3 #include <torch/csrc/distributed/autograd/functions/recvrpc_backward.h>
4 #include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.h>
5 #include <torch/csrc/distributed/rpc/rpc_agent.h>
6 
7 namespace torch {
8 namespace distributed {
9 namespace autograd {
10 
11 using torch::autograd::Variable;
12 using torch::autograd::variable_list;
13 
RecvRpcBackward(const AutogradMetadata & autogradMetadata,ContextPtr autogradContext,rpc::worker_id_t fromWorkerId,rpc::DeviceMap deviceMap)14 RecvRpcBackward::RecvRpcBackward(
15     const AutogradMetadata& autogradMetadata,
16     ContextPtr autogradContext,
17     rpc::worker_id_t fromWorkerId,
18     rpc::DeviceMap deviceMap)
19     : autogradMetadata_(autogradMetadata),
20       autogradContext_(std::move(autogradContext)),
21       fromWorkerId_(fromWorkerId),
22       deviceMap_(std::move(deviceMap)) {}
23 
apply(variable_list && grads)24 variable_list RecvRpcBackward::apply(variable_list&& grads) {
25   std::vector<Variable> outputGrads;
26   for (const auto i : c10::irange(grads.size())) {
27     const auto& grad = grads[i];
28     if (grad.defined()) {
29       outputGrads.emplace_back(grad);
30     } else {
31       // Put in zeros for a tensor with no grad.
32       outputGrads.emplace_back(input_metadata(i).zeros_like());
33     }
34   }
35 
36   auto sharedContext = autogradContext_.lock();
37   TORCH_CHECK(
38       sharedContext,
39       c10::str(
40           "Autograd context no longer valid! This usually ",
41           "means the autograd context was cleaned up by a different thread due ",
42           "to an error before RecvRcpBackward had a chance to run"));
43 
44   // Send the gradients over the wire and record the future in the autograd
45   // context.
46   PropagateGradientsReq gradCall(
47       autogradMetadata_,
48       outputGrads,
49       sharedContext->retrieveGraphTask()->keep_graph_);
50 
51   // Send the gradients over to the appropriate node.
52   auto rpcAgent = rpc::RpcAgent::getCurrentRpcAgent();
53   auto jitFuture = rpcAgent->send(
54       rpcAgent->getWorkerInfo(fromWorkerId_),
55       std::move(gradCall).toMessage(),
56       rpc::kUnsetRpcTimeout,
57       deviceMap_);
58 
59   // Record the future in the context.
60   sharedContext->addOutstandingRpc(jitFuture);
61 
62   // 'recv' function sends the gradients over the wire using RPC, it doesn't
63   // need to return anything for any downstream autograd function.
64   return variable_list();
65 }
66 
67 } // namespace autograd
68 } // namespace distributed
69 } // namespace torch
70