1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/aot/embedded_protocol_buffers.h"
17
18 #include <memory>
19 #include <string>
20
21 #include "absl/memory/memory.h"
22 #include "absl/strings/str_replace.h"
23 #include "llvm/ADT/Triple.h"
24 #include "llvm/IR/Constants.h"
25 #include "llvm/IR/GlobalVariable.h"
26 #include "llvm/IR/LLVMContext.h"
27 #include "llvm/IR/LegacyPassManager.h"
28 #include "llvm/IR/Module.h"
29 #include "llvm/MC/TargetRegistry.h"
30 #include "llvm/Target/TargetMachine.h"
31 #include "llvm/Target/TargetOptions.h"
32 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_type_conversion_util.h"
33 #include "tensorflow/compiler/xla/util.h"
34
35 namespace tensorflow {
36 namespace tfcompile {
37
38 using xla::llvm_ir::AsStringRef;
39
AddEmbeddedProtocolBufferToLlvmModule(llvm::Module * module,const::tensorflow::protobuf::MessageLite & proto,absl::string_view unique_identifier,string * protobuf_array_symbol_name,int64_t * protobuf_array_size)40 static void AddEmbeddedProtocolBufferToLlvmModule(
41 llvm::Module* module, const ::tensorflow::protobuf::MessageLite& proto,
42 absl::string_view unique_identifier, string* protobuf_array_symbol_name,
43 int64_t* protobuf_array_size) {
44 string protobuf_array_contents = proto.SerializeAsString();
45 *protobuf_array_symbol_name =
46 absl::StrCat(unique_identifier, "_protobuf_array_contents");
47 *protobuf_array_size = protobuf_array_contents.size();
48
49 llvm::Constant* protobuf_array_initializer =
50 llvm::ConstantDataArray::getString(module->getContext(),
51 AsStringRef(protobuf_array_contents),
52 /*AddNull=*/false);
53 new llvm::GlobalVariable(
54 *module, protobuf_array_initializer->getType(),
55 /*isConstant=*/true, llvm::GlobalValue::ExternalLinkage,
56 protobuf_array_initializer, AsStringRef(*protobuf_array_symbol_name));
57 }
58
CreateCPPShimExpression(absl::string_view qualified_cpp_protobuf_name,absl::string_view protobuf_array_symbol_name,int64_t protobuf_array_size)59 static string CreateCPPShimExpression(
60 absl::string_view qualified_cpp_protobuf_name,
61 absl::string_view protobuf_array_symbol_name, int64_t protobuf_array_size) {
62 string code =
63 "[]() {\n"
64 " {{PROTOBUF_NAME}}* proto = new {{PROTOBUF_NAME}};\n"
65 " proto->ParseFromArray(&{{ARRAY_SYMBOL}}[0], {{ARRAY_SIZE}});\n"
66 " return proto;\n"
67 " }()";
68
69 return absl::StrReplaceAll(
70 code,
71 {
72 {"{{ARRAY_SYMBOL}}", absl::StrCat(protobuf_array_symbol_name)},
73 {"{{ARRAY_SIZE}}", absl::StrCat(protobuf_array_size)},
74 {"{{PROTOBUF_NAME}}", absl::StrCat(qualified_cpp_protobuf_name)},
75 });
76 }
77
CodegenModule(llvm::TargetMachine * target_machine,std::unique_ptr<llvm::Module> module)78 static StatusOr<string> CodegenModule(llvm::TargetMachine* target_machine,
79 std::unique_ptr<llvm::Module> module) {
80 llvm::SmallVector<char, 0> stream_buffer;
81 llvm::raw_svector_ostream ostream(stream_buffer);
82 llvm::legacy::PassManager codegen_passes;
83
84 if (target_machine->addPassesToEmitFile(codegen_passes, ostream, nullptr,
85 llvm::CGFT_ObjectFile)) {
86 return xla::InternalError(
87 "Could not create pass pipeline to generate object file");
88 }
89
90 codegen_passes.run(*module);
91
92 return string(stream_buffer.begin(), stream_buffer.end());
93 }
94
95 static StatusOr<std::unique_ptr<llvm::TargetMachine>>
GetTargetMachineFromTriple(absl::string_view target_triple)96 GetTargetMachineFromTriple(absl::string_view target_triple) {
97 std::string error;
98 std::string normalized_triple =
99 llvm::Triple::normalize(AsStringRef(absl::string_view(target_triple)));
100 const llvm::Target* target =
101 llvm::TargetRegistry::lookupTarget(normalized_triple, error);
102 if (target == nullptr) {
103 return xla::InternalError("TargetRegistry::lookupTarget failed: %s",
104 error.c_str());
105 }
106
107 return absl::WrapUnique(target->createTargetMachine(
108 normalized_triple, /*CPU=*/"",
109 /*Features=*/"", llvm::TargetOptions(), llvm::None));
110 }
111
CreateEmbeddedProtocolBuffers(absl::string_view target_triple,absl::Span<const ProtobufToEmbed> protobufs_to_embed)112 StatusOr<EmbeddedProtocolBuffers> CreateEmbeddedProtocolBuffers(
113 absl::string_view target_triple,
114 absl::Span<const ProtobufToEmbed> protobufs_to_embed) {
115 TF_ASSIGN_OR_RETURN(std::unique_ptr<llvm::TargetMachine> target_machine,
116 GetTargetMachineFromTriple(target_triple));
117
118 llvm::LLVMContext llvm_context;
119 std::unique_ptr<llvm::Module> module_with_serialized_proto =
120 absl::make_unique<llvm::Module>("embedded_data_module", llvm_context);
121
122 EmbeddedProtocolBuffers result;
123
124 for (const ProtobufToEmbed& protobuf_to_embed : protobufs_to_embed) {
125 string cpp_shim, cpp_variable_decl;
126 if (protobuf_to_embed.message) {
127 string protobuf_array_symbol_name;
128 int64_t protobuf_array_size;
129
130 AddEmbeddedProtocolBufferToLlvmModule(
131 module_with_serialized_proto.get(), *protobuf_to_embed.message,
132 protobuf_to_embed.symbol_prefix, &protobuf_array_symbol_name,
133 &protobuf_array_size);
134 cpp_shim = CreateCPPShimExpression(
135 protobuf_to_embed.qualified_cpp_protobuf_name,
136 protobuf_array_symbol_name, protobuf_array_size);
137
138 cpp_variable_decl =
139 absl::StrCat("extern \"C\" char ", protobuf_array_symbol_name, "[];");
140 } else {
141 cpp_shim = "nullptr";
142 }
143 result.cpp_shims.push_back({cpp_shim, cpp_variable_decl});
144 }
145
146 TF_ASSIGN_OR_RETURN(result.object_file_data,
147 CodegenModule(target_machine.get(),
148 std::move(module_with_serialized_proto)));
149 return result;
150 }
151
152 } // namespace tfcompile
153 } // namespace tensorflow
154