1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tpu/kernels/tpu_program_group.h"
16
17 #include "tensorflow/compiler/xla/service/hlo_module_group.h"
18 #include "tensorflow/compiler/xla/xla.pb.h"
19 #include "tensorflow/core/lib/gtl/cleanup.h"
20 #include "tensorflow/core/platform/casts.h"
21 #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
22 #include "tensorflow/core/tpu/kernels/tpu_compile.pb.h"
23 #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
24 #include "tensorflow/core/tpu/tpu_api.h"
25 #include "tensorflow/core/tpu/tpu_ops_c_api.h"
26 #include "tensorflow/stream_executor/tpu/proto_helper.h"
27 #include "tensorflow/stream_executor/tpu/status_helper.h"
28
29 namespace tensorflow {
30 namespace tpu {
31 namespace {
32 namespace se_tpu = ::stream_executor::tpu;
33 using stream_executor::port::Status;
34 } // namespace
35
ConstructExecutableInfo(const XLA_TpuProgram * xla_tpu_program)36 TPUExecutableInfoProto TpuProgramGroup::ConstructExecutableInfo(
37 const XLA_TpuProgram* xla_tpu_program) {
38 VLOG(1) << "ConstructExecutableInfo";
39 TpuSerializedProto serialized_executable_info = {};
40 StatusHelper status;
41 OpsApiFn()->TpuProgram_GetExecutableInfoFn(
42 xla_tpu_program, &serialized_executable_info, status.c_status);
43 TPUExecutableInfoProto executable_info;
44 if (status.ok()) {
45 executable_info = se_tpu::DeserializeProto<TPUExecutableInfoProto>(
46 serialized_executable_info);
47 StreamExecutor_Tpu_FreeSerializedProto(&serialized_executable_info);
48 }
49 return executable_info;
50 }
51
ConstructHostTransferInfo(const XLA_TpuProgram * xla_tpu_program)52 TPUHostTransferInfoProto TpuProgramGroup::ConstructHostTransferInfo(
53 const XLA_TpuProgram* xla_tpu_program) {
54 VLOG(1) << "ConstructHostTransferInfo";
55 TpuSerializedProto serialized_host_transfer_info = {};
56 StatusHelper status;
57 OpsApiFn()->TpuProgram_GetHostTransferInfoFn(
58 xla_tpu_program, &serialized_host_transfer_info, status.c_status);
59 TPUHostTransferInfoProto host_transfer_info;
60 if (status.ok()) {
61 host_transfer_info = se_tpu::DeserializeProto<TPUHostTransferInfoProto>(
62 serialized_host_transfer_info);
63 StreamExecutor_Tpu_FreeSerializedProto(&serialized_host_transfer_info);
64 }
65 return host_transfer_info;
66 }
67
ConstructHloMetadata(const XLA_TpuProgram * xla_tpu_program)68 xla::HloProto TpuProgramGroup::ConstructHloMetadata(
69 const XLA_TpuProgram* xla_tpu_program) {
70 VLOG(1) << "ConstructHloMetadata";
71 TpuSerializedProto serialized_hlo_metadata = {};
72 StatusHelper status;
73 OpsApiFn()->TpuProgram_GetHloMetadataFn(
74 xla_tpu_program, &serialized_hlo_metadata, status.c_status);
75 xla::HloProto hlo_metadata;
76 if (status.ok()) {
77 hlo_metadata =
78 se_tpu::DeserializeProto<xla::HloProto>(serialized_hlo_metadata);
79 StreamExecutor_Tpu_FreeSerializedProto(&serialized_hlo_metadata);
80 }
81 return hlo_metadata;
82 }
83
Initialize(absl::Span<XLA_TpuProgram * const> xla_tpu_programs)84 void TpuProgramGroup::Initialize(
85 absl::Span<XLA_TpuProgram* const> xla_tpu_programs) {
86 CHECK_GT(xla_tpu_programs.size(), 0);
87 CHECK_EQ(program_count(), 0) << "Reinitialization of an existing "
88 "`TpuProgramGroup` instance is prohibited.";
89 set_tpu_programs(xla_tpu_programs);
90
91 CHECK_EQ(tpu_program_fingerprints_.size(), 0);
92 set_fingerprints();
93
94 std::vector<bool> may_modify_variables_array(tpu_programs_.size(), false);
95 std::vector<TPUExecutableInfoProto> executable_infos(tpu_programs_.size());
96 std::vector<TPUHostTransferInfoProto> host_transfer_infos(
97 tpu_programs_.size());
98 std::vector<xla::HloProto> hlo_metadatas(tpu_programs_.size());
99 for (size_t i = 0; i < tpu_programs_.size(); ++i) {
100 const XLA_TpuProgram* xla_tpu_program = tpu_programs_[i];
101 bool may_modify_variables;
102 OpsApiFn()->TpuProgram_GetMayModifyVariablesFn(xla_tpu_program,
103 &may_modify_variables);
104 may_modify_variables_array[i] = may_modify_variables;
105 executable_infos[i] = ConstructExecutableInfo(xla_tpu_program);
106 host_transfer_infos[i] = ConstructHostTransferInfo(xla_tpu_program);
107 hlo_metadatas[i] = ConstructHloMetadata(xla_tpu_program);
108 }
109
110 may_modify_variables_ = may_modify_variables_array;
111 executable_infos_ = executable_infos;
112 host_transfer_infos_ = host_transfer_infos;
113 hlo_metadatas_ = hlo_metadatas;
114 RefreshHloMetadatasPtrs();
115 }
116
has_sharding_program() const117 bool TpuProgramGroup::has_sharding_program() const {
118 for (const XLA_TpuProgram* tpu_program : tpu_programs_) {
119 if (!OpsApiFn()->TpuProgram_HasShardingFn(tpu_program)) {
120 return false;
121 }
122 }
123 return true;
124 }
125
program_count() const126 size_t TpuProgramGroup::program_count() const { return tpu_programs_.size(); }
127
program_size() const128 int64_t TpuProgramGroup::program_size() const {
129 int64_t total_size = 0;
130 for (const XLA_TpuProgram* tpu_program : tpu_programs_) {
131 total_size += OpsApiFn()->TpuProgram_GetProgramSizeFn(tpu_program);
132 }
133 return total_size;
134 }
135
LogProgramMemorySummary()136 bool TpuProgramGroup::LogProgramMemorySummary() {
137 bool success = true;
138 for (const XLA_TpuProgram* tpu_program : tpu_programs_) {
139 success &= OpsApiFn()->TpuProgram_LogProgramMemorySummaryFn(tpu_program);
140 }
141 return success;
142 }
143
UnloadAndDestroyPrograms()144 void TpuProgramGroup::UnloadAndDestroyPrograms() {
145 for (XLA_TpuProgram* tpu_program : tpu_programs_) {
146 StatusHelper status;
147 OpsApiFn()->TpuProgram_UnloadAndDestroyFn(tpu_program, status.c_status);
148 auto s = status.status();
149 if (!s.ok()) {
150 LOG(ERROR) << "TpuProgramGroup::UnloadPrograms(): " << s.ToString();
151 }
152 }
153 tpu_programs_.clear();
154 }
155
TpuProgramGroup(TpuProgramGroup && other)156 TpuProgramGroup::TpuProgramGroup(TpuProgramGroup&& other)
157 : may_modify_variables_(std::move(other.may_modify_variables_)),
158 tpu_programs_(std::move(other.tpu_programs_)),
159 executable_infos_(std::move(other.executable_infos_)),
160 host_transfer_infos_(std::move(other.host_transfer_infos_)),
161 hlo_metadatas_(std::move(other.hlo_metadatas_)) {
162 RefreshHloMetadatasPtrs();
163 }
164
set_hlo_metadatas(absl::Span<const xla::HloProto> hlo_metadatas)165 void TpuProgramGroup::set_hlo_metadatas(
166 absl::Span<const xla::HloProto> hlo_metadatas) {
167 hlo_metadatas_.resize(hlo_metadatas.size());
168 for (size_t i = 0; i < hlo_metadatas.size(); ++i) {
169 hlo_metadatas_[i] = hlo_metadatas[i];
170 }
171 RefreshHloMetadatasPtrs();
172 }
173
hlo_metadatas() const174 absl::Span<const xla::HloProto* const> TpuProgramGroup::hlo_metadatas() const {
175 return hlo_metadatas_ptrs_;
176 }
177
hlo_metadata(int index) const178 const xla::HloProto* TpuProgramGroup::hlo_metadata(int index) const {
179 CHECK_GE(index, 0);
180 CHECK_LT(index, hlo_metadatas_ptrs_.size());
181 return hlo_metadatas_ptrs_[index];
182 }
183
RefreshHloMetadatasPtrs()184 void TpuProgramGroup::RefreshHloMetadatasPtrs() {
185 hlo_metadatas_ptrs_.reserve(hlo_metadatas_.size());
186 for (const auto& hlo_metadata_internal_ : hlo_metadatas_) {
187 hlo_metadatas_ptrs_.push_back(&hlo_metadata_internal_);
188 }
189 }
190
may_modify_variables_list() const191 const std::vector<bool>& TpuProgramGroup::may_modify_variables_list() const {
192 return may_modify_variables_;
193 }
194
set_may_modify_variables(const std::vector<bool> & may_modify_variables)195 void TpuProgramGroup::set_may_modify_variables(
196 const std::vector<bool>& may_modify_variables) {
197 may_modify_variables_ = may_modify_variables;
198 }
199
may_modify_variables(int index) const200 bool TpuProgramGroup::may_modify_variables(int index) const {
201 CHECK_GE(index, 0);
202 CHECK_LT(index, tpu_programs_.size());
203 bool may_modify_variables;
204 OpsApiFn()->TpuProgram_GetMayModifyVariablesFn(tpu_programs_[index],
205 &may_modify_variables);
206 return may_modify_variables;
207 }
208
tpu_programs() const209 const std::vector<XLA_TpuProgram*>& TpuProgramGroup::tpu_programs() const {
210 return tpu_programs_;
211 }
212
fingerprints() const213 const std::vector<std::string>& TpuProgramGroup::fingerprints() const {
214 return tpu_program_fingerprints_;
215 }
216
set_fingerprints()217 void TpuProgramGroup::set_fingerprints() {
218 for (const XLA_TpuProgram* tpu_program : tpu_programs_) {
219 TpuProgramFingerprint fingerprint =
220 OpsApiFn()->TpuProgram_GetFingerprintFn(tpu_program);
221 tpu_program_fingerprints_.emplace_back(
222 std::string(fingerprint.bytes, fingerprint.size));
223 OpsApiFn()->TpuProgram_DestroyFingerprintFn(fingerprint);
224 }
225 }
226
fingerprint(int index) const227 const std::string& TpuProgramGroup::fingerprint(int index) const {
228 return fingerprints().at(index);
229 }
230
tpu_program(int index) const231 const XLA_TpuProgram* TpuProgramGroup::tpu_program(int index) const {
232 CHECK_GE(index, 0);
233 CHECK_LT(index, tpu_programs_.size());
234 return tpu_programs_[index];
235 }
236
set_tpu_programs(absl::Span<XLA_TpuProgram * const> tpu_programs)237 void TpuProgramGroup::set_tpu_programs(
238 absl::Span<XLA_TpuProgram* const> tpu_programs) {
239 tpu_programs_.resize(tpu_programs.size());
240 for (size_t i = 0; i < tpu_programs.size(); ++i) {
241 tpu_programs_[i] = tpu_programs[i];
242 }
243 }
244
executable_info(int index) const245 const TPUExecutableInfoProto& TpuProgramGroup::executable_info(
246 int index) const {
247 CHECK_GE(index, 0);
248 CHECK_LT(index, executable_infos_.size());
249 return executable_infos_[index];
250 }
251
host_transfer_info(int index) const252 const TPUHostTransferInfoProto& TpuProgramGroup::host_transfer_info(
253 int index) const {
254 CHECK_GE(index, 0);
255 CHECK_LT(index, host_transfer_infos_.size());
256 return host_transfer_infos_[index];
257 }
258
259 /*static*/
CompileAndBuild(const TpuCompilationRequestProto & compilation_request,const XLA_TpuMeshState * mesh_state,TpuProgramGroupInterface * tpu_program_group_interface)260 Status TpuProgramGroup::CompileAndBuild(
261 const TpuCompilationRequestProto& compilation_request,
262 const XLA_TpuMeshState* mesh_state,
263 TpuProgramGroupInterface* tpu_program_group_interface) {
264 se_tpu::SerializedProto serialized_compilation_request =
265 se_tpu::SerializeProto(compilation_request);
266 auto cleanup = gtl::MakeCleanup([serialized_compilation_request] {
267 se_tpu::SerializedProto_Free(serialized_compilation_request);
268 });
269 size_t count = 0;
270 XLA_TpuProgram** xla_tpu_programs = nullptr;
271 StatusHelper status;
272 OpsApiFn()->TpuCompile_CompileAndBuildFn(serialized_compilation_request,
273 mesh_state, &xla_tpu_programs,
274 &count, status.c_status);
275 if (!status.ok()) {
276 VLOG(1) << "Run CompileAndBuild failed.";
277 return status.status();
278 }
279
280 // SPMD could return 1 result for all partitions.
281 TF_RET_CHECK(count == 1 ||
282 count == compilation_request.metadata().num_cores_per_replica());
283
284 VLOG(1) << "Initialize TpuProgramGroup.";
285 TpuProgramGroup* tpu_program_group =
286 tensorflow::down_cast<TpuProgramGroup*>(tpu_program_group_interface);
287 tpu_program_group->Initialize(
288 absl::MakeConstSpan(&xla_tpu_programs[0], count));
289 OpsApiFn()->TpuProgram_FreeArrayFn(xla_tpu_programs);
290 return status.status();
291 }
292
293 /*static*/
CompileAndBuild(const xrt::XLAComputation & xrt_computation_proto,const XLA_TpuMeshState * mesh_state,TpuProgramGroupInterface * tpu_program_group_interface)294 Status TpuProgramGroup::CompileAndBuild(
295 const xrt::XLAComputation& xrt_computation_proto,
296 const XLA_TpuMeshState* mesh_state,
297 TpuProgramGroupInterface* tpu_program_group_interface) {
298 se_tpu::SerializedProto serialized_compilation_request =
299 se_tpu::SerializeProto(xrt_computation_proto);
300 auto cleanup = gtl::MakeCleanup([serialized_compilation_request] {
301 se_tpu::SerializedProto_Free(serialized_compilation_request);
302 });
303 size_t count = 0;
304 XLA_TpuProgram** xla_tpu_programs = nullptr;
305 StatusHelper status;
306 OpsApiFn()->TpuCompile_XrtCompileAndBuildFn(serialized_compilation_request,
307 mesh_state, &xla_tpu_programs,
308 &count, status.c_status);
309 if (!status.ok()) {
310 VLOG(1) << "Run CompileAndBuild failed.";
311 return status.status();
312 }
313
314 // SPMD could return 1 result for all partitions.
315 int num_cores_per_replica =
316 xrt_computation_proto.config().num_cores_per_replica()
317 ? xrt_computation_proto.config().num_cores_per_replica()
318 : 1;
319 TF_RET_CHECK(count == 1 || count == num_cores_per_replica);
320 VLOG(1) << "Initialize TpuProgramGroup.";
321 TpuProgramGroup* tpu_program_group =
322 tensorflow::down_cast<TpuProgramGroup*>(tpu_program_group_interface);
323 tpu_program_group->Initialize(
324 absl::MakeConstSpan(&xla_tpu_programs[0], count));
325 OpsApiFn()->TpuProgram_FreeArrayFn(xla_tpu_programs);
326 return status.status();
327 }
328
tpu_programs(TpuProgramShardingType sharding_type) const329 std::vector<XLA_TpuProgram*> TpuProgramGroup::tpu_programs(
330 TpuProgramShardingType sharding_type) const {
331 std::vector<XLA_TpuProgram*> tpu_programs;
332 tpu_programs.reserve(tpu_programs_.size());
333 for (size_t i = 0; i < tpu_programs_.size(); ++i) {
334 if (OpsApiFn()->TpuProgram_HasShardingFn(tpu_programs_[i])) {
335 tpu_programs.push_back(OpsApiFn()->TpuProgram_GetTpuProgramFn(
336 tpu_programs_[i], sharding_type));
337 CHECK_NE(tpu_programs[i], nullptr);
338 }
339 }
340 return tpu_programs;
341 }
342
DeserializeFromRpcResponseProtos(const std::vector<TpuSerializedProto> & rpc_response_protos)343 Status TpuProgramGroup::DeserializeFromRpcResponseProtos(
344 const std::vector<TpuSerializedProto>& rpc_response_protos) {
345 std::vector<XLA_TpuProgram*> tpu_programs;
346 tpu_programs.resize(rpc_response_protos.size());
347
348 for (size_t i = 0; i < rpc_response_protos.size(); ++i) {
349 StatusHelper status;
350 auto* xla_tpu_program = OpsApiFn()->TpuProgram_NewFn();
351 OpsApiFn()->TpuProgram_DeserializeFromGetTpuProgramResponseProtoFn(
352 rpc_response_protos[i], xla_tpu_program, status.c_status);
353 if (!status.status().ok()) {
354 OpsApiFn()->TpuProgram_FreeFn(xla_tpu_program);
355 return status.status();
356 }
357 tpu_programs[i] = xla_tpu_program;
358 }
359
360 Initialize(tpu_programs);
361 return OkStatus();
362 }
363
SerializeExecutable(int index,TpuExecutableSerializedProto * executable) const364 Status TpuProgramGroup::SerializeExecutable(
365 int index, TpuExecutableSerializedProto* executable) const {
366 CHECK_GE(index, 0);
367 CHECK_LT(index, tpu_programs_.size());
368 StatusHelper status;
369 OpsApiFn()->TpuProgram_SerializeTpuExecutableFn(tpu_programs_[index],
370 executable, status.c_status);
371 return status.status();
372 }
373
SerializeCompilerMetadata(int index,CompilerMetadataSerializedProto * compiler_metadata) const374 Status TpuProgramGroup::SerializeCompilerMetadata(
375 int index, CompilerMetadataSerializedProto* compiler_metadata) const {
376 CHECK_GE(index, 0);
377 CHECK_LT(index, tpu_programs_.size());
378 StatusHelper status;
379 OpsApiFn()->TpuProgram_SerializeCompilerMetadataFn(
380 tpu_programs_[index], compiler_metadata, status.c_status);
381 return status.status();
382 }
383 } // namespace tpu
384 } // namespace tensorflow
385