/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #pragma once #include #include #include #include namespace executorch { namespace cpublas { enum class TransposeType { NoTranspose, Transpose, ConjTranspose, }; // clang-format off void normalize_last_dims( TransposeType transa, TransposeType transb, int64_t m, int64_t n, int64_t k, int64_t *lda, int64_t *ldb, int64_t *ldc); // clang-format on inline char to_blas(TransposeType trans) { switch (trans) { case TransposeType::Transpose: return 'T'; case TransposeType::NoTranspose: return 'N'; case TransposeType::ConjTranspose: return 'C'; } // Assume no transpose by default return 'N'; } // clang-format off template void gemm_impl( TransposeType transa, TransposeType transb, int64_t m, int64_t n, int64_t k, opmath_t alpha, const scalar_t *a, int64_t lda, const scalar_t *b, int64_t ldb, opmath_t beta, scalar_t *c, int64_t ldc) { if (transa == TransposeType::NoTranspose && transb == TransposeType::NoTranspose) { return gemm_notrans_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } else if ( transa == TransposeType::Transpose && transb != TransposeType::Transpose) { gemm_transa_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } else if ( transa == TransposeType::NoTranspose && transb == TransposeType::Transpose) { gemm_transb_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } else { // transa == TransposeType::Transpose && transb == // TransposeType::Transpose gemm_transab_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } } // clang-format on // clang-format off void gemm( TransposeType transa, TransposeType transb, int64_t m, int64_t n, int64_t k, double alpha, const double *a, int64_t lda, const double *b, int64_t ldb, double beta, double *c, int64_t ldc); // clang-format on // clang-format off void gemm( TransposeType transa, TransposeType transb, int64_t m, int64_t n, int64_t k, const float alpha, const float *a, int64_t lda, const float *b, int64_t ldb, const float beta, float *c, int64_t ldc); // clang-format on // clang-format off void gemm( TransposeType transa, TransposeType transb, int64_t m, int64_t n, int64_t k, const exec_aten::Half alpha, const exec_aten::Half *a, int64_t lda, const exec_aten::Half *b, int64_t ldb, const exec_aten::Half beta, exec_aten::Half *c, int64_t ldc); void gemm( TransposeType transa, TransposeType transb, int64_t m, int64_t n, int64_t k, const exec_aten::BFloat16 alpha, const exec_aten::BFloat16 *a, int64_t lda, const exec_aten::BFloat16 *b, int64_t ldb, const exec_aten::BFloat16 beta, exec_aten::BFloat16 *c, int64_t ldc); // clang-format on // clang-format off template ::value, int>::type = 0> void gemm( TransposeType transa, TransposeType transb, int64_t m, int64_t n, int64_t k, const T alpha, const T *a, int64_t lda, const T *b, int64_t ldb, const T beta, T *c, int64_t ldc) { normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc); using acc_type = utils::compute_dtype; gemm_impl( transa, transb, m, n, k, static_cast(alpha), a, lda, b, ldb, static_cast(beta), c, ldc); } // clang-format on } // namespace cpublas } // namespace executorch