1*14675a02SAndroid Build Coastguard Worker /*
2*14675a02SAndroid Build Coastguard Worker * Copyright 2021 Google LLC
3*14675a02SAndroid Build Coastguard Worker *
4*14675a02SAndroid Build Coastguard Worker * Licensed under the Apache License, Version 2.0 (the "License");
5*14675a02SAndroid Build Coastguard Worker * you may not use this file except in compliance with the License.
6*14675a02SAndroid Build Coastguard Worker * You may obtain a copy of the License at
7*14675a02SAndroid Build Coastguard Worker *
8*14675a02SAndroid Build Coastguard Worker * http://www.apache.org/licenses/LICENSE-2.0
9*14675a02SAndroid Build Coastguard Worker *
10*14675a02SAndroid Build Coastguard Worker * Unless required by applicable law or agreed to in writing, software
11*14675a02SAndroid Build Coastguard Worker * distributed under the License is distributed on an "AS IS" BASIS,
12*14675a02SAndroid Build Coastguard Worker * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*14675a02SAndroid Build Coastguard Worker * See the License for the specific language governing permissions and
14*14675a02SAndroid Build Coastguard Worker * limitations under the License.
15*14675a02SAndroid Build Coastguard Worker */
16*14675a02SAndroid Build Coastguard Worker #include "fcp/client/engine/common.h"
17*14675a02SAndroid Build Coastguard Worker
18*14675a02SAndroid Build Coastguard Worker #include <string>
19*14675a02SAndroid Build Coastguard Worker
20*14675a02SAndroid Build Coastguard Worker #include "fcp/base/monitoring.h"
21*14675a02SAndroid Build Coastguard Worker #include "tensorflow/core/protobuf/struct.pb.h"
22*14675a02SAndroid Build Coastguard Worker
23*14675a02SAndroid Build Coastguard Worker namespace fcp {
24*14675a02SAndroid Build Coastguard Worker namespace client {
25*14675a02SAndroid Build Coastguard Worker namespace engine {
26*14675a02SAndroid Build Coastguard Worker
27*14675a02SAndroid Build Coastguard Worker using ::google::internal::federated::plan::TensorflowSpec;
28*14675a02SAndroid Build Coastguard Worker
PlanResult(PlanOutcome outcome,absl::Status status)29*14675a02SAndroid Build Coastguard Worker PlanResult::PlanResult(PlanOutcome outcome, absl::Status status)
30*14675a02SAndroid Build Coastguard Worker : outcome(outcome), original_status(std::move(status)) {
31*14675a02SAndroid Build Coastguard Worker if (outcome == PlanOutcome::kSuccess) {
32*14675a02SAndroid Build Coastguard Worker FCP_CHECK(original_status.ok());
33*14675a02SAndroid Build Coastguard Worker }
34*14675a02SAndroid Build Coastguard Worker }
35*14675a02SAndroid Build Coastguard Worker
ValidateTensorflowSpec(const TensorflowSpec & tensorflow_spec,const absl::flat_hash_set<std::string> & expected_input_tensor_names_set,const std::vector<std::string> & output_names)36*14675a02SAndroid Build Coastguard Worker absl::Status ValidateTensorflowSpec(
37*14675a02SAndroid Build Coastguard Worker const TensorflowSpec& tensorflow_spec,
38*14675a02SAndroid Build Coastguard Worker const absl::flat_hash_set<std::string>& expected_input_tensor_names_set,
39*14675a02SAndroid Build Coastguard Worker const std::vector<std::string>& output_names) {
40*14675a02SAndroid Build Coastguard Worker // Check that all inputs have corresponding TensorSpecProtos.
41*14675a02SAndroid Build Coastguard Worker if (expected_input_tensor_names_set.size() !=
42*14675a02SAndroid Build Coastguard Worker tensorflow_spec.input_tensor_specs_size()) {
43*14675a02SAndroid Build Coastguard Worker return absl::InvalidArgumentError(
44*14675a02SAndroid Build Coastguard Worker "Unexpected number of input_tensor_specs");
45*14675a02SAndroid Build Coastguard Worker }
46*14675a02SAndroid Build Coastguard Worker
47*14675a02SAndroid Build Coastguard Worker for (const tensorflow::TensorSpecProto& it :
48*14675a02SAndroid Build Coastguard Worker tensorflow_spec.input_tensor_specs()) {
49*14675a02SAndroid Build Coastguard Worker if (!expected_input_tensor_names_set.contains(it.name())) {
50*14675a02SAndroid Build Coastguard Worker return absl::InvalidArgumentError(absl::StrCat(
51*14675a02SAndroid Build Coastguard Worker "Missing expected TensorSpecProto for input ", it.name()));
52*14675a02SAndroid Build Coastguard Worker }
53*14675a02SAndroid Build Coastguard Worker }
54*14675a02SAndroid Build Coastguard Worker // Check that all outputs have corresponding TensorSpecProtos.
55*14675a02SAndroid Build Coastguard Worker absl::flat_hash_set<std::string> expected_output_tensor_names_set(
56*14675a02SAndroid Build Coastguard Worker output_names.begin(), output_names.end());
57*14675a02SAndroid Build Coastguard Worker if (expected_output_tensor_names_set.size() !=
58*14675a02SAndroid Build Coastguard Worker tensorflow_spec.output_tensor_specs_size()) {
59*14675a02SAndroid Build Coastguard Worker return absl::InvalidArgumentError(
60*14675a02SAndroid Build Coastguard Worker absl::StrCat("Unexpected number of output_tensor_specs: ",
61*14675a02SAndroid Build Coastguard Worker expected_output_tensor_names_set.size(), " vs. ",
62*14675a02SAndroid Build Coastguard Worker tensorflow_spec.output_tensor_specs_size()));
63*14675a02SAndroid Build Coastguard Worker }
64*14675a02SAndroid Build Coastguard Worker for (const tensorflow::TensorSpecProto& it :
65*14675a02SAndroid Build Coastguard Worker tensorflow_spec.output_tensor_specs()) {
66*14675a02SAndroid Build Coastguard Worker if (!expected_output_tensor_names_set.count(it.name())) {
67*14675a02SAndroid Build Coastguard Worker return absl::InvalidArgumentError(absl::StrCat(
68*14675a02SAndroid Build Coastguard Worker "Missing expected TensorSpecProto for output ", it.name()));
69*14675a02SAndroid Build Coastguard Worker }
70*14675a02SAndroid Build Coastguard Worker }
71*14675a02SAndroid Build Coastguard Worker
72*14675a02SAndroid Build Coastguard Worker return absl::OkStatus();
73*14675a02SAndroid Build Coastguard Worker }
74*14675a02SAndroid Build Coastguard Worker
ConvertPlanOutcomeToPhaseOutcome(PlanOutcome plan_outcome)75*14675a02SAndroid Build Coastguard Worker PhaseOutcome ConvertPlanOutcomeToPhaseOutcome(PlanOutcome plan_outcome) {
76*14675a02SAndroid Build Coastguard Worker switch (plan_outcome) {
77*14675a02SAndroid Build Coastguard Worker case PlanOutcome::kSuccess:
78*14675a02SAndroid Build Coastguard Worker return PhaseOutcome::COMPLETED;
79*14675a02SAndroid Build Coastguard Worker case PlanOutcome::kInterrupted:
80*14675a02SAndroid Build Coastguard Worker return PhaseOutcome::INTERRUPTED;
81*14675a02SAndroid Build Coastguard Worker case PlanOutcome::kTensorflowError:
82*14675a02SAndroid Build Coastguard Worker case PlanOutcome::kInvalidArgument:
83*14675a02SAndroid Build Coastguard Worker case PlanOutcome::kExampleIteratorError:
84*14675a02SAndroid Build Coastguard Worker return PhaseOutcome::ERROR;
85*14675a02SAndroid Build Coastguard Worker }
86*14675a02SAndroid Build Coastguard Worker }
87*14675a02SAndroid Build Coastguard Worker
ConvertPlanOutcomeToStatus(PlanOutcome outcome)88*14675a02SAndroid Build Coastguard Worker absl::Status ConvertPlanOutcomeToStatus(PlanOutcome outcome) {
89*14675a02SAndroid Build Coastguard Worker switch (outcome) {
90*14675a02SAndroid Build Coastguard Worker case PlanOutcome::kSuccess:
91*14675a02SAndroid Build Coastguard Worker return absl::OkStatus();
92*14675a02SAndroid Build Coastguard Worker case PlanOutcome::kTensorflowError:
93*14675a02SAndroid Build Coastguard Worker case PlanOutcome::kInvalidArgument:
94*14675a02SAndroid Build Coastguard Worker case PlanOutcome::kExampleIteratorError:
95*14675a02SAndroid Build Coastguard Worker return absl::InternalError("");
96*14675a02SAndroid Build Coastguard Worker case PlanOutcome::kInterrupted:
97*14675a02SAndroid Build Coastguard Worker return absl::CancelledError("");
98*14675a02SAndroid Build Coastguard Worker }
99*14675a02SAndroid Build Coastguard Worker }
100*14675a02SAndroid Build Coastguard Worker
101*14675a02SAndroid Build Coastguard Worker } // namespace engine
102*14675a02SAndroid Build Coastguard Worker } // namespace client
103*14675a02SAndroid Build Coastguard Worker } // namespace fcp
104