1 /* 2 * Copyright 2021 Google LLC 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef FCP_CLIENT_ENGINE_COMMON_H_ 17 #define FCP_CLIENT_ENGINE_COMMON_H_ 18 19 #include <string> 20 #include <utility> 21 #include <vector> 22 23 #include "absl/container/flat_hash_set.h" 24 #include "absl/status/status.h" 25 #include "fcp/client/engine/engine.pb.h" 26 #include "fcp/client/stats.h" 27 #include "fcp/protos/plan.pb.h" 28 #include "tensorflow/core/framework/tensor.h" 29 30 namespace fcp { 31 namespace client { 32 namespace engine { 33 34 enum class PlanOutcome { 35 kSuccess, 36 // A TensorFlow error occurred. 37 kTensorflowError, 38 // Computation was interrupted. 39 kInterrupted, 40 // The input parameters are invalid. 41 kInvalidArgument, 42 // An example iterator error occurred. 43 kExampleIteratorError, 44 }; 45 46 // The result of a call to `SimplePlanEngine::RunPlan` or 47 // `TfLitePlanEngine::RunPlan`. 48 struct PlanResult { 49 explicit PlanResult(PlanOutcome outcome, absl::Status status); 50 51 // The outcome of the plan execution. 52 PlanOutcome outcome; 53 // Only set if `outcome` is `kSuccess`, otherwise this is empty. 54 std::vector<tensorflow::Tensor> output_tensors; 55 // Only set if `outcome` is `kSuccess`, otherwise this is empty. 56 std::vector<std::string> output_names; 57 // When the outcome is `kSuccess`, the status is ok. Otherwise, this status 58 // contain the original error status which leads to the PlanOutcome. 59 absl::Status original_status; 60 ::fcp::client::ExampleStats example_stats; 61 62 PlanResult(PlanResult&&) = default; 63 PlanResult& operator=(PlanResult&&) = default; 64 65 // Disallow copy and assign. 66 PlanResult(const PlanResult&) = delete; 67 PlanResult& operator=(const PlanResult&) = delete; 68 }; 69 70 // Validates that the input tensors match what's inside the TensorflowSpec. 71 absl::Status ValidateTensorflowSpec( 72 const google::internal::federated::plan::TensorflowSpec& tensorflow_spec, 73 const absl::flat_hash_set<std::string>& expected_input_tensor_names_set, 74 const std::vector<std::string>& output_names); 75 76 PhaseOutcome ConvertPlanOutcomeToPhaseOutcome(PlanOutcome plan_outcome); 77 78 absl::Status ConvertPlanOutcomeToStatus(engine::PlanOutcome outcome); 79 80 } // namespace engine 81 } // namespace client 82 } // namespace fcp 83 84 #endif // FCP_CLIENT_ENGINE_COMMON_H_ 85