1 // This file defines OptionalArrayRef<T>, a class that has almost the same 2 // exact functionality as std::optional<ArrayRef<T>>, except that its 3 // converting constructor fixes a dangling pointer issue. 4 // 5 // The implicit converting constructor of both std::optional<ArrayRef<T>> and 6 // std::optional<ArrayRef<T>> can cause the underlying ArrayRef<T> to store 7 // a dangling pointer. OptionalArrayRef<T> prevents this by wrapping 8 // a std::optional<ArrayRef<T>> and fixing the constructor implementation. 9 // 10 // See https://github.com/pytorch/pytorch/issues/63645 for more on this. 11 12 #pragma once 13 14 #include <c10/util/ArrayRef.h> 15 #include <cstdint> 16 #include <initializer_list> 17 #include <optional> 18 #include <type_traits> 19 #include <utility> 20 21 namespace c10 { 22 23 template <typename T> 24 class OptionalArrayRef final { 25 public: 26 // Constructors 27 28 constexpr OptionalArrayRef() noexcept = default; 29 OptionalArrayRef(std::nullopt_t)30 constexpr OptionalArrayRef(std::nullopt_t) noexcept {} 31 32 OptionalArrayRef(const OptionalArrayRef& other) = default; 33 34 OptionalArrayRef(OptionalArrayRef&& other) noexcept = default; 35 OptionalArrayRef(const std::optional<ArrayRef<T>> & other)36 constexpr OptionalArrayRef(const std::optional<ArrayRef<T>>& other) noexcept 37 : wrapped_opt_array_ref(other) {} 38 OptionalArrayRef(std::optional<ArrayRef<T>> && other)39 constexpr OptionalArrayRef(std::optional<ArrayRef<T>>&& other) noexcept 40 : wrapped_opt_array_ref(std::move(other)) {} 41 OptionalArrayRef(const T & value)42 constexpr OptionalArrayRef(const T& value) noexcept 43 : wrapped_opt_array_ref(value) {} 44 45 template < 46 typename U = ArrayRef<T>, 47 std::enable_if_t< 48 !std::is_same_v<std::decay_t<U>, OptionalArrayRef> && 49 !std::is_same_v<std::decay_t<U>, std::in_place_t> && 50 std::is_constructible_v<ArrayRef<T>, U&&> && 51 std::is_convertible_v<U&&, ArrayRef<T>> && 52 !std::is_convertible_v<U&&, T>, 53 bool> = false> OptionalArrayRef(U && value)54 constexpr OptionalArrayRef(U&& value) noexcept( 55 std::is_nothrow_constructible_v<ArrayRef<T>, U&&>) 56 : wrapped_opt_array_ref(std::forward<U>(value)) {} 57 58 template < 59 typename U = ArrayRef<T>, 60 std::enable_if_t< 61 !std::is_same_v<std::decay_t<U>, OptionalArrayRef> && 62 !std::is_same_v<std::decay_t<U>, std::in_place_t> && 63 std::is_constructible_v<ArrayRef<T>, U&&> && 64 !std::is_convertible_v<U&&, ArrayRef<T>>, 65 bool> = false> OptionalArrayRef(U && value)66 constexpr explicit OptionalArrayRef(U&& value) noexcept( 67 std::is_nothrow_constructible_v<ArrayRef<T>, U&&>) 68 : wrapped_opt_array_ref(std::forward<U>(value)) {} 69 70 template <typename... Args> OptionalArrayRef(std::in_place_t ip,Args &&...args)71 constexpr explicit OptionalArrayRef( 72 std::in_place_t ip, 73 Args&&... args) noexcept 74 : wrapped_opt_array_ref(ip, std::forward<Args>(args)...) {} 75 76 template <typename U, typename... Args> OptionalArrayRef(std::in_place_t ip,std::initializer_list<U> il,Args &&...args)77 constexpr explicit OptionalArrayRef( 78 std::in_place_t ip, 79 std::initializer_list<U> il, 80 Args&&... args) 81 : wrapped_opt_array_ref(ip, il, std::forward<Args>(args)...) {} 82 OptionalArrayRef(const std::initializer_list<T> & Vec)83 constexpr OptionalArrayRef(const std::initializer_list<T>& Vec) 84 : wrapped_opt_array_ref(ArrayRef<T>(Vec)) {} 85 86 // Destructor 87 88 ~OptionalArrayRef() = default; 89 90 // Assignment 91 92 constexpr OptionalArrayRef& operator=(std::nullopt_t) noexcept { 93 wrapped_opt_array_ref = std::nullopt; 94 return *this; 95 } 96 97 OptionalArrayRef& operator=(const OptionalArrayRef& other) = default; 98 99 OptionalArrayRef& operator=(OptionalArrayRef&& other) noexcept = default; 100 101 constexpr OptionalArrayRef& operator=( 102 const std::optional<ArrayRef<T>>& other) noexcept { 103 wrapped_opt_array_ref = other; 104 return *this; 105 } 106 107 constexpr OptionalArrayRef& operator=( 108 std::optional<ArrayRef<T>>&& other) noexcept { 109 wrapped_opt_array_ref = std::move(other); 110 return *this; 111 } 112 113 template < 114 typename U = ArrayRef<T>, 115 typename = std::enable_if_t< 116 !std::is_same_v<std::decay_t<U>, OptionalArrayRef> && 117 std::is_constructible_v<ArrayRef<T>, U&&> && 118 std::is_assignable_v<ArrayRef<T>&, U&&>>> noexcept(std::is_nothrow_constructible_v<ArrayRef<T>,U &&> && std::is_nothrow_assignable_v<ArrayRef<T> &,U &&>)119 constexpr OptionalArrayRef& operator=(U&& value) noexcept( 120 std::is_nothrow_constructible_v<ArrayRef<T>, U&&> && 121 std::is_nothrow_assignable_v<ArrayRef<T>&, U&&>) { 122 wrapped_opt_array_ref = std::forward<U>(value); 123 return *this; 124 } 125 126 // Observers 127 128 constexpr ArrayRef<T>* operator->() noexcept { 129 return &wrapped_opt_array_ref.value(); 130 } 131 132 constexpr const ArrayRef<T>* operator->() const noexcept { 133 return &wrapped_opt_array_ref.value(); 134 } 135 136 constexpr ArrayRef<T>& operator*() & noexcept { 137 return wrapped_opt_array_ref.value(); 138 } 139 140 constexpr const ArrayRef<T>& operator*() const& noexcept { 141 return wrapped_opt_array_ref.value(); 142 } 143 144 constexpr ArrayRef<T>&& operator*() && noexcept { 145 return std::move(wrapped_opt_array_ref.value()); 146 } 147 148 constexpr const ArrayRef<T>&& operator*() const&& noexcept { 149 return std::move(wrapped_opt_array_ref.value()); 150 } 151 152 constexpr explicit operator bool() const noexcept { 153 return wrapped_opt_array_ref.has_value(); 154 } 155 has_value()156 constexpr bool has_value() const noexcept { 157 return wrapped_opt_array_ref.has_value(); 158 } 159 value()160 constexpr ArrayRef<T>& value() & { 161 return wrapped_opt_array_ref.value(); 162 } 163 value()164 constexpr const ArrayRef<T>& value() const& { 165 return wrapped_opt_array_ref.value(); 166 } 167 value()168 constexpr ArrayRef<T>&& value() && { 169 return std::move(wrapped_opt_array_ref.value()); 170 } 171 value()172 constexpr const ArrayRef<T>&& value() const&& { 173 return std::move(wrapped_opt_array_ref.value()); 174 } 175 176 template <typename U> 177 constexpr std:: 178 enable_if_t<std::is_convertible_v<U&&, ArrayRef<T>>, ArrayRef<T>> value_or(U && default_value)179 value_or(U&& default_value) const& { 180 return wrapped_opt_array_ref.value_or(std::forward<U>(default_value)); 181 } 182 183 template <typename U> 184 constexpr std:: 185 enable_if_t<std::is_convertible_v<U&&, ArrayRef<T>>, ArrayRef<T>> value_or(U && default_value)186 value_or(U&& default_value) && { 187 return wrapped_opt_array_ref.value_or(std::forward<U>(default_value)); 188 } 189 190 // Modifiers 191 swap(OptionalArrayRef & other)192 constexpr void swap(OptionalArrayRef& other) noexcept { 193 std::swap(wrapped_opt_array_ref, other.wrapped_opt_array_ref); 194 } 195 reset()196 constexpr void reset() noexcept { 197 wrapped_opt_array_ref.reset(); 198 } 199 200 template <typename... Args> 201 constexpr std:: 202 enable_if_t<std::is_constructible_v<ArrayRef<T>, Args&&...>, ArrayRef<T>&> emplace(Args &&...args)203 emplace(Args&&... args) noexcept( 204 std::is_nothrow_constructible_v<ArrayRef<T>, Args&&...>) { 205 return wrapped_opt_array_ref.emplace(std::forward<Args>(args)...); 206 } 207 208 template <typename U, typename... Args> emplace(std::initializer_list<U> il,Args &&...args)209 constexpr ArrayRef<T>& emplace( 210 std::initializer_list<U> il, 211 Args&&... args) noexcept { 212 return wrapped_opt_array_ref.emplace(il, std::forward<Args>(args)...); 213 } 214 215 private: 216 std::optional<ArrayRef<T>> wrapped_opt_array_ref; 217 }; 218 219 using OptionalIntArrayRef = OptionalArrayRef<int64_t>; 220 221 inline bool operator==( 222 const OptionalIntArrayRef& a1, 223 const IntArrayRef& other) { 224 if (!a1.has_value()) { 225 return false; 226 } 227 return a1.value() == other; 228 } 229 230 inline bool operator==( 231 const c10::IntArrayRef& a1, 232 const c10::OptionalIntArrayRef& a2) { 233 return a2 == a1; 234 } 235 236 } // namespace c10 237