xref: /aosp_15_r20/external/pytorch/torch/csrc/utils/variadic.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/Tensor.h>
4 #include <ATen/core/Variadic.h>
5 #include <torch/csrc/autograd/variable.h>
6 
7 #include <type_traits>
8 #include <utility>
9 
10 namespace torch {
11 
12 using at::IterArgs;
13 
14 struct CountTensors : IterArgs<CountTensors> {
15   size_t out = 0;
operatorCountTensors16   void operator()(const at::Tensor& x) {
17     out += 1;
18   }
operatorCountTensors19   void operator()(const std::optional<at::Tensor>& x) {
20     out += x.has_value();
21   }
operatorCountTensors22   void operator()(at::ArrayRef<at::Tensor> xs) {
23     out += xs.size();
24   }
25 };
26 
27 template <typename... Args>
count_tensors(Args &&...args)28 size_t count_tensors(Args&&... args) {
29   return CountTensors().apply(std::forward<Args>(args)...).out;
30 }
31 
32 struct CountVariables : IterArgs<CountVariables> {
33   size_t out = 0;
operatorCountVariables34   void operator()(const autograd::Variable& x) {
35     out += 1;
36   }
operatorCountVariables37   void operator()(at::ArrayRef<autograd::Variable> xs) {
38     out += xs.size();
39   }
40 };
41 
42 template <typename... Args>
count_variables(Args &&...args)43 inline size_t count_variables(Args&&... args) {
44   return CountVariables().apply(std::forward<Args>(args)...).out;
45 }
46 
47 //===----------------------------------------------------------------------===//
48 //                std::index_sequence shim for C++11
49 //===----------------------------------------------------------------------===//
50 
51 // A container of type-template parameter indices.
52 template <size_t... Is>
53 struct Indices {};
54 
55 // Decrements the index N, adds N-1 to the list of indices and forwards
56 // whatever we already have.
57 template <size_t N, size_t... Is>
58 struct MakeIndices : MakeIndices<N - 1, N - 1, Is...> {};
59 
60 // Partial specialization that forms our base case. When N is zero, we stop
61 // and define a typedef that will be visible to earlier classes due to
62 // inheritance. The typedef we define is an index list containing the numbers
63 // 0 through N-1.
64 template <size_t... Is>
65 struct MakeIndices<0, Is...> {
66   using indices = Indices<Is...>;
67 };
68 
69 //===----------------------------------------------------------------------===//
70 //                                 Utilities
71 //===----------------------------------------------------------------------===//
72 
73 template <typename Function, typename... Ts>
74 void apply(Function function, Ts&&... ts) {
75   // https://stackoverflow.com/questions/13978916/inserting-a-variadic-argument-list-into-a-vector
76   // Creates a dummy array, so that each function call is evaluated in order.
77   // `(function(), 0)` is because `function` should (!) return `void`, so
78   // according to the comma operator, it is evaluated and its result (`void`)
79   // is discarded. Then the zero is evaluated and used as an element in the
80   // array. The first zero ensures the array is not empty.
81   // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
82   int _[]{0, (function(std::forward<Ts>(ts)), 0)...};
83   (void)_;
84 }
85 
86 template <
87     typename ReturnType,
88     typename... Ts,
89     typename Function,
90     typename Accessor>
91 ReturnType unpack(Function function, Accessor accessor) {
92   return ReturnType(unpack<ReturnType, Ts...>(
93       std::move(function),
94       std::move(accessor),
95       typename MakeIndices<sizeof...(Ts)>::indices()));
96 }
97 
98 template <
99     typename ReturnType,
100     typename... Ts,
101     typename Function,
102     typename Accessor,
103     size_t... Is>
104 ReturnType unpack(Function function, Accessor accessor, Indices<Is...>) {
105   return ReturnType(function(accessor.template operator()<Ts>(Is)...));
106 }
107 
108 } // namespace torch
109