xref: /aosp_15_r20/external/federated-compute/fcp/client/engine/common.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1*14675a02SAndroid Build Coastguard Worker /*
2*14675a02SAndroid Build Coastguard Worker  * Copyright 2021 Google LLC
3*14675a02SAndroid Build Coastguard Worker  *
4*14675a02SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*14675a02SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*14675a02SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*14675a02SAndroid Build Coastguard Worker  *
8*14675a02SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*14675a02SAndroid Build Coastguard Worker  *
10*14675a02SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*14675a02SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*14675a02SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*14675a02SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*14675a02SAndroid Build Coastguard Worker  * limitations under the License.
15*14675a02SAndroid Build Coastguard Worker  */
16*14675a02SAndroid Build Coastguard Worker #include "fcp/client/engine/common.h"
17*14675a02SAndroid Build Coastguard Worker 
18*14675a02SAndroid Build Coastguard Worker #include <string>
19*14675a02SAndroid Build Coastguard Worker 
20*14675a02SAndroid Build Coastguard Worker #include "fcp/base/monitoring.h"
21*14675a02SAndroid Build Coastguard Worker #include "tensorflow/core/protobuf/struct.pb.h"
22*14675a02SAndroid Build Coastguard Worker 
23*14675a02SAndroid Build Coastguard Worker namespace fcp {
24*14675a02SAndroid Build Coastguard Worker namespace client {
25*14675a02SAndroid Build Coastguard Worker namespace engine {
26*14675a02SAndroid Build Coastguard Worker 
27*14675a02SAndroid Build Coastguard Worker using ::google::internal::federated::plan::TensorflowSpec;
28*14675a02SAndroid Build Coastguard Worker 
PlanResult(PlanOutcome outcome,absl::Status status)29*14675a02SAndroid Build Coastguard Worker PlanResult::PlanResult(PlanOutcome outcome, absl::Status status)
30*14675a02SAndroid Build Coastguard Worker     : outcome(outcome), original_status(std::move(status)) {
31*14675a02SAndroid Build Coastguard Worker   if (outcome == PlanOutcome::kSuccess) {
32*14675a02SAndroid Build Coastguard Worker     FCP_CHECK(original_status.ok());
33*14675a02SAndroid Build Coastguard Worker   }
34*14675a02SAndroid Build Coastguard Worker }
35*14675a02SAndroid Build Coastguard Worker 
ValidateTensorflowSpec(const TensorflowSpec & tensorflow_spec,const absl::flat_hash_set<std::string> & expected_input_tensor_names_set,const std::vector<std::string> & output_names)36*14675a02SAndroid Build Coastguard Worker absl::Status ValidateTensorflowSpec(
37*14675a02SAndroid Build Coastguard Worker     const TensorflowSpec& tensorflow_spec,
38*14675a02SAndroid Build Coastguard Worker     const absl::flat_hash_set<std::string>& expected_input_tensor_names_set,
39*14675a02SAndroid Build Coastguard Worker     const std::vector<std::string>& output_names) {
40*14675a02SAndroid Build Coastguard Worker   // Check that all inputs have corresponding TensorSpecProtos.
41*14675a02SAndroid Build Coastguard Worker   if (expected_input_tensor_names_set.size() !=
42*14675a02SAndroid Build Coastguard Worker       tensorflow_spec.input_tensor_specs_size()) {
43*14675a02SAndroid Build Coastguard Worker     return absl::InvalidArgumentError(
44*14675a02SAndroid Build Coastguard Worker         "Unexpected number of input_tensor_specs");
45*14675a02SAndroid Build Coastguard Worker   }
46*14675a02SAndroid Build Coastguard Worker 
47*14675a02SAndroid Build Coastguard Worker   for (const tensorflow::TensorSpecProto& it :
48*14675a02SAndroid Build Coastguard Worker        tensorflow_spec.input_tensor_specs()) {
49*14675a02SAndroid Build Coastguard Worker     if (!expected_input_tensor_names_set.contains(it.name())) {
50*14675a02SAndroid Build Coastguard Worker       return absl::InvalidArgumentError(absl::StrCat(
51*14675a02SAndroid Build Coastguard Worker           "Missing expected TensorSpecProto for input ", it.name()));
52*14675a02SAndroid Build Coastguard Worker     }
53*14675a02SAndroid Build Coastguard Worker   }
54*14675a02SAndroid Build Coastguard Worker   // Check that all outputs have corresponding TensorSpecProtos.
55*14675a02SAndroid Build Coastguard Worker   absl::flat_hash_set<std::string> expected_output_tensor_names_set(
56*14675a02SAndroid Build Coastguard Worker       output_names.begin(), output_names.end());
57*14675a02SAndroid Build Coastguard Worker   if (expected_output_tensor_names_set.size() !=
58*14675a02SAndroid Build Coastguard Worker       tensorflow_spec.output_tensor_specs_size()) {
59*14675a02SAndroid Build Coastguard Worker     return absl::InvalidArgumentError(
60*14675a02SAndroid Build Coastguard Worker         absl::StrCat("Unexpected number of output_tensor_specs: ",
61*14675a02SAndroid Build Coastguard Worker                      expected_output_tensor_names_set.size(), " vs. ",
62*14675a02SAndroid Build Coastguard Worker                      tensorflow_spec.output_tensor_specs_size()));
63*14675a02SAndroid Build Coastguard Worker   }
64*14675a02SAndroid Build Coastguard Worker   for (const tensorflow::TensorSpecProto& it :
65*14675a02SAndroid Build Coastguard Worker        tensorflow_spec.output_tensor_specs()) {
66*14675a02SAndroid Build Coastguard Worker     if (!expected_output_tensor_names_set.count(it.name())) {
67*14675a02SAndroid Build Coastguard Worker       return absl::InvalidArgumentError(absl::StrCat(
68*14675a02SAndroid Build Coastguard Worker           "Missing expected TensorSpecProto for output ", it.name()));
69*14675a02SAndroid Build Coastguard Worker     }
70*14675a02SAndroid Build Coastguard Worker   }
71*14675a02SAndroid Build Coastguard Worker 
72*14675a02SAndroid Build Coastguard Worker   return absl::OkStatus();
73*14675a02SAndroid Build Coastguard Worker }
74*14675a02SAndroid Build Coastguard Worker 
ConvertPlanOutcomeToPhaseOutcome(PlanOutcome plan_outcome)75*14675a02SAndroid Build Coastguard Worker PhaseOutcome ConvertPlanOutcomeToPhaseOutcome(PlanOutcome plan_outcome) {
76*14675a02SAndroid Build Coastguard Worker   switch (plan_outcome) {
77*14675a02SAndroid Build Coastguard Worker     case PlanOutcome::kSuccess:
78*14675a02SAndroid Build Coastguard Worker       return PhaseOutcome::COMPLETED;
79*14675a02SAndroid Build Coastguard Worker     case PlanOutcome::kInterrupted:
80*14675a02SAndroid Build Coastguard Worker       return PhaseOutcome::INTERRUPTED;
81*14675a02SAndroid Build Coastguard Worker     case PlanOutcome::kTensorflowError:
82*14675a02SAndroid Build Coastguard Worker     case PlanOutcome::kInvalidArgument:
83*14675a02SAndroid Build Coastguard Worker     case PlanOutcome::kExampleIteratorError:
84*14675a02SAndroid Build Coastguard Worker       return PhaseOutcome::ERROR;
85*14675a02SAndroid Build Coastguard Worker   }
86*14675a02SAndroid Build Coastguard Worker }
87*14675a02SAndroid Build Coastguard Worker 
ConvertPlanOutcomeToStatus(PlanOutcome outcome)88*14675a02SAndroid Build Coastguard Worker absl::Status ConvertPlanOutcomeToStatus(PlanOutcome outcome) {
89*14675a02SAndroid Build Coastguard Worker   switch (outcome) {
90*14675a02SAndroid Build Coastguard Worker     case PlanOutcome::kSuccess:
91*14675a02SAndroid Build Coastguard Worker       return absl::OkStatus();
92*14675a02SAndroid Build Coastguard Worker     case PlanOutcome::kTensorflowError:
93*14675a02SAndroid Build Coastguard Worker     case PlanOutcome::kInvalidArgument:
94*14675a02SAndroid Build Coastguard Worker     case PlanOutcome::kExampleIteratorError:
95*14675a02SAndroid Build Coastguard Worker       return absl::InternalError("");
96*14675a02SAndroid Build Coastguard Worker     case PlanOutcome::kInterrupted:
97*14675a02SAndroid Build Coastguard Worker       return absl::CancelledError("");
98*14675a02SAndroid Build Coastguard Worker   }
99*14675a02SAndroid Build Coastguard Worker }
100*14675a02SAndroid Build Coastguard Worker 
101*14675a02SAndroid Build Coastguard Worker }  // namespace engine
102*14675a02SAndroid Build Coastguard Worker }  // namespace client
103*14675a02SAndroid Build Coastguard Worker }  // namespace fcp
104