xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/tests/value_inference_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/value_inference.h"
17 
18 #include <memory>
19 #include <utility>
20 #include <vector>
21 
22 #include "absl/strings/match.h"
23 #include "absl/types/span.h"
24 #include "tensorflow/compiler/xla/client/client_library.h"
25 #include "tensorflow/compiler/xla/client/global_data.h"
26 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
27 #include "tensorflow/compiler/xla/client/lib/prng.h"
28 #include "tensorflow/compiler/xla/client/xla_builder.h"
29 #include "tensorflow/compiler/xla/client/xla_computation.h"
30 #include "tensorflow/compiler/xla/layout_util.h"
31 #include "tensorflow/compiler/xla/literal.h"
32 #include "tensorflow/compiler/xla/shape_util.h"
33 #include "tensorflow/compiler/xla/status_macros.h"
34 #include "tensorflow/compiler/xla/statusor.h"
35 #include "tensorflow/compiler/xla/test.h"
36 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
37 #include "tensorflow/compiler/xla/tests/test_macros.h"
38 #include "tensorflow/compiler/xla/tests/test_utils.h"
39 #include "tensorflow/compiler/xla/xla_data.pb.h"
40 #include "tensorflow/core/lib/core/status_test_util.h"
41 #include "tensorflow/core/platform/status.h"
42 #include "tensorflow/core/platform/statusor.h"
43 
44 namespace xla {
45 namespace {
46 
47 class ValueInferenceTest : public ::testing::Test {
48  public:
TestName() const49   std::string TestName() const {
50     return ::testing::UnitTest::GetInstance()->current_test_info()->name();
51   }
52 };
53 
54 class DynamismInferenceTest : public ValueInferenceTest {
55  public:
DynamismInferenceTest(se::Platform * platform=nullptr)56   explicit DynamismInferenceTest(se::Platform* platform = nullptr)
57       : platform_(platform) {}
58 
ComputeDynamismLiteral(XlaOp operand,XlaBuilder * builder,Layout * output_layout=nullptr)59   StatusOr<Literal> ComputeDynamismLiteral(XlaOp operand, XlaBuilder* builder,
60                                            Layout* output_layout = nullptr) {
61     TF_RETURN_IF_ERROR(builder->first_error());
62     ValueInference value_inference(builder);
63     TF_ASSIGN_OR_RETURN(auto literal_slice,
64                         value_inference.AnalyzeIsDynamic(operand));
65     return literal_slice.Clone();
66   }
67 
ComputeDynamismScalar(XlaOp operand,XlaBuilder * builder,ShapeIndex index={})68   StatusOr<bool> ComputeDynamismScalar(XlaOp operand, XlaBuilder* builder,
69                                        ShapeIndex index = {}) {
70     TF_ASSIGN_OR_RETURN(auto literal,
71                         ComputeDynamismLiteral(operand, builder, nullptr));
72     return literal.Get<bool>({}, index);
73   }
74 
75   se::Platform* platform_;
76 };
77 
TEST_F(DynamismInferenceTest,ScalarInt32Literal)78 TEST_F(DynamismInferenceTest, ScalarInt32Literal) {
79   XlaBuilder b(TestName());
80   auto computation = ConstantR0<int32_t>(&b, 42);
81 
82   auto value = ComputeDynamismScalar(computation, &b);
83   ASSERT_TRUE(value.ok()) << value.status();
84   // A constant is not dynamic.
85   EXPECT_EQ(value.ValueOrDie(), false);
86 }
87 
TEST_F(DynamismInferenceTest,Iota)88 TEST_F(DynamismInferenceTest, Iota) {
89   // The output of iota are consistened static.
90   XlaBuilder b(TestName());
91   auto computation = Iota(&b, S32, 2);
92   // Iota is not dynamic.
93   EXPECT_FALSE(
94       ComputeDynamismLiteral(computation, &b).ValueOrDie().Get<bool>({0}));
95 }
96 
TEST_F(DynamismInferenceTest,TupleSimple)97 TEST_F(DynamismInferenceTest, TupleSimple) {
98   XlaBuilder b(TestName());
99   auto c = ConstantR0<int32_t>(&b, 42);
100   auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
101 
102   auto tuple = Tuple(&b, {c, p});
103   EXPECT_EQ(ComputeDynamismScalar(tuple, &b, {0}).ValueOrDie(), false);
104   EXPECT_EQ(ComputeDynamismScalar(tuple, &b, {1}).ValueOrDie(), true);
105 }
106 
TEST_F(DynamismInferenceTest,TupleGteKeepsDynamism)107 TEST_F(DynamismInferenceTest, TupleGteKeepsDynamism) {
108   XlaBuilder b(TestName());
109   auto c = ConstantR0<int32_t>(&b, 42);
110   auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
111 
112   auto tuple = Tuple(&b, {c, p});
113   auto gte0 = GetTupleElement(tuple, 0);
114   auto gte1 = GetTupleElement(tuple, 1);
115   auto tuple_2 = Tuple(&b, {gte0, gte1});
116   EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {0}).ValueOrDie(), false);
117   EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {1}).ValueOrDie(), true);
118 }
119 
TEST_F(DynamismInferenceTest,PredValueUsedTwice)120 TEST_F(DynamismInferenceTest, PredValueUsedTwice) {
121   XlaBuilder b(TestName());
122   auto c = ConstantR0<int32_t>(&b, 42);
123   auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
124   auto pred = Eq(c, p);
125   auto result = Select(pred, p, c);
126   EXPECT_EQ(ComputeDynamismScalar(result, &b, {}).ValueOrDie(), true);
127 }
128 
TEST_F(DynamismInferenceTest,ReduceUsedTwice)129 TEST_F(DynamismInferenceTest, ReduceUsedTwice) {
130   XlaBuilder b(TestName());
131   auto c = ConstantR0<int32_t>(&b, 42);
132   auto p = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2}), "p0");
133   auto zero = ConstantR0<int32_t>(&b, 0);
134   XlaComputation add_s32 = CreateScalarAddComputation(S32, &b);
135   auto reduce = Reduce(p, zero, add_s32, {0});
136   auto pred = Eq(c, reduce);
137   auto result = Select(pred, reduce, c);
138   EXPECT_EQ(ComputeDynamismScalar(result, &b, {}).ValueOrDie(), true);
139 }
140 
TEST_F(DynamismInferenceTest,VariadicReduce)141 TEST_F(DynamismInferenceTest, VariadicReduce) {
142   XlaBuilder b(TestName());
143   auto c = ConstantR2<int32_t>(&b, {{0, 0}});
144   auto p = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {1, 2}), "p0");
145   // half_dynamic[0] is static, half_dynamic[0] is dynamic.
146   auto half_dynamic = ConcatInDim(&b, {c, p}, 0);
147   XlaBuilder reduce_add("reduce_add");
148   auto p0 = Parameter(&reduce_add, 0, ShapeUtil::MakeScalarShape(S32), "p");
149   auto p1 = Parameter(&reduce_add, 1, ShapeUtil::MakeScalarShape(S32), "p");
150   auto p2 = Parameter(&reduce_add, 2, ShapeUtil::MakeScalarShape(S32), "p");
151   auto p3 = Parameter(&reduce_add, 3, ShapeUtil::MakeScalarShape(S32), "p");
152   auto reduce_result = p0;
153   reduce_result = Add(reduce_result, p1);
154   reduce_result = Add(reduce_result, p2);
155   reduce_result = Add(reduce_result, p3);
156   Tuple(&reduce_add, {reduce_result, reduce_result});
157   auto init = ConstantR0<int32_t>(&b, 0);
158   auto variadic_reduce = Reduce(&b, {half_dynamic, half_dynamic}, {init, init},
159                                 reduce_add.Build().value(), {1});
160   auto result = GetTupleElement(variadic_reduce, 0);
161 
162   // result[0] should be static; result[1] should be dynamic.
163   EXPECT_FALSE(ComputeDynamismLiteral(result, &b).ValueOrDie().Get<bool>({0}));
164   EXPECT_TRUE(ComputeDynamismLiteral(result, &b).ValueOrDie().Get<bool>({1}));
165 }
166 
TEST_F(DynamismInferenceTest,DynamicSelectorWithMixedValues)167 TEST_F(DynamismInferenceTest, DynamicSelectorWithMixedValues) {
168   XlaBuilder b(TestName());
169   auto constant_pred = ConstantR1<bool>(&b, {true});
170   auto dynamic_pred = Parameter(&b, 0, ShapeUtil::MakeShape(PRED, {1}), "p0");
171   auto concat = ConcatInDim(&b, {constant_pred, dynamic_pred}, 0);
172   auto constant_values = ConstantR1<bool>(&b, {true, true});
173   auto result = Select(concat, constant_values, constant_values);
174   // First result is static (selector is constant, both values are constant).
175   // Iota is not dynamic.
176   EXPECT_FALSE(ComputeDynamismLiteral(result, &b).ValueOrDie().Get<bool>({0}));
177   // Second result is dynamic (selector is dynamic).
178   EXPECT_TRUE(ComputeDynamismLiteral(result, &b).ValueOrDie().Get<bool>({1}));
179 }
180 
TEST_F(DynamismInferenceTest,ConcatSliceReshapeKeepsDynamism)181 TEST_F(DynamismInferenceTest, ConcatSliceReshapeKeepsDynamism) {
182   XlaBuilder b(TestName());
183   auto c = ConstantR0<int32_t>(&b, 42);
184   auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
185 
186   auto concat = ConcatScalars(&b, {c, p});
187   auto slice0 = SliceInDim(concat, 0, 1, 1, 0);
188   auto reshape0 = Reshape(slice0, {});
189   auto slice1 = SliceInDim(concat, 1, 2, 1, 0);
190   auto reshape1 = Reshape(slice1, {});
191   auto tuple_2 = Tuple(&b, {reshape0, reshape1});
192   EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {0}).ValueOrDie(), false);
193   EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {1}).ValueOrDie(), true);
194 }
195 
TEST_F(DynamismInferenceTest,ParameterIsDynamic)196 TEST_F(DynamismInferenceTest, ParameterIsDynamic) {
197   XlaBuilder b(TestName());
198   auto computation = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
199 
200   auto value = ComputeDynamismScalar(computation, &b);
201   ASSERT_TRUE(value.ok()) << value.status();
202   // A parameter is considered dynamic.
203   EXPECT_EQ(value.ValueOrDie(), true);
204 }
205 
TEST_F(DynamismInferenceTest,UnaryOpKeepsDynamism)206 TEST_F(DynamismInferenceTest, UnaryOpKeepsDynamism) {
207   XlaBuilder b(TestName());
208   auto c = ConstantR0<int32_t>(&b, 42);
209   auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
210 
211   auto neg0 = Neg(c);
212   auto neg1 = Neg(p);
213   auto tuple_2 = Tuple(&b, {neg0, neg1});
214   EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {0}).ValueOrDie(), false);
215   EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {1}).ValueOrDie(), true);
216 }
217 
TEST_F(DynamismInferenceTest,ParameterWithToken)218 TEST_F(DynamismInferenceTest, ParameterWithToken) {
219   // Test that token shape can be handled in a parameter.
220   XlaBuilder b(TestName());
221   auto p =
222       Parameter(&b, 0,
223                 ShapeUtil::MakeTupleShape({ShapeUtil::MakeTokenShape(),
224                                            ShapeUtil::MakeScalarShape(S32)}),
225                 "p0");
226   EXPECT_EQ(ComputeDynamismScalar(p, &b, {0}).ValueOrDie(), true);
227   EXPECT_EQ(ComputeDynamismScalar(p, &b, {1}).ValueOrDie(), true);
228 }
229 
TEST_F(DynamismInferenceTest,BinaryOpsOrsDynamism)230 TEST_F(DynamismInferenceTest, BinaryOpsOrsDynamism) {
231   XlaBuilder b(TestName());
232   auto c = ConstantR0<int32_t>(&b, 42);
233   auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
234 
235   // Static value + static value = static
236   auto add1 = Add(c, c);
237   // Dynamic value + dynamic value = dynamic
238   auto add2 = Add(p, c);
239   auto tuple_2 = Tuple(&b, {add1, add2});
240   EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {0}).ValueOrDie(), false);
241   EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {1}).ValueOrDie(), true);
242 }
243 
TEST_F(DynamismInferenceTest,GetDimensionSize)244 TEST_F(DynamismInferenceTest, GetDimensionSize) {
245   XlaBuilder b(TestName());
246   // param = Param([<=2, 3])
247   // get_dimension_size(param, 0) is dynamic
248   // get_dimension_size(param, 1) is static
249   auto p =
250       Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, false}), "p0");
251 
252   auto gds0 = GetDimensionSize(p, 0);
253   auto gds1 = GetDimensionSize(p, 1);
254   auto tuple_2 = Tuple(&b, {gds0, gds1});
255   EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {0}).ValueOrDie(), true);
256   EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {1}).ValueOrDie(), false);
257 }
258 
TEST_F(DynamismInferenceTest,DynamicSliceWithConstantOperands)259 TEST_F(DynamismInferenceTest, DynamicSliceWithConstantOperands) {
260   XlaBuilder b(TestName());
261 
262   auto constant = ConstantR1<int32_t>(&b, {0, 1, 2, 3});
263   auto slice_start = ConstantR0(&b, 1);
264   auto dynamic_slice = DynamicSlice(constant, {slice_start}, {1});
265   EXPECT_FALSE(
266       ComputeDynamismLiteral(dynamic_slice, &b).ValueOrDie().Get<bool>({0}));
267 }
268 
TEST_F(DynamismInferenceTest,GatherWithCommonParent)269 TEST_F(DynamismInferenceTest, GatherWithCommonParent) {
270   XlaBuilder b(TestName());
271   // Test the analysis on a gather where first operand and second operand have
272   // common parents.
273   Shape indices_shape = ShapeUtil::MakeShape(S32, {2});
274 
275   auto operand1 = Parameter(&b, 0, indices_shape, "p1");
276   auto operand2 = Parameter(&b, 1, indices_shape, "p2");
277   auto indices = Sub(operand1, operand2);
278   GatherDimensionNumbers dim_numbers;
279   dim_numbers.add_offset_dims(1);
280   dim_numbers.add_start_index_map(0);
281   dim_numbers.set_index_vector_dim(1);
282   auto gather = Gather(operand1, indices, dim_numbers, {1});
283   ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
284   EXPECT_TRUE(
285       ComputeDynamismLiteral(gather, &b).ValueOrDie().Get<bool>({0, 0}));
286 }
287 
TEST_F(DynamismInferenceTest,GatherWithConstantParent)288 TEST_F(DynamismInferenceTest, GatherWithConstantParent) {
289   XlaBuilder b(TestName());
290   // Test the analysis on a gather.
291   Shape indices_shape = ShapeUtil::MakeShape(S32, {2});
292   auto data_operand = ConstantR1<int32_t>(&b, {1, 2});
293   auto indices = ConstantR1<int32_t>(&b, {1, 2});
294   GatherDimensionNumbers dim_numbers;
295   dim_numbers.add_offset_dims(1);
296   dim_numbers.add_start_index_map(0);
297   dim_numbers.set_index_vector_dim(1);
298   auto gather = Gather(data_operand, indices, dim_numbers, {1});
299   ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
300   // Everything is constant, result is also contant.
301   EXPECT_FALSE(
302       ComputeDynamismLiteral(gather, &b).ValueOrDie().Get<bool>({0, 0}));
303 }
304 
TEST_F(DynamismInferenceTest,GatherWithSharedConstantParent)305 TEST_F(DynamismInferenceTest, GatherWithSharedConstantParent) {
306   XlaBuilder b(TestName());
307   // Test the analysis on a gather.
308   Shape indices_shape = ShapeUtil::MakeShape(S32, {2});
309   auto operand1 = ConstantR1<int32_t>(&b, {1, 2});
310   auto operand2 = ConstantR1<int32_t>(&b, {1, 2});
311   auto indices = Sub(operand1, operand2);
312   GatherDimensionNumbers dim_numbers;
313   dim_numbers.add_offset_dims(1);
314   dim_numbers.add_start_index_map(0);
315   dim_numbers.set_index_vector_dim(1);
316   auto gather = Gather(operand1, indices, dim_numbers, {1});
317   ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
318   // Everything is constant, result is also contant.
319   EXPECT_FALSE(
320       ComputeDynamismLiteral(gather, &b).ValueOrDie().Get<bool>({0, 0}));
321 }
322 
TEST_F(DynamismInferenceTest,InferThroughPad)323 TEST_F(DynamismInferenceTest, InferThroughPad) {
324   XlaBuilder b(TestName());
325   // Test the analysis on a gather.
326   auto operand1 = ConstantR1<int32_t>(&b, {1, 2});
327   auto parameter = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {}), "p0");
328   PaddingConfig padding_config;
329   padding_config.add_dimensions()->set_edge_padding_high(1);
330   // After pad the value is [constant, constant, parameter].
331   auto pad = Pad(operand1, parameter, padding_config);
332   ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
333   // Everything is constant, result is also contant.
334   EXPECT_FALSE(ComputeDynamismLiteral(pad, &b).ValueOrDie().Get<bool>({0}));
335   EXPECT_FALSE(ComputeDynamismLiteral(pad, &b).ValueOrDie().Get<bool>({1}));
336   EXPECT_TRUE(ComputeDynamismLiteral(pad, &b).ValueOrDie().Get<bool>({2}));
337 }
338 
TEST_F(DynamismInferenceTest,InferThroughConditionalBranchesAreSame)339 TEST_F(DynamismInferenceTest, InferThroughConditionalBranchesAreSame) {
340   // The result of following conditional is static.
341   // pred = .. # a dynamic value
342   // if (pred) {
343   //  return (1) # both branches return the same value
344   // } else {
345   //  return (1)
346   // }
347   //
348 
349   auto s32_shape = ShapeUtil::MakeShape(S32, {});
350   auto cond_shape = ShapeUtil::MakeTupleShape({s32_shape});
351   XlaBuilder true_builder("true");
352   Parameter(&true_builder, 0, s32_shape, "cond_param");
353   Tuple(&true_builder, {ConstantR0<int32_t>(&true_builder, 1)});
354   auto true_computation = true_builder.Build().ValueOrDie();
355 
356   XlaBuilder false_builder("false");
357   Parameter(&false_builder, 0, s32_shape, "cond_param");
358   Tuple(&false_builder, {ConstantR0<int32_t>(&false_builder, 1)});
359   auto false_computation = false_builder.Build().ValueOrDie();
360 
361   XlaBuilder b(TestName());
362   auto parameter = Parameter(&b, 0, ShapeUtil::MakeShape(PRED, {}), "p0");
363   auto constant = ConstantR0<int32_t>(&b, 0);
364   auto cond = Conditional(parameter, constant, true_computation, constant,
365                           false_computation);
366   auto gte = GetTupleElement(cond, 0);
367   ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
368   // Result is not dynamic.
369   EXPECT_FALSE(ComputeDynamismLiteral(gte, &b).ValueOrDie().Get<bool>({}));
370 }
371 
TEST_F(DynamismInferenceTest,InferThroughCall)372 TEST_F(DynamismInferenceTest, InferThroughCall) {
373   // The result of following call instruction is static.
374   //
375   // Callee:
376   //   p = param
377   //   return p
378   //
379   // Entry:
380   //   c = constant(3)
381   //   return call(c), callee
382   //
383   //
384 
385   auto s32_shape = ShapeUtil::MakeShape(S32, {});
386   XlaBuilder call_builder("call");
387   Parameter(&call_builder, 0, s32_shape, "call_param");
388   auto call_computation = call_builder.Build().ValueOrDie();
389 
390   XlaBuilder b(TestName());
391   auto constant = ConstantR0<int32_t>(&b, 3);
392   auto call = Call(&b, call_computation, {constant});
393   ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
394   // Result is static.
395   EXPECT_EQ(ComputeDynamismScalar(call, &b, {}).ValueOrDie(), false);
396 }
397 
TEST_F(DynamismInferenceTest,InferThroughConditionalBranchesAreNotSame)398 TEST_F(DynamismInferenceTest, InferThroughConditionalBranchesAreNotSame) {
399   // The result of following conditional is dynamic.
400   // pred = .. # a dynamic value
401   // if (pred) {
402   //  return (1) # These two branches return different values.
403   // } else {
404   //  return (2)
405   // }
406   //
407 
408   auto s32_shape = ShapeUtil::MakeShape(S32, {});
409   auto cond_shape = ShapeUtil::MakeTupleShape({s32_shape});
410   XlaBuilder true_builder("true");
411   Parameter(&true_builder, 0, s32_shape, "cond_param");
412   Tuple(&true_builder, {ConstantR0<int32_t>(&true_builder, 1)});
413   auto true_computation = true_builder.Build().ValueOrDie();
414 
415   XlaBuilder false_builder("false");
416   Parameter(&false_builder, 0, s32_shape, "cond_param");
417   Tuple(&false_builder, {ConstantR0<int32_t>(&false_builder, 2)});
418   auto false_computation = false_builder.Build().ValueOrDie();
419 
420   XlaBuilder b(TestName());
421   auto parameter = Parameter(&b, 0, ShapeUtil::MakeShape(PRED, {}), "p0");
422   auto constant = ConstantR0<int32_t>(&b, 0);
423   auto cond = Conditional(parameter, constant, true_computation, constant,
424                           false_computation);
425   auto gte = GetTupleElement(cond, 0);
426   ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
427   // Result is dynamic.
428   EXPECT_TRUE(ComputeDynamismLiteral(gte, &b).ValueOrDie().Get<bool>({}));
429 }
430 
TEST_F(DynamismInferenceTest,InferThroughConditionalPredIsConstantTrueBranch)431 TEST_F(DynamismInferenceTest, InferThroughConditionalPredIsConstantTrueBranch) {
432   // The result of following conditional is static.
433   // pred = true
434   // if (pred) {
435   //  return (1)
436   // } else {
437   //  return (..dynamic_value...)
438   // }
439   //
440 
441   auto s32_shape = ShapeUtil::MakeShape(S32, {});
442   auto cond_shape = ShapeUtil::MakeTupleShape({s32_shape});
443   XlaBuilder true_builder("true");
444   Parameter(&true_builder, 0, s32_shape, "cond_param");
445   Tuple(&true_builder, {ConstantR0<int32_t>(&true_builder, 0)});
446   auto true_computation = true_builder.Build().ValueOrDie();
447 
448   XlaBuilder false_builder("false");
449   Tuple(&false_builder,
450         {Parameter(&false_builder, 0, s32_shape, "cond_param")});
451   auto false_computation = false_builder.Build().ValueOrDie();
452 
453   XlaBuilder b(TestName());
454   auto pred = ConstantR0<bool>(&b, true);
455   auto constant = ConstantR0<int32_t>(&b, 0);
456   auto cond = Conditional(pred, constant, true_computation, constant,
457                           false_computation);
458   auto gte = GetTupleElement(cond, 0);
459   ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
460   // Result is not dynamic.
461   EXPECT_FALSE(ComputeDynamismLiteral(gte, &b).ValueOrDie().Get<bool>({}));
462 }
463 
TEST_F(DynamismInferenceTest,InferThroughConditionalPredIsConstantFalseBranch)464 TEST_F(DynamismInferenceTest,
465        InferThroughConditionalPredIsConstantFalseBranch) {
466   // The result of following conditional is dynamic.
467   // pred = false
468   // if (pred) {
469   //  return (1)
470   // } else {
471   //  return (..dynamic_value...)
472   // }
473   //
474 
475   auto s32_shape = ShapeUtil::MakeShape(S32, {});
476   auto cond_shape = ShapeUtil::MakeTupleShape({s32_shape});
477   XlaBuilder true_builder("true");
478   Parameter(&true_builder, 0, s32_shape, "cond_param");
479   Tuple(&true_builder, {ConstantR0<int32_t>(&true_builder, 0)});
480   auto true_computation = true_builder.Build().ValueOrDie();
481 
482   XlaBuilder false_builder("false");
483   Tuple(&false_builder,
484         {Parameter(&false_builder, 0, s32_shape, "cond_param")});
485   auto false_computation = false_builder.Build().ValueOrDie();
486 
487   XlaBuilder b(TestName());
488   auto param = Parameter(&b, 0, s32_shape, "param");
489   auto pred = ConstantR0<bool>(&b, false);
490   auto constant = ConstantR0<int32_t>(&b, 0);
491   auto cond =
492       Conditional(pred, constant, true_computation, param, false_computation);
493   auto gte = GetTupleElement(cond, 0);
494   ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
495   // Result is dynamic.
496   EXPECT_TRUE(ComputeDynamismLiteral(gte, &b).ValueOrDie().Get<bool>({}));
497 }
498 
TEST_F(DynamismInferenceTest,ArgumentForwardingNestedTuple)499 TEST_F(DynamismInferenceTest, ArgumentForwardingNestedTuple) {
500   // The result of following conditional is considered static.
501   // pred = .. dynamic value..
502   //
503   // op = 1
504   // if (pred) {
505   //   if (pred) {
506   //     return op
507   //   } else {
508   //     return op
509   //   }
510   // } else {
511   //   if (pred) {
512   //     return op
513   //   } else {
514   //     return op
515   //   }
516   // }
517   //
518   auto pred_shape = ShapeUtil::MakeShape(PRED, {});
519   auto s32_shape = ShapeUtil::MakeShape(S32, {});
520   auto tuple_shape = ShapeUtil::MakeTupleShape({pred_shape, s32_shape});
521   auto cond_shape = ShapeUtil::MakeTupleShape({s32_shape});
522   XlaBuilder inner_true_builder("inner_true");
523   Parameter(&inner_true_builder, 0, s32_shape, "cond_param");
524   Tuple(&inner_true_builder, {ConstantR0<int32_t>(&inner_true_builder, 0)});
525   auto inner_true_computation = inner_true_builder.Build().ValueOrDie();
526 
527   XlaBuilder inner_false_builder("inner_false");
528   Tuple(&inner_false_builder,
529         {Parameter(&inner_false_builder, 0, s32_shape, "cond_param")});
530   auto inner_false_computation = inner_false_builder.Build().ValueOrDie();
531 
532   XlaBuilder true_builder("true");
533   {
534     auto param = Parameter(&true_builder, 0, tuple_shape, "param");
535     auto op = GetTupleElement(param, 1);
536     auto pred = GetTupleElement(param, 0);
537     Conditional(pred, op, inner_true_computation, op, inner_false_computation);
538   }
539   auto true_computation = true_builder.Build().ValueOrDie();
540   XlaBuilder false_builder("false");
541   {
542     auto param = Parameter(&false_builder, 0, tuple_shape, "param");
543     auto op = GetTupleElement(param, 1);
544     auto pred = GetTupleElement(param, 0);
545     Conditional(pred, op, inner_true_computation, op, inner_false_computation);
546   }
547   auto false_computation = false_builder.Build().ValueOrDie();
548   XlaBuilder b(TestName());
549   auto constant = ConstantR0<int32_t>(&b, 0);
550   auto pred = Parameter(&b, 0, pred_shape, "param");
551   auto param = Tuple(&b, {pred, constant});
552   auto cond =
553       Conditional(pred, param, true_computation, param, false_computation);
554   auto gte = GetTupleElement(cond, 0);
555   ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
556   // Result is static.
557   EXPECT_FALSE(ComputeDynamismLiteral(gte, &b).ValueOrDie().Get<bool>({}));
558 }
559 
560 class UpperBoundInferenceTest : public ValueInferenceTest {
561  public:
UpperBoundInferenceTest(se::Platform * platform=nullptr)562   explicit UpperBoundInferenceTest(se::Platform* platform = nullptr)
563       : platform_(platform) {}
564 
ComputeUpperBoundLiteral(XlaOp operand,XlaBuilder * builder,Layout * output_layout=nullptr)565   StatusOr<OptionalLiteral> ComputeUpperBoundLiteral(
566       XlaOp operand, XlaBuilder* builder, Layout* output_layout = nullptr) {
567     ValueInference value_inference(builder);
568     TF_ASSIGN_OR_RETURN(auto literal,
569                         value_inference.AnalyzeConstant(
570                             operand, ValueInferenceMode::kUpperBound));
571     return literal;
572   }
573 
574   se::Platform* platform_;
575 };
576 
TEST_F(UpperBoundInferenceTest,GetDimensionSize)577 TEST_F(UpperBoundInferenceTest, GetDimensionSize) {
578   XlaBuilder b(TestName());
579   auto p =
580       Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, false}), "p0");
581 
582   auto gds0 = GetDimensionSize(p, 0);
583   auto gds1 = GetDimensionSize(p, 1);
584   auto tuple_2 = Tuple(&b, {gds0, gds1});
585   EXPECT_EQ(
586       ComputeUpperBoundLiteral(tuple_2, &b).ValueOrDie().Get<int32_t>({}, {0}),
587       2);
588   EXPECT_EQ(
589       ComputeUpperBoundLiteral(tuple_2, &b).ValueOrDie().Get<int32_t>({}, {1}),
590       3);
591 }
592 
TEST_F(UpperBoundInferenceTest,GetDimensionSizeSub)593 TEST_F(UpperBoundInferenceTest, GetDimensionSizeSub) {
594   XlaBuilder b(TestName());
595   auto p =
596       Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, false}), "p0");
597 
598   // The range of the first dimension is [0, 2]
599   auto gds0 = GetDimensionSize(p, 0);
600   // The range of the second dimension is [3, 3]
601   auto gds1 = GetDimensionSize(p, 1);
602   // Upper bound of `second_dimension - first_dimension` is 3 - 0 = 3
603   auto sub = Sub(gds1, gds0);
604   EXPECT_EQ(ComputeUpperBoundLiteral(sub, &b).ValueOrDie().Get<int32_t>({}), 3);
605 }
606 
TEST_F(UpperBoundInferenceTest,GetDimensionSizeDiv)607 TEST_F(UpperBoundInferenceTest, GetDimensionSizeDiv) {
608   XlaBuilder b(TestName());
609   auto p =
610       Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, false}), "p0");
611   // The range of the first dimension is [0, 2]
612   auto gds0 = GetDimensionSize(p, 0);
613   // The range of the second dimension is [3, 3]
614   auto gds1 = GetDimensionSize(p, 1);
615   // Upper bound of `second_dimension / first_dimension` is 3 / 1 = 3. Notice we
616   // don't use 0 as the lower bound as it would create divide-by-zero error.
617   auto div = Div(gds1, gds0);
618   EXPECT_EQ(ComputeUpperBoundLiteral(div, &b).ValueOrDie().Get<int32_t>({}), 3);
619 }
620 
TEST_F(UpperBoundInferenceTest,SumSubtract)621 TEST_F(UpperBoundInferenceTest, SumSubtract) {
622   // If x = a, y = b - a
623   // upperbound(x + y) should be upperbound(b)
624   XlaBuilder b(TestName());
625   auto p =
626       Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, true}), "p0");
627   // The range of the first dimension is [0, 2]
628   auto gds0 = GetDimensionSize(p, 0);
629   // The range of the second dimension is [0, 3]
630   auto gds1 = GetDimensionSize(p, 1);
631   auto sub = Sub(gds1, gds0);
632   auto add = Add(sub, gds0);
633   EXPECT_EQ(ComputeUpperBoundLiteral(add, &b).ValueOrDie().Get<int32_t>({}), 3);
634   auto add2 = Add(gds1, gds0);
635   // upperbound(gds1 - gds0 + gds1 + gds0) ==> upperbound(2 * gds1)
636   auto add3 = Add(sub, add2);
637   EXPECT_EQ(ComputeUpperBoundLiteral(add3, &b).ValueOrDie().Get<int32_t>({}),
638             6);
639 }
640 
TEST_F(UpperBoundInferenceTest,SumSubtractWithDataShuffling)641 TEST_F(UpperBoundInferenceTest, SumSubtractWithDataShuffling) {
642   // Similar to the test above, but with some data shuffling ops in it
643   // (broadcast, slice, reshape, identity convert, etc).
644   XlaBuilder b(TestName());
645   auto p =
646       Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, true}), "p0");
647   // The range of the first dimension is [0, 2]
648   auto gds0 = GetDimensionSize(p, 0);
649   // The range of the second dimension is [0, 3]
650   auto gds1 = GetDimensionSize(p, 1);
651   auto broadcast = Broadcast(gds0, {1, 10});
652   auto convert = ConvertElementType(broadcast, S32);  // Identity convert.
653   auto slice = SliceInDim(convert, /*start_index=*/0, /*limit_index=*/1,
654                           /*stride=*/1, /*dimno=*/1);
655   gds0 = Reshape(slice, {});
656   auto sub = Sub(gds1, gds0);
657   auto add = Add(sub, gds0);
658   EXPECT_EQ(ComputeUpperBoundLiteral(add, &b).ValueOrDie().Get<int32_t>({}), 3);
659   auto add2 = Add(gds1, gds0);
660   // upperbound(gds1 - gds0 + gds1 + gds0) ==> upperbound(2 * gds1)
661   auto add3 = Add(sub, add2);
662   EXPECT_EQ(ComputeUpperBoundLiteral(add3, &b).ValueOrDie().Get<int32_t>({}),
663             6);
664 }
665 
TEST_F(UpperBoundInferenceTest,SumSubtractEquivalentGetDimensionSize)666 TEST_F(UpperBoundInferenceTest, SumSubtractEquivalentGetDimensionSize) {
667   XlaBuilder b(TestName());
668   auto p =
669       Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, true}), "p0");
670   // The range of the first dimension is [0, 2]
671   auto gds0 = GetDimensionSize(p, 0);
672   // The range of the second dimension is [0, 3]
673   auto gds1 = GetDimensionSize(p, 1);
674   // gds2 is equivalent to gds0
675   auto gds2 = GetDimensionSize(p, 0);
676   auto sub = Sub(gds1, gds2);
677   auto add = Add(sub, gds0);
678   // upperbound(gds0 + gds1 - gds2) is equal to upperbound(gds1) if gds0 ==
679   // gds2.
680   EXPECT_EQ(ComputeUpperBoundLiteral(add, &b).ValueOrDie().Get<int32_t>({}), 3);
681 }
682 
TEST_F(UpperBoundInferenceTest,ParamCantInferBound)683 TEST_F(UpperBoundInferenceTest, ParamCantInferBound) {
684   // We can infer a parameter's dimension's bound, but not the parameter value's
685   // bound.
686   XlaBuilder b(TestName());
687   auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2}, {true}), "p0");
688   auto p1 = Parameter(&b, 1, ShapeUtil::MakeShape(S32, {}, {}), "p1");
689   auto gds = GetDimensionSize(p0, 0);
690   auto sub = Div(gds, p1);
691   EXPECT_FALSE(ComputeUpperBoundLiteral(sub, &b)
692                    .ValueOrDie()
693                    .Get<int32_t>({})
694                    .has_value());
695 }
696 
TEST_F(UpperBoundInferenceTest,KeyValueSort)697 TEST_F(UpperBoundInferenceTest, KeyValueSort) {
698   XlaBuilder comparator_b("comparator");
699   auto p0 = Parameter(&comparator_b, 0, ShapeUtil::MakeShape(S32, {}), "p0");
700   auto p1 = Parameter(&comparator_b, 1, ShapeUtil::MakeShape(S32, {}), "p1");
701   Parameter(&comparator_b, 2, ShapeUtil::MakeShape(S32, {}), "p2");
702   Parameter(&comparator_b, 3, ShapeUtil::MakeShape(S32, {}), "p3");
703   Compare(p0, p1, ComparisonDirection::kGe);
704   TF_ASSERT_OK_AND_ASSIGN(auto comparator, comparator_b.Build());
705 
706   int64_t elem_count = 17;
707   XlaBuilder b(TestName());
708   auto param = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {elem_count}), "p0");
709   auto iota = Iota(&b, S32, elem_count);
710   auto sort = Sort({param, iota}, comparator);
711   auto gte = GetTupleElement(sort, 1);
712 
713   for (int64_t i = 0; i < elem_count; ++i) {
714     auto result_first_elem =
715         ComputeUpperBoundLiteral(gte, &b).ValueOrDie().Get<int32_t>({i});
716     // We can infer the bound of sort.
717     EXPECT_TRUE(result_first_elem.has_value());
718     // The bound of the sort result is the max value in the input.
719     EXPECT_EQ(result_first_elem.value(), elem_count - 1);
720   }
721 }
722 
723 class ConstValueInferenceTest : public ValueInferenceTest {
724  public:
ConstValueInferenceTest(se::Platform * platform=nullptr)725   explicit ConstValueInferenceTest(se::Platform* platform = nullptr)
726       : platform_(platform) {}
727 
ComputeConstantValueLiteral(XlaOp operand,XlaBuilder * builder,Layout * output_layout=nullptr)728   StatusOr<OptionalLiteral> ComputeConstantValueLiteral(
729       XlaOp operand, XlaBuilder* builder, Layout* output_layout = nullptr) {
730     ValueInference value_inference(builder);
731     TF_ASSIGN_OR_RETURN(auto literal, value_inference.AnalyzeConstant(
732                                           operand, ValueInferenceMode::kValue));
733     return literal;
734   }
735 
736   se::Platform* platform_;
737 };
738 
TEST_F(ConstValueInferenceTest,ConstValuePassThroughSetBound)739 TEST_F(ConstValueInferenceTest, ConstValuePassThroughSetBound) {
740   XlaBuilder b(TestName());
741   auto p0 = ConstantR0<int32_t>(&b, 32);
742   Shape shape = ShapeUtil::MakeShape(S32, {});
743   xla::Literal dynamism = xla::LiteralUtil::CreateR0<bool>(false);
744   xla::Literal bound = xla::LiteralUtil::CreateR0<int32_t>(32);
745   xla::Literal tuple =
746       xla::LiteralUtil::MakeTupleOwned(std::move(bound), std::move(dynamism));
747   auto set_bound =
748       CustomCall(&b, "SetBound", {p0}, shape, "", false, {}, &tuple);
749   auto result =
750       ComputeConstantValueLiteral(set_bound, &b).ValueOrDie().Get<int32_t>({});
751   EXPECT_TRUE(result.has_value());
752   EXPECT_EQ(result.value(), 32);
753 }
754 
755 }  // namespace
756 }  // namespace xla
757