1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/shape_inference.h"
17
18 #include <string>
19 #include <utility>
20
21 #include "absl/strings/string_view.h"
22 #include "absl/strings/substitute.h"
23 #include "absl/types/span.h"
24 #include "tensorflow/compiler/xla/client/padding.h"
25 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/test.h"
28 #include "tensorflow/compiler/xla/test_helpers.h"
29 #include "tensorflow/compiler/xla/types.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31
32 namespace xla {
33 namespace {
34
35 using ::testing::ContainsRegex;
36 using ::testing::HasSubstr;
37
38 class ShapeInferenceTest : public ::testing::Test {
39 protected:
40 // Some handy scalar shapes.
41 const Shape s32_ = ShapeUtil::MakeShape(S32, {});
42 const Shape f16_ = ShapeUtil::MakeShape(F16, {});
43 const Shape f32_ = ShapeUtil::MakeShape(F32, {});
44 const Shape f64_ = ShapeUtil::MakeShape(F64, {});
45 const Shape pred_ = ShapeUtil::MakeShape(PRED, {});
46
47 // Some handy vector and matrix shapes of F32 type.
48 // Suffix: vector_length_, matrix_rows_cols_
49 const Shape vector_32_ = ShapeUtil::MakeShape(F32, {32});
50 const Shape vector_64_ = ShapeUtil::MakeShape(F32, {64});
51 const Shape matrix_32_48_ = ShapeUtil::MakeShape(F32, {32, 48});
52 const Shape matrix_32_64_ = ShapeUtil::MakeShape(F32, {32, 64});
53 const Shape matrix_64_48_ = ShapeUtil::MakeShape(F32, {64, 48});
54
55 // Some handy S32 arrays.
56 const Shape s32matrix_64_64_ = ShapeUtil::MakeShape(S32, {64, 64});
57 };
58
59 // Subclass for testing InferReduceShape.
60 class ReduceShapeInferenceTest : public ShapeInferenceTest {
61 protected:
62 // Helper that runs reduce shape inference with the input 'arg' and given
63 // dimensions to reduce, and checks the inferred shape is as expected. The
64 // element type here is hard-coded to F32.
ExpectInferredReduceShape(const Shape & expected_inferred_shape,const Shape & arg,absl::Span<const int64_t> dimensions_to_reduce)65 void ExpectInferredReduceShape(
66 const Shape& expected_inferred_shape, const Shape& arg,
67 absl::Span<const int64_t> dimensions_to_reduce) {
68 ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_);
69 auto inferred_status = ShapeInference::InferReduceShape(
70 {&arg, &f32_}, dimensions_to_reduce, to_apply);
71 EXPECT_IS_OK(inferred_status.status());
72 EXPECT_TRUE(ShapeUtil::Equal(expected_inferred_shape,
73 inferred_status.ValueOrDie()));
74 }
75 };
76
77 // Subclass for testing InferSelectAndScatterShape.
78 class SelectAndScatterShapeInferenceTest : public ShapeInferenceTest {
79 protected:
SelectAndScatterShapeInferenceTest()80 SelectAndScatterShapeInferenceTest() {
81 operand_shape_ = ShapeUtil::MakeShape(F32, {8, 16});
82 source_shape_ = ShapeUtil::MakeShape(F32, {4, 8});
83 WindowDimension dim;
84 dim.set_size(2);
85 dim.set_stride(2);
86 dim.set_padding_low(0);
87 dim.set_padding_high(0);
88 dim.set_window_dilation(1);
89 dim.set_base_dilation(1);
90 *window_.add_dimensions() = dim;
91 *window_.add_dimensions() = dim;
92 init_value_shape_ = ShapeUtil::MakeShape(F32, {});
93 select_program_shape_ = ShapeUtil::MakeProgramShape(
94 {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, pred_);
95 scatter_program_shape_ = ShapeUtil::MakeProgramShape(
96 {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, f32_);
97 }
98
99 Shape operand_shape_;
100 Shape source_shape_;
101 Window window_;
102 Shape init_value_shape_;
103 ProgramShape select_program_shape_;
104 ProgramShape scatter_program_shape_;
105 };
106
TEST_F(ShapeInferenceTest,UnaryNegateMatrix)107 TEST_F(ShapeInferenceTest, UnaryNegateMatrix) {
108 Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
109 auto inferred_status =
110 ShapeInference::InferUnaryOpShape(HloOpcode::kNegate, matrix_shape);
111 ASSERT_IS_OK(inferred_status.status());
112 ASSERT_TRUE(ShapeUtil::Equal(matrix_shape, inferred_status.ValueOrDie()));
113 }
114
TEST_F(ShapeInferenceTest,SelectScalarPredBetweenTuples)115 TEST_F(ShapeInferenceTest, SelectScalarPredBetweenTuples) {
116 Shape tuple = ShapeUtil::MakeTupleShape({s32_, f32_});
117 auto inferred_status = ShapeInference::InferTernaryOpShape(
118 HloOpcode::kSelect, pred_, tuple, tuple);
119 ASSERT_FALSE(inferred_status.ok());
120 ASSERT_THAT(inferred_status.status().error_message(),
121 HasSubstr("Expected array argument for select"));
122 }
123
TEST_F(ShapeInferenceTest,SelectScalarPredBetweenArrays)124 TEST_F(ShapeInferenceTest, SelectScalarPredBetweenArrays) {
125 auto inferred_status = ShapeInference::InferTernaryOpShape(
126 HloOpcode::kSelect, pred_, matrix_64_48_, matrix_64_48_);
127 ASSERT_FALSE(inferred_status.ok());
128 ASSERT_THAT(
129 inferred_status.status().error_message(),
130 HasSubstr("Operands to select and predicate must be the same shape"));
131 }
132
TEST_F(ShapeInferenceTest,SelectArrayPredBetweenArrays)133 TEST_F(ShapeInferenceTest, SelectArrayPredBetweenArrays) {
134 auto predarray = ShapeUtil::MakeShape(PRED, {64, 48});
135 auto inferred_status = ShapeInference::InferTernaryOpShape(
136 HloOpcode::kSelect, predarray, matrix_64_48_, matrix_64_48_);
137 ASSERT_IS_OK(inferred_status.status());
138 ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
139 }
140
TEST_F(ShapeInferenceTest,SelectBadShapes)141 TEST_F(ShapeInferenceTest, SelectBadShapes) {
142 auto inferred_status_error1 = ShapeInference::InferTernaryOpShape(
143 HloOpcode::kSelect, pred_, matrix_64_48_, matrix_32_64_);
144 ASSERT_FALSE(inferred_status_error1.ok());
145 ASSERT_THAT(inferred_status_error1.status().error_message(),
146 HasSubstr("Operands to select must be the same shape"));
147
148 auto inferred_status_error2 = ShapeInference::InferTernaryOpShape(
149 HloOpcode::kSelect, s32_, matrix_64_48_, matrix_64_48_);
150 ASSERT_FALSE(inferred_status_error2.ok());
151 ASSERT_THAT(inferred_status_error2.status().error_message(),
152 HasSubstr("pred operand must have PRED"));
153
154 auto inferred_status_error3 = ShapeInference::InferTernaryOpShape(
155 HloOpcode::kSelect, ShapeUtil::MakeShape(PRED, {64}), matrix_64_48_,
156 matrix_64_48_);
157 ASSERT_FALSE(inferred_status_error3.ok());
158 ASSERT_THAT(
159 inferred_status_error3.status().error_message(),
160 HasSubstr("Operands to select and predicate must be the same shape"));
161
162 // Tuples have a TUPLE element type and cannot be the pred of a select.
163 auto inferred_status_error4 = ShapeInference::InferTernaryOpShape(
164 HloOpcode::kSelect, ShapeUtil::MakeTupleShape({pred_, pred_}),
165 ShapeUtil::MakeTupleShape({f32_, f32_}),
166 ShapeUtil::MakeTupleShape({f32_, f32_}));
167 ASSERT_FALSE(inferred_status_error4.ok());
168 ASSERT_THAT(inferred_status_error4.status().error_message(),
169 HasSubstr("Expected array argument for select pred"));
170 }
171
TEST_F(ShapeInferenceTest,ClampAllMatrix)172 TEST_F(ShapeInferenceTest, ClampAllMatrix) {
173 auto inferred_status = ShapeInference::InferTernaryOpShape(
174 HloOpcode::kClamp, matrix_64_48_, matrix_64_48_, matrix_64_48_);
175 ASSERT_IS_OK(inferred_status.status());
176 ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
177 }
178
TEST_F(ShapeInferenceTest,ClampAllScalar)179 TEST_F(ShapeInferenceTest, ClampAllScalar) {
180 auto inferred_status =
181 ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, f32_, f32_);
182 ASSERT_IS_OK(inferred_status.status());
183 ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred_status.ValueOrDie()));
184 }
185
TEST_F(ShapeInferenceTest,ClampMinScalar)186 TEST_F(ShapeInferenceTest, ClampMinScalar) {
187 auto inferred_status = ShapeInference::InferTernaryOpShape(
188 HloOpcode::kClamp, f32_, matrix_64_48_, matrix_64_48_);
189 ASSERT_FALSE(inferred_status.ok());
190 ASSERT_THAT(inferred_status.status().error_message(),
191 HasSubstr("Clamp with different shapes"));
192 }
193
TEST_F(ShapeInferenceTest,ClampMaxScalar)194 TEST_F(ShapeInferenceTest, ClampMaxScalar) {
195 auto inferred_status = ShapeInference::InferTernaryOpShape(
196 HloOpcode::kClamp, matrix_64_48_, matrix_64_48_, f32_);
197 ASSERT_FALSE(inferred_status.ok());
198 ASSERT_THAT(inferred_status.status().error_message(),
199 HasSubstr("Clamp with different shapes"));
200 }
201
TEST_F(ShapeInferenceTest,ClampOperandScalar)202 TEST_F(ShapeInferenceTest, ClampOperandScalar) {
203 auto inferred_status = ShapeInference::InferTernaryOpShape(
204 HloOpcode::kClamp, matrix_64_48_, f32_, matrix_64_48_);
205 ASSERT_FALSE(inferred_status.ok());
206 ASSERT_THAT(inferred_status.status().error_message(),
207 HasSubstr("Clamp with different shapes"));
208 }
209
TEST_F(ShapeInferenceTest,ClampMinMatrix)210 TEST_F(ShapeInferenceTest, ClampMinMatrix) {
211 auto inferred_status = ShapeInference::InferTernaryOpShape(
212 HloOpcode::kClamp, matrix_64_48_, f32_, f32_);
213 ASSERT_FALSE(inferred_status.ok());
214 ASSERT_THAT(inferred_status.status().error_message(),
215 HasSubstr("Clamp with different shapes"));
216 }
217
TEST_F(ShapeInferenceTest,ClampMaxMatrix)218 TEST_F(ShapeInferenceTest, ClampMaxMatrix) {
219 auto inferred_status = ShapeInference::InferTernaryOpShape(
220 HloOpcode::kClamp, f32_, f32_, matrix_64_48_);
221 ASSERT_FALSE(inferred_status.ok());
222 ASSERT_THAT(inferred_status.status().error_message(),
223 HasSubstr("Clamp with different shapes"));
224 }
225
TEST_F(ShapeInferenceTest,ClampOperandMatrix)226 TEST_F(ShapeInferenceTest, ClampOperandMatrix) {
227 auto inferred_status = ShapeInference::InferTernaryOpShape(
228 HloOpcode::kClamp, f32_, matrix_64_48_, f32_);
229 ASSERT_FALSE(inferred_status.ok());
230 ASSERT_THAT(inferred_status.status().error_message(),
231 HasSubstr("Clamp with different shapes"));
232 }
233
TEST_F(ShapeInferenceTest,ClampBadShapes)234 TEST_F(ShapeInferenceTest, ClampBadShapes) {
235 // Type mismatch
236 ASSERT_FALSE(
237 ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, s32_, f32_, f32_)
238 .ok());
239 ASSERT_FALSE(
240 ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, s32_, f32_)
241 .ok());
242 ASSERT_FALSE(
243 ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, f32_, s32_)
244 .ok());
245 // Dimension mismatch
246 ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
247 HloOpcode::kClamp, vector_64_, vector_32_, vector_32_)
248 .ok());
249 ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
250 HloOpcode::kClamp, vector_32_, vector_64_, vector_32_)
251 .ok());
252 ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
253 HloOpcode::kClamp, vector_32_, vector_32_, vector_64_)
254 .ok());
255 // Dimension mismatch, where one operand is a scalar
256 ASSERT_FALSE(ShapeInference::InferTernaryOpShape(HloOpcode::kClamp,
257 vector_64_, vector_32_, f32_)
258 .ok());
259 ASSERT_FALSE(ShapeInference::InferTernaryOpShape(HloOpcode::kClamp,
260 vector_64_, f32_, vector_32_)
261 .ok());
262 ASSERT_FALSE(ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_,
263 vector_64_, vector_32_)
264 .ok());
265 }
266
TEST_F(ShapeInferenceTest,Complex)267 TEST_F(ShapeInferenceTest, Complex) {
268 auto complex_shape = [&](const Shape& lhs, const Shape& rhs,
269 absl::Span<const int64_t> bcast) {
270 return ShapeInference::InferBinaryOpShape(HloOpcode::kComplex, lhs, rhs,
271 bcast);
272 };
273 // Inputs must be FP.
274 ASSERT_FALSE(complex_shape(s32_, s32_, {}).ok());
275 ASSERT_FALSE(complex_shape(pred_, pred_, {}).ok());
276 // Component types must match.
277 ASSERT_FALSE(complex_shape(f32_, f64_, {}).ok());
278 // Only F32->C64 and F64->C128 supported.
279 ASSERT_FALSE(complex_shape(f16_, f16_, {}).ok());
280 // Validate correct uses.
281 Shape c64_32 = ShapeUtil::MakeShape(C64, {32});
282 TF_ASSERT_OK_AND_ASSIGN(Shape result, complex_shape(f32_, f32_, {}));
283 ASSERT_TRUE(ShapeUtil::Equal(result, ShapeUtil::MakeShape(C64, {})));
284 TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(vector_32_, f32_, {}));
285 ASSERT_TRUE(ShapeUtil::Equal(result, c64_32));
286 TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(f32_, vector_32_, {}));
287 ASSERT_TRUE(ShapeUtil::Equal(result, c64_32));
288 TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(vector_32_, f32_, {}));
289 ASSERT_TRUE(ShapeUtil::Equal(result, c64_32));
290
291 Shape c64_32_64 = ShapeUtil::MakeShape(C64, {32, 64});
292 TF_ASSERT_OK_AND_ASSIGN(result,
293 complex_shape(vector_64_, matrix_32_64_, {1}));
294 ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
295 TF_ASSERT_OK_AND_ASSIGN(result,
296 complex_shape(matrix_32_64_, vector_64_, {1}));
297 ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
298 TF_ASSERT_OK_AND_ASSIGN(result,
299 complex_shape(matrix_32_64_, matrix_32_64_, {}));
300 ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
301 TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(matrix_32_64_, f32_, {}));
302 ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
303
304 TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(f64_, f64_, {}));
305 ASSERT_TRUE(ShapeUtil::Equal(result, ShapeUtil::MakeShape(C128, {})));
306 }
307
TEST_F(ShapeInferenceTest,VariadicOpTuplify)308 TEST_F(ShapeInferenceTest, VariadicOpTuplify) {
309 StatusOr<Shape> result =
310 ShapeInference::InferVariadicOpShape(HloOpcode::kTuple, {&s32_, &f32_});
311 ASSERT_IS_OK(result.status());
312 ASSERT_TRUE(ShapeUtil::Equal(result.ValueOrDie(),
313 ShapeUtil::MakeTupleShape({s32_, f32_})));
314 }
315
TEST_F(ShapeInferenceTest,ReduceWindowInHalf)316 TEST_F(ShapeInferenceTest, ReduceWindowInHalf) {
317 Shape matrix_shape = ShapeUtil::MakeShape(F32, {8, 8});
318 Window window;
319 WindowDimension dim;
320 dim.set_size(2);
321 dim.set_stride(2);
322 dim.set_padding_low(0);
323 dim.set_padding_high(0);
324 dim.set_window_dilation(1);
325 dim.set_base_dilation(1);
326 *window.add_dimensions() = dim;
327 *window.add_dimensions() = dim;
328 Shape window_shape = ShapeUtil::MakeShape(F32, {2, 2});
329 Shape init_value_shape = ShapeUtil::MakeShape(F32, {});
330 Shape float_scalar = ShapeUtil::MakeShape(F32, {});
331 ProgramShape to_apply = ShapeUtil::MakeProgramShape(
332 {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, f32_);
333 auto inferred_status = ShapeInference::InferReduceWindowShape(
334 matrix_shape, init_value_shape, window, to_apply);
335
336 ASSERT_IS_OK(inferred_status.status());
337 Shape inferred = inferred_status.ValueOrDie();
338 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {4, 4}), inferred));
339 }
340
TEST_F(SelectAndScatterShapeInferenceTest,SelectAndScatterProperShapes)341 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterProperShapes) {
342 auto inferred_status_ok = ShapeInference::InferSelectAndScatterShape(
343 operand_shape_, select_program_shape_, window_, source_shape_,
344 init_value_shape_, scatter_program_shape_);
345 ASSERT_IS_OK(inferred_status_ok.status());
346 Shape inferred = inferred_status_ok.ValueOrDie();
347 ASSERT_TRUE(ShapeUtil::Equal(operand_shape_, inferred));
348 }
349
TEST_F(SelectAndScatterShapeInferenceTest,SelectAndScatterWrongSourceShape)350 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSourceShape) {
351 Shape source_shape_fail = ShapeUtil::MakeShape(F32, {4, 6});
352 auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape(
353 operand_shape_, select_program_shape_, window_, source_shape_fail,
354 init_value_shape_, scatter_program_shape_);
355 ASSERT_FALSE(inferred_status_fail.ok());
356 ASSERT_THAT(inferred_status_fail.status().error_message(),
357 HasSubstr("Source shape does not match"));
358 }
359
TEST_F(SelectAndScatterShapeInferenceTest,SelectAndScatterWrongSelectShape1)360 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape1) {
361 ProgramShape select_program_shape_fail =
362 ShapeUtil::MakeProgramShape({ShapeUtil::MakeShape(F32, {})}, pred_);
363 auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape(
364 operand_shape_, select_program_shape_fail, window_, source_shape_,
365 init_value_shape_, scatter_program_shape_);
366 ASSERT_FALSE(inferred_status_fail.ok());
367 ASSERT_THAT(inferred_status_fail.status().error_message(),
368 HasSubstr("Select function must take 2 parameters"));
369 }
370
TEST_F(SelectAndScatterShapeInferenceTest,SelectAndScatterWrongSelectShape2)371 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape2) {
372 ProgramShape select_program_shape_fail = ShapeUtil::MakeProgramShape(
373 {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, f32_);
374 auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape(
375 operand_shape_, select_program_shape_fail, window_, source_shape_,
376 init_value_shape_, scatter_program_shape_);
377 ASSERT_FALSE(inferred_status_fail.ok());
378 ASSERT_THAT(inferred_status_fail.status().error_message(),
379 HasSubstr("Select function must have rank-0 PRED"));
380 }
381
TEST_F(SelectAndScatterShapeInferenceTest,SelectAndScatterWrongSelectShape3)382 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape3) {
383 ProgramShape select_program_shape_fail = ShapeUtil::MakeProgramShape(
384 {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {})}, pred_);
385 auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape(
386 operand_shape_, select_program_shape_fail, window_, source_shape_,
387 init_value_shape_, scatter_program_shape_);
388 ASSERT_FALSE(inferred_status_fail.ok());
389 ASSERT_THAT(inferred_status_fail.status().error_message(),
390 HasSubstr("Select function's first parameter"));
391 }
392
TEST_F(SelectAndScatterShapeInferenceTest,SelectAndScatterWrongSelectShape4)393 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape4) {
394 ProgramShape select_program_shape_fail = ShapeUtil::MakeProgramShape(
395 {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(U32, {})}, pred_);
396 auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape(
397 operand_shape_, select_program_shape_fail, window_, source_shape_,
398 init_value_shape_, scatter_program_shape_);
399 ASSERT_FALSE(inferred_status_fail.ok());
400 ASSERT_THAT(inferred_status_fail.status().error_message(),
401 HasSubstr("Select function's second parameter"));
402 }
403
TEST_F(ShapeInferenceTest,AllGatherStart)404 TEST_F(ShapeInferenceTest, AllGatherStart) {
405 const Shape operand = ShapeUtil::MakeShape(F32, {1, 8, 4});
406 const Shape expected_shape = ShapeUtil::MakeTupleShape(
407 {operand, ShapeUtil::MakeShape(F32, {8, 8, 4})});
408
409 auto inferred_ag_shape = ShapeInference::InferAllGatherStartShape(
410 {&operand}, /*all_gather_dimension=*/0, /*shard_count=*/8);
411 EXPECT_TRUE(inferred_ag_shape.ok());
412 EXPECT_TRUE(ShapeUtil::Equal(inferred_ag_shape.ValueOrDie(), expected_shape));
413 }
414
TEST_F(ShapeInferenceTest,AllGatherDone)415 TEST_F(ShapeInferenceTest, AllGatherDone) {
416 const Shape input_shape =
417 ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {1, 8, 4}),
418 ShapeUtil::MakeShape(F32, {8, 8, 4})});
419 const Shape expected_shape = ShapeUtil::MakeShape(F32, {8, 8, 4});
420
421 auto inferred_ag_done_shape =
422 ShapeInference::InferAllGatherDoneShape(input_shape);
423 EXPECT_TRUE(inferred_ag_done_shape.ok());
424 EXPECT_TRUE(
425 ShapeUtil::Equal(inferred_ag_done_shape.ValueOrDie(), expected_shape));
426 }
427
TEST_F(ShapeInferenceTest,Convolve)428 TEST_F(ShapeInferenceTest, Convolve) {
429 ConvolutionDimensionNumbers dnums;
430
431 // Dimension order: batch, feature, x0, x1
432 Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 3, 4});
433 dnums.set_input_batch_dimension(0);
434 dnums.set_output_batch_dimension(0);
435 dnums.set_input_feature_dimension(1);
436 dnums.set_output_feature_dimension(1);
437 dnums.add_input_spatial_dimensions(2);
438 dnums.add_output_spatial_dimensions(2);
439 dnums.add_input_spatial_dimensions(3);
440 dnums.add_output_spatial_dimensions(3);
441
442 // Dimension order: x1, batch, feature, x0
443 Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 3});
444 dnums.set_kernel_input_feature_dimension(2);
445 dnums.set_kernel_output_feature_dimension(1);
446 dnums.add_kernel_spatial_dimensions(3);
447 dnums.add_kernel_spatial_dimensions(0);
448
449 Window window;
450 auto dim0 = window.add_dimensions();
451 auto dim1 = window.add_dimensions();
452 dim0->set_size(3);
453 dim0->set_stride(2);
454 dim0->set_padding_low(1);
455 dim0->set_padding_high(1);
456 dim0->set_window_dilation(1);
457 dim0->set_base_dilation(1);
458 dim1->set_size(2);
459 dim1->set_stride(1);
460 dim1->set_padding_low(0);
461 dim1->set_padding_high(0);
462 dim1->set_window_dilation(1);
463 dim1->set_base_dilation(1);
464 auto inferred_status = ShapeInference::InferConvolveShape(
465 lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1,
466 window, dnums, /*preferred_element_type=*/std::nullopt);
467 ASSERT_IS_OK(inferred_status.status());
468 Shape inferred_shape = inferred_status.ValueOrDie();
469 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 2, 3}),
470 inferred_shape));
471 }
472
TEST_F(ShapeInferenceTest,ConvolveWithWindowDilation)473 TEST_F(ShapeInferenceTest, ConvolveWithWindowDilation) {
474 ConvolutionDimensionNumbers dnums;
475
476 // Dimension order: batch, feature, x0, x1
477 Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 103, 4});
478 dnums.set_input_batch_dimension(0);
479 dnums.set_output_batch_dimension(0);
480 dnums.set_input_feature_dimension(1);
481 dnums.set_output_feature_dimension(1);
482 dnums.add_input_spatial_dimensions(2);
483 dnums.add_output_spatial_dimensions(2);
484 dnums.add_input_spatial_dimensions(3);
485 dnums.add_output_spatial_dimensions(3);
486
487 // Dimension order: x1, batch, feature, x0
488 Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 3});
489 dnums.set_kernel_input_feature_dimension(2);
490 dnums.set_kernel_output_feature_dimension(1);
491 dnums.add_kernel_spatial_dimensions(3);
492 dnums.add_kernel_spatial_dimensions(0);
493
494 Window window;
495 auto dim0 = window.add_dimensions();
496 dim0->set_size(3);
497 dim0->set_stride(3);
498 dim0->set_padding_low(0);
499 dim0->set_padding_high(0);
500 dim0->set_window_dilation(6);
501 dim0->set_base_dilation(1);
502
503 auto dim1 = window.add_dimensions();
504 dim1->set_size(2);
505 dim1->set_stride(1);
506 dim1->set_padding_low(2);
507 dim1->set_padding_high(1);
508 dim1->set_window_dilation(2);
509 dim1->set_base_dilation(1);
510 auto inferred_status = ShapeInference::InferConvolveShape(
511 lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1,
512 window, dnums, /*preferred_element_type=*/std::nullopt);
513 ASSERT_IS_OK(inferred_status.status());
514 Shape inferred_shape = inferred_status.ValueOrDie();
515 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 31, 5}),
516 inferred_shape));
517 }
518
TEST_F(ShapeInferenceTest,ConvolveWithBaseDilation)519 TEST_F(ShapeInferenceTest, ConvolveWithBaseDilation) {
520 ConvolutionDimensionNumbers dnums;
521
522 // Dimension order: batch, feature, x0, x1
523 Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 3, 4});
524 dnums.set_input_batch_dimension(0);
525 dnums.set_output_batch_dimension(0);
526 dnums.set_input_feature_dimension(1);
527 dnums.set_output_feature_dimension(1);
528 dnums.add_input_spatial_dimensions(2);
529 dnums.add_output_spatial_dimensions(2);
530 dnums.add_input_spatial_dimensions(3);
531 dnums.add_output_spatial_dimensions(3);
532
533 // Dimension order: x1, batch, feature, x0
534 Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 4});
535 dnums.set_kernel_input_feature_dimension(2);
536 dnums.set_kernel_output_feature_dimension(1);
537 dnums.add_kernel_spatial_dimensions(3);
538 dnums.add_kernel_spatial_dimensions(0);
539
540 Window window;
541 auto dim0 = window.add_dimensions();
542 dim0->set_size(4);
543 dim0->set_stride(3);
544 dim0->set_padding_low(0);
545 dim0->set_padding_high(0);
546 dim0->set_window_dilation(1);
547 dim0->set_base_dilation(6);
548
549 auto dim1 = window.add_dimensions();
550 dim1->set_size(2);
551 dim1->set_stride(1);
552 dim1->set_padding_low(2);
553 dim1->set_padding_high(1);
554 dim1->set_window_dilation(1);
555 dim1->set_base_dilation(2);
556 auto inferred_status = ShapeInference::InferConvolveShape(
557 lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1,
558 window, dnums, /*preferred_element_type=*/std::nullopt);
559 ASSERT_IS_OK(inferred_status.status());
560 Shape inferred_shape = inferred_status.ValueOrDie();
561 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 4, 9}),
562 inferred_shape));
563 }
564
TEST_F(ShapeInferenceTest,ConvolveDimensionNumbersOverlapError)565 TEST_F(ShapeInferenceTest, ConvolveDimensionNumbersOverlapError) {
566 // Dimension order for this test: batch, feature, x0, x1
567 Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 3, 4});
568 Shape rhs_shape = ShapeUtil::MakeShape(F32, {12, 11, 3, 2});
569
570 ConvolutionDimensionNumbers dnums;
571 dnums.set_input_batch_dimension(3);
572 dnums.set_output_batch_dimension(3);
573 dnums.set_input_feature_dimension(2);
574 dnums.set_output_feature_dimension(2);
575 dnums.add_input_spatial_dimensions(0);
576 dnums.add_output_spatial_dimensions(0);
577 dnums.add_input_spatial_dimensions(1);
578 dnums.add_output_spatial_dimensions(1);
579 dnums.set_kernel_input_feature_dimension(0); // duplicated with kernel_x0
580 dnums.set_kernel_output_feature_dimension(3);
581 dnums.add_kernel_spatial_dimensions(0);
582 dnums.add_kernel_spatial_dimensions(1);
583
584 Window window;
585 auto dim0 = window.add_dimensions();
586 auto dim1 = window.add_dimensions();
587 dim0->set_size(2);
588 dim0->set_stride(1);
589 dim0->set_padding_low(0);
590 dim0->set_padding_high(0);
591 dim1->set_size(3);
592 dim1->set_stride(2);
593 dim1->set_padding_low(1);
594 dim1->set_padding_high(1);
595 auto inferred_status = ShapeInference::InferConvolveShape(
596 lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1,
597 window, dnums, /*preferred_element_type=*/std::nullopt);
598 ASSERT_FALSE(inferred_status.ok());
599 ASSERT_THAT(inferred_status.status().error_message(),
600 HasSubstr("each dimension exactly once"));
601 }
602
TEST_F(ShapeInferenceTest,ConvolveBatchGroupCountUnequalOutputFeature)603 TEST_F(ShapeInferenceTest, ConvolveBatchGroupCountUnequalOutputFeature) {
604 ConvolutionDimensionNumbers dnums;
605 dnums.set_input_batch_dimension(0);
606 dnums.set_input_feature_dimension(1);
607 dnums.add_input_spatial_dimensions(2);
608 dnums.add_input_spatial_dimensions(3);
609 dnums.set_kernel_input_feature_dimension(0);
610 dnums.set_kernel_output_feature_dimension(1);
611 dnums.add_kernel_spatial_dimensions(2);
612 dnums.add_kernel_spatial_dimensions(3);
613 dnums.set_output_batch_dimension(0);
614 dnums.set_output_feature_dimension(1);
615 dnums.add_output_spatial_dimensions(2);
616 dnums.add_output_spatial_dimensions(3);
617 Shape lhs_shape = ShapeUtil::MakeShape(F32, {60, 38, 17, 13});
618 Shape rhs_shape = ShapeUtil::MakeShape(F32, {38, 10, 4, 4});
619 Window window;
620 auto dim0 = window.add_dimensions();
621 auto dim1 = window.add_dimensions();
622 dim0->set_size(4);
623 dim1->set_size(4);
624 dim0->set_padding_low(0);
625 dim0->set_padding_high(2);
626 dim1->set_padding_low(2);
627 dim1->set_padding_high(1);
628 dim0->set_stride(1);
629 dim1->set_stride(1);
630 dim0->set_window_dilation(3);
631 dim1->set_window_dilation(2);
632 auto inferred_status = ShapeInference::InferConvolveShape(
633 lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/6,
634 window, dnums, /*preferred_element_type=*/std::nullopt);
635 ASSERT_FALSE(inferred_status.ok());
636 ASSERT_THAT(inferred_status.status().error_message(),
637 HasSubstr("to be a multiple of batch group count"));
638 }
639
640 struct ConvolveArgs {
641 Shape lhs_shape;
642 Shape rhs_shape;
643 ConvolutionDimensionNumbers dnums;
644 Window window;
645 };
646
MakeConvolveArgs(PrimitiveType lhs_type,PrimitiveType rhs_type)647 ConvolveArgs MakeConvolveArgs(PrimitiveType lhs_type, PrimitiveType rhs_type) {
648 ConvolveArgs args;
649 ConvolutionDimensionNumbers& dnums = args.dnums;
650
651 // Dimension order: batch, feature, x0, x1
652 args.lhs_shape = ShapeUtil::MakeShape(lhs_type, {10, 11, 3, 4});
653 dnums.set_input_batch_dimension(0);
654 dnums.set_output_batch_dimension(0);
655 dnums.set_input_feature_dimension(1);
656 dnums.set_output_feature_dimension(1);
657 dnums.add_input_spatial_dimensions(2);
658 dnums.add_output_spatial_dimensions(2);
659 dnums.add_input_spatial_dimensions(3);
660 dnums.add_output_spatial_dimensions(3);
661
662 // Dimension order: x1, batch, feature, x0
663 args.rhs_shape = ShapeUtil::MakeShape(rhs_type, {2, 12, 11, 3});
664 dnums.set_kernel_input_feature_dimension(2);
665 dnums.set_kernel_output_feature_dimension(1);
666 dnums.add_kernel_spatial_dimensions(3);
667 dnums.add_kernel_spatial_dimensions(0);
668
669 auto dim0 = args.window.add_dimensions();
670 auto dim1 = args.window.add_dimensions();
671 dim0->set_size(3);
672 dim0->set_stride(2);
673 dim0->set_padding_low(1);
674 dim0->set_padding_high(1);
675 dim0->set_window_dilation(1);
676 dim0->set_base_dilation(1);
677 dim1->set_size(2);
678 dim1->set_stride(1);
679 dim1->set_padding_low(0);
680 dim1->set_padding_high(0);
681 dim1->set_window_dilation(1);
682 dim1->set_base_dilation(1);
683 return args;
684 }
685
TEST_F(ShapeInferenceTest,ConvolveWithBF16_F16)686 TEST_F(ShapeInferenceTest, ConvolveWithBF16_F16) {
687 ConvolveArgs args = MakeConvolveArgs(BF16, F16);
688 TF_ASSERT_OK_AND_ASSIGN(
689 Shape inferred_shape,
690 ShapeInference::InferConvolveShape(
691 args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
692 /*batch_group_count=*/1, args.window, args.dnums,
693 /*preferred_element_type=*/std::nullopt))
694 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(BF16, {10, 12, 2, 3}),
695 inferred_shape));
696 }
697
TEST_F(ShapeInferenceTest,ConvolveWithF16_BF16)698 TEST_F(ShapeInferenceTest, ConvolveWithF16_BF16) {
699 ConvolveArgs args = MakeConvolveArgs(F16, BF16);
700 TF_ASSERT_OK_AND_ASSIGN(
701 Shape inferred_shape,
702 ShapeInference::InferConvolveShape(
703 args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
704 /*batch_group_count=*/1, args.window, args.dnums,
705 /*preferred_element_type=*/std::nullopt))
706 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(BF16, {10, 12, 2, 3}),
707 inferred_shape));
708 }
709
TEST_F(ShapeInferenceTest,ConvolveWithS32_U32)710 TEST_F(ShapeInferenceTest, ConvolveWithS32_U32) {
711 ConvolveArgs args = MakeConvolveArgs(S32, U32);
712 TF_ASSERT_OK_AND_ASSIGN(
713 Shape inferred_shape,
714 ShapeInference::InferConvolveShape(
715 args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
716 /*batch_group_count=*/1, args.window, args.dnums,
717 /*preferred_element_type=*/std::nullopt))
718 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(S32, {10, 12, 2, 3}),
719 inferred_shape));
720 }
721
TEST_F(ShapeInferenceTest,ConvolveWithU32_S32)722 TEST_F(ShapeInferenceTest, ConvolveWithU32_S32) {
723 ConvolveArgs args = MakeConvolveArgs(U32, S32);
724 TF_ASSERT_OK_AND_ASSIGN(
725 Shape inferred_shape,
726 ShapeInference::InferConvolveShape(
727 args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
728 /*batch_group_count=*/1, args.window, args.dnums,
729 /*preferred_element_type=*/std::nullopt))
730 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(S32, {10, 12, 2, 3}),
731 inferred_shape));
732 }
733
TEST_F(ShapeInferenceTest,ConvolveWithPreferredElementType)734 TEST_F(ShapeInferenceTest, ConvolveWithPreferredElementType) {
735 ConvolveArgs args = MakeConvolveArgs(S8, S16);
736 TF_ASSERT_OK_AND_ASSIGN(
737 Shape inferred_shape,
738 ShapeInference::InferConvolveShape(
739 args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
740 /*batch_group_count=*/1, args.window, args.dnums,
741 /*preferred_element_type=*/S16))
742 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(S16, {10, 12, 2, 3}),
743 inferred_shape));
744 }
745
TEST_F(ShapeInferenceTest,ConvolveWithPreferredElementTypeSameAsInferredType)746 TEST_F(ShapeInferenceTest, ConvolveWithPreferredElementTypeSameAsInferredType) {
747 ConvolveArgs args = MakeConvolveArgs(S8, S16);
748 TF_ASSERT_OK_AND_ASSIGN(
749 Shape inferred_shape,
750 ShapeInference::InferConvolveShape(
751 args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
752 /*batch_group_count=*/1, args.window, args.dnums,
753 /*preferred_element_type=*/S32))
754 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(S32, {10, 12, 2, 3}),
755 inferred_shape));
756 }
757
TEST_F(ShapeInferenceTest,FloatingPointConvolveWithNarrowerPreferredElementType)758 TEST_F(ShapeInferenceTest,
759 FloatingPointConvolveWithNarrowerPreferredElementType) {
760 ConvolveArgs args = MakeConvolveArgs(F32, F32);
761 TF_ASSERT_OK_AND_ASSIGN(
762 Shape inferred_shape,
763 ShapeInference::InferConvolveShape(
764 args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
765 /*batch_group_count=*/1, args.window, args.dnums,
766 /*preferred_element_type=*/BF16))
767 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(BF16, {10, 12, 2, 3}),
768 inferred_shape));
769 }
770
TEST_F(ShapeInferenceTest,FloatingPointConvolveWithIntegralPreferredElementType)771 TEST_F(ShapeInferenceTest,
772 FloatingPointConvolveWithIntegralPreferredElementType) {
773 ConvolveArgs args = MakeConvolveArgs(BF16, BF16);
774 TF_ASSERT_OK_AND_ASSIGN(
775 Shape inferred_shape,
776 ShapeInference::InferConvolveShape(
777 args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
778 /*batch_group_count=*/1, args.window, args.dnums,
779 /*preferred_element_type=*/S32));
780 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(S32, {10, 12, 2, 3}),
781 inferred_shape));
782 }
783
TEST_F(ShapeInferenceTest,IntegralConvolveWithFloatingPointPreferredElementType)784 TEST_F(ShapeInferenceTest,
785 IntegralConvolveWithFloatingPointPreferredElementType) {
786 ConvolveArgs args = MakeConvolveArgs(S8, S16);
787 TF_ASSERT_OK_AND_ASSIGN(
788 Shape inferred_shape,
789 ShapeInference::InferConvolveShape(
790 args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
791 /*batch_group_count=*/1, args.window, args.dnums,
792 /*preferred_element_type=*/F32));
793 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 2, 3}),
794 inferred_shape));
795 }
796
TEST_F(ShapeInferenceTest,ConvolveWithPreferredElementTypeWithDifferentSignedness)797 TEST_F(ShapeInferenceTest,
798 ConvolveWithPreferredElementTypeWithDifferentSignedness) {
799 ConvolveArgs args = MakeConvolveArgs(S8, S16);
800 TF_ASSERT_OK_AND_ASSIGN(
801 Shape inferred_shape,
802 ShapeInference::InferConvolveShape(
803 args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
804 /*batch_group_count=*/1, args.window, args.dnums,
805 /*preferred_element_type=*/U32));
806 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(U32, {10, 12, 2, 3}),
807 inferred_shape));
808 }
809
TEST_F(ShapeInferenceTest,ConvolveWithNarrowerPreferredElementType)810 TEST_F(ShapeInferenceTest, ConvolveWithNarrowerPreferredElementType) {
811 ConvolveArgs args = MakeConvolveArgs(S8, S16);
812 auto inferred_status =
813 ShapeInference::InferConvolveShape(
814 args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
815 /*batch_group_count=*/1, args.window, args.dnums,
816 /*preferred_element_type=*/S8)
817 .status();
818 ASSERT_FALSE(inferred_status.ok());
819 ASSERT_THAT(inferred_status.error_message(),
820 HasSubstr("must not be narrower than the original type"));
821 }
822
823 namespace fft {
824
825 static const char* unsupported_rank = "only supports ranks 1-3";
826 static const char* invalid_rank = "requires input of at least same rank";
827 static const char* requires_complex_input = "requires complex input type";
828 static const char* requires_f32_input = "requires F32 or F64 input type";
829 static const char* dimensions_match = "innermost dimensions match fft_length";
830 static const char* innermost_dimension_matches =
831 "innermost dimension matches fft_length/2+1";
832
Pass(const Shape & shape,FftType type,absl::Span<const int64_t> length,const Shape & expected_shape)833 static void Pass(const Shape& shape, FftType type,
834 absl::Span<const int64_t> length,
835 const Shape& expected_shape) {
836 auto inferred_status = ShapeInference::InferFftShape(shape, type, length);
837 ASSERT_IS_OK(inferred_status.status());
838 Shape inferred_shape = inferred_status.ValueOrDie();
839 ASSERT_TRUE(ShapeUtil::Equal(inferred_shape, expected_shape));
840 }
841
Fail(const Shape & shape,FftType type,absl::Span<const int64_t> length,absl::string_view message)842 static void Fail(const Shape& shape, FftType type,
843 absl::Span<const int64_t> length, absl::string_view message) {
844 auto inferred_status = ShapeInference::InferFftShape(shape, type, length);
845 ASSERT_FALSE(inferred_status.ok());
846 ASSERT_THAT(inferred_status.status().error_message(),
847 HasSubstr(std::string(message)));
848 }
849
850 } // namespace fft
851
TEST_F(ShapeInferenceTest,InferFftShapeTestFftRanks)852 TEST_F(ShapeInferenceTest, InferFftShapeTestFftRanks) {
853 FftType type = FftType::FFT;
854 Shape shape = ShapeUtil::MakeShape(C64, {16, 8});
855 fft::Fail(shape, type, {}, fft::unsupported_rank);
856 fft::Pass(shape, type, {8}, shape);
857 fft::Pass(shape, type, {16, 8}, shape);
858 fft::Fail(shape, type, {32, 16, 8}, fft::invalid_rank);
859 fft::Fail(shape, type, {64, 32, 16, 8}, fft::unsupported_rank);
860 }
861
TEST_F(ShapeInferenceTest,InferFftShapeTestFftTypes)862 TEST_F(ShapeInferenceTest, InferFftShapeTestFftTypes) {
863 FftType type = FftType::FFT;
864 Shape shape_f32 = ShapeUtil::MakeShape(F32, {16, 8});
865 Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 8});
866 fft::Fail(shape_f32, type, {16, 8}, fft::requires_complex_input);
867 fft::Pass(shape_c128, type, {16, 8}, shape_c128);
868 }
869
TEST_F(ShapeInferenceTest,InferFftShapeTestIfftRanks)870 TEST_F(ShapeInferenceTest, InferFftShapeTestIfftRanks) {
871 FftType type = FftType::IFFT;
872 Shape shape = ShapeUtil::MakeShape(C64, {16, 8});
873 fft::Fail(shape, type, {}, fft::unsupported_rank);
874 fft::Pass(shape, type, {8}, shape);
875 fft::Pass(shape, type, {16, 8}, shape);
876 fft::Fail(shape, type, {32, 16, 8}, fft::invalid_rank);
877 fft::Fail(shape, type, {64, 32, 16, 8}, fft::unsupported_rank);
878 }
879
TEST_F(ShapeInferenceTest,InferFftShapeTestIfftTypes)880 TEST_F(ShapeInferenceTest, InferFftShapeTestIfftTypes) {
881 FftType type = FftType::IFFT;
882 Shape shape_f32 = ShapeUtil::MakeShape(F32, {16, 8});
883 Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 8});
884 fft::Fail(shape_f32, type, {16, 8}, fft::requires_complex_input);
885 fft::Pass(shape_c128, type, {16, 8}, shape_c128);
886 }
887
TEST_F(ShapeInferenceTest,InferFftShapeTestRfftRanks)888 TEST_F(ShapeInferenceTest, InferFftShapeTestRfftRanks) {
889 FftType type = FftType::RFFT;
890 Shape shape_in = ShapeUtil::MakeShape(F32, {16, 8});
891 Shape shape_out = ShapeUtil::MakeShape(C64, {16, 5});
892 fft::Fail(shape_in, type, {}, fft::unsupported_rank);
893 fft::Pass(shape_in, type, {8}, shape_out);
894 fft::Pass(shape_in, type, {16, 8}, shape_out);
895 fft::Fail(shape_in, type, {32, 16, 8}, fft::invalid_rank);
896 fft::Fail(shape_in, type, {64, 32, 16, 8}, fft::unsupported_rank);
897 }
898
TEST_F(ShapeInferenceTest,InferFftShapeTestRfftDimensions)899 TEST_F(ShapeInferenceTest, InferFftShapeTestRfftDimensions) {
900 FftType type = FftType::RFFT;
901 Shape shape = ShapeUtil::MakeShape(F32, {16, 8});
902 fft::Fail(shape, type, {4}, fft::dimensions_match);
903 fft::Fail(shape, type, {16, 4}, fft::dimensions_match);
904 fft::Fail(shape, type, {8, 8}, fft::dimensions_match);
905 fft::Fail(shape, type, {8, 16}, fft::dimensions_match);
906
907 Shape zero_shape_in = ShapeUtil::MakeShape(F32, {16, 0});
908 Shape zero_shape_out = ShapeUtil::MakeShape(C64, {16, 0});
909 fft::Pass(zero_shape_in, type, {0}, zero_shape_out);
910 fft::Pass(zero_shape_in, type, {16, 0}, zero_shape_out);
911
912 Shape even_shape_in = ShapeUtil::MakeShape(F32, {16, 8});
913 Shape odd_shape_in = ShapeUtil::MakeShape(F32, {16, 9});
914 Shape shape_out = ShapeUtil::MakeShape(C64, {16, 5});
915 fft::Pass(even_shape_in, type, {16, 8}, shape_out);
916 fft::Pass(odd_shape_in, type, {16, 9}, shape_out);
917 }
918
TEST_F(ShapeInferenceTest,InferFftShapeTestRfftTypes)919 TEST_F(ShapeInferenceTest, InferFftShapeTestRfftTypes) {
920 FftType type = FftType::RFFT;
921 Shape shape_c64 = ShapeUtil::MakeShape(C64, {16, 8});
922 Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 8});
923 fft::Fail(shape_c64, type, {16, 8}, fft::requires_f32_input);
924 fft::Fail(shape_c128, type, {16, 8}, fft::requires_f32_input);
925 }
926
TEST_F(ShapeInferenceTest,InferFftShapeTestIrfftRanks)927 TEST_F(ShapeInferenceTest, InferFftShapeTestIrfftRanks) {
928 FftType type = FftType::IRFFT;
929 Shape shape_in = ShapeUtil::MakeShape(C64, {16, 5});
930 Shape shape_out = ShapeUtil::MakeShape(F32, {16, 8});
931 fft::Fail(shape_in, type, {}, fft::unsupported_rank);
932 fft::Pass(shape_in, type, {8}, shape_out);
933 fft::Pass(shape_in, type, {16, 8}, shape_out);
934 fft::Fail(shape_in, type, {32, 16, 8}, fft::invalid_rank);
935 fft::Fail(shape_in, type, {64, 32, 16, 8}, fft::unsupported_rank);
936 }
937
TEST_F(ShapeInferenceTest,InferFftShapeTestIrfftDimensions)938 TEST_F(ShapeInferenceTest, InferFftShapeTestIrfftDimensions) {
939 FftType type = FftType::IRFFT;
940 Shape shape = ShapeUtil::MakeShape(C64, {16, 5});
941 fft::Fail(shape, type, {5}, fft::innermost_dimension_matches);
942 fft::Fail(shape, type, {16, 5}, fft::innermost_dimension_matches);
943 fft::Fail(shape, type, {8, 8}, fft::dimensions_match);
944 fft::Fail(shape, type, {8, 9}, fft::dimensions_match);
945
946 Shape zero_shape_in = ShapeUtil::MakeShape(C64, {16, 0});
947 Shape zero_shape_out = ShapeUtil::MakeShape(F32, {16, 0});
948 fft::Pass(zero_shape_in, type, {0}, zero_shape_out);
949 fft::Pass(zero_shape_in, type, {16, 0}, zero_shape_out);
950
951 Shape even_shape_out = ShapeUtil::MakeShape(F32, {16, 8});
952 Shape odd_shape_out = ShapeUtil::MakeShape(F32, {16, 9});
953 fft::Pass(shape, type, {16, 8}, even_shape_out);
954 fft::Pass(shape, type, {16, 9}, odd_shape_out);
955 }
956
TEST_F(ShapeInferenceTest,InferFftShapeTestIrfftTypes)957 TEST_F(ShapeInferenceTest, InferFftShapeTestIrfftTypes) {
958 FftType type = FftType::IRFFT;
959 Shape shape_f32 = ShapeUtil::MakeShape(F32, {16, 8});
960 Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 5});
961 Shape shape_f64_out = ShapeUtil::MakeShape(F64, {16, 8});
962 fft::Fail(shape_f32, type, {16, 8}, fft::requires_complex_input);
963 fft::Pass(shape_c128, type, {16, 8}, shape_f64_out);
964 }
965
TEST_F(ShapeInferenceTest,MapThatChangesElementType)966 TEST_F(ShapeInferenceTest, MapThatChangesElementType) {
967 Shape arg = ShapeUtil::MakeShape(F32, {20});
968 ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_}, s32_);
969 auto inferred_status = ShapeInference::InferMapShape({&arg}, to_apply, {0});
970 EXPECT_IS_OK(inferred_status.status());
971 Shape expected = ShapeUtil::MakeShape(S32, {20});
972 EXPECT_TRUE(ShapeUtil::Equal(expected, inferred_status.ValueOrDie()));
973 }
974
TEST_F(ShapeInferenceTest,Map)975 TEST_F(ShapeInferenceTest, Map) {
976 auto inferred_status_r1f32 = ShapeInference::InferMapShape(
977 {&vector_32_, &vector_32_},
978 ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {0});
979 EXPECT_IS_OK(inferred_status_r1f32.status());
980 EXPECT_TRUE(ShapeUtil::Equal(vector_32_, inferred_status_r1f32.ValueOrDie()));
981
982 // It's OK to provide a single argument, as long as the applied arity matches
983 // (this degenerates to a Map).
984 auto inferred_status_r1f32_one = ShapeInference::InferMapShape(
985 {&vector_32_}, ShapeUtil::MakeProgramShape({f32_}, f32_), {0});
986 EXPECT_IS_OK(inferred_status_r1f32_one.status());
987 EXPECT_TRUE(
988 ShapeUtil::Equal(vector_32_, inferred_status_r1f32_one.ValueOrDie()));
989
990 auto inferred_status_r2s32 = ShapeInference::InferMapShape(
991 {&s32matrix_64_64_, &s32matrix_64_64_, &s32matrix_64_64_},
992 ShapeUtil::MakeProgramShape({s32_, s32_, s32_}, s32_), {0, 1});
993 EXPECT_IS_OK(inferred_status_r2s32.status());
994 EXPECT_TRUE(
995 ShapeUtil::Equal(s32matrix_64_64_, inferred_status_r2s32.ValueOrDie()));
996
997 auto no_args_error = ShapeInference::InferMapShape(
998 {}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {});
999 ASSERT_FALSE(no_args_error.ok());
1000 ASSERT_THAT(no_args_error.status().error_message(),
1001 HasSubstr("expects at least one argument"));
1002
1003 auto args_diff_shapes_error = ShapeInference::InferMapShape(
1004 {&vector_32_, &vector_64_},
1005 ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {0});
1006 ASSERT_FALSE(args_diff_shapes_error.ok());
1007 ASSERT_THAT(args_diff_shapes_error.status().error_message(),
1008 HasSubstr("requires all operands to have the same shape"));
1009
1010 auto arity_error = ShapeInference::InferMapShape(
1011 {&vector_32_, &vector_32_}, ShapeUtil::MakeProgramShape({f32_}, f32_),
1012 {0});
1013 ASSERT_FALSE(arity_error.ok());
1014 ASSERT_THAT(arity_error.status().error_message(),
1015 HasSubstr("function arity must match"));
1016
1017 auto output_shape_error = ShapeInference::InferMapShape(
1018 {&vector_32_, &vector_32_},
1019 ShapeUtil::MakeProgramShape({f32_, f32_}, vector_32_), {0});
1020 ASSERT_FALSE(output_shape_error.ok());
1021 ASSERT_THAT(output_shape_error.status().error_message(),
1022 HasSubstr("result has to be a scalar"));
1023
1024 auto param_shape_error = ShapeInference::InferMapShape(
1025 {&vector_32_, &vector_32_},
1026 ShapeUtil::MakeProgramShape({vector_32_, f32_}, f32_), {0});
1027 ASSERT_FALSE(param_shape_error.ok());
1028 ASSERT_THAT(param_shape_error.status().error_message(),
1029 HasSubstr("parameter has to be a scalar"));
1030
1031 auto param_element_type_error = ShapeInference::InferMapShape(
1032 {&vector_32_, &vector_32_},
1033 ShapeUtil::MakeProgramShape({f32_, s32_}, f32_), {0});
1034 ASSERT_FALSE(param_element_type_error.ok());
1035 ASSERT_THAT(param_element_type_error.status().error_message(),
1036 HasSubstr("parameter type has to match argument"));
1037
1038 Shape arg = ShapeUtil::MakeShape(F32, {20});
1039 ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_}, f32_);
1040 auto inferred_status = ShapeInference::InferMapShape({&arg}, to_apply, {0});
1041 EXPECT_IS_OK(inferred_status.status());
1042 EXPECT_TRUE(ShapeUtil::Equal(arg, inferred_status.ValueOrDie()));
1043
1044 auto inferred_status_error1 = ShapeInference::InferMapShape(
1045 {&arg}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {0});
1046 ASSERT_FALSE(inferred_status_error1.ok());
1047 ASSERT_THAT(inferred_status_error1.status().error_message(),
1048 HasSubstr("arity must match number of arguments"));
1049
1050 auto inferred_status_error2 = ShapeInference::InferMapShape(
1051 {&arg}, ShapeUtil::MakeProgramShape({vector_32_}, f32_), {0});
1052 ASSERT_FALSE(inferred_status_error2.ok());
1053 ASSERT_THAT(inferred_status_error2.status().error_message(),
1054 HasSubstr("has to be a scalar"));
1055
1056 auto inferred_status_error3 = ShapeInference::InferMapShape(
1057 {&arg}, ShapeUtil::MakeProgramShape({f32_}, vector_32_), {0});
1058 ASSERT_FALSE(inferred_status_error3.ok());
1059 ASSERT_THAT(inferred_status_error3.status().error_message(),
1060 HasSubstr("has to be a scalar"));
1061
1062 auto inferred_status_error5 = ShapeInference::InferMapShape(
1063 {&arg}, ShapeUtil::MakeProgramShape({s32_}, s32_), {0});
1064 ASSERT_FALSE(inferred_status_error5.ok());
1065 ASSERT_THAT(inferred_status_error5.status().error_message(),
1066 HasSubstr("parameter type has to match argument"));
1067 }
1068
TEST_F(ShapeInferenceTest,MapWithDifferentInputTypes)1069 TEST_F(ShapeInferenceTest, MapWithDifferentInputTypes) {
1070 Shape arg0 = ShapeUtil::MakeShape(F32, {20});
1071 Shape arg1 = ShapeUtil::MakeShape(S32, {20});
1072 ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, s32_}, s32_);
1073 auto inferred_status =
1074 ShapeInference::InferMapShape({&arg0, &arg1}, to_apply, {0});
1075 EXPECT_IS_OK(inferred_status.status());
1076 Shape expected = ShapeUtil::MakeShape(S32, {20});
1077 EXPECT_TRUE(ShapeUtil::Equal(expected, inferred_status.ValueOrDie()));
1078 }
1079
TEST_F(ReduceShapeInferenceTest,ReduceVectorToScalar)1080 TEST_F(ReduceShapeInferenceTest, ReduceVectorToScalar) {
1081 ExpectInferredReduceShape(f32_, ShapeUtil::MakeShape(F32, {128}),
1082 /*dimensions_to_reduce=*/{0});
1083 }
1084
TEST_F(ReduceShapeInferenceTest,ReduceCubeAmongFirstDimension)1085 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongFirstDimension) {
1086 ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {3, 4}),
1087 ShapeUtil::MakeShape(F32, {2, 3, 4}),
1088 /*dimensions_to_reduce=*/{0});
1089 }
1090
TEST_F(ReduceShapeInferenceTest,ReduceCubeAmongMiddleDimension)1091 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongMiddleDimension) {
1092 ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {2, 4}),
1093 ShapeUtil::MakeShape(F32, {2, 3, 4}),
1094 /*dimensions_to_reduce=*/{1});
1095 }
1096
TEST_F(ReduceShapeInferenceTest,ReduceCubeAmongFirstTwoDimensions)1097 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongFirstTwoDimensions) {
1098 ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {4}),
1099 ShapeUtil::MakeShape(F32, {2, 3, 4}),
1100 /*dimensions_to_reduce=*/{0, 1});
1101 }
1102
TEST_F(ReduceShapeInferenceTest,ReduceCubeAmongLastTwoDimensions)1103 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongLastTwoDimensions) {
1104 ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {2}),
1105 ShapeUtil::MakeShape(F32, {2, 3, 4}),
1106 /*dimensions_to_reduce=*/{1, 2});
1107 }
1108
TEST_F(ReduceShapeInferenceTest,ReduceCubeAmongFirstAndLastDimensions)1109 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongFirstAndLastDimensions) {
1110 ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {3}),
1111 ShapeUtil::MakeShape(F32, {2, 3, 4}),
1112 /*dimensions_to_reduce=*/{0, 2});
1113
1114 // Check that the order of dimensions_to_reduce doesn't matter.
1115 ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {3}),
1116 ShapeUtil::MakeShape(F32, {2, 3, 4}),
1117 /*dimensions_to_reduce=*/{2, 0});
1118 }
1119
TEST_F(ReduceShapeInferenceTest,ReduceCubeAmongAllDimensions)1120 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongAllDimensions) {
1121 ExpectInferredReduceShape(f32_, ShapeUtil::MakeShape(F32, {2, 3, 4}),
1122 /*dimensions_to_reduce=*/{0, 1, 2});
1123 }
1124
TEST_F(ReduceShapeInferenceTest,ReduceMultiOutput)1125 TEST_F(ReduceShapeInferenceTest, ReduceMultiOutput) {
1126 Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1127 Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
1128 ProgramShape to_apply = ShapeUtil::MakeProgramShape(
1129 {f32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_}));
1130 auto inferred_status = ShapeInference::InferReduceShape(
1131 {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
1132 EXPECT_IS_OK(inferred_status.status());
1133 EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeTupleShape({f32_, s32_}),
1134 inferred_status.ValueOrDie()));
1135 }
1136
TEST_F(ReduceShapeInferenceTest,ReduceWindowMultiOutput)1137 TEST_F(ReduceShapeInferenceTest, ReduceWindowMultiOutput) {
1138 Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3, 1});
1139 Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3, 1});
1140 std::vector<const Shape*> args = {&f32_arg_shape, &s32_arg_shape};
1141 std::vector<const Shape*> inits = {&f32_, &s32_};
1142 ProgramShape to_apply = ShapeUtil::MakeProgramShape(
1143 {f32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_}));
1144 std::vector<int64_t> window_dimensions = {1, 2, 4};
1145 std::vector<int64_t> window_strides = {1, 1, 1};
1146 std::vector<std::pair<int64_t, int64_t>> padding_values =
1147 MakePadding(f32_arg_shape.dimensions(), window_dimensions, window_strides,
1148 Padding::kValid);
1149 TF_ASSERT_OK_AND_ASSIGN(
1150 Window window,
1151 ShapeInference::InferWindowFromDimensions(
1152 window_dimensions, window_strides, padding_values, {}, {}));
1153 auto inferred_status = ShapeInference::InferReduceWindowShape(
1154 absl::MakeSpan(args), absl::MakeSpan(inits), window, to_apply);
1155 VLOG(2) << inferred_status.ValueOrDie().ToString() << "\n";
1156 EXPECT_IS_OK(inferred_status.status());
1157 EXPECT_TRUE(ShapeUtil::Equal(
1158 ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {5, 2, 0}),
1159 ShapeUtil::MakeShape(S32, {5, 2, 0})}),
1160 inferred_status.ValueOrDie()));
1161 }
1162
TEST_F(ReduceShapeInferenceTest,ErrorMultiOutputBadReducerInput1)1163 TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput1) {
1164 Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1165 Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
1166 ProgramShape to_apply =
1167 ShapeUtil::MakeProgramShape({f32_, s32_, f32_, s32_, f32_, s32_},
1168 ShapeUtil::MakeTupleShape({f32_, s32_}));
1169 auto inferred_status = ShapeInference::InferReduceShape(
1170 {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
1171 EXPECT_FALSE(inferred_status.ok());
1172 EXPECT_THAT(inferred_status.status().error_message(),
1173 HasSubstr("must take 4 parameters, but takes 6 parameter(s)"));
1174 }
1175
TEST_F(ReduceShapeInferenceTest,ErrorMultiOutputBadReducerInput2)1176 TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput2) {
1177 Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1178 Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
1179 ProgramShape to_apply = ShapeUtil::MakeProgramShape(
1180 {s32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_}));
1181 auto inferred_status = ShapeInference::InferReduceShape(
1182 {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
1183 EXPECT_FALSE(inferred_status.ok());
1184 EXPECT_THAT(
1185 inferred_status.status().error_message(),
1186 HasSubstr(
1187 "parameter shape differs from the result shape: s32[] vs f32[]"));
1188 }
1189
TEST_F(ReduceShapeInferenceTest,ErrorMultiOutputBadReducerInput3)1190 TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput3) {
1191 ProgramShape to_apply = ShapeUtil::MakeProgramShape(
1192 {s32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_}));
1193 auto inferred_status = ShapeInference::InferReduceShape({}, {0, 1}, to_apply);
1194 EXPECT_FALSE(inferred_status.ok());
1195 EXPECT_THAT(inferred_status.status().error_message(),
1196 HasSubstr("must have at least 2 arguments, has 0"));
1197 }
1198
TEST_F(ReduceShapeInferenceTest,ErrorBadReduceWindowInput)1199 TEST_F(ReduceShapeInferenceTest, ErrorBadReduceWindowInput) {
1200 Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3, 1});
1201 Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3, 1});
1202 std::vector<const Shape*> args = {&f32_arg_shape, &s32_arg_shape};
1203 std::vector<const Shape*> inits = {&f32_, &s32_};
1204 ProgramShape to_apply = ShapeUtil::MakeProgramShape(
1205 {f32_, f32_, f32_, f32_}, ShapeUtil::MakeTupleShape({f32_, s32_}));
1206 std::vector<int64_t> window_dimensions = {1, 2, 4};
1207 std::vector<int64_t> window_strides = {1, 1, 1};
1208 std::vector<std::pair<int64_t, int64_t>> padding_values =
1209 MakePadding(f32_arg_shape.dimensions(), window_dimensions, window_strides,
1210 Padding::kValid);
1211 TF_ASSERT_OK_AND_ASSIGN(
1212 Window window,
1213 ShapeInference::InferWindowFromDimensions(
1214 window_dimensions, window_strides, padding_values, {}, {}));
1215 auto inferred_status = ShapeInference::InferReduceWindowShape(
1216 absl::MakeSpan(args), absl::MakeSpan(inits), window, to_apply);
1217 EXPECT_FALSE(inferred_status.status().ok());
1218 EXPECT_THAT(inferred_status.status().error_message(),
1219 HasSubstr("f32[] vs s32[]"));
1220 }
1221
TEST_F(ReduceShapeInferenceTest,ErrorMultiOutputBadReducerOutput1)1222 TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerOutput1) {
1223 Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1224 Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
1225 ProgramShape to_apply =
1226 ShapeUtil::MakeProgramShape({f32_, s32_, f32_, s32_}, f32_);
1227 auto inferred_status = ShapeInference::InferReduceShape(
1228 {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
1229 EXPECT_FALSE(inferred_status.ok());
1230 EXPECT_THAT(
1231 inferred_status.status().error_message(),
1232 HasSubstr("must produce a tuple with 2 elements, but produces a scalar"));
1233 }
1234
TEST_F(ReduceShapeInferenceTest,ErrorMultiOutputBadReducerOutput2)1235 TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerOutput2) {
1236 Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1237 Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
1238 ProgramShape to_apply = ShapeUtil::MakeProgramShape(
1239 {f32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_, s32_}));
1240 auto inferred_status = ShapeInference::InferReduceShape(
1241 {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
1242 EXPECT_FALSE(inferred_status.ok());
1243 EXPECT_THAT(
1244 inferred_status.status().error_message(),
1245 HasSubstr("must produce a tuple with 2 elements, but has 3 elements"));
1246 }
1247
TEST_F(ReduceShapeInferenceTest,ErrorMultiOutputBadReducerBoth)1248 TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerBoth) {
1249 Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1250 Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
1251 ProgramShape to_apply = ShapeUtil::MakeProgramShape(
1252 {s32_, s32_, s32_, s32_}, ShapeUtil::MakeTupleShape({s32_, s32_}));
1253 auto inferred_status = ShapeInference::InferReduceShape(
1254 {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
1255 EXPECT_FALSE(inferred_status.ok());
1256 EXPECT_THAT(inferred_status.status().error_message(),
1257 HasSubstr("accumulator shape at index 0 differs from the "
1258 "init_value shape: s32[] vs f32[]"));
1259 }
1260
TEST_F(ReduceShapeInferenceTest,ErrorOutOfBoundsDimension)1261 TEST_F(ReduceShapeInferenceTest, ErrorOutOfBoundsDimension) {
1262 ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_);
1263 Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1264 auto inferred_status = ShapeInference::InferReduceShape(
1265 {&arg_shape, &f32_},
1266 /*dimensions_to_reduce=*/{3, 4}, to_apply);
1267 EXPECT_FALSE(inferred_status.ok());
1268 EXPECT_THAT(inferred_status.status().error_message(),
1269 HasSubstr("out-of-bounds dimension"));
1270 }
1271
TEST_F(ReduceShapeInferenceTest,ErrorToApplyArity)1272 TEST_F(ReduceShapeInferenceTest, ErrorToApplyArity) {
1273 ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_, f32_}, f32_);
1274 Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1275 auto inferred_status =
1276 ShapeInference::InferReduceShape({&arg_shape, &f32_},
1277 /*dimensions_to_reduce=*/{0}, to_apply);
1278 EXPECT_FALSE(inferred_status.ok());
1279 EXPECT_THAT(inferred_status.status().error_message(),
1280 HasSubstr("take 2 parameters"));
1281 }
1282
TEST_F(ReduceShapeInferenceTest,ErrorElementTypeVsApplyType)1283 TEST_F(ReduceShapeInferenceTest, ErrorElementTypeVsApplyType) {
1284 ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, s32_);
1285 Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1286 auto inferred_status =
1287 ShapeInference::InferReduceShape({&arg_shape, &f32_},
1288 /*dimensions_to_reduce=*/{0}, to_apply);
1289 EXPECT_FALSE(inferred_status.ok());
1290 EXPECT_THAT(inferred_status.status().error_message(),
1291 HasSubstr("0-th parameter shape differs"));
1292 }
1293
TEST_F(ReduceShapeInferenceTest,ReduceWithRepeatedReduceDimension)1294 TEST_F(ReduceShapeInferenceTest, ReduceWithRepeatedReduceDimension) {
1295 ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_);
1296 Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1297 auto inferred_status = ShapeInference::InferReduceShape(
1298 {&arg_shape, &f32_},
1299 /*dimensions_to_reduce=*/{0, 0}, to_apply);
1300 EXPECT_FALSE(inferred_status.ok());
1301 EXPECT_THAT(inferred_status.status().error_message(),
1302 HasSubstr("Duplicate reduction dimension: 0"));
1303 }
1304
TEST_F(ShapeInferenceTest,InferSliceShapeRank2)1305 TEST_F(ShapeInferenceTest, InferSliceShapeRank2) {
1306 Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
1307 auto inferred_status =
1308 ShapeInference::InferSliceShape(matrix_shape, {32, 0}, {64, 64}, {1, 1});
1309 ASSERT_IS_OK(inferred_status.status());
1310 Shape inferred = inferred_status.ValueOrDie();
1311 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {32, 64}), inferred));
1312 }
1313
TEST_F(ShapeInferenceTest,InferSliceWithDynamicDimensions)1314 TEST_F(ShapeInferenceTest, InferSliceWithDynamicDimensions) {
1315 Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}, {true, true});
1316 auto inferred_status =
1317 ShapeInference::InferSliceShape(matrix_shape, {32, 0}, {33, 64}, {1, 1});
1318 ASSERT_IS_OK(inferred_status.status());
1319 Shape inferred = inferred_status.ValueOrDie();
1320 ASSERT_TRUE(ShapeUtil::Equal(
1321 ShapeUtil::MakeShape(F32, {1, 64}, {false, true}), inferred));
1322 }
1323
TEST_F(ShapeInferenceTest,InferSliceShapeRank2WithStrides)1324 TEST_F(ShapeInferenceTest, InferSliceShapeRank2WithStrides) {
1325 Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
1326 auto inferred_status =
1327 ShapeInference::InferSliceShape(matrix_shape, {32, 0}, {64, 64}, {2, 4});
1328 ASSERT_IS_OK(inferred_status.status());
1329 Shape inferred = inferred_status.ValueOrDie();
1330 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {16, 16}), inferred));
1331 }
1332
TEST_F(ShapeInferenceTest,InferSliceShapeRank2WithStridesNotIntegral)1333 TEST_F(ShapeInferenceTest, InferSliceShapeRank2WithStridesNotIntegral) {
1334 Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
1335 auto inferred_status =
1336 ShapeInference::InferSliceShape(matrix_shape, {15, 0}, {20, 13}, {2, 4});
1337 ASSERT_IS_OK(inferred_status.status());
1338 Shape inferred = inferred_status.ValueOrDie();
1339 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {3, 4}), inferred));
1340 }
1341
TEST_F(ShapeInferenceTest,InferInvalidStride)1342 TEST_F(ShapeInferenceTest, InferInvalidStride) {
1343 Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
1344 auto inferred_status =
1345 ShapeInference::InferSliceShape(matrix_shape, {127, 0}, {129, 2}, {0, 1});
1346 ASSERT_FALSE(inferred_status.ok());
1347 ASSERT_EQ(tensorflow::error::INVALID_ARGUMENT,
1348 inferred_status.status().code());
1349 }
1350
TEST_F(ShapeInferenceTest,InferOobSliceShapeRank2)1351 TEST_F(ShapeInferenceTest, InferOobSliceShapeRank2) {
1352 Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
1353 auto inferred_status =
1354 ShapeInference::InferSliceShape(matrix_shape, {127, 0}, {129, 2}, {1, 1});
1355 ASSERT_FALSE(inferred_status.ok());
1356 ASSERT_EQ(tensorflow::error::INVALID_ARGUMENT,
1357 inferred_status.status().code());
1358 }
1359
TEST_F(ShapeInferenceTest,InferSliceShapeRank1)1360 TEST_F(ShapeInferenceTest, InferSliceShapeRank1) {
1361 Shape vector_shape = ShapeUtil::MakeShape(F32, {17});
1362 auto inferred_status =
1363 ShapeInference::InferSliceShape(vector_shape, {2}, {4}, {1});
1364 ASSERT_TRUE(inferred_status.ok());
1365 Shape inferred = inferred_status.ValueOrDie();
1366 ASSERT_TRUE(ShapeUtil::Equal(inferred, ShapeUtil::MakeShape(F32, {2})));
1367 }
1368
TEST_F(ShapeInferenceTest,InferConstIndexShape)1369 TEST_F(ShapeInferenceTest, InferConstIndexShape) {
1370 Shape tuple_shape = ShapeUtil::MakeTupleShape({f32_, s32_});
1371 auto inferred0_status =
1372 ShapeInference::InferGetTupleElementShape(tuple_shape, 0);
1373 auto inferred1_status =
1374 ShapeInference::InferGetTupleElementShape(tuple_shape, 1);
1375 ASSERT_IS_OK(inferred0_status.status());
1376 ASSERT_IS_OK(inferred1_status.status());
1377 ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred0_status.ValueOrDie()));
1378 ASSERT_TRUE(ShapeUtil::Equal(s32_, inferred1_status.ValueOrDie()));
1379 }
1380
TEST_F(ShapeInferenceTest,InferTupleElementShapeOutOfBound)1381 TEST_F(ShapeInferenceTest, InferTupleElementShapeOutOfBound) {
1382 Shape tuple_shape = ShapeUtil::MakeTupleShape({f32_, s32_});
1383 auto inferredNegative_status =
1384 ShapeInference::InferGetTupleElementShape(tuple_shape, -1);
1385 auto inferred2_status =
1386 ShapeInference::InferGetTupleElementShape(tuple_shape, 2);
1387 ASSERT_FALSE(inferredNegative_status.ok());
1388 ASSERT_FALSE(inferred2_status.ok());
1389 EXPECT_THAT(inferredNegative_status.status().error_message(),
1390 HasSubstr("attempt to index out of tuple bounds"));
1391 EXPECT_THAT(inferred2_status.status().error_message(),
1392 HasSubstr("attempt to index out of tuple bounds"));
1393 }
1394
TEST_F(ShapeInferenceTest,InferPowShape)1395 TEST_F(ShapeInferenceTest, InferPowShape) {
1396 auto ten_floats = ShapeUtil::MakeShape(F32, {10});
1397 auto inferred_status = ShapeInference::InferBinaryOpShape(
1398 HloOpcode::kPower, ten_floats, f32_, {});
1399 ASSERT_IS_OK(inferred_status.status());
1400 ASSERT_TRUE(ShapeUtil::Equal(ten_floats, inferred_status.ValueOrDie()));
1401 }
1402
TEST_F(ShapeInferenceTest,InferCompareShape)1403 TEST_F(ShapeInferenceTest, InferCompareShape) {
1404 auto ten_floats = ShapeUtil::MakeShape(F32, {10});
1405 auto inferred_status = ShapeInference::InferBinaryOpShape(
1406 HloOpcode::kCompare, ten_floats, f32_, {});
1407 ASSERT_IS_OK(inferred_status.status());
1408 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}),
1409 inferred_status.ValueOrDie()));
1410 }
1411
TEST_F(ShapeInferenceTest,InferReshapeDegenerateCombine)1412 TEST_F(ShapeInferenceTest, InferReshapeDegenerateCombine) {
1413 // [1, <=1]
1414 // | reshape
1415 // [<=1]
1416 //
1417 // Both output dimension can be dynamic, use inferred_dimension to tie-break.
1418 auto operand = ShapeUtil::MakeShape(F32, {1, 1}, {false, true});
1419 auto status = ShapeInference::InferReshapeShape(operand, {1, 0}, {1},
1420 /*inferred_dimension=*/-1);
1421 ASSERT_EQ(ShapeUtil::MakeShape(F32, {1}, {true}), status.ValueOrDie());
1422 }
1423
TEST_F(ShapeInferenceTest,InferReshapeSplit)1424 TEST_F(ShapeInferenceTest, InferReshapeSplit) {
1425 // [<=10]
1426 // | reshape
1427 // [1, 10]
1428 //
1429 // Both output dimension can be dynamic, use inferred_dimension to tie-break.
1430 auto operand = ShapeUtil::MakeShape(F32, {10}, {true});
1431 auto status = ShapeInference::InferReshapeShape(operand, {0}, {1, 10},
1432 /*inferred_dimension=*/0);
1433 ASSERT_EQ(ShapeUtil::MakeShape(F32, {1, 10}, {true, false}),
1434 status.ValueOrDie());
1435 }
1436
TEST_F(ShapeInferenceTest,InferReshapeCombine)1437 TEST_F(ShapeInferenceTest, InferReshapeCombine) {
1438 // [6, <=10]
1439 // | reshape
1440 // [<=60]
1441 auto operand = ShapeUtil::MakeShape(F32, {6, 10}, {false, true});
1442 auto status = ShapeInference::InferReshapeShape(operand, {1, 0}, {60},
1443 /*inferred_dimension=*/-11);
1444 ASSERT_EQ(ShapeUtil::MakeShape(F32, {60}, {true}), status.ValueOrDie());
1445 }
1446
TEST_F(ShapeInferenceTest,UnchangedDimension)1447 TEST_F(ShapeInferenceTest, UnchangedDimension) {
1448 // [6, <=10]
1449 // | reshape
1450 // [2, 3, <=10]
1451 auto operand = ShapeUtil::MakeShape(F32, {6, 10}, {false, true});
1452 auto status = ShapeInference::InferReshapeShape(operand, {1, 0}, {2, 3, 10},
1453 /*inferred_dimension=*/-11);
1454 ASSERT_EQ(ShapeUtil::MakeShape(F32, {2, 3, 10}, {false, false, true}),
1455 status.ValueOrDie());
1456 }
1457
TEST_F(ShapeInferenceTest,InferDynamicBroadcast)1458 TEST_F(ShapeInferenceTest, InferDynamicBroadcast) {
1459 // CHECK:
1460 // %broadcast = s32[15,<=15]{1,0} broadcast(s32[<=15]{0}), dimensions={1}
1461
1462 auto operand_shape = ShapeUtil::MakeShape(F32, {15}, {true});
1463 auto inferred_status =
1464 ShapeInference::InferBroadcastShape(operand_shape, {15});
1465 ASSERT_IS_OK(inferred_status.status());
1466 Shape inferred = inferred_status.ValueOrDie();
1467 ASSERT_EQ(ShapeUtil::MakeShape(F32, {15, 15}, {false, true}), inferred);
1468 }
1469
TEST_F(ShapeInferenceTest,BroadcastScalar)1470 TEST_F(ShapeInferenceTest, BroadcastScalar) {
1471 for (auto element_type : {F32, U32, S8}) {
1472 const Shape scalar_shape = ShapeUtil::MakeShape(element_type, {});
1473 { // no-op scalar broadcast
1474 auto status = ShapeInference::InferBroadcastShape(scalar_shape, {});
1475 ASSERT_IS_OK(status.status());
1476 ASSERT_TRUE(ShapeUtil::Equal(scalar_shape, status.ValueOrDie()));
1477 }
1478 const Shape oned_shape = ShapeUtil::MakeShape(element_type, {3});
1479 { // scalar -> 1d broadcast
1480 auto status = ShapeInference::InferBroadcastShape(scalar_shape, {3});
1481 ASSERT_IS_OK(status.status());
1482 ASSERT_TRUE(ShapeUtil::Equal(oned_shape, status.ValueOrDie()));
1483 }
1484 { // no-op 1d broadcast
1485 auto status = ShapeInference::InferBroadcastShape(oned_shape, {});
1486 ASSERT_IS_OK(status.status());
1487 ASSERT_TRUE(ShapeUtil::Equal(oned_shape, status.ValueOrDie()));
1488 }
1489 const Shape twod_shape = ShapeUtil::MakeShape(element_type, {2, 3});
1490 { // scalar -> 2d broadcast
1491 auto status = ShapeInference::InferBroadcastShape(scalar_shape, {2, 3});
1492 ASSERT_IS_OK(status.status());
1493 ASSERT_TRUE(ShapeUtil::Equal(twod_shape, status.ValueOrDie()));
1494 }
1495 { // 1d -> 2d broadcast
1496 auto status = ShapeInference::InferBroadcastShape(oned_shape, {2});
1497 ASSERT_IS_OK(status.status());
1498 ASSERT_TRUE(ShapeUtil::Equal(twod_shape, status.ValueOrDie()));
1499 }
1500 }
1501 }
1502
1503 // scalar <dot> vector: ok
TEST_F(ShapeInferenceTest,ScalarDotVector)1504 TEST_F(ShapeInferenceTest, ScalarDotVector) {
1505 DotDimensionNumbers dot_dnums;
1506 auto inferred_status = ShapeInference::InferDotOpShape(
1507 f32_, vector_32_, dot_dnums, /*preferred_element_type=*/std::nullopt);
1508 EXPECT_TRUE(inferred_status.ok());
1509 EXPECT_EQ(inferred_status.ValueOrDie(), vector_32_);
1510 }
1511
1512 // 3D <dot> 2D: error
TEST_F(ShapeInferenceTest,DotWithRankHigherThanTwo)1513 TEST_F(ShapeInferenceTest, DotWithRankHigherThanTwo) {
1514 DotDimensionNumbers dot_dnums;
1515 dot_dnums.add_lhs_contracting_dimensions(1);
1516 dot_dnums.add_rhs_contracting_dimensions(0);
1517 auto inferred_status = ShapeInference::InferDotOpShape(
1518 ShapeUtil::MakeShape(F32, {32, 32, 32}), matrix_32_64_, dot_dnums,
1519 /*preferred_element_type=*/std::nullopt);
1520 EXPECT_TRUE(inferred_status.ok());
1521 EXPECT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(),
1522 ShapeUtil::MakeShape(F32, {32, 32, 64})));
1523 }
1524
1525 // vector <dot> vector -> scalar
TEST_F(ShapeInferenceTest,VectorDotVector)1526 TEST_F(ShapeInferenceTest, VectorDotVector) {
1527 DotDimensionNumbers dot_dnums;
1528 dot_dnums.add_lhs_contracting_dimensions(0);
1529 dot_dnums.add_rhs_contracting_dimensions(0);
1530 auto inferred_status =
1531 ShapeInference::InferDotOpShape(vector_64_, vector_64_, dot_dnums,
1532 /*preferred_element_type=*/std::nullopt);
1533 ASSERT_IS_OK(inferred_status.status());
1534 ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred_status.ValueOrDie()));
1535 auto inferred_status_mismatch =
1536 ShapeInference::InferDotOpShape(vector_64_, vector_32_, dot_dnums,
1537 /*preferred_element_type=*/std::nullopt);
1538 ASSERT_FALSE(inferred_status_mismatch.ok());
1539 }
1540
1541 // matrix <dot> vector -> vector
TEST_F(ShapeInferenceTest,MatrixDotVector)1542 TEST_F(ShapeInferenceTest, MatrixDotVector) {
1543 DotDimensionNumbers dot_dnums;
1544 dot_dnums.add_lhs_contracting_dimensions(1);
1545 dot_dnums.add_rhs_contracting_dimensions(0);
1546 auto inferred_status =
1547 ShapeInference::InferDotOpShape(matrix_32_64_, vector_64_, dot_dnums,
1548 /*preferred_element_type=*/std::nullopt);
1549 ASSERT_IS_OK(inferred_status.status());
1550 ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), vector_32_));
1551 auto inferred_status_mismatch =
1552 ShapeInference::InferDotOpShape(matrix_32_64_, vector_32_, dot_dnums,
1553 /*preferred_element_type=*/std::nullopt);
1554 ASSERT_FALSE(inferred_status_mismatch.ok());
1555 }
1556
1557 // vector <dot> matrix -> vector
TEST_F(ShapeInferenceTest,VectorDotMatrix)1558 TEST_F(ShapeInferenceTest, VectorDotMatrix) {
1559 DotDimensionNumbers dot_dnums;
1560 dot_dnums.add_lhs_contracting_dimensions(0);
1561 dot_dnums.add_rhs_contracting_dimensions(0);
1562 auto inferred_status =
1563 ShapeInference::InferDotOpShape(vector_32_, matrix_32_64_, dot_dnums,
1564 /*preferred_element_type=*/std::nullopt);
1565 ASSERT_IS_OK(inferred_status.status());
1566 ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), vector_64_));
1567 auto inferred_status_mismatch =
1568 ShapeInference::InferDotOpShape(vector_64_, matrix_32_64_, dot_dnums,
1569 /*preferred_element_type=*/std::nullopt);
1570 ASSERT_FALSE(inferred_status_mismatch.ok());
1571 }
1572
1573 // matrix <dot> matrix -> matrix
TEST_F(ShapeInferenceTest,MatrixDotMatrix)1574 TEST_F(ShapeInferenceTest, MatrixDotMatrix) {
1575 DotDimensionNumbers dot_dnums;
1576 dot_dnums.add_lhs_contracting_dimensions(1);
1577 dot_dnums.add_rhs_contracting_dimensions(0);
1578 auto inferred_status_match =
1579 ShapeInference::InferDotOpShape(matrix_32_64_, matrix_64_48_, dot_dnums,
1580 /*preferred_element_type=*/std::nullopt);
1581 ASSERT_IS_OK(inferred_status_match.status());
1582 ASSERT_TRUE(
1583 ShapeUtil::Equal(inferred_status_match.ValueOrDie(), matrix_32_48_))
1584 << "inferred: "
1585 << ShapeUtil::HumanString(inferred_status_match.ValueOrDie())
1586 << " expected: " << ShapeUtil::HumanString(matrix_64_48_);
1587 auto inferred_status_mismatch =
1588 ShapeInference::InferDotOpShape(matrix_32_64_, matrix_32_64_, dot_dnums,
1589 /*preferred_element_type=*/std::nullopt);
1590 ASSERT_FALSE(inferred_status_mismatch.ok());
1591 }
1592
1593 // BatchMatMul with two batch dimensions and one contracting dimension.
TEST_F(ShapeInferenceTest,DotGeneral)1594 TEST_F(ShapeInferenceTest, DotGeneral) {
1595 Shape lhs_shape = ShapeUtil::MakeShape(F32, {5, 2, 11, 3});
1596 Shape rhs_shape = ShapeUtil::MakeShape(F32, {5, 2, 3, 14});
1597 Shape output_shape = ShapeUtil::MakeShape(F32, {5, 2, 11, 14});
1598
1599 DotDimensionNumbers dot_dnums;
1600 dot_dnums.add_lhs_contracting_dimensions(3);
1601 dot_dnums.add_lhs_batch_dimensions(0);
1602 dot_dnums.add_lhs_batch_dimensions(1);
1603
1604 dot_dnums.add_rhs_contracting_dimensions(2);
1605 dot_dnums.add_rhs_batch_dimensions(0);
1606 dot_dnums.add_rhs_batch_dimensions(1);
1607
1608 auto inferred_status_match =
1609 ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
1610 /*preferred_element_type=*/std::nullopt);
1611 ASSERT_IS_OK(inferred_status_match.status());
1612 ASSERT_TRUE(
1613 ShapeUtil::Equal(inferred_status_match.ValueOrDie(), output_shape))
1614 << "inferred: "
1615 << ShapeUtil::HumanString(inferred_status_match.ValueOrDie())
1616 << " expected: " << ShapeUtil::HumanString(output_shape);
1617 }
1618
1619 // BatchMatMul with two contracting dimensions fails.
TEST_F(ShapeInferenceTest,DotWithTwoContractingDimsFails)1620 TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsFails) {
1621 Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3, 2});
1622 Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14});
1623
1624 DotDimensionNumbers dot_dnums;
1625 dot_dnums.add_lhs_contracting_dimensions(2);
1626 dot_dnums.add_lhs_contracting_dimensions(3);
1627 dot_dnums.add_lhs_batch_dimensions(0);
1628
1629 dot_dnums.add_rhs_contracting_dimensions(1);
1630 dot_dnums.add_rhs_batch_dimensions(0);
1631
1632 auto inferred_status =
1633 ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
1634 /*preferred_element_type=*/std::nullopt);
1635 ASSERT_FALSE(inferred_status.ok());
1636 ASSERT_THAT(inferred_status.status().error_message(),
1637 HasSubstr("Must specify the same number of contracting "
1638 "dimensions for lhs and rhs."));
1639 }
1640
TEST_F(ShapeInferenceTest,DotWithTwoContractingDimsPasses)1641 TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsPasses) {
1642 Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3, 2});
1643 Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 2, 14});
1644 Shape output_shape = ShapeUtil::MakeShape(F32, {2, 11, 14});
1645
1646 DotDimensionNumbers dot_dnums;
1647 dot_dnums.add_lhs_contracting_dimensions(2);
1648 dot_dnums.add_lhs_contracting_dimensions(3);
1649 dot_dnums.add_lhs_batch_dimensions(0);
1650
1651 dot_dnums.add_rhs_contracting_dimensions(1);
1652 dot_dnums.add_rhs_contracting_dimensions(2);
1653 dot_dnums.add_rhs_batch_dimensions(0);
1654
1655 auto inferred_status =
1656 ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
1657 /*preferred_element_type=*/std::nullopt);
1658 EXPECT_TRUE(inferred_status.ok());
1659 EXPECT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), output_shape));
1660 }
1661
TEST_F(ShapeInferenceTest,ErrorSetDimensionSize)1662 TEST_F(ShapeInferenceTest, ErrorSetDimensionSize) {
1663 Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1664 Shape val_shape = ShapeUtil::MakeShape(S32, {1});
1665 auto inferred_status = ShapeInference::InferSetDimensionSizeShape(
1666 arg_shape, val_shape, /*dimension=*/0);
1667
1668 EXPECT_FALSE(inferred_status.ok());
1669 EXPECT_THAT(inferred_status.status().error_message(),
1670 HasSubstr("value has to be S32 scalar"));
1671 }
1672
TEST_F(ShapeInferenceTest,ErrorSetDimensionSizeWrongType)1673 TEST_F(ShapeInferenceTest, ErrorSetDimensionSizeWrongType) {
1674 Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1675 Shape val_shape = ShapeUtil::MakeShape(U32, {});
1676 auto inferred_status = ShapeInference::InferSetDimensionSizeShape(
1677 arg_shape, val_shape, /*dimension=*/0);
1678
1679 EXPECT_FALSE(inferred_status.ok());
1680 EXPECT_THAT(inferred_status.status().error_message(),
1681 HasSubstr("value has to be S32 scalar"));
1682 }
1683
1684 // BatchMatMul with different batch dimension sizes fails.
TEST_F(ShapeInferenceTest,DotWithMismatchedBatchDimSizesFails)1685 TEST_F(ShapeInferenceTest, DotWithMismatchedBatchDimSizesFails) {
1686 Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3});
1687 Shape rhs_shape = ShapeUtil::MakeShape(F32, {3, 3, 14});
1688
1689 DotDimensionNumbers dot_dnums;
1690 dot_dnums.add_lhs_contracting_dimensions(2);
1691 dot_dnums.add_lhs_batch_dimensions(0);
1692
1693 dot_dnums.add_rhs_contracting_dimensions(1);
1694 dot_dnums.add_rhs_batch_dimensions(0);
1695
1696 auto inferred_status =
1697 ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
1698 /*preferred_element_type=*/std::nullopt);
1699 ASSERT_FALSE(inferred_status.ok());
1700 ASSERT_THAT(inferred_status.status().error_message(),
1701 HasSubstr("Batch dimension sizes must match"));
1702 }
1703
1704 // BatchMatMul with different batch dimension numbers passes
TEST_F(ShapeInferenceTest,DotWithMismatchedBatchDimNumbersPasses)1705 TEST_F(ShapeInferenceTest, DotWithMismatchedBatchDimNumbersPasses) {
1706 Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3});
1707 Shape rhs_shape = ShapeUtil::MakeShape(F32, {3, 2, 14});
1708
1709 DotDimensionNumbers dot_dnums;
1710 dot_dnums.add_lhs_contracting_dimensions(2);
1711 dot_dnums.add_lhs_batch_dimensions(0);
1712
1713 dot_dnums.add_rhs_contracting_dimensions(0);
1714 dot_dnums.add_rhs_batch_dimensions(1);
1715
1716 auto inferred_status =
1717 ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
1718 /*preferred_element_type=*/std::nullopt);
1719 ASSERT_TRUE(inferred_status.ok());
1720 ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(),
1721 ShapeUtil::MakeShape(F32, {2, 11, 14})));
1722 }
1723
1724 // BatchMatMul with out-of-range dimension numbers fails.
TEST_F(ShapeInferenceTest,DotWithContractingDimNumberOutOfRange)1725 TEST_F(ShapeInferenceTest, DotWithContractingDimNumberOutOfRange) {
1726 Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3});
1727 Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14});
1728
1729 DotDimensionNumbers dot_dnums;
1730 dot_dnums.add_lhs_contracting_dimensions(3);
1731 dot_dnums.add_lhs_batch_dimensions(0);
1732
1733 dot_dnums.add_rhs_contracting_dimensions(0);
1734 dot_dnums.add_rhs_batch_dimensions(1);
1735
1736 auto inferred_status =
1737 ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
1738 /*preferred_element_type=*/std::nullopt);
1739 ASSERT_FALSE(inferred_status.ok());
1740 ASSERT_THAT(inferred_status.status().error_message(),
1741 HasSubstr("A dimension number is out of range"));
1742 }
1743
1744 // BatchMatMul with non-unique dimension numbers fails.
TEST_F(ShapeInferenceTest,DotWithContractingNonUniqueDimNumber)1745 TEST_F(ShapeInferenceTest, DotWithContractingNonUniqueDimNumber) {
1746 Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3});
1747 Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14});
1748
1749 DotDimensionNumbers dot_dnums;
1750 dot_dnums.add_lhs_contracting_dimensions(0);
1751 dot_dnums.add_lhs_batch_dimensions(0);
1752
1753 dot_dnums.add_rhs_contracting_dimensions(0);
1754 dot_dnums.add_rhs_batch_dimensions(1);
1755
1756 auto inferred_status =
1757 ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
1758 /*preferred_element_type=*/std::nullopt);
1759 ASSERT_FALSE(inferred_status.ok());
1760 ASSERT_THAT(inferred_status.status().error_message(),
1761 HasSubstr("A dimension number is not unique"));
1762 }
1763
TEST_F(ShapeInferenceTest,DotWithIntegralPreferredElementType)1764 TEST_F(ShapeInferenceTest, DotWithIntegralPreferredElementType) {
1765 DotDimensionNumbers dot_dnums;
1766 dot_dnums.add_lhs_contracting_dimensions(1);
1767 dot_dnums.add_rhs_contracting_dimensions(0);
1768 TF_ASSERT_OK_AND_ASSIGN(Shape inferred_shape,
1769 ShapeInference::InferDotOpShape(
1770 ShapeUtil::MakeShape(S8, {32, 32}),
1771 ShapeUtil::MakeShape(S16, {32, 32}), dot_dnums,
1772 /*preferred_element_type=*/S32));
1773 EXPECT_TRUE(
1774 ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(S32, {32, 32})));
1775 }
1776
TEST_F(ShapeInferenceTest,DotWithPreferredElementTypeSameAsInferredType)1777 TEST_F(ShapeInferenceTest, DotWithPreferredElementTypeSameAsInferredType) {
1778 DotDimensionNumbers dot_dnums;
1779 dot_dnums.add_lhs_contracting_dimensions(1);
1780 dot_dnums.add_rhs_contracting_dimensions(0);
1781 TF_ASSERT_OK_AND_ASSIGN(Shape inferred_shape,
1782 ShapeInference::InferDotOpShape(
1783 ShapeUtil::MakeShape(BF16, {32, 32}),
1784 ShapeUtil::MakeShape(F32, {32, 32}), dot_dnums,
1785 /*preferred_element_type=*/F32));
1786 EXPECT_TRUE(
1787 ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(F32, {32, 32})));
1788 }
1789
TEST_F(ShapeInferenceTest,FloatingPointDotWithNarrowerPreferredElementType)1790 TEST_F(ShapeInferenceTest, FloatingPointDotWithNarrowerPreferredElementType) {
1791 DotDimensionNumbers dot_dnums;
1792 dot_dnums.add_lhs_contracting_dimensions(1);
1793 dot_dnums.add_rhs_contracting_dimensions(0);
1794 TF_ASSERT_OK_AND_ASSIGN(Shape inferred_shape,
1795 ShapeInference::InferDotOpShape(
1796 ShapeUtil::MakeShape(BF16, {32, 32}),
1797 ShapeUtil::MakeShape(F32, {32, 32}), dot_dnums,
1798 /*preferred_element_type=*/BF16));
1799 EXPECT_TRUE(
1800 ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(BF16, {32, 32})));
1801 }
1802
TEST_F(ShapeInferenceTest,FloatingPointDotWithIntegralPreferredElementType)1803 TEST_F(ShapeInferenceTest, FloatingPointDotWithIntegralPreferredElementType) {
1804 DotDimensionNumbers dot_dnums;
1805 dot_dnums.add_lhs_contracting_dimensions(1);
1806 dot_dnums.add_rhs_contracting_dimensions(0);
1807 TF_ASSERT_OK_AND_ASSIGN(Shape inferred_shape,
1808 ShapeInference::InferDotOpShape(
1809 ShapeUtil::MakeShape(BF16, {32, 32}),
1810 ShapeUtil::MakeShape(BF16, {32, 32}), dot_dnums,
1811 /*preferred_element_type=*/S32));
1812 EXPECT_TRUE(
1813 ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(S32, {32, 32})));
1814 }
1815
TEST_F(ShapeInferenceTest,IntegralDotWithFloatingPointPreferredElementType)1816 TEST_F(ShapeInferenceTest, IntegralDotWithFloatingPointPreferredElementType) {
1817 DotDimensionNumbers dot_dnums;
1818 dot_dnums.add_lhs_contracting_dimensions(1);
1819 dot_dnums.add_rhs_contracting_dimensions(0);
1820 TF_ASSERT_OK_AND_ASSIGN(Shape inferred_shape,
1821 ShapeInference::InferDotOpShape(
1822 ShapeUtil::MakeShape(S8, {32, 32}),
1823 ShapeUtil::MakeShape(S16, {32, 32}), dot_dnums,
1824 /*preferred_element_type=*/F32));
1825 EXPECT_TRUE(
1826 ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(F32, {32, 32})));
1827 }
1828
TEST_F(ShapeInferenceTest,DotWithPreferredElementTypeWithDifferentSignedness)1829 TEST_F(ShapeInferenceTest, DotWithPreferredElementTypeWithDifferentSignedness) {
1830 DotDimensionNumbers dot_dnums;
1831 dot_dnums.add_lhs_contracting_dimensions(1);
1832 dot_dnums.add_rhs_contracting_dimensions(0);
1833 TF_ASSERT_OK_AND_ASSIGN(Shape inferred_shape,
1834 ShapeInference::InferDotOpShape(
1835 ShapeUtil::MakeShape(S8, {32, 32}),
1836 ShapeUtil::MakeShape(S16, {32, 32}), dot_dnums,
1837 /*preferred_element_type=*/U32));
1838 EXPECT_TRUE(
1839 ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(U32, {32, 32})));
1840 }
1841
TEST_F(ShapeInferenceTest,DotWithNarrowerPreferredElementType)1842 TEST_F(ShapeInferenceTest, DotWithNarrowerPreferredElementType) {
1843 DotDimensionNumbers dot_dnums;
1844 dot_dnums.add_lhs_contracting_dimensions(1);
1845 dot_dnums.add_rhs_contracting_dimensions(0);
1846 auto inferred_status = ShapeInference::InferDotOpShape(
1847 ShapeUtil::MakeShape(S8, {32, 32}),
1848 ShapeUtil::MakeShape(S16, {32, 32}), dot_dnums,
1849 /*preferred_element_type=*/S8)
1850 .status();
1851 ASSERT_FALSE(inferred_status.ok());
1852 ASSERT_THAT(inferred_status.error_message(),
1853 HasSubstr("must not be narrower than the original type"));
1854 }
1855
TEST_F(ShapeInferenceTest,BinOpBroadcastMatrixVector)1856 TEST_F(ShapeInferenceTest, BinOpBroadcastMatrixVector) {
1857 // Test variations of broadcasting a vector for a binary add with a
1858 // matrix.
1859 const Shape mat = ShapeUtil::MakeShape(F32, {16, 8});
1860 const Shape vec8 = ShapeUtil::MakeShape(F32, {8});
1861 const Shape vec16 = ShapeUtil::MakeShape(F32, {16});
1862
1863 auto inferred_status_match =
1864 ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec8, {1});
1865 ASSERT_IS_OK(inferred_status_match.status());
1866 ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), mat));
1867
1868 auto inferred_status_mismatch =
1869 ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec8, {0});
1870 ASSERT_FALSE(inferred_status_mismatch.ok());
1871
1872 inferred_status_match =
1873 ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec16, {0});
1874 ASSERT_IS_OK(inferred_status_match.status());
1875 ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), mat));
1876
1877 inferred_status_mismatch =
1878 ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec16, {1});
1879 ASSERT_FALSE(inferred_status_mismatch.ok());
1880 }
1881
TEST_F(ShapeInferenceTest,BinOpBroadcastCubeMatrix)1882 TEST_F(ShapeInferenceTest, BinOpBroadcastCubeMatrix) {
1883 // Test variations of broadcasting a matrix for a binary add with a cube.
1884 const Shape cube = ShapeUtil::MakeShape(F32, {16, 8, 4});
1885 const Shape matrix8_4 = ShapeUtil::MakeShape(F32, {8, 4});
1886 const Shape matrix16_4 = ShapeUtil::MakeShape(F32, {16, 4});
1887 const Shape matrix16_8 = ShapeUtil::MakeShape(F32, {16, 8});
1888
1889 auto inferred_status_match = ShapeInference::InferBinaryOpShape(
1890 HloOpcode::kAdd, cube, matrix8_4, {1, 2});
1891 ASSERT_IS_OK(inferred_status_match.status());
1892 ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube));
1893
1894 inferred_status_match = ShapeInference::InferBinaryOpShape(
1895 HloOpcode::kAdd, cube, matrix16_4, {0, 2});
1896 ASSERT_IS_OK(inferred_status_match.status());
1897 ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube));
1898
1899 inferred_status_match = ShapeInference::InferBinaryOpShape(
1900 HloOpcode::kAdd, cube, matrix16_8, {0, 1});
1901 ASSERT_IS_OK(inferred_status_match.status());
1902 ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube));
1903 }
1904
TEST_F(ShapeInferenceTest,BinOpBroadcastBadDimension)1905 TEST_F(ShapeInferenceTest, BinOpBroadcastBadDimension) {
1906 // Test various errors with the broadcast argument.
1907 const Shape tensor = ShapeUtil::MakeShape(F32, {16, 8, 4});
1908 const Shape tensor8_8_8 = ShapeUtil::MakeShape(F32, {8, 8, 8});
1909 const Shape vec8 = ShapeUtil::MakeShape(F32, {8});
1910 const Shape matrix8_4 = ShapeUtil::MakeShape(F32, {8, 4});
1911 const Shape matrix8_8 = ShapeUtil::MakeShape(F32, {8, 8});
1912
1913 // "magical" broadcast rejected
1914 auto inferred_status_error1 =
1915 ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {});
1916 ASSERT_FALSE(inferred_status_error1.ok());
1917 ASSERT_THAT(inferred_status_error1.status().error_message(),
1918 HasSubstr("Shapes must be equal rank"));
1919
1920 // broadcast_dimension out of bounds for tensor's rank
1921 auto inferred_status_error2 =
1922 ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {3});
1923 ASSERT_FALSE(inferred_status_error2.ok());
1924 ASSERT_THAT(inferred_status_error2.status().error_message(),
1925 ContainsRegex("Broadcast dimension number .* too large"));
1926
1927 // broadcast_dimension doesn't match corresponding dimension
1928 auto inferred_status_error3 =
1929 ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {0});
1930 ASSERT_FALSE(inferred_status_error3.ok());
1931 ASSERT_THAT(inferred_status_error3.status().error_message(),
1932 HasSubstr("Broadcast dimension 0 mismatch"));
1933
1934 // broadcast_dimensions list too long
1935 auto inferred_status_error4 = ShapeInference::InferBinaryOpShape(
1936 HloOpcode::kAdd, tensor, matrix8_4, {0, 1, 2});
1937 ASSERT_FALSE(inferred_status_error4.ok());
1938 ASSERT_THAT(inferred_status_error4.status().error_message(),
1939 HasSubstr("broadcast_dimensions has to match"));
1940
1941 // there's a dimension above the rank of the tensor
1942 auto inferred_status_error5 = ShapeInference::InferBinaryOpShape(
1943 HloOpcode::kAdd, tensor, matrix8_4, {3, 0});
1944 ASSERT_FALSE(inferred_status_error5.ok());
1945 ASSERT_THAT(inferred_status_error5.status().error_message(),
1946 ContainsRegex("dimension number .* too large"));
1947
1948 // broadcasting dimensions don't match in this order
1949 auto inferred_status_error6 = ShapeInference::InferBinaryOpShape(
1950 HloOpcode::kAdd, tensor, matrix8_4, {2, 1});
1951 ASSERT_FALSE(inferred_status_error6.ok());
1952 ASSERT_THAT(inferred_status_error6.status().error_message(),
1953 HasSubstr("dimension 0 mismatch"));
1954
1955 // The following two tests make sure that broadcasting dimensions are listed
1956 // in a proper (strictly increasing) order, even if the lower-rank array
1957 // matches the higher-rank array in many different ways.
1958 auto inferred_status_error7 = ShapeInference::InferBinaryOpShape(
1959 HloOpcode::kAdd, tensor8_8_8, matrix8_8, {0, 0});
1960 ASSERT_FALSE(inferred_status_error7.ok());
1961 ASSERT_THAT(inferred_status_error7.status().error_message(),
1962 HasSubstr("dimensions order is wrong"));
1963
1964 auto inferred_status_error8 = ShapeInference::InferBinaryOpShape(
1965 HloOpcode::kAdd, tensor8_8_8, matrix8_8, {1, 0});
1966 ASSERT_FALSE(inferred_status_error8.ok());
1967 ASSERT_THAT(inferred_status_error8.status().error_message(),
1968 HasSubstr("dimensions order is wrong"));
1969 }
1970
1971 // Tests for the while instruction with proper shapes.
TEST_F(ShapeInferenceTest,WhileWithCorrectShapes)1972 TEST_F(ShapeInferenceTest, WhileWithCorrectShapes) {
1973 Shape result_shape = ShapeUtil::MakeTupleShape({s32_, vector_32_});
1974 ProgramShape cond = ShapeUtil::MakeProgramShape({result_shape}, pred_);
1975 ProgramShape body = ShapeUtil::MakeProgramShape({result_shape}, result_shape);
1976 auto inferred_status =
1977 ShapeInference::InferWhileShape(cond, body, result_shape);
1978 ASSERT_IS_OK(inferred_status.status());
1979 Shape inferred = inferred_status.ValueOrDie();
1980 ASSERT_TRUE(ShapeUtil::Equal(result_shape, inferred));
1981 }
1982
1983 // Tests for the while instruction with wrong shapes.
TEST_F(ShapeInferenceTest,WhileWithBadShapes)1984 TEST_F(ShapeInferenceTest, WhileWithBadShapes) {
1985 Shape result_shape = ShapeUtil::MakeTupleShape({s32_, vector_32_});
1986 ProgramShape cond = ShapeUtil::MakeProgramShape({result_shape}, pred_);
1987 ProgramShape body = ShapeUtil::MakeProgramShape({result_shape}, result_shape);
1988
1989 auto bad_shape_1 = ShapeUtil::MakeProgramShape({s32_, result_shape}, pred_);
1990 auto inferred_status_error1 =
1991 ShapeInference::InferWhileShape(bad_shape_1, body, result_shape);
1992 ASSERT_FALSE(inferred_status_error1.ok());
1993 ASSERT_THAT(inferred_status_error1.status().error_message(),
1994 HasSubstr("Condition must take 1 arguments"));
1995
1996 auto bad_shape_2 =
1997 ShapeUtil::MakeProgramShape({s32_, result_shape}, result_shape);
1998 auto inferred_status_error2 =
1999 ShapeInference::InferWhileShape(cond, bad_shape_2, result_shape);
2000 ASSERT_FALSE(inferred_status_error2.ok());
2001 ASSERT_THAT(inferred_status_error2.status().error_message(),
2002 HasSubstr("Body must take 1 arguments"));
2003
2004 auto bad_shape_3 = ShapeUtil::MakeProgramShape({result_shape}, s32_);
2005 auto inferred_status_error3 =
2006 ShapeInference::InferWhileShape(bad_shape_3, body, result_shape);
2007 ASSERT_FALSE(inferred_status_error3.ok());
2008 ASSERT_THAT(inferred_status_error3.status().error_message(),
2009 HasSubstr("Condition must return a boolean"));
2010
2011 auto bad_shape_4 = ShapeUtil::MakeProgramShape({result_shape}, vector_32_);
2012 auto inferred_status_error4 =
2013 ShapeInference::InferWhileShape(cond, bad_shape_4, result_shape);
2014 ASSERT_FALSE(inferred_status_error4.ok());
2015 ASSERT_THAT(inferred_status_error4.status().error_message(),
2016 HasSubstr("parameter of condition and body"));
2017 }
2018
2019 // Tests for the concatenate instruction with dynamic shapes.
TEST_F(ShapeInferenceTest,ConcatenateWithDynamicShapes)2020 TEST_F(ShapeInferenceTest, ConcatenateWithDynamicShapes) {
2021 auto dynamic_shape_1 =
2022 ShapeUtil::MakeShape(F32, {32, 160, 10}, {true, false, false});
2023 auto dynamic_shape_2 =
2024 ShapeUtil::MakeShape(F32, {32, 160, 10}, {false, true, false});
2025 auto inferred_status = ShapeInference::InferConcatOpShape(
2026 {&dynamic_shape_1, &dynamic_shape_2}, /*dimension=*/0);
2027 ASSERT_IS_OK(inferred_status.status());
2028 Shape inferred = inferred_status.ValueOrDie();
2029 ASSERT_TRUE(ShapeUtil::Equal(
2030 ShapeUtil::MakeShape(F32, {64, 160, 10}, {true, true, false}), inferred));
2031 }
2032
2033 // Tests for the concatenate instruction with proper shapes.
TEST_F(ShapeInferenceTest,ConcatenateWithCorrectShapes)2034 TEST_F(ShapeInferenceTest, ConcatenateWithCorrectShapes) {
2035 auto inferred_status_1 = ShapeInference::InferConcatOpShape(
2036 {&vector_32_, &vector_64_}, /*dimension=*/0);
2037 ASSERT_IS_OK(inferred_status_1.status());
2038 Shape inferred_1 = inferred_status_1.ValueOrDie();
2039 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {96}), inferred_1));
2040
2041 auto inferred_status_2 = ShapeInference::InferConcatOpShape(
2042 {&vector_32_, &vector_64_, &vector_32_}, /*dimension=*/0);
2043 ASSERT_IS_OK(inferred_status_2.status());
2044 Shape inferred_2 = inferred_status_2.ValueOrDie();
2045 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {128}), inferred_2));
2046
2047 auto inferred_status_3 = ShapeInference::InferConcatOpShape(
2048 {&matrix_32_48_, &matrix_32_64_, &matrix_32_48_}, /*dimension=*/1);
2049 ASSERT_IS_OK(inferred_status_3.status());
2050 Shape inferred_3 = inferred_status_3.ValueOrDie();
2051 ASSERT_TRUE(
2052 ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {32, 160}), inferred_3));
2053 }
2054
2055 // Tests for the concatenate instruction with wrong shapes.
TEST_F(ShapeInferenceTest,ConcatenateWithBadShapes)2056 TEST_F(ShapeInferenceTest, ConcatenateWithBadShapes) {
2057 auto inferred_status_error1 =
2058 ShapeInference::InferConcatOpShape({}, /*dimension=*/0);
2059 ASSERT_FALSE(inferred_status_error1.ok());
2060 ASSERT_THAT(inferred_status_error1.status().error_message(),
2061 HasSubstr("Concatenate expects at least one argument"));
2062
2063 auto inferred_status_error2 =
2064 ShapeInference::InferConcatOpShape({&vector_32_}, /*dimension=*/-1);
2065 ASSERT_FALSE(inferred_status_error2.ok());
2066 ASSERT_THAT(inferred_status_error2.status().error_message(),
2067 HasSubstr("dimension out of bounds: -1"));
2068
2069 auto inferred_status_error3 =
2070 ShapeInference::InferConcatOpShape({&vector_32_}, /*dimension=*/1);
2071 ASSERT_FALSE(inferred_status_error3.ok());
2072 ASSERT_THAT(inferred_status_error3.status().error_message(),
2073 HasSubstr("dimension out of bounds: 1"));
2074
2075 Shape tuple = ShapeUtil::MakeTupleShape({vector_32_});
2076 auto inferred_status_error4 = ShapeInference::InferConcatOpShape(
2077 {&vector_32_, &tuple}, /*dimension=*/0);
2078 ASSERT_FALSE(inferred_status_error4.ok());
2079 ASSERT_THAT(
2080 inferred_status_error4.status().error_message(),
2081 HasSubstr("Expected array argument for operand of concatenation"));
2082
2083 const Shape vector_s32 = ShapeUtil::MakeShape(S32, {32});
2084 auto inferred_status_error5 = ShapeInference::InferConcatOpShape(
2085 {&vector_32_, &vector_s32}, /*dimension=*/0);
2086 ASSERT_FALSE(inferred_status_error5.ok());
2087 ASSERT_THAT(inferred_status_error5.status().error_message(),
2088 HasSubstr("concatenate arrays with different element types"));
2089
2090 auto inferred_status_error6 = ShapeInference::InferConcatOpShape(
2091 {&matrix_32_48_, &matrix_32_64_}, /*dimension=*/0);
2092 ASSERT_FALSE(inferred_status_error6.ok());
2093 ASSERT_THAT(inferred_status_error6.status().error_message(),
2094 HasSubstr("concatenate arrays that differ in "
2095 "dimensions other than the one being "
2096 "concatenated"));
2097 }
2098
TEST_F(ShapeInferenceTest,Pad)2099 TEST_F(ShapeInferenceTest, Pad) {
2100 Shape input_shape = ShapeUtil::MakeShape(F32, {10, 25});
2101 Shape padding_value_shape = ShapeUtil::MakeShape(F32, {});
2102 // Padding for dimension 0: {low: 0, high: 2, interior: 3}
2103 // Padding for dimension 1: {low: 1, high: 5, interior: 0}
2104 PaddingConfig padding_config;
2105 auto dimension0 = padding_config.add_dimensions();
2106 dimension0->set_edge_padding_low(0);
2107 dimension0->set_edge_padding_high(2);
2108 dimension0->set_interior_padding(3);
2109 auto dimension1 = padding_config.add_dimensions();
2110 dimension1->set_edge_padding_low(1);
2111 dimension1->set_edge_padding_high(5);
2112 dimension1->set_interior_padding(0);
2113
2114 auto inferred_status = ShapeInference::InferPadShape(
2115 input_shape, padding_value_shape, padding_config);
2116 ASSERT_IS_OK(inferred_status.status());
2117 Shape inferred_shape = inferred_status.ValueOrDie();
2118 ASSERT_TRUE(
2119 ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {39, 31}), inferred_shape));
2120
2121 dimension1->set_edge_padding_low(-20);
2122 dimension1->set_edge_padding_high(-10);
2123 auto negative_dimension_size = ShapeInference::InferPadShape(
2124 input_shape, padding_value_shape, padding_config);
2125 ASSERT_FALSE(negative_dimension_size.ok());
2126 ASSERT_THAT(negative_dimension_size.status().error_message(),
2127 HasSubstr("negative size for dimension 1"));
2128 }
2129
TEST_F(ShapeInferenceTest,Reverse)2130 TEST_F(ShapeInferenceTest, Reverse) {
2131 Shape input_shape = ShapeUtil::MakeShape(F32, {10, 25});
2132
2133 auto inferred_status = ShapeInference::InferReverseShape(input_shape, {0, 1});
2134 ASSERT_IS_OK(inferred_status.status());
2135 Shape inferred_shape = inferred_status.ValueOrDie();
2136 ASSERT_TRUE(ShapeUtil::Equal(input_shape, inferred_shape));
2137 }
2138
TEST_F(ShapeInferenceTest,ReverseInvalidDimension)2139 TEST_F(ShapeInferenceTest, ReverseInvalidDimension) {
2140 Shape input_shape = ShapeUtil::MakeShape(F32, {10, 25});
2141
2142 auto inferred_status_error0 =
2143 ShapeInference::InferReverseShape(input_shape, {0, 2});
2144 ASSERT_FALSE(inferred_status_error0.ok());
2145 ASSERT_THAT(inferred_status_error0.status().error_message(),
2146 HasSubstr("out-of-bounds"));
2147
2148 auto inferred_status_error1 =
2149 ShapeInference::InferReverseShape(input_shape, {0, -1});
2150 ASSERT_FALSE(inferred_status_error1.ok());
2151 ASSERT_THAT(inferred_status_error1.status().error_message(),
2152 HasSubstr("out-of-bounds"));
2153
2154 auto inferred_status_error2 =
2155 ShapeInference::InferReverseShape(input_shape, {0, 0});
2156 ASSERT_FALSE(inferred_status_error2.ok());
2157 ASSERT_THAT(inferred_status_error2.status().error_message(),
2158 HasSubstr("duplicated"));
2159
2160 Shape tuple_shape = ShapeUtil::MakeTupleShape({input_shape, input_shape});
2161 auto inferred_status_error3 =
2162 ShapeInference::InferReverseShape(tuple_shape, {0});
2163 ASSERT_FALSE(inferred_status_error3.ok());
2164 ASSERT_THAT(inferred_status_error3.status().error_message(),
2165 HasSubstr("Expected array argument"));
2166 }
2167
TEST_F(ShapeInferenceTest,Call)2168 TEST_F(ShapeInferenceTest, Call) {
2169 auto inferred_status0 =
2170 ShapeInference::InferCallShape({}, ShapeUtil::MakeProgramShape({}, f32_));
2171 EXPECT_IS_OK(inferred_status0.status());
2172 EXPECT_TRUE(ShapeUtil::Equal(f32_, inferred_status0.ValueOrDie()));
2173
2174 auto inferred_status1 = ShapeInference::InferCallShape(
2175 {&f32_, &s32_, &pred_, &vector_32_, &matrix_32_48_},
2176 ShapeUtil::MakeProgramShape(
2177 {f32_, s32_, pred_, vector_32_, matrix_32_48_}, s32matrix_64_64_));
2178 EXPECT_IS_OK(inferred_status1.status());
2179 EXPECT_TRUE(
2180 ShapeUtil::Equal(s32matrix_64_64_, inferred_status1.ValueOrDie()));
2181
2182 auto inferred_status_error0 = ShapeInference::InferCallShape(
2183 {}, ShapeUtil::MakeProgramShape({f32_}, f32_));
2184 EXPECT_FALSE(inferred_status_error0.ok());
2185 EXPECT_THAT(inferred_status_error0.status().error_message(),
2186 HasSubstr("arity must match"));
2187
2188 auto inferred_status_error1 = ShapeInference::InferCallShape(
2189 {&f32_}, ShapeUtil::MakeProgramShape({}, f32_));
2190 EXPECT_FALSE(inferred_status_error1.ok());
2191 EXPECT_THAT(inferred_status_error1.status().error_message(),
2192 HasSubstr("arity must match"));
2193
2194 auto inferred_status_error2 = ShapeInference::InferCallShape(
2195 {&f32_}, ShapeUtil::MakeProgramShape({s32_}, f32_));
2196 EXPECT_FALSE(inferred_status_error2.ok());
2197 EXPECT_THAT(inferred_status_error2.status().error_message(),
2198 HasSubstr("parameter must match argument"));
2199 }
2200
TEST_F(ShapeInferenceTest,Transpose)2201 TEST_F(ShapeInferenceTest, Transpose) {
2202 Shape a_shape = ShapeUtil::MakeShape(F32, {2, 3, 4, 5});
2203 auto inferred_shape_and_status =
2204 ShapeInference::InferTransposeShape(a_shape, {1, 2, 3, 0});
2205 EXPECT_IS_OK(inferred_shape_and_status);
2206 Shape inferred_shape = inferred_shape_and_status.ValueOrDie();
2207 EXPECT_TRUE(ShapeUtil::Compatible(inferred_shape,
2208 ShapeUtil::MakeShape(F32, {3, 4, 5, 2})));
2209 }
2210
TEST_F(ShapeInferenceTest,Rank1Transpose)2211 TEST_F(ShapeInferenceTest, Rank1Transpose) {
2212 Shape a_shape = ShapeUtil::MakeShape(F32, {5});
2213 auto inferred_shape_and_status =
2214 ShapeInference::InferTransposeShape(a_shape, {0});
2215 EXPECT_IS_OK(inferred_shape_and_status);
2216 Shape inferred_shape = inferred_shape_and_status.ValueOrDie();
2217 EXPECT_TRUE(
2218 ShapeUtil::Compatible(inferred_shape, ShapeUtil::MakeShape(F32, {5})));
2219 }
2220
TEST_F(ShapeInferenceTest,ConditionalPred)2221 TEST_F(ShapeInferenceTest, ConditionalPred) {
2222 auto inferred_status0 = ShapeInference::InferConditionalShape(
2223 pred_,
2224 {ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2225 ShapeUtil::MakeProgramShape({vector_64_}, f32_)},
2226 {vector_32_, vector_64_});
2227 EXPECT_IS_OK(inferred_status0.status());
2228 EXPECT_TRUE(ShapeUtil::Equal(f32_, inferred_status0.ValueOrDie()));
2229
2230 auto inferred_status1 = ShapeInference::InferConditionalShape(
2231 pred_,
2232 {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_64_),
2233 ShapeUtil::MakeProgramShape({vector_32_}, vector_64_)},
2234 {matrix_32_48_, vector_32_});
2235 EXPECT_IS_OK(inferred_status1.status());
2236 EXPECT_TRUE(ShapeUtil::Equal(vector_64_, inferred_status1.ValueOrDie()));
2237
2238 auto tuple_f32_v32 = ShapeUtil::MakeTupleShape({f32_, vector_32_});
2239 auto inferred_status2 = ShapeInference::InferConditionalShape(
2240 pred_,
2241 {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_),
2242 ShapeUtil::MakeProgramShape({tuple_f32_v32}, vector_32_)},
2243 {matrix_32_48_, tuple_f32_v32});
2244 EXPECT_IS_OK(inferred_status2.status());
2245 EXPECT_TRUE(ShapeUtil::Equal(vector_32_, inferred_status2.ValueOrDie()));
2246
2247 auto inferred_status_error0 = ShapeInference::InferConditionalShape(
2248 f32_,
2249 {ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2250 ShapeUtil::MakeProgramShape({vector_64_}, f32_)},
2251 {vector_32_, vector_64_});
2252 EXPECT_FALSE(inferred_status_error0.ok());
2253 EXPECT_THAT(inferred_status_error0.status().error_message(),
2254 HasSubstr("must be bool or int32_t"));
2255
2256 auto inferred_status_error1 = ShapeInference::InferConditionalShape(
2257 pred_,
2258 {ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_),
2259 ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_)},
2260 {ShapeUtil::MakeTupleShape({f32_, vector_32_}), matrix_32_48_});
2261 EXPECT_FALSE(inferred_status_error1.ok());
2262 EXPECT_THAT(inferred_status_error1.status().error_message(),
2263 HasSubstr("branch computation 0 must take 1 argument"));
2264
2265 auto inferred_status_error2 = ShapeInference::InferConditionalShape(
2266 pred_,
2267 {ShapeUtil::MakeProgramShape({vector_64_}, f32_),
2268 ShapeUtil::MakeProgramShape({vector_64_}, f32_)},
2269 {vector_32_, vector_64_});
2270 EXPECT_FALSE(inferred_status_error2.ok());
2271 EXPECT_THAT(inferred_status_error2.status().error_message(),
2272 HasSubstr("branch operand 0 must match the shape of the only "
2273 "parameter of branch computation 0"));
2274
2275 auto inferred_status_error3 = ShapeInference::InferConditionalShape(
2276 pred_,
2277 {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_),
2278 ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_)},
2279 {matrix_32_48_, ShapeUtil::MakeTupleShape({f32_, vector_32_})});
2280 EXPECT_FALSE(inferred_status_error3.ok());
2281 EXPECT_THAT(inferred_status_error3.status().error_message(),
2282 HasSubstr("branch computation 1 must take 1 argument"));
2283
2284 auto inferred_status_error4 = ShapeInference::InferConditionalShape(
2285 pred_,
2286 {ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2287 ShapeUtil::MakeProgramShape({vector_32_}, f32_)},
2288 {vector_32_, vector_64_});
2289 EXPECT_FALSE(inferred_status_error4.ok());
2290 EXPECT_THAT(inferred_status_error4.status().error_message(),
2291 HasSubstr("branch operand 1 must match the shape of the only "
2292 "parameter of branch computation 1"));
2293
2294 auto inferred_status_error5 = ShapeInference::InferConditionalShape(
2295 pred_,
2296 {ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2297 ShapeUtil::MakeProgramShape({vector_64_}, vector_32_)},
2298 {vector_32_, vector_64_});
2299 EXPECT_FALSE(inferred_status_error5.ok());
2300 EXPECT_THAT(inferred_status_error5.status().error_message(),
2301 HasSubstr("the result of branch 0 computation and branch 1 "
2302 "computation must have the same shape"));
2303 }
2304
TEST_F(ShapeInferenceTest,ConditionalIndexed)2305 TEST_F(ShapeInferenceTest, ConditionalIndexed) {
2306 auto r0s32 = ShapeUtil::MakeShape(S32, {});
2307 auto inferred_status0 = ShapeInference::InferConditionalShape(
2308 r0s32,
2309 {ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2310 ShapeUtil::MakeProgramShape({vector_64_}, f32_),
2311 ShapeUtil::MakeProgramShape({vector_64_}, f32_)},
2312 {vector_32_, vector_64_, vector_64_});
2313 EXPECT_IS_OK(inferred_status0.status());
2314 EXPECT_TRUE(ShapeUtil::Equal(f32_, inferred_status0.ValueOrDie()));
2315
2316 auto inferred_status1 = ShapeInference::InferConditionalShape(
2317 r0s32,
2318 {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_64_),
2319 ShapeUtil::MakeProgramShape({vector_32_}, vector_64_),
2320 ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_64_)},
2321 {matrix_32_48_, vector_32_, matrix_32_48_});
2322 EXPECT_IS_OK(inferred_status1.status());
2323 EXPECT_TRUE(ShapeUtil::Equal(vector_64_, inferred_status1.ValueOrDie()));
2324
2325 auto tuple_f32_v32 = ShapeUtil::MakeTupleShape({f32_, vector_32_});
2326 auto inferred_status2 = ShapeInference::InferConditionalShape(
2327 r0s32, {ShapeUtil::MakeProgramShape({tuple_f32_v32}, vector_32_)},
2328 {tuple_f32_v32});
2329 EXPECT_IS_OK(inferred_status2.status());
2330 EXPECT_TRUE(ShapeUtil::Equal(vector_32_, inferred_status2.ValueOrDie()));
2331
2332 auto inferred_status_error0 = ShapeInference::InferConditionalShape(
2333 pred_,
2334 {ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2335 ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2336 ShapeUtil::MakeProgramShape({vector_64_}, f32_)},
2337 {vector_32_, vector_32_, vector_64_});
2338 EXPECT_FALSE(inferred_status_error0.ok());
2339 EXPECT_THAT(inferred_status_error0.status().error_message(),
2340 HasSubstr("2 == branch_computations.size()"));
2341
2342 auto inferred_status_error1 = ShapeInference::InferConditionalShape(
2343 r0s32,
2344 {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_),
2345 ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_),
2346 ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_)},
2347 {matrix_32_48_, ShapeUtil::MakeTupleShape({f32_, vector_32_}),
2348 matrix_32_48_});
2349 EXPECT_FALSE(inferred_status_error1.ok());
2350 EXPECT_THAT(inferred_status_error1.status().error_message(),
2351 HasSubstr("branch computation 1 must take 1 argument"));
2352
2353 auto inferred_status_error2 = ShapeInference::InferConditionalShape(
2354 r0s32,
2355 {ShapeUtil::MakeProgramShape({r0s32}, f32_),
2356 ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2357 ShapeUtil::MakeProgramShape({vector_32_}, f32_)},
2358 {r0s32, vector_32_, vector_64_});
2359 EXPECT_FALSE(inferred_status_error2.ok());
2360 EXPECT_THAT(inferred_status_error2.status().error_message(),
2361 HasSubstr("branch operand 2 must match the shape of the only "
2362 "parameter of branch computation 2"));
2363
2364 auto inferred_status_error3 = ShapeInference::InferConditionalShape(
2365 r0s32,
2366 {ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2367 ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2368 ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2369 ShapeUtil::MakeProgramShape({vector_64_}, vector_32_)},
2370 {vector_32_, vector_32_, vector_32_, vector_64_});
2371 EXPECT_FALSE(inferred_status_error3.ok());
2372 EXPECT_THAT(inferred_status_error3.status().error_message(),
2373 HasSubstr("the result of branch 0 computation and branch 3 "
2374 "computation must have the same shape"));
2375
2376 auto inferred_status_error4 =
2377 ShapeInference::InferConditionalShape(r0s32, {}, {});
2378 EXPECT_FALSE(inferred_status_error4.ok());
2379 EXPECT_THAT(inferred_status_error4.status().error_message(),
2380 HasSubstr("!branch_computations.empty()"));
2381 }
2382
TEST_F(ShapeInferenceTest,ConditionalDynamic)2383 TEST_F(ShapeInferenceTest, ConditionalDynamic) {
2384 auto r0s32 = ShapeUtil::MakeShape(S32, {});
2385 auto static_shape = ShapeUtil::MakeShape(S32, {4}, {false});
2386 auto dynamic_shape = ShapeUtil::MakeShape(S32, {4}, {true});
2387 auto inferred_status0 = ShapeInference::InferConditionalShape(
2388 r0s32,
2389 {ShapeUtil::MakeProgramShape({vector_32_}, static_shape),
2390 ShapeUtil::MakeProgramShape({vector_64_}, dynamic_shape),
2391 ShapeUtil::MakeProgramShape({vector_64_}, dynamic_shape)},
2392 {vector_32_, vector_64_, vector_64_});
2393 EXPECT_IS_OK(inferred_status0.status());
2394 EXPECT_TRUE(ShapeUtil::Equal(dynamic_shape, inferred_status0.ValueOrDie()));
2395
2396 auto inferred_status1 = ShapeInference::InferConditionalShape(
2397 r0s32,
2398 {ShapeUtil::MakeProgramShape({vector_32_}, dynamic_shape),
2399 ShapeUtil::MakeProgramShape({vector_64_}, static_shape),
2400 ShapeUtil::MakeProgramShape({vector_64_}, dynamic_shape)},
2401 {vector_32_, vector_64_, vector_64_});
2402 EXPECT_IS_OK(inferred_status1.status());
2403 EXPECT_TRUE(ShapeUtil::Equal(dynamic_shape, inferred_status1.ValueOrDie()));
2404 }
2405
TEST_F(ShapeInferenceTest,BadSlice)2406 TEST_F(ShapeInferenceTest, BadSlice) {
2407 auto arg = ShapeUtil::MakeShape(F32, {4});
2408 StatusOr<Shape> statusor =
2409 ShapeInference::InferSliceShape(arg, {0}, {5}, {1});
2410 ASSERT_FALSE(statusor.ok());
2411
2412 LOG(INFO) << statusor.status();
2413
2414 EXPECT_THAT(statusor.status().error_message(),
2415 HasSubstr("less than or equal to dimension size"))
2416 << statusor.status();
2417 EXPECT_THAT(statusor.status().error_message(), HasSubstr("argument shape"))
2418 << statusor.status();
2419 }
2420
TEST_F(ShapeInferenceTest,BadSort)2421 TEST_F(ShapeInferenceTest, BadSort) {
2422 auto keys = ShapeUtil::MakeShape(F32, {4});
2423 auto values = ShapeUtil::MakeShape(F32, {5});
2424 StatusOr<Shape> statusor =
2425 ShapeInference::InferVariadicOpShape(HloOpcode::kSort, {&keys, &values});
2426 EXPECT_FALSE(statusor.ok());
2427 EXPECT_THAT(statusor.status().error_message(),
2428 HasSubstr("dimensions must match"))
2429 << statusor.status();
2430 }
2431
TEST_F(ShapeInferenceTest,BadSortValuesMismatch)2432 TEST_F(ShapeInferenceTest, BadSortValuesMismatch) {
2433 auto keys = ShapeUtil::MakeShape(F32, {4});
2434 auto values_good = ShapeUtil::MakeShape(F32, {4});
2435 auto values_bad = ShapeUtil::MakeShape(F32, {5});
2436 StatusOr<Shape> statusor = ShapeInference::InferVariadicOpShape(
2437 HloOpcode::kSort, {&keys, &values_good, &values_bad});
2438 EXPECT_FALSE(statusor.ok());
2439 EXPECT_THAT(statusor.status().error_message(),
2440 HasSubstr("dimensions must match"))
2441 << statusor.status();
2442 }
2443
TEST_F(ShapeInferenceTest,SortManyValues)2444 TEST_F(ShapeInferenceTest, SortManyValues) {
2445 auto keys = ShapeUtil::MakeShape(F32, {4});
2446 auto values_s32 = ShapeUtil::MakeShape(S32, {4});
2447 auto values_u32 = ShapeUtil::MakeShape(U32, {4});
2448 StatusOr<Shape> statusor = ShapeInference::InferVariadicOpShape(
2449 HloOpcode::kSort, {&keys, &values_s32, &values_u32});
2450 EXPECT_IS_OK(statusor);
2451 Shape inferred_shape = statusor.ValueOrDie();
2452 EXPECT_TRUE(ShapeUtil::Compatible(
2453 inferred_shape,
2454 ShapeUtil::MakeTupleShape({keys, values_s32, values_u32})));
2455 }
2456
2457 class GatherShapeInferenceTest : public ShapeInferenceTest {
2458 protected:
2459 const Shape s64_scalar_ = ShapeUtil::MakeShape(S64, {});
2460 const Shape s64_vector_5_ = ShapeUtil::MakeShape(S64, {5});
2461 const Shape s64_vector_32_ = ShapeUtil::MakeShape(S64, {32});
2462 const Shape s64_4d_tensor_10_9_8_7_1_ =
2463 ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 1});
2464 const Shape s64_4d_tensor_10_9_8_7_5_ =
2465 ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 5});
2466 const Shape s64_4d_tensor_5_10_9_7_6_ =
2467 ShapeUtil::MakeShape(S64, {5, 10, 9, 7, 6});
2468 const Shape s64_4d_tensor_10_9_5_7_6_ =
2469 ShapeUtil::MakeShape(S64, {10, 9, 5, 7, 6});
2470 const Shape f32_5d_tensor_50_49_48_47_46_ =
2471 ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46});
2472 const Shape tuple_shape_ = ShapeUtil::MakeTupleShape(
2473 {s64_4d_tensor_10_9_8_7_1_, s64_4d_tensor_10_9_8_7_1_});
2474 };
2475
TEST_F(GatherShapeInferenceTest,TensorFlowGather)2476 TEST_F(GatherShapeInferenceTest, TensorFlowGather) {
2477 TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
2478 ShapeInference::InferGatherShape(
2479 matrix_64_48_, s64_vector_32_,
2480 HloGatherInstruction::MakeGatherDimNumbers(
2481 /*offset_dims=*/{0},
2482 /*collapsed_slice_dims=*/{1},
2483 /*start_index_map=*/{1},
2484 /*index_vector_dim=*/1),
2485 /*slice_sizes=*/{64, 1}));
2486 EXPECT_TRUE(
2487 ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {64, 32})))
2488 << ShapeUtil::HumanString(gather_shape);
2489 }
2490
TEST_F(GatherShapeInferenceTest,TensorFlowGatherV2)2491 TEST_F(GatherShapeInferenceTest, TensorFlowGatherV2) {
2492 TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
2493 ShapeInference::InferGatherShape(
2494 matrix_64_48_, s64_vector_32_,
2495 HloGatherInstruction::MakeGatherDimNumbers(
2496 /*offset_dims=*/{1},
2497 /*collapsed_slice_dims=*/{0},
2498 /*start_index_map=*/{0},
2499 /*index_vector_dim=*/1),
2500 /*slice_sizes=*/{1, 48}));
2501 EXPECT_TRUE(
2502 ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {32, 48})))
2503 << ShapeUtil::HumanString(gather_shape);
2504 }
2505
TEST_F(GatherShapeInferenceTest,TensorFlowGatherNd)2506 TEST_F(GatherShapeInferenceTest, TensorFlowGatherNd) {
2507 TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
2508 ShapeInference::InferGatherShape(
2509 matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
2510 HloGatherInstruction::MakeGatherDimNumbers(
2511 /*offset_dims=*/{4},
2512 /*collapsed_slice_dims=*/{0},
2513 /*start_index_map=*/{0},
2514 /*index_vector_dim=*/4),
2515 /*slice_sizes=*/{1, 48}));
2516 EXPECT_TRUE(ShapeUtil::Equal(gather_shape,
2517 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 48})))
2518 << ShapeUtil::HumanString(gather_shape);
2519 }
2520
TEST_F(GatherShapeInferenceTest,TensorFlowBatchDynamicSlice)2521 TEST_F(GatherShapeInferenceTest, TensorFlowBatchDynamicSlice) {
2522 TF_ASSERT_OK_AND_ASSIGN(
2523 Shape gather_shape,
2524 ShapeInference::InferGatherShape(
2525 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2526 HloGatherInstruction::MakeGatherDimNumbers(
2527 /*offset_dims=*/{4, 5, 6, 7, 8},
2528 /*collapsed_slice_dims=*/{},
2529 /*start_index_map=*/{0, 1, 2, 3, 4},
2530 /*index_vector_dim=*/4),
2531 /*slice_sizes=*/{30, 29, 28, 27, 26}));
2532 EXPECT_TRUE(ShapeUtil::Equal(
2533 gather_shape,
2534 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26})))
2535 << ShapeUtil::HumanString(gather_shape);
2536 }
2537
TEST_F(GatherShapeInferenceTest,DynamicGatherEntireDimension)2538 TEST_F(GatherShapeInferenceTest, DynamicGatherEntireDimension) {
2539 TF_ASSERT_OK_AND_ASSIGN(
2540 Shape gather_shape,
2541 ShapeInference::InferGatherShape(
2542 ShapeUtil::MakeShape(F32, {3, 2, 1}, {false, true, false}),
2543 ShapeUtil::MakeShape(S64, {}),
2544 HloGatherInstruction::MakeGatherDimNumbers(
2545 /*offset_dims=*/{0, 1},
2546 /*collapsed_slice_dims=*/{0},
2547 /*start_index_map=*/{0},
2548 /*index_vector_dim=*/0),
2549 /*slice_sizes=*/{1, 2, 1}));
2550 EXPECT_TRUE(ShapeUtil::Equal(
2551 gather_shape, ShapeUtil::MakeShape(F32, {2, 1}, {true, false})))
2552 << ShapeUtil::HumanString(gather_shape);
2553 }
2554
TEST_F(GatherShapeInferenceTest,DynamicGatherCollapsedDimension)2555 TEST_F(GatherShapeInferenceTest, DynamicGatherCollapsedDimension) {
2556 TF_ASSERT_OK_AND_ASSIGN(
2557 Shape gather_shape,
2558 ShapeInference::InferGatherShape(
2559 ShapeUtil::MakeShape(F32, {3, 2, 1}, {true, false, false}),
2560 ShapeUtil::MakeShape(S64, {}),
2561 HloGatherInstruction::MakeGatherDimNumbers(
2562 /*offset_dims=*/{0, 1},
2563 /*collapsed_slice_dims=*/{0},
2564 /*start_index_map=*/{0},
2565 /*index_vector_dim=*/0),
2566 /*slice_sizes=*/{1, 2, 1}));
2567 EXPECT_TRUE(ShapeUtil::Equal(
2568 gather_shape, ShapeUtil::MakeShape(F32, {2, 1}, {false, false})))
2569 << ShapeUtil::HumanString(gather_shape);
2570 }
2571
TEST_F(GatherShapeInferenceTest,DynamicIndices)2572 TEST_F(GatherShapeInferenceTest, DynamicIndices) {
2573 TF_ASSERT_OK_AND_ASSIGN(
2574 Shape gather_shape,
2575 ShapeInference::InferGatherShape(
2576 ShapeUtil::MakeShape(F32, {3, 2, 2}),
2577 ShapeUtil::MakeShape(S64, {3, 4, 2}, {false, true, false}),
2578 HloGatherInstruction::MakeGatherDimNumbers(
2579 /*offset_dims=*/{2, 3},
2580 /*collapsed_slice_dims=*/{0},
2581 /*start_index_map=*/{0, 1},
2582 /*index_vector_dim=*/2),
2583 /*slice_sizes=*/{1, 2, 2}));
2584 EXPECT_TRUE(ShapeUtil::Equal(
2585 gather_shape,
2586 ShapeUtil::MakeShape(F32, {3, 4, 2, 2}, {false, true, false, false})))
2587 << ShapeUtil::HumanString(gather_shape);
2588 }
2589
TEST_F(GatherShapeInferenceTest,NonDefaultGatherIndicesLeafDim_A)2590 TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_A) {
2591 TF_ASSERT_OK_AND_ASSIGN(
2592 Shape gather_shape,
2593 ShapeInference::InferGatherShape(
2594 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_,
2595 HloGatherInstruction::MakeGatherDimNumbers(
2596 /*offset_dims=*/{4, 5, 6, 7, 8},
2597 /*collapsed_slice_dims=*/{},
2598 /*start_index_map=*/{0, 1, 2, 3, 4},
2599 /*index_vector_dim=*/2),
2600 /*slice_sizes=*/{30, 29, 28, 27, 26}));
2601
2602 EXPECT_TRUE(ShapeUtil::Equal(
2603 gather_shape,
2604 ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26})))
2605 << ShapeUtil::HumanString(gather_shape);
2606 }
2607
TEST_F(GatherShapeInferenceTest,NonDefaultGatherIndicesLeafDim_B)2608 TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) {
2609 TF_ASSERT_OK_AND_ASSIGN(
2610 Shape gather_shape,
2611 ShapeInference::InferGatherShape(
2612 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_5_10_9_7_6_,
2613 HloGatherInstruction::MakeGatherDimNumbers(
2614 /*offset_dims=*/{4, 5, 6, 7, 8},
2615 /*collapsed_slice_dims=*/{},
2616 /*start_index_map=*/{0, 1, 2, 3, 4},
2617 /*index_vector_dim=*/0),
2618 /*slice_sizes=*/{30, 29, 28, 27, 26}));
2619
2620 EXPECT_TRUE(ShapeUtil::Equal(
2621 gather_shape,
2622 ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26})))
2623 << ShapeUtil::HumanString(gather_shape);
2624 }
2625
TEST_F(GatherShapeInferenceTest,NoOutputGatherDims)2626 TEST_F(GatherShapeInferenceTest, NoOutputGatherDims) {
2627 // This is equivalent to a dynamic slice.
2628 TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
2629 ShapeInference::InferGatherShape(
2630 f32_5d_tensor_50_49_48_47_46_, s64_vector_5_,
2631 HloGatherInstruction::MakeGatherDimNumbers(
2632 /*offset_dims=*/{0, 1, 2, 3, 4},
2633 /*collapsed_slice_dims=*/{},
2634 /*start_index_map=*/{0, 1, 2, 3, 4},
2635 /*index_vector_dim=*/0),
2636 /*slice_sizes=*/{30, 29, 28, 27, 26}));
2637
2638 EXPECT_TRUE(ShapeUtil::Equal(gather_shape,
2639 ShapeUtil::MakeShape(F32, {30, 29, 28, 27, 26})))
2640 << ShapeUtil::HumanString(gather_shape);
2641 }
2642
TEST_F(GatherShapeInferenceTest,ScalarGatherIndices)2643 TEST_F(GatherShapeInferenceTest, ScalarGatherIndices) {
2644 // The gather indices "tensor" is a scalar S here that's used to slice out
2645 // [S,0,0,0,0]..[S,30,29,28,27] into a [30,29,28,27] shaped result.
2646 TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
2647 ShapeInference::InferGatherShape(
2648 f32_5d_tensor_50_49_48_47_46_, s64_scalar_,
2649 HloGatherInstruction::MakeGatherDimNumbers(
2650 /*offset_dims=*/{0, 1, 2, 3},
2651 /*collapsed_slice_dims=*/{0},
2652 /*start_index_map=*/{0},
2653 /*index_vector_dim=*/0),
2654 /*slice_sizes=*/{1, 30, 29, 28, 27}));
2655
2656 EXPECT_TRUE(ShapeUtil::Equal(gather_shape,
2657 ShapeUtil::MakeShape(F32, {30, 29, 28, 27})))
2658 << ShapeUtil::HumanString(gather_shape);
2659 }
2660
TEST_F(GatherShapeInferenceTest,TupleShapedTensorInput)2661 TEST_F(GatherShapeInferenceTest, TupleShapedTensorInput) {
2662 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2663 tuple_shape_, s64_vector_32_,
2664 HloGatherInstruction::MakeGatherDimNumbers(
2665 /*offset_dims=*/{0},
2666 /*collapsed_slice_dims=*/{1},
2667 /*start_index_map=*/{1},
2668 /*index_vector_dim=*/1),
2669 /*slice_sizes=*/{64, 1});
2670 ASSERT_FALSE(statusor.ok());
2671 EXPECT_THAT(statusor.status().error_message(),
2672 HasSubstr("Expected array argument for input"))
2673 << statusor.status();
2674 }
2675
TEST_F(GatherShapeInferenceTest,TupleShapedGatherIndicesInput)2676 TEST_F(GatherShapeInferenceTest, TupleShapedGatherIndicesInput) {
2677 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2678 s64_vector_32_, tuple_shape_,
2679 HloGatherInstruction::MakeGatherDimNumbers(
2680 /*offset_dims=*/{0},
2681 /*collapsed_slice_dims=*/{1},
2682 /*start_index_map=*/{1},
2683 /*index_vector_dim=*/0),
2684 /*slice_sizes=*/{64, 1});
2685 ASSERT_FALSE(statusor.ok());
2686 EXPECT_THAT(statusor.status().error_message(),
2687 HasSubstr("Expected array argument for gather indices"))
2688 << statusor.status();
2689 }
2690
TEST_F(GatherShapeInferenceTest,FloatingPointGatherIndicesInput)2691 TEST_F(GatherShapeInferenceTest, FloatingPointGatherIndicesInput) {
2692 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2693 s64_vector_32_, vector_32_,
2694 HloGatherInstruction::MakeGatherDimNumbers(
2695 /*offset_dims=*/{0},
2696 /*collapsed_slice_dims=*/{1},
2697 /*start_index_map=*/{1},
2698 /*index_vector_dim=*/0),
2699 /*slice_sizes=*/{64, 1});
2700 ASSERT_FALSE(statusor.ok());
2701 EXPECT_THAT(statusor.status().error_message(),
2702 HasSubstr("Gather indices parameter must be an integral tensor"))
2703 << statusor.status();
2704 }
2705
TEST_F(GatherShapeInferenceTest,InvalidGatherDimNumbers_NonAscendingWindowIndices)2706 TEST_F(GatherShapeInferenceTest,
2707 InvalidGatherDimNumbers_NonAscendingWindowIndices) {
2708 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2709 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2710 HloGatherInstruction::MakeGatherDimNumbers(
2711 /*offset_dims=*/{4, 5, 6, 8, 7},
2712 /*collapsed_slice_dims=*/{},
2713 /*start_index_map=*/{0, 1, 2, 3, 4},
2714 /*index_vector_dim=*/4),
2715 /*slice_sizes=*/{30, 29, 28, 27, 26});
2716 ASSERT_FALSE(statusor.ok());
2717 EXPECT_THAT(
2718 statusor.status().error_message(),
2719 HasSubstr("Output window dimensions in gather op must be ascending"))
2720 << statusor.status();
2721 }
2722
TEST_F(GatherShapeInferenceTest,InvalidGatherDimNumbers_RepeatedWindowIndices)2723 TEST_F(GatherShapeInferenceTest,
2724 InvalidGatherDimNumbers_RepeatedWindowIndices) {
2725 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2726 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2727 HloGatherInstruction::MakeGatherDimNumbers(
2728 /*offset_dims=*/{4, 5, 6, 7, 7},
2729 /*collapsed_slice_dims=*/{},
2730 /*start_index_map=*/{0, 1, 2, 3, 4},
2731 /*index_vector_dim=*/4),
2732 /*slice_sizes=*/{30, 29, 28, 27, 26});
2733 ASSERT_FALSE(statusor.ok());
2734 EXPECT_THAT(
2735 statusor.status().error_message(),
2736 HasSubstr("Output window dimensions in gather op must not repeat"))
2737 << statusor.status();
2738 }
2739
TEST_F(GatherShapeInferenceTest,InvalidGatherDimNumbers_WindowIndexOutOfBounds)2740 TEST_F(GatherShapeInferenceTest,
2741 InvalidGatherDimNumbers_WindowIndexOutOfBounds) {
2742 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2743 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2744 HloGatherInstruction::MakeGatherDimNumbers(
2745 /*offset_dims=*/{4, 5, 99, 100, 101},
2746 /*collapsed_slice_dims=*/{},
2747 /*start_index_map=*/{0, 1, 2, 3, 4},
2748 /*index_vector_dim=*/4),
2749 /*slice_sizes=*/{30, 29, 28, 27, 26});
2750 ASSERT_FALSE(statusor.ok());
2751 EXPECT_THAT(statusor.status().error_message(),
2752 HasSubstr("Offset dimension 2 in gather op is out of bounds"))
2753 << statusor.status();
2754 }
2755
TEST_F(GatherShapeInferenceTest,InvalidGatherDimNumbers_WindowIndexBarelyOutOfBounds)2756 TEST_F(GatherShapeInferenceTest,
2757 InvalidGatherDimNumbers_WindowIndexBarelyOutOfBounds) {
2758 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2759 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2760 HloGatherInstruction::MakeGatherDimNumbers(
2761 /*offset_dims=*/{4, 5, 6, 7, 9},
2762 /*collapsed_slice_dims=*/{},
2763 /*start_index_map=*/{0, 1, 2, 3, 4},
2764 /*index_vector_dim=*/4),
2765 /*slice_sizes=*/{30, 29, 28, 27, 26});
2766 ASSERT_FALSE(statusor.ok());
2767 EXPECT_THAT(statusor.status().error_message(),
2768 HasSubstr("Offset dimension 4 in gather op is out of bounds"))
2769 << statusor.status();
2770 }
2771
TEST_F(GatherShapeInferenceTest,InvalidGatherDimNumbers_MismatchingElidedWindowDims)2772 TEST_F(GatherShapeInferenceTest,
2773 InvalidGatherDimNumbers_MismatchingElidedWindowDims) {
2774 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2775 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2776 HloGatherInstruction::MakeGatherDimNumbers(
2777 /*offset_dims=*/{4, 5, 6, 7, 8},
2778 /*collapsed_slice_dims=*/{4},
2779 /*start_index_map=*/{0, 1, 2, 3, 4},
2780 /*index_vector_dim=*/4),
2781 /*slice_sizes=*/{30, 29, 28, 27, 26});
2782 ASSERT_FALSE(statusor.ok());
2783 EXPECT_THAT(
2784 statusor.status().error_message(),
2785 HasSubstr("All components of the offset index in a gather op must either "
2786 "be a offset dimension or explicitly collapsed"))
2787 << statusor.status();
2788 }
2789
TEST_F(GatherShapeInferenceTest,InvalidGatherDimNumbers_OutOfBoundsWindowToInputMapping)2790 TEST_F(GatherShapeInferenceTest,
2791 InvalidGatherDimNumbers_OutOfBoundsWindowToInputMapping) {
2792 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2793 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2794 HloGatherInstruction::MakeGatherDimNumbers(
2795 /*offset_dims=*/{4, 5, 6, 7, 8},
2796 /*collapsed_slice_dims=*/{0, 1, 2, 3, 19},
2797 /*start_index_map=*/{0, 1, 2, 3, 4},
2798 /*index_vector_dim=*/4),
2799 /*slice_sizes=*/{30, 29, 28, 27, 26});
2800 ASSERT_FALSE(statusor.ok());
2801 EXPECT_THAT(statusor.status().error_message(),
2802 HasSubstr("Invalid collapsed_slice_dims set in gather op; valid "
2803 "range is [0, 5), got: 19"))
2804 << statusor.status();
2805 }
2806
TEST_F(GatherShapeInferenceTest,InvalidGatherDimNumbers_RepeatedWindowToInputMapping)2807 TEST_F(GatherShapeInferenceTest,
2808 InvalidGatherDimNumbers_RepeatedWindowToInputMapping) {
2809 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2810 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2811 HloGatherInstruction::MakeGatherDimNumbers(
2812 /*offset_dims=*/{4, 5, 6, 7, 8},
2813 /*collapsed_slice_dims=*/{0, 1, 2, 3, 3},
2814 /*start_index_map=*/{0, 1, 2, 3, 4},
2815 /*index_vector_dim=*/4),
2816 /*slice_sizes=*/{30, 29, 28, 27, 26});
2817 ASSERT_FALSE(statusor.ok());
2818 EXPECT_THAT(statusor.status().error_message(),
2819 HasSubstr("Repeated dimensions not allowed in "
2820 "collapsed_slice_dims in gather op"))
2821 << statusor.status();
2822 }
2823
TEST_F(GatherShapeInferenceTest,InvalidGatherDimNumbers_MismatchingGatherToInputMapping)2824 TEST_F(GatherShapeInferenceTest,
2825 InvalidGatherDimNumbers_MismatchingGatherToInputMapping) {
2826 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2827 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2828 HloGatherInstruction::MakeGatherDimNumbers(
2829 /*offset_dims=*/{4, 5, 6, 7, 8},
2830 /*collapsed_slice_dims=*/{},
2831 /*start_index_map=*/{0, 1, 2, 3},
2832 /*index_vector_dim=*/4),
2833 /*slice_sizes=*/{30, 29, 28, 27, 26});
2834 ASSERT_FALSE(statusor.ok());
2835 EXPECT_THAT(statusor.status().error_message(),
2836 HasSubstr("Gather op has 4 elements in start_index_map and "
2837 "the bound of dimension index_vector_dim=4 of "
2838 "start_indices is 5. These two numbers must be equal."))
2839 << statusor.status();
2840 }
2841
TEST_F(GatherShapeInferenceTest,InvalidGatherDimNumbers_OutOfBoundsGatherToInputMapping)2842 TEST_F(GatherShapeInferenceTest,
2843 InvalidGatherDimNumbers_OutOfBoundsGatherToInputMapping) {
2844 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2845 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2846 HloGatherInstruction::MakeGatherDimNumbers(
2847 /*offset_dims=*/{4, 5, 6, 7, 8},
2848 /*collapsed_slice_dims=*/{},
2849 /*start_index_map=*/{0, 1, 2, 3, 7},
2850 /*index_vector_dim=*/4),
2851 /*slice_sizes=*/{30, 29, 28, 27, 26});
2852 ASSERT_FALSE(statusor.ok());
2853 EXPECT_THAT(statusor.status().error_message(),
2854 HasSubstr("Invalid start_index_map; domain is [0, 5), got: 4->7"))
2855 << statusor.status();
2856 }
2857
TEST_F(GatherShapeInferenceTest,InvalidGatherDimNumbers_RepeatedGatherToInputMapping)2858 TEST_F(GatherShapeInferenceTest,
2859 InvalidGatherDimNumbers_RepeatedGatherToInputMapping) {
2860 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2861 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2862 HloGatherInstruction::MakeGatherDimNumbers(
2863 /*offset_dims=*/{4, 5, 6, 7, 8},
2864 /*collapsed_slice_dims=*/{},
2865 /*start_index_map=*/{0, 1, 2, 3, 3},
2866 /*index_vector_dim=*/4),
2867 /*slice_sizes=*/{30, 29, 28, 27, 26});
2868 ASSERT_FALSE(statusor.ok());
2869 EXPECT_THAT(
2870 statusor.status().error_message(),
2871 HasSubstr("Repeated dimensions are not allowed in start_index_map"))
2872 << statusor.status();
2873 }
2874
TEST_F(GatherShapeInferenceTest,InvalidGatherDimNumbers_NonAscendingElidedWindowDims)2875 TEST_F(GatherShapeInferenceTest,
2876 InvalidGatherDimNumbers_NonAscendingElidedWindowDims) {
2877 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2878 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2879 HloGatherInstruction::MakeGatherDimNumbers(
2880 /*offset_dims=*/{4, 5, 6, 7, 8},
2881 /*collapsed_slice_dims=*/{2, 1},
2882 /*start_index_map=*/{0, 1, 2, 3, 4},
2883 /*index_vector_dim=*/4),
2884 /*slice_sizes=*/{1, 1, 28, 27, 26});
2885 ASSERT_FALSE(statusor.ok());
2886 EXPECT_THAT(statusor.status().error_message(),
2887 HasSubstr("collapsed_slice_dims in gather op must be sorted"))
2888 << statusor.status();
2889 }
2890
TEST_F(GatherShapeInferenceTest,InvalidGatherDimNumbers_WindowBoundsTooLarge)2891 TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_WindowBoundsTooLarge) {
2892 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2893 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2894 HloGatherInstruction::MakeGatherDimNumbers(
2895 /*offset_dims=*/{4, 5, 6, 7},
2896 /*collapsed_slice_dims=*/{2},
2897 /*start_index_map=*/{0, 1, 2, 3, 4},
2898 /*index_vector_dim=*/4),
2899 /*slice_sizes=*/{30, 29, 1, 300, 26});
2900 ASSERT_FALSE(statusor.ok());
2901 EXPECT_THAT(statusor.status().error_message(),
2902 HasSubstr("Slice size at index 3 in gather op is out of range, "
2903 "must be within [0, 48), got 300."))
2904 << statusor.status();
2905 }
2906
TEST_F(GatherShapeInferenceTest,InvalidGatherDimNumbers_MismatchingNumberOfWindowBounds)2907 TEST_F(GatherShapeInferenceTest,
2908 InvalidGatherDimNumbers_MismatchingNumberOfWindowBounds) {
2909 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2910 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2911 HloGatherInstruction::MakeGatherDimNumbers(
2912 /*offset_dims=*/{4, 5, 6, 7, 8},
2913 /*collapsed_slice_dims=*/{},
2914 /*start_index_map=*/{0, 1, 2, 3, 4},
2915 /*index_vector_dim=*/4),
2916 /*slice_sizes=*/{30, 29, 28, 26});
2917 ASSERT_FALSE(statusor.ok());
2918 EXPECT_THAT(
2919 statusor.status().error_message(),
2920 HasSubstr("Gather op must have one slice size for every input dimension"))
2921 << statusor.status();
2922 }
2923
TEST_F(GatherShapeInferenceTest,InvalidGatherDimNumbers_WindowBoundsNot1ForElidedDim)2924 TEST_F(GatherShapeInferenceTest,
2925 InvalidGatherDimNumbers_WindowBoundsNot1ForElidedDim) {
2926 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2927 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2928 HloGatherInstruction::MakeGatherDimNumbers(
2929 /*offset_dims=*/{4, 5, 6, 7},
2930 /*collapsed_slice_dims=*/{1},
2931 /*start_index_map=*/{0, 1, 2, 3, 4},
2932 /*index_vector_dim=*/4),
2933 /*slice_sizes=*/{30, 29, 28, 26, 20});
2934 ASSERT_FALSE(statusor.ok());
2935 EXPECT_THAT(
2936 statusor.status().error_message(),
2937 HasSubstr("Gather op can only collapse slice dims with bound 1 or 0, "
2938 "but bound is 29 for index 1 at position 0."))
2939 << statusor.status();
2940 }
2941
TEST_F(GatherShapeInferenceTest,OutOfBoundsGatherIndicesLeafDim)2942 TEST_F(GatherShapeInferenceTest, OutOfBoundsGatherIndicesLeafDim) {
2943 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2944 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_,
2945 HloGatherInstruction::MakeGatherDimNumbers(
2946 /*offset_dims=*/{4, 5, 6, 7, 8},
2947 /*collapsed_slice_dims=*/{},
2948 /*start_index_map=*/{0, 1, 2, 3, 4},
2949 /*index_vector_dim=*/32),
2950 /*slice_sizes=*/{30, 29, 28, 27, 26});
2951
2952 ASSERT_FALSE(statusor.ok());
2953 EXPECT_THAT(statusor.status().error_message(),
2954 HasSubstr("Gather index leaf dimension must be within [0, "
2955 "rank(start_indices) + 1)"))
2956 << statusor.status();
2957 }
2958
2959 class ScatterShapeInferenceTest
2960 : public ShapeInferenceTest,
2961 public ::testing::WithParamInterface<std::vector<PrimitiveType>> {
2962 protected:
2963 struct ScatterShapes {
Addxla::__anon5c819e1b0111::ScatterShapeInferenceTest::ScatterShapes2964 void Add(Shape shape) {
2965 shapes.push_back(std::move(shape));
2966 ptrs.push_back(&shapes.back());
2967 }
2968 std::vector<Shape> shapes;
2969 std::vector<const Shape*> ptrs;
2970 };
CreateShapes(absl::Span<const int64_t> operand_dims,const Shape & scatter_indices_shape,absl::Span<const int64_t> update_dims,absl::Span<const PrimitiveType> types)2971 static ScatterShapes CreateShapes(absl::Span<const int64_t> operand_dims,
2972 const Shape& scatter_indices_shape,
2973 absl::Span<const int64_t> update_dims,
2974 absl::Span<const PrimitiveType> types) {
2975 CHECK(!types.empty());
2976 size_t size = types.size() * 2 + 1;
2977 ScatterShapes shapes;
2978 shapes.shapes.reserve(size);
2979 shapes.ptrs.reserve(size);
2980 for (PrimitiveType type : types) {
2981 shapes.Add(ShapeUtil::MakeShape(type, operand_dims));
2982 }
2983 shapes.Add(scatter_indices_shape);
2984 for (PrimitiveType type : types) {
2985 shapes.Add(ShapeUtil::MakeShape(type, update_dims));
2986 }
2987 return shapes;
2988 }
Collate(absl::Span<const int64_t> dims,absl::Span<const PrimitiveType> types)2989 static Shape Collate(absl::Span<const int64_t> dims,
2990 absl::Span<const PrimitiveType> types) {
2991 CHECK(!types.empty());
2992 if (types.size() == 1) {
2993 return ShapeUtil::MakeShape(types[0], dims);
2994 }
2995 std::vector<Shape> shapes;
2996 for (PrimitiveType type : types) {
2997 shapes.push_back(ShapeUtil::MakeShape(type, dims));
2998 }
2999 return ShapeUtil::MakeTupleShape(shapes);
3000 }
scalar(PrimitiveType type)3001 static Shape scalar(PrimitiveType type) {
3002 return ShapeUtil::MakeShape(type, {});
3003 }
s64_vector(int dim)3004 static Shape s64_vector(int dim) { return ShapeUtil::MakeShape(S64, {dim}); }
s64_tensor(absl::Span<const int64_t> dims)3005 static Shape s64_tensor(absl::Span<const int64_t> dims) {
3006 return ShapeUtil::MakeShape(S64, dims);
3007 }
to_apply(absl::Span<const PrimitiveType> types)3008 static ProgramShape to_apply(absl::Span<const PrimitiveType> types) {
3009 CHECK(!types.empty());
3010 ProgramShape program_shape;
3011 Shape& result = *program_shape.mutable_result();
3012 result = ShapeUtil::MakeNil();
3013 result.mutable_tuple_shapes()->reserve(types.size());
3014 program_shape.mutable_parameters()->reserve(types.size() * 2);
3015 for (PrimitiveType type : types) {
3016 *program_shape.add_parameters() = scalar(type);
3017 *result.add_tuple_shapes() = scalar(type);
3018 }
3019 for (PrimitiveType type : types) {
3020 *program_shape.add_parameters() = scalar(type);
3021 }
3022 return program_shape;
3023 }
types() const3024 std::vector<PrimitiveType> types() const { return GetParam(); }
3025 };
3026
TEST_P(ScatterShapeInferenceTest,TfScatterWithFullUpdates)3027 TEST_P(ScatterShapeInferenceTest, TfScatterWithFullUpdates) {
3028 auto shapes = CreateShapes({64, 48}, s64_vector(32), {64, 32}, types());
3029 TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape,
3030 ShapeInference::InferScatterShape(
3031 shapes.ptrs, to_apply(types()),
3032 HloScatterInstruction::MakeScatterDimNumbers(
3033 /*update_window_dims=*/{0},
3034 /*inserted_window_dims=*/{1},
3035 /*scatter_dims_to_operand_dims=*/{1},
3036 /*index_vector_dim=*/1)));
3037 EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, Collate({64, 48}, types())))
3038 << ShapeUtil::HumanString(scatter_shape);
3039 }
3040
TEST_P(ScatterShapeInferenceTest,TfScatterWithFullUpdatesV2)3041 TEST_P(ScatterShapeInferenceTest, TfScatterWithFullUpdatesV2) {
3042 auto shapes = CreateShapes({64, 48}, s64_vector(32), {32, 48}, types());
3043 TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape,
3044 ShapeInference::InferScatterShape(
3045 shapes.ptrs, to_apply(types()),
3046 HloScatterInstruction::MakeScatterDimNumbers(
3047 /*update_window_dims=*/{1},
3048 /*inserted_window_dims=*/{0},
3049 /*scatter_dims_to_operand_dims=*/{0},
3050 /*index_vector_dim=*/1)));
3051 EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, Collate({64, 48}, types())))
3052 << ShapeUtil::HumanString(scatter_shape);
3053 }
3054
TEST_P(ScatterShapeInferenceTest,TfScatterWithPartialUpdates)3055 TEST_P(ScatterShapeInferenceTest, TfScatterWithPartialUpdates) {
3056 auto shapes = CreateShapes({64, 48}, s64_vector(32), {10, 32}, types());
3057 TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape,
3058 ShapeInference::InferScatterShape(
3059 shapes.ptrs, to_apply(types()),
3060 HloScatterInstruction::MakeScatterDimNumbers(
3061 /*update_window_dims=*/{0},
3062 /*inserted_window_dims=*/{1},
3063 /*scatter_dims_to_operand_dims=*/{1},
3064 /*index_vector_dim=*/1)));
3065 EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, Collate({64, 48}, types())))
3066 << ShapeUtil::HumanString(scatter_shape);
3067 }
3068
TEST_P(ScatterShapeInferenceTest,TfScatterWithPartialUpdatesV2)3069 TEST_P(ScatterShapeInferenceTest, TfScatterWithPartialUpdatesV2) {
3070 auto shapes = CreateShapes({64, 48}, s64_vector(32), {32, 8}, types());
3071 TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape,
3072 ShapeInference::InferScatterShape(
3073 shapes.ptrs, to_apply(types()),
3074 HloScatterInstruction::MakeScatterDimNumbers(
3075 /*update_window_dims=*/{1},
3076 /*inserted_window_dims=*/{0},
3077 /*scatter_dims_to_operand_dims=*/{0},
3078 /*index_vector_dim=*/1)));
3079 EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, Collate({64, 48}, types())))
3080 << ShapeUtil::HumanString(scatter_shape);
3081 }
3082
TEST_P(ScatterShapeInferenceTest,TfScatterWithUpdatesBiggerThanInput)3083 TEST_P(ScatterShapeInferenceTest, TfScatterWithUpdatesBiggerThanInput) {
3084 auto shapes = CreateShapes({64, 48}, s64_vector(32), {65, 32}, types());
3085 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3086 shapes.ptrs, to_apply(types()),
3087 HloScatterInstruction::MakeScatterDimNumbers(
3088 /*update_window_dims=*/{0},
3089 /*inserted_window_dims=*/{1},
3090 /*scatter_dims_to_operand_dims=*/{1},
3091 /*index_vector_dim=*/1));
3092 ASSERT_FALSE(statusor.ok());
3093 EXPECT_THAT(
3094 statusor.status().error_message(),
3095 HasSubstr("Bounds of the window dimensions of updates must not exceed "
3096 "the bounds of the corresponding dimensions of operand."))
3097 << statusor.status();
3098 }
3099
TEST_P(ScatterShapeInferenceTest,TfScatterWithUpdatesBiggerThanInputV2)3100 TEST_P(ScatterShapeInferenceTest, TfScatterWithUpdatesBiggerThanInputV2) {
3101 auto shapes = CreateShapes({64, 48}, s64_vector(32), {32, 49}, types());
3102 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3103 shapes.ptrs, to_apply(types()),
3104 HloScatterInstruction::MakeScatterDimNumbers(
3105 /*update_window_dims=*/{1},
3106 /*inserted_window_dims=*/{0},
3107 /*scatter_dims_to_operand_dims=*/{1},
3108 /*index_vector_dim=*/1));
3109 ASSERT_FALSE(statusor.ok());
3110 EXPECT_THAT(
3111 statusor.status().error_message(),
3112 HasSubstr("Bounds of the window dimensions of updates must not exceed "
3113 "the bounds of the corresponding dimensions of operand."))
3114 << statusor.status();
3115 }
3116
TEST_P(ScatterShapeInferenceTest,TfScatterWithUpdatesNotMatchingIndices)3117 TEST_P(ScatterShapeInferenceTest, TfScatterWithUpdatesNotMatchingIndices) {
3118 auto shapes = CreateShapes({64, 48}, s64_vector(32), {64, 31}, types());
3119 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3120 shapes.ptrs, to_apply(types()),
3121 HloScatterInstruction::MakeScatterDimNumbers(
3122 /*update_window_dims=*/{0},
3123 /*inserted_window_dims=*/{1},
3124 /*scatter_dims_to_operand_dims=*/{1},
3125 /*index_vector_dim=*/1));
3126 ASSERT_FALSE(statusor.ok());
3127 EXPECT_THAT(
3128 statusor.status().error_message(),
3129 HasSubstr(
3130 "Bounds of the scatter dimensions of updates must be same as the "
3131 "bounds of the corresponding dimensions of scatter indices."))
3132 << statusor.status();
3133 }
3134
TEST_P(ScatterShapeInferenceTest,TfScatterWithUpdatesNotMatchingIndicesV2)3135 TEST_P(ScatterShapeInferenceTest, TfScatterWithUpdatesNotMatchingIndicesV2) {
3136 auto shapes = CreateShapes({64, 48}, s64_vector(32), {31, 48}, types());
3137 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3138 shapes.ptrs, to_apply(types()),
3139 HloScatterInstruction::MakeScatterDimNumbers(
3140 /*update_window_dims=*/{1},
3141 /*inserted_window_dims=*/{0},
3142 /*scatter_dims_to_operand_dims=*/{1},
3143 /*index_vector_dim=*/1));
3144 ASSERT_FALSE(statusor.ok());
3145 EXPECT_THAT(
3146 statusor.status().error_message(),
3147 HasSubstr(
3148 "Bounds of the scatter dimensions of updates must be same as the "
3149 "bounds of the corresponding dimensions of scatter indices."))
3150 << statusor.status();
3151 }
3152
TEST_P(ScatterShapeInferenceTest,TfScatterNdWithFullUpdates)3153 TEST_P(ScatterShapeInferenceTest, TfScatterNdWithFullUpdates) {
3154 auto shapes = CreateShapes({64, 48}, s64_tensor({10, 9, 8, 7, 1}),
3155 {10, 9, 8, 7, 48}, types());
3156 TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape,
3157 ShapeInference::InferScatterShape(
3158 shapes.ptrs, to_apply(types()),
3159 HloScatterInstruction::MakeScatterDimNumbers(
3160 /*update_window_dims=*/{4},
3161 /*inserted_window_dims=*/{0},
3162 /*scatter_dims_to_operand_dims=*/{0},
3163 /*index_vector_dim=*/4)));
3164 EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, Collate({64, 48}, types())))
3165 << ShapeUtil::HumanString(scatter_shape);
3166 }
3167
TEST_P(ScatterShapeInferenceTest,TfScatterNdWithFullUpdatesV2)3168 TEST_P(ScatterShapeInferenceTest, TfScatterNdWithFullUpdatesV2) {
3169 auto shapes = CreateShapes({64, 48}, s64_tensor({10, 9, 8, 7, 1}),
3170 {10, 9, 8, 7, 64}, types());
3171 TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape,
3172 ShapeInference::InferScatterShape(
3173 shapes.ptrs, to_apply(types()),
3174 HloScatterInstruction::MakeScatterDimNumbers(
3175 /*update_window_dims=*/{4},
3176 /*inserted_window_dims=*/{1},
3177 /*scatter_dims_to_operand_dims=*/{0},
3178 /*index_vector_dim=*/4)));
3179 EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, Collate({64, 48}, types())))
3180 << ShapeUtil::HumanString(scatter_shape);
3181 }
3182
TEST_P(ScatterShapeInferenceTest,TfScatterNdWithPartialUpdates)3183 TEST_P(ScatterShapeInferenceTest, TfScatterNdWithPartialUpdates) {
3184 auto shapes = CreateShapes({64, 48}, s64_tensor({10, 9, 8, 7, 1}),
3185 {10, 9, 8, 7, 10}, types());
3186 TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape,
3187 ShapeInference::InferScatterShape(
3188 shapes.ptrs, to_apply(types()),
3189 HloScatterInstruction::MakeScatterDimNumbers(
3190 /*update_window_dims=*/{4},
3191 /*inserted_window_dims=*/{0},
3192 /*scatter_dims_to_operand_dims=*/{0},
3193 /*index_vector_dim=*/4)));
3194 EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, Collate({64, 48}, types())))
3195 << ShapeUtil::HumanString(scatter_shape);
3196 }
3197
TEST_P(ScatterShapeInferenceTest,TfScatterNdWithPartialUpdatesV2)3198 TEST_P(ScatterShapeInferenceTest, TfScatterNdWithPartialUpdatesV2) {
3199 auto shapes = CreateShapes({64, 48}, s64_tensor({10, 9, 8, 7, 1}),
3200 {10, 9, 8, 7, 12}, types());
3201 TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape,
3202 ShapeInference::InferScatterShape(
3203 shapes.ptrs, to_apply(types()),
3204 HloScatterInstruction::MakeScatterDimNumbers(
3205 /*update_window_dims=*/{4},
3206 /*inserted_window_dims=*/{1},
3207 /*scatter_dims_to_operand_dims=*/{0},
3208 /*index_vector_dim=*/4)));
3209 EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, Collate({64, 48}, types())))
3210 << ShapeUtil::HumanString(scatter_shape);
3211 }
3212
TEST_P(ScatterShapeInferenceTest,TfScatterNdWithUpdatesBiggerThanInput)3213 TEST_P(ScatterShapeInferenceTest, TfScatterNdWithUpdatesBiggerThanInput) {
3214 auto shapes = CreateShapes({64, 48}, s64_tensor({10, 9, 8, 7, 1}),
3215 {10, 9, 8, 7, 65}, types());
3216 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3217 shapes.ptrs, to_apply(types()),
3218 HloScatterInstruction::MakeScatterDimNumbers(
3219 /*update_window_dims=*/{4},
3220 /*inserted_window_dims=*/{1},
3221 /*scatter_dims_to_operand_dims=*/{0},
3222 /*index_vector_dim=*/4));
3223 ASSERT_FALSE(statusor.ok());
3224 EXPECT_THAT(
3225 statusor.status().error_message(),
3226 HasSubstr("Bounds of the window dimensions of updates must not exceed "
3227 "the bounds of the corresponding dimensions of operand."))
3228 << statusor.status();
3229 }
3230
TEST_P(ScatterShapeInferenceTest,TfScatterNdWithUpdatesNotMatchingIndices)3231 TEST_P(ScatterShapeInferenceTest, TfScatterNdWithUpdatesNotMatchingIndices) {
3232 auto shapes = CreateShapes({64, 48}, s64_tensor({10, 9, 8, 7, 1}),
3233 {9, 9, 8, 7, 64}, types());
3234 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3235 shapes.ptrs, to_apply(types()),
3236 HloScatterInstruction::MakeScatterDimNumbers(
3237 /*update_window_dims=*/{4},
3238 /*inserted_window_dims=*/{1},
3239 /*scatter_dims_to_operand_dims=*/{0},
3240 /*index_vector_dim=*/4));
3241 ASSERT_FALSE(statusor.ok());
3242 EXPECT_THAT(
3243 statusor.status().error_message(),
3244 HasSubstr(
3245 "Bounds of the scatter dimensions of updates must be same as the "
3246 "bounds of the corresponding dimensions of scatter indices."))
3247 << statusor.status();
3248 }
3249
TEST_P(ScatterShapeInferenceTest,TfBatchDynamicUpdateSlice)3250 TEST_P(ScatterShapeInferenceTest, TfBatchDynamicUpdateSlice) {
3251 auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}),
3252 {10, 9, 8, 7, 30, 29, 28, 27, 26}, types());
3253 TF_ASSERT_OK_AND_ASSIGN(
3254 Shape scatter_shape,
3255 ShapeInference::InferScatterShape(
3256 shapes.ptrs, to_apply(types()),
3257 HloScatterInstruction::MakeScatterDimNumbers(
3258 /*update_window_dims=*/{4, 5, 6, 7, 8},
3259 /*inserted_window_dims=*/{},
3260 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3261 /*index_vector_dim=*/4)));
3262 EXPECT_TRUE(
3263 ShapeUtil::Equal(scatter_shape, Collate({50, 49, 48, 47, 46}, types())))
3264 << ShapeUtil::HumanString(scatter_shape);
3265 }
3266
TEST_P(ScatterShapeInferenceTest,NonDefaultScatterIndicesLeafDim)3267 TEST_P(ScatterShapeInferenceTest, NonDefaultScatterIndicesLeafDim) {
3268 auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 5, 7, 6}),
3269 {10, 9, 7, 6, 30, 29, 28, 27, 26}, types());
3270 TF_ASSERT_OK_AND_ASSIGN(
3271 Shape scatter_shape,
3272 ShapeInference::InferScatterShape(
3273 shapes.ptrs, to_apply(types()),
3274 HloScatterInstruction::MakeScatterDimNumbers(
3275 /*update_window_dims=*/{4, 5, 6, 7, 8},
3276 /*inserted_window_dims=*/{},
3277 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3278 /*index_vector_dim=*/2)));
3279
3280 EXPECT_TRUE(
3281 ShapeUtil::Equal(scatter_shape, Collate({50, 49, 48, 47, 46}, types())))
3282 << ShapeUtil::HumanString(scatter_shape);
3283 }
3284
TEST_P(ScatterShapeInferenceTest,NonDefaultScatterIndicesLeafDimV2)3285 TEST_P(ScatterShapeInferenceTest, NonDefaultScatterIndicesLeafDimV2) {
3286 auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({5, 10, 9, 7, 6}),
3287 {10, 9, 7, 6, 30, 29, 28, 27, 26}, types());
3288 TF_ASSERT_OK_AND_ASSIGN(
3289 Shape scatter_shape,
3290 ShapeInference::InferScatterShape(
3291 shapes.ptrs, to_apply(types()),
3292 HloScatterInstruction::MakeScatterDimNumbers(
3293 /*update_window_dims=*/{4, 5, 6, 7, 8},
3294 /*inserted_window_dims=*/{},
3295 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3296 /*index_vector_dim=*/0)));
3297
3298 EXPECT_TRUE(
3299 ShapeUtil::Equal(scatter_shape, Collate({50, 49, 48, 47, 46}, types())))
3300 << ShapeUtil::HumanString(scatter_shape);
3301 }
3302
TEST_P(ScatterShapeInferenceTest,NoUpdateScatterDims)3303 TEST_P(ScatterShapeInferenceTest, NoUpdateScatterDims) {
3304 auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_vector(5),
3305 {30, 29, 28, 27, 26}, types());
3306 // This is equivalent to a dynamic update slice.
3307 TF_ASSERT_OK_AND_ASSIGN(
3308 Shape scatter_shape,
3309 ShapeInference::InferScatterShape(
3310 shapes.ptrs, to_apply(types()),
3311 HloScatterInstruction::MakeScatterDimNumbers(
3312 /*update_window_dims=*/{0, 1, 2, 3, 4},
3313 /*inserted_window_dims=*/{},
3314 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3315 /*index_vector_dim=*/0)));
3316
3317 EXPECT_TRUE(
3318 ShapeUtil::Equal(scatter_shape, Collate({50, 49, 48, 47, 46}, types())))
3319 << ShapeUtil::HumanString(scatter_shape);
3320 }
3321
TEST_P(ScatterShapeInferenceTest,ScalarScatterIndices)3322 TEST_P(ScatterShapeInferenceTest, ScalarScatterIndices) {
3323 auto shapes = CreateShapes({50, 49, 48, 47, 46}, scalar(S64),
3324 {30, 29, 28, 27}, types());
3325 // The scalar indices "tensor" is a scalar S here that's used to update a
3326 // [30,29,28,27] shaped tensor within the operand at position S.
3327 TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape,
3328 ShapeInference::InferScatterShape(
3329 shapes.ptrs, to_apply(types()),
3330 HloScatterInstruction::MakeScatterDimNumbers(
3331 /*update_window_dims=*/{0, 1, 2, 3},
3332 /*inserted_window_dims=*/{0},
3333 /*scatter_dims_to_operand_dims=*/{0},
3334 /*index_vector_dim=*/0)));
3335
3336 EXPECT_TRUE(
3337 ShapeUtil::Equal(scatter_shape, Collate({50, 49, 48, 47, 46}, types())))
3338 << ShapeUtil::HumanString(scatter_shape);
3339 }
3340
TEST_P(ScatterShapeInferenceTest,ScatterWithTupleShapedTensorInput)3341 TEST_P(ScatterShapeInferenceTest, ScatterWithTupleShapedTensorInput) {
3342 Shape tuple_shape =
3343 ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 1}),
3344 ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 1})});
3345 Shape s64_vector_32 = s64_vector(32);
3346 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3347 {&tuple_shape, &s64_vector_32, &s64_vector_32}, to_apply(types()),
3348 HloScatterInstruction::MakeScatterDimNumbers(
3349 /*update_window_dims=*/{0},
3350 /*inserted_window_dims=*/{1},
3351 /*scatter_dims_to_operand_dims=*/{1},
3352 /*index_vector_dim=*/1));
3353 ASSERT_FALSE(statusor.ok());
3354 EXPECT_THAT(statusor.status().error_message(),
3355 HasSubstr("Expected array argument for operand"))
3356 << statusor.status();
3357 }
3358
TEST_P(ScatterShapeInferenceTest,ScatterWithTupleShapedScatterIndicesInput)3359 TEST_P(ScatterShapeInferenceTest, ScatterWithTupleShapedScatterIndicesInput) {
3360 Shape tuple_shape =
3361 ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 1}),
3362 ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 1})});
3363 Shape s64_vector_32 = s64_vector(32);
3364 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3365 {&s64_vector_32, &tuple_shape, &s64_vector_32}, to_apply(types()),
3366 HloScatterInstruction::MakeScatterDimNumbers(
3367 /*update_window_dims=*/{0},
3368 /*inserted_window_dims=*/{1},
3369 /*scatter_dims_to_operand_dims=*/{1},
3370 /*index_vector_dim=*/0));
3371 ASSERT_FALSE(statusor.ok());
3372 EXPECT_THAT(statusor.status().error_message(),
3373 HasSubstr("Expected array argument for scatter indices"))
3374 << statusor.status();
3375 }
3376
TEST_P(ScatterShapeInferenceTest,ScatterWithTupleShapedUpdatesInput)3377 TEST_P(ScatterShapeInferenceTest, ScatterWithTupleShapedUpdatesInput) {
3378 Shape tuple_shape =
3379 ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 1}),
3380 ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 1})});
3381 Shape s64_vector_32 = s64_vector(32);
3382 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3383 {&s64_vector_32, &s64_vector_32, &tuple_shape}, to_apply(types()),
3384 HloScatterInstruction::MakeScatterDimNumbers(
3385 /*update_window_dims=*/{0},
3386 /*inserted_window_dims=*/{1},
3387 /*scatter_dims_to_operand_dims=*/{1},
3388 /*index_vector_dim=*/0));
3389 ASSERT_FALSE(statusor.ok());
3390 EXPECT_THAT(statusor.status().error_message(),
3391 HasSubstr("Expected array argument for updates"))
3392 << statusor.status();
3393 }
3394
TEST_P(ScatterShapeInferenceTest,FloatingPointScatterIndicesInput)3395 TEST_P(ScatterShapeInferenceTest, FloatingPointScatterIndicesInput) {
3396 Shape s64_vector_32 = s64_vector(32);
3397 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3398 {&s64_vector_32, &vector_32_, &s64_vector_32}, to_apply(types()),
3399 HloScatterInstruction::MakeScatterDimNumbers(
3400 /*update_window_dims=*/{0},
3401 /*inserted_window_dims=*/{1},
3402 /*scatter_dims_to_operand_dims=*/{1},
3403 /*index_vector_dim=*/0));
3404 ASSERT_FALSE(statusor.ok());
3405 EXPECT_THAT(statusor.status().error_message(),
3406 HasSubstr("Scatter indices parameter must be an integral tensor"))
3407 << statusor.status();
3408 }
3409
TEST_P(ScatterShapeInferenceTest,OutOfBoundsScatterIndicesLeafDim)3410 TEST_P(ScatterShapeInferenceTest, OutOfBoundsScatterIndicesLeafDim) {
3411 auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}),
3412 {10, 9, 8, 7, 30, 29, 28}, types());
3413 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3414 shapes.ptrs, to_apply(types()),
3415 HloScatterInstruction::MakeScatterDimNumbers(
3416 /*update_window_dims=*/{4, 5, 6},
3417 /*inserted_window_dims=*/{1, 2},
3418 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3419 /*index_vector_dim=*/10));
3420 ASSERT_FALSE(statusor.ok());
3421 EXPECT_THAT(statusor.status().error_message(),
3422 HasSubstr("Scatter index leaf dimension must be within [0, "
3423 "rank(scatter_indices) + 1)"))
3424 << statusor.status();
3425 }
3426
TEST_P(ScatterShapeInferenceTest,InvalidUpdates)3427 TEST_P(ScatterShapeInferenceTest, InvalidUpdates) {
3428 auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}),
3429 {10, 9, 8, 7, 30, 29, 28, 50}, types());
3430 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3431 shapes.ptrs, to_apply(types()),
3432 HloScatterInstruction::MakeScatterDimNumbers(
3433 /*update_window_dims=*/{4, 5, 6},
3434 /*inserted_window_dims=*/{1, 2},
3435 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3436 /*index_vector_dim=*/4));
3437 ASSERT_FALSE(statusor.ok());
3438 EXPECT_THAT(statusor.status().error_message(),
3439 HasSubstr("Updates tensor must be of rank 7; got 8."))
3440 << statusor.status();
3441 }
3442
TEST_P(ScatterShapeInferenceTest,InvalidUpdateComputation)3443 TEST_P(ScatterShapeInferenceTest, InvalidUpdateComputation) {
3444 const ProgramShape invalid_update_computation =
3445 ShapeUtil::MakeProgramShape({f32_}, f32_);
3446 auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}),
3447 {10, 9, 8, 7, 30, 29, 28}, types());
3448 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3449 shapes.ptrs, invalid_update_computation,
3450 HloScatterInstruction::MakeScatterDimNumbers(
3451 /*update_window_dims=*/{4, 5, 6},
3452 /*inserted_window_dims=*/{1, 2},
3453 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3454 /*index_vector_dim=*/4));
3455 ASSERT_FALSE(statusor.ok());
3456 EXPECT_THAT(statusor.status().error_message(),
3457 HasSubstr(absl::Substitute(
3458 "Reduction function must take $0 parameters, but takes 1",
3459 2 * types().size())))
3460 << statusor.status();
3461 }
3462
TEST_P(ScatterShapeInferenceTest,InvalidScatterDimNumbers_NonAscendingUpdateWindowDims)3463 TEST_P(ScatterShapeInferenceTest,
3464 InvalidScatterDimNumbers_NonAscendingUpdateWindowDims) {
3465 auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}),
3466 {10, 9, 8, 7, 30, 29, 28, 27, 26}, types());
3467 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3468 shapes.ptrs, to_apply(types()),
3469 HloScatterInstruction::MakeScatterDimNumbers(
3470 /*update_window_dims=*/{4, 5, 6, 8, 7},
3471 /*inserted_window_dims=*/{},
3472 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3473 /*index_vector_dim=*/4));
3474 ASSERT_FALSE(statusor.ok());
3475 EXPECT_THAT(statusor.status().error_message(),
3476 HasSubstr("update_window_dims in scatter op must be sorted"))
3477 << statusor.status();
3478 }
3479
TEST_P(ScatterShapeInferenceTest,InvalidScatterDimNumbers_RepeatedUpdateWindowDims)3480 TEST_P(ScatterShapeInferenceTest,
3481 InvalidScatterDimNumbers_RepeatedUpdateWindowDims) {
3482 auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}),
3483 {10, 9, 8, 7, 30, 29, 28, 27, 26}, types());
3484 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3485 shapes.ptrs, to_apply(types()),
3486 HloScatterInstruction::MakeScatterDimNumbers(
3487 /*update_window_dims=*/{4, 5, 6, 7, 7},
3488 /*inserted_window_dims=*/{},
3489 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3490 /*index_vector_dim=*/4));
3491 ASSERT_FALSE(statusor.ok());
3492 EXPECT_THAT(statusor.status().error_message(),
3493 HasSubstr("update_window_dims in scatter op must not repeat"))
3494 << statusor.status();
3495 }
3496
TEST_P(ScatterShapeInferenceTest,InvalidScatterDimNumbers_OutOfBoundsUpdateWindowDims)3497 TEST_P(ScatterShapeInferenceTest,
3498 InvalidScatterDimNumbers_OutOfBoundsUpdateWindowDims) {
3499 auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}),
3500 {10, 9, 8, 7, 30, 29, 28, 27, 26}, types());
3501 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3502 shapes.ptrs, to_apply(types()),
3503 HloScatterInstruction::MakeScatterDimNumbers(
3504 /*update_window_dims=*/{4, 5, 6, 7, 9},
3505 /*inserted_window_dims=*/{},
3506 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3507 /*index_vector_dim=*/4));
3508 ASSERT_FALSE(statusor.ok());
3509 EXPECT_THAT(statusor.status().error_message(),
3510 HasSubstr("Invalid update_window_dims set in scatter op; valid "
3511 "range is [0, 9)"))
3512 << statusor.status();
3513 }
3514
TEST_P(ScatterShapeInferenceTest,InvalidScatterDimNumbers_NonAscendingInsertedWindowDims)3515 TEST_P(ScatterShapeInferenceTest,
3516 InvalidScatterDimNumbers_NonAscendingInsertedWindowDims) {
3517 auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}),
3518 {10, 9, 8, 7, 30, 29, 28}, types());
3519 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3520 shapes.ptrs, to_apply(types()),
3521 HloScatterInstruction::MakeScatterDimNumbers(
3522 /*update_window_dims=*/{4, 5, 6},
3523 /*inserted_window_dims=*/{2, 1},
3524 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3525 /*index_vector_dim=*/4));
3526 ASSERT_FALSE(statusor.ok());
3527 EXPECT_THAT(statusor.status().error_message(),
3528 HasSubstr("inserted_window_dims in scatter op must be sorted"))
3529 << statusor.status();
3530 }
3531
TEST_P(ScatterShapeInferenceTest,InvalidScatterDimNumbers_RepeatedInsertedWindowDims)3532 TEST_P(ScatterShapeInferenceTest,
3533 InvalidScatterDimNumbers_RepeatedInsertedWindowDims) {
3534 auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}),
3535 {10, 9, 8, 7, 30, 29, 28}, types());
3536 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3537 shapes.ptrs, to_apply(types()),
3538 HloScatterInstruction::MakeScatterDimNumbers(
3539 /*update_window_dims=*/{4, 5, 6},
3540 /*inserted_window_dims=*/{1, 1},
3541 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3542 /*index_vector_dim=*/4));
3543 ASSERT_FALSE(statusor.ok());
3544 EXPECT_THAT(statusor.status().error_message(),
3545 HasSubstr("inserted_window_dims in scatter op must not repeat"))
3546 << statusor.status();
3547 }
3548
TEST_P(ScatterShapeInferenceTest,InvalidScatterDimNumbers_OutOfBoundsInsertedWindowDims)3549 TEST_P(ScatterShapeInferenceTest,
3550 InvalidScatterDimNumbers_OutOfBoundsInsertedWindowDims) {
3551 auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}),
3552 {10, 9, 8, 7, 30, 29, 28}, types());
3553 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3554 shapes.ptrs, to_apply(types()),
3555 HloScatterInstruction::MakeScatterDimNumbers(
3556 /*update_window_dims=*/{4, 5, 6},
3557 /*inserted_window_dims=*/{1, 5},
3558 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3559 /*index_vector_dim=*/4));
3560 ASSERT_FALSE(statusor.ok());
3561 EXPECT_THAT(statusor.status().error_message(),
3562 HasSubstr("Invalid inserted_window_dims set in scatter op; valid "
3563 "range is [0, 5)"))
3564 << statusor.status();
3565 }
3566
TEST_P(ScatterShapeInferenceTest,InvalidScatterDimNumbers_MismatchingScatterDimsToOperandDims)3567 TEST_P(ScatterShapeInferenceTest,
3568 InvalidScatterDimNumbers_MismatchingScatterDimsToOperandDims) {
3569 auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}),
3570 {10, 9, 8, 7, 30, 29, 28}, types());
3571 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3572 shapes.ptrs, to_apply(types()),
3573 HloScatterInstruction::MakeScatterDimNumbers(
3574 /*update_window_dims=*/{4, 5, 6},
3575 /*inserted_window_dims=*/{1, 2},
3576 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3},
3577 /*index_vector_dim=*/4));
3578 ASSERT_FALSE(statusor.ok());
3579 EXPECT_THAT(
3580 statusor.status().error_message(),
3581 HasSubstr("Scatter op has 4 elements in scatter_dims_to_operand_dims and "
3582 "the bound of dimension index_vector_dim=4 of scatter_indices "
3583 "is 5. These two numbers must be equal"))
3584 << statusor.status();
3585 }
3586
TEST_P(ScatterShapeInferenceTest,InvalidScatterDimNumbers_OutOfBoundsScatterDimsToOperandDims)3587 TEST_P(ScatterShapeInferenceTest,
3588 InvalidScatterDimNumbers_OutOfBoundsScatterDimsToOperandDims) {
3589 auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}),
3590 {10, 9, 8, 7, 30, 29, 28}, types());
3591 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3592 shapes.ptrs, to_apply(types()),
3593 HloScatterInstruction::MakeScatterDimNumbers(
3594 /*update_window_dims=*/{4, 5, 6},
3595 /*inserted_window_dims=*/{1, 2},
3596 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 10},
3597 /*index_vector_dim=*/4));
3598 ASSERT_FALSE(statusor.ok());
3599 EXPECT_THAT(statusor.status().error_message(),
3600 HasSubstr("Invalid scatter_dims_to_operand_dims mapping; domain "
3601 "is [0, 5), got: 4->10"))
3602 << statusor.status();
3603 }
3604
TEST_P(ScatterShapeInferenceTest,InvalidScatterDimNumbers_RepeatedValuesInScatterDimsToOperandDims)3605 TEST_P(ScatterShapeInferenceTest,
3606 InvalidScatterDimNumbers_RepeatedValuesInScatterDimsToOperandDims) {
3607 auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}),
3608 {10, 9, 8, 7, 30, 29, 28}, types());
3609 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3610 shapes.ptrs, to_apply(types()),
3611 HloScatterInstruction::MakeScatterDimNumbers(
3612 /*update_window_dims=*/{4, 5, 6},
3613 /*inserted_window_dims=*/{1, 2},
3614 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 2, 3},
3615 /*index_vector_dim=*/4));
3616 ASSERT_FALSE(statusor.ok());
3617 EXPECT_THAT(
3618 statusor.status().error_message(),
3619 HasSubstr(
3620 "Repeated dimensions not allowed in scatter_dims_to_operand_dims"))
3621 << statusor.status();
3622 }
3623
TEST_P(ScatterShapeInferenceTest,InvalidScatterDimNumbers_InsufficientWindowDims)3624 TEST_P(ScatterShapeInferenceTest,
3625 InvalidScatterDimNumbers_InsufficientWindowDims) {
3626 auto shapes = CreateShapes({50, 49, 48, 47, 46}, scalar(S64),
3627 {30, 29, 28, 27}, types());
3628 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3629 shapes.ptrs, to_apply(types()),
3630 HloScatterInstruction::MakeScatterDimNumbers(
3631 /*update_window_dims=*/{0, 1, 2, 3},
3632 /*inserted_window_dims=*/{},
3633 /*scatter_dims_to_operand_dims=*/{0},
3634 /*index_vector_dim=*/0));
3635 ASSERT_FALSE(statusor.ok());
3636 EXPECT_THAT(
3637 statusor.status().error_message(),
3638 HasSubstr(
3639 "Scatter op has window of size 4; doesn't match operand of rank 5."))
3640 << statusor.status();
3641 }
3642
3643 struct ScatterTestName {
operator ()xla::__anon5c819e1b0111::ScatterTestName3644 std::string operator()(
3645 const ::testing::TestParamInfo<std::vector<PrimitiveType>>& info) const {
3646 return absl::StrJoin(info.param, "_", absl::StreamFormatter());
3647 }
3648 };
3649
3650 INSTANTIATE_TEST_SUITE_P(All, ScatterShapeInferenceTest,
3651 ::testing::Values(std::vector<PrimitiveType>{F32},
3652 std::vector<PrimitiveType>{F32,
3653 BF16}),
3654 ScatterTestName());
3655
3656 } // namespace
3657 } // namespace xla
3658