xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/shape_inference_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/shape_inference.h"
17 
18 #include <string>
19 #include <utility>
20 
21 #include "absl/strings/string_view.h"
22 #include "absl/strings/substitute.h"
23 #include "absl/types/span.h"
24 #include "tensorflow/compiler/xla/client/padding.h"
25 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/test.h"
28 #include "tensorflow/compiler/xla/test_helpers.h"
29 #include "tensorflow/compiler/xla/types.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31 
32 namespace xla {
33 namespace {
34 
35 using ::testing::ContainsRegex;
36 using ::testing::HasSubstr;
37 
38 class ShapeInferenceTest : public ::testing::Test {
39  protected:
40   // Some handy scalar shapes.
41   const Shape s32_ = ShapeUtil::MakeShape(S32, {});
42   const Shape f16_ = ShapeUtil::MakeShape(F16, {});
43   const Shape f32_ = ShapeUtil::MakeShape(F32, {});
44   const Shape f64_ = ShapeUtil::MakeShape(F64, {});
45   const Shape pred_ = ShapeUtil::MakeShape(PRED, {});
46 
47   // Some handy vector and matrix shapes of F32 type.
48   // Suffix: vector_length_, matrix_rows_cols_
49   const Shape vector_32_ = ShapeUtil::MakeShape(F32, {32});
50   const Shape vector_64_ = ShapeUtil::MakeShape(F32, {64});
51   const Shape matrix_32_48_ = ShapeUtil::MakeShape(F32, {32, 48});
52   const Shape matrix_32_64_ = ShapeUtil::MakeShape(F32, {32, 64});
53   const Shape matrix_64_48_ = ShapeUtil::MakeShape(F32, {64, 48});
54 
55   // Some handy S32 arrays.
56   const Shape s32matrix_64_64_ = ShapeUtil::MakeShape(S32, {64, 64});
57 };
58 
59 // Subclass for testing InferReduceShape.
60 class ReduceShapeInferenceTest : public ShapeInferenceTest {
61  protected:
62   // Helper that runs reduce shape inference with the input 'arg' and given
63   // dimensions to reduce, and checks the inferred shape is as expected. The
64   // element type here is hard-coded to F32.
ExpectInferredReduceShape(const Shape & expected_inferred_shape,const Shape & arg,absl::Span<const int64_t> dimensions_to_reduce)65   void ExpectInferredReduceShape(
66       const Shape& expected_inferred_shape, const Shape& arg,
67       absl::Span<const int64_t> dimensions_to_reduce) {
68     ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_);
69     auto inferred_status = ShapeInference::InferReduceShape(
70         {&arg, &f32_}, dimensions_to_reduce, to_apply);
71     EXPECT_IS_OK(inferred_status.status());
72     EXPECT_TRUE(ShapeUtil::Equal(expected_inferred_shape,
73                                  inferred_status.ValueOrDie()));
74   }
75 };
76 
77 // Subclass for testing InferSelectAndScatterShape.
78 class SelectAndScatterShapeInferenceTest : public ShapeInferenceTest {
79  protected:
SelectAndScatterShapeInferenceTest()80   SelectAndScatterShapeInferenceTest() {
81     operand_shape_ = ShapeUtil::MakeShape(F32, {8, 16});
82     source_shape_ = ShapeUtil::MakeShape(F32, {4, 8});
83     WindowDimension dim;
84     dim.set_size(2);
85     dim.set_stride(2);
86     dim.set_padding_low(0);
87     dim.set_padding_high(0);
88     dim.set_window_dilation(1);
89     dim.set_base_dilation(1);
90     *window_.add_dimensions() = dim;
91     *window_.add_dimensions() = dim;
92     init_value_shape_ = ShapeUtil::MakeShape(F32, {});
93     select_program_shape_ = ShapeUtil::MakeProgramShape(
94         {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, pred_);
95     scatter_program_shape_ = ShapeUtil::MakeProgramShape(
96         {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, f32_);
97   }
98 
99   Shape operand_shape_;
100   Shape source_shape_;
101   Window window_;
102   Shape init_value_shape_;
103   ProgramShape select_program_shape_;
104   ProgramShape scatter_program_shape_;
105 };
106 
TEST_F(ShapeInferenceTest,UnaryNegateMatrix)107 TEST_F(ShapeInferenceTest, UnaryNegateMatrix) {
108   Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
109   auto inferred_status =
110       ShapeInference::InferUnaryOpShape(HloOpcode::kNegate, matrix_shape);
111   ASSERT_IS_OK(inferred_status.status());
112   ASSERT_TRUE(ShapeUtil::Equal(matrix_shape, inferred_status.ValueOrDie()));
113 }
114 
TEST_F(ShapeInferenceTest,SelectScalarPredBetweenTuples)115 TEST_F(ShapeInferenceTest, SelectScalarPredBetweenTuples) {
116   Shape tuple = ShapeUtil::MakeTupleShape({s32_, f32_});
117   auto inferred_status = ShapeInference::InferTernaryOpShape(
118       HloOpcode::kSelect, pred_, tuple, tuple);
119   ASSERT_FALSE(inferred_status.ok());
120   ASSERT_THAT(inferred_status.status().error_message(),
121               HasSubstr("Expected array argument for select"));
122 }
123 
TEST_F(ShapeInferenceTest,SelectScalarPredBetweenArrays)124 TEST_F(ShapeInferenceTest, SelectScalarPredBetweenArrays) {
125   auto inferred_status = ShapeInference::InferTernaryOpShape(
126       HloOpcode::kSelect, pred_, matrix_64_48_, matrix_64_48_);
127   ASSERT_FALSE(inferred_status.ok());
128   ASSERT_THAT(
129       inferred_status.status().error_message(),
130       HasSubstr("Operands to select and predicate must be the same shape"));
131 }
132 
TEST_F(ShapeInferenceTest,SelectArrayPredBetweenArrays)133 TEST_F(ShapeInferenceTest, SelectArrayPredBetweenArrays) {
134   auto predarray = ShapeUtil::MakeShape(PRED, {64, 48});
135   auto inferred_status = ShapeInference::InferTernaryOpShape(
136       HloOpcode::kSelect, predarray, matrix_64_48_, matrix_64_48_);
137   ASSERT_IS_OK(inferred_status.status());
138   ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
139 }
140 
TEST_F(ShapeInferenceTest,SelectBadShapes)141 TEST_F(ShapeInferenceTest, SelectBadShapes) {
142   auto inferred_status_error1 = ShapeInference::InferTernaryOpShape(
143       HloOpcode::kSelect, pred_, matrix_64_48_, matrix_32_64_);
144   ASSERT_FALSE(inferred_status_error1.ok());
145   ASSERT_THAT(inferred_status_error1.status().error_message(),
146               HasSubstr("Operands to select must be the same shape"));
147 
148   auto inferred_status_error2 = ShapeInference::InferTernaryOpShape(
149       HloOpcode::kSelect, s32_, matrix_64_48_, matrix_64_48_);
150   ASSERT_FALSE(inferred_status_error2.ok());
151   ASSERT_THAT(inferred_status_error2.status().error_message(),
152               HasSubstr("pred operand must have PRED"));
153 
154   auto inferred_status_error3 = ShapeInference::InferTernaryOpShape(
155       HloOpcode::kSelect, ShapeUtil::MakeShape(PRED, {64}), matrix_64_48_,
156       matrix_64_48_);
157   ASSERT_FALSE(inferred_status_error3.ok());
158   ASSERT_THAT(
159       inferred_status_error3.status().error_message(),
160       HasSubstr("Operands to select and predicate must be the same shape"));
161 
162   // Tuples have a TUPLE element type and cannot be the pred of a select.
163   auto inferred_status_error4 = ShapeInference::InferTernaryOpShape(
164       HloOpcode::kSelect, ShapeUtil::MakeTupleShape({pred_, pred_}),
165       ShapeUtil::MakeTupleShape({f32_, f32_}),
166       ShapeUtil::MakeTupleShape({f32_, f32_}));
167   ASSERT_FALSE(inferred_status_error4.ok());
168   ASSERT_THAT(inferred_status_error4.status().error_message(),
169               HasSubstr("Expected array argument for select pred"));
170 }
171 
TEST_F(ShapeInferenceTest,ClampAllMatrix)172 TEST_F(ShapeInferenceTest, ClampAllMatrix) {
173   auto inferred_status = ShapeInference::InferTernaryOpShape(
174       HloOpcode::kClamp, matrix_64_48_, matrix_64_48_, matrix_64_48_);
175   ASSERT_IS_OK(inferred_status.status());
176   ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
177 }
178 
TEST_F(ShapeInferenceTest,ClampAllScalar)179 TEST_F(ShapeInferenceTest, ClampAllScalar) {
180   auto inferred_status =
181       ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, f32_, f32_);
182   ASSERT_IS_OK(inferred_status.status());
183   ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred_status.ValueOrDie()));
184 }
185 
TEST_F(ShapeInferenceTest,ClampMinScalar)186 TEST_F(ShapeInferenceTest, ClampMinScalar) {
187   auto inferred_status = ShapeInference::InferTernaryOpShape(
188       HloOpcode::kClamp, f32_, matrix_64_48_, matrix_64_48_);
189   ASSERT_FALSE(inferred_status.ok());
190   ASSERT_THAT(inferred_status.status().error_message(),
191               HasSubstr("Clamp with different shapes"));
192 }
193 
TEST_F(ShapeInferenceTest,ClampMaxScalar)194 TEST_F(ShapeInferenceTest, ClampMaxScalar) {
195   auto inferred_status = ShapeInference::InferTernaryOpShape(
196       HloOpcode::kClamp, matrix_64_48_, matrix_64_48_, f32_);
197   ASSERT_FALSE(inferred_status.ok());
198   ASSERT_THAT(inferred_status.status().error_message(),
199               HasSubstr("Clamp with different shapes"));
200 }
201 
TEST_F(ShapeInferenceTest,ClampOperandScalar)202 TEST_F(ShapeInferenceTest, ClampOperandScalar) {
203   auto inferred_status = ShapeInference::InferTernaryOpShape(
204       HloOpcode::kClamp, matrix_64_48_, f32_, matrix_64_48_);
205   ASSERT_FALSE(inferred_status.ok());
206   ASSERT_THAT(inferred_status.status().error_message(),
207               HasSubstr("Clamp with different shapes"));
208 }
209 
TEST_F(ShapeInferenceTest,ClampMinMatrix)210 TEST_F(ShapeInferenceTest, ClampMinMatrix) {
211   auto inferred_status = ShapeInference::InferTernaryOpShape(
212       HloOpcode::kClamp, matrix_64_48_, f32_, f32_);
213   ASSERT_FALSE(inferred_status.ok());
214   ASSERT_THAT(inferred_status.status().error_message(),
215               HasSubstr("Clamp with different shapes"));
216 }
217 
TEST_F(ShapeInferenceTest,ClampMaxMatrix)218 TEST_F(ShapeInferenceTest, ClampMaxMatrix) {
219   auto inferred_status = ShapeInference::InferTernaryOpShape(
220       HloOpcode::kClamp, f32_, f32_, matrix_64_48_);
221   ASSERT_FALSE(inferred_status.ok());
222   ASSERT_THAT(inferred_status.status().error_message(),
223               HasSubstr("Clamp with different shapes"));
224 }
225 
TEST_F(ShapeInferenceTest,ClampOperandMatrix)226 TEST_F(ShapeInferenceTest, ClampOperandMatrix) {
227   auto inferred_status = ShapeInference::InferTernaryOpShape(
228       HloOpcode::kClamp, f32_, matrix_64_48_, f32_);
229   ASSERT_FALSE(inferred_status.ok());
230   ASSERT_THAT(inferred_status.status().error_message(),
231               HasSubstr("Clamp with different shapes"));
232 }
233 
TEST_F(ShapeInferenceTest,ClampBadShapes)234 TEST_F(ShapeInferenceTest, ClampBadShapes) {
235   // Type mismatch
236   ASSERT_FALSE(
237       ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, s32_, f32_, f32_)
238           .ok());
239   ASSERT_FALSE(
240       ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, s32_, f32_)
241           .ok());
242   ASSERT_FALSE(
243       ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, f32_, s32_)
244           .ok());
245   // Dimension mismatch
246   ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
247                    HloOpcode::kClamp, vector_64_, vector_32_, vector_32_)
248                    .ok());
249   ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
250                    HloOpcode::kClamp, vector_32_, vector_64_, vector_32_)
251                    .ok());
252   ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
253                    HloOpcode::kClamp, vector_32_, vector_32_, vector_64_)
254                    .ok());
255   // Dimension mismatch, where one operand is a scalar
256   ASSERT_FALSE(ShapeInference::InferTernaryOpShape(HloOpcode::kClamp,
257                                                    vector_64_, vector_32_, f32_)
258                    .ok());
259   ASSERT_FALSE(ShapeInference::InferTernaryOpShape(HloOpcode::kClamp,
260                                                    vector_64_, f32_, vector_32_)
261                    .ok());
262   ASSERT_FALSE(ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_,
263                                                    vector_64_, vector_32_)
264                    .ok());
265 }
266 
TEST_F(ShapeInferenceTest,Complex)267 TEST_F(ShapeInferenceTest, Complex) {
268   auto complex_shape = [&](const Shape& lhs, const Shape& rhs,
269                            absl::Span<const int64_t> bcast) {
270     return ShapeInference::InferBinaryOpShape(HloOpcode::kComplex, lhs, rhs,
271                                               bcast);
272   };
273   // Inputs must be FP.
274   ASSERT_FALSE(complex_shape(s32_, s32_, {}).ok());
275   ASSERT_FALSE(complex_shape(pred_, pred_, {}).ok());
276   // Component types must match.
277   ASSERT_FALSE(complex_shape(f32_, f64_, {}).ok());
278   // Only F32->C64 and F64->C128 supported.
279   ASSERT_FALSE(complex_shape(f16_, f16_, {}).ok());
280   // Validate correct uses.
281   Shape c64_32 = ShapeUtil::MakeShape(C64, {32});
282   TF_ASSERT_OK_AND_ASSIGN(Shape result, complex_shape(f32_, f32_, {}));
283   ASSERT_TRUE(ShapeUtil::Equal(result, ShapeUtil::MakeShape(C64, {})));
284   TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(vector_32_, f32_, {}));
285   ASSERT_TRUE(ShapeUtil::Equal(result, c64_32));
286   TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(f32_, vector_32_, {}));
287   ASSERT_TRUE(ShapeUtil::Equal(result, c64_32));
288   TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(vector_32_, f32_, {}));
289   ASSERT_TRUE(ShapeUtil::Equal(result, c64_32));
290 
291   Shape c64_32_64 = ShapeUtil::MakeShape(C64, {32, 64});
292   TF_ASSERT_OK_AND_ASSIGN(result,
293                           complex_shape(vector_64_, matrix_32_64_, {1}));
294   ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
295   TF_ASSERT_OK_AND_ASSIGN(result,
296                           complex_shape(matrix_32_64_, vector_64_, {1}));
297   ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
298   TF_ASSERT_OK_AND_ASSIGN(result,
299                           complex_shape(matrix_32_64_, matrix_32_64_, {}));
300   ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
301   TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(matrix_32_64_, f32_, {}));
302   ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
303 
304   TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(f64_, f64_, {}));
305   ASSERT_TRUE(ShapeUtil::Equal(result, ShapeUtil::MakeShape(C128, {})));
306 }
307 
TEST_F(ShapeInferenceTest,VariadicOpTuplify)308 TEST_F(ShapeInferenceTest, VariadicOpTuplify) {
309   StatusOr<Shape> result =
310       ShapeInference::InferVariadicOpShape(HloOpcode::kTuple, {&s32_, &f32_});
311   ASSERT_IS_OK(result.status());
312   ASSERT_TRUE(ShapeUtil::Equal(result.ValueOrDie(),
313                                ShapeUtil::MakeTupleShape({s32_, f32_})));
314 }
315 
TEST_F(ShapeInferenceTest,ReduceWindowInHalf)316 TEST_F(ShapeInferenceTest, ReduceWindowInHalf) {
317   Shape matrix_shape = ShapeUtil::MakeShape(F32, {8, 8});
318   Window window;
319   WindowDimension dim;
320   dim.set_size(2);
321   dim.set_stride(2);
322   dim.set_padding_low(0);
323   dim.set_padding_high(0);
324   dim.set_window_dilation(1);
325   dim.set_base_dilation(1);
326   *window.add_dimensions() = dim;
327   *window.add_dimensions() = dim;
328   Shape window_shape = ShapeUtil::MakeShape(F32, {2, 2});
329   Shape init_value_shape = ShapeUtil::MakeShape(F32, {});
330   Shape float_scalar = ShapeUtil::MakeShape(F32, {});
331   ProgramShape to_apply = ShapeUtil::MakeProgramShape(
332       {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, f32_);
333   auto inferred_status = ShapeInference::InferReduceWindowShape(
334       matrix_shape, init_value_shape, window, to_apply);
335 
336   ASSERT_IS_OK(inferred_status.status());
337   Shape inferred = inferred_status.ValueOrDie();
338   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {4, 4}), inferred));
339 }
340 
TEST_F(SelectAndScatterShapeInferenceTest,SelectAndScatterProperShapes)341 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterProperShapes) {
342   auto inferred_status_ok = ShapeInference::InferSelectAndScatterShape(
343       operand_shape_, select_program_shape_, window_, source_shape_,
344       init_value_shape_, scatter_program_shape_);
345   ASSERT_IS_OK(inferred_status_ok.status());
346   Shape inferred = inferred_status_ok.ValueOrDie();
347   ASSERT_TRUE(ShapeUtil::Equal(operand_shape_, inferred));
348 }
349 
TEST_F(SelectAndScatterShapeInferenceTest,SelectAndScatterWrongSourceShape)350 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSourceShape) {
351   Shape source_shape_fail = ShapeUtil::MakeShape(F32, {4, 6});
352   auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape(
353       operand_shape_, select_program_shape_, window_, source_shape_fail,
354       init_value_shape_, scatter_program_shape_);
355   ASSERT_FALSE(inferred_status_fail.ok());
356   ASSERT_THAT(inferred_status_fail.status().error_message(),
357               HasSubstr("Source shape does not match"));
358 }
359 
TEST_F(SelectAndScatterShapeInferenceTest,SelectAndScatterWrongSelectShape1)360 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape1) {
361   ProgramShape select_program_shape_fail =
362       ShapeUtil::MakeProgramShape({ShapeUtil::MakeShape(F32, {})}, pred_);
363   auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape(
364       operand_shape_, select_program_shape_fail, window_, source_shape_,
365       init_value_shape_, scatter_program_shape_);
366   ASSERT_FALSE(inferred_status_fail.ok());
367   ASSERT_THAT(inferred_status_fail.status().error_message(),
368               HasSubstr("Select function must take 2 parameters"));
369 }
370 
TEST_F(SelectAndScatterShapeInferenceTest,SelectAndScatterWrongSelectShape2)371 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape2) {
372   ProgramShape select_program_shape_fail = ShapeUtil::MakeProgramShape(
373       {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, f32_);
374   auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape(
375       operand_shape_, select_program_shape_fail, window_, source_shape_,
376       init_value_shape_, scatter_program_shape_);
377   ASSERT_FALSE(inferred_status_fail.ok());
378   ASSERT_THAT(inferred_status_fail.status().error_message(),
379               HasSubstr("Select function must have rank-0 PRED"));
380 }
381 
TEST_F(SelectAndScatterShapeInferenceTest,SelectAndScatterWrongSelectShape3)382 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape3) {
383   ProgramShape select_program_shape_fail = ShapeUtil::MakeProgramShape(
384       {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {})}, pred_);
385   auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape(
386       operand_shape_, select_program_shape_fail, window_, source_shape_,
387       init_value_shape_, scatter_program_shape_);
388   ASSERT_FALSE(inferred_status_fail.ok());
389   ASSERT_THAT(inferred_status_fail.status().error_message(),
390               HasSubstr("Select function's first parameter"));
391 }
392 
TEST_F(SelectAndScatterShapeInferenceTest,SelectAndScatterWrongSelectShape4)393 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape4) {
394   ProgramShape select_program_shape_fail = ShapeUtil::MakeProgramShape(
395       {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(U32, {})}, pred_);
396   auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape(
397       operand_shape_, select_program_shape_fail, window_, source_shape_,
398       init_value_shape_, scatter_program_shape_);
399   ASSERT_FALSE(inferred_status_fail.ok());
400   ASSERT_THAT(inferred_status_fail.status().error_message(),
401               HasSubstr("Select function's second parameter"));
402 }
403 
TEST_F(ShapeInferenceTest,AllGatherStart)404 TEST_F(ShapeInferenceTest, AllGatherStart) {
405   const Shape operand = ShapeUtil::MakeShape(F32, {1, 8, 4});
406   const Shape expected_shape = ShapeUtil::MakeTupleShape(
407       {operand, ShapeUtil::MakeShape(F32, {8, 8, 4})});
408 
409   auto inferred_ag_shape = ShapeInference::InferAllGatherStartShape(
410       {&operand}, /*all_gather_dimension=*/0, /*shard_count=*/8);
411   EXPECT_TRUE(inferred_ag_shape.ok());
412   EXPECT_TRUE(ShapeUtil::Equal(inferred_ag_shape.ValueOrDie(), expected_shape));
413 }
414 
TEST_F(ShapeInferenceTest,AllGatherDone)415 TEST_F(ShapeInferenceTest, AllGatherDone) {
416   const Shape input_shape =
417       ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {1, 8, 4}),
418                                  ShapeUtil::MakeShape(F32, {8, 8, 4})});
419   const Shape expected_shape = ShapeUtil::MakeShape(F32, {8, 8, 4});
420 
421   auto inferred_ag_done_shape =
422       ShapeInference::InferAllGatherDoneShape(input_shape);
423   EXPECT_TRUE(inferred_ag_done_shape.ok());
424   EXPECT_TRUE(
425       ShapeUtil::Equal(inferred_ag_done_shape.ValueOrDie(), expected_shape));
426 }
427 
TEST_F(ShapeInferenceTest,Convolve)428 TEST_F(ShapeInferenceTest, Convolve) {
429   ConvolutionDimensionNumbers dnums;
430 
431   // Dimension order: batch, feature, x0, x1
432   Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 3, 4});
433   dnums.set_input_batch_dimension(0);
434   dnums.set_output_batch_dimension(0);
435   dnums.set_input_feature_dimension(1);
436   dnums.set_output_feature_dimension(1);
437   dnums.add_input_spatial_dimensions(2);
438   dnums.add_output_spatial_dimensions(2);
439   dnums.add_input_spatial_dimensions(3);
440   dnums.add_output_spatial_dimensions(3);
441 
442   // Dimension order: x1, batch, feature, x0
443   Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 3});
444   dnums.set_kernel_input_feature_dimension(2);
445   dnums.set_kernel_output_feature_dimension(1);
446   dnums.add_kernel_spatial_dimensions(3);
447   dnums.add_kernel_spatial_dimensions(0);
448 
449   Window window;
450   auto dim0 = window.add_dimensions();
451   auto dim1 = window.add_dimensions();
452   dim0->set_size(3);
453   dim0->set_stride(2);
454   dim0->set_padding_low(1);
455   dim0->set_padding_high(1);
456   dim0->set_window_dilation(1);
457   dim0->set_base_dilation(1);
458   dim1->set_size(2);
459   dim1->set_stride(1);
460   dim1->set_padding_low(0);
461   dim1->set_padding_high(0);
462   dim1->set_window_dilation(1);
463   dim1->set_base_dilation(1);
464   auto inferred_status = ShapeInference::InferConvolveShape(
465       lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1,
466       window, dnums, /*preferred_element_type=*/std::nullopt);
467   ASSERT_IS_OK(inferred_status.status());
468   Shape inferred_shape = inferred_status.ValueOrDie();
469   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 2, 3}),
470                                inferred_shape));
471 }
472 
TEST_F(ShapeInferenceTest,ConvolveWithWindowDilation)473 TEST_F(ShapeInferenceTest, ConvolveWithWindowDilation) {
474   ConvolutionDimensionNumbers dnums;
475 
476   // Dimension order: batch, feature, x0, x1
477   Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 103, 4});
478   dnums.set_input_batch_dimension(0);
479   dnums.set_output_batch_dimension(0);
480   dnums.set_input_feature_dimension(1);
481   dnums.set_output_feature_dimension(1);
482   dnums.add_input_spatial_dimensions(2);
483   dnums.add_output_spatial_dimensions(2);
484   dnums.add_input_spatial_dimensions(3);
485   dnums.add_output_spatial_dimensions(3);
486 
487   // Dimension order: x1, batch, feature, x0
488   Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 3});
489   dnums.set_kernel_input_feature_dimension(2);
490   dnums.set_kernel_output_feature_dimension(1);
491   dnums.add_kernel_spatial_dimensions(3);
492   dnums.add_kernel_spatial_dimensions(0);
493 
494   Window window;
495   auto dim0 = window.add_dimensions();
496   dim0->set_size(3);
497   dim0->set_stride(3);
498   dim0->set_padding_low(0);
499   dim0->set_padding_high(0);
500   dim0->set_window_dilation(6);
501   dim0->set_base_dilation(1);
502 
503   auto dim1 = window.add_dimensions();
504   dim1->set_size(2);
505   dim1->set_stride(1);
506   dim1->set_padding_low(2);
507   dim1->set_padding_high(1);
508   dim1->set_window_dilation(2);
509   dim1->set_base_dilation(1);
510   auto inferred_status = ShapeInference::InferConvolveShape(
511       lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1,
512       window, dnums, /*preferred_element_type=*/std::nullopt);
513   ASSERT_IS_OK(inferred_status.status());
514   Shape inferred_shape = inferred_status.ValueOrDie();
515   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 31, 5}),
516                                inferred_shape));
517 }
518 
TEST_F(ShapeInferenceTest,ConvolveWithBaseDilation)519 TEST_F(ShapeInferenceTest, ConvolveWithBaseDilation) {
520   ConvolutionDimensionNumbers dnums;
521 
522   // Dimension order: batch, feature, x0, x1
523   Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 3, 4});
524   dnums.set_input_batch_dimension(0);
525   dnums.set_output_batch_dimension(0);
526   dnums.set_input_feature_dimension(1);
527   dnums.set_output_feature_dimension(1);
528   dnums.add_input_spatial_dimensions(2);
529   dnums.add_output_spatial_dimensions(2);
530   dnums.add_input_spatial_dimensions(3);
531   dnums.add_output_spatial_dimensions(3);
532 
533   // Dimension order: x1, batch, feature, x0
534   Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 4});
535   dnums.set_kernel_input_feature_dimension(2);
536   dnums.set_kernel_output_feature_dimension(1);
537   dnums.add_kernel_spatial_dimensions(3);
538   dnums.add_kernel_spatial_dimensions(0);
539 
540   Window window;
541   auto dim0 = window.add_dimensions();
542   dim0->set_size(4);
543   dim0->set_stride(3);
544   dim0->set_padding_low(0);
545   dim0->set_padding_high(0);
546   dim0->set_window_dilation(1);
547   dim0->set_base_dilation(6);
548 
549   auto dim1 = window.add_dimensions();
550   dim1->set_size(2);
551   dim1->set_stride(1);
552   dim1->set_padding_low(2);
553   dim1->set_padding_high(1);
554   dim1->set_window_dilation(1);
555   dim1->set_base_dilation(2);
556   auto inferred_status = ShapeInference::InferConvolveShape(
557       lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1,
558       window, dnums, /*preferred_element_type=*/std::nullopt);
559   ASSERT_IS_OK(inferred_status.status());
560   Shape inferred_shape = inferred_status.ValueOrDie();
561   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 4, 9}),
562                                inferred_shape));
563 }
564 
TEST_F(ShapeInferenceTest,ConvolveDimensionNumbersOverlapError)565 TEST_F(ShapeInferenceTest, ConvolveDimensionNumbersOverlapError) {
566   // Dimension order for this test: batch, feature, x0, x1
567   Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 3, 4});
568   Shape rhs_shape = ShapeUtil::MakeShape(F32, {12, 11, 3, 2});
569 
570   ConvolutionDimensionNumbers dnums;
571   dnums.set_input_batch_dimension(3);
572   dnums.set_output_batch_dimension(3);
573   dnums.set_input_feature_dimension(2);
574   dnums.set_output_feature_dimension(2);
575   dnums.add_input_spatial_dimensions(0);
576   dnums.add_output_spatial_dimensions(0);
577   dnums.add_input_spatial_dimensions(1);
578   dnums.add_output_spatial_dimensions(1);
579   dnums.set_kernel_input_feature_dimension(0);  // duplicated with kernel_x0
580   dnums.set_kernel_output_feature_dimension(3);
581   dnums.add_kernel_spatial_dimensions(0);
582   dnums.add_kernel_spatial_dimensions(1);
583 
584   Window window;
585   auto dim0 = window.add_dimensions();
586   auto dim1 = window.add_dimensions();
587   dim0->set_size(2);
588   dim0->set_stride(1);
589   dim0->set_padding_low(0);
590   dim0->set_padding_high(0);
591   dim1->set_size(3);
592   dim1->set_stride(2);
593   dim1->set_padding_low(1);
594   dim1->set_padding_high(1);
595   auto inferred_status = ShapeInference::InferConvolveShape(
596       lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1,
597       window, dnums, /*preferred_element_type=*/std::nullopt);
598   ASSERT_FALSE(inferred_status.ok());
599   ASSERT_THAT(inferred_status.status().error_message(),
600               HasSubstr("each dimension exactly once"));
601 }
602 
TEST_F(ShapeInferenceTest,ConvolveBatchGroupCountUnequalOutputFeature)603 TEST_F(ShapeInferenceTest, ConvolveBatchGroupCountUnequalOutputFeature) {
604   ConvolutionDimensionNumbers dnums;
605   dnums.set_input_batch_dimension(0);
606   dnums.set_input_feature_dimension(1);
607   dnums.add_input_spatial_dimensions(2);
608   dnums.add_input_spatial_dimensions(3);
609   dnums.set_kernel_input_feature_dimension(0);
610   dnums.set_kernel_output_feature_dimension(1);
611   dnums.add_kernel_spatial_dimensions(2);
612   dnums.add_kernel_spatial_dimensions(3);
613   dnums.set_output_batch_dimension(0);
614   dnums.set_output_feature_dimension(1);
615   dnums.add_output_spatial_dimensions(2);
616   dnums.add_output_spatial_dimensions(3);
617   Shape lhs_shape = ShapeUtil::MakeShape(F32, {60, 38, 17, 13});
618   Shape rhs_shape = ShapeUtil::MakeShape(F32, {38, 10, 4, 4});
619   Window window;
620   auto dim0 = window.add_dimensions();
621   auto dim1 = window.add_dimensions();
622   dim0->set_size(4);
623   dim1->set_size(4);
624   dim0->set_padding_low(0);
625   dim0->set_padding_high(2);
626   dim1->set_padding_low(2);
627   dim1->set_padding_high(1);
628   dim0->set_stride(1);
629   dim1->set_stride(1);
630   dim0->set_window_dilation(3);
631   dim1->set_window_dilation(2);
632   auto inferred_status = ShapeInference::InferConvolveShape(
633       lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/6,
634       window, dnums, /*preferred_element_type=*/std::nullopt);
635   ASSERT_FALSE(inferred_status.ok());
636   ASSERT_THAT(inferred_status.status().error_message(),
637               HasSubstr("to be a multiple of batch group count"));
638 }
639 
640 struct ConvolveArgs {
641   Shape lhs_shape;
642   Shape rhs_shape;
643   ConvolutionDimensionNumbers dnums;
644   Window window;
645 };
646 
MakeConvolveArgs(PrimitiveType lhs_type,PrimitiveType rhs_type)647 ConvolveArgs MakeConvolveArgs(PrimitiveType lhs_type, PrimitiveType rhs_type) {
648   ConvolveArgs args;
649   ConvolutionDimensionNumbers& dnums = args.dnums;
650 
651   // Dimension order: batch, feature, x0, x1
652   args.lhs_shape = ShapeUtil::MakeShape(lhs_type, {10, 11, 3, 4});
653   dnums.set_input_batch_dimension(0);
654   dnums.set_output_batch_dimension(0);
655   dnums.set_input_feature_dimension(1);
656   dnums.set_output_feature_dimension(1);
657   dnums.add_input_spatial_dimensions(2);
658   dnums.add_output_spatial_dimensions(2);
659   dnums.add_input_spatial_dimensions(3);
660   dnums.add_output_spatial_dimensions(3);
661 
662   // Dimension order: x1, batch, feature, x0
663   args.rhs_shape = ShapeUtil::MakeShape(rhs_type, {2, 12, 11, 3});
664   dnums.set_kernel_input_feature_dimension(2);
665   dnums.set_kernel_output_feature_dimension(1);
666   dnums.add_kernel_spatial_dimensions(3);
667   dnums.add_kernel_spatial_dimensions(0);
668 
669   auto dim0 = args.window.add_dimensions();
670   auto dim1 = args.window.add_dimensions();
671   dim0->set_size(3);
672   dim0->set_stride(2);
673   dim0->set_padding_low(1);
674   dim0->set_padding_high(1);
675   dim0->set_window_dilation(1);
676   dim0->set_base_dilation(1);
677   dim1->set_size(2);
678   dim1->set_stride(1);
679   dim1->set_padding_low(0);
680   dim1->set_padding_high(0);
681   dim1->set_window_dilation(1);
682   dim1->set_base_dilation(1);
683   return args;
684 }
685 
TEST_F(ShapeInferenceTest,ConvolveWithBF16_F16)686 TEST_F(ShapeInferenceTest, ConvolveWithBF16_F16) {
687   ConvolveArgs args = MakeConvolveArgs(BF16, F16);
688   TF_ASSERT_OK_AND_ASSIGN(
689       Shape inferred_shape,
690       ShapeInference::InferConvolveShape(
691           args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
692           /*batch_group_count=*/1, args.window, args.dnums,
693           /*preferred_element_type=*/std::nullopt))
694   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(BF16, {10, 12, 2, 3}),
695                                inferred_shape));
696 }
697 
TEST_F(ShapeInferenceTest,ConvolveWithF16_BF16)698 TEST_F(ShapeInferenceTest, ConvolveWithF16_BF16) {
699   ConvolveArgs args = MakeConvolveArgs(F16, BF16);
700   TF_ASSERT_OK_AND_ASSIGN(
701       Shape inferred_shape,
702       ShapeInference::InferConvolveShape(
703           args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
704           /*batch_group_count=*/1, args.window, args.dnums,
705           /*preferred_element_type=*/std::nullopt))
706   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(BF16, {10, 12, 2, 3}),
707                                inferred_shape));
708 }
709 
TEST_F(ShapeInferenceTest,ConvolveWithS32_U32)710 TEST_F(ShapeInferenceTest, ConvolveWithS32_U32) {
711   ConvolveArgs args = MakeConvolveArgs(S32, U32);
712   TF_ASSERT_OK_AND_ASSIGN(
713       Shape inferred_shape,
714       ShapeInference::InferConvolveShape(
715           args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
716           /*batch_group_count=*/1, args.window, args.dnums,
717           /*preferred_element_type=*/std::nullopt))
718   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(S32, {10, 12, 2, 3}),
719                                inferred_shape));
720 }
721 
TEST_F(ShapeInferenceTest,ConvolveWithU32_S32)722 TEST_F(ShapeInferenceTest, ConvolveWithU32_S32) {
723   ConvolveArgs args = MakeConvolveArgs(U32, S32);
724   TF_ASSERT_OK_AND_ASSIGN(
725       Shape inferred_shape,
726       ShapeInference::InferConvolveShape(
727           args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
728           /*batch_group_count=*/1, args.window, args.dnums,
729           /*preferred_element_type=*/std::nullopt))
730   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(S32, {10, 12, 2, 3}),
731                                inferred_shape));
732 }
733 
TEST_F(ShapeInferenceTest,ConvolveWithPreferredElementType)734 TEST_F(ShapeInferenceTest, ConvolveWithPreferredElementType) {
735   ConvolveArgs args = MakeConvolveArgs(S8, S16);
736   TF_ASSERT_OK_AND_ASSIGN(
737       Shape inferred_shape,
738       ShapeInference::InferConvolveShape(
739           args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
740           /*batch_group_count=*/1, args.window, args.dnums,
741           /*preferred_element_type=*/S16))
742   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(S16, {10, 12, 2, 3}),
743                                inferred_shape));
744 }
745 
TEST_F(ShapeInferenceTest,ConvolveWithPreferredElementTypeSameAsInferredType)746 TEST_F(ShapeInferenceTest, ConvolveWithPreferredElementTypeSameAsInferredType) {
747   ConvolveArgs args = MakeConvolveArgs(S8, S16);
748   TF_ASSERT_OK_AND_ASSIGN(
749       Shape inferred_shape,
750       ShapeInference::InferConvolveShape(
751           args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
752           /*batch_group_count=*/1, args.window, args.dnums,
753           /*preferred_element_type=*/S32))
754   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(S32, {10, 12, 2, 3}),
755                                inferred_shape));
756 }
757 
TEST_F(ShapeInferenceTest,FloatingPointConvolveWithNarrowerPreferredElementType)758 TEST_F(ShapeInferenceTest,
759        FloatingPointConvolveWithNarrowerPreferredElementType) {
760   ConvolveArgs args = MakeConvolveArgs(F32, F32);
761   TF_ASSERT_OK_AND_ASSIGN(
762       Shape inferred_shape,
763       ShapeInference::InferConvolveShape(
764           args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
765           /*batch_group_count=*/1, args.window, args.dnums,
766           /*preferred_element_type=*/BF16))
767   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(BF16, {10, 12, 2, 3}),
768                                inferred_shape));
769 }
770 
TEST_F(ShapeInferenceTest,FloatingPointConvolveWithIntegralPreferredElementType)771 TEST_F(ShapeInferenceTest,
772        FloatingPointConvolveWithIntegralPreferredElementType) {
773   ConvolveArgs args = MakeConvolveArgs(BF16, BF16);
774   TF_ASSERT_OK_AND_ASSIGN(
775       Shape inferred_shape,
776       ShapeInference::InferConvolveShape(
777           args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
778           /*batch_group_count=*/1, args.window, args.dnums,
779           /*preferred_element_type=*/S32));
780   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(S32, {10, 12, 2, 3}),
781                                inferred_shape));
782 }
783 
TEST_F(ShapeInferenceTest,IntegralConvolveWithFloatingPointPreferredElementType)784 TEST_F(ShapeInferenceTest,
785        IntegralConvolveWithFloatingPointPreferredElementType) {
786   ConvolveArgs args = MakeConvolveArgs(S8, S16);
787   TF_ASSERT_OK_AND_ASSIGN(
788       Shape inferred_shape,
789       ShapeInference::InferConvolveShape(
790           args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
791           /*batch_group_count=*/1, args.window, args.dnums,
792           /*preferred_element_type=*/F32));
793   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 2, 3}),
794                                inferred_shape));
795 }
796 
TEST_F(ShapeInferenceTest,ConvolveWithPreferredElementTypeWithDifferentSignedness)797 TEST_F(ShapeInferenceTest,
798        ConvolveWithPreferredElementTypeWithDifferentSignedness) {
799   ConvolveArgs args = MakeConvolveArgs(S8, S16);
800   TF_ASSERT_OK_AND_ASSIGN(
801       Shape inferred_shape,
802       ShapeInference::InferConvolveShape(
803           args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
804           /*batch_group_count=*/1, args.window, args.dnums,
805           /*preferred_element_type=*/U32));
806   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(U32, {10, 12, 2, 3}),
807                                inferred_shape));
808 }
809 
TEST_F(ShapeInferenceTest,ConvolveWithNarrowerPreferredElementType)810 TEST_F(ShapeInferenceTest, ConvolveWithNarrowerPreferredElementType) {
811   ConvolveArgs args = MakeConvolveArgs(S8, S16);
812   auto inferred_status =
813       ShapeInference::InferConvolveShape(
814           args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
815           /*batch_group_count=*/1, args.window, args.dnums,
816           /*preferred_element_type=*/S8)
817           .status();
818   ASSERT_FALSE(inferred_status.ok());
819   ASSERT_THAT(inferred_status.error_message(),
820               HasSubstr("must not be narrower than the original type"));
821 }
822 
823 namespace fft {
824 
825 static const char* unsupported_rank = "only supports ranks 1-3";
826 static const char* invalid_rank = "requires input of at least same rank";
827 static const char* requires_complex_input = "requires complex input type";
828 static const char* requires_f32_input = "requires F32 or F64 input type";
829 static const char* dimensions_match = "innermost dimensions match fft_length";
830 static const char* innermost_dimension_matches =
831     "innermost dimension matches fft_length/2+1";
832 
Pass(const Shape & shape,FftType type,absl::Span<const int64_t> length,const Shape & expected_shape)833 static void Pass(const Shape& shape, FftType type,
834                  absl::Span<const int64_t> length,
835                  const Shape& expected_shape) {
836   auto inferred_status = ShapeInference::InferFftShape(shape, type, length);
837   ASSERT_IS_OK(inferred_status.status());
838   Shape inferred_shape = inferred_status.ValueOrDie();
839   ASSERT_TRUE(ShapeUtil::Equal(inferred_shape, expected_shape));
840 }
841 
Fail(const Shape & shape,FftType type,absl::Span<const int64_t> length,absl::string_view message)842 static void Fail(const Shape& shape, FftType type,
843                  absl::Span<const int64_t> length, absl::string_view message) {
844   auto inferred_status = ShapeInference::InferFftShape(shape, type, length);
845   ASSERT_FALSE(inferred_status.ok());
846   ASSERT_THAT(inferred_status.status().error_message(),
847               HasSubstr(std::string(message)));
848 }
849 
850 }  // namespace fft
851 
TEST_F(ShapeInferenceTest,InferFftShapeTestFftRanks)852 TEST_F(ShapeInferenceTest, InferFftShapeTestFftRanks) {
853   FftType type = FftType::FFT;
854   Shape shape = ShapeUtil::MakeShape(C64, {16, 8});
855   fft::Fail(shape, type, {}, fft::unsupported_rank);
856   fft::Pass(shape, type, {8}, shape);
857   fft::Pass(shape, type, {16, 8}, shape);
858   fft::Fail(shape, type, {32, 16, 8}, fft::invalid_rank);
859   fft::Fail(shape, type, {64, 32, 16, 8}, fft::unsupported_rank);
860 }
861 
TEST_F(ShapeInferenceTest,InferFftShapeTestFftTypes)862 TEST_F(ShapeInferenceTest, InferFftShapeTestFftTypes) {
863   FftType type = FftType::FFT;
864   Shape shape_f32 = ShapeUtil::MakeShape(F32, {16, 8});
865   Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 8});
866   fft::Fail(shape_f32, type, {16, 8}, fft::requires_complex_input);
867   fft::Pass(shape_c128, type, {16, 8}, shape_c128);
868 }
869 
TEST_F(ShapeInferenceTest,InferFftShapeTestIfftRanks)870 TEST_F(ShapeInferenceTest, InferFftShapeTestIfftRanks) {
871   FftType type = FftType::IFFT;
872   Shape shape = ShapeUtil::MakeShape(C64, {16, 8});
873   fft::Fail(shape, type, {}, fft::unsupported_rank);
874   fft::Pass(shape, type, {8}, shape);
875   fft::Pass(shape, type, {16, 8}, shape);
876   fft::Fail(shape, type, {32, 16, 8}, fft::invalid_rank);
877   fft::Fail(shape, type, {64, 32, 16, 8}, fft::unsupported_rank);
878 }
879 
TEST_F(ShapeInferenceTest,InferFftShapeTestIfftTypes)880 TEST_F(ShapeInferenceTest, InferFftShapeTestIfftTypes) {
881   FftType type = FftType::IFFT;
882   Shape shape_f32 = ShapeUtil::MakeShape(F32, {16, 8});
883   Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 8});
884   fft::Fail(shape_f32, type, {16, 8}, fft::requires_complex_input);
885   fft::Pass(shape_c128, type, {16, 8}, shape_c128);
886 }
887 
TEST_F(ShapeInferenceTest,InferFftShapeTestRfftRanks)888 TEST_F(ShapeInferenceTest, InferFftShapeTestRfftRanks) {
889   FftType type = FftType::RFFT;
890   Shape shape_in = ShapeUtil::MakeShape(F32, {16, 8});
891   Shape shape_out = ShapeUtil::MakeShape(C64, {16, 5});
892   fft::Fail(shape_in, type, {}, fft::unsupported_rank);
893   fft::Pass(shape_in, type, {8}, shape_out);
894   fft::Pass(shape_in, type, {16, 8}, shape_out);
895   fft::Fail(shape_in, type, {32, 16, 8}, fft::invalid_rank);
896   fft::Fail(shape_in, type, {64, 32, 16, 8}, fft::unsupported_rank);
897 }
898 
TEST_F(ShapeInferenceTest,InferFftShapeTestRfftDimensions)899 TEST_F(ShapeInferenceTest, InferFftShapeTestRfftDimensions) {
900   FftType type = FftType::RFFT;
901   Shape shape = ShapeUtil::MakeShape(F32, {16, 8});
902   fft::Fail(shape, type, {4}, fft::dimensions_match);
903   fft::Fail(shape, type, {16, 4}, fft::dimensions_match);
904   fft::Fail(shape, type, {8, 8}, fft::dimensions_match);
905   fft::Fail(shape, type, {8, 16}, fft::dimensions_match);
906 
907   Shape zero_shape_in = ShapeUtil::MakeShape(F32, {16, 0});
908   Shape zero_shape_out = ShapeUtil::MakeShape(C64, {16, 0});
909   fft::Pass(zero_shape_in, type, {0}, zero_shape_out);
910   fft::Pass(zero_shape_in, type, {16, 0}, zero_shape_out);
911 
912   Shape even_shape_in = ShapeUtil::MakeShape(F32, {16, 8});
913   Shape odd_shape_in = ShapeUtil::MakeShape(F32, {16, 9});
914   Shape shape_out = ShapeUtil::MakeShape(C64, {16, 5});
915   fft::Pass(even_shape_in, type, {16, 8}, shape_out);
916   fft::Pass(odd_shape_in, type, {16, 9}, shape_out);
917 }
918 
TEST_F(ShapeInferenceTest,InferFftShapeTestRfftTypes)919 TEST_F(ShapeInferenceTest, InferFftShapeTestRfftTypes) {
920   FftType type = FftType::RFFT;
921   Shape shape_c64 = ShapeUtil::MakeShape(C64, {16, 8});
922   Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 8});
923   fft::Fail(shape_c64, type, {16, 8}, fft::requires_f32_input);
924   fft::Fail(shape_c128, type, {16, 8}, fft::requires_f32_input);
925 }
926 
TEST_F(ShapeInferenceTest,InferFftShapeTestIrfftRanks)927 TEST_F(ShapeInferenceTest, InferFftShapeTestIrfftRanks) {
928   FftType type = FftType::IRFFT;
929   Shape shape_in = ShapeUtil::MakeShape(C64, {16, 5});
930   Shape shape_out = ShapeUtil::MakeShape(F32, {16, 8});
931   fft::Fail(shape_in, type, {}, fft::unsupported_rank);
932   fft::Pass(shape_in, type, {8}, shape_out);
933   fft::Pass(shape_in, type, {16, 8}, shape_out);
934   fft::Fail(shape_in, type, {32, 16, 8}, fft::invalid_rank);
935   fft::Fail(shape_in, type, {64, 32, 16, 8}, fft::unsupported_rank);
936 }
937 
TEST_F(ShapeInferenceTest,InferFftShapeTestIrfftDimensions)938 TEST_F(ShapeInferenceTest, InferFftShapeTestIrfftDimensions) {
939   FftType type = FftType::IRFFT;
940   Shape shape = ShapeUtil::MakeShape(C64, {16, 5});
941   fft::Fail(shape, type, {5}, fft::innermost_dimension_matches);
942   fft::Fail(shape, type, {16, 5}, fft::innermost_dimension_matches);
943   fft::Fail(shape, type, {8, 8}, fft::dimensions_match);
944   fft::Fail(shape, type, {8, 9}, fft::dimensions_match);
945 
946   Shape zero_shape_in = ShapeUtil::MakeShape(C64, {16, 0});
947   Shape zero_shape_out = ShapeUtil::MakeShape(F32, {16, 0});
948   fft::Pass(zero_shape_in, type, {0}, zero_shape_out);
949   fft::Pass(zero_shape_in, type, {16, 0}, zero_shape_out);
950 
951   Shape even_shape_out = ShapeUtil::MakeShape(F32, {16, 8});
952   Shape odd_shape_out = ShapeUtil::MakeShape(F32, {16, 9});
953   fft::Pass(shape, type, {16, 8}, even_shape_out);
954   fft::Pass(shape, type, {16, 9}, odd_shape_out);
955 }
956 
TEST_F(ShapeInferenceTest,InferFftShapeTestIrfftTypes)957 TEST_F(ShapeInferenceTest, InferFftShapeTestIrfftTypes) {
958   FftType type = FftType::IRFFT;
959   Shape shape_f32 = ShapeUtil::MakeShape(F32, {16, 8});
960   Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 5});
961   Shape shape_f64_out = ShapeUtil::MakeShape(F64, {16, 8});
962   fft::Fail(shape_f32, type, {16, 8}, fft::requires_complex_input);
963   fft::Pass(shape_c128, type, {16, 8}, shape_f64_out);
964 }
965 
TEST_F(ShapeInferenceTest,MapThatChangesElementType)966 TEST_F(ShapeInferenceTest, MapThatChangesElementType) {
967   Shape arg = ShapeUtil::MakeShape(F32, {20});
968   ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_}, s32_);
969   auto inferred_status = ShapeInference::InferMapShape({&arg}, to_apply, {0});
970   EXPECT_IS_OK(inferred_status.status());
971   Shape expected = ShapeUtil::MakeShape(S32, {20});
972   EXPECT_TRUE(ShapeUtil::Equal(expected, inferred_status.ValueOrDie()));
973 }
974 
TEST_F(ShapeInferenceTest,Map)975 TEST_F(ShapeInferenceTest, Map) {
976   auto inferred_status_r1f32 = ShapeInference::InferMapShape(
977       {&vector_32_, &vector_32_},
978       ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {0});
979   EXPECT_IS_OK(inferred_status_r1f32.status());
980   EXPECT_TRUE(ShapeUtil::Equal(vector_32_, inferred_status_r1f32.ValueOrDie()));
981 
982   // It's OK to provide a single argument, as long as the applied arity matches
983   // (this degenerates to a Map).
984   auto inferred_status_r1f32_one = ShapeInference::InferMapShape(
985       {&vector_32_}, ShapeUtil::MakeProgramShape({f32_}, f32_), {0});
986   EXPECT_IS_OK(inferred_status_r1f32_one.status());
987   EXPECT_TRUE(
988       ShapeUtil::Equal(vector_32_, inferred_status_r1f32_one.ValueOrDie()));
989 
990   auto inferred_status_r2s32 = ShapeInference::InferMapShape(
991       {&s32matrix_64_64_, &s32matrix_64_64_, &s32matrix_64_64_},
992       ShapeUtil::MakeProgramShape({s32_, s32_, s32_}, s32_), {0, 1});
993   EXPECT_IS_OK(inferred_status_r2s32.status());
994   EXPECT_TRUE(
995       ShapeUtil::Equal(s32matrix_64_64_, inferred_status_r2s32.ValueOrDie()));
996 
997   auto no_args_error = ShapeInference::InferMapShape(
998       {}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {});
999   ASSERT_FALSE(no_args_error.ok());
1000   ASSERT_THAT(no_args_error.status().error_message(),
1001               HasSubstr("expects at least one argument"));
1002 
1003   auto args_diff_shapes_error = ShapeInference::InferMapShape(
1004       {&vector_32_, &vector_64_},
1005       ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {0});
1006   ASSERT_FALSE(args_diff_shapes_error.ok());
1007   ASSERT_THAT(args_diff_shapes_error.status().error_message(),
1008               HasSubstr("requires all operands to have the same shape"));
1009 
1010   auto arity_error = ShapeInference::InferMapShape(
1011       {&vector_32_, &vector_32_}, ShapeUtil::MakeProgramShape({f32_}, f32_),
1012       {0});
1013   ASSERT_FALSE(arity_error.ok());
1014   ASSERT_THAT(arity_error.status().error_message(),
1015               HasSubstr("function arity must match"));
1016 
1017   auto output_shape_error = ShapeInference::InferMapShape(
1018       {&vector_32_, &vector_32_},
1019       ShapeUtil::MakeProgramShape({f32_, f32_}, vector_32_), {0});
1020   ASSERT_FALSE(output_shape_error.ok());
1021   ASSERT_THAT(output_shape_error.status().error_message(),
1022               HasSubstr("result has to be a scalar"));
1023 
1024   auto param_shape_error = ShapeInference::InferMapShape(
1025       {&vector_32_, &vector_32_},
1026       ShapeUtil::MakeProgramShape({vector_32_, f32_}, f32_), {0});
1027   ASSERT_FALSE(param_shape_error.ok());
1028   ASSERT_THAT(param_shape_error.status().error_message(),
1029               HasSubstr("parameter has to be a scalar"));
1030 
1031   auto param_element_type_error = ShapeInference::InferMapShape(
1032       {&vector_32_, &vector_32_},
1033       ShapeUtil::MakeProgramShape({f32_, s32_}, f32_), {0});
1034   ASSERT_FALSE(param_element_type_error.ok());
1035   ASSERT_THAT(param_element_type_error.status().error_message(),
1036               HasSubstr("parameter type has to match argument"));
1037 
1038   Shape arg = ShapeUtil::MakeShape(F32, {20});
1039   ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_}, f32_);
1040   auto inferred_status = ShapeInference::InferMapShape({&arg}, to_apply, {0});
1041   EXPECT_IS_OK(inferred_status.status());
1042   EXPECT_TRUE(ShapeUtil::Equal(arg, inferred_status.ValueOrDie()));
1043 
1044   auto inferred_status_error1 = ShapeInference::InferMapShape(
1045       {&arg}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {0});
1046   ASSERT_FALSE(inferred_status_error1.ok());
1047   ASSERT_THAT(inferred_status_error1.status().error_message(),
1048               HasSubstr("arity must match number of arguments"));
1049 
1050   auto inferred_status_error2 = ShapeInference::InferMapShape(
1051       {&arg}, ShapeUtil::MakeProgramShape({vector_32_}, f32_), {0});
1052   ASSERT_FALSE(inferred_status_error2.ok());
1053   ASSERT_THAT(inferred_status_error2.status().error_message(),
1054               HasSubstr("has to be a scalar"));
1055 
1056   auto inferred_status_error3 = ShapeInference::InferMapShape(
1057       {&arg}, ShapeUtil::MakeProgramShape({f32_}, vector_32_), {0});
1058   ASSERT_FALSE(inferred_status_error3.ok());
1059   ASSERT_THAT(inferred_status_error3.status().error_message(),
1060               HasSubstr("has to be a scalar"));
1061 
1062   auto inferred_status_error5 = ShapeInference::InferMapShape(
1063       {&arg}, ShapeUtil::MakeProgramShape({s32_}, s32_), {0});
1064   ASSERT_FALSE(inferred_status_error5.ok());
1065   ASSERT_THAT(inferred_status_error5.status().error_message(),
1066               HasSubstr("parameter type has to match argument"));
1067 }
1068 
TEST_F(ShapeInferenceTest,MapWithDifferentInputTypes)1069 TEST_F(ShapeInferenceTest, MapWithDifferentInputTypes) {
1070   Shape arg0 = ShapeUtil::MakeShape(F32, {20});
1071   Shape arg1 = ShapeUtil::MakeShape(S32, {20});
1072   ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, s32_}, s32_);
1073   auto inferred_status =
1074       ShapeInference::InferMapShape({&arg0, &arg1}, to_apply, {0});
1075   EXPECT_IS_OK(inferred_status.status());
1076   Shape expected = ShapeUtil::MakeShape(S32, {20});
1077   EXPECT_TRUE(ShapeUtil::Equal(expected, inferred_status.ValueOrDie()));
1078 }
1079 
TEST_F(ReduceShapeInferenceTest,ReduceVectorToScalar)1080 TEST_F(ReduceShapeInferenceTest, ReduceVectorToScalar) {
1081   ExpectInferredReduceShape(f32_, ShapeUtil::MakeShape(F32, {128}),
1082                             /*dimensions_to_reduce=*/{0});
1083 }
1084 
TEST_F(ReduceShapeInferenceTest,ReduceCubeAmongFirstDimension)1085 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongFirstDimension) {
1086   ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {3, 4}),
1087                             ShapeUtil::MakeShape(F32, {2, 3, 4}),
1088                             /*dimensions_to_reduce=*/{0});
1089 }
1090 
TEST_F(ReduceShapeInferenceTest,ReduceCubeAmongMiddleDimension)1091 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongMiddleDimension) {
1092   ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {2, 4}),
1093                             ShapeUtil::MakeShape(F32, {2, 3, 4}),
1094                             /*dimensions_to_reduce=*/{1});
1095 }
1096 
TEST_F(ReduceShapeInferenceTest,ReduceCubeAmongFirstTwoDimensions)1097 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongFirstTwoDimensions) {
1098   ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {4}),
1099                             ShapeUtil::MakeShape(F32, {2, 3, 4}),
1100                             /*dimensions_to_reduce=*/{0, 1});
1101 }
1102 
TEST_F(ReduceShapeInferenceTest,ReduceCubeAmongLastTwoDimensions)1103 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongLastTwoDimensions) {
1104   ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {2}),
1105                             ShapeUtil::MakeShape(F32, {2, 3, 4}),
1106                             /*dimensions_to_reduce=*/{1, 2});
1107 }
1108 
TEST_F(ReduceShapeInferenceTest,ReduceCubeAmongFirstAndLastDimensions)1109 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongFirstAndLastDimensions) {
1110   ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {3}),
1111                             ShapeUtil::MakeShape(F32, {2, 3, 4}),
1112                             /*dimensions_to_reduce=*/{0, 2});
1113 
1114   // Check that the order of dimensions_to_reduce doesn't matter.
1115   ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {3}),
1116                             ShapeUtil::MakeShape(F32, {2, 3, 4}),
1117                             /*dimensions_to_reduce=*/{2, 0});
1118 }
1119 
TEST_F(ReduceShapeInferenceTest,ReduceCubeAmongAllDimensions)1120 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongAllDimensions) {
1121   ExpectInferredReduceShape(f32_, ShapeUtil::MakeShape(F32, {2, 3, 4}),
1122                             /*dimensions_to_reduce=*/{0, 1, 2});
1123 }
1124 
TEST_F(ReduceShapeInferenceTest,ReduceMultiOutput)1125 TEST_F(ReduceShapeInferenceTest, ReduceMultiOutput) {
1126   Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1127   Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
1128   ProgramShape to_apply = ShapeUtil::MakeProgramShape(
1129       {f32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_}));
1130   auto inferred_status = ShapeInference::InferReduceShape(
1131       {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
1132   EXPECT_IS_OK(inferred_status.status());
1133   EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeTupleShape({f32_, s32_}),
1134                                inferred_status.ValueOrDie()));
1135 }
1136 
TEST_F(ReduceShapeInferenceTest,ReduceWindowMultiOutput)1137 TEST_F(ReduceShapeInferenceTest, ReduceWindowMultiOutput) {
1138   Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3, 1});
1139   Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3, 1});
1140   std::vector<const Shape*> args = {&f32_arg_shape, &s32_arg_shape};
1141   std::vector<const Shape*> inits = {&f32_, &s32_};
1142   ProgramShape to_apply = ShapeUtil::MakeProgramShape(
1143       {f32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_}));
1144   std::vector<int64_t> window_dimensions = {1, 2, 4};
1145   std::vector<int64_t> window_strides = {1, 1, 1};
1146   std::vector<std::pair<int64_t, int64_t>> padding_values =
1147       MakePadding(f32_arg_shape.dimensions(), window_dimensions, window_strides,
1148                   Padding::kValid);
1149   TF_ASSERT_OK_AND_ASSIGN(
1150       Window window,
1151       ShapeInference::InferWindowFromDimensions(
1152           window_dimensions, window_strides, padding_values, {}, {}));
1153   auto inferred_status = ShapeInference::InferReduceWindowShape(
1154       absl::MakeSpan(args), absl::MakeSpan(inits), window, to_apply);
1155   VLOG(2) << inferred_status.ValueOrDie().ToString() << "\n";
1156   EXPECT_IS_OK(inferred_status.status());
1157   EXPECT_TRUE(ShapeUtil::Equal(
1158       ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {5, 2, 0}),
1159                                  ShapeUtil::MakeShape(S32, {5, 2, 0})}),
1160       inferred_status.ValueOrDie()));
1161 }
1162 
TEST_F(ReduceShapeInferenceTest,ErrorMultiOutputBadReducerInput1)1163 TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput1) {
1164   Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1165   Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
1166   ProgramShape to_apply =
1167       ShapeUtil::MakeProgramShape({f32_, s32_, f32_, s32_, f32_, s32_},
1168                                   ShapeUtil::MakeTupleShape({f32_, s32_}));
1169   auto inferred_status = ShapeInference::InferReduceShape(
1170       {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
1171   EXPECT_FALSE(inferred_status.ok());
1172   EXPECT_THAT(inferred_status.status().error_message(),
1173               HasSubstr("must take 4 parameters, but takes 6 parameter(s)"));
1174 }
1175 
TEST_F(ReduceShapeInferenceTest,ErrorMultiOutputBadReducerInput2)1176 TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput2) {
1177   Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1178   Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
1179   ProgramShape to_apply = ShapeUtil::MakeProgramShape(
1180       {s32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_}));
1181   auto inferred_status = ShapeInference::InferReduceShape(
1182       {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
1183   EXPECT_FALSE(inferred_status.ok());
1184   EXPECT_THAT(
1185       inferred_status.status().error_message(),
1186       HasSubstr(
1187           "parameter shape differs from the result shape: s32[] vs f32[]"));
1188 }
1189 
TEST_F(ReduceShapeInferenceTest,ErrorMultiOutputBadReducerInput3)1190 TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput3) {
1191   ProgramShape to_apply = ShapeUtil::MakeProgramShape(
1192       {s32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_}));
1193   auto inferred_status = ShapeInference::InferReduceShape({}, {0, 1}, to_apply);
1194   EXPECT_FALSE(inferred_status.ok());
1195   EXPECT_THAT(inferred_status.status().error_message(),
1196               HasSubstr("must have at least 2 arguments, has 0"));
1197 }
1198 
TEST_F(ReduceShapeInferenceTest,ErrorBadReduceWindowInput)1199 TEST_F(ReduceShapeInferenceTest, ErrorBadReduceWindowInput) {
1200   Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3, 1});
1201   Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3, 1});
1202   std::vector<const Shape*> args = {&f32_arg_shape, &s32_arg_shape};
1203   std::vector<const Shape*> inits = {&f32_, &s32_};
1204   ProgramShape to_apply = ShapeUtil::MakeProgramShape(
1205       {f32_, f32_, f32_, f32_}, ShapeUtil::MakeTupleShape({f32_, s32_}));
1206   std::vector<int64_t> window_dimensions = {1, 2, 4};
1207   std::vector<int64_t> window_strides = {1, 1, 1};
1208   std::vector<std::pair<int64_t, int64_t>> padding_values =
1209       MakePadding(f32_arg_shape.dimensions(), window_dimensions, window_strides,
1210                   Padding::kValid);
1211   TF_ASSERT_OK_AND_ASSIGN(
1212       Window window,
1213       ShapeInference::InferWindowFromDimensions(
1214           window_dimensions, window_strides, padding_values, {}, {}));
1215   auto inferred_status = ShapeInference::InferReduceWindowShape(
1216       absl::MakeSpan(args), absl::MakeSpan(inits), window, to_apply);
1217   EXPECT_FALSE(inferred_status.status().ok());
1218   EXPECT_THAT(inferred_status.status().error_message(),
1219               HasSubstr("f32[] vs s32[]"));
1220 }
1221 
TEST_F(ReduceShapeInferenceTest,ErrorMultiOutputBadReducerOutput1)1222 TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerOutput1) {
1223   Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1224   Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
1225   ProgramShape to_apply =
1226       ShapeUtil::MakeProgramShape({f32_, s32_, f32_, s32_}, f32_);
1227   auto inferred_status = ShapeInference::InferReduceShape(
1228       {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
1229   EXPECT_FALSE(inferred_status.ok());
1230   EXPECT_THAT(
1231       inferred_status.status().error_message(),
1232       HasSubstr("must produce a tuple with 2 elements, but produces a scalar"));
1233 }
1234 
TEST_F(ReduceShapeInferenceTest,ErrorMultiOutputBadReducerOutput2)1235 TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerOutput2) {
1236   Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1237   Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
1238   ProgramShape to_apply = ShapeUtil::MakeProgramShape(
1239       {f32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_, s32_}));
1240   auto inferred_status = ShapeInference::InferReduceShape(
1241       {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
1242   EXPECT_FALSE(inferred_status.ok());
1243   EXPECT_THAT(
1244       inferred_status.status().error_message(),
1245       HasSubstr("must produce a tuple with 2 elements, but has 3 elements"));
1246 }
1247 
TEST_F(ReduceShapeInferenceTest,ErrorMultiOutputBadReducerBoth)1248 TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerBoth) {
1249   Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1250   Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
1251   ProgramShape to_apply = ShapeUtil::MakeProgramShape(
1252       {s32_, s32_, s32_, s32_}, ShapeUtil::MakeTupleShape({s32_, s32_}));
1253   auto inferred_status = ShapeInference::InferReduceShape(
1254       {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
1255   EXPECT_FALSE(inferred_status.ok());
1256   EXPECT_THAT(inferred_status.status().error_message(),
1257               HasSubstr("accumulator shape at index 0 differs from the "
1258                         "init_value shape: s32[] vs f32[]"));
1259 }
1260 
TEST_F(ReduceShapeInferenceTest,ErrorOutOfBoundsDimension)1261 TEST_F(ReduceShapeInferenceTest, ErrorOutOfBoundsDimension) {
1262   ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_);
1263   Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1264   auto inferred_status = ShapeInference::InferReduceShape(
1265       {&arg_shape, &f32_},
1266       /*dimensions_to_reduce=*/{3, 4}, to_apply);
1267   EXPECT_FALSE(inferred_status.ok());
1268   EXPECT_THAT(inferred_status.status().error_message(),
1269               HasSubstr("out-of-bounds dimension"));
1270 }
1271 
TEST_F(ReduceShapeInferenceTest,ErrorToApplyArity)1272 TEST_F(ReduceShapeInferenceTest, ErrorToApplyArity) {
1273   ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_, f32_}, f32_);
1274   Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1275   auto inferred_status =
1276       ShapeInference::InferReduceShape({&arg_shape, &f32_},
1277                                        /*dimensions_to_reduce=*/{0}, to_apply);
1278   EXPECT_FALSE(inferred_status.ok());
1279   EXPECT_THAT(inferred_status.status().error_message(),
1280               HasSubstr("take 2 parameters"));
1281 }
1282 
TEST_F(ReduceShapeInferenceTest,ErrorElementTypeVsApplyType)1283 TEST_F(ReduceShapeInferenceTest, ErrorElementTypeVsApplyType) {
1284   ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, s32_);
1285   Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1286   auto inferred_status =
1287       ShapeInference::InferReduceShape({&arg_shape, &f32_},
1288                                        /*dimensions_to_reduce=*/{0}, to_apply);
1289   EXPECT_FALSE(inferred_status.ok());
1290   EXPECT_THAT(inferred_status.status().error_message(),
1291               HasSubstr("0-th parameter shape differs"));
1292 }
1293 
TEST_F(ReduceShapeInferenceTest,ReduceWithRepeatedReduceDimension)1294 TEST_F(ReduceShapeInferenceTest, ReduceWithRepeatedReduceDimension) {
1295   ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_);
1296   Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1297   auto inferred_status = ShapeInference::InferReduceShape(
1298       {&arg_shape, &f32_},
1299       /*dimensions_to_reduce=*/{0, 0}, to_apply);
1300   EXPECT_FALSE(inferred_status.ok());
1301   EXPECT_THAT(inferred_status.status().error_message(),
1302               HasSubstr("Duplicate reduction dimension: 0"));
1303 }
1304 
TEST_F(ShapeInferenceTest,InferSliceShapeRank2)1305 TEST_F(ShapeInferenceTest, InferSliceShapeRank2) {
1306   Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
1307   auto inferred_status =
1308       ShapeInference::InferSliceShape(matrix_shape, {32, 0}, {64, 64}, {1, 1});
1309   ASSERT_IS_OK(inferred_status.status());
1310   Shape inferred = inferred_status.ValueOrDie();
1311   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {32, 64}), inferred));
1312 }
1313 
TEST_F(ShapeInferenceTest,InferSliceWithDynamicDimensions)1314 TEST_F(ShapeInferenceTest, InferSliceWithDynamicDimensions) {
1315   Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}, {true, true});
1316   auto inferred_status =
1317       ShapeInference::InferSliceShape(matrix_shape, {32, 0}, {33, 64}, {1, 1});
1318   ASSERT_IS_OK(inferred_status.status());
1319   Shape inferred = inferred_status.ValueOrDie();
1320   ASSERT_TRUE(ShapeUtil::Equal(
1321       ShapeUtil::MakeShape(F32, {1, 64}, {false, true}), inferred));
1322 }
1323 
TEST_F(ShapeInferenceTest,InferSliceShapeRank2WithStrides)1324 TEST_F(ShapeInferenceTest, InferSliceShapeRank2WithStrides) {
1325   Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
1326   auto inferred_status =
1327       ShapeInference::InferSliceShape(matrix_shape, {32, 0}, {64, 64}, {2, 4});
1328   ASSERT_IS_OK(inferred_status.status());
1329   Shape inferred = inferred_status.ValueOrDie();
1330   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {16, 16}), inferred));
1331 }
1332 
TEST_F(ShapeInferenceTest,InferSliceShapeRank2WithStridesNotIntegral)1333 TEST_F(ShapeInferenceTest, InferSliceShapeRank2WithStridesNotIntegral) {
1334   Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
1335   auto inferred_status =
1336       ShapeInference::InferSliceShape(matrix_shape, {15, 0}, {20, 13}, {2, 4});
1337   ASSERT_IS_OK(inferred_status.status());
1338   Shape inferred = inferred_status.ValueOrDie();
1339   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {3, 4}), inferred));
1340 }
1341 
TEST_F(ShapeInferenceTest,InferInvalidStride)1342 TEST_F(ShapeInferenceTest, InferInvalidStride) {
1343   Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
1344   auto inferred_status =
1345       ShapeInference::InferSliceShape(matrix_shape, {127, 0}, {129, 2}, {0, 1});
1346   ASSERT_FALSE(inferred_status.ok());
1347   ASSERT_EQ(tensorflow::error::INVALID_ARGUMENT,
1348             inferred_status.status().code());
1349 }
1350 
TEST_F(ShapeInferenceTest,InferOobSliceShapeRank2)1351 TEST_F(ShapeInferenceTest, InferOobSliceShapeRank2) {
1352   Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
1353   auto inferred_status =
1354       ShapeInference::InferSliceShape(matrix_shape, {127, 0}, {129, 2}, {1, 1});
1355   ASSERT_FALSE(inferred_status.ok());
1356   ASSERT_EQ(tensorflow::error::INVALID_ARGUMENT,
1357             inferred_status.status().code());
1358 }
1359 
TEST_F(ShapeInferenceTest,InferSliceShapeRank1)1360 TEST_F(ShapeInferenceTest, InferSliceShapeRank1) {
1361   Shape vector_shape = ShapeUtil::MakeShape(F32, {17});
1362   auto inferred_status =
1363       ShapeInference::InferSliceShape(vector_shape, {2}, {4}, {1});
1364   ASSERT_TRUE(inferred_status.ok());
1365   Shape inferred = inferred_status.ValueOrDie();
1366   ASSERT_TRUE(ShapeUtil::Equal(inferred, ShapeUtil::MakeShape(F32, {2})));
1367 }
1368 
TEST_F(ShapeInferenceTest,InferConstIndexShape)1369 TEST_F(ShapeInferenceTest, InferConstIndexShape) {
1370   Shape tuple_shape = ShapeUtil::MakeTupleShape({f32_, s32_});
1371   auto inferred0_status =
1372       ShapeInference::InferGetTupleElementShape(tuple_shape, 0);
1373   auto inferred1_status =
1374       ShapeInference::InferGetTupleElementShape(tuple_shape, 1);
1375   ASSERT_IS_OK(inferred0_status.status());
1376   ASSERT_IS_OK(inferred1_status.status());
1377   ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred0_status.ValueOrDie()));
1378   ASSERT_TRUE(ShapeUtil::Equal(s32_, inferred1_status.ValueOrDie()));
1379 }
1380 
TEST_F(ShapeInferenceTest,InferTupleElementShapeOutOfBound)1381 TEST_F(ShapeInferenceTest, InferTupleElementShapeOutOfBound) {
1382   Shape tuple_shape = ShapeUtil::MakeTupleShape({f32_, s32_});
1383   auto inferredNegative_status =
1384       ShapeInference::InferGetTupleElementShape(tuple_shape, -1);
1385   auto inferred2_status =
1386       ShapeInference::InferGetTupleElementShape(tuple_shape, 2);
1387   ASSERT_FALSE(inferredNegative_status.ok());
1388   ASSERT_FALSE(inferred2_status.ok());
1389   EXPECT_THAT(inferredNegative_status.status().error_message(),
1390               HasSubstr("attempt to index out of tuple bounds"));
1391   EXPECT_THAT(inferred2_status.status().error_message(),
1392               HasSubstr("attempt to index out of tuple bounds"));
1393 }
1394 
TEST_F(ShapeInferenceTest,InferPowShape)1395 TEST_F(ShapeInferenceTest, InferPowShape) {
1396   auto ten_floats = ShapeUtil::MakeShape(F32, {10});
1397   auto inferred_status = ShapeInference::InferBinaryOpShape(
1398       HloOpcode::kPower, ten_floats, f32_, {});
1399   ASSERT_IS_OK(inferred_status.status());
1400   ASSERT_TRUE(ShapeUtil::Equal(ten_floats, inferred_status.ValueOrDie()));
1401 }
1402 
TEST_F(ShapeInferenceTest,InferCompareShape)1403 TEST_F(ShapeInferenceTest, InferCompareShape) {
1404   auto ten_floats = ShapeUtil::MakeShape(F32, {10});
1405   auto inferred_status = ShapeInference::InferBinaryOpShape(
1406       HloOpcode::kCompare, ten_floats, f32_, {});
1407   ASSERT_IS_OK(inferred_status.status());
1408   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}),
1409                                inferred_status.ValueOrDie()));
1410 }
1411 
TEST_F(ShapeInferenceTest,InferReshapeDegenerateCombine)1412 TEST_F(ShapeInferenceTest, InferReshapeDegenerateCombine) {
1413   // [1, <=1]
1414   //   | reshape
1415   // [<=1]
1416   //
1417   // Both output dimension can be dynamic, use inferred_dimension to tie-break.
1418   auto operand = ShapeUtil::MakeShape(F32, {1, 1}, {false, true});
1419   auto status = ShapeInference::InferReshapeShape(operand, {1, 0}, {1},
1420                                                   /*inferred_dimension=*/-1);
1421   ASSERT_EQ(ShapeUtil::MakeShape(F32, {1}, {true}), status.ValueOrDie());
1422 }
1423 
TEST_F(ShapeInferenceTest,InferReshapeSplit)1424 TEST_F(ShapeInferenceTest, InferReshapeSplit) {
1425   // [<=10]
1426   //   | reshape
1427   // [1, 10]
1428   //
1429   // Both output dimension can be dynamic, use inferred_dimension to tie-break.
1430   auto operand = ShapeUtil::MakeShape(F32, {10}, {true});
1431   auto status = ShapeInference::InferReshapeShape(operand, {0}, {1, 10},
1432                                                   /*inferred_dimension=*/0);
1433   ASSERT_EQ(ShapeUtil::MakeShape(F32, {1, 10}, {true, false}),
1434             status.ValueOrDie());
1435 }
1436 
TEST_F(ShapeInferenceTest,InferReshapeCombine)1437 TEST_F(ShapeInferenceTest, InferReshapeCombine) {
1438   // [6, <=10]
1439   //   | reshape
1440   // [<=60]
1441   auto operand = ShapeUtil::MakeShape(F32, {6, 10}, {false, true});
1442   auto status = ShapeInference::InferReshapeShape(operand, {1, 0}, {60},
1443                                                   /*inferred_dimension=*/-11);
1444   ASSERT_EQ(ShapeUtil::MakeShape(F32, {60}, {true}), status.ValueOrDie());
1445 }
1446 
TEST_F(ShapeInferenceTest,UnchangedDimension)1447 TEST_F(ShapeInferenceTest, UnchangedDimension) {
1448   // [6, <=10]
1449   //   | reshape
1450   // [2, 3, <=10]
1451   auto operand = ShapeUtil::MakeShape(F32, {6, 10}, {false, true});
1452   auto status = ShapeInference::InferReshapeShape(operand, {1, 0}, {2, 3, 10},
1453                                                   /*inferred_dimension=*/-11);
1454   ASSERT_EQ(ShapeUtil::MakeShape(F32, {2, 3, 10}, {false, false, true}),
1455             status.ValueOrDie());
1456 }
1457 
TEST_F(ShapeInferenceTest,InferDynamicBroadcast)1458 TEST_F(ShapeInferenceTest, InferDynamicBroadcast) {
1459   // CHECK:
1460   // %broadcast = s32[15,<=15]{1,0} broadcast(s32[<=15]{0}), dimensions={1}
1461 
1462   auto operand_shape = ShapeUtil::MakeShape(F32, {15}, {true});
1463   auto inferred_status =
1464       ShapeInference::InferBroadcastShape(operand_shape, {15});
1465   ASSERT_IS_OK(inferred_status.status());
1466   Shape inferred = inferred_status.ValueOrDie();
1467   ASSERT_EQ(ShapeUtil::MakeShape(F32, {15, 15}, {false, true}), inferred);
1468 }
1469 
TEST_F(ShapeInferenceTest,BroadcastScalar)1470 TEST_F(ShapeInferenceTest, BroadcastScalar) {
1471   for (auto element_type : {F32, U32, S8}) {
1472     const Shape scalar_shape = ShapeUtil::MakeShape(element_type, {});
1473     {  // no-op scalar broadcast
1474       auto status = ShapeInference::InferBroadcastShape(scalar_shape, {});
1475       ASSERT_IS_OK(status.status());
1476       ASSERT_TRUE(ShapeUtil::Equal(scalar_shape, status.ValueOrDie()));
1477     }
1478     const Shape oned_shape = ShapeUtil::MakeShape(element_type, {3});
1479     {  // scalar -> 1d broadcast
1480       auto status = ShapeInference::InferBroadcastShape(scalar_shape, {3});
1481       ASSERT_IS_OK(status.status());
1482       ASSERT_TRUE(ShapeUtil::Equal(oned_shape, status.ValueOrDie()));
1483     }
1484     {  // no-op 1d broadcast
1485       auto status = ShapeInference::InferBroadcastShape(oned_shape, {});
1486       ASSERT_IS_OK(status.status());
1487       ASSERT_TRUE(ShapeUtil::Equal(oned_shape, status.ValueOrDie()));
1488     }
1489     const Shape twod_shape = ShapeUtil::MakeShape(element_type, {2, 3});
1490     {  // scalar -> 2d broadcast
1491       auto status = ShapeInference::InferBroadcastShape(scalar_shape, {2, 3});
1492       ASSERT_IS_OK(status.status());
1493       ASSERT_TRUE(ShapeUtil::Equal(twod_shape, status.ValueOrDie()));
1494     }
1495     {  // 1d -> 2d broadcast
1496       auto status = ShapeInference::InferBroadcastShape(oned_shape, {2});
1497       ASSERT_IS_OK(status.status());
1498       ASSERT_TRUE(ShapeUtil::Equal(twod_shape, status.ValueOrDie()));
1499     }
1500   }
1501 }
1502 
1503 // scalar <dot> vector: ok
TEST_F(ShapeInferenceTest,ScalarDotVector)1504 TEST_F(ShapeInferenceTest, ScalarDotVector) {
1505   DotDimensionNumbers dot_dnums;
1506   auto inferred_status = ShapeInference::InferDotOpShape(
1507       f32_, vector_32_, dot_dnums, /*preferred_element_type=*/std::nullopt);
1508   EXPECT_TRUE(inferred_status.ok());
1509   EXPECT_EQ(inferred_status.ValueOrDie(), vector_32_);
1510 }
1511 
1512 // 3D <dot> 2D: error
TEST_F(ShapeInferenceTest,DotWithRankHigherThanTwo)1513 TEST_F(ShapeInferenceTest, DotWithRankHigherThanTwo) {
1514   DotDimensionNumbers dot_dnums;
1515   dot_dnums.add_lhs_contracting_dimensions(1);
1516   dot_dnums.add_rhs_contracting_dimensions(0);
1517   auto inferred_status = ShapeInference::InferDotOpShape(
1518       ShapeUtil::MakeShape(F32, {32, 32, 32}), matrix_32_64_, dot_dnums,
1519       /*preferred_element_type=*/std::nullopt);
1520   EXPECT_TRUE(inferred_status.ok());
1521   EXPECT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(),
1522                                ShapeUtil::MakeShape(F32, {32, 32, 64})));
1523 }
1524 
1525 // vector <dot> vector -> scalar
TEST_F(ShapeInferenceTest,VectorDotVector)1526 TEST_F(ShapeInferenceTest, VectorDotVector) {
1527   DotDimensionNumbers dot_dnums;
1528   dot_dnums.add_lhs_contracting_dimensions(0);
1529   dot_dnums.add_rhs_contracting_dimensions(0);
1530   auto inferred_status =
1531       ShapeInference::InferDotOpShape(vector_64_, vector_64_, dot_dnums,
1532                                       /*preferred_element_type=*/std::nullopt);
1533   ASSERT_IS_OK(inferred_status.status());
1534   ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred_status.ValueOrDie()));
1535   auto inferred_status_mismatch =
1536       ShapeInference::InferDotOpShape(vector_64_, vector_32_, dot_dnums,
1537                                       /*preferred_element_type=*/std::nullopt);
1538   ASSERT_FALSE(inferred_status_mismatch.ok());
1539 }
1540 
1541 // matrix <dot> vector -> vector
TEST_F(ShapeInferenceTest,MatrixDotVector)1542 TEST_F(ShapeInferenceTest, MatrixDotVector) {
1543   DotDimensionNumbers dot_dnums;
1544   dot_dnums.add_lhs_contracting_dimensions(1);
1545   dot_dnums.add_rhs_contracting_dimensions(0);
1546   auto inferred_status =
1547       ShapeInference::InferDotOpShape(matrix_32_64_, vector_64_, dot_dnums,
1548                                       /*preferred_element_type=*/std::nullopt);
1549   ASSERT_IS_OK(inferred_status.status());
1550   ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), vector_32_));
1551   auto inferred_status_mismatch =
1552       ShapeInference::InferDotOpShape(matrix_32_64_, vector_32_, dot_dnums,
1553                                       /*preferred_element_type=*/std::nullopt);
1554   ASSERT_FALSE(inferred_status_mismatch.ok());
1555 }
1556 
1557 // vector <dot> matrix -> vector
TEST_F(ShapeInferenceTest,VectorDotMatrix)1558 TEST_F(ShapeInferenceTest, VectorDotMatrix) {
1559   DotDimensionNumbers dot_dnums;
1560   dot_dnums.add_lhs_contracting_dimensions(0);
1561   dot_dnums.add_rhs_contracting_dimensions(0);
1562   auto inferred_status =
1563       ShapeInference::InferDotOpShape(vector_32_, matrix_32_64_, dot_dnums,
1564                                       /*preferred_element_type=*/std::nullopt);
1565   ASSERT_IS_OK(inferred_status.status());
1566   ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), vector_64_));
1567   auto inferred_status_mismatch =
1568       ShapeInference::InferDotOpShape(vector_64_, matrix_32_64_, dot_dnums,
1569                                       /*preferred_element_type=*/std::nullopt);
1570   ASSERT_FALSE(inferred_status_mismatch.ok());
1571 }
1572 
1573 // matrix <dot> matrix -> matrix
TEST_F(ShapeInferenceTest,MatrixDotMatrix)1574 TEST_F(ShapeInferenceTest, MatrixDotMatrix) {
1575   DotDimensionNumbers dot_dnums;
1576   dot_dnums.add_lhs_contracting_dimensions(1);
1577   dot_dnums.add_rhs_contracting_dimensions(0);
1578   auto inferred_status_match =
1579       ShapeInference::InferDotOpShape(matrix_32_64_, matrix_64_48_, dot_dnums,
1580                                       /*preferred_element_type=*/std::nullopt);
1581   ASSERT_IS_OK(inferred_status_match.status());
1582   ASSERT_TRUE(
1583       ShapeUtil::Equal(inferred_status_match.ValueOrDie(), matrix_32_48_))
1584       << "inferred: "
1585       << ShapeUtil::HumanString(inferred_status_match.ValueOrDie())
1586       << " expected: " << ShapeUtil::HumanString(matrix_64_48_);
1587   auto inferred_status_mismatch =
1588       ShapeInference::InferDotOpShape(matrix_32_64_, matrix_32_64_, dot_dnums,
1589                                       /*preferred_element_type=*/std::nullopt);
1590   ASSERT_FALSE(inferred_status_mismatch.ok());
1591 }
1592 
1593 // BatchMatMul with two batch dimensions and one contracting dimension.
TEST_F(ShapeInferenceTest,DotGeneral)1594 TEST_F(ShapeInferenceTest, DotGeneral) {
1595   Shape lhs_shape = ShapeUtil::MakeShape(F32, {5, 2, 11, 3});
1596   Shape rhs_shape = ShapeUtil::MakeShape(F32, {5, 2, 3, 14});
1597   Shape output_shape = ShapeUtil::MakeShape(F32, {5, 2, 11, 14});
1598 
1599   DotDimensionNumbers dot_dnums;
1600   dot_dnums.add_lhs_contracting_dimensions(3);
1601   dot_dnums.add_lhs_batch_dimensions(0);
1602   dot_dnums.add_lhs_batch_dimensions(1);
1603 
1604   dot_dnums.add_rhs_contracting_dimensions(2);
1605   dot_dnums.add_rhs_batch_dimensions(0);
1606   dot_dnums.add_rhs_batch_dimensions(1);
1607 
1608   auto inferred_status_match =
1609       ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
1610                                       /*preferred_element_type=*/std::nullopt);
1611   ASSERT_IS_OK(inferred_status_match.status());
1612   ASSERT_TRUE(
1613       ShapeUtil::Equal(inferred_status_match.ValueOrDie(), output_shape))
1614       << "inferred: "
1615       << ShapeUtil::HumanString(inferred_status_match.ValueOrDie())
1616       << " expected: " << ShapeUtil::HumanString(output_shape);
1617 }
1618 
1619 // BatchMatMul with two contracting dimensions fails.
TEST_F(ShapeInferenceTest,DotWithTwoContractingDimsFails)1620 TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsFails) {
1621   Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3, 2});
1622   Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14});
1623 
1624   DotDimensionNumbers dot_dnums;
1625   dot_dnums.add_lhs_contracting_dimensions(2);
1626   dot_dnums.add_lhs_contracting_dimensions(3);
1627   dot_dnums.add_lhs_batch_dimensions(0);
1628 
1629   dot_dnums.add_rhs_contracting_dimensions(1);
1630   dot_dnums.add_rhs_batch_dimensions(0);
1631 
1632   auto inferred_status =
1633       ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
1634                                       /*preferred_element_type=*/std::nullopt);
1635   ASSERT_FALSE(inferred_status.ok());
1636   ASSERT_THAT(inferred_status.status().error_message(),
1637               HasSubstr("Must specify the same number of contracting "
1638                         "dimensions for lhs and rhs."));
1639 }
1640 
TEST_F(ShapeInferenceTest,DotWithTwoContractingDimsPasses)1641 TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsPasses) {
1642   Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3, 2});
1643   Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 2, 14});
1644   Shape output_shape = ShapeUtil::MakeShape(F32, {2, 11, 14});
1645 
1646   DotDimensionNumbers dot_dnums;
1647   dot_dnums.add_lhs_contracting_dimensions(2);
1648   dot_dnums.add_lhs_contracting_dimensions(3);
1649   dot_dnums.add_lhs_batch_dimensions(0);
1650 
1651   dot_dnums.add_rhs_contracting_dimensions(1);
1652   dot_dnums.add_rhs_contracting_dimensions(2);
1653   dot_dnums.add_rhs_batch_dimensions(0);
1654 
1655   auto inferred_status =
1656       ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
1657                                       /*preferred_element_type=*/std::nullopt);
1658   EXPECT_TRUE(inferred_status.ok());
1659   EXPECT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), output_shape));
1660 }
1661 
TEST_F(ShapeInferenceTest,ErrorSetDimensionSize)1662 TEST_F(ShapeInferenceTest, ErrorSetDimensionSize) {
1663   Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1664   Shape val_shape = ShapeUtil::MakeShape(S32, {1});
1665   auto inferred_status = ShapeInference::InferSetDimensionSizeShape(
1666       arg_shape, val_shape, /*dimension=*/0);
1667 
1668   EXPECT_FALSE(inferred_status.ok());
1669   EXPECT_THAT(inferred_status.status().error_message(),
1670               HasSubstr("value has to be S32 scalar"));
1671 }
1672 
TEST_F(ShapeInferenceTest,ErrorSetDimensionSizeWrongType)1673 TEST_F(ShapeInferenceTest, ErrorSetDimensionSizeWrongType) {
1674   Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1675   Shape val_shape = ShapeUtil::MakeShape(U32, {});
1676   auto inferred_status = ShapeInference::InferSetDimensionSizeShape(
1677       arg_shape, val_shape, /*dimension=*/0);
1678 
1679   EXPECT_FALSE(inferred_status.ok());
1680   EXPECT_THAT(inferred_status.status().error_message(),
1681               HasSubstr("value has to be S32 scalar"));
1682 }
1683 
1684 // BatchMatMul with different batch dimension sizes fails.
TEST_F(ShapeInferenceTest,DotWithMismatchedBatchDimSizesFails)1685 TEST_F(ShapeInferenceTest, DotWithMismatchedBatchDimSizesFails) {
1686   Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3});
1687   Shape rhs_shape = ShapeUtil::MakeShape(F32, {3, 3, 14});
1688 
1689   DotDimensionNumbers dot_dnums;
1690   dot_dnums.add_lhs_contracting_dimensions(2);
1691   dot_dnums.add_lhs_batch_dimensions(0);
1692 
1693   dot_dnums.add_rhs_contracting_dimensions(1);
1694   dot_dnums.add_rhs_batch_dimensions(0);
1695 
1696   auto inferred_status =
1697       ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
1698                                       /*preferred_element_type=*/std::nullopt);
1699   ASSERT_FALSE(inferred_status.ok());
1700   ASSERT_THAT(inferred_status.status().error_message(),
1701               HasSubstr("Batch dimension sizes must match"));
1702 }
1703 
1704 // BatchMatMul with different batch dimension numbers passes
TEST_F(ShapeInferenceTest,DotWithMismatchedBatchDimNumbersPasses)1705 TEST_F(ShapeInferenceTest, DotWithMismatchedBatchDimNumbersPasses) {
1706   Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3});
1707   Shape rhs_shape = ShapeUtil::MakeShape(F32, {3, 2, 14});
1708 
1709   DotDimensionNumbers dot_dnums;
1710   dot_dnums.add_lhs_contracting_dimensions(2);
1711   dot_dnums.add_lhs_batch_dimensions(0);
1712 
1713   dot_dnums.add_rhs_contracting_dimensions(0);
1714   dot_dnums.add_rhs_batch_dimensions(1);
1715 
1716   auto inferred_status =
1717       ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
1718                                       /*preferred_element_type=*/std::nullopt);
1719   ASSERT_TRUE(inferred_status.ok());
1720   ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(),
1721                                ShapeUtil::MakeShape(F32, {2, 11, 14})));
1722 }
1723 
1724 // BatchMatMul with out-of-range dimension numbers fails.
TEST_F(ShapeInferenceTest,DotWithContractingDimNumberOutOfRange)1725 TEST_F(ShapeInferenceTest, DotWithContractingDimNumberOutOfRange) {
1726   Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3});
1727   Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14});
1728 
1729   DotDimensionNumbers dot_dnums;
1730   dot_dnums.add_lhs_contracting_dimensions(3);
1731   dot_dnums.add_lhs_batch_dimensions(0);
1732 
1733   dot_dnums.add_rhs_contracting_dimensions(0);
1734   dot_dnums.add_rhs_batch_dimensions(1);
1735 
1736   auto inferred_status =
1737       ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
1738                                       /*preferred_element_type=*/std::nullopt);
1739   ASSERT_FALSE(inferred_status.ok());
1740   ASSERT_THAT(inferred_status.status().error_message(),
1741               HasSubstr("A dimension number is out of range"));
1742 }
1743 
1744 // BatchMatMul with non-unique dimension numbers fails.
TEST_F(ShapeInferenceTest,DotWithContractingNonUniqueDimNumber)1745 TEST_F(ShapeInferenceTest, DotWithContractingNonUniqueDimNumber) {
1746   Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3});
1747   Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14});
1748 
1749   DotDimensionNumbers dot_dnums;
1750   dot_dnums.add_lhs_contracting_dimensions(0);
1751   dot_dnums.add_lhs_batch_dimensions(0);
1752 
1753   dot_dnums.add_rhs_contracting_dimensions(0);
1754   dot_dnums.add_rhs_batch_dimensions(1);
1755 
1756   auto inferred_status =
1757       ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
1758                                       /*preferred_element_type=*/std::nullopt);
1759   ASSERT_FALSE(inferred_status.ok());
1760   ASSERT_THAT(inferred_status.status().error_message(),
1761               HasSubstr("A dimension number is not unique"));
1762 }
1763 
TEST_F(ShapeInferenceTest,DotWithIntegralPreferredElementType)1764 TEST_F(ShapeInferenceTest, DotWithIntegralPreferredElementType) {
1765   DotDimensionNumbers dot_dnums;
1766   dot_dnums.add_lhs_contracting_dimensions(1);
1767   dot_dnums.add_rhs_contracting_dimensions(0);
1768   TF_ASSERT_OK_AND_ASSIGN(Shape inferred_shape,
1769                           ShapeInference::InferDotOpShape(
1770                               ShapeUtil::MakeShape(S8, {32, 32}),
1771                               ShapeUtil::MakeShape(S16, {32, 32}), dot_dnums,
1772                               /*preferred_element_type=*/S32));
1773   EXPECT_TRUE(
1774       ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(S32, {32, 32})));
1775 }
1776 
TEST_F(ShapeInferenceTest,DotWithPreferredElementTypeSameAsInferredType)1777 TEST_F(ShapeInferenceTest, DotWithPreferredElementTypeSameAsInferredType) {
1778   DotDimensionNumbers dot_dnums;
1779   dot_dnums.add_lhs_contracting_dimensions(1);
1780   dot_dnums.add_rhs_contracting_dimensions(0);
1781   TF_ASSERT_OK_AND_ASSIGN(Shape inferred_shape,
1782                           ShapeInference::InferDotOpShape(
1783                               ShapeUtil::MakeShape(BF16, {32, 32}),
1784                               ShapeUtil::MakeShape(F32, {32, 32}), dot_dnums,
1785                               /*preferred_element_type=*/F32));
1786   EXPECT_TRUE(
1787       ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(F32, {32, 32})));
1788 }
1789 
TEST_F(ShapeInferenceTest,FloatingPointDotWithNarrowerPreferredElementType)1790 TEST_F(ShapeInferenceTest, FloatingPointDotWithNarrowerPreferredElementType) {
1791   DotDimensionNumbers dot_dnums;
1792   dot_dnums.add_lhs_contracting_dimensions(1);
1793   dot_dnums.add_rhs_contracting_dimensions(0);
1794   TF_ASSERT_OK_AND_ASSIGN(Shape inferred_shape,
1795                           ShapeInference::InferDotOpShape(
1796                               ShapeUtil::MakeShape(BF16, {32, 32}),
1797                               ShapeUtil::MakeShape(F32, {32, 32}), dot_dnums,
1798                               /*preferred_element_type=*/BF16));
1799   EXPECT_TRUE(
1800       ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(BF16, {32, 32})));
1801 }
1802 
TEST_F(ShapeInferenceTest,FloatingPointDotWithIntegralPreferredElementType)1803 TEST_F(ShapeInferenceTest, FloatingPointDotWithIntegralPreferredElementType) {
1804   DotDimensionNumbers dot_dnums;
1805   dot_dnums.add_lhs_contracting_dimensions(1);
1806   dot_dnums.add_rhs_contracting_dimensions(0);
1807   TF_ASSERT_OK_AND_ASSIGN(Shape inferred_shape,
1808                           ShapeInference::InferDotOpShape(
1809                               ShapeUtil::MakeShape(BF16, {32, 32}),
1810                               ShapeUtil::MakeShape(BF16, {32, 32}), dot_dnums,
1811                               /*preferred_element_type=*/S32));
1812   EXPECT_TRUE(
1813       ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(S32, {32, 32})));
1814 }
1815 
TEST_F(ShapeInferenceTest,IntegralDotWithFloatingPointPreferredElementType)1816 TEST_F(ShapeInferenceTest, IntegralDotWithFloatingPointPreferredElementType) {
1817   DotDimensionNumbers dot_dnums;
1818   dot_dnums.add_lhs_contracting_dimensions(1);
1819   dot_dnums.add_rhs_contracting_dimensions(0);
1820   TF_ASSERT_OK_AND_ASSIGN(Shape inferred_shape,
1821                           ShapeInference::InferDotOpShape(
1822                               ShapeUtil::MakeShape(S8, {32, 32}),
1823                               ShapeUtil::MakeShape(S16, {32, 32}), dot_dnums,
1824                               /*preferred_element_type=*/F32));
1825   EXPECT_TRUE(
1826       ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(F32, {32, 32})));
1827 }
1828 
TEST_F(ShapeInferenceTest,DotWithPreferredElementTypeWithDifferentSignedness)1829 TEST_F(ShapeInferenceTest, DotWithPreferredElementTypeWithDifferentSignedness) {
1830   DotDimensionNumbers dot_dnums;
1831   dot_dnums.add_lhs_contracting_dimensions(1);
1832   dot_dnums.add_rhs_contracting_dimensions(0);
1833   TF_ASSERT_OK_AND_ASSIGN(Shape inferred_shape,
1834                           ShapeInference::InferDotOpShape(
1835                               ShapeUtil::MakeShape(S8, {32, 32}),
1836                               ShapeUtil::MakeShape(S16, {32, 32}), dot_dnums,
1837                               /*preferred_element_type=*/U32));
1838   EXPECT_TRUE(
1839       ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(U32, {32, 32})));
1840 }
1841 
TEST_F(ShapeInferenceTest,DotWithNarrowerPreferredElementType)1842 TEST_F(ShapeInferenceTest, DotWithNarrowerPreferredElementType) {
1843   DotDimensionNumbers dot_dnums;
1844   dot_dnums.add_lhs_contracting_dimensions(1);
1845   dot_dnums.add_rhs_contracting_dimensions(0);
1846   auto inferred_status = ShapeInference::InferDotOpShape(
1847                              ShapeUtil::MakeShape(S8, {32, 32}),
1848                              ShapeUtil::MakeShape(S16, {32, 32}), dot_dnums,
1849                              /*preferred_element_type=*/S8)
1850                              .status();
1851   ASSERT_FALSE(inferred_status.ok());
1852   ASSERT_THAT(inferred_status.error_message(),
1853               HasSubstr("must not be narrower than the original type"));
1854 }
1855 
TEST_F(ShapeInferenceTest,BinOpBroadcastMatrixVector)1856 TEST_F(ShapeInferenceTest, BinOpBroadcastMatrixVector) {
1857   // Test variations of broadcasting a vector for a binary add with a
1858   // matrix.
1859   const Shape mat = ShapeUtil::MakeShape(F32, {16, 8});
1860   const Shape vec8 = ShapeUtil::MakeShape(F32, {8});
1861   const Shape vec16 = ShapeUtil::MakeShape(F32, {16});
1862 
1863   auto inferred_status_match =
1864       ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec8, {1});
1865   ASSERT_IS_OK(inferred_status_match.status());
1866   ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), mat));
1867 
1868   auto inferred_status_mismatch =
1869       ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec8, {0});
1870   ASSERT_FALSE(inferred_status_mismatch.ok());
1871 
1872   inferred_status_match =
1873       ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec16, {0});
1874   ASSERT_IS_OK(inferred_status_match.status());
1875   ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), mat));
1876 
1877   inferred_status_mismatch =
1878       ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec16, {1});
1879   ASSERT_FALSE(inferred_status_mismatch.ok());
1880 }
1881 
TEST_F(ShapeInferenceTest,BinOpBroadcastCubeMatrix)1882 TEST_F(ShapeInferenceTest, BinOpBroadcastCubeMatrix) {
1883   // Test variations of broadcasting a matrix for a binary add with a cube.
1884   const Shape cube = ShapeUtil::MakeShape(F32, {16, 8, 4});
1885   const Shape matrix8_4 = ShapeUtil::MakeShape(F32, {8, 4});
1886   const Shape matrix16_4 = ShapeUtil::MakeShape(F32, {16, 4});
1887   const Shape matrix16_8 = ShapeUtil::MakeShape(F32, {16, 8});
1888 
1889   auto inferred_status_match = ShapeInference::InferBinaryOpShape(
1890       HloOpcode::kAdd, cube, matrix8_4, {1, 2});
1891   ASSERT_IS_OK(inferred_status_match.status());
1892   ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube));
1893 
1894   inferred_status_match = ShapeInference::InferBinaryOpShape(
1895       HloOpcode::kAdd, cube, matrix16_4, {0, 2});
1896   ASSERT_IS_OK(inferred_status_match.status());
1897   ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube));
1898 
1899   inferred_status_match = ShapeInference::InferBinaryOpShape(
1900       HloOpcode::kAdd, cube, matrix16_8, {0, 1});
1901   ASSERT_IS_OK(inferred_status_match.status());
1902   ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube));
1903 }
1904 
TEST_F(ShapeInferenceTest,BinOpBroadcastBadDimension)1905 TEST_F(ShapeInferenceTest, BinOpBroadcastBadDimension) {
1906   // Test various errors with the broadcast argument.
1907   const Shape tensor = ShapeUtil::MakeShape(F32, {16, 8, 4});
1908   const Shape tensor8_8_8 = ShapeUtil::MakeShape(F32, {8, 8, 8});
1909   const Shape vec8 = ShapeUtil::MakeShape(F32, {8});
1910   const Shape matrix8_4 = ShapeUtil::MakeShape(F32, {8, 4});
1911   const Shape matrix8_8 = ShapeUtil::MakeShape(F32, {8, 8});
1912 
1913   // "magical" broadcast rejected
1914   auto inferred_status_error1 =
1915       ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {});
1916   ASSERT_FALSE(inferred_status_error1.ok());
1917   ASSERT_THAT(inferred_status_error1.status().error_message(),
1918               HasSubstr("Shapes must be equal rank"));
1919 
1920   // broadcast_dimension out of bounds for tensor's rank
1921   auto inferred_status_error2 =
1922       ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {3});
1923   ASSERT_FALSE(inferred_status_error2.ok());
1924   ASSERT_THAT(inferred_status_error2.status().error_message(),
1925               ContainsRegex("Broadcast dimension number .* too large"));
1926 
1927   // broadcast_dimension doesn't match corresponding dimension
1928   auto inferred_status_error3 =
1929       ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {0});
1930   ASSERT_FALSE(inferred_status_error3.ok());
1931   ASSERT_THAT(inferred_status_error3.status().error_message(),
1932               HasSubstr("Broadcast dimension 0 mismatch"));
1933 
1934   // broadcast_dimensions list too long
1935   auto inferred_status_error4 = ShapeInference::InferBinaryOpShape(
1936       HloOpcode::kAdd, tensor, matrix8_4, {0, 1, 2});
1937   ASSERT_FALSE(inferred_status_error4.ok());
1938   ASSERT_THAT(inferred_status_error4.status().error_message(),
1939               HasSubstr("broadcast_dimensions has to match"));
1940 
1941   // there's a dimension above the rank of the tensor
1942   auto inferred_status_error5 = ShapeInference::InferBinaryOpShape(
1943       HloOpcode::kAdd, tensor, matrix8_4, {3, 0});
1944   ASSERT_FALSE(inferred_status_error5.ok());
1945   ASSERT_THAT(inferred_status_error5.status().error_message(),
1946               ContainsRegex("dimension number .* too large"));
1947 
1948   // broadcasting dimensions don't match in this order
1949   auto inferred_status_error6 = ShapeInference::InferBinaryOpShape(
1950       HloOpcode::kAdd, tensor, matrix8_4, {2, 1});
1951   ASSERT_FALSE(inferred_status_error6.ok());
1952   ASSERT_THAT(inferred_status_error6.status().error_message(),
1953               HasSubstr("dimension 0 mismatch"));
1954 
1955   // The following two tests make sure that broadcasting dimensions are listed
1956   // in a proper (strictly increasing) order, even if the lower-rank array
1957   // matches the higher-rank array in many different ways.
1958   auto inferred_status_error7 = ShapeInference::InferBinaryOpShape(
1959       HloOpcode::kAdd, tensor8_8_8, matrix8_8, {0, 0});
1960   ASSERT_FALSE(inferred_status_error7.ok());
1961   ASSERT_THAT(inferred_status_error7.status().error_message(),
1962               HasSubstr("dimensions order is wrong"));
1963 
1964   auto inferred_status_error8 = ShapeInference::InferBinaryOpShape(
1965       HloOpcode::kAdd, tensor8_8_8, matrix8_8, {1, 0});
1966   ASSERT_FALSE(inferred_status_error8.ok());
1967   ASSERT_THAT(inferred_status_error8.status().error_message(),
1968               HasSubstr("dimensions order is wrong"));
1969 }
1970 
1971 // Tests for the while instruction with proper shapes.
TEST_F(ShapeInferenceTest,WhileWithCorrectShapes)1972 TEST_F(ShapeInferenceTest, WhileWithCorrectShapes) {
1973   Shape result_shape = ShapeUtil::MakeTupleShape({s32_, vector_32_});
1974   ProgramShape cond = ShapeUtil::MakeProgramShape({result_shape}, pred_);
1975   ProgramShape body = ShapeUtil::MakeProgramShape({result_shape}, result_shape);
1976   auto inferred_status =
1977       ShapeInference::InferWhileShape(cond, body, result_shape);
1978   ASSERT_IS_OK(inferred_status.status());
1979   Shape inferred = inferred_status.ValueOrDie();
1980   ASSERT_TRUE(ShapeUtil::Equal(result_shape, inferred));
1981 }
1982 
1983 // Tests for the while instruction with wrong shapes.
TEST_F(ShapeInferenceTest,WhileWithBadShapes)1984 TEST_F(ShapeInferenceTest, WhileWithBadShapes) {
1985   Shape result_shape = ShapeUtil::MakeTupleShape({s32_, vector_32_});
1986   ProgramShape cond = ShapeUtil::MakeProgramShape({result_shape}, pred_);
1987   ProgramShape body = ShapeUtil::MakeProgramShape({result_shape}, result_shape);
1988 
1989   auto bad_shape_1 = ShapeUtil::MakeProgramShape({s32_, result_shape}, pred_);
1990   auto inferred_status_error1 =
1991       ShapeInference::InferWhileShape(bad_shape_1, body, result_shape);
1992   ASSERT_FALSE(inferred_status_error1.ok());
1993   ASSERT_THAT(inferred_status_error1.status().error_message(),
1994               HasSubstr("Condition must take 1 arguments"));
1995 
1996   auto bad_shape_2 =
1997       ShapeUtil::MakeProgramShape({s32_, result_shape}, result_shape);
1998   auto inferred_status_error2 =
1999       ShapeInference::InferWhileShape(cond, bad_shape_2, result_shape);
2000   ASSERT_FALSE(inferred_status_error2.ok());
2001   ASSERT_THAT(inferred_status_error2.status().error_message(),
2002               HasSubstr("Body must take 1 arguments"));
2003 
2004   auto bad_shape_3 = ShapeUtil::MakeProgramShape({result_shape}, s32_);
2005   auto inferred_status_error3 =
2006       ShapeInference::InferWhileShape(bad_shape_3, body, result_shape);
2007   ASSERT_FALSE(inferred_status_error3.ok());
2008   ASSERT_THAT(inferred_status_error3.status().error_message(),
2009               HasSubstr("Condition must return a boolean"));
2010 
2011   auto bad_shape_4 = ShapeUtil::MakeProgramShape({result_shape}, vector_32_);
2012   auto inferred_status_error4 =
2013       ShapeInference::InferWhileShape(cond, bad_shape_4, result_shape);
2014   ASSERT_FALSE(inferred_status_error4.ok());
2015   ASSERT_THAT(inferred_status_error4.status().error_message(),
2016               HasSubstr("parameter of condition and body"));
2017 }
2018 
2019 // Tests for the concatenate instruction with dynamic shapes.
TEST_F(ShapeInferenceTest,ConcatenateWithDynamicShapes)2020 TEST_F(ShapeInferenceTest, ConcatenateWithDynamicShapes) {
2021   auto dynamic_shape_1 =
2022       ShapeUtil::MakeShape(F32, {32, 160, 10}, {true, false, false});
2023   auto dynamic_shape_2 =
2024       ShapeUtil::MakeShape(F32, {32, 160, 10}, {false, true, false});
2025   auto inferred_status = ShapeInference::InferConcatOpShape(
2026       {&dynamic_shape_1, &dynamic_shape_2}, /*dimension=*/0);
2027   ASSERT_IS_OK(inferred_status.status());
2028   Shape inferred = inferred_status.ValueOrDie();
2029   ASSERT_TRUE(ShapeUtil::Equal(
2030       ShapeUtil::MakeShape(F32, {64, 160, 10}, {true, true, false}), inferred));
2031 }
2032 
2033 // Tests for the concatenate instruction with proper shapes.
TEST_F(ShapeInferenceTest,ConcatenateWithCorrectShapes)2034 TEST_F(ShapeInferenceTest, ConcatenateWithCorrectShapes) {
2035   auto inferred_status_1 = ShapeInference::InferConcatOpShape(
2036       {&vector_32_, &vector_64_}, /*dimension=*/0);
2037   ASSERT_IS_OK(inferred_status_1.status());
2038   Shape inferred_1 = inferred_status_1.ValueOrDie();
2039   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {96}), inferred_1));
2040 
2041   auto inferred_status_2 = ShapeInference::InferConcatOpShape(
2042       {&vector_32_, &vector_64_, &vector_32_}, /*dimension=*/0);
2043   ASSERT_IS_OK(inferred_status_2.status());
2044   Shape inferred_2 = inferred_status_2.ValueOrDie();
2045   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {128}), inferred_2));
2046 
2047   auto inferred_status_3 = ShapeInference::InferConcatOpShape(
2048       {&matrix_32_48_, &matrix_32_64_, &matrix_32_48_}, /*dimension=*/1);
2049   ASSERT_IS_OK(inferred_status_3.status());
2050   Shape inferred_3 = inferred_status_3.ValueOrDie();
2051   ASSERT_TRUE(
2052       ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {32, 160}), inferred_3));
2053 }
2054 
2055 // Tests for the concatenate instruction with wrong shapes.
TEST_F(ShapeInferenceTest,ConcatenateWithBadShapes)2056 TEST_F(ShapeInferenceTest, ConcatenateWithBadShapes) {
2057   auto inferred_status_error1 =
2058       ShapeInference::InferConcatOpShape({}, /*dimension=*/0);
2059   ASSERT_FALSE(inferred_status_error1.ok());
2060   ASSERT_THAT(inferred_status_error1.status().error_message(),
2061               HasSubstr("Concatenate expects at least one argument"));
2062 
2063   auto inferred_status_error2 =
2064       ShapeInference::InferConcatOpShape({&vector_32_}, /*dimension=*/-1);
2065   ASSERT_FALSE(inferred_status_error2.ok());
2066   ASSERT_THAT(inferred_status_error2.status().error_message(),
2067               HasSubstr("dimension out of bounds: -1"));
2068 
2069   auto inferred_status_error3 =
2070       ShapeInference::InferConcatOpShape({&vector_32_}, /*dimension=*/1);
2071   ASSERT_FALSE(inferred_status_error3.ok());
2072   ASSERT_THAT(inferred_status_error3.status().error_message(),
2073               HasSubstr("dimension out of bounds: 1"));
2074 
2075   Shape tuple = ShapeUtil::MakeTupleShape({vector_32_});
2076   auto inferred_status_error4 = ShapeInference::InferConcatOpShape(
2077       {&vector_32_, &tuple}, /*dimension=*/0);
2078   ASSERT_FALSE(inferred_status_error4.ok());
2079   ASSERT_THAT(
2080       inferred_status_error4.status().error_message(),
2081       HasSubstr("Expected array argument for operand of concatenation"));
2082 
2083   const Shape vector_s32 = ShapeUtil::MakeShape(S32, {32});
2084   auto inferred_status_error5 = ShapeInference::InferConcatOpShape(
2085       {&vector_32_, &vector_s32}, /*dimension=*/0);
2086   ASSERT_FALSE(inferred_status_error5.ok());
2087   ASSERT_THAT(inferred_status_error5.status().error_message(),
2088               HasSubstr("concatenate arrays with different element types"));
2089 
2090   auto inferred_status_error6 = ShapeInference::InferConcatOpShape(
2091       {&matrix_32_48_, &matrix_32_64_}, /*dimension=*/0);
2092   ASSERT_FALSE(inferred_status_error6.ok());
2093   ASSERT_THAT(inferred_status_error6.status().error_message(),
2094               HasSubstr("concatenate arrays that differ in "
2095                         "dimensions other than the one being "
2096                         "concatenated"));
2097 }
2098 
TEST_F(ShapeInferenceTest,Pad)2099 TEST_F(ShapeInferenceTest, Pad) {
2100   Shape input_shape = ShapeUtil::MakeShape(F32, {10, 25});
2101   Shape padding_value_shape = ShapeUtil::MakeShape(F32, {});
2102   // Padding for dimension 0: {low: 0, high: 2, interior: 3}
2103   // Padding for dimension 1: {low: 1, high: 5, interior: 0}
2104   PaddingConfig padding_config;
2105   auto dimension0 = padding_config.add_dimensions();
2106   dimension0->set_edge_padding_low(0);
2107   dimension0->set_edge_padding_high(2);
2108   dimension0->set_interior_padding(3);
2109   auto dimension1 = padding_config.add_dimensions();
2110   dimension1->set_edge_padding_low(1);
2111   dimension1->set_edge_padding_high(5);
2112   dimension1->set_interior_padding(0);
2113 
2114   auto inferred_status = ShapeInference::InferPadShape(
2115       input_shape, padding_value_shape, padding_config);
2116   ASSERT_IS_OK(inferred_status.status());
2117   Shape inferred_shape = inferred_status.ValueOrDie();
2118   ASSERT_TRUE(
2119       ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {39, 31}), inferred_shape));
2120 
2121   dimension1->set_edge_padding_low(-20);
2122   dimension1->set_edge_padding_high(-10);
2123   auto negative_dimension_size = ShapeInference::InferPadShape(
2124       input_shape, padding_value_shape, padding_config);
2125   ASSERT_FALSE(negative_dimension_size.ok());
2126   ASSERT_THAT(negative_dimension_size.status().error_message(),
2127               HasSubstr("negative size for dimension 1"));
2128 }
2129 
TEST_F(ShapeInferenceTest,Reverse)2130 TEST_F(ShapeInferenceTest, Reverse) {
2131   Shape input_shape = ShapeUtil::MakeShape(F32, {10, 25});
2132 
2133   auto inferred_status = ShapeInference::InferReverseShape(input_shape, {0, 1});
2134   ASSERT_IS_OK(inferred_status.status());
2135   Shape inferred_shape = inferred_status.ValueOrDie();
2136   ASSERT_TRUE(ShapeUtil::Equal(input_shape, inferred_shape));
2137 }
2138 
TEST_F(ShapeInferenceTest,ReverseInvalidDimension)2139 TEST_F(ShapeInferenceTest, ReverseInvalidDimension) {
2140   Shape input_shape = ShapeUtil::MakeShape(F32, {10, 25});
2141 
2142   auto inferred_status_error0 =
2143       ShapeInference::InferReverseShape(input_shape, {0, 2});
2144   ASSERT_FALSE(inferred_status_error0.ok());
2145   ASSERT_THAT(inferred_status_error0.status().error_message(),
2146               HasSubstr("out-of-bounds"));
2147 
2148   auto inferred_status_error1 =
2149       ShapeInference::InferReverseShape(input_shape, {0, -1});
2150   ASSERT_FALSE(inferred_status_error1.ok());
2151   ASSERT_THAT(inferred_status_error1.status().error_message(),
2152               HasSubstr("out-of-bounds"));
2153 
2154   auto inferred_status_error2 =
2155       ShapeInference::InferReverseShape(input_shape, {0, 0});
2156   ASSERT_FALSE(inferred_status_error2.ok());
2157   ASSERT_THAT(inferred_status_error2.status().error_message(),
2158               HasSubstr("duplicated"));
2159 
2160   Shape tuple_shape = ShapeUtil::MakeTupleShape({input_shape, input_shape});
2161   auto inferred_status_error3 =
2162       ShapeInference::InferReverseShape(tuple_shape, {0});
2163   ASSERT_FALSE(inferred_status_error3.ok());
2164   ASSERT_THAT(inferred_status_error3.status().error_message(),
2165               HasSubstr("Expected array argument"));
2166 }
2167 
TEST_F(ShapeInferenceTest,Call)2168 TEST_F(ShapeInferenceTest, Call) {
2169   auto inferred_status0 =
2170       ShapeInference::InferCallShape({}, ShapeUtil::MakeProgramShape({}, f32_));
2171   EXPECT_IS_OK(inferred_status0.status());
2172   EXPECT_TRUE(ShapeUtil::Equal(f32_, inferred_status0.ValueOrDie()));
2173 
2174   auto inferred_status1 = ShapeInference::InferCallShape(
2175       {&f32_, &s32_, &pred_, &vector_32_, &matrix_32_48_},
2176       ShapeUtil::MakeProgramShape(
2177           {f32_, s32_, pred_, vector_32_, matrix_32_48_}, s32matrix_64_64_));
2178   EXPECT_IS_OK(inferred_status1.status());
2179   EXPECT_TRUE(
2180       ShapeUtil::Equal(s32matrix_64_64_, inferred_status1.ValueOrDie()));
2181 
2182   auto inferred_status_error0 = ShapeInference::InferCallShape(
2183       {}, ShapeUtil::MakeProgramShape({f32_}, f32_));
2184   EXPECT_FALSE(inferred_status_error0.ok());
2185   EXPECT_THAT(inferred_status_error0.status().error_message(),
2186               HasSubstr("arity must match"));
2187 
2188   auto inferred_status_error1 = ShapeInference::InferCallShape(
2189       {&f32_}, ShapeUtil::MakeProgramShape({}, f32_));
2190   EXPECT_FALSE(inferred_status_error1.ok());
2191   EXPECT_THAT(inferred_status_error1.status().error_message(),
2192               HasSubstr("arity must match"));
2193 
2194   auto inferred_status_error2 = ShapeInference::InferCallShape(
2195       {&f32_}, ShapeUtil::MakeProgramShape({s32_}, f32_));
2196   EXPECT_FALSE(inferred_status_error2.ok());
2197   EXPECT_THAT(inferred_status_error2.status().error_message(),
2198               HasSubstr("parameter must match argument"));
2199 }
2200 
TEST_F(ShapeInferenceTest,Transpose)2201 TEST_F(ShapeInferenceTest, Transpose) {
2202   Shape a_shape = ShapeUtil::MakeShape(F32, {2, 3, 4, 5});
2203   auto inferred_shape_and_status =
2204       ShapeInference::InferTransposeShape(a_shape, {1, 2, 3, 0});
2205   EXPECT_IS_OK(inferred_shape_and_status);
2206   Shape inferred_shape = inferred_shape_and_status.ValueOrDie();
2207   EXPECT_TRUE(ShapeUtil::Compatible(inferred_shape,
2208                                     ShapeUtil::MakeShape(F32, {3, 4, 5, 2})));
2209 }
2210 
TEST_F(ShapeInferenceTest,Rank1Transpose)2211 TEST_F(ShapeInferenceTest, Rank1Transpose) {
2212   Shape a_shape = ShapeUtil::MakeShape(F32, {5});
2213   auto inferred_shape_and_status =
2214       ShapeInference::InferTransposeShape(a_shape, {0});
2215   EXPECT_IS_OK(inferred_shape_and_status);
2216   Shape inferred_shape = inferred_shape_and_status.ValueOrDie();
2217   EXPECT_TRUE(
2218       ShapeUtil::Compatible(inferred_shape, ShapeUtil::MakeShape(F32, {5})));
2219 }
2220 
TEST_F(ShapeInferenceTest,ConditionalPred)2221 TEST_F(ShapeInferenceTest, ConditionalPred) {
2222   auto inferred_status0 = ShapeInference::InferConditionalShape(
2223       pred_,
2224       {ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2225        ShapeUtil::MakeProgramShape({vector_64_}, f32_)},
2226       {vector_32_, vector_64_});
2227   EXPECT_IS_OK(inferred_status0.status());
2228   EXPECT_TRUE(ShapeUtil::Equal(f32_, inferred_status0.ValueOrDie()));
2229 
2230   auto inferred_status1 = ShapeInference::InferConditionalShape(
2231       pred_,
2232       {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_64_),
2233        ShapeUtil::MakeProgramShape({vector_32_}, vector_64_)},
2234       {matrix_32_48_, vector_32_});
2235   EXPECT_IS_OK(inferred_status1.status());
2236   EXPECT_TRUE(ShapeUtil::Equal(vector_64_, inferred_status1.ValueOrDie()));
2237 
2238   auto tuple_f32_v32 = ShapeUtil::MakeTupleShape({f32_, vector_32_});
2239   auto inferred_status2 = ShapeInference::InferConditionalShape(
2240       pred_,
2241       {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_),
2242        ShapeUtil::MakeProgramShape({tuple_f32_v32}, vector_32_)},
2243       {matrix_32_48_, tuple_f32_v32});
2244   EXPECT_IS_OK(inferred_status2.status());
2245   EXPECT_TRUE(ShapeUtil::Equal(vector_32_, inferred_status2.ValueOrDie()));
2246 
2247   auto inferred_status_error0 = ShapeInference::InferConditionalShape(
2248       f32_,
2249       {ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2250        ShapeUtil::MakeProgramShape({vector_64_}, f32_)},
2251       {vector_32_, vector_64_});
2252   EXPECT_FALSE(inferred_status_error0.ok());
2253   EXPECT_THAT(inferred_status_error0.status().error_message(),
2254               HasSubstr("must be bool or int32_t"));
2255 
2256   auto inferred_status_error1 = ShapeInference::InferConditionalShape(
2257       pred_,
2258       {ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_),
2259        ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_)},
2260       {ShapeUtil::MakeTupleShape({f32_, vector_32_}), matrix_32_48_});
2261   EXPECT_FALSE(inferred_status_error1.ok());
2262   EXPECT_THAT(inferred_status_error1.status().error_message(),
2263               HasSubstr("branch computation 0 must take 1 argument"));
2264 
2265   auto inferred_status_error2 = ShapeInference::InferConditionalShape(
2266       pred_,
2267       {ShapeUtil::MakeProgramShape({vector_64_}, f32_),
2268        ShapeUtil::MakeProgramShape({vector_64_}, f32_)},
2269       {vector_32_, vector_64_});
2270   EXPECT_FALSE(inferred_status_error2.ok());
2271   EXPECT_THAT(inferred_status_error2.status().error_message(),
2272               HasSubstr("branch operand 0 must match the shape of the only "
2273                         "parameter of branch computation 0"));
2274 
2275   auto inferred_status_error3 = ShapeInference::InferConditionalShape(
2276       pred_,
2277       {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_),
2278        ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_)},
2279       {matrix_32_48_, ShapeUtil::MakeTupleShape({f32_, vector_32_})});
2280   EXPECT_FALSE(inferred_status_error3.ok());
2281   EXPECT_THAT(inferred_status_error3.status().error_message(),
2282               HasSubstr("branch computation 1 must take 1 argument"));
2283 
2284   auto inferred_status_error4 = ShapeInference::InferConditionalShape(
2285       pred_,
2286       {ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2287        ShapeUtil::MakeProgramShape({vector_32_}, f32_)},
2288       {vector_32_, vector_64_});
2289   EXPECT_FALSE(inferred_status_error4.ok());
2290   EXPECT_THAT(inferred_status_error4.status().error_message(),
2291               HasSubstr("branch operand 1 must match the shape of the only "
2292                         "parameter of branch computation 1"));
2293 
2294   auto inferred_status_error5 = ShapeInference::InferConditionalShape(
2295       pred_,
2296       {ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2297        ShapeUtil::MakeProgramShape({vector_64_}, vector_32_)},
2298       {vector_32_, vector_64_});
2299   EXPECT_FALSE(inferred_status_error5.ok());
2300   EXPECT_THAT(inferred_status_error5.status().error_message(),
2301               HasSubstr("the result of branch 0 computation and branch 1 "
2302                         "computation must have the same shape"));
2303 }
2304 
TEST_F(ShapeInferenceTest,ConditionalIndexed)2305 TEST_F(ShapeInferenceTest, ConditionalIndexed) {
2306   auto r0s32 = ShapeUtil::MakeShape(S32, {});
2307   auto inferred_status0 = ShapeInference::InferConditionalShape(
2308       r0s32,
2309       {ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2310        ShapeUtil::MakeProgramShape({vector_64_}, f32_),
2311        ShapeUtil::MakeProgramShape({vector_64_}, f32_)},
2312       {vector_32_, vector_64_, vector_64_});
2313   EXPECT_IS_OK(inferred_status0.status());
2314   EXPECT_TRUE(ShapeUtil::Equal(f32_, inferred_status0.ValueOrDie()));
2315 
2316   auto inferred_status1 = ShapeInference::InferConditionalShape(
2317       r0s32,
2318       {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_64_),
2319        ShapeUtil::MakeProgramShape({vector_32_}, vector_64_),
2320        ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_64_)},
2321       {matrix_32_48_, vector_32_, matrix_32_48_});
2322   EXPECT_IS_OK(inferred_status1.status());
2323   EXPECT_TRUE(ShapeUtil::Equal(vector_64_, inferred_status1.ValueOrDie()));
2324 
2325   auto tuple_f32_v32 = ShapeUtil::MakeTupleShape({f32_, vector_32_});
2326   auto inferred_status2 = ShapeInference::InferConditionalShape(
2327       r0s32, {ShapeUtil::MakeProgramShape({tuple_f32_v32}, vector_32_)},
2328       {tuple_f32_v32});
2329   EXPECT_IS_OK(inferred_status2.status());
2330   EXPECT_TRUE(ShapeUtil::Equal(vector_32_, inferred_status2.ValueOrDie()));
2331 
2332   auto inferred_status_error0 = ShapeInference::InferConditionalShape(
2333       pred_,
2334       {ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2335        ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2336        ShapeUtil::MakeProgramShape({vector_64_}, f32_)},
2337       {vector_32_, vector_32_, vector_64_});
2338   EXPECT_FALSE(inferred_status_error0.ok());
2339   EXPECT_THAT(inferred_status_error0.status().error_message(),
2340               HasSubstr("2 == branch_computations.size()"));
2341 
2342   auto inferred_status_error1 = ShapeInference::InferConditionalShape(
2343       r0s32,
2344       {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_),
2345        ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_),
2346        ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_)},
2347       {matrix_32_48_, ShapeUtil::MakeTupleShape({f32_, vector_32_}),
2348        matrix_32_48_});
2349   EXPECT_FALSE(inferred_status_error1.ok());
2350   EXPECT_THAT(inferred_status_error1.status().error_message(),
2351               HasSubstr("branch computation 1 must take 1 argument"));
2352 
2353   auto inferred_status_error2 = ShapeInference::InferConditionalShape(
2354       r0s32,
2355       {ShapeUtil::MakeProgramShape({r0s32}, f32_),
2356        ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2357        ShapeUtil::MakeProgramShape({vector_32_}, f32_)},
2358       {r0s32, vector_32_, vector_64_});
2359   EXPECT_FALSE(inferred_status_error2.ok());
2360   EXPECT_THAT(inferred_status_error2.status().error_message(),
2361               HasSubstr("branch operand 2 must match the shape of the only "
2362                         "parameter of branch computation 2"));
2363 
2364   auto inferred_status_error3 = ShapeInference::InferConditionalShape(
2365       r0s32,
2366       {ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2367        ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2368        ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2369        ShapeUtil::MakeProgramShape({vector_64_}, vector_32_)},
2370       {vector_32_, vector_32_, vector_32_, vector_64_});
2371   EXPECT_FALSE(inferred_status_error3.ok());
2372   EXPECT_THAT(inferred_status_error3.status().error_message(),
2373               HasSubstr("the result of branch 0 computation and branch 3 "
2374                         "computation must have the same shape"));
2375 
2376   auto inferred_status_error4 =
2377       ShapeInference::InferConditionalShape(r0s32, {}, {});
2378   EXPECT_FALSE(inferred_status_error4.ok());
2379   EXPECT_THAT(inferred_status_error4.status().error_message(),
2380               HasSubstr("!branch_computations.empty()"));
2381 }
2382 
TEST_F(ShapeInferenceTest,ConditionalDynamic)2383 TEST_F(ShapeInferenceTest, ConditionalDynamic) {
2384   auto r0s32 = ShapeUtil::MakeShape(S32, {});
2385   auto static_shape = ShapeUtil::MakeShape(S32, {4}, {false});
2386   auto dynamic_shape = ShapeUtil::MakeShape(S32, {4}, {true});
2387   auto inferred_status0 = ShapeInference::InferConditionalShape(
2388       r0s32,
2389       {ShapeUtil::MakeProgramShape({vector_32_}, static_shape),
2390        ShapeUtil::MakeProgramShape({vector_64_}, dynamic_shape),
2391        ShapeUtil::MakeProgramShape({vector_64_}, dynamic_shape)},
2392       {vector_32_, vector_64_, vector_64_});
2393   EXPECT_IS_OK(inferred_status0.status());
2394   EXPECT_TRUE(ShapeUtil::Equal(dynamic_shape, inferred_status0.ValueOrDie()));
2395 
2396   auto inferred_status1 = ShapeInference::InferConditionalShape(
2397       r0s32,
2398       {ShapeUtil::MakeProgramShape({vector_32_}, dynamic_shape),
2399        ShapeUtil::MakeProgramShape({vector_64_}, static_shape),
2400        ShapeUtil::MakeProgramShape({vector_64_}, dynamic_shape)},
2401       {vector_32_, vector_64_, vector_64_});
2402   EXPECT_IS_OK(inferred_status1.status());
2403   EXPECT_TRUE(ShapeUtil::Equal(dynamic_shape, inferred_status1.ValueOrDie()));
2404 }
2405 
TEST_F(ShapeInferenceTest,BadSlice)2406 TEST_F(ShapeInferenceTest, BadSlice) {
2407   auto arg = ShapeUtil::MakeShape(F32, {4});
2408   StatusOr<Shape> statusor =
2409       ShapeInference::InferSliceShape(arg, {0}, {5}, {1});
2410   ASSERT_FALSE(statusor.ok());
2411 
2412   LOG(INFO) << statusor.status();
2413 
2414   EXPECT_THAT(statusor.status().error_message(),
2415               HasSubstr("less than or equal to dimension size"))
2416       << statusor.status();
2417   EXPECT_THAT(statusor.status().error_message(), HasSubstr("argument shape"))
2418       << statusor.status();
2419 }
2420 
TEST_F(ShapeInferenceTest,BadSort)2421 TEST_F(ShapeInferenceTest, BadSort) {
2422   auto keys = ShapeUtil::MakeShape(F32, {4});
2423   auto values = ShapeUtil::MakeShape(F32, {5});
2424   StatusOr<Shape> statusor =
2425       ShapeInference::InferVariadicOpShape(HloOpcode::kSort, {&keys, &values});
2426   EXPECT_FALSE(statusor.ok());
2427   EXPECT_THAT(statusor.status().error_message(),
2428               HasSubstr("dimensions must match"))
2429       << statusor.status();
2430 }
2431 
TEST_F(ShapeInferenceTest,BadSortValuesMismatch)2432 TEST_F(ShapeInferenceTest, BadSortValuesMismatch) {
2433   auto keys = ShapeUtil::MakeShape(F32, {4});
2434   auto values_good = ShapeUtil::MakeShape(F32, {4});
2435   auto values_bad = ShapeUtil::MakeShape(F32, {5});
2436   StatusOr<Shape> statusor = ShapeInference::InferVariadicOpShape(
2437       HloOpcode::kSort, {&keys, &values_good, &values_bad});
2438   EXPECT_FALSE(statusor.ok());
2439   EXPECT_THAT(statusor.status().error_message(),
2440               HasSubstr("dimensions must match"))
2441       << statusor.status();
2442 }
2443 
TEST_F(ShapeInferenceTest,SortManyValues)2444 TEST_F(ShapeInferenceTest, SortManyValues) {
2445   auto keys = ShapeUtil::MakeShape(F32, {4});
2446   auto values_s32 = ShapeUtil::MakeShape(S32, {4});
2447   auto values_u32 = ShapeUtil::MakeShape(U32, {4});
2448   StatusOr<Shape> statusor = ShapeInference::InferVariadicOpShape(
2449       HloOpcode::kSort, {&keys, &values_s32, &values_u32});
2450   EXPECT_IS_OK(statusor);
2451   Shape inferred_shape = statusor.ValueOrDie();
2452   EXPECT_TRUE(ShapeUtil::Compatible(
2453       inferred_shape,
2454       ShapeUtil::MakeTupleShape({keys, values_s32, values_u32})));
2455 }
2456 
2457 class GatherShapeInferenceTest : public ShapeInferenceTest {
2458  protected:
2459   const Shape s64_scalar_ = ShapeUtil::MakeShape(S64, {});
2460   const Shape s64_vector_5_ = ShapeUtil::MakeShape(S64, {5});
2461   const Shape s64_vector_32_ = ShapeUtil::MakeShape(S64, {32});
2462   const Shape s64_4d_tensor_10_9_8_7_1_ =
2463       ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 1});
2464   const Shape s64_4d_tensor_10_9_8_7_5_ =
2465       ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 5});
2466   const Shape s64_4d_tensor_5_10_9_7_6_ =
2467       ShapeUtil::MakeShape(S64, {5, 10, 9, 7, 6});
2468   const Shape s64_4d_tensor_10_9_5_7_6_ =
2469       ShapeUtil::MakeShape(S64, {10, 9, 5, 7, 6});
2470   const Shape f32_5d_tensor_50_49_48_47_46_ =
2471       ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46});
2472   const Shape tuple_shape_ = ShapeUtil::MakeTupleShape(
2473       {s64_4d_tensor_10_9_8_7_1_, s64_4d_tensor_10_9_8_7_1_});
2474 };
2475 
TEST_F(GatherShapeInferenceTest,TensorFlowGather)2476 TEST_F(GatherShapeInferenceTest, TensorFlowGather) {
2477   TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
2478                           ShapeInference::InferGatherShape(
2479                               matrix_64_48_, s64_vector_32_,
2480                               HloGatherInstruction::MakeGatherDimNumbers(
2481                                   /*offset_dims=*/{0},
2482                                   /*collapsed_slice_dims=*/{1},
2483                                   /*start_index_map=*/{1},
2484                                   /*index_vector_dim=*/1),
2485                               /*slice_sizes=*/{64, 1}));
2486   EXPECT_TRUE(
2487       ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {64, 32})))
2488       << ShapeUtil::HumanString(gather_shape);
2489 }
2490 
TEST_F(GatherShapeInferenceTest,TensorFlowGatherV2)2491 TEST_F(GatherShapeInferenceTest, TensorFlowGatherV2) {
2492   TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
2493                           ShapeInference::InferGatherShape(
2494                               matrix_64_48_, s64_vector_32_,
2495                               HloGatherInstruction::MakeGatherDimNumbers(
2496                                   /*offset_dims=*/{1},
2497                                   /*collapsed_slice_dims=*/{0},
2498                                   /*start_index_map=*/{0},
2499                                   /*index_vector_dim=*/1),
2500                               /*slice_sizes=*/{1, 48}));
2501   EXPECT_TRUE(
2502       ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {32, 48})))
2503       << ShapeUtil::HumanString(gather_shape);
2504 }
2505 
TEST_F(GatherShapeInferenceTest,TensorFlowGatherNd)2506 TEST_F(GatherShapeInferenceTest, TensorFlowGatherNd) {
2507   TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
2508                           ShapeInference::InferGatherShape(
2509                               matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
2510                               HloGatherInstruction::MakeGatherDimNumbers(
2511                                   /*offset_dims=*/{4},
2512                                   /*collapsed_slice_dims=*/{0},
2513                                   /*start_index_map=*/{0},
2514                                   /*index_vector_dim=*/4),
2515                               /*slice_sizes=*/{1, 48}));
2516   EXPECT_TRUE(ShapeUtil::Equal(gather_shape,
2517                                ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 48})))
2518       << ShapeUtil::HumanString(gather_shape);
2519 }
2520 
TEST_F(GatherShapeInferenceTest,TensorFlowBatchDynamicSlice)2521 TEST_F(GatherShapeInferenceTest, TensorFlowBatchDynamicSlice) {
2522   TF_ASSERT_OK_AND_ASSIGN(
2523       Shape gather_shape,
2524       ShapeInference::InferGatherShape(
2525           f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2526           HloGatherInstruction::MakeGatherDimNumbers(
2527               /*offset_dims=*/{4, 5, 6, 7, 8},
2528               /*collapsed_slice_dims=*/{},
2529               /*start_index_map=*/{0, 1, 2, 3, 4},
2530               /*index_vector_dim=*/4),
2531           /*slice_sizes=*/{30, 29, 28, 27, 26}));
2532   EXPECT_TRUE(ShapeUtil::Equal(
2533       gather_shape,
2534       ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26})))
2535       << ShapeUtil::HumanString(gather_shape);
2536 }
2537 
TEST_F(GatherShapeInferenceTest,DynamicGatherEntireDimension)2538 TEST_F(GatherShapeInferenceTest, DynamicGatherEntireDimension) {
2539   TF_ASSERT_OK_AND_ASSIGN(
2540       Shape gather_shape,
2541       ShapeInference::InferGatherShape(
2542           ShapeUtil::MakeShape(F32, {3, 2, 1}, {false, true, false}),
2543           ShapeUtil::MakeShape(S64, {}),
2544           HloGatherInstruction::MakeGatherDimNumbers(
2545               /*offset_dims=*/{0, 1},
2546               /*collapsed_slice_dims=*/{0},
2547               /*start_index_map=*/{0},
2548               /*index_vector_dim=*/0),
2549           /*slice_sizes=*/{1, 2, 1}));
2550   EXPECT_TRUE(ShapeUtil::Equal(
2551       gather_shape, ShapeUtil::MakeShape(F32, {2, 1}, {true, false})))
2552       << ShapeUtil::HumanString(gather_shape);
2553 }
2554 
TEST_F(GatherShapeInferenceTest,DynamicGatherCollapsedDimension)2555 TEST_F(GatherShapeInferenceTest, DynamicGatherCollapsedDimension) {
2556   TF_ASSERT_OK_AND_ASSIGN(
2557       Shape gather_shape,
2558       ShapeInference::InferGatherShape(
2559           ShapeUtil::MakeShape(F32, {3, 2, 1}, {true, false, false}),
2560           ShapeUtil::MakeShape(S64, {}),
2561           HloGatherInstruction::MakeGatherDimNumbers(
2562               /*offset_dims=*/{0, 1},
2563               /*collapsed_slice_dims=*/{0},
2564               /*start_index_map=*/{0},
2565               /*index_vector_dim=*/0),
2566           /*slice_sizes=*/{1, 2, 1}));
2567   EXPECT_TRUE(ShapeUtil::Equal(
2568       gather_shape, ShapeUtil::MakeShape(F32, {2, 1}, {false, false})))
2569       << ShapeUtil::HumanString(gather_shape);
2570 }
2571 
TEST_F(GatherShapeInferenceTest,DynamicIndices)2572 TEST_F(GatherShapeInferenceTest, DynamicIndices) {
2573   TF_ASSERT_OK_AND_ASSIGN(
2574       Shape gather_shape,
2575       ShapeInference::InferGatherShape(
2576           ShapeUtil::MakeShape(F32, {3, 2, 2}),
2577           ShapeUtil::MakeShape(S64, {3, 4, 2}, {false, true, false}),
2578           HloGatherInstruction::MakeGatherDimNumbers(
2579               /*offset_dims=*/{2, 3},
2580               /*collapsed_slice_dims=*/{0},
2581               /*start_index_map=*/{0, 1},
2582               /*index_vector_dim=*/2),
2583           /*slice_sizes=*/{1, 2, 2}));
2584   EXPECT_TRUE(ShapeUtil::Equal(
2585       gather_shape,
2586       ShapeUtil::MakeShape(F32, {3, 4, 2, 2}, {false, true, false, false})))
2587       << ShapeUtil::HumanString(gather_shape);
2588 }
2589 
TEST_F(GatherShapeInferenceTest,NonDefaultGatherIndicesLeafDim_A)2590 TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_A) {
2591   TF_ASSERT_OK_AND_ASSIGN(
2592       Shape gather_shape,
2593       ShapeInference::InferGatherShape(
2594           f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_,
2595           HloGatherInstruction::MakeGatherDimNumbers(
2596               /*offset_dims=*/{4, 5, 6, 7, 8},
2597               /*collapsed_slice_dims=*/{},
2598               /*start_index_map=*/{0, 1, 2, 3, 4},
2599               /*index_vector_dim=*/2),
2600           /*slice_sizes=*/{30, 29, 28, 27, 26}));
2601 
2602   EXPECT_TRUE(ShapeUtil::Equal(
2603       gather_shape,
2604       ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26})))
2605       << ShapeUtil::HumanString(gather_shape);
2606 }
2607 
TEST_F(GatherShapeInferenceTest,NonDefaultGatherIndicesLeafDim_B)2608 TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) {
2609   TF_ASSERT_OK_AND_ASSIGN(
2610       Shape gather_shape,
2611       ShapeInference::InferGatherShape(
2612           f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_5_10_9_7_6_,
2613           HloGatherInstruction::MakeGatherDimNumbers(
2614               /*offset_dims=*/{4, 5, 6, 7, 8},
2615               /*collapsed_slice_dims=*/{},
2616               /*start_index_map=*/{0, 1, 2, 3, 4},
2617               /*index_vector_dim=*/0),
2618           /*slice_sizes=*/{30, 29, 28, 27, 26}));
2619 
2620   EXPECT_TRUE(ShapeUtil::Equal(
2621       gather_shape,
2622       ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26})))
2623       << ShapeUtil::HumanString(gather_shape);
2624 }
2625 
TEST_F(GatherShapeInferenceTest,NoOutputGatherDims)2626 TEST_F(GatherShapeInferenceTest, NoOutputGatherDims) {
2627   // This is equivalent to a dynamic slice.
2628   TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
2629                           ShapeInference::InferGatherShape(
2630                               f32_5d_tensor_50_49_48_47_46_, s64_vector_5_,
2631                               HloGatherInstruction::MakeGatherDimNumbers(
2632                                   /*offset_dims=*/{0, 1, 2, 3, 4},
2633                                   /*collapsed_slice_dims=*/{},
2634                                   /*start_index_map=*/{0, 1, 2, 3, 4},
2635                                   /*index_vector_dim=*/0),
2636                               /*slice_sizes=*/{30, 29, 28, 27, 26}));
2637 
2638   EXPECT_TRUE(ShapeUtil::Equal(gather_shape,
2639                                ShapeUtil::MakeShape(F32, {30, 29, 28, 27, 26})))
2640       << ShapeUtil::HumanString(gather_shape);
2641 }
2642 
TEST_F(GatherShapeInferenceTest,ScalarGatherIndices)2643 TEST_F(GatherShapeInferenceTest, ScalarGatherIndices) {
2644   // The gather indices "tensor" is a scalar S here that's used to slice out
2645   // [S,0,0,0,0]..[S,30,29,28,27] into a [30,29,28,27] shaped result.
2646   TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
2647                           ShapeInference::InferGatherShape(
2648                               f32_5d_tensor_50_49_48_47_46_, s64_scalar_,
2649                               HloGatherInstruction::MakeGatherDimNumbers(
2650                                   /*offset_dims=*/{0, 1, 2, 3},
2651                                   /*collapsed_slice_dims=*/{0},
2652                                   /*start_index_map=*/{0},
2653                                   /*index_vector_dim=*/0),
2654                               /*slice_sizes=*/{1, 30, 29, 28, 27}));
2655 
2656   EXPECT_TRUE(ShapeUtil::Equal(gather_shape,
2657                                ShapeUtil::MakeShape(F32, {30, 29, 28, 27})))
2658       << ShapeUtil::HumanString(gather_shape);
2659 }
2660 
TEST_F(GatherShapeInferenceTest,TupleShapedTensorInput)2661 TEST_F(GatherShapeInferenceTest, TupleShapedTensorInput) {
2662   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2663       tuple_shape_, s64_vector_32_,
2664       HloGatherInstruction::MakeGatherDimNumbers(
2665           /*offset_dims=*/{0},
2666           /*collapsed_slice_dims=*/{1},
2667           /*start_index_map=*/{1},
2668           /*index_vector_dim=*/1),
2669       /*slice_sizes=*/{64, 1});
2670   ASSERT_FALSE(statusor.ok());
2671   EXPECT_THAT(statusor.status().error_message(),
2672               HasSubstr("Expected array argument for input"))
2673       << statusor.status();
2674 }
2675 
TEST_F(GatherShapeInferenceTest,TupleShapedGatherIndicesInput)2676 TEST_F(GatherShapeInferenceTest, TupleShapedGatherIndicesInput) {
2677   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2678       s64_vector_32_, tuple_shape_,
2679       HloGatherInstruction::MakeGatherDimNumbers(
2680           /*offset_dims=*/{0},
2681           /*collapsed_slice_dims=*/{1},
2682           /*start_index_map=*/{1},
2683           /*index_vector_dim=*/0),
2684       /*slice_sizes=*/{64, 1});
2685   ASSERT_FALSE(statusor.ok());
2686   EXPECT_THAT(statusor.status().error_message(),
2687               HasSubstr("Expected array argument for gather indices"))
2688       << statusor.status();
2689 }
2690 
TEST_F(GatherShapeInferenceTest,FloatingPointGatherIndicesInput)2691 TEST_F(GatherShapeInferenceTest, FloatingPointGatherIndicesInput) {
2692   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2693       s64_vector_32_, vector_32_,
2694       HloGatherInstruction::MakeGatherDimNumbers(
2695           /*offset_dims=*/{0},
2696           /*collapsed_slice_dims=*/{1},
2697           /*start_index_map=*/{1},
2698           /*index_vector_dim=*/0),
2699       /*slice_sizes=*/{64, 1});
2700   ASSERT_FALSE(statusor.ok());
2701   EXPECT_THAT(statusor.status().error_message(),
2702               HasSubstr("Gather indices parameter must be an integral tensor"))
2703       << statusor.status();
2704 }
2705 
TEST_F(GatherShapeInferenceTest,InvalidGatherDimNumbers_NonAscendingWindowIndices)2706 TEST_F(GatherShapeInferenceTest,
2707        InvalidGatherDimNumbers_NonAscendingWindowIndices) {
2708   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2709       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2710       HloGatherInstruction::MakeGatherDimNumbers(
2711           /*offset_dims=*/{4, 5, 6, 8, 7},
2712           /*collapsed_slice_dims=*/{},
2713           /*start_index_map=*/{0, 1, 2, 3, 4},
2714           /*index_vector_dim=*/4),
2715       /*slice_sizes=*/{30, 29, 28, 27, 26});
2716   ASSERT_FALSE(statusor.ok());
2717   EXPECT_THAT(
2718       statusor.status().error_message(),
2719       HasSubstr("Output window dimensions in gather op must be ascending"))
2720       << statusor.status();
2721 }
2722 
TEST_F(GatherShapeInferenceTest,InvalidGatherDimNumbers_RepeatedWindowIndices)2723 TEST_F(GatherShapeInferenceTest,
2724        InvalidGatherDimNumbers_RepeatedWindowIndices) {
2725   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2726       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2727       HloGatherInstruction::MakeGatherDimNumbers(
2728           /*offset_dims=*/{4, 5, 6, 7, 7},
2729           /*collapsed_slice_dims=*/{},
2730           /*start_index_map=*/{0, 1, 2, 3, 4},
2731           /*index_vector_dim=*/4),
2732       /*slice_sizes=*/{30, 29, 28, 27, 26});
2733   ASSERT_FALSE(statusor.ok());
2734   EXPECT_THAT(
2735       statusor.status().error_message(),
2736       HasSubstr("Output window dimensions in gather op must not repeat"))
2737       << statusor.status();
2738 }
2739 
TEST_F(GatherShapeInferenceTest,InvalidGatherDimNumbers_WindowIndexOutOfBounds)2740 TEST_F(GatherShapeInferenceTest,
2741        InvalidGatherDimNumbers_WindowIndexOutOfBounds) {
2742   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2743       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2744       HloGatherInstruction::MakeGatherDimNumbers(
2745           /*offset_dims=*/{4, 5, 99, 100, 101},
2746           /*collapsed_slice_dims=*/{},
2747           /*start_index_map=*/{0, 1, 2, 3, 4},
2748           /*index_vector_dim=*/4),
2749       /*slice_sizes=*/{30, 29, 28, 27, 26});
2750   ASSERT_FALSE(statusor.ok());
2751   EXPECT_THAT(statusor.status().error_message(),
2752               HasSubstr("Offset dimension 2 in gather op is out of bounds"))
2753       << statusor.status();
2754 }
2755 
TEST_F(GatherShapeInferenceTest,InvalidGatherDimNumbers_WindowIndexBarelyOutOfBounds)2756 TEST_F(GatherShapeInferenceTest,
2757        InvalidGatherDimNumbers_WindowIndexBarelyOutOfBounds) {
2758   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2759       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2760       HloGatherInstruction::MakeGatherDimNumbers(
2761           /*offset_dims=*/{4, 5, 6, 7, 9},
2762           /*collapsed_slice_dims=*/{},
2763           /*start_index_map=*/{0, 1, 2, 3, 4},
2764           /*index_vector_dim=*/4),
2765       /*slice_sizes=*/{30, 29, 28, 27, 26});
2766   ASSERT_FALSE(statusor.ok());
2767   EXPECT_THAT(statusor.status().error_message(),
2768               HasSubstr("Offset dimension 4 in gather op is out of bounds"))
2769       << statusor.status();
2770 }
2771 
TEST_F(GatherShapeInferenceTest,InvalidGatherDimNumbers_MismatchingElidedWindowDims)2772 TEST_F(GatherShapeInferenceTest,
2773        InvalidGatherDimNumbers_MismatchingElidedWindowDims) {
2774   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2775       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2776       HloGatherInstruction::MakeGatherDimNumbers(
2777           /*offset_dims=*/{4, 5, 6, 7, 8},
2778           /*collapsed_slice_dims=*/{4},
2779           /*start_index_map=*/{0, 1, 2, 3, 4},
2780           /*index_vector_dim=*/4),
2781       /*slice_sizes=*/{30, 29, 28, 27, 26});
2782   ASSERT_FALSE(statusor.ok());
2783   EXPECT_THAT(
2784       statusor.status().error_message(),
2785       HasSubstr("All components of the offset index in a gather op must either "
2786                 "be a offset dimension or explicitly collapsed"))
2787       << statusor.status();
2788 }
2789 
TEST_F(GatherShapeInferenceTest,InvalidGatherDimNumbers_OutOfBoundsWindowToInputMapping)2790 TEST_F(GatherShapeInferenceTest,
2791        InvalidGatherDimNumbers_OutOfBoundsWindowToInputMapping) {
2792   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2793       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2794       HloGatherInstruction::MakeGatherDimNumbers(
2795           /*offset_dims=*/{4, 5, 6, 7, 8},
2796           /*collapsed_slice_dims=*/{0, 1, 2, 3, 19},
2797           /*start_index_map=*/{0, 1, 2, 3, 4},
2798           /*index_vector_dim=*/4),
2799       /*slice_sizes=*/{30, 29, 28, 27, 26});
2800   ASSERT_FALSE(statusor.ok());
2801   EXPECT_THAT(statusor.status().error_message(),
2802               HasSubstr("Invalid collapsed_slice_dims set in gather op; valid "
2803                         "range is [0, 5), got: 19"))
2804       << statusor.status();
2805 }
2806 
TEST_F(GatherShapeInferenceTest,InvalidGatherDimNumbers_RepeatedWindowToInputMapping)2807 TEST_F(GatherShapeInferenceTest,
2808        InvalidGatherDimNumbers_RepeatedWindowToInputMapping) {
2809   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2810       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2811       HloGatherInstruction::MakeGatherDimNumbers(
2812           /*offset_dims=*/{4, 5, 6, 7, 8},
2813           /*collapsed_slice_dims=*/{0, 1, 2, 3, 3},
2814           /*start_index_map=*/{0, 1, 2, 3, 4},
2815           /*index_vector_dim=*/4),
2816       /*slice_sizes=*/{30, 29, 28, 27, 26});
2817   ASSERT_FALSE(statusor.ok());
2818   EXPECT_THAT(statusor.status().error_message(),
2819               HasSubstr("Repeated dimensions not allowed in "
2820                         "collapsed_slice_dims in gather op"))
2821       << statusor.status();
2822 }
2823 
TEST_F(GatherShapeInferenceTest,InvalidGatherDimNumbers_MismatchingGatherToInputMapping)2824 TEST_F(GatherShapeInferenceTest,
2825        InvalidGatherDimNumbers_MismatchingGatherToInputMapping) {
2826   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2827       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2828       HloGatherInstruction::MakeGatherDimNumbers(
2829           /*offset_dims=*/{4, 5, 6, 7, 8},
2830           /*collapsed_slice_dims=*/{},
2831           /*start_index_map=*/{0, 1, 2, 3},
2832           /*index_vector_dim=*/4),
2833       /*slice_sizes=*/{30, 29, 28, 27, 26});
2834   ASSERT_FALSE(statusor.ok());
2835   EXPECT_THAT(statusor.status().error_message(),
2836               HasSubstr("Gather op has 4 elements in start_index_map and "
2837                         "the bound of dimension index_vector_dim=4 of "
2838                         "start_indices is 5. These two numbers must be equal."))
2839       << statusor.status();
2840 }
2841 
TEST_F(GatherShapeInferenceTest,InvalidGatherDimNumbers_OutOfBoundsGatherToInputMapping)2842 TEST_F(GatherShapeInferenceTest,
2843        InvalidGatherDimNumbers_OutOfBoundsGatherToInputMapping) {
2844   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2845       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2846       HloGatherInstruction::MakeGatherDimNumbers(
2847           /*offset_dims=*/{4, 5, 6, 7, 8},
2848           /*collapsed_slice_dims=*/{},
2849           /*start_index_map=*/{0, 1, 2, 3, 7},
2850           /*index_vector_dim=*/4),
2851       /*slice_sizes=*/{30, 29, 28, 27, 26});
2852   ASSERT_FALSE(statusor.ok());
2853   EXPECT_THAT(statusor.status().error_message(),
2854               HasSubstr("Invalid start_index_map; domain is [0, 5), got: 4->7"))
2855       << statusor.status();
2856 }
2857 
TEST_F(GatherShapeInferenceTest,InvalidGatherDimNumbers_RepeatedGatherToInputMapping)2858 TEST_F(GatherShapeInferenceTest,
2859        InvalidGatherDimNumbers_RepeatedGatherToInputMapping) {
2860   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2861       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2862       HloGatherInstruction::MakeGatherDimNumbers(
2863           /*offset_dims=*/{4, 5, 6, 7, 8},
2864           /*collapsed_slice_dims=*/{},
2865           /*start_index_map=*/{0, 1, 2, 3, 3},
2866           /*index_vector_dim=*/4),
2867       /*slice_sizes=*/{30, 29, 28, 27, 26});
2868   ASSERT_FALSE(statusor.ok());
2869   EXPECT_THAT(
2870       statusor.status().error_message(),
2871       HasSubstr("Repeated dimensions are not allowed in start_index_map"))
2872       << statusor.status();
2873 }
2874 
TEST_F(GatherShapeInferenceTest,InvalidGatherDimNumbers_NonAscendingElidedWindowDims)2875 TEST_F(GatherShapeInferenceTest,
2876        InvalidGatherDimNumbers_NonAscendingElidedWindowDims) {
2877   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2878       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2879       HloGatherInstruction::MakeGatherDimNumbers(
2880           /*offset_dims=*/{4, 5, 6, 7, 8},
2881           /*collapsed_slice_dims=*/{2, 1},
2882           /*start_index_map=*/{0, 1, 2, 3, 4},
2883           /*index_vector_dim=*/4),
2884       /*slice_sizes=*/{1, 1, 28, 27, 26});
2885   ASSERT_FALSE(statusor.ok());
2886   EXPECT_THAT(statusor.status().error_message(),
2887               HasSubstr("collapsed_slice_dims in gather op must be sorted"))
2888       << statusor.status();
2889 }
2890 
TEST_F(GatherShapeInferenceTest,InvalidGatherDimNumbers_WindowBoundsTooLarge)2891 TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_WindowBoundsTooLarge) {
2892   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2893       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2894       HloGatherInstruction::MakeGatherDimNumbers(
2895           /*offset_dims=*/{4, 5, 6, 7},
2896           /*collapsed_slice_dims=*/{2},
2897           /*start_index_map=*/{0, 1, 2, 3, 4},
2898           /*index_vector_dim=*/4),
2899       /*slice_sizes=*/{30, 29, 1, 300, 26});
2900   ASSERT_FALSE(statusor.ok());
2901   EXPECT_THAT(statusor.status().error_message(),
2902               HasSubstr("Slice size at index 3 in gather op is out of range, "
2903                         "must be within [0, 48), got 300."))
2904       << statusor.status();
2905 }
2906 
TEST_F(GatherShapeInferenceTest,InvalidGatherDimNumbers_MismatchingNumberOfWindowBounds)2907 TEST_F(GatherShapeInferenceTest,
2908        InvalidGatherDimNumbers_MismatchingNumberOfWindowBounds) {
2909   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2910       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2911       HloGatherInstruction::MakeGatherDimNumbers(
2912           /*offset_dims=*/{4, 5, 6, 7, 8},
2913           /*collapsed_slice_dims=*/{},
2914           /*start_index_map=*/{0, 1, 2, 3, 4},
2915           /*index_vector_dim=*/4),
2916       /*slice_sizes=*/{30, 29, 28, 26});
2917   ASSERT_FALSE(statusor.ok());
2918   EXPECT_THAT(
2919       statusor.status().error_message(),
2920       HasSubstr("Gather op must have one slice size for every input dimension"))
2921       << statusor.status();
2922 }
2923 
TEST_F(GatherShapeInferenceTest,InvalidGatherDimNumbers_WindowBoundsNot1ForElidedDim)2924 TEST_F(GatherShapeInferenceTest,
2925        InvalidGatherDimNumbers_WindowBoundsNot1ForElidedDim) {
2926   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2927       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2928       HloGatherInstruction::MakeGatherDimNumbers(
2929           /*offset_dims=*/{4, 5, 6, 7},
2930           /*collapsed_slice_dims=*/{1},
2931           /*start_index_map=*/{0, 1, 2, 3, 4},
2932           /*index_vector_dim=*/4),
2933       /*slice_sizes=*/{30, 29, 28, 26, 20});
2934   ASSERT_FALSE(statusor.ok());
2935   EXPECT_THAT(
2936       statusor.status().error_message(),
2937       HasSubstr("Gather op can only collapse slice dims with bound 1 or 0, "
2938                 "but bound is 29 for index 1 at position 0."))
2939       << statusor.status();
2940 }
2941 
TEST_F(GatherShapeInferenceTest,OutOfBoundsGatherIndicesLeafDim)2942 TEST_F(GatherShapeInferenceTest, OutOfBoundsGatherIndicesLeafDim) {
2943   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2944       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_,
2945       HloGatherInstruction::MakeGatherDimNumbers(
2946           /*offset_dims=*/{4, 5, 6, 7, 8},
2947           /*collapsed_slice_dims=*/{},
2948           /*start_index_map=*/{0, 1, 2, 3, 4},
2949           /*index_vector_dim=*/32),
2950       /*slice_sizes=*/{30, 29, 28, 27, 26});
2951 
2952   ASSERT_FALSE(statusor.ok());
2953   EXPECT_THAT(statusor.status().error_message(),
2954               HasSubstr("Gather index leaf dimension must be within [0, "
2955                         "rank(start_indices) + 1)"))
2956       << statusor.status();
2957 }
2958 
2959 class ScatterShapeInferenceTest
2960     : public ShapeInferenceTest,
2961       public ::testing::WithParamInterface<std::vector<PrimitiveType>> {
2962  protected:
2963   struct ScatterShapes {
Addxla::__anon5c819e1b0111::ScatterShapeInferenceTest::ScatterShapes2964     void Add(Shape shape) {
2965       shapes.push_back(std::move(shape));
2966       ptrs.push_back(&shapes.back());
2967     }
2968     std::vector<Shape> shapes;
2969     std::vector<const Shape*> ptrs;
2970   };
CreateShapes(absl::Span<const int64_t> operand_dims,const Shape & scatter_indices_shape,absl::Span<const int64_t> update_dims,absl::Span<const PrimitiveType> types)2971   static ScatterShapes CreateShapes(absl::Span<const int64_t> operand_dims,
2972                                     const Shape& scatter_indices_shape,
2973                                     absl::Span<const int64_t> update_dims,
2974                                     absl::Span<const PrimitiveType> types) {
2975     CHECK(!types.empty());
2976     size_t size = types.size() * 2 + 1;
2977     ScatterShapes shapes;
2978     shapes.shapes.reserve(size);
2979     shapes.ptrs.reserve(size);
2980     for (PrimitiveType type : types) {
2981       shapes.Add(ShapeUtil::MakeShape(type, operand_dims));
2982     }
2983     shapes.Add(scatter_indices_shape);
2984     for (PrimitiveType type : types) {
2985       shapes.Add(ShapeUtil::MakeShape(type, update_dims));
2986     }
2987     return shapes;
2988   }
Collate(absl::Span<const int64_t> dims,absl::Span<const PrimitiveType> types)2989   static Shape Collate(absl::Span<const int64_t> dims,
2990                        absl::Span<const PrimitiveType> types) {
2991     CHECK(!types.empty());
2992     if (types.size() == 1) {
2993       return ShapeUtil::MakeShape(types[0], dims);
2994     }
2995     std::vector<Shape> shapes;
2996     for (PrimitiveType type : types) {
2997       shapes.push_back(ShapeUtil::MakeShape(type, dims));
2998     }
2999     return ShapeUtil::MakeTupleShape(shapes);
3000   }
scalar(PrimitiveType type)3001   static Shape scalar(PrimitiveType type) {
3002     return ShapeUtil::MakeShape(type, {});
3003   }
s64_vector(int dim)3004   static Shape s64_vector(int dim) { return ShapeUtil::MakeShape(S64, {dim}); }
s64_tensor(absl::Span<const int64_t> dims)3005   static Shape s64_tensor(absl::Span<const int64_t> dims) {
3006     return ShapeUtil::MakeShape(S64, dims);
3007   }
to_apply(absl::Span<const PrimitiveType> types)3008   static ProgramShape to_apply(absl::Span<const PrimitiveType> types) {
3009     CHECK(!types.empty());
3010     ProgramShape program_shape;
3011     Shape& result = *program_shape.mutable_result();
3012     result = ShapeUtil::MakeNil();
3013     result.mutable_tuple_shapes()->reserve(types.size());
3014     program_shape.mutable_parameters()->reserve(types.size() * 2);
3015     for (PrimitiveType type : types) {
3016       *program_shape.add_parameters() = scalar(type);
3017       *result.add_tuple_shapes() = scalar(type);
3018     }
3019     for (PrimitiveType type : types) {
3020       *program_shape.add_parameters() = scalar(type);
3021     }
3022     return program_shape;
3023   }
types() const3024   std::vector<PrimitiveType> types() const { return GetParam(); }
3025 };
3026 
TEST_P(ScatterShapeInferenceTest,TfScatterWithFullUpdates)3027 TEST_P(ScatterShapeInferenceTest, TfScatterWithFullUpdates) {
3028   auto shapes = CreateShapes({64, 48}, s64_vector(32), {64, 32}, types());
3029   TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape,
3030                           ShapeInference::InferScatterShape(
3031                               shapes.ptrs, to_apply(types()),
3032                               HloScatterInstruction::MakeScatterDimNumbers(
3033                                   /*update_window_dims=*/{0},
3034                                   /*inserted_window_dims=*/{1},
3035                                   /*scatter_dims_to_operand_dims=*/{1},
3036                                   /*index_vector_dim=*/1)));
3037   EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, Collate({64, 48}, types())))
3038       << ShapeUtil::HumanString(scatter_shape);
3039 }
3040 
TEST_P(ScatterShapeInferenceTest,TfScatterWithFullUpdatesV2)3041 TEST_P(ScatterShapeInferenceTest, TfScatterWithFullUpdatesV2) {
3042   auto shapes = CreateShapes({64, 48}, s64_vector(32), {32, 48}, types());
3043   TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape,
3044                           ShapeInference::InferScatterShape(
3045                               shapes.ptrs, to_apply(types()),
3046                               HloScatterInstruction::MakeScatterDimNumbers(
3047                                   /*update_window_dims=*/{1},
3048                                   /*inserted_window_dims=*/{0},
3049                                   /*scatter_dims_to_operand_dims=*/{0},
3050                                   /*index_vector_dim=*/1)));
3051   EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, Collate({64, 48}, types())))
3052       << ShapeUtil::HumanString(scatter_shape);
3053 }
3054 
TEST_P(ScatterShapeInferenceTest,TfScatterWithPartialUpdates)3055 TEST_P(ScatterShapeInferenceTest, TfScatterWithPartialUpdates) {
3056   auto shapes = CreateShapes({64, 48}, s64_vector(32), {10, 32}, types());
3057   TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape,
3058                           ShapeInference::InferScatterShape(
3059                               shapes.ptrs, to_apply(types()),
3060                               HloScatterInstruction::MakeScatterDimNumbers(
3061                                   /*update_window_dims=*/{0},
3062                                   /*inserted_window_dims=*/{1},
3063                                   /*scatter_dims_to_operand_dims=*/{1},
3064                                   /*index_vector_dim=*/1)));
3065   EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, Collate({64, 48}, types())))
3066       << ShapeUtil::HumanString(scatter_shape);
3067 }
3068 
TEST_P(ScatterShapeInferenceTest,TfScatterWithPartialUpdatesV2)3069 TEST_P(ScatterShapeInferenceTest, TfScatterWithPartialUpdatesV2) {
3070   auto shapes = CreateShapes({64, 48}, s64_vector(32), {32, 8}, types());
3071   TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape,
3072                           ShapeInference::InferScatterShape(
3073                               shapes.ptrs, to_apply(types()),
3074                               HloScatterInstruction::MakeScatterDimNumbers(
3075                                   /*update_window_dims=*/{1},
3076                                   /*inserted_window_dims=*/{0},
3077                                   /*scatter_dims_to_operand_dims=*/{0},
3078                                   /*index_vector_dim=*/1)));
3079   EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, Collate({64, 48}, types())))
3080       << ShapeUtil::HumanString(scatter_shape);
3081 }
3082 
TEST_P(ScatterShapeInferenceTest,TfScatterWithUpdatesBiggerThanInput)3083 TEST_P(ScatterShapeInferenceTest, TfScatterWithUpdatesBiggerThanInput) {
3084   auto shapes = CreateShapes({64, 48}, s64_vector(32), {65, 32}, types());
3085   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3086       shapes.ptrs, to_apply(types()),
3087       HloScatterInstruction::MakeScatterDimNumbers(
3088           /*update_window_dims=*/{0},
3089           /*inserted_window_dims=*/{1},
3090           /*scatter_dims_to_operand_dims=*/{1},
3091           /*index_vector_dim=*/1));
3092   ASSERT_FALSE(statusor.ok());
3093   EXPECT_THAT(
3094       statusor.status().error_message(),
3095       HasSubstr("Bounds of the window dimensions of updates must not exceed "
3096                 "the bounds of the corresponding dimensions of operand."))
3097       << statusor.status();
3098 }
3099 
TEST_P(ScatterShapeInferenceTest,TfScatterWithUpdatesBiggerThanInputV2)3100 TEST_P(ScatterShapeInferenceTest, TfScatterWithUpdatesBiggerThanInputV2) {
3101   auto shapes = CreateShapes({64, 48}, s64_vector(32), {32, 49}, types());
3102   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3103       shapes.ptrs, to_apply(types()),
3104       HloScatterInstruction::MakeScatterDimNumbers(
3105           /*update_window_dims=*/{1},
3106           /*inserted_window_dims=*/{0},
3107           /*scatter_dims_to_operand_dims=*/{1},
3108           /*index_vector_dim=*/1));
3109   ASSERT_FALSE(statusor.ok());
3110   EXPECT_THAT(
3111       statusor.status().error_message(),
3112       HasSubstr("Bounds of the window dimensions of updates must not exceed "
3113                 "the bounds of the corresponding dimensions of operand."))
3114       << statusor.status();
3115 }
3116 
TEST_P(ScatterShapeInferenceTest,TfScatterWithUpdatesNotMatchingIndices)3117 TEST_P(ScatterShapeInferenceTest, TfScatterWithUpdatesNotMatchingIndices) {
3118   auto shapes = CreateShapes({64, 48}, s64_vector(32), {64, 31}, types());
3119   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3120       shapes.ptrs, to_apply(types()),
3121       HloScatterInstruction::MakeScatterDimNumbers(
3122           /*update_window_dims=*/{0},
3123           /*inserted_window_dims=*/{1},
3124           /*scatter_dims_to_operand_dims=*/{1},
3125           /*index_vector_dim=*/1));
3126   ASSERT_FALSE(statusor.ok());
3127   EXPECT_THAT(
3128       statusor.status().error_message(),
3129       HasSubstr(
3130           "Bounds of the scatter dimensions of updates must be same as the "
3131           "bounds of the corresponding dimensions of scatter indices."))
3132       << statusor.status();
3133 }
3134 
TEST_P(ScatterShapeInferenceTest,TfScatterWithUpdatesNotMatchingIndicesV2)3135 TEST_P(ScatterShapeInferenceTest, TfScatterWithUpdatesNotMatchingIndicesV2) {
3136   auto shapes = CreateShapes({64, 48}, s64_vector(32), {31, 48}, types());
3137   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3138       shapes.ptrs, to_apply(types()),
3139       HloScatterInstruction::MakeScatterDimNumbers(
3140           /*update_window_dims=*/{1},
3141           /*inserted_window_dims=*/{0},
3142           /*scatter_dims_to_operand_dims=*/{1},
3143           /*index_vector_dim=*/1));
3144   ASSERT_FALSE(statusor.ok());
3145   EXPECT_THAT(
3146       statusor.status().error_message(),
3147       HasSubstr(
3148           "Bounds of the scatter dimensions of updates must be same as the "
3149           "bounds of the corresponding dimensions of scatter indices."))
3150       << statusor.status();
3151 }
3152 
TEST_P(ScatterShapeInferenceTest,TfScatterNdWithFullUpdates)3153 TEST_P(ScatterShapeInferenceTest, TfScatterNdWithFullUpdates) {
3154   auto shapes = CreateShapes({64, 48}, s64_tensor({10, 9, 8, 7, 1}),
3155                              {10, 9, 8, 7, 48}, types());
3156   TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape,
3157                           ShapeInference::InferScatterShape(
3158                               shapes.ptrs, to_apply(types()),
3159                               HloScatterInstruction::MakeScatterDimNumbers(
3160                                   /*update_window_dims=*/{4},
3161                                   /*inserted_window_dims=*/{0},
3162                                   /*scatter_dims_to_operand_dims=*/{0},
3163                                   /*index_vector_dim=*/4)));
3164   EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, Collate({64, 48}, types())))
3165       << ShapeUtil::HumanString(scatter_shape);
3166 }
3167 
TEST_P(ScatterShapeInferenceTest,TfScatterNdWithFullUpdatesV2)3168 TEST_P(ScatterShapeInferenceTest, TfScatterNdWithFullUpdatesV2) {
3169   auto shapes = CreateShapes({64, 48}, s64_tensor({10, 9, 8, 7, 1}),
3170                              {10, 9, 8, 7, 64}, types());
3171   TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape,
3172                           ShapeInference::InferScatterShape(
3173                               shapes.ptrs, to_apply(types()),
3174                               HloScatterInstruction::MakeScatterDimNumbers(
3175                                   /*update_window_dims=*/{4},
3176                                   /*inserted_window_dims=*/{1},
3177                                   /*scatter_dims_to_operand_dims=*/{0},
3178                                   /*index_vector_dim=*/4)));
3179   EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, Collate({64, 48}, types())))
3180       << ShapeUtil::HumanString(scatter_shape);
3181 }
3182 
TEST_P(ScatterShapeInferenceTest,TfScatterNdWithPartialUpdates)3183 TEST_P(ScatterShapeInferenceTest, TfScatterNdWithPartialUpdates) {
3184   auto shapes = CreateShapes({64, 48}, s64_tensor({10, 9, 8, 7, 1}),
3185                              {10, 9, 8, 7, 10}, types());
3186   TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape,
3187                           ShapeInference::InferScatterShape(
3188                               shapes.ptrs, to_apply(types()),
3189                               HloScatterInstruction::MakeScatterDimNumbers(
3190                                   /*update_window_dims=*/{4},
3191                                   /*inserted_window_dims=*/{0},
3192                                   /*scatter_dims_to_operand_dims=*/{0},
3193                                   /*index_vector_dim=*/4)));
3194   EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, Collate({64, 48}, types())))
3195       << ShapeUtil::HumanString(scatter_shape);
3196 }
3197 
TEST_P(ScatterShapeInferenceTest,TfScatterNdWithPartialUpdatesV2)3198 TEST_P(ScatterShapeInferenceTest, TfScatterNdWithPartialUpdatesV2) {
3199   auto shapes = CreateShapes({64, 48}, s64_tensor({10, 9, 8, 7, 1}),
3200                              {10, 9, 8, 7, 12}, types());
3201   TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape,
3202                           ShapeInference::InferScatterShape(
3203                               shapes.ptrs, to_apply(types()),
3204                               HloScatterInstruction::MakeScatterDimNumbers(
3205                                   /*update_window_dims=*/{4},
3206                                   /*inserted_window_dims=*/{1},
3207                                   /*scatter_dims_to_operand_dims=*/{0},
3208                                   /*index_vector_dim=*/4)));
3209   EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, Collate({64, 48}, types())))
3210       << ShapeUtil::HumanString(scatter_shape);
3211 }
3212 
TEST_P(ScatterShapeInferenceTest,TfScatterNdWithUpdatesBiggerThanInput)3213 TEST_P(ScatterShapeInferenceTest, TfScatterNdWithUpdatesBiggerThanInput) {
3214   auto shapes = CreateShapes({64, 48}, s64_tensor({10, 9, 8, 7, 1}),
3215                              {10, 9, 8, 7, 65}, types());
3216   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3217       shapes.ptrs, to_apply(types()),
3218       HloScatterInstruction::MakeScatterDimNumbers(
3219           /*update_window_dims=*/{4},
3220           /*inserted_window_dims=*/{1},
3221           /*scatter_dims_to_operand_dims=*/{0},
3222           /*index_vector_dim=*/4));
3223   ASSERT_FALSE(statusor.ok());
3224   EXPECT_THAT(
3225       statusor.status().error_message(),
3226       HasSubstr("Bounds of the window dimensions of updates must not exceed "
3227                 "the bounds of the corresponding dimensions of operand."))
3228       << statusor.status();
3229 }
3230 
TEST_P(ScatterShapeInferenceTest,TfScatterNdWithUpdatesNotMatchingIndices)3231 TEST_P(ScatterShapeInferenceTest, TfScatterNdWithUpdatesNotMatchingIndices) {
3232   auto shapes = CreateShapes({64, 48}, s64_tensor({10, 9, 8, 7, 1}),
3233                              {9, 9, 8, 7, 64}, types());
3234   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3235       shapes.ptrs, to_apply(types()),
3236       HloScatterInstruction::MakeScatterDimNumbers(
3237           /*update_window_dims=*/{4},
3238           /*inserted_window_dims=*/{1},
3239           /*scatter_dims_to_operand_dims=*/{0},
3240           /*index_vector_dim=*/4));
3241   ASSERT_FALSE(statusor.ok());
3242   EXPECT_THAT(
3243       statusor.status().error_message(),
3244       HasSubstr(
3245           "Bounds of the scatter dimensions of updates must be same as the "
3246           "bounds of the corresponding dimensions of scatter indices."))
3247       << statusor.status();
3248 }
3249 
TEST_P(ScatterShapeInferenceTest,TfBatchDynamicUpdateSlice)3250 TEST_P(ScatterShapeInferenceTest, TfBatchDynamicUpdateSlice) {
3251   auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}),
3252                              {10, 9, 8, 7, 30, 29, 28, 27, 26}, types());
3253   TF_ASSERT_OK_AND_ASSIGN(
3254       Shape scatter_shape,
3255       ShapeInference::InferScatterShape(
3256           shapes.ptrs, to_apply(types()),
3257           HloScatterInstruction::MakeScatterDimNumbers(
3258               /*update_window_dims=*/{4, 5, 6, 7, 8},
3259               /*inserted_window_dims=*/{},
3260               /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3261               /*index_vector_dim=*/4)));
3262   EXPECT_TRUE(
3263       ShapeUtil::Equal(scatter_shape, Collate({50, 49, 48, 47, 46}, types())))
3264       << ShapeUtil::HumanString(scatter_shape);
3265 }
3266 
TEST_P(ScatterShapeInferenceTest,NonDefaultScatterIndicesLeafDim)3267 TEST_P(ScatterShapeInferenceTest, NonDefaultScatterIndicesLeafDim) {
3268   auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 5, 7, 6}),
3269                              {10, 9, 7, 6, 30, 29, 28, 27, 26}, types());
3270   TF_ASSERT_OK_AND_ASSIGN(
3271       Shape scatter_shape,
3272       ShapeInference::InferScatterShape(
3273           shapes.ptrs, to_apply(types()),
3274           HloScatterInstruction::MakeScatterDimNumbers(
3275               /*update_window_dims=*/{4, 5, 6, 7, 8},
3276               /*inserted_window_dims=*/{},
3277               /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3278               /*index_vector_dim=*/2)));
3279 
3280   EXPECT_TRUE(
3281       ShapeUtil::Equal(scatter_shape, Collate({50, 49, 48, 47, 46}, types())))
3282       << ShapeUtil::HumanString(scatter_shape);
3283 }
3284 
TEST_P(ScatterShapeInferenceTest,NonDefaultScatterIndicesLeafDimV2)3285 TEST_P(ScatterShapeInferenceTest, NonDefaultScatterIndicesLeafDimV2) {
3286   auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({5, 10, 9, 7, 6}),
3287                              {10, 9, 7, 6, 30, 29, 28, 27, 26}, types());
3288   TF_ASSERT_OK_AND_ASSIGN(
3289       Shape scatter_shape,
3290       ShapeInference::InferScatterShape(
3291           shapes.ptrs, to_apply(types()),
3292           HloScatterInstruction::MakeScatterDimNumbers(
3293               /*update_window_dims=*/{4, 5, 6, 7, 8},
3294               /*inserted_window_dims=*/{},
3295               /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3296               /*index_vector_dim=*/0)));
3297 
3298   EXPECT_TRUE(
3299       ShapeUtil::Equal(scatter_shape, Collate({50, 49, 48, 47, 46}, types())))
3300       << ShapeUtil::HumanString(scatter_shape);
3301 }
3302 
TEST_P(ScatterShapeInferenceTest,NoUpdateScatterDims)3303 TEST_P(ScatterShapeInferenceTest, NoUpdateScatterDims) {
3304   auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_vector(5),
3305                              {30, 29, 28, 27, 26}, types());
3306   // This is equivalent to a dynamic update slice.
3307   TF_ASSERT_OK_AND_ASSIGN(
3308       Shape scatter_shape,
3309       ShapeInference::InferScatterShape(
3310           shapes.ptrs, to_apply(types()),
3311           HloScatterInstruction::MakeScatterDimNumbers(
3312               /*update_window_dims=*/{0, 1, 2, 3, 4},
3313               /*inserted_window_dims=*/{},
3314               /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3315               /*index_vector_dim=*/0)));
3316 
3317   EXPECT_TRUE(
3318       ShapeUtil::Equal(scatter_shape, Collate({50, 49, 48, 47, 46}, types())))
3319       << ShapeUtil::HumanString(scatter_shape);
3320 }
3321 
TEST_P(ScatterShapeInferenceTest,ScalarScatterIndices)3322 TEST_P(ScatterShapeInferenceTest, ScalarScatterIndices) {
3323   auto shapes = CreateShapes({50, 49, 48, 47, 46}, scalar(S64),
3324                              {30, 29, 28, 27}, types());
3325   // The scalar indices "tensor" is a scalar S here that's used to update a
3326   // [30,29,28,27] shaped tensor within the operand at position S.
3327   TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape,
3328                           ShapeInference::InferScatterShape(
3329                               shapes.ptrs, to_apply(types()),
3330                               HloScatterInstruction::MakeScatterDimNumbers(
3331                                   /*update_window_dims=*/{0, 1, 2, 3},
3332                                   /*inserted_window_dims=*/{0},
3333                                   /*scatter_dims_to_operand_dims=*/{0},
3334                                   /*index_vector_dim=*/0)));
3335 
3336   EXPECT_TRUE(
3337       ShapeUtil::Equal(scatter_shape, Collate({50, 49, 48, 47, 46}, types())))
3338       << ShapeUtil::HumanString(scatter_shape);
3339 }
3340 
TEST_P(ScatterShapeInferenceTest,ScatterWithTupleShapedTensorInput)3341 TEST_P(ScatterShapeInferenceTest, ScatterWithTupleShapedTensorInput) {
3342   Shape tuple_shape =
3343       ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 1}),
3344                                  ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 1})});
3345   Shape s64_vector_32 = s64_vector(32);
3346   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3347       {&tuple_shape, &s64_vector_32, &s64_vector_32}, to_apply(types()),
3348       HloScatterInstruction::MakeScatterDimNumbers(
3349           /*update_window_dims=*/{0},
3350           /*inserted_window_dims=*/{1},
3351           /*scatter_dims_to_operand_dims=*/{1},
3352           /*index_vector_dim=*/1));
3353   ASSERT_FALSE(statusor.ok());
3354   EXPECT_THAT(statusor.status().error_message(),
3355               HasSubstr("Expected array argument for operand"))
3356       << statusor.status();
3357 }
3358 
TEST_P(ScatterShapeInferenceTest,ScatterWithTupleShapedScatterIndicesInput)3359 TEST_P(ScatterShapeInferenceTest, ScatterWithTupleShapedScatterIndicesInput) {
3360   Shape tuple_shape =
3361       ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 1}),
3362                                  ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 1})});
3363   Shape s64_vector_32 = s64_vector(32);
3364   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3365       {&s64_vector_32, &tuple_shape, &s64_vector_32}, to_apply(types()),
3366       HloScatterInstruction::MakeScatterDimNumbers(
3367           /*update_window_dims=*/{0},
3368           /*inserted_window_dims=*/{1},
3369           /*scatter_dims_to_operand_dims=*/{1},
3370           /*index_vector_dim=*/0));
3371   ASSERT_FALSE(statusor.ok());
3372   EXPECT_THAT(statusor.status().error_message(),
3373               HasSubstr("Expected array argument for scatter indices"))
3374       << statusor.status();
3375 }
3376 
TEST_P(ScatterShapeInferenceTest,ScatterWithTupleShapedUpdatesInput)3377 TEST_P(ScatterShapeInferenceTest, ScatterWithTupleShapedUpdatesInput) {
3378   Shape tuple_shape =
3379       ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 1}),
3380                                  ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 1})});
3381   Shape s64_vector_32 = s64_vector(32);
3382   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3383       {&s64_vector_32, &s64_vector_32, &tuple_shape}, to_apply(types()),
3384       HloScatterInstruction::MakeScatterDimNumbers(
3385           /*update_window_dims=*/{0},
3386           /*inserted_window_dims=*/{1},
3387           /*scatter_dims_to_operand_dims=*/{1},
3388           /*index_vector_dim=*/0));
3389   ASSERT_FALSE(statusor.ok());
3390   EXPECT_THAT(statusor.status().error_message(),
3391               HasSubstr("Expected array argument for updates"))
3392       << statusor.status();
3393 }
3394 
TEST_P(ScatterShapeInferenceTest,FloatingPointScatterIndicesInput)3395 TEST_P(ScatterShapeInferenceTest, FloatingPointScatterIndicesInput) {
3396   Shape s64_vector_32 = s64_vector(32);
3397   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3398       {&s64_vector_32, &vector_32_, &s64_vector_32}, to_apply(types()),
3399       HloScatterInstruction::MakeScatterDimNumbers(
3400           /*update_window_dims=*/{0},
3401           /*inserted_window_dims=*/{1},
3402           /*scatter_dims_to_operand_dims=*/{1},
3403           /*index_vector_dim=*/0));
3404   ASSERT_FALSE(statusor.ok());
3405   EXPECT_THAT(statusor.status().error_message(),
3406               HasSubstr("Scatter indices parameter must be an integral tensor"))
3407       << statusor.status();
3408 }
3409 
TEST_P(ScatterShapeInferenceTest,OutOfBoundsScatterIndicesLeafDim)3410 TEST_P(ScatterShapeInferenceTest, OutOfBoundsScatterIndicesLeafDim) {
3411   auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}),
3412                              {10, 9, 8, 7, 30, 29, 28}, types());
3413   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3414       shapes.ptrs, to_apply(types()),
3415       HloScatterInstruction::MakeScatterDimNumbers(
3416           /*update_window_dims=*/{4, 5, 6},
3417           /*inserted_window_dims=*/{1, 2},
3418           /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3419           /*index_vector_dim=*/10));
3420   ASSERT_FALSE(statusor.ok());
3421   EXPECT_THAT(statusor.status().error_message(),
3422               HasSubstr("Scatter index leaf dimension must be within [0, "
3423                         "rank(scatter_indices) + 1)"))
3424       << statusor.status();
3425 }
3426 
TEST_P(ScatterShapeInferenceTest,InvalidUpdates)3427 TEST_P(ScatterShapeInferenceTest, InvalidUpdates) {
3428   auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}),
3429                              {10, 9, 8, 7, 30, 29, 28, 50}, types());
3430   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3431       shapes.ptrs, to_apply(types()),
3432       HloScatterInstruction::MakeScatterDimNumbers(
3433           /*update_window_dims=*/{4, 5, 6},
3434           /*inserted_window_dims=*/{1, 2},
3435           /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3436           /*index_vector_dim=*/4));
3437   ASSERT_FALSE(statusor.ok());
3438   EXPECT_THAT(statusor.status().error_message(),
3439               HasSubstr("Updates tensor must be of rank 7; got 8."))
3440       << statusor.status();
3441 }
3442 
TEST_P(ScatterShapeInferenceTest,InvalidUpdateComputation)3443 TEST_P(ScatterShapeInferenceTest, InvalidUpdateComputation) {
3444   const ProgramShape invalid_update_computation =
3445       ShapeUtil::MakeProgramShape({f32_}, f32_);
3446   auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}),
3447                              {10, 9, 8, 7, 30, 29, 28}, types());
3448   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3449       shapes.ptrs, invalid_update_computation,
3450       HloScatterInstruction::MakeScatterDimNumbers(
3451           /*update_window_dims=*/{4, 5, 6},
3452           /*inserted_window_dims=*/{1, 2},
3453           /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3454           /*index_vector_dim=*/4));
3455   ASSERT_FALSE(statusor.ok());
3456   EXPECT_THAT(statusor.status().error_message(),
3457               HasSubstr(absl::Substitute(
3458                   "Reduction function must take $0 parameters, but takes 1",
3459                   2 * types().size())))
3460       << statusor.status();
3461 }
3462 
TEST_P(ScatterShapeInferenceTest,InvalidScatterDimNumbers_NonAscendingUpdateWindowDims)3463 TEST_P(ScatterShapeInferenceTest,
3464        InvalidScatterDimNumbers_NonAscendingUpdateWindowDims) {
3465   auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}),
3466                              {10, 9, 8, 7, 30, 29, 28, 27, 26}, types());
3467   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3468       shapes.ptrs, to_apply(types()),
3469       HloScatterInstruction::MakeScatterDimNumbers(
3470           /*update_window_dims=*/{4, 5, 6, 8, 7},
3471           /*inserted_window_dims=*/{},
3472           /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3473           /*index_vector_dim=*/4));
3474   ASSERT_FALSE(statusor.ok());
3475   EXPECT_THAT(statusor.status().error_message(),
3476               HasSubstr("update_window_dims in scatter op must be sorted"))
3477       << statusor.status();
3478 }
3479 
TEST_P(ScatterShapeInferenceTest,InvalidScatterDimNumbers_RepeatedUpdateWindowDims)3480 TEST_P(ScatterShapeInferenceTest,
3481        InvalidScatterDimNumbers_RepeatedUpdateWindowDims) {
3482   auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}),
3483                              {10, 9, 8, 7, 30, 29, 28, 27, 26}, types());
3484   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3485       shapes.ptrs, to_apply(types()),
3486       HloScatterInstruction::MakeScatterDimNumbers(
3487           /*update_window_dims=*/{4, 5, 6, 7, 7},
3488           /*inserted_window_dims=*/{},
3489           /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3490           /*index_vector_dim=*/4));
3491   ASSERT_FALSE(statusor.ok());
3492   EXPECT_THAT(statusor.status().error_message(),
3493               HasSubstr("update_window_dims in scatter op must not repeat"))
3494       << statusor.status();
3495 }
3496 
TEST_P(ScatterShapeInferenceTest,InvalidScatterDimNumbers_OutOfBoundsUpdateWindowDims)3497 TEST_P(ScatterShapeInferenceTest,
3498        InvalidScatterDimNumbers_OutOfBoundsUpdateWindowDims) {
3499   auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}),
3500                              {10, 9, 8, 7, 30, 29, 28, 27, 26}, types());
3501   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3502       shapes.ptrs, to_apply(types()),
3503       HloScatterInstruction::MakeScatterDimNumbers(
3504           /*update_window_dims=*/{4, 5, 6, 7, 9},
3505           /*inserted_window_dims=*/{},
3506           /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3507           /*index_vector_dim=*/4));
3508   ASSERT_FALSE(statusor.ok());
3509   EXPECT_THAT(statusor.status().error_message(),
3510               HasSubstr("Invalid update_window_dims set in scatter op; valid "
3511                         "range is [0, 9)"))
3512       << statusor.status();
3513 }
3514 
TEST_P(ScatterShapeInferenceTest,InvalidScatterDimNumbers_NonAscendingInsertedWindowDims)3515 TEST_P(ScatterShapeInferenceTest,
3516        InvalidScatterDimNumbers_NonAscendingInsertedWindowDims) {
3517   auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}),
3518                              {10, 9, 8, 7, 30, 29, 28}, types());
3519   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3520       shapes.ptrs, to_apply(types()),
3521       HloScatterInstruction::MakeScatterDimNumbers(
3522           /*update_window_dims=*/{4, 5, 6},
3523           /*inserted_window_dims=*/{2, 1},
3524           /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3525           /*index_vector_dim=*/4));
3526   ASSERT_FALSE(statusor.ok());
3527   EXPECT_THAT(statusor.status().error_message(),
3528               HasSubstr("inserted_window_dims in scatter op must be sorted"))
3529       << statusor.status();
3530 }
3531 
TEST_P(ScatterShapeInferenceTest,InvalidScatterDimNumbers_RepeatedInsertedWindowDims)3532 TEST_P(ScatterShapeInferenceTest,
3533        InvalidScatterDimNumbers_RepeatedInsertedWindowDims) {
3534   auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}),
3535                              {10, 9, 8, 7, 30, 29, 28}, types());
3536   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3537       shapes.ptrs, to_apply(types()),
3538       HloScatterInstruction::MakeScatterDimNumbers(
3539           /*update_window_dims=*/{4, 5, 6},
3540           /*inserted_window_dims=*/{1, 1},
3541           /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3542           /*index_vector_dim=*/4));
3543   ASSERT_FALSE(statusor.ok());
3544   EXPECT_THAT(statusor.status().error_message(),
3545               HasSubstr("inserted_window_dims in scatter op must not repeat"))
3546       << statusor.status();
3547 }
3548 
TEST_P(ScatterShapeInferenceTest,InvalidScatterDimNumbers_OutOfBoundsInsertedWindowDims)3549 TEST_P(ScatterShapeInferenceTest,
3550        InvalidScatterDimNumbers_OutOfBoundsInsertedWindowDims) {
3551   auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}),
3552                              {10, 9, 8, 7, 30, 29, 28}, types());
3553   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3554       shapes.ptrs, to_apply(types()),
3555       HloScatterInstruction::MakeScatterDimNumbers(
3556           /*update_window_dims=*/{4, 5, 6},
3557           /*inserted_window_dims=*/{1, 5},
3558           /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3559           /*index_vector_dim=*/4));
3560   ASSERT_FALSE(statusor.ok());
3561   EXPECT_THAT(statusor.status().error_message(),
3562               HasSubstr("Invalid inserted_window_dims set in scatter op; valid "
3563                         "range is [0, 5)"))
3564       << statusor.status();
3565 }
3566 
TEST_P(ScatterShapeInferenceTest,InvalidScatterDimNumbers_MismatchingScatterDimsToOperandDims)3567 TEST_P(ScatterShapeInferenceTest,
3568        InvalidScatterDimNumbers_MismatchingScatterDimsToOperandDims) {
3569   auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}),
3570                              {10, 9, 8, 7, 30, 29, 28}, types());
3571   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3572       shapes.ptrs, to_apply(types()),
3573       HloScatterInstruction::MakeScatterDimNumbers(
3574           /*update_window_dims=*/{4, 5, 6},
3575           /*inserted_window_dims=*/{1, 2},
3576           /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3},
3577           /*index_vector_dim=*/4));
3578   ASSERT_FALSE(statusor.ok());
3579   EXPECT_THAT(
3580       statusor.status().error_message(),
3581       HasSubstr("Scatter op has 4 elements in scatter_dims_to_operand_dims and "
3582                 "the bound of dimension index_vector_dim=4 of scatter_indices "
3583                 "is 5. These two numbers must be equal"))
3584       << statusor.status();
3585 }
3586 
TEST_P(ScatterShapeInferenceTest,InvalidScatterDimNumbers_OutOfBoundsScatterDimsToOperandDims)3587 TEST_P(ScatterShapeInferenceTest,
3588        InvalidScatterDimNumbers_OutOfBoundsScatterDimsToOperandDims) {
3589   auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}),
3590                              {10, 9, 8, 7, 30, 29, 28}, types());
3591   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3592       shapes.ptrs, to_apply(types()),
3593       HloScatterInstruction::MakeScatterDimNumbers(
3594           /*update_window_dims=*/{4, 5, 6},
3595           /*inserted_window_dims=*/{1, 2},
3596           /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 10},
3597           /*index_vector_dim=*/4));
3598   ASSERT_FALSE(statusor.ok());
3599   EXPECT_THAT(statusor.status().error_message(),
3600               HasSubstr("Invalid scatter_dims_to_operand_dims mapping; domain "
3601                         "is [0, 5), got: 4->10"))
3602       << statusor.status();
3603 }
3604 
TEST_P(ScatterShapeInferenceTest,InvalidScatterDimNumbers_RepeatedValuesInScatterDimsToOperandDims)3605 TEST_P(ScatterShapeInferenceTest,
3606        InvalidScatterDimNumbers_RepeatedValuesInScatterDimsToOperandDims) {
3607   auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}),
3608                              {10, 9, 8, 7, 30, 29, 28}, types());
3609   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3610       shapes.ptrs, to_apply(types()),
3611       HloScatterInstruction::MakeScatterDimNumbers(
3612           /*update_window_dims=*/{4, 5, 6},
3613           /*inserted_window_dims=*/{1, 2},
3614           /*scatter_dims_to_operand_dims=*/{0, 1, 2, 2, 3},
3615           /*index_vector_dim=*/4));
3616   ASSERT_FALSE(statusor.ok());
3617   EXPECT_THAT(
3618       statusor.status().error_message(),
3619       HasSubstr(
3620           "Repeated dimensions not allowed in scatter_dims_to_operand_dims"))
3621       << statusor.status();
3622 }
3623 
TEST_P(ScatterShapeInferenceTest,InvalidScatterDimNumbers_InsufficientWindowDims)3624 TEST_P(ScatterShapeInferenceTest,
3625        InvalidScatterDimNumbers_InsufficientWindowDims) {
3626   auto shapes = CreateShapes({50, 49, 48, 47, 46}, scalar(S64),
3627                              {30, 29, 28, 27}, types());
3628   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3629       shapes.ptrs, to_apply(types()),
3630       HloScatterInstruction::MakeScatterDimNumbers(
3631           /*update_window_dims=*/{0, 1, 2, 3},
3632           /*inserted_window_dims=*/{},
3633           /*scatter_dims_to_operand_dims=*/{0},
3634           /*index_vector_dim=*/0));
3635   ASSERT_FALSE(statusor.ok());
3636   EXPECT_THAT(
3637       statusor.status().error_message(),
3638       HasSubstr(
3639           "Scatter op has window of size 4; doesn't match operand of rank 5."))
3640       << statusor.status();
3641 }
3642 
3643 struct ScatterTestName {
operator ()xla::__anon5c819e1b0111::ScatterTestName3644   std::string operator()(
3645       const ::testing::TestParamInfo<std::vector<PrimitiveType>>& info) const {
3646     return absl::StrJoin(info.param, "_", absl::StreamFormatter());
3647   }
3648 };
3649 
3650 INSTANTIATE_TEST_SUITE_P(All, ScatterShapeInferenceTest,
3651                          ::testing::Values(std::vector<PrimitiveType>{F32},
3652                                            std::vector<PrimitiveType>{F32,
3653                                                                       BF16}),
3654                          ScatterTestName());
3655 
3656 }  // namespace
3657 }  // namespace xla
3658