xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/shape_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/shape.h"
17 
18 #include <numeric>
19 
20 #include "absl/hash/hash_testing.h"
21 #include "absl/strings/str_cat.h"
22 #include "absl/strings/str_join.h"
23 #include "tensorflow/compiler/xla/layout_util.h"
24 #include "tensorflow/compiler/xla/shape_util.h"
25 #include "tensorflow/compiler/xla/status_macros.h"
26 #include "tensorflow/compiler/xla/test.h"
27 #include "tensorflow/compiler/xla/test_helpers.h"
28 #include "tensorflow/compiler/xla/types.h"
29 #include "tensorflow/compiler/xla/util.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31 #include "tensorflow/core/platform/test_benchmark.h"
32 
33 namespace xla {
34 namespace {
35 
36 class ShapeTest : public ::testing::Test {
37  protected:
38   const Shape opaque_ = ShapeUtil::MakeOpaqueShape();
39   const Shape token_ = ShapeUtil::MakeTokenShape();
40   const Shape scalar_ = ShapeUtil::MakeShape(F32, {});
41   const Shape scalar_with_tile_ =
42       ShapeUtil::MakeShapeWithLayout(F32, {}, {}, {}, {Tile({256})});
43   const Shape matrix_ = ShapeUtil::MakeShape(U32, {1, 2});
44   const Shape matrix2_ = ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {0, 1});
45   const Shape tuple_ =
46       ShapeUtil::MakeTupleShape({opaque_, scalar_, matrix_, matrix2_});
47   const Shape nested_tuple_ =
48       ShapeUtil::MakeTupleShape({tuple_, matrix_, token_});
49   const Shape dynamic_matrix_ =
50       ShapeUtil::MakeShape(S32, {5, 2}, {true, false});
51 };
52 
TEST_F(ShapeTest,ShapeToFromProto)53 TEST_F(ShapeTest, ShapeToFromProto) {
54   for (const Shape& shape : {opaque_, token_, scalar_, matrix_, matrix2_,
55                              tuple_, nested_tuple_, dynamic_matrix_}) {
56     Shape shape_copy(shape.ToProto());
57     EXPECT_TRUE(ShapeUtil::Equal(shape, shape_copy))
58         << shape << " != " << shape_copy;
59   }
60 }
61 
TEST_F(ShapeTest,ShapeToString)62 TEST_F(ShapeTest, ShapeToString) {
63   EXPECT_EQ("opaque[]", opaque_.ToString());
64   EXPECT_EQ("token[]", token_.ToString());
65   EXPECT_EQ("f32[]", scalar_.ToString());
66   EXPECT_EQ("u32[1,2]", matrix_.ToString());
67   EXPECT_EQ("s32[3,4]", matrix2_.ToString());
68   EXPECT_EQ("(opaque[], f32[], u32[1,2], s32[3,4])", tuple_.ToString());
69   EXPECT_EQ("((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])",
70             nested_tuple_.ToString());
71 
72   EXPECT_EQ("opaque[]", opaque_.ToString(/*print_layout=*/true));
73   EXPECT_EQ("f32[]", scalar_.ToString(/*print_layout=*/true));
74   EXPECT_EQ("f32[]{:T(256)}",
75             scalar_with_tile_.ToString(/*print_layout=*/true));
76   EXPECT_EQ("u32[1,2]{1,0}", matrix_.ToString(/*print_layout=*/true));
77   EXPECT_EQ("s32[3,4]{0,1}", matrix2_.ToString(/*print_layout=*/true));
78   EXPECT_EQ("(opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1})",
79             tuple_.ToString(/*print_layout=*/true));
80   EXPECT_EQ(
81       "((opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1}), u32[1,2]{1,0}, "
82       "token[])",
83       nested_tuple_.ToString(/*print_layout=*/true));
84 }
85 
TEST_F(ShapeTest,DynamicShapeToString)86 TEST_F(ShapeTest, DynamicShapeToString) {
87   Shape array_shape =
88       ShapeUtil::MakeShape(F32, {23, 44, 55}, {true, false, true});
89   EXPECT_EQ("f32[<=23,44,<=55]", array_shape.ToString());
90 
91   array_shape.set_dynamic_dimension(2, false);
92   EXPECT_EQ("f32[<=23,44,55]", array_shape.ToString());
93 }
94 
TEST_F(ShapeTest,EqualityTest)95 TEST_F(ShapeTest, EqualityTest) {
96   // Different layouts.
97   EXPECT_NE(ShapeUtil::MakeShapeWithLayout(F32, {23, 44}, {1, 0}),
98             ShapeUtil::MakeShapeWithLayout(F32, {23, 44}, {0, 1}));
99 
100   // Different dims.
101   EXPECT_NE(ShapeUtil::MakeShapeWithLayout(F32, {44, 23}, {1, 0}),
102             ShapeUtil::MakeShapeWithLayout(F32, {23, 44}, {1, 0}));
103 
104   // Different elements.
105   EXPECT_NE(ShapeUtil::MakeShapeWithLayout(S32, {44, 23}, {1, 0}),
106             ShapeUtil::MakeShapeWithLayout(F32, {23, 44}, {1, 0}));
107 
108   // Equal shapes.
109   EXPECT_EQ(ShapeUtil::MakeShapeWithLayout(F32, {23, 44}, {1, 0}),
110             ShapeUtil::MakeShapeWithLayout(F32, {23, 44}, {1, 0}));
111 }
112 
TEST_F(ShapeTest,IsStatic)113 TEST_F(ShapeTest, IsStatic) {
114   EXPECT_TRUE(opaque_.is_static());
115   EXPECT_TRUE(token_.is_static());
116   EXPECT_TRUE(matrix_.is_static());
117   EXPECT_TRUE(tuple_.is_static());
118   EXPECT_TRUE(nested_tuple_.is_static());
119 
120   Shape dynamic_matrix = matrix_;
121   EXPECT_TRUE(dynamic_matrix.is_static());
122   dynamic_matrix.set_dynamic_dimension(1, true);
123   EXPECT_FALSE(dynamic_matrix.is_static());
124 
125   Shape dynamic_tuple = tuple_;
126   EXPECT_TRUE(dynamic_tuple.is_static());
127   ShapeUtil::GetMutableSubshape(&dynamic_tuple, {2})
128       ->set_dynamic_dimension(1, true);
129   EXPECT_FALSE(dynamic_tuple.is_static());
130 }
131 
TEST_F(ShapeTest,IsDynamicDimension)132 TEST_F(ShapeTest, IsDynamicDimension) {
133   Shape dynamic_matrix = matrix_;
134   dynamic_matrix.set_dynamic_dimension(1, true);
135   EXPECT_FALSE(dynamic_matrix.is_dynamic_dimension(0));
136   EXPECT_TRUE(dynamic_matrix.is_dynamic_dimension(1));
137 
138   Shape dynamic_tuple = tuple_;
139   EXPECT_TRUE(dynamic_tuple.is_static());
140   ShapeUtil::GetMutableSubshape(&dynamic_tuple, {2})
141       ->set_dynamic_dimension(1, true);
142   EXPECT_FALSE(dynamic_tuple.is_static());
143 }
144 
TEST_F(ShapeTest,ProgramShapeToFromProto)145 TEST_F(ShapeTest, ProgramShapeToFromProto) {
146   ProgramShape program_shape;
147   *program_shape.add_parameters() = ShapeUtil::MakeShape(F32, {1, 2, 3});
148   *program_shape.add_parameters() = ShapeUtil::MakeTokenShape();
149   *program_shape.add_parameters() = ShapeUtil::MakeShape(S64, {});
150   *program_shape.add_parameters() = ShapeUtil::MakeTupleShape(
151       {ShapeUtil::MakeShape(S32, {}),
152        ShapeUtil::MakeTupleShape({ShapeUtil::MakeTokenShape()}),
153        ShapeUtil::MakeShape(F32, {42, 42})});
154 
155   *program_shape.mutable_result() = ShapeUtil::MakeShape(F32, {7});
156 
157   program_shape.add_parameter_names("foo");
158   program_shape.add_parameter_names("bar");
159   program_shape.add_parameter_names("baz");
160   program_shape.add_parameter_names("qux qux");
161 
162   // Create a copy of the program shape by round-tripping through a proto.
163   ProgramShape program_shape_copy(program_shape.ToProto());
164   ASSERT_EQ(program_shape.parameters_size(),
165             program_shape_copy.parameters_size());
166   for (int i = 0; i < program_shape.parameters_size(); ++i) {
167     EXPECT_TRUE(ShapeUtil::Equal(program_shape.parameters(i),
168                                  program_shape_copy.parameters(i)));
169   }
170 
171   EXPECT_TRUE(
172       ShapeUtil::Equal(program_shape.result(), program_shape_copy.result()));
173 
174   ASSERT_EQ(program_shape.parameter_names_size(),
175             program_shape_copy.parameter_names_size());
176   for (int i = 0; i < program_shape.parameter_names_size(); ++i) {
177     EXPECT_EQ(program_shape.parameter_names(i),
178               program_shape_copy.parameter_names(i));
179   }
180 }
181 
TEST_F(ShapeTest,ProgramShapeToString)182 TEST_F(ShapeTest, ProgramShapeToString) {
183   ProgramShape prog = ShapeUtil::MakeProgramShape(
184       {opaque_, scalar_, matrix_, matrix2_, tuple_, nested_tuple_},
185       nested_tuple_);
186   EXPECT_EQ(
187       "((unknown): opaque[], "
188       "(unknown): f32[], "
189       "(unknown): u32[1,2], "
190       "(unknown): s32[3,4], "
191       "(unknown): (opaque[], f32[], u32[1,2], s32[3,4]), "
192       "(unknown): ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])) "
193       "-> "
194       "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])",
195       prog.ToString());
196 
197   prog.add_parameter_names("arg0");
198   prog.add_parameter_names("scalar");
199   prog.add_parameter_names("matrix");
200   prog.add_parameter_names("matrix2");
201   prog.add_parameter_names("tuple");
202   prog.add_parameter_names("nested_tuple");
203   EXPECT_EQ(
204       "(arg0: opaque[], "
205       "scalar: f32[], "
206       "matrix: u32[1,2], "
207       "matrix2: s32[3,4], "
208       "tuple: (opaque[], f32[], u32[1,2], s32[3,4]), "
209       "nested_tuple: ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], "
210       "token[])) "
211       "-> "
212       "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])",
213       prog.ToString());
214 }
215 
TEST_F(ShapeTest,SupportsAbslHash)216 TEST_F(ShapeTest, SupportsAbslHash) {
217   EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly(
218       {opaque_, token_, scalar_, scalar_with_tile_, matrix_, matrix2_, tuple_,
219        nested_tuple_, dynamic_matrix_}));
220 }
221 
BM_ShapeCopy(::testing::benchmark::State & state)222 void BM_ShapeCopy(::testing::benchmark::State& state) {
223   // Create different shapes based on benchmark parameters:
224   Shape shape;
225   switch (state.range(0)) {
226     case 0: {
227       // Shape()
228       break;
229     }
230     case 1: {
231       // f32[1,2,2]{2,1,0}
232       shape = Shape(F32, {1, 2, 2}, {false, false, false}, {});
233       *shape.mutable_layout() = Layout({2, 1, 0});
234       break;
235     }
236     case 2: {
237       // f32[1,2,2]{2,1,0:T(2,128)}
238       shape = Shape(F32, {1, 2, 2}, {false, false, false}, {});
239       *shape.mutable_layout() = Layout({2, 1, 0}, {}, {Tile({2, 128})});
240       break;
241     }
242   }
243   state.SetLabel(shape.ToString(true));
244 
245   for (auto s : state) {
246     Shape copy(shape);
247   }
248 }
249 BENCHMARK(BM_ShapeCopy)->Arg(0)->Arg(1)->Arg(2);
250 
251 }  // namespace
252 }  // namespace xla
253