1 /*
2 * Copyright 2020 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "fcp/client/engine/simple_plan_engine.h"
17
18 #include <functional>
19 #include <string>
20 #include <utility>
21 #include <vector>
22
23 #include "google/protobuf/any.pb.h"
24 #include "absl/container/flat_hash_set.h"
25 #include "absl/status/status.h"
26 #include "absl/status/statusor.h"
27 #include "absl/strings/str_cat.h"
28 #include "fcp/base/monitoring.h"
29 #include "fcp/client/engine/plan_engine_helpers.h"
30 #include "fcp/client/simple_task_environment.h"
31 #include "tensorflow/core/framework/tensor.h"
32 #include "tensorflow/core/framework/tensor.pb.h"
33 #include "tensorflow/core/protobuf/struct.pb.h"
34
35 namespace fcp {
36 namespace client {
37 namespace engine {
38
39 using ::fcp::client::opstats::OpStatsLogger;
40 using ::google::internal::federated::plan::TensorflowSpec;
41
SimplePlanEngine(std::vector<ExampleIteratorFactory * > example_iterator_factories,std::function<bool ()> should_abort,LogManager * log_manager,OpStatsLogger * opstats_logger,const InterruptibleRunner::TimingConfig * timing_config,const bool support_constant_tf_inputs)42 SimplePlanEngine::SimplePlanEngine(
43 std::vector<ExampleIteratorFactory*> example_iterator_factories,
44 std::function<bool()> should_abort, LogManager* log_manager,
45 OpStatsLogger* opstats_logger,
46 const InterruptibleRunner::TimingConfig* timing_config,
47 const bool support_constant_tf_inputs)
48 : example_iterator_factories_(example_iterator_factories),
49 should_abort_(should_abort),
50 log_manager_(log_manager),
51 opstats_logger_(opstats_logger),
52 timing_config_(timing_config),
53 support_constant_tf_inputs_(support_constant_tf_inputs) {}
54
RunPlan(const TensorflowSpec & tensorflow_spec,const std::string & graph,const::google::protobuf::Any & config_proto,std::unique_ptr<std::vector<std::pair<std::string,tensorflow::Tensor>>> inputs,const std::vector<std::string> & output_names)55 PlanResult SimplePlanEngine::RunPlan(
56 const TensorflowSpec& tensorflow_spec, const std::string& graph,
57 const ::google::protobuf::Any& config_proto,
58 std::unique_ptr<std::vector<std::pair<std::string, tensorflow::Tensor>>>
59 inputs,
60 const std::vector<std::string>& output_names) {
61 // Check that all inputs have corresponding TensorSpecProtos.
62 absl::flat_hash_set<std::string> expected_input_tensor_names_set;
63 for (const std::pair<std::string, tensorflow::Tensor>& input : *inputs) {
64 expected_input_tensor_names_set.insert(input.first);
65 }
66 absl::Status validity_checks = ValidateTensorflowSpec(
67 tensorflow_spec, expected_input_tensor_names_set, output_names);
68 if (!validity_checks.ok()) {
69 FCP_LOG(ERROR) << validity_checks.message();
70 return PlanResult(PlanOutcome::kInvalidArgument,
71 std::move(validity_checks));
72 }
73
74 absl::StatusOr<std::unique_ptr<TensorFlowWrapper>> tf_wrapper_or =
75 TensorFlowWrapper::Create(graph, config_proto, should_abort_,
76 *timing_config_, log_manager_);
77 if (!tf_wrapper_or.ok()) {
78 return PlanResult(PlanOutcome::kTensorflowError, tf_wrapper_or.status());
79 }
80
81 std::unique_ptr<TensorFlowWrapper> tf_wrapper =
82 std::move(tf_wrapper_or.value());
83 std::atomic<int> total_example_count = 0;
84 std::atomic<int64_t> total_example_size_bytes = 0;
85 ExampleIteratorStatus example_iterator_status;
86 auto tf_result =
87 RunPlanInternal(tf_wrapper.get(), tensorflow_spec, std::move(inputs),
88 output_names, &total_example_count,
89 &total_example_size_bytes, &example_iterator_status);
90 FCP_CHECK(tf_wrapper->CloseAndRelease().ok());
91
92 switch (tf_result.status().code()) {
93 case absl::StatusCode::kOk: {
94 PlanResult plan_result(PlanOutcome::kSuccess, absl::OkStatus());
95 plan_result.output_names = output_names;
96 plan_result.output_tensors = std::move(tf_result).value();
97 plan_result.example_stats = {
98 .example_count = total_example_count,
99 .example_size_bytes = total_example_size_bytes};
100 return plan_result;
101 }
102 case absl::StatusCode::kCancelled:
103 return PlanResult(PlanOutcome::kInterrupted, tf_result.status());
104 case absl::StatusCode::kInvalidArgument:
105 return CreateComputationErrorPlanResult(
106 example_iterator_status.GetStatus(), tf_result.status());
107 default:
108 FCP_LOG(FATAL) << "unexpected status code: " << tf_result.status().code();
109 }
110 // Unreachable, but clang doesn't get it.
111 return PlanResult(PlanOutcome::kTensorflowError, absl::InternalError(""));
112 }
113
114 absl::StatusOr<std::vector<tensorflow::Tensor>>
RunPlanInternal(TensorFlowWrapper * tf_wrapper,const google::internal::federated::plan::TensorflowSpec & tensorflow_spec,std::unique_ptr<std::vector<std::pair<std::string,tensorflow::Tensor>>> inputs,const std::vector<std::string> & output_names,std::atomic<int> * total_example_count,std::atomic<int64_t> * total_example_size_bytes,ExampleIteratorStatus * example_iterator_status)115 SimplePlanEngine::RunPlanInternal(
116 TensorFlowWrapper* tf_wrapper,
117 const google::internal::federated::plan::TensorflowSpec& tensorflow_spec,
118 std::unique_ptr<std::vector<std::pair<std::string, tensorflow::Tensor>>>
119 inputs,
120 const std::vector<std::string>& output_names,
121 std::atomic<int>* total_example_count,
122 std::atomic<int64_t>* total_example_size_bytes,
123 ExampleIteratorStatus* example_iterator_status) {
124 // Populate input tensor vector
125 // AddDatasetTokenToInputs first registers a DatasetProvider with the global
126 // ExternalDatasetProviderRegistry and then returns a HostObjectRegistration
127 // object. Hold onto the HostObjectRegistration object since it de-registers
128 // upon destruction.
129 HostObjectRegistration host_registration = AddDatasetTokenToInputs(
130 example_iterator_factories_, opstats_logger_, inputs.get(),
131 tensorflow_spec.dataset_token_tensor_name(), total_example_count,
132 total_example_size_bytes, example_iterator_status);
133
134 std::vector<std::string> target_names;
135 for (const std::string& target_node_name :
136 tensorflow_spec.target_node_names()) {
137 target_names.push_back(target_node_name);
138 }
139 if (support_constant_tf_inputs_ &&
140 !tensorflow_spec.constant_inputs().empty()) {
141 // If the server-side constant inputs message is provided, copy over these
142 // values to the set of input tensors.
143 for (const auto& [name, tensor_proto] : tensorflow_spec.constant_inputs()) {
144 tensorflow::Tensor input_tensor;
145 if (!input_tensor.FromProto(tensor_proto)) {
146 return absl::InvalidArgumentError(
147 absl::StrCat("unable to convert constant_input to tensor: %s",
148 tensor_proto.DebugString()));
149 }
150 inputs->push_back({name, std::move(input_tensor)});
151 }
152 }
153
154 FCP_ASSIGN_OR_RETURN(
155 auto result,
156 RunTensorFlowInternal(tf_wrapper, *inputs, output_names, target_names));
157 return result;
158 }
159
160 absl::StatusOr<std::vector<tensorflow::Tensor>>
RunTensorFlowInternal(TensorFlowWrapper * tf_wrapper,const std::vector<std::pair<std::string,tensorflow::Tensor>> & inputs,const std::vector<std::string> & output_tensor_names,const std::vector<std::string> & target_node_names)161 SimplePlanEngine::RunTensorFlowInternal(
162 TensorFlowWrapper* tf_wrapper,
163 const std::vector<std::pair<std::string, tensorflow::Tensor>>& inputs,
164 const std::vector<std::string>& output_tensor_names,
165 const std::vector<std::string>& target_node_names) {
166 std::vector<tensorflow::Tensor> outputs;
167 absl::Status status =
168 tf_wrapper->Run(inputs, output_tensor_names, target_node_names, &outputs);
169 switch (status.code()) {
170 case absl::StatusCode::kCancelled:
171 case absl::StatusCode::kInvalidArgument:
172 return status;
173 case absl::StatusCode::kOutOfRange:
174 case absl::StatusCode::kOk:
175 break;
176 default:
177 FCP_CHECK_STATUS(status);
178 }
179 return outputs;
180 }
181
182 } // namespace engine
183 } // namespace client
184 } // namespace fcp
185