xref: /aosp_15_r20/external/ComputeLibrary/src/core/utils/AssemblyUtils.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2021-2022 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "src/core/utils/AssemblyUtils.h"
25 
26 namespace arm_compute
27 {
28 namespace assembly_utils
29 {
map_to_arm_gemm_activation(const ActivationLayerInfo & act)30 arm_gemm::Activation map_to_arm_gemm_activation(const ActivationLayerInfo &act)
31 {
32     arm_gemm::Activation gemm_act;
33 
34     // Early exit in case lower bound is other than 0, as it's not yet supported
35     if(act.b() != 0.f)
36     {
37         return gemm_act;
38     }
39 
40     switch(act.activation())
41     {
42         case ActivationLayerInfo::ActivationFunction::RELU:
43             gemm_act.type = arm_gemm::Activation::Type::ReLU;
44             break;
45         case ActivationLayerInfo::ActivationFunction::BOUNDED_RELU:
46             gemm_act.type   = arm_gemm::Activation::Type::BoundedReLU;
47             gemm_act.param1 = act.a();
48             gemm_act.param2 = 0.f;
49             break;
50         case ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU:
51             gemm_act.type   = arm_gemm::Activation::Type::BoundedReLU;
52             gemm_act.param1 = act.a();
53             gemm_act.param2 = act.b();
54             break;
55         default:
56             gemm_act.type = arm_gemm::Activation::Type::None;
57     }
58 
59     return gemm_act;
60 }
61 
map_to_arm_conv_padding(const PadStrideInfo & pad_stride_info)62 arm_conv::PaddingValues map_to_arm_conv_padding(const PadStrideInfo &pad_stride_info)
63 {
64     return arm_conv::PaddingValues{ pad_stride_info.pad_left(),
65                                     pad_stride_info.pad_top(),
66                                     pad_stride_info.pad_right(),
67                                     pad_stride_info.pad_bottom() };
68 }
69 
map_to_arm_gemm_weight_format(const arm_compute::WeightFormat & weight_format)70 arm_gemm::WeightFormat map_to_arm_gemm_weight_format(const arm_compute::WeightFormat &weight_format)
71 {
72     arm_gemm::WeightFormat gemm_weight_fromat;
73 
74     switch(weight_format)
75     {
76         case arm_compute::WeightFormat::UNSPECIFIED:
77             gemm_weight_fromat = arm_gemm::WeightFormat::UNSPECIFIED;
78             break;
79         case arm_compute::WeightFormat::ANY:
80             gemm_weight_fromat = arm_gemm::WeightFormat::ANY;
81             break;
82         case arm_compute::WeightFormat::OHWI:
83             gemm_weight_fromat = arm_gemm::WeightFormat::OHWI;
84             break;
85         case arm_compute::WeightFormat::OHWIo2:
86             gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo2;
87             break;
88         case arm_compute::WeightFormat::OHWIo4:
89             gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo4;
90             break;
91         case arm_compute::WeightFormat::OHWIo8:
92             gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo8;
93             break;
94         case arm_compute::WeightFormat::OHWIo16:
95             gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo16;
96             break;
97         case arm_compute::WeightFormat::OHWIo32:
98             gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo32;
99             break;
100         case arm_compute::WeightFormat::OHWIo64:
101             gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo64;
102             break;
103         case arm_compute::WeightFormat::OHWIo128:
104             gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo128;
105             break;
106         case arm_compute::WeightFormat::OHWIo4i2:
107             gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo4i2;
108             break;
109         case arm_compute::WeightFormat::OHWIo4i2_bf16:
110             gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo4i2_bf16;
111             break;
112         case arm_compute::WeightFormat::OHWIo8i2:
113             gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo8i2;
114             break;
115         case arm_compute::WeightFormat::OHWIo8i2_bf16:
116             gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo8i2_bf16;
117             break;
118         case arm_compute::WeightFormat::OHWIo16i2:
119             gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo16i2;
120             break;
121         case arm_compute::WeightFormat::OHWIo16i2_bf16:
122             gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo16i2_bf16;
123             break;
124         case arm_compute::WeightFormat::OHWIo32i2:
125             gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo32i2;
126             break;
127         case arm_compute::WeightFormat::OHWIo32i2_bf16:
128             gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo32i2_bf16;
129             break;
130         case arm_compute::WeightFormat::OHWIo64i2:
131             gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo64i2;
132             break;
133         case arm_compute::WeightFormat::OHWIo64i2_bf16:
134             gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo64i2_bf16;
135             break;
136         case arm_compute::WeightFormat::OHWIo4i4:
137             gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo4i4;
138             break;
139         case arm_compute::WeightFormat::OHWIo4i4_bf16:
140             gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo4i4_bf16;
141             break;
142         case arm_compute::WeightFormat::OHWIo8i4:
143             gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo8i4;
144             break;
145         case arm_compute::WeightFormat::OHWIo8i4_bf16:
146             gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo8i4_bf16;
147             break;
148         case arm_compute::WeightFormat::OHWIo16i4:
149             gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo16i4;
150             break;
151         case arm_compute::WeightFormat::OHWIo16i4_bf16:
152             gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo16i4_bf16;
153             break;
154         case arm_compute::WeightFormat::OHWIo32i4:
155             gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo32i4;
156             break;
157         case arm_compute::WeightFormat::OHWIo32i4_bf16:
158             gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo32i4_bf16;
159             break;
160         case arm_compute::WeightFormat::OHWIo64i4:
161             gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo64i4;
162             break;
163         case arm_compute::WeightFormat::OHWIo64i4_bf16:
164             gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo64i4_bf16;
165             break;
166         case arm_compute::WeightFormat::OHWIo2i8:
167             gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo2i8;
168             break;
169         case arm_compute::WeightFormat::OHWIo4i8:
170             gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo4i8;
171             break;
172         case arm_compute::WeightFormat::OHWIo8i8:
173             gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo8i8;
174             break;
175         case arm_compute::WeightFormat::OHWIo16i8:
176             gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo16i8;
177             break;
178         case arm_compute::WeightFormat::OHWIo32i8:
179             gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo32i8;
180             break;
181         case arm_compute::WeightFormat::OHWIo64i8:
182             gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo64i8;
183             break;
184         default:
185             gemm_weight_fromat = arm_gemm::WeightFormat::UNSPECIFIED;
186     }
187     return gemm_weight_fromat;
188 }
189 
map_to_arm_compute_weight_format(const arm_gemm::WeightFormat & weight_format)190 arm_compute::WeightFormat map_to_arm_compute_weight_format(const arm_gemm::WeightFormat &weight_format)
191 {
192     arm_compute::WeightFormat acl_weight_fromat;
193 
194     switch(weight_format)
195     {
196         case arm_gemm::WeightFormat::UNSPECIFIED:
197             acl_weight_fromat = arm_compute::WeightFormat::UNSPECIFIED;
198             break;
199         case arm_gemm::WeightFormat::ANY:
200             acl_weight_fromat = arm_compute::WeightFormat::ANY;
201             break;
202         case arm_gemm::WeightFormat::OHWI:
203             acl_weight_fromat = arm_compute::WeightFormat::OHWI;
204             break;
205         case arm_gemm::WeightFormat::OHWIo2:
206             acl_weight_fromat = arm_compute::WeightFormat::OHWIo2;
207             break;
208         case arm_gemm::WeightFormat::OHWIo4:
209             acl_weight_fromat = arm_compute::WeightFormat::OHWIo4;
210             break;
211         case arm_gemm::WeightFormat::OHWIo8:
212             acl_weight_fromat = arm_compute::WeightFormat::OHWIo8;
213             break;
214         case arm_gemm::WeightFormat::OHWIo16:
215             acl_weight_fromat = arm_compute::WeightFormat::OHWIo16;
216             break;
217         case arm_gemm::WeightFormat::OHWIo32:
218             acl_weight_fromat = arm_compute::WeightFormat::OHWIo32;
219             break;
220         case arm_gemm::WeightFormat::OHWIo64:
221             acl_weight_fromat = arm_compute::WeightFormat::OHWIo64;
222             break;
223         case arm_gemm::WeightFormat::OHWIo128:
224             acl_weight_fromat = arm_compute::WeightFormat::OHWIo128;
225             break;
226         case arm_gemm::WeightFormat::OHWIo4i2:
227             acl_weight_fromat = arm_compute::WeightFormat::OHWIo4i2;
228             break;
229         case arm_gemm::WeightFormat::OHWIo4i2_bf16:
230             acl_weight_fromat = arm_compute::WeightFormat::OHWIo4i2_bf16;
231             break;
232         case arm_gemm::WeightFormat::OHWIo8i2:
233             acl_weight_fromat = arm_compute::WeightFormat::OHWIo8i2;
234             break;
235         case arm_gemm::WeightFormat::OHWIo8i2_bf16:
236             acl_weight_fromat = arm_compute::WeightFormat::OHWIo8i2_bf16;
237             break;
238         case arm_gemm::WeightFormat::OHWIo16i2:
239             acl_weight_fromat = arm_compute::WeightFormat::OHWIo16i2;
240             break;
241         case arm_gemm::WeightFormat::OHWIo16i2_bf16:
242             acl_weight_fromat = arm_compute::WeightFormat::OHWIo16i2_bf16;
243             break;
244         case arm_gemm::WeightFormat::OHWIo32i2:
245             acl_weight_fromat = arm_compute::WeightFormat::OHWIo32i2;
246             break;
247         case arm_gemm::WeightFormat::OHWIo32i2_bf16:
248             acl_weight_fromat = arm_compute::WeightFormat::OHWIo32i2_bf16;
249             break;
250         case arm_gemm::WeightFormat::OHWIo64i2:
251             acl_weight_fromat = arm_compute::WeightFormat::OHWIo64i2;
252             break;
253         case arm_gemm::WeightFormat::OHWIo64i2_bf16:
254             acl_weight_fromat = arm_compute::WeightFormat::OHWIo64i2_bf16;
255             break;
256         case arm_gemm::WeightFormat::OHWIo4i4:
257             acl_weight_fromat = arm_compute::WeightFormat::OHWIo4i4;
258             break;
259         case arm_gemm::WeightFormat::OHWIo4i4_bf16:
260             acl_weight_fromat = arm_compute::WeightFormat::OHWIo4i4_bf16;
261             break;
262         case arm_gemm::WeightFormat::OHWIo8i4:
263             acl_weight_fromat = arm_compute::WeightFormat::OHWIo8i4;
264             break;
265         case arm_gemm::WeightFormat::OHWIo8i4_bf16:
266             acl_weight_fromat = arm_compute::WeightFormat::OHWIo8i4_bf16;
267             break;
268         case arm_gemm::WeightFormat::OHWIo16i4:
269             acl_weight_fromat = arm_compute::WeightFormat::OHWIo16i4;
270             break;
271         case arm_gemm::WeightFormat::OHWIo16i4_bf16:
272             acl_weight_fromat = arm_compute::WeightFormat::OHWIo16i4_bf16;
273             break;
274         case arm_gemm::WeightFormat::OHWIo32i4:
275             acl_weight_fromat = arm_compute::WeightFormat::OHWIo32i4;
276             break;
277         case arm_gemm::WeightFormat::OHWIo32i4_bf16:
278             acl_weight_fromat = arm_compute::WeightFormat::OHWIo32i4_bf16;
279             break;
280         case arm_gemm::WeightFormat::OHWIo64i4:
281             acl_weight_fromat = arm_compute::WeightFormat::OHWIo64i4;
282             break;
283         case arm_gemm::WeightFormat::OHWIo64i4_bf16:
284             acl_weight_fromat = arm_compute::WeightFormat::OHWIo64i4_bf16;
285             break;
286         case arm_gemm::WeightFormat::OHWIo2i8:
287             acl_weight_fromat = arm_compute::WeightFormat::OHWIo2i8;
288             break;
289         case arm_gemm::WeightFormat::OHWIo4i8:
290             acl_weight_fromat = arm_compute::WeightFormat::OHWIo4i8;
291             break;
292         case arm_gemm::WeightFormat::OHWIo8i8:
293             acl_weight_fromat = arm_compute::WeightFormat::OHWIo8i8;
294             break;
295         case arm_gemm::WeightFormat::OHWIo16i8:
296             acl_weight_fromat = arm_compute::WeightFormat::OHWIo16i8;
297             break;
298         case arm_gemm::WeightFormat::OHWIo32i8:
299             acl_weight_fromat = arm_compute::WeightFormat::OHWIo32i8;
300             break;
301         case arm_gemm::WeightFormat::OHWIo64i8:
302             acl_weight_fromat = arm_compute::WeightFormat::OHWIo64i8;
303             break;
304         default:
305             acl_weight_fromat = arm_compute::WeightFormat::UNSPECIFIED;
306     }
307     return acl_weight_fromat;
308 }
309 } // namespace assembly_utils
310 } // namespace arm_compute
311