1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/xla/service/cpu/runtime_conv2d_mkl.h"
16
17 #include <iostream>
18
19 #include "absl/base/dynamic_annotations.h"
20 #include "tensorflow/compiler/xla/executable_run_options.h"
21
22 #ifdef ENABLE_MKL
23 #include <omp.h>
24
25 #include "dnnl.hpp"
26 #include "tensorflow/compiler/xla/service/cpu/runtime_conv2d.h"
27
28 namespace {
29
30 // Downcast an int64_t to int and check if value is in range.
ToInt(int64_t input)31 int ToInt(int64_t input) {
32 int output = static_cast<int>(input);
33 if (static_cast<int64_t>(output) != input) {
34 std::cerr << "Error occurred in downcasting int64_t to int32_t: Value "
35 << input << " is out-of-range for type int32_t. \n";
36 exit(1);
37 }
38 return output;
39 }
40
41 using dnnl::convolution_direct;
42 using dnnl::convolution_forward;
43 using dnnl::engine;
44 using dnnl::memory;
45 using dnnl::padding_kind;
46 using dnnl::primitive;
47 using dnnl::prop_kind;
48 using dnnl::reorder;
49 using dnnl::stream;
50
51 template <typename EigenDevice, typename ScalarType>
MKLConvImpl(const EigenDevice & device,ScalarType * out,ScalarType * lhs,ScalarType * rhs,int64_t input_batch,int64_t input_rows,int64_t input_cols,int64_t input_channels,int64_t kernel_rows,int64_t kernel_cols,int64_t kernel_channels,int64_t kernel_filters,int64_t output_rows,int64_t output_cols,int64_t row_stride,int64_t col_stride,int64_t padding_top,int64_t padding_bottom,int64_t padding_left,int64_t padding_right,int64_t lhs_row_dilation,int64_t lhs_col_dilation,int64_t rhs_row_dilation,int64_t rhs_col_dilation)52 void MKLConvImpl(const EigenDevice& device, ScalarType* out, ScalarType* lhs,
53 ScalarType* rhs, int64_t input_batch, int64_t input_rows,
54 int64_t input_cols, int64_t input_channels,
55 int64_t kernel_rows, int64_t kernel_cols,
56 int64_t kernel_channels, int64_t kernel_filters,
57 int64_t output_rows, int64_t output_cols, int64_t row_stride,
58 int64_t col_stride, int64_t padding_top,
59 int64_t padding_bottom, int64_t padding_left,
60 int64_t padding_right, int64_t lhs_row_dilation,
61 int64_t lhs_col_dilation, int64_t rhs_row_dilation,
62 int64_t rhs_col_dilation) {
63 auto cpu_engine = engine(engine::cpu, 0);
64
65 // Create a vector primitive to hold the network.
66 std::vector<primitive> net;
67
68 // Since memory::dims takes int for each dimension, we downcast the int64_t
69 // values to int using the ToInt function defined above.
70 memory::dims conv1_src_dim = {ToInt(input_batch), ToInt(input_channels),
71 ToInt(input_rows), ToInt(input_cols)};
72 memory::dims conv1_weights_dim = {ToInt(kernel_filters),
73 ToInt(kernel_channels), ToInt(kernel_rows),
74 ToInt(kernel_cols)};
75 memory::dims conv1_dst_dim = {ToInt(input_batch), ToInt(kernel_filters),
76 ToInt(output_rows), ToInt(output_cols)};
77 memory::dims conv1_strides = {ToInt(row_stride), ToInt(col_stride)};
78 // Note: In MKL_DNN dilation starts from 0.
79 memory::dims conv1_dilates = {ToInt(rhs_row_dilation - 1),
80 ToInt(rhs_col_dilation - 1)};
81 memory::dims conv1_padding_l = {ToInt(padding_top), ToInt(padding_left)};
82 memory::dims conv1_padding_r = {ToInt(padding_bottom), ToInt(padding_right)};
83
84 // Create memory for user data. Input and output data have format of NHWC and
85 // kernel data has format of HWIO.
86 // Note that as a convention in MKL-DNN, the dimensions of the data is always
87 // described in NCHW/IOHW, regardless of the actual layout of the data.
88 auto user_src_memory =
89 memory({{{conv1_src_dim}, memory::data_type::f32, memory::format::nhwc},
90 cpu_engine},
91 lhs);
92 auto user_weights_memory = memory(
93 {{{conv1_weights_dim}, memory::data_type::f32, memory::format::hwio},
94 cpu_engine},
95 rhs);
96 auto user_dst_memory =
97 memory({{{conv1_dst_dim}, memory::data_type::f32, memory::format::nhwc},
98 cpu_engine},
99 out);
100
101 // Create memory descriptors for convolution data with no specified format for
102 // best performance.
103 auto conv1_src_mem_desc = memory::desc(
104 {conv1_src_dim}, memory::data_type::f32, memory::format::any);
105 auto conv1_weights_mem_desc = memory::desc(
106 {conv1_weights_dim}, memory::data_type::f32, memory::format::any);
107 auto conv1_dst_mem_desc = memory::desc(
108 {conv1_dst_dim}, memory::data_type::f32, memory::format::any);
109
110 // Create a convolution.
111 auto conv1_desc = convolution_forward::desc(
112 prop_kind::forward_inference, convolution_direct, conv1_src_mem_desc,
113 conv1_weights_mem_desc, conv1_dst_mem_desc, conv1_strides, conv1_dilates,
114 conv1_padding_l, conv1_padding_r, padding_kind::zero);
115 auto conv1_prim_desc =
116 convolution_forward::primitive_desc(conv1_desc, cpu_engine);
117
118 // Create reorders for data and weights if layout requested by convolution is
119 // different from NCHW/OIHW.
120 auto conv1_src_memory = user_src_memory;
121 if (memory::primitive_desc(conv1_prim_desc.src_primitive_desc()) !=
122 user_src_memory.get_primitive_desc()) {
123 conv1_src_memory = memory(conv1_prim_desc.src_primitive_desc());
124 net.push_back(reorder(user_src_memory, conv1_src_memory));
125 }
126
127 auto conv1_weights_memory = user_weights_memory;
128 if (memory::primitive_desc(conv1_prim_desc.weights_primitive_desc()) !=
129 user_weights_memory.get_primitive_desc()) {
130 conv1_weights_memory = memory(conv1_prim_desc.weights_primitive_desc());
131 net.push_back(reorder(user_weights_memory, conv1_weights_memory));
132 }
133
134 // Check if output need layout conversion. If yes, create memory for
135 // intermediate layer of conv1_dst_memory.
136 bool need_output_conversion =
137 (memory::primitive_desc(conv1_prim_desc.dst_primitive_desc()) !=
138 user_dst_memory.get_primitive_desc());
139 auto conv1_dst_memory = need_output_conversion
140 ? memory(conv1_prim_desc.dst_primitive_desc())
141 : user_dst_memory;
142
143 // Create convolution primitive and add it to net.
144 net.push_back(convolution_forward(conv1_prim_desc, conv1_src_memory,
145 conv1_weights_memory, conv1_dst_memory));
146 if (need_output_conversion) {
147 net.push_back(reorder(conv1_dst_memory, user_dst_memory));
148 }
149 stream(stream::kind::eager_nostore).submit(net).wait();
150 }
151 } // namespace
152 #endif // ENABLE_MKL
153
__xla_cpu_runtime_MKLConv2DF32(const void * run_options_ptr,float * out,float * lhs,float * rhs,int64_t input_batch,int64_t input_rows,int64_t input_cols,int64_t input_channels,int64_t kernel_rows,int64_t kernel_cols,int64_t kernel_channels,int64_t kernel_filters,int64_t output_rows,int64_t output_cols,int64_t row_stride,int64_t col_stride,int64_t padding_top,int64_t padding_bottom,int64_t padding_left,int64_t padding_right,int64_t lhs_row_dilation,int64_t lhs_col_dilation,int64_t rhs_row_dilation,int64_t rhs_col_dilation)154 ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_MKLConv2DF32(
155 const void* run_options_ptr, float* out, float* lhs, float* rhs,
156 int64_t input_batch, int64_t input_rows, int64_t input_cols,
157 int64_t input_channels, int64_t kernel_rows, int64_t kernel_cols,
158 int64_t kernel_channels, int64_t kernel_filters, int64_t output_rows,
159 int64_t output_cols, int64_t row_stride, int64_t col_stride,
160 int64_t padding_top, int64_t padding_bottom, int64_t padding_left,
161 int64_t padding_right, int64_t lhs_row_dilation, int64_t lhs_col_dilation,
162 int64_t rhs_row_dilation, int64_t rhs_col_dilation) {
163 #ifdef ENABLE_MKL
164 // Since MKL_DNN cannot handle transposed convolution, this is handled by
165 // Eigen.
166 if (lhs_row_dilation > 1 || lhs_col_dilation > 1) {
167 __xla_cpu_runtime_EigenConvF32(
168 run_options_ptr, out, lhs, rhs, input_batch, input_rows, input_cols,
169 input_channels, kernel_rows, kernel_cols, kernel_channels,
170 kernel_filters, output_rows, output_cols, row_stride, col_stride,
171 padding_top, padding_bottom, padding_left, padding_right,
172 lhs_row_dilation, lhs_col_dilation, rhs_row_dilation, rhs_col_dilation);
173 } else {
174 MKLConvImpl(nullptr, out, lhs, rhs, input_batch, input_rows, input_cols,
175 input_channels, kernel_rows, kernel_cols, kernel_channels,
176 kernel_filters, output_rows, output_cols, row_stride,
177 col_stride, padding_top, padding_bottom, padding_left,
178 padding_right, lhs_row_dilation, lhs_col_dilation,
179 rhs_row_dilation, rhs_col_dilation);
180 }
181 #else
182 std::cerr << "Attempt to call MKL Conv2D runtime library without defining "
183 "ENABLE_MKL. Add --config=mkl to build with MKL.";
184 exit(1);
185 #endif // ENABLE_MKL
186 }
187