xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/mkl/mkl_aggregate_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/math_ops.cc.
17 
18 #ifdef INTEL_MKL
19 #define EIGEN_USE_THREADS
20 
21 #include <numeric>
22 
23 #include "dnnl.hpp"
24 #include "tensorflow/core/framework/numeric_op.h"
25 #include "tensorflow/core/framework/register_types.h"
26 #include "tensorflow/core/lib/gtl/inlined_vector.h"
27 #include "tensorflow/core/util/mkl_util.h"
28 
29 using dnnl::stream;
30 using dnnl::sum;
31 
32 namespace tensorflow {
33 typedef Eigen::ThreadPoolDevice CPUDevice;
34 
35 template <typename Device, typename T>
36 class MklAddNOp : public OpKernel {
37  public:
~MklAddNOp()38   ~MklAddNOp() {}
MklAddNOp(OpKernelConstruction * context)39   explicit MklAddNOp(OpKernelConstruction* context) : OpKernel(context) {}
40 
GetTensorShape(OpKernelContext * ctx,size_t src_index)41   TensorShape GetTensorShape(OpKernelContext* ctx, size_t src_index) {
42     const Tensor& src_tensor = MklGetInput(ctx, src_index);
43     MklDnnShape src_mkl_shape;
44     GetMklShape(ctx, src_index, &src_mkl_shape);
45     return src_mkl_shape.IsMklTensor() ? src_mkl_shape.GetTfShape()
46                                        : src_tensor.shape();
47   }
48 
CheckInputShape(OpKernelContext * ctx)49   bool CheckInputShape(OpKernelContext* ctx) {
50     const int num_inputs = ctx->num_inputs() / 2;
51     const TensorShape src0_shape = GetTensorShape(ctx, 0);
52 
53     for (size_t i = 1; i < num_inputs; ++i) {
54       if (!src0_shape.IsSameSize(GetTensorShape(ctx, i))) {
55         ctx->SetStatus(errors::InvalidArgument(
56             "Inputs to operation ", this->name(), " of type ",
57             this->type_string(),
58             " must have the same size and shape.  Input 0: ",
59             src0_shape.DebugString(), " != input : ", i,
60             GetTensorShape(ctx, i).DebugString()));
61 
62         return false;
63       }
64     }
65 
66     return true;
67   }
68 
69   // Return first tensor index which is in MKL layout, or -1 with no MKL input.
FindMKLInputIndex(OpKernelContext * ctx)70   int FindMKLInputIndex(OpKernelContext* ctx) {
71     int mkl_index = -1;
72     const int num_inputs = ctx->num_inputs() / 2;
73 
74     MklDnnShape src_mkl_shape;
75     for (size_t i = 0; i < num_inputs; ++i) {
76       GetMklShape(ctx, i, &src_mkl_shape);
77       if (src_mkl_shape.IsMklTensor()) {
78         mkl_index = i;
79         break;
80       }
81     }
82 
83     return mkl_index;
84   }
85 
ComputeScalar(OpKernelContext * ctx)86   void ComputeScalar(OpKernelContext* ctx) {
87     const int num_inputs = ctx->num_inputs() / 2;
88     const size_t kOutputIdx = 0;
89     TensorShape output_tf_shape;
90     MklDnnShape output_mkl_shape;
91     Tensor* dst_tensor = nullptr;
92 
93     T sum = static_cast<T>(0);
94     for (int src_idx = 0; src_idx < num_inputs; ++src_idx) {
95       const Tensor& src_tensor = MklGetInput(ctx, src_idx);
96       T* src_i = const_cast<T*>(src_tensor.flat<T>().data());
97       sum += src_i[0];
98     }
99 
100     output_mkl_shape.SetMklTensor(false);
101     output_tf_shape = MklGetInput(ctx, kOutputIdx).shape();
102     AllocateOutputSetMklShape(ctx, kOutputIdx, &dst_tensor, output_tf_shape,
103                               output_mkl_shape);
104 
105     T* out_o = dst_tensor->flat<T>().data();
106     out_o[0] = sum;
107   }
108 
Compute(OpKernelContext * ctx)109   void Compute(OpKernelContext* ctx) override {
110     // Each input tensor in MKL layout has additional meta-tensor carrying
111     // layout information. So the number of actual tensors is half the total
112     // number of inputs.
113     const int num_inputs = ctx->num_inputs() / 2;
114 
115     MklDnnShape mkl_shape;
116     const size_t kSrc0Idx = 0;
117     const size_t kOutputIdx = 0;
118 
119     if (num_inputs == 1) {
120       GetMklShape(ctx, kSrc0Idx, &mkl_shape);
121       bool input_in_mkl_format = mkl_shape.IsMklTensor();
122 
123       if (input_in_mkl_format) {
124         ForwardMklTensorInToOut(ctx, kSrc0Idx, kOutputIdx);
125       } else {
126         ForwardTfTensorInToOut(ctx, kSrc0Idx, kOutputIdx);
127       }
128       return;
129     }
130 
131     // Check if the input shape is same
132     if (!CheckInputShape(ctx)) return;
133 
134     try {
135       TensorShape output_tf_shape;
136       MklDnnShape output_mkl_shape;
137       const Tensor& src_tensor = MklGetInput(ctx, kSrc0Idx);
138 
139       Tensor* dst_tensor = nullptr;
140 
141       // Nothing to compute, return.
142       if (src_tensor.shape().num_elements() == 0) {
143         output_mkl_shape.SetMklTensor(false);
144         output_tf_shape = src_tensor.shape();
145         AllocateOutputSetMklShape(ctx, kOutputIdx, &dst_tensor, output_tf_shape,
146                                   output_mkl_shape);
147         return;
148       }
149 
150       if (src_tensor.dims() == 0) {
151         ComputeScalar(ctx);
152         return;
153       }
154 
155       auto cpu_engine = engine(engine::kind::cpu, 0);
156       std::vector<float> coeff(num_inputs, 1.0);
157       std::vector<memory::desc> srcs_pd;
158       std::vector<memory> inputs;
159 
160       MklDnnData<T> dst(&cpu_engine);
161       MklDnnData<T> src(&cpu_engine);
162       bool has_mkl_input = false;
163       int mkl_input_index = FindMKLInputIndex(ctx);
164       MklTensorFormat mkl_data_format;
165       TensorFormat tf_data_format;
166       memory::format_tag dnn_fmt = memory::format_tag::any;
167       if (mkl_input_index >= 0) {
168         has_mkl_input = true;
169         GetMklShape(ctx, mkl_input_index, &mkl_shape);
170         // MKL input has the data format information.
171         mkl_data_format = mkl_shape.GetTfDataFormat();
172         tf_data_format = MklDnnDataFormatToTFDataFormat(mkl_data_format);
173         dnn_fmt = MklTensorFormatToMklDnnDataFormat(mkl_data_format);
174       }
175 
176       std::shared_ptr<stream> fwd_cpu_stream;
177       MklDnnThreadPool eigen_tp(ctx);
178       fwd_cpu_stream.reset(CreateStream(&eigen_tp, cpu_engine));
179 
180       // Create memory descriptor for MKL-DNN.
181       // If all input in Tensorflow format, create block memory descriptor,
182       // else convert TF format to MKL memory descriptor
183       for (int src_idx = 0; src_idx < num_inputs; ++src_idx) {
184         MklDnnShape src_mkl_shape;
185         GetMklShape(ctx, src_idx, &src_mkl_shape);
186         memory::desc md({}, memory::data_type::undef,
187                         memory::format_tag::undef);
188         const Tensor& src_tensor = MklGetInput(ctx, src_idx);
189 
190         if (src_mkl_shape.IsMklTensor()) {
191           md = src_mkl_shape.GetMklLayout();
192         } else {
193           if (has_mkl_input) {
194             memory::dims src_dims;
195             if (src_tensor.dims() == 4) {
196               src_dims =
197                   TFShapeToMklDnnDimsInNCHW(src_tensor.shape(), tf_data_format);
198             } else {
199               DCHECK(src_tensor.dims() == 5);
200               src_dims = TFShapeToMklDnnDimsInNCDHW(src_tensor.shape(),
201                                                     tf_data_format);
202             }
203             md = memory::desc(src_dims, MklDnnType<T>(), dnn_fmt);
204           } else {
205             // Create block memory descriptor for TensorFlow format input.
206             auto dims = TFShapeToMklDnnDims(src_tensor.shape());
207             auto strides = CalculateTFStrides(dims);
208             md = MklDnnData<T>::CreateBlockedMemDesc(dims, strides);
209           }
210         }
211         srcs_pd.push_back(memory::desc(md));
212         src.SetUsrMem(md, &src_tensor);
213         src.SetUsrMemDataHandle(&src_tensor, fwd_cpu_stream);
214         inputs.push_back(src.GetOpMem());
215       }
216 
217       auto sum_pd = sum::primitive_desc(coeff, srcs_pd, cpu_engine);
218       output_mkl_shape.SetMklTensor(has_mkl_input);
219       auto output_pd = sum_pd.dst_desc();
220       dst.SetUsrMem(output_pd);
221 
222       if (has_mkl_input) {
223         output_mkl_shape.SetMklLayout(&output_pd);
224         output_mkl_shape.SetElemType(MklDnnType<T>());
225         output_mkl_shape.SetTfLayout(mkl_shape.GetDimension(),
226                                      mkl_shape.GetSizesAsMklDnnDims(),
227                                      mkl_shape.GetTfDataFormat());
228         output_tf_shape.AddDim((output_pd.get_size() / sizeof(T)));
229       } else {
230         // All inputs have TF shapes, get the shape from first one.
231         output_tf_shape = MklGetInput(ctx, kSrc0Idx).shape();
232       }
233       AllocateOutputSetMklShape(ctx, kOutputIdx, &dst_tensor, output_tf_shape,
234                                 output_mkl_shape);
235       dst.SetUsrMemDataHandle(dst_tensor, fwd_cpu_stream);
236 
237       // Create Sum op, and submit net for execution.
238       std::vector<primitive> net;
239       dnnl::sum sum_op(sum_pd);
240       std::unordered_map<int, memory> net_args = {
241           {DNNL_ARG_DST, dst.GetOpMem()}};
242       for (int i = 0; i < num_inputs; ++i) {
243         net_args.insert({DNNL_ARG_MULTIPLE_SRC + i, inputs[i]});
244       }
245       sum_op.execute(*fwd_cpu_stream, net_args);
246     } catch (dnnl::error& e) {
247       string error_msg = "Status: " + std::to_string(e.status) +
248                          ", message: " + string(e.message) + ", in file " +
249                          string(__FILE__) + ":" + std::to_string(__LINE__);
250       OP_REQUIRES_OK(
251           ctx, errors::Aborted("Operation received an exception:", error_msg));
252     }
253   }
254 };
255 
256 #define REGISTER_MKL_CPU(T)                                    \
257   REGISTER_KERNEL_BUILDER(                                     \
258       Name("_MklAddN")                                         \
259           .Device(DEVICE_CPU)                                  \
260           .TypeConstraint<T>("T")                              \
261           .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
262       MklAddNOp<CPUDevice, T>);
263 
264 TF_CALL_float(REGISTER_MKL_CPU);
265 TF_CALL_bfloat16(REGISTER_MKL_CPU);
266 #undef REGISTER_MKL_CPU
267 }  // namespace tensorflow
268 #endif  // INTEL_MKL
269