xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/client/lib/prng.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/lib/prng.h"
17 
18 #include <cmath>
19 #include <vector>
20 
21 #include "tensorflow/compiler/xla/client/lib/constants.h"
22 #include "tensorflow/compiler/xla/client/xla_builder.h"
23 #include "tensorflow/compiler/xla/util.h"
24 
25 namespace xla {
26 
ConcatScalars(xla::XlaBuilder * builder,absl::Span<const xla::XlaOp> scalars)27 xla::XlaOp ConcatScalars(xla::XlaBuilder* builder,
28                          absl::Span<const xla::XlaOp> scalars) {
29   std::vector<xla::XlaOp> vectors;
30   absl::c_transform(scalars, std::back_inserter(vectors),
31                     [](xla::XlaOp x) { return xla::Reshape(x, {1}); });
32   return ConcatInDim(builder, vectors, 0);
33 }
34 
35 namespace {
36 
37 // Rotates a 32-bit integer 'v' left by 'distance' bits.
RotateLeftU32(XlaOp v,int distance)38 XlaOp RotateLeftU32(XlaOp v, int distance) {
39   return (v << ConstantR0<uint32_t>(v.builder(), distance)) |
40          ShiftRightLogical(v, ConstantR0<uint32_t>(v.builder(), 32 - distance));
41 }
42 
43 // The internal state of the Three Fry implementation.
44 using ThreeFry2x32State = std::array<XlaOp, 2>;
45 
46 // Implements the ThreeFry counter-based PRNG algorithm.
47 // Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.
48 // http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
ThreeFry2x32(ThreeFry2x32State input,ThreeFry2x32State key)49 ThreeFry2x32State ThreeFry2x32(ThreeFry2x32State input, ThreeFry2x32State key) {
50   XlaBuilder* builder = input[0].builder();
51   key[0] = BitcastConvertType(key[0], U32);
52   key[1] = BitcastConvertType(key[1], U32);
53 
54   // Rotation distances specified by the Threefry2x32 algorithm.
55   constexpr std::array<int, 8> rotations = {13, 15, 26, 6, 17, 29, 16, 24};
56   ThreeFry2x32State x;
57 
58   std::array<XlaOp, 3> ks;
59   // 0x1BD11BDA is a parity constant specified by the ThreeFry2x32 algorithm.
60   ks[2] = ConstantR0<uint32_t>(builder, 0x1BD11BDA);
61   for (int i = 0; i < 2; ++i) {
62     ks[i] = key[i];
63     x[i] = input[i];
64     ks[2] = ks[2] ^ key[i];
65   }
66 
67   x[0] = x[0] + ks[0];
68   x[1] = x[1] + ks[1];
69 
70   // Performs a single round of the Threefry2x32 algorithm, with a rotation
71   // amount 'rotation'.
72   auto round = [](ThreeFry2x32State v, int rotation) {
73     v[0] = v[0] + v[1];
74     v[1] = RotateLeftU32(v[1], rotation);
75     v[1] = v[0] ^ v[1];
76     return v;
77   };
78 
79   // There are no known statistical flaws with 13 rounds of Threefry2x32.
80   // We are conservative and use 20 rounds.
81   x = round(x, rotations[0]);
82   x = round(x, rotations[1]);
83   x = round(x, rotations[2]);
84   x = round(x, rotations[3]);
85   x[0] = x[0] + ks[1];
86   x[1] = x[1] + ks[2] + ConstantR0<uint32_t>(builder, 1);
87 
88   x = round(x, rotations[4]);
89   x = round(x, rotations[5]);
90   x = round(x, rotations[6]);
91   x = round(x, rotations[7]);
92   x[0] = x[0] + ks[2];
93   x[1] = x[1] + ks[0] + ConstantR0<uint32_t>(builder, 2);
94 
95   x = round(x, rotations[0]);
96   x = round(x, rotations[1]);
97   x = round(x, rotations[2]);
98   x = round(x, rotations[3]);
99   x[0] = x[0] + ks[0];
100   x[1] = x[1] + ks[1] + ConstantR0<uint32_t>(builder, 3);
101 
102   x = round(x, rotations[4]);
103   x = round(x, rotations[5]);
104   x = round(x, rotations[6]);
105   x = round(x, rotations[7]);
106   x[0] = x[0] + ks[1];
107   x[1] = x[1] + ks[2] + ConstantR0<uint32_t>(builder, 4);
108 
109   x = round(x, rotations[0]);
110   x = round(x, rotations[1]);
111   x = round(x, rotations[2]);
112   x = round(x, rotations[3]);
113   x[0] = x[0] + ks[2];
114   x[1] = x[1] + ks[0] + ConstantR0<uint32_t>(builder, 5);
115 
116   return x;
117 }
118 
119 // Converts a uint64_t to two uint32s.
Uint64ToUint32s(XlaOp u64)120 std::array<XlaOp, 2> Uint64ToUint32s(XlaOp u64) {
121   XlaBuilder* builder = u64.builder();
122   XlaOp const32 = ConstantR0WithType(builder, U64, 32);
123   XlaOp fst = ConvertElementType(u64, U32);
124   XlaOp snd = ConvertElementType(ShiftRightLogical(u64, const32), U32);
125   return {fst, snd};
126 }
127 
128 // Converts two uint32s to a uint64_t.
Uint32sToUint64(std::array<XlaOp,2> u32s)129 XlaOp Uint32sToUint64(std::array<XlaOp, 2> u32s) {
130   XlaBuilder* builder = u32s[0].builder();
131   return ConvertElementType(u32s[0], U64) |
132          ShiftLeft(ConvertElementType(u32s[1], U64),
133                    ConstantR0WithType(builder, U64, 32));
134 }
135 
136 // Given the initial state and the request shape of random numbers to be
137 // generated, returns the input for the random number generator and a new state.
GetThreeFryInputsAndUpdatedState(XlaOp initial_state,const Shape & shape)138 std::pair<ThreeFry2x32State, XlaOp> GetThreeFryInputsAndUpdatedState(
139     XlaOp initial_state, const Shape& shape) {
140   XlaBuilder* builder = initial_state.builder();
141   auto u64_shape = ShapeUtil::MakeShape(U64, shape.dimensions());
142   // initial_state is an R1, so reshape it to a scalar.
143   auto input_u64 = Broadcast(Reshape(initial_state, {}), shape.dimensions());
144   int64_t trailing_dims_product = 1;
145   for (int64_t i = shape.rank() - 1; i >= 0; --i) {
146     if (shape.dimensions(i) < 2) {
147       continue;
148     }
149     input_u64 =
150         input_u64 + (Iota(builder, u64_shape, i) *
151                      ConstantR0<uint64_t>(builder, trailing_dims_product));
152     trailing_dims_product *= shape.dimensions(i);
153   }
154   XlaOp new_state = initial_state +
155                     ConstantR0<uint64_t>(builder, ShapeUtil::ElementsIn(shape));
156   return std::make_pair(Uint64ToUint32s(input_u64), new_state);
157 }
158 
159 // Result for SplitShapeIntoHalves().
160 struct SplitShapePair {
161   Shape half_shape;
162   Shape concat_shape;
163   int64_t split_dim;
164   int64_t new_concat_dim;
165 };
166 
167 // Split the shape on a dimension > 1 into two halves.
SplitShapeIntoHalves(const Shape & shape)168 SplitShapePair SplitShapeIntoHalves(const Shape& shape) {
169   SplitShapePair pair;
170   if (shape.rank() == 0) {
171     pair.half_shape = ShapeUtil::MakeShape(shape.element_type(), {1});
172     pair.concat_shape = ShapeUtil::MakeShape(shape.element_type(), {2});
173     pair.split_dim = 0;
174     pair.new_concat_dim = 0;
175     return pair;
176   }
177   pair.split_dim = -1;
178   for (int64_t i = 0; i < shape.rank(); ++i) {
179     if (shape.dimensions(i) % 2 == 0) {
180       pair.split_dim = i;
181       break;
182     }
183   }
184   if (pair.split_dim == -1) {
185     // No even dims. Find a dimension with maximum size.
186     for (int64_t i = 0; i < shape.rank(); ++i) {
187       if (pair.split_dim == -1 ||
188           shape.dimensions(i) > shape.dimensions(pair.split_dim)) {
189         pair.split_dim = i;
190       }
191     }
192   }
193   CHECK_GE(pair.split_dim, 0);
194   std::vector<int64_t> half_shape_dims;
195   std::vector<int64_t> concat_shape_dims;
196   const auto rank = shape.rank();
197   half_shape_dims.reserve(rank + 1);
198   concat_shape_dims.reserve(rank + 1);
199   for (int64_t i = 0; i < rank; ++i) {
200     if (i == pair.split_dim) {
201       // Create a new trivial dim for the later concat, which is more friendly
202       // to sharding propagation.
203       half_shape_dims.push_back(CeilOfRatio<int64_t>(shape.dimensions(i), 2));
204       half_shape_dims.push_back(1);
205       concat_shape_dims.push_back(half_shape_dims[i]);
206       concat_shape_dims.push_back(2);
207     } else {
208       half_shape_dims.push_back(shape.dimensions(i));
209       concat_shape_dims.push_back(shape.dimensions(i));
210     }
211   }
212   pair.new_concat_dim = pair.split_dim + 1;
213   pair.half_shape = ShapeUtil::MakeShape(shape.element_type(), half_shape_dims);
214   pair.concat_shape =
215       ShapeUtil::MakeShape(shape.element_type(), concat_shape_dims);
216   return pair;
217 }
218 
219 // Combines a pair of split shapes. It works with scalar and non-scalar shapes.
CombineShapePair(absl::Span<const XlaOp> pair,const SplitShapePair & shape_pair,const Shape & original_shape)220 XlaOp CombineShapePair(absl::Span<const XlaOp> pair,
221                        const SplitShapePair& shape_pair,
222                        const Shape& original_shape) {
223   if (original_shape.rank() == 0) {
224     return Reshape(pair[0], {});
225   }
226   XlaBuilder* builder = pair[0].builder();
227   XlaOp result = ConcatInDim(builder, pair, shape_pair.new_concat_dim);
228   const int64_t pre_split_size =
229       original_shape.dimensions(shape_pair.split_dim);
230   std::vector<int64_t> reshape_dims(original_shape.dimensions().begin(),
231                                     original_shape.dimensions().end());
232   reshape_dims[shape_pair.split_dim] = RoundUpTo<int64_t>(pre_split_size, 2);
233   result = Reshape(result, reshape_dims);
234   if (reshape_dims[shape_pair.split_dim] != pre_split_size) {
235     result = Slice(result, std::vector<int64_t>(original_shape.rank(), 0),
236                    original_shape.dimensions(),
237                    std::vector<int64_t>(original_shape.rank(), 1));
238   }
239   return result;
240 }
241 
242 // Generates random 32bits with the given shape using the Three Fry
243 // implementation. Returns the random bits and the new state.
ThreeFryRngBit32(XlaOp key,XlaOp initial_state,const Shape & shape)244 RngOutput ThreeFryRngBit32(XlaOp key, XlaOp initial_state, const Shape& shape) {
245   auto shape_pair = SplitShapeIntoHalves(shape);
246   std::pair<ThreeFry2x32State, XlaOp> inputs_state =
247       GetThreeFryInputsAndUpdatedState(initial_state, shape_pair.half_shape);
248   ThreeFry2x32State inputs = inputs_state.first;
249   ThreeFry2x32State outputs = ThreeFry2x32(inputs, Uint64ToUint32s(key));
250   XlaOp result = CombineShapePair(outputs, shape_pair, shape);
251   return {result, inputs_state.second};
252 }
253 
254 // Generates random 64bits with the given shape using the Three Fry
255 // implementation. Returns the random bits and the new state.
ThreeFryRngBit64(XlaOp key,XlaOp initial_state,const Shape & shape)256 RngOutput ThreeFryRngBit64(XlaOp key, XlaOp initial_state, const Shape& shape) {
257   std::pair<ThreeFry2x32State, XlaOp> inputs_state =
258       GetThreeFryInputsAndUpdatedState(initial_state, shape);
259   ThreeFry2x32State inputs = inputs_state.first;
260   ThreeFry2x32State outputs = ThreeFry2x32(inputs, Uint64ToUint32s(key));
261   XlaOp result = Uint32sToUint64(outputs);
262   return {result, inputs_state.second};
263 }
264 
265 // The key of the Philox random number generator.
266 using Philox4x32Key = std::array<XlaOp, 2>;
267 // The internal state of the Philox random number generator.
268 using Philox4x32State = std::array<XlaOp, 4>;
269 
270 // Computes the Philox4x32 algorithm using 10 rounds.
Philox4x32(Philox4x32State state,Philox4x32Key key)271 Philox4x32State Philox4x32(Philox4x32State state, Philox4x32Key key) {
272   // Constants specified by the Philox algorithm.
273   static const uint32_t kPhiloxW32A = 0x9E3779B9;
274   static const uint32_t kPhiloxW32B = 0xBB67AE85;
275   static const uint32_t kPhiloxM4x32A = 0xD2511F53;
276   static const uint32_t kPhiloxM4x32B = 0xCD9E8D57;
277 
278   struct HighLowPair {
279     XlaOp high;
280     XlaOp low;
281   };
282 
283   // Compute the high and low words from multiplying two 32-bit integers.
284   auto mul_hi_low = [](XlaOp x, uint32_t k) {
285     auto product =
286         ConvertElementType(x, U64) * ConstantR0<uint64_t>(x.builder(), k);
287     auto low = ConvertElementType(product, U32);
288     auto high = ConvertElementType(
289         product >> ConstantR0<uint64_t>(x.builder(), 32), U32);
290     return HighLowPair{high, low};
291   };
292 
293   // Perform a single round of the Philox algorithm.
294   auto philox_round = [&](Philox4x32State x, Philox4x32Key key) {
295     auto product0 = mul_hi_low(x[0], kPhiloxM4x32A);
296     auto product1 = mul_hi_low(x[2], kPhiloxM4x32B);
297     return Philox4x32State{product1.high ^ x[1] ^ key[0], product1.low,
298                            product0.high ^ x[3] ^ key[1], product0.low};
299   };
300 
301   // Update the key after a round of Philox algorithm.
302   auto raise_key = [](Philox4x32Key key) {
303     XlaBuilder* builder = key[0].builder();
304     return Philox4x32Key{key[0] + ConstantR0<uint32_t>(builder, kPhiloxW32A),
305                          key[1] + ConstantR0<uint32_t>(builder, kPhiloxW32B)};
306   };
307 
308   static const int kNumRounds = 10;
309   for (int round = 0; round < kNumRounds; ++round, key = raise_key(key)) {
310     state = philox_round(state, key);
311   }
312   return state;
313 }
314 
315 // Scrambles the input key so that users don't need to worry about which part
316 // of the key needs to be strong.
ScramblePhiloxKey(Philox4x32Key key)317 std::pair<Philox4x32State, Philox4x32Key> ScramblePhiloxKey(Philox4x32Key key) {
318   XlaBuilder* builder = key[0].builder();
319   XlaOp key0 = ConvertElementType(key[0], U64);
320   XlaOp key1 = ConvertElementType(key[1], U64);
321 
322   Philox4x32State state = {
323       ConvertElementType(key0, U32),
324       ConvertElementType(key0 >> ScalarLike(key0, 32), U32),
325       ConvertElementType(key1, U32),
326       ConvertElementType(key1 >> ScalarLike(key1, 32), U32),
327   };
328   key = {ConstantR0<uint32_t>(builder, 0x3ec8f720),
329          ConstantR0<uint32_t>(builder, 0x02461e29)};
330   state = Philox4x32(state, key);
331   XlaOp zero = ConstantR0<uint32_t>(builder, 0);
332   return {Philox4x32State{zero, zero, state[2], state[3]},
333           Philox4x32Key{state[0], state[1]}};
334 }
335 
336 // Adds an U128 tensor with an U64 tensor. The U128 tensor is represented as two
337 // U64s with the low 64bits in the front. This routine supports explicit
338 // broadcasting of the U128 tensor, with `broadcast_sizes` representing the
339 // dimensions prepended to its shape.
Uint128AddUint64(const std::array<XlaOp,2> & u128,XlaOp u64,absl::Span<const int64_t> broadcast_sizes={})340 std::array<XlaOp, 2> Uint128AddUint64(
341     const std::array<XlaOp, 2>& u128, XlaOp u64,
342     absl::Span<const int64_t> broadcast_sizes = {}) {
343   auto u128_low = u128[0];
344   auto u128_high = u128[1];
345   XlaOp new_u128_low = u128_low + u64;
346   XlaOp one = ConstantR0<uint64_t>(u128[0].builder(), 1);
347   XlaOp new_u128_high = Select(Lt(new_u128_low, u128_low),
348                                Broadcast(u128_high + one, broadcast_sizes),
349                                Broadcast(u128_high, broadcast_sizes));
350   return {new_u128_low, new_u128_high};
351 }
352 
Uint32sToUint128(const std::array<XlaOp,4> & u32s)353 std::array<XlaOp, 2> Uint32sToUint128(const std::array<XlaOp, 4>& u32s) {
354   return {Uint32sToUint64({u32s[0], u32s[1]}),
355           Uint32sToUint64({u32s[2], u32s[3]})};
356 }
357 
Uint128ToUint32s(const std::array<XlaOp,2> & u128)358 std::array<XlaOp, 4> Uint128ToUint32s(const std::array<XlaOp, 2>& u128) {
359   std::array<XlaOp, 2> u128_low_32s = Uint64ToUint32s(u128[0]);
360   std::array<XlaOp, 2> u128_high_32s = Uint64ToUint32s(u128[1]);
361   return {u128_low_32s[0], u128_low_32s[1], u128_high_32s[0], u128_high_32s[1]};
362 }
363 
Uint128FromOp(XlaOp op)364 std::array<XlaOp, 2> Uint128FromOp(XlaOp op) {
365   auto u128_low = xla::Reshape(xla::Slice(op, {0}, {1}, {1}), {});
366   auto u128_high = xla::Reshape(xla::Slice(op, {1}, {2}, {1}), {});
367   return {u128_low, u128_high};
368 }
369 
Uint128ToOp(std::array<XlaOp,2> u128)370 XlaOp Uint128ToOp(std::array<XlaOp, 2> u128) {
371   return ConcatScalars(u128[0].builder(), {u128[0], u128[1]});
372 }
373 
374 // Returns the pair (state + [0, 1, ..., n-1], state + n), which should be used
375 // as the inputs fed to `Philox4x32` and the updated state. `state` is an U128
376 // represented as 4 U32s in the order from the least significant one to the most
377 // significant one.
GetPhiloxInputsAndUpdatedState(const Philox4x32State & state,int64_t n)378 std::pair<Philox4x32State, XlaOp> GetPhiloxInputsAndUpdatedState(
379     const Philox4x32State& state, int64_t n) {
380   XlaBuilder* builder = state[0].builder();
381   XlaOp iota = Iota(builder, U64, n);
382   auto state_u128 = Uint32sToUint128(state);
383   auto inputs = Uint128ToUint32s(Uint128AddUint64(state_u128, iota, {n}));
384   XlaOp new_state = Uint128ToOp(
385       Uint128AddUint64(state_u128, ConstantR0<uint64_t>(builder, n)));
386   return std::make_pair(inputs, new_state);
387 }
388 
389 // Generates CeilOfRatio(num_elems, 4)*4 32bit Philox random numbers, as Philox
390 // numbers are generated in the unit of 128bits.
GeneratePhiloxBits(int64_t num_elems,XlaOp initial_state,Philox4x32Key key)391 std::pair<Philox4x32State, XlaOp> GeneratePhiloxBits(int64_t num_elems,
392                                                      XlaOp initial_state,
393                                                      Philox4x32Key key) {
394   Philox4x32State state;
395   state = Uint128ToUint32s(Uint128FromOp(initial_state));
396   const int64_t num_vector4 = CeilOfRatio<int64_t>(num_elems, 4);
397   Philox4x32State inputs;
398   XlaOp new_state;
399   std::tie(inputs, new_state) =
400       GetPhiloxInputsAndUpdatedState(state, num_vector4);
401   auto outputs = Philox4x32(inputs, key);
402   return std::make_pair(outputs, new_state);
403 }
404 
405 // Generates an array of primitive type U32 with the given shape containing
406 // random bits generated by the Philox algorithm. Returns the array and the new
407 // state of the random number generator.
PhiloxRngBit32(XlaOp op_key,XlaOp initial_state,const Shape & shape)408 RngOutput PhiloxRngBit32(XlaOp op_key, XlaOp initial_state,
409                          const Shape& shape) {
410   XlaBuilder* builder = op_key.builder();
411   const int64_t num_elems = ShapeUtil::ElementsIn(shape);
412 
413   Philox4x32Key key = Uint64ToUint32s(op_key);
414   Philox4x32State bits;
415   XlaOp new_state;
416   std::tie(bits, new_state) = GeneratePhiloxBits(num_elems, initial_state, key);
417   // Combining bits[i] in a round-robin fashion, to align with non-XLA
418   // implementations
419   int64_t bits_len = (num_elems + 3) / 4;
420   for (auto i = 0; i < 4; ++i) {
421     bits[i] = Reshape(bits[i], {bits_len, 1});
422   }
423   XlaOp numbers = ConcatInDim(builder, {bits[0], bits[1], bits[2], bits[3]},
424                               /*dimension=*/1);
425   numbers = Reshape(numbers, {bits_len * 4});
426   numbers = Slice(numbers, /*start_indices=*/{0},
427                   /*limit_indices=*/{num_elems},
428                   /*strides=*/{1});
429   return {Reshape(numbers, shape.dimensions()), new_state};
430 }
431 
432 // Generates an array of primitive type U64 with the given shape containing
433 // random bits generated by the Philox algorithm. Returns the array and the new
434 // state of the random number generator.
PhiloxRngBit64(XlaOp op_key,XlaOp initial_state,const Shape & shape)435 RngOutput PhiloxRngBit64(XlaOp op_key, XlaOp initial_state,
436                          const Shape& shape) {
437   XlaBuilder* builder = op_key.builder();
438   const int64_t num_elems = ShapeUtil::ElementsIn(shape);
439 
440   Philox4x32Key key = Uint64ToUint32s(op_key);
441   Philox4x32State bits32;
442   XlaOp new_state;
443   std::tie(bits32, new_state) =
444       GeneratePhiloxBits(num_elems * 2, initial_state, key);
445 
446   std::array<XlaOp, 2> bits64;
447   bits64[0] = Uint32sToUint64({bits32[0], bits32[1]});
448   bits64[1] = Uint32sToUint64({bits32[2], bits32[3]});
449 
450   // Combining bits64[i] in a round-robin fashion, to align with non-XLA
451   // implementations
452   int64_t bits64_len = (num_elems + 1) / 2;
453   for (auto i = 0; i < 2; ++i) {
454     bits64[i] = Reshape(bits64[i], {bits64_len, 1});
455   }
456   XlaOp numbers = ConcatInDim(builder, {bits64[0], bits64[1]},
457                               /*dimension=*/1);
458   numbers = Reshape(numbers, {bits64_len * 2});
459   numbers = Slice(numbers, /*start_indices=*/{0},
460                   /*limit_indices=*/{num_elems},
461                   /*strides=*/{1});
462   return {Reshape(numbers, shape.dimensions()), new_state};
463 }
464 
ConvertRandomBitsToUniformFloatingPoint(XlaOp bits,XlaOp minval,XlaOp maxval)465 XlaOp ConvertRandomBitsToUniformFloatingPoint(XlaOp bits, XlaOp minval,
466                                               XlaOp maxval) {
467   XlaBuilder* builder = bits.builder();
468   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
469     TF_ASSIGN_OR_RETURN(const Shape* minval_shape,
470                         builder->GetShapePtr(minval));
471     TF_ASSIGN_OR_RETURN(const Shape* bits_shape, builder->GetShapePtr(bits));
472     PrimitiveType value_type = minval_shape->element_type();
473     PrimitiveType bit_type = bits_shape->element_type();
474     CHECK((value_type == F32 && bit_type == U32) ||
475           (value_type == F64 && bit_type == U64));
476 
477     // Form random mantissa bits for float/double, with a leading 1 bit.
478     int num_float_bits = primitive_util::BitWidth(value_type);
479     // Subtract one as SignificandWidth includes the leading 1 bit.
480     int num_mantissa_bits = primitive_util::SignificandWidth(value_type) - 1;
481 
482     // Ignore the exponent bits and convert the mantissa bits to the floating
483     // point type.
484     bits = ShiftRightLogical(
485         bits, ScalarLike(bits, num_float_bits - num_mantissa_bits));
486 
487     // We have an integer-valued floating point number in the range
488     // [0, 2**{num_mantissa_bits}).
489     XlaOp values = ConvertElementType(bits, value_type);
490 
491     // Divide by 2**{-num_mantissa_bits} to get a number in the range
492     // [0.0, 1.0).
493     values = values * ScalarLike(values, std::ldexp(1., -num_mantissa_bits));
494 
495     // Multiply and add to shift to the range [minval, maxval).
496     return values * (maxval - minval) + minval;
497   });
498 }
499 
ConvertRandomBitsToUniformInt(XlaOp bits,XlaOp minval,XlaOp maxval,PrimitiveType type,PrimitiveType unsigned_type)500 XlaOp ConvertRandomBitsToUniformInt(XlaOp bits, XlaOp minval, XlaOp maxval,
501                                     PrimitiveType type,
502                                     PrimitiveType unsigned_type) {
503   XlaBuilder* builder = bits.builder();
504   XlaOp range = BitcastConvertType(maxval, unsigned_type) -
505                 BitcastConvertType(minval, unsigned_type);
506   XlaOp dist = Rem(bits, range);
507   XlaOp dist_div_2 =
508       ShiftRightLogical(dist, ConstantR0WithType(builder, unsigned_type, 1));
509 
510   return minval + BitcastConvertType(dist_div_2, type) +
511          BitcastConvertType(dist - dist_div_2, type);
512 }
513 
514 // Implements the Box-Muller transform, which converts random floats in the
515 // range of [0, 1] from uniform distribution to normal distribution with mean 0
516 // and variance 1. For more detail on the Box-Muller transform, see
517 // http://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform#Basic_form
BoxMullerTransform(XlaOp x0,XlaOp x1)518 std::pair<XlaOp, XlaOp> BoxMullerTransform(XlaOp x0, XlaOp x1) {
519   // Do not send a really small number to log().
520   XlaOp u1 = Max(x0, ScalarLike(x0, 1.0e-7f));
521 
522   XlaOp v1 = ScalarLike(x1, 2.0f * M_PI) * x1;
523   XlaOp u2 = Sqrt(ScalarLike(u1, -2.0f) * Log(u1));
524   return {Sin(v1) * u2, Cos(v1) * u2};
525 }
526 
527 }  // namespace
528 
PhiloxIncreaseCounter(XlaOp counter,XlaOp delta)529 XlaOp PhiloxIncreaseCounter(XlaOp counter, XlaOp delta) {
530   return Uint128ToOp(Uint128AddUint64(Uint128FromOp(counter), delta));
531 }
532 
ThreeFryBitGenerator(XlaOp key,XlaOp initial_state,const Shape & shape)533 RngOutput ThreeFryBitGenerator(XlaOp key, XlaOp initial_state,
534                                const Shape& shape) {
535   PrimitiveType type = shape.element_type();
536   switch (type) {
537     case F32:
538     case U32:
539     case S32:
540       return ThreeFryRngBit32(key, initial_state, shape);
541     case F64:
542     case U64:
543     case S64:
544       return ThreeFryRngBit64(key, initial_state, shape);
545     default:
546       return {key.builder()->ReportError(Unimplemented(
547                   "Types other than F32, F64, U32, S32, U64 and S64 "
548                   "are not implemented by ThreeFryBitGenerator; got %s",
549                   primitive_util::LowercasePrimitiveTypeName(type))),
550               initial_state};
551   }
552 }
553 
PhiloxBitGenerator(XlaOp key,XlaOp initial_state,const Shape & shape)554 RngOutput PhiloxBitGenerator(XlaOp key, XlaOp initial_state,
555                              const Shape& shape) {
556   PrimitiveType type = shape.element_type();
557   switch (type) {
558     case F32:
559     case U32:
560     case S32:
561       return PhiloxRngBit32(key, initial_state, shape);
562     case F64:
563     case U64:
564     case S64:
565       return PhiloxRngBit64(key, initial_state, shape);
566     default:
567       return {key.builder()->ReportError(Unimplemented(
568                   "Types other than F32, F64, U32, S32, U64 and S64 "
569                   "are not implemented by PhiloxFryBitGenerator; got %s",
570                   primitive_util::LowercasePrimitiveTypeName(type))),
571               initial_state};
572   }
573 }
574 
ScramblePhiloxKey(XlaOp key)575 std::pair<XlaOp, XlaOp> ScramblePhiloxKey(XlaOp key) {
576   Philox4x32Key pkey = Uint64ToUint32s(key);
577   auto state_key = ScramblePhiloxKey(pkey);
578   return std::make_pair(Uint128ToOp(Uint32sToUint128(state_key.first)),
579                         Uint32sToUint64(state_key.second));
580 }
581 
UniformFloatingPointDistribution(XlaOp key,XlaOp initial_state,BitGeneratorTy bit_generator,XlaOp minval,XlaOp maxval,const Shape & shape)582 RngOutput UniformFloatingPointDistribution(XlaOp key, XlaOp initial_state,
583                                            BitGeneratorTy bit_generator,
584                                            XlaOp minval, XlaOp maxval,
585                                            const Shape& shape) {
586   RngOutput bits_state = bit_generator(key, initial_state, shape);
587   XlaOp bits = bits_state.value;
588   XlaOp new_state = bits_state.state;
589   return {ConvertRandomBitsToUniformFloatingPoint(bits, minval, maxval),
590           new_state};
591 }
592 
UniformIntDistribution(XlaOp key,XlaOp initial_state,BitGeneratorTy bit_generator,XlaOp minval,XlaOp maxval,const Shape & shape)593 RngOutput UniformIntDistribution(XlaOp key, XlaOp initial_state,
594                                  BitGeneratorTy bit_generator, XlaOp minval,
595                                  XlaOp maxval, const Shape& shape) {
596   RngOutput bits_state = bit_generator(key, initial_state, shape);
597   XlaOp bits = bits_state.value;
598   XlaOp new_state = bits_state.state;
599   PrimitiveType type = shape.element_type();
600   PrimitiveType unsigned_type;
601   if (type == U32 || type == S32) {
602     unsigned_type = U32;
603   } else {
604     DCHECK(type == U64 || type == S64);
605     unsigned_type = U64;
606   }
607   return {
608       ConvertRandomBitsToUniformInt(bits, minval, maxval, type, unsigned_type),
609       new_state};
610 }
611 
NormalFloatingPointDistribution(XlaOp key,XlaOp initial_state,BitGeneratorTy bit_generator,const Shape & shape)612 RngOutput NormalFloatingPointDistribution(XlaOp key, XlaOp initial_state,
613                                           BitGeneratorTy bit_generator,
614                                           const Shape& shape) {
615   PrimitiveType primitive_type = shape.element_type();
616   DCHECK(primitive_type == F32 || primitive_type == F64);
617 
618   XlaBuilder* builder = key.builder();
619   auto shape_pair = SplitShapeIntoHalves(shape);
620   RngOutput bits_state = UniformFloatingPointDistribution(
621       key, initial_state, bit_generator,
622       xla::ConstantR0WithType(builder, primitive_type, 0.0),
623       xla::ConstantR0WithType(builder, primitive_type, 1.0),
624       shape_pair.concat_shape);
625 
626   // Separate the bits into two groups to perform the Box-Muller transform.
627   XlaOp bits_0 = Slice(bits_state.value,
628                        std::vector<int64_t>(shape_pair.half_shape.rank(), 0),
629                        shape_pair.half_shape.dimensions(),
630                        std::vector<int64_t>(shape_pair.half_shape.rank(), 1));
631   std::vector<int64_t> bits_1_starts(shape_pair.half_shape.rank(), 0);
632   bits_1_starts[shape_pair.new_concat_dim] = 1;
633   XlaOp bits_1 = Slice(bits_state.value, bits_1_starts,
634                        shape_pair.concat_shape.dimensions(),
635                        std::vector<int64_t>(shape_pair.half_shape.rank(), 1));
636   std::tie(bits_0, bits_1) = BoxMullerTransform(bits_0, bits_1);
637 
638   // Put the numbers in the two groups back to form the requested shape.
639   XlaOp normal = CombineShapePair({bits_0, bits_1}, shape_pair, shape);
640   return {normal, bits_state.state};
641 }
642 
643 }  // namespace xla
644