xref: /aosp_15_r20/external/pytorch/c10/util/OptionalArrayRef.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker // This file defines OptionalArrayRef<T>, a class that has almost the same
2*da0073e9SAndroid Build Coastguard Worker // exact functionality as std::optional<ArrayRef<T>>, except that its
3*da0073e9SAndroid Build Coastguard Worker // converting constructor fixes a dangling pointer issue.
4*da0073e9SAndroid Build Coastguard Worker //
5*da0073e9SAndroid Build Coastguard Worker // The implicit converting constructor of both std::optional<ArrayRef<T>> and
6*da0073e9SAndroid Build Coastguard Worker // std::optional<ArrayRef<T>> can cause the underlying ArrayRef<T> to store
7*da0073e9SAndroid Build Coastguard Worker // a dangling pointer. OptionalArrayRef<T> prevents this by wrapping
8*da0073e9SAndroid Build Coastguard Worker // a std::optional<ArrayRef<T>> and fixing the constructor implementation.
9*da0073e9SAndroid Build Coastguard Worker //
10*da0073e9SAndroid Build Coastguard Worker // See https://github.com/pytorch/pytorch/issues/63645 for more on this.
11*da0073e9SAndroid Build Coastguard Worker 
12*da0073e9SAndroid Build Coastguard Worker #pragma once
13*da0073e9SAndroid Build Coastguard Worker 
14*da0073e9SAndroid Build Coastguard Worker #include <c10/util/ArrayRef.h>
15*da0073e9SAndroid Build Coastguard Worker #include <cstdint>
16*da0073e9SAndroid Build Coastguard Worker #include <initializer_list>
17*da0073e9SAndroid Build Coastguard Worker #include <optional>
18*da0073e9SAndroid Build Coastguard Worker #include <type_traits>
19*da0073e9SAndroid Build Coastguard Worker #include <utility>
20*da0073e9SAndroid Build Coastguard Worker 
21*da0073e9SAndroid Build Coastguard Worker namespace c10 {
22*da0073e9SAndroid Build Coastguard Worker 
23*da0073e9SAndroid Build Coastguard Worker template <typename T>
24*da0073e9SAndroid Build Coastguard Worker class OptionalArrayRef final {
25*da0073e9SAndroid Build Coastguard Worker  public:
26*da0073e9SAndroid Build Coastguard Worker   // Constructors
27*da0073e9SAndroid Build Coastguard Worker 
28*da0073e9SAndroid Build Coastguard Worker   constexpr OptionalArrayRef() noexcept = default;
29*da0073e9SAndroid Build Coastguard Worker 
OptionalArrayRef(std::nullopt_t)30*da0073e9SAndroid Build Coastguard Worker   constexpr OptionalArrayRef(std::nullopt_t) noexcept {}
31*da0073e9SAndroid Build Coastguard Worker 
32*da0073e9SAndroid Build Coastguard Worker   OptionalArrayRef(const OptionalArrayRef& other) = default;
33*da0073e9SAndroid Build Coastguard Worker 
34*da0073e9SAndroid Build Coastguard Worker   OptionalArrayRef(OptionalArrayRef&& other) noexcept = default;
35*da0073e9SAndroid Build Coastguard Worker 
OptionalArrayRef(const std::optional<ArrayRef<T>> & other)36*da0073e9SAndroid Build Coastguard Worker   constexpr OptionalArrayRef(const std::optional<ArrayRef<T>>& other) noexcept
37*da0073e9SAndroid Build Coastguard Worker       : wrapped_opt_array_ref(other) {}
38*da0073e9SAndroid Build Coastguard Worker 
OptionalArrayRef(std::optional<ArrayRef<T>> && other)39*da0073e9SAndroid Build Coastguard Worker   constexpr OptionalArrayRef(std::optional<ArrayRef<T>>&& other) noexcept
40*da0073e9SAndroid Build Coastguard Worker       : wrapped_opt_array_ref(std::move(other)) {}
41*da0073e9SAndroid Build Coastguard Worker 
OptionalArrayRef(const T & value)42*da0073e9SAndroid Build Coastguard Worker   constexpr OptionalArrayRef(const T& value) noexcept
43*da0073e9SAndroid Build Coastguard Worker       : wrapped_opt_array_ref(value) {}
44*da0073e9SAndroid Build Coastguard Worker 
45*da0073e9SAndroid Build Coastguard Worker   template <
46*da0073e9SAndroid Build Coastguard Worker       typename U = ArrayRef<T>,
47*da0073e9SAndroid Build Coastguard Worker       std::enable_if_t<
48*da0073e9SAndroid Build Coastguard Worker           !std::is_same_v<std::decay_t<U>, OptionalArrayRef> &&
49*da0073e9SAndroid Build Coastguard Worker               !std::is_same_v<std::decay_t<U>, std::in_place_t> &&
50*da0073e9SAndroid Build Coastguard Worker               std::is_constructible_v<ArrayRef<T>, U&&> &&
51*da0073e9SAndroid Build Coastguard Worker               std::is_convertible_v<U&&, ArrayRef<T>> &&
52*da0073e9SAndroid Build Coastguard Worker               !std::is_convertible_v<U&&, T>,
53*da0073e9SAndroid Build Coastguard Worker           bool> = false>
OptionalArrayRef(U && value)54*da0073e9SAndroid Build Coastguard Worker   constexpr OptionalArrayRef(U&& value) noexcept(
55*da0073e9SAndroid Build Coastguard Worker       std::is_nothrow_constructible_v<ArrayRef<T>, U&&>)
56*da0073e9SAndroid Build Coastguard Worker       : wrapped_opt_array_ref(std::forward<U>(value)) {}
57*da0073e9SAndroid Build Coastguard Worker 
58*da0073e9SAndroid Build Coastguard Worker   template <
59*da0073e9SAndroid Build Coastguard Worker       typename U = ArrayRef<T>,
60*da0073e9SAndroid Build Coastguard Worker       std::enable_if_t<
61*da0073e9SAndroid Build Coastguard Worker           !std::is_same_v<std::decay_t<U>, OptionalArrayRef> &&
62*da0073e9SAndroid Build Coastguard Worker               !std::is_same_v<std::decay_t<U>, std::in_place_t> &&
63*da0073e9SAndroid Build Coastguard Worker               std::is_constructible_v<ArrayRef<T>, U&&> &&
64*da0073e9SAndroid Build Coastguard Worker               !std::is_convertible_v<U&&, ArrayRef<T>>,
65*da0073e9SAndroid Build Coastguard Worker           bool> = false>
OptionalArrayRef(U && value)66*da0073e9SAndroid Build Coastguard Worker   constexpr explicit OptionalArrayRef(U&& value) noexcept(
67*da0073e9SAndroid Build Coastguard Worker       std::is_nothrow_constructible_v<ArrayRef<T>, U&&>)
68*da0073e9SAndroid Build Coastguard Worker       : wrapped_opt_array_ref(std::forward<U>(value)) {}
69*da0073e9SAndroid Build Coastguard Worker 
70*da0073e9SAndroid Build Coastguard Worker   template <typename... Args>
OptionalArrayRef(std::in_place_t ip,Args &&...args)71*da0073e9SAndroid Build Coastguard Worker   constexpr explicit OptionalArrayRef(
72*da0073e9SAndroid Build Coastguard Worker       std::in_place_t ip,
73*da0073e9SAndroid Build Coastguard Worker       Args&&... args) noexcept
74*da0073e9SAndroid Build Coastguard Worker       : wrapped_opt_array_ref(ip, std::forward<Args>(args)...) {}
75*da0073e9SAndroid Build Coastguard Worker 
76*da0073e9SAndroid Build Coastguard Worker   template <typename U, typename... Args>
OptionalArrayRef(std::in_place_t ip,std::initializer_list<U> il,Args &&...args)77*da0073e9SAndroid Build Coastguard Worker   constexpr explicit OptionalArrayRef(
78*da0073e9SAndroid Build Coastguard Worker       std::in_place_t ip,
79*da0073e9SAndroid Build Coastguard Worker       std::initializer_list<U> il,
80*da0073e9SAndroid Build Coastguard Worker       Args&&... args)
81*da0073e9SAndroid Build Coastguard Worker       : wrapped_opt_array_ref(ip, il, std::forward<Args>(args)...) {}
82*da0073e9SAndroid Build Coastguard Worker 
OptionalArrayRef(const std::initializer_list<T> & Vec)83*da0073e9SAndroid Build Coastguard Worker   constexpr OptionalArrayRef(const std::initializer_list<T>& Vec)
84*da0073e9SAndroid Build Coastguard Worker       : wrapped_opt_array_ref(ArrayRef<T>(Vec)) {}
85*da0073e9SAndroid Build Coastguard Worker 
86*da0073e9SAndroid Build Coastguard Worker   // Destructor
87*da0073e9SAndroid Build Coastguard Worker 
88*da0073e9SAndroid Build Coastguard Worker   ~OptionalArrayRef() = default;
89*da0073e9SAndroid Build Coastguard Worker 
90*da0073e9SAndroid Build Coastguard Worker   // Assignment
91*da0073e9SAndroid Build Coastguard Worker 
92*da0073e9SAndroid Build Coastguard Worker   constexpr OptionalArrayRef& operator=(std::nullopt_t) noexcept {
93*da0073e9SAndroid Build Coastguard Worker     wrapped_opt_array_ref = std::nullopt;
94*da0073e9SAndroid Build Coastguard Worker     return *this;
95*da0073e9SAndroid Build Coastguard Worker   }
96*da0073e9SAndroid Build Coastguard Worker 
97*da0073e9SAndroid Build Coastguard Worker   OptionalArrayRef& operator=(const OptionalArrayRef& other) = default;
98*da0073e9SAndroid Build Coastguard Worker 
99*da0073e9SAndroid Build Coastguard Worker   OptionalArrayRef& operator=(OptionalArrayRef&& other) noexcept = default;
100*da0073e9SAndroid Build Coastguard Worker 
101*da0073e9SAndroid Build Coastguard Worker   constexpr OptionalArrayRef& operator=(
102*da0073e9SAndroid Build Coastguard Worker       const std::optional<ArrayRef<T>>& other) noexcept {
103*da0073e9SAndroid Build Coastguard Worker     wrapped_opt_array_ref = other;
104*da0073e9SAndroid Build Coastguard Worker     return *this;
105*da0073e9SAndroid Build Coastguard Worker   }
106*da0073e9SAndroid Build Coastguard Worker 
107*da0073e9SAndroid Build Coastguard Worker   constexpr OptionalArrayRef& operator=(
108*da0073e9SAndroid Build Coastguard Worker       std::optional<ArrayRef<T>>&& other) noexcept {
109*da0073e9SAndroid Build Coastguard Worker     wrapped_opt_array_ref = std::move(other);
110*da0073e9SAndroid Build Coastguard Worker     return *this;
111*da0073e9SAndroid Build Coastguard Worker   }
112*da0073e9SAndroid Build Coastguard Worker 
113*da0073e9SAndroid Build Coastguard Worker   template <
114*da0073e9SAndroid Build Coastguard Worker       typename U = ArrayRef<T>,
115*da0073e9SAndroid Build Coastguard Worker       typename = std::enable_if_t<
116*da0073e9SAndroid Build Coastguard Worker           !std::is_same_v<std::decay_t<U>, OptionalArrayRef> &&
117*da0073e9SAndroid Build Coastguard Worker           std::is_constructible_v<ArrayRef<T>, U&&> &&
118*da0073e9SAndroid Build Coastguard Worker           std::is_assignable_v<ArrayRef<T>&, U&&>>>
noexcept(std::is_nothrow_constructible_v<ArrayRef<T>,U &&> && std::is_nothrow_assignable_v<ArrayRef<T> &,U &&>)119*da0073e9SAndroid Build Coastguard Worker   constexpr OptionalArrayRef& operator=(U&& value) noexcept(
120*da0073e9SAndroid Build Coastguard Worker       std::is_nothrow_constructible_v<ArrayRef<T>, U&&> &&
121*da0073e9SAndroid Build Coastguard Worker       std::is_nothrow_assignable_v<ArrayRef<T>&, U&&>) {
122*da0073e9SAndroid Build Coastguard Worker     wrapped_opt_array_ref = std::forward<U>(value);
123*da0073e9SAndroid Build Coastguard Worker     return *this;
124*da0073e9SAndroid Build Coastguard Worker   }
125*da0073e9SAndroid Build Coastguard Worker 
126*da0073e9SAndroid Build Coastguard Worker   // Observers
127*da0073e9SAndroid Build Coastguard Worker 
128*da0073e9SAndroid Build Coastguard Worker   constexpr ArrayRef<T>* operator->() noexcept {
129*da0073e9SAndroid Build Coastguard Worker     return &wrapped_opt_array_ref.value();
130*da0073e9SAndroid Build Coastguard Worker   }
131*da0073e9SAndroid Build Coastguard Worker 
132*da0073e9SAndroid Build Coastguard Worker   constexpr const ArrayRef<T>* operator->() const noexcept {
133*da0073e9SAndroid Build Coastguard Worker     return &wrapped_opt_array_ref.value();
134*da0073e9SAndroid Build Coastguard Worker   }
135*da0073e9SAndroid Build Coastguard Worker 
136*da0073e9SAndroid Build Coastguard Worker   constexpr ArrayRef<T>& operator*() & noexcept {
137*da0073e9SAndroid Build Coastguard Worker     return wrapped_opt_array_ref.value();
138*da0073e9SAndroid Build Coastguard Worker   }
139*da0073e9SAndroid Build Coastguard Worker 
140*da0073e9SAndroid Build Coastguard Worker   constexpr const ArrayRef<T>& operator*() const& noexcept {
141*da0073e9SAndroid Build Coastguard Worker     return wrapped_opt_array_ref.value();
142*da0073e9SAndroid Build Coastguard Worker   }
143*da0073e9SAndroid Build Coastguard Worker 
144*da0073e9SAndroid Build Coastguard Worker   constexpr ArrayRef<T>&& operator*() && noexcept {
145*da0073e9SAndroid Build Coastguard Worker     return std::move(wrapped_opt_array_ref.value());
146*da0073e9SAndroid Build Coastguard Worker   }
147*da0073e9SAndroid Build Coastguard Worker 
148*da0073e9SAndroid Build Coastguard Worker   constexpr const ArrayRef<T>&& operator*() const&& noexcept {
149*da0073e9SAndroid Build Coastguard Worker     return std::move(wrapped_opt_array_ref.value());
150*da0073e9SAndroid Build Coastguard Worker   }
151*da0073e9SAndroid Build Coastguard Worker 
152*da0073e9SAndroid Build Coastguard Worker   constexpr explicit operator bool() const noexcept {
153*da0073e9SAndroid Build Coastguard Worker     return wrapped_opt_array_ref.has_value();
154*da0073e9SAndroid Build Coastguard Worker   }
155*da0073e9SAndroid Build Coastguard Worker 
has_value()156*da0073e9SAndroid Build Coastguard Worker   constexpr bool has_value() const noexcept {
157*da0073e9SAndroid Build Coastguard Worker     return wrapped_opt_array_ref.has_value();
158*da0073e9SAndroid Build Coastguard Worker   }
159*da0073e9SAndroid Build Coastguard Worker 
value()160*da0073e9SAndroid Build Coastguard Worker   constexpr ArrayRef<T>& value() & {
161*da0073e9SAndroid Build Coastguard Worker     return wrapped_opt_array_ref.value();
162*da0073e9SAndroid Build Coastguard Worker   }
163*da0073e9SAndroid Build Coastguard Worker 
value()164*da0073e9SAndroid Build Coastguard Worker   constexpr const ArrayRef<T>& value() const& {
165*da0073e9SAndroid Build Coastguard Worker     return wrapped_opt_array_ref.value();
166*da0073e9SAndroid Build Coastguard Worker   }
167*da0073e9SAndroid Build Coastguard Worker 
value()168*da0073e9SAndroid Build Coastguard Worker   constexpr ArrayRef<T>&& value() && {
169*da0073e9SAndroid Build Coastguard Worker     return std::move(wrapped_opt_array_ref.value());
170*da0073e9SAndroid Build Coastguard Worker   }
171*da0073e9SAndroid Build Coastguard Worker 
value()172*da0073e9SAndroid Build Coastguard Worker   constexpr const ArrayRef<T>&& value() const&& {
173*da0073e9SAndroid Build Coastguard Worker     return std::move(wrapped_opt_array_ref.value());
174*da0073e9SAndroid Build Coastguard Worker   }
175*da0073e9SAndroid Build Coastguard Worker 
176*da0073e9SAndroid Build Coastguard Worker   template <typename U>
177*da0073e9SAndroid Build Coastguard Worker   constexpr std::
178*da0073e9SAndroid Build Coastguard Worker       enable_if_t<std::is_convertible_v<U&&, ArrayRef<T>>, ArrayRef<T>>
value_or(U && default_value)179*da0073e9SAndroid Build Coastguard Worker       value_or(U&& default_value) const& {
180*da0073e9SAndroid Build Coastguard Worker     return wrapped_opt_array_ref.value_or(std::forward<U>(default_value));
181*da0073e9SAndroid Build Coastguard Worker   }
182*da0073e9SAndroid Build Coastguard Worker 
183*da0073e9SAndroid Build Coastguard Worker   template <typename U>
184*da0073e9SAndroid Build Coastguard Worker   constexpr std::
185*da0073e9SAndroid Build Coastguard Worker       enable_if_t<std::is_convertible_v<U&&, ArrayRef<T>>, ArrayRef<T>>
value_or(U && default_value)186*da0073e9SAndroid Build Coastguard Worker       value_or(U&& default_value) && {
187*da0073e9SAndroid Build Coastguard Worker     return wrapped_opt_array_ref.value_or(std::forward<U>(default_value));
188*da0073e9SAndroid Build Coastguard Worker   }
189*da0073e9SAndroid Build Coastguard Worker 
190*da0073e9SAndroid Build Coastguard Worker   // Modifiers
191*da0073e9SAndroid Build Coastguard Worker 
swap(OptionalArrayRef & other)192*da0073e9SAndroid Build Coastguard Worker   constexpr void swap(OptionalArrayRef& other) noexcept {
193*da0073e9SAndroid Build Coastguard Worker     std::swap(wrapped_opt_array_ref, other.wrapped_opt_array_ref);
194*da0073e9SAndroid Build Coastguard Worker   }
195*da0073e9SAndroid Build Coastguard Worker 
reset()196*da0073e9SAndroid Build Coastguard Worker   constexpr void reset() noexcept {
197*da0073e9SAndroid Build Coastguard Worker     wrapped_opt_array_ref.reset();
198*da0073e9SAndroid Build Coastguard Worker   }
199*da0073e9SAndroid Build Coastguard Worker 
200*da0073e9SAndroid Build Coastguard Worker   template <typename... Args>
201*da0073e9SAndroid Build Coastguard Worker   constexpr std::
202*da0073e9SAndroid Build Coastguard Worker       enable_if_t<std::is_constructible_v<ArrayRef<T>, Args&&...>, ArrayRef<T>&>
emplace(Args &&...args)203*da0073e9SAndroid Build Coastguard Worker       emplace(Args&&... args) noexcept(
204*da0073e9SAndroid Build Coastguard Worker           std::is_nothrow_constructible_v<ArrayRef<T>, Args&&...>) {
205*da0073e9SAndroid Build Coastguard Worker     return wrapped_opt_array_ref.emplace(std::forward<Args>(args)...);
206*da0073e9SAndroid Build Coastguard Worker   }
207*da0073e9SAndroid Build Coastguard Worker 
208*da0073e9SAndroid Build Coastguard Worker   template <typename U, typename... Args>
emplace(std::initializer_list<U> il,Args &&...args)209*da0073e9SAndroid Build Coastguard Worker   constexpr ArrayRef<T>& emplace(
210*da0073e9SAndroid Build Coastguard Worker       std::initializer_list<U> il,
211*da0073e9SAndroid Build Coastguard Worker       Args&&... args) noexcept {
212*da0073e9SAndroid Build Coastguard Worker     return wrapped_opt_array_ref.emplace(il, std::forward<Args>(args)...);
213*da0073e9SAndroid Build Coastguard Worker   }
214*da0073e9SAndroid Build Coastguard Worker 
215*da0073e9SAndroid Build Coastguard Worker  private:
216*da0073e9SAndroid Build Coastguard Worker   std::optional<ArrayRef<T>> wrapped_opt_array_ref;
217*da0073e9SAndroid Build Coastguard Worker };
218*da0073e9SAndroid Build Coastguard Worker 
219*da0073e9SAndroid Build Coastguard Worker using OptionalIntArrayRef = OptionalArrayRef<int64_t>;
220*da0073e9SAndroid Build Coastguard Worker 
221*da0073e9SAndroid Build Coastguard Worker inline bool operator==(
222*da0073e9SAndroid Build Coastguard Worker     const OptionalIntArrayRef& a1,
223*da0073e9SAndroid Build Coastguard Worker     const IntArrayRef& other) {
224*da0073e9SAndroid Build Coastguard Worker   if (!a1.has_value()) {
225*da0073e9SAndroid Build Coastguard Worker     return false;
226*da0073e9SAndroid Build Coastguard Worker   }
227*da0073e9SAndroid Build Coastguard Worker   return a1.value() == other;
228*da0073e9SAndroid Build Coastguard Worker }
229*da0073e9SAndroid Build Coastguard Worker 
230*da0073e9SAndroid Build Coastguard Worker inline bool operator==(
231*da0073e9SAndroid Build Coastguard Worker     const c10::IntArrayRef& a1,
232*da0073e9SAndroid Build Coastguard Worker     const c10::OptionalIntArrayRef& a2) {
233*da0073e9SAndroid Build Coastguard Worker   return a2 == a1;
234*da0073e9SAndroid Build Coastguard Worker }
235*da0073e9SAndroid Build Coastguard Worker 
236*da0073e9SAndroid Build Coastguard Worker } // namespace c10
237