1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h"
17
18 #include "tensorflow/compiler/xla/client/xla_builder.h"
19 #include "tensorflow/compiler/xla/literal.h"
20 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
21 #include "tensorflow/compiler/xla/service/hlo_computation.h"
22 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
23 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
24 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
25 #include "tensorflow/compiler/xla/service/hlo_module.h"
26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
27 #include "tensorflow/compiler/xla/service/hlo_runner.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/test.h"
30 #include "tensorflow/compiler/xla/test_helpers.h"
31 #include "tensorflow/compiler/xla/tests/filecheck.h"
32 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 #include "tensorflow/core/lib/core/status_test_util.h"
35 #include "tensorflow/core/platform/test_benchmark.h"
36
37 namespace op = xla::testing::opcode_matchers;
38
39 namespace xla {
40 namespace {
41
42 class DynamicDimensionInferenceTest : public HloTestBase {
43 protected:
DynamicDimensionInferenceTest()44 DynamicDimensionInferenceTest() : HloTestBase() {
45 module_ = CreateNewVerifiedModule();
46 }
47
RunInference(DynamicDimensionInference::CustomCallInferenceHandler handler=nullptr,DynamicDimensionInference::ShapeCheckMode shape_check_mode=DynamicDimensionInference::ShapeCheckMode::kIgnore,const DynamicDimensionInference::AssertionGenerator & assertion_generator=nullptr)48 Status RunInference(
49 DynamicDimensionInference::CustomCallInferenceHandler handler = nullptr,
50 DynamicDimensionInference::ShapeCheckMode shape_check_mode =
51 DynamicDimensionInference::ShapeCheckMode::kIgnore,
52 const DynamicDimensionInference::AssertionGenerator& assertion_generator =
53 nullptr) {
54 TF_ASSIGN_OR_RETURN(
55 DynamicDimensionInference inference,
56 DynamicDimensionInference::Run(module_.get(), handler, shape_check_mode,
57 assertion_generator));
58
59 inference_ = std::make_unique<DynamicDimensionInference>(inference);
60 return OkStatus();
61 }
62
GetAdd()63 HloComputation* GetAdd() {
64 auto embedded_builder = HloComputation::Builder("add");
65 auto lhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
66 0, ShapeUtil::MakeShape(F32, {}), "lhs"));
67 auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
68 1, ShapeUtil::MakeShape(F32, {}), "rhs"));
69 embedded_builder.AddInstruction(
70 HloInstruction::CreateBinary(lhs->shape(), HloOpcode::kAdd, lhs, rhs));
71 return module_->AddEmbeddedComputation(embedded_builder.Build());
72 }
73
GetAddTuple()74 HloComputation* GetAddTuple() {
75 auto embedded_builder = HloComputation::Builder("add");
76 auto lhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
77 0, ShapeUtil::MakeShape(F32, {}), "lhs"));
78 auto lhs_1 =
79 embedded_builder.AddInstruction(HloInstruction::CreateParameter(
80 1, ShapeUtil::MakeShape(F32, {}), "lhs.1"));
81 auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
82 2, ShapeUtil::MakeShape(F32, {}), "rhs"));
83 auto rhs_1 =
84 embedded_builder.AddInstruction(HloInstruction::CreateParameter(
85 3, ShapeUtil::MakeShape(F32, {}), "rhs.1"));
86 auto add = embedded_builder.AddInstruction(
87 HloInstruction::CreateBinary(lhs->shape(), HloOpcode::kAdd, lhs, rhs));
88 auto add_1 = embedded_builder.AddInstruction(HloInstruction::CreateBinary(
89 lhs->shape(), HloOpcode::kAdd, lhs_1, rhs_1));
90 embedded_builder.AddInstruction(HloInstruction::CreateTuple({add, add_1}));
91 return module_->AddEmbeddedComputation(embedded_builder.Build());
92 }
93
GetGe()94 HloComputation* GetGe() {
95 auto embedded_builder = HloComputation::Builder("ge");
96 auto lhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
97 0, ShapeUtil::MakeShape(F32, {}), "lhs"));
98 auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
99 1, ShapeUtil::MakeShape(F32, {}), "rhs"));
100 embedded_builder.AddInstruction(HloInstruction::CreateCompare(
101 ShapeUtil::MakeShape(PRED, {}), lhs, rhs, ComparisonDirection::kGe));
102 return module_->AddEmbeddedComputation(embedded_builder.Build());
103 }
104
105 std::unique_ptr<HloModule> module_;
106 std::unique_ptr<DynamicDimensionInference> inference_;
107 const Shape scalar_shape_ = ShapeUtil::MakeShape(S32, {});
108 };
109
TEST_F(DynamicDimensionInferenceTest,ParamTest)110 TEST_F(DynamicDimensionInferenceTest, ParamTest) {
111 auto builder = HloComputation::Builder(TestName());
112 auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
113
114 auto param = builder.AddInstruction(
115 HloInstruction::CreateParameter(0, input_shape, "param"));
116 auto param2 = builder.AddInstruction(
117 HloInstruction::CreateParameter(1, scalar_shape_, "param"));
118
119 module_->AddEntryComputation(builder.Build());
120 SCOPED_TRACE(module_->ToString());
121
122 // Set up dynamic parameter binding.
123 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
124 DynamicParameterBinding::DynamicParameter{1, {}},
125 DynamicParameterBinding::DynamicDimension{0, {}, 1}));
126
127 TF_ASSERT_OK(RunInference());
128 EXPECT_EQ(inference_->GetDynamicSize(param, {}, 1), param2);
129 EXPECT_EQ(inference_->GetDynamicSize(param, {}, 0), nullptr);
130 EXPECT_EQ(inference_->GetDynamicSize(param2, {}, 0), nullptr);
131 }
132
TEST_F(DynamicDimensionInferenceTest,ParamTestTuple)133 TEST_F(DynamicDimensionInferenceTest, ParamTestTuple) {
134 auto builder = HloComputation::Builder(TestName());
135 auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
136
137 auto param = builder.AddInstruction(HloInstruction::CreateParameter(
138 0, ShapeUtil::MakeTupleShape({input_shape, scalar_shape_}), "param"));
139
140 module_->AddEntryComputation(builder.Build());
141 // Set up dynamic parameter binding.
142 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
143 DynamicParameterBinding::DynamicParameter{0, {1}},
144 DynamicParameterBinding::DynamicDimension{0, {0}, 1}));
145
146 SCOPED_TRACE(module_->ToString());
147 TF_ASSERT_OK(RunInference());
148 EXPECT_THAT(inference_->GetDynamicSize(param, {0}, 1),
149 op::GetTupleElement(param, 1));
150
151 EXPECT_EQ(inference_->GetDynamicSize(param, {0}, 0), nullptr);
152 }
153
TEST_F(DynamicDimensionInferenceTest,GetTupleElement)154 TEST_F(DynamicDimensionInferenceTest, GetTupleElement) {
155 // When data flows through GTE, the dynamic dimension size keeps the
156 // same, and the index has its front popped.
157 auto builder = HloComputation::Builder(TestName());
158 auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
159
160 auto param = builder.AddInstruction(HloInstruction::CreateParameter(
161 0, ShapeUtil::MakeTupleShape({input_shape, scalar_shape_}), "param"));
162
163 auto gte = builder.AddInstruction(
164 HloInstruction::CreateGetTupleElement(input_shape, param, 0));
165
166 module_->AddEntryComputation(builder.Build());
167 // Set up dynamic parameter binding.
168 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
169 DynamicParameterBinding::DynamicParameter{0, {1}},
170 DynamicParameterBinding::DynamicDimension{0, {0}, 1}));
171
172 SCOPED_TRACE(module_->ToString());
173 TF_ASSERT_OK(RunInference());
174 EXPECT_THAT(inference_->GetDynamicSize(param, {0}, 1),
175 op::GetTupleElement(param, 1));
176
177 EXPECT_THAT(inference_->GetDynamicSize(gte, {}, 1),
178 op::GetTupleElement(param, 1));
179
180 EXPECT_EQ(inference_->GetDynamicSize(param, {0}, 0), nullptr);
181 }
182
TEST_F(DynamicDimensionInferenceTest,ElementwiseTest)183 TEST_F(DynamicDimensionInferenceTest, ElementwiseTest) {
184 // When data flows through elementwise, the dynamic dimension size keeps the
185 // same.
186 auto builder = HloComputation::Builder(TestName());
187 auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
188
189 auto data_param = builder.AddInstruction(
190 HloInstruction::CreateParameter(0, input_shape, "data_param"));
191 auto size_param = builder.AddInstruction(
192 HloInstruction::CreateParameter(1, scalar_shape_, "size_param"));
193
194 auto* negate = builder.AddInstruction(
195 HloInstruction::CreateUnary(input_shape, HloOpcode::kNegate, data_param));
196
197 module_->AddEntryComputation(builder.Build());
198 // Set up dynamic parameter binding.
199 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
200 DynamicParameterBinding::DynamicParameter{1, {}},
201 DynamicParameterBinding::DynamicDimension{0, {}, 1}));
202
203 SCOPED_TRACE(module_->ToString());
204 TF_ASSERT_OK(RunInference());
205 EXPECT_EQ(inference_->GetDynamicSize(negate, {}, 1), size_param);
206 }
207
TEST_F(DynamicDimensionInferenceTest,ReduceTestI)208 TEST_F(DynamicDimensionInferenceTest, ReduceTestI) {
209 auto builder = HloComputation::Builder(TestName());
210 auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
211 auto reduce_shape = ShapeUtil::MakeShape(F32, {2});
212
213 auto data_param = builder.AddInstruction(
214 HloInstruction::CreateParameter(0, input_shape, "data_param"));
215 auto size_param = builder.AddInstruction(
216 HloInstruction::CreateParameter(1, scalar_shape_, "size_param"));
217
218 auto negate = builder.AddInstruction(
219 HloInstruction::CreateUnary(input_shape, HloOpcode::kNegate, data_param));
220
221 auto init = builder.AddInstruction(
222 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
223
224 auto reduce = builder.AddInstruction(HloInstruction::CreateReduce(
225 reduce_shape, negate, init, {0, 2}, GetAdd()));
226
227 module_->AddEntryComputation(builder.Build());
228
229 // Set up dynamic parameter binding.
230 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
231 DynamicParameterBinding::DynamicParameter{1, {}},
232 DynamicParameterBinding::DynamicDimension{0, {}, 1}));
233
234 SCOPED_TRACE(module_->ToString());
235 TF_ASSERT_OK(RunInference());
236 EXPECT_EQ(inference_->GetDynamicSize(reduce, {}, 0), size_param);
237 }
238
TEST_F(DynamicDimensionInferenceTest,ReduceTestII)239 TEST_F(DynamicDimensionInferenceTest, ReduceTestII) {
240 // Same as ReduceTestI, but only reduce one dimension.
241 auto builder = HloComputation::Builder(TestName());
242 auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
243 auto reduce_shape = ShapeUtil::MakeShape(F32, {1, 2});
244
245 auto data_param = builder.AddInstruction(
246 HloInstruction::CreateParameter(0, input_shape, "data_param"));
247 auto size_param = builder.AddInstruction(
248 HloInstruction::CreateParameter(1, scalar_shape_, "size_param"));
249
250 auto negate = builder.AddInstruction(
251 HloInstruction::CreateUnary(input_shape, HloOpcode::kNegate, data_param));
252
253 auto init = builder.AddInstruction(
254 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
255
256 auto reduce = builder.AddInstruction(
257 HloInstruction::CreateReduce(reduce_shape, negate, init, {1}, GetAdd()));
258
259 module_->AddEntryComputation(builder.Build());
260
261 // Set up dynamic parameter binding.
262 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
263 DynamicParameterBinding::DynamicParameter{1, {}},
264 DynamicParameterBinding::DynamicDimension{0, {}, 2}));
265
266 SCOPED_TRACE(module_->ToString());
267 TF_ASSERT_OK(RunInference());
268 EXPECT_EQ(inference_->GetDynamicSize(reduce, {}, 1), size_param);
269 EXPECT_EQ(inference_->GetDynamicSize(reduce, {}, 0), nullptr);
270 }
271
TEST_F(DynamicDimensionInferenceTest,VariadicReduce)272 TEST_F(DynamicDimensionInferenceTest, VariadicReduce) {
273 // Handle variadic reduce where output is a tuple.
274 auto builder = HloComputation::Builder(TestName());
275 auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
276 auto reduce_shape = ShapeUtil::MakeShape(F32, {1, 2});
277
278 auto data_param_dynamic = builder.AddInstruction(
279 HloInstruction::CreateParameter(0, input_shape, "data_param"));
280 auto data_param_static = builder.AddInstruction(
281 HloInstruction::CreateParameter(1, input_shape, "data_param.2"));
282 auto size_param = builder.AddInstruction(
283 HloInstruction::CreateParameter(2, scalar_shape_, "size_param"));
284
285 // Set up dynamic parameter binding.
286 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
287 DynamicParameterBinding::DynamicParameter{2, {}},
288 DynamicParameterBinding::DynamicDimension{0, {}, 2}));
289
290 auto dynamic_negate = builder.AddInstruction(HloInstruction::CreateUnary(
291 input_shape, HloOpcode::kNegate, data_param_dynamic));
292
293 auto static_negate = builder.AddInstruction(HloInstruction::CreateUnary(
294 input_shape, HloOpcode::kNegate, data_param_static));
295
296 auto init = builder.AddInstruction(
297 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
298
299 auto reduce = builder.AddInstruction(HloInstruction::CreateReduce(
300 ShapeUtil::MakeTupleShape({reduce_shape, reduce_shape}),
301 {dynamic_negate, static_negate}, {init, init}, {1}, GetAddTuple()));
302
303 module_->AddEntryComputation(builder.Build());
304
305 SCOPED_TRACE(module_->ToString());
306 TF_ASSERT_OK(RunInference());
307 EXPECT_EQ(inference_->GetDynamicSize(reduce, {0}, 1), size_param);
308 EXPECT_EQ(inference_->GetDynamicSize(reduce, {1}, 1), size_param);
309 EXPECT_EQ(inference_->GetDynamicSize(reduce, {0}, 0), nullptr);
310 EXPECT_EQ(inference_->GetDynamicSize(reduce, {1}, 0), nullptr);
311 }
312
TEST_F(DynamicDimensionInferenceTest,DotTest)313 TEST_F(DynamicDimensionInferenceTest, DotTest) {
314 auto builder = HloComputation::Builder(TestName());
315 constexpr int xdim = 3;
316 constexpr int ydim = 2;
317 constexpr int zdim = 1;
318 auto xy_shape = ShapeUtil::MakeShape(F32, {xdim, ydim});
319 auto yz_shape = ShapeUtil::MakeShape(F32, {ydim, zdim});
320 auto xz_shape = ShapeUtil::MakeShape(F32, {xdim, zdim});
321
322 auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
323 /*parameter_number=*/0, xy_shape, "A"));
324 auto* b_param = builder.AddInstruction(HloInstruction::CreateParameter(
325 /*parameter_number=*/1, yz_shape, "B"));
326 auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter(
327 /*parameter_number=*/2, scalar_shape_, "size_param"));
328
329 DotDimensionNumbers dot_dnums;
330 dot_dnums.add_lhs_contracting_dimensions(1);
331 dot_dnums.add_rhs_contracting_dimensions(0);
332 auto dot = builder.AddInstruction(
333 HloInstruction::CreateDot(xz_shape, a_param, b_param, dot_dnums,
334 HloTestBase::DefaultPrecisionConfig(2)));
335
336 module_->AddEntryComputation(builder.Build());
337
338 // Set up dynamic parameter binding for non-contracting dimension.
339 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
340 DynamicParameterBinding::DynamicParameter{2, {}},
341 DynamicParameterBinding::DynamicDimension{0, {}, 0}));
342
343 // Set up binding for contracting dimensions.
344 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
345 DynamicParameterBinding::DynamicParameter{2, {}},
346 DynamicParameterBinding::DynamicDimension{0, {}, 1}));
347 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
348 DynamicParameterBinding::DynamicParameter{2, {}},
349 DynamicParameterBinding::DynamicDimension{1, {}, 0}));
350
351 SCOPED_TRACE(module_->ToString());
352 TF_ASSERT_OK(RunInference());
353 EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 0), size_param);
354 EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 1), nullptr);
355 }
356
TEST_F(DynamicDimensionInferenceTest,DotTestBatch)357 TEST_F(DynamicDimensionInferenceTest, DotTestBatch) {
358 auto builder = HloComputation::Builder(TestName());
359 auto lhs_shape = ShapeUtil::MakeShape(F32, {4, 128, 2, 8});
360 auto rhs_shape = ShapeUtil::MakeShape(F32, {4, 128, 2, 8});
361 auto output_shape = ShapeUtil::MakeShape(F32, {4, 2, 128, 128});
362
363 auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
364 /*parameter_number=*/0, lhs_shape, "A"));
365 auto* b_param = builder.AddInstruction(HloInstruction::CreateParameter(
366 /*parameter_number=*/1, rhs_shape, "B"));
367 auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter(
368 /*parameter_number=*/2, scalar_shape_, "size_param"));
369
370 DotDimensionNumbers dot_dnums;
371 dot_dnums.add_lhs_contracting_dimensions(3);
372 dot_dnums.add_rhs_contracting_dimensions(3);
373 dot_dnums.add_lhs_batch_dimensions(0);
374 dot_dnums.add_lhs_batch_dimensions(2);
375 dot_dnums.add_rhs_batch_dimensions(0);
376 dot_dnums.add_rhs_batch_dimensions(2);
377 auto dot = builder.AddInstruction(
378 HloInstruction::CreateDot(output_shape, a_param, b_param, dot_dnums,
379 HloTestBase::DefaultPrecisionConfig(2)));
380
381 module_->AddEntryComputation(builder.Build());
382
383 // Set up dynamic parameter binding for batch dimension.
384 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
385 DynamicParameterBinding::DynamicParameter{2, {}},
386 DynamicParameterBinding::DynamicDimension{0, {}, 0}));
387
388 SCOPED_TRACE(module_->ToString());
389 TF_ASSERT_OK(RunInference());
390 EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 0), size_param);
391 EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 1), nullptr);
392 EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 2), nullptr);
393 EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 3), nullptr);
394 }
395
TEST_F(DynamicDimensionInferenceTest,DotTestMultiContracting)396 TEST_F(DynamicDimensionInferenceTest, DotTestMultiContracting) {
397 auto builder = HloComputation::Builder(TestName());
398 auto lhs_shape = ShapeUtil::MakeShape(F32, {2, 2, 8, 64});
399 auto rhs_shape = ShapeUtil::MakeShape(F32, {2, 2, 512});
400 auto output_shape = ShapeUtil::MakeShape(F32, {8, 64, 512});
401
402 auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
403 /*parameter_number=*/0, lhs_shape, "A"));
404 auto* b_param = builder.AddInstruction(HloInstruction::CreateParameter(
405 /*parameter_number=*/1, rhs_shape, "B"));
406 builder.AddInstruction(HloInstruction::CreateParameter(
407 /*parameter_number=*/2, scalar_shape_, "size_param"));
408
409 DotDimensionNumbers dot_dnums;
410 dot_dnums.add_lhs_contracting_dimensions(0);
411 dot_dnums.add_lhs_contracting_dimensions(1);
412 dot_dnums.add_rhs_contracting_dimensions(0);
413 dot_dnums.add_rhs_contracting_dimensions(1);
414 auto dot = builder.AddInstruction(
415 HloInstruction::CreateDot(output_shape, a_param, b_param, dot_dnums,
416 HloTestBase::DefaultPrecisionConfig(2)));
417
418 module_->AddEntryComputation(builder.Build());
419
420 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
421 DynamicParameterBinding::DynamicParameter{2, {}},
422 DynamicParameterBinding::DynamicDimension{0, {}, 0}));
423
424 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
425 DynamicParameterBinding::DynamicParameter{2, {}},
426 DynamicParameterBinding::DynamicDimension{0, {}, 1}));
427 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
428 DynamicParameterBinding::DynamicParameter{2, {}},
429 DynamicParameterBinding::DynamicDimension{1, {}, 0}));
430
431 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
432 DynamicParameterBinding::DynamicParameter{2, {}},
433 DynamicParameterBinding::DynamicDimension{1, {}, 1}));
434
435 SCOPED_TRACE(module_->ToString());
436 TF_ASSERT_OK(RunInference());
437 // Nothing is dynamic in the output.
438 EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 0), nullptr);
439 EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 1), nullptr);
440 EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 2), nullptr);
441 }
442
TEST_F(DynamicDimensionInferenceTest,ConvolutionTest)443 TEST_F(DynamicDimensionInferenceTest, ConvolutionTest) {
444 auto builder = HloComputation::Builder(TestName());
445 constexpr int xdim = 3;
446 constexpr int ydim = 2;
447 constexpr int zdim = 1;
448 auto xy_shape = ShapeUtil::MakeShape(F32, {xdim, ydim});
449 auto yz_shape = ShapeUtil::MakeShape(F32, {ydim, zdim});
450 auto zx_shape = ShapeUtil::MakeShape(F32, {zdim, xdim});
451
452 auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
453 /*parameter_number=*/0, xy_shape, "A"));
454 auto* b_param = builder.AddInstruction(HloInstruction::CreateParameter(
455 /*parameter_number=*/1, yz_shape, "B"));
456 auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter(
457 /*parameter_number=*/2, scalar_shape_, "size_param"));
458
459 auto dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(0);
460
461 dnums.set_kernel_input_feature_dimension(0);
462 dnums.set_kernel_output_feature_dimension(1);
463 dnums.set_input_batch_dimension(0);
464 dnums.set_output_batch_dimension(1);
465 dnums.set_output_feature_dimension(0);
466
467 Window window;
468
469 auto* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
470 zx_shape, a_param, b_param, /*feature_group_count=*/1,
471 /*batch_group_count=*/1, window, dnums,
472 HloTestBase::DefaultPrecisionConfig(2)));
473
474 module_->AddEntryComputation(builder.Build());
475
476 // Set up dynamic parameter binding for non-contracting dimension.
477 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
478 DynamicParameterBinding::DynamicParameter{2, {}},
479 DynamicParameterBinding::DynamicDimension{0, {}, 0}));
480
481 // Set up binding for contracting dimensions.
482 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
483 DynamicParameterBinding::DynamicParameter{2, {}},
484 DynamicParameterBinding::DynamicDimension{0, {}, 1}));
485
486 SCOPED_TRACE(module_->ToString());
487 TF_ASSERT_OK(RunInference());
488 EXPECT_EQ(inference_->GetDynamicSize(conv, {}, 1), size_param);
489 EXPECT_EQ(inference_->GetDynamicSize(conv, {}, 0), nullptr);
490 }
491
TEST_F(DynamicDimensionInferenceTest,TransposeTest)492 TEST_F(DynamicDimensionInferenceTest, TransposeTest) {
493 // Test the ability to trace unmodified dimensions
494 auto builder = HloComputation::Builder(TestName());
495 auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 3});
496 auto output_shape = ShapeUtil::MakeShape(F32, {3, 2, 1});
497
498 auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
499 /*parameter_number=*/0, input_shape, "A"));
500 auto* size_param_1 = builder.AddInstruction(HloInstruction::CreateParameter(
501 /*parameter_number=*/1, scalar_shape_, "size_param"));
502 auto* size_param_2 = builder.AddInstruction(HloInstruction::CreateParameter(
503 /*parameter_number=*/2, scalar_shape_, "size_param"));
504 auto* size_param_3 = builder.AddInstruction(HloInstruction::CreateParameter(
505 /*parameter_number=*/3, scalar_shape_, "size_param"));
506
507 auto* transpose = builder.AddInstruction(
508 HloInstruction::CreateTranspose(output_shape, a_param, {2, 1, 0}));
509
510 module_->AddEntryComputation(builder.Build());
511
512 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
513 DynamicParameterBinding::DynamicParameter{1, {}},
514 DynamicParameterBinding::DynamicDimension{0, {}, 0}));
515
516 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
517 DynamicParameterBinding::DynamicParameter{2, {}},
518 DynamicParameterBinding::DynamicDimension{0, {}, 1}));
519
520 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
521 DynamicParameterBinding::DynamicParameter{3, {}},
522 DynamicParameterBinding::DynamicDimension{0, {}, 2}));
523
524 SCOPED_TRACE(module_->ToString());
525 TF_ASSERT_OK(RunInference());
526 EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 0), size_param_3);
527 EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 1), size_param_2);
528 EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 2), size_param_1);
529 }
530
TEST_F(DynamicDimensionInferenceTest,NonDescendingTransposeTest)531 TEST_F(DynamicDimensionInferenceTest, NonDescendingTransposeTest) {
532 // Test the ability to trace unmodified dimensions
533 auto builder = HloComputation::Builder(TestName());
534 auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 3});
535 auto output_shape = ShapeUtil::MakeShape(F32, {3, 1, 2});
536
537 auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
538 /*parameter_number=*/0, input_shape, "A"));
539 auto* size_param_1 = builder.AddInstruction(HloInstruction::CreateParameter(
540 /*parameter_number=*/1, scalar_shape_, "size_param"));
541 auto* size_param_2 = builder.AddInstruction(HloInstruction::CreateParameter(
542 /*parameter_number=*/2, scalar_shape_, "size_param"));
543 auto* size_param_3 = builder.AddInstruction(HloInstruction::CreateParameter(
544 /*parameter_number=*/3, scalar_shape_, "size_param"));
545
546 auto* transpose = builder.AddInstruction(
547 HloInstruction::CreateTranspose(output_shape, a_param, {2, 0, 1}));
548
549 module_->AddEntryComputation(builder.Build());
550
551 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
552 DynamicParameterBinding::DynamicParameter{1, {}},
553 DynamicParameterBinding::DynamicDimension{0, {}, 0}));
554
555 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
556 DynamicParameterBinding::DynamicParameter{2, {}},
557 DynamicParameterBinding::DynamicDimension{0, {}, 1}));
558
559 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
560 DynamicParameterBinding::DynamicParameter{3, {}},
561 DynamicParameterBinding::DynamicDimension{0, {}, 2}));
562
563 SCOPED_TRACE(module_->ToString());
564 TF_ASSERT_OK(RunInference());
565 EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 0), size_param_3);
566 EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 1), size_param_1);
567 EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 2), size_param_2);
568 }
569
TEST_F(DynamicDimensionInferenceTest,ReshapeTest)570 TEST_F(DynamicDimensionInferenceTest, ReshapeTest) {
571 // Test the ability to trace unmodified reshape dimensions.
572 auto builder = HloComputation::Builder(TestName());
573 auto input_shape = ShapeUtil::MakeShape(F32, {2, 3, 4, 5, 6});
574 auto output_shape = ShapeUtil::MakeShape(F32, {6, 4, 1, 5, 2, 3});
575
576 auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
577 /*parameter_number=*/0, input_shape, "A"));
578 auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter(
579 /*parameter_number=*/1, scalar_shape_, "size_param"));
580
581 auto* reshape = builder.AddInstruction(
582 HloInstruction::CreateReshape(output_shape, a_param));
583
584 module_->AddEntryComputation(builder.Build());
585
586 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
587 DynamicParameterBinding::DynamicParameter{1, {}},
588 DynamicParameterBinding::DynamicDimension{0, {}, 2}));
589
590 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
591 DynamicParameterBinding::DynamicParameter{1, {}},
592 DynamicParameterBinding::DynamicDimension{0, {}, 3}));
593
594 SCOPED_TRACE(module_->ToString());
595 TF_ASSERT_OK(RunInference());
596 EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 0), nullptr);
597 EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 1), size_param);
598 EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 2), nullptr);
599 EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 3), size_param);
600 EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 4), nullptr);
601 EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 5), nullptr);
602 }
603
TEST_F(DynamicDimensionInferenceTest,ReshapeInferredDimensionTest)604 TEST_F(DynamicDimensionInferenceTest, ReshapeInferredDimensionTest) {
605 // Test the ability to trace inferred dimension when output is bigger than
606 // input.
607 auto builder = HloComputation::Builder(TestName());
608 auto input_shape = ShapeUtil::MakeShape(F32, {4, 5});
609 auto output_shape = ShapeUtil::MakeShape(F32, {1, 4, 5});
610
611 auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
612 /*parameter_number=*/0, input_shape, "A"));
613 builder.AddInstruction(HloInstruction::CreateParameter(
614 /*parameter_number=*/1, scalar_shape_, "size_param"));
615
616 auto* reshape = builder.AddInstruction(HloInstruction::CreateReshape(
617 output_shape, a_param, /*inferred_dimension=*/0));
618
619 module_->AddEntryComputation(builder.Build());
620
621 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
622 DynamicParameterBinding::DynamicParameter{1, {}},
623 DynamicParameterBinding::DynamicDimension{0, {}, 0}));
624
625 SCOPED_TRACE(module_->ToString());
626 TF_ASSERT_OK(RunInference());
627 EXPECT_NE(inference_->GetDynamicSize(reshape, {}, 0), nullptr);
628 }
629
TEST_F(DynamicDimensionInferenceTest,ReshapeTestMajorDimension)630 TEST_F(DynamicDimensionInferenceTest, ReshapeTestMajorDimension) {
631 // Test the ability to trace dimension combining.
632 auto builder = HloComputation::Builder(TestName());
633 auto input_shape = ShapeUtil::MakeShape(F32, {32, 10, 4});
634 auto output_shape = ShapeUtil::MakeShape(F32, {320, 4});
635
636 auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
637 /*parameter_number=*/0, input_shape, "A"));
638
639 builder.AddInstruction(HloInstruction::CreateParameter(
640 /*parameter_number=*/1, scalar_shape_, "size_param"));
641
642 auto* reshape = builder.AddInstruction(
643 HloInstruction::CreateReshape(output_shape, a_param));
644
645 module_->AddEntryComputation(builder.Build());
646
647 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
648 DynamicParameterBinding::DynamicParameter{1, {}},
649 DynamicParameterBinding::DynamicDimension{0, {}, 0}));
650
651 SCOPED_TRACE(module_->ToString());
652 Status status = RunInference();
653 EXPECT_NE(inference_->GetDynamicSize(reshape, {}, 0), nullptr);
654 }
655
TEST_F(DynamicDimensionInferenceTest,ReshapeIntoScalar)656 TEST_F(DynamicDimensionInferenceTest, ReshapeIntoScalar) {
657 // Test the ability to a reshape into scalar.
658 auto builder = HloComputation::Builder(TestName());
659 auto input_shape = ShapeUtil::MakeShape(F32, {1});
660 auto output_shape = ShapeUtil::MakeShape(F32, {});
661
662 auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
663 /*parameter_number=*/0, input_shape, "A"));
664
665 builder.AddInstruction(HloInstruction::CreateParameter(
666 /*parameter_number=*/1, scalar_shape_, "size_param"));
667
668 builder.AddInstruction(HloInstruction::CreateReshape(output_shape, a_param));
669
670 module_->AddEntryComputation(builder.Build());
671
672 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
673 DynamicParameterBinding::DynamicParameter{1, {}},
674 DynamicParameterBinding::DynamicDimension{0, {}, 0}));
675
676 SCOPED_TRACE(module_->ToString());
677 TF_CHECK_OK(RunInference());
678 }
679
TEST_F(DynamicDimensionInferenceTest,GatherTest)680 TEST_F(DynamicDimensionInferenceTest, GatherTest) {
681 const std::string hlo_text = R"(
682 HloModule TensorFlowGatherV2
683
684 ENTRY main {
685 operand = s32[20,10]{1,0} parameter(0)
686 indices = s32[32,20] parameter(1)
687 dynamic_size = s32[] parameter(2)
688 ROOT gather = s32[32,20,10]{2,1,0} gather(%operand, %indices),
689 offset_dims={2},
690 collapsed_slice_dims={0},
691 start_index_map={0},
692 index_vector_dim=2,
693 slice_sizes={1,10}
694 }
695 )";
696
697 TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(hlo_text));
698 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
699 DynamicParameterBinding::DynamicParameter{2, {}},
700 DynamicParameterBinding::DynamicDimension{1, {}, 0}));
701 SCOPED_TRACE(module_->ToString());
702 TF_ASSERT_OK(RunInference());
703 EXPECT_EQ(inference_->GetDynamicSize(
704 module_->entry_computation()->root_instruction(), {}, 0),
705 module_->entry_computation()->parameter_instruction(2));
706 }
707
TEST_F(DynamicDimensionInferenceTest,BroadcastTest)708 TEST_F(DynamicDimensionInferenceTest, BroadcastTest) {
709 // Test the ability to trace broadcast dimension.
710 auto builder = HloComputation::Builder(TestName());
711 auto input_shape = ShapeUtil::MakeShape(F32, {2});
712 auto output_shape = ShapeUtil::MakeShape(F32, {3, 2, 4});
713
714 auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
715 /*parameter_number=*/0, input_shape, "A"));
716 auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter(
717 /*parameter_number=*/1, scalar_shape_, "size_param"));
718
719 auto* broadcast = builder.AddInstruction(
720 HloInstruction::CreateBroadcast(output_shape, a_param, {1}));
721
722 module_->AddEntryComputation(builder.Build());
723
724 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
725 DynamicParameterBinding::DynamicParameter{1, {}},
726 DynamicParameterBinding::DynamicDimension{0, {}, 0}));
727
728 SCOPED_TRACE(module_->ToString());
729 TF_ASSERT_OK(RunInference());
730 EXPECT_EQ(inference_->GetDynamicSize(broadcast, {}, 0), nullptr);
731 EXPECT_EQ(inference_->GetDynamicSize(broadcast, {}, 1), size_param);
732 EXPECT_EQ(inference_->GetDynamicSize(broadcast, {}, 2), nullptr);
733 }
734
TEST_F(DynamicDimensionInferenceTest,WhileTest)735 TEST_F(DynamicDimensionInferenceTest, WhileTest) {
736 // Test the ability to trace into while loops.
737 auto builder = HloComputation::Builder(TestName());
738 auto input_shape = ShapeUtil::MakeShape(F32, {2, 4, 4});
739 auto output_shape = ShapeUtil::MakeShape(F32, {2, 2, 2});
740 auto tuple_shape = ShapeUtil::MakeTupleShape({input_shape, input_shape});
741
742 // Body:
743 //
744 // Param
745 // | |
746 // GTE1 GTE2
747 // | |
748 // ADD
749 auto body_builder = HloComputation::Builder("body");
750 auto body_param = body_builder.AddInstruction(
751 HloInstruction::CreateParameter(0, tuple_shape, "param"));
752 auto gte_0 = body_builder.AddInstruction(
753 HloInstruction::CreateGetTupleElement(input_shape, body_param, 0));
754 auto gte_1 = body_builder.AddInstruction(
755 HloInstruction::CreateGetTupleElement(input_shape, body_param, 1));
756 auto add = body_builder.AddInstruction(
757 HloInstruction::CreateBinary(input_shape, HloOpcode::kAdd, gte_0, gte_1));
758 body_builder.AddInstruction(HloInstruction::CreateTuple({add, add}));
759
760 HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
761
762 auto cond_builder = HloComputation::Builder("condition");
763 cond_builder.AddInstruction(
764 HloInstruction::CreateParameter(0, tuple_shape, "param"));
765 cond_builder.AddInstruction(
766 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
767 HloComputation* condition =
768 module_->AddEmbeddedComputation(cond_builder.Build());
769
770 // Entry:
771 //
772 // Param
773 // |
774 // While
775 auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
776 /*parameter_number=*/0, tuple_shape, "A"));
777 builder.AddInstruction(HloInstruction::CreateParameter(
778 /*parameter_number=*/1, scalar_shape_, "size_param"));
779 builder.AddInstruction(
780 HloInstruction::CreateWhile(tuple_shape, condition, body, a_param));
781
782 module_->AddEntryComputation(builder.Build());
783
784 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
785 DynamicParameterBinding::DynamicParameter{1, {}},
786 DynamicParameterBinding::DynamicDimension{0, {0}, 0}));
787
788 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
789 DynamicParameterBinding::DynamicParameter{1, {}},
790 DynamicParameterBinding::DynamicDimension{0, {1}, 0}));
791
792 TF_ASSERT_OK(RunInference());
793 HloInstruction* while_hlo = nullptr;
794 // The while hlo has been replaced, find the new one.
795 for (HloInstruction* inst : module_->entry_computation()->instructions()) {
796 if (inst->opcode() == HloOpcode::kWhile) {
797 while_hlo = inst;
798 }
799 }
800 ASSERT_NE(while_hlo, nullptr);
801 // The original while shape has 2 parameters. With dynamic size, the tuple
802 // should have 4 elements (We don't deduplicate the arguments).
803 EXPECT_EQ(while_hlo->shape().tuple_shapes_size(), 4);
804 HloInstruction* add_inst = nullptr;
805 for (HloInstruction* inst : while_hlo->while_body()->instructions()) {
806 if (inst->opcode() == HloOpcode::kAdd) {
807 add_inst = inst;
808 }
809 }
810 EXPECT_NE(add_inst, nullptr);
811 EXPECT_NE(inference_->GetDynamicSize(add_inst, {}, 0), nullptr);
812 EXPECT_NE(inference_->GetDynamicSize(
813 module_->entry_computation()->root_instruction(), {0}, 0),
814 nullptr);
815 EXPECT_NE(inference_->GetDynamicSize(
816 module_->entry_computation()->root_instruction(), {1}, 0),
817 nullptr);
818 }
819
TEST_F(DynamicDimensionInferenceTest,ConditionalInputTest)820 TEST_F(DynamicDimensionInferenceTest, ConditionalInputTest) {
821 // Test the ability to trace into contional loops.
822 auto builder = HloComputation::Builder(TestName());
823 auto input_shape = ShapeUtil::MakeShape(F32, {2, 4, 4});
824 auto output_shape = ShapeUtil::MakeShape(F32, {2, 2, 2});
825 // In this test we set inputs to different branches to different shapes.
826 auto tuple_shape_1 = ShapeUtil::MakeTupleShape({input_shape});
827 auto tuple_shape_2 = ShapeUtil::MakeTupleShape({input_shape, input_shape});
828 auto tuple_shape_3 =
829 ShapeUtil::MakeTupleShape({input_shape, input_shape, input_shape});
830
831 // true branch:
832 //
833 // Param
834 // | |
835 // GTE1 GTE2
836 // | |
837 // Tuple(ADD)
838 auto true_builder = HloComputation::Builder("true");
839 {
840 auto true_param = true_builder.AddInstruction(
841 HloInstruction::CreateParameter(0, tuple_shape_2, "param"));
842 auto gte_0 = true_builder.AddInstruction(
843 HloInstruction::CreateGetTupleElement(input_shape, true_param, 0));
844 auto gte_1 = true_builder.AddInstruction(
845 HloInstruction::CreateGetTupleElement(input_shape, true_param, 1));
846 auto add = true_builder.AddInstruction(HloInstruction::CreateBinary(
847 input_shape, HloOpcode::kAdd, gte_0, gte_1));
848 true_builder.AddInstruction(HloInstruction::CreateTuple({add}));
849 }
850 HloComputation* true_branch =
851 module_->AddEmbeddedComputation(true_builder.Build());
852 // false branch:
853 //
854 // Param
855 // | | |
856 // GTE1 GTE2 GTE3
857 // | |
858 // Tuple(ADD)
859 auto false_builder = HloComputation::Builder("false");
860 {
861 auto false_param = false_builder.AddInstruction(
862 HloInstruction::CreateParameter(0, tuple_shape_3, "param"));
863 auto gte_0 = false_builder.AddInstruction(
864 HloInstruction::CreateGetTupleElement(input_shape, false_param, 1));
865 auto gte_1 = false_builder.AddInstruction(
866 HloInstruction::CreateGetTupleElement(input_shape, false_param, 2));
867 auto add = false_builder.AddInstruction(HloInstruction::CreateBinary(
868 input_shape, HloOpcode::kAdd, gte_0, gte_1));
869 false_builder.AddInstruction(HloInstruction::CreateTuple({add}));
870 }
871 HloComputation* false_branch =
872 module_->AddEmbeddedComputation(false_builder.Build());
873
874 // Entry:
875 //
876 // Param(bool) Param2 (tuple_2) Param3(tuple_3)
877 // | | |
878 // +---------Condition------------+
879 auto* pred_param = builder.AddInstruction(HloInstruction::CreateParameter(
880 /*parameter_number=*/0, ShapeUtil::MakeScalarShape(PRED), "pred"));
881
882 auto* tuple_2_param = builder.AddInstruction(HloInstruction::CreateParameter(
883 /*parameter_number=*/1, tuple_shape_2, "tuple_2_param"));
884 auto* tuple_3_param = builder.AddInstruction(HloInstruction::CreateParameter(
885 /*parameter_number=*/2, tuple_shape_3, "tuple_3_param"));
886 builder.AddInstruction(HloInstruction::CreateParameter(
887 /*parameter_number=*/3, scalar_shape_, "size_param"));
888 builder.AddInstruction(HloInstruction::CreateConditional(
889 tuple_shape_1, pred_param, tuple_2_param, true_branch, tuple_3_param,
890 false_branch));
891
892 module_->AddEntryComputation(builder.Build());
893
894 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
895 DynamicParameterBinding::DynamicParameter{3, {}},
896 DynamicParameterBinding::DynamicDimension{1, {0}, 0}));
897 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
898 DynamicParameterBinding::DynamicParameter{3, {}},
899 DynamicParameterBinding::DynamicDimension{1, {1}, 0}));
900 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
901 DynamicParameterBinding::DynamicParameter{3, {}},
902 DynamicParameterBinding::DynamicDimension{2, {1}, 0}));
903 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
904 DynamicParameterBinding::DynamicParameter{3, {}},
905 DynamicParameterBinding::DynamicDimension{2, {2}, 0}));
906
907 TF_ASSERT_OK(RunInference());
908
909 HloInstruction* conditional_hlo = nullptr;
910 // The while hlo has been replaced, find the new one.
911 for (HloInstruction* inst : module_->entry_computation()->instructions()) {
912 if (inst->opcode() == HloOpcode::kConditional) {
913 conditional_hlo = inst;
914 }
915 }
916 ASSERT_NE(conditional_hlo, nullptr);
917 // The original conditional shape has 1 parameters. With dynamic size passed
918 // out from the computation, another element is added to the tuple.
919 EXPECT_EQ(conditional_hlo->shape().tuple_shapes_size(), 2);
920 HloInstruction* add_true_branch = nullptr;
921 for (HloInstruction* inst :
922 conditional_hlo->true_computation()->instructions()) {
923 if (inst->opcode() == HloOpcode::kAdd) {
924 add_true_branch = inst;
925 }
926 }
927 EXPECT_NE(add_true_branch, nullptr);
928 EXPECT_NE(inference_->GetDynamicSize(add_true_branch, {}, 0), nullptr);
929
930 HloInstruction* add_false_branch = nullptr;
931 for (HloInstruction* inst :
932 conditional_hlo->false_computation()->instructions()) {
933 if (inst->opcode() == HloOpcode::kAdd) {
934 add_false_branch = inst;
935 }
936 }
937 EXPECT_NE(add_false_branch, nullptr);
938 EXPECT_NE(inference_->GetDynamicSize(add_false_branch, {}, 0), nullptr);
939
940 EXPECT_NE(inference_->GetDynamicSize(conditional_hlo, {0}, 0), nullptr);
941 }
942
TEST_F(DynamicDimensionInferenceTest,ReduceWindowBatchTest)943 TEST_F(DynamicDimensionInferenceTest, ReduceWindowBatchTest) {
944 // Test the ability to trace reduce window batch dimensions.
945 auto builder = HloComputation::Builder(TestName());
946 auto input_shape = ShapeUtil::MakeShape(F32, {2, 4, 4});
947 auto output_shape = ShapeUtil::MakeShape(F32, {2, 2, 2});
948
949 Window window;
950 // First dimension is unchanged.
951 WindowDimension* batch_dim = window.add_dimensions();
952 batch_dim->set_size(1);
953 batch_dim->set_stride(1);
954 batch_dim->set_padding_low(0);
955 batch_dim->set_padding_high(0);
956 batch_dim->set_window_dilation(1);
957 batch_dim->set_base_dilation(1);
958
959 // Second and third dimension are reduced.
960 for (int64_t i = 0; i < 2; ++i) {
961 WindowDimension* dim = window.add_dimensions();
962 dim->set_size(2);
963 dim->set_stride(2);
964 dim->set_padding_low(0);
965 dim->set_padding_high(0);
966 dim->set_window_dilation(1);
967 dim->set_base_dilation(1);
968 }
969
970 auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
971 /*parameter_number=*/0, input_shape, "A"));
972 auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter(
973 /*parameter_number=*/1, scalar_shape_, "size_param"));
974
975 auto init = builder.AddInstruction(
976 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
977
978 auto* reduce_window =
979 builder.AddInstruction(HloInstruction::CreateReduceWindow(
980 output_shape, a_param, init, window, GetAdd()));
981
982 module_->AddEntryComputation(builder.Build());
983
984 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
985 DynamicParameterBinding::DynamicParameter{1, {}},
986 DynamicParameterBinding::DynamicDimension{0, {}, 0}));
987
988 SCOPED_TRACE(module_->ToString());
989 TF_ASSERT_OK(RunInference());
990 EXPECT_EQ(inference_->GetDynamicSize(reduce_window, {}, 0), size_param);
991 }
992
TEST_F(DynamicDimensionInferenceTest,SelectAndScatterTest)993 TEST_F(DynamicDimensionInferenceTest, SelectAndScatterTest) {
994 // Test the ability to trace select and scatter batch dimensions.
995 auto builder = HloComputation::Builder(TestName());
996 auto input_shape = ShapeUtil::MakeShape(F32, {2, 4, 4});
997 auto source_shape = ShapeUtil::MakeShape(F32, {2, 2, 2});
998
999 Window window;
1000 // First dimension is unchanged.
1001 WindowDimension* batch_dim = window.add_dimensions();
1002 batch_dim->set_size(1);
1003 batch_dim->set_stride(1);
1004 batch_dim->set_padding_low(0);
1005 batch_dim->set_padding_high(0);
1006 batch_dim->set_window_dilation(1);
1007 batch_dim->set_base_dilation(1);
1008
1009 // Second and third dimension are reduced.
1010 for (int64_t i = 0; i < 2; ++i) {
1011 WindowDimension* dim = window.add_dimensions();
1012 dim->set_size(2);
1013 dim->set_stride(2);
1014 dim->set_padding_low(0);
1015 dim->set_padding_high(0);
1016 dim->set_window_dilation(1);
1017 dim->set_base_dilation(1);
1018 }
1019
1020 auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
1021 /*parameter_number=*/0, input_shape, "A"));
1022 auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter(
1023 /*parameter_number=*/1, scalar_shape_, "size_param"));
1024 auto* source = builder.AddInstruction(HloInstruction::CreateParameter(
1025 /*parameter_number=*/2, source_shape, "B"));
1026
1027 auto init = builder.AddInstruction(
1028 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
1029
1030 auto* sns = builder.AddInstruction(HloInstruction::CreateSelectAndScatter(
1031 input_shape, a_param, GetGe(), window, source, init, GetAdd()));
1032
1033 module_->AddEntryComputation(builder.Build());
1034
1035 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
1036 DynamicParameterBinding::DynamicParameter{1, {}},
1037 DynamicParameterBinding::DynamicDimension{0, {}, 0}));
1038 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
1039 DynamicParameterBinding::DynamicParameter{1, {}},
1040 DynamicParameterBinding::DynamicDimension{2, {}, 0}));
1041
1042 SCOPED_TRACE(module_->ToString());
1043 TF_ASSERT_OK(RunInference());
1044 EXPECT_EQ(inference_->GetDynamicSize(sns, {}, 0), size_param);
1045 }
1046
TEST_F(DynamicDimensionInferenceTest,ConcateTest)1047 TEST_F(DynamicDimensionInferenceTest, ConcateTest) {
1048 // Concat two data params.
1049 auto builder = HloComputation::Builder(TestName());
1050
1051 auto data_param = builder.AddInstruction(HloInstruction::CreateParameter(
1052 0, ShapeUtil::MakeShape(F32, {5, 7}), "data_param_1"));
1053 auto data_param_2 = builder.AddInstruction(HloInstruction::CreateParameter(
1054 1, ShapeUtil::MakeShape(F32, {5, 8}), "data_param_2"));
1055 auto size_param = builder.AddInstruction(
1056 HloInstruction::CreateParameter(2, scalar_shape_, "size_param"));
1057
1058 auto* concat = builder.AddInstruction(HloInstruction::CreateConcatenate(
1059 ShapeUtil::MakeShape(F32, {5, 15}), {data_param, data_param_2}, 1));
1060
1061 module_->AddEntryComputation(builder.Build());
1062 // Set up dynamic parameter binding.
1063 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
1064 DynamicParameterBinding::DynamicParameter{2, {}},
1065 DynamicParameterBinding::DynamicDimension{0, {}, 0}));
1066
1067 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
1068 DynamicParameterBinding::DynamicParameter{2, {}},
1069 DynamicParameterBinding::DynamicDimension{1, {}, 0}));
1070
1071 TF_ASSERT_OK(RunInference());
1072 EXPECT_EQ(inference_->GetDynamicSize(concat, {}, 0), size_param);
1073 }
1074
TEST_F(DynamicDimensionInferenceTest,SliceTest)1075 TEST_F(DynamicDimensionInferenceTest, SliceTest) {
1076 auto builder = HloComputation::Builder(TestName());
1077
1078 auto data_param = builder.AddInstruction(HloInstruction::CreateParameter(
1079 0, ShapeUtil::MakeShape(F32, {5, 7}), "data_param"));
1080 auto size_param = builder.AddInstruction(
1081 HloInstruction::CreateParameter(1, scalar_shape_, "size_param"));
1082
1083 auto* slice = builder.AddInstruction(HloInstruction::CreateSlice(
1084 ShapeUtil::MakeShape(F32, {5, 7}), data_param, /*start_indices=*/{0, 0},
1085 /*limit_indices=*/{5, 7}, /*strides=*/{1, 1}));
1086
1087 module_->AddEntryComputation(builder.Build());
1088 // Set up dynamic parameter binding.
1089 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
1090 DynamicParameterBinding::DynamicParameter{1, {}},
1091 DynamicParameterBinding::DynamicDimension{0, {}, 1}));
1092
1093 TF_ASSERT_OK(RunInference());
1094 EXPECT_EQ(inference_->GetDynamicSize(slice, {}, 1), size_param);
1095 }
1096
TEST_F(DynamicDimensionInferenceTest,DynamicSliceTest)1097 TEST_F(DynamicDimensionInferenceTest, DynamicSliceTest) {
1098 auto builder = HloComputation::Builder(TestName());
1099
1100 auto data_param = builder.AddInstruction(HloInstruction::CreateParameter(
1101 0, ShapeUtil::MakeShape(F32, {5, 7}), "data_param"));
1102 auto size_param = builder.AddInstruction(
1103 HloInstruction::CreateParameter(1, scalar_shape_, "size_param"));
1104
1105 std::vector<HloInstruction*> params;
1106 for (int i = 0; i < 2; ++i) {
1107 params.push_back(builder.AddInstruction(HloInstruction::CreateParameter(
1108 i + 2, ShapeUtil::MakeShape(S32, {}), "slice_indices")));
1109 }
1110
1111 auto* slice = builder.AddInstruction(HloInstruction::CreateDynamicSlice(
1112 ShapeUtil::MakeShape(F32, {5, 1}), data_param, params,
1113 /*slice_sizes=*/{5, 1}));
1114
1115 module_->AddEntryComputation(builder.Build());
1116 // Set up dynamic parameter binding.
1117 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
1118 DynamicParameterBinding::DynamicParameter{1, {}},
1119 DynamicParameterBinding::DynamicDimension{0, {}, 0}));
1120
1121 TF_ASSERT_OK(RunInference());
1122 EXPECT_EQ(inference_->GetDynamicSize(slice, {}, 0), size_param);
1123 }
1124
TEST_F(DynamicDimensionInferenceTest,SortTest)1125 TEST_F(DynamicDimensionInferenceTest, SortTest) {
1126 auto builder = HloComputation::Builder(TestName());
1127
1128 auto data_param = builder.AddInstruction(HloInstruction::CreateParameter(
1129 0, ShapeUtil::MakeShape(F32, {5, 7}), "data_param"));
1130 auto size_param = builder.AddInstruction(
1131 HloInstruction::CreateParameter(1, scalar_shape_, "size_param"));
1132
1133 auto compare_builder = HloComputation::Builder("condition");
1134 compare_builder.AddInstruction(HloInstruction::CreateParameter(
1135 0, ShapeUtil::MakeShape(F32, {}), "param1"));
1136 compare_builder.AddInstruction(HloInstruction::CreateParameter(
1137 1, ShapeUtil::MakeShape(F32, {}), "param2"));
1138 compare_builder.AddInstruction(
1139 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1140 HloComputation* compare =
1141 module_->AddEmbeddedComputation(compare_builder.Build());
1142
1143 auto* sort = builder.AddInstruction(HloInstruction::CreateSort(
1144 ShapeUtil::MakeShape(F32, {5, 7}), 1, {data_param}, compare,
1145 /*is_stable=*/false));
1146
1147 module_->AddEntryComputation(builder.Build());
1148 // Set up dynamic parameter binding.
1149 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
1150 DynamicParameterBinding::DynamicParameter{1, {}},
1151 DynamicParameterBinding::DynamicDimension{0, {}, 0}));
1152
1153 TF_ASSERT_OK(RunInference());
1154 EXPECT_EQ(inference_->GetDynamicSize(sort, {}, 0), size_param);
1155 }
1156
TEST_F(DynamicDimensionInferenceTest,MultiValueSortTest)1157 TEST_F(DynamicDimensionInferenceTest, MultiValueSortTest) {
1158 auto builder = HloComputation::Builder(TestName());
1159
1160 auto shape = ShapeUtil::MakeShape(F32, {5, 7});
1161
1162 auto data_param = builder.AddInstruction(
1163 HloInstruction::CreateParameter(0, shape, "data_param"));
1164 auto size_param = builder.AddInstruction(
1165 HloInstruction::CreateParameter(1, scalar_shape_, "size_param"));
1166
1167 auto compare_builder = HloComputation::Builder("condition");
1168 compare_builder.AddInstruction(HloInstruction::CreateParameter(
1169 0, ShapeUtil::MakeShape(F32, {}), "param1"));
1170 compare_builder.AddInstruction(HloInstruction::CreateParameter(
1171 1, ShapeUtil::MakeShape(F32, {}), "param2"));
1172 compare_builder.AddInstruction(HloInstruction::CreateParameter(
1173 2, ShapeUtil::MakeShape(F32, {}), "param3"));
1174 compare_builder.AddInstruction(HloInstruction::CreateParameter(
1175 3, ShapeUtil::MakeShape(F32, {}), "param4"));
1176 compare_builder.AddInstruction(
1177 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1178 HloComputation* compare =
1179 module_->AddEmbeddedComputation(compare_builder.Build());
1180
1181 auto* sort = builder.AddInstruction(
1182 HloInstruction::CreateSort(ShapeUtil::MakeTupleShape({shape, shape}), 1,
1183 {data_param, data_param}, compare,
1184 /*is_stable=*/false));
1185
1186 module_->AddEntryComputation(builder.Build());
1187 // Set up dynamic parameter binding.
1188 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
1189 DynamicParameterBinding::DynamicParameter{1, {}},
1190 DynamicParameterBinding::DynamicDimension{0, {}, 0}));
1191
1192 TF_ASSERT_OK(RunInference());
1193 EXPECT_EQ(inference_->GetDynamicSize(sort, {0}, 0), size_param);
1194 EXPECT_EQ(inference_->GetDynamicSize(sort, {1}, 0), size_param);
1195 }
1196
TEST_F(DynamicDimensionInferenceTest,DynamicSliceSingleElementTest)1197 TEST_F(DynamicDimensionInferenceTest, DynamicSliceSingleElementTest) {
1198 // Slicing out a single element from a dynamic dimension terminates the
1199 // dynamic dimension.
1200 auto builder = HloComputation::Builder(TestName());
1201
1202 auto data_param = builder.AddInstruction(HloInstruction::CreateParameter(
1203 0, ShapeUtil::MakeShape(F32, {5, 7}), "data_param"));
1204 builder.AddInstruction(
1205 HloInstruction::CreateParameter(1, scalar_shape_, "size_param"));
1206
1207 std::vector<HloInstruction*> params;
1208 for (int i = 0; i < 2; ++i) {
1209 params.push_back(builder.AddInstruction(HloInstruction::CreateParameter(
1210 i + 2, ShapeUtil::MakeShape(S32, {}), "slice_indices")));
1211 }
1212
1213 auto* slice = builder.AddInstruction(HloInstruction::CreateDynamicSlice(
1214 ShapeUtil::MakeShape(F32, {1, 1}), data_param, params,
1215 /*slice_sizes=*/{1, 1}));
1216
1217 module_->AddEntryComputation(builder.Build());
1218 // Set up dynamic parameter binding.
1219 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
1220 DynamicParameterBinding::DynamicParameter{1, {}},
1221 DynamicParameterBinding::DynamicDimension{0, {}, 0}));
1222
1223 TF_ASSERT_OK(RunInference());
1224 EXPECT_EQ(inference_->GetDynamicSize(slice, {}, 0), nullptr);
1225 }
1226
TEST_F(DynamicDimensionInferenceTest,InfersCustomOp)1227 TEST_F(DynamicDimensionInferenceTest, InfersCustomOp) {
1228 auto builder = HloComputation::Builder(TestName());
1229
1230 auto data_param = builder.AddInstruction(HloInstruction::CreateParameter(
1231 0, ShapeUtil::MakeShape(F32, {5, 7}), "data_param"));
1232 builder.AddInstruction(
1233 HloInstruction::CreateParameter(1, scalar_shape_, "size_param"));
1234
1235 builder.AddInstruction(HloInstruction::CreateCustomCall(
1236 ShapeUtil::MakeShape(F32, {1, 1}), {data_param}, "MyCustomOp", ""));
1237
1238 module_->AddEntryComputation(builder.Build());
1239
1240 // Set up dynamic parameter binding.
1241 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
1242 DynamicParameterBinding::DynamicParameter{1, {}},
1243 DynamicParameterBinding::DynamicDimension{0, {}, 0}));
1244
1245 bool handler_called = false;
1246 auto handler = [&](HloInstruction* hlo,
1247 DynamicDimensionInference* inference) {
1248 CHECK(inference != nullptr);
1249 CHECK(Cast<HloCustomCallInstruction>(hlo) != nullptr);
1250 handler_called = true;
1251 return OkStatus();
1252 };
1253 TF_ASSERT_OK(RunInference(handler));
1254
1255 EXPECT_TRUE(handler_called);
1256 }
1257
TEST_F(DynamicDimensionInferenceTest,DynamicReshapeOp)1258 TEST_F(DynamicDimensionInferenceTest, DynamicReshapeOp) {
1259 auto builder = HloComputation::Builder(TestName());
1260 auto input = builder.AddInstruction(HloInstruction::CreateParameter(
1261 0, ShapeUtil::MakeShape(F32, {9}), "data_input"));
1262 auto six = builder.AddInstruction(
1263 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(6)));
1264 // Creates an input of shape [<=9], dynamic size is 6.
1265 auto dynamic_input =
1266 builder.AddInstruction(HloInstruction::CreateSetDimensionSize(
1267 ShapeUtil::MakeShape(F32, {9}, {true}), input, six, 0));
1268 auto dynamic_size = builder.AddInstruction(HloInstruction::CreateParameter(
1269 1, ShapeUtil::MakeShape(S32, {}), "size_param"));
1270 auto three = builder.AddInstruction(
1271 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(3)));
1272
1273 // Reshape [<=9] into [3, <=3]
1274
1275 auto dynamic_reshape =
1276 builder.AddInstruction(HloInstruction::CreateDynamicReshape(
1277 ShapeUtil::MakeShape(F32, {3, 3}, {false, true}), dynamic_input,
1278 {three, dynamic_size}));
1279
1280 module_->AddEntryComputation(builder.Build());
1281
1282 TF_ASSERT_OK(RunInference());
1283 EXPECT_EQ(inference_->GetDynamicSize(dynamic_reshape, {}, 0), nullptr);
1284 EXPECT_EQ(inference_->GetDynamicSize(dynamic_reshape, {}, 1), dynamic_size);
1285 }
1286
TEST_F(DynamicDimensionInferenceTest,ReshapeOpWithMultipleDynamicDimensions)1287 TEST_F(DynamicDimensionInferenceTest, ReshapeOpWithMultipleDynamicDimensions) {
1288 auto builder = HloComputation::Builder(TestName());
1289 auto input = builder.AddInstruction(HloInstruction::CreateParameter(
1290 0, ShapeUtil::MakeShape(F32, {9, 2}), "data_input"));
1291 auto six = builder.AddInstruction(
1292 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(6)));
1293 input = builder.AddInstruction(HloInstruction::CreateSetDimensionSize(
1294 ShapeUtil::MakeShape(F32, {9, 2}, {true, false}), input, six, 0));
1295 auto one = builder.AddInstruction(
1296 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(1)));
1297 input = builder.AddInstruction(HloInstruction::CreateSetDimensionSize(
1298 ShapeUtil::MakeShape(F32, {9, 2}, {true, true}), input, one, 1));
1299
1300 // Reshape [<=9, <=2] into [<=9, 1, <=2]
1301
1302 auto dynamic_reshape = builder.AddInstruction(HloInstruction::CreateReshape(
1303 ShapeUtil::MakeShape(F32, {9, 1, 2}, {true, false, true}), input));
1304
1305 module_->AddEntryComputation(builder.Build());
1306
1307 TF_ASSERT_OK(RunInference());
1308 EXPECT_EQ(inference_->GetDynamicSize(dynamic_reshape, {}, 0), six);
1309 EXPECT_EQ(inference_->GetDynamicSize(dynamic_reshape, {}, 1), nullptr);
1310 EXPECT_EQ(inference_->GetDynamicSize(dynamic_reshape, {}, 2), one);
1311 }
1312
TEST_F(DynamicDimensionInferenceTest,HandleMapInDynamicDimensionInference)1313 TEST_F(DynamicDimensionInferenceTest, HandleMapInDynamicDimensionInference) {
1314 const char* module_str = R"(
1315 HloModule test_module
1316
1317 %scatter-combiner.285 (p0.286: c128[], p1.287: c128[]) -> c128[] {
1318 %p0.286 = c128[] parameter(0)
1319 %p1.287 = c128[] parameter(1)
1320 ROOT %add.288 = c128[] add(c128[] %p0.286, c128[] %p1.287)
1321 }
1322
1323 %while_body {
1324 %reshape.8 = s32[] parameter(4)
1325 %reshape.7 = c128[1]{0} parameter(3)
1326 %reduce = pred[] parameter(2)
1327 %concatenate = s32[1]{0} parameter(1)
1328 %slice.4 = s32[1]{0} slice(s32[1]{0} %concatenate), slice={[0 : 1]}
1329 %broadcast.7 = pred[1]{0} broadcast(pred[] %reduce), dimensions={}
1330 %param.1 = (s32[],c128[<=1]{0},s32[1]{0},c128[1]{0}) parameter(0)
1331 %get-tuple-element.2 = c128[<=1]{0} get-tuple-element((s32[],c128[<=1]{0},s32[1]{0},c128[1]{0}) %param.1), index=1
1332 %dynamic-slice.2 = c128[1]{0} dynamic-slice(c128[<=1]{0} %get-tuple-element.2,s32[] %reshape.8), dynamic_slice_sizes={1}
1333 %map = c128[1]{0} map(c128[1]{0} %dynamic-slice.2,c128[1]{0} %reshape.7), dimensions={0}, to_apply=%scatter-combiner.285
1334 %select = c128[1]{0} select(pred[1]{0} %broadcast.7,c128[1]{0} %map,c128[1]{0} %dynamic-slice.2)
1335 %reshape.9 = s32[] reshape(s32[1]{0} %slice.4)
1336 %dynamic-update-slice = c128[<=1]{0} dynamic-update-slice(c128[<=1]{0} %get-tuple-element.2,c128[1]{0} %select,s32[] %reshape.9)
1337 })";
1338 TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnUnverifiedModule(module_str));
1339 TF_ASSERT_OK(RunInference());
1340 }
1341
TEST_F(DynamicDimensionInferenceTest,RuntimeShapeCheck)1342 TEST_F(DynamicDimensionInferenceTest, RuntimeShapeCheck) {
1343 const char* hlo = R"(
1344 HloModule module
1345
1346 ENTRY computation {
1347 a = f32[20,20] parameter(0)
1348 a_size_1 = s32[] parameter(1)
1349 a_size_2 = s32[] parameter(2)
1350 b = f32[20,20] parameter(3)
1351 b_size_1 = s32[] parameter(4)
1352 b_size_2 = s32[] parameter(5)
1353 ROOT f = add(a, b)
1354 }
1355 )";
1356
1357 TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(hlo));
1358 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
1359 DynamicParameterBinding::DynamicParameter{1, {}},
1360 DynamicParameterBinding::DynamicDimension{0, {}, 0}));
1361 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
1362 DynamicParameterBinding::DynamicParameter{2, {}},
1363 DynamicParameterBinding::DynamicDimension{0, {}, 1}));
1364 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
1365 DynamicParameterBinding::DynamicParameter{4, {}},
1366 DynamicParameterBinding::DynamicDimension{3, {}, 0}));
1367 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
1368 DynamicParameterBinding::DynamicParameter{5, {}},
1369 DynamicParameterBinding::DynamicDimension{3, {}, 1}));
1370
1371 TF_ASSERT_OK(RunInference(
1372 /*handler=*/nullptr, DynamicDimensionInference::ShapeCheckMode::kRuntime,
1373 /*assertion_generator=*/[&](HloInstruction* constraint) {
1374 constraint->parent()->AddInstruction(HloInstruction::CreateCustomCall(
1375 ShapeUtil::MakeTokenShape(), {constraint},
1376 /*custom_call_target=*/"__xla__assert",
1377 /*opaque=*/std::string{}, API_VERSION_STATUS_RETURNING));
1378 }));
1379
1380 StatusOr<bool> filecheck_result = RunFileCheck(module_->ToString({}),
1381 R"(
1382 // CHECK: compare = pred[] compare(s32[] %a_size_1, s32[] %b_size_1), direction=EQ
1383 // CHECK: compare.5 = pred[] compare(s32[] %a_size_2, s32[] %b_size_2), direction=EQ
1384 // CHECK: and.2 = pred[] and(pred[] %compare, pred[] %compare.5)
1385 // CHECK: custom-call(pred[] %and.2), custom_call_target="__xla__assert"
1386 )");
1387 TF_ASSERT_OK(filecheck_result.status());
1388 EXPECT_TRUE(*filecheck_result);
1389 }
1390
1391 } // namespace
1392 } // namespace xla
1393