1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <numeric>
17 #include <vector>
18
19 #include "tensorflow/compiler/xla/array2d.h"
20 #include "tensorflow/compiler/xla/client/client_library.h"
21 #include "tensorflow/compiler/xla/client/local_client.h"
22 #include "tensorflow/compiler/xla/client/xla_builder.h"
23 #include "tensorflow/compiler/xla/reference_util.h"
24 #include "tensorflow/compiler/xla/service/local_service.h"
25 #include "tensorflow/compiler/xla/service/platform_util.h"
26 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
27 #include "tensorflow/compiler/xla/service/transfer_manager.h"
28 #include "tensorflow/compiler/xla/test_helpers.h"
29 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
30 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
31 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
32 #include "tensorflow/compiler/xla/tests/test_macros.h"
33 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
34 #include "tensorflow/core/platform/test.h"
35 #include "tensorflow/core/platform/test_benchmark.h"
36 #include "tensorflow/stream_executor/device_memory_allocator.h"
37
38 namespace xla {
39 namespace {
40
41 class DynamicSliceTest : public ClientLibraryTestBase {
42 protected:
43 template <typename IndexT, typename DataT>
TestR1()44 void TestR1() {
45 // Slice at dimension start.
46 RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {0}, {5}, {0, 1, 2, 3, 4});
47 // Slice in the middle.
48 RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {2}, {3}, {2, 3, 4});
49 // Slice at dimension boundaries.
50 RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {5}, {3}, {5, 6, 7});
51 // Zero element slice.
52 RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {2}, {0}, {});
53 }
54
55 template <typename IndexT, typename DataT>
TestR1OOB()56 void TestR1OOB() {
57 // Slice at dimension boundaries, but with out of bounds indices.
58 RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {6}, {4}, {4, 5, 6, 7});
59 }
60
61 template <typename IndexT, typename DataT>
TestR2()62 void TestR2() {
63 // Slice at dimension start.
64 RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {0, 0}, {2, 2},
65 {{1, 2}, {4, 5}});
66 // Slice in the middle.
67 RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {1, 1}, {2, 1},
68 {{5}, {8}});
69 // Slice at dimension boundaries.
70 RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {1, 1}, {2, 1},
71 {{5}, {8}});
72 // Zero element slice: 2x0.
73 RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {0, 0}, {2, 0},
74 {{}, {}});
75 // Zero element slice: 0x2.
76 RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {0, 0}, {0, 2},
77 Array2D<int>(0, 2));
78 }
79
80 template <typename IndexT, typename DataT>
TestR2OOB()81 void TestR2OOB() {
82 // Slice at dimension boundaries, but with out of bounds indices.
83 RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {1, 1}, {3, 3},
84 {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
85 }
86
87 template <typename IndexT, typename DataT>
TestR3()88 void TestR3() {
89 // R3 Shape: [2, 3, 2]
90 // clang-format off
91
92 // Slice at dimension start.
93 RunR3<IndexT, DataT>(
94 {{{1, 2}, {3, 4}, {5, 6}},
95 {{7, 8}, {9, 10}, {11, 12}}},
96 {0, 0, 0}, {2, 1, 2},
97 {{{1, 2}}, {{7, 8}}});
98
99 // Slice in the middle.
100 RunR3<IndexT, DataT>(
101 {{{1, 2}, {3, 4}, {5, 6}},
102 {{7, 8}, {9, 10}, {11, 12}}},
103 {0, 1, 1}, {2, 2, 1},
104 {{{4}, {6}}, {{10}, {12}}});
105 // clang-format on
106 }
107
108 template <typename IndexT, typename DataT>
TestR3OOB()109 void TestR3OOB() {
110 // Slice at dimension boundaries, but with out of bounds indices.
111 RunR3<IndexT, DataT>(
112 {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}, {0, 2, 1},
113 {2, 1, 2}, {{{5, 6}}, {{11, 12}}});
114 }
115
116 template <typename IndexT, typename DataT>
RunR1(absl::Span<const int> input_values_int,const std::vector<IndexT> slice_starts,const std::vector<int64_t> & slice_sizes,absl::Span<const int> expected_values_int)117 void RunR1(absl::Span<const int> input_values_int,
118 const std::vector<IndexT> slice_starts,
119 const std::vector<int64_t>& slice_sizes,
120 absl::Span<const int> expected_values_int) {
121 // bfloat16 has explicit constructors, so it does not implicitly convert the
122 // way built-in types do, which is why we can't take the parameter as an
123 // Span<DataT>. We also can't convert it to a vector, because
124 // vector<bool> is special so that it cannot be a Span<bool>, which
125 // is what the code below wants. So instead we do this.
126 Literal input_values =
127 LiteralUtil::CreateR1(input_values_int)
128 .Convert(primitive_util::NativeToPrimitiveType<DataT>())
129 .ValueOrDie();
130 Literal expected_values =
131 std::move(LiteralUtil::CreateR1(expected_values_int)
132 .Convert(primitive_util::NativeToPrimitiveType<DataT>())
133 .ValueOrDie());
134
135 XlaBuilder builder(TestName());
136 // Initialize and transfer dynamic slice start indices parameter.
137 XlaOp starts;
138 std::unique_ptr<GlobalData> start_data = CreateR0Parameter<IndexT>(
139 slice_starts[0], 0, "slice_starts", &builder, &starts);
140 // Build dynamic slice computation.
141 auto input = ConstantLiteral(&builder, input_values);
142 DynamicSlice(input, absl::Span<const XlaOp>({starts}), slice_sizes);
143 // Run computation and compare against expected values.
144 ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()});
145 }
146
147 template <typename IndexT, typename DataT>
RunR2(const Array2D<int> & input_values_int,const std::vector<IndexT> slice_starts,const std::vector<int64_t> & slice_sizes,const Array2D<int> & expected_values_int)148 void RunR2(const Array2D<int>& input_values_int,
149 const std::vector<IndexT> slice_starts,
150 const std::vector<int64_t>& slice_sizes,
151 const Array2D<int>& expected_values_int) {
152 Literal input_values =
153 std::move(LiteralUtil::CreateR2FromArray2D(input_values_int)
154 .Convert(primitive_util::NativeToPrimitiveType<DataT>())
155 .ValueOrDie());
156 Literal expected_values =
157 std::move(LiteralUtil::CreateR2FromArray2D(expected_values_int)
158 .Convert(primitive_util::NativeToPrimitiveType<DataT>())
159 .ValueOrDie());
160
161 XlaBuilder builder(TestName());
162 // Initialize and transfer dynamic slice start indices parameter.
163 std::vector<XlaOp> starts(2);
164 std::vector<std::unique_ptr<GlobalData>> start_data(2);
165 for (int i = 0; i < 2; ++i) {
166 start_data[i] = CreateR0Parameter<IndexT>(
167 slice_starts[i], i, "slice_starts", &builder, &starts[i]);
168 }
169
170 // Build dynamic slice computation.
171 auto input = ConstantLiteral(&builder, input_values);
172 DynamicSlice(input, starts, slice_sizes);
173 // Run computation and compare against expected values.
174 std::vector<GlobalData*> argument_ptrs;
175 absl::c_transform(start_data, std::back_inserter(argument_ptrs),
176 [](const std::unique_ptr<GlobalData>& argument) {
177 return argument.get();
178 });
179 ComputeAndCompareLiteral(&builder, expected_values, argument_ptrs);
180 }
181
182 template <typename IndexT, typename DataT>
RunR3(const Array3D<int> & input_values_int,const std::vector<IndexT> slice_starts,const std::vector<int64_t> & slice_sizes,const Array3D<int> & expected_values_int)183 void RunR3(const Array3D<int>& input_values_int,
184 const std::vector<IndexT> slice_starts,
185 const std::vector<int64_t>& slice_sizes,
186 const Array3D<int>& expected_values_int) {
187 Literal input_values =
188 std::move(LiteralUtil::CreateR3FromArray3D(input_values_int)
189 .Convert(primitive_util::NativeToPrimitiveType<DataT>())
190 .ValueOrDie());
191 Literal expected_values =
192 std::move(LiteralUtil::CreateR3FromArray3D(expected_values_int)
193 .Convert(primitive_util::NativeToPrimitiveType<DataT>())
194 .ValueOrDie());
195
196 XlaBuilder builder(TestName());
197 // Initialize and transfer dynamic slice start indices parameter.
198 std::vector<XlaOp> starts(3);
199 std::vector<std::unique_ptr<GlobalData>> start_data(3);
200 for (int i = 0; i < 3; ++i) {
201 start_data[i] = CreateR0Parameter<IndexT>(
202 slice_starts[i], i, "slice_starts", &builder, &starts[i]);
203 }
204 // Build dynamic slice computation.
205 auto input = ConstantLiteral(&builder, input_values);
206 DynamicSlice(input, starts, slice_sizes);
207 // Run computation and compare against expected values.
208 std::vector<GlobalData*> argument_ptrs;
209 absl::c_transform(start_data, std::back_inserter(argument_ptrs),
210 [](const std::unique_ptr<GlobalData>& argument) {
211 return argument.get();
212 });
213 ComputeAndCompareLiteral(&builder, expected_values, argument_ptrs);
214 }
215 };
216
XLA_TEST_F(DynamicSliceTest,Int32R1BF16)217 XLA_TEST_F(DynamicSliceTest, Int32R1BF16) { TestR1<int32_t, bfloat16>(); }
XLA_TEST_F(DynamicSliceTest,Int32R1)218 XLA_TEST_F(DynamicSliceTest, Int32R1) { TestR1<int32_t, int32_t>(); }
XLA_TEST_F(DynamicSliceTest,Int32R1OOB)219 XLA_TEST_F(DynamicSliceTest, Int32R1OOB) { TestR1OOB<int32_t, int32_t>(); }
XLA_TEST_F(DynamicSliceTest,Int64R1)220 XLA_TEST_F(DynamicSliceTest, Int64R1) { TestR1<int64_t, float>(); }
XLA_TEST_F(DynamicSliceTest,UInt64R1)221 XLA_TEST_F(DynamicSliceTest, UInt64R1) { TestR1<uint64_t, float>(); }
XLA_TEST_F(DynamicSliceTest,UInt32R1OOB)222 XLA_TEST_F(DynamicSliceTest, UInt32R1OOB) {
223 RunR1<uint32_t, int32_t>({0, 1, 2, 3, 4}, {2147483648u}, {2}, {3, 4});
224 }
225
XLA_TEST_F(DynamicSliceTest,Int32R2BF16)226 XLA_TEST_F(DynamicSliceTest, Int32R2BF16) { TestR2<int32_t, bfloat16>(); }
XLA_TEST_F(DynamicSliceTest,Int32R2)227 XLA_TEST_F(DynamicSliceTest, Int32R2) { TestR2<int32_t, int32_t>(); }
XLA_TEST_F(DynamicSliceTest,Int32R2OOB)228 XLA_TEST_F(DynamicSliceTest, Int32R2OOB) { TestR2OOB<int32_t, int32_t>(); }
XLA_TEST_F(DynamicSliceTest,Int64R2)229 XLA_TEST_F(DynamicSliceTest, Int64R2) { TestR2<int64_t, float>(); }
XLA_TEST_F(DynamicSliceTest,UInt64R2)230 XLA_TEST_F(DynamicSliceTest, UInt64R2) { TestR2<uint64_t, int32_t>(); }
XLA_TEST_F(DynamicSliceTest,UInt32R2OOB)231 XLA_TEST_F(DynamicSliceTest, UInt32R2OOB) {
232 RunR2<uint32_t, int32_t>({{0, 1}, {2, 3}}, {2147483648u, 0}, {1, 1}, {{2}});
233 }
234
XLA_TEST_F(DynamicSliceTest,Int32R3BF16)235 XLA_TEST_F(DynamicSliceTest, Int32R3BF16) { TestR3<int32_t, bfloat16>(); }
XLA_TEST_F(DynamicSliceTest,Int32R3)236 XLA_TEST_F(DynamicSliceTest, Int32R3) { TestR3<int32_t, float>(); }
XLA_TEST_F(DynamicSliceTest,Int32R3OOB)237 XLA_TEST_F(DynamicSliceTest, Int32R3OOB) { TestR3OOB<int32_t, float>(); }
XLA_TEST_F(DynamicSliceTest,Int64R3)238 XLA_TEST_F(DynamicSliceTest, Int64R3) { TestR3<int64_t, float>(); }
XLA_TEST_F(DynamicSliceTest,UInt64R3)239 XLA_TEST_F(DynamicSliceTest, UInt64R3) { TestR3<uint64_t, float>(); }
XLA_TEST_F(DynamicSliceTest,UInt32R3OOB)240 XLA_TEST_F(DynamicSliceTest, UInt32R3OOB) {
241 RunR3<uint32_t, int32_t>({{{0, 1}, {2, 3}}, {{4, 5}, {6, 7}}},
242 {2147483648u, 0, 2147483648u}, {1, 1, 1}, {{{5}}});
243 }
244
XLA_TEST_F(DynamicSliceTest,Int32R1Pred)245 XLA_TEST_F(DynamicSliceTest, Int32R1Pred) {
246 // Slice at dimension start.
247 RunR1<int32_t, bool>({true, false, false, true, false, true, true, false},
248 {0}, {5}, {true, false, false, true, false});
249 // Slice in the middle.
250 RunR1<int32_t, bool>({true, false, false, true, false, true, true, false},
251 {2}, {3}, {false, true, false});
252 // Slice at dimension boundaries.
253 RunR1<int32_t, bool>({true, false, false, true, false, true, true, false},
254 {5}, {3}, {true, true, false});
255 // Zero element slice.
256 RunR1<int32_t, bool>({true, false, false, true, false, true, true, false},
257 {2}, {0}, {});
258 }
259
XLA_TEST_F(DynamicSliceTest,Int32R2Pred)260 XLA_TEST_F(DynamicSliceTest, Int32R2Pred) {
261 // Slice at dimension start.
262 RunR2<int32_t, bool>(
263 {{true, false, true}, {false, false, true}, {true, true, false}}, {0, 0},
264 {2, 2}, {{true, false}, {false, false}});
265 // Slice in the middle.
266 RunR2<int32_t, bool>(
267 {{true, false, true}, {false, false, true}, {true, true, false}}, {1, 1},
268 {2, 1}, {{false}, {true}});
269 // Slice at dimension boundaries.
270 RunR2<int32_t, bool>(
271 {{true, false, true}, {false, false, true}, {true, true, false}}, {1, 1},
272 {2, 1}, {{false}, {true}});
273 // Zero element slice: 2x0.
274 RunR2<int32_t, bool>(
275 {{true, false, true}, {false, false, true}, {true, true, false}}, {0, 0},
276 {2, 0}, {{}, {}});
277 // Zero element slice: 0x2.
278 RunR2<int32_t, bool>(
279 {{true, false, true}, {false, false, true}, {true, true, false}}, {0, 0},
280 {0, 2}, Array2D<int>(0, 2));
281 }
282
XLA_TEST_F(DynamicSliceTest,Int32R3Pred)283 XLA_TEST_F(DynamicSliceTest, Int32R3Pred) {
284 // R3 Shape: [2, 3, 2]
285 // clang-format off
286
287 // Slice at dimension start.
288 RunR3<int32_t, bool>(
289 {{{true, false}, {false, true}, {true, true}},
290 {{false, true}, {true, false}, {false, false}}},
291 {0, 0, 0}, {2, 1, 2},
292 {{{true, false}}, {{false, true}}});
293
294 // Slice in the middle.
295 RunR3<int32_t, bool>(
296 {{{true, false}, {false, true}, {true, true}},
297 {{false, true}, {true, false}, {false, false}}},
298 {0, 1, 1}, {2, 2, 1},
299 {{{true}, {true}}, {{false}, {false}}});
300
301 // clang-format on
302 }
303
304 class DynamicUpdateSliceTest : public ClientLibraryTestBase {
305 protected:
306 template <typename IndexT, typename DataT>
TestR0()307 void TestR0() {
308 // Disable algebraic simplifier, otherwise the op will be replaced by a
309 // constant.
310 execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes(
311 "algsimp");
312 RunR0<IndexT, DataT>(0, 123, {}, 123);
313 }
314
315 template <typename IndexT, typename DataT>
TestR1()316 void TestR1() {
317 // Slice at dimension start.
318 RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {8, 9, 10}, {0},
319 {8, 9, 10, 3, 4, 5, 6, 7});
320 // Slice in the middle.
321 RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {8, 9, 10}, {2},
322 {0, 1, 8, 9, 10, 5, 6, 7});
323 // Slice at dimension boundaries.
324 RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {8, 9, 10}, {5},
325 {0, 1, 2, 3, 4, 8, 9, 10});
326 // Zero-sized update.
327 RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {}, {2},
328 {0, 1, 2, 3, 4, 5, 6, 7});
329 }
330
331 template <typename IndexT, typename DataT>
TestR2()332 void TestR2() {
333 // Slice at dimension start.
334 RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {{10, 11}}, {0, 0},
335 {{10, 11, 3}, {4, 5, 6}, {7, 8, 9}});
336 // Slice in the middle.
337 RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {{10, 11}}, {1, 1},
338 {{1, 2, 3}, {4, 10, 11}, {7, 8, 9}});
339 // Slice at dimension boundaries.
340 RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {{10, 11}}, {2, 1},
341 {{1, 2, 3}, {4, 5, 6}, {7, 10, 11}});
342 // Zero-sized update.
343 RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {{}}, {2, 1},
344 {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
345 }
346
347 template <typename IndexT, typename DataT>
TestR3()348 void TestR3() {
349 // R3 Shape: [2, 3, 2]
350 // Slice at dimension start.
351 RunR3<IndexT, DataT>(
352 {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}},
353 {{{13, 14}, {15, 16}}, {{17, 18}, {19, 20}}}, {0, 0, 0},
354 {{{13, 14}, {15, 16}, {5, 6}}, {{17, 18}, {19, 20}, {11, 12}}});
355 // Slice in the middle.
356 RunR3<IndexT, DataT>(
357 {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}, {{{13}, {15}}},
358 {1, 1, 1}, {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 13}, {11, 15}}});
359 }
360
361 template <typename IndexT, typename DataT>
TestOOB()362 void TestOOB() {
363 // // Slice at dimension boundaries, but with out of bounds indices.
364 RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {8, 9, 10}, {6},
365 {0, 1, 2, 3, 4, 8, 9, 10});
366 // R2 Shape: [3, 3]
367 RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {{10, 11}}, {2, 2},
368 {{1, 2, 3}, {4, 5, 6}, {7, 10, 11}});
369 // R3 Shape: [2, 3, 2]
370 RunR3<IndexT, DataT>(
371 {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}, {{{13}, {15}}},
372 {1, 2, 1}, {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 13}, {11, 15}}});
373 }
374
375 template <typename IndexT, typename DataT>
RunR0(int input_value_int,int update_value_int,const std::vector<IndexT> slice_starts,int expected_value_int)376 void RunR0(int input_value_int, int update_value_int,
377 const std::vector<IndexT> slice_starts, int expected_value_int) {
378 Literal input_value =
379 std::move(LiteralUtil::CreateR0(input_value_int)
380 .Convert(primitive_util::NativeToPrimitiveType<DataT>())
381 .ValueOrDie());
382 Literal update_value =
383 std::move(LiteralUtil::CreateR0(update_value_int)
384 .Convert(primitive_util::NativeToPrimitiveType<DataT>())
385 .ValueOrDie());
386 Literal expected_value =
387 std::move(LiteralUtil::CreateR0(expected_value_int)
388 .Convert(primitive_util::NativeToPrimitiveType<DataT>())
389 .ValueOrDie());
390
391 XlaBuilder builder(TestName());
392 // Build dynamic slice computation.
393 auto input = ConstantLiteral(&builder, input_value);
394 auto update = ConstantLiteral(&builder, update_value);
395 DynamicUpdateSlice(input, update, absl::Span<const XlaOp>({}));
396 // Run computation and compare against expected values.
397 ComputeAndCompareLiteral(&builder, expected_value, {});
398 }
399
400 template <typename IndexT, typename DataT>
RunR1(absl::Span<const int> input_values_int,absl::Span<const int> update_values_int,const std::vector<IndexT> slice_starts,absl::Span<const int> expected_values_int)401 void RunR1(absl::Span<const int> input_values_int,
402 absl::Span<const int> update_values_int,
403 const std::vector<IndexT> slice_starts,
404 absl::Span<const int> expected_values_int) {
405 Literal input_values =
406 std::move(LiteralUtil::CreateR1(input_values_int)
407 .Convert(primitive_util::NativeToPrimitiveType<DataT>())
408 .ValueOrDie());
409 Literal update_values =
410 std::move(LiteralUtil::CreateR1(update_values_int)
411 .Convert(primitive_util::NativeToPrimitiveType<DataT>())
412 .ValueOrDie());
413 Literal expected_values =
414 std::move(LiteralUtil::CreateR1(expected_values_int)
415 .Convert(primitive_util::NativeToPrimitiveType<DataT>())
416 .ValueOrDie());
417
418 XlaBuilder builder(TestName());
419 // Initialize and transfer dynamic slice start indices parameter.
420 XlaOp starts;
421 std::unique_ptr<GlobalData> start_data = CreateR0Parameter<IndexT>(
422 slice_starts[0], 0, "slice_starts", &builder, &starts);
423 // Build dynamic slice computation.
424 auto input = ConstantLiteral(&builder, input_values);
425 auto update = ConstantLiteral(&builder, update_values);
426 DynamicUpdateSlice(input, update, absl::Span<const XlaOp>({starts}));
427 // Run computation and compare against expected values.
428 ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()});
429 }
430
431 template <typename IndexT, typename DataT>
RunR2(const Array2D<int> & input_values_int,const Array2D<int> & update_values_int,const std::vector<IndexT> slice_starts,const Array2D<int> & expected_values_int)432 void RunR2(const Array2D<int>& input_values_int,
433 const Array2D<int>& update_values_int,
434 const std::vector<IndexT> slice_starts,
435 const Array2D<int>& expected_values_int) {
436 Literal input_values =
437 std::move(LiteralUtil::CreateR2FromArray2D(input_values_int)
438 .Convert(primitive_util::NativeToPrimitiveType<DataT>())
439 .ValueOrDie());
440 Literal update_values =
441 std::move(LiteralUtil::CreateR2FromArray2D(update_values_int)
442 .Convert(primitive_util::NativeToPrimitiveType<DataT>())
443 .ValueOrDie());
444 Literal expected_values =
445 std::move(LiteralUtil::CreateR2FromArray2D(expected_values_int)
446 .Convert(primitive_util::NativeToPrimitiveType<DataT>())
447 .ValueOrDie());
448
449 XlaBuilder builder(TestName());
450 // Initialize and transfer dynamic slice start indices parameter.
451 std::vector<XlaOp> starts(2);
452 std::vector<std::unique_ptr<GlobalData>> start_data(2);
453 for (int i = 0; i < 2; ++i) {
454 start_data[i] = CreateR0Parameter<IndexT>(
455 slice_starts[i], i, "slice_starts", &builder, &starts[i]);
456 }
457 // Build dynamic slice computation.
458 auto input = ConstantLiteral(&builder, input_values);
459 auto update = ConstantLiteral(&builder, update_values);
460 DynamicUpdateSlice(input, update, starts);
461 // Run computation and compare against expected values.
462 std::vector<GlobalData*> argument_ptrs;
463 absl::c_transform(start_data, std::back_inserter(argument_ptrs),
464 [](const std::unique_ptr<GlobalData>& argument) {
465 return argument.get();
466 });
467 ComputeAndCompareLiteral(&builder, expected_values, argument_ptrs);
468 }
469
470 template <typename IndexT, typename DataT>
RunR3(const Array3D<int> & input_values_int,const Array3D<int> & update_values_int,const std::vector<IndexT> slice_starts,const Array3D<int> & expected_values_int)471 void RunR3(const Array3D<int>& input_values_int,
472 const Array3D<int>& update_values_int,
473 const std::vector<IndexT> slice_starts,
474 const Array3D<int>& expected_values_int) {
475 Literal input_values =
476 std::move(LiteralUtil::CreateR3FromArray3D(input_values_int)
477 .Convert(primitive_util::NativeToPrimitiveType<DataT>())
478 .ValueOrDie());
479 Literal update_values =
480 std::move(LiteralUtil::CreateR3FromArray3D(update_values_int)
481 .Convert(primitive_util::NativeToPrimitiveType<DataT>())
482 .ValueOrDie());
483 Literal expected_values =
484 std::move(LiteralUtil::CreateR3FromArray3D(expected_values_int)
485 .Convert(primitive_util::NativeToPrimitiveType<DataT>())
486 .ValueOrDie());
487
488 XlaBuilder builder(TestName());
489 // Initialize and transfer dynamic slice start indices parameter.
490 std::vector<XlaOp> starts(3);
491 std::vector<std::unique_ptr<GlobalData>> start_data(3);
492 for (int i = 0; i < 3; ++i) {
493 start_data[i] = CreateR0Parameter<IndexT>(
494 slice_starts[i], i, "slice_starts", &builder, &starts[i]);
495 }
496
497 // Build dynamic slice computation.
498 auto input = ConstantLiteral(&builder, input_values);
499 auto update = ConstantLiteral(&builder, update_values);
500 DynamicUpdateSlice(input, update, starts);
501 // Run computation and compare against expected values.
502 std::vector<GlobalData*> argument_ptrs;
503 absl::c_transform(start_data, std::back_inserter(argument_ptrs),
504 [](const std::unique_ptr<GlobalData>& argument) {
505 return argument.get();
506 });
507 ComputeAndCompareLiteral(&builder, expected_values, argument_ptrs);
508 }
509
510 template <class T>
RunR3Contiguous(std::vector<int32_t> operand_shape,int32_t index,int32_t size)511 void RunR3Contiguous(std::vector<int32_t> operand_shape, int32_t index,
512 int32_t size) {
513 const int32_t kSeq = operand_shape[0];
514 const int32_t kBatch = operand_shape[1];
515 const int32_t kDim = operand_shape[2];
516 Array3D<T> input_values(kSeq, kBatch, kDim);
517 Array3D<T> update_values(size, kBatch, kDim);
518 Array3D<T> expected_values(kSeq, kBatch, kDim);
519 index = std::min(std::max(0, index), kSeq - size);
520
521 input_values.FillIota(static_cast<T>(0));
522 T value = static_cast<T>(10);
523 update_values.FillIota(static_cast<T>(value));
524
525 // TODO(b/34128753) Expected values may vary depending on backend when
526 // the indices are out of bounds.
527 expected_values.FillIota(static_cast<T>(0));
528 for (int i = 0; i < size; i++) {
529 for (int j = 0; j < kBatch; j++) {
530 for (int k = 0; k < kDim; k++) {
531 expected_values(index + i, j, k) = value++;
532 }
533 }
534 }
535 if (VLOG_IS_ON(1)) {
536 DumpArray<T>("input", input_values);
537 DumpArray<T>("update", update_values);
538 DumpArray<T>("expected", expected_values);
539 }
540
541 // Build dynamic slice computation.
542 XlaBuilder builder(TestName());
543 // Initialize and transfer input parameter.
544 XlaOp input;
545 std::unique_ptr<GlobalData> input_data =
546 CreateR3Parameter<T>(input_values, 0, "input_values", &builder, &input);
547 // Initialize and transfer update parameter.
548 XlaOp update;
549 std::unique_ptr<GlobalData> update_data = CreateR3Parameter<T>(
550 update_values, 1, "update_values", &builder, &update);
551 auto constant_index = ConstantR0<int32_t>(&builder, index);
552 auto zero = ConstantR0<int32_t>(&builder, 0);
553 DynamicUpdateSlice(input, update, {constant_index, zero, zero});
554
555 // Run computation and compare against expected values.
556 ComputeAndCompareR3<T>(&builder, expected_values,
557 {input_data.get(), update_data.get()},
558 ErrorSpec(0.000001));
559 }
560
561 template <typename NativeT>
DumpArray(const std::string & name,const Array3D<NativeT> values)562 void DumpArray(const std::string& name, const Array3D<NativeT> values) {
563 Literal literal = LiteralUtil::CreateR3FromArray3D<NativeT>(values);
564 LOG(INFO) << name << ":" << literal.ToString();
565 }
566 };
567
XLA_TEST_F(DynamicUpdateSliceTest,Int32R0BF16)568 XLA_TEST_F(DynamicUpdateSliceTest, Int32R0BF16) { TestR0<int32_t, bfloat16>(); }
XLA_TEST_F(DynamicUpdateSliceTest,Int32R0)569 XLA_TEST_F(DynamicUpdateSliceTest, Int32R0) { TestR0<int32_t, float>(); }
XLA_TEST_F(DynamicUpdateSliceTest,Int64R0)570 XLA_TEST_F(DynamicUpdateSliceTest, Int64R0) { TestR0<int64_t, float>(); }
XLA_TEST_F(DynamicUpdateSliceTest,UInt64R0)571 XLA_TEST_F(DynamicUpdateSliceTest, UInt64R0) { TestR0<uint64_t, float>(); }
572
XLA_TEST_F(DynamicUpdateSliceTest,Int32R1BF16)573 XLA_TEST_F(DynamicUpdateSliceTest, Int32R1BF16) { TestR1<int32_t, bfloat16>(); }
XLA_TEST_F(DynamicUpdateSliceTest,Int32R1)574 XLA_TEST_F(DynamicUpdateSliceTest, Int32R1) { TestR1<int32_t, float>(); }
XLA_TEST_F(DynamicUpdateSliceTest,Int64R1)575 XLA_TEST_F(DynamicUpdateSliceTest, Int64R1) { TestR1<int64_t, float>(); }
XLA_TEST_F(DynamicUpdateSliceTest,UInt64R1)576 XLA_TEST_F(DynamicUpdateSliceTest, UInt64R1) { TestR1<uint64_t, float>(); }
XLA_TEST_F(DynamicUpdateSliceTest,UInt32R1OOB)577 XLA_TEST_F(DynamicUpdateSliceTest, UInt32R1OOB) {
578 RunR1<uint32_t, int32_t>({0, 1, 2, 3, 4}, {5, 6}, {2147483648u},
579 {0, 1, 2, 5, 6});
580 }
581
XLA_TEST_F(DynamicUpdateSliceTest,Int32R2BF16)582 XLA_TEST_F(DynamicUpdateSliceTest, Int32R2BF16) { TestR2<int32_t, bfloat16>(); }
XLA_TEST_F(DynamicUpdateSliceTest,Int32R2)583 XLA_TEST_F(DynamicUpdateSliceTest, Int32R2) { TestR2<int32_t, float>(); }
XLA_TEST_F(DynamicUpdateSliceTest,Int64R2)584 XLA_TEST_F(DynamicUpdateSliceTest, Int64R2) { TestR2<int64_t, int64_t>(); }
XLA_TEST_F(DynamicUpdateSliceTest,UInt64R2)585 XLA_TEST_F(DynamicUpdateSliceTest, UInt64R2) { TestR2<uint64_t, int32_t>(); }
XLA_TEST_F(DynamicUpdateSliceTest,UInt32R2OOB)586 XLA_TEST_F(DynamicUpdateSliceTest, UInt32R2OOB) {
587 RunR2<uint32_t, int32_t>({{0, 1}, {2, 3}}, {{4}}, {2147483648u, 0},
588 {{0, 1}, {4, 3}});
589 }
590
XLA_TEST_F(DynamicUpdateSliceTest,Int32R3BF16)591 XLA_TEST_F(DynamicUpdateSliceTest, Int32R3BF16) { TestR3<int32_t, bfloat16>(); }
XLA_TEST_F(DynamicUpdateSliceTest,Int32R3)592 XLA_TEST_F(DynamicUpdateSliceTest, Int32R3) { TestR3<int32_t, float>(); }
XLA_TEST_F(DynamicUpdateSliceTest,Int64R3)593 XLA_TEST_F(DynamicUpdateSliceTest, Int64R3) { TestR3<int64_t, int64_t>(); }
XLA_TEST_F(DynamicUpdateSliceTest,UInt64R3)594 XLA_TEST_F(DynamicUpdateSliceTest, UInt64R3) { TestR3<uint64_t, uint64_t>(); }
XLA_TEST_F(DynamicUpdateSliceTest,UInt32R3OOB)595 XLA_TEST_F(DynamicUpdateSliceTest, UInt32R3OOB) {
596 RunR3<uint32_t, int32_t>({{{0, 1}, {2, 3}}, {{4, 5}, {6, 7}}}, {{{8}}},
597 {2147483648u, 0, 2147483648u},
598 {{{0, 1}, {2, 3}}, {{4, 8}, {6, 7}}});
599 }
600
XLA_TEST_F(DynamicUpdateSliceTest,Int32OOBBF16)601 XLA_TEST_F(DynamicUpdateSliceTest, Int32OOBBF16) {
602 TestOOB<int32_t, bfloat16>();
603 }
XLA_TEST_F(DynamicUpdateSliceTest,Int32OOB)604 XLA_TEST_F(DynamicUpdateSliceTest, Int32OOB) { TestOOB<int32_t, float>(); }
XLA_TEST_F(DynamicUpdateSliceTest,Int64OOB)605 XLA_TEST_F(DynamicUpdateSliceTest, Int64OOB) { TestOOB<int64_t, int64_t>(); }
XLA_TEST_F(DynamicUpdateSliceTest,UInt64OOB)606 XLA_TEST_F(DynamicUpdateSliceTest, UInt64OOB) { TestOOB<uint64_t, uint64_t>(); }
607
XLA_TEST_F(DynamicUpdateSliceTest,Int32R1Pred)608 XLA_TEST_F(DynamicUpdateSliceTest, Int32R1Pred) {
609 // Slice at dimension start.
610 RunR1<int32_t, bool>({false, false, true, true, false, true, true, false},
611 {true, true, false}, {0},
612 {true, true, false, true, false, true, true, false});
613 // Slice in the middle.
614 RunR1<int32_t, bool>({false, false, true, true, false, true, true, false},
615 {false, true, true}, {2},
616 {false, false, false, true, true, true, true, false});
617 // Slice at dimension boundaries.
618 RunR1<int32_t, bool>({false, false, true, true, false, true, true, false},
619 {false, true, true}, {5},
620 {false, false, true, true, false, false, true, true});
621 // Zero-sized update.
622 RunR1<int32_t, bool>({false, false, true, true, false, true, true, false}, {},
623 {2},
624 {false, false, true, true, false, true, true, false});
625 }
626
XLA_TEST_F(DynamicUpdateSliceTest,Int32R2Pred)627 XLA_TEST_F(DynamicUpdateSliceTest, Int32R2Pred) {
628 // Slice at dimension start.
629 RunR2<int32_t, bool>(
630 {{false, true, false}, {true, false, true}, {false, true, true}},
631 {{true, false}}, {0, 0},
632 {{true, false, false}, {true, false, true}, {false, true, true}});
633 // Slice in the middle.
634 RunR2<int32_t, bool>(
635 {{false, true, false}, {true, false, true}, {false, true, true}},
636 {{true, false}}, {1, 1},
637 {{false, true, false}, {true, true, false}, {false, true, true}});
638 // Slice at dimension boundaries.
639 RunR2<int32_t, bool>(
640 {{false, true, false}, {true, false, true}, {false, true, true}},
641 {{true, false}}, {2, 1},
642 {{false, true, false}, {true, false, true}, {false, true, false}});
643 // Zero-sized update.
644 RunR2<int32_t, bool>(
645 {{false, true, false}, {true, false, true}, {false, true, true}}, {{}},
646 {2, 1}, {{false, true, false}, {true, false, true}, {false, true, true}});
647 }
648
XLA_TEST_F(DynamicUpdateSliceTest,Int32R3Pred)649 XLA_TEST_F(DynamicUpdateSliceTest, Int32R3Pred) {
650 // R3 Shape: [2, 3, 2]
651 // Slice at dimension start.
652 RunR3<int32_t, bool>(
653 {{{true, false}, {false, true}, {true, true}},
654 {{false, false}, {false, true}, {true, false}}},
655 {{{false, true}, {true, false}}, {{true, true}, {false, true}}},
656 {0, 0, 0},
657 {{{false, true}, {true, false}, {true, true}},
658 {{true, true}, {false, true}, {true, false}}});
659 // Slice in the middle.
660 RunR3<int32_t, bool>({{{true, false}, {false, true}, {true, true}},
661 {{false, false}, {false, true}, {true, false}}},
662 {{{false}, {true}}}, {1, 1, 1},
663 {{{true, false}, {false, true}, {true, true}},
664 {{false, false}, {false, false}, {true, true}}});
665 }
666
667 // Tests for simple R3 case where the update is contiguous (i.e. the minor
668 // two dimensions are not sliced).
XLA_TEST_F(DynamicUpdateSliceTest,R3ContiguousSingleElement)669 XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousSingleElement) {
670 // Single element, index in-bounds
671 std::vector<int32_t> operand_shape({4, 5, 2});
672 RunR3Contiguous<float>(operand_shape, /*index=*/1, /*size=*/1);
673 }
674
XLA_TEST_F(DynamicUpdateSliceTest,R3ContiguousSingleElementBF16)675 XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousSingleElementBF16) {
676 // Single element, index in-bounds
677 std::vector<int32_t> operand_shape({4, 5, 2});
678 RunR3Contiguous<bfloat16>(operand_shape, /*index=*/1, /*size=*/1);
679 }
680
XLA_TEST_F(DynamicUpdateSliceTest,R3ContiguousMultipleElements)681 XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleElements) {
682 // Multiples element, index in-bounds.
683 std::vector<int32_t> operand_shape({4, 5, 2});
684 RunR3Contiguous<float>(operand_shape, /*index=*/1, /*size=*/2);
685 }
686
XLA_TEST_F(DynamicUpdateSliceTest,R3ContiguousMultipleElementsBF16)687 XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleElementsBF16) {
688 // Multiples element, index in-bounds.
689 std::vector<int32_t> operand_shape({4, 5, 2});
690 RunR3Contiguous<bfloat16>(operand_shape, /*index=*/1, /*size=*/2);
691 }
692
XLA_TEST_F(DynamicUpdateSliceTest,R3ContiguousMultipleOOB)693 XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleOOB) {
694 // Multiple element, index out of bounds.
695 std::vector<int32_t> operand_shape({4, 5, 2});
696 RunR3Contiguous<float>(operand_shape, /*index=*/3, /*size=*/2);
697 }
698
XLA_TEST_F(DynamicUpdateSliceTest,R3ContiguousMultipleOOBBF16)699 XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleOOBBF16) {
700 // Multiple element, index out of bounds.
701 std::vector<int32_t> operand_shape({4, 5, 2});
702 RunR3Contiguous<bfloat16>(operand_shape, /*index=*/3, /*size=*/2);
703 }
704
XLA_TEST_F(DynamicUpdateSliceTest,R3ContiguousTooLarge)705 XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousTooLarge) {
706 // Multiple element, update size larger than operand.
707 std::vector<int32_t> operand_shape({4, 5, 2});
708 RunR3Contiguous<float>(operand_shape, /*index=*/5, /*size=*/2);
709 }
710
XLA_TEST_F(DynamicUpdateSliceTest,R3ContiguousTooLargeBF16)711 XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousTooLargeBF16) {
712 // Multiple element, update size larger than operand.
713 std::vector<int32_t> operand_shape({4, 5, 2});
714 RunR3Contiguous<bfloat16>(operand_shape, /*index=*/5, /*size=*/2);
715 }
716
XLA_TEST_F(DynamicUpdateSliceTest,R3ContiguousUnaligned)717 XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousUnaligned) {
718 std::vector<int32_t> operand_shape({3, 123, 247});
719 RunR3Contiguous<float>(operand_shape, /*index=*/1, /*size=*/1);
720 }
721
XLA_TEST_F(DynamicUpdateSliceTest,R3ContiguousUnalignedBF16)722 XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousUnalignedBF16) {
723 std::vector<int32_t> operand_shape({3, 123, 247});
724 RunR3Contiguous<bfloat16>(operand_shape, /*index=*/1, /*size=*/1);
725 }
726
727 // TODO(b/34134076) Disabled on GPU 2016-01-06 due to out-of-memory error.
XLA_TEST_F(DynamicUpdateSliceTest,DISABLED_ON_GPU (R3ContiguousLarger))728 XLA_TEST_F(DynamicUpdateSliceTest, DISABLED_ON_GPU(R3ContiguousLarger)) {
729 std::vector<int32_t> operand_shape({32, 128, 1024});
730 RunR3Contiguous<float>(operand_shape, /*index=*/7, /*size=*/1);
731 }
732
XLA_TEST_F(DynamicUpdateSliceTest,DISABLED_ON_GPU (R3ContiguousLargerBF16))733 XLA_TEST_F(DynamicUpdateSliceTest, DISABLED_ON_GPU(R3ContiguousLargerBF16)) {
734 std::vector<int32_t> operand_shape({32, 128, 1024});
735 RunR3Contiguous<bfloat16>(operand_shape, /*index=*/7, /*size=*/1);
736 }
737
738 // This test that buffer assignment does not alias constants with the output of
739 // dynamic update slice.
XLA_TEST_F(HloTestBase,AddOfDUS)740 XLA_TEST_F(HloTestBase, AddOfDUS) {
741 const char* hlo_string = R"(
742 HloModule m
743 test {
744 o = s32[6] constant({2,3,4,5,6,7})
745 i = s32[] parameter(0)
746 u = s32[2] parameter(1)
747 dus = s32[6] dynamic-update-slice(o,u,i)
748 a = s32[6] add(dus, dus)
749 j = s32[] parameter(2)
750 ROOT ds = s32[2] dynamic-slice(a, j), dynamic_slice_sizes={2}
751 }
752 )";
753 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0}));
754 }
755
BM_DynamicSlice(::testing::benchmark::State & state)756 void BM_DynamicSlice(::testing::benchmark::State& state) {
757 se::Platform* platform = PlatformUtil::GetDefaultPlatform().ValueOrDie();
758 auto executors = PlatformUtil::GetStreamExecutors(platform).ValueOrDie();
759 se::StreamExecutorMemoryAllocator allocator(platform, executors);
760 LocalClient* client =
761 ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie();
762 auto* transfer_manager =
763 TransferManager::GetForPlatform(platform).ValueOrDie();
764 int device_ordinal = client->default_device_ordinal();
765
766 XlaBuilder builder("DynamicSlice");
767
768 // Create input as a constant: shape [1, 2, 3, 4]
769 auto input_literal = LiteralUtil::CreateR4(
770 {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
771 {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}});
772 auto input = ConstantLiteral(&builder, input_literal);
773
774 auto stream =
775 client->mutable_backend()->BorrowStream(device_ordinal).ValueOrDie();
776
777 // Create dynamic slice start indices as a parameter: shape [4]
778 auto start_indices_shape = ShapeUtil::MakeShape(S32, {});
779 std::vector<XlaOp> start_indices(4);
780 std::vector<ScopedShapedBuffer> shaped_buffers;
781 std::vector<const Shape*> host_shapes(4);
782 for (int i = 0; i < 4; ++i) {
783 start_indices[i] =
784 Parameter(&builder, i, start_indices_shape, "start_indices");
785 auto start_index_literal = LiteralUtil::CreateR0<int32_t>(i + 1);
786 // Initialize and transfer parameter buffer.
787 shaped_buffers.emplace_back(
788 client->backend()
789 .transfer_manager()
790 ->AllocateScopedShapedBuffer(start_indices_shape, &allocator,
791 /*device_ordinal=*/0)
792 .value());
793 host_shapes[i] = &shaped_buffers[i].on_host_shape();
794 ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice(
795 stream.get(), start_index_literal, shaped_buffers[i]));
796 }
797
798 // Add DynamicSlice op to the computation.
799 DynamicSlice(input, start_indices, {1, 1, 1, 1});
800 auto computation = builder.Build().value();
801
802 TF_ASSERT_OK_AND_ASSIGN(
803 auto executables,
804 client->Compile(computation, host_shapes, ExecutableBuildOptions()));
805 auto executable = std::move(executables[0]);
806
807 // Run some warm-up executions.
808 ExecutableRunOptions options;
809 options.set_allocator(&allocator);
810 const int kWarmups = 2;
811 std::vector<const ShapedBuffer*> shaped_buffer_ptrs;
812 absl::c_transform(shaped_buffers, std::back_inserter(shaped_buffer_ptrs),
813 [](const ScopedShapedBuffer& buffer) { return &buffer; });
814
815 for (int i = 0; i < kWarmups; ++i) {
816 auto result = executable->Run(shaped_buffer_ptrs, options);
817 ASSERT_TRUE(result.ok());
818 }
819
820 // Run benchmark.
821 for (auto s : state) {
822 auto result = executable->Run(shaped_buffer_ptrs, options);
823 ASSERT_TRUE(result.ok());
824 }
825 }
826 BENCHMARK(BM_DynamicSlice);
827
828 } // namespace
829 } // namespace xla
830