xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/HistogramKernel.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/Histogram.h>
3 
4 #include <ATen/core/Tensor.h>
5 #include <ATen/Dispatch.h>
6 #include <ATen/Parallel.h>
7 #include <c10/util/irange.h>
8 
9 #ifndef AT_PER_OPERATOR_HEADERS
10 #include <ATen/Functions.h>
11 #include <ATen/NativeFunctions.h>
12 #else
13 #include <ATen/ops/aminmax.h>
14 #include <ATen/ops/sum.h>
15 #include <ATen/ops/zeros.h>
16 #include <ATen/ops/zeros_like_ops.h>
17 #endif
18 
19 #include <algorithm>
20 #include <numeric>
21 #include <functional>
22 
23 namespace at::native {
24 
25 namespace {
26 
27 constexpr int64_t HISTOGRAM_GRAIN_SIZE = 200;
28 
29 /* The main algorithm. Expects that the input tensor has shape (N, D).
30  * Expects that bin_edges contains D one-dimensional tensors, each specifying
31  * an increasing sequences of bin edges.
32  *
33  * Interprets the input as N different D-dimensional coordinates and maps them
34  * into the D-dimensional bins defined by bin_edges, accumulating a D-dimensional
35  * histogram in the hist tensor.
36  *
37  * Accepts a template argument of type BIN_SELECTION_ALGORITHM specifying how
38  * the scalars in each dimension should be mapped into the dimension's bins:
39  *
40  *     - LINEAR_INTERPOLATION: each bin edge sequence must form a linear progression.
41  *       Scalars are mapped to bins by computing
42  *           (element - leftmost_edge)/(rightmost_edge - leftmost_edge) * bin_ct
43  *       and truncating the result to an integer.
44  *
45  *       This is the fastest option, but its results may not be perfectly consistent
46  *       with the boundaries specified in bin_edges due to precision issues.
47  *
48  *       Used by torch.histc, which doesn't need consistency with bin_edges as it does not
49  *       return bin_edges. Additionally, this implementation is identical to the legacy histc
50  *       implementation, which was replaced when histogram was implemented.
51  *
52  *     - LINEAR_INTERPOLATION_WITH_LOCAL_SEARCH: Also expects that each bin edge sequence
53  *       forms a linear progression. For each scalar, if 'pos' is the bin selected by the
54  *       LINEAR_INTERPOLATION approach, this approach inspects the boundaries in bin_edges to
55  *       place the scalar into pos - 1, pos, or pos + 1. The "local search" over neighboring
56  *       bins allows for correction of misclassifications due to precision issues (a scalar
57  *       very close to a bin_edge may be misclassified by LINEAR_INTERPOLATION).
58  *
59  *       Should produce the same output as the general case BINARY_SEARCH, but run about
60  *       3x faster asymptotically.
61  *
62  *       Used by torch.histogram for cases in which bin_edges is constructed using
63  *       torch.linspace. The behavior of LINEAR_INTERPOLATION may not perfectly align
64  *       with linspace bin_edges due to precision issues. torch.histogram returns both
65  *       the hist and bin_edges tensors as output, so the "local search" is needed to
66  *       keep its output internally consistent.
67  *
68  *     - BINARY_SEARCH: Handles torch.histogram's general case by by searching over the
69  *       elements of bin_edges. Implemented using std::upper_bound.
70  *
71  * See discussion at https://github.com/pytorch/pytorch/pull/58780#discussion_r648604866
72  * for further details on relative performance of the bin selection algorithms.
73  */
74 enum BIN_SELECTION_ALGORITHM {
75     LINEAR_INTERPOLATION,
76     LINEAR_INTERPOLATION_WITH_LOCAL_SEARCH,
77     BINARY_SEARCH,
78 };
79 template<typename input_t, BIN_SELECTION_ALGORITHM algorithm>
histogramdd_cpu_contiguous(Tensor & hist,const TensorList & bin_edges,const Tensor & input,const std::optional<Tensor> & weight)80 void histogramdd_cpu_contiguous(Tensor& hist, const TensorList& bin_edges,
81         const Tensor& input, const std::optional<Tensor>& weight) {
82     TORCH_INTERNAL_ASSERT(input.dim() == 2);
83 
84     const int64_t N = input.size(0);
85     if (weight.has_value()) {
86         TORCH_INTERNAL_ASSERT(weight.value().dim() == 1 && weight.value().numel() == N);
87     }
88 
89     const int64_t D = input.size(1);
90     TORCH_INTERNAL_ASSERT(int64_t(bin_edges.size()) == D);
91     for (const auto dim : c10::irange(D)) {
92         TORCH_INTERNAL_ASSERT(bin_edges[dim].is_contiguous());
93         TORCH_INTERNAL_ASSERT(hist.size(dim) + 1 == bin_edges[dim].numel());
94     }
95 
96     if (D == 0) {
97         // hist is an empty tensor in this case; nothing to do here
98         return;
99     }
100 
101     TensorAccessor<const input_t, 2> accessor_in = input.accessor<const input_t, 2>();
102 
103     /* Constructs a std::optional<TensorAccessor> containing an accessor if
104      * the optional weight tensor has a value.
105      */
106     const auto accessor_wt = weight.has_value()
107             ? std::optional<TensorAccessor<const input_t, 1>>(weight.value().accessor<const input_t, 1>())
108             : std::optional<TensorAccessor<const input_t, 1>>();
109 
110     std::vector<input_t*> bin_seq(D);
111     std::vector<int64_t> num_bin_edges(D);
112     std::vector<input_t> leftmost_edge(D), rightmost_edge(D);
113 
114     for (const auto dim : c10::irange(D)) {
115         bin_seq[dim] = bin_edges[dim].data_ptr<input_t>();
116         num_bin_edges[dim] = bin_edges[dim].numel();
117         leftmost_edge[dim] = bin_seq[dim][0];
118         rightmost_edge[dim] = bin_seq[dim][num_bin_edges[dim] - 1];
119     }
120 
121     int64_t GRAIN_SIZE = std::max(int64_t(1), HISTOGRAM_GRAIN_SIZE / D);
122 
123     /* Parallelizes processing of input using at::parallel_for.
124      * Each thread accumulates a local result into their own slice of
125      * thread_histograms which get summed together at the end.
126      */
127     const auto num_threads = at::get_num_threads();
128     const auto hist_sizes = hist.sizes();
129     DimVector thread_hist_sizes(hist_sizes.size() + 1);
130     thread_hist_sizes[0] = num_threads;
131     std::copy(hist_sizes.begin(), hist_sizes.end(),
132               thread_hist_sizes.begin() + 1);
133     Tensor thread_histograms = at::zeros(thread_hist_sizes, hist.dtype());
134     TORCH_INTERNAL_ASSERT(thread_histograms.is_contiguous());
135 
136     at::parallel_for(0, N, GRAIN_SIZE, [&](int64_t start, int64_t end) {
137         const auto tid = at::get_thread_num();
138         auto hist_strides = thread_histograms.strides();
139         input_t *hist_local_data = thread_histograms.data_ptr<input_t>();
140 
141         // View only this thread's local results
142         hist_local_data += hist_strides[0] * tid;
143         hist_strides = hist_strides.slice(1);
144 
145         for (const auto i : c10::irange(start, end)) {
146             bool skip_elt = false;
147             int64_t hist_index = 0;
148 
149             for (const auto dim : c10::irange(D)) {
150                 const input_t elt = accessor_in[i][dim];
151 
152                 // Skips elements which fall outside the specified bins and NaN elements
153                 if (!(elt >= leftmost_edge[dim] && elt <= rightmost_edge[dim])) {
154                     skip_elt = true;
155                     break;
156                 }
157 
158                 int64_t pos = -1;
159 
160                 if (algorithm == BINARY_SEARCH) {
161                     // Handles the general case via binary search on the bin edges.
162                     pos = std::upper_bound(bin_seq[dim], bin_seq[dim] + num_bin_edges[dim], elt)
163                             - bin_seq[dim] - 1;
164                 } else if (algorithm == LINEAR_INTERPOLATION
165                         || algorithm == LINEAR_INTERPOLATION_WITH_LOCAL_SEARCH) {
166                     /* When bin_edges is known to be a linear progression, maps elt to
167                      * the appropriate bin via simple division.
168                      */
169                     pos = static_cast<int64_t>((elt - leftmost_edge[dim])
170                             * (num_bin_edges[dim] - 1)
171                             / (rightmost_edge[dim] - leftmost_edge[dim]));
172 
173                     /* Ensures consistency with bin_edges by checking the bins to the left and right
174                      * of the selected position. Necessary for cases in which an element very close
175                      * to a bin edge may be misclassified by simple division.
176                      */
177                     if (algorithm == LINEAR_INTERPOLATION_WITH_LOCAL_SEARCH) {
178                         int64_t pos_min = std::max(static_cast<int64_t>(0), pos - 1);
179                         int64_t pos_max = std::min(pos + 2, num_bin_edges[dim]);
180                         pos = std::upper_bound(bin_seq[dim] + pos_min, bin_seq[dim] + pos_max, elt)
181                                 - bin_seq[dim] - 1;
182                     }
183                 } else {
184                     TORCH_INTERNAL_ASSERT(false);
185                 }
186 
187                 // Unlike other bins, the rightmost bin includes its right boundary
188                 if (pos == (num_bin_edges[dim] - 1)) {
189                     pos -= 1;
190                 }
191 
192                 hist_index += hist_strides[dim] * pos;
193             }
194 
195             if (!skip_elt) {
196                 // In the unweighted case, the default weight is 1
197                 input_t wt = accessor_wt.has_value() ? accessor_wt.value()[i] : static_cast<input_t>(1);
198 
199                 hist_local_data[hist_index] += wt;
200             }
201         }
202     });
203 
204     at::sum_out(hist, thread_histograms, /*dim=*/{0});
205 }
206 
207 /* Some pre- and post- processing steps for the main algorithm.
208  * Initializes hist to 0, calls into the main algorithm, and normalizes output if necessary.
209  */
210 template<BIN_SELECTION_ALGORITHM bin_algorithm>
histogramdd_out_cpu_template(const Tensor & self,const std::optional<Tensor> & weight,bool density,Tensor & hist,const TensorList & bin_edges)211 void histogramdd_out_cpu_template(const Tensor& self, const std::optional<Tensor>& weight, bool density,
212         Tensor& hist, const TensorList& bin_edges) {
213     hist.fill_(0);
214 
215     const int64_t N = self.size(-1);
216     const int64_t M = std::accumulate(self.sizes().begin(), self.sizes().end() - 1,
217             (int64_t)1, std::multiplies<int64_t>());
218 
219     const Tensor reshaped_input = self.reshape({M, N});
220 
221     const auto reshaped_weight = weight.has_value()
222             ? std::optional<Tensor>(weight.value().reshape({M}))
223             : std::optional<Tensor>();
224 
225     std::vector<Tensor> bin_edges_contig(bin_edges.size());
226     for (const auto dim : c10::irange(bin_edges_contig.size())) {
227         bin_edges_contig[dim] = bin_edges[dim].contiguous();
228     }
229 
230     AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, self.scalar_type(), "histogram_cpu", [&]() {
231         histogramdd_cpu_contiguous<scalar_t, bin_algorithm>(
232                 hist, bin_edges_contig, reshaped_input, reshaped_weight);
233     });
234 
235     /* Divides each bin's value by the total count/weight in all bins,
236      * and by the bin's volume.
237      */
238     if (density) {
239         const auto hist_sum = hist.sum().item();
240         hist.div_(hist_sum);
241 
242          /* For each dimension, divides each bin's value
243           * by the bin's length in that dimension.
244           */
245         for (const auto dim : c10::irange(N)) {
246             const auto bin_lengths = bin_edges[dim].diff();
247 
248             // Used to reshape bin_lengths to align with the corresponding dimension of hist.
249             std::vector<int64_t> shape(N, 1);
250             shape[dim] = bin_lengths.numel();
251 
252             hist.div_(bin_lengths.reshape(shape));
253         }
254     }
255 }
256 
257 /* The general implementation of the histogram kernel. Maps each element of the input tensor
258  * to its corresponding bin by performing a binary search over the elements of bin_edges.
259  *
260  * Refer to histogramdd_out_cpu_template for more details.
261  */
histogramdd_kernel_impl(const Tensor & self,const std::optional<Tensor> & weight,bool density,Tensor & hist,const TensorList & bin_edges)262 static void histogramdd_kernel_impl(const Tensor& self, const std::optional<Tensor>& weight, bool density,
263         Tensor& hist, const TensorList& bin_edges) {
264     histogramdd_out_cpu_template<BINARY_SEARCH>(self, weight, density, hist, bin_edges);
265 }
266 
267 /* A faster version of the histogram kernel for cases in which bin_edges are known
268  * to form a linear progression.
269  *
270  * Refer to histogramdd_out_cpu_template for more details.
271  */
histogramdd_linear_kernel_impl(const Tensor & self,const std::optional<Tensor> & weight,bool density,Tensor & hist,const TensorList & bin_edges,bool local_search)272 static void histogramdd_linear_kernel_impl(const Tensor& self, const std::optional<Tensor>& weight,
273         bool density, Tensor& hist, const TensorList& bin_edges, bool local_search) {
274     if (local_search) {
275         // histogramdd codepath: both hist and bin_edges are eventually returned as output,
276         // so we'll keep them consistent
277         histogramdd_out_cpu_template<LINEAR_INTERPOLATION_WITH_LOCAL_SEARCH>(
278               self, weight, density, hist, bin_edges);
279     } else {
280         // histc codepath: bin_edges are not returned to the caller
281         histogramdd_out_cpu_template<LINEAR_INTERPOLATION>(
282               self, weight, density, hist, bin_edges);
283     }
284 }
285 
286 template<typename scalar_t>
infer_bin_edges_from_input(const Tensor & input,const int64_t N,std::vector<double> & leftmost_edges,std::vector<double> & rightmost_edges)287 void infer_bin_edges_from_input(const Tensor& input, const int64_t N,
288         std::vector<double> &leftmost_edges, std::vector<double> &rightmost_edges) {
289     // Calls aminmax on input with dim=0, reducing all but the innermost dimension of input.
290     auto [min, max] = aminmax(input, 0);
291 
292     TORCH_INTERNAL_ASSERT(min.is_contiguous() && max.is_contiguous());
293 
294     const scalar_t *min_data = min.const_data_ptr<scalar_t>();
295     std::copy(min_data, min_data + N, leftmost_edges.begin());
296 
297     const scalar_t *max_data = max.const_data_ptr<scalar_t>();
298     std::copy(max_data, max_data + N, rightmost_edges.begin());
299 }
300 
histogram_select_outer_bin_edges_impl(const Tensor & input,const int64_t N,std::vector<double> & leftmost_edges,std::vector<double> & rightmost_edges)301 static void histogram_select_outer_bin_edges_impl(const Tensor& input, const int64_t N,
302         std::vector<double> &leftmost_edges, std::vector<double> &rightmost_edges) {
303     AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "histogramdd", [&]() {
304         infer_bin_edges_from_input<scalar_t>(input, N, leftmost_edges, rightmost_edges);
305     });
306 }
307 
308 } // namespace
309 
310 REGISTER_DISPATCH(histogramdd_stub, &histogramdd_kernel_impl);
311 REGISTER_DISPATCH(histogramdd_linear_stub, &histogramdd_linear_kernel_impl);
312 REGISTER_DISPATCH(histogram_select_outer_bin_edges_stub, &histogram_select_outer_bin_edges_impl);
313 
314 } // namespace at::native
315