xref: /aosp_15_r20/external/tensorflow/tensorflow/core/tpu/kernels/tpu_program_group.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tpu/kernels/tpu_program_group.h"
16 
17 #include "tensorflow/compiler/xla/service/hlo_module_group.h"
18 #include "tensorflow/compiler/xla/xla.pb.h"
19 #include "tensorflow/core/lib/gtl/cleanup.h"
20 #include "tensorflow/core/platform/casts.h"
21 #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
22 #include "tensorflow/core/tpu/kernels/tpu_compile.pb.h"
23 #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
24 #include "tensorflow/core/tpu/tpu_api.h"
25 #include "tensorflow/core/tpu/tpu_ops_c_api.h"
26 #include "tensorflow/stream_executor/tpu/proto_helper.h"
27 #include "tensorflow/stream_executor/tpu/status_helper.h"
28 
29 namespace tensorflow {
30 namespace tpu {
31 namespace {
32 namespace se_tpu = ::stream_executor::tpu;
33 using stream_executor::port::Status;
34 }  // namespace
35 
ConstructExecutableInfo(const XLA_TpuProgram * xla_tpu_program)36 TPUExecutableInfoProto TpuProgramGroup::ConstructExecutableInfo(
37     const XLA_TpuProgram* xla_tpu_program) {
38   VLOG(1) << "ConstructExecutableInfo";
39   TpuSerializedProto serialized_executable_info = {};
40   StatusHelper status;
41   OpsApiFn()->TpuProgram_GetExecutableInfoFn(
42       xla_tpu_program, &serialized_executable_info, status.c_status);
43   TPUExecutableInfoProto executable_info;
44   if (status.ok()) {
45     executable_info = se_tpu::DeserializeProto<TPUExecutableInfoProto>(
46         serialized_executable_info);
47     StreamExecutor_Tpu_FreeSerializedProto(&serialized_executable_info);
48   }
49   return executable_info;
50 }
51 
ConstructHostTransferInfo(const XLA_TpuProgram * xla_tpu_program)52 TPUHostTransferInfoProto TpuProgramGroup::ConstructHostTransferInfo(
53     const XLA_TpuProgram* xla_tpu_program) {
54   VLOG(1) << "ConstructHostTransferInfo";
55   TpuSerializedProto serialized_host_transfer_info = {};
56   StatusHelper status;
57   OpsApiFn()->TpuProgram_GetHostTransferInfoFn(
58       xla_tpu_program, &serialized_host_transfer_info, status.c_status);
59   TPUHostTransferInfoProto host_transfer_info;
60   if (status.ok()) {
61     host_transfer_info = se_tpu::DeserializeProto<TPUHostTransferInfoProto>(
62         serialized_host_transfer_info);
63     StreamExecutor_Tpu_FreeSerializedProto(&serialized_host_transfer_info);
64   }
65   return host_transfer_info;
66 }
67 
ConstructHloMetadata(const XLA_TpuProgram * xla_tpu_program)68 xla::HloProto TpuProgramGroup::ConstructHloMetadata(
69     const XLA_TpuProgram* xla_tpu_program) {
70   VLOG(1) << "ConstructHloMetadata";
71   TpuSerializedProto serialized_hlo_metadata = {};
72   StatusHelper status;
73   OpsApiFn()->TpuProgram_GetHloMetadataFn(
74       xla_tpu_program, &serialized_hlo_metadata, status.c_status);
75   xla::HloProto hlo_metadata;
76   if (status.ok()) {
77     hlo_metadata =
78         se_tpu::DeserializeProto<xla::HloProto>(serialized_hlo_metadata);
79     StreamExecutor_Tpu_FreeSerializedProto(&serialized_hlo_metadata);
80   }
81   return hlo_metadata;
82 }
83 
Initialize(absl::Span<XLA_TpuProgram * const> xla_tpu_programs)84 void TpuProgramGroup::Initialize(
85     absl::Span<XLA_TpuProgram* const> xla_tpu_programs) {
86   CHECK_GT(xla_tpu_programs.size(), 0);
87   CHECK_EQ(program_count(), 0) << "Reinitialization of an existing "
88                                   "`TpuProgramGroup` instance is prohibited.";
89   set_tpu_programs(xla_tpu_programs);
90 
91   CHECK_EQ(tpu_program_fingerprints_.size(), 0);
92   set_fingerprints();
93 
94   std::vector<bool> may_modify_variables_array(tpu_programs_.size(), false);
95   std::vector<TPUExecutableInfoProto> executable_infos(tpu_programs_.size());
96   std::vector<TPUHostTransferInfoProto> host_transfer_infos(
97       tpu_programs_.size());
98   std::vector<xla::HloProto> hlo_metadatas(tpu_programs_.size());
99   for (size_t i = 0; i < tpu_programs_.size(); ++i) {
100     const XLA_TpuProgram* xla_tpu_program = tpu_programs_[i];
101     bool may_modify_variables;
102     OpsApiFn()->TpuProgram_GetMayModifyVariablesFn(xla_tpu_program,
103                                                    &may_modify_variables);
104     may_modify_variables_array[i] = may_modify_variables;
105     executable_infos[i] = ConstructExecutableInfo(xla_tpu_program);
106     host_transfer_infos[i] = ConstructHostTransferInfo(xla_tpu_program);
107     hlo_metadatas[i] = ConstructHloMetadata(xla_tpu_program);
108   }
109 
110   may_modify_variables_ = may_modify_variables_array;
111   executable_infos_ = executable_infos;
112   host_transfer_infos_ = host_transfer_infos;
113   hlo_metadatas_ = hlo_metadatas;
114   RefreshHloMetadatasPtrs();
115 }
116 
has_sharding_program() const117 bool TpuProgramGroup::has_sharding_program() const {
118   for (const XLA_TpuProgram* tpu_program : tpu_programs_) {
119     if (!OpsApiFn()->TpuProgram_HasShardingFn(tpu_program)) {
120       return false;
121     }
122   }
123   return true;
124 }
125 
program_count() const126 size_t TpuProgramGroup::program_count() const { return tpu_programs_.size(); }
127 
program_size() const128 int64_t TpuProgramGroup::program_size() const {
129   int64_t total_size = 0;
130   for (const XLA_TpuProgram* tpu_program : tpu_programs_) {
131     total_size += OpsApiFn()->TpuProgram_GetProgramSizeFn(tpu_program);
132   }
133   return total_size;
134 }
135 
LogProgramMemorySummary()136 bool TpuProgramGroup::LogProgramMemorySummary() {
137   bool success = true;
138   for (const XLA_TpuProgram* tpu_program : tpu_programs_) {
139     success &= OpsApiFn()->TpuProgram_LogProgramMemorySummaryFn(tpu_program);
140   }
141   return success;
142 }
143 
UnloadAndDestroyPrograms()144 void TpuProgramGroup::UnloadAndDestroyPrograms() {
145   for (XLA_TpuProgram* tpu_program : tpu_programs_) {
146     StatusHelper status;
147     OpsApiFn()->TpuProgram_UnloadAndDestroyFn(tpu_program, status.c_status);
148     auto s = status.status();
149     if (!s.ok()) {
150       LOG(ERROR) << "TpuProgramGroup::UnloadPrograms(): " << s.ToString();
151     }
152   }
153   tpu_programs_.clear();
154 }
155 
TpuProgramGroup(TpuProgramGroup && other)156 TpuProgramGroup::TpuProgramGroup(TpuProgramGroup&& other)
157     : may_modify_variables_(std::move(other.may_modify_variables_)),
158       tpu_programs_(std::move(other.tpu_programs_)),
159       executable_infos_(std::move(other.executable_infos_)),
160       host_transfer_infos_(std::move(other.host_transfer_infos_)),
161       hlo_metadatas_(std::move(other.hlo_metadatas_)) {
162   RefreshHloMetadatasPtrs();
163 }
164 
set_hlo_metadatas(absl::Span<const xla::HloProto> hlo_metadatas)165 void TpuProgramGroup::set_hlo_metadatas(
166     absl::Span<const xla::HloProto> hlo_metadatas) {
167   hlo_metadatas_.resize(hlo_metadatas.size());
168   for (size_t i = 0; i < hlo_metadatas.size(); ++i) {
169     hlo_metadatas_[i] = hlo_metadatas[i];
170   }
171   RefreshHloMetadatasPtrs();
172 }
173 
hlo_metadatas() const174 absl::Span<const xla::HloProto* const> TpuProgramGroup::hlo_metadatas() const {
175   return hlo_metadatas_ptrs_;
176 }
177 
hlo_metadata(int index) const178 const xla::HloProto* TpuProgramGroup::hlo_metadata(int index) const {
179   CHECK_GE(index, 0);
180   CHECK_LT(index, hlo_metadatas_ptrs_.size());
181   return hlo_metadatas_ptrs_[index];
182 }
183 
RefreshHloMetadatasPtrs()184 void TpuProgramGroup::RefreshHloMetadatasPtrs() {
185   hlo_metadatas_ptrs_.reserve(hlo_metadatas_.size());
186   for (const auto& hlo_metadata_internal_ : hlo_metadatas_) {
187     hlo_metadatas_ptrs_.push_back(&hlo_metadata_internal_);
188   }
189 }
190 
may_modify_variables_list() const191 const std::vector<bool>& TpuProgramGroup::may_modify_variables_list() const {
192   return may_modify_variables_;
193 }
194 
set_may_modify_variables(const std::vector<bool> & may_modify_variables)195 void TpuProgramGroup::set_may_modify_variables(
196     const std::vector<bool>& may_modify_variables) {
197   may_modify_variables_ = may_modify_variables;
198 }
199 
may_modify_variables(int index) const200 bool TpuProgramGroup::may_modify_variables(int index) const {
201   CHECK_GE(index, 0);
202   CHECK_LT(index, tpu_programs_.size());
203   bool may_modify_variables;
204   OpsApiFn()->TpuProgram_GetMayModifyVariablesFn(tpu_programs_[index],
205                                                  &may_modify_variables);
206   return may_modify_variables;
207 }
208 
tpu_programs() const209 const std::vector<XLA_TpuProgram*>& TpuProgramGroup::tpu_programs() const {
210   return tpu_programs_;
211 }
212 
fingerprints() const213 const std::vector<std::string>& TpuProgramGroup::fingerprints() const {
214   return tpu_program_fingerprints_;
215 }
216 
set_fingerprints()217 void TpuProgramGroup::set_fingerprints() {
218   for (const XLA_TpuProgram* tpu_program : tpu_programs_) {
219     TpuProgramFingerprint fingerprint =
220         OpsApiFn()->TpuProgram_GetFingerprintFn(tpu_program);
221     tpu_program_fingerprints_.emplace_back(
222         std::string(fingerprint.bytes, fingerprint.size));
223     OpsApiFn()->TpuProgram_DestroyFingerprintFn(fingerprint);
224   }
225 }
226 
fingerprint(int index) const227 const std::string& TpuProgramGroup::fingerprint(int index) const {
228   return fingerprints().at(index);
229 }
230 
tpu_program(int index) const231 const XLA_TpuProgram* TpuProgramGroup::tpu_program(int index) const {
232   CHECK_GE(index, 0);
233   CHECK_LT(index, tpu_programs_.size());
234   return tpu_programs_[index];
235 }
236 
set_tpu_programs(absl::Span<XLA_TpuProgram * const> tpu_programs)237 void TpuProgramGroup::set_tpu_programs(
238     absl::Span<XLA_TpuProgram* const> tpu_programs) {
239   tpu_programs_.resize(tpu_programs.size());
240   for (size_t i = 0; i < tpu_programs.size(); ++i) {
241     tpu_programs_[i] = tpu_programs[i];
242   }
243 }
244 
executable_info(int index) const245 const TPUExecutableInfoProto& TpuProgramGroup::executable_info(
246     int index) const {
247   CHECK_GE(index, 0);
248   CHECK_LT(index, executable_infos_.size());
249   return executable_infos_[index];
250 }
251 
host_transfer_info(int index) const252 const TPUHostTransferInfoProto& TpuProgramGroup::host_transfer_info(
253     int index) const {
254   CHECK_GE(index, 0);
255   CHECK_LT(index, host_transfer_infos_.size());
256   return host_transfer_infos_[index];
257 }
258 
259 /*static*/
CompileAndBuild(const TpuCompilationRequestProto & compilation_request,const XLA_TpuMeshState * mesh_state,TpuProgramGroupInterface * tpu_program_group_interface)260 Status TpuProgramGroup::CompileAndBuild(
261     const TpuCompilationRequestProto& compilation_request,
262     const XLA_TpuMeshState* mesh_state,
263     TpuProgramGroupInterface* tpu_program_group_interface) {
264   se_tpu::SerializedProto serialized_compilation_request =
265       se_tpu::SerializeProto(compilation_request);
266   auto cleanup = gtl::MakeCleanup([serialized_compilation_request] {
267     se_tpu::SerializedProto_Free(serialized_compilation_request);
268   });
269   size_t count = 0;
270   XLA_TpuProgram** xla_tpu_programs = nullptr;
271   StatusHelper status;
272   OpsApiFn()->TpuCompile_CompileAndBuildFn(serialized_compilation_request,
273                                            mesh_state, &xla_tpu_programs,
274                                            &count, status.c_status);
275   if (!status.ok()) {
276     VLOG(1) << "Run CompileAndBuild failed.";
277     return status.status();
278   }
279 
280   // SPMD could return 1 result for all partitions.
281   TF_RET_CHECK(count == 1 ||
282                count == compilation_request.metadata().num_cores_per_replica());
283 
284   VLOG(1) << "Initialize TpuProgramGroup.";
285   TpuProgramGroup* tpu_program_group =
286       tensorflow::down_cast<TpuProgramGroup*>(tpu_program_group_interface);
287   tpu_program_group->Initialize(
288       absl::MakeConstSpan(&xla_tpu_programs[0], count));
289   OpsApiFn()->TpuProgram_FreeArrayFn(xla_tpu_programs);
290   return status.status();
291 }
292 
293 /*static*/
CompileAndBuild(const xrt::XLAComputation & xrt_computation_proto,const XLA_TpuMeshState * mesh_state,TpuProgramGroupInterface * tpu_program_group_interface)294 Status TpuProgramGroup::CompileAndBuild(
295     const xrt::XLAComputation& xrt_computation_proto,
296     const XLA_TpuMeshState* mesh_state,
297     TpuProgramGroupInterface* tpu_program_group_interface) {
298   se_tpu::SerializedProto serialized_compilation_request =
299       se_tpu::SerializeProto(xrt_computation_proto);
300   auto cleanup = gtl::MakeCleanup([serialized_compilation_request] {
301     se_tpu::SerializedProto_Free(serialized_compilation_request);
302   });
303   size_t count = 0;
304   XLA_TpuProgram** xla_tpu_programs = nullptr;
305   StatusHelper status;
306   OpsApiFn()->TpuCompile_XrtCompileAndBuildFn(serialized_compilation_request,
307                                               mesh_state, &xla_tpu_programs,
308                                               &count, status.c_status);
309   if (!status.ok()) {
310     VLOG(1) << "Run CompileAndBuild failed.";
311     return status.status();
312   }
313 
314   // SPMD could return 1 result for all partitions.
315   int num_cores_per_replica =
316       xrt_computation_proto.config().num_cores_per_replica()
317           ? xrt_computation_proto.config().num_cores_per_replica()
318           : 1;
319   TF_RET_CHECK(count == 1 || count == num_cores_per_replica);
320   VLOG(1) << "Initialize TpuProgramGroup.";
321   TpuProgramGroup* tpu_program_group =
322       tensorflow::down_cast<TpuProgramGroup*>(tpu_program_group_interface);
323   tpu_program_group->Initialize(
324       absl::MakeConstSpan(&xla_tpu_programs[0], count));
325   OpsApiFn()->TpuProgram_FreeArrayFn(xla_tpu_programs);
326   return status.status();
327 }
328 
tpu_programs(TpuProgramShardingType sharding_type) const329 std::vector<XLA_TpuProgram*> TpuProgramGroup::tpu_programs(
330     TpuProgramShardingType sharding_type) const {
331   std::vector<XLA_TpuProgram*> tpu_programs;
332   tpu_programs.reserve(tpu_programs_.size());
333   for (size_t i = 0; i < tpu_programs_.size(); ++i) {
334     if (OpsApiFn()->TpuProgram_HasShardingFn(tpu_programs_[i])) {
335       tpu_programs.push_back(OpsApiFn()->TpuProgram_GetTpuProgramFn(
336           tpu_programs_[i], sharding_type));
337       CHECK_NE(tpu_programs[i], nullptr);
338     }
339   }
340   return tpu_programs;
341 }
342 
DeserializeFromRpcResponseProtos(const std::vector<TpuSerializedProto> & rpc_response_protos)343 Status TpuProgramGroup::DeserializeFromRpcResponseProtos(
344     const std::vector<TpuSerializedProto>& rpc_response_protos) {
345   std::vector<XLA_TpuProgram*> tpu_programs;
346   tpu_programs.resize(rpc_response_protos.size());
347 
348   for (size_t i = 0; i < rpc_response_protos.size(); ++i) {
349     StatusHelper status;
350     auto* xla_tpu_program = OpsApiFn()->TpuProgram_NewFn();
351     OpsApiFn()->TpuProgram_DeserializeFromGetTpuProgramResponseProtoFn(
352         rpc_response_protos[i], xla_tpu_program, status.c_status);
353     if (!status.status().ok()) {
354       OpsApiFn()->TpuProgram_FreeFn(xla_tpu_program);
355       return status.status();
356     }
357     tpu_programs[i] = xla_tpu_program;
358   }
359 
360   Initialize(tpu_programs);
361   return OkStatus();
362 }
363 
SerializeExecutable(int index,TpuExecutableSerializedProto * executable) const364 Status TpuProgramGroup::SerializeExecutable(
365     int index, TpuExecutableSerializedProto* executable) const {
366   CHECK_GE(index, 0);
367   CHECK_LT(index, tpu_programs_.size());
368   StatusHelper status;
369   OpsApiFn()->TpuProgram_SerializeTpuExecutableFn(tpu_programs_[index],
370                                                   executable, status.c_status);
371   return status.status();
372 }
373 
SerializeCompilerMetadata(int index,CompilerMetadataSerializedProto * compiler_metadata) const374 Status TpuProgramGroup::SerializeCompilerMetadata(
375     int index, CompilerMetadataSerializedProto* compiler_metadata) const {
376   CHECK_GE(index, 0);
377   CHECK_LT(index, tpu_programs_.size());
378   StatusHelper status;
379   OpsApiFn()->TpuProgram_SerializeCompilerMetadataFn(
380       tpu_programs_[index], compiler_metadata, status.c_status);
381   return status.status();
382 }
383 }  // namespace tpu
384 }  // namespace tensorflow
385