xref: /aosp_15_r20/external/ComputeLibrary/src/cpu/kernels/CpuGemmLowpOffsetContributionKernel.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1*c217d954SCole Faust /*
2*c217d954SCole Faust  * Copyright (c) 2017-2022 Arm Limited.
3*c217d954SCole Faust  *
4*c217d954SCole Faust  * SPDX-License-Identifier: MIT
5*c217d954SCole Faust  *
6*c217d954SCole Faust  * Permission is hereby granted, free of charge, to any person obtaining a copy
7*c217d954SCole Faust  * of this software and associated documentation files (the "Software"), to
8*c217d954SCole Faust  * deal in the Software without restriction, including without limitation the
9*c217d954SCole Faust  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10*c217d954SCole Faust  * sell copies of the Software, and to permit persons to whom the Software is
11*c217d954SCole Faust  * furnished to do so, subject to the following conditions:
12*c217d954SCole Faust  *
13*c217d954SCole Faust  * The above copyright notice and this permission notice shall be included in all
14*c217d954SCole Faust  * copies or substantial portions of the Software.
15*c217d954SCole Faust  *
16*c217d954SCole Faust  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17*c217d954SCole Faust  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18*c217d954SCole Faust  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19*c217d954SCole Faust  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20*c217d954SCole Faust  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21*c217d954SCole Faust  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22*c217d954SCole Faust  * SOFTWARE.
23*c217d954SCole Faust  */
24*c217d954SCole Faust #include "src/cpu/kernels/CpuGemmLowpOffsetContributionKernel.h"
25*c217d954SCole Faust 
26*c217d954SCole Faust #include "arm_compute/core/Error.h"
27*c217d954SCole Faust #include "arm_compute/core/Helpers.h"
28*c217d954SCole Faust #include "arm_compute/core/ITensor.h"
29*c217d954SCole Faust #include "arm_compute/core/TensorInfo.h"
30*c217d954SCole Faust #include "arm_compute/core/Types.h"
31*c217d954SCole Faust #include "arm_compute/core/Utils.h"
32*c217d954SCole Faust #include "arm_compute/core/Validate.h"
33*c217d954SCole Faust #include "arm_compute/core/Window.h"
34*c217d954SCole Faust #include "src/core/helpers/AutoConfiguration.h"
35*c217d954SCole Faust #include "src/core/helpers/WindowHelpers.h"
36*c217d954SCole Faust 
37*c217d954SCole Faust #include <arm_neon.h>
38*c217d954SCole Faust 
39*c217d954SCole Faust namespace arm_compute
40*c217d954SCole Faust {
41*c217d954SCole Faust namespace cpu
42*c217d954SCole Faust {
43*c217d954SCole Faust namespace kernels
44*c217d954SCole Faust {
45*c217d954SCole Faust namespace
46*c217d954SCole Faust {
validate_arguments(const ITensorInfo * mm_result,const ITensorInfo * vector_sum_col,const ITensorInfo * vector_sum_row,int32_t a_offset,int32_t b_offset)47*c217d954SCole Faust Status validate_arguments(const ITensorInfo *mm_result, const ITensorInfo *vector_sum_col, const ITensorInfo *vector_sum_row,
48*c217d954SCole Faust                           int32_t a_offset, int32_t b_offset)
49*c217d954SCole Faust {
50*c217d954SCole Faust     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(mm_result, 1, DataType::S32);
51*c217d954SCole Faust 
52*c217d954SCole Faust     // If a_offset == 0, vector_sum_col can be a nullptr
53*c217d954SCole Faust     if(a_offset != 0)
54*c217d954SCole Faust     {
55*c217d954SCole Faust         ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(vector_sum_col, 1, DataType::S32);
56*c217d954SCole Faust         ARM_COMPUTE_RETURN_ERROR_ON(vector_sum_col->dimension(0) != mm_result->dimension(0));
57*c217d954SCole Faust     }
58*c217d954SCole Faust 
59*c217d954SCole Faust     // If b_offset == 0, vector_sum_row can be a nullptr
60*c217d954SCole Faust     if(b_offset != 0)
61*c217d954SCole Faust     {
62*c217d954SCole Faust         ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(vector_sum_row, 1, DataType::S32);
63*c217d954SCole Faust 
64*c217d954SCole Faust         // Check if input is a 3D reinterpretation
65*c217d954SCole Faust         const bool reinterpret_as_3d = mm_result->num_dimensions() > 1 && mm_result->tensor_shape().y() != vector_sum_row->tensor_shape().x();
66*c217d954SCole Faust 
67*c217d954SCole Faust         // Validate input
68*c217d954SCole Faust         ARM_COMPUTE_RETURN_ERROR_ON(reinterpret_as_3d && vector_sum_row->dimension(0) != (mm_result->dimension(1) * mm_result->dimension(2)));
69*c217d954SCole Faust         ARM_COMPUTE_RETURN_ERROR_ON(!reinterpret_as_3d && vector_sum_row->dimension(0) != mm_result->dimension(1));
70*c217d954SCole Faust 
71*c217d954SCole Faust         TensorShape output_shape = mm_result->tensor_shape();
72*c217d954SCole Faust         if(output_shape.num_dimensions() > 1)
73*c217d954SCole Faust         {
74*c217d954SCole Faust             const unsigned int output_batch_idx = reinterpret_as_3d ? 3 : 2;
75*c217d954SCole Faust 
76*c217d954SCole Faust             TensorShape vector_sum_row_shape = vector_sum_row->tensor_shape();
77*c217d954SCole Faust             vector_sum_row_shape.collapse_from(1);
78*c217d954SCole Faust             output_shape.collapse_from(output_batch_idx);
79*c217d954SCole Faust 
80*c217d954SCole Faust             ARM_COMPUTE_RETURN_ERROR_ON_MSG(vector_sum_row_shape[1] != output_shape[output_batch_idx],
81*c217d954SCole Faust                                             "mm_result tensor must have the same number of batches of output tensor");
82*c217d954SCole Faust 
83*c217d954SCole Faust             if(a_offset != 0)
84*c217d954SCole Faust             {
85*c217d954SCole Faust                 TensorShape vector_sum_col_shape = vector_sum_col->tensor_shape();
86*c217d954SCole Faust                 vector_sum_col_shape.collapse_from(1);
87*c217d954SCole Faust 
88*c217d954SCole Faust                 ARM_COMPUTE_RETURN_ERROR_ON_MSG(vector_sum_col_shape[1] != 1 && vector_sum_col_shape[1] != vector_sum_row_shape[1],
89*c217d954SCole Faust                                                 "vector_sum_col tensor must have the same number of batches of vector_sum_row_shape or the number of batches must be set to 1");
90*c217d954SCole Faust             }
91*c217d954SCole Faust         }
92*c217d954SCole Faust     }
93*c217d954SCole Faust 
94*c217d954SCole Faust     return Status{};
95*c217d954SCole Faust }
96*c217d954SCole Faust 
run_offset_contribution(const Window & window,ITensor * mm_result,const ITensor * vector_sum_col,const ITensor * vector_sum_row,int32_t a_offset,int32_t b_offset,int32_t k_offset,bool slide_vector_sum_col,bool is_gemm3d)97*c217d954SCole Faust void run_offset_contribution(const Window &window,
98*c217d954SCole Faust                              ITensor *mm_result, const ITensor *vector_sum_col, const ITensor *vector_sum_row,
99*c217d954SCole Faust                              int32_t a_offset, int32_t b_offset, int32_t k_offset, bool slide_vector_sum_col, bool is_gemm3d)
100*c217d954SCole Faust {
101*c217d954SCole Faust     Window collapsed_window = window.collapse_if_possible(window, Window::DimZ);
102*c217d954SCole Faust     collapsed_window.set(Window::DimX, Window::Dimension(0, 1, 1));
103*c217d954SCole Faust 
104*c217d954SCole Faust     const int height_input = is_gemm3d ? mm_result->info()->dimension(1) : 0;
105*c217d954SCole Faust     const int depth_input  = is_gemm3d ? mm_result->info()->dimension(2) : 1;
106*c217d954SCole Faust 
107*c217d954SCole Faust     const int window_start_x = window.x().start();
108*c217d954SCole Faust     const int window_end_x   = window.x().end();
109*c217d954SCole Faust     const int window_step_x  = 16;
110*c217d954SCole Faust 
111*c217d954SCole Faust     // if vector_sum_col is nullptr then stride_y is 0, else get stride_y
112*c217d954SCole Faust     const size_t sum_col_stride_y = (vector_sum_col != nullptr) ? (vector_sum_col->info()->strides_in_bytes().y()) : 0;
113*c217d954SCole Faust     Iterator     mm_result_it(mm_result, collapsed_window);
114*c217d954SCole Faust 
115*c217d954SCole Faust     if((a_offset != 0) && (b_offset != 0) && (vector_sum_col != nullptr) && (vector_sum_row != nullptr)) // true, true
116*c217d954SCole Faust     {
117*c217d954SCole Faust         // Set window for vector_sum_col
118*c217d954SCole Faust         Window win_vector_sum_col(collapsed_window);
119*c217d954SCole Faust         win_vector_sum_col.set(Window::DimY, Window::Dimension(0, 0, 0));
120*c217d954SCole Faust         win_vector_sum_col.set(Window::DimZ, Window::Dimension(0, 0, 0));
121*c217d954SCole Faust 
122*c217d954SCole Faust         // Set window for vector_sum_row
123*c217d954SCole Faust         Window win_vector_sum_row(collapsed_window);
124*c217d954SCole Faust         win_vector_sum_row.set(Window::DimX, Window::Dimension(0, 0, 0));
125*c217d954SCole Faust         win_vector_sum_row.set(Window::DimY, Window::Dimension(0, 0, 0));
126*c217d954SCole Faust         win_vector_sum_row.set(Window::DimZ, Window::Dimension(0, 0, 0));
127*c217d954SCole Faust 
128*c217d954SCole Faust         Iterator vector_sum_col_it(vector_sum_col, win_vector_sum_col);
129*c217d954SCole Faust         Iterator vector_sum_row_it(vector_sum_row, win_vector_sum_row);
130*c217d954SCole Faust 
131*c217d954SCole Faust         const size_t sum_row_stride_y = vector_sum_row->info()->strides_in_bytes().y();
132*c217d954SCole Faust 
133*c217d954SCole Faust         // Offset in case vector_sum_col is batched
134*c217d954SCole Faust         const int vector_sum_col_batch_offset = slide_vector_sum_col ? vector_sum_col->info()->strides_in_bytes().z() : 0;
135*c217d954SCole Faust 
136*c217d954SCole Faust         execute_window_loop(collapsed_window, [&](const Coordinates & id)
137*c217d954SCole Faust         {
138*c217d954SCole Faust             const int    batch_id           = id.z() / depth_input;
139*c217d954SCole Faust             const size_t batch_offset_col   = batch_id * (sum_col_stride_y );
140*c217d954SCole Faust             auto         vector_sum_col_ptr = reinterpret_cast<const int32_t *>(vector_sum_col_it.ptr() + batch_offset_col + batch_id * vector_sum_col_batch_offset);
141*c217d954SCole Faust             auto         mm_result_ptr      = reinterpret_cast<int32_t *>(mm_result_it.ptr());
142*c217d954SCole Faust 
143*c217d954SCole Faust             // Compute the leftover term due to b_offset.
144*c217d954SCole Faust             int32_t b_offset_term_s32 = *(reinterpret_cast<const int32_t *>(vector_sum_row_it.ptr() + batch_id * sum_row_stride_y) + id.y() + (id.z() % depth_input) * height_input);
145*c217d954SCole Faust             b_offset_term_s32 *= b_offset;
146*c217d954SCole Faust 
147*c217d954SCole Faust             const int32x4_t b_offset_term_s32_vec = vdupq_n_s32(b_offset_term_s32);
148*c217d954SCole Faust 
149*c217d954SCole Faust             int x = window_start_x;
150*c217d954SCole Faust             for(; x <= (window_end_x - window_step_x); x += window_step_x)
151*c217d954SCole Faust             {
152*c217d954SCole Faust                 // Compute the leftover term due to a_offset.
153*c217d954SCole Faust                 int32x4x4_t a_offset_term_s32 =
154*c217d954SCole Faust                 {
155*c217d954SCole Faust                     {
156*c217d954SCole Faust                         vld1q_s32(vector_sum_col_ptr + x + 0),
157*c217d954SCole Faust                         vld1q_s32(vector_sum_col_ptr + x + 4),
158*c217d954SCole Faust                         vld1q_s32(vector_sum_col_ptr + x + 8),
159*c217d954SCole Faust                         vld1q_s32(vector_sum_col_ptr + x + 12)
160*c217d954SCole Faust                     }
161*c217d954SCole Faust                 };
162*c217d954SCole Faust 
163*c217d954SCole Faust                 a_offset_term_s32.val[0] = vmulq_n_s32(a_offset_term_s32.val[0], a_offset);
164*c217d954SCole Faust                 a_offset_term_s32.val[1] = vmulq_n_s32(a_offset_term_s32.val[1], a_offset);
165*c217d954SCole Faust                 a_offset_term_s32.val[2] = vmulq_n_s32(a_offset_term_s32.val[2], a_offset);
166*c217d954SCole Faust                 a_offset_term_s32.val[3] = vmulq_n_s32(a_offset_term_s32.val[3], a_offset);
167*c217d954SCole Faust 
168*c217d954SCole Faust                 // Add a_offset_term_s32 and b_offset_term_s32
169*c217d954SCole Faust                 int32x4x4_t offset_term_s32 =
170*c217d954SCole Faust                 {
171*c217d954SCole Faust                     {
172*c217d954SCole Faust                         vdupq_n_s32(k_offset),
173*c217d954SCole Faust                         vdupq_n_s32(k_offset),
174*c217d954SCole Faust                         vdupq_n_s32(k_offset),
175*c217d954SCole Faust                         vdupq_n_s32(k_offset)
176*c217d954SCole Faust                     }
177*c217d954SCole Faust                 };
178*c217d954SCole Faust 
179*c217d954SCole Faust                 offset_term_s32.val[0] = vaddq_s32(offset_term_s32.val[0], vaddq_s32(a_offset_term_s32.val[0], b_offset_term_s32_vec));
180*c217d954SCole Faust                 offset_term_s32.val[1] = vaddq_s32(offset_term_s32.val[1], vaddq_s32(a_offset_term_s32.val[1], b_offset_term_s32_vec));
181*c217d954SCole Faust                 offset_term_s32.val[2] = vaddq_s32(offset_term_s32.val[2], vaddq_s32(a_offset_term_s32.val[2], b_offset_term_s32_vec));
182*c217d954SCole Faust                 offset_term_s32.val[3] = vaddq_s32(offset_term_s32.val[3], vaddq_s32(a_offset_term_s32.val[3], b_offset_term_s32_vec));
183*c217d954SCole Faust 
184*c217d954SCole Faust                 int32x4x4_t in_s32 =
185*c217d954SCole Faust                 {
186*c217d954SCole Faust                     {
187*c217d954SCole Faust                         vld1q_s32(mm_result_ptr + x + 0),
188*c217d954SCole Faust                         vld1q_s32(mm_result_ptr + x + 4),
189*c217d954SCole Faust                         vld1q_s32(mm_result_ptr + x + 8),
190*c217d954SCole Faust                         vld1q_s32(mm_result_ptr + x + 12)
191*c217d954SCole Faust                     }
192*c217d954SCole Faust                 };
193*c217d954SCole Faust 
194*c217d954SCole Faust                 // Add the offset terms to GEMM's result
195*c217d954SCole Faust                 in_s32.val[0] = vaddq_s32(in_s32.val[0], offset_term_s32.val[0]);
196*c217d954SCole Faust                 in_s32.val[1] = vaddq_s32(in_s32.val[1], offset_term_s32.val[1]);
197*c217d954SCole Faust                 in_s32.val[2] = vaddq_s32(in_s32.val[2], offset_term_s32.val[2]);
198*c217d954SCole Faust                 in_s32.val[3] = vaddq_s32(in_s32.val[3], offset_term_s32.val[3]);
199*c217d954SCole Faust 
200*c217d954SCole Faust                 // Store the result with the offset contribution
201*c217d954SCole Faust                 vst1q_s32(mm_result_ptr + x + 0, in_s32.val[0]);
202*c217d954SCole Faust                 vst1q_s32(mm_result_ptr + x + 4, in_s32.val[1]);
203*c217d954SCole Faust                 vst1q_s32(mm_result_ptr + x + 8, in_s32.val[2]);
204*c217d954SCole Faust                 vst1q_s32(mm_result_ptr + x + 12, in_s32.val[3]);
205*c217d954SCole Faust             }
206*c217d954SCole Faust 
207*c217d954SCole Faust             // Left-overs loop
208*c217d954SCole Faust             for(; x < window_end_x; ++x)
209*c217d954SCole Faust             {
210*c217d954SCole Faust                 // Compute the leftover term due to a_offset.
211*c217d954SCole Faust                 int32_t a_offset_term_s32 = *(vector_sum_col_ptr + x);
212*c217d954SCole Faust 
213*c217d954SCole Faust                 a_offset_term_s32 *= a_offset;
214*c217d954SCole Faust 
215*c217d954SCole Faust                 // Add the offset terms to GEMM's result
216*c217d954SCole Faust                 // Store the result with the offset contribution
217*c217d954SCole Faust                 mm_result_ptr[x] += k_offset + a_offset_term_s32 + b_offset_term_s32;
218*c217d954SCole Faust             }
219*c217d954SCole Faust         },
220*c217d954SCole Faust         vector_sum_col_it, vector_sum_row_it, mm_result_it);
221*c217d954SCole Faust     }
222*c217d954SCole Faust     else if((a_offset == 0) && (b_offset != 0) && (vector_sum_row != nullptr)) // false, true
223*c217d954SCole Faust     {
224*c217d954SCole Faust         ARM_COMPUTE_ERROR_ON_NULLPTR(vector_sum_row);
225*c217d954SCole Faust 
226*c217d954SCole Faust         // Set window for vector_sum_row
227*c217d954SCole Faust         Window win_vector_sum_row(collapsed_window);
228*c217d954SCole Faust         win_vector_sum_row.set(Window::DimX, Window::Dimension(0, 0, 0));
229*c217d954SCole Faust         win_vector_sum_row.set(Window::DimY, Window::Dimension(0, 0, 0));
230*c217d954SCole Faust         win_vector_sum_row.set(Window::DimZ, Window::Dimension(0, 0, 0));
231*c217d954SCole Faust 
232*c217d954SCole Faust         Iterator vector_sum_row_it(vector_sum_row, win_vector_sum_row);
233*c217d954SCole Faust 
234*c217d954SCole Faust         const size_t sum_row_stride_y = vector_sum_row->info()->strides_in_bytes().y();
235*c217d954SCole Faust 
236*c217d954SCole Faust         execute_window_loop(collapsed_window, [&](const Coordinates & id)
237*c217d954SCole Faust         {
238*c217d954SCole Faust             const int batch_id      = id.z() / depth_input;
239*c217d954SCole Faust             auto      mm_result_ptr = reinterpret_cast<int32_t *>(mm_result_it.ptr());
240*c217d954SCole Faust 
241*c217d954SCole Faust             // Compute the leftover term due to b_offset.
242*c217d954SCole Faust             int32_t b_offset_term_s32 = *(reinterpret_cast<const int32_t *>(vector_sum_row_it.ptr() + batch_id * sum_row_stride_y) + id.y() + (id.z() % depth_input) * height_input);
243*c217d954SCole Faust             b_offset_term_s32 *= b_offset;
244*c217d954SCole Faust 
245*c217d954SCole Faust             const int32x4_t b_offset_term_s32_vec = vdupq_n_s32(b_offset_term_s32);
246*c217d954SCole Faust 
247*c217d954SCole Faust             int x = window_start_x;
248*c217d954SCole Faust             for(; x <= (window_end_x - window_step_x); x += window_step_x)
249*c217d954SCole Faust             {
250*c217d954SCole Faust                 int32x4x4_t in_s32 =
251*c217d954SCole Faust                 {
252*c217d954SCole Faust                     {
253*c217d954SCole Faust                         vld1q_s32(mm_result_ptr + x + 0),
254*c217d954SCole Faust                         vld1q_s32(mm_result_ptr + x + 4),
255*c217d954SCole Faust                         vld1q_s32(mm_result_ptr + x + 8),
256*c217d954SCole Faust                         vld1q_s32(mm_result_ptr + x + 12)
257*c217d954SCole Faust                     }
258*c217d954SCole Faust                 };
259*c217d954SCole Faust 
260*c217d954SCole Faust                 // Add the offset terms to GEMM's result
261*c217d954SCole Faust                 in_s32.val[0] = vaddq_s32(in_s32.val[0], b_offset_term_s32_vec);
262*c217d954SCole Faust                 in_s32.val[1] = vaddq_s32(in_s32.val[1], b_offset_term_s32_vec);
263*c217d954SCole Faust                 in_s32.val[2] = vaddq_s32(in_s32.val[2], b_offset_term_s32_vec);
264*c217d954SCole Faust                 in_s32.val[3] = vaddq_s32(in_s32.val[3], b_offset_term_s32_vec);
265*c217d954SCole Faust 
266*c217d954SCole Faust                 // Store the result with the offset contribution
267*c217d954SCole Faust                 vst1q_s32(mm_result_ptr + x + 0, in_s32.val[0]);
268*c217d954SCole Faust                 vst1q_s32(mm_result_ptr + x + 4, in_s32.val[1]);
269*c217d954SCole Faust                 vst1q_s32(mm_result_ptr + x + 8, in_s32.val[2]);
270*c217d954SCole Faust                 vst1q_s32(mm_result_ptr + x + 12, in_s32.val[3]);
271*c217d954SCole Faust             }
272*c217d954SCole Faust 
273*c217d954SCole Faust             // Left-overs loop
274*c217d954SCole Faust             for(; x < window_end_x; ++x)
275*c217d954SCole Faust             {
276*c217d954SCole Faust                 // Add the offset terms to GEMM's result
277*c217d954SCole Faust                 // Store the result with the offset contribution
278*c217d954SCole Faust                 mm_result_ptr[x] += b_offset_term_s32;
279*c217d954SCole Faust             }
280*c217d954SCole Faust         },
281*c217d954SCole Faust         vector_sum_row_it, mm_result_it);
282*c217d954SCole Faust     }
283*c217d954SCole Faust     else if((a_offset != 0) && (b_offset == 0) && (vector_sum_col != nullptr)) // true, false
284*c217d954SCole Faust     {
285*c217d954SCole Faust         // Set window for vector_sum_col
286*c217d954SCole Faust         Window win_vector_sum_col(collapsed_window);
287*c217d954SCole Faust         win_vector_sum_col.set(Window::DimY, Window::Dimension(0, 0, 0));
288*c217d954SCole Faust         win_vector_sum_col.set(Window::DimZ, Window::Dimension(0, 0, 0));
289*c217d954SCole Faust 
290*c217d954SCole Faust         Iterator vector_sum_col_it(vector_sum_col, win_vector_sum_col);
291*c217d954SCole Faust 
292*c217d954SCole Faust         // Offset in case vector_sum_col is batched
293*c217d954SCole Faust         const int vector_sum_col_batch_offset = slide_vector_sum_col ? vector_sum_col->info()->strides_in_bytes().z() : 0;
294*c217d954SCole Faust 
295*c217d954SCole Faust         execute_window_loop(collapsed_window, [&](const Coordinates & id)
296*c217d954SCole Faust         {
297*c217d954SCole Faust             const int    batch_id           = id.z() / depth_input;
298*c217d954SCole Faust             const size_t batch_offset_col   = batch_id * (sum_col_stride_y ); // Value to offset vector_sum_col_ptr to allow for iteration of y values in tensor
299*c217d954SCole Faust             auto         vector_sum_col_ptr = reinterpret_cast<const int32_t *>(vector_sum_col_it.ptr() + batch_offset_col + batch_id * vector_sum_col_batch_offset);
300*c217d954SCole Faust             auto         mm_result_ptr      = reinterpret_cast<int32_t *>(mm_result_it.ptr());
301*c217d954SCole Faust 
302*c217d954SCole Faust             int x = window_start_x;
303*c217d954SCole Faust             for(; x <= (window_end_x - window_step_x); x += window_step_x)
304*c217d954SCole Faust             {
305*c217d954SCole Faust                 // Compute the leftover term due to a_offset.
306*c217d954SCole Faust                 int32x4x4_t a_offset_term_s32 =
307*c217d954SCole Faust                 {
308*c217d954SCole Faust                     {
309*c217d954SCole Faust                         vld1q_s32(vector_sum_col_ptr + x + 0),
310*c217d954SCole Faust                         vld1q_s32(vector_sum_col_ptr + x + 4),
311*c217d954SCole Faust                         vld1q_s32(vector_sum_col_ptr + x + 8),
312*c217d954SCole Faust                         vld1q_s32(vector_sum_col_ptr + x + 12)
313*c217d954SCole Faust                     }
314*c217d954SCole Faust                 };
315*c217d954SCole Faust 
316*c217d954SCole Faust                 a_offset_term_s32.val[0] = vmulq_n_s32(a_offset_term_s32.val[0], a_offset);
317*c217d954SCole Faust                 a_offset_term_s32.val[1] = vmulq_n_s32(a_offset_term_s32.val[1], a_offset);
318*c217d954SCole Faust                 a_offset_term_s32.val[2] = vmulq_n_s32(a_offset_term_s32.val[2], a_offset);
319*c217d954SCole Faust                 a_offset_term_s32.val[3] = vmulq_n_s32(a_offset_term_s32.val[3], a_offset);
320*c217d954SCole Faust 
321*c217d954SCole Faust                 int32x4x4_t in_s32 =
322*c217d954SCole Faust                 {
323*c217d954SCole Faust                     {
324*c217d954SCole Faust                         vld1q_s32(mm_result_ptr + x + 0),
325*c217d954SCole Faust                         vld1q_s32(mm_result_ptr + x + 4),
326*c217d954SCole Faust                         vld1q_s32(mm_result_ptr + x + 8),
327*c217d954SCole Faust                         vld1q_s32(mm_result_ptr + x + 12)
328*c217d954SCole Faust                     }
329*c217d954SCole Faust                 };
330*c217d954SCole Faust 
331*c217d954SCole Faust                 // Add the offset terms to GEMM's result
332*c217d954SCole Faust                 in_s32.val[0] = vaddq_s32(in_s32.val[0], a_offset_term_s32.val[0]);
333*c217d954SCole Faust                 in_s32.val[1] = vaddq_s32(in_s32.val[1], a_offset_term_s32.val[1]);
334*c217d954SCole Faust                 in_s32.val[2] = vaddq_s32(in_s32.val[2], a_offset_term_s32.val[2]);
335*c217d954SCole Faust                 in_s32.val[3] = vaddq_s32(in_s32.val[3], a_offset_term_s32.val[3]);
336*c217d954SCole Faust 
337*c217d954SCole Faust                 // Store the result with the offset contribution
338*c217d954SCole Faust                 vst1q_s32(mm_result_ptr + x + 0, in_s32.val[0]);
339*c217d954SCole Faust                 vst1q_s32(mm_result_ptr + x + 4, in_s32.val[1]);
340*c217d954SCole Faust                 vst1q_s32(mm_result_ptr + x + 8, in_s32.val[2]);
341*c217d954SCole Faust                 vst1q_s32(mm_result_ptr + x + 12, in_s32.val[3]);
342*c217d954SCole Faust             }
343*c217d954SCole Faust 
344*c217d954SCole Faust             // Left-overs loop
345*c217d954SCole Faust             for(; x < window_end_x; ++x)
346*c217d954SCole Faust             {
347*c217d954SCole Faust                 // Compute the leftover term due to a_offset.
348*c217d954SCole Faust                 const int32_t a_offset_term_s32 = *(vector_sum_col_ptr + x);
349*c217d954SCole Faust 
350*c217d954SCole Faust                 // Add the offset terms to GEMM's result
351*c217d954SCole Faust                 // Store the result with the offset contribution
352*c217d954SCole Faust                 mm_result_ptr[x] += a_offset_term_s32 * a_offset;
353*c217d954SCole Faust             }
354*c217d954SCole Faust         },
355*c217d954SCole Faust         vector_sum_col_it, mm_result_it);
356*c217d954SCole Faust     }
357*c217d954SCole Faust     else // false, false
358*c217d954SCole Faust     {
359*c217d954SCole Faust         // No offset contribution from matrix A and matrix B
360*c217d954SCole Faust         return;
361*c217d954SCole Faust     }
362*c217d954SCole Faust }
363*c217d954SCole Faust } // namespace
364*c217d954SCole Faust 
configure(ITensorInfo * mm_result,ITensorInfo * vector_sum_col,ITensorInfo * vector_sum_row,int32_t k,int32_t a_offset,int32_t b_offset)365*c217d954SCole Faust void CpuGemmLowpOffsetContributionKernel::configure(ITensorInfo *mm_result, ITensorInfo *vector_sum_col, ITensorInfo *vector_sum_row, int32_t k, int32_t a_offset, int32_t b_offset)
366*c217d954SCole Faust {
367*c217d954SCole Faust     // Perform validate step
368*c217d954SCole Faust     ARM_COMPUTE_UNUSED(vector_sum_row);
369*c217d954SCole Faust     ARM_COMPUTE_ERROR_ON_NULLPTR(mm_result);
370*c217d954SCole Faust     ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(mm_result, vector_sum_col, vector_sum_row, a_offset, b_offset));
371*c217d954SCole Faust 
372*c217d954SCole Faust     _a_offset = a_offset;
373*c217d954SCole Faust     _b_offset = b_offset;
374*c217d954SCole Faust     _k_offset = a_offset * b_offset * k;
375*c217d954SCole Faust 
376*c217d954SCole Faust     // If a_offset == 0, vector_sum_col can be a nullptr
377*c217d954SCole Faust     if(a_offset != 0)
378*c217d954SCole Faust     {
379*c217d954SCole Faust         // Check if vector_sum_col_shape should be slidden or not
380*c217d954SCole Faust         // Don't slide vector_sum_col_shape along the y dimension if vector_sum_col_shape has just 1 dimension and vector_sum_row_shape more than 1
381*c217d954SCole Faust         // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
382*c217d954SCole Faust         _slide_vector_sum_col = vector_sum_col->tensor_shape().num_dimensions() > 1;
383*c217d954SCole Faust     }
384*c217d954SCole Faust 
385*c217d954SCole Faust     // Configure kernel window
386*c217d954SCole Faust     Window win = calculate_max_window(*mm_result, Steps());
387*c217d954SCole Faust     ICpuKernel::configure(win);
388*c217d954SCole Faust }
389*c217d954SCole Faust 
validate(const ITensorInfo * mm_result,const ITensorInfo * vector_sum_col,const ITensorInfo * vector_sum_row,int32_t a_offset,int32_t b_offset)390*c217d954SCole Faust Status CpuGemmLowpOffsetContributionKernel::validate(const ITensorInfo *mm_result, const ITensorInfo *vector_sum_col, const ITensorInfo *vector_sum_row,
391*c217d954SCole Faust                                                      int32_t a_offset, int32_t b_offset)
392*c217d954SCole Faust {
393*c217d954SCole Faust     ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(mm_result, vector_sum_col, vector_sum_row, a_offset, b_offset));
394*c217d954SCole Faust     return Status{};
395*c217d954SCole Faust }
396*c217d954SCole Faust 
run_op(ITensorPack & tensors,const Window & window,const ThreadInfo & info)397*c217d954SCole Faust void CpuGemmLowpOffsetContributionKernel::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info)
398*c217d954SCole Faust {
399*c217d954SCole Faust     ARM_COMPUTE_UNUSED(info);
400*c217d954SCole Faust     ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
401*c217d954SCole Faust     ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICpuKernel::window(), window);
402*c217d954SCole Faust 
403*c217d954SCole Faust     auto vector_sum_col = tensors.get_const_tensor(TensorType::ACL_SRC_0);
404*c217d954SCole Faust     auto vector_sum_row = tensors.get_const_tensor(TensorType::ACL_SRC_1);
405*c217d954SCole Faust     auto mm_result      = tensors.get_tensor(TensorType::ACL_DST);
406*c217d954SCole Faust 
407*c217d954SCole Faust     // Check if input is a 3D reinterpretation
408*c217d954SCole Faust     const bool reinterpret_as_3d = vector_sum_row != nullptr
409*c217d954SCole Faust                                    && mm_result->info()->num_dimensions() > 1
410*c217d954SCole Faust                                    && mm_result->info()->tensor_shape().y() != vector_sum_row->info()->tensor_shape().x();
411*c217d954SCole Faust 
412*c217d954SCole Faust     run_offset_contribution(window, mm_result, vector_sum_col, vector_sum_row, _a_offset, _b_offset, _k_offset, _slide_vector_sum_col, reinterpret_as_3d);
413*c217d954SCole Faust }
414*c217d954SCole Faust 
name() const415*c217d954SCole Faust const char *CpuGemmLowpOffsetContributionKernel::name() const
416*c217d954SCole Faust {
417*c217d954SCole Faust     return "CpuGemmLowpOffsetContributionKernel";
418*c217d954SCole Faust }
419*c217d954SCole Faust } // namespace kernels
420*c217d954SCole Faust } // namespace cpu
421*c217d954SCole Faust } // namespace arm_compute