1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/math_ops.cc. 17 18 #ifdef INTEL_MKL 19 #define EIGEN_USE_THREADS 20 21 #include <numeric> 22 23 #include "dnnl.hpp" 24 #include "tensorflow/core/framework/numeric_op.h" 25 #include "tensorflow/core/framework/register_types.h" 26 #include "tensorflow/core/lib/gtl/inlined_vector.h" 27 #include "tensorflow/core/util/mkl_util.h" 28 29 using dnnl::stream; 30 using dnnl::sum; 31 32 namespace tensorflow { 33 typedef Eigen::ThreadPoolDevice CPUDevice; 34 35 template <typename Device, typename T> 36 class MklAddNOp : public OpKernel { 37 public: ~MklAddNOp()38 ~MklAddNOp() {} MklAddNOp(OpKernelConstruction * context)39 explicit MklAddNOp(OpKernelConstruction* context) : OpKernel(context) {} 40 GetTensorShape(OpKernelContext * ctx,size_t src_index)41 TensorShape GetTensorShape(OpKernelContext* ctx, size_t src_index) { 42 const Tensor& src_tensor = MklGetInput(ctx, src_index); 43 MklDnnShape src_mkl_shape; 44 GetMklShape(ctx, src_index, &src_mkl_shape); 45 return src_mkl_shape.IsMklTensor() ? src_mkl_shape.GetTfShape() 46 : src_tensor.shape(); 47 } 48 CheckInputShape(OpKernelContext * ctx)49 bool CheckInputShape(OpKernelContext* ctx) { 50 const int num_inputs = ctx->num_inputs() / 2; 51 const TensorShape src0_shape = GetTensorShape(ctx, 0); 52 53 for (size_t i = 1; i < num_inputs; ++i) { 54 if (!src0_shape.IsSameSize(GetTensorShape(ctx, i))) { 55 ctx->SetStatus(errors::InvalidArgument( 56 "Inputs to operation ", this->name(), " of type ", 57 this->type_string(), 58 " must have the same size and shape. Input 0: ", 59 src0_shape.DebugString(), " != input : ", i, 60 GetTensorShape(ctx, i).DebugString())); 61 62 return false; 63 } 64 } 65 66 return true; 67 } 68 69 // Return first tensor index which is in MKL layout, or -1 with no MKL input. FindMKLInputIndex(OpKernelContext * ctx)70 int FindMKLInputIndex(OpKernelContext* ctx) { 71 int mkl_index = -1; 72 const int num_inputs = ctx->num_inputs() / 2; 73 74 MklDnnShape src_mkl_shape; 75 for (size_t i = 0; i < num_inputs; ++i) { 76 GetMklShape(ctx, i, &src_mkl_shape); 77 if (src_mkl_shape.IsMklTensor()) { 78 mkl_index = i; 79 break; 80 } 81 } 82 83 return mkl_index; 84 } 85 ComputeScalar(OpKernelContext * ctx)86 void ComputeScalar(OpKernelContext* ctx) { 87 const int num_inputs = ctx->num_inputs() / 2; 88 const size_t kOutputIdx = 0; 89 TensorShape output_tf_shape; 90 MklDnnShape output_mkl_shape; 91 Tensor* dst_tensor = nullptr; 92 93 T sum = static_cast<T>(0); 94 for (int src_idx = 0; src_idx < num_inputs; ++src_idx) { 95 const Tensor& src_tensor = MklGetInput(ctx, src_idx); 96 T* src_i = const_cast<T*>(src_tensor.flat<T>().data()); 97 sum += src_i[0]; 98 } 99 100 output_mkl_shape.SetMklTensor(false); 101 output_tf_shape = MklGetInput(ctx, kOutputIdx).shape(); 102 AllocateOutputSetMklShape(ctx, kOutputIdx, &dst_tensor, output_tf_shape, 103 output_mkl_shape); 104 105 T* out_o = dst_tensor->flat<T>().data(); 106 out_o[0] = sum; 107 } 108 Compute(OpKernelContext * ctx)109 void Compute(OpKernelContext* ctx) override { 110 // Each input tensor in MKL layout has additional meta-tensor carrying 111 // layout information. So the number of actual tensors is half the total 112 // number of inputs. 113 const int num_inputs = ctx->num_inputs() / 2; 114 115 MklDnnShape mkl_shape; 116 const size_t kSrc0Idx = 0; 117 const size_t kOutputIdx = 0; 118 119 if (num_inputs == 1) { 120 GetMklShape(ctx, kSrc0Idx, &mkl_shape); 121 bool input_in_mkl_format = mkl_shape.IsMklTensor(); 122 123 if (input_in_mkl_format) { 124 ForwardMklTensorInToOut(ctx, kSrc0Idx, kOutputIdx); 125 } else { 126 ForwardTfTensorInToOut(ctx, kSrc0Idx, kOutputIdx); 127 } 128 return; 129 } 130 131 // Check if the input shape is same 132 if (!CheckInputShape(ctx)) return; 133 134 try { 135 TensorShape output_tf_shape; 136 MklDnnShape output_mkl_shape; 137 const Tensor& src_tensor = MklGetInput(ctx, kSrc0Idx); 138 139 Tensor* dst_tensor = nullptr; 140 141 // Nothing to compute, return. 142 if (src_tensor.shape().num_elements() == 0) { 143 output_mkl_shape.SetMklTensor(false); 144 output_tf_shape = src_tensor.shape(); 145 AllocateOutputSetMklShape(ctx, kOutputIdx, &dst_tensor, output_tf_shape, 146 output_mkl_shape); 147 return; 148 } 149 150 if (src_tensor.dims() == 0) { 151 ComputeScalar(ctx); 152 return; 153 } 154 155 auto cpu_engine = engine(engine::kind::cpu, 0); 156 std::vector<float> coeff(num_inputs, 1.0); 157 std::vector<memory::desc> srcs_pd; 158 std::vector<memory> inputs; 159 160 MklDnnData<T> dst(&cpu_engine); 161 MklDnnData<T> src(&cpu_engine); 162 bool has_mkl_input = false; 163 int mkl_input_index = FindMKLInputIndex(ctx); 164 MklTensorFormat mkl_data_format; 165 TensorFormat tf_data_format; 166 memory::format_tag dnn_fmt = memory::format_tag::any; 167 if (mkl_input_index >= 0) { 168 has_mkl_input = true; 169 GetMklShape(ctx, mkl_input_index, &mkl_shape); 170 // MKL input has the data format information. 171 mkl_data_format = mkl_shape.GetTfDataFormat(); 172 tf_data_format = MklDnnDataFormatToTFDataFormat(mkl_data_format); 173 dnn_fmt = MklTensorFormatToMklDnnDataFormat(mkl_data_format); 174 } 175 176 std::shared_ptr<stream> fwd_cpu_stream; 177 MklDnnThreadPool eigen_tp(ctx); 178 fwd_cpu_stream.reset(CreateStream(&eigen_tp, cpu_engine)); 179 180 // Create memory descriptor for MKL-DNN. 181 // If all input in Tensorflow format, create block memory descriptor, 182 // else convert TF format to MKL memory descriptor 183 for (int src_idx = 0; src_idx < num_inputs; ++src_idx) { 184 MklDnnShape src_mkl_shape; 185 GetMklShape(ctx, src_idx, &src_mkl_shape); 186 memory::desc md({}, memory::data_type::undef, 187 memory::format_tag::undef); 188 const Tensor& src_tensor = MklGetInput(ctx, src_idx); 189 190 if (src_mkl_shape.IsMklTensor()) { 191 md = src_mkl_shape.GetMklLayout(); 192 } else { 193 if (has_mkl_input) { 194 memory::dims src_dims; 195 if (src_tensor.dims() == 4) { 196 src_dims = 197 TFShapeToMklDnnDimsInNCHW(src_tensor.shape(), tf_data_format); 198 } else { 199 DCHECK(src_tensor.dims() == 5); 200 src_dims = TFShapeToMklDnnDimsInNCDHW(src_tensor.shape(), 201 tf_data_format); 202 } 203 md = memory::desc(src_dims, MklDnnType<T>(), dnn_fmt); 204 } else { 205 // Create block memory descriptor for TensorFlow format input. 206 auto dims = TFShapeToMklDnnDims(src_tensor.shape()); 207 auto strides = CalculateTFStrides(dims); 208 md = MklDnnData<T>::CreateBlockedMemDesc(dims, strides); 209 } 210 } 211 srcs_pd.push_back(memory::desc(md)); 212 src.SetUsrMem(md, &src_tensor); 213 src.SetUsrMemDataHandle(&src_tensor, fwd_cpu_stream); 214 inputs.push_back(src.GetOpMem()); 215 } 216 217 auto sum_pd = sum::primitive_desc(coeff, srcs_pd, cpu_engine); 218 output_mkl_shape.SetMklTensor(has_mkl_input); 219 auto output_pd = sum_pd.dst_desc(); 220 dst.SetUsrMem(output_pd); 221 222 if (has_mkl_input) { 223 output_mkl_shape.SetMklLayout(&output_pd); 224 output_mkl_shape.SetElemType(MklDnnType<T>()); 225 output_mkl_shape.SetTfLayout(mkl_shape.GetDimension(), 226 mkl_shape.GetSizesAsMklDnnDims(), 227 mkl_shape.GetTfDataFormat()); 228 output_tf_shape.AddDim((output_pd.get_size() / sizeof(T))); 229 } else { 230 // All inputs have TF shapes, get the shape from first one. 231 output_tf_shape = MklGetInput(ctx, kSrc0Idx).shape(); 232 } 233 AllocateOutputSetMklShape(ctx, kOutputIdx, &dst_tensor, output_tf_shape, 234 output_mkl_shape); 235 dst.SetUsrMemDataHandle(dst_tensor, fwd_cpu_stream); 236 237 // Create Sum op, and submit net for execution. 238 std::vector<primitive> net; 239 dnnl::sum sum_op(sum_pd); 240 std::unordered_map<int, memory> net_args = { 241 {DNNL_ARG_DST, dst.GetOpMem()}}; 242 for (int i = 0; i < num_inputs; ++i) { 243 net_args.insert({DNNL_ARG_MULTIPLE_SRC + i, inputs[i]}); 244 } 245 sum_op.execute(*fwd_cpu_stream, net_args); 246 } catch (dnnl::error& e) { 247 string error_msg = "Status: " + std::to_string(e.status) + 248 ", message: " + string(e.message) + ", in file " + 249 string(__FILE__) + ":" + std::to_string(__LINE__); 250 OP_REQUIRES_OK( 251 ctx, errors::Aborted("Operation received an exception:", error_msg)); 252 } 253 } 254 }; 255 256 #define REGISTER_MKL_CPU(T) \ 257 REGISTER_KERNEL_BUILDER( \ 258 Name("_MklAddN") \ 259 .Device(DEVICE_CPU) \ 260 .TypeConstraint<T>("T") \ 261 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 262 MklAddNOp<CPUDevice, T>); 263 264 TF_CALL_float(REGISTER_MKL_CPU); 265 TF_CALL_bfloat16(REGISTER_MKL_CPU); 266 #undef REGISTER_MKL_CPU 267 } // namespace tensorflow 268 #endif // INTEL_MKL 269