xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/Bucketization.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/Parallel.h>
5 #include <ATen/native/BucketizationUtils.h>
6 #include <ATen/native/Resize.h>
7 #include <c10/util/irange.h>
8 
9 #ifndef AT_PER_OPERATOR_HEADERS
10 #include <ATen/Functions.h>
11 #else
12 #include <ATen/ops/empty.h>
13 #include <ATen/ops/bucketize_native.h>
14 #include <ATen/ops/searchsorted_native.h>
15 #endif
16 
17 /* Implement a numpy like searchsorted and a TF like bucketize function running on cpu
18  *
19  * - torch.searchsorted(sorted_sequence, values, right=False, side=None, out_int32=False, sorter=None)
20  *   sorted_sequence - N*D or 1D (apply to all values) tensor containing sorted sequences in last dimension
21  *   values          - N*D tensor or a Scalar (when sorted_sequence is 1D) containing the search values
22  *   right           - corresponding to lower bound if False and upper bound if True
23  *   side            - (preferred to right) corresponding to lower bound if 'left' and upper bound if 'right'
24  *   out_int32       - the output tensor is int64_t type if False and int(32bit normally) type if True.
25  *   sorter          - if provided, sorted_sequence may not be sorted and the sorted order is given by this tensor
26  *
27  * - torch.bucketize(values, boundaries, right=False, out_int32=False)
28  *   values     - N*D tensor or a Scalar containing the search value
29  *   boundaries - 1D tensor containing a sorted sequences
30  *   right      - corresponding to lower bound if False and upper bound if True
31  *   out_int32  - the output tensor is int64_t type if False and int(32bit normally) type if True.
32  *
33  * - Restrictions are defined in searchsorted_pre_check()
34  */
35 
36 namespace at::native {
37 
38 namespace {
39 
40 // minimal size for searchsorted_cpu_contiguous to run parallel (multithread)
41 constexpr int64_t SEARCHSORTED_GRAIN_SIZE = 200;
42 
43 // customized lower_bound func to ensure the low bound of 'nan', 'inf' etc. be the end of boundary
44 // and we can properly handle a sorter argument
45 // std::lower_bound can not be used here since its customized comparator need strict weak ordering
46 // and the customized comparators require both arguments to have the same type, which wouldn't
47 // happen when comparing val of input_t to an indexer value from sorter of int64
48 template<typename input_t>
cus_lower_bound(int64_t start,int64_t end,const input_t val,const input_t * bd,const int64_t * sort)49 int64_t cus_lower_bound(int64_t start, int64_t end, const input_t val, const input_t* bd, const int64_t* sort) {
50   // sorter gives relative ordering for ND tensors, so we need to save and add the non-updated start as an offset
51   // i.e. the second row of a 3x3 tensors starts at element 3 but sorter's second row only contains 0, 1, or 2
52   const int64_t orig_start = start;
53   while (start < end) {
54     const int64_t mid = start + ((end - start) >> 1);
55     const input_t mid_val = sort ? bd[sort[mid] + orig_start] : bd[mid];
56     if (!(mid_val >= val)) {
57       start = mid + 1;
58     }
59     else {
60       end = mid;
61     }
62   }
63   return start;
64 }
65 
66 // customized upper_bound func to ensure we can properly handle a sorter argument
67 // std::upper_bound can not be used here since its customized comparator requires both arguments to have the
68 // same type, which wouldn't happen when comparing val of input_t to an indexer value from sorter of int64
69 template<typename input_t>
cus_upper_bound(int64_t start,int64_t end,const input_t val,const input_t * bd,const int64_t * sort)70 int64_t cus_upper_bound(int64_t start, int64_t end, const input_t val, const input_t* bd, const int64_t* sort) {
71   // sorter gives relative ordering for ND tensors, so we need to save and add the non-updated start as an offset
72   // i.e. the second row of a 3x3 tensors starts at element 3 but sorter's second row only contains 0, 1, or 2
73   const int64_t orig_start = start;
74   while (start < end) {
75     const int64_t mid = start + ((end - start) >> 1);
76     const input_t mid_val = sort ? bd[sort[mid] + orig_start] : bd[mid];
77     if (!(mid_val > val)) {
78       start = mid + 1;
79     }
80     else {
81       end = mid;
82     }
83   }
84   return start;
85 }
86 
87 template<typename input_t, typename output_t>
searchsorted_cpu_contiguous(Tensor & result,const Tensor & input,const Tensor & boundaries,const bool & right,const Tensor & sorter)88 void searchsorted_cpu_contiguous(Tensor& result, const Tensor& input, const Tensor& boundaries, const bool& right, const Tensor& sorter) {
89   int64_t numel_in = input.numel();
90   bool is_scalar_input = input.dim() == 0 && numel_in == 1;
91   // inner most dim size of input and boundaries
92   int64_t idim_in = is_scalar_input ? 1 : input.sizes().back();
93   int64_t idim_bd = boundaries.sizes().back();
94 
95   const input_t *data_in = input.const_data_ptr<input_t>();
96   const input_t *data_bd = boundaries.const_data_ptr<input_t>();
97   const int64_t *data_st = sorter.defined() ? sorter.const_data_ptr<int64_t>() : nullptr;
98   output_t *data_out = result.data_ptr<output_t>();
99 
100   bool is_1d_boundaries = boundaries.dim() == 1;
101   at::parallel_for(0, numel_in, SEARCHSORTED_GRAIN_SIZE, [&](int64_t start, int64_t end) {
102     for (const auto i : c10::irange(start, end)) {
103       // If boundaries tensor is 1d, we always search the entire boundary tensor
104       int64_t start_bd = is_1d_boundaries ? 0 : i / idim_in * idim_bd;
105       int64_t end_bd = start_bd + idim_bd;
106 
107       int64_t pos = !right ?
108         cus_lower_bound(start_bd, end_bd, data_in[i], data_bd, data_st) - start_bd :
109         cus_upper_bound(start_bd, end_bd, data_in[i], data_bd, data_st) - start_bd;
110 
111       // type conversion might happen here
112       data_out[i] = pos;
113     }
114   });
115 }
116 
dispatch(Tensor & result,const Tensor & input,const Tensor & boundaries,bool out_int32,bool right,const Tensor & sorter)117 void dispatch(Tensor& result, const Tensor& input, const Tensor& boundaries, bool out_int32, bool right, const Tensor& sorter) {
118   if (!out_int32) {
119     AT_DISPATCH_ALL_TYPES_AND2(
120         ScalarType::Half,
121         ScalarType::BFloat16,
122         input.scalar_type(),
123         "searchsorted_out_cpu",
124         [&] {
125           searchsorted_cpu_contiguous<scalar_t, int64_t>(
126               result, input, boundaries, right, sorter);
127         });
128   }
129   else {
130     AT_DISPATCH_ALL_TYPES_AND2(
131         ScalarType::Half,
132         ScalarType::BFloat16,
133         input.scalar_type(),
134         "searchsorted_out_cpu",
135         [&] {
136           searchsorted_cpu_contiguous<scalar_t, int>(
137               result, input, boundaries, right, sorter);
138         });
139   }
140 }
141 
142 }
143 
searchsorted_out_cpu(const Tensor & sorted_sequence,const Tensor & self,bool out_int32,bool right,const std::optional<c10::string_view> side_opt,const std::optional<Tensor> & sorter_opt,Tensor & result)144 Tensor& searchsorted_out_cpu(
145     const Tensor& sorted_sequence,
146     const Tensor& self,
147     bool out_int32,
148     bool right,
149     const std::optional<c10::string_view> side_opt,
150     const std::optional<Tensor>& sorter_opt,
151     Tensor& result) {
152   // See [Note: hacky wrapper removal for optional tensor]
153   c10::MaybeOwned<Tensor> sorter_maybe_owned = at::borrow_from_optional_tensor(sorter_opt);
154   const Tensor& sorter = *sorter_maybe_owned;
155   searchsorted_pre_check(sorted_sequence, self, result, out_int32, right, side_opt, sorter);
156   resize_output(result, self.sizes());
157 
158   // we have two inputs to set right, pre_check checks that they aren't set to opposites
159   bool is_right = side_opt ? *side_opt == "right" : right;
160 
161   if (self.numel() == 0) {
162     return result;
163   }
164 
165   // for non-contiguous result tensors, we write the output to a contiguous copy so we can later copy back, maintaining the original result tensor
166   Tensor out = result;
167   if (!result.is_contiguous()) {
168     out = result.contiguous();
169   }
170   if (sorted_sequence.is_contiguous() && self.is_contiguous() && sorted_sequence.dtype() == self.dtype() && sorter.is_contiguous()) {
171     dispatch(out, self, sorted_sequence, out_int32, is_right, sorter);
172   }
173   else {
174     Tensor trimmed_input;
175     Tensor trimmed_boundaries;
176     Tensor trimmed_sorter;
177     searchsorted_maybe_trim_input_tensors(trimmed_input, trimmed_boundaries, trimmed_sorter, self, sorted_sequence, sorter);
178     const Tensor& final_input = trimmed_input.defined() ? trimmed_input : self;
179     const Tensor& final_boundaries = trimmed_boundaries.defined() ? trimmed_boundaries : sorted_sequence;
180     const Tensor& final_sorter = trimmed_sorter.defined() ? trimmed_sorter : sorter;
181     dispatch(out, final_input, final_boundaries, out_int32, is_right, final_sorter);
182   }
183 
184   // if result is non-contiguous, we wrote the answer to a copied version, so we copy back to the original result tensor
185   if (!result.is_contiguous()) {
186     result.copy_(out);
187   }
188   return result;
189 }
190 
searchsorted_out_cpu(const Tensor & sorted_sequence,const Scalar & self,bool out_int32,bool right,const std::optional<c10::string_view> side_opt,const std::optional<Tensor> & sorter_opt,Tensor & result)191 Tensor& searchsorted_out_cpu(
192     const Tensor& sorted_sequence,
193     const Scalar& self,
194     bool out_int32,
195     bool right,
196     const std::optional<c10::string_view> side_opt,
197     const std::optional<Tensor>& sorter_opt,
198     Tensor& result) {
199   const Tensor& scalar_tensor = searchsorted_scalar_tensor(self, sorted_sequence.device());
200   return searchsorted_out_cpu(sorted_sequence, scalar_tensor, out_int32, right, side_opt, sorter_opt, result);
201 }
202 
searchsorted_cpu(const Tensor & sorted_sequence,const Tensor & self,bool out_int32,bool right,const std::optional<c10::string_view> side_opt,const std::optional<Tensor> & sorter_opt)203 Tensor searchsorted_cpu(
204       const Tensor& sorted_sequence,
205       const Tensor& self,
206       bool out_int32,
207       bool right,
208       const std::optional<c10::string_view> side_opt,
209       const std::optional<Tensor>& sorter_opt) {
210   ScalarType scalar_type = out_int32 ? ScalarType::Int : ScalarType::Long;
211   c10::TensorOptions options = TensorOptions().device(self.options().device()).dtype(scalar_type);
212   Tensor result = at::empty({0}, options, MemoryFormat::Contiguous);
213   at::native::searchsorted_out_cpu(sorted_sequence, self, out_int32, right, side_opt, sorter_opt, result);
214   return result;
215 }
216 
searchsorted_cpu(const Tensor & sorted_sequence,const Scalar & self,bool out_int32,bool right,const std::optional<c10::string_view> side_opt,const std::optional<Tensor> & sorter_opt)217 Tensor searchsorted_cpu(
218     const Tensor& sorted_sequence,
219     const Scalar& self,
220     bool out_int32,
221     bool right,
222     const std::optional<c10::string_view> side_opt,
223     const std::optional<Tensor>& sorter_opt) {
224   const Tensor& scalar_tensor = searchsorted_scalar_tensor(self, sorted_sequence.device());
225   return searchsorted_cpu(sorted_sequence, scalar_tensor, out_int32, right, side_opt, sorter_opt);
226 }
227 
bucketize_out_cpu(const Tensor & self,const Tensor & boundaries,bool out_int32,bool right,Tensor & result)228 Tensor& bucketize_out_cpu(const Tensor& self, const Tensor& boundaries, bool out_int32, bool right, Tensor& result) {
229   TORCH_CHECK(boundaries.dim() == 1, "boundaries tensor must be 1 dimension, but got dim(", boundaries.dim(), ")");
230   at::native::searchsorted_out_cpu(boundaries, self, out_int32, right, std::nullopt, std::nullopt, result);
231   return result;
232 }
233 
bucketize_cpu(const Tensor & self,const Tensor & boundaries,bool out_int32,bool right)234 Tensor bucketize_cpu(const Tensor& self, const Tensor& boundaries, bool out_int32, bool right) {
235   ScalarType scalar_type = out_int32 ? ScalarType::Int : ScalarType::Long;
236   c10::TensorOptions options = TensorOptions().device(self.options().device()).dtype(scalar_type);
237   Tensor result = at::empty({0}, options, MemoryFormat::Contiguous);
238   at::native::bucketize_out_cpu(self, boundaries, out_int32, right, result);
239   return result;
240 }
241 
bucketize_cpu(const Scalar & self,const Tensor & boundaries,bool out_int32,bool right)242 Tensor bucketize_cpu(const Scalar& self, const Tensor& boundaries, bool out_int32, bool right) {
243   return bucketize_cpu(searchsorted_scalar_tensor(self, boundaries.device()), boundaries, out_int32, right);
244 }
245 
246 } // namespace at::native
247