xref: /aosp_15_r20/external/gemmlowp/meta/multi_thread_gemm.h (revision 5f39d1b313f0528e11bae88b3029b54b9e1033e7)
1*5f39d1b3SJooyung Han // Copyright 2016 The Gemmlowp Authors. All Rights Reserved.
2*5f39d1b3SJooyung Han //
3*5f39d1b3SJooyung Han // Licensed under the Apache License, Version 2.0 (the "License");
4*5f39d1b3SJooyung Han // you may not use this file except in compliance with the License.
5*5f39d1b3SJooyung Han // You may obtain a copy of the License at
6*5f39d1b3SJooyung Han //
7*5f39d1b3SJooyung Han //     http://www.apache.org/licenses/LICENSE-2.0
8*5f39d1b3SJooyung Han //
9*5f39d1b3SJooyung Han // Unless required by applicable law or agreed to in writing, software
10*5f39d1b3SJooyung Han // distributed under the License is distributed on an "AS IS" BASIS,
11*5f39d1b3SJooyung Han // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*5f39d1b3SJooyung Han // See the License for the specific language governing permissions and
13*5f39d1b3SJooyung Han // limitations under the License.
14*5f39d1b3SJooyung Han 
15*5f39d1b3SJooyung Han #ifndef GEMMLOWP_META_MULTI_THREAD_GEMM_H_
16*5f39d1b3SJooyung Han #define GEMMLOWP_META_MULTI_THREAD_GEMM_H_
17*5f39d1b3SJooyung Han 
18*5f39d1b3SJooyung Han #include "multi_thread_common.h"
19*5f39d1b3SJooyung Han #include "single_thread_gemm.h"
20*5f39d1b3SJooyung Han 
21*5f39d1b3SJooyung Han namespace gemmlowp {
22*5f39d1b3SJooyung Han namespace meta {
23*5f39d1b3SJooyung Han namespace internal {
24*5f39d1b3SJooyung Han 
25*5f39d1b3SJooyung Han const std::int32_t kMinGemmTaskSize = 16000;
26*5f39d1b3SJooyung Han const std::int32_t kMinGemmTaskDimension = 4;
27*5f39d1b3SJooyung Han 
28*5f39d1b3SJooyung Han template <typename Executor, typename Params>
PrepareGemmTask(const Params & params,int kernel_m,int kernel_n,int kernel_k,std::uint8_t * scratch,int m_start,int m,int n_start,int n,std::vector<Params> * tasks)29*5f39d1b3SJooyung Han std::uint8_t* PrepareGemmTask(const Params& params, int kernel_m, int kernel_n,
30*5f39d1b3SJooyung Han                               int kernel_k, std::uint8_t* scratch, int m_start,
31*5f39d1b3SJooyung Han                               int m, int n_start, int n,
32*5f39d1b3SJooyung Han                               std::vector<Params>* tasks) {
33*5f39d1b3SJooyung Han   tasks->push_back(params);
34*5f39d1b3SJooyung Han   Params& task = tasks->back();
35*5f39d1b3SJooyung Han   task.scratch = scratch;
36*5f39d1b3SJooyung Han 
37*5f39d1b3SJooyung Han   task.m = m;
38*5f39d1b3SJooyung Han   task.lhs =
39*5f39d1b3SJooyung Han       StreamUtil<typename Params::InType, typename Params::LeftStream>::Offset(
40*5f39d1b3SJooyung Han           params.left_stream, params.lhs, m_start, 0);
41*5f39d1b3SJooyung Han 
42*5f39d1b3SJooyung Han   task.n = n;
43*5f39d1b3SJooyung Han   task.rhs =
44*5f39d1b3SJooyung Han       StreamUtil<typename Params::InType, typename Params::RightStream>::Offset(
45*5f39d1b3SJooyung Han           params.right_stream, params.rhs, n_start, 0);
46*5f39d1b3SJooyung Han 
47*5f39d1b3SJooyung Han   task.result =
48*5f39d1b3SJooyung Han       StreamUtil<typename Params::OutType, typename Params::OutputStream>::
49*5f39d1b3SJooyung Han           Offset(params.fused_kernel.output_stream, params.result, m_start,
50*5f39d1b3SJooyung Han                  n_start);
51*5f39d1b3SJooyung Han 
52*5f39d1b3SJooyung Han   return scratch + Executor::template EstimateScratchSize<Params>(
53*5f39d1b3SJooyung Han                        task, kernel_m, kernel_n, kernel_k);
54*5f39d1b3SJooyung Han }
55*5f39d1b3SJooyung Han 
56*5f39d1b3SJooyung Han template <typename MultiThreadingContext, typename Executor, typename Params>
PrepareGemmTasks(MultiThreadingContext * context,const Params & params,int kernel_m,int kernel_n,int kernel_k,std::vector<Params> * task_params)57*5f39d1b3SJooyung Han bool PrepareGemmTasks(MultiThreadingContext* context, const Params& params,
58*5f39d1b3SJooyung Han                       int kernel_m, int kernel_n, int kernel_k,
59*5f39d1b3SJooyung Han                       std::vector<Params>* task_params) {
60*5f39d1b3SJooyung Han   const int max_threads = ResolveMaxThreads(context->max_num_threads());
61*5f39d1b3SJooyung Han   const int max_tasks_by_size =
62*5f39d1b3SJooyung Han       (params.m * params.n * params.k) / kMinGemmTaskSize;
63*5f39d1b3SJooyung Han   const int max_tasks_m = params.m / kMinGemmTaskDimension;
64*5f39d1b3SJooyung Han   const int max_tasks_n = params.n / kMinGemmTaskDimension;
65*5f39d1b3SJooyung Han   const int max_tasks_dimension = std::max(max_tasks_m, max_tasks_n);
66*5f39d1b3SJooyung Han 
67*5f39d1b3SJooyung Han   const int real_tasks = std::max(
68*5f39d1b3SJooyung Han       1,
69*5f39d1b3SJooyung Han       std::min(max_threads, std::min(max_tasks_by_size, max_tasks_dimension)));
70*5f39d1b3SJooyung Han 
71*5f39d1b3SJooyung Han   if (real_tasks == 1) {
72*5f39d1b3SJooyung Han     return false;
73*5f39d1b3SJooyung Han   }
74*5f39d1b3SJooyung Han 
75*5f39d1b3SJooyung Han   std::uint8_t* scratch = params.scratch;
76*5f39d1b3SJooyung Han 
77*5f39d1b3SJooyung Han   if (max_tasks_m > max_tasks_n) {
78*5f39d1b3SJooyung Han     const int m_chunk = params.m / real_tasks;
79*5f39d1b3SJooyung Han     for (int i = 0; i < real_tasks - 1; ++i) {
80*5f39d1b3SJooyung Han       scratch = PrepareGemmTask<Executor, Params>(
81*5f39d1b3SJooyung Han           params, kernel_m, kernel_n, kernel_k, scratch, i * m_chunk, m_chunk,
82*5f39d1b3SJooyung Han           0, params.n, task_params);
83*5f39d1b3SJooyung Han     }
84*5f39d1b3SJooyung Han     const int sum_m = (real_tasks - 1) * m_chunk;
85*5f39d1b3SJooyung Han     PrepareGemmTask<Executor, Params>(params, kernel_m, kernel_n, kernel_k,
86*5f39d1b3SJooyung Han                                       scratch, sum_m, params.m - sum_m, 0,
87*5f39d1b3SJooyung Han                                       params.n, task_params);
88*5f39d1b3SJooyung Han   } else {
89*5f39d1b3SJooyung Han     const int n_chunk = params.n / real_tasks;
90*5f39d1b3SJooyung Han     for (int i = 0; i < real_tasks - 1; ++i) {
91*5f39d1b3SJooyung Han       scratch = PrepareGemmTask<Executor, Params>(
92*5f39d1b3SJooyung Han           params, kernel_m, kernel_n, kernel_k, scratch, 0, params.m,
93*5f39d1b3SJooyung Han           i * n_chunk, n_chunk, task_params);
94*5f39d1b3SJooyung Han     }
95*5f39d1b3SJooyung Han     int sum_n = (real_tasks - 1) * n_chunk;
96*5f39d1b3SJooyung Han     PrepareGemmTask<Executor, Params>(params, kernel_m, kernel_n, kernel_k,
97*5f39d1b3SJooyung Han                                       scratch, 0, params.m, sum_n,
98*5f39d1b3SJooyung Han                                       params.n - sum_n, task_params);
99*5f39d1b3SJooyung Han   }
100*5f39d1b3SJooyung Han 
101*5f39d1b3SJooyung Han   return true;
102*5f39d1b3SJooyung Han }
103*5f39d1b3SJooyung Han 
104*5f39d1b3SJooyung Han template <typename Executor, typename Params, int kernel_m, int kernel_n,
105*5f39d1b3SJooyung Han           int kernel_k>
106*5f39d1b3SJooyung Han struct GemmTaskRunner : gemmlowp::Task {
GemmTaskRunnerGemmTaskRunner107*5f39d1b3SJooyung Han   GemmTaskRunner(const Params& params) : params(params) {}
108*5f39d1b3SJooyung Han 
RunGemmTaskRunner109*5f39d1b3SJooyung Han   void Run() override {
110*5f39d1b3SJooyung Han     Gemm<Executor, Params, kernel_m, kernel_n, kernel_k>(params);
111*5f39d1b3SJooyung Han   }
112*5f39d1b3SJooyung Han 
113*5f39d1b3SJooyung Han   Params params;
114*5f39d1b3SJooyung Han };
115*5f39d1b3SJooyung Han 
116*5f39d1b3SJooyung Han }  // namespace internal
117*5f39d1b3SJooyung Han 
118*5f39d1b3SJooyung Han template <typename MultiThreadingContext, typename Executor, typename Params,
119*5f39d1b3SJooyung Han           int kernel_m, int kernel_n, int kernel_k>
MultiThreadGemm(MultiThreadingContext * context,const Params & params)120*5f39d1b3SJooyung Han inline void MultiThreadGemm(MultiThreadingContext* context,
121*5f39d1b3SJooyung Han                             const Params& params) {
122*5f39d1b3SJooyung Han   typedef internal::GemmTaskRunner<Executor, Params, kernel_m, kernel_n,
123*5f39d1b3SJooyung Han                                    kernel_k>
124*5f39d1b3SJooyung Han       TaskRunnerType;
125*5f39d1b3SJooyung Han 
126*5f39d1b3SJooyung Han   std::vector<Params> task_params;
127*5f39d1b3SJooyung Han   if (!internal::PrepareGemmTasks<MultiThreadingContext, Executor, Params>(
128*5f39d1b3SJooyung Han           context, params, kernel_m, kernel_n, kernel_k, &task_params)) {
129*5f39d1b3SJooyung Han     Gemm<Executor, Params, kernel_m, kernel_n, kernel_k>(params);
130*5f39d1b3SJooyung Han     return;
131*5f39d1b3SJooyung Han   }
132*5f39d1b3SJooyung Han 
133*5f39d1b3SJooyung Han   auto workers_pool = context->workers_pool();
134*5f39d1b3SJooyung Han   std::vector<Task*> tasks;
135*5f39d1b3SJooyung Han   for (auto& task_param : task_params) {
136*5f39d1b3SJooyung Han     tasks.push_back(new TaskRunnerType(task_param));
137*5f39d1b3SJooyung Han   };
138*5f39d1b3SJooyung Han   workers_pool->Execute(tasks);
139*5f39d1b3SJooyung Han }
140*5f39d1b3SJooyung Han 
141*5f39d1b3SJooyung Han }  // namespace meta
142*5f39d1b3SJooyung Han }  // namespace gemmlowp
143*5f39d1b3SJooyung Han 
144*5f39d1b3SJooyung Han #endif  // GEMMLOWP_META_MULTI_THREAD_GEMM_H_
145