1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/lib/prng.h"
17
18 #include <cmath>
19 #include <vector>
20
21 #include "tensorflow/compiler/xla/client/lib/constants.h"
22 #include "tensorflow/compiler/xla/client/xla_builder.h"
23 #include "tensorflow/compiler/xla/util.h"
24
25 namespace xla {
26
ConcatScalars(xla::XlaBuilder * builder,absl::Span<const xla::XlaOp> scalars)27 xla::XlaOp ConcatScalars(xla::XlaBuilder* builder,
28 absl::Span<const xla::XlaOp> scalars) {
29 std::vector<xla::XlaOp> vectors;
30 absl::c_transform(scalars, std::back_inserter(vectors),
31 [](xla::XlaOp x) { return xla::Reshape(x, {1}); });
32 return ConcatInDim(builder, vectors, 0);
33 }
34
35 namespace {
36
37 // Rotates a 32-bit integer 'v' left by 'distance' bits.
RotateLeftU32(XlaOp v,int distance)38 XlaOp RotateLeftU32(XlaOp v, int distance) {
39 return (v << ConstantR0<uint32_t>(v.builder(), distance)) |
40 ShiftRightLogical(v, ConstantR0<uint32_t>(v.builder(), 32 - distance));
41 }
42
43 // The internal state of the Three Fry implementation.
44 using ThreeFry2x32State = std::array<XlaOp, 2>;
45
46 // Implements the ThreeFry counter-based PRNG algorithm.
47 // Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.
48 // http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
ThreeFry2x32(ThreeFry2x32State input,ThreeFry2x32State key)49 ThreeFry2x32State ThreeFry2x32(ThreeFry2x32State input, ThreeFry2x32State key) {
50 XlaBuilder* builder = input[0].builder();
51 key[0] = BitcastConvertType(key[0], U32);
52 key[1] = BitcastConvertType(key[1], U32);
53
54 // Rotation distances specified by the Threefry2x32 algorithm.
55 constexpr std::array<int, 8> rotations = {13, 15, 26, 6, 17, 29, 16, 24};
56 ThreeFry2x32State x;
57
58 std::array<XlaOp, 3> ks;
59 // 0x1BD11BDA is a parity constant specified by the ThreeFry2x32 algorithm.
60 ks[2] = ConstantR0<uint32_t>(builder, 0x1BD11BDA);
61 for (int i = 0; i < 2; ++i) {
62 ks[i] = key[i];
63 x[i] = input[i];
64 ks[2] = ks[2] ^ key[i];
65 }
66
67 x[0] = x[0] + ks[0];
68 x[1] = x[1] + ks[1];
69
70 // Performs a single round of the Threefry2x32 algorithm, with a rotation
71 // amount 'rotation'.
72 auto round = [](ThreeFry2x32State v, int rotation) {
73 v[0] = v[0] + v[1];
74 v[1] = RotateLeftU32(v[1], rotation);
75 v[1] = v[0] ^ v[1];
76 return v;
77 };
78
79 // There are no known statistical flaws with 13 rounds of Threefry2x32.
80 // We are conservative and use 20 rounds.
81 x = round(x, rotations[0]);
82 x = round(x, rotations[1]);
83 x = round(x, rotations[2]);
84 x = round(x, rotations[3]);
85 x[0] = x[0] + ks[1];
86 x[1] = x[1] + ks[2] + ConstantR0<uint32_t>(builder, 1);
87
88 x = round(x, rotations[4]);
89 x = round(x, rotations[5]);
90 x = round(x, rotations[6]);
91 x = round(x, rotations[7]);
92 x[0] = x[0] + ks[2];
93 x[1] = x[1] + ks[0] + ConstantR0<uint32_t>(builder, 2);
94
95 x = round(x, rotations[0]);
96 x = round(x, rotations[1]);
97 x = round(x, rotations[2]);
98 x = round(x, rotations[3]);
99 x[0] = x[0] + ks[0];
100 x[1] = x[1] + ks[1] + ConstantR0<uint32_t>(builder, 3);
101
102 x = round(x, rotations[4]);
103 x = round(x, rotations[5]);
104 x = round(x, rotations[6]);
105 x = round(x, rotations[7]);
106 x[0] = x[0] + ks[1];
107 x[1] = x[1] + ks[2] + ConstantR0<uint32_t>(builder, 4);
108
109 x = round(x, rotations[0]);
110 x = round(x, rotations[1]);
111 x = round(x, rotations[2]);
112 x = round(x, rotations[3]);
113 x[0] = x[0] + ks[2];
114 x[1] = x[1] + ks[0] + ConstantR0<uint32_t>(builder, 5);
115
116 return x;
117 }
118
119 // Converts a uint64_t to two uint32s.
Uint64ToUint32s(XlaOp u64)120 std::array<XlaOp, 2> Uint64ToUint32s(XlaOp u64) {
121 XlaBuilder* builder = u64.builder();
122 XlaOp const32 = ConstantR0WithType(builder, U64, 32);
123 XlaOp fst = ConvertElementType(u64, U32);
124 XlaOp snd = ConvertElementType(ShiftRightLogical(u64, const32), U32);
125 return {fst, snd};
126 }
127
128 // Converts two uint32s to a uint64_t.
Uint32sToUint64(std::array<XlaOp,2> u32s)129 XlaOp Uint32sToUint64(std::array<XlaOp, 2> u32s) {
130 XlaBuilder* builder = u32s[0].builder();
131 return ConvertElementType(u32s[0], U64) |
132 ShiftLeft(ConvertElementType(u32s[1], U64),
133 ConstantR0WithType(builder, U64, 32));
134 }
135
136 // Given the initial state and the request shape of random numbers to be
137 // generated, returns the input for the random number generator and a new state.
GetThreeFryInputsAndUpdatedState(XlaOp initial_state,const Shape & shape)138 std::pair<ThreeFry2x32State, XlaOp> GetThreeFryInputsAndUpdatedState(
139 XlaOp initial_state, const Shape& shape) {
140 XlaBuilder* builder = initial_state.builder();
141 auto u64_shape = ShapeUtil::MakeShape(U64, shape.dimensions());
142 // initial_state is an R1, so reshape it to a scalar.
143 auto input_u64 = Broadcast(Reshape(initial_state, {}), shape.dimensions());
144 int64_t trailing_dims_product = 1;
145 for (int64_t i = shape.rank() - 1; i >= 0; --i) {
146 if (shape.dimensions(i) < 2) {
147 continue;
148 }
149 input_u64 =
150 input_u64 + (Iota(builder, u64_shape, i) *
151 ConstantR0<uint64_t>(builder, trailing_dims_product));
152 trailing_dims_product *= shape.dimensions(i);
153 }
154 XlaOp new_state = initial_state +
155 ConstantR0<uint64_t>(builder, ShapeUtil::ElementsIn(shape));
156 return std::make_pair(Uint64ToUint32s(input_u64), new_state);
157 }
158
159 // Result for SplitShapeIntoHalves().
160 struct SplitShapePair {
161 Shape half_shape;
162 Shape concat_shape;
163 int64_t split_dim;
164 int64_t new_concat_dim;
165 };
166
167 // Split the shape on a dimension > 1 into two halves.
SplitShapeIntoHalves(const Shape & shape)168 SplitShapePair SplitShapeIntoHalves(const Shape& shape) {
169 SplitShapePair pair;
170 if (shape.rank() == 0) {
171 pair.half_shape = ShapeUtil::MakeShape(shape.element_type(), {1});
172 pair.concat_shape = ShapeUtil::MakeShape(shape.element_type(), {2});
173 pair.split_dim = 0;
174 pair.new_concat_dim = 0;
175 return pair;
176 }
177 pair.split_dim = -1;
178 for (int64_t i = 0; i < shape.rank(); ++i) {
179 if (shape.dimensions(i) % 2 == 0) {
180 pair.split_dim = i;
181 break;
182 }
183 }
184 if (pair.split_dim == -1) {
185 // No even dims. Find a dimension with maximum size.
186 for (int64_t i = 0; i < shape.rank(); ++i) {
187 if (pair.split_dim == -1 ||
188 shape.dimensions(i) > shape.dimensions(pair.split_dim)) {
189 pair.split_dim = i;
190 }
191 }
192 }
193 CHECK_GE(pair.split_dim, 0);
194 std::vector<int64_t> half_shape_dims;
195 std::vector<int64_t> concat_shape_dims;
196 const auto rank = shape.rank();
197 half_shape_dims.reserve(rank + 1);
198 concat_shape_dims.reserve(rank + 1);
199 for (int64_t i = 0; i < rank; ++i) {
200 if (i == pair.split_dim) {
201 // Create a new trivial dim for the later concat, which is more friendly
202 // to sharding propagation.
203 half_shape_dims.push_back(CeilOfRatio<int64_t>(shape.dimensions(i), 2));
204 half_shape_dims.push_back(1);
205 concat_shape_dims.push_back(half_shape_dims[i]);
206 concat_shape_dims.push_back(2);
207 } else {
208 half_shape_dims.push_back(shape.dimensions(i));
209 concat_shape_dims.push_back(shape.dimensions(i));
210 }
211 }
212 pair.new_concat_dim = pair.split_dim + 1;
213 pair.half_shape = ShapeUtil::MakeShape(shape.element_type(), half_shape_dims);
214 pair.concat_shape =
215 ShapeUtil::MakeShape(shape.element_type(), concat_shape_dims);
216 return pair;
217 }
218
219 // Combines a pair of split shapes. It works with scalar and non-scalar shapes.
CombineShapePair(absl::Span<const XlaOp> pair,const SplitShapePair & shape_pair,const Shape & original_shape)220 XlaOp CombineShapePair(absl::Span<const XlaOp> pair,
221 const SplitShapePair& shape_pair,
222 const Shape& original_shape) {
223 if (original_shape.rank() == 0) {
224 return Reshape(pair[0], {});
225 }
226 XlaBuilder* builder = pair[0].builder();
227 XlaOp result = ConcatInDim(builder, pair, shape_pair.new_concat_dim);
228 const int64_t pre_split_size =
229 original_shape.dimensions(shape_pair.split_dim);
230 std::vector<int64_t> reshape_dims(original_shape.dimensions().begin(),
231 original_shape.dimensions().end());
232 reshape_dims[shape_pair.split_dim] = RoundUpTo<int64_t>(pre_split_size, 2);
233 result = Reshape(result, reshape_dims);
234 if (reshape_dims[shape_pair.split_dim] != pre_split_size) {
235 result = Slice(result, std::vector<int64_t>(original_shape.rank(), 0),
236 original_shape.dimensions(),
237 std::vector<int64_t>(original_shape.rank(), 1));
238 }
239 return result;
240 }
241
242 // Generates random 32bits with the given shape using the Three Fry
243 // implementation. Returns the random bits and the new state.
ThreeFryRngBit32(XlaOp key,XlaOp initial_state,const Shape & shape)244 RngOutput ThreeFryRngBit32(XlaOp key, XlaOp initial_state, const Shape& shape) {
245 auto shape_pair = SplitShapeIntoHalves(shape);
246 std::pair<ThreeFry2x32State, XlaOp> inputs_state =
247 GetThreeFryInputsAndUpdatedState(initial_state, shape_pair.half_shape);
248 ThreeFry2x32State inputs = inputs_state.first;
249 ThreeFry2x32State outputs = ThreeFry2x32(inputs, Uint64ToUint32s(key));
250 XlaOp result = CombineShapePair(outputs, shape_pair, shape);
251 return {result, inputs_state.second};
252 }
253
254 // Generates random 64bits with the given shape using the Three Fry
255 // implementation. Returns the random bits and the new state.
ThreeFryRngBit64(XlaOp key,XlaOp initial_state,const Shape & shape)256 RngOutput ThreeFryRngBit64(XlaOp key, XlaOp initial_state, const Shape& shape) {
257 std::pair<ThreeFry2x32State, XlaOp> inputs_state =
258 GetThreeFryInputsAndUpdatedState(initial_state, shape);
259 ThreeFry2x32State inputs = inputs_state.first;
260 ThreeFry2x32State outputs = ThreeFry2x32(inputs, Uint64ToUint32s(key));
261 XlaOp result = Uint32sToUint64(outputs);
262 return {result, inputs_state.second};
263 }
264
265 // The key of the Philox random number generator.
266 using Philox4x32Key = std::array<XlaOp, 2>;
267 // The internal state of the Philox random number generator.
268 using Philox4x32State = std::array<XlaOp, 4>;
269
270 // Computes the Philox4x32 algorithm using 10 rounds.
Philox4x32(Philox4x32State state,Philox4x32Key key)271 Philox4x32State Philox4x32(Philox4x32State state, Philox4x32Key key) {
272 // Constants specified by the Philox algorithm.
273 static const uint32_t kPhiloxW32A = 0x9E3779B9;
274 static const uint32_t kPhiloxW32B = 0xBB67AE85;
275 static const uint32_t kPhiloxM4x32A = 0xD2511F53;
276 static const uint32_t kPhiloxM4x32B = 0xCD9E8D57;
277
278 struct HighLowPair {
279 XlaOp high;
280 XlaOp low;
281 };
282
283 // Compute the high and low words from multiplying two 32-bit integers.
284 auto mul_hi_low = [](XlaOp x, uint32_t k) {
285 auto product =
286 ConvertElementType(x, U64) * ConstantR0<uint64_t>(x.builder(), k);
287 auto low = ConvertElementType(product, U32);
288 auto high = ConvertElementType(
289 product >> ConstantR0<uint64_t>(x.builder(), 32), U32);
290 return HighLowPair{high, low};
291 };
292
293 // Perform a single round of the Philox algorithm.
294 auto philox_round = [&](Philox4x32State x, Philox4x32Key key) {
295 auto product0 = mul_hi_low(x[0], kPhiloxM4x32A);
296 auto product1 = mul_hi_low(x[2], kPhiloxM4x32B);
297 return Philox4x32State{product1.high ^ x[1] ^ key[0], product1.low,
298 product0.high ^ x[3] ^ key[1], product0.low};
299 };
300
301 // Update the key after a round of Philox algorithm.
302 auto raise_key = [](Philox4x32Key key) {
303 XlaBuilder* builder = key[0].builder();
304 return Philox4x32Key{key[0] + ConstantR0<uint32_t>(builder, kPhiloxW32A),
305 key[1] + ConstantR0<uint32_t>(builder, kPhiloxW32B)};
306 };
307
308 static const int kNumRounds = 10;
309 for (int round = 0; round < kNumRounds; ++round, key = raise_key(key)) {
310 state = philox_round(state, key);
311 }
312 return state;
313 }
314
315 // Scrambles the input key so that users don't need to worry about which part
316 // of the key needs to be strong.
ScramblePhiloxKey(Philox4x32Key key)317 std::pair<Philox4x32State, Philox4x32Key> ScramblePhiloxKey(Philox4x32Key key) {
318 XlaBuilder* builder = key[0].builder();
319 XlaOp key0 = ConvertElementType(key[0], U64);
320 XlaOp key1 = ConvertElementType(key[1], U64);
321
322 Philox4x32State state = {
323 ConvertElementType(key0, U32),
324 ConvertElementType(key0 >> ScalarLike(key0, 32), U32),
325 ConvertElementType(key1, U32),
326 ConvertElementType(key1 >> ScalarLike(key1, 32), U32),
327 };
328 key = {ConstantR0<uint32_t>(builder, 0x3ec8f720),
329 ConstantR0<uint32_t>(builder, 0x02461e29)};
330 state = Philox4x32(state, key);
331 XlaOp zero = ConstantR0<uint32_t>(builder, 0);
332 return {Philox4x32State{zero, zero, state[2], state[3]},
333 Philox4x32Key{state[0], state[1]}};
334 }
335
336 // Adds an U128 tensor with an U64 tensor. The U128 tensor is represented as two
337 // U64s with the low 64bits in the front. This routine supports explicit
338 // broadcasting of the U128 tensor, with `broadcast_sizes` representing the
339 // dimensions prepended to its shape.
Uint128AddUint64(const std::array<XlaOp,2> & u128,XlaOp u64,absl::Span<const int64_t> broadcast_sizes={})340 std::array<XlaOp, 2> Uint128AddUint64(
341 const std::array<XlaOp, 2>& u128, XlaOp u64,
342 absl::Span<const int64_t> broadcast_sizes = {}) {
343 auto u128_low = u128[0];
344 auto u128_high = u128[1];
345 XlaOp new_u128_low = u128_low + u64;
346 XlaOp one = ConstantR0<uint64_t>(u128[0].builder(), 1);
347 XlaOp new_u128_high = Select(Lt(new_u128_low, u128_low),
348 Broadcast(u128_high + one, broadcast_sizes),
349 Broadcast(u128_high, broadcast_sizes));
350 return {new_u128_low, new_u128_high};
351 }
352
Uint32sToUint128(const std::array<XlaOp,4> & u32s)353 std::array<XlaOp, 2> Uint32sToUint128(const std::array<XlaOp, 4>& u32s) {
354 return {Uint32sToUint64({u32s[0], u32s[1]}),
355 Uint32sToUint64({u32s[2], u32s[3]})};
356 }
357
Uint128ToUint32s(const std::array<XlaOp,2> & u128)358 std::array<XlaOp, 4> Uint128ToUint32s(const std::array<XlaOp, 2>& u128) {
359 std::array<XlaOp, 2> u128_low_32s = Uint64ToUint32s(u128[0]);
360 std::array<XlaOp, 2> u128_high_32s = Uint64ToUint32s(u128[1]);
361 return {u128_low_32s[0], u128_low_32s[1], u128_high_32s[0], u128_high_32s[1]};
362 }
363
Uint128FromOp(XlaOp op)364 std::array<XlaOp, 2> Uint128FromOp(XlaOp op) {
365 auto u128_low = xla::Reshape(xla::Slice(op, {0}, {1}, {1}), {});
366 auto u128_high = xla::Reshape(xla::Slice(op, {1}, {2}, {1}), {});
367 return {u128_low, u128_high};
368 }
369
Uint128ToOp(std::array<XlaOp,2> u128)370 XlaOp Uint128ToOp(std::array<XlaOp, 2> u128) {
371 return ConcatScalars(u128[0].builder(), {u128[0], u128[1]});
372 }
373
374 // Returns the pair (state + [0, 1, ..., n-1], state + n), which should be used
375 // as the inputs fed to `Philox4x32` and the updated state. `state` is an U128
376 // represented as 4 U32s in the order from the least significant one to the most
377 // significant one.
GetPhiloxInputsAndUpdatedState(const Philox4x32State & state,int64_t n)378 std::pair<Philox4x32State, XlaOp> GetPhiloxInputsAndUpdatedState(
379 const Philox4x32State& state, int64_t n) {
380 XlaBuilder* builder = state[0].builder();
381 XlaOp iota = Iota(builder, U64, n);
382 auto state_u128 = Uint32sToUint128(state);
383 auto inputs = Uint128ToUint32s(Uint128AddUint64(state_u128, iota, {n}));
384 XlaOp new_state = Uint128ToOp(
385 Uint128AddUint64(state_u128, ConstantR0<uint64_t>(builder, n)));
386 return std::make_pair(inputs, new_state);
387 }
388
389 // Generates CeilOfRatio(num_elems, 4)*4 32bit Philox random numbers, as Philox
390 // numbers are generated in the unit of 128bits.
GeneratePhiloxBits(int64_t num_elems,XlaOp initial_state,Philox4x32Key key)391 std::pair<Philox4x32State, XlaOp> GeneratePhiloxBits(int64_t num_elems,
392 XlaOp initial_state,
393 Philox4x32Key key) {
394 Philox4x32State state;
395 state = Uint128ToUint32s(Uint128FromOp(initial_state));
396 const int64_t num_vector4 = CeilOfRatio<int64_t>(num_elems, 4);
397 Philox4x32State inputs;
398 XlaOp new_state;
399 std::tie(inputs, new_state) =
400 GetPhiloxInputsAndUpdatedState(state, num_vector4);
401 auto outputs = Philox4x32(inputs, key);
402 return std::make_pair(outputs, new_state);
403 }
404
405 // Generates an array of primitive type U32 with the given shape containing
406 // random bits generated by the Philox algorithm. Returns the array and the new
407 // state of the random number generator.
PhiloxRngBit32(XlaOp op_key,XlaOp initial_state,const Shape & shape)408 RngOutput PhiloxRngBit32(XlaOp op_key, XlaOp initial_state,
409 const Shape& shape) {
410 XlaBuilder* builder = op_key.builder();
411 const int64_t num_elems = ShapeUtil::ElementsIn(shape);
412
413 Philox4x32Key key = Uint64ToUint32s(op_key);
414 Philox4x32State bits;
415 XlaOp new_state;
416 std::tie(bits, new_state) = GeneratePhiloxBits(num_elems, initial_state, key);
417 // Combining bits[i] in a round-robin fashion, to align with non-XLA
418 // implementations
419 int64_t bits_len = (num_elems + 3) / 4;
420 for (auto i = 0; i < 4; ++i) {
421 bits[i] = Reshape(bits[i], {bits_len, 1});
422 }
423 XlaOp numbers = ConcatInDim(builder, {bits[0], bits[1], bits[2], bits[3]},
424 /*dimension=*/1);
425 numbers = Reshape(numbers, {bits_len * 4});
426 numbers = Slice(numbers, /*start_indices=*/{0},
427 /*limit_indices=*/{num_elems},
428 /*strides=*/{1});
429 return {Reshape(numbers, shape.dimensions()), new_state};
430 }
431
432 // Generates an array of primitive type U64 with the given shape containing
433 // random bits generated by the Philox algorithm. Returns the array and the new
434 // state of the random number generator.
PhiloxRngBit64(XlaOp op_key,XlaOp initial_state,const Shape & shape)435 RngOutput PhiloxRngBit64(XlaOp op_key, XlaOp initial_state,
436 const Shape& shape) {
437 XlaBuilder* builder = op_key.builder();
438 const int64_t num_elems = ShapeUtil::ElementsIn(shape);
439
440 Philox4x32Key key = Uint64ToUint32s(op_key);
441 Philox4x32State bits32;
442 XlaOp new_state;
443 std::tie(bits32, new_state) =
444 GeneratePhiloxBits(num_elems * 2, initial_state, key);
445
446 std::array<XlaOp, 2> bits64;
447 bits64[0] = Uint32sToUint64({bits32[0], bits32[1]});
448 bits64[1] = Uint32sToUint64({bits32[2], bits32[3]});
449
450 // Combining bits64[i] in a round-robin fashion, to align with non-XLA
451 // implementations
452 int64_t bits64_len = (num_elems + 1) / 2;
453 for (auto i = 0; i < 2; ++i) {
454 bits64[i] = Reshape(bits64[i], {bits64_len, 1});
455 }
456 XlaOp numbers = ConcatInDim(builder, {bits64[0], bits64[1]},
457 /*dimension=*/1);
458 numbers = Reshape(numbers, {bits64_len * 2});
459 numbers = Slice(numbers, /*start_indices=*/{0},
460 /*limit_indices=*/{num_elems},
461 /*strides=*/{1});
462 return {Reshape(numbers, shape.dimensions()), new_state};
463 }
464
ConvertRandomBitsToUniformFloatingPoint(XlaOp bits,XlaOp minval,XlaOp maxval)465 XlaOp ConvertRandomBitsToUniformFloatingPoint(XlaOp bits, XlaOp minval,
466 XlaOp maxval) {
467 XlaBuilder* builder = bits.builder();
468 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
469 TF_ASSIGN_OR_RETURN(const Shape* minval_shape,
470 builder->GetShapePtr(minval));
471 TF_ASSIGN_OR_RETURN(const Shape* bits_shape, builder->GetShapePtr(bits));
472 PrimitiveType value_type = minval_shape->element_type();
473 PrimitiveType bit_type = bits_shape->element_type();
474 CHECK((value_type == F32 && bit_type == U32) ||
475 (value_type == F64 && bit_type == U64));
476
477 // Form random mantissa bits for float/double, with a leading 1 bit.
478 int num_float_bits = primitive_util::BitWidth(value_type);
479 // Subtract one as SignificandWidth includes the leading 1 bit.
480 int num_mantissa_bits = primitive_util::SignificandWidth(value_type) - 1;
481
482 // Ignore the exponent bits and convert the mantissa bits to the floating
483 // point type.
484 bits = ShiftRightLogical(
485 bits, ScalarLike(bits, num_float_bits - num_mantissa_bits));
486
487 // We have an integer-valued floating point number in the range
488 // [0, 2**{num_mantissa_bits}).
489 XlaOp values = ConvertElementType(bits, value_type);
490
491 // Divide by 2**{-num_mantissa_bits} to get a number in the range
492 // [0.0, 1.0).
493 values = values * ScalarLike(values, std::ldexp(1., -num_mantissa_bits));
494
495 // Multiply and add to shift to the range [minval, maxval).
496 return values * (maxval - minval) + minval;
497 });
498 }
499
ConvertRandomBitsToUniformInt(XlaOp bits,XlaOp minval,XlaOp maxval,PrimitiveType type,PrimitiveType unsigned_type)500 XlaOp ConvertRandomBitsToUniformInt(XlaOp bits, XlaOp minval, XlaOp maxval,
501 PrimitiveType type,
502 PrimitiveType unsigned_type) {
503 XlaBuilder* builder = bits.builder();
504 XlaOp range = BitcastConvertType(maxval, unsigned_type) -
505 BitcastConvertType(minval, unsigned_type);
506 XlaOp dist = Rem(bits, range);
507 XlaOp dist_div_2 =
508 ShiftRightLogical(dist, ConstantR0WithType(builder, unsigned_type, 1));
509
510 return minval + BitcastConvertType(dist_div_2, type) +
511 BitcastConvertType(dist - dist_div_2, type);
512 }
513
514 // Implements the Box-Muller transform, which converts random floats in the
515 // range of [0, 1] from uniform distribution to normal distribution with mean 0
516 // and variance 1. For more detail on the Box-Muller transform, see
517 // http://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform#Basic_form
BoxMullerTransform(XlaOp x0,XlaOp x1)518 std::pair<XlaOp, XlaOp> BoxMullerTransform(XlaOp x0, XlaOp x1) {
519 // Do not send a really small number to log().
520 XlaOp u1 = Max(x0, ScalarLike(x0, 1.0e-7f));
521
522 XlaOp v1 = ScalarLike(x1, 2.0f * M_PI) * x1;
523 XlaOp u2 = Sqrt(ScalarLike(u1, -2.0f) * Log(u1));
524 return {Sin(v1) * u2, Cos(v1) * u2};
525 }
526
527 } // namespace
528
PhiloxIncreaseCounter(XlaOp counter,XlaOp delta)529 XlaOp PhiloxIncreaseCounter(XlaOp counter, XlaOp delta) {
530 return Uint128ToOp(Uint128AddUint64(Uint128FromOp(counter), delta));
531 }
532
ThreeFryBitGenerator(XlaOp key,XlaOp initial_state,const Shape & shape)533 RngOutput ThreeFryBitGenerator(XlaOp key, XlaOp initial_state,
534 const Shape& shape) {
535 PrimitiveType type = shape.element_type();
536 switch (type) {
537 case F32:
538 case U32:
539 case S32:
540 return ThreeFryRngBit32(key, initial_state, shape);
541 case F64:
542 case U64:
543 case S64:
544 return ThreeFryRngBit64(key, initial_state, shape);
545 default:
546 return {key.builder()->ReportError(Unimplemented(
547 "Types other than F32, F64, U32, S32, U64 and S64 "
548 "are not implemented by ThreeFryBitGenerator; got %s",
549 primitive_util::LowercasePrimitiveTypeName(type))),
550 initial_state};
551 }
552 }
553
PhiloxBitGenerator(XlaOp key,XlaOp initial_state,const Shape & shape)554 RngOutput PhiloxBitGenerator(XlaOp key, XlaOp initial_state,
555 const Shape& shape) {
556 PrimitiveType type = shape.element_type();
557 switch (type) {
558 case F32:
559 case U32:
560 case S32:
561 return PhiloxRngBit32(key, initial_state, shape);
562 case F64:
563 case U64:
564 case S64:
565 return PhiloxRngBit64(key, initial_state, shape);
566 default:
567 return {key.builder()->ReportError(Unimplemented(
568 "Types other than F32, F64, U32, S32, U64 and S64 "
569 "are not implemented by PhiloxFryBitGenerator; got %s",
570 primitive_util::LowercasePrimitiveTypeName(type))),
571 initial_state};
572 }
573 }
574
ScramblePhiloxKey(XlaOp key)575 std::pair<XlaOp, XlaOp> ScramblePhiloxKey(XlaOp key) {
576 Philox4x32Key pkey = Uint64ToUint32s(key);
577 auto state_key = ScramblePhiloxKey(pkey);
578 return std::make_pair(Uint128ToOp(Uint32sToUint128(state_key.first)),
579 Uint32sToUint64(state_key.second));
580 }
581
UniformFloatingPointDistribution(XlaOp key,XlaOp initial_state,BitGeneratorTy bit_generator,XlaOp minval,XlaOp maxval,const Shape & shape)582 RngOutput UniformFloatingPointDistribution(XlaOp key, XlaOp initial_state,
583 BitGeneratorTy bit_generator,
584 XlaOp minval, XlaOp maxval,
585 const Shape& shape) {
586 RngOutput bits_state = bit_generator(key, initial_state, shape);
587 XlaOp bits = bits_state.value;
588 XlaOp new_state = bits_state.state;
589 return {ConvertRandomBitsToUniformFloatingPoint(bits, minval, maxval),
590 new_state};
591 }
592
UniformIntDistribution(XlaOp key,XlaOp initial_state,BitGeneratorTy bit_generator,XlaOp minval,XlaOp maxval,const Shape & shape)593 RngOutput UniformIntDistribution(XlaOp key, XlaOp initial_state,
594 BitGeneratorTy bit_generator, XlaOp minval,
595 XlaOp maxval, const Shape& shape) {
596 RngOutput bits_state = bit_generator(key, initial_state, shape);
597 XlaOp bits = bits_state.value;
598 XlaOp new_state = bits_state.state;
599 PrimitiveType type = shape.element_type();
600 PrimitiveType unsigned_type;
601 if (type == U32 || type == S32) {
602 unsigned_type = U32;
603 } else {
604 DCHECK(type == U64 || type == S64);
605 unsigned_type = U64;
606 }
607 return {
608 ConvertRandomBitsToUniformInt(bits, minval, maxval, type, unsigned_type),
609 new_state};
610 }
611
NormalFloatingPointDistribution(XlaOp key,XlaOp initial_state,BitGeneratorTy bit_generator,const Shape & shape)612 RngOutput NormalFloatingPointDistribution(XlaOp key, XlaOp initial_state,
613 BitGeneratorTy bit_generator,
614 const Shape& shape) {
615 PrimitiveType primitive_type = shape.element_type();
616 DCHECK(primitive_type == F32 || primitive_type == F64);
617
618 XlaBuilder* builder = key.builder();
619 auto shape_pair = SplitShapeIntoHalves(shape);
620 RngOutput bits_state = UniformFloatingPointDistribution(
621 key, initial_state, bit_generator,
622 xla::ConstantR0WithType(builder, primitive_type, 0.0),
623 xla::ConstantR0WithType(builder, primitive_type, 1.0),
624 shape_pair.concat_shape);
625
626 // Separate the bits into two groups to perform the Box-Muller transform.
627 XlaOp bits_0 = Slice(bits_state.value,
628 std::vector<int64_t>(shape_pair.half_shape.rank(), 0),
629 shape_pair.half_shape.dimensions(),
630 std::vector<int64_t>(shape_pair.half_shape.rank(), 1));
631 std::vector<int64_t> bits_1_starts(shape_pair.half_shape.rank(), 0);
632 bits_1_starts[shape_pair.new_concat_dim] = 1;
633 XlaOp bits_1 = Slice(bits_state.value, bits_1_starts,
634 shape_pair.concat_shape.dimensions(),
635 std::vector<int64_t>(shape_pair.half_shape.rank(), 1));
636 std::tie(bits_0, bits_1) = BoxMullerTransform(bits_0, bits_1);
637
638 // Put the numbers in the two groups back to form the requested shape.
639 XlaOp normal = CombineShapePair({bits_0, bits_1}, shape_pair, shape);
640 return {normal, bits_state.state};
641 }
642
643 } // namespace xla
644