xref: /aosp_15_r20/external/federated-compute/fcp/secagg/client/secagg_client_state.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2018 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fcp/secagg/client/secagg_client_state.h"
18 
19 #include <string>
20 #include <utility>
21 
22 #include "fcp/base/monitoring.h"
23 #include "fcp/secagg/client/state_transition_listener_interface.h"
24 #include "fcp/secagg/shared/input_vector_specification.h"
25 #include "fcp/secagg/shared/secagg_messages.pb.h"
26 #include "fcp/secagg/shared/secagg_vector.h"
27 
28 // The methods implemented here should be overridden by state classes from
29 // which these transitions are valid, and inherited by state classes from which
30 // they are invalid. For example, only round 0 classes should override the Start
31 // method.
32 //
33 // Classes that return booleans should only be overridden by state classes for
34 // which they will return true.
35 
36 namespace fcp {
37 namespace secagg {
38 
SecAggClientState(std::unique_ptr<SendToServerInterface> sender,std::unique_ptr<StateTransitionListenerInterface> transition_listener,ClientState state)39 SecAggClientState::SecAggClientState(
40     std::unique_ptr<SendToServerInterface> sender,
41     std::unique_ptr<StateTransitionListenerInterface> transition_listener,
42     ClientState state)
43     : sender_(std::move(sender)),
44       transition_listener_(std::move(transition_listener)),
45       state_(state) {
46   transition_listener_->Transition(state_);
47 }
48 
Start()49 StatusOr<std::unique_ptr<SecAggClientState> > SecAggClientState::Start() {
50   return FCP_STATUS(FAILED_PRECONDITION)
51          << "An illegal start transition was attempted from state "
52          << StateName();
53 }
54 
HandleMessage(const ServerToClientWrapperMessage & message)55 StatusOr<std::unique_ptr<SecAggClientState> > SecAggClientState::HandleMessage(
56     const ServerToClientWrapperMessage& message) {
57   if (message.message_content_case() ==
58       ServerToClientWrapperMessage::MESSAGE_CONTENT_NOT_SET) {
59     return FCP_STATUS(FAILED_PRECONDITION)
60            << "Client received a message of unknown type but was in state "
61            << StateName();
62   } else {
63     return FCP_STATUS(FAILED_PRECONDITION)
64            << "Client received a message of type "
65            << message.message_content_case() << " but was in state "
66            << StateName();
67   }
68 }
69 
SetInput(std::unique_ptr<SecAggVectorMap> input_map)70 StatusOr<std::unique_ptr<SecAggClientState> > SecAggClientState::SetInput(
71     std::unique_ptr<SecAggVectorMap> input_map) {
72   return FCP_STATUS(FAILED_PRECONDITION)
73          << "An illegal input transition was attempted from state "
74          << StateName();
75 }
76 
Abort(const std::string & reason)77 StatusOr<std::unique_ptr<SecAggClientState> > SecAggClientState::Abort(
78     const std::string& reason) {
79   return FCP_STATUS(FAILED_PRECONDITION)
80          << "The client was already in terminal state " << StateName()
81          << " but received an abort with message: " << reason;
82 }
83 
IsAborted() const84 bool SecAggClientState::IsAborted() const { return false; }
85 
IsCompletedSuccessfully() const86 bool SecAggClientState::IsCompletedSuccessfully() const { return false; }
87 
ErrorMessage() const88 StatusOr<std::string> SecAggClientState::ErrorMessage() const {
89   return FCP_STATUS(FAILED_PRECONDITION)
90          << "Error message requested, but client is in state " << StateName();
91 }
92 
ValidateInput(const SecAggVectorMap & input_map,const std::vector<InputVectorSpecification> & input_vector_specs)93 bool SecAggClientState::ValidateInput(
94     const SecAggVectorMap& input_map,
95     const std::vector<InputVectorSpecification>& input_vector_specs) {
96   if (input_map.size() != input_vector_specs.size()) {
97     return false;
98   }
99   for (const auto& vector_spec : input_vector_specs) {
100     auto input_vec = input_map.find(vector_spec.name());
101     if (input_vec == input_map.end() ||
102         input_vec->second.modulus() != vector_spec.modulus() ||
103         input_vec->second.num_elements() != vector_spec.length()) {
104       return false;
105     }
106   }
107 
108   return true;
109 }
110 
111 }  // namespace secagg
112 }  // namespace fcp
113