xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/mkl/mkl_softmax_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/nn_ops.cc.
17 
18 #ifdef INTEL_MKL
19 
20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
21 #include "dnnl.hpp"
22 #include "tensorflow/core/framework/numeric_op.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/register_types.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/util/mkl_util.h"
28 #include "tensorflow/core/util/tensor_format.h"
29 #ifdef DNNL_AARCH64_USE_ACL
30 #include "tensorflow/core/platform/mutex.h"
31 #endif
32 
33 using dnnl::prop_kind;
34 using dnnl::softmax_forward;
35 using dnnl::stream;
36 
37 namespace tensorflow {
38 
39 class MklSoftmaxParams {
40  public:
41   memory::dims src_dims;
42   MklTensorFormat src_fmt;
43   int axis;
44 #ifdef DNNL_AARCH64_USE_ACL
45   int aarch64_counter;
46 #endif
MklSoftmaxParams(memory::dims src_dims,MklTensorFormat src_fmt,int axis)47   MklSoftmaxParams(memory::dims src_dims, MklTensorFormat src_fmt, int axis)
48       : src_dims(src_dims), src_fmt(src_fmt), axis(axis) {}
49 };
50 
51 template <typename T>
52 class MklSoftmaxPrimitive : public MklPrimitive {
53  public:
MklSoftmaxPrimitive(const MklSoftmaxParams & fwdParams)54   explicit MklSoftmaxPrimitive(const MklSoftmaxParams& fwdParams)
55       : MklPrimitive(engine(engine::kind::cpu, 0)) {
56     Setup(fwdParams);
57   }
58 
~MklSoftmaxPrimitive()59   ~MklSoftmaxPrimitive() {}
60 
61   // Softmax forward execute
62   //   src_data:  input data buffer of src
63   //   dst_data:  output data buffer of dst
Execute(const T * src_data,T * dst_data,std::shared_ptr<stream> fwd_cpu_stream)64   void Execute(const T* src_data, T* dst_data,
65                std::shared_ptr<stream> fwd_cpu_stream) {
66 #ifdef DNNL_AARCH64_USE_ACL
67     mutex_lock lock(primitive_execution_mu_);
68 #endif
69 #ifndef ENABLE_ONEDNN_OPENMP
70     context_.src_mem->set_data_handle(
71         static_cast<void*>(const_cast<T*>(src_data)), *fwd_cpu_stream);
72     context_.dst_mem->set_data_handle(static_cast<void*>(dst_data),
73                                       *fwd_cpu_stream);
74 #else
75     context_.src_mem->set_data_handle(
76         static_cast<void*>(const_cast<T*>(src_data)));
77     context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
78 #endif  // !ENABLE_ONEDNN_OPENMP
79 
80     DCHECK_EQ(context_.fwd_primitives.size(), context_.fwd_net_args.size());
81     execute_primitives(context_.fwd_primitives, fwd_cpu_stream,
82                        context_.fwd_net_args);
83 
84     // After execution, set data handle back.
85     context_.src_mem->set_data_handle(DummyData);
86     context_.dst_mem->set_data_handle(DummyData);
87   }
88 
GetSoftmaxFwdPd()89   std::shared_ptr<dnnl::softmax_forward::primitive_desc> GetSoftmaxFwdPd() {
90     return context_.fwd_pd;
91   }
92 
93  private:
94   struct SoftmaxFwdContext {
95     // MKL-DNN memory.
96     std::shared_ptr<memory> src_mem;
97     std::shared_ptr<memory> dst_mem;
98 
99     // Primitive descriptor.
100     std::shared_ptr<dnnl::softmax_forward::desc> fwd_desc;
101 
102     // Memory descriptor.
103     std::shared_ptr<memory::desc> src_md;
104 
105     // Softmax primitive.
106     std::shared_ptr<dnnl::softmax_forward::primitive_desc> fwd_pd;
107     std::shared_ptr<dnnl::primitive> softmax_fwd;
108 
109     std::vector<dnnl::primitive> fwd_primitives;
110     std::vector<MemoryArgsMap> fwd_net_args;
111 
SoftmaxFwdContexttensorflow::MklSoftmaxPrimitive::SoftmaxFwdContext112     SoftmaxFwdContext()
113         : src_mem(nullptr),
114           dst_mem(nullptr),
115           fwd_desc(nullptr),
116           src_md(nullptr),
117           fwd_pd(nullptr),
118           softmax_fwd(nullptr) {}
119   };
120 
121   // Softmax forward primitive setup
Setup(const MklSoftmaxParams & fwdParams)122   void Setup(const MklSoftmaxParams& fwdParams) {
123     // Create memory descriptors for softmax data with specified format.
124     auto src_format = MklTensorFormatToMklDnnDataFormat(fwdParams.src_fmt);
125     context_.src_md.reset(
126         new memory::desc({fwdParams.src_dims}, MklDnnType<T>(), src_format));
127 
128     // Create softmax descriptor and primitive descriptor.
129     context_.fwd_desc.reset(new dnnl::softmax_forward::desc(
130         prop_kind::forward_scoring, *context_.src_md, fwdParams.axis));
131     context_.fwd_pd.reset(new dnnl::softmax_forward::primitive_desc(
132         *context_.fwd_desc, cpu_engine_));
133 
134     // Create memory primitive based on dummy data.
135     context_.src_mem.reset(
136         new memory(*context_.src_md, cpu_engine_, DummyData));
137     context_.dst_mem.reset(
138         new memory(context_.fwd_pd.get()->dst_desc(), cpu_engine_, DummyData));
139 
140     // Create softmax primitive and add it to net
141     context_.softmax_fwd.reset(new dnnl::softmax_forward(*context_.fwd_pd));
142     context_.fwd_net_args.push_back(
143         {{DNNL_ARG_SRC, *context_.src_mem}, {DNNL_ARG_DST, *context_.dst_mem}});
144 
145     context_.fwd_primitives.push_back(*context_.softmax_fwd);
146   }
147 
148   struct SoftmaxFwdContext context_;
149 
150 #ifdef DNNL_AARCH64_USE_ACL
151   mutex primitive_execution_mu_;
152 #endif
153 };
154 
155 template <typename T>
156 class MklSoftmaxPrimitiveFactory : public MklPrimitiveFactory<T> {
157  public:
Get(const MklSoftmaxParams & fwdParams)158   static MklSoftmaxPrimitive<T>* Get(const MklSoftmaxParams& fwdParams) {
159     // Get a softmax fwd primitive from the cached pool.
160     MklSoftmaxPrimitive<T>* softmax_forward =
161         static_cast<MklSoftmaxPrimitive<T>*>(
162             MklSoftmaxPrimitiveFactory<T>::GetInstance().GetSoftmaxFwd(
163                 fwdParams));
164     if (softmax_forward == nullptr) {
165       softmax_forward = new MklSoftmaxPrimitive<T>(fwdParams);
166       MklSoftmaxPrimitiveFactory<T>::GetInstance().SetSoftmaxFwd(
167           fwdParams, softmax_forward);
168     }
169     return softmax_forward;
170   }
171 
GetInstance()172   static MklSoftmaxPrimitiveFactory& GetInstance() {
173     static MklSoftmaxPrimitiveFactory instance_;
174     return instance_;
175   }
176 
177  private:
MklSoftmaxPrimitiveFactory()178   MklSoftmaxPrimitiveFactory() {}
~MklSoftmaxPrimitiveFactory()179   ~MklSoftmaxPrimitiveFactory() {}
180 
CreateKey(const MklSoftmaxParams & fwdParams)181   static string CreateKey(const MklSoftmaxParams& fwdParams) {
182     string prefix = "softmax_fwd";
183     FactoryKeyCreator key_creator;
184     key_creator.AddAsKey(prefix);
185     key_creator.AddAsKey(fwdParams.src_dims);
186     key_creator.AddAsKey<int>(static_cast<int>(fwdParams.src_fmt));
187     key_creator.AddAsKey<int>(fwdParams.axis);
188 #ifdef DNNL_AARCH64_USE_ACL
189     key_creator.AddAsKey(fwdParams.aarch64_counter);
190 #endif
191     return key_creator.GetKey();
192   }
193 
GetSoftmaxFwd(const MklSoftmaxParams & fwdParams)194   MklPrimitive* GetSoftmaxFwd(const MklSoftmaxParams& fwdParams) {
195     string key = CreateKey(fwdParams);
196     return this->GetOp(key);
197   }
198 
SetSoftmaxFwd(const MklSoftmaxParams & fwdParams,MklPrimitive * op)199   void SetSoftmaxFwd(const MklSoftmaxParams& fwdParams, MklPrimitive* op) {
200     string key = CreateKey(fwdParams);
201     this->SetOp(key, op);
202   }
203 };
204 
205 typedef Eigen::ThreadPoolDevice CPUDevice;
206 
207 template <typename Device, typename T>
208 class MklSoftmaxOp : public OpKernel {
209  public:
~MklSoftmaxOp()210   ~MklSoftmaxOp() {}
211 
MklSoftmaxOp(OpKernelConstruction * context)212   explicit MklSoftmaxOp(OpKernelConstruction* context) : OpKernel(context) {}
213 
Compute(OpKernelContext * context)214   void Compute(OpKernelContext* context) override {
215     try {
216       auto cpu_engine = engine(engine::kind::cpu, 0);
217       // src_tensor points to the 0-th input of global data struct "context".
218       size_t src_idx = 0;
219       const Tensor& src_tensor = MklGetInput(context, src_idx);
220       MklDnnShape src_mkl_shape;
221       GetMklShape(context, src_idx, &src_mkl_shape);
222 
223       // src_dims is the dimension of src_tensor.
224       // Dim of the dst will also be same as src_dims.
225       auto src_tf_shape = src_mkl_shape.IsMklTensor()
226                               ? src_mkl_shape.GetTfShape()
227                               : src_tensor.shape();
228       const int input_dims = src_tf_shape.dims();
229       memory::dims src_dims;
230       int axis;
231       if (src_mkl_shape.IsMklTensor()) {
232         src_dims = src_mkl_shape.GetSizesAsMklDnnDims();
233         axis = 1;
234       } else {
235         src_dims = TFShapeToMklDnnDims(src_tf_shape);
236         axis = input_dims - 1;
237       }
238       MklTensorFormat layout_type;
239       // In MKL, data format passed to mkl softmax op depends on dimension of
240       // the input tensor. Here "x" data format in MKL is used for 1 dim tensor,
241       // "nc" for 2 dim tensor, "tnc" for 3 dim tensor, "nchw" for 4 dim tensor,
242       // and "ncdhw" for 5 dim tensor. Each of the symbols has the following
243       // meaning: n = batch, c = channels, t = sequence length, h = height, w =
244       // width, d = depth. When src tensor is MKL, layout_type here is only used
245       // for setting TF layout type of output tensor. When input is TF Tensor,
246       // layout here is no special sense. We use axis to define on which
247       // dimension to do softmax.
248       switch (input_dims) {
249         case 1:
250           layout_type = MklTensorFormat::FORMAT_X;
251           break;
252         case 2:
253           layout_type = MklTensorFormat::FORMAT_NC;
254           break;
255         case 3:
256           layout_type = MklTensorFormat::FORMAT_TNC;
257           break;
258         case 4:
259           if (src_mkl_shape.IsMklTensor()) {
260             layout_type = MklTensorFormat::FORMAT_NHWC;
261           } else {
262             layout_type = MklTensorFormat::FORMAT_NCHW;
263           }
264           break;
265         case 5:
266           if (src_mkl_shape.IsMklTensor()) {
267             layout_type = MklTensorFormat::FORMAT_NDHWC;
268           } else {
269             layout_type = MklTensorFormat::FORMAT_NCDHW;
270           }
271           break;
272         default:
273           OP_REQUIRES_OK(context,
274                          errors::Aborted("Input dims must be <= 5 and >=1"));
275           return;
276       }
277 
278       // If input is in MKL layout, then simply get the format from input;
279       // otherwise, use TF layout defined before.
280       auto src_fmt = src_mkl_shape.IsMklTensor()
281                          ? MklTensorFormat::FORMAT_BLOCKED
282                          : layout_type;
283 
284       // Get a softmax fwd primitive from primitive pool.
285       MklSoftmaxParams fwdParams(src_dims, src_fmt, axis);
286 #ifdef DNNL_AARCH64_USE_ACL
287       // ACL does not support reuse of primitives with different data.
288       // For softmax, the previous approach (PR #47775) of using Tensor
289       // addresses does not work, as the addresses are re-used in matmul with
290       // different data The counter ensures we still benefit from caching via
291       // SetSoftmaxFwd().
292       fwdParams.aarch64_counter =
293           MklSoftmaxPrimitiveFactory<T>::IncrementCounter();
294 #endif
295       MklSoftmaxPrimitive<T>* softmax_fwd =
296           MklSoftmaxPrimitiveFactory<T>::Get(fwdParams);
297 
298       // Prepare for creating output tensor.
299       Tensor* output_tensor = nullptr;
300       MklDnnShape output_mkl_shape;
301       TensorShape output_tf_shape;  // shape of output TF tensor.
302 
303       auto dst_pd = softmax_fwd->GetSoftmaxFwdPd()->dst_desc();
304 
305       // If input is MKL shape, output is also MKL shape.
306       // If input is TF shape, output is also TF shape.
307       if (src_mkl_shape.IsMklTensor()) {
308         output_mkl_shape.SetMklTensor(true);
309         output_mkl_shape.SetMklLayout(&dst_pd);
310         output_mkl_shape.SetElemType(MklDnnType<T>());
311         output_mkl_shape.SetTfLayout(src_dims.size(), src_dims, layout_type);
312         output_tf_shape.AddDim((dst_pd.get_size() / sizeof(T)));
313       } else {
314         output_mkl_shape.SetMklTensor(false);
315         output_tf_shape = MklDnnDimsToTFShape(src_dims);
316       }
317       // Allocate output tensor.
318       AllocateOutputSetMklShape(context, 0, &output_tensor, output_tf_shape,
319                                 output_mkl_shape);
320 
321       const T* src_data = src_tensor.flat<T>().data();
322       T* dst_data = reinterpret_cast<T*>(output_tensor->flat<T>().data());
323       std::shared_ptr<stream> fwd_cpu_stream;
324       MklDnnThreadPool eigen_tp(context);
325       fwd_cpu_stream.reset(CreateStream(&eigen_tp, softmax_fwd->GetEngine()));
326       softmax_fwd->Execute(src_data, dst_data, fwd_cpu_stream);
327     } catch (dnnl::error& e) {
328       string error_msg = "Status: " + std::to_string(e.status) +
329                          ", message: " + string(e.message) + ", in file " +
330                          string(__FILE__) + ":" + std::to_string(__LINE__);
331       OP_REQUIRES_OK(
332           context,
333           errors::Aborted("Operation received an exception:", error_msg));
334     }
335   }
336 };
337 
338 /* Register DNN kernels for supported operations and supported types - right now
339  * it is only Softmax and f32 */
340 #define REGISTER_SOFTMAX_MKL_SUPPORTED_KERNELS_TYPES(type)     \
341   REGISTER_KERNEL_BUILDER(                                     \
342       Name("_MklSoftmax")                                      \
343           .Device(DEVICE_CPU)                                  \
344           .TypeConstraint<type>("T")                           \
345           .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
346       MklSoftmaxOp<CPUDevice, type>);
347 
348 TF_CALL_float(REGISTER_SOFTMAX_MKL_SUPPORTED_KERNELS_TYPES);
349 TF_CALL_bfloat16(REGISTER_SOFTMAX_MKL_SUPPORTED_KERNELS_TYPES);
350 
351 }  // namespace tensorflow
352 
353 #endif  // INTEL_MKL
354