1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Utilities for dealing with Literal protobufs.
17
18 #ifndef TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_
19 #define TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_
20
21 #include <functional>
22 #include <initializer_list>
23 #include <iterator>
24 #include <memory>
25 #include <ostream>
26 #include <random>
27 #include <string>
28 #include <type_traits>
29 #include <utility>
30 #include <vector>
31
32 #include "absl/strings/string_view.h"
33 #include "absl/types/span.h"
34 #include "tensorflow/compiler/xla/array2d.h"
35 #include "tensorflow/compiler/xla/array3d.h"
36 #include "tensorflow/compiler/xla/array4d.h"
37 #include "tensorflow/compiler/xla/index_util.h"
38 #include "tensorflow/compiler/xla/layout_util.h"
39 #include "tensorflow/compiler/xla/literal.h"
40 #include "tensorflow/compiler/xla/primitive_util.h"
41 #include "tensorflow/compiler/xla/shape_util.h"
42 #include "tensorflow/compiler/xla/status_macros.h"
43 #include "tensorflow/compiler/xla/types.h"
44 #include "tensorflow/compiler/xla/util.h"
45 #include "tensorflow/compiler/xla/xla_data.pb.h"
46 #include "tensorflow/core/lib/core/bitmap.h"
47 #include "tensorflow/core/lib/core/status.h"
48 #include "tensorflow/core/platform/logging.h"
49 #include "tensorflow/core/platform/protobuf.h"
50
51 namespace xla {
52
53 class LiteralUtil {
54 public:
55 LiteralUtil() = delete;
56
57 // Returns a literal scalar representing the first element.
58 static Literal GetFirstScalarLiteral(const LiteralSlice& literal);
59 // Returns a literal scalar representing the element at `multi_index`.
60 static Literal GetScalarLiteral(const LiteralBase& literal,
61 absl::Span<const int64_t> multi_index);
62 // Sets the value of the element at `multi_index` with a scalar literal.
63 static void SetScalarLiteral(MutableLiteralBase& literal,
64 absl::Span<const int64_t> multi_index,
65 const LiteralBase& scalar);
66
67 // Creates a new literal of a given rank. To minimize ambiguity (for users
68 // and the compiler) these CreateR[0-2] methods should explicitly specify the
69 // native type. For example:
70 //
71 // CreateR1<float>({1.0, 42.0});
72 // CreateR2<uint32_t>({{1, 2}, {3, 4}});
73 //
74 // The variants not ending with WithLayout use the default XLA layout for the
75 // literal's linear representation in memory.
76 template <typename NativeT>
77 static Literal CreateR0(NativeT value);
78 template <typename NativeT>
79 static Literal CreateR1(absl::Span<const NativeT> values);
80 static Literal CreateR1(const tensorflow::core::Bitmap& values);
81 template <typename NativeT>
82 static Literal CreateR2(
83 std::initializer_list<std::initializer_list<NativeT>> values);
84 template <typename NativeT>
85 static Literal CreateR2WithLayout(
86 std::initializer_list<std::initializer_list<NativeT>> values,
87 const Layout& layout);
88 template <typename NativeT>
89 static Literal CreateR3(std::initializer_list<
90 std::initializer_list<std::initializer_list<NativeT>>>
91 values);
92 template <typename NativeT>
93 static Literal CreateR3WithLayout(
94 std::initializer_list<
95 std::initializer_list<std::initializer_list<NativeT>>>
96 values,
97 const Layout& layout);
98 template <typename NativeT>
99 static Literal CreateR4(
100 std::initializer_list<std::initializer_list<
101 std::initializer_list<std::initializer_list<NativeT>>>>
102 values);
103 template <typename NativeT>
104 static Literal CreateR4WithLayout(
105 std::initializer_list<std::initializer_list<
106 std::initializer_list<std::initializer_list<NativeT>>>>
107 values,
108 const Layout& layout);
109
110 // Creates a scalar literal value zero of the given primitive type.
111 static Literal Zero(PrimitiveType primitive_type);
112 // Creates a scalar literal value one of the given primitive type.
113 static Literal One(PrimitiveType primitive_type);
114 // Creates a scalar literal value containing the minimum value of the given
115 // primitive type. For floating-point types, returns -inf.
116 static Literal MinValue(PrimitiveType primitive_type);
117 // Creates a scalar literal value containing the maximum value of the given
118 // primitive type. For floating-point types, returns inf.
119 static Literal MaxValue(PrimitiveType primitive_type);
120 // Creates a scalar literal value containing the NaN value of the given
121 // primitive type. Fail for non-inexact types. For complex types, returns a
122 // nan + nan * j value.
123 static StatusOr<Literal> NanValue(PrimitiveType primitive_type);
124 // Creates a literal of the given shape where each element is `value`.
125 template <typename NativeT>
126 static Literal CreateFullWithDescendingLayout(
127 absl::Span<const int64_t> dimensions, NativeT value);
128
129 // Creates a new literal from an Array type. The variants not ending with
130 // WithLayout use the default XLA layout for the literal's linear
131 // representation in memory.
132 template <typename NativeT>
133 static Literal CreateFromArray(const Array<NativeT>& values);
134 template <typename NativeT>
135 static Literal CreateFromArrayWithLayout(const Array<NativeT>& values,
136 const Layout& layout);
137 template <typename NativeT>
138 static Literal CreateR2FromArray2D(const Array2D<NativeT>& values);
139 template <typename NativeT>
140 static Literal CreateR2FromArray2DWithLayout(const Array2D<NativeT>& values,
141 const Layout& layout);
142 template <typename NativeT>
143 static Literal CreateR3FromArray3D(const Array3D<NativeT>& values);
144 template <typename NativeT>
145 static Literal CreateR3FromArray3DWithLayout(const Array3D<NativeT>& values,
146 const Layout& layout);
147 template <typename NativeT>
148 static Literal CreateR4FromArray4D(const Array4D<NativeT>& values);
149 template <typename NativeT>
150 static Literal CreateR4FromArray4DWithLayout(const Array4D<NativeT>& values,
151 const Layout& layout);
152
153 // Creates a new vector of U8s literal value from a string.
154 static Literal CreateR1U8(absl::string_view value);
155
156 // Creates a linspace-populated literal with the given number of rows and
157 // columns.
158 static Literal CreateR2F32Linspace(float from, float to, int64_t rows,
159 int64_t cols);
160
161 // Creates a literal that projects the (x, y) dimensions given in values into
162 // the z dimension given by "projection".
163 template <typename NativeT>
164 static Literal CreateR3Projected(
165 std::initializer_list<std::initializer_list<NativeT>> values,
166 int64_t projection);
167
168 // Creates a literal that projects the (x, y) dimensions given in values into
169 // the z and p dimensions given.
170 template <typename NativeT>
171 static Literal CreateR4Projected(
172 std::initializer_list<std::initializer_list<NativeT>> values,
173 int64_t projection_p, int64_t projection_z);
174
175 // Returns an identity matrix (rank 2) with the given row and column count.
176 template <typename NativeT>
177 static Literal MakeIdentityR2(int64_t size);
178
179 // Returns a tuple literal composed of given literals. Data is copied from the
180 // given elements into the returned literal.
181 static Literal MakeTuple(absl::Span<const Literal* const> elements);
182
183 static Literal MakeTupleFromSlices(absl::Span<const LiteralSlice> elements);
184
185 // As above, but intended to be invoked with move semantics; i.e.
186 //
187 // std::vector<Literal> elements = ...;
188 // auto result = LiteralUtil::MakeTupleOwned(std::move(elements));
189 //
190 // This would have been declared as an overload, but there is ambiguity
191 // in invocation between the above signature and this one.
192 static Literal MakeTupleOwned(std::vector<Literal> elements);
193
194 // This overload lets you pass a list of Literals to MakeTupleOwned:
195 //
196 // LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1(...), ...).
197 //
198 // Simply relying on the MakeTupleOwned(std::vector<Literal>)
199 // overload doesn't work because std::initializer_list's elements are always
200 // const.
201 //
202 // The arguments to this function must all be Literal.
203 template <typename... Ts>
MakeTupleOwned(Ts...elements)204 static Literal MakeTupleOwned(Ts... elements) {
205 std::array<Literal, sizeof...(Ts)> arr{std::move(elements)...};
206 std::vector<Literal> v;
207 v.insert(v.begin(), std::make_move_iterator(arr.begin()),
208 std::make_move_iterator(arr.end()));
209 return MakeTupleOwned(std::move(v));
210 }
211
212 // Create a constant token literal. Token types have no value.
213 static Literal CreateToken();
214
215 // Creates a new Literal object with its values havings the primitive_type
216 // type, and with dimensions defined by the dimensions parameter.
217 // The content of the literal values is the default value of the primitive
218 // type of literal itself (0 for numeric types, and false for predicates).
219 static Literal CreateFromDimensions(PrimitiveType primitive_type,
220 absl::Span<const int64_t> dimensions);
221
222 // Convert<SrcType>To<DstType> family of functions:
223 // If the given literal's data type is <SrcType>, converts it to a <DstType>
224 // literal; otherwise, returns a copy of it. If the literal is a tuple,
225 // recursively converts its elements.
226 static Literal ConvertBF16ToF32(const LiteralSlice& bf16_literal);
227 static Literal ConvertBF16ToF64(const LiteralSlice& bf16_literal);
228 static Literal ConvertF32ToBF16(const LiteralSlice& f32_literal);
229 static Literal ConvertF32ToF64(const LiteralSlice& f32_literal);
230 static Literal ConvertF64ToBF16(const LiteralSlice& f64_literal);
231 static Literal ConvertF64ToF32(const LiteralSlice& f64_literal);
232 static Literal ConvertS32ToF32(const LiteralSlice& s32_literal);
233
234 // Creates a scalar literal whose value is the maximum value of a given
235 // literal slice.
236 static Literal MaxElement(const LiteralSlice& literal);
237
238 // Creates a literal with a new shape with the given new dimensions using the
239 // data in the given input literal. For reshaping purposes the (flat) data
240 // buffer of the input literal is assumed to have the given minor_to_major
241 // layout order.
242 static Literal ReshapeSlice(absl::Span<const int64_t> new_dimensions,
243 absl::Span<const int64_t> minor_to_major,
244 const LiteralSlice& literal);
245
246 // Creates a literal with the supplied shape, and uses the provided value
247 // generator to populate the literal's values.
248 // Returns the new literal object, or an error Status if failed.
249 template <
250 PrimitiveType type,
251 typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
252 static StatusOr<Literal> CreateLiteralWithGenerator(
253 const Shape& shape,
254 const std::function<T(absl::Span<const int64_t>)>& generator);
255
256 // Creates a literal with the supplied shape, and initializes the literal
257 // values using a normal distribution with given mean and stddev standard
258 // deviation, and using the engine as entropy generator.
259 // Returns the new literal object, or an error Status if failed.
260 template <
261 PrimitiveType type, typename E,
262 typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
263 static StatusOr<Literal> CreateRandomLiteral(const Shape& shape, E* engine,
264 T mean, T stddev);
265
266 // Creates a literal with the supplied shape, and initializes the literal
267 // values using a normal distribution with given mean and stddev standard
268 // deviation.
269 // Returns the new literal object, or an error Status if failed.
270 template <
271 PrimitiveType type,
272 typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
273 static StatusOr<Literal> CreateRandomLiteral(const Shape& shape, T mean,
274 T stddev);
275
276 //
277 // End of factory methods.
278
279 // Returns a multi-dimensional index as a string. For example: '{7, 8}' will
280 // be returned for a 2-dimensional index with dimension 0 index equal to 7,
281 // dimension 1 equal to 8.
282 static std::string MultiIndexAsString(absl::Span<const int64_t> multi_index);
283 };
284
285 std::ostream& operator<<(std::ostream& out, const Literal& literal);
286
287 template <typename NativeT>
CreateR0(NativeT value)288 /* static */ Literal LiteralUtil::CreateR0(NativeT value) {
289 Literal literal(ShapeUtil::MakeShape(
290 primitive_util::NativeToPrimitiveType<NativeT>(), {}));
291 literal.Set({}, value);
292 return literal;
293 }
294
295 template <typename NativeT>
CreateR1(absl::Span<const NativeT> values)296 /* static */ Literal LiteralUtil::CreateR1(absl::Span<const NativeT> values) {
297 Literal literal(
298 ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<NativeT>(),
299 {static_cast<int64_t>(values.size())}));
300 literal.PopulateR1(values);
301 return literal;
302 }
303
304 template <typename NativeT>
CreateR2WithLayout(std::initializer_list<std::initializer_list<NativeT>> values,const Layout & layout)305 /* static */ Literal LiteralUtil::CreateR2WithLayout(
306 std::initializer_list<std::initializer_list<NativeT>> values,
307 const Layout& layout) {
308 Literal literal(ShapeUtil::MakeShapeWithLayout(
309 primitive_util::NativeToPrimitiveType<NativeT>(),
310 {static_cast<int64_t>(values.size()),
311 static_cast<int64_t>(values.begin()->size())},
312 layout.minor_to_major()));
313 literal.PopulateR2(values);
314 return literal;
315 }
316
317 template <typename NativeT>
CreateR2(std::initializer_list<std::initializer_list<NativeT>> values)318 /* static */ Literal LiteralUtil::CreateR2(
319 std::initializer_list<std::initializer_list<NativeT>> values) {
320 return CreateR2WithLayout(values, LayoutUtil::GetDefaultLayoutForR2());
321 }
322
323 template <typename NativeT>
CreateR3WithLayout(std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>> values,const Layout & layout)324 /* static */ Literal LiteralUtil::CreateR3WithLayout(
325 std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
326 values,
327 const Layout& layout) {
328 const int64_t d0 = values.size();
329 const int64_t d1 = values.begin()->size();
330 const int64_t d2 = values.begin()->begin()->size();
331 Array3D<NativeT> tmp(d0, d1, d2);
332 int64_t i0 = 0;
333 for (auto d1_values : values) {
334 int64_t i1 = 0;
335 for (auto d2_values : d1_values) {
336 int64_t i2 = 0;
337 for (auto value : d2_values) {
338 tmp(i0, i1, i2) = value;
339 ++i2;
340 }
341 ++i1;
342 }
343 ++i0;
344 }
345 return CreateR3FromArray3DWithLayout(tmp, layout);
346 }
347
348 template <typename NativeT>
CreateR3(std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>> values)349 /* static */ Literal LiteralUtil::CreateR3(
350 std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
351 values) {
352 return CreateR3WithLayout(values, LayoutUtil::GetDefaultLayoutForR3());
353 }
354
355 template <typename NativeT>
CreateR4WithLayout(std::initializer_list<std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>> values,const Layout & layout)356 /* static */ Literal LiteralUtil::CreateR4WithLayout(
357 std::initializer_list<std::initializer_list<
358 std::initializer_list<std::initializer_list<NativeT>>>>
359 values,
360 const Layout& layout) {
361 const int64_t d0 = values.size();
362 const int64_t d1 = values.begin()->size();
363 const int64_t d2 = values.begin()->begin()->size();
364 const int64_t d3 = values.begin()->begin()->begin()->size();
365 Array4D<NativeT> tmp(d0, d1, d2, d3);
366 int64_t i0 = 0;
367 for (auto d1_values : values) {
368 int64_t i1 = 0;
369 for (auto d2_values : d1_values) {
370 int64_t i2 = 0;
371 for (auto d3_values : d2_values) {
372 int64_t i3 = 0;
373 for (auto value : d3_values) {
374 tmp(i0, i1, i2, i3) = value;
375 ++i3;
376 }
377 ++i2;
378 }
379 ++i1;
380 }
381 ++i0;
382 }
383 return CreateR4FromArray4DWithLayout(tmp, layout);
384 }
385
386 template <typename NativeT>
CreateR4(std::initializer_list<std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>> values)387 /* static */ Literal LiteralUtil::CreateR4(
388 std::initializer_list<std::initializer_list<
389 std::initializer_list<std::initializer_list<NativeT>>>>
390 values) {
391 return CreateR4WithLayout(values, LayoutUtil::GetDefaultLayoutForR4());
392 }
393
394 template <typename NativeT>
CreateFromArrayWithLayout(const Array<NativeT> & values,const Layout & layout)395 /* static */ Literal LiteralUtil::CreateFromArrayWithLayout(
396 const Array<NativeT>& values, const Layout& layout) {
397 Literal literal(ShapeUtil::MakeShapeWithLayout(
398 primitive_util::NativeToPrimitiveType<NativeT>(), values.dimensions(),
399 layout.minor_to_major()));
400 literal.PopulateFromArray(values);
401 return literal;
402 }
403
404 template <typename NativeT>
CreateFromArray(const Array<NativeT> & values)405 /* static */ Literal LiteralUtil::CreateFromArray(
406 const Array<NativeT>& values) {
407 return CreateFromArrayWithLayout(
408 values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions()));
409 }
410
411 template <typename NativeT>
CreateR2FromArray2DWithLayout(const Array2D<NativeT> & values,const Layout & layout)412 /* static */ Literal LiteralUtil::CreateR2FromArray2DWithLayout(
413 const Array2D<NativeT>& values, const Layout& layout) {
414 return CreateFromArrayWithLayout(values, layout);
415 }
416
417 template <typename NativeT>
CreateR2FromArray2D(const Array2D<NativeT> & values)418 /* static */ Literal LiteralUtil::CreateR2FromArray2D(
419 const Array2D<NativeT>& values) {
420 return CreateFromArray(values);
421 }
422
423 template <typename NativeT>
CreateR3FromArray3DWithLayout(const Array3D<NativeT> & values,const Layout & layout)424 /* static */ Literal LiteralUtil::CreateR3FromArray3DWithLayout(
425 const Array3D<NativeT>& values, const Layout& layout) {
426 return CreateFromArrayWithLayout(values, layout);
427 }
428
429 template <typename NativeT>
CreateR3FromArray3D(const Array3D<NativeT> & values)430 /* static */ Literal LiteralUtil::CreateR3FromArray3D(
431 const Array3D<NativeT>& values) {
432 return CreateFromArray(values);
433 }
434
435 template <typename NativeT>
CreateR3Projected(std::initializer_list<std::initializer_list<NativeT>> values,int64_t projection)436 /* static */ Literal LiteralUtil::CreateR3Projected(
437 std::initializer_list<std::initializer_list<NativeT>> values,
438 int64_t projection) {
439 int64_t dim0_size = projection;
440 int64_t dim1_size = values.size();
441 int64_t dim2_size = values.begin()->size();
442
443 Array3D<NativeT> array(dim0_size, dim1_size, dim2_size);
444 for (int64_t dim0 = 0; dim0 < dim0_size; ++dim0) {
445 int64_t dim1 = 0;
446 for (auto inner_list : values) {
447 int64_t dim2 = 0;
448 for (auto value : inner_list) {
449 array(dim0, dim1, dim2) = value;
450 ++dim2;
451 }
452 CHECK_EQ(dim2_size, dim2);
453 ++dim1;
454 }
455 CHECK_EQ(dim1_size, dim1);
456 }
457 return CreateR3FromArray3D(array);
458 }
459
460 template <typename NativeT>
CreateR4Projected(std::initializer_list<std::initializer_list<NativeT>> values,int64_t projection_p,int64_t projection_z)461 /* static */ Literal LiteralUtil::CreateR4Projected(
462 std::initializer_list<std::initializer_list<NativeT>> values,
463 int64_t projection_p, int64_t projection_z) {
464 int64_t dim0_size = projection_p;
465 int64_t dim1_size = projection_z;
466 int64_t dim2_size = values.size();
467 int64_t dim3_size = values.begin()->size();
468
469 Array4D<NativeT> array(dim0_size, dim1_size, dim2_size, dim3_size);
470 for (int64_t dim0 = 0; dim0 < dim0_size; ++dim0) {
471 for (int64_t dim1 = 0; dim1 < dim1_size; ++dim1) {
472 int64_t dim2 = 0;
473 for (auto inner_list : values) {
474 int64_t dim3 = 0;
475 for (auto value : inner_list) {
476 array(dim0, dim1, dim2, dim3) = value;
477 ++dim3;
478 }
479 CHECK_EQ(dim3_size, dim3);
480 ++dim2;
481 }
482 CHECK_EQ(dim2_size, dim2);
483 }
484 }
485 return CreateR4FromArray4D(array);
486 }
487
488 template <typename NativeT>
CreateR4FromArray4D(const Array4D<NativeT> & values)489 /* static */ Literal LiteralUtil::CreateR4FromArray4D(
490 const Array4D<NativeT>& values) {
491 return CreateFromArray(values);
492 }
493
494 template <typename NativeT>
CreateR4FromArray4DWithLayout(const Array4D<NativeT> & values,const Layout & layout)495 /* static */ Literal LiteralUtil::CreateR4FromArray4DWithLayout(
496 const Array4D<NativeT>& values, const Layout& layout) {
497 return CreateFromArrayWithLayout(values, layout);
498 }
499
500 // Returns an identity matrix (rank 2) with the given row and column count.
501 template <typename NativeT>
MakeIdentityR2(int64_t size)502 /* static */ Literal LiteralUtil::MakeIdentityR2(int64_t size) {
503 Array2D<NativeT> array(size, size, 0);
504 for (int64_t i = 0; i < size; ++i) {
505 array(i, i) = 1;
506 }
507 return CreateR2FromArray2D(array);
508 }
509
510 template <typename NativeT>
CreateFullWithDescendingLayout(absl::Span<const int64_t> dimensions,NativeT value)511 /* static */ Literal LiteralUtil::CreateFullWithDescendingLayout(
512 absl::Span<const int64_t> dimensions, NativeT value) {
513 Literal literal(ShapeUtil::MakeShapeWithDescendingLayout(
514 primitive_util::NativeToPrimitiveType<NativeT>(), dimensions));
515 literal.PopulateWithValue(value);
516 return literal;
517 }
518
519 template <PrimitiveType type, typename T>
CreateLiteralWithGenerator(const Shape & shape,const std::function<T (absl::Span<const int64_t>)> & generator)520 /* static */ StatusOr<Literal> LiteralUtil::CreateLiteralWithGenerator(
521 const Shape& shape,
522 const std::function<T(absl::Span<const int64_t>)>& generator) {
523 using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
524 TF_RET_CHECK(shape.element_type() == type);
525 Literal literal(shape);
526 TF_RETURN_IF_ERROR(literal.Populate<NativeT>(
527 [&](absl::Span<const int64_t> indexes) { return generator(indexes); }));
528 return std::move(literal);
529 }
530
531 template <PrimitiveType type, typename E, typename T>
CreateRandomLiteral(const Shape & shape,E * engine,T mean,T stddev)532 /* static */ StatusOr<Literal> LiteralUtil::CreateRandomLiteral(
533 const Shape& shape, E* engine, T mean, T stddev) {
534 using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
535 std::normal_distribution<NativeT> generator(mean, stddev);
536 return CreateLiteralWithGenerator<type, NativeT>(
537 shape, [&](absl::Span<const int64_t> /*indexes*/) {
538 return generator(*engine);
539 });
540 }
541
542 template <PrimitiveType type, typename T>
CreateRandomLiteral(const Shape & shape,T mean,T stddev)543 /* static */ StatusOr<Literal> LiteralUtil::CreateRandomLiteral(
544 const Shape& shape, T mean, T stddev) {
545 std::minstd_rand0 engine;
546 return CreateRandomLiteral<type>(shape, &engine, mean, stddev);
547 }
548
549 } // namespace xla
550
551 #endif // TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_
552