1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/dynamic_padder.h"
17
18 #include "absl/strings/str_replace.h"
19 #include "tensorflow/compiler/xla/client/xla_builder.h"
20 #include "tensorflow/compiler/xla/literal.h"
21 #include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h"
22 #include "tensorflow/compiler/xla/service/dynamic_dimension_simplifier.h"
23 #include "tensorflow/compiler/xla/service/hlo_computation.h"
24 #include "tensorflow/compiler/xla/service/hlo_dce.h"
25 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
26 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
27 #include "tensorflow/compiler/xla/service/hlo_module.h"
28 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
29 #include "tensorflow/compiler/xla/service/hlo_parser.h"
30 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
31 #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
32 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
33 #include "tensorflow/compiler/xla/shape_util.h"
34 #include "tensorflow/compiler/xla/status_macros.h"
35 #include "tensorflow/compiler/xla/test.h"
36 #include "tensorflow/compiler/xla/test_helpers.h"
37 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
38 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
39 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
40 #include "tensorflow/compiler/xla/tests/llvm_irgen_test_base.h"
41 #include "tensorflow/compiler/xla/tests/test_macros.h"
42 #include "tensorflow/compiler/xla/util.h"
43 #include "tensorflow/compiler/xla/xla_data.pb.h"
44 #include "tensorflow/core/lib/core/status_test_util.h"
45 #include "tensorflow/core/platform/test_benchmark.h"
46 #include "tensorflow/core/protobuf/error_codes.pb.h"
47
48 namespace xla {
49 namespace {
50
51 namespace m = ::xla::match;
52 namespace op = xla::testing::opcode_matchers;
53
OpHasDynamismSupport(HloInstruction * hlo)54 OpDynamismSupport OpHasDynamismSupport(HloInstruction* hlo) {
55 if (hlo->opcode() != HloOpcode::kCustomCall) {
56 return OpDynamismSupport::kNoSupport;
57 }
58 if (hlo->custom_call_target() == "OpWithDynamicLowering") {
59 return OpDynamismSupport::kRequired;
60 }
61 return OpDynamismSupport::kNoSupport;
62 }
63
CustomCallDynamicDimensionInference(HloInstruction * hlo,DynamicDimensionInference * inferencer)64 Status CustomCallDynamicDimensionInference(
65 HloInstruction* hlo, DynamicDimensionInference* inferencer) {
66 if (hlo->custom_call_target() == "OpWithDynamicLowering") {
67 if (hlo->shape().IsTuple()) {
68 // Use the operand's dynamic size as output dynamic size.
69 HloInstruction* dynamic_size =
70 inferencer->GetDynamicSize(hlo->mutable_operand(0), {1}, 0);
71 inferencer->SetDynamicSize(hlo, {1}, 0, dynamic_size);
72 } else {
73 // Use the operand's dynamic size as output dynamic size.
74 HloInstruction* dynamic_size =
75 inferencer->GetDynamicSize(hlo->mutable_operand(0), {}, 0);
76 inferencer->SetDynamicSize(hlo, {}, 0, dynamic_size);
77 }
78 }
79
80 return OkStatus();
81 }
82
83 class DynamicPadderTest : public HloTestBase {
84 protected:
DynamicPadderTest()85 DynamicPadderTest() : HloTestBase() { module_ = CreateNewVerifiedModule(); }
86
GetHloModule(const std::string & hlo_text)87 std::unique_ptr<HloModule> GetHloModule(const std::string& hlo_text) {
88 std::unique_ptr<HloModule> module =
89 ParseAndReturnVerifiedModule(hlo_text).ValueOrDie();
90 return module;
91 }
92
RunPadder(bool slice_dynamic_output=false)93 StatusOr<bool> RunPadder(bool slice_dynamic_output = false) {
94 DynamicPadderOptions options;
95 options.slice_dynamic_output = slice_dynamic_output;
96 options.custom_call_handler = CustomCallDynamicDimensionInference;
97 options.op_supports_dynamism_handler = OpHasDynamismSupport;
98 DynamicPadder padder(std::move(options));
99 return RunHloPass(&padder, module_.get());
100 }
101
ExpectPadded(const HloInstruction * inst)102 void ExpectPadded(const HloInstruction* inst) {
103 EXPECT_THAT(inst,
104 op::Select(op::Lt(op::Iota(), op::Broadcast(op::Parameter())),
105 ::testing::_, op::Broadcast()));
106 }
107
GetScalarAddComputation()108 HloComputation* GetScalarAddComputation() {
109 auto embedded_builder = HloComputation::Builder("add");
110 auto lhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
111 0, ShapeUtil::MakeShape(F32, {}), "lhs"));
112 auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
113 1, ShapeUtil::MakeShape(F32, {}), "rhs"));
114 embedded_builder.AddInstruction(
115 HloInstruction::CreateBinary(lhs->shape(), HloOpcode::kAdd, lhs, rhs));
116 return module_->AddEmbeddedComputation(embedded_builder.Build());
117 }
118
119 std::unique_ptr<HloModule> module_;
120 const Shape scalar_shape_ = ShapeUtil::MakeShape(S32, {});
121 };
122
123 class MemoryAlignmentTest : public HloTestBase {};
124
125 // Test that dynamic padder will not cause memory misalignment in CUDA
126 // when the read or write address is not aligned with 32 bits.
127 // TODO(b/203599920): Disabled on CPU due to ASAN test failure.
TEST_F(MemoryAlignmentTest,DISABLED_ON_CPU (TestDataTypeFP16))128 TEST_F(MemoryAlignmentTest, DISABLED_ON_CPU(TestDataTypeFP16)) {
129 const std::string hlo_text = R"(
130 HloModule TestDataTypeFP16
131
132 update_add (p0: f16[], p1: f16[]) -> f16[] {
133 p0 = f16[] parameter(0)
134 p1 = f16[] parameter(1)
135 ROOT out = f16[] add(p0, p1)
136 }
137
138 ENTRY main () -> f16[<=1,1] {
139 c1 = s32[1]{0} constant({1})
140 c2 = f16[1,1]{1,0} constant({ {0.099976} })
141 shape = s32[] reshape(s32[1]{0} c1)
142 dim_size = f16[<=1,1]{1,0} set-dimension-size(f16[1,1]{1,0} c2, s32[] shape),
143 dimensions={0}
144 ROOT out = f16[<=1,1]{1,0} scatter(f16[<=1,1]{1,0} dim_size, s32[1]{0} c1, f16[1,1]{1,0} c2),
145 update_window_dims={1},
146 inserted_window_dims={0},
147 scatter_dims_to_operand_dims={0},
148 index_vector_dim=1,
149 to_apply=update_add
150 }
151 )";
152 EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
153 }
154
TEST_F(DynamicPadderTest,ReduceTest)155 TEST_F(DynamicPadderTest, ReduceTest) {
156 auto builder = HloComputation::Builder(TestName());
157 auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
158 auto reduce_shape = ShapeUtil::MakeShape(F32, {2});
159
160 auto data_param = builder.AddInstruction(
161 HloInstruction::CreateParameter(0, input_shape, "data_param"));
162 builder.AddInstruction(
163 HloInstruction::CreateParameter(1, scalar_shape_, "size_param"));
164
165 auto negate = builder.AddInstruction(
166 HloInstruction::CreateUnary(input_shape, HloOpcode::kNegate, data_param));
167
168 auto init = builder.AddInstruction(
169 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
170
171 auto reduce = builder.AddInstruction(HloInstruction::CreateReduce(
172 reduce_shape, negate, init, {0, 2}, GetScalarAddComputation()));
173 EXPECT_FALSE(module_->is_dynamic());
174 module_->AddEntryComputation(builder.Build());
175
176 // Set up dynamic parameter binding.
177 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
178 DynamicParameterBinding::DynamicParameter{1, {}},
179 DynamicParameterBinding::DynamicDimension{0, {}, 2}));
180
181 TF_ASSERT_OK(RunPadder().status());
182
183 ExpectPadded(reduce->operand(0));
184 EXPECT_TRUE(module_->is_dynamic());
185 }
186
TEST_F(DynamicPadderTest,DynamicLoweringTest)187 TEST_F(DynamicPadderTest, DynamicLoweringTest) {
188 const std::string hlo_text = R"(
189 HloModule DynamicLowering
190
191 ENTRY main {
192 param = s32[5] parameter(0)
193 const = s32[] constant(3)
194 param_padded = s32[<=5] set-dimension-size(param, const),
195 dimensions={0}
196 custom-call.1 = s32[<=5] custom-call(param_padded),
197 custom_call_target="OpWithDynamicLowering"
198 custom-call.2 = s32[<=5] custom-call(custom-call.1),
199 custom_call_target="OpWithDynamicLowering"
200 // Negate doesn't support dynamic lowering.
201 ROOT negate = s32[<=5] negate(custom-call.2)
202 }
203 )";
204
205 module_ = GetHloModule(hlo_text);
206
207 TF_ASSERT_OK(RunPadder(/*slice_dynamic_output=*/true).status());
208 // After rewrite, we should have :
209 //
210 // param
211 // |
212 // SliceToDynamic
213 // |
214 // OpWithDynamicLowering (custom_call_1)
215 // |
216 // OpWithDynamicLowering (custom_call_2)
217 // |
218 // PadToStatic
219 // |
220 // Negate
221 // |
222 // SliceToDynamic // Root require dynamic form tensor.
223 auto custom_call_1 =
224 module_->entry_computation()->GetInstructionWithName("custom-call.1");
225 auto custom_call_2 =
226 module_->entry_computation()->GetInstructionWithName("custom-call.2");
227 // Test that the input to custom call
228 HloInstruction* slice_to_dynamic = custom_call_1->mutable_operand(0);
229 ASSERT_THAT(slice_to_dynamic->opcode(), HloOpcode::kCustomCall);
230 ASSERT_THAT(slice_to_dynamic->custom_call_target(), "SliceToDynamic");
231 ASSERT_EQ(custom_call_2->user_count(), 1);
232 HloInstruction* pad_to_static = custom_call_2->users()[0];
233 ASSERT_THAT(pad_to_static->opcode(), HloOpcode::kCustomCall);
234 ASSERT_THAT(pad_to_static->custom_call_target(), "PadToStatic");
235 slice_to_dynamic = module_->entry_computation()->root_instruction();
236 ASSERT_THAT(slice_to_dynamic->opcode(), HloOpcode::kCustomCall);
237 ASSERT_THAT(slice_to_dynamic->custom_call_target(), "SliceToDynamic");
238 }
239
TEST_F(DynamicPadderTest,DynamicLoweringTestTupleInput)240 TEST_F(DynamicPadderTest, DynamicLoweringTestTupleInput) {
241 const std::string hlo_text = R"(
242 HloModule DynamicLowering
243
244 ENTRY main {
245 param = s32[5] parameter(0)
246 const = s32[] constant(3)
247 param_padded = s32[<=5] set-dimension-size(param, const),
248 dimensions={0}
249 // Create a tuple with static and dynamic componenet.
250 tuple_arg = (s32[], s32[<=5]) tuple(const, param_padded)
251 custom-call.1 = (s32[], s32[<=5]) custom-call(tuple_arg),
252 custom_call_target="OpWithDynamicLowering"
253 custom-call.2 = (s32[], s32[<=5]) custom-call(custom-call.1),
254 custom_call_target="OpWithDynamicLowering"
255 data = s32[<=5]{0} get-tuple-element(custom-call.2), index=1
256 // Negate doesn't support dynamic lowering.
257 ROOT negate = s32[<=5] negate(data)
258 }
259 )";
260
261 module_ = GetHloModule(hlo_text);
262
263 TF_ASSERT_OK(RunPadder(/*slice_dynamic_output=*/true).status());
264 // After rewrite, we should have :
265 //
266 // param
267 // |
268 // SliceToDynamic
269 // |
270 // Tuple
271 // |
272 // OpWithDynamicLowering (custom_call_1)
273 // |
274 // OpWithDynamicLowering (custom_call_2)
275 // |
276 // GTE
277 // |
278 // PadToStatic
279 // |
280 // Negate
281 // |
282 // SliceToDynamic // Root require dynamic form tensor.
283
284 auto* root = module_->entry_computation()->root_instruction();
285 EXPECT_THAT(root,
286 op::CustomCall("SliceToDynamic", op::Negate(), op::Constant()));
287 HloInstruction* negate = root->mutable_operand(0);
288 EXPECT_THAT(
289 negate,
290 op::Negate(op::GetTupleElement(op::CustomCall(
291 "PadToStatic", op::GetTupleElement(op::CustomCall(
292 "OpWithDynamicLowering", ::testing::_))))));
293 auto custom_call_1 =
294 module_->entry_computation()->GetInstructionWithName("custom-call.1");
295 EXPECT_THAT(custom_call_1,
296 op::CustomCall("OpWithDynamicLowering",
297 op::Tuple(op::GetTupleElement(),
298 op::CustomCall("SliceToDynamic"))));
299 }
300
TEST_F(DynamicPadderTest,DynamicOutputNestedTuple)301 TEST_F(DynamicPadderTest, DynamicOutputNestedTuple) {
302 const std::string hlo_text = R"(
303 HloModule DynamicLowering
304
305 ENTRY main {
306 param = s32[5] parameter(0)
307 const = s32[] constant(3)
308 const2 = s32[] constant(4)
309 param_padded = s32[<=5] set-dimension-size(param, const),
310 dimensions={0}
311 // Create a tuple with static and dynamic componenet.
312 tuple0 = (s32[], s32[<=5]) tuple(const, param_padded)
313 ROOT tuple1 = (s32[], (s32[], s32[<=5])) tuple(const2, tuple0)
314 }
315 )";
316
317 module_ = GetHloModule(hlo_text);
318
319 TF_ASSERT_OK(RunPadder(/*slice_dynamic_output=*/true).status());
320 TF_ASSERT_OK(TupleSimplifier().Run(module_.get()).status());
321 XLA_LOG_LINES(0, module_->ToString());
322
323 auto* root = module_->entry_computation()->root_instruction();
324 EXPECT_THAT(root, op::Tuple(op::Constant(), op::Tuple()));
325 HloInstruction* nested_tuple = root->mutable_operand(1);
326 EXPECT_THAT(nested_tuple,
327 op::Tuple(op::Constant(), op::CustomCall("SliceToDynamic")));
328 }
329
TEST_F(DynamicPadderTest,ConvolutionTest)330 TEST_F(DynamicPadderTest, ConvolutionTest) {
331 auto builder = HloComputation::Builder(TestName());
332 constexpr int xdim = 3;
333 constexpr int ydim = 2;
334 constexpr int zdim = 1;
335 auto xy_shape = ShapeUtil::MakeShape(F32, {xdim, ydim});
336 auto yz_shape = ShapeUtil::MakeShape(F32, {ydim, zdim});
337 auto zx_shape = ShapeUtil::MakeShape(F32, {zdim, xdim});
338
339 auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
340 /*parameter_number=*/0, xy_shape, "A"));
341 auto* b_param = builder.AddInstruction(HloInstruction::CreateParameter(
342 /*parameter_number=*/1, yz_shape, "B"));
343 builder.AddInstruction(HloInstruction::CreateParameter(
344 /*parameter_number=*/2, scalar_shape_, "size_param"));
345
346 auto dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(0);
347
348 dnums.set_kernel_input_feature_dimension(0);
349 dnums.set_kernel_output_feature_dimension(1);
350 dnums.set_input_batch_dimension(0);
351 dnums.set_output_batch_dimension(1);
352 dnums.set_output_feature_dimension(0);
353
354 Window window;
355
356 auto* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
357 zx_shape, a_param, b_param, /*feature_group_count=*/1,
358 /*batch_group_count=*/1, window, dnums,
359 HloTestBase::DefaultPrecisionConfig(2)));
360
361 module_->AddEntryComputation(builder.Build());
362
363 // Set up binding for contracting dimensions.
364 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
365 DynamicParameterBinding::DynamicParameter{2, {}},
366 DynamicParameterBinding::DynamicDimension{0, {}, 1}));
367
368 TF_ASSERT_OK(RunPadder().status());
369
370 ExpectPadded(conv->operand(0));
371 }
372
TEST_F(DynamicPadderTest,ConvolutionNoPad)373 TEST_F(DynamicPadderTest, ConvolutionNoPad) {
374 auto builder = HloComputation::Builder(TestName());
375 constexpr int xdim = 3;
376 constexpr int ydim = 2;
377 constexpr int zdim = 1;
378 auto xy_shape = ShapeUtil::MakeShape(F32, {xdim, ydim});
379 auto yz_shape = ShapeUtil::MakeShape(F32, {ydim, zdim});
380 auto zx_shape = ShapeUtil::MakeShape(F32, {zdim, xdim});
381
382 auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
383 /*parameter_number=*/0, xy_shape, "A"));
384 auto* b_param = builder.AddInstruction(HloInstruction::CreateParameter(
385 /*parameter_number=*/1, yz_shape, "B"));
386 builder.AddInstruction(HloInstruction::CreateParameter(
387 /*parameter_number=*/2, scalar_shape_, "size_param"));
388
389 auto dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(0);
390
391 dnums.set_kernel_input_feature_dimension(0);
392 dnums.set_kernel_output_feature_dimension(1);
393 dnums.set_input_batch_dimension(0);
394 dnums.set_output_batch_dimension(1);
395 dnums.set_output_feature_dimension(0);
396
397 Window window;
398
399 auto* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
400 zx_shape, a_param, b_param, /*feature_group_count=*/1,
401 /*batch_group_count=*/1, window, dnums,
402 HloTestBase::DefaultPrecisionConfig(2)));
403
404 module_->AddEntryComputation(builder.Build());
405
406 // Set up dynamic parameter binding for non-contracting dimension.
407 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
408 DynamicParameterBinding::DynamicParameter{2, {}},
409 DynamicParameterBinding::DynamicDimension{0, {}, 0}));
410
411 TF_ASSERT_OK(RunPadder().status());
412
413 EXPECT_THAT(conv->operand(0), op::Parameter());
414 }
415
TEST_F(DynamicPadderTest,ReduceWindowNoPadForTrivialWindow)416 TEST_F(DynamicPadderTest, ReduceWindowNoPadForTrivialWindow) {
417 auto builder = HloComputation::Builder(TestName());
418 auto input_shape = ShapeUtil::MakeShape(F32, {4, 5});
419 auto reduce_shape = ShapeUtil::MakeShape(F32, {3, 5});
420
421 auto input = builder.AddInstruction(
422 HloInstruction::CreateParameter(0, input_shape, "input"));
423 builder.AddInstruction(
424 HloInstruction::CreateParameter(1, scalar_shape_, "size_param"));
425 auto init = builder.AddInstruction(
426 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
427 TF_ASSERT_OK_AND_ASSIGN(Window window, ParseWindow("size=2x1 pad=0_0x0_0"));
428 auto output = builder.AddInstruction(HloInstruction::CreateReduceWindow(
429 reduce_shape, input, init, window, GetScalarAddComputation()));
430
431 module_->AddEntryComputation(builder.Build());
432
433 // Set up dynamic parameter binding.
434 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
435 DynamicParameterBinding::DynamicParameter{1, {}},
436 DynamicParameterBinding::DynamicDimension{0, {}, 1}));
437
438 TF_ASSERT_OK(RunPadder().status());
439
440 EXPECT_THAT(output->operand(0), op::Parameter());
441 }
442
TEST_F(DynamicPadderTest,VariadicReduceWindowNoPadForTrivialWindow)443 TEST_F(DynamicPadderTest, VariadicReduceWindowNoPadForTrivialWindow) {
444 const std::string hlo_text = R"(
445 HloModule VariadicReduceWindowNoPadForTrivialWindow
446
447 add_f32 (a: f32[], b: s32[], c: f32[], d: s32[]) -> (f32[], s32[]) {
448 a = f32[] parameter(0)
449 b = s32[] parameter(1)
450 c = f32[] parameter(2)
451 d = s32[] parameter(3)
452 add.0 = f32[] add(a, c)
453 add.1 = s32[] add(b, d)
454 ROOT out = tuple(add.0, add.1)
455 }
456
457 ENTRY main {
458 input.0 = f32[4, 5] parameter(0)
459 input.1 = s32[4, 5] parameter(1)
460 size_param.0 = s32[] parameter(2)
461 size_param.1 = s32[] parameter(3)
462 init.0 = f32[] constant(0.0)
463 init.1 = s32[] constant(0)
464 ROOT output = (f32[3, 5], s32[3, 5]) reduce-window(input.0, input.1, init.0, init.1), window={size=2x1 pad=0_0x0_0}, to_apply=add_f32
465 }
466 )";
467
468 const int kNumParams = 2;
469 module_ = ParseAndReturnVerifiedModule(hlo_text).ValueOrDie();
470 // Set up dynamic parameter binding.
471 for (int i = 0; i < kNumParams; ++i) {
472 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
473 DynamicParameterBinding::DynamicParameter{2, {}},
474 DynamicParameterBinding::DynamicDimension{i, {}, 1}));
475 }
476
477 TF_ASSERT_OK(RunPadder().status());
478
479 for (int i = 0; i < kNumParams; ++i) {
480 EXPECT_THAT(module_->entry_computation()->root_instruction()->operand(i),
481 op::Parameter());
482 }
483 }
484
TEST_F(DynamicPadderTest,PadS8ToS32Dot)485 TEST_F(DynamicPadderTest, PadS8ToS32Dot) {
486 const std::string hlo_text = R"(
487 HloModule test
488 ENTRY test {
489 a = s8[<=16,32] parameter(0)
490 b = s8[32,64] parameter(1)
491 ROOT root = s32[<=16,64] dot(a, b), lhs_contracting_dims={1}, rhs_contracting_dims={0}
492 }
493 )";
494 module_ = GetHloModule(hlo_text);
495 TF_ASSERT_OK(RunPadder(/*slice_dynamic_output=*/true).status());
496
497 EXPECT_THAT(module_->entry_computation()->root_instruction(),
498 GmockMatch(m::CustomCall("SliceToDynamic",
499 m::Dot(m::Op().WithShape(S8, {16, 32}),
500 m::Op().WithShape(S8, {32, 64}))
501 .WithShape(S32, {16, 64}),
502 m::Op(), m::Op())));
503 }
504
TEST_F(DynamicPadderTest,PadToStaticForCustomCall)505 TEST_F(DynamicPadderTest, PadToStaticForCustomCall) {
506 const std::string hlo_text = R"(
507 HloModule test
508 ENTRY test {
509 a = f32[64] parameter(0)
510 ROOT c = f32[<=128] custom-call(a),
511 custom_call_target="UnknownOp"
512 }
513 )";
514
515 module_ = GetHloModule(hlo_text);
516 TF_ASSERT_OK(RunPadder(/*slice_dynamic_output=*/true).status());
517
518 EXPECT_THAT(
519 module_->entry_computation()->root_instruction(),
520 GmockMatch(m::CustomCall("SliceToDynamic",
521 m::GetTupleElement(m::CustomCall(
522 "PadToStatic", m::CustomCall("UnknownOp"))),
523 m::Op())));
524 }
525
526 // Test that dynamic padder has the same result as if not padded.
527 class ExecutionTest : public HloTestBase {
528 protected:
GetHloModule(const std::string & hlo_text)529 std::unique_ptr<HloModule> GetHloModule(const std::string& hlo_text) {
530 std::unique_ptr<HloModule> module =
531 ParseAndReturnVerifiedModule(hlo_text).ValueOrDie();
532 return module;
533 }
PadAndExecute(std::unique_ptr<HloModule> module,absl::Span<Literal * const> arguments,bool slice_dynamic_output=true)534 Literal PadAndExecute(std::unique_ptr<HloModule> module,
535 absl::Span<Literal* const> arguments,
536 bool slice_dynamic_output = true) {
537 if (!slice_dynamic_output) {
538 auto new_config = module->config();
539 new_config.mutable_entry_computation_layout()
540 ->mutable_result_layout()
541 ->ClearDynamicShape();
542 module->set_config(new_config);
543 }
544 DynamicPadderOptions options;
545 options.slice_dynamic_output = slice_dynamic_output;
546 DynamicPadder padder(options);
547 TF_CHECK_OK(padder.Run(module.get()).status());
548 HloDCE dce;
549 TF_CHECK_OK(dce.Run(module.get()).status());
550 return ExecuteAndTransfer(std::move(module), arguments);
551 }
552 };
553
XLA_TEST_F(ExecutionTest,ScatterUpdate)554 XLA_TEST_F(ExecutionTest, ScatterUpdate) {
555 // Test that scattering on indices=[2] is same as scattering on indices=[4]
556 // and dynamic dimension = 2
557 const std::string hlo_text = R"(
558 HloModule TensorFlowScatterV1
559
560 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
561 lhs = s32[] parameter(0)
562 ROOT rhs = s32[] parameter(1)
563 }
564
565 ENTRY main {
566 operand = s32[3,3] parameter(0)
567 indices = s32[INDICES_BOUND] parameter(1)
568 updates = s32[INDICES_BOUND,3] parameter(2)
569 dynamic_size = s32[] parameter(3)
570 ROOT scatter = s32[3,3] scatter(operand, indices, updates),
571 to_apply=update_s32,
572 update_window_dims={1},
573 inserted_window_dims={0},
574 scatter_dims_to_operand_dims={0},
575 index_vector_dim=1
576
577 }
578 )";
579 const std::string hlo_text_not_padded =
580 absl::StrReplaceAll(hlo_text, {{"INDICES_BOUND", "2"}});
581 auto module_not_padded = GetHloModule(hlo_text_not_padded);
582
583 Literal operand =
584 LiteralUtil::CreateR2<int32_t>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
585 Literal scatter_indices = LiteralUtil::CreateR1<int32_t>({0, 2});
586 Literal updates =
587 LiteralUtil::CreateR2<int32_t>({{10, 20, 30}, {70, 80, 90}});
588 Literal dynamic_size = LiteralUtil::CreateR0<int32_t>(2);
589
590 Literal not_padded =
591 ExecuteAndTransfer(std::move(module_not_padded),
592 {&operand, &scatter_indices, &updates, &dynamic_size});
593
594 // Pad input to 4.
595 const std::string hlo_text_padded =
596 absl::StrReplaceAll(hlo_text, {{"INDICES_BOUND", "4"}});
597 auto module_padded = GetHloModule(hlo_text_padded);
598 // Set up dynamic parameter binding.
599 TF_CHECK_OK(module_padded->dynamic_parameter_binding().Bind(
600 DynamicParameterBinding::DynamicParameter{3, {}},
601 DynamicParameterBinding::DynamicDimension{1, {}, 0}));
602 TF_CHECK_OK(module_padded->dynamic_parameter_binding().Bind(
603 DynamicParameterBinding::DynamicParameter{3, {}},
604 DynamicParameterBinding::DynamicDimension{2, {}, 0}));
605 // Pad the rest of input with garbage data.
606 Literal scatter_indices_padded = LiteralUtil::CreateR1<int32_t>({0, 2, 0, 4});
607 Literal updates_padded = LiteralUtil::CreateR2<int32_t>(
608 {{10, 20, 30}, {70, 80, 90}, {30, 22, 11}, {-1, 20, -1}});
609 DynamicPadder padder;
610 TF_CHECK_OK(padder.Run(module_padded.get()).status());
611 Literal padded = PadAndExecute(
612 std::move(module_padded),
613 {&operand, &scatter_indices_padded, &updates_padded, &dynamic_size});
614
615 EXPECT_EQ(padded, not_padded);
616 }
617
XLA_TEST_F(ExecutionTest,ScatterUpdateWindowDim)618 XLA_TEST_F(ExecutionTest, ScatterUpdateWindowDim) {
619 const std::string hlo_text = R"(
620 HloModule ScatterUpdateWindowDim
621
622 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
623 lhs = s32[] parameter(0)
624 ROOT rhs = s32[] parameter(1)
625 }
626
627 ENTRY main {
628 operand = s32[1,2,3] parameter(0)
629 indices = s32[1] parameter(1)
630 updates = s32[2,3,1] parameter(2)
631 dynamic_size = s32[] constant(1)
632 operand_dynamic = s32[1, <=2, 3] set-dimension-size(operand, dynamic_size),
633 dimensions={1}
634 updates_dynamic = s32[<=2, 3, 1] set-dimension-size(updates, dynamic_size),
635 dimensions={0}
636 ROOT scatter = s32[1, <=2, 3] scatter(operand_dynamic, indices, updates_dynamic),
637 to_apply=update_s32,
638 update_window_dims={0, 1},
639 inserted_window_dims={0},
640 scatter_dims_to_operand_dims={0},
641 index_vector_dim=1
642
643 }
644 )";
645 auto hlo_module = GetHloModule(hlo_text);
646
647 Literal operand = LiteralUtil::CreateR3<int32_t>({{{0, 0, 0}, {0, 0, 0}}});
648 Literal scatter_indices = LiteralUtil::CreateR1<int32_t>({0});
649 Literal updates =
650 LiteralUtil::CreateR3<int32_t>({{{10}, {20}, {30}}, {{70}, {80}, {90}}});
651
652 Literal padded = PadAndExecute(std::move(hlo_module),
653 {&operand, &scatter_indices, &updates}, false);
654 Literal expected =
655 LiteralUtil::CreateR3<int32_t>({{{10, 20, 30}, {70, 80, 90}}});
656 EXPECT_EQ(padded, expected);
657 }
658
XLA_TEST_F(ExecutionTest,ScatterUpdateF32)659 XLA_TEST_F(ExecutionTest, ScatterUpdateF32) {
660 // Test that scattering on indices=[2] is same as scattering on indices=[4]
661 // and dynamic dimension = 2
662 const std::string hlo_text = R"(
663 HloModule TensorFlowScatterV1
664
665 update_f32 (lhs: f32[], rhs: f32[]) -> f32[] {
666 lhs = f32[] parameter(0)
667 ROOT rhs = f32[] parameter(1)
668 }
669
670 ENTRY main {
671 operand = f32[3,3] parameter(0)
672 indices = s32[2] parameter(1)
673 updates = f32[2,3] parameter(2)
674 dynamic_size = s32[] parameter(3)
675 ROOT scatter = f32[3,3] scatter(operand, indices, updates),
676 to_apply=update_f32,
677 update_window_dims={1},
678 inserted_window_dims={0},
679 scatter_dims_to_operand_dims={0},
680 index_vector_dim=1
681
682 }
683 )";
684
685 auto module_not_padded = GetHloModule(hlo_text);
686
687 Literal operand = LiteralUtil::CreateR2<float>(
688 {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}, {7.0, 8.0, 9.0}});
689 Literal scatter_indices = LiteralUtil::CreateR1<int32_t>({0, 2});
690 Literal updates =
691 LiteralUtil::CreateR2<float>({{10.0, 20.0, 30.0}, {70.0, 80.0, 90.0}});
692 // Dynamic Size is 1, pad to 2
693 Literal dynamic_size = LiteralUtil::CreateR0<int32_t>(1);
694
695 auto module_padded = GetHloModule(hlo_text);
696 // Set up dynamic parameter binding.
697 TF_CHECK_OK(module_padded->dynamic_parameter_binding().Bind(
698 DynamicParameterBinding::DynamicParameter{3, {}},
699 DynamicParameterBinding::DynamicDimension{1, {}, 0}));
700 TF_CHECK_OK(module_padded->dynamic_parameter_binding().Bind(
701 DynamicParameterBinding::DynamicParameter{3, {}},
702 DynamicParameterBinding::DynamicDimension{2, {}, 0}));
703 DynamicPadder padder;
704 TF_CHECK_OK(padder.Run(module_padded.get()).status());
705 Literal not_padded =
706 PadAndExecute(std::move(module_padded),
707 {&operand, &scatter_indices, &updates, &dynamic_size});
708 // Although we have two indices, only the first element is updated because of
709 // padding.
710 EXPECT_EQ(LiteralUtil::CreateR2<float>(
711 {{10.0, 20.0, 30.0}, {4.0, 5.0, 6.0}, {7.0, 8.0, 9.0}}),
712 not_padded);
713 }
714
XLA_TEST_F(ExecutionTest,WholeDimensionGather)715 XLA_TEST_F(ExecutionTest, WholeDimensionGather) {
716 // Second dimension (size 2) is dynamic, assuming real size is 1 and padded to
717 // 2:
718 //
719 // [[1, 2]
720 // [3, 4]
721 // [5, 6]]
722 //
723 // Gathering the second dimension out creates:
724 //
725 // [3, 4]
726 //
727 // Reducing this gives us 3 (4 is padded value so ignored)
728 const std::string hlo_text = R"(
729 HloModule TensorFlowScatterV1
730
731 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
732 lhs = s32[] parameter(0)
733 rhs = s32[] parameter(1)
734 ROOT add = s32[] add(lhs, rhs)
735 }
736
737 ENTRY main {
738 param = s32[3, 2, 1] parameter(0)
739 size = s32[] constant(1)
740 param_padded = s32[3, 2, 1] set-dimension-size(param, size), dimensions={1}
741 index = s32[] constant(1)
742 gather = s32[2,1]{1,0} gather(param_padded, index),
743 offset_dims={0,1},
744 collapsed_slice_dims={0},
745 start_index_map={0},
746 index_vector_dim=0,
747 slice_sizes={1,2,1}
748 init = s32[] constant(0)
749 ROOT reduce = s32[] reduce(gather, init),
750 dimensions={0, 1},
751 to_apply=update_s32
752 }
753 )";
754 // Slicing out entire dimension propagates the dimension
755 Literal operand =
756 LiteralUtil::CreateR3<int32_t>({{{1}, {2}}, {{3}, {4}}, {{5}, {6}}});
757 auto module = GetHloModule(hlo_text);
758 DynamicPadder padder;
759 TF_CHECK_OK(padder.Run(module.get()).status());
760 Literal result = PadAndExecute(std::move(module), {&operand});
761
762 // Only first element will be reduced.
763 Literal expected = LiteralUtil::CreateR0<int32_t>(3);
764
765 EXPECT_EQ(result, expected);
766 }
767
XLA_TEST_F(ExecutionTest,TwoDimensionReduce)768 XLA_TEST_F(ExecutionTest, TwoDimensionReduce) {
769 // Test that reducing on operand=[2,2] is same as reducing on operand=[4,4]
770 // and dynamic dimension = 2
771 const std::string hlo_text = R"(
772 HloModule TensorFlowScatterV1
773
774 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
775 lhs = s32[] parameter(0)
776 rhs = s32[] parameter(1)
777 ROOT add = s32[] add(lhs, rhs)
778 }
779
780 ENTRY main {
781 param = s32[INDICES_BOUND, INDICES_BOUND] parameter(0)
782 dynamic_size = s32[] parameter(1)
783 const = s32[] constant(0)
784 ROOT reduce = s32[] reduce(param, const),
785 dimensions={0, 1},
786 to_apply=update_s32
787 }
788 )";
789 const std::string hlo_text_not_padded =
790 absl::StrReplaceAll(hlo_text, {{"INDICES_BOUND", "2"}});
791 auto module_not_padded = GetHloModule(hlo_text_not_padded);
792
793 Literal operand = LiteralUtil::CreateR2<int32_t>({{1, 2}, {4, 5}});
794 Literal dynamic_size = LiteralUtil::CreateR0<int32_t>(2);
795
796 Literal not_padded = ExecuteAndTransfer(std::move(module_not_padded),
797 {&operand, &dynamic_size});
798
799 // Pad input to 4.
800 const std::string hlo_text_padded =
801 absl::StrReplaceAll(hlo_text, {{"INDICES_BOUND", "4"}});
802 auto module_padded = GetHloModule(hlo_text_padded);
803 // Set up dynamic parameter binding.
804 TF_CHECK_OK(module_padded->dynamic_parameter_binding().Bind(
805 DynamicParameterBinding::DynamicParameter{1, {}},
806 DynamicParameterBinding::DynamicDimension{0, {}, 0}));
807 TF_CHECK_OK(module_padded->dynamic_parameter_binding().Bind(
808 DynamicParameterBinding::DynamicParameter{1, {}},
809 DynamicParameterBinding::DynamicDimension{0, {}, 1}));
810 // Pad the rest of input with garbage data.
811 Literal operand_padded = LiteralUtil::CreateR2<int32_t>(
812 {{1, 2, 3, 4}, {4, 5, 6, 7}, {1, 2, 3, 4}, {4, 5, 6, 7}});
813 DynamicPadder padder;
814 TF_CHECK_OK(padder.Run(module_padded.get()).status());
815 Literal padded =
816 PadAndExecute(std::move(module_padded), {&operand_padded, &dynamic_size});
817
818 EXPECT_EQ(padded, not_padded);
819 }
820
XLA_TEST_F(ExecutionTest,DynamicDimensionClamp)821 XLA_TEST_F(ExecutionTest, DynamicDimensionClamp) {
822 const std::string hlo_text = R"(
823 HloModule TensorFlowTenaryV1
824
825 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
826 lhs = s32[] parameter(0)
827 rhs = s32[] parameter(1)
828 ROOT add = s32[] add(lhs, rhs)
829 }
830
831 ENTRY main {
832 param = s32[5] parameter(0)
833 const = s32[] constant(3)
834 param_padded = s32[5] set-dimension-size(param, const), dimensions={0}
835 clamp = s32[5] clamp(param_padded, param_padded, param_padded)
836 init = s32[] constant(0)
837 ROOT reduce = s32[] reduce(clamp, init),
838 dimensions={0},
839 to_apply=update_s32
840 }
841 )";
842
843 // Input has upper bound of 5, dynamic dimension is 3.
844 Literal operand = LiteralUtil::CreateR1<int32_t>({1, 2, 3, 4, 5});
845 auto module = GetHloModule(hlo_text);
846
847 Literal result = PadAndExecute(std::move(module), {&operand});
848
849 // only first 3 elements will be reduced.
850 Literal expected = LiteralUtil::CreateR0<int32_t>(6);
851
852 EXPECT_EQ(result, expected);
853 }
854
XLA_TEST_F(ExecutionTest,DynamicConcat)855 XLA_TEST_F(ExecutionTest, DynamicConcat) {
856 // Concatting a list of {dynamic_operand, static_operand, dynamic_operand}.
857 const std::string hlo_text = R"(
858 HloModule DynamicConcat
859
860 ENTRY main {
861 param_0 = s32[3] parameter(0)
862 param_1 = s32[3] parameter(1)
863 param_2 = s32[3] parameter(2)
864 size = s32[] constant(2)
865 param_padded_0 = s32[<=3] set-dimension-size(param_0, size), dimensions={0}
866 param_padded_2 = s32[<=3] set-dimension-size(param_2, size), dimensions={0}
867 ROOT %concatenate = s32[9]
868 concatenate(s32[<=3] param_padded_0, s32[<=3] param_1, s32[<=3] param_padded_2),
869 dimensions={0}
870 }
871 )";
872
873 // Input has upper bound of 3, dynamic dimension is 2. Using -1 as padding.
874 Literal operand_0 =
875 LiteralUtil::CreateR1<int32_t>({1, 2, -1}); // Dynamic operand.
876 Literal operand_1 =
877 LiteralUtil::CreateR1<int32_t>({3, 4, 5}); // Static operand.
878 Literal operand_2 =
879 LiteralUtil::CreateR1<int32_t>({6, 7, -1}); // Dynamic operand.
880 auto module = GetHloModule(hlo_text);
881
882 Literal result = PadAndExecute(std::move(module),
883 {&operand_0, &operand_1, &operand_2}, false);
884 result.SetDynamicSize(0, 7);
885 Literal expected = LiteralUtil::CreateR1<int32_t>({1, 2, 3, 4, 5, 6, 7});
886
887 EXPECT_EQ(result, expected);
888 }
889
XLA_TEST_F(ExecutionTest,DynamicReverseSingleDim)890 XLA_TEST_F(ExecutionTest, DynamicReverseSingleDim) {
891 const std::string hlo_text = R"(
892 HloModule DynamicConcat
893
894 ENTRY main {
895 param_0 = s32[3] parameter(0)
896 size = s32[] constant(2)
897 param_padded_0 = s32[<=3] set-dimension-size(param_0, size), dimensions={0}
898 ROOT %reverse = s32[<=3]
899 reverse(s32[<=3] param_padded_0),
900 dimensions={0}
901 }
902 )";
903
904 // Input has upper bound of 3, dynamic dimension is 2. Using -1 as padding.
905 Literal operand_0 =
906 LiteralUtil::CreateR1<int32_t>({1, 2, -1}); // Dynamic operand.
907 auto module = GetHloModule(hlo_text);
908
909 Literal result = PadAndExecute(std::move(module), {&operand_0}, false);
910 result.SetDynamicSize(0, 2);
911 Literal expected = LiteralUtil::CreateR1<int32_t>({2, 1});
912
913 EXPECT_EQ(result, expected);
914 }
915
XLA_TEST_F(ExecutionTest,DynamicReverseMultiDims)916 XLA_TEST_F(ExecutionTest, DynamicReverseMultiDims) {
917 const std::string hlo_text = R"(
918 HloModule DynamicConcat
919
920 ENTRY main {
921 param_0 = s32[3, 3] parameter(0)
922 size = s32[] constant(2)
923 param_padded_0 = s32[<=3, 3] set-dimension-size(param_0, size), dimensions={0}
924 param_padded_1 = s32[<=3, <=3] set-dimension-size(param_padded_0, size),
925 dimensions={1}
926 ROOT %reverse = s32[<=3, <=3]
927 reverse(s32[<=3, <=3] param_padded_1),
928 dimensions={0, 1}
929 }
930 )";
931
932 // Input has upper bound of 3, dynamic dimension is 2. Using -1 as padding.
933 Literal operand_0 = LiteralUtil::CreateR2<int32_t>(
934 {{1, 2, -1}, {3, 4, -1}, {-1, -1, -1}}); // Dynamic operand.
935 auto module = GetHloModule(hlo_text);
936
937 Literal result = PadAndExecute(std::move(module), {&operand_0}, false);
938 result.SetDynamicSize(0, 2);
939 result.SetDynamicSize(1, 2);
940 Literal expected = LiteralUtil::CreateR2<int32_t>({{4, 3}, {2, 1}});
941
942 EXPECT_EQ(result, expected);
943 }
944
XLA_TEST_F(ExecutionTest,DynamicDimensionReduce)945 XLA_TEST_F(ExecutionTest, DynamicDimensionReduce) {
946 const std::string hlo_text = R"(
947 HloModule TensorFlowScatterV1
948
949 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
950 lhs = s32[] parameter(0)
951 rhs = s32[] parameter(1)
952 ROOT add = s32[] add(lhs, rhs)
953 }
954
955 ENTRY main {
956 param = s32[5] parameter(0)
957 const = s32[] constant(3)
958 param_padded = s32[<=5] set-dimension-size(param, const), dimensions={0}
959 init = s32[] constant(0)
960 ROOT reduce = s32[] reduce(param_padded, init),
961 dimensions={0},
962 to_apply=update_s32
963 }
964 )";
965
966 // Input has upper bound of 5, dynamic dimension is 3.
967 Literal operand = LiteralUtil::CreateR1<int32_t>({1, 2, 3, 4, 5});
968 auto module = GetHloModule(hlo_text);
969
970 Literal result = PadAndExecute(std::move(module), {&operand});
971
972 // only first 3 elements will be reduced.
973 Literal expected = LiteralUtil::CreateR0<int32_t>(6);
974
975 EXPECT_EQ(result, expected);
976 }
977
XLA_TEST_F(ExecutionTest,InputMinorDimensionReshape)978 XLA_TEST_F(ExecutionTest, InputMinorDimensionReshape) {
979 const std::string hlo_text = R"(
980 HloModule TensorFlowScatterV1
981
982 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
983 lhs = s32[] parameter(0)
984 rhs = s32[] parameter(1)
985 ROOT add = s32[] add(lhs, rhs)
986 }
987
988 ENTRY main {
989 param = s32[1, 2, 5, 1] parameter(0)
990 const = s32[] constant(3)
991 param_padded = s32[1, 2, 5, 1] set-dimension-size(param, const), dimensions={2}
992 reshaped = s32[10] reshape(param_padded)
993 init = s32[] constant(0)
994 ROOT reduce = s32[] reduce(reshaped, init),
995 dimensions={0},
996 to_apply=update_s32
997 }
998 )";
999
1000 // The third dimension has upper bound of 5, dynamic dimension is 3.
1001 Literal operand = LiteralUtil::CreateR4<int32_t>(
1002 {{{{1}, {2}, {3}, {4}, {5}}, {{2}, {4}, {6}, {7}, {8}}}});
1003 auto module = GetHloModule(hlo_text);
1004
1005 Literal result = PadAndExecute(std::move(module), {&operand});
1006
1007 // Only the first 6 elements will be reduced.
1008 Literal expected = LiteralUtil::CreateR0<int32_t>(18);
1009
1010 EXPECT_EQ(result, expected);
1011 }
1012
XLA_TEST_F(ExecutionTest,SliceSingleElement)1013 XLA_TEST_F(ExecutionTest, SliceSingleElement) {
1014 // Slicing out a single element is supported.
1015 const std::string hlo_text = R"(
1016 HloModule Slicing
1017
1018 ENTRY main {
1019 param = s32[5] parameter(0)
1020 const = s32[] constant(3)
1021 param_padded = s32[5] set-dimension-size(param, const), dimensions={0}
1022 ROOT slice = s32[1]{0} slice(param_padded), slice={[0:1]}
1023 }
1024 )";
1025
1026 // The dynamic dimension has upper bound of 5, dynamic dimension is 3.
1027 Literal operand = LiteralUtil::CreateR1<int32_t>({0, 1, 2, 3, 4});
1028 auto module = GetHloModule(hlo_text);
1029
1030 Literal result = PadAndExecute(std::move(module), {&operand});
1031
1032 Literal expected = LiteralUtil::CreateR1<int32_t>({0});
1033
1034 EXPECT_EQ(result, expected);
1035 }
1036
XLA_TEST_F(ExecutionTest,OutputMinorDimensionReshape)1037 XLA_TEST_F(ExecutionTest, OutputMinorDimensionReshape) {
1038 const std::string hlo_text = R"(
1039 HloModule TensorFlowScatterV1
1040
1041 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
1042 lhs = s32[] parameter(0)
1043 rhs = s32[] parameter(1)
1044 ROOT add = s32[] add(lhs, rhs)
1045 }
1046
1047 ENTRY main {
1048 param = s32[12] parameter(0)
1049 const = s32[] constant(8)
1050 param_padded = s32[12] set-dimension-size(param, const), dimensions={0}
1051 // Second dimension is dynamic.
1052 reshaped = s32[2, 3, 2] reshape(param_padded), inferred_dimension=1
1053 init = s32[] constant(0)
1054 ROOT reduce = s32[2, 2] reduce(reshaped, init),
1055 dimensions={1},
1056 to_apply=update_s32
1057 }
1058 )";
1059
1060 // The third dimension has upper bound of 5, dynamic dimension is 3.
1061 Literal operand =
1062 LiteralUtil::CreateR1<int32_t>({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11});
1063 auto module = GetHloModule(hlo_text);
1064
1065 Literal result = PadAndExecute(std::move(module), {&operand});
1066
1067 // After padding and reshape we have
1068 //
1069 // [[[0, 1],
1070 // [2, 3]
1071 // [P, P]]
1072 // [[4, 5],
1073 // [6, 7],
1074 // [P, P]]]
1075 // Reducing on the second dimension gives us
1076 // [0+2, 1+3]
1077 // [4+6, 5+7]
1078 //
1079 Literal expected = LiteralUtil::CreateR2<int32_t>({{2, 4}, {10, 12}});
1080
1081 EXPECT_EQ(result, expected);
1082 }
1083
XLA_TEST_F(ExecutionTest,OutputMinorDimensionReshapeWithUnchangedDimMajor)1084 XLA_TEST_F(ExecutionTest, OutputMinorDimensionReshapeWithUnchangedDimMajor) {
1085 const std::string hlo_text = R"(
1086 HloModule TensorFlowScatterV1
1087
1088 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
1089 lhs = s32[] parameter(0)
1090 rhs = s32[] parameter(1)
1091 ROOT add = s32[] add(lhs, rhs)
1092 }
1093
1094 ENTRY main {
1095 param = s32[2, 6] parameter(0)
1096 const = s32[] constant(4)
1097 param_padded = s32[2, 6] set-dimension-size(param, const), dimensions={1}
1098 // Third dimension is dynamic.
1099 reshaped = s32[2, 2, 3] reshape(param_padded), inferred_dimension=2
1100 init = s32[] constant(0)
1101 ROOT reduce = s32[2, 2] reduce(reshaped, init),
1102 dimensions={2},
1103 to_apply=update_s32
1104 }
1105 )";
1106
1107 // The third dimension has upper bound of 5, dynamic dimension is 3.
1108 Literal operand = LiteralUtil::CreateR2<int32_t>(
1109 {{0, 1, 2, 3, 4, 5}, {6, 7, 8, 9, 10, 11}});
1110 auto module = GetHloModule(hlo_text);
1111
1112 Literal result = PadAndExecute(std::move(module), {&operand});
1113
1114 // After padding and reshape we have
1115 //
1116 // [[[0, 1, P],
1117 // [2, 3, P]],
1118 // [[6, 7, P],
1119 // [8, 9, P]]]
1120 // Reducing on the third dimension gives us
1121 // [0+1, 2+3]
1122 // [6+7, 8+9]
1123 //
1124 Literal expected = LiteralUtil::CreateR2<int32_t>({{1, 5}, {13, 17}});
1125
1126 EXPECT_EQ(result, expected);
1127 }
1128
XLA_TEST_F(ExecutionTest,OutputMinorDimensionReshapeWithUnchangedDimMinor)1129 XLA_TEST_F(ExecutionTest, OutputMinorDimensionReshapeWithUnchangedDimMinor) {
1130 const std::string hlo_text = R"(
1131 HloModule TensorFlowScatterV1
1132
1133 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
1134 lhs = s32[] parameter(0)
1135 rhs = s32[] parameter(1)
1136 ROOT add = s32[] add(lhs, rhs)
1137 }
1138
1139 ENTRY main {
1140 param = s32[6, 2] parameter(0)
1141 const = s32[] constant(4)
1142 param_padded = s32[6, 2] set-dimension-size(param, const), dimensions={0}
1143 // Second dimension is dynamic.
1144 reshaped = s32[2, 3, 2] reshape(param_padded), inferred_dimension=1
1145 init = s32[] constant(0)
1146 ROOT reduce = s32[2, 2] reduce(reshaped, init),
1147 dimensions={1},
1148 to_apply=update_s32
1149 }
1150 )";
1151
1152 // The third dimension has upper bound of 5, dynamic dimension is 3.
1153 Literal operand = LiteralUtil::CreateR2<int32_t>(
1154 {{0, 1}, {2, 3}, {4, 5}, {6, 7}, {8, 9}, {10, 11}});
1155 auto module = GetHloModule(hlo_text);
1156
1157 Literal result = PadAndExecute(std::move(module), {&operand});
1158
1159 // After padding and reshape we have
1160 //
1161 // [[[0, 1],
1162 // [2, 3]
1163 // [P, P]],
1164 // [[4, 5],
1165 // [6, 7],
1166 // [P, P]]]
1167 // Reducing on the second dimension gives us
1168 // [0+2, 1+3]
1169 // [4+6, 5+7]
1170 //
1171 Literal expected = LiteralUtil::CreateR2<int32_t>({{2, 4}, {10, 12}});
1172
1173 EXPECT_EQ(result, expected);
1174 }
1175
XLA_TEST_F(ExecutionTest,DynamicInputFeature)1176 XLA_TEST_F(ExecutionTest, DynamicInputFeature) {
1177 const std::string hlo_text = R"(
1178 HloModule DynamicInputFeature
1179
1180 ENTRY main {
1181 param = f32[1, 1, 5] parameter(0)
1182 const = s32[] constant(5)
1183 one = f32[] constant(1)
1184 kernel = f32[1,5,1]{2,1,0} broadcast(f32[] one), dimensions={}
1185 param_dynamic = f32[1,1,<=5] set-dimension-size(param, const), dimensions={2}
1186 ROOT conv = f32[1, 1, 1]{2,1,0} custom-call(f32[1, 1, <=5] param_dynamic, f32[1,5,1]{2,1,0} kernel),
1187 window={size=1 pad=0_0},
1188 dim_labels=b0f_0io->b0f,
1189 padding_type=PADDING_VALID,
1190 custom_call_target="DynamicConvolutionForward"
1191 }
1192 )";
1193
1194 Literal operand = LiteralUtil::CreateR3<float>({{{1, 2, 3, 4, 5}}});
1195 auto module = GetHloModule(hlo_text);
1196
1197 Literal result = PadAndExecute(std::move(module), {&operand});
1198
1199 Literal expected = LiteralUtil::CreateR3<float>({{{15}}});
1200
1201 EXPECT_EQ(result, expected);
1202 }
1203
XLA_TEST_F(LlvmIrGenTestBase,LargeDynamicInput)1204 XLA_TEST_F(LlvmIrGenTestBase, LargeDynamicInput) {
1205 #ifndef XLA_TEST_BACKEND_GPU
1206 GTEST_SKIP();
1207 #endif
1208 const std::string hlo_text = R"( // NOLINT: Will be executed for GPU.
1209 HloModule LargeDynamicInput
1210
1211 add (lhs: f32[], rhs: f32[]) -> f32[] {
1212 lhs = f32[] parameter(0)
1213 rhs = f32[] parameter(1)
1214 ROOT add = f32[] add(lhs, rhs)
1215 }
1216
1217 ENTRY main {
1218 param = f32[<=20,<=20,<=20,<=20,<=20,<=20,<=20,<=20] parameter(0)
1219 zero = f32[] constant(0)
1220 ROOT out = reduce(param, zero), to_apply=add, dimensions={0,1,2,3,4,5,6,7}
1221 }
1222 )";
1223
1224 CompileAndVerifyIr(hlo_text, R"(
1225 CHECK: ret void
1226 )",
1227 /*match_optimized_ir=*/true);
1228 }
1229
XLA_TEST_F(ExecutionTest,DynamicDimensionReshapeUnchanged)1230 XLA_TEST_F(ExecutionTest, DynamicDimensionReshapeUnchanged) {
1231 const std::string hlo_text = R"(
1232 HloModule TensorFlowScatterV1
1233
1234 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
1235 lhs = s32[] parameter(0)
1236 rhs = s32[] parameter(1)
1237 ROOT add = s32[] add(lhs, rhs)
1238 }
1239
1240 ENTRY main {
1241 param = s32[1, 2, 5, 1] parameter(0)
1242 const = s32[] constant(3)
1243 param_padded = s32[1, 2, 5, 1] set-dimension-size(param, const), dimensions={2}
1244 reshaped = s32[2, 5] reshape(param_padded)
1245 init = s32[] constant(0)
1246 ROOT reduce = s32[2] reduce(reshaped, init),
1247 dimensions={1},
1248 to_apply=update_s32
1249 }
1250 )";
1251
1252 // Test dynamic padder in unchanged dimension reshape.
1253 Literal operand = LiteralUtil::CreateR4<int32_t>(
1254 {{{{1}, {2}, {3}, {4}, {5}}, {{2}, {4}, {6}, {7}, {8}}}});
1255 auto module = GetHloModule(hlo_text);
1256
1257 Literal result = PadAndExecute(std::move(module), {&operand});
1258
1259 Literal expected = LiteralUtil::CreateR1<int32_t>({6, 12});
1260
1261 EXPECT_EQ(result, expected);
1262 }
1263
XLA_TEST_F(ExecutionTest,DegeneratedDimension)1264 XLA_TEST_F(ExecutionTest, DegeneratedDimension) {
1265 const std::string hlo_text = R"(
1266 HloModule TensorFlowScatterV1
1267
1268 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
1269 lhs = s32[] parameter(0)
1270 rhs = s32[] parameter(1)
1271 ROOT add = s32[] add(lhs, rhs)
1272 }
1273
1274 ENTRY main {
1275 param = s32[1, 2, 5, 1] parameter(0)
1276 size = s32[] constant(0)
1277 // First dimension is dynamic.
1278 param_padded = s32[1, 2, 5, 1] set-dimension-size(param, size),
1279 dimensions={0}
1280 reshaped = s32[10] reshape(param_padded)
1281 init = s32[] constant(0)
1282 ROOT reduce = s32[] reduce(reshaped, init),
1283 dimensions={0},
1284 to_apply=update_s32
1285 }
1286 )";
1287
1288 // First dimension (1) is dynamic. Since dynamic size is 0, result is also 0.
1289 Literal operand = LiteralUtil::CreateR4<int32_t>(
1290 {{{{1}, {2}, {3}, {4}, {5}}, {{2}, {4}, {6}, {7}, {8}}}});
1291 auto module = GetHloModule(hlo_text);
1292
1293 Literal result = PadAndExecute(std::move(module), {&operand});
1294
1295 Literal expected = LiteralUtil::CreateR0<int32_t>(0);
1296
1297 EXPECT_EQ(result, expected);
1298 }
1299
XLA_TEST_F(ExecutionTest,ReshapeSplitCombineSameTime)1300 XLA_TEST_F(ExecutionTest, ReshapeSplitCombineSameTime) {
1301 // [<=4, 2, <=2]
1302 // |
1303 // Reshape
1304 // |
1305 // [2, <=2, <=4]
1306 //
1307 // Split one input dynamic dim to multiple output dims while combining two
1308 // dimensions together.
1309 //
1310 const std::string hlo_text = R"(
1311 HloModule TensorFlowScatterV1
1312
1313 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
1314 lhs = s32[] parameter(0)
1315 rhs = s32[] parameter(1)
1316 ROOT add = s32[] add(lhs, rhs)
1317 }
1318
1319 ENTRY main {
1320 param = s32[4, 2, 2] parameter(0)
1321 two = s32[] constant(2)
1322 one = s32[] constant(1)
1323 param_padded_partial = s32[<=4, 2, 2] set-dimension-size(param, two),
1324 dimensions={0}
1325
1326 param_padded_dynamic = s32[<=4, 2, <=2] set-dimension-size(param_padded_partial,
1327 one),
1328 dimensions={2}
1329 reshaped = s32[2, <=2, <=4] reshape(param_padded_dynamic),
1330 inferred_dimension=1
1331 init = s32[] constant(0)
1332 ROOT reduce = s32[] reduce(reshaped, init),
1333 dimensions={0, 1, 2},
1334 to_apply=update_s32
1335 }
1336 )";
1337
1338 // First and last dims are dynamic. Padded data are expressed as -1.
1339 Literal operand = LiteralUtil::CreateR3<int32_t>({{{0, -1}, {1, -1}},
1340 {{2, -1}, {3, -1}},
1341 {{-1, -1}, {-1, -1}},
1342 {{-1, -1}, {-1, -1}}});
1343 auto module = GetHloModule(hlo_text);
1344
1345 Literal result = PadAndExecute(std::move(module), {&operand});
1346
1347 // Reshaping (with correct reshape rewriting) produces:
1348 // [[[0, 1, -1, -1], [-1, -1, -1, -1]], [[2, 3, -1, -1], [-1, -1, -1, -1]]]
1349 //
1350 // Dynamic padder auto pads -1 with 0.
1351 //
1352 // Reducing it produces 0 + 1 + 2 + 3 = 6
1353
1354 Literal expected = LiteralUtil::CreateR0<int32_t>(6);
1355
1356 EXPECT_EQ(result, expected);
1357 }
1358
XLA_TEST_F(ExecutionTest,ReshapeComplicated)1359 XLA_TEST_F(ExecutionTest, ReshapeComplicated) {
1360 // [2, <=4, 4]
1361 // |
1362 // Reshape
1363 // |
1364 // [<=16, 2]
1365 //
1366 // Reshape that is not a composition of splitting one input dim to multiple
1367 // output dims or combining multiple input dimensions to one output dimension.
1368 //
1369 const std::string hlo_text = R"(
1370 HloModule TensorFlowScatterV1
1371
1372 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
1373 lhs = s32[] parameter(0)
1374 rhs = s32[] parameter(1)
1375 ROOT add = s32[] add(lhs, rhs)
1376 }
1377
1378 ENTRY main {
1379 param = s32[2, 4, 4] parameter(0)
1380 two = s32[] constant(2)
1381 param_padded_dynamic = s32[2, <=4, 4] set-dimension-size(param, two),
1382 dimensions={1}
1383 reshaped = s32[<=16, 2] reshape(param_padded_dynamic), inferred_dimension=0
1384 init = s32[] constant(0)
1385 ROOT reduce = s32[] reduce(reshaped, init),
1386 dimensions={0, 1},
1387 to_apply=update_s32
1388 }
1389 )";
1390
1391 // First and last dims are dynamic. Padded data are expressed as -1.
1392 Literal operand = LiteralUtil::CreateR3<int32_t>(
1393 {{{1, 2, 3, 4}, {5, 6, 7, 8}, {-1, -1, -1, -1}, {-1, -1, -1, -1}},
1394 {{9, 10, 11, 12},
1395 {13, 14, 15, 16},
1396 {-1, -1, -1, -1},
1397 {-1, -1, -1, -1}}});
1398 auto module = GetHloModule(hlo_text);
1399 Literal result = PadAndExecute(std::move(module), {&operand});
1400
1401 // Reshaping (with correct reshape rewriting) produces:
1402 // [[[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]],
1403 // [[-1, -1], [-1, -1], ...]]
1404 //
1405 // Dynamic padder auto pads -1 with 0.
1406 //
1407 // Reducing it produces 1 + 2 + 3 + ... + 16 = 136
1408
1409 Literal expected = LiteralUtil::CreateR0<int32_t>(136);
1410 EXPECT_EQ(result, expected);
1411 }
1412
XLA_TEST_F(ExecutionTest,WhileLoopStack)1413 XLA_TEST_F(ExecutionTest, WhileLoopStack) {
1414 // Push into a dynamic sized stack with iteration number:
1415 // init:
1416 // [[P, P],
1417 // [P, P],
1418 // [P, P],
1419 // [P, P]]
1420 // First iteration i = 0:
1421 // [[0, 0],
1422 // [P, P],
1423 // [P, P],
1424 // [P, P]]
1425 // Second iteration i = 1:
1426 // [[0, 0],
1427 // [1, 1],
1428 // [P, P],
1429 // [P, P]]
1430 // Third iteration i = 2:
1431 // [[0, 0],
1432 // [1, 1],
1433 // [2, 2],
1434 // [P, P]]
1435
1436 const std::string hlo_text = R"(
1437 HloModule module
1438
1439 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
1440 lhs = s32[] parameter(0)
1441 rhs = s32[] parameter(1)
1442 ROOT add = s32[] add(lhs, rhs)
1443 }
1444
1445 body {
1446 stack = (s32[<=4,2]) parameter(0)
1447 stack_buffer = s32[<=4, 2] get-tuple-element(stack), index=0
1448 stack_size = s32[] get-dimension-size(stack_buffer), dimensions={0}
1449 zero = s32[] constant(0)
1450 one = s32[] constant(1)
1451 // content of the stack is the stack index broadcasted.
1452 new_data = s32[1, 2] broadcast(s32[] stack_size), dimensions={}
1453 new_stack_buffer = s32[<=4, 2] dynamic-update-slice(stack_buffer, new_data, stack_size, zero)
1454 new_stack_size = s32[] add(stack_size, one)
1455 new_stack_buffer_dynamic = s32[<=4, 2]set-dimension-size(new_stack_buffer, new_stack_size), dimensions={0}
1456 ROOT new_stack = (s32[<=4,2]) tuple(new_stack_buffer_dynamic)
1457 }
1458
1459 condition {
1460 stack = (s32[<=4,2]) parameter(0)
1461 stack_buffer = s32[<=4, 2] get-tuple-element(stack), index=0
1462 stack_size = s32[] get-dimension-size(stack_buffer), dimensions={0}
1463 three = s32[] constant(3)
1464 ROOT less-than = pred[] compare(s32[] stack_size, s32[] three), direction=LT
1465 }
1466
1467 ENTRY entry {
1468 zero = s32[] constant(0)
1469 pad = s32[] constant(-1)
1470 stack_buffer_input = s32[4, 2] broadcast(s32[] pad), dimensions={}
1471 stack_buffer_input_dynamic = s32[<=4, 2] set-dimension-size(stack_buffer_input, zero), dimensions={0}
1472 input_tuple = (s32[<=4 ,2]) tuple(stack_buffer_input_dynamic)
1473 while = (s32[<=4, 2]) while(input_tuple), body=body, condition=condition
1474 stack_buffer = s32[<=4, 2] get-tuple-element(while), index=0
1475 ROOT reduce = s32[2] reduce(stack_buffer, zero),
1476 dimensions={0},
1477 to_apply=update_s32
1478 }
1479 )";
1480
1481 auto module = GetHloModule(hlo_text);
1482
1483 Literal result = PadAndExecute(std::move(module), {});
1484
1485 // Stack has three valid items in it:
1486 // [[0, 0],
1487 // [1, 1],
1488 // [2, 2],
1489 // [P, P]]
1490 //
1491 // Reducing along major dimension gives us [3, 3]
1492 Literal expected = LiteralUtil::CreateR1<int32_t>({{3, 3}});
1493
1494 EXPECT_EQ(result, expected);
1495 }
1496
XLA_TEST_F(ExecutionTest,DynamicAddWithImplicitBroadcast)1497 XLA_TEST_F(ExecutionTest, DynamicAddWithImplicitBroadcast) {
1498 const std::string hlo_text = R"(
1499 HloModule module
1500
1501 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
1502 lhs = s32[] parameter(0)
1503 rhs = s32[] parameter(1)
1504 ROOT add = s32[] add(lhs, rhs)
1505 }
1506
1507 ENTRY entry {
1508 zero = s32[] constant(0)
1509 one = s32[] constant(1)
1510 two = s32[] constant(2)
1511 three = s32[] constant(3)
1512 input1 = s32[4, 2] iota(), iota_dimension=0
1513 ones = s32[4, 2] broadcast(one), dimensions={}
1514 input1_added = s32[4, 2] add(input1, ones)
1515 input1_dynamic = s32[<=4, 2] set-dimension-size(input1_added, one), dimensions={0}
1516 input2 = s32[4, 2] broadcast(two), dimensions={}
1517 input2_dynamic = s32[<=4, 2] set-dimension-size(input2, three), dimensions={0}
1518 add = s32[<=4, 2] add(input1_dynamic, input2_dynamic)
1519 ROOT reduce = s32[2] reduce(add, zero),
1520 dimensions={0},
1521 to_apply=update_s32
1522 }
1523 )";
1524
1525 auto module = GetHloModule(hlo_text);
1526
1527 Literal result = PadAndExecute(std::move(module), {});
1528
1529 // Array has two valid items in it:
1530 // [[3, 3],
1531 // [3, 3],
1532 // [3, 3],
1533 // [P, P]]
1534 // Reducing them gives us [9, 9]
1535 Literal expected = LiteralUtil::CreateR1<int32_t>({{9, 9}});
1536
1537 EXPECT_EQ(result, expected);
1538 }
1539
XLA_TEST_F(ExecutionTest,DynamicAddWithImplicitSlice)1540 XLA_TEST_F(ExecutionTest, DynamicAddWithImplicitSlice) {
1541 const std::string hlo_text = R"(
1542 HloModule module
1543
1544 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
1545 lhs = s32[] parameter(0)
1546 rhs = s32[] parameter(1)
1547 ROOT add = s32[] add(lhs, rhs)
1548 }
1549
1550 ENTRY entry {
1551 zero = s32[] constant(0)
1552 one = s32[] constant(1)
1553 two = s32[] constant(2)
1554 three = s32[] constant(3)
1555 input1 = s32[4, 2] broadcast(one), dimensions={}
1556 input1_dynamic = s32[<=4, 2] set-dimension-size(input1, three), dimensions={0}
1557 input2 = s32[4, 2] broadcast(two), dimensions={}
1558 input2_dynamic = s32[<=4, 2] set-dimension-size(input2, two), dimensions={0}
1559 add = s32[<=4, 2] add(input1_dynamic, input2_dynamic)
1560 ROOT reduce = s32[2] reduce(add, zero),
1561 dimensions={0},
1562 to_apply=update_s32
1563 }
1564 )";
1565
1566 auto module = GetHloModule(hlo_text);
1567
1568 Literal result = PadAndExecute(std::move(module), {});
1569
1570 // Array has two valid items in it:
1571 // [[3, 3],
1572 // [3, 3],
1573 // [P, P],
1574 // [P, P]]
1575 // Reducing them gives us [6, 6]
1576 Literal expected = LiteralUtil::CreateR1<int32_t>({{6, 6}});
1577
1578 EXPECT_EQ(result, expected);
1579 }
1580
XLA_TEST_F(ExecutionTest,DynamicStackPop)1581 XLA_TEST_F(ExecutionTest, DynamicStackPop) {
1582 // This tests the case where a static sized stack is popped by a dynamic
1583 // number of times.
1584
1585 // In the beginning the stack has static size that has 4 elements:
1586 // [[1, 1],
1587 // [1, 1],
1588 // [1, 1],
1589 // [1, 1]]
1590 //
1591 // Popping this stack using set-dimension-size in a loop creates a dynamic
1592 // result depending on how many times we pop it (in this test, two times).
1593
1594 const std::string hlo_text = R"(
1595 HloModule module
1596
1597 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
1598 lhs = s32[] parameter(0)
1599 rhs = s32[] parameter(1)
1600 ROOT add = s32[] add(lhs, rhs)
1601 }
1602
1603 body {
1604 param_tuple = (s32[<=4,2]) parameter(0)
1605 param = s32[<=4, 2] get-tuple-element(param_tuple), index=0
1606 one = s32[] constant(1)
1607 size = s32[] get-dimension-size(param), dimensions={0}
1608 new_size = s32[] subtract(size, one)
1609 output = s32[<=4, 2] set-dimension-size(param, new_size), dimensions={0}
1610 ROOT root = (s32[<=4, 2]) tuple(output)
1611 }
1612
1613 condition {
1614 stack = (s32[<=4,2]) parameter(0)
1615 stack_buffer = s32[<=4,2] get-tuple-element(stack), index=0
1616 stack_size = s32[] get-dimension-size(stack_buffer), dimensions={0}
1617 two = s32[] constant(2)
1618 ROOT greater-than = pred[] compare(s32[] stack_size, s32[] two), direction=GE
1619 }
1620
1621 ENTRY entry {
1622 one = s32[] constant(1)
1623 zero = s32[] constant(0)
1624 stack_buffer_input = s32[4, 2] broadcast(s32[] one), dimensions={}
1625 input_tuple = (s32[4, 2]) tuple(stack_buffer_input)
1626 while = (s32[4, 2]) while(input_tuple), body=body, condition=condition
1627 stack_buffer = s32[<=4, 2] get-tuple-element(while), index=0
1628 ROOT reduce = s32[2] reduce(stack_buffer, zero),
1629 dimensions={0},
1630 to_apply=update_s32
1631 }
1632 )";
1633
1634 auto module = GetHloModule(hlo_text);
1635
1636 Literal result = PadAndExecute(std::move(module), {});
1637
1638 // Stack has two valid items in it:
1639 // [[1, 1],
1640 // [1, 1],
1641 // [P, P],
1642 // [P, P]]
1643 // Reducing them gives us [2, 2]
1644 Literal expected = LiteralUtil::CreateR1<int32_t>({{1, 1}});
1645
1646 EXPECT_EQ(result, expected);
1647 }
1648
XLA_TEST_F(ExecutionTest,DoubleDynamicDimension)1649 XLA_TEST_F(ExecutionTest, DoubleDynamicDimension) {
1650 const std::string hlo_text = R"(
1651 HloModule TensorFlowScatterV1
1652
1653 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
1654 lhs = s32[] parameter(0)
1655 rhs = s32[] parameter(1)
1656 ROOT add = s32[] add(lhs, rhs)
1657 }
1658
1659 ENTRY main {
1660 param = s32[2, 3, 3] parameter(0)
1661 size = s32[] constant(2)
1662 param_padded_partial = s32[2, 3, 3] set-dimension-size(param, size),
1663 dimensions={1}
1664 param_padded = s32[2, 3, 3] set-dimension-size(param_padded_partial, size),
1665 dimensions={2}
1666 reshaped = s32[18] reshape(param_padded)
1667 init = s32[] constant(0)
1668 ROOT reduce = s32[] reduce(reshaped, init),
1669 dimensions={0},
1670 to_apply=update_s32
1671 }
1672 )";
1673
1674 // First dimension (1) is dynamic. Since dynamic size is 0, result is also 0.
1675 Literal operand = LiteralUtil::CreateR3<int32_t>(
1676 {{{0, 1, 2}, {3, 4, 5}, {6, 7, 8}}, {{0, 1, 2}, {3, 4, 5}, {6, 7, 8}}});
1677 auto module = GetHloModule(hlo_text);
1678
1679 Literal result = PadAndExecute(std::move(module), {&operand});
1680
1681 // Padded data looks like this (P is padding which is ignored).
1682 // [[0, 1, P]
1683 // [3, 4, P]
1684 // [P, P, P]]
1685 //
1686 // [[0, 1, P]
1687 // [3, 4, P]
1688 // [P, P, P]]
1689 //
1690 // Reshaping (with correct reshape rewriting) produces:
1691 // [0, 1, 3, 4, 0, 1, 3, 4, P, P, P, P, P, P, P, P, P, P]
1692 //
1693 // Reducing it produces 16
1694
1695 Literal expected = LiteralUtil::CreateR0<int32_t>(16);
1696
1697 EXPECT_EQ(result, expected);
1698 }
1699
XLA_TEST_F(ExecutionTest,DynamicReshapeDoubleDynamicDimensions)1700 XLA_TEST_F(ExecutionTest, DynamicReshapeDoubleDynamicDimensions) {
1701 const std::string hlo_text = R"(
1702 HloModule TensorFlowScatterV1
1703
1704 ENTRY main {
1705 param = s32[2, 3, 3] parameter(0)
1706 size = s32[] constant(2)
1707 param_padded_partial = s32[2, <=3, 3] set-dimension-size(param, size),
1708 dimensions={1}
1709 param_padded = s32[2, <=3, <=3] set-dimension-size(param_padded_partial, size),
1710 dimensions={2}
1711 result_size = s32[] constant(8)
1712 ROOT reshaped = s32[<=18] dynamic-reshape(param_padded, result_size)
1713 }
1714 )";
1715
1716 // First dimension (1) is dynamic. Since dynamic size is 0, result is also 0.
1717 Literal operand = LiteralUtil::CreateR3<int32_t>(
1718 {{{0, 1, 2}, {3, 4, 5}, {6, 7, 8}}, {{0, 1, 2}, {3, 4, 5}, {6, 7, 8}}});
1719 auto module = GetHloModule(hlo_text);
1720
1721 Literal result = PadAndExecute(std::move(module), {&operand}, false);
1722 result.SetDynamicSize(0, 8);
1723 // Padded data looks like this (P is padding which is ignored).
1724 // [[0, 1, P]
1725 // [3, 4, P]
1726 // [P, P, P]]
1727 //
1728 // [[0, 1, P]
1729 // [3, 4, P]
1730 // [P, P, P]]
1731 //
1732 // Reshaping (with correct reshape rewriting) produces:
1733 // [0, 1, 3, 4, 0, 1, 3, 4]
1734 Literal expected = LiteralUtil::CreateR1<int32_t>({0, 1, 3, 4, 0, 1, 3, 4});
1735
1736 EXPECT_EQ(result, expected);
1737 }
1738
XLA_TEST_F(ExecutionTest,DynamicReshapeOutputDoubleDynamicDimensions)1739 XLA_TEST_F(ExecutionTest, DynamicReshapeOutputDoubleDynamicDimensions) {
1740 const std::string hlo_text = R"(
1741 HloModule TensorFlowScatterV1
1742
1743 ENTRY main {
1744 param = s32[18] parameter(0)
1745 eight = s32[] constant(8)
1746 param_dynamic = s32[<=18] set-dimension-size(param, eight), dimensions={0}
1747 two = s32[] constant(2)
1748 // every dimension has dynamic size two.
1749 ROOT reshaped = s32[2, <=3, <=3] dynamic-reshape(param_dynamic, two, two, two)
1750 }
1751 )";
1752 Literal operand = LiteralUtil::CreateR1<int32_t>(
1753 {0, 1, 3, 4, 0, 1, 3, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1});
1754
1755 auto module = GetHloModule(hlo_text);
1756
1757 Literal result = PadAndExecute(std::move(module), {&operand}, false);
1758 VLOG(1) << " result: " << result.ToString();
1759 result.SetDynamicSize(1, 2);
1760 result.SetDynamicSize(2, 2);
1761 // Padded operand is:
1762 // [0, 1, 3, 4, 0, 1, 3, 4, P, P ....]
1763 //
1764 // Reshaping it should produce:
1765 // [[0, 1, P]
1766 // [3, 4, P]
1767 // [P, P, P]]
1768 //
1769 // [[0, 1, P]
1770 // [3, 4, P]
1771 // [P, P, P]]
1772 Literal expected =
1773 LiteralUtil::CreateR3<int32_t>({{{0, 1}, {3, 4}}, {{0, 1}, {3, 4}}});
1774 EXPECT_EQ(result, expected);
1775 }
1776
XLA_TEST_F(ExecutionTest,DynamicReshapeComplicated)1777 XLA_TEST_F(ExecutionTest, DynamicReshapeComplicated) {
1778 const std::string hlo_text = R"(
1779 HloModule TensorFlowScatterV1
1780
1781 ENTRY main {
1782 param = s32[3, 4, 4] parameter(0)
1783 two = s32[] constant(2)
1784 param_dynamic = s32[<=3, 4, 4] set-dimension-size(param, two), dimensions={0}
1785 three = s32[] constant(3)
1786 param_dynamic1 = s32[<=3, <=4, 4] set-dimension-size(param_dynamic, three), dimensions={1}
1787 param_dynamic2 = s32[<=3, <=4, <=4] set-dimension-size(param_dynamic1, three), dimensions={2}
1788 six = s32[] constant(6)
1789
1790 // Static reshape is from [3, 4, 4] to [6, 8].
1791 // Dynamic reshape is from [2, 3, 3] to [3, 6].
1792 ROOT reshaped = s32[<=6, <=8] dynamic-reshape(param_dynamic2, three, six)
1793 }
1794 )";
1795 Literal operand = LiteralUtil::CreateR3<int32_t>(
1796 {{{0, 1, 2, -1}, {3, 4, 5, -1}, {6, 7, 8, -1}, {-1, -1, -1, -1}},
1797 {{9, 8, 7, -1}, {6, 5, 4, -1}, {3, 2, 1, -1}, {-1, -1, -1, -1}},
1798 {{-1, -1, -1, -1},
1799 {-1, -1, -1, -1},
1800 {-1, -1, -1, -1},
1801 {-1, -1, -1, -1}}});
1802
1803 auto module = GetHloModule(hlo_text);
1804
1805 Literal result = PadAndExecute(std::move(module), {&operand}, false);
1806 result.SetDynamicSize(0, 3);
1807 result.SetDynamicSize(1, 6);
1808 Literal expected = LiteralUtil::CreateR2<int32_t>(
1809 {{0, 1, 2, 3, 4, 5}, {6, 7, 8, 9, 8, 7}, {6, 5, 4, 3, 2, 1}});
1810 EXPECT_EQ(result, expected);
1811 }
1812
XLA_TEST_F(ExecutionTest,SetGetDimensionSize)1813 XLA_TEST_F(ExecutionTest, SetGetDimensionSize) {
1814 const std::string hlo_text = R"(
1815 HloModule TensorFlowScatterV1
1816
1817 ENTRY main {
1818 param = s32[3] parameter(0)
1819 size = s32[] constant(2)
1820 param_dynamic_size = s32[3] set-dimension-size(param, size),
1821 dimensions={0}
1822 ROOT gds = s32[] get-dimension-size(param_dynamic_size),
1823 dimensions={0}
1824 }
1825 )";
1826
1827 // First dimension (1) is dynamic. Since dynamic size is 0, result is also 0.
1828 Literal operand = LiteralUtil::CreateR1<int32_t>({1, 2, 3});
1829 auto module = GetHloModule(hlo_text);
1830
1831 Literal result = PadAndExecute(std::move(module), {&operand});
1832
1833 // Should return the size 2 instead of 3.
1834 Literal expected = LiteralUtil::CreateR0<int32_t>(2);
1835
1836 EXPECT_EQ(result, expected);
1837 }
1838
XLA_TEST_F(ExecutionTest,DynamicSort)1839 XLA_TEST_F(ExecutionTest, DynamicSort) {
1840 const std::string hlo_text = R"(
1841 HloModule TEST
1842
1843 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
1844 lhs = s32[] parameter(0)
1845 rhs = s32[] parameter(1)
1846 ROOT add = s32[] add(lhs, rhs)
1847 }
1848
1849 %compare-greater-than (lhs: s32[], rhs: s32[]) -> pred[] {
1850 %lhs = s32[] parameter(0)
1851 %rhs = s32[] parameter(1)
1852 ROOT %compare = pred[] compare(s32[] %lhs, s32[] %rhs), direction=GT
1853 }
1854
1855 ENTRY main {
1856 param = s32[4] parameter(0)
1857 size = s32[] constant(3)
1858 param_dynamic_size = s32[4] set-dimension-size(param, size),
1859 dimensions={0}
1860 ROOT sort = s32[4]{0} sort(s32[4]{0} %param_dynamic_size),
1861 dimensions={0}, is_stable=false, to_apply=%compare-greater-than
1862 }
1863 )";
1864
1865 Literal operand = LiteralUtil::CreateR1<int32_t>({1, 4, 3, 2});
1866 auto module = GetHloModule(hlo_text);
1867
1868 Literal result = PadAndExecute(std::move(module), {&operand},
1869 /*slice_dynamic_output=*/false);
1870 Literal expected = LiteralUtil::CreateR1<int32_t>({4, 3, 1, 2});
1871
1872 EXPECT_EQ(result, expected);
1873 }
1874
XLA_TEST_F(ExecutionTest,DynamicPad)1875 XLA_TEST_F(ExecutionTest, DynamicPad) {
1876 const std::string hlo_text = R"(
1877 HloModule TEST
1878
1879 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
1880 lhs = s32[] parameter(0)
1881 rhs = s32[] parameter(1)
1882 ROOT add = s32[] add(lhs, rhs)
1883 }
1884
1885 ENTRY main {
1886 param = s32[4] parameter(0)
1887 size = s32[] constant(3)
1888 padding = s32[] constant(2)
1889 param_dynamic = s32[<=4] set-dimension-size(param, size),
1890 dimensions={0}
1891 // Pad head and tail with 2
1892 pad = s32[<=6] pad(param_dynamic, padding), padding=1_1
1893
1894 init = s32[] constant(0)
1895 ROOT reduce = s32[] reduce(pad, init),
1896 dimensions={0},
1897 to_apply=update_s32
1898 }
1899 )";
1900
1901 Literal operand = LiteralUtil::CreateR1<int32_t>({1, 4, 3, 5});
1902 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_text));
1903
1904 // After padding head and tail with "2", the effective data will be [2, 1, 4,
1905 // 3, 2]
1906
1907 Literal result = PadAndExecute(std::move(module), {&operand},
1908 /*slice_dynamic_output=*/false);
1909 Literal expected = LiteralUtil::CreateR0<int32_t>(12);
1910
1911 EXPECT_EQ(result, expected);
1912 }
1913
XLA_TEST_F(ExecutionTest,DynamicPadInteriorPadding)1914 XLA_TEST_F(ExecutionTest, DynamicPadInteriorPadding) {
1915 const std::string hlo_text = R"(
1916 HloModule TEST
1917
1918 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
1919 lhs = s32[] parameter(0)
1920 rhs = s32[] parameter(1)
1921 ROOT add = s32[] add(lhs, rhs)
1922 }
1923
1924 ENTRY main {
1925 param = s32[4] parameter(0)
1926 size = s32[] constant(3)
1927 padding = s32[] constant(2)
1928 param_dynamic = s32[<=4] set-dimension-size(param, size),
1929 dimensions={0}
1930 // Pad interior with constant 2.
1931 pad = s32[<=7] pad(param_dynamic, padding), padding=0_0_1
1932
1933 init = s32[] constant(0)
1934 ROOT reduce = s32[] reduce(pad, init),
1935 dimensions={0},
1936 to_apply=update_s32
1937 }
1938 )";
1939
1940 // Only the first 3 elements are effective: 1, 4, 3
1941 Literal operand = LiteralUtil::CreateR1<int32_t>({1, 4, 3, 5});
1942 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_text));
1943
1944 // After interior padding with "2", the effective data will be
1945 // [1, 2, 4, 2, 3]
1946 Literal result = PadAndExecute(std::move(module), {&operand},
1947 /*slice_dynamic_output=*/false);
1948 Literal expected = LiteralUtil::CreateR0<int32_t>(12);
1949
1950 EXPECT_EQ(result, expected);
1951 }
1952
XLA_TEST_F(ExecutionTest,DynamicConditionalDimension)1953 XLA_TEST_F(ExecutionTest, DynamicConditionalDimension) {
1954 const std::string hlo_text = R"(
1955 HloModule module
1956
1957 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
1958 lhs = s32[] parameter(0)
1959 rhs = s32[] parameter(1)
1960 ROOT add = s32[] add(lhs, rhs)
1961 }
1962
1963 true_branch {
1964 true_param = (s32[<=3,2]) parameter(0)
1965 param = s32[<=3, 2] get-tuple-element(true_param), index=0
1966 add = s32[<=3,2] add(param, param)
1967 ROOT true_tuple = (s32[<=3,2], s32[<=3,2]) tuple(add, add)
1968 }
1969
1970 false_branch {
1971 false_param = (s32[<=3,2]) parameter(0)
1972 param = s32[<=3, 2] get-tuple-element(false_param), index=0
1973 add = s32[<=3,2] add(param, param)
1974 ROOT false_tuple = (s32[<=3,2], s32[<=3,2]) tuple(add, add)
1975 }
1976
1977 ENTRY entry {
1978 param0 = s32[3,2] parameter(0)
1979 size = s32[] constant(2)
1980 branch = pred[] constant(false)
1981 param_dynamic = s32[<=3, 2] set-dimension-size(param0, size), dimensions={0}
1982 param_tuple = (s32[<=3 ,2]) tuple(param_dynamic)
1983 conditional = (s32[<=3, 2], s32[<=3, 2]) conditional(branch, param_tuple, param_tuple),
1984 true_computation=true_branch, false_computation=false_branch
1985 gte0 = s32[<=3,2] get-tuple-element(conditional), index=1
1986 init = s32[] constant(0)
1987 ROOT reduce = s32[2] reduce(gte0, init),
1988 dimensions={0},
1989 to_apply=update_s32
1990 }
1991 )";
1992
1993 Literal operand = LiteralUtil::CreateR2<int32_t>({{0, 1}, {2, 3}, {4, 5}});
1994 auto module = GetHloModule(hlo_text);
1995
1996 Literal result = PadAndExecute(std::move(module), {&operand},
1997 /*slice_dynamic_output=*/false);
1998 Literal expected = LiteralUtil::CreateR1<int32_t>({4, 8});
1999
2000 EXPECT_EQ(result, expected);
2001 }
2002
XLA_TEST_F(ExecutionTest,DynamicTupleSort)2003 XLA_TEST_F(ExecutionTest, DynamicTupleSort) {
2004 const std::string hlo_text = R"(
2005 HloModule TEST
2006
2007 %compare-greater-than (lhs: s32[], rhs: s32[], lhs_2: s32[], lhs_2: s32[]) -> pred[] {
2008 %lhs = s32[] parameter(0)
2009 %rhs = s32[] parameter(1)
2010 %lhs_2 = s32[] parameter(2)
2011 %rhs_2 = s32[] parameter(3)
2012 ROOT %compare = pred[] compare(s32[] %lhs, s32[] %rhs), direction=GT
2013 }
2014
2015 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
2016 lhs = s32[] parameter(0)
2017 rhs = s32[] parameter(1)
2018 ROOT add = s32[] add(lhs, rhs)
2019 }
2020
2021 ENTRY main {
2022 param = s32[3] parameter(0)
2023 size = s32[] constant(2)
2024 param_dynamic_size = s32[3] set-dimension-size(param, size),
2025 dimensions={0}
2026 sort = (s32[3]{0}, s32[3]{0}) sort(s32[3]{0} %param_dynamic_size,
2027 s32[3]{0} %param_dynamic_size),
2028 dimensions={0}, is_stable=true, to_apply=%compare-greater-than
2029 ROOT get-tuple-element = s32[3]{0} get-tuple-element((s32[3]{0}, s32[3]{0}) %sort),
2030 index=0
2031 }
2032 )";
2033
2034 Literal operand = LiteralUtil::CreateR1<int32_t>({0, 4, 2});
2035 auto module = GetHloModule(hlo_text);
2036
2037 Literal result = PadAndExecute(std::move(module), {&operand},
2038 /*slice_dynamic_output=*/false);
2039 Literal expected = LiteralUtil::CreateR1<int32_t>({4, 0, 2});
2040
2041 EXPECT_EQ(result, expected);
2042 }
2043
2044 namespace op = xla::testing::opcode_matchers;
2045
2046 class HloDimensionSizeLegalizerTest : public HloTestBase {
2047 protected:
HloDimensionSizeLegalizerTest()2048 HloDimensionSizeLegalizerTest() {}
2049 };
2050
TEST_F(HloDimensionSizeLegalizerTest,Ok)2051 TEST_F(HloDimensionSizeLegalizerTest, Ok) {
2052 auto module = ParseAndReturnVerifiedModule(R"(
2053 HloModule _
2054 ENTRY gds {
2055 p = s32[3,4] parameter(0)
2056 size0 = s32[] get-dimension-size(p), dimensions={0}
2057 size1 = s32[] get-dimension-size(p), dimensions={1}
2058 ROOT mul = s32[] multiply(size0, size1)
2059 })")
2060 .ValueOrDie();
2061 DynamicPadder pass;
2062 EXPECT_TRUE(pass.Run(module.get()).ValueOrDie());
2063 EXPECT_THAT(module->entry_computation()->root_instruction(),
2064 op::Multiply(op::Constant(), op::Constant()));
2065 }
2066
TEST_F(HloDimensionSizeLegalizerTest,GetSetSetDimensionSizeRewriter)2067 TEST_F(HloDimensionSizeLegalizerTest, GetSetSetDimensionSizeRewriter) {
2068 auto module = ParseAndReturnVerifiedModule(R"(
2069 HloModule _
2070 ENTRY gds {
2071 p = s32[3,4] parameter(0)
2072 size0 = s32[] get-dimension-size(p), dimensions={0}
2073 p_copy = s32[3,4] copy(p)
2074 p_copy_dynamic = s32[<=3, 4] set-dimension-size(p_copy, size0), dimensions={0}
2075 size1 = s32[] get-dimension-size(p_copy_dynamic), dimensions={0}
2076 ROOT mul = s32[] multiply(size0, size1)
2077 })")
2078 .ValueOrDie();
2079 DynamicPadder pass;
2080 EXPECT_TRUE(pass.Run(module.get()).ValueOrDie());
2081 EXPECT_THAT(module->entry_computation()->root_instruction(),
2082 op::Multiply(op::Constant(), op::Constant()));
2083 }
2084
TEST_F(HloDimensionSizeLegalizerTest,IllegalType)2085 TEST_F(HloDimensionSizeLegalizerTest, IllegalType) {
2086 auto module = ParseAndReturnUnverifiedModule(R"(
2087 HloModule _
2088 ENTRY gds {
2089 p = s32[3]{0} parameter(0)
2090 ROOT gds = s64[] get-dimension-size(p), dimensions={0}
2091 })")
2092 .ValueOrDie();
2093 DynamicPadder pass;
2094 EXPECT_FALSE(pass.Run(module.get()).ok());
2095 }
2096
TEST_F(HloDimensionSizeLegalizerTest,IllegalDimension)2097 TEST_F(HloDimensionSizeLegalizerTest, IllegalDimension) {
2098 auto module = ParseAndReturnUnverifiedModule(R"(
2099 HloModule _
2100 ENTRY gds {
2101 p = f32[2,5] parameter(0)
2102 ROOT gds = s32[] get-dimension-size(p), dimensions={2}
2103 })")
2104 .ValueOrDie();
2105 DynamicPadder pass;
2106 EXPECT_FALSE(pass.Run(module.get()).ok());
2107 }
2108
2109 class SizeCheckTest : public HloTestBase {
2110 protected:
SizeCheckTest()2111 SizeCheckTest() {}
2112 };
2113
TEST_F(SizeCheckTest,CompileTimeCheckBinaryOpFail)2114 TEST_F(SizeCheckTest, CompileTimeCheckBinaryOpFail) {
2115 auto module = ParseAndReturnUnverifiedModule(R"(
2116 HloModule _
2117 ENTRY gds {
2118 size_0 = s32[] parameter(0)
2119 size_1 = s32[] parameter(1)
2120 arg = s32[4]{0} parameter(2)
2121 dynamic_arg_0 = s32[<=4] set-dimension-size(arg, size_0), dimensions={0}
2122 dynamic_arg_1 = s32[<=4] set-dimension-size(arg, size_1), dimensions={0}
2123 ROOT add = s32[<=4] add(dynamic_arg_0, dynamic_arg_1)
2124 })")
2125 .ValueOrDie();
2126 auto options = DynamicPadderOptions();
2127 options.shape_check_mode =
2128 DynamicDimensionInference::ShapeCheckMode::kCompileTime;
2129 DynamicPadder pass(options);
2130 auto status = pass.Run(module.get()).status();
2131 EXPECT_THAT(status.code(), tensorflow::error::INVALID_ARGUMENT);
2132 }
2133
TEST_F(SizeCheckTest,CompileTimeCheckBinaryOpPass)2134 TEST_F(SizeCheckTest, CompileTimeCheckBinaryOpPass) {
2135 // Two different sizes.
2136 auto module = ParseAndReturnUnverifiedModule(R"(
2137 HloModule _
2138 ENTRY gds {
2139 size_0 = s32[] parameter(0)
2140 size_0_reshape = s32[1] reshape(size_0)
2141 size_1 = s32[] reshape(size_0_reshape)
2142 arg = s32[4]{0} parameter(1)
2143 dynamic_arg_0 = s32[<=4] set-dimension-size(arg, size_0), dimensions={0}
2144 dynamic_arg_1 = s32[<=4] set-dimension-size(arg, size_1), dimensions={0}
2145 ROOT add = s32[<=4] add(dynamic_arg_0, dynamic_arg_1)
2146 })")
2147 .ValueOrDie();
2148 auto options = DynamicPadderOptions();
2149 options.shape_check_mode =
2150 DynamicDimensionInference::ShapeCheckMode::kCompileTime;
2151 DynamicDimensionSimplifier simplifier;
2152 EXPECT_TRUE(simplifier.Run(module.get()).ok());
2153 DynamicPadder pass(options);
2154 auto status = pass.Run(module.get()).status();
2155 EXPECT_TRUE(status.ok());
2156 }
2157
2158 } // namespace
2159 } // namespace xla
2160