1 /*
2  * Copyright (C) 2010-2018 Arm Limited or its affiliates. All rights reserved.
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_fully_connected_q15_opt.c
22  * Description:  Q15 opt fully-connected layer function
23  *
24  * $Date:        17. January 2018
25  * $Revision:    V.1.0.0
26  *
27  * Target Processor:  Cortex-M cores
28  *
29  * -------------------------------------------------------------------- */
30 
31 #include "arm_math.h"
32 #include "arm_nnfunctions.h"
33 
34 /**
35  *  @ingroup groupNN
36  */
37 
38 /**
39  * @addtogroup FC
40  * @{
41  */
42 
43   /**
44    * @brief Q15 opt fully-connected layer function
45    * @param[in]       pV          pointer to input vector
46    * @param[in]       pM          pointer to matrix weights
47    * @param[in]       dim_vec     length of the vector
48    * @param[in]       num_of_rows number of rows in weight matrix
49    * @param[in]       bias_shift  amount of left-shift for bias
50    * @param[in]       out_shift   amount of right-shift for output
51    * @param[in]       bias        pointer to bias
52    * @param[in,out]   pOut        pointer to output vector
53    * @param[in,out]   vec_buffer  pointer to buffer space for input
54    * @return     The function returns <code>ARM_MATH_SUCCESS</code>
55    *
56    *
57    * @details
58    *
59    * <b>Buffer size:</b>
60    *
61    * vec_buffer size: 0
62    *
63    *  Here we use only one pointer to read 4 rows in the weight
64    *  matrix. So if the original matrix looks like this:
65    *
66    *  | a11 | a12 | a13 |
67    *
68    *  | a21 | a22 | a23 |
69    *
70    *  | a31 | a32 | a33 |
71    *
72    *  | a41 | a42 | a43 |
73    *
74    *  | a51 | a52 | a53 |
75    *
76    *  | a61 | a62 | a63 |
77    *
78    *  We operates on multiple-of-4 rows, so the first four rows becomes
79    *
80    *  | a11 | a12 | a21 | a22 | a31 | a32 | a41 | a42 |
81    *
82    *  | a13 | a23 | a33 | a43 |
83    *
84    *  Remaining rows are kept the same original order.
85    *
86    *  So the stored weight matrix looks like this:
87    *
88    *
89    *  | a11 | a12 | a21 | a22 | a31 | a32 | a41 | a42 |
90    *
91    *  | a13 | a23 | a33 | a43 | a51 | a52 | a53 | a61 |
92    *
93    *  | a62 | a63 |
94    */
95 
96 arm_status
arm_fully_connected_q15_opt(const q15_t * pV,const q15_t * pM,const uint16_t dim_vec,const uint16_t num_of_rows,const uint16_t bias_shift,const uint16_t out_shift,const q15_t * bias,q15_t * pOut,q15_t * vec_buffer)97 arm_fully_connected_q15_opt(const q15_t * pV,
98                             const q15_t * pM,
99                             const uint16_t dim_vec,
100                             const uint16_t num_of_rows,
101                             const uint16_t bias_shift,
102                             const uint16_t out_shift,
103                             const q15_t * bias,
104                             q15_t * pOut,
105                             q15_t * vec_buffer)
106 {
107 
108 #if defined (ARM_MATH_DSP)
109     /* Run the following code for Cortex-M4 and Cortex-M7 */
110 
111     const q15_t *pB = pM;
112     q15_t    *pO = pOut;
113     const q15_t *pBias = bias;
114     const q15_t *pA = pV;
115 
116     uint16_t  rowCnt = num_of_rows >> 2;
117 
118     while (rowCnt)
119     {
120         q31_t     sum =  ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
121         q31_t     sum2 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
122         q31_t     sum3 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
123         q31_t     sum4 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
124 
125         uint16_t  colCnt = dim_vec >> 1;
126 
127         pA = pV;
128 
129 #ifdef USE_INTRINSIC
130 
131         while (colCnt)
132         {
133             q31_t     inM11, inM12, inM13, inM14;
134             q31_t     inV;
135 
136             inV = *__SIMD32(pA)++;
137             inM11 = *__SIMD32(pB)++;
138             sum = __SMLAD(inV, inM11, sum);
139             inM12 = *__SIMD32(pB)++;
140             sum2 = __SMLAD(inV, inM12, sum2);
141             inM13 = *__SIMD32(pB)++;
142             sum3 = __SMLAD(inV, inM13, sum3);
143             inM14 = *__SIMD32(pB)++;
144             sum4 = __SMLAD(inV, inM14, sum4);
145             colCnt--;
146         }
147 
148 #else
149 
150         /*
151          * register needed:
152          * loop counter: colCnt
153          * accumulators: sum, sum2, sum3, sum4
154          * pointers: pB, pA
155          * weight data: inM11, inM12, inM13, inM14
156          * activation data: inV
157          */
158 
159         asm volatile ("COL_LOOP_%=:\n"
160                       "ldr.w r4, [%[pA]], #4\n"
161                       "ldr.w r0, [%[pB]], #16\n"
162                       "smlad %[sum], r4, r0, %[sum]\n"
163                       "ldr.w r1, [%[pB] , #-12]\n"
164                       "smlad %[sum2], r4, r1, %[sum2]\n"
165                       "ldr.w r2, [%[pB] , #-8]\n"
166                       "smlad %[sum3], r4, r2, %[sum3]\n"
167                       "ldr.w r3, [%[pB] , #-4]\n"
168                       "smlad %[sum4], r4, r3, %[sum4]\n"
169                       "subs %[colCnt], #1\n"
170                       "bne COL_LOOP_%=\n":[sum] "+r"(sum),
171                       [sum2] "+r"(sum2),[sum3] "+r"(sum3),
172                       [sum4] "+r"(sum4),[pB] "+r"(pB),[pA] "+r"(pA):[colCnt] "r"(colCnt):"r0", "r1", "r2", "r3", "r4");
173 
174 #endif                          /* USE_INTRINSIC */
175 
176         colCnt = dim_vec & 0x1;
177         while (colCnt)
178         {
179 
180             q15_t     inV = *pA++;
181             q15_t     inM = *pB++;
182             q15_t     inM2 = *pB++;
183             q15_t     inM3 = *pB++;
184             q15_t     inM4 = *pB++;
185 
186             sum += inV * inM;
187             sum2 += inV * inM2;
188             sum3 += inV * inM3;
189             sum4 += inV * inM4;
190             colCnt--;
191         }                       /* while over colCnt */
192         *pO++ = (q15_t) (__SSAT((sum >> out_shift), 16));
193         *pO++ = (q15_t) (__SSAT((sum2 >> out_shift), 16));
194         *pO++ = (q15_t) (__SSAT((sum3 >> out_shift), 16));
195         *pO++ = (q15_t) (__SSAT((sum4 >> out_shift), 16));
196 
197         /* adjust the pointers and counters */
198         rowCnt--;
199     }
200 
201     /* left-over part of the rows */
202     rowCnt = num_of_rows & 0x3;
203 
204     while (rowCnt)
205     {
206         q31_t     sum = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
207 
208         uint16_t  colCnt = dim_vec >> 2;
209 
210         pA = pV;
211 
212         while (colCnt)
213         {
214             q31_t     inV1, inV2, inM1, inM2;
215 
216             inM1 = *__SIMD32(pB)++;
217             inV1 = *__SIMD32(pA)++;
218             sum = __SMLAD(inV1, inM1, sum);
219 
220             inM2 = *__SIMD32(pB)++;
221             inV2 = *__SIMD32(pA)++;
222             sum = __SMLAD(inV2, inM2, sum);
223 
224             colCnt--;
225         }
226 
227         /* left-over of the vector */
228         colCnt = dim_vec & 0x3;
229         while (colCnt)
230         {
231             q15_t     inV = *pA++;
232             q15_t     inM = *pB++;
233             sum += inV * inM;
234             colCnt--;
235         }
236 
237         *pO++ = (q15_t) (__SSAT((sum >> out_shift), 16));
238 
239         rowCnt--;
240     }
241 
242 #else
243     /* Run the following code as reference implementation for Cortex-M0 and Cortex-M3 */
244     uint16_t  rowCnt = num_of_rows >> 2;
245     const q15_t *pB = pM;
246     const q15_t *pA;
247     q15_t    *pO = pOut;
248     const q15_t *pBias = bias;
249 
250     while (rowCnt)
251     {
252         q31_t     sum =  ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
253         q31_t     sum2 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
254         q31_t     sum3 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
255         q31_t     sum4 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
256 
257         uint16_t  colCnt = dim_vec >> 1;
258 
259         pA = pV;
260         while (colCnt)
261         {
262             q15_t     inA1 = *pA++;
263             q15_t     inA2 = *pA++;
264 
265             q15_t     inB1 = *pB++;
266             q15_t     inB2 = *pB++;
267             sum += inA1 * inB1 + inA2 * inB2;
268 
269             inB1 = *pB++;
270             inB2 = *pB++;
271             sum2 += inA1 * inB1 + inA2 * inB2;
272 
273             inB1 = *pB++;
274             inB2 = *pB++;
275             sum3 += inA1 * inB1 + inA2 * inB2;
276 
277             inB1 = *pB++;
278             inB2 = *pB++;
279             sum4 += inA1 * inB1 + inA2 * inB2;
280 
281             colCnt--;
282         }
283         colCnt = dim_vec & 0x1;
284         while (colCnt)
285         {
286             q15_t     inA = *pA++;
287             q15_t     inB = *pB++;
288             sum += inA * inB;
289             inB = *pB++;
290             sum2 += inA * inB;
291             inB = *pB++;
292             sum3 += inA * inB;
293             inB = *pB++;
294             sum4 += inA * inB;
295             colCnt--;
296         }
297         *pO++ = (q15_t) __SSAT((sum >> out_shift), 16);
298         *pO++ = (q15_t) __SSAT((sum2 >> out_shift), 16);
299         *pO++ = (q15_t) __SSAT((sum3 >> out_shift), 16);
300         *pO++ = (q15_t) __SSAT((sum4 >> out_shift), 16);
301 
302         rowCnt--;
303     }
304     rowCnt = num_of_rows & 0x3;
305 
306     while (rowCnt)
307     {
308         int       ip_out = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
309         int       j;
310 
311         pA = pV;
312         for (j = 0; j < dim_vec; j++)
313         {
314             q15_t     inA = *pA++;
315             q15_t     inB = *pB++;
316             ip_out += inA * inB;
317         }
318         *pO++ = (q15_t) __SSAT((ip_out >> out_shift), 16);
319 
320         rowCnt--;
321     }
322 
323 #endif                          /* ARM_MATH_DSP */
324 
325     /* Return to ARM_MATH_SUCCESS */
326     return (ARM_MATH_SUCCESS);
327 
328 }
329 
330 /**
331  * @} end of FC group
332  */
333