xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
17 
18 #include <memory>
19 
20 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
21 #include "tensorflow/compiler/xla/service/hlo_module.h"
22 #include "tensorflow/compiler/xla/shape_util.h"
23 #include "tensorflow/compiler/xla/test.h"
24 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
25 #include "tensorflow/core/platform/test.h"
26 
27 namespace xla {
28 namespace {
29 
30 class HloCreationUtilsTest : public HloTestBase {
31  protected:
CreateModuleWithProgramShape(PrimitiveType primitive_type,absl::Span<const int64_t> input_shape_dims,absl::Span<const int64_t> output_shape_dims,HloInstruction ** param,HloComputation ** entry_computation)32   std::unique_ptr<VerifiedHloModule> CreateModuleWithProgramShape(
33       PrimitiveType primitive_type, absl::Span<const int64_t> input_shape_dims,
34       absl::Span<const int64_t> output_shape_dims, HloInstruction** param,
35       HloComputation** entry_computation) {
36     Shape input_shape = ShapeUtil::MakeShape(primitive_type, input_shape_dims);
37     Shape output_shape =
38         ShapeUtil::MakeShape(primitive_type, output_shape_dims);
39     auto module = CreateNewVerifiedModule("test");
40     *entry_computation = module->AddEntryComputation(
41         CreateComputationWithSignature({&input_shape}, output_shape, "entry")
42             .ValueOrDie());
43     *param = (*entry_computation)->parameter_instruction(0);
44     return module;
45   }
46 
CreateModuleWithProgramShape(PrimitiveType primitive_type,absl::Span<const int64_t> input_shape_dims,absl::Span<const int64_t> output_shape_dims,HloInstruction ** param,HloComputation ** entry_computation,PrimitiveType primitive_type_output)47   std::unique_ptr<VerifiedHloModule> CreateModuleWithProgramShape(
48       PrimitiveType primitive_type, absl::Span<const int64_t> input_shape_dims,
49       absl::Span<const int64_t> output_shape_dims, HloInstruction** param,
50       HloComputation** entry_computation, PrimitiveType primitive_type_output) {
51     Shape input_shape = ShapeUtil::MakeShape(primitive_type, input_shape_dims);
52     Shape output_shape =
53         ShapeUtil::MakeShape(primitive_type_output, output_shape_dims);
54     auto module = CreateNewVerifiedModule("test");
55     *entry_computation = module->AddEntryComputation(
56         CreateComputationWithSignature({&input_shape}, output_shape, "entry")
57             .ValueOrDie());
58     *param = (*entry_computation)->parameter_instruction(0);
59     return module;
60   }
61 };
62 
TEST_F(HloCreationUtilsTest,CollapseFirst1Dim)63 TEST_F(HloCreationUtilsTest, CollapseFirst1Dim) {
64   HloInstruction* param;
65   HloComputation* entry_computation;
66 
67   auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{2},
68                                              /*output_shape_dims=*/{2}, &param,
69                                              &entry_computation);
70 
71   TF_ASSERT_OK_AND_ASSIGN(HloInstruction * first_1_dims_collapsed,
72                           CollapseFirstNDims(param, 1));
73   entry_computation->set_root_instruction(first_1_dims_collapsed);
74 
75   HloEvaluator evaluator;
76   TF_ASSERT_OK_AND_ASSIGN(
77       Literal result_literal,
78       evaluator.Evaluate(*module, {LiteralUtil::CreateR1<int32_t>({3, 4})}));
79   CHECK_EQ(result_literal, LiteralUtil::CreateR1<int32_t>({3, 4}));
80 }
81 
TEST_F(HloCreationUtilsTest,CollapseFirst2Dims)82 TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) {
83   HloInstruction* param;
84   HloComputation* entry_computation;
85 
86   auto module = CreateModuleWithProgramShape(
87       S32, /*input_shape_dims=*/{2, 3, 2}, /*output_shape_dims=*/{6, 2}, &param,
88       &entry_computation);
89 
90   TF_ASSERT_OK_AND_ASSIGN(HloInstruction * first_2_dims_collapsed,
91                           CollapseFirstNDims(param, 2));
92   entry_computation->set_root_instruction(first_2_dims_collapsed);
93 
94   HloEvaluator evaluator;
95   TF_ASSERT_OK_AND_ASSIGN(
96       Literal result_literal,
97       evaluator.Evaluate(*module, {LiteralUtil::CreateR3<int32_t>(
98                                       {{{1, 2}, {3, 4}, {5, 6}},
99                                        {{-1, -2}, {-3, -4}, {-5, -6}}})}));
100   CHECK_EQ(result_literal,
101            LiteralUtil::CreateR2<int32_t>(
102                {{1, 2}, {3, 4}, {5, 6}, {-1, -2}, {-3, -4}, {-5, -6}}));
103 }
104 
TEST_F(HloCreationUtilsTest,Prepend1DegenerateDim)105 TEST_F(HloCreationUtilsTest, Prepend1DegenerateDim) {
106   HloInstruction* param;
107   HloComputation* entry_computation;
108 
109   auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{2},
110                                              /*output_shape_dims=*/{1, 2},
111                                              &param, &entry_computation);
112 
113   TF_ASSERT_OK_AND_ASSIGN(HloInstruction * with_1_degenerate_dim_prepended,
114                           PrependDegenerateDims(param, 1));
115   entry_computation->set_root_instruction(with_1_degenerate_dim_prepended);
116 
117   HloEvaluator evaluator;
118   TF_ASSERT_OK_AND_ASSIGN(
119       Literal result_literal,
120       evaluator.Evaluate(*module, {LiteralUtil::CreateR1<int32_t>({9, 10})}));
121   CHECK_EQ(result_literal, LiteralUtil::CreateR2<int32_t>({{9, 10}}));
122 }
123 
TEST_F(HloCreationUtilsTest,Prepend2DegenerateDims)124 TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) {
125   HloInstruction* param;
126   HloComputation* entry_computation;
127 
128   auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{2},
129                                              /*output_shape_dims=*/{1, 1, 2},
130                                              &param, &entry_computation);
131 
132   TF_ASSERT_OK_AND_ASSIGN(HloInstruction * with_2_degenerate_dims_prepended,
133                           PrependDegenerateDims(param, 2));
134   entry_computation->set_root_instruction(with_2_degenerate_dims_prepended);
135 
136   HloEvaluator evaluator;
137   TF_ASSERT_OK_AND_ASSIGN(
138       Literal result_literal,
139       evaluator.Evaluate(*module, {LiteralUtil::CreateR1<int32_t>({9, 10})}));
140   CHECK_EQ(result_literal, LiteralUtil::CreateR3<int32_t>({{{9, 10}}}));
141 }
142 
TEST_F(HloCreationUtilsTest,Prepend2DegenerateDimsToScalar)143 TEST_F(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) {
144   HloInstruction* param;
145   HloComputation* entry_computation;
146 
147   auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{},
148                                              /*output_shape_dims=*/{1, 1},
149                                              &param, &entry_computation);
150 
151   TF_ASSERT_OK_AND_ASSIGN(HloInstruction * with_2_degenerate_dims_prepended,
152                           PrependDegenerateDims(param, 2));
153   entry_computation->set_root_instruction(with_2_degenerate_dims_prepended);
154 
155   HloEvaluator evaluator;
156   TF_ASSERT_OK_AND_ASSIGN(
157       Literal result_literal,
158       evaluator.Evaluate(*module, {LiteralUtil::CreateR0<int32_t>(9)}));
159   CHECK_EQ(result_literal, LiteralUtil::CreateR2<int32_t>({{9}}));
160 }
161 
TEST_F(HloCreationUtilsTest,ExpandFirstDimInto3Dims)162 TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) {
163   HloInstruction* param;
164   HloComputation* entry_computation;
165 
166   auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{6},
167                                              /*output_shape_dims=*/{3, 1, 2},
168                                              &param, &entry_computation);
169 
170   TF_ASSERT_OK_AND_ASSIGN(HloInstruction * first_dim_expanded,
171                           ExpandFirstDimIntoNDims(param, {3, 1, 2}));
172   entry_computation->set_root_instruction(first_dim_expanded);
173 
174   HloEvaluator evaluator;
175   TF_ASSERT_OK_AND_ASSIGN(
176       Literal result_literal,
177       evaluator.Evaluate(*module,
178                          {LiteralUtil::CreateR1<int32_t>({1, 2, 3, 4, 5, 6})}));
179   CHECK_EQ(result_literal,
180            LiteralUtil::CreateR3<int32_t>({{{1, 2}}, {{3, 4}}, {{5, 6}}}));
181 }
182 
TEST_F(HloCreationUtilsTest,PadVectorWithZeros)183 TEST_F(HloCreationUtilsTest, PadVectorWithZeros) {
184   HloInstruction* param;
185   HloComputation* entry_computation;
186 
187   auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{2},
188                                              /*output_shape_dims=*/{6}, &param,
189                                              &entry_computation);
190 
191   TF_ASSERT_OK_AND_ASSIGN(
192       HloInstruction * zero_padded_param,
193       PadVectorWithZeros(param, /*zeros_to_prepend=*/3, /*zeros_to_append=*/1));
194   entry_computation->set_root_instruction(zero_padded_param);
195 
196   HloEvaluator evaluator;
197   TF_ASSERT_OK_AND_ASSIGN(
198       Literal result_literal,
199       evaluator.Evaluate(*module, {LiteralUtil::CreateR1<int32_t>({3, 4})}));
200   CHECK_EQ(result_literal, LiteralUtil::CreateR1<int32_t>({0, 0, 0, 3, 4, 0}));
201 }
202 
TEST_F(HloCreationUtilsTest,BroadcastZeros_S32)203 TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) {
204   HloInstruction* param;
205   HloComputation* entry_computation;
206 
207   auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{},
208                                              /*output_shape_dims=*/{2, 2},
209                                              &param, &entry_computation);
210 
211   HloInstruction* zeros =
212       BroadcastZeros(module->entry_computation(), S32, {2, 2});
213   entry_computation->set_root_instruction(zeros);
214 
215   HloEvaluator evaluator;
216   TF_ASSERT_OK_AND_ASSIGN(
217       Literal result_literal,
218       evaluator.Evaluate(*module, {LiteralUtil::CreateR0<int32_t>(0)}));
219   CHECK_EQ(result_literal, LiteralUtil::CreateR2<int32_t>({{0, 0}, {0, 0}}));
220 }
221 
TEST_F(HloCreationUtilsTest,BroadcastZeros_F32)222 TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) {
223   HloInstruction* param;
224   HloComputation* entry_computation;
225 
226   auto module = CreateModuleWithProgramShape(F32, /*input_shape_dims=*/{},
227                                              /*output_shape_dims=*/{2, 2},
228                                              &param, &entry_computation);
229 
230   HloInstruction* zeros =
231       BroadcastZeros(module->entry_computation(), F32, {2, 2});
232   entry_computation->set_root_instruction(zeros);
233 
234   HloEvaluator evaluator;
235   TF_ASSERT_OK_AND_ASSIGN(
236       Literal result_literal,
237       evaluator.Evaluate(*module, {LiteralUtil::CreateR0<float>(0.0f)}));
238   CHECK_EQ(result_literal,
239            LiteralUtil::CreateR2<float>({{0.0f, 0.0f}, {0.0f, 0.0f}}));
240 }
241 
TEST_F(HloCreationUtilsTest,MakeBitcastConvertToHlo_S32)242 TEST_F(HloCreationUtilsTest, MakeBitcastConvertToHlo_S32) {
243   HloInstruction* param;
244   HloComputation* entry_computation;
245 
246   auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{2, 2},
247                                              /*output_shape_dims=*/{2, 2},
248                                              &param, &entry_computation, F32);
249   auto* input = module->entry_computation()->AddInstruction(
250       HloInstruction::CreateConstant(
251           LiteralUtil::CreateR2<int32_t>({{0, 0}, {0, 0}})));
252 
253   HloInstruction* output = MakeBitcastConvertToHlo(input, F32);
254   entry_computation->set_root_instruction(output);
255 
256   HloEvaluator evaluator;
257   TF_ASSERT_OK_AND_ASSIGN(
258       Literal result_literal,
259       evaluator.Evaluate(*module,
260                          {LiteralUtil::CreateR2<int32_t>({{0, 0}, {0, 0}})}));
261   CHECK_EQ(result_literal,
262            LiteralUtil::CreateR2<float>({{0.0f, 0.0f}, {0.0f, 0.0f}}));
263 }
264 
TEST_F(HloCreationUtilsTest,MakeIotaHlo_I32)265 TEST_F(HloCreationUtilsTest, MakeIotaHlo_I32) {
266   HloInstruction* param;
267   HloComputation* entry_computation;
268 
269   auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{},
270                                              /*output_shape_dims=*/{2, 2},
271                                              &param, &entry_computation, F32);
272   HloInstruction* output = MakeIotaHlo(module->entry_computation(),
273                                        ShapeUtil::MakeShape(F32, {2, 2}), 0);
274   entry_computation->set_root_instruction(output);
275 
276   HloEvaluator evaluator;
277   TF_ASSERT_OK_AND_ASSIGN(
278       Literal result_literal,
279       evaluator.Evaluate(*module, {LiteralUtil::CreateR0<int32_t>(0.0)}));
280   CHECK_EQ(result_literal,
281            LiteralUtil::CreateR2<float>({{0.0f, 0.0f}, {1.0f, 1.0f}}));
282 }
283 
TEST_F(HloCreationUtilsTest,MakeBroadcast_F32)284 TEST_F(HloCreationUtilsTest, MakeBroadcast_F32) {
285   HloInstruction* param;
286   HloComputation* entry_computation;
287 
288   auto module = CreateModuleWithProgramShape(F32, /*input_shape_dims=*/{},
289                                              /*output_shape_dims=*/{2, 2},
290                                              &param, &entry_computation);
291   auto* input = MakeR0ConstantHlo<float>(module->entry_computation(), 0);
292   HloInstruction* output = MakeBroadcastHlo(input, {}, {2, 2});
293   entry_computation->set_root_instruction(output);
294 
295   HloEvaluator evaluator;
296   TF_ASSERT_OK_AND_ASSIGN(
297       Literal result_literal,
298       evaluator.Evaluate(*module, {LiteralUtil::CreateR0<float>(0.0f)}));
299   CHECK_EQ(result_literal,
300            LiteralUtil::CreateR2<float>({{0.0f, 0.0f}, {0.0f, 0.0f}}));
301 }
302 
TEST_F(HloCreationUtilsTest,MakeBroadcast_Shape_I32)303 TEST_F(HloCreationUtilsTest, MakeBroadcast_Shape_I32) {
304   HloInstruction* param;
305   HloComputation* entry_computation;
306 
307   auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{},
308                                              /*output_shape_dims=*/{2, 2},
309                                              &param, &entry_computation);
310   auto* input = MakeR0ConstantHlo<int32_t>(module->entry_computation(), 0);
311   HloInstruction* output =
312       MakeBroadcastHlo(input, {}, ShapeUtil::MakeShape(S32, {2, 2}));
313   entry_computation->set_root_instruction(output);
314 
315   HloEvaluator evaluator;
316   TF_ASSERT_OK_AND_ASSIGN(
317       Literal result_literal,
318       evaluator.Evaluate(*module, {LiteralUtil::CreateR0<int32_t>(0.0)}));
319   CHECK_EQ(result_literal, LiteralUtil::CreateR2<int32_t>({{0, 0}, {0, 0}}));
320 }
321 
TEST_F(HloCreationUtilsTest,MaybeMakeTupleCrashesWithEmptyOperands)322 TEST_F(HloCreationUtilsTest, MaybeMakeTupleCrashesWithEmptyOperands) {
323   EXPECT_DEATH(MaybeMakeTuple({}), "");
324 }
325 
TEST_F(HloCreationUtilsTest,MaybeMakeTupleForwardsSingleElement)326 TEST_F(HloCreationUtilsTest, MaybeMakeTupleForwardsSingleElement) {
327   HloInstruction* param;
328   HloComputation* entry_computation;
329 
330   auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{2, 2},
331                                              /*output_shape_dims=*/{2, 2},
332                                              &param, &entry_computation);
333   HloInstruction* output = MaybeMakeTuple({param});
334   entry_computation->set_root_instruction(output);
335 
336   HloEvaluator evaluator;
337   TF_ASSERT_OK_AND_ASSIGN(
338       Literal result_literal,
339       evaluator.Evaluate(*module,
340                          {LiteralUtil::CreateR2<int32_t>({{0, 0}, {0, 0}})}));
341   EXPECT_EQ(result_literal, LiteralUtil::CreateR2<int32_t>({{0, 0}, {0, 0}}));
342 }
343 
TEST_F(HloCreationUtilsTest,MaybeMakeTupleTuplizesMultipleOperands)344 TEST_F(HloCreationUtilsTest, MaybeMakeTupleTuplizesMultipleOperands) {
345   Shape input_shape0 = ShapeUtil::MakeShape(S32, {2});
346   Shape input_shape1 = ShapeUtil::MakeShape(F32, {3, 3});
347   Shape output_shape =
348       ShapeUtil::MakeTupleShapeWithPtrs({&input_shape1, &input_shape0});
349   auto module = CreateNewVerifiedModule("test");
350   HloComputation* entry_computation = module->AddEntryComputation(
351       CreateComputationWithSignature({&input_shape0, &input_shape1},
352                                      output_shape, "entry")
353           .ValueOrDie());
354   HloInstruction* output =
355       MaybeMakeTuple({entry_computation->parameter_instruction(1),
356                       entry_computation->parameter_instruction(0)});
357   entry_computation->set_root_instruction(output);
358 
359   HloEvaluator evaluator;
360   Literal input0 = LiteralUtil::CreateR1<int32_t>({{2, 4}});
361   Literal input1 =
362       LiteralUtil::CreateR2<float>({{3, 2, 1}, {4, 5, 6}, {9, 8, 7}});
363   TF_ASSERT_OK_AND_ASSIGN(
364       Literal result_literal,
365       evaluator.Evaluate(*module, {input0.Clone(), input1.Clone()}));
366   Literal expected_result = LiteralUtil::MakeTuple({&input1, &input0});
367   EXPECT_EQ(result_literal, expected_result);
368 }
369 
370 }  // namespace
371 }  // namespace xla
372