1 #include <ATen/core/functional.h>
2 #include <c10/util/irange.h>
3 #include <torch/csrc/distributed/autograd/functions/recvrpc_backward.h>
4 #include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.h>
5 #include <torch/csrc/distributed/rpc/rpc_agent.h>
6
7 namespace torch {
8 namespace distributed {
9 namespace autograd {
10
11 using torch::autograd::Variable;
12 using torch::autograd::variable_list;
13
RecvRpcBackward(const AutogradMetadata & autogradMetadata,ContextPtr autogradContext,rpc::worker_id_t fromWorkerId,rpc::DeviceMap deviceMap)14 RecvRpcBackward::RecvRpcBackward(
15 const AutogradMetadata& autogradMetadata,
16 ContextPtr autogradContext,
17 rpc::worker_id_t fromWorkerId,
18 rpc::DeviceMap deviceMap)
19 : autogradMetadata_(autogradMetadata),
20 autogradContext_(std::move(autogradContext)),
21 fromWorkerId_(fromWorkerId),
22 deviceMap_(std::move(deviceMap)) {}
23
apply(variable_list && grads)24 variable_list RecvRpcBackward::apply(variable_list&& grads) {
25 std::vector<Variable> outputGrads;
26 for (const auto i : c10::irange(grads.size())) {
27 const auto& grad = grads[i];
28 if (grad.defined()) {
29 outputGrads.emplace_back(grad);
30 } else {
31 // Put in zeros for a tensor with no grad.
32 outputGrads.emplace_back(input_metadata(i).zeros_like());
33 }
34 }
35
36 auto sharedContext = autogradContext_.lock();
37 TORCH_CHECK(
38 sharedContext,
39 c10::str(
40 "Autograd context no longer valid! This usually ",
41 "means the autograd context was cleaned up by a different thread due ",
42 "to an error before RecvRcpBackward had a chance to run"));
43
44 // Send the gradients over the wire and record the future in the autograd
45 // context.
46 PropagateGradientsReq gradCall(
47 autogradMetadata_,
48 outputGrads,
49 sharedContext->retrieveGraphTask()->keep_graph_);
50
51 // Send the gradients over to the appropriate node.
52 auto rpcAgent = rpc::RpcAgent::getCurrentRpcAgent();
53 auto jitFuture = rpcAgent->send(
54 rpcAgent->getWorkerInfo(fromWorkerId_),
55 std::move(gradCall).toMessage(),
56 rpc::kUnsetRpcTimeout,
57 deviceMap_);
58
59 // Record the future in the context.
60 sharedContext->addOutstandingRpc(jitFuture);
61
62 // 'recv' function sends the gradients over the wire using RPC, it doesn't
63 // need to return anything for any downstream autograd function.
64 return variable_list();
65 }
66
67 } // namespace autograd
68 } // namespace distributed
69 } // namespace torch
70