1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/Parallel.h>
5 #include <ATen/native/BucketizationUtils.h>
6 #include <ATen/native/Resize.h>
7 #include <c10/util/irange.h>
8
9 #ifndef AT_PER_OPERATOR_HEADERS
10 #include <ATen/Functions.h>
11 #else
12 #include <ATen/ops/empty.h>
13 #include <ATen/ops/bucketize_native.h>
14 #include <ATen/ops/searchsorted_native.h>
15 #endif
16
17 /* Implement a numpy like searchsorted and a TF like bucketize function running on cpu
18 *
19 * - torch.searchsorted(sorted_sequence, values, right=False, side=None, out_int32=False, sorter=None)
20 * sorted_sequence - N*D or 1D (apply to all values) tensor containing sorted sequences in last dimension
21 * values - N*D tensor or a Scalar (when sorted_sequence is 1D) containing the search values
22 * right - corresponding to lower bound if False and upper bound if True
23 * side - (preferred to right) corresponding to lower bound if 'left' and upper bound if 'right'
24 * out_int32 - the output tensor is int64_t type if False and int(32bit normally) type if True.
25 * sorter - if provided, sorted_sequence may not be sorted and the sorted order is given by this tensor
26 *
27 * - torch.bucketize(values, boundaries, right=False, out_int32=False)
28 * values - N*D tensor or a Scalar containing the search value
29 * boundaries - 1D tensor containing a sorted sequences
30 * right - corresponding to lower bound if False and upper bound if True
31 * out_int32 - the output tensor is int64_t type if False and int(32bit normally) type if True.
32 *
33 * - Restrictions are defined in searchsorted_pre_check()
34 */
35
36 namespace at::native {
37
38 namespace {
39
40 // minimal size for searchsorted_cpu_contiguous to run parallel (multithread)
41 constexpr int64_t SEARCHSORTED_GRAIN_SIZE = 200;
42
43 // customized lower_bound func to ensure the low bound of 'nan', 'inf' etc. be the end of boundary
44 // and we can properly handle a sorter argument
45 // std::lower_bound can not be used here since its customized comparator need strict weak ordering
46 // and the customized comparators require both arguments to have the same type, which wouldn't
47 // happen when comparing val of input_t to an indexer value from sorter of int64
48 template<typename input_t>
cus_lower_bound(int64_t start,int64_t end,const input_t val,const input_t * bd,const int64_t * sort)49 int64_t cus_lower_bound(int64_t start, int64_t end, const input_t val, const input_t* bd, const int64_t* sort) {
50 // sorter gives relative ordering for ND tensors, so we need to save and add the non-updated start as an offset
51 // i.e. the second row of a 3x3 tensors starts at element 3 but sorter's second row only contains 0, 1, or 2
52 const int64_t orig_start = start;
53 while (start < end) {
54 const int64_t mid = start + ((end - start) >> 1);
55 const input_t mid_val = sort ? bd[sort[mid] + orig_start] : bd[mid];
56 if (!(mid_val >= val)) {
57 start = mid + 1;
58 }
59 else {
60 end = mid;
61 }
62 }
63 return start;
64 }
65
66 // customized upper_bound func to ensure we can properly handle a sorter argument
67 // std::upper_bound can not be used here since its customized comparator requires both arguments to have the
68 // same type, which wouldn't happen when comparing val of input_t to an indexer value from sorter of int64
69 template<typename input_t>
cus_upper_bound(int64_t start,int64_t end,const input_t val,const input_t * bd,const int64_t * sort)70 int64_t cus_upper_bound(int64_t start, int64_t end, const input_t val, const input_t* bd, const int64_t* sort) {
71 // sorter gives relative ordering for ND tensors, so we need to save and add the non-updated start as an offset
72 // i.e. the second row of a 3x3 tensors starts at element 3 but sorter's second row only contains 0, 1, or 2
73 const int64_t orig_start = start;
74 while (start < end) {
75 const int64_t mid = start + ((end - start) >> 1);
76 const input_t mid_val = sort ? bd[sort[mid] + orig_start] : bd[mid];
77 if (!(mid_val > val)) {
78 start = mid + 1;
79 }
80 else {
81 end = mid;
82 }
83 }
84 return start;
85 }
86
87 template<typename input_t, typename output_t>
searchsorted_cpu_contiguous(Tensor & result,const Tensor & input,const Tensor & boundaries,const bool & right,const Tensor & sorter)88 void searchsorted_cpu_contiguous(Tensor& result, const Tensor& input, const Tensor& boundaries, const bool& right, const Tensor& sorter) {
89 int64_t numel_in = input.numel();
90 bool is_scalar_input = input.dim() == 0 && numel_in == 1;
91 // inner most dim size of input and boundaries
92 int64_t idim_in = is_scalar_input ? 1 : input.sizes().back();
93 int64_t idim_bd = boundaries.sizes().back();
94
95 const input_t *data_in = input.const_data_ptr<input_t>();
96 const input_t *data_bd = boundaries.const_data_ptr<input_t>();
97 const int64_t *data_st = sorter.defined() ? sorter.const_data_ptr<int64_t>() : nullptr;
98 output_t *data_out = result.data_ptr<output_t>();
99
100 bool is_1d_boundaries = boundaries.dim() == 1;
101 at::parallel_for(0, numel_in, SEARCHSORTED_GRAIN_SIZE, [&](int64_t start, int64_t end) {
102 for (const auto i : c10::irange(start, end)) {
103 // If boundaries tensor is 1d, we always search the entire boundary tensor
104 int64_t start_bd = is_1d_boundaries ? 0 : i / idim_in * idim_bd;
105 int64_t end_bd = start_bd + idim_bd;
106
107 int64_t pos = !right ?
108 cus_lower_bound(start_bd, end_bd, data_in[i], data_bd, data_st) - start_bd :
109 cus_upper_bound(start_bd, end_bd, data_in[i], data_bd, data_st) - start_bd;
110
111 // type conversion might happen here
112 data_out[i] = pos;
113 }
114 });
115 }
116
dispatch(Tensor & result,const Tensor & input,const Tensor & boundaries,bool out_int32,bool right,const Tensor & sorter)117 void dispatch(Tensor& result, const Tensor& input, const Tensor& boundaries, bool out_int32, bool right, const Tensor& sorter) {
118 if (!out_int32) {
119 AT_DISPATCH_ALL_TYPES_AND2(
120 ScalarType::Half,
121 ScalarType::BFloat16,
122 input.scalar_type(),
123 "searchsorted_out_cpu",
124 [&] {
125 searchsorted_cpu_contiguous<scalar_t, int64_t>(
126 result, input, boundaries, right, sorter);
127 });
128 }
129 else {
130 AT_DISPATCH_ALL_TYPES_AND2(
131 ScalarType::Half,
132 ScalarType::BFloat16,
133 input.scalar_type(),
134 "searchsorted_out_cpu",
135 [&] {
136 searchsorted_cpu_contiguous<scalar_t, int>(
137 result, input, boundaries, right, sorter);
138 });
139 }
140 }
141
142 }
143
searchsorted_out_cpu(const Tensor & sorted_sequence,const Tensor & self,bool out_int32,bool right,const std::optional<c10::string_view> side_opt,const std::optional<Tensor> & sorter_opt,Tensor & result)144 Tensor& searchsorted_out_cpu(
145 const Tensor& sorted_sequence,
146 const Tensor& self,
147 bool out_int32,
148 bool right,
149 const std::optional<c10::string_view> side_opt,
150 const std::optional<Tensor>& sorter_opt,
151 Tensor& result) {
152 // See [Note: hacky wrapper removal for optional tensor]
153 c10::MaybeOwned<Tensor> sorter_maybe_owned = at::borrow_from_optional_tensor(sorter_opt);
154 const Tensor& sorter = *sorter_maybe_owned;
155 searchsorted_pre_check(sorted_sequence, self, result, out_int32, right, side_opt, sorter);
156 resize_output(result, self.sizes());
157
158 // we have two inputs to set right, pre_check checks that they aren't set to opposites
159 bool is_right = side_opt ? *side_opt == "right" : right;
160
161 if (self.numel() == 0) {
162 return result;
163 }
164
165 // for non-contiguous result tensors, we write the output to a contiguous copy so we can later copy back, maintaining the original result tensor
166 Tensor out = result;
167 if (!result.is_contiguous()) {
168 out = result.contiguous();
169 }
170 if (sorted_sequence.is_contiguous() && self.is_contiguous() && sorted_sequence.dtype() == self.dtype() && sorter.is_contiguous()) {
171 dispatch(out, self, sorted_sequence, out_int32, is_right, sorter);
172 }
173 else {
174 Tensor trimmed_input;
175 Tensor trimmed_boundaries;
176 Tensor trimmed_sorter;
177 searchsorted_maybe_trim_input_tensors(trimmed_input, trimmed_boundaries, trimmed_sorter, self, sorted_sequence, sorter);
178 const Tensor& final_input = trimmed_input.defined() ? trimmed_input : self;
179 const Tensor& final_boundaries = trimmed_boundaries.defined() ? trimmed_boundaries : sorted_sequence;
180 const Tensor& final_sorter = trimmed_sorter.defined() ? trimmed_sorter : sorter;
181 dispatch(out, final_input, final_boundaries, out_int32, is_right, final_sorter);
182 }
183
184 // if result is non-contiguous, we wrote the answer to a copied version, so we copy back to the original result tensor
185 if (!result.is_contiguous()) {
186 result.copy_(out);
187 }
188 return result;
189 }
190
searchsorted_out_cpu(const Tensor & sorted_sequence,const Scalar & self,bool out_int32,bool right,const std::optional<c10::string_view> side_opt,const std::optional<Tensor> & sorter_opt,Tensor & result)191 Tensor& searchsorted_out_cpu(
192 const Tensor& sorted_sequence,
193 const Scalar& self,
194 bool out_int32,
195 bool right,
196 const std::optional<c10::string_view> side_opt,
197 const std::optional<Tensor>& sorter_opt,
198 Tensor& result) {
199 const Tensor& scalar_tensor = searchsorted_scalar_tensor(self, sorted_sequence.device());
200 return searchsorted_out_cpu(sorted_sequence, scalar_tensor, out_int32, right, side_opt, sorter_opt, result);
201 }
202
searchsorted_cpu(const Tensor & sorted_sequence,const Tensor & self,bool out_int32,bool right,const std::optional<c10::string_view> side_opt,const std::optional<Tensor> & sorter_opt)203 Tensor searchsorted_cpu(
204 const Tensor& sorted_sequence,
205 const Tensor& self,
206 bool out_int32,
207 bool right,
208 const std::optional<c10::string_view> side_opt,
209 const std::optional<Tensor>& sorter_opt) {
210 ScalarType scalar_type = out_int32 ? ScalarType::Int : ScalarType::Long;
211 c10::TensorOptions options = TensorOptions().device(self.options().device()).dtype(scalar_type);
212 Tensor result = at::empty({0}, options, MemoryFormat::Contiguous);
213 at::native::searchsorted_out_cpu(sorted_sequence, self, out_int32, right, side_opt, sorter_opt, result);
214 return result;
215 }
216
searchsorted_cpu(const Tensor & sorted_sequence,const Scalar & self,bool out_int32,bool right,const std::optional<c10::string_view> side_opt,const std::optional<Tensor> & sorter_opt)217 Tensor searchsorted_cpu(
218 const Tensor& sorted_sequence,
219 const Scalar& self,
220 bool out_int32,
221 bool right,
222 const std::optional<c10::string_view> side_opt,
223 const std::optional<Tensor>& sorter_opt) {
224 const Tensor& scalar_tensor = searchsorted_scalar_tensor(self, sorted_sequence.device());
225 return searchsorted_cpu(sorted_sequence, scalar_tensor, out_int32, right, side_opt, sorter_opt);
226 }
227
bucketize_out_cpu(const Tensor & self,const Tensor & boundaries,bool out_int32,bool right,Tensor & result)228 Tensor& bucketize_out_cpu(const Tensor& self, const Tensor& boundaries, bool out_int32, bool right, Tensor& result) {
229 TORCH_CHECK(boundaries.dim() == 1, "boundaries tensor must be 1 dimension, but got dim(", boundaries.dim(), ")");
230 at::native::searchsorted_out_cpu(boundaries, self, out_int32, right, std::nullopt, std::nullopt, result);
231 return result;
232 }
233
bucketize_cpu(const Tensor & self,const Tensor & boundaries,bool out_int32,bool right)234 Tensor bucketize_cpu(const Tensor& self, const Tensor& boundaries, bool out_int32, bool right) {
235 ScalarType scalar_type = out_int32 ? ScalarType::Int : ScalarType::Long;
236 c10::TensorOptions options = TensorOptions().device(self.options().device()).dtype(scalar_type);
237 Tensor result = at::empty({0}, options, MemoryFormat::Contiguous);
238 at::native::bucketize_out_cpu(self, boundaries, out_int32, right, result);
239 return result;
240 }
241
bucketize_cpu(const Scalar & self,const Tensor & boundaries,bool out_int32,bool right)242 Tensor bucketize_cpu(const Scalar& self, const Tensor& boundaries, bool out_int32, bool right) {
243 return bucketize_cpu(searchsorted_scalar_tensor(self, boundaries.device()), boundaries, out_int32, right);
244 }
245
246 } // namespace at::native
247