xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/DistributionTemplates.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/Tensor.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/Dispatch_v2.h>
6 #include <ATen/Generator.h>
7 #include <ATen/ExpandUtils.h>
8 #include <ATen/Tensor.h>
9 #include <ATen/MemoryOverlap.h>
10 #include <ATen/NamedTensorUtils.h>
11 #include <ATen/native/Resize.h>
12 #include <ATen/native/TensorIterator.h>
13 #include <cmath>
14 #include <limits>
15 #include <optional>
16 
17 #ifndef AT_PER_OPERATOR_HEADERS
18 #include <ATen/Functions.h>
19 #else
20 #include <ATen/ops/empty_like.h>
21 #include <ATen/ops/empty.h>
22 #include <ATen/ops/full.h>
23 #include <ATen/ops/view_as_real.h>
24 #endif
25 
26 namespace at::native::templates {
27 
28 // ==================================================== Random ========================================================
29 
30 // The purpose of `update_from` and `update_to` is to find the closest valid int64_t number that can be used as actual `from`.
31 // The current implementation of `random_` uses uint64_t arithmetics and casts the result to the target dtype(scalar_t).
32 // This casting can result in generating numbers that happen to be greater or equal to `to` value. For instance:
33 //
34 //    auto actual = torch::empty({3, 3}, torch::half);
35 //    actual.random_(0, 65504);
36 //
37 // If random's uint64_t arithmetics produces 65503 as a random value after casting to torch::half it becomes 65504
38 // and violates the requirement that random value must be less than `to`. To resolve this issue `update_from` and `update_to`
39 // moves `from` to the right and `to` to the left to the next closest value that won't go outside [from, to) after casting to
40 // the target dtype. For `to` = 65504 it moves left for (1 << (log2(to) - 11 + 1)) = 32 and becomes 65472, which is previous
41 // available number for torch::half dtype.
42 template<typename scalar_t>
update_from(int64_t from)43 int64_t update_from(int64_t from) {
44   static_assert(
45     std::is_floating_point<scalar_t>::value ||
46     std::is_same<scalar_t, at::Half>::value ||
47     std::is_same<scalar_t, at::BFloat16>::value, "scalar_t must be floating-point type");
48   const auto from_plus_1 = static_cast<int64_t>(static_cast<scalar_t>(from + 1));
49   if (from_plus_1 < from) {
50     int64_t from_ = std::abs(from + 1);
51     int n = 0;
52     while (from_ >>= 1) ++n;
53     // NOLINTNEXTLINE(clang-analyzer-core.UndefinedBinaryOperatorResult)
54     from = from_plus_1 + (1LL << (n - std::numeric_limits<scalar_t>::digits + 1));
55   }
56   return from;
57 }
58 
59 template<typename scalar_t>
update_to(int64_t to)60 int64_t update_to(int64_t to) {
61   static_assert(
62     std::is_floating_point<scalar_t>::value ||
63     std::is_same<scalar_t, at::Half>::value ||
64     std::is_same<scalar_t, at::BFloat16>::value, "scalar_t must be floating-point type");
65   const auto to_minus_1 = static_cast<int64_t>(static_cast<scalar_t>(to - 1));
66   if (to_minus_1 >= to) {
67     int64_t to_ = std::abs(to - 1);
68     int n = 0;
69     while (to_ >>= 1) ++n;
70     // NOLINTNEXTLINE(clang-analyzer-core.UndefinedBinaryOperatorResult)
71     to = to_minus_1 - (1LL << (n - std::numeric_limits<scalar_t>::digits + 1));
72   }
73   return to;
74 }
75 
76 // Return earlier for not invoking kernel.
77 // See https://github.com/pytorch/pytorch/issues/103418 for more details
78 #define CHECK_EMPTY_AND_RETURN(tensor) \
79   if (tensor.numel() == 0) {  \
80     return tensor;  \
81   }
82 
83 template<template<typename> class random_kernel, typename RNG>
random_impl(at::Tensor & self,std::optional<Generator> generator)84 at::Tensor& random_impl(at::Tensor& self, std::optional<Generator> generator) {
85   CHECK_EMPTY_AND_RETURN(self);
86   auto iter = at::TensorIterator::borrowing_nullary_op(self);
87   random_kernel<RNG>()(iter, generator);
88   return self;
89 }
90 
91 #define CHECK_OUT_OF_BOUNDS(var, name, min, max, dtype) \
92   TORCH_CHECK(var >= min && var <= max, name , " is out of bounds for ", dtype); \
93 
94 #define WARN_OUT_OF_BOUNDS(var, name, digits, dtype) \
95   if (var < -(1LL << digits) || var > (1LL << digits)) { \
96     TORCH_WARN(name , " is out of bounds [-(2^", digits, "), 2^", digits, "]. ", \
97       "Due to precision limitations ", dtype, " can support discrete uniform distribution only within this range. ", \
98       "This warning will become an error in version 1.7 release, please fix the code in advance"); \
99   }
100 
check_from_to_in_range(int64_t from,int64_t to_inc,caffe2::TypeMeta dtype)101 inline void check_from_to_in_range(int64_t from, int64_t to_inc, caffe2::TypeMeta dtype) {
102   const auto scalar_type = typeMetaToScalarType(dtype);
103   if (isFloatingType(scalar_type)) {
104     AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, scalar_type, "check_random_fp_bounds", [&] {
105       const auto min = static_cast<double>(std::numeric_limits<scalar_t>::lowest());
106       const auto max = static_cast<double>(std::numeric_limits<scalar_t>::max());
107       CHECK_OUT_OF_BOUNDS(from, "from", min, max, dtype);
108       CHECK_OUT_OF_BOUNDS(to_inc, "to - 1", min, max, dtype);
109 
110       constexpr auto digits = std::numeric_limits<scalar_t>::digits;
111       WARN_OUT_OF_BOUNDS(from, "from", digits, dtype);
112       WARN_OUT_OF_BOUNDS(to_inc, "to - 1", digits, dtype);
113     });
114   } else if (scalar_type == kUInt64) {
115     // When you do a comparison between int64_t and uint64_t, the usual
116     // arithmetic conversions say that the int64_t value is promoted to
117     // unsigned. But this conversion wraps around: if I had -1 as my int64_t,
118     // then it will promote to 0xFFFFFFFFFFFFFFFF in uint64_t. This is never
119     // the right thing to do.
120     CHECK_OUT_OF_BOUNDS(from, "from", 0, INT64_MAX, dtype);
121     CHECK_OUT_OF_BOUNDS(to_inc, "to - 1", 0, INT64_MAX, dtype);
122   } else if (isIntegralType(scalar_type, /*includeBool=*/true)) {
123     AT_DISPATCH_V2(scalar_type, "check_random_integral_bounds", AT_WRAP([&]() {
124       const auto min = static_cast<int64_t>(std::numeric_limits<scalar_t>::lowest());
125       const auto max = static_cast<int64_t>(std::numeric_limits<scalar_t>::max());
126       CHECK_OUT_OF_BOUNDS(from, "from", min, max, dtype);
127       CHECK_OUT_OF_BOUNDS(to_inc, "to - 1", min, max, dtype);
128     }), AT_EXPAND(AT_INTEGRAL_TYPES), kUInt16, kUInt32, kBool);
129   } else {
130     TORCH_CHECK(false, "check_random_bounds handles only integral, floating-point and boolean types");
131   }
132 }
133 
134 template<template<typename> class random_from_to_kernel, typename RNG>
random_from_to_impl(at::Tensor & self,int64_t from,std::optional<int64_t> to_opt,std::optional<Generator> generator)135 at::Tensor& random_from_to_impl(at::Tensor& self, int64_t from, std::optional<int64_t> to_opt, std::optional<Generator> generator) {
136   uint64_t range = 0;
137   auto iter = at::TensorIterator::borrowing_nullary_op(self);
138   if (to_opt.has_value()) {
139     // [from, to)
140     int64_t to = *to_opt;
141     TORCH_CHECK(from < to, "random_ expects 'from' to be less than 'to', but got from=", from, " >= to=", to);
142     if (isFloatingType(iter.dtype())) {
143       AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "random_update_from_to", [&] {
144         from = update_from<scalar_t>(from);
145         to = update_to<scalar_t>(to);
146         TORCH_CHECK(from < to, "random_ expects 'from' casted to dtype to be less than 'to' casted to dtype, but got from=", from, " >= to=", to);
147       });
148     }
149     check_from_to_in_range(from, to - 1, self.dtype());
150     CHECK_EMPTY_AND_RETURN(self);
151     range = static_cast<uint64_t>(to) - static_cast<uint64_t>(from);
152     random_from_to_kernel<RNG>()(iter, range, from, generator);
153   } else if (from != std::numeric_limits<int64_t>::lowest()) {
154     // [from, std::numeric_limits<int64_t>::max()]
155     int64_t to_inc = 0;
156     if (isFloatingType(iter.dtype())) {
157       AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "random_from_to_range_calc", [&] {
158         constexpr int64_t scalar_t_max = static_cast<int64_t>(1) << std::numeric_limits<scalar_t>::digits;
159         to_inc = scalar_t_max > std::numeric_limits<int64_t>::max() ? std::numeric_limits<int64_t>::max() : static_cast<int64_t>(scalar_t_max);
160         from = update_from<scalar_t>(from);
161         TORCH_CHECK(from < to_inc, "random_ expects 'from' casted to dtype to be less than or equal to 'to_inc' casted to dtype, but got from=", from, " > to_inc=", to_inc);
162       });
163     } else if (isIntegralType(iter.dtype(), /*includeBool=*/true)) {
164       AT_DISPATCH_V2(self.scalar_type(), "random_from_to_range_calc", AT_WRAP([&] {
165         if constexpr (std::is_same_v<scalar_t, bool>) {
166           to_inc = static_cast<int64_t>(true);
167         } else {
168           to_inc = static_cast<int64_t>(std::numeric_limits<scalar_t>::max());
169         }
170       }), AT_EXPAND(AT_INTEGRAL_TYPES_V2), kBool);
171     } else {
172       TORCH_CHECK(false, "random_from_to_impl handles only integral, floating-point and boolean types");
173     }
174     check_from_to_in_range(from, to_inc, self.dtype());
175     CHECK_EMPTY_AND_RETURN(self);
176     range = static_cast<uint64_t>(to_inc) - static_cast<uint64_t>(from) + 1;
177     random_from_to_kernel<RNG>()(iter, range, from, generator);
178   } else {
179     // [std::numeric_limits<int64_t>::lowest(), std::numeric_limits<int64_t>::max()]
180     // range = 2^64
181     CHECK_EMPTY_AND_RETURN(self);
182     random_from_to_kernel<RNG>()(iter, generator);
183   }
184   return self;
185 }
186 
187 // ==================================================== Normal ========================================================
188 
189 #define CHECK_NORMAL_TENSOR_STD(std) \
190   do { \
191     TORCH_CHECK( \
192       !std.is_complex(), \
193       "normal expects standard deviation to be non-complex"); \
194     TORCH_CHECK( \
195       std.numel() == 0 || std.is_meta() || std.min().ge(0).item<bool>(), \
196       "normal expects all elements of std >= 0.0"); \
197   } while (0)
198 
199 #define CHECK_NORMAL_STD(std) \
200   TORCH_CHECK(std >= 0.0, "normal expects std >= 0.0, but found std ", std);
201 
202 template<template<typename> class normal_kernel, typename RNG>
normal_impl_(Tensor & self,double mean,double std,std::optional<Generator> gen)203 Tensor& normal_impl_(Tensor& self, double mean, double std, std::optional<Generator> gen) {
204   CHECK_NORMAL_STD(std);
205   CHECK_EMPTY_AND_RETURN(self);
206 
207   if (self.is_complex()) {
208     auto float_tensor = at::view_as_real(self);
209     // variance for normal distribution of the real and imaginary values
210     // is half of the input variance
211     normal_kernel<RNG>()(float_tensor, mean, std/(std::sqrt(2)), gen);
212   } else {
213     normal_kernel<RNG>()(self, mean, std, gen);
214   }
215   return self;
216 }
217 
218 template<template<typename> class normal_kernel, typename RNG>
normal_out_impl(Tensor & output,const Tensor & mean,double std,std::optional<Generator> gen)219 Tensor& normal_out_impl(Tensor& output, const Tensor& mean, double std, std::optional<Generator> gen) {
220   CHECK_NORMAL_STD(std);
221   auto std_tensor = at::empty_like(output, MemoryFormat::Contiguous);
222   auto shape = at::infer_size(mean.sizes(), std_tensor.sizes());
223   at::native::resize_output(output, shape);
224   normal_impl_<normal_kernel, RNG>(output, 0, std, gen);
225   output.add_(mean);
226   return output;
227 }
228 
229 template<template<typename> class normal_kernel, typename RNG>
normal_out_impl(Tensor & output,double mean,const Tensor & std,std::optional<Generator> gen)230 Tensor& normal_out_impl(Tensor& output, double mean, const Tensor& std, std::optional<Generator> gen) {
231   CHECK_NORMAL_TENSOR_STD(std);
232   auto mean_tensor = at::full({}, mean, output.options());
233   auto shape = at::infer_size(mean_tensor.sizes(), std.sizes());
234   at::native::resize_output(output, shape);
235   normal_impl_<normal_kernel, RNG>(output, 0, 1, gen);
236   // CUDA NB: addcmul_out copies the tensor to be added into the output.
237   // The previous function here was addcmul_out(output, mean_tensor, output, std, 1);
238   // The third argument is not a constant reference and hence the samples in output are overwritten.
239   // Consequently, the computation performed is mean_tensor + mean_tensor * std instead of mean_tensor + output * std
240   output.mul_(std).add_(mean_tensor);
241   return output;
242 }
243 
244 template<template<typename> class normal_kernel, typename RNG>
normal_out_impl(Tensor & output,const Tensor & mean,const Tensor & std,std::optional<Generator> gen)245 Tensor& normal_out_impl(Tensor& output, const Tensor& mean, const Tensor& std, std::optional<Generator> gen) {
246   CHECK_NORMAL_TENSOR_STD(std);
247   auto shape = at::infer_size(mean.sizes(), std.sizes());
248   at::native::resize_output(output, shape);
249   normal_impl_<normal_kernel, RNG>(output, 0, 1, gen);
250   // CUDA NB: addcmul_out copies the tensor to be added into the output.
251   // The previous function here was addcmul_out(output, mean, output, std, 1);
252   // The third argument is not a constant reference and hence the samples in output are overwritten.
253   // Consequently, the computation performed is mean + mean * std instead of mean + output * std
254   output.mul_(std).add_(mean);
255   return output;
256 }
257 
258 template<template<typename> class normal_kernel, typename RNG>
normal_impl(const Tensor & mean,double std,std::optional<Generator> gen)259 Tensor normal_impl(const Tensor& mean, double std, std::optional<Generator> gen) {
260   CHECK_NORMAL_STD(std);
261   Tensor ret = at::empty_like(mean, MemoryFormat::Contiguous);
262   normal_out_impl<normal_kernel, RNG>(ret, mean, std, gen);
263   return ret;
264 }
265 
266 template<template<typename> class normal_kernel, typename RNG>
normal_impl(double mean,const Tensor & std,std::optional<Generator> gen)267 Tensor normal_impl(double mean, const Tensor& std, std::optional<Generator> gen) {
268   CHECK_NORMAL_TENSOR_STD(std);
269   Tensor ret = at::empty_like(std, MemoryFormat::Contiguous);
270   normal_out_impl<normal_kernel, RNG>(ret, mean, std, gen);
271   return ret;
272 }
273 
274 template<template<typename> class normal_kernel, typename RNG>
normal_impl(const Tensor & mean,const Tensor & std,std::optional<Generator> gen)275 Tensor normal_impl(const Tensor& mean, const Tensor& std, std::optional<Generator> gen) {
276   CHECK_NORMAL_TENSOR_STD(std);
277   auto shape = at::infer_size(mean.sizes(), std.sizes());
278   Tensor ret = at::empty(shape, mean.options(), MemoryFormat::Contiguous);
279   normal_out_impl<normal_kernel, RNG>(ret, mean, std, gen);
280   return ret;
281 }
282 
283 // ==================================================== Uniform =======================================================
284 
285 template<template<typename> class uniform_kernel, typename RNG>
uniform_impl_(at::Tensor & self,double from,double to,std::optional<Generator> generator)286 at::Tensor& uniform_impl_(at::Tensor& self, double from, double to, std::optional<Generator> generator) {
287   if (self.is_complex()) {
288     CHECK_EMPTY_AND_RETURN(self);
289     auto float_tensor = at::view_as_real(self);
290     uniform_impl_<uniform_kernel, RNG>(float_tensor, from, to, generator);
291   } else {
292     AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "check_uniform_bounds", [&] {
293       [[maybe_unused]] const auto dtype = self.dtype();
294       const auto min = static_cast<double>(std::numeric_limits<scalar_t>::lowest());
295       const auto max = static_cast<double>(std::numeric_limits<scalar_t>::max());
296       CHECK_OUT_OF_BOUNDS(from, "from", min, max, dtype);
297       CHECK_OUT_OF_BOUNDS(to, "to", min, max, dtype);
298       TORCH_CHECK(from <= to, "uniform_ expects to return a [from, to) range, but found from=", from, " > to=", to);
299       TORCH_CHECK((to - from) <= std::numeric_limits<scalar_t>::max(),
300             "uniform_ expects to-from <= std::numeric_limits<", toString(self.scalar_type()),
301             ">::max(), but found to=", to, " and from=", from,
302             " which result in to-from to exceed the limit");
303       from = std::min(std::max(from, min), max);
304       to = std::max(std::min(to, max), min);
305     });
306     CHECK_EMPTY_AND_RETURN(self);
307     auto iter = at::TensorIterator::borrowing_nullary_op(self);
308     uniform_kernel<RNG>()(iter, from, to, generator);
309   }
310   return self;
311 }
312 
313 // ================================================== LogNormal =======================================================
314 
315 template<template<typename> class log_normal_kernel, typename RNG>
log_normal_impl_(at::Tensor & self,double mean,double std,std::optional<Generator> gen)316 at::Tensor& log_normal_impl_(at::Tensor& self, double mean, double std, std::optional<Generator> gen) {
317   TORCH_CHECK(std > 0.0, "log_normal_ expects std > 0.0, but found std=", std);
318   CHECK_EMPTY_AND_RETURN(self);
319   auto iter = TensorIterator::borrowing_nullary_op(self);
320   log_normal_kernel<RNG>()(iter, mean, std, gen);
321   return self;
322 }
323 
324 // =================================================== Geometric ======================================================
325 
326 template<template<typename> class geometric_kernel, typename RNG>
geometric_impl_(Tensor & self,double p,std::optional<Generator> gen)327 Tensor& geometric_impl_(Tensor& self, double p, std::optional<Generator> gen) {
328   TORCH_CHECK(0 < p && p < 1, "geometric_ expects p to be in (0, 1), but got p=", p);
329   CHECK_EMPTY_AND_RETURN(self);
330   auto iter = TensorIterator::borrowing_nullary_op(self);
331   geometric_kernel<RNG>()(iter, p, gen);
332   return self;
333 }
334 
335 // ================================================== Exponential =====================================================
336 
337 template<template<typename> class exponential_kernel, typename RNG>
exponential_impl_(Tensor & self,double lambda,std::optional<Generator> gen)338 Tensor& exponential_impl_(Tensor& self, double lambda, std::optional<Generator> gen) {
339   TORCH_CHECK(lambda > 0.0, "exponential_ expects lambda > 0.0, but found lambda=", lambda);
340   CHECK_EMPTY_AND_RETURN(self);
341   auto iter = TensorIterator::borrowing_nullary_op(self);
342   exponential_kernel<RNG>()(iter, lambda, gen);
343   return self;
344 }
345 
346 // ==================================================== Cauchy ========================================================
347 
348 template<template<typename> class cauchy_kernel, typename RNG>
cauchy_impl_(Tensor & self,double median,double sigma,std::optional<Generator> gen)349 Tensor& cauchy_impl_(Tensor& self, double median, double sigma, std::optional<Generator> gen) {
350   // TODO: instead of variable name 'sigma', use 'gamma' or 'scale'
351   // the variance, squared sigma, is undefined for cauchy distribution
352   TORCH_CHECK(sigma > 0.0, "cauchy_ expects sigma > 0.0, but found sigma=", sigma);
353   TORCH_CHECK(at::isFloatingType(self.scalar_type()), "Cauchy distribution is a continuous probability distribution. dtype must be a floating point but you specified ", self.dtype());
354   CHECK_EMPTY_AND_RETURN(self);
355   auto iter = TensorIterator::borrowing_nullary_op(self);
356   cauchy_kernel<RNG>()(iter, median, sigma, gen);
357   return self;
358 }
359 
360 // ==================================================== Bernoulli =====================================================
361 
362 template<template<typename> class bernoulli_tensor_kernel, typename RNG>
bernoulli_impl_(Tensor & self,const Tensor & p_,std::optional<Generator> gen)363 Tensor& bernoulli_impl_(Tensor& self, const Tensor& p_, std::optional<Generator> gen) {
364   CHECK_EMPTY_AND_RETURN(self);
365   NoNamesGuard guard;
366   at::assert_no_internal_overlap(self);
367   bernoulli_tensor_kernel<RNG>()(self, p_, gen);
368   return self;
369 }
370 
371 template<template<typename> class bernoulli_scalar_kernel, typename RNG>
bernoulli_impl_(Tensor & self,double p,std::optional<Generator> gen)372 Tensor& bernoulli_impl_(Tensor& self, double p, std::optional<Generator> gen) {
373   TORCH_CHECK(0 <= p && p <= 1, "bernoulli_ expects p to be in [0, 1], but got p=", p);
374   CHECK_EMPTY_AND_RETURN(self);
375   at::assert_no_internal_overlap(self);
376   bernoulli_scalar_kernel<RNG>()(self, p, gen);
377   return self;
378 }
379 
380 template<template<typename> class bernoulli_tensor_kernel, typename RNG>
bernoulli_out_impl(Tensor & result,const Tensor & self,std::optional<Generator> gen)381 Tensor& bernoulli_out_impl(Tensor& result, const Tensor& self, std::optional<Generator> gen) {
382   // result.resize_as_(self) requires self to have same dtype as result, so we
383   // use resize_ instead.
384   // TODO: Fix resize_as_. See pytorch/pytorch#11665.
385   result.resize_(self.sizes());
386   bernoulli_impl_<bernoulli_tensor_kernel, RNG>(result, self, gen);
387   namedinference::propagate_names(result, self);
388   return result;
389 }
390 
391 #undef CHECK_OUT_OF_BOUNDS
392 #undef WARN_OUT_OF_BOUNDS
393 
394 } // namespace at::native::templates
395