xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/autograd/autograd.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/distributed/autograd/context/container.h>
4 #include <torch/csrc/distributed/autograd/engine/dist_engine.h>
5 
6 namespace torch {
7 namespace distributed {
8 namespace autograd {
9 
10 using torch::autograd::variable_list;
11 
12 /// C++ API of Distributed Autograd that kicks off the distributed backward pass
13 /// using the provided roots. This currently implements the
14 /// :ref:`fast-mode-algorithm` which assumes all RPC messages sent in the same
15 /// distributed autograd context across workers would be part of the autograd
16 /// graph during the backward pass.
17 ///
18 /// We use the provided roots to discover the autograd graph and compute
19 /// appropriate dependencies. This method blocks until the entire
20 /// autograd computation is done.
21 /// This function accumulates gradients in the leaves - you might need to zero
22 /// them before calling it.
23 ///
24 /// \param context_id The autograd context id for which we should retrieve the
25 ///                   gradients.
26 /// \param roots Tensors which represent the roots of the autograd computation.
27 ///              All the tensors should be scalars.
28 /// \param retain_graph If `false`, the graph used to compute the grad will be
29 ///                     freed. Note that in nearly all cases setting this
30 ///                     option to `true` is not needed and often can be worked
31 ///                     around in a much more efficient way. Usually, you need
32 ///                     to set this to `true` to run backward multiple times.
33 TORCH_API void backward(
34     int64_t context_id,
35     const variable_list& roots,
36     bool retain_graph = false);
37 
38 } // namespace autograd
39 } // namespace distributed
40 } // namespace torch
41