xref: /aosp_15_r20/external/ComputeLibrary/src/core/CL/cl_kernels/load_store_utility.h (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2020 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 
25 /** Store the 0 to (n-1)th rows of the given variables
26  * @name STORE_ROW_n
27  *
28  * @param[in] N0        The width of the passed in vector. Supported: 1, 2, 3, 4, 8, 16
29  * @param[in] DATA_TYPE The data type of the vectors
30  * @param[in] BASENAME  The basename of the variables
31  * @param[in] PTR       The base pointer
32  * @param[in] STRIDE_Y  The stride value in y-axis direction
33  * @param[in] Z         The offset in z-axis direction
34  * @{
35  */
36 #define STORE_ROW_1(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
37     VSTORE(N0)                                                 \
38     (BASENAME##0, 0, (__global DATA_TYPE *)(PTR + 0 * STRIDE_Y + Z##0));
39 
40 #define STORE_ROW_2(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
41     STORE_ROW_1(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
42     VSTORE(N0)                                                 \
43     (BASENAME##1, 0, (__global DATA_TYPE *)(PTR + 1 * STRIDE_Y + Z##1));
44 
45 #define STORE_ROW_3(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
46     STORE_ROW_2(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
47     VSTORE(N0)                                                 \
48     (BASENAME##2, 0, (__global DATA_TYPE *)(PTR + 2 * STRIDE_Y + Z##2));
49 
50 #define STORE_ROW_4(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
51     STORE_ROW_3(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
52     VSTORE(N0)                                                 \
53     (BASENAME##3, 0, (__global DATA_TYPE *)(PTR + 3 * STRIDE_Y + Z##3));
54 
55 #define STORE_ROW_5(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
56     STORE_ROW_4(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
57     VSTORE(N0)                                                 \
58     (BASENAME##4, 0, (__global DATA_TYPE *)(PTR + 4 * STRIDE_Y + Z##4));
59 
60 #define STORE_ROW_6(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
61     STORE_ROW_5(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
62     VSTORE(N0)                                                 \
63     (BASENAME##5, 0, (__global DATA_TYPE *)(PTR + 5 * STRIDE_Y + Z##5));
64 
65 #define STORE_ROW_7(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
66     STORE_ROW_6(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
67     VSTORE(N0)                                                 \
68     (BASENAME##6, 0, (__global DATA_TYPE *)(PTR + 6 * STRIDE_Y + Z##6));
69 
70 #define STORE_ROW_8(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
71     STORE_ROW_7(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
72     VSTORE(N0)                                                 \
73     (BASENAME##7, 0, (__global DATA_TYPE *)(PTR + 7 * STRIDE_Y + Z##7));
74 
75 #define STORE_ROW_9(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
76     STORE_ROW_8(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
77     VSTORE(N0)                                                 \
78     (BASENAME##8, 0, (__global DATA_TYPE *)(PTR + 8 * STRIDE_Y + Z##8));
79 
80 #define STORE_ROW_10(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
81     STORE_ROW_9(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)      \
82     VSTORE(N0)                                                  \
83     (BASENAME##9, 0, (__global DATA_TYPE *)(PTR + 9 * STRIDE_Y + Z##9));
84 
85 #define STORE_ROW_11(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
86     STORE_ROW_10(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
87     VSTORE(N0)                                                  \
88     (BASENAME##A, 0, (__global DATA_TYPE *)(PTR + 10 * STRIDE_Y + Z##A));
89 
90 #define STORE_ROW_12(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
91     STORE_ROW_11(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
92     VSTORE(N0)                                                  \
93     (BASENAME##B, 0, (__global DATA_TYPE *)(PTR + 11 * STRIDE_Y + Z##B));
94 
95 #define STORE_ROW_13(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
96     STORE_ROW_12(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
97     VSTORE(N0)                                                  \
98     (BASENAME##C, 0, (__global DATA_TYPE *)(PTR + 12 * STRIDE_Y + Z##C));
99 
100 #define STORE_ROW_14(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
101     STORE_ROW_13(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
102     VSTORE(N0)                                                  \
103     (BASENAME##D, 0, (__global DATA_TYPE *)(PTR + 13 * STRIDE_Y + Z##D));
104 
105 #define STORE_ROW_15(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
106     STORE_ROW_14(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
107     VSTORE(N0)                                                  \
108     (BASENAME##E, 0, (__global DATA_TYPE *)(PTR + 14 * STRIDE_Y + Z##E));
109 
110 #define STORE_ROW_16(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
111     STORE_ROW_15(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
112     VSTORE(N0)                                                  \
113     (BASENAME##F, 0, (__global DATA_TYPE *)(PTR + 15 * STRIDE_Y + Z##F));
114 /** @} */ // end of groupd STORE_ROW_n
115 
116 /** Convert and store the 0th to (n-1)th rows of the given variables
117  * @name CONVERT_STORE_ROW_n
118  *
119  * @param[in] N0        The size of the vectors
120  * @param[in] DATA_TYPE The data type of the vectors
121  * @param[in] BASENAME  The basename of the variables
122  * @param[in] PTR       The base pointer
123  * @param[in] STRIDE_Y  The stride value in y-axis direction
124  * @param[in] Z         The offset in z-axis direction
125  * @{
126  */
127 #define CONVERT_STORE_ROW_1(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
128     VSTORE(N0)                                                         \
129     (CONVERT_SAT((BASENAME##0), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 0 * STRIDE_Y + Z##0));
130 
131 #define CONVERT_STORE_ROW_2(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
132     CONVERT_STORE_ROW_1(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
133     VSTORE(N0)                                                         \
134     (CONVERT_SAT((BASENAME##1), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 1 * STRIDE_Y + Z##1));
135 
136 #define CONVERT_STORE_ROW_3(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
137     CONVERT_STORE_ROW_2(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
138     VSTORE(N0)                                                         \
139     (CONVERT_SAT((BASENAME##2), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 2 * STRIDE_Y + Z##2));
140 
141 #define CONVERT_STORE_ROW_4(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
142     CONVERT_STORE_ROW_3(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
143     VSTORE(N0)                                                         \
144     (CONVERT_SAT((BASENAME##3), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 3 * STRIDE_Y + Z##3));
145 
146 #define CONVERT_STORE_ROW_5(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
147     CONVERT_STORE_ROW_4(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
148     VSTORE(N0)                                                         \
149     (CONVERT_SAT((BASENAME##4), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 4 * STRIDE_Y + Z##4));
150 
151 #define CONVERT_STORE_ROW_6(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
152     CONVERT_STORE_ROW_5(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
153     VSTORE(N0)                                                         \
154     (CONVERT_SAT((BASENAME##5), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 5 * STRIDE_Y + Z##5));
155 
156 #define CONVERT_STORE_ROW_7(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
157     CONVERT_STORE_ROW_6(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
158     VSTORE(N0)                                                         \
159     (CONVERT_SAT((BASENAME##6), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 6 * STRIDE_Y + Z##6));
160 
161 #define CONVERT_STORE_ROW_8(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
162     CONVERT_STORE_ROW_7(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
163     VSTORE(N0)                                                         \
164     (CONVERT_SAT((BASENAME##7), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 7 * STRIDE_Y + Z##7));
165 
166 #define CONVERT_STORE_ROW_9(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
167     CONVERT_STORE_ROW_8(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
168     VSTORE(N0)                                                         \
169     (CONVERT_SAT((BASENAME##8), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 8 * STRIDE_Y + Z##8));
170 
171 #define CONVERT_STORE_ROW_10(N0, DATA, BASENAME, PTR, STRIDE_Y, Z) \
172     CONVERT_STORE_ROW_9(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
173     VSTORE(N0)                                                     \
174     (CONVERT_SAT((BASENAME##9), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 9 * STRIDE_Y + Z##9));
175 
176 #define CONVERT_STORE_ROW_11(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
177     CONVERT_STORE_ROW_10(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
178     VSTORE(N0)                                                          \
179     (CONVERT_SAT((BASENAME##A), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 10 * STRIDE_Y + Z##A));
180 
181 #define CONVERT_STORE_ROW_12(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
182     CONVERT_STORE_ROW_11(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
183     VSTORE(N0)                                                          \
184     (CONVERT_SAT((BASENAME##B), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 11 * STRIDE_Y + Z##B));
185 
186 #define CONVERT_STORE_ROW_13(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
187     CONVERT_STORE_ROW_12(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
188     VSTORE(N0)                                                          \
189     (CONVERT_SAT((BASENAME##C), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 12 * STRIDE_Y + Z##C));
190 
191 #define CONVERT_STORE_ROW_14(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
192     CONVERT_STORE_ROW_13(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
193     VSTORE(N0)                                                          \
194     (CONVERT_SAT((BASENAME##D), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 13 * STRIDE_Y + Z##D));
195 
196 #define CONVERT_STORE_ROW_15(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
197     CONVERT_STORE_ROW_14(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
198     VSTORE(N0)                                                          \
199     (CONVERT_SAT((BASENAME##E), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 14 * STRIDE_Y + Z##E));
200 
201 #define CONVERT_STORE_ROW_16(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
202     CONVERT_STORE_ROW_15(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
203     VSTORE(N0)                                                          \
204     (CONVERT_SAT((BASENAME##F), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 15 * STRIDE_Y + Z##F));
205 
206 /** @} */ // end of groupd CONVERT_STORE_ROW_n
207 
208 /** Store a block of the given size M0xN0
209  * @name STORE_BLOCK
210  *
211  * Supported cases are M0=1,2,3,...,16 and N0=2,3,4,8,16.
212  * The data to store is expected to have consecutive names for each row.
213  * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2.
214  * The Z offset is expected to have consecutive names.
215  * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
216  *
217  * @param[in] M0        The number of rows to store
218  * @param[in] N0        The size of each vector
219  * @param[in] DATA_TYPE The data type of the vectors
220  * @param[in] BASENAME  The basename of the variables
221  * @param[in] PTR       The base pointer
222  * @param[in] STRIDE_Y  The stride value in y-axis direction
223  * @param[in] Z         The offset in z-axis direction
224  * @{
225  */
226 #define STORE_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) STORE_ROW_##M0(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
227 #define STORE_BLOCK(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) STORE_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
228 /** @} */ // end of group STORE_BLOCK
229 
230 /** Convert and store a block of the given size M0xN0
231  * @name CONVERT_STORE_BLOCK
232  *
233  * Supported cases are M0=1,2,3,...,16 and N0=2,3,4,8,16.
234  * The data to store is expected to have consecutive names for each row.
235  * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2.
236  * The Z offset is expected to have consecutive names.
237  * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
238  *
239  * @param[in] M0        The number of rows to store
240  * @param[in] N0        The size of each vector
241  * @param[in] DATA_TYPE The data type of the vectors
242  * @param[in] BASENAME  The basename of the variables
243  * @param[in] PTR       The base pointer
244  * @param[in] STRIDE_Y  The stride value in y-axis direction
245  * @param[in] Z         The offset in z-axis direction
246  * @{
247  */
248 #define CONVERT_STORE_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) CONVERT_STORE_ROW_##M0(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
249 #define CONVERT_STORE_BLOCK(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) CONVERT_STORE_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
250 /** @} */ // end of group CONVERT_STORE_BLOCK
251 
252 /** Partially store the 0 to (n-1)th rows of the given variables
253  * @name STORE_ROW_PARTIAL_n
254  * Within each row, store the lower @p STORE_N0 elements of vectors of width @p N0
255  *
256  * @note in case @p STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
257  *
258  * @param[in] N0        The width of the passed in vector. Supported: 1, 2, 3, 4, 8, 16
259  * @param[in] STORE_N0  The **lower** size of the vectors to store. Supported: [1-16 and <= @p N0
260  * @param[in] DATA_TYPE The data type of the vectors
261  * @param[in] BASENAME  The basename of the variables
262  * @param[in] PTR       The base pointer
263  * @param[in] STRIDE_Y  The stride value in y-axis direction
264  * @param[in] Z         The offset in z-axis direction
265  * @{
266  */
267 #define STORE_ROW_PARTIAL_1(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
268     VSTORE_PARTIAL(N0, STORE_N0)                                                 \
269     (BASENAME##0, 0, (__global DATA_TYPE *)(PTR + 0 * STRIDE_Y + Z##0));
270 
271 #define STORE_ROW_PARTIAL_2(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
272     STORE_ROW_PARTIAL_1(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
273     VSTORE_PARTIAL(N0, STORE_N0)                                                 \
274     (BASENAME##1, 0, (__global DATA_TYPE *)(PTR + 1 * STRIDE_Y + Z##1));
275 
276 #define STORE_ROW_PARTIAL_3(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
277     STORE_ROW_PARTIAL_2(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
278     VSTORE_PARTIAL(N0, STORE_N0)                                                 \
279     (BASENAME##2, 0, (__global DATA_TYPE *)(PTR + 2 * STRIDE_Y + Z##2));
280 
281 #define STORE_ROW_PARTIAL_4(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
282     STORE_ROW_PARTIAL_3(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
283     VSTORE_PARTIAL(N0, STORE_N0)                                                 \
284     (BASENAME##3, 0, (__global DATA_TYPE *)(PTR + 3 * STRIDE_Y + Z##3));
285 
286 #define STORE_ROW_PARTIAL_5(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
287     STORE_ROW_PARTIAL_4(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
288     VSTORE_PARTIAL(N0, STORE_N0)                                                 \
289     (BASENAME##4, 0, (__global DATA_TYPE *)(PTR + 4 * STRIDE_Y + Z##4));
290 
291 #define STORE_ROW_PARTIAL_6(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
292     STORE_ROW_PARTIAL_5(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
293     VSTORE_PARTIAL(N0, STORE_N0)                                                 \
294     (BASENAME##5, 0, (__global DATA_TYPE *)(PTR + 5 * STRIDE_Y + Z##5));
295 
296 #define STORE_ROW_PARTIAL_7(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
297     STORE_ROW_PARTIAL_6(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
298     VSTORE_PARTIAL(N0, STORE_N0)                                                 \
299     (BASENAME##6, 0, (__global DATA_TYPE *)(PTR + 6 * STRIDE_Y + Z##6));
300 
301 #define STORE_ROW_PARTIAL_8(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
302     STORE_ROW_PARTIAL_7(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
303     VSTORE_PARTIAL(N0, STORE_N0)                                                 \
304     (BASENAME##7, 0, (__global DATA_TYPE *)(PTR + 7 * STRIDE_Y + Z##7));
305 
306 #define STORE_ROW_PARTIAL_9(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
307     STORE_ROW_PARTIAL_8(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
308     VSTORE_PARTIAL(N0, STORE_N0)                                                 \
309     (BASENAME##8, 0, (__global DATA_TYPE *)(PTR + 8 * STRIDE_Y + Z##8));
310 
311 #define STORE_ROW_PARTIAL_10(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
312     STORE_ROW_PARTIAL_9(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)      \
313     VSTORE_PARTIAL(N0, STORE_N0)                                                  \
314     (BASENAME##9, 0, (__global DATA_TYPE *)(PTR + 9 * STRIDE_Y + Z##9));
315 
316 #define STORE_ROW_PARTIAL_11(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
317     STORE_ROW_PARTIAL_10(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
318     VSTORE_PARTIAL(N0, STORE_N0)                                                  \
319     (BASENAME##A, 0, (__global DATA_TYPE *)(PTR + 10 * STRIDE_Y + Z##A));
320 
321 #define STORE_ROW_PARTIAL_12(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
322     STORE_ROW_PARTIAL_11(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
323     VSTORE_PARTIAL(N0, STORE_N0)                                                  \
324     (BASENAME##B, 0, (__global DATA_TYPE *)(PTR + 11 * STRIDE_Y + Z##B));
325 
326 #define STORE_ROW_PARTIAL_13(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
327     STORE_ROW_PARTIAL_12(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
328     VSTORE_PARTIAL(N0, STORE_N0)                                                  \
329     (BASENAME##C, 0, (__global DATA_TYPE *)(PTR + 12 * STRIDE_Y + Z##C));
330 
331 #define STORE_ROW_PARTIAL_14(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
332     STORE_ROW_PARTIAL_13(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
333     VSTORE_PARTIAL(N0, STORE_N0)                                                  \
334     (BASENAME##D, 0, (__global DATA_TYPE *)(PTR + 13 * STRIDE_Y + Z##D));
335 
336 #define STORE_ROW_PARTIAL_15(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
337     STORE_ROW_PARTIAL_14(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
338     VSTORE_PARTIAL(N0, STORE_N0)                                                  \
339     (BASENAME##E, 0, (__global DATA_TYPE *)(PTR + 14 * STRIDE_Y + Z##E));
340 
341 #define STORE_ROW_PARTIAL_16(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
342     STORE_ROW_PARTIAL_15(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
343     VSTORE_PARTIAL(N0, STORE_N0)                                                  \
344     (BASENAME##F, 0, (__global DATA_TYPE *)(PTR + 15 * STRIDE_Y + Z##F));
345 /** @} */ // end of groupd STORE_ROW_PARTIAL_n
346 
347 /** Partially store a block of the given size STORE_M0xSTORE_N0
348  * @name STORE_BLOCK_PARTIAL
349  *
350  * @note The vector width @p N0 is also required for correct partial storing behaviour.
351  * @note in case @p STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
352  *
353  * The data to store is expected to have consecutive names for each row.
354  * E.g., for STORE_M0=3 and basename=c, the expected names are c0, c1 and c2.
355  * The Z offset is expected to have consecutive names.
356  * E.g., for STORE_M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
357  *
358  * @param[in] STORE_M0  The number of rows to store. Supported: 1-16
359  * @param[in] STORE_N0  The lower number of elements of vectors to store. Supported: 1-16 and <= @p N0
360  * @param[in] N0        The size of each vector. Supported: 1, 2, 3, 4, 8, 16
361  * @param[in] DATA_TYPE The data type of the vectors
362  * @param[in] BASENAME  The basename of the variables
363  * @param[in] PTR       The base pointer
364  * @param[in] STRIDE_Y  The stride value in y-axis direction
365  * @param[in] Z         The offset in z-axis direction
366  * @{
367  */
368 #define STORE_BLOCK_PARTIAL_STR(STORE_M0, STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) STORE_ROW_PARTIAL_##STORE_M0(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
369 #define STORE_BLOCK_PARTIAL(STORE_M0, STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) STORE_BLOCK_PARTIAL_STR(STORE_M0, STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
370 /** Store a block that can be partial in both x and y dimensions
371  *
372  * @note in cases @p PARTIAL_STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
373  *
374  * The data to store is expected to have consecutive names for each row.
375  * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2.
376  * The Z offset is expected to have consecutive names.
377  * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
378  *
379  * @param[in] M0               The number of rows to store, for non-partial blocks. Supported: 1-16
380  * @param[in] N0               The size of each vector, for non-partial blocks. Supported: 1, 2, 3, 4, 8, 16
381  * @param[in] DATA_TYPE        The data type of the vectors
382  * @param[in] BASENAME         The basename of the variables
383  * @param[in] PTR              The base pointer
384  * @param[in] STRIDE_Y         The stride value in y-axis direction
385  * @param[in] Z                The offset in z-axis direction
386  * @param[in] PARTIAL_STORE_M0 The partial size in y, for partial blocks. Supported range: [1, @p M0)
387  * @param[in] PARTIAL_STORE_N0 The partial size in x, for partial blocks. Supported range: [1, @p N0)
388  * @param[in] PARTIAL_COND_Y   Condition on the y axis to perform the partial store Y. True to use PARTIAL_STORE_M0 rather than M0.
389  * @param[in] PARTIAL_COND_X   Condition on the x axis to perform the partial store X. True to use PARTIAL_STORE_N0 rather than N0.
390  */
391 #define STORE_BLOCK_PARTIAL_IN_X_AND_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \
392     if(!(PARTIAL_COND_X) && !(PARTIAL_COND_Y))                                                                                                            \
393     {                                                                                                                                                     \
394         STORE_BLOCK_PARTIAL(M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                                                                           \
395     }                                                                                                                                                     \
396     else if((PARTIAL_COND_Y) && !(PARTIAL_COND_X))                                                                                                        \
397     {                                                                                                                                                     \
398         STORE_BLOCK_PARTIAL(PARTIAL_STORE_M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                                                             \
399     }                                                                                                                                                     \
400     else if(!(PARTIAL_COND_Y) && (PARTIAL_COND_X))                                                                                                        \
401     {                                                                                                                                                     \
402         STORE_BLOCK_PARTIAL(M0, PARTIAL_STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                                                             \
403     }                                                                                                                                                     \
404     else                                                                                                                                                  \
405     {                                                                                                                                                     \
406         STORE_BLOCK_PARTIAL(PARTIAL_STORE_M0, PARTIAL_STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                                               \
407     }
408 /** Store a block that can only be partial in x but not y.
409  *
410  * @note in case @p N0 or @p PARTIAL_STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
411  *
412  * The data to store is expected to have consecutive names for each row.
413  * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2.
414  * The Z offset is expected to have consecutive names.
415  * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
416  *
417  * @param[in] M0               The number of rows to store, for non-partial blocks. Supported: 1-16
418  * @param[in] N0               The size of each vector, for non-partial blocks. Supported: 1, 2, 3, 4, 8, 16
419  * @param[in] DATA_TYPE        The data type of the vectors
420  * @param[in] BASENAME         The basename of the variables
421  * @param[in] PTR              The base pointer
422  * @param[in] STRIDE_Y         The stride value in y-axis direction
423  * @param[in] Z                The offset in z-axis direction
424  * @param[in] PARTIAL_STORE_N0 The partial size in x, for partial blocks. Supported range: [1, @p N0)
425  * @param[in] PARTIAL_COND_X   Condition on the x axis to perform the partial store X. True to use PARTIAL_STORE_N0 rather than N0.
426  */
427 #define STORE_BLOCK_PARTIAL_IN_X(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_N0, PARTIAL_COND_X) \
428     if(!(PARTIAL_COND_X))                                                                                         \
429     {                                                                                                             \
430         STORE_BLOCK_PARTIAL(M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                                   \
431     }                                                                                                             \
432     else                                                                                                          \
433     {                                                                                                             \
434         STORE_BLOCK_PARTIAL(M0, PARTIAL_STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                     \
435     }
436 /** Store a block that can only be partial in y but not x.
437  *
438  * @note in case @p N0 or @p PARTIAL_STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
439  *
440  * The data to store is expected to have consecutive names for each row.
441  * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2.
442  * The Z offset is expected to have consecutive names.
443  * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
444  *
445  * @param[in] M0               The number of rows to store, for non-partial blocks. Supported: 1-16
446  * @param[in] N0               The size of each vector, for non-partial blocks. Supported: 1, 2, 3, 4, 8, 16
447  * @param[in] DATA_TYPE        The data type of the vectors
448  * @param[in] BASENAME         The basename of the variables
449  * @param[in] PTR              The base pointer
450  * @param[in] STRIDE_Y         The stride value in y-axis direction
451  * @param[in] Z                The offset in z-axis direction
452  * @param[in] PARTIAL_STORE_M0 The partial size in y, for partial blocks. Supported range: [1, @p M0)
453  * @param[in] PARTIAL_COND_Y   Condition on the y axis to perform the partial store Y. True to use PARTIAL_STORE_M0 rather than M0.
454  */
455 #define STORE_BLOCK_PARTIAL_IN_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_COND_Y) \
456     if(!(PARTIAL_COND_Y))                                                                                         \
457     {                                                                                                             \
458         STORE_BLOCK_PARTIAL(M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                                   \
459     }                                                                                                             \
460     else                                                                                                          \
461     {                                                                                                             \
462         STORE_BLOCK_PARTIAL(PARTIAL_STORE_M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                     \
463     }
464 /** @} */ // end of group STORE_BLOCK_PARTIAL
465 
466 #if defined(PARTIAL_STORE_M0) && defined(PARTIAL_STORE_N0)
467 
468 /** Boundary-aware GEMM block store
469  * @name STORE_BLOCK_BOUNDARY_AWARE
470  * This macro assumes the following schemes to achieve boundary-awareness:
471  *  - Overlapping load in Y axis from lhs tensor. This implies lhs has no padding along y dim.
472  *  - Non-Overlapping(normal) load from rhs tensor. This imples rhs can have paddings.
473  *  - Overlapping load in Y axis from bias tensor. This implies rhs has no padding along y dim.
474  * The macro then ensures that the dst tensor can be stored without any paddings in both x and y dim.
475  *
476  * In the y dimension, we place the partial blocks **at the beginning** while in the x dimension, we place the partial
477  * blocks **at the end**.
478  * Say, the dst tensor is of shape MxN and we have M0 and N0 as the block size, this is how we define "partial blocks"/
479  * "boundary block" (we use the 2 terms "partial blocks" and "boundary blocks" interchangeably) and its various parameters:
480  *
481  *  *--x-->                         x == 0                        x == 1
482  *  |                  |<------------------------------N-------------------------->|
483  *  y                  |<--------------N0------------->|<----PARTIAL_STORE_N0----->|
484  *  |     -------------#############################################################
485  *  *     |          | |...............................|...........................|
486  * y == 0 | PAR_..._M0 |......Boundary block in y......|.Boundary block in x and y.|
487  *        |          | |...............................|...........................|
488  *        M          --#############################################################
489  *        |          | |                               |...........................|
490  * y == 1 |         M0 |      Non-boundary block       |....Boundary block in x....|
491  *        |          | |                               |...........................|
492  *        |------------#############################################################
493  *
494  * Then @p PARTIAL_STORE_M0 = M % M0      and @p PARTIAL_STORE_N0 = N % N0
495  *
496  * @note in cases @p PARTIAL_STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
497  *
498  * It automatically detects if a giving M,N,M0,N0 combination can yield partial blocks in either X and Y dimension,
499  * and select corresponding store methods such that the boundary detection logic is only added when needed.
500  *
501  * The data to store is expected to have consecutive names for each row.
502  * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2.
503  * The Z offset is expected to have consecutive names.
504  * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
505  *
506  * @param[in] M0               The number of rows to store, for non-partial blocks. Supported: 1-16
507  * @param[in] N0               The size of each vector, for non-partial blocks. Supported: 1, 2, 3, 4, 8, 16
508  * @param[in] DATA_TYPE        The data type of the vectors
509  * @param[in] BASENAME         The basename of the variables
510  * @param[in] PTR              The base pointer
511  * @param[in] STRIDE_Y         The stride value in y-axis direction
512  * @param[in] Z                The offset in z-axis direction
513  * @param[in] PARTIAL_STORE_M0 The partial size in y, for partial blocks. Supported: [0, @p M0)
514  * @param[in] PARTIAL_STORE_N0 The partial size in x, for partial blocks. Supported: [0, @p N0)
515  * @param[in] PARTIAL_COND_Y   Condition on the y axis to perform the partial store Y. True to use PARTIAL_STORE_M0 rather than M0.
516  * @param[in] PARTIAL_COND_X   Condition on the x axis to perform the partial store X. True to use PARTIAL_STORE_N0 rather than N0.
517  * @{
518  */
519 #if PARTIAL_STORE_M0 == 0 && PARTIAL_STORE_N0 == 0
520 // Case1: No partial blocks in either x or y
521 #define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \
522     STORE_BLOCK(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
523 
524 #elif PARTIAL_STORE_M0 > 0 && PARTIAL_STORE_N0 == 0
525 // Case2: Partial blocks in y
526 #define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \
527     STORE_BLOCK_PARTIAL_IN_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_COND_Y)
528 
529 #elif PARTIAL_STORE_M0 == 0 && PARTIAL_STORE_N0 > 0
530 // Case3: Partial blocks in x
531 #define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \
532     STORE_BLOCK_PARTIAL_IN_X(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_N0, PARTIAL_COND_X)
533 
534 #else // PARTIAL_STORE_M0 == 0 && PARTIAL_STORE_N0 == 0
535 // Case4: Partial blocks in both x and y
536 #define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \
537     STORE_BLOCK_PARTIAL_IN_X_AND_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X)
538 
539 #endif // PARTIAL_STORE_M0 == 0 && PARTIAL_STORE_N0 == 0
540 
541 #endif    // defined(PARTIAL_STORE_M0) && defined(PARTIAL_STORE_N0)
542 /** @} */ // end of group STORE_BLOCK_BOUNDARY_AWARE
543 
544 #if defined(PARTIAL_STORE_M0)
545 /** Compute the start m0 row (LHS, BIAS and DST) in a boundary-aware way so as to avoid padding
546  * @name COMPUTE_M0_START_ROW
547  * If there're any partial blocks in y dimension, they are placed at the beginning of the rows.
548  * This shift amount is added to all rows such that the partial block (at the beginning) overlaps with the subsequent
549  * blocks in the y dimension to avoid any padding.
550  * EG: M0=4, PARTIAL_STORE_M0=1:
551  *                  | Non-overlapping | +M0_ROW_SHIFT (Overlapping)
552  * block 0 (partial)| start row = 0   | start row = 0
553  * block 1 (full)   | start row = 4   | start row = 1
554  * block 2 (full)   | start row = 8   | start row = 5
555  *
556  * @param[in] y                Global id of current block in y.
557  * @param[in] M0               The number of rows to store, for non-partial blocks. Supported: 1-16
558  * @param[in] PARTIAL_STORE_M0 The partial size in y, for partial blocks. Supported: [0, @p M0)
559  * @{
560  */
561 #define COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) \
562     ((uint)(max(0, (int)(y * M0) - (int)((M0 - PARTIAL_STORE_M0) % M0))))
563 #else // defined(PARTIAL_STORE_M0)
564 #define COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) \
565     ((uint)(y * M0))
566 #endif    // defined(PARTIAL_STORE_M0)
567 /** @} */ // end of group COMPUTE_M0_START_ROW
568 
569 /** Store a vector that can only be partial in x.
570  *
571  * @note in case @p vec_size or @p leftover != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
572  *
573  * The data to store is expected to end in a 0.
574  * E.g., for basename=c, the expected name is c0.
575  *
576  * @param[in] basename  The name of the variable without trailing 0
577  * @param[in] data_type The data type of the vector
578  * @param[in] ptr       The base pointer
579  * @param[in] vec_size  The vector size if cond = false. Supported: 1, 2, 3, 4, 8, 16
580  * @param[in] leftover  The vector size if cond = true. Supported range: [1, @p vec_size0)
581  * @param[in] cond      Condition to select either vec_size0 or vec_size1
582  * @{
583  */
584 #define STORE_VECTOR_SELECT(basename, data_type, ptr, vec_size, leftover, cond) \
585     STORE_BLOCK_PARTIAL_IN_X(1, vec_size, data_type, basename, ptr, 0, 0, leftover, cond)
586 /** @} */ // end of group STORE_VECTOR_SELECT