xref: /aosp_15_r20/external/gemmlowp/meta/legacy_single_thread_gemm.h (revision 5f39d1b313f0528e11bae88b3029b54b9e1033e7)
1*5f39d1b3SJooyung Han // Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
2*5f39d1b3SJooyung Han //
3*5f39d1b3SJooyung Han // Licensed under the Apache License, Version 2.0 (the "License");
4*5f39d1b3SJooyung Han // you may not use this file except in compliance with the License.
5*5f39d1b3SJooyung Han // You may obtain a copy of the License at
6*5f39d1b3SJooyung Han //
7*5f39d1b3SJooyung Han //     http://www.apache.org/licenses/LICENSE-2.0
8*5f39d1b3SJooyung Han //
9*5f39d1b3SJooyung Han // Unless required by applicable law or agreed to in writing, software
10*5f39d1b3SJooyung Han // distributed under the License is distributed on an "AS IS" BASIS,
11*5f39d1b3SJooyung Han // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*5f39d1b3SJooyung Han // See the License for the specific language governing permissions and
13*5f39d1b3SJooyung Han // limitations under the License.
14*5f39d1b3SJooyung Han 
15*5f39d1b3SJooyung Han #ifndef GEMMLOWP_META_LEGACY_SINGLE_THREAD_GEMM_H_
16*5f39d1b3SJooyung Han #define GEMMLOWP_META_LEGACY_SINGLE_THREAD_GEMM_H_
17*5f39d1b3SJooyung Han 
18*5f39d1b3SJooyung Han #include "../internal/common.h"
19*5f39d1b3SJooyung Han 
20*5f39d1b3SJooyung Han #ifdef GEMMLOWP_NEON
21*5f39d1b3SJooyung Han 
22*5f39d1b3SJooyung Han #include "quantized_mul_kernels.h"
23*5f39d1b3SJooyung Han #include "single_thread_gemm.h"
24*5f39d1b3SJooyung Han #include "streams.h"
25*5f39d1b3SJooyung Han 
26*5f39d1b3SJooyung Han namespace gemmlowp {
27*5f39d1b3SJooyung Han namespace meta {
28*5f39d1b3SJooyung Han 
gemm_q8_strided(std::uint8_t * scratch,const std::uint8_t * lhs,const std::uint8_t * rhs,std::int32_t m,std::int32_t n,std::int32_t k,std::int32_t lhs_offset,std::int32_t rhs_offset,std::int32_t result_offset,std::int32_t multiplicative_offset,std::int32_t shift,std::uint8_t * result,std::int32_t result_stride)29*5f39d1b3SJooyung Han void gemm_q8_strided(std::uint8_t* scratch, const std::uint8_t* lhs,
30*5f39d1b3SJooyung Han                      const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
31*5f39d1b3SJooyung Han                      std::int32_t k, std::int32_t lhs_offset,
32*5f39d1b3SJooyung Han                      std::int32_t rhs_offset, std::int32_t result_offset,
33*5f39d1b3SJooyung Han                      std::int32_t multiplicative_offset, std::int32_t shift,
34*5f39d1b3SJooyung Han                      std::uint8_t* result, std::int32_t result_stride) {
35*5f39d1b3SJooyung Han #ifdef DEBUG
36*5f39d1b3SJooyung Han #ifdef DEBUG_METAGEMM_LEGACY_VERBOSE
37*5f39d1b3SJooyung Han   std::cout << "Legacy::GemmQ8." << std::endl;
38*5f39d1b3SJooyung Han #endif
39*5f39d1b3SJooyung Han #endif
40*5f39d1b3SJooyung Han   typedef GemmParams<std::uint8_t, std::uint8_t, RowMajorWithSum,
41*5f39d1b3SJooyung Han                      RowMajorWithSum, QuantizedStaticPreprocessed, RowMajor>
42*5f39d1b3SJooyung Han       Params;
43*5f39d1b3SJooyung Han   Params params;
44*5f39d1b3SJooyung Han 
45*5f39d1b3SJooyung Han   params.m = m;
46*5f39d1b3SJooyung Han   params.n = n;
47*5f39d1b3SJooyung Han   params.k = k;
48*5f39d1b3SJooyung Han 
49*5f39d1b3SJooyung Han   params.lhs = lhs;
50*5f39d1b3SJooyung Han   params.rhs = rhs;
51*5f39d1b3SJooyung Han   params.result = result;
52*5f39d1b3SJooyung Han   params.scratch = scratch;
53*5f39d1b3SJooyung Han 
54*5f39d1b3SJooyung Han   params.left_stream.count = k;
55*5f39d1b3SJooyung Han   params.left_stream.stride = k;
56*5f39d1b3SJooyung Han   params.left_stream.multiplicative_sum_offset = rhs_offset;
57*5f39d1b3SJooyung Han   params.left_stream.additive_sum_offset =
58*5f39d1b3SJooyung Han       result_offset + k * lhs_offset * rhs_offset;
59*5f39d1b3SJooyung Han 
60*5f39d1b3SJooyung Han   params.right_stream.count = k;
61*5f39d1b3SJooyung Han   params.right_stream.stride = k;
62*5f39d1b3SJooyung Han   params.right_stream.multiplicative_sum_offset = lhs_offset;
63*5f39d1b3SJooyung Han   params.right_stream.additive_sum_offset = 0;
64*5f39d1b3SJooyung Han 
65*5f39d1b3SJooyung Han   params.fused_kernel.kernel.multiplicative_offset = multiplicative_offset;
66*5f39d1b3SJooyung Han   params.fused_kernel.kernel.rounding_offset = (1 << (shift - 1));
67*5f39d1b3SJooyung Han   params.fused_kernel.kernel.shift = -shift;
68*5f39d1b3SJooyung Han   params.fused_kernel.kernel.count = k;
69*5f39d1b3SJooyung Han   params.fused_kernel.output_stream.stride = result_stride;
70*5f39d1b3SJooyung Han 
71*5f39d1b3SJooyung Han   Gemm<GemmExecutorPackRHS, Params, 2, 4, 8>(params);
72*5f39d1b3SJooyung Han }
73*5f39d1b3SJooyung Han 
gemv_q8(std::uint8_t * scratch,const std::uint8_t * lhs,const std::uint8_t * rhs,std::int32_t n,std::int32_t k,std::int32_t lhs_offset,std::int32_t rhs_offset,std::int32_t result_offset,std::int32_t multiplicative_offset,std::int32_t shift,std::uint8_t * result)74*5f39d1b3SJooyung Han void gemv_q8(std::uint8_t* scratch, const std::uint8_t* lhs,
75*5f39d1b3SJooyung Han              const std::uint8_t* rhs, std::int32_t n, std::int32_t k,
76*5f39d1b3SJooyung Han              std::int32_t lhs_offset, std::int32_t rhs_offset,
77*5f39d1b3SJooyung Han              std::int32_t result_offset, std::int32_t multiplicative_offset,
78*5f39d1b3SJooyung Han              std::int32_t shift, std::uint8_t* result) {
79*5f39d1b3SJooyung Han #ifdef DEBUG
80*5f39d1b3SJooyung Han #ifdef DEBUG_METAGEMM_LEGACY_VERBOSE
81*5f39d1b3SJooyung Han   std::cout << "Legacy::GemvQ8." << std::endl;
82*5f39d1b3SJooyung Han #endif
83*5f39d1b3SJooyung Han #endif
84*5f39d1b3SJooyung Han   typedef GemmParams<std::uint8_t, std::uint8_t, RowMajorWithSum,
85*5f39d1b3SJooyung Han                      RowMajorWithSum, QuantizedStaticPreprocessed, RowMajor>
86*5f39d1b3SJooyung Han       Params;
87*5f39d1b3SJooyung Han   Params params;
88*5f39d1b3SJooyung Han 
89*5f39d1b3SJooyung Han   params.m = 1;
90*5f39d1b3SJooyung Han   params.n = n;
91*5f39d1b3SJooyung Han   params.k = k;
92*5f39d1b3SJooyung Han 
93*5f39d1b3SJooyung Han   params.lhs = lhs;
94*5f39d1b3SJooyung Han   params.rhs = rhs;
95*5f39d1b3SJooyung Han   params.result = result;
96*5f39d1b3SJooyung Han   params.scratch = scratch;
97*5f39d1b3SJooyung Han 
98*5f39d1b3SJooyung Han   params.left_stream.count = k;
99*5f39d1b3SJooyung Han   params.left_stream.stride = k;
100*5f39d1b3SJooyung Han   params.left_stream.multiplicative_sum_offset = rhs_offset;
101*5f39d1b3SJooyung Han   params.left_stream.additive_sum_offset =
102*5f39d1b3SJooyung Han       result_offset + k * lhs_offset * rhs_offset;
103*5f39d1b3SJooyung Han 
104*5f39d1b3SJooyung Han   params.right_stream.count = k;
105*5f39d1b3SJooyung Han   params.right_stream.stride = k;
106*5f39d1b3SJooyung Han   params.right_stream.multiplicative_sum_offset = lhs_offset;
107*5f39d1b3SJooyung Han   params.right_stream.additive_sum_offset = 0;
108*5f39d1b3SJooyung Han 
109*5f39d1b3SJooyung Han   params.fused_kernel.kernel.multiplicative_offset = multiplicative_offset;
110*5f39d1b3SJooyung Han   params.fused_kernel.kernel.rounding_offset = (1 << (shift - 1));
111*5f39d1b3SJooyung Han   params.fused_kernel.kernel.shift = -shift;
112*5f39d1b3SJooyung Han   params.fused_kernel.kernel.count = k;
113*5f39d1b3SJooyung Han   params.fused_kernel.output_stream.stride = n;
114*5f39d1b3SJooyung Han 
115*5f39d1b3SJooyung Han   if (k < 1536) {
116*5f39d1b3SJooyung Han     Gemm<GemmExecutorPackLHS, Params, 1, 8, 8>(params);
117*5f39d1b3SJooyung Han   } else {
118*5f39d1b3SJooyung Han     Gemm<GemmExecutorPackLHS, Params, 2, 4, 8>(params);
119*5f39d1b3SJooyung Han   }
120*5f39d1b3SJooyung Han }
121*5f39d1b3SJooyung Han 
gemm_i32_strided(std::uint8_t * scratch,const std::uint8_t * lhs,const std::uint8_t * rhs,std::int32_t m,std::int32_t n,std::int32_t k,std::int32_t lhs_offset,std::int32_t rhs_offset,std::int32_t * result,std::int32_t result_stride)122*5f39d1b3SJooyung Han void gemm_i32_strided(std::uint8_t* scratch, const std::uint8_t* lhs,
123*5f39d1b3SJooyung Han                       const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
124*5f39d1b3SJooyung Han                       std::int32_t k, std::int32_t lhs_offset,
125*5f39d1b3SJooyung Han                       std::int32_t rhs_offset, std::int32_t* result,
126*5f39d1b3SJooyung Han                       std::int32_t result_stride) {
127*5f39d1b3SJooyung Han #ifdef DEBUG
128*5f39d1b3SJooyung Han #ifdef DEBUG_METAGEMM_LEGACY_VERBOSE
129*5f39d1b3SJooyung Han   std::cout << "Legacy::GemmI32." << std::endl;
130*5f39d1b3SJooyung Han #endif
131*5f39d1b3SJooyung Han #endif
132*5f39d1b3SJooyung Han   typedef GemmParams<std::uint8_t, std::int32_t, RowMajorWithSum,
133*5f39d1b3SJooyung Han                      RowMajorWithSum, QuantizedStaticPreprocessedAsInt32,
134*5f39d1b3SJooyung Han                      RowMajor>
135*5f39d1b3SJooyung Han       Params;
136*5f39d1b3SJooyung Han   Params params;
137*5f39d1b3SJooyung Han 
138*5f39d1b3SJooyung Han   params.m = m;
139*5f39d1b3SJooyung Han   params.n = n;
140*5f39d1b3SJooyung Han   params.k = k;
141*5f39d1b3SJooyung Han 
142*5f39d1b3SJooyung Han   params.lhs = lhs;
143*5f39d1b3SJooyung Han   params.rhs = rhs;
144*5f39d1b3SJooyung Han   params.result = result;
145*5f39d1b3SJooyung Han   params.scratch = scratch;
146*5f39d1b3SJooyung Han 
147*5f39d1b3SJooyung Han   params.left_stream.count = k;
148*5f39d1b3SJooyung Han   params.left_stream.stride = k;
149*5f39d1b3SJooyung Han   params.left_stream.multiplicative_sum_offset = rhs_offset;
150*5f39d1b3SJooyung Han   params.left_stream.additive_sum_offset = k * lhs_offset * rhs_offset;
151*5f39d1b3SJooyung Han 
152*5f39d1b3SJooyung Han   params.right_stream.count = k;
153*5f39d1b3SJooyung Han   params.right_stream.stride = k;
154*5f39d1b3SJooyung Han   params.right_stream.multiplicative_sum_offset = lhs_offset;
155*5f39d1b3SJooyung Han   params.right_stream.additive_sum_offset = 0;
156*5f39d1b3SJooyung Han 
157*5f39d1b3SJooyung Han   params.fused_kernel.kernel.count = k;
158*5f39d1b3SJooyung Han   params.fused_kernel.output_stream.stride = result_stride * 4;
159*5f39d1b3SJooyung Han 
160*5f39d1b3SJooyung Han   Gemm<GemmExecutorPackRHS, Params, 2, 4, 8>(params);
161*5f39d1b3SJooyung Han }
162*5f39d1b3SJooyung Han 
gemv_i32(std::uint8_t * scratch,const std::uint8_t * lhs,const std::uint8_t * rhs,std::int32_t n,std::int32_t k,std::int32_t lhs_offset,std::int32_t rhs_offset,std::int32_t * result)163*5f39d1b3SJooyung Han void gemv_i32(std::uint8_t* scratch, const std::uint8_t* lhs,
164*5f39d1b3SJooyung Han               const std::uint8_t* rhs, std::int32_t n, std::int32_t k,
165*5f39d1b3SJooyung Han               std::int32_t lhs_offset, std::int32_t rhs_offset,
166*5f39d1b3SJooyung Han               std::int32_t* result) {
167*5f39d1b3SJooyung Han #ifdef DEBUG
168*5f39d1b3SJooyung Han #ifdef DEBUG_METAGEMM_LEGACY_VERBOSE
169*5f39d1b3SJooyung Han   std::cout << "Legacy::GemvI32." << std::endl;
170*5f39d1b3SJooyung Han #endif
171*5f39d1b3SJooyung Han #endif
172*5f39d1b3SJooyung Han   typedef GemmParams<std::uint8_t, std::int32_t, RowMajorWithSum,
173*5f39d1b3SJooyung Han                      RowMajorWithSum, QuantizedStaticPreprocessedAsInt32,
174*5f39d1b3SJooyung Han                      RowMajor>
175*5f39d1b3SJooyung Han       Params;
176*5f39d1b3SJooyung Han   Params params;
177*5f39d1b3SJooyung Han 
178*5f39d1b3SJooyung Han   params.m = 1;
179*5f39d1b3SJooyung Han   params.n = n;
180*5f39d1b3SJooyung Han   params.k = k;
181*5f39d1b3SJooyung Han 
182*5f39d1b3SJooyung Han   params.lhs = lhs;
183*5f39d1b3SJooyung Han   params.rhs = rhs;
184*5f39d1b3SJooyung Han   params.result = result;
185*5f39d1b3SJooyung Han   params.scratch = scratch;
186*5f39d1b3SJooyung Han 
187*5f39d1b3SJooyung Han   params.left_stream.count = k;
188*5f39d1b3SJooyung Han   params.left_stream.stride = k;
189*5f39d1b3SJooyung Han   params.left_stream.multiplicative_sum_offset = rhs_offset;
190*5f39d1b3SJooyung Han   params.left_stream.additive_sum_offset = k * lhs_offset * rhs_offset;
191*5f39d1b3SJooyung Han 
192*5f39d1b3SJooyung Han   params.right_stream.count = k;
193*5f39d1b3SJooyung Han   params.right_stream.stride = k;
194*5f39d1b3SJooyung Han   params.right_stream.multiplicative_sum_offset = lhs_offset;
195*5f39d1b3SJooyung Han   params.right_stream.additive_sum_offset = 0;
196*5f39d1b3SJooyung Han 
197*5f39d1b3SJooyung Han   params.fused_kernel.kernel.count = k;
198*5f39d1b3SJooyung Han   params.fused_kernel.output_stream.stride = 0;
199*5f39d1b3SJooyung Han 
200*5f39d1b3SJooyung Han   if (k < 1664) {
201*5f39d1b3SJooyung Han     Gemm<GemmExecutorPackLHS, Params, 1, 8, 8>(params);
202*5f39d1b3SJooyung Han   } else {
203*5f39d1b3SJooyung Han     Gemm<GemmExecutorPackLHS, Params, 1, 6, 8>(params);
204*5f39d1b3SJooyung Han   }
205*5f39d1b3SJooyung Han }
206*5f39d1b3SJooyung Han 
gemm_f_strided(std::uint8_t * scratch,const std::uint8_t * lhs,const std::uint8_t * rhs,std::int32_t m,std::int32_t n,std::int32_t k,std::int32_t lhs_offset,std::int32_t rhs_offset,float result_offset,float * result,std::int32_t result_stride)207*5f39d1b3SJooyung Han void gemm_f_strided(std::uint8_t* scratch, const std::uint8_t* lhs,
208*5f39d1b3SJooyung Han                     const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
209*5f39d1b3SJooyung Han                     std::int32_t k, std::int32_t lhs_offset,
210*5f39d1b3SJooyung Han                     std::int32_t rhs_offset, float result_offset, float* result,
211*5f39d1b3SJooyung Han                     std::int32_t result_stride) {
212*5f39d1b3SJooyung Han #ifdef DEBUG
213*5f39d1b3SJooyung Han #ifdef DEBUG_METAGEMM_LEGACY_VERBOSE
214*5f39d1b3SJooyung Han   std::cout << "Legacy::GemmF." << std::endl;
215*5f39d1b3SJooyung Han #endif
216*5f39d1b3SJooyung Han #endif
217*5f39d1b3SJooyung Han   typedef GemmParams<std::uint8_t, float, RowMajorWithSum, RowMajorWithSum,
218*5f39d1b3SJooyung Han                      QuantizedStaticPreprocessedAsFloat, RowMajor>
219*5f39d1b3SJooyung Han       Params;
220*5f39d1b3SJooyung Han   Params params;
221*5f39d1b3SJooyung Han 
222*5f39d1b3SJooyung Han   params.m = m;
223*5f39d1b3SJooyung Han   params.n = n;
224*5f39d1b3SJooyung Han   params.k = k;
225*5f39d1b3SJooyung Han 
226*5f39d1b3SJooyung Han   params.lhs = lhs;
227*5f39d1b3SJooyung Han   params.rhs = rhs;
228*5f39d1b3SJooyung Han   params.result = result;
229*5f39d1b3SJooyung Han   params.scratch = scratch;
230*5f39d1b3SJooyung Han 
231*5f39d1b3SJooyung Han   params.left_stream.count = k;
232*5f39d1b3SJooyung Han   params.left_stream.stride = k;
233*5f39d1b3SJooyung Han   params.left_stream.multiplicative_sum_offset = rhs_offset;
234*5f39d1b3SJooyung Han   params.left_stream.additive_sum_offset = k * lhs_offset * rhs_offset;
235*5f39d1b3SJooyung Han 
236*5f39d1b3SJooyung Han   params.right_stream.count = k;
237*5f39d1b3SJooyung Han   params.right_stream.stride = k;
238*5f39d1b3SJooyung Han   params.right_stream.multiplicative_sum_offset = lhs_offset;
239*5f39d1b3SJooyung Han   params.right_stream.additive_sum_offset = 0;
240*5f39d1b3SJooyung Han 
241*5f39d1b3SJooyung Han   params.fused_kernel.kernel.count = k;
242*5f39d1b3SJooyung Han   params.fused_kernel.kernel.scale = result_offset;
243*5f39d1b3SJooyung Han   params.fused_kernel.output_stream.stride = result_stride * 4;
244*5f39d1b3SJooyung Han 
245*5f39d1b3SJooyung Han   Gemm<GemmExecutorPackRHS, Params, 2, 4, 8>(params);
246*5f39d1b3SJooyung Han }
247*5f39d1b3SJooyung Han 
gemv_f(std::uint8_t * scratch,const std::uint8_t * lhs,const std::uint8_t * rhs,std::int32_t n,std::int32_t k,std::int32_t lhs_offset,std::int32_t rhs_offset,float result_offset,float * result)248*5f39d1b3SJooyung Han void gemv_f(std::uint8_t* scratch, const std::uint8_t* lhs,
249*5f39d1b3SJooyung Han             const std::uint8_t* rhs, std::int32_t n, std::int32_t k,
250*5f39d1b3SJooyung Han             std::int32_t lhs_offset, std::int32_t rhs_offset,
251*5f39d1b3SJooyung Han             float result_offset, float* result) {
252*5f39d1b3SJooyung Han #ifdef DEBUG
253*5f39d1b3SJooyung Han #ifdef DEBUG_METAGEMM_LEGACY_VERBOSE
254*5f39d1b3SJooyung Han   std::cout << "Legacy::GemvF." << std::endl;
255*5f39d1b3SJooyung Han #endif
256*5f39d1b3SJooyung Han #endif
257*5f39d1b3SJooyung Han   typedef GemmParams<std::uint8_t, float, RowMajorWithSum, RowMajorWithSum,
258*5f39d1b3SJooyung Han                      QuantizedStaticPreprocessedAsFloat, RowMajor>
259*5f39d1b3SJooyung Han       Params;
260*5f39d1b3SJooyung Han   Params params;
261*5f39d1b3SJooyung Han 
262*5f39d1b3SJooyung Han   params.m = 1;
263*5f39d1b3SJooyung Han   params.n = n;
264*5f39d1b3SJooyung Han   params.k = k;
265*5f39d1b3SJooyung Han 
266*5f39d1b3SJooyung Han   params.lhs = lhs;
267*5f39d1b3SJooyung Han   params.rhs = rhs;
268*5f39d1b3SJooyung Han   params.result = result;
269*5f39d1b3SJooyung Han   params.scratch = scratch;
270*5f39d1b3SJooyung Han 
271*5f39d1b3SJooyung Han   params.left_stream.count = k;
272*5f39d1b3SJooyung Han   params.left_stream.stride = k;
273*5f39d1b3SJooyung Han   params.left_stream.multiplicative_sum_offset = rhs_offset;
274*5f39d1b3SJooyung Han   params.left_stream.additive_sum_offset = k * lhs_offset * rhs_offset;
275*5f39d1b3SJooyung Han 
276*5f39d1b3SJooyung Han   params.right_stream.count = k;
277*5f39d1b3SJooyung Han   params.right_stream.stride = k;
278*5f39d1b3SJooyung Han   params.right_stream.multiplicative_sum_offset = lhs_offset;
279*5f39d1b3SJooyung Han   params.right_stream.additive_sum_offset = 0;
280*5f39d1b3SJooyung Han 
281*5f39d1b3SJooyung Han   params.fused_kernel.kernel.count = k;
282*5f39d1b3SJooyung Han   params.fused_kernel.kernel.scale = result_offset;
283*5f39d1b3SJooyung Han   params.fused_kernel.output_stream.stride = 0;
284*5f39d1b3SJooyung Han 
285*5f39d1b3SJooyung Han   if (k < 1664) {
286*5f39d1b3SJooyung Han     Gemm<GemmExecutorPackLHS, Params, 1, 8, 8>(params);
287*5f39d1b3SJooyung Han   } else {
288*5f39d1b3SJooyung Han     Gemm<GemmExecutorPackLHS, Params, 1, 6, 8>(params);
289*5f39d1b3SJooyung Han   }
290*5f39d1b3SJooyung Han }
291*5f39d1b3SJooyung Han 
292*5f39d1b3SJooyung Han }  // namespace meta
293*5f39d1b3SJooyung Han }  // namespace gemmlowp
294*5f39d1b3SJooyung Han 
295*5f39d1b3SJooyung Han #else
296*5f39d1b3SJooyung Han #warning "Meta gemm fast-path requires GEMMLOWP_NEON_(32|64)!"
297*5f39d1b3SJooyung Han #endif
298*5f39d1b3SJooyung Han 
299*5f39d1b3SJooyung Han #endif  // GEMMLOWP_META_LEGACY_SINGLE_THREAD_GEMM_H_
300