xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/ReduceOpsUtils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <limits>
4 #include <ATen/core/Tensor.h>
5 #include <ATen/native/Resize.h>
6 #include <ATen/native/TensorIterator.h>
7 #include <ATen/native/NonEmptyUtils.h>
8 #include <ATen/WrapDimUtilsMulti.h>
9 #include <c10/core/ScalarType.h>
10 #include <c10/util/irange.h>
11 
12 #ifndef AT_PER_OPERATOR_HEADERS
13 #include <ATen/Functions.h>
14 #else
15 #include <ATen/ops/empty.h>
16 #include <ATen/ops/scalar_tensor.h>
17 #endif
18 
19 namespace at::native {
20 
21 // Maximum and minimum possible scalar values, including infinities
22 template <typename scalar_t>
upper_bound()23 constexpr scalar_t upper_bound() {
24   using lim = std::numeric_limits<scalar_t>;
25   return lim::has_infinity ? lim::infinity() : lim::max();
26 }
27 
28 template <typename scalar_t>
lower_bound()29 constexpr scalar_t lower_bound() {
30   using lim = std::numeric_limits<scalar_t>;
31   return lim::has_infinity ? -lim::infinity() : lim::lowest();
32 }
33 
restride_dim(const Tensor & src,int64_t dim,IntArrayRef replacement_shape)34 inline Tensor restride_dim(
35   const Tensor& src, int64_t dim,
36   IntArrayRef replacement_shape
37 ) {
38   auto strides = ensure_nonempty_vec(src.strides().vec());
39   strides[dim] = 0;
40   return src.as_strided(replacement_shape, strides);
41 }
42 
_dimreduce_setup(const Tensor & result,const Tensor & self,int64_t dim)43 inline void _dimreduce_setup(const Tensor &result, const Tensor &self,
44                                 int64_t dim) {
45   IntArrayRef self_sizes = self.sizes();
46   std::vector<int64_t> result_sizes;
47   result_sizes.insert(result_sizes.end(), self_sizes.begin(), self_sizes.end());
48   result_sizes[dim] = 1;
49   result.resize_(result_sizes);
50 }
51 
_dimreduce_return_trivial(const Tensor & result,const Tensor & self,const Scalar & ident,int64_t dim,bool keepdim)52 inline bool _dimreduce_return_trivial(const Tensor &result, const Tensor &self,
53                                       const Scalar& ident, int64_t dim, bool keepdim) {
54   if (self.numel() == 1 && self.ndimension() == 0) {
55     result.resize_({});
56     result.fill_(self);
57     return true;
58   }
59   // Return identity
60   if (self.numel() == 0) {
61     _dimreduce_setup(result, self, dim);
62     result.fill_(ident);
63     if (!keepdim) result.squeeze_(dim);
64     return true;
65   }
66   return false;
67 }
68 
_dimreduce_return_trivial_no_ident(Tensor & result,const Tensor & self,int64_t,bool,const char *)69 inline bool _dimreduce_return_trivial_no_ident(Tensor &result, const Tensor &self,
70                                                int64_t /*dim*/, bool /*keepdim*/, const char* /*fn_name*/) {
71   if (self.numel() == 1 && self.ndimension() == 0) {
72     result.resize_({});
73     result.fill_(self);
74     return true;
75   }
76 
77   return false;
78 }
79 
_allreduce_return_trivial(const Tensor & self,const Scalar & ident)80 inline std::optional<Tensor> _allreduce_return_trivial(
81     const Tensor& self,
82     const Scalar& ident) {
83   // Return identity
84   if (self.numel() == 0) {
85     return at::scalar_tensor(ident, self.options());
86   }
87   return std::nullopt;
88 }
89 
90 #define OPTION_TYPE_EQUALITY_CHECK(option, out, self) \
91 { \
92   TORCH_CHECK(\
93     out.option() == self.option(),\
94     "expected ", #option, " ",\
95     self.option(),\
96     " but found ", out.option())\
97 }
98 
check_scalar_type_device_layout_equal(const Tensor & out,const Tensor & self)99 inline void check_scalar_type_device_layout_equal(const Tensor& out, const Tensor& self) {
100   OPTION_TYPE_EQUALITY_CHECK(scalar_type, out, self);
101   OPTION_TYPE_EQUALITY_CHECK(device, out.options(), self.options());
102   OPTION_TYPE_EQUALITY_CHECK(layout, out.options(), self.options());
103 }
104 
integer_upcast(const Tensor & self,std::optional<ScalarType> dtype)105 inline Tensor integer_upcast(const Tensor& self, std::optional<ScalarType> dtype) {
106   ScalarType scalarType = self.scalar_type();
107   TORCH_CHECK(!isBarebonesUnsignedType(scalarType), "integer upcasting for uint16, uint32 and uint64 is not currently implemented");
108   ScalarType upcast_scalarType = dtype.value_or(at::isIntegralType(scalarType, /*includeBool=*/true) ? ScalarType::Long : scalarType);
109   return self.toType(upcast_scalarType);
110 }
111 
112 using DimMask = TensorIterator::DimMask;
113 
make_dim_vector(OptionalIntArrayRef opt_dims,int64_t ndim)114 inline DimVector make_dim_vector(OptionalIntArrayRef opt_dims, int64_t ndim) {
115   if (opt_dims.has_value()) {
116     return DimVector(opt_dims.value());
117   } else {
118     std::vector<int64_t> all_dims(ndim);
119     std::iota(all_dims.begin(), all_dims.end(), 0);
120     return DimVector(all_dims);
121   }
122 }
123 
124 inline DimMask make_dim_mask(OptionalIntArrayRef opt_dims, int64_t ndim, bool allow_empty_dims=false) {
125   DimMask mask;
126   if (opt_dims.has_value()) {
127     auto dims = opt_dims.value();
128     if (dims.empty() && !allow_empty_dims) {
129       mask = DimMask().flip();
130     } else {
131       mask = at::dim_list_to_bitset(dims, ndim);
132     }
133   } else {
134     mask = DimMask().flip();
135   }
136   return mask;
137 }
138 
shape_from_dim_mask(const Tensor & self,DimMask mask,bool keepdim)139 inline DimVector shape_from_dim_mask(const Tensor& self, DimMask mask, bool keepdim) {
140   auto shape = DimVector(self.sizes());
141   for (int dim = shape.size() - 1; dim >= 0; dim--) {
142     if (mask[dim]) {
143       if (keepdim) {
144         shape[dim] = 1;
145       } else {
146         shape.erase(shape.begin() + dim);
147       }
148     }
149   }
150   return shape;
151 }
152 
resize_reduction_result(Tensor & result,const Tensor & self,DimMask mask,bool keepdim,ScalarType)153 inline void resize_reduction_result(
154     Tensor& result, const Tensor& self, DimMask mask, bool keepdim,
155     ScalarType /*dtype*/)
156 {
157   auto shape = shape_from_dim_mask(self, mask, keepdim);
158   TORCH_CHECK(result.defined(), "Cannot create a new tensor inside a reduction op. You likely tried to call an operator with an out argument but the out argument was an undefined tensor.");
159   at::native::resize_output(result, shape);
160 }
161 
create_reduction_result(const Tensor & self,at::OptionalIntArrayRef dim,bool keepdim,ScalarType dtype)162 inline Tensor create_reduction_result(
163   const Tensor& self, at::OptionalIntArrayRef dim, bool keepdim, ScalarType dtype
164 ) {
165   DimMask mask = make_dim_mask(dim, self.dim());
166   auto shape = shape_from_dim_mask(self, mask, keepdim);
167   return at::empty(shape, self.options().dtype(dtype));
168 }
169 
review_reduce_result(const Tensor & result,int ndim,DimMask mask,bool keepdim)170 inline Tensor review_reduce_result(const Tensor& result, int ndim, DimMask mask, bool keepdim) {
171   if (keepdim) {
172     return result;
173   }
174   auto shape = DimVector(result.sizes());
175   auto stride = DimVector(result.strides());
176   for (const auto dim : c10::irange(ndim)) {
177     if (mask[dim]) {
178       shape.insert(shape.begin() + dim, 1);
179       stride.insert(stride.begin() + dim, 0);
180     }
181   }
182   return result.as_strided(shape, stride);
183 }
184 
make_reduction(const char * name,Tensor & result,const Tensor & self,at::OptionalIntArrayRef dim_opt,bool keepdim,ScalarType in_dtype,ScalarType out_dtype)185 inline TensorIterator make_reduction(
186     const char* name, Tensor& result, const Tensor& self,
187     at::OptionalIntArrayRef dim_opt,
188     bool keepdim, ScalarType in_dtype, ScalarType out_dtype) {
189   // check that result type and dtype match if provided
190   TORCH_CHECK(
191       !result.defined() || result.scalar_type() == out_dtype,
192       name, ": provided dtype must match dtype of result. Got ",
193       toString(result.scalar_type()),
194       " and ",
195       toString(out_dtype),
196       ".");
197   // dim={} performs an all-reduce, same as dim=None
198   IntArrayRef dim = dim_opt.value_or(IntArrayRef{});
199   int64_t ndim = self.dim();
200   auto mask = make_dim_mask(dim, ndim);
201   resize_reduction_result(result, self, mask, keepdim, out_dtype);
202   auto viewed_result = review_reduce_result(result, ndim, mask, keepdim);
203   namedinference::propagate_names_for_reduction(result, self, dim, keepdim);
204   if (self.scalar_type() == in_dtype) {
205     return TensorIterator::reduce_op(viewed_result, self);
206   }
207   return TensorIterator::reduce_op(viewed_result, self.to(in_dtype));
208 }
209 
make_reduction(const char * name,Tensor & result,const Tensor & self,at::OptionalIntArrayRef dim,bool keepdim,ScalarType out_dtype)210 inline C10_UNUSED TensorIterator make_reduction(
211     const char* name, Tensor& result, const Tensor& self,
212     at::OptionalIntArrayRef dim, bool keepdim, ScalarType out_dtype) {
213   // special case for type promotion in mixed precision, improves computational
214   // efficiency.
215   // not generalize this to common mismatched input/output types to avoid cross
216   // product of templated kernel launches.
217   const bool gpu_lowp_to_f32 = (
218     self.is_cuda() && (self.scalar_type() == kHalf || self.scalar_type() == kBFloat16) && out_dtype == kFloat);
219   auto in_dtype = gpu_lowp_to_f32 ? self.scalar_type()
220                    : self.is_complex() ? c10::toComplexType(out_dtype)
221                                        : out_dtype;
222   return make_reduction(name, result, self, dim, keepdim, in_dtype, out_dtype);
223 }
224 
make_reduction(const char * name,Tensor & result1,Tensor & result2,const Tensor & self,at::OptionalIntArrayRef dim_opt,bool keepdim,ScalarType dtype1,ScalarType dtype2)225 inline TensorIterator make_reduction(
226     const char* name, Tensor& result1, Tensor& result2, const Tensor& self,
227     at::OptionalIntArrayRef dim_opt, bool keepdim, ScalarType dtype1,
228     ScalarType dtype2) {
229   // check that result type and dtype match if provided
230   TORCH_CHECK(
231     (!result1.defined() || result1.scalar_type() == dtype1) && (!result2.defined() || result2.scalar_type() == dtype2),
232     name, ": provided dtype must match dtype of result. Got ",
233     toString(result1.scalar_type()), toString(result2.scalar_type()),
234     " and ",
235     toString(dtype1), toString(dtype2),
236     ".");
237 
238   // dim={} performs an all-reduce, same as dim=None
239   auto dim = dim_opt.value_or(IntArrayRef{});
240   int64_t ndim = self.dim();
241   DimMask mask = make_dim_mask(dim, ndim);
242   resize_reduction_result(result1, self, mask, keepdim, dtype1);
243   auto viewed_result1 = review_reduce_result(result1, ndim, mask, keepdim);
244 
245   resize_reduction_result(result2, self, mask, keepdim, dtype2);
246   auto viewed_result2 = review_reduce_result(result2, ndim, mask, keepdim);
247 
248   namedinference::propagate_names_for_reduction(result1, self, dim, keepdim);
249   namedinference::propagate_names_for_reduction(result2, self, dim, keepdim);
250 
251   // special case for type promotion in mixed precision, improves computational
252   // efficiency.
253   // We don't generalize this to common mismatched input/output types to avoid cross
254   // product of templated kernel launches.
255   if (self.scalar_type() == dtype1 ||
256       (self.is_cuda() && self.scalar_type() == kHalf && dtype1 == kFloat)) {
257     return TensorIterator::reduce_op(viewed_result1, viewed_result2, self);
258   }
259   return TensorIterator::reduce_op(viewed_result1, viewed_result2, self.to(dtype1));
260 }
261 
make_reduction(const char * name,Tensor & result1,Tensor & result2,const Tensor & self,at::OptionalIntArrayRef dim,bool keepdim,ScalarType dtype)262 inline C10_UNUSED TensorIterator make_reduction(
263     const char* name, Tensor& result1, Tensor& result2, const Tensor& self,
264     at::OptionalIntArrayRef dim, bool keepdim, ScalarType dtype) {
265   return make_reduction(name, result1, result2, self, dim, keepdim, dtype, dtype);
266 }
267 
zero_numel_check_dims(const Tensor & self,const int64_t dim,const char * fn_name)268 inline void zero_numel_check_dims(const Tensor& self, const int64_t dim, const char *fn_name) {
269   if (self.ndimension() == 0) {
270     TORCH_CHECK_INDEX(dim == 0 || dim == -1, fn_name,
271       ": Expected reduction dim -1 or 0 for scalar but got ", dim);
272   }
273   else {
274     TORCH_CHECK_INDEX(self.size(dim) != 0, fn_name,
275       ": Expected reduction dim ", dim, " to have non-zero size.");
276   }
277 }
278 
zero_numel_check_dims(const Tensor & self,const IntArrayRef dim,const char * fn_name)279 inline void zero_numel_check_dims(const Tensor& self, const IntArrayRef dim, const char *fn_name) {
280   TORCH_CHECK(
281     !dim.empty(),
282       fn_name, ": Expected reduction dim to be specified for input.numel() == 0. ",
283         "Specify the reduction dim with the 'dim' argument.");
284   for (const int64_t d : dim) {
285     zero_numel_check_dims(self, d, fn_name);
286   }
287 }
288 
get_zero_numel_tensor_size(const Tensor & self,const int64_t dim,const bool keepdim,const char * fn_name)289 inline std::vector<int64_t> get_zero_numel_tensor_size(
290     const Tensor& self,
291     const int64_t dim,
292     const bool keepdim,
293     const char* fn_name) {
294   TORCH_INTERNAL_ASSERT(self.numel() == 0,  fn_name, ": Expected self.numel() == 0.");
295   zero_numel_check_dims(self, dim, fn_name);
296   std::vector<int64_t> sizes;
297   if (keepdim) {
298     sizes = self.sizes().vec();
299     sizes[dim] = 1;
300   }
301   else {
302     for (const auto d : c10::irange(self.dim())) {
303       if (d != dim) {
304         sizes.push_back(self.sizes()[d]);
305       }
306     }
307   }
308   return sizes;
309 }
310 
311 // Resize the result tensor and indices when result.numel() == 0 depending on values of
312 // dim and keepdim for returning tensors containing reduction results.
313 // This function should be called when you are reducing a zero-numel tensor and want to
314 // resize the output and return it. This function exists for resizing zero-numel
315 // tensors when the size of the reduction dimension is non-zero.
zero_numel_tensor_resize(Tensor & result,Tensor & result_indices,const Tensor & self,const int64_t dim,const bool keepdim,const char * fn_name)316 inline C10_UNUSED void zero_numel_tensor_resize(Tensor& result, Tensor& result_indices,
317                                      const Tensor& self, const int64_t dim,
318                                      const bool keepdim, const char *fn_name) {
319   auto sizes = get_zero_numel_tensor_size(self, dim, keepdim, fn_name);
320   at::native::resize_output(result, sizes);
321   at::native::resize_output(result_indices, sizes);
322 }
323 
get_dtype_from_self(const Tensor & self,const std::optional<ScalarType> & dtype,bool promote_integers)324 inline ScalarType get_dtype_from_self(
325     const Tensor& self,
326     const std::optional<ScalarType>& dtype,
327     bool promote_integers) {
328   if (dtype.has_value()) {
329     return dtype.value();
330   }
331   ScalarType src_type = self.scalar_type();
332   if (promote_integers && at::isIntegralType(src_type, /*includeBool=*/true)) {
333     return kLong;
334   }
335   return src_type;
336 }
337 
get_dtype_from_result(Tensor & result,std::optional<ScalarType> dtype)338 inline ScalarType get_dtype_from_result(Tensor& result, std::optional<ScalarType> dtype) {
339   TORCH_CHECK(result.defined(), "Cannot create a new tensor inside a reduction op. You likely tried to call an operator with an out argument but the out argument was an undefined tensor.");
340   if (dtype.has_value()) {
341     return dtype.value();
342   } else {
343     return result.scalar_type();
344   }
345 }
346 
347 
348 } // namespace at::native
349 
350 namespace at::meta {
351 
352 inline C10_UNUSED DimVector get_reduction_shape(
353     const Tensor& self,
354     IntArrayRef dims,
355     bool keepdim,
356     bool allow_empty_dims=false) {
357   auto mask = native::make_dim_mask(dims, self.dim(), allow_empty_dims);
358   return native::shape_from_dim_mask(self, mask, keepdim);
359 }
360 
361 inline void resize_reduction(
362     impl::MetaBase& meta,
363     const Tensor& self,
364     OptionalIntArrayRef opt_dims,
365     bool keepdim,
366     ScalarType out_dtype,
367     bool allow_empty_dims=false) {
368   DimVector dims_ = at::native::make_dim_vector(opt_dims, self.dim());
369   maybe_wrap_dims(dims_, self.dim());
370   auto shape = get_reduction_shape(self, dims_, keepdim, allow_empty_dims);
371   if (self.layout() == kStrided) {
372     meta.set_output_raw_strided(0, shape, {}, self.options().dtype(out_dtype));
373   } else if (shape.empty()) {
374     meta.set_output_raw_strided(0, shape, {}, self.options().dtype(out_dtype).layout(kStrided));
375   } else {
376     TORCH_CHECK(false, "resize_reduction: support for output with ", self.layout(), " layout is not implemented yet");
377   }
378   namedinference::propagate_names_for_reduction(
379       meta.maybe_get_output(), self, dims_, keepdim);
380 }
381 
resize_reduction_with_indices(impl::MetaBase & meta,const Tensor & self,IntArrayRef dims,bool keepdim,ScalarType out_dtype)382 inline void resize_reduction_with_indices(
383     impl::MetaBase& meta,
384     const Tensor& self,
385     IntArrayRef dims,
386     bool keepdim,
387     ScalarType out_dtype) {
388   DimVector dims_(dims);
389   maybe_wrap_dims(dims_, self.dim());
390   auto shape = get_reduction_shape(self, dims_, keepdim);
391   meta.set_output_raw_strided(0, shape, {}, self.options().dtype(out_dtype));
392   meta.set_output_raw_strided(1, shape, {}, self.options().dtype(kLong));
393   namedinference::propagate_names_for_reduction(
394       meta.maybe_get_output(0), self, dims_, keepdim);
395   namedinference::propagate_names_for_reduction(
396       meta.maybe_get_output(1), self, dims_, keepdim);
397 }
398 
make_reduction(const Tensor & self,const Tensor & result,OptionalIntArrayRef opt_dims,bool keepdim,ScalarType in_dtype)399 inline TensorIterator make_reduction(
400     const Tensor& self,
401     const Tensor& result,
402     OptionalIntArrayRef opt_dims,
403     bool keepdim,
404     ScalarType in_dtype) {
405   int64_t ndim = self.dim();
406   auto mask = at::native::make_dim_mask(opt_dims, ndim);
407   auto viewed_result =
408       at::native::review_reduce_result(result, ndim, mask, keepdim);
409   if (self.scalar_type() == in_dtype) {
410     return TensorIterator::reduce_op(viewed_result, self);
411   }
412   return TensorIterator::reduce_op(viewed_result, self.to(in_dtype));
413 }
414 
make_reduction(const Tensor & self,const Tensor & result1,const Tensor & result2,IntArrayRef dims,bool keepdim,ScalarType dtype1,ScalarType)415 inline TensorIterator make_reduction(
416     const Tensor& self,
417     const Tensor& result1,
418     const Tensor& result2,
419     IntArrayRef dims,
420     bool keepdim,
421     ScalarType dtype1,
422     ScalarType /*dtype2*/) {
423   int64_t ndim = self.dim();
424   auto mask = at::native::make_dim_mask(dims, ndim);
425   auto viewed_result1 = at::native::review_reduce_result(result1, ndim, mask, keepdim);
426   auto viewed_result2 = at::native::review_reduce_result(result2, ndim, mask, keepdim);
427   // special case for type promotion in mixed precision, improves computational efficiency.
428   // We don't generalize this to common mismatched input/output types to avoid cross product
429   // of templated kernel launches.
430   if (self.scalar_type() == dtype1 ||
431       (self.is_cuda() && self.scalar_type() == kHalf && dtype1 == kFloat)) {
432     return TensorIterator::reduce_op(viewed_result1, viewed_result2, self);
433   }
434   return TensorIterator::reduce_op(viewed_result1, viewed_result2, self.to(dtype1));
435 }
436 
make_reduction_from_out_ty(const Tensor & self,const Tensor & result,OptionalIntArrayRef opt_dims,bool keepdim,ScalarType out_dtype)437 inline C10_UNUSED TensorIterator make_reduction_from_out_ty(
438     const Tensor& self,
439     const Tensor& result,
440     OptionalIntArrayRef opt_dims,
441     bool keepdim,
442     ScalarType out_dtype) {
443   // special case for type promotion in mixed precision, improves computational
444   // efficiency.
445   // not generalize this to common mismatched input/output types to avoid cross
446   // product of templated kernel launches.
447   const bool gpu_lowp_to_f32 =
448       (self.is_cuda() &&
449        (self.scalar_type() == kHalf || self.scalar_type() == kBFloat16) &&
450        out_dtype == kFloat);
451   auto in_dtype = gpu_lowp_to_f32 ? self.scalar_type() : out_dtype;
452   return make_reduction(self, result, opt_dims, keepdim, in_dtype);
453 }
454 
455 } // namespace at::meta
456