1 #pragma once
2
3 #include <ATen/core/Tensor.h>
4 #include <ATen/core/Variadic.h>
5 #include <torch/csrc/autograd/variable.h>
6
7 #include <type_traits>
8 #include <utility>
9
10 namespace torch {
11
12 using at::IterArgs;
13
14 struct CountTensors : IterArgs<CountTensors> {
15 size_t out = 0;
operatorCountTensors16 void operator()(const at::Tensor& x) {
17 out += 1;
18 }
operatorCountTensors19 void operator()(const std::optional<at::Tensor>& x) {
20 out += x.has_value();
21 }
operatorCountTensors22 void operator()(at::ArrayRef<at::Tensor> xs) {
23 out += xs.size();
24 }
25 };
26
27 template <typename... Args>
count_tensors(Args &&...args)28 size_t count_tensors(Args&&... args) {
29 return CountTensors().apply(std::forward<Args>(args)...).out;
30 }
31
32 struct CountVariables : IterArgs<CountVariables> {
33 size_t out = 0;
operatorCountVariables34 void operator()(const autograd::Variable& x) {
35 out += 1;
36 }
operatorCountVariables37 void operator()(at::ArrayRef<autograd::Variable> xs) {
38 out += xs.size();
39 }
40 };
41
42 template <typename... Args>
count_variables(Args &&...args)43 inline size_t count_variables(Args&&... args) {
44 return CountVariables().apply(std::forward<Args>(args)...).out;
45 }
46
47 //===----------------------------------------------------------------------===//
48 // std::index_sequence shim for C++11
49 //===----------------------------------------------------------------------===//
50
51 // A container of type-template parameter indices.
52 template <size_t... Is>
53 struct Indices {};
54
55 // Decrements the index N, adds N-1 to the list of indices and forwards
56 // whatever we already have.
57 template <size_t N, size_t... Is>
58 struct MakeIndices : MakeIndices<N - 1, N - 1, Is...> {};
59
60 // Partial specialization that forms our base case. When N is zero, we stop
61 // and define a typedef that will be visible to earlier classes due to
62 // inheritance. The typedef we define is an index list containing the numbers
63 // 0 through N-1.
64 template <size_t... Is>
65 struct MakeIndices<0, Is...> {
66 using indices = Indices<Is...>;
67 };
68
69 //===----------------------------------------------------------------------===//
70 // Utilities
71 //===----------------------------------------------------------------------===//
72
73 template <typename Function, typename... Ts>
74 void apply(Function function, Ts&&... ts) {
75 // https://stackoverflow.com/questions/13978916/inserting-a-variadic-argument-list-into-a-vector
76 // Creates a dummy array, so that each function call is evaluated in order.
77 // `(function(), 0)` is because `function` should (!) return `void`, so
78 // according to the comma operator, it is evaluated and its result (`void`)
79 // is discarded. Then the zero is evaluated and used as an element in the
80 // array. The first zero ensures the array is not empty.
81 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
82 int _[]{0, (function(std::forward<Ts>(ts)), 0)...};
83 (void)_;
84 }
85
86 template <
87 typename ReturnType,
88 typename... Ts,
89 typename Function,
90 typename Accessor>
91 ReturnType unpack(Function function, Accessor accessor) {
92 return ReturnType(unpack<ReturnType, Ts...>(
93 std::move(function),
94 std::move(accessor),
95 typename MakeIndices<sizeof...(Ts)>::indices()));
96 }
97
98 template <
99 typename ReturnType,
100 typename... Ts,
101 typename Function,
102 typename Accessor,
103 size_t... Is>
104 ReturnType unpack(Function function, Accessor accessor, Indices<Is...>) {
105 return ReturnType(function(accessor.template operator()<Ts>(Is)...));
106 }
107
108 } // namespace torch
109