xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/distributed/autograd/rpc_messages/autograd_metadata.h>
4 #include <torch/csrc/distributed/rpc/message.h>
5 #include <torch/csrc/distributed/rpc/rpc_command_base.h>
6 #include <vector>
7 
8 namespace torch {
9 namespace distributed {
10 namespace autograd {
11 
12 // Used to propagate gradients from one node to another during a distributed
13 // backwards pass. This RPC call is invoked when we hit a `recv` autograd
14 // function during backward pass execution.
15 class TORCH_API PropagateGradientsReq : public rpc::RpcCommandBase {
16  public:
17   PropagateGradientsReq(
18       const AutogradMetadata& autogradMetadata,
19       std::vector<torch::autograd::Variable> grads,
20       bool retainGraph = false);
21 
22   const AutogradMetadata& getAutogradMetadata();
23 
24   const std::vector<torch::autograd::Variable>& getGrads();
25 
26   // Serialization and deserialization methods.
27   c10::intrusive_ptr<rpc::Message> toMessageImpl() && override;
28   static std::unique_ptr<PropagateGradientsReq> fromMessage(
29       const rpc::Message& message);
30 
31   // Whether or not to retain the autograd graph.
32   bool retainGraph();
33 
34  private:
35   AutogradMetadata autogradMetadata_;
36   std::vector<torch::autograd::Variable> grads_;
37   bool retainGraph_;
38 };
39 
40 } // namespace autograd
41 } // namespace distributed
42 } // namespace torch
43