xref: /aosp_15_r20/external/gemmlowp/doc/low-precision.md (revision 5f39d1b313f0528e11bae88b3029b54b9e1033e7)
1*5f39d1b3SJooyung Han# The low-precision paradigm in gemmlowp, and how it's implemented
2*5f39d1b3SJooyung Han
3*5f39d1b3SJooyung Han## Introduction
4*5f39d1b3SJooyung Han
5*5f39d1b3SJooyung Han"Low-precision" means that the input and output matrix entries are integers on
6*5f39d1b3SJooyung Hanat most 8 bits. The scalar type is uint8_t.
7*5f39d1b3SJooyung Han
8*5f39d1b3SJooyung HanThis isn't the same as just doing plain matrix arithmetic over uint8_t, because
9*5f39d1b3SJooyung Hanthat would overflow. To avoid overflow, we internally accumulate results on more
10*5f39d1b3SJooyung Hanthan 8 bits, and at the end we keep only some significant 8 bits. This relies on
11*5f39d1b3SJooyung Hanthe caller providing suitable offset/multiplier/shift parameters, which
12*5f39d1b3SJooyung Haneffectively govern how we extract some significant 8 bit from our more-than-8bit
13*5f39d1b3SJooyung Hantemporary accumulators.
14*5f39d1b3SJooyung Han
15*5f39d1b3SJooyung Han## Low-precision paradigms
16*5f39d1b3SJooyung Han
17*5f39d1b3SJooyung Hangemmlowp is flexible enough to support multiple low-precision paradigms, i.e.
18*5f39d1b3SJooyung Hanmultiple ways that a meaning is attached to 8bit values so that a computation
19*5f39d1b3SJooyung Hancan rely on a 8bit GEMM provided by gemmlowp.
20*5f39d1b3SJooyung Han
21*5f39d1b3SJooyung Han### The current flexible design with arbitrary "output pipelines".
22*5f39d1b3SJooyung Han
23*5f39d1b3SJooyung HanSee [output.md](output.md) for more details about output pipelines. This is a
24*5f39d1b3SJooyung Hanmechanism by which gemmlowp becomes generic enough to support multiple 8bit
25*5f39d1b3SJooyung Hancomputation paradigms, by allowing the user to set up a chain of transformations
26*5f39d1b3SJooyung Hanto be performed on internal 32bit accumulators to obtain the final outputs.
27*5f39d1b3SJooyung Han
28*5f39d1b3SJooyung HanThe public entry point in [public/gemmlowp.h](../public/gemmlowp.h) allowing to
29*5f39d1b3SJooyung Hanset up an arbitrary output pipeline is `GemmWithOutputPipeline`.
30*5f39d1b3SJooyung Han
31*5f39d1b3SJooyung HanRefer to [quantization.md](quantization.md) for details of how one gets from
32*5f39d1b3SJooyung Hanfirst principles to the actual output pipelines to assemble for successful
33*5f39d1b3SJooyung Hanreal-world quantized calculations.
34*5f39d1b3SJooyung Han
35*5f39d1b3SJooyung HanFor the scope of the present document, it suffices to say that quantized matrix
36*5f39d1b3SJooyung Hanmultiplication takes the following parameters:
37*5f39d1b3SJooyung Han
38*5f39d1b3SJooyung Han-   The lhs matrix of uint8 quantized values.
39*5f39d1b3SJooyung Han-   The rhs matrix of uint8 quantized values.
40*5f39d1b3SJooyung Han-   A int32 lhs_offset, that will be added to each entry of the lhs matrix.
41*5f39d1b3SJooyung Han-   A int32 rhs_offset, that will be added to each entry of the rhs matrix.
42*5f39d1b3SJooyung Han-   An output pipeline, that will process int32 accumulators into final outputs.
43*5f39d1b3SJooyung Han
44*5f39d1b3SJooyung HanThe overall computation goes through the following steps:
45*5f39d1b3SJooyung Han
46*5f39d1b3SJooyung Han1.  Cast lhs entries from uint8 to int32 and add lhs_offset to each of them.
47*5f39d1b3SJooyung Han2.  Cast rhs entries from uint8 to int32 and add rhs_offset to each of them.
48*5f39d1b3SJooyung Han3.  Compute the int32 matrix product of the resulting lhs times rhs.
49*5f39d1b3SJooyung Han4.  Apply the output pipeline on these int32 accumulators, to obtain the final
50*5f39d1b3SJooyung Han    outputs.
51*5f39d1b3SJooyung Han
52*5f39d1b3SJooyung Han### The legacy low-precision paradigm
53*5f39d1b3SJooyung Han
54*5f39d1b3SJooyung HanThis older paradigm is the one exposed by the following entry points:
55*5f39d1b3SJooyung Han
56*5f39d1b3SJooyung Han*   In [public/gemmlowp.h](../public/gemmlowp.h), the `Gemm` entry point.
57*5f39d1b3SJooyung Han*   The deprecateed `eight_bit_int_gemm` directory.
58*5f39d1b3SJooyung Han
59*5f39d1b3SJooyung HanOriginally, gemmlowp started an implementation of the (now deprecated)
60*5f39d1b3SJooyung HanEightBitIntGemm paradigm, where quantized matrix multiplication takes the
61*5f39d1b3SJooyung Hanfollowing input parameters: - the lhs matrix of uint8 quantized values - the rhs
62*5f39d1b3SJooyung Hanmatrix of uint8 quantized values - the following int32 "quantization
63*5f39d1b3SJooyung Hanparameters", which control how the uint8 quantized values in the matrices are to
64*5f39d1b3SJooyung Hanbe interpreted during the matrix computation: - lhs_offset - rhs_offset -
65*5f39d1b3SJooyung Hanresult_offset - result_mult_int - result_shift
66*5f39d1b3SJooyung Han
67*5f39d1b3SJooyung HanIn that legacy paradigm, the mathematical expression to be computed is the
68*5f39d1b3SJooyung Hanresult of the following steps:
69*5f39d1b3SJooyung Han
70*5f39d1b3SJooyung Han1.  Cast lhs entries from uint8 to int32 and add lhs_offset to each of them.
71*5f39d1b3SJooyung Han2.  Cast rhs entries from uint8 to int32 and add rhs_offset to each of them.
72*5f39d1b3SJooyung Han3.  Compute the int32 matrix product of the resulting lhs times rhs.
73*5f39d1b3SJooyung Han4.  Add result_offset to each entry of the result.
74*5f39d1b3SJooyung Han5.  Multiply each entry of the result by the following fraction, and round to
75*5f39d1b3SJooyung Han    the nearest integer:
76*5f39d1b3SJooyung Han
77*5f39d1b3SJooyung Han```
78*5f39d1b3SJooyung Hanresult_mult_int
79*5f39d1b3SJooyung Han---------------                             (1)
80*5f39d1b3SJooyung Han2^result_shift
81*5f39d1b3SJooyung Han```
82*5f39d1b3SJooyung Han
83*5f39d1b3SJooyung Han1.  Clamp the resulting int32 values to the `[0..255]` range and cast to uint8.
84*5f39d1b3SJooyung Han
85*5f39d1b3SJooyung HanAgain, this paradigm is not recommended for new usage. See
86*5f39d1b3SJooyung Han[quantization.md](quantization.md) for how reasoning from first principles, one
87*5f39d1b3SJooyung Hanarrives to a substantially different quantization paradigm.
88*5f39d1b3SJooyung Han
89*5f39d1b3SJooyung HanIn addition, note that the integer multiplication by the numerator in the above
90*5f39d1b3SJooyung Hanstep 5. risks overflowing. That concern is avoided in the currently recommended
91*5f39d1b3SJooyung Hanoutput stages by performing a fixed-point multiplication instead of an ordinary
92*5f39d1b3SJooyung Haninteger multiplication.
93*5f39d1b3SJooyung Han
94*5f39d1b3SJooyung Han# Efficient handling of offsets
95*5f39d1b3SJooyung Han
96*5f39d1b3SJooyung HanAt first glance it may seem like the above-described quantized computation
97*5f39d1b3SJooyung Hanscheme requires adding the lhs_offset and rhs_offset to each of the lhs and rhs
98*5f39d1b3SJooyung Hanmatrix entries.
99*5f39d1b3SJooyung Han
100*5f39d1b3SJooyung HanDoing that in the GEMM kernel would incur substantial overhead: - It would mean
101*5f39d1b3SJooyung Hanextra arithmetic work in the GEMM kernel; - It would require storing the
102*5f39d1b3SJooyung Hanlhs_offset and rhs_offset in registers, which would eat into the register space
103*5f39d1b3SJooyung Hanavailable for the rest of the GEMM kernel.
104*5f39d1b3SJooyung Han
105*5f39d1b3SJooyung HanOne may then consider adding the lhs_offset and rhs_offset once and for all to
106*5f39d1b3SJooyung Hanlhs and rhs blocks, in a GEMM implementation operating on one lhs block and one
107*5f39d1b3SJooyung Hanrhs block at a time. However, doing so would require storing lhs and rhs blocks
108*5f39d1b3SJooyung Hanin 32 bit (or at least in 16 bit in real-world cases), which would partially
109*5f39d1b3SJooyung Hannegate the memory bandwidth benefits of low-precision computation.
110*5f39d1b3SJooyung Han
111*5f39d1b3SJooyung HanFortunately, there is another way to handle these offsets that has none of the
112*5f39d1b3SJooyung Hancosts of the approaches described above. The idea is as follows.
113*5f39d1b3SJooyung Han
114*5f39d1b3SJooyung HanLet `P` denote the matrix shaped like `lhs`, but filled with 1's.
115*5f39d1b3SJooyung Han
116*5f39d1b3SJooyung HanLet `Q` denote the matrix shaped like `rhs`, but filled with 1's.
117*5f39d1b3SJooyung Han
118*5f39d1b3SJooyung HanAdding lhs_offset to each entry of `lhs`, means adding `lhs_offset * P` to
119*5f39d1b3SJooyung Han`lhs`.
120*5f39d1b3SJooyung Han
121*5f39d1b3SJooyung HanAdding rhs_offset to each entry of `rhs`, means adding `rhs_offset * Q` to
122*5f39d1b3SJooyung Han`rhs`.
123*5f39d1b3SJooyung Han
124*5f39d1b3SJooyung HanThus, as far as handling `lhs_offset` and `rhs_offset` goes, the matrix product
125*5f39d1b3SJooyung Hanto be computed is:
126*5f39d1b3SJooyung Han
127*5f39d1b3SJooyung Han```
128*5f39d1b3SJooyung Han(lhs + lhs_offset * P) * (rhs + rhs_offset * Q)
129*5f39d1b3SJooyung Han```
130*5f39d1b3SJooyung Han
131*5f39d1b3SJooyung HanExpanding this (using distributivity of matrix multiplication over addition), we
132*5f39d1b3SJooyung Hansee that the above product is equal to the following sum of 4 terms:
133*5f39d1b3SJooyung Han
134*5f39d1b3SJooyung Han```
135*5f39d1b3SJooyung Han  lhs * rhs                                 (2)
136*5f39d1b3SJooyung Han+ lhs_offset * P * rhs
137*5f39d1b3SJooyung Han+ lhs * rhs_offset * Q
138*5f39d1b3SJooyung Han+ lhs_offset * rhs_offset * P * Q
139*5f39d1b3SJooyung Han```
140*5f39d1b3SJooyung Han
141*5f39d1b3SJooyung HanThe first term, `lhs * rhs`, is just the matrix multiplication ignoring the
142*5f39d1b3SJooyung Hanoffsets, i.e. as if `lhs_offset==rhs_offset==0`. Our claim here is that this is
143*5f39d1b3SJooyung Hanall what we have to compute in the GEMM kernel.
144*5f39d1b3SJooyung Han
145*5f39d1b3SJooyung HanIn the second term, `lhs_offset * P * rhs`, notice that since P is filled with
146*5f39d1b3SJooyung Han1's, `P * rhs` has all its rows equal to each other, and equal to the row-vector
147*5f39d1b3SJooyung Hanof sums of all the entries in each column of rhs.
148*5f39d1b3SJooyung Han
149*5f39d1b3SJooyung HanThus, we can compute the second term, `lhs_offset * P * rhs`, by summing each
150*5f39d1b3SJooyung Hancolumn of rhs. This produces a single row-vector, and in order to add the second
151*5f39d1b3SJooyung Hanterm, we simply need to add this row-vector (multiplied by lhs_offset) to each
152*5f39d1b3SJooyung Hanrow of the result. This is just a rank one update of the result (equivalently,
153*5f39d1b3SJooyung Hanthe second term is a rank one matrix), and we can efficiently store it as a
154*5f39d1b3SJooyung Hansingle vector.
155*5f39d1b3SJooyung Han
156*5f39d1b3SJooyung HanThe third term, `lhs * rhs_offset * Q`, is entirely similar to the second one,
157*5f39d1b3SJooyung Hanand can be similarly computed by summing each row of lhs, storing this in a
158*5f39d1b3SJooyung Hansingle column-vector, and later multiplying these sums by rhs_offset.
159*5f39d1b3SJooyung Han
160*5f39d1b3SJooyung HanThe fourth term is a single constant, repeated into all the entries of the
161*5f39d1b3SJooyung Hanmatrix. The matrix `P * Q` is filled with the single constant value 'depth' (the
162*5f39d1b3SJooyung Handepth of the matrix product i.e. the number of columns of the lhs). Thus the
163*5f39d1b3SJooyung Hanfourth term is simply the rank zero update adding this constant to each matrix
164*5f39d1b3SJooyung Hanentry:
165*5f39d1b3SJooyung Han
166*5f39d1b3SJooyung Han```
167*5f39d1b3SJooyung Hanlhs_offset * rhs_offset * depth
168*5f39d1b3SJooyung Han```
169*5f39d1b3SJooyung Han
170*5f39d1b3SJooyung Han# Implementation of this technique in gemmlowp
171*5f39d1b3SJooyung Han
172*5f39d1b3SJooyung HanIn gemmlowp, at the packing stage (where we traverse blocks of the lhs and rhs
173*5f39d1b3SJooyung Hanto prepare them for efficient repeated traversal by the kernel), we compute the
174*5f39d1b3SJooyung Hansum of each row of the lhs block and the sum of each column of the rhs block.
175*5f39d1b3SJooyung Han
176*5f39d1b3SJooyung HanSee in [internal/pack.h](../internal/pack.h), in the PackedSideBlock class, the
177*5f39d1b3SJooyung Hanfollowing member:
178*5f39d1b3SJooyung Han
179*5f39d1b3SJooyung Han```
180*5f39d1b3SJooyung Han// Handle on the additional buffer backing the vector of sums of slices
181*5f39d1b3SJooyung Han// associated with this block. Owned.
182*5f39d1b3SJooyung HanAllocator::Handle sums_of_each_slice_handle_;
183*5f39d1b3SJooyung Han```
184*5f39d1b3SJooyung Han
185*5f39d1b3SJooyung Hansums_of_each_slice_handle_ is the handle to the buffer allocated to store the
186*5f39d1b3SJooyung Hanvector containing sums of rows of lhs, or of sums of columns of rhs.
187*5f39d1b3SJooyung Han
188*5f39d1b3SJooyung HanAfter these rank one updates have been computed at the packing stage, they are
189*5f39d1b3SJooyung Hanignored at the compute kernel stage, since that stage is only concerned with the
190*5f39d1b3SJooyung Hanfirst of the four terms in (2); they are only used at the unpacking stage. See
191*5f39d1b3SJooyung Hanthe default/reference implementation, `UnpackResultImpl`, in
192*5f39d1b3SJooyung Han[internal/unpack.h](../internal/unpack.h).
193