xref: /aosp_15_r20/external/ComputeLibrary/src/core/NEON/kernels/arm_gemm/interleave_indirect_impl.hpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2022 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 
25 #pragma once
26 
27 // Implementations of interleave functions
28 // These must be included with a "namespace arm_gemm" block.
29 
30 /*
31  * Core function that does heavy lifting - interleave 'int_by' rows of width 'width' together.
32  *
33  * 'height' indicates the actual number of rows to interleave, so if it's less than int_by then the remaining
34  * entries are padded (note that this is "GEMM" padding rather than convolution padding, so there is no need to pad
35  * with a particular value.
36  *
37  * Note that it is not expected for this templated version to ever be used - all cases that matter should be
38  * explicitly specialized with an optimized implementation.
39  */
40 template<unsigned int height_vectors, unsigned int block, VLType vlt, bool integrate_sums, typename TIn, typename TOut>
interleave_block(TOut * & out,const TIn * const * in,size_t width,size_t height,size_t row_offset,bool first)41 void interleave_block( TOut * &out, const TIn * const *in, size_t width, size_t height, size_t row_offset, bool first) {
42     const unsigned int int_by = height_vectors * (vlt == VLType::SVE ? get_vector_length<TOut>() / block :
43                                                   (vlt == VLType::SME ? sme::get_vector_length<TOut>() / block : 1 ));
44 
45     std::vector<int32_t> the_sums;
46 
47     if (integrate_sums) {
48         the_sums = std::vector<int32_t>(int_by, 0);
49 
50         if (!first) {
51             // In 'integrate sums' mode, we dump the sums at the end on each pass.
52 
53             // On the last pass this is correct, but on other passes it is not -
54             // so on the subsequent pass we need to take the output written by
55             // the previous pass as starting point for the sums, and then
56             // overwrite them with new interleaved data.
57             int32_t *out_int32 = reinterpret_cast<int32_t *>(out);
58 
59             // Rewind pointer to where we wrote out the sums last time.
60             out_int32 -= int_by;
61 
62             // Restore the running sums.
63             memcpy(the_sums.data(), out_int32, int_by * sizeof(int32_t));
64 
65             // Update the "real" pointer so that the next output will clobber the old sums.
66             out = reinterpret_cast<TOut *>(out_int32);
67         }
68     }
69 
70     for (unsigned int pos=0; pos<width; pos+=block) {
71         for (unsigned int row=0; row<int_by; row++) {
72             // Row out of range - pad 'block' entries.
73             if (row >= height) {
74                 for (unsigned int col=0; col<block; col++) {
75                     *out++ = 0;
76                 }
77                 continue;
78             }
79 
80             for (unsigned int col=0; col<block; col++) {
81                 // Column out of range - pad a single entry
82                 if (pos + col >= width) {
83                     *out++ = 0;
84                     continue;
85                 }
86 
87                 if (integrate_sums) {
88                     the_sums[row] += in[row][row_offset + pos + col];
89                 }
90 
91                 *out++ = in[row][row_offset + pos + col];
92             }
93         }
94     }
95 
96     if (integrate_sums) {
97         int32_t *out_int32 = reinterpret_cast<int32_t *>(out);
98 
99         memcpy(out_int32, the_sums.data(), int_by * sizeof(int32_t));
100 
101         out = reinterpret_cast<TOut *>(out_int32 + int_by);
102     }
103 }
104 
105 template<unsigned int height_vectors, unsigned int block, VLType vlt, typename TOut>
FixupRowSums(TOut * & out,const int32_t row_sum_multiplier)106 inline void FixupRowSums(TOut * &out, const int32_t row_sum_multiplier) {
107     const unsigned int height = height_vectors * (vlt == VLType::SVE ? get_vector_length<TOut>() / block :
108                                                   (vlt == VLType::SME ? sme::get_vector_length<TOut>() / block : 1 ));
109 
110     // If we are integrating row sums, we need to do some fix up, depending on whether the multiplier is non-zero or not.
111     if (row_sum_multiplier) {
112         // Non-zero: interleave_block<>() will have done the sums, so 'out' will point to the start of the
113         // next block (post sums).
114         // We need to go back and apply the multiplier to the computed sums.  We don't need to change 'out'.
115         int32_t *out_int32 = reinterpret_cast<int32_t *>(out);
116 
117         out_int32 -= height;
118         for (unsigned int i=0; i<height; i++) {
119             out_int32[i] *= row_sum_multiplier;
120         }
121     } else {
122         // Zero: interleave_block<>() will *not* have done the sums, so 'out' will point to the start of the
123         // sum block.  We need to insert the (zero) sums, and advance 'out'.
124         int32_t *out_int32 = reinterpret_cast<int32_t *>(out);
125 
126         for (unsigned int i=0; i<height; i++) {
127             out_int32[i] = 0;
128         }
129 
130         out_int32 += height;
131 
132         out = reinterpret_cast<TOut *>(out_int32);
133     }
134 }
135 
136 template<unsigned int height_vectors, unsigned int block, VLType vlt, typename TIn, typename TOut>
IndirectInterleave(TOut * out,const TIn * const * const * ptr,unsigned int stringlen,unsigned int rounded_stringlen,const unsigned int y0,const unsigned int ymax,const unsigned int k0,const unsigned int kmax,bool integrate_sums,const int32_t row_sum_multiplier)137 void IndirectInterleave(TOut *out, const TIn * const * const *ptr, unsigned int stringlen,
138                         unsigned int rounded_stringlen, const unsigned int y0, const unsigned int ymax,
139                         const unsigned int k0, const unsigned int kmax, bool integrate_sums,
140                         const int32_t row_sum_multiplier) {
141     const unsigned int height = height_vectors * (vlt == VLType::SVE ? get_vector_length<TOut>() / block :
142                                                   (vlt == VLType::SME ? sme::get_vector_length<TOut>() / block : 1 ));
143 
144     // 'interleave_block' implementations are entitled to read a pointer for each row they handle from the input
145     // pointer array, even for out of range rows (although they must not subsequently dereference those pointers for
146     // out of range rows).  This allows interleave_block to use techniques like row predication, or loading all
147     // pointers and conditionally overriding the out of range ones.
148 
149     // This is problematic in the "pure" indirect case when we get to the last rows, where it can lead to out of
150     // range reads.  Avoid this with a local buffer to use in last-rows cases.  Use alloca as a std::vector can be
151     // expensive in highly threaded scenarios.
152     const TIn **row_ptrs = reinterpret_cast<const TIn **>(alloca(height * sizeof(const TIn *)));
153 
154     // Figure out the starting position based on k0 (with rounded length)
155     unsigned int start_string      = k0 / rounded_stringlen;
156     unsigned int start_stringpos   = k0 % rounded_stringlen;
157 
158     // Process blocks of 'height' height...
159     for (unsigned int ybase = y0; ybase < ymax; ybase+=height) {
160         // Height to process
161         unsigned int active_height = std::min(ymax - ybase, height);
162 
163         // Track our progress through the various strings
164         unsigned int k_left    = (kmax - k0);
165         unsigned int string    = start_string;
166         unsigned int stringpos = start_stringpos;
167 
168         bool first = true;
169 
170         // Prepare to call 'interleave_block' above for each string encompassed by K range
171         while (k_left > 0) {
172             // Width to process - and the width we will generate (with padding)
173             unsigned int in_width   = std::min(k_left, stringlen - stringpos);
174             unsigned int out_width  = std::min(k_left, rounded_stringlen - stringpos);
175 
176             const TIn * const *row_base = ptr[string] + ybase;
177 
178             // If not all rows are valid, copy the ones that are into local array (see above comment).
179             if (active_height < height) {
180                 for (unsigned int i=0; i<active_height; i++) {
181                     row_ptrs[i] = ptr[string][ybase + i];
182                 }
183 
184                 row_base = row_ptrs;
185             }
186 
187             // 'integrate_sums' is a function parameter rather than a template parameter to prevent duplicating too
188             // much code.  However, integrated sums make no sense for non-integral types and won't ever be
189             // requested.  So put a type trait check here to avoid generating pointless code.
190             if (std::is_integral<TOut>::value && integrate_sums && row_sum_multiplier) {
191                 interleave_block<height_vectors, block, vlt, true>(out, row_base, in_width, active_height, stringpos, first);
192             } else {
193                 interleave_block<height_vectors, block, vlt, false>(out, row_base, in_width, active_height, stringpos, first);
194             }
195 
196             k_left -= out_width;
197             string++;
198             stringpos=0;
199             first=false;
200         }
201 
202         if (std::is_integral<TOut>::value && integrate_sums) {
203             FixupRowSums<height_vectors, block, vlt>(out, row_sum_multiplier);
204         }
205     }
206 }
207 
208 template<unsigned int height_vectors, unsigned int block, VLType vlt, typename TIn, typename TOut>
ConvolutionInterleave(TOut * out,const TIn * in,size_t in_stride,const convolver<TIn> & conv,const unsigned int rounded_stringlen,const unsigned int y0,const unsigned int ymax,const unsigned int k0,const unsigned int kmax,bool integrate_sums,const int32_t row_sum_multiplier)209 void ConvolutionInterleave(TOut *out, const TIn *in, size_t in_stride, const convolver<TIn> &conv, const unsigned int rounded_stringlen,
210         const unsigned int y0, const unsigned int ymax, const unsigned int k0, const unsigned int kmax, bool integrate_sums, const int32_t row_sum_multiplier) {
211     const unsigned int height = height_vectors * (vlt == VLType::SVE ? get_vector_length<TOut>() / block :
212                                                   (vlt == VLType::SME ? sme::get_vector_length<TOut>() / block : 1 ));
213     auto conv_cols = conv.process_columns(in, in_stride, k0, kmax, rounded_stringlen);
214 
215     // Use alloca here as a std::vector can be expensive in highly threaded scenarios.
216     const TIn **row_ptrs = reinterpret_cast<const TIn **>(alloca(height * sizeof(const TIn *)));
217 
218     for (unsigned int ybase = y0; ybase < ymax; ybase += height) {
219         // How many of the rows are active - the rest will get padded in interleave_block.
220         unsigned int active_height   = std::min(ymax - ybase, height);
221         bool first = true;
222 
223         auto conv_rows = conv_cols.process_rows(ybase, active_height);
224 
225         while (!conv_rows.finished()) {
226             unsigned int width, offset;
227 
228             // Get next set of parameters
229             std::tie(width, offset) = conv_rows.next_block(row_ptrs);
230 
231             // Perform the interleave
232             if (std::is_integral<TOut>::value && integrate_sums && row_sum_multiplier) {
233                 interleave_block<height_vectors, block, vlt, true>(out, row_ptrs, width, active_height, offset, first);
234             } else {
235                 interleave_block<height_vectors, block, vlt, false>(out, row_ptrs, width, active_height, offset, first);
236             }
237 
238             first=false;
239         }
240 
241         if (std::is_integral<TOut>::value && integrate_sums) {
242             FixupRowSums<height_vectors, block, vlt>(out, row_sum_multiplier);
243         }
244     }
245 }
246 
247 template<unsigned int height_vectors, unsigned int block, VLType vlt, typename TIn, typename TOut>
Interleave(TOut * out,const TIn * in,size_t in_stride,const unsigned int y0,const unsigned int ymax,const unsigned int k0,const unsigned int kmax,bool integrate_sums,const int32_t row_sum_multiplier)248 void Interleave(TOut *out, const TIn *in, size_t in_stride, const unsigned int y0, const unsigned int ymax, const unsigned int k0, const unsigned int kmax, bool integrate_sums, const int32_t row_sum_multiplier) {
249     const unsigned int height = height_vectors * (vlt == VLType::SVE ? get_vector_length<TOut>() / block :
250                                                   (vlt == VLType::SME ? sme::get_vector_length<TOut>() / block : 1 ));
251     // Use alloca here as a std::vector can be expensive in highly threaded scenarios.
252     const TIn **row_ptrs = reinterpret_cast<const TIn **>(alloca(height * sizeof(const TIn *)));
253 
254     const unsigned int width=kmax-k0;
255 
256     for (unsigned int y=y0; y<ymax; y+=height) {
257         for (unsigned int r=0; r<height; r++) {
258             row_ptrs[r] = in + ((y + r) * in_stride);
259         }
260 
261         if (std::is_integral<TOut>::value && integrate_sums && row_sum_multiplier) {
262             interleave_block<height_vectors, block, vlt, true>(out, row_ptrs, width, std::min(height, ymax-y), k0, true);
263         } else {
264             interleave_block<height_vectors, block, vlt, false>(out, row_ptrs, width, std::min(height, ymax-y), k0, true);
265         }
266 
267         if (std::is_integral<TOut>::value && integrate_sums) {
268             FixupRowSums<height_vectors, block, vlt>(out, row_sum_multiplier);
269         }
270     }
271 }
272