1*c217d954SCole Faust /*
2*c217d954SCole Faust * Copyright (c) 2017-2022 Arm Limited.
3*c217d954SCole Faust *
4*c217d954SCole Faust * SPDX-License-Identifier: MIT
5*c217d954SCole Faust *
6*c217d954SCole Faust * Permission is hereby granted, free of charge, to any person obtaining a copy
7*c217d954SCole Faust * of this software and associated documentation files (the "Software"), to
8*c217d954SCole Faust * deal in the Software without restriction, including without limitation the
9*c217d954SCole Faust * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10*c217d954SCole Faust * sell copies of the Software, and to permit persons to whom the Software is
11*c217d954SCole Faust * furnished to do so, subject to the following conditions:
12*c217d954SCole Faust *
13*c217d954SCole Faust * The above copyright notice and this permission notice shall be included in all
14*c217d954SCole Faust * copies or substantial portions of the Software.
15*c217d954SCole Faust *
16*c217d954SCole Faust * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17*c217d954SCole Faust * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18*c217d954SCole Faust * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19*c217d954SCole Faust * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20*c217d954SCole Faust * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21*c217d954SCole Faust * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22*c217d954SCole Faust * SOFTWARE.
23*c217d954SCole Faust */
24*c217d954SCole Faust #include "src/cpu/kernels/CpuGemmLowpOffsetContributionKernel.h"
25*c217d954SCole Faust
26*c217d954SCole Faust #include "arm_compute/core/Error.h"
27*c217d954SCole Faust #include "arm_compute/core/Helpers.h"
28*c217d954SCole Faust #include "arm_compute/core/ITensor.h"
29*c217d954SCole Faust #include "arm_compute/core/TensorInfo.h"
30*c217d954SCole Faust #include "arm_compute/core/Types.h"
31*c217d954SCole Faust #include "arm_compute/core/Utils.h"
32*c217d954SCole Faust #include "arm_compute/core/Validate.h"
33*c217d954SCole Faust #include "arm_compute/core/Window.h"
34*c217d954SCole Faust #include "src/core/helpers/AutoConfiguration.h"
35*c217d954SCole Faust #include "src/core/helpers/WindowHelpers.h"
36*c217d954SCole Faust
37*c217d954SCole Faust #include <arm_neon.h>
38*c217d954SCole Faust
39*c217d954SCole Faust namespace arm_compute
40*c217d954SCole Faust {
41*c217d954SCole Faust namespace cpu
42*c217d954SCole Faust {
43*c217d954SCole Faust namespace kernels
44*c217d954SCole Faust {
45*c217d954SCole Faust namespace
46*c217d954SCole Faust {
validate_arguments(const ITensorInfo * mm_result,const ITensorInfo * vector_sum_col,const ITensorInfo * vector_sum_row,int32_t a_offset,int32_t b_offset)47*c217d954SCole Faust Status validate_arguments(const ITensorInfo *mm_result, const ITensorInfo *vector_sum_col, const ITensorInfo *vector_sum_row,
48*c217d954SCole Faust int32_t a_offset, int32_t b_offset)
49*c217d954SCole Faust {
50*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(mm_result, 1, DataType::S32);
51*c217d954SCole Faust
52*c217d954SCole Faust // If a_offset == 0, vector_sum_col can be a nullptr
53*c217d954SCole Faust if(a_offset != 0)
54*c217d954SCole Faust {
55*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(vector_sum_col, 1, DataType::S32);
56*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON(vector_sum_col->dimension(0) != mm_result->dimension(0));
57*c217d954SCole Faust }
58*c217d954SCole Faust
59*c217d954SCole Faust // If b_offset == 0, vector_sum_row can be a nullptr
60*c217d954SCole Faust if(b_offset != 0)
61*c217d954SCole Faust {
62*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(vector_sum_row, 1, DataType::S32);
63*c217d954SCole Faust
64*c217d954SCole Faust // Check if input is a 3D reinterpretation
65*c217d954SCole Faust const bool reinterpret_as_3d = mm_result->num_dimensions() > 1 && mm_result->tensor_shape().y() != vector_sum_row->tensor_shape().x();
66*c217d954SCole Faust
67*c217d954SCole Faust // Validate input
68*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON(reinterpret_as_3d && vector_sum_row->dimension(0) != (mm_result->dimension(1) * mm_result->dimension(2)));
69*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON(!reinterpret_as_3d && vector_sum_row->dimension(0) != mm_result->dimension(1));
70*c217d954SCole Faust
71*c217d954SCole Faust TensorShape output_shape = mm_result->tensor_shape();
72*c217d954SCole Faust if(output_shape.num_dimensions() > 1)
73*c217d954SCole Faust {
74*c217d954SCole Faust const unsigned int output_batch_idx = reinterpret_as_3d ? 3 : 2;
75*c217d954SCole Faust
76*c217d954SCole Faust TensorShape vector_sum_row_shape = vector_sum_row->tensor_shape();
77*c217d954SCole Faust vector_sum_row_shape.collapse_from(1);
78*c217d954SCole Faust output_shape.collapse_from(output_batch_idx);
79*c217d954SCole Faust
80*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON_MSG(vector_sum_row_shape[1] != output_shape[output_batch_idx],
81*c217d954SCole Faust "mm_result tensor must have the same number of batches of output tensor");
82*c217d954SCole Faust
83*c217d954SCole Faust if(a_offset != 0)
84*c217d954SCole Faust {
85*c217d954SCole Faust TensorShape vector_sum_col_shape = vector_sum_col->tensor_shape();
86*c217d954SCole Faust vector_sum_col_shape.collapse_from(1);
87*c217d954SCole Faust
88*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON_MSG(vector_sum_col_shape[1] != 1 && vector_sum_col_shape[1] != vector_sum_row_shape[1],
89*c217d954SCole Faust "vector_sum_col tensor must have the same number of batches of vector_sum_row_shape or the number of batches must be set to 1");
90*c217d954SCole Faust }
91*c217d954SCole Faust }
92*c217d954SCole Faust }
93*c217d954SCole Faust
94*c217d954SCole Faust return Status{};
95*c217d954SCole Faust }
96*c217d954SCole Faust
run_offset_contribution(const Window & window,ITensor * mm_result,const ITensor * vector_sum_col,const ITensor * vector_sum_row,int32_t a_offset,int32_t b_offset,int32_t k_offset,bool slide_vector_sum_col,bool is_gemm3d)97*c217d954SCole Faust void run_offset_contribution(const Window &window,
98*c217d954SCole Faust ITensor *mm_result, const ITensor *vector_sum_col, const ITensor *vector_sum_row,
99*c217d954SCole Faust int32_t a_offset, int32_t b_offset, int32_t k_offset, bool slide_vector_sum_col, bool is_gemm3d)
100*c217d954SCole Faust {
101*c217d954SCole Faust Window collapsed_window = window.collapse_if_possible(window, Window::DimZ);
102*c217d954SCole Faust collapsed_window.set(Window::DimX, Window::Dimension(0, 1, 1));
103*c217d954SCole Faust
104*c217d954SCole Faust const int height_input = is_gemm3d ? mm_result->info()->dimension(1) : 0;
105*c217d954SCole Faust const int depth_input = is_gemm3d ? mm_result->info()->dimension(2) : 1;
106*c217d954SCole Faust
107*c217d954SCole Faust const int window_start_x = window.x().start();
108*c217d954SCole Faust const int window_end_x = window.x().end();
109*c217d954SCole Faust const int window_step_x = 16;
110*c217d954SCole Faust
111*c217d954SCole Faust // if vector_sum_col is nullptr then stride_y is 0, else get stride_y
112*c217d954SCole Faust const size_t sum_col_stride_y = (vector_sum_col != nullptr) ? (vector_sum_col->info()->strides_in_bytes().y()) : 0;
113*c217d954SCole Faust Iterator mm_result_it(mm_result, collapsed_window);
114*c217d954SCole Faust
115*c217d954SCole Faust if((a_offset != 0) && (b_offset != 0) && (vector_sum_col != nullptr) && (vector_sum_row != nullptr)) // true, true
116*c217d954SCole Faust {
117*c217d954SCole Faust // Set window for vector_sum_col
118*c217d954SCole Faust Window win_vector_sum_col(collapsed_window);
119*c217d954SCole Faust win_vector_sum_col.set(Window::DimY, Window::Dimension(0, 0, 0));
120*c217d954SCole Faust win_vector_sum_col.set(Window::DimZ, Window::Dimension(0, 0, 0));
121*c217d954SCole Faust
122*c217d954SCole Faust // Set window for vector_sum_row
123*c217d954SCole Faust Window win_vector_sum_row(collapsed_window);
124*c217d954SCole Faust win_vector_sum_row.set(Window::DimX, Window::Dimension(0, 0, 0));
125*c217d954SCole Faust win_vector_sum_row.set(Window::DimY, Window::Dimension(0, 0, 0));
126*c217d954SCole Faust win_vector_sum_row.set(Window::DimZ, Window::Dimension(0, 0, 0));
127*c217d954SCole Faust
128*c217d954SCole Faust Iterator vector_sum_col_it(vector_sum_col, win_vector_sum_col);
129*c217d954SCole Faust Iterator vector_sum_row_it(vector_sum_row, win_vector_sum_row);
130*c217d954SCole Faust
131*c217d954SCole Faust const size_t sum_row_stride_y = vector_sum_row->info()->strides_in_bytes().y();
132*c217d954SCole Faust
133*c217d954SCole Faust // Offset in case vector_sum_col is batched
134*c217d954SCole Faust const int vector_sum_col_batch_offset = slide_vector_sum_col ? vector_sum_col->info()->strides_in_bytes().z() : 0;
135*c217d954SCole Faust
136*c217d954SCole Faust execute_window_loop(collapsed_window, [&](const Coordinates & id)
137*c217d954SCole Faust {
138*c217d954SCole Faust const int batch_id = id.z() / depth_input;
139*c217d954SCole Faust const size_t batch_offset_col = batch_id * (sum_col_stride_y );
140*c217d954SCole Faust auto vector_sum_col_ptr = reinterpret_cast<const int32_t *>(vector_sum_col_it.ptr() + batch_offset_col + batch_id * vector_sum_col_batch_offset);
141*c217d954SCole Faust auto mm_result_ptr = reinterpret_cast<int32_t *>(mm_result_it.ptr());
142*c217d954SCole Faust
143*c217d954SCole Faust // Compute the leftover term due to b_offset.
144*c217d954SCole Faust int32_t b_offset_term_s32 = *(reinterpret_cast<const int32_t *>(vector_sum_row_it.ptr() + batch_id * sum_row_stride_y) + id.y() + (id.z() % depth_input) * height_input);
145*c217d954SCole Faust b_offset_term_s32 *= b_offset;
146*c217d954SCole Faust
147*c217d954SCole Faust const int32x4_t b_offset_term_s32_vec = vdupq_n_s32(b_offset_term_s32);
148*c217d954SCole Faust
149*c217d954SCole Faust int x = window_start_x;
150*c217d954SCole Faust for(; x <= (window_end_x - window_step_x); x += window_step_x)
151*c217d954SCole Faust {
152*c217d954SCole Faust // Compute the leftover term due to a_offset.
153*c217d954SCole Faust int32x4x4_t a_offset_term_s32 =
154*c217d954SCole Faust {
155*c217d954SCole Faust {
156*c217d954SCole Faust vld1q_s32(vector_sum_col_ptr + x + 0),
157*c217d954SCole Faust vld1q_s32(vector_sum_col_ptr + x + 4),
158*c217d954SCole Faust vld1q_s32(vector_sum_col_ptr + x + 8),
159*c217d954SCole Faust vld1q_s32(vector_sum_col_ptr + x + 12)
160*c217d954SCole Faust }
161*c217d954SCole Faust };
162*c217d954SCole Faust
163*c217d954SCole Faust a_offset_term_s32.val[0] = vmulq_n_s32(a_offset_term_s32.val[0], a_offset);
164*c217d954SCole Faust a_offset_term_s32.val[1] = vmulq_n_s32(a_offset_term_s32.val[1], a_offset);
165*c217d954SCole Faust a_offset_term_s32.val[2] = vmulq_n_s32(a_offset_term_s32.val[2], a_offset);
166*c217d954SCole Faust a_offset_term_s32.val[3] = vmulq_n_s32(a_offset_term_s32.val[3], a_offset);
167*c217d954SCole Faust
168*c217d954SCole Faust // Add a_offset_term_s32 and b_offset_term_s32
169*c217d954SCole Faust int32x4x4_t offset_term_s32 =
170*c217d954SCole Faust {
171*c217d954SCole Faust {
172*c217d954SCole Faust vdupq_n_s32(k_offset),
173*c217d954SCole Faust vdupq_n_s32(k_offset),
174*c217d954SCole Faust vdupq_n_s32(k_offset),
175*c217d954SCole Faust vdupq_n_s32(k_offset)
176*c217d954SCole Faust }
177*c217d954SCole Faust };
178*c217d954SCole Faust
179*c217d954SCole Faust offset_term_s32.val[0] = vaddq_s32(offset_term_s32.val[0], vaddq_s32(a_offset_term_s32.val[0], b_offset_term_s32_vec));
180*c217d954SCole Faust offset_term_s32.val[1] = vaddq_s32(offset_term_s32.val[1], vaddq_s32(a_offset_term_s32.val[1], b_offset_term_s32_vec));
181*c217d954SCole Faust offset_term_s32.val[2] = vaddq_s32(offset_term_s32.val[2], vaddq_s32(a_offset_term_s32.val[2], b_offset_term_s32_vec));
182*c217d954SCole Faust offset_term_s32.val[3] = vaddq_s32(offset_term_s32.val[3], vaddq_s32(a_offset_term_s32.val[3], b_offset_term_s32_vec));
183*c217d954SCole Faust
184*c217d954SCole Faust int32x4x4_t in_s32 =
185*c217d954SCole Faust {
186*c217d954SCole Faust {
187*c217d954SCole Faust vld1q_s32(mm_result_ptr + x + 0),
188*c217d954SCole Faust vld1q_s32(mm_result_ptr + x + 4),
189*c217d954SCole Faust vld1q_s32(mm_result_ptr + x + 8),
190*c217d954SCole Faust vld1q_s32(mm_result_ptr + x + 12)
191*c217d954SCole Faust }
192*c217d954SCole Faust };
193*c217d954SCole Faust
194*c217d954SCole Faust // Add the offset terms to GEMM's result
195*c217d954SCole Faust in_s32.val[0] = vaddq_s32(in_s32.val[0], offset_term_s32.val[0]);
196*c217d954SCole Faust in_s32.val[1] = vaddq_s32(in_s32.val[1], offset_term_s32.val[1]);
197*c217d954SCole Faust in_s32.val[2] = vaddq_s32(in_s32.val[2], offset_term_s32.val[2]);
198*c217d954SCole Faust in_s32.val[3] = vaddq_s32(in_s32.val[3], offset_term_s32.val[3]);
199*c217d954SCole Faust
200*c217d954SCole Faust // Store the result with the offset contribution
201*c217d954SCole Faust vst1q_s32(mm_result_ptr + x + 0, in_s32.val[0]);
202*c217d954SCole Faust vst1q_s32(mm_result_ptr + x + 4, in_s32.val[1]);
203*c217d954SCole Faust vst1q_s32(mm_result_ptr + x + 8, in_s32.val[2]);
204*c217d954SCole Faust vst1q_s32(mm_result_ptr + x + 12, in_s32.val[3]);
205*c217d954SCole Faust }
206*c217d954SCole Faust
207*c217d954SCole Faust // Left-overs loop
208*c217d954SCole Faust for(; x < window_end_x; ++x)
209*c217d954SCole Faust {
210*c217d954SCole Faust // Compute the leftover term due to a_offset.
211*c217d954SCole Faust int32_t a_offset_term_s32 = *(vector_sum_col_ptr + x);
212*c217d954SCole Faust
213*c217d954SCole Faust a_offset_term_s32 *= a_offset;
214*c217d954SCole Faust
215*c217d954SCole Faust // Add the offset terms to GEMM's result
216*c217d954SCole Faust // Store the result with the offset contribution
217*c217d954SCole Faust mm_result_ptr[x] += k_offset + a_offset_term_s32 + b_offset_term_s32;
218*c217d954SCole Faust }
219*c217d954SCole Faust },
220*c217d954SCole Faust vector_sum_col_it, vector_sum_row_it, mm_result_it);
221*c217d954SCole Faust }
222*c217d954SCole Faust else if((a_offset == 0) && (b_offset != 0) && (vector_sum_row != nullptr)) // false, true
223*c217d954SCole Faust {
224*c217d954SCole Faust ARM_COMPUTE_ERROR_ON_NULLPTR(vector_sum_row);
225*c217d954SCole Faust
226*c217d954SCole Faust // Set window for vector_sum_row
227*c217d954SCole Faust Window win_vector_sum_row(collapsed_window);
228*c217d954SCole Faust win_vector_sum_row.set(Window::DimX, Window::Dimension(0, 0, 0));
229*c217d954SCole Faust win_vector_sum_row.set(Window::DimY, Window::Dimension(0, 0, 0));
230*c217d954SCole Faust win_vector_sum_row.set(Window::DimZ, Window::Dimension(0, 0, 0));
231*c217d954SCole Faust
232*c217d954SCole Faust Iterator vector_sum_row_it(vector_sum_row, win_vector_sum_row);
233*c217d954SCole Faust
234*c217d954SCole Faust const size_t sum_row_stride_y = vector_sum_row->info()->strides_in_bytes().y();
235*c217d954SCole Faust
236*c217d954SCole Faust execute_window_loop(collapsed_window, [&](const Coordinates & id)
237*c217d954SCole Faust {
238*c217d954SCole Faust const int batch_id = id.z() / depth_input;
239*c217d954SCole Faust auto mm_result_ptr = reinterpret_cast<int32_t *>(mm_result_it.ptr());
240*c217d954SCole Faust
241*c217d954SCole Faust // Compute the leftover term due to b_offset.
242*c217d954SCole Faust int32_t b_offset_term_s32 = *(reinterpret_cast<const int32_t *>(vector_sum_row_it.ptr() + batch_id * sum_row_stride_y) + id.y() + (id.z() % depth_input) * height_input);
243*c217d954SCole Faust b_offset_term_s32 *= b_offset;
244*c217d954SCole Faust
245*c217d954SCole Faust const int32x4_t b_offset_term_s32_vec = vdupq_n_s32(b_offset_term_s32);
246*c217d954SCole Faust
247*c217d954SCole Faust int x = window_start_x;
248*c217d954SCole Faust for(; x <= (window_end_x - window_step_x); x += window_step_x)
249*c217d954SCole Faust {
250*c217d954SCole Faust int32x4x4_t in_s32 =
251*c217d954SCole Faust {
252*c217d954SCole Faust {
253*c217d954SCole Faust vld1q_s32(mm_result_ptr + x + 0),
254*c217d954SCole Faust vld1q_s32(mm_result_ptr + x + 4),
255*c217d954SCole Faust vld1q_s32(mm_result_ptr + x + 8),
256*c217d954SCole Faust vld1q_s32(mm_result_ptr + x + 12)
257*c217d954SCole Faust }
258*c217d954SCole Faust };
259*c217d954SCole Faust
260*c217d954SCole Faust // Add the offset terms to GEMM's result
261*c217d954SCole Faust in_s32.val[0] = vaddq_s32(in_s32.val[0], b_offset_term_s32_vec);
262*c217d954SCole Faust in_s32.val[1] = vaddq_s32(in_s32.val[1], b_offset_term_s32_vec);
263*c217d954SCole Faust in_s32.val[2] = vaddq_s32(in_s32.val[2], b_offset_term_s32_vec);
264*c217d954SCole Faust in_s32.val[3] = vaddq_s32(in_s32.val[3], b_offset_term_s32_vec);
265*c217d954SCole Faust
266*c217d954SCole Faust // Store the result with the offset contribution
267*c217d954SCole Faust vst1q_s32(mm_result_ptr + x + 0, in_s32.val[0]);
268*c217d954SCole Faust vst1q_s32(mm_result_ptr + x + 4, in_s32.val[1]);
269*c217d954SCole Faust vst1q_s32(mm_result_ptr + x + 8, in_s32.val[2]);
270*c217d954SCole Faust vst1q_s32(mm_result_ptr + x + 12, in_s32.val[3]);
271*c217d954SCole Faust }
272*c217d954SCole Faust
273*c217d954SCole Faust // Left-overs loop
274*c217d954SCole Faust for(; x < window_end_x; ++x)
275*c217d954SCole Faust {
276*c217d954SCole Faust // Add the offset terms to GEMM's result
277*c217d954SCole Faust // Store the result with the offset contribution
278*c217d954SCole Faust mm_result_ptr[x] += b_offset_term_s32;
279*c217d954SCole Faust }
280*c217d954SCole Faust },
281*c217d954SCole Faust vector_sum_row_it, mm_result_it);
282*c217d954SCole Faust }
283*c217d954SCole Faust else if((a_offset != 0) && (b_offset == 0) && (vector_sum_col != nullptr)) // true, false
284*c217d954SCole Faust {
285*c217d954SCole Faust // Set window for vector_sum_col
286*c217d954SCole Faust Window win_vector_sum_col(collapsed_window);
287*c217d954SCole Faust win_vector_sum_col.set(Window::DimY, Window::Dimension(0, 0, 0));
288*c217d954SCole Faust win_vector_sum_col.set(Window::DimZ, Window::Dimension(0, 0, 0));
289*c217d954SCole Faust
290*c217d954SCole Faust Iterator vector_sum_col_it(vector_sum_col, win_vector_sum_col);
291*c217d954SCole Faust
292*c217d954SCole Faust // Offset in case vector_sum_col is batched
293*c217d954SCole Faust const int vector_sum_col_batch_offset = slide_vector_sum_col ? vector_sum_col->info()->strides_in_bytes().z() : 0;
294*c217d954SCole Faust
295*c217d954SCole Faust execute_window_loop(collapsed_window, [&](const Coordinates & id)
296*c217d954SCole Faust {
297*c217d954SCole Faust const int batch_id = id.z() / depth_input;
298*c217d954SCole Faust const size_t batch_offset_col = batch_id * (sum_col_stride_y ); // Value to offset vector_sum_col_ptr to allow for iteration of y values in tensor
299*c217d954SCole Faust auto vector_sum_col_ptr = reinterpret_cast<const int32_t *>(vector_sum_col_it.ptr() + batch_offset_col + batch_id * vector_sum_col_batch_offset);
300*c217d954SCole Faust auto mm_result_ptr = reinterpret_cast<int32_t *>(mm_result_it.ptr());
301*c217d954SCole Faust
302*c217d954SCole Faust int x = window_start_x;
303*c217d954SCole Faust for(; x <= (window_end_x - window_step_x); x += window_step_x)
304*c217d954SCole Faust {
305*c217d954SCole Faust // Compute the leftover term due to a_offset.
306*c217d954SCole Faust int32x4x4_t a_offset_term_s32 =
307*c217d954SCole Faust {
308*c217d954SCole Faust {
309*c217d954SCole Faust vld1q_s32(vector_sum_col_ptr + x + 0),
310*c217d954SCole Faust vld1q_s32(vector_sum_col_ptr + x + 4),
311*c217d954SCole Faust vld1q_s32(vector_sum_col_ptr + x + 8),
312*c217d954SCole Faust vld1q_s32(vector_sum_col_ptr + x + 12)
313*c217d954SCole Faust }
314*c217d954SCole Faust };
315*c217d954SCole Faust
316*c217d954SCole Faust a_offset_term_s32.val[0] = vmulq_n_s32(a_offset_term_s32.val[0], a_offset);
317*c217d954SCole Faust a_offset_term_s32.val[1] = vmulq_n_s32(a_offset_term_s32.val[1], a_offset);
318*c217d954SCole Faust a_offset_term_s32.val[2] = vmulq_n_s32(a_offset_term_s32.val[2], a_offset);
319*c217d954SCole Faust a_offset_term_s32.val[3] = vmulq_n_s32(a_offset_term_s32.val[3], a_offset);
320*c217d954SCole Faust
321*c217d954SCole Faust int32x4x4_t in_s32 =
322*c217d954SCole Faust {
323*c217d954SCole Faust {
324*c217d954SCole Faust vld1q_s32(mm_result_ptr + x + 0),
325*c217d954SCole Faust vld1q_s32(mm_result_ptr + x + 4),
326*c217d954SCole Faust vld1q_s32(mm_result_ptr + x + 8),
327*c217d954SCole Faust vld1q_s32(mm_result_ptr + x + 12)
328*c217d954SCole Faust }
329*c217d954SCole Faust };
330*c217d954SCole Faust
331*c217d954SCole Faust // Add the offset terms to GEMM's result
332*c217d954SCole Faust in_s32.val[0] = vaddq_s32(in_s32.val[0], a_offset_term_s32.val[0]);
333*c217d954SCole Faust in_s32.val[1] = vaddq_s32(in_s32.val[1], a_offset_term_s32.val[1]);
334*c217d954SCole Faust in_s32.val[2] = vaddq_s32(in_s32.val[2], a_offset_term_s32.val[2]);
335*c217d954SCole Faust in_s32.val[3] = vaddq_s32(in_s32.val[3], a_offset_term_s32.val[3]);
336*c217d954SCole Faust
337*c217d954SCole Faust // Store the result with the offset contribution
338*c217d954SCole Faust vst1q_s32(mm_result_ptr + x + 0, in_s32.val[0]);
339*c217d954SCole Faust vst1q_s32(mm_result_ptr + x + 4, in_s32.val[1]);
340*c217d954SCole Faust vst1q_s32(mm_result_ptr + x + 8, in_s32.val[2]);
341*c217d954SCole Faust vst1q_s32(mm_result_ptr + x + 12, in_s32.val[3]);
342*c217d954SCole Faust }
343*c217d954SCole Faust
344*c217d954SCole Faust // Left-overs loop
345*c217d954SCole Faust for(; x < window_end_x; ++x)
346*c217d954SCole Faust {
347*c217d954SCole Faust // Compute the leftover term due to a_offset.
348*c217d954SCole Faust const int32_t a_offset_term_s32 = *(vector_sum_col_ptr + x);
349*c217d954SCole Faust
350*c217d954SCole Faust // Add the offset terms to GEMM's result
351*c217d954SCole Faust // Store the result with the offset contribution
352*c217d954SCole Faust mm_result_ptr[x] += a_offset_term_s32 * a_offset;
353*c217d954SCole Faust }
354*c217d954SCole Faust },
355*c217d954SCole Faust vector_sum_col_it, mm_result_it);
356*c217d954SCole Faust }
357*c217d954SCole Faust else // false, false
358*c217d954SCole Faust {
359*c217d954SCole Faust // No offset contribution from matrix A and matrix B
360*c217d954SCole Faust return;
361*c217d954SCole Faust }
362*c217d954SCole Faust }
363*c217d954SCole Faust } // namespace
364*c217d954SCole Faust
configure(ITensorInfo * mm_result,ITensorInfo * vector_sum_col,ITensorInfo * vector_sum_row,int32_t k,int32_t a_offset,int32_t b_offset)365*c217d954SCole Faust void CpuGemmLowpOffsetContributionKernel::configure(ITensorInfo *mm_result, ITensorInfo *vector_sum_col, ITensorInfo *vector_sum_row, int32_t k, int32_t a_offset, int32_t b_offset)
366*c217d954SCole Faust {
367*c217d954SCole Faust // Perform validate step
368*c217d954SCole Faust ARM_COMPUTE_UNUSED(vector_sum_row);
369*c217d954SCole Faust ARM_COMPUTE_ERROR_ON_NULLPTR(mm_result);
370*c217d954SCole Faust ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(mm_result, vector_sum_col, vector_sum_row, a_offset, b_offset));
371*c217d954SCole Faust
372*c217d954SCole Faust _a_offset = a_offset;
373*c217d954SCole Faust _b_offset = b_offset;
374*c217d954SCole Faust _k_offset = a_offset * b_offset * k;
375*c217d954SCole Faust
376*c217d954SCole Faust // If a_offset == 0, vector_sum_col can be a nullptr
377*c217d954SCole Faust if(a_offset != 0)
378*c217d954SCole Faust {
379*c217d954SCole Faust // Check if vector_sum_col_shape should be slidden or not
380*c217d954SCole Faust // Don't slide vector_sum_col_shape along the y dimension if vector_sum_col_shape has just 1 dimension and vector_sum_row_shape more than 1
381*c217d954SCole Faust // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
382*c217d954SCole Faust _slide_vector_sum_col = vector_sum_col->tensor_shape().num_dimensions() > 1;
383*c217d954SCole Faust }
384*c217d954SCole Faust
385*c217d954SCole Faust // Configure kernel window
386*c217d954SCole Faust Window win = calculate_max_window(*mm_result, Steps());
387*c217d954SCole Faust ICpuKernel::configure(win);
388*c217d954SCole Faust }
389*c217d954SCole Faust
validate(const ITensorInfo * mm_result,const ITensorInfo * vector_sum_col,const ITensorInfo * vector_sum_row,int32_t a_offset,int32_t b_offset)390*c217d954SCole Faust Status CpuGemmLowpOffsetContributionKernel::validate(const ITensorInfo *mm_result, const ITensorInfo *vector_sum_col, const ITensorInfo *vector_sum_row,
391*c217d954SCole Faust int32_t a_offset, int32_t b_offset)
392*c217d954SCole Faust {
393*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(mm_result, vector_sum_col, vector_sum_row, a_offset, b_offset));
394*c217d954SCole Faust return Status{};
395*c217d954SCole Faust }
396*c217d954SCole Faust
run_op(ITensorPack & tensors,const Window & window,const ThreadInfo & info)397*c217d954SCole Faust void CpuGemmLowpOffsetContributionKernel::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info)
398*c217d954SCole Faust {
399*c217d954SCole Faust ARM_COMPUTE_UNUSED(info);
400*c217d954SCole Faust ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
401*c217d954SCole Faust ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICpuKernel::window(), window);
402*c217d954SCole Faust
403*c217d954SCole Faust auto vector_sum_col = tensors.get_const_tensor(TensorType::ACL_SRC_0);
404*c217d954SCole Faust auto vector_sum_row = tensors.get_const_tensor(TensorType::ACL_SRC_1);
405*c217d954SCole Faust auto mm_result = tensors.get_tensor(TensorType::ACL_DST);
406*c217d954SCole Faust
407*c217d954SCole Faust // Check if input is a 3D reinterpretation
408*c217d954SCole Faust const bool reinterpret_as_3d = vector_sum_row != nullptr
409*c217d954SCole Faust && mm_result->info()->num_dimensions() > 1
410*c217d954SCole Faust && mm_result->info()->tensor_shape().y() != vector_sum_row->info()->tensor_shape().x();
411*c217d954SCole Faust
412*c217d954SCole Faust run_offset_contribution(window, mm_result, vector_sum_col, vector_sum_row, _a_offset, _b_offset, _k_offset, _slide_vector_sum_col, reinterpret_as_3d);
413*c217d954SCole Faust }
414*c217d954SCole Faust
name() const415*c217d954SCole Faust const char *CpuGemmLowpOffsetContributionKernel::name() const
416*c217d954SCole Faust {
417*c217d954SCole Faust return "CpuGemmLowpOffsetContributionKernel";
418*c217d954SCole Faust }
419*c217d954SCole Faust } // namespace kernels
420*c217d954SCole Faust } // namespace cpu
421*c217d954SCole Faust } // namespace arm_compute