xref: /aosp_15_r20/external/gemmlowp/meta/base.h (revision 5f39d1b313f0528e11bae88b3029b54b9e1033e7)
1*5f39d1b3SJooyung Han // Copyright 2016 The Gemmlowp Authors. All Rights Reserved.
2*5f39d1b3SJooyung Han //
3*5f39d1b3SJooyung Han // Licensed under the Apache License, Version 2.0 (the "License");
4*5f39d1b3SJooyung Han // you may not use this file except in compliance with the License.
5*5f39d1b3SJooyung Han // You may obtain a copy of the License at
6*5f39d1b3SJooyung Han //
7*5f39d1b3SJooyung Han //     http://www.apache.org/licenses/LICENSE-2.0
8*5f39d1b3SJooyung Han //
9*5f39d1b3SJooyung Han // Unless required by applicable law or agreed to in writing, software
10*5f39d1b3SJooyung Han // distributed under the License is distributed on an "AS IS" BASIS,
11*5f39d1b3SJooyung Han // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*5f39d1b3SJooyung Han // See the License for the specific language governing permissions and
13*5f39d1b3SJooyung Han // limitations under the License.
14*5f39d1b3SJooyung Han 
15*5f39d1b3SJooyung Han #ifndef GEMMLOWP_META_BASE_H_
16*5f39d1b3SJooyung Han #define GEMMLOWP_META_BASE_H_
17*5f39d1b3SJooyung Han 
18*5f39d1b3SJooyung Han #include <cassert>
19*5f39d1b3SJooyung Han #include <cstdint>
20*5f39d1b3SJooyung Han 
21*5f39d1b3SJooyung Han #include "../internal/common.h"
22*5f39d1b3SJooyung Han 
23*5f39d1b3SJooyung Han namespace gemmlowp {
24*5f39d1b3SJooyung Han namespace meta {
25*5f39d1b3SJooyung Han 
26*5f39d1b3SJooyung Han template <int align>
AlignTo(int value)27*5f39d1b3SJooyung Han inline int AlignTo(int value) {
28*5f39d1b3SJooyung Han   return ((value + align - 1) / align) * align;
29*5f39d1b3SJooyung Han }
30*5f39d1b3SJooyung Han 
AlignTo(int align,int value)31*5f39d1b3SJooyung Han inline int AlignTo(int align, int value) {
32*5f39d1b3SJooyung Han   return ((value + align - 1) / align) * align;
33*5f39d1b3SJooyung Han }
34*5f39d1b3SJooyung Han 
35*5f39d1b3SJooyung Han template <typename Kernel_, typename OutputStream_>
36*5f39d1b3SJooyung Han struct FusedKernelParams {
37*5f39d1b3SJooyung Han  public:
38*5f39d1b3SJooyung Han   typedef Kernel_ Kernel;
39*5f39d1b3SJooyung Han   typedef OutputStream_ OutputStream;
40*5f39d1b3SJooyung Han 
41*5f39d1b3SJooyung Han   Kernel kernel;
42*5f39d1b3SJooyung Han   OutputStream output_stream;
43*5f39d1b3SJooyung Han };
44*5f39d1b3SJooyung Han 
45*5f39d1b3SJooyung Han template <typename InType_, typename OutType_, typename LeftStream_,
46*5f39d1b3SJooyung Han           typename RightStream_, typename Kernel_, typename OutputStream_>
47*5f39d1b3SJooyung Han struct GemmParams {
48*5f39d1b3SJooyung Han  public:
49*5f39d1b3SJooyung Han   typedef InType_ InType;
50*5f39d1b3SJooyung Han   typedef OutType_ OutType;
51*5f39d1b3SJooyung Han   typedef LeftStream_ LeftStream;
52*5f39d1b3SJooyung Han   typedef RightStream_ RightStream;
53*5f39d1b3SJooyung Han   typedef Kernel_ Kernel;
54*5f39d1b3SJooyung Han   typedef OutputStream_ OutputStream;
55*5f39d1b3SJooyung Han 
56*5f39d1b3SJooyung Han   typedef FusedKernelParams<Kernel, OutputStream> FusedKernel;
57*5f39d1b3SJooyung Han 
58*5f39d1b3SJooyung Han   // Common parameters.
59*5f39d1b3SJooyung Han 
60*5f39d1b3SJooyung Han   int m;
61*5f39d1b3SJooyung Han   int n;
62*5f39d1b3SJooyung Han   int k;
63*5f39d1b3SJooyung Han 
64*5f39d1b3SJooyung Han   const InType* lhs;
65*5f39d1b3SJooyung Han   const InType* rhs;
66*5f39d1b3SJooyung Han   OutType* result;
67*5f39d1b3SJooyung Han   std::uint8_t* scratch;
68*5f39d1b3SJooyung Han 
69*5f39d1b3SJooyung Han   // Specialized parameters.
70*5f39d1b3SJooyung Han 
71*5f39d1b3SJooyung Han   LeftStream left_stream;
72*5f39d1b3SJooyung Han   RightStream right_stream;
73*5f39d1b3SJooyung Han   FusedKernel fused_kernel;
74*5f39d1b3SJooyung Han };
75*5f39d1b3SJooyung Han 
76*5f39d1b3SJooyung Han template <typename InType, int lanes_count, int pack_size, int leftovers,
77*5f39d1b3SJooyung Han           typename StreamParams>
78*5f39d1b3SJooyung Han class Stream {
79*5f39d1b3SJooyung Han  public:
80*5f39d1b3SJooyung Han   static void Pack(const InType* in, const StreamParams& params, InType* out);
81*5f39d1b3SJooyung Han 
82*5f39d1b3SJooyung Han   static int UnpackedAdvance(const StreamParams& params);
83*5f39d1b3SJooyung Han 
84*5f39d1b3SJooyung Han   static int PackedAdvance(const StreamParams& params);
85*5f39d1b3SJooyung Han 
86*5f39d1b3SJooyung Han   static int UnpackedStride(const StreamParams& params);
87*5f39d1b3SJooyung Han 
88*5f39d1b3SJooyung Han   static int PackedStride(const StreamParams& params);
89*5f39d1b3SJooyung Han };
90*5f39d1b3SJooyung Han 
91*5f39d1b3SJooyung Han template <typename InType, typename StreamType>
92*5f39d1b3SJooyung Han class StreamUtil {
93*5f39d1b3SJooyung Han  public:
94*5f39d1b3SJooyung Han   static const InType* Offset(const StreamType& params, const InType* source,
95*5f39d1b3SJooyung Han                               int offset_stride, int offset_advance);
96*5f39d1b3SJooyung Han 
97*5f39d1b3SJooyung Han   static int Scratch(const StreamType& params, int lanes);
98*5f39d1b3SJooyung Han };
99*5f39d1b3SJooyung Han 
100*5f39d1b3SJooyung Han template <typename InType, typename OutType, typename Kernel,
101*5f39d1b3SJooyung Han           typename OutputStream, int kernel_m, int kernel_n, int pack_size>
102*5f39d1b3SJooyung Han class MulKernel {
103*5f39d1b3SJooyung Han  public:
104*5f39d1b3SJooyung Han   static void Multiply(const InType* lhs, const InType* rhs,
105*5f39d1b3SJooyung Han                        const FusedKernelParams<Kernel, OutputStream>& params,
106*5f39d1b3SJooyung Han                        OutType* result);
107*5f39d1b3SJooyung Han };
108*5f39d1b3SJooyung Han 
109*5f39d1b3SJooyung Han template <typename InType_, typename OutType_, typename Kernel_>
110*5f39d1b3SJooyung Han struct Transform1DParams {
111*5f39d1b3SJooyung Han   typedef InType_ InType;
112*5f39d1b3SJooyung Han   typedef OutType_ OutType;
113*5f39d1b3SJooyung Han   typedef Kernel_ Kernel;
114*5f39d1b3SJooyung Han 
115*5f39d1b3SJooyung Han   const InType* input;
116*5f39d1b3SJooyung Han   OutType* output;
117*5f39d1b3SJooyung Han   std::uint8_t* scratch;
118*5f39d1b3SJooyung Han 
119*5f39d1b3SJooyung Han   Kernel kernel;
120*5f39d1b3SJooyung Han };
121*5f39d1b3SJooyung Han 
122*5f39d1b3SJooyung Han template <typename InType, typename OutType, typename Kernel, int kernel_size,
123*5f39d1b3SJooyung Han           int leftovers>
124*5f39d1b3SJooyung Han class Transform1DKernel {
125*5f39d1b3SJooyung Han  public:
126*5f39d1b3SJooyung Han   static void Transform(const InType* input, const Kernel& params,
127*5f39d1b3SJooyung Han                         OutType* output);
128*5f39d1b3SJooyung Han };
129*5f39d1b3SJooyung Han 
130*5f39d1b3SJooyung Han template <typename InType, typename OutType, typename Transform>
131*5f39d1b3SJooyung Han class Transform1DUtil {
132*5f39d1b3SJooyung Han  public:
133*5f39d1b3SJooyung Han   static int EstimateComputeCost(const Transform& params);
134*5f39d1b3SJooyung Han 
135*5f39d1b3SJooyung Han   static const InType* OffsetInput(const Transform& params, const InType* input,
136*5f39d1b3SJooyung Han                                    int offset);
137*5f39d1b3SJooyung Han 
138*5f39d1b3SJooyung Han   static OutType* OffsetOutput(const Transform& params, OutType* output,
139*5f39d1b3SJooyung Han                                int offset);
140*5f39d1b3SJooyung Han };
141*5f39d1b3SJooyung Han 
142*5f39d1b3SJooyung Han }  // namespace meta
143*5f39d1b3SJooyung Han }  // namespace gemmlowp
144*5f39d1b3SJooyung Han 
145*5f39d1b3SJooyung Han #endif  // GEMMLOWP_META_BASE_H_
146