xref: /aosp_15_r20/external/pytorch/aten/src/ATen/functorch/BatchingMetaprogramming.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6 
7 #pragma once
8 #include <ATen/Tensor.h>
9 #include <ATen/VmapGeneratedPlumbing.h>
10 
11 // This file contains template metaprogramming things that are used for our
12 // batching rules.
13 //
14 // See NOTE: [vmap plumbing] for more details on why this is necessary.
15 // The plumbing has a bunch of metaprogramming hacks for determining the signature
16 // of a batching rule from the signature of the operator, many of which use the
17 // helper functions in this file.
18 
19 namespace at::functorch {
20 
21 // Metaprogramming things
22 template <class... Items> using typelist = c10::guts::typelist::typelist<Items...>;
23 template <class TypeList> using head_t = c10::guts::typelist::head_t<TypeList>;
24 template <class TL1, class TL2> using concat_t = c10::guts::typelist::concat_t<TL1, TL2>;
25 template <typename T> class debug_t;
26 
27 // tail operation
28 template<class TypeList>
29 struct tail final {
30     static_assert(c10::guts::false_t<TypeList>::value,
31                   "In typelist::tail<T>, the T argument must be typelist<...>.");
32 };
33 template<class Head, class... Tail>
34 struct tail<typelist<Head, Tail...>> final {
35   using type = typelist<Tail...>;
36 };
37 template<class TypeList> using tail_t = typename tail<TypeList>::type;
38 
39 template <class First, class Second, class Next, class Tail>
40 struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext {
41   using type = Next;
42 };
43 template <class Next, class Tail>
44 struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<Tensor, std::optional<int64_t>, Next, Tail> {
45   using type = Tail;
46 };
47 template <class Next, class Tail>
48 struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<const Tensor&, std::optional<int64_t>, Next, Tail> {
49   using type = Tail;
50 };
51 template <class Next, class Tail>
52 struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<Tensor&, std::optional<int64_t>, Next, Tail> {
53   using type = Tail;
54 };
55 template <class Next, class Tail>
56 struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<std::optional<Tensor>, std::optional<int64_t>, Next, Tail> {
57   using type = Tail;
58 };
59 template <class Next, class Tail>
60 struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<const std::optional<Tensor>&, std::optional<int64_t>, Next, Tail> {
61   using type = Tail;
62 };
63 template <class Next, class Tail>
64 struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<std::optional<Tensor>&, std::optional<int64_t>, Next, Tail> {
65   using type = Tail;
66 };
67 template <class Next, class Tail>
68 struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<std::vector<Tensor>, std::optional<int64_t>, Next, Tail> {
69   using type = Tail;
70 };
71 template <class TypeList> struct RemoveBatchDimAfterTensor {
72   using first = head_t<TypeList>;
73   using next = tail_t<TypeList>;
74   using second = head_t<next>;
75   using tail = tail_t<next>;
76 
77   using type = concat_t<
78     typelist<first>,
79     typename RemoveBatchDimAfterTensor<
80       typename IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<first, second, next, tail>::type
81     >::type
82   >;
83 };
84 template <class Type> struct RemoveBatchDimAfterTensor<typelist<Type>> {
85   using type = typelist<Type>;
86 };
87 template <> struct RemoveBatchDimAfterTensor<typelist<>> {
88   using type = typelist<>;
89 };
90 template<class TypeList> using remove_batch_dim_after_tensor_t = typename RemoveBatchDimAfterTensor<TypeList>::type;
91 
92 template <typename T> struct UnpackSingleItemTuple {
93   using type = T;
94 };
95 template <typename T> struct UnpackSingleItemTuple<std::tuple<T>> {
96   using type = T;
97 };
98 template <typename T> using unpack_single_item_tuple_t = typename UnpackSingleItemTuple<T>::type;
99 
100 template <typename Return, typename TupleArgs> struct BuildFunctionHelper;
101 template <typename Return, typename... Args> struct BuildFunctionHelper<Return, std::tuple<Args...>> {
102   using type = Return(Args...);
103 };
104 template <typename Return, typename TL>
105 struct BuildFunction {
106   using type = typename BuildFunctionHelper<Return, c10::guts::typelist::to_tuple_t<TL>>::type;
107 };
108 template <typename Return, typename TL> using build_function_t = typename BuildFunction<Return, TL>::type;
109 
110 
111 template <typename batch_rule_t> struct ToOperatorType {
112   using batch_rule_return_type = typename c10::guts::function_traits<batch_rule_t>::return_type;
113   using batch_rule_parameter_types = typename c10::guts::function_traits<batch_rule_t>::parameter_types;
114 
115   using operator_parameter_types = remove_batch_dim_after_tensor_t<batch_rule_parameter_types>;
116   using operator_return_type =
117     unpack_single_item_tuple_t<
118       c10::guts::typelist::to_tuple_t<
119         remove_batch_dim_after_tensor_t<
120           c10::guts::typelist::from_tuple_t<batch_rule_return_type>>>>;
121 
122   using type = build_function_t<operator_return_type, operator_parameter_types>;
123 };
124 template <typename batch_rule_t> using to_operator_t = typename ToOperatorType<batch_rule_t>::type;
125 
126 } // namespace at::functorch
127