1 // Copyright (c) Facebook, Inc. and its affiliates. 2 // All rights reserved. 3 // 4 // This source code is licensed under the BSD-style license found in the 5 // LICENSE file in the root directory of this source tree. 6 7 #pragma once 8 #include <ATen/Tensor.h> 9 #include <ATen/VmapGeneratedPlumbing.h> 10 11 // This file contains template metaprogramming things that are used for our 12 // batching rules. 13 // 14 // See NOTE: [vmap plumbing] for more details on why this is necessary. 15 // The plumbing has a bunch of metaprogramming hacks for determining the signature 16 // of a batching rule from the signature of the operator, many of which use the 17 // helper functions in this file. 18 19 namespace at::functorch { 20 21 // Metaprogramming things 22 template <class... Items> using typelist = c10::guts::typelist::typelist<Items...>; 23 template <class TypeList> using head_t = c10::guts::typelist::head_t<TypeList>; 24 template <class TL1, class TL2> using concat_t = c10::guts::typelist::concat_t<TL1, TL2>; 25 template <typename T> class debug_t; 26 27 // tail operation 28 template<class TypeList> 29 struct tail final { 30 static_assert(c10::guts::false_t<TypeList>::value, 31 "In typelist::tail<T>, the T argument must be typelist<...>."); 32 }; 33 template<class Head, class... Tail> 34 struct tail<typelist<Head, Tail...>> final { 35 using type = typelist<Tail...>; 36 }; 37 template<class TypeList> using tail_t = typename tail<TypeList>::type; 38 39 template <class First, class Second, class Next, class Tail> 40 struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext { 41 using type = Next; 42 }; 43 template <class Next, class Tail> 44 struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<Tensor, std::optional<int64_t>, Next, Tail> { 45 using type = Tail; 46 }; 47 template <class Next, class Tail> 48 struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<const Tensor&, std::optional<int64_t>, Next, Tail> { 49 using type = Tail; 50 }; 51 template <class Next, class Tail> 52 struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<Tensor&, std::optional<int64_t>, Next, Tail> { 53 using type = Tail; 54 }; 55 template <class Next, class Tail> 56 struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<std::optional<Tensor>, std::optional<int64_t>, Next, Tail> { 57 using type = Tail; 58 }; 59 template <class Next, class Tail> 60 struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<const std::optional<Tensor>&, std::optional<int64_t>, Next, Tail> { 61 using type = Tail; 62 }; 63 template <class Next, class Tail> 64 struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<std::optional<Tensor>&, std::optional<int64_t>, Next, Tail> { 65 using type = Tail; 66 }; 67 template <class Next, class Tail> 68 struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<std::vector<Tensor>, std::optional<int64_t>, Next, Tail> { 69 using type = Tail; 70 }; 71 template <class TypeList> struct RemoveBatchDimAfterTensor { 72 using first = head_t<TypeList>; 73 using next = tail_t<TypeList>; 74 using second = head_t<next>; 75 using tail = tail_t<next>; 76 77 using type = concat_t< 78 typelist<first>, 79 typename RemoveBatchDimAfterTensor< 80 typename IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<first, second, next, tail>::type 81 >::type 82 >; 83 }; 84 template <class Type> struct RemoveBatchDimAfterTensor<typelist<Type>> { 85 using type = typelist<Type>; 86 }; 87 template <> struct RemoveBatchDimAfterTensor<typelist<>> { 88 using type = typelist<>; 89 }; 90 template<class TypeList> using remove_batch_dim_after_tensor_t = typename RemoveBatchDimAfterTensor<TypeList>::type; 91 92 template <typename T> struct UnpackSingleItemTuple { 93 using type = T; 94 }; 95 template <typename T> struct UnpackSingleItemTuple<std::tuple<T>> { 96 using type = T; 97 }; 98 template <typename T> using unpack_single_item_tuple_t = typename UnpackSingleItemTuple<T>::type; 99 100 template <typename Return, typename TupleArgs> struct BuildFunctionHelper; 101 template <typename Return, typename... Args> struct BuildFunctionHelper<Return, std::tuple<Args...>> { 102 using type = Return(Args...); 103 }; 104 template <typename Return, typename TL> 105 struct BuildFunction { 106 using type = typename BuildFunctionHelper<Return, c10::guts::typelist::to_tuple_t<TL>>::type; 107 }; 108 template <typename Return, typename TL> using build_function_t = typename BuildFunction<Return, TL>::type; 109 110 111 template <typename batch_rule_t> struct ToOperatorType { 112 using batch_rule_return_type = typename c10::guts::function_traits<batch_rule_t>::return_type; 113 using batch_rule_parameter_types = typename c10::guts::function_traits<batch_rule_t>::parameter_types; 114 115 using operator_parameter_types = remove_batch_dim_after_tensor_t<batch_rule_parameter_types>; 116 using operator_return_type = 117 unpack_single_item_tuple_t< 118 c10::guts::typelist::to_tuple_t< 119 remove_batch_dim_after_tensor_t< 120 c10::guts::typelist::from_tuple_t<batch_rule_return_type>>>>; 121 122 using type = build_function_t<operator_return_type, operator_parameter_types>; 123 }; 124 template <typename batch_rule_t> using to_operator_t = typename ToOperatorType<batch_rule_t>::type; 125 126 } // namespace at::functorch 127