xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
17 
18 #include "absl/base/attributes.h"
19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
20 
21 #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
22 #include "tensorflow/core/kernels/eigen_contraction_kernel.h"
23 #endif
24 
25 namespace {
26 
Is16BytesAligned(void * ptr)27 bool Is16BytesAligned(void* ptr) {
28   return reinterpret_cast<uintptr_t>(ptr) % 16 == 0;
29 }
30 
31 template <typename T, Eigen::AlignmentType Alignment>
MatMul(const void * run_options_ptr,T * out,T * lhs,T * rhs,int64_t m,int64_t n,int64_t k,int32_t transpose_lhs,int32_t transpose_rhs)32 void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64_t m,
33             int64_t n, int64_t k, int32_t transpose_lhs,
34             int32_t transpose_rhs) {
35   int64_t lhs_rows = m;
36   int64_t lhs_cols = k;
37   if (transpose_lhs) {
38     std::swap(lhs_rows, lhs_cols);
39   }
40 
41   int64_t rhs_rows = k;
42   int64_t rhs_cols = n;
43   if (transpose_rhs) {
44     std::swap(rhs_rows, rhs_cols);
45   }
46 
47   const Eigen::TensorMap<Eigen::Tensor<const T, 2>, Alignment> A(lhs, lhs_rows,
48                                                                  lhs_cols);
49   const Eigen::TensorMap<Eigen::Tensor<const T, 2>, Alignment> B(rhs, rhs_rows,
50                                                                  rhs_cols);
51   Eigen::TensorMap<Eigen::Tensor<T, 2>, Alignment> C(out, m, n);
52 
53   typedef typename Eigen::Tensor<T, 2>::DimensionPair DimPair;
54   int lhs_contract_dim = transpose_lhs ? 0 : 1;
55   int rhs_contract_dim = transpose_rhs ? 1 : 0;
56   const Eigen::array<DimPair, 1> dims(
57       {DimPair(lhs_contract_dim, rhs_contract_dim)});
58 
59   // Matrix multiply is a special case of the "contract" operation where
60   // the contraction is performed along dimension 1 of the lhs and dimension
61   // 0 of the rhs.
62   C = A.contract(B, dims);
63 }
64 
65 template <typename T>
SingleThreadedMatMulDispatch(const void * run_options_ptr,T * out,T * lhs,T * rhs,int64_t m,int64_t n,int64_t k,int32_t transpose_lhs,int32_t transpose_rhs)66 void SingleThreadedMatMulDispatch(const void* run_options_ptr, T* out, T* lhs,
67                                   T* rhs, int64_t m, int64_t n, int64_t k,
68                                   int32_t transpose_lhs,
69                                   int32_t transpose_rhs) {
70   bool all_buffers_16b_aligned =
71       Is16BytesAligned(out) && Is16BytesAligned(lhs) && Is16BytesAligned(rhs);
72 
73   if (!all_buffers_16b_aligned) {
74     MatMul<T, Eigen::Unaligned>(run_options_ptr, out, lhs, rhs, m, n, k,
75                                 transpose_lhs, transpose_rhs);
76   }
77 
78   MatMul<T, Eigen::Aligned16>(run_options_ptr, out, lhs, rhs, m, n, k,
79                               transpose_lhs, transpose_rhs);
80 }
81 
82 }  // namespace
83 
84 ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void
__xla_cpu_runtime_EigenSingleThreadedMatMulF16(const void * run_options_ptr,Eigen::half * out,Eigen::half * lhs,Eigen::half * rhs,int64_t m,int64_t n,int64_t k,int32_t transpose_lhs,int32_t transpose_rhs)85 __xla_cpu_runtime_EigenSingleThreadedMatMulF16(
86     const void* run_options_ptr, Eigen::half* out, Eigen::half* lhs,
87     Eigen::half* rhs, int64_t m, int64_t n, int64_t k, int32_t transpose_lhs,
88     int32_t transpose_rhs) {
89   SingleThreadedMatMulDispatch<Eigen::half>(run_options_ptr, out, lhs, rhs, m,
90                                             n, k, transpose_lhs, transpose_rhs);
91 }
92 
93 ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void
__xla_cpu_runtime_EigenSingleThreadedMatMulF32(const void * run_options_ptr,float * out,float * lhs,float * rhs,int64_t m,int64_t n,int64_t k,int32_t transpose_lhs,int32_t transpose_rhs)94 __xla_cpu_runtime_EigenSingleThreadedMatMulF32(const void* run_options_ptr,
95                                                float* out, float* lhs,
96                                                float* rhs, int64_t m, int64_t n,
97                                                int64_t k, int32_t transpose_lhs,
98                                                int32_t transpose_rhs) {
99   SingleThreadedMatMulDispatch<float>(run_options_ptr, out, lhs, rhs, m, n, k,
100                                       transpose_lhs, transpose_rhs);
101 }
102 
103 ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void
__xla_cpu_runtime_EigenSingleThreadedMatMulF64(const void * run_options_ptr,double * out,double * lhs,double * rhs,int64_t m,int64_t n,int64_t k,int32_t transpose_lhs,int32_t transpose_rhs)104 __xla_cpu_runtime_EigenSingleThreadedMatMulF64(const void* run_options_ptr,
105                                                double* out, double* lhs,
106                                                double* rhs, int64_t m,
107                                                int64_t n, int64_t k,
108                                                int32_t transpose_lhs,
109                                                int32_t transpose_rhs) {
110   SingleThreadedMatMulDispatch<double>(run_options_ptr, out, lhs, rhs, m, n, k,
111                                        transpose_lhs, transpose_rhs);
112 }
113 
114 ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void
__xla_cpu_runtime_EigenSingleThreadedMatMulC64(const void * run_options_ptr,std::complex<float> * out,std::complex<float> * lhs,std::complex<float> * rhs,int64_t m,int64_t n,int64_t k,int32_t transpose_lhs,int32_t transpose_rhs)115 __xla_cpu_runtime_EigenSingleThreadedMatMulC64(
116     const void* run_options_ptr, std::complex<float>* out,
117     std::complex<float>* lhs, std::complex<float>* rhs, int64_t m, int64_t n,
118     int64_t k, int32_t transpose_lhs, int32_t transpose_rhs) {
119   SingleThreadedMatMulDispatch<std::complex<float>>(
120       run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs);
121 }
122 
123 ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void
__xla_cpu_runtime_EigenSingleThreadedMatMulC128(const void * run_options_ptr,std::complex<double> * out,std::complex<double> * lhs,std::complex<double> * rhs,int64_t m,int64_t n,int64_t k,int32_t transpose_lhs,int32_t transpose_rhs)124 __xla_cpu_runtime_EigenSingleThreadedMatMulC128(
125     const void* run_options_ptr, std::complex<double>* out,
126     std::complex<double>* lhs, std::complex<double>* rhs, int64_t m, int64_t n,
127     int64_t k, int32_t transpose_lhs, int32_t transpose_rhs) {
128   SingleThreadedMatMulDispatch<std::complex<double>>(
129       run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs);
130 }
131 
132 ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void
__xla_cpu_runtime_EigenSingleThreadedMatMulS32(const void * run_options_ptr,int32_t * out,int32_t * lhs,int32_t * rhs,int64_t m,int64_t n,int64_t k,int32_t transpose_lhs,int32_t transpose_rhs)133 __xla_cpu_runtime_EigenSingleThreadedMatMulS32(const void* run_options_ptr,
134                                                int32_t* out, int32_t* lhs,
135                                                int32_t* rhs, int64_t m,
136                                                int64_t n, int64_t k,
137                                                int32_t transpose_lhs,
138                                                int32_t transpose_rhs) {
139   SingleThreadedMatMulDispatch<int32_t>(run_options_ptr, out, lhs, rhs, m, n, k,
140                                         transpose_lhs, transpose_rhs);
141 }
142