xref: /aosp_15_r20/external/federated-compute/fcp/base/match.h (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2019 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 // 'Match' expressions for {std, absl}::variant.
18 //
19 // {std, absl}::variant is an algebraic sum type. However, the standard library
20 // does not provide a convenient way to destructure or match on them - unlike in
21 // Haskell, Rust, etc.
22 //
23 // This file provides a way to match on :variant in a way akin to a switch
24 // statement.
25 //
26 // Example:
27 //
28 //   using V = std::variant<X, Y, Z>;
29 //   V v = ...;
30 //   ...
31 //   int i = Match(v,
32 //     [](X const& x) { return 1; },
33 //     [](Y const& y) { return 2; },
34 //     [](Z const& z) { return 3; });
35 //
36 // It is a compile-time error if the match is not exhaustive. A 'Default' case
37 // can be provided:
38 //
39 //   int i = Match(v,
40 //     [](X const& x) { return 1; },
41 //     // Called with the otherwise-unhandled alternative (see decltype(alt)).
42 //     [](Default, auto const& alt) { ...; });
43 //
44 //   int i = Match(v,
45 //     [](X const& x) { return 1; },
46 //     // Called with the variant itself.
47 //     [](Default, V const& v) { ...; });
48 //
49 // If constructing the matcher lambdas is non-trivial, it might be worthwhile to
50 // create a re-usable matcher object. See 'MakeMatcher'.
51 
52 #ifndef FCP_BASE_MATCH_H_
53 #define FCP_BASE_MATCH_H_
54 
55 #include <optional>
56 #include <type_traits>
57 #include <variant>
58 
59 #include "fcp/base/meta.h"
60 
61 namespace fcp {
62 
63 // Marker type for default match cases.
64 struct Default {};
65 
66 namespace match_internal {
67 
68 template <typename... CaseFns>
69 struct MatchCasesCallable : public CaseFns... {
70   // Each CaseFn provides operator(). We want to pick one by overload
71   // resolution.
72   using CaseFns::operator()...;
73 };
74 
75 template <typename ToType, typename... CaseFns>
76 class MatchCases {
77  public:
MatchCases(MatchCasesCallable<CaseFns...> c)78   explicit constexpr MatchCases(MatchCasesCallable<CaseFns...> c)
79       : callable_(std::move(c)) {}
80 
81   // False by default
82   template <typename Enable, typename... T>
83   struct IsCaseHandledImpl : public std::false_type {};
84 
85   // True when m.MatchCases(args...) is well-formed, for a
86   // MatchCases<CaseFns...> m and T arg.
87   template <typename... T>
88   struct IsCaseHandledImpl<
89       std::void_t<decltype(std::declval<MatchCasesCallable<CaseFns...>>()(
90           std::declval<T>()...))>,
91       T...> : public std::true_type {};
92 
93   template <typename... T>
94   static constexpr bool IsCaseHandled() {
95     return IsCaseHandledImpl<void, T...>::value;
96   }
97 
98   template <typename ToType_ = ToType, typename... Args>
99   constexpr auto operator()(Args&&... args) const {
100     if constexpr (std::is_void_v<ToType_>) {
101       return callable_(std::forward<Args>(args)...);
102     } else {
103       return ToType_(callable_(std::forward<Args>(args)...));
104     }
105   }
106 
107  private:
108   MatchCasesCallable<CaseFns...> callable_;
109 };
110 
111 template <typename ToType, typename... CaseFns>
112 constexpr MatchCases<ToType, CaseFns...> MakeMatchCases(CaseFns... case_fns) {
113   return MatchCases<ToType, CaseFns...>(
114       MatchCasesCallable<CaseFns...>{case_fns...});
115 }
116 
117 template <typename CasesType, typename VariantType, typename ArgType>
118 constexpr auto ApplyCase(CasesType const& cases, VariantType&& v,
119                          ArgType&& arg) {
120   if constexpr (CasesType::template IsCaseHandled<ArgType>()) {
121     return cases(std::forward<ArgType>(arg));
122   } else if constexpr (CasesType::template IsCaseHandled<Default, ArgType>()) {
123     return cases(Default{}, std::forward<ArgType>(arg));
124   } else if constexpr (CasesType::template IsCaseHandled<Default,
125                                                          VariantType>()) {
126     return cases(Default{}, std::forward<VariantType>(v));
127   } else if constexpr (CasesType::template IsCaseHandled<Default>()) {
128     return cases(Default{});
129   } else {
130     static_assert(
131         FailIfReached<ArgType>(),
132         "Provide a case for all variant alternatives, or a 'Default' case");
133   }
134 }
135 
136 template <typename Traits, typename CasesType>
137 class VariantMatcherImpl {
138  public:
139   using ValueType = typename Traits::ValueType;
140 
141   explicit constexpr VariantMatcherImpl(CasesType cases)
142       : cases_(std::move(cases)) {}
143 
144   constexpr auto Match(ValueType* v) const { return MatchInternal(v); }
145 
146   constexpr auto Match(ValueType const& v) const { return MatchInternal(v); }
147 
148   constexpr auto Match(ValueType&& v) const {
149     return MatchInternal(std::move(v));
150   }
151 
152  private:
153   template <typename FromType>
154   constexpr auto MatchInternal(FromType&& v) const {
155     return Traits::Visit(std::forward<FromType>(v), [this, &v](auto&& alt) {
156       return ApplyCase(cases_, std::forward<FromType>(v),
157                        std::forward<decltype(alt)>(alt));
158     });
159   }
160 
161   CasesType cases_;
162 };
163 
164 template <typename T, typename Enable = void>
165 struct MatchTraits {
166   static_assert(FailIfReached<T>(),
167                 "Only variant-like (e.g. std::variant<...> types can be "
168                 "matched. See MatchTraits.");
169 };
170 
171 template <typename... AltTypes>
172 struct MatchTraits<std::variant<AltTypes...>> {
173   using ValueType = std::variant<AltTypes...>;
174 
175   template <typename VisitFn>
176   static constexpr auto Visit(ValueType const& v, VisitFn&& fn) {
177     return absl::visit(std::forward<VisitFn>(fn), v);
178   }
179 
180   template <typename VisitFn>
181   static constexpr auto Visit(ValueType&& v, VisitFn&& fn) {
182     return absl::visit(std::forward<VisitFn>(fn), std::move(v));
183   }
184 
185   template <typename VisitFn>
186   static constexpr auto Visit(ValueType* v, VisitFn&& fn) {
187     return absl::visit([fn = std::forward<VisitFn>(fn)](
188                            auto& alt) mutable { return fn(&alt); },
189                        *v);
190   }
191 };
192 
193 template <typename T>
194 struct MatchTraits<std::optional<T>> {
195   using ValueType = std::optional<T>;
196 
197   static constexpr auto Wrap(std::optional<T>* o)
198       -> std::variant<T*, std::nullopt_t> {
199     if (o->has_value()) {
200       return &**o;
201     } else {
202       return std::nullopt;
203     }
204   }
205 
206   static constexpr auto Wrap(std::optional<T> const& o)
207       -> std::variant<std::reference_wrapper<T const>, std::nullopt_t> {
208     if (o.has_value()) {
209       return std::ref(*o);
210     } else {
211       return std::nullopt;
212     }
213   }
214 
215   static constexpr auto Wrap(std::optional<T>&& o)
216       -> std::variant<T, std::nullopt_t> {
217     if (o.has_value()) {
218       return *std::move(o);
219     } else {
220       return std::nullopt;
221     }
222   }
223 
224   template <typename V, typename VisitFn>
225   static constexpr auto Visit(V&& v, VisitFn&& fn) {
226     return absl::visit(std::forward<VisitFn>(fn), Wrap(std::forward<V>(v)));
227   }
228 };
229 
230 template <typename T>
231 struct MatchTraits<T, std::void_t<typename T::VariantType>> {
232   using ValueType = T;
233 
234   template <typename VisitFn>
235   static constexpr auto Visit(ValueType const& v, VisitFn&& fn) {
236     return MatchTraits<typename T::VariantType>::Visit(
237         v.variant(), std::forward<VisitFn>(fn));
238   }
239 
240   template <typename VisitFn>
241   static constexpr auto Visit(ValueType&& v, VisitFn&& fn) {
242     return MatchTraits<typename T::VariantType>::Visit(
243         std::move(v).variant(), std::forward<VisitFn>(fn));
244   }
245 
246   template <typename VisitFn>
247   static constexpr auto Visit(ValueType* v, VisitFn&& fn) {
248     return MatchTraits<typename T::VariantType>::Visit(
249         &v->variant(), std::forward<VisitFn>(fn));
250   }
251 };
252 
253 template <typename VariantType, typename CasesType>
254 constexpr auto CreateMatcherImpl(CasesType cases) {
255   return VariantMatcherImpl<MatchTraits<VariantType>, CasesType>(
256       std::move(cases));
257 }
258 
259 }  // namespace match_internal
260 
261 // See file remarks.
262 template <typename From, typename To = void, typename... CaseFnTypes>
263 constexpr auto MakeMatcher(CaseFnTypes... fns) {
264   return match_internal::CreateMatcherImpl<From>(
265       match_internal::MakeMatchCases<To>(fns...));
266 }
267 
268 // See file remarks.
269 //
270 // Note that the order of template arguments differs from MakeMatcher; it is
271 // expected that 'From' is always deduced (but it can be useful to specify 'To'
272 // explicitly).
273 template <typename To = void, typename From, typename... CaseFnTypes>
274 constexpr auto Match(From&& v, CaseFnTypes... fns) {
275   // 'From' is intended to be deduced. For MakeMatcher, we want V (not e.g. V
276   // const&).
277   auto m = MakeMatcher<std::decay_t<From>, To>(fns...);
278   // The full type is still relevant for forwarding.
279   return m.Match(std::forward<From>(v));
280 }
281 
282 template <typename To = void, typename From, typename... CaseFnTypes>
283 constexpr auto Match(From* v, CaseFnTypes... fns) {
284   // 'From' is intended to be deduced. For MakeMatcher, we want V (not e.g. V
285   // const*).
286   auto m = MakeMatcher<std::decay_t<From>, To>(fns...);
287   return m.Match(v);
288 }
289 
290 }  // namespace fcp
291 
292 #endif  // FCP_BASE_MATCH_H_
293