xref: /aosp_15_r20/external/federated-compute/fcp/client/engine/simple_plan_engine.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2020 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "fcp/client/engine/simple_plan_engine.h"
17 
18 #include <functional>
19 #include <string>
20 #include <utility>
21 #include <vector>
22 
23 #include "google/protobuf/any.pb.h"
24 #include "absl/container/flat_hash_set.h"
25 #include "absl/status/status.h"
26 #include "absl/status/statusor.h"
27 #include "absl/strings/str_cat.h"
28 #include "fcp/base/monitoring.h"
29 #include "fcp/client/engine/plan_engine_helpers.h"
30 #include "fcp/client/simple_task_environment.h"
31 #include "tensorflow/core/framework/tensor.h"
32 #include "tensorflow/core/framework/tensor.pb.h"
33 #include "tensorflow/core/protobuf/struct.pb.h"
34 
35 namespace fcp {
36 namespace client {
37 namespace engine {
38 
39 using ::fcp::client::opstats::OpStatsLogger;
40 using ::google::internal::federated::plan::TensorflowSpec;
41 
SimplePlanEngine(std::vector<ExampleIteratorFactory * > example_iterator_factories,std::function<bool ()> should_abort,LogManager * log_manager,OpStatsLogger * opstats_logger,const InterruptibleRunner::TimingConfig * timing_config,const bool support_constant_tf_inputs)42 SimplePlanEngine::SimplePlanEngine(
43     std::vector<ExampleIteratorFactory*> example_iterator_factories,
44     std::function<bool()> should_abort, LogManager* log_manager,
45     OpStatsLogger* opstats_logger,
46     const InterruptibleRunner::TimingConfig* timing_config,
47     const bool support_constant_tf_inputs)
48     : example_iterator_factories_(example_iterator_factories),
49       should_abort_(should_abort),
50       log_manager_(log_manager),
51       opstats_logger_(opstats_logger),
52       timing_config_(timing_config),
53       support_constant_tf_inputs_(support_constant_tf_inputs) {}
54 
RunPlan(const TensorflowSpec & tensorflow_spec,const std::string & graph,const::google::protobuf::Any & config_proto,std::unique_ptr<std::vector<std::pair<std::string,tensorflow::Tensor>>> inputs,const std::vector<std::string> & output_names)55 PlanResult SimplePlanEngine::RunPlan(
56     const TensorflowSpec& tensorflow_spec, const std::string& graph,
57     const ::google::protobuf::Any& config_proto,
58     std::unique_ptr<std::vector<std::pair<std::string, tensorflow::Tensor>>>
59         inputs,
60     const std::vector<std::string>& output_names) {
61   // Check that all inputs have corresponding TensorSpecProtos.
62   absl::flat_hash_set<std::string> expected_input_tensor_names_set;
63   for (const std::pair<std::string, tensorflow::Tensor>& input : *inputs) {
64     expected_input_tensor_names_set.insert(input.first);
65   }
66   absl::Status validity_checks = ValidateTensorflowSpec(
67       tensorflow_spec, expected_input_tensor_names_set, output_names);
68   if (!validity_checks.ok()) {
69     FCP_LOG(ERROR) << validity_checks.message();
70     return PlanResult(PlanOutcome::kInvalidArgument,
71                       std::move(validity_checks));
72   }
73 
74   absl::StatusOr<std::unique_ptr<TensorFlowWrapper>> tf_wrapper_or =
75       TensorFlowWrapper::Create(graph, config_proto, should_abort_,
76                                 *timing_config_, log_manager_);
77   if (!tf_wrapper_or.ok()) {
78     return PlanResult(PlanOutcome::kTensorflowError, tf_wrapper_or.status());
79   }
80 
81   std::unique_ptr<TensorFlowWrapper> tf_wrapper =
82       std::move(tf_wrapper_or.value());
83   std::atomic<int> total_example_count = 0;
84   std::atomic<int64_t> total_example_size_bytes = 0;
85   ExampleIteratorStatus example_iterator_status;
86   auto tf_result =
87       RunPlanInternal(tf_wrapper.get(), tensorflow_spec, std::move(inputs),
88                       output_names, &total_example_count,
89                       &total_example_size_bytes, &example_iterator_status);
90   FCP_CHECK(tf_wrapper->CloseAndRelease().ok());
91 
92   switch (tf_result.status().code()) {
93     case absl::StatusCode::kOk: {
94       PlanResult plan_result(PlanOutcome::kSuccess, absl::OkStatus());
95       plan_result.output_names = output_names;
96       plan_result.output_tensors = std::move(tf_result).value();
97       plan_result.example_stats = {
98           .example_count = total_example_count,
99           .example_size_bytes = total_example_size_bytes};
100       return plan_result;
101     }
102     case absl::StatusCode::kCancelled:
103       return PlanResult(PlanOutcome::kInterrupted, tf_result.status());
104     case absl::StatusCode::kInvalidArgument:
105       return CreateComputationErrorPlanResult(
106           example_iterator_status.GetStatus(), tf_result.status());
107     default:
108       FCP_LOG(FATAL) << "unexpected status code: " << tf_result.status().code();
109   }
110   // Unreachable, but clang doesn't get it.
111   return PlanResult(PlanOutcome::kTensorflowError, absl::InternalError(""));
112 }
113 
114 absl::StatusOr<std::vector<tensorflow::Tensor>>
RunPlanInternal(TensorFlowWrapper * tf_wrapper,const google::internal::federated::plan::TensorflowSpec & tensorflow_spec,std::unique_ptr<std::vector<std::pair<std::string,tensorflow::Tensor>>> inputs,const std::vector<std::string> & output_names,std::atomic<int> * total_example_count,std::atomic<int64_t> * total_example_size_bytes,ExampleIteratorStatus * example_iterator_status)115 SimplePlanEngine::RunPlanInternal(
116     TensorFlowWrapper* tf_wrapper,
117     const google::internal::federated::plan::TensorflowSpec& tensorflow_spec,
118     std::unique_ptr<std::vector<std::pair<std::string, tensorflow::Tensor>>>
119         inputs,
120     const std::vector<std::string>& output_names,
121     std::atomic<int>* total_example_count,
122     std::atomic<int64_t>* total_example_size_bytes,
123     ExampleIteratorStatus* example_iterator_status) {
124   // Populate input tensor vector
125   // AddDatasetTokenToInputs first registers a DatasetProvider with the global
126   // ExternalDatasetProviderRegistry and then returns a HostObjectRegistration
127   // object. Hold onto the HostObjectRegistration object since it de-registers
128   // upon destruction.
129   HostObjectRegistration host_registration = AddDatasetTokenToInputs(
130       example_iterator_factories_, opstats_logger_, inputs.get(),
131       tensorflow_spec.dataset_token_tensor_name(), total_example_count,
132       total_example_size_bytes, example_iterator_status);
133 
134   std::vector<std::string> target_names;
135   for (const std::string& target_node_name :
136        tensorflow_spec.target_node_names()) {
137     target_names.push_back(target_node_name);
138   }
139   if (support_constant_tf_inputs_ &&
140       !tensorflow_spec.constant_inputs().empty()) {
141     // If the server-side constant inputs message is provided, copy over these
142     // values to the set of input tensors.
143     for (const auto& [name, tensor_proto] : tensorflow_spec.constant_inputs()) {
144       tensorflow::Tensor input_tensor;
145       if (!input_tensor.FromProto(tensor_proto)) {
146         return absl::InvalidArgumentError(
147             absl::StrCat("unable to convert constant_input to tensor: %s",
148                          tensor_proto.DebugString()));
149       }
150       inputs->push_back({name, std::move(input_tensor)});
151     }
152   }
153 
154   FCP_ASSIGN_OR_RETURN(
155       auto result,
156       RunTensorFlowInternal(tf_wrapper, *inputs, output_names, target_names));
157   return result;
158 }
159 
160 absl::StatusOr<std::vector<tensorflow::Tensor>>
RunTensorFlowInternal(TensorFlowWrapper * tf_wrapper,const std::vector<std::pair<std::string,tensorflow::Tensor>> & inputs,const std::vector<std::string> & output_tensor_names,const std::vector<std::string> & target_node_names)161 SimplePlanEngine::RunTensorFlowInternal(
162     TensorFlowWrapper* tf_wrapper,
163     const std::vector<std::pair<std::string, tensorflow::Tensor>>& inputs,
164     const std::vector<std::string>& output_tensor_names,
165     const std::vector<std::string>& target_node_names) {
166   std::vector<tensorflow::Tensor> outputs;
167   absl::Status status =
168       tf_wrapper->Run(inputs, output_tensor_names, target_node_names, &outputs);
169   switch (status.code()) {
170     case absl::StatusCode::kCancelled:
171     case absl::StatusCode::kInvalidArgument:
172       return status;
173     case absl::StatusCode::kOutOfRange:
174     case absl::StatusCode::kOk:
175       break;
176     default:
177       FCP_CHECK_STATUS(status);
178   }
179   return outputs;
180 }
181 
182 }  // namespace engine
183 }  // namespace client
184 }  // namespace fcp
185