xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/input_buffer.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 // The InputBuffer class accumulates a list of Variables for use by a
4 // function. It implements logic to avoid modifying the passed
5 // values in-place (adding an input twice will accumulate the result).
6 // This behaviour is needed and used only in backward graphs.
7 
8 #include <utility>
9 #include <vector>
10 
11 #include <c10/core/Stream.h>
12 #include <torch/csrc/autograd/variable.h>
13 #include <optional>
14 
15 namespace torch::autograd {
16 
17 struct InputBuffer {
InputBufferInputBuffer18   explicit InputBuffer(size_t size) : buffer(size) {}
19   InputBuffer(const InputBuffer& other) = delete;
20   InputBuffer(InputBuffer&& other) = default;
InputBufferInputBuffer21   explicit InputBuffer(variable_list&& inputs) : buffer(std::move(inputs)){};
22   InputBuffer& operator=(InputBuffer&& other) = default;
23 
24   // Accumulates the variable at a specified index.
25   // The optional CUDA streams determine which stream the accumulation
26   // is run on and how the addition is synchronized.
27   TORCH_API void add(
28       size_t pos,
29       Variable&& var,
30       const std::optional<c10::Stream>& opt_producer_stream,
31       const std::optional<c10::Stream>& opt_consumer_stream);
32 
33   at::Device device() const;
34 
35   Variable operator[](size_t pos) {
36     return buffer[pos];
37   }
38 
39   // Returns the inputs as a list of variables. Destroys given InputBuffer.
40   static std::vector<Variable> variables(InputBuffer&& g);
41 
42   std::vector<Variable> buffer;
43 };
44 
45 } // namespace torch::autograd
46