xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/isotonic_regression_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <cmath>
16 
17 #include "tensorflow/core/framework/bounds_check.h"
18 #include "tensorflow/core/framework/op.h"
19 #include "tensorflow/core/framework/op_kernel.h"
20 #include "tensorflow/core/framework/register_types.h"
21 #include "tensorflow/core/platform/threadpool.h"
22 
23 namespace {
24 
25 using ::int64_t;
26 using tensorflow::int32;
27 
28 // The # of ops estimated for the isotonic regression solver is the size of the
29 // array multiplied by this constant. This is used by the thread pool executor
30 // when deciding how many threads to use.
31 constexpr int kCostMultiplier = 100;
32 
33 // In separable chain-constrained problems, i.e., those of the form
34 //
35 //  min_{y_1 >= y_2 >= ... >= y_n} \sum_{i=1}^n h_i(y_i)
36 //
37 // for any set of convex functions h_i, of particular importance are contiguous
38 // segments of coordinates, which this class represents. The interval is assumed
39 // to be half-closed and equal to [col_start(), col_limit()).
40 class Segment {
41  public:
42   // Creates the [col_index, col_index+1).
Segment(int col_index)43   explicit Segment(int col_index)
44       : col_start_(col_index), col_limit_(col_index + 1) {}
45 
46   // Returns the number of points in the segment.
num_points() const47   int num_points() const { return col_limit_ - col_start_; }
48 
49   // Merge another segment into this one.
merge_with(const Segment & other)50   void merge_with(const Segment& other) {
51     col_start_ = std::min(col_start_, other.col_start());
52     col_limit_ = std::max(col_limit_, other.col_limit());
53   }
54 
col_start() const55   int col_start() const { return col_start_; }
56 
col_limit() const57   int col_limit() const { return col_limit_; }
58 
59  private:
60   int col_start_;
61   int col_limit_;
62 };
63 
64 // If we can solve for each segment {j, j+1, ..., j+m} the interval problem
65 //
66 //  argmin_y \sum_{i=j}^{j+m} h_i(y),
67 //
68 // we can use such an oracle to solve the general problem. The following class
69 // implements such an oracle for the case when h_i is the squared (l2) loss,
70 // or formally h_i(y) = (y - x_i)^2, where x_i is the i-th input.
71 //
72 // TODO(josipd): We know how and can extend this to other functions if needed.
73 template <typename T>
74 class L2PavaSegment : public Segment {
75  public:
L2PavaSegment(T y,int col_index)76   L2PavaSegment(T y, int col_index)
77       : Segment(col_index), y_sum_(y), minimum_(y) {}
78 
merge_with(const L2PavaSegment & other)79   void merge_with(const L2PavaSegment& other) {
80     Segment::merge_with(other);
81     y_sum_ += other.y_sum_;
82     minimum_ = y_sum_ / static_cast<T>(num_points());
83   }
84 
minimum() const85   T minimum() const { return minimum_; }
86 
87  private:
88   T y_sum_;    // The sum of the inputs within the segment.
89   T minimum_;  // The minimum, cached to avoid expensive divisions.
90 };
91 
92 // Solve one of the problems in the batch (the row_index'th one) using the
93 // pool-adjacent violators algorithm (PAVA).
94 //
95 // The PAVA algorithm goes back to
96 //
97 // Nonmetric Multidimensional Scaling: A numerical method
98 // Kruskal, J. B. (1964), Psychometrika (1964)
99 //
100 // For a more recent analysis, please refer to
101 //
102 // Active set algorithms for isotonic regression; a unifying framework
103 // Best, Michael J., and Nilotpal Chakravarti
104 // Mathematical Programming 47.1-3 (1990)
105 //
106 // Intuitively, the algorithm splits the inputs into blocks (starting from
107 // singleton ones), and then whenever there are two consecutive blocks whose
108 // minima violate the inequality constraint, they are merged. The solution is
109 // then block-wise constant, each block equal to the corresponding minimum.
110 //
111 // The tensors should be two dimensional, and the segment objects should
112 // support the minimum() and merge_with() methods.
113 template <typename SegmentType, typename FloatTensor, typename IntTensor>
solve_pava(const std::function<SegmentType (int,int)> & make_segment,FloatTensor * solution,IntTensor * segments,int row_index)114 void solve_pava(const std::function<SegmentType(int, int)>& make_segment,
115                 FloatTensor* solution, IntTensor* segments, int row_index) {
116   const size_t n = solution->dimensions()[1];
117   std::vector<SegmentType> pools;
118   pools.reserve(n);
119 
120   for (size_t col_index = 0; col_index < n; ++col_index) {
121     pools.push_back(make_segment(row_index, col_index));
122 
123     // While the last two pools are decreasing, merge them.
124     while (pools.size() > 1 &&
125            pools.rbegin()->minimum() > (pools.rbegin() + 1)->minimum()) {
126       (pools.rbegin() + 1)->merge_with(*pools.rbegin());
127       pools.pop_back();
128     }
129   }
130 
131   int segment_id = 0;
132   for (const auto& pool : pools) {
133     const auto pool_minimum = pool.minimum();
134     // The matrices are row major, so we can scan the memory linearly.
135     auto* solution_ptr = &(*solution)(row_index, pool.col_start());
136     auto* segments_ptr = &(*segments)(row_index, pool.col_start());
137     for (int i = pool.col_start(); i < pool.col_limit(); ++i) {
138       *solution_ptr++ = pool_minimum;
139       *segments_ptr++ = segment_id;
140     }
141     ++segment_id;
142   }
143 }
144 
145 // Solve a batch of problems using the pool-adjacent violators algorithm.
146 // The problems are solved in parallel using tensorflow's thread pool.
147 template <typename SegmentType, typename FloatTensor, typename IntTensor>
solve_pava_batch(const std::function<SegmentType (int,int)> & make_segment,FloatTensor * solution,IntTensor * segments,tensorflow::OpKernelContext * context)148 void solve_pava_batch(const std::function<SegmentType(int, int)>& make_segment,
149                       FloatTensor* solution, IntTensor* segments,
150                       tensorflow::OpKernelContext* context) {
151   const int batch_size = solution->dimensions()[0];
152   const int problem_size = solution->dimensions()[1];
153 
154   auto thread_pool =
155       context->device()->tensorflow_cpu_worker_threads()->workers;
156 
157   thread_pool->ParallelFor(
158       batch_size, kCostMultiplier * problem_size,
159       [&make_segment, &solution, &segments](int64_t row_start,
160                                             int64_t row_limit) {
161         // Casting to int is safe, as we do boundary checks in `Compute`.
162         for (int row_index = static_cast<int>(row_start);
163              row_index < static_cast<int>(row_limit); ++row_index) {
164           solve_pava(make_segment, solution, segments, row_index);
165         }
166       });
167 }
168 
169 }  // namespace
170 
171 template <typename Tin, typename Tout>
172 class IsotonicRegressionOp : public tensorflow::OpKernel {
173  public:
IsotonicRegressionOp(tensorflow::OpKernelConstruction * context)174   explicit IsotonicRegressionOp(tensorflow::OpKernelConstruction* context)
175       : tensorflow::OpKernel(context) {}
176 
Compute(tensorflow::OpKernelContext * context)177   void Compute(tensorflow::OpKernelContext* context) override {
178     // Grab the input tensor.
179     const tensorflow::Tensor& input_tensor = context->input(0);
180     const auto input = input_tensor.flat_inner_dims<Tin, 2>();
181     int int_max = std::numeric_limits<int32>::max();
182     OP_REQUIRES(context,
183                 tensorflow::FastBoundsCheck(input.dimensions()[0], int_max) &&
184                     tensorflow::FastBoundsCheck(input.dimensions()[1], int_max),
185                 tensorflow::errors::InvalidArgument("Tensor too large"));
186 
187     // Create the output tensor holding the minimizers.
188     const auto shape = input_tensor.shape();
189     tensorflow::Tensor* output_tensor = nullptr;
190     OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
191                                 {0}, 0, shape, &output_tensor));
192     auto output = output_tensor->flat_inner_dims<Tout, 2>();
193 
194     // Create the output tensor holidng the segment memberships.
195     tensorflow::Tensor* segments_tensor = nullptr;
196     OP_REQUIRES_OK(context,
197                    context->allocate_output(1, shape, &segments_tensor));
198     auto segments = segments_tensor->flat_inner_dims<int>();
199 
200     auto make_l2_segment = [&input](int row_index, int col_index) {
201       return L2PavaSegment<Tout>(input(row_index, col_index), col_index);
202     };
203     solve_pava_batch<L2PavaSegment<Tout>>(make_l2_segment, &output, &segments,
204                                           context);
205   }
206 };
207 
208 #define REGISTER_CPU_KERNEL(Tin, Tout)                               \
209   REGISTER_KERNEL_BUILDER(Name("IsotonicRegression")                 \
210                               .Device(tensorflow::DEVICE_CPU)        \
211                               .TypeConstraint<Tin>("T")              \
212                               .TypeConstraint<Tout>("output_dtype"), \
213                           IsotonicRegressionOp<Tin, Tout>);
214 
215 // Float types have the same input and output.
216 #define REGISTER_CPU_SAME_KERNEL(T) REGISTER_CPU_KERNEL(T, T)
217 TF_CALL_FLOAT_TYPES(REGISTER_CPU_SAME_KERNEL);
218 
219 // 8 and 16 bit integers get converted to 32 bit floats.
220 #define REGISTER_CPU_KERNEL_FLOAT(Tin) REGISTER_CPU_KERNEL(Tin, float)
221 TF_CALL_int16(REGISTER_CPU_KERNEL_FLOAT);
222 TF_CALL_int8(REGISTER_CPU_KERNEL_FLOAT);
223 
224 // 32 and 64 bit integers get converted to 64 bit floats.
225 #define REGISTER_CPU_KERNEL_DOUBLE(Tin) REGISTER_CPU_KERNEL(Tin, double)
226 TF_CALL_int64(REGISTER_CPU_KERNEL_DOUBLE);
227 TF_CALL_int32(REGISTER_CPU_KERNEL_DOUBLE);
228