1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/Histogram.h>
3
4 #include <ATen/core/Tensor.h>
5 #include <ATen/Dispatch.h>
6 #include <ATen/Parallel.h>
7 #include <c10/util/irange.h>
8
9 #ifndef AT_PER_OPERATOR_HEADERS
10 #include <ATen/Functions.h>
11 #include <ATen/NativeFunctions.h>
12 #else
13 #include <ATen/ops/aminmax.h>
14 #include <ATen/ops/sum.h>
15 #include <ATen/ops/zeros.h>
16 #include <ATen/ops/zeros_like_ops.h>
17 #endif
18
19 #include <algorithm>
20 #include <numeric>
21 #include <functional>
22
23 namespace at::native {
24
25 namespace {
26
27 constexpr int64_t HISTOGRAM_GRAIN_SIZE = 200;
28
29 /* The main algorithm. Expects that the input tensor has shape (N, D).
30 * Expects that bin_edges contains D one-dimensional tensors, each specifying
31 * an increasing sequences of bin edges.
32 *
33 * Interprets the input as N different D-dimensional coordinates and maps them
34 * into the D-dimensional bins defined by bin_edges, accumulating a D-dimensional
35 * histogram in the hist tensor.
36 *
37 * Accepts a template argument of type BIN_SELECTION_ALGORITHM specifying how
38 * the scalars in each dimension should be mapped into the dimension's bins:
39 *
40 * - LINEAR_INTERPOLATION: each bin edge sequence must form a linear progression.
41 * Scalars are mapped to bins by computing
42 * (element - leftmost_edge)/(rightmost_edge - leftmost_edge) * bin_ct
43 * and truncating the result to an integer.
44 *
45 * This is the fastest option, but its results may not be perfectly consistent
46 * with the boundaries specified in bin_edges due to precision issues.
47 *
48 * Used by torch.histc, which doesn't need consistency with bin_edges as it does not
49 * return bin_edges. Additionally, this implementation is identical to the legacy histc
50 * implementation, which was replaced when histogram was implemented.
51 *
52 * - LINEAR_INTERPOLATION_WITH_LOCAL_SEARCH: Also expects that each bin edge sequence
53 * forms a linear progression. For each scalar, if 'pos' is the bin selected by the
54 * LINEAR_INTERPOLATION approach, this approach inspects the boundaries in bin_edges to
55 * place the scalar into pos - 1, pos, or pos + 1. The "local search" over neighboring
56 * bins allows for correction of misclassifications due to precision issues (a scalar
57 * very close to a bin_edge may be misclassified by LINEAR_INTERPOLATION).
58 *
59 * Should produce the same output as the general case BINARY_SEARCH, but run about
60 * 3x faster asymptotically.
61 *
62 * Used by torch.histogram for cases in which bin_edges is constructed using
63 * torch.linspace. The behavior of LINEAR_INTERPOLATION may not perfectly align
64 * with linspace bin_edges due to precision issues. torch.histogram returns both
65 * the hist and bin_edges tensors as output, so the "local search" is needed to
66 * keep its output internally consistent.
67 *
68 * - BINARY_SEARCH: Handles torch.histogram's general case by by searching over the
69 * elements of bin_edges. Implemented using std::upper_bound.
70 *
71 * See discussion at https://github.com/pytorch/pytorch/pull/58780#discussion_r648604866
72 * for further details on relative performance of the bin selection algorithms.
73 */
74 enum BIN_SELECTION_ALGORITHM {
75 LINEAR_INTERPOLATION,
76 LINEAR_INTERPOLATION_WITH_LOCAL_SEARCH,
77 BINARY_SEARCH,
78 };
79 template<typename input_t, BIN_SELECTION_ALGORITHM algorithm>
histogramdd_cpu_contiguous(Tensor & hist,const TensorList & bin_edges,const Tensor & input,const std::optional<Tensor> & weight)80 void histogramdd_cpu_contiguous(Tensor& hist, const TensorList& bin_edges,
81 const Tensor& input, const std::optional<Tensor>& weight) {
82 TORCH_INTERNAL_ASSERT(input.dim() == 2);
83
84 const int64_t N = input.size(0);
85 if (weight.has_value()) {
86 TORCH_INTERNAL_ASSERT(weight.value().dim() == 1 && weight.value().numel() == N);
87 }
88
89 const int64_t D = input.size(1);
90 TORCH_INTERNAL_ASSERT(int64_t(bin_edges.size()) == D);
91 for (const auto dim : c10::irange(D)) {
92 TORCH_INTERNAL_ASSERT(bin_edges[dim].is_contiguous());
93 TORCH_INTERNAL_ASSERT(hist.size(dim) + 1 == bin_edges[dim].numel());
94 }
95
96 if (D == 0) {
97 // hist is an empty tensor in this case; nothing to do here
98 return;
99 }
100
101 TensorAccessor<const input_t, 2> accessor_in = input.accessor<const input_t, 2>();
102
103 /* Constructs a std::optional<TensorAccessor> containing an accessor if
104 * the optional weight tensor has a value.
105 */
106 const auto accessor_wt = weight.has_value()
107 ? std::optional<TensorAccessor<const input_t, 1>>(weight.value().accessor<const input_t, 1>())
108 : std::optional<TensorAccessor<const input_t, 1>>();
109
110 std::vector<input_t*> bin_seq(D);
111 std::vector<int64_t> num_bin_edges(D);
112 std::vector<input_t> leftmost_edge(D), rightmost_edge(D);
113
114 for (const auto dim : c10::irange(D)) {
115 bin_seq[dim] = bin_edges[dim].data_ptr<input_t>();
116 num_bin_edges[dim] = bin_edges[dim].numel();
117 leftmost_edge[dim] = bin_seq[dim][0];
118 rightmost_edge[dim] = bin_seq[dim][num_bin_edges[dim] - 1];
119 }
120
121 int64_t GRAIN_SIZE = std::max(int64_t(1), HISTOGRAM_GRAIN_SIZE / D);
122
123 /* Parallelizes processing of input using at::parallel_for.
124 * Each thread accumulates a local result into their own slice of
125 * thread_histograms which get summed together at the end.
126 */
127 const auto num_threads = at::get_num_threads();
128 const auto hist_sizes = hist.sizes();
129 DimVector thread_hist_sizes(hist_sizes.size() + 1);
130 thread_hist_sizes[0] = num_threads;
131 std::copy(hist_sizes.begin(), hist_sizes.end(),
132 thread_hist_sizes.begin() + 1);
133 Tensor thread_histograms = at::zeros(thread_hist_sizes, hist.dtype());
134 TORCH_INTERNAL_ASSERT(thread_histograms.is_contiguous());
135
136 at::parallel_for(0, N, GRAIN_SIZE, [&](int64_t start, int64_t end) {
137 const auto tid = at::get_thread_num();
138 auto hist_strides = thread_histograms.strides();
139 input_t *hist_local_data = thread_histograms.data_ptr<input_t>();
140
141 // View only this thread's local results
142 hist_local_data += hist_strides[0] * tid;
143 hist_strides = hist_strides.slice(1);
144
145 for (const auto i : c10::irange(start, end)) {
146 bool skip_elt = false;
147 int64_t hist_index = 0;
148
149 for (const auto dim : c10::irange(D)) {
150 const input_t elt = accessor_in[i][dim];
151
152 // Skips elements which fall outside the specified bins and NaN elements
153 if (!(elt >= leftmost_edge[dim] && elt <= rightmost_edge[dim])) {
154 skip_elt = true;
155 break;
156 }
157
158 int64_t pos = -1;
159
160 if (algorithm == BINARY_SEARCH) {
161 // Handles the general case via binary search on the bin edges.
162 pos = std::upper_bound(bin_seq[dim], bin_seq[dim] + num_bin_edges[dim], elt)
163 - bin_seq[dim] - 1;
164 } else if (algorithm == LINEAR_INTERPOLATION
165 || algorithm == LINEAR_INTERPOLATION_WITH_LOCAL_SEARCH) {
166 /* When bin_edges is known to be a linear progression, maps elt to
167 * the appropriate bin via simple division.
168 */
169 pos = static_cast<int64_t>((elt - leftmost_edge[dim])
170 * (num_bin_edges[dim] - 1)
171 / (rightmost_edge[dim] - leftmost_edge[dim]));
172
173 /* Ensures consistency with bin_edges by checking the bins to the left and right
174 * of the selected position. Necessary for cases in which an element very close
175 * to a bin edge may be misclassified by simple division.
176 */
177 if (algorithm == LINEAR_INTERPOLATION_WITH_LOCAL_SEARCH) {
178 int64_t pos_min = std::max(static_cast<int64_t>(0), pos - 1);
179 int64_t pos_max = std::min(pos + 2, num_bin_edges[dim]);
180 pos = std::upper_bound(bin_seq[dim] + pos_min, bin_seq[dim] + pos_max, elt)
181 - bin_seq[dim] - 1;
182 }
183 } else {
184 TORCH_INTERNAL_ASSERT(false);
185 }
186
187 // Unlike other bins, the rightmost bin includes its right boundary
188 if (pos == (num_bin_edges[dim] - 1)) {
189 pos -= 1;
190 }
191
192 hist_index += hist_strides[dim] * pos;
193 }
194
195 if (!skip_elt) {
196 // In the unweighted case, the default weight is 1
197 input_t wt = accessor_wt.has_value() ? accessor_wt.value()[i] : static_cast<input_t>(1);
198
199 hist_local_data[hist_index] += wt;
200 }
201 }
202 });
203
204 at::sum_out(hist, thread_histograms, /*dim=*/{0});
205 }
206
207 /* Some pre- and post- processing steps for the main algorithm.
208 * Initializes hist to 0, calls into the main algorithm, and normalizes output if necessary.
209 */
210 template<BIN_SELECTION_ALGORITHM bin_algorithm>
histogramdd_out_cpu_template(const Tensor & self,const std::optional<Tensor> & weight,bool density,Tensor & hist,const TensorList & bin_edges)211 void histogramdd_out_cpu_template(const Tensor& self, const std::optional<Tensor>& weight, bool density,
212 Tensor& hist, const TensorList& bin_edges) {
213 hist.fill_(0);
214
215 const int64_t N = self.size(-1);
216 const int64_t M = std::accumulate(self.sizes().begin(), self.sizes().end() - 1,
217 (int64_t)1, std::multiplies<int64_t>());
218
219 const Tensor reshaped_input = self.reshape({M, N});
220
221 const auto reshaped_weight = weight.has_value()
222 ? std::optional<Tensor>(weight.value().reshape({M}))
223 : std::optional<Tensor>();
224
225 std::vector<Tensor> bin_edges_contig(bin_edges.size());
226 for (const auto dim : c10::irange(bin_edges_contig.size())) {
227 bin_edges_contig[dim] = bin_edges[dim].contiguous();
228 }
229
230 AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, self.scalar_type(), "histogram_cpu", [&]() {
231 histogramdd_cpu_contiguous<scalar_t, bin_algorithm>(
232 hist, bin_edges_contig, reshaped_input, reshaped_weight);
233 });
234
235 /* Divides each bin's value by the total count/weight in all bins,
236 * and by the bin's volume.
237 */
238 if (density) {
239 const auto hist_sum = hist.sum().item();
240 hist.div_(hist_sum);
241
242 /* For each dimension, divides each bin's value
243 * by the bin's length in that dimension.
244 */
245 for (const auto dim : c10::irange(N)) {
246 const auto bin_lengths = bin_edges[dim].diff();
247
248 // Used to reshape bin_lengths to align with the corresponding dimension of hist.
249 std::vector<int64_t> shape(N, 1);
250 shape[dim] = bin_lengths.numel();
251
252 hist.div_(bin_lengths.reshape(shape));
253 }
254 }
255 }
256
257 /* The general implementation of the histogram kernel. Maps each element of the input tensor
258 * to its corresponding bin by performing a binary search over the elements of bin_edges.
259 *
260 * Refer to histogramdd_out_cpu_template for more details.
261 */
histogramdd_kernel_impl(const Tensor & self,const std::optional<Tensor> & weight,bool density,Tensor & hist,const TensorList & bin_edges)262 static void histogramdd_kernel_impl(const Tensor& self, const std::optional<Tensor>& weight, bool density,
263 Tensor& hist, const TensorList& bin_edges) {
264 histogramdd_out_cpu_template<BINARY_SEARCH>(self, weight, density, hist, bin_edges);
265 }
266
267 /* A faster version of the histogram kernel for cases in which bin_edges are known
268 * to form a linear progression.
269 *
270 * Refer to histogramdd_out_cpu_template for more details.
271 */
histogramdd_linear_kernel_impl(const Tensor & self,const std::optional<Tensor> & weight,bool density,Tensor & hist,const TensorList & bin_edges,bool local_search)272 static void histogramdd_linear_kernel_impl(const Tensor& self, const std::optional<Tensor>& weight,
273 bool density, Tensor& hist, const TensorList& bin_edges, bool local_search) {
274 if (local_search) {
275 // histogramdd codepath: both hist and bin_edges are eventually returned as output,
276 // so we'll keep them consistent
277 histogramdd_out_cpu_template<LINEAR_INTERPOLATION_WITH_LOCAL_SEARCH>(
278 self, weight, density, hist, bin_edges);
279 } else {
280 // histc codepath: bin_edges are not returned to the caller
281 histogramdd_out_cpu_template<LINEAR_INTERPOLATION>(
282 self, weight, density, hist, bin_edges);
283 }
284 }
285
286 template<typename scalar_t>
infer_bin_edges_from_input(const Tensor & input,const int64_t N,std::vector<double> & leftmost_edges,std::vector<double> & rightmost_edges)287 void infer_bin_edges_from_input(const Tensor& input, const int64_t N,
288 std::vector<double> &leftmost_edges, std::vector<double> &rightmost_edges) {
289 // Calls aminmax on input with dim=0, reducing all but the innermost dimension of input.
290 auto [min, max] = aminmax(input, 0);
291
292 TORCH_INTERNAL_ASSERT(min.is_contiguous() && max.is_contiguous());
293
294 const scalar_t *min_data = min.const_data_ptr<scalar_t>();
295 std::copy(min_data, min_data + N, leftmost_edges.begin());
296
297 const scalar_t *max_data = max.const_data_ptr<scalar_t>();
298 std::copy(max_data, max_data + N, rightmost_edges.begin());
299 }
300
histogram_select_outer_bin_edges_impl(const Tensor & input,const int64_t N,std::vector<double> & leftmost_edges,std::vector<double> & rightmost_edges)301 static void histogram_select_outer_bin_edges_impl(const Tensor& input, const int64_t N,
302 std::vector<double> &leftmost_edges, std::vector<double> &rightmost_edges) {
303 AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "histogramdd", [&]() {
304 infer_bin_edges_from_input<scalar_t>(input, N, leftmost_edges, rightmost_edges);
305 });
306 }
307
308 } // namespace
309
310 REGISTER_DISPATCH(histogramdd_stub, &histogramdd_kernel_impl);
311 REGISTER_DISPATCH(histogramdd_linear_stub, &histogramdd_linear_kernel_impl);
312 REGISTER_DISPATCH(histogram_select_outer_bin_edges_stub, &histogram_select_outer_bin_edges_impl);
313
314 } // namespace at::native
315