xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This program compiles an XLA program which computes 123 and writes the
17 // resulting object file to stdout.
18 
19 #include <iostream>
20 #include <vector>
21 
22 #include "llvm/ADT/Triple.h"
23 #include "llvm/Support/Host.h"
24 #include "tensorflow/compiler/xla/client/client_library.h"
25 #include "tensorflow/compiler/xla/client/xla_builder.h"
26 #include "tensorflow/compiler/xla/client/xla_computation.h"
27 #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
28 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
29 #include "tensorflow/compiler/xla/types.h"
30 #include "tensorflow/compiler/xla/util.h"
31 #include "tensorflow/core/platform/init_main.h"
32 #include "tensorflow/core/platform/logging.h"
33 
34 namespace {
35 
36 using std::string;
37 
Doubler()38 xla::XlaComputation Doubler() {
39   xla::XlaBuilder builder("doubler");
40   auto r0f32 = xla::ShapeUtil::MakeShape(xla::F32, {});
41   auto x = xla::Parameter(&builder, 0, r0f32, "x");
42   xla::Mul(x, xla::ConstantR0<float>(&builder, 2.0));
43   return std::move(builder.Build().ValueOrDie());
44 }
45 
46 }  // namespace
47 
main(int argc,char ** argv)48 int main(int argc, char** argv) {
49   tensorflow::port::InitMain(argv[0], &argc, &argv);
50 
51   auto client = xla::ClientLibrary::GetOrCreateCompileOnlyClient().ValueOrDie();
52 
53   xla::XlaBuilder builder("aot_test_helper");
54   auto opaque_shape = xla::ShapeUtil::MakeOpaqueShape();
55   auto opaque_param = Parameter(&builder, 0, opaque_shape, "x");
56   auto r0f32 = xla::ShapeUtil::MakeShape(xla::F32, {});
57   auto sum = CustomCall(&builder, "SumStructElements", {opaque_param}, r0f32);
58   Call(&builder, Doubler(), {sum});
59 
60   if (argc != 2) {
61     LOG(FATAL) << "local_client_aot_test_helper TARGET_CPU";
62   }
63 
64   std::string triple_string;
65   std::string target_cpu = argv[1];
66   if (target_cpu == "k8") {
67     triple_string = "x86_64-none-linux-gnu";
68   } else if (target_cpu == "darwin") {
69     triple_string = "x86_64-apple-macosx";
70   } else if ((target_cpu == "arm") || (target_cpu == "aarch64")) {
71     triple_string = "aarch64-none-linux-gnu";
72   } else if (target_cpu == "x64_windows") {
73     triple_string = "x86_64-pc-windows-msvc19";
74   } else if (target_cpu == "ppc") {
75     triple_string = "ppc64le-ibm-linux-gnu";
76   } else if (target_cpu == "s390x") {
77     triple_string = "systemz-none-linux-gnu";
78   } else if (target_cpu == "local") {
79     triple_string = llvm::sys::getDefaultTargetTriple();
80   } else {
81     LOG(FATAL) << "unsupported TARGET_CPU: " << target_cpu;
82   }
83 
84   llvm::Triple triple(triple_string);
85 
86   xla::XlaComputation computation = builder.Build().value();
87   xla::CompileOnlyClient::AotXlaComputationInstance instance{
88       &computation, /*argument_layouts=*/{&opaque_shape}, &r0f32};
89 
90   xla::cpu::CpuAotCompilationOptions options(
91       triple_string,
92       /*cpu_name=*/"", /*features=*/"", "SumAndDouble",
93       xla::cpu::CpuAotCompilationOptions::RelocationModel::Static);
94 
95   auto results = client->CompileAheadOfTime({instance}, options).value();
96   auto result = xla::unique_ptr_static_cast<xla::cpu::CpuAotCompilationResult>(
97       std::move(results.front()));
98   // It's lame to hard-code the buffer assignments, but we need
99   // local_client_aot_test.cc to be able to easily invoke the function.
100   CHECK_EQ(result->result_buffer_index(), 1);
101   CHECK_EQ(result->buffer_infos().size(), 3);
102   CHECK(result->buffer_infos()[0].is_entry_parameter());      // param buffer
103   CHECK_EQ(result->buffer_infos()[1].size(), sizeof(float));  // result buffer
104   CHECK(result->buffer_infos()[2].is_constant());             // const buffer
105   if (triple.isOSBinFormatELF()) {
106     // Check the ELF magic.
107     CHECK_EQ(result->object_file_data()[0], 0x7F);
108     CHECK_EQ(result->object_file_data()[1], 'E');
109     CHECK_EQ(result->object_file_data()[2], 'L');
110     CHECK_EQ(result->object_file_data()[3], 'F');
111     // Check the ELF class.
112     CHECK_EQ(result->object_file_data()[4], triple.isArch32Bit() ? 1 : 2);
113     // Check the ELF endianness: it should be little.
114     CHECK_EQ(result->object_file_data()[5], triple.isLittleEndian() ? 1 : 2);
115     // Check the ELF version: it should be 1.
116     CHECK_EQ(result->object_file_data()[6], 1);
117   }
118 
119   const std::vector<char>& object_file_data = result->object_file_data();
120   std::cout.write(object_file_data.data(), object_file_data.size());
121 
122   return 0;
123 }
124