xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/fractional_max_pool_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #define EIGEN_USE_THREADS
16 
17 #include <algorithm>
18 #include <cmath>
19 #include <random>
20 #include <vector>
21 
22 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
23 #include "tensorflow/core/framework/numeric_op.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/op_requires.h"
26 #include "tensorflow/core/kernels/fractional_pool_common.h"
27 #include "tensorflow/core/lib/random/random.h"
28 #include "tensorflow/core/platform/errors.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/core/platform/mutex.h"
31 #include "tensorflow/core/util/guarded_philox_random.h"
32 
33 namespace tensorflow {
34 typedef Eigen::ThreadPoolDevice CPUDevice;
35 
36 template <typename T>
37 class FractionalMaxPoolOp : public OpKernel {
38  public:
FractionalMaxPoolOp(OpKernelConstruction * context)39   explicit FractionalMaxPoolOp(OpKernelConstruction* context)
40       : OpKernel(context) {
41     OP_REQUIRES_OK(context, context->GetAttr("pooling_ratio", &pooling_ratio_));
42     OP_REQUIRES_OK(context, context->GetAttr("pseudo_random", &pseudo_random_));
43     OP_REQUIRES_OK(context, context->GetAttr("overlapping", &overlapping_));
44 
45     OP_REQUIRES(context, pooling_ratio_.size() == 4,
46                 errors::InvalidArgument("pooling_ratio field must "
47                                         "specify 4 dimensions"));
48 
49     OP_REQUIRES(
50         context, pooling_ratio_[0] == 1 || pooling_ratio_[3] == 1,
51         errors::Unimplemented("Fractional max pooling is not yet "
52                               "supported on the batch nor channel dimension."));
53 
54     OP_REQUIRES_OK(context, context->GetAttr("deterministic", &deterministic_));
55     OP_REQUIRES_OK(context, context->GetAttr("seed", &seed_));
56     OP_REQUIRES_OK(context, context->GetAttr("seed2", &seed2_));
57     if (deterministic_) {
58       // If both seeds are not set when deterministic_ is true, force set seeds.
59       if ((seed_ == 0) && (seed2_ == 0)) {
60         seed_ = random::New64();
61         seed2_ = random::New64();
62       }
63     } else {
64       OP_REQUIRES(
65           context, (seed_ == 0) && (seed2_ == 0),
66           errors::InvalidArgument(
67               "Both seed and seed2 should be 0 if deterministic is false."));
68     }
69   }
70 
Compute(OpKernelContext * context)71   void Compute(OpKernelContext* context) override {
72     typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
73         ConstEigenMatrixMap;
74     typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
75         EigenMatrixMap;
76 
77     constexpr int tensor_in_and_out_dims = 4;
78 
79     const Tensor& tensor_in = context->input(0);
80     OP_REQUIRES(context, tensor_in.dims() == tensor_in_and_out_dims,
81                 errors::InvalidArgument("tensor_in must be 4-dimensional"));
82 
83     std::vector<int> input_size(tensor_in_and_out_dims);
84     std::vector<int> output_size(tensor_in_and_out_dims);
85     for (int i = 0; i < tensor_in_and_out_dims; ++i) {
86       input_size[i] = tensor_in.dim_size(i);
87 
88       OP_REQUIRES(
89           context, input_size[i] >= pooling_ratio_[i],
90           errors::InvalidArgument("Pooling ratio is higher than input "
91                                   "dimension size for dimension ",
92                                   i, ". Input dim size: ", input_size[i],
93                                   " pooling ratio: ", pooling_ratio_[i]));
94     }
95     // Output size.
96     for (int i = 0; i < tensor_in_and_out_dims; ++i) {
97       // This must match the same logic in the shape function in
98       // core/ops/nn_ops.cc.
99       output_size[i] =
100           static_cast<int>(std::floor(input_size[i] / pooling_ratio_[i]));
101       DCHECK_GT(output_size[i], 0);
102     }
103 
104     // Generate pooling sequence.
105     std::vector<int64_t> height_cum_seq;
106     std::vector<int64_t> width_cum_seq;
107     GuardedPhiloxRandom generator;
108     generator.Init(seed_, seed2_);
109     height_cum_seq = GeneratePoolingSequence(input_size[1], output_size[1],
110                                              &generator, pseudo_random_);
111     width_cum_seq = GeneratePoolingSequence(input_size[2], output_size[2],
112                                             &generator, pseudo_random_);
113 
114     // Prepare output.
115     Tensor* output_tensor = nullptr;
116     OP_REQUIRES_OK(context, context->allocate_output(
117                                 0,
118                                 TensorShape({output_size[0], output_size[1],
119                                              output_size[2], output_size[3]}),
120                                 &output_tensor));
121     Tensor* output_height_seq_tensor = nullptr;
122     OP_REQUIRES_OK(
123         context,
124         context->allocate_output(
125             1, TensorShape({static_cast<int64_t>(height_cum_seq.size())}),
126             &output_height_seq_tensor));
127     Tensor* output_width_seq_tensor = nullptr;
128     OP_REQUIRES_OK(
129         context,
130         context->allocate_output(
131             2, TensorShape({static_cast<int64_t>(width_cum_seq.size())}),
132             &output_width_seq_tensor));
133 
134     ConstEigenMatrixMap in_mat(tensor_in.flat<T>().data(), input_size[3],
135                                input_size[2] * input_size[1] * input_size[0]);
136 
137     EigenMatrixMap out_mat(output_tensor->flat<T>().data(), output_size[3],
138                            output_size[2] * output_size[1] * output_size[0]);
139 
140     // Initializes the output tensor with MIN<T>.
141     output_tensor->flat<T>().setConstant(Eigen::NumTraits<T>::lowest());
142 
143     auto output_height_seq_flat = output_height_seq_tensor->flat<int64_t>();
144     auto output_width_seq_flat = output_width_seq_tensor->flat<int64_t>();
145 
146     // Set output tensors.
147     for (int i = 0; i < height_cum_seq.size(); ++i) {
148       output_height_seq_flat(i) = height_cum_seq[i];
149     }
150 
151     for (int i = 0; i < width_cum_seq.size(); ++i) {
152       output_width_seq_flat(i) = width_cum_seq[i];
153     }
154 
155     // For both input and output,
156     // 0: batch
157     // 1: height / row
158     // 2: width / col
159     // 3: depth / channel
160     const int64_t height_max = input_size[1] - 1;
161     const int64_t width_max = input_size[2] - 1;
162     for (int64_t b = 0; b < input_size[0]; ++b) {
163       // height sequence.
164       for (int64_t hs = 0; hs < height_cum_seq.size() - 1; ++hs) {
165         // height start and end.
166         const int64_t height_start = height_cum_seq[hs];
167         int64_t height_end =
168             overlapping_ ? height_cum_seq[hs + 1] : height_cum_seq[hs + 1] - 1;
169         height_end = std::min(height_end, height_max);
170 
171         // width sequence.
172         for (int64_t ws = 0; ws < width_cum_seq.size() - 1; ++ws) {
173           const int64_t out_offset =
174               (b * output_size[1] + hs) * output_size[2] + ws;
175           // width start and end.
176           const int64_t width_start = width_cum_seq[ws];
177           int64_t width_end =
178               overlapping_ ? width_cum_seq[ws + 1] : width_cum_seq[ws + 1] - 1;
179           width_end = std::min(width_end, width_max);
180           for (int64_t h = height_start; h <= height_end; ++h) {
181             for (int64_t w = width_start; w <= width_end; ++w) {
182               const int64_t in_offset =
183                   (b * input_size[1] + h) * input_size[2] + w;
184               out_mat.col(out_offset) =
185                   out_mat.col(out_offset).cwiseMax(in_mat.col(in_offset));
186             }
187           }
188         }
189       }
190     }
191   }
192 
193  private:
194   bool deterministic_;
195   int64_t seed_;
196   int64_t seed2_;
197   std::vector<float> pooling_ratio_;
198   bool pseudo_random_;
199   bool overlapping_;
200 };
201 
202 #define REGISTER_FRACTIONALMAXPOOL(type)                                      \
203   REGISTER_KERNEL_BUILDER(                                                    \
204       Name("FractionalMaxPool").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
205       FractionalMaxPoolOp<type>)
206 
207 REGISTER_FRACTIONALMAXPOOL(int32);
208 REGISTER_FRACTIONALMAXPOOL(int64_t);
209 REGISTER_FRACTIONALMAXPOOL(float);
210 REGISTER_FRACTIONALMAXPOOL(double);
211 
212 #undef REGISTER_FRACTIONALMAXPOOL
213 
214 static const int kInvalidMaxPoolingIndex = -1;
215 
216 template <class T>
217 class FractionalMaxPoolGradOp : public OpKernel {
218  public:
FractionalMaxPoolGradOp(OpKernelConstruction * context)219   explicit FractionalMaxPoolGradOp(OpKernelConstruction* context)
220       : OpKernel(context) {
221     OP_REQUIRES_OK(context, context->GetAttr("overlapping", &overlapping_));
222   }
223 
Compute(OpKernelContext * context)224   void Compute(OpKernelContext* context) override {
225     // There are two steps when calculating gradient for FractionalMaxPool.
226     // 1) Walk through the process of calculating fractional pooling given
227     //    pooling region; however, in the process, keep track of where the max
228     //    element comes from. (arg_max)
229     // 2) Populate the value of out_backprop to where arg_max indicates. If
230     //    we support overlapping, it is likely to have multiple out_backprop[i]
231     //    propagates back to the same arg_max value.
232     typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
233         ConstEigenMatrixMap;
234     typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
235         EigenMatrixMap;
236     typedef Eigen::Map<Eigen::Matrix<int64, Eigen::Dynamic, Eigen::Dynamic>>
237         EigenIndexMatrixMap;
238 
239     const Tensor& tensor_in = context->input(0);
240     const Tensor& tensor_out = context->input(1);
241     const Tensor& out_backprop = context->input(2);
242     const Tensor& height_seq_tensor = context->input(3);
243     const Tensor& width_seq_tensor = context->input(4);
244 
245     // Just to make it similar to FractionalMaxPoolOp.
246     constexpr int tensor_in_and_out_dims = 4;
247     OP_REQUIRES(
248         context, tensor_in.dims() == tensor_in_and_out_dims,
249         errors::InvalidArgument("orig_input should be a tensor of rank 4, got ",
250                                 tensor_in.DebugString()));
251     OP_REQUIRES(context, tensor_in.NumElements() > 0,
252                 errors::InvalidArgument("orig_input must not be empty, got ",
253                                         tensor_in.DebugString()));
254     OP_REQUIRES(context, tensor_out.dims() == tensor_in_and_out_dims,
255                 errors::InvalidArgument(
256                     "orig_output should be a tensor of rank 4, got ",
257                     tensor_out.DebugString()));
258     OP_REQUIRES(context, tensor_out.NumElements() > 0,
259                 errors::InvalidArgument("orig_output must not be empty, got ",
260                                         tensor_out.DebugString()));
261     std::vector<int64_t> input_size(tensor_in_and_out_dims);
262     std::vector<int64_t> output_size(tensor_in_and_out_dims);
263     for (int i = 0; i < tensor_in_and_out_dims; ++i) {
264       input_size[i] = tensor_in.dim_size(i);
265     }
266     for (int i = 0; i < tensor_in_and_out_dims; ++i) {
267       output_size[i] = tensor_out.dim_size(i);
268     }
269 
270     // ---------
271     // Step 1
272     // ---------
273     Tensor tensor_out_dup;
274     OP_REQUIRES_OK(context, context->forward_input_or_allocate_temp(
275                                 {1}, DataTypeToEnum<T>::v(), tensor_out.shape(),
276                                 &tensor_out_dup));
277     Tensor tensor_out_arg_max;
278     OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<int64_t>::v(),
279                                                    tensor_out.shape(),
280                                                    &tensor_out_arg_max));
281     // Find arg_max for each tensor_out
282     ConstEigenMatrixMap tensor_in_mat(
283         tensor_in.flat<T>().data(), input_size[3],
284         input_size[2] * input_size[1] * input_size[0]);
285     EigenMatrixMap tensor_out_dup_mat(
286         tensor_out_dup.flat<T>().data(), output_size[3],
287         output_size[2] * output_size[1] * output_size[0]);
288     EigenIndexMatrixMap tensor_out_arg_max_mat(
289         tensor_out_arg_max.flat<int64_t>().data(), output_size[3],
290         output_size[2] * output_size[1] * output_size[0]);
291 
292     tensor_out_arg_max.flat<int64_t>().setConstant(kInvalidMaxPoolingIndex);
293     // Initializes the duplicate output tensor with MIN<T>.
294     tensor_out_dup.flat<T>().setConstant(Eigen::NumTraits<T>::lowest());
295 
296     auto height_seq_tensor_flat = height_seq_tensor.flat<int64_t>();
297     auto width_seq_tensor_flat = width_seq_tensor.flat<int64_t>();
298 
299     // Now walk through the process of fractional max pooling again.
300     // For both input and output,
301     // 0: batch
302     // 1: height / row
303     // 2: width / col
304     // 3: depth / channel
305     const int64_t height_max = input_size[1] - 1;
306     const int64_t width_max = input_size[2] - 1;
307     for (int64_t b = 0; b < input_size[0]; ++b) {
308       // height sequence.
309       for (int64_t hs = 0; hs < height_seq_tensor.dim_size(0) - 1; ++hs) {
310         // height start and end.
311         const int64_t height_start = height_seq_tensor_flat(hs);
312         int64_t height_end = overlapping_ ? height_seq_tensor_flat(hs + 1)
313                                           : height_seq_tensor_flat(hs + 1) - 1;
314         height_end = std::min(height_end, height_max);
315 
316         // width sequence.
317         for (int64_t ws = 0; ws < width_seq_tensor.dim_size(0) - 1; ++ws) {
318           const int64_t out_index =
319               (b * output_size[1] + hs) * output_size[2] + ws;
320           // width start and end.
321           const int64_t width_start = width_seq_tensor_flat(ws);
322           int64_t width_end = overlapping_ ? width_seq_tensor_flat(ws + 1)
323                                            : width_seq_tensor_flat(ws + 1) - 1;
324           width_end = std::min(width_end, width_max);
325           for (int64_t h = height_start; h <= height_end; ++h) {
326             for (int64_t w = width_start; w <= width_end; ++w) {
327               const int64_t in_index =
328                   (b * input_size[1] + h) * input_size[2] + w;
329               // Walk through each channel (depth).
330               for (int64_t d = 0; d < input_size[3]; ++d) {
331                 const T& input_ref = tensor_in_mat.coeffRef(d, in_index);
332                 T& output_ref = tensor_out_dup_mat.coeffRef(d, out_index);
333                 int64_t& out_arg_max_ref =
334                     tensor_out_arg_max_mat.coeffRef(d, out_index);
335                 if (output_ref < input_ref ||
336                     out_arg_max_ref == kInvalidMaxPoolingIndex) {
337                   output_ref = input_ref;
338                   int input_offset = in_index * input_size[3] + d;
339                   out_arg_max_ref = input_offset;
340                 }
341               }
342             }
343           }
344         }
345       }
346     }
347 
348     // Check tensor_out_dup is the same as tensor_out.
349     ConstEigenMatrixMap tensor_out_mat(
350         tensor_out.flat<T>().data(), output_size[3],
351         output_size[2] * output_size[1] * output_size[0]);
352     const int64_t num_reshaped_cols =
353         output_size[2] * output_size[1] * output_size[0];
354     for (int64_t i = 0; i < num_reshaped_cols; ++i) {
355       for (int64_t j = 0; j < output_size[3]; ++j) {
356         OP_REQUIRES(context, tensor_out_dup_mat(j, i) == tensor_out_mat(j, i),
357                     errors::InvalidArgument(
358                         "tensor_out_dup is not the same as tensor_out"));
359       }
360     }
361 
362     Tensor* output = nullptr;
363     OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
364                                 {0}, 0, tensor_in.shape(), &output));
365     output->flat<T>().setZero();
366 
367     auto out_backprop_flat = out_backprop.flat<T>();
368     auto input_backprop_flat = output->flat<T>();
369     auto out_arg_max_flat = tensor_out_arg_max.flat<int64_t>();
370     int num_total_outputs = out_backprop_flat.size();
371     int num_total_inputs = input_backprop_flat.size();
372 
373     for (int index = 0; index < num_total_outputs; ++index) {
374       int input_backprop_index = out_arg_max_flat(index);
375       OP_REQUIRES(
376           context,
377           input_backprop_index >= 0 && input_backprop_index < num_total_inputs,
378           errors::InvalidArgument(
379               "Invalid input backprop index: ", input_backprop_index, ", ",
380               num_total_inputs));
381       input_backprop_flat(input_backprop_index) += out_backprop_flat(index);
382     }
383   }
384 
385  private:
386   bool overlapping_;
387 };
388 
389 #define REGISTER_FRACTIONALMAXPOOLGRAD(type)              \
390   REGISTER_KERNEL_BUILDER(Name("FractionalMaxPoolGrad")   \
391                               .Device(DEVICE_CPU)         \
392                               .TypeConstraint<type>("T"), \
393                           FractionalMaxPoolGradOp<type>)
394 
395 REGISTER_FRACTIONALMAXPOOLGRAD(int32);
396 REGISTER_FRACTIONALMAXPOOLGRAD(int64_t);
397 REGISTER_FRACTIONALMAXPOOLGRAD(float);
398 REGISTER_FRACTIONALMAXPOOLGRAD(double);
399 
400 #undef REGISTER_FRACTIONALMAXPOOLGRAD
401 }  // namespace tensorflow
402