1*5f39d1b3SJooyung Han // Copyright 2016 The Gemmlowp Authors. All Rights Reserved.
2*5f39d1b3SJooyung Han //
3*5f39d1b3SJooyung Han // Licensed under the Apache License, Version 2.0 (the "License");
4*5f39d1b3SJooyung Han // you may not use this file except in compliance with the License.
5*5f39d1b3SJooyung Han // You may obtain a copy of the License at
6*5f39d1b3SJooyung Han //
7*5f39d1b3SJooyung Han // http://www.apache.org/licenses/LICENSE-2.0
8*5f39d1b3SJooyung Han //
9*5f39d1b3SJooyung Han // Unless required by applicable law or agreed to in writing, software
10*5f39d1b3SJooyung Han // distributed under the License is distributed on an "AS IS" BASIS,
11*5f39d1b3SJooyung Han // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*5f39d1b3SJooyung Han // See the License for the specific language governing permissions and
13*5f39d1b3SJooyung Han // limitations under the License.
14*5f39d1b3SJooyung Han
15*5f39d1b3SJooyung Han #ifndef GEMMLOWP_META_MULTI_THREAD_GEMM_H_
16*5f39d1b3SJooyung Han #define GEMMLOWP_META_MULTI_THREAD_GEMM_H_
17*5f39d1b3SJooyung Han
18*5f39d1b3SJooyung Han #include "multi_thread_common.h"
19*5f39d1b3SJooyung Han #include "single_thread_gemm.h"
20*5f39d1b3SJooyung Han
21*5f39d1b3SJooyung Han namespace gemmlowp {
22*5f39d1b3SJooyung Han namespace meta {
23*5f39d1b3SJooyung Han namespace internal {
24*5f39d1b3SJooyung Han
25*5f39d1b3SJooyung Han const std::int32_t kMinGemmTaskSize = 16000;
26*5f39d1b3SJooyung Han const std::int32_t kMinGemmTaskDimension = 4;
27*5f39d1b3SJooyung Han
28*5f39d1b3SJooyung Han template <typename Executor, typename Params>
PrepareGemmTask(const Params & params,int kernel_m,int kernel_n,int kernel_k,std::uint8_t * scratch,int m_start,int m,int n_start,int n,std::vector<Params> * tasks)29*5f39d1b3SJooyung Han std::uint8_t* PrepareGemmTask(const Params& params, int kernel_m, int kernel_n,
30*5f39d1b3SJooyung Han int kernel_k, std::uint8_t* scratch, int m_start,
31*5f39d1b3SJooyung Han int m, int n_start, int n,
32*5f39d1b3SJooyung Han std::vector<Params>* tasks) {
33*5f39d1b3SJooyung Han tasks->push_back(params);
34*5f39d1b3SJooyung Han Params& task = tasks->back();
35*5f39d1b3SJooyung Han task.scratch = scratch;
36*5f39d1b3SJooyung Han
37*5f39d1b3SJooyung Han task.m = m;
38*5f39d1b3SJooyung Han task.lhs =
39*5f39d1b3SJooyung Han StreamUtil<typename Params::InType, typename Params::LeftStream>::Offset(
40*5f39d1b3SJooyung Han params.left_stream, params.lhs, m_start, 0);
41*5f39d1b3SJooyung Han
42*5f39d1b3SJooyung Han task.n = n;
43*5f39d1b3SJooyung Han task.rhs =
44*5f39d1b3SJooyung Han StreamUtil<typename Params::InType, typename Params::RightStream>::Offset(
45*5f39d1b3SJooyung Han params.right_stream, params.rhs, n_start, 0);
46*5f39d1b3SJooyung Han
47*5f39d1b3SJooyung Han task.result =
48*5f39d1b3SJooyung Han StreamUtil<typename Params::OutType, typename Params::OutputStream>::
49*5f39d1b3SJooyung Han Offset(params.fused_kernel.output_stream, params.result, m_start,
50*5f39d1b3SJooyung Han n_start);
51*5f39d1b3SJooyung Han
52*5f39d1b3SJooyung Han return scratch + Executor::template EstimateScratchSize<Params>(
53*5f39d1b3SJooyung Han task, kernel_m, kernel_n, kernel_k);
54*5f39d1b3SJooyung Han }
55*5f39d1b3SJooyung Han
56*5f39d1b3SJooyung Han template <typename MultiThreadingContext, typename Executor, typename Params>
PrepareGemmTasks(MultiThreadingContext * context,const Params & params,int kernel_m,int kernel_n,int kernel_k,std::vector<Params> * task_params)57*5f39d1b3SJooyung Han bool PrepareGemmTasks(MultiThreadingContext* context, const Params& params,
58*5f39d1b3SJooyung Han int kernel_m, int kernel_n, int kernel_k,
59*5f39d1b3SJooyung Han std::vector<Params>* task_params) {
60*5f39d1b3SJooyung Han const int max_threads = ResolveMaxThreads(context->max_num_threads());
61*5f39d1b3SJooyung Han const int max_tasks_by_size =
62*5f39d1b3SJooyung Han (params.m * params.n * params.k) / kMinGemmTaskSize;
63*5f39d1b3SJooyung Han const int max_tasks_m = params.m / kMinGemmTaskDimension;
64*5f39d1b3SJooyung Han const int max_tasks_n = params.n / kMinGemmTaskDimension;
65*5f39d1b3SJooyung Han const int max_tasks_dimension = std::max(max_tasks_m, max_tasks_n);
66*5f39d1b3SJooyung Han
67*5f39d1b3SJooyung Han const int real_tasks = std::max(
68*5f39d1b3SJooyung Han 1,
69*5f39d1b3SJooyung Han std::min(max_threads, std::min(max_tasks_by_size, max_tasks_dimension)));
70*5f39d1b3SJooyung Han
71*5f39d1b3SJooyung Han if (real_tasks == 1) {
72*5f39d1b3SJooyung Han return false;
73*5f39d1b3SJooyung Han }
74*5f39d1b3SJooyung Han
75*5f39d1b3SJooyung Han std::uint8_t* scratch = params.scratch;
76*5f39d1b3SJooyung Han
77*5f39d1b3SJooyung Han if (max_tasks_m > max_tasks_n) {
78*5f39d1b3SJooyung Han const int m_chunk = params.m / real_tasks;
79*5f39d1b3SJooyung Han for (int i = 0; i < real_tasks - 1; ++i) {
80*5f39d1b3SJooyung Han scratch = PrepareGemmTask<Executor, Params>(
81*5f39d1b3SJooyung Han params, kernel_m, kernel_n, kernel_k, scratch, i * m_chunk, m_chunk,
82*5f39d1b3SJooyung Han 0, params.n, task_params);
83*5f39d1b3SJooyung Han }
84*5f39d1b3SJooyung Han const int sum_m = (real_tasks - 1) * m_chunk;
85*5f39d1b3SJooyung Han PrepareGemmTask<Executor, Params>(params, kernel_m, kernel_n, kernel_k,
86*5f39d1b3SJooyung Han scratch, sum_m, params.m - sum_m, 0,
87*5f39d1b3SJooyung Han params.n, task_params);
88*5f39d1b3SJooyung Han } else {
89*5f39d1b3SJooyung Han const int n_chunk = params.n / real_tasks;
90*5f39d1b3SJooyung Han for (int i = 0; i < real_tasks - 1; ++i) {
91*5f39d1b3SJooyung Han scratch = PrepareGemmTask<Executor, Params>(
92*5f39d1b3SJooyung Han params, kernel_m, kernel_n, kernel_k, scratch, 0, params.m,
93*5f39d1b3SJooyung Han i * n_chunk, n_chunk, task_params);
94*5f39d1b3SJooyung Han }
95*5f39d1b3SJooyung Han int sum_n = (real_tasks - 1) * n_chunk;
96*5f39d1b3SJooyung Han PrepareGemmTask<Executor, Params>(params, kernel_m, kernel_n, kernel_k,
97*5f39d1b3SJooyung Han scratch, 0, params.m, sum_n,
98*5f39d1b3SJooyung Han params.n - sum_n, task_params);
99*5f39d1b3SJooyung Han }
100*5f39d1b3SJooyung Han
101*5f39d1b3SJooyung Han return true;
102*5f39d1b3SJooyung Han }
103*5f39d1b3SJooyung Han
104*5f39d1b3SJooyung Han template <typename Executor, typename Params, int kernel_m, int kernel_n,
105*5f39d1b3SJooyung Han int kernel_k>
106*5f39d1b3SJooyung Han struct GemmTaskRunner : gemmlowp::Task {
GemmTaskRunnerGemmTaskRunner107*5f39d1b3SJooyung Han GemmTaskRunner(const Params& params) : params(params) {}
108*5f39d1b3SJooyung Han
RunGemmTaskRunner109*5f39d1b3SJooyung Han void Run() override {
110*5f39d1b3SJooyung Han Gemm<Executor, Params, kernel_m, kernel_n, kernel_k>(params);
111*5f39d1b3SJooyung Han }
112*5f39d1b3SJooyung Han
113*5f39d1b3SJooyung Han Params params;
114*5f39d1b3SJooyung Han };
115*5f39d1b3SJooyung Han
116*5f39d1b3SJooyung Han } // namespace internal
117*5f39d1b3SJooyung Han
118*5f39d1b3SJooyung Han template <typename MultiThreadingContext, typename Executor, typename Params,
119*5f39d1b3SJooyung Han int kernel_m, int kernel_n, int kernel_k>
MultiThreadGemm(MultiThreadingContext * context,const Params & params)120*5f39d1b3SJooyung Han inline void MultiThreadGemm(MultiThreadingContext* context,
121*5f39d1b3SJooyung Han const Params& params) {
122*5f39d1b3SJooyung Han typedef internal::GemmTaskRunner<Executor, Params, kernel_m, kernel_n,
123*5f39d1b3SJooyung Han kernel_k>
124*5f39d1b3SJooyung Han TaskRunnerType;
125*5f39d1b3SJooyung Han
126*5f39d1b3SJooyung Han std::vector<Params> task_params;
127*5f39d1b3SJooyung Han if (!internal::PrepareGemmTasks<MultiThreadingContext, Executor, Params>(
128*5f39d1b3SJooyung Han context, params, kernel_m, kernel_n, kernel_k, &task_params)) {
129*5f39d1b3SJooyung Han Gemm<Executor, Params, kernel_m, kernel_n, kernel_k>(params);
130*5f39d1b3SJooyung Han return;
131*5f39d1b3SJooyung Han }
132*5f39d1b3SJooyung Han
133*5f39d1b3SJooyung Han auto workers_pool = context->workers_pool();
134*5f39d1b3SJooyung Han std::vector<Task*> tasks;
135*5f39d1b3SJooyung Han for (auto& task_param : task_params) {
136*5f39d1b3SJooyung Han tasks.push_back(new TaskRunnerType(task_param));
137*5f39d1b3SJooyung Han };
138*5f39d1b3SJooyung Han workers_pool->Execute(tasks);
139*5f39d1b3SJooyung Han }
140*5f39d1b3SJooyung Han
141*5f39d1b3SJooyung Han } // namespace meta
142*5f39d1b3SJooyung Han } // namespace gemmlowp
143*5f39d1b3SJooyung Han
144*5f39d1b3SJooyung Han #endif // GEMMLOWP_META_MULTI_THREAD_GEMM_H_
145