1*14675a02SAndroid Build Coastguard Worker /*
2*14675a02SAndroid Build Coastguard Worker * Copyright 2019 Google LLC
3*14675a02SAndroid Build Coastguard Worker *
4*14675a02SAndroid Build Coastguard Worker * Licensed under the Apache License, Version 2.0 (the "License");
5*14675a02SAndroid Build Coastguard Worker * you may not use this file except in compliance with the License.
6*14675a02SAndroid Build Coastguard Worker * You may obtain a copy of the License at
7*14675a02SAndroid Build Coastguard Worker *
8*14675a02SAndroid Build Coastguard Worker * http://www.apache.org/licenses/LICENSE-2.0
9*14675a02SAndroid Build Coastguard Worker *
10*14675a02SAndroid Build Coastguard Worker * Unless required by applicable law or agreed to in writing, software
11*14675a02SAndroid Build Coastguard Worker * distributed under the License is distributed on an "AS IS" BASIS,
12*14675a02SAndroid Build Coastguard Worker * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*14675a02SAndroid Build Coastguard Worker * See the License for the specific language governing permissions and
14*14675a02SAndroid Build Coastguard Worker * limitations under the License.
15*14675a02SAndroid Build Coastguard Worker */
16*14675a02SAndroid Build Coastguard Worker #include "fcp/client/engine/tf_wrapper.h"
17*14675a02SAndroid Build Coastguard Worker
18*14675a02SAndroid Build Coastguard Worker #include <functional>
19*14675a02SAndroid Build Coastguard Worker #include <memory>
20*14675a02SAndroid Build Coastguard Worker #include <string>
21*14675a02SAndroid Build Coastguard Worker #include <utility>
22*14675a02SAndroid Build Coastguard Worker
23*14675a02SAndroid Build Coastguard Worker #include "google/protobuf/any.pb.h"
24*14675a02SAndroid Build Coastguard Worker #include "absl/status/status.h"
25*14675a02SAndroid Build Coastguard Worker #include "absl/status/statusor.h"
26*14675a02SAndroid Build Coastguard Worker #include "fcp/client/diag_codes.pb.h"
27*14675a02SAndroid Build Coastguard Worker #include "fcp/client/engine/plan_engine_helpers.h"
28*14675a02SAndroid Build Coastguard Worker #include "fcp/client/interruptible_runner.h"
29*14675a02SAndroid Build Coastguard Worker
30*14675a02SAndroid Build Coastguard Worker namespace fcp {
31*14675a02SAndroid Build Coastguard Worker namespace client {
32*14675a02SAndroid Build Coastguard Worker namespace engine {
33*14675a02SAndroid Build Coastguard Worker
34*14675a02SAndroid Build Coastguard Worker using ::google::protobuf::Any;
35*14675a02SAndroid Build Coastguard Worker
36*14675a02SAndroid Build Coastguard Worker // If `external_config_proto` contains a non-empty config proto, use that.
37*14675a02SAndroid Build Coastguard Worker // Otherwise initializes a config proto from a set of defaults.
38*14675a02SAndroid Build Coastguard Worker absl::StatusOr<tensorflow::ConfigProto>
InitializeConfigProto(const Any & external_config_proto)39*14675a02SAndroid Build Coastguard Worker TensorFlowWrapper::InitializeConfigProto(const Any& external_config_proto) {
40*14675a02SAndroid Build Coastguard Worker // Previously, we specified a hardcoded set of options in the ConfigProto by
41*14675a02SAndroid Build Coastguard Worker // default. However, if a non-empty ConfigProto is now provided as a
42*14675a02SAndroid Build Coastguard Worker // parameter, then we should use it as-is, without overriding any of the
43*14675a02SAndroid Build Coastguard Worker // options (otherwise we prevent the caller from having control over the
44*14675a02SAndroid Build Coastguard Worker // parameters we set by default).
45*14675a02SAndroid Build Coastguard Worker if (external_config_proto.ByteSizeLong() > 0) {
46*14675a02SAndroid Build Coastguard Worker // Unpack the external_config_proto parameter if one is provided. In this
47*14675a02SAndroid Build Coastguard Worker // case it must be a packed ConfigProto (anything else is an error).
48*14675a02SAndroid Build Coastguard Worker // Accordingly, UnpackTo will return false if parsing fails or if the Any is
49*14675a02SAndroid Build Coastguard Worker // not of a compatible type.
50*14675a02SAndroid Build Coastguard Worker tensorflow::ConfigProto unpacked_config_proto;
51*14675a02SAndroid Build Coastguard Worker if (!external_config_proto.UnpackTo(&unpacked_config_proto)) {
52*14675a02SAndroid Build Coastguard Worker return absl::InvalidArgumentError("Could not parse ConfigProto.");
53*14675a02SAndroid Build Coastguard Worker }
54*14675a02SAndroid Build Coastguard Worker if (unpacked_config_proto.ByteSizeLong() > 0) {
55*14675a02SAndroid Build Coastguard Worker // The caller-provided, unpacked ConfigProto was not empty, so we use it
56*14675a02SAndroid Build Coastguard Worker // in the SessionOptions and we do not specify our default config options
57*14675a02SAndroid Build Coastguard Worker // anymore.
58*14675a02SAndroid Build Coastguard Worker return unpacked_config_proto;
59*14675a02SAndroid Build Coastguard Worker }
60*14675a02SAndroid Build Coastguard Worker // We purposely fall through to the next block if the unpacked_config_proto
61*14675a02SAndroid Build Coastguard Worker // was empty.
62*14675a02SAndroid Build Coastguard Worker }
63*14675a02SAndroid Build Coastguard Worker
64*14675a02SAndroid Build Coastguard Worker // Only if the provided ConfigProto was empty (or if none was provided) do we
65*14675a02SAndroid Build Coastguard Worker // still set hardcoded options (this is our "old" behavior, equivalent to what
66*14675a02SAndroid Build Coastguard Worker // we did before we supported caller-specified ConfigProtos).
67*14675a02SAndroid Build Coastguard Worker //
68*14675a02SAndroid Build Coastguard Worker // WARNING: If the need for tuning configuration options further arises again
69*14675a02SAndroid Build Coastguard Worker // in the future, we ideally shouldn't update any of the hardcoded ConfigProto
70*14675a02SAndroid Build Coastguard Worker // values here anymore. Instead, we should expect our callers to specify any
71*14675a02SAndroid Build Coastguard Worker // ConfigProto values they want to use. We only maintain this block of code
72*14675a02SAndroid Build Coastguard Worker // for compatibility with callers that don't provide any ConfigProto at all
73*14675a02SAndroid Build Coastguard Worker // (yet).
74*14675a02SAndroid Build Coastguard Worker //
75*14675a02SAndroid Build Coastguard Worker tensorflow::ConfigProto config_proto;
76*14675a02SAndroid Build Coastguard Worker config_proto.mutable_graph_options()->set_place_pruned_graph(true);
77*14675a02SAndroid Build Coastguard Worker auto mutable_experimental = config_proto.mutable_experimental();
78*14675a02SAndroid Build Coastguard Worker mutable_experimental->set_optimize_for_static_graph(true);
79*14675a02SAndroid Build Coastguard Worker mutable_experimental->set_disable_output_partition_graphs(true);
80*14675a02SAndroid Build Coastguard Worker return config_proto;
81*14675a02SAndroid Build Coastguard Worker }
82*14675a02SAndroid Build Coastguard Worker
Create(const std::string & graph,const Any & config_proto,std::function<bool ()> should_abort,const InterruptibleRunner::TimingConfig & timing_config,LogManager * log_manager)83*14675a02SAndroid Build Coastguard Worker absl::StatusOr<std::unique_ptr<TensorFlowWrapper>> TensorFlowWrapper::Create(
84*14675a02SAndroid Build Coastguard Worker const std::string& graph, const Any& config_proto,
85*14675a02SAndroid Build Coastguard Worker std::function<bool()> should_abort,
86*14675a02SAndroid Build Coastguard Worker const InterruptibleRunner::TimingConfig& timing_config,
87*14675a02SAndroid Build Coastguard Worker LogManager* log_manager) {
88*14675a02SAndroid Build Coastguard Worker // Create a tensorflow::Session.
89*14675a02SAndroid Build Coastguard Worker tensorflow::Session* session_ptr;
90*14675a02SAndroid Build Coastguard Worker std::unique_ptr<tensorflow::Session> session;
91*14675a02SAndroid Build Coastguard Worker tensorflow::SessionOptions session_options;
92*14675a02SAndroid Build Coastguard Worker FCP_ASSIGN_OR_RETURN(session_options.config,
93*14675a02SAndroid Build Coastguard Worker InitializeConfigProto(config_proto));
94*14675a02SAndroid Build Coastguard Worker
95*14675a02SAndroid Build Coastguard Worker tensorflow::Status status =
96*14675a02SAndroid Build Coastguard Worker tensorflow::NewSession(session_options, &session_ptr);
97*14675a02SAndroid Build Coastguard Worker if (!status.ok()) {
98*14675a02SAndroid Build Coastguard Worker return ToFcpStatus(status, "Error in tensorflow::NewSession()");
99*14675a02SAndroid Build Coastguard Worker }
100*14675a02SAndroid Build Coastguard Worker session = absl::WrapUnique(session_ptr);
101*14675a02SAndroid Build Coastguard Worker
102*14675a02SAndroid Build Coastguard Worker // Parse GraphDef.
103*14675a02SAndroid Build Coastguard Worker tensorflow::GraphDef graph_def;
104*14675a02SAndroid Build Coastguard Worker bool parse_result = graph_def.ParseFromString(graph);
105*14675a02SAndroid Build Coastguard Worker if (parse_result == false) {
106*14675a02SAndroid Build Coastguard Worker return absl::InvalidArgumentError("Could not parse GraphDef.");
107*14675a02SAndroid Build Coastguard Worker }
108*14675a02SAndroid Build Coastguard Worker // Load graph.
109*14675a02SAndroid Build Coastguard Worker status = session->Create(std::move(graph_def));
110*14675a02SAndroid Build Coastguard Worker if (!status.ok()) {
111*14675a02SAndroid Build Coastguard Worker return ToFcpStatus(status, "Error in Session::Create()");
112*14675a02SAndroid Build Coastguard Worker }
113*14675a02SAndroid Build Coastguard Worker
114*14675a02SAndroid Build Coastguard Worker // Create an InterruptibleRunner to execute TF calls in a background thread,
115*14675a02SAndroid Build Coastguard Worker // allowing us to abort them if need be.
116*14675a02SAndroid Build Coastguard Worker auto interruptible_runner = std::make_unique<InterruptibleRunner>(
117*14675a02SAndroid Build Coastguard Worker log_manager, should_abort, timing_config,
118*14675a02SAndroid Build Coastguard Worker InterruptibleRunner::DiagnosticsConfig{
119*14675a02SAndroid Build Coastguard Worker .interrupted =
120*14675a02SAndroid Build Coastguard Worker ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_TF_EXECUTION,
121*14675a02SAndroid Build Coastguard Worker .interrupt_timeout = ProdDiagCode::
122*14675a02SAndroid Build Coastguard Worker BACKGROUND_TRAINING_INTERRUPT_TF_EXECUTION_TIMED_OUT,
123*14675a02SAndroid Build Coastguard Worker .interrupted_extended = ProdDiagCode::
124*14675a02SAndroid Build Coastguard Worker BACKGROUND_TRAINING_INTERRUPT_TF_EXTENDED_EXECUTION_COMPLETED,
125*14675a02SAndroid Build Coastguard Worker .interrupt_timeout_extended = ProdDiagCode::
126*14675a02SAndroid Build Coastguard Worker BACKGROUND_TRAINING_INTERRUPT_TF_EXTENDED_EXECUTION_TIMED_OUT});
127*14675a02SAndroid Build Coastguard Worker auto wrapper = absl::WrapUnique(new TensorFlowWrapper(
128*14675a02SAndroid Build Coastguard Worker std::move(session), std::move(interruptible_runner), log_manager));
129*14675a02SAndroid Build Coastguard Worker return wrapper;
130*14675a02SAndroid Build Coastguard Worker }
131*14675a02SAndroid Build Coastguard Worker
~TensorFlowWrapper()132*14675a02SAndroid Build Coastguard Worker TensorFlowWrapper::~TensorFlowWrapper() { FCP_CHECK(CloseAndRelease().ok()); }
133*14675a02SAndroid Build Coastguard Worker
ToFcpStatus(tensorflow::Status s,const std::string & message_prefix)134*14675a02SAndroid Build Coastguard Worker absl::Status TensorFlowWrapper::ToFcpStatus(tensorflow::Status s,
135*14675a02SAndroid Build Coastguard Worker const std::string& message_prefix) {
136*14675a02SAndroid Build Coastguard Worker if (s.ok()) {
137*14675a02SAndroid Build Coastguard Worker return absl::OkStatus();
138*14675a02SAndroid Build Coastguard Worker } else if (s.code() == tensorflow::error::OUT_OF_RANGE) {
139*14675a02SAndroid Build Coastguard Worker return absl::OutOfRangeError("");
140*14675a02SAndroid Build Coastguard Worker } else {
141*14675a02SAndroid Build Coastguard Worker return absl::InvalidArgumentError(
142*14675a02SAndroid Build Coastguard Worker absl::StrCat(message_prefix, ": ", s.ToString()));
143*14675a02SAndroid Build Coastguard Worker }
144*14675a02SAndroid Build Coastguard Worker }
145*14675a02SAndroid Build Coastguard Worker
Run(const std::vector<std::pair<std::string,tensorflow::Tensor>> & inputs,const std::vector<std::string> & output_tensor_names,const std::vector<std::string> & target_node_names,std::vector<tensorflow::Tensor> * outputs)146*14675a02SAndroid Build Coastguard Worker absl::Status TensorFlowWrapper::Run(
147*14675a02SAndroid Build Coastguard Worker const std::vector<std::pair<std::string, tensorflow::Tensor>>& inputs,
148*14675a02SAndroid Build Coastguard Worker const std::vector<std::string>& output_tensor_names,
149*14675a02SAndroid Build Coastguard Worker const std::vector<std::string>& target_node_names,
150*14675a02SAndroid Build Coastguard Worker std::vector<tensorflow::Tensor>* outputs) {
151*14675a02SAndroid Build Coastguard Worker FCP_CHECK(!session_closed_) << "Run() called after session close!";
152*14675a02SAndroid Build Coastguard Worker
153*14675a02SAndroid Build Coastguard Worker auto tensorflow_runnable = [&inputs, &output_tensor_names, &target_node_names,
154*14675a02SAndroid Build Coastguard Worker &outputs, this]() -> absl::Status {
155*14675a02SAndroid Build Coastguard Worker tensorflow::Status status = this->session_->Run(inputs, output_tensor_names,
156*14675a02SAndroid Build Coastguard Worker target_node_names, outputs);
157*14675a02SAndroid Build Coastguard Worker if (!status.ok()) {
158*14675a02SAndroid Build Coastguard Worker return ToFcpStatus(status, "Error in Session::Run()");
159*14675a02SAndroid Build Coastguard Worker }
160*14675a02SAndroid Build Coastguard Worker return absl::OkStatus();
161*14675a02SAndroid Build Coastguard Worker };
162*14675a02SAndroid Build Coastguard Worker auto abort_tensorflow = [this]() {
163*14675a02SAndroid Build Coastguard Worker absl::MutexLock _(&session_lock_);
164*14675a02SAndroid Build Coastguard Worker // Errors from Close() are expected when interrupting ongoing calls. We
165*14675a02SAndroid Build Coastguard Worker // don't call CloseAndRelease() here because that would free the TensorFlow
166*14675a02SAndroid Build Coastguard Worker // session while other TensorFlow worker threads may still be using it.
167*14675a02SAndroid Build Coastguard Worker session_->Close().IgnoreError();
168*14675a02SAndroid Build Coastguard Worker session_closed_ = true;
169*14675a02SAndroid Build Coastguard Worker };
170*14675a02SAndroid Build Coastguard Worker return interruptible_runner_->Run(tensorflow_runnable, abort_tensorflow);
171*14675a02SAndroid Build Coastguard Worker }
172*14675a02SAndroid Build Coastguard Worker
CloseAndRelease()173*14675a02SAndroid Build Coastguard Worker absl::Status TensorFlowWrapper::CloseAndRelease() {
174*14675a02SAndroid Build Coastguard Worker absl::MutexLock _(&session_lock_);
175*14675a02SAndroid Build Coastguard Worker // If the TensorFlow session hasn't been closed yet, close it.
176*14675a02SAndroid Build Coastguard Worker if (!session_closed_) {
177*14675a02SAndroid Build Coastguard Worker FCP_ENGINE_RETURN_IF_ERROR(
178*14675a02SAndroid Build Coastguard Worker ToFcpStatus(session_->Close(), "Could not close TF session"));
179*14675a02SAndroid Build Coastguard Worker session_closed_ = true;
180*14675a02SAndroid Build Coastguard Worker }
181*14675a02SAndroid Build Coastguard Worker // If the TensorflowSession hasn't been released yet, release it.
182*14675a02SAndroid Build Coastguard Worker if (session_) {
183*14675a02SAndroid Build Coastguard Worker session_.reset();
184*14675a02SAndroid Build Coastguard Worker }
185*14675a02SAndroid Build Coastguard Worker return absl::OkStatus();
186*14675a02SAndroid Build Coastguard Worker }
187*14675a02SAndroid Build Coastguard Worker
188*14675a02SAndroid Build Coastguard Worker } // namespace engine
189*14675a02SAndroid Build Coastguard Worker } // namespace client
190*14675a02SAndroid Build Coastguard Worker } // namespace fcp
191