xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/Variadic.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <utility>
4 
5 #include <c10/util/ArrayRef.h>
6 #include <ATen/core/List.h>
7 
8 namespace at {
9 
10 // This class allows you to write variadic functions which
11 // call a (possibly overloaded) function on each argument,
12 // in order.  This is most commonly used in autogenerated code,
13 // where it is convenient to have a function that can uniformly
14 // take arguments of different types.  If your arguments
15 // are homogenous consider using a std::initializer_list instead.
16 //
17 // For examples of this in use, see torch/csrc/utils/variadic.h
18 template <typename F>
19 struct IterArgs {
20   template <typename... Args>
applyIterArgs21   inline F& apply() {
22     return self();
23   }
24 
25   // NB: Use perfect forwarding here, otherwise we'll make value
26   // copies of all arguments!
27   template <typename T, typename... Args>
applyIterArgs28   inline F& apply(T&& arg, Args&&... args) {
29     self()(std::forward<T>(arg));
30     if (self().short_circuit()) {
31       return self();
32     } else {
33       return apply(std::forward<Args>(args)...);
34     }
35   }
36 
37   // Here are some handy overloads which provide sensible
38   // defaults for container-like structures that one might
39   // be interested in recursing into.  You can enable them
40   // by adding:
41   //
42   //    using IterArgs<YourStructName>::operator()
43   //
44   // to your struct.  These are not enabled by default because
45   // you may be able to process these structures more efficiently
46   // than handling them one-by-one.
47 
48   template <typename T>
operatorIterArgs49   void operator()(c10::IListRef<T> args) {
50     for (const auto& arg : args) {
51       self()(arg);
52       if (self().short_circuit())
53         return;
54     }
55   }
56 
57   template <typename T>
operatorIterArgs58   void operator()(at::ArrayRef<T> args) {
59     for (const auto& arg : args) {
60       self()(arg);
61       if (self().short_circuit())
62         return;
63     }
64   }
65 
66   template <typename T>
operatorIterArgs67   void operator()(const torch::List<T>& args) {
68     for (const auto& arg : args) {
69       self()(arg);
70       if (self().short_circuit())
71         return;
72     }
73   }
74 
75   // NB: we need to specify std::vector manually as C++ won't
76   // do an implicit conversion to make a template deduction go through.
77   template <typename T>
operatorIterArgs78   void operator()(const std::vector<T>& args) {
79     self()(at::ArrayRef<T>{args});
80   }
81 
short_circuitIterArgs82   constexpr bool short_circuit() const {
83     return false;
84   }
85 
86  private:
selfIterArgs87   inline F& self() {
88     return *static_cast<F*>(this);
89   }
90 };
91 
92 } // namespace torch
93