xref: /aosp_15_r20/external/gemmlowp/internal/dispatch_gemm_shape.h (revision 5f39d1b313f0528e11bae88b3029b54b9e1033e7)
1*5f39d1b3SJooyung Han // Copyright 2017 The Gemmlowp Authors. All Rights Reserved.
2*5f39d1b3SJooyung Han //
3*5f39d1b3SJooyung Han // Licensed under the Apache License, Version 2.0 (the "License");
4*5f39d1b3SJooyung Han // you may not use this file except in compliance with the License.
5*5f39d1b3SJooyung Han // You may obtain a copy of the License at
6*5f39d1b3SJooyung Han //
7*5f39d1b3SJooyung Han //     http://www.apache.org/licenses/LICENSE-2.0
8*5f39d1b3SJooyung Han //
9*5f39d1b3SJooyung Han // Unless required by applicable law or agreed to in writing, software
10*5f39d1b3SJooyung Han // distributed under the License is distributed on an "AS IS" BASIS,
11*5f39d1b3SJooyung Han // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*5f39d1b3SJooyung Han // See the License for the specific language governing permissions and
13*5f39d1b3SJooyung Han // limitations under the License.
14*5f39d1b3SJooyung Han 
15*5f39d1b3SJooyung Han // dispatch_gemm_shape.h: dispatch GEMM calls according to their shape
16*5f39d1b3SJooyung Han 
17*5f39d1b3SJooyung Han #ifndef GEMMLOWP_INTERNAL_DISPATCH_GEMM_SHAPE_H_
18*5f39d1b3SJooyung Han #define GEMMLOWP_INTERNAL_DISPATCH_GEMM_SHAPE_H_
19*5f39d1b3SJooyung Han 
20*5f39d1b3SJooyung Han #include "../internal/kernel_default.h"
21*5f39d1b3SJooyung Han #include "../public/map.h"
22*5f39d1b3SJooyung Han #include "../public/output_stages.h"
23*5f39d1b3SJooyung Han #include "multi_thread_gemm.h"
24*5f39d1b3SJooyung Han 
25*5f39d1b3SJooyung Han namespace gemmlowp {
26*5f39d1b3SJooyung Han 
27*5f39d1b3SJooyung Han template <typename T>
28*5f39d1b3SJooyung Han struct TransposeImpl {
29*5f39d1b3SJooyung Han   typedef T DstType;
RunTransposeImpl30*5f39d1b3SJooyung Han   static T Run(const T& t) { return t; }
31*5f39d1b3SJooyung Han };
32*5f39d1b3SJooyung Han 
33*5f39d1b3SJooyung Han template <typename T>
34*5f39d1b3SJooyung Han using TransposeType = typename TransposeImpl<T>::DstType;
35*5f39d1b3SJooyung Han 
36*5f39d1b3SJooyung Han template <typename T>
Transpose(const T & t)37*5f39d1b3SJooyung Han TransposeType<T> Transpose(const T& t) {
38*5f39d1b3SJooyung Han   return TransposeImpl<T>::Run(t);
39*5f39d1b3SJooyung Han }
40*5f39d1b3SJooyung Han 
41*5f39d1b3SJooyung Han template <MapOrder Order>
42*5f39d1b3SJooyung Han struct TransposeMapOrder {
43*5f39d1b3SJooyung Han   static constexpr MapOrder Value =
44*5f39d1b3SJooyung Han       Order == MapOrder::RowMajor ? MapOrder::ColMajor : MapOrder::RowMajor;
45*5f39d1b3SJooyung Han };
46*5f39d1b3SJooyung Han 
47*5f39d1b3SJooyung Han template <VectorShape Shape>
48*5f39d1b3SJooyung Han struct TransposeVectorShape {
49*5f39d1b3SJooyung Han   static constexpr VectorShape Value =
50*5f39d1b3SJooyung Han       Shape == VectorShape::Row ? VectorShape::Col : VectorShape::Row;
51*5f39d1b3SJooyung Han };
52*5f39d1b3SJooyung Han 
53*5f39d1b3SJooyung Han template <typename Scalar, VectorShape Shape>
54*5f39d1b3SJooyung Han struct TransposeImpl<VectorMap<Scalar, Shape>> {
55*5f39d1b3SJooyung Han   typedef VectorMap<Scalar, Shape> SrcType;
56*5f39d1b3SJooyung Han   static constexpr VectorShape TransposedShape =
57*5f39d1b3SJooyung Han       TransposeVectorShape<Shape>::Value;
58*5f39d1b3SJooyung Han   typedef VectorMap<Scalar, TransposedShape> DstType;
59*5f39d1b3SJooyung Han   static DstType Run(const SrcType& src) {
60*5f39d1b3SJooyung Han     return DstType(src.data(), src.size());
61*5f39d1b3SJooyung Han   }
62*5f39d1b3SJooyung Han };
63*5f39d1b3SJooyung Han 
64*5f39d1b3SJooyung Han template <typename Scalar, MapOrder Order>
65*5f39d1b3SJooyung Han struct TransposeImpl<MatrixMap<Scalar, Order>> {
66*5f39d1b3SJooyung Han   typedef MatrixMap<Scalar, Order> SrcType;
67*5f39d1b3SJooyung Han   static constexpr MapOrder TransposedOrder = TransposeMapOrder<Order>::Value;
68*5f39d1b3SJooyung Han   typedef MatrixMap<Scalar, TransposedOrder> DstType;
69*5f39d1b3SJooyung Han   static DstType Run(const SrcType& src) {
70*5f39d1b3SJooyung Han     return DstType(src.data(), src.cols(), src.rows(), src.stride());
71*5f39d1b3SJooyung Han   }
72*5f39d1b3SJooyung Han };
73*5f39d1b3SJooyung Han 
74*5f39d1b3SJooyung Han template <VectorShape Shape>
75*5f39d1b3SJooyung Han struct TransposeImpl<OutputStageQuantizeDownInt32ToUint8ScalePC<Shape>> {
76*5f39d1b3SJooyung Han   typedef OutputStageQuantizeDownInt32ToUint8ScalePC<Shape> SrcType;
77*5f39d1b3SJooyung Han   static constexpr VectorShape TransposedShape =
78*5f39d1b3SJooyung Han       TransposeVectorShape<Shape>::Value;
79*5f39d1b3SJooyung Han   typedef OutputStageQuantizeDownInt32ToUint8ScalePC<TransposedShape> DstType;
80*5f39d1b3SJooyung Han   static DstType Run(const SrcType& src) {
81*5f39d1b3SJooyung Han     DstType dst;
82*5f39d1b3SJooyung Han     dst.result_shift = src.result_shift;
83*5f39d1b3SJooyung Han     dst.result_offset = Transpose(src.result_offset);
84*5f39d1b3SJooyung Han     dst.result_mult_int = Transpose(src.result_mult_int);
85*5f39d1b3SJooyung Han     return dst;
86*5f39d1b3SJooyung Han   }
87*5f39d1b3SJooyung Han };
88*5f39d1b3SJooyung Han 
89*5f39d1b3SJooyung Han template <VectorShape Shape>
90*5f39d1b3SJooyung Han struct TransposeImpl<OutputStageScaleInt32ByFixedPointAndExponentPC<Shape>> {
91*5f39d1b3SJooyung Han   typedef OutputStageScaleInt32ByFixedPointAndExponentPC<Shape> SrcType;
92*5f39d1b3SJooyung Han   static constexpr VectorShape TransposedShape =
93*5f39d1b3SJooyung Han       TransposeVectorShape<Shape>::Value;
94*5f39d1b3SJooyung Han   typedef OutputStageScaleInt32ByFixedPointAndExponentPC<TransposedShape>
95*5f39d1b3SJooyung Han       DstType;
96*5f39d1b3SJooyung Han   static DstType Run(const SrcType& src) {
97*5f39d1b3SJooyung Han     DstType dst;
98*5f39d1b3SJooyung Han     dst.result_fixedpoint_multiplier =
99*5f39d1b3SJooyung Han         Transpose(src.result_fixedpoint_multiplier);
100*5f39d1b3SJooyung Han     dst.result_exponent = Transpose(src.result_exponent);
101*5f39d1b3SJooyung Han     dst.result_offset_after_shift = src.result_offset_after_shift;
102*5f39d1b3SJooyung Han     return dst;
103*5f39d1b3SJooyung Han   }
104*5f39d1b3SJooyung Han };
105*5f39d1b3SJooyung Han 
106*5f39d1b3SJooyung Han template <typename VectorMapType>
107*5f39d1b3SJooyung Han struct TransposeImpl<OutputStageBiasAddition<VectorMapType>> {
108*5f39d1b3SJooyung Han   typedef OutputStageBiasAddition<VectorMapType> SrcType;
109*5f39d1b3SJooyung Han   typedef TransposeType<VectorMapType> TransposedVectorMapType;
110*5f39d1b3SJooyung Han   typedef OutputStageBiasAddition<TransposedVectorMapType> DstType;
111*5f39d1b3SJooyung Han   static DstType Run(const SrcType& src) {
112*5f39d1b3SJooyung Han     DstType dst;
113*5f39d1b3SJooyung Han     dst.bias_vector = Transpose(src.bias_vector);
114*5f39d1b3SJooyung Han     return dst;
115*5f39d1b3SJooyung Han   }
116*5f39d1b3SJooyung Han };
117*5f39d1b3SJooyung Han 
118*5f39d1b3SJooyung Han // TODO(benoitjacob) - does anyone understand C++ variadic templates?
119*5f39d1b3SJooyung Han // How to use them to implement TransposeTuple? Note: there are lots
120*5f39d1b3SJooyung Han // of answers on StackOverflow but they seem to all involve either
121*5f39d1b3SJooyung Han // C++14/C++17 (we can only use C++11) or lots of abstract nonsense.
122*5f39d1b3SJooyung Han inline std::tuple<> TransposeTuple(const std::tuple<>& t) { return t; }
123*5f39d1b3SJooyung Han 
124*5f39d1b3SJooyung Han template <typename T0>
125*5f39d1b3SJooyung Han std::tuple<TransposeType<T0>> TransposeTuple(const std::tuple<T0>& t) {
126*5f39d1b3SJooyung Han   return std::make_tuple(Transpose(std::get<0>(t)));
127*5f39d1b3SJooyung Han }
128*5f39d1b3SJooyung Han 
129*5f39d1b3SJooyung Han template <typename T0, typename T1>
130*5f39d1b3SJooyung Han std::tuple<TransposeType<T0>, TransposeType<T1>> TransposeTuple(
131*5f39d1b3SJooyung Han     const std::tuple<T0, T1>& t) {
132*5f39d1b3SJooyung Han   return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)));
133*5f39d1b3SJooyung Han }
134*5f39d1b3SJooyung Han 
135*5f39d1b3SJooyung Han template <typename T0, typename T1, typename T2>
136*5f39d1b3SJooyung Han std::tuple<TransposeType<T0>, TransposeType<T1>, TransposeType<T2>>
137*5f39d1b3SJooyung Han TransposeTuple(const std::tuple<T0, T1, T2>& t) {
138*5f39d1b3SJooyung Han   return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)),
139*5f39d1b3SJooyung Han                          Transpose(std::get<2>(t)));
140*5f39d1b3SJooyung Han }
141*5f39d1b3SJooyung Han 
142*5f39d1b3SJooyung Han template <typename T0, typename T1, typename T2, typename T3>
143*5f39d1b3SJooyung Han std::tuple<TransposeType<T0>, TransposeType<T1>, TransposeType<T2>,
144*5f39d1b3SJooyung Han            TransposeType<T3>>
145*5f39d1b3SJooyung Han TransposeTuple(const std::tuple<T0, T1, T2, T3>& t) {
146*5f39d1b3SJooyung Han   return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)),
147*5f39d1b3SJooyung Han                          Transpose(std::get<2>(t)), Transpose(std::get<3>(t)));
148*5f39d1b3SJooyung Han }
149*5f39d1b3SJooyung Han 
150*5f39d1b3SJooyung Han template <typename T0, typename T1, typename T2, typename T3, typename T4>
151*5f39d1b3SJooyung Han std::tuple<TransposeType<T0>, TransposeType<T1>, TransposeType<T2>,
152*5f39d1b3SJooyung Han            TransposeType<T3>, TransposeType<T4>>
153*5f39d1b3SJooyung Han TransposeTuple(const std::tuple<T0, T1, T2, T3, T4>& t) {
154*5f39d1b3SJooyung Han   return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)),
155*5f39d1b3SJooyung Han                          Transpose(std::get<2>(t)), Transpose(std::get<3>(t)),
156*5f39d1b3SJooyung Han                          Transpose(std::get<4>(t)));
157*5f39d1b3SJooyung Han }
158*5f39d1b3SJooyung Han 
159*5f39d1b3SJooyung Han template <typename T0, typename T1, typename T2, typename T3, typename T4,
160*5f39d1b3SJooyung Han           typename T5>
161*5f39d1b3SJooyung Han std::tuple<TransposeType<T0>, TransposeType<T1>, TransposeType<T2>,
162*5f39d1b3SJooyung Han            TransposeType<T3>, TransposeType<T4>, TransposeType<T5>>
163*5f39d1b3SJooyung Han TransposeTuple(const std::tuple<T0, T1, T2, T3, T4, T5>& t) {
164*5f39d1b3SJooyung Han   return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)),
165*5f39d1b3SJooyung Han                          Transpose(std::get<2>(t)), Transpose(std::get<3>(t)),
166*5f39d1b3SJooyung Han                          Transpose(std::get<4>(t)), Transpose(std::get<5>(t)));
167*5f39d1b3SJooyung Han }
168*5f39d1b3SJooyung Han 
169*5f39d1b3SJooyung Han template <typename InputScalar, typename OutputScalar, typename BitDepthParams,
170*5f39d1b3SJooyung Han           MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder,
171*5f39d1b3SJooyung Han           typename LhsOffset, typename RhsOffset, typename OutputPipelineType,
172*5f39d1b3SJooyung Han           typename GemmContextType>
173*5f39d1b3SJooyung Han void DispatchGemmShape(GemmContextType* context,
174*5f39d1b3SJooyung Han                        const MatrixMap<const InputScalar, LhsOrder>& lhs,
175*5f39d1b3SJooyung Han                        const MatrixMap<const InputScalar, RhsOrder>& rhs,
176*5f39d1b3SJooyung Han                        MatrixMap<OutputScalar, ResultOrder>* result,
177*5f39d1b3SJooyung Han                        const LhsOffset& lhs_offset, const RhsOffset& rhs_offset,
178*5f39d1b3SJooyung Han                        const OutputPipelineType& output_pipeline) {
179*5f39d1b3SJooyung Han   assert(lhs.cols() == rhs.rows());
180*5f39d1b3SJooyung Han 
181*5f39d1b3SJooyung Han   int rows = result->rows();
182*5f39d1b3SJooyung Han   int cols = result->cols();
183*5f39d1b3SJooyung Han   int depth = lhs.cols();
184*5f39d1b3SJooyung Han 
185*5f39d1b3SJooyung Han   if (rows == 0 || cols == 0 || depth == 0) {
186*5f39d1b3SJooyung Han     // Vacuous GEMM, return early to avoid having to deal with
187*5f39d1b3SJooyung Han     // zero sizes below.
188*5f39d1b3SJooyung Han     return;
189*5f39d1b3SJooyung Han   }
190*5f39d1b3SJooyung Han 
191*5f39d1b3SJooyung Han   if (rows < cols) {
192*5f39d1b3SJooyung Han     auto transposed_result_map = Transpose(*result);
193*5f39d1b3SJooyung Han     return DispatchGemmShape<InputScalar, OutputScalar, BitDepthParams>(
194*5f39d1b3SJooyung Han         context, Transpose(rhs), Transpose(lhs), &transposed_result_map,
195*5f39d1b3SJooyung Han         Transpose(rhs_offset), Transpose(lhs_offset),
196*5f39d1b3SJooyung Han         TransposeTuple(output_pipeline));
197*5f39d1b3SJooyung Han   }
198*5f39d1b3SJooyung Han 
199*5f39d1b3SJooyung Han   typedef DefaultKernel<BitDepthParams> Kernel;
200*5f39d1b3SJooyung Han   MultiThreadGemm<typename Kernel::Format, InputScalar, OutputScalar,
201*5f39d1b3SJooyung Han                   BitDepthParams>(context, Kernel(), lhs, rhs, result,
202*5f39d1b3SJooyung Han                                   lhs_offset, rhs_offset, output_pipeline);
203*5f39d1b3SJooyung Han }
204*5f39d1b3SJooyung Han 
205*5f39d1b3SJooyung Han }  // end namespace gemmlowp
206*5f39d1b3SJooyung Han 
207*5f39d1b3SJooyung Han #endif  // GEMMLOWP_INTERNAL_DISPATCH_GEMM_SHAPE_H_
208