xref: /aosp_15_r20/external/pytorch/c10/util/OptionalArrayRef.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // This file defines OptionalArrayRef<T>, a class that has almost the same
2 // exact functionality as std::optional<ArrayRef<T>>, except that its
3 // converting constructor fixes a dangling pointer issue.
4 //
5 // The implicit converting constructor of both std::optional<ArrayRef<T>> and
6 // std::optional<ArrayRef<T>> can cause the underlying ArrayRef<T> to store
7 // a dangling pointer. OptionalArrayRef<T> prevents this by wrapping
8 // a std::optional<ArrayRef<T>> and fixing the constructor implementation.
9 //
10 // See https://github.com/pytorch/pytorch/issues/63645 for more on this.
11 
12 #pragma once
13 
14 #include <c10/util/ArrayRef.h>
15 #include <cstdint>
16 #include <initializer_list>
17 #include <optional>
18 #include <type_traits>
19 #include <utility>
20 
21 namespace c10 {
22 
23 template <typename T>
24 class OptionalArrayRef final {
25  public:
26   // Constructors
27 
28   constexpr OptionalArrayRef() noexcept = default;
29 
OptionalArrayRef(std::nullopt_t)30   constexpr OptionalArrayRef(std::nullopt_t) noexcept {}
31 
32   OptionalArrayRef(const OptionalArrayRef& other) = default;
33 
34   OptionalArrayRef(OptionalArrayRef&& other) noexcept = default;
35 
OptionalArrayRef(const std::optional<ArrayRef<T>> & other)36   constexpr OptionalArrayRef(const std::optional<ArrayRef<T>>& other) noexcept
37       : wrapped_opt_array_ref(other) {}
38 
OptionalArrayRef(std::optional<ArrayRef<T>> && other)39   constexpr OptionalArrayRef(std::optional<ArrayRef<T>>&& other) noexcept
40       : wrapped_opt_array_ref(std::move(other)) {}
41 
OptionalArrayRef(const T & value)42   constexpr OptionalArrayRef(const T& value) noexcept
43       : wrapped_opt_array_ref(value) {}
44 
45   template <
46       typename U = ArrayRef<T>,
47       std::enable_if_t<
48           !std::is_same_v<std::decay_t<U>, OptionalArrayRef> &&
49               !std::is_same_v<std::decay_t<U>, std::in_place_t> &&
50               std::is_constructible_v<ArrayRef<T>, U&&> &&
51               std::is_convertible_v<U&&, ArrayRef<T>> &&
52               !std::is_convertible_v<U&&, T>,
53           bool> = false>
OptionalArrayRef(U && value)54   constexpr OptionalArrayRef(U&& value) noexcept(
55       std::is_nothrow_constructible_v<ArrayRef<T>, U&&>)
56       : wrapped_opt_array_ref(std::forward<U>(value)) {}
57 
58   template <
59       typename U = ArrayRef<T>,
60       std::enable_if_t<
61           !std::is_same_v<std::decay_t<U>, OptionalArrayRef> &&
62               !std::is_same_v<std::decay_t<U>, std::in_place_t> &&
63               std::is_constructible_v<ArrayRef<T>, U&&> &&
64               !std::is_convertible_v<U&&, ArrayRef<T>>,
65           bool> = false>
OptionalArrayRef(U && value)66   constexpr explicit OptionalArrayRef(U&& value) noexcept(
67       std::is_nothrow_constructible_v<ArrayRef<T>, U&&>)
68       : wrapped_opt_array_ref(std::forward<U>(value)) {}
69 
70   template <typename... Args>
OptionalArrayRef(std::in_place_t ip,Args &&...args)71   constexpr explicit OptionalArrayRef(
72       std::in_place_t ip,
73       Args&&... args) noexcept
74       : wrapped_opt_array_ref(ip, std::forward<Args>(args)...) {}
75 
76   template <typename U, typename... Args>
OptionalArrayRef(std::in_place_t ip,std::initializer_list<U> il,Args &&...args)77   constexpr explicit OptionalArrayRef(
78       std::in_place_t ip,
79       std::initializer_list<U> il,
80       Args&&... args)
81       : wrapped_opt_array_ref(ip, il, std::forward<Args>(args)...) {}
82 
OptionalArrayRef(const std::initializer_list<T> & Vec)83   constexpr OptionalArrayRef(const std::initializer_list<T>& Vec)
84       : wrapped_opt_array_ref(ArrayRef<T>(Vec)) {}
85 
86   // Destructor
87 
88   ~OptionalArrayRef() = default;
89 
90   // Assignment
91 
92   constexpr OptionalArrayRef& operator=(std::nullopt_t) noexcept {
93     wrapped_opt_array_ref = std::nullopt;
94     return *this;
95   }
96 
97   OptionalArrayRef& operator=(const OptionalArrayRef& other) = default;
98 
99   OptionalArrayRef& operator=(OptionalArrayRef&& other) noexcept = default;
100 
101   constexpr OptionalArrayRef& operator=(
102       const std::optional<ArrayRef<T>>& other) noexcept {
103     wrapped_opt_array_ref = other;
104     return *this;
105   }
106 
107   constexpr OptionalArrayRef& operator=(
108       std::optional<ArrayRef<T>>&& other) noexcept {
109     wrapped_opt_array_ref = std::move(other);
110     return *this;
111   }
112 
113   template <
114       typename U = ArrayRef<T>,
115       typename = std::enable_if_t<
116           !std::is_same_v<std::decay_t<U>, OptionalArrayRef> &&
117           std::is_constructible_v<ArrayRef<T>, U&&> &&
118           std::is_assignable_v<ArrayRef<T>&, U&&>>>
noexcept(std::is_nothrow_constructible_v<ArrayRef<T>,U &&> && std::is_nothrow_assignable_v<ArrayRef<T> &,U &&>)119   constexpr OptionalArrayRef& operator=(U&& value) noexcept(
120       std::is_nothrow_constructible_v<ArrayRef<T>, U&&> &&
121       std::is_nothrow_assignable_v<ArrayRef<T>&, U&&>) {
122     wrapped_opt_array_ref = std::forward<U>(value);
123     return *this;
124   }
125 
126   // Observers
127 
128   constexpr ArrayRef<T>* operator->() noexcept {
129     return &wrapped_opt_array_ref.value();
130   }
131 
132   constexpr const ArrayRef<T>* operator->() const noexcept {
133     return &wrapped_opt_array_ref.value();
134   }
135 
136   constexpr ArrayRef<T>& operator*() & noexcept {
137     return wrapped_opt_array_ref.value();
138   }
139 
140   constexpr const ArrayRef<T>& operator*() const& noexcept {
141     return wrapped_opt_array_ref.value();
142   }
143 
144   constexpr ArrayRef<T>&& operator*() && noexcept {
145     return std::move(wrapped_opt_array_ref.value());
146   }
147 
148   constexpr const ArrayRef<T>&& operator*() const&& noexcept {
149     return std::move(wrapped_opt_array_ref.value());
150   }
151 
152   constexpr explicit operator bool() const noexcept {
153     return wrapped_opt_array_ref.has_value();
154   }
155 
has_value()156   constexpr bool has_value() const noexcept {
157     return wrapped_opt_array_ref.has_value();
158   }
159 
value()160   constexpr ArrayRef<T>& value() & {
161     return wrapped_opt_array_ref.value();
162   }
163 
value()164   constexpr const ArrayRef<T>& value() const& {
165     return wrapped_opt_array_ref.value();
166   }
167 
value()168   constexpr ArrayRef<T>&& value() && {
169     return std::move(wrapped_opt_array_ref.value());
170   }
171 
value()172   constexpr const ArrayRef<T>&& value() const&& {
173     return std::move(wrapped_opt_array_ref.value());
174   }
175 
176   template <typename U>
177   constexpr std::
178       enable_if_t<std::is_convertible_v<U&&, ArrayRef<T>>, ArrayRef<T>>
value_or(U && default_value)179       value_or(U&& default_value) const& {
180     return wrapped_opt_array_ref.value_or(std::forward<U>(default_value));
181   }
182 
183   template <typename U>
184   constexpr std::
185       enable_if_t<std::is_convertible_v<U&&, ArrayRef<T>>, ArrayRef<T>>
value_or(U && default_value)186       value_or(U&& default_value) && {
187     return wrapped_opt_array_ref.value_or(std::forward<U>(default_value));
188   }
189 
190   // Modifiers
191 
swap(OptionalArrayRef & other)192   constexpr void swap(OptionalArrayRef& other) noexcept {
193     std::swap(wrapped_opt_array_ref, other.wrapped_opt_array_ref);
194   }
195 
reset()196   constexpr void reset() noexcept {
197     wrapped_opt_array_ref.reset();
198   }
199 
200   template <typename... Args>
201   constexpr std::
202       enable_if_t<std::is_constructible_v<ArrayRef<T>, Args&&...>, ArrayRef<T>&>
emplace(Args &&...args)203       emplace(Args&&... args) noexcept(
204           std::is_nothrow_constructible_v<ArrayRef<T>, Args&&...>) {
205     return wrapped_opt_array_ref.emplace(std::forward<Args>(args)...);
206   }
207 
208   template <typename U, typename... Args>
emplace(std::initializer_list<U> il,Args &&...args)209   constexpr ArrayRef<T>& emplace(
210       std::initializer_list<U> il,
211       Args&&... args) noexcept {
212     return wrapped_opt_array_ref.emplace(il, std::forward<Args>(args)...);
213   }
214 
215  private:
216   std::optional<ArrayRef<T>> wrapped_opt_array_ref;
217 };
218 
219 using OptionalIntArrayRef = OptionalArrayRef<int64_t>;
220 
221 inline bool operator==(
222     const OptionalIntArrayRef& a1,
223     const IntArrayRef& other) {
224   if (!a1.has_value()) {
225     return false;
226   }
227   return a1.value() == other;
228 }
229 
230 inline bool operator==(
231     const c10::IntArrayRef& a1,
232     const c10::OptionalIntArrayRef& a2) {
233   return a2 == a1;
234 }
235 
236 } // namespace c10
237