1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/shape.h"
17
18 #include <numeric>
19
20 #include "absl/hash/hash_testing.h"
21 #include "absl/strings/str_cat.h"
22 #include "absl/strings/str_join.h"
23 #include "tensorflow/compiler/xla/layout_util.h"
24 #include "tensorflow/compiler/xla/shape_util.h"
25 #include "tensorflow/compiler/xla/status_macros.h"
26 #include "tensorflow/compiler/xla/test.h"
27 #include "tensorflow/compiler/xla/test_helpers.h"
28 #include "tensorflow/compiler/xla/types.h"
29 #include "tensorflow/compiler/xla/util.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31 #include "tensorflow/core/platform/test_benchmark.h"
32
33 namespace xla {
34 namespace {
35
36 class ShapeTest : public ::testing::Test {
37 protected:
38 const Shape opaque_ = ShapeUtil::MakeOpaqueShape();
39 const Shape token_ = ShapeUtil::MakeTokenShape();
40 const Shape scalar_ = ShapeUtil::MakeShape(F32, {});
41 const Shape scalar_with_tile_ =
42 ShapeUtil::MakeShapeWithLayout(F32, {}, {}, {}, {Tile({256})});
43 const Shape matrix_ = ShapeUtil::MakeShape(U32, {1, 2});
44 const Shape matrix2_ = ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {0, 1});
45 const Shape tuple_ =
46 ShapeUtil::MakeTupleShape({opaque_, scalar_, matrix_, matrix2_});
47 const Shape nested_tuple_ =
48 ShapeUtil::MakeTupleShape({tuple_, matrix_, token_});
49 const Shape dynamic_matrix_ =
50 ShapeUtil::MakeShape(S32, {5, 2}, {true, false});
51 };
52
TEST_F(ShapeTest,ShapeToFromProto)53 TEST_F(ShapeTest, ShapeToFromProto) {
54 for (const Shape& shape : {opaque_, token_, scalar_, matrix_, matrix2_,
55 tuple_, nested_tuple_, dynamic_matrix_}) {
56 Shape shape_copy(shape.ToProto());
57 EXPECT_TRUE(ShapeUtil::Equal(shape, shape_copy))
58 << shape << " != " << shape_copy;
59 }
60 }
61
TEST_F(ShapeTest,ShapeToString)62 TEST_F(ShapeTest, ShapeToString) {
63 EXPECT_EQ("opaque[]", opaque_.ToString());
64 EXPECT_EQ("token[]", token_.ToString());
65 EXPECT_EQ("f32[]", scalar_.ToString());
66 EXPECT_EQ("u32[1,2]", matrix_.ToString());
67 EXPECT_EQ("s32[3,4]", matrix2_.ToString());
68 EXPECT_EQ("(opaque[], f32[], u32[1,2], s32[3,4])", tuple_.ToString());
69 EXPECT_EQ("((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])",
70 nested_tuple_.ToString());
71
72 EXPECT_EQ("opaque[]", opaque_.ToString(/*print_layout=*/true));
73 EXPECT_EQ("f32[]", scalar_.ToString(/*print_layout=*/true));
74 EXPECT_EQ("f32[]{:T(256)}",
75 scalar_with_tile_.ToString(/*print_layout=*/true));
76 EXPECT_EQ("u32[1,2]{1,0}", matrix_.ToString(/*print_layout=*/true));
77 EXPECT_EQ("s32[3,4]{0,1}", matrix2_.ToString(/*print_layout=*/true));
78 EXPECT_EQ("(opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1})",
79 tuple_.ToString(/*print_layout=*/true));
80 EXPECT_EQ(
81 "((opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1}), u32[1,2]{1,0}, "
82 "token[])",
83 nested_tuple_.ToString(/*print_layout=*/true));
84 }
85
TEST_F(ShapeTest,DynamicShapeToString)86 TEST_F(ShapeTest, DynamicShapeToString) {
87 Shape array_shape =
88 ShapeUtil::MakeShape(F32, {23, 44, 55}, {true, false, true});
89 EXPECT_EQ("f32[<=23,44,<=55]", array_shape.ToString());
90
91 array_shape.set_dynamic_dimension(2, false);
92 EXPECT_EQ("f32[<=23,44,55]", array_shape.ToString());
93 }
94
TEST_F(ShapeTest,EqualityTest)95 TEST_F(ShapeTest, EqualityTest) {
96 // Different layouts.
97 EXPECT_NE(ShapeUtil::MakeShapeWithLayout(F32, {23, 44}, {1, 0}),
98 ShapeUtil::MakeShapeWithLayout(F32, {23, 44}, {0, 1}));
99
100 // Different dims.
101 EXPECT_NE(ShapeUtil::MakeShapeWithLayout(F32, {44, 23}, {1, 0}),
102 ShapeUtil::MakeShapeWithLayout(F32, {23, 44}, {1, 0}));
103
104 // Different elements.
105 EXPECT_NE(ShapeUtil::MakeShapeWithLayout(S32, {44, 23}, {1, 0}),
106 ShapeUtil::MakeShapeWithLayout(F32, {23, 44}, {1, 0}));
107
108 // Equal shapes.
109 EXPECT_EQ(ShapeUtil::MakeShapeWithLayout(F32, {23, 44}, {1, 0}),
110 ShapeUtil::MakeShapeWithLayout(F32, {23, 44}, {1, 0}));
111 }
112
TEST_F(ShapeTest,IsStatic)113 TEST_F(ShapeTest, IsStatic) {
114 EXPECT_TRUE(opaque_.is_static());
115 EXPECT_TRUE(token_.is_static());
116 EXPECT_TRUE(matrix_.is_static());
117 EXPECT_TRUE(tuple_.is_static());
118 EXPECT_TRUE(nested_tuple_.is_static());
119
120 Shape dynamic_matrix = matrix_;
121 EXPECT_TRUE(dynamic_matrix.is_static());
122 dynamic_matrix.set_dynamic_dimension(1, true);
123 EXPECT_FALSE(dynamic_matrix.is_static());
124
125 Shape dynamic_tuple = tuple_;
126 EXPECT_TRUE(dynamic_tuple.is_static());
127 ShapeUtil::GetMutableSubshape(&dynamic_tuple, {2})
128 ->set_dynamic_dimension(1, true);
129 EXPECT_FALSE(dynamic_tuple.is_static());
130 }
131
TEST_F(ShapeTest,IsDynamicDimension)132 TEST_F(ShapeTest, IsDynamicDimension) {
133 Shape dynamic_matrix = matrix_;
134 dynamic_matrix.set_dynamic_dimension(1, true);
135 EXPECT_FALSE(dynamic_matrix.is_dynamic_dimension(0));
136 EXPECT_TRUE(dynamic_matrix.is_dynamic_dimension(1));
137
138 Shape dynamic_tuple = tuple_;
139 EXPECT_TRUE(dynamic_tuple.is_static());
140 ShapeUtil::GetMutableSubshape(&dynamic_tuple, {2})
141 ->set_dynamic_dimension(1, true);
142 EXPECT_FALSE(dynamic_tuple.is_static());
143 }
144
TEST_F(ShapeTest,ProgramShapeToFromProto)145 TEST_F(ShapeTest, ProgramShapeToFromProto) {
146 ProgramShape program_shape;
147 *program_shape.add_parameters() = ShapeUtil::MakeShape(F32, {1, 2, 3});
148 *program_shape.add_parameters() = ShapeUtil::MakeTokenShape();
149 *program_shape.add_parameters() = ShapeUtil::MakeShape(S64, {});
150 *program_shape.add_parameters() = ShapeUtil::MakeTupleShape(
151 {ShapeUtil::MakeShape(S32, {}),
152 ShapeUtil::MakeTupleShape({ShapeUtil::MakeTokenShape()}),
153 ShapeUtil::MakeShape(F32, {42, 42})});
154
155 *program_shape.mutable_result() = ShapeUtil::MakeShape(F32, {7});
156
157 program_shape.add_parameter_names("foo");
158 program_shape.add_parameter_names("bar");
159 program_shape.add_parameter_names("baz");
160 program_shape.add_parameter_names("qux qux");
161
162 // Create a copy of the program shape by round-tripping through a proto.
163 ProgramShape program_shape_copy(program_shape.ToProto());
164 ASSERT_EQ(program_shape.parameters_size(),
165 program_shape_copy.parameters_size());
166 for (int i = 0; i < program_shape.parameters_size(); ++i) {
167 EXPECT_TRUE(ShapeUtil::Equal(program_shape.parameters(i),
168 program_shape_copy.parameters(i)));
169 }
170
171 EXPECT_TRUE(
172 ShapeUtil::Equal(program_shape.result(), program_shape_copy.result()));
173
174 ASSERT_EQ(program_shape.parameter_names_size(),
175 program_shape_copy.parameter_names_size());
176 for (int i = 0; i < program_shape.parameter_names_size(); ++i) {
177 EXPECT_EQ(program_shape.parameter_names(i),
178 program_shape_copy.parameter_names(i));
179 }
180 }
181
TEST_F(ShapeTest,ProgramShapeToString)182 TEST_F(ShapeTest, ProgramShapeToString) {
183 ProgramShape prog = ShapeUtil::MakeProgramShape(
184 {opaque_, scalar_, matrix_, matrix2_, tuple_, nested_tuple_},
185 nested_tuple_);
186 EXPECT_EQ(
187 "((unknown): opaque[], "
188 "(unknown): f32[], "
189 "(unknown): u32[1,2], "
190 "(unknown): s32[3,4], "
191 "(unknown): (opaque[], f32[], u32[1,2], s32[3,4]), "
192 "(unknown): ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])) "
193 "-> "
194 "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])",
195 prog.ToString());
196
197 prog.add_parameter_names("arg0");
198 prog.add_parameter_names("scalar");
199 prog.add_parameter_names("matrix");
200 prog.add_parameter_names("matrix2");
201 prog.add_parameter_names("tuple");
202 prog.add_parameter_names("nested_tuple");
203 EXPECT_EQ(
204 "(arg0: opaque[], "
205 "scalar: f32[], "
206 "matrix: u32[1,2], "
207 "matrix2: s32[3,4], "
208 "tuple: (opaque[], f32[], u32[1,2], s32[3,4]), "
209 "nested_tuple: ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], "
210 "token[])) "
211 "-> "
212 "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])",
213 prog.ToString());
214 }
215
TEST_F(ShapeTest,SupportsAbslHash)216 TEST_F(ShapeTest, SupportsAbslHash) {
217 EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly(
218 {opaque_, token_, scalar_, scalar_with_tile_, matrix_, matrix2_, tuple_,
219 nested_tuple_, dynamic_matrix_}));
220 }
221
BM_ShapeCopy(::testing::benchmark::State & state)222 void BM_ShapeCopy(::testing::benchmark::State& state) {
223 // Create different shapes based on benchmark parameters:
224 Shape shape;
225 switch (state.range(0)) {
226 case 0: {
227 // Shape()
228 break;
229 }
230 case 1: {
231 // f32[1,2,2]{2,1,0}
232 shape = Shape(F32, {1, 2, 2}, {false, false, false}, {});
233 *shape.mutable_layout() = Layout({2, 1, 0});
234 break;
235 }
236 case 2: {
237 // f32[1,2,2]{2,1,0:T(2,128)}
238 shape = Shape(F32, {1, 2, 2}, {false, false, false}, {});
239 *shape.mutable_layout() = Layout({2, 1, 0}, {}, {Tile({2, 128})});
240 break;
241 }
242 }
243 state.SetLabel(shape.ToString(true));
244
245 for (auto s : state) {
246 Shape copy(shape);
247 }
248 }
249 BENCHMARK(BM_ShapeCopy)->Arg(0)->Arg(1)->Arg(2);
250
251 } // namespace
252 } // namespace xla
253