1*14675a02SAndroid Build Coastguard Worker /* 2*14675a02SAndroid Build Coastguard Worker * Copyright 2020 Google LLC 3*14675a02SAndroid Build Coastguard Worker * 4*14675a02SAndroid Build Coastguard Worker * Licensed under the Apache License, Version 2.0 (the "License"); 5*14675a02SAndroid Build Coastguard Worker * you may not use this file except in compliance with the License. 6*14675a02SAndroid Build Coastguard Worker * You may obtain a copy of the License at 7*14675a02SAndroid Build Coastguard Worker * 8*14675a02SAndroid Build Coastguard Worker * http://www.apache.org/licenses/LICENSE-2.0 9*14675a02SAndroid Build Coastguard Worker * 10*14675a02SAndroid Build Coastguard Worker * Unless required by applicable law or agreed to in writing, software 11*14675a02SAndroid Build Coastguard Worker * distributed under the License is distributed on an "AS IS" BASIS, 12*14675a02SAndroid Build Coastguard Worker * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13*14675a02SAndroid Build Coastguard Worker * See the License for the specific language governing permissions and 14*14675a02SAndroid Build Coastguard Worker * limitations under the License. 15*14675a02SAndroid Build Coastguard Worker */ 16*14675a02SAndroid Build Coastguard Worker 17*14675a02SAndroid Build Coastguard Worker #ifndef FCP_TENSORFLOW_TF_SESSION_H_ 18*14675a02SAndroid Build Coastguard Worker #define FCP_TENSORFLOW_TF_SESSION_H_ 19*14675a02SAndroid Build Coastguard Worker 20*14675a02SAndroid Build Coastguard Worker #include <filesystem> 21*14675a02SAndroid Build Coastguard Worker #include <string> 22*14675a02SAndroid Build Coastguard Worker 23*14675a02SAndroid Build Coastguard Worker #include "absl/container/flat_hash_map.h" 24*14675a02SAndroid Build Coastguard Worker #include "absl/strings/cord.h" 25*14675a02SAndroid Build Coastguard Worker #include "absl/strings/string_view.h" 26*14675a02SAndroid Build Coastguard Worker #include "fcp/base/result.h" 27*14675a02SAndroid Build Coastguard Worker #include "fcp/protos/plan.pb.h" 28*14675a02SAndroid Build Coastguard Worker #include "fcp/tensorflow/tracing_schema.h" 29*14675a02SAndroid Build Coastguard Worker #include "fcp/tracing/tracing_span.h" 30*14675a02SAndroid Build Coastguard Worker #include "tensorflow/core/public/session.h" 31*14675a02SAndroid Build Coastguard Worker 32*14675a02SAndroid Build Coastguard Worker namespace fcp { 33*14675a02SAndroid Build Coastguard Worker 34*14675a02SAndroid Build Coastguard Worker class TfSession { 35*14675a02SAndroid Build Coastguard Worker public: 36*14675a02SAndroid Build Coastguard Worker /** 37*14675a02SAndroid Build Coastguard Worker * Starts a tensorflow client session with the provided graph def 38*14675a02SAndroid Build Coastguard Worker * @param tmp_dir A directory in which to create tmp files used while saving 39*14675a02SAndroid Build Coastguard Worker * or restoring checkpoints. This directory can be the same for multiple 40*14675a02SAndroid Build Coastguard Worker * TfSessions created in the same process, even if they are running 41*14675a02SAndroid Build Coastguard Worker * concurrently, but it must not be the same directory passed to a 42*14675a02SAndroid Build Coastguard Worker * TfSession in a different process. 43*14675a02SAndroid Build Coastguard Worker * @param graph Serialized graph describing how to aggregate client updates 44*14675a02SAndroid Build Coastguard Worker * into a global model. Must be parseable into a tesnorflow::GraphDef 45*14675a02SAndroid Build Coastguard Worker * proto. 46*14675a02SAndroid Build Coastguard Worker */ 47*14675a02SAndroid Build Coastguard Worker TfSession(const std::filesystem::path& tmp_dir, const absl::Cord& graph); 48*14675a02SAndroid Build Coastguard Worker TfSession(const std::filesystem::path& tmp_dir, absl::string_view graph); 49*14675a02SAndroid Build Coastguard Worker 50*14675a02SAndroid Build Coastguard Worker // TfSession is neither copyable nor movable. 51*14675a02SAndroid Build Coastguard Worker TfSession(const TfSession&) = delete; 52*14675a02SAndroid Build Coastguard Worker TfSession& operator=(const TfSession&) = delete; 53*14675a02SAndroid Build Coastguard Worker 54*14675a02SAndroid Build Coastguard Worker using NamedTensorList = 55*14675a02SAndroid Build Coastguard Worker std::vector<std::pair<std::string, tensorflow::Tensor>>; 56*14675a02SAndroid Build Coastguard Worker using NamedTensorMap = absl::flat_hash_map<std::string, tensorflow::Tensor>; 57*14675a02SAndroid Build Coastguard Worker 58*14675a02SAndroid Build Coastguard Worker // Returns Error if the TfSession is in a bad state (for example if the 59*14675a02SAndroid Build Coastguard Worker // provided GraphDef was invalid.) Allows failing fast while recording a 60*14675a02SAndroid Build Coastguard Worker // useful error for debugging. 61*14675a02SAndroid Build Coastguard Worker // If Ready() returns Error, all other methods will return Error as well. 62*14675a02SAndroid Build Coastguard Worker Result<Unit> Ready(); 63*14675a02SAndroid Build Coastguard Worker 64*14675a02SAndroid Build Coastguard Worker // Run a single operation only if the operation is nonempty. The operation 65*14675a02SAndroid Build Coastguard Worker // must be present in the GraphDef that was provided in the constructor. 66*14675a02SAndroid Build Coastguard Worker Result<Unit> RunOp(absl::string_view op); 67*14675a02SAndroid Build Coastguard Worker 68*14675a02SAndroid Build Coastguard Worker // Returns a map of name, output tensor pairs for the outputs specified by 69*14675a02SAndroid Build Coastguard Worker // output_names. 70*14675a02SAndroid Build Coastguard Worker Result<std::unique_ptr<NamedTensorMap>> GetOutputs( 71*14675a02SAndroid Build Coastguard Worker std::unique_ptr<std::vector<std::string>> output_names); 72*14675a02SAndroid Build Coastguard Worker 73*14675a02SAndroid Build Coastguard Worker /** 74*14675a02SAndroid Build Coastguard Worker * Saves the current state of the session. 75*14675a02SAndroid Build Coastguard Worker * @param op Contains instructions for how to save the session state. 76*14675a02SAndroid Build Coastguard Worker * @return the state of the session as a serialized checkpoint. 77*14675a02SAndroid Build Coastguard Worker */ 78*14675a02SAndroid Build Coastguard Worker Result<absl::Cord> SaveState( 79*14675a02SAndroid Build Coastguard Worker const google::internal::federated::plan::CheckpointOp& op); 80*14675a02SAndroid Build Coastguard Worker 81*14675a02SAndroid Build Coastguard Worker /** 82*14675a02SAndroid Build Coastguard Worker * Restores state into the session. 83*14675a02SAndroid Build Coastguard Worker * @param op Contains instructions for operations to run to restore the 84*14675a02SAndroid Build Coastguard Worker * state. 85*14675a02SAndroid Build Coastguard Worker * @param checkpoint Serialized tensorflow checkpoint that should be loaded 86*14675a02SAndroid Build Coastguard Worker * into the session. 87*14675a02SAndroid Build Coastguard Worker */ 88*14675a02SAndroid Build Coastguard Worker Result<Unit> RestoreState( 89*14675a02SAndroid Build Coastguard Worker const google::internal::federated::plan::CheckpointOp& op, 90*14675a02SAndroid Build Coastguard Worker const absl::Cord& checkpoint); 91*14675a02SAndroid Build Coastguard Worker 92*14675a02SAndroid Build Coastguard Worker /** 93*14675a02SAndroid Build Coastguard Worker * Restores state into the session. 94*14675a02SAndroid Build Coastguard Worker * @param op Contains instructions for operations to run to restore the state. 95*14675a02SAndroid Build Coastguard Worker * saver_def must not be set on the op. 96*14675a02SAndroid Build Coastguard Worker * @param restore_inputs A collection of tensor variables that should be 97*14675a02SAndroid Build Coastguard Worker * loaded into the session. 98*14675a02SAndroid Build Coastguard Worker */ 99*14675a02SAndroid Build Coastguard Worker Result<Unit> RestoreState( 100*14675a02SAndroid Build Coastguard Worker const google::internal::federated::plan::CheckpointOp& op, 101*14675a02SAndroid Build Coastguard Worker const NamedTensorList& restore_inputs); 102*14675a02SAndroid Build Coastguard Worker 103*14675a02SAndroid Build Coastguard Worker private: 104*14675a02SAndroid Build Coastguard Worker // Overload to allow providing inputs to operations. 105*14675a02SAndroid Build Coastguard Worker Result<Unit> RunOp(const NamedTensorList& inputs, absl::string_view op); 106*14675a02SAndroid Build Coastguard Worker std::string GetTmpCheckpointFileName(absl::string_view name); 107*14675a02SAndroid Build Coastguard Worker 108*14675a02SAndroid Build Coastguard Worker std::string tmp_dir_; 109*14675a02SAndroid Build Coastguard Worker std::unique_ptr<tensorflow::Session> session_; 110*14675a02SAndroid Build Coastguard Worker fcp::Status session_status_; 111*14675a02SAndroid Build Coastguard Worker }; 112*14675a02SAndroid Build Coastguard Worker 113*14675a02SAndroid Build Coastguard Worker } // namespace fcp 114*14675a02SAndroid Build Coastguard Worker 115*14675a02SAndroid Build Coastguard Worker #endif // FCP_TENSORFLOW_TF_SESSION_H_ 116