1 /*
2  * Copyright (c) 2022 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in
14  * all copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
21  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
22  * IN THE SOFTWARE.
23  */
24 #pragma once
25 #ifdef ARM_COMPUTE_ENABLE_SVE
26 
27 #include "../std_transforms_sve.hpp"
28 #include "../kernel_weight_format.hpp"
29 #include "../performance_parameters.hpp"
30 
31 #define ARGLIST  \
32     unsigned int, const unsigned int *, \
33     IndirectInputArg<float>, \
34     size_t, size_t, \
35     const float *, \
36     size_t, \
37     IndirectOutputArg<float>, \
38     const float *, Activation, bool
39 
40 namespace arm_gemm
41 {
42 // Actual kernel implementations
43 void sve_ffhybrid_fp32_mla_6x4VL( ARGLIST );
44 void sve_ffhybrid_fp32_mla_6x4VL_a64fx( ARGLIST );
45 
46 class cls_sve_ffhybrid_fp32_mla_6x4VL
47 {
48 public:
49     typedef float lhs_operand_type;
50     typedef float rhs_operand_type;
51     typedef float result_type;
52 
53     typedef void (*kern_type)( ARGLIST );
54 
55     /* Kernel blocking parameters */
out_height()56     static constexpr unsigned int out_height()
57     {
58         return 6;
59     }
stripe_width()60     static unsigned int stripe_width()
61     {
62         return get_vector_length<float>() * 1;
63     }
64 
kernel_weight_format()65     static KernelWeightFormat kernel_weight_format()
66     {
67         return KernelWeightFormat::VL1VL_BL32;
68     }
69 
out_width()70     static unsigned int out_width()
71     {
72         return get_vector_length<float>() * 4;
73     }
74 
k_unroll()75     static constexpr unsigned int k_unroll()
76     {
77         return 1;
78     }
79 
supports_accumulate()80     static constexpr bool supports_accumulate()
81     {
82         return true;
83     }
84 
85     StdTransformsSVE<rhs_operand_type, result_type, 6, 4, 1> transforms = {};
86     template<typename T>
get_performance_parameters(const CPUInfo * ci)87     static inline PerformanceParameters get_performance_parameters(const CPUInfo *ci)
88     {
89         if (std::is_same<T, float>::value) {
90             switch (ci->get_cpu_model()) {
91                 default:
92                     return { 15.27 };
93             }
94         }
95 
96         return { 1.0 };
97     }
98 
99     // Default to the generic kernel
100     kern_type kernel=sve_ffhybrid_fp32_mla_6x4VL;
cls_sve_ffhybrid_fp32_mla_6x4VL(const CPUInfo * ci)101     cls_sve_ffhybrid_fp32_mla_6x4VL(const CPUInfo *ci)
102     {
103         switch(ci->get_cpu_model()) {
104             default:
105                 break;
106             case CPUModel::A64FX:
107                 kernel=sve_ffhybrid_fp32_mla_6x4VL_a64fx;
108                 break;
109         }
110     }
111 };
112 
113 } // namespace arm_gemm
114 
115 #undef ARGLIST
116 #endif // ARM_COMPUTE_ENABLE_SVE
117