xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h"
17 
18 #define EIGEN_USE_THREADS
19 
20 #include "absl/base/dynamic_annotations.h"
21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22 #include "tensorflow/compiler/xla/executable_run_options.h"
23 #include "tensorflow/compiler/xla/service/cpu/runtime_lightweight_check.h"
24 
25 #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
26 #include "tensorflow/core/kernels/eigen_contraction_kernel.h"
27 #endif
28 
29 namespace {
30 
Is16BytesAligned(void * ptr)31 bool Is16BytesAligned(void* ptr) {
32   return reinterpret_cast<uintptr_t>(ptr) % 16 == 0;
33 }
34 
35 template <typename T, Eigen::AlignmentType Alignment>
MatMul(const void * run_options_ptr,T * out,T * lhs,T * rhs,int64_t m,int64_t n,int64_t k,int32_t transpose_lhs,int32_t transpose_rhs)36 void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64_t m,
37             int64_t n, int64_t k, int32_t transpose_lhs,
38             int32_t transpose_rhs) {
39   const xla::ExecutableRunOptions* run_options =
40       static_cast<const xla::ExecutableRunOptions*>(run_options_ptr);
41 
42   int64_t lhs_rows = m;
43   int64_t lhs_cols = k;
44   if (transpose_lhs) {
45     std::swap(lhs_rows, lhs_cols);
46   }
47 
48   int64_t rhs_rows = k;
49   int64_t rhs_cols = n;
50   if (transpose_rhs) {
51     std::swap(rhs_rows, rhs_cols);
52   }
53 
54   const Eigen::TensorMap<Eigen::Tensor<const T, 2>, Alignment> A(lhs, lhs_rows,
55                                                                  lhs_cols);
56   const Eigen::TensorMap<Eigen::Tensor<const T, 2>, Alignment> B(rhs, rhs_rows,
57                                                                  rhs_cols);
58   Eigen::TensorMap<Eigen::Tensor<T, 2>, Alignment> C(out, m, n);
59 
60   typedef typename Eigen::Tensor<T, 2>::DimensionPair DimPair;
61   int lhs_contract_dim = transpose_lhs ? 0 : 1;
62   int rhs_contract_dim = transpose_rhs ? 1 : 0;
63   const Eigen::array<DimPair, 1> dims(
64       {DimPair(lhs_contract_dim, rhs_contract_dim)});
65 
66   // Matrix multiply is a special case of the "contract" operation where
67   // the contraction is performed along dimension 1 of the lhs and dimension
68   // 0 of the rhs.
69   XLA_LIGHTWEIGHT_CHECK(run_options->intra_op_thread_pool() != nullptr);
70   C.device(*run_options->intra_op_thread_pool()) = A.contract(B, dims);
71 }
72 
73 template <typename T, Eigen::AlignmentType Alignment>
MatMul_Batch(const void * run_options_ptr,T * out,T * lhs,T * rhs,int64_t m,int64_t n,int64_t k,Eigen::Index batch_size,int32_t transpose_lhs,int32_t transpose_rhs)74 void MatMul_Batch(const void* run_options_ptr, T* out, T* lhs, T* rhs,
75                   int64_t m, int64_t n, int64_t k, Eigen::Index batch_size,
76                   int32_t transpose_lhs, int32_t transpose_rhs) {
77   const xla::ExecutableRunOptions* run_options =
78       static_cast<const xla::ExecutableRunOptions*>(run_options_ptr);
79 
80   int64_t lhs_rows = m;
81   int64_t lhs_cols = k;
82   if (transpose_lhs) {
83     std::swap(lhs_rows, lhs_cols);
84   }
85 
86   int64_t rhs_rows = k;
87   int64_t rhs_cols = n;
88   if (transpose_rhs) {
89     std::swap(rhs_rows, rhs_cols);
90   }
91 
92   const Eigen::TensorMap<Eigen::Tensor<const T, 3>, Alignment> A(
93       lhs, lhs_rows, lhs_cols, batch_size);
94   const Eigen::TensorMap<Eigen::Tensor<const T, 3>, Alignment> B(
95       rhs, rhs_rows, rhs_cols, batch_size);
96   Eigen::TensorMap<Eigen::Tensor<T, 3>, Alignment> C(out, m, n, batch_size);
97 
98   typedef typename Eigen::Tensor<T, 2>::DimensionPair DimPair;
99   int lhs_contract_dim = transpose_lhs ? 0 : 1;
100   int rhs_contract_dim = transpose_rhs ? 1 : 0;
101 
102   const Eigen::array<DimPair, 1> dims(
103       {DimPair(lhs_contract_dim, rhs_contract_dim)});
104 
105   // Matrix multiply is a special case of the "contract" operation where
106   // the contraction is performed along dimension 1 of the lhs and dimension
107   // 0 of the rhs.
108   XLA_LIGHTWEIGHT_CHECK(run_options->intra_op_thread_pool() != nullptr);
109 
110   for (int64_t i = 0; i < batch_size; ++i) {
111     C.chip(i, 2).device(*run_options->intra_op_thread_pool()) =
112         A.chip(i, 2).contract(B.chip(i, 2), dims);
113   }
114 }
115 
116 template <typename T>
MatMulDispatch(const void * run_options_ptr,T * out,T * lhs,T * rhs,int64_t m,int64_t n,int64_t k,int32_t transpose_lhs,int32_t transpose_rhs)117 void MatMulDispatch(const void* run_options_ptr, T* out, T* lhs, T* rhs,
118                     int64_t m, int64_t n, int64_t k, int32_t transpose_lhs,
119                     int32_t transpose_rhs) {
120   bool all_buffers_16b_aligned =
121       Is16BytesAligned(out) && Is16BytesAligned(lhs) && Is16BytesAligned(rhs);
122 
123   if (!all_buffers_16b_aligned) {
124     MatMul<T, Eigen::Unaligned>(run_options_ptr, out, lhs, rhs, m, n, k,
125                                 transpose_lhs, transpose_rhs);
126     return;
127   }
128 
129   MatMul<T, Eigen::Aligned16>(run_options_ptr, out, lhs, rhs, m, n, k,
130                               transpose_lhs, transpose_rhs);
131 }
132 
133 template <typename T>
BatchMatMulDispatch(const void * run_options_ptr,T * out,T * lhs,T * rhs,int64_t m,int64_t n,int64_t k,int64_t batch_size,int32_t transpose_lhs,int32_t transpose_rhs)134 void BatchMatMulDispatch(const void* run_options_ptr, T* out, T* lhs, T* rhs,
135                          int64_t m, int64_t n, int64_t k, int64_t batch_size,
136                          int32_t transpose_lhs, int32_t transpose_rhs) {
137   bool all_buffers_16b_aligned =
138       Is16BytesAligned(out) && Is16BytesAligned(lhs) && Is16BytesAligned(rhs);
139 
140   if (!all_buffers_16b_aligned) {
141     MatMul_Batch<T, Eigen::Unaligned>(run_options_ptr, out, lhs, rhs, m, n, k,
142                                       batch_size, transpose_lhs, transpose_rhs);
143     return;
144   }
145   MatMul_Batch<T, Eigen::Aligned16>(run_options_ptr, out, lhs, rhs, m, n, k,
146                                     batch_size, transpose_lhs, transpose_rhs);
147 }
148 
149 }  // namespace
150 
__xla_cpu_runtime_EigenMatMulF16(const void * run_options_ptr,Eigen::half * out,Eigen::half * lhs,Eigen::half * rhs,int64_t m,int64_t n,int64_t k,int32_t transpose_lhs,int32_t transpose_rhs)151 ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF16(
152     const void* run_options_ptr, Eigen::half* out, Eigen::half* lhs,
153     Eigen::half* rhs, int64_t m, int64_t n, int64_t k, int32_t transpose_lhs,
154     int32_t transpose_rhs) {
155   MatMulDispatch<Eigen::half>(run_options_ptr, out, lhs, rhs, m, n, k,
156                               transpose_lhs, transpose_rhs);
157 }
158 
__xla_cpu_runtime_EigenMatMulF32(const void * run_options_ptr,float * out,float * lhs,float * rhs,int64_t m,int64_t n,int64_t k,int32_t transpose_lhs,int32_t transpose_rhs)159 ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF32(
160     const void* run_options_ptr, float* out, float* lhs, float* rhs, int64_t m,
161     int64_t n, int64_t k, int32_t transpose_lhs, int32_t transpose_rhs) {
162   MatMulDispatch<float>(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs,
163                         transpose_rhs);
164 }
165 
__xla_cpu_runtime_EigenMatMulF64(const void * run_options_ptr,double * out,double * lhs,double * rhs,int64_t m,int64_t n,int64_t k,int32_t transpose_lhs,int32_t transpose_rhs)166 ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF64(
167     const void* run_options_ptr, double* out, double* lhs, double* rhs,
168     int64_t m, int64_t n, int64_t k, int32_t transpose_lhs,
169     int32_t transpose_rhs) {
170   MatMulDispatch<double>(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs,
171                          transpose_rhs);
172 }
173 
__xla_cpu_runtime_EigenMatMulC64(const void * run_options_ptr,std::complex<float> * out,std::complex<float> * lhs,std::complex<float> * rhs,int64_t m,int64_t n,int64_t k,int32_t transpose_lhs,int32_t transpose_rhs)174 ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulC64(
175     const void* run_options_ptr, std::complex<float>* out,
176     std::complex<float>* lhs, std::complex<float>* rhs, int64_t m, int64_t n,
177     int64_t k, int32_t transpose_lhs, int32_t transpose_rhs) {
178   MatMulDispatch<std::complex<float>>(run_options_ptr, out, lhs, rhs, m, n, k,
179                                       transpose_lhs, transpose_rhs);
180 }
181 
__xla_cpu_runtime_EigenMatMulC128(const void * run_options_ptr,std::complex<double> * out,std::complex<double> * lhs,std::complex<double> * rhs,int64_t m,int64_t n,int64_t k,int32_t transpose_lhs,int32_t transpose_rhs)182 ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulC128(
183     const void* run_options_ptr, std::complex<double>* out,
184     std::complex<double>* lhs, std::complex<double>* rhs, int64_t m, int64_t n,
185     int64_t k, int32_t transpose_lhs, int32_t transpose_rhs) {
186   MatMulDispatch<std::complex<double>>(run_options_ptr, out, lhs, rhs, m, n, k,
187                                        transpose_lhs, transpose_rhs);
188 }
189 
__xla_cpu_runtime_EigenMatMulS32(const void * run_options_ptr,int32_t * out,int32_t * lhs,int32_t * rhs,int64_t m,int64_t n,int64_t k,int32_t transpose_lhs,int32_t transpose_rhs)190 ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulS32(
191     const void* run_options_ptr, int32_t* out, int32_t* lhs, int32_t* rhs,
192     int64_t m, int64_t n, int64_t k, int32_t transpose_lhs,
193     int32_t transpose_rhs) {
194   MatMulDispatch<int32_t>(run_options_ptr, out, lhs, rhs, m, n, k,
195                           transpose_lhs, transpose_rhs);
196 }
197 
__xla_cpu_runtime_EigenBatchMatMulF32(const void * run_options_ptr,float * out,float * lhs,float * rhs,int64_t m,int64_t n,int64_t k,int64_t batch_size,int32_t transpose_lhs,int32_t transpose_rhs)198 ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenBatchMatMulF32(
199     const void* run_options_ptr, float* out, float* lhs, float* rhs, int64_t m,
200     int64_t n, int64_t k, int64_t batch_size, int32_t transpose_lhs,
201     int32_t transpose_rhs) {
202   BatchMatMulDispatch<float>(run_options_ptr, out, lhs, rhs, m, n, k,
203                              batch_size, transpose_lhs, transpose_rhs);
204 }
205