xref: /aosp_15_r20/external/federated-compute/fcp/tensorflow/tf_session.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1*14675a02SAndroid Build Coastguard Worker /*
2*14675a02SAndroid Build Coastguard Worker  * Copyright 2020 Google LLC
3*14675a02SAndroid Build Coastguard Worker  *
4*14675a02SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*14675a02SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*14675a02SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*14675a02SAndroid Build Coastguard Worker  *
8*14675a02SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*14675a02SAndroid Build Coastguard Worker  *
10*14675a02SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*14675a02SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*14675a02SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*14675a02SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*14675a02SAndroid Build Coastguard Worker  * limitations under the License.
15*14675a02SAndroid Build Coastguard Worker  */
16*14675a02SAndroid Build Coastguard Worker 
17*14675a02SAndroid Build Coastguard Worker #include "fcp/tensorflow/tf_session.h"
18*14675a02SAndroid Build Coastguard Worker 
19*14675a02SAndroid Build Coastguard Worker #include <cstdio>
20*14675a02SAndroid Build Coastguard Worker #include <fstream>
21*14675a02SAndroid Build Coastguard Worker #include <iostream>
22*14675a02SAndroid Build Coastguard Worker #include <string>
23*14675a02SAndroid Build Coastguard Worker #include <utility>
24*14675a02SAndroid Build Coastguard Worker 
25*14675a02SAndroid Build Coastguard Worker #include "absl/strings/cord.h"
26*14675a02SAndroid Build Coastguard Worker #include "fcp/base/platform.h"
27*14675a02SAndroid Build Coastguard Worker #include "fcp/base/process_unique_id.h"
28*14675a02SAndroid Build Coastguard Worker #include "fcp/base/result.h"
29*14675a02SAndroid Build Coastguard Worker #include "fcp/tensorflow/status.h"
30*14675a02SAndroid Build Coastguard Worker #include "tensorflow/core/protobuf/saver.pb.h"
31*14675a02SAndroid Build Coastguard Worker 
32*14675a02SAndroid Build Coastguard Worker namespace fcp {
33*14675a02SAndroid Build Coastguard Worker 
34*14675a02SAndroid Build Coastguard Worker #define TF_STATUS_EXPECT_OK(tf_status) \
35*14675a02SAndroid Build Coastguard Worker   Result(ConvertFromTensorFlowStatus(tf_status)).Then(ExpectOk())
36*14675a02SAndroid Build Coastguard Worker 
37*14675a02SAndroid Build Coastguard Worker using CheckpointOp = google::internal::federated::plan::CheckpointOp;
38*14675a02SAndroid Build Coastguard Worker 
TfSession(const std::filesystem::path & tmp_dir,const absl::Cord & graph)39*14675a02SAndroid Build Coastguard Worker TfSession::TfSession(const std::filesystem::path& tmp_dir,
40*14675a02SAndroid Build Coastguard Worker                      const absl::Cord& graph)
41*14675a02SAndroid Build Coastguard Worker     : tmp_dir_(StripTrailingPathSeparator(tmp_dir.c_str())),
42*14675a02SAndroid Build Coastguard Worker       session_(tensorflow::NewSession(tensorflow::SessionOptions{})) {
43*14675a02SAndroid Build Coastguard Worker   // Parse GraphDef.
44*14675a02SAndroid Build Coastguard Worker   tensorflow::GraphDef graph_def;
45*14675a02SAndroid Build Coastguard Worker   // TODO(team): Replace with ParseFromCord (check if it is available).
46*14675a02SAndroid Build Coastguard Worker   std::string graph_str;
47*14675a02SAndroid Build Coastguard Worker   absl::CopyCordToString(graph, &graph_str);
48*14675a02SAndroid Build Coastguard Worker   if (!graph_def.ParseFromString(graph_str)) {
49*14675a02SAndroid Build Coastguard Worker     session_status_ = FCP_STATUS(INVALID_ARGUMENT)
50*14675a02SAndroid Build Coastguard Worker                       << "Could not parse GraphDef.";
51*14675a02SAndroid Build Coastguard Worker     return;
52*14675a02SAndroid Build Coastguard Worker   }
53*14675a02SAndroid Build Coastguard Worker   session_status_ = ConvertFromTensorFlowStatus(session_->Create(graph_def));
54*14675a02SAndroid Build Coastguard Worker }
55*14675a02SAndroid Build Coastguard Worker 
TfSession(const std::filesystem::path & tmp_dir,absl::string_view graph)56*14675a02SAndroid Build Coastguard Worker TfSession::TfSession(const std::filesystem::path& tmp_dir,
57*14675a02SAndroid Build Coastguard Worker                      absl::string_view graph)
58*14675a02SAndroid Build Coastguard Worker     : TfSession(tmp_dir, absl::Cord(graph)) {}
59*14675a02SAndroid Build Coastguard Worker 
Ready()60*14675a02SAndroid Build Coastguard Worker Result<Unit> TfSession::Ready() {
61*14675a02SAndroid Build Coastguard Worker   return Result(session_status_).Then(ExpectOk());
62*14675a02SAndroid Build Coastguard Worker }
63*14675a02SAndroid Build Coastguard Worker 
RunOp(absl::string_view op)64*14675a02SAndroid Build Coastguard Worker Result<Unit> TfSession::RunOp(absl::string_view op) {
65*14675a02SAndroid Build Coastguard Worker   FCP_TRY(Ready());
66*14675a02SAndroid Build Coastguard Worker   if (op.empty()) {
67*14675a02SAndroid Build Coastguard Worker     return Unit{};
68*14675a02SAndroid Build Coastguard Worker   }
69*14675a02SAndroid Build Coastguard Worker   TracingSpan<RunTfOp> span(op);
70*14675a02SAndroid Build Coastguard Worker   std::vector<std::string> target_node_names;
71*14675a02SAndroid Build Coastguard Worker   target_node_names.emplace_back(op);
72*14675a02SAndroid Build Coastguard Worker   FCP_TRY(TF_STATUS_EXPECT_OK(session_->Run(
73*14675a02SAndroid Build Coastguard Worker       /*inputs=*/{},
74*14675a02SAndroid Build Coastguard Worker       /*output_tensor_names=*/{}, target_node_names,
75*14675a02SAndroid Build Coastguard Worker       /*outputs=*/nullptr)));
76*14675a02SAndroid Build Coastguard Worker   return Unit{};
77*14675a02SAndroid Build Coastguard Worker }
78*14675a02SAndroid Build Coastguard Worker 
RunOp(const NamedTensorList & inputs,absl::string_view op)79*14675a02SAndroid Build Coastguard Worker Result<Unit> TfSession::RunOp(const NamedTensorList& inputs,
80*14675a02SAndroid Build Coastguard Worker                               absl::string_view op) {
81*14675a02SAndroid Build Coastguard Worker   FCP_TRY(Ready());
82*14675a02SAndroid Build Coastguard Worker   if (op.empty()) {
83*14675a02SAndroid Build Coastguard Worker     return Unit{};
84*14675a02SAndroid Build Coastguard Worker   }
85*14675a02SAndroid Build Coastguard Worker   std::vector<std::string> target_node_names;
86*14675a02SAndroid Build Coastguard Worker   target_node_names.emplace_back(op);
87*14675a02SAndroid Build Coastguard Worker   FCP_TRY(TF_STATUS_EXPECT_OK(session_->Run(inputs,
88*14675a02SAndroid Build Coastguard Worker                                             /*output_tensor_names=*/{},
89*14675a02SAndroid Build Coastguard Worker                                             target_node_names,
90*14675a02SAndroid Build Coastguard Worker                                             /*outputs=*/nullptr)));
91*14675a02SAndroid Build Coastguard Worker   return Unit{};
92*14675a02SAndroid Build Coastguard Worker }
93*14675a02SAndroid Build Coastguard Worker 
GetOutputs(std::unique_ptr<std::vector<std::string>> output_names)94*14675a02SAndroid Build Coastguard Worker Result<std::unique_ptr<TfSession::NamedTensorMap>> TfSession::GetOutputs(
95*14675a02SAndroid Build Coastguard Worker     std::unique_ptr<std::vector<std::string>> output_names) {
96*14675a02SAndroid Build Coastguard Worker   FCP_TRY(Ready());
97*14675a02SAndroid Build Coastguard Worker   auto outputs = std::make_unique<TfSession::NamedTensorMap>();
98*14675a02SAndroid Build Coastguard Worker   if (output_names->empty()) {
99*14675a02SAndroid Build Coastguard Worker     return std::move(outputs);
100*14675a02SAndroid Build Coastguard Worker   }
101*14675a02SAndroid Build Coastguard Worker   std::vector<tensorflow::Tensor> output_list;
102*14675a02SAndroid Build Coastguard Worker   FCP_TRY(TF_STATUS_EXPECT_OK(session_->Run(
103*14675a02SAndroid Build Coastguard Worker       /*inputs=*/{}, *output_names,
104*14675a02SAndroid Build Coastguard Worker       /*target_tensor_names=*/{}, &output_list)));
105*14675a02SAndroid Build Coastguard Worker   FCP_CHECK(output_names->size() == output_list.size());
106*14675a02SAndroid Build Coastguard Worker   for (int i = 0; i < output_names->size(); i++) {
107*14675a02SAndroid Build Coastguard Worker     outputs->emplace(std::move((*output_names)[i]), std::move(output_list[i]));
108*14675a02SAndroid Build Coastguard Worker   }
109*14675a02SAndroid Build Coastguard Worker   return std::move(outputs);
110*14675a02SAndroid Build Coastguard Worker }
111*14675a02SAndroid Build Coastguard Worker 
DeleteTmpFile(const std::string & tmp_file_name)112*14675a02SAndroid Build Coastguard Worker void DeleteTmpFile(const std::string& tmp_file_name) {
113*14675a02SAndroid Build Coastguard Worker   if (std::remove(tmp_file_name.c_str()) > 0) {
114*14675a02SAndroid Build Coastguard Worker     Trace<TmpFileNotDeleted>(tmp_file_name);
115*14675a02SAndroid Build Coastguard Worker   }
116*14675a02SAndroid Build Coastguard Worker }
117*14675a02SAndroid Build Coastguard Worker 
SaveState(const CheckpointOp & op)118*14675a02SAndroid Build Coastguard Worker Result<absl::Cord> TfSession::SaveState(const CheckpointOp& op) {
119*14675a02SAndroid Build Coastguard Worker   FCP_TRY(Ready());
120*14675a02SAndroid Build Coastguard Worker   TracingSpan<SaveToCheckpoint> span(
121*14675a02SAndroid Build Coastguard Worker       op.before_save_op(),
122*14675a02SAndroid Build Coastguard Worker       op.has_saver_def() ? op.saver_def().save_tensor_name() : "",
123*14675a02SAndroid Build Coastguard Worker       op.after_save_op());
124*14675a02SAndroid Build Coastguard Worker   FCP_TRY(RunOp(op.before_save_op()));
125*14675a02SAndroid Build Coastguard Worker   Result<absl::Cord> res = absl::Cord("");
126*14675a02SAndroid Build Coastguard Worker   if (op.has_saver_def()) {
127*14675a02SAndroid Build Coastguard Worker     const tensorflow::SaverDef& def = op.saver_def();
128*14675a02SAndroid Build Coastguard Worker     absl::string_view save_op = def.save_tensor_name();
129*14675a02SAndroid Build Coastguard Worker     // TODO(team): Workaround due to difference between python and c++
130*14675a02SAndroid Build Coastguard Worker     //  TensorFlow APIs.
131*14675a02SAndroid Build Coastguard Worker     save_op = absl::StripSuffix(save_op, ":0");
132*14675a02SAndroid Build Coastguard Worker     std::string tmp_file_name = GetTmpCheckpointFileName("save_checkpoint");
133*14675a02SAndroid Build Coastguard Worker     res =
134*14675a02SAndroid Build Coastguard Worker         RunOp({{def.filename_tensor_name(), tensorflow::Tensor(tmp_file_name)}},
135*14675a02SAndroid Build Coastguard Worker               save_op)
136*14675a02SAndroid Build Coastguard Worker             .Then([&tmp_file_name](Unit u) -> Result<StatusOr<absl::Cord>> {
137*14675a02SAndroid Build Coastguard Worker               return Result(fcp::ReadFileToCord(tmp_file_name));
138*14675a02SAndroid Build Coastguard Worker             })
139*14675a02SAndroid Build Coastguard Worker             .Then(ExpectOk());
140*14675a02SAndroid Build Coastguard Worker     DeleteTmpFile(tmp_file_name);
141*14675a02SAndroid Build Coastguard Worker   }
142*14675a02SAndroid Build Coastguard Worker   FCP_TRY(RunOp(op.after_save_op()));
143*14675a02SAndroid Build Coastguard Worker   return res;
144*14675a02SAndroid Build Coastguard Worker }
145*14675a02SAndroid Build Coastguard Worker 
RestoreState(const CheckpointOp & op,const absl::Cord & checkpoint)146*14675a02SAndroid Build Coastguard Worker Result<Unit> TfSession::RestoreState(const CheckpointOp& op,
147*14675a02SAndroid Build Coastguard Worker                                      const absl::Cord& checkpoint) {
148*14675a02SAndroid Build Coastguard Worker   FCP_TRY(Ready());
149*14675a02SAndroid Build Coastguard Worker   TracingSpan<RestoreFromCheckpoint> span(
150*14675a02SAndroid Build Coastguard Worker       op.before_restore_op(),
151*14675a02SAndroid Build Coastguard Worker       op.has_saver_def() ? op.saver_def().restore_op_name() : "",
152*14675a02SAndroid Build Coastguard Worker       op.after_restore_op());
153*14675a02SAndroid Build Coastguard Worker   FCP_TRY(RunOp(op.before_restore_op()));
154*14675a02SAndroid Build Coastguard Worker   Result<Unit> res = Unit{};
155*14675a02SAndroid Build Coastguard Worker   if (op.has_saver_def()) {
156*14675a02SAndroid Build Coastguard Worker     const tensorflow::SaverDef& def = op.saver_def();
157*14675a02SAndroid Build Coastguard Worker     std::string tmp_file_name = GetTmpCheckpointFileName("restore_checkpoint");
158*14675a02SAndroid Build Coastguard Worker     res = Result(fcp::WriteCordToFile(tmp_file_name, checkpoint))
159*14675a02SAndroid Build Coastguard Worker               .Then(ExpectOk())
160*14675a02SAndroid Build Coastguard Worker               .Then([this, &def, &tmp_file_name](Unit u) -> Result<Unit> {
161*14675a02SAndroid Build Coastguard Worker                 return RunOp({{def.filename_tensor_name(),
162*14675a02SAndroid Build Coastguard Worker                                tensorflow::Tensor(tmp_file_name)}},
163*14675a02SAndroid Build Coastguard Worker                              def.restore_op_name());
164*14675a02SAndroid Build Coastguard Worker               });
165*14675a02SAndroid Build Coastguard Worker     DeleteTmpFile(tmp_file_name);
166*14675a02SAndroid Build Coastguard Worker   }
167*14675a02SAndroid Build Coastguard Worker   FCP_TRY(RunOp(op.after_restore_op()));
168*14675a02SAndroid Build Coastguard Worker   return res;
169*14675a02SAndroid Build Coastguard Worker }
170*14675a02SAndroid Build Coastguard Worker 
RestoreState(const CheckpointOp & op,const NamedTensorList & restore_inputs)171*14675a02SAndroid Build Coastguard Worker Result<Unit> TfSession::RestoreState(const CheckpointOp& op,
172*14675a02SAndroid Build Coastguard Worker                                      const NamedTensorList& restore_inputs) {
173*14675a02SAndroid Build Coastguard Worker   FCP_TRY(Ready());
174*14675a02SAndroid Build Coastguard Worker   TracingSpan<RestoreFromTensors> span(op.before_restore_op(),
175*14675a02SAndroid Build Coastguard Worker                                        op.after_restore_op());
176*14675a02SAndroid Build Coastguard Worker   if (op.has_saver_def()) {
177*14675a02SAndroid Build Coastguard Worker     return TraceError<InvalidCheckpointOp>(
178*14675a02SAndroid Build Coastguard Worker         "saver_def",
179*14675a02SAndroid Build Coastguard Worker         "Cannot call RestoreState with a list of named tensors with a "
180*14675a02SAndroid Build Coastguard Worker         "checkpoint op containing a SaverDef.");
181*14675a02SAndroid Build Coastguard Worker   }
182*14675a02SAndroid Build Coastguard Worker   FCP_TRY(RunOp(restore_inputs, op.before_restore_op()));
183*14675a02SAndroid Build Coastguard Worker   return RunOp(op.after_restore_op());
184*14675a02SAndroid Build Coastguard Worker }
185*14675a02SAndroid Build Coastguard Worker 
GetTmpCheckpointFileName(absl::string_view name)186*14675a02SAndroid Build Coastguard Worker std::string TfSession::GetTmpCheckpointFileName(absl::string_view name) {
187*14675a02SAndroid Build Coastguard Worker   return ConcatPath(
188*14675a02SAndroid Build Coastguard Worker       tmp_dir_, absl::StrCat(name, ProcessUniqueId::Next().value(), ".ckp"));
189*14675a02SAndroid Build Coastguard Worker }
190*14675a02SAndroid Build Coastguard Worker 
191*14675a02SAndroid Build Coastguard Worker }  // namespace fcp
192