xref: /aosp_15_r20/external/gemmlowp/meta/legacy_multi_thread_gemv.h (revision 5f39d1b313f0528e11bae88b3029b54b9e1033e7)
1*5f39d1b3SJooyung Han // Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
2*5f39d1b3SJooyung Han //
3*5f39d1b3SJooyung Han // Licensed under the Apache License, Version 2.0 (the "License");
4*5f39d1b3SJooyung Han // you may not use this file except in compliance with the License.
5*5f39d1b3SJooyung Han // You may obtain a copy of the License at
6*5f39d1b3SJooyung Han //
7*5f39d1b3SJooyung Han //     http://www.apache.org/licenses/LICENSE-2.0
8*5f39d1b3SJooyung Han //
9*5f39d1b3SJooyung Han // Unless required by applicable law or agreed to in writing, software
10*5f39d1b3SJooyung Han // distributed under the License is distributed on an "AS IS" BASIS,
11*5f39d1b3SJooyung Han // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*5f39d1b3SJooyung Han // See the License for the specific language governing permissions and
13*5f39d1b3SJooyung Han // limitations under the License.
14*5f39d1b3SJooyung Han 
15*5f39d1b3SJooyung Han // multi_thread_gemv.h: Entry point to the multithreaded version of the
16*5f39d1b3SJooyung Han // generated (meta) gemv library.
17*5f39d1b3SJooyung Han 
18*5f39d1b3SJooyung Han #ifndef GEMMLOWP_META_MULTI_THREAD_GEMV_H_
19*5f39d1b3SJooyung Han #define GEMMLOWP_META_MULTI_THREAD_GEMV_H_
20*5f39d1b3SJooyung Han 
21*5f39d1b3SJooyung Han #ifdef GEMMLOWP_NEON
22*5f39d1b3SJooyung Han 
23*5f39d1b3SJooyung Han #include "legacy_multi_thread_common.h"
24*5f39d1b3SJooyung Han #include "legacy_operations_common.h"
25*5f39d1b3SJooyung Han #include "legacy_single_thread_gemm.h"
26*5f39d1b3SJooyung Han 
27*5f39d1b3SJooyung Han namespace gemmlowp {
28*5f39d1b3SJooyung Han namespace meta {
29*5f39d1b3SJooyung Han namespace internal {
30*5f39d1b3SJooyung Han 
31*5f39d1b3SJooyung Han class GemvQuantized8BitOperation : public Quantized8BitOperation {
32*5f39d1b3SJooyung Han  public:
GemvQuantized8BitOperation(std::int32_t lhs_offset,std::int32_t rhs_offset,std::int32_t sum_offset,std::int32_t multiplier,std::int32_t shift)33*5f39d1b3SJooyung Han   GemvQuantized8BitOperation(std::int32_t lhs_offset, std::int32_t rhs_offset,
34*5f39d1b3SJooyung Han                              std::int32_t sum_offset, std::int32_t multiplier,
35*5f39d1b3SJooyung Han                              std::int32_t shift)
36*5f39d1b3SJooyung Han       : Quantized8BitOperation(lhs_offset, rhs_offset, sum_offset, multiplier,
37*5f39d1b3SJooyung Han                                shift) {}
38*5f39d1b3SJooyung Han 
ExecuteMatrixMatrix(std::uint8_t * scratch,const std::uint8_t * lhs,const std::uint8_t * rhs,std::int32_t m,std::int32_t n,std::int32_t k,std::uint8_t * result,std::int32_t result_stride)39*5f39d1b3SJooyung Han   void ExecuteMatrixMatrix(std::uint8_t* scratch, const std::uint8_t* lhs,
40*5f39d1b3SJooyung Han                            const std::uint8_t* rhs, std::int32_t m,
41*5f39d1b3SJooyung Han                            std::int32_t n, std::int32_t k, std::uint8_t* result,
42*5f39d1b3SJooyung Han                            std::int32_t result_stride) const {
43*5f39d1b3SJooyung Han     gemv_q8(scratch, lhs, rhs, n, k, lhs_offset, rhs_offset, sum_offset,
44*5f39d1b3SJooyung Han             multiplier, shift, result);
45*5f39d1b3SJooyung Han   }
46*5f39d1b3SJooyung Han 
ScratchPerThread(std::int32_t m,std::int32_t n,std::int32_t k)47*5f39d1b3SJooyung Han   static std::int32_t ScratchPerThread(std::int32_t m, std::int32_t n,
48*5f39d1b3SJooyung Han                                        std::int32_t k) {
49*5f39d1b3SJooyung Han     return 128 * 1024;
50*5f39d1b3SJooyung Han   }
51*5f39d1b3SJooyung Han };
52*5f39d1b3SJooyung Han 
53*5f39d1b3SJooyung Han class GemvFloatOperation : public FloatOperation {
54*5f39d1b3SJooyung Han  public:
GemvFloatOperation(std::int32_t lhs_offset,std::int32_t rhs_offset,float result_offset)55*5f39d1b3SJooyung Han   GemvFloatOperation(std::int32_t lhs_offset, std::int32_t rhs_offset,
56*5f39d1b3SJooyung Han                      float result_offset)
57*5f39d1b3SJooyung Han       : FloatOperation(lhs_offset, rhs_offset, result_offset) {}
58*5f39d1b3SJooyung Han 
ExecuteMatrixMatrix(std::uint8_t * scratch,const std::uint8_t * lhs,const std::uint8_t * rhs,std::int32_t m,std::int32_t n,std::int32_t k,float * result,std::int32_t result_stride)59*5f39d1b3SJooyung Han   void ExecuteMatrixMatrix(std::uint8_t* scratch, const std::uint8_t* lhs,
60*5f39d1b3SJooyung Han                            const std::uint8_t* rhs, std::int32_t m,
61*5f39d1b3SJooyung Han                            std::int32_t n, std::int32_t k, float* result,
62*5f39d1b3SJooyung Han                            std::int32_t result_stride) const {
63*5f39d1b3SJooyung Han     gemv_f(scratch, lhs, rhs, n, k, lhs_offset, rhs_offset, result_offset,
64*5f39d1b3SJooyung Han            result);
65*5f39d1b3SJooyung Han   }
66*5f39d1b3SJooyung Han 
ScratchPerThread(std::int32_t m,std::int32_t n,std::int32_t k)67*5f39d1b3SJooyung Han   static std::int32_t ScratchPerThread(std::int32_t m, std::int32_t n,
68*5f39d1b3SJooyung Han                                        std::int32_t k) {
69*5f39d1b3SJooyung Han     return 128 * 1024;
70*5f39d1b3SJooyung Han   }
71*5f39d1b3SJooyung Han };
72*5f39d1b3SJooyung Han 
73*5f39d1b3SJooyung Han class GemvInt32Operation : public Int32Operation {
74*5f39d1b3SJooyung Han  public:
GemvInt32Operation(std::int32_t lhs_offset,std::int32_t rhs_offset)75*5f39d1b3SJooyung Han   GemvInt32Operation(std::int32_t lhs_offset, std::int32_t rhs_offset)
76*5f39d1b3SJooyung Han       : Int32Operation(lhs_offset, rhs_offset) {}
77*5f39d1b3SJooyung Han 
ExecuteMatrixMatrix(std::uint8_t * scratch,const std::uint8_t * lhs,const std::uint8_t * rhs,std::int32_t m,std::int32_t n,std::int32_t k,std::int32_t * result,std::int32_t result_stride)78*5f39d1b3SJooyung Han   void ExecuteMatrixMatrix(std::uint8_t* scratch, const std::uint8_t* lhs,
79*5f39d1b3SJooyung Han                            const std::uint8_t* rhs, std::int32_t m,
80*5f39d1b3SJooyung Han                            std::int32_t n, std::int32_t k, std::int32_t* result,
81*5f39d1b3SJooyung Han                            std::int32_t result_stride) const {
82*5f39d1b3SJooyung Han     gemv_i32(scratch, lhs, rhs, n, k, lhs_offset, rhs_offset, result);
83*5f39d1b3SJooyung Han   }
84*5f39d1b3SJooyung Han 
ScratchPerThread(std::int32_t m,std::int32_t n,std::int32_t k)85*5f39d1b3SJooyung Han   static std::int32_t ScratchPerThread(std::int32_t m, std::int32_t n,
86*5f39d1b3SJooyung Han                                        std::int32_t k) {
87*5f39d1b3SJooyung Han     return 128 * 1024;
88*5f39d1b3SJooyung Han   }
89*5f39d1b3SJooyung Han };
90*5f39d1b3SJooyung Han 
91*5f39d1b3SJooyung Han }  // namespace internal
92*5f39d1b3SJooyung Han 
gemv_q8_scratch(std::int32_t m,std::int32_t n,std::int32_t k,std::int32_t max_threads)93*5f39d1b3SJooyung Han std::int32_t gemv_q8_scratch(std::int32_t m, std::int32_t n, std::int32_t k,
94*5f39d1b3SJooyung Han                              std::int32_t max_threads) {
95*5f39d1b3SJooyung Han   return internal::ResolveMaxThreads(max_threads) *
96*5f39d1b3SJooyung Han          internal::GemvQuantized8BitOperation::ScratchPerThread(m, n, k);
97*5f39d1b3SJooyung Han }
98*5f39d1b3SJooyung Han 
multi_thread_gemv_q8(gemmlowp::WorkersPool * pool,std::int32_t max_threads,std::uint8_t * scratch,const std::uint8_t * lhs,const std::uint8_t * rhs,std::int32_t n,std::int32_t k,std::int32_t lhs_offset,std::int32_t rhs_offset,std::int32_t sum_offset,std::int32_t multiplier,std::int32_t shift,std::uint8_t * result)99*5f39d1b3SJooyung Han void multi_thread_gemv_q8(gemmlowp::WorkersPool* pool, std::int32_t max_threads,
100*5f39d1b3SJooyung Han                           std::uint8_t* scratch, const std::uint8_t* lhs,
101*5f39d1b3SJooyung Han                           const std::uint8_t* rhs, std::int32_t n,
102*5f39d1b3SJooyung Han                           std::int32_t k, std::int32_t lhs_offset,
103*5f39d1b3SJooyung Han                           std::int32_t rhs_offset, std::int32_t sum_offset,
104*5f39d1b3SJooyung Han                           std::int32_t multiplier, std::int32_t shift,
105*5f39d1b3SJooyung Han                           std::uint8_t* result) {
106*5f39d1b3SJooyung Han   max_threads = internal::ResolveMaxThreads(max_threads);
107*5f39d1b3SJooyung Han   internal::GemvQuantized8BitOperation operation(lhs_offset, rhs_offset,
108*5f39d1b3SJooyung Han                                                  sum_offset, multiplier, shift);
109*5f39d1b3SJooyung Han   if (max_threads == 1) {
110*5f39d1b3SJooyung Han     operation.ExecuteMatrixMatrix(scratch, lhs, rhs, 1, n, k, result, n);
111*5f39d1b3SJooyung Han   } else {
112*5f39d1b3SJooyung Han     internal::MultiThreadedMatrixMatrix(pool, max_threads, scratch, lhs, rhs, 1,
113*5f39d1b3SJooyung Han                                         n, k, result, n, operation);
114*5f39d1b3SJooyung Han   }
115*5f39d1b3SJooyung Han }
116*5f39d1b3SJooyung Han 
gemv_f_scratch(std::int32_t m,std::int32_t n,std::int32_t k,std::int32_t max_threads)117*5f39d1b3SJooyung Han std::int32_t gemv_f_scratch(std::int32_t m, std::int32_t n, std::int32_t k,
118*5f39d1b3SJooyung Han                             std::int32_t max_threads) {
119*5f39d1b3SJooyung Han   return internal::ResolveMaxThreads(max_threads) *
120*5f39d1b3SJooyung Han          internal::GemvFloatOperation::ScratchPerThread(m, n, k);
121*5f39d1b3SJooyung Han }
122*5f39d1b3SJooyung Han 
multi_thread_gemv_f(gemmlowp::WorkersPool * pool,std::int32_t max_threads,std::uint8_t * scratch,const std::uint8_t * lhs,const std::uint8_t * rhs,std::int32_t n,std::int32_t k,std::int32_t lhs_offset,std::int32_t rhs_offset,float result_offset,float * result)123*5f39d1b3SJooyung Han void multi_thread_gemv_f(gemmlowp::WorkersPool* pool, std::int32_t max_threads,
124*5f39d1b3SJooyung Han                          std::uint8_t* scratch, const std::uint8_t* lhs,
125*5f39d1b3SJooyung Han                          const std::uint8_t* rhs, std::int32_t n,
126*5f39d1b3SJooyung Han                          std::int32_t k, std::int32_t lhs_offset,
127*5f39d1b3SJooyung Han                          std::int32_t rhs_offset, float result_offset,
128*5f39d1b3SJooyung Han                          float* result) {
129*5f39d1b3SJooyung Han   max_threads = internal::ResolveMaxThreads(max_threads);
130*5f39d1b3SJooyung Han   internal::GemvFloatOperation operation(lhs_offset, rhs_offset, result_offset);
131*5f39d1b3SJooyung Han   if (max_threads == 1) {
132*5f39d1b3SJooyung Han     operation.ExecuteMatrixMatrix(scratch, lhs, rhs, 1, n, k, result, n);
133*5f39d1b3SJooyung Han   } else {
134*5f39d1b3SJooyung Han     internal::MultiThreadedMatrixMatrix(pool, max_threads, scratch, lhs, rhs, 1,
135*5f39d1b3SJooyung Han                                         n, k, result, n, operation);
136*5f39d1b3SJooyung Han   }
137*5f39d1b3SJooyung Han }
138*5f39d1b3SJooyung Han 
gemv_i32_scratch(std::int32_t m,std::int32_t n,std::int32_t k,std::int32_t max_threads)139*5f39d1b3SJooyung Han std::int32_t gemv_i32_scratch(std::int32_t m, std::int32_t n, std::int32_t k,
140*5f39d1b3SJooyung Han                               std::int32_t max_threads) {
141*5f39d1b3SJooyung Han   return internal::ResolveMaxThreads(max_threads) *
142*5f39d1b3SJooyung Han          internal::GemvInt32Operation::ScratchPerThread(m, n, k);
143*5f39d1b3SJooyung Han }
144*5f39d1b3SJooyung Han 
multi_thread_gemv_i32(gemmlowp::WorkersPool * pool,std::int32_t max_threads,std::uint8_t * scratch,const std::uint8_t * lhs,const std::uint8_t * rhs,std::int32_t n,std::int32_t k,std::int32_t lhs_offset,std::int32_t rhs_offset,std::int32_t * result)145*5f39d1b3SJooyung Han void multi_thread_gemv_i32(gemmlowp::WorkersPool* pool,
146*5f39d1b3SJooyung Han                            std::int32_t max_threads, std::uint8_t* scratch,
147*5f39d1b3SJooyung Han                            const std::uint8_t* lhs, const std::uint8_t* rhs,
148*5f39d1b3SJooyung Han                            std::int32_t n, std::int32_t k,
149*5f39d1b3SJooyung Han                            std::int32_t lhs_offset, std::int32_t rhs_offset,
150*5f39d1b3SJooyung Han                            std::int32_t* result) {
151*5f39d1b3SJooyung Han   max_threads = internal::ResolveMaxThreads(max_threads);
152*5f39d1b3SJooyung Han   internal::GemvInt32Operation operation(lhs_offset, rhs_offset);
153*5f39d1b3SJooyung Han   if (max_threads == 1) {
154*5f39d1b3SJooyung Han     operation.ExecuteMatrixMatrix(scratch, lhs, rhs, 1, n, k, result, n);
155*5f39d1b3SJooyung Han   } else {
156*5f39d1b3SJooyung Han     internal::MultiThreadedMatrixMatrix(pool, max_threads, scratch, lhs, rhs, 1,
157*5f39d1b3SJooyung Han                                         n, k, result, n, operation);
158*5f39d1b3SJooyung Han   }
159*5f39d1b3SJooyung Han }
160*5f39d1b3SJooyung Han 
161*5f39d1b3SJooyung Han }  // namespace meta
162*5f39d1b3SJooyung Han }  // namespace gemmlowp
163*5f39d1b3SJooyung Han 
164*5f39d1b3SJooyung Han #else
165*5f39d1b3SJooyung Han #warning "Meta gemm fast-path requires GEMMLOWP_NEON_(32|64)!"
166*5f39d1b3SJooyung Han #endif
167*5f39d1b3SJooyung Han 
168*5f39d1b3SJooyung Han #endif  // GEMMLOWP_META_MULTI_THREAD_GEMV_H_
169