1 /*
2 * Copyright 2021 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "fcp/client/engine/common.h"
17
18 #include <string>
19
20 #include "fcp/base/monitoring.h"
21 #include "tensorflow/core/protobuf/struct.pb.h"
22
23 namespace fcp {
24 namespace client {
25 namespace engine {
26
27 using ::google::internal::federated::plan::TensorflowSpec;
28
PlanResult(PlanOutcome outcome,absl::Status status)29 PlanResult::PlanResult(PlanOutcome outcome, absl::Status status)
30 : outcome(outcome), original_status(std::move(status)) {
31 if (outcome == PlanOutcome::kSuccess) {
32 FCP_CHECK(original_status.ok());
33 }
34 }
35
ValidateTensorflowSpec(const TensorflowSpec & tensorflow_spec,const absl::flat_hash_set<std::string> & expected_input_tensor_names_set,const std::vector<std::string> & output_names)36 absl::Status ValidateTensorflowSpec(
37 const TensorflowSpec& tensorflow_spec,
38 const absl::flat_hash_set<std::string>& expected_input_tensor_names_set,
39 const std::vector<std::string>& output_names) {
40 // Check that all inputs have corresponding TensorSpecProtos.
41 if (expected_input_tensor_names_set.size() !=
42 tensorflow_spec.input_tensor_specs_size()) {
43 return absl::InvalidArgumentError(
44 "Unexpected number of input_tensor_specs");
45 }
46
47 for (const tensorflow::TensorSpecProto& it :
48 tensorflow_spec.input_tensor_specs()) {
49 if (!expected_input_tensor_names_set.contains(it.name())) {
50 return absl::InvalidArgumentError(absl::StrCat(
51 "Missing expected TensorSpecProto for input ", it.name()));
52 }
53 }
54 // Check that all outputs have corresponding TensorSpecProtos.
55 absl::flat_hash_set<std::string> expected_output_tensor_names_set(
56 output_names.begin(), output_names.end());
57 if (expected_output_tensor_names_set.size() !=
58 tensorflow_spec.output_tensor_specs_size()) {
59 return absl::InvalidArgumentError(
60 absl::StrCat("Unexpected number of output_tensor_specs: ",
61 expected_output_tensor_names_set.size(), " vs. ",
62 tensorflow_spec.output_tensor_specs_size()));
63 }
64 for (const tensorflow::TensorSpecProto& it :
65 tensorflow_spec.output_tensor_specs()) {
66 if (!expected_output_tensor_names_set.count(it.name())) {
67 return absl::InvalidArgumentError(absl::StrCat(
68 "Missing expected TensorSpecProto for output ", it.name()));
69 }
70 }
71
72 return absl::OkStatus();
73 }
74
ConvertPlanOutcomeToPhaseOutcome(PlanOutcome plan_outcome)75 PhaseOutcome ConvertPlanOutcomeToPhaseOutcome(PlanOutcome plan_outcome) {
76 switch (plan_outcome) {
77 case PlanOutcome::kSuccess:
78 return PhaseOutcome::COMPLETED;
79 case PlanOutcome::kInterrupted:
80 return PhaseOutcome::INTERRUPTED;
81 case PlanOutcome::kTensorflowError:
82 case PlanOutcome::kInvalidArgument:
83 case PlanOutcome::kExampleIteratorError:
84 return PhaseOutcome::ERROR;
85 }
86 }
87
ConvertPlanOutcomeToStatus(PlanOutcome outcome)88 absl::Status ConvertPlanOutcomeToStatus(PlanOutcome outcome) {
89 switch (outcome) {
90 case PlanOutcome::kSuccess:
91 return absl::OkStatus();
92 case PlanOutcome::kTensorflowError:
93 case PlanOutcome::kInvalidArgument:
94 case PlanOutcome::kExampleIteratorError:
95 return absl::InternalError("");
96 case PlanOutcome::kInterrupted:
97 return absl::CancelledError("");
98 }
99 }
100
101 } // namespace engine
102 } // namespace client
103 } // namespace fcp
104