xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/StridedRandomAccessor.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 namespace at::native {
4 
5 // (Const)StridedRandomAccessor is a
6 // (const) random access iterator defined over
7 // a strided array.
8 
9 // The traits below are to introduce __restrict__
10 // modifier on different platforms.
11 
12 template <typename T>
13 struct DefaultPtrTraits {
14   using PtrType = T*;
15 };
16 
17 #if (defined(_WIN32) || defined(_WIN64))
18 #define RESTRICT __restrict
19 #else
20 #define RESTRICT __restrict__
21 #endif
22 
23 template <typename T>
24 struct RestrictPtrTraits {
25   using PtrType = T* RESTRICT;
26 };
27 
28 template <
29   typename T,
30   typename index_t = int64_t,
31   template <typename U> class PtrTraits = DefaultPtrTraits
32 >
33 class ConstStridedRandomAccessor {
34 public:
35   using difference_type = index_t;
36   using value_type = const T;
37   using pointer = const typename PtrTraits<T>::PtrType;
38   using reference = const value_type&;
39   using iterator_category = std::random_access_iterator_tag;
40 
41   using PtrType = typename PtrTraits<T>::PtrType;
42   using index_type = index_t;
43 
44   // Constructors {
45   C10_HOST_DEVICE
ConstStridedRandomAccessor(PtrType ptr,index_t stride)46   ConstStridedRandomAccessor(PtrType ptr, index_t stride)
47     : ptr{ptr}, stride{stride}
48   {}
49 
50   C10_HOST_DEVICE
ConstStridedRandomAccessor(PtrType ptr)51   explicit ConstStridedRandomAccessor(PtrType ptr)
52     : ptr{ptr}, stride{static_cast<index_t>(1)}
53   {}
54 
55   C10_HOST_DEVICE
ConstStridedRandomAccessor()56   ConstStridedRandomAccessor()
57     : ptr{nullptr}, stride{static_cast<index_t>(1)}
58   {}
59   // }
60 
61   // Pointer-like operations {
62   C10_HOST_DEVICE
63   reference operator*() const {
64     return *ptr;
65   }
66 
67   C10_HOST_DEVICE
68   const value_type* operator->() const {
69     return reinterpret_cast<const value_type*>(ptr);
70   }
71 
72   C10_HOST_DEVICE
73   reference operator[](index_t idx) const {
74     return ptr[idx * stride];
75   }
76   // }
77 
78   // Prefix/postfix increment/decrement {
79   C10_HOST_DEVICE
80   ConstStridedRandomAccessor& operator++() {
81     ptr += stride;
82     return *this;
83   }
84 
85   C10_HOST_DEVICE
86   ConstStridedRandomAccessor operator++(int) {
87     ConstStridedRandomAccessor copy(*this);
88     ++*this;
89     return copy;
90   }
91 
92   C10_HOST_DEVICE
93   ConstStridedRandomAccessor& operator--() {
94     ptr -= stride;
95     return *this;
96   }
97 
98   C10_HOST_DEVICE
99   ConstStridedRandomAccessor operator--(int) {
100     ConstStridedRandomAccessor copy(*this);
101     --*this;
102     return copy;
103   }
104   // }
105 
106   // Arithmetic operations {
107   C10_HOST_DEVICE
108   ConstStridedRandomAccessor& operator+=(index_t offset) {
109     ptr += offset * stride;
110     return *this;
111   }
112 
113   C10_HOST_DEVICE
114   ConstStridedRandomAccessor operator+(index_t offset) const {
115     return ConstStridedRandomAccessor(ptr + offset * stride, stride);
116   }
117 
118   C10_HOST_DEVICE
119   friend ConstStridedRandomAccessor operator+(
120     index_t offset,
121     const ConstStridedRandomAccessor& accessor
122   ) {
123     return accessor + offset;
124   }
125 
126   C10_HOST_DEVICE
127   ConstStridedRandomAccessor& operator-=(index_t offset) {
128     ptr -= offset * stride;
129     return *this;
130   }
131 
132   C10_HOST_DEVICE
133   ConstStridedRandomAccessor operator-(index_t offset) const {
134     return ConstStridedRandomAccessor(ptr - offset * stride, stride);
135   }
136 
137   // Note that this operator is well-defined when `this` and `other`
138   // represent the same sequences, i.e. when
139   // 1. this.stride == other.stride,
140   // 2. |other - this| / this.stride is an Integer.
141   C10_HOST_DEVICE
142   difference_type operator-(const ConstStridedRandomAccessor& other) const {
143     return (ptr - other.ptr) / stride;
144   }
145   // }
146 
147   // Comparison operators {
148   C10_HOST_DEVICE
149   bool operator==(const ConstStridedRandomAccessor& other) const {
150     return (ptr == other.ptr) && (stride == other.stride);
151   }
152 
153   C10_HOST_DEVICE
154   bool operator!=(const ConstStridedRandomAccessor& other) const {
155     return !(*this == other);
156   }
157 
158   C10_HOST_DEVICE
159   bool operator<(const ConstStridedRandomAccessor& other) const {
160     return ptr < other.ptr;
161   }
162 
163   C10_HOST_DEVICE
164   bool operator<=(const ConstStridedRandomAccessor& other) const {
165     return (*this < other) || (*this == other);
166   }
167 
168   C10_HOST_DEVICE
169   bool operator>(const ConstStridedRandomAccessor& other) const {
170     return !(*this <= other);
171   }
172 
173   C10_HOST_DEVICE
174   bool operator>=(const ConstStridedRandomAccessor& other) const {
175     return !(*this < other);
176   }
177   // }
178 
179 protected:
180   PtrType ptr;
181   index_t stride;
182 };
183 
184 template <
185   typename T,
186   typename index_t = int64_t,
187   template <typename U> class PtrTraits = DefaultPtrTraits
188 >
189 class StridedRandomAccessor
190   : public ConstStridedRandomAccessor<T, index_t, PtrTraits> {
191 public:
192   using difference_type = index_t;
193   using value_type = T;
194   using pointer = typename PtrTraits<T>::PtrType;
195   using reference = value_type&;
196 
197   using BaseType = ConstStridedRandomAccessor<T, index_t, PtrTraits>;
198   using PtrType = typename PtrTraits<T>::PtrType;
199 
200   // Constructors {
201   C10_HOST_DEVICE
StridedRandomAccessor(PtrType ptr,index_t stride)202   StridedRandomAccessor(PtrType ptr, index_t stride)
203     : BaseType(ptr, stride)
204   {}
205 
206   C10_HOST_DEVICE
StridedRandomAccessor(PtrType ptr)207   explicit StridedRandomAccessor(PtrType ptr)
208     : BaseType(ptr)
209   {}
210 
211   C10_HOST_DEVICE
StridedRandomAccessor()212   StridedRandomAccessor()
213     : BaseType()
214   {}
215   // }
216 
217   // Pointer-like operations {
218   C10_HOST_DEVICE
219   reference operator*() const {
220     return *this->ptr;
221   }
222 
223   C10_HOST_DEVICE
224   value_type* operator->() const {
225     return reinterpret_cast<value_type*>(this->ptr);
226   }
227 
228   C10_HOST_DEVICE
229   reference operator[](index_t idx) const {
230     return this->ptr[idx * this->stride];
231   }
232   // }
233 
234   // Prefix/postfix increment/decrement {
235   C10_HOST_DEVICE
236   StridedRandomAccessor& operator++() {
237     this->ptr += this->stride;
238     return *this;
239   }
240 
241   C10_HOST_DEVICE
242   StridedRandomAccessor operator++(int) {
243     StridedRandomAccessor copy(*this);
244     ++*this;
245     return copy;
246   }
247 
248   C10_HOST_DEVICE
249   StridedRandomAccessor& operator--() {
250     this->ptr -= this->stride;
251     return *this;
252   }
253 
254   C10_HOST_DEVICE
255   StridedRandomAccessor operator--(int) {
256     StridedRandomAccessor copy(*this);
257     --*this;
258     return copy;
259   }
260   // }
261 
262   // Arithmetic operations {
263   C10_HOST_DEVICE
264   StridedRandomAccessor& operator+=(index_t offset) {
265     this->ptr += offset * this->stride;
266     return *this;
267   }
268 
269   C10_HOST_DEVICE
270   StridedRandomAccessor operator+(index_t offset) const {
271     return StridedRandomAccessor(this->ptr + offset * this->stride, this->stride);
272   }
273 
274   C10_HOST_DEVICE
275   friend StridedRandomAccessor operator+(
276     index_t offset,
277     const StridedRandomAccessor& accessor
278   ) {
279     return accessor + offset;
280   }
281 
282   C10_HOST_DEVICE
283   StridedRandomAccessor& operator-=(index_t offset) {
284     this->ptr -= offset * this->stride;
285     return *this;
286   }
287 
288   C10_HOST_DEVICE
289   StridedRandomAccessor operator-(index_t offset) const {
290     return StridedRandomAccessor(this->ptr - offset * this->stride, this->stride);
291   }
292 
293   // Note that here we call BaseType::operator- version
294   C10_HOST_DEVICE
295   difference_type operator-(const BaseType& other) const {
296     return (static_cast<const BaseType&>(*this) - other);
297   }
298   // }
299 };
300 
301 } // namespace at::native
302