xref: /aosp_15_r20/external/federated-compute/fcp/client/engine/common.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2021 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "fcp/client/engine/common.h"
17 
18 #include <string>
19 
20 #include "fcp/base/monitoring.h"
21 #include "tensorflow/core/protobuf/struct.pb.h"
22 
23 namespace fcp {
24 namespace client {
25 namespace engine {
26 
27 using ::google::internal::federated::plan::TensorflowSpec;
28 
PlanResult(PlanOutcome outcome,absl::Status status)29 PlanResult::PlanResult(PlanOutcome outcome, absl::Status status)
30     : outcome(outcome), original_status(std::move(status)) {
31   if (outcome == PlanOutcome::kSuccess) {
32     FCP_CHECK(original_status.ok());
33   }
34 }
35 
ValidateTensorflowSpec(const TensorflowSpec & tensorflow_spec,const absl::flat_hash_set<std::string> & expected_input_tensor_names_set,const std::vector<std::string> & output_names)36 absl::Status ValidateTensorflowSpec(
37     const TensorflowSpec& tensorflow_spec,
38     const absl::flat_hash_set<std::string>& expected_input_tensor_names_set,
39     const std::vector<std::string>& output_names) {
40   // Check that all inputs have corresponding TensorSpecProtos.
41   if (expected_input_tensor_names_set.size() !=
42       tensorflow_spec.input_tensor_specs_size()) {
43     return absl::InvalidArgumentError(
44         "Unexpected number of input_tensor_specs");
45   }
46 
47   for (const tensorflow::TensorSpecProto& it :
48        tensorflow_spec.input_tensor_specs()) {
49     if (!expected_input_tensor_names_set.contains(it.name())) {
50       return absl::InvalidArgumentError(absl::StrCat(
51           "Missing expected TensorSpecProto for input ", it.name()));
52     }
53   }
54   // Check that all outputs have corresponding TensorSpecProtos.
55   absl::flat_hash_set<std::string> expected_output_tensor_names_set(
56       output_names.begin(), output_names.end());
57   if (expected_output_tensor_names_set.size() !=
58       tensorflow_spec.output_tensor_specs_size()) {
59     return absl::InvalidArgumentError(
60         absl::StrCat("Unexpected number of output_tensor_specs: ",
61                      expected_output_tensor_names_set.size(), " vs. ",
62                      tensorflow_spec.output_tensor_specs_size()));
63   }
64   for (const tensorflow::TensorSpecProto& it :
65        tensorflow_spec.output_tensor_specs()) {
66     if (!expected_output_tensor_names_set.count(it.name())) {
67       return absl::InvalidArgumentError(absl::StrCat(
68           "Missing expected TensorSpecProto for output ", it.name()));
69     }
70   }
71 
72   return absl::OkStatus();
73 }
74 
ConvertPlanOutcomeToPhaseOutcome(PlanOutcome plan_outcome)75 PhaseOutcome ConvertPlanOutcomeToPhaseOutcome(PlanOutcome plan_outcome) {
76   switch (plan_outcome) {
77     case PlanOutcome::kSuccess:
78       return PhaseOutcome::COMPLETED;
79     case PlanOutcome::kInterrupted:
80       return PhaseOutcome::INTERRUPTED;
81     case PlanOutcome::kTensorflowError:
82     case PlanOutcome::kInvalidArgument:
83     case PlanOutcome::kExampleIteratorError:
84       return PhaseOutcome::ERROR;
85   }
86 }
87 
ConvertPlanOutcomeToStatus(PlanOutcome outcome)88 absl::Status ConvertPlanOutcomeToStatus(PlanOutcome outcome) {
89   switch (outcome) {
90     case PlanOutcome::kSuccess:
91       return absl::OkStatus();
92     case PlanOutcome::kTensorflowError:
93     case PlanOutcome::kInvalidArgument:
94     case PlanOutcome::kExampleIteratorError:
95       return absl::InternalError("");
96     case PlanOutcome::kInterrupted:
97       return absl::CancelledError("");
98   }
99 }
100 
101 }  // namespace engine
102 }  // namespace client
103 }  // namespace fcp
104