1*5f39d1b3SJooyung Han // Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
2*5f39d1b3SJooyung Han //
3*5f39d1b3SJooyung Han // Licensed under the Apache License, Version 2.0 (the "License");
4*5f39d1b3SJooyung Han // you may not use this file except in compliance with the License.
5*5f39d1b3SJooyung Han // You may obtain a copy of the License at
6*5f39d1b3SJooyung Han //
7*5f39d1b3SJooyung Han // http://www.apache.org/licenses/LICENSE-2.0
8*5f39d1b3SJooyung Han //
9*5f39d1b3SJooyung Han // Unless required by applicable law or agreed to in writing, software
10*5f39d1b3SJooyung Han // distributed under the License is distributed on an "AS IS" BASIS,
11*5f39d1b3SJooyung Han // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*5f39d1b3SJooyung Han // See the License for the specific language governing permissions and
13*5f39d1b3SJooyung Han // limitations under the License.
14*5f39d1b3SJooyung Han
15*5f39d1b3SJooyung Han // multi_thread_common.h: Multithreading code shared by different meta gemm
16*5f39d1b3SJooyung Han // versions.
17*5f39d1b3SJooyung Han
18*5f39d1b3SJooyung Han #ifndef GEMMLOWP_META_MULTI_THREAD_COMMON_H_
19*5f39d1b3SJooyung Han #define GEMMLOWP_META_MULTI_THREAD_COMMON_H_
20*5f39d1b3SJooyung Han
21*5f39d1b3SJooyung Han #include "../internal/multi_thread_gemm.h"
22*5f39d1b3SJooyung Han
23*5f39d1b3SJooyung Han namespace gemmlowp {
24*5f39d1b3SJooyung Han namespace meta {
25*5f39d1b3SJooyung Han namespace internal {
26*5f39d1b3SJooyung Han
27*5f39d1b3SJooyung Han const std::int32_t kMinTaskSize = 16000;
28*5f39d1b3SJooyung Han const std::int32_t kMinTaskDimension = 4;
29*5f39d1b3SJooyung Han
30*5f39d1b3SJooyung Han struct TaskRect {
31*5f39d1b3SJooyung Han std::int32_t m_offset;
32*5f39d1b3SJooyung Han std::int32_t m;
33*5f39d1b3SJooyung Han std::int32_t n_offset;
34*5f39d1b3SJooyung Han std::int32_t n;
35*5f39d1b3SJooyung Han
TaskRectTaskRect36*5f39d1b3SJooyung Han TaskRect(std::int32_t m_offset, std::int32_t m, std::int32_t n_offset,
37*5f39d1b3SJooyung Han std::int32_t n)
38*5f39d1b3SJooyung Han : m_offset(m_offset), m(m), n_offset(n_offset), n(n) {}
39*5f39d1b3SJooyung Han };
40*5f39d1b3SJooyung Han
41*5f39d1b3SJooyung Han template <typename IN_TYPE, typename OUT_TYPE, typename F>
42*5f39d1b3SJooyung Han struct MetaTask : gemmlowp::Task {
43*5f39d1b3SJooyung Han std::uint8_t* scratch;
44*5f39d1b3SJooyung Han const IN_TYPE* lhs;
45*5f39d1b3SJooyung Han const IN_TYPE* rhs;
46*5f39d1b3SJooyung Han TaskRect task_rect;
47*5f39d1b3SJooyung Han std::int32_t k;
48*5f39d1b3SJooyung Han OUT_TYPE* result;
49*5f39d1b3SJooyung Han std::int32_t result_stride;
50*5f39d1b3SJooyung Han const F& operation;
51*5f39d1b3SJooyung Han
MetaTaskMetaTask52*5f39d1b3SJooyung Han MetaTask(std::uint8_t* scratch, const IN_TYPE* lhs, const IN_TYPE* rhs,
53*5f39d1b3SJooyung Han const TaskRect& task_rect, std::int32_t k, OUT_TYPE* result,
54*5f39d1b3SJooyung Han std::int32_t result_stride, const F& operation)
55*5f39d1b3SJooyung Han : scratch(scratch),
56*5f39d1b3SJooyung Han lhs(lhs),
57*5f39d1b3SJooyung Han rhs(rhs),
58*5f39d1b3SJooyung Han task_rect(task_rect),
59*5f39d1b3SJooyung Han k(k),
60*5f39d1b3SJooyung Han result(result),
61*5f39d1b3SJooyung Han result_stride(result_stride),
62*5f39d1b3SJooyung Han operation(operation) {}
63*5f39d1b3SJooyung Han
RunMetaTask64*5f39d1b3SJooyung Han void Run() override {
65*5f39d1b3SJooyung Han const IN_TYPE* task_lhs = lhs + task_rect.m_offset * k;
66*5f39d1b3SJooyung Han const IN_TYPE* task_rhs = rhs + task_rect.n_offset * k;
67*5f39d1b3SJooyung Han OUT_TYPE* task_result =
68*5f39d1b3SJooyung Han result + task_rect.m_offset * result_stride + task_rect.n_offset;
69*5f39d1b3SJooyung Han operation.ExecuteMatrixMatrix(scratch, task_lhs, task_rhs, task_rect.m,
70*5f39d1b3SJooyung Han task_rect.n, k, task_result, result_stride);
71*5f39d1b3SJooyung Han }
72*5f39d1b3SJooyung Han };
73*5f39d1b3SJooyung Han
ResolveMaxThreads(std::int32_t max_threads)74*5f39d1b3SJooyung Han std::int32_t ResolveMaxThreads(std::int32_t max_threads) {
75*5f39d1b3SJooyung Han if (max_threads == 0) {
76*5f39d1b3SJooyung Han static const int hardware_threads_count =
77*5f39d1b3SJooyung Han static_cast<int>(sysconf(_SC_NPROCESSORS_CONF));
78*5f39d1b3SJooyung Han return hardware_threads_count;
79*5f39d1b3SJooyung Han }
80*5f39d1b3SJooyung Han return max_threads;
81*5f39d1b3SJooyung Han }
82*5f39d1b3SJooyung Han
PrepareTasks(std::int32_t max_tasks,std::int32_t m,std::int32_t n,std::int32_t k,std::vector<internal::TaskRect> * tasks)83*5f39d1b3SJooyung Han void PrepareTasks(std::int32_t max_tasks, std::int32_t m, std::int32_t n,
84*5f39d1b3SJooyung Han std::int32_t k, std::vector<internal::TaskRect>* tasks) {
85*5f39d1b3SJooyung Han const std::int32_t max_tasks_by_size = (m * n * k) / kMinTaskSize;
86*5f39d1b3SJooyung Han const std::int32_t max_tasks_m = m / kMinTaskDimension;
87*5f39d1b3SJooyung Han const std::int32_t max_tasks_n = n / kMinTaskDimension;
88*5f39d1b3SJooyung Han const std::int32_t max_tasks_dimension = std::max(max_tasks_m, max_tasks_n);
89*5f39d1b3SJooyung Han
90*5f39d1b3SJooyung Han std::int32_t real_tasks = std::max(
91*5f39d1b3SJooyung Han 1, std::min(max_tasks, std::min(max_tasks_by_size, max_tasks_dimension)));
92*5f39d1b3SJooyung Han
93*5f39d1b3SJooyung Han if (real_tasks == 1) {
94*5f39d1b3SJooyung Han tasks->push_back(TaskRect(0, m, 0, n));
95*5f39d1b3SJooyung Han return;
96*5f39d1b3SJooyung Han }
97*5f39d1b3SJooyung Han
98*5f39d1b3SJooyung Han if (max_tasks_m > max_tasks_n) {
99*5f39d1b3SJooyung Han const std::int32_t m_chunk = m / real_tasks;
100*5f39d1b3SJooyung Han for (int i = 0; i < real_tasks - 1; ++i) {
101*5f39d1b3SJooyung Han tasks->push_back(TaskRect(i * m_chunk, m_chunk, 0, n));
102*5f39d1b3SJooyung Han }
103*5f39d1b3SJooyung Han const std::int32_t last_m_offset = (real_tasks - 1) * m_chunk;
104*5f39d1b3SJooyung Han tasks->push_back(TaskRect(last_m_offset, m - last_m_offset, 0, n));
105*5f39d1b3SJooyung Han } else {
106*5f39d1b3SJooyung Han const std::int32_t n_chunk = n / real_tasks;
107*5f39d1b3SJooyung Han for (int i = 0; i < real_tasks - 1; ++i) {
108*5f39d1b3SJooyung Han tasks->push_back(TaskRect(0, m, i * n_chunk, n_chunk));
109*5f39d1b3SJooyung Han }
110*5f39d1b3SJooyung Han const std::int32_t last_n_offset = (real_tasks - 1) * n_chunk;
111*5f39d1b3SJooyung Han tasks->push_back(TaskRect(0, m, last_n_offset, n - last_n_offset));
112*5f39d1b3SJooyung Han }
113*5f39d1b3SJooyung Han }
114*5f39d1b3SJooyung Han
115*5f39d1b3SJooyung Han template <typename IN_TYPE, typename OUT_TYPE, typename F>
MultiThreadedMatrixMatrix(gemmlowp::WorkersPool * pool,std::int32_t max_threads,std::uint8_t * scratch,const IN_TYPE * lhs,const IN_TYPE * rhs,std::int32_t m,std::int32_t n,std::int32_t k,OUT_TYPE * result,std::int32_t result_stride,const F & operation)116*5f39d1b3SJooyung Han void MultiThreadedMatrixMatrix(gemmlowp::WorkersPool* pool,
117*5f39d1b3SJooyung Han std::int32_t max_threads, std::uint8_t* scratch,
118*5f39d1b3SJooyung Han const IN_TYPE* lhs, const IN_TYPE* rhs,
119*5f39d1b3SJooyung Han std::int32_t m, std::int32_t n, std::int32_t k,
120*5f39d1b3SJooyung Han OUT_TYPE* result, std::int32_t result_stride,
121*5f39d1b3SJooyung Han const F& operation) {
122*5f39d1b3SJooyung Han max_threads = internal::ResolveMaxThreads(max_threads);
123*5f39d1b3SJooyung Han
124*5f39d1b3SJooyung Han std::vector<internal::TaskRect> task_rects;
125*5f39d1b3SJooyung Han internal::PrepareTasks(max_threads, m, n, k, &task_rects);
126*5f39d1b3SJooyung Han
127*5f39d1b3SJooyung Han if (task_rects.size() == 1) {
128*5f39d1b3SJooyung Han operation.ExecuteMatrixMatrix(scratch, lhs, rhs, m, n, k, result,
129*5f39d1b3SJooyung Han result_stride);
130*5f39d1b3SJooyung Han return;
131*5f39d1b3SJooyung Han }
132*5f39d1b3SJooyung Han
133*5f39d1b3SJooyung Han std::uint8_t* task_scratch = scratch;
134*5f39d1b3SJooyung Han std::int32_t scratch_per_thread = operation.ScratchPerThread(m, n, k);
135*5f39d1b3SJooyung Han std::vector<Task*> tasks;
136*5f39d1b3SJooyung Han std::for_each(
137*5f39d1b3SJooyung Han task_rects.begin(), task_rects.end(),
138*5f39d1b3SJooyung Han [&tasks, &task_scratch, lhs, rhs, k, result, result_stride, operation,
139*5f39d1b3SJooyung Han scratch_per_thread](internal::TaskRect& rect) {
140*5f39d1b3SJooyung Han tasks.push_back(new internal::MetaTask<IN_TYPE, OUT_TYPE, F>(
141*5f39d1b3SJooyung Han task_scratch, lhs, rhs, rect, k, result, result_stride, operation));
142*5f39d1b3SJooyung Han task_scratch += scratch_per_thread;
143*5f39d1b3SJooyung Han });
144*5f39d1b3SJooyung Han pool->Execute(tasks);
145*5f39d1b3SJooyung Han }
146*5f39d1b3SJooyung Han
147*5f39d1b3SJooyung Han } // namespace internal
148*5f39d1b3SJooyung Han } // namespace meta
149*5f39d1b3SJooyung Han } // namespace gemmlowp
150*5f39d1b3SJooyung Han
151*5f39d1b3SJooyung Han #endif // GEMMLOWP_META_MULTI_THREAD_COMMON_H_
152