1 #pragma once 2 3 #include <utility> 4 5 #include <c10/util/ArrayRef.h> 6 #include <ATen/core/List.h> 7 8 namespace at { 9 10 // This class allows you to write variadic functions which 11 // call a (possibly overloaded) function on each argument, 12 // in order. This is most commonly used in autogenerated code, 13 // where it is convenient to have a function that can uniformly 14 // take arguments of different types. If your arguments 15 // are homogenous consider using a std::initializer_list instead. 16 // 17 // For examples of this in use, see torch/csrc/utils/variadic.h 18 template <typename F> 19 struct IterArgs { 20 template <typename... Args> applyIterArgs21 inline F& apply() { 22 return self(); 23 } 24 25 // NB: Use perfect forwarding here, otherwise we'll make value 26 // copies of all arguments! 27 template <typename T, typename... Args> applyIterArgs28 inline F& apply(T&& arg, Args&&... args) { 29 self()(std::forward<T>(arg)); 30 if (self().short_circuit()) { 31 return self(); 32 } else { 33 return apply(std::forward<Args>(args)...); 34 } 35 } 36 37 // Here are some handy overloads which provide sensible 38 // defaults for container-like structures that one might 39 // be interested in recursing into. You can enable them 40 // by adding: 41 // 42 // using IterArgs<YourStructName>::operator() 43 // 44 // to your struct. These are not enabled by default because 45 // you may be able to process these structures more efficiently 46 // than handling them one-by-one. 47 48 template <typename T> operatorIterArgs49 void operator()(c10::IListRef<T> args) { 50 for (const auto& arg : args) { 51 self()(arg); 52 if (self().short_circuit()) 53 return; 54 } 55 } 56 57 template <typename T> operatorIterArgs58 void operator()(at::ArrayRef<T> args) { 59 for (const auto& arg : args) { 60 self()(arg); 61 if (self().short_circuit()) 62 return; 63 } 64 } 65 66 template <typename T> operatorIterArgs67 void operator()(const torch::List<T>& args) { 68 for (const auto& arg : args) { 69 self()(arg); 70 if (self().short_circuit()) 71 return; 72 } 73 } 74 75 // NB: we need to specify std::vector manually as C++ won't 76 // do an implicit conversion to make a template deduction go through. 77 template <typename T> operatorIterArgs78 void operator()(const std::vector<T>& args) { 79 self()(at::ArrayRef<T>{args}); 80 } 81 short_circuitIterArgs82 constexpr bool short_circuit() const { 83 return false; 84 } 85 86 private: selfIterArgs87 inline F& self() { 88 return *static_cast<F*>(this); 89 } 90 }; 91 92 } // namespace torch 93