1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h"
17
18 #define EIGEN_USE_THREADS
19
20 #include "absl/base/dynamic_annotations.h"
21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22 #include "tensorflow/compiler/xla/executable_run_options.h"
23 #include "tensorflow/compiler/xla/service/cpu/runtime_lightweight_check.h"
24
25 #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
26 #include "tensorflow/core/kernels/eigen_contraction_kernel.h"
27 #endif
28
29 namespace {
30
Is16BytesAligned(void * ptr)31 bool Is16BytesAligned(void* ptr) {
32 return reinterpret_cast<uintptr_t>(ptr) % 16 == 0;
33 }
34
35 template <typename T, Eigen::AlignmentType Alignment>
MatMul(const void * run_options_ptr,T * out,T * lhs,T * rhs,int64_t m,int64_t n,int64_t k,int32_t transpose_lhs,int32_t transpose_rhs)36 void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64_t m,
37 int64_t n, int64_t k, int32_t transpose_lhs,
38 int32_t transpose_rhs) {
39 const xla::ExecutableRunOptions* run_options =
40 static_cast<const xla::ExecutableRunOptions*>(run_options_ptr);
41
42 int64_t lhs_rows = m;
43 int64_t lhs_cols = k;
44 if (transpose_lhs) {
45 std::swap(lhs_rows, lhs_cols);
46 }
47
48 int64_t rhs_rows = k;
49 int64_t rhs_cols = n;
50 if (transpose_rhs) {
51 std::swap(rhs_rows, rhs_cols);
52 }
53
54 const Eigen::TensorMap<Eigen::Tensor<const T, 2>, Alignment> A(lhs, lhs_rows,
55 lhs_cols);
56 const Eigen::TensorMap<Eigen::Tensor<const T, 2>, Alignment> B(rhs, rhs_rows,
57 rhs_cols);
58 Eigen::TensorMap<Eigen::Tensor<T, 2>, Alignment> C(out, m, n);
59
60 typedef typename Eigen::Tensor<T, 2>::DimensionPair DimPair;
61 int lhs_contract_dim = transpose_lhs ? 0 : 1;
62 int rhs_contract_dim = transpose_rhs ? 1 : 0;
63 const Eigen::array<DimPair, 1> dims(
64 {DimPair(lhs_contract_dim, rhs_contract_dim)});
65
66 // Matrix multiply is a special case of the "contract" operation where
67 // the contraction is performed along dimension 1 of the lhs and dimension
68 // 0 of the rhs.
69 XLA_LIGHTWEIGHT_CHECK(run_options->intra_op_thread_pool() != nullptr);
70 C.device(*run_options->intra_op_thread_pool()) = A.contract(B, dims);
71 }
72
73 template <typename T, Eigen::AlignmentType Alignment>
MatMul_Batch(const void * run_options_ptr,T * out,T * lhs,T * rhs,int64_t m,int64_t n,int64_t k,Eigen::Index batch_size,int32_t transpose_lhs,int32_t transpose_rhs)74 void MatMul_Batch(const void* run_options_ptr, T* out, T* lhs, T* rhs,
75 int64_t m, int64_t n, int64_t k, Eigen::Index batch_size,
76 int32_t transpose_lhs, int32_t transpose_rhs) {
77 const xla::ExecutableRunOptions* run_options =
78 static_cast<const xla::ExecutableRunOptions*>(run_options_ptr);
79
80 int64_t lhs_rows = m;
81 int64_t lhs_cols = k;
82 if (transpose_lhs) {
83 std::swap(lhs_rows, lhs_cols);
84 }
85
86 int64_t rhs_rows = k;
87 int64_t rhs_cols = n;
88 if (transpose_rhs) {
89 std::swap(rhs_rows, rhs_cols);
90 }
91
92 const Eigen::TensorMap<Eigen::Tensor<const T, 3>, Alignment> A(
93 lhs, lhs_rows, lhs_cols, batch_size);
94 const Eigen::TensorMap<Eigen::Tensor<const T, 3>, Alignment> B(
95 rhs, rhs_rows, rhs_cols, batch_size);
96 Eigen::TensorMap<Eigen::Tensor<T, 3>, Alignment> C(out, m, n, batch_size);
97
98 typedef typename Eigen::Tensor<T, 2>::DimensionPair DimPair;
99 int lhs_contract_dim = transpose_lhs ? 0 : 1;
100 int rhs_contract_dim = transpose_rhs ? 1 : 0;
101
102 const Eigen::array<DimPair, 1> dims(
103 {DimPair(lhs_contract_dim, rhs_contract_dim)});
104
105 // Matrix multiply is a special case of the "contract" operation where
106 // the contraction is performed along dimension 1 of the lhs and dimension
107 // 0 of the rhs.
108 XLA_LIGHTWEIGHT_CHECK(run_options->intra_op_thread_pool() != nullptr);
109
110 for (int64_t i = 0; i < batch_size; ++i) {
111 C.chip(i, 2).device(*run_options->intra_op_thread_pool()) =
112 A.chip(i, 2).contract(B.chip(i, 2), dims);
113 }
114 }
115
116 template <typename T>
MatMulDispatch(const void * run_options_ptr,T * out,T * lhs,T * rhs,int64_t m,int64_t n,int64_t k,int32_t transpose_lhs,int32_t transpose_rhs)117 void MatMulDispatch(const void* run_options_ptr, T* out, T* lhs, T* rhs,
118 int64_t m, int64_t n, int64_t k, int32_t transpose_lhs,
119 int32_t transpose_rhs) {
120 bool all_buffers_16b_aligned =
121 Is16BytesAligned(out) && Is16BytesAligned(lhs) && Is16BytesAligned(rhs);
122
123 if (!all_buffers_16b_aligned) {
124 MatMul<T, Eigen::Unaligned>(run_options_ptr, out, lhs, rhs, m, n, k,
125 transpose_lhs, transpose_rhs);
126 return;
127 }
128
129 MatMul<T, Eigen::Aligned16>(run_options_ptr, out, lhs, rhs, m, n, k,
130 transpose_lhs, transpose_rhs);
131 }
132
133 template <typename T>
BatchMatMulDispatch(const void * run_options_ptr,T * out,T * lhs,T * rhs,int64_t m,int64_t n,int64_t k,int64_t batch_size,int32_t transpose_lhs,int32_t transpose_rhs)134 void BatchMatMulDispatch(const void* run_options_ptr, T* out, T* lhs, T* rhs,
135 int64_t m, int64_t n, int64_t k, int64_t batch_size,
136 int32_t transpose_lhs, int32_t transpose_rhs) {
137 bool all_buffers_16b_aligned =
138 Is16BytesAligned(out) && Is16BytesAligned(lhs) && Is16BytesAligned(rhs);
139
140 if (!all_buffers_16b_aligned) {
141 MatMul_Batch<T, Eigen::Unaligned>(run_options_ptr, out, lhs, rhs, m, n, k,
142 batch_size, transpose_lhs, transpose_rhs);
143 return;
144 }
145 MatMul_Batch<T, Eigen::Aligned16>(run_options_ptr, out, lhs, rhs, m, n, k,
146 batch_size, transpose_lhs, transpose_rhs);
147 }
148
149 } // namespace
150
__xla_cpu_runtime_EigenMatMulF16(const void * run_options_ptr,Eigen::half * out,Eigen::half * lhs,Eigen::half * rhs,int64_t m,int64_t n,int64_t k,int32_t transpose_lhs,int32_t transpose_rhs)151 ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF16(
152 const void* run_options_ptr, Eigen::half* out, Eigen::half* lhs,
153 Eigen::half* rhs, int64_t m, int64_t n, int64_t k, int32_t transpose_lhs,
154 int32_t transpose_rhs) {
155 MatMulDispatch<Eigen::half>(run_options_ptr, out, lhs, rhs, m, n, k,
156 transpose_lhs, transpose_rhs);
157 }
158
__xla_cpu_runtime_EigenMatMulF32(const void * run_options_ptr,float * out,float * lhs,float * rhs,int64_t m,int64_t n,int64_t k,int32_t transpose_lhs,int32_t transpose_rhs)159 ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF32(
160 const void* run_options_ptr, float* out, float* lhs, float* rhs, int64_t m,
161 int64_t n, int64_t k, int32_t transpose_lhs, int32_t transpose_rhs) {
162 MatMulDispatch<float>(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs,
163 transpose_rhs);
164 }
165
__xla_cpu_runtime_EigenMatMulF64(const void * run_options_ptr,double * out,double * lhs,double * rhs,int64_t m,int64_t n,int64_t k,int32_t transpose_lhs,int32_t transpose_rhs)166 ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF64(
167 const void* run_options_ptr, double* out, double* lhs, double* rhs,
168 int64_t m, int64_t n, int64_t k, int32_t transpose_lhs,
169 int32_t transpose_rhs) {
170 MatMulDispatch<double>(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs,
171 transpose_rhs);
172 }
173
__xla_cpu_runtime_EigenMatMulC64(const void * run_options_ptr,std::complex<float> * out,std::complex<float> * lhs,std::complex<float> * rhs,int64_t m,int64_t n,int64_t k,int32_t transpose_lhs,int32_t transpose_rhs)174 ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulC64(
175 const void* run_options_ptr, std::complex<float>* out,
176 std::complex<float>* lhs, std::complex<float>* rhs, int64_t m, int64_t n,
177 int64_t k, int32_t transpose_lhs, int32_t transpose_rhs) {
178 MatMulDispatch<std::complex<float>>(run_options_ptr, out, lhs, rhs, m, n, k,
179 transpose_lhs, transpose_rhs);
180 }
181
__xla_cpu_runtime_EigenMatMulC128(const void * run_options_ptr,std::complex<double> * out,std::complex<double> * lhs,std::complex<double> * rhs,int64_t m,int64_t n,int64_t k,int32_t transpose_lhs,int32_t transpose_rhs)182 ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulC128(
183 const void* run_options_ptr, std::complex<double>* out,
184 std::complex<double>* lhs, std::complex<double>* rhs, int64_t m, int64_t n,
185 int64_t k, int32_t transpose_lhs, int32_t transpose_rhs) {
186 MatMulDispatch<std::complex<double>>(run_options_ptr, out, lhs, rhs, m, n, k,
187 transpose_lhs, transpose_rhs);
188 }
189
__xla_cpu_runtime_EigenMatMulS32(const void * run_options_ptr,int32_t * out,int32_t * lhs,int32_t * rhs,int64_t m,int64_t n,int64_t k,int32_t transpose_lhs,int32_t transpose_rhs)190 ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulS32(
191 const void* run_options_ptr, int32_t* out, int32_t* lhs, int32_t* rhs,
192 int64_t m, int64_t n, int64_t k, int32_t transpose_lhs,
193 int32_t transpose_rhs) {
194 MatMulDispatch<int32_t>(run_options_ptr, out, lhs, rhs, m, n, k,
195 transpose_lhs, transpose_rhs);
196 }
197
__xla_cpu_runtime_EigenBatchMatMulF32(const void * run_options_ptr,float * out,float * lhs,float * rhs,int64_t m,int64_t n,int64_t k,int64_t batch_size,int32_t transpose_lhs,int32_t transpose_rhs)198 ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenBatchMatMulF32(
199 const void* run_options_ptr, float* out, float* lhs, float* rhs, int64_t m,
200 int64_t n, int64_t k, int64_t batch_size, int32_t transpose_lhs,
201 int32_t transpose_rhs) {
202 BatchMatMulDispatch<float>(run_options_ptr, out, lhs, rhs, m, n, k,
203 batch_size, transpose_lhs, transpose_rhs);
204 }
205