xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/mkl/mkl_tfconv_op.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_MKL_MKL_TFCONV_OP_H_
17 #define TENSORFLOW_CORE_KERNELS_MKL_MKL_TFCONV_OP_H_
18 
19 #ifdef INTEL_MKL
20 
21 #include <algorithm>
22 #include <string>
23 #include <vector>
24 
25 #include "tensorflow/core/framework/numeric_op.h"
26 #include "tensorflow/core/framework/op.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/register_types.h"
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/framework/tensor_shape.h"
31 #include "tensorflow/core/kernels/ops_util.h"
32 #include "tensorflow/core/platform/byte_order.h"
33 #include "tensorflow/core/platform/cpu_info.h"
34 #include "tensorflow/core/platform/macros.h"
35 #include "tensorflow/core/util/mkl_util.h"
36 #include "tensorflow/core/util/tensor_format.h"
37 
38 using dnnl::stream;
39 
40 namespace tensorflow {
41 
42 typedef Eigen::ThreadPoolDevice CPUDevice;
43 
44 ///////////////////////////////////////////////////////////
45 //               Op kernel
46 ///////////////////////////////////////////////////////////
47 
48 template <typename Device, typename T>
49 class MklToTfOp : public OpKernel {
50  public:
MklToTfOp(OpKernelConstruction * context)51   explicit MklToTfOp(OpKernelConstruction* context) : OpKernel(context) {
52     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str));
53     OP_REQUIRES_OK(context, context->GetAttr("T", &op_data_type));
54     has_avx512f_ = port::TestCPUFeature(port::CPUFeature::AVX512F);
55   }
56 
Compute(OpKernelContext * context)57   void Compute(OpKernelContext* context) override {
58     ConvertMklToTf(this, context, data_format_str, op_data_type, has_avx512f_,
59                    0);
60     VLOG(1) << "MKLToTFConversion complete successfully.";
61   }
62 
63   // TODO(intel-tf): Move the below ConvertMklToTf() to mkl_util.h
ConvertMklToTf(OpKernel * op_kernel,OpKernelContext * context,string data_format_str,DataType op_data_type,bool has_avx512f,uint input_number)64   static void ConvertMklToTf(OpKernel* op_kernel, OpKernelContext* context,
65                              string data_format_str, DataType op_data_type,
66                              bool has_avx512f, uint input_number) {
67     try {
68       // Check that input tensor is in MKL format.
69       const Tensor& input_tensor = MklGetInput(context, input_number);
70       MklDnnShape input_shape;
71       GetMklShape(context, input_number, &input_shape);
72 
73       // if input is already in Tf format, then copy input tensor to output.
74       if (!input_shape.IsMklTensor()) {
75         context->set_output(input_number, input_tensor);
76         VLOG(1) << "MKLToTFConversion: No conversion needed, "
77                 << "copying input to output";
78         return;
79       }
80 
81       // Check that input data type is same as operator data type and that it
82       // is same as output data type.
83       DataType input_data_type = op_kernel->input_type(input_number);
84       DataType output_data_type = op_kernel->output_type(input_number);
85       CHECK_EQ(op_data_type, input_data_type);
86       CHECK_EQ(op_data_type, output_data_type);
87 
88       auto cpu_engine = engine(engine::kind::cpu, 0);
89       MklDnnData<T> input(&cpu_engine);
90 
91       // Get MKL layout of input tensor.
92       auto input_mkl_md = input_shape.GetMklLayout();
93       // Get TensorFlow layout of input tensor. Expected output of conversion
94       // has same layout as Tensorflow layout of input tensor.
95       auto output_tf_md = input_shape.GetTfLayout();
96       // Set input MKL layout as the user layout.
97       input.SetUsrMem(input_mkl_md, &input_tensor);
98 
99       // Allocate output tensor.
100       TensorShape output_shape = input_shape.GetTfShape();
101       Tensor* output_tensor = nullptr;
102       OP_REQUIRES_OK(context, context->allocate_output(
103                                   input_number, output_shape, &output_tensor));
104       DCHECK(output_tensor);
105 
106       // Check if input needs to be reordered
107       if (input.IsReorderNeeded(output_tf_md)) {
108         // Insert reorder between MKL layout and TensorFlow layout
109         OP_REQUIRES(
110             context,
111             input.CheckReorderToOpMem(output_tf_md, output_tensor, context),
112             errors::Internal("MklToTfOp: Failed to create input reorder"));
113       } else {
114         // If not, just forward input tensor to output tensor.
115         OP_REQUIRES(context,
116                     output_tensor->CopyFrom(input_tensor, output_shape),
117                     errors::Internal(
118                         "MklToTfOp: Failed to forward input tensor to output"));
119       }
120     } catch (dnnl::error& e) {
121       OP_REQUIRES_OK(
122           context,
123           errors::Aborted("Operation received an exception: Status: ", e.status,
124                           ", message: ", StringPiece(e.message), ", in file ",
125                           __FILE__, ":", __LINE__));
126     }
127   }
128 
129  private:
130   /// Data format of the operation
131   string data_format_str;
132 
133   /// Data type of the operation
134   DataType op_data_type;
135 
136   /// CPUIDInfo
137   bool has_avx512f_ = false;
138 };
139 
140 ///////////////////////////////////////////////////////////
141 //               Register kernel
142 ///////////////////////////////////////////////////////////
143 
144 #define REGISTER_CPU(T)                                        \
145   REGISTER_KERNEL_BUILDER(                                     \
146       Name("_MklToTf")                                         \
147           .Device(DEVICE_CPU)                                  \
148           .TypeConstraint<T>("T")                              \
149           .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
150       MklToTfOp<CPUDevice, T>);
151 
152 TF_CALL_NUMBER_TYPES(REGISTER_CPU);
153 TF_CALL_QUANTIZED_TYPES(REGISTER_CPU);
154 
155 #undef REGISTER_CPU
156 
157 }  // namespace tensorflow
158 #endif  // INTEL_MKL
159 #endif  // TENSORFLOW_CORE_KERNELS_MKL_MKL_TFCONV_OP_H_
160