1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_CONV2D_IMPL_H_
16 #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_CONV2D_IMPL_H_
17
18 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
19 #include "tensorflow/core/kernels/eigen_spatial_convolutions.h"
20
21 #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
22 #include "tensorflow/core/kernels/eigen_contraction_kernel.h"
23 #endif
24
25 // 'tensorflow' namespace is used so that types don't require qualification.
26 namespace tensorflow {
27 namespace xla {
28
29 template <typename EigenDevice, typename ScalarType>
EigenConv2DImpl(const EigenDevice & device,ScalarType * out,ScalarType * lhs,ScalarType * rhs,Eigen::Index input_batch,Eigen::Index input_x,Eigen::Index input_y,Eigen::Index input_channels,Eigen::Index kernel_x,Eigen::Index kernel_y,Eigen::Index kernel_channels,Eigen::Index kernel_filters,Eigen::Index output_x,Eigen::Index output_y,Eigen::Index x_stride,Eigen::Index y_stride,Eigen::Index padding_x_before,Eigen::Index padding_x_after,Eigen::Index padding_y_before,Eigen::Index padding_y_after,Eigen::Index lhs_x_dilation,Eigen::Index lhs_y_dilation,Eigen::Index rhs_x_dilation,Eigen::Index rhs_y_dilation,Eigen::Index feature_group_count)30 void EigenConv2DImpl(
31 const EigenDevice& device, ScalarType* out, ScalarType* lhs,
32 ScalarType* rhs, Eigen::Index input_batch, Eigen::Index input_x,
33 Eigen::Index input_y, Eigen::Index input_channels, Eigen::Index kernel_x,
34 Eigen::Index kernel_y, Eigen::Index kernel_channels,
35 Eigen::Index kernel_filters, Eigen::Index output_x, Eigen::Index output_y,
36 Eigen::Index x_stride, Eigen::Index y_stride, Eigen::Index padding_x_before,
37 Eigen::Index padding_x_after, Eigen::Index padding_y_before,
38 Eigen::Index padding_y_after, Eigen::Index lhs_x_dilation,
39 Eigen::Index lhs_y_dilation, Eigen::Index rhs_x_dilation,
40 Eigen::Index rhs_y_dilation, Eigen::Index feature_group_count) {
41 const Eigen::TensorMap<Eigen::Tensor<const ScalarType, 4, Eigen::RowMajor>,
42 Eigen::Aligned>
43 input(lhs, input_batch, input_x, input_y, input_channels);
44
45 const Eigen::TensorMap<Eigen::Tensor<const ScalarType, 4, Eigen::RowMajor>,
46 Eigen::Aligned>
47 kernel(rhs, kernel_x, kernel_y, kernel_channels, kernel_filters);
48
49 Eigen::TensorMap<Eigen::Tensor<ScalarType, 4, Eigen::RowMajor>,
50 Eigen::Aligned>
51 output(out, input_batch, output_x, output_y, kernel_filters);
52
53 Eigen::array<Eigen::IndexPair<Eigen::Index>, 1> contract_dims;
54 contract_dims[0] = Eigen::IndexPair<Eigen::Index>(1, 0);
55
56 Eigen::DSizes<Eigen::Index, 5> input_reshaped_dims;
57 input_reshaped_dims[0] = input_batch;
58 input_reshaped_dims[1] = input_x;
59 input_reshaped_dims[2] = input_y;
60 input_reshaped_dims[3] = feature_group_count;
61 input_reshaped_dims[4] = input_channels / feature_group_count;
62
63 Eigen::DSizes<Eigen::Index, 5> output_reshaped_dims;
64 output_reshaped_dims[0] = input_batch;
65 output_reshaped_dims[1] = output_x;
66 output_reshaped_dims[2] = output_y;
67 output_reshaped_dims[3] = feature_group_count;
68 output_reshaped_dims[4] = kernel_filters / feature_group_count;
69
70 // Molds the output of the patch extraction code into a 2d tensor:
71 // - the first dimension (dims[0]): the patch values to be multiplied with the
72 // kernels
73 // - the second dimension (dims[1]): everything else
74 Eigen::DSizes<Eigen::Index, 2> pre_contract_dims;
75 pre_contract_dims[0] = output_y * output_x * input_batch;
76 pre_contract_dims[1] = kernel_channels * kernel_y * kernel_x;
77
78 // Molds the output of the contraction into the shape expected by the user:
79 Eigen::DSizes<Eigen::Index, 4> post_contract_dims;
80 post_contract_dims[0] = input_batch;
81 post_contract_dims[1] = output_x;
82 post_contract_dims[2] = output_y;
83 post_contract_dims[3] = kernel_filters / feature_group_count;
84
85 Eigen::DSizes<Eigen::Index, 3> kernel_dims;
86 kernel_dims[0] = kernel_channels * kernel_y * kernel_x;
87 kernel_dims[1] = feature_group_count;
88 kernel_dims[2] = kernel_filters / feature_group_count;
89
90 for (Eigen::Index i = 0; i < feature_group_count; ++i) {
91 // The row and column dimensions must be flipped when passed to Eigen.
92 output.reshape(output_reshaped_dims).chip(i, 3).device(device) =
93 input.reshape(input_reshaped_dims)
94 .chip(i, 3)
95 .extract_image_patches(
96 kernel_y, kernel_x, y_stride, x_stride, rhs_y_dilation,
97 rhs_x_dilation, lhs_y_dilation, lhs_x_dilation,
98 padding_y_before, padding_y_after, padding_x_before,
99 padding_x_after, static_cast<ScalarType>(0.0f))
100 .reshape(pre_contract_dims)
101 .contract(kernel.reshape(kernel_dims).chip(i, 1), contract_dims)
102 .reshape(post_contract_dims);
103 }
104 }
105
106 template <typename EigenDevice, typename ScalarType>
EigenConv3DImpl(const EigenDevice & device,ScalarType * out,ScalarType * lhs,ScalarType * rhs,Eigen::Index input_batch,Eigen::Index input_x,Eigen::Index input_y,Eigen::Index input_z,Eigen::Index input_channels,Eigen::Index kernel_x,Eigen::Index kernel_y,Eigen::Index kernel_z,Eigen::Index kernel_channels,Eigen::Index kernel_filters,Eigen::Index output_x,Eigen::Index output_y,Eigen::Index output_z,Eigen::Index x_stride,Eigen::Index y_stride,Eigen::Index z_stride,Eigen::Index padding_x_before,Eigen::Index padding_x_after,Eigen::Index padding_y_before,Eigen::Index padding_y_after,Eigen::Index padding_z_before,Eigen::Index padding_z_after,Eigen::Index lhs_x_dilation,Eigen::Index lhs_y_dilation,Eigen::Index lhs_z_dilation,Eigen::Index rhs_x_dilation,Eigen::Index rhs_y_dilation,Eigen::Index rhs_z_dilation,Eigen::Index feature_group_count)107 void EigenConv3DImpl(
108 const EigenDevice& device, ScalarType* out, ScalarType* lhs,
109 ScalarType* rhs, Eigen::Index input_batch, Eigen::Index input_x,
110 Eigen::Index input_y, Eigen::Index input_z, Eigen::Index input_channels,
111 Eigen::Index kernel_x, Eigen::Index kernel_y, Eigen::Index kernel_z,
112 Eigen::Index kernel_channels, Eigen::Index kernel_filters,
113 Eigen::Index output_x, Eigen::Index output_y, Eigen::Index output_z,
114 Eigen::Index x_stride, Eigen::Index y_stride, Eigen::Index z_stride,
115 Eigen::Index padding_x_before, Eigen::Index padding_x_after,
116 Eigen::Index padding_y_before, Eigen::Index padding_y_after,
117 Eigen::Index padding_z_before, Eigen::Index padding_z_after,
118 Eigen::Index lhs_x_dilation, Eigen::Index lhs_y_dilation,
119 Eigen::Index lhs_z_dilation, Eigen::Index rhs_x_dilation,
120 Eigen::Index rhs_y_dilation, Eigen::Index rhs_z_dilation,
121 Eigen::Index feature_group_count) {
122 using ConstTType =
123 Eigen::TensorMap<Eigen::Tensor<const ScalarType, 5, Eigen::RowMajor>,
124 Eigen::Aligned>;
125 const ConstTType input(lhs, input_batch, input_x, input_y, input_z,
126 input_channels);
127
128 const ConstTType kernel(rhs, kernel_x, kernel_y, kernel_z, kernel_channels,
129 kernel_filters);
130
131 Eigen::TensorMap<Eigen::Tensor<ScalarType, 5, Eigen::RowMajor>,
132 Eigen::Aligned>
133 output(out, input_batch, output_x, output_y, output_z, kernel_filters);
134
135 Eigen::DSizes<Eigen::Index, 6> input_reshaped_dims;
136 input_reshaped_dims[0] = input_batch;
137 input_reshaped_dims[1] = input_x;
138 input_reshaped_dims[2] = input_y;
139 input_reshaped_dims[3] = input_z;
140 input_reshaped_dims[4] = feature_group_count;
141 input_reshaped_dims[5] = input_channels / feature_group_count;
142
143 Eigen::DSizes<Eigen::Index, 6> output_reshaped_dims;
144 output_reshaped_dims[0] = input_batch;
145 output_reshaped_dims[1] = output_x;
146 output_reshaped_dims[2] = output_y;
147 output_reshaped_dims[3] = output_z;
148 output_reshaped_dims[4] = feature_group_count;
149 output_reshaped_dims[5] = kernel_filters / feature_group_count;
150
151 Eigen::array<Eigen::IndexPair<Eigen::Index>, 1> contract_dims;
152 contract_dims[0] = Eigen::IndexPair<Eigen::Index>(1, 0);
153
154 // Molds the output of the patch extraction code into a 2d tensor:
155 // - the first dimension (dims[0]): the patch values to be multiplied with the
156 // kernels
157 // - the second dimension (dims[1]): everything else
158 Eigen::DSizes<Eigen::Index, 2> pre_contract_dims;
159 pre_contract_dims[0] = output_x * output_y * output_z * input_batch;
160 pre_contract_dims[1] = kernel_channels * kernel_x * kernel_y * kernel_z;
161
162 // Molds the output of the contraction into the shape expected by the user:
163 Eigen::DSizes<Eigen::Index, 5> post_contract_dims;
164 post_contract_dims[0] = input_batch;
165 post_contract_dims[1] = output_x;
166 post_contract_dims[2] = output_y;
167 post_contract_dims[3] = output_z;
168 post_contract_dims[4] = kernel_filters / feature_group_count;
169
170 Eigen::DSizes<Eigen::Index, 3> kernel_dims;
171 kernel_dims[0] = kernel_channels * kernel_x * kernel_y * kernel_z;
172 kernel_dims[1] = feature_group_count;
173 kernel_dims[2] = kernel_filters / feature_group_count;
174
175 for (Eigen::Index i = 0; i < feature_group_count; ++i) {
176 // The dimension order must be flipped when passed to Eigen.
177 auto input_chip = input.reshape(input_reshaped_dims).chip(i, 4);
178 auto patches =
179 Eigen::TensorVolumePatchOp<Eigen::Dynamic, Eigen::Dynamic,
180 Eigen::Dynamic, decltype(input_chip)>(
181 input_chip, kernel_z, kernel_y, kernel_x, z_stride, y_stride,
182 x_stride, rhs_z_dilation, rhs_y_dilation, rhs_x_dilation,
183 lhs_z_dilation, lhs_y_dilation, lhs_x_dilation, padding_z_before,
184 padding_z_after, padding_y_before, padding_y_after,
185 padding_x_before, padding_x_after, static_cast<ScalarType>(0.0f));
186
187 output.reshape(output_reshaped_dims).chip(i, 4).device(device) =
188 patches.reshape(pre_contract_dims)
189 .contract(kernel.reshape(kernel_dims).chip(i, 1), contract_dims)
190 .reshape(post_contract_dims);
191 }
192 }
193
194 } // namespace xla
195 } // namespace tensorflow
196
197 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_CONV2D_IMPL_H_
198