xref: /aosp_15_r20/external/gemmlowp/internal/block_params.h (revision 5f39d1b313f0528e11bae88b3029b54b9e1033e7)
1*5f39d1b3SJooyung Han // Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
2*5f39d1b3SJooyung Han //
3*5f39d1b3SJooyung Han // Licensed under the Apache License, Version 2.0 (the "License");
4*5f39d1b3SJooyung Han // you may not use this file except in compliance with the License.
5*5f39d1b3SJooyung Han // You may obtain a copy of the License at
6*5f39d1b3SJooyung Han //
7*5f39d1b3SJooyung Han //     http://www.apache.org/licenses/LICENSE-2.0
8*5f39d1b3SJooyung Han //
9*5f39d1b3SJooyung Han // Unless required by applicable law or agreed to in writing, software
10*5f39d1b3SJooyung Han // distributed under the License is distributed on an "AS IS" BASIS,
11*5f39d1b3SJooyung Han // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*5f39d1b3SJooyung Han // See the License for the specific language governing permissions and
13*5f39d1b3SJooyung Han // limitations under the License.
14*5f39d1b3SJooyung Han 
15*5f39d1b3SJooyung Han // block_params.h: Logic to choose L1 and L2 block sizes
16*5f39d1b3SJooyung Han // to optimize cache-friendliness.
17*5f39d1b3SJooyung Han 
18*5f39d1b3SJooyung Han #ifndef GEMMLOWP_INTERNAL_BLOCK_PARAMS_H_
19*5f39d1b3SJooyung Han #define GEMMLOWP_INTERNAL_BLOCK_PARAMS_H_
20*5f39d1b3SJooyung Han 
21*5f39d1b3SJooyung Han #include "common.h"
22*5f39d1b3SJooyung Han 
23*5f39d1b3SJooyung Han namespace gemmlowp {
24*5f39d1b3SJooyung Han 
25*5f39d1b3SJooyung Han // A BlockParams instance contains a full description of all the block size
26*5f39d1b3SJooyung Han // parameters to be used by a Gemm.
27*5f39d1b3SJooyung Han // There are two nested levels of block subdivisions: first a subdivision
28*5f39d1b3SJooyung Han // into large blocks that should fit in last-level cache (what we call L2 here)
29*5f39d1b3SJooyung Han // and then another subdivision into smaller blocks that should fit in
30*5f39d1b3SJooyung Han // L1 cache. There is then actually a third level of subdivision to fit
31*5f39d1b3SJooyung Han // in registers, but we are not concerned with that here.
32*5f39d1b3SJooyung Han struct BlockParams {
33*5f39d1b3SJooyung Han   // L1 block parameters determine the size of small blocks that should
34*5f39d1b3SJooyung Han   // fit in L1 cache.
35*5f39d1b3SJooyung Han   int l1_rows;
36*5f39d1b3SJooyung Han   int l1_cols;
37*5f39d1b3SJooyung Han   int l1_depth;
38*5f39d1b3SJooyung Han 
39*5f39d1b3SJooyung Han   // L2 block parameters determine the size of larger blocks that should
40*5f39d1b3SJooyung Han   // fit in L2 cache.
41*5f39d1b3SJooyung Han   int l2_rows;
42*5f39d1b3SJooyung Han   int l2_cols;
43*5f39d1b3SJooyung Han   int l2_depth;
44*5f39d1b3SJooyung Han 
45*5f39d1b3SJooyung Han   template <typename KernelFormat>
InitBlockParams46*5f39d1b3SJooyung Han   void Init(int rows, int cols, int depth, int num_threads, int l1_bytes_to_use,
47*5f39d1b3SJooyung Han             int l2_bytes_to_use, float l2_rhs_factor) {
48*5f39d1b3SJooyung Han     FindL2BlockSizes<KernelFormat>(rows, cols, depth, num_threads,
49*5f39d1b3SJooyung Han                                    l2_bytes_to_use, l2_rhs_factor, &l2_rows,
50*5f39d1b3SJooyung Han                                    &l2_cols, &l2_depth);
51*5f39d1b3SJooyung Han     FindL1BlockSizes<KernelFormat>(l2_rows, l2_cols, l2_depth, l1_bytes_to_use,
52*5f39d1b3SJooyung Han                                    &l1_rows, &l1_cols, &l1_depth);
53*5f39d1b3SJooyung Han   }
54*5f39d1b3SJooyung Han 
55*5f39d1b3SJooyung Han   template <typename KernelFormat>
FindL2BlockSizesBlockParams56*5f39d1b3SJooyung Han   static void FindL2BlockSizes(int rows, int cols, int depth, int num_threads,
57*5f39d1b3SJooyung Han                                int l2_bytes_to_use, float l2_rhs_factor,
58*5f39d1b3SJooyung Han                                int* out_l2_rows, int* out_l2_cols,
59*5f39d1b3SJooyung Han                                int* out_l2_depth) {
60*5f39d1b3SJooyung Han     int l2_rows = 0;
61*5f39d1b3SJooyung Han     int l2_cols = 0;
62*5f39d1b3SJooyung Han     int l2_depth = 0;
63*5f39d1b3SJooyung Han 
64*5f39d1b3SJooyung Han     int per_thread_rows =
65*5f39d1b3SJooyung Han         std::max(1, RoundUp<KernelFormat::kRows>(rows) / num_threads);
66*5f39d1b3SJooyung Han 
67*5f39d1b3SJooyung Han     // No L2 blocking in the depth dimension at the moment.
68*5f39d1b3SJooyung Han     // Too much loss of accuracy due to storing intermediate results in
69*5f39d1b3SJooyung Han     // low precision.
70*5f39d1b3SJooyung Han     // However, we still want to round l2_depth up to the next multiple
71*5f39d1b3SJooyung Han     // of register size, so as to avoid having to special-case unaligned depths.
72*5f39d1b3SJooyung Han     l2_depth = RoundUp<kRegisterSize>(depth);
73*5f39d1b3SJooyung Han 
74*5f39d1b3SJooyung Han     {
75*5f39d1b3SJooyung Han       int max_cache_friendly_l2_cols = std::max(
76*5f39d1b3SJooyung Han           1, static_cast<int>(l2_rhs_factor * (l2_bytes_to_use / l2_depth)));
77*5f39d1b3SJooyung Han       int min_l2_cols_blocks =
78*5f39d1b3SJooyung Han           std::max(1, CeilQuotient(cols, max_cache_friendly_l2_cols));
79*5f39d1b3SJooyung Han       l2_cols =
80*5f39d1b3SJooyung Han           RoundUp<KernelFormat::kCols>(CeilQuotient(cols, min_l2_cols_blocks));
81*5f39d1b3SJooyung Han     }
82*5f39d1b3SJooyung Han 
83*5f39d1b3SJooyung Han     // No L2 blocking in the row dimension if l2_rhs_factor is 1.0 as the row
84*5f39d1b3SJooyung Han     // dimension concerns only the LHS. Blocking only RHS matrix for L2 enhances
85*5f39d1b3SJooyung Han     // the performance on x86.
86*5f39d1b3SJooyung Han     if (l2_rhs_factor == 1.0f) {
87*5f39d1b3SJooyung Han       l2_rows = RoundUp<KernelFormat::kRows>(per_thread_rows);
88*5f39d1b3SJooyung Han     } else {
89*5f39d1b3SJooyung Han       int max_cache_friendly_l2_rows =
90*5f39d1b3SJooyung Han           std::max(1, (l2_bytes_to_use - l2_depth * l2_cols) /
91*5f39d1b3SJooyung Han                           (num_threads * (l2_depth + 4 * l2_cols)));
92*5f39d1b3SJooyung Han       int min_l2_rows_blocks = std::max(
93*5f39d1b3SJooyung Han           1, CeilQuotient(per_thread_rows, max_cache_friendly_l2_rows));
94*5f39d1b3SJooyung Han       l2_rows = RoundUp<KernelFormat::kRows>(
95*5f39d1b3SJooyung Han           CeilQuotient(per_thread_rows, min_l2_rows_blocks));
96*5f39d1b3SJooyung Han     }
97*5f39d1b3SJooyung Han 
98*5f39d1b3SJooyung Han     *out_l2_rows = l2_rows;
99*5f39d1b3SJooyung Han     *out_l2_cols = l2_cols;
100*5f39d1b3SJooyung Han     *out_l2_depth = l2_depth;
101*5f39d1b3SJooyung Han   }
102*5f39d1b3SJooyung Han 
103*5f39d1b3SJooyung Han   template <typename KernelFormat>
FindL1BlockSizesBlockParams104*5f39d1b3SJooyung Han   static void FindL1BlockSizes(int rows, int cols, int depth,
105*5f39d1b3SJooyung Han                                int l1_bytes_to_use, int* out_l1_rows,
106*5f39d1b3SJooyung Han                                int* out_l1_cols, int* out_l1_depth) {
107*5f39d1b3SJooyung Han     int l1_rows = 0;
108*5f39d1b3SJooyung Han     int l1_cols = 0;
109*5f39d1b3SJooyung Han     int l1_depth = 0;
110*5f39d1b3SJooyung Han 
111*5f39d1b3SJooyung Han     // L2 block sizes should already be multiples of kernel block sizes.
112*5f39d1b3SJooyung Han     assert(rows % KernelFormat::kRows == 0);
113*5f39d1b3SJooyung Han     assert(cols % KernelFormat::kCols == 0);
114*5f39d1b3SJooyung Han     assert(depth % KernelFormat::kDepth == 0);
115*5f39d1b3SJooyung Han 
116*5f39d1b3SJooyung Han     // No L1 blocking in the columns dimension at the moment.
117*5f39d1b3SJooyung Han     // Thought not to be needed. Similar to Eigen.
118*5f39d1b3SJooyung Han     l1_cols = cols;
119*5f39d1b3SJooyung Han 
120*5f39d1b3SJooyung Han     {
121*5f39d1b3SJooyung Han       int max_cache_friendly_l1_depth = std::max(
122*5f39d1b3SJooyung Han           1, (l1_bytes_to_use - 4 * KernelFormat::kRows * KernelFormat::kCols) /
123*5f39d1b3SJooyung Han                  (KernelFormat::kRows + KernelFormat::kCols));
124*5f39d1b3SJooyung Han       int min_l1_depth_blocks =
125*5f39d1b3SJooyung Han           std::max(1, CeilQuotient(depth, max_cache_friendly_l1_depth));
126*5f39d1b3SJooyung Han       l1_depth =
127*5f39d1b3SJooyung Han           RoundUp<kRegisterSize>(CeilQuotient(depth, min_l1_depth_blocks));
128*5f39d1b3SJooyung Han     }
129*5f39d1b3SJooyung Han 
130*5f39d1b3SJooyung Han     {
131*5f39d1b3SJooyung Han       int max_cache_friendly_l1_rows =
132*5f39d1b3SJooyung Han           std::max(1, l1_bytes_to_use / (l1_depth + 4 * l1_cols));
133*5f39d1b3SJooyung Han       int min_l1_rows_blocks =
134*5f39d1b3SJooyung Han           std::max(1, CeilQuotient(rows, max_cache_friendly_l1_rows));
135*5f39d1b3SJooyung Han       l1_rows =
136*5f39d1b3SJooyung Han           RoundUp<KernelFormat::kRows>(CeilQuotient(rows, min_l1_rows_blocks));
137*5f39d1b3SJooyung Han     }
138*5f39d1b3SJooyung Han 
139*5f39d1b3SJooyung Han     *out_l1_rows = l1_rows;
140*5f39d1b3SJooyung Han     *out_l1_cols = l1_cols;
141*5f39d1b3SJooyung Han     *out_l1_depth = l1_depth;
142*5f39d1b3SJooyung Han   }
143*5f39d1b3SJooyung Han };
144*5f39d1b3SJooyung Han 
145*5f39d1b3SJooyung Han // A SideBlockParams instance contains only the block params relevant to
146*5f39d1b3SJooyung Han // one side (LHS or RHS), expressed in terms of 'width' instead of
147*5f39d1b3SJooyung Han // rows/colums. See the explanation in kernel.h: in the LHS, 'width' means
148*5f39d1b3SJooyung Han // the number of rows, while in the RHS, 'width' means the number of columns.
149*5f39d1b3SJooyung Han // That allows us to write generic code that applies to either LHS or RHS.
150*5f39d1b3SJooyung Han struct SideBlockParams {
151*5f39d1b3SJooyung Han   // L1 block parameters determine the size of small blocks that should
152*5f39d1b3SJooyung Han   // fit in L1 cache.
153*5f39d1b3SJooyung Han   int l1_width;
154*5f39d1b3SJooyung Han   int l1_depth;
155*5f39d1b3SJooyung Han 
156*5f39d1b3SJooyung Han   // L2 block parameters determine the size of larger blocks that should
157*5f39d1b3SJooyung Han   // fit in L2 cache.
158*5f39d1b3SJooyung Han   int l2_width;
159*5f39d1b3SJooyung Han   int l2_depth;
160*5f39d1b3SJooyung Han };
161*5f39d1b3SJooyung Han 
162*5f39d1b3SJooyung Han enum class Side { Lhs, Rhs };
163*5f39d1b3SJooyung Han 
GetSideBlockParams(Side side,SideBlockParams * side_block_params,const BlockParams & block_params)164*5f39d1b3SJooyung Han inline void GetSideBlockParams(Side side, SideBlockParams* side_block_params,
165*5f39d1b3SJooyung Han                                const BlockParams& block_params) {
166*5f39d1b3SJooyung Han   side_block_params->l1_width =
167*5f39d1b3SJooyung Han       side == Side::Lhs ? block_params.l1_rows : block_params.l1_cols;
168*5f39d1b3SJooyung Han   side_block_params->l2_width =
169*5f39d1b3SJooyung Han       side == Side::Lhs ? block_params.l2_rows : block_params.l2_cols;
170*5f39d1b3SJooyung Han 
171*5f39d1b3SJooyung Han   side_block_params->l1_depth = block_params.l1_depth;
172*5f39d1b3SJooyung Han   side_block_params->l2_depth = block_params.l2_depth;
173*5f39d1b3SJooyung Han }
174*5f39d1b3SJooyung Han 
175*5f39d1b3SJooyung Han }  // namespace gemmlowp
176*5f39d1b3SJooyung Han 
177*5f39d1b3SJooyung Han #endif  // GEMMLOWP_INTERNAL_BLOCK_PARAMS_H_
178