1 #pragma once 2 3 namespace at::native { 4 5 // (Const)StridedRandomAccessor is a 6 // (const) random access iterator defined over 7 // a strided array. 8 9 // The traits below are to introduce __restrict__ 10 // modifier on different platforms. 11 12 template <typename T> 13 struct DefaultPtrTraits { 14 using PtrType = T*; 15 }; 16 17 #if (defined(_WIN32) || defined(_WIN64)) 18 #define RESTRICT __restrict 19 #else 20 #define RESTRICT __restrict__ 21 #endif 22 23 template <typename T> 24 struct RestrictPtrTraits { 25 using PtrType = T* RESTRICT; 26 }; 27 28 template < 29 typename T, 30 typename index_t = int64_t, 31 template <typename U> class PtrTraits = DefaultPtrTraits 32 > 33 class ConstStridedRandomAccessor { 34 public: 35 using difference_type = index_t; 36 using value_type = const T; 37 using pointer = const typename PtrTraits<T>::PtrType; 38 using reference = const value_type&; 39 using iterator_category = std::random_access_iterator_tag; 40 41 using PtrType = typename PtrTraits<T>::PtrType; 42 using index_type = index_t; 43 44 // Constructors { 45 C10_HOST_DEVICE ConstStridedRandomAccessor(PtrType ptr,index_t stride)46 ConstStridedRandomAccessor(PtrType ptr, index_t stride) 47 : ptr{ptr}, stride{stride} 48 {} 49 50 C10_HOST_DEVICE ConstStridedRandomAccessor(PtrType ptr)51 explicit ConstStridedRandomAccessor(PtrType ptr) 52 : ptr{ptr}, stride{static_cast<index_t>(1)} 53 {} 54 55 C10_HOST_DEVICE ConstStridedRandomAccessor()56 ConstStridedRandomAccessor() 57 : ptr{nullptr}, stride{static_cast<index_t>(1)} 58 {} 59 // } 60 61 // Pointer-like operations { 62 C10_HOST_DEVICE 63 reference operator*() const { 64 return *ptr; 65 } 66 67 C10_HOST_DEVICE 68 const value_type* operator->() const { 69 return reinterpret_cast<const value_type*>(ptr); 70 } 71 72 C10_HOST_DEVICE 73 reference operator[](index_t idx) const { 74 return ptr[idx * stride]; 75 } 76 // } 77 78 // Prefix/postfix increment/decrement { 79 C10_HOST_DEVICE 80 ConstStridedRandomAccessor& operator++() { 81 ptr += stride; 82 return *this; 83 } 84 85 C10_HOST_DEVICE 86 ConstStridedRandomAccessor operator++(int) { 87 ConstStridedRandomAccessor copy(*this); 88 ++*this; 89 return copy; 90 } 91 92 C10_HOST_DEVICE 93 ConstStridedRandomAccessor& operator--() { 94 ptr -= stride; 95 return *this; 96 } 97 98 C10_HOST_DEVICE 99 ConstStridedRandomAccessor operator--(int) { 100 ConstStridedRandomAccessor copy(*this); 101 --*this; 102 return copy; 103 } 104 // } 105 106 // Arithmetic operations { 107 C10_HOST_DEVICE 108 ConstStridedRandomAccessor& operator+=(index_t offset) { 109 ptr += offset * stride; 110 return *this; 111 } 112 113 C10_HOST_DEVICE 114 ConstStridedRandomAccessor operator+(index_t offset) const { 115 return ConstStridedRandomAccessor(ptr + offset * stride, stride); 116 } 117 118 C10_HOST_DEVICE 119 friend ConstStridedRandomAccessor operator+( 120 index_t offset, 121 const ConstStridedRandomAccessor& accessor 122 ) { 123 return accessor + offset; 124 } 125 126 C10_HOST_DEVICE 127 ConstStridedRandomAccessor& operator-=(index_t offset) { 128 ptr -= offset * stride; 129 return *this; 130 } 131 132 C10_HOST_DEVICE 133 ConstStridedRandomAccessor operator-(index_t offset) const { 134 return ConstStridedRandomAccessor(ptr - offset * stride, stride); 135 } 136 137 // Note that this operator is well-defined when `this` and `other` 138 // represent the same sequences, i.e. when 139 // 1. this.stride == other.stride, 140 // 2. |other - this| / this.stride is an Integer. 141 C10_HOST_DEVICE 142 difference_type operator-(const ConstStridedRandomAccessor& other) const { 143 return (ptr - other.ptr) / stride; 144 } 145 // } 146 147 // Comparison operators { 148 C10_HOST_DEVICE 149 bool operator==(const ConstStridedRandomAccessor& other) const { 150 return (ptr == other.ptr) && (stride == other.stride); 151 } 152 153 C10_HOST_DEVICE 154 bool operator!=(const ConstStridedRandomAccessor& other) const { 155 return !(*this == other); 156 } 157 158 C10_HOST_DEVICE 159 bool operator<(const ConstStridedRandomAccessor& other) const { 160 return ptr < other.ptr; 161 } 162 163 C10_HOST_DEVICE 164 bool operator<=(const ConstStridedRandomAccessor& other) const { 165 return (*this < other) || (*this == other); 166 } 167 168 C10_HOST_DEVICE 169 bool operator>(const ConstStridedRandomAccessor& other) const { 170 return !(*this <= other); 171 } 172 173 C10_HOST_DEVICE 174 bool operator>=(const ConstStridedRandomAccessor& other) const { 175 return !(*this < other); 176 } 177 // } 178 179 protected: 180 PtrType ptr; 181 index_t stride; 182 }; 183 184 template < 185 typename T, 186 typename index_t = int64_t, 187 template <typename U> class PtrTraits = DefaultPtrTraits 188 > 189 class StridedRandomAccessor 190 : public ConstStridedRandomAccessor<T, index_t, PtrTraits> { 191 public: 192 using difference_type = index_t; 193 using value_type = T; 194 using pointer = typename PtrTraits<T>::PtrType; 195 using reference = value_type&; 196 197 using BaseType = ConstStridedRandomAccessor<T, index_t, PtrTraits>; 198 using PtrType = typename PtrTraits<T>::PtrType; 199 200 // Constructors { 201 C10_HOST_DEVICE StridedRandomAccessor(PtrType ptr,index_t stride)202 StridedRandomAccessor(PtrType ptr, index_t stride) 203 : BaseType(ptr, stride) 204 {} 205 206 C10_HOST_DEVICE StridedRandomAccessor(PtrType ptr)207 explicit StridedRandomAccessor(PtrType ptr) 208 : BaseType(ptr) 209 {} 210 211 C10_HOST_DEVICE StridedRandomAccessor()212 StridedRandomAccessor() 213 : BaseType() 214 {} 215 // } 216 217 // Pointer-like operations { 218 C10_HOST_DEVICE 219 reference operator*() const { 220 return *this->ptr; 221 } 222 223 C10_HOST_DEVICE 224 value_type* operator->() const { 225 return reinterpret_cast<value_type*>(this->ptr); 226 } 227 228 C10_HOST_DEVICE 229 reference operator[](index_t idx) const { 230 return this->ptr[idx * this->stride]; 231 } 232 // } 233 234 // Prefix/postfix increment/decrement { 235 C10_HOST_DEVICE 236 StridedRandomAccessor& operator++() { 237 this->ptr += this->stride; 238 return *this; 239 } 240 241 C10_HOST_DEVICE 242 StridedRandomAccessor operator++(int) { 243 StridedRandomAccessor copy(*this); 244 ++*this; 245 return copy; 246 } 247 248 C10_HOST_DEVICE 249 StridedRandomAccessor& operator--() { 250 this->ptr -= this->stride; 251 return *this; 252 } 253 254 C10_HOST_DEVICE 255 StridedRandomAccessor operator--(int) { 256 StridedRandomAccessor copy(*this); 257 --*this; 258 return copy; 259 } 260 // } 261 262 // Arithmetic operations { 263 C10_HOST_DEVICE 264 StridedRandomAccessor& operator+=(index_t offset) { 265 this->ptr += offset * this->stride; 266 return *this; 267 } 268 269 C10_HOST_DEVICE 270 StridedRandomAccessor operator+(index_t offset) const { 271 return StridedRandomAccessor(this->ptr + offset * this->stride, this->stride); 272 } 273 274 C10_HOST_DEVICE 275 friend StridedRandomAccessor operator+( 276 index_t offset, 277 const StridedRandomAccessor& accessor 278 ) { 279 return accessor + offset; 280 } 281 282 C10_HOST_DEVICE 283 StridedRandomAccessor& operator-=(index_t offset) { 284 this->ptr -= offset * this->stride; 285 return *this; 286 } 287 288 C10_HOST_DEVICE 289 StridedRandomAccessor operator-(index_t offset) const { 290 return StridedRandomAccessor(this->ptr - offset * this->stride, this->stride); 291 } 292 293 // Note that here we call BaseType::operator- version 294 C10_HOST_DEVICE 295 difference_type operator-(const BaseType& other) const { 296 return (static_cast<const BaseType&>(*this) - other); 297 } 298 // } 299 }; 300 301 } // namespace at::native 302