1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/value_inference.h"
17
18 #include <memory>
19 #include <utility>
20 #include <vector>
21
22 #include "absl/strings/match.h"
23 #include "absl/types/span.h"
24 #include "tensorflow/compiler/xla/client/client_library.h"
25 #include "tensorflow/compiler/xla/client/global_data.h"
26 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
27 #include "tensorflow/compiler/xla/client/lib/prng.h"
28 #include "tensorflow/compiler/xla/client/xla_builder.h"
29 #include "tensorflow/compiler/xla/client/xla_computation.h"
30 #include "tensorflow/compiler/xla/layout_util.h"
31 #include "tensorflow/compiler/xla/literal.h"
32 #include "tensorflow/compiler/xla/shape_util.h"
33 #include "tensorflow/compiler/xla/status_macros.h"
34 #include "tensorflow/compiler/xla/statusor.h"
35 #include "tensorflow/compiler/xla/test.h"
36 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
37 #include "tensorflow/compiler/xla/tests/test_macros.h"
38 #include "tensorflow/compiler/xla/tests/test_utils.h"
39 #include "tensorflow/compiler/xla/xla_data.pb.h"
40 #include "tensorflow/core/lib/core/status_test_util.h"
41 #include "tensorflow/core/platform/status.h"
42 #include "tensorflow/core/platform/statusor.h"
43
44 namespace xla {
45 namespace {
46
47 class ValueInferenceTest : public ::testing::Test {
48 public:
TestName() const49 std::string TestName() const {
50 return ::testing::UnitTest::GetInstance()->current_test_info()->name();
51 }
52 };
53
54 class DynamismInferenceTest : public ValueInferenceTest {
55 public:
DynamismInferenceTest(se::Platform * platform=nullptr)56 explicit DynamismInferenceTest(se::Platform* platform = nullptr)
57 : platform_(platform) {}
58
ComputeDynamismLiteral(XlaOp operand,XlaBuilder * builder,Layout * output_layout=nullptr)59 StatusOr<Literal> ComputeDynamismLiteral(XlaOp operand, XlaBuilder* builder,
60 Layout* output_layout = nullptr) {
61 TF_RETURN_IF_ERROR(builder->first_error());
62 ValueInference value_inference(builder);
63 TF_ASSIGN_OR_RETURN(auto literal_slice,
64 value_inference.AnalyzeIsDynamic(operand));
65 return literal_slice.Clone();
66 }
67
ComputeDynamismScalar(XlaOp operand,XlaBuilder * builder,ShapeIndex index={})68 StatusOr<bool> ComputeDynamismScalar(XlaOp operand, XlaBuilder* builder,
69 ShapeIndex index = {}) {
70 TF_ASSIGN_OR_RETURN(auto literal,
71 ComputeDynamismLiteral(operand, builder, nullptr));
72 return literal.Get<bool>({}, index);
73 }
74
75 se::Platform* platform_;
76 };
77
TEST_F(DynamismInferenceTest,ScalarInt32Literal)78 TEST_F(DynamismInferenceTest, ScalarInt32Literal) {
79 XlaBuilder b(TestName());
80 auto computation = ConstantR0<int32_t>(&b, 42);
81
82 auto value = ComputeDynamismScalar(computation, &b);
83 ASSERT_TRUE(value.ok()) << value.status();
84 // A constant is not dynamic.
85 EXPECT_EQ(value.ValueOrDie(), false);
86 }
87
TEST_F(DynamismInferenceTest,Iota)88 TEST_F(DynamismInferenceTest, Iota) {
89 // The output of iota are consistened static.
90 XlaBuilder b(TestName());
91 auto computation = Iota(&b, S32, 2);
92 // Iota is not dynamic.
93 EXPECT_FALSE(
94 ComputeDynamismLiteral(computation, &b).ValueOrDie().Get<bool>({0}));
95 }
96
TEST_F(DynamismInferenceTest,TupleSimple)97 TEST_F(DynamismInferenceTest, TupleSimple) {
98 XlaBuilder b(TestName());
99 auto c = ConstantR0<int32_t>(&b, 42);
100 auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
101
102 auto tuple = Tuple(&b, {c, p});
103 EXPECT_EQ(ComputeDynamismScalar(tuple, &b, {0}).ValueOrDie(), false);
104 EXPECT_EQ(ComputeDynamismScalar(tuple, &b, {1}).ValueOrDie(), true);
105 }
106
TEST_F(DynamismInferenceTest,TupleGteKeepsDynamism)107 TEST_F(DynamismInferenceTest, TupleGteKeepsDynamism) {
108 XlaBuilder b(TestName());
109 auto c = ConstantR0<int32_t>(&b, 42);
110 auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
111
112 auto tuple = Tuple(&b, {c, p});
113 auto gte0 = GetTupleElement(tuple, 0);
114 auto gte1 = GetTupleElement(tuple, 1);
115 auto tuple_2 = Tuple(&b, {gte0, gte1});
116 EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {0}).ValueOrDie(), false);
117 EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {1}).ValueOrDie(), true);
118 }
119
TEST_F(DynamismInferenceTest,PredValueUsedTwice)120 TEST_F(DynamismInferenceTest, PredValueUsedTwice) {
121 XlaBuilder b(TestName());
122 auto c = ConstantR0<int32_t>(&b, 42);
123 auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
124 auto pred = Eq(c, p);
125 auto result = Select(pred, p, c);
126 EXPECT_EQ(ComputeDynamismScalar(result, &b, {}).ValueOrDie(), true);
127 }
128
TEST_F(DynamismInferenceTest,ReduceUsedTwice)129 TEST_F(DynamismInferenceTest, ReduceUsedTwice) {
130 XlaBuilder b(TestName());
131 auto c = ConstantR0<int32_t>(&b, 42);
132 auto p = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2}), "p0");
133 auto zero = ConstantR0<int32_t>(&b, 0);
134 XlaComputation add_s32 = CreateScalarAddComputation(S32, &b);
135 auto reduce = Reduce(p, zero, add_s32, {0});
136 auto pred = Eq(c, reduce);
137 auto result = Select(pred, reduce, c);
138 EXPECT_EQ(ComputeDynamismScalar(result, &b, {}).ValueOrDie(), true);
139 }
140
TEST_F(DynamismInferenceTest,VariadicReduce)141 TEST_F(DynamismInferenceTest, VariadicReduce) {
142 XlaBuilder b(TestName());
143 auto c = ConstantR2<int32_t>(&b, {{0, 0}});
144 auto p = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {1, 2}), "p0");
145 // half_dynamic[0] is static, half_dynamic[0] is dynamic.
146 auto half_dynamic = ConcatInDim(&b, {c, p}, 0);
147 XlaBuilder reduce_add("reduce_add");
148 auto p0 = Parameter(&reduce_add, 0, ShapeUtil::MakeScalarShape(S32), "p");
149 auto p1 = Parameter(&reduce_add, 1, ShapeUtil::MakeScalarShape(S32), "p");
150 auto p2 = Parameter(&reduce_add, 2, ShapeUtil::MakeScalarShape(S32), "p");
151 auto p3 = Parameter(&reduce_add, 3, ShapeUtil::MakeScalarShape(S32), "p");
152 auto reduce_result = p0;
153 reduce_result = Add(reduce_result, p1);
154 reduce_result = Add(reduce_result, p2);
155 reduce_result = Add(reduce_result, p3);
156 Tuple(&reduce_add, {reduce_result, reduce_result});
157 auto init = ConstantR0<int32_t>(&b, 0);
158 auto variadic_reduce = Reduce(&b, {half_dynamic, half_dynamic}, {init, init},
159 reduce_add.Build().value(), {1});
160 auto result = GetTupleElement(variadic_reduce, 0);
161
162 // result[0] should be static; result[1] should be dynamic.
163 EXPECT_FALSE(ComputeDynamismLiteral(result, &b).ValueOrDie().Get<bool>({0}));
164 EXPECT_TRUE(ComputeDynamismLiteral(result, &b).ValueOrDie().Get<bool>({1}));
165 }
166
TEST_F(DynamismInferenceTest,DynamicSelectorWithMixedValues)167 TEST_F(DynamismInferenceTest, DynamicSelectorWithMixedValues) {
168 XlaBuilder b(TestName());
169 auto constant_pred = ConstantR1<bool>(&b, {true});
170 auto dynamic_pred = Parameter(&b, 0, ShapeUtil::MakeShape(PRED, {1}), "p0");
171 auto concat = ConcatInDim(&b, {constant_pred, dynamic_pred}, 0);
172 auto constant_values = ConstantR1<bool>(&b, {true, true});
173 auto result = Select(concat, constant_values, constant_values);
174 // First result is static (selector is constant, both values are constant).
175 // Iota is not dynamic.
176 EXPECT_FALSE(ComputeDynamismLiteral(result, &b).ValueOrDie().Get<bool>({0}));
177 // Second result is dynamic (selector is dynamic).
178 EXPECT_TRUE(ComputeDynamismLiteral(result, &b).ValueOrDie().Get<bool>({1}));
179 }
180
TEST_F(DynamismInferenceTest,ConcatSliceReshapeKeepsDynamism)181 TEST_F(DynamismInferenceTest, ConcatSliceReshapeKeepsDynamism) {
182 XlaBuilder b(TestName());
183 auto c = ConstantR0<int32_t>(&b, 42);
184 auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
185
186 auto concat = ConcatScalars(&b, {c, p});
187 auto slice0 = SliceInDim(concat, 0, 1, 1, 0);
188 auto reshape0 = Reshape(slice0, {});
189 auto slice1 = SliceInDim(concat, 1, 2, 1, 0);
190 auto reshape1 = Reshape(slice1, {});
191 auto tuple_2 = Tuple(&b, {reshape0, reshape1});
192 EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {0}).ValueOrDie(), false);
193 EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {1}).ValueOrDie(), true);
194 }
195
TEST_F(DynamismInferenceTest,ParameterIsDynamic)196 TEST_F(DynamismInferenceTest, ParameterIsDynamic) {
197 XlaBuilder b(TestName());
198 auto computation = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
199
200 auto value = ComputeDynamismScalar(computation, &b);
201 ASSERT_TRUE(value.ok()) << value.status();
202 // A parameter is considered dynamic.
203 EXPECT_EQ(value.ValueOrDie(), true);
204 }
205
TEST_F(DynamismInferenceTest,UnaryOpKeepsDynamism)206 TEST_F(DynamismInferenceTest, UnaryOpKeepsDynamism) {
207 XlaBuilder b(TestName());
208 auto c = ConstantR0<int32_t>(&b, 42);
209 auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
210
211 auto neg0 = Neg(c);
212 auto neg1 = Neg(p);
213 auto tuple_2 = Tuple(&b, {neg0, neg1});
214 EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {0}).ValueOrDie(), false);
215 EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {1}).ValueOrDie(), true);
216 }
217
TEST_F(DynamismInferenceTest,ParameterWithToken)218 TEST_F(DynamismInferenceTest, ParameterWithToken) {
219 // Test that token shape can be handled in a parameter.
220 XlaBuilder b(TestName());
221 auto p =
222 Parameter(&b, 0,
223 ShapeUtil::MakeTupleShape({ShapeUtil::MakeTokenShape(),
224 ShapeUtil::MakeScalarShape(S32)}),
225 "p0");
226 EXPECT_EQ(ComputeDynamismScalar(p, &b, {0}).ValueOrDie(), true);
227 EXPECT_EQ(ComputeDynamismScalar(p, &b, {1}).ValueOrDie(), true);
228 }
229
TEST_F(DynamismInferenceTest,BinaryOpsOrsDynamism)230 TEST_F(DynamismInferenceTest, BinaryOpsOrsDynamism) {
231 XlaBuilder b(TestName());
232 auto c = ConstantR0<int32_t>(&b, 42);
233 auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
234
235 // Static value + static value = static
236 auto add1 = Add(c, c);
237 // Dynamic value + dynamic value = dynamic
238 auto add2 = Add(p, c);
239 auto tuple_2 = Tuple(&b, {add1, add2});
240 EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {0}).ValueOrDie(), false);
241 EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {1}).ValueOrDie(), true);
242 }
243
TEST_F(DynamismInferenceTest,GetDimensionSize)244 TEST_F(DynamismInferenceTest, GetDimensionSize) {
245 XlaBuilder b(TestName());
246 // param = Param([<=2, 3])
247 // get_dimension_size(param, 0) is dynamic
248 // get_dimension_size(param, 1) is static
249 auto p =
250 Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, false}), "p0");
251
252 auto gds0 = GetDimensionSize(p, 0);
253 auto gds1 = GetDimensionSize(p, 1);
254 auto tuple_2 = Tuple(&b, {gds0, gds1});
255 EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {0}).ValueOrDie(), true);
256 EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {1}).ValueOrDie(), false);
257 }
258
TEST_F(DynamismInferenceTest,DynamicSliceWithConstantOperands)259 TEST_F(DynamismInferenceTest, DynamicSliceWithConstantOperands) {
260 XlaBuilder b(TestName());
261
262 auto constant = ConstantR1<int32_t>(&b, {0, 1, 2, 3});
263 auto slice_start = ConstantR0(&b, 1);
264 auto dynamic_slice = DynamicSlice(constant, {slice_start}, {1});
265 EXPECT_FALSE(
266 ComputeDynamismLiteral(dynamic_slice, &b).ValueOrDie().Get<bool>({0}));
267 }
268
TEST_F(DynamismInferenceTest,GatherWithCommonParent)269 TEST_F(DynamismInferenceTest, GatherWithCommonParent) {
270 XlaBuilder b(TestName());
271 // Test the analysis on a gather where first operand and second operand have
272 // common parents.
273 Shape indices_shape = ShapeUtil::MakeShape(S32, {2});
274
275 auto operand1 = Parameter(&b, 0, indices_shape, "p1");
276 auto operand2 = Parameter(&b, 1, indices_shape, "p2");
277 auto indices = Sub(operand1, operand2);
278 GatherDimensionNumbers dim_numbers;
279 dim_numbers.add_offset_dims(1);
280 dim_numbers.add_start_index_map(0);
281 dim_numbers.set_index_vector_dim(1);
282 auto gather = Gather(operand1, indices, dim_numbers, {1});
283 ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
284 EXPECT_TRUE(
285 ComputeDynamismLiteral(gather, &b).ValueOrDie().Get<bool>({0, 0}));
286 }
287
TEST_F(DynamismInferenceTest,GatherWithConstantParent)288 TEST_F(DynamismInferenceTest, GatherWithConstantParent) {
289 XlaBuilder b(TestName());
290 // Test the analysis on a gather.
291 Shape indices_shape = ShapeUtil::MakeShape(S32, {2});
292 auto data_operand = ConstantR1<int32_t>(&b, {1, 2});
293 auto indices = ConstantR1<int32_t>(&b, {1, 2});
294 GatherDimensionNumbers dim_numbers;
295 dim_numbers.add_offset_dims(1);
296 dim_numbers.add_start_index_map(0);
297 dim_numbers.set_index_vector_dim(1);
298 auto gather = Gather(data_operand, indices, dim_numbers, {1});
299 ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
300 // Everything is constant, result is also contant.
301 EXPECT_FALSE(
302 ComputeDynamismLiteral(gather, &b).ValueOrDie().Get<bool>({0, 0}));
303 }
304
TEST_F(DynamismInferenceTest,GatherWithSharedConstantParent)305 TEST_F(DynamismInferenceTest, GatherWithSharedConstantParent) {
306 XlaBuilder b(TestName());
307 // Test the analysis on a gather.
308 Shape indices_shape = ShapeUtil::MakeShape(S32, {2});
309 auto operand1 = ConstantR1<int32_t>(&b, {1, 2});
310 auto operand2 = ConstantR1<int32_t>(&b, {1, 2});
311 auto indices = Sub(operand1, operand2);
312 GatherDimensionNumbers dim_numbers;
313 dim_numbers.add_offset_dims(1);
314 dim_numbers.add_start_index_map(0);
315 dim_numbers.set_index_vector_dim(1);
316 auto gather = Gather(operand1, indices, dim_numbers, {1});
317 ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
318 // Everything is constant, result is also contant.
319 EXPECT_FALSE(
320 ComputeDynamismLiteral(gather, &b).ValueOrDie().Get<bool>({0, 0}));
321 }
322
TEST_F(DynamismInferenceTest,InferThroughPad)323 TEST_F(DynamismInferenceTest, InferThroughPad) {
324 XlaBuilder b(TestName());
325 // Test the analysis on a gather.
326 auto operand1 = ConstantR1<int32_t>(&b, {1, 2});
327 auto parameter = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {}), "p0");
328 PaddingConfig padding_config;
329 padding_config.add_dimensions()->set_edge_padding_high(1);
330 // After pad the value is [constant, constant, parameter].
331 auto pad = Pad(operand1, parameter, padding_config);
332 ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
333 // Everything is constant, result is also contant.
334 EXPECT_FALSE(ComputeDynamismLiteral(pad, &b).ValueOrDie().Get<bool>({0}));
335 EXPECT_FALSE(ComputeDynamismLiteral(pad, &b).ValueOrDie().Get<bool>({1}));
336 EXPECT_TRUE(ComputeDynamismLiteral(pad, &b).ValueOrDie().Get<bool>({2}));
337 }
338
TEST_F(DynamismInferenceTest,InferThroughConditionalBranchesAreSame)339 TEST_F(DynamismInferenceTest, InferThroughConditionalBranchesAreSame) {
340 // The result of following conditional is static.
341 // pred = .. # a dynamic value
342 // if (pred) {
343 // return (1) # both branches return the same value
344 // } else {
345 // return (1)
346 // }
347 //
348
349 auto s32_shape = ShapeUtil::MakeShape(S32, {});
350 auto cond_shape = ShapeUtil::MakeTupleShape({s32_shape});
351 XlaBuilder true_builder("true");
352 Parameter(&true_builder, 0, s32_shape, "cond_param");
353 Tuple(&true_builder, {ConstantR0<int32_t>(&true_builder, 1)});
354 auto true_computation = true_builder.Build().ValueOrDie();
355
356 XlaBuilder false_builder("false");
357 Parameter(&false_builder, 0, s32_shape, "cond_param");
358 Tuple(&false_builder, {ConstantR0<int32_t>(&false_builder, 1)});
359 auto false_computation = false_builder.Build().ValueOrDie();
360
361 XlaBuilder b(TestName());
362 auto parameter = Parameter(&b, 0, ShapeUtil::MakeShape(PRED, {}), "p0");
363 auto constant = ConstantR0<int32_t>(&b, 0);
364 auto cond = Conditional(parameter, constant, true_computation, constant,
365 false_computation);
366 auto gte = GetTupleElement(cond, 0);
367 ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
368 // Result is not dynamic.
369 EXPECT_FALSE(ComputeDynamismLiteral(gte, &b).ValueOrDie().Get<bool>({}));
370 }
371
TEST_F(DynamismInferenceTest,InferThroughCall)372 TEST_F(DynamismInferenceTest, InferThroughCall) {
373 // The result of following call instruction is static.
374 //
375 // Callee:
376 // p = param
377 // return p
378 //
379 // Entry:
380 // c = constant(3)
381 // return call(c), callee
382 //
383 //
384
385 auto s32_shape = ShapeUtil::MakeShape(S32, {});
386 XlaBuilder call_builder("call");
387 Parameter(&call_builder, 0, s32_shape, "call_param");
388 auto call_computation = call_builder.Build().ValueOrDie();
389
390 XlaBuilder b(TestName());
391 auto constant = ConstantR0<int32_t>(&b, 3);
392 auto call = Call(&b, call_computation, {constant});
393 ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
394 // Result is static.
395 EXPECT_EQ(ComputeDynamismScalar(call, &b, {}).ValueOrDie(), false);
396 }
397
TEST_F(DynamismInferenceTest,InferThroughConditionalBranchesAreNotSame)398 TEST_F(DynamismInferenceTest, InferThroughConditionalBranchesAreNotSame) {
399 // The result of following conditional is dynamic.
400 // pred = .. # a dynamic value
401 // if (pred) {
402 // return (1) # These two branches return different values.
403 // } else {
404 // return (2)
405 // }
406 //
407
408 auto s32_shape = ShapeUtil::MakeShape(S32, {});
409 auto cond_shape = ShapeUtil::MakeTupleShape({s32_shape});
410 XlaBuilder true_builder("true");
411 Parameter(&true_builder, 0, s32_shape, "cond_param");
412 Tuple(&true_builder, {ConstantR0<int32_t>(&true_builder, 1)});
413 auto true_computation = true_builder.Build().ValueOrDie();
414
415 XlaBuilder false_builder("false");
416 Parameter(&false_builder, 0, s32_shape, "cond_param");
417 Tuple(&false_builder, {ConstantR0<int32_t>(&false_builder, 2)});
418 auto false_computation = false_builder.Build().ValueOrDie();
419
420 XlaBuilder b(TestName());
421 auto parameter = Parameter(&b, 0, ShapeUtil::MakeShape(PRED, {}), "p0");
422 auto constant = ConstantR0<int32_t>(&b, 0);
423 auto cond = Conditional(parameter, constant, true_computation, constant,
424 false_computation);
425 auto gte = GetTupleElement(cond, 0);
426 ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
427 // Result is dynamic.
428 EXPECT_TRUE(ComputeDynamismLiteral(gte, &b).ValueOrDie().Get<bool>({}));
429 }
430
TEST_F(DynamismInferenceTest,InferThroughConditionalPredIsConstantTrueBranch)431 TEST_F(DynamismInferenceTest, InferThroughConditionalPredIsConstantTrueBranch) {
432 // The result of following conditional is static.
433 // pred = true
434 // if (pred) {
435 // return (1)
436 // } else {
437 // return (..dynamic_value...)
438 // }
439 //
440
441 auto s32_shape = ShapeUtil::MakeShape(S32, {});
442 auto cond_shape = ShapeUtil::MakeTupleShape({s32_shape});
443 XlaBuilder true_builder("true");
444 Parameter(&true_builder, 0, s32_shape, "cond_param");
445 Tuple(&true_builder, {ConstantR0<int32_t>(&true_builder, 0)});
446 auto true_computation = true_builder.Build().ValueOrDie();
447
448 XlaBuilder false_builder("false");
449 Tuple(&false_builder,
450 {Parameter(&false_builder, 0, s32_shape, "cond_param")});
451 auto false_computation = false_builder.Build().ValueOrDie();
452
453 XlaBuilder b(TestName());
454 auto pred = ConstantR0<bool>(&b, true);
455 auto constant = ConstantR0<int32_t>(&b, 0);
456 auto cond = Conditional(pred, constant, true_computation, constant,
457 false_computation);
458 auto gte = GetTupleElement(cond, 0);
459 ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
460 // Result is not dynamic.
461 EXPECT_FALSE(ComputeDynamismLiteral(gte, &b).ValueOrDie().Get<bool>({}));
462 }
463
TEST_F(DynamismInferenceTest,InferThroughConditionalPredIsConstantFalseBranch)464 TEST_F(DynamismInferenceTest,
465 InferThroughConditionalPredIsConstantFalseBranch) {
466 // The result of following conditional is dynamic.
467 // pred = false
468 // if (pred) {
469 // return (1)
470 // } else {
471 // return (..dynamic_value...)
472 // }
473 //
474
475 auto s32_shape = ShapeUtil::MakeShape(S32, {});
476 auto cond_shape = ShapeUtil::MakeTupleShape({s32_shape});
477 XlaBuilder true_builder("true");
478 Parameter(&true_builder, 0, s32_shape, "cond_param");
479 Tuple(&true_builder, {ConstantR0<int32_t>(&true_builder, 0)});
480 auto true_computation = true_builder.Build().ValueOrDie();
481
482 XlaBuilder false_builder("false");
483 Tuple(&false_builder,
484 {Parameter(&false_builder, 0, s32_shape, "cond_param")});
485 auto false_computation = false_builder.Build().ValueOrDie();
486
487 XlaBuilder b(TestName());
488 auto param = Parameter(&b, 0, s32_shape, "param");
489 auto pred = ConstantR0<bool>(&b, false);
490 auto constant = ConstantR0<int32_t>(&b, 0);
491 auto cond =
492 Conditional(pred, constant, true_computation, param, false_computation);
493 auto gte = GetTupleElement(cond, 0);
494 ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
495 // Result is dynamic.
496 EXPECT_TRUE(ComputeDynamismLiteral(gte, &b).ValueOrDie().Get<bool>({}));
497 }
498
TEST_F(DynamismInferenceTest,ArgumentForwardingNestedTuple)499 TEST_F(DynamismInferenceTest, ArgumentForwardingNestedTuple) {
500 // The result of following conditional is considered static.
501 // pred = .. dynamic value..
502 //
503 // op = 1
504 // if (pred) {
505 // if (pred) {
506 // return op
507 // } else {
508 // return op
509 // }
510 // } else {
511 // if (pred) {
512 // return op
513 // } else {
514 // return op
515 // }
516 // }
517 //
518 auto pred_shape = ShapeUtil::MakeShape(PRED, {});
519 auto s32_shape = ShapeUtil::MakeShape(S32, {});
520 auto tuple_shape = ShapeUtil::MakeTupleShape({pred_shape, s32_shape});
521 auto cond_shape = ShapeUtil::MakeTupleShape({s32_shape});
522 XlaBuilder inner_true_builder("inner_true");
523 Parameter(&inner_true_builder, 0, s32_shape, "cond_param");
524 Tuple(&inner_true_builder, {ConstantR0<int32_t>(&inner_true_builder, 0)});
525 auto inner_true_computation = inner_true_builder.Build().ValueOrDie();
526
527 XlaBuilder inner_false_builder("inner_false");
528 Tuple(&inner_false_builder,
529 {Parameter(&inner_false_builder, 0, s32_shape, "cond_param")});
530 auto inner_false_computation = inner_false_builder.Build().ValueOrDie();
531
532 XlaBuilder true_builder("true");
533 {
534 auto param = Parameter(&true_builder, 0, tuple_shape, "param");
535 auto op = GetTupleElement(param, 1);
536 auto pred = GetTupleElement(param, 0);
537 Conditional(pred, op, inner_true_computation, op, inner_false_computation);
538 }
539 auto true_computation = true_builder.Build().ValueOrDie();
540 XlaBuilder false_builder("false");
541 {
542 auto param = Parameter(&false_builder, 0, tuple_shape, "param");
543 auto op = GetTupleElement(param, 1);
544 auto pred = GetTupleElement(param, 0);
545 Conditional(pred, op, inner_true_computation, op, inner_false_computation);
546 }
547 auto false_computation = false_builder.Build().ValueOrDie();
548 XlaBuilder b(TestName());
549 auto constant = ConstantR0<int32_t>(&b, 0);
550 auto pred = Parameter(&b, 0, pred_shape, "param");
551 auto param = Tuple(&b, {pred, constant});
552 auto cond =
553 Conditional(pred, param, true_computation, param, false_computation);
554 auto gte = GetTupleElement(cond, 0);
555 ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
556 // Result is static.
557 EXPECT_FALSE(ComputeDynamismLiteral(gte, &b).ValueOrDie().Get<bool>({}));
558 }
559
560 class UpperBoundInferenceTest : public ValueInferenceTest {
561 public:
UpperBoundInferenceTest(se::Platform * platform=nullptr)562 explicit UpperBoundInferenceTest(se::Platform* platform = nullptr)
563 : platform_(platform) {}
564
ComputeUpperBoundLiteral(XlaOp operand,XlaBuilder * builder,Layout * output_layout=nullptr)565 StatusOr<OptionalLiteral> ComputeUpperBoundLiteral(
566 XlaOp operand, XlaBuilder* builder, Layout* output_layout = nullptr) {
567 ValueInference value_inference(builder);
568 TF_ASSIGN_OR_RETURN(auto literal,
569 value_inference.AnalyzeConstant(
570 operand, ValueInferenceMode::kUpperBound));
571 return literal;
572 }
573
574 se::Platform* platform_;
575 };
576
TEST_F(UpperBoundInferenceTest,GetDimensionSize)577 TEST_F(UpperBoundInferenceTest, GetDimensionSize) {
578 XlaBuilder b(TestName());
579 auto p =
580 Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, false}), "p0");
581
582 auto gds0 = GetDimensionSize(p, 0);
583 auto gds1 = GetDimensionSize(p, 1);
584 auto tuple_2 = Tuple(&b, {gds0, gds1});
585 EXPECT_EQ(
586 ComputeUpperBoundLiteral(tuple_2, &b).ValueOrDie().Get<int32_t>({}, {0}),
587 2);
588 EXPECT_EQ(
589 ComputeUpperBoundLiteral(tuple_2, &b).ValueOrDie().Get<int32_t>({}, {1}),
590 3);
591 }
592
TEST_F(UpperBoundInferenceTest,GetDimensionSizeSub)593 TEST_F(UpperBoundInferenceTest, GetDimensionSizeSub) {
594 XlaBuilder b(TestName());
595 auto p =
596 Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, false}), "p0");
597
598 // The range of the first dimension is [0, 2]
599 auto gds0 = GetDimensionSize(p, 0);
600 // The range of the second dimension is [3, 3]
601 auto gds1 = GetDimensionSize(p, 1);
602 // Upper bound of `second_dimension - first_dimension` is 3 - 0 = 3
603 auto sub = Sub(gds1, gds0);
604 EXPECT_EQ(ComputeUpperBoundLiteral(sub, &b).ValueOrDie().Get<int32_t>({}), 3);
605 }
606
TEST_F(UpperBoundInferenceTest,GetDimensionSizeDiv)607 TEST_F(UpperBoundInferenceTest, GetDimensionSizeDiv) {
608 XlaBuilder b(TestName());
609 auto p =
610 Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, false}), "p0");
611 // The range of the first dimension is [0, 2]
612 auto gds0 = GetDimensionSize(p, 0);
613 // The range of the second dimension is [3, 3]
614 auto gds1 = GetDimensionSize(p, 1);
615 // Upper bound of `second_dimension / first_dimension` is 3 / 1 = 3. Notice we
616 // don't use 0 as the lower bound as it would create divide-by-zero error.
617 auto div = Div(gds1, gds0);
618 EXPECT_EQ(ComputeUpperBoundLiteral(div, &b).ValueOrDie().Get<int32_t>({}), 3);
619 }
620
TEST_F(UpperBoundInferenceTest,SumSubtract)621 TEST_F(UpperBoundInferenceTest, SumSubtract) {
622 // If x = a, y = b - a
623 // upperbound(x + y) should be upperbound(b)
624 XlaBuilder b(TestName());
625 auto p =
626 Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, true}), "p0");
627 // The range of the first dimension is [0, 2]
628 auto gds0 = GetDimensionSize(p, 0);
629 // The range of the second dimension is [0, 3]
630 auto gds1 = GetDimensionSize(p, 1);
631 auto sub = Sub(gds1, gds0);
632 auto add = Add(sub, gds0);
633 EXPECT_EQ(ComputeUpperBoundLiteral(add, &b).ValueOrDie().Get<int32_t>({}), 3);
634 auto add2 = Add(gds1, gds0);
635 // upperbound(gds1 - gds0 + gds1 + gds0) ==> upperbound(2 * gds1)
636 auto add3 = Add(sub, add2);
637 EXPECT_EQ(ComputeUpperBoundLiteral(add3, &b).ValueOrDie().Get<int32_t>({}),
638 6);
639 }
640
TEST_F(UpperBoundInferenceTest,SumSubtractWithDataShuffling)641 TEST_F(UpperBoundInferenceTest, SumSubtractWithDataShuffling) {
642 // Similar to the test above, but with some data shuffling ops in it
643 // (broadcast, slice, reshape, identity convert, etc).
644 XlaBuilder b(TestName());
645 auto p =
646 Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, true}), "p0");
647 // The range of the first dimension is [0, 2]
648 auto gds0 = GetDimensionSize(p, 0);
649 // The range of the second dimension is [0, 3]
650 auto gds1 = GetDimensionSize(p, 1);
651 auto broadcast = Broadcast(gds0, {1, 10});
652 auto convert = ConvertElementType(broadcast, S32); // Identity convert.
653 auto slice = SliceInDim(convert, /*start_index=*/0, /*limit_index=*/1,
654 /*stride=*/1, /*dimno=*/1);
655 gds0 = Reshape(slice, {});
656 auto sub = Sub(gds1, gds0);
657 auto add = Add(sub, gds0);
658 EXPECT_EQ(ComputeUpperBoundLiteral(add, &b).ValueOrDie().Get<int32_t>({}), 3);
659 auto add2 = Add(gds1, gds0);
660 // upperbound(gds1 - gds0 + gds1 + gds0) ==> upperbound(2 * gds1)
661 auto add3 = Add(sub, add2);
662 EXPECT_EQ(ComputeUpperBoundLiteral(add3, &b).ValueOrDie().Get<int32_t>({}),
663 6);
664 }
665
TEST_F(UpperBoundInferenceTest,SumSubtractEquivalentGetDimensionSize)666 TEST_F(UpperBoundInferenceTest, SumSubtractEquivalentGetDimensionSize) {
667 XlaBuilder b(TestName());
668 auto p =
669 Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, true}), "p0");
670 // The range of the first dimension is [0, 2]
671 auto gds0 = GetDimensionSize(p, 0);
672 // The range of the second dimension is [0, 3]
673 auto gds1 = GetDimensionSize(p, 1);
674 // gds2 is equivalent to gds0
675 auto gds2 = GetDimensionSize(p, 0);
676 auto sub = Sub(gds1, gds2);
677 auto add = Add(sub, gds0);
678 // upperbound(gds0 + gds1 - gds2) is equal to upperbound(gds1) if gds0 ==
679 // gds2.
680 EXPECT_EQ(ComputeUpperBoundLiteral(add, &b).ValueOrDie().Get<int32_t>({}), 3);
681 }
682
TEST_F(UpperBoundInferenceTest,ParamCantInferBound)683 TEST_F(UpperBoundInferenceTest, ParamCantInferBound) {
684 // We can infer a parameter's dimension's bound, but not the parameter value's
685 // bound.
686 XlaBuilder b(TestName());
687 auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2}, {true}), "p0");
688 auto p1 = Parameter(&b, 1, ShapeUtil::MakeShape(S32, {}, {}), "p1");
689 auto gds = GetDimensionSize(p0, 0);
690 auto sub = Div(gds, p1);
691 EXPECT_FALSE(ComputeUpperBoundLiteral(sub, &b)
692 .ValueOrDie()
693 .Get<int32_t>({})
694 .has_value());
695 }
696
TEST_F(UpperBoundInferenceTest,KeyValueSort)697 TEST_F(UpperBoundInferenceTest, KeyValueSort) {
698 XlaBuilder comparator_b("comparator");
699 auto p0 = Parameter(&comparator_b, 0, ShapeUtil::MakeShape(S32, {}), "p0");
700 auto p1 = Parameter(&comparator_b, 1, ShapeUtil::MakeShape(S32, {}), "p1");
701 Parameter(&comparator_b, 2, ShapeUtil::MakeShape(S32, {}), "p2");
702 Parameter(&comparator_b, 3, ShapeUtil::MakeShape(S32, {}), "p3");
703 Compare(p0, p1, ComparisonDirection::kGe);
704 TF_ASSERT_OK_AND_ASSIGN(auto comparator, comparator_b.Build());
705
706 int64_t elem_count = 17;
707 XlaBuilder b(TestName());
708 auto param = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {elem_count}), "p0");
709 auto iota = Iota(&b, S32, elem_count);
710 auto sort = Sort({param, iota}, comparator);
711 auto gte = GetTupleElement(sort, 1);
712
713 for (int64_t i = 0; i < elem_count; ++i) {
714 auto result_first_elem =
715 ComputeUpperBoundLiteral(gte, &b).ValueOrDie().Get<int32_t>({i});
716 // We can infer the bound of sort.
717 EXPECT_TRUE(result_first_elem.has_value());
718 // The bound of the sort result is the max value in the input.
719 EXPECT_EQ(result_first_elem.value(), elem_count - 1);
720 }
721 }
722
723 class ConstValueInferenceTest : public ValueInferenceTest {
724 public:
ConstValueInferenceTest(se::Platform * platform=nullptr)725 explicit ConstValueInferenceTest(se::Platform* platform = nullptr)
726 : platform_(platform) {}
727
ComputeConstantValueLiteral(XlaOp operand,XlaBuilder * builder,Layout * output_layout=nullptr)728 StatusOr<OptionalLiteral> ComputeConstantValueLiteral(
729 XlaOp operand, XlaBuilder* builder, Layout* output_layout = nullptr) {
730 ValueInference value_inference(builder);
731 TF_ASSIGN_OR_RETURN(auto literal, value_inference.AnalyzeConstant(
732 operand, ValueInferenceMode::kValue));
733 return literal;
734 }
735
736 se::Platform* platform_;
737 };
738
TEST_F(ConstValueInferenceTest,ConstValuePassThroughSetBound)739 TEST_F(ConstValueInferenceTest, ConstValuePassThroughSetBound) {
740 XlaBuilder b(TestName());
741 auto p0 = ConstantR0<int32_t>(&b, 32);
742 Shape shape = ShapeUtil::MakeShape(S32, {});
743 xla::Literal dynamism = xla::LiteralUtil::CreateR0<bool>(false);
744 xla::Literal bound = xla::LiteralUtil::CreateR0<int32_t>(32);
745 xla::Literal tuple =
746 xla::LiteralUtil::MakeTupleOwned(std::move(bound), std::move(dynamism));
747 auto set_bound =
748 CustomCall(&b, "SetBound", {p0}, shape, "", false, {}, &tuple);
749 auto result =
750 ComputeConstantValueLiteral(set_bound, &b).ValueOrDie().Get<int32_t>({});
751 EXPECT_TRUE(result.has_value());
752 EXPECT_EQ(result.value(), 32);
753 }
754
755 } // namespace
756 } // namespace xla
757