1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
17
18 #include <memory>
19
20 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
21 #include "tensorflow/compiler/xla/service/hlo_module.h"
22 #include "tensorflow/compiler/xla/shape_util.h"
23 #include "tensorflow/compiler/xla/test.h"
24 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
25 #include "tensorflow/core/platform/test.h"
26
27 namespace xla {
28 namespace {
29
30 class HloCreationUtilsTest : public HloTestBase {
31 protected:
CreateModuleWithProgramShape(PrimitiveType primitive_type,absl::Span<const int64_t> input_shape_dims,absl::Span<const int64_t> output_shape_dims,HloInstruction ** param,HloComputation ** entry_computation)32 std::unique_ptr<VerifiedHloModule> CreateModuleWithProgramShape(
33 PrimitiveType primitive_type, absl::Span<const int64_t> input_shape_dims,
34 absl::Span<const int64_t> output_shape_dims, HloInstruction** param,
35 HloComputation** entry_computation) {
36 Shape input_shape = ShapeUtil::MakeShape(primitive_type, input_shape_dims);
37 Shape output_shape =
38 ShapeUtil::MakeShape(primitive_type, output_shape_dims);
39 auto module = CreateNewVerifiedModule("test");
40 *entry_computation = module->AddEntryComputation(
41 CreateComputationWithSignature({&input_shape}, output_shape, "entry")
42 .ValueOrDie());
43 *param = (*entry_computation)->parameter_instruction(0);
44 return module;
45 }
46
CreateModuleWithProgramShape(PrimitiveType primitive_type,absl::Span<const int64_t> input_shape_dims,absl::Span<const int64_t> output_shape_dims,HloInstruction ** param,HloComputation ** entry_computation,PrimitiveType primitive_type_output)47 std::unique_ptr<VerifiedHloModule> CreateModuleWithProgramShape(
48 PrimitiveType primitive_type, absl::Span<const int64_t> input_shape_dims,
49 absl::Span<const int64_t> output_shape_dims, HloInstruction** param,
50 HloComputation** entry_computation, PrimitiveType primitive_type_output) {
51 Shape input_shape = ShapeUtil::MakeShape(primitive_type, input_shape_dims);
52 Shape output_shape =
53 ShapeUtil::MakeShape(primitive_type_output, output_shape_dims);
54 auto module = CreateNewVerifiedModule("test");
55 *entry_computation = module->AddEntryComputation(
56 CreateComputationWithSignature({&input_shape}, output_shape, "entry")
57 .ValueOrDie());
58 *param = (*entry_computation)->parameter_instruction(0);
59 return module;
60 }
61 };
62
TEST_F(HloCreationUtilsTest,CollapseFirst1Dim)63 TEST_F(HloCreationUtilsTest, CollapseFirst1Dim) {
64 HloInstruction* param;
65 HloComputation* entry_computation;
66
67 auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{2},
68 /*output_shape_dims=*/{2}, ¶m,
69 &entry_computation);
70
71 TF_ASSERT_OK_AND_ASSIGN(HloInstruction * first_1_dims_collapsed,
72 CollapseFirstNDims(param, 1));
73 entry_computation->set_root_instruction(first_1_dims_collapsed);
74
75 HloEvaluator evaluator;
76 TF_ASSERT_OK_AND_ASSIGN(
77 Literal result_literal,
78 evaluator.Evaluate(*module, {LiteralUtil::CreateR1<int32_t>({3, 4})}));
79 CHECK_EQ(result_literal, LiteralUtil::CreateR1<int32_t>({3, 4}));
80 }
81
TEST_F(HloCreationUtilsTest,CollapseFirst2Dims)82 TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) {
83 HloInstruction* param;
84 HloComputation* entry_computation;
85
86 auto module = CreateModuleWithProgramShape(
87 S32, /*input_shape_dims=*/{2, 3, 2}, /*output_shape_dims=*/{6, 2}, ¶m,
88 &entry_computation);
89
90 TF_ASSERT_OK_AND_ASSIGN(HloInstruction * first_2_dims_collapsed,
91 CollapseFirstNDims(param, 2));
92 entry_computation->set_root_instruction(first_2_dims_collapsed);
93
94 HloEvaluator evaluator;
95 TF_ASSERT_OK_AND_ASSIGN(
96 Literal result_literal,
97 evaluator.Evaluate(*module, {LiteralUtil::CreateR3<int32_t>(
98 {{{1, 2}, {3, 4}, {5, 6}},
99 {{-1, -2}, {-3, -4}, {-5, -6}}})}));
100 CHECK_EQ(result_literal,
101 LiteralUtil::CreateR2<int32_t>(
102 {{1, 2}, {3, 4}, {5, 6}, {-1, -2}, {-3, -4}, {-5, -6}}));
103 }
104
TEST_F(HloCreationUtilsTest,Prepend1DegenerateDim)105 TEST_F(HloCreationUtilsTest, Prepend1DegenerateDim) {
106 HloInstruction* param;
107 HloComputation* entry_computation;
108
109 auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{2},
110 /*output_shape_dims=*/{1, 2},
111 ¶m, &entry_computation);
112
113 TF_ASSERT_OK_AND_ASSIGN(HloInstruction * with_1_degenerate_dim_prepended,
114 PrependDegenerateDims(param, 1));
115 entry_computation->set_root_instruction(with_1_degenerate_dim_prepended);
116
117 HloEvaluator evaluator;
118 TF_ASSERT_OK_AND_ASSIGN(
119 Literal result_literal,
120 evaluator.Evaluate(*module, {LiteralUtil::CreateR1<int32_t>({9, 10})}));
121 CHECK_EQ(result_literal, LiteralUtil::CreateR2<int32_t>({{9, 10}}));
122 }
123
TEST_F(HloCreationUtilsTest,Prepend2DegenerateDims)124 TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) {
125 HloInstruction* param;
126 HloComputation* entry_computation;
127
128 auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{2},
129 /*output_shape_dims=*/{1, 1, 2},
130 ¶m, &entry_computation);
131
132 TF_ASSERT_OK_AND_ASSIGN(HloInstruction * with_2_degenerate_dims_prepended,
133 PrependDegenerateDims(param, 2));
134 entry_computation->set_root_instruction(with_2_degenerate_dims_prepended);
135
136 HloEvaluator evaluator;
137 TF_ASSERT_OK_AND_ASSIGN(
138 Literal result_literal,
139 evaluator.Evaluate(*module, {LiteralUtil::CreateR1<int32_t>({9, 10})}));
140 CHECK_EQ(result_literal, LiteralUtil::CreateR3<int32_t>({{{9, 10}}}));
141 }
142
TEST_F(HloCreationUtilsTest,Prepend2DegenerateDimsToScalar)143 TEST_F(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) {
144 HloInstruction* param;
145 HloComputation* entry_computation;
146
147 auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{},
148 /*output_shape_dims=*/{1, 1},
149 ¶m, &entry_computation);
150
151 TF_ASSERT_OK_AND_ASSIGN(HloInstruction * with_2_degenerate_dims_prepended,
152 PrependDegenerateDims(param, 2));
153 entry_computation->set_root_instruction(with_2_degenerate_dims_prepended);
154
155 HloEvaluator evaluator;
156 TF_ASSERT_OK_AND_ASSIGN(
157 Literal result_literal,
158 evaluator.Evaluate(*module, {LiteralUtil::CreateR0<int32_t>(9)}));
159 CHECK_EQ(result_literal, LiteralUtil::CreateR2<int32_t>({{9}}));
160 }
161
TEST_F(HloCreationUtilsTest,ExpandFirstDimInto3Dims)162 TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) {
163 HloInstruction* param;
164 HloComputation* entry_computation;
165
166 auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{6},
167 /*output_shape_dims=*/{3, 1, 2},
168 ¶m, &entry_computation);
169
170 TF_ASSERT_OK_AND_ASSIGN(HloInstruction * first_dim_expanded,
171 ExpandFirstDimIntoNDims(param, {3, 1, 2}));
172 entry_computation->set_root_instruction(first_dim_expanded);
173
174 HloEvaluator evaluator;
175 TF_ASSERT_OK_AND_ASSIGN(
176 Literal result_literal,
177 evaluator.Evaluate(*module,
178 {LiteralUtil::CreateR1<int32_t>({1, 2, 3, 4, 5, 6})}));
179 CHECK_EQ(result_literal,
180 LiteralUtil::CreateR3<int32_t>({{{1, 2}}, {{3, 4}}, {{5, 6}}}));
181 }
182
TEST_F(HloCreationUtilsTest,PadVectorWithZeros)183 TEST_F(HloCreationUtilsTest, PadVectorWithZeros) {
184 HloInstruction* param;
185 HloComputation* entry_computation;
186
187 auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{2},
188 /*output_shape_dims=*/{6}, ¶m,
189 &entry_computation);
190
191 TF_ASSERT_OK_AND_ASSIGN(
192 HloInstruction * zero_padded_param,
193 PadVectorWithZeros(param, /*zeros_to_prepend=*/3, /*zeros_to_append=*/1));
194 entry_computation->set_root_instruction(zero_padded_param);
195
196 HloEvaluator evaluator;
197 TF_ASSERT_OK_AND_ASSIGN(
198 Literal result_literal,
199 evaluator.Evaluate(*module, {LiteralUtil::CreateR1<int32_t>({3, 4})}));
200 CHECK_EQ(result_literal, LiteralUtil::CreateR1<int32_t>({0, 0, 0, 3, 4, 0}));
201 }
202
TEST_F(HloCreationUtilsTest,BroadcastZeros_S32)203 TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) {
204 HloInstruction* param;
205 HloComputation* entry_computation;
206
207 auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{},
208 /*output_shape_dims=*/{2, 2},
209 ¶m, &entry_computation);
210
211 HloInstruction* zeros =
212 BroadcastZeros(module->entry_computation(), S32, {2, 2});
213 entry_computation->set_root_instruction(zeros);
214
215 HloEvaluator evaluator;
216 TF_ASSERT_OK_AND_ASSIGN(
217 Literal result_literal,
218 evaluator.Evaluate(*module, {LiteralUtil::CreateR0<int32_t>(0)}));
219 CHECK_EQ(result_literal, LiteralUtil::CreateR2<int32_t>({{0, 0}, {0, 0}}));
220 }
221
TEST_F(HloCreationUtilsTest,BroadcastZeros_F32)222 TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) {
223 HloInstruction* param;
224 HloComputation* entry_computation;
225
226 auto module = CreateModuleWithProgramShape(F32, /*input_shape_dims=*/{},
227 /*output_shape_dims=*/{2, 2},
228 ¶m, &entry_computation);
229
230 HloInstruction* zeros =
231 BroadcastZeros(module->entry_computation(), F32, {2, 2});
232 entry_computation->set_root_instruction(zeros);
233
234 HloEvaluator evaluator;
235 TF_ASSERT_OK_AND_ASSIGN(
236 Literal result_literal,
237 evaluator.Evaluate(*module, {LiteralUtil::CreateR0<float>(0.0f)}));
238 CHECK_EQ(result_literal,
239 LiteralUtil::CreateR2<float>({{0.0f, 0.0f}, {0.0f, 0.0f}}));
240 }
241
TEST_F(HloCreationUtilsTest,MakeBitcastConvertToHlo_S32)242 TEST_F(HloCreationUtilsTest, MakeBitcastConvertToHlo_S32) {
243 HloInstruction* param;
244 HloComputation* entry_computation;
245
246 auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{2, 2},
247 /*output_shape_dims=*/{2, 2},
248 ¶m, &entry_computation, F32);
249 auto* input = module->entry_computation()->AddInstruction(
250 HloInstruction::CreateConstant(
251 LiteralUtil::CreateR2<int32_t>({{0, 0}, {0, 0}})));
252
253 HloInstruction* output = MakeBitcastConvertToHlo(input, F32);
254 entry_computation->set_root_instruction(output);
255
256 HloEvaluator evaluator;
257 TF_ASSERT_OK_AND_ASSIGN(
258 Literal result_literal,
259 evaluator.Evaluate(*module,
260 {LiteralUtil::CreateR2<int32_t>({{0, 0}, {0, 0}})}));
261 CHECK_EQ(result_literal,
262 LiteralUtil::CreateR2<float>({{0.0f, 0.0f}, {0.0f, 0.0f}}));
263 }
264
TEST_F(HloCreationUtilsTest,MakeIotaHlo_I32)265 TEST_F(HloCreationUtilsTest, MakeIotaHlo_I32) {
266 HloInstruction* param;
267 HloComputation* entry_computation;
268
269 auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{},
270 /*output_shape_dims=*/{2, 2},
271 ¶m, &entry_computation, F32);
272 HloInstruction* output = MakeIotaHlo(module->entry_computation(),
273 ShapeUtil::MakeShape(F32, {2, 2}), 0);
274 entry_computation->set_root_instruction(output);
275
276 HloEvaluator evaluator;
277 TF_ASSERT_OK_AND_ASSIGN(
278 Literal result_literal,
279 evaluator.Evaluate(*module, {LiteralUtil::CreateR0<int32_t>(0.0)}));
280 CHECK_EQ(result_literal,
281 LiteralUtil::CreateR2<float>({{0.0f, 0.0f}, {1.0f, 1.0f}}));
282 }
283
TEST_F(HloCreationUtilsTest,MakeBroadcast_F32)284 TEST_F(HloCreationUtilsTest, MakeBroadcast_F32) {
285 HloInstruction* param;
286 HloComputation* entry_computation;
287
288 auto module = CreateModuleWithProgramShape(F32, /*input_shape_dims=*/{},
289 /*output_shape_dims=*/{2, 2},
290 ¶m, &entry_computation);
291 auto* input = MakeR0ConstantHlo<float>(module->entry_computation(), 0);
292 HloInstruction* output = MakeBroadcastHlo(input, {}, {2, 2});
293 entry_computation->set_root_instruction(output);
294
295 HloEvaluator evaluator;
296 TF_ASSERT_OK_AND_ASSIGN(
297 Literal result_literal,
298 evaluator.Evaluate(*module, {LiteralUtil::CreateR0<float>(0.0f)}));
299 CHECK_EQ(result_literal,
300 LiteralUtil::CreateR2<float>({{0.0f, 0.0f}, {0.0f, 0.0f}}));
301 }
302
TEST_F(HloCreationUtilsTest,MakeBroadcast_Shape_I32)303 TEST_F(HloCreationUtilsTest, MakeBroadcast_Shape_I32) {
304 HloInstruction* param;
305 HloComputation* entry_computation;
306
307 auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{},
308 /*output_shape_dims=*/{2, 2},
309 ¶m, &entry_computation);
310 auto* input = MakeR0ConstantHlo<int32_t>(module->entry_computation(), 0);
311 HloInstruction* output =
312 MakeBroadcastHlo(input, {}, ShapeUtil::MakeShape(S32, {2, 2}));
313 entry_computation->set_root_instruction(output);
314
315 HloEvaluator evaluator;
316 TF_ASSERT_OK_AND_ASSIGN(
317 Literal result_literal,
318 evaluator.Evaluate(*module, {LiteralUtil::CreateR0<int32_t>(0.0)}));
319 CHECK_EQ(result_literal, LiteralUtil::CreateR2<int32_t>({{0, 0}, {0, 0}}));
320 }
321
TEST_F(HloCreationUtilsTest,MaybeMakeTupleCrashesWithEmptyOperands)322 TEST_F(HloCreationUtilsTest, MaybeMakeTupleCrashesWithEmptyOperands) {
323 EXPECT_DEATH(MaybeMakeTuple({}), "");
324 }
325
TEST_F(HloCreationUtilsTest,MaybeMakeTupleForwardsSingleElement)326 TEST_F(HloCreationUtilsTest, MaybeMakeTupleForwardsSingleElement) {
327 HloInstruction* param;
328 HloComputation* entry_computation;
329
330 auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{2, 2},
331 /*output_shape_dims=*/{2, 2},
332 ¶m, &entry_computation);
333 HloInstruction* output = MaybeMakeTuple({param});
334 entry_computation->set_root_instruction(output);
335
336 HloEvaluator evaluator;
337 TF_ASSERT_OK_AND_ASSIGN(
338 Literal result_literal,
339 evaluator.Evaluate(*module,
340 {LiteralUtil::CreateR2<int32_t>({{0, 0}, {0, 0}})}));
341 EXPECT_EQ(result_literal, LiteralUtil::CreateR2<int32_t>({{0, 0}, {0, 0}}));
342 }
343
TEST_F(HloCreationUtilsTest,MaybeMakeTupleTuplizesMultipleOperands)344 TEST_F(HloCreationUtilsTest, MaybeMakeTupleTuplizesMultipleOperands) {
345 Shape input_shape0 = ShapeUtil::MakeShape(S32, {2});
346 Shape input_shape1 = ShapeUtil::MakeShape(F32, {3, 3});
347 Shape output_shape =
348 ShapeUtil::MakeTupleShapeWithPtrs({&input_shape1, &input_shape0});
349 auto module = CreateNewVerifiedModule("test");
350 HloComputation* entry_computation = module->AddEntryComputation(
351 CreateComputationWithSignature({&input_shape0, &input_shape1},
352 output_shape, "entry")
353 .ValueOrDie());
354 HloInstruction* output =
355 MaybeMakeTuple({entry_computation->parameter_instruction(1),
356 entry_computation->parameter_instruction(0)});
357 entry_computation->set_root_instruction(output);
358
359 HloEvaluator evaluator;
360 Literal input0 = LiteralUtil::CreateR1<int32_t>({{2, 4}});
361 Literal input1 =
362 LiteralUtil::CreateR2<float>({{3, 2, 1}, {4, 5, 6}, {9, 8, 7}});
363 TF_ASSERT_OK_AND_ASSIGN(
364 Literal result_literal,
365 evaluator.Evaluate(*module, {input0.Clone(), input1.Clone()}));
366 Literal expected_result = LiteralUtil::MakeTuple({&input1, &input0});
367 EXPECT_EQ(result_literal, expected_result);
368 }
369
370 } // namespace
371 } // namespace xla
372