1 /*
2 * Copyright 2020 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "fcp/tensorflow/tf_session.h"
18
19 #include <cstdio>
20 #include <fstream>
21 #include <iostream>
22 #include <string>
23 #include <utility>
24
25 #include "absl/strings/cord.h"
26 #include "fcp/base/platform.h"
27 #include "fcp/base/process_unique_id.h"
28 #include "fcp/base/result.h"
29 #include "fcp/tensorflow/status.h"
30 #include "tensorflow/core/protobuf/saver.pb.h"
31
32 namespace fcp {
33
34 #define TF_STATUS_EXPECT_OK(tf_status) \
35 Result(ConvertFromTensorFlowStatus(tf_status)).Then(ExpectOk())
36
37 using CheckpointOp = google::internal::federated::plan::CheckpointOp;
38
TfSession(const std::filesystem::path & tmp_dir,const absl::Cord & graph)39 TfSession::TfSession(const std::filesystem::path& tmp_dir,
40 const absl::Cord& graph)
41 : tmp_dir_(StripTrailingPathSeparator(tmp_dir.c_str())),
42 session_(tensorflow::NewSession(tensorflow::SessionOptions{})) {
43 // Parse GraphDef.
44 tensorflow::GraphDef graph_def;
45 // TODO(team): Replace with ParseFromCord (check if it is available).
46 std::string graph_str;
47 absl::CopyCordToString(graph, &graph_str);
48 if (!graph_def.ParseFromString(graph_str)) {
49 session_status_ = FCP_STATUS(INVALID_ARGUMENT)
50 << "Could not parse GraphDef.";
51 return;
52 }
53 session_status_ = ConvertFromTensorFlowStatus(session_->Create(graph_def));
54 }
55
TfSession(const std::filesystem::path & tmp_dir,absl::string_view graph)56 TfSession::TfSession(const std::filesystem::path& tmp_dir,
57 absl::string_view graph)
58 : TfSession(tmp_dir, absl::Cord(graph)) {}
59
Ready()60 Result<Unit> TfSession::Ready() {
61 return Result(session_status_).Then(ExpectOk());
62 }
63
RunOp(absl::string_view op)64 Result<Unit> TfSession::RunOp(absl::string_view op) {
65 FCP_TRY(Ready());
66 if (op.empty()) {
67 return Unit{};
68 }
69 TracingSpan<RunTfOp> span(op);
70 std::vector<std::string> target_node_names;
71 target_node_names.emplace_back(op);
72 FCP_TRY(TF_STATUS_EXPECT_OK(session_->Run(
73 /*inputs=*/{},
74 /*output_tensor_names=*/{}, target_node_names,
75 /*outputs=*/nullptr)));
76 return Unit{};
77 }
78
RunOp(const NamedTensorList & inputs,absl::string_view op)79 Result<Unit> TfSession::RunOp(const NamedTensorList& inputs,
80 absl::string_view op) {
81 FCP_TRY(Ready());
82 if (op.empty()) {
83 return Unit{};
84 }
85 std::vector<std::string> target_node_names;
86 target_node_names.emplace_back(op);
87 FCP_TRY(TF_STATUS_EXPECT_OK(session_->Run(inputs,
88 /*output_tensor_names=*/{},
89 target_node_names,
90 /*outputs=*/nullptr)));
91 return Unit{};
92 }
93
GetOutputs(std::unique_ptr<std::vector<std::string>> output_names)94 Result<std::unique_ptr<TfSession::NamedTensorMap>> TfSession::GetOutputs(
95 std::unique_ptr<std::vector<std::string>> output_names) {
96 FCP_TRY(Ready());
97 auto outputs = std::make_unique<TfSession::NamedTensorMap>();
98 if (output_names->empty()) {
99 return std::move(outputs);
100 }
101 std::vector<tensorflow::Tensor> output_list;
102 FCP_TRY(TF_STATUS_EXPECT_OK(session_->Run(
103 /*inputs=*/{}, *output_names,
104 /*target_tensor_names=*/{}, &output_list)));
105 FCP_CHECK(output_names->size() == output_list.size());
106 for (int i = 0; i < output_names->size(); i++) {
107 outputs->emplace(std::move((*output_names)[i]), std::move(output_list[i]));
108 }
109 return std::move(outputs);
110 }
111
DeleteTmpFile(const std::string & tmp_file_name)112 void DeleteTmpFile(const std::string& tmp_file_name) {
113 if (std::remove(tmp_file_name.c_str()) > 0) {
114 Trace<TmpFileNotDeleted>(tmp_file_name);
115 }
116 }
117
SaveState(const CheckpointOp & op)118 Result<absl::Cord> TfSession::SaveState(const CheckpointOp& op) {
119 FCP_TRY(Ready());
120 TracingSpan<SaveToCheckpoint> span(
121 op.before_save_op(),
122 op.has_saver_def() ? op.saver_def().save_tensor_name() : "",
123 op.after_save_op());
124 FCP_TRY(RunOp(op.before_save_op()));
125 Result<absl::Cord> res = absl::Cord("");
126 if (op.has_saver_def()) {
127 const tensorflow::SaverDef& def = op.saver_def();
128 absl::string_view save_op = def.save_tensor_name();
129 // TODO(team): Workaround due to difference between python and c++
130 // TensorFlow APIs.
131 save_op = absl::StripSuffix(save_op, ":0");
132 std::string tmp_file_name = GetTmpCheckpointFileName("save_checkpoint");
133 res =
134 RunOp({{def.filename_tensor_name(), tensorflow::Tensor(tmp_file_name)}},
135 save_op)
136 .Then([&tmp_file_name](Unit u) -> Result<StatusOr<absl::Cord>> {
137 return Result(fcp::ReadFileToCord(tmp_file_name));
138 })
139 .Then(ExpectOk());
140 DeleteTmpFile(tmp_file_name);
141 }
142 FCP_TRY(RunOp(op.after_save_op()));
143 return res;
144 }
145
RestoreState(const CheckpointOp & op,const absl::Cord & checkpoint)146 Result<Unit> TfSession::RestoreState(const CheckpointOp& op,
147 const absl::Cord& checkpoint) {
148 FCP_TRY(Ready());
149 TracingSpan<RestoreFromCheckpoint> span(
150 op.before_restore_op(),
151 op.has_saver_def() ? op.saver_def().restore_op_name() : "",
152 op.after_restore_op());
153 FCP_TRY(RunOp(op.before_restore_op()));
154 Result<Unit> res = Unit{};
155 if (op.has_saver_def()) {
156 const tensorflow::SaverDef& def = op.saver_def();
157 std::string tmp_file_name = GetTmpCheckpointFileName("restore_checkpoint");
158 res = Result(fcp::WriteCordToFile(tmp_file_name, checkpoint))
159 .Then(ExpectOk())
160 .Then([this, &def, &tmp_file_name](Unit u) -> Result<Unit> {
161 return RunOp({{def.filename_tensor_name(),
162 tensorflow::Tensor(tmp_file_name)}},
163 def.restore_op_name());
164 });
165 DeleteTmpFile(tmp_file_name);
166 }
167 FCP_TRY(RunOp(op.after_restore_op()));
168 return res;
169 }
170
RestoreState(const CheckpointOp & op,const NamedTensorList & restore_inputs)171 Result<Unit> TfSession::RestoreState(const CheckpointOp& op,
172 const NamedTensorList& restore_inputs) {
173 FCP_TRY(Ready());
174 TracingSpan<RestoreFromTensors> span(op.before_restore_op(),
175 op.after_restore_op());
176 if (op.has_saver_def()) {
177 return TraceError<InvalidCheckpointOp>(
178 "saver_def",
179 "Cannot call RestoreState with a list of named tensors with a "
180 "checkpoint op containing a SaverDef.");
181 }
182 FCP_TRY(RunOp(restore_inputs, op.before_restore_op()));
183 return RunOp(op.after_restore_op());
184 }
185
GetTmpCheckpointFileName(absl::string_view name)186 std::string TfSession::GetTmpCheckpointFileName(absl::string_view name) {
187 return ConcatPath(
188 tmp_dir_, absl::StrCat(name, ProcessUniqueId::Next().value(), ".ckp"));
189 }
190
191 } // namespace fcp
192