xref: /aosp_15_r20/external/gemmlowp/meta/single_thread_gemm.h (revision 5f39d1b313f0528e11bae88b3029b54b9e1033e7)
1*5f39d1b3SJooyung Han // Copyright 2016 The Gemmlowp Authors. All Rights Reserved.
2*5f39d1b3SJooyung Han //
3*5f39d1b3SJooyung Han // Licensed under the Apache License, Version 2.0 (the "License");
4*5f39d1b3SJooyung Han // you may not use this file except in compliance with the License.
5*5f39d1b3SJooyung Han // You may obtain a copy of the License at
6*5f39d1b3SJooyung Han //
7*5f39d1b3SJooyung Han //     http://www.apache.org/licenses/LICENSE-2.0
8*5f39d1b3SJooyung Han //
9*5f39d1b3SJooyung Han // Unless required by applicable law or agreed to in writing, software
10*5f39d1b3SJooyung Han // distributed under the License is distributed on an "AS IS" BASIS,
11*5f39d1b3SJooyung Han // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*5f39d1b3SJooyung Han // See the License for the specific language governing permissions and
13*5f39d1b3SJooyung Han // limitations under the License.
14*5f39d1b3SJooyung Han 
15*5f39d1b3SJooyung Han #ifndef GEMMLOWP_META_SINGLE_THREAD_GEMM_H_
16*5f39d1b3SJooyung Han #define GEMMLOWP_META_SINGLE_THREAD_GEMM_H_
17*5f39d1b3SJooyung Han 
18*5f39d1b3SJooyung Han #include <iostream>
19*5f39d1b3SJooyung Han #include "base.h"
20*5f39d1b3SJooyung Han 
21*5f39d1b3SJooyung Han namespace gemmlowp {
22*5f39d1b3SJooyung Han namespace meta {
23*5f39d1b3SJooyung Han 
24*5f39d1b3SJooyung Han template <typename Executor, typename Params, int kernel_m, int kernel_n,
25*5f39d1b3SJooyung Han           int kernel_k>
26*5f39d1b3SJooyung Han void Gemm(const Params& params);
27*5f39d1b3SJooyung Han 
28*5f39d1b3SJooyung Han class GemmExecutorPackRHS {
29*5f39d1b3SJooyung Han  public:
30*5f39d1b3SJooyung Han   template <typename P>
EstimateScratchSize(const P & params,int kernel_m,int kernel_n,int kernel_k)31*5f39d1b3SJooyung Han   static int EstimateScratchSize(const P& params, int kernel_m, int kernel_n,
32*5f39d1b3SJooyung Han                                  int kernel_k) {
33*5f39d1b3SJooyung Han     const int lhs_scratch =
34*5f39d1b3SJooyung Han         StreamUtil<typename P::InType, typename P::LeftStream>::Scratch(
35*5f39d1b3SJooyung Han             params.left_stream, kernel_m, kernel_k);
36*5f39d1b3SJooyung Han     const int rhs_chunks = ((params.n + kernel_n - 1) / kernel_n);
37*5f39d1b3SJooyung Han     const int rhs_scratch =
38*5f39d1b3SJooyung Han         rhs_chunks *
39*5f39d1b3SJooyung Han         StreamUtil<typename P::InType, typename P::RightStream>::Scratch(
40*5f39d1b3SJooyung Han             params.right_stream, kernel_n, kernel_k);
41*5f39d1b3SJooyung Han     return AlignTo<64 * 1024>(lhs_scratch + rhs_scratch);
42*5f39d1b3SJooyung Han   }
43*5f39d1b3SJooyung Han 
44*5f39d1b3SJooyung Han   template <typename P, int m, int n, int k, int m_leftovers, int n_leftovers,
45*5f39d1b3SJooyung Han             int k_leftovers>
ExecuteDispatch3D(const P & params)46*5f39d1b3SJooyung Han   static void ExecuteDispatch3D(const P& params) {
47*5f39d1b3SJooyung Han     // Shorthand typedefs for streams and multiply kernels.
48*5f39d1b3SJooyung Han     typedef typename P::InType InType;
49*5f39d1b3SJooyung Han     typedef typename P::OutType OutType;
50*5f39d1b3SJooyung Han 
51*5f39d1b3SJooyung Han     typedef Stream<typename P::InType, m, k, k_leftovers,
52*5f39d1b3SJooyung Han                    typename P::LeftStream>
53*5f39d1b3SJooyung Han         LeftStreamF;
54*5f39d1b3SJooyung Han     typedef Stream<typename P::InType, m_leftovers, k, k_leftovers,
55*5f39d1b3SJooyung Han                    typename P::LeftStream>
56*5f39d1b3SJooyung Han         LeftStreamL;
57*5f39d1b3SJooyung Han 
58*5f39d1b3SJooyung Han     typedef Stream<typename P::InType, n, k, k_leftovers,
59*5f39d1b3SJooyung Han                    typename P::RightStream>
60*5f39d1b3SJooyung Han         RightStreamF;
61*5f39d1b3SJooyung Han     typedef Stream<typename P::InType, n_leftovers, k, k_leftovers,
62*5f39d1b3SJooyung Han                    typename P::RightStream>
63*5f39d1b3SJooyung Han         RightStreamL;
64*5f39d1b3SJooyung Han 
65*5f39d1b3SJooyung Han     typedef Stream<typename P::OutType, m, n, 0, typename P::OutputStream>
66*5f39d1b3SJooyung Han         OutputStreamFF;
67*5f39d1b3SJooyung Han     typedef Stream<typename P::OutType, m_leftovers, n, 0,
68*5f39d1b3SJooyung Han                    typename P::OutputStream>
69*5f39d1b3SJooyung Han         OutputStreamLF;
70*5f39d1b3SJooyung Han 
71*5f39d1b3SJooyung Han     typedef MulKernel<typename P::InType, typename P::OutType,
72*5f39d1b3SJooyung Han                       typename P::Kernel, typename P::OutputStream, m, n, k>
73*5f39d1b3SJooyung Han         KernelFF;
74*5f39d1b3SJooyung Han     typedef MulKernel<typename P::InType, typename P::OutType,
75*5f39d1b3SJooyung Han                       typename P::Kernel, typename P::OutputStream, m,
76*5f39d1b3SJooyung Han                       n_leftovers, k>
77*5f39d1b3SJooyung Han         KernelFL;
78*5f39d1b3SJooyung Han     typedef MulKernel<typename P::InType, typename P::OutType,
79*5f39d1b3SJooyung Han                       typename P::Kernel, typename P::OutputStream, m_leftovers,
80*5f39d1b3SJooyung Han                       n, k>
81*5f39d1b3SJooyung Han         KernelLF;
82*5f39d1b3SJooyung Han     typedef MulKernel<typename P::InType, typename P::OutType,
83*5f39d1b3SJooyung Han                       typename P::Kernel, typename P::OutputStream, m_leftovers,
84*5f39d1b3SJooyung Han                       n_leftovers, k>
85*5f39d1b3SJooyung Han         KernelLL;
86*5f39d1b3SJooyung Han 
87*5f39d1b3SJooyung Han #ifdef DEBUG
88*5f39d1b3SJooyung Han #ifdef DEBUG_METAGEMM_VERBOSE
89*5f39d1b3SJooyung Han     std::cout << "GemmExecutor(" << typeid(P).name() << "): " << m << "x" << n
90*5f39d1b3SJooyung Han               << "x" << k << " -- " << m_leftovers << "x" << n_leftovers << "x"
91*5f39d1b3SJooyung Han               << k_leftovers << " -- " << params.m << "x" << params.n << "x"
92*5f39d1b3SJooyung Han               << params.k << std::endl;
93*5f39d1b3SJooyung Han     LeftStreamF::Debug(params.left_stream);
94*5f39d1b3SJooyung Han     LeftStreamL::Debug(params.left_stream);
95*5f39d1b3SJooyung Han 
96*5f39d1b3SJooyung Han     RightStreamF::Debug(params.right_stream);
97*5f39d1b3SJooyung Han     RightStreamL::Debug(params.right_stream);
98*5f39d1b3SJooyung Han 
99*5f39d1b3SJooyung Han     OutputStreamFF::Debug(params.fused_kernel.output_stream);
100*5f39d1b3SJooyung Han     OutputStreamLF::Debug(params.fused_kernel.output_stream);
101*5f39d1b3SJooyung Han 
102*5f39d1b3SJooyung Han     KernelFF::Debug(params.fused_kernel);
103*5f39d1b3SJooyung Han     KernelFL::Debug(params.fused_kernel);
104*5f39d1b3SJooyung Han     KernelLF::Debug(params.fused_kernel);
105*5f39d1b3SJooyung Han     KernelLL::Debug(params.fused_kernel);
106*5f39d1b3SJooyung Han #endif
107*5f39d1b3SJooyung Han #endif
108*5f39d1b3SJooyung Han 
109*5f39d1b3SJooyung Han     int lhs_chunks = params.m / m;
110*5f39d1b3SJooyung Han     int rhs_chunks = params.n / n;
111*5f39d1b3SJooyung Han 
112*5f39d1b3SJooyung Han     // Scratch memory for packed LHS & RHS chunks.
113*5f39d1b3SJooyung Han 
114*5f39d1b3SJooyung Han     std::uint8_t* packed_lhs = params.scratch;
115*5f39d1b3SJooyung Han     std::uint8_t* packed_rhs =
116*5f39d1b3SJooyung Han         params.scratch + LeftStreamF::Scratch(params.left_stream);
117*5f39d1b3SJooyung Han 
118*5f39d1b3SJooyung Han     // Pack full RHS first.
119*5f39d1b3SJooyung Han 
120*5f39d1b3SJooyung Han     std::uint8_t* packed_rhs_chunk = packed_rhs;
121*5f39d1b3SJooyung Han     const int packed_rhs_chunk_size =
122*5f39d1b3SJooyung Han         RightStreamF::PackedStride(params.right_stream);
123*5f39d1b3SJooyung Han 
124*5f39d1b3SJooyung Han     {
125*5f39d1b3SJooyung Han       const std::uint8_t* rhs_chunk =
126*5f39d1b3SJooyung Han           reinterpret_cast<const std::uint8_t*>(params.rhs);
127*5f39d1b3SJooyung Han       const int rhs_chunk_size =
128*5f39d1b3SJooyung Han           RightStreamF::UnpackedStride(params.right_stream);
129*5f39d1b3SJooyung Han 
130*5f39d1b3SJooyung Han       for (int i = 0; i < rhs_chunks; ++i) {
131*5f39d1b3SJooyung Han         RightStreamF::Pack(reinterpret_cast<const InType*>(rhs_chunk),
132*5f39d1b3SJooyung Han                            params.right_stream,
133*5f39d1b3SJooyung Han                            reinterpret_cast<InType*>(packed_rhs_chunk));
134*5f39d1b3SJooyung Han 
135*5f39d1b3SJooyung Han         rhs_chunk += rhs_chunk_size;
136*5f39d1b3SJooyung Han         packed_rhs_chunk += packed_rhs_chunk_size;
137*5f39d1b3SJooyung Han       }
138*5f39d1b3SJooyung Han 
139*5f39d1b3SJooyung Han       RightStreamL::Pack(reinterpret_cast<const InType*>(rhs_chunk),
140*5f39d1b3SJooyung Han                          params.right_stream,
141*5f39d1b3SJooyung Han                          reinterpret_cast<InType*>(packed_rhs_chunk));
142*5f39d1b3SJooyung Han     }
143*5f39d1b3SJooyung Han 
144*5f39d1b3SJooyung Han     // Multiply RHS by LHS one LHS chunk at a time.
145*5f39d1b3SJooyung Han 
146*5f39d1b3SJooyung Han     const std::uint8_t* lhs_chunk =
147*5f39d1b3SJooyung Han         reinterpret_cast<const std::uint8_t*>(params.lhs);
148*5f39d1b3SJooyung Han     std::uint8_t* result_strip = reinterpret_cast<std::uint8_t*>(params.result);
149*5f39d1b3SJooyung Han     std::uint8_t* result_chunk = result_strip;
150*5f39d1b3SJooyung Han 
151*5f39d1b3SJooyung Han     {
152*5f39d1b3SJooyung Han       const int lhs_chunk_size =
153*5f39d1b3SJooyung Han           LeftStreamF::UnpackedStride(params.left_stream);
154*5f39d1b3SJooyung Han       const int result_strip_size =
155*5f39d1b3SJooyung Han           OutputStreamFF::UnpackedStride(params.fused_kernel.output_stream);
156*5f39d1b3SJooyung Han       const int result_chunk_size =
157*5f39d1b3SJooyung Han           OutputStreamFF::UnpackedAdvance(params.fused_kernel.output_stream);
158*5f39d1b3SJooyung Han 
159*5f39d1b3SJooyung Han       for (int i = 0; i < lhs_chunks; ++i) {
160*5f39d1b3SJooyung Han         LeftStreamF::Pack(reinterpret_cast<const InType*>(lhs_chunk),
161*5f39d1b3SJooyung Han                           params.left_stream,
162*5f39d1b3SJooyung Han                           reinterpret_cast<InType*>(packed_lhs));
163*5f39d1b3SJooyung Han 
164*5f39d1b3SJooyung Han         result_chunk = result_strip;
165*5f39d1b3SJooyung Han         packed_rhs_chunk = packed_rhs;
166*5f39d1b3SJooyung Han 
167*5f39d1b3SJooyung Han         for (int j = 0; j < rhs_chunks; ++j) {
168*5f39d1b3SJooyung Han           KernelFF::Multiply(reinterpret_cast<const InType*>(packed_lhs),
169*5f39d1b3SJooyung Han                              reinterpret_cast<const InType*>(packed_rhs_chunk),
170*5f39d1b3SJooyung Han                              params.fused_kernel,
171*5f39d1b3SJooyung Han                              reinterpret_cast<OutType*>(result_chunk));
172*5f39d1b3SJooyung Han 
173*5f39d1b3SJooyung Han           result_chunk += result_chunk_size;
174*5f39d1b3SJooyung Han           packed_rhs_chunk += packed_rhs_chunk_size;
175*5f39d1b3SJooyung Han         }
176*5f39d1b3SJooyung Han 
177*5f39d1b3SJooyung Han         KernelFL::Multiply(reinterpret_cast<const InType*>(packed_lhs),
178*5f39d1b3SJooyung Han                            reinterpret_cast<const InType*>(packed_rhs_chunk),
179*5f39d1b3SJooyung Han                            params.fused_kernel,
180*5f39d1b3SJooyung Han                            reinterpret_cast<OutType*>(result_chunk));
181*5f39d1b3SJooyung Han 
182*5f39d1b3SJooyung Han         lhs_chunk += lhs_chunk_size;
183*5f39d1b3SJooyung Han         result_strip += result_strip_size;
184*5f39d1b3SJooyung Han       }
185*5f39d1b3SJooyung Han     }
186*5f39d1b3SJooyung Han 
187*5f39d1b3SJooyung Han     // Leftover LHS chunk.
188*5f39d1b3SJooyung Han     if (m_leftovers > 0) {  // static if
189*5f39d1b3SJooyung Han       const int result_chunk_size =
190*5f39d1b3SJooyung Han           OutputStreamLF::UnpackedAdvance(params.fused_kernel.output_stream);
191*5f39d1b3SJooyung Han 
192*5f39d1b3SJooyung Han       LeftStreamL::Pack(reinterpret_cast<const InType*>(lhs_chunk),
193*5f39d1b3SJooyung Han                         params.left_stream,
194*5f39d1b3SJooyung Han                         reinterpret_cast<InType*>(packed_lhs));
195*5f39d1b3SJooyung Han 
196*5f39d1b3SJooyung Han       result_chunk = result_strip;
197*5f39d1b3SJooyung Han       packed_rhs_chunk = packed_rhs;
198*5f39d1b3SJooyung Han 
199*5f39d1b3SJooyung Han       for (int i = 0; i < rhs_chunks; ++i) {
200*5f39d1b3SJooyung Han         KernelLF::Multiply(reinterpret_cast<const InType*>(packed_lhs),
201*5f39d1b3SJooyung Han                            reinterpret_cast<const InType*>(packed_rhs_chunk),
202*5f39d1b3SJooyung Han                            params.fused_kernel,
203*5f39d1b3SJooyung Han                            reinterpret_cast<OutType*>(result_chunk));
204*5f39d1b3SJooyung Han 
205*5f39d1b3SJooyung Han         result_chunk += result_chunk_size;
206*5f39d1b3SJooyung Han         packed_rhs_chunk += packed_rhs_chunk_size;
207*5f39d1b3SJooyung Han       }
208*5f39d1b3SJooyung Han 
209*5f39d1b3SJooyung Han       KernelLL::Multiply(reinterpret_cast<const InType*>(packed_lhs),
210*5f39d1b3SJooyung Han                          reinterpret_cast<const InType*>(packed_rhs_chunk),
211*5f39d1b3SJooyung Han                          params.fused_kernel,
212*5f39d1b3SJooyung Han                          reinterpret_cast<OutType*>(result_chunk));
213*5f39d1b3SJooyung Han     }
214*5f39d1b3SJooyung Han   }
215*5f39d1b3SJooyung Han };
216*5f39d1b3SJooyung Han 
217*5f39d1b3SJooyung Han class GemmExecutorPackLHS {
218*5f39d1b3SJooyung Han  public:
219*5f39d1b3SJooyung Han   template <typename P>
EstimateScratchSize(const P & params,int kernel_m,int kernel_n,int kernel_k)220*5f39d1b3SJooyung Han   static int EstimateScratchSize(const P& params, int kernel_m, int kernel_n,
221*5f39d1b3SJooyung Han                                  int kernel_k) {
222*5f39d1b3SJooyung Han     const int lhs_chunks = ((params.m + kernel_m - 1) / kernel_m);
223*5f39d1b3SJooyung Han     const int lhs_scratch =
224*5f39d1b3SJooyung Han         lhs_chunks *
225*5f39d1b3SJooyung Han         StreamUtil<typename P::InType, typename P::LeftStream>::Scratch(
226*5f39d1b3SJooyung Han             params.left_stream, kernel_m, kernel_k);
227*5f39d1b3SJooyung Han     const int rhs_scratch =
228*5f39d1b3SJooyung Han         StreamUtil<typename P::InType, typename P::RightStream>::Scratch(
229*5f39d1b3SJooyung Han             params.right_stream, kernel_n, kernel_k);
230*5f39d1b3SJooyung Han     return AlignTo<64 * 1024>(lhs_scratch + rhs_scratch);
231*5f39d1b3SJooyung Han   }
232*5f39d1b3SJooyung Han 
233*5f39d1b3SJooyung Han   template <typename P, int m, int n, int k, int m_leftovers, int n_leftovers,
234*5f39d1b3SJooyung Han             int k_leftovers>
ExecuteDispatch3D(const P & params)235*5f39d1b3SJooyung Han   static void ExecuteDispatch3D(const P& params) {
236*5f39d1b3SJooyung Han     // Shorthand typedefs for streams and multiply kernels.
237*5f39d1b3SJooyung Han     typedef typename P::InType InType;
238*5f39d1b3SJooyung Han     typedef typename P::OutType OutType;
239*5f39d1b3SJooyung Han 
240*5f39d1b3SJooyung Han     typedef Stream<typename P::InType, m, k, k_leftovers,
241*5f39d1b3SJooyung Han                    typename P::LeftStream>
242*5f39d1b3SJooyung Han         LeftStreamF;
243*5f39d1b3SJooyung Han     typedef Stream<typename P::InType, m_leftovers, k, k_leftovers,
244*5f39d1b3SJooyung Han                    typename P::LeftStream>
245*5f39d1b3SJooyung Han         LeftStreamL;
246*5f39d1b3SJooyung Han 
247*5f39d1b3SJooyung Han     typedef Stream<typename P::InType, n, k, k_leftovers,
248*5f39d1b3SJooyung Han                    typename P::RightStream>
249*5f39d1b3SJooyung Han         RightStreamF;
250*5f39d1b3SJooyung Han     typedef Stream<typename P::InType, n_leftovers, k, k_leftovers,
251*5f39d1b3SJooyung Han                    typename P::RightStream>
252*5f39d1b3SJooyung Han         RightStreamL;
253*5f39d1b3SJooyung Han 
254*5f39d1b3SJooyung Han     typedef Stream<typename P::OutType, m, n, 0, typename P::OutputStream>
255*5f39d1b3SJooyung Han         OutputStreamFF;
256*5f39d1b3SJooyung Han     typedef Stream<typename P::OutType, m, n_leftovers, 0,
257*5f39d1b3SJooyung Han                    typename P::OutputStream>
258*5f39d1b3SJooyung Han         OutputStreamFL;
259*5f39d1b3SJooyung Han 
260*5f39d1b3SJooyung Han     typedef MulKernel<typename P::InType, typename P::OutType,
261*5f39d1b3SJooyung Han                       typename P::Kernel, typename P::OutputStream, m, n, k>
262*5f39d1b3SJooyung Han         KernelFF;
263*5f39d1b3SJooyung Han     typedef MulKernel<typename P::InType, typename P::OutType,
264*5f39d1b3SJooyung Han                       typename P::Kernel, typename P::OutputStream, m,
265*5f39d1b3SJooyung Han                       n_leftovers, k>
266*5f39d1b3SJooyung Han         KernelFL;
267*5f39d1b3SJooyung Han     typedef MulKernel<typename P::InType, typename P::OutType,
268*5f39d1b3SJooyung Han                       typename P::Kernel, typename P::OutputStream, m_leftovers,
269*5f39d1b3SJooyung Han                       n, k>
270*5f39d1b3SJooyung Han         KernelLF;
271*5f39d1b3SJooyung Han     typedef MulKernel<typename P::InType, typename P::OutType,
272*5f39d1b3SJooyung Han                       typename P::Kernel, typename P::OutputStream, m_leftovers,
273*5f39d1b3SJooyung Han                       n_leftovers, k>
274*5f39d1b3SJooyung Han         KernelLL;
275*5f39d1b3SJooyung Han #ifdef DEBUG
276*5f39d1b3SJooyung Han #ifdef DEBUG_METAGEMM_VERBOSE
277*5f39d1b3SJooyung Han     std::cout << "GemmExecutor(" << typeid(P).name() << "): " << m << "x" << n
278*5f39d1b3SJooyung Han               << "x" << k << " -- " << m_leftovers << "x" << n_leftovers << "x"
279*5f39d1b3SJooyung Han               << k_leftovers << " -- " << params.m << "x" << params.n << "x"
280*5f39d1b3SJooyung Han               << params.k << std::endl;
281*5f39d1b3SJooyung Han     LeftStreamF::Debug(params.left_stream);
282*5f39d1b3SJooyung Han     LeftStreamL::Debug(params.left_stream);
283*5f39d1b3SJooyung Han 
284*5f39d1b3SJooyung Han     RightStreamF::Debug(params.right_stream);
285*5f39d1b3SJooyung Han     RightStreamL::Debug(params.right_stream);
286*5f39d1b3SJooyung Han 
287*5f39d1b3SJooyung Han     OutputStreamFF::Debug(params.fused_kernel.output_stream);
288*5f39d1b3SJooyung Han     OutputStreamFL::Debug(params.fused_kernel.output_stream);
289*5f39d1b3SJooyung Han 
290*5f39d1b3SJooyung Han     KernelFF::Debug(params.fused_kernel);
291*5f39d1b3SJooyung Han     KernelFL::Debug(params.fused_kernel);
292*5f39d1b3SJooyung Han     KernelLF::Debug(params.fused_kernel);
293*5f39d1b3SJooyung Han     KernelLL::Debug(params.fused_kernel);
294*5f39d1b3SJooyung Han #endif
295*5f39d1b3SJooyung Han #endif
296*5f39d1b3SJooyung Han 
297*5f39d1b3SJooyung Han     int lhs_chunks = params.m / m;
298*5f39d1b3SJooyung Han     int rhs_chunks = params.n / n;
299*5f39d1b3SJooyung Han 
300*5f39d1b3SJooyung Han     // Scratch memory for packed LHS & RHS chunks.
301*5f39d1b3SJooyung Han     std::uint8_t* packed_rhs = params.scratch;
302*5f39d1b3SJooyung Han     std::uint8_t* packed_lhs =
303*5f39d1b3SJooyung Han         params.scratch + RightStreamF::Scratch(params.right_stream);
304*5f39d1b3SJooyung Han 
305*5f39d1b3SJooyung Han     // Pack full LHS first.
306*5f39d1b3SJooyung Han 
307*5f39d1b3SJooyung Han     std::uint8_t* packed_lhs_chunk = packed_lhs;
308*5f39d1b3SJooyung Han     const int packed_lhs_chunk_size =
309*5f39d1b3SJooyung Han         LeftStreamF::PackedStride(params.left_stream);
310*5f39d1b3SJooyung Han 
311*5f39d1b3SJooyung Han     {
312*5f39d1b3SJooyung Han       const std::uint8_t* lhs_chunk =
313*5f39d1b3SJooyung Han           reinterpret_cast<const std::uint8_t*>(params.lhs);
314*5f39d1b3SJooyung Han       const int lhs_chunk_size =
315*5f39d1b3SJooyung Han           LeftStreamF::UnpackedStride(params.left_stream);
316*5f39d1b3SJooyung Han 
317*5f39d1b3SJooyung Han       for (int i = 0; i < lhs_chunks; ++i) {
318*5f39d1b3SJooyung Han         LeftStreamF::Pack(reinterpret_cast<const InType*>(lhs_chunk),
319*5f39d1b3SJooyung Han                           params.left_stream,
320*5f39d1b3SJooyung Han                           reinterpret_cast<InType*>(packed_lhs_chunk));
321*5f39d1b3SJooyung Han 
322*5f39d1b3SJooyung Han         lhs_chunk += lhs_chunk_size;
323*5f39d1b3SJooyung Han         packed_lhs_chunk += packed_lhs_chunk_size;
324*5f39d1b3SJooyung Han       }
325*5f39d1b3SJooyung Han 
326*5f39d1b3SJooyung Han       LeftStreamL::Pack(reinterpret_cast<const InType*>(lhs_chunk),
327*5f39d1b3SJooyung Han                         params.left_stream,
328*5f39d1b3SJooyung Han                         reinterpret_cast<InType*>(packed_lhs_chunk));
329*5f39d1b3SJooyung Han     }
330*5f39d1b3SJooyung Han 
331*5f39d1b3SJooyung Han     // Multiply RHS by LHS one RHS chunk at a time.
332*5f39d1b3SJooyung Han 
333*5f39d1b3SJooyung Han     const std::uint8_t* rhs_chunk =
334*5f39d1b3SJooyung Han         reinterpret_cast<const std::uint8_t*>(params.rhs);
335*5f39d1b3SJooyung Han     std::uint8_t* result_strip = reinterpret_cast<std::uint8_t*>(params.result);
336*5f39d1b3SJooyung Han     std::uint8_t* result_chunk = result_strip;
337*5f39d1b3SJooyung Han 
338*5f39d1b3SJooyung Han     {
339*5f39d1b3SJooyung Han       const int rhs_chunk_size =
340*5f39d1b3SJooyung Han           RightStreamF::UnpackedStride(params.right_stream);
341*5f39d1b3SJooyung Han       const int result_strip_size =
342*5f39d1b3SJooyung Han           OutputStreamFF::UnpackedAdvance(params.fused_kernel.output_stream);
343*5f39d1b3SJooyung Han       const int result_chunk_size =
344*5f39d1b3SJooyung Han           OutputStreamFF::UnpackedStride(params.fused_kernel.output_stream);
345*5f39d1b3SJooyung Han 
346*5f39d1b3SJooyung Han       for (int i = 0; i < rhs_chunks; ++i) {
347*5f39d1b3SJooyung Han         RightStreamF::Pack(reinterpret_cast<const InType*>(rhs_chunk),
348*5f39d1b3SJooyung Han                            params.right_stream,
349*5f39d1b3SJooyung Han                            reinterpret_cast<InType*>(packed_rhs));
350*5f39d1b3SJooyung Han 
351*5f39d1b3SJooyung Han         result_chunk = result_strip;
352*5f39d1b3SJooyung Han         packed_lhs_chunk = packed_lhs;
353*5f39d1b3SJooyung Han 
354*5f39d1b3SJooyung Han         for (int j = 0; j < lhs_chunks; ++j) {
355*5f39d1b3SJooyung Han           KernelFF::Multiply(reinterpret_cast<const InType*>(packed_lhs_chunk),
356*5f39d1b3SJooyung Han                              reinterpret_cast<const InType*>(packed_rhs),
357*5f39d1b3SJooyung Han                              params.fused_kernel,
358*5f39d1b3SJooyung Han                              reinterpret_cast<OutType*>(result_chunk));
359*5f39d1b3SJooyung Han 
360*5f39d1b3SJooyung Han           result_chunk += result_chunk_size;
361*5f39d1b3SJooyung Han           packed_lhs_chunk += packed_lhs_chunk_size;
362*5f39d1b3SJooyung Han         }
363*5f39d1b3SJooyung Han 
364*5f39d1b3SJooyung Han         KernelLF::Multiply(reinterpret_cast<const InType*>(packed_lhs_chunk),
365*5f39d1b3SJooyung Han                            reinterpret_cast<const InType*>(packed_rhs),
366*5f39d1b3SJooyung Han                            params.fused_kernel,
367*5f39d1b3SJooyung Han                            reinterpret_cast<OutType*>(result_chunk));
368*5f39d1b3SJooyung Han 
369*5f39d1b3SJooyung Han         rhs_chunk += rhs_chunk_size;
370*5f39d1b3SJooyung Han         result_strip += result_strip_size;
371*5f39d1b3SJooyung Han       }
372*5f39d1b3SJooyung Han     }
373*5f39d1b3SJooyung Han 
374*5f39d1b3SJooyung Han     // Leftover RHS chunk.
375*5f39d1b3SJooyung Han     if (n_leftovers > 0) {  // static if
376*5f39d1b3SJooyung Han       const int result_chunk_size =
377*5f39d1b3SJooyung Han           OutputStreamFL::UnpackedStride(params.fused_kernel.output_stream);
378*5f39d1b3SJooyung Han 
379*5f39d1b3SJooyung Han       RightStreamL::Pack(reinterpret_cast<const InType*>(rhs_chunk),
380*5f39d1b3SJooyung Han                          params.right_stream,
381*5f39d1b3SJooyung Han                          reinterpret_cast<InType*>(packed_rhs));
382*5f39d1b3SJooyung Han 
383*5f39d1b3SJooyung Han       result_chunk = result_strip;
384*5f39d1b3SJooyung Han       packed_lhs_chunk = packed_lhs;
385*5f39d1b3SJooyung Han 
386*5f39d1b3SJooyung Han       for (int i = 0; i < lhs_chunks; ++i) {
387*5f39d1b3SJooyung Han         KernelFL::Multiply(reinterpret_cast<const InType*>(packed_lhs_chunk),
388*5f39d1b3SJooyung Han                            reinterpret_cast<const InType*>(packed_rhs),
389*5f39d1b3SJooyung Han                            params.fused_kernel,
390*5f39d1b3SJooyung Han                            reinterpret_cast<OutType*>(result_chunk));
391*5f39d1b3SJooyung Han 
392*5f39d1b3SJooyung Han         result_chunk += result_chunk_size;
393*5f39d1b3SJooyung Han         packed_lhs_chunk += packed_lhs_chunk_size;
394*5f39d1b3SJooyung Han       }
395*5f39d1b3SJooyung Han 
396*5f39d1b3SJooyung Han       KernelLL::Multiply(reinterpret_cast<const InType*>(packed_lhs_chunk),
397*5f39d1b3SJooyung Han                          reinterpret_cast<const InType*>(packed_rhs),
398*5f39d1b3SJooyung Han                          params.fused_kernel,
399*5f39d1b3SJooyung Han                          reinterpret_cast<OutType*>(result_chunk));
400*5f39d1b3SJooyung Han     }
401*5f39d1b3SJooyung Han   }
402*5f39d1b3SJooyung Han };
403*5f39d1b3SJooyung Han 
404*5f39d1b3SJooyung Han namespace internal {
405*5f39d1b3SJooyung Han 
CalculateCacheFriendlyTasksCount(int cache_size,int constant_memory,int per_chunk_memory,int total_dim,int chunk_dim)406*5f39d1b3SJooyung Han inline int CalculateCacheFriendlyTasksCount(int cache_size, int constant_memory,
407*5f39d1b3SJooyung Han                                             int per_chunk_memory, int total_dim,
408*5f39d1b3SJooyung Han                                             int chunk_dim) {
409*5f39d1b3SJooyung Han   assert(constant_memory + per_chunk_memory < cache_size);
410*5f39d1b3SJooyung Han   const int available_cache = cache_size - constant_memory;
411*5f39d1b3SJooyung Han   const int available_chunks = available_cache / per_chunk_memory;
412*5f39d1b3SJooyung Han   const int chunks_count = (total_dim + chunk_dim - 1) / chunk_dim;
413*5f39d1b3SJooyung Han   return (chunks_count + available_chunks - 1) / available_chunks;
414*5f39d1b3SJooyung Han }
415*5f39d1b3SJooyung Han 
416*5f39d1b3SJooyung Han template <typename Params>
UpdateCacheFriendlyTask(int m_offset,int m,int n_offset,int n,const Params & params,Params * task_params)417*5f39d1b3SJooyung Han inline void UpdateCacheFriendlyTask(int m_offset, int m, int n_offset, int n,
418*5f39d1b3SJooyung Han                                     const Params& params, Params* task_params) {
419*5f39d1b3SJooyung Han   task_params->m = m;
420*5f39d1b3SJooyung Han   task_params->lhs =
421*5f39d1b3SJooyung Han       StreamUtil<typename Params::InType, typename Params::LeftStream>::Offset(
422*5f39d1b3SJooyung Han           params.left_stream, params.lhs, m_offset, 0);
423*5f39d1b3SJooyung Han 
424*5f39d1b3SJooyung Han   task_params->n = n;
425*5f39d1b3SJooyung Han   task_params->rhs =
426*5f39d1b3SJooyung Han       StreamUtil<typename Params::InType, typename Params::RightStream>::Offset(
427*5f39d1b3SJooyung Han           params.right_stream, params.rhs, n_offset, 0);
428*5f39d1b3SJooyung Han 
429*5f39d1b3SJooyung Han   task_params->result =
430*5f39d1b3SJooyung Han       StreamUtil<typename Params::OutType, typename Params::OutputStream>::
431*5f39d1b3SJooyung Han           Offset(params.fused_kernel.output_stream, params.result, m_offset,
432*5f39d1b3SJooyung Han                  n_offset);
433*5f39d1b3SJooyung Han }
434*5f39d1b3SJooyung Han 
435*5f39d1b3SJooyung Han }  // namespace internal
436*5f39d1b3SJooyung Han 
437*5f39d1b3SJooyung Han template <int cache_size = 256 * 1024>
438*5f39d1b3SJooyung Han class GemmExecutorPackRHSCacheFriendly {
439*5f39d1b3SJooyung Han  public:
440*5f39d1b3SJooyung Han   template <typename P>
EstimateScratchSize(const P & params,int kernel_m,int kernel_n,int kernel_k)441*5f39d1b3SJooyung Han   static int EstimateScratchSize(const P& params, int kernel_m, int kernel_n,
442*5f39d1b3SJooyung Han                                  int kernel_k) {
443*5f39d1b3SJooyung Han     return cache_size;
444*5f39d1b3SJooyung Han   }
445*5f39d1b3SJooyung Han 
446*5f39d1b3SJooyung Han   template <typename P, int m, int n, int k, int m_leftovers, int n_leftovers,
447*5f39d1b3SJooyung Han             int k_leftovers>
ExecuteDispatch3D(const P & params)448*5f39d1b3SJooyung Han   static void ExecuteDispatch3D(const P& params) {
449*5f39d1b3SJooyung Han     typedef Stream<typename P::InType, m, k, k_leftovers,
450*5f39d1b3SJooyung Han                    typename P::LeftStream>
451*5f39d1b3SJooyung Han         LeftStream;
452*5f39d1b3SJooyung Han 
453*5f39d1b3SJooyung Han     typedef Stream<typename P::InType, n, k, k_leftovers,
454*5f39d1b3SJooyung Han                    typename P::RightStream>
455*5f39d1b3SJooyung Han         RightStream;
456*5f39d1b3SJooyung Han 
457*5f39d1b3SJooyung Han     const int lhs_scratch = LeftStream::Scratch(params.left_stream);
458*5f39d1b3SJooyung Han     const int rhs_scratch = RightStream::Scratch(params.right_stream);
459*5f39d1b3SJooyung Han 
460*5f39d1b3SJooyung Han     const int cache_friendly_tasks_count =
461*5f39d1b3SJooyung Han         internal::CalculateCacheFriendlyTasksCount(cache_size, lhs_scratch,
462*5f39d1b3SJooyung Han                                                    rhs_scratch, params.n, n);
463*5f39d1b3SJooyung Han 
464*5f39d1b3SJooyung Han     if (cache_friendly_tasks_count == 1) {
465*5f39d1b3SJooyung Han       GemmExecutorPackRHS::ExecuteDispatch3D<P, m, n, k, m_leftovers,
466*5f39d1b3SJooyung Han                                              n_leftovers, k_leftovers>(params);
467*5f39d1b3SJooyung Han       return;
468*5f39d1b3SJooyung Han     }
469*5f39d1b3SJooyung Han 
470*5f39d1b3SJooyung Han     const int cache_friendly_dim = params.n / cache_friendly_tasks_count;
471*5f39d1b3SJooyung Han 
472*5f39d1b3SJooyung Han     P task_params = params;
473*5f39d1b3SJooyung Han     for (int i = 0; i < cache_friendly_tasks_count - 1; ++i) {
474*5f39d1b3SJooyung Han       internal::UpdateCacheFriendlyTask(0, params.m, i * cache_friendly_dim,
475*5f39d1b3SJooyung Han                                         cache_friendly_dim, params,
476*5f39d1b3SJooyung Han                                         &task_params);
477*5f39d1b3SJooyung Han       Gemm<GemmExecutorPackRHS, P, m, n, k>(task_params);
478*5f39d1b3SJooyung Han     }
479*5f39d1b3SJooyung Han     const int dim_sum = (cache_friendly_tasks_count - 1) * cache_friendly_dim;
480*5f39d1b3SJooyung Han     internal::UpdateCacheFriendlyTask(0, params.m, dim_sum, params.n - dim_sum,
481*5f39d1b3SJooyung Han                                       params, &task_params);
482*5f39d1b3SJooyung Han     Gemm<GemmExecutorPackRHS, P, m, n, k>(task_params);
483*5f39d1b3SJooyung Han   }
484*5f39d1b3SJooyung Han };
485*5f39d1b3SJooyung Han 
486*5f39d1b3SJooyung Han template <int cache_size = 256 * 1024>
487*5f39d1b3SJooyung Han class GemmExecutorPackLHSCacheFriendly {
488*5f39d1b3SJooyung Han  public:
489*5f39d1b3SJooyung Han   template <typename P>
EstimateScratchSize(const P & params,int kernel_m,int kernel_n,int kernel_k)490*5f39d1b3SJooyung Han   static int EstimateScratchSize(const P& params, int kernel_m, int kernel_n,
491*5f39d1b3SJooyung Han                                  int kernel_k) {
492*5f39d1b3SJooyung Han     return cache_size;
493*5f39d1b3SJooyung Han   }
494*5f39d1b3SJooyung Han 
495*5f39d1b3SJooyung Han   template <typename P, int m, int n, int k, int m_leftovers, int n_leftovers,
496*5f39d1b3SJooyung Han             int k_leftovers>
ExecuteDispatch3D(const P & params)497*5f39d1b3SJooyung Han   static void ExecuteDispatch3D(const P& params) {
498*5f39d1b3SJooyung Han     typedef Stream<typename P::InType, m, k, k_leftovers,
499*5f39d1b3SJooyung Han                    typename P::LeftStream>
500*5f39d1b3SJooyung Han         LeftStream;
501*5f39d1b3SJooyung Han 
502*5f39d1b3SJooyung Han     typedef Stream<typename P::InType, n, k, k_leftovers,
503*5f39d1b3SJooyung Han                    typename P::RightStream>
504*5f39d1b3SJooyung Han         RightStream;
505*5f39d1b3SJooyung Han 
506*5f39d1b3SJooyung Han     const int lhs_scratch = LeftStream::Scratch(params.left_stream);
507*5f39d1b3SJooyung Han     const int rhs_scratch = RightStream::Scratch(params.right_stream);
508*5f39d1b3SJooyung Han 
509*5f39d1b3SJooyung Han     const int cache_friendly_tasks_count =
510*5f39d1b3SJooyung Han         internal::CalculateCacheFriendlyTasksCount(cache_size, rhs_scratch,
511*5f39d1b3SJooyung Han                                                    lhs_scratch, params.m, m);
512*5f39d1b3SJooyung Han 
513*5f39d1b3SJooyung Han     if (cache_friendly_tasks_count == 1) {
514*5f39d1b3SJooyung Han       GemmExecutorPackLHS::ExecuteDispatch3D<P, m, n, k, m_leftovers,
515*5f39d1b3SJooyung Han                                              n_leftovers, k_leftovers>(params);
516*5f39d1b3SJooyung Han       return;
517*5f39d1b3SJooyung Han     }
518*5f39d1b3SJooyung Han 
519*5f39d1b3SJooyung Han     const int cache_friendly_dim = params.m / cache_friendly_tasks_count;
520*5f39d1b3SJooyung Han 
521*5f39d1b3SJooyung Han     P task_params = params;
522*5f39d1b3SJooyung Han     for (int i = 0; i < cache_friendly_tasks_count - 1; ++i) {
523*5f39d1b3SJooyung Han       internal::UpdateCacheFriendlyTask(i * cache_friendly_dim,
524*5f39d1b3SJooyung Han                                         cache_friendly_dim, 0, params.n, params,
525*5f39d1b3SJooyung Han                                         &task_params);
526*5f39d1b3SJooyung Han       Gemm<GemmExecutorPackLHS, P, m, n, k>(task_params);
527*5f39d1b3SJooyung Han     }
528*5f39d1b3SJooyung Han     const int dim_sum = (cache_friendly_tasks_count - 1) * cache_friendly_dim;
529*5f39d1b3SJooyung Han     internal::UpdateCacheFriendlyTask(dim_sum, params.m - dim_sum, 0, params.n,
530*5f39d1b3SJooyung Han                                       params, &task_params);
531*5f39d1b3SJooyung Han     Gemm<GemmExecutorPackLHS, P, m, n, k>(task_params);
532*5f39d1b3SJooyung Han   }
533*5f39d1b3SJooyung Han };
534*5f39d1b3SJooyung Han 
535*5f39d1b3SJooyung Han namespace internal {
536*5f39d1b3SJooyung Han 
537*5f39d1b3SJooyung Han // Stage 3.
538*5f39d1b3SJooyung Han 
539*5f39d1b3SJooyung Han template <typename E, typename P, int dim_m, int dim_n, int dim_k, int fixed_m,
540*5f39d1b3SJooyung Han           int fixed_n, int variable_k>
541*5f39d1b3SJooyung Han struct Dispatch3DStage3 {
ExecuteDispatch3DStage3542*5f39d1b3SJooyung Han   static void Execute(const P& params, int k) {
543*5f39d1b3SJooyung Han #ifdef DEBUG
544*5f39d1b3SJooyung Han #ifdef DEBUG_METAGEMM_VERBOSE
545*5f39d1b3SJooyung Han     std::cout << "Dispatch(3): " << dim_m << "x" << dim_n << "x" << dim_k
546*5f39d1b3SJooyung Han               << " : " << fixed_m << "x" << fixed_n << "x" << variable_k
547*5f39d1b3SJooyung Han               << std::endl
548*5f39d1b3SJooyung Han               << std::flush;
549*5f39d1b3SJooyung Han #endif
550*5f39d1b3SJooyung Han #endif
551*5f39d1b3SJooyung Han     if (k == variable_k) {
552*5f39d1b3SJooyung Han       E::template ExecuteDispatch3D<P, dim_m, dim_n, dim_k, fixed_m, fixed_n,
553*5f39d1b3SJooyung Han                                     variable_k>(params);
554*5f39d1b3SJooyung Han     } else {
555*5f39d1b3SJooyung Han       Dispatch3DStage3<E, P, dim_m, dim_n, dim_k, fixed_m, fixed_n,
556*5f39d1b3SJooyung Han                        variable_k - 1>::Execute(params, k);
557*5f39d1b3SJooyung Han     }
558*5f39d1b3SJooyung Han   }
559*5f39d1b3SJooyung Han };
560*5f39d1b3SJooyung Han 
561*5f39d1b3SJooyung Han template <typename E, typename P, int dim_m, int dim_n, int dim_k, int fixed_m,
562*5f39d1b3SJooyung Han           int fixed_n>
563*5f39d1b3SJooyung Han struct Dispatch3DStage3<E, P, dim_m, dim_n, dim_k, fixed_m, fixed_n, 0> {
564*5f39d1b3SJooyung Han   static void Execute(const P& params, int k) {
565*5f39d1b3SJooyung Han #ifdef DEBUG
566*5f39d1b3SJooyung Han #ifdef DEBUG_METAGEMM_VERBOSE
567*5f39d1b3SJooyung Han     std::cout << "Dispatch(3): " << dim_m << "x" << dim_n << "x" << dim_k
568*5f39d1b3SJooyung Han               << " : " << fixed_m << "x" << fixed_n << "x" << 0 << std::endl
569*5f39d1b3SJooyung Han               << std::flush;
570*5f39d1b3SJooyung Han #endif
571*5f39d1b3SJooyung Han #endif
572*5f39d1b3SJooyung Han     if (k == 0) {
573*5f39d1b3SJooyung Han       E::template ExecuteDispatch3D<P, dim_m, dim_n, dim_k, fixed_m, fixed_n,
574*5f39d1b3SJooyung Han                                     0>(params);
575*5f39d1b3SJooyung Han     } else {
576*5f39d1b3SJooyung Han       std::cerr << "FATAL: dispatch3DStage3 failed: ran out of cases."
577*5f39d1b3SJooyung Han                 << std::endl
578*5f39d1b3SJooyung Han                 << std::flush;
579*5f39d1b3SJooyung Han       std::exit(1);
580*5f39d1b3SJooyung Han     }
581*5f39d1b3SJooyung Han   }
582*5f39d1b3SJooyung Han };
583*5f39d1b3SJooyung Han 
584*5f39d1b3SJooyung Han // Stage 2.
585*5f39d1b3SJooyung Han 
586*5f39d1b3SJooyung Han template <typename E, typename P, int dim_m, int dim_n, int dim_k, int fixed_m,
587*5f39d1b3SJooyung Han           int variable_n>
588*5f39d1b3SJooyung Han struct Dispatch3DStage2 {
589*5f39d1b3SJooyung Han   static void Execute(const P& params, int n, int k) {
590*5f39d1b3SJooyung Han #ifdef DEBUG
591*5f39d1b3SJooyung Han #ifdef DEBUG_METAGEMM_VERBOSE
592*5f39d1b3SJooyung Han     std::cout << "Dispatch(2): " << dim_m << "x" << dim_n << "x" << dim_k
593*5f39d1b3SJooyung Han               << " : " << fixed_m << "x" << variable_n << std::endl
594*5f39d1b3SJooyung Han               << std::flush;
595*5f39d1b3SJooyung Han #endif
596*5f39d1b3SJooyung Han #endif
597*5f39d1b3SJooyung Han     if (n == variable_n) {
598*5f39d1b3SJooyung Han       Dispatch3DStage3<E, P, dim_m, dim_n, dim_k, fixed_m, variable_n,
599*5f39d1b3SJooyung Han                        dim_k - 1>::Execute(params, k);
600*5f39d1b3SJooyung Han     } else {
601*5f39d1b3SJooyung Han       Dispatch3DStage2<E, P, dim_m, dim_n, dim_k, fixed_m,
602*5f39d1b3SJooyung Han                        variable_n - 1>::Execute(params, n, k);
603*5f39d1b3SJooyung Han     }
604*5f39d1b3SJooyung Han   }
605*5f39d1b3SJooyung Han };
606*5f39d1b3SJooyung Han 
607*5f39d1b3SJooyung Han template <typename E, typename P, int dim_m, int dim_n, int dim_k, int fixed_m>
608*5f39d1b3SJooyung Han struct Dispatch3DStage2<E, P, dim_m, dim_n, dim_k, fixed_m, 0> {
609*5f39d1b3SJooyung Han   static void Execute(const P& params, int n, int k) {
610*5f39d1b3SJooyung Han #ifdef DEBUG
611*5f39d1b3SJooyung Han #ifdef DEBUG_METAGEMM_VERBOSE
612*5f39d1b3SJooyung Han     std::cout << "Dispatch(2): " << dim_m << "x" << dim_n << "x" << dim_k
613*5f39d1b3SJooyung Han               << " : " << fixed_m << "x" << 0 << std::endl
614*5f39d1b3SJooyung Han               << std::flush;
615*5f39d1b3SJooyung Han #endif
616*5f39d1b3SJooyung Han #endif
617*5f39d1b3SJooyung Han     if (n == 0) {
618*5f39d1b3SJooyung Han       Dispatch3DStage3<E, P, dim_m, dim_n, dim_k, fixed_m, 0,
619*5f39d1b3SJooyung Han                        dim_k - 1>::Execute(params, k);
620*5f39d1b3SJooyung Han     } else {
621*5f39d1b3SJooyung Han       std::cerr << "FATAL: dispatch3DStage2 failed: ran out of cases."
622*5f39d1b3SJooyung Han                 << std::endl
623*5f39d1b3SJooyung Han                 << std::flush;
624*5f39d1b3SJooyung Han       std::exit(1);
625*5f39d1b3SJooyung Han     }
626*5f39d1b3SJooyung Han   }
627*5f39d1b3SJooyung Han };
628*5f39d1b3SJooyung Han 
629*5f39d1b3SJooyung Han // Stage 1.
630*5f39d1b3SJooyung Han 
631*5f39d1b3SJooyung Han template <typename E, typename P, int dim_m, int dim_n, int dim_k,
632*5f39d1b3SJooyung Han           int variable_m>
633*5f39d1b3SJooyung Han struct Dispatch3DStage1 {
634*5f39d1b3SJooyung Han   static void Execute(const P& params, int m, int n, int k) {
635*5f39d1b3SJooyung Han #ifdef DEBUG
636*5f39d1b3SJooyung Han #ifdef DEBUG_METAGEMM_VERBOSE
637*5f39d1b3SJooyung Han     std::cout << "Dispatch(1): " << dim_m << "x" << dim_n << "x" << dim_k
638*5f39d1b3SJooyung Han               << " : " << variable_m << std::endl
639*5f39d1b3SJooyung Han               << std::flush;
640*5f39d1b3SJooyung Han #endif
641*5f39d1b3SJooyung Han #endif
642*5f39d1b3SJooyung Han     if (m == variable_m) {
643*5f39d1b3SJooyung Han       Dispatch3DStage2<E, P, dim_m, dim_n, dim_k, variable_m,
644*5f39d1b3SJooyung Han                        dim_n - 1>::Execute(params, n, k);
645*5f39d1b3SJooyung Han     } else {
646*5f39d1b3SJooyung Han       Dispatch3DStage1<E, P, dim_m, dim_n, dim_k, variable_m - 1>::Execute(
647*5f39d1b3SJooyung Han           params, m, n, k);
648*5f39d1b3SJooyung Han     }
649*5f39d1b3SJooyung Han   }
650*5f39d1b3SJooyung Han };
651*5f39d1b3SJooyung Han 
652*5f39d1b3SJooyung Han template <typename E, typename P, int dim_m, int dim_n, int dim_k>
653*5f39d1b3SJooyung Han struct Dispatch3DStage1<E, P, dim_m, dim_n, dim_k, 0> {
654*5f39d1b3SJooyung Han   static void Execute(const P& params, int m, int n, int k) {
655*5f39d1b3SJooyung Han #ifdef DEBUG
656*5f39d1b3SJooyung Han #ifdef DEBUG_METAGEMM_VERBOSE
657*5f39d1b3SJooyung Han     std::cout << "Dispatch(1): " << dim_m << "x" << dim_n << "x" << dim_k
658*5f39d1b3SJooyung Han               << " : " << 0 << std::endl
659*5f39d1b3SJooyung Han               << std::flush;
660*5f39d1b3SJooyung Han #endif
661*5f39d1b3SJooyung Han #endif
662*5f39d1b3SJooyung Han     if (m == 0) {
663*5f39d1b3SJooyung Han       Dispatch3DStage2<E, P, dim_m, dim_n, dim_k, 0, dim_n - 1>::Execute(params,
664*5f39d1b3SJooyung Han                                                                          n, k);
665*5f39d1b3SJooyung Han     } else {
666*5f39d1b3SJooyung Han       std::cerr << "FATAL: dispatch3DStage1 failed: ran out of cases."
667*5f39d1b3SJooyung Han                 << std::endl
668*5f39d1b3SJooyung Han                 << std::flush;
669*5f39d1b3SJooyung Han       std::exit(1);
670*5f39d1b3SJooyung Han     }
671*5f39d1b3SJooyung Han   }
672*5f39d1b3SJooyung Han };
673*5f39d1b3SJooyung Han 
674*5f39d1b3SJooyung Han }  // namespace internal
675*5f39d1b3SJooyung Han 
676*5f39d1b3SJooyung Han template <typename Executor, typename Params, int kernel_m, int kernel_n,
677*5f39d1b3SJooyung Han           int kernel_k>
678*5f39d1b3SJooyung Han inline void Gemm(const Params& params) {
679*5f39d1b3SJooyung Han   internal::Dispatch3DStage1<Executor, Params, kernel_m, kernel_n, kernel_k,
680*5f39d1b3SJooyung Han                              kernel_m - 1>::Execute(params, params.m % kernel_m,
681*5f39d1b3SJooyung Han                                                     params.n % kernel_n,
682*5f39d1b3SJooyung Han                                                     params.k % kernel_k);
683*5f39d1b3SJooyung Han }
684*5f39d1b3SJooyung Han 
685*5f39d1b3SJooyung Han }  // namespace meta
686*5f39d1b3SJooyung Han }  // namespace gemmlowp
687*5f39d1b3SJooyung Han 
688*5f39d1b3SJooyung Han #endif  // GEMMLOWP_META_SINGLE_THREAD_GEMM_H_
689