xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h"
17 
18 #include "tensorflow/compiler/xla/client/xla_builder.h"
19 #include "tensorflow/compiler/xla/literal.h"
20 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
21 #include "tensorflow/compiler/xla/service/hlo_computation.h"
22 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
23 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
24 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
25 #include "tensorflow/compiler/xla/service/hlo_module.h"
26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
27 #include "tensorflow/compiler/xla/service/hlo_runner.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/test.h"
30 #include "tensorflow/compiler/xla/test_helpers.h"
31 #include "tensorflow/compiler/xla/tests/filecheck.h"
32 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 #include "tensorflow/core/lib/core/status_test_util.h"
35 #include "tensorflow/core/platform/test_benchmark.h"
36 
37 namespace op = xla::testing::opcode_matchers;
38 
39 namespace xla {
40 namespace {
41 
42 class DynamicDimensionInferenceTest : public HloTestBase {
43  protected:
DynamicDimensionInferenceTest()44   DynamicDimensionInferenceTest() : HloTestBase() {
45     module_ = CreateNewVerifiedModule();
46   }
47 
RunInference(DynamicDimensionInference::CustomCallInferenceHandler handler=nullptr,DynamicDimensionInference::ShapeCheckMode shape_check_mode=DynamicDimensionInference::ShapeCheckMode::kIgnore,const DynamicDimensionInference::AssertionGenerator & assertion_generator=nullptr)48   Status RunInference(
49       DynamicDimensionInference::CustomCallInferenceHandler handler = nullptr,
50       DynamicDimensionInference::ShapeCheckMode shape_check_mode =
51           DynamicDimensionInference::ShapeCheckMode::kIgnore,
52       const DynamicDimensionInference::AssertionGenerator& assertion_generator =
53           nullptr) {
54     TF_ASSIGN_OR_RETURN(
55         DynamicDimensionInference inference,
56         DynamicDimensionInference::Run(module_.get(), handler, shape_check_mode,
57                                        assertion_generator));
58 
59     inference_ = std::make_unique<DynamicDimensionInference>(inference);
60     return OkStatus();
61   }
62 
GetAdd()63   HloComputation* GetAdd() {
64     auto embedded_builder = HloComputation::Builder("add");
65     auto lhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
66         0, ShapeUtil::MakeShape(F32, {}), "lhs"));
67     auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
68         1, ShapeUtil::MakeShape(F32, {}), "rhs"));
69     embedded_builder.AddInstruction(
70         HloInstruction::CreateBinary(lhs->shape(), HloOpcode::kAdd, lhs, rhs));
71     return module_->AddEmbeddedComputation(embedded_builder.Build());
72   }
73 
GetAddTuple()74   HloComputation* GetAddTuple() {
75     auto embedded_builder = HloComputation::Builder("add");
76     auto lhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
77         0, ShapeUtil::MakeShape(F32, {}), "lhs"));
78     auto lhs_1 =
79         embedded_builder.AddInstruction(HloInstruction::CreateParameter(
80             1, ShapeUtil::MakeShape(F32, {}), "lhs.1"));
81     auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
82         2, ShapeUtil::MakeShape(F32, {}), "rhs"));
83     auto rhs_1 =
84         embedded_builder.AddInstruction(HloInstruction::CreateParameter(
85             3, ShapeUtil::MakeShape(F32, {}), "rhs.1"));
86     auto add = embedded_builder.AddInstruction(
87         HloInstruction::CreateBinary(lhs->shape(), HloOpcode::kAdd, lhs, rhs));
88     auto add_1 = embedded_builder.AddInstruction(HloInstruction::CreateBinary(
89         lhs->shape(), HloOpcode::kAdd, lhs_1, rhs_1));
90     embedded_builder.AddInstruction(HloInstruction::CreateTuple({add, add_1}));
91     return module_->AddEmbeddedComputation(embedded_builder.Build());
92   }
93 
GetGe()94   HloComputation* GetGe() {
95     auto embedded_builder = HloComputation::Builder("ge");
96     auto lhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
97         0, ShapeUtil::MakeShape(F32, {}), "lhs"));
98     auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
99         1, ShapeUtil::MakeShape(F32, {}), "rhs"));
100     embedded_builder.AddInstruction(HloInstruction::CreateCompare(
101         ShapeUtil::MakeShape(PRED, {}), lhs, rhs, ComparisonDirection::kGe));
102     return module_->AddEmbeddedComputation(embedded_builder.Build());
103   }
104 
105   std::unique_ptr<HloModule> module_;
106   std::unique_ptr<DynamicDimensionInference> inference_;
107   const Shape scalar_shape_ = ShapeUtil::MakeShape(S32, {});
108 };
109 
TEST_F(DynamicDimensionInferenceTest,ParamTest)110 TEST_F(DynamicDimensionInferenceTest, ParamTest) {
111   auto builder = HloComputation::Builder(TestName());
112   auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
113 
114   auto param = builder.AddInstruction(
115       HloInstruction::CreateParameter(0, input_shape, "param"));
116   auto param2 = builder.AddInstruction(
117       HloInstruction::CreateParameter(1, scalar_shape_, "param"));
118 
119   module_->AddEntryComputation(builder.Build());
120   SCOPED_TRACE(module_->ToString());
121 
122   // Set up dynamic parameter binding.
123   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
124       DynamicParameterBinding::DynamicParameter{1, {}},
125       DynamicParameterBinding::DynamicDimension{0, {}, 1}));
126 
127   TF_ASSERT_OK(RunInference());
128   EXPECT_EQ(inference_->GetDynamicSize(param, {}, 1), param2);
129   EXPECT_EQ(inference_->GetDynamicSize(param, {}, 0), nullptr);
130   EXPECT_EQ(inference_->GetDynamicSize(param2, {}, 0), nullptr);
131 }
132 
TEST_F(DynamicDimensionInferenceTest,ParamTestTuple)133 TEST_F(DynamicDimensionInferenceTest, ParamTestTuple) {
134   auto builder = HloComputation::Builder(TestName());
135   auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
136 
137   auto param = builder.AddInstruction(HloInstruction::CreateParameter(
138       0, ShapeUtil::MakeTupleShape({input_shape, scalar_shape_}), "param"));
139 
140   module_->AddEntryComputation(builder.Build());
141   // Set up dynamic parameter binding.
142   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
143       DynamicParameterBinding::DynamicParameter{0, {1}},
144       DynamicParameterBinding::DynamicDimension{0, {0}, 1}));
145 
146   SCOPED_TRACE(module_->ToString());
147   TF_ASSERT_OK(RunInference());
148   EXPECT_THAT(inference_->GetDynamicSize(param, {0}, 1),
149               op::GetTupleElement(param, 1));
150 
151   EXPECT_EQ(inference_->GetDynamicSize(param, {0}, 0), nullptr);
152 }
153 
TEST_F(DynamicDimensionInferenceTest,GetTupleElement)154 TEST_F(DynamicDimensionInferenceTest, GetTupleElement) {
155   // When data flows through GTE, the dynamic dimension size keeps the
156   // same, and the index has its front popped.
157   auto builder = HloComputation::Builder(TestName());
158   auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
159 
160   auto param = builder.AddInstruction(HloInstruction::CreateParameter(
161       0, ShapeUtil::MakeTupleShape({input_shape, scalar_shape_}), "param"));
162 
163   auto gte = builder.AddInstruction(
164       HloInstruction::CreateGetTupleElement(input_shape, param, 0));
165 
166   module_->AddEntryComputation(builder.Build());
167   // Set up dynamic parameter binding.
168   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
169       DynamicParameterBinding::DynamicParameter{0, {1}},
170       DynamicParameterBinding::DynamicDimension{0, {0}, 1}));
171 
172   SCOPED_TRACE(module_->ToString());
173   TF_ASSERT_OK(RunInference());
174   EXPECT_THAT(inference_->GetDynamicSize(param, {0}, 1),
175               op::GetTupleElement(param, 1));
176 
177   EXPECT_THAT(inference_->GetDynamicSize(gte, {}, 1),
178               op::GetTupleElement(param, 1));
179 
180   EXPECT_EQ(inference_->GetDynamicSize(param, {0}, 0), nullptr);
181 }
182 
TEST_F(DynamicDimensionInferenceTest,ElementwiseTest)183 TEST_F(DynamicDimensionInferenceTest, ElementwiseTest) {
184   // When data flows through elementwise, the dynamic dimension size keeps the
185   // same.
186   auto builder = HloComputation::Builder(TestName());
187   auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
188 
189   auto data_param = builder.AddInstruction(
190       HloInstruction::CreateParameter(0, input_shape, "data_param"));
191   auto size_param = builder.AddInstruction(
192       HloInstruction::CreateParameter(1, scalar_shape_, "size_param"));
193 
194   auto* negate = builder.AddInstruction(
195       HloInstruction::CreateUnary(input_shape, HloOpcode::kNegate, data_param));
196 
197   module_->AddEntryComputation(builder.Build());
198   // Set up dynamic parameter binding.
199   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
200       DynamicParameterBinding::DynamicParameter{1, {}},
201       DynamicParameterBinding::DynamicDimension{0, {}, 1}));
202 
203   SCOPED_TRACE(module_->ToString());
204   TF_ASSERT_OK(RunInference());
205   EXPECT_EQ(inference_->GetDynamicSize(negate, {}, 1), size_param);
206 }
207 
TEST_F(DynamicDimensionInferenceTest,ReduceTestI)208 TEST_F(DynamicDimensionInferenceTest, ReduceTestI) {
209   auto builder = HloComputation::Builder(TestName());
210   auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
211   auto reduce_shape = ShapeUtil::MakeShape(F32, {2});
212 
213   auto data_param = builder.AddInstruction(
214       HloInstruction::CreateParameter(0, input_shape, "data_param"));
215   auto size_param = builder.AddInstruction(
216       HloInstruction::CreateParameter(1, scalar_shape_, "size_param"));
217 
218   auto negate = builder.AddInstruction(
219       HloInstruction::CreateUnary(input_shape, HloOpcode::kNegate, data_param));
220 
221   auto init = builder.AddInstruction(
222       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
223 
224   auto reduce = builder.AddInstruction(HloInstruction::CreateReduce(
225       reduce_shape, negate, init, {0, 2}, GetAdd()));
226 
227   module_->AddEntryComputation(builder.Build());
228 
229   // Set up dynamic parameter binding.
230   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
231       DynamicParameterBinding::DynamicParameter{1, {}},
232       DynamicParameterBinding::DynamicDimension{0, {}, 1}));
233 
234   SCOPED_TRACE(module_->ToString());
235   TF_ASSERT_OK(RunInference());
236   EXPECT_EQ(inference_->GetDynamicSize(reduce, {}, 0), size_param);
237 }
238 
TEST_F(DynamicDimensionInferenceTest,ReduceTestII)239 TEST_F(DynamicDimensionInferenceTest, ReduceTestII) {
240   // Same as ReduceTestI, but only reduce one dimension.
241   auto builder = HloComputation::Builder(TestName());
242   auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
243   auto reduce_shape = ShapeUtil::MakeShape(F32, {1, 2});
244 
245   auto data_param = builder.AddInstruction(
246       HloInstruction::CreateParameter(0, input_shape, "data_param"));
247   auto size_param = builder.AddInstruction(
248       HloInstruction::CreateParameter(1, scalar_shape_, "size_param"));
249 
250   auto negate = builder.AddInstruction(
251       HloInstruction::CreateUnary(input_shape, HloOpcode::kNegate, data_param));
252 
253   auto init = builder.AddInstruction(
254       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
255 
256   auto reduce = builder.AddInstruction(
257       HloInstruction::CreateReduce(reduce_shape, negate, init, {1}, GetAdd()));
258 
259   module_->AddEntryComputation(builder.Build());
260 
261   // Set up dynamic parameter binding.
262   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
263       DynamicParameterBinding::DynamicParameter{1, {}},
264       DynamicParameterBinding::DynamicDimension{0, {}, 2}));
265 
266   SCOPED_TRACE(module_->ToString());
267   TF_ASSERT_OK(RunInference());
268   EXPECT_EQ(inference_->GetDynamicSize(reduce, {}, 1), size_param);
269   EXPECT_EQ(inference_->GetDynamicSize(reduce, {}, 0), nullptr);
270 }
271 
TEST_F(DynamicDimensionInferenceTest,VariadicReduce)272 TEST_F(DynamicDimensionInferenceTest, VariadicReduce) {
273   // Handle variadic reduce where output is a tuple.
274   auto builder = HloComputation::Builder(TestName());
275   auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
276   auto reduce_shape = ShapeUtil::MakeShape(F32, {1, 2});
277 
278   auto data_param_dynamic = builder.AddInstruction(
279       HloInstruction::CreateParameter(0, input_shape, "data_param"));
280   auto data_param_static = builder.AddInstruction(
281       HloInstruction::CreateParameter(1, input_shape, "data_param.2"));
282   auto size_param = builder.AddInstruction(
283       HloInstruction::CreateParameter(2, scalar_shape_, "size_param"));
284 
285   // Set up dynamic parameter binding.
286   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
287       DynamicParameterBinding::DynamicParameter{2, {}},
288       DynamicParameterBinding::DynamicDimension{0, {}, 2}));
289 
290   auto dynamic_negate = builder.AddInstruction(HloInstruction::CreateUnary(
291       input_shape, HloOpcode::kNegate, data_param_dynamic));
292 
293   auto static_negate = builder.AddInstruction(HloInstruction::CreateUnary(
294       input_shape, HloOpcode::kNegate, data_param_static));
295 
296   auto init = builder.AddInstruction(
297       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
298 
299   auto reduce = builder.AddInstruction(HloInstruction::CreateReduce(
300       ShapeUtil::MakeTupleShape({reduce_shape, reduce_shape}),
301       {dynamic_negate, static_negate}, {init, init}, {1}, GetAddTuple()));
302 
303   module_->AddEntryComputation(builder.Build());
304 
305   SCOPED_TRACE(module_->ToString());
306   TF_ASSERT_OK(RunInference());
307   EXPECT_EQ(inference_->GetDynamicSize(reduce, {0}, 1), size_param);
308   EXPECT_EQ(inference_->GetDynamicSize(reduce, {1}, 1), size_param);
309   EXPECT_EQ(inference_->GetDynamicSize(reduce, {0}, 0), nullptr);
310   EXPECT_EQ(inference_->GetDynamicSize(reduce, {1}, 0), nullptr);
311 }
312 
TEST_F(DynamicDimensionInferenceTest,DotTest)313 TEST_F(DynamicDimensionInferenceTest, DotTest) {
314   auto builder = HloComputation::Builder(TestName());
315   constexpr int xdim = 3;
316   constexpr int ydim = 2;
317   constexpr int zdim = 1;
318   auto xy_shape = ShapeUtil::MakeShape(F32, {xdim, ydim});
319   auto yz_shape = ShapeUtil::MakeShape(F32, {ydim, zdim});
320   auto xz_shape = ShapeUtil::MakeShape(F32, {xdim, zdim});
321 
322   auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
323       /*parameter_number=*/0, xy_shape, "A"));
324   auto* b_param = builder.AddInstruction(HloInstruction::CreateParameter(
325       /*parameter_number=*/1, yz_shape, "B"));
326   auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter(
327       /*parameter_number=*/2, scalar_shape_, "size_param"));
328 
329   DotDimensionNumbers dot_dnums;
330   dot_dnums.add_lhs_contracting_dimensions(1);
331   dot_dnums.add_rhs_contracting_dimensions(0);
332   auto dot = builder.AddInstruction(
333       HloInstruction::CreateDot(xz_shape, a_param, b_param, dot_dnums,
334                                 HloTestBase::DefaultPrecisionConfig(2)));
335 
336   module_->AddEntryComputation(builder.Build());
337 
338   // Set up dynamic parameter binding for non-contracting dimension.
339   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
340       DynamicParameterBinding::DynamicParameter{2, {}},
341       DynamicParameterBinding::DynamicDimension{0, {}, 0}));
342 
343   // Set up binding for contracting dimensions.
344   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
345       DynamicParameterBinding::DynamicParameter{2, {}},
346       DynamicParameterBinding::DynamicDimension{0, {}, 1}));
347   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
348       DynamicParameterBinding::DynamicParameter{2, {}},
349       DynamicParameterBinding::DynamicDimension{1, {}, 0}));
350 
351   SCOPED_TRACE(module_->ToString());
352   TF_ASSERT_OK(RunInference());
353   EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 0), size_param);
354   EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 1), nullptr);
355 }
356 
TEST_F(DynamicDimensionInferenceTest,DotTestBatch)357 TEST_F(DynamicDimensionInferenceTest, DotTestBatch) {
358   auto builder = HloComputation::Builder(TestName());
359   auto lhs_shape = ShapeUtil::MakeShape(F32, {4, 128, 2, 8});
360   auto rhs_shape = ShapeUtil::MakeShape(F32, {4, 128, 2, 8});
361   auto output_shape = ShapeUtil::MakeShape(F32, {4, 2, 128, 128});
362 
363   auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
364       /*parameter_number=*/0, lhs_shape, "A"));
365   auto* b_param = builder.AddInstruction(HloInstruction::CreateParameter(
366       /*parameter_number=*/1, rhs_shape, "B"));
367   auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter(
368       /*parameter_number=*/2, scalar_shape_, "size_param"));
369 
370   DotDimensionNumbers dot_dnums;
371   dot_dnums.add_lhs_contracting_dimensions(3);
372   dot_dnums.add_rhs_contracting_dimensions(3);
373   dot_dnums.add_lhs_batch_dimensions(0);
374   dot_dnums.add_lhs_batch_dimensions(2);
375   dot_dnums.add_rhs_batch_dimensions(0);
376   dot_dnums.add_rhs_batch_dimensions(2);
377   auto dot = builder.AddInstruction(
378       HloInstruction::CreateDot(output_shape, a_param, b_param, dot_dnums,
379                                 HloTestBase::DefaultPrecisionConfig(2)));
380 
381   module_->AddEntryComputation(builder.Build());
382 
383   // Set up dynamic parameter binding for batch dimension.
384   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
385       DynamicParameterBinding::DynamicParameter{2, {}},
386       DynamicParameterBinding::DynamicDimension{0, {}, 0}));
387 
388   SCOPED_TRACE(module_->ToString());
389   TF_ASSERT_OK(RunInference());
390   EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 0), size_param);
391   EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 1), nullptr);
392   EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 2), nullptr);
393   EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 3), nullptr);
394 }
395 
TEST_F(DynamicDimensionInferenceTest,DotTestMultiContracting)396 TEST_F(DynamicDimensionInferenceTest, DotTestMultiContracting) {
397   auto builder = HloComputation::Builder(TestName());
398   auto lhs_shape = ShapeUtil::MakeShape(F32, {2, 2, 8, 64});
399   auto rhs_shape = ShapeUtil::MakeShape(F32, {2, 2, 512});
400   auto output_shape = ShapeUtil::MakeShape(F32, {8, 64, 512});
401 
402   auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
403       /*parameter_number=*/0, lhs_shape, "A"));
404   auto* b_param = builder.AddInstruction(HloInstruction::CreateParameter(
405       /*parameter_number=*/1, rhs_shape, "B"));
406   builder.AddInstruction(HloInstruction::CreateParameter(
407       /*parameter_number=*/2, scalar_shape_, "size_param"));
408 
409   DotDimensionNumbers dot_dnums;
410   dot_dnums.add_lhs_contracting_dimensions(0);
411   dot_dnums.add_lhs_contracting_dimensions(1);
412   dot_dnums.add_rhs_contracting_dimensions(0);
413   dot_dnums.add_rhs_contracting_dimensions(1);
414   auto dot = builder.AddInstruction(
415       HloInstruction::CreateDot(output_shape, a_param, b_param, dot_dnums,
416                                 HloTestBase::DefaultPrecisionConfig(2)));
417 
418   module_->AddEntryComputation(builder.Build());
419 
420   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
421       DynamicParameterBinding::DynamicParameter{2, {}},
422       DynamicParameterBinding::DynamicDimension{0, {}, 0}));
423 
424   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
425       DynamicParameterBinding::DynamicParameter{2, {}},
426       DynamicParameterBinding::DynamicDimension{0, {}, 1}));
427   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
428       DynamicParameterBinding::DynamicParameter{2, {}},
429       DynamicParameterBinding::DynamicDimension{1, {}, 0}));
430 
431   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
432       DynamicParameterBinding::DynamicParameter{2, {}},
433       DynamicParameterBinding::DynamicDimension{1, {}, 1}));
434 
435   SCOPED_TRACE(module_->ToString());
436   TF_ASSERT_OK(RunInference());
437   // Nothing is dynamic in the output.
438   EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 0), nullptr);
439   EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 1), nullptr);
440   EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 2), nullptr);
441 }
442 
TEST_F(DynamicDimensionInferenceTest,ConvolutionTest)443 TEST_F(DynamicDimensionInferenceTest, ConvolutionTest) {
444   auto builder = HloComputation::Builder(TestName());
445   constexpr int xdim = 3;
446   constexpr int ydim = 2;
447   constexpr int zdim = 1;
448   auto xy_shape = ShapeUtil::MakeShape(F32, {xdim, ydim});
449   auto yz_shape = ShapeUtil::MakeShape(F32, {ydim, zdim});
450   auto zx_shape = ShapeUtil::MakeShape(F32, {zdim, xdim});
451 
452   auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
453       /*parameter_number=*/0, xy_shape, "A"));
454   auto* b_param = builder.AddInstruction(HloInstruction::CreateParameter(
455       /*parameter_number=*/1, yz_shape, "B"));
456   auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter(
457       /*parameter_number=*/2, scalar_shape_, "size_param"));
458 
459   auto dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(0);
460 
461   dnums.set_kernel_input_feature_dimension(0);
462   dnums.set_kernel_output_feature_dimension(1);
463   dnums.set_input_batch_dimension(0);
464   dnums.set_output_batch_dimension(1);
465   dnums.set_output_feature_dimension(0);
466 
467   Window window;
468 
469   auto* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
470       zx_shape, a_param, b_param, /*feature_group_count=*/1,
471       /*batch_group_count=*/1, window, dnums,
472       HloTestBase::DefaultPrecisionConfig(2)));
473 
474   module_->AddEntryComputation(builder.Build());
475 
476   // Set up dynamic parameter binding for non-contracting dimension.
477   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
478       DynamicParameterBinding::DynamicParameter{2, {}},
479       DynamicParameterBinding::DynamicDimension{0, {}, 0}));
480 
481   // Set up binding for contracting dimensions.
482   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
483       DynamicParameterBinding::DynamicParameter{2, {}},
484       DynamicParameterBinding::DynamicDimension{0, {}, 1}));
485 
486   SCOPED_TRACE(module_->ToString());
487   TF_ASSERT_OK(RunInference());
488   EXPECT_EQ(inference_->GetDynamicSize(conv, {}, 1), size_param);
489   EXPECT_EQ(inference_->GetDynamicSize(conv, {}, 0), nullptr);
490 }
491 
TEST_F(DynamicDimensionInferenceTest,TransposeTest)492 TEST_F(DynamicDimensionInferenceTest, TransposeTest) {
493   // Test the ability to trace unmodified dimensions
494   auto builder = HloComputation::Builder(TestName());
495   auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 3});
496   auto output_shape = ShapeUtil::MakeShape(F32, {3, 2, 1});
497 
498   auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
499       /*parameter_number=*/0, input_shape, "A"));
500   auto* size_param_1 = builder.AddInstruction(HloInstruction::CreateParameter(
501       /*parameter_number=*/1, scalar_shape_, "size_param"));
502   auto* size_param_2 = builder.AddInstruction(HloInstruction::CreateParameter(
503       /*parameter_number=*/2, scalar_shape_, "size_param"));
504   auto* size_param_3 = builder.AddInstruction(HloInstruction::CreateParameter(
505       /*parameter_number=*/3, scalar_shape_, "size_param"));
506 
507   auto* transpose = builder.AddInstruction(
508       HloInstruction::CreateTranspose(output_shape, a_param, {2, 1, 0}));
509 
510   module_->AddEntryComputation(builder.Build());
511 
512   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
513       DynamicParameterBinding::DynamicParameter{1, {}},
514       DynamicParameterBinding::DynamicDimension{0, {}, 0}));
515 
516   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
517       DynamicParameterBinding::DynamicParameter{2, {}},
518       DynamicParameterBinding::DynamicDimension{0, {}, 1}));
519 
520   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
521       DynamicParameterBinding::DynamicParameter{3, {}},
522       DynamicParameterBinding::DynamicDimension{0, {}, 2}));
523 
524   SCOPED_TRACE(module_->ToString());
525   TF_ASSERT_OK(RunInference());
526   EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 0), size_param_3);
527   EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 1), size_param_2);
528   EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 2), size_param_1);
529 }
530 
TEST_F(DynamicDimensionInferenceTest,NonDescendingTransposeTest)531 TEST_F(DynamicDimensionInferenceTest, NonDescendingTransposeTest) {
532   // Test the ability to trace unmodified dimensions
533   auto builder = HloComputation::Builder(TestName());
534   auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 3});
535   auto output_shape = ShapeUtil::MakeShape(F32, {3, 1, 2});
536 
537   auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
538       /*parameter_number=*/0, input_shape, "A"));
539   auto* size_param_1 = builder.AddInstruction(HloInstruction::CreateParameter(
540       /*parameter_number=*/1, scalar_shape_, "size_param"));
541   auto* size_param_2 = builder.AddInstruction(HloInstruction::CreateParameter(
542       /*parameter_number=*/2, scalar_shape_, "size_param"));
543   auto* size_param_3 = builder.AddInstruction(HloInstruction::CreateParameter(
544       /*parameter_number=*/3, scalar_shape_, "size_param"));
545 
546   auto* transpose = builder.AddInstruction(
547       HloInstruction::CreateTranspose(output_shape, a_param, {2, 0, 1}));
548 
549   module_->AddEntryComputation(builder.Build());
550 
551   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
552       DynamicParameterBinding::DynamicParameter{1, {}},
553       DynamicParameterBinding::DynamicDimension{0, {}, 0}));
554 
555   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
556       DynamicParameterBinding::DynamicParameter{2, {}},
557       DynamicParameterBinding::DynamicDimension{0, {}, 1}));
558 
559   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
560       DynamicParameterBinding::DynamicParameter{3, {}},
561       DynamicParameterBinding::DynamicDimension{0, {}, 2}));
562 
563   SCOPED_TRACE(module_->ToString());
564   TF_ASSERT_OK(RunInference());
565   EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 0), size_param_3);
566   EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 1), size_param_1);
567   EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 2), size_param_2);
568 }
569 
TEST_F(DynamicDimensionInferenceTest,ReshapeTest)570 TEST_F(DynamicDimensionInferenceTest, ReshapeTest) {
571   // Test the ability to trace unmodified reshape dimensions.
572   auto builder = HloComputation::Builder(TestName());
573   auto input_shape = ShapeUtil::MakeShape(F32, {2, 3, 4, 5, 6});
574   auto output_shape = ShapeUtil::MakeShape(F32, {6, 4, 1, 5, 2, 3});
575 
576   auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
577       /*parameter_number=*/0, input_shape, "A"));
578   auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter(
579       /*parameter_number=*/1, scalar_shape_, "size_param"));
580 
581   auto* reshape = builder.AddInstruction(
582       HloInstruction::CreateReshape(output_shape, a_param));
583 
584   module_->AddEntryComputation(builder.Build());
585 
586   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
587       DynamicParameterBinding::DynamicParameter{1, {}},
588       DynamicParameterBinding::DynamicDimension{0, {}, 2}));
589 
590   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
591       DynamicParameterBinding::DynamicParameter{1, {}},
592       DynamicParameterBinding::DynamicDimension{0, {}, 3}));
593 
594   SCOPED_TRACE(module_->ToString());
595   TF_ASSERT_OK(RunInference());
596   EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 0), nullptr);
597   EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 1), size_param);
598   EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 2), nullptr);
599   EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 3), size_param);
600   EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 4), nullptr);
601   EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 5), nullptr);
602 }
603 
TEST_F(DynamicDimensionInferenceTest,ReshapeInferredDimensionTest)604 TEST_F(DynamicDimensionInferenceTest, ReshapeInferredDimensionTest) {
605   // Test the ability to trace inferred dimension when output is bigger than
606   // input.
607   auto builder = HloComputation::Builder(TestName());
608   auto input_shape = ShapeUtil::MakeShape(F32, {4, 5});
609   auto output_shape = ShapeUtil::MakeShape(F32, {1, 4, 5});
610 
611   auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
612       /*parameter_number=*/0, input_shape, "A"));
613   builder.AddInstruction(HloInstruction::CreateParameter(
614       /*parameter_number=*/1, scalar_shape_, "size_param"));
615 
616   auto* reshape = builder.AddInstruction(HloInstruction::CreateReshape(
617       output_shape, a_param, /*inferred_dimension=*/0));
618 
619   module_->AddEntryComputation(builder.Build());
620 
621   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
622       DynamicParameterBinding::DynamicParameter{1, {}},
623       DynamicParameterBinding::DynamicDimension{0, {}, 0}));
624 
625   SCOPED_TRACE(module_->ToString());
626   TF_ASSERT_OK(RunInference());
627   EXPECT_NE(inference_->GetDynamicSize(reshape, {}, 0), nullptr);
628 }
629 
TEST_F(DynamicDimensionInferenceTest,ReshapeTestMajorDimension)630 TEST_F(DynamicDimensionInferenceTest, ReshapeTestMajorDimension) {
631   // Test the ability to trace dimension combining.
632   auto builder = HloComputation::Builder(TestName());
633   auto input_shape = ShapeUtil::MakeShape(F32, {32, 10, 4});
634   auto output_shape = ShapeUtil::MakeShape(F32, {320, 4});
635 
636   auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
637       /*parameter_number=*/0, input_shape, "A"));
638 
639   builder.AddInstruction(HloInstruction::CreateParameter(
640       /*parameter_number=*/1, scalar_shape_, "size_param"));
641 
642   auto* reshape = builder.AddInstruction(
643       HloInstruction::CreateReshape(output_shape, a_param));
644 
645   module_->AddEntryComputation(builder.Build());
646 
647   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
648       DynamicParameterBinding::DynamicParameter{1, {}},
649       DynamicParameterBinding::DynamicDimension{0, {}, 0}));
650 
651   SCOPED_TRACE(module_->ToString());
652   Status status = RunInference();
653   EXPECT_NE(inference_->GetDynamicSize(reshape, {}, 0), nullptr);
654 }
655 
TEST_F(DynamicDimensionInferenceTest,ReshapeIntoScalar)656 TEST_F(DynamicDimensionInferenceTest, ReshapeIntoScalar) {
657   // Test the ability to a reshape into scalar.
658   auto builder = HloComputation::Builder(TestName());
659   auto input_shape = ShapeUtil::MakeShape(F32, {1});
660   auto output_shape = ShapeUtil::MakeShape(F32, {});
661 
662   auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
663       /*parameter_number=*/0, input_shape, "A"));
664 
665   builder.AddInstruction(HloInstruction::CreateParameter(
666       /*parameter_number=*/1, scalar_shape_, "size_param"));
667 
668   builder.AddInstruction(HloInstruction::CreateReshape(output_shape, a_param));
669 
670   module_->AddEntryComputation(builder.Build());
671 
672   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
673       DynamicParameterBinding::DynamicParameter{1, {}},
674       DynamicParameterBinding::DynamicDimension{0, {}, 0}));
675 
676   SCOPED_TRACE(module_->ToString());
677   TF_CHECK_OK(RunInference());
678 }
679 
TEST_F(DynamicDimensionInferenceTest,GatherTest)680 TEST_F(DynamicDimensionInferenceTest, GatherTest) {
681   const std::string hlo_text = R"(
682 HloModule TensorFlowGatherV2
683 
684 ENTRY main {
685   operand = s32[20,10]{1,0} parameter(0)
686   indices = s32[32,20] parameter(1)
687   dynamic_size = s32[] parameter(2)
688   ROOT gather = s32[32,20,10]{2,1,0} gather(%operand, %indices),
689                  offset_dims={2},
690                  collapsed_slice_dims={0},
691                  start_index_map={0},
692                  index_vector_dim=2,
693                  slice_sizes={1,10}
694 }
695 )";
696 
697   TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(hlo_text));
698   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
699       DynamicParameterBinding::DynamicParameter{2, {}},
700       DynamicParameterBinding::DynamicDimension{1, {}, 0}));
701   SCOPED_TRACE(module_->ToString());
702   TF_ASSERT_OK(RunInference());
703   EXPECT_EQ(inference_->GetDynamicSize(
704                 module_->entry_computation()->root_instruction(), {}, 0),
705             module_->entry_computation()->parameter_instruction(2));
706 }
707 
TEST_F(DynamicDimensionInferenceTest,BroadcastTest)708 TEST_F(DynamicDimensionInferenceTest, BroadcastTest) {
709   // Test the ability to trace broadcast dimension.
710   auto builder = HloComputation::Builder(TestName());
711   auto input_shape = ShapeUtil::MakeShape(F32, {2});
712   auto output_shape = ShapeUtil::MakeShape(F32, {3, 2, 4});
713 
714   auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
715       /*parameter_number=*/0, input_shape, "A"));
716   auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter(
717       /*parameter_number=*/1, scalar_shape_, "size_param"));
718 
719   auto* broadcast = builder.AddInstruction(
720       HloInstruction::CreateBroadcast(output_shape, a_param, {1}));
721 
722   module_->AddEntryComputation(builder.Build());
723 
724   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
725       DynamicParameterBinding::DynamicParameter{1, {}},
726       DynamicParameterBinding::DynamicDimension{0, {}, 0}));
727 
728   SCOPED_TRACE(module_->ToString());
729   TF_ASSERT_OK(RunInference());
730   EXPECT_EQ(inference_->GetDynamicSize(broadcast, {}, 0), nullptr);
731   EXPECT_EQ(inference_->GetDynamicSize(broadcast, {}, 1), size_param);
732   EXPECT_EQ(inference_->GetDynamicSize(broadcast, {}, 2), nullptr);
733 }
734 
TEST_F(DynamicDimensionInferenceTest,WhileTest)735 TEST_F(DynamicDimensionInferenceTest, WhileTest) {
736   // Test the ability to trace into while loops.
737   auto builder = HloComputation::Builder(TestName());
738   auto input_shape = ShapeUtil::MakeShape(F32, {2, 4, 4});
739   auto output_shape = ShapeUtil::MakeShape(F32, {2, 2, 2});
740   auto tuple_shape = ShapeUtil::MakeTupleShape({input_shape, input_shape});
741 
742   // Body:
743   //
744   //   Param
745   //   |  |
746   // GTE1 GTE2
747   //   |  |
748   //    ADD
749   auto body_builder = HloComputation::Builder("body");
750   auto body_param = body_builder.AddInstruction(
751       HloInstruction::CreateParameter(0, tuple_shape, "param"));
752   auto gte_0 = body_builder.AddInstruction(
753       HloInstruction::CreateGetTupleElement(input_shape, body_param, 0));
754   auto gte_1 = body_builder.AddInstruction(
755       HloInstruction::CreateGetTupleElement(input_shape, body_param, 1));
756   auto add = body_builder.AddInstruction(
757       HloInstruction::CreateBinary(input_shape, HloOpcode::kAdd, gte_0, gte_1));
758   body_builder.AddInstruction(HloInstruction::CreateTuple({add, add}));
759 
760   HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
761 
762   auto cond_builder = HloComputation::Builder("condition");
763   cond_builder.AddInstruction(
764       HloInstruction::CreateParameter(0, tuple_shape, "param"));
765   cond_builder.AddInstruction(
766       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
767   HloComputation* condition =
768       module_->AddEmbeddedComputation(cond_builder.Build());
769 
770   // Entry:
771   //
772   //  Param
773   //   |
774   //  While
775   auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
776       /*parameter_number=*/0, tuple_shape, "A"));
777   builder.AddInstruction(HloInstruction::CreateParameter(
778       /*parameter_number=*/1, scalar_shape_, "size_param"));
779   builder.AddInstruction(
780       HloInstruction::CreateWhile(tuple_shape, condition, body, a_param));
781 
782   module_->AddEntryComputation(builder.Build());
783 
784   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
785       DynamicParameterBinding::DynamicParameter{1, {}},
786       DynamicParameterBinding::DynamicDimension{0, {0}, 0}));
787 
788   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
789       DynamicParameterBinding::DynamicParameter{1, {}},
790       DynamicParameterBinding::DynamicDimension{0, {1}, 0}));
791 
792   TF_ASSERT_OK(RunInference());
793   HloInstruction* while_hlo = nullptr;
794   // The while hlo has been replaced, find the new one.
795   for (HloInstruction* inst : module_->entry_computation()->instructions()) {
796     if (inst->opcode() == HloOpcode::kWhile) {
797       while_hlo = inst;
798     }
799   }
800   ASSERT_NE(while_hlo, nullptr);
801   // The original while shape has 2 parameters. With dynamic size, the tuple
802   // should have 4 elements (We don't deduplicate the arguments).
803   EXPECT_EQ(while_hlo->shape().tuple_shapes_size(), 4);
804   HloInstruction* add_inst = nullptr;
805   for (HloInstruction* inst : while_hlo->while_body()->instructions()) {
806     if (inst->opcode() == HloOpcode::kAdd) {
807       add_inst = inst;
808     }
809   }
810   EXPECT_NE(add_inst, nullptr);
811   EXPECT_NE(inference_->GetDynamicSize(add_inst, {}, 0), nullptr);
812   EXPECT_NE(inference_->GetDynamicSize(
813                 module_->entry_computation()->root_instruction(), {0}, 0),
814             nullptr);
815   EXPECT_NE(inference_->GetDynamicSize(
816                 module_->entry_computation()->root_instruction(), {1}, 0),
817             nullptr);
818 }
819 
TEST_F(DynamicDimensionInferenceTest,ConditionalInputTest)820 TEST_F(DynamicDimensionInferenceTest, ConditionalInputTest) {
821   // Test the ability to trace into contional loops.
822   auto builder = HloComputation::Builder(TestName());
823   auto input_shape = ShapeUtil::MakeShape(F32, {2, 4, 4});
824   auto output_shape = ShapeUtil::MakeShape(F32, {2, 2, 2});
825   // In this test we set inputs to different branches to different shapes.
826   auto tuple_shape_1 = ShapeUtil::MakeTupleShape({input_shape});
827   auto tuple_shape_2 = ShapeUtil::MakeTupleShape({input_shape, input_shape});
828   auto tuple_shape_3 =
829       ShapeUtil::MakeTupleShape({input_shape, input_shape, input_shape});
830 
831   // true branch:
832   //
833   //   Param
834   //   |  |
835   // GTE1 GTE2
836   //   |  |
837   // Tuple(ADD)
838   auto true_builder = HloComputation::Builder("true");
839   {
840     auto true_param = true_builder.AddInstruction(
841         HloInstruction::CreateParameter(0, tuple_shape_2, "param"));
842     auto gte_0 = true_builder.AddInstruction(
843         HloInstruction::CreateGetTupleElement(input_shape, true_param, 0));
844     auto gte_1 = true_builder.AddInstruction(
845         HloInstruction::CreateGetTupleElement(input_shape, true_param, 1));
846     auto add = true_builder.AddInstruction(HloInstruction::CreateBinary(
847         input_shape, HloOpcode::kAdd, gte_0, gte_1));
848     true_builder.AddInstruction(HloInstruction::CreateTuple({add}));
849   }
850   HloComputation* true_branch =
851       module_->AddEmbeddedComputation(true_builder.Build());
852   // false branch:
853   //
854   //      Param
855   //  |     |    |
856   // GTE1  GTE2 GTE3
857   //        |     |
858   //       Tuple(ADD)
859   auto false_builder = HloComputation::Builder("false");
860   {
861     auto false_param = false_builder.AddInstruction(
862         HloInstruction::CreateParameter(0, tuple_shape_3, "param"));
863     auto gte_0 = false_builder.AddInstruction(
864         HloInstruction::CreateGetTupleElement(input_shape, false_param, 1));
865     auto gte_1 = false_builder.AddInstruction(
866         HloInstruction::CreateGetTupleElement(input_shape, false_param, 2));
867     auto add = false_builder.AddInstruction(HloInstruction::CreateBinary(
868         input_shape, HloOpcode::kAdd, gte_0, gte_1));
869     false_builder.AddInstruction(HloInstruction::CreateTuple({add}));
870   }
871   HloComputation* false_branch =
872       module_->AddEmbeddedComputation(false_builder.Build());
873 
874   // Entry:
875   //
876   //  Param(bool) Param2 (tuple_2) Param3(tuple_3)
877   //   |            |                 |
878   //   +---------Condition------------+
879   auto* pred_param = builder.AddInstruction(HloInstruction::CreateParameter(
880       /*parameter_number=*/0, ShapeUtil::MakeScalarShape(PRED), "pred"));
881 
882   auto* tuple_2_param = builder.AddInstruction(HloInstruction::CreateParameter(
883       /*parameter_number=*/1, tuple_shape_2, "tuple_2_param"));
884   auto* tuple_3_param = builder.AddInstruction(HloInstruction::CreateParameter(
885       /*parameter_number=*/2, tuple_shape_3, "tuple_3_param"));
886   builder.AddInstruction(HloInstruction::CreateParameter(
887       /*parameter_number=*/3, scalar_shape_, "size_param"));
888   builder.AddInstruction(HloInstruction::CreateConditional(
889       tuple_shape_1, pred_param, tuple_2_param, true_branch, tuple_3_param,
890       false_branch));
891 
892   module_->AddEntryComputation(builder.Build());
893 
894   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
895       DynamicParameterBinding::DynamicParameter{3, {}},
896       DynamicParameterBinding::DynamicDimension{1, {0}, 0}));
897   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
898       DynamicParameterBinding::DynamicParameter{3, {}},
899       DynamicParameterBinding::DynamicDimension{1, {1}, 0}));
900   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
901       DynamicParameterBinding::DynamicParameter{3, {}},
902       DynamicParameterBinding::DynamicDimension{2, {1}, 0}));
903   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
904       DynamicParameterBinding::DynamicParameter{3, {}},
905       DynamicParameterBinding::DynamicDimension{2, {2}, 0}));
906 
907   TF_ASSERT_OK(RunInference());
908 
909   HloInstruction* conditional_hlo = nullptr;
910   // The while hlo has been replaced, find the new one.
911   for (HloInstruction* inst : module_->entry_computation()->instructions()) {
912     if (inst->opcode() == HloOpcode::kConditional) {
913       conditional_hlo = inst;
914     }
915   }
916   ASSERT_NE(conditional_hlo, nullptr);
917   // The original conditional shape has 1 parameters. With dynamic size passed
918   // out from the computation, another element is added to the tuple.
919   EXPECT_EQ(conditional_hlo->shape().tuple_shapes_size(), 2);
920   HloInstruction* add_true_branch = nullptr;
921   for (HloInstruction* inst :
922        conditional_hlo->true_computation()->instructions()) {
923     if (inst->opcode() == HloOpcode::kAdd) {
924       add_true_branch = inst;
925     }
926   }
927   EXPECT_NE(add_true_branch, nullptr);
928   EXPECT_NE(inference_->GetDynamicSize(add_true_branch, {}, 0), nullptr);
929 
930   HloInstruction* add_false_branch = nullptr;
931   for (HloInstruction* inst :
932        conditional_hlo->false_computation()->instructions()) {
933     if (inst->opcode() == HloOpcode::kAdd) {
934       add_false_branch = inst;
935     }
936   }
937   EXPECT_NE(add_false_branch, nullptr);
938   EXPECT_NE(inference_->GetDynamicSize(add_false_branch, {}, 0), nullptr);
939 
940   EXPECT_NE(inference_->GetDynamicSize(conditional_hlo, {0}, 0), nullptr);
941 }
942 
TEST_F(DynamicDimensionInferenceTest,ReduceWindowBatchTest)943 TEST_F(DynamicDimensionInferenceTest, ReduceWindowBatchTest) {
944   // Test the ability to trace reduce window batch dimensions.
945   auto builder = HloComputation::Builder(TestName());
946   auto input_shape = ShapeUtil::MakeShape(F32, {2, 4, 4});
947   auto output_shape = ShapeUtil::MakeShape(F32, {2, 2, 2});
948 
949   Window window;
950   // First dimension is unchanged.
951   WindowDimension* batch_dim = window.add_dimensions();
952   batch_dim->set_size(1);
953   batch_dim->set_stride(1);
954   batch_dim->set_padding_low(0);
955   batch_dim->set_padding_high(0);
956   batch_dim->set_window_dilation(1);
957   batch_dim->set_base_dilation(1);
958 
959   // Second and third dimension are reduced.
960   for (int64_t i = 0; i < 2; ++i) {
961     WindowDimension* dim = window.add_dimensions();
962     dim->set_size(2);
963     dim->set_stride(2);
964     dim->set_padding_low(0);
965     dim->set_padding_high(0);
966     dim->set_window_dilation(1);
967     dim->set_base_dilation(1);
968   }
969 
970   auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
971       /*parameter_number=*/0, input_shape, "A"));
972   auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter(
973       /*parameter_number=*/1, scalar_shape_, "size_param"));
974 
975   auto init = builder.AddInstruction(
976       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
977 
978   auto* reduce_window =
979       builder.AddInstruction(HloInstruction::CreateReduceWindow(
980           output_shape, a_param, init, window, GetAdd()));
981 
982   module_->AddEntryComputation(builder.Build());
983 
984   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
985       DynamicParameterBinding::DynamicParameter{1, {}},
986       DynamicParameterBinding::DynamicDimension{0, {}, 0}));
987 
988   SCOPED_TRACE(module_->ToString());
989   TF_ASSERT_OK(RunInference());
990   EXPECT_EQ(inference_->GetDynamicSize(reduce_window, {}, 0), size_param);
991 }
992 
TEST_F(DynamicDimensionInferenceTest,SelectAndScatterTest)993 TEST_F(DynamicDimensionInferenceTest, SelectAndScatterTest) {
994   // Test the ability to trace select and scatter batch dimensions.
995   auto builder = HloComputation::Builder(TestName());
996   auto input_shape = ShapeUtil::MakeShape(F32, {2, 4, 4});
997   auto source_shape = ShapeUtil::MakeShape(F32, {2, 2, 2});
998 
999   Window window;
1000   // First dimension is unchanged.
1001   WindowDimension* batch_dim = window.add_dimensions();
1002   batch_dim->set_size(1);
1003   batch_dim->set_stride(1);
1004   batch_dim->set_padding_low(0);
1005   batch_dim->set_padding_high(0);
1006   batch_dim->set_window_dilation(1);
1007   batch_dim->set_base_dilation(1);
1008 
1009   // Second and third dimension are reduced.
1010   for (int64_t i = 0; i < 2; ++i) {
1011     WindowDimension* dim = window.add_dimensions();
1012     dim->set_size(2);
1013     dim->set_stride(2);
1014     dim->set_padding_low(0);
1015     dim->set_padding_high(0);
1016     dim->set_window_dilation(1);
1017     dim->set_base_dilation(1);
1018   }
1019 
1020   auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
1021       /*parameter_number=*/0, input_shape, "A"));
1022   auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter(
1023       /*parameter_number=*/1, scalar_shape_, "size_param"));
1024   auto* source = builder.AddInstruction(HloInstruction::CreateParameter(
1025       /*parameter_number=*/2, source_shape, "B"));
1026 
1027   auto init = builder.AddInstruction(
1028       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
1029 
1030   auto* sns = builder.AddInstruction(HloInstruction::CreateSelectAndScatter(
1031       input_shape, a_param, GetGe(), window, source, init, GetAdd()));
1032 
1033   module_->AddEntryComputation(builder.Build());
1034 
1035   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
1036       DynamicParameterBinding::DynamicParameter{1, {}},
1037       DynamicParameterBinding::DynamicDimension{0, {}, 0}));
1038   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
1039       DynamicParameterBinding::DynamicParameter{1, {}},
1040       DynamicParameterBinding::DynamicDimension{2, {}, 0}));
1041 
1042   SCOPED_TRACE(module_->ToString());
1043   TF_ASSERT_OK(RunInference());
1044   EXPECT_EQ(inference_->GetDynamicSize(sns, {}, 0), size_param);
1045 }
1046 
TEST_F(DynamicDimensionInferenceTest,ConcateTest)1047 TEST_F(DynamicDimensionInferenceTest, ConcateTest) {
1048   // Concat two data params.
1049   auto builder = HloComputation::Builder(TestName());
1050 
1051   auto data_param = builder.AddInstruction(HloInstruction::CreateParameter(
1052       0, ShapeUtil::MakeShape(F32, {5, 7}), "data_param_1"));
1053   auto data_param_2 = builder.AddInstruction(HloInstruction::CreateParameter(
1054       1, ShapeUtil::MakeShape(F32, {5, 8}), "data_param_2"));
1055   auto size_param = builder.AddInstruction(
1056       HloInstruction::CreateParameter(2, scalar_shape_, "size_param"));
1057 
1058   auto* concat = builder.AddInstruction(HloInstruction::CreateConcatenate(
1059       ShapeUtil::MakeShape(F32, {5, 15}), {data_param, data_param_2}, 1));
1060 
1061   module_->AddEntryComputation(builder.Build());
1062   // Set up dynamic parameter binding.
1063   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
1064       DynamicParameterBinding::DynamicParameter{2, {}},
1065       DynamicParameterBinding::DynamicDimension{0, {}, 0}));
1066 
1067   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
1068       DynamicParameterBinding::DynamicParameter{2, {}},
1069       DynamicParameterBinding::DynamicDimension{1, {}, 0}));
1070 
1071   TF_ASSERT_OK(RunInference());
1072   EXPECT_EQ(inference_->GetDynamicSize(concat, {}, 0), size_param);
1073 }
1074 
TEST_F(DynamicDimensionInferenceTest,SliceTest)1075 TEST_F(DynamicDimensionInferenceTest, SliceTest) {
1076   auto builder = HloComputation::Builder(TestName());
1077 
1078   auto data_param = builder.AddInstruction(HloInstruction::CreateParameter(
1079       0, ShapeUtil::MakeShape(F32, {5, 7}), "data_param"));
1080   auto size_param = builder.AddInstruction(
1081       HloInstruction::CreateParameter(1, scalar_shape_, "size_param"));
1082 
1083   auto* slice = builder.AddInstruction(HloInstruction::CreateSlice(
1084       ShapeUtil::MakeShape(F32, {5, 7}), data_param, /*start_indices=*/{0, 0},
1085       /*limit_indices=*/{5, 7}, /*strides=*/{1, 1}));
1086 
1087   module_->AddEntryComputation(builder.Build());
1088   // Set up dynamic parameter binding.
1089   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
1090       DynamicParameterBinding::DynamicParameter{1, {}},
1091       DynamicParameterBinding::DynamicDimension{0, {}, 1}));
1092 
1093   TF_ASSERT_OK(RunInference());
1094   EXPECT_EQ(inference_->GetDynamicSize(slice, {}, 1), size_param);
1095 }
1096 
TEST_F(DynamicDimensionInferenceTest,DynamicSliceTest)1097 TEST_F(DynamicDimensionInferenceTest, DynamicSliceTest) {
1098   auto builder = HloComputation::Builder(TestName());
1099 
1100   auto data_param = builder.AddInstruction(HloInstruction::CreateParameter(
1101       0, ShapeUtil::MakeShape(F32, {5, 7}), "data_param"));
1102   auto size_param = builder.AddInstruction(
1103       HloInstruction::CreateParameter(1, scalar_shape_, "size_param"));
1104 
1105   std::vector<HloInstruction*> params;
1106   for (int i = 0; i < 2; ++i) {
1107     params.push_back(builder.AddInstruction(HloInstruction::CreateParameter(
1108         i + 2, ShapeUtil::MakeShape(S32, {}), "slice_indices")));
1109   }
1110 
1111   auto* slice = builder.AddInstruction(HloInstruction::CreateDynamicSlice(
1112       ShapeUtil::MakeShape(F32, {5, 1}), data_param, params,
1113       /*slice_sizes=*/{5, 1}));
1114 
1115   module_->AddEntryComputation(builder.Build());
1116   // Set up dynamic parameter binding.
1117   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
1118       DynamicParameterBinding::DynamicParameter{1, {}},
1119       DynamicParameterBinding::DynamicDimension{0, {}, 0}));
1120 
1121   TF_ASSERT_OK(RunInference());
1122   EXPECT_EQ(inference_->GetDynamicSize(slice, {}, 0), size_param);
1123 }
1124 
TEST_F(DynamicDimensionInferenceTest,SortTest)1125 TEST_F(DynamicDimensionInferenceTest, SortTest) {
1126   auto builder = HloComputation::Builder(TestName());
1127 
1128   auto data_param = builder.AddInstruction(HloInstruction::CreateParameter(
1129       0, ShapeUtil::MakeShape(F32, {5, 7}), "data_param"));
1130   auto size_param = builder.AddInstruction(
1131       HloInstruction::CreateParameter(1, scalar_shape_, "size_param"));
1132 
1133   auto compare_builder = HloComputation::Builder("condition");
1134   compare_builder.AddInstruction(HloInstruction::CreateParameter(
1135       0, ShapeUtil::MakeShape(F32, {}), "param1"));
1136   compare_builder.AddInstruction(HloInstruction::CreateParameter(
1137       1, ShapeUtil::MakeShape(F32, {}), "param2"));
1138   compare_builder.AddInstruction(
1139       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1140   HloComputation* compare =
1141       module_->AddEmbeddedComputation(compare_builder.Build());
1142 
1143   auto* sort = builder.AddInstruction(HloInstruction::CreateSort(
1144       ShapeUtil::MakeShape(F32, {5, 7}), 1, {data_param}, compare,
1145       /*is_stable=*/false));
1146 
1147   module_->AddEntryComputation(builder.Build());
1148   // Set up dynamic parameter binding.
1149   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
1150       DynamicParameterBinding::DynamicParameter{1, {}},
1151       DynamicParameterBinding::DynamicDimension{0, {}, 0}));
1152 
1153   TF_ASSERT_OK(RunInference());
1154   EXPECT_EQ(inference_->GetDynamicSize(sort, {}, 0), size_param);
1155 }
1156 
TEST_F(DynamicDimensionInferenceTest,MultiValueSortTest)1157 TEST_F(DynamicDimensionInferenceTest, MultiValueSortTest) {
1158   auto builder = HloComputation::Builder(TestName());
1159 
1160   auto shape = ShapeUtil::MakeShape(F32, {5, 7});
1161 
1162   auto data_param = builder.AddInstruction(
1163       HloInstruction::CreateParameter(0, shape, "data_param"));
1164   auto size_param = builder.AddInstruction(
1165       HloInstruction::CreateParameter(1, scalar_shape_, "size_param"));
1166 
1167   auto compare_builder = HloComputation::Builder("condition");
1168   compare_builder.AddInstruction(HloInstruction::CreateParameter(
1169       0, ShapeUtil::MakeShape(F32, {}), "param1"));
1170   compare_builder.AddInstruction(HloInstruction::CreateParameter(
1171       1, ShapeUtil::MakeShape(F32, {}), "param2"));
1172   compare_builder.AddInstruction(HloInstruction::CreateParameter(
1173       2, ShapeUtil::MakeShape(F32, {}), "param3"));
1174   compare_builder.AddInstruction(HloInstruction::CreateParameter(
1175       3, ShapeUtil::MakeShape(F32, {}), "param4"));
1176   compare_builder.AddInstruction(
1177       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1178   HloComputation* compare =
1179       module_->AddEmbeddedComputation(compare_builder.Build());
1180 
1181   auto* sort = builder.AddInstruction(
1182       HloInstruction::CreateSort(ShapeUtil::MakeTupleShape({shape, shape}), 1,
1183                                  {data_param, data_param}, compare,
1184                                  /*is_stable=*/false));
1185 
1186   module_->AddEntryComputation(builder.Build());
1187   // Set up dynamic parameter binding.
1188   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
1189       DynamicParameterBinding::DynamicParameter{1, {}},
1190       DynamicParameterBinding::DynamicDimension{0, {}, 0}));
1191 
1192   TF_ASSERT_OK(RunInference());
1193   EXPECT_EQ(inference_->GetDynamicSize(sort, {0}, 0), size_param);
1194   EXPECT_EQ(inference_->GetDynamicSize(sort, {1}, 0), size_param);
1195 }
1196 
TEST_F(DynamicDimensionInferenceTest,DynamicSliceSingleElementTest)1197 TEST_F(DynamicDimensionInferenceTest, DynamicSliceSingleElementTest) {
1198   // Slicing out a single element from a dynamic dimension terminates the
1199   // dynamic dimension.
1200   auto builder = HloComputation::Builder(TestName());
1201 
1202   auto data_param = builder.AddInstruction(HloInstruction::CreateParameter(
1203       0, ShapeUtil::MakeShape(F32, {5, 7}), "data_param"));
1204   builder.AddInstruction(
1205       HloInstruction::CreateParameter(1, scalar_shape_, "size_param"));
1206 
1207   std::vector<HloInstruction*> params;
1208   for (int i = 0; i < 2; ++i) {
1209     params.push_back(builder.AddInstruction(HloInstruction::CreateParameter(
1210         i + 2, ShapeUtil::MakeShape(S32, {}), "slice_indices")));
1211   }
1212 
1213   auto* slice = builder.AddInstruction(HloInstruction::CreateDynamicSlice(
1214       ShapeUtil::MakeShape(F32, {1, 1}), data_param, params,
1215       /*slice_sizes=*/{1, 1}));
1216 
1217   module_->AddEntryComputation(builder.Build());
1218   // Set up dynamic parameter binding.
1219   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
1220       DynamicParameterBinding::DynamicParameter{1, {}},
1221       DynamicParameterBinding::DynamicDimension{0, {}, 0}));
1222 
1223   TF_ASSERT_OK(RunInference());
1224   EXPECT_EQ(inference_->GetDynamicSize(slice, {}, 0), nullptr);
1225 }
1226 
TEST_F(DynamicDimensionInferenceTest,InfersCustomOp)1227 TEST_F(DynamicDimensionInferenceTest, InfersCustomOp) {
1228   auto builder = HloComputation::Builder(TestName());
1229 
1230   auto data_param = builder.AddInstruction(HloInstruction::CreateParameter(
1231       0, ShapeUtil::MakeShape(F32, {5, 7}), "data_param"));
1232   builder.AddInstruction(
1233       HloInstruction::CreateParameter(1, scalar_shape_, "size_param"));
1234 
1235   builder.AddInstruction(HloInstruction::CreateCustomCall(
1236       ShapeUtil::MakeShape(F32, {1, 1}), {data_param}, "MyCustomOp", ""));
1237 
1238   module_->AddEntryComputation(builder.Build());
1239 
1240   // Set up dynamic parameter binding.
1241   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
1242       DynamicParameterBinding::DynamicParameter{1, {}},
1243       DynamicParameterBinding::DynamicDimension{0, {}, 0}));
1244 
1245   bool handler_called = false;
1246   auto handler = [&](HloInstruction* hlo,
1247                      DynamicDimensionInference* inference) {
1248     CHECK(inference != nullptr);
1249     CHECK(Cast<HloCustomCallInstruction>(hlo) != nullptr);
1250     handler_called = true;
1251     return OkStatus();
1252   };
1253   TF_ASSERT_OK(RunInference(handler));
1254 
1255   EXPECT_TRUE(handler_called);
1256 }
1257 
TEST_F(DynamicDimensionInferenceTest,DynamicReshapeOp)1258 TEST_F(DynamicDimensionInferenceTest, DynamicReshapeOp) {
1259   auto builder = HloComputation::Builder(TestName());
1260   auto input = builder.AddInstruction(HloInstruction::CreateParameter(
1261       0, ShapeUtil::MakeShape(F32, {9}), "data_input"));
1262   auto six = builder.AddInstruction(
1263       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(6)));
1264   // Creates an input of shape [<=9], dynamic size is 6.
1265   auto dynamic_input =
1266       builder.AddInstruction(HloInstruction::CreateSetDimensionSize(
1267           ShapeUtil::MakeShape(F32, {9}, {true}), input, six, 0));
1268   auto dynamic_size = builder.AddInstruction(HloInstruction::CreateParameter(
1269       1, ShapeUtil::MakeShape(S32, {}), "size_param"));
1270   auto three = builder.AddInstruction(
1271       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(3)));
1272 
1273   // Reshape [<=9] into [3, <=3]
1274 
1275   auto dynamic_reshape =
1276       builder.AddInstruction(HloInstruction::CreateDynamicReshape(
1277           ShapeUtil::MakeShape(F32, {3, 3}, {false, true}), dynamic_input,
1278           {three, dynamic_size}));
1279 
1280   module_->AddEntryComputation(builder.Build());
1281 
1282   TF_ASSERT_OK(RunInference());
1283   EXPECT_EQ(inference_->GetDynamicSize(dynamic_reshape, {}, 0), nullptr);
1284   EXPECT_EQ(inference_->GetDynamicSize(dynamic_reshape, {}, 1), dynamic_size);
1285 }
1286 
TEST_F(DynamicDimensionInferenceTest,ReshapeOpWithMultipleDynamicDimensions)1287 TEST_F(DynamicDimensionInferenceTest, ReshapeOpWithMultipleDynamicDimensions) {
1288   auto builder = HloComputation::Builder(TestName());
1289   auto input = builder.AddInstruction(HloInstruction::CreateParameter(
1290       0, ShapeUtil::MakeShape(F32, {9, 2}), "data_input"));
1291   auto six = builder.AddInstruction(
1292       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(6)));
1293   input = builder.AddInstruction(HloInstruction::CreateSetDimensionSize(
1294       ShapeUtil::MakeShape(F32, {9, 2}, {true, false}), input, six, 0));
1295   auto one = builder.AddInstruction(
1296       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(1)));
1297   input = builder.AddInstruction(HloInstruction::CreateSetDimensionSize(
1298       ShapeUtil::MakeShape(F32, {9, 2}, {true, true}), input, one, 1));
1299 
1300   // Reshape [<=9, <=2] into [<=9, 1, <=2]
1301 
1302   auto dynamic_reshape = builder.AddInstruction(HloInstruction::CreateReshape(
1303       ShapeUtil::MakeShape(F32, {9, 1, 2}, {true, false, true}), input));
1304 
1305   module_->AddEntryComputation(builder.Build());
1306 
1307   TF_ASSERT_OK(RunInference());
1308   EXPECT_EQ(inference_->GetDynamicSize(dynamic_reshape, {}, 0), six);
1309   EXPECT_EQ(inference_->GetDynamicSize(dynamic_reshape, {}, 1), nullptr);
1310   EXPECT_EQ(inference_->GetDynamicSize(dynamic_reshape, {}, 2), one);
1311 }
1312 
TEST_F(DynamicDimensionInferenceTest,HandleMapInDynamicDimensionInference)1313 TEST_F(DynamicDimensionInferenceTest, HandleMapInDynamicDimensionInference) {
1314   const char* module_str = R"(
1315 HloModule test_module
1316 
1317 %scatter-combiner.285 (p0.286: c128[], p1.287: c128[]) -> c128[] {
1318   %p0.286 = c128[] parameter(0)
1319   %p1.287 = c128[] parameter(1)
1320   ROOT %add.288 = c128[] add(c128[] %p0.286, c128[] %p1.287)
1321 }
1322 
1323  %while_body  {
1324   %reshape.8 = s32[] parameter(4)
1325   %reshape.7 = c128[1]{0} parameter(3)
1326   %reduce = pred[] parameter(2)
1327   %concatenate = s32[1]{0} parameter(1)
1328   %slice.4 = s32[1]{0} slice(s32[1]{0} %concatenate), slice={[0 : 1]}
1329   %broadcast.7 = pred[1]{0} broadcast(pred[] %reduce), dimensions={}
1330   %param.1 = (s32[],c128[<=1]{0},s32[1]{0},c128[1]{0}) parameter(0)
1331   %get-tuple-element.2 = c128[<=1]{0} get-tuple-element((s32[],c128[<=1]{0},s32[1]{0},c128[1]{0}) %param.1), index=1
1332   %dynamic-slice.2 = c128[1]{0} dynamic-slice(c128[<=1]{0} %get-tuple-element.2,s32[] %reshape.8), dynamic_slice_sizes={1}
1333   %map = c128[1]{0} map(c128[1]{0} %dynamic-slice.2,c128[1]{0} %reshape.7), dimensions={0}, to_apply=%scatter-combiner.285
1334   %select = c128[1]{0} select(pred[1]{0} %broadcast.7,c128[1]{0} %map,c128[1]{0} %dynamic-slice.2)
1335   %reshape.9 = s32[] reshape(s32[1]{0} %slice.4)
1336   %dynamic-update-slice = c128[<=1]{0} dynamic-update-slice(c128[<=1]{0} %get-tuple-element.2,c128[1]{0} %select,s32[] %reshape.9)
1337 })";
1338   TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnUnverifiedModule(module_str));
1339   TF_ASSERT_OK(RunInference());
1340 }
1341 
TEST_F(DynamicDimensionInferenceTest,RuntimeShapeCheck)1342 TEST_F(DynamicDimensionInferenceTest, RuntimeShapeCheck) {
1343   const char* hlo = R"(
1344 HloModule module
1345 
1346 ENTRY computation {
1347   a = f32[20,20] parameter(0)
1348   a_size_1 = s32[] parameter(1)
1349   a_size_2 = s32[] parameter(2)
1350   b = f32[20,20] parameter(3)
1351   b_size_1 = s32[] parameter(4)
1352   b_size_2 = s32[] parameter(5)
1353   ROOT f = add(a, b)
1354 }
1355 )";
1356 
1357   TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(hlo));
1358   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
1359       DynamicParameterBinding::DynamicParameter{1, {}},
1360       DynamicParameterBinding::DynamicDimension{0, {}, 0}));
1361   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
1362       DynamicParameterBinding::DynamicParameter{2, {}},
1363       DynamicParameterBinding::DynamicDimension{0, {}, 1}));
1364   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
1365       DynamicParameterBinding::DynamicParameter{4, {}},
1366       DynamicParameterBinding::DynamicDimension{3, {}, 0}));
1367   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
1368       DynamicParameterBinding::DynamicParameter{5, {}},
1369       DynamicParameterBinding::DynamicDimension{3, {}, 1}));
1370 
1371   TF_ASSERT_OK(RunInference(
1372       /*handler=*/nullptr, DynamicDimensionInference::ShapeCheckMode::kRuntime,
1373       /*assertion_generator=*/[&](HloInstruction* constraint) {
1374         constraint->parent()->AddInstruction(HloInstruction::CreateCustomCall(
1375             ShapeUtil::MakeTokenShape(), {constraint},
1376             /*custom_call_target=*/"__xla__assert",
1377             /*opaque=*/std::string{}, API_VERSION_STATUS_RETURNING));
1378       }));
1379 
1380   StatusOr<bool> filecheck_result = RunFileCheck(module_->ToString({}),
1381                                                  R"(
1382 // CHECK: compare = pred[] compare(s32[] %a_size_1, s32[] %b_size_1), direction=EQ
1383 // CHECK: compare.5 = pred[] compare(s32[] %a_size_2, s32[] %b_size_2), direction=EQ
1384 // CHECK: and.2 = pred[] and(pred[] %compare, pred[] %compare.5)
1385 // CHECK: custom-call(pred[] %and.2), custom_call_target="__xla__assert"
1386                    )");
1387   TF_ASSERT_OK(filecheck_result.status());
1388   EXPECT_TRUE(*filecheck_result);
1389 }
1390 
1391 }  // namespace
1392 }  // namespace xla
1393