xref: /aosp_15_r20/external/ComputeLibrary/arm_compute/core/WindowIterator.h (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2018-2021 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifndef ARM_COMPUTE_WINDOW_ITERATOR_H
25 #define ARM_COMPUTE_WINDOW_ITERATOR_H
26 #include "arm_compute/core/Coordinates.h"
27 #include "arm_compute/core/Error.h"
28 #include "arm_compute/core/ITensor.h"
29 #include "arm_compute/core/Window.h"
30 
31 
32 namespace arm_compute
33 {
34 /** Convert an offset in window steps into absolute coordinates.
35  *
36  * @param[in] w      Window @p offset is related to.
37  * @param[in] offset Offset inside the window expressed in number of window steps.
38  *
39  * @return Absolute coordinates.
40  */
convert_window_coord_to_position(const Window & w,const Coordinates & offset)41 inline Coordinates convert_window_coord_to_position(const Window &w, const Coordinates &offset)
42 {
43     Coordinates position;
44     for(unsigned int i = 0; i < Coordinates::num_max_dimensions; ++i)
45     {
46         position.set(i, w[i].start() + offset[i] * w[i].step());
47     }
48     return position;
49 }
50 
51 /** Tensor accessors to make it easier to interface with arm_gemm */
52 template <typename T>
53 class TensorAccessor
54 {
55 public:
56     /** Constructor:
57      *
58      * @param[in] tensor Source tensor, must be allocated.
59      */
TensorAccessor(const ITensor & tensor)60     TensorAccessor(const ITensor &tensor)
61         : _first(tensor.ptr_to_element(Coordinates())), _strides(tensor.info()->strides_in_bytes())
62     {
63     }
64     /** Get the stride of the dimension dim expressed in number of Ts.
65      *
66      * @param[in] dim Dimension of the wanted stride.
67      *
68      * @return Stride in number of Ts.
69      */
stride(size_t dim)70     inline size_t stride(size_t dim) const
71     {
72         ARM_COMPUTE_ERROR_ON(_strides[dim] % sizeof(T) != 0);
73         return _strides[dim] / sizeof(T);
74     }
75 
76     /** Manually set the stride of a dimension
77      *
78      * @param[in] dim  Dimension of the stride to set.
79      * @param[in] size Value to set the stride to (in bytes).
80      */
set_stride(size_t dim,size_t size)81     void set_stride(size_t dim, size_t size)
82     {
83         _strides[dim] = size;
84     }
85 
86     /** Manually set the strides
87      *
88      * @param[in] strides Strides to set
89      */
set_strides(const Strides & strides)90     void set_strides(const Strides &strides)
91     {
92         _strides = strides;
93     }
94 
95     /** Returns a pointer to the element at coordinates (x,y,z,w)
96      *
97      * @param[in] x X coordinates
98      * @param[in] y (optional) Y coordinates
99      * @param[in] z (optional) Z coordinates
100      * @param[in] w (optional) W coordinates
101      */
102     inline T *get_ptr(unsigned int x, unsigned int y = 0, unsigned int z = 0, unsigned int w = 0)
103     {
104         return reinterpret_cast<T *>(_first + x * _strides[0] + y * _strides[1] + z * _strides[2] + w * _strides[3]);
105     }
106 
107     /** Returns a pointer to the element at coordinates (x,y,z,w)
108      *
109      * @param[in] x X coordinates
110      * @param[in] y (optional) Y coordinates
111      * @param[in] z (optional) Z coordinates
112      * @param[in] w (optional) W coordinates
113      */
operator()114     inline T *operator()(unsigned int x, unsigned int y = 0, unsigned int z = 0, unsigned int w = 0)
115     {
116         return get_ptr(x, y, z, w);
117     }
118 
119     /** Returns a pointer to the first element of the tensor
120      *
121      * @return Pointer to the first element.
122      */
first_element()123     inline T *first_element()
124     {
125         return reinterpret_cast<T *>(_first);
126     }
127 
128     /** Returns a pointer to the first element of the tensor
129      *
130      * @return Pointer to the first element.
131      */
operator()132     inline T *operator()()
133     {
134         return first_element();
135     }
136 
137 private:
138     uint8_t *_first;   /**< Pointer to the first element of the tensor.*/
139     Strides  _strides; /**< Strides in bytes of the tensor */
140 };
141 
142 /** Iterate over a portion of a Window */
143 template <typename L>
144 class WindowIterator
145 {
146 public:
147     /** Construct a WindowIterator object
148      *
149      * @param[in] w               Window to use for the iteration
150      * @param[in] start           Where to start iterating from (In Window coordinates)
151      * @param[in] end             Where to stop iterating (In Window coordinates).
152      * @param[in] lambda_function Lambda function to call for every iteration between start and end. (It will be called last for end - 1)
153      */
WindowIterator(const Window & w,const Coordinates & start,const Coordinates & end,L && lambda_function)154     WindowIterator(const Window &w, const Coordinates &start, const Coordinates &end, L &&lambda_function)
155         : _lambda_function(std::move(lambda_function)),
156           _position(convert_window_coord_to_position(w, start)),
157           _end(convert_window_coord_to_position(w, end)),
158           _w(w)
159     {
160     }
161     /** Iterate over the lowest 3 dimensions of the window.
162      *
163      * @param[in] on_new_row_size Callback to be called before lambda_function every time the width of the row processed changes.
164      */
165     template <typename M>
iterate_3D(M && on_new_row_size)166     void iterate_3D(M &&on_new_row_size)
167     {
168         while(_end.z() != _position.z())
169         {
170             iterate_2D_internal(on_new_row_size, _w.x().end() - _w.x().step(), _w.y().end() - _w.y().step());
171             _position[2] += _w.z().step();
172             _position[1] = _w.y().start();
173             _position[0] = _w.x().start();
174         }
175         // Left over:
176         iterate_2D(on_new_row_size);
177     }
178 
179     /** Iterate over the lowest 2 dimensions of the window.
180      *
181      * @param[in] on_new_row_size Callback to be called before lambda_function every time the width of the row processed changes.
182      */
183     template <typename M>
iterate_2D(M && on_new_row_size)184     void iterate_2D(M &&on_new_row_size)
185     {
186         iterate_2D_internal(on_new_row_size, _end.x(), _end.y());
187     }
188 
189     /** Change the step used for the iteration.
190      *
191      * @note Does not affect the start and end points.
192      *
193      * @param[in] dim  Dimension to change
194      * @param[in] step New step to use for the given dimension.
195      */
set_step(size_t dim,int step)196     inline void set_step(size_t dim, int step)
197     {
198         _w.set_dimension_step(dim, step);
199     }
200 
201     /** Returns the coordinates in absolute coordinates of the end position
202          *
203          * @return End position coordinates.
204          */
end_position()205     const Coordinates &end_position() const
206     {
207         return _end;
208     }
209 
210 private:
211     template <typename M>
iterate_2D_internal(M && on_new_row_size,int end_x,int end_y)212     void iterate_2D_internal(M &&on_new_row_size, int end_x, int end_y)
213     {
214         //Is there more than one row to process ?
215         if(end_y == _position.y())
216         {
217             // Both start and end belong to the same row:
218             iterate_over_dim0(end_x + _w.x().step(), on_new_row_size);
219         }
220         else
221         {
222             // Do we start from the beginning of the row ?
223             if(_w.x().start() != _position.x())
224             {
225                 //Start in the middle of a row: process left-over X
226                 iterate_over_dim0(_w.x().end(), on_new_row_size);
227                 _position[1] += _w.y().step();
228             }
229 
230             //Middle rows
231             bool no_leftover = end_x + _w.x().step() == _w.x().end();
232             if(no_leftover)
233             {
234                 //Switch to full row size:
235                 on_new_row_size(_w[0].start(), _w.x().end());
236                 // Shouldn't be possible to reach that point and not have at least one entire row to process
237                 ARM_COMPUTE_ERROR_ON(_w.y().end() == _position.y());
238                 // No leftover: all the rows lefts to process are full width:
239                 iterate_over_dim1(end_y + _w.y().step());
240             }
241             else
242             {
243                 // Are there full rows to process ?
244                 if(_position[1] != end_y)
245                 {
246                     //Switch to full row size:
247                     on_new_row_size(_w[0].start(), _w.x().end());
248                     iterate_over_dim1(end_y);
249                 }
250 
251                 //Leftover end x
252                 _position[0] = _w.x().start();
253                 iterate_over_dim0(end_x + _w.x().step(), on_new_row_size);
254             }
255         }
256     }
257 
258     /** Process full rows below 'end'
259      *
260      * @param[in] end Y position to stop at.
261      */
iterate_over_dim1(int end)262     void iterate_over_dim1(int end)
263     {
264         for(; _position[1] != end; _position[1] += _w[1].step())
265         {
266             _position[0] = _w[0].start();
267             iterate_over_dim0(_w[0].end());
268         }
269     }
270 
271     /** Process elements of a given row up to 'end'
272      *
273      * @param[in] end             X position to stop at.
274      * @param[in] on_new_row_size Callback to call before starting iterating
275      */
276     template <typename M>
iterate_over_dim0(int end,M && on_new_row_size)277     void iterate_over_dim0(int end, M &&on_new_row_size)
278     {
279         on_new_row_size(_position.x(), end);
280         iterate_over_dim0(end);
281     }
282 
283     /** Process elements of a given row up to 'end'
284      *
285      * @param[in] end X position to stop at.
286      */
iterate_over_dim0(int end)287     void iterate_over_dim0(int end)
288     {
289         // Both start and end belong to the same row:
290         ARM_COMPUTE_ERROR_ON(_position[0] > end);
291         for(; _position.x() < end; _position[0] += _w[0].step())
292         {
293             _lambda_function(_position);
294         }
295     }
296 
297     L           _lambda_function; /**< Function to call for each iteration */
298     Coordinates _position;        /**< Absolute coordinates of the current position */
299     Coordinates _end;             /**< Absolute coordinates of the point after the last iteration */
300     Window      _w;               /**< Window to iterate over */
301 };
302 
303 /** Create a WindowIterator object
304  *
305  * @param[in] w               Window to use for the iteration
306  * @param[in] start           Where to start iterating from (In Window coordinates)
307  * @param[in] end             Where to stop iterating (In Window coordinates).
308  * @param[in] lambda_function Lambda function to call for every iteration between start and end. (It will be called last for end - 1)
309  *
310  * @return A WindowIterator object.
311  */
312 template <typename L>
create_window_iterator(const Window & w,const Coordinates & start,const Coordinates & end,L && lambda_function)313 WindowIterator<L> create_window_iterator(const Window &w, const Coordinates &start, const Coordinates &end, L &&lambda_function)
314 {
315     return WindowIterator<L>(w, start, end, std::move(lambda_function));
316 }
317 }
318 #endif /*ARM_COMPUTE_WINDOW_ITERATOR_H*/
319