1*14675a02SAndroid Build Coastguard Worker /*
2*14675a02SAndroid Build Coastguard Worker * Copyright 2020 Google LLC
3*14675a02SAndroid Build Coastguard Worker *
4*14675a02SAndroid Build Coastguard Worker * Licensed under the Apache License, Version 2.0 (the "License");
5*14675a02SAndroid Build Coastguard Worker * you may not use this file except in compliance with the License.
6*14675a02SAndroid Build Coastguard Worker * You may obtain a copy of the License at
7*14675a02SAndroid Build Coastguard Worker *
8*14675a02SAndroid Build Coastguard Worker * http://www.apache.org/licenses/LICENSE-2.0
9*14675a02SAndroid Build Coastguard Worker *
10*14675a02SAndroid Build Coastguard Worker * Unless required by applicable law or agreed to in writing, software
11*14675a02SAndroid Build Coastguard Worker * distributed under the License is distributed on an "AS IS" BASIS,
12*14675a02SAndroid Build Coastguard Worker * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*14675a02SAndroid Build Coastguard Worker * See the License for the specific language governing permissions and
14*14675a02SAndroid Build Coastguard Worker * limitations under the License.
15*14675a02SAndroid Build Coastguard Worker */
16*14675a02SAndroid Build Coastguard Worker
17*14675a02SAndroid Build Coastguard Worker #include "fcp/tensorflow/tf_session.h"
18*14675a02SAndroid Build Coastguard Worker
19*14675a02SAndroid Build Coastguard Worker #include <cstdio>
20*14675a02SAndroid Build Coastguard Worker #include <fstream>
21*14675a02SAndroid Build Coastguard Worker #include <iostream>
22*14675a02SAndroid Build Coastguard Worker #include <string>
23*14675a02SAndroid Build Coastguard Worker #include <utility>
24*14675a02SAndroid Build Coastguard Worker
25*14675a02SAndroid Build Coastguard Worker #include "absl/strings/cord.h"
26*14675a02SAndroid Build Coastguard Worker #include "fcp/base/platform.h"
27*14675a02SAndroid Build Coastguard Worker #include "fcp/base/process_unique_id.h"
28*14675a02SAndroid Build Coastguard Worker #include "fcp/base/result.h"
29*14675a02SAndroid Build Coastguard Worker #include "fcp/tensorflow/status.h"
30*14675a02SAndroid Build Coastguard Worker #include "tensorflow/core/protobuf/saver.pb.h"
31*14675a02SAndroid Build Coastguard Worker
32*14675a02SAndroid Build Coastguard Worker namespace fcp {
33*14675a02SAndroid Build Coastguard Worker
34*14675a02SAndroid Build Coastguard Worker #define TF_STATUS_EXPECT_OK(tf_status) \
35*14675a02SAndroid Build Coastguard Worker Result(ConvertFromTensorFlowStatus(tf_status)).Then(ExpectOk())
36*14675a02SAndroid Build Coastguard Worker
37*14675a02SAndroid Build Coastguard Worker using CheckpointOp = google::internal::federated::plan::CheckpointOp;
38*14675a02SAndroid Build Coastguard Worker
TfSession(const std::filesystem::path & tmp_dir,const absl::Cord & graph)39*14675a02SAndroid Build Coastguard Worker TfSession::TfSession(const std::filesystem::path& tmp_dir,
40*14675a02SAndroid Build Coastguard Worker const absl::Cord& graph)
41*14675a02SAndroid Build Coastguard Worker : tmp_dir_(StripTrailingPathSeparator(tmp_dir.c_str())),
42*14675a02SAndroid Build Coastguard Worker session_(tensorflow::NewSession(tensorflow::SessionOptions{})) {
43*14675a02SAndroid Build Coastguard Worker // Parse GraphDef.
44*14675a02SAndroid Build Coastguard Worker tensorflow::GraphDef graph_def;
45*14675a02SAndroid Build Coastguard Worker // TODO(team): Replace with ParseFromCord (check if it is available).
46*14675a02SAndroid Build Coastguard Worker std::string graph_str;
47*14675a02SAndroid Build Coastguard Worker absl::CopyCordToString(graph, &graph_str);
48*14675a02SAndroid Build Coastguard Worker if (!graph_def.ParseFromString(graph_str)) {
49*14675a02SAndroid Build Coastguard Worker session_status_ = FCP_STATUS(INVALID_ARGUMENT)
50*14675a02SAndroid Build Coastguard Worker << "Could not parse GraphDef.";
51*14675a02SAndroid Build Coastguard Worker return;
52*14675a02SAndroid Build Coastguard Worker }
53*14675a02SAndroid Build Coastguard Worker session_status_ = ConvertFromTensorFlowStatus(session_->Create(graph_def));
54*14675a02SAndroid Build Coastguard Worker }
55*14675a02SAndroid Build Coastguard Worker
TfSession(const std::filesystem::path & tmp_dir,absl::string_view graph)56*14675a02SAndroid Build Coastguard Worker TfSession::TfSession(const std::filesystem::path& tmp_dir,
57*14675a02SAndroid Build Coastguard Worker absl::string_view graph)
58*14675a02SAndroid Build Coastguard Worker : TfSession(tmp_dir, absl::Cord(graph)) {}
59*14675a02SAndroid Build Coastguard Worker
Ready()60*14675a02SAndroid Build Coastguard Worker Result<Unit> TfSession::Ready() {
61*14675a02SAndroid Build Coastguard Worker return Result(session_status_).Then(ExpectOk());
62*14675a02SAndroid Build Coastguard Worker }
63*14675a02SAndroid Build Coastguard Worker
RunOp(absl::string_view op)64*14675a02SAndroid Build Coastguard Worker Result<Unit> TfSession::RunOp(absl::string_view op) {
65*14675a02SAndroid Build Coastguard Worker FCP_TRY(Ready());
66*14675a02SAndroid Build Coastguard Worker if (op.empty()) {
67*14675a02SAndroid Build Coastguard Worker return Unit{};
68*14675a02SAndroid Build Coastguard Worker }
69*14675a02SAndroid Build Coastguard Worker TracingSpan<RunTfOp> span(op);
70*14675a02SAndroid Build Coastguard Worker std::vector<std::string> target_node_names;
71*14675a02SAndroid Build Coastguard Worker target_node_names.emplace_back(op);
72*14675a02SAndroid Build Coastguard Worker FCP_TRY(TF_STATUS_EXPECT_OK(session_->Run(
73*14675a02SAndroid Build Coastguard Worker /*inputs=*/{},
74*14675a02SAndroid Build Coastguard Worker /*output_tensor_names=*/{}, target_node_names,
75*14675a02SAndroid Build Coastguard Worker /*outputs=*/nullptr)));
76*14675a02SAndroid Build Coastguard Worker return Unit{};
77*14675a02SAndroid Build Coastguard Worker }
78*14675a02SAndroid Build Coastguard Worker
RunOp(const NamedTensorList & inputs,absl::string_view op)79*14675a02SAndroid Build Coastguard Worker Result<Unit> TfSession::RunOp(const NamedTensorList& inputs,
80*14675a02SAndroid Build Coastguard Worker absl::string_view op) {
81*14675a02SAndroid Build Coastguard Worker FCP_TRY(Ready());
82*14675a02SAndroid Build Coastguard Worker if (op.empty()) {
83*14675a02SAndroid Build Coastguard Worker return Unit{};
84*14675a02SAndroid Build Coastguard Worker }
85*14675a02SAndroid Build Coastguard Worker std::vector<std::string> target_node_names;
86*14675a02SAndroid Build Coastguard Worker target_node_names.emplace_back(op);
87*14675a02SAndroid Build Coastguard Worker FCP_TRY(TF_STATUS_EXPECT_OK(session_->Run(inputs,
88*14675a02SAndroid Build Coastguard Worker /*output_tensor_names=*/{},
89*14675a02SAndroid Build Coastguard Worker target_node_names,
90*14675a02SAndroid Build Coastguard Worker /*outputs=*/nullptr)));
91*14675a02SAndroid Build Coastguard Worker return Unit{};
92*14675a02SAndroid Build Coastguard Worker }
93*14675a02SAndroid Build Coastguard Worker
GetOutputs(std::unique_ptr<std::vector<std::string>> output_names)94*14675a02SAndroid Build Coastguard Worker Result<std::unique_ptr<TfSession::NamedTensorMap>> TfSession::GetOutputs(
95*14675a02SAndroid Build Coastguard Worker std::unique_ptr<std::vector<std::string>> output_names) {
96*14675a02SAndroid Build Coastguard Worker FCP_TRY(Ready());
97*14675a02SAndroid Build Coastguard Worker auto outputs = std::make_unique<TfSession::NamedTensorMap>();
98*14675a02SAndroid Build Coastguard Worker if (output_names->empty()) {
99*14675a02SAndroid Build Coastguard Worker return std::move(outputs);
100*14675a02SAndroid Build Coastguard Worker }
101*14675a02SAndroid Build Coastguard Worker std::vector<tensorflow::Tensor> output_list;
102*14675a02SAndroid Build Coastguard Worker FCP_TRY(TF_STATUS_EXPECT_OK(session_->Run(
103*14675a02SAndroid Build Coastguard Worker /*inputs=*/{}, *output_names,
104*14675a02SAndroid Build Coastguard Worker /*target_tensor_names=*/{}, &output_list)));
105*14675a02SAndroid Build Coastguard Worker FCP_CHECK(output_names->size() == output_list.size());
106*14675a02SAndroid Build Coastguard Worker for (int i = 0; i < output_names->size(); i++) {
107*14675a02SAndroid Build Coastguard Worker outputs->emplace(std::move((*output_names)[i]), std::move(output_list[i]));
108*14675a02SAndroid Build Coastguard Worker }
109*14675a02SAndroid Build Coastguard Worker return std::move(outputs);
110*14675a02SAndroid Build Coastguard Worker }
111*14675a02SAndroid Build Coastguard Worker
DeleteTmpFile(const std::string & tmp_file_name)112*14675a02SAndroid Build Coastguard Worker void DeleteTmpFile(const std::string& tmp_file_name) {
113*14675a02SAndroid Build Coastguard Worker if (std::remove(tmp_file_name.c_str()) > 0) {
114*14675a02SAndroid Build Coastguard Worker Trace<TmpFileNotDeleted>(tmp_file_name);
115*14675a02SAndroid Build Coastguard Worker }
116*14675a02SAndroid Build Coastguard Worker }
117*14675a02SAndroid Build Coastguard Worker
SaveState(const CheckpointOp & op)118*14675a02SAndroid Build Coastguard Worker Result<absl::Cord> TfSession::SaveState(const CheckpointOp& op) {
119*14675a02SAndroid Build Coastguard Worker FCP_TRY(Ready());
120*14675a02SAndroid Build Coastguard Worker TracingSpan<SaveToCheckpoint> span(
121*14675a02SAndroid Build Coastguard Worker op.before_save_op(),
122*14675a02SAndroid Build Coastguard Worker op.has_saver_def() ? op.saver_def().save_tensor_name() : "",
123*14675a02SAndroid Build Coastguard Worker op.after_save_op());
124*14675a02SAndroid Build Coastguard Worker FCP_TRY(RunOp(op.before_save_op()));
125*14675a02SAndroid Build Coastguard Worker Result<absl::Cord> res = absl::Cord("");
126*14675a02SAndroid Build Coastguard Worker if (op.has_saver_def()) {
127*14675a02SAndroid Build Coastguard Worker const tensorflow::SaverDef& def = op.saver_def();
128*14675a02SAndroid Build Coastguard Worker absl::string_view save_op = def.save_tensor_name();
129*14675a02SAndroid Build Coastguard Worker // TODO(team): Workaround due to difference between python and c++
130*14675a02SAndroid Build Coastguard Worker // TensorFlow APIs.
131*14675a02SAndroid Build Coastguard Worker save_op = absl::StripSuffix(save_op, ":0");
132*14675a02SAndroid Build Coastguard Worker std::string tmp_file_name = GetTmpCheckpointFileName("save_checkpoint");
133*14675a02SAndroid Build Coastguard Worker res =
134*14675a02SAndroid Build Coastguard Worker RunOp({{def.filename_tensor_name(), tensorflow::Tensor(tmp_file_name)}},
135*14675a02SAndroid Build Coastguard Worker save_op)
136*14675a02SAndroid Build Coastguard Worker .Then([&tmp_file_name](Unit u) -> Result<StatusOr<absl::Cord>> {
137*14675a02SAndroid Build Coastguard Worker return Result(fcp::ReadFileToCord(tmp_file_name));
138*14675a02SAndroid Build Coastguard Worker })
139*14675a02SAndroid Build Coastguard Worker .Then(ExpectOk());
140*14675a02SAndroid Build Coastguard Worker DeleteTmpFile(tmp_file_name);
141*14675a02SAndroid Build Coastguard Worker }
142*14675a02SAndroid Build Coastguard Worker FCP_TRY(RunOp(op.after_save_op()));
143*14675a02SAndroid Build Coastguard Worker return res;
144*14675a02SAndroid Build Coastguard Worker }
145*14675a02SAndroid Build Coastguard Worker
RestoreState(const CheckpointOp & op,const absl::Cord & checkpoint)146*14675a02SAndroid Build Coastguard Worker Result<Unit> TfSession::RestoreState(const CheckpointOp& op,
147*14675a02SAndroid Build Coastguard Worker const absl::Cord& checkpoint) {
148*14675a02SAndroid Build Coastguard Worker FCP_TRY(Ready());
149*14675a02SAndroid Build Coastguard Worker TracingSpan<RestoreFromCheckpoint> span(
150*14675a02SAndroid Build Coastguard Worker op.before_restore_op(),
151*14675a02SAndroid Build Coastguard Worker op.has_saver_def() ? op.saver_def().restore_op_name() : "",
152*14675a02SAndroid Build Coastguard Worker op.after_restore_op());
153*14675a02SAndroid Build Coastguard Worker FCP_TRY(RunOp(op.before_restore_op()));
154*14675a02SAndroid Build Coastguard Worker Result<Unit> res = Unit{};
155*14675a02SAndroid Build Coastguard Worker if (op.has_saver_def()) {
156*14675a02SAndroid Build Coastguard Worker const tensorflow::SaverDef& def = op.saver_def();
157*14675a02SAndroid Build Coastguard Worker std::string tmp_file_name = GetTmpCheckpointFileName("restore_checkpoint");
158*14675a02SAndroid Build Coastguard Worker res = Result(fcp::WriteCordToFile(tmp_file_name, checkpoint))
159*14675a02SAndroid Build Coastguard Worker .Then(ExpectOk())
160*14675a02SAndroid Build Coastguard Worker .Then([this, &def, &tmp_file_name](Unit u) -> Result<Unit> {
161*14675a02SAndroid Build Coastguard Worker return RunOp({{def.filename_tensor_name(),
162*14675a02SAndroid Build Coastguard Worker tensorflow::Tensor(tmp_file_name)}},
163*14675a02SAndroid Build Coastguard Worker def.restore_op_name());
164*14675a02SAndroid Build Coastguard Worker });
165*14675a02SAndroid Build Coastguard Worker DeleteTmpFile(tmp_file_name);
166*14675a02SAndroid Build Coastguard Worker }
167*14675a02SAndroid Build Coastguard Worker FCP_TRY(RunOp(op.after_restore_op()));
168*14675a02SAndroid Build Coastguard Worker return res;
169*14675a02SAndroid Build Coastguard Worker }
170*14675a02SAndroid Build Coastguard Worker
RestoreState(const CheckpointOp & op,const NamedTensorList & restore_inputs)171*14675a02SAndroid Build Coastguard Worker Result<Unit> TfSession::RestoreState(const CheckpointOp& op,
172*14675a02SAndroid Build Coastguard Worker const NamedTensorList& restore_inputs) {
173*14675a02SAndroid Build Coastguard Worker FCP_TRY(Ready());
174*14675a02SAndroid Build Coastguard Worker TracingSpan<RestoreFromTensors> span(op.before_restore_op(),
175*14675a02SAndroid Build Coastguard Worker op.after_restore_op());
176*14675a02SAndroid Build Coastguard Worker if (op.has_saver_def()) {
177*14675a02SAndroid Build Coastguard Worker return TraceError<InvalidCheckpointOp>(
178*14675a02SAndroid Build Coastguard Worker "saver_def",
179*14675a02SAndroid Build Coastguard Worker "Cannot call RestoreState with a list of named tensors with a "
180*14675a02SAndroid Build Coastguard Worker "checkpoint op containing a SaverDef.");
181*14675a02SAndroid Build Coastguard Worker }
182*14675a02SAndroid Build Coastguard Worker FCP_TRY(RunOp(restore_inputs, op.before_restore_op()));
183*14675a02SAndroid Build Coastguard Worker return RunOp(op.after_restore_op());
184*14675a02SAndroid Build Coastguard Worker }
185*14675a02SAndroid Build Coastguard Worker
GetTmpCheckpointFileName(absl::string_view name)186*14675a02SAndroid Build Coastguard Worker std::string TfSession::GetTmpCheckpointFileName(absl::string_view name) {
187*14675a02SAndroid Build Coastguard Worker return ConcatPath(
188*14675a02SAndroid Build Coastguard Worker tmp_dir_, absl::StrCat(name, ProcessUniqueId::Next().value(), ".ckp"));
189*14675a02SAndroid Build Coastguard Worker }
190*14675a02SAndroid Build Coastguard Worker
191*14675a02SAndroid Build Coastguard Worker } // namespace fcp
192