1 /* 2 * Copyright 2020 Google LLC 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef FCP_TENSORFLOW_TF_SESSION_H_ 18 #define FCP_TENSORFLOW_TF_SESSION_H_ 19 20 #include <filesystem> 21 #include <string> 22 23 #include "absl/container/flat_hash_map.h" 24 #include "absl/strings/cord.h" 25 #include "absl/strings/string_view.h" 26 #include "fcp/base/result.h" 27 #include "fcp/protos/plan.pb.h" 28 #include "fcp/tensorflow/tracing_schema.h" 29 #include "fcp/tracing/tracing_span.h" 30 #include "tensorflow/core/public/session.h" 31 32 namespace fcp { 33 34 class TfSession { 35 public: 36 /** 37 * Starts a tensorflow client session with the provided graph def 38 * @param tmp_dir A directory in which to create tmp files used while saving 39 * or restoring checkpoints. This directory can be the same for multiple 40 * TfSessions created in the same process, even if they are running 41 * concurrently, but it must not be the same directory passed to a 42 * TfSession in a different process. 43 * @param graph Serialized graph describing how to aggregate client updates 44 * into a global model. Must be parseable into a tesnorflow::GraphDef 45 * proto. 46 */ 47 TfSession(const std::filesystem::path& tmp_dir, const absl::Cord& graph); 48 TfSession(const std::filesystem::path& tmp_dir, absl::string_view graph); 49 50 // TfSession is neither copyable nor movable. 51 TfSession(const TfSession&) = delete; 52 TfSession& operator=(const TfSession&) = delete; 53 54 using NamedTensorList = 55 std::vector<std::pair<std::string, tensorflow::Tensor>>; 56 using NamedTensorMap = absl::flat_hash_map<std::string, tensorflow::Tensor>; 57 58 // Returns Error if the TfSession is in a bad state (for example if the 59 // provided GraphDef was invalid.) Allows failing fast while recording a 60 // useful error for debugging. 61 // If Ready() returns Error, all other methods will return Error as well. 62 Result<Unit> Ready(); 63 64 // Run a single operation only if the operation is nonempty. The operation 65 // must be present in the GraphDef that was provided in the constructor. 66 Result<Unit> RunOp(absl::string_view op); 67 68 // Returns a map of name, output tensor pairs for the outputs specified by 69 // output_names. 70 Result<std::unique_ptr<NamedTensorMap>> GetOutputs( 71 std::unique_ptr<std::vector<std::string>> output_names); 72 73 /** 74 * Saves the current state of the session. 75 * @param op Contains instructions for how to save the session state. 76 * @return the state of the session as a serialized checkpoint. 77 */ 78 Result<absl::Cord> SaveState( 79 const google::internal::federated::plan::CheckpointOp& op); 80 81 /** 82 * Restores state into the session. 83 * @param op Contains instructions for operations to run to restore the 84 * state. 85 * @param checkpoint Serialized tensorflow checkpoint that should be loaded 86 * into the session. 87 */ 88 Result<Unit> RestoreState( 89 const google::internal::federated::plan::CheckpointOp& op, 90 const absl::Cord& checkpoint); 91 92 /** 93 * Restores state into the session. 94 * @param op Contains instructions for operations to run to restore the state. 95 * saver_def must not be set on the op. 96 * @param restore_inputs A collection of tensor variables that should be 97 * loaded into the session. 98 */ 99 Result<Unit> RestoreState( 100 const google::internal::federated::plan::CheckpointOp& op, 101 const NamedTensorList& restore_inputs); 102 103 private: 104 // Overload to allow providing inputs to operations. 105 Result<Unit> RunOp(const NamedTensorList& inputs, absl::string_view op); 106 std::string GetTmpCheckpointFileName(absl::string_view name); 107 108 std::string tmp_dir_; 109 std::unique_ptr<tensorflow::Session> session_; 110 fcp::Status session_status_; 111 }; 112 113 } // namespace fcp 114 115 #endif // FCP_TENSORFLOW_TF_SESSION_H_ 116