1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3
4 #include <ATen/Dispatch.h>
5 #include <ATen/core/DistributionsHelper.h>
6 #include <ATen/native/Copy.h>
7 #include <ATen/native/TensorIterator.h>
8 #include <ATen/native/UnaryOps.h>
9 #include <ATen/native/cpu/Loops.h>
10 #include <c10/util/irange.h>
11
12 #ifndef AT_PER_OPERATOR_HEADERS
13 #include <ATen/Functions.h>
14 #else
15 #include <ATen/ops/empty.h>
16 #endif
17
18 namespace at::native {
19 namespace {
20
21 template <typename scalar_t>
22 typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, void>
multinomial_with_replacement_apply(Tensor & result,const Tensor & self,const int64_t n_sample,std::optional<Generator> generator)23 multinomial_with_replacement_apply(
24 Tensor& result,
25 const Tensor& self,
26 const int64_t n_sample,
27 std::optional<Generator> generator) {
28 auto gen = get_generator_or_default<CPUGeneratorImpl>(
29 generator, detail::getDefaultCPUGenerator());
30 // See Note [Acquire lock when using random generators]
31 std::lock_guard<std::mutex> lock(gen->mutex_);
32
33 int64_t n_categories = self.size(-1);
34 int64_t n_dist = self.dim() > 1 ? self.size(-2) : 1;
35
36 /* cumulative probability distribution vector */
37 Tensor cum_dist = at::empty({n_categories}, self.options());
38
39 const scalar_t* const self_ptr = self.const_data_ptr<scalar_t>();
40 scalar_t* const cum_dist_ptr = cum_dist.data_ptr<scalar_t>();
41 int64_t* const result_ptr = result.data_ptr<int64_t>();
42
43 auto self_stride_0 = self.dim() > 1 ? self.stride(-2) : 0;
44 auto self_stride_1 = self.stride(-1);
45
46 auto cum_dist_stride_0 = cum_dist.stride(0);
47
48 auto result_dist_stride_0 = result.dim() > 1 ? result.stride(-2) : 0;
49 auto result_dist_stride_1 = result.stride(-1);
50
51 for (const auto i : c10::irange(n_dist)) {
52 /* Get normalized cumulative distribution from prob distribution */
53 scalar_t sum = 0;
54 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
55 scalar_t val;
56 for (const auto j : c10::irange(n_categories)) {
57 val = self_ptr[i * self_stride_0 + j * self_stride_1];
58 TORCH_CHECK(
59 val >= 0,
60 "invalid multinomial distribution (encountering probability entry < 0)");
61 // NB: std::isfinite doesn't bode well with libc++ for half datatypes,
62 // so we manually cast it to a double and perform the check.
63 #if defined(_LIBCPP_VERSION)
64 TORCH_CHECK(
65 std::isfinite(static_cast<double>(val)),
66 "invalid multinomial distribution (encountering probability entry = infinity or NaN)");
67 #else
68 TORCH_CHECK(
69 std::isfinite(val),
70 "invalid multinomial distribution (encountering probability entry = infinity or NaN)");
71 #endif
72
73 sum += val;
74 cum_dist_ptr[j * cum_dist_stride_0] = sum;
75 }
76
77 TORCH_CHECK(
78 sum > 0,
79 "invalid multinomial distribution (sum of probabilities <= 0)");
80
81 /* normalize cumulative probability distribution so that last val is 1
82 i.e. doesn't assume original self row sums to one */
83 if ((sum > 0) || ((sum < 1.00001) && (sum > 0.99999))) {
84 for (const auto j : c10::irange(n_categories)) {
85 cum_dist_ptr[j * cum_dist_stride_0] /= sum;
86 }
87 }
88
89 for (const auto j : c10::irange(n_sample)) {
90 /* sample a probability mass from a uniform distribution */
91 at::uniform_real_distribution<double> uniform(0, 1);
92 double uniform_sample = uniform(gen);
93 /* Do a binary search for the slot in which the prob falls
94 ie cum_dist[row][slot-1] < uniform_prob < cum_distr[row][slot] */
95 int left_pointer = 0;
96 int right_pointer = n_categories;
97 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
98 int mid_pointer;
99 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
100 scalar_t cum_prob;
101 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
102 int sample_idx;
103 /* Make sure the last cumulative distribution bucket sums to 1 */
104 cum_dist_ptr[(n_categories - 1) * cum_dist_stride_0] = 1;
105
106 while (right_pointer - left_pointer > 0) {
107 mid_pointer = left_pointer + (right_pointer - left_pointer) / 2;
108 cum_prob = cum_dist_ptr[mid_pointer * cum_dist_stride_0];
109 if (cum_prob < uniform_sample) {
110 left_pointer = mid_pointer + 1;
111 } else {
112 right_pointer = mid_pointer;
113 }
114 }
115 sample_idx = left_pointer;
116
117 /* store in result tensor (will be incremented for lua compat by wrapper)
118 */
119 result_ptr[i * result_dist_stride_0 + j * result_dist_stride_1] =
120 sample_idx;
121 }
122 }
123 }
124
125 template <typename scalar_t>
126 typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, void>
multinomial_with_replacement_apply(Tensor & result,const Tensor & self,const int64_t n_sample,std::optional<Generator> generator)127 multinomial_with_replacement_apply(
128 Tensor& result,
129 const Tensor& self,
130 const int64_t n_sample,
131 std::optional<Generator> generator) {
132 auto gen = get_generator_or_default<CPUGeneratorImpl>(
133 generator, detail::getDefaultCPUGenerator());
134 // See Note [Acquire lock when using random generators]
135 std::lock_guard<std::mutex> lock(gen->mutex_);
136
137 int64_t n_categories = self.size(-1);
138 int64_t n_dist = self.dim() > 1 ? self.size(-2) : 1;
139
140 /* cumulative probability distribution vector */
141 Tensor cum_dist = at::empty({n_categories}, self.options().dtype(kFloat));
142
143 const scalar_t* const self_ptr = self.const_data_ptr<scalar_t>();
144 float* const cum_dist_ptr = cum_dist.data_ptr<float>();
145 int64_t* const result_ptr = result.data_ptr<int64_t>();
146
147 auto self_stride_0 = self.dim() > 1 ? self.stride(-2) : 0;
148 auto self_stride_1 = self.stride(-1);
149
150 auto cum_dist_stride_0 = cum_dist.stride(0);
151
152 auto result_dist_stride_0 = result.dim() > 1 ? result.stride(-2) : 0;
153 auto result_dist_stride_1 = result.stride(-1);
154
155 for (const auto i : c10::irange(n_dist)) {
156 /* Get normalized cumulative distribution from prob distribution */
157 float sum = 0;
158 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
159 float val;
160 for (const auto j : c10::irange(n_categories)) {
161 val = self_ptr[i * self_stride_0 + j * self_stride_1];
162 TORCH_CHECK(
163 val >= 0,
164 "invalid multinomial distribution (encountering probability entry < 0)");
165 // NB: std::isfinite doesn't bode well with libc++ for half datatypes,
166 // so we manually cast it to a double and perform the check.
167 #if defined(_LIBCPP_VERSION)
168 TORCH_CHECK(
169 std::isfinite(static_cast<double>(val)),
170 "invalid multinomial distribution (encountering probability entry = infinity or NaN)");
171 #else
172 TORCH_CHECK(
173 std::isfinite(val),
174 "invalid multinomial distribution (encountering probability entry = infinity or NaN)");
175 #endif
176
177 sum += val;
178 cum_dist_ptr[j * cum_dist_stride_0] = sum;
179 }
180
181 TORCH_CHECK(
182 sum > 0,
183 "invalid multinomial distribution (sum of probabilities <= 0)");
184
185 /* normalize cumulative probability distribution so that last val is 1
186 i.e. doesn't assume original self row sums to one */
187 if ((sum > 0) || ((sum < 1.00001) && (sum > 0.99999))) {
188 for (const auto j : c10::irange(n_categories)) {
189 cum_dist_ptr[j * cum_dist_stride_0] /= sum;
190 }
191 }
192
193 for (const auto j : c10::irange(n_sample)) {
194 /* sample a probability mass from a uniform distribution */
195 at::uniform_real_distribution<double> uniform(0, 1);
196 double uniform_sample = uniform(gen);
197 /* Do a binary search for the slot in which the prob falls
198 ie cum_dist[row][slot-1] < uniform_prob < cum_distr[row][slot] */
199 int left_pointer = 0;
200 int right_pointer = n_categories;
201 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
202 int mid_pointer;
203 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
204 float cum_prob;
205 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
206 int sample_idx;
207 /* Make sure the last cumulative distribution bucket sums to 1 */
208 cum_dist_ptr[(n_categories - 1) * cum_dist_stride_0] = 1;
209
210 while (right_pointer - left_pointer > 0) {
211 mid_pointer = left_pointer + (right_pointer - left_pointer) / 2;
212 cum_prob = cum_dist_ptr[mid_pointer * cum_dist_stride_0];
213 if (cum_prob < uniform_sample) {
214 left_pointer = mid_pointer + 1;
215 } else {
216 right_pointer = mid_pointer;
217 }
218 }
219 sample_idx = left_pointer;
220
221 /* store in result tensor (will be incremented for lua compat by wrapper)
222 */
223 result_ptr[i * result_dist_stride_0 + j * result_dist_stride_1] =
224 sample_idx;
225 }
226 }
227 }
228
multinomial_with_replacement_kernel_impl(Tensor & result,const Tensor & self,const int64_t n_sample,std::optional<Generator> gen)229 static void multinomial_with_replacement_kernel_impl(
230 Tensor& result,
231 const Tensor& self,
232 const int64_t n_sample,
233 std::optional<Generator> gen) {
234 AT_DISPATCH_FLOATING_TYPES_AND2(
235 kHalf, kBFloat16, self.scalar_type(), "multinomial", [&] {
236 multinomial_with_replacement_apply<scalar_t>(
237 result, self, n_sample, gen);
238 });
239 }
240 } // namespace
241
242 REGISTER_DISPATCH(
243 multinomial_with_replacement_stub,
244 &multinomial_with_replacement_kernel_impl);
245 } // namespace at::native
246