xref: /aosp_15_r20/external/federated-compute/fcp/tensorflow/tf_session.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2020 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fcp/tensorflow/tf_session.h"
18 
19 #include <cstdio>
20 #include <fstream>
21 #include <iostream>
22 #include <string>
23 #include <utility>
24 
25 #include "absl/strings/cord.h"
26 #include "fcp/base/platform.h"
27 #include "fcp/base/process_unique_id.h"
28 #include "fcp/base/result.h"
29 #include "fcp/tensorflow/status.h"
30 #include "tensorflow/core/protobuf/saver.pb.h"
31 
32 namespace fcp {
33 
34 #define TF_STATUS_EXPECT_OK(tf_status) \
35   Result(ConvertFromTensorFlowStatus(tf_status)).Then(ExpectOk())
36 
37 using CheckpointOp = google::internal::federated::plan::CheckpointOp;
38 
TfSession(const std::filesystem::path & tmp_dir,const absl::Cord & graph)39 TfSession::TfSession(const std::filesystem::path& tmp_dir,
40                      const absl::Cord& graph)
41     : tmp_dir_(StripTrailingPathSeparator(tmp_dir.c_str())),
42       session_(tensorflow::NewSession(tensorflow::SessionOptions{})) {
43   // Parse GraphDef.
44   tensorflow::GraphDef graph_def;
45   // TODO(team): Replace with ParseFromCord (check if it is available).
46   std::string graph_str;
47   absl::CopyCordToString(graph, &graph_str);
48   if (!graph_def.ParseFromString(graph_str)) {
49     session_status_ = FCP_STATUS(INVALID_ARGUMENT)
50                       << "Could not parse GraphDef.";
51     return;
52   }
53   session_status_ = ConvertFromTensorFlowStatus(session_->Create(graph_def));
54 }
55 
TfSession(const std::filesystem::path & tmp_dir,absl::string_view graph)56 TfSession::TfSession(const std::filesystem::path& tmp_dir,
57                      absl::string_view graph)
58     : TfSession(tmp_dir, absl::Cord(graph)) {}
59 
Ready()60 Result<Unit> TfSession::Ready() {
61   return Result(session_status_).Then(ExpectOk());
62 }
63 
RunOp(absl::string_view op)64 Result<Unit> TfSession::RunOp(absl::string_view op) {
65   FCP_TRY(Ready());
66   if (op.empty()) {
67     return Unit{};
68   }
69   TracingSpan<RunTfOp> span(op);
70   std::vector<std::string> target_node_names;
71   target_node_names.emplace_back(op);
72   FCP_TRY(TF_STATUS_EXPECT_OK(session_->Run(
73       /*inputs=*/{},
74       /*output_tensor_names=*/{}, target_node_names,
75       /*outputs=*/nullptr)));
76   return Unit{};
77 }
78 
RunOp(const NamedTensorList & inputs,absl::string_view op)79 Result<Unit> TfSession::RunOp(const NamedTensorList& inputs,
80                               absl::string_view op) {
81   FCP_TRY(Ready());
82   if (op.empty()) {
83     return Unit{};
84   }
85   std::vector<std::string> target_node_names;
86   target_node_names.emplace_back(op);
87   FCP_TRY(TF_STATUS_EXPECT_OK(session_->Run(inputs,
88                                             /*output_tensor_names=*/{},
89                                             target_node_names,
90                                             /*outputs=*/nullptr)));
91   return Unit{};
92 }
93 
GetOutputs(std::unique_ptr<std::vector<std::string>> output_names)94 Result<std::unique_ptr<TfSession::NamedTensorMap>> TfSession::GetOutputs(
95     std::unique_ptr<std::vector<std::string>> output_names) {
96   FCP_TRY(Ready());
97   auto outputs = std::make_unique<TfSession::NamedTensorMap>();
98   if (output_names->empty()) {
99     return std::move(outputs);
100   }
101   std::vector<tensorflow::Tensor> output_list;
102   FCP_TRY(TF_STATUS_EXPECT_OK(session_->Run(
103       /*inputs=*/{}, *output_names,
104       /*target_tensor_names=*/{}, &output_list)));
105   FCP_CHECK(output_names->size() == output_list.size());
106   for (int i = 0; i < output_names->size(); i++) {
107     outputs->emplace(std::move((*output_names)[i]), std::move(output_list[i]));
108   }
109   return std::move(outputs);
110 }
111 
DeleteTmpFile(const std::string & tmp_file_name)112 void DeleteTmpFile(const std::string& tmp_file_name) {
113   if (std::remove(tmp_file_name.c_str()) > 0) {
114     Trace<TmpFileNotDeleted>(tmp_file_name);
115   }
116 }
117 
SaveState(const CheckpointOp & op)118 Result<absl::Cord> TfSession::SaveState(const CheckpointOp& op) {
119   FCP_TRY(Ready());
120   TracingSpan<SaveToCheckpoint> span(
121       op.before_save_op(),
122       op.has_saver_def() ? op.saver_def().save_tensor_name() : "",
123       op.after_save_op());
124   FCP_TRY(RunOp(op.before_save_op()));
125   Result<absl::Cord> res = absl::Cord("");
126   if (op.has_saver_def()) {
127     const tensorflow::SaverDef& def = op.saver_def();
128     absl::string_view save_op = def.save_tensor_name();
129     // TODO(team): Workaround due to difference between python and c++
130     //  TensorFlow APIs.
131     save_op = absl::StripSuffix(save_op, ":0");
132     std::string tmp_file_name = GetTmpCheckpointFileName("save_checkpoint");
133     res =
134         RunOp({{def.filename_tensor_name(), tensorflow::Tensor(tmp_file_name)}},
135               save_op)
136             .Then([&tmp_file_name](Unit u) -> Result<StatusOr<absl::Cord>> {
137               return Result(fcp::ReadFileToCord(tmp_file_name));
138             })
139             .Then(ExpectOk());
140     DeleteTmpFile(tmp_file_name);
141   }
142   FCP_TRY(RunOp(op.after_save_op()));
143   return res;
144 }
145 
RestoreState(const CheckpointOp & op,const absl::Cord & checkpoint)146 Result<Unit> TfSession::RestoreState(const CheckpointOp& op,
147                                      const absl::Cord& checkpoint) {
148   FCP_TRY(Ready());
149   TracingSpan<RestoreFromCheckpoint> span(
150       op.before_restore_op(),
151       op.has_saver_def() ? op.saver_def().restore_op_name() : "",
152       op.after_restore_op());
153   FCP_TRY(RunOp(op.before_restore_op()));
154   Result<Unit> res = Unit{};
155   if (op.has_saver_def()) {
156     const tensorflow::SaverDef& def = op.saver_def();
157     std::string tmp_file_name = GetTmpCheckpointFileName("restore_checkpoint");
158     res = Result(fcp::WriteCordToFile(tmp_file_name, checkpoint))
159               .Then(ExpectOk())
160               .Then([this, &def, &tmp_file_name](Unit u) -> Result<Unit> {
161                 return RunOp({{def.filename_tensor_name(),
162                                tensorflow::Tensor(tmp_file_name)}},
163                              def.restore_op_name());
164               });
165     DeleteTmpFile(tmp_file_name);
166   }
167   FCP_TRY(RunOp(op.after_restore_op()));
168   return res;
169 }
170 
RestoreState(const CheckpointOp & op,const NamedTensorList & restore_inputs)171 Result<Unit> TfSession::RestoreState(const CheckpointOp& op,
172                                      const NamedTensorList& restore_inputs) {
173   FCP_TRY(Ready());
174   TracingSpan<RestoreFromTensors> span(op.before_restore_op(),
175                                        op.after_restore_op());
176   if (op.has_saver_def()) {
177     return TraceError<InvalidCheckpointOp>(
178         "saver_def",
179         "Cannot call RestoreState with a list of named tensors with a "
180         "checkpoint op containing a SaverDef.");
181   }
182   FCP_TRY(RunOp(restore_inputs, op.before_restore_op()));
183   return RunOp(op.after_restore_op());
184 }
185 
GetTmpCheckpointFileName(absl::string_view name)186 std::string TfSession::GetTmpCheckpointFileName(absl::string_view name) {
187   return ConcatPath(
188       tmp_dir_, absl::StrCat(name, ProcessUniqueId::Next().value(), ".ckp"));
189 }
190 
191 }  // namespace fcp
192