1 /*
2 * Copyright (C) 2010-2018 Arm Limited or its affiliates. All rights reserved.
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 /* ----------------------------------------------------------------------
20 * Project: CMSIS NN Library
21 * Title: arm_nn_mat_mult_kernel_q7_q15.c
22 * Description: Matrix-multiplication function for convolution
23 *
24 * $Date: 17. January 2018
25 * $Revision: V.1.0.0
26 *
27 * Target Processor: Cortex-M cores
28 * -------------------------------------------------------------------- */
29
30 #include "arm_math.h"
31 #include "arm_nnfunctions.h"
32
33 /**
34 * @brief Matrix-multiplication function for convolution
35 * @param[in] pA pointer to operand A
36 * @param[in] pInBuffer pointer to operand B, always conssists of 2 vectors
37 * @param[in] ch_im_out numRow of A
38 * @param[in] numCol_A numCol of A
39 * @param[in] bias_shift amount of left-shift for bias
40 * @param[in] out_shift amount of right-shift for output
41 * @param[in] bias the bias
42 * @param[in,out] pOut pointer to output
43 * @return The function returns the incremented output pointer
44 *
45 * @details
46 *
47 * This function does the matrix multiplication with weight matrix
48 * and 2 columns from im2col.
49 */
50
arm_nn_mat_mult_kernel_q7_q15(const q7_t * pA,const q15_t * pInBuffer,const uint16_t ch_im_out,const uint16_t numCol_A,const uint16_t bias_shift,const uint16_t out_shift,const q7_t * bias,q7_t * pOut)51 q7_t *arm_nn_mat_mult_kernel_q7_q15(const q7_t * pA,
52 const q15_t * pInBuffer,
53 const uint16_t ch_im_out,
54 const uint16_t numCol_A,
55 const uint16_t bias_shift,
56 const uint16_t out_shift,
57 const q7_t * bias,
58 q7_t * pOut)
59 {
60 #if defined (ARM_MATH_DSP)
61 /* set up the second output pointers */
62 q7_t *pOut2 = pOut + ch_im_out;
63 const q7_t *pBias = bias;
64
65 uint16_t rowCnt = ch_im_out >> 1;
66 /* this loop over rows in A */
67 while (rowCnt)
68 {
69 /* setup pointers for B */
70 const q15_t *pB = pInBuffer;
71 const q15_t *pB2 = pB + numCol_A;
72
73 /* align the second pointer for A */
74 const q7_t *pA2 = pA + numCol_A;
75
76 /* init the sum with bias */
77 q31_t sum = ((q31_t)(*pBias) << bias_shift) + NN_ROUND(out_shift);
78 q31_t sum2 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
79 q31_t sum3 = ((q31_t)(*pBias) << bias_shift) + NN_ROUND(out_shift);
80 q31_t sum4 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
81
82 uint16_t colCnt = numCol_A >> 2;
83 /* accumulate over the vector */
84 while (colCnt)
85 {
86 q31_t inA11, inA12, inA21, inA22;
87 q31_t inB1 = *__SIMD32(pB)++;
88 q31_t inB2 = *__SIMD32(pB2)++;
89
90 pA = (q7_t *) read_and_pad((void *)pA, &inA11, &inA12);
91 pA2 = (q7_t *) read_and_pad((void *)pA2, &inA21, &inA22);
92
93 sum = __SMLAD(inA11, inB1, sum);
94 sum2 = __SMLAD(inA11, inB2, sum2);
95 sum3 = __SMLAD(inA21, inB1, sum3);
96 sum4 = __SMLAD(inA21, inB2, sum4);
97
98 inB1 = *__SIMD32(pB)++;
99 inB2 = *__SIMD32(pB2)++;
100
101 sum = __SMLAD(inA12, inB1, sum);
102 sum2 = __SMLAD(inA12, inB2, sum2);
103 sum3 = __SMLAD(inA22, inB1, sum3);
104 sum4 = __SMLAD(inA22, inB2, sum4);
105
106 colCnt--;
107 } /* while over colCnt */
108 colCnt = numCol_A & 0x3;
109 while (colCnt)
110 {
111 q7_t inA1 = *pA++;
112 q15_t inB1 = *pB++;
113 q7_t inA2 = *pA2++;
114 q15_t inB2 = *pB2++;
115
116 sum += inA1 * inB1;
117 sum2 += inA1 * inB2;
118 sum3 += inA2 * inB1;
119 sum4 += inA2 * inB2;
120 colCnt--;
121 } /* while over colCnt */
122 *pOut++ = (q7_t) __SSAT((sum >> out_shift), 8);
123 *pOut++ = (q7_t) __SSAT((sum3 >> out_shift), 8);
124 *pOut2++ = (q7_t) __SSAT((sum2 >> out_shift), 8);
125 *pOut2++ = (q7_t) __SSAT((sum4 >> out_shift), 8);
126
127 /* skip the row computed with A2 */
128 pA += numCol_A;
129 rowCnt--;
130 } /* for over ch_im_out */
131
132 /* compute left-over row if any */
133 if (ch_im_out & 0x1)
134 {
135 /* setup pointers for B */
136 const q15_t *pB = pInBuffer;
137 const q15_t *pB2 = pB + numCol_A;
138
139 /* load the bias */
140 q31_t sum = ((q31_t)(*pBias) << bias_shift) + NN_ROUND(out_shift);
141 q31_t sum2 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
142
143 uint16_t colCnt = numCol_A >> 2;
144 while (colCnt)
145 {
146 q31_t inA11, inA12;
147 q31_t inB1 = *__SIMD32(pB)++;
148 q31_t inB2 = *__SIMD32(pB2)++;
149
150 pA = (q7_t *) read_and_pad((void *)pA, &inA11, &inA12);
151
152 sum = __SMLAD(inA11, inB1, sum);
153 sum2 = __SMLAD(inA11, inB2, sum2);
154
155 inB1 = *__SIMD32(pB)++;
156 inB2 = *__SIMD32(pB2)++;
157 sum = __SMLAD(inA12, inB1, sum);
158 sum2 = __SMLAD(inA12, inB2, sum2);
159
160 colCnt--;
161 }
162 colCnt = numCol_A & 0x3;
163 while (colCnt)
164 {
165 q7_t inA1 = *pA++;
166 q15_t inB1 = *pB++;
167 q15_t inB2 = *pB2++;
168
169 sum += inA1 * inB1;
170 sum2 += inA1 * inB2;
171 colCnt--;
172 }
173
174 *pOut++ = (q7_t) __SSAT((sum >> out_shift), 8);
175 *pOut2++ = (q7_t) __SSAT((sum2 >> out_shift), 8);
176 }
177
178 pOut += ch_im_out;
179
180 /* return the new output pointer with offset */
181 return pOut;
182 #else
183 /* To be completed */
184 return NULL;
185 #endif /* ARM_MATH_DSP */
186
187 }
188