xref: /aosp_15_r20/external/federated-compute/fcp/tensorflow/tf_session.h (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2020 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef FCP_TENSORFLOW_TF_SESSION_H_
18 #define FCP_TENSORFLOW_TF_SESSION_H_
19 
20 #include <filesystem>
21 #include <string>
22 
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/strings/cord.h"
25 #include "absl/strings/string_view.h"
26 #include "fcp/base/result.h"
27 #include "fcp/protos/plan.pb.h"
28 #include "fcp/tensorflow/tracing_schema.h"
29 #include "fcp/tracing/tracing_span.h"
30 #include "tensorflow/core/public/session.h"
31 
32 namespace fcp {
33 
34 class TfSession {
35  public:
36   /**
37    * Starts a tensorflow client session with the provided graph def
38    * @param tmp_dir A directory in which to create tmp files used while saving
39    *    or restoring checkpoints. This directory can be the same for multiple
40    *    TfSessions created in the same process, even if they are running
41    *    concurrently, but it must not be the same directory passed to a
42    *    TfSession in a different process.
43    * @param graph Serialized graph describing how to aggregate client updates
44    *    into a global model. Must be parseable into a tesnorflow::GraphDef
45    *    proto.
46    */
47   TfSession(const std::filesystem::path& tmp_dir, const absl::Cord& graph);
48   TfSession(const std::filesystem::path& tmp_dir, absl::string_view graph);
49 
50   // TfSession is neither copyable nor movable.
51   TfSession(const TfSession&) = delete;
52   TfSession& operator=(const TfSession&) = delete;
53 
54   using NamedTensorList =
55       std::vector<std::pair<std::string, tensorflow::Tensor>>;
56   using NamedTensorMap = absl::flat_hash_map<std::string, tensorflow::Tensor>;
57 
58   // Returns Error if the TfSession is in a bad state (for example if the
59   // provided GraphDef was invalid.) Allows failing fast while recording a
60   // useful error for debugging.
61   // If Ready() returns Error, all other methods will return Error as well.
62   Result<Unit> Ready();
63 
64   // Run a single operation only if the operation is nonempty. The operation
65   // must be present in the GraphDef that was provided in the constructor.
66   Result<Unit> RunOp(absl::string_view op);
67 
68   // Returns a map of name, output tensor pairs for the outputs specified by
69   // output_names.
70   Result<std::unique_ptr<NamedTensorMap>> GetOutputs(
71       std::unique_ptr<std::vector<std::string>> output_names);
72 
73   /**
74    * Saves the current state of the session.
75    * @param op Contains instructions for how to save the session state.
76    * @return the state of the session as a serialized checkpoint.
77    */
78   Result<absl::Cord> SaveState(
79       const google::internal::federated::plan::CheckpointOp& op);
80 
81   /**
82    * Restores state into the session.
83    * @param op Contains instructions for operations to run to restore the
84    * state.
85    * @param checkpoint Serialized tensorflow checkpoint that should be loaded
86    * into the session.
87    */
88   Result<Unit> RestoreState(
89       const google::internal::federated::plan::CheckpointOp& op,
90       const absl::Cord& checkpoint);
91 
92   /**
93    * Restores state into the session.
94    * @param op Contains instructions for operations to run to restore the state.
95    * saver_def must not be set on the op.
96    * @param restore_inputs A collection of tensor variables that should be
97    * loaded into the session.
98    */
99   Result<Unit> RestoreState(
100       const google::internal::federated::plan::CheckpointOp& op,
101       const NamedTensorList& restore_inputs);
102 
103  private:
104   // Overload to allow providing inputs to operations.
105   Result<Unit> RunOp(const NamedTensorList& inputs, absl::string_view op);
106   std::string GetTmpCheckpointFileName(absl::string_view name);
107 
108   std::string tmp_dir_;
109   std::unique_ptr<tensorflow::Session> session_;
110   fcp::Status session_status_;
111 };
112 
113 }  // namespace fcp
114 
115 #endif  // FCP_TENSORFLOW_TF_SESSION_H_
116