xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/cpu/runtime_conv2d_mkl.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/xla/service/cpu/runtime_conv2d_mkl.h"
16 
17 #include <iostream>
18 
19 #include "absl/base/dynamic_annotations.h"
20 #include "tensorflow/compiler/xla/executable_run_options.h"
21 
22 #ifdef ENABLE_MKL
23 #include <omp.h>
24 
25 #include "dnnl.hpp"
26 #include "tensorflow/compiler/xla/service/cpu/runtime_conv2d.h"
27 
28 namespace {
29 
30 // Downcast an int64_t to int and check if value is in range.
ToInt(int64_t input)31 int ToInt(int64_t input) {
32   int output = static_cast<int>(input);
33   if (static_cast<int64_t>(output) != input) {
34     std::cerr << "Error occurred in downcasting int64_t to int32_t: Value "
35               << input << " is out-of-range for type int32_t. \n";
36     exit(1);
37   }
38   return output;
39 }
40 
41 using dnnl::convolution_direct;
42 using dnnl::convolution_forward;
43 using dnnl::engine;
44 using dnnl::memory;
45 using dnnl::padding_kind;
46 using dnnl::primitive;
47 using dnnl::prop_kind;
48 using dnnl::reorder;
49 using dnnl::stream;
50 
51 template <typename EigenDevice, typename ScalarType>
MKLConvImpl(const EigenDevice & device,ScalarType * out,ScalarType * lhs,ScalarType * rhs,int64_t input_batch,int64_t input_rows,int64_t input_cols,int64_t input_channels,int64_t kernel_rows,int64_t kernel_cols,int64_t kernel_channels,int64_t kernel_filters,int64_t output_rows,int64_t output_cols,int64_t row_stride,int64_t col_stride,int64_t padding_top,int64_t padding_bottom,int64_t padding_left,int64_t padding_right,int64_t lhs_row_dilation,int64_t lhs_col_dilation,int64_t rhs_row_dilation,int64_t rhs_col_dilation)52 void MKLConvImpl(const EigenDevice& device, ScalarType* out, ScalarType* lhs,
53                  ScalarType* rhs, int64_t input_batch, int64_t input_rows,
54                  int64_t input_cols, int64_t input_channels,
55                  int64_t kernel_rows, int64_t kernel_cols,
56                  int64_t kernel_channels, int64_t kernel_filters,
57                  int64_t output_rows, int64_t output_cols, int64_t row_stride,
58                  int64_t col_stride, int64_t padding_top,
59                  int64_t padding_bottom, int64_t padding_left,
60                  int64_t padding_right, int64_t lhs_row_dilation,
61                  int64_t lhs_col_dilation, int64_t rhs_row_dilation,
62                  int64_t rhs_col_dilation) {
63   auto cpu_engine = engine(engine::cpu, 0);
64 
65   // Create a vector primitive to hold the network.
66   std::vector<primitive> net;
67 
68   // Since memory::dims takes int for each dimension, we downcast the int64_t
69   // values to int using the ToInt function defined above.
70   memory::dims conv1_src_dim = {ToInt(input_batch), ToInt(input_channels),
71                                 ToInt(input_rows), ToInt(input_cols)};
72   memory::dims conv1_weights_dim = {ToInt(kernel_filters),
73                                     ToInt(kernel_channels), ToInt(kernel_rows),
74                                     ToInt(kernel_cols)};
75   memory::dims conv1_dst_dim = {ToInt(input_batch), ToInt(kernel_filters),
76                                 ToInt(output_rows), ToInt(output_cols)};
77   memory::dims conv1_strides = {ToInt(row_stride), ToInt(col_stride)};
78   // Note: In MKL_DNN dilation starts from 0.
79   memory::dims conv1_dilates = {ToInt(rhs_row_dilation - 1),
80                                 ToInt(rhs_col_dilation - 1)};
81   memory::dims conv1_padding_l = {ToInt(padding_top), ToInt(padding_left)};
82   memory::dims conv1_padding_r = {ToInt(padding_bottom), ToInt(padding_right)};
83 
84   // Create memory for user data. Input and output data have format of NHWC and
85   // kernel data has format of HWIO.
86   // Note that as a convention in MKL-DNN, the dimensions of the data is always
87   // described in NCHW/IOHW, regardless of the actual layout of the data.
88   auto user_src_memory =
89       memory({{{conv1_src_dim}, memory::data_type::f32, memory::format::nhwc},
90               cpu_engine},
91              lhs);
92   auto user_weights_memory = memory(
93       {{{conv1_weights_dim}, memory::data_type::f32, memory::format::hwio},
94        cpu_engine},
95       rhs);
96   auto user_dst_memory =
97       memory({{{conv1_dst_dim}, memory::data_type::f32, memory::format::nhwc},
98               cpu_engine},
99              out);
100 
101   // Create memory descriptors for convolution data with no specified format for
102   // best performance.
103   auto conv1_src_mem_desc = memory::desc(
104       {conv1_src_dim}, memory::data_type::f32, memory::format::any);
105   auto conv1_weights_mem_desc = memory::desc(
106       {conv1_weights_dim}, memory::data_type::f32, memory::format::any);
107   auto conv1_dst_mem_desc = memory::desc(
108       {conv1_dst_dim}, memory::data_type::f32, memory::format::any);
109 
110   // Create a convolution.
111   auto conv1_desc = convolution_forward::desc(
112       prop_kind::forward_inference, convolution_direct, conv1_src_mem_desc,
113       conv1_weights_mem_desc, conv1_dst_mem_desc, conv1_strides, conv1_dilates,
114       conv1_padding_l, conv1_padding_r, padding_kind::zero);
115   auto conv1_prim_desc =
116       convolution_forward::primitive_desc(conv1_desc, cpu_engine);
117 
118   // Create reorders for data and weights if layout requested by convolution is
119   // different from NCHW/OIHW.
120   auto conv1_src_memory = user_src_memory;
121   if (memory::primitive_desc(conv1_prim_desc.src_primitive_desc()) !=
122       user_src_memory.get_primitive_desc()) {
123     conv1_src_memory = memory(conv1_prim_desc.src_primitive_desc());
124     net.push_back(reorder(user_src_memory, conv1_src_memory));
125   }
126 
127   auto conv1_weights_memory = user_weights_memory;
128   if (memory::primitive_desc(conv1_prim_desc.weights_primitive_desc()) !=
129       user_weights_memory.get_primitive_desc()) {
130     conv1_weights_memory = memory(conv1_prim_desc.weights_primitive_desc());
131     net.push_back(reorder(user_weights_memory, conv1_weights_memory));
132   }
133 
134   // Check if output need layout conversion. If yes, create memory for
135   // intermediate layer of conv1_dst_memory.
136   bool need_output_conversion =
137       (memory::primitive_desc(conv1_prim_desc.dst_primitive_desc()) !=
138        user_dst_memory.get_primitive_desc());
139   auto conv1_dst_memory = need_output_conversion
140                               ? memory(conv1_prim_desc.dst_primitive_desc())
141                               : user_dst_memory;
142 
143   // Create convolution primitive and add it to net.
144   net.push_back(convolution_forward(conv1_prim_desc, conv1_src_memory,
145                                     conv1_weights_memory, conv1_dst_memory));
146   if (need_output_conversion) {
147     net.push_back(reorder(conv1_dst_memory, user_dst_memory));
148   }
149   stream(stream::kind::eager_nostore).submit(net).wait();
150 }
151 }  // namespace
152 #endif  // ENABLE_MKL
153 
__xla_cpu_runtime_MKLConv2DF32(const void * run_options_ptr,float * out,float * lhs,float * rhs,int64_t input_batch,int64_t input_rows,int64_t input_cols,int64_t input_channels,int64_t kernel_rows,int64_t kernel_cols,int64_t kernel_channels,int64_t kernel_filters,int64_t output_rows,int64_t output_cols,int64_t row_stride,int64_t col_stride,int64_t padding_top,int64_t padding_bottom,int64_t padding_left,int64_t padding_right,int64_t lhs_row_dilation,int64_t lhs_col_dilation,int64_t rhs_row_dilation,int64_t rhs_col_dilation)154 ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_MKLConv2DF32(
155     const void* run_options_ptr, float* out, float* lhs, float* rhs,
156     int64_t input_batch, int64_t input_rows, int64_t input_cols,
157     int64_t input_channels, int64_t kernel_rows, int64_t kernel_cols,
158     int64_t kernel_channels, int64_t kernel_filters, int64_t output_rows,
159     int64_t output_cols, int64_t row_stride, int64_t col_stride,
160     int64_t padding_top, int64_t padding_bottom, int64_t padding_left,
161     int64_t padding_right, int64_t lhs_row_dilation, int64_t lhs_col_dilation,
162     int64_t rhs_row_dilation, int64_t rhs_col_dilation) {
163 #ifdef ENABLE_MKL
164   // Since MKL_DNN cannot handle transposed convolution, this is handled by
165   // Eigen.
166   if (lhs_row_dilation > 1 || lhs_col_dilation > 1) {
167     __xla_cpu_runtime_EigenConvF32(
168         run_options_ptr, out, lhs, rhs, input_batch, input_rows, input_cols,
169         input_channels, kernel_rows, kernel_cols, kernel_channels,
170         kernel_filters, output_rows, output_cols, row_stride, col_stride,
171         padding_top, padding_bottom, padding_left, padding_right,
172         lhs_row_dilation, lhs_col_dilation, rhs_row_dilation, rhs_col_dilation);
173   } else {
174     MKLConvImpl(nullptr, out, lhs, rhs, input_batch, input_rows, input_cols,
175                 input_channels, kernel_rows, kernel_cols, kernel_channels,
176                 kernel_filters, output_rows, output_cols, row_stride,
177                 col_stride, padding_top, padding_bottom, padding_left,
178                 padding_right, lhs_row_dilation, lhs_col_dilation,
179                 rhs_row_dilation, rhs_col_dilation);
180   }
181 #else
182   std::cerr << "Attempt to call MKL Conv2D runtime library without defining "
183                "ENABLE_MKL. Add --config=mkl to build with MKL.";
184   exit(1);
185 #endif  // ENABLE_MKL
186 }
187