xref: /aosp_15_r20/external/federated-compute/fcp/base/match.h (revision 14675a029014e728ec732f129a32e299b2da0601)
1*14675a02SAndroid Build Coastguard Worker /*
2*14675a02SAndroid Build Coastguard Worker  * Copyright 2019 Google LLC
3*14675a02SAndroid Build Coastguard Worker  *
4*14675a02SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*14675a02SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*14675a02SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*14675a02SAndroid Build Coastguard Worker  *
8*14675a02SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*14675a02SAndroid Build Coastguard Worker  *
10*14675a02SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*14675a02SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*14675a02SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*14675a02SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*14675a02SAndroid Build Coastguard Worker  * limitations under the License.
15*14675a02SAndroid Build Coastguard Worker  */
16*14675a02SAndroid Build Coastguard Worker 
17*14675a02SAndroid Build Coastguard Worker // 'Match' expressions for {std, absl}::variant.
18*14675a02SAndroid Build Coastguard Worker //
19*14675a02SAndroid Build Coastguard Worker // {std, absl}::variant is an algebraic sum type. However, the standard library
20*14675a02SAndroid Build Coastguard Worker // does not provide a convenient way to destructure or match on them - unlike in
21*14675a02SAndroid Build Coastguard Worker // Haskell, Rust, etc.
22*14675a02SAndroid Build Coastguard Worker //
23*14675a02SAndroid Build Coastguard Worker // This file provides a way to match on :variant in a way akin to a switch
24*14675a02SAndroid Build Coastguard Worker // statement.
25*14675a02SAndroid Build Coastguard Worker //
26*14675a02SAndroid Build Coastguard Worker // Example:
27*14675a02SAndroid Build Coastguard Worker //
28*14675a02SAndroid Build Coastguard Worker //   using V = std::variant<X, Y, Z>;
29*14675a02SAndroid Build Coastguard Worker //   V v = ...;
30*14675a02SAndroid Build Coastguard Worker //   ...
31*14675a02SAndroid Build Coastguard Worker //   int i = Match(v,
32*14675a02SAndroid Build Coastguard Worker //     [](X const& x) { return 1; },
33*14675a02SAndroid Build Coastguard Worker //     [](Y const& y) { return 2; },
34*14675a02SAndroid Build Coastguard Worker //     [](Z const& z) { return 3; });
35*14675a02SAndroid Build Coastguard Worker //
36*14675a02SAndroid Build Coastguard Worker // It is a compile-time error if the match is not exhaustive. A 'Default' case
37*14675a02SAndroid Build Coastguard Worker // can be provided:
38*14675a02SAndroid Build Coastguard Worker //
39*14675a02SAndroid Build Coastguard Worker //   int i = Match(v,
40*14675a02SAndroid Build Coastguard Worker //     [](X const& x) { return 1; },
41*14675a02SAndroid Build Coastguard Worker //     // Called with the otherwise-unhandled alternative (see decltype(alt)).
42*14675a02SAndroid Build Coastguard Worker //     [](Default, auto const& alt) { ...; });
43*14675a02SAndroid Build Coastguard Worker //
44*14675a02SAndroid Build Coastguard Worker //   int i = Match(v,
45*14675a02SAndroid Build Coastguard Worker //     [](X const& x) { return 1; },
46*14675a02SAndroid Build Coastguard Worker //     // Called with the variant itself.
47*14675a02SAndroid Build Coastguard Worker //     [](Default, V const& v) { ...; });
48*14675a02SAndroid Build Coastguard Worker //
49*14675a02SAndroid Build Coastguard Worker // If constructing the matcher lambdas is non-trivial, it might be worthwhile to
50*14675a02SAndroid Build Coastguard Worker // create a re-usable matcher object. See 'MakeMatcher'.
51*14675a02SAndroid Build Coastguard Worker 
52*14675a02SAndroid Build Coastguard Worker #ifndef FCP_BASE_MATCH_H_
53*14675a02SAndroid Build Coastguard Worker #define FCP_BASE_MATCH_H_
54*14675a02SAndroid Build Coastguard Worker 
55*14675a02SAndroid Build Coastguard Worker #include <optional>
56*14675a02SAndroid Build Coastguard Worker #include <type_traits>
57*14675a02SAndroid Build Coastguard Worker #include <variant>
58*14675a02SAndroid Build Coastguard Worker 
59*14675a02SAndroid Build Coastguard Worker #include "fcp/base/meta.h"
60*14675a02SAndroid Build Coastguard Worker 
61*14675a02SAndroid Build Coastguard Worker namespace fcp {
62*14675a02SAndroid Build Coastguard Worker 
63*14675a02SAndroid Build Coastguard Worker // Marker type for default match cases.
64*14675a02SAndroid Build Coastguard Worker struct Default {};
65*14675a02SAndroid Build Coastguard Worker 
66*14675a02SAndroid Build Coastguard Worker namespace match_internal {
67*14675a02SAndroid Build Coastguard Worker 
68*14675a02SAndroid Build Coastguard Worker template <typename... CaseFns>
69*14675a02SAndroid Build Coastguard Worker struct MatchCasesCallable : public CaseFns... {
70*14675a02SAndroid Build Coastguard Worker   // Each CaseFn provides operator(). We want to pick one by overload
71*14675a02SAndroid Build Coastguard Worker   // resolution.
72*14675a02SAndroid Build Coastguard Worker   using CaseFns::operator()...;
73*14675a02SAndroid Build Coastguard Worker };
74*14675a02SAndroid Build Coastguard Worker 
75*14675a02SAndroid Build Coastguard Worker template <typename ToType, typename... CaseFns>
76*14675a02SAndroid Build Coastguard Worker class MatchCases {
77*14675a02SAndroid Build Coastguard Worker  public:
MatchCases(MatchCasesCallable<CaseFns...> c)78*14675a02SAndroid Build Coastguard Worker   explicit constexpr MatchCases(MatchCasesCallable<CaseFns...> c)
79*14675a02SAndroid Build Coastguard Worker       : callable_(std::move(c)) {}
80*14675a02SAndroid Build Coastguard Worker 
81*14675a02SAndroid Build Coastguard Worker   // False by default
82*14675a02SAndroid Build Coastguard Worker   template <typename Enable, typename... T>
83*14675a02SAndroid Build Coastguard Worker   struct IsCaseHandledImpl : public std::false_type {};
84*14675a02SAndroid Build Coastguard Worker 
85*14675a02SAndroid Build Coastguard Worker   // True when m.MatchCases(args...) is well-formed, for a
86*14675a02SAndroid Build Coastguard Worker   // MatchCases<CaseFns...> m and T arg.
87*14675a02SAndroid Build Coastguard Worker   template <typename... T>
88*14675a02SAndroid Build Coastguard Worker   struct IsCaseHandledImpl<
89*14675a02SAndroid Build Coastguard Worker       std::void_t<decltype(std::declval<MatchCasesCallable<CaseFns...>>()(
90*14675a02SAndroid Build Coastguard Worker           std::declval<T>()...))>,
91*14675a02SAndroid Build Coastguard Worker       T...> : public std::true_type {};
92*14675a02SAndroid Build Coastguard Worker 
93*14675a02SAndroid Build Coastguard Worker   template <typename... T>
94*14675a02SAndroid Build Coastguard Worker   static constexpr bool IsCaseHandled() {
95*14675a02SAndroid Build Coastguard Worker     return IsCaseHandledImpl<void, T...>::value;
96*14675a02SAndroid Build Coastguard Worker   }
97*14675a02SAndroid Build Coastguard Worker 
98*14675a02SAndroid Build Coastguard Worker   template <typename ToType_ = ToType, typename... Args>
99*14675a02SAndroid Build Coastguard Worker   constexpr auto operator()(Args&&... args) const {
100*14675a02SAndroid Build Coastguard Worker     if constexpr (std::is_void_v<ToType_>) {
101*14675a02SAndroid Build Coastguard Worker       return callable_(std::forward<Args>(args)...);
102*14675a02SAndroid Build Coastguard Worker     } else {
103*14675a02SAndroid Build Coastguard Worker       return ToType_(callable_(std::forward<Args>(args)...));
104*14675a02SAndroid Build Coastguard Worker     }
105*14675a02SAndroid Build Coastguard Worker   }
106*14675a02SAndroid Build Coastguard Worker 
107*14675a02SAndroid Build Coastguard Worker  private:
108*14675a02SAndroid Build Coastguard Worker   MatchCasesCallable<CaseFns...> callable_;
109*14675a02SAndroid Build Coastguard Worker };
110*14675a02SAndroid Build Coastguard Worker 
111*14675a02SAndroid Build Coastguard Worker template <typename ToType, typename... CaseFns>
112*14675a02SAndroid Build Coastguard Worker constexpr MatchCases<ToType, CaseFns...> MakeMatchCases(CaseFns... case_fns) {
113*14675a02SAndroid Build Coastguard Worker   return MatchCases<ToType, CaseFns...>(
114*14675a02SAndroid Build Coastguard Worker       MatchCasesCallable<CaseFns...>{case_fns...});
115*14675a02SAndroid Build Coastguard Worker }
116*14675a02SAndroid Build Coastguard Worker 
117*14675a02SAndroid Build Coastguard Worker template <typename CasesType, typename VariantType, typename ArgType>
118*14675a02SAndroid Build Coastguard Worker constexpr auto ApplyCase(CasesType const& cases, VariantType&& v,
119*14675a02SAndroid Build Coastguard Worker                          ArgType&& arg) {
120*14675a02SAndroid Build Coastguard Worker   if constexpr (CasesType::template IsCaseHandled<ArgType>()) {
121*14675a02SAndroid Build Coastguard Worker     return cases(std::forward<ArgType>(arg));
122*14675a02SAndroid Build Coastguard Worker   } else if constexpr (CasesType::template IsCaseHandled<Default, ArgType>()) {
123*14675a02SAndroid Build Coastguard Worker     return cases(Default{}, std::forward<ArgType>(arg));
124*14675a02SAndroid Build Coastguard Worker   } else if constexpr (CasesType::template IsCaseHandled<Default,
125*14675a02SAndroid Build Coastguard Worker                                                          VariantType>()) {
126*14675a02SAndroid Build Coastguard Worker     return cases(Default{}, std::forward<VariantType>(v));
127*14675a02SAndroid Build Coastguard Worker   } else if constexpr (CasesType::template IsCaseHandled<Default>()) {
128*14675a02SAndroid Build Coastguard Worker     return cases(Default{});
129*14675a02SAndroid Build Coastguard Worker   } else {
130*14675a02SAndroid Build Coastguard Worker     static_assert(
131*14675a02SAndroid Build Coastguard Worker         FailIfReached<ArgType>(),
132*14675a02SAndroid Build Coastguard Worker         "Provide a case for all variant alternatives, or a 'Default' case");
133*14675a02SAndroid Build Coastguard Worker   }
134*14675a02SAndroid Build Coastguard Worker }
135*14675a02SAndroid Build Coastguard Worker 
136*14675a02SAndroid Build Coastguard Worker template <typename Traits, typename CasesType>
137*14675a02SAndroid Build Coastguard Worker class VariantMatcherImpl {
138*14675a02SAndroid Build Coastguard Worker  public:
139*14675a02SAndroid Build Coastguard Worker   using ValueType = typename Traits::ValueType;
140*14675a02SAndroid Build Coastguard Worker 
141*14675a02SAndroid Build Coastguard Worker   explicit constexpr VariantMatcherImpl(CasesType cases)
142*14675a02SAndroid Build Coastguard Worker       : cases_(std::move(cases)) {}
143*14675a02SAndroid Build Coastguard Worker 
144*14675a02SAndroid Build Coastguard Worker   constexpr auto Match(ValueType* v) const { return MatchInternal(v); }
145*14675a02SAndroid Build Coastguard Worker 
146*14675a02SAndroid Build Coastguard Worker   constexpr auto Match(ValueType const& v) const { return MatchInternal(v); }
147*14675a02SAndroid Build Coastguard Worker 
148*14675a02SAndroid Build Coastguard Worker   constexpr auto Match(ValueType&& v) const {
149*14675a02SAndroid Build Coastguard Worker     return MatchInternal(std::move(v));
150*14675a02SAndroid Build Coastguard Worker   }
151*14675a02SAndroid Build Coastguard Worker 
152*14675a02SAndroid Build Coastguard Worker  private:
153*14675a02SAndroid Build Coastguard Worker   template <typename FromType>
154*14675a02SAndroid Build Coastguard Worker   constexpr auto MatchInternal(FromType&& v) const {
155*14675a02SAndroid Build Coastguard Worker     return Traits::Visit(std::forward<FromType>(v), [this, &v](auto&& alt) {
156*14675a02SAndroid Build Coastguard Worker       return ApplyCase(cases_, std::forward<FromType>(v),
157*14675a02SAndroid Build Coastguard Worker                        std::forward<decltype(alt)>(alt));
158*14675a02SAndroid Build Coastguard Worker     });
159*14675a02SAndroid Build Coastguard Worker   }
160*14675a02SAndroid Build Coastguard Worker 
161*14675a02SAndroid Build Coastguard Worker   CasesType cases_;
162*14675a02SAndroid Build Coastguard Worker };
163*14675a02SAndroid Build Coastguard Worker 
164*14675a02SAndroid Build Coastguard Worker template <typename T, typename Enable = void>
165*14675a02SAndroid Build Coastguard Worker struct MatchTraits {
166*14675a02SAndroid Build Coastguard Worker   static_assert(FailIfReached<T>(),
167*14675a02SAndroid Build Coastguard Worker                 "Only variant-like (e.g. std::variant<...> types can be "
168*14675a02SAndroid Build Coastguard Worker                 "matched. See MatchTraits.");
169*14675a02SAndroid Build Coastguard Worker };
170*14675a02SAndroid Build Coastguard Worker 
171*14675a02SAndroid Build Coastguard Worker template <typename... AltTypes>
172*14675a02SAndroid Build Coastguard Worker struct MatchTraits<std::variant<AltTypes...>> {
173*14675a02SAndroid Build Coastguard Worker   using ValueType = std::variant<AltTypes...>;
174*14675a02SAndroid Build Coastguard Worker 
175*14675a02SAndroid Build Coastguard Worker   template <typename VisitFn>
176*14675a02SAndroid Build Coastguard Worker   static constexpr auto Visit(ValueType const& v, VisitFn&& fn) {
177*14675a02SAndroid Build Coastguard Worker     return absl::visit(std::forward<VisitFn>(fn), v);
178*14675a02SAndroid Build Coastguard Worker   }
179*14675a02SAndroid Build Coastguard Worker 
180*14675a02SAndroid Build Coastguard Worker   template <typename VisitFn>
181*14675a02SAndroid Build Coastguard Worker   static constexpr auto Visit(ValueType&& v, VisitFn&& fn) {
182*14675a02SAndroid Build Coastguard Worker     return absl::visit(std::forward<VisitFn>(fn), std::move(v));
183*14675a02SAndroid Build Coastguard Worker   }
184*14675a02SAndroid Build Coastguard Worker 
185*14675a02SAndroid Build Coastguard Worker   template <typename VisitFn>
186*14675a02SAndroid Build Coastguard Worker   static constexpr auto Visit(ValueType* v, VisitFn&& fn) {
187*14675a02SAndroid Build Coastguard Worker     return absl::visit([fn = std::forward<VisitFn>(fn)](
188*14675a02SAndroid Build Coastguard Worker                            auto& alt) mutable { return fn(&alt); },
189*14675a02SAndroid Build Coastguard Worker                        *v);
190*14675a02SAndroid Build Coastguard Worker   }
191*14675a02SAndroid Build Coastguard Worker };
192*14675a02SAndroid Build Coastguard Worker 
193*14675a02SAndroid Build Coastguard Worker template <typename T>
194*14675a02SAndroid Build Coastguard Worker struct MatchTraits<std::optional<T>> {
195*14675a02SAndroid Build Coastguard Worker   using ValueType = std::optional<T>;
196*14675a02SAndroid Build Coastguard Worker 
197*14675a02SAndroid Build Coastguard Worker   static constexpr auto Wrap(std::optional<T>* o)
198*14675a02SAndroid Build Coastguard Worker       -> std::variant<T*, std::nullopt_t> {
199*14675a02SAndroid Build Coastguard Worker     if (o->has_value()) {
200*14675a02SAndroid Build Coastguard Worker       return &**o;
201*14675a02SAndroid Build Coastguard Worker     } else {
202*14675a02SAndroid Build Coastguard Worker       return std::nullopt;
203*14675a02SAndroid Build Coastguard Worker     }
204*14675a02SAndroid Build Coastguard Worker   }
205*14675a02SAndroid Build Coastguard Worker 
206*14675a02SAndroid Build Coastguard Worker   static constexpr auto Wrap(std::optional<T> const& o)
207*14675a02SAndroid Build Coastguard Worker       -> std::variant<std::reference_wrapper<T const>, std::nullopt_t> {
208*14675a02SAndroid Build Coastguard Worker     if (o.has_value()) {
209*14675a02SAndroid Build Coastguard Worker       return std::ref(*o);
210*14675a02SAndroid Build Coastguard Worker     } else {
211*14675a02SAndroid Build Coastguard Worker       return std::nullopt;
212*14675a02SAndroid Build Coastguard Worker     }
213*14675a02SAndroid Build Coastguard Worker   }
214*14675a02SAndroid Build Coastguard Worker 
215*14675a02SAndroid Build Coastguard Worker   static constexpr auto Wrap(std::optional<T>&& o)
216*14675a02SAndroid Build Coastguard Worker       -> std::variant<T, std::nullopt_t> {
217*14675a02SAndroid Build Coastguard Worker     if (o.has_value()) {
218*14675a02SAndroid Build Coastguard Worker       return *std::move(o);
219*14675a02SAndroid Build Coastguard Worker     } else {
220*14675a02SAndroid Build Coastguard Worker       return std::nullopt;
221*14675a02SAndroid Build Coastguard Worker     }
222*14675a02SAndroid Build Coastguard Worker   }
223*14675a02SAndroid Build Coastguard Worker 
224*14675a02SAndroid Build Coastguard Worker   template <typename V, typename VisitFn>
225*14675a02SAndroid Build Coastguard Worker   static constexpr auto Visit(V&& v, VisitFn&& fn) {
226*14675a02SAndroid Build Coastguard Worker     return absl::visit(std::forward<VisitFn>(fn), Wrap(std::forward<V>(v)));
227*14675a02SAndroid Build Coastguard Worker   }
228*14675a02SAndroid Build Coastguard Worker };
229*14675a02SAndroid Build Coastguard Worker 
230*14675a02SAndroid Build Coastguard Worker template <typename T>
231*14675a02SAndroid Build Coastguard Worker struct MatchTraits<T, std::void_t<typename T::VariantType>> {
232*14675a02SAndroid Build Coastguard Worker   using ValueType = T;
233*14675a02SAndroid Build Coastguard Worker 
234*14675a02SAndroid Build Coastguard Worker   template <typename VisitFn>
235*14675a02SAndroid Build Coastguard Worker   static constexpr auto Visit(ValueType const& v, VisitFn&& fn) {
236*14675a02SAndroid Build Coastguard Worker     return MatchTraits<typename T::VariantType>::Visit(
237*14675a02SAndroid Build Coastguard Worker         v.variant(), std::forward<VisitFn>(fn));
238*14675a02SAndroid Build Coastguard Worker   }
239*14675a02SAndroid Build Coastguard Worker 
240*14675a02SAndroid Build Coastguard Worker   template <typename VisitFn>
241*14675a02SAndroid Build Coastguard Worker   static constexpr auto Visit(ValueType&& v, VisitFn&& fn) {
242*14675a02SAndroid Build Coastguard Worker     return MatchTraits<typename T::VariantType>::Visit(
243*14675a02SAndroid Build Coastguard Worker         std::move(v).variant(), std::forward<VisitFn>(fn));
244*14675a02SAndroid Build Coastguard Worker   }
245*14675a02SAndroid Build Coastguard Worker 
246*14675a02SAndroid Build Coastguard Worker   template <typename VisitFn>
247*14675a02SAndroid Build Coastguard Worker   static constexpr auto Visit(ValueType* v, VisitFn&& fn) {
248*14675a02SAndroid Build Coastguard Worker     return MatchTraits<typename T::VariantType>::Visit(
249*14675a02SAndroid Build Coastguard Worker         &v->variant(), std::forward<VisitFn>(fn));
250*14675a02SAndroid Build Coastguard Worker   }
251*14675a02SAndroid Build Coastguard Worker };
252*14675a02SAndroid Build Coastguard Worker 
253*14675a02SAndroid Build Coastguard Worker template <typename VariantType, typename CasesType>
254*14675a02SAndroid Build Coastguard Worker constexpr auto CreateMatcherImpl(CasesType cases) {
255*14675a02SAndroid Build Coastguard Worker   return VariantMatcherImpl<MatchTraits<VariantType>, CasesType>(
256*14675a02SAndroid Build Coastguard Worker       std::move(cases));
257*14675a02SAndroid Build Coastguard Worker }
258*14675a02SAndroid Build Coastguard Worker 
259*14675a02SAndroid Build Coastguard Worker }  // namespace match_internal
260*14675a02SAndroid Build Coastguard Worker 
261*14675a02SAndroid Build Coastguard Worker // See file remarks.
262*14675a02SAndroid Build Coastguard Worker template <typename From, typename To = void, typename... CaseFnTypes>
263*14675a02SAndroid Build Coastguard Worker constexpr auto MakeMatcher(CaseFnTypes... fns) {
264*14675a02SAndroid Build Coastguard Worker   return match_internal::CreateMatcherImpl<From>(
265*14675a02SAndroid Build Coastguard Worker       match_internal::MakeMatchCases<To>(fns...));
266*14675a02SAndroid Build Coastguard Worker }
267*14675a02SAndroid Build Coastguard Worker 
268*14675a02SAndroid Build Coastguard Worker // See file remarks.
269*14675a02SAndroid Build Coastguard Worker //
270*14675a02SAndroid Build Coastguard Worker // Note that the order of template arguments differs from MakeMatcher; it is
271*14675a02SAndroid Build Coastguard Worker // expected that 'From' is always deduced (but it can be useful to specify 'To'
272*14675a02SAndroid Build Coastguard Worker // explicitly).
273*14675a02SAndroid Build Coastguard Worker template <typename To = void, typename From, typename... CaseFnTypes>
274*14675a02SAndroid Build Coastguard Worker constexpr auto Match(From&& v, CaseFnTypes... fns) {
275*14675a02SAndroid Build Coastguard Worker   // 'From' is intended to be deduced. For MakeMatcher, we want V (not e.g. V
276*14675a02SAndroid Build Coastguard Worker   // const&).
277*14675a02SAndroid Build Coastguard Worker   auto m = MakeMatcher<std::decay_t<From>, To>(fns...);
278*14675a02SAndroid Build Coastguard Worker   // The full type is still relevant for forwarding.
279*14675a02SAndroid Build Coastguard Worker   return m.Match(std::forward<From>(v));
280*14675a02SAndroid Build Coastguard Worker }
281*14675a02SAndroid Build Coastguard Worker 
282*14675a02SAndroid Build Coastguard Worker template <typename To = void, typename From, typename... CaseFnTypes>
283*14675a02SAndroid Build Coastguard Worker constexpr auto Match(From* v, CaseFnTypes... fns) {
284*14675a02SAndroid Build Coastguard Worker   // 'From' is intended to be deduced. For MakeMatcher, we want V (not e.g. V
285*14675a02SAndroid Build Coastguard Worker   // const*).
286*14675a02SAndroid Build Coastguard Worker   auto m = MakeMatcher<std::decay_t<From>, To>(fns...);
287*14675a02SAndroid Build Coastguard Worker   return m.Match(v);
288*14675a02SAndroid Build Coastguard Worker }
289*14675a02SAndroid Build Coastguard Worker 
290*14675a02SAndroid Build Coastguard Worker }  // namespace fcp
291*14675a02SAndroid Build Coastguard Worker 
292*14675a02SAndroid Build Coastguard Worker #endif  // FCP_BASE_MATCH_H_
293