1 /*
2  * Copyright (C) 2010-2018 Arm Limited or its affiliates. All rights reserved.
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_nn_mat_mult_kernel_q7_q15.c
22  * Description:  Matrix-multiplication function for convolution
23  *
24  * $Date:        17. January 2018
25  * $Revision:    V.1.0.0
26  *
27  * Target Processor:  Cortex-M cores
28  * -------------------------------------------------------------------- */
29 
30 #include "arm_math.h"
31 #include "arm_nnfunctions.h"
32 
33   /**
34    * @brief Matrix-multiplication function for convolution
35    * @param[in]       pA          pointer to operand A
36    * @param[in]       pInBuffer   pointer to operand B, always conssists of 2 vectors
37    * @param[in]       ch_im_out   numRow of A
38    * @param[in]       numCol_A    numCol of A
39    * @param[in]       bias_shift  amount of left-shift for bias
40    * @param[in]       out_shift   amount of right-shift for output
41    * @param[in]       bias        the bias
42    * @param[in,out]   pOut        pointer to output
43    * @return     The function returns the incremented output pointer
44    *
45    * @details
46    *
47    * This function does the matrix multiplication with weight matrix
48    * and 2 columns from im2col.
49    */
50 
arm_nn_mat_mult_kernel_q7_q15(const q7_t * pA,const q15_t * pInBuffer,const uint16_t ch_im_out,const uint16_t numCol_A,const uint16_t bias_shift,const uint16_t out_shift,const q7_t * bias,q7_t * pOut)51 q7_t     *arm_nn_mat_mult_kernel_q7_q15(const q7_t * pA,
52                                         const q15_t * pInBuffer,
53                                         const uint16_t ch_im_out,
54                                         const uint16_t numCol_A,
55                                         const uint16_t bias_shift,
56                                         const uint16_t out_shift,
57                                         const q7_t * bias,
58                                         q7_t * pOut)
59 {
60 #if defined (ARM_MATH_DSP)
61     /* set up the second output pointers */
62     q7_t     *pOut2 = pOut + ch_im_out;
63     const q7_t *pBias = bias;
64 
65     uint16_t  rowCnt = ch_im_out >> 1;
66     /* this loop over rows in A */
67     while (rowCnt)
68     {
69         /* setup pointers for B */
70         const q15_t *pB = pInBuffer;
71         const q15_t *pB2 = pB + numCol_A;
72 
73         /* align the second pointer for A */
74         const q7_t *pA2 = pA + numCol_A;
75 
76         /* init the sum with bias */
77         q31_t     sum =  ((q31_t)(*pBias) << bias_shift) + NN_ROUND(out_shift);
78         q31_t     sum2 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
79         q31_t     sum3 = ((q31_t)(*pBias) << bias_shift) + NN_ROUND(out_shift);
80         q31_t     sum4 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
81 
82         uint16_t  colCnt = numCol_A >> 2;
83         /* accumulate over the vector */
84         while (colCnt)
85         {
86             q31_t     inA11, inA12, inA21, inA22;
87             q31_t     inB1 = *__SIMD32(pB)++;
88             q31_t     inB2 = *__SIMD32(pB2)++;
89 
90             pA = (q7_t *) read_and_pad((void *)pA, &inA11, &inA12);
91             pA2 = (q7_t *) read_and_pad((void *)pA2, &inA21, &inA22);
92 
93             sum = __SMLAD(inA11, inB1, sum);
94             sum2 = __SMLAD(inA11, inB2, sum2);
95             sum3 = __SMLAD(inA21, inB1, sum3);
96             sum4 = __SMLAD(inA21, inB2, sum4);
97 
98             inB1 = *__SIMD32(pB)++;
99             inB2 = *__SIMD32(pB2)++;
100 
101             sum = __SMLAD(inA12, inB1, sum);
102             sum2 = __SMLAD(inA12, inB2, sum2);
103             sum3 = __SMLAD(inA22, inB1, sum3);
104             sum4 = __SMLAD(inA22, inB2, sum4);
105 
106             colCnt--;
107         }                       /* while over colCnt */
108         colCnt = numCol_A & 0x3;
109         while (colCnt)
110         {
111             q7_t      inA1 = *pA++;
112             q15_t     inB1 = *pB++;
113             q7_t      inA2 = *pA2++;
114             q15_t     inB2 = *pB2++;
115 
116             sum += inA1 * inB1;
117             sum2 += inA1 * inB2;
118             sum3 += inA2 * inB1;
119             sum4 += inA2 * inB2;
120             colCnt--;
121         }                       /* while over colCnt */
122         *pOut++ = (q7_t) __SSAT((sum >> out_shift), 8);
123         *pOut++ = (q7_t) __SSAT((sum3 >> out_shift), 8);
124         *pOut2++ = (q7_t) __SSAT((sum2 >> out_shift), 8);
125         *pOut2++ = (q7_t) __SSAT((sum4 >> out_shift), 8);
126 
127         /* skip the row computed with A2 */
128         pA += numCol_A;
129         rowCnt--;
130     }                           /* for over ch_im_out */
131 
132     /* compute left-over row if any */
133     if (ch_im_out & 0x1)
134     {
135         /* setup pointers for B */
136         const q15_t *pB = pInBuffer;
137         const q15_t *pB2 = pB + numCol_A;
138 
139         /* load the bias */
140         q31_t     sum = ((q31_t)(*pBias) << bias_shift) + NN_ROUND(out_shift);
141         q31_t     sum2 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
142 
143         uint16_t  colCnt = numCol_A >> 2;
144         while (colCnt)
145         {
146             q31_t     inA11, inA12;
147             q31_t     inB1 = *__SIMD32(pB)++;
148             q31_t     inB2 = *__SIMD32(pB2)++;
149 
150             pA = (q7_t *) read_and_pad((void *)pA, &inA11, &inA12);
151 
152             sum = __SMLAD(inA11, inB1, sum);
153             sum2 = __SMLAD(inA11, inB2, sum2);
154 
155             inB1 = *__SIMD32(pB)++;
156             inB2 = *__SIMD32(pB2)++;
157             sum = __SMLAD(inA12, inB1, sum);
158             sum2 = __SMLAD(inA12, inB2, sum2);
159 
160             colCnt--;
161         }
162         colCnt = numCol_A & 0x3;
163         while (colCnt)
164         {
165             q7_t      inA1 = *pA++;
166             q15_t     inB1 = *pB++;
167             q15_t     inB2 = *pB2++;
168 
169             sum += inA1 * inB1;
170             sum2 += inA1 * inB2;
171             colCnt--;
172         }
173 
174         *pOut++ = (q7_t) __SSAT((sum >> out_shift), 8);
175         *pOut2++ = (q7_t) __SSAT((sum2 >> out_shift), 8);
176     }
177 
178     pOut += ch_im_out;
179 
180     /* return the new output pointer with offset */
181     return pOut;
182 #else
183     /* To be completed */
184     return NULL;
185 #endif                          /* ARM_MATH_DSP */
186 
187 }
188