#define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include #include #include #include #include #include #if AT_BUILD_WITH_BLAS() #if C10_IOS #include #else extern "C" void dgemm_(char *transa, char *transb, int *m, int *n, int *k, double *alpha, const double *a, int *lda, const double *b, int *ldb, double *beta, double *c, int *ldc); extern "C" void sgemm_(char *transa, char *transb, int *m, int *n, int *k, float *alpha, const float *a, int *lda, const float *b, int *ldb, float *beta, float *c, int *ldc); extern "C" void cgemm_(char *transa, char *transb, int *m, int *n, int *k, void *alpha, const void *a, int *lda, const void *b, int *ldb, void *beta, void *c, int *ldc); extern "C" void zgemm_(char *transa, char *transb, int *m, int *n, int *k, void *alpha, const void *a, int *lda, const void *b, int *ldb, void *beta, void *c, int *ldc); #ifdef BLAS_HAS_SBGEMM extern "C" void sbgemm_(char *transa, char *transb, int *m, int *n, int *k, float *alpha, const at::BFloat16 *a, int *lda, const at::BFloat16 *b, int *ldb, float *beta, float *c, int *ldc); #endif // BLAS_HAS_SBGEMM extern "C" void cswap_(int *n, const void *x, int *incx, void *y, int *incy); extern "C" void dcopy_(int *n, const double *x, int *incx, double *y, int *incy); extern "C" void scopy_(int *n, const float *x, int *incx, float *y, int *incy); extern "C" void zcopy_(int *n, const void *x, int *incx, void *y, int *incy); extern "C" void ccopy_(int *n, const void *x, int *incx, void *y, int *incy); extern "C" void daxpy_(int *n, double *a, const double *x, int *incx, double *y, int *incy); extern "C" void saxpy_(int *n, float *a, const float *x, int *incx, float *y, int *incy); extern "C" void caxpy_(int *n, void *a, const void *x, int *incx, void *y, int *incy); extern "C" void zaxpy_(int *n, void *a, const void *x, int *incx, void *y, int *incy); #endif // C10_IOS #endif // AT_BUILD_WITH_BLAS #ifdef USE_FBGEMM #include #endif // USE_FBGEMM #if AT_MKLDNN_ENABLED() #include #endif // oneDNN #define ONEDNN_UKERNEL_ENABLED (DNNL_VERSION_MAJOR >=3 && DNNL_VERSION_MINOR >=5) #if ONEDNN_UKERNEL_ENABLED && (defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC))) #include #include #endif // oneDNN BRGEMM namespace at::native::cpublas { namespace internal { void normalize_last_dims( TransposeType transa, TransposeType transb, int64_t m, int64_t n, int64_t k, int64_t *lda, int64_t *ldb, int64_t *ldc) { if (n == 1) { *ldc = m; } if(transa != TransposeType::NoTranspose) { if (m == 1) { *lda = k; } } else if(k == 1) { *lda = m; } if(transb != TransposeType::NoTranspose) { if (k == 1) { *ldb = n; } } else if (n == 1) { *ldb = k; } } } // namespace internal namespace { C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunneeded-internal-declaration") bool use_blas_gemm( TransposeType transa, TransposeType transb, int64_t m, int64_t n, int64_t k, int64_t lda, int64_t ldb, int64_t ldc) { const bool transa_ = transa != TransposeType::NoTranspose; const bool transb_ = transb != TransposeType::NoTranspose; return ( (m <= INT_MAX) && (n <= INT_MAX) && (k <= INT_MAX) && (lda <= INT_MAX) && (ldb <= INT_MAX) && (ldc <= INT_MAX) && (lda >= std::max(int64_t{1}, (transa_ ? k : m))) && (ldb >= std::max(int64_t{1}, (transb_ ? n : k))) && (ldc >= std::max(int64_t{1}, m))); } C10_DIAGNOSTIC_POP() #ifdef USE_FBGEMM fbgemm::matrix_op_t to_fbgemm(TransposeType trans) { switch (trans) { case TransposeType::Transpose: return fbgemm::matrix_op_t::Transpose; case TransposeType::NoTranspose: return fbgemm::matrix_op_t::NoTranspose; case TransposeType::ConjTranspose: TORCH_INTERNAL_ASSERT(false, "ConjTranspose type is not supported in fbgemm"); } TORCH_INTERNAL_ASSERT(false, "Invalid transpose type"); } #endif // USE_FBGEMM #if (AT_BUILD_WITH_BLAS() && C10_IOS) CBLAS_TRANSPOSE to_apple_accelerate_transpose(TransposeType trans) { switch (trans) { case TransposeType::Transpose: return CblasTrans; case TransposeType::NoTranspose: return CblasNoTrans; case TransposeType::ConjTranspose: return CblasConjTrans; } TORCH_INTERNAL_ASSERT(false, "Invalid transpose type"); } #endif } // namespace (anonymous) DEFINE_DISPATCH(gemm_stub); void gemm( TransposeType transa, TransposeType transb, int64_t m, int64_t n, int64_t k, const double alpha, const double *a, int64_t lda, const double *b, int64_t ldb, const double beta, double *c, int64_t ldc) { internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc); #if AT_BUILD_WITH_BLAS() if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) { int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc; double alpha_ = alpha, beta_ = beta; #if C10_IOS CBLAS_TRANSPOSE transa_ = to_apple_accelerate_transpose(transa); CBLAS_TRANSPOSE transb_ = to_apple_accelerate_transpose(transb); cblas_dgemm(CblasColMajor, transa_, transb_, m_, n_, k_, alpha_, a, lda_, b, ldb_, beta_, c, ldc_); #else char transa_ = to_blas(transa), transb_ = to_blas(transb); dgemm_( &transa_, &transb_, &m_, &n_, &k_, &alpha_, a, &lda_, b, &ldb_, &beta_, c, &ldc_); #endif return; } #endif gemm_stub( at::kCPU, at::kDouble, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void gemm( TransposeType transa, TransposeType transb, int64_t m, int64_t n, int64_t k, const float alpha, const float *a, int64_t lda, const float *b, int64_t ldb, const float beta, float *c, int64_t ldc) { internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc); #if AT_MKLDNN_ENABLED() if (mkldnn_bf32_gemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)) { return; } #endif #if AT_BUILD_WITH_BLAS() if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) { int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc; float alpha_ = alpha, beta_ = beta; #if C10_IOS CBLAS_TRANSPOSE transa_ = to_apple_accelerate_transpose(transa); CBLAS_TRANSPOSE transb_ = to_apple_accelerate_transpose(transb); cblas_sgemm(CblasColMajor, transa_, transb_, m_, n_, k_, alpha_, a, lda_, b, ldb_, beta_, c, ldc_); #else char transa_ = to_blas(transa), transb_ = to_blas(transb); sgemm_( &transa_, &transb_, &m_, &n_, &k_, &alpha_, a, &lda_, b, &ldb_, &beta_, c, &ldc_); #endif return; } #endif gemm_stub( at::kCPU, at::kFloat, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void gemm( TransposeType transa, TransposeType transb, int64_t m, int64_t n, int64_t k, const c10::complex alpha, const c10::complex *a, int64_t lda, const c10::complex *b, int64_t ldb, const c10::complex beta, c10::complex *c, int64_t ldc) { internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc); #if AT_BUILD_WITH_BLAS() if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) { int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc; c10::complex alpha_ = alpha, beta_ = beta; #if C10_IOS CBLAS_TRANSPOSE transa_ = to_apple_accelerate_transpose(transa); CBLAS_TRANSPOSE transb_ = to_apple_accelerate_transpose(transb); cblas_zgemm(CblasColMajor, transa_, transb_, m_, n_, k_, &alpha_, a, lda_, b, ldb_, &beta_, c, ldc_); #else char transa_ = to_blas(transa), transb_ = to_blas(transb); zgemm_( &transa_, &transb_, &m_, &n_, &k_, &alpha_, a, &lda_, b, &ldb_, &beta_, c, &ldc_); #endif return; } #endif gemm_stub( at::kCPU, at::kComplexDouble, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void gemm( TransposeType transa, TransposeType transb, int64_t m, int64_t n, int64_t k, const c10::complex alpha, const c10::complex *a, int64_t lda, const c10::complex *b, int64_t ldb, const c10::complex beta, c10::complex *c, int64_t ldc) { internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc); #if AT_BUILD_WITH_BLAS() if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) { int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc; c10::complex alpha_ = alpha, beta_ = beta; #if C10_IOS CBLAS_TRANSPOSE transa_ = to_apple_accelerate_transpose(transa); CBLAS_TRANSPOSE transb_ = to_apple_accelerate_transpose(transb); cblas_cgemm(CblasColMajor, transa_, transb_, m_, n_, k_, &alpha_, a, lda_, b, ldb_, &beta_, c, ldc_); #else char transa_ = to_blas(transa), transb_ = to_blas(transb); cgemm_( &transa_, &transb_, &m_, &n_, &k_, &alpha_, a, &lda_, b, &ldb_, &beta_, c, &ldc_); #endif return; } #endif gemm_stub( at::kCPU, at::kComplexFloat, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void gemm( TransposeType transa, TransposeType transb, int64_t m, int64_t n, int64_t k, const float alpha, const at::BFloat16 *a, int64_t lda, const at::BFloat16 *b, int64_t ldb, const float beta, at::BFloat16 *c, int64_t ldc) { internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc); #if AT_BUILD_WITH_BLAS() && defined(BLAS_HAS_SBGEMM) if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) { int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc; char transa_ = to_blas(transa), transb_ = to_blas(transb); float alpha_ = alpha, beta_ = beta; int c_size = n_ * ldc_; // C matrix in OpenBLAS sbgemm are of type "float" so we have to convert, copy and copy back. std::vector float_v(c, c + c_size); sbgemm_(&transa_, &transb_, &m_, &n_, &k_, &alpha_, a, &lda_, b, &ldb_, &beta_, float_v.data(), &ldc_); for (auto cv: float_v) { *(c++) = c10::convert(cv); } return; } #endif #if AT_MKLDNN_ENABLED() if (mkldnn_bf16_gemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)) { return; } #endif gemm_stub( at::kCPU, at::kBFloat16, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void gemm( TransposeType transa, TransposeType transb, int64_t m, int64_t n, int64_t k, const float alpha, const at::Half *a, int64_t lda, const at::Half *b, int64_t ldb, const float beta, at::Half *c, int64_t ldc) { internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc); #if AT_MKLDNN_ENABLED() if (mkldnn_fp16_gemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)) { return; } #endif gemm_stub( at::kCPU, at::kHalf, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void gemm( TransposeType transa, TransposeType transb, int64_t m, int64_t n, int64_t k, const float alpha, const at::BFloat16 *a, int64_t lda, const at::BFloat16 *b, int64_t ldb, const float beta, float *c, int64_t ldc) { internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc); #if AT_BUILD_WITH_BLAS() && defined(BLAS_HAS_SBGEMM) if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) { int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc; char transa_ = to_blas(transa), transb_ = to_blas(transb); float alpha_ = alpha, beta_ = beta; sbgemm_(&transa_, &transb_, &m_, &n_, &k_, &alpha_, a, &lda_, b, &ldb_, &beta_, c, &ldc_); return; } #endif #ifdef MKL_HAS_SBGEMM if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) { int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc; mkl_gemm_bf16bf16f32(transa, transb, m_, n_, k_, alpha, a, lda_, b, ldb_, beta, c, ldc_); return; } #endif // for the fallback path, first compute gemm with beta = 0, // and then add c in full precision. int64_t c_size = n * m; std::vector bfloat_c(c_size, 0.f); gemm_stub( at::kCPU, at::kBFloat16, transa, transb, m, n, k, alpha, a, lda, b, ldb, 0.f, bfloat_c.data(), m); for (const auto j : c10::irange(n)) { for (const auto i : c10::irange(m)) { auto offset = j * ldc + i; // beta == 0 won't propagate NaN from C if (beta == 0.f) { c[offset] = c10::convert(bfloat_c[j * m + i]); } else { c[offset] = beta * c[offset] + c10::convert(bfloat_c[j * m + i]); } } } } void gemm( TransposeType transa, TransposeType transb, int64_t m, int64_t n, int64_t k, const float alpha, const at::Half *a, int64_t lda, const at::Half *b, int64_t ldb, const float beta, float *c, int64_t ldc) { internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc); #ifdef MKL_HAS_SHGEMM if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) { int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc; mkl_gemm_f16f16f32(transa, transb, m_, n_, k_, alpha, a, lda_, b, ldb_, beta, c, ldc_); return; } #endif // for the fallback path, first compute gemm with beta = 0, // and then add c in full precision. int64_t c_size = n * m; std::vector float16_c(c_size, 0.f); gemm_stub( at::kCPU, at::kHalf, transa, transb, m, n, k, alpha, a, lda, b, ldb, 0.f, float16_c.data(), m); for (const auto j : c10::irange(n)) { for (const auto i : c10::irange(m)) { auto offset = j * ldc + i; // beta == 0 won't propagate NaN from C if (beta == 0.f) { c[offset] = c10::convert(float16_c[j * m + i]); } else { c[offset] = beta * c[offset] + c10::convert(float16_c[j * m + i]); } } } } void gemm( TransposeType transa, TransposeType transb, int64_t m, int64_t n, int64_t k, const int64_t alpha, const int64_t *a, int64_t lda, const int64_t *b, int64_t ldb, const int64_t beta, int64_t *c, int64_t ldc) { internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc); #ifdef USE_FBGEMM if (alpha == 1 && (beta == 0 || beta == 1)) { // In FBGEMM, we assume row-major ordering; However, here we assume the // column-major ordering following the FORTRAN tradition in BLAS interface // in this function: we can configure the layout (row/column-major ordering) // of A and B by changing transa_ and transb_, but we cannot change the // layout of C with this FORTRAN-style BLAS interface. // // The workaround is that we compute // C^T (n x m) = B^T (n x k) * A^T (k x m) instead. // // In this way we view C^T as the row-major ordering when passing to FBGEMM. fbgemm::cblas_gemm_i64_i64acc( to_fbgemm(transb), to_fbgemm(transa), n, m, k, b, ldb, a, lda, beta == 1, c, ldc); return; } #endif gemm_stub( kCPU, kLong, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } template static void gemm_batched_mkl_impl( TransposeType transa, TransposeType transb, int64_t batch_size, int64_t m, int64_t n, int64_t k, scalar_t alpha, const scalar_t **a, int64_t lda, const scalar_t **b, int64_t ldb, scalar_t beta, scalar_t **c, int64_t ldc) { for (int64_t i = 0; i < batch_size;) { int sub_batch = std::min(batch_size - i, int64_t{INT_MAX}); mkl_gemm_batched(transa, transb, sub_batch, m, n, k, alpha, &a[i], lda, &b[i], ldb, beta, &c[i], ldc); i += sub_batch; } } template using is_blas_library_type = std::integral_constant || std::is_same_v || std::is_same_v> || std::is_same_v>>; template void gemm_batched_generic( TransposeType transa, TransposeType transb, int64_t batch_size, int64_t m, int64_t n, int64_t k, scalar_t alpha, const scalar_t **a, int64_t lda, const scalar_t **b, int64_t ldb, scalar_t beta, scalar_t **c, int64_t ldc) { for (const auto batch : c10::irange(batch_size)) { gemm(transa, transb, m, n, k, alpha, a[batch], lda, b[batch], ldb, beta, c[batch], ldc); } } template void gemm_batched( TransposeType transa, TransposeType transb, int64_t batch_size, int64_t m, int64_t n, int64_t k, scalar_t alpha, const scalar_t **a, int64_t lda, const scalar_t **b, int64_t ldb, scalar_t beta, scalar_t **c, int64_t ldc) { if (batch_size == 1) { return gemm(transa, transb, m, n, k, alpha, a[0], lda, b[0], ldb, beta, c[0], ldc); } if constexpr (AT_MKL_ENABLED() && is_blas_library_type::value) { internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc); if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) { gemm_batched_mkl_impl( transa, transb, batch_size, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } else { gemm_batched_generic( transa, transb, batch_size, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } } else { gemm_batched_generic( transa, transb, batch_size, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } } template void gemm_batched_with_stride_generic( TransposeType transa, TransposeType transb, int64_t batch_size, int64_t m, int64_t n, int64_t k, scalar_t alpha, const scalar_t *a, int64_t lda, int64_t batch_stride_a, const scalar_t *b, int64_t ldb, int64_t batch_stride_b, scalar_t beta, scalar_t *c, int64_t ldc, int64_t batch_stride_c) { for (const auto batch : c10::irange(batch_size)) { const auto a_batch = a + batch_stride_a * batch; const auto b_batch = b + batch_stride_b * batch; const auto c_batch = c + batch_stride_c * batch; gemm(transa, transb, m, n, k, alpha, a_batch, lda, b_batch, ldb, beta, c_batch, ldc); } } template void gemm_batched_with_stride( TransposeType transa, TransposeType transb, int64_t batch_size, int64_t m, int64_t n, int64_t k, scalar_t alpha, const scalar_t *a, int64_t lda, int64_t batch_stride_a, const scalar_t *b, int64_t ldb, int64_t batch_stride_b, scalar_t beta, scalar_t *c, int64_t ldc, int64_t batch_stride_c) { if (batch_size == 1) { return gemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } if constexpr (AT_MKL_ENABLED() && is_blas_library_type::value) { internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc); if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) { c10::SmallBuffer a_ptrs(batch_size); c10::SmallBuffer b_ptrs(batch_size); c10::SmallBuffer c_ptrs(batch_size); for (const auto batch : c10::irange(batch_size)) { a_ptrs[batch] = a + batch_stride_a * batch; b_ptrs[batch] = b + batch_stride_b * batch; c_ptrs[batch] = c + batch_stride_c * batch; } gemm_batched_mkl_impl( transa, transb, batch_size, m, n, k, alpha, a_ptrs.data(), lda, b_ptrs.data(), ldb, beta, c_ptrs.data(), ldc); } else { gemm_batched_with_stride_generic( transa, transb, batch_size, m, n, k, alpha, a, lda, batch_stride_a, b, ldb, batch_stride_b, beta, c, ldc, batch_stride_c); } } else { gemm_batched_with_stride_generic(transa, transb, batch_size, m, n, k, alpha, a, lda, batch_stride_a, b, ldb, batch_stride_b, beta, c, ldc, batch_stride_c); } } #define INSTANTIATE_BATCHED_GEMM(scalar_t, DType) \ template void gemm_batched( \ TransposeType transa, TransposeType transb, \ int64_t batch_size, int64_t m, int64_t n, int64_t k, \ scalar_t alpha, \ const scalar_t **a, int64_t lda, \ const scalar_t **b, int64_t ldb, \ scalar_t beta, \ scalar_t **c, int64_t ldc); \ template void gemm_batched_with_stride( \ TransposeType transa, TransposeType transb, \ int64_t batch_size, int64_t m, int64_t n, int64_t k, \ scalar_t alpha, \ const scalar_t *a, int64_t lda, int64_t batch_stride_a, \ const scalar_t *b, int64_t ldb, int64_t batch_stride_b, \ scalar_t beta, \ scalar_t *c, int64_t ldc, int64_t batch_stride_c); AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF_F8NZ(INSTANTIATE_BATCHED_GEMM) DEFINE_DISPATCH(axpy_stub); void axpy(int64_t n, double a, const double *x, int64_t incx, double *y, int64_t incy) { if(n == 1) { incx = 1; incy = 1; } #if AT_BUILD_WITH_BLAS() if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) { int i_n = (int)n; int i_incx = (int)incx; int i_incy = (int)incy; #if C10_IOS cblas_daxpy(i_n, a, x, i_incx, y, i_incy); #else daxpy_(&i_n, &a, x, &i_incx, y, &i_incy); #endif return; } #endif axpy_stub( kCPU, at::kDouble, n, a, x, incx, y, incy); } void axpy(int64_t n, float a, const float *x, int64_t incx, float *y, int64_t incy) { if(n == 1) { incx = 1; incy = 1; } #if AT_BUILD_WITH_BLAS() if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) { int i_n = (int)n; int i_incx = (int)incx; int i_incy = (int)incy; #if C10_IOS cblas_saxpy(i_n, a, x, i_incx, y, i_incy); #else saxpy_(&i_n, &a, x, &i_incx, y, &i_incy); #endif return; } #endif axpy_stub( kCPU, at::kFloat, n, a, x, incx, y, incy); } void axpy(int64_t n, c10::complex a, const c10::complex *x, int64_t incx, c10::complex *y, int64_t incy) { if(n == 1) { incx = 1; incy = 1; } #if AT_BUILD_WITH_BLAS() if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) { int i_n = (int)n; int i_incx = (int)incx; int i_incy = (int)incy; #if C10_IOS cblas_zaxpy(i_n, &a, x, i_incx, y, i_incy); #else zaxpy_(&i_n, &a, x, &i_incx, y, &i_incy); #endif return; } #endif axpy_stub( kCPU, at::kComplexDouble, n, a, x, incx, y, incy); } void axpy(int64_t n, c10::complex a, const c10::complex *x, int64_t incx, c10::complex *y, int64_t incy) { if(n == 1) { incx = 1; incy = 1; } #if AT_BUILD_WITH_BLAS() if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) { int i_n = (int)n; int i_incx = (int)incx; int i_incy = (int)incy; #if C10_IOS cblas_caxpy(i_n, &a, x, i_incx, y, i_incy); #else caxpy_(&i_n, &a, x, &i_incx, y, &i_incy); #endif return; } #endif axpy_stub( kCPU, at::kComplexFloat, n, a, x, incx, y, incy); } DEFINE_DISPATCH(copy_stub); void copy(int64_t n, const double *x, int64_t incx, double *y, int64_t incy) { if(n == 1) { incx = 1; incy = 1; } #if AT_BUILD_WITH_BLAS() if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) { int i_n = (int)n; int i_incx = (int)incx; int i_incy = (int)incy; #if C10_IOS cblas_dcopy(i_n, x, i_incx, y, i_incy); #else dcopy_(&i_n, x, &i_incx, y, &i_incy); #endif return; } #endif copy_stub( kCPU, at::kDouble, n, x, incx, y, incy); } void copy(int64_t n, const float *x, int64_t incx, float *y, int64_t incy) { if(n == 1) { incx = 1; incy = 1; } #if AT_BUILD_WITH_BLAS() if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) { int i_n = (int)n; int i_incx = (int)incx; int i_incy = (int)incy; #if C10_IOS cblas_scopy(i_n, x, i_incx, y, i_incy); #else scopy_(&i_n, x, &i_incx, y, &i_incy); #endif return; } #endif copy_stub( kCPU, at::kFloat, n, x, incx, y, incy); } void copy(int64_t n, const c10::complex *x, int64_t incx, c10::complex *y, int64_t incy) { if(n == 1) { incx = 1; incy = 1; } #if AT_BUILD_WITH_BLAS() if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) { int i_n = (int)n; int i_incx = (int)incx; int i_incy = (int)incy; #if C10_IOS cblas_zcopy(i_n, x, i_incx, y, i_incy); #else zcopy_(&i_n, x, &i_incx, y, &i_incy); #endif return; } #endif copy_stub( kCPU, at::kComplexDouble, n, x, incx, y, incy); } void copy(int64_t n, const c10::complex *x, int64_t incx, c10::complex *y, int64_t incy){ if(n == 1) { incx = 1; incy = 1; } #if AT_BUILD_WITH_BLAS() if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) { int i_n = (int)n; int i_incx = (int)incx; int i_incy = (int)incy; #if C10_IOS cblas_ccopy(i_n, &x, i_incx, y, i_incy); #else ccopy_(&i_n, x, &i_incx, y, &i_incy); #endif return; } #endif copy_stub( kCPU, at::kComplexFloat, n, x, incx, y, incy); } // oneDNN BRGEMM #if ONEDNN_UKERNEL_ENABLED && (defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC))) struct BrgemmKey { int64_t M; int64_t N; int64_t K; int64_t batch_size; int64_t lda; int64_t ldb; int64_t ldc; ScalarType dt_a; ScalarType dt_b; ScalarType dt_c; float alpha; float beta; BrgemmKey( int64_t M, int64_t N, int64_t K, int64_t batch_size, int64_t lda, int64_t ldb, int64_t ldc, ScalarType dt_a, ScalarType dt_b, ScalarType dt_c, float alpha, float beta) : M(M), N(N), K(K), batch_size(batch_size), lda(lda), ldb(ldb), ldc(ldc), dt_a(dt_a), dt_b(dt_b), dt_c(dt_c), alpha(alpha), beta(beta) {} bool operator==(const BrgemmKey& other) const { return M == other.M && N == other.N && K == other.K && batch_size == other.batch_size && lda == other.lda && ldb == other.ldb && ldc == other.ldc && dt_a == other.dt_a && dt_b == other.dt_b && dt_c == other.dt_c && alpha == other.alpha && beta == other.beta; } }; struct PackKey { int64_t K; int64_t N; int64_t ld_in; int64_t ld_out; ScalarType dt_in; ScalarType dt_out; PackKey( int64_t K, int64_t N, int64_t ld_in, int64_t ld_out, ScalarType dt_in, ScalarType dt_out) : K(K), N(N), ld_in(ld_in), ld_out(ld_out), dt_in(dt_in), dt_out(dt_out) {} bool operator==(const PackKey& other) const { return N == other.N && K == other.K && ld_in == other.ld_in && ld_out == other.ld_out && dt_in == other.dt_in && dt_out == other.dt_out; } }; inline dnnl::memory::data_type get_dnnl_dtype(ScalarType dtype) { if (dtype == ScalarType::Float) { return dnnl::memory::data_type::f32; } else if (dtype == ScalarType::BFloat16) { return dnnl::memory::data_type::bf16; } else if (dtype == ScalarType::Half) { return dnnl::memory::data_type::f16; } else if (dtype == ScalarType::Byte) { return dnnl::memory::data_type::u8; } else if (dtype == ScalarType::Char) { return dnnl::memory::data_type::s8; } else { TORCH_CHECK(false, "get_dnnl_dtype expects float/bfloat16/half/int8 tensor input"); } } template struct UnsafeUkernelKeyHasher { std::size_t operator()(const key_t& key) const; }; template<> std::size_t UnsafeUkernelKeyHasher::operator()(const BrgemmKey& key) const { // Use beta, M, N, and K to compute hash to reduce the overhead as // batch size, alpha, and data types are unlikely to change within the same kernel and // leading dimensions are likely to be related to M, K, N or use fixed values. std::size_t h = std::hash()(key.beta + 1); h = std::hash()(key.M) ^ (h << 1); h = std::hash()(key.N) ^ (h << 1); h = std::hash()(key.K) ^ (h << 1); h = std::hash()(key.ldc) ^ (h << 1); return h; } template<> std::size_t UnsafeUkernelKeyHasher::operator()(const PackKey& key) const { // Use K and N to compute hash to reduce the overhead as // data types are unlikely to change and // ld_in/ld_out is likely to be related to K, N or use fixed values std::size_t h = std::hash()(key.K); h = std::hash()(key.N) ^ (h << 1); return h; } template struct KernelCache { using kstore_t = std::unordered_map, UnsafeUkernelKeyHasher>; static inline std::shared_ptr&& fetch_or_create( const key_t& key, const std::function()>& callback) { auto&& search = get_store().find(key); if (search != get_store().end()) { return std::move(search->second); } else { get_store().insert({key, callback()}); return std::move(get_store()[key]); } } static inline kstore_t& get_store() { static thread_local kstore_t cache_kernels; return cache_kernels; } }; // Helper struct for convenient brgemm configuration struct GemmHelper { GemmHelper( int64_t M, int64_t N, int64_t K, int64_t bs, int64_t ld_a, int64_t ld_b, int64_t ld_c, ScalarType dt_a, ScalarType dt_b, ScalarType dt_c, const float alpha, const float beta) { // Create brgemm brg = dnnl::ukernel::brgemm( M, N, K, bs, ld_a, ld_b, ld_c, get_dnnl_dtype(dt_a), get_dnnl_dtype(dt_b), get_dnnl_dtype(dt_c), alpha, beta); // Create a scratchpad buffer for the brgemm execution scratchpad = std::vector(brg.get_scratchpad_size()); // Prepare default vector of pairs of tensors A and B offsets for each batch. A_B_offsets.reserve(1); A_B_offsets[0] = std::make_pair(0, 0); } dnnl::ukernel::brgemm brg; std::vector scratchpad; std::vector> A_B_offsets; }; struct Brgemm : public KernelCache { // Fetch/create GemmHelper object and execute brgemm with batch size = 1 template static inline void call( int64_t M, int64_t N, int64_t K, int64_t ld_a, int64_t ld_b, int64_t ld_c, const float alpha, const float beta, const scalar_t_a* A, const scalar_t_b* B, scalar_t_c* C) { auto&& key = BrgemmKey( M, N, K, int64_t(1), ld_a, ld_b, ld_c, c10::CppTypeToScalarType::value, c10::CppTypeToScalarType::value, c10::CppTypeToScalarType::value, alpha, beta); // Fetch/create GemmHelper object auto&& value = fetch_or_create(key, [&]() { auto&& v = std::make_shared( M, N, K, 1, ld_a, ld_b, ld_c, c10::CppTypeToScalarType::value, c10::CppTypeToScalarType::value, c10::CppTypeToScalarType::value, alpha, beta); (*v).brg.generate(); return std::move(v); }); if (get_current() != value) { dnnl::ukernel::brgemm::release_hw_context(); ((*value).brg).set_hw_context(); get_current() = value; } ((*value).brg) .execute(A, B, (*value).A_B_offsets, C, (*value).scratchpad.data()); } static inline std::shared_ptr& get_current() { static thread_local std::shared_ptr current; return current; } static inline bool device_check(ScalarType dtype) { if (!at::globalContext().userEnabledMkldnn()) { return false; } if (dtype == ScalarType::Half) { static bool fp16_support = dnnl::get_effective_cpu_isa() >= dnnl::cpu_isa::avx512_core_fp16; return fp16_support; } return false; } }; using pack_t = dnnl::ukernel::brgemm_pack_B; struct Pack : public KernelCache { static inline void call( int64_t K, int64_t N, int64_t ld_in, int64_t ld_out, ScalarType dt_in, ScalarType dt_out, const void* in, void* out) { auto&& key = PackKey(K, N, ld_in, ld_out, dt_in, dt_out); auto&& pack = fetch_or_create(key, [&]() { auto&& p = std::make_shared( K, N, ld_in, ld_out, get_dnnl_dtype(dt_in), get_dnnl_dtype(dt_out)); if (need_pack(dt_in)) { (*p).generate(); } return std::move(p); }); if (need_pack(dt_in)) { (*pack).execute(in, out); } else { TORCH_CHECK(false, "No need to pack"); } } static inline bool need_pack(ScalarType dtype) { if (!at::globalContext().userEnabledMkldnn()) { return false; } if (dtype == ScalarType::Half) { static bool fp16_pack = dnnl::get_effective_cpu_isa() >= dnnl::cpu_isa::avx512_core_amx_fp16; return fp16_pack; } return false; } }; #endif void brgemm( int64_t M, int64_t N, int64_t K, int64_t ld_a, int64_t ld_b, int64_t ld_c, const float alpha, const float beta, const at::Half* A, const at::Half* B, float* C) { #if ONEDNN_UKERNEL_ENABLED && (defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC))) if (Brgemm::device_check(ScalarType::Half)) { Brgemm::call( M, N, K, ld_a, ld_b, ld_c, alpha, beta, A, B, C); return; } #endif TORCH_CHECK(false, "Half Brgemm is only supported on X64 when oneDNN ukernel is enabled and avx512_fp16 is supported"); } void brgemm_release() { #if ONEDNN_UKERNEL_ENABLED && (defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC))) dnnl::ukernel::brgemm::release_hw_context(); #endif } void pack( int64_t K, int64_t N, int64_t ld_in, int64_t ld_out, ScalarType dt_in, ScalarType dt_out, const void* in, void* out) { #if ONEDNN_UKERNEL_ENABLED && (defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC))) Pack::call(K, N, ld_in, ld_out, dt_in, dt_out, in, out); #else TORCH_CHECK(false, "pack is only supported on X64 with oneDNN ukernel enabled"); #endif } bool need_pack(ScalarType dt_in) { #if ONEDNN_UKERNEL_ENABLED && (defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC))) return Pack::need_pack(dt_in); #else return false; #endif } } // namespace at::native::cpublas