1 /*
2 * Copyright (C) 2010-2018 Arm Limited or its affiliates. All rights reserved.
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 /* ----------------------------------------------------------------------
20 * Project: CMSIS NN Library
21 * Title: arm_fully_connected_q15_opt.c
22 * Description: Q15 opt fully-connected layer function
23 *
24 * $Date: 17. January 2018
25 * $Revision: V.1.0.0
26 *
27 * Target Processor: Cortex-M cores
28 *
29 * -------------------------------------------------------------------- */
30
31 #include "arm_math.h"
32 #include "arm_nnfunctions.h"
33
34 /**
35 * @ingroup groupNN
36 */
37
38 /**
39 * @addtogroup FC
40 * @{
41 */
42
43 /**
44 * @brief Q15 opt fully-connected layer function
45 * @param[in] pV pointer to input vector
46 * @param[in] pM pointer to matrix weights
47 * @param[in] dim_vec length of the vector
48 * @param[in] num_of_rows number of rows in weight matrix
49 * @param[in] bias_shift amount of left-shift for bias
50 * @param[in] out_shift amount of right-shift for output
51 * @param[in] bias pointer to bias
52 * @param[in,out] pOut pointer to output vector
53 * @param[in,out] vec_buffer pointer to buffer space for input
54 * @return The function returns <code>ARM_MATH_SUCCESS</code>
55 *
56 *
57 * @details
58 *
59 * <b>Buffer size:</b>
60 *
61 * vec_buffer size: 0
62 *
63 * Here we use only one pointer to read 4 rows in the weight
64 * matrix. So if the original matrix looks like this:
65 *
66 * | a11 | a12 | a13 |
67 *
68 * | a21 | a22 | a23 |
69 *
70 * | a31 | a32 | a33 |
71 *
72 * | a41 | a42 | a43 |
73 *
74 * | a51 | a52 | a53 |
75 *
76 * | a61 | a62 | a63 |
77 *
78 * We operates on multiple-of-4 rows, so the first four rows becomes
79 *
80 * | a11 | a12 | a21 | a22 | a31 | a32 | a41 | a42 |
81 *
82 * | a13 | a23 | a33 | a43 |
83 *
84 * Remaining rows are kept the same original order.
85 *
86 * So the stored weight matrix looks like this:
87 *
88 *
89 * | a11 | a12 | a21 | a22 | a31 | a32 | a41 | a42 |
90 *
91 * | a13 | a23 | a33 | a43 | a51 | a52 | a53 | a61 |
92 *
93 * | a62 | a63 |
94 */
95
96 arm_status
arm_fully_connected_q15_opt(const q15_t * pV,const q15_t * pM,const uint16_t dim_vec,const uint16_t num_of_rows,const uint16_t bias_shift,const uint16_t out_shift,const q15_t * bias,q15_t * pOut,q15_t * vec_buffer)97 arm_fully_connected_q15_opt(const q15_t * pV,
98 const q15_t * pM,
99 const uint16_t dim_vec,
100 const uint16_t num_of_rows,
101 const uint16_t bias_shift,
102 const uint16_t out_shift,
103 const q15_t * bias,
104 q15_t * pOut,
105 q15_t * vec_buffer)
106 {
107
108 #if defined (ARM_MATH_DSP)
109 /* Run the following code for Cortex-M4 and Cortex-M7 */
110
111 const q15_t *pB = pM;
112 q15_t *pO = pOut;
113 const q15_t *pBias = bias;
114 const q15_t *pA = pV;
115
116 uint16_t rowCnt = num_of_rows >> 2;
117
118 while (rowCnt)
119 {
120 q31_t sum = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
121 q31_t sum2 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
122 q31_t sum3 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
123 q31_t sum4 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
124
125 uint16_t colCnt = dim_vec >> 1;
126
127 pA = pV;
128
129 #ifdef USE_INTRINSIC
130
131 while (colCnt)
132 {
133 q31_t inM11, inM12, inM13, inM14;
134 q31_t inV;
135
136 inV = *__SIMD32(pA)++;
137 inM11 = *__SIMD32(pB)++;
138 sum = __SMLAD(inV, inM11, sum);
139 inM12 = *__SIMD32(pB)++;
140 sum2 = __SMLAD(inV, inM12, sum2);
141 inM13 = *__SIMD32(pB)++;
142 sum3 = __SMLAD(inV, inM13, sum3);
143 inM14 = *__SIMD32(pB)++;
144 sum4 = __SMLAD(inV, inM14, sum4);
145 colCnt--;
146 }
147
148 #else
149
150 /*
151 * register needed:
152 * loop counter: colCnt
153 * accumulators: sum, sum2, sum3, sum4
154 * pointers: pB, pA
155 * weight data: inM11, inM12, inM13, inM14
156 * activation data: inV
157 */
158
159 asm volatile ("COL_LOOP_%=:\n"
160 "ldr.w r4, [%[pA]], #4\n"
161 "ldr.w r0, [%[pB]], #16\n"
162 "smlad %[sum], r4, r0, %[sum]\n"
163 "ldr.w r1, [%[pB] , #-12]\n"
164 "smlad %[sum2], r4, r1, %[sum2]\n"
165 "ldr.w r2, [%[pB] , #-8]\n"
166 "smlad %[sum3], r4, r2, %[sum3]\n"
167 "ldr.w r3, [%[pB] , #-4]\n"
168 "smlad %[sum4], r4, r3, %[sum4]\n"
169 "subs %[colCnt], #1\n"
170 "bne COL_LOOP_%=\n":[sum] "+r"(sum),
171 [sum2] "+r"(sum2),[sum3] "+r"(sum3),
172 [sum4] "+r"(sum4),[pB] "+r"(pB),[pA] "+r"(pA):[colCnt] "r"(colCnt):"r0", "r1", "r2", "r3", "r4");
173
174 #endif /* USE_INTRINSIC */
175
176 colCnt = dim_vec & 0x1;
177 while (colCnt)
178 {
179
180 q15_t inV = *pA++;
181 q15_t inM = *pB++;
182 q15_t inM2 = *pB++;
183 q15_t inM3 = *pB++;
184 q15_t inM4 = *pB++;
185
186 sum += inV * inM;
187 sum2 += inV * inM2;
188 sum3 += inV * inM3;
189 sum4 += inV * inM4;
190 colCnt--;
191 } /* while over colCnt */
192 *pO++ = (q15_t) (__SSAT((sum >> out_shift), 16));
193 *pO++ = (q15_t) (__SSAT((sum2 >> out_shift), 16));
194 *pO++ = (q15_t) (__SSAT((sum3 >> out_shift), 16));
195 *pO++ = (q15_t) (__SSAT((sum4 >> out_shift), 16));
196
197 /* adjust the pointers and counters */
198 rowCnt--;
199 }
200
201 /* left-over part of the rows */
202 rowCnt = num_of_rows & 0x3;
203
204 while (rowCnt)
205 {
206 q31_t sum = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
207
208 uint16_t colCnt = dim_vec >> 2;
209
210 pA = pV;
211
212 while (colCnt)
213 {
214 q31_t inV1, inV2, inM1, inM2;
215
216 inM1 = *__SIMD32(pB)++;
217 inV1 = *__SIMD32(pA)++;
218 sum = __SMLAD(inV1, inM1, sum);
219
220 inM2 = *__SIMD32(pB)++;
221 inV2 = *__SIMD32(pA)++;
222 sum = __SMLAD(inV2, inM2, sum);
223
224 colCnt--;
225 }
226
227 /* left-over of the vector */
228 colCnt = dim_vec & 0x3;
229 while (colCnt)
230 {
231 q15_t inV = *pA++;
232 q15_t inM = *pB++;
233 sum += inV * inM;
234 colCnt--;
235 }
236
237 *pO++ = (q15_t) (__SSAT((sum >> out_shift), 16));
238
239 rowCnt--;
240 }
241
242 #else
243 /* Run the following code as reference implementation for Cortex-M0 and Cortex-M3 */
244 uint16_t rowCnt = num_of_rows >> 2;
245 const q15_t *pB = pM;
246 const q15_t *pA;
247 q15_t *pO = pOut;
248 const q15_t *pBias = bias;
249
250 while (rowCnt)
251 {
252 q31_t sum = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
253 q31_t sum2 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
254 q31_t sum3 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
255 q31_t sum4 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
256
257 uint16_t colCnt = dim_vec >> 1;
258
259 pA = pV;
260 while (colCnt)
261 {
262 q15_t inA1 = *pA++;
263 q15_t inA2 = *pA++;
264
265 q15_t inB1 = *pB++;
266 q15_t inB2 = *pB++;
267 sum += inA1 * inB1 + inA2 * inB2;
268
269 inB1 = *pB++;
270 inB2 = *pB++;
271 sum2 += inA1 * inB1 + inA2 * inB2;
272
273 inB1 = *pB++;
274 inB2 = *pB++;
275 sum3 += inA1 * inB1 + inA2 * inB2;
276
277 inB1 = *pB++;
278 inB2 = *pB++;
279 sum4 += inA1 * inB1 + inA2 * inB2;
280
281 colCnt--;
282 }
283 colCnt = dim_vec & 0x1;
284 while (colCnt)
285 {
286 q15_t inA = *pA++;
287 q15_t inB = *pB++;
288 sum += inA * inB;
289 inB = *pB++;
290 sum2 += inA * inB;
291 inB = *pB++;
292 sum3 += inA * inB;
293 inB = *pB++;
294 sum4 += inA * inB;
295 colCnt--;
296 }
297 *pO++ = (q15_t) __SSAT((sum >> out_shift), 16);
298 *pO++ = (q15_t) __SSAT((sum2 >> out_shift), 16);
299 *pO++ = (q15_t) __SSAT((sum3 >> out_shift), 16);
300 *pO++ = (q15_t) __SSAT((sum4 >> out_shift), 16);
301
302 rowCnt--;
303 }
304 rowCnt = num_of_rows & 0x3;
305
306 while (rowCnt)
307 {
308 int ip_out = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
309 int j;
310
311 pA = pV;
312 for (j = 0; j < dim_vec; j++)
313 {
314 q15_t inA = *pA++;
315 q15_t inB = *pB++;
316 ip_out += inA * inB;
317 }
318 *pO++ = (q15_t) __SSAT((ip_out >> out_shift), 16);
319
320 rowCnt--;
321 }
322
323 #endif /* ARM_MATH_DSP */
324
325 /* Return to ARM_MATH_SUCCESS */
326 return (ARM_MATH_SUCCESS);
327
328 }
329
330 /**
331 * @} end of FC group
332 */
333