xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/stateless_random_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/bounds_check.h"
17 #include "tensorflow/core/framework/op_kernel.h"
18 #include "tensorflow/core/framework/register_types.h"
19 #include "tensorflow/core/framework/tensor.h"
20 #include "tensorflow/core/framework/tensor_shape.h"
21 #include "tensorflow/core/framework/tensor_util.h"
22 #include "tensorflow/core/kernels/random_op.h"
23 #include "tensorflow/core/kernels/random_poisson_op.h"
24 #include "tensorflow/core/lib/random/random_distributions.h"
25 #include "tensorflow/core/platform/logging.h"
26 
27 namespace tensorflow {
28 
29 using CPUDevice = Eigen::ThreadPoolDevice;
30 using GPUDevice = Eigen::GpuDevice;
31 
GenerateKey(Tensor seed,random::PhiloxRandom::Key * out_key,random::PhiloxRandom::ResultType * out_counter)32 Status GenerateKey(Tensor seed, random::PhiloxRandom::Key* out_key,
33                    random::PhiloxRandom::ResultType* out_counter) {
34   // Grab the two seeds
35   uint64 seed0;
36   uint64 seed1;
37   if (seed.dtype() == DT_INT32) {
38     const auto seed_vals = seed.flat<int32>();
39     seed0 = internal::SubtleMustCopy(seed_vals(0));
40     seed1 = internal::SubtleMustCopy(seed_vals(1));
41   } else if (seed.dtype() == DT_INT64) {
42     const auto seed_vals = seed.flat<int64_t>();
43     seed0 = internal::SubtleMustCopy(seed_vals(0));
44     seed1 = internal::SubtleMustCopy(seed_vals(1));
45   } else {
46     return errors::InvalidArgument("Invalid seed type: ",
47                                    DataTypeString(seed.dtype()));
48   }
49 
50   // Scramble the seeds so that the user doesn't need to worry about which
51   // part of the seed needs to be strong.
52   (*out_key)[0] = 0x3ec8f720;
53   (*out_key)[1] = 0x02461e29;
54   (*out_counter)[0] = static_cast<uint32>(seed0);
55   (*out_counter)[1] = static_cast<uint32>(seed0 >> 32);
56   (*out_counter)[2] = static_cast<uint32>(seed1);
57   (*out_counter)[3] = static_cast<uint32>(seed1 >> 32);
58   const auto mix = random::PhiloxRandom(*out_counter, *out_key)();
59   (*out_key)[0] = mix[0];
60   (*out_key)[1] = mix[1];
61   (*out_counter)[0] = (*out_counter)[1] = 0;
62   (*out_counter)[2] = mix[2];
63   (*out_counter)[3] = mix[3];
64   return OkStatus();
65 }
66 
67 namespace {
68 
69 class StatelessRandomOpBase : public OpKernel {
70  public:
StatelessRandomOpBase(OpKernelConstruction * context)71   explicit StatelessRandomOpBase(OpKernelConstruction* context)
72       : OpKernel(context) {}
73 
Compute(OpKernelContext * context)74   void Compute(OpKernelContext* context) override {
75     // Sanitize input
76     const Tensor& shape_t = context->input(0);
77     const Tensor& seed_t = context->input(1);
78     TensorShape shape;
79     OP_REQUIRES_OK(context, tensor::MakeShape(shape_t, &shape));
80     OP_REQUIRES(context, seed_t.dims() == 1 && seed_t.dim_size(0) == 2,
81                 errors::InvalidArgument("seed must have shape [2], not ",
82                                         seed_t.shape().DebugString()));
83 
84     // Allocate output
85     Tensor* output;
86     OP_REQUIRES_OK(context, context->allocate_output(0, shape, &output));
87     if (shape.num_elements() == 0) return;
88 
89     random::PhiloxRandom::Key key;
90     random::PhiloxRandom::ResultType counter;
91     OP_REQUIRES_OK(context, GenerateKey(seed_t, &key, &counter));
92 
93     // Fill in the random numbers
94     Fill(context, random::PhiloxRandom(counter, key), output);
95   }
96 
97   // The part of Compute that depends on device, type, and distribution
98   virtual void Fill(OpKernelContext* context, random::PhiloxRandom random,
99                     Tensor* output) = 0;
100 };
101 
102 template <typename Device, class Distribution>
103 class StatelessRandomOp : public StatelessRandomOpBase {
104  public:
105   using StatelessRandomOpBase::StatelessRandomOpBase;
106 
Fill(OpKernelContext * context,random::PhiloxRandom random,Tensor * output)107   void Fill(OpKernelContext* context, random::PhiloxRandom random,
108             Tensor* output) override {
109     typedef typename Distribution::ResultElementType T;
110     auto flat = output->flat<T>();
111     // Reuse the compute kernels from the stateful random ops
112     functor::FillPhiloxRandom<Device, Distribution>()(
113         context, context->eigen_device<Device>(), /*key=*/nullptr,
114         /*counter=*/nullptr, random, flat.data(), flat.size(), Distribution());
115   }
116 };
117 
118 template <typename Device, typename IntType>
119 class StatelessRandomUniformIntOp : public StatelessRandomOpBase {
120  public:
121   using StatelessRandomOpBase::StatelessRandomOpBase;
122 
Fill(OpKernelContext * context,random::PhiloxRandom random,Tensor * output)123   void Fill(OpKernelContext* context, random::PhiloxRandom random,
124             Tensor* output) override {
125     const Tensor& minval = context->input(2);
126     const Tensor& maxval = context->input(3);
127     OP_REQUIRES(context, TensorShapeUtils::IsScalar(minval.shape()),
128                 errors::InvalidArgument("minval must be 0-D, got shape ",
129                                         minval.shape().DebugString()));
130     OP_REQUIRES(context, TensorShapeUtils::IsScalar(maxval.shape()),
131                 errors::InvalidArgument("maxval must be 0-D, got shape ",
132                                         maxval.shape().DebugString()));
133 
134     // Verify that minval < maxval.  Note that we'll never reach this point for
135     // empty output.  Zero impossible things are fine.
136     const auto lo = minval.scalar<IntType>()();
137     const auto hi = maxval.scalar<IntType>()();
138     OP_REQUIRES(
139         context, lo < hi,
140         errors::InvalidArgument("Need minval < maxval, got ", lo, " >= ", hi));
141 
142     // Build distribution
143     typedef random::UniformDistribution<random::PhiloxRandom, IntType>
144         Distribution;
145     Distribution dist(lo, hi);
146 
147     auto flat = output->flat<IntType>();
148     // Reuse the compute kernels from the stateful random ops
149     functor::FillPhiloxRandom<Device, Distribution>()(
150         context, context->eigen_device<Device>(), /*key=*/nullptr,
151         /*counter=*/nullptr, random, flat.data(), flat.size(), dist);
152   }
153 };
154 
155 template <typename Device, typename IntType>
156 class StatelessRandomUniformFullIntOp : public StatelessRandomOpBase {
157  public:
158   using StatelessRandomOpBase::StatelessRandomOpBase;
159 
Fill(OpKernelContext * context,random::PhiloxRandom random,Tensor * output)160   void Fill(OpKernelContext* context, random::PhiloxRandom random,
161             Tensor* output) override {
162     // Build distribution
163     typedef random::UniformFullIntDistribution<random::PhiloxRandom, IntType>
164         Distribution;
165     Distribution dist;
166 
167     auto flat = output->flat<IntType>();
168     // Reuse the compute kernels from the stateful random ops
169     functor::FillPhiloxRandom<Device, Distribution>()(
170         context, context->eigen_device<Device>(), /*key=*/nullptr,
171         /*counter=*/nullptr, random, flat.data(), flat.size(), dist);
172   }
173 };
174 
175 // Samples from one or more Poisson distributions.
176 template <typename T, typename U>
177 class StatelessRandomPoissonOp : public StatelessRandomOpBase {
178  public:
179   using StatelessRandomOpBase::StatelessRandomOpBase;
180 
Fill(OpKernelContext * ctx,random::PhiloxRandom random,Tensor * output)181   void Fill(OpKernelContext* ctx, random::PhiloxRandom random,
182             Tensor* output) override {
183     const Tensor& rate_t = ctx->input(2);
184 
185     TensorShape samples_shape = output->shape();
186     OP_REQUIRES(ctx, TensorShapeUtils::EndsWith(samples_shape, rate_t.shape()),
187                 errors::InvalidArgument(
188                     "Shape passed in must end with broadcasted shape."));
189 
190     const int64_t num_rate = rate_t.NumElements();
191     const int64_t samples_per_rate = samples_shape.num_elements() / num_rate;
192     const auto rate_flat = rate_t.flat<T>().data();
193     auto samples_flat = output->flat<U>().data();
194 
195     functor::PoissonFunctor<CPUDevice, T, U>()(
196         ctx, ctx->eigen_device<CPUDevice>(), rate_flat, num_rate,
197         samples_per_rate, random, samples_flat);
198   }
199 
200  private:
201   TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomPoissonOp);
202 };
203 
204 #define REGISTER(DEVICE, TYPE)                                              \
205   REGISTER_KERNEL_BUILDER(                                                  \
206       Name("StatelessRandomUniform")                                        \
207           .Device(DEVICE_##DEVICE)                                          \
208           .HostMemory("shape")                                              \
209           .HostMemory("seed")                                               \
210           .TypeConstraint<TYPE>("dtype"),                                   \
211       StatelessRandomOp<DEVICE##Device, random::UniformDistribution<        \
212                                             random::PhiloxRandom, TYPE> >); \
213   REGISTER_KERNEL_BUILDER(                                                  \
214       Name("StatelessRandomNormal")                                         \
215           .Device(DEVICE_##DEVICE)                                          \
216           .HostMemory("shape")                                              \
217           .HostMemory("seed")                                               \
218           .TypeConstraint<TYPE>("dtype"),                                   \
219       StatelessRandomOp<DEVICE##Device, random::NormalDistribution<         \
220                                             random::PhiloxRandom, TYPE> >); \
221   REGISTER_KERNEL_BUILDER(                                                  \
222       Name("StatelessTruncatedNormal")                                      \
223           .Device(DEVICE_##DEVICE)                                          \
224           .HostMemory("shape")                                              \
225           .HostMemory("seed")                                               \
226           .TypeConstraint<TYPE>("dtype"),                                   \
227       StatelessRandomOp<                                                    \
228           DEVICE##Device,                                                   \
229           random::TruncatedNormalDistribution<                              \
230               random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >)
231 
232 #define REGISTER_FULL_INT(DEVICE, TYPE)     \
233   REGISTER_KERNEL_BUILDER(                  \
234       Name("StatelessRandomUniformFullInt") \
235           .Device(DEVICE_##DEVICE)          \
236           .HostMemory("shape")              \
237           .HostMemory("seed")               \
238           .TypeConstraint<TYPE>("dtype"),   \
239       StatelessRandomUniformFullIntOp<DEVICE##Device, TYPE>)
240 
241 #define REGISTER_INT(DEVICE, TYPE)                            \
242   REGISTER_FULL_INT(DEVICE, TYPE);                            \
243   REGISTER_KERNEL_BUILDER(Name("StatelessRandomUniformInt")   \
244                               .Device(DEVICE_##DEVICE)        \
245                               .HostMemory("shape")            \
246                               .HostMemory("seed")             \
247                               .HostMemory("minval")           \
248                               .HostMemory("maxval")           \
249                               .TypeConstraint<TYPE>("dtype"), \
250                           StatelessRandomUniformIntOp<DEVICE##Device, TYPE>)
251 
252 #define REGISTER_CPU(TYPE) REGISTER(CPU, TYPE)
253 #define REGISTER_GPU(TYPE) REGISTER(GPU, TYPE)
254 #define REGISTER_INT_CPU(TYPE) REGISTER_INT(CPU, TYPE)
255 #define REGISTER_INT_GPU(TYPE) REGISTER_INT(GPU, TYPE)
256 #define REGISTER_FULL_INT_CPU(TYPE) REGISTER_FULL_INT(CPU, TYPE)
257 #define REGISTER_FULL_INT_GPU(TYPE) REGISTER_FULL_INT(GPU, TYPE)
258 
259 TF_CALL_half(REGISTER_CPU);
260 TF_CALL_bfloat16(REGISTER_CPU);
261 TF_CALL_float(REGISTER_CPU);
262 TF_CALL_double(REGISTER_CPU);
263 TF_CALL_int32(REGISTER_INT_CPU);
264 TF_CALL_int64(REGISTER_INT_CPU);
265 TF_CALL_uint32(REGISTER_FULL_INT_CPU);
266 TF_CALL_uint64(REGISTER_FULL_INT_CPU);
267 
268 #define REGISTER_POISSON(RATE_TYPE, OUT_TYPE)                     \
269   REGISTER_KERNEL_BUILDER(Name("StatelessRandomPoisson")          \
270                               .Device(DEVICE_CPU)                 \
271                               .HostMemory("shape")                \
272                               .HostMemory("seed")                 \
273                               .HostMemory("lam")                  \
274                               .TypeConstraint<RATE_TYPE>("Rtype") \
275                               .TypeConstraint<OUT_TYPE>("dtype"), \
276                           StatelessRandomPoissonOp<RATE_TYPE, OUT_TYPE>)
277 
278 #define REGISTER_ALL_POISSON(RATE_TYPE)     \
279   REGISTER_POISSON(RATE_TYPE, Eigen::half); \
280   REGISTER_POISSON(RATE_TYPE, float);       \
281   REGISTER_POISSON(RATE_TYPE, double);      \
282   REGISTER_POISSON(RATE_TYPE, int32);       \
283   REGISTER_POISSON(RATE_TYPE, int64_t)
284 
285 TF_CALL_half(REGISTER_ALL_POISSON);
286 TF_CALL_float(REGISTER_ALL_POISSON);
287 TF_CALL_double(REGISTER_ALL_POISSON);
288 TF_CALL_int32(REGISTER_ALL_POISSON);
289 TF_CALL_int64(REGISTER_ALL_POISSON);
290 
291 #undef REGISTER_ALL_POISSON
292 #undef REGISTER_POISSON
293 
294 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
295 
296 TF_CALL_half(REGISTER_GPU);
297 TF_CALL_float(REGISTER_GPU);
298 TF_CALL_double(REGISTER_GPU);
299 TF_CALL_int32(REGISTER_INT_GPU);
300 TF_CALL_int64(REGISTER_INT_GPU);
301 TF_CALL_uint32(REGISTER_FULL_INT_GPU);
302 TF_CALL_uint64(REGISTER_FULL_INT_GPU);
303 
304 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
305 
306 #undef REGISTER
307 #undef REGISTER_INT
308 #undef REGISTER_CPU
309 #undef REGISTER_GPU
310 #undef REGISTER_INT_CPU
311 #undef REGISTER_INT_GPU
312 #undef REGISTER_FULL_INT_CPU
313 #undef REGISTER_FULL_INT_GPU
314 
315 }  // namespace
316 
317 }  // namespace tensorflow
318