xref: /aosp_15_r20/external/federated-compute/fcp/tensorflow/tf_session.h (revision 14675a029014e728ec732f129a32e299b2da0601)
1*14675a02SAndroid Build Coastguard Worker /*
2*14675a02SAndroid Build Coastguard Worker  * Copyright 2020 Google LLC
3*14675a02SAndroid Build Coastguard Worker  *
4*14675a02SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*14675a02SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*14675a02SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*14675a02SAndroid Build Coastguard Worker  *
8*14675a02SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*14675a02SAndroid Build Coastguard Worker  *
10*14675a02SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*14675a02SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*14675a02SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*14675a02SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*14675a02SAndroid Build Coastguard Worker  * limitations under the License.
15*14675a02SAndroid Build Coastguard Worker  */
16*14675a02SAndroid Build Coastguard Worker 
17*14675a02SAndroid Build Coastguard Worker #ifndef FCP_TENSORFLOW_TF_SESSION_H_
18*14675a02SAndroid Build Coastguard Worker #define FCP_TENSORFLOW_TF_SESSION_H_
19*14675a02SAndroid Build Coastguard Worker 
20*14675a02SAndroid Build Coastguard Worker #include <filesystem>
21*14675a02SAndroid Build Coastguard Worker #include <string>
22*14675a02SAndroid Build Coastguard Worker 
23*14675a02SAndroid Build Coastguard Worker #include "absl/container/flat_hash_map.h"
24*14675a02SAndroid Build Coastguard Worker #include "absl/strings/cord.h"
25*14675a02SAndroid Build Coastguard Worker #include "absl/strings/string_view.h"
26*14675a02SAndroid Build Coastguard Worker #include "fcp/base/result.h"
27*14675a02SAndroid Build Coastguard Worker #include "fcp/protos/plan.pb.h"
28*14675a02SAndroid Build Coastguard Worker #include "fcp/tensorflow/tracing_schema.h"
29*14675a02SAndroid Build Coastguard Worker #include "fcp/tracing/tracing_span.h"
30*14675a02SAndroid Build Coastguard Worker #include "tensorflow/core/public/session.h"
31*14675a02SAndroid Build Coastguard Worker 
32*14675a02SAndroid Build Coastguard Worker namespace fcp {
33*14675a02SAndroid Build Coastguard Worker 
34*14675a02SAndroid Build Coastguard Worker class TfSession {
35*14675a02SAndroid Build Coastguard Worker  public:
36*14675a02SAndroid Build Coastguard Worker   /**
37*14675a02SAndroid Build Coastguard Worker    * Starts a tensorflow client session with the provided graph def
38*14675a02SAndroid Build Coastguard Worker    * @param tmp_dir A directory in which to create tmp files used while saving
39*14675a02SAndroid Build Coastguard Worker    *    or restoring checkpoints. This directory can be the same for multiple
40*14675a02SAndroid Build Coastguard Worker    *    TfSessions created in the same process, even if they are running
41*14675a02SAndroid Build Coastguard Worker    *    concurrently, but it must not be the same directory passed to a
42*14675a02SAndroid Build Coastguard Worker    *    TfSession in a different process.
43*14675a02SAndroid Build Coastguard Worker    * @param graph Serialized graph describing how to aggregate client updates
44*14675a02SAndroid Build Coastguard Worker    *    into a global model. Must be parseable into a tesnorflow::GraphDef
45*14675a02SAndroid Build Coastguard Worker    *    proto.
46*14675a02SAndroid Build Coastguard Worker    */
47*14675a02SAndroid Build Coastguard Worker   TfSession(const std::filesystem::path& tmp_dir, const absl::Cord& graph);
48*14675a02SAndroid Build Coastguard Worker   TfSession(const std::filesystem::path& tmp_dir, absl::string_view graph);
49*14675a02SAndroid Build Coastguard Worker 
50*14675a02SAndroid Build Coastguard Worker   // TfSession is neither copyable nor movable.
51*14675a02SAndroid Build Coastguard Worker   TfSession(const TfSession&) = delete;
52*14675a02SAndroid Build Coastguard Worker   TfSession& operator=(const TfSession&) = delete;
53*14675a02SAndroid Build Coastguard Worker 
54*14675a02SAndroid Build Coastguard Worker   using NamedTensorList =
55*14675a02SAndroid Build Coastguard Worker       std::vector<std::pair<std::string, tensorflow::Tensor>>;
56*14675a02SAndroid Build Coastguard Worker   using NamedTensorMap = absl::flat_hash_map<std::string, tensorflow::Tensor>;
57*14675a02SAndroid Build Coastguard Worker 
58*14675a02SAndroid Build Coastguard Worker   // Returns Error if the TfSession is in a bad state (for example if the
59*14675a02SAndroid Build Coastguard Worker   // provided GraphDef was invalid.) Allows failing fast while recording a
60*14675a02SAndroid Build Coastguard Worker   // useful error for debugging.
61*14675a02SAndroid Build Coastguard Worker   // If Ready() returns Error, all other methods will return Error as well.
62*14675a02SAndroid Build Coastguard Worker   Result<Unit> Ready();
63*14675a02SAndroid Build Coastguard Worker 
64*14675a02SAndroid Build Coastguard Worker   // Run a single operation only if the operation is nonempty. The operation
65*14675a02SAndroid Build Coastguard Worker   // must be present in the GraphDef that was provided in the constructor.
66*14675a02SAndroid Build Coastguard Worker   Result<Unit> RunOp(absl::string_view op);
67*14675a02SAndroid Build Coastguard Worker 
68*14675a02SAndroid Build Coastguard Worker   // Returns a map of name, output tensor pairs for the outputs specified by
69*14675a02SAndroid Build Coastguard Worker   // output_names.
70*14675a02SAndroid Build Coastguard Worker   Result<std::unique_ptr<NamedTensorMap>> GetOutputs(
71*14675a02SAndroid Build Coastguard Worker       std::unique_ptr<std::vector<std::string>> output_names);
72*14675a02SAndroid Build Coastguard Worker 
73*14675a02SAndroid Build Coastguard Worker   /**
74*14675a02SAndroid Build Coastguard Worker    * Saves the current state of the session.
75*14675a02SAndroid Build Coastguard Worker    * @param op Contains instructions for how to save the session state.
76*14675a02SAndroid Build Coastguard Worker    * @return the state of the session as a serialized checkpoint.
77*14675a02SAndroid Build Coastguard Worker    */
78*14675a02SAndroid Build Coastguard Worker   Result<absl::Cord> SaveState(
79*14675a02SAndroid Build Coastguard Worker       const google::internal::federated::plan::CheckpointOp& op);
80*14675a02SAndroid Build Coastguard Worker 
81*14675a02SAndroid Build Coastguard Worker   /**
82*14675a02SAndroid Build Coastguard Worker    * Restores state into the session.
83*14675a02SAndroid Build Coastguard Worker    * @param op Contains instructions for operations to run to restore the
84*14675a02SAndroid Build Coastguard Worker    * state.
85*14675a02SAndroid Build Coastguard Worker    * @param checkpoint Serialized tensorflow checkpoint that should be loaded
86*14675a02SAndroid Build Coastguard Worker    * into the session.
87*14675a02SAndroid Build Coastguard Worker    */
88*14675a02SAndroid Build Coastguard Worker   Result<Unit> RestoreState(
89*14675a02SAndroid Build Coastguard Worker       const google::internal::federated::plan::CheckpointOp& op,
90*14675a02SAndroid Build Coastguard Worker       const absl::Cord& checkpoint);
91*14675a02SAndroid Build Coastguard Worker 
92*14675a02SAndroid Build Coastguard Worker   /**
93*14675a02SAndroid Build Coastguard Worker    * Restores state into the session.
94*14675a02SAndroid Build Coastguard Worker    * @param op Contains instructions for operations to run to restore the state.
95*14675a02SAndroid Build Coastguard Worker    * saver_def must not be set on the op.
96*14675a02SAndroid Build Coastguard Worker    * @param restore_inputs A collection of tensor variables that should be
97*14675a02SAndroid Build Coastguard Worker    * loaded into the session.
98*14675a02SAndroid Build Coastguard Worker    */
99*14675a02SAndroid Build Coastguard Worker   Result<Unit> RestoreState(
100*14675a02SAndroid Build Coastguard Worker       const google::internal::federated::plan::CheckpointOp& op,
101*14675a02SAndroid Build Coastguard Worker       const NamedTensorList& restore_inputs);
102*14675a02SAndroid Build Coastguard Worker 
103*14675a02SAndroid Build Coastguard Worker  private:
104*14675a02SAndroid Build Coastguard Worker   // Overload to allow providing inputs to operations.
105*14675a02SAndroid Build Coastguard Worker   Result<Unit> RunOp(const NamedTensorList& inputs, absl::string_view op);
106*14675a02SAndroid Build Coastguard Worker   std::string GetTmpCheckpointFileName(absl::string_view name);
107*14675a02SAndroid Build Coastguard Worker 
108*14675a02SAndroid Build Coastguard Worker   std::string tmp_dir_;
109*14675a02SAndroid Build Coastguard Worker   std::unique_ptr<tensorflow::Session> session_;
110*14675a02SAndroid Build Coastguard Worker   fcp::Status session_status_;
111*14675a02SAndroid Build Coastguard Worker };
112*14675a02SAndroid Build Coastguard Worker 
113*14675a02SAndroid Build Coastguard Worker }  // namespace fcp
114*14675a02SAndroid Build Coastguard Worker 
115*14675a02SAndroid Build Coastguard Worker #endif  // FCP_TENSORFLOW_TF_SESSION_H_
116