1 #pragma once 2 3 #include <torch/csrc/distributed/autograd/rpc_messages/autograd_metadata.h> 4 #include <torch/csrc/distributed/rpc/message.h> 5 #include <torch/csrc/distributed/rpc/rpc_command_base.h> 6 #include <vector> 7 8 namespace torch { 9 namespace distributed { 10 namespace autograd { 11 12 // Used to propagate gradients from one node to another during a distributed 13 // backwards pass. This RPC call is invoked when we hit a `recv` autograd 14 // function during backward pass execution. 15 class TORCH_API PropagateGradientsReq : public rpc::RpcCommandBase { 16 public: 17 PropagateGradientsReq( 18 const AutogradMetadata& autogradMetadata, 19 std::vector<torch::autograd::Variable> grads, 20 bool retainGraph = false); 21 22 const AutogradMetadata& getAutogradMetadata(); 23 24 const std::vector<torch::autograd::Variable>& getGrads(); 25 26 // Serialization and deserialization methods. 27 c10::intrusive_ptr<rpc::Message> toMessageImpl() && override; 28 static std::unique_ptr<PropagateGradientsReq> fromMessage( 29 const rpc::Message& message); 30 31 // Whether or not to retain the autograd graph. 32 bool retainGraph(); 33 34 private: 35 AutogradMetadata autogradMetadata_; 36 std::vector<torch::autograd::Variable> grads_; 37 bool retainGraph_; 38 }; 39 40 } // namespace autograd 41 } // namespace distributed 42 } // namespace torch 43