xref: /aosp_15_r20/external/federated-compute/fcp/client/engine/tf_wrapper.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1*14675a02SAndroid Build Coastguard Worker /*
2*14675a02SAndroid Build Coastguard Worker  * Copyright 2019 Google LLC
3*14675a02SAndroid Build Coastguard Worker  *
4*14675a02SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*14675a02SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*14675a02SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*14675a02SAndroid Build Coastguard Worker  *
8*14675a02SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*14675a02SAndroid Build Coastguard Worker  *
10*14675a02SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*14675a02SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*14675a02SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*14675a02SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*14675a02SAndroid Build Coastguard Worker  * limitations under the License.
15*14675a02SAndroid Build Coastguard Worker  */
16*14675a02SAndroid Build Coastguard Worker #include "fcp/client/engine/tf_wrapper.h"
17*14675a02SAndroid Build Coastguard Worker 
18*14675a02SAndroid Build Coastguard Worker #include <functional>
19*14675a02SAndroid Build Coastguard Worker #include <memory>
20*14675a02SAndroid Build Coastguard Worker #include <string>
21*14675a02SAndroid Build Coastguard Worker #include <utility>
22*14675a02SAndroid Build Coastguard Worker 
23*14675a02SAndroid Build Coastguard Worker #include "google/protobuf/any.pb.h"
24*14675a02SAndroid Build Coastguard Worker #include "absl/status/status.h"
25*14675a02SAndroid Build Coastguard Worker #include "absl/status/statusor.h"
26*14675a02SAndroid Build Coastguard Worker #include "fcp/client/diag_codes.pb.h"
27*14675a02SAndroid Build Coastguard Worker #include "fcp/client/engine/plan_engine_helpers.h"
28*14675a02SAndroid Build Coastguard Worker #include "fcp/client/interruptible_runner.h"
29*14675a02SAndroid Build Coastguard Worker 
30*14675a02SAndroid Build Coastguard Worker namespace fcp {
31*14675a02SAndroid Build Coastguard Worker namespace client {
32*14675a02SAndroid Build Coastguard Worker namespace engine {
33*14675a02SAndroid Build Coastguard Worker 
34*14675a02SAndroid Build Coastguard Worker using ::google::protobuf::Any;
35*14675a02SAndroid Build Coastguard Worker 
36*14675a02SAndroid Build Coastguard Worker // If `external_config_proto` contains a non-empty config proto, use that.
37*14675a02SAndroid Build Coastguard Worker // Otherwise initializes a config proto from a set of defaults.
38*14675a02SAndroid Build Coastguard Worker absl::StatusOr<tensorflow::ConfigProto>
InitializeConfigProto(const Any & external_config_proto)39*14675a02SAndroid Build Coastguard Worker TensorFlowWrapper::InitializeConfigProto(const Any& external_config_proto) {
40*14675a02SAndroid Build Coastguard Worker   // Previously, we specified a hardcoded set of options in the ConfigProto by
41*14675a02SAndroid Build Coastguard Worker   // default. However, if a non-empty ConfigProto is now provided as a
42*14675a02SAndroid Build Coastguard Worker   // parameter, then we should use it as-is, without overriding any of the
43*14675a02SAndroid Build Coastguard Worker   // options (otherwise we prevent the caller from having control over the
44*14675a02SAndroid Build Coastguard Worker   // parameters we set by default).
45*14675a02SAndroid Build Coastguard Worker   if (external_config_proto.ByteSizeLong() > 0) {
46*14675a02SAndroid Build Coastguard Worker     // Unpack the external_config_proto parameter if one is provided. In this
47*14675a02SAndroid Build Coastguard Worker     // case it must be a packed ConfigProto (anything else is an error).
48*14675a02SAndroid Build Coastguard Worker     // Accordingly, UnpackTo will return false if parsing fails or if the Any is
49*14675a02SAndroid Build Coastguard Worker     // not of a compatible type.
50*14675a02SAndroid Build Coastguard Worker     tensorflow::ConfigProto unpacked_config_proto;
51*14675a02SAndroid Build Coastguard Worker     if (!external_config_proto.UnpackTo(&unpacked_config_proto)) {
52*14675a02SAndroid Build Coastguard Worker       return absl::InvalidArgumentError("Could not parse ConfigProto.");
53*14675a02SAndroid Build Coastguard Worker     }
54*14675a02SAndroid Build Coastguard Worker     if (unpacked_config_proto.ByteSizeLong() > 0) {
55*14675a02SAndroid Build Coastguard Worker       // The caller-provided, unpacked ConfigProto was not empty, so we use it
56*14675a02SAndroid Build Coastguard Worker       // in the SessionOptions and we do not specify our default config options
57*14675a02SAndroid Build Coastguard Worker       // anymore.
58*14675a02SAndroid Build Coastguard Worker       return unpacked_config_proto;
59*14675a02SAndroid Build Coastguard Worker     }
60*14675a02SAndroid Build Coastguard Worker     // We purposely fall through to the next block if the unpacked_config_proto
61*14675a02SAndroid Build Coastguard Worker     // was empty.
62*14675a02SAndroid Build Coastguard Worker   }
63*14675a02SAndroid Build Coastguard Worker 
64*14675a02SAndroid Build Coastguard Worker   // Only if the provided ConfigProto was empty (or if none was provided) do we
65*14675a02SAndroid Build Coastguard Worker   // still set hardcoded options (this is our "old" behavior, equivalent to what
66*14675a02SAndroid Build Coastguard Worker   // we did before we supported caller-specified ConfigProtos).
67*14675a02SAndroid Build Coastguard Worker   //
68*14675a02SAndroid Build Coastguard Worker   // WARNING: If the need for tuning configuration options further arises again
69*14675a02SAndroid Build Coastguard Worker   // in the future, we ideally shouldn't update any of the hardcoded ConfigProto
70*14675a02SAndroid Build Coastguard Worker   // values here anymore. Instead, we should expect our callers to specify any
71*14675a02SAndroid Build Coastguard Worker   // ConfigProto values they want to use. We only maintain this block of code
72*14675a02SAndroid Build Coastguard Worker   // for compatibility with callers that don't provide any ConfigProto at all
73*14675a02SAndroid Build Coastguard Worker   // (yet).
74*14675a02SAndroid Build Coastguard Worker   //
75*14675a02SAndroid Build Coastguard Worker   tensorflow::ConfigProto config_proto;
76*14675a02SAndroid Build Coastguard Worker   config_proto.mutable_graph_options()->set_place_pruned_graph(true);
77*14675a02SAndroid Build Coastguard Worker   auto mutable_experimental = config_proto.mutable_experimental();
78*14675a02SAndroid Build Coastguard Worker   mutable_experimental->set_optimize_for_static_graph(true);
79*14675a02SAndroid Build Coastguard Worker   mutable_experimental->set_disable_output_partition_graphs(true);
80*14675a02SAndroid Build Coastguard Worker   return config_proto;
81*14675a02SAndroid Build Coastguard Worker }
82*14675a02SAndroid Build Coastguard Worker 
Create(const std::string & graph,const Any & config_proto,std::function<bool ()> should_abort,const InterruptibleRunner::TimingConfig & timing_config,LogManager * log_manager)83*14675a02SAndroid Build Coastguard Worker absl::StatusOr<std::unique_ptr<TensorFlowWrapper>> TensorFlowWrapper::Create(
84*14675a02SAndroid Build Coastguard Worker     const std::string& graph, const Any& config_proto,
85*14675a02SAndroid Build Coastguard Worker     std::function<bool()> should_abort,
86*14675a02SAndroid Build Coastguard Worker     const InterruptibleRunner::TimingConfig& timing_config,
87*14675a02SAndroid Build Coastguard Worker     LogManager* log_manager) {
88*14675a02SAndroid Build Coastguard Worker   // Create a tensorflow::Session.
89*14675a02SAndroid Build Coastguard Worker   tensorflow::Session* session_ptr;
90*14675a02SAndroid Build Coastguard Worker   std::unique_ptr<tensorflow::Session> session;
91*14675a02SAndroid Build Coastguard Worker   tensorflow::SessionOptions session_options;
92*14675a02SAndroid Build Coastguard Worker   FCP_ASSIGN_OR_RETURN(session_options.config,
93*14675a02SAndroid Build Coastguard Worker                        InitializeConfigProto(config_proto));
94*14675a02SAndroid Build Coastguard Worker 
95*14675a02SAndroid Build Coastguard Worker   tensorflow::Status status =
96*14675a02SAndroid Build Coastguard Worker       tensorflow::NewSession(session_options, &session_ptr);
97*14675a02SAndroid Build Coastguard Worker   if (!status.ok()) {
98*14675a02SAndroid Build Coastguard Worker     return ToFcpStatus(status, "Error in tensorflow::NewSession()");
99*14675a02SAndroid Build Coastguard Worker   }
100*14675a02SAndroid Build Coastguard Worker   session = absl::WrapUnique(session_ptr);
101*14675a02SAndroid Build Coastguard Worker 
102*14675a02SAndroid Build Coastguard Worker   // Parse GraphDef.
103*14675a02SAndroid Build Coastguard Worker   tensorflow::GraphDef graph_def;
104*14675a02SAndroid Build Coastguard Worker   bool parse_result = graph_def.ParseFromString(graph);
105*14675a02SAndroid Build Coastguard Worker   if (parse_result == false) {
106*14675a02SAndroid Build Coastguard Worker     return absl::InvalidArgumentError("Could not parse GraphDef.");
107*14675a02SAndroid Build Coastguard Worker   }
108*14675a02SAndroid Build Coastguard Worker   // Load graph.
109*14675a02SAndroid Build Coastguard Worker   status = session->Create(std::move(graph_def));
110*14675a02SAndroid Build Coastguard Worker   if (!status.ok()) {
111*14675a02SAndroid Build Coastguard Worker     return ToFcpStatus(status, "Error in Session::Create()");
112*14675a02SAndroid Build Coastguard Worker   }
113*14675a02SAndroid Build Coastguard Worker 
114*14675a02SAndroid Build Coastguard Worker   // Create an InterruptibleRunner to execute TF calls in a background thread,
115*14675a02SAndroid Build Coastguard Worker   // allowing us to abort them if need be.
116*14675a02SAndroid Build Coastguard Worker   auto interruptible_runner = std::make_unique<InterruptibleRunner>(
117*14675a02SAndroid Build Coastguard Worker       log_manager, should_abort, timing_config,
118*14675a02SAndroid Build Coastguard Worker       InterruptibleRunner::DiagnosticsConfig{
119*14675a02SAndroid Build Coastguard Worker           .interrupted =
120*14675a02SAndroid Build Coastguard Worker               ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_TF_EXECUTION,
121*14675a02SAndroid Build Coastguard Worker           .interrupt_timeout = ProdDiagCode::
122*14675a02SAndroid Build Coastguard Worker               BACKGROUND_TRAINING_INTERRUPT_TF_EXECUTION_TIMED_OUT,
123*14675a02SAndroid Build Coastguard Worker           .interrupted_extended = ProdDiagCode::
124*14675a02SAndroid Build Coastguard Worker               BACKGROUND_TRAINING_INTERRUPT_TF_EXTENDED_EXECUTION_COMPLETED,
125*14675a02SAndroid Build Coastguard Worker           .interrupt_timeout_extended = ProdDiagCode::
126*14675a02SAndroid Build Coastguard Worker               BACKGROUND_TRAINING_INTERRUPT_TF_EXTENDED_EXECUTION_TIMED_OUT});
127*14675a02SAndroid Build Coastguard Worker   auto wrapper = absl::WrapUnique(new TensorFlowWrapper(
128*14675a02SAndroid Build Coastguard Worker       std::move(session), std::move(interruptible_runner), log_manager));
129*14675a02SAndroid Build Coastguard Worker   return wrapper;
130*14675a02SAndroid Build Coastguard Worker }
131*14675a02SAndroid Build Coastguard Worker 
~TensorFlowWrapper()132*14675a02SAndroid Build Coastguard Worker TensorFlowWrapper::~TensorFlowWrapper() { FCP_CHECK(CloseAndRelease().ok()); }
133*14675a02SAndroid Build Coastguard Worker 
ToFcpStatus(tensorflow::Status s,const std::string & message_prefix)134*14675a02SAndroid Build Coastguard Worker absl::Status TensorFlowWrapper::ToFcpStatus(tensorflow::Status s,
135*14675a02SAndroid Build Coastguard Worker                                             const std::string& message_prefix) {
136*14675a02SAndroid Build Coastguard Worker   if (s.ok()) {
137*14675a02SAndroid Build Coastguard Worker     return absl::OkStatus();
138*14675a02SAndroid Build Coastguard Worker   } else if (s.code() == tensorflow::error::OUT_OF_RANGE) {
139*14675a02SAndroid Build Coastguard Worker     return absl::OutOfRangeError("");
140*14675a02SAndroid Build Coastguard Worker   } else {
141*14675a02SAndroid Build Coastguard Worker     return absl::InvalidArgumentError(
142*14675a02SAndroid Build Coastguard Worker         absl::StrCat(message_prefix, ": ", s.ToString()));
143*14675a02SAndroid Build Coastguard Worker   }
144*14675a02SAndroid Build Coastguard Worker }
145*14675a02SAndroid Build Coastguard Worker 
Run(const std::vector<std::pair<std::string,tensorflow::Tensor>> & inputs,const std::vector<std::string> & output_tensor_names,const std::vector<std::string> & target_node_names,std::vector<tensorflow::Tensor> * outputs)146*14675a02SAndroid Build Coastguard Worker absl::Status TensorFlowWrapper::Run(
147*14675a02SAndroid Build Coastguard Worker     const std::vector<std::pair<std::string, tensorflow::Tensor>>& inputs,
148*14675a02SAndroid Build Coastguard Worker     const std::vector<std::string>& output_tensor_names,
149*14675a02SAndroid Build Coastguard Worker     const std::vector<std::string>& target_node_names,
150*14675a02SAndroid Build Coastguard Worker     std::vector<tensorflow::Tensor>* outputs) {
151*14675a02SAndroid Build Coastguard Worker   FCP_CHECK(!session_closed_) << "Run() called after session close!";
152*14675a02SAndroid Build Coastguard Worker 
153*14675a02SAndroid Build Coastguard Worker   auto tensorflow_runnable = [&inputs, &output_tensor_names, &target_node_names,
154*14675a02SAndroid Build Coastguard Worker                               &outputs, this]() -> absl::Status {
155*14675a02SAndroid Build Coastguard Worker     tensorflow::Status status = this->session_->Run(inputs, output_tensor_names,
156*14675a02SAndroid Build Coastguard Worker                                                     target_node_names, outputs);
157*14675a02SAndroid Build Coastguard Worker     if (!status.ok()) {
158*14675a02SAndroid Build Coastguard Worker       return ToFcpStatus(status, "Error in Session::Run()");
159*14675a02SAndroid Build Coastguard Worker     }
160*14675a02SAndroid Build Coastguard Worker     return absl::OkStatus();
161*14675a02SAndroid Build Coastguard Worker   };
162*14675a02SAndroid Build Coastguard Worker   auto abort_tensorflow = [this]() {
163*14675a02SAndroid Build Coastguard Worker     absl::MutexLock _(&session_lock_);
164*14675a02SAndroid Build Coastguard Worker     // Errors from Close() are expected when interrupting ongoing calls. We
165*14675a02SAndroid Build Coastguard Worker     // don't call CloseAndRelease() here because that would free the TensorFlow
166*14675a02SAndroid Build Coastguard Worker     // session while other TensorFlow worker threads may still be using it.
167*14675a02SAndroid Build Coastguard Worker     session_->Close().IgnoreError();
168*14675a02SAndroid Build Coastguard Worker     session_closed_ = true;
169*14675a02SAndroid Build Coastguard Worker   };
170*14675a02SAndroid Build Coastguard Worker   return interruptible_runner_->Run(tensorflow_runnable, abort_tensorflow);
171*14675a02SAndroid Build Coastguard Worker }
172*14675a02SAndroid Build Coastguard Worker 
CloseAndRelease()173*14675a02SAndroid Build Coastguard Worker absl::Status TensorFlowWrapper::CloseAndRelease() {
174*14675a02SAndroid Build Coastguard Worker   absl::MutexLock _(&session_lock_);
175*14675a02SAndroid Build Coastguard Worker   // If the TensorFlow session hasn't been closed yet, close it.
176*14675a02SAndroid Build Coastguard Worker   if (!session_closed_) {
177*14675a02SAndroid Build Coastguard Worker     FCP_ENGINE_RETURN_IF_ERROR(
178*14675a02SAndroid Build Coastguard Worker         ToFcpStatus(session_->Close(), "Could not close TF session"));
179*14675a02SAndroid Build Coastguard Worker     session_closed_ = true;
180*14675a02SAndroid Build Coastguard Worker   }
181*14675a02SAndroid Build Coastguard Worker   // If the TensorflowSession hasn't been released yet, release it.
182*14675a02SAndroid Build Coastguard Worker   if (session_) {
183*14675a02SAndroid Build Coastguard Worker     session_.reset();
184*14675a02SAndroid Build Coastguard Worker   }
185*14675a02SAndroid Build Coastguard Worker   return absl::OkStatus();
186*14675a02SAndroid Build Coastguard Worker }
187*14675a02SAndroid Build Coastguard Worker 
188*14675a02SAndroid Build Coastguard Worker }  // namespace engine
189*14675a02SAndroid Build Coastguard Worker }  // namespace client
190*14675a02SAndroid Build Coastguard Worker }  // namespace fcp
191