xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/dynamic_padder_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/dynamic_padder.h"
17 
18 #include "absl/strings/str_replace.h"
19 #include "tensorflow/compiler/xla/client/xla_builder.h"
20 #include "tensorflow/compiler/xla/literal.h"
21 #include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h"
22 #include "tensorflow/compiler/xla/service/dynamic_dimension_simplifier.h"
23 #include "tensorflow/compiler/xla/service/hlo_computation.h"
24 #include "tensorflow/compiler/xla/service/hlo_dce.h"
25 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
26 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
27 #include "tensorflow/compiler/xla/service/hlo_module.h"
28 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
29 #include "tensorflow/compiler/xla/service/hlo_parser.h"
30 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
31 #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
32 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
33 #include "tensorflow/compiler/xla/shape_util.h"
34 #include "tensorflow/compiler/xla/status_macros.h"
35 #include "tensorflow/compiler/xla/test.h"
36 #include "tensorflow/compiler/xla/test_helpers.h"
37 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
38 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
39 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
40 #include "tensorflow/compiler/xla/tests/llvm_irgen_test_base.h"
41 #include "tensorflow/compiler/xla/tests/test_macros.h"
42 #include "tensorflow/compiler/xla/util.h"
43 #include "tensorflow/compiler/xla/xla_data.pb.h"
44 #include "tensorflow/core/lib/core/status_test_util.h"
45 #include "tensorflow/core/platform/test_benchmark.h"
46 #include "tensorflow/core/protobuf/error_codes.pb.h"
47 
48 namespace xla {
49 namespace {
50 
51 namespace m = ::xla::match;
52 namespace op = xla::testing::opcode_matchers;
53 
OpHasDynamismSupport(HloInstruction * hlo)54 OpDynamismSupport OpHasDynamismSupport(HloInstruction* hlo) {
55   if (hlo->opcode() != HloOpcode::kCustomCall) {
56     return OpDynamismSupport::kNoSupport;
57   }
58   if (hlo->custom_call_target() == "OpWithDynamicLowering") {
59     return OpDynamismSupport::kRequired;
60   }
61   return OpDynamismSupport::kNoSupport;
62 }
63 
CustomCallDynamicDimensionInference(HloInstruction * hlo,DynamicDimensionInference * inferencer)64 Status CustomCallDynamicDimensionInference(
65     HloInstruction* hlo, DynamicDimensionInference* inferencer) {
66   if (hlo->custom_call_target() == "OpWithDynamicLowering") {
67     if (hlo->shape().IsTuple()) {
68       // Use the operand's dynamic size as output dynamic size.
69       HloInstruction* dynamic_size =
70           inferencer->GetDynamicSize(hlo->mutable_operand(0), {1}, 0);
71       inferencer->SetDynamicSize(hlo, {1}, 0, dynamic_size);
72     } else {
73       // Use the operand's dynamic size as output dynamic size.
74       HloInstruction* dynamic_size =
75           inferencer->GetDynamicSize(hlo->mutable_operand(0), {}, 0);
76       inferencer->SetDynamicSize(hlo, {}, 0, dynamic_size);
77     }
78   }
79 
80   return OkStatus();
81 }
82 
83 class DynamicPadderTest : public HloTestBase {
84  protected:
DynamicPadderTest()85   DynamicPadderTest() : HloTestBase() { module_ = CreateNewVerifiedModule(); }
86 
GetHloModule(const std::string & hlo_text)87   std::unique_ptr<HloModule> GetHloModule(const std::string& hlo_text) {
88     std::unique_ptr<HloModule> module =
89         ParseAndReturnVerifiedModule(hlo_text).ValueOrDie();
90     return module;
91   }
92 
RunPadder(bool slice_dynamic_output=false)93   StatusOr<bool> RunPadder(bool slice_dynamic_output = false) {
94     DynamicPadderOptions options;
95     options.slice_dynamic_output = slice_dynamic_output;
96     options.custom_call_handler = CustomCallDynamicDimensionInference;
97     options.op_supports_dynamism_handler = OpHasDynamismSupport;
98     DynamicPadder padder(std::move(options));
99     return RunHloPass(&padder, module_.get());
100   }
101 
ExpectPadded(const HloInstruction * inst)102   void ExpectPadded(const HloInstruction* inst) {
103     EXPECT_THAT(inst,
104                 op::Select(op::Lt(op::Iota(), op::Broadcast(op::Parameter())),
105                            ::testing::_, op::Broadcast()));
106   }
107 
GetScalarAddComputation()108   HloComputation* GetScalarAddComputation() {
109     auto embedded_builder = HloComputation::Builder("add");
110     auto lhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
111         0, ShapeUtil::MakeShape(F32, {}), "lhs"));
112     auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
113         1, ShapeUtil::MakeShape(F32, {}), "rhs"));
114     embedded_builder.AddInstruction(
115         HloInstruction::CreateBinary(lhs->shape(), HloOpcode::kAdd, lhs, rhs));
116     return module_->AddEmbeddedComputation(embedded_builder.Build());
117   }
118 
119   std::unique_ptr<HloModule> module_;
120   const Shape scalar_shape_ = ShapeUtil::MakeShape(S32, {});
121 };
122 
123 class MemoryAlignmentTest : public HloTestBase {};
124 
125 // Test that dynamic padder will not cause memory misalignment in CUDA
126 // when the read or write address is not aligned with 32 bits.
127 // TODO(b/203599920): Disabled on CPU due to ASAN test failure.
TEST_F(MemoryAlignmentTest,DISABLED_ON_CPU (TestDataTypeFP16))128 TEST_F(MemoryAlignmentTest, DISABLED_ON_CPU(TestDataTypeFP16)) {
129   const std::string hlo_text = R"(
130     HloModule TestDataTypeFP16
131 
132     update_add (p0: f16[], p1: f16[]) -> f16[] {
133       p0 = f16[] parameter(0)
134       p1 = f16[] parameter(1)
135       ROOT out = f16[] add(p0, p1)
136     }
137 
138     ENTRY main () -> f16[<=1,1] {
139       c1 = s32[1]{0} constant({1})
140       c2 = f16[1,1]{1,0} constant({ {0.099976} })
141       shape = s32[] reshape(s32[1]{0} c1)
142       dim_size = f16[<=1,1]{1,0} set-dimension-size(f16[1,1]{1,0} c2, s32[] shape),
143           dimensions={0}
144       ROOT out = f16[<=1,1]{1,0} scatter(f16[<=1,1]{1,0} dim_size, s32[1]{0} c1, f16[1,1]{1,0} c2),
145           update_window_dims={1},
146           inserted_window_dims={0},
147           scatter_dims_to_operand_dims={0},
148           index_vector_dim=1,
149           to_apply=update_add
150     }
151   )";
152   EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
153 }
154 
TEST_F(DynamicPadderTest,ReduceTest)155 TEST_F(DynamicPadderTest, ReduceTest) {
156   auto builder = HloComputation::Builder(TestName());
157   auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
158   auto reduce_shape = ShapeUtil::MakeShape(F32, {2});
159 
160   auto data_param = builder.AddInstruction(
161       HloInstruction::CreateParameter(0, input_shape, "data_param"));
162   builder.AddInstruction(
163       HloInstruction::CreateParameter(1, scalar_shape_, "size_param"));
164 
165   auto negate = builder.AddInstruction(
166       HloInstruction::CreateUnary(input_shape, HloOpcode::kNegate, data_param));
167 
168   auto init = builder.AddInstruction(
169       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
170 
171   auto reduce = builder.AddInstruction(HloInstruction::CreateReduce(
172       reduce_shape, negate, init, {0, 2}, GetScalarAddComputation()));
173   EXPECT_FALSE(module_->is_dynamic());
174   module_->AddEntryComputation(builder.Build());
175 
176   // Set up dynamic parameter binding.
177   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
178       DynamicParameterBinding::DynamicParameter{1, {}},
179       DynamicParameterBinding::DynamicDimension{0, {}, 2}));
180 
181   TF_ASSERT_OK(RunPadder().status());
182 
183   ExpectPadded(reduce->operand(0));
184   EXPECT_TRUE(module_->is_dynamic());
185 }
186 
TEST_F(DynamicPadderTest,DynamicLoweringTest)187 TEST_F(DynamicPadderTest, DynamicLoweringTest) {
188   const std::string hlo_text = R"(
189 HloModule DynamicLowering
190 
191 ENTRY main {
192   param = s32[5] parameter(0)
193   const = s32[] constant(3)
194   param_padded = s32[<=5] set-dimension-size(param, const),
195                 dimensions={0}
196   custom-call.1 = s32[<=5] custom-call(param_padded),
197     custom_call_target="OpWithDynamicLowering"
198   custom-call.2 = s32[<=5] custom-call(custom-call.1),
199     custom_call_target="OpWithDynamicLowering"
200   // Negate doesn't support dynamic lowering.
201   ROOT negate = s32[<=5] negate(custom-call.2)
202 }
203 )";
204 
205   module_ = GetHloModule(hlo_text);
206 
207   TF_ASSERT_OK(RunPadder(/*slice_dynamic_output=*/true).status());
208   // After rewrite, we should have :
209   //
210   //   param
211   //     |
212   //  SliceToDynamic
213   //     |
214   //  OpWithDynamicLowering (custom_call_1)
215   //     |
216   //  OpWithDynamicLowering (custom_call_2)
217   //     |
218   //  PadToStatic
219   //     |
220   //   Negate
221   //     |
222   //   SliceToDynamic // Root require dynamic form tensor.
223   auto custom_call_1 =
224       module_->entry_computation()->GetInstructionWithName("custom-call.1");
225   auto custom_call_2 =
226       module_->entry_computation()->GetInstructionWithName("custom-call.2");
227   // Test that the input to custom call
228   HloInstruction* slice_to_dynamic = custom_call_1->mutable_operand(0);
229   ASSERT_THAT(slice_to_dynamic->opcode(), HloOpcode::kCustomCall);
230   ASSERT_THAT(slice_to_dynamic->custom_call_target(), "SliceToDynamic");
231   ASSERT_EQ(custom_call_2->user_count(), 1);
232   HloInstruction* pad_to_static = custom_call_2->users()[0];
233   ASSERT_THAT(pad_to_static->opcode(), HloOpcode::kCustomCall);
234   ASSERT_THAT(pad_to_static->custom_call_target(), "PadToStatic");
235   slice_to_dynamic = module_->entry_computation()->root_instruction();
236   ASSERT_THAT(slice_to_dynamic->opcode(), HloOpcode::kCustomCall);
237   ASSERT_THAT(slice_to_dynamic->custom_call_target(), "SliceToDynamic");
238 }
239 
TEST_F(DynamicPadderTest,DynamicLoweringTestTupleInput)240 TEST_F(DynamicPadderTest, DynamicLoweringTestTupleInput) {
241   const std::string hlo_text = R"(
242 HloModule DynamicLowering
243 
244 ENTRY main {
245   param = s32[5] parameter(0)
246   const = s32[] constant(3)
247   param_padded = s32[<=5] set-dimension-size(param, const),
248                 dimensions={0}
249   // Create a tuple with static and dynamic componenet.
250   tuple_arg = (s32[], s32[<=5]) tuple(const, param_padded)
251   custom-call.1 = (s32[], s32[<=5]) custom-call(tuple_arg),
252     custom_call_target="OpWithDynamicLowering"
253   custom-call.2 = (s32[], s32[<=5]) custom-call(custom-call.1),
254     custom_call_target="OpWithDynamicLowering"
255   data = s32[<=5]{0} get-tuple-element(custom-call.2), index=1
256   // Negate doesn't support dynamic lowering.
257   ROOT negate = s32[<=5] negate(data)
258 }
259 )";
260 
261   module_ = GetHloModule(hlo_text);
262 
263   TF_ASSERT_OK(RunPadder(/*slice_dynamic_output=*/true).status());
264   // After rewrite, we should have :
265   //
266   //   param
267   //     |
268   //  SliceToDynamic
269   //     |
270   //    Tuple
271   //     |
272   //  OpWithDynamicLowering (custom_call_1)
273   //     |
274   //  OpWithDynamicLowering (custom_call_2)
275   //     |
276   //   GTE
277   //     |
278   //  PadToStatic
279   //     |
280   //   Negate
281   //     |
282   //   SliceToDynamic // Root require dynamic form tensor.
283 
284   auto* root = module_->entry_computation()->root_instruction();
285   EXPECT_THAT(root,
286               op::CustomCall("SliceToDynamic", op::Negate(), op::Constant()));
287   HloInstruction* negate = root->mutable_operand(0);
288   EXPECT_THAT(
289       negate,
290       op::Negate(op::GetTupleElement(op::CustomCall(
291           "PadToStatic", op::GetTupleElement(op::CustomCall(
292                              "OpWithDynamicLowering", ::testing::_))))));
293   auto custom_call_1 =
294       module_->entry_computation()->GetInstructionWithName("custom-call.1");
295   EXPECT_THAT(custom_call_1,
296               op::CustomCall("OpWithDynamicLowering",
297                              op::Tuple(op::GetTupleElement(),
298                                        op::CustomCall("SliceToDynamic"))));
299 }
300 
TEST_F(DynamicPadderTest,DynamicOutputNestedTuple)301 TEST_F(DynamicPadderTest, DynamicOutputNestedTuple) {
302   const std::string hlo_text = R"(
303 HloModule DynamicLowering
304 
305 ENTRY main {
306   param = s32[5] parameter(0)
307   const = s32[] constant(3)
308   const2 = s32[] constant(4)
309   param_padded = s32[<=5] set-dimension-size(param, const),
310                 dimensions={0}
311   // Create a tuple with static and dynamic componenet.
312   tuple0 = (s32[], s32[<=5]) tuple(const, param_padded)
313   ROOT tuple1 = (s32[], (s32[], s32[<=5])) tuple(const2, tuple0)
314 }
315 )";
316 
317   module_ = GetHloModule(hlo_text);
318 
319   TF_ASSERT_OK(RunPadder(/*slice_dynamic_output=*/true).status());
320   TF_ASSERT_OK(TupleSimplifier().Run(module_.get()).status());
321   XLA_LOG_LINES(0, module_->ToString());
322 
323   auto* root = module_->entry_computation()->root_instruction();
324   EXPECT_THAT(root, op::Tuple(op::Constant(), op::Tuple()));
325   HloInstruction* nested_tuple = root->mutable_operand(1);
326   EXPECT_THAT(nested_tuple,
327               op::Tuple(op::Constant(), op::CustomCall("SliceToDynamic")));
328 }
329 
TEST_F(DynamicPadderTest,ConvolutionTest)330 TEST_F(DynamicPadderTest, ConvolutionTest) {
331   auto builder = HloComputation::Builder(TestName());
332   constexpr int xdim = 3;
333   constexpr int ydim = 2;
334   constexpr int zdim = 1;
335   auto xy_shape = ShapeUtil::MakeShape(F32, {xdim, ydim});
336   auto yz_shape = ShapeUtil::MakeShape(F32, {ydim, zdim});
337   auto zx_shape = ShapeUtil::MakeShape(F32, {zdim, xdim});
338 
339   auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
340       /*parameter_number=*/0, xy_shape, "A"));
341   auto* b_param = builder.AddInstruction(HloInstruction::CreateParameter(
342       /*parameter_number=*/1, yz_shape, "B"));
343   builder.AddInstruction(HloInstruction::CreateParameter(
344       /*parameter_number=*/2, scalar_shape_, "size_param"));
345 
346   auto dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(0);
347 
348   dnums.set_kernel_input_feature_dimension(0);
349   dnums.set_kernel_output_feature_dimension(1);
350   dnums.set_input_batch_dimension(0);
351   dnums.set_output_batch_dimension(1);
352   dnums.set_output_feature_dimension(0);
353 
354   Window window;
355 
356   auto* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
357       zx_shape, a_param, b_param, /*feature_group_count=*/1,
358       /*batch_group_count=*/1, window, dnums,
359       HloTestBase::DefaultPrecisionConfig(2)));
360 
361   module_->AddEntryComputation(builder.Build());
362 
363   // Set up binding for contracting dimensions.
364   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
365       DynamicParameterBinding::DynamicParameter{2, {}},
366       DynamicParameterBinding::DynamicDimension{0, {}, 1}));
367 
368   TF_ASSERT_OK(RunPadder().status());
369 
370   ExpectPadded(conv->operand(0));
371 }
372 
TEST_F(DynamicPadderTest,ConvolutionNoPad)373 TEST_F(DynamicPadderTest, ConvolutionNoPad) {
374   auto builder = HloComputation::Builder(TestName());
375   constexpr int xdim = 3;
376   constexpr int ydim = 2;
377   constexpr int zdim = 1;
378   auto xy_shape = ShapeUtil::MakeShape(F32, {xdim, ydim});
379   auto yz_shape = ShapeUtil::MakeShape(F32, {ydim, zdim});
380   auto zx_shape = ShapeUtil::MakeShape(F32, {zdim, xdim});
381 
382   auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
383       /*parameter_number=*/0, xy_shape, "A"));
384   auto* b_param = builder.AddInstruction(HloInstruction::CreateParameter(
385       /*parameter_number=*/1, yz_shape, "B"));
386   builder.AddInstruction(HloInstruction::CreateParameter(
387       /*parameter_number=*/2, scalar_shape_, "size_param"));
388 
389   auto dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(0);
390 
391   dnums.set_kernel_input_feature_dimension(0);
392   dnums.set_kernel_output_feature_dimension(1);
393   dnums.set_input_batch_dimension(0);
394   dnums.set_output_batch_dimension(1);
395   dnums.set_output_feature_dimension(0);
396 
397   Window window;
398 
399   auto* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
400       zx_shape, a_param, b_param, /*feature_group_count=*/1,
401       /*batch_group_count=*/1, window, dnums,
402       HloTestBase::DefaultPrecisionConfig(2)));
403 
404   module_->AddEntryComputation(builder.Build());
405 
406   // Set up dynamic parameter binding for non-contracting dimension.
407   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
408       DynamicParameterBinding::DynamicParameter{2, {}},
409       DynamicParameterBinding::DynamicDimension{0, {}, 0}));
410 
411   TF_ASSERT_OK(RunPadder().status());
412 
413   EXPECT_THAT(conv->operand(0), op::Parameter());
414 }
415 
TEST_F(DynamicPadderTest,ReduceWindowNoPadForTrivialWindow)416 TEST_F(DynamicPadderTest, ReduceWindowNoPadForTrivialWindow) {
417   auto builder = HloComputation::Builder(TestName());
418   auto input_shape = ShapeUtil::MakeShape(F32, {4, 5});
419   auto reduce_shape = ShapeUtil::MakeShape(F32, {3, 5});
420 
421   auto input = builder.AddInstruction(
422       HloInstruction::CreateParameter(0, input_shape, "input"));
423   builder.AddInstruction(
424       HloInstruction::CreateParameter(1, scalar_shape_, "size_param"));
425   auto init = builder.AddInstruction(
426       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
427   TF_ASSERT_OK_AND_ASSIGN(Window window, ParseWindow("size=2x1 pad=0_0x0_0"));
428   auto output = builder.AddInstruction(HloInstruction::CreateReduceWindow(
429       reduce_shape, input, init, window, GetScalarAddComputation()));
430 
431   module_->AddEntryComputation(builder.Build());
432 
433   // Set up dynamic parameter binding.
434   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
435       DynamicParameterBinding::DynamicParameter{1, {}},
436       DynamicParameterBinding::DynamicDimension{0, {}, 1}));
437 
438   TF_ASSERT_OK(RunPadder().status());
439 
440   EXPECT_THAT(output->operand(0), op::Parameter());
441 }
442 
TEST_F(DynamicPadderTest,VariadicReduceWindowNoPadForTrivialWindow)443 TEST_F(DynamicPadderTest, VariadicReduceWindowNoPadForTrivialWindow) {
444   const std::string hlo_text = R"(
445 HloModule VariadicReduceWindowNoPadForTrivialWindow
446 
447 add_f32 (a: f32[], b: s32[], c: f32[], d: s32[]) -> (f32[], s32[]) {
448   a = f32[] parameter(0)
449   b = s32[] parameter(1)
450   c = f32[] parameter(2)
451   d = s32[] parameter(3)
452   add.0 = f32[] add(a, c)
453   add.1 = s32[] add(b, d)
454   ROOT out = tuple(add.0, add.1)
455 }
456 
457 ENTRY main {
458   input.0 = f32[4, 5] parameter(0)
459   input.1 = s32[4, 5] parameter(1)
460   size_param.0 = s32[] parameter(2)
461   size_param.1 = s32[] parameter(3)
462   init.0 = f32[] constant(0.0)
463   init.1 = s32[] constant(0)
464   ROOT output = (f32[3, 5], s32[3, 5]) reduce-window(input.0, input.1, init.0, init.1), window={size=2x1 pad=0_0x0_0}, to_apply=add_f32
465 }
466 )";
467 
468   const int kNumParams = 2;
469   module_ = ParseAndReturnVerifiedModule(hlo_text).ValueOrDie();
470   // Set up dynamic parameter binding.
471   for (int i = 0; i < kNumParams; ++i) {
472     TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
473         DynamicParameterBinding::DynamicParameter{2, {}},
474         DynamicParameterBinding::DynamicDimension{i, {}, 1}));
475   }
476 
477   TF_ASSERT_OK(RunPadder().status());
478 
479   for (int i = 0; i < kNumParams; ++i) {
480     EXPECT_THAT(module_->entry_computation()->root_instruction()->operand(i),
481                 op::Parameter());
482   }
483 }
484 
TEST_F(DynamicPadderTest,PadS8ToS32Dot)485 TEST_F(DynamicPadderTest, PadS8ToS32Dot) {
486   const std::string hlo_text = R"(
487 HloModule test
488 ENTRY test {
489   a = s8[<=16,32] parameter(0)
490   b = s8[32,64] parameter(1)
491   ROOT root = s32[<=16,64] dot(a, b), lhs_contracting_dims={1}, rhs_contracting_dims={0}
492 }
493 )";
494   module_ = GetHloModule(hlo_text);
495   TF_ASSERT_OK(RunPadder(/*slice_dynamic_output=*/true).status());
496 
497   EXPECT_THAT(module_->entry_computation()->root_instruction(),
498               GmockMatch(m::CustomCall("SliceToDynamic",
499                                        m::Dot(m::Op().WithShape(S8, {16, 32}),
500                                               m::Op().WithShape(S8, {32, 64}))
501                                            .WithShape(S32, {16, 64}),
502                                        m::Op(), m::Op())));
503 }
504 
TEST_F(DynamicPadderTest,PadToStaticForCustomCall)505 TEST_F(DynamicPadderTest, PadToStaticForCustomCall) {
506   const std::string hlo_text = R"(
507 HloModule test
508 ENTRY test {
509   a = f32[64] parameter(0)
510   ROOT c = f32[<=128] custom-call(a),
511     custom_call_target="UnknownOp"
512 }
513 )";
514 
515   module_ = GetHloModule(hlo_text);
516   TF_ASSERT_OK(RunPadder(/*slice_dynamic_output=*/true).status());
517 
518   EXPECT_THAT(
519       module_->entry_computation()->root_instruction(),
520       GmockMatch(m::CustomCall("SliceToDynamic",
521                                m::GetTupleElement(m::CustomCall(
522                                    "PadToStatic", m::CustomCall("UnknownOp"))),
523                                m::Op())));
524 }
525 
526 // Test that dynamic padder has the same result as if not padded.
527 class ExecutionTest : public HloTestBase {
528  protected:
GetHloModule(const std::string & hlo_text)529   std::unique_ptr<HloModule> GetHloModule(const std::string& hlo_text) {
530     std::unique_ptr<HloModule> module =
531         ParseAndReturnVerifiedModule(hlo_text).ValueOrDie();
532     return module;
533   }
PadAndExecute(std::unique_ptr<HloModule> module,absl::Span<Literal * const> arguments,bool slice_dynamic_output=true)534   Literal PadAndExecute(std::unique_ptr<HloModule> module,
535                         absl::Span<Literal* const> arguments,
536                         bool slice_dynamic_output = true) {
537     if (!slice_dynamic_output) {
538       auto new_config = module->config();
539       new_config.mutable_entry_computation_layout()
540           ->mutable_result_layout()
541           ->ClearDynamicShape();
542       module->set_config(new_config);
543     }
544     DynamicPadderOptions options;
545     options.slice_dynamic_output = slice_dynamic_output;
546     DynamicPadder padder(options);
547     TF_CHECK_OK(padder.Run(module.get()).status());
548     HloDCE dce;
549     TF_CHECK_OK(dce.Run(module.get()).status());
550     return ExecuteAndTransfer(std::move(module), arguments);
551   }
552 };
553 
XLA_TEST_F(ExecutionTest,ScatterUpdate)554 XLA_TEST_F(ExecutionTest, ScatterUpdate) {
555   // Test that scattering on indices=[2] is same as scattering on indices=[4]
556   // and dynamic dimension = 2
557   const std::string hlo_text = R"(
558 HloModule TensorFlowScatterV1
559 
560 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
561   lhs = s32[] parameter(0)
562   ROOT rhs = s32[] parameter(1)
563 }
564 
565 ENTRY main {
566   operand = s32[3,3] parameter(0)
567   indices = s32[INDICES_BOUND] parameter(1)
568   updates = s32[INDICES_BOUND,3] parameter(2)
569   dynamic_size = s32[] parameter(3)
570   ROOT scatter = s32[3,3] scatter(operand, indices, updates),
571       to_apply=update_s32,
572       update_window_dims={1},
573       inserted_window_dims={0},
574       scatter_dims_to_operand_dims={0},
575       index_vector_dim=1
576 
577 }
578 )";
579   const std::string hlo_text_not_padded =
580       absl::StrReplaceAll(hlo_text, {{"INDICES_BOUND", "2"}});
581   auto module_not_padded = GetHloModule(hlo_text_not_padded);
582 
583   Literal operand =
584       LiteralUtil::CreateR2<int32_t>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
585   Literal scatter_indices = LiteralUtil::CreateR1<int32_t>({0, 2});
586   Literal updates =
587       LiteralUtil::CreateR2<int32_t>({{10, 20, 30}, {70, 80, 90}});
588   Literal dynamic_size = LiteralUtil::CreateR0<int32_t>(2);
589 
590   Literal not_padded =
591       ExecuteAndTransfer(std::move(module_not_padded),
592                          {&operand, &scatter_indices, &updates, &dynamic_size});
593 
594   // Pad input to 4.
595   const std::string hlo_text_padded =
596       absl::StrReplaceAll(hlo_text, {{"INDICES_BOUND", "4"}});
597   auto module_padded = GetHloModule(hlo_text_padded);
598   // Set up dynamic parameter binding.
599   TF_CHECK_OK(module_padded->dynamic_parameter_binding().Bind(
600       DynamicParameterBinding::DynamicParameter{3, {}},
601       DynamicParameterBinding::DynamicDimension{1, {}, 0}));
602   TF_CHECK_OK(module_padded->dynamic_parameter_binding().Bind(
603       DynamicParameterBinding::DynamicParameter{3, {}},
604       DynamicParameterBinding::DynamicDimension{2, {}, 0}));
605   // Pad the rest of input with garbage data.
606   Literal scatter_indices_padded = LiteralUtil::CreateR1<int32_t>({0, 2, 0, 4});
607   Literal updates_padded = LiteralUtil::CreateR2<int32_t>(
608       {{10, 20, 30}, {70, 80, 90}, {30, 22, 11}, {-1, 20, -1}});
609   DynamicPadder padder;
610   TF_CHECK_OK(padder.Run(module_padded.get()).status());
611   Literal padded = PadAndExecute(
612       std::move(module_padded),
613       {&operand, &scatter_indices_padded, &updates_padded, &dynamic_size});
614 
615   EXPECT_EQ(padded, not_padded);
616 }
617 
XLA_TEST_F(ExecutionTest,ScatterUpdateWindowDim)618 XLA_TEST_F(ExecutionTest, ScatterUpdateWindowDim) {
619   const std::string hlo_text = R"(
620 HloModule ScatterUpdateWindowDim
621 
622 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
623   lhs = s32[] parameter(0)
624   ROOT rhs = s32[] parameter(1)
625 }
626 
627 ENTRY main {
628   operand = s32[1,2,3] parameter(0)
629   indices = s32[1] parameter(1)
630   updates = s32[2,3,1] parameter(2)
631   dynamic_size = s32[] constant(1)
632   operand_dynamic = s32[1, <=2, 3] set-dimension-size(operand, dynamic_size),
633       dimensions={1}
634   updates_dynamic = s32[<=2, 3, 1] set-dimension-size(updates, dynamic_size),
635       dimensions={0}
636   ROOT scatter = s32[1, <=2, 3] scatter(operand_dynamic, indices, updates_dynamic),
637       to_apply=update_s32,
638       update_window_dims={0, 1},
639       inserted_window_dims={0},
640       scatter_dims_to_operand_dims={0},
641       index_vector_dim=1
642 
643 }
644 )";
645   auto hlo_module = GetHloModule(hlo_text);
646 
647   Literal operand = LiteralUtil::CreateR3<int32_t>({{{0, 0, 0}, {0, 0, 0}}});
648   Literal scatter_indices = LiteralUtil::CreateR1<int32_t>({0});
649   Literal updates =
650       LiteralUtil::CreateR3<int32_t>({{{10}, {20}, {30}}, {{70}, {80}, {90}}});
651 
652   Literal padded = PadAndExecute(std::move(hlo_module),
653                                  {&operand, &scatter_indices, &updates}, false);
654   Literal expected =
655       LiteralUtil::CreateR3<int32_t>({{{10, 20, 30}, {70, 80, 90}}});
656   EXPECT_EQ(padded, expected);
657 }
658 
XLA_TEST_F(ExecutionTest,ScatterUpdateF32)659 XLA_TEST_F(ExecutionTest, ScatterUpdateF32) {
660   // Test that scattering on indices=[2] is same as scattering on indices=[4]
661   // and dynamic dimension = 2
662   const std::string hlo_text = R"(
663 HloModule TensorFlowScatterV1
664 
665 update_f32 (lhs: f32[], rhs: f32[]) -> f32[] {
666   lhs = f32[] parameter(0)
667   ROOT rhs = f32[] parameter(1)
668 }
669 
670 ENTRY main {
671   operand = f32[3,3] parameter(0)
672   indices = s32[2] parameter(1)
673   updates = f32[2,3] parameter(2)
674   dynamic_size = s32[] parameter(3)
675   ROOT scatter = f32[3,3] scatter(operand, indices, updates),
676       to_apply=update_f32,
677       update_window_dims={1},
678       inserted_window_dims={0},
679       scatter_dims_to_operand_dims={0},
680       index_vector_dim=1
681 
682 }
683 )";
684 
685   auto module_not_padded = GetHloModule(hlo_text);
686 
687   Literal operand = LiteralUtil::CreateR2<float>(
688       {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}, {7.0, 8.0, 9.0}});
689   Literal scatter_indices = LiteralUtil::CreateR1<int32_t>({0, 2});
690   Literal updates =
691       LiteralUtil::CreateR2<float>({{10.0, 20.0, 30.0}, {70.0, 80.0, 90.0}});
692   // Dynamic Size is 1, pad to 2
693   Literal dynamic_size = LiteralUtil::CreateR0<int32_t>(1);
694 
695   auto module_padded = GetHloModule(hlo_text);
696   // Set up dynamic parameter binding.
697   TF_CHECK_OK(module_padded->dynamic_parameter_binding().Bind(
698       DynamicParameterBinding::DynamicParameter{3, {}},
699       DynamicParameterBinding::DynamicDimension{1, {}, 0}));
700   TF_CHECK_OK(module_padded->dynamic_parameter_binding().Bind(
701       DynamicParameterBinding::DynamicParameter{3, {}},
702       DynamicParameterBinding::DynamicDimension{2, {}, 0}));
703   DynamicPadder padder;
704   TF_CHECK_OK(padder.Run(module_padded.get()).status());
705   Literal not_padded =
706       PadAndExecute(std::move(module_padded),
707                     {&operand, &scatter_indices, &updates, &dynamic_size});
708   // Although we have two indices, only the first element is updated because of
709   // padding.
710   EXPECT_EQ(LiteralUtil::CreateR2<float>(
711                 {{10.0, 20.0, 30.0}, {4.0, 5.0, 6.0}, {7.0, 8.0, 9.0}}),
712             not_padded);
713 }
714 
XLA_TEST_F(ExecutionTest,WholeDimensionGather)715 XLA_TEST_F(ExecutionTest, WholeDimensionGather) {
716   // Second dimension (size 2) is dynamic, assuming real size is 1 and padded to
717   // 2:
718   //
719   // [[1, 2]
720   //  [3, 4]
721   //  [5, 6]]
722   //
723   // Gathering the second dimension out creates:
724   //
725   // [3, 4]
726   //
727   // Reducing this gives us 3 (4 is padded value so ignored)
728   const std::string hlo_text = R"(
729 HloModule TensorFlowScatterV1
730 
731 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
732   lhs = s32[] parameter(0)
733   rhs = s32[] parameter(1)
734   ROOT add = s32[] add(lhs, rhs)
735 }
736 
737 ENTRY main {
738   param = s32[3, 2, 1] parameter(0)
739   size = s32[] constant(1)
740   param_padded = s32[3, 2, 1] set-dimension-size(param, size), dimensions={1}
741   index = s32[] constant(1)
742   gather = s32[2,1]{1,0} gather(param_padded, index),
743               offset_dims={0,1},
744               collapsed_slice_dims={0},
745               start_index_map={0},
746               index_vector_dim=0,
747               slice_sizes={1,2,1}
748   init = s32[] constant(0)
749   ROOT reduce = s32[] reduce(gather, init),
750       dimensions={0, 1},
751       to_apply=update_s32
752 }
753 )";
754   // Slicing out entire dimension propagates the dimension
755   Literal operand =
756       LiteralUtil::CreateR3<int32_t>({{{1}, {2}}, {{3}, {4}}, {{5}, {6}}});
757   auto module = GetHloModule(hlo_text);
758   DynamicPadder padder;
759   TF_CHECK_OK(padder.Run(module.get()).status());
760   Literal result = PadAndExecute(std::move(module), {&operand});
761 
762   // Only first element will be reduced.
763   Literal expected = LiteralUtil::CreateR0<int32_t>(3);
764 
765   EXPECT_EQ(result, expected);
766 }
767 
XLA_TEST_F(ExecutionTest,TwoDimensionReduce)768 XLA_TEST_F(ExecutionTest, TwoDimensionReduce) {
769   // Test that reducing on operand=[2,2] is same as reducing on operand=[4,4]
770   // and dynamic dimension = 2
771   const std::string hlo_text = R"(
772 HloModule TensorFlowScatterV1
773 
774 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
775   lhs = s32[] parameter(0)
776   rhs = s32[] parameter(1)
777   ROOT add = s32[] add(lhs, rhs)
778 }
779 
780 ENTRY main {
781   param = s32[INDICES_BOUND, INDICES_BOUND] parameter(0)
782   dynamic_size = s32[] parameter(1)
783   const = s32[] constant(0)
784   ROOT reduce = s32[] reduce(param, const),
785       dimensions={0, 1},
786       to_apply=update_s32
787 }
788 )";
789   const std::string hlo_text_not_padded =
790       absl::StrReplaceAll(hlo_text, {{"INDICES_BOUND", "2"}});
791   auto module_not_padded = GetHloModule(hlo_text_not_padded);
792 
793   Literal operand = LiteralUtil::CreateR2<int32_t>({{1, 2}, {4, 5}});
794   Literal dynamic_size = LiteralUtil::CreateR0<int32_t>(2);
795 
796   Literal not_padded = ExecuteAndTransfer(std::move(module_not_padded),
797                                           {&operand, &dynamic_size});
798 
799   // Pad input to 4.
800   const std::string hlo_text_padded =
801       absl::StrReplaceAll(hlo_text, {{"INDICES_BOUND", "4"}});
802   auto module_padded = GetHloModule(hlo_text_padded);
803   // Set up dynamic parameter binding.
804   TF_CHECK_OK(module_padded->dynamic_parameter_binding().Bind(
805       DynamicParameterBinding::DynamicParameter{1, {}},
806       DynamicParameterBinding::DynamicDimension{0, {}, 0}));
807   TF_CHECK_OK(module_padded->dynamic_parameter_binding().Bind(
808       DynamicParameterBinding::DynamicParameter{1, {}},
809       DynamicParameterBinding::DynamicDimension{0, {}, 1}));
810   // Pad the rest of input with garbage data.
811   Literal operand_padded = LiteralUtil::CreateR2<int32_t>(
812       {{1, 2, 3, 4}, {4, 5, 6, 7}, {1, 2, 3, 4}, {4, 5, 6, 7}});
813   DynamicPadder padder;
814   TF_CHECK_OK(padder.Run(module_padded.get()).status());
815   Literal padded =
816       PadAndExecute(std::move(module_padded), {&operand_padded, &dynamic_size});
817 
818   EXPECT_EQ(padded, not_padded);
819 }
820 
XLA_TEST_F(ExecutionTest,DynamicDimensionClamp)821 XLA_TEST_F(ExecutionTest, DynamicDimensionClamp) {
822   const std::string hlo_text = R"(
823 HloModule TensorFlowTenaryV1
824 
825 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
826   lhs = s32[] parameter(0)
827   rhs = s32[] parameter(1)
828   ROOT add = s32[] add(lhs, rhs)
829 }
830 
831 ENTRY main {
832   param = s32[5] parameter(0)
833   const = s32[] constant(3)
834   param_padded = s32[5] set-dimension-size(param, const), dimensions={0}
835   clamp = s32[5] clamp(param_padded, param_padded, param_padded)
836   init = s32[] constant(0)
837   ROOT reduce = s32[] reduce(clamp, init),
838       dimensions={0},
839       to_apply=update_s32
840 }
841 )";
842 
843   // Input has upper bound of 5, dynamic dimension is 3.
844   Literal operand = LiteralUtil::CreateR1<int32_t>({1, 2, 3, 4, 5});
845   auto module = GetHloModule(hlo_text);
846 
847   Literal result = PadAndExecute(std::move(module), {&operand});
848 
849   // only first 3 elements will be reduced.
850   Literal expected = LiteralUtil::CreateR0<int32_t>(6);
851 
852   EXPECT_EQ(result, expected);
853 }
854 
XLA_TEST_F(ExecutionTest,DynamicConcat)855 XLA_TEST_F(ExecutionTest, DynamicConcat) {
856   // Concatting a list of {dynamic_operand, static_operand, dynamic_operand}.
857   const std::string hlo_text = R"(
858 HloModule DynamicConcat
859 
860 ENTRY main {
861   param_0 = s32[3] parameter(0)
862   param_1 = s32[3] parameter(1)
863   param_2 = s32[3] parameter(2)
864   size = s32[] constant(2)
865   param_padded_0 = s32[<=3] set-dimension-size(param_0, size), dimensions={0}
866   param_padded_2 = s32[<=3] set-dimension-size(param_2, size), dimensions={0}
867   ROOT %concatenate = s32[9]
868     concatenate(s32[<=3] param_padded_0, s32[<=3] param_1, s32[<=3] param_padded_2),
869     dimensions={0}
870 }
871 )";
872 
873   // Input has upper bound of 3, dynamic dimension is 2. Using -1 as padding.
874   Literal operand_0 =
875       LiteralUtil::CreateR1<int32_t>({1, 2, -1});  // Dynamic operand.
876   Literal operand_1 =
877       LiteralUtil::CreateR1<int32_t>({3, 4, 5});  // Static operand.
878   Literal operand_2 =
879       LiteralUtil::CreateR1<int32_t>({6, 7, -1});  // Dynamic operand.
880   auto module = GetHloModule(hlo_text);
881 
882   Literal result = PadAndExecute(std::move(module),
883                                  {&operand_0, &operand_1, &operand_2}, false);
884   result.SetDynamicSize(0, 7);
885   Literal expected = LiteralUtil::CreateR1<int32_t>({1, 2, 3, 4, 5, 6, 7});
886 
887   EXPECT_EQ(result, expected);
888 }
889 
XLA_TEST_F(ExecutionTest,DynamicReverseSingleDim)890 XLA_TEST_F(ExecutionTest, DynamicReverseSingleDim) {
891   const std::string hlo_text = R"(
892 HloModule DynamicConcat
893 
894 ENTRY main {
895   param_0 = s32[3] parameter(0)
896   size = s32[] constant(2)
897   param_padded_0 = s32[<=3] set-dimension-size(param_0, size), dimensions={0}
898   ROOT %reverse = s32[<=3]
899     reverse(s32[<=3] param_padded_0),
900     dimensions={0}
901 }
902 )";
903 
904   // Input has upper bound of 3, dynamic dimension is 2. Using -1 as padding.
905   Literal operand_0 =
906       LiteralUtil::CreateR1<int32_t>({1, 2, -1});  // Dynamic operand.
907   auto module = GetHloModule(hlo_text);
908 
909   Literal result = PadAndExecute(std::move(module), {&operand_0}, false);
910   result.SetDynamicSize(0, 2);
911   Literal expected = LiteralUtil::CreateR1<int32_t>({2, 1});
912 
913   EXPECT_EQ(result, expected);
914 }
915 
XLA_TEST_F(ExecutionTest,DynamicReverseMultiDims)916 XLA_TEST_F(ExecutionTest, DynamicReverseMultiDims) {
917   const std::string hlo_text = R"(
918 HloModule DynamicConcat
919 
920 ENTRY main {
921   param_0 = s32[3, 3] parameter(0)
922   size = s32[] constant(2)
923   param_padded_0 = s32[<=3, 3] set-dimension-size(param_0, size), dimensions={0}
924   param_padded_1 = s32[<=3, <=3] set-dimension-size(param_padded_0, size),
925     dimensions={1}
926   ROOT %reverse = s32[<=3, <=3]
927     reverse(s32[<=3, <=3] param_padded_1),
928     dimensions={0, 1}
929 }
930 )";
931 
932   // Input has upper bound of 3, dynamic dimension is 2. Using -1 as padding.
933   Literal operand_0 = LiteralUtil::CreateR2<int32_t>(
934       {{1, 2, -1}, {3, 4, -1}, {-1, -1, -1}});  // Dynamic operand.
935   auto module = GetHloModule(hlo_text);
936 
937   Literal result = PadAndExecute(std::move(module), {&operand_0}, false);
938   result.SetDynamicSize(0, 2);
939   result.SetDynamicSize(1, 2);
940   Literal expected = LiteralUtil::CreateR2<int32_t>({{4, 3}, {2, 1}});
941 
942   EXPECT_EQ(result, expected);
943 }
944 
XLA_TEST_F(ExecutionTest,DynamicDimensionReduce)945 XLA_TEST_F(ExecutionTest, DynamicDimensionReduce) {
946   const std::string hlo_text = R"(
947 HloModule TensorFlowScatterV1
948 
949 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
950   lhs = s32[] parameter(0)
951   rhs = s32[] parameter(1)
952   ROOT add = s32[] add(lhs, rhs)
953 }
954 
955 ENTRY main {
956   param = s32[5] parameter(0)
957   const = s32[] constant(3)
958   param_padded = s32[<=5] set-dimension-size(param, const), dimensions={0}
959   init = s32[] constant(0)
960   ROOT reduce = s32[] reduce(param_padded, init),
961       dimensions={0},
962       to_apply=update_s32
963 }
964 )";
965 
966   // Input has upper bound of 5, dynamic dimension is 3.
967   Literal operand = LiteralUtil::CreateR1<int32_t>({1, 2, 3, 4, 5});
968   auto module = GetHloModule(hlo_text);
969 
970   Literal result = PadAndExecute(std::move(module), {&operand});
971 
972   // only first 3 elements will be reduced.
973   Literal expected = LiteralUtil::CreateR0<int32_t>(6);
974 
975   EXPECT_EQ(result, expected);
976 }
977 
XLA_TEST_F(ExecutionTest,InputMinorDimensionReshape)978 XLA_TEST_F(ExecutionTest, InputMinorDimensionReshape) {
979   const std::string hlo_text = R"(
980 HloModule TensorFlowScatterV1
981 
982 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
983   lhs = s32[] parameter(0)
984   rhs = s32[] parameter(1)
985   ROOT add = s32[] add(lhs, rhs)
986 }
987 
988 ENTRY main {
989   param = s32[1, 2, 5, 1] parameter(0)
990   const = s32[] constant(3)
991   param_padded = s32[1, 2, 5, 1] set-dimension-size(param, const), dimensions={2}
992   reshaped = s32[10] reshape(param_padded)
993   init = s32[] constant(0)
994   ROOT reduce = s32[] reduce(reshaped, init),
995       dimensions={0},
996       to_apply=update_s32
997 }
998 )";
999 
1000   // The third dimension has upper bound of 5, dynamic dimension is 3.
1001   Literal operand = LiteralUtil::CreateR4<int32_t>(
1002       {{{{1}, {2}, {3}, {4}, {5}}, {{2}, {4}, {6}, {7}, {8}}}});
1003   auto module = GetHloModule(hlo_text);
1004 
1005   Literal result = PadAndExecute(std::move(module), {&operand});
1006 
1007   // Only the first 6 elements will be reduced.
1008   Literal expected = LiteralUtil::CreateR0<int32_t>(18);
1009 
1010   EXPECT_EQ(result, expected);
1011 }
1012 
XLA_TEST_F(ExecutionTest,SliceSingleElement)1013 XLA_TEST_F(ExecutionTest, SliceSingleElement) {
1014   // Slicing out a single element is supported.
1015   const std::string hlo_text = R"(
1016 HloModule Slicing
1017 
1018 ENTRY main {
1019   param = s32[5] parameter(0)
1020   const = s32[] constant(3)
1021   param_padded = s32[5] set-dimension-size(param, const), dimensions={0}
1022   ROOT slice = s32[1]{0} slice(param_padded), slice={[0:1]}
1023 }
1024 )";
1025 
1026   // The dynamic dimension has upper bound of 5, dynamic dimension is 3.
1027   Literal operand = LiteralUtil::CreateR1<int32_t>({0, 1, 2, 3, 4});
1028   auto module = GetHloModule(hlo_text);
1029 
1030   Literal result = PadAndExecute(std::move(module), {&operand});
1031 
1032   Literal expected = LiteralUtil::CreateR1<int32_t>({0});
1033 
1034   EXPECT_EQ(result, expected);
1035 }
1036 
XLA_TEST_F(ExecutionTest,OutputMinorDimensionReshape)1037 XLA_TEST_F(ExecutionTest, OutputMinorDimensionReshape) {
1038   const std::string hlo_text = R"(
1039 HloModule TensorFlowScatterV1
1040 
1041 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
1042   lhs = s32[] parameter(0)
1043   rhs = s32[] parameter(1)
1044   ROOT add = s32[] add(lhs, rhs)
1045 }
1046 
1047 ENTRY main {
1048   param = s32[12] parameter(0)
1049   const = s32[] constant(8)
1050   param_padded = s32[12] set-dimension-size(param, const), dimensions={0}
1051   // Second dimension is dynamic.
1052   reshaped = s32[2, 3, 2] reshape(param_padded), inferred_dimension=1
1053   init = s32[] constant(0)
1054   ROOT reduce = s32[2, 2] reduce(reshaped, init),
1055       dimensions={1},
1056       to_apply=update_s32
1057 }
1058 )";
1059 
1060   // The third dimension has upper bound of 5, dynamic dimension is 3.
1061   Literal operand =
1062       LiteralUtil::CreateR1<int32_t>({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11});
1063   auto module = GetHloModule(hlo_text);
1064 
1065   Literal result = PadAndExecute(std::move(module), {&operand});
1066 
1067   // After padding and reshape we have
1068   //
1069   // [[[0, 1],
1070   //   [2, 3]
1071   //   [P, P]]
1072   //  [[4, 5],
1073   //   [6, 7],
1074   //   [P, P]]]
1075   // Reducing on the second dimension gives us
1076   //  [0+2, 1+3]
1077   //  [4+6, 5+7]
1078   //
1079   Literal expected = LiteralUtil::CreateR2<int32_t>({{2, 4}, {10, 12}});
1080 
1081   EXPECT_EQ(result, expected);
1082 }
1083 
XLA_TEST_F(ExecutionTest,OutputMinorDimensionReshapeWithUnchangedDimMajor)1084 XLA_TEST_F(ExecutionTest, OutputMinorDimensionReshapeWithUnchangedDimMajor) {
1085   const std::string hlo_text = R"(
1086 HloModule TensorFlowScatterV1
1087 
1088 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
1089   lhs = s32[] parameter(0)
1090   rhs = s32[] parameter(1)
1091   ROOT add = s32[] add(lhs, rhs)
1092 }
1093 
1094 ENTRY main {
1095   param = s32[2, 6] parameter(0)
1096   const = s32[] constant(4)
1097   param_padded = s32[2, 6] set-dimension-size(param, const), dimensions={1}
1098   // Third dimension is dynamic.
1099   reshaped = s32[2, 2, 3] reshape(param_padded), inferred_dimension=2
1100   init = s32[] constant(0)
1101   ROOT reduce = s32[2, 2] reduce(reshaped, init),
1102       dimensions={2},
1103       to_apply=update_s32
1104 }
1105 )";
1106 
1107   // The third dimension has upper bound of 5, dynamic dimension is 3.
1108   Literal operand = LiteralUtil::CreateR2<int32_t>(
1109       {{0, 1, 2, 3, 4, 5}, {6, 7, 8, 9, 10, 11}});
1110   auto module = GetHloModule(hlo_text);
1111 
1112   Literal result = PadAndExecute(std::move(module), {&operand});
1113 
1114   // After padding and reshape we have
1115   //
1116   // [[[0, 1, P],
1117   //   [2, 3, P]],
1118   //  [[6, 7, P],
1119   //   [8, 9, P]]]
1120   // Reducing on the third dimension gives us
1121   //  [0+1, 2+3]
1122   //  [6+7, 8+9]
1123   //
1124   Literal expected = LiteralUtil::CreateR2<int32_t>({{1, 5}, {13, 17}});
1125 
1126   EXPECT_EQ(result, expected);
1127 }
1128 
XLA_TEST_F(ExecutionTest,OutputMinorDimensionReshapeWithUnchangedDimMinor)1129 XLA_TEST_F(ExecutionTest, OutputMinorDimensionReshapeWithUnchangedDimMinor) {
1130   const std::string hlo_text = R"(
1131 HloModule TensorFlowScatterV1
1132 
1133 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
1134   lhs = s32[] parameter(0)
1135   rhs = s32[] parameter(1)
1136   ROOT add = s32[] add(lhs, rhs)
1137 }
1138 
1139 ENTRY main {
1140   param = s32[6, 2] parameter(0)
1141   const = s32[] constant(4)
1142   param_padded = s32[6, 2] set-dimension-size(param, const), dimensions={0}
1143   // Second dimension is dynamic.
1144   reshaped = s32[2, 3, 2] reshape(param_padded), inferred_dimension=1
1145   init = s32[] constant(0)
1146   ROOT reduce = s32[2, 2] reduce(reshaped, init),
1147       dimensions={1},
1148       to_apply=update_s32
1149 }
1150 )";
1151 
1152   // The third dimension has upper bound of 5, dynamic dimension is 3.
1153   Literal operand = LiteralUtil::CreateR2<int32_t>(
1154       {{0, 1}, {2, 3}, {4, 5}, {6, 7}, {8, 9}, {10, 11}});
1155   auto module = GetHloModule(hlo_text);
1156 
1157   Literal result = PadAndExecute(std::move(module), {&operand});
1158 
1159   // After padding and reshape we have
1160   //
1161   // [[[0, 1],
1162   //   [2, 3]
1163   //   [P, P]],
1164   //  [[4, 5],
1165   //   [6, 7],
1166   //   [P, P]]]
1167   // Reducing on the second dimension gives us
1168   //  [0+2, 1+3]
1169   //  [4+6, 5+7]
1170   //
1171   Literal expected = LiteralUtil::CreateR2<int32_t>({{2, 4}, {10, 12}});
1172 
1173   EXPECT_EQ(result, expected);
1174 }
1175 
XLA_TEST_F(ExecutionTest,DynamicInputFeature)1176 XLA_TEST_F(ExecutionTest, DynamicInputFeature) {
1177   const std::string hlo_text = R"(
1178 HloModule DynamicInputFeature
1179 
1180 ENTRY main {
1181   param = f32[1, 1, 5] parameter(0)
1182   const = s32[] constant(5)
1183   one = f32[] constant(1)
1184   kernel = f32[1,5,1]{2,1,0} broadcast(f32[] one), dimensions={}
1185   param_dynamic = f32[1,1,<=5] set-dimension-size(param, const), dimensions={2}
1186   ROOT conv = f32[1, 1, 1]{2,1,0} custom-call(f32[1, 1, <=5] param_dynamic, f32[1,5,1]{2,1,0} kernel),
1187                              window={size=1 pad=0_0},
1188                              dim_labels=b0f_0io->b0f,
1189                              padding_type=PADDING_VALID,
1190                              custom_call_target="DynamicConvolutionForward"
1191 }
1192 )";
1193 
1194   Literal operand = LiteralUtil::CreateR3<float>({{{1, 2, 3, 4, 5}}});
1195   auto module = GetHloModule(hlo_text);
1196 
1197   Literal result = PadAndExecute(std::move(module), {&operand});
1198 
1199   Literal expected = LiteralUtil::CreateR3<float>({{{15}}});
1200 
1201   EXPECT_EQ(result, expected);
1202 }
1203 
XLA_TEST_F(LlvmIrGenTestBase,LargeDynamicInput)1204 XLA_TEST_F(LlvmIrGenTestBase, LargeDynamicInput) {
1205 #ifndef XLA_TEST_BACKEND_GPU
1206   GTEST_SKIP();
1207 #endif
1208   const std::string hlo_text = R"( // NOLINT: Will be executed for GPU.
1209 HloModule LargeDynamicInput
1210 
1211 add (lhs: f32[], rhs: f32[]) -> f32[] {
1212   lhs = f32[] parameter(0)
1213   rhs = f32[] parameter(1)
1214   ROOT add = f32[] add(lhs, rhs)
1215 }
1216 
1217 ENTRY main {
1218   param = f32[<=20,<=20,<=20,<=20,<=20,<=20,<=20,<=20] parameter(0)
1219   zero = f32[] constant(0)
1220   ROOT out = reduce(param, zero), to_apply=add, dimensions={0,1,2,3,4,5,6,7}
1221 }
1222 )";
1223 
1224   CompileAndVerifyIr(hlo_text, R"(
1225 CHECK: ret void
1226 )",
1227                      /*match_optimized_ir=*/true);
1228 }
1229 
XLA_TEST_F(ExecutionTest,DynamicDimensionReshapeUnchanged)1230 XLA_TEST_F(ExecutionTest, DynamicDimensionReshapeUnchanged) {
1231   const std::string hlo_text = R"(
1232 HloModule TensorFlowScatterV1
1233 
1234 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
1235   lhs = s32[] parameter(0)
1236   rhs = s32[] parameter(1)
1237   ROOT add = s32[] add(lhs, rhs)
1238 }
1239 
1240 ENTRY main {
1241   param = s32[1, 2, 5, 1] parameter(0)
1242   const = s32[] constant(3)
1243   param_padded = s32[1, 2, 5, 1] set-dimension-size(param, const), dimensions={2}
1244   reshaped = s32[2, 5] reshape(param_padded)
1245   init = s32[] constant(0)
1246   ROOT reduce = s32[2] reduce(reshaped, init),
1247       dimensions={1},
1248       to_apply=update_s32
1249 }
1250 )";
1251 
1252   // Test dynamic padder in unchanged dimension reshape.
1253   Literal operand = LiteralUtil::CreateR4<int32_t>(
1254       {{{{1}, {2}, {3}, {4}, {5}}, {{2}, {4}, {6}, {7}, {8}}}});
1255   auto module = GetHloModule(hlo_text);
1256 
1257   Literal result = PadAndExecute(std::move(module), {&operand});
1258 
1259   Literal expected = LiteralUtil::CreateR1<int32_t>({6, 12});
1260 
1261   EXPECT_EQ(result, expected);
1262 }
1263 
XLA_TEST_F(ExecutionTest,DegeneratedDimension)1264 XLA_TEST_F(ExecutionTest, DegeneratedDimension) {
1265   const std::string hlo_text = R"(
1266 HloModule TensorFlowScatterV1
1267 
1268 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
1269   lhs = s32[] parameter(0)
1270   rhs = s32[] parameter(1)
1271   ROOT add = s32[] add(lhs, rhs)
1272 }
1273 
1274 ENTRY main {
1275   param = s32[1, 2, 5, 1] parameter(0)
1276   size = s32[] constant(0)
1277 // First dimension is dynamic.
1278   param_padded = s32[1, 2, 5, 1] set-dimension-size(param, size),
1279     dimensions={0}
1280   reshaped = s32[10] reshape(param_padded)
1281   init = s32[] constant(0)
1282   ROOT reduce = s32[] reduce(reshaped, init),
1283       dimensions={0},
1284       to_apply=update_s32
1285 }
1286 )";
1287 
1288   // First dimension (1) is dynamic. Since dynamic size is 0, result is also 0.
1289   Literal operand = LiteralUtil::CreateR4<int32_t>(
1290       {{{{1}, {2}, {3}, {4}, {5}}, {{2}, {4}, {6}, {7}, {8}}}});
1291   auto module = GetHloModule(hlo_text);
1292 
1293   Literal result = PadAndExecute(std::move(module), {&operand});
1294 
1295   Literal expected = LiteralUtil::CreateR0<int32_t>(0);
1296 
1297   EXPECT_EQ(result, expected);
1298 }
1299 
XLA_TEST_F(ExecutionTest,ReshapeSplitCombineSameTime)1300 XLA_TEST_F(ExecutionTest, ReshapeSplitCombineSameTime) {
1301   // [<=4, 2, <=2]
1302   //       |
1303   //    Reshape
1304   //       |
1305   // [2, <=2, <=4]
1306   //
1307   // Split one input dynamic dim to multiple output dims while combining two
1308   // dimensions together.
1309   //
1310   const std::string hlo_text = R"(
1311 HloModule TensorFlowScatterV1
1312 
1313 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
1314   lhs = s32[] parameter(0)
1315   rhs = s32[] parameter(1)
1316   ROOT add = s32[] add(lhs, rhs)
1317 }
1318 
1319 ENTRY main {
1320   param = s32[4, 2, 2] parameter(0)
1321   two = s32[] constant(2)
1322   one = s32[] constant(1)
1323   param_padded_partial = s32[<=4, 2, 2] set-dimension-size(param, two),
1324     dimensions={0}
1325 
1326   param_padded_dynamic = s32[<=4, 2, <=2] set-dimension-size(param_padded_partial,
1327                                                              one),
1328     dimensions={2}
1329   reshaped = s32[2, <=2, <=4] reshape(param_padded_dynamic),
1330     inferred_dimension=1
1331   init = s32[] constant(0)
1332   ROOT reduce = s32[] reduce(reshaped, init),
1333       dimensions={0, 1, 2},
1334       to_apply=update_s32
1335 }
1336 )";
1337 
1338   // First and last dims are dynamic. Padded data are expressed as -1.
1339   Literal operand = LiteralUtil::CreateR3<int32_t>({{{0, -1}, {1, -1}},
1340                                                     {{2, -1}, {3, -1}},
1341                                                     {{-1, -1}, {-1, -1}},
1342                                                     {{-1, -1}, {-1, -1}}});
1343   auto module = GetHloModule(hlo_text);
1344 
1345   Literal result = PadAndExecute(std::move(module), {&operand});
1346 
1347   // Reshaping (with correct reshape rewriting) produces:
1348   // [[[0, 1, -1, -1], [-1, -1, -1, -1]], [[2, 3, -1, -1], [-1, -1, -1, -1]]]
1349   //
1350   //  Dynamic padder auto pads -1 with 0.
1351   //
1352   // Reducing it produces 0 + 1 + 2 + 3 = 6
1353 
1354   Literal expected = LiteralUtil::CreateR0<int32_t>(6);
1355 
1356   EXPECT_EQ(result, expected);
1357 }
1358 
XLA_TEST_F(ExecutionTest,ReshapeComplicated)1359 XLA_TEST_F(ExecutionTest, ReshapeComplicated) {
1360   // [2, <=4, 4]
1361   //       |
1362   //    Reshape
1363   //       |
1364   // [<=16, 2]
1365   //
1366   // Reshape that is not a composition of splitting one input dim to multiple
1367   // output dims or combining multiple input dimensions to one output dimension.
1368   //
1369   const std::string hlo_text = R"(
1370 HloModule TensorFlowScatterV1
1371 
1372 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
1373   lhs = s32[] parameter(0)
1374   rhs = s32[] parameter(1)
1375   ROOT add = s32[] add(lhs, rhs)
1376 }
1377 
1378 ENTRY main {
1379   param = s32[2, 4, 4] parameter(0)
1380   two = s32[] constant(2)
1381   param_padded_dynamic = s32[2, <=4, 4] set-dimension-size(param, two),
1382     dimensions={1}
1383   reshaped = s32[<=16, 2] reshape(param_padded_dynamic), inferred_dimension=0
1384   init = s32[] constant(0)
1385   ROOT reduce = s32[] reduce(reshaped, init),
1386       dimensions={0, 1},
1387       to_apply=update_s32
1388 }
1389 )";
1390 
1391   // First and last dims are dynamic. Padded data are expressed as -1.
1392   Literal operand = LiteralUtil::CreateR3<int32_t>(
1393       {{{1, 2, 3, 4}, {5, 6, 7, 8}, {-1, -1, -1, -1}, {-1, -1, -1, -1}},
1394        {{9, 10, 11, 12},
1395         {13, 14, 15, 16},
1396         {-1, -1, -1, -1},
1397         {-1, -1, -1, -1}}});
1398   auto module = GetHloModule(hlo_text);
1399   Literal result = PadAndExecute(std::move(module), {&operand});
1400 
1401   // Reshaping (with correct reshape rewriting) produces:
1402   // [[[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]],
1403   //  [[-1, -1], [-1, -1], ...]]
1404   //
1405   //  Dynamic padder auto pads -1 with 0.
1406   //
1407   // Reducing it produces 1 + 2 + 3 + ... + 16 = 136
1408 
1409   Literal expected = LiteralUtil::CreateR0<int32_t>(136);
1410   EXPECT_EQ(result, expected);
1411 }
1412 
XLA_TEST_F(ExecutionTest,WhileLoopStack)1413 XLA_TEST_F(ExecutionTest, WhileLoopStack) {
1414   // Push into a dynamic sized stack with iteration number:
1415   // init:
1416   // [[P, P],
1417   //  [P, P],
1418   //  [P, P],
1419   //  [P, P]]
1420   // First iteration i = 0:
1421   // [[0, 0],
1422   //  [P, P],
1423   //  [P, P],
1424   //  [P, P]]
1425   // Second iteration i = 1:
1426   // [[0, 0],
1427   //  [1, 1],
1428   //  [P, P],
1429   //  [P, P]]
1430   // Third iteration i = 2:
1431   // [[0, 0],
1432   //  [1, 1],
1433   //  [2, 2],
1434   //  [P, P]]
1435 
1436   const std::string hlo_text = R"(
1437 HloModule module
1438 
1439 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
1440   lhs = s32[] parameter(0)
1441   rhs = s32[] parameter(1)
1442   ROOT add = s32[] add(lhs, rhs)
1443 }
1444 
1445 body {
1446   stack = (s32[<=4,2]) parameter(0)
1447   stack_buffer = s32[<=4, 2] get-tuple-element(stack), index=0
1448   stack_size = s32[] get-dimension-size(stack_buffer), dimensions={0}
1449   zero = s32[] constant(0)
1450   one = s32[] constant(1)
1451   // content of the stack is the stack index broadcasted.
1452   new_data = s32[1, 2] broadcast(s32[] stack_size), dimensions={}
1453   new_stack_buffer = s32[<=4, 2] dynamic-update-slice(stack_buffer, new_data, stack_size, zero)
1454   new_stack_size = s32[] add(stack_size, one)
1455   new_stack_buffer_dynamic = s32[<=4, 2]set-dimension-size(new_stack_buffer, new_stack_size), dimensions={0}
1456   ROOT new_stack = (s32[<=4,2]) tuple(new_stack_buffer_dynamic)
1457 }
1458 
1459 condition {
1460   stack = (s32[<=4,2]) parameter(0)
1461   stack_buffer = s32[<=4, 2] get-tuple-element(stack), index=0
1462   stack_size = s32[] get-dimension-size(stack_buffer), dimensions={0}
1463   three = s32[] constant(3)
1464   ROOT less-than = pred[] compare(s32[] stack_size, s32[] three), direction=LT
1465 }
1466 
1467 ENTRY entry {
1468   zero = s32[] constant(0)
1469   pad = s32[] constant(-1)
1470   stack_buffer_input = s32[4, 2] broadcast(s32[] pad), dimensions={}
1471   stack_buffer_input_dynamic = s32[<=4, 2] set-dimension-size(stack_buffer_input, zero), dimensions={0}
1472   input_tuple = (s32[<=4 ,2]) tuple(stack_buffer_input_dynamic)
1473   while = (s32[<=4, 2]) while(input_tuple), body=body, condition=condition
1474   stack_buffer = s32[<=4, 2] get-tuple-element(while), index=0
1475   ROOT reduce = s32[2] reduce(stack_buffer, zero),
1476     dimensions={0},
1477     to_apply=update_s32
1478 }
1479 )";
1480 
1481   auto module = GetHloModule(hlo_text);
1482 
1483   Literal result = PadAndExecute(std::move(module), {});
1484 
1485   // Stack has three valid items in it:
1486   // [[0, 0],
1487   //  [1, 1],
1488   //  [2, 2],
1489   //  [P, P]]
1490   //
1491   // Reducing along major dimension gives us [3, 3]
1492   Literal expected = LiteralUtil::CreateR1<int32_t>({{3, 3}});
1493 
1494   EXPECT_EQ(result, expected);
1495 }
1496 
XLA_TEST_F(ExecutionTest,DynamicAddWithImplicitBroadcast)1497 XLA_TEST_F(ExecutionTest, DynamicAddWithImplicitBroadcast) {
1498   const std::string hlo_text = R"(
1499 HloModule module
1500 
1501 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
1502   lhs = s32[] parameter(0)
1503   rhs = s32[] parameter(1)
1504   ROOT add = s32[] add(lhs, rhs)
1505 }
1506 
1507 ENTRY entry {
1508   zero = s32[] constant(0)
1509   one = s32[] constant(1)
1510   two = s32[] constant(2)
1511   three = s32[] constant(3)
1512   input1 = s32[4, 2] iota(), iota_dimension=0
1513   ones = s32[4, 2] broadcast(one), dimensions={}
1514   input1_added = s32[4, 2] add(input1, ones)
1515   input1_dynamic = s32[<=4, 2] set-dimension-size(input1_added, one), dimensions={0}
1516   input2 = s32[4, 2] broadcast(two), dimensions={}
1517   input2_dynamic = s32[<=4, 2] set-dimension-size(input2, three), dimensions={0}
1518   add = s32[<=4, 2] add(input1_dynamic, input2_dynamic)
1519   ROOT reduce = s32[2] reduce(add, zero),
1520     dimensions={0},
1521     to_apply=update_s32
1522 }
1523 )";
1524 
1525   auto module = GetHloModule(hlo_text);
1526 
1527   Literal result = PadAndExecute(std::move(module), {});
1528 
1529   // Array has two valid items in it:
1530   // [[3, 3],
1531   //  [3, 3],
1532   //  [3, 3],
1533   //  [P, P]]
1534   // Reducing them gives us [9, 9]
1535   Literal expected = LiteralUtil::CreateR1<int32_t>({{9, 9}});
1536 
1537   EXPECT_EQ(result, expected);
1538 }
1539 
XLA_TEST_F(ExecutionTest,DynamicAddWithImplicitSlice)1540 XLA_TEST_F(ExecutionTest, DynamicAddWithImplicitSlice) {
1541   const std::string hlo_text = R"(
1542 HloModule module
1543 
1544 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
1545   lhs = s32[] parameter(0)
1546   rhs = s32[] parameter(1)
1547   ROOT add = s32[] add(lhs, rhs)
1548 }
1549 
1550 ENTRY entry {
1551   zero = s32[] constant(0)
1552   one = s32[] constant(1)
1553   two = s32[] constant(2)
1554   three = s32[] constant(3)
1555   input1 = s32[4, 2] broadcast(one), dimensions={}
1556   input1_dynamic = s32[<=4, 2] set-dimension-size(input1, three), dimensions={0}
1557   input2 = s32[4, 2] broadcast(two), dimensions={}
1558   input2_dynamic = s32[<=4, 2] set-dimension-size(input2, two), dimensions={0}
1559   add = s32[<=4, 2] add(input1_dynamic, input2_dynamic)
1560   ROOT reduce = s32[2] reduce(add, zero),
1561     dimensions={0},
1562     to_apply=update_s32
1563 }
1564 )";
1565 
1566   auto module = GetHloModule(hlo_text);
1567 
1568   Literal result = PadAndExecute(std::move(module), {});
1569 
1570   // Array has two valid items in it:
1571   // [[3, 3],
1572   //  [3, 3],
1573   //  [P, P],
1574   //  [P, P]]
1575   // Reducing them gives us [6, 6]
1576   Literal expected = LiteralUtil::CreateR1<int32_t>({{6, 6}});
1577 
1578   EXPECT_EQ(result, expected);
1579 }
1580 
XLA_TEST_F(ExecutionTest,DynamicStackPop)1581 XLA_TEST_F(ExecutionTest, DynamicStackPop) {
1582   // This tests the case where a static sized stack is popped by a dynamic
1583   // number of times.
1584 
1585   // In the beginning the stack has static size that has 4 elements:
1586   // [[1, 1],
1587   //  [1, 1],
1588   //  [1, 1],
1589   //  [1, 1]]
1590   //
1591   // Popping this stack using set-dimension-size in a loop creates a dynamic
1592   // result depending on how many times we pop it (in this test, two times).
1593 
1594   const std::string hlo_text = R"(
1595 HloModule module
1596 
1597 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
1598   lhs = s32[] parameter(0)
1599   rhs = s32[] parameter(1)
1600   ROOT add = s32[] add(lhs, rhs)
1601 }
1602 
1603 body {
1604   param_tuple = (s32[<=4,2]) parameter(0)
1605   param = s32[<=4, 2] get-tuple-element(param_tuple), index=0
1606   one = s32[] constant(1)
1607   size = s32[] get-dimension-size(param), dimensions={0}
1608   new_size = s32[] subtract(size, one)
1609   output = s32[<=4, 2] set-dimension-size(param, new_size), dimensions={0}
1610   ROOT root = (s32[<=4, 2]) tuple(output)
1611 }
1612 
1613 condition {
1614   stack = (s32[<=4,2]) parameter(0)
1615   stack_buffer = s32[<=4,2] get-tuple-element(stack), index=0
1616   stack_size = s32[] get-dimension-size(stack_buffer), dimensions={0}
1617   two = s32[] constant(2)
1618   ROOT greater-than = pred[] compare(s32[] stack_size, s32[] two), direction=GE
1619 }
1620 
1621 ENTRY entry {
1622   one = s32[] constant(1)
1623   zero = s32[] constant(0)
1624   stack_buffer_input = s32[4, 2] broadcast(s32[] one), dimensions={}
1625   input_tuple = (s32[4, 2]) tuple(stack_buffer_input)
1626   while = (s32[4, 2]) while(input_tuple), body=body, condition=condition
1627   stack_buffer = s32[<=4, 2] get-tuple-element(while), index=0
1628   ROOT reduce = s32[2] reduce(stack_buffer, zero),
1629     dimensions={0},
1630     to_apply=update_s32
1631 }
1632 )";
1633 
1634   auto module = GetHloModule(hlo_text);
1635 
1636   Literal result = PadAndExecute(std::move(module), {});
1637 
1638   // Stack has two valid items in it:
1639   // [[1, 1],
1640   //  [1, 1],
1641   //  [P, P],
1642   //  [P, P]]
1643   // Reducing them gives us [2, 2]
1644   Literal expected = LiteralUtil::CreateR1<int32_t>({{1, 1}});
1645 
1646   EXPECT_EQ(result, expected);
1647 }
1648 
XLA_TEST_F(ExecutionTest,DoubleDynamicDimension)1649 XLA_TEST_F(ExecutionTest, DoubleDynamicDimension) {
1650   const std::string hlo_text = R"(
1651 HloModule TensorFlowScatterV1
1652 
1653 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
1654   lhs = s32[] parameter(0)
1655   rhs = s32[] parameter(1)
1656   ROOT add = s32[] add(lhs, rhs)
1657 }
1658 
1659 ENTRY main {
1660   param = s32[2, 3, 3] parameter(0)
1661   size = s32[] constant(2)
1662   param_padded_partial = s32[2, 3, 3] set-dimension-size(param, size),
1663     dimensions={1}
1664   param_padded = s32[2, 3, 3] set-dimension-size(param_padded_partial, size),
1665     dimensions={2}
1666   reshaped = s32[18] reshape(param_padded)
1667   init = s32[] constant(0)
1668   ROOT reduce = s32[] reduce(reshaped, init),
1669       dimensions={0},
1670       to_apply=update_s32
1671 }
1672 )";
1673 
1674   // First dimension (1) is dynamic. Since dynamic size is 0, result is also 0.
1675   Literal operand = LiteralUtil::CreateR3<int32_t>(
1676       {{{0, 1, 2}, {3, 4, 5}, {6, 7, 8}}, {{0, 1, 2}, {3, 4, 5}, {6, 7, 8}}});
1677   auto module = GetHloModule(hlo_text);
1678 
1679   Literal result = PadAndExecute(std::move(module), {&operand});
1680 
1681   // Padded data looks like this (P is padding which is ignored).
1682   // [[0, 1, P]
1683   // [3, 4, P]
1684   // [P, P, P]]
1685   //
1686   // [[0, 1, P]
1687   // [3, 4, P]
1688   // [P, P, P]]
1689   //
1690   // Reshaping (with correct reshape rewriting) produces:
1691   // [0, 1, 3, 4, 0, 1, 3, 4, P, P, P, P, P, P, P, P, P, P]
1692   //
1693   // Reducing it produces 16
1694 
1695   Literal expected = LiteralUtil::CreateR0<int32_t>(16);
1696 
1697   EXPECT_EQ(result, expected);
1698 }
1699 
XLA_TEST_F(ExecutionTest,DynamicReshapeDoubleDynamicDimensions)1700 XLA_TEST_F(ExecutionTest, DynamicReshapeDoubleDynamicDimensions) {
1701   const std::string hlo_text = R"(
1702 HloModule TensorFlowScatterV1
1703 
1704 ENTRY main {
1705   param = s32[2, 3, 3] parameter(0)
1706   size = s32[] constant(2)
1707   param_padded_partial = s32[2, <=3, 3] set-dimension-size(param, size),
1708     dimensions={1}
1709   param_padded = s32[2, <=3, <=3] set-dimension-size(param_padded_partial, size),
1710     dimensions={2}
1711   result_size = s32[] constant(8)
1712   ROOT reshaped = s32[<=18] dynamic-reshape(param_padded, result_size)
1713 }
1714 )";
1715 
1716   // First dimension (1) is dynamic. Since dynamic size is 0, result is also 0.
1717   Literal operand = LiteralUtil::CreateR3<int32_t>(
1718       {{{0, 1, 2}, {3, 4, 5}, {6, 7, 8}}, {{0, 1, 2}, {3, 4, 5}, {6, 7, 8}}});
1719   auto module = GetHloModule(hlo_text);
1720 
1721   Literal result = PadAndExecute(std::move(module), {&operand}, false);
1722   result.SetDynamicSize(0, 8);
1723   // Padded data looks like this (P is padding which is ignored).
1724   // [[0, 1, P]
1725   // [3, 4, P]
1726   // [P, P, P]]
1727   //
1728   // [[0, 1, P]
1729   // [3, 4, P]
1730   // [P, P, P]]
1731   //
1732   // Reshaping (with correct reshape rewriting) produces:
1733   // [0, 1, 3, 4, 0, 1, 3, 4]
1734   Literal expected = LiteralUtil::CreateR1<int32_t>({0, 1, 3, 4, 0, 1, 3, 4});
1735 
1736   EXPECT_EQ(result, expected);
1737 }
1738 
XLA_TEST_F(ExecutionTest,DynamicReshapeOutputDoubleDynamicDimensions)1739 XLA_TEST_F(ExecutionTest, DynamicReshapeOutputDoubleDynamicDimensions) {
1740   const std::string hlo_text = R"(
1741 HloModule TensorFlowScatterV1
1742 
1743 ENTRY main {
1744   param = s32[18] parameter(0)
1745   eight = s32[] constant(8)
1746   param_dynamic = s32[<=18] set-dimension-size(param, eight), dimensions={0}
1747   two = s32[] constant(2)
1748   // every dimension has dynamic size two.
1749   ROOT reshaped = s32[2, <=3, <=3] dynamic-reshape(param_dynamic, two, two, two)
1750 }
1751 )";
1752   Literal operand = LiteralUtil::CreateR1<int32_t>(
1753       {0, 1, 3, 4, 0, 1, 3, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1});
1754 
1755   auto module = GetHloModule(hlo_text);
1756 
1757   Literal result = PadAndExecute(std::move(module), {&operand}, false);
1758   VLOG(1) << " result: " << result.ToString();
1759   result.SetDynamicSize(1, 2);
1760   result.SetDynamicSize(2, 2);
1761   // Padded operand is:
1762   // [0, 1, 3, 4, 0, 1, 3, 4, P, P ....]
1763   //
1764   // Reshaping it should produce:
1765   // [[0, 1, P]
1766   // [3, 4, P]
1767   // [P, P, P]]
1768   //
1769   // [[0, 1, P]
1770   // [3, 4, P]
1771   // [P, P, P]]
1772   Literal expected =
1773       LiteralUtil::CreateR3<int32_t>({{{0, 1}, {3, 4}}, {{0, 1}, {3, 4}}});
1774   EXPECT_EQ(result, expected);
1775 }
1776 
XLA_TEST_F(ExecutionTest,DynamicReshapeComplicated)1777 XLA_TEST_F(ExecutionTest, DynamicReshapeComplicated) {
1778   const std::string hlo_text = R"(
1779 HloModule TensorFlowScatterV1
1780 
1781 ENTRY main {
1782   param = s32[3, 4, 4] parameter(0)
1783   two = s32[] constant(2)
1784   param_dynamic = s32[<=3, 4, 4] set-dimension-size(param, two), dimensions={0}
1785   three = s32[] constant(3)
1786   param_dynamic1 = s32[<=3, <=4, 4] set-dimension-size(param_dynamic, three), dimensions={1}
1787   param_dynamic2 = s32[<=3, <=4, <=4] set-dimension-size(param_dynamic1, three), dimensions={2}
1788   six = s32[] constant(6)
1789 
1790   // Static reshape is from [3, 4, 4] to [6, 8].
1791   // Dynamic reshape is from [2, 3, 3] to [3, 6].
1792   ROOT reshaped = s32[<=6, <=8] dynamic-reshape(param_dynamic2, three, six)
1793 }
1794 )";
1795   Literal operand = LiteralUtil::CreateR3<int32_t>(
1796       {{{0, 1, 2, -1}, {3, 4, 5, -1}, {6, 7, 8, -1}, {-1, -1, -1, -1}},
1797        {{9, 8, 7, -1}, {6, 5, 4, -1}, {3, 2, 1, -1}, {-1, -1, -1, -1}},
1798        {{-1, -1, -1, -1},
1799         {-1, -1, -1, -1},
1800         {-1, -1, -1, -1},
1801         {-1, -1, -1, -1}}});
1802 
1803   auto module = GetHloModule(hlo_text);
1804 
1805   Literal result = PadAndExecute(std::move(module), {&operand}, false);
1806   result.SetDynamicSize(0, 3);
1807   result.SetDynamicSize(1, 6);
1808   Literal expected = LiteralUtil::CreateR2<int32_t>(
1809       {{0, 1, 2, 3, 4, 5}, {6, 7, 8, 9, 8, 7}, {6, 5, 4, 3, 2, 1}});
1810   EXPECT_EQ(result, expected);
1811 }
1812 
XLA_TEST_F(ExecutionTest,SetGetDimensionSize)1813 XLA_TEST_F(ExecutionTest, SetGetDimensionSize) {
1814   const std::string hlo_text = R"(
1815 HloModule TensorFlowScatterV1
1816 
1817 ENTRY main {
1818   param = s32[3] parameter(0)
1819   size = s32[] constant(2)
1820   param_dynamic_size = s32[3] set-dimension-size(param, size),
1821     dimensions={0}
1822   ROOT gds = s32[] get-dimension-size(param_dynamic_size),
1823     dimensions={0}
1824 }
1825 )";
1826 
1827   // First dimension (1) is dynamic. Since dynamic size is 0, result is also 0.
1828   Literal operand = LiteralUtil::CreateR1<int32_t>({1, 2, 3});
1829   auto module = GetHloModule(hlo_text);
1830 
1831   Literal result = PadAndExecute(std::move(module), {&operand});
1832 
1833   // Should return the size 2 instead of 3.
1834   Literal expected = LiteralUtil::CreateR0<int32_t>(2);
1835 
1836   EXPECT_EQ(result, expected);
1837 }
1838 
XLA_TEST_F(ExecutionTest,DynamicSort)1839 XLA_TEST_F(ExecutionTest, DynamicSort) {
1840   const std::string hlo_text = R"(
1841 HloModule TEST
1842 
1843 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
1844   lhs = s32[] parameter(0)
1845   rhs = s32[] parameter(1)
1846   ROOT add = s32[] add(lhs, rhs)
1847 }
1848 
1849 %compare-greater-than (lhs: s32[], rhs: s32[]) -> pred[] {
1850   %lhs = s32[] parameter(0)
1851   %rhs = s32[] parameter(1)
1852   ROOT %compare = pred[] compare(s32[] %lhs, s32[] %rhs), direction=GT
1853 }
1854 
1855 ENTRY main {
1856   param = s32[4] parameter(0)
1857   size = s32[] constant(3)
1858   param_dynamic_size = s32[4] set-dimension-size(param, size),
1859     dimensions={0}
1860   ROOT sort = s32[4]{0} sort(s32[4]{0} %param_dynamic_size),
1861     dimensions={0}, is_stable=false, to_apply=%compare-greater-than
1862 }
1863 )";
1864 
1865   Literal operand = LiteralUtil::CreateR1<int32_t>({1, 4, 3, 2});
1866   auto module = GetHloModule(hlo_text);
1867 
1868   Literal result = PadAndExecute(std::move(module), {&operand},
1869                                  /*slice_dynamic_output=*/false);
1870   Literal expected = LiteralUtil::CreateR1<int32_t>({4, 3, 1, 2});
1871 
1872   EXPECT_EQ(result, expected);
1873 }
1874 
XLA_TEST_F(ExecutionTest,DynamicPad)1875 XLA_TEST_F(ExecutionTest, DynamicPad) {
1876   const std::string hlo_text = R"(
1877 HloModule TEST
1878 
1879 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
1880   lhs = s32[] parameter(0)
1881   rhs = s32[] parameter(1)
1882   ROOT add = s32[] add(lhs, rhs)
1883 }
1884 
1885 ENTRY main {
1886   param = s32[4] parameter(0)
1887   size = s32[] constant(3)
1888   padding = s32[] constant(2)
1889   param_dynamic = s32[<=4] set-dimension-size(param, size),
1890     dimensions={0}
1891   // Pad head and tail with 2
1892   pad = s32[<=6] pad(param_dynamic, padding), padding=1_1
1893 
1894   init = s32[] constant(0)
1895   ROOT reduce = s32[] reduce(pad, init),
1896     dimensions={0},
1897     to_apply=update_s32
1898 }
1899 )";
1900 
1901   Literal operand = LiteralUtil::CreateR1<int32_t>({1, 4, 3, 5});
1902   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_text));
1903 
1904   // After padding head and tail with "2", the effective data will be [2, 1, 4,
1905   // 3, 2]
1906 
1907   Literal result = PadAndExecute(std::move(module), {&operand},
1908                                  /*slice_dynamic_output=*/false);
1909   Literal expected = LiteralUtil::CreateR0<int32_t>(12);
1910 
1911   EXPECT_EQ(result, expected);
1912 }
1913 
XLA_TEST_F(ExecutionTest,DynamicPadInteriorPadding)1914 XLA_TEST_F(ExecutionTest, DynamicPadInteriorPadding) {
1915   const std::string hlo_text = R"(
1916 HloModule TEST
1917 
1918 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
1919   lhs = s32[] parameter(0)
1920   rhs = s32[] parameter(1)
1921   ROOT add = s32[] add(lhs, rhs)
1922 }
1923 
1924 ENTRY main {
1925   param = s32[4] parameter(0)
1926   size = s32[] constant(3)
1927   padding = s32[] constant(2)
1928   param_dynamic = s32[<=4] set-dimension-size(param, size),
1929     dimensions={0}
1930   // Pad interior with constant 2.
1931   pad = s32[<=7] pad(param_dynamic, padding), padding=0_0_1
1932 
1933   init = s32[] constant(0)
1934   ROOT reduce = s32[] reduce(pad, init),
1935     dimensions={0},
1936     to_apply=update_s32
1937 }
1938 )";
1939 
1940   // Only the first 3 elements are effective: 1, 4, 3
1941   Literal operand = LiteralUtil::CreateR1<int32_t>({1, 4, 3, 5});
1942   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_text));
1943 
1944   // After interior padding with "2", the effective data will be
1945   // [1, 2, 4, 2, 3]
1946   Literal result = PadAndExecute(std::move(module), {&operand},
1947                                  /*slice_dynamic_output=*/false);
1948   Literal expected = LiteralUtil::CreateR0<int32_t>(12);
1949 
1950   EXPECT_EQ(result, expected);
1951 }
1952 
XLA_TEST_F(ExecutionTest,DynamicConditionalDimension)1953 XLA_TEST_F(ExecutionTest, DynamicConditionalDimension) {
1954   const std::string hlo_text = R"(
1955 HloModule module
1956 
1957 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
1958   lhs = s32[] parameter(0)
1959   rhs = s32[] parameter(1)
1960   ROOT add = s32[] add(lhs, rhs)
1961 }
1962 
1963 true_branch {
1964   true_param = (s32[<=3,2]) parameter(0)
1965   param = s32[<=3, 2] get-tuple-element(true_param), index=0
1966   add = s32[<=3,2] add(param, param)
1967   ROOT true_tuple = (s32[<=3,2], s32[<=3,2]) tuple(add, add)
1968 }
1969 
1970 false_branch {
1971   false_param = (s32[<=3,2]) parameter(0)
1972   param = s32[<=3, 2] get-tuple-element(false_param), index=0
1973   add = s32[<=3,2] add(param, param)
1974   ROOT false_tuple = (s32[<=3,2], s32[<=3,2]) tuple(add, add)
1975 }
1976 
1977 ENTRY entry {
1978   param0 = s32[3,2] parameter(0)
1979   size = s32[] constant(2)
1980   branch = pred[] constant(false)
1981   param_dynamic = s32[<=3, 2] set-dimension-size(param0, size), dimensions={0}
1982   param_tuple = (s32[<=3 ,2]) tuple(param_dynamic)
1983   conditional = (s32[<=3, 2], s32[<=3, 2]) conditional(branch, param_tuple, param_tuple),
1984     true_computation=true_branch, false_computation=false_branch
1985   gte0 = s32[<=3,2] get-tuple-element(conditional), index=1
1986   init = s32[] constant(0)
1987   ROOT reduce = s32[2] reduce(gte0, init),
1988     dimensions={0},
1989     to_apply=update_s32
1990 }
1991 )";
1992 
1993   Literal operand = LiteralUtil::CreateR2<int32_t>({{0, 1}, {2, 3}, {4, 5}});
1994   auto module = GetHloModule(hlo_text);
1995 
1996   Literal result = PadAndExecute(std::move(module), {&operand},
1997                                  /*slice_dynamic_output=*/false);
1998   Literal expected = LiteralUtil::CreateR1<int32_t>({4, 8});
1999 
2000   EXPECT_EQ(result, expected);
2001 }
2002 
XLA_TEST_F(ExecutionTest,DynamicTupleSort)2003 XLA_TEST_F(ExecutionTest, DynamicTupleSort) {
2004   const std::string hlo_text = R"(
2005 HloModule TEST
2006 
2007 %compare-greater-than (lhs: s32[], rhs: s32[], lhs_2: s32[], lhs_2: s32[]) -> pred[] {
2008   %lhs = s32[] parameter(0)
2009   %rhs = s32[] parameter(1)
2010   %lhs_2 = s32[] parameter(2)
2011   %rhs_2 = s32[] parameter(3)
2012   ROOT %compare = pred[] compare(s32[] %lhs, s32[] %rhs), direction=GT
2013 }
2014 
2015 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
2016   lhs = s32[] parameter(0)
2017   rhs = s32[] parameter(1)
2018   ROOT add = s32[] add(lhs, rhs)
2019 }
2020 
2021 ENTRY main {
2022   param = s32[3] parameter(0)
2023   size = s32[] constant(2)
2024   param_dynamic_size = s32[3] set-dimension-size(param, size),
2025     dimensions={0}
2026   sort = (s32[3]{0}, s32[3]{0}) sort(s32[3]{0} %param_dynamic_size,
2027                                      s32[3]{0} %param_dynamic_size),
2028     dimensions={0}, is_stable=true, to_apply=%compare-greater-than
2029   ROOT get-tuple-element = s32[3]{0} get-tuple-element((s32[3]{0}, s32[3]{0}) %sort),
2030     index=0
2031 }
2032 )";
2033 
2034   Literal operand = LiteralUtil::CreateR1<int32_t>({0, 4, 2});
2035   auto module = GetHloModule(hlo_text);
2036 
2037   Literal result = PadAndExecute(std::move(module), {&operand},
2038                                  /*slice_dynamic_output=*/false);
2039   Literal expected = LiteralUtil::CreateR1<int32_t>({4, 0, 2});
2040 
2041   EXPECT_EQ(result, expected);
2042 }
2043 
2044 namespace op = xla::testing::opcode_matchers;
2045 
2046 class HloDimensionSizeLegalizerTest : public HloTestBase {
2047  protected:
HloDimensionSizeLegalizerTest()2048   HloDimensionSizeLegalizerTest() {}
2049 };
2050 
TEST_F(HloDimensionSizeLegalizerTest,Ok)2051 TEST_F(HloDimensionSizeLegalizerTest, Ok) {
2052   auto module = ParseAndReturnVerifiedModule(R"(
2053 HloModule _
2054 ENTRY gds {
2055   p = s32[3,4] parameter(0)
2056   size0 = s32[] get-dimension-size(p), dimensions={0}
2057   size1 = s32[] get-dimension-size(p), dimensions={1}
2058   ROOT mul = s32[] multiply(size0, size1)
2059 })")
2060                     .ValueOrDie();
2061   DynamicPadder pass;
2062   EXPECT_TRUE(pass.Run(module.get()).ValueOrDie());
2063   EXPECT_THAT(module->entry_computation()->root_instruction(),
2064               op::Multiply(op::Constant(), op::Constant()));
2065 }
2066 
TEST_F(HloDimensionSizeLegalizerTest,GetSetSetDimensionSizeRewriter)2067 TEST_F(HloDimensionSizeLegalizerTest, GetSetSetDimensionSizeRewriter) {
2068   auto module = ParseAndReturnVerifiedModule(R"(
2069 HloModule _
2070 ENTRY gds {
2071   p = s32[3,4] parameter(0)
2072   size0 = s32[] get-dimension-size(p), dimensions={0}
2073   p_copy = s32[3,4] copy(p)
2074   p_copy_dynamic = s32[<=3, 4] set-dimension-size(p_copy, size0), dimensions={0}
2075   size1 = s32[] get-dimension-size(p_copy_dynamic), dimensions={0}
2076   ROOT mul = s32[] multiply(size0, size1)
2077 })")
2078                     .ValueOrDie();
2079   DynamicPadder pass;
2080   EXPECT_TRUE(pass.Run(module.get()).ValueOrDie());
2081   EXPECT_THAT(module->entry_computation()->root_instruction(),
2082               op::Multiply(op::Constant(), op::Constant()));
2083 }
2084 
TEST_F(HloDimensionSizeLegalizerTest,IllegalType)2085 TEST_F(HloDimensionSizeLegalizerTest, IllegalType) {
2086   auto module = ParseAndReturnUnverifiedModule(R"(
2087 HloModule _
2088 ENTRY gds {
2089   p = s32[3]{0} parameter(0)
2090   ROOT gds = s64[] get-dimension-size(p), dimensions={0}
2091 })")
2092                     .ValueOrDie();
2093   DynamicPadder pass;
2094   EXPECT_FALSE(pass.Run(module.get()).ok());
2095 }
2096 
TEST_F(HloDimensionSizeLegalizerTest,IllegalDimension)2097 TEST_F(HloDimensionSizeLegalizerTest, IllegalDimension) {
2098   auto module = ParseAndReturnUnverifiedModule(R"(
2099 HloModule _
2100 ENTRY gds {
2101   p = f32[2,5] parameter(0)
2102   ROOT gds = s32[] get-dimension-size(p), dimensions={2}
2103 })")
2104                     .ValueOrDie();
2105   DynamicPadder pass;
2106   EXPECT_FALSE(pass.Run(module.get()).ok());
2107 }
2108 
2109 class SizeCheckTest : public HloTestBase {
2110  protected:
SizeCheckTest()2111   SizeCheckTest() {}
2112 };
2113 
TEST_F(SizeCheckTest,CompileTimeCheckBinaryOpFail)2114 TEST_F(SizeCheckTest, CompileTimeCheckBinaryOpFail) {
2115   auto module = ParseAndReturnUnverifiedModule(R"(
2116 HloModule _
2117 ENTRY gds {
2118   size_0 = s32[] parameter(0)
2119   size_1 = s32[] parameter(1)
2120   arg = s32[4]{0} parameter(2)
2121   dynamic_arg_0 = s32[<=4] set-dimension-size(arg, size_0), dimensions={0}
2122   dynamic_arg_1 = s32[<=4] set-dimension-size(arg, size_1), dimensions={0}
2123   ROOT add = s32[<=4] add(dynamic_arg_0, dynamic_arg_1)
2124 })")
2125                     .ValueOrDie();
2126   auto options = DynamicPadderOptions();
2127   options.shape_check_mode =
2128       DynamicDimensionInference::ShapeCheckMode::kCompileTime;
2129   DynamicPadder pass(options);
2130   auto status = pass.Run(module.get()).status();
2131   EXPECT_THAT(status.code(), tensorflow::error::INVALID_ARGUMENT);
2132 }
2133 
TEST_F(SizeCheckTest,CompileTimeCheckBinaryOpPass)2134 TEST_F(SizeCheckTest, CompileTimeCheckBinaryOpPass) {
2135   // Two different sizes.
2136   auto module = ParseAndReturnUnverifiedModule(R"(
2137 HloModule _
2138 ENTRY gds {
2139   size_0 = s32[] parameter(0)
2140   size_0_reshape = s32[1] reshape(size_0)
2141   size_1 = s32[] reshape(size_0_reshape)
2142   arg = s32[4]{0} parameter(1)
2143   dynamic_arg_0 = s32[<=4] set-dimension-size(arg, size_0), dimensions={0}
2144   dynamic_arg_1 = s32[<=4] set-dimension-size(arg, size_1), dimensions={0}
2145   ROOT add = s32[<=4] add(dynamic_arg_0, dynamic_arg_1)
2146 })")
2147                     .ValueOrDie();
2148   auto options = DynamicPadderOptions();
2149   options.shape_check_mode =
2150       DynamicDimensionInference::ShapeCheckMode::kCompileTime;
2151   DynamicDimensionSimplifier simplifier;
2152   EXPECT_TRUE(simplifier.Run(module.get()).ok());
2153   DynamicPadder pass(options);
2154   auto status = pass.Run(module.get()).status();
2155   EXPECT_TRUE(status.ok());
2156 }
2157 
2158 }  // namespace
2159 }  // namespace xla
2160