1 // Copyright 2016 The gemmlowp Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 // This is a standalone testbed and benchmark for gemmlowp-style GEMM kernels,
16 // either doing integer or float arithmetic.
17 // It verifies that a kernel produces correct results, then benchmarks it.
18 //
19 // Some benchmark results are recorded in this spreadsheet:
20 //
21 // https://docs.google.com/spreadsheets/d/1UPbzbp9rdsD6RXxOr5q6AZ0n1omgEknLYO2ogiw6Kqk/edit?usp=sharing
22 //
23 // This program is entirely self-contained, and can be compiled manually
24 // such as suggested in the command lines below.
25 // It currently supports only Android/ARM but would trivially generalize to
26 // other OSes (it's mostly standard POSIX) or architectures (each kernel
27 // targets a specific architecture, one may simply add more).
28
29 /*
30 Build and run this benchmark on Android/ARM/32bit:
31 ~/android/toolchains/arm-linux-androideabi/bin/arm-linux-androideabi-clang++ \
32 -fPIE -pie -O3 --std=c++11 standalone/neon-gemm-kernel-benchmark.cc -o \
33 /tmp/benchmark -mfloat-abi=softfp -mfpu=neon-vfpv4 && adb push /tmp/benchmark \
34 /data/local/tmp && adb shell /data/local/tmp/benchmark
35 Build and run this benchmark on Android/ARM/64bit:
36 ~/android/toolchains/aarch64-linux-android/bin/aarch64-linux-android-clang++ \
37 -fPIE -static -O3 --std=c++11 standalone/neon-gemm-kernel-benchmark.cc -o \
38 /tmp/benchmark && adb push /tmp/benchmark /data/local/tmp && adb shell \
39 /data/local/tmp/benchmark
40 */
41
42 // For big.LITTLE devices, use 'taskset' to select which cores to benchmark.
43 //
44 // The syntax is: taskset <mask> <commandline>
45 // where mask is a binary mask where each bit corresponds to a core,
46 // and low bits are little cores.
47 //
48 // Examples:
49 // Nexus 5X big cores: taskset 30
50 // Nexus 5X little cores: taskset 0f
51 // Pixel XL big cores: taskset 0c
52 // Pixel XL little cores: taskset 03
53 //
54 // Full example:
55 // adb shell taskset 0c /data/local/tmp/benchmark
56
57 #include <sched.h>
58 #include <unistd.h>
59
60 #include <algorithm>
61 #include <cassert>
62 #include <cstdint>
63 #include <cstdlib>
64 #include <cstring>
65 #include <iostream>
66 #include <random>
67 #include <type_traits>
68
69 #if !defined(__arm__) && !defined(__aarch64__) && \
70 !(defined(__mips) && (__mips_isa_rev >= 5) && defined(__mips_msa))
71 #error This benchmark assumes ARM or MIPS (for intrinsics and inline assembly sections).
72 #endif
73
74 #if defined(__arm__) || defined(__aarch64__)
75 #include <arm_neon.h>
76 #endif
77
78 #if defined(__mips)
79 #include <msa.h>
80
81 // Some convenience macros to hide differences between MIPS32 and MIPS64.
82 #ifdef __LP64__
83 #define GEMMLOWP_MIPS_XADDIU "daddiu"
84 #else
85 #define GEMMLOWP_MIPS_XADDIU "addiu"
86 #endif
87 #endif
88
89 // Typically one wants to fit in L1 cache, and GEMM implementations
90 // are carefully optimized to tune their access patterns to that effect.
91 // Most devices have at least 16k of L1 cache. The Kraits have exactly 16k.
92 const int kDefaultCacheSizeK = 16;
93
94 const int kCacheLineSize = 64;
95
96 // These definitions are used for labels within assembly code. Required for
97 // iOS toolchain compatibility.
98 #define GEMMLOWP_LABEL_AFTER_LOOP "1"
99 #define GEMMLOWP_LABEL_LOOP "2"
100 #define GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES "3"
101 #define GEMMLOWP_LABEL_STORE "4"
102
103 // BEGIN code copied from gemmlowp/internal/kernel.h
104
105 // Explanation of general gemmlowp terminology
106 // ===========================================
107 //
108 // We use the following abbreviations:
109 // LHS = "left-hand side"
110 // RHS = "right-hand side"
111 // Sometimes when referring to either LHS or RHS, we just say a "Side".
112 //
113 // In a matrix product of a MxK matrix times a KxN matrix,
114 // we call K the 'depth'. Note that M is the number of rows
115 // of the result (and of the LHS), and N is the number of columns
116 // of the result (and of the RHS).
117 //
118 // In each of the LHS and RHS matrices, we call 'width' the
119 // other dimension, besides the depth. So in the LHS, 'width'
120 // is the number of rows, while in the RHS, 'width' is the number
121 // of columns.
122 //
123 // So in the LHS MxK matrix, the depth is K and the width in M.
124 // And in the RHS KxN matrix, the depth is K and the width in N.
125 //
126 // This is illustrated in this picture:
127 //
128 // RHS width
129 // <----------------->
130 // +-----------------+ ^
131 // | RHS | | Depth
132 // +-----------------+ v
133 // ^ +--+ +-----------------+
134 // | |L | | |
135 // LHS width | |H | | Result |
136 // | |S | | |
137 // v +--+ +-----------------+
138 // <-->
139 // Depth
140
141 // Explanation of gemmlowp kernel formats and "cells"
142 // ==================================================
143 //
144 // Kernels operate on small LHS and RHS blocks that fit in registers.
145 // These blocks are stored contiguously in memory, but not always
146 // in a traditional column-major or row-major order; instead,
147 // they consist of a number of sub-blocks, which we call "cells",
148 // that are stored in column-major or row-major order. However,
149 // what really matters to us is not so much rows vs columns, but
150 // rather width vs depth. So we refer to "width-major" and "depth-major"
151 // storage orders. In the LHS, width-major means row-major,
152 // while in the RHS, width-major means column-major.
153 // There is also a third possibility, "diagonal order",
154 // which is unused at the moment.
155 //
156 // We aim to treat both sides, LHS and RHS, on an equal footing,
157 // so we call them both 'sides'. A KernelFormat thus is just a pair
158 // of KernelSideFormat's, one for LHS and one for RHS; each KernelSideFormat
159 // contains a CellFormat and a number of cells; cells are only ever
160 // stacked in the width dimension, which means stacked vertically in the
161 // LHS and stacked horizondally in the RHS.
162 //
163 // Example
164 // =======
165 //
166 // Let's work out the data layout expected by a kernel having the
167 // following format (the struct names here are defined below in this file):
168 //
169 // KernelFormat<
170 // KernelSideFormat<CellFormat<3, 4>, 3>,
171 // KernelSideFormat<CellFormat<5, 4>, 2>
172 // >
173 //
174 // The LHS format, KernelSideFormat<CellFormat<3, 4>, 3>, means:
175 // 3 cells, each cell having dimensions (width=3, depth=4), laid out in
176 // DepthMajor order (the default value, see CellFormat). In the LHS,
177 // DepthMajor means column-major, so the LHS cells are of size 3x4 in
178 // column-major order, so the LHS layout is:
179 //
180 // 0 3 6 9
181 // 1 4 7 10
182 // 2 5 8 11
183 // 12 15 18 21
184 // 13 16 19 22
185 // 14 17 20 23
186 // 24 27 30 33
187 // 25 28 31 34
188 // 26 29 32 35
189 //
190 // The RHS format, KernelSideFormat<CellFormat<5, 4>, 2>, means:
191 // 2 cells each having dimensions (width=5, depth=4), laid out in
192 // DepthMajor order (the default value, see CellFormat). In the RHS,
193 // DepthMajor means row-major, so the RHS cells are of size 4x5 in
194 // row-major order, so the RHS layout is:
195 //
196 // 0 1 2 3 4 20 21 22 23 24
197 // 5 6 7 8 9 25 26 27 28 29
198 // 10 11 12 13 14 30 31 32 33 34
199 // 15 16 17 18 19 35 36 37 38 39
200
201 // CellOrder enumerates the possible storage orders (=layouts) for
202 // a cell (see explanation above).
203 enum class CellOrder { DepthMajor, WidthMajor, Diagonal };
204
205 // CellFormat describes how data is laid
206 // out in a cell. That is, a CellOrder together with actual dimensions.
207 template <int tWidth, int tDepth, CellOrder tOrder>
208 struct CellFormat {
209 static const int kWidth = tWidth;
210 static const int kDepth = tDepth;
211 static const CellOrder kOrder = tOrder;
212
213 static const int kSize = kWidth * kDepth;
214 };
215
216 // KernelSideFormat describes how data is laid out in a kernel side
217 // (i.e. LHS or RHS). That is, a CellFormat together with a number of
218 // cells. These cells are always stacked in the Width dimension.
219 // For example, in the LHS case, the Width dimension is the rows dimension,
220 // se we're saying that in the LHS, cells are stacked vertically.
221 // We never stack cells in the Depth dimension.
222 template <typename tCellFormat, int tCells>
223 struct KernelSideFormat {
224 typedef tCellFormat Cell;
225 static const int kCells = tCells;
226 static const int kWidth = kCells * Cell::kWidth;
227 static const int kDepth = Cell::kDepth;
228 };
229
230 // KernelFormat describes fully the input data layout that a kernel expects.
231 // It consists of two KernelSideFormat's, one for LHS and one for RHS.
232 template <typename tLhs, typename tRhs>
233 struct KernelFormat {
234 typedef tLhs Lhs;
235 typedef tRhs Rhs;
236
237 static_assert(Lhs::Cell::kDepth == Rhs::Cell::kDepth, "");
238 static const int kDepth = Lhs::Cell::kDepth;
239 static const int kRows = Lhs::Cell::kWidth * Lhs::kCells;
240 static const int kCols = Rhs::Cell::kWidth * Rhs::kCells;
241 };
242
243 // KernelOperandRanges specifies the minimum and maximum values an operand can
244 // take. It consists of two ranges: one for the LHS and one for the RHS. The
245 // default values are the minimum and maximum values of the operand data type.
246 template <typename Kernel, typename OperandType = typename Kernel::OperandType>
247 struct KernelOperandRanges {
LhsMinKernelOperandRanges248 static OperandType LhsMin() {
249 return std::numeric_limits<OperandType>::lowest();
250 }
LhsMaxKernelOperandRanges251 static OperandType LhsMax() {
252 return std::numeric_limits<OperandType>::max();
253 }
RhsMinKernelOperandRanges254 static OperandType RhsMin() {
255 return std::numeric_limits<OperandType>::lowest();
256 }
RhsMaxKernelOperandRanges257 static OperandType RhsMax() {
258 return std::numeric_limits<OperandType>::max();
259 }
260 };
261
262 template <typename Kernel>
263 struct KernelOperandRanges<Kernel, float> {
LhsMinKernelOperandRanges264 static float LhsMin() { return -100.f; }
LhsMaxKernelOperandRanges265 static float LhsMax() { return 100.f; }
RhsMinKernelOperandRanges266 static float RhsMin() { return -100.f; }
RhsMaxKernelOperandRanges267 static float RhsMax() { return 100.f; }
268 };
269
270 #define SET_7BIT_RANGES(kernel) \
271 template <> \
272 struct KernelOperandRanges<kernel, std::int8_t> { \
273 static std::int8_t LhsMin() { return -63; } \
274 static std::int8_t LhsMax() { return 63; } \
275 static std::int8_t RhsMin() { return -64; } \
276 static std::int8_t RhsMax() { return 63; } \
277 };
278
279 #define SET_425BIT_RANGES(kernel) \
280 template <> \
281 struct KernelOperandRanges<kernel, std::int8_t> { \
282 static std::int8_t LhsMin() { return -7; } \
283 static std::int8_t LhsMax() { return 7; } \
284 static std::int8_t RhsMin() { return -9; } \
285 static std::int8_t RhsMax() { return 9; } \
286 };
287
CellOrderName(CellOrder o)288 inline const char* CellOrderName(CellOrder o) {
289 switch (o) {
290 case CellOrder::DepthMajor:
291 return "DepthMajor";
292 case CellOrder::WidthMajor:
293 return "WidthMajor";
294 case CellOrder::Diagonal:
295 return "Diagonal";
296 default:
297 assert(false);
298 return nullptr;
299 }
300 }
301
302 // Returns the offset into a cell, at which a given coefficient is stored.
303 template <typename CellFormat>
OffsetIntoCell(int w,int d)304 inline int OffsetIntoCell(int w, int d) {
305 switch (CellFormat::kOrder) {
306 case CellOrder::DepthMajor:
307 return w + d * CellFormat::kWidth;
308 case CellOrder::WidthMajor:
309 return d + w * CellFormat::kDepth;
310 case CellOrder::Diagonal:
311 assert(CellFormat::kWidth == CellFormat::kDepth);
312 static const int size = CellFormat::kWidth;
313 return ((size + w - d) * size + d) % (size * size);
314 default:
315 assert(false);
316 return 0;
317 }
318 }
319
320 // END code copied from gemmlowp/internal/kernel.h
321
322 #ifdef __arm__
323
324 // This is the current standard kernel in gemmlowp, see:
325 // https://github.com/google/gemmlowp/blob/b1e2a29ff866680028f3080efc244e10e8dd7f46/internal/kernel_neon.h#L33
326 struct NEON_32bit_GEMM_Uint8Operands_Uint32Accumulators {
327 typedef std::uint8_t OperandType;
328 typedef std::uint32_t AccumulatorType;
329 typedef KernelFormat<
330 KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 3>,
331 KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 1> >
332 Format;
RunNEON_32bit_GEMM_Uint8Operands_Uint32Accumulators333 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
334 AccumulatorType* accum_ptr, int depth) {
335 asm volatile(
336 // Load 1 Rhs cell of size 2x4
337 "vld1.8 {d0}, [%[rhs_ptr]]!\n"
338 // Load 3 Lhs cells of size 4x2 each
339 "vld1.8 {d2}, [%[lhs_ptr]]!\n"
340 "vld1.8 {d4}, [%[lhs_ptr]]!\n"
341 "vld1.8 {d6}, [%[lhs_ptr]]!\n"
342 // Load accumulators
343 "mov r0, %[accum_ptr]\n"
344 "vld1.32 {d8, d9}, [r0]!\n"
345 "vld1.32 {d16, d17}, [r0]!\n"
346 "vld1.32 {d24, d25}, [r0]!\n"
347 "vld1.32 {d10, d11}, [r0]!\n"
348 "vld1.32 {d18, d19}, [r0]!\n"
349 "vld1.32 {d26, d27}, [r0]!\n"
350 "vld1.32 {d12, d13}, [r0]!\n"
351 "vld1.32 {d20, d21}, [r0]!\n"
352 "vld1.32 {d28, d29}, [r0]!\n"
353 "vld1.32 {d14, d15}, [r0]!\n"
354 "vld1.32 {d22, d23}, [r0]!\n"
355 "vld1.32 {d30, d31}, [r0]!\n"
356
357 "subs %[depth], #2\n"
358
359 "beq " GEMMLOWP_LABEL_AFTER_LOOP "f\n"
360
361 GEMMLOWP_LABEL_LOOP
362 ":\n"
363 // Overview of register layout:
364 //
365 // A 2x4 cell of Rhs is stored in 16bit in d0--d1 (q0).
366 // A 12x2 block of 3 4x2 cells Lhs is stored in 16bit in d2--d7
367 // (q1--q3).
368 // A 12x4 block of accumulators is stored in 32bit in q4--q15.
369 //
370 // +-----+-----+-----+-----+
371 // |d0[0]|d0[1]|d0[2]|d0[3]|
372 // Rhs +-----+-----+-----+-----+
373 // |d1[0]|d1[1]|d1[2]|d1[3]|
374 // +-----+-----+-----+-----+
375 //
376 // | | | | |
377 //
378 // Lhs | | | | |
379 //
380 // +--+--+ - - - - +-----+-----+-----+-----+
381 // |d2|d3| | q4 | q5 | q6 | q7 |
382 // |d2|d3| | q4 | q5 | q6 | q7 |
383 // |d2|d3| | q4 | q5 | q6 | q7 |
384 // |d2|d3| | q4 | q5 | q6 | q7 |
385 // +--+--+ - - - - +-----+-----+-----+-----+
386 // |d4|d5| | q8 | q9 | q10 | q11 |
387 // |d4|d5| | q8 | q9 | q10 | q11 |
388 // |d4|d5| | q8 | q9 | q10 | q11 |
389 // |d4|d5| | q8 | q9 | q10 | q11 |
390 // +--+--+ - - - - +-----+-----+-----+-----+
391 // |d6|d7| | q12 | q13 | q14 | q15 |
392 // |d6|d7| | q12 | q13 | q14 | q15 |
393 // |d6|d7| | q12 | q13 | q14 | q15 |
394 // |d6|d7| | q12 | q13 | q14 | q15 |
395 // +--+--+ - - - - +-----+-----+-----+-----+
396 //
397 // Accumulator
398
399 // Expand Lhs/Rhs cells to 16 bit.
400 // Note: moving theses vmovls further down to allow for
401 // longer data pipelining helps a little on A57 but is
402 // harmful on A53 --- It looks as if A53 doesn't like
403 // interleaving vmovl's into the vmlal's.
404 "vmovl.u8 q0, d0\n"
405 "vmovl.u8 q1, d2\n"
406 "vmovl.u8 q2, d4\n"
407 "vmovl.u8 q3, d6\n"
408
409 // Multiply-accumulate, level of depth 0
410 "vmlal.u16 q4, d2, d0[0]\n"
411 "vmlal.u16 q5, d2, d0[1]\n"
412 "vmlal.u16 q6, d2, d0[2]\n"
413 "vmlal.u16 q7, d2, d0[3]\n"
414 "vldr d2, [%[lhs_ptr]]\n"
415 "vmlal.u16 q8, d4, d0[0]\n"
416 "vmlal.u16 q9, d4, d0[1]\n"
417 "vmlal.u16 q10, d4, d0[2]\n"
418 "vmlal.u16 q11, d4, d0[3]\n"
419 "vldr d4, [%[lhs_ptr], #8]\n"
420 "vmlal.u16 q12, d6, d0[0]\n"
421 "vmlal.u16 q13, d6, d0[1]\n"
422 "vmlal.u16 q14, d6, d0[2]\n"
423 "vmlal.u16 q15, d6, d0[3]\n"
424 "vldr d6, [%[lhs_ptr], #16]\n"
425 "vldr d0, [%[rhs_ptr]]\n"
426
427 // Multiply-accumulate, level of depth 1
428 "vmlal.u16 q4, d3, d1[0]\n"
429 "vmlal.u16 q5, d3, d1[1]\n"
430 "add %[lhs_ptr], #24\n"
431 "vmlal.u16 q6, d3, d1[2]\n"
432 "vmlal.u16 q7, d3, d1[3]\n"
433 "add %[rhs_ptr], #8\n"
434 "vmlal.u16 q8, d5, d1[0]\n"
435 "vmlal.u16 q9, d5, d1[1]\n"
436 "subs %[depth], #2\n"
437 "vmlal.u16 q10, d5, d1[2]\n"
438 "vmlal.u16 q11, d5, d1[3]\n"
439 "vmlal.u16 q12, d7, d1[0]\n"
440 "vmlal.u16 q13, d7, d1[1]\n"
441 "vmlal.u16 q14, d7, d1[2]\n"
442 "vmlal.u16 q15, d7, d1[3]\n"
443
444 "bne " GEMMLOWP_LABEL_LOOP "b\n"
445
446 GEMMLOWP_LABEL_AFTER_LOOP
447 ":\n"
448
449 // Expand Lhs/Rhs cells to 16 bit.
450 "vmovl.u8 q0, d0\n"
451 "vmovl.u8 q1, d2\n"
452 "vmovl.u8 q2, d4\n"
453 "vmovl.u8 q3, d6\n"
454
455 // Multiply-accumulate, level of depth 0
456 "vmlal.u16 q4, d2, d0[0]\n"
457 "vmlal.u16 q5, d2, d0[1]\n"
458 "vmlal.u16 q6, d2, d0[2]\n"
459 "vmlal.u16 q7, d2, d0[3]\n"
460 "vmlal.u16 q8, d4, d0[0]\n"
461 "vmlal.u16 q9, d4, d0[1]\n"
462 "vmlal.u16 q10, d4, d0[2]\n"
463 "vmlal.u16 q11, d4, d0[3]\n"
464 "vmlal.u16 q12, d6, d0[0]\n"
465 "vmlal.u16 q13, d6, d0[1]\n"
466 "vmlal.u16 q14, d6, d0[2]\n"
467 "vmlal.u16 q15, d6, d0[3]\n"
468
469 // Multiply-accumulate, level of depth 1
470 "vmlal.u16 q4, d3, d1[0]\n"
471 "vmlal.u16 q5, d3, d1[1]\n"
472 "vmlal.u16 q6, d3, d1[2]\n"
473 "vmlal.u16 q7, d3, d1[3]\n"
474 "vmlal.u16 q8, d5, d1[0]\n"
475 "vmlal.u16 q9, d5, d1[1]\n"
476 "vmlal.u16 q10, d5, d1[2]\n"
477 "vmlal.u16 q11, d5, d1[3]\n"
478 "vmlal.u16 q12, d7, d1[0]\n"
479 "vmlal.u16 q13, d7, d1[1]\n"
480 "vmlal.u16 q14, d7, d1[2]\n"
481 "vmlal.u16 q15, d7, d1[3]\n"
482
483 // Store accumulators
484 "mov r0, %[accum_ptr]\n"
485 "vst1.32 {d8, d9}, [r0]!\n"
486 "vst1.32 {d16, d17}, [r0]!\n"
487 "vst1.32 {d24, d25}, [r0]!\n"
488 "vst1.32 {d10, d11}, [r0]!\n"
489 "vst1.32 {d18, d19}, [r0]!\n"
490 "vst1.32 {d26, d27}, [r0]!\n"
491 "vst1.32 {d12, d13}, [r0]!\n"
492 "vst1.32 {d20, d21}, [r0]!\n"
493 "vst1.32 {d28, d29}, [r0]!\n"
494 "vst1.32 {d14, d15}, [r0]!\n"
495 "vst1.32 {d22, d23}, [r0]!\n"
496 "vst1.32 {d30, d31}, [r0]!\n"
497 : // outputs
498 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
499 [depth] "+r"(depth)
500 : // inputs
501 [accum_ptr] "r"(accum_ptr)
502 : // clobbers
503 "cc", "memory", "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7",
504 "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17",
505 "d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27",
506 "d28", "d29", "d30", "d31");
507 }
508 };
509
510 // This is Maciek Chociej's fast kernel not expanding operands,
511 // from gemmlowp/meta/. Search for
512 // mul_3x8_3x8_int32_lhsadd_rhsadd
513 // in this file:
514 // https://raw.githubusercontent.com/google/gemmlowp/e4b9d858b6637d5d0058bfa3d869d2b95864251b/meta/single_thread_gemm.h
515 struct NEON_32bit_GEMM_Uint8Operands_Uint32Accumulators_noexpand {
516 typedef std::uint8_t OperandType;
517 typedef std::uint32_t AccumulatorType;
518 typedef KernelFormat<
519 KernelSideFormat<CellFormat<3, 8, CellOrder::WidthMajor>, 1>,
520 KernelSideFormat<CellFormat<3, 8, CellOrder::WidthMajor>, 1> >
521 Format;
RunNEON_32bit_GEMM_Uint8Operands_Uint32Accumulators_noexpand522 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
523 AccumulatorType* accum_ptr, int depth) {
524 asm volatile(
525 // Clear aggregators.
526 "vmov.i32 q0, #0\n"
527 "vmov.i32 q1, #0\n"
528 "vmov.i32 q2, #0\n"
529 "vmov.i32 q3, q0\n"
530 "vmov.i32 q4, q1\n"
531 "vmov.i32 q5, q2\n"
532 "vmov.i32 q6, q3\n"
533 "vmov.i32 q7, q4\n"
534 "vmov.i32 q8, q5\n"
535
536 // Loop head
537 GEMMLOWP_LABEL_LOOP
538 ":\n"
539
540 // Subtract counter.
541 "subs %[depth], %[depth], #8\n"
542
543 "vld1.8 {d18, d19, d20}, [%[rhs_ptr]]!\n"
544 "vld1.8 {d21, d22, d23}, [%[lhs_ptr]]!\n"
545 "vmull.u8 q12, d18, d21\n"
546 "vmull.u8 q13, d18, d22\n"
547 "vmull.u8 q14, d18, d23\n"
548 "vmull.u8 q15, d19, d21\n"
549 "vpadal.u16 q0, q12\n"
550 "vpadal.u16 q1, q13\n"
551 "vpadal.u16 q2, q14\n"
552 "vpadal.u16 q3, q15\n"
553 "vmull.u8 q12, d19, d22\n"
554 "vmull.u8 q13, d19, d23\n"
555 "vmull.u8 q14, d20, d21\n"
556 "vmull.u8 q15, d20, d22\n"
557 "vmull.u8 q9, d20, d23\n"
558 "vpadal.u16 q4, q12\n"
559 "vpadal.u16 q5, q13\n"
560 "vpadal.u16 q6, q14\n"
561 "vpadal.u16 q7, q15\n"
562 "vpadal.u16 q8, q9\n"
563
564 // Loop branch
565 "bne " GEMMLOWP_LABEL_LOOP
566 "b\n"
567
568 // Horizontal reduce aggregators, step 1
569 "vpadd.u32 d0, d0, d1\n"
570 "vpadd.u32 d2, d2, d3\n"
571 "vpadd.u32 d4, d4, d5\n"
572 "vpadd.u32 d6, d6, d7\n"
573 "vpadd.u32 d8, d8, d9\n"
574 "vpadd.u32 d10, d10, d11\n"
575 "vpadd.u32 d12, d12, d13\n"
576 "vpadd.u32 d14, d14, d15\n"
577 "vpadd.u32 d16, d16, d17\n"
578
579 // Horizontal reduce aggregators, step 2
580 "vpadd.u32 d0, d0, d2\n"
581 "vpadd.u32 d1, d4, d4\n"
582 "vpadd.u32 d6, d6, d8\n"
583 "vpadd.u32 d7, d10, d10\n"
584 "vpadd.u32 d12, d12, d14\n"
585 "vpadd.u32 d13, d16, d16\n"
586
587 // Load accumulators
588 "mov r0, %[accum_ptr]\n"
589 "vld1.32 {d2}, [r0]!\n"
590 "vld1.32 {d3[0]}, [r0]!\n"
591
592 "vld1.32 {d8}, [r0]!\n"
593 "vld1.32 {d9[0]}, [r0]!\n"
594
595 "vld1.32 {d14}, [r0]!\n"
596 "vld1.32 {d15[0]}, [r0]!\n"
597
598 // Accumulate
599 "vadd.s32 q0, q0, q1\n"
600 "vadd.s32 q3, q3, q4\n"
601 "vadd.s32 q6, q6, q7\n"
602
603 // Store accumulators
604 "mov r0, %[accum_ptr]\n"
605 "vst1.32 {d0}, [r0]!\n"
606 "vst1.32 {d1[0]}, [r0]!\n"
607
608 "vst1.32 {d6}, [r0]!\n"
609 "vst1.32 {d7[0]}, [r0]!\n"
610
611 "vst1.32 {d12}, [r0]!\n"
612 "vst1.32 {d13[0]}, [r0]!\n"
613 : // outputs
614 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
615 [depth] "+r"(depth)
616 : // inputs
617 [accum_ptr] "r"(accum_ptr)
618 : // clobbers
619 "cc", "memory", "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7",
620 "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17",
621 "d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27",
622 "d28", "d29", "d30", "d31");
623 }
624 };
625
626 // Fast kernel operating on int8 operands.
627 // It is assumed that one of the two int8 operands only takes values
628 // in [-127, 127], while the other may freely range in [-128, 127].
629 // The issue with both operands taking the value -128 is that:
630 // -128*-128 + -128*-128 == -32768 overflows int16.
631 // Every other expression a*b + c*d, for any int8 a,b,c,d, fits in int16
632 // range. That is the basic idea of this kernel.
633 struct NEON_32bit_GEMM_Int8Operands_AccumTwoWithin16Bits {
634 typedef std::int8_t OperandType;
635 typedef std::int32_t AccumulatorType;
636 typedef KernelFormat<
637 KernelSideFormat<CellFormat<4, 16, CellOrder::WidthMajor>, 1>,
638 KernelSideFormat<CellFormat<2, 16, CellOrder::WidthMajor>, 1> >
639 Format;
RunNEON_32bit_GEMM_Int8Operands_AccumTwoWithin16Bits640 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
641 AccumulatorType* accum_ptr, int depth) {
642 std::size_t start_depth = 123;
643 std::size_t run_depth = depth;
644 AccumulatorType* dst_ptr = accum_ptr;
645 asm volatile(
646
647 // Overview of register layout:
648 //
649 // A 2x16 block of Rhs is stored in 8 bit in d0--d3.
650 // A 4x16 block of Lhs is stored in 8 bit in d4--d7. That is only
651 // half of the register space required, so we loop over these registers
652 // twice. Only half of it, a 2x16 block, is stored in d4--d7 at
653 // any given time.
654 //
655 // A 4x2 block of accumulators is stored in q8--q15 (as 4x32 bit
656 // components which need to be horizontally-added at the end)
657 //
658 // The Lhs vectors are multiplied by the Rhs vectors with a widening
659 // multiply over the 8 first levels of depth, producing int16x8
660 // vectors of products for each position in the accumulator matrix.
661 // Here comes the special trick: since the operands are signed int8,
662 // their range being [ -2^7 , 2^7 ), their products are in range
663 // [ -2^14 , 2^14 - 1 ), meaning that we can add two such values
664 // without any risk of overflowing int16.
665 // We thus proceed with the 8 next levels of depth, multiplying
666 // again Lhs by Rhs, accumulating into this existing int16x8 vector.
667 //
668 // Only then, having processed 16 levels of depth, do we need to
669 // horizontally add these int16x8 accumulators into the final
670 // int32x4 accumulators.
671 //
672 // As we do not have enough registers to store all 16 int16x8
673 // temporary-16bit-accumulators, we have them cycle through q4--q7.
674 //
675 //
676 // Register layout (ignoring the q4--q7 temporary 16bit accumulators):
677 //
678 // +----+----+
679 // | d0 | d2 |
680 // | . | . |
681 // | . | . |
682 // | . | . |
683 // Rhs +----+----+
684 // | d1 | d3 |
685 // | . | . |
686 // | . | . |
687 // | . | . |
688 // +----+----+
689 //
690 // | | |
691 //
692 // Lhs | | |
693 //
694 // +--------+--------+ - - - - +----+----+
695 // | d4 ... | d5 ... | | q8 | q9 |
696 // | d6 ... | d7 ... | | q10| q11|
697 // | d4 ... | d5 ... | | q12| q13|
698 // | d6 ... | d7 ... | | q14| q15|
699 // +--------+--------+ - - - - +----+----+
700 //
701 // Accumulator
702 //
703
704 // Clear accumulators, and, interleaved with it,
705 // initial loads of the first loop iteration,
706 // taken out of the loop so that in the loop itself we have
707 // optimal streaming of data from memory.
708 "vldr d0, [%[rhs_ptr], #0]\n"
709 "vmov.i32 q8, #0\n"
710 "vldr d4, [%[lhs_ptr], #0]\n"
711 "vmov.i32 q9, #0\n"
712 "vldr d2, [%[rhs_ptr], #16]\n"
713 "vmov.i32 q10, q8\n"
714 "vldr d6, [%[lhs_ptr], #16]\n"
715 "vmov.i32 q11, q8\n"
716 "vldr d1, [%[rhs_ptr], #8]\n"
717 "vmov.i32 q12, q8\n"
718 "vldr d5, [%[lhs_ptr], #8]\n"
719 "vmov.i32 q13, q8\n"
720 "vldr d3, [%[rhs_ptr], #24]\n"
721 "vmov.i32 q14, q8\n"
722 "vldr d7, [%[lhs_ptr], #24]\n"
723 "vmov.i32 q15, q8\n"
724
725 // General loop.
726 GEMMLOWP_LABEL_LOOP
727 ":\n"
728
729 // Multiply 8 first levels of depth.
730 "vmull.s8 q4, d0, d4\n"
731 "add %[rhs_ptr], %[rhs_ptr], #32\n"
732 "vmull.s8 q5, d2, d4\n"
733 "vldr d4, [%[lhs_ptr], #32]\n"
734 "vmull.s8 q6, d0, d6\n"
735 "vmull.s8 q7, d2, d6\n"
736 "vldr d6, [%[lhs_ptr], #48]\n"
737
738 // Multiply-accumulate second-half, again into the same
739 // 16bit local accumulator registers. This is where we
740 // take advantage of having int8 instead of uint8 and therefore
741 // being able to accumulate two products into int16.
742 "vmlal.s8 q4, d1, d5\n"
743 "vmlal.s8 q5, d3, d5\n"
744 "vldr d5, [%[lhs_ptr], #40]\n"
745 "vmlal.s8 q6, d1, d7\n"
746 "vmlal.s8 q7, d3, d7\n"
747 "vldr d7, [%[lhs_ptr], #56]\n"
748
749 // Add pairwise, accumulate into 32-bit accumulators.
750 "vpadal.s16 q8, q4\n"
751 "add %[lhs_ptr], %[lhs_ptr], #64\n"
752 "vpadal.s16 q9, q5\n"
753 "subs %[run_depth], %[run_depth], #16\n"
754 "vpadal.s16 q10, q6\n"
755 "vpadal.s16 q11, q7\n"
756
757 "beq " GEMMLOWP_LABEL_AFTER_LOOP
758 "f\n"
759
760 // Multiply first half.
761 "vmull.s8 q4, d0, d4\n"
762 "vmull.s8 q5, d2, d4\n"
763 "vldr d4, [%[lhs_ptr], #0]\n"
764 "vmull.s8 q6, d0, d6\n"
765 "vldr d0, [%[rhs_ptr], #0]\n"
766 "vmull.s8 q7, d2, d6\n"
767 "vldr d2, [%[rhs_ptr], #16]\n"
768
769 // Multiply-accumulate second-half, again into the same
770 // 16bit local accumulator registers. This is where we
771 // take advantage of having int8 instead of uint8 and therefore
772 // being able to accumulate two products into int16.
773 "vmlal.s8 q4, d1, d5\n"
774 "vldr d6, [%[lhs_ptr], #16]\n"
775 "vmlal.s8 q5, d3, d5\n"
776 "vldr d5, [%[lhs_ptr], #8]\n"
777 "vmlal.s8 q6, d1, d7\n"
778 "vldr d1, [%[rhs_ptr], #8]\n"
779 "vmlal.s8 q7, d3, d7\n"
780 "vldr d3, [%[rhs_ptr], #24]\n"
781
782 // Add pairwise, accumulate into 32-bit accumulators.
783 "vpadal.s16 q12, q4\n"
784 "vldr d7, [%[lhs_ptr], #24]\n"
785 "vpadal.s16 q13, q5\n"
786 "vpadal.s16 q14, q6\n"
787 "vpadal.s16 q15, q7\n"
788
789 "b " GEMMLOWP_LABEL_LOOP "b\n"
790
791 GEMMLOWP_LABEL_AFTER_LOOP
792 ":\n"
793
794 // Multiply first half.
795 "vmull.s8 q4, d0, d4\n"
796 "vmull.s8 q5, d2, d4\n"
797 "vmull.s8 q6, d0, d6\n"
798 "vmull.s8 q7, d2, d6\n"
799
800 // Multiply-accumulate second-half, again into the same
801 // 16bit local accumulator registers. This is where we
802 // take advantage of having int8 instead of uint8 and therefore
803 // being able to accumulate two products into int16.
804 "vmlal.s8 q4, d1, d5\n"
805 "vmlal.s8 q5, d3, d5\n"
806 "vmlal.s8 q6, d1, d7\n"
807 "vmlal.s8 q7, d3, d7\n"
808
809 // Add pairwise, accumulate into 32-bit accumulators.
810 "vpadal.s16 q12, q4\n"
811 "vpadal.s16 q13, q5\n"
812 "vpadal.s16 q14, q6\n"
813 "vpadal.s16 q15, q7\n"
814 "cmp %[start_depth], #0\n"
815
816 // Reduce 32bit accumulators horizontally.
817 "vpadd.s32 d0, d16, d17\n"
818 "vpadd.s32 d1, d18, d19\n"
819 "vpadd.s32 d2, d20, d21\n"
820 "vpadd.s32 d3, d22, d23\n"
821 "vpadd.s32 d4, d24, d25\n"
822 "vpadd.s32 d5, d26, d27\n"
823 "vpadd.s32 d6, d28, d29\n"
824 "vpadd.s32 d7, d30, d31\n"
825
826 "bne " GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES
827 "f\n"
828
829 // Reduce 32bit accumulators horizontally, second pass
830 // (each pass adds pairwise. we need to add 4-wise).
831 "vpadd.s32 d8, d0, d2\n"
832 "vpadd.s32 d9, d4, d6\n"
833 "vpadd.s32 d10, d1, d3\n"
834 "vpadd.s32 d11, d5, d7\n"
835
836 "b " GEMMLOWP_LABEL_STORE "f\n"
837
838 GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES
839 ":\n"
840
841 // Reduce 32bit accumulators horizontally, second pass
842 // (each pass adds pairwise. we need to add 4-wise),
843 // and load destination values from memory.
844 "mov r0, %[dst_ptr]\n"
845 "vld1.32 {d16, d17}, [r0]!\n"
846 "vpadd.s32 d8, d0, d2\n"
847 "vpadd.s32 d9, d4, d6\n"
848 "vld1.32 {d18, d19}, [r0]\n"
849 "vpadd.s32 d10, d1, d3\n"
850 "vpadd.s32 d11, d5, d7\n"
851
852 // Add horizontally-reduced accumulators into
853 // the values loaded from memory
854 "vadd.s32 q4, q8, q4\n"
855 "vadd.s32 q5, q9, q5\n"
856
857 GEMMLOWP_LABEL_STORE
858 ":\n"
859 // Store back into memory
860 "mov r0, %[dst_ptr]\n"
861 "vst1.32 {d8, d9}, [r0]!\n"
862 "vst1.32 {d10, d11}, [r0]\n"
863 : // outputs
864 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
865 [dst_ptr] "+r"(dst_ptr), [run_depth] "+r"(run_depth)
866 : // inputs
867 [start_depth] "r"(start_depth)
868 : // clobbers
869 "cc", "memory", "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7",
870 "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17",
871 "d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27",
872 "d28", "d29", "d30", "d31");
873 }
874 };
875
876 // We don't actually use int32*int32 in production. This is just an
877 // experiment to help dissociate the effect of integer-vs-float, from the
878 // effect of operands width.
879 struct NEON_32bit_GEMM_Int32_WithScalar {
880 typedef std::int32_t OperandType;
881 typedef std::int32_t AccumulatorType;
882 typedef KernelFormat<
883 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 3>,
884 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 1> >
885 Format;
RunNEON_32bit_GEMM_Int32_WithScalar886 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
887 AccumulatorType* accum_ptr, int depth) {
888 asm volatile(
889 // Load accumulators
890 "mov r0, %[accum_ptr]\n"
891 "vld1.32 {d8, d9}, [r0]!\n"
892 "vld1.32 {d16, d17}, [r0]!\n"
893 "vld1.32 {d24, d25}, [r0]!\n"
894 "vld1.32 {d10, d11}, [r0]!\n"
895 "vld1.32 {d18, d19}, [r0]!\n"
896 "vld1.32 {d26, d27}, [r0]!\n"
897 "vld1.32 {d12, d13}, [r0]!\n"
898 "vld1.32 {d20, d21}, [r0]!\n"
899 "vld1.32 {d28, d29}, [r0]!\n"
900 "vld1.32 {d14, d15}, [r0]!\n"
901 "vld1.32 {d22, d23}, [r0]!\n"
902 "vld1.32 {d30, d31}, [r0]!\n"
903
904 GEMMLOWP_LABEL_LOOP
905 ":\n"
906
907 // Load 1 Rhs cell of size 1x4
908 "vld1.32 {d0, d1}, [%[rhs_ptr]]!\n"
909
910 // Load 3 Lhs cells of size 4x1 each
911 "vld1.32 {d2, d3}, [%[lhs_ptr]]!\n"
912 "vld1.32 {d4, d5}, [%[lhs_ptr]]!\n"
913 "vld1.32 {d6, d7}, [%[lhs_ptr]]!\n"
914
915 // Multiply-accumulate
916 "vmla.s32 q4, q1, d0[0]\n"
917 "vmla.s32 q5, q1, d0[1]\n"
918 "vmla.s32 q6, q1, d1[0]\n"
919 "vmla.s32 q7, q1, d1[1]\n"
920 "vmla.s32 q8, q2, d0[0]\n"
921 "vmla.s32 q9, q2, d0[1]\n"
922 "vmla.s32 q10, q2, d1[0]\n"
923 "vmla.s32 q11, q2, d1[1]\n"
924 "vmla.s32 q12, q3, d0[0]\n"
925 "vmla.s32 q13, q3, d0[1]\n"
926 "vmla.s32 q14, q3, d1[0]\n"
927 "vmla.s32 q15, q3, d1[1]\n"
928
929 // Loop. Decrement loop index (depth) by 1, since we just handled 1
930 // level of depth.
931 "subs %[depth], #1\n"
932 "bne " GEMMLOWP_LABEL_LOOP
933 "b\n"
934
935 // Store accumulators
936 "mov r0, %[accum_ptr]\n"
937 "vst1.32 {d8, d9}, [r0]!\n"
938 "vst1.32 {d16, d17}, [r0]!\n"
939 "vst1.32 {d24, d25}, [r0]!\n"
940 "vst1.32 {d10, d11}, [r0]!\n"
941 "vst1.32 {d18, d19}, [r0]!\n"
942 "vst1.32 {d26, d27}, [r0]!\n"
943 "vst1.32 {d12, d13}, [r0]!\n"
944 "vst1.32 {d20, d21}, [r0]!\n"
945 "vst1.32 {d28, d29}, [r0]!\n"
946 "vst1.32 {d14, d15}, [r0]!\n"
947 "vst1.32 {d22, d23}, [r0]!\n"
948 "vst1.32 {d30, d31}, [r0]!\n"
949 : // outputs
950 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
951 [depth] "+r"(depth)
952 : // inputs
953 [accum_ptr] "r"(accum_ptr)
954 : // clobbers
955 "cc", "memory", "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7",
956 "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17",
957 "d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27",
958 "d28", "d29", "d30", "d31");
959 }
960 };
961
962 // Not very efficient kernel, just an experiment to see what we can do
963 // without using NEON multiply-with-scalar instructions.
964 struct NEON_32bit_GEMM_Float32_MLA_WithVectorDuplicatingScalar {
965 typedef float OperandType;
966 typedef float AccumulatorType;
967 typedef KernelFormat<
968 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 3>,
969 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 1> >
970 Format;
RunNEON_32bit_GEMM_Float32_MLA_WithVectorDuplicatingScalar971 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
972 AccumulatorType* accum_ptr, int depth) {
973 asm volatile(
974 // Load accumulators
975 "mov r0, %[accum_ptr]\n"
976 "vld1.32 {d8, d9}, [r0]!\n"
977 "vld1.32 {d16, d17}, [r0]!\n"
978 "vld1.32 {d24, d25}, [r0]!\n"
979 "vld1.32 {d10, d11}, [r0]!\n"
980 "vld1.32 {d18, d19}, [r0]!\n"
981 "vld1.32 {d26, d27}, [r0]!\n"
982 "vld1.32 {d12, d13}, [r0]!\n"
983 "vld1.32 {d20, d21}, [r0]!\n"
984 "vld1.32 {d28, d29}, [r0]!\n"
985 "vld1.32 {d14, d15}, [r0]!\n"
986 "vld1.32 {d22, d23}, [r0]!\n"
987 "vld1.32 {d30, d31}, [r0]!\n"
988
989 GEMMLOWP_LABEL_LOOP
990 ":\n"
991
992 // Load 3 Lhs cells of size 4x1 each
993 "vld1.32 {d2, d3}, [%[lhs_ptr]]!\n"
994 "vld1.32 {d4, d5}, [%[lhs_ptr]]!\n"
995 "vld1.32 {d6, d7}, [%[lhs_ptr]]!\n"
996
997 // Multiply-accumulate
998 "vld1.32 {d0[], d1[]}, [%[rhs_ptr]]!\n"
999 "vmla.f32 q4, q1, q0\n"
1000 "vmla.f32 q8, q2, q0\n"
1001 "vmla.f32 q12, q3, q0\n"
1002 "vld1.32 {d0[], d1[]}, [%[rhs_ptr]]!\n"
1003 "vmla.f32 q5, q1, q0\n"
1004 "vmla.f32 q9, q2, q0\n"
1005 "vmla.f32 q13, q3, q0\n"
1006 "vld1.32 {d0[], d1[]}, [%[rhs_ptr]]!\n"
1007 "vmla.f32 q6, q1, q0\n"
1008 "vmla.f32 q10, q2, q0\n"
1009 "vmla.f32 q14, q3, q0\n"
1010 "vld1.32 {d0[], d1[]}, [%[rhs_ptr]]!\n"
1011 "vmla.f32 q7, q1, q0\n"
1012 "vmla.f32 q11, q2, q0\n"
1013 "vmla.f32 q15, q3, q0\n"
1014
1015 // Loop. Decrement loop index (depth) by 1, since we just handled 1
1016 // level of depth.
1017 "subs %[depth], #1\n"
1018 "bne " GEMMLOWP_LABEL_LOOP
1019 "b\n"
1020
1021 // Store accumulators
1022 "mov r0, %[accum_ptr]\n"
1023 "vst1.32 {d8, d9}, [r0]!\n"
1024 "vst1.32 {d16, d17}, [r0]!\n"
1025 "vst1.32 {d24, d25}, [r0]!\n"
1026 "vst1.32 {d10, d11}, [r0]!\n"
1027 "vst1.32 {d18, d19}, [r0]!\n"
1028 "vst1.32 {d26, d27}, [r0]!\n"
1029 "vst1.32 {d12, d13}, [r0]!\n"
1030 "vst1.32 {d20, d21}, [r0]!\n"
1031 "vst1.32 {d28, d29}, [r0]!\n"
1032 "vst1.32 {d14, d15}, [r0]!\n"
1033 "vst1.32 {d22, d23}, [r0]!\n"
1034 "vst1.32 {d30, d31}, [r0]!\n"
1035 : // outputs
1036 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
1037 [depth] "+r"(depth)
1038 : // inputs
1039 [accum_ptr] "r"(accum_ptr)
1040 : // clobbers
1041 "cc", "memory", "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7",
1042 "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17",
1043 "d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27",
1044 "d28", "d29", "d30", "d31");
1045 }
1046 };
1047
1048 // Not very efficient kernel, just an experiment to see what we can do
1049 // without using NEON multiply-with-scalar instructions.
1050 // This variant is relevant as on ARMv7 FMA does not have a with-scalar variant.
1051 struct NEON_32bit_GEMM_Float32_FMA_WithVectorDuplicatingScalar {
1052 typedef float OperandType;
1053 typedef float AccumulatorType;
1054 typedef KernelFormat<
1055 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 3>,
1056 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 1> >
1057 Format;
RunNEON_32bit_GEMM_Float32_FMA_WithVectorDuplicatingScalar1058 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
1059 AccumulatorType* accum_ptr, int depth) {
1060 asm volatile(
1061 // Load accumulators
1062 "mov r0, %[accum_ptr]\n"
1063 "vld1.32 {d8, d9}, [r0]!\n"
1064 "vld1.32 {d16, d17}, [r0]!\n"
1065 "vld1.32 {d24, d25}, [r0]!\n"
1066 "vld1.32 {d10, d11}, [r0]!\n"
1067 "vld1.32 {d18, d19}, [r0]!\n"
1068 "vld1.32 {d26, d27}, [r0]!\n"
1069 "vld1.32 {d12, d13}, [r0]!\n"
1070 "vld1.32 {d20, d21}, [r0]!\n"
1071 "vld1.32 {d28, d29}, [r0]!\n"
1072 "vld1.32 {d14, d15}, [r0]!\n"
1073 "vld1.32 {d22, d23}, [r0]!\n"
1074 "vld1.32 {d30, d31}, [r0]!\n"
1075
1076 GEMMLOWP_LABEL_LOOP
1077 ":\n"
1078
1079 // Load 3 Lhs cells of size 4x1 each
1080 "vld1.32 {d2, d3}, [%[lhs_ptr]]!\n"
1081 "vld1.32 {d4, d5}, [%[lhs_ptr]]!\n"
1082 "vld1.32 {d6, d7}, [%[lhs_ptr]]!\n"
1083
1084 // Multiply-accumulate
1085 "vld1.32 {d0[], d1[]}, [%[rhs_ptr]]!\n"
1086 "vfma.f32 q4, q1, q0\n"
1087 "vfma.f32 q8, q2, q0\n"
1088 "vfma.f32 q12, q3, q0\n"
1089 "vld1.32 {d0[], d1[]}, [%[rhs_ptr]]!\n"
1090 "vfma.f32 q5, q1, q0\n"
1091 "vfma.f32 q9, q2, q0\n"
1092 "vfma.f32 q13, q3, q0\n"
1093 "vld1.32 {d0[], d1[]}, [%[rhs_ptr]]!\n"
1094 "vfma.f32 q6, q1, q0\n"
1095 "vfma.f32 q10, q2, q0\n"
1096 "vfma.f32 q14, q3, q0\n"
1097 "vld1.32 {d0[], d1[]}, [%[rhs_ptr]]!\n"
1098 "vfma.f32 q7, q1, q0\n"
1099 "vfma.f32 q11, q2, q0\n"
1100 "vfma.f32 q15, q3, q0\n"
1101
1102 // Loop. Decrement loop index (depth) by 1, since we just handled 1
1103 // level of depth.
1104 "subs %[depth], #1\n"
1105 "bne " GEMMLOWP_LABEL_LOOP
1106 "b\n"
1107
1108 // Store accumulators
1109 "mov r0, %[accum_ptr]\n"
1110 "vst1.32 {d8, d9}, [r0]!\n"
1111 "vst1.32 {d16, d17}, [r0]!\n"
1112 "vst1.32 {d24, d25}, [r0]!\n"
1113 "vst1.32 {d10, d11}, [r0]!\n"
1114 "vst1.32 {d18, d19}, [r0]!\n"
1115 "vst1.32 {d26, d27}, [r0]!\n"
1116 "vst1.32 {d12, d13}, [r0]!\n"
1117 "vst1.32 {d20, d21}, [r0]!\n"
1118 "vst1.32 {d28, d29}, [r0]!\n"
1119 "vst1.32 {d14, d15}, [r0]!\n"
1120 "vst1.32 {d22, d23}, [r0]!\n"
1121 "vst1.32 {d30, d31}, [r0]!\n"
1122 : // outputs
1123 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
1124 [depth] "+r"(depth)
1125 : // inputs
1126 [accum_ptr] "r"(accum_ptr)
1127 : // clobbers
1128 "cc", "memory", "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7",
1129 "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17",
1130 "d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27",
1131 "d28", "d29", "d30", "d31");
1132 }
1133 };
1134
1135 // This is the "most natural" kernel, using NEON multiply-with-scalar
1136 // instructions.
1137 struct NEON_32bit_GEMM_Float32_MLA_WithScalar {
1138 typedef float OperandType;
1139 typedef float AccumulatorType;
1140 typedef KernelFormat<
1141 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 3>,
1142 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 1> >
1143 Format;
RunNEON_32bit_GEMM_Float32_MLA_WithScalar1144 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
1145 AccumulatorType* accum_ptr, int depth) {
1146 asm volatile(
1147 // Load accumulators
1148 "mov r0, %[accum_ptr]\n"
1149 "vld1.32 {d8, d9}, [r0]!\n"
1150 "vld1.32 {d16, d17}, [r0]!\n"
1151 "vld1.32 {d24, d25}, [r0]!\n"
1152 "vld1.32 {d10, d11}, [r0]!\n"
1153 "vld1.32 {d18, d19}, [r0]!\n"
1154 "vld1.32 {d26, d27}, [r0]!\n"
1155 "vld1.32 {d12, d13}, [r0]!\n"
1156 "vld1.32 {d20, d21}, [r0]!\n"
1157 "vld1.32 {d28, d29}, [r0]!\n"
1158 "vld1.32 {d14, d15}, [r0]!\n"
1159 "vld1.32 {d22, d23}, [r0]!\n"
1160 "vld1.32 {d30, d31}, [r0]!\n"
1161
1162 GEMMLOWP_LABEL_LOOP
1163 ":\n"
1164
1165 // Load 1 Rhs cell of size 1x4
1166 "vld1.32 {d0, d1}, [%[rhs_ptr]]!\n"
1167
1168 // Load 3 Lhs cells of size 4x1 each
1169 "vld1.32 {d2, d3}, [%[lhs_ptr]]!\n"
1170 "vld1.32 {d4, d5}, [%[lhs_ptr]]!\n"
1171 "vld1.32 {d6, d7}, [%[lhs_ptr]]!\n"
1172
1173 // Multiply-accumulate
1174 "vmla.f32 q4, q1, d0[0]\n"
1175 "vmla.f32 q5, q1, d0[1]\n"
1176 "vmla.f32 q6, q1, d1[0]\n"
1177 "vmla.f32 q7, q1, d1[1]\n"
1178 "vmla.f32 q8, q2, d0[0]\n"
1179 "vmla.f32 q9, q2, d0[1]\n"
1180 "vmla.f32 q10, q2, d1[0]\n"
1181 "vmla.f32 q11, q2, d1[1]\n"
1182 "vmla.f32 q12, q3, d0[0]\n"
1183 "vmla.f32 q13, q3, d0[1]\n"
1184 "vmla.f32 q14, q3, d1[0]\n"
1185 "vmla.f32 q15, q3, d1[1]\n"
1186
1187 // Loop. Decrement loop index (depth) by 1, since we just handled 1
1188 // level of depth.
1189 "subs %[depth], #1\n"
1190 "bne " GEMMLOWP_LABEL_LOOP
1191 "b\n"
1192
1193 // Store accumulators
1194 "mov r0, %[accum_ptr]\n"
1195 "vst1.32 {d8, d9}, [r0]!\n"
1196 "vst1.32 {d16, d17}, [r0]!\n"
1197 "vst1.32 {d24, d25}, [r0]!\n"
1198 "vst1.32 {d10, d11}, [r0]!\n"
1199 "vst1.32 {d18, d19}, [r0]!\n"
1200 "vst1.32 {d26, d27}, [r0]!\n"
1201 "vst1.32 {d12, d13}, [r0]!\n"
1202 "vst1.32 {d20, d21}, [r0]!\n"
1203 "vst1.32 {d28, d29}, [r0]!\n"
1204 "vst1.32 {d14, d15}, [r0]!\n"
1205 "vst1.32 {d22, d23}, [r0]!\n"
1206 "vst1.32 {d30, d31}, [r0]!\n"
1207 : // outputs
1208 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
1209 [depth] "+r"(depth)
1210 : // inputs
1211 [accum_ptr] "r"(accum_ptr)
1212 : // clobbers
1213 "cc", "memory", "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7",
1214 "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17",
1215 "d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27",
1216 "d28", "d29", "d30", "d31");
1217 }
1218 };
1219
1220 // Faster kernel contributed by ARM in 64bit form
1221 // (see NEON_64bit_GEMM_Float32_WithScalar_A53) then ported to 32bit code.
1222 // Tuned for A53.
1223 struct NEON_32bit_GEMM_Float32_WithScalar_A53 {
1224 typedef float OperandType;
1225 typedef float AccumulatorType;
1226 typedef KernelFormat<
1227 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 3>,
1228 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 1> >
1229 Format;
RunNEON_32bit_GEMM_Float32_WithScalar_A531230 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
1231 AccumulatorType* accum_ptr, int depth) {
1232 asm volatile(
1233 // Load accumulators
1234 "mov r0, %[accum_ptr]\n"
1235 "vld1.32 {d8, d9}, [r0]!\n"
1236 "vld1.32 {d16, d17}, [r0]!\n"
1237 "vld1.32 {d24, d25}, [r0]!\n"
1238 "vld1.32 {d10, d11}, [r0]!\n"
1239 "vld1.32 {d18, d19}, [r0]!\n"
1240 "vld1.32 {d26, d27}, [r0]!\n"
1241 "vld1.32 {d12, d13}, [r0]!\n"
1242 "vld1.32 {d20, d21}, [r0]!\n"
1243 "vld1.32 {d28, d29}, [r0]!\n"
1244 "vld1.32 {d14, d15}, [r0]!\n"
1245 "vld1.32 {d22, d23}, [r0]!\n"
1246 "vld1.32 {d30, d31}, [r0]!\n"
1247
1248 // Overview of register layout:
1249 //
1250 // A 1x4 cell of Rhs is stored in d0--d1 (q0).
1251 // A 12x1 block of 3 4x1 cells Lhs is stored in d2--d7
1252 // (q1--q3).
1253 // A 12x4 block of accumulators is stored in q4--q15.
1254 //
1255 // +-----+-----+-----+-----+
1256 // Rhs |d0[0]|d0[1]|d1[0]|d1[1]|
1257 // +-----+-----+-----+-----+
1258 //
1259 // | | | | |
1260 //
1261 // Lhs | | | | |
1262 //
1263 // +--+- - - - - - +-----+-----+-----+-----+
1264 // |d2| | q4 | q5 | q6 | q7 |
1265 // |d2| | q4 | q5 | q6 | q7 |
1266 // |d3| | q4 | q5 | q6 | q7 |
1267 // |d3| | q4 | q5 | q6 | q7 |
1268 // +--+- - - - - - +-----+-----+-----+-----+
1269 // |d4| | q8 | q9 | q10 | q11 |
1270 // |d4| | q8 | q9 | q10 | q11 |
1271 // |d5| | q8 | q9 | q10 | q11 |
1272 // |d5| | q8 | q9 | q10 | q11 |
1273 // +--+ - - - - - - +-----+-----+-----+-----+
1274 // |d6| | q12 | q13 | q14 | q15 |
1275 // |d6| | q12 | q13 | q14 | q15 |
1276 // |d7| | q12 | q13 | q14 | q15 |
1277 // |d7| | q12 | q13 | q14 | q15 |
1278 // +--+- - - - - - +-----+-----+-----+-----+
1279 //
1280 // Accumulator
1281
1282 // Load Rhs cell
1283 "vldr d0, [%[rhs_ptr]]\n"
1284 "ldr r2, [%[rhs_ptr], #8]\n"
1285 "ldr r3, [%[rhs_ptr], #12]\n"
1286
1287 // Load 1st Lhs Cell
1288 "vld1.32 {d2, d3}, [%[lhs_ptr]]\n"
1289
1290 GEMMLOWP_LABEL_LOOP
1291 ":\n"
1292
1293 "vldr d4, [%[lhs_ptr], #16]\n" // Load 1st half of 2nd Lhs cell
1294 "vmov d1, r2, r3\n" // Prepare 2nd half of Rhs cell
1295 "vmla.f32 q4, q1, d0[0]\n" // Multiply 1st Lhs cell with column 0
1296 "ldr r2, [%[lhs_ptr], #24]\n" // Load 2nd half of 2nd Lhs cell, part 1
1297 "vmla.f32 q5, q1, d0[1]\n" // Multiply 1st Lhs cell with column 1
1298 "ldr r3, [%[lhs_ptr], #28]\n" // Load 2nd half of 2nd Lhs cell, part 2
1299 "vmla.f32 q6, q1, d1[0]\n" // Multiply 1st Lhs cell with column 2
1300 "subs %[depth], #1\n"
1301
1302 "vldr d6, [%[lhs_ptr], #32]\n" // Load 1st half of 3rd Lhs cell
1303 "vmov d5, r2, r3\n" // Prepare 2nd half of 2nd Lhs cell
1304 "vmla.f32 q7, q1, d1[1]\n" // Multiply 1st Lhs cell with column 3
1305 "ldr r2, [%[lhs_ptr], #40]\n" // Load 2nd half of 3rd Lhs cell, part 1
1306 "vmla.f32 q8, q2, d0[0]\n" // Multiply 2nd Lhs cell with column 0
1307 "ldr r3, [%[lhs_ptr], #44]\n" // Load 2nd half of 3rd Lhs cell, part 2
1308 "vmla.f32 q9, q2, d0[1]\n" // Multiply 2nd Lhs cell with column 1
1309 "add %[rhs_ptr], %[rhs_ptr], #16\n" // Move forward by 1 Rhs cell
1310
1311 "vldr d2, [%[lhs_ptr], #48]\n" // Load 1st half of 1st Lhs cell of next
1312 // iteration
1313 "vmov d7, r2, r3\n" // Prepare 2nd half of 3rd Lhs cell
1314 "vmla.f32 q10, q2, d1[0]\n" // Multiply 2nd Lhs cell with column 2
1315 "ldr r2, [%[lhs_ptr], #56]\n" // Load 2nd half of 1st Lhs cell of next
1316 // iter, part 1
1317 "vmla.f32 q12, q3, d0[0]\n" // Multiply 3rd Lhs cell with column 0
1318 "ldr r3, [%[lhs_ptr], #60]\n" // Load 2nd half of 1st Lhs cell of next
1319 // iter, part 2
1320 "vmla.f32 q13, q3, d0[1]\n" // Multiply 3rd Lhs cell with column 1
1321 "add %[lhs_ptr], %[lhs_ptr], #48\n" // Move forward by 3 Lhs cells
1322
1323 "vldr d0, [%[rhs_ptr]]\n" // Load 1st half of Rhs cell of next
1324 // iteration
1325 "vmov d3, r2, r3\n" // Prepare 2nd half of 1st Lhs cell of next
1326 // iteration
1327 "vmla.f32 q11, q2, d1[1]\n" // Multiply 2nd Lhs cell with column 3
1328 "ldr r2, [%[rhs_ptr], #8]\n" // Load 2nd half of Rhs cell of next
1329 // iteration, part 1
1330 "vmla.f32 q14, q3, d1[0]\n" // Multiply 3rd Lhs cell with column 2
1331 "ldr r3, [%[rhs_ptr], #12]\n" // Load 2nd half of Rhs cell of next
1332 // iteration, part 2
1333 "vmla.f32 q15, q3, d1[1]\n" // Multiply 3rd Lhs cell with column 3
1334
1335 // Loop branch. This will dual issue in fmla cycle 3 of the 4th block.
1336 "bne " GEMMLOWP_LABEL_LOOP
1337 "b\n"
1338
1339 // Store accumulators
1340 "mov r0, %[accum_ptr]\n"
1341 "vst1.32 {d8, d9}, [r0]!\n"
1342 "vst1.32 {d16, d17}, [r0]!\n"
1343 "vst1.32 {d24, d25}, [r0]!\n"
1344 "vst1.32 {d10, d11}, [r0]!\n"
1345 "vst1.32 {d18, d19}, [r0]!\n"
1346 "vst1.32 {d26, d27}, [r0]!\n"
1347 "vst1.32 {d12, d13}, [r0]!\n"
1348 "vst1.32 {d20, d21}, [r0]!\n"
1349 "vst1.32 {d28, d29}, [r0]!\n"
1350 "vst1.32 {d14, d15}, [r0]!\n"
1351 "vst1.32 {d22, d23}, [r0]!\n"
1352 "vst1.32 {d30, d31}, [r0]!\n"
1353 : // outputs
1354 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
1355 [depth] "+r"(depth)
1356 : // inputs
1357 [accum_ptr] "r"(accum_ptr)
1358 : // clobbers
1359 "cc", "memory", "r0", "r2", "r3", "d0", "d1", "d2", "d3", "d4", "d5",
1360 "d6", "d7", "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16",
1361 "d17", "d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26",
1362 "d27", "d28", "d29", "d30", "d31");
1363 }
1364 };
1365
1366 struct NEON_32bit_GEMM_Float32_WithScalar_A53_depth2 {
1367 typedef float OperandType;
1368 typedef float AccumulatorType;
1369 typedef KernelFormat<
1370 KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 3>,
1371 KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 1> >
1372 Format;
RunNEON_32bit_GEMM_Float32_WithScalar_A53_depth21373 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
1374 AccumulatorType* accum_ptr, int depth) {
1375 asm volatile(
1376 // Load accumulators
1377 "mov r0, %[accum_ptr]\n"
1378 "vld1.32 {d8, d9}, [r0]!\n"
1379 "vld1.32 {d16, d17}, [r0]!\n"
1380 "vld1.32 {d24, d25}, [r0]!\n"
1381 "vld1.32 {d10, d11}, [r0]!\n"
1382 "vld1.32 {d18, d19}, [r0]!\n"
1383 "vld1.32 {d26, d27}, [r0]!\n"
1384 "vld1.32 {d12, d13}, [r0]!\n"
1385 "vld1.32 {d20, d21}, [r0]!\n"
1386 "vld1.32 {d28, d29}, [r0]!\n"
1387 "vld1.32 {d14, d15}, [r0]!\n"
1388 "vld1.32 {d22, d23}, [r0]!\n"
1389 "vld1.32 {d30, d31}, [r0]!\n"
1390
1391 // Overview of register layout:
1392 //
1393 // A 1x4 cell of Rhs is stored in d0--d1 (q0).
1394 // A 12x1 block of 3 4x1 cells Lhs is stored in d2--d7
1395 // (q1--q3).
1396 // A 12x4 block of accumulators is stored in q4--q15.
1397 //
1398 // +-----+-----+-----+-----+
1399 // Rhs |d0[0]|d0[1]|d1[0]|d1[1]|
1400 // +-----+-----+-----+-----+
1401 //
1402 // | | | | |
1403 //
1404 // Lhs | | | | |
1405 //
1406 // +--+- - - - - - +-----+-----+-----+-----+
1407 // |d2| | q4 | q5 | q6 | q7 |
1408 // |d2| | q4 | q5 | q6 | q7 |
1409 // |d3| | q4 | q5 | q6 | q7 |
1410 // |d3| | q4 | q5 | q6 | q7 |
1411 // +--+- - - - - - +-----+-----+-----+-----+
1412 // |d4| | q8 | q9 | q10 | q11 |
1413 // |d4| | q8 | q9 | q10 | q11 |
1414 // |d5| | q8 | q9 | q10 | q11 |
1415 // |d5| | q8 | q9 | q10 | q11 |
1416 // +--+ - - - - - - +-----+-----+-----+-----+
1417 // |d6| | q12 | q13 | q14 | q15 |
1418 // |d6| | q12 | q13 | q14 | q15 |
1419 // |d7| | q12 | q13 | q14 | q15 |
1420 // |d7| | q12 | q13 | q14 | q15 |
1421 // +--+- - - - - - +-----+-----+-----+-----+
1422 //
1423 // Accumulator
1424
1425 // Load Rhs cell
1426 "vldr d0, [%[rhs_ptr]]\n"
1427 "ldr r2, [%[rhs_ptr], #8]\n"
1428 "ldr r3, [%[rhs_ptr], #12]\n"
1429
1430 // Load 1st Lhs Cell
1431 "vld1.32 {d2, d3}, [%[lhs_ptr]]\n"
1432
1433 // Loop head - handling 2 levels of depth at once
1434 GEMMLOWP_LABEL_LOOP
1435 ":\n"
1436
1437 // Level of depth 1
1438
1439 "vldr d4, [%[lhs_ptr], #32]\n" // Load 1st half of 2nd Lhs cell
1440 "vmov d1, r2, r3\n" // Prepare 2nd half of Rhs cell
1441 "vmla.f32 q4, q1, d0[0]\n" // Multiply 1st Lhs cell with column 0
1442 "ldr r2, [%[lhs_ptr], #40]\n" // Load 2nd half of 2nd Lhs cell, part 1
1443 "vmla.f32 q5, q1, d0[1]\n" // Multiply 1st Lhs cell with column 1
1444 "ldr r3, [%[lhs_ptr], #44]\n" // Load 2nd half of 2nd Lhs cell, part 2
1445 "vmla.f32 q6, q1, d1[0]\n" // Multiply 1st Lhs cell with column 2
1446
1447 "vldr d6, [%[lhs_ptr], #64]\n" // Load 1st half of 3rd Lhs cell
1448 "vmov d5, r2, r3\n" // Prepare 2nd half of 2nd Lhs cell
1449 "vmla.f32 q7, q1, d1[1]\n" // Multiply 1st Lhs cell with column 3
1450 "ldr r2, [%[lhs_ptr], #72]\n" // Load 2nd half of 3rd Lhs cell, part 1
1451 "vmla.f32 q8, q2, d0[0]\n" // Multiply 2nd Lhs cell with column 0
1452 "ldr r3, [%[lhs_ptr], #76]\n" // Load 2nd half of 3rd Lhs cell, part 2
1453 "vmla.f32 q9, q2, d0[1]\n" // Multiply 2nd Lhs cell with column 1
1454
1455 "vldr d2, [%[lhs_ptr], #16]\n" // Load 1st half of 1st Lhs cell of next
1456 // iteration
1457 "vmov d7, r2, r3\n" // Prepare 2nd half of 3rd Lhs cell
1458 "vmla.f32 q10, q2, d1[0]\n" // Multiply 2nd Lhs cell with column 2
1459 "ldr r2, [%[lhs_ptr], #24]\n" // Load 2nd half of 1st Lhs cell of next
1460 // iter, part 1
1461 "vmla.f32 q12, q3, d0[0]\n" // Multiply 3rd Lhs cell with column 0
1462 "ldr r3, [%[lhs_ptr], #28]\n" // Load 2nd half of 1st Lhs cell of next
1463 // iter, part 2
1464 "vmla.f32 q13, q3, d0[1]\n" // Multiply 3rd Lhs cell with column 1
1465
1466 "vldr d0, [%[rhs_ptr], #16]\n" // Load 1st half of Rhs cell of next
1467 // iteration
1468 "vmov d3, r2, r3\n" // Prepare 2nd half of 1st Lhs cell of next
1469 // iteration
1470 "vmla.f32 q11, q2, d1[1]\n" // Multiply 2nd Lhs cell with column 3
1471 "ldr r2, [%[rhs_ptr], #24]\n" // Load 2nd half of Rhs cell of next
1472 // iteration, part 1
1473 "vmla.f32 q14, q3, d1[0]\n" // Multiply 3rd Lhs cell with column 2
1474 "ldr r3, [%[rhs_ptr], #28]\n" // Load 2nd half of Rhs cell of next
1475 // iteration, part 2
1476 "vmla.f32 q15, q3, d1[1]\n" // Multiply 3rd Lhs cell with column 3
1477
1478 // Level of depth 2
1479 "vldr d4, [%[lhs_ptr], #48]\n" // Load 1st half of 2nd Lhs cell
1480 "vmov d1, r2, r3\n" // Prepare 2nd half of Rhs cell
1481 "vmla.f32 q4, q1, d0[0]\n" // Multiply 1st Lhs cell with column 0
1482 "ldr r2, [%[lhs_ptr], #56]\n" // Load 2nd half of 2nd Lhs cell, part 1
1483 "vmla.f32 q5, q1, d0[1]\n" // Multiply 1st Lhs cell with column 1
1484 "ldr r3, [%[lhs_ptr], #60]\n" // Load 2nd half of 2nd Lhs cell, part 2
1485 "vmla.f32 q6, q1, d1[0]\n" // Multiply 1st Lhs cell with column 2
1486 "subs %[depth], #2\n" // Decrement depth counter
1487
1488 "vldr d6, [%[lhs_ptr], #80]\n" // Load 1st half of 3rd Lhs cell
1489 "vmov d5, r2, r3\n" // Prepare 2nd half of 2nd Lhs cell
1490 "vmla.f32 q7, q1, d1[1]\n" // Multiply 1st Lhs cell with column 3
1491 "ldr r2, [%[lhs_ptr], #88]\n" // Load 2nd half of 3rd Lhs cell, part 1
1492 "vmla.f32 q8, q2, d0[0]\n" // Multiply 2nd Lhs cell with column 0
1493 "ldr r3, [%[lhs_ptr], #92]\n" // Load 2nd half of 3rd Lhs cell, part 2
1494 "vmla.f32 q9, q2, d0[1]\n" // Multiply 2nd Lhs cell with column 1
1495 "add %[rhs_ptr], %[rhs_ptr], #32\n" // Move forward by 1 Rhs cell
1496
1497 "vldr d2, [%[lhs_ptr], #96]\n" // Load 1st half of 1st Lhs cell of next
1498 // iteration
1499 "vmov d7, r2, r3\n" // Prepare 2nd half of 3rd Lhs cell
1500 "vmla.f32 q10, q2, d1[0]\n" // Multiply 2nd Lhs cell with column 2
1501 "ldr r2, [%[lhs_ptr], #104]\n" // Load 2nd half of 1st Lhs cell of next
1502 // iter, part 1
1503 "vmla.f32 q12, q3, d0[0]\n" // Multiply 3rd Lhs cell with column 0
1504 "ldr r3, [%[lhs_ptr], #108]\n" // Load 2nd half of 1st Lhs cell of next
1505 // iter, part 2
1506 "vmla.f32 q13, q3, d0[1]\n" // Multiply 3rd Lhs cell with column 1
1507 "add %[lhs_ptr], %[lhs_ptr], #96\n" // Move forward by 3 Lhs cells
1508
1509 "vldr d0, [%[rhs_ptr]]\n" // Load 1st half of Rhs cell of next
1510 // iteration
1511 "vmov d3, r2, r3\n" // Prepare 2nd half of 1st Lhs cell of next
1512 // iteration
1513 "vmla.f32 q11, q2, d1[1]\n" // Multiply 2nd Lhs cell with column 3
1514 "ldr r2, [%[rhs_ptr], #8]\n" // Load 2nd half of Rhs cell of next
1515 // iteration, part 1
1516 "vmla.f32 q14, q3, d1[0]\n" // Multiply 3rd Lhs cell with column 2
1517 "ldr r3, [%[rhs_ptr], #12]\n" // Load 2nd half of Rhs cell of next
1518 // iteration, part 2
1519 "vmla.f32 q15, q3, d1[1]\n" // Multiply 3rd Lhs cell with column 3
1520
1521 // Loop branch. This will dual issue in fmla cycle 3 of the 4th block.
1522 //"bne loop_%=\n"
1523 "bne " GEMMLOWP_LABEL_LOOP
1524 "b\n"
1525
1526 // Store accumulators
1527 "mov r0, %[accum_ptr]\n"
1528 "vst1.32 {d8, d9}, [r0]!\n"
1529 "vst1.32 {d16, d17}, [r0]!\n"
1530 "vst1.32 {d24, d25}, [r0]!\n"
1531 "vst1.32 {d10, d11}, [r0]!\n"
1532 "vst1.32 {d18, d19}, [r0]!\n"
1533 "vst1.32 {d26, d27}, [r0]!\n"
1534 "vst1.32 {d12, d13}, [r0]!\n"
1535 "vst1.32 {d20, d21}, [r0]!\n"
1536 "vst1.32 {d28, d29}, [r0]!\n"
1537 "vst1.32 {d14, d15}, [r0]!\n"
1538 "vst1.32 {d22, d23}, [r0]!\n"
1539 "vst1.32 {d30, d31}, [r0]!\n"
1540 : // outputs
1541 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
1542 [depth] "+r"(depth)
1543 : // inputs
1544 [accum_ptr] "r"(accum_ptr)
1545 : // clobbers
1546 "cc", "memory", "r0", "r2", "r3", "d0", "d1", "d2", "d3", "d4", "d5",
1547 "d6", "d7", "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16",
1548 "d17", "d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26",
1549 "d27", "d28", "d29", "d30", "d31");
1550 }
1551 };
1552
1553 // This rotating variant performs well when permutations (vext) can be
1554 // dual-issued with arithmetic instructions.
1555 struct NEON_32bit_GEMM_Float32_MLA_Rotating {
1556 typedef float OperandType;
1557 typedef float AccumulatorType;
1558 typedef KernelFormat<
1559 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 3>,
1560 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 1> >
1561 Format;
RunNEON_32bit_GEMM_Float32_MLA_Rotating1562 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
1563 AccumulatorType* accum_ptr, int depth) {
1564 asm volatile(
1565 // Load accumulators
1566 "mov r0, %[accum_ptr]\n"
1567 "vld1.32 {d8, d9}, [r0]!\n"
1568 "vld1.32 {d16, d17}, [r0]!\n"
1569 "vld1.32 {d24, d25}, [r0]!\n"
1570 "vld1.32 {d10, d11}, [r0]!\n"
1571 "vld1.32 {d18, d19}, [r0]!\n"
1572 "vld1.32 {d26, d27}, [r0]!\n"
1573 "vld1.32 {d12, d13}, [r0]!\n"
1574 "vld1.32 {d20, d21}, [r0]!\n"
1575 "vld1.32 {d28, d29}, [r0]!\n"
1576 "vld1.32 {d14, d15}, [r0]!\n"
1577 "vld1.32 {d22, d23}, [r0]!\n"
1578 "vld1.32 {d30, d31}, [r0]!\n"
1579
1580 #define NEON_32BIT_ROTATING_FLOAT_KERNEL_TRANSPOSE_ACCUMULATOR_CELLS \
1581 "vtrn.32 q4, q5\n" \
1582 "vtrn.32 q6, q7\n" \
1583 "vswp d9, d12\n" \
1584 "vswp d11, d14\n" \
1585 "vtrn.32 q8, q9\n" \
1586 "vtrn.32 q10, q11\n" \
1587 "vswp d17, d20\n" \
1588 "vswp d19, d22\n" \
1589 "vtrn.32 q12, q13\n" \
1590 "vtrn.32 q14, q15\n" \
1591 "vswp d25, d28\n" \
1592 "vswp d27, d30\n"
1593
1594 #define NEON_32BIT_ROTATING_FLOAT_KERNEL_ROTATE_ACCUMULATOR_CELLS(a, b, c) \
1595 NEON_32BIT_ROTATING_FLOAT_KERNEL_TRANSPOSE_ACCUMULATOR_CELLS \
1596 "vext.32 q5, q5, q5, #" #a \
1597 "\n" \
1598 "vext.32 q6, q6, q6, #" #b \
1599 "\n" \
1600 "vext.32 q7, q7, q7, #" #c \
1601 "\n" \
1602 "vext.32 q9, q9, q9, #" #a \
1603 "\n" \
1604 "vext.32 q10, q10, q10, #" #b \
1605 "\n" \
1606 "vext.32 q11, q11, q11, #" #c \
1607 "\n" \
1608 "vext.32 q13, q13, q13, #" #a \
1609 "\n" \
1610 "vext.32 q14, q14, q14, #" #b \
1611 "\n" \
1612 "vext.32 q15, q15, q15, #" #c \
1613 "\n" NEON_32BIT_ROTATING_FLOAT_KERNEL_TRANSPOSE_ACCUMULATOR_CELLS
1614
1615 NEON_32BIT_ROTATING_FLOAT_KERNEL_ROTATE_ACCUMULATOR_CELLS(1, 2, 3)
1616
1617 //"loop_%=:\n"
1618 GEMMLOWP_LABEL_LOOP
1619 ":\n"
1620
1621 // Load 1 Rhs cell of size 1x4
1622 "vld1.32 {d0, d1}, [%[rhs_ptr]]!\n"
1623
1624 // Load 3 Lhs cells of size 4x1 each
1625 "vld1.32 {d2, d3}, [%[lhs_ptr]]!\n"
1626 "vld1.32 {d4, d5}, [%[lhs_ptr]]!\n"
1627 "vld1.32 {d6, d7}, [%[lhs_ptr]]!\n"
1628
1629 // Multiply-accumulate
1630 "vmla.f32 q4, q1, q0\n"
1631 "vmla.f32 q8, q2, q0\n"
1632 "vmla.f32 q12, q3, q0\n"
1633 "vext.f32 q0, q0, q0, #1\n"
1634 "vmla.f32 q5, q1, q0\n"
1635 "vmla.f32 q9, q2, q0\n"
1636 "vmla.f32 q13, q3, q0\n"
1637 "vext.f32 q0, q0, q0, #1\n"
1638 "vmla.f32 q6, q1, q0\n"
1639 "vmla.f32 q10, q2, q0\n"
1640 "vmla.f32 q14, q3, q0\n"
1641 "vext.f32 q0, q0, q0, #1\n"
1642 "vmla.f32 q7, q1, q0\n"
1643 "vmla.f32 q11, q2, q0\n"
1644 "vmla.f32 q15, q3, q0\n"
1645
1646 // Loop. Decrement loop index (depth) by 1, since we just handled 1
1647 // level of depth.
1648 "subs %[depth], #1\n"
1649 //"bne loop_%=\n"
1650 "bne " GEMMLOWP_LABEL_LOOP
1651 "b\n"
1652
1653 // Store accumulators
1654 "mov r0, %[accum_ptr]\n"
1655
1656 NEON_32BIT_ROTATING_FLOAT_KERNEL_ROTATE_ACCUMULATOR_CELLS(3, 2, 1)
1657
1658 "vst1.32 {d8, d9}, [r0]!\n"
1659 "vst1.32 {d16, d17}, [r0]!\n"
1660 "vst1.32 {d24, d25}, [r0]!\n"
1661 "vst1.32 {d10, d11}, [r0]!\n"
1662 "vst1.32 {d18, d19}, [r0]!\n"
1663 "vst1.32 {d26, d27}, [r0]!\n"
1664 "vst1.32 {d12, d13}, [r0]!\n"
1665 "vst1.32 {d20, d21}, [r0]!\n"
1666 "vst1.32 {d28, d29}, [r0]!\n"
1667 "vst1.32 {d14, d15}, [r0]!\n"
1668 "vst1.32 {d22, d23}, [r0]!\n"
1669 "vst1.32 {d30, d31}, [r0]!\n"
1670 : // outputs
1671 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
1672 [depth] "+r"(depth)
1673 : // inputs
1674 [accum_ptr] "r"(accum_ptr)
1675 : // clobbers
1676 "cc", "memory", "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7",
1677 "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17",
1678 "d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27",
1679 "d28", "d29", "d30", "d31");
1680 }
1681 };
1682
1683 // This rotating variant performs well when permutations (vext) can be
1684 // dual-issued with arithmetic instructions. It is relevant as the rotating
1685 // approach removes the need for multiply-with-scalar instructions, and ARMv7
1686 // FMA does not have a with-scalar variant.
1687 struct NEON_32bit_GEMM_Float32_FMA_Rotating {
1688 typedef float OperandType;
1689 typedef float AccumulatorType;
1690 typedef KernelFormat<
1691 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 3>,
1692 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 1> >
1693 Format;
RunNEON_32bit_GEMM_Float32_FMA_Rotating1694 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
1695 AccumulatorType* accum_ptr, int depth) {
1696 asm volatile(
1697 // Load accumulators
1698 "mov r0, %[accum_ptr]\n"
1699 "vld1.32 {d8, d9}, [r0]!\n"
1700 "vld1.32 {d16, d17}, [r0]!\n"
1701 "vld1.32 {d24, d25}, [r0]!\n"
1702 "vld1.32 {d10, d11}, [r0]!\n"
1703 "vld1.32 {d18, d19}, [r0]!\n"
1704 "vld1.32 {d26, d27}, [r0]!\n"
1705 "vld1.32 {d12, d13}, [r0]!\n"
1706 "vld1.32 {d20, d21}, [r0]!\n"
1707 "vld1.32 {d28, d29}, [r0]!\n"
1708 "vld1.32 {d14, d15}, [r0]!\n"
1709 "vld1.32 {d22, d23}, [r0]!\n"
1710 "vld1.32 {d30, d31}, [r0]!\n"
1711
1712 NEON_32BIT_ROTATING_FLOAT_KERNEL_ROTATE_ACCUMULATOR_CELLS(1, 2, 3)
1713
1714 //"loop_%=:\n"
1715 GEMMLOWP_LABEL_LOOP
1716 ":\n"
1717
1718 // Load 1 Rhs cell of size 1x4
1719 "vld1.32 {d0, d1}, [%[rhs_ptr]]!\n"
1720
1721 // Load 3 Lhs cells of size 4x1 each
1722 "vld1.32 {d2, d3}, [%[lhs_ptr]]!\n"
1723 "vld1.32 {d4, d5}, [%[lhs_ptr]]!\n"
1724 "vld1.32 {d6, d7}, [%[lhs_ptr]]!\n"
1725
1726 // Multiply-accumulate
1727 "vfma.f32 q4, q1, q0\n"
1728 "vfma.f32 q8, q2, q0\n"
1729 "vfma.f32 q12, q3, q0\n"
1730 "vext.f32 q0, q0, q0, #1\n"
1731 "vfma.f32 q5, q1, q0\n"
1732 "vfma.f32 q9, q2, q0\n"
1733 "vfma.f32 q13, q3, q0\n"
1734 "vext.f32 q0, q0, q0, #1\n"
1735 "vfma.f32 q6, q1, q0\n"
1736 "vfma.f32 q10, q2, q0\n"
1737 "vfma.f32 q14, q3, q0\n"
1738 "vext.f32 q0, q0, q0, #1\n"
1739 "vfma.f32 q7, q1, q0\n"
1740 "vfma.f32 q11, q2, q0\n"
1741 "vfma.f32 q15, q3, q0\n"
1742
1743 // Loop. Decrement loop index (depth) by 1, since we just handled 1
1744 // level of depth.
1745 "subs %[depth], #1\n"
1746 //"bne loop_%=\n"
1747 "bne " GEMMLOWP_LABEL_LOOP "b\n"
1748
1749 NEON_32BIT_ROTATING_FLOAT_KERNEL_ROTATE_ACCUMULATOR_CELLS(3, 2, 1)
1750
1751 // Store accumulators
1752 "mov r0, %[accum_ptr]\n"
1753 "vst1.32 {d8, d9}, [r0]!\n"
1754 "vst1.32 {d16, d17}, [r0]!\n"
1755 "vst1.32 {d24, d25}, [r0]!\n"
1756 "vst1.32 {d10, d11}, [r0]!\n"
1757 "vst1.32 {d18, d19}, [r0]!\n"
1758 "vst1.32 {d26, d27}, [r0]!\n"
1759 "vst1.32 {d12, d13}, [r0]!\n"
1760 "vst1.32 {d20, d21}, [r0]!\n"
1761 "vst1.32 {d28, d29}, [r0]!\n"
1762 "vst1.32 {d14, d15}, [r0]!\n"
1763 "vst1.32 {d22, d23}, [r0]!\n"
1764 "vst1.32 {d30, d31}, [r0]!\n"
1765 : // outputs
1766 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
1767 [depth] "+r"(depth)
1768 : // inputs
1769 [accum_ptr] "r"(accum_ptr)
1770 : // clobbers
1771 "cc", "memory", "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7",
1772 "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17",
1773 "d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27",
1774 "d28", "d29", "d30", "d31");
1775 }
1776 };
1777
1778 #endif // __arm__
1779
1780 #ifdef __aarch64__
1781
1782 // This is the current standard kernel in gemmlowp, see:
1783 // https://github.com/google/gemmlowp/blob/b1e2a29ff866680028f3080efc244e10e8dd7f46/internal/kernel_neon.h#L646
1784 struct NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators {
1785 typedef std::uint8_t OperandType;
1786 typedef std::uint32_t AccumulatorType;
1787 typedef KernelFormat<
1788 KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 3>,
1789 KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 2> >
1790 Format;
RunNEON_64bit_GEMM_Uint8Operands_Uint32Accumulators1791 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
1792 AccumulatorType* accum_ptr, int depth) {
1793 asm volatile(
1794 // Load 1 Rhs cell of size 2x8
1795 "ld1 {v5.8b}, [%[rhs_ptr]], #8\n"
1796 "ld1 {v6.8b}, [%[rhs_ptr]], #8\n"
1797
1798 // Load 3 Lhs cells of size 4x2 each
1799 "ld1 {v2.8b}, [%[lhs_ptr]], #8\n"
1800 "ld1 {v3.8b}, [%[lhs_ptr]], #8\n"
1801 "ld1 {v4.8b}, [%[lhs_ptr]], #8\n"
1802
1803 "subs %w[depth], %w[depth], #2\n"
1804
1805 // Load accumulators
1806 "mov x0, %[accum_ptr]\n"
1807 "ld1 {v8.16b}, [x0], #16\n"
1808 "ld1 {v16.16b}, [x0], #16\n"
1809 "ld1 {v24.16b}, [x0], #16\n"
1810 "ld1 {v9.16b}, [x0], #16\n"
1811 "ld1 {v17.16b}, [x0], #16\n"
1812 "ld1 {v25.16b}, [x0], #16\n"
1813 "ld1 {v10.16b}, [x0], #16\n"
1814 "ld1 {v18.16b}, [x0], #16\n"
1815 "ld1 {v26.16b}, [x0], #16\n"
1816 "ld1 {v11.16b}, [x0], #16\n"
1817 "ld1 {v19.16b}, [x0], #16\n"
1818 "ld1 {v27.16b}, [x0], #16\n"
1819 "ld1 {v12.16b}, [x0], #16\n"
1820 "ld1 {v20.16b}, [x0], #16\n"
1821 "ld1 {v28.16b}, [x0], #16\n"
1822 "ld1 {v13.16b}, [x0], #16\n"
1823 "ld1 {v21.16b}, [x0], #16\n"
1824 "ld1 {v29.16b}, [x0], #16\n"
1825 "ld1 {v14.16b}, [x0], #16\n"
1826 "ld1 {v22.16b}, [x0], #16\n"
1827 "ld1 {v30.16b}, [x0], #16\n"
1828 "ld1 {v15.16b}, [x0], #16\n"
1829 "ld1 {v23.16b}, [x0], #16\n"
1830 "ld1 {v31.16b}, [x0], #16\n"
1831
1832 "beq " GEMMLOWP_LABEL_AFTER_LOOP "f\n"
1833
1834 //"loop_%=:\n"
1835 GEMMLOWP_LABEL_LOOP
1836 ":\n"
1837
1838 // Overview of register layout:
1839 //
1840 // A 2x8 block of 2 2x4 cells of Rhs is stored in 16bit in v0--v1.
1841 // A 12x2 block of 3 4x2 cells Lhs is stored in 16bit in v2--v4.
1842 // A 12x8 block of accumulators is stored in 32bit in v8--v31.
1843 //
1844 // +--------+--------+-----+--------+--------+
1845 // |v0.h[0] |v0.h[1] | ... |v1.h[2] |v1.h[3] |
1846 // Rhs +--------+--------+-----+--------+--------+
1847 // |v0.h[4] |v0.h[5] | ... |v1.h[6] |v1.h[7] |
1848 // +--------+--------+-----+--------+--------+
1849 //
1850 // | | | | | |
1851 //
1852 // Lhs | | | | | |
1853 //
1854 // +-------+-------+ - - +--------+--------+-----+--------+--------+
1855 // |v2.h[0]|v2.h[4]| |v8.s[0] |v9.s[0] | ... |v14.s[0]|v15.s[0]|
1856 // |v2.h[1]|v2.h[5]| |v8.s[1] |v9.s[1] | ... |v14.s[1]|v15.s[1]|
1857 // |v2.h[2]|v2.h[6]| |v8.s[2] |v9.s[2] | ... |v14.s[2]|v15.s[2]|
1858 // |v2.h[3]|v2.h[7]| |v8.s[3] |v9.s[3] | ... |v14.s[3]|v15.s[3]|
1859 // +-------+-------+ - - +--------+--------+-----+--------+--------+
1860 // |v3.h[0]|v3.h[4]| |v16.s[0]|v17.s[0]| ... |v22.s[0]|v23.s[0]|
1861 // |v3.h[1]|v3.h[5]| |v16.s[1]|v17.s[1]| ... |v22.s[1]|v23.s[1]|
1862 // |v3.h[2]|v3.h[6]| |v16.s[2]|v17.s[2]| ... |v22.s[2]|v23.s[2]|
1863 // |v3.h[3]|v3.h[7]| |v16.s[3]|v17.s[3]| ... |v22.s[3]|v23.s[3]|
1864 // +-------+-------+ - - +--------+--------+-----+--------+--------+
1865 // |v4.h[0]|v4.h[4]| |v24.s[0]|v25.s[0]| ... |v30.s[0]|v31.s[0]|
1866 // |v4.h[1]|v4.h[5]| |v24.s[1]|v25.s[1]| ... |v30.s[1]|v31.s[1]|
1867 // |v4.h[2]|v4.h[6]| |v24.s[2]|v25.s[2]| ... |v30.s[2]|v31.s[2]|
1868 // |v4.h[3]|v4.h[7]| |v24.s[3]|v25.s[3]| ... |v30.s[3]|v31.s[3]|
1869 // +-------+-------+ - - +--------+--------+-----+--------+--------+
1870 //
1871 // Accumulator
1872
1873 // Expand Lhs/Rhs cells to 16 bit.
1874 "uxtl v0.8h, v5.8b\n"
1875 "ld1 {v5.8b}, [%[rhs_ptr]], #8\n"
1876 "uxtl v1.8h, v6.8b\n"
1877 "ld1 {v6.8b}, [%[rhs_ptr]], #8\n"
1878 "uxtl v2.8h, v2.8b\n"
1879 "uxtl v3.8h, v3.8b\n"
1880 "uxtl v4.8h, v4.8b\n"
1881
1882 // Multiply-accumulate, top third
1883 "umlal v8.4s, v2.4h, v0.h[0]\n"
1884 "umlal v9.4s, v2.4h, v0.h[1]\n"
1885 "umlal v10.4s, v2.4h, v0.h[2]\n"
1886 "umlal v11.4s, v2.4h, v0.h[3]\n"
1887 "umlal v12.4s, v2.4h, v1.h[0]\n"
1888 "umlal v13.4s, v2.4h, v1.h[1]\n"
1889 "umlal v14.4s, v2.4h, v1.h[2]\n"
1890 "umlal v15.4s, v2.4h, v1.h[3]\n"
1891 "umlal2 v8.4s, v2.8h, v0.h[4]\n"
1892 "umlal2 v9.4s, v2.8h, v0.h[5]\n"
1893 "umlal2 v10.4s, v2.8h, v0.h[6]\n"
1894 "umlal2 v11.4s, v2.8h, v0.h[7]\n"
1895 "umlal2 v12.4s, v2.8h, v1.h[4]\n"
1896 "umlal2 v13.4s, v2.8h, v1.h[5]\n"
1897 "umlal2 v14.4s, v2.8h, v1.h[6]\n"
1898 "umlal2 v15.4s, v2.8h, v1.h[7]\n"
1899 "ld1 {v2.8b}, [%[lhs_ptr]], #8\n"
1900
1901 // Multiply-accumulate, middle third
1902 "umlal v16.4s, v3.4h, v0.h[0]\n"
1903 "umlal v17.4s, v3.4h, v0.h[1]\n"
1904 "umlal v18.4s, v3.4h, v0.h[2]\n"
1905 "umlal v19.4s, v3.4h, v0.h[3]\n"
1906 "umlal v20.4s, v3.4h, v1.h[0]\n"
1907 "umlal v21.4s, v3.4h, v1.h[1]\n"
1908 "umlal v22.4s, v3.4h, v1.h[2]\n"
1909 "umlal v23.4s, v3.4h, v1.h[3]\n"
1910 "umlal2 v16.4s, v3.8h, v0.h[4]\n"
1911 "umlal2 v17.4s, v3.8h, v0.h[5]\n"
1912 "umlal2 v18.4s, v3.8h, v0.h[6]\n"
1913 "umlal2 v19.4s, v3.8h, v0.h[7]\n"
1914 "umlal2 v20.4s, v3.8h, v1.h[4]\n"
1915 "umlal2 v21.4s, v3.8h, v1.h[5]\n"
1916 "umlal2 v22.4s, v3.8h, v1.h[6]\n"
1917 "umlal2 v23.4s, v3.8h, v1.h[7]\n"
1918 "ld1 {v3.8b}, [%[lhs_ptr]], #8\n"
1919
1920 "subs %w[depth], %w[depth], #2\n"
1921
1922 // Multiply-accumulate, bottom third
1923 "umlal v24.4s, v4.4h, v0.h[0]\n"
1924 "umlal v25.4s, v4.4h, v0.h[1]\n"
1925 "umlal v26.4s, v4.4h, v0.h[2]\n"
1926 "umlal v27.4s, v4.4h, v0.h[3]\n"
1927 "umlal v28.4s, v4.4h, v1.h[0]\n"
1928 "umlal v29.4s, v4.4h, v1.h[1]\n"
1929 "umlal v30.4s, v4.4h, v1.h[2]\n"
1930 "umlal v31.4s, v4.4h, v1.h[3]\n"
1931 "umlal2 v24.4s, v4.8h, v0.h[4]\n"
1932 "umlal2 v25.4s, v4.8h, v0.h[5]\n"
1933 "umlal2 v26.4s, v4.8h, v0.h[6]\n"
1934 "umlal2 v27.4s, v4.8h, v0.h[7]\n"
1935 "umlal2 v28.4s, v4.8h, v1.h[4]\n"
1936 "umlal2 v29.4s, v4.8h, v1.h[5]\n"
1937 "umlal2 v30.4s, v4.8h, v1.h[6]\n"
1938 "umlal2 v31.4s, v4.8h, v1.h[7]\n"
1939 "ld1 {v4.8b}, [%[lhs_ptr]], #8\n"
1940
1941 "bne " GEMMLOWP_LABEL_LOOP "b\n"
1942
1943 GEMMLOWP_LABEL_AFTER_LOOP
1944 ":\n"
1945
1946 // Expand Lhs/Rhs cells to 16 bit.
1947 "uxtl v0.8h, v5.8b\n"
1948 "uxtl v1.8h, v6.8b\n"
1949 "uxtl v2.8h, v2.8b\n"
1950 "uxtl v3.8h, v3.8b\n"
1951 "uxtl v4.8h, v4.8b\n"
1952
1953 // Multiply-accumulate, level of depth 0
1954 "umlal v8.4s, v2.4h, v0.h[0]\n"
1955 "umlal v9.4s, v2.4h, v0.h[1]\n"
1956 "umlal v10.4s, v2.4h, v0.h[2]\n"
1957 "umlal v11.4s, v2.4h, v0.h[3]\n"
1958 "umlal v12.4s, v2.4h, v1.h[0]\n"
1959 "umlal v13.4s, v2.4h, v1.h[1]\n"
1960 "umlal v14.4s, v2.4h, v1.h[2]\n"
1961 "umlal v15.4s, v2.4h, v1.h[3]\n"
1962 "umlal v16.4s, v3.4h, v0.h[0]\n"
1963 "umlal v17.4s, v3.4h, v0.h[1]\n"
1964 "umlal v18.4s, v3.4h, v0.h[2]\n"
1965 "umlal v19.4s, v3.4h, v0.h[3]\n"
1966 "umlal v20.4s, v3.4h, v1.h[0]\n"
1967 "umlal v21.4s, v3.4h, v1.h[1]\n"
1968 "umlal v22.4s, v3.4h, v1.h[2]\n"
1969 "umlal v23.4s, v3.4h, v1.h[3]\n"
1970 "umlal v24.4s, v4.4h, v0.h[0]\n"
1971 "umlal v25.4s, v4.4h, v0.h[1]\n"
1972 "umlal v26.4s, v4.4h, v0.h[2]\n"
1973 "umlal v27.4s, v4.4h, v0.h[3]\n"
1974 "umlal v28.4s, v4.4h, v1.h[0]\n"
1975 "umlal v29.4s, v4.4h, v1.h[1]\n"
1976 "umlal v30.4s, v4.4h, v1.h[2]\n"
1977 "umlal v31.4s, v4.4h, v1.h[3]\n"
1978
1979 // Multiply-accumulate, level of depth 1
1980 "umlal2 v8.4s, v2.8h, v0.h[4]\n"
1981 "umlal2 v9.4s, v2.8h, v0.h[5]\n"
1982 "umlal2 v10.4s, v2.8h, v0.h[6]\n"
1983 "umlal2 v11.4s, v2.8h, v0.h[7]\n"
1984 "umlal2 v12.4s, v2.8h, v1.h[4]\n"
1985 "umlal2 v13.4s, v2.8h, v1.h[5]\n"
1986 "umlal2 v14.4s, v2.8h, v1.h[6]\n"
1987 "umlal2 v15.4s, v2.8h, v1.h[7]\n"
1988 "umlal2 v16.4s, v3.8h, v0.h[4]\n"
1989 "umlal2 v17.4s, v3.8h, v0.h[5]\n"
1990 "umlal2 v18.4s, v3.8h, v0.h[6]\n"
1991 "umlal2 v19.4s, v3.8h, v0.h[7]\n"
1992 "umlal2 v20.4s, v3.8h, v1.h[4]\n"
1993 "umlal2 v21.4s, v3.8h, v1.h[5]\n"
1994 "umlal2 v22.4s, v3.8h, v1.h[6]\n"
1995 "umlal2 v23.4s, v3.8h, v1.h[7]\n"
1996 "umlal2 v24.4s, v4.8h, v0.h[4]\n"
1997 "umlal2 v25.4s, v4.8h, v0.h[5]\n"
1998 "umlal2 v26.4s, v4.8h, v0.h[6]\n"
1999 "umlal2 v27.4s, v4.8h, v0.h[7]\n"
2000 "umlal2 v28.4s, v4.8h, v1.h[4]\n"
2001 "umlal2 v29.4s, v4.8h, v1.h[5]\n"
2002 "umlal2 v30.4s, v4.8h, v1.h[6]\n"
2003 "umlal2 v31.4s, v4.8h, v1.h[7]\n"
2004
2005 // Store accumulators
2006 "mov x0, %[accum_ptr]\n"
2007 "st1 {v8.16b}, [x0], #16\n"
2008 "st1 {v16.16b}, [x0], #16\n"
2009 "st1 {v24.16b}, [x0], #16\n"
2010 "st1 {v9.16b}, [x0], #16\n"
2011 "st1 {v17.16b}, [x0], #16\n"
2012 "st1 {v25.16b}, [x0], #16\n"
2013 "st1 {v10.16b}, [x0], #16\n"
2014 "st1 {v18.16b}, [x0], #16\n"
2015 "st1 {v26.16b}, [x0], #16\n"
2016 "st1 {v11.16b}, [x0], #16\n"
2017 "st1 {v19.16b}, [x0], #16\n"
2018 "st1 {v27.16b}, [x0], #16\n"
2019 "st1 {v12.16b}, [x0], #16\n"
2020 "st1 {v20.16b}, [x0], #16\n"
2021 "st1 {v28.16b}, [x0], #16\n"
2022 "st1 {v13.16b}, [x0], #16\n"
2023 "st1 {v21.16b}, [x0], #16\n"
2024 "st1 {v29.16b}, [x0], #16\n"
2025 "st1 {v14.16b}, [x0], #16\n"
2026 "st1 {v22.16b}, [x0], #16\n"
2027 "st1 {v30.16b}, [x0], #16\n"
2028 "st1 {v15.16b}, [x0], #16\n"
2029 "st1 {v23.16b}, [x0], #16\n"
2030 "st1 {v31.16b}, [x0], #16\n"
2031 : // outputs
2032 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
2033 [depth] "+r"(depth)
2034 : // inputs
2035 [accum_ptr] "r"(accum_ptr)
2036 : // clobbers
2037 "cc", "memory", "x0", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7",
2038 "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17",
2039 "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27",
2040 "v28", "v29", "v30", "v31");
2041 }
2042 };
2043
2044 // Faster kernel by ARM. Not expanding operands before multiplication.
2045 // Tuned for A57. Compare to
2046 // NEON_32bit_GEMM_Uint8Operands_Uint32Accumulators_noexpand
2047 struct NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_noexpand_A57 {
2048 typedef std::uint8_t OperandType;
2049 typedef std::uint32_t AccumulatorType;
2050 typedef KernelFormat<
2051 KernelSideFormat<CellFormat<5, 16, CellOrder::WidthMajor>, 1>,
2052 KernelSideFormat<CellFormat<4, 16, CellOrder::WidthMajor>, 1> >
2053 Format;
RunNEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_noexpand_A572054 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
2055 AccumulatorType* accum_ptr, int depth) {
2056 static const int kLhsWidth = Format::Lhs::kWidth;
2057 static const int kRhsWidth = Format::Rhs::kWidth;
2058 AccumulatorType rowmajor_accumulator_buffer[kLhsWidth * kRhsWidth];
2059 asm volatile(
2060 // Clear aggregators
2061 "dup v12.4s, wzr\n"
2062 "dup v13.4s, wzr\n"
2063 "dup v14.4s, wzr\n"
2064 "dup v15.4s, wzr\n"
2065 "dup v16.4s, wzr\n"
2066 "dup v17.4s, wzr\n"
2067 "dup v18.4s, wzr\n"
2068 "dup v19.4s, wzr\n"
2069 "dup v20.4s, wzr\n"
2070 "dup v21.4s, wzr\n"
2071 "dup v22.4s, wzr\n"
2072 "dup v23.4s, wzr\n"
2073 "dup v24.4s, wzr\n"
2074 "dup v25.4s, wzr\n"
2075 "dup v26.4s, wzr\n"
2076 "dup v27.4s, wzr\n"
2077 "dup v28.4s, wzr\n"
2078 "dup v29.4s, wzr\n"
2079 "dup v30.4s, wzr\n"
2080 "dup v31.4s, wzr\n"
2081
2082 GEMMLOWP_LABEL_LOOP
2083 ":\n"
2084
2085 // Overview of register layout:
2086 //
2087 // A 4x16 block of Rhs is stored in 8 bit in v0--v3.
2088 // A 5x16 block of Lhs is cycled through v4 and v5 in 8 bit.
2089 //
2090 // A 4x5 block of aggregators is stored in v12-v31 (as 4x32 bit
2091 // components which would need to be added at the end)
2092 //
2093 // The Lhs vectors are multiplied by the Rhs vectors with a widening
2094 // multiply to produce an intermediate result which is stored in
2095 // v6-v11. Each intermediate result is 8x16 bits so this happens
2096 // twice for each Lhs/Rhs combination (once with UMULL for elements
2097 // 0-7 and once with UMULL2 for elements 8-15).
2098 //
2099 // UADALP is used to accumulate these intermediate results into the
2100 // result aggregators.
2101 //
2102 //
2103 //
2104 // +--------+--------+--------+--------+
2105 // |v0.b[0] |v1.b[0] |v2.b[0] |v3.b[0] |
2106 // Rhs +--------+--------+--------+--------+
2107 // | ... | ... | ... | ... |
2108 // +--------+--------+--------+--------|
2109 // |v0.b[15]|v1.b[15]|v2.b[15]|v3.b[15]|
2110 // +--------+--------+--------+--------+
2111 //
2112 // | | | | |
2113 //
2114 // Lhs | | | | |
2115 //
2116 // +-------+-----+--------+ - - +--------+--------+--------+--------+
2117 // |v4.b[0]| ... |v4.b[15]| | v12.4s | v13.4s | v14.4s | v15.4s |
2118 // |v5.b[0]| ... |v5.b[15]| | v16.4s | v17.4s | v18.4s | v19.4s |
2119 // |v4.b[0]| ... |v4.b[15]| | v20.4s | v21.4s | v22.4s | v23.4s |
2120 // |v5.b[0]| ... |v5.b[15]| | v24.4s | v25.4s | v26.4s | v27.4s |
2121 // |v4.b[0]| ... |v4.b[15]| | v28.4s | v29.4s | v30.4s | v31.4s |
2122 // +-------+--------------+ - - +--------+--------+--------+--------+
2123 //
2124 // Accumulator
2125 //
2126 //
2127 // Further possible optimisations (not tried):
2128 // - Move early loads into previous iteration (see Float32_WithScalar
2129 // for example). - Unroll loop 2x to alternate more smoothly between
2130 // v4 and v5. - A different number of temporary registers might work
2131 // better. - Pairing umull with corresponding umull2 might allow
2132 // better
2133 // register loading (e.g. at the start of the loop)
2134 // - Interleaving umull{2} and uadalp even more aggressively might
2135 // help, (not sure about latency vs. dispatch rate).
2136 //
2137 //
2138 // Start loading Rhs - further loads are interleaved amongst the
2139 // multiplies for better dispatch on A57.
2140 "ld1 {v0.16b}, [%[rhs_ptr]], #16\n"
2141
2142 // Load first Lhs vector - further loads are interleaved amongst the
2143 // multiplies
2144 "ld1 {v4.16b}, [%[lhs_ptr]], #16\n"
2145
2146 "umull v6.8h, v0.8b, v4.8b\n"
2147 "ld1 {v1.16b}, [%[rhs_ptr]], #16\n" // 2nd RHS element
2148 "umull v7.8h, v1.8b, v4.8b\n"
2149 "ld1 {v2.16b}, [%[rhs_ptr]], #16\n" // 3rd RHS element
2150 "umull v8.8h, v2.8b, v4.8b\n"
2151 "ld1 {v3.16b}, [%[rhs_ptr]], #16\n" // 4th RHS element
2152 "umull v9.8h, v3.8b, v4.8b\n"
2153 "umull2 v10.8h, v0.16b, v4.16b\n"
2154 "umull2 v11.8h, v1.16b, v4.16b\n"
2155 "ld1 {v5.16b}, [%[lhs_ptr]], #16\n" // 2nd LHS element
2156
2157 "uadalp v12.4s, v6.8h\n"
2158 "umull2 v6.8h, v2.16b, v4.16b\n"
2159 "uadalp v13.4s, v7.8h\n"
2160 "umull2 v7.8h, v3.16b, v4.16b\n"
2161 "ld1 {v4.16b}, [%[lhs_ptr]], #16\n" // 1st LHS element done - Reuse v4
2162 // for 3rd LHS element
2163 "uadalp v14.4s, v8.8h\n"
2164 "umull v8.8h, v0.8b, v5.8b\n"
2165 "uadalp v15.4s, v9.8h\n"
2166 "umull v9.8h, v1.8b, v5.8b\n"
2167 "uadalp v12.4s, v10.8h\n"
2168 "umull v10.8h, v2.8b, v5.8b\n"
2169 "uadalp v13.4s, v11.8h\n"
2170 "umull v11.8h, v3.8b, v5.8b\n"
2171
2172 "uadalp v14.4s, v6.8h\n"
2173 "umull2 v6.8h, v0.16b, v5.16b\n"
2174 "uadalp v15.4s, v7.8h\n"
2175 "umull2 v7.8h, v1.16b, v5.16b\n"
2176 "uadalp v16.4s, v8.8h\n"
2177 "umull2 v8.8h, v2.16b, v5.16b\n"
2178 "uadalp v17.4s, v9.8h\n"
2179 "umull2 v9.8h, v3.16b, v5.16b\n"
2180 "ld1 {v5.16b}, [%[lhs_ptr]], #16\n" // 2nd LHS element done - Reuse v5
2181 // for 4th LHS element
2182 "uadalp v18.4s, v10.8h\n"
2183 "umull v10.8h, v0.8b, v4.8b\n"
2184 "uadalp v19.4s, v11.8h\n"
2185 "umull v11.8h, v1.8b, v4.8b\n"
2186
2187 "uadalp v16.4s, v6.8h\n"
2188 "umull v6.8h, v2.8b, v4.8b\n"
2189 "uadalp v17.4s, v7.8h\n"
2190 "umull v7.8h, v3.8b, v4.8b\n"
2191 "uadalp v18.4s, v8.8h\n"
2192 "umull2 v8.8h, v0.16b, v4.16b\n"
2193 "uadalp v19.4s, v9.8h\n"
2194 "umull2 v9.8h, v1.16b, v4.16b\n"
2195 "uadalp v20.4s, v10.8h\n"
2196 "umull2 v10.8h, v2.16b, v4.16b\n"
2197 "uadalp v21.4s, v11.8h\n"
2198 "umull2 v11.8h, v3.16b, v4.16b\n"
2199 "ld1 {v4.16b}, [%[lhs_ptr]], #16\n" // 3rd LHS element done - Reuse v4
2200 // for 5th LHS element
2201
2202 "uadalp v22.4s, v6.8h\n"
2203 "umull v6.8h, v0.8b, v5.8b\n"
2204 "uadalp v23.4s, v7.8h\n"
2205 "umull v7.8h, v1.8b, v5.8b\n"
2206 "uadalp v20.4s, v8.8h\n"
2207 "umull v8.8h, v2.8b, v5.8b\n"
2208 "uadalp v21.4s, v9.8h\n"
2209 "umull v9.8h, v3.8b, v5.8b\n"
2210 "uadalp v22.4s, v10.8h\n"
2211 "umull2 v10.8h, v0.16b, v5.16b\n"
2212 "uadalp v23.4s, v11.8h\n"
2213 "umull2 v11.8h, v1.16b, v5.16b\n"
2214
2215 "uadalp v24.4s, v6.8h\n"
2216 "umull2 v6.8h, v2.16b, v5.16b\n"
2217 "uadalp v25.4s, v7.8h\n"
2218 "umull2 v7.8h, v3.16b, v5.16b\n"
2219 "uadalp v26.4s, v8.8h\n"
2220 "umull v8.8h, v0.8b, v4.8b\n"
2221 "uadalp v27.4s, v9.8h\n"
2222 "umull v9.8h, v1.8b, v4.8b\n"
2223 "uadalp v24.4s, v10.8h\n"
2224 "umull v10.8h, v2.8b, v4.8b\n"
2225 "uadalp v25.4s, v11.8h\n"
2226 "umull v11.8h, v3.8b, v4.8b\n"
2227
2228 "uadalp v26.4s, v6.8h\n"
2229 "umull2 v6.8h, v0.16b, v4.16b\n"
2230 "uadalp v27.4s, v7.8h\n"
2231 "umull2 v7.8h, v1.16b, v4.16b\n"
2232 "uadalp v28.4s, v8.8h\n"
2233 "umull2 v8.8h, v2.16b, v4.16b\n"
2234 "uadalp v29.4s, v9.8h\n"
2235 "umull2 v9.8h, v3.16b, v4.16b\n"
2236 "uadalp v30.4s, v10.8h\n"
2237 "uadalp v31.4s, v11.8h\n"
2238
2239 "uadalp v28.4s, v6.8h\n"
2240 "uadalp v29.4s, v7.8h\n"
2241 // Loop. Decrement loop index (depth) by 16, since we just handled
2242 // 16 levels of depth. Do this subs a bit before the end of the loop
2243 // for better dispatch on A57.
2244 "subs %w[depth], %w[depth], #16\n"
2245 "uadalp v30.4s, v8.8h\n"
2246 "uadalp v31.4s, v9.8h\n"
2247
2248 "bne " GEMMLOWP_LABEL_LOOP
2249 "b\n"
2250
2251 // Reduce aggregators horizontally
2252 "addp v0.4s, v12.4s, v13.4s\n"
2253 "addp v1.4s, v14.4s, v15.4s\n"
2254 "addp v2.4s, v16.4s, v17.4s\n"
2255 "addp v3.4s, v18.4s, v19.4s\n"
2256 "addp v4.4s, v20.4s, v21.4s\n"
2257 "addp v5.4s, v22.4s, v23.4s\n"
2258 "addp v6.4s, v24.4s, v25.4s\n"
2259 "addp v7.4s, v26.4s, v27.4s\n"
2260 "addp v8.4s, v28.4s, v29.4s\n"
2261 "addp v9.4s, v30.4s, v31.4s\n"
2262
2263 "addp v10.4s, v0.4s, v1.4s\n"
2264 "addp v11.4s, v2.4s, v3.4s\n"
2265 "addp v12.4s, v4.4s, v5.4s\n"
2266 "addp v13.4s, v6.4s, v7.4s\n"
2267 "addp v14.4s, v8.4s, v9.4s\n"
2268
2269 "mov x0, %[rowmajor_accumulator_buffer]\n"
2270 "st1 {v10.16b}, [x0], #16\n"
2271 "st1 {v11.16b}, [x0], #16\n"
2272 "st1 {v12.16b}, [x0], #16\n"
2273 "st1 {v13.16b}, [x0], #16\n"
2274 "st1 {v14.16b}, [x0], #16\n"
2275 : // outputs
2276 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
2277 [depth] "+r"(depth)
2278 : // inputs
2279 [rowmajor_accumulator_buffer] "r"(rowmajor_accumulator_buffer)
2280 : // clobbers
2281 "cc", "memory", "x0", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7",
2282 "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17",
2283 "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27",
2284 "v28", "v29", "v30", "v31");
2285
2286 // accumulate row-major accumulators into global (column-major) accumulators
2287 for (int l = 0; l < kLhsWidth; l++) {
2288 for (int r = 0; r < kRhsWidth; r++) {
2289 accum_ptr[l + kLhsWidth * r] +=
2290 rowmajor_accumulator_buffer[r + l * kRhsWidth];
2291 }
2292 }
2293 }
2294 };
2295
2296 // Fast kernel operating on int8 operands.
2297 // It is assumed that one of the two int8 operands only takes values
2298 // in [-127, 127], while the other may freely range in [-128, 127].
2299 // The issue with both operands taking the value -128 is that:
2300 // -128*-128 + -128*-128 == -32768 overflows int16.
2301 // Every other expression a*b + c*d, for any int8 a,b,c,d, fits in int16
2302 // range. That is the basic idea of this kernel.
2303 struct NEON_64bit_GEMM_Int8Operands_AccumTwoWithin16Bits {
2304 typedef std::int8_t OperandType;
2305 typedef std::int32_t AccumulatorType;
2306 typedef KernelFormat<
2307 KernelSideFormat<CellFormat<4, 16, CellOrder::WidthMajor>, 1>,
2308 KernelSideFormat<CellFormat<4, 16, CellOrder::WidthMajor>, 1> >
2309 Format;
RunNEON_64bit_GEMM_Int8Operands_AccumTwoWithin16Bits2310 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
2311 AccumulatorType* accum_ptr, int depth) {
2312 std::size_t start_depth = 123;
2313 std::size_t run_depth = depth;
2314 std::size_t dst_col_stride = 4;
2315 AccumulatorType* dst_ptr = accum_ptr;
2316 asm volatile(
2317 // Overview of register layout:
2318 //
2319 // A 4x16 block of Rhs is stored in 8 bit in v0--v3.
2320 // A 4x16 block of Lhs is stored in 8 bit in v4--v7.
2321 //
2322 // A 4x4 block of accumulators is stored in v16-v31 (as 4x32 bit
2323 // components which need to be horizontally-added at the end)
2324 //
2325 // The Lhs vectors are multiplied by the Rhs vectors with a widening
2326 // multiply over the 8 first levels of depth, producing int16x8
2327 // vectors of products for each position in the accumulator matrix.
2328 // Here comes the special trick: since the operands are signed int8,
2329 // their range being [ -2^7 , 2^7 ), their products are in range
2330 // [ -2^14 , 2^14 - 1 ), meaning that we can add two such values
2331 // without any risk of overflowing int16.
2332 // We thus proceed with the 8 next levels of depth, multiplying
2333 // again Lhs by Rhs, accumulating into this existing int16x8 vector.
2334 //
2335 // Only then, having processed 16 levels of depth, do we need to
2336 // horizontally add these int16x8 accumulators into the final
2337 // int32x4 accumulators.
2338 //
2339 // As we do not have enough registers to store all 16 int16x8
2340 // temporary-16bit-accumulators, we have them cycle through v8--v15.
2341 //
2342 //
2343 // Register layout (ignoring the v8--v15 temporary 16bit accumulators):
2344 //
2345 // +--------+--------+--------+--------+
2346 // |v0.b[0] |v1.b[0] |v2.b[0] |v3.b[0] |
2347 // Rhs +--------+--------+--------+--------+
2348 // | ... | ... | ... | ... |
2349 // +--------+--------+--------+--------|
2350 // |v0.b[15]|v1.b[15]|v2.b[15]|v3.b[15]|
2351 // +--------+--------+--------+--------+
2352 //
2353 // | | | | |
2354 //
2355 // Lhs | | | | |
2356 //
2357 // +-------+-----+--------+ - - +--------+--------+--------+--------+
2358 // |v4.b[0]| ... |v4.b[15]| | v16.4s | v17.4s | v18.4s | v19.4s |
2359 // |v5.b[0]| ... |v5.b[15]| | v20.4s | v21.4s | v22.4s | v23.4s |
2360 // |v6.b[0]| ... |v6.b[15]| | v24.4s | v25.4s | v26.4s | v27.4s |
2361 // |v7.b[0]| ... |v7.b[15]| | v28.4s | v29.4s | v30.4s | v31.4s |
2362 // +-------+--------------+ - - +--------+--------+--------+--------+
2363 //
2364 // Accumulator
2365 //
2366
2367 // Clear accumulators
2368 "ld1 {v0.16b}, [%[rhs_ptr]], #16\n"
2369 "dup v16.4s, wzr\n"
2370 "ld1 {v1.16b}, [%[rhs_ptr]], #16\n"
2371 "dup v17.4s, wzr\n"
2372 "ld1 {v4.16b}, [%[lhs_ptr]], #16\n"
2373 "dup v18.4s, wzr\n"
2374 "ld1 {v5.16b}, [%[lhs_ptr]], #16\n"
2375 "dup v19.4s, wzr\n"
2376 "ld1 {v6.16b}, [%[lhs_ptr]], #16\n"
2377 "dup v20.4s, wzr\n"
2378 "ld1 {v7.16b}, [%[lhs_ptr]], #16\n"
2379 "dup v21.4s, wzr\n"
2380 "ld1 {v2.16b}, [%[rhs_ptr]], #16\n"
2381 "dup v22.4s, wzr\n"
2382 "ld1 {v3.16b}, [%[rhs_ptr]], #16\n"
2383 "dup v23.4s, wzr\n"
2384 "subs %[run_depth], %[run_depth], #16\n"
2385 "dup v24.4s, wzr\n"
2386 "mov x0, %[dst_ptr]\n"
2387 "dup v25.4s, wzr\n"
2388 "dup v26.4s, wzr\n"
2389 "dup v27.4s, wzr\n"
2390 "dup v28.4s, wzr\n"
2391 "dup v29.4s, wzr\n"
2392 "dup v30.4s, wzr\n"
2393 "dup v31.4s, wzr\n"
2394
2395 "smull v12.8h, v0.8b, v4.8b\n"
2396 "smull v13.8h, v1.8b, v4.8b\n"
2397 "smull v14.8h, v0.8b, v5.8b\n"
2398 "smull v15.8h, v1.8b, v5.8b\n"
2399 "smlal2 v12.8h, v0.16b, v4.16b\n"
2400 "smlal2 v13.8h, v1.16b, v4.16b\n"
2401 "smlal2 v14.8h, v0.16b, v5.16b\n"
2402 "smlal2 v15.8h, v1.16b, v5.16b\n"
2403
2404 "beq " GEMMLOWP_LABEL_AFTER_LOOP "f\n"
2405
2406 GEMMLOWP_LABEL_LOOP
2407 ":\n"
2408
2409 "subs %[run_depth], %[run_depth], #16\n"
2410
2411 "sadalp v16.4s, v12.8h\n"
2412 "smull v12.8h, v0.8b, v6.8b\n"
2413 "sadalp v17.4s, v13.8h\n"
2414 "smull v13.8h, v0.8b, v7.8b\n"
2415 "sadalp v20.4s, v14.8h\n"
2416 "smull v14.8h, v1.8b, v6.8b\n"
2417 "sadalp v21.4s, v15.8h\n"
2418 "smull v15.8h, v1.8b, v7.8b\n"
2419 "smlal2 v12.8h, v0.16b, v6.16b\n"
2420 "smlal2 v13.8h, v0.16b, v7.16b\n"
2421 "ld1 {v0.16b}, [%[rhs_ptr]], #16\n"
2422 "smlal2 v14.8h, v1.16b, v6.16b\n"
2423 "smlal2 v15.8h, v1.16b, v7.16b\n"
2424 "ld1 {v1.16b}, [%[rhs_ptr]], #16\n"
2425 "sadalp v24.4s, v12.8h\n"
2426 "smull v12.8h, v2.8b, v4.8b\n"
2427 "sadalp v28.4s, v13.8h\n"
2428 "smull v13.8h, v3.8b, v4.8b\n"
2429 "sadalp v25.4s, v14.8h\n"
2430 "smull v14.8h, v2.8b, v5.8b\n"
2431 "sadalp v29.4s, v15.8h\n"
2432 "smull v15.8h, v3.8b, v5.8b\n"
2433 "smlal2 v12.8h, v2.16b, v4.16b\n"
2434 "smlal2 v13.8h, v3.16b, v4.16b\n"
2435 "ld1 {v4.16b}, [%[lhs_ptr]], #16\n"
2436 "smlal2 v14.8h, v2.16b, v5.16b\n"
2437 "smlal2 v15.8h, v3.16b, v5.16b\n"
2438 "ld1 {v5.16b}, [%[lhs_ptr]], #16\n"
2439 "sadalp v18.4s, v12.8h\n"
2440 "smull v12.8h, v2.8b, v6.8b\n"
2441 "sadalp v19.4s, v13.8h\n"
2442 "smull v13.8h, v2.8b, v7.8b\n"
2443 "sadalp v22.4s, v14.8h\n"
2444 "smull v14.8h, v3.8b, v6.8b\n"
2445 "sadalp v23.4s, v15.8h\n"
2446 "smull v15.8h, v3.8b, v7.8b\n"
2447 "smlal2 v12.8h, v2.16b, v6.16b\n"
2448 "smlal2 v13.8h, v2.16b, v7.16b\n"
2449 "ld1 {v2.16b}, [%[rhs_ptr]], #16\n"
2450 "smlal2 v14.8h, v3.16b, v6.16b\n"
2451 "ld1 {v6.16b}, [%[lhs_ptr]], #16\n"
2452 "smlal2 v15.8h, v3.16b, v7.16b\n"
2453 "ld1 {v7.16b}, [%[lhs_ptr]], #16\n"
2454 "sadalp v26.4s, v12.8h\n"
2455 "ld1 {v3.16b}, [%[rhs_ptr]], #16\n"
2456 "smull v12.8h, v0.8b, v4.8b\n"
2457 "sadalp v30.4s, v13.8h\n"
2458 "smull v13.8h, v1.8b, v4.8b\n"
2459 "sadalp v27.4s, v14.8h\n"
2460 "smull v14.8h, v0.8b, v5.8b\n"
2461 "sadalp v31.4s, v15.8h\n"
2462 "smull v15.8h, v1.8b, v5.8b\n"
2463 "smlal2 v12.8h, v0.16b, v4.16b\n"
2464 "smlal2 v13.8h, v1.16b, v4.16b\n"
2465 "smlal2 v14.8h, v0.16b, v5.16b\n"
2466 "smlal2 v15.8h, v1.16b, v5.16b\n"
2467
2468 "bne " GEMMLOWP_LABEL_LOOP "b\n"
2469
2470 GEMMLOWP_LABEL_AFTER_LOOP
2471 ":\n"
2472
2473 // Load accumulators from memory
2474 "ld1 {v8.16b}, [x0], #16\n"
2475 "ld1 {v9.16b}, [x0], #16\n"
2476 "ld1 {v10.16b}, [x0], #16\n"
2477 "ld1 {v11.16b}, [x0], #16\n"
2478 "mov x0, %[dst_ptr]\n"
2479
2480 // Do the remaining arithmetic for the 16 last levels of depths.
2481 // All the operands are already loaded.
2482 "sadalp v16.4s, v12.8h\n"
2483 "smull v12.8h, v0.8b, v6.8b\n"
2484 "sadalp v17.4s, v13.8h\n"
2485 "smull v13.8h, v0.8b, v7.8b\n"
2486 "sadalp v20.4s, v14.8h\n"
2487 "smull v14.8h, v1.8b, v6.8b\n"
2488 "sadalp v21.4s, v15.8h\n"
2489 "smull v15.8h, v1.8b, v7.8b\n"
2490 "smlal2 v12.8h, v0.16b, v6.16b\n"
2491 "smlal2 v13.8h, v0.16b, v7.16b\n"
2492 "smlal2 v14.8h, v1.16b, v6.16b\n"
2493 "smlal2 v15.8h, v1.16b, v7.16b\n"
2494 "sadalp v24.4s, v12.8h\n"
2495 "smull v12.8h, v2.8b, v4.8b\n"
2496 "sadalp v28.4s, v13.8h\n"
2497 "smull v13.8h, v3.8b, v4.8b\n"
2498 "sadalp v25.4s, v14.8h\n"
2499 "smull v14.8h, v2.8b, v5.8b\n"
2500 "sadalp v29.4s, v15.8h\n"
2501 "smull v15.8h, v3.8b, v5.8b\n"
2502 "smlal2 v12.8h, v2.16b, v4.16b\n"
2503 "smlal2 v13.8h, v3.16b, v4.16b\n"
2504 "smlal2 v14.8h, v2.16b, v5.16b\n"
2505 "smlal2 v15.8h, v3.16b, v5.16b\n"
2506 "sadalp v18.4s, v12.8h\n"
2507 "smull v12.8h, v2.8b, v6.8b\n"
2508 "sadalp v19.4s, v13.8h\n"
2509 "smull v13.8h, v2.8b, v7.8b\n"
2510 "sadalp v22.4s, v14.8h\n"
2511 "smull v14.8h, v3.8b, v6.8b\n"
2512 "sadalp v23.4s, v15.8h\n"
2513 "smull v15.8h, v3.8b, v7.8b\n"
2514 "smlal2 v12.8h, v2.16b, v6.16b\n"
2515 "smlal2 v13.8h, v2.16b, v7.16b\n"
2516 "smlal2 v14.8h, v3.16b, v6.16b\n"
2517 "smlal2 v15.8h, v3.16b, v7.16b\n"
2518 "sadalp v26.4s, v12.8h\n"
2519 "sadalp v30.4s, v13.8h\n"
2520 "sadalp v27.4s, v14.8h\n"
2521 "sadalp v31.4s, v15.8h\n"
2522
2523 // Reduce aggregators horizontally
2524 "addp v0.4s, v16.4s, v20.4s\n"
2525 "addp v1.4s, v17.4s, v21.4s\n"
2526 "addp v2.4s, v18.4s, v22.4s\n"
2527 "addp v3.4s, v19.4s, v23.4s\n"
2528 "addp v4.4s, v24.4s, v28.4s\n"
2529 "addp v5.4s, v25.4s, v29.4s\n"
2530 "addp v6.4s, v26.4s, v30.4s\n"
2531 "addp v7.4s, v27.4s, v31.4s\n"
2532
2533 "addp v12.4s, v0.4s, v4.4s\n"
2534 "addp v13.4s, v1.4s, v5.4s\n"
2535 "addp v14.4s, v2.4s, v6.4s\n"
2536 "addp v15.4s, v3.4s, v7.4s\n"
2537
2538 // Add to the accumulators loaded from memory
2539 "add v8.4s, v8.4s, v12.4s\n"
2540 "add v9.4s, v9.4s, v13.4s\n"
2541 "add v10.4s, v10.4s, v14.4s\n"
2542 "add v11.4s, v11.4s, v15.4s\n"
2543
2544 // Store accumulators back to memory
2545 "st1 {v8.16b}, [x0], #16\n"
2546 "st1 {v9.16b}, [x0], #16\n"
2547 "st1 {v10.16b}, [x0], #16\n"
2548 "st1 {v11.16b}, [x0], #16\n"
2549 : // outputs
2550 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
2551 [dst_ptr] "+r"(dst_ptr), [run_depth] "+r"(run_depth),
2552 [dst_col_stride] "+r"(dst_col_stride)
2553 : // inputs
2554 [start_depth] "r"(start_depth)
2555 : // clobbers
2556 "cc", "memory", "x0", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7",
2557 "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17",
2558 "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27",
2559 "v28", "v29", "v30", "v31");
2560 }
2561 };
2562
2563 struct NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct_narrow {
2564 typedef std::uint8_t OperandType;
2565 typedef std::uint32_t AccumulatorType;
2566 typedef KernelFormat<
2567 KernelSideFormat<CellFormat<4, 16, CellOrder::WidthMajor>, 1>,
2568 KernelSideFormat<CellFormat<4, 16, CellOrder::WidthMajor>, 1> >
2569 Format;
RunNEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct_narrow2570 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
2571 AccumulatorType* accum_ptr, int depth) {
2572 std::size_t start_depth = 123;
2573 std::size_t run_depth = depth;
2574 std::size_t dst_col_stride = 4;
2575 AccumulatorType* dst_ptr = accum_ptr;
2576 asm volatile(
2577 // Overview of register layout:
2578 //
2579 // A 4x16 block of Rhs is stored in 8 bit in v0--v3.
2580 // A 4x16 block of Lhs is stored in 8 bit in v4--v7.
2581 //
2582 // A 4x4 block of accumulators is stored in v16-v31 (as 4x32 bit
2583 // components which need to be horizontally-added at the end)
2584 //
2585 // Register layout:
2586 //
2587 // +--------+--------+--------+--------+
2588 // |v0.b[0] |v1.b[0] |v2.b[0] |v3.b[0] |
2589 // Rhs +--------+--------+--------+--------+
2590 // | ... | ... | ... | ... |
2591 // +--------+--------+--------+--------|
2592 // |v0.b[15]|v1.b[15]|v2.b[15]|v3.b[15]|
2593 // +--------+--------+--------+--------+
2594 //
2595 // | | | | |
2596 //
2597 // Lhs | | | | |
2598 //
2599 // +-------+-----+--------+ - - +--------+--------+--------+--------+
2600 // |v4.b[0]| ... |v4.b[15]| | v16.4s | v17.4s | v18.4s | v19.4s |
2601 // |v5.b[0]| ... |v5.b[15]| | v20.4s | v21.4s | v22.4s | v23.4s |
2602 // |v6.b[0]| ... |v6.b[15]| | v24.4s | v25.4s | v26.4s | v27.4s |
2603 // |v7.b[0]| ... |v7.b[15]| | v28.4s | v29.4s | v30.4s | v31.4s |
2604 // +-------+--------------+ - - +--------+--------+--------+--------+
2605 //
2606 // Accumulator
2607 //
2608
2609 // Clear accumulators
2610 "ld1 {v0.16b}, [%[rhs_ptr]], #16\n"
2611 "dup v16.4s, wzr\n"
2612 "ld1 {v1.16b}, [%[rhs_ptr]], #16\n"
2613 "dup v17.4s, wzr\n"
2614 "ld1 {v4.16b}, [%[lhs_ptr]], #16\n"
2615 "dup v18.4s, wzr\n"
2616 "ld1 {v5.16b}, [%[lhs_ptr]], #16\n"
2617 "dup v19.4s, wzr\n"
2618 "ld1 {v6.16b}, [%[lhs_ptr]], #16\n"
2619 "dup v20.4s, wzr\n"
2620 "ld1 {v7.16b}, [%[lhs_ptr]], #16\n"
2621 "dup v21.4s, wzr\n"
2622 "ld1 {v2.16b}, [%[rhs_ptr]], #16\n"
2623 "dup v22.4s, wzr\n"
2624 "ld1 {v3.16b}, [%[rhs_ptr]], #16\n"
2625 "dup v23.4s, wzr\n"
2626 "subs %w[run_depth], %w[run_depth], #16\n"
2627 "dup v24.4s, wzr\n"
2628 "mov x0, %[dst_ptr]\n"
2629 "dup v25.4s, wzr\n"
2630 "dup v26.4s, wzr\n"
2631 "dup v27.4s, wzr\n"
2632 "dup v28.4s, wzr\n"
2633 "dup v29.4s, wzr\n"
2634 "dup v30.4s, wzr\n"
2635 "dup v31.4s, wzr\n"
2636
2637 "beq 1f\n"
2638
2639 "cmp %w[run_depth], #32\n"
2640 "blt 2f\n"
2641
2642 "3:\n"
2643 "ld1 {v12.16b}, [%[lhs_ptr]], #16\n"
2644 ".word 0x6e809490 // udot v16.4s, v4.16b, v0.16b\n"
2645 ".word 0x6e819491 // udot v17.4s, v4.16b, v1.16b\n"
2646 "ld1 {v13.16b}, [%[lhs_ptr]], #16\n"
2647 ".word 0x6e829492 // udot v18.4s, v4.16b, v2.16b\n"
2648 ".word 0x6e839493 // udot v19.4s, v4.16b, v3.16b\n"
2649 "ld1 {v8.16b}, [%[rhs_ptr]], #16\n"
2650 ".word 0x6e8094b4 // udot v20.4s, v5.16b, v0.16b\n"
2651 ".word 0x6e8194b5 // udot v21.4s, v5.16b, v1.16b\n"
2652 "ld1 {v9.16b}, [%[rhs_ptr]], #16\n"
2653 ".word 0x6e8294b6 // udot v22.4s, v5.16b, v2.16b\n"
2654 ".word 0x6e8394b7 // udot v23.4s, v5.16b, v3.16b\n"
2655 "ld1 {v10.16b}, [%[rhs_ptr]], #16\n"
2656 ".word 0x6e8094d8 // udot v24.4s, v6.16b, v0.16b\n"
2657 ".word 0x6e8194d9 // udot v25.4s, v6.16b, v1.16b\n"
2658 "ld1 {v11.16b}, [%[rhs_ptr]], #16\n"
2659 ".word 0x6e8294da // udot v26.4s, v6.16b, v2.16b\n"
2660 "prfm pldl1keep, [%[rhs_ptr], #128]\n"
2661 ".word 0x6e8394db // udot v27.4s, v6.16b, v3.16b\n"
2662 "ld1 {v14.16b}, [%[lhs_ptr]], #16\n"
2663 ".word 0x6e8094fc // udot v28.4s, v7.16b, v0.16b\n"
2664 ".word 0x6e8194fd // udot v29.4s, v7.16b, v1.16b\n"
2665 "ld1 {v15.16b}, [%[lhs_ptr]], #16\n"
2666 ".word 0x6e8294fe // udot v30.4s, v7.16b, v2.16b\n"
2667 "prfm pldl1keep, [%[lhs_ptr], #128]\n"
2668 ".word 0x6e8394ff // udot v31.4s, v7.16b, v3.16b\n"
2669 "ld1 {v4.16b}, [%[lhs_ptr]], #16\n"
2670 ".word 0x6e889590 // udot v16.4s, v12.16b, v8.16b\n"
2671 ".word 0x6e899591 // udot v17.4s, v12.16b, v9.16b\n"
2672 "ld1 {v5.16b}, [%[lhs_ptr]], #16\n"
2673 ".word 0x6e8a9592 // udot v18.4s, v12.16b, v10.16b\n"
2674 ".word 0x6e8b9593 // udot v19.4s, v12.16b, v11.16b\n"
2675 "ld1 {v6.16b}, [%[lhs_ptr]], #16\n"
2676 ".word 0x6e8895b4 // udot v20.4s, v13.16b, v8.16b\n"
2677 ".word 0x6e8995b5 // udot v21.4s, v13.16b, v9.16b\n"
2678 "ld1 {v0.16b}, [%[rhs_ptr]], #16\n"
2679 "sub %[run_depth], %[run_depth], #32\n"
2680 ".word 0x6e8a95b6 // udot v22.4s, v13.16b, v10.16b\n"
2681 ".word 0x6e8b95b7 // udot v23.4s, v13.16b, v11.16b\n"
2682 "ld1 {v1.16b}, [%[rhs_ptr]], #16\n"
2683 ".word 0x6e8895d8 // udot v24.4s, v14.16b, v8.16b\n"
2684 ".word 0x6e8995d9 // udot v25.4s, v14.16b, v9.16b\n"
2685 "ld1 {v2.16b}, [%[rhs_ptr]], #16\n"
2686 ".word 0x6e8a95da // udot v26.4s, v14.16b, v10.16b\n"
2687 ".word 0x6e8b95db // udot v27.4s, v14.16b, v11.16b\n"
2688 "ld1 {v3.16b}, [%[rhs_ptr]], #16\n"
2689 ".word 0x6e8895fc // udot v28.4s, v15.16b, v8.16b\n"
2690 "prfm pldl1keep, [%[rhs_ptr], #128]\n"
2691 ".word 0x6e8995fd // udot v29.4s, v15.16b, v9.16b\n"
2692 "ld1 {v7.16b}, [%[lhs_ptr]], #16\n"
2693 "cmp %w[run_depth], #32\n"
2694 ".word 0x6e8a95fe // udot v30.4s, v15.16b, v10.16b\n"
2695 "prfm pldl1keep, [%[lhs_ptr], #128]\n"
2696 ".word 0x6e8b95ff // udot v31.4s, v15.16b, v11.16b\n"
2697
2698 "bge 3b\n"
2699
2700 "cmp %w[run_depth], #0\n"
2701 "beq 1f\n"
2702
2703 "2:\n"
2704
2705 "subs %w[run_depth], %w[run_depth], #16\n"
2706
2707 ".word 0x6e809490 // udot v16.4s, v4.16b, v0.16b\n"
2708 ".word 0x6e819491 // udot v17.4s, v4.16b, v1.16b\n"
2709 ".word 0x6e829492 // udot v18.4s, v4.16b, v2.16b\n"
2710 ".word 0x6e839493 // udot v19.4s, v4.16b, v3.16b\n"
2711 "ld1 {v4.16b}, [%[lhs_ptr]], #16\n"
2712 ".word 0x6e8094b4 // udot v20.4s, v5.16b, v0.16b\n"
2713 ".word 0x6e8194b5 // udot v21.4s, v5.16b, v1.16b\n"
2714 ".word 0x6e8294b6 // udot v22.4s, v5.16b, v2.16b\n"
2715 ".word 0x6e8394b7 // udot v23.4s, v5.16b, v3.16b\n"
2716 "ld1 {v5.16b}, [%[lhs_ptr]], #16\n"
2717 ".word 0x6e8094d8 // udot v24.4s, v6.16b, v0.16b\n"
2718 ".word 0x6e8194d9 // udot v25.4s, v6.16b, v1.16b\n"
2719 ".word 0x6e8294da // udot v26.4s, v6.16b, v2.16b\n"
2720 ".word 0x6e8394db // udot v27.4s, v6.16b, v3.16b\n"
2721 "ld1 {v6.16b}, [%[lhs_ptr]], #16\n"
2722 ".word 0x6e8094fc // udot v28.4s, v7.16b, v0.16b\n"
2723 "ld1 {v0.16b}, [%[rhs_ptr]], #16\n"
2724 ".word 0x6e8194fd // udot v29.4s, v7.16b, v1.16b\n"
2725 "ld1 {v1.16b}, [%[rhs_ptr]], #16\n"
2726 ".word 0x6e8294fe // udot v30.4s, v7.16b, v2.16b\n"
2727 "ld1 {v2.16b}, [%[rhs_ptr]], #16\n"
2728 ".word 0x6e8394ff // udot v31.4s, v7.16b, v3.16b\n"
2729 "ld1 {v3.16b}, [%[rhs_ptr]], #16\n"
2730 "ld1 {v7.16b}, [%[lhs_ptr]], #16\n"
2731
2732 "bne 2b\n"
2733
2734 "1:\n"
2735
2736 ".word 0x6e809490 // udot v16.4s, v4.16b, v0.16b\n"
2737 ".word 0x6e819491 // udot v17.4s, v4.16b, v1.16b\n"
2738 ".word 0x6e829492 // udot v18.4s, v4.16b, v2.16b\n"
2739 ".word 0x6e839493 // udot v19.4s, v4.16b, v3.16b\n"
2740 ".word 0x6e8094b4 // udot v20.4s, v5.16b, v0.16b\n"
2741 ".word 0x6e8194b5 // udot v21.4s, v5.16b, v1.16b\n"
2742 ".word 0x6e8294b6 // udot v22.4s, v5.16b, v2.16b\n"
2743 ".word 0x6e8394b7 // udot v23.4s, v5.16b, v3.16b\n"
2744 ".word 0x6e8094d8 // udot v24.4s, v6.16b, v0.16b\n"
2745 ".word 0x6e8194d9 // udot v25.4s, v6.16b, v1.16b\n"
2746 ".word 0x6e8294da // udot v26.4s, v6.16b, v2.16b\n"
2747 ".word 0x6e8394db // udot v27.4s, v6.16b, v3.16b\n"
2748 ".word 0x6e8094fc // udot v28.4s, v7.16b, v0.16b\n"
2749 ".word 0x6e8194fd // udot v29.4s, v7.16b, v1.16b\n"
2750 ".word 0x6e8294fe // udot v30.4s, v7.16b, v2.16b\n"
2751 ".word 0x6e8394ff // udot v31.4s, v7.16b, v3.16b\n"
2752
2753 // Load accumulators from memory
2754 "ld1 {v8.16b}, [x0], #16\n"
2755 "ld1 {v9.16b}, [x0], #16\n"
2756 "ld1 {v10.16b}, [x0], #16\n"
2757 "ld1 {v11.16b}, [x0], #16\n"
2758 "mov x0, %[dst_ptr]\n"
2759
2760 // Reduce aggregators horizontally
2761 "addp v0.4s, v16.4s, v20.4s\n"
2762 "addp v1.4s, v17.4s, v21.4s\n"
2763 "addp v2.4s, v18.4s, v22.4s\n"
2764 "addp v3.4s, v19.4s, v23.4s\n"
2765 "addp v4.4s, v24.4s, v28.4s\n"
2766 "addp v5.4s, v25.4s, v29.4s\n"
2767 "addp v6.4s, v26.4s, v30.4s\n"
2768 "addp v7.4s, v27.4s, v31.4s\n"
2769
2770 "addp v12.4s, v0.4s, v4.4s\n"
2771 "addp v13.4s, v1.4s, v5.4s\n"
2772 "addp v14.4s, v2.4s, v6.4s\n"
2773 "addp v15.4s, v3.4s, v7.4s\n"
2774
2775 // Add to the accumulators loaded from memory
2776 "add v8.4s, v8.4s, v12.4s\n"
2777 "add v9.4s, v9.4s, v13.4s\n"
2778 "add v10.4s, v10.4s, v14.4s\n"
2779 "add v11.4s, v11.4s, v15.4s\n"
2780
2781 // Store accumulators back to memory
2782 "st1 {v8.16b}, [x0], #16\n"
2783 "st1 {v9.16b}, [x0], #16\n"
2784 "st1 {v10.16b}, [x0], #16\n"
2785 "st1 {v11.16b}, [x0], #16\n"
2786 : // outputs
2787 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
2788 [dst_ptr] "+r"(dst_ptr), [run_depth] "+r"(run_depth),
2789 [dst_col_stride] "+r"(dst_col_stride)
2790 : // inputs
2791 [start_depth] "r"(start_depth)
2792 : // clobbers
2793 "cc", "memory", "x0", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7",
2794 "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17",
2795 "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27",
2796 "v28", "v29", "v30", "v31");
2797 }
2798 };
2799
2800 // Fast kernel operating on int8 operands with 7-bit range.
2801 // It is assumed that one of the two operands only takes values in [-63, 63],
2802 // while the other take values in [-64, 63].
2803 // With this restriction, it is possible to multiply-accumulate operands into
2804 // a 16-bit integer eight times without overflow.
2805 struct NEON_64bit_GEMM_Int7Operands_AccumEightWithin16Bits {
2806 typedef std::int8_t OperandType;
2807 typedef std::int32_t AccumulatorType;
2808 typedef KernelFormat<
2809 KernelSideFormat<CellFormat<4, 16, CellOrder::WidthMajor>, 1>,
2810 KernelSideFormat<CellFormat<2, 16, CellOrder::WidthMajor>, 1> >
2811 Format;
RunNEON_64bit_GEMM_Int7Operands_AccumEightWithin16Bits2812 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
2813 AccumulatorType* accum_ptr, int depth) {
2814 #define GEMMLOWP_LABEL_64_DEPTH_LOOP "1"
2815 #define GEMMLOWP_LABEL_64_DEPTH_AFTER_LOOP "2"
2816 #define GEMMLOWP_LABEL_16_DEPTH_LOOP "3"
2817 #define GEMMLOWP_LABEL_16_DEPTH_AFTER_LOOP "4"
2818
2819 AccumulatorType* dst_ptr = accum_ptr;
2820 asm volatile(
2821 // Overview of register layout:
2822 //
2823 // A 4x16 block of Lhs is stored in 8 bit in v0--v7.
2824 // A 2x16 block of Rhs is stored in 8 bit in v8--v15.
2825 //
2826 // A 4x2 block of global accumulators is stored in v24-v31 (as 4x32 bit
2827 // components which need to be horizontally-added at the end).
2828 //
2829 // A 4x2 block of local accumulators is stored in v16-v23 (as 8x16 bit
2830 // components which are added to global accumulators every 64 depth
2831 // iteration.
2832 //
2833 // The Lhs vectors are multiplied by the Rhs vectors with a widening
2834 // multiply over the 8 first levels of depth, producing int16x8
2835 // vectors of products for each position in the accumulator matrix.
2836 //
2837 // Like the trick used in the fast 8-bit kernel, the operands are
2838 // restricted to 7-bit range [-2^6, 2^6) so their products are in range
2839 // [-2^12, 2^12 -1). This enables adding eight such products without any
2840 // risk of overflowing int16, equating to 64 levels of depth before
2841 // horizontally adding these int16x8 accumulators into the final int32x4
2842 // accumulators.
2843 //
2844 // Register layout including both local and global accumulators.
2845 // Since we do not have enough registers to store all Lhs values, we
2846 // reuse the same registers v0--v7 to load the rest of the Lhs values.
2847 //
2848 // +-----+-----+
2849 // | v8 | v9 |
2850 // Rhs +-----+-----+
2851 // | v10 | v11 |
2852 // +-----+-----+
2853 // | v12 | v13 |
2854 // +-----+-----+
2855 // | v14 | v15 |
2856 // Lhs +-----+-----+
2857 // +----+----+----+----+ - - +-----+-----+ +--------+--------+
2858 // | v0 | v4 | v0 | v4 | | v16 | v20 | | v24.4s | v28.4s |
2859 // | v1 | v5 | v1 | v5 | | v17 | v21 | -> | v25.4s | v29.4s |
2860 // | v2 | v6 | v2 | v6 | | v18 | v22 | | v26.4s | v30.4s |
2861 // | v3 | v7 | v3 | v7 | | v19 | v23 | | v27.4s | v31.4s |
2862 // +----+----+----+----+ - - +-----+-----+ +--------+--------+
2863 //
2864 // Local Accumulator Global Accumulator
2865 //
2866
2867 // Clear accumulators.
2868 "dup v16.4s, wzr\n"
2869 "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
2870 "dup v24.4s, wzr\n"
2871 "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
2872 "dup v17.4s, wzr\n"
2873 "ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
2874 "dup v25.4s, wzr\n"
2875 "ld1 {v3.16b}, [%[lhs_ptr]], #16\n"
2876 "dup v18.4s, wzr\n"
2877 "ld1 {v8.16b}, [%[rhs_ptr]], #16\n"
2878 "dup v26.4s, wzr\n"
2879 "ld1 {v9.16b}, [%[rhs_ptr]], #16\n"
2880 "dup v19.4s, wzr\n"
2881 "dup v27.4s, wzr\n"
2882 "dup v20.4s, wzr\n"
2883 "dup v28.4s, wzr\n"
2884 "dup v21.4s, wzr\n"
2885 "dup v29.4s, wzr\n"
2886 "dup v22.4s, wzr\n"
2887 "dup v30.4s, wzr\n"
2888 "dup v23.4s, wzr\n"
2889 "dup v31.4s, wzr\n"
2890
2891 "cmp %w[depth], #64\n"
2892 "blt " GEMMLOWP_LABEL_64_DEPTH_AFTER_LOOP "f\n"
2893
2894 //"loop_%=:\n"
2895 GEMMLOWP_LABEL_64_DEPTH_LOOP
2896 ":\n"
2897 "subs %w[depth], %w[depth], #64\n"
2898 "ld1 {v4.16b}, [%[lhs_ptr]], #16\n"
2899 "sadalp v24.4s, v16.8h\n"
2900 "smull v16.8h, v0.8b, v8.8b\n"
2901 "ld1 {v5.16b}, [%[lhs_ptr]], #16\n"
2902 "sadalp v25.4s, v17.8h\n"
2903 "smull v17.8h, v1.8b, v8.8b\n"
2904 "ld1 {v6.16b}, [%[lhs_ptr]], #16\n"
2905 "sadalp v26.4s, v18.8h\n"
2906 "smull v18.8h, v2.8b, v8.8b\n"
2907 "ld1 {v7.16b}, [%[lhs_ptr]], #16\n"
2908 "sadalp v27.4s, v19.8h\n"
2909 "smull v19.8h, v3.8b, v8.8b\n"
2910 "ld1 {v10.16b}, [%[rhs_ptr]], #16\n"
2911 "sadalp v28.4s, v20.8h\n"
2912 "smull v20.8h, v0.8b, v9.8b\n"
2913 "ld1 {v11.16b}, [%[rhs_ptr]], #16\n"
2914 "sadalp v29.4s, v21.8h\n"
2915 "smull v21.8h, v1.8b, v9.8b\n"
2916 "ld1 {v12.16b}, [%[rhs_ptr]], #16\n"
2917 "sadalp v30.4s, v22.8h\n"
2918 "smull v22.8h, v2.8b, v9.8b\n"
2919 "ld1 {v13.16b}, [%[rhs_ptr]], #16\n"
2920 "sadalp v31.4s, v23.8h\n"
2921 "smull v23.8h, v3.8b, v9.8b\n"
2922
2923 "cmp %w[depth], #64\n"
2924 "smlal2 v16.8h, v0.16b, v8.16b\n"
2925 "ld1 {v14.16b}, [%[rhs_ptr]], #16\n"
2926 "smlal2 v17.8h, v1.16b, v8.16b\n"
2927 "ld1 {v15.16b}, [%[rhs_ptr]], #16\n"
2928 "smlal2 v18.8h, v2.16b, v8.16b\n"
2929 "smlal2 v19.8h, v3.16b, v8.16b\n"
2930
2931 "smlal2 v20.8h, v0.16b, v9.16b\n"
2932 "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
2933 "smlal2 v21.8h, v1.16b, v9.16b\n"
2934 "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
2935 "smlal2 v22.8h, v2.16b, v9.16b\n"
2936 "ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
2937 "smlal2 v23.8h, v3.16b, v9.16b\n"
2938 "ld1 {v3.16b}, [%[lhs_ptr]], #16\n"
2939
2940 "smlal v16.8h, v4.8b, v10.8b\n"
2941 "smlal v17.8h, v5.8b, v10.8b\n"
2942 "smlal v18.8h, v6.8b, v10.8b\n"
2943 "smlal v19.8h, v7.8b, v10.8b\n"
2944 "smlal v20.8h, v4.8b, v11.8b\n"
2945
2946 "smlal v21.8h, v5.8b, v11.8b\n"
2947 "smlal v22.8h, v6.8b, v11.8b\n"
2948 "smlal v23.8h, v7.8b, v11.8b\n"
2949
2950 "smlal2 v16.8h, v4.16b, v10.16b\n"
2951 "ld1 {v8.16b}, [%[rhs_ptr]], #16\n"
2952 "smlal2 v17.8h, v5.16b, v10.16b\n"
2953 "ld1 {v9.16b}, [%[rhs_ptr]], #16\n"
2954 "smlal2 v18.8h, v6.16b, v10.16b\n"
2955 "smlal2 v19.8h, v7.16b, v10.16b\n"
2956
2957 "smlal2 v20.8h, v4.16b, v11.16b\n"
2958 "ld1 {v4.16b}, [%[lhs_ptr]], #16\n"
2959 "smlal2 v21.8h, v5.16b, v11.16b\n"
2960 "ld1 {v5.16b}, [%[lhs_ptr]], #16\n"
2961 "smlal2 v22.8h, v6.16b, v11.16b\n"
2962 "ld1 {v6.16b}, [%[lhs_ptr]], #16\n"
2963 "smlal2 v23.8h, v7.16b, v11.16b\n"
2964 "ld1 {v7.16b}, [%[lhs_ptr]], #16\n"
2965
2966 "smlal v16.8h, v0.8b, v12.8b\n"
2967 "smlal v17.8h, v1.8b, v12.8b\n"
2968 "smlal v18.8h, v2.8b, v12.8b\n"
2969 "smlal v19.8h, v3.8b, v12.8b\n"
2970 "smlal v20.8h, v0.8b, v13.8b\n"
2971 "smlal v21.8h, v1.8b, v13.8b\n"
2972 "smlal v22.8h, v2.8b, v13.8b\n"
2973 "smlal v23.8h, v3.8b, v13.8b\n"
2974
2975 "smlal2 v16.8h, v0.16b, v12.16b\n"
2976 "smlal2 v17.8h, v1.16b, v12.16b\n"
2977 "smlal2 v18.8h, v2.16b, v12.16b\n"
2978 "smlal2 v19.8h, v3.16b, v12.16b\n"
2979
2980 "smlal2 v20.8h, v0.16b, v13.16b\n"
2981 "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
2982 "smlal2 v21.8h, v1.16b, v13.16b\n"
2983 "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
2984 "smlal2 v22.8h, v2.16b, v13.16b\n"
2985 "ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
2986 "smlal2 v23.8h, v3.16b, v13.16b\n"
2987 "ld1 {v3.16b}, [%[lhs_ptr]], #16\n"
2988
2989 "smlal v16.8h, v4.8b, v14.8b\n"
2990 "smlal v17.8h, v5.8b, v14.8b\n"
2991 "smlal v18.8h, v6.8b, v14.8b\n"
2992 "smlal v19.8h, v7.8b, v14.8b\n"
2993
2994 "smlal v20.8h, v4.8b, v15.8b\n"
2995 "smlal v21.8h, v5.8b, v15.8b\n"
2996 "smlal v22.8h, v6.8b, v15.8b\n"
2997 "smlal v23.8h, v7.8b, v15.8b\n"
2998
2999 "smlal2 v16.8h, v4.16b, v14.16b\n"
3000 "smlal2 v17.8h, v5.16b, v14.16b\n"
3001 "smlal2 v18.8h, v6.16b, v14.16b\n"
3002 "smlal2 v19.8h, v7.16b, v14.16b\n"
3003
3004 "smlal2 v20.8h, v4.16b, v15.16b\n"
3005 "smlal2 v21.8h, v5.16b, v15.16b\n"
3006 "smlal2 v22.8h, v6.16b, v15.16b\n"
3007 "smlal2 v23.8h, v7.16b, v15.16b\n"
3008
3009 "bge " GEMMLOWP_LABEL_64_DEPTH_LOOP "b\n"
3010
3011 GEMMLOWP_LABEL_64_DEPTH_AFTER_LOOP
3012 ":\n"
3013
3014 "cmp %w[depth], #16\n"
3015 "blt " GEMMLOWP_LABEL_16_DEPTH_AFTER_LOOP "f\n"
3016
3017 //"loop_%=:\n"
3018 GEMMLOWP_LABEL_16_DEPTH_LOOP
3019 ":\n"
3020 "sadalp v24.4s, v16.8h\n"
3021 "smull v16.8h, v0.8b, v8.8b\n"
3022 "subs %w[depth], %w[depth], #16\n"
3023 "sadalp v25.4s, v17.8h\n"
3024 "smull v17.8h, v1.8b, v8.8b\n"
3025 "sadalp v26.4s, v18.8h\n"
3026 "smull v18.8h, v2.8b, v8.8b\n"
3027 "sadalp v27.4s, v19.8h\n"
3028 "smull v19.8h, v3.8b, v8.8b\n"
3029 "sadalp v28.4s, v20.8h\n"
3030 "smull v20.8h, v0.8b, v9.8b\n"
3031 "sadalp v29.4s, v21.8h\n"
3032 "smull v21.8h, v1.8b, v9.8b\n"
3033 "sadalp v30.4s, v22.8h\n"
3034 "smull v22.8h, v2.8b, v9.8b\n"
3035 "sadalp v31.4s, v23.8h\n"
3036 "smull v23.8h, v3.8b, v9.8b\n"
3037
3038 "cmp %w[depth], #16\n"
3039 "smlal2 v16.8h, v0.16b, v8.16b\n"
3040 "smlal2 v17.8h, v1.16b, v8.16b\n"
3041 "smlal2 v18.8h, v2.16b, v8.16b\n"
3042 "smlal2 v19.8h, v3.16b, v8.16b\n"
3043 "ld1 {v8.16b}, [%[rhs_ptr]], #16\n"
3044
3045 "smlal2 v20.8h, v0.16b, v9.16b\n"
3046 "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
3047 "smlal2 v21.8h, v1.16b, v9.16b\n"
3048 "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
3049 "smlal2 v22.8h, v2.16b, v9.16b\n"
3050 "ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
3051 "smlal2 v23.8h, v3.16b, v9.16b\n"
3052 "ld1 {v3.16b}, [%[lhs_ptr]], #16\n"
3053 "ld1 {v9.16b}, [%[rhs_ptr]], #16\n"
3054
3055 "bge " GEMMLOWP_LABEL_16_DEPTH_LOOP "b\n"
3056
3057 GEMMLOWP_LABEL_16_DEPTH_AFTER_LOOP
3058 ":\n"
3059
3060 "sadalp v24.4s, v16.8h\n"
3061 "sadalp v25.4s, v17.8h\n"
3062 "sadalp v26.4s, v18.8h\n"
3063 "sadalp v27.4s, v19.8h\n"
3064 "sadalp v28.4s, v20.8h\n"
3065 "sadalp v29.4s, v21.8h\n"
3066 "sadalp v30.4s, v22.8h\n"
3067 "sadalp v31.4s, v23.8h\n"
3068
3069 // Reduce aggregators horizontally.
3070 "addp v0.4s, v24.4s, v25.4s\n"
3071 "addp v1.4s, v26.4s, v27.4s\n"
3072 "addp v2.4s, v28.4s, v29.4s\n"
3073 "addp v3.4s, v30.4s, v31.4s\n"
3074
3075 "addp v4.4s, v0.4s, v1.4s\n"
3076 "addp v5.4s, v2.4s, v3.4s\n"
3077
3078 // Load accumulators from memory.
3079 "mov x0, %[dst_ptr]\n"
3080 "ld1 {v6.16b}, [x0], #16\n"
3081 "ld1 {v7.16b}, [x0], #16\n"
3082
3083 // Add to the accumulators loaded from memory.
3084 "add v6.4s, v6.4s, v4.4s\n"
3085 "add v7.4s, v7.4s, v5.4s\n"
3086
3087 // Store accumulators back to memory.
3088 "mov x0, %[dst_ptr]\n"
3089 "st1 {v6.16b}, [x0], #16\n"
3090 "st1 {v7.16b}, [x0], #16\n"
3091
3092 :
3093 // Outputs.
3094 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
3095 [dst_ptr] "+r"(dst_ptr), [depth] "+r"(depth)
3096 :
3097 // Inputs.
3098
3099 :
3100 // Clobbers.
3101 "cc", "memory",
3102 // We use these NEON registers
3103 "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
3104 "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20",
3105 "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30",
3106 "v31", "x0");
3107 }
3108 };
3109
3110 SET_7BIT_RANGES(NEON_64bit_GEMM_Int7Operands_AccumEightWithin16Bits);
3111
3112 // Kernel operating on int8 operands with 4.25-bit range.
3113 // It is assumed that one of the two operands only takes values in [-7, 7],
3114 // while the other take values in [-9, 9].
3115 // With this restriction, it is possible to multiply-accumulate operands into
3116 // a 16-bit integer thirty-two times without overflow.
3117 struct NEON_64bit_GEMM_Int425Operands {
3118 typedef std::int8_t OperandType;
3119 typedef std::int32_t AccumulatorType;
3120 typedef KernelFormat<
3121 KernelSideFormat<CellFormat<4, 32, CellOrder::WidthMajor>, 1>,
3122 KernelSideFormat<CellFormat<2, 32, CellOrder::WidthMajor>, 1> >
3123 Format;
RunNEON_64bit_GEMM_Int425Operands3124 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
3125 AccumulatorType* accum_ptr, int depth) {
3126 #define GEMMLOWP_LABEL_512_DEPTH_LOOP "1"
3127 #define GEMMLOWP_LABEL_32_DEPTH_LOOP "2"
3128 #define GEMMLOWP_LABEL_32_DEPTH_AFTER_LOOP "3"
3129
3130 AccumulatorType* dst_ptr = accum_ptr;
3131 int outer_depth = depth / 512 + 1;
3132
3133 asm volatile(
3134 // Overview of register layout:
3135 //
3136 // A 4x32 block of Lhs is stored in 8 bit in v0--v7.
3137 // A 2x32 block of Rhs is stored in 8 bit in v8--v11.
3138 //
3139 // A 4x2 block of global accumulators is stored in v24-v31 (as 4x32 bit
3140 // components which need to be horizontally-added at the end).
3141 //
3142 // A 4x2 block of local accumulators is stored in v16-v23 (as 8x16 bit
3143 // components which are horizontally-added to global accumulators every
3144 // 512 depth iteration.
3145 //
3146 // The Lhs vectors are multiplied by the Rhs vectors with a multiply
3147 // over the 16 first levels of depth, producing int8x16 vectors of
3148 // products for each position in the accumulator matrix.
3149 //
3150 // Like the trick used in the fast 8-bit and 7-bit kernels, the operands
3151 // are restricted to 4.25-bit range, [-7, 7] for one operand and [-9, 9]
3152 // for the other operand. This enables adding two such products without
3153 // any risk of overflowing int8, and thiry-two such products without
3154 // overflowing int16. This equates to 512 levels of depth before
3155 // horizontally adding these int16x8 accumulators into the final int32x4
3156 // accumulators.
3157 //
3158 // Register layout (ignoring the v12--v15 temporary 8-bit accumulators).
3159 // Since we do not have enough registers to store all Lhs values and Rhs
3160 // values, we reuse the same registers v0--v7 to load subsequent Lhs
3161 // values and v8-v11 to subsequent Rhs values.
3162 //
3163 // +-----+-----+
3164 // | v8 | v9 |
3165 // Rhs +-----+-----+
3166 // | v10 | v11 |
3167 // +-----+-----+
3168 // | v8 | v9 |
3169 // +-----+-----+
3170 // | v10 | v11 |
3171 // Lhs +-----+-----+
3172 // +----+----+----+----+ - - +-----+-----+ +--------+--------+
3173 // | v0 | v4 | v0 | v4 | | v16 | v17 | | v24.4s | v25.4s |
3174 // | v1 | v5 | v1 | v5 | | v18 | v19 | -> | v26.4s | v27.4s |
3175 // | v2 | v6 | v2 | v6 | | v20 | v21 | | v28.4s | v29.4s |
3176 // | v3 | v7 | v3 | v7 | | v22 | v23 | | v30.4s | v31.4s |
3177 // +----+----+----+----+ - - +-----+-----+ +--------+--------+
3178 //
3179 // Local Accumulator Global Accumulator
3180 //
3181
3182 // Clear global accumulators.
3183 "dup v24.4s, wzr\n"
3184 "ld1 {v8.16b}, [%[rhs_ptr]], #16\n"
3185 "dup v25.4s, wzr\n"
3186 "ld1 {v9.16b}, [%[rhs_ptr]], #16\n"
3187 "dup v26.4s, wzr\n"
3188 "ld1 {v10.16b}, [%[rhs_ptr]], #16\n"
3189 "dup v27.4s, wzr\n"
3190 "ld1 {v11.16b}, [%[rhs_ptr]], #16\n"
3191 "dup v28.4s, wzr\n"
3192 "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
3193 "dup v29.4s, wzr\n"
3194 "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
3195 "dup v30.4s, wzr\n"
3196 "ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
3197 "dup v31.4s, wzr\n"
3198 "ld1 {v3.16b}, [%[lhs_ptr]], #16\n"
3199
3200 "ld1 {v4.16b}, [%[lhs_ptr]], #16\n"
3201 "ld1 {v5.16b}, [%[lhs_ptr]], #16\n"
3202 "ld1 {v6.16b}, [%[lhs_ptr]], #16\n"
3203 "ld1 {v7.16b}, [%[lhs_ptr]], #16\n"
3204
3205 //"loop_%=:\n"
3206 GEMMLOWP_LABEL_512_DEPTH_LOOP
3207 ":\n"
3208 // Clear local accumulators.
3209 "dup v16.8h, wzr\n"
3210 "dup v17.8h, wzr\n"
3211 "dup v18.8h, wzr\n"
3212 "mov x1, #512\n"
3213 "dup v19.8h, wzr\n"
3214 "dup v20.8h, wzr\n"
3215 "dup v21.8h, wzr\n"
3216 "dup v22.8h, wzr\n"
3217 "dup v23.8h, wzr\n"
3218
3219 //"loop_%=:\n"
3220 GEMMLOWP_LABEL_32_DEPTH_LOOP
3221 ":\n"
3222 "mul v12.16b, v0.16b, v8.16b\n"
3223 "mul v13.16b, v0.16b, v10.16b\n"
3224 "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
3225 "mul v14.16b, v2.16b, v8.16b\n"
3226 "mul v15.16b, v2.16b, v10.16b\n"
3227
3228 "mla v12.16b, v1.16b, v9.16b\n"
3229 "mla v13.16b, v1.16b, v11.16b\n"
3230 "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
3231 "mla v14.16b, v3.16b, v9.16b\n"
3232 "ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
3233 "mla v15.16b, v3.16b, v11.16b\n"
3234 "ld1 {v3.16b}, [%[lhs_ptr]], #16\n"
3235
3236 "sadalp v16.8h, v12.16b\n"
3237 "sadalp v17.8h, v13.16b\n"
3238 "subs %w[depth], %w[depth], #32\n"
3239 "sadalp v18.8h, v14.16b\n"
3240 "sadalp v19.8h, v15.16b\n"
3241 "subs x1, x1, #32\n"
3242
3243 "mul v12.16b, v4.16b, v8.16b\n"
3244 "mul v13.16b, v4.16b, v10.16b\n"
3245 "ld1 {v4.16b}, [%[lhs_ptr]], #16\n"
3246 "mul v14.16b, v6.16b, v8.16b\n"
3247 "ld1 {v8.16b}, [%[rhs_ptr]], #16\n"
3248 "mul v15.16b, v6.16b, v10.16b\n"
3249
3250 "mla v12.16b, v5.16b, v9.16b\n"
3251 "mla v13.16b, v5.16b, v11.16b\n"
3252 "ld1 {v5.16b}, [%[lhs_ptr]], #16\n"
3253 "mla v14.16b, v7.16b, v9.16b\n"
3254 "ld1 {v9.16b}, [%[rhs_ptr]], #16\n"
3255 "mla v15.16b, v7.16b, v11.16b\n"
3256 "ld1 {v10.16b}, [%[rhs_ptr]], #16\n"
3257
3258 "sadalp v20.8h, v12.16b\n"
3259 "ld1 {v11.16b}, [%[rhs_ptr]], #16\n"
3260 "sadalp v21.8h, v13.16b\n"
3261 "ld1 {v6.16b}, [%[lhs_ptr]], #16\n"
3262 "sadalp v22.8h, v14.16b\n"
3263 "ld1 {v7.16b}, [%[lhs_ptr]], #16\n"
3264 "sadalp v23.8h, v15.16b\n"
3265
3266 "mul v12.16b, v0.16b, v8.16b\n"
3267 "mul v13.16b, v0.16b, v10.16b\n"
3268 "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
3269 "mul v14.16b, v2.16b, v8.16b\n"
3270 "mul v15.16b, v2.16b, v10.16b\n"
3271
3272 "mla v12.16b, v1.16b, v9.16b\n"
3273 "mla v13.16b, v1.16b, v11.16b\n"
3274 "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
3275 "mla v14.16b, v3.16b, v9.16b\n"
3276 "ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
3277 "mla v15.16b, v3.16b, v11.16b\n"
3278 "ld1 {v3.16b}, [%[lhs_ptr]], #16\n"
3279
3280 "sadalp v16.8h, v12.16b\n"
3281 "sadalp v17.8h, v13.16b\n"
3282 "sadalp v18.8h, v14.16b\n"
3283 "sadalp v19.8h, v15.16b\n"
3284
3285 "mul v12.16b, v4.16b, v8.16b\n"
3286 "mul v13.16b, v4.16b, v10.16b\n"
3287 "ld1 {v4.16b}, [%[lhs_ptr]], #16\n"
3288 "mul v14.16b, v6.16b, v8.16b\n"
3289 "ld1 {v8.16b}, [%[rhs_ptr]], #16\n"
3290 "mul v15.16b, v6.16b, v10.16b\n"
3291
3292 "mla v12.16b, v5.16b, v9.16b\n"
3293 "mla v13.16b, v5.16b, v11.16b\n"
3294 "ld1 {v5.16b}, [%[lhs_ptr]], #16\n"
3295 "mla v14.16b, v7.16b, v9.16b\n"
3296 "ld1 {v9.16b}, [%[rhs_ptr]], #16\n"
3297 "mla v15.16b, v7.16b, v11.16b\n"
3298 "ld1 {v10.16b}, [%[rhs_ptr]], #16\n"
3299
3300 "sadalp v20.8h, v12.16b\n"
3301 "ld1 {v11.16b}, [%[rhs_ptr]], #16\n"
3302 "sadalp v21.8h, v13.16b\n"
3303 "ld1 {v6.16b}, [%[lhs_ptr]], #16\n"
3304 "sadalp v22.8h, v14.16b\n"
3305 "ld1 {v7.16b}, [%[lhs_ptr]], #16\n"
3306 "sadalp v23.8h, v15.16b\n"
3307
3308 "beq " GEMMLOWP_LABEL_32_DEPTH_AFTER_LOOP
3309 "f\n"
3310
3311 "cmp %w[depth], #0\n"
3312 "bne " GEMMLOWP_LABEL_32_DEPTH_LOOP "b\n"
3313
3314 GEMMLOWP_LABEL_32_DEPTH_AFTER_LOOP
3315 ":\n"
3316
3317 // Pairwise add 16-bit local accums to 32-bit global accums.
3318 "sadalp v24.4s, v16.8h\n"
3319 "sadalp v25.4s, v17.8h\n"
3320 "sadalp v26.4s, v18.8h\n"
3321 "sadalp v27.4s, v19.8h\n"
3322 "sadalp v28.4s, v20.8h\n"
3323 "sadalp v29.4s, v21.8h\n"
3324 "sadalp v30.4s, v22.8h\n"
3325 "sadalp v31.4s, v23.8h\n"
3326
3327 "bne " GEMMLOWP_LABEL_512_DEPTH_LOOP
3328 "b\n"
3329
3330 // Reduce aggregators horizontally.
3331 "addp v0.4s, v24.4s, v26.4s\n"
3332 "addp v1.4s, v28.4s, v30.4s\n"
3333 "addp v2.4s, v25.4s, v27.4s\n"
3334 "addp v3.4s, v29.4s, v31.4s\n"
3335
3336 "addp v4.4s, v0.4s, v1.4s\n"
3337 "addp v5.4s, v2.4s, v3.4s\n"
3338
3339 // Load accumulators from memory.
3340 "mov x0, %[dst_ptr]\n"
3341 "ld1 {v6.16b}, [x0], #16\n"
3342 "ld1 {v7.16b}, [x0], #16\n"
3343
3344 // Add to the accumulators loaded from memory.
3345 "add v6.4s, v6.4s, v4.4s\n"
3346 "add v7.4s, v7.4s, v5.4s\n"
3347
3348 // Store accumulators back to memory.
3349 "mov x0, %[dst_ptr]\n"
3350 "st1 {v6.16b}, [x0], #16\n"
3351 "st1 {v7.16b}, [x0], #16\n"
3352
3353 :
3354 // Outputs.
3355 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
3356 [dst_ptr] "+r"(dst_ptr), [depth] "+r"(depth),
3357 [outer_depth] "+r"(outer_depth)
3358 :
3359 // Inputs.
3360
3361 :
3362 // Clobbers.
3363 "cc", "memory",
3364 // We use these NEON registers
3365 "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
3366 "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20",
3367 "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30",
3368 "v31", "x0", "x1");
3369 }
3370 };
3371
3372 SET_425BIT_RANGES(NEON_64bit_GEMM_Int425Operands);
3373
3374 #ifdef __ARM_FEATURE_DOTPROD
3375 // Kernels utilizing the Armv8.2 Dot Product extension.
3376 //
3377 // The dot product instructions work by taking 4 consecutive 8-bit depth
3378 // values from each operand, multiplying the 4 pairs together and
3379 // accumulating all the results into the corresponding 32-bit accumulator
3380 // lane. As such, the operation is identical to a 32-bit instruction (like
3381 // FMLA used in SGEMM), except that 4 depth values are processed at a time
3382 // instead of 1.
3383
3384 // Thus, this first kernel is a carbon copy of
3385 // "NEON_64bit_GEMM_Float32_WithScalar_A57" (which should provide good
3386 // performance for most processors) below with the opcode (fmla -> udot) and
3387 // types (float32 -> uint8/uint32) changed.
3388 //
3389 // A signed version of this kernel could be produced by replacing "udot"
3390 // with "sdot" - performance should be identical to this udot kernel.
3391 struct NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct {
3392 typedef std::uint8_t OperandType;
3393 typedef std::uint32_t AccumulatorType;
3394 typedef KernelFormat<
3395 KernelSideFormat<CellFormat<4, 4, CellOrder::WidthMajor>, 3>,
3396 KernelSideFormat<CellFormat<4, 4, CellOrder::WidthMajor>, 2> >
3397 Format;
RunNEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct3398 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
3399 AccumulatorType* accum_ptr, int depth) {
3400 asm volatile(
3401 // Load accumulators
3402 "mov x0, %[accum_ptr]\n"
3403 "ld1 {v8.4s}, [x0], #16\n"
3404 "ld1 {v16.4s}, [x0], #16\n"
3405 "ld1 {v24.4s}, [x0], #16\n"
3406 "ld1 {v9.4s}, [x0], #16\n"
3407 "ld1 {v17.4s}, [x0], #16\n"
3408 "ld1 {v25.4s}, [x0], #16\n"
3409 "ld1 {v10.4s}, [x0], #16\n"
3410 "ld1 {v18.4s}, [x0], #16\n"
3411 "ld1 {v26.4s}, [x0], #16\n"
3412 "ld1 {v11.4s}, [x0], #16\n"
3413 "ld1 {v19.4s}, [x0], #16\n"
3414 "ld1 {v27.4s}, [x0], #16\n"
3415 "ld1 {v12.4s}, [x0], #16\n"
3416 "ld1 {v20.4s}, [x0], #16\n"
3417 "ld1 {v28.4s}, [x0], #16\n"
3418 "ld1 {v13.4s}, [x0], #16\n"
3419 "ld1 {v21.4s}, [x0], #16\n"
3420 "ld1 {v29.4s}, [x0], #16\n"
3421 "ld1 {v14.4s}, [x0], #16\n"
3422 "ld1 {v22.4s}, [x0], #16\n"
3423 "ld1 {v30.4s}, [x0], #16\n"
3424 "ld1 {v15.4s}, [x0], #16\n"
3425 "ld1 {v23.4s}, [x0], #16\n"
3426 "ld1 {v31.4s}, [x0], #16\n"
3427
3428 // The start of the loop assumes first Rhs cell is already loaded, so
3429 // do it here for first iteration.
3430 "ld1 {v0.16b}, [%[rhs_ptr]], #16\n"
3431
3432 // And the same for the first Lhs cell.
3433 "ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
3434
3435 GEMMLOWP_LABEL_LOOP
3436 ":\n"
3437
3438 // Start the MACs at the head of the loop - 1st cell from each side
3439 // already loaded.
3440 ".word 0x6f80e048 // udot v8.4s, v2.16b, v0.4b[0]\n"
3441 ".word 0x6fa0e049 // udot v9.4s, v2.16b, v0.4b[1]\n"
3442 "ld1 {v1.16b}, [%[rhs_ptr]], #16\n" // Load second Rhs cell.
3443 ".word 0x6f80e84a // udot v10.4s, v2.16b, v0.4b[2]\n"
3444 ".word 0x6fa0e84b // udot v11.4s, v2.16b, v0.4b[3]\n"
3445 "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" // Load second Lhs cell.
3446 ".word 0x6f81e04c // udot v12.4s, v2.16b, v1.4b[0]\n"
3447 ".word 0x6fa1e04d // udot v13.4s, v2.16b, v1.4b[1]\n"
3448 "ld1 {v4.16b}, [%[lhs_ptr]], #16\n" // Load third Lhs cell.
3449 ".word 0x6f81e84e // udot v14.4s, v2.16b, v1.4b[2]\n"
3450 ".word 0x6fa1e84f // udot v15.4s, v2.16b, v1.4b[3]\n"
3451 "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" // Done with first Lhs cell - load
3452 // for the next iteration early.
3453 ".word 0x6f80e070 // udot v16.4s, v3.16b, v0.4b[0]\n"
3454 ".word 0x6fa0e071 // udot v17.4s, v3.16b, v0.4b[1]\n"
3455 ".word 0x6f80e872 // udot v18.4s, v3.16b, v0.4b[2]\n"
3456 ".word 0x6fa0e873 // udot v19.4s, v3.16b, v0.4b[3]\n"
3457 ".word 0x6f81e074 // udot v20.4s, v3.16b, v1.4b[0]\n"
3458 ".word 0x6fa1e075 // udot v21.4s, v3.16b, v1.4b[1]\n"
3459 ".word 0x6f81e876 // udot v22.4s, v3.16b, v1.4b[2]\n"
3460 ".word 0x6fa1e877 // udot v23.4s, v3.16b, v1.4b[3]\n"
3461 ".word 0x6f80e098 // udot v24.4s, v4.16b, v0.4b[0]\n"
3462 ".word 0x6fa0e099 // udot v25.4s, v4.16b, v0.4b[1]\n"
3463 ".word 0x6f80e89a // udot v26.4s, v4.16b, v0.4b[2]\n"
3464 ".word 0x6fa0e89b // udot v27.4s, v4.16b, v0.4b[3]\n"
3465 "ld1 {v0.16b}, [%[rhs_ptr]], #16\n" // Done with the first Rhs cell -
3466 // load for the next iteration early.
3467 ".word 0x6f81e09c // udot v28.4s, v4.16b, v1.4b[0]\n"
3468 ".word 0x6fa1e09d // udot v29.4s, v4.16b, v1.4b[1]\n"
3469
3470 // Loop. Decrement loop index (depth) by 4 as udot processes 4
3471 // depth values.
3472 "subs %w[depth], %w[depth], #4\n"
3473 ".word 0x6f81e89e // udot v30.4s, v4.16b, v1.4b[2]\n"
3474 ".word 0x6fa1e89f // udot v31.4s, v4.16b, v1.4b[3]\n"
3475
3476 "bne " GEMMLOWP_LABEL_LOOP
3477 "b\n"
3478
3479 // Store accumulators
3480 "mov x0, %[accum_ptr]\n"
3481 "st1 {v8.16b}, [x0], #16\n"
3482 "st1 {v16.16b}, [x0], #16\n"
3483 "st1 {v24.16b}, [x0], #16\n"
3484 "st1 {v9.16b}, [x0], #16\n"
3485 "st1 {v17.16b}, [x0], #16\n"
3486 "st1 {v25.16b}, [x0], #16\n"
3487 "st1 {v10.16b}, [x0], #16\n"
3488 "st1 {v18.16b}, [x0], #16\n"
3489 "st1 {v26.16b}, [x0], #16\n"
3490 "st1 {v11.16b}, [x0], #16\n"
3491 "st1 {v19.16b}, [x0], #16\n"
3492 "st1 {v27.16b}, [x0], #16\n"
3493 "st1 {v12.16b}, [x0], #16\n"
3494 "st1 {v20.16b}, [x0], #16\n"
3495 "st1 {v28.16b}, [x0], #16\n"
3496 "st1 {v13.16b}, [x0], #16\n"
3497 "st1 {v21.16b}, [x0], #16\n"
3498 "st1 {v29.16b}, [x0], #16\n"
3499 "st1 {v14.16b}, [x0], #16\n"
3500 "st1 {v22.16b}, [x0], #16\n"
3501 "st1 {v30.16b}, [x0], #16\n"
3502 "st1 {v15.16b}, [x0], #16\n"
3503 "st1 {v23.16b}, [x0], #16\n"
3504 "st1 {v31.16b}, [x0], #16\n"
3505 : // outputs
3506 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
3507 [depth] "+r"(depth)
3508 : // inputs
3509 [accum_ptr] "r"(accum_ptr)
3510 : // clobbers
3511 "cc", "memory", "x0", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7",
3512 "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17",
3513 "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27",
3514 "v28", "v29", "v30", "v31");
3515 }
3516 };
3517
3518 // As above, except tuned for Cortex-A55r1.
3519 //
3520 // Similarly, this is a clone of NEON_64bit_GEMM_Float32_WithScalar_A55r1
3521 // with the names changed.
3522 struct NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct_A55r1 {
3523 typedef std::uint8_t OperandType;
3524 typedef std::uint32_t AccumulatorType;
3525 typedef KernelFormat<
3526 KernelSideFormat<CellFormat<4, 4, CellOrder::WidthMajor>, 3>,
3527 KernelSideFormat<CellFormat<4, 4, CellOrder::WidthMajor>, 2> >
3528 Format;
RunNEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct_A55r13529 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
3530 AccumulatorType* accum_ptr, int depth) {
3531 asm volatile(
3532 // Load accumulators
3533 "mov x0, %[accum_ptr]\n"
3534 "ld1 {v8.4s}, [x0], #16\n"
3535 "ld1 {v16.4s}, [x0], #16\n"
3536 "ld1 {v24.4s}, [x0], #16\n"
3537 "ld1 {v9.4s}, [x0], #16\n"
3538 "ld1 {v17.4s}, [x0], #16\n"
3539 "ld1 {v25.4s}, [x0], #16\n"
3540 "ld1 {v10.4s}, [x0], #16\n"
3541 "ld1 {v18.4s}, [x0], #16\n"
3542 "ld1 {v26.4s}, [x0], #16\n"
3543 "ld1 {v11.4s}, [x0], #16\n"
3544 "ld1 {v19.4s}, [x0], #16\n"
3545 "ld1 {v27.4s}, [x0], #16\n"
3546 "ld1 {v12.4s}, [x0], #16\n"
3547 "ld1 {v20.4s}, [x0], #16\n"
3548 "ld1 {v28.4s}, [x0], #16\n"
3549 "ld1 {v13.4s}, [x0], #16\n"
3550 "ld1 {v21.4s}, [x0], #16\n"
3551 "ld1 {v29.4s}, [x0], #16\n"
3552 "ld1 {v14.4s}, [x0], #16\n"
3553 "ld1 {v22.4s}, [x0], #16\n"
3554 "ld1 {v30.4s}, [x0], #16\n"
3555 "ld1 {v15.4s}, [x0], #16\n"
3556 "ld1 {v23.4s}, [x0], #16\n"
3557 "ld1 {v31.4s}, [x0], #16\n"
3558
3559 // For details on how this kernel works, see the Float32 kernel below.
3560
3561 "ldr d0, [%[rhs_ptr]]\n"
3562 "ldr x18, [%[rhs_ptr], #8]\n"
3563
3564 "ldr q2, [%[lhs_ptr]]\n"
3565 "ldr q3, [%[lhs_ptr], #16]\n"
3566
3567 GEMMLOWP_LABEL_LOOP
3568 ":\n"
3569
3570 ".word 0x6f80e048 // udot v8.4s, v2.16b, v0.4b[0]\n"
3571 "ldr d1, [%[rhs_ptr], #16]\n" // Bottom half of v1
3572 ".word 0x6fa0e049 // udot v9.4s, v2.16b, v0.4b[1]\n"
3573 "ins v0.d[1], x18\n" // Finish loading v0
3574 ".word 0x6f80e070 // udot v16.4s, v3.16b, v0.4b[0]\n" // out of
3575 // sequence -
3576 // used to
3577 // reduce
3578 // load/use
3579 // pressure.
3580 "ldr x18, [%[rhs_ptr], #24]\n" // Top half of v1 to X register
3581 ".word 0x6fa0e071 // udot v17.4s, v3.16b, v0.4b[1]\n" // out of
3582 // sequence -
3583 // used to
3584 // reduce
3585 // load/use
3586 // pressure.
3587 "add %[rhs_ptr], %[rhs_ptr], #32\n" // RHS loads complete - increment
3588 // pointer.
3589 ".word 0x6f80e84a // udot v10.4s, v2.16b, v0.4b[2]\n"
3590 "ldr d4, [%[lhs_ptr], #32]\n" // Bottom half of v4
3591 ".word 0x6fa0e84b // udot v11.4s, v2.16b, v0.4b[3]\n"
3592 "ins v1.d[1], x18\n" // Finish loading v1
3593 ".word 0x6f81e04c // udot v12.4s, v2.16b, v1.4b[0]\n"
3594 "ldr x18, [%[lhs_ptr], #40]\n" // Top half of v4 to X register
3595 ".word 0x6fa1e04d // udot v13.4s, v2.16b, v1.4b[1]\n"
3596 "add %[lhs_ptr], %[lhs_ptr], #48\n" // LHS loads complete - increment
3597 // pointer.
3598 ".word 0x6f81e84e // udot v14.4s, v2.16b, v1.4b[2]\n"
3599
3600 ".word 0x6fa1e84f // udot v15.4s, v2.16b, v1.4b[3]\n"
3601 "ldr d2, [%[lhs_ptr]]\n" // Bottom half of v2 (for next time)
3602 ".word 0x6f80e872 // udot v18.4s, v3.16b, v0.4b[2]\n"
3603 "ins v4.d[1], x18\n" // Finish loading v4
3604 ".word 0x6fa0e873 // udot v19.4s, v3.16b, v0.4b[3]\n"
3605 "ldr x18, [%[lhs_ptr], #8]\n" // Top half of next v2 to X register
3606 ".word 0x6f81e074 // udot v20.4s, v3.16b, v1.4b[0]\n"
3607 "subs %w[depth], %w[depth], #4\n"
3608 ".word 0x6fa1e075 // udot v21.4s, v3.16b, v1.4b[1]\n"
3609
3610 ".word 0x6f81e876 // udot v22.4s, v3.16b, v1.4b[2]\n"
3611
3612 ".word 0x6fa1e877 // udot v23.4s, v3.16b, v1.4b[3]\n"
3613 "ldr d3, [%[lhs_ptr], #16]\n" // Bottom half of v3 (for next time)
3614 ".word 0x6f80e098 // udot v24.4s, v4.16b, v0.4b[0]\n"
3615 "ins v2.d[1], x18\n" // Finish loading next v2
3616 ".word 0x6fa0e099 // udot v25.4s, v4.16b, v0.4b[1]\n"
3617 "ldr x18, [%[lhs_ptr], #24]\n" // Top half of next v3 to X register
3618 ".word 0x6f80e89a // udot v26.4s, v4.16b, v0.4b[2]\n"
3619
3620 ".word 0x6fa0e89b // udot v27.4s, v4.16b, v0.4b[3]\n"
3621 "ldr d0, [%[rhs_ptr]]\n" // Bottom half of v0 (for next time)
3622 ".word 0x6f81e09c // udot v28.4s, v4.16b, v1.4b[0]\n"
3623 "ins v3.d[1], x18\n" // Finish loading next v3
3624 ".word 0x6fa1e09d // udot v29.4s, v4.16b, v1.4b[1]\n"
3625 "ldr x18, [%[rhs_ptr], #8]\n" // Top half of next v0 to X register
3626 ".word 0x6f81e89e // udot v30.4s, v4.16b, v1.4b[2]\n"
3627
3628 ".word 0x6fa1e89f // udot v31.4s, v4.16b, v1.4b[3]\n"
3629 "bne " GEMMLOWP_LABEL_LOOP
3630 "b\n"
3631
3632 // Store accumulators
3633 "mov x0, %[accum_ptr]\n"
3634 "st1 {v8.4s}, [x0], #16\n"
3635 "st1 {v16.4s}, [x0], #16\n"
3636 "st1 {v24.4s}, [x0], #16\n"
3637 "st1 {v9.4s}, [x0], #16\n"
3638 "st1 {v17.4s}, [x0], #16\n"
3639 "st1 {v25.4s}, [x0], #16\n"
3640 "st1 {v10.4s}, [x0], #16\n"
3641 "st1 {v18.4s}, [x0], #16\n"
3642 "st1 {v26.4s}, [x0], #16\n"
3643 "st1 {v11.4s}, [x0], #16\n"
3644 "st1 {v19.4s}, [x0], #16\n"
3645 "st1 {v27.4s}, [x0], #16\n"
3646 "st1 {v12.4s}, [x0], #16\n"
3647 "st1 {v20.4s}, [x0], #16\n"
3648 "st1 {v28.4s}, [x0], #16\n"
3649 "st1 {v13.4s}, [x0], #16\n"
3650 "st1 {v21.4s}, [x0], #16\n"
3651 "st1 {v29.4s}, [x0], #16\n"
3652 "st1 {v14.4s}, [x0], #16\n"
3653 "st1 {v22.4s}, [x0], #16\n"
3654 "st1 {v30.4s}, [x0], #16\n"
3655 "st1 {v15.4s}, [x0], #16\n"
3656 "st1 {v23.4s}, [x0], #16\n"
3657 "st1 {v31.4s}, [x0], #16\n"
3658 : // outputs
3659 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
3660 [depth] "+r"(depth)
3661 : // inputs
3662 [accum_ptr] "r"(accum_ptr)
3663 : // clobbers
3664 "cc", "memory", "x0", "x18", "v0", "v1", "v2", "v3", "v4", "v5", "v6",
3665 "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16",
3666 "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26",
3667 "v27", "v28", "v29", "v30", "v31");
3668 }
3669 };
3670 #endif // __ARM_FEATURE_DOTPROD
3671
3672 // We don't actually use int32*int32 in production. This is just an
3673 // experiment to help dissociate the effect of integer-vs-float, from the
3674 // effect of operands width.
3675 struct NEON_64bit_GEMM_Int32_WithScalar {
3676 typedef std::int32_t OperandType;
3677 typedef std::int32_t AccumulatorType;
3678 typedef KernelFormat<
3679 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 3>,
3680 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 2> >
3681 Format;
RunNEON_64bit_GEMM_Int32_WithScalar3682 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
3683 AccumulatorType* accum_ptr, int depth) {
3684 asm volatile(
3685 // Load accumulators
3686 "mov x0, %[accum_ptr]\n"
3687 "ld1 {v8.16b}, [x0], #16\n"
3688 "ld1 {v16.16b}, [x0], #16\n"
3689 "ld1 {v24.16b}, [x0], #16\n"
3690 "ld1 {v9.16b}, [x0], #16\n"
3691 "ld1 {v17.16b}, [x0], #16\n"
3692 "ld1 {v25.16b}, [x0], #16\n"
3693 "ld1 {v10.16b}, [x0], #16\n"
3694 "ld1 {v18.16b}, [x0], #16\n"
3695 "ld1 {v26.16b}, [x0], #16\n"
3696 "ld1 {v11.16b}, [x0], #16\n"
3697 "ld1 {v19.16b}, [x0], #16\n"
3698 "ld1 {v27.16b}, [x0], #16\n"
3699 "ld1 {v12.16b}, [x0], #16\n"
3700 "ld1 {v20.16b}, [x0], #16\n"
3701 "ld1 {v28.16b}, [x0], #16\n"
3702 "ld1 {v13.16b}, [x0], #16\n"
3703 "ld1 {v21.16b}, [x0], #16\n"
3704 "ld1 {v29.16b}, [x0], #16\n"
3705 "ld1 {v14.16b}, [x0], #16\n"
3706 "ld1 {v22.16b}, [x0], #16\n"
3707 "ld1 {v30.16b}, [x0], #16\n"
3708 "ld1 {v15.16b}, [x0], #16\n"
3709 "ld1 {v23.16b}, [x0], #16\n"
3710 "ld1 {v31.16b}, [x0], #16\n"
3711
3712 GEMMLOWP_LABEL_LOOP
3713 ":\n"
3714
3715 // Load 2 Rhs cell of size 1x4 each
3716 "ld1 {v0.4s}, [%[rhs_ptr]], #16\n"
3717 "ld1 {v1.4s}, [%[rhs_ptr]], #16\n"
3718
3719 // Load 3 Lhs cells of size 4x1 each
3720 "ld1 {v2.4s}, [%[lhs_ptr]], #16\n"
3721 "ld1 {v3.4s}, [%[lhs_ptr]], #16\n"
3722 "ld1 {v4.4s}, [%[lhs_ptr]], #16\n"
3723
3724 // Multiply-accumulate
3725 "mla v8.4s, v2.4s, v0.s[0]\n"
3726 "mla v9.4s, v2.4s, v0.s[1]\n"
3727 "mla v10.4s, v2.4s, v0.s[2]\n"
3728 "mla v11.4s, v2.4s, v0.s[3]\n"
3729 "mla v12.4s, v2.4s, v1.s[0]\n"
3730 "mla v13.4s, v2.4s, v1.s[1]\n"
3731 "mla v14.4s, v2.4s, v1.s[2]\n"
3732 "mla v15.4s, v2.4s, v1.s[3]\n"
3733 "mla v16.4s, v3.4s, v0.s[0]\n"
3734 "mla v17.4s, v3.4s, v0.s[1]\n"
3735 "mla v18.4s, v3.4s, v0.s[2]\n"
3736 "mla v19.4s, v3.4s, v0.s[3]\n"
3737 "mla v20.4s, v3.4s, v1.s[0]\n"
3738 "mla v21.4s, v3.4s, v1.s[1]\n"
3739 "mla v22.4s, v3.4s, v1.s[2]\n"
3740 "mla v23.4s, v3.4s, v1.s[3]\n"
3741 "mla v24.4s, v4.4s, v0.s[0]\n"
3742 "mla v25.4s, v4.4s, v0.s[1]\n"
3743 "mla v26.4s, v4.4s, v0.s[2]\n"
3744 "mla v27.4s, v4.4s, v0.s[3]\n"
3745 "mla v28.4s, v4.4s, v1.s[0]\n"
3746 "mla v29.4s, v4.4s, v1.s[1]\n"
3747 "mla v30.4s, v4.4s, v1.s[2]\n"
3748 "mla v31.4s, v4.4s, v1.s[3]\n"
3749
3750 // Loop. Decrement loop index (depth) by 1, since we just handled 1
3751 // level of depth.
3752 "subs %w[depth], %w[depth], #1\n"
3753 "bne " GEMMLOWP_LABEL_LOOP
3754 "b\n"
3755
3756 // Store accumulators
3757 "mov x0, %[accum_ptr]\n"
3758 "st1 {v8.16b}, [x0], #16\n"
3759 "st1 {v16.16b}, [x0], #16\n"
3760 "st1 {v24.16b}, [x0], #16\n"
3761 "st1 {v9.16b}, [x0], #16\n"
3762 "st1 {v17.16b}, [x0], #16\n"
3763 "st1 {v25.16b}, [x0], #16\n"
3764 "st1 {v10.16b}, [x0], #16\n"
3765 "st1 {v18.16b}, [x0], #16\n"
3766 "st1 {v26.16b}, [x0], #16\n"
3767 "st1 {v11.16b}, [x0], #16\n"
3768 "st1 {v19.16b}, [x0], #16\n"
3769 "st1 {v27.16b}, [x0], #16\n"
3770 "st1 {v12.16b}, [x0], #16\n"
3771 "st1 {v20.16b}, [x0], #16\n"
3772 "st1 {v28.16b}, [x0], #16\n"
3773 "st1 {v13.16b}, [x0], #16\n"
3774 "st1 {v21.16b}, [x0], #16\n"
3775 "st1 {v29.16b}, [x0], #16\n"
3776 "st1 {v14.16b}, [x0], #16\n"
3777 "st1 {v22.16b}, [x0], #16\n"
3778 "st1 {v30.16b}, [x0], #16\n"
3779 "st1 {v15.16b}, [x0], #16\n"
3780 "st1 {v23.16b}, [x0], #16\n"
3781 "st1 {v31.16b}, [x0], #16\n"
3782 : // outputs
3783 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
3784 [depth] "+r"(depth)
3785 : // inputs
3786 [accum_ptr] "r"(accum_ptr)
3787 : // clobbers
3788 "cc", "memory", "x0", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7",
3789 "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17",
3790 "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27",
3791 "v28", "v29", "v30", "v31");
3792 }
3793 };
3794
3795 // Not very efficient kernel, just an experiment to see what we can do
3796 // without using NEON multiply-with-scalar instructions.
3797 struct NEON_64bit_GEMM_Float32_WithVectorDuplicatingScalar {
3798 typedef float OperandType;
3799 typedef float AccumulatorType;
3800 typedef KernelFormat<
3801 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 3>,
3802 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 2> >
3803 Format;
RunNEON_64bit_GEMM_Float32_WithVectorDuplicatingScalar3804 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
3805 AccumulatorType* accum_ptr, int depth) {
3806 asm volatile(
3807 // Load accumulators
3808 "mov x0, %[accum_ptr]\n"
3809 "ld1 {v8.16b}, [x0], #16\n"
3810 "ld1 {v16.16b}, [x0], #16\n"
3811 "ld1 {v24.16b}, [x0], #16\n"
3812 "ld1 {v9.16b}, [x0], #16\n"
3813 "ld1 {v17.16b}, [x0], #16\n"
3814 "ld1 {v25.16b}, [x0], #16\n"
3815 "ld1 {v10.16b}, [x0], #16\n"
3816 "ld1 {v18.16b}, [x0], #16\n"
3817 "ld1 {v26.16b}, [x0], #16\n"
3818 "ld1 {v11.16b}, [x0], #16\n"
3819 "ld1 {v19.16b}, [x0], #16\n"
3820 "ld1 {v27.16b}, [x0], #16\n"
3821 "ld1 {v12.16b}, [x0], #16\n"
3822 "ld1 {v20.16b}, [x0], #16\n"
3823 "ld1 {v28.16b}, [x0], #16\n"
3824 "ld1 {v13.16b}, [x0], #16\n"
3825 "ld1 {v21.16b}, [x0], #16\n"
3826 "ld1 {v29.16b}, [x0], #16\n"
3827 "ld1 {v14.16b}, [x0], #16\n"
3828 "ld1 {v22.16b}, [x0], #16\n"
3829 "ld1 {v30.16b}, [x0], #16\n"
3830 "ld1 {v15.16b}, [x0], #16\n"
3831 "ld1 {v23.16b}, [x0], #16\n"
3832 "ld1 {v31.16b}, [x0], #16\n"
3833
3834 GEMMLOWP_LABEL_LOOP
3835 ":\n"
3836
3837 // Load 2 Rhs cell of size 1x4 each
3838 "ld1 {v5.4s}, [%[rhs_ptr]], #16\n"
3839 "ld1 {v6.4s}, [%[rhs_ptr]], #16\n"
3840
3841 // Load 3 Lhs cells of size 4x1 each
3842 "ld1 {v2.4s}, [%[lhs_ptr]], #16\n"
3843 "ld1 {v3.4s}, [%[lhs_ptr]], #16\n"
3844 "ld1 {v4.4s}, [%[lhs_ptr]], #16\n"
3845
3846 // Multiply-accumulate
3847 "dup v0.4s, v5.s[0]\n"
3848 "dup v1.4s, v5.s[1]\n"
3849 "fmla v8.4s, v2.4s, v0.4s\n"
3850 "fmla v16.4s, v3.4s, v0.4s\n"
3851 "fmla v24.4s, v4.4s, v0.4s\n"
3852 "fmla v9.4s, v2.4s, v1.4s\n"
3853 "fmla v17.4s, v3.4s, v1.4s\n"
3854 "fmla v25.4s, v4.4s, v1.4s\n"
3855 "dup v0.4s, v5.s[2]\n"
3856 "dup v1.4s, v5.s[3]\n"
3857 "fmla v10.4s, v2.4s, v0.4s\n"
3858 "fmla v18.4s, v3.4s, v0.4s\n"
3859 "fmla v26.4s, v4.4s, v0.4s\n"
3860 "fmla v11.4s, v2.4s, v1.4s\n"
3861 "fmla v19.4s, v3.4s, v1.4s\n"
3862 "fmla v27.4s, v4.4s, v1.4s\n"
3863 "dup v0.4s, v6.s[0]\n"
3864 "dup v1.4s, v6.s[1]\n"
3865 "fmla v12.4s, v2.4s, v0.4s\n"
3866 "fmla v20.4s, v3.4s, v0.4s\n"
3867 "fmla v28.4s, v4.4s, v0.4s\n"
3868 "fmla v13.4s, v2.4s, v1.4s\n"
3869 "fmla v21.4s, v3.4s, v1.4s\n"
3870 "fmla v29.4s, v4.4s, v1.4s\n"
3871 "dup v0.4s, v6.s[2]\n"
3872 "dup v1.4s, v6.s[3]\n"
3873 "fmla v14.4s, v2.4s, v0.4s\n"
3874 "fmla v22.4s, v3.4s, v0.4s\n"
3875 "fmla v30.4s, v4.4s, v0.4s\n"
3876 "fmla v15.4s, v2.4s, v1.4s\n"
3877 "fmla v23.4s, v3.4s, v1.4s\n"
3878 "fmla v31.4s, v4.4s, v1.4s\n"
3879
3880 // Loop. Decrement loop index (depth) by 1, since we just handled 1
3881 // level of depth.
3882 "subs %w[depth], %w[depth], #1\n"
3883 "bne " GEMMLOWP_LABEL_LOOP
3884 "b\n"
3885
3886 // Store accumulators
3887 "mov x0, %[accum_ptr]\n"
3888 "st1 {v8.16b}, [x0], #16\n"
3889 "st1 {v16.16b}, [x0], #16\n"
3890 "st1 {v24.16b}, [x0], #16\n"
3891 "st1 {v9.16b}, [x0], #16\n"
3892 "st1 {v17.16b}, [x0], #16\n"
3893 "st1 {v25.16b}, [x0], #16\n"
3894 "st1 {v10.16b}, [x0], #16\n"
3895 "st1 {v18.16b}, [x0], #16\n"
3896 "st1 {v26.16b}, [x0], #16\n"
3897 "st1 {v11.16b}, [x0], #16\n"
3898 "st1 {v19.16b}, [x0], #16\n"
3899 "st1 {v27.16b}, [x0], #16\n"
3900 "st1 {v12.16b}, [x0], #16\n"
3901 "st1 {v20.16b}, [x0], #16\n"
3902 "st1 {v28.16b}, [x0], #16\n"
3903 "st1 {v13.16b}, [x0], #16\n"
3904 "st1 {v21.16b}, [x0], #16\n"
3905 "st1 {v29.16b}, [x0], #16\n"
3906 "st1 {v14.16b}, [x0], #16\n"
3907 "st1 {v22.16b}, [x0], #16\n"
3908 "st1 {v30.16b}, [x0], #16\n"
3909 "st1 {v15.16b}, [x0], #16\n"
3910 "st1 {v23.16b}, [x0], #16\n"
3911 "st1 {v31.16b}, [x0], #16\n"
3912 : // outputs
3913 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
3914 [depth] "+r"(depth)
3915 : // inputs
3916 [accum_ptr] "r"(accum_ptr)
3917 : // clobbers
3918 "cc", "memory", "x0", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7",
3919 "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17",
3920 "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27",
3921 "v28", "v29", "v30", "v31");
3922 }
3923 };
3924
3925 // This is the "most natural" kernel, using NEON multiply-with-scalar
3926 // instructions.
3927 struct NEON_64bit_GEMM_Float32_WithScalar {
3928 typedef float OperandType;
3929 typedef float AccumulatorType;
3930 typedef KernelFormat<
3931 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 3>,
3932 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 2> >
3933 Format;
RunNEON_64bit_GEMM_Float32_WithScalar3934 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
3935 AccumulatorType* accum_ptr, int depth) {
3936 asm volatile(
3937 // Load accumulators
3938 "mov x0, %[accum_ptr]\n"
3939 "ld1 {v8.16b}, [x0], #16\n"
3940 "ld1 {v16.16b}, [x0], #16\n"
3941 "ld1 {v24.16b}, [x0], #16\n"
3942 "ld1 {v9.16b}, [x0], #16\n"
3943 "ld1 {v17.16b}, [x0], #16\n"
3944 "ld1 {v25.16b}, [x0], #16\n"
3945 "ld1 {v10.16b}, [x0], #16\n"
3946 "ld1 {v18.16b}, [x0], #16\n"
3947 "ld1 {v26.16b}, [x0], #16\n"
3948 "ld1 {v11.16b}, [x0], #16\n"
3949 "ld1 {v19.16b}, [x0], #16\n"
3950 "ld1 {v27.16b}, [x0], #16\n"
3951 "ld1 {v12.16b}, [x0], #16\n"
3952 "ld1 {v20.16b}, [x0], #16\n"
3953 "ld1 {v28.16b}, [x0], #16\n"
3954 "ld1 {v13.16b}, [x0], #16\n"
3955 "ld1 {v21.16b}, [x0], #16\n"
3956 "ld1 {v29.16b}, [x0], #16\n"
3957 "ld1 {v14.16b}, [x0], #16\n"
3958 "ld1 {v22.16b}, [x0], #16\n"
3959 "ld1 {v30.16b}, [x0], #16\n"
3960 "ld1 {v15.16b}, [x0], #16\n"
3961 "ld1 {v23.16b}, [x0], #16\n"
3962 "ld1 {v31.16b}, [x0], #16\n"
3963
3964 GEMMLOWP_LABEL_LOOP
3965 ":\n"
3966
3967 // Load 2 Rhs cell of size 1x4 each
3968 "ld1 {v0.4s}, [%[rhs_ptr]], #16\n"
3969 "ld1 {v1.4s}, [%[rhs_ptr]], #16\n"
3970
3971 // Load 3 Lhs cells of size 4x1 each
3972 "ld1 {v2.4s}, [%[lhs_ptr]], #16\n"
3973 "ld1 {v3.4s}, [%[lhs_ptr]], #16\n"
3974 "ld1 {v4.4s}, [%[lhs_ptr]], #16\n"
3975
3976 // Multiply-accumulate
3977 "fmla v8.4s, v2.4s, v0.s[0]\n"
3978 "fmla v9.4s, v2.4s, v0.s[1]\n"
3979 "fmla v10.4s, v2.4s, v0.s[2]\n"
3980 "fmla v11.4s, v2.4s, v0.s[3]\n"
3981 "fmla v12.4s, v2.4s, v1.s[0]\n"
3982 "fmla v13.4s, v2.4s, v1.s[1]\n"
3983 "fmla v14.4s, v2.4s, v1.s[2]\n"
3984 "fmla v15.4s, v2.4s, v1.s[3]\n"
3985 "fmla v16.4s, v3.4s, v0.s[0]\n"
3986 "fmla v17.4s, v3.4s, v0.s[1]\n"
3987 "fmla v18.4s, v3.4s, v0.s[2]\n"
3988 "fmla v19.4s, v3.4s, v0.s[3]\n"
3989 "fmla v20.4s, v3.4s, v1.s[0]\n"
3990 "fmla v21.4s, v3.4s, v1.s[1]\n"
3991 "fmla v22.4s, v3.4s, v1.s[2]\n"
3992 "fmla v23.4s, v3.4s, v1.s[3]\n"
3993 "fmla v24.4s, v4.4s, v0.s[0]\n"
3994 "fmla v25.4s, v4.4s, v0.s[1]\n"
3995 "fmla v26.4s, v4.4s, v0.s[2]\n"
3996 "fmla v27.4s, v4.4s, v0.s[3]\n"
3997 "fmla v28.4s, v4.4s, v1.s[0]\n"
3998 "fmla v29.4s, v4.4s, v1.s[1]\n"
3999 "fmla v30.4s, v4.4s, v1.s[2]\n"
4000 "fmla v31.4s, v4.4s, v1.s[3]\n"
4001
4002 // Loop. Decrement loop index (depth) by 1, since we just handled 1
4003 // level of depth.
4004 "subs %w[depth], %w[depth], #1\n"
4005 "bne " GEMMLOWP_LABEL_LOOP
4006 "b\n"
4007
4008 // Store accumulators
4009 "mov x0, %[accum_ptr]\n"
4010 "st1 {v8.16b}, [x0], #16\n"
4011 "st1 {v16.16b}, [x0], #16\n"
4012 "st1 {v24.16b}, [x0], #16\n"
4013 "st1 {v9.16b}, [x0], #16\n"
4014 "st1 {v17.16b}, [x0], #16\n"
4015 "st1 {v25.16b}, [x0], #16\n"
4016 "st1 {v10.16b}, [x0], #16\n"
4017 "st1 {v18.16b}, [x0], #16\n"
4018 "st1 {v26.16b}, [x0], #16\n"
4019 "st1 {v11.16b}, [x0], #16\n"
4020 "st1 {v19.16b}, [x0], #16\n"
4021 "st1 {v27.16b}, [x0], #16\n"
4022 "st1 {v12.16b}, [x0], #16\n"
4023 "st1 {v20.16b}, [x0], #16\n"
4024 "st1 {v28.16b}, [x0], #16\n"
4025 "st1 {v13.16b}, [x0], #16\n"
4026 "st1 {v21.16b}, [x0], #16\n"
4027 "st1 {v29.16b}, [x0], #16\n"
4028 "st1 {v14.16b}, [x0], #16\n"
4029 "st1 {v22.16b}, [x0], #16\n"
4030 "st1 {v30.16b}, [x0], #16\n"
4031 "st1 {v15.16b}, [x0], #16\n"
4032 "st1 {v23.16b}, [x0], #16\n"
4033 "st1 {v31.16b}, [x0], #16\n"
4034 : // outputs
4035 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
4036 [depth] "+r"(depth)
4037 : // inputs
4038 [accum_ptr] "r"(accum_ptr)
4039 : // clobbers
4040 "cc", "memory", "x0", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7",
4041 "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17",
4042 "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27",
4043 "v28", "v29", "v30", "v31");
4044 }
4045 };
4046
4047 // Faster kernel contributed by ARM. Tuned for A57.
4048 struct NEON_64bit_GEMM_Float32_WithScalar_A57 {
4049 typedef float OperandType;
4050 typedef float AccumulatorType;
4051 typedef KernelFormat<
4052 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 3>,
4053 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 2> >
4054 Format;
RunNEON_64bit_GEMM_Float32_WithScalar_A574055 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
4056 AccumulatorType* accum_ptr, int depth) {
4057 asm volatile(
4058 // Load accumulators
4059 "mov x0, %[accum_ptr]\n"
4060 "ld1 {v8.16b}, [x0], #16\n"
4061 "ld1 {v16.16b}, [x0], #16\n"
4062 "ld1 {v24.16b}, [x0], #16\n"
4063 "ld1 {v9.16b}, [x0], #16\n"
4064 "ld1 {v17.16b}, [x0], #16\n"
4065 "ld1 {v25.16b}, [x0], #16\n"
4066 "ld1 {v10.16b}, [x0], #16\n"
4067 "ld1 {v18.16b}, [x0], #16\n"
4068 "ld1 {v26.16b}, [x0], #16\n"
4069 "ld1 {v11.16b}, [x0], #16\n"
4070 "ld1 {v19.16b}, [x0], #16\n"
4071 "ld1 {v27.16b}, [x0], #16\n"
4072 "ld1 {v12.16b}, [x0], #16\n"
4073 "ld1 {v20.16b}, [x0], #16\n"
4074 "ld1 {v28.16b}, [x0], #16\n"
4075 "ld1 {v13.16b}, [x0], #16\n"
4076 "ld1 {v21.16b}, [x0], #16\n"
4077 "ld1 {v29.16b}, [x0], #16\n"
4078 "ld1 {v14.16b}, [x0], #16\n"
4079 "ld1 {v22.16b}, [x0], #16\n"
4080 "ld1 {v30.16b}, [x0], #16\n"
4081 "ld1 {v15.16b}, [x0], #16\n"
4082 "ld1 {v23.16b}, [x0], #16\n"
4083 "ld1 {v31.16b}, [x0], #16\n"
4084
4085 // The start of the loop assumes first Rhs cell is already loaded, so
4086 // do it here for first iteration.
4087 "ld1 {v0.4s}, [%[rhs_ptr]], #16\n"
4088
4089 // And the same for the first Lhs cell.
4090 "ld1 {v2.4s}, [%[lhs_ptr]], #16\n"
4091
4092 GEMMLOWP_LABEL_LOOP
4093 ":\n"
4094
4095 // Start the MACs at the head of the loop - 1st cell from each side
4096 // already loaded.
4097 "fmla v8.4s, v2.4s, v0.s[0]\n"
4098 "fmla v9.4s, v2.4s, v0.s[1]\n"
4099 "ld1 {v1.4s}, [%[rhs_ptr]], #16\n" // Load second Rhs cell.
4100 "fmla v10.4s, v2.4s, v0.s[2]\n"
4101 "fmla v11.4s, v2.4s, v0.s[3]\n"
4102 "ld1 {v3.4s}, [%[lhs_ptr]], #16\n" // Load second Lhs cell.
4103 "fmla v12.4s, v2.4s, v1.s[0]\n"
4104 "fmla v13.4s, v2.4s, v1.s[1]\n"
4105 "ld1 {v4.4s}, [%[lhs_ptr]], #16\n" // Load third Lhs cell.
4106 "fmla v14.4s, v2.4s, v1.s[2]\n"
4107 "fmla v15.4s, v2.4s, v1.s[3]\n"
4108 "ld1 {v2.4s}, [%[lhs_ptr]], #16\n" // Done with first Lhs cell - load
4109 // for the next iteration early.
4110 "fmla v16.4s, v3.4s, v0.s[0]\n"
4111 "fmla v17.4s, v3.4s, v0.s[1]\n"
4112 "fmla v18.4s, v3.4s, v0.s[2]\n"
4113 "fmla v19.4s, v3.4s, v0.s[3]\n"
4114 "fmla v20.4s, v3.4s, v1.s[0]\n"
4115 "fmla v21.4s, v3.4s, v1.s[1]\n"
4116 "fmla v22.4s, v3.4s, v1.s[2]\n"
4117 "fmla v23.4s, v3.4s, v1.s[3]\n"
4118 "fmla v24.4s, v4.4s, v0.s[0]\n"
4119 "fmla v25.4s, v4.4s, v0.s[1]\n"
4120 "fmla v26.4s, v4.4s, v0.s[2]\n"
4121 "fmla v27.4s, v4.4s, v0.s[3]\n"
4122 "ld1 {v0.4s}, [%[rhs_ptr]], #16\n" // Done with the first Rhs cell -
4123 // load for the next iteration
4124 // early.
4125 "fmla v28.4s, v4.4s, v1.s[0]\n"
4126 "fmla v29.4s, v4.4s, v1.s[1]\n"
4127 // Loop. Decrement loop index (depth) by 1, since we just handled
4128 // 1 level of depth. Do this a bit before the end of the loop for
4129 // better dispatch on A57.
4130 "subs %w[depth], %w[depth], #1\n"
4131 "fmla v30.4s, v4.4s, v1.s[2]\n"
4132 "fmla v31.4s, v4.4s, v1.s[3]\n"
4133
4134 "bne " GEMMLOWP_LABEL_LOOP
4135 "b\n"
4136
4137 // Store accumulators
4138 "mov x0, %[accum_ptr]\n"
4139 "st1 {v8.16b}, [x0], #16\n"
4140 "st1 {v16.16b}, [x0], #16\n"
4141 "st1 {v24.16b}, [x0], #16\n"
4142 "st1 {v9.16b}, [x0], #16\n"
4143 "st1 {v17.16b}, [x0], #16\n"
4144 "st1 {v25.16b}, [x0], #16\n"
4145 "st1 {v10.16b}, [x0], #16\n"
4146 "st1 {v18.16b}, [x0], #16\n"
4147 "st1 {v26.16b}, [x0], #16\n"
4148 "st1 {v11.16b}, [x0], #16\n"
4149 "st1 {v19.16b}, [x0], #16\n"
4150 "st1 {v27.16b}, [x0], #16\n"
4151 "st1 {v12.16b}, [x0], #16\n"
4152 "st1 {v20.16b}, [x0], #16\n"
4153 "st1 {v28.16b}, [x0], #16\n"
4154 "st1 {v13.16b}, [x0], #16\n"
4155 "st1 {v21.16b}, [x0], #16\n"
4156 "st1 {v29.16b}, [x0], #16\n"
4157 "st1 {v14.16b}, [x0], #16\n"
4158 "st1 {v22.16b}, [x0], #16\n"
4159 "st1 {v30.16b}, [x0], #16\n"
4160 "st1 {v15.16b}, [x0], #16\n"
4161 "st1 {v23.16b}, [x0], #16\n"
4162 "st1 {v31.16b}, [x0], #16\n"
4163 : // outputs
4164 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
4165 [depth] "+r"(depth)
4166 : // inputs
4167 [accum_ptr] "r"(accum_ptr)
4168 : // clobbers
4169 "cc", "memory", "x0", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7",
4170 "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17",
4171 "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27",
4172 "v28", "v29", "v30", "v31");
4173 }
4174 };
4175
4176 #ifndef __APPLE__
4177 // Faster kernel contributed by ARM. Tuned for A53.
4178 struct NEON_64bit_GEMM_Float32_WithScalar_A53 {
4179 typedef float OperandType;
4180 typedef float AccumulatorType;
4181 typedef KernelFormat<
4182 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 3>,
4183 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 2> >
4184 Format;
RunNEON_64bit_GEMM_Float32_WithScalar_A534185 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
4186 AccumulatorType* accum_ptr, int depth) {
4187 asm volatile(
4188 // Load accumulators
4189 "mov x0, %[accum_ptr]\n"
4190 "ld1 {v8.16b}, [x0], #16\n"
4191 "ld1 {v16.16b}, [x0], #16\n"
4192 "ld1 {v24.16b}, [x0], #16\n"
4193 "ld1 {v9.16b}, [x0], #16\n"
4194 "ld1 {v17.16b}, [x0], #16\n"
4195 "ld1 {v25.16b}, [x0], #16\n"
4196 "ld1 {v10.16b}, [x0], #16\n"
4197 "ld1 {v18.16b}, [x0], #16\n"
4198 "ld1 {v26.16b}, [x0], #16\n"
4199 "ld1 {v11.16b}, [x0], #16\n"
4200 "ld1 {v19.16b}, [x0], #16\n"
4201 "ld1 {v27.16b}, [x0], #16\n"
4202 "ld1 {v12.16b}, [x0], #16\n"
4203 "ld1 {v20.16b}, [x0], #16\n"
4204 "ld1 {v28.16b}, [x0], #16\n"
4205 "ld1 {v13.16b}, [x0], #16\n"
4206 "ld1 {v21.16b}, [x0], #16\n"
4207 "ld1 {v29.16b}, [x0], #16\n"
4208 "ld1 {v14.16b}, [x0], #16\n"
4209 "ld1 {v22.16b}, [x0], #16\n"
4210 "ld1 {v30.16b}, [x0], #16\n"
4211 "ld1 {v15.16b}, [x0], #16\n"
4212 "ld1 {v23.16b}, [x0], #16\n"
4213 "ld1 {v31.16b}, [x0], #16\n"
4214
4215 // For A53, a very different-looking loop is needed.
4216 //
4217 // The main reason for this is that on A53 128-bit loads take two
4218 // cycles during which no dual issue can occur. Doing two separate
4219 // 64-bit loads avoids this issue - they each take one cycle and are
4220 // able to dual issue. Since vector register loads don't dual issue
4221 // with FMLA, we load half the register as normal and the other half
4222 // into an integer register. This second half can then be moved into
4223 // place later with an INS instruction - which will dual issue with a
4224 // later FP load.
4225 //
4226 // For this kernel there are approximately 3 times as many multiplies
4227 // as loads, so it makes sense to structure the loop into blocks of 4
4228 // cycles, with 1 dedicated "load cycle" and 3 "multiply cycles" per
4229 // block. Strictly preserving this structure with NOPs where no load
4230 // is needed seems to result in higher performance.
4231 //
4232 // Choice of x18 to store the upper halves on their way into the
4233 // vector registers is arbitrary. Added to the clobber list so that
4234 // the compiler will make it available.
4235 //
4236 //
4237 // At the start of the loop, it is assumed that v0 is "half loaded" -
4238 // bottom half in place in d0 and the upper half in x18 ready to
4239 // insert. So set that up here for the first iteration:
4240 "ldr d0, [%[rhs_ptr]]\n" // Bottom half of first Rhs cell
4241 "ldr x18, [%[rhs_ptr], #8]\n" // Upper half
4242 "add %[rhs_ptr], %[rhs_ptr], #16\n" // Separate increment (needed as
4243 // there is no operation to load at
4244 // reg + 8 but then increment reg
4245 // by 16).
4246
4247 // v2 should be fully loaded - as it's outside the loop proper it's fine
4248 // to use a 128-bit load here.
4249 "ld1 {v2.4s}, [%[lhs_ptr]], #16\n" // first Lhs cell
4250
4251 GEMMLOWP_LABEL_LOOP
4252 ":\n"
4253
4254 // First block of four cycles. Multplies all require v2 and v0; v2 is
4255 // loaded earlier and v0 is half loaded and completed in the load
4256 // cycle at the start.
4257 "ldr d1, [%[rhs_ptr]]\n" // "load" cycle - loading bottom half of v1
4258 // (second Rhs cell).
4259 "ins v0.d[1], x18\n" // "load" cycle - moving the upper half of v0 into
4260 // place.
4261 "fmla v8.4s, v2.4s, v0.s[0]\n" // "fmla" cycle 1 - first multiply.
4262 "ldr x18, [%[rhs_ptr], #8]\n" // "fmla" cycle 1 - load upper half of v1
4263 // into x18.
4264 "fmla v9.4s, v2.4s, v0.s[1]\n" // "fmla" cycle 2 - second multiply
4265 "add %[rhs_ptr], %[rhs_ptr], #16\n" // "fmla" cycle 2 - increment Rhs
4266 // pointer (if needed)
4267 "fmla v10.4s, v2.4s, v0.s[2]\n" // "fmla" cycle 3 - third multiply. No
4268 // more work to dual issue.
4269
4270 // Second block. Start loading v3 (second Lhs cell), finish loading v1.
4271 "ldr d3, [%[lhs_ptr]]\n"
4272 "ins v1.d[1], x18\n" // v1 ready here.
4273 "fmla v11.4s, v2.4s, v0.s[3]\n"
4274 "ldr x18, [%[lhs_ptr], #8]\n"
4275 "fmla v12.4s, v2.4s, v1.s[0]\n" // First use of v1.
4276 "add %[lhs_ptr], %[lhs_ptr], #16\n"
4277 "fmla v13.4s, v2.4s, v1.s[1]\n"
4278
4279 // Third block. Start loading v4 (third Lhs cell), finish loading v3.
4280 "ldr d4, [%[lhs_ptr]]\n"
4281 "ins v3.d[1], x18\n" // v3 ready here.
4282 "fmla v14.4s, v2.4s, v1.s[2]\n"
4283 "ldr x18, [%[lhs_ptr], #8]\n"
4284 "fmla v15.4s, v2.4s, v1.s[3]\n"
4285 "add %[lhs_ptr], %[lhs_ptr], #16\n"
4286 "fmla v16.4s, v3.4s, v0.s[0]\n" // First use of v3.
4287
4288 // Fourth block. v2 (first Lhs cell) is now finished with, so start
4289 // loading value for next iteration. Finish loading v4.
4290 "ldr d2, [%[lhs_ptr]]\n"
4291 "ins v4.d[1], x18\n" // v4 ready here.
4292 "fmla v17.4s, v3.4s, v0.s[1]\n"
4293 "ldr x18, [%[lhs_ptr], #8]\n"
4294 "fmla v18.4s, v3.4s, v0.s[2]\n"
4295 "add %[lhs_ptr], %[lhs_ptr], #16\n"
4296 "fmla v19.4s, v3.4s, v0.s[3]\n"
4297
4298 // Fifth block, finish loading v2. No new load to start as the other
4299 // registers are all still live.
4300 "ins v2.d[1], x18\n"
4301 "fmla v20.4s, v3.4s, v1.s[0]\n"
4302 "fmla v21.4s, v3.4s, v1.s[1]\n"
4303 "fmla v22.4s, v3.4s, v1.s[2]\n"
4304
4305 // Sixth block, nothing to load. 2 nops needed as a single nop would
4306 // dual issue with the FMLA and break the timing.
4307 "nop\n"
4308 "nop\n"
4309 "fmla v23.4s, v3.4s, v1.s[3]\n"
4310 "fmla v24.4s, v4.4s, v0.s[0]\n" // First use of v4.
4311 "fmla v25.4s, v4.4s, v0.s[1]\n"
4312
4313 // Seventh block, nothing to load. Decrement the loop counter in this
4314 // block as the last block is very full.
4315 "nop\n"
4316 "nop\n"
4317 "fmla v26.4s, v4.4s, v0.s[2]\n"
4318 "subs %w[depth], %w[depth], #1\n"
4319 "fmla v27.4s, v4.4s, v0.s[3]\n"
4320 "fmla v28.4s, v4.4s, v1.s[0]\n"
4321
4322 // Eighth block - start loading v0 for next iteration.
4323 "ldr d0, [%[rhs_ptr]]\n"
4324 "fmla v29.4s, v4.4s, v1.s[1]\n"
4325 "ldr x18, [%[rhs_ptr], #8]\n"
4326 "fmla v30.4s, v4.4s, v1.s[2]\n"
4327 "add %[rhs_ptr], %[rhs_ptr], #16\n"
4328 "fmla v31.4s, v4.4s, v1.s[3]\n"
4329
4330 // Loop branch. This will dual issue in fmla cycle 3 of the 8th block.
4331 "bne " GEMMLOWP_LABEL_LOOP
4332 "b\n"
4333
4334 // Store accumulators
4335 "mov x0, %[accum_ptr]\n"
4336 "st1 {v8.16b}, [x0], #16\n"
4337 "st1 {v16.16b}, [x0], #16\n"
4338 "st1 {v24.16b}, [x0], #16\n"
4339 "st1 {v9.16b}, [x0], #16\n"
4340 "st1 {v17.16b}, [x0], #16\n"
4341 "st1 {v25.16b}, [x0], #16\n"
4342 "st1 {v10.16b}, [x0], #16\n"
4343 "st1 {v18.16b}, [x0], #16\n"
4344 "st1 {v26.16b}, [x0], #16\n"
4345 "st1 {v11.16b}, [x0], #16\n"
4346 "st1 {v19.16b}, [x0], #16\n"
4347 "st1 {v27.16b}, [x0], #16\n"
4348 "st1 {v12.16b}, [x0], #16\n"
4349 "st1 {v20.16b}, [x0], #16\n"
4350 "st1 {v28.16b}, [x0], #16\n"
4351 "st1 {v13.16b}, [x0], #16\n"
4352 "st1 {v21.16b}, [x0], #16\n"
4353 "st1 {v29.16b}, [x0], #16\n"
4354 "st1 {v14.16b}, [x0], #16\n"
4355 "st1 {v22.16b}, [x0], #16\n"
4356 "st1 {v30.16b}, [x0], #16\n"
4357 "st1 {v15.16b}, [x0], #16\n"
4358 "st1 {v23.16b}, [x0], #16\n"
4359 "st1 {v31.16b}, [x0], #16\n"
4360 : // outputs
4361 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
4362 [depth] "+r"(depth)
4363 : // inputs
4364 [accum_ptr] "r"(accum_ptr)
4365 : // clobbers
4366 "cc", "memory", "x0", "x18", "v0", "v1", "v2", "v3", "v4", "v5", "v6",
4367 "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16",
4368 "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26",
4369 "v27", "v28", "v29", "v30", "v31");
4370 }
4371 };
4372 #endif
4373
4374 // Faster kernel contributed by ARM. Tuned for A55r1.
4375 struct NEON_64bit_GEMM_Float32_WithScalar_A55r1 {
4376 typedef float OperandType;
4377 typedef float AccumulatorType;
4378 typedef KernelFormat<
4379 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 3>,
4380 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 2> >
4381 Format;
RunNEON_64bit_GEMM_Float32_WithScalar_A55r14382 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
4383 AccumulatorType* accum_ptr, int depth) {
4384 asm volatile(
4385 // Load accumulators
4386 "mov x0, %[accum_ptr]\n"
4387 "ld1 {v8.4s}, [x0], #16\n"
4388 "ld1 {v16.4s}, [x0], #16\n"
4389 "ld1 {v24.4s}, [x0], #16\n"
4390 "ld1 {v9.4s}, [x0], #16\n"
4391 "ld1 {v17.4s}, [x0], #16\n"
4392 "ld1 {v25.4s}, [x0], #16\n"
4393 "ld1 {v10.4s}, [x0], #16\n"
4394 "ld1 {v18.4s}, [x0], #16\n"
4395 "ld1 {v26.4s}, [x0], #16\n"
4396 "ld1 {v11.4s}, [x0], #16\n"
4397 "ld1 {v19.4s}, [x0], #16\n"
4398 "ld1 {v27.4s}, [x0], #16\n"
4399 "ld1 {v12.4s}, [x0], #16\n"
4400 "ld1 {v20.4s}, [x0], #16\n"
4401 "ld1 {v28.4s}, [x0], #16\n"
4402 "ld1 {v13.4s}, [x0], #16\n"
4403 "ld1 {v21.4s}, [x0], #16\n"
4404 "ld1 {v29.4s}, [x0], #16\n"
4405 "ld1 {v14.4s}, [x0], #16\n"
4406 "ld1 {v22.4s}, [x0], #16\n"
4407 "ld1 {v30.4s}, [x0], #16\n"
4408 "ld1 {v15.4s}, [x0], #16\n"
4409 "ld1 {v23.4s}, [x0], #16\n"
4410 "ld1 {v31.4s}, [x0], #16\n"
4411
4412 // A55r1 requires a hybrid of the A53 and standard approaches.
4413 //
4414 // Like A53, this processor prefers 64-bit loads.
4415 //
4416 // Unlike A53, it is capable of dual-issuing a 64-bit vector load
4417 // (or INS) with a FMLA instruction.
4418 //
4419 // Therefore we aim to issue an FMLA instruction every cycle.
4420 // Alongside three FMLAs we can dual issue a (vector) 64-bit load, a
4421 // scalar 64-bit load and finally an INS to replicate the effect of
4422 // a single 128-bit load.
4423 //
4424 // The loop contains 24 FMLA instructions, and 5 vector registers
4425 // need to be loaded, consuming 15 dual issue slots. This leaves 9
4426 // dual issue slots. Four of these are used for loop housekeeping
4427 // (2 pointer adds, 1 counter update and 1 branch), leaving 5 left
4428 // over (marked by blank lines).
4429 //
4430 // Choice of x18 to store the upper halves on their way into the
4431 // vector registers is arbitrary. Added to the clobber list so that
4432 // the compiler will make it available.
4433
4434
4435 // At the start of the loop, it is assumed that v0 is "half loaded" -
4436 // bottom half in place in d0 and the upper half in x18 ready to
4437 // insert. So set that up here for the first iteration:
4438 "ldr d0, [%[rhs_ptr]]\n" // Bottom half of first Rhs cell
4439 "ldr x18, [%[rhs_ptr], #8]\n" // Upper half
4440
4441 // v2-v3 should be fully loaded - as it's outside the loop proper it's fine
4442 // to use a 128-bit load here.
4443 "ldr q2, [%[lhs_ptr]]\n" // first Lhs cell
4444 "ldr q3, [%[lhs_ptr], #16]\n" // second Lhs cell
4445
4446 GEMMLOWP_LABEL_LOOP
4447 ":\n"
4448
4449 "fmla v8.4s, v2.4s, v0.s[0]\n"
4450 "ldr d1, [%[rhs_ptr], #16]\n" // Bottom half of v1
4451 "fmla v9.4s, v2.4s, v0.s[1]\n"
4452 "ins v0.d[1], x18\n" // Finish loading v0
4453 "fmla v16.4s, v3.4s, v0.s[0]\n" // out of sequence - used to reduce load/use pressure.
4454 "ldr x18, [%[rhs_ptr], #24]\n" // Top half of v1 to X register
4455 "fmla v17.4s, v3.4s, v0.s[1]\n" // out of sequence - used to reduce load/use pressure.
4456 "add %[rhs_ptr], %[rhs_ptr], #32\n" // RHS loads complete - increment pointer.
4457 "fmla v10.4s, v2.4s, v0.s[2]\n"
4458 "ldr d4, [%[lhs_ptr], #32]\n" // Bottom half of v4
4459 "fmla v11.4s, v2.4s, v0.s[3]\n"
4460 "ins v1.d[1], x18\n" // Finish loading v1
4461 "fmla v12.4s, v2.4s, v1.s[0]\n"
4462 "ldr x18, [%[lhs_ptr], #40]\n" // Top half of v4 to X register
4463 "fmla v13.4s, v2.4s, v1.s[1]\n"
4464 "add %[lhs_ptr], %[lhs_ptr], #48\n" // LHS loads complete - increment pointer.
4465 "fmla v14.4s, v2.4s, v1.s[2]\n"
4466
4467 "fmla v15.4s, v2.4s, v1.s[3]\n"
4468 "ldr d2, [%[lhs_ptr]]\n" // Bottom half of v2 (for next time)
4469 "fmla v18.4s, v3.4s, v0.s[2]\n"
4470 "ins v4.d[1], x18\n" // Finish loading v4
4471 "fmla v19.4s, v3.4s, v0.s[3]\n"
4472 "ldr x18, [%[lhs_ptr], #8]\n" // Top half of next v2 to X register
4473 "fmla v20.4s, v3.4s, v1.s[0]\n"
4474 "subs %w[depth], %w[depth], #1\n"
4475 "fmla v21.4s, v3.4s, v1.s[1]\n"
4476
4477 "fmla v22.4s, v3.4s, v1.s[2]\n"
4478
4479 "fmla v23.4s, v3.4s, v1.s[3]\n"
4480 "ldr d3, [%[lhs_ptr], #16]\n" // Bottom half of v3 (for next time)
4481 "fmla v24.4s, v4.4s, v0.s[0]\n"
4482 "ins v2.d[1], x18\n" // Finish loading next v2
4483 "fmla v25.4s, v4.4s, v0.s[1]\n"
4484 "ldr x18, [%[lhs_ptr], #24]\n" // Top half of next v3 to X register
4485 "fmla v26.4s, v4.4s, v0.s[2]\n"
4486
4487 "fmla v27.4s, v4.4s, v0.s[3]\n"
4488 "ldr d0, [%[rhs_ptr]]\n" // Bottom half of v0 (for next time)
4489 "fmla v28.4s, v4.4s, v1.s[0]\n"
4490 "ins v3.d[1], x18\n" // Finish loading next v3
4491 "fmla v29.4s, v4.4s, v1.s[1]\n"
4492 "ldr x18, [%[rhs_ptr], #8]\n" // Top half of next v0 to X register
4493 "fmla v30.4s, v4.4s, v1.s[2]\n"
4494
4495 "fmla v31.4s, v4.4s, v1.s[3]\n"
4496 "bne " GEMMLOWP_LABEL_LOOP "b\n"
4497
4498 // Store accumulators
4499 "mov x0, %[accum_ptr]\n"
4500 "st1 {v8.4s}, [x0], #16\n"
4501 "st1 {v16.4s}, [x0], #16\n"
4502 "st1 {v24.4s}, [x0], #16\n"
4503 "st1 {v9.4s}, [x0], #16\n"
4504 "st1 {v17.4s}, [x0], #16\n"
4505 "st1 {v25.4s}, [x0], #16\n"
4506 "st1 {v10.4s}, [x0], #16\n"
4507 "st1 {v18.4s}, [x0], #16\n"
4508 "st1 {v26.4s}, [x0], #16\n"
4509 "st1 {v11.4s}, [x0], #16\n"
4510 "st1 {v19.4s}, [x0], #16\n"
4511 "st1 {v27.4s}, [x0], #16\n"
4512 "st1 {v12.4s}, [x0], #16\n"
4513 "st1 {v20.4s}, [x0], #16\n"
4514 "st1 {v28.4s}, [x0], #16\n"
4515 "st1 {v13.4s}, [x0], #16\n"
4516 "st1 {v21.4s}, [x0], #16\n"
4517 "st1 {v29.4s}, [x0], #16\n"
4518 "st1 {v14.4s}, [x0], #16\n"
4519 "st1 {v22.4s}, [x0], #16\n"
4520 "st1 {v30.4s}, [x0], #16\n"
4521 "st1 {v15.4s}, [x0], #16\n"
4522 "st1 {v23.4s}, [x0], #16\n"
4523 "st1 {v31.4s}, [x0], #16\n"
4524 : // outputs
4525 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
4526 [depth] "+r"(depth)
4527 : // inputs
4528 [accum_ptr] "r"(accum_ptr)
4529 : // clobbers
4530 "cc", "memory", "x0", "x18", "v0", "v1", "v2", "v3", "v4", "v5", "v6",
4531 "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16",
4532 "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26",
4533 "v27", "v28", "v29", "v30", "v31");
4534 }
4535 };
4536
4537 #endif // __aarch64__
4538
4539 #if defined(__arm__) || defined(__aarch64__)
4540 #ifndef __aarch64__
vpaddq_s32(int32x4_t a,int32x4_t b)4541 inline int32x4_t vpaddq_s32(int32x4_t a, int32x4_t b) {
4542 const int32x2_t c = vpadd_s32(vget_low_s32(a), vget_high_s32(a));
4543 const int32x2_t d = vpadd_s32(vget_low_s32(b), vget_high_s32(b));
4544 return vcombine_s32(c, d);
4545 }
4546 #endif
4547
4548 // C++ intrinsics-based variant of the deep, int8, fast kernel
4549 template <int Cols>
4550 struct NEON_GEMM_Int8Operands_AccumTwoWithin16Bits_intrinsics {
4551 typedef std::int8_t OperandType;
4552 typedef std::int32_t AccumulatorType;
4553 typedef KernelFormat<
4554 KernelSideFormat<CellFormat<4, 16, CellOrder::WidthMajor>, 1>,
4555 KernelSideFormat<CellFormat<Cols, 16, CellOrder::WidthMajor>, 1> >
4556 Format;
RunNEON_GEMM_Int8Operands_AccumTwoWithin16Bits_intrinsics4557 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
4558 AccumulatorType* accum_ptr, int depth) {
4559 int32x4_t acc[4][Cols];
4560 for (int i = 0; i < 4; i++) {
4561 for (int j = 0; j < Cols; j++) {
4562 acc[i][j] = vdupq_n_s32(0);
4563 }
4564 }
4565 for (int d = 0; d < depth; d += 16) {
4566 int8x16_t lhs[4];
4567 for (int i = 0; i < 4; i++) {
4568 lhs[i] = vld1q_s8(lhs_ptr + 16 * i);
4569 }
4570 int8x16_t rhs[Cols];
4571 for (int i = 0; i < Cols; i++) {
4572 rhs[i] = vld1q_s8(rhs_ptr + 16 * i);
4573 }
4574 for (int i = 0; i < 4; i++) {
4575 for (int j = 0; j < Cols; j++) {
4576 int16x8_t local_acc =
4577 vmull_s8(vget_low_s8(lhs[i]), vget_low_s8(rhs[j]));
4578 local_acc =
4579 vmlal_s8(local_acc, vget_high_s8(lhs[i]), vget_high_s8(rhs[j]));
4580 acc[i][j] = vpadalq_s16(acc[i][j], local_acc);
4581 }
4582 }
4583 lhs_ptr += 64;
4584 rhs_ptr += 16 * Cols;
4585 }
4586 for (int i = 0; i < Cols; i++) {
4587 int32x4_t acc_2x_0 = vpaddq_s32(acc[0][i], acc[1][i]);
4588 int32x4_t acc_2x_1 = vpaddq_s32(acc[2][i], acc[3][i]);
4589 int32x4_t acc_4x = vpaddq_s32(acc_2x_0, acc_2x_1);
4590 int32x4_t dst_val = vld1q_s32(accum_ptr + 4 * i);
4591 dst_val = vaddq_s32(dst_val, acc_4x);
4592 vst1q_s32(accum_ptr + 4 * i, dst_val);
4593 }
4594 }
4595 };
4596
4597 using NEON_64bit_GEMM_Int8Operands_AccumTwoWithin16Bits_intrinsics =
4598 NEON_GEMM_Int8Operands_AccumTwoWithin16Bits_intrinsics<4>;
4599
4600 using NEON_32bit_GEMM_Int8Operands_AccumTwoWithin16Bits_intrinsics =
4601 NEON_GEMM_Int8Operands_AccumTwoWithin16Bits_intrinsics<2>;
4602
4603 // C++ intrinsics-based variant of the wide, uint8, general kernel
4604 template <int RhsCells>
4605 struct NEON_GEMM_Uint8Operands_Uint32Accumulators_intrinsics {
4606 typedef std::uint8_t OperandType;
4607 typedef std::int32_t AccumulatorType;
4608 typedef KernelFormat<
4609 KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 3>,
4610 KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, RhsCells> >
4611 Format;
RunNEON_GEMM_Uint8Operands_Uint32Accumulators_intrinsics4612 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
4613 AccumulatorType* accum_ptr, int depth) {
4614 int32x4_t acc[3][4 * RhsCells];
4615 for (int i = 0; i < 3; i++) {
4616 for (int j = 0; j < 4 * RhsCells; j++) {
4617 acc[i][j] = vld1q_s32(accum_ptr + 4 * (i + 3 * j));
4618 }
4619 }
4620 for (int d = 0; d < depth; d += 2) {
4621 int16x8_t lhs[3];
4622 for (int i = 0; i < 3; i++) {
4623 lhs[i] = vreinterpretq_s16_u16(vmovl_u8(vld1_u8(lhs_ptr + 8 * i)));
4624 }
4625 int16x8_t rhs[RhsCells];
4626 for (int i = 0; i < RhsCells; i++) {
4627 rhs[i] = vreinterpretq_s16_u16(vmovl_u8(vld1_u8(rhs_ptr + 8 * i)));
4628 }
4629 for (int i = 0; i < 3; i++) {
4630 for (int j = 0; j < RhsCells; j++) {
4631 acc[i][4 * j + 0] = vmlal_lane_s16(
4632 acc[i][4 * j + 0], vget_low_s16(lhs[i]), vget_low_s16(rhs[j]), 0);
4633 acc[i][4 * j + 1] = vmlal_lane_s16(
4634 acc[i][4 * j + 1], vget_low_s16(lhs[i]), vget_low_s16(rhs[j]), 1);
4635 acc[i][4 * j + 2] = vmlal_lane_s16(
4636 acc[i][4 * j + 2], vget_low_s16(lhs[i]), vget_low_s16(rhs[j]), 2);
4637 acc[i][4 * j + 3] = vmlal_lane_s16(
4638 acc[i][4 * j + 3], vget_low_s16(lhs[i]), vget_low_s16(rhs[j]), 3);
4639 acc[i][4 * j + 0] =
4640 vmlal_lane_s16(acc[i][4 * j + 0], vget_high_s16(lhs[i]),
4641 vget_high_s16(rhs[j]), 0);
4642 acc[i][4 * j + 1] =
4643 vmlal_lane_s16(acc[i][4 * j + 1], vget_high_s16(lhs[i]),
4644 vget_high_s16(rhs[j]), 1);
4645 acc[i][4 * j + 2] =
4646 vmlal_lane_s16(acc[i][4 * j + 2], vget_high_s16(lhs[i]),
4647 vget_high_s16(rhs[j]), 2);
4648 acc[i][4 * j + 3] =
4649 vmlal_lane_s16(acc[i][4 * j + 3], vget_high_s16(lhs[i]),
4650 vget_high_s16(rhs[j]), 3);
4651 }
4652 }
4653 lhs_ptr += 24;
4654 rhs_ptr += 8 * RhsCells;
4655 }
4656 for (int i = 0; i < 3; i++) {
4657 for (int j = 0; j < 4 * RhsCells; j++) {
4658 vst1q_s32(accum_ptr + 4 * (i + 3 * j), acc[i][j]);
4659 }
4660 }
4661 }
4662 };
4663
4664 using NEON_32bit_GEMM_Uint8Operands_Uint32Accumulators_intrinsics =
4665 NEON_GEMM_Uint8Operands_Uint32Accumulators_intrinsics<1>;
4666
4667 using NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_intrinsics =
4668 NEON_GEMM_Uint8Operands_Uint32Accumulators_intrinsics<2>;
4669
4670 template <int RhsCells>
4671 struct NEON_GEMM_Float32_WithScalar_intrinsics {
4672 typedef float OperandType;
4673 typedef float AccumulatorType;
4674 typedef KernelFormat<
4675 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 3>,
4676 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, RhsCells> >
4677 Format;
RunNEON_GEMM_Float32_WithScalar_intrinsics4678 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
4679 AccumulatorType* accum_ptr, int depth) {
4680 float32x4_t acc[3][4 * RhsCells];
4681 for (int i = 0; i < 3; i++) {
4682 for (int j = 0; j < 4 * RhsCells; j++) {
4683 acc[i][j] = vld1q_f32(accum_ptr + 4 * (i + 3 * j));
4684 }
4685 }
4686 for (int d = 0; d < depth; d++) {
4687 float32x4_t lhs[3];
4688 for (int i = 0; i < 3; i++) {
4689 lhs[i] = vld1q_f32(lhs_ptr + 4 * i);
4690 }
4691 float32x4_t rhs[RhsCells];
4692 for (int i = 0; i < RhsCells; i++) {
4693 rhs[i] = vld1q_f32(rhs_ptr + 4 * i);
4694 }
4695 for (int i = 0; i < 3; i++) {
4696 for (int j = 0; j < RhsCells; j++) {
4697 acc[i][4 * j + 0] = vmlaq_lane_f32(acc[i][4 * j + 0], lhs[i],
4698 vget_low_f32(rhs[j]), 0);
4699 acc[i][4 * j + 1] = vmlaq_lane_f32(acc[i][4 * j + 1], lhs[i],
4700 vget_low_f32(rhs[j]), 1);
4701 acc[i][4 * j + 2] = vmlaq_lane_f32(acc[i][4 * j + 2], lhs[i],
4702 vget_high_f32(rhs[j]), 0);
4703 acc[i][4 * j + 3] = vmlaq_lane_f32(acc[i][4 * j + 3], lhs[i],
4704 vget_high_f32(rhs[j]), 1);
4705 }
4706 }
4707 lhs_ptr += 12;
4708 rhs_ptr += 4 * RhsCells;
4709 }
4710 for (int i = 0; i < 3; i++) {
4711 for (int j = 0; j < 4 * RhsCells; j++) {
4712 vst1q_f32(accum_ptr + 4 * (i + 3 * j), acc[i][j]);
4713 }
4714 }
4715 }
4716 };
4717
4718 using NEON_32bit_GEMM_Float32_WithScalar_intrinsics =
4719 NEON_GEMM_Float32_WithScalar_intrinsics<1>;
4720
4721 using NEON_64bit_GEMM_Float32_WithScalar_intrinsics =
4722 NEON_GEMM_Float32_WithScalar_intrinsics<2>;
4723
4724 // C++ intrinsics-based variant of the deep, 7-bit, fast kernel
4725 struct NEON_64bit_GEMM_Int7Operands_AccumEightWithin16Bits_intrinsics {
4726 typedef std::int8_t OperandType;
4727 typedef std::int32_t AccumulatorType;
4728 typedef KernelFormat<
4729 KernelSideFormat<CellFormat<4, 16, CellOrder::WidthMajor>, 1>,
4730 KernelSideFormat<CellFormat<2, 16, CellOrder::WidthMajor>, 1> >
4731 Format;
RunNEON_64bit_GEMM_Int7Operands_AccumEightWithin16Bits_intrinsics4732 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
4733 AccumulatorType* accum_ptr, int depth) {
4734 int32x4_t acc[4][2];
4735 for (int i = 0; i < 4; i++) {
4736 for (int j = 0; j < 2; j++) {
4737 acc[i][j] = vdupq_n_s32(0);
4738 }
4739 }
4740
4741 int d = 0;
4742 for (; d <= depth - 64; d += 64) {
4743 int16x8_t local_acc[4][2];
4744 for (int i = 0; i < 4; i++) {
4745 for (int j = 0; j < 2; j++) {
4746 local_acc[i][j] = vdupq_n_s16(0);
4747 }
4748 }
4749
4750 // There are not enough registers to fit all lhs and rhs values for 64
4751 // depth. Instead, load values for 32 depth at a time.
4752 for (int k = 0; k < 2; k++) {
4753 int8x16_t lhs[4][2];
4754 for (int i = 0; i < 4; i++) {
4755 lhs[i][0] = vld1q_s8(lhs_ptr + 16 * i + 128 * k);
4756 lhs[i][1] = vld1q_s8(lhs_ptr + 64 + 16 * i + 128 * k);
4757 }
4758
4759 int8x16_t rhs[4];
4760 for (int i = 0; i < 4; i++) {
4761 rhs[i] = vld1q_s8(rhs_ptr + 16 * i + 64 * k);
4762 }
4763
4764 for (int i = 0; i < 4; i++) {
4765 if (k == 0) {
4766 local_acc[i][0] = vmull_s8(vget_low_s8(lhs[i][0]),
4767 vget_low_s8(rhs[0]));
4768 local_acc[i][0] = vmlal_s8(local_acc[i][0], vget_low_s8(lhs[i][1]),
4769 vget_low_s8(rhs[2]));
4770 local_acc[i][1] = vmull_s8(vget_low_s8(lhs[i][0]),
4771 vget_low_s8(rhs[1]));
4772 local_acc[i][1] = vmlal_s8(local_acc[i][1],
4773 vget_low_s8(lhs[i][1]),
4774 vget_low_s8(rhs[3]));
4775 } else {
4776 local_acc[i][0] = vmlal_s8(local_acc[i][0], vget_low_s8(lhs[i][0]),
4777 vget_low_s8(rhs[0]));
4778 local_acc[i][0] = vmlal_s8(local_acc[i][0], vget_low_s8(lhs[i][1]),
4779 vget_low_s8(rhs[2]));
4780 local_acc[i][1] = vmlal_s8(local_acc[i][1], vget_low_s8(lhs[i][0]),
4781 vget_low_s8(rhs[1]));
4782 local_acc[i][1] = vmlal_s8(local_acc[i][1], vget_low_s8(lhs[i][1]),
4783 vget_low_s8(rhs[3]));
4784 }
4785
4786 local_acc[i][0] = vmlal_s8(local_acc[i][0], vget_high_s8(lhs[i][0]),
4787 vget_high_s8(rhs[0]));
4788 local_acc[i][0] = vmlal_s8(local_acc[i][0], vget_high_s8(lhs[i][1]),
4789 vget_high_s8(rhs[2]));
4790 local_acc[i][1] = vmlal_s8(local_acc[i][1], vget_high_s8(lhs[i][0]),
4791 vget_high_s8(rhs[1]));
4792 local_acc[i][1] = vmlal_s8(local_acc[i][1], vget_high_s8(lhs[i][1]),
4793 vget_high_s8(rhs[3]));
4794 }
4795 }
4796
4797 for (int i = 0; i < 4; i++) {
4798 acc[i][0] = vpadalq_s16(acc[i][0], local_acc[i][0]);
4799 acc[i][1] = vpadalq_s16(acc[i][1], local_acc[i][1]);
4800 }
4801
4802 lhs_ptr += 64 * 4;
4803 rhs_ptr += 64 * 2;
4804 }
4805 for (; d <= depth - 16; d += 16) {
4806 int8x16_t lhs[4];
4807 for (int i = 0; i < 4; i++) {
4808 lhs[i] = vld1q_s8(lhs_ptr + 16 * i);
4809 }
4810 int8x16_t rhs[2];
4811 for (int i = 0; i < 2; i++) {
4812 rhs[i] = vld1q_s8(rhs_ptr + 16 * i);
4813 }
4814
4815 for (int i = 0; i < 4; i++) {
4816 for (int j = 0; j < 2; j++) {
4817 int16x8_t local_acc =
4818 vmull_s8(vget_low_s8(lhs[i]), vget_low_s8(rhs[j]));
4819 local_acc =
4820 vmlal_s8(local_acc, vget_high_s8(lhs[i]), vget_high_s8(rhs[j]));
4821 acc[i][j] = vpadalq_s16(acc[i][j], local_acc);
4822 }
4823 }
4824 lhs_ptr += 16 * 4;
4825 rhs_ptr += 16 * 2;
4826 }
4827 for (int i = 0; i < 2; i++) {
4828 int32x4_t acc_2x_0 = vpaddq_s32(acc[0][i], acc[1][i]);
4829 int32x4_t acc_2x_1 = vpaddq_s32(acc[2][i], acc[3][i]);
4830 int32x4_t acc_4x = vpaddq_s32(acc_2x_0, acc_2x_1);
4831 int32x4_t dst_val = vld1q_s32(accum_ptr + 4 * i);
4832 dst_val = vaddq_s32(dst_val, acc_4x);
4833 vst1q_s32(accum_ptr + 4 * i, dst_val);
4834 }
4835 }
4836 };
4837
4838 SET_7BIT_RANGES(NEON_64bit_GEMM_Int7Operands_AccumEightWithin16Bits_intrinsics);
4839
4840 // C++ intrinsics-based variant of the deep, 4.25-bit, fast kernel.
4841 struct NEON_64bit_GEMM_Int425Operands_intrinsics {
4842 typedef std::int8_t OperandType;
4843 typedef std::int32_t AccumulatorType;
4844 typedef KernelFormat<
4845 KernelSideFormat<CellFormat<4, 32, CellOrder::WidthMajor>, 1>,
4846 KernelSideFormat<CellFormat<2, 32, CellOrder::WidthMajor>, 1> >
4847 Format;
RunNEON_64bit_GEMM_Int425Operands_intrinsics4848 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
4849 AccumulatorType* accum_ptr, int depth) {
4850 int32x4_t acc[4][2];
4851 for (int i = 0; i < 4; i++) {
4852 for (int j = 0; j < 2; j++) {
4853 acc[i][j] = vdupq_n_s32(0);
4854 }
4855 }
4856
4857 const int num_outer_depth_loop = depth / 512 + 1;
4858 int d = 0;
4859 for (int od = 0; od < num_outer_depth_loop; od++) {
4860 int16x8_t local_acc[4][2];
4861 for (int i = 0; i < 4; i++) {
4862 for (int j = 0; j < 2; j++) {
4863 local_acc[i][j] = vdupq_n_s16(0);
4864 }
4865 }
4866 for (int k = 0; k < 16 && d <= depth - 32; k++, d += 32) {
4867 int8x16_t lhs[8];
4868 for (int i = 0; i < 8; i++) {
4869 lhs[i] = vld1q_s8(lhs_ptr + 16 * i);
4870 }
4871 int8x16_t rhs[4];
4872 for (int i = 0; i < 4; i++) {
4873 rhs[i] = vld1q_s8(rhs_ptr + 16 * i);
4874 }
4875 for (int i = 0; i < 4; i++) {
4876 for (int j = 0; j < 2; j++) {
4877 int8x16_t temp_acc = vmulq_s8(lhs[i * 2], rhs[j * 2]);
4878 temp_acc = vmlaq_s8(temp_acc, lhs[i * 2 + 1], rhs[j * 2 + 1]);
4879 local_acc[i][j] = vpadalq_s8(local_acc[i][j], temp_acc);
4880 }
4881 }
4882 lhs_ptr += 128;
4883 rhs_ptr += 64;
4884 }
4885
4886 for (int i = 0; i < 4; i++) {
4887 for (int j = 0; j < 2; j++) {
4888 acc[i][j] = vpadalq_s16(acc[i][j], local_acc[i][j]);
4889 }
4890 }
4891 }
4892
4893 for (int i = 0; i < 2; i++) {
4894 int32x4_t acc_2x_0 = vpaddq_s32(acc[0][i], acc[1][i]);
4895 int32x4_t acc_2x_1 = vpaddq_s32(acc[2][i], acc[3][i]);
4896 int32x4_t acc_4x = vpaddq_s32(acc_2x_0, acc_2x_1);
4897
4898 int32x4_t dst_val = vld1q_s32(accum_ptr + 4 * i);
4899 dst_val = vaddq_s32(dst_val, acc_4x);
4900 vst1q_s32(accum_ptr + 4 * i, dst_val);
4901 }
4902 }
4903 };
4904
4905 SET_425BIT_RANGES(NEON_64bit_GEMM_Int425Operands_intrinsics);
4906
4907 #endif // __arm__ || __aarch64__
4908
4909 #ifdef __mips
4910 // 12x8 depth 2 depth-major kernel.
4911 struct MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators1 {
4912 typedef std::uint8_t OperandType;
4913 typedef std::uint32_t AccumulatorType;
4914 typedef KernelFormat<
4915 KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 3>,
4916 KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 2> >
4917 Format;
RunMSA_GEMM_12x8_Uint8Operands_Uint32Accumulators14918 static void Run(OperandType* lhs_ptr, OperandType* rhs_ptr,
4919 AccumulatorType* accum_ptr, int depth) {
4920 asm volatile(
4921 // Load accumulators
4922 "ld.w $w0, (0*16)(%[accum_ptr])\n"
4923 "ld.w $w4, (1*16)(%[accum_ptr])\n"
4924 "ld.w $w8, (2*16)(%[accum_ptr])\n"
4925 "ld.w $w1, (3*16)(%[accum_ptr])\n"
4926 "ld.w $w5, (4*16)(%[accum_ptr])\n"
4927 "ld.w $w9, (5*16)(%[accum_ptr])\n"
4928 "ld.w $w2, (6*16)(%[accum_ptr])\n"
4929 "ld.w $w6, (7*16)(%[accum_ptr])\n"
4930 "ld.w $w10, (8*16)(%[accum_ptr])\n"
4931 "ld.w $w3, (9*16)(%[accum_ptr])\n"
4932 "ld.w $w7, (10*16)(%[accum_ptr])\n"
4933 "ld.w $w11, (11*16)(%[accum_ptr])\n"
4934 "ld.w $w12, (12*16)(%[accum_ptr])\n"
4935 "ld.w $w16, (13*16)(%[accum_ptr])\n"
4936 "ld.w $w20, (14*16)(%[accum_ptr])\n"
4937 "ld.w $w13, (15*16)(%[accum_ptr])\n"
4938 "ld.w $w17, (16*16)(%[accum_ptr])\n"
4939 "ld.w $w21, (17*16)(%[accum_ptr])\n"
4940 "ld.w $w14, (18*16)(%[accum_ptr])\n"
4941 "ld.w $w18, (19*16)(%[accum_ptr])\n"
4942 "ld.w $w22, (20*16)(%[accum_ptr])\n"
4943 "ld.w $w15, (21*16)(%[accum_ptr])\n"
4944 "ld.w $w19, (22*16)(%[accum_ptr])\n"
4945 "ld.w $w23, (23*16)(%[accum_ptr])\n"
4946 // Set a temp to all zeroes.
4947 "ldi.b $w31, 0\n"
4948
4949 GEMMLOWP_LABEL_LOOP ":\n"
4950 // Overview of register layout:
4951 //
4952 // A half of the 2 2x4 cells of Rhs is stored in 16bit in w27-w30
4953 // (each register contains 4 replicas of a pair of elements).
4954 // A 12x2 block of 3 4x2 cells Lhs is stored in 16bit in w24-w26.
4955 // A 12x8 block of accumulators is stored in 32bit in w0-w23.
4956 //
4957 // +------+------+------+------+
4958 // Rhs |w27 |w28 |w29 |w30 |
4959 // +------+------+------+------+
4960 //
4961 // | | | | |
4962 //
4963 // Lhs | | | | |
4964 //
4965 // +---+ - - - - +------+------+------+------+
4966 // |w24| |w0/12 |w1/13 |w2/14 |w3/15 |
4967 // |w24| |w0/12 |w1/13 |w2/14 |w3/15 |
4968 // |w24| |w0/12 |w1/13 |w2/14 |w3/15 |
4969 // |w24| |w0/12 |w1/13 |w2/14 |w3/15 |
4970 // +---+ - - - - +------+------+------+------+
4971 // |w25| |w4/16 |w5/17 |w6/18 |w7/19 |
4972 // |w25| |w4/16 |w5/17 |w6/18 |w7/19 |
4973 // |w25| |w4/16 |w5/17 |w6/18 |w7/19 |
4974 // |w25| |w4/16 |w5/17 |w6/18 |w7/19 |
4975 // +---+ - - - - +------+------+------+------+
4976 // |w26| |w8/20 |w9/21 |w10/22|w11/23|
4977 // |w26| |w8/20 |w9/21 |w10/22|w11/23|
4978 // |w26| |w8/20 |w9/21 |w10/22|w11/23|
4979 // |w26| |w8/20 |w9/21 |w10/22|w11/23|
4980 // +---+ - - - - +------+------+------+------+
4981 //
4982 // Accumulators
4983
4984 // Load 3 x 8 bytes of lhs[] with 2 16-byte overlapped loads.
4985 "ld.b $w24, 0(%[lhs_ptr])\n"
4986 "ld.b $w25, 8(%[lhs_ptr])\n"
4987
4988 // Load 4 bytes of rhs[] for the first half of depth 0.
4989 "lbu $a0, 0(%[rhs_ptr])\n"
4990 "lbu $a1, 1(%[rhs_ptr])\n"
4991 "lbu $a2, 2(%[rhs_ptr])\n"
4992 "lbu $a3, 3(%[rhs_ptr])\n"
4993 // Load 4 bytes of rhs[] for the first half of depth 1.
4994 "lbu $v0, 4(%[rhs_ptr])\n"
4995 "lbu $v1, 5(%[rhs_ptr])\n"
4996 "lbu $t8, 6(%[rhs_ptr])\n"
4997 "lbu $t9, 7(%[rhs_ptr])\n"
4998
4999 // Zero-extend 8-bit elements of lhs[] to 16 bits.
5000 "ilvr.b $w24, $w31, $w24\n"
5001 "ilvl.b $w26, $w31, $w25\n"
5002 "ilvr.b $w25, $w31, $w25\n"
5003 // Interleave depth 0 and depth 1 elements of lhs[] for dpadd_u.w.
5004 "ilvl.d $w27, $w31, $w24\n"
5005 "ilvl.d $w28, $w31, $w25\n"
5006 "ilvl.d $w29, $w31, $w26\n"
5007 "ilvr.h $w24, $w27, $w24\n"
5008 "ilvr.h $w25, $w28, $w25\n"
5009 "ilvr.h $w26, $w29, $w26\n"
5010
5011 // Combine and interleave depth 0 and depth 1 elements of rhs[] for dpadd_u.w
5012 // (for the first half).
5013 "ins $a0, $v0, 16, 8\n"
5014 "ins $a1, $v1, 16, 8\n"
5015 "ins $a2, $t8, 16, 8\n"
5016 "ins $a3, $t9, 16, 8\n"
5017 // Make 4 replicas of every pair of rhs[] elements.
5018 "fill.w $w27, $a0\n"
5019 "fill.w $w28, $a1\n"
5020 "fill.w $w29, $a2\n"
5021 "fill.w $w30, $a3\n"
5022
5023 // Load 4 bytes of rhs[] for the second half of depth 0.
5024 "lbu $a0, 8(%[rhs_ptr])\n"
5025 "lbu $a1, 9(%[rhs_ptr])\n"
5026 "lbu $a2, 10(%[rhs_ptr])\n"
5027 "lbu $a3, 11(%[rhs_ptr])\n"
5028 // Load 4 bytes of rhs[] for the second half of depth 1.
5029 "lbu $v0, 12(%[rhs_ptr])\n"
5030 "lbu $v1, 13(%[rhs_ptr])\n"
5031 "lbu $t8, 14(%[rhs_ptr])\n"
5032 "lbu $t9, 15(%[rhs_ptr])\n"
5033
5034 // First half of depths 0 and 1.
5035 // Dot-product-(and)-add doubles multiplicand width.
5036 "dpadd_u.w $w0, $w24, $w27\n"
5037 "dpadd_u.w $w4, $w25, $w27\n"
5038 "dpadd_u.w $w8, $w26, $w27\n"
5039 "dpadd_u.w $w1, $w24, $w28\n"
5040 "dpadd_u.w $w5, $w25, $w28\n"
5041 "dpadd_u.w $w9, $w26, $w28\n"
5042 "dpadd_u.w $w2, $w24, $w29\n"
5043 "dpadd_u.w $w6, $w25, $w29\n"
5044 "dpadd_u.w $w10, $w26, $w29\n"
5045 "dpadd_u.w $w3, $w24, $w30\n"
5046 "dpadd_u.w $w7, $w25, $w30\n"
5047 "dpadd_u.w $w11, $w26, $w30\n"
5048
5049 // Combine and interleave depth 0 and depth 1 elements of rhs[] for dpadd_u.w
5050 // (for the second half).
5051 "ins $a0, $v0, 16, 8\n"
5052 "ins $a1, $v1, 16, 8\n"
5053 "ins $a2, $t8, 16, 8\n"
5054 "ins $a3, $t9, 16, 8\n"
5055 // Make 4 replicas of every pair of rhs[] elements.
5056 "fill.w $w27, $a0\n"
5057 "fill.w $w28, $a1\n"
5058 "fill.w $w29, $a2\n"
5059 "fill.w $w30, $a3\n"
5060
5061 // Second half of depths 0 and 1.
5062 // Dot-product-(and)-add doubles multiplicand width.
5063 "dpadd_u.w $w12, $w24, $w27\n"
5064 "dpadd_u.w $w16, $w25, $w27\n"
5065 "dpadd_u.w $w20, $w26, $w27\n"
5066 "dpadd_u.w $w13, $w24, $w28\n"
5067 "dpadd_u.w $w17, $w25, $w28\n"
5068 "dpadd_u.w $w21, $w26, $w28\n"
5069 "dpadd_u.w $w14, $w24, $w29\n"
5070 "dpadd_u.w $w18, $w25, $w29\n"
5071 "dpadd_u.w $w22, $w26, $w29\n"
5072 "dpadd_u.w $w15, $w24, $w30\n"
5073 "dpadd_u.w $w19, $w25, $w30\n"
5074 "dpadd_u.w $w23, $w26, $w30\n"
5075
5076 "addiu %[depth], -2\n"
5077 GEMMLOWP_MIPS_XADDIU " %[lhs_ptr], 24\n"
5078 GEMMLOWP_MIPS_XADDIU " %[rhs_ptr], 16\n"
5079 "bnez %[depth]," GEMMLOWP_LABEL_LOOP "b\n"
5080
5081 // Store accumulators.
5082 "st.w $w0, (0*16)(%[accum_ptr])\n"
5083 "st.w $w4, (1*16)(%[accum_ptr])\n"
5084 "st.w $w8, (2*16)(%[accum_ptr])\n"
5085 "st.w $w1, (3*16)(%[accum_ptr])\n"
5086 "st.w $w5, (4*16)(%[accum_ptr])\n"
5087 "st.w $w9, (5*16)(%[accum_ptr])\n"
5088 "st.w $w2, (6*16)(%[accum_ptr])\n"
5089 "st.w $w6, (7*16)(%[accum_ptr])\n"
5090 "st.w $w10, (8*16)(%[accum_ptr])\n"
5091 "st.w $w3, (9*16)(%[accum_ptr])\n"
5092 "st.w $w7, (10*16)(%[accum_ptr])\n"
5093 "st.w $w11, (11*16)(%[accum_ptr])\n"
5094 "st.w $w12, (12*16)(%[accum_ptr])\n"
5095 "st.w $w16, (13*16)(%[accum_ptr])\n"
5096 "st.w $w20, (14*16)(%[accum_ptr])\n"
5097 "st.w $w13, (15*16)(%[accum_ptr])\n"
5098 "st.w $w17, (16*16)(%[accum_ptr])\n"
5099 "st.w $w21, (17*16)(%[accum_ptr])\n"
5100 "st.w $w14, (18*16)(%[accum_ptr])\n"
5101 "st.w $w18, (19*16)(%[accum_ptr])\n"
5102 "st.w $w22, (20*16)(%[accum_ptr])\n"
5103 "st.w $w15, (21*16)(%[accum_ptr])\n"
5104 "st.w $w19, (22*16)(%[accum_ptr])\n"
5105 "st.w $w23, (23*16)(%[accum_ptr])\n"
5106 : // outputs
5107 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
5108 [depth] "+r"(depth)
5109 : // inputs
5110 [accum_ptr] "r"(accum_ptr)
5111 : // clobbers
5112 "memory",
5113 "v0", "v1",
5114 "a0", "a1", "a2", "a3",
5115 "t8", "t9",
5116 "$f0", "$f1", "$f2", "$f3", "$f4", "$f5", "$f6", "$f7",
5117 "$f8", "$f9", "$f10", "$f11", "$f12", "$f13", "$f14", "$f15",
5118 "$f16", "$f17", "$f18", "$f19", "$f20", "$f21", "$f22", "$f23",
5119 "$f24", "$f25", "$f26", "$f27", "$f28", "$f29", "$f30", "$f31");
5120 }
5121 };
5122
5123 // 12x8 depth 2 width-major kernel.
5124 // Does less shuffling and replication than the kernel above.
5125 struct MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators2 {
5126 typedef std::uint8_t OperandType;
5127 typedef std::uint32_t AccumulatorType;
5128 typedef KernelFormat<
5129 KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, 3>,
5130 KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, 2> >
5131 Format;
RunMSA_GEMM_12x8_Uint8Operands_Uint32Accumulators25132 static void Run(OperandType* lhs_ptr, OperandType* rhs_ptr,
5133 AccumulatorType* accum_ptr, int depth) {
5134 asm volatile(
5135 // Load accumulators
5136 "ld.w $w0, (0*16)(%[accum_ptr])\n"
5137 "ld.w $w4, (1*16)(%[accum_ptr])\n"
5138 "ld.w $w8, (2*16)(%[accum_ptr])\n"
5139 "ld.w $w1, (3*16)(%[accum_ptr])\n"
5140 "ld.w $w5, (4*16)(%[accum_ptr])\n"
5141 "ld.w $w9, (5*16)(%[accum_ptr])\n"
5142 "ld.w $w2, (6*16)(%[accum_ptr])\n"
5143 "ld.w $w6, (7*16)(%[accum_ptr])\n"
5144 "ld.w $w10, (8*16)(%[accum_ptr])\n"
5145 "ld.w $w3, (9*16)(%[accum_ptr])\n"
5146 "ld.w $w7, (10*16)(%[accum_ptr])\n"
5147 "ld.w $w11, (11*16)(%[accum_ptr])\n"
5148 "ld.w $w12, (12*16)(%[accum_ptr])\n"
5149 "ld.w $w16, (13*16)(%[accum_ptr])\n"
5150 "ld.w $w20, (14*16)(%[accum_ptr])\n"
5151 "ld.w $w13, (15*16)(%[accum_ptr])\n"
5152 "ld.w $w17, (16*16)(%[accum_ptr])\n"
5153 "ld.w $w21, (17*16)(%[accum_ptr])\n"
5154 "ld.w $w14, (18*16)(%[accum_ptr])\n"
5155 "ld.w $w18, (19*16)(%[accum_ptr])\n"
5156 "ld.w $w22, (20*16)(%[accum_ptr])\n"
5157 "ld.w $w15, (21*16)(%[accum_ptr])\n"
5158 "ld.w $w19, (22*16)(%[accum_ptr])\n"
5159 "ld.w $w23, (23*16)(%[accum_ptr])\n"
5160
5161 GEMMLOWP_LABEL_LOOP
5162 ":\n"
5163 // Overview of register layout:
5164 //
5165 // A half of the 2 2x4 cells of Rhs is stored in 16bit in w28-w31
5166 // (each register contains 4 replicas of a pair of elements).
5167 // A 12x2 block of 3 4x2 cells Lhs is stored in 16bit in w24-w26.
5168 // A 12x8 block of accumulators is stored in 32bit in w0-w23.
5169 //
5170 // +------+------+------+------+
5171 // Rhs |w28 |w29 |w30 |w31 |
5172 // +------+------+------+------+
5173 //
5174 // | | | | |
5175 //
5176 // Lhs | | | | |
5177 //
5178 // +---+ - - - - +------+------+------+------+
5179 // |w24| |w0/12 |w1/13 |w2/14 |w3/15 |
5180 // |w24| |w0/12 |w1/13 |w2/14 |w3/15 |
5181 // |w24| |w0/12 |w1/13 |w2/14 |w3/15 |
5182 // |w24| |w0/12 |w1/13 |w2/14 |w3/15 |
5183 // +---+ - - - - +------+------+------+------+
5184 // |w25| |w4/16 |w5/17 |w6/18 |w7/19 |
5185 // |w25| |w4/16 |w5/17 |w6/18 |w7/19 |
5186 // |w25| |w4/16 |w5/17 |w6/18 |w7/19 |
5187 // |w25| |w4/16 |w5/17 |w6/18 |w7/19 |
5188 // +---+ - - - - +------+------+------+------+
5189 // |w26| |w8/20 |w9/21 |w10/22|w11/23|
5190 // |w26| |w8/20 |w9/21 |w10/22|w11/23|
5191 // |w26| |w8/20 |w9/21 |w10/22|w11/23|
5192 // |w26| |w8/20 |w9/21 |w10/22|w11/23|
5193 // +---+ - - - - +------+------+------+------+
5194 //
5195 // Accumulators
5196
5197 // Load 3 x 8 bytes of lhs[] with 2 16-byte overlapped loads.
5198 "ld.b $w24, 0(%[lhs_ptr])\n"
5199 "ld.b $w25, 8(%[lhs_ptr])\n"
5200
5201 // Load 2 x 8 bytes of rhs[].
5202 "ld.b $w27, 0(%[rhs_ptr])\n"
5203
5204 // Zero-extend 8-bit elements of lhs[] to 16 bits.
5205 "ldi.b $w31, 0\n"
5206 "ilvr.b $w24, $w31, $w24\n"
5207 "ilvl.b $w26, $w31, $w25\n"
5208 "ilvr.b $w25, $w31, $w25\n"
5209
5210 // First half of depths 0 and 1.
5211 // Zero-extend 8-bit elements of rhs[] to 16 bits.
5212 "ilvr.b $w31, $w31, $w27\n"
5213 // Make 4 replicas of every pair of rhs[] elements.
5214 "splati.w $w28, $w31[0]\n"
5215 "splati.w $w29, $w31[1]\n"
5216 "splati.w $w30, $w31[2]\n"
5217 "splati.w $w31, $w31[3]\n"
5218 // Dot-product-(and)-add doubles multiplicand width.
5219 "dpadd_u.w $w0, $w24, $w28\n"
5220 "dpadd_u.w $w4, $w25, $w28\n"
5221 "dpadd_u.w $w8, $w26, $w28\n"
5222 "dpadd_u.w $w1, $w24, $w29\n"
5223 "dpadd_u.w $w5, $w25, $w29\n"
5224 "dpadd_u.w $w9, $w26, $w29\n"
5225 "dpadd_u.w $w2, $w24, $w30\n"
5226 "dpadd_u.w $w6, $w25, $w30\n"
5227 "dpadd_u.w $w10, $w26, $w30\n"
5228 "dpadd_u.w $w3, $w24, $w31\n"
5229 "dpadd_u.w $w7, $w25, $w31\n"
5230 "dpadd_u.w $w11, $w26, $w31\n"
5231
5232 // Second half of depths 0 and 1.
5233 // Zero-extend 8-bit elements of rhs[] to 16 bits.
5234 "ldi.b $w31, 0\n"
5235 "ilvl.b $w31, $w31, $w27\n"
5236 // Make 4 replicas of every pair of rhs[] elements.
5237 "splati.w $w28, $w31[0]\n"
5238 "splati.w $w29, $w31[1]\n"
5239 "splati.w $w30, $w31[2]\n"
5240 "splati.w $w31, $w31[3]\n"
5241 // Dot-product-(and)-add doubles multiplicand width.
5242 "dpadd_u.w $w12, $w24, $w28\n"
5243 "dpadd_u.w $w16, $w25, $w28\n"
5244 "dpadd_u.w $w20, $w26, $w28\n"
5245 "dpadd_u.w $w13, $w24, $w29\n"
5246 "dpadd_u.w $w17, $w25, $w29\n"
5247 "dpadd_u.w $w21, $w26, $w29\n"
5248 "dpadd_u.w $w14, $w24, $w30\n"
5249 "dpadd_u.w $w18, $w25, $w30\n"
5250 "dpadd_u.w $w22, $w26, $w30\n"
5251 "dpadd_u.w $w15, $w24, $w31\n"
5252 "dpadd_u.w $w19, $w25, $w31\n"
5253 "dpadd_u.w $w23, $w26, $w31\n"
5254
5255 "addiu %[depth], -2\n" GEMMLOWP_MIPS_XADDIU
5256 " %[lhs_ptr], 24\n" GEMMLOWP_MIPS_XADDIU
5257 " %[rhs_ptr], 16\n"
5258 "bnez %[depth]," GEMMLOWP_LABEL_LOOP
5259 "b\n"
5260
5261 // Store accumulators.
5262 "st.w $w0, (0*16)(%[accum_ptr])\n"
5263 "st.w $w4, (1*16)(%[accum_ptr])\n"
5264 "st.w $w8, (2*16)(%[accum_ptr])\n"
5265 "st.w $w1, (3*16)(%[accum_ptr])\n"
5266 "st.w $w5, (4*16)(%[accum_ptr])\n"
5267 "st.w $w9, (5*16)(%[accum_ptr])\n"
5268 "st.w $w2, (6*16)(%[accum_ptr])\n"
5269 "st.w $w6, (7*16)(%[accum_ptr])\n"
5270 "st.w $w10, (8*16)(%[accum_ptr])\n"
5271 "st.w $w3, (9*16)(%[accum_ptr])\n"
5272 "st.w $w7, (10*16)(%[accum_ptr])\n"
5273 "st.w $w11, (11*16)(%[accum_ptr])\n"
5274 "st.w $w12, (12*16)(%[accum_ptr])\n"
5275 "st.w $w16, (13*16)(%[accum_ptr])\n"
5276 "st.w $w20, (14*16)(%[accum_ptr])\n"
5277 "st.w $w13, (15*16)(%[accum_ptr])\n"
5278 "st.w $w17, (16*16)(%[accum_ptr])\n"
5279 "st.w $w21, (17*16)(%[accum_ptr])\n"
5280 "st.w $w14, (18*16)(%[accum_ptr])\n"
5281 "st.w $w18, (19*16)(%[accum_ptr])\n"
5282 "st.w $w22, (20*16)(%[accum_ptr])\n"
5283 "st.w $w15, (21*16)(%[accum_ptr])\n"
5284 "st.w $w19, (22*16)(%[accum_ptr])\n"
5285 "st.w $w23, (23*16)(%[accum_ptr])\n"
5286 : // outputs
5287 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
5288 [depth] "+r"(depth)
5289 : // inputs
5290 [accum_ptr] "r"(accum_ptr)
5291 : // clobbers
5292 "memory", "$f0", "$f1", "$f2", "$f3", "$f4", "$f5", "$f6", "$f7", "$f8",
5293 "$f9", "$f10", "$f11", "$f12", "$f13", "$f14", "$f15", "$f16", "$f17",
5294 "$f18", "$f19", "$f20", "$f21", "$f22", "$f23", "$f24", "$f25", "$f26",
5295 "$f27", "$f28", "$f29", "$f30", "$f31");
5296 }
5297 };
5298
5299 // 4x4 depth 16 width-major kernel operating on int8 operands.
5300 // It is assumed that one of the two int8 operands only takes values
5301 // in [-127, 127], while the other may freely range in [-128, 127].
5302 // The issue with both operands taking the value -128 is that:
5303 // -128*-128 + -128*-128 == -32768 overflows int16.
5304 // Every other expression a*b + c*d, for any int8 a,b,c,d, fits in int16
5305 // range. That is the basic idea of this kernel.
5306 struct MSA_GEMM_Int8Operands_AccumTwoWithin16Bits {
5307 typedef std::int8_t OperandType;
5308 typedef std::int32_t AccumulatorType;
5309 typedef KernelFormat<
5310 KernelSideFormat<CellFormat<4, 16, CellOrder::WidthMajor>, 1>,
5311 KernelSideFormat<CellFormat<4, 16, CellOrder::WidthMajor>, 1> >
5312 Format;
RunMSA_GEMM_Int8Operands_AccumTwoWithin16Bits5313 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
5314 AccumulatorType* accum_ptr, int depth) {
5315 std::size_t start_depth = 123;
5316 std::size_t run_depth = depth;
5317 std::size_t dst_col_stride = 4;
5318 AccumulatorType* dst_ptr = accum_ptr;
5319 #define GEMMLOWP_LABEL_AFTER_LOOP_LAST16 "1"
5320 #define GEMMLOWP_LABEL_LOOP "2"
5321 #define GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES "3"
5322 #define GEMMLOWP_LABEL_STORE "4"
5323 asm volatile(
5324 GEMMLOWP_MIPS_XADDIU " %[run_depth], -16\n"
5325 // Load lhs[] and rhs[], zero out internal accumulators.
5326 "ld.b $w16, 0(%[lhs_ptr])\n"
5327 "ldi.b $w0, 0\n"
5328 "ld.b $w20, 0(%[rhs_ptr])\n"
5329 "ldi.b $w1, 0\n"
5330 "ld.b $w17, 16(%[lhs_ptr])\n"
5331 "ldi.b $w2, 0\n"
5332 "ld.b $w21, 16(%[rhs_ptr])\n"
5333 "ldi.b $w3, 0\n"
5334 "ld.b $w18, 32(%[lhs_ptr])\n"
5335 "ldi.b $w4, 0\n"
5336 "ld.b $w19, 48(%[lhs_ptr])\n"
5337 "ldi.b $w5, 0\n"
5338 "ld.b $w22, 32(%[rhs_ptr])\n"
5339 "ldi.b $w6, 0\n"
5340 "ld.b $w23, 48(%[rhs_ptr])\n"
5341 "ldi.b $w7, 0\n"
5342 "ldi.b $w8, 0\n"
5343 "ldi.b $w9, 0\n"
5344 "ldi.b $w10, 0\n"
5345 "ldi.b $w11, 0\n"
5346 "ldi.b $w12, 0\n"
5347 "ldi.b $w13, 0\n"
5348 "ldi.b $w14, 0\n"
5349 "ldi.b $w15, 0\n"
5350 "ldi.h $w31, 1\n"
5351 // If the loop depth is only 16, then we can skip the general loop
5352 // and go straight to the final part of the code.
5353 "beqz %[run_depth], " GEMMLOWP_LABEL_AFTER_LOOP_LAST16 "f\n"
5354
5355 GEMMLOWP_LABEL_LOOP ":\n"
5356 // Overview of register layout:
5357 //
5358 // A 4x16 block of Rhs is stored in 8 bit in w16-w19.
5359 // A 4x16 block of Lhs is stored in 8 bit in w20-w23.
5360 //
5361 // A 4x4 block of accumulators is stored in w0-w15 (as 4x32 bit
5362 // components which need to be horizontally added at the end).
5363 //
5364 // Dot products of Lhs and Rhs are 16-bit values, which can't
5365 // immediately be accumulated in 32-bit accumulators by that
5366 // same instruction that calculates them.
5367 // For example, "dotp_s.h $w25, $w16, $w20" produces 8 16-bit
5368 // sums in w25 (note, the 16 sums have already been reduced to 8
5369 // by the horizontal addition of the dotp instruction).
5370 // They are then sign-extended to 32 bits, horizontally added
5371 // (again) to form 4 32-bit sums and then they are finally added
5372 // to the 32-bit accumulators, all by "dpadd_s.w $w0, $w25, $w31".
5373 //
5374 // +-----+-----+-----+-----+
5375 // Rhs | w20 | w21 | w22 | w23 |
5376 // +-----+-----+-----+-----+
5377 //
5378 // | | | | |
5379 //
5380 // Lhs | | | | |
5381 //
5382 // +---+ - - - - +-----+-----+-----+-----+
5383 // |w16| | w0 | w4 | w8 | w12 |
5384 // |w17| | w1 | w5 | w9 | w13 |
5385 // |w18| | w2 | w6 | w10 | w14 |
5386 // |w19| | w3 | w7 | w11 | w15 |
5387 // +---+ - - - - +-----+-----+-----+-----+
5388 //
5389 // Accumulators
5390
5391 // Calculate the results for 16 depths and load
5392 // lhs[] and rhs[] for the next iteration.
5393 GEMMLOWP_MIPS_XADDIU " %[lhs_ptr], 64\n"
5394 GEMMLOWP_MIPS_XADDIU " %[rhs_ptr], 64\n"
5395 GEMMLOWP_MIPS_XADDIU " %[run_depth], -16\n"
5396
5397 // Dot product: multiply-add pairs of adjacent int8 elements.
5398 // Each dot product takes 16*2 int8 values in and produces 8 int16 sums.
5399 "dotp_s.h $w25, $w16, $w20\n"
5400 "dotp_s.h $w26, $w17, $w20\n"
5401 "dotp_s.h $w27, $w16, $w21\n"
5402 "dotp_s.h $w28, $w17, $w21\n"
5403 "dotp_s.h $w29, $w18, $w20\n"
5404 // Horizontal add of pairs of adjacent int16 sums into internal int32
5405 // accumulators.
5406 "dpadd_s.w $w0, $w25, $w31\n"
5407 "dpadd_s.w $w1, $w26, $w31\n"
5408 "dpadd_s.w $w4, $w27, $w31\n"
5409 "dpadd_s.w $w5, $w28, $w31\n"
5410 "dpadd_s.w $w2, $w29, $w31\n"
5411
5412 // Dot product: multiply-add pairs of adjacent int8 elements.
5413 // Each dot product takes 16*2 int8 values in and produces 8 int16 sums.
5414 "dotp_s.h $w24, $w16, $w22\n"
5415 "dotp_s.h $w25, $w19, $w20\n"
5416 "dotp_s.h $w26, $w16, $w23\n"
5417 "dotp_s.h $w27, $w17, $w22\n"
5418 "ld.b $w20, 0(%[rhs_ptr])\n"
5419 "dotp_s.h $w28, $w17, $w23\n"
5420 "ld.b $w16, 0(%[lhs_ptr])\n"
5421 "dotp_s.h $w29, $w18, $w21\n"
5422 "ld.b $w17, 16(%[lhs_ptr])\n"
5423 // Horizontal add of pairs of adjacent int16 sums into internal int32
5424 // accumulators.
5425 "dpadd_s.w $w8, $w24, $w31\n"
5426 "dpadd_s.w $w3, $w25, $w31\n"
5427 "dpadd_s.w $w12, $w26, $w31\n"
5428 "dpadd_s.w $w9, $w27, $w31\n"
5429 "dpadd_s.w $w13, $w28, $w31\n"
5430 "dpadd_s.w $w6, $w29, $w31\n"
5431
5432 // Dot product: multiply-add pairs of adjacent int8 elements.
5433 // Each dot product takes 16*2 int8 values in and produces 8 int16 sums.
5434 "dotp_s.h $w25, $w19, $w21\n"
5435 "dotp_s.h $w26, $w18, $w22\n"
5436 "dotp_s.h $w27, $w18, $w23\n"
5437 "ld.b $w21, 16(%[rhs_ptr])\n"
5438 "dotp_s.h $w28, $w19, $w22\n"
5439 "ld.b $w18, 32(%[lhs_ptr])\n"
5440 "dotp_s.h $w29, $w19, $w23\n"
5441 "ld.b $w22, 32(%[rhs_ptr])\n"
5442 // Horizontal add of pairs of adjacent int16 sums into internal int32
5443 // accumulators.
5444 "dpadd_s.w $w7, $w25, $w31\n"
5445 "ld.b $w19, 48(%[lhs_ptr])\n"
5446 "dpadd_s.w $w10, $w26, $w31\n"
5447 "ld.b $w23, 48(%[rhs_ptr])\n"
5448 "dpadd_s.w $w14, $w27, $w31\n"
5449 "dpadd_s.w $w11, $w28, $w31\n"
5450 "dpadd_s.w $w15, $w29, $w31\n"
5451
5452 "bnez %[run_depth], " GEMMLOWP_LABEL_LOOP "b\n"
5453
5454 GEMMLOWP_LABEL_AFTER_LOOP_LAST16 ":\n"
5455 // Calculate the results for the last 16 depths.
5456
5457 // Dot product: multiply-add pairs of adjacent int8 elements.
5458 // Each dot product takes 16*2 int8 values in and produces 8 int16 sums.
5459 "dotp_s.h $w25, $w16, $w20\n"
5460 "dotp_s.h $w26, $w17, $w20\n"
5461 "dotp_s.h $w27, $w16, $w21\n"
5462 "dotp_s.h $w28, $w17, $w21\n"
5463 "dotp_s.h $w29, $w18, $w20\n"
5464 // Horizontal add of pairs of adjacent int16 sums into internal int32
5465 // accumulators.
5466 "dpadd_s.w $w0, $w25, $w31\n"
5467 "dpadd_s.w $w1, $w26, $w31\n"
5468 "dpadd_s.w $w4, $w27, $w31\n"
5469 "dpadd_s.w $w5, $w28, $w31\n"
5470 "dpadd_s.w $w2, $w29, $w31\n"
5471
5472 // Dot product: multiply-add pairs of adjacent int8 elements.
5473 // Each dot product takes 16*2 int8 values in and produces 8 int16 sums.
5474 "dotp_s.h $w24, $w16, $w22\n"
5475 "dotp_s.h $w25, $w19, $w20\n"
5476 "dotp_s.h $w26, $w16, $w23\n"
5477 "dotp_s.h $w27, $w17, $w22\n"
5478 "dotp_s.h $w28, $w17, $w23\n"
5479 "dotp_s.h $w29, $w18, $w21\n"
5480 // Horizontal add of pairs of adjacent int16 sums into internal int32
5481 // accumulators.
5482 "dpadd_s.w $w8, $w24, $w31\n"
5483 "dpadd_s.w $w3, $w25, $w31\n"
5484 "dpadd_s.w $w12, $w26, $w31\n"
5485 "dpadd_s.w $w9, $w27, $w31\n"
5486 "dpadd_s.w $w13, $w28, $w31\n"
5487 "dpadd_s.w $w6, $w29, $w31\n"
5488
5489 // Dot product: multiply-add pairs of adjacent int8 elements.
5490 // Each dot product takes 16*2 int8 values in and produces 8 int16 sums.
5491 "dotp_s.h $w25, $w19, $w21\n"
5492 "dotp_s.h $w26, $w18, $w22\n"
5493 "dotp_s.h $w27, $w18, $w23\n"
5494 "dotp_s.h $w28, $w19, $w22\n"
5495 "dotp_s.h $w29, $w19, $w23\n"
5496 // Horizontal add of pairs of adjacent int16 sums into internal int32
5497 // accumulators.
5498 "dpadd_s.w $w7, $w25, $w31\n"
5499 "dpadd_s.w $w10, $w26, $w31\n"
5500 "dpadd_s.w $w14, $w27, $w31\n"
5501 "dpadd_s.w $w11, $w28, $w31\n"
5502 "dpadd_s.w $w15, $w29, $w31\n"
5503
5504 // Horizontal-add internal accumulators.
5505 "hadd_s.d $w0, $w0, $w0\n"
5506 "hadd_s.d $w1, $w1, $w1\n"
5507 "hadd_s.d $w2, $w2, $w2\n"
5508 "hadd_s.d $w3, $w3, $w3\n"
5509 "hadd_s.d $w4, $w4, $w4\n"
5510 "hadd_s.d $w5, $w5, $w5\n"
5511 "hadd_s.d $w6, $w6, $w6\n"
5512 "hadd_s.d $w7, $w7, $w7\n"
5513 "hadd_s.d $w8, $w8, $w8\n"
5514 "hadd_s.d $w9, $w9, $w9\n"
5515 "hadd_s.d $w10, $w10, $w10\n"
5516 "hadd_s.d $w11, $w11, $w11\n"
5517 "hadd_s.d $w12, $w12, $w12\n"
5518 "hadd_s.d $w13, $w13, $w13\n"
5519 "hadd_s.d $w14, $w14, $w14\n"
5520 "hadd_s.d $w15, $w15, $w15\n"
5521 "pckev.w $w0, $w1, $w0\n"
5522 "pckev.w $w2, $w3, $w2\n"
5523 "pckev.w $w4, $w5, $w4\n"
5524 "pckev.w $w6, $w7, $w6\n"
5525 "pckev.w $w8, $w9, $w8\n"
5526 "pckev.w $w10, $w11, $w10\n"
5527 "pckev.w $w12, $w13, $w12\n"
5528 "pckev.w $w14, $w15, $w14\n"
5529 "hadd_s.d $w0, $w0, $w0\n"
5530 "hadd_s.d $w2, $w2, $w2\n"
5531 "hadd_s.d $w4, $w4, $w4\n"
5532 "hadd_s.d $w6, $w6, $w6\n"
5533 "hadd_s.d $w8, $w8, $w8\n"
5534 "hadd_s.d $w10, $w10, $w10\n"
5535 "hadd_s.d $w12, $w12, $w12\n"
5536 "hadd_s.d $w14, $w14, $w14\n"
5537 // 4 more pckev instructions follow in both paths below.
5538
5539 // Check if start_depth==0 to decide whether we will load
5540 // existing accumulators from memory.
5541 "bnez %[start_depth], " GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES "f\n"
5542
5543 "pckev.w $w0, $w2, $w0\n"
5544 "pckev.w $w1, $w6, $w4\n"
5545 "pckev.w $w2, $w10, $w8\n"
5546 "pckev.w $w3, $w14, $w12\n"
5547
5548 "b " GEMMLOWP_LABEL_STORE "f\n"
5549
5550 GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES ":\n"
5551 // Load accumulators from memory.
5552 "ld.w $w16, 0(%[dst_ptr0])\n"
5553 "pckev.w $w0, $w2, $w0\n"
5554 "ld.w $w17, 0(%[dst_ptr1])\n"
5555 "pckev.w $w1, $w6, $w4\n"
5556 "ld.w $w18, 0(%[dst_ptr2])\n"
5557 "pckev.w $w2, $w10, $w8\n"
5558 "ld.w $w19, 0(%[dst_ptr3])\n"
5559 "pckev.w $w3, $w14, $w12\n"
5560
5561 // Add them to internal accumulators.
5562 "addv.w $w0, $w0, $w16\n"
5563 "addv.w $w1, $w1, $w17\n"
5564 "addv.w $w2, $w2, $w18\n"
5565 "addv.w $w3, $w3, $w19\n"
5566
5567 GEMMLOWP_LABEL_STORE ":\n"
5568 // Store accumulators.
5569 "st.w $w0, 0(%[dst_ptr0])\n"
5570 "st.w $w1, 0(%[dst_ptr1])\n"
5571 "st.w $w2, 0(%[dst_ptr2])\n"
5572 "st.w $w3, 0(%[dst_ptr3])\n"
5573 : // outputs
5574 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
5575 [run_depth] "+r"(run_depth)
5576 : // inputs
5577 [dst_ptr0] "r"(dst_ptr), [dst_ptr1] "r"(dst_ptr + dst_col_stride),
5578 [dst_ptr2] "r"(dst_ptr + dst_col_stride * 2),
5579 [dst_ptr3] "r"(dst_ptr + dst_col_stride * 3),
5580 [start_depth] "r"(start_depth)
5581 : // clobbers
5582 "memory", "$f0", "$f1", "$f2", "$f3", "$f4", "$f5", "$f6", "$f7", "$f8",
5583 "$f9", "$f10", "$f11", "$f12", "$f13", "$f14", "$f15", "$f16", "$f17",
5584 "$f18", "$f19", "$f20", "$f21", "$f22", "$f23", "$f24", "$f25", "$f26",
5585 "$f27", "$f28", "$f29", "$f30", "$f31");
5586 #undef GEMMLOWP_LABEL_LOOP
5587 #undef GEMMLOWP_LABEL_AFTER_LOOP_LAST16
5588 #undef GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES
5589 #undef GEMMLOWP_LABEL_STORE
5590 }
5591 };
5592 #endif // __mips
5593
5594 // BEGIN code copied from gemmlowp/internal/kernel_reference.h
5595
5596 // This kernel is templatized in an arbitrary Format template parameter,
5597 // allowing it to have any arbitrary format.
5598 template <typename tOperandType, typename tAccumulatorType, typename tFormat>
5599 struct ReferenceKernel {
5600 typedef tOperandType OperandType;
5601 typedef tAccumulatorType AccumulatorType;
5602 typedef tFormat Format;
5603
RunReferenceKernel5604 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
5605 AccumulatorType* accum_ptr, int depth) {
5606 const int depth_cells = static_cast<int>(depth / Format::kDepth);
5607
5608 // The outer loop is over the depth dimension.
5609 for (int dc = 0; dc < depth_cells; dc++) {
5610 // The next two loops are over cells of the Lhs (stacked vertically),
5611 // and over cells of the Rhs (stacked horizontally).
5612 for (int rc = 0; rc < Format::Lhs::kCells; rc++) {
5613 const OperandType* lhs_cell_ptr =
5614 lhs_ptr + (dc * Format::Lhs::kCells + rc) *
5615 Format::Lhs::Cell::kWidth * Format::kDepth;
5616 for (int cc = 0; cc < Format::Rhs::kCells; cc++) {
5617 const OperandType* rhs_cell_ptr =
5618 rhs_ptr + (dc * Format::Rhs::kCells + cc) *
5619 Format::Rhs::Cell::kWidth * Format::kDepth;
5620
5621 // Now we are inside one cell of the Lhs and inside one cell
5622 // of the Rhs, so the remaining inner loops are just
5623 // traditional three loops of matrix multiplication.
5624 for (int di = 0; di < Format::kDepth; di++) {
5625 for (int ri = 0; ri < Format::Lhs::Cell::kWidth; ri++) {
5626 for (int ci = 0; ci < Format::Rhs::Cell::kWidth; ci++) {
5627 const OperandType* lhs_coeff_ptr =
5628 lhs_cell_ptr +
5629 OffsetIntoCell<typename Format::Lhs::Cell>(ri, di);
5630 const OperandType* rhs_coeff_ptr =
5631 rhs_cell_ptr +
5632 OffsetIntoCell<typename Format::Rhs::Cell>(ci, di);
5633 AccumulatorType* accumulator_coeff_ptr =
5634 accum_ptr + (ri + rc * Format::Lhs::Cell::kWidth) +
5635 (ci + cc * Format::Rhs::Cell::kWidth) * Format::kRows;
5636 *accumulator_coeff_ptr += AccumulatorType(*lhs_coeff_ptr) *
5637 AccumulatorType(*rhs_coeff_ptr);
5638 }
5639 }
5640 }
5641 }
5642 }
5643 }
5644 }
5645 };
5646
5647 // END code copied from gemmlowp/internal/kernel_reference.h
5648
5649 template <typename DataType>
5650 class CacheLineAlignedBuffer {
5651 public:
CacheLineAlignedBuffer(std::size_t size)5652 CacheLineAlignedBuffer(std::size_t size) : size_(size) {
5653 data_ = nullptr;
5654 // Adds a few bytes of padding here, because the 64-bit 'A57' kernel
5655 // reads one iteration past the end the buffer, causing a crash on iOS.
5656 int res = posix_memalign(reinterpret_cast<void**>(&data_), kCacheLineSize,
5657 size_ * sizeof(DataType) + 16);
5658 (void)res;
5659 }
5660
~CacheLineAlignedBuffer()5661 ~CacheLineAlignedBuffer() { free(data_); }
5662
data() const5663 const DataType* data() const { return data_; }
data()5664 DataType* data() { return data_; }
5665
size() const5666 std::size_t size() const { return size_; }
5667
5668 private:
5669 const std::size_t size_;
5670 DataType* data_;
5671 };
5672
5673 template <typename DataType>
FillRandom(CacheLineAlignedBuffer<DataType> * buffer,DataType min,DataType max)5674 void FillRandom(CacheLineAlignedBuffer<DataType>* buffer, DataType min,
5675 DataType max) {
5676 static std::mt19937 generator(0);
5677 std::uniform_real_distribution<float> dist(min, max);
5678 for (std::size_t i = 0; i < buffer->size(); i++) {
5679 buffer->data()[i] = DataType(dist(generator));
5680 }
5681 }
5682
5683 template <typename DataType>
FillZero(CacheLineAlignedBuffer<DataType> * buffer)5684 void FillZero(CacheLineAlignedBuffer<DataType>* buffer) {
5685 for (std::size_t i = 0; i < buffer->size(); i++) {
5686 buffer->data()[i] = DataType(0);
5687 }
5688 }
5689
5690 template <typename DataType>
Copy(CacheLineAlignedBuffer<DataType> * dst,const CacheLineAlignedBuffer<DataType> & src)5691 void Copy(CacheLineAlignedBuffer<DataType>* dst,
5692 const CacheLineAlignedBuffer<DataType>& src) {
5693 assert(dst->size() == src.size());
5694 memcpy(dst->data(), src.data(), src.size() * sizeof(DataType));
5695 }
5696
5697 template <typename DataType>
PrintMatrix(int rows,int cols,int rowstride,int colstride,const DataType * data)5698 void PrintMatrix(int rows, int cols, int rowstride, int colstride,
5699 const DataType* data) {
5700 for (int r = 0; r < rows; r++) {
5701 for (int c = 0; c < cols; c++) {
5702 std::cerr << double(data[r * rowstride + c * colstride]) << " ";
5703 }
5704 std::cerr << std::endl;
5705 }
5706 std::cerr << std::endl;
5707 }
5708
5709 template <typename DataType>
approx_equals(DataType a,DataType b)5710 bool approx_equals(DataType a, DataType b) {
5711 return a == b;
5712 }
5713
5714 template <>
approx_equals(float a,float b)5715 bool approx_equals(float a, float b) {
5716 if (!a && !b) {
5717 return true;
5718 }
5719 // 1e-1 is very coarse accuracy, we should switch to an overall L2 metric
5720 // and tighten the tolerance on that metric.
5721 return std::abs(a - b) < 1e-1f * std::min(std::abs(a), std::abs(b));
5722 }
5723
5724 template <typename Kernel>
test_kernel(int depth,const char * kernel_name)5725 void test_kernel(int depth, const char* kernel_name) {
5726 typedef typename Kernel::OperandType OperandType;
5727 typedef typename Kernel::AccumulatorType AccumulatorType;
5728 typedef typename Kernel::Format Format;
5729 static const int kLhsWidth = Format::Lhs::kWidth;
5730 static const int kRhsWidth = Format::Rhs::kWidth;
5731
5732 typedef ReferenceKernel<OperandType, AccumulatorType, Format> ReferenceKernel;
5733
5734 CacheLineAlignedBuffer<OperandType> lhs(kLhsWidth * depth);
5735 CacheLineAlignedBuffer<OperandType> rhs(kRhsWidth * depth);
5736 CacheLineAlignedBuffer<AccumulatorType> accum_initial(kLhsWidth * kRhsWidth);
5737 CacheLineAlignedBuffer<AccumulatorType> accum(kLhsWidth * kRhsWidth);
5738 CacheLineAlignedBuffer<AccumulatorType> accum_reference(kLhsWidth *
5739 kRhsWidth);
5740
5741 FillRandom(&lhs, KernelOperandRanges<Kernel>::LhsMin(),
5742 KernelOperandRanges<Kernel>::LhsMax());
5743 FillRandom(&rhs, KernelOperandRanges<Kernel>::RhsMin(),
5744 KernelOperandRanges<Kernel>::RhsMax());
5745 FillRandom(&accum_initial,
5746 std::is_signed<AccumulatorType>::value
5747 ? AccumulatorType(-100)
5748 : AccumulatorType(0),
5749 AccumulatorType(100));
5750
5751 Copy(&accum, accum_initial);
5752 Copy(&accum_reference, accum_initial);
5753
5754 ReferenceKernel::Run(lhs.data(), rhs.data(), accum_reference.data(), depth);
5755 Kernel::Run(lhs.data(), rhs.data(), accum.data(), depth);
5756
5757 for (int l = 0; l < kLhsWidth; l++) {
5758 for (int r = 0; r < kRhsWidth; r++) {
5759 const int index = l + kLhsWidth * r;
5760 if (!approx_equals(accum.data()[index], accum_reference.data()[index])) {
5761 std::cerr << "Arithmetic error in kernel:" << std::endl
5762 << " " << kernel_name << std::endl
5763 << "Wrong accumulator for depth=" << depth << ", "
5764 << "at l = " << l << ", r = " << r << std::endl;
5765 std::cerr << "reference value: " << accum_reference.data()[index]
5766 << std::endl;
5767 std::cerr << "actual value: " << accum.data()[index] << std::endl;
5768 if (depth <= 16) {
5769 std::cerr << "LHS matrix:" << std::endl;
5770 PrintMatrix(kLhsWidth, depth, 1, kLhsWidth, lhs.data());
5771 std::cerr << "RHS matrix:" << std::endl;
5772 PrintMatrix(depth, kRhsWidth, kRhsWidth, 1, rhs.data());
5773 std::cerr << "Initial Accumulator matrix:" << std::endl;
5774 PrintMatrix(kLhsWidth, kRhsWidth, 1, kLhsWidth, accum_initial.data());
5775 std::cerr << "Reference Accumulator matrix:" << std::endl;
5776 PrintMatrix(kLhsWidth, kRhsWidth, 1, kLhsWidth,
5777 accum_reference.data());
5778 std::cerr << "Actual Accumulator matrix:" << std::endl;
5779 PrintMatrix(kLhsWidth, kRhsWidth, 1, kLhsWidth, accum.data());
5780 }
5781 abort();
5782 }
5783 }
5784 }
5785 }
5786
5787 template <typename Kernel>
ops(int depth)5788 int ops(int depth) {
5789 // 2x the number of multiply-accumulate scalar ops.
5790 return 2 * Kernel::Format::Lhs::kWidth * Kernel::Format::Rhs::kWidth * depth;
5791 }
5792
5793 template <unsigned Modulus, typename Integer>
RoundDown(Integer i)5794 Integer RoundDown(Integer i) {
5795 return i - (i % Modulus);
5796 }
5797
CacheSizeInKB()5798 int CacheSizeInKB() {
5799 static const char* cache_size_k_env = getenv("CACHE_SIZE_KB");
5800 static const int cache_size_k =
5801 cache_size_k_env ? atoi(cache_size_k_env) : kDefaultCacheSizeK;
5802 return cache_size_k;
5803 }
5804
5805 template <typename Kernel>
BenchmarkDepthToFitInCache()5806 int BenchmarkDepthToFitInCache() {
5807 const int cache_size_bytes = 1024 * CacheSizeInKB();
5808
5809 // Subtract the typical size of a few cache lines, so
5810 // we don't need to worry too hard about e.g. some stack data.
5811 const int conservative_cache_size_bytes =
5812 cache_size_bytes - 2 * kCacheLineSize;
5813
5814 // We will subtract the memory occupied by accumulators.
5815 typedef typename Kernel::AccumulatorType AccumulatorType;
5816 const int kAccumulatorBytes = sizeof(AccumulatorType) *
5817 Kernel::Format::Lhs::kWidth *
5818 Kernel::Format::Rhs::kWidth;
5819
5820 // Compute the depth.
5821 typedef typename Kernel::OperandType OperandType;
5822 const int kBytesPerUnitOfDepth =
5823 sizeof(OperandType) *
5824 (Kernel::Format::Lhs::kWidth + Kernel::Format::Rhs::kWidth);
5825 const int unrounded_depth =
5826 (conservative_cache_size_bytes - kAccumulatorBytes) /
5827 kBytesPerUnitOfDepth;
5828
5829 // Cap depth, to avoid unfairly favoring narrower kernels
5830 const int kMaxDepth = 1024;
5831 const int clamped_unrounded_depth = std::min(kMaxDepth, unrounded_depth);
5832
5833 // Round depth down to a multiple of cache line size, which helps because
5834 // our kernels may crash if depth is not a multiple of the number of
5835 // depth level that they want to
5836 // handle at each loop iteration, and we don't want to require kernels
5837 // to be more complex. Currently all kernels process 1, 2 or 8 levels of
5838 // depth at a time. The main reason why that might increase in the future
5839 // is if registers get wider, but I don't suppose that register could
5840 // ever get wider than cache lines.
5841 return RoundDown<kCacheLineSize>(clamped_unrounded_depth);
5842 }
5843
current_time_in_seconds()5844 double current_time_in_seconds() {
5845 timespec t;
5846 clock_gettime(CLOCK_REALTIME, &t);
5847 return t.tv_sec + 1e-9 * t.tv_nsec;
5848 }
5849
5850 template <typename Kernel>
benchmark(int depth)5851 double benchmark(int depth) {
5852 // Minimum duration for this benchmark to run. If the workload finishes
5853 // sooner, we retry with double the number of iterations.
5854 static const double min_benchmark_time_in_seconds = 1.0;
5855
5856 typedef typename Kernel::OperandType OperandType;
5857 typedef typename Kernel::AccumulatorType AccumulatorType;
5858
5859 CacheLineAlignedBuffer<OperandType> lhs(Kernel::Format::Lhs::kWidth * depth);
5860 CacheLineAlignedBuffer<OperandType> rhs(Kernel::Format::Rhs::kWidth * depth);
5861 CacheLineAlignedBuffer<AccumulatorType> accum(Kernel::Format::Lhs::kWidth *
5862 Kernel::Format::Rhs::kWidth);
5863
5864 for (std::uint64_t iters_at_a_time = 1;; iters_at_a_time *= 2) {
5865 const double t_start = current_time_in_seconds();
5866 for (std::uint64_t i = 0; i < iters_at_a_time; i++) {
5867 Kernel::Run(lhs.data(), rhs.data(), accum.data(), depth);
5868 }
5869 const double t_end = current_time_in_seconds();
5870 const double elapsed = t_end - t_start;
5871 if (elapsed > min_benchmark_time_in_seconds) {
5872 return iters_at_a_time * ops<Kernel>(depth) / elapsed;
5873 }
5874 }
5875 }
5876
5877 template <typename Kernel>
benchmark_and_print_results(const char * kernel_name)5878 void benchmark_and_print_results(const char* kernel_name) {
5879 if (getenv("BENCHMARK_KERNEL")) {
5880 if (strcmp(getenv("BENCHMARK_KERNEL"), kernel_name)) {
5881 return;
5882 }
5883 }
5884 const int kKernelDepth = Kernel::Format::kDepth;
5885 for (int depth = kKernelDepth; depth <= 1024; depth += kKernelDepth) {
5886 test_kernel<Kernel>(depth, kernel_name);
5887 }
5888
5889 if (getenv("BENCHMARK_ALL_DEPTHS")) {
5890 for (int depth = kKernelDepth;
5891 depth <= BenchmarkDepthToFitInCache<Kernel>(); depth *= 2) {
5892 std::cout << kernel_name << "," << depth << ","
5893 << benchmark<Kernel>(depth) * 1e-9f << std::endl;
5894 }
5895 } else {
5896 const int depth = BenchmarkDepthToFitInCache<Kernel>();
5897 std::cout << kernel_name << "," << benchmark<Kernel>(depth) * 1e-9f
5898 << std::endl;
5899 }
5900 }
5901
5902 #define BENCHMARK(Kernel) \
5903 do { \
5904 benchmark_and_print_results<Kernel>(#Kernel); \
5905 } while (false)
5906
main()5907 int main() {
5908 if (getenv("BENCHMARK_ALL_DEPTHS")) {
5909 std::cout << "kernel,depth,Gop/s" << std::endl;
5910 } else {
5911 std::cout << "kernel,Gop/s" << std::endl;
5912 }
5913
5914 #ifdef __arm__
5915 BENCHMARK(NEON_32bit_GEMM_Int8Operands_AccumTwoWithin16Bits);
5916 BENCHMARK(NEON_32bit_GEMM_Int8Operands_AccumTwoWithin16Bits_intrinsics);
5917 BENCHMARK(NEON_32bit_GEMM_Uint8Operands_Uint32Accumulators);
5918 BENCHMARK(NEON_32bit_GEMM_Uint8Operands_Uint32Accumulators_intrinsics);
5919 BENCHMARK(NEON_32bit_GEMM_Uint8Operands_Uint32Accumulators_noexpand);
5920 BENCHMARK(NEON_32bit_GEMM_Int32_WithScalar);
5921 BENCHMARK(NEON_32bit_GEMM_Float32_MLA_WithVectorDuplicatingScalar);
5922 #ifdef __ARM_FEATURE_FMA
5923 BENCHMARK(NEON_32bit_GEMM_Float32_FMA_WithVectorDuplicatingScalar);
5924 #endif
5925 BENCHMARK(NEON_32bit_GEMM_Float32_MLA_WithScalar);
5926 BENCHMARK(NEON_32bit_GEMM_Float32_WithScalar_intrinsics);
5927 BENCHMARK(NEON_32bit_GEMM_Float32_WithScalar_A53);
5928 BENCHMARK(NEON_32bit_GEMM_Float32_WithScalar_A53_depth2);
5929 BENCHMARK(NEON_32bit_GEMM_Float32_MLA_Rotating);
5930 #ifdef __ARM_FEATURE_FMA
5931 BENCHMARK(NEON_32bit_GEMM_Float32_FMA_Rotating);
5932 #endif
5933 #endif
5934
5935 #ifdef __aarch64__
5936 BENCHMARK(NEON_64bit_GEMM_Int425Operands);
5937 BENCHMARK(NEON_64bit_GEMM_Int425Operands_intrinsics);
5938 BENCHMARK(NEON_64bit_GEMM_Int7Operands_AccumEightWithin16Bits);
5939 BENCHMARK(NEON_64bit_GEMM_Int7Operands_AccumEightWithin16Bits_intrinsics);
5940 BENCHMARK(NEON_64bit_GEMM_Int8Operands_AccumTwoWithin16Bits);
5941 BENCHMARK(NEON_64bit_GEMM_Int8Operands_AccumTwoWithin16Bits_intrinsics);
5942 BENCHMARK(NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators);
5943 BENCHMARK(NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_intrinsics);
5944 BENCHMARK(NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_noexpand_A57);
5945 #ifdef __ARM_FEATURE_DOTPROD
5946 BENCHMARK(NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct);
5947 BENCHMARK(NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct_A55r1);
5948 BENCHMARK(NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct_narrow);
5949 #endif
5950 BENCHMARK(NEON_64bit_GEMM_Int32_WithScalar);
5951 BENCHMARK(NEON_64bit_GEMM_Float32_WithVectorDuplicatingScalar);
5952 BENCHMARK(NEON_64bit_GEMM_Float32_WithScalar);
5953 BENCHMARK(NEON_64bit_GEMM_Float32_WithScalar_intrinsics);
5954 BENCHMARK(NEON_64bit_GEMM_Float32_WithScalar_A57);
5955 #ifndef __APPLE__
5956 BENCHMARK(NEON_64bit_GEMM_Float32_WithScalar_A53);
5957 #endif
5958 BENCHMARK(NEON_64bit_GEMM_Float32_WithScalar_A55r1);
5959 #endif
5960
5961 #ifdef __mips
5962 BENCHMARK(MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators1);
5963 BENCHMARK(MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators2);
5964 BENCHMARK(MSA_GEMM_Int8Operands_AccumTwoWithin16Bits);
5965 #endif
5966
5967 return 0;
5968 }
5969