xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/literal_util.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Utilities for dealing with Literal protobufs.
17 
18 #ifndef TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_
19 #define TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_
20 
21 #include <functional>
22 #include <initializer_list>
23 #include <iterator>
24 #include <memory>
25 #include <ostream>
26 #include <random>
27 #include <string>
28 #include <type_traits>
29 #include <utility>
30 #include <vector>
31 
32 #include "absl/strings/string_view.h"
33 #include "absl/types/span.h"
34 #include "tensorflow/compiler/xla/array2d.h"
35 #include "tensorflow/compiler/xla/array3d.h"
36 #include "tensorflow/compiler/xla/array4d.h"
37 #include "tensorflow/compiler/xla/index_util.h"
38 #include "tensorflow/compiler/xla/layout_util.h"
39 #include "tensorflow/compiler/xla/literal.h"
40 #include "tensorflow/compiler/xla/primitive_util.h"
41 #include "tensorflow/compiler/xla/shape_util.h"
42 #include "tensorflow/compiler/xla/status_macros.h"
43 #include "tensorflow/compiler/xla/types.h"
44 #include "tensorflow/compiler/xla/util.h"
45 #include "tensorflow/compiler/xla/xla_data.pb.h"
46 #include "tensorflow/core/lib/core/bitmap.h"
47 #include "tensorflow/core/lib/core/status.h"
48 #include "tensorflow/core/platform/logging.h"
49 #include "tensorflow/core/platform/protobuf.h"
50 
51 namespace xla {
52 
53 class LiteralUtil {
54  public:
55   LiteralUtil() = delete;
56 
57   // Returns a literal scalar representing the first element.
58   static Literal GetFirstScalarLiteral(const LiteralSlice& literal);
59   // Returns a literal scalar representing the element at `multi_index`.
60   static Literal GetScalarLiteral(const LiteralBase& literal,
61                                   absl::Span<const int64_t> multi_index);
62   // Sets the value of the element at `multi_index` with a scalar literal.
63   static void SetScalarLiteral(MutableLiteralBase& literal,
64                                absl::Span<const int64_t> multi_index,
65                                const LiteralBase& scalar);
66 
67   // Creates a new literal of a given rank. To minimize ambiguity (for users
68   // and the compiler) these CreateR[0-2] methods should explicitly specify the
69   // native type. For example:
70   //
71   //  CreateR1<float>({1.0, 42.0});
72   //  CreateR2<uint32_t>({{1, 2}, {3, 4}});
73   //
74   // The variants not ending with WithLayout use the default XLA layout for the
75   // literal's linear representation in memory.
76   template <typename NativeT>
77   static Literal CreateR0(NativeT value);
78   template <typename NativeT>
79   static Literal CreateR1(absl::Span<const NativeT> values);
80   static Literal CreateR1(const tensorflow::core::Bitmap& values);
81   template <typename NativeT>
82   static Literal CreateR2(
83       std::initializer_list<std::initializer_list<NativeT>> values);
84   template <typename NativeT>
85   static Literal CreateR2WithLayout(
86       std::initializer_list<std::initializer_list<NativeT>> values,
87       const Layout& layout);
88   template <typename NativeT>
89   static Literal CreateR3(std::initializer_list<
90                           std::initializer_list<std::initializer_list<NativeT>>>
91                               values);
92   template <typename NativeT>
93   static Literal CreateR3WithLayout(
94       std::initializer_list<
95           std::initializer_list<std::initializer_list<NativeT>>>
96           values,
97       const Layout& layout);
98   template <typename NativeT>
99   static Literal CreateR4(
100       std::initializer_list<std::initializer_list<
101           std::initializer_list<std::initializer_list<NativeT>>>>
102           values);
103   template <typename NativeT>
104   static Literal CreateR4WithLayout(
105       std::initializer_list<std::initializer_list<
106           std::initializer_list<std::initializer_list<NativeT>>>>
107           values,
108       const Layout& layout);
109 
110   // Creates a scalar literal value zero of the given primitive type.
111   static Literal Zero(PrimitiveType primitive_type);
112   // Creates a scalar literal value one of the given primitive type.
113   static Literal One(PrimitiveType primitive_type);
114   // Creates a scalar literal value containing the minimum value of the given
115   // primitive type. For floating-point types, returns -inf.
116   static Literal MinValue(PrimitiveType primitive_type);
117   // Creates a scalar literal value containing the maximum value of the given
118   // primitive type. For floating-point types, returns inf.
119   static Literal MaxValue(PrimitiveType primitive_type);
120   // Creates a scalar literal value containing the NaN value of the given
121   // primitive type. Fail for non-inexact types. For complex types, returns a
122   // nan + nan * j value.
123   static StatusOr<Literal> NanValue(PrimitiveType primitive_type);
124   // Creates a literal of the given shape where each element is `value`.
125   template <typename NativeT>
126   static Literal CreateFullWithDescendingLayout(
127       absl::Span<const int64_t> dimensions, NativeT value);
128 
129   // Creates a new literal from an Array type. The variants not ending with
130   // WithLayout use the default XLA layout for the literal's linear
131   // representation in memory.
132   template <typename NativeT>
133   static Literal CreateFromArray(const Array<NativeT>& values);
134   template <typename NativeT>
135   static Literal CreateFromArrayWithLayout(const Array<NativeT>& values,
136                                            const Layout& layout);
137   template <typename NativeT>
138   static Literal CreateR2FromArray2D(const Array2D<NativeT>& values);
139   template <typename NativeT>
140   static Literal CreateR2FromArray2DWithLayout(const Array2D<NativeT>& values,
141                                                const Layout& layout);
142   template <typename NativeT>
143   static Literal CreateR3FromArray3D(const Array3D<NativeT>& values);
144   template <typename NativeT>
145   static Literal CreateR3FromArray3DWithLayout(const Array3D<NativeT>& values,
146                                                const Layout& layout);
147   template <typename NativeT>
148   static Literal CreateR4FromArray4D(const Array4D<NativeT>& values);
149   template <typename NativeT>
150   static Literal CreateR4FromArray4DWithLayout(const Array4D<NativeT>& values,
151                                                const Layout& layout);
152 
153   // Creates a new vector of U8s literal value from a string.
154   static Literal CreateR1U8(absl::string_view value);
155 
156   // Creates a linspace-populated literal with the given number of rows and
157   // columns.
158   static Literal CreateR2F32Linspace(float from, float to, int64_t rows,
159                                      int64_t cols);
160 
161   // Creates a literal that projects the (x, y) dimensions given in values into
162   // the z dimension given by "projection".
163   template <typename NativeT>
164   static Literal CreateR3Projected(
165       std::initializer_list<std::initializer_list<NativeT>> values,
166       int64_t projection);
167 
168   // Creates a literal that projects the (x, y) dimensions given in values into
169   // the z and p dimensions given.
170   template <typename NativeT>
171   static Literal CreateR4Projected(
172       std::initializer_list<std::initializer_list<NativeT>> values,
173       int64_t projection_p, int64_t projection_z);
174 
175   // Returns an identity matrix (rank 2) with the given row and column count.
176   template <typename NativeT>
177   static Literal MakeIdentityR2(int64_t size);
178 
179   // Returns a tuple literal composed of given literals. Data is copied from the
180   // given elements into the returned literal.
181   static Literal MakeTuple(absl::Span<const Literal* const> elements);
182 
183   static Literal MakeTupleFromSlices(absl::Span<const LiteralSlice> elements);
184 
185   // As above, but intended to be invoked with move semantics; i.e.
186   //
187   //  std::vector<Literal> elements = ...;
188   //  auto result = LiteralUtil::MakeTupleOwned(std::move(elements));
189   //
190   // This would have been declared as an overload, but there is ambiguity
191   // in invocation between the above signature and this one.
192   static Literal MakeTupleOwned(std::vector<Literal> elements);
193 
194   // This overload lets you pass a list of Literals to MakeTupleOwned:
195   //
196   //   LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1(...), ...).
197   //
198   // Simply relying on the MakeTupleOwned(std::vector<Literal>)
199   // overload doesn't work because std::initializer_list's elements are always
200   // const.
201   //
202   // The arguments to this function must all be Literal.
203   template <typename... Ts>
MakeTupleOwned(Ts...elements)204   static Literal MakeTupleOwned(Ts... elements) {
205     std::array<Literal, sizeof...(Ts)> arr{std::move(elements)...};
206     std::vector<Literal> v;
207     v.insert(v.begin(), std::make_move_iterator(arr.begin()),
208              std::make_move_iterator(arr.end()));
209     return MakeTupleOwned(std::move(v));
210   }
211 
212   // Create a constant token literal. Token types have no value.
213   static Literal CreateToken();
214 
215   // Creates a new Literal object with its values havings the primitive_type
216   // type, and with dimensions defined by the dimensions parameter.
217   // The content of the literal values is the default value of the primitive
218   // type of literal itself (0 for numeric types, and false for predicates).
219   static Literal CreateFromDimensions(PrimitiveType primitive_type,
220                                       absl::Span<const int64_t> dimensions);
221 
222   // Convert<SrcType>To<DstType> family of functions:
223   // If the given literal's data type is <SrcType>, converts it to a <DstType>
224   // literal; otherwise, returns a copy of it. If the literal is a tuple,
225   // recursively converts its elements.
226   static Literal ConvertBF16ToF32(const LiteralSlice& bf16_literal);
227   static Literal ConvertBF16ToF64(const LiteralSlice& bf16_literal);
228   static Literal ConvertF32ToBF16(const LiteralSlice& f32_literal);
229   static Literal ConvertF32ToF64(const LiteralSlice& f32_literal);
230   static Literal ConvertF64ToBF16(const LiteralSlice& f64_literal);
231   static Literal ConvertF64ToF32(const LiteralSlice& f64_literal);
232   static Literal ConvertS32ToF32(const LiteralSlice& s32_literal);
233 
234   // Creates a scalar literal whose value is the maximum value of a given
235   // literal slice.
236   static Literal MaxElement(const LiteralSlice& literal);
237 
238   // Creates a literal with a new shape with the given new dimensions using the
239   // data in the given input literal. For reshaping purposes the (flat) data
240   // buffer of the input literal is assumed to have the given minor_to_major
241   // layout order.
242   static Literal ReshapeSlice(absl::Span<const int64_t> new_dimensions,
243                               absl::Span<const int64_t> minor_to_major,
244                               const LiteralSlice& literal);
245 
246   // Creates a literal with the supplied shape, and uses the provided value
247   // generator to populate the literal's values.
248   // Returns the new literal object, or an error Status if failed.
249   template <
250       PrimitiveType type,
251       typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
252   static StatusOr<Literal> CreateLiteralWithGenerator(
253       const Shape& shape,
254       const std::function<T(absl::Span<const int64_t>)>& generator);
255 
256   // Creates a literal with the supplied shape, and initializes the literal
257   // values using a normal distribution with given mean and stddev standard
258   // deviation, and using the engine as entropy generator.
259   // Returns the new literal object, or an error Status if failed.
260   template <
261       PrimitiveType type, typename E,
262       typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
263   static StatusOr<Literal> CreateRandomLiteral(const Shape& shape, E* engine,
264                                                T mean, T stddev);
265 
266   // Creates a literal with the supplied shape, and initializes the literal
267   // values using a normal distribution with given mean and stddev standard
268   // deviation.
269   // Returns the new literal object, or an error Status if failed.
270   template <
271       PrimitiveType type,
272       typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
273   static StatusOr<Literal> CreateRandomLiteral(const Shape& shape, T mean,
274                                                T stddev);
275 
276   //
277   // End of factory methods.
278 
279   // Returns a multi-dimensional index as a string. For example: '{7, 8}' will
280   // be returned for a 2-dimensional index with dimension 0 index equal to 7,
281   // dimension 1 equal to 8.
282   static std::string MultiIndexAsString(absl::Span<const int64_t> multi_index);
283 };
284 
285 std::ostream& operator<<(std::ostream& out, const Literal& literal);
286 
287 template <typename NativeT>
CreateR0(NativeT value)288 /* static */ Literal LiteralUtil::CreateR0(NativeT value) {
289   Literal literal(ShapeUtil::MakeShape(
290       primitive_util::NativeToPrimitiveType<NativeT>(), {}));
291   literal.Set({}, value);
292   return literal;
293 }
294 
295 template <typename NativeT>
CreateR1(absl::Span<const NativeT> values)296 /* static */ Literal LiteralUtil::CreateR1(absl::Span<const NativeT> values) {
297   Literal literal(
298       ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<NativeT>(),
299                            {static_cast<int64_t>(values.size())}));
300   literal.PopulateR1(values);
301   return literal;
302 }
303 
304 template <typename NativeT>
CreateR2WithLayout(std::initializer_list<std::initializer_list<NativeT>> values,const Layout & layout)305 /* static */ Literal LiteralUtil::CreateR2WithLayout(
306     std::initializer_list<std::initializer_list<NativeT>> values,
307     const Layout& layout) {
308   Literal literal(ShapeUtil::MakeShapeWithLayout(
309       primitive_util::NativeToPrimitiveType<NativeT>(),
310       {static_cast<int64_t>(values.size()),
311        static_cast<int64_t>(values.begin()->size())},
312       layout.minor_to_major()));
313   literal.PopulateR2(values);
314   return literal;
315 }
316 
317 template <typename NativeT>
CreateR2(std::initializer_list<std::initializer_list<NativeT>> values)318 /* static */ Literal LiteralUtil::CreateR2(
319     std::initializer_list<std::initializer_list<NativeT>> values) {
320   return CreateR2WithLayout(values, LayoutUtil::GetDefaultLayoutForR2());
321 }
322 
323 template <typename NativeT>
CreateR3WithLayout(std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>> values,const Layout & layout)324 /* static */ Literal LiteralUtil::CreateR3WithLayout(
325     std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
326         values,
327     const Layout& layout) {
328   const int64_t d0 = values.size();
329   const int64_t d1 = values.begin()->size();
330   const int64_t d2 = values.begin()->begin()->size();
331   Array3D<NativeT> tmp(d0, d1, d2);
332   int64_t i0 = 0;
333   for (auto d1_values : values) {
334     int64_t i1 = 0;
335     for (auto d2_values : d1_values) {
336       int64_t i2 = 0;
337       for (auto value : d2_values) {
338         tmp(i0, i1, i2) = value;
339         ++i2;
340       }
341       ++i1;
342     }
343     ++i0;
344   }
345   return CreateR3FromArray3DWithLayout(tmp, layout);
346 }
347 
348 template <typename NativeT>
CreateR3(std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>> values)349 /* static */ Literal LiteralUtil::CreateR3(
350     std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
351         values) {
352   return CreateR3WithLayout(values, LayoutUtil::GetDefaultLayoutForR3());
353 }
354 
355 template <typename NativeT>
CreateR4WithLayout(std::initializer_list<std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>> values,const Layout & layout)356 /* static */ Literal LiteralUtil::CreateR4WithLayout(
357     std::initializer_list<std::initializer_list<
358         std::initializer_list<std::initializer_list<NativeT>>>>
359         values,
360     const Layout& layout) {
361   const int64_t d0 = values.size();
362   const int64_t d1 = values.begin()->size();
363   const int64_t d2 = values.begin()->begin()->size();
364   const int64_t d3 = values.begin()->begin()->begin()->size();
365   Array4D<NativeT> tmp(d0, d1, d2, d3);
366   int64_t i0 = 0;
367   for (auto d1_values : values) {
368     int64_t i1 = 0;
369     for (auto d2_values : d1_values) {
370       int64_t i2 = 0;
371       for (auto d3_values : d2_values) {
372         int64_t i3 = 0;
373         for (auto value : d3_values) {
374           tmp(i0, i1, i2, i3) = value;
375           ++i3;
376         }
377         ++i2;
378       }
379       ++i1;
380     }
381     ++i0;
382   }
383   return CreateR4FromArray4DWithLayout(tmp, layout);
384 }
385 
386 template <typename NativeT>
CreateR4(std::initializer_list<std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>> values)387 /* static */ Literal LiteralUtil::CreateR4(
388     std::initializer_list<std::initializer_list<
389         std::initializer_list<std::initializer_list<NativeT>>>>
390         values) {
391   return CreateR4WithLayout(values, LayoutUtil::GetDefaultLayoutForR4());
392 }
393 
394 template <typename NativeT>
CreateFromArrayWithLayout(const Array<NativeT> & values,const Layout & layout)395 /* static */ Literal LiteralUtil::CreateFromArrayWithLayout(
396     const Array<NativeT>& values, const Layout& layout) {
397   Literal literal(ShapeUtil::MakeShapeWithLayout(
398       primitive_util::NativeToPrimitiveType<NativeT>(), values.dimensions(),
399       layout.minor_to_major()));
400   literal.PopulateFromArray(values);
401   return literal;
402 }
403 
404 template <typename NativeT>
CreateFromArray(const Array<NativeT> & values)405 /* static */ Literal LiteralUtil::CreateFromArray(
406     const Array<NativeT>& values) {
407   return CreateFromArrayWithLayout(
408       values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions()));
409 }
410 
411 template <typename NativeT>
CreateR2FromArray2DWithLayout(const Array2D<NativeT> & values,const Layout & layout)412 /* static */ Literal LiteralUtil::CreateR2FromArray2DWithLayout(
413     const Array2D<NativeT>& values, const Layout& layout) {
414   return CreateFromArrayWithLayout(values, layout);
415 }
416 
417 template <typename NativeT>
CreateR2FromArray2D(const Array2D<NativeT> & values)418 /* static */ Literal LiteralUtil::CreateR2FromArray2D(
419     const Array2D<NativeT>& values) {
420   return CreateFromArray(values);
421 }
422 
423 template <typename NativeT>
CreateR3FromArray3DWithLayout(const Array3D<NativeT> & values,const Layout & layout)424 /* static */ Literal LiteralUtil::CreateR3FromArray3DWithLayout(
425     const Array3D<NativeT>& values, const Layout& layout) {
426   return CreateFromArrayWithLayout(values, layout);
427 }
428 
429 template <typename NativeT>
CreateR3FromArray3D(const Array3D<NativeT> & values)430 /* static */ Literal LiteralUtil::CreateR3FromArray3D(
431     const Array3D<NativeT>& values) {
432   return CreateFromArray(values);
433 }
434 
435 template <typename NativeT>
CreateR3Projected(std::initializer_list<std::initializer_list<NativeT>> values,int64_t projection)436 /* static */ Literal LiteralUtil::CreateR3Projected(
437     std::initializer_list<std::initializer_list<NativeT>> values,
438     int64_t projection) {
439   int64_t dim0_size = projection;
440   int64_t dim1_size = values.size();
441   int64_t dim2_size = values.begin()->size();
442 
443   Array3D<NativeT> array(dim0_size, dim1_size, dim2_size);
444   for (int64_t dim0 = 0; dim0 < dim0_size; ++dim0) {
445     int64_t dim1 = 0;
446     for (auto inner_list : values) {
447       int64_t dim2 = 0;
448       for (auto value : inner_list) {
449         array(dim0, dim1, dim2) = value;
450         ++dim2;
451       }
452       CHECK_EQ(dim2_size, dim2);
453       ++dim1;
454     }
455     CHECK_EQ(dim1_size, dim1);
456   }
457   return CreateR3FromArray3D(array);
458 }
459 
460 template <typename NativeT>
CreateR4Projected(std::initializer_list<std::initializer_list<NativeT>> values,int64_t projection_p,int64_t projection_z)461 /* static */ Literal LiteralUtil::CreateR4Projected(
462     std::initializer_list<std::initializer_list<NativeT>> values,
463     int64_t projection_p, int64_t projection_z) {
464   int64_t dim0_size = projection_p;
465   int64_t dim1_size = projection_z;
466   int64_t dim2_size = values.size();
467   int64_t dim3_size = values.begin()->size();
468 
469   Array4D<NativeT> array(dim0_size, dim1_size, dim2_size, dim3_size);
470   for (int64_t dim0 = 0; dim0 < dim0_size; ++dim0) {
471     for (int64_t dim1 = 0; dim1 < dim1_size; ++dim1) {
472       int64_t dim2 = 0;
473       for (auto inner_list : values) {
474         int64_t dim3 = 0;
475         for (auto value : inner_list) {
476           array(dim0, dim1, dim2, dim3) = value;
477           ++dim3;
478         }
479         CHECK_EQ(dim3_size, dim3);
480         ++dim2;
481       }
482       CHECK_EQ(dim2_size, dim2);
483     }
484   }
485   return CreateR4FromArray4D(array);
486 }
487 
488 template <typename NativeT>
CreateR4FromArray4D(const Array4D<NativeT> & values)489 /* static */ Literal LiteralUtil::CreateR4FromArray4D(
490     const Array4D<NativeT>& values) {
491   return CreateFromArray(values);
492 }
493 
494 template <typename NativeT>
CreateR4FromArray4DWithLayout(const Array4D<NativeT> & values,const Layout & layout)495 /* static */ Literal LiteralUtil::CreateR4FromArray4DWithLayout(
496     const Array4D<NativeT>& values, const Layout& layout) {
497   return CreateFromArrayWithLayout(values, layout);
498 }
499 
500 // Returns an identity matrix (rank 2) with the given row and column count.
501 template <typename NativeT>
MakeIdentityR2(int64_t size)502 /* static */ Literal LiteralUtil::MakeIdentityR2(int64_t size) {
503   Array2D<NativeT> array(size, size, 0);
504   for (int64_t i = 0; i < size; ++i) {
505     array(i, i) = 1;
506   }
507   return CreateR2FromArray2D(array);
508 }
509 
510 template <typename NativeT>
CreateFullWithDescendingLayout(absl::Span<const int64_t> dimensions,NativeT value)511 /* static */ Literal LiteralUtil::CreateFullWithDescendingLayout(
512     absl::Span<const int64_t> dimensions, NativeT value) {
513   Literal literal(ShapeUtil::MakeShapeWithDescendingLayout(
514       primitive_util::NativeToPrimitiveType<NativeT>(), dimensions));
515   literal.PopulateWithValue(value);
516   return literal;
517 }
518 
519 template <PrimitiveType type, typename T>
CreateLiteralWithGenerator(const Shape & shape,const std::function<T (absl::Span<const int64_t>)> & generator)520 /* static */ StatusOr<Literal> LiteralUtil::CreateLiteralWithGenerator(
521     const Shape& shape,
522     const std::function<T(absl::Span<const int64_t>)>& generator) {
523   using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
524   TF_RET_CHECK(shape.element_type() == type);
525   Literal literal(shape);
526   TF_RETURN_IF_ERROR(literal.Populate<NativeT>(
527       [&](absl::Span<const int64_t> indexes) { return generator(indexes); }));
528   return std::move(literal);
529 }
530 
531 template <PrimitiveType type, typename E, typename T>
CreateRandomLiteral(const Shape & shape,E * engine,T mean,T stddev)532 /* static */ StatusOr<Literal> LiteralUtil::CreateRandomLiteral(
533     const Shape& shape, E* engine, T mean, T stddev) {
534   using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
535   std::normal_distribution<NativeT> generator(mean, stddev);
536   return CreateLiteralWithGenerator<type, NativeT>(
537       shape, [&](absl::Span<const int64_t> /*indexes*/) {
538         return generator(*engine);
539       });
540 }
541 
542 template <PrimitiveType type, typename T>
CreateRandomLiteral(const Shape & shape,T mean,T stddev)543 /* static */ StatusOr<Literal> LiteralUtil::CreateRandomLiteral(
544     const Shape& shape, T mean, T stddev) {
545   std::minstd_rand0 engine;
546   return CreateRandomLiteral<type>(shape, &engine, mean, stddev);
547 }
548 
549 }  // namespace xla
550 
551 #endif  // TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_
552