xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/mkl/mkl_relu_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/nn_ops.cc.
17 #ifdef INTEL_MKL
18 
19 #include <unordered_map>
20 
21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22 #include "dnnl.hpp"
23 #include "tensorflow/core/framework/numeric_op.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/register_types.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/util/mkl_util.h"
29 #ifdef DNNL_AARCH64_USE_ACL
30 #include "tensorflow/core/platform/mutex.h"
31 #endif
32 
33 using dnnl::algorithm;
34 using dnnl::eltwise_forward;
35 using dnnl::memory;
36 using dnnl::prop_kind;
37 using dnnl::stream;
38 
39 using EltwiseFwdPd = dnnl::eltwise_forward::primitive_desc;
40 using EltwiseBwdPd = dnnl::eltwise_backward::primitive_desc;
41 
42 namespace tensorflow {
43 
44 template <typename T>
45 class MklEltwiseFwdParams {
46  public:
47   memory::dims src_dims;
48   memory::desc src_md;
49   algorithm alg_kind;
50   float alpha;
51   float beta;
52 
MklEltwiseFwdParams(memory::dims src_dims,memory::desc src_md,algorithm alg_kind,float alpha,float beta)53   MklEltwiseFwdParams(memory::dims src_dims, memory::desc src_md,
54                       algorithm alg_kind, float alpha, float beta)
55       : src_dims(src_dims),
56         src_md(src_md),
57         alg_kind(alg_kind),
58         alpha(alpha),
59         beta(beta) {}
60 };
61 
62 template <typename T>
63 class MklEltwiseFwdPrimitive : public MklPrimitive {
64  public:
MklEltwiseFwdPrimitive(const MklEltwiseFwdParams<T> & fwdParams)65   explicit MklEltwiseFwdPrimitive(const MklEltwiseFwdParams<T>& fwdParams)
66       : MklPrimitive(engine(engine::kind::cpu, 0)) {
67     // create eltwise primitive
68     if (context_.eltwise_fwd == nullptr) {
69       Setup(fwdParams);
70     }
71   }
72 
~MklEltwiseFwdPrimitive()73   ~MklEltwiseFwdPrimitive() {}
74 
75   // Eltwise forward execute
76   //   src_data:  input data buffer of src
77   //   dst_data:  output data buffer of dst
Execute(const T * src_data,T * dst_data,std::shared_ptr<stream> fwd_stream)78   void Execute(const T* src_data, T* dst_data,
79                std::shared_ptr<stream> fwd_stream) {
80 #ifdef DNNL_AARCH64_USE_ACL
81     mutex_lock lock(primitive_execution_mu_);
82 #endif
83 #ifndef ENABLE_ONEDNN_OPENMP
84     context_.src_mem->set_data_handle(
85         static_cast<void*>(const_cast<T*>(src_data)), *fwd_stream);
86     context_.dst_mem->set_data_handle(static_cast<void*>(dst_data),
87                                       *fwd_stream);
88 #else
89     context_.src_mem->set_data_handle(
90         static_cast<void*>(const_cast<T*>(src_data)));
91     context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
92 #endif  // !ENABLE_ONEDNN_OPENMP
93     DCHECK_EQ(context_.fwd_primitives.size(),
94               context_.fwd_primitives_args.size());
95     execute_primitives(context_.fwd_primitives, fwd_stream,
96                        context_.fwd_primitives_args);
97 
98     // After execution, set data handle back.
99     context_.src_mem->set_data_handle(DummyData);
100     context_.dst_mem->set_data_handle(DummyData);
101   }
102 
GetEltwiseFwdPd()103   std::shared_ptr<EltwiseFwdPd> GetEltwiseFwdPd() { return context_.fwd_pd; }
104 
105  private:
106   // Primitive reuse context for eltwise Fwd ops: Relu, Elu, Tanh
107   struct EltwiseFwdContext {
108     // oneDNN memory
109     std::shared_ptr<memory> src_mem;
110     std::shared_ptr<memory> dst_mem;
111 
112     // desc & primitive desc
113     std::shared_ptr<dnnl::eltwise_forward::desc> fwd_desc;
114     std::shared_ptr<EltwiseFwdPd> fwd_pd;
115 
116     // memory desc
117     std::shared_ptr<memory::desc> src_md;
118     std::shared_ptr<memory::desc> dst_md;
119 
120     // Eltwise primitive
121     std::shared_ptr<dnnl::primitive> eltwise_fwd;
122 
123     std::vector<dnnl::primitive> fwd_primitives;
124 
125     std::vector<std::unordered_map<int, memory>> fwd_primitives_args;
126 
EltwiseFwdContexttensorflow::MklEltwiseFwdPrimitive::EltwiseFwdContext127     EltwiseFwdContext()
128         : src_mem(nullptr),
129           dst_mem(nullptr),
130           fwd_desc(nullptr),
131           fwd_pd(nullptr),
132           src_md(nullptr),
133           dst_md(nullptr),
134           eltwise_fwd(nullptr) {}
135   };
136 
137   // Eltwise forward primitive setup
Setup(const MklEltwiseFwdParams<T> & fwdParams)138   void Setup(const MklEltwiseFwdParams<T>& fwdParams) {
139     // create memory descriptors for eltwise data with specified format
140     context_.src_md.reset(new memory::desc(fwdParams.src_md.data));
141 
142     // Create an eltwise forward descriptor and primitive descriptor
143     context_.fwd_desc.reset(new eltwise_forward::desc(
144         prop_kind::forward, fwdParams.alg_kind, *context_.src_md,
145         fwdParams.alpha, fwdParams.beta));
146     context_.fwd_pd.reset(new EltwiseFwdPd(*context_.fwd_desc, cpu_engine_));
147     auto fwd_pd = context_.fwd_pd.get();
148 
149     // Create memory primitive based on dummy data
150     context_.src_mem.reset(
151         new memory(fwd_pd->src_desc(), cpu_engine_, DummyData));
152     context_.dst_mem.reset(
153         new memory(fwd_pd->dst_desc(), cpu_engine_, DummyData));
154     // Create eltwise primitive and add it to net
155     context_.eltwise_fwd.reset(new eltwise_forward(*context_.fwd_pd));
156     context_.fwd_primitives_args.push_back(
157         {{DNNL_ARG_SRC, *context_.src_mem}, {DNNL_ARG_DST, *context_.dst_mem}});
158     context_.fwd_primitives.push_back(*context_.eltwise_fwd);
159   }
160 
161   struct EltwiseFwdContext context_;
162 
163 #ifdef DNNL_AARCH64_USE_ACL
164   mutex primitive_execution_mu_;
165 #endif
166 };
167 
168 template <typename T>
169 class MklEltwiseFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
170  public:
Get(const MklEltwiseFwdParams<T> & fwdParams)171   static MklEltwiseFwdPrimitive<T>* Get(
172       const MklEltwiseFwdParams<T>& fwdParams) {
173     MklEltwiseFwdPrimitive<T>* eltwise_forward = nullptr;
174 
175     // Get a eltwise fwd primitive from the cached pool
176     eltwise_forward = static_cast<MklEltwiseFwdPrimitive<T>*>(
177         MklEltwiseFwdPrimitiveFactory<T>::GetInstance().GetEltwiseFwd(
178             fwdParams));
179     if (eltwise_forward == nullptr) {
180       eltwise_forward = new MklEltwiseFwdPrimitive<T>(fwdParams);
181       MklEltwiseFwdPrimitiveFactory<T>::GetInstance().SetEltwiseFwd(
182           fwdParams, eltwise_forward);
183     }
184 
185     return eltwise_forward;
186   }
187 
GetInstance()188   static MklEltwiseFwdPrimitiveFactory& GetInstance() {
189     static MklEltwiseFwdPrimitiveFactory instance_;
190     return instance_;
191   }
192 
193  private:
MklEltwiseFwdPrimitiveFactory()194   MklEltwiseFwdPrimitiveFactory() {}
~MklEltwiseFwdPrimitiveFactory()195   ~MklEltwiseFwdPrimitiveFactory() {}
196 
CreateKey(const MklEltwiseFwdParams<T> & fwdParams)197   static string CreateKey(const MklEltwiseFwdParams<T>& fwdParams) {
198     string prefix = "eltwise_fwd";
199     FactoryKeyCreator key_creator;
200     key_creator.AddAsKey(prefix);
201     key_creator.AddAsKey(fwdParams.src_dims);
202     key_creator.AddAsKey<int>(static_cast<int>(fwdParams.alg_kind));
203     key_creator.AddAsKey<float>(static_cast<float>(fwdParams.alpha));
204     key_creator.AddAsKey<float>(static_cast<float>(fwdParams.beta));
205     return key_creator.GetKey();
206   }
207 
GetEltwiseFwd(const MklEltwiseFwdParams<T> & fwdParams)208   MklPrimitive* GetEltwiseFwd(const MklEltwiseFwdParams<T>& fwdParams) {
209     string key = CreateKey(fwdParams);
210     return this->GetOp(key);
211   }
212 
SetEltwiseFwd(const MklEltwiseFwdParams<T> & fwdParams,MklPrimitive * op)213   void SetEltwiseFwd(const MklEltwiseFwdParams<T>& fwdParams,
214                      MklPrimitive* op) {
215     string key = CreateKey(fwdParams);
216     this->SetOp(key, op);
217   }
218 };
219 
220 template <typename T>
221 class MklEltwiseBwdParams {
222  public:
223   memory::dims src_dims;
224   memory::desc common_md;
225   algorithm alg_kind;
226   float alpha;
227   float beta;
228   // Whether the input that grad op gets from forward op is SRC
229   // of forward op or DST of forward op.
230   int forward_input_type;
231 
MklEltwiseBwdParams(const memory::dims & src_dims,const memory::desc & common_md,algorithm alg_kind,float alpha,float beta,int forward_input_type=-1)232   MklEltwiseBwdParams(const memory::dims& src_dims,
233                       const memory::desc& common_md, algorithm alg_kind,
234                       float alpha, float beta, int forward_input_type = -1)
235       : src_dims(src_dims),
236         common_md(common_md),
237         alg_kind(alg_kind),
238         alpha(alpha),
239         beta(beta),
240         forward_input_type(forward_input_type) {}
241 };
242 
243 template <typename T>
244 class MklEltwiseBwdPrimitive : public MklPrimitive {
245  public:
MklEltwiseBwdPrimitive(const MklEltwiseBwdParams<T> & bwdParams)246   explicit MklEltwiseBwdPrimitive(const MklEltwiseBwdParams<T>& bwdParams)
247       : MklPrimitive(engine(engine::kind::cpu, 0)) {
248     // create eltwise primitive
249     if (context_.eltwise_bwd == nullptr) {
250       Setup(bwdParams);
251     }
252   }
253 
~MklEltwiseBwdPrimitive()254   ~MklEltwiseBwdPrimitive() {}
255 
256   // Eltwise backward execute
257   //   src_data:       input data buffer of src
258   //   diff_dst_data:  input data buffer of diff_dst
259   //   diff_src_data:  output data buffer of diff_src
Execute(const T * src_data,const T * diff_dst_data,T * diff_src_data,std::shared_ptr<stream> bwd_stream)260   void Execute(const T* src_data, const T* diff_dst_data, T* diff_src_data,
261                std::shared_ptr<stream> bwd_stream) {
262 #ifdef DNNL_AARCH64_USE_ACL
263     mutex_lock lock(primitive_execution_mu_);
264 #endif
265 #ifndef ENABLE_ONEDNN_OPENMP
266     context_.src_mem->set_data_handle(
267         static_cast<void*>(const_cast<T*>(src_data)), *bwd_stream);
268     context_.diff_dst_mem->set_data_handle(
269         static_cast<void*>(const_cast<T*>(diff_dst_data)), *bwd_stream);
270     context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data),
271                                            *bwd_stream);
272 #else
273     context_.src_mem->set_data_handle(
274         static_cast<void*>(const_cast<T*>(src_data)));
275     context_.diff_dst_mem->set_data_handle(
276         static_cast<void*>(const_cast<T*>(diff_dst_data)));
277     context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data));
278 #endif  // !ENABLE_ONEDNN_OPENMP
279     DCHECK_EQ(context_.bwd_primitives.size(),
280               context_.bwd_primitives_args.size());
281     execute_primitives(context_.bwd_primitives, bwd_stream,
282                        context_.bwd_primitives_args);
283 
284     // after execution, set data handle back
285     context_.src_mem->set_data_handle(DummyData);
286     context_.diff_dst_mem->set_data_handle(DummyData);
287     context_.diff_src_mem->set_data_handle(DummyData);
288   }
289 
GetEltwiseBwdPd()290   std::shared_ptr<EltwiseBwdPd> GetEltwiseBwdPd() { return context_.bwd_pd; }
291 
292  private:
293   // Primitive reuse context for eltwise Bwd ops: Relu, Elu, Tanh
294   struct EltwiseBwdContext {
295     // oneDNN memory
296     std::shared_ptr<memory> src_mem;
297     std::shared_ptr<memory> diff_dst_mem;
298     std::shared_ptr<memory> diff_src_mem;
299 
300     // Backward Eltwise descriptor.
301     std::shared_ptr<dnnl::eltwise_backward::desc> bwd_desc;
302 
303     // Memory descriptors.
304     std::shared_ptr<memory::desc> src_md;
305     std::shared_ptr<memory::desc> diff_dst_md;
306     std::shared_ptr<memory::desc> common_md;
307 
308     // Forward and backward descriptors and primitive descriptors.
309     std::shared_ptr<dnnl::eltwise_forward::desc> fwd_desc;
310     std::shared_ptr<EltwiseFwdPd> fwd_pd;
311     std::shared_ptr<EltwiseBwdPd> bwd_pd;
312 
313     // Eltwise primitive.
314     std::shared_ptr<dnnl::primitive> eltwise_bwd;
315 
316     std::vector<dnnl::primitive> bwd_primitives;
317 
318     std::vector<MemoryArgsMap> bwd_primitives_args;
319 
EltwiseBwdContexttensorflow::MklEltwiseBwdPrimitive::EltwiseBwdContext320     EltwiseBwdContext()
321         : src_mem(nullptr),
322           diff_dst_mem(nullptr),
323           diff_src_mem(nullptr),
324           src_md(nullptr),
325           diff_dst_md(nullptr),
326           common_md(nullptr),
327           fwd_desc(nullptr),
328           fwd_pd(nullptr),
329           bwd_pd(nullptr),
330           eltwise_bwd(nullptr) {}
331   };
332 
333   // Eltwise backward primitive setup
Setup(const MklEltwiseBwdParams<T> & bwdParams)334   void Setup(const MklEltwiseBwdParams<T>& bwdParams) {
335     // Create memory descriptors for eltwise data w/ no specified format
336     context_.src_md.reset(new memory::desc(bwdParams.common_md.data));
337     context_.diff_dst_md.reset(new memory::desc(bwdParams.common_md.data));
338 
339     // Create forward eltwise primitive.
340     context_.fwd_desc.reset(new dnnl::eltwise_forward::desc(
341         prop_kind::forward_training, bwdParams.alg_kind, *context_.src_md,
342         bwdParams.alpha, bwdParams.beta));
343     context_.fwd_pd.reset(new EltwiseFwdPd(*context_.fwd_desc, cpu_engine_));
344     context_.bwd_desc.reset(new dnnl::eltwise_backward::desc(
345         bwdParams.alg_kind, *context_.diff_dst_md, *context_.src_md,
346         bwdParams.alpha, bwdParams.beta));
347     context_.bwd_pd.reset(
348         new EltwiseBwdPd(*context_.bwd_desc, cpu_engine_, *context_.fwd_pd));
349 
350     auto bwd_pd = context_.bwd_pd.get();
351 
352     // Create memory primitive based on dummy data.
353     context_.src_mem.reset(
354         new memory(bwd_pd->src_desc(), cpu_engine_, DummyData));
355     context_.diff_dst_mem.reset(
356         new memory(bwd_pd->diff_dst_desc(), cpu_engine_, DummyData));
357     context_.diff_src_mem.reset(
358         new memory(bwd_pd->diff_src_desc(), cpu_engine_, DummyData));
359     // Create eltwise primitive and add it to net.
360     context_.eltwise_bwd.reset(new dnnl::eltwise_backward(*context_.bwd_pd));
361     context_.bwd_primitives_args.push_back(
362         {{bwdParams.forward_input_type, *context_.src_mem},
363          {DNNL_ARG_DIFF_DST, *context_.diff_dst_mem},
364          {DNNL_ARG_DIFF_SRC, *context_.diff_src_mem}});
365 
366     context_.bwd_primitives.push_back(*context_.eltwise_bwd);
367   }
368 
369   struct EltwiseBwdContext context_;
370 
371 #ifdef DNNL_AARCH64_USE_ACL
372   mutex primitive_execution_mu_;
373 #endif
374 };
375 
376 template <typename T>
377 class MklEltwiseBwdPrimitiveFactory : public MklPrimitiveFactory<T> {
378  private:
MklEltwiseBwdPrimitiveFactory()379   MklEltwiseBwdPrimitiveFactory() {}
~MklEltwiseBwdPrimitiveFactory()380   ~MklEltwiseBwdPrimitiveFactory() {}
381 
382  public:
Get(const MklEltwiseBwdParams<T> & bwdParams)383   static MklEltwiseBwdPrimitive<T>* Get(
384       const MklEltwiseBwdParams<T>& bwdParams) {
385     MklEltwiseBwdPrimitive<T>* eltwise_backward = nullptr;
386 
387     // try to find a suitable one in pool
388     eltwise_backward = static_cast<MklEltwiseBwdPrimitive<T>*>(
389         MklEltwiseBwdPrimitiveFactory<T>::GetInstance().GetEltwiseBwd(
390             bwdParams));
391 
392     if (eltwise_backward == nullptr) {
393       eltwise_backward = new MklEltwiseBwdPrimitive<T>(bwdParams);
394       MklEltwiseBwdPrimitiveFactory<T>::GetInstance().SetEltwiseBwd(
395           bwdParams, eltwise_backward);
396     }
397     return eltwise_backward;
398   }
399 
GetInstance()400   static MklEltwiseBwdPrimitiveFactory& GetInstance() {
401     static MklEltwiseBwdPrimitiveFactory instance_;
402     return instance_;
403   }
404 
405  private:
CreateKey(const MklEltwiseBwdParams<T> & bwdParams)406   static string CreateKey(const MklEltwiseBwdParams<T>& bwdParams) {
407     string prefix = "eltwise_bwd";
408     FactoryKeyCreator key_creator;
409     key_creator.AddAsKey(prefix);
410     key_creator.AddAsKey(bwdParams.src_dims);
411     key_creator.AddAsKey(static_cast<int>(bwdParams.alg_kind));
412     key_creator.AddAsKey(static_cast<float>(bwdParams.alpha));
413     key_creator.AddAsKey(static_cast<float>(bwdParams.beta));
414     return key_creator.GetKey();
415   }
416 
GetEltwiseBwd(const MklEltwiseBwdParams<T> & bwdParams)417   MklPrimitive* GetEltwiseBwd(const MklEltwiseBwdParams<T>& bwdParams) {
418     string key = CreateKey(bwdParams);
419     return this->GetOp(key);
420   }
421 
SetEltwiseBwd(const MklEltwiseBwdParams<T> & bwdParams,MklPrimitive * op)422   void SetEltwiseBwd(const MklEltwiseBwdParams<T>& bwdParams,
423                      MklPrimitive* op) {
424     string key = CreateKey(bwdParams);
425     this->SetOp(key, op);
426   }
427 };
428 
429 typedef Eigen::ThreadPoolDevice CPUDevice;
430 
431 template <typename Device, typename T, algorithm alg_kind>
432 class MklReluOpBase : public OpKernel {
433  public:
~MklReluOpBase()434   ~MklReluOpBase() {}
435 
MklReluOpBase(OpKernelConstruction * context,float alpha,float beta)436   explicit MklReluOpBase(OpKernelConstruction* context, float alpha, float beta)
437       : OpKernel(context), alpha_(alpha), beta_(beta) {}
438   virtual void Compute_Scalar(OpKernelContext* context) = 0;
439 
Compute(OpKernelContext * context)440   void Compute(OpKernelContext* context) override {
441     try {
442       const size_t src_index = 0;  // index of src input tensor
443       const size_t dst_index = 0;  // index of dst output tensor
444       const Tensor& src_tensor = MklGetInput(context, src_index);
445       MklDnnShape dnn_shape_src;
446       GetMklShape(context, src_index, &dnn_shape_src);
447       if (src_tensor.dims() == 0) {
448         Compute_Scalar(context);
449         return;
450       }
451       MklDnnShape dnn_shape_dst;
452       TensorShape tf_shape_dst;
453       Tensor* dst_tensor = nullptr;
454       // Nothing to compute, return.
455       if (src_tensor.shape().num_elements() == 0) {
456         dnn_shape_dst.SetMklTensor(false);
457         tf_shape_dst = MklGetInput(context, src_index).shape();
458         AllocateOutputSetMklShape(context, dst_index, &dst_tensor, tf_shape_dst,
459                                   dnn_shape_dst);
460         return;
461       }
462       // Set DNN primitive - src
463       MklDnnData<T> src(&cpu_engine);
464       memory::dims src_dims;
465       memory::desc src_md({}, memory::data_type::undef,
466                           memory::format_tag::undef);
467       if (dnn_shape_src.IsMklTensor()) {
468         src_md = dnn_shape_src.GetMklLayout();
469         src_dims = dnn_shape_src.GetSizesAsMklDnnDims();
470       } else {
471         src_dims = TFShapeToMklDnnDims(src_tensor.shape());
472         auto src_strides = CalculateTFStrides(src_dims);
473         // Create blocked memory descriptor
474         src_md = MklDnnData<T>::CreateBlockedMemDesc(src_dims, src_strides);
475       }
476       // Try to get an eltwise forward primitive from caching pool
477       MklEltwiseFwdParams<T> fwdParams(src_dims, src_md, alg_kind, alpha_,
478                                        beta_);
479       MklEltwiseFwdPrimitive<T>* eltwise_fwd =
480           MklEltwiseFwdPrimitiveFactory<T>::Get(fwdParams);
481       auto eltwise_fwd_pd = eltwise_fwd->GetEltwiseFwdPd();
482       std::shared_ptr<stream> fwd_cpu_stream;
483       MklDnnThreadPool eigen_tp(context);
484       fwd_cpu_stream.reset(CreateStream(&eigen_tp, eltwise_fwd->GetEngine()));
485       // Check if src needs to be reordered
486       bool is_src_reordered = false;
487       const T* src_data = src_tensor.flat<T>().data();
488       if (src_md != eltwise_fwd_pd->src_desc()) {
489         src.SetUsrMem(src_md, &src_tensor);
490         src.CheckReorderToOpMem(eltwise_fwd_pd->src_desc(), cpu_engine,
491                                 context);
492         src_data = const_cast<T*>(
493             reinterpret_cast<T*>(src.GetOpMem().get_data_handle()));
494         is_src_reordered = true;
495       }
496 
497       // If src is reordered, then dst tensor would be in blocked layout.
498       // So we propagate this blocked layout on the output. We follow same
499       // logic when src is in blocked (MKL) layout to start of with also.
500       if (is_src_reordered || dnn_shape_src.IsMklTensor()) {
501         dnn_shape_dst.SetMklTensor(true);
502         auto dst_pd = eltwise_fwd_pd->dst_desc();
503         dnn_shape_dst.SetMklLayout(&dst_pd);
504         dnn_shape_dst.SetElemType(MklDnnType<T>());
505         if (dnn_shape_src.IsMklTensor()) {
506           dnn_shape_dst.SetTfLayout(dnn_shape_src.GetDimension(),
507                                     dnn_shape_src.GetSizesAsMklDnnDims(),
508                                     dnn_shape_src.GetTfDataFormat());
509         } else {
510           dnn_shape_dst.SetTfLayout(src_tensor.dims(),
511                                     TFShapeToMklDnnDims(src_tensor.shape()),
512                                     MklTensorFormat::FORMAT_BLOCKED);
513         }
514         tf_shape_dst.AddDim(dst_pd.get_size() / sizeof(T));
515       } else {
516         // If src is not in blocked layout or it is not reordered, then dst is
517         // in native layout.
518         dnn_shape_dst.SetMklTensor(false);
519         tf_shape_dst = src_tensor.shape();
520       }
521 
522       if (is_src_reordered) {
523         // If src is reordered, then src and dst would be in different layouts.
524         AllocateOutputSetMklShape(context, dst_index, &dst_tensor, tf_shape_dst,
525                                   dnn_shape_dst);
526       } else {
527         // forwarding input to output works only when layouts of src and
528         // dst tensor remains same -- either both of them are in native layout
529         // or in blocked (MKL) layout.
530         OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
531                                     {static_cast<const int>(src_index)},
532                                     static_cast<const int>(dst_index),
533                                     tf_shape_dst, &dst_tensor));
534         AllocateOutputSetMklShape(context, dst_index, dnn_shape_dst);
535       }
536       T* dst_data = dst_tensor->flat<T>().data();
537 
538       // execute eltwise
539       eltwise_fwd->Execute(src_data, dst_data, fwd_cpu_stream);
540     } catch (dnnl::error& e) {
541       string error_msg = "Status: " + std::to_string(e.status) +
542                          ", message: " + string(e.message) + ", in file " +
543                          string(__FILE__) + ":" + std::to_string(__LINE__);
544       OP_REQUIRES_OK(
545           context,
546           errors::Aborted("Operation received an exception:", error_msg));
547     }
548   }
549 
550  private:
551   engine cpu_engine = engine(engine::kind::cpu, 0);
552   std::shared_ptr<EltwiseFwdPd> relu_fwd_pd;
553 
554  protected:
555   float alpha_;
556   float beta_;
557 };
558 
559 template <typename Device, typename T, algorithm alg_kind>
560 class MklReluGradOpBase : public OpKernel {
561  public:
~MklReluGradOpBase()562   ~MklReluGradOpBase() {}
563 
MklReluGradOpBase(OpKernelConstruction * context,float alpha,float beta)564   explicit MklReluGradOpBase(OpKernelConstruction* context, float alpha,
565                              float beta)
566       : OpKernel(context), alpha_(alpha), beta_(beta) {}
567 
568   virtual void Compute_Scalar(OpKernelContext* context) = 0;
569 
570   // All activation functions that are part of NN ops, such as Relu, Elu,
571   // LeakyRelu, Relu6, etc have dy at index 0 and y at index 1.
572   //
573   // if forward op is defined as: y = f(x),
574   // {Relu,Elu,Relu6,LeakyRelu}Grad is: z = f_grad(dy,x)
575   // TanhGrad is: z = tanh_grad(y,dy)
576   //
577   // Src below refers to a tensor that gradient op receives from forward
578   // operator. From Relu-family ops, it is 'x'; while for TanhGrad, it is 'y'.
GetDiffDstIndex() const579   virtual int GetDiffDstIndex() const { return 0; }
GetSrcIndex() const580   virtual int GetSrcIndex() const { return 1; }
GetDiffSrcIndex() const581   virtual int GetDiffSrcIndex() const { return 0; }
582   // What is the type of input tensor that grad op receives from forward op --
583   // is it 'x' (SRC) or 'y' (DST). For Relu-family, it is 'x', so fwd op SRC.
584 
GetTypeOfInputTensorFromFwdOp() const585   virtual int GetTypeOfInputTensorFromFwdOp() const { return DNNL_ARG_SRC; }
586 
Compute(OpKernelContext * context)587   void Compute(OpKernelContext* context) {
588     try {
589       MklDnnData<T> src(&cpu_engine);
590       MklDnnData<T> diff_dst(&cpu_engine);
591 
592       size_t diff_dst_index = GetDiffDstIndex();
593       size_t src_index = GetSrcIndex();
594       const size_t diff_src_index = GetDiffSrcIndex();
595 
596       const Tensor& src_tensor = MklGetInput(context, src_index);
597       const Tensor& diff_dst_tensor = MklGetInput(context, diff_dst_index);
598       Tensor* diff_src_tensor = nullptr;
599 
600       MklDnnShape dnn_shape_src, dnn_shape_diff_dst;
601       GetMklShape(context, src_index, &dnn_shape_src);
602       GetMklShape(context, diff_dst_index, &dnn_shape_diff_dst);
603 
604       int src_dims_size = src_tensor.dims();
605       if (src_dims_size == 0) {
606         Compute_Scalar(context);
607         return;
608       }
609 
610       TensorShape tf_shape_diff_src;
611       MklDnnShape dnn_shape_diff_src;
612       // Nothing to compute, return.
613       if (src_tensor.shape().num_elements() == 0) {
614         dnn_shape_diff_src.SetMklTensor(false);
615         tf_shape_diff_src = MklGetInput(context, diff_src_index).shape();
616         AllocateOutputSetMklShape(context, diff_src_index, &diff_src_tensor,
617                                   tf_shape_diff_src, dnn_shape_diff_src);
618         return;
619       }
620 
621       // get a eltwise bwd from primitive pool
622       memory::dims src_dims = {};
623       memory::desc src_md({}, memory::data_type::undef,
624                           memory::format_tag::undef);
625       memory::desc diff_dst_md({}, memory::data_type::undef,
626                                memory::format_tag::undef);
627       if (!dnn_shape_src.IsMklTensor() && !dnn_shape_diff_dst.IsMklTensor()) {
628         src_dims = TFShapeToMklDnnDims(src_tensor.shape());
629         auto src_strides = CalculateTFStrides(src_dims);
630         src_md = MklDnnData<T>::CreateBlockedMemDesc(src_dims, src_strides);
631         diff_dst_md = src_md;
632       } else if (dnn_shape_src.IsMklTensor() &&
633                  !dnn_shape_diff_dst.IsMklTensor()) {
634         src_md = dnn_shape_src.GetMklLayout();
635         src_dims = dnn_shape_src.GetSizesAsMklDnnDims();
636 
637         MklTensorFormat src_mkl_data_format = dnn_shape_src.GetTfDataFormat();
638         auto src_tf_data_format =
639             MklDnnDataFormatToTFDataFormat(src_mkl_data_format);
640         auto diff_dst_dims = TFShapeToMklDnnDimsInNCHW(diff_dst_tensor.shape(),
641                                                        src_tf_data_format);
642         diff_dst_md = memory::desc(
643             diff_dst_dims, MklDnnType<T>(),
644             MklTensorFormatToMklDnnDataFormat(src_mkl_data_format));
645       } else if (!dnn_shape_src.IsMklTensor() &&
646                  dnn_shape_diff_dst.IsMklTensor()) {
647         diff_dst_md = dnn_shape_diff_dst.GetMklLayout();
648 
649         MklTensorFormat diff_dst_mkl_data_format =
650             dnn_shape_diff_dst.GetTfDataFormat();
651         auto diff_dst_tf_data_format =
652             MklDnnDataFormatToTFDataFormat(diff_dst_mkl_data_format);
653 
654         src_dims = (src_tensor.dims() == 4)
655                        ? TFShapeToMklDnnDimsInNCHW(src_tensor.shape(),
656                                                    diff_dst_tf_data_format)
657                        : TFShapeToMklDnnDimsInNCDHW(src_tensor.shape(),
658                                                     diff_dst_tf_data_format);
659         src_md = memory::desc(
660             src_dims, MklDnnType<T>(),
661             MklTensorFormatToMklDnnDataFormat(diff_dst_mkl_data_format));
662       } else {
663         src_md = dnn_shape_src.GetMklLayout();
664         diff_dst_md = dnn_shape_diff_dst.GetMklLayout();
665         src_dims = dnn_shape_src.GetSizesAsMklDnnDims();
666       }
667 
668       // As per comment above, we tell oneDNN that both the inputs are in same
669       // format. So we set common memory descriptor in MKL format, if any of the
670       // inputs are in MKL format. Let's get memory descriptor that we will use
671       // for both the inputs.
672       memory::desc common_md({}, memory::data_type::undef,
673                              memory::format_tag::undef);
674       if (dnn_shape_src.IsMklTensor() || dnn_shape_diff_dst.IsMklTensor()) {
675         common_md = dnn_shape_src.IsMklTensor() ? src_md : diff_dst_md;
676       } else {
677         // Since both the inputs are in Tensorflow format, and have
678         // same shape, we can get memory descriptor from any input.
679         common_md = src_md;
680       }
681 
682       MklEltwiseBwdParams<T> bwdParams(src_dims, common_md, alg_kind, alpha_,
683                                        beta_, GetTypeOfInputTensorFromFwdOp());
684 
685       MklEltwiseBwdPrimitive<T>* eltwise_bwd =
686           MklEltwiseBwdPrimitiveFactory<T>::Get(bwdParams);
687 
688       auto eltwise_bwd_pd = eltwise_bwd->GetEltwiseBwdPd();
689       std::shared_ptr<stream> bwd_cpu_stream;
690       MklDnnThreadPool eigen_tp(context);
691       bwd_cpu_stream.reset(CreateStream(&eigen_tp, eltwise_bwd->GetEngine()));
692       // check whether need reorder for src / diff_dst
693       const T* src_data = src_tensor.flat<T>().data();
694       if (src_md != eltwise_bwd_pd->src_desc()) {
695         src.SetUsrMem(src_md, &src_tensor);
696         src.CheckReorderToOpMem(eltwise_bwd_pd.get()->diff_src_desc(),
697                                 cpu_engine, context);
698         src_data = const_cast<T*>(
699             reinterpret_cast<T*>(src.GetOpMem().get_data_handle()));
700       }
701 
702       const T* diff_dst_data = diff_dst_tensor.flat<T>().data();
703       if (diff_dst_md != eltwise_bwd_pd->diff_dst_desc()) {
704         diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor);
705         diff_dst.CheckReorderToOpMem(eltwise_bwd_pd.get()->diff_src_desc(),
706                                      cpu_engine, context);
707         diff_dst_data = const_cast<T*>(
708             reinterpret_cast<T*>(diff_dst.GetOpMem().get_data_handle()));
709       }
710 
711       // allocate diff_src tensor
712       if (dnn_shape_src.IsMklTensor() || dnn_shape_diff_dst.IsMklTensor()) {
713         auto diff_src_pd = eltwise_bwd_pd->diff_src_desc();
714         dnn_shape_diff_src.SetMklTensor(true);
715         dnn_shape_diff_src.SetMklLayout(&diff_src_pd);
716         dnn_shape_diff_src.SetElemType(MklDnnType<T>());
717         if (dnn_shape_src.IsMklTensor()) {
718           dnn_shape_diff_src.SetTfLayout(dnn_shape_src.GetDimension(),
719                                          dnn_shape_src.GetSizesAsMklDnnDims(),
720                                          dnn_shape_src.GetTfDataFormat());
721         } else {
722           dnn_shape_diff_src.SetTfLayout(
723               dnn_shape_diff_dst.GetDimension(),
724               dnn_shape_diff_dst.GetSizesAsMklDnnDims(),
725               dnn_shape_diff_dst.GetTfDataFormat());
726         }
727         tf_shape_diff_src.AddDim(diff_src_pd.get_size() / sizeof(T));
728       } else {
729         dnn_shape_diff_src.SetMklTensor(false);
730         tf_shape_diff_src = src_tensor.shape();
731       }
732 
733       OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
734                                   {static_cast<const int>(diff_dst_index)},
735                                   static_cast<const int>(diff_src_index),
736                                   tf_shape_diff_src, &diff_src_tensor));
737       AllocateOutputSetMklShape(context, diff_src_index, dnn_shape_diff_src);
738 
739       T* diff_src_data = diff_src_tensor->flat<T>().data();
740 
741       // execute eltwise bwd
742       eltwise_bwd->Execute(src_data, diff_dst_data, diff_src_data,
743                            bwd_cpu_stream);
744     } catch (dnnl::error& e) {
745       string error_msg = "Status: " + std::to_string(e.status) +
746                          ", message: " + string(e.message) + ", in file " +
747                          string(__FILE__) + ":" + std::to_string(__LINE__);
748       OP_REQUIRES_OK(
749           context,
750           errors::Aborted("Operation received an exception:", error_msg));
751     }
752   }
753 
754  private:
755   engine cpu_engine = engine(engine::kind::cpu, 0);
756   std::shared_ptr<EltwiseFwdPd> relu_fwd_pd;
757 
758  protected:
759   float alpha_;
760   float beta_;
761 };
762 
763 template <typename Device, typename T>
764 class MklReluOp
765     : public MklReluOpBase<Device, T, dnnl::algorithm::eltwise_relu> {
766  public:
~MklReluOp()767   ~MklReluOp() {}
768 
MklReluOp(OpKernelConstruction * context)769   explicit MklReluOp(OpKernelConstruction* context)
770       : MklReluOpBase<Device, T, dnnl::algorithm::eltwise_relu>(context, 0.0f,
771                                                                 0.0f) {}
772 
Compute_Scalar(OpKernelContext * context)773   virtual void Compute_Scalar(OpKernelContext* context) {
774     const size_t src_index = 0;  // index of src input tensor
775     const size_t dst_index = 0;  // index of dst output tensor
776     const Tensor& src_tensor = MklGetInput(context, src_index);
777     MklDnnShape dnn_shape_src;
778     GetMklShape(context, src_index, &dnn_shape_src);
779 
780     Tensor* dst_tensor = nullptr;
781     void* user_i =
782         static_cast<void*>(const_cast<T*>(src_tensor.flat<T>().data()));
783     MklDnnShape dnn_shape_dst;
784     dnn_shape_dst.SetMklTensor(false);
785     AllocateOutputSetMklShape(context, dst_index, &dst_tensor,
786                               src_tensor.shape(), dnn_shape_dst);
787     void* out_o = static_cast<void*>(dst_tensor->flat<T>().data());
788     (static_cast<T*>(out_o))[0] =
789         std::max((static_cast<T*>(user_i))[0], static_cast<T>(0));
790     return;
791   }
792 };
793 
794 template <typename Device, typename T>
795 class MklReluGradOp
796     : public MklReluGradOpBase<Device, T, dnnl::algorithm::eltwise_relu> {
797  public:
~MklReluGradOp()798   ~MklReluGradOp() {}
799 
MklReluGradOp(OpKernelConstruction * context)800   explicit MklReluGradOp(OpKernelConstruction* context)
801       : MklReluGradOpBase<Device, T, dnnl::algorithm::eltwise_relu>(
802             context, 0.0f, 0.0f) {}
803 
Compute_Scalar(OpKernelContext * context)804   virtual void Compute_Scalar(OpKernelContext* context) {
805     const size_t diff_dst_index = 0;  // index of diff_dst input tensor
806     const size_t src_index = 1;       // index of src input tensor
807     const size_t diff_src_index = 0;  // index of diff_src output tensor
808     const Tensor& src_tensor = MklGetInput(context, src_index);
809     const Tensor& diff_dst_tensor = MklGetInput(context, diff_dst_index);
810     Tensor* diff_src_tensor = nullptr;
811 
812     MklDnnShape dnn_shape_diff_dst;
813     GetMklShape(context, diff_dst_index, &dnn_shape_diff_dst);
814 
815     MklDnnShape dnn_shape_diff_src;
816     dnn_shape_diff_src.SetMklTensor(false);
817     AllocateOutputSetMklShape(context, diff_src_index, &diff_src_tensor,
818                               diff_dst_tensor.shape(), dnn_shape_diff_src);
819     void* out_o = static_cast<void*>(diff_src_tensor->flat<T>().data());
820     void* user_i =
821         static_cast<void*>(const_cast<T*>(src_tensor.flat<T>().data()));
822     void* user_g =
823         static_cast<void*>(const_cast<T*>(diff_dst_tensor.flat<T>().data()));
824     (static_cast<T*>(out_o))[0] =
825         (static_cast<T*>(user_g))[0] *
826         (static_cast<T>((static_cast<T*>(user_i))[0] > static_cast<T>(0)));
827     return;
828   }
829 };
830 
831 template <typename Device, typename T>
832 class MklEluOp : public MklReluOpBase<Device, T, dnnl::algorithm::eltwise_elu> {
833  public:
~MklEluOp()834   ~MklEluOp() {}
835 
MklEluOp(OpKernelConstruction * context)836   explicit MklEluOp(OpKernelConstruction* context)
837       : MklReluOpBase<Device, T, dnnl::algorithm::eltwise_elu>(context, 0.0f,
838                                                                0.0f) {}
839 
Compute_Scalar(OpKernelContext * context)840   virtual void Compute_Scalar(OpKernelContext* context) {
841     const size_t src_index = 0;  // index of src input tensor
842     const size_t dst_index = 0;  // index of dst output tensor
843     const Tensor& src_tensor = MklGetInput(context, src_index);
844     MklDnnShape dnn_shape_src;
845     GetMklShape(context, src_index, &dnn_shape_src);
846 
847     Tensor* dst_tensor = nullptr;
848     void* user_i =
849         static_cast<void*>(const_cast<T*>(src_tensor.flat<T>().data()));
850     MklDnnShape dnn_shape_dst;
851     dnn_shape_dst.SetMklTensor(false);
852     AllocateOutputSetMklShape(context, dst_index, &dst_tensor,
853                               src_tensor.shape(), dnn_shape_dst);
854     void* out_o = static_cast<void*>(dst_tensor->flat<T>().data());
855     // return exp(feature) - 1 if feature > 0; feature otherwise
856     T feature = (static_cast<T*>(user_i))[0];
857     if (feature < static_cast<T>(0))
858       (static_cast<T*>(out_o))[0] = Eigen::numext::exp(feature);
859     else
860       (static_cast<T*>(out_o))[0] = feature;
861     return;
862   }
863 };
864 
865 template <typename Device, typename T>
866 class MklEluGradOp
867     : public MklReluGradOpBase<Device, T, dnnl::algorithm::eltwise_elu> {
868  public:
~MklEluGradOp()869   ~MklEluGradOp() {}
870 
MklEluGradOp(OpKernelConstruction * context)871   explicit MklEluGradOp(OpKernelConstruction* context)
872       : MklReluGradOpBase<Device, T, dnnl::algorithm::eltwise_elu>(
873             context, 0.0f, 0.0f) {}
874 
Compute_Scalar(OpKernelContext * context)875   virtual void Compute_Scalar(OpKernelContext* context) {
876     const size_t diff_dst_index = 0;  // index of diff_dst input tensor
877     const size_t src_index = 1;       // index of src input tensor
878     const size_t diff_src_index = 0;  // index of diff_src output tensor
879     const Tensor& src_tensor = MklGetInput(context, src_index);
880     const Tensor& diff_dst_tensor = MklGetInput(context, diff_dst_index);
881     Tensor* diff_src_tensor = nullptr;
882 
883     MklDnnShape dnn_shape_diff_dst;
884     GetMklShape(context, diff_dst_index, &dnn_shape_diff_dst);
885 
886     MklDnnShape dnn_shape_diff_src;
887     dnn_shape_diff_src.SetMklTensor(false);
888     AllocateOutputSetMklShape(context, diff_src_index, &diff_src_tensor,
889                               diff_dst_tensor.shape(), dnn_shape_diff_src);
890     void* out_o = static_cast<void*>(diff_src_tensor->flat<T>().data());
891     void* user_i =
892         static_cast<void*>(const_cast<T*>(src_tensor.flat<T>().data()));
893     void* user_g =
894         static_cast<void*>(const_cast<T*>(diff_dst_tensor.flat<T>().data()));
895     // gradient of elu(x) = 1 if x > 0; elu(x) + 1 otherwise
896     T feature = (static_cast<T*>(user_i))[0];
897     if (feature > static_cast<T>(0)) {
898       (static_cast<T*>(out_o))[0] = (static_cast<T*>(user_g))[0];
899     } else {
900       T elu = Eigen::numext::exp(feature) - static_cast<T>(1);
901       (static_cast<T*>(out_o))[0] =
902           (static_cast<T*>(user_g))[0] * (elu + static_cast<T>(1));
903     }
904   }
905 };
906 
907 // Optimized TanhGrad support exists in DNNL1.x only
908 // (eltwise_tanh_use_dst_for_bwd). We can still support it with DNNL0.x, but
909 // it will not be optimized. So we disable it for DNNL0.x.
910 
911 template <typename Device, typename T>
912 class MklTanhOp
913     : public MklReluOpBase<Device, T, dnnl::algorithm::eltwise_tanh> {
914  public:
~MklTanhOp()915   ~MklTanhOp() {}
916 
MklTanhOp(OpKernelConstruction * context)917   explicit MklTanhOp(OpKernelConstruction* context)
918       : MklReluOpBase<Device, T, dnnl::algorithm::eltwise_tanh>(context, 0.0f,
919                                                                 0.0f) {}
920 
Compute_Scalar(OpKernelContext * context)921   virtual void Compute_Scalar(OpKernelContext* context) {
922     const size_t src_index = 0;  // index of src input tensor
923     const size_t dst_index = 0;  // index of dst output tensor
924     const Tensor& src_tensor = MklGetInput(context, src_index);
925     MklDnnShape dnn_shape_src;
926     GetMklShape(context, src_index, &dnn_shape_src);
927 
928     Tensor* dst_tensor = nullptr;
929     void* user_i =
930         static_cast<void*>(const_cast<T*>(src_tensor.flat<T>().data()));
931     MklDnnShape dnn_shape_dst;
932     dnn_shape_dst.SetMklTensor(false);
933     AllocateOutputSetMklShape(context, dst_index, &dst_tensor,
934                               src_tensor.shape(), dnn_shape_dst);
935     void* out_o = static_cast<void*>(dst_tensor->flat<T>().data());
936     // tanh(x) = (e^x - e^(-x))/ (e^x + e^(-x))
937     T feature = (static_cast<T*>(user_i))[0];
938     T e1 = Eigen::numext::exp(feature);
939     T e2 = Eigen::numext::exp(-feature);
940     (static_cast<T*>(out_o))[0] = (e1 - e2) / (e1 + e2);
941     return;
942   }
943 };
944 
945 template <typename Device, typename T>
946 class MklTanhGradOp
947     : public MklReluGradOpBase<Device, T,
948                                dnnl::algorithm::eltwise_tanh_use_dst_for_bwd> {
949  public:
~MklTanhGradOp()950   ~MklTanhGradOp() {}
951 
MklTanhGradOp(OpKernelConstruction * context)952   explicit MklTanhGradOp(OpKernelConstruction* context)
953       : MklReluGradOpBase<Device, T,
954                           dnnl::algorithm::eltwise_tanh_use_dst_for_bwd>(
955             context, 0.0f, 0.0f) {}
956 
GetDiffDstIndex() const957   virtual int GetDiffDstIndex() const { return 1; }
GetSrcIndex() const958   virtual int GetSrcIndex() const { return 0; }
GetDiffSrcIndex() const959   virtual int GetDiffSrcIndex() const { return 0; }
960 
961   // TanhGrad gets 'y' from Tanh, where 'y' is output of Tanh(x).
GetTypeOfInputTensorFromFwdOp() const962   virtual int GetTypeOfInputTensorFromFwdOp() const { return DNNL_ARG_DST; }
963 
Compute_Scalar(OpKernelContext * context)964   virtual void Compute_Scalar(OpKernelContext* context) {
965     // NOTE: Order of y and dy for Tanh is reverse of that for Relu/Elu/other
966     // element-wise ops. Tanh is math op in Tensorflow; others are NN ops.
967     const size_t diff_dst_index = GetDiffDstIndex();
968     const size_t src_index = GetSrcIndex();
969     const size_t diff_src_index = GetDiffSrcIndex();
970     const Tensor& src_tensor = MklGetInput(context, src_index);
971     const Tensor& diff_dst_tensor = MklGetInput(context, diff_dst_index);
972     Tensor* diff_src_tensor = nullptr;
973 
974     MklDnnShape dnn_shape_diff_dst;
975     GetMklShape(context, diff_dst_index, &dnn_shape_diff_dst);
976 
977     MklDnnShape dnn_shape_diff_src;
978     dnn_shape_diff_src.SetMklTensor(false);
979     AllocateOutputSetMklShape(context, diff_src_index, &diff_src_tensor,
980                               diff_dst_tensor.shape(), dnn_shape_diff_src);
981     void* out_o = static_cast<void*>(diff_src_tensor->flat<T>().data());
982     void* user_i =
983         static_cast<void*>(const_cast<T*>(src_tensor.flat<T>().data()));
984     // gradient of tanh(x) = 1 - tanh(x)^2
985     // Input to TanhGrad is output of Tanh. So we do not need to compute
986     // Tanh again.
987     T tanh = (static_cast<T*>(user_i))[0];
988     void* user_g =
989         static_cast<void*>(const_cast<T*>(diff_dst_tensor.flat<T>().data()));
990     (static_cast<T*>(out_o))[0] =
991         (static_cast<T*>(user_g))[0] * (static_cast<T>(1) - tanh * tanh);
992   }
993 };
994 
995 #define RELU6_UPPER_BOUND 6.0f
996 template <typename Device, typename T>
997 class MklRelu6Op
998     : public MklReluOpBase<Device, T, dnnl::algorithm::eltwise_bounded_relu> {
999  public:
~MklRelu6Op()1000   ~MklRelu6Op() {}
1001 
MklRelu6Op(OpKernelConstruction * context)1002   explicit MklRelu6Op(OpKernelConstruction* context)
1003       : MklReluOpBase<Device, T, dnnl::algorithm::eltwise_bounded_relu>(
1004             context, RELU6_UPPER_BOUND, 0.0f) {}
1005 
Compute_Scalar(OpKernelContext * context)1006   virtual void Compute_Scalar(OpKernelContext* context) {
1007     const size_t src_index = 0;  // index of src input tensor
1008     const size_t dst_index = 0;  // index of dst output tensor
1009     const Tensor& src_tensor = MklGetInput(context, src_index);
1010     MklDnnShape dnn_shape_src;
1011     GetMklShape(context, src_index, &dnn_shape_src);
1012 
1013     Tensor* dst_tensor = nullptr;
1014     T* user_i = const_cast<T*>(src_tensor.flat<T>().data());
1015     MklDnnShape dnn_shape_dst;
1016     dnn_shape_dst.SetMklTensor(false);
1017     AllocateOutputSetMklShape(context, dst_index, &dst_tensor,
1018                               src_tensor.shape(), dnn_shape_dst);
1019     T* out_o = dst_tensor->flat<T>().data();
1020     out_o[0] = std::min(std::max(user_i[0], static_cast<T>(0)),
1021                         static_cast<T>(RELU6_UPPER_BOUND));
1022     return;
1023   }
1024 };
1025 
1026 template <typename Device, typename T>
1027 class MklRelu6GradOp
1028     : public MklReluGradOpBase<Device, T,
1029                                dnnl::algorithm::eltwise_bounded_relu> {
1030  public:
~MklRelu6GradOp()1031   ~MklRelu6GradOp() {}
1032 
MklRelu6GradOp(OpKernelConstruction * context)1033   explicit MklRelu6GradOp(OpKernelConstruction* context)
1034       : MklReluGradOpBase<Device, T, dnnl::algorithm::eltwise_bounded_relu>(
1035             context, RELU6_UPPER_BOUND, 0.0f) {}
1036 
Compute_Scalar(OpKernelContext * context)1037   virtual void Compute_Scalar(OpKernelContext* context) {
1038     const size_t diff_dst_index = 0;  // index of diff_dst input tensor
1039     const size_t src_index = 1;       // index of src input tensor
1040     const size_t diff_src_index = 0;  // index of diff_src output tensor
1041     const Tensor& src_tensor = MklGetInput(context, src_index);
1042     const Tensor& diff_dst_tensor = MklGetInput(context, diff_dst_index);
1043     Tensor* diff_src_tensor = nullptr;
1044 
1045     MklDnnShape dnn_shape_diff_dst;
1046     GetMklShape(context, diff_dst_index, &dnn_shape_diff_dst);
1047 
1048     MklDnnShape dnn_shape_diff_src;
1049     dnn_shape_diff_src.SetMklTensor(false);
1050     AllocateOutputSetMklShape(context, diff_src_index, &diff_src_tensor,
1051                               diff_dst_tensor.shape(), dnn_shape_diff_src);
1052     T* out_o = diff_src_tensor->flat<T>().data();
1053     T* user_i = const_cast<T*>(src_tensor.flat<T>().data());
1054     T* user_g = const_cast<T*>(diff_dst_tensor.flat<T>().data());
1055     out_o[0] = user_g[0] *
1056                static_cast<T>(user_i[0] > static_cast<T>(0) &&
1057                               (user_i[0] < static_cast<T>(RELU6_UPPER_BOUND)));
1058     return;
1059   }
1060 };
1061 
1062 template <typename Device, typename T>
1063 class MklLeakyReluOp
1064     : public MklReluOpBase<Device, T, dnnl::algorithm::eltwise_relu> {
1065  public:
~MklLeakyReluOp()1066   ~MklLeakyReluOp() {}
1067 
MklLeakyReluOp(OpKernelConstruction * context)1068   explicit MklLeakyReluOp(OpKernelConstruction* context)
1069       : MklReluOpBase<Device, T, dnnl::algorithm::eltwise_relu>(context, 0.0f,
1070                                                                 0.0f) {
1071     float alpha;
1072     OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha));
1073     OP_REQUIRES(
1074         context, alpha <= 1,
1075         errors::InvalidArgument("MKL LeakyRelu only supports alpha <= 1. "
1076                                 "alpha is: ",
1077                                 alpha));
1078 
1079     this->alpha_ = alpha;
1080   }
1081 
Compute_Scalar(OpKernelContext * context)1082   virtual void Compute_Scalar(OpKernelContext* context) {
1083     const size_t src_index = 0;  // index of src input tensor
1084     const size_t dst_index = 0;  // index of dst output tensor
1085     const Tensor& src_tensor = MklGetInput(context, src_index);
1086     MklDnnShape dnn_shape_src;
1087     GetMklShape(context, src_index, &dnn_shape_src);
1088 
1089     Tensor* dst_tensor = nullptr;
1090     T* user_i = const_cast<T*>(src_tensor.flat<T>().data());
1091     MklDnnShape dnn_shape_dst;
1092     dnn_shape_dst.SetMklTensor(false);
1093     AllocateOutputSetMklShape(context, dst_index, &dst_tensor,
1094                               src_tensor.shape(), dnn_shape_dst);
1095     T* out_o = dst_tensor->flat<T>().data();
1096     out_o[0] = user_i[0] >= T(0) ? user_i[0] : user_i[0] * T(this->alpha_);
1097     return;
1098   }
1099 };
1100 
1101 template <typename Device, typename T>
1102 class MklLeakyReluGradOp
1103     : public MklReluGradOpBase<Device, T, dnnl::algorithm::eltwise_relu> {
1104  public:
~MklLeakyReluGradOp()1105   ~MklLeakyReluGradOp() {}
1106 
MklLeakyReluGradOp(OpKernelConstruction * context)1107   explicit MklLeakyReluGradOp(OpKernelConstruction* context)
1108       : MklReluGradOpBase<Device, T, dnnl::algorithm::eltwise_relu>(
1109             context, 0.0f, 0.0f) {
1110     float alpha;
1111     OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha));
1112     OP_REQUIRES(
1113         context, alpha <= 1,
1114         errors::InvalidArgument("MKL LeakyRelu only supports alpha <= 1. "
1115                                 "alpha is: ",
1116                                 alpha));
1117 
1118     this->alpha_ = alpha;
1119   }
1120 
Compute_Scalar(OpKernelContext * context)1121   virtual void Compute_Scalar(OpKernelContext* context) {
1122     const size_t diff_dst_index = 0;  // index of diff_dst input tensor
1123     const size_t src_index = 1;       // index of src input tensor
1124     const size_t diff_src_index = 0;  // index of diff_src output tensor
1125     const Tensor& src_tensor = MklGetInput(context, src_index);
1126     const Tensor& diff_dst_tensor = MklGetInput(context, diff_dst_index);
1127     Tensor* diff_src_tensor = nullptr;
1128 
1129     MklDnnShape dnn_shape_diff_dst;
1130     GetMklShape(context, diff_dst_index, &dnn_shape_diff_dst);
1131 
1132     MklDnnShape dnn_shape_diff_src;
1133     dnn_shape_diff_src.SetMklTensor(false);
1134     AllocateOutputSetMklShape(context, diff_src_index, &diff_src_tensor,
1135                               diff_dst_tensor.shape(), dnn_shape_diff_src);
1136     T* out_o = diff_src_tensor->flat<T>().data();
1137     T* user_i = const_cast<T*>(src_tensor.flat<T>().data());
1138     T* user_g = const_cast<T*>(diff_dst_tensor.flat<T>().data());
1139     out_o[0] = user_i[0] >= static_cast<T>(0)
1140                    ? user_g[0]
1141                    : user_g[0] * static_cast<T>(this->alpha_);
1142     return;
1143   }
1144 };
1145 
1146 // register dnn kernels for supported operations and supported types
1147 #define REGISTER_RELU_MKL_SUPPORTED_KERNELS_TYPES(type)        \
1148   REGISTER_KERNEL_BUILDER(                                     \
1149       Name("_MklRelu")                                         \
1150           .Device(DEVICE_CPU)                                  \
1151           .TypeConstraint<type>("T")                           \
1152           .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
1153       MklReluOp<CPUDevice, type>);                             \
1154   REGISTER_KERNEL_BUILDER(                                     \
1155       Name("_MklReluGrad")                                     \
1156           .Device(DEVICE_CPU)                                  \
1157           .TypeConstraint<type>("T")                           \
1158           .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
1159       MklReluGradOp<CPUDevice, type>);
1160 TF_CALL_float(REGISTER_RELU_MKL_SUPPORTED_KERNELS_TYPES);
1161 TF_CALL_bfloat16(REGISTER_RELU_MKL_SUPPORTED_KERNELS_TYPES);
1162 
1163 // register dnn kernels for supported operations and supported types
1164 #define REGISTER_ELU_MKL_SUPPORTED_KERNELS_TYPES(type)         \
1165   REGISTER_KERNEL_BUILDER(                                     \
1166       Name("_MklElu")                                          \
1167           .Device(DEVICE_CPU)                                  \
1168           .TypeConstraint<type>("T")                           \
1169           .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
1170       MklEluOp<CPUDevice, type>);                              \
1171   REGISTER_KERNEL_BUILDER(                                     \
1172       Name("_MklEluGrad")                                      \
1173           .Device(DEVICE_CPU)                                  \
1174           .TypeConstraint<type>("T")                           \
1175           .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
1176       MklEluGradOp<CPUDevice, type>);
1177 TF_CALL_float(REGISTER_ELU_MKL_SUPPORTED_KERNELS_TYPES);
1178 TF_CALL_bfloat16(REGISTER_ELU_MKL_SUPPORTED_KERNELS_TYPES);
1179 
1180 #define REGISTER_TANH_MKL_SUPPORTED_KERNELS_TYPES(type)        \
1181   REGISTER_KERNEL_BUILDER(                                     \
1182       Name("_MklTanh")                                         \
1183           .Device(DEVICE_CPU)                                  \
1184           .TypeConstraint<type>("T")                           \
1185           .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
1186       MklTanhOp<CPUDevice, type>);                             \
1187   REGISTER_KERNEL_BUILDER(                                     \
1188       Name("_MklTanhGrad")                                     \
1189           .Device(DEVICE_CPU)                                  \
1190           .TypeConstraint<type>("T")                           \
1191           .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
1192       MklTanhGradOp<CPUDevice, type>);
1193 TF_CALL_float(REGISTER_TANH_MKL_SUPPORTED_KERNELS_TYPES);
1194 TF_CALL_bfloat16(REGISTER_TANH_MKL_SUPPORTED_KERNELS_TYPES);
1195 
1196 #define REGISTER_RELU6_MKL_SUPPORTED_KERNELS_TYPES(type)       \
1197   REGISTER_KERNEL_BUILDER(                                     \
1198       Name("_MklRelu6")                                        \
1199           .Device(DEVICE_CPU)                                  \
1200           .TypeConstraint<type>("T")                           \
1201           .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
1202       MklRelu6Op<CPUDevice, type>);                            \
1203   REGISTER_KERNEL_BUILDER(                                     \
1204       Name("_MklRelu6Grad")                                    \
1205           .Device(DEVICE_CPU)                                  \
1206           .TypeConstraint<type>("T")                           \
1207           .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
1208       MklRelu6GradOp<CPUDevice, type>);
1209 TF_CALL_float(REGISTER_RELU6_MKL_SUPPORTED_KERNELS_TYPES);
1210 TF_CALL_bfloat16(REGISTER_RELU6_MKL_SUPPORTED_KERNELS_TYPES);
1211 
1212 #define REGISTER_LeakyRelu_MKL_SUPPORTED_KERNELS_TYPES(type)   \
1213   REGISTER_KERNEL_BUILDER(                                     \
1214       Name("_MklLeakyRelu")                                    \
1215           .Device(DEVICE_CPU)                                  \
1216           .TypeConstraint<type>("T")                           \
1217           .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
1218       MklLeakyReluOp<CPUDevice, type>);                        \
1219   REGISTER_KERNEL_BUILDER(                                     \
1220       Name("_MklLeakyReluGrad")                                \
1221           .Device(DEVICE_CPU)                                  \
1222           .TypeConstraint<type>("T")                           \
1223           .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
1224       MklLeakyReluGradOp<CPUDevice, type>);
1225 TF_CALL_float(REGISTER_LeakyRelu_MKL_SUPPORTED_KERNELS_TYPES);
1226 TF_CALL_bfloat16(REGISTER_LeakyRelu_MKL_SUPPORTED_KERNELS_TYPES);
1227 
1228 }  // namespace tensorflow
1229 
1230 #endif  // INTEL_MKL
1231