1 /* 2 * Copyright (c) 2022 Arm Limited. 3 * 4 * SPDX-License-Identifier: MIT 5 * 6 * Permission is hereby granted, free of charge, to any person obtaining a copy 7 * of this software and associated documentation files (the "Software"), to 8 * deal in the Software without restriction, including without limitation the 9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 10 * sell copies of the Software, and to permit persons to whom the Software is 11 * furnished to do so, subject to the following conditions: 12 * 13 * The above copyright notice and this permission notice shall be included in all 14 * copies or substantial portions of the Software. 15 * 16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 * SOFTWARE. 23 */ 24 #pragma once 25 26 #include "arm_gemm.hpp" 27 28 namespace arm_gemm { 29 30 /* Internal enum to define the weight format a kernel is expecting. 31 * 32 * This is distinct from the "external" WeightFormat defined in arm_gemm.hpp primarily to allow for SVE, where 33 * internally kernels are defined in terms of multiples of the SVE vector length, but externally they are converted 34 * to a fixed format (based on the VL of the machine we are running on). 35 * 36 * Encoded as a bitfield: 37 * bit 0 : SVE flag 38 * bit 4 : BF16 convert flag (fast mode) 39 * bits 11-8 : block length (bytes) 40 * bits 15-12: vector count 41 */ 42 enum class KernelWeightFormat { 43 NON_FIXED = 0, 44 VL128_BL16 = 0x1200, 45 VL128_BL32 = 0x1400, 46 VL128_BL32_BF16 = 0x1410, 47 VL128_BL64 = 0x1800, 48 VL256_BL64 = 0x2800, 49 VL256_BL64_BF16 = 0x2810, 50 VL1VL_BL16 = 0x1201, 51 VL1VL_BL32 = 0x1401, 52 VL1VL_BL32_BF16 = 0x1411, 53 VL1VL_BL64 = 0x1801, 54 VL2VL_BL64 = 0x2801, 55 VL2VL_BL64_BF16 = 0x2811 56 }; 57 58 WeightFormat get_weight_format(const KernelWeightFormat, size_t); 59 60 } // namespace arm_gemm 61