1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/nn_ops.cc. 17 #ifdef INTEL_MKL 18 19 #include <unordered_map> 20 21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 22 #include "dnnl.hpp" 23 #include "tensorflow/core/framework/numeric_op.h" 24 #include "tensorflow/core/framework/op_kernel.h" 25 #include "tensorflow/core/framework/register_types.h" 26 #include "tensorflow/core/framework/tensor.h" 27 #include "tensorflow/core/lib/core/errors.h" 28 #include "tensorflow/core/util/mkl_util.h" 29 #ifdef DNNL_AARCH64_USE_ACL 30 #include "tensorflow/core/platform/mutex.h" 31 #endif 32 33 using dnnl::algorithm; 34 using dnnl::eltwise_forward; 35 using dnnl::memory; 36 using dnnl::prop_kind; 37 using dnnl::stream; 38 39 using EltwiseFwdPd = dnnl::eltwise_forward::primitive_desc; 40 using EltwiseBwdPd = dnnl::eltwise_backward::primitive_desc; 41 42 namespace tensorflow { 43 44 template <typename T> 45 class MklEltwiseFwdParams { 46 public: 47 memory::dims src_dims; 48 memory::desc src_md; 49 algorithm alg_kind; 50 float alpha; 51 float beta; 52 MklEltwiseFwdParams(memory::dims src_dims,memory::desc src_md,algorithm alg_kind,float alpha,float beta)53 MklEltwiseFwdParams(memory::dims src_dims, memory::desc src_md, 54 algorithm alg_kind, float alpha, float beta) 55 : src_dims(src_dims), 56 src_md(src_md), 57 alg_kind(alg_kind), 58 alpha(alpha), 59 beta(beta) {} 60 }; 61 62 template <typename T> 63 class MklEltwiseFwdPrimitive : public MklPrimitive { 64 public: MklEltwiseFwdPrimitive(const MklEltwiseFwdParams<T> & fwdParams)65 explicit MklEltwiseFwdPrimitive(const MklEltwiseFwdParams<T>& fwdParams) 66 : MklPrimitive(engine(engine::kind::cpu, 0)) { 67 // create eltwise primitive 68 if (context_.eltwise_fwd == nullptr) { 69 Setup(fwdParams); 70 } 71 } 72 ~MklEltwiseFwdPrimitive()73 ~MklEltwiseFwdPrimitive() {} 74 75 // Eltwise forward execute 76 // src_data: input data buffer of src 77 // dst_data: output data buffer of dst Execute(const T * src_data,T * dst_data,std::shared_ptr<stream> fwd_stream)78 void Execute(const T* src_data, T* dst_data, 79 std::shared_ptr<stream> fwd_stream) { 80 #ifdef DNNL_AARCH64_USE_ACL 81 mutex_lock lock(primitive_execution_mu_); 82 #endif 83 #ifndef ENABLE_ONEDNN_OPENMP 84 context_.src_mem->set_data_handle( 85 static_cast<void*>(const_cast<T*>(src_data)), *fwd_stream); 86 context_.dst_mem->set_data_handle(static_cast<void*>(dst_data), 87 *fwd_stream); 88 #else 89 context_.src_mem->set_data_handle( 90 static_cast<void*>(const_cast<T*>(src_data))); 91 context_.dst_mem->set_data_handle(static_cast<void*>(dst_data)); 92 #endif // !ENABLE_ONEDNN_OPENMP 93 DCHECK_EQ(context_.fwd_primitives.size(), 94 context_.fwd_primitives_args.size()); 95 execute_primitives(context_.fwd_primitives, fwd_stream, 96 context_.fwd_primitives_args); 97 98 // After execution, set data handle back. 99 context_.src_mem->set_data_handle(DummyData); 100 context_.dst_mem->set_data_handle(DummyData); 101 } 102 GetEltwiseFwdPd()103 std::shared_ptr<EltwiseFwdPd> GetEltwiseFwdPd() { return context_.fwd_pd; } 104 105 private: 106 // Primitive reuse context for eltwise Fwd ops: Relu, Elu, Tanh 107 struct EltwiseFwdContext { 108 // oneDNN memory 109 std::shared_ptr<memory> src_mem; 110 std::shared_ptr<memory> dst_mem; 111 112 // desc & primitive desc 113 std::shared_ptr<dnnl::eltwise_forward::desc> fwd_desc; 114 std::shared_ptr<EltwiseFwdPd> fwd_pd; 115 116 // memory desc 117 std::shared_ptr<memory::desc> src_md; 118 std::shared_ptr<memory::desc> dst_md; 119 120 // Eltwise primitive 121 std::shared_ptr<dnnl::primitive> eltwise_fwd; 122 123 std::vector<dnnl::primitive> fwd_primitives; 124 125 std::vector<std::unordered_map<int, memory>> fwd_primitives_args; 126 EltwiseFwdContexttensorflow::MklEltwiseFwdPrimitive::EltwiseFwdContext127 EltwiseFwdContext() 128 : src_mem(nullptr), 129 dst_mem(nullptr), 130 fwd_desc(nullptr), 131 fwd_pd(nullptr), 132 src_md(nullptr), 133 dst_md(nullptr), 134 eltwise_fwd(nullptr) {} 135 }; 136 137 // Eltwise forward primitive setup Setup(const MklEltwiseFwdParams<T> & fwdParams)138 void Setup(const MklEltwiseFwdParams<T>& fwdParams) { 139 // create memory descriptors for eltwise data with specified format 140 context_.src_md.reset(new memory::desc(fwdParams.src_md.data)); 141 142 // Create an eltwise forward descriptor and primitive descriptor 143 context_.fwd_desc.reset(new eltwise_forward::desc( 144 prop_kind::forward, fwdParams.alg_kind, *context_.src_md, 145 fwdParams.alpha, fwdParams.beta)); 146 context_.fwd_pd.reset(new EltwiseFwdPd(*context_.fwd_desc, cpu_engine_)); 147 auto fwd_pd = context_.fwd_pd.get(); 148 149 // Create memory primitive based on dummy data 150 context_.src_mem.reset( 151 new memory(fwd_pd->src_desc(), cpu_engine_, DummyData)); 152 context_.dst_mem.reset( 153 new memory(fwd_pd->dst_desc(), cpu_engine_, DummyData)); 154 // Create eltwise primitive and add it to net 155 context_.eltwise_fwd.reset(new eltwise_forward(*context_.fwd_pd)); 156 context_.fwd_primitives_args.push_back( 157 {{DNNL_ARG_SRC, *context_.src_mem}, {DNNL_ARG_DST, *context_.dst_mem}}); 158 context_.fwd_primitives.push_back(*context_.eltwise_fwd); 159 } 160 161 struct EltwiseFwdContext context_; 162 163 #ifdef DNNL_AARCH64_USE_ACL 164 mutex primitive_execution_mu_; 165 #endif 166 }; 167 168 template <typename T> 169 class MklEltwiseFwdPrimitiveFactory : public MklPrimitiveFactory<T> { 170 public: Get(const MklEltwiseFwdParams<T> & fwdParams)171 static MklEltwiseFwdPrimitive<T>* Get( 172 const MklEltwiseFwdParams<T>& fwdParams) { 173 MklEltwiseFwdPrimitive<T>* eltwise_forward = nullptr; 174 175 // Get a eltwise fwd primitive from the cached pool 176 eltwise_forward = static_cast<MklEltwiseFwdPrimitive<T>*>( 177 MklEltwiseFwdPrimitiveFactory<T>::GetInstance().GetEltwiseFwd( 178 fwdParams)); 179 if (eltwise_forward == nullptr) { 180 eltwise_forward = new MklEltwiseFwdPrimitive<T>(fwdParams); 181 MklEltwiseFwdPrimitiveFactory<T>::GetInstance().SetEltwiseFwd( 182 fwdParams, eltwise_forward); 183 } 184 185 return eltwise_forward; 186 } 187 GetInstance()188 static MklEltwiseFwdPrimitiveFactory& GetInstance() { 189 static MklEltwiseFwdPrimitiveFactory instance_; 190 return instance_; 191 } 192 193 private: MklEltwiseFwdPrimitiveFactory()194 MklEltwiseFwdPrimitiveFactory() {} ~MklEltwiseFwdPrimitiveFactory()195 ~MklEltwiseFwdPrimitiveFactory() {} 196 CreateKey(const MklEltwiseFwdParams<T> & fwdParams)197 static string CreateKey(const MklEltwiseFwdParams<T>& fwdParams) { 198 string prefix = "eltwise_fwd"; 199 FactoryKeyCreator key_creator; 200 key_creator.AddAsKey(prefix); 201 key_creator.AddAsKey(fwdParams.src_dims); 202 key_creator.AddAsKey<int>(static_cast<int>(fwdParams.alg_kind)); 203 key_creator.AddAsKey<float>(static_cast<float>(fwdParams.alpha)); 204 key_creator.AddAsKey<float>(static_cast<float>(fwdParams.beta)); 205 return key_creator.GetKey(); 206 } 207 GetEltwiseFwd(const MklEltwiseFwdParams<T> & fwdParams)208 MklPrimitive* GetEltwiseFwd(const MklEltwiseFwdParams<T>& fwdParams) { 209 string key = CreateKey(fwdParams); 210 return this->GetOp(key); 211 } 212 SetEltwiseFwd(const MklEltwiseFwdParams<T> & fwdParams,MklPrimitive * op)213 void SetEltwiseFwd(const MklEltwiseFwdParams<T>& fwdParams, 214 MklPrimitive* op) { 215 string key = CreateKey(fwdParams); 216 this->SetOp(key, op); 217 } 218 }; 219 220 template <typename T> 221 class MklEltwiseBwdParams { 222 public: 223 memory::dims src_dims; 224 memory::desc common_md; 225 algorithm alg_kind; 226 float alpha; 227 float beta; 228 // Whether the input that grad op gets from forward op is SRC 229 // of forward op or DST of forward op. 230 int forward_input_type; 231 MklEltwiseBwdParams(const memory::dims & src_dims,const memory::desc & common_md,algorithm alg_kind,float alpha,float beta,int forward_input_type=-1)232 MklEltwiseBwdParams(const memory::dims& src_dims, 233 const memory::desc& common_md, algorithm alg_kind, 234 float alpha, float beta, int forward_input_type = -1) 235 : src_dims(src_dims), 236 common_md(common_md), 237 alg_kind(alg_kind), 238 alpha(alpha), 239 beta(beta), 240 forward_input_type(forward_input_type) {} 241 }; 242 243 template <typename T> 244 class MklEltwiseBwdPrimitive : public MklPrimitive { 245 public: MklEltwiseBwdPrimitive(const MklEltwiseBwdParams<T> & bwdParams)246 explicit MklEltwiseBwdPrimitive(const MklEltwiseBwdParams<T>& bwdParams) 247 : MklPrimitive(engine(engine::kind::cpu, 0)) { 248 // create eltwise primitive 249 if (context_.eltwise_bwd == nullptr) { 250 Setup(bwdParams); 251 } 252 } 253 ~MklEltwiseBwdPrimitive()254 ~MklEltwiseBwdPrimitive() {} 255 256 // Eltwise backward execute 257 // src_data: input data buffer of src 258 // diff_dst_data: input data buffer of diff_dst 259 // diff_src_data: output data buffer of diff_src Execute(const T * src_data,const T * diff_dst_data,T * diff_src_data,std::shared_ptr<stream> bwd_stream)260 void Execute(const T* src_data, const T* diff_dst_data, T* diff_src_data, 261 std::shared_ptr<stream> bwd_stream) { 262 #ifdef DNNL_AARCH64_USE_ACL 263 mutex_lock lock(primitive_execution_mu_); 264 #endif 265 #ifndef ENABLE_ONEDNN_OPENMP 266 context_.src_mem->set_data_handle( 267 static_cast<void*>(const_cast<T*>(src_data)), *bwd_stream); 268 context_.diff_dst_mem->set_data_handle( 269 static_cast<void*>(const_cast<T*>(diff_dst_data)), *bwd_stream); 270 context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data), 271 *bwd_stream); 272 #else 273 context_.src_mem->set_data_handle( 274 static_cast<void*>(const_cast<T*>(src_data))); 275 context_.diff_dst_mem->set_data_handle( 276 static_cast<void*>(const_cast<T*>(diff_dst_data))); 277 context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data)); 278 #endif // !ENABLE_ONEDNN_OPENMP 279 DCHECK_EQ(context_.bwd_primitives.size(), 280 context_.bwd_primitives_args.size()); 281 execute_primitives(context_.bwd_primitives, bwd_stream, 282 context_.bwd_primitives_args); 283 284 // after execution, set data handle back 285 context_.src_mem->set_data_handle(DummyData); 286 context_.diff_dst_mem->set_data_handle(DummyData); 287 context_.diff_src_mem->set_data_handle(DummyData); 288 } 289 GetEltwiseBwdPd()290 std::shared_ptr<EltwiseBwdPd> GetEltwiseBwdPd() { return context_.bwd_pd; } 291 292 private: 293 // Primitive reuse context for eltwise Bwd ops: Relu, Elu, Tanh 294 struct EltwiseBwdContext { 295 // oneDNN memory 296 std::shared_ptr<memory> src_mem; 297 std::shared_ptr<memory> diff_dst_mem; 298 std::shared_ptr<memory> diff_src_mem; 299 300 // Backward Eltwise descriptor. 301 std::shared_ptr<dnnl::eltwise_backward::desc> bwd_desc; 302 303 // Memory descriptors. 304 std::shared_ptr<memory::desc> src_md; 305 std::shared_ptr<memory::desc> diff_dst_md; 306 std::shared_ptr<memory::desc> common_md; 307 308 // Forward and backward descriptors and primitive descriptors. 309 std::shared_ptr<dnnl::eltwise_forward::desc> fwd_desc; 310 std::shared_ptr<EltwiseFwdPd> fwd_pd; 311 std::shared_ptr<EltwiseBwdPd> bwd_pd; 312 313 // Eltwise primitive. 314 std::shared_ptr<dnnl::primitive> eltwise_bwd; 315 316 std::vector<dnnl::primitive> bwd_primitives; 317 318 std::vector<MemoryArgsMap> bwd_primitives_args; 319 EltwiseBwdContexttensorflow::MklEltwiseBwdPrimitive::EltwiseBwdContext320 EltwiseBwdContext() 321 : src_mem(nullptr), 322 diff_dst_mem(nullptr), 323 diff_src_mem(nullptr), 324 src_md(nullptr), 325 diff_dst_md(nullptr), 326 common_md(nullptr), 327 fwd_desc(nullptr), 328 fwd_pd(nullptr), 329 bwd_pd(nullptr), 330 eltwise_bwd(nullptr) {} 331 }; 332 333 // Eltwise backward primitive setup Setup(const MklEltwiseBwdParams<T> & bwdParams)334 void Setup(const MklEltwiseBwdParams<T>& bwdParams) { 335 // Create memory descriptors for eltwise data w/ no specified format 336 context_.src_md.reset(new memory::desc(bwdParams.common_md.data)); 337 context_.diff_dst_md.reset(new memory::desc(bwdParams.common_md.data)); 338 339 // Create forward eltwise primitive. 340 context_.fwd_desc.reset(new dnnl::eltwise_forward::desc( 341 prop_kind::forward_training, bwdParams.alg_kind, *context_.src_md, 342 bwdParams.alpha, bwdParams.beta)); 343 context_.fwd_pd.reset(new EltwiseFwdPd(*context_.fwd_desc, cpu_engine_)); 344 context_.bwd_desc.reset(new dnnl::eltwise_backward::desc( 345 bwdParams.alg_kind, *context_.diff_dst_md, *context_.src_md, 346 bwdParams.alpha, bwdParams.beta)); 347 context_.bwd_pd.reset( 348 new EltwiseBwdPd(*context_.bwd_desc, cpu_engine_, *context_.fwd_pd)); 349 350 auto bwd_pd = context_.bwd_pd.get(); 351 352 // Create memory primitive based on dummy data. 353 context_.src_mem.reset( 354 new memory(bwd_pd->src_desc(), cpu_engine_, DummyData)); 355 context_.diff_dst_mem.reset( 356 new memory(bwd_pd->diff_dst_desc(), cpu_engine_, DummyData)); 357 context_.diff_src_mem.reset( 358 new memory(bwd_pd->diff_src_desc(), cpu_engine_, DummyData)); 359 // Create eltwise primitive and add it to net. 360 context_.eltwise_bwd.reset(new dnnl::eltwise_backward(*context_.bwd_pd)); 361 context_.bwd_primitives_args.push_back( 362 {{bwdParams.forward_input_type, *context_.src_mem}, 363 {DNNL_ARG_DIFF_DST, *context_.diff_dst_mem}, 364 {DNNL_ARG_DIFF_SRC, *context_.diff_src_mem}}); 365 366 context_.bwd_primitives.push_back(*context_.eltwise_bwd); 367 } 368 369 struct EltwiseBwdContext context_; 370 371 #ifdef DNNL_AARCH64_USE_ACL 372 mutex primitive_execution_mu_; 373 #endif 374 }; 375 376 template <typename T> 377 class MklEltwiseBwdPrimitiveFactory : public MklPrimitiveFactory<T> { 378 private: MklEltwiseBwdPrimitiveFactory()379 MklEltwiseBwdPrimitiveFactory() {} ~MklEltwiseBwdPrimitiveFactory()380 ~MklEltwiseBwdPrimitiveFactory() {} 381 382 public: Get(const MklEltwiseBwdParams<T> & bwdParams)383 static MklEltwiseBwdPrimitive<T>* Get( 384 const MklEltwiseBwdParams<T>& bwdParams) { 385 MklEltwiseBwdPrimitive<T>* eltwise_backward = nullptr; 386 387 // try to find a suitable one in pool 388 eltwise_backward = static_cast<MklEltwiseBwdPrimitive<T>*>( 389 MklEltwiseBwdPrimitiveFactory<T>::GetInstance().GetEltwiseBwd( 390 bwdParams)); 391 392 if (eltwise_backward == nullptr) { 393 eltwise_backward = new MklEltwiseBwdPrimitive<T>(bwdParams); 394 MklEltwiseBwdPrimitiveFactory<T>::GetInstance().SetEltwiseBwd( 395 bwdParams, eltwise_backward); 396 } 397 return eltwise_backward; 398 } 399 GetInstance()400 static MklEltwiseBwdPrimitiveFactory& GetInstance() { 401 static MklEltwiseBwdPrimitiveFactory instance_; 402 return instance_; 403 } 404 405 private: CreateKey(const MklEltwiseBwdParams<T> & bwdParams)406 static string CreateKey(const MklEltwiseBwdParams<T>& bwdParams) { 407 string prefix = "eltwise_bwd"; 408 FactoryKeyCreator key_creator; 409 key_creator.AddAsKey(prefix); 410 key_creator.AddAsKey(bwdParams.src_dims); 411 key_creator.AddAsKey(static_cast<int>(bwdParams.alg_kind)); 412 key_creator.AddAsKey(static_cast<float>(bwdParams.alpha)); 413 key_creator.AddAsKey(static_cast<float>(bwdParams.beta)); 414 return key_creator.GetKey(); 415 } 416 GetEltwiseBwd(const MklEltwiseBwdParams<T> & bwdParams)417 MklPrimitive* GetEltwiseBwd(const MklEltwiseBwdParams<T>& bwdParams) { 418 string key = CreateKey(bwdParams); 419 return this->GetOp(key); 420 } 421 SetEltwiseBwd(const MklEltwiseBwdParams<T> & bwdParams,MklPrimitive * op)422 void SetEltwiseBwd(const MklEltwiseBwdParams<T>& bwdParams, 423 MklPrimitive* op) { 424 string key = CreateKey(bwdParams); 425 this->SetOp(key, op); 426 } 427 }; 428 429 typedef Eigen::ThreadPoolDevice CPUDevice; 430 431 template <typename Device, typename T, algorithm alg_kind> 432 class MklReluOpBase : public OpKernel { 433 public: ~MklReluOpBase()434 ~MklReluOpBase() {} 435 MklReluOpBase(OpKernelConstruction * context,float alpha,float beta)436 explicit MklReluOpBase(OpKernelConstruction* context, float alpha, float beta) 437 : OpKernel(context), alpha_(alpha), beta_(beta) {} 438 virtual void Compute_Scalar(OpKernelContext* context) = 0; 439 Compute(OpKernelContext * context)440 void Compute(OpKernelContext* context) override { 441 try { 442 const size_t src_index = 0; // index of src input tensor 443 const size_t dst_index = 0; // index of dst output tensor 444 const Tensor& src_tensor = MklGetInput(context, src_index); 445 MklDnnShape dnn_shape_src; 446 GetMklShape(context, src_index, &dnn_shape_src); 447 if (src_tensor.dims() == 0) { 448 Compute_Scalar(context); 449 return; 450 } 451 MklDnnShape dnn_shape_dst; 452 TensorShape tf_shape_dst; 453 Tensor* dst_tensor = nullptr; 454 // Nothing to compute, return. 455 if (src_tensor.shape().num_elements() == 0) { 456 dnn_shape_dst.SetMklTensor(false); 457 tf_shape_dst = MklGetInput(context, src_index).shape(); 458 AllocateOutputSetMklShape(context, dst_index, &dst_tensor, tf_shape_dst, 459 dnn_shape_dst); 460 return; 461 } 462 // Set DNN primitive - src 463 MklDnnData<T> src(&cpu_engine); 464 memory::dims src_dims; 465 memory::desc src_md({}, memory::data_type::undef, 466 memory::format_tag::undef); 467 if (dnn_shape_src.IsMklTensor()) { 468 src_md = dnn_shape_src.GetMklLayout(); 469 src_dims = dnn_shape_src.GetSizesAsMklDnnDims(); 470 } else { 471 src_dims = TFShapeToMklDnnDims(src_tensor.shape()); 472 auto src_strides = CalculateTFStrides(src_dims); 473 // Create blocked memory descriptor 474 src_md = MklDnnData<T>::CreateBlockedMemDesc(src_dims, src_strides); 475 } 476 // Try to get an eltwise forward primitive from caching pool 477 MklEltwiseFwdParams<T> fwdParams(src_dims, src_md, alg_kind, alpha_, 478 beta_); 479 MklEltwiseFwdPrimitive<T>* eltwise_fwd = 480 MklEltwiseFwdPrimitiveFactory<T>::Get(fwdParams); 481 auto eltwise_fwd_pd = eltwise_fwd->GetEltwiseFwdPd(); 482 std::shared_ptr<stream> fwd_cpu_stream; 483 MklDnnThreadPool eigen_tp(context); 484 fwd_cpu_stream.reset(CreateStream(&eigen_tp, eltwise_fwd->GetEngine())); 485 // Check if src needs to be reordered 486 bool is_src_reordered = false; 487 const T* src_data = src_tensor.flat<T>().data(); 488 if (src_md != eltwise_fwd_pd->src_desc()) { 489 src.SetUsrMem(src_md, &src_tensor); 490 src.CheckReorderToOpMem(eltwise_fwd_pd->src_desc(), cpu_engine, 491 context); 492 src_data = const_cast<T*>( 493 reinterpret_cast<T*>(src.GetOpMem().get_data_handle())); 494 is_src_reordered = true; 495 } 496 497 // If src is reordered, then dst tensor would be in blocked layout. 498 // So we propagate this blocked layout on the output. We follow same 499 // logic when src is in blocked (MKL) layout to start of with also. 500 if (is_src_reordered || dnn_shape_src.IsMklTensor()) { 501 dnn_shape_dst.SetMklTensor(true); 502 auto dst_pd = eltwise_fwd_pd->dst_desc(); 503 dnn_shape_dst.SetMklLayout(&dst_pd); 504 dnn_shape_dst.SetElemType(MklDnnType<T>()); 505 if (dnn_shape_src.IsMklTensor()) { 506 dnn_shape_dst.SetTfLayout(dnn_shape_src.GetDimension(), 507 dnn_shape_src.GetSizesAsMklDnnDims(), 508 dnn_shape_src.GetTfDataFormat()); 509 } else { 510 dnn_shape_dst.SetTfLayout(src_tensor.dims(), 511 TFShapeToMklDnnDims(src_tensor.shape()), 512 MklTensorFormat::FORMAT_BLOCKED); 513 } 514 tf_shape_dst.AddDim(dst_pd.get_size() / sizeof(T)); 515 } else { 516 // If src is not in blocked layout or it is not reordered, then dst is 517 // in native layout. 518 dnn_shape_dst.SetMklTensor(false); 519 tf_shape_dst = src_tensor.shape(); 520 } 521 522 if (is_src_reordered) { 523 // If src is reordered, then src and dst would be in different layouts. 524 AllocateOutputSetMklShape(context, dst_index, &dst_tensor, tf_shape_dst, 525 dnn_shape_dst); 526 } else { 527 // forwarding input to output works only when layouts of src and 528 // dst tensor remains same -- either both of them are in native layout 529 // or in blocked (MKL) layout. 530 OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( 531 {static_cast<const int>(src_index)}, 532 static_cast<const int>(dst_index), 533 tf_shape_dst, &dst_tensor)); 534 AllocateOutputSetMklShape(context, dst_index, dnn_shape_dst); 535 } 536 T* dst_data = dst_tensor->flat<T>().data(); 537 538 // execute eltwise 539 eltwise_fwd->Execute(src_data, dst_data, fwd_cpu_stream); 540 } catch (dnnl::error& e) { 541 string error_msg = "Status: " + std::to_string(e.status) + 542 ", message: " + string(e.message) + ", in file " + 543 string(__FILE__) + ":" + std::to_string(__LINE__); 544 OP_REQUIRES_OK( 545 context, 546 errors::Aborted("Operation received an exception:", error_msg)); 547 } 548 } 549 550 private: 551 engine cpu_engine = engine(engine::kind::cpu, 0); 552 std::shared_ptr<EltwiseFwdPd> relu_fwd_pd; 553 554 protected: 555 float alpha_; 556 float beta_; 557 }; 558 559 template <typename Device, typename T, algorithm alg_kind> 560 class MklReluGradOpBase : public OpKernel { 561 public: ~MklReluGradOpBase()562 ~MklReluGradOpBase() {} 563 MklReluGradOpBase(OpKernelConstruction * context,float alpha,float beta)564 explicit MklReluGradOpBase(OpKernelConstruction* context, float alpha, 565 float beta) 566 : OpKernel(context), alpha_(alpha), beta_(beta) {} 567 568 virtual void Compute_Scalar(OpKernelContext* context) = 0; 569 570 // All activation functions that are part of NN ops, such as Relu, Elu, 571 // LeakyRelu, Relu6, etc have dy at index 0 and y at index 1. 572 // 573 // if forward op is defined as: y = f(x), 574 // {Relu,Elu,Relu6,LeakyRelu}Grad is: z = f_grad(dy,x) 575 // TanhGrad is: z = tanh_grad(y,dy) 576 // 577 // Src below refers to a tensor that gradient op receives from forward 578 // operator. From Relu-family ops, it is 'x'; while for TanhGrad, it is 'y'. GetDiffDstIndex() const579 virtual int GetDiffDstIndex() const { return 0; } GetSrcIndex() const580 virtual int GetSrcIndex() const { return 1; } GetDiffSrcIndex() const581 virtual int GetDiffSrcIndex() const { return 0; } 582 // What is the type of input tensor that grad op receives from forward op -- 583 // is it 'x' (SRC) or 'y' (DST). For Relu-family, it is 'x', so fwd op SRC. 584 GetTypeOfInputTensorFromFwdOp() const585 virtual int GetTypeOfInputTensorFromFwdOp() const { return DNNL_ARG_SRC; } 586 Compute(OpKernelContext * context)587 void Compute(OpKernelContext* context) { 588 try { 589 MklDnnData<T> src(&cpu_engine); 590 MklDnnData<T> diff_dst(&cpu_engine); 591 592 size_t diff_dst_index = GetDiffDstIndex(); 593 size_t src_index = GetSrcIndex(); 594 const size_t diff_src_index = GetDiffSrcIndex(); 595 596 const Tensor& src_tensor = MklGetInput(context, src_index); 597 const Tensor& diff_dst_tensor = MklGetInput(context, diff_dst_index); 598 Tensor* diff_src_tensor = nullptr; 599 600 MklDnnShape dnn_shape_src, dnn_shape_diff_dst; 601 GetMklShape(context, src_index, &dnn_shape_src); 602 GetMklShape(context, diff_dst_index, &dnn_shape_diff_dst); 603 604 int src_dims_size = src_tensor.dims(); 605 if (src_dims_size == 0) { 606 Compute_Scalar(context); 607 return; 608 } 609 610 TensorShape tf_shape_diff_src; 611 MklDnnShape dnn_shape_diff_src; 612 // Nothing to compute, return. 613 if (src_tensor.shape().num_elements() == 0) { 614 dnn_shape_diff_src.SetMklTensor(false); 615 tf_shape_diff_src = MklGetInput(context, diff_src_index).shape(); 616 AllocateOutputSetMklShape(context, diff_src_index, &diff_src_tensor, 617 tf_shape_diff_src, dnn_shape_diff_src); 618 return; 619 } 620 621 // get a eltwise bwd from primitive pool 622 memory::dims src_dims = {}; 623 memory::desc src_md({}, memory::data_type::undef, 624 memory::format_tag::undef); 625 memory::desc diff_dst_md({}, memory::data_type::undef, 626 memory::format_tag::undef); 627 if (!dnn_shape_src.IsMklTensor() && !dnn_shape_diff_dst.IsMklTensor()) { 628 src_dims = TFShapeToMklDnnDims(src_tensor.shape()); 629 auto src_strides = CalculateTFStrides(src_dims); 630 src_md = MklDnnData<T>::CreateBlockedMemDesc(src_dims, src_strides); 631 diff_dst_md = src_md; 632 } else if (dnn_shape_src.IsMklTensor() && 633 !dnn_shape_diff_dst.IsMklTensor()) { 634 src_md = dnn_shape_src.GetMklLayout(); 635 src_dims = dnn_shape_src.GetSizesAsMklDnnDims(); 636 637 MklTensorFormat src_mkl_data_format = dnn_shape_src.GetTfDataFormat(); 638 auto src_tf_data_format = 639 MklDnnDataFormatToTFDataFormat(src_mkl_data_format); 640 auto diff_dst_dims = TFShapeToMklDnnDimsInNCHW(diff_dst_tensor.shape(), 641 src_tf_data_format); 642 diff_dst_md = memory::desc( 643 diff_dst_dims, MklDnnType<T>(), 644 MklTensorFormatToMklDnnDataFormat(src_mkl_data_format)); 645 } else if (!dnn_shape_src.IsMklTensor() && 646 dnn_shape_diff_dst.IsMklTensor()) { 647 diff_dst_md = dnn_shape_diff_dst.GetMklLayout(); 648 649 MklTensorFormat diff_dst_mkl_data_format = 650 dnn_shape_diff_dst.GetTfDataFormat(); 651 auto diff_dst_tf_data_format = 652 MklDnnDataFormatToTFDataFormat(diff_dst_mkl_data_format); 653 654 src_dims = (src_tensor.dims() == 4) 655 ? TFShapeToMklDnnDimsInNCHW(src_tensor.shape(), 656 diff_dst_tf_data_format) 657 : TFShapeToMklDnnDimsInNCDHW(src_tensor.shape(), 658 diff_dst_tf_data_format); 659 src_md = memory::desc( 660 src_dims, MklDnnType<T>(), 661 MklTensorFormatToMklDnnDataFormat(diff_dst_mkl_data_format)); 662 } else { 663 src_md = dnn_shape_src.GetMklLayout(); 664 diff_dst_md = dnn_shape_diff_dst.GetMklLayout(); 665 src_dims = dnn_shape_src.GetSizesAsMklDnnDims(); 666 } 667 668 // As per comment above, we tell oneDNN that both the inputs are in same 669 // format. So we set common memory descriptor in MKL format, if any of the 670 // inputs are in MKL format. Let's get memory descriptor that we will use 671 // for both the inputs. 672 memory::desc common_md({}, memory::data_type::undef, 673 memory::format_tag::undef); 674 if (dnn_shape_src.IsMklTensor() || dnn_shape_diff_dst.IsMklTensor()) { 675 common_md = dnn_shape_src.IsMklTensor() ? src_md : diff_dst_md; 676 } else { 677 // Since both the inputs are in Tensorflow format, and have 678 // same shape, we can get memory descriptor from any input. 679 common_md = src_md; 680 } 681 682 MklEltwiseBwdParams<T> bwdParams(src_dims, common_md, alg_kind, alpha_, 683 beta_, GetTypeOfInputTensorFromFwdOp()); 684 685 MklEltwiseBwdPrimitive<T>* eltwise_bwd = 686 MklEltwiseBwdPrimitiveFactory<T>::Get(bwdParams); 687 688 auto eltwise_bwd_pd = eltwise_bwd->GetEltwiseBwdPd(); 689 std::shared_ptr<stream> bwd_cpu_stream; 690 MklDnnThreadPool eigen_tp(context); 691 bwd_cpu_stream.reset(CreateStream(&eigen_tp, eltwise_bwd->GetEngine())); 692 // check whether need reorder for src / diff_dst 693 const T* src_data = src_tensor.flat<T>().data(); 694 if (src_md != eltwise_bwd_pd->src_desc()) { 695 src.SetUsrMem(src_md, &src_tensor); 696 src.CheckReorderToOpMem(eltwise_bwd_pd.get()->diff_src_desc(), 697 cpu_engine, context); 698 src_data = const_cast<T*>( 699 reinterpret_cast<T*>(src.GetOpMem().get_data_handle())); 700 } 701 702 const T* diff_dst_data = diff_dst_tensor.flat<T>().data(); 703 if (diff_dst_md != eltwise_bwd_pd->diff_dst_desc()) { 704 diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor); 705 diff_dst.CheckReorderToOpMem(eltwise_bwd_pd.get()->diff_src_desc(), 706 cpu_engine, context); 707 diff_dst_data = const_cast<T*>( 708 reinterpret_cast<T*>(diff_dst.GetOpMem().get_data_handle())); 709 } 710 711 // allocate diff_src tensor 712 if (dnn_shape_src.IsMklTensor() || dnn_shape_diff_dst.IsMklTensor()) { 713 auto diff_src_pd = eltwise_bwd_pd->diff_src_desc(); 714 dnn_shape_diff_src.SetMklTensor(true); 715 dnn_shape_diff_src.SetMklLayout(&diff_src_pd); 716 dnn_shape_diff_src.SetElemType(MklDnnType<T>()); 717 if (dnn_shape_src.IsMklTensor()) { 718 dnn_shape_diff_src.SetTfLayout(dnn_shape_src.GetDimension(), 719 dnn_shape_src.GetSizesAsMklDnnDims(), 720 dnn_shape_src.GetTfDataFormat()); 721 } else { 722 dnn_shape_diff_src.SetTfLayout( 723 dnn_shape_diff_dst.GetDimension(), 724 dnn_shape_diff_dst.GetSizesAsMklDnnDims(), 725 dnn_shape_diff_dst.GetTfDataFormat()); 726 } 727 tf_shape_diff_src.AddDim(diff_src_pd.get_size() / sizeof(T)); 728 } else { 729 dnn_shape_diff_src.SetMklTensor(false); 730 tf_shape_diff_src = src_tensor.shape(); 731 } 732 733 OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( 734 {static_cast<const int>(diff_dst_index)}, 735 static_cast<const int>(diff_src_index), 736 tf_shape_diff_src, &diff_src_tensor)); 737 AllocateOutputSetMklShape(context, diff_src_index, dnn_shape_diff_src); 738 739 T* diff_src_data = diff_src_tensor->flat<T>().data(); 740 741 // execute eltwise bwd 742 eltwise_bwd->Execute(src_data, diff_dst_data, diff_src_data, 743 bwd_cpu_stream); 744 } catch (dnnl::error& e) { 745 string error_msg = "Status: " + std::to_string(e.status) + 746 ", message: " + string(e.message) + ", in file " + 747 string(__FILE__) + ":" + std::to_string(__LINE__); 748 OP_REQUIRES_OK( 749 context, 750 errors::Aborted("Operation received an exception:", error_msg)); 751 } 752 } 753 754 private: 755 engine cpu_engine = engine(engine::kind::cpu, 0); 756 std::shared_ptr<EltwiseFwdPd> relu_fwd_pd; 757 758 protected: 759 float alpha_; 760 float beta_; 761 }; 762 763 template <typename Device, typename T> 764 class MklReluOp 765 : public MklReluOpBase<Device, T, dnnl::algorithm::eltwise_relu> { 766 public: ~MklReluOp()767 ~MklReluOp() {} 768 MklReluOp(OpKernelConstruction * context)769 explicit MklReluOp(OpKernelConstruction* context) 770 : MklReluOpBase<Device, T, dnnl::algorithm::eltwise_relu>(context, 0.0f, 771 0.0f) {} 772 Compute_Scalar(OpKernelContext * context)773 virtual void Compute_Scalar(OpKernelContext* context) { 774 const size_t src_index = 0; // index of src input tensor 775 const size_t dst_index = 0; // index of dst output tensor 776 const Tensor& src_tensor = MklGetInput(context, src_index); 777 MklDnnShape dnn_shape_src; 778 GetMklShape(context, src_index, &dnn_shape_src); 779 780 Tensor* dst_tensor = nullptr; 781 void* user_i = 782 static_cast<void*>(const_cast<T*>(src_tensor.flat<T>().data())); 783 MklDnnShape dnn_shape_dst; 784 dnn_shape_dst.SetMklTensor(false); 785 AllocateOutputSetMklShape(context, dst_index, &dst_tensor, 786 src_tensor.shape(), dnn_shape_dst); 787 void* out_o = static_cast<void*>(dst_tensor->flat<T>().data()); 788 (static_cast<T*>(out_o))[0] = 789 std::max((static_cast<T*>(user_i))[0], static_cast<T>(0)); 790 return; 791 } 792 }; 793 794 template <typename Device, typename T> 795 class MklReluGradOp 796 : public MklReluGradOpBase<Device, T, dnnl::algorithm::eltwise_relu> { 797 public: ~MklReluGradOp()798 ~MklReluGradOp() {} 799 MklReluGradOp(OpKernelConstruction * context)800 explicit MklReluGradOp(OpKernelConstruction* context) 801 : MklReluGradOpBase<Device, T, dnnl::algorithm::eltwise_relu>( 802 context, 0.0f, 0.0f) {} 803 Compute_Scalar(OpKernelContext * context)804 virtual void Compute_Scalar(OpKernelContext* context) { 805 const size_t diff_dst_index = 0; // index of diff_dst input tensor 806 const size_t src_index = 1; // index of src input tensor 807 const size_t diff_src_index = 0; // index of diff_src output tensor 808 const Tensor& src_tensor = MklGetInput(context, src_index); 809 const Tensor& diff_dst_tensor = MklGetInput(context, diff_dst_index); 810 Tensor* diff_src_tensor = nullptr; 811 812 MklDnnShape dnn_shape_diff_dst; 813 GetMklShape(context, diff_dst_index, &dnn_shape_diff_dst); 814 815 MklDnnShape dnn_shape_diff_src; 816 dnn_shape_diff_src.SetMklTensor(false); 817 AllocateOutputSetMklShape(context, diff_src_index, &diff_src_tensor, 818 diff_dst_tensor.shape(), dnn_shape_diff_src); 819 void* out_o = static_cast<void*>(diff_src_tensor->flat<T>().data()); 820 void* user_i = 821 static_cast<void*>(const_cast<T*>(src_tensor.flat<T>().data())); 822 void* user_g = 823 static_cast<void*>(const_cast<T*>(diff_dst_tensor.flat<T>().data())); 824 (static_cast<T*>(out_o))[0] = 825 (static_cast<T*>(user_g))[0] * 826 (static_cast<T>((static_cast<T*>(user_i))[0] > static_cast<T>(0))); 827 return; 828 } 829 }; 830 831 template <typename Device, typename T> 832 class MklEluOp : public MklReluOpBase<Device, T, dnnl::algorithm::eltwise_elu> { 833 public: ~MklEluOp()834 ~MklEluOp() {} 835 MklEluOp(OpKernelConstruction * context)836 explicit MklEluOp(OpKernelConstruction* context) 837 : MklReluOpBase<Device, T, dnnl::algorithm::eltwise_elu>(context, 0.0f, 838 0.0f) {} 839 Compute_Scalar(OpKernelContext * context)840 virtual void Compute_Scalar(OpKernelContext* context) { 841 const size_t src_index = 0; // index of src input tensor 842 const size_t dst_index = 0; // index of dst output tensor 843 const Tensor& src_tensor = MklGetInput(context, src_index); 844 MklDnnShape dnn_shape_src; 845 GetMklShape(context, src_index, &dnn_shape_src); 846 847 Tensor* dst_tensor = nullptr; 848 void* user_i = 849 static_cast<void*>(const_cast<T*>(src_tensor.flat<T>().data())); 850 MklDnnShape dnn_shape_dst; 851 dnn_shape_dst.SetMklTensor(false); 852 AllocateOutputSetMklShape(context, dst_index, &dst_tensor, 853 src_tensor.shape(), dnn_shape_dst); 854 void* out_o = static_cast<void*>(dst_tensor->flat<T>().data()); 855 // return exp(feature) - 1 if feature > 0; feature otherwise 856 T feature = (static_cast<T*>(user_i))[0]; 857 if (feature < static_cast<T>(0)) 858 (static_cast<T*>(out_o))[0] = Eigen::numext::exp(feature); 859 else 860 (static_cast<T*>(out_o))[0] = feature; 861 return; 862 } 863 }; 864 865 template <typename Device, typename T> 866 class MklEluGradOp 867 : public MklReluGradOpBase<Device, T, dnnl::algorithm::eltwise_elu> { 868 public: ~MklEluGradOp()869 ~MklEluGradOp() {} 870 MklEluGradOp(OpKernelConstruction * context)871 explicit MklEluGradOp(OpKernelConstruction* context) 872 : MklReluGradOpBase<Device, T, dnnl::algorithm::eltwise_elu>( 873 context, 0.0f, 0.0f) {} 874 Compute_Scalar(OpKernelContext * context)875 virtual void Compute_Scalar(OpKernelContext* context) { 876 const size_t diff_dst_index = 0; // index of diff_dst input tensor 877 const size_t src_index = 1; // index of src input tensor 878 const size_t diff_src_index = 0; // index of diff_src output tensor 879 const Tensor& src_tensor = MklGetInput(context, src_index); 880 const Tensor& diff_dst_tensor = MklGetInput(context, diff_dst_index); 881 Tensor* diff_src_tensor = nullptr; 882 883 MklDnnShape dnn_shape_diff_dst; 884 GetMklShape(context, diff_dst_index, &dnn_shape_diff_dst); 885 886 MklDnnShape dnn_shape_diff_src; 887 dnn_shape_diff_src.SetMklTensor(false); 888 AllocateOutputSetMklShape(context, diff_src_index, &diff_src_tensor, 889 diff_dst_tensor.shape(), dnn_shape_diff_src); 890 void* out_o = static_cast<void*>(diff_src_tensor->flat<T>().data()); 891 void* user_i = 892 static_cast<void*>(const_cast<T*>(src_tensor.flat<T>().data())); 893 void* user_g = 894 static_cast<void*>(const_cast<T*>(diff_dst_tensor.flat<T>().data())); 895 // gradient of elu(x) = 1 if x > 0; elu(x) + 1 otherwise 896 T feature = (static_cast<T*>(user_i))[0]; 897 if (feature > static_cast<T>(0)) { 898 (static_cast<T*>(out_o))[0] = (static_cast<T*>(user_g))[0]; 899 } else { 900 T elu = Eigen::numext::exp(feature) - static_cast<T>(1); 901 (static_cast<T*>(out_o))[0] = 902 (static_cast<T*>(user_g))[0] * (elu + static_cast<T>(1)); 903 } 904 } 905 }; 906 907 // Optimized TanhGrad support exists in DNNL1.x only 908 // (eltwise_tanh_use_dst_for_bwd). We can still support it with DNNL0.x, but 909 // it will not be optimized. So we disable it for DNNL0.x. 910 911 template <typename Device, typename T> 912 class MklTanhOp 913 : public MklReluOpBase<Device, T, dnnl::algorithm::eltwise_tanh> { 914 public: ~MklTanhOp()915 ~MklTanhOp() {} 916 MklTanhOp(OpKernelConstruction * context)917 explicit MklTanhOp(OpKernelConstruction* context) 918 : MklReluOpBase<Device, T, dnnl::algorithm::eltwise_tanh>(context, 0.0f, 919 0.0f) {} 920 Compute_Scalar(OpKernelContext * context)921 virtual void Compute_Scalar(OpKernelContext* context) { 922 const size_t src_index = 0; // index of src input tensor 923 const size_t dst_index = 0; // index of dst output tensor 924 const Tensor& src_tensor = MklGetInput(context, src_index); 925 MklDnnShape dnn_shape_src; 926 GetMklShape(context, src_index, &dnn_shape_src); 927 928 Tensor* dst_tensor = nullptr; 929 void* user_i = 930 static_cast<void*>(const_cast<T*>(src_tensor.flat<T>().data())); 931 MklDnnShape dnn_shape_dst; 932 dnn_shape_dst.SetMklTensor(false); 933 AllocateOutputSetMklShape(context, dst_index, &dst_tensor, 934 src_tensor.shape(), dnn_shape_dst); 935 void* out_o = static_cast<void*>(dst_tensor->flat<T>().data()); 936 // tanh(x) = (e^x - e^(-x))/ (e^x + e^(-x)) 937 T feature = (static_cast<T*>(user_i))[0]; 938 T e1 = Eigen::numext::exp(feature); 939 T e2 = Eigen::numext::exp(-feature); 940 (static_cast<T*>(out_o))[0] = (e1 - e2) / (e1 + e2); 941 return; 942 } 943 }; 944 945 template <typename Device, typename T> 946 class MklTanhGradOp 947 : public MklReluGradOpBase<Device, T, 948 dnnl::algorithm::eltwise_tanh_use_dst_for_bwd> { 949 public: ~MklTanhGradOp()950 ~MklTanhGradOp() {} 951 MklTanhGradOp(OpKernelConstruction * context)952 explicit MklTanhGradOp(OpKernelConstruction* context) 953 : MklReluGradOpBase<Device, T, 954 dnnl::algorithm::eltwise_tanh_use_dst_for_bwd>( 955 context, 0.0f, 0.0f) {} 956 GetDiffDstIndex() const957 virtual int GetDiffDstIndex() const { return 1; } GetSrcIndex() const958 virtual int GetSrcIndex() const { return 0; } GetDiffSrcIndex() const959 virtual int GetDiffSrcIndex() const { return 0; } 960 961 // TanhGrad gets 'y' from Tanh, where 'y' is output of Tanh(x). GetTypeOfInputTensorFromFwdOp() const962 virtual int GetTypeOfInputTensorFromFwdOp() const { return DNNL_ARG_DST; } 963 Compute_Scalar(OpKernelContext * context)964 virtual void Compute_Scalar(OpKernelContext* context) { 965 // NOTE: Order of y and dy for Tanh is reverse of that for Relu/Elu/other 966 // element-wise ops. Tanh is math op in Tensorflow; others are NN ops. 967 const size_t diff_dst_index = GetDiffDstIndex(); 968 const size_t src_index = GetSrcIndex(); 969 const size_t diff_src_index = GetDiffSrcIndex(); 970 const Tensor& src_tensor = MklGetInput(context, src_index); 971 const Tensor& diff_dst_tensor = MklGetInput(context, diff_dst_index); 972 Tensor* diff_src_tensor = nullptr; 973 974 MklDnnShape dnn_shape_diff_dst; 975 GetMklShape(context, diff_dst_index, &dnn_shape_diff_dst); 976 977 MklDnnShape dnn_shape_diff_src; 978 dnn_shape_diff_src.SetMklTensor(false); 979 AllocateOutputSetMklShape(context, diff_src_index, &diff_src_tensor, 980 diff_dst_tensor.shape(), dnn_shape_diff_src); 981 void* out_o = static_cast<void*>(diff_src_tensor->flat<T>().data()); 982 void* user_i = 983 static_cast<void*>(const_cast<T*>(src_tensor.flat<T>().data())); 984 // gradient of tanh(x) = 1 - tanh(x)^2 985 // Input to TanhGrad is output of Tanh. So we do not need to compute 986 // Tanh again. 987 T tanh = (static_cast<T*>(user_i))[0]; 988 void* user_g = 989 static_cast<void*>(const_cast<T*>(diff_dst_tensor.flat<T>().data())); 990 (static_cast<T*>(out_o))[0] = 991 (static_cast<T*>(user_g))[0] * (static_cast<T>(1) - tanh * tanh); 992 } 993 }; 994 995 #define RELU6_UPPER_BOUND 6.0f 996 template <typename Device, typename T> 997 class MklRelu6Op 998 : public MklReluOpBase<Device, T, dnnl::algorithm::eltwise_bounded_relu> { 999 public: ~MklRelu6Op()1000 ~MklRelu6Op() {} 1001 MklRelu6Op(OpKernelConstruction * context)1002 explicit MklRelu6Op(OpKernelConstruction* context) 1003 : MklReluOpBase<Device, T, dnnl::algorithm::eltwise_bounded_relu>( 1004 context, RELU6_UPPER_BOUND, 0.0f) {} 1005 Compute_Scalar(OpKernelContext * context)1006 virtual void Compute_Scalar(OpKernelContext* context) { 1007 const size_t src_index = 0; // index of src input tensor 1008 const size_t dst_index = 0; // index of dst output tensor 1009 const Tensor& src_tensor = MklGetInput(context, src_index); 1010 MklDnnShape dnn_shape_src; 1011 GetMklShape(context, src_index, &dnn_shape_src); 1012 1013 Tensor* dst_tensor = nullptr; 1014 T* user_i = const_cast<T*>(src_tensor.flat<T>().data()); 1015 MklDnnShape dnn_shape_dst; 1016 dnn_shape_dst.SetMklTensor(false); 1017 AllocateOutputSetMklShape(context, dst_index, &dst_tensor, 1018 src_tensor.shape(), dnn_shape_dst); 1019 T* out_o = dst_tensor->flat<T>().data(); 1020 out_o[0] = std::min(std::max(user_i[0], static_cast<T>(0)), 1021 static_cast<T>(RELU6_UPPER_BOUND)); 1022 return; 1023 } 1024 }; 1025 1026 template <typename Device, typename T> 1027 class MklRelu6GradOp 1028 : public MklReluGradOpBase<Device, T, 1029 dnnl::algorithm::eltwise_bounded_relu> { 1030 public: ~MklRelu6GradOp()1031 ~MklRelu6GradOp() {} 1032 MklRelu6GradOp(OpKernelConstruction * context)1033 explicit MklRelu6GradOp(OpKernelConstruction* context) 1034 : MklReluGradOpBase<Device, T, dnnl::algorithm::eltwise_bounded_relu>( 1035 context, RELU6_UPPER_BOUND, 0.0f) {} 1036 Compute_Scalar(OpKernelContext * context)1037 virtual void Compute_Scalar(OpKernelContext* context) { 1038 const size_t diff_dst_index = 0; // index of diff_dst input tensor 1039 const size_t src_index = 1; // index of src input tensor 1040 const size_t diff_src_index = 0; // index of diff_src output tensor 1041 const Tensor& src_tensor = MklGetInput(context, src_index); 1042 const Tensor& diff_dst_tensor = MklGetInput(context, diff_dst_index); 1043 Tensor* diff_src_tensor = nullptr; 1044 1045 MklDnnShape dnn_shape_diff_dst; 1046 GetMklShape(context, diff_dst_index, &dnn_shape_diff_dst); 1047 1048 MklDnnShape dnn_shape_diff_src; 1049 dnn_shape_diff_src.SetMklTensor(false); 1050 AllocateOutputSetMklShape(context, diff_src_index, &diff_src_tensor, 1051 diff_dst_tensor.shape(), dnn_shape_diff_src); 1052 T* out_o = diff_src_tensor->flat<T>().data(); 1053 T* user_i = const_cast<T*>(src_tensor.flat<T>().data()); 1054 T* user_g = const_cast<T*>(diff_dst_tensor.flat<T>().data()); 1055 out_o[0] = user_g[0] * 1056 static_cast<T>(user_i[0] > static_cast<T>(0) && 1057 (user_i[0] < static_cast<T>(RELU6_UPPER_BOUND))); 1058 return; 1059 } 1060 }; 1061 1062 template <typename Device, typename T> 1063 class MklLeakyReluOp 1064 : public MklReluOpBase<Device, T, dnnl::algorithm::eltwise_relu> { 1065 public: ~MklLeakyReluOp()1066 ~MklLeakyReluOp() {} 1067 MklLeakyReluOp(OpKernelConstruction * context)1068 explicit MklLeakyReluOp(OpKernelConstruction* context) 1069 : MklReluOpBase<Device, T, dnnl::algorithm::eltwise_relu>(context, 0.0f, 1070 0.0f) { 1071 float alpha; 1072 OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha)); 1073 OP_REQUIRES( 1074 context, alpha <= 1, 1075 errors::InvalidArgument("MKL LeakyRelu only supports alpha <= 1. " 1076 "alpha is: ", 1077 alpha)); 1078 1079 this->alpha_ = alpha; 1080 } 1081 Compute_Scalar(OpKernelContext * context)1082 virtual void Compute_Scalar(OpKernelContext* context) { 1083 const size_t src_index = 0; // index of src input tensor 1084 const size_t dst_index = 0; // index of dst output tensor 1085 const Tensor& src_tensor = MklGetInput(context, src_index); 1086 MklDnnShape dnn_shape_src; 1087 GetMklShape(context, src_index, &dnn_shape_src); 1088 1089 Tensor* dst_tensor = nullptr; 1090 T* user_i = const_cast<T*>(src_tensor.flat<T>().data()); 1091 MklDnnShape dnn_shape_dst; 1092 dnn_shape_dst.SetMklTensor(false); 1093 AllocateOutputSetMklShape(context, dst_index, &dst_tensor, 1094 src_tensor.shape(), dnn_shape_dst); 1095 T* out_o = dst_tensor->flat<T>().data(); 1096 out_o[0] = user_i[0] >= T(0) ? user_i[0] : user_i[0] * T(this->alpha_); 1097 return; 1098 } 1099 }; 1100 1101 template <typename Device, typename T> 1102 class MklLeakyReluGradOp 1103 : public MklReluGradOpBase<Device, T, dnnl::algorithm::eltwise_relu> { 1104 public: ~MklLeakyReluGradOp()1105 ~MklLeakyReluGradOp() {} 1106 MklLeakyReluGradOp(OpKernelConstruction * context)1107 explicit MklLeakyReluGradOp(OpKernelConstruction* context) 1108 : MklReluGradOpBase<Device, T, dnnl::algorithm::eltwise_relu>( 1109 context, 0.0f, 0.0f) { 1110 float alpha; 1111 OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha)); 1112 OP_REQUIRES( 1113 context, alpha <= 1, 1114 errors::InvalidArgument("MKL LeakyRelu only supports alpha <= 1. " 1115 "alpha is: ", 1116 alpha)); 1117 1118 this->alpha_ = alpha; 1119 } 1120 Compute_Scalar(OpKernelContext * context)1121 virtual void Compute_Scalar(OpKernelContext* context) { 1122 const size_t diff_dst_index = 0; // index of diff_dst input tensor 1123 const size_t src_index = 1; // index of src input tensor 1124 const size_t diff_src_index = 0; // index of diff_src output tensor 1125 const Tensor& src_tensor = MklGetInput(context, src_index); 1126 const Tensor& diff_dst_tensor = MklGetInput(context, diff_dst_index); 1127 Tensor* diff_src_tensor = nullptr; 1128 1129 MklDnnShape dnn_shape_diff_dst; 1130 GetMklShape(context, diff_dst_index, &dnn_shape_diff_dst); 1131 1132 MklDnnShape dnn_shape_diff_src; 1133 dnn_shape_diff_src.SetMklTensor(false); 1134 AllocateOutputSetMklShape(context, diff_src_index, &diff_src_tensor, 1135 diff_dst_tensor.shape(), dnn_shape_diff_src); 1136 T* out_o = diff_src_tensor->flat<T>().data(); 1137 T* user_i = const_cast<T*>(src_tensor.flat<T>().data()); 1138 T* user_g = const_cast<T*>(diff_dst_tensor.flat<T>().data()); 1139 out_o[0] = user_i[0] >= static_cast<T>(0) 1140 ? user_g[0] 1141 : user_g[0] * static_cast<T>(this->alpha_); 1142 return; 1143 } 1144 }; 1145 1146 // register dnn kernels for supported operations and supported types 1147 #define REGISTER_RELU_MKL_SUPPORTED_KERNELS_TYPES(type) \ 1148 REGISTER_KERNEL_BUILDER( \ 1149 Name("_MklRelu") \ 1150 .Device(DEVICE_CPU) \ 1151 .TypeConstraint<type>("T") \ 1152 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 1153 MklReluOp<CPUDevice, type>); \ 1154 REGISTER_KERNEL_BUILDER( \ 1155 Name("_MklReluGrad") \ 1156 .Device(DEVICE_CPU) \ 1157 .TypeConstraint<type>("T") \ 1158 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 1159 MklReluGradOp<CPUDevice, type>); 1160 TF_CALL_float(REGISTER_RELU_MKL_SUPPORTED_KERNELS_TYPES); 1161 TF_CALL_bfloat16(REGISTER_RELU_MKL_SUPPORTED_KERNELS_TYPES); 1162 1163 // register dnn kernels for supported operations and supported types 1164 #define REGISTER_ELU_MKL_SUPPORTED_KERNELS_TYPES(type) \ 1165 REGISTER_KERNEL_BUILDER( \ 1166 Name("_MklElu") \ 1167 .Device(DEVICE_CPU) \ 1168 .TypeConstraint<type>("T") \ 1169 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 1170 MklEluOp<CPUDevice, type>); \ 1171 REGISTER_KERNEL_BUILDER( \ 1172 Name("_MklEluGrad") \ 1173 .Device(DEVICE_CPU) \ 1174 .TypeConstraint<type>("T") \ 1175 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 1176 MklEluGradOp<CPUDevice, type>); 1177 TF_CALL_float(REGISTER_ELU_MKL_SUPPORTED_KERNELS_TYPES); 1178 TF_CALL_bfloat16(REGISTER_ELU_MKL_SUPPORTED_KERNELS_TYPES); 1179 1180 #define REGISTER_TANH_MKL_SUPPORTED_KERNELS_TYPES(type) \ 1181 REGISTER_KERNEL_BUILDER( \ 1182 Name("_MklTanh") \ 1183 .Device(DEVICE_CPU) \ 1184 .TypeConstraint<type>("T") \ 1185 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 1186 MklTanhOp<CPUDevice, type>); \ 1187 REGISTER_KERNEL_BUILDER( \ 1188 Name("_MklTanhGrad") \ 1189 .Device(DEVICE_CPU) \ 1190 .TypeConstraint<type>("T") \ 1191 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 1192 MklTanhGradOp<CPUDevice, type>); 1193 TF_CALL_float(REGISTER_TANH_MKL_SUPPORTED_KERNELS_TYPES); 1194 TF_CALL_bfloat16(REGISTER_TANH_MKL_SUPPORTED_KERNELS_TYPES); 1195 1196 #define REGISTER_RELU6_MKL_SUPPORTED_KERNELS_TYPES(type) \ 1197 REGISTER_KERNEL_BUILDER( \ 1198 Name("_MklRelu6") \ 1199 .Device(DEVICE_CPU) \ 1200 .TypeConstraint<type>("T") \ 1201 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 1202 MklRelu6Op<CPUDevice, type>); \ 1203 REGISTER_KERNEL_BUILDER( \ 1204 Name("_MklRelu6Grad") \ 1205 .Device(DEVICE_CPU) \ 1206 .TypeConstraint<type>("T") \ 1207 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 1208 MklRelu6GradOp<CPUDevice, type>); 1209 TF_CALL_float(REGISTER_RELU6_MKL_SUPPORTED_KERNELS_TYPES); 1210 TF_CALL_bfloat16(REGISTER_RELU6_MKL_SUPPORTED_KERNELS_TYPES); 1211 1212 #define REGISTER_LeakyRelu_MKL_SUPPORTED_KERNELS_TYPES(type) \ 1213 REGISTER_KERNEL_BUILDER( \ 1214 Name("_MklLeakyRelu") \ 1215 .Device(DEVICE_CPU) \ 1216 .TypeConstraint<type>("T") \ 1217 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 1218 MklLeakyReluOp<CPUDevice, type>); \ 1219 REGISTER_KERNEL_BUILDER( \ 1220 Name("_MklLeakyReluGrad") \ 1221 .Device(DEVICE_CPU) \ 1222 .TypeConstraint<type>("T") \ 1223 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 1224 MklLeakyReluGradOp<CPUDevice, type>); 1225 TF_CALL_float(REGISTER_LeakyRelu_MKL_SUPPORTED_KERNELS_TYPES); 1226 TF_CALL_bfloat16(REGISTER_LeakyRelu_MKL_SUPPORTED_KERNELS_TYPES); 1227 1228 } // namespace tensorflow 1229 1230 #endif // INTEL_MKL 1231