xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/MultinomialKernel.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 
4 #include <ATen/Dispatch.h>
5 #include <ATen/core/DistributionsHelper.h>
6 #include <ATen/native/Copy.h>
7 #include <ATen/native/TensorIterator.h>
8 #include <ATen/native/UnaryOps.h>
9 #include <ATen/native/cpu/Loops.h>
10 #include <c10/util/irange.h>
11 
12 #ifndef AT_PER_OPERATOR_HEADERS
13 #include <ATen/Functions.h>
14 #else
15 #include <ATen/ops/empty.h>
16 #endif
17 
18 namespace at::native {
19 namespace {
20 
21 template <typename scalar_t>
22 typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, void>
multinomial_with_replacement_apply(Tensor & result,const Tensor & self,const int64_t n_sample,std::optional<Generator> generator)23 multinomial_with_replacement_apply(
24     Tensor& result,
25     const Tensor& self,
26     const int64_t n_sample,
27     std::optional<Generator> generator) {
28   auto gen = get_generator_or_default<CPUGeneratorImpl>(
29       generator, detail::getDefaultCPUGenerator());
30   // See Note [Acquire lock when using random generators]
31   std::lock_guard<std::mutex> lock(gen->mutex_);
32 
33   int64_t n_categories = self.size(-1);
34   int64_t n_dist = self.dim() > 1 ? self.size(-2) : 1;
35 
36   /* cumulative probability distribution vector */
37   Tensor cum_dist = at::empty({n_categories}, self.options());
38 
39   const scalar_t* const self_ptr = self.const_data_ptr<scalar_t>();
40   scalar_t* const cum_dist_ptr = cum_dist.data_ptr<scalar_t>();
41   int64_t* const result_ptr = result.data_ptr<int64_t>();
42 
43   auto self_stride_0 = self.dim() > 1 ? self.stride(-2) : 0;
44   auto self_stride_1 = self.stride(-1);
45 
46   auto cum_dist_stride_0 = cum_dist.stride(0);
47 
48   auto result_dist_stride_0 = result.dim() > 1 ? result.stride(-2) : 0;
49   auto result_dist_stride_1 = result.stride(-1);
50 
51   for (const auto i : c10::irange(n_dist)) {
52     /* Get normalized cumulative distribution from prob distribution */
53     scalar_t sum = 0;
54     // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
55     scalar_t val;
56     for (const auto j : c10::irange(n_categories)) {
57       val = self_ptr[i * self_stride_0 + j * self_stride_1];
58       TORCH_CHECK(
59           val >= 0,
60           "invalid multinomial distribution (encountering probability entry < 0)");
61 // NB: std::isfinite doesn't bode well with libc++ for half datatypes,
62 // so we manually cast it to a double and perform the check.
63 #if defined(_LIBCPP_VERSION)
64       TORCH_CHECK(
65           std::isfinite(static_cast<double>(val)),
66           "invalid multinomial distribution (encountering probability entry = infinity or NaN)");
67 #else
68       TORCH_CHECK(
69           std::isfinite(val),
70           "invalid multinomial distribution (encountering probability entry = infinity or NaN)");
71 #endif
72 
73       sum += val;
74       cum_dist_ptr[j * cum_dist_stride_0] = sum;
75     }
76 
77     TORCH_CHECK(
78         sum > 0,
79         "invalid multinomial distribution (sum of probabilities <= 0)");
80 
81     /* normalize cumulative probability distribution so that last val is 1
82     i.e. doesn't assume original self row sums to one */
83     if ((sum > 0) || ((sum < 1.00001) && (sum > 0.99999))) {
84       for (const auto j : c10::irange(n_categories)) {
85         cum_dist_ptr[j * cum_dist_stride_0] /= sum;
86       }
87     }
88 
89     for (const auto j : c10::irange(n_sample)) {
90       /* sample a probability mass from a uniform distribution */
91       at::uniform_real_distribution<double> uniform(0, 1);
92       double uniform_sample = uniform(gen);
93       /* Do a binary search for the slot in which the prob falls
94       ie cum_dist[row][slot-1] < uniform_prob < cum_distr[row][slot] */
95       int left_pointer = 0;
96       int right_pointer = n_categories;
97       // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
98       int mid_pointer;
99       // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
100       scalar_t cum_prob;
101       // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
102       int sample_idx;
103       /* Make sure the last cumulative distribution bucket sums to 1 */
104       cum_dist_ptr[(n_categories - 1) * cum_dist_stride_0] = 1;
105 
106       while (right_pointer - left_pointer > 0) {
107         mid_pointer = left_pointer + (right_pointer - left_pointer) / 2;
108         cum_prob = cum_dist_ptr[mid_pointer * cum_dist_stride_0];
109         if (cum_prob < uniform_sample) {
110           left_pointer = mid_pointer + 1;
111         } else {
112           right_pointer = mid_pointer;
113         }
114       }
115       sample_idx = left_pointer;
116 
117       /* store in result tensor (will be incremented for lua compat by wrapper)
118        */
119       result_ptr[i * result_dist_stride_0 + j * result_dist_stride_1] =
120           sample_idx;
121     }
122   }
123 }
124 
125 template <typename scalar_t>
126 typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, void>
multinomial_with_replacement_apply(Tensor & result,const Tensor & self,const int64_t n_sample,std::optional<Generator> generator)127 multinomial_with_replacement_apply(
128     Tensor& result,
129     const Tensor& self,
130     const int64_t n_sample,
131     std::optional<Generator> generator) {
132   auto gen = get_generator_or_default<CPUGeneratorImpl>(
133       generator, detail::getDefaultCPUGenerator());
134   // See Note [Acquire lock when using random generators]
135   std::lock_guard<std::mutex> lock(gen->mutex_);
136 
137   int64_t n_categories = self.size(-1);
138   int64_t n_dist = self.dim() > 1 ? self.size(-2) : 1;
139 
140   /* cumulative probability distribution vector */
141   Tensor cum_dist = at::empty({n_categories}, self.options().dtype(kFloat));
142 
143   const scalar_t* const self_ptr = self.const_data_ptr<scalar_t>();
144   float* const cum_dist_ptr = cum_dist.data_ptr<float>();
145   int64_t* const result_ptr = result.data_ptr<int64_t>();
146 
147   auto self_stride_0 = self.dim() > 1 ? self.stride(-2) : 0;
148   auto self_stride_1 = self.stride(-1);
149 
150   auto cum_dist_stride_0 = cum_dist.stride(0);
151 
152   auto result_dist_stride_0 = result.dim() > 1 ? result.stride(-2) : 0;
153   auto result_dist_stride_1 = result.stride(-1);
154 
155   for (const auto i : c10::irange(n_dist)) {
156     /* Get normalized cumulative distribution from prob distribution */
157     float sum = 0;
158     // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
159     float val;
160     for (const auto j : c10::irange(n_categories)) {
161       val = self_ptr[i * self_stride_0 + j * self_stride_1];
162       TORCH_CHECK(
163           val >= 0,
164           "invalid multinomial distribution (encountering probability entry < 0)");
165 // NB: std::isfinite doesn't bode well with libc++ for half datatypes,
166 // so we manually cast it to a double and perform the check.
167 #if defined(_LIBCPP_VERSION)
168       TORCH_CHECK(
169           std::isfinite(static_cast<double>(val)),
170           "invalid multinomial distribution (encountering probability entry = infinity or NaN)");
171 #else
172       TORCH_CHECK(
173           std::isfinite(val),
174           "invalid multinomial distribution (encountering probability entry = infinity or NaN)");
175 #endif
176 
177       sum += val;
178       cum_dist_ptr[j * cum_dist_stride_0] = sum;
179     }
180 
181     TORCH_CHECK(
182         sum > 0,
183         "invalid multinomial distribution (sum of probabilities <= 0)");
184 
185     /* normalize cumulative probability distribution so that last val is 1
186     i.e. doesn't assume original self row sums to one */
187     if ((sum > 0) || ((sum < 1.00001) && (sum > 0.99999))) {
188       for (const auto j : c10::irange(n_categories)) {
189         cum_dist_ptr[j * cum_dist_stride_0] /= sum;
190       }
191     }
192 
193     for (const auto j : c10::irange(n_sample)) {
194       /* sample a probability mass from a uniform distribution */
195       at::uniform_real_distribution<double> uniform(0, 1);
196       double uniform_sample = uniform(gen);
197       /* Do a binary search for the slot in which the prob falls
198       ie cum_dist[row][slot-1] < uniform_prob < cum_distr[row][slot] */
199       int left_pointer = 0;
200       int right_pointer = n_categories;
201       // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
202       int mid_pointer;
203       // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
204       float cum_prob;
205       // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
206       int sample_idx;
207       /* Make sure the last cumulative distribution bucket sums to 1 */
208       cum_dist_ptr[(n_categories - 1) * cum_dist_stride_0] = 1;
209 
210       while (right_pointer - left_pointer > 0) {
211         mid_pointer = left_pointer + (right_pointer - left_pointer) / 2;
212         cum_prob = cum_dist_ptr[mid_pointer * cum_dist_stride_0];
213         if (cum_prob < uniform_sample) {
214           left_pointer = mid_pointer + 1;
215         } else {
216           right_pointer = mid_pointer;
217         }
218       }
219       sample_idx = left_pointer;
220 
221       /* store in result tensor (will be incremented for lua compat by wrapper)
222        */
223       result_ptr[i * result_dist_stride_0 + j * result_dist_stride_1] =
224           sample_idx;
225     }
226   }
227 }
228 
multinomial_with_replacement_kernel_impl(Tensor & result,const Tensor & self,const int64_t n_sample,std::optional<Generator> gen)229 static void multinomial_with_replacement_kernel_impl(
230     Tensor& result,
231     const Tensor& self,
232     const int64_t n_sample,
233     std::optional<Generator> gen) {
234   AT_DISPATCH_FLOATING_TYPES_AND2(
235       kHalf, kBFloat16, self.scalar_type(), "multinomial", [&] {
236         multinomial_with_replacement_apply<scalar_t>(
237             result, self, n_sample, gen);
238       });
239 }
240 } // namespace
241 
242 REGISTER_DISPATCH(
243     multinomial_with_replacement_stub,
244     &multinomial_with_replacement_kernel_impl);
245 } // namespace at::native
246