1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/nn_ops.cc. 17 18 #ifdef INTEL_MKL 19 20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 21 #include "dnnl.hpp" 22 #include "tensorflow/core/framework/numeric_op.h" 23 #include "tensorflow/core/framework/op_kernel.h" 24 #include "tensorflow/core/framework/register_types.h" 25 #include "tensorflow/core/framework/tensor.h" 26 #include "tensorflow/core/lib/core/errors.h" 27 #include "tensorflow/core/util/mkl_util.h" 28 #include "tensorflow/core/util/tensor_format.h" 29 #ifdef DNNL_AARCH64_USE_ACL 30 #include "tensorflow/core/platform/mutex.h" 31 #endif 32 33 using dnnl::prop_kind; 34 using dnnl::softmax_forward; 35 using dnnl::stream; 36 37 namespace tensorflow { 38 39 class MklSoftmaxParams { 40 public: 41 memory::dims src_dims; 42 MklTensorFormat src_fmt; 43 int axis; 44 #ifdef DNNL_AARCH64_USE_ACL 45 int aarch64_counter; 46 #endif MklSoftmaxParams(memory::dims src_dims,MklTensorFormat src_fmt,int axis)47 MklSoftmaxParams(memory::dims src_dims, MklTensorFormat src_fmt, int axis) 48 : src_dims(src_dims), src_fmt(src_fmt), axis(axis) {} 49 }; 50 51 template <typename T> 52 class MklSoftmaxPrimitive : public MklPrimitive { 53 public: MklSoftmaxPrimitive(const MklSoftmaxParams & fwdParams)54 explicit MklSoftmaxPrimitive(const MklSoftmaxParams& fwdParams) 55 : MklPrimitive(engine(engine::kind::cpu, 0)) { 56 Setup(fwdParams); 57 } 58 ~MklSoftmaxPrimitive()59 ~MklSoftmaxPrimitive() {} 60 61 // Softmax forward execute 62 // src_data: input data buffer of src 63 // dst_data: output data buffer of dst Execute(const T * src_data,T * dst_data,std::shared_ptr<stream> fwd_cpu_stream)64 void Execute(const T* src_data, T* dst_data, 65 std::shared_ptr<stream> fwd_cpu_stream) { 66 #ifdef DNNL_AARCH64_USE_ACL 67 mutex_lock lock(primitive_execution_mu_); 68 #endif 69 #ifndef ENABLE_ONEDNN_OPENMP 70 context_.src_mem->set_data_handle( 71 static_cast<void*>(const_cast<T*>(src_data)), *fwd_cpu_stream); 72 context_.dst_mem->set_data_handle(static_cast<void*>(dst_data), 73 *fwd_cpu_stream); 74 #else 75 context_.src_mem->set_data_handle( 76 static_cast<void*>(const_cast<T*>(src_data))); 77 context_.dst_mem->set_data_handle(static_cast<void*>(dst_data)); 78 #endif // !ENABLE_ONEDNN_OPENMP 79 80 DCHECK_EQ(context_.fwd_primitives.size(), context_.fwd_net_args.size()); 81 execute_primitives(context_.fwd_primitives, fwd_cpu_stream, 82 context_.fwd_net_args); 83 84 // After execution, set data handle back. 85 context_.src_mem->set_data_handle(DummyData); 86 context_.dst_mem->set_data_handle(DummyData); 87 } 88 GetSoftmaxFwdPd()89 std::shared_ptr<dnnl::softmax_forward::primitive_desc> GetSoftmaxFwdPd() { 90 return context_.fwd_pd; 91 } 92 93 private: 94 struct SoftmaxFwdContext { 95 // MKL-DNN memory. 96 std::shared_ptr<memory> src_mem; 97 std::shared_ptr<memory> dst_mem; 98 99 // Primitive descriptor. 100 std::shared_ptr<dnnl::softmax_forward::desc> fwd_desc; 101 102 // Memory descriptor. 103 std::shared_ptr<memory::desc> src_md; 104 105 // Softmax primitive. 106 std::shared_ptr<dnnl::softmax_forward::primitive_desc> fwd_pd; 107 std::shared_ptr<dnnl::primitive> softmax_fwd; 108 109 std::vector<dnnl::primitive> fwd_primitives; 110 std::vector<MemoryArgsMap> fwd_net_args; 111 SoftmaxFwdContexttensorflow::MklSoftmaxPrimitive::SoftmaxFwdContext112 SoftmaxFwdContext() 113 : src_mem(nullptr), 114 dst_mem(nullptr), 115 fwd_desc(nullptr), 116 src_md(nullptr), 117 fwd_pd(nullptr), 118 softmax_fwd(nullptr) {} 119 }; 120 121 // Softmax forward primitive setup Setup(const MklSoftmaxParams & fwdParams)122 void Setup(const MklSoftmaxParams& fwdParams) { 123 // Create memory descriptors for softmax data with specified format. 124 auto src_format = MklTensorFormatToMklDnnDataFormat(fwdParams.src_fmt); 125 context_.src_md.reset( 126 new memory::desc({fwdParams.src_dims}, MklDnnType<T>(), src_format)); 127 128 // Create softmax descriptor and primitive descriptor. 129 context_.fwd_desc.reset(new dnnl::softmax_forward::desc( 130 prop_kind::forward_scoring, *context_.src_md, fwdParams.axis)); 131 context_.fwd_pd.reset(new dnnl::softmax_forward::primitive_desc( 132 *context_.fwd_desc, cpu_engine_)); 133 134 // Create memory primitive based on dummy data. 135 context_.src_mem.reset( 136 new memory(*context_.src_md, cpu_engine_, DummyData)); 137 context_.dst_mem.reset( 138 new memory(context_.fwd_pd.get()->dst_desc(), cpu_engine_, DummyData)); 139 140 // Create softmax primitive and add it to net 141 context_.softmax_fwd.reset(new dnnl::softmax_forward(*context_.fwd_pd)); 142 context_.fwd_net_args.push_back( 143 {{DNNL_ARG_SRC, *context_.src_mem}, {DNNL_ARG_DST, *context_.dst_mem}}); 144 145 context_.fwd_primitives.push_back(*context_.softmax_fwd); 146 } 147 148 struct SoftmaxFwdContext context_; 149 150 #ifdef DNNL_AARCH64_USE_ACL 151 mutex primitive_execution_mu_; 152 #endif 153 }; 154 155 template <typename T> 156 class MklSoftmaxPrimitiveFactory : public MklPrimitiveFactory<T> { 157 public: Get(const MklSoftmaxParams & fwdParams)158 static MklSoftmaxPrimitive<T>* Get(const MklSoftmaxParams& fwdParams) { 159 // Get a softmax fwd primitive from the cached pool. 160 MklSoftmaxPrimitive<T>* softmax_forward = 161 static_cast<MklSoftmaxPrimitive<T>*>( 162 MklSoftmaxPrimitiveFactory<T>::GetInstance().GetSoftmaxFwd( 163 fwdParams)); 164 if (softmax_forward == nullptr) { 165 softmax_forward = new MklSoftmaxPrimitive<T>(fwdParams); 166 MklSoftmaxPrimitiveFactory<T>::GetInstance().SetSoftmaxFwd( 167 fwdParams, softmax_forward); 168 } 169 return softmax_forward; 170 } 171 GetInstance()172 static MklSoftmaxPrimitiveFactory& GetInstance() { 173 static MklSoftmaxPrimitiveFactory instance_; 174 return instance_; 175 } 176 177 private: MklSoftmaxPrimitiveFactory()178 MklSoftmaxPrimitiveFactory() {} ~MklSoftmaxPrimitiveFactory()179 ~MklSoftmaxPrimitiveFactory() {} 180 CreateKey(const MklSoftmaxParams & fwdParams)181 static string CreateKey(const MklSoftmaxParams& fwdParams) { 182 string prefix = "softmax_fwd"; 183 FactoryKeyCreator key_creator; 184 key_creator.AddAsKey(prefix); 185 key_creator.AddAsKey(fwdParams.src_dims); 186 key_creator.AddAsKey<int>(static_cast<int>(fwdParams.src_fmt)); 187 key_creator.AddAsKey<int>(fwdParams.axis); 188 #ifdef DNNL_AARCH64_USE_ACL 189 key_creator.AddAsKey(fwdParams.aarch64_counter); 190 #endif 191 return key_creator.GetKey(); 192 } 193 GetSoftmaxFwd(const MklSoftmaxParams & fwdParams)194 MklPrimitive* GetSoftmaxFwd(const MklSoftmaxParams& fwdParams) { 195 string key = CreateKey(fwdParams); 196 return this->GetOp(key); 197 } 198 SetSoftmaxFwd(const MklSoftmaxParams & fwdParams,MklPrimitive * op)199 void SetSoftmaxFwd(const MklSoftmaxParams& fwdParams, MklPrimitive* op) { 200 string key = CreateKey(fwdParams); 201 this->SetOp(key, op); 202 } 203 }; 204 205 typedef Eigen::ThreadPoolDevice CPUDevice; 206 207 template <typename Device, typename T> 208 class MklSoftmaxOp : public OpKernel { 209 public: ~MklSoftmaxOp()210 ~MklSoftmaxOp() {} 211 MklSoftmaxOp(OpKernelConstruction * context)212 explicit MklSoftmaxOp(OpKernelConstruction* context) : OpKernel(context) {} 213 Compute(OpKernelContext * context)214 void Compute(OpKernelContext* context) override { 215 try { 216 auto cpu_engine = engine(engine::kind::cpu, 0); 217 // src_tensor points to the 0-th input of global data struct "context". 218 size_t src_idx = 0; 219 const Tensor& src_tensor = MklGetInput(context, src_idx); 220 MklDnnShape src_mkl_shape; 221 GetMklShape(context, src_idx, &src_mkl_shape); 222 223 // src_dims is the dimension of src_tensor. 224 // Dim of the dst will also be same as src_dims. 225 auto src_tf_shape = src_mkl_shape.IsMklTensor() 226 ? src_mkl_shape.GetTfShape() 227 : src_tensor.shape(); 228 const int input_dims = src_tf_shape.dims(); 229 memory::dims src_dims; 230 int axis; 231 if (src_mkl_shape.IsMklTensor()) { 232 src_dims = src_mkl_shape.GetSizesAsMklDnnDims(); 233 axis = 1; 234 } else { 235 src_dims = TFShapeToMklDnnDims(src_tf_shape); 236 axis = input_dims - 1; 237 } 238 MklTensorFormat layout_type; 239 // In MKL, data format passed to mkl softmax op depends on dimension of 240 // the input tensor. Here "x" data format in MKL is used for 1 dim tensor, 241 // "nc" for 2 dim tensor, "tnc" for 3 dim tensor, "nchw" for 4 dim tensor, 242 // and "ncdhw" for 5 dim tensor. Each of the symbols has the following 243 // meaning: n = batch, c = channels, t = sequence length, h = height, w = 244 // width, d = depth. When src tensor is MKL, layout_type here is only used 245 // for setting TF layout type of output tensor. When input is TF Tensor, 246 // layout here is no special sense. We use axis to define on which 247 // dimension to do softmax. 248 switch (input_dims) { 249 case 1: 250 layout_type = MklTensorFormat::FORMAT_X; 251 break; 252 case 2: 253 layout_type = MklTensorFormat::FORMAT_NC; 254 break; 255 case 3: 256 layout_type = MklTensorFormat::FORMAT_TNC; 257 break; 258 case 4: 259 if (src_mkl_shape.IsMklTensor()) { 260 layout_type = MklTensorFormat::FORMAT_NHWC; 261 } else { 262 layout_type = MklTensorFormat::FORMAT_NCHW; 263 } 264 break; 265 case 5: 266 if (src_mkl_shape.IsMklTensor()) { 267 layout_type = MklTensorFormat::FORMAT_NDHWC; 268 } else { 269 layout_type = MklTensorFormat::FORMAT_NCDHW; 270 } 271 break; 272 default: 273 OP_REQUIRES_OK(context, 274 errors::Aborted("Input dims must be <= 5 and >=1")); 275 return; 276 } 277 278 // If input is in MKL layout, then simply get the format from input; 279 // otherwise, use TF layout defined before. 280 auto src_fmt = src_mkl_shape.IsMklTensor() 281 ? MklTensorFormat::FORMAT_BLOCKED 282 : layout_type; 283 284 // Get a softmax fwd primitive from primitive pool. 285 MklSoftmaxParams fwdParams(src_dims, src_fmt, axis); 286 #ifdef DNNL_AARCH64_USE_ACL 287 // ACL does not support reuse of primitives with different data. 288 // For softmax, the previous approach (PR #47775) of using Tensor 289 // addresses does not work, as the addresses are re-used in matmul with 290 // different data The counter ensures we still benefit from caching via 291 // SetSoftmaxFwd(). 292 fwdParams.aarch64_counter = 293 MklSoftmaxPrimitiveFactory<T>::IncrementCounter(); 294 #endif 295 MklSoftmaxPrimitive<T>* softmax_fwd = 296 MklSoftmaxPrimitiveFactory<T>::Get(fwdParams); 297 298 // Prepare for creating output tensor. 299 Tensor* output_tensor = nullptr; 300 MklDnnShape output_mkl_shape; 301 TensorShape output_tf_shape; // shape of output TF tensor. 302 303 auto dst_pd = softmax_fwd->GetSoftmaxFwdPd()->dst_desc(); 304 305 // If input is MKL shape, output is also MKL shape. 306 // If input is TF shape, output is also TF shape. 307 if (src_mkl_shape.IsMklTensor()) { 308 output_mkl_shape.SetMklTensor(true); 309 output_mkl_shape.SetMklLayout(&dst_pd); 310 output_mkl_shape.SetElemType(MklDnnType<T>()); 311 output_mkl_shape.SetTfLayout(src_dims.size(), src_dims, layout_type); 312 output_tf_shape.AddDim((dst_pd.get_size() / sizeof(T))); 313 } else { 314 output_mkl_shape.SetMklTensor(false); 315 output_tf_shape = MklDnnDimsToTFShape(src_dims); 316 } 317 // Allocate output tensor. 318 AllocateOutputSetMklShape(context, 0, &output_tensor, output_tf_shape, 319 output_mkl_shape); 320 321 const T* src_data = src_tensor.flat<T>().data(); 322 T* dst_data = reinterpret_cast<T*>(output_tensor->flat<T>().data()); 323 std::shared_ptr<stream> fwd_cpu_stream; 324 MklDnnThreadPool eigen_tp(context); 325 fwd_cpu_stream.reset(CreateStream(&eigen_tp, softmax_fwd->GetEngine())); 326 softmax_fwd->Execute(src_data, dst_data, fwd_cpu_stream); 327 } catch (dnnl::error& e) { 328 string error_msg = "Status: " + std::to_string(e.status) + 329 ", message: " + string(e.message) + ", in file " + 330 string(__FILE__) + ":" + std::to_string(__LINE__); 331 OP_REQUIRES_OK( 332 context, 333 errors::Aborted("Operation received an exception:", error_msg)); 334 } 335 } 336 }; 337 338 /* Register DNN kernels for supported operations and supported types - right now 339 * it is only Softmax and f32 */ 340 #define REGISTER_SOFTMAX_MKL_SUPPORTED_KERNELS_TYPES(type) \ 341 REGISTER_KERNEL_BUILDER( \ 342 Name("_MklSoftmax") \ 343 .Device(DEVICE_CPU) \ 344 .TypeConstraint<type>("T") \ 345 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 346 MklSoftmaxOp<CPUDevice, type>); 347 348 TF_CALL_float(REGISTER_SOFTMAX_MKL_SUPPORTED_KERNELS_TYPES); 349 TF_CALL_bfloat16(REGISTER_SOFTMAX_MKL_SUPPORTED_KERNELS_TYPES); 350 351 } // namespace tensorflow 352 353 #endif // INTEL_MKL 354