1 #pragma once 2 3 #include <torch/csrc/distributed/autograd/context/container.h> 4 #include <torch/csrc/distributed/autograd/engine/dist_engine.h> 5 6 namespace torch { 7 namespace distributed { 8 namespace autograd { 9 10 using torch::autograd::variable_list; 11 12 /// C++ API of Distributed Autograd that kicks off the distributed backward pass 13 /// using the provided roots. This currently implements the 14 /// :ref:`fast-mode-algorithm` which assumes all RPC messages sent in the same 15 /// distributed autograd context across workers would be part of the autograd 16 /// graph during the backward pass. 17 /// 18 /// We use the provided roots to discover the autograd graph and compute 19 /// appropriate dependencies. This method blocks until the entire 20 /// autograd computation is done. 21 /// This function accumulates gradients in the leaves - you might need to zero 22 /// them before calling it. 23 /// 24 /// \param context_id The autograd context id for which we should retrieve the 25 /// gradients. 26 /// \param roots Tensors which represent the roots of the autograd computation. 27 /// All the tensors should be scalars. 28 /// \param retain_graph If `false`, the graph used to compute the grad will be 29 /// freed. Note that in nearly all cases setting this 30 /// option to `true` is not needed and often can be worked 31 /// around in a much more efficient way. Usually, you need 32 /// to set this to `true` to run backward multiple times. 33 TORCH_API void backward( 34 int64_t context_id, 35 const variable_list& roots, 36 bool retain_graph = false); 37 38 } // namespace autograd 39 } // namespace distributed 40 } // namespace torch 41