xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/tests/dynamic_ops_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <numeric>
17 #include <vector>
18 
19 #include "tensorflow/compiler/xla/array2d.h"
20 #include "tensorflow/compiler/xla/client/client_library.h"
21 #include "tensorflow/compiler/xla/client/local_client.h"
22 #include "tensorflow/compiler/xla/client/xla_builder.h"
23 #include "tensorflow/compiler/xla/reference_util.h"
24 #include "tensorflow/compiler/xla/service/local_service.h"
25 #include "tensorflow/compiler/xla/service/platform_util.h"
26 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
27 #include "tensorflow/compiler/xla/service/transfer_manager.h"
28 #include "tensorflow/compiler/xla/test_helpers.h"
29 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
30 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
31 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
32 #include "tensorflow/compiler/xla/tests/test_macros.h"
33 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
34 #include "tensorflow/core/platform/test.h"
35 #include "tensorflow/core/platform/test_benchmark.h"
36 #include "tensorflow/stream_executor/device_memory_allocator.h"
37 
38 namespace xla {
39 namespace {
40 
41 class DynamicSliceTest : public ClientLibraryTestBase {
42  protected:
43   template <typename IndexT, typename DataT>
TestR1()44   void TestR1() {
45     // Slice at dimension start.
46     RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {0}, {5}, {0, 1, 2, 3, 4});
47     // Slice in the middle.
48     RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {2}, {3}, {2, 3, 4});
49     // Slice at dimension boundaries.
50     RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {5}, {3}, {5, 6, 7});
51     // Zero element slice.
52     RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {2}, {0}, {});
53   }
54 
55   template <typename IndexT, typename DataT>
TestR1OOB()56   void TestR1OOB() {
57     // Slice at dimension boundaries, but with out of bounds indices.
58     RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {6}, {4}, {4, 5, 6, 7});
59   }
60 
61   template <typename IndexT, typename DataT>
TestR2()62   void TestR2() {
63     // Slice at dimension start.
64     RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {0, 0}, {2, 2},
65                          {{1, 2}, {4, 5}});
66     // Slice in the middle.
67     RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {1, 1}, {2, 1},
68                          {{5}, {8}});
69     // Slice at dimension boundaries.
70     RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {1, 1}, {2, 1},
71                          {{5}, {8}});
72     // Zero element slice: 2x0.
73     RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {0, 0}, {2, 0},
74                          {{}, {}});
75     // Zero element slice: 0x2.
76     RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {0, 0}, {0, 2},
77                          Array2D<int>(0, 2));
78   }
79 
80   template <typename IndexT, typename DataT>
TestR2OOB()81   void TestR2OOB() {
82     // Slice at dimension boundaries, but with out of bounds indices.
83     RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {1, 1}, {3, 3},
84                          {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
85   }
86 
87   template <typename IndexT, typename DataT>
TestR3()88   void TestR3() {
89     // R3 Shape: [2, 3, 2]
90     // clang-format off
91 
92     // Slice at dimension start.
93     RunR3<IndexT, DataT>(
94       {{{1, 2}, {3, 4}, {5, 6}},
95        {{7, 8}, {9, 10}, {11, 12}}},
96       {0, 0, 0}, {2, 1, 2},
97       {{{1, 2}}, {{7, 8}}});
98 
99     // Slice in the middle.
100     RunR3<IndexT, DataT>(
101       {{{1, 2}, {3, 4}, {5, 6}},
102        {{7, 8}, {9, 10}, {11, 12}}},
103       {0, 1, 1}, {2, 2, 1},
104       {{{4}, {6}}, {{10}, {12}}});
105     // clang-format on
106   }
107 
108   template <typename IndexT, typename DataT>
TestR3OOB()109   void TestR3OOB() {
110     // Slice at dimension boundaries, but with out of bounds indices.
111     RunR3<IndexT, DataT>(
112         {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}, {0, 2, 1},
113         {2, 1, 2}, {{{5, 6}}, {{11, 12}}});
114   }
115 
116   template <typename IndexT, typename DataT>
RunR1(absl::Span<const int> input_values_int,const std::vector<IndexT> slice_starts,const std::vector<int64_t> & slice_sizes,absl::Span<const int> expected_values_int)117   void RunR1(absl::Span<const int> input_values_int,
118              const std::vector<IndexT> slice_starts,
119              const std::vector<int64_t>& slice_sizes,
120              absl::Span<const int> expected_values_int) {
121     // bfloat16 has explicit constructors, so it does not implicitly convert the
122     // way built-in types do, which is why we can't take the parameter as an
123     // Span<DataT>. We also can't convert it to a vector, because
124     // vector<bool> is special so that it cannot be a Span<bool>, which
125     // is what the code below wants. So instead we do this.
126     Literal input_values =
127         LiteralUtil::CreateR1(input_values_int)
128             .Convert(primitive_util::NativeToPrimitiveType<DataT>())
129             .ValueOrDie();
130     Literal expected_values =
131         std::move(LiteralUtil::CreateR1(expected_values_int)
132                       .Convert(primitive_util::NativeToPrimitiveType<DataT>())
133                       .ValueOrDie());
134 
135     XlaBuilder builder(TestName());
136     // Initialize and transfer dynamic slice start indices parameter.
137     XlaOp starts;
138     std::unique_ptr<GlobalData> start_data = CreateR0Parameter<IndexT>(
139         slice_starts[0], 0, "slice_starts", &builder, &starts);
140     // Build dynamic slice computation.
141     auto input = ConstantLiteral(&builder, input_values);
142     DynamicSlice(input, absl::Span<const XlaOp>({starts}), slice_sizes);
143     // Run computation and compare against expected values.
144     ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()});
145   }
146 
147   template <typename IndexT, typename DataT>
RunR2(const Array2D<int> & input_values_int,const std::vector<IndexT> slice_starts,const std::vector<int64_t> & slice_sizes,const Array2D<int> & expected_values_int)148   void RunR2(const Array2D<int>& input_values_int,
149              const std::vector<IndexT> slice_starts,
150              const std::vector<int64_t>& slice_sizes,
151              const Array2D<int>& expected_values_int) {
152     Literal input_values =
153         std::move(LiteralUtil::CreateR2FromArray2D(input_values_int)
154                       .Convert(primitive_util::NativeToPrimitiveType<DataT>())
155                       .ValueOrDie());
156     Literal expected_values =
157         std::move(LiteralUtil::CreateR2FromArray2D(expected_values_int)
158                       .Convert(primitive_util::NativeToPrimitiveType<DataT>())
159                       .ValueOrDie());
160 
161     XlaBuilder builder(TestName());
162     // Initialize and transfer dynamic slice start indices parameter.
163     std::vector<XlaOp> starts(2);
164     std::vector<std::unique_ptr<GlobalData>> start_data(2);
165     for (int i = 0; i < 2; ++i) {
166       start_data[i] = CreateR0Parameter<IndexT>(
167           slice_starts[i], i, "slice_starts", &builder, &starts[i]);
168     }
169 
170     // Build dynamic slice computation.
171     auto input = ConstantLiteral(&builder, input_values);
172     DynamicSlice(input, starts, slice_sizes);
173     // Run computation and compare against expected values.
174     std::vector<GlobalData*> argument_ptrs;
175     absl::c_transform(start_data, std::back_inserter(argument_ptrs),
176                       [](const std::unique_ptr<GlobalData>& argument) {
177                         return argument.get();
178                       });
179     ComputeAndCompareLiteral(&builder, expected_values, argument_ptrs);
180   }
181 
182   template <typename IndexT, typename DataT>
RunR3(const Array3D<int> & input_values_int,const std::vector<IndexT> slice_starts,const std::vector<int64_t> & slice_sizes,const Array3D<int> & expected_values_int)183   void RunR3(const Array3D<int>& input_values_int,
184              const std::vector<IndexT> slice_starts,
185              const std::vector<int64_t>& slice_sizes,
186              const Array3D<int>& expected_values_int) {
187     Literal input_values =
188         std::move(LiteralUtil::CreateR3FromArray3D(input_values_int)
189                       .Convert(primitive_util::NativeToPrimitiveType<DataT>())
190                       .ValueOrDie());
191     Literal expected_values =
192         std::move(LiteralUtil::CreateR3FromArray3D(expected_values_int)
193                       .Convert(primitive_util::NativeToPrimitiveType<DataT>())
194                       .ValueOrDie());
195 
196     XlaBuilder builder(TestName());
197     // Initialize and transfer dynamic slice start indices parameter.
198     std::vector<XlaOp> starts(3);
199     std::vector<std::unique_ptr<GlobalData>> start_data(3);
200     for (int i = 0; i < 3; ++i) {
201       start_data[i] = CreateR0Parameter<IndexT>(
202           slice_starts[i], i, "slice_starts", &builder, &starts[i]);
203     }
204     // Build dynamic slice computation.
205     auto input = ConstantLiteral(&builder, input_values);
206     DynamicSlice(input, starts, slice_sizes);
207     // Run computation and compare against expected values.
208     std::vector<GlobalData*> argument_ptrs;
209     absl::c_transform(start_data, std::back_inserter(argument_ptrs),
210                       [](const std::unique_ptr<GlobalData>& argument) {
211                         return argument.get();
212                       });
213     ComputeAndCompareLiteral(&builder, expected_values, argument_ptrs);
214   }
215 };
216 
XLA_TEST_F(DynamicSliceTest,Int32R1BF16)217 XLA_TEST_F(DynamicSliceTest, Int32R1BF16) { TestR1<int32_t, bfloat16>(); }
XLA_TEST_F(DynamicSliceTest,Int32R1)218 XLA_TEST_F(DynamicSliceTest, Int32R1) { TestR1<int32_t, int32_t>(); }
XLA_TEST_F(DynamicSliceTest,Int32R1OOB)219 XLA_TEST_F(DynamicSliceTest, Int32R1OOB) { TestR1OOB<int32_t, int32_t>(); }
XLA_TEST_F(DynamicSliceTest,Int64R1)220 XLA_TEST_F(DynamicSliceTest, Int64R1) { TestR1<int64_t, float>(); }
XLA_TEST_F(DynamicSliceTest,UInt64R1)221 XLA_TEST_F(DynamicSliceTest, UInt64R1) { TestR1<uint64_t, float>(); }
XLA_TEST_F(DynamicSliceTest,UInt32R1OOB)222 XLA_TEST_F(DynamicSliceTest, UInt32R1OOB) {
223   RunR1<uint32_t, int32_t>({0, 1, 2, 3, 4}, {2147483648u}, {2}, {3, 4});
224 }
225 
XLA_TEST_F(DynamicSliceTest,Int32R2BF16)226 XLA_TEST_F(DynamicSliceTest, Int32R2BF16) { TestR2<int32_t, bfloat16>(); }
XLA_TEST_F(DynamicSliceTest,Int32R2)227 XLA_TEST_F(DynamicSliceTest, Int32R2) { TestR2<int32_t, int32_t>(); }
XLA_TEST_F(DynamicSliceTest,Int32R2OOB)228 XLA_TEST_F(DynamicSliceTest, Int32R2OOB) { TestR2OOB<int32_t, int32_t>(); }
XLA_TEST_F(DynamicSliceTest,Int64R2)229 XLA_TEST_F(DynamicSliceTest, Int64R2) { TestR2<int64_t, float>(); }
XLA_TEST_F(DynamicSliceTest,UInt64R2)230 XLA_TEST_F(DynamicSliceTest, UInt64R2) { TestR2<uint64_t, int32_t>(); }
XLA_TEST_F(DynamicSliceTest,UInt32R2OOB)231 XLA_TEST_F(DynamicSliceTest, UInt32R2OOB) {
232   RunR2<uint32_t, int32_t>({{0, 1}, {2, 3}}, {2147483648u, 0}, {1, 1}, {{2}});
233 }
234 
XLA_TEST_F(DynamicSliceTest,Int32R3BF16)235 XLA_TEST_F(DynamicSliceTest, Int32R3BF16) { TestR3<int32_t, bfloat16>(); }
XLA_TEST_F(DynamicSliceTest,Int32R3)236 XLA_TEST_F(DynamicSliceTest, Int32R3) { TestR3<int32_t, float>(); }
XLA_TEST_F(DynamicSliceTest,Int32R3OOB)237 XLA_TEST_F(DynamicSliceTest, Int32R3OOB) { TestR3OOB<int32_t, float>(); }
XLA_TEST_F(DynamicSliceTest,Int64R3)238 XLA_TEST_F(DynamicSliceTest, Int64R3) { TestR3<int64_t, float>(); }
XLA_TEST_F(DynamicSliceTest,UInt64R3)239 XLA_TEST_F(DynamicSliceTest, UInt64R3) { TestR3<uint64_t, float>(); }
XLA_TEST_F(DynamicSliceTest,UInt32R3OOB)240 XLA_TEST_F(DynamicSliceTest, UInt32R3OOB) {
241   RunR3<uint32_t, int32_t>({{{0, 1}, {2, 3}}, {{4, 5}, {6, 7}}},
242                            {2147483648u, 0, 2147483648u}, {1, 1, 1}, {{{5}}});
243 }
244 
XLA_TEST_F(DynamicSliceTest,Int32R1Pred)245 XLA_TEST_F(DynamicSliceTest, Int32R1Pred) {
246   // Slice at dimension start.
247   RunR1<int32_t, bool>({true, false, false, true, false, true, true, false},
248                        {0}, {5}, {true, false, false, true, false});
249   // Slice in the middle.
250   RunR1<int32_t, bool>({true, false, false, true, false, true, true, false},
251                        {2}, {3}, {false, true, false});
252   // Slice at dimension boundaries.
253   RunR1<int32_t, bool>({true, false, false, true, false, true, true, false},
254                        {5}, {3}, {true, true, false});
255   // Zero element slice.
256   RunR1<int32_t, bool>({true, false, false, true, false, true, true, false},
257                        {2}, {0}, {});
258 }
259 
XLA_TEST_F(DynamicSliceTest,Int32R2Pred)260 XLA_TEST_F(DynamicSliceTest, Int32R2Pred) {
261   // Slice at dimension start.
262   RunR2<int32_t, bool>(
263       {{true, false, true}, {false, false, true}, {true, true, false}}, {0, 0},
264       {2, 2}, {{true, false}, {false, false}});
265   // Slice in the middle.
266   RunR2<int32_t, bool>(
267       {{true, false, true}, {false, false, true}, {true, true, false}}, {1, 1},
268       {2, 1}, {{false}, {true}});
269   // Slice at dimension boundaries.
270   RunR2<int32_t, bool>(
271       {{true, false, true}, {false, false, true}, {true, true, false}}, {1, 1},
272       {2, 1}, {{false}, {true}});
273   // Zero element slice: 2x0.
274   RunR2<int32_t, bool>(
275       {{true, false, true}, {false, false, true}, {true, true, false}}, {0, 0},
276       {2, 0}, {{}, {}});
277   // Zero element slice: 0x2.
278   RunR2<int32_t, bool>(
279       {{true, false, true}, {false, false, true}, {true, true, false}}, {0, 0},
280       {0, 2}, Array2D<int>(0, 2));
281 }
282 
XLA_TEST_F(DynamicSliceTest,Int32R3Pred)283 XLA_TEST_F(DynamicSliceTest, Int32R3Pred) {
284   // R3 Shape: [2, 3, 2]
285   // clang-format off
286 
287   // Slice at dimension start.
288   RunR3<int32_t, bool>(
289     {{{true, false}, {false, true}, {true, true}},
290      {{false, true}, {true, false}, {false, false}}},
291     {0, 0, 0}, {2, 1, 2},
292     {{{true, false}}, {{false, true}}});
293 
294   // Slice in the middle.
295   RunR3<int32_t, bool>(
296     {{{true, false}, {false, true}, {true, true}},
297      {{false, true}, {true, false}, {false, false}}},
298     {0, 1, 1}, {2, 2, 1},
299     {{{true}, {true}}, {{false}, {false}}});
300 
301   // clang-format on
302 }
303 
304 class DynamicUpdateSliceTest : public ClientLibraryTestBase {
305  protected:
306   template <typename IndexT, typename DataT>
TestR0()307   void TestR0() {
308     // Disable algebraic simplifier, otherwise the op will be replaced by a
309     // constant.
310     execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes(
311         "algsimp");
312     RunR0<IndexT, DataT>(0, 123, {}, 123);
313   }
314 
315   template <typename IndexT, typename DataT>
TestR1()316   void TestR1() {
317     // Slice at dimension start.
318     RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {8, 9, 10}, {0},
319                          {8, 9, 10, 3, 4, 5, 6, 7});
320     // Slice in the middle.
321     RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {8, 9, 10}, {2},
322                          {0, 1, 8, 9, 10, 5, 6, 7});
323     // Slice at dimension boundaries.
324     RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {8, 9, 10}, {5},
325                          {0, 1, 2, 3, 4, 8, 9, 10});
326     // Zero-sized update.
327     RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {}, {2},
328                          {0, 1, 2, 3, 4, 5, 6, 7});
329   }
330 
331   template <typename IndexT, typename DataT>
TestR2()332   void TestR2() {
333     // Slice at dimension start.
334     RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {{10, 11}}, {0, 0},
335                          {{10, 11, 3}, {4, 5, 6}, {7, 8, 9}});
336     // Slice in the middle.
337     RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {{10, 11}}, {1, 1},
338                          {{1, 2, 3}, {4, 10, 11}, {7, 8, 9}});
339     // Slice at dimension boundaries.
340     RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {{10, 11}}, {2, 1},
341                          {{1, 2, 3}, {4, 5, 6}, {7, 10, 11}});
342     // Zero-sized update.
343     RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {{}}, {2, 1},
344                          {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
345   }
346 
347   template <typename IndexT, typename DataT>
TestR3()348   void TestR3() {
349     // R3 Shape: [2, 3, 2]
350     // Slice at dimension start.
351     RunR3<IndexT, DataT>(
352         {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}},
353         {{{13, 14}, {15, 16}}, {{17, 18}, {19, 20}}}, {0, 0, 0},
354         {{{13, 14}, {15, 16}, {5, 6}}, {{17, 18}, {19, 20}, {11, 12}}});
355     // Slice in the middle.
356     RunR3<IndexT, DataT>(
357         {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}, {{{13}, {15}}},
358         {1, 1, 1}, {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 13}, {11, 15}}});
359   }
360 
361   template <typename IndexT, typename DataT>
TestOOB()362   void TestOOB() {
363     // // Slice at dimension boundaries, but with out of bounds indices.
364     RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {8, 9, 10}, {6},
365                          {0, 1, 2, 3, 4, 8, 9, 10});
366     // R2 Shape: [3, 3]
367     RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {{10, 11}}, {2, 2},
368                          {{1, 2, 3}, {4, 5, 6}, {7, 10, 11}});
369     // R3 Shape: [2, 3, 2]
370     RunR3<IndexT, DataT>(
371         {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}, {{{13}, {15}}},
372         {1, 2, 1}, {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 13}, {11, 15}}});
373   }
374 
375   template <typename IndexT, typename DataT>
RunR0(int input_value_int,int update_value_int,const std::vector<IndexT> slice_starts,int expected_value_int)376   void RunR0(int input_value_int, int update_value_int,
377              const std::vector<IndexT> slice_starts, int expected_value_int) {
378     Literal input_value =
379         std::move(LiteralUtil::CreateR0(input_value_int)
380                       .Convert(primitive_util::NativeToPrimitiveType<DataT>())
381                       .ValueOrDie());
382     Literal update_value =
383         std::move(LiteralUtil::CreateR0(update_value_int)
384                       .Convert(primitive_util::NativeToPrimitiveType<DataT>())
385                       .ValueOrDie());
386     Literal expected_value =
387         std::move(LiteralUtil::CreateR0(expected_value_int)
388                       .Convert(primitive_util::NativeToPrimitiveType<DataT>())
389                       .ValueOrDie());
390 
391     XlaBuilder builder(TestName());
392     // Build dynamic slice computation.
393     auto input = ConstantLiteral(&builder, input_value);
394     auto update = ConstantLiteral(&builder, update_value);
395     DynamicUpdateSlice(input, update, absl::Span<const XlaOp>({}));
396     // Run computation and compare against expected values.
397     ComputeAndCompareLiteral(&builder, expected_value, {});
398   }
399 
400   template <typename IndexT, typename DataT>
RunR1(absl::Span<const int> input_values_int,absl::Span<const int> update_values_int,const std::vector<IndexT> slice_starts,absl::Span<const int> expected_values_int)401   void RunR1(absl::Span<const int> input_values_int,
402              absl::Span<const int> update_values_int,
403              const std::vector<IndexT> slice_starts,
404              absl::Span<const int> expected_values_int) {
405     Literal input_values =
406         std::move(LiteralUtil::CreateR1(input_values_int)
407                       .Convert(primitive_util::NativeToPrimitiveType<DataT>())
408                       .ValueOrDie());
409     Literal update_values =
410         std::move(LiteralUtil::CreateR1(update_values_int)
411                       .Convert(primitive_util::NativeToPrimitiveType<DataT>())
412                       .ValueOrDie());
413     Literal expected_values =
414         std::move(LiteralUtil::CreateR1(expected_values_int)
415                       .Convert(primitive_util::NativeToPrimitiveType<DataT>())
416                       .ValueOrDie());
417 
418     XlaBuilder builder(TestName());
419     // Initialize and transfer dynamic slice start indices parameter.
420     XlaOp starts;
421     std::unique_ptr<GlobalData> start_data = CreateR0Parameter<IndexT>(
422         slice_starts[0], 0, "slice_starts", &builder, &starts);
423     // Build dynamic slice computation.
424     auto input = ConstantLiteral(&builder, input_values);
425     auto update = ConstantLiteral(&builder, update_values);
426     DynamicUpdateSlice(input, update, absl::Span<const XlaOp>({starts}));
427     // Run computation and compare against expected values.
428     ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()});
429   }
430 
431   template <typename IndexT, typename DataT>
RunR2(const Array2D<int> & input_values_int,const Array2D<int> & update_values_int,const std::vector<IndexT> slice_starts,const Array2D<int> & expected_values_int)432   void RunR2(const Array2D<int>& input_values_int,
433              const Array2D<int>& update_values_int,
434              const std::vector<IndexT> slice_starts,
435              const Array2D<int>& expected_values_int) {
436     Literal input_values =
437         std::move(LiteralUtil::CreateR2FromArray2D(input_values_int)
438                       .Convert(primitive_util::NativeToPrimitiveType<DataT>())
439                       .ValueOrDie());
440     Literal update_values =
441         std::move(LiteralUtil::CreateR2FromArray2D(update_values_int)
442                       .Convert(primitive_util::NativeToPrimitiveType<DataT>())
443                       .ValueOrDie());
444     Literal expected_values =
445         std::move(LiteralUtil::CreateR2FromArray2D(expected_values_int)
446                       .Convert(primitive_util::NativeToPrimitiveType<DataT>())
447                       .ValueOrDie());
448 
449     XlaBuilder builder(TestName());
450     // Initialize and transfer dynamic slice start indices parameter.
451     std::vector<XlaOp> starts(2);
452     std::vector<std::unique_ptr<GlobalData>> start_data(2);
453     for (int i = 0; i < 2; ++i) {
454       start_data[i] = CreateR0Parameter<IndexT>(
455           slice_starts[i], i, "slice_starts", &builder, &starts[i]);
456     }
457     // Build dynamic slice computation.
458     auto input = ConstantLiteral(&builder, input_values);
459     auto update = ConstantLiteral(&builder, update_values);
460     DynamicUpdateSlice(input, update, starts);
461     // Run computation and compare against expected values.
462     std::vector<GlobalData*> argument_ptrs;
463     absl::c_transform(start_data, std::back_inserter(argument_ptrs),
464                       [](const std::unique_ptr<GlobalData>& argument) {
465                         return argument.get();
466                       });
467     ComputeAndCompareLiteral(&builder, expected_values, argument_ptrs);
468   }
469 
470   template <typename IndexT, typename DataT>
RunR3(const Array3D<int> & input_values_int,const Array3D<int> & update_values_int,const std::vector<IndexT> slice_starts,const Array3D<int> & expected_values_int)471   void RunR3(const Array3D<int>& input_values_int,
472              const Array3D<int>& update_values_int,
473              const std::vector<IndexT> slice_starts,
474              const Array3D<int>& expected_values_int) {
475     Literal input_values =
476         std::move(LiteralUtil::CreateR3FromArray3D(input_values_int)
477                       .Convert(primitive_util::NativeToPrimitiveType<DataT>())
478                       .ValueOrDie());
479     Literal update_values =
480         std::move(LiteralUtil::CreateR3FromArray3D(update_values_int)
481                       .Convert(primitive_util::NativeToPrimitiveType<DataT>())
482                       .ValueOrDie());
483     Literal expected_values =
484         std::move(LiteralUtil::CreateR3FromArray3D(expected_values_int)
485                       .Convert(primitive_util::NativeToPrimitiveType<DataT>())
486                       .ValueOrDie());
487 
488     XlaBuilder builder(TestName());
489     // Initialize and transfer dynamic slice start indices parameter.
490     std::vector<XlaOp> starts(3);
491     std::vector<std::unique_ptr<GlobalData>> start_data(3);
492     for (int i = 0; i < 3; ++i) {
493       start_data[i] = CreateR0Parameter<IndexT>(
494           slice_starts[i], i, "slice_starts", &builder, &starts[i]);
495     }
496 
497     // Build dynamic slice computation.
498     auto input = ConstantLiteral(&builder, input_values);
499     auto update = ConstantLiteral(&builder, update_values);
500     DynamicUpdateSlice(input, update, starts);
501     // Run computation and compare against expected values.
502     std::vector<GlobalData*> argument_ptrs;
503     absl::c_transform(start_data, std::back_inserter(argument_ptrs),
504                       [](const std::unique_ptr<GlobalData>& argument) {
505                         return argument.get();
506                       });
507     ComputeAndCompareLiteral(&builder, expected_values, argument_ptrs);
508   }
509 
510   template <class T>
RunR3Contiguous(std::vector<int32_t> operand_shape,int32_t index,int32_t size)511   void RunR3Contiguous(std::vector<int32_t> operand_shape, int32_t index,
512                        int32_t size) {
513     const int32_t kSeq = operand_shape[0];
514     const int32_t kBatch = operand_shape[1];
515     const int32_t kDim = operand_shape[2];
516     Array3D<T> input_values(kSeq, kBatch, kDim);
517     Array3D<T> update_values(size, kBatch, kDim);
518     Array3D<T> expected_values(kSeq, kBatch, kDim);
519     index = std::min(std::max(0, index), kSeq - size);
520 
521     input_values.FillIota(static_cast<T>(0));
522     T value = static_cast<T>(10);
523     update_values.FillIota(static_cast<T>(value));
524 
525     // TODO(b/34128753) Expected values may vary depending on backend when
526     // the indices are out of bounds.
527     expected_values.FillIota(static_cast<T>(0));
528     for (int i = 0; i < size; i++) {
529       for (int j = 0; j < kBatch; j++) {
530         for (int k = 0; k < kDim; k++) {
531           expected_values(index + i, j, k) = value++;
532         }
533       }
534     }
535     if (VLOG_IS_ON(1)) {
536       DumpArray<T>("input", input_values);
537       DumpArray<T>("update", update_values);
538       DumpArray<T>("expected", expected_values);
539     }
540 
541     // Build dynamic slice computation.
542     XlaBuilder builder(TestName());
543     // Initialize and transfer input parameter.
544     XlaOp input;
545     std::unique_ptr<GlobalData> input_data =
546         CreateR3Parameter<T>(input_values, 0, "input_values", &builder, &input);
547     // Initialize and transfer update parameter.
548     XlaOp update;
549     std::unique_ptr<GlobalData> update_data = CreateR3Parameter<T>(
550         update_values, 1, "update_values", &builder, &update);
551     auto constant_index = ConstantR0<int32_t>(&builder, index);
552     auto zero = ConstantR0<int32_t>(&builder, 0);
553     DynamicUpdateSlice(input, update, {constant_index, zero, zero});
554 
555     // Run computation and compare against expected values.
556     ComputeAndCompareR3<T>(&builder, expected_values,
557                            {input_data.get(), update_data.get()},
558                            ErrorSpec(0.000001));
559   }
560 
561   template <typename NativeT>
DumpArray(const std::string & name,const Array3D<NativeT> values)562   void DumpArray(const std::string& name, const Array3D<NativeT> values) {
563     Literal literal = LiteralUtil::CreateR3FromArray3D<NativeT>(values);
564     LOG(INFO) << name << ":" << literal.ToString();
565   }
566 };
567 
XLA_TEST_F(DynamicUpdateSliceTest,Int32R0BF16)568 XLA_TEST_F(DynamicUpdateSliceTest, Int32R0BF16) { TestR0<int32_t, bfloat16>(); }
XLA_TEST_F(DynamicUpdateSliceTest,Int32R0)569 XLA_TEST_F(DynamicUpdateSliceTest, Int32R0) { TestR0<int32_t, float>(); }
XLA_TEST_F(DynamicUpdateSliceTest,Int64R0)570 XLA_TEST_F(DynamicUpdateSliceTest, Int64R0) { TestR0<int64_t, float>(); }
XLA_TEST_F(DynamicUpdateSliceTest,UInt64R0)571 XLA_TEST_F(DynamicUpdateSliceTest, UInt64R0) { TestR0<uint64_t, float>(); }
572 
XLA_TEST_F(DynamicUpdateSliceTest,Int32R1BF16)573 XLA_TEST_F(DynamicUpdateSliceTest, Int32R1BF16) { TestR1<int32_t, bfloat16>(); }
XLA_TEST_F(DynamicUpdateSliceTest,Int32R1)574 XLA_TEST_F(DynamicUpdateSliceTest, Int32R1) { TestR1<int32_t, float>(); }
XLA_TEST_F(DynamicUpdateSliceTest,Int64R1)575 XLA_TEST_F(DynamicUpdateSliceTest, Int64R1) { TestR1<int64_t, float>(); }
XLA_TEST_F(DynamicUpdateSliceTest,UInt64R1)576 XLA_TEST_F(DynamicUpdateSliceTest, UInt64R1) { TestR1<uint64_t, float>(); }
XLA_TEST_F(DynamicUpdateSliceTest,UInt32R1OOB)577 XLA_TEST_F(DynamicUpdateSliceTest, UInt32R1OOB) {
578   RunR1<uint32_t, int32_t>({0, 1, 2, 3, 4}, {5, 6}, {2147483648u},
579                            {0, 1, 2, 5, 6});
580 }
581 
XLA_TEST_F(DynamicUpdateSliceTest,Int32R2BF16)582 XLA_TEST_F(DynamicUpdateSliceTest, Int32R2BF16) { TestR2<int32_t, bfloat16>(); }
XLA_TEST_F(DynamicUpdateSliceTest,Int32R2)583 XLA_TEST_F(DynamicUpdateSliceTest, Int32R2) { TestR2<int32_t, float>(); }
XLA_TEST_F(DynamicUpdateSliceTest,Int64R2)584 XLA_TEST_F(DynamicUpdateSliceTest, Int64R2) { TestR2<int64_t, int64_t>(); }
XLA_TEST_F(DynamicUpdateSliceTest,UInt64R2)585 XLA_TEST_F(DynamicUpdateSliceTest, UInt64R2) { TestR2<uint64_t, int32_t>(); }
XLA_TEST_F(DynamicUpdateSliceTest,UInt32R2OOB)586 XLA_TEST_F(DynamicUpdateSliceTest, UInt32R2OOB) {
587   RunR2<uint32_t, int32_t>({{0, 1}, {2, 3}}, {{4}}, {2147483648u, 0},
588                            {{0, 1}, {4, 3}});
589 }
590 
XLA_TEST_F(DynamicUpdateSliceTest,Int32R3BF16)591 XLA_TEST_F(DynamicUpdateSliceTest, Int32R3BF16) { TestR3<int32_t, bfloat16>(); }
XLA_TEST_F(DynamicUpdateSliceTest,Int32R3)592 XLA_TEST_F(DynamicUpdateSliceTest, Int32R3) { TestR3<int32_t, float>(); }
XLA_TEST_F(DynamicUpdateSliceTest,Int64R3)593 XLA_TEST_F(DynamicUpdateSliceTest, Int64R3) { TestR3<int64_t, int64_t>(); }
XLA_TEST_F(DynamicUpdateSliceTest,UInt64R3)594 XLA_TEST_F(DynamicUpdateSliceTest, UInt64R3) { TestR3<uint64_t, uint64_t>(); }
XLA_TEST_F(DynamicUpdateSliceTest,UInt32R3OOB)595 XLA_TEST_F(DynamicUpdateSliceTest, UInt32R3OOB) {
596   RunR3<uint32_t, int32_t>({{{0, 1}, {2, 3}}, {{4, 5}, {6, 7}}}, {{{8}}},
597                            {2147483648u, 0, 2147483648u},
598                            {{{0, 1}, {2, 3}}, {{4, 8}, {6, 7}}});
599 }
600 
XLA_TEST_F(DynamicUpdateSliceTest,Int32OOBBF16)601 XLA_TEST_F(DynamicUpdateSliceTest, Int32OOBBF16) {
602   TestOOB<int32_t, bfloat16>();
603 }
XLA_TEST_F(DynamicUpdateSliceTest,Int32OOB)604 XLA_TEST_F(DynamicUpdateSliceTest, Int32OOB) { TestOOB<int32_t, float>(); }
XLA_TEST_F(DynamicUpdateSliceTest,Int64OOB)605 XLA_TEST_F(DynamicUpdateSliceTest, Int64OOB) { TestOOB<int64_t, int64_t>(); }
XLA_TEST_F(DynamicUpdateSliceTest,UInt64OOB)606 XLA_TEST_F(DynamicUpdateSliceTest, UInt64OOB) { TestOOB<uint64_t, uint64_t>(); }
607 
XLA_TEST_F(DynamicUpdateSliceTest,Int32R1Pred)608 XLA_TEST_F(DynamicUpdateSliceTest, Int32R1Pred) {
609   // Slice at dimension start.
610   RunR1<int32_t, bool>({false, false, true, true, false, true, true, false},
611                        {true, true, false}, {0},
612                        {true, true, false, true, false, true, true, false});
613   // Slice in the middle.
614   RunR1<int32_t, bool>({false, false, true, true, false, true, true, false},
615                        {false, true, true}, {2},
616                        {false, false, false, true, true, true, true, false});
617   // Slice at dimension boundaries.
618   RunR1<int32_t, bool>({false, false, true, true, false, true, true, false},
619                        {false, true, true}, {5},
620                        {false, false, true, true, false, false, true, true});
621   // Zero-sized update.
622   RunR1<int32_t, bool>({false, false, true, true, false, true, true, false}, {},
623                        {2},
624                        {false, false, true, true, false, true, true, false});
625 }
626 
XLA_TEST_F(DynamicUpdateSliceTest,Int32R2Pred)627 XLA_TEST_F(DynamicUpdateSliceTest, Int32R2Pred) {
628   // Slice at dimension start.
629   RunR2<int32_t, bool>(
630       {{false, true, false}, {true, false, true}, {false, true, true}},
631       {{true, false}}, {0, 0},
632       {{true, false, false}, {true, false, true}, {false, true, true}});
633   // Slice in the middle.
634   RunR2<int32_t, bool>(
635       {{false, true, false}, {true, false, true}, {false, true, true}},
636       {{true, false}}, {1, 1},
637       {{false, true, false}, {true, true, false}, {false, true, true}});
638   // Slice at dimension boundaries.
639   RunR2<int32_t, bool>(
640       {{false, true, false}, {true, false, true}, {false, true, true}},
641       {{true, false}}, {2, 1},
642       {{false, true, false}, {true, false, true}, {false, true, false}});
643   // Zero-sized update.
644   RunR2<int32_t, bool>(
645       {{false, true, false}, {true, false, true}, {false, true, true}}, {{}},
646       {2, 1}, {{false, true, false}, {true, false, true}, {false, true, true}});
647 }
648 
XLA_TEST_F(DynamicUpdateSliceTest,Int32R3Pred)649 XLA_TEST_F(DynamicUpdateSliceTest, Int32R3Pred) {
650   // R3 Shape: [2, 3, 2]
651   // Slice at dimension start.
652   RunR3<int32_t, bool>(
653       {{{true, false}, {false, true}, {true, true}},
654        {{false, false}, {false, true}, {true, false}}},
655       {{{false, true}, {true, false}}, {{true, true}, {false, true}}},
656       {0, 0, 0},
657       {{{false, true}, {true, false}, {true, true}},
658        {{true, true}, {false, true}, {true, false}}});
659   // Slice in the middle.
660   RunR3<int32_t, bool>({{{true, false}, {false, true}, {true, true}},
661                         {{false, false}, {false, true}, {true, false}}},
662                        {{{false}, {true}}}, {1, 1, 1},
663                        {{{true, false}, {false, true}, {true, true}},
664                         {{false, false}, {false, false}, {true, true}}});
665 }
666 
667 // Tests for simple R3 case where the update is contiguous (i.e. the minor
668 // two dimensions are not sliced).
XLA_TEST_F(DynamicUpdateSliceTest,R3ContiguousSingleElement)669 XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousSingleElement) {
670   // Single element, index in-bounds
671   std::vector<int32_t> operand_shape({4, 5, 2});
672   RunR3Contiguous<float>(operand_shape, /*index=*/1, /*size=*/1);
673 }
674 
XLA_TEST_F(DynamicUpdateSliceTest,R3ContiguousSingleElementBF16)675 XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousSingleElementBF16) {
676   // Single element, index in-bounds
677   std::vector<int32_t> operand_shape({4, 5, 2});
678   RunR3Contiguous<bfloat16>(operand_shape, /*index=*/1, /*size=*/1);
679 }
680 
XLA_TEST_F(DynamicUpdateSliceTest,R3ContiguousMultipleElements)681 XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleElements) {
682   // Multiples element, index in-bounds.
683   std::vector<int32_t> operand_shape({4, 5, 2});
684   RunR3Contiguous<float>(operand_shape, /*index=*/1, /*size=*/2);
685 }
686 
XLA_TEST_F(DynamicUpdateSliceTest,R3ContiguousMultipleElementsBF16)687 XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleElementsBF16) {
688   // Multiples element, index in-bounds.
689   std::vector<int32_t> operand_shape({4, 5, 2});
690   RunR3Contiguous<bfloat16>(operand_shape, /*index=*/1, /*size=*/2);
691 }
692 
XLA_TEST_F(DynamicUpdateSliceTest,R3ContiguousMultipleOOB)693 XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleOOB) {
694   // Multiple element, index out of bounds.
695   std::vector<int32_t> operand_shape({4, 5, 2});
696   RunR3Contiguous<float>(operand_shape, /*index=*/3, /*size=*/2);
697 }
698 
XLA_TEST_F(DynamicUpdateSliceTest,R3ContiguousMultipleOOBBF16)699 XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleOOBBF16) {
700   // Multiple element, index out of bounds.
701   std::vector<int32_t> operand_shape({4, 5, 2});
702   RunR3Contiguous<bfloat16>(operand_shape, /*index=*/3, /*size=*/2);
703 }
704 
XLA_TEST_F(DynamicUpdateSliceTest,R3ContiguousTooLarge)705 XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousTooLarge) {
706   // Multiple element, update size larger than operand.
707   std::vector<int32_t> operand_shape({4, 5, 2});
708   RunR3Contiguous<float>(operand_shape, /*index=*/5, /*size=*/2);
709 }
710 
XLA_TEST_F(DynamicUpdateSliceTest,R3ContiguousTooLargeBF16)711 XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousTooLargeBF16) {
712   // Multiple element, update size larger than operand.
713   std::vector<int32_t> operand_shape({4, 5, 2});
714   RunR3Contiguous<bfloat16>(operand_shape, /*index=*/5, /*size=*/2);
715 }
716 
XLA_TEST_F(DynamicUpdateSliceTest,R3ContiguousUnaligned)717 XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousUnaligned) {
718   std::vector<int32_t> operand_shape({3, 123, 247});
719   RunR3Contiguous<float>(operand_shape, /*index=*/1, /*size=*/1);
720 }
721 
XLA_TEST_F(DynamicUpdateSliceTest,R3ContiguousUnalignedBF16)722 XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousUnalignedBF16) {
723   std::vector<int32_t> operand_shape({3, 123, 247});
724   RunR3Contiguous<bfloat16>(operand_shape, /*index=*/1, /*size=*/1);
725 }
726 
727 // TODO(b/34134076) Disabled on GPU 2016-01-06 due to out-of-memory error.
XLA_TEST_F(DynamicUpdateSliceTest,DISABLED_ON_GPU (R3ContiguousLarger))728 XLA_TEST_F(DynamicUpdateSliceTest, DISABLED_ON_GPU(R3ContiguousLarger)) {
729   std::vector<int32_t> operand_shape({32, 128, 1024});
730   RunR3Contiguous<float>(operand_shape, /*index=*/7, /*size=*/1);
731 }
732 
XLA_TEST_F(DynamicUpdateSliceTest,DISABLED_ON_GPU (R3ContiguousLargerBF16))733 XLA_TEST_F(DynamicUpdateSliceTest, DISABLED_ON_GPU(R3ContiguousLargerBF16)) {
734   std::vector<int32_t> operand_shape({32, 128, 1024});
735   RunR3Contiguous<bfloat16>(operand_shape, /*index=*/7, /*size=*/1);
736 }
737 
738 // This test that buffer assignment does not alias constants with the output of
739 // dynamic update slice.
XLA_TEST_F(HloTestBase,AddOfDUS)740 XLA_TEST_F(HloTestBase, AddOfDUS) {
741   const char* hlo_string = R"(
742   HloModule m
743   test {
744     o = s32[6] constant({2,3,4,5,6,7})
745     i = s32[] parameter(0)
746     u = s32[2] parameter(1)
747     dus = s32[6] dynamic-update-slice(o,u,i)
748     a = s32[6] add(dus, dus)
749     j = s32[] parameter(2)
750     ROOT ds = s32[2] dynamic-slice(a, j), dynamic_slice_sizes={2}
751   }
752   )";
753   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0}));
754 }
755 
BM_DynamicSlice(::testing::benchmark::State & state)756 void BM_DynamicSlice(::testing::benchmark::State& state) {
757   se::Platform* platform = PlatformUtil::GetDefaultPlatform().ValueOrDie();
758   auto executors = PlatformUtil::GetStreamExecutors(platform).ValueOrDie();
759   se::StreamExecutorMemoryAllocator allocator(platform, executors);
760   LocalClient* client =
761       ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie();
762   auto* transfer_manager =
763       TransferManager::GetForPlatform(platform).ValueOrDie();
764   int device_ordinal = client->default_device_ordinal();
765 
766   XlaBuilder builder("DynamicSlice");
767 
768   // Create input as a constant: shape [1, 2, 3, 4]
769   auto input_literal = LiteralUtil::CreateR4(
770       {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
771         {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}});
772   auto input = ConstantLiteral(&builder, input_literal);
773 
774   auto stream =
775       client->mutable_backend()->BorrowStream(device_ordinal).ValueOrDie();
776 
777   // Create dynamic slice start indices as a parameter: shape [4]
778   auto start_indices_shape = ShapeUtil::MakeShape(S32, {});
779   std::vector<XlaOp> start_indices(4);
780   std::vector<ScopedShapedBuffer> shaped_buffers;
781   std::vector<const Shape*> host_shapes(4);
782   for (int i = 0; i < 4; ++i) {
783     start_indices[i] =
784         Parameter(&builder, i, start_indices_shape, "start_indices");
785     auto start_index_literal = LiteralUtil::CreateR0<int32_t>(i + 1);
786     // Initialize and transfer parameter buffer.
787     shaped_buffers.emplace_back(
788         client->backend()
789             .transfer_manager()
790             ->AllocateScopedShapedBuffer(start_indices_shape, &allocator,
791                                          /*device_ordinal=*/0)
792             .value());
793     host_shapes[i] = &shaped_buffers[i].on_host_shape();
794     ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice(
795         stream.get(), start_index_literal, shaped_buffers[i]));
796   }
797 
798   // Add DynamicSlice op to the computation.
799   DynamicSlice(input, start_indices, {1, 1, 1, 1});
800   auto computation = builder.Build().value();
801 
802   TF_ASSERT_OK_AND_ASSIGN(
803       auto executables,
804       client->Compile(computation, host_shapes, ExecutableBuildOptions()));
805   auto executable = std::move(executables[0]);
806 
807   // Run some warm-up executions.
808   ExecutableRunOptions options;
809   options.set_allocator(&allocator);
810   const int kWarmups = 2;
811   std::vector<const ShapedBuffer*> shaped_buffer_ptrs;
812   absl::c_transform(shaped_buffers, std::back_inserter(shaped_buffer_ptrs),
813                     [](const ScopedShapedBuffer& buffer) { return &buffer; });
814 
815   for (int i = 0; i < kWarmups; ++i) {
816     auto result = executable->Run(shaped_buffer_ptrs, options);
817     ASSERT_TRUE(result.ok());
818   }
819 
820   // Run benchmark.
821   for (auto s : state) {
822     auto result = executable->Run(shaped_buffer_ptrs, options);
823     ASSERT_TRUE(result.ok());
824   }
825 }
826 BENCHMARK(BM_DynamicSlice);
827 
828 }  // namespace
829 }  // namespace xla
830