xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/cpu/runtime_conv_impl.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_CONV2D_IMPL_H_
16 #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_CONV2D_IMPL_H_
17 
18 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
19 #include "tensorflow/core/kernels/eigen_spatial_convolutions.h"
20 
21 #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
22 #include "tensorflow/core/kernels/eigen_contraction_kernel.h"
23 #endif
24 
25 // 'tensorflow' namespace is used so that types don't require qualification.
26 namespace tensorflow {
27 namespace xla {
28 
29 template <typename EigenDevice, typename ScalarType>
EigenConv2DImpl(const EigenDevice & device,ScalarType * out,ScalarType * lhs,ScalarType * rhs,Eigen::Index input_batch,Eigen::Index input_x,Eigen::Index input_y,Eigen::Index input_channels,Eigen::Index kernel_x,Eigen::Index kernel_y,Eigen::Index kernel_channels,Eigen::Index kernel_filters,Eigen::Index output_x,Eigen::Index output_y,Eigen::Index x_stride,Eigen::Index y_stride,Eigen::Index padding_x_before,Eigen::Index padding_x_after,Eigen::Index padding_y_before,Eigen::Index padding_y_after,Eigen::Index lhs_x_dilation,Eigen::Index lhs_y_dilation,Eigen::Index rhs_x_dilation,Eigen::Index rhs_y_dilation,Eigen::Index feature_group_count)30 void EigenConv2DImpl(
31     const EigenDevice& device, ScalarType* out, ScalarType* lhs,
32     ScalarType* rhs, Eigen::Index input_batch, Eigen::Index input_x,
33     Eigen::Index input_y, Eigen::Index input_channels, Eigen::Index kernel_x,
34     Eigen::Index kernel_y, Eigen::Index kernel_channels,
35     Eigen::Index kernel_filters, Eigen::Index output_x, Eigen::Index output_y,
36     Eigen::Index x_stride, Eigen::Index y_stride, Eigen::Index padding_x_before,
37     Eigen::Index padding_x_after, Eigen::Index padding_y_before,
38     Eigen::Index padding_y_after, Eigen::Index lhs_x_dilation,
39     Eigen::Index lhs_y_dilation, Eigen::Index rhs_x_dilation,
40     Eigen::Index rhs_y_dilation, Eigen::Index feature_group_count) {
41   const Eigen::TensorMap<Eigen::Tensor<const ScalarType, 4, Eigen::RowMajor>,
42                          Eigen::Aligned>
43       input(lhs, input_batch, input_x, input_y, input_channels);
44 
45   const Eigen::TensorMap<Eigen::Tensor<const ScalarType, 4, Eigen::RowMajor>,
46                          Eigen::Aligned>
47       kernel(rhs, kernel_x, kernel_y, kernel_channels, kernel_filters);
48 
49   Eigen::TensorMap<Eigen::Tensor<ScalarType, 4, Eigen::RowMajor>,
50                    Eigen::Aligned>
51       output(out, input_batch, output_x, output_y, kernel_filters);
52 
53   Eigen::array<Eigen::IndexPair<Eigen::Index>, 1> contract_dims;
54   contract_dims[0] = Eigen::IndexPair<Eigen::Index>(1, 0);
55 
56   Eigen::DSizes<Eigen::Index, 5> input_reshaped_dims;
57   input_reshaped_dims[0] = input_batch;
58   input_reshaped_dims[1] = input_x;
59   input_reshaped_dims[2] = input_y;
60   input_reshaped_dims[3] = feature_group_count;
61   input_reshaped_dims[4] = input_channels / feature_group_count;
62 
63   Eigen::DSizes<Eigen::Index, 5> output_reshaped_dims;
64   output_reshaped_dims[0] = input_batch;
65   output_reshaped_dims[1] = output_x;
66   output_reshaped_dims[2] = output_y;
67   output_reshaped_dims[3] = feature_group_count;
68   output_reshaped_dims[4] = kernel_filters / feature_group_count;
69 
70   // Molds the output of the patch extraction code into a 2d tensor:
71   // - the first dimension (dims[0]): the patch values to be multiplied with the
72   //   kernels
73   // - the second dimension (dims[1]): everything else
74   Eigen::DSizes<Eigen::Index, 2> pre_contract_dims;
75   pre_contract_dims[0] = output_y * output_x * input_batch;
76   pre_contract_dims[1] = kernel_channels * kernel_y * kernel_x;
77 
78   // Molds the output of the contraction into the shape expected by the user:
79   Eigen::DSizes<Eigen::Index, 4> post_contract_dims;
80   post_contract_dims[0] = input_batch;
81   post_contract_dims[1] = output_x;
82   post_contract_dims[2] = output_y;
83   post_contract_dims[3] = kernel_filters / feature_group_count;
84 
85   Eigen::DSizes<Eigen::Index, 3> kernel_dims;
86   kernel_dims[0] = kernel_channels * kernel_y * kernel_x;
87   kernel_dims[1] = feature_group_count;
88   kernel_dims[2] = kernel_filters / feature_group_count;
89 
90   for (Eigen::Index i = 0; i < feature_group_count; ++i) {
91     // The row and column dimensions must be flipped when passed to Eigen.
92     output.reshape(output_reshaped_dims).chip(i, 3).device(device) =
93         input.reshape(input_reshaped_dims)
94             .chip(i, 3)
95             .extract_image_patches(
96                 kernel_y, kernel_x, y_stride, x_stride, rhs_y_dilation,
97                 rhs_x_dilation, lhs_y_dilation, lhs_x_dilation,
98                 padding_y_before, padding_y_after, padding_x_before,
99                 padding_x_after, static_cast<ScalarType>(0.0f))
100             .reshape(pre_contract_dims)
101             .contract(kernel.reshape(kernel_dims).chip(i, 1), contract_dims)
102             .reshape(post_contract_dims);
103   }
104 }
105 
106 template <typename EigenDevice, typename ScalarType>
EigenConv3DImpl(const EigenDevice & device,ScalarType * out,ScalarType * lhs,ScalarType * rhs,Eigen::Index input_batch,Eigen::Index input_x,Eigen::Index input_y,Eigen::Index input_z,Eigen::Index input_channels,Eigen::Index kernel_x,Eigen::Index kernel_y,Eigen::Index kernel_z,Eigen::Index kernel_channels,Eigen::Index kernel_filters,Eigen::Index output_x,Eigen::Index output_y,Eigen::Index output_z,Eigen::Index x_stride,Eigen::Index y_stride,Eigen::Index z_stride,Eigen::Index padding_x_before,Eigen::Index padding_x_after,Eigen::Index padding_y_before,Eigen::Index padding_y_after,Eigen::Index padding_z_before,Eigen::Index padding_z_after,Eigen::Index lhs_x_dilation,Eigen::Index lhs_y_dilation,Eigen::Index lhs_z_dilation,Eigen::Index rhs_x_dilation,Eigen::Index rhs_y_dilation,Eigen::Index rhs_z_dilation,Eigen::Index feature_group_count)107 void EigenConv3DImpl(
108     const EigenDevice& device, ScalarType* out, ScalarType* lhs,
109     ScalarType* rhs, Eigen::Index input_batch, Eigen::Index input_x,
110     Eigen::Index input_y, Eigen::Index input_z, Eigen::Index input_channels,
111     Eigen::Index kernel_x, Eigen::Index kernel_y, Eigen::Index kernel_z,
112     Eigen::Index kernel_channels, Eigen::Index kernel_filters,
113     Eigen::Index output_x, Eigen::Index output_y, Eigen::Index output_z,
114     Eigen::Index x_stride, Eigen::Index y_stride, Eigen::Index z_stride,
115     Eigen::Index padding_x_before, Eigen::Index padding_x_after,
116     Eigen::Index padding_y_before, Eigen::Index padding_y_after,
117     Eigen::Index padding_z_before, Eigen::Index padding_z_after,
118     Eigen::Index lhs_x_dilation, Eigen::Index lhs_y_dilation,
119     Eigen::Index lhs_z_dilation, Eigen::Index rhs_x_dilation,
120     Eigen::Index rhs_y_dilation, Eigen::Index rhs_z_dilation,
121     Eigen::Index feature_group_count) {
122   using ConstTType =
123       Eigen::TensorMap<Eigen::Tensor<const ScalarType, 5, Eigen::RowMajor>,
124                        Eigen::Aligned>;
125   const ConstTType input(lhs, input_batch, input_x, input_y, input_z,
126                          input_channels);
127 
128   const ConstTType kernel(rhs, kernel_x, kernel_y, kernel_z, kernel_channels,
129                           kernel_filters);
130 
131   Eigen::TensorMap<Eigen::Tensor<ScalarType, 5, Eigen::RowMajor>,
132                    Eigen::Aligned>
133       output(out, input_batch, output_x, output_y, output_z, kernel_filters);
134 
135   Eigen::DSizes<Eigen::Index, 6> input_reshaped_dims;
136   input_reshaped_dims[0] = input_batch;
137   input_reshaped_dims[1] = input_x;
138   input_reshaped_dims[2] = input_y;
139   input_reshaped_dims[3] = input_z;
140   input_reshaped_dims[4] = feature_group_count;
141   input_reshaped_dims[5] = input_channels / feature_group_count;
142 
143   Eigen::DSizes<Eigen::Index, 6> output_reshaped_dims;
144   output_reshaped_dims[0] = input_batch;
145   output_reshaped_dims[1] = output_x;
146   output_reshaped_dims[2] = output_y;
147   output_reshaped_dims[3] = output_z;
148   output_reshaped_dims[4] = feature_group_count;
149   output_reshaped_dims[5] = kernel_filters / feature_group_count;
150 
151   Eigen::array<Eigen::IndexPair<Eigen::Index>, 1> contract_dims;
152   contract_dims[0] = Eigen::IndexPair<Eigen::Index>(1, 0);
153 
154   // Molds the output of the patch extraction code into a 2d tensor:
155   // - the first dimension (dims[0]): the patch values to be multiplied with the
156   //   kernels
157   // - the second dimension (dims[1]): everything else
158   Eigen::DSizes<Eigen::Index, 2> pre_contract_dims;
159   pre_contract_dims[0] = output_x * output_y * output_z * input_batch;
160   pre_contract_dims[1] = kernel_channels * kernel_x * kernel_y * kernel_z;
161 
162   // Molds the output of the contraction into the shape expected by the user:
163   Eigen::DSizes<Eigen::Index, 5> post_contract_dims;
164   post_contract_dims[0] = input_batch;
165   post_contract_dims[1] = output_x;
166   post_contract_dims[2] = output_y;
167   post_contract_dims[3] = output_z;
168   post_contract_dims[4] = kernel_filters / feature_group_count;
169 
170   Eigen::DSizes<Eigen::Index, 3> kernel_dims;
171   kernel_dims[0] = kernel_channels * kernel_x * kernel_y * kernel_z;
172   kernel_dims[1] = feature_group_count;
173   kernel_dims[2] = kernel_filters / feature_group_count;
174 
175   for (Eigen::Index i = 0; i < feature_group_count; ++i) {
176     // The dimension order must be flipped when passed to Eigen.
177     auto input_chip = input.reshape(input_reshaped_dims).chip(i, 4);
178     auto patches =
179         Eigen::TensorVolumePatchOp<Eigen::Dynamic, Eigen::Dynamic,
180                                    Eigen::Dynamic, decltype(input_chip)>(
181             input_chip, kernel_z, kernel_y, kernel_x, z_stride, y_stride,
182             x_stride, rhs_z_dilation, rhs_y_dilation, rhs_x_dilation,
183             lhs_z_dilation, lhs_y_dilation, lhs_x_dilation, padding_z_before,
184             padding_z_after, padding_y_before, padding_y_after,
185             padding_x_before, padding_x_after, static_cast<ScalarType>(0.0f));
186 
187     output.reshape(output_reshaped_dims).chip(i, 4).device(device) =
188         patches.reshape(pre_contract_dims)
189             .contract(kernel.reshape(kernel_dims).chip(i, 1), contract_dims)
190             .reshape(post_contract_dims);
191   }
192 }
193 
194 }  // namespace xla
195 }  // namespace tensorflow
196 
197 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_CONV2D_IMPL_H_
198