xref: /aosp_15_r20/external/federated-compute/fcp/secagg/client/secagg_client.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2018 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fcp/secagg/client/secagg_client.h"
18 
19 #include <memory>
20 #include <string>
21 #include <unordered_map>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/synchronization/mutex.h"
26 #include "fcp/base/monitoring.h"
27 #include "fcp/secagg/client/secagg_client_r0_advertise_keys_input_not_set_state.h"
28 #include "fcp/secagg/client/secagg_client_state.h"
29 #include "fcp/secagg/client/send_to_server_interface.h"
30 #include "fcp/secagg/client/state_transition_listener_interface.h"
31 #include "fcp/secagg/shared/aes_prng_factory.h"
32 #include "fcp/secagg/shared/async_abort.h"
33 #include "fcp/secagg/shared/input_vector_specification.h"
34 #include "fcp/secagg/shared/prng.h"
35 #include "fcp/secagg/shared/secagg_messages.pb.h"
36 #include "fcp/secagg/shared/secagg_vector.h"
37 
38 namespace fcp {
39 namespace secagg {
40 
SecAggClient(int max_neighbors_expected,int minimum_surviving_neighbors_for_reconstruction,std::vector<InputVectorSpecification> input_vector_specs,std::unique_ptr<SecurePrng> prng,std::unique_ptr<SendToServerInterface> sender,std::unique_ptr<StateTransitionListenerInterface> transition_listener,std::unique_ptr<AesPrngFactory> prng_factory,std::atomic<std::string * > * abort_signal_for_test)41 SecAggClient::SecAggClient(
42     int max_neighbors_expected,
43     int minimum_surviving_neighbors_for_reconstruction,
44     std::vector<InputVectorSpecification> input_vector_specs,
45     std::unique_ptr<SecurePrng> prng,
46     std::unique_ptr<SendToServerInterface> sender,
47     std::unique_ptr<StateTransitionListenerInterface> transition_listener,
48     std::unique_ptr<AesPrngFactory> prng_factory,
49     std::atomic<std::string*>* abort_signal_for_test)
50     : mu_(),
51       abort_signal_(nullptr),
52       async_abort_(abort_signal_for_test ? abort_signal_for_test
53                                          : &abort_signal_),
54       state_(std::make_unique<SecAggClientR0AdvertiseKeysInputNotSetState>(
55           max_neighbors_expected,
56           minimum_surviving_neighbors_for_reconstruction,
57           std::make_unique<std::vector<InputVectorSpecification> >(
58               std::move(input_vector_specs)),
59           std::move(prng), std::move(sender), std::move(transition_listener),
60           std::move(prng_factory), &async_abort_)) {}
61 
Start()62 Status SecAggClient::Start() {
63   absl::WriterMutexLock _(&mu_);
64   auto state_or_error = state_->Start();
65   if (state_or_error.ok()) {
66     state_ = std::move(state_or_error.value());
67   }
68   return state_or_error.status();
69 }
70 
Abort()71 Status SecAggClient::Abort() { return Abort("unknown reason"); }
72 
Abort(const std::string & reason)73 Status SecAggClient::Abort(const std::string& reason) {
74   async_abort_.Abort(reason);
75   absl::WriterMutexLock _(&mu_);
76   if (state_->IsAborted() || state_->IsCompletedSuccessfully())
77     return FCP_STATUS(OK);
78 
79   auto state_or_error = state_->Abort(reason);
80   if (state_or_error.ok()) {
81     state_ = std::move(state_or_error.value());
82   }
83   return state_or_error.status();
84 }
85 
SetInput(std::unique_ptr<SecAggVectorMap> input_map)86 Status SecAggClient::SetInput(std::unique_ptr<SecAggVectorMap> input_map) {
87   absl::WriterMutexLock _(&mu_);
88   auto state_or_error = state_->SetInput(std::move(input_map));
89   if (state_or_error.ok()) {
90     state_ = std::move(state_or_error.value());
91   }
92   return state_or_error.status();
93 }
94 
ReceiveMessage(const ServerToClientWrapperMessage & incoming)95 StatusOr<bool> SecAggClient::ReceiveMessage(
96     const ServerToClientWrapperMessage& incoming) {
97   absl::WriterMutexLock _(&mu_);
98   auto state_or_error = state_->HandleMessage(incoming);
99   if (state_or_error.ok()) {
100     state_ = std::move(state_or_error.value());
101     // Return true iff neither aborted nor completed.
102     return !(state_->IsAborted() || state_->IsCompletedSuccessfully());
103   } else {
104     return state_or_error.status();
105   }
106 }
107 
ErrorMessage() const108 StatusOr<std::string> SecAggClient::ErrorMessage() const {
109   absl::ReaderMutexLock _(&mu_);
110   return state_->ErrorMessage();
111 }
112 
IsAborted() const113 bool SecAggClient::IsAborted() const {
114   absl::ReaderMutexLock _(&mu_);
115   return state_->IsAborted();
116 }
117 
IsCompletedSuccessfully() const118 bool SecAggClient::IsCompletedSuccessfully() const {
119   absl::ReaderMutexLock _(&mu_);
120   return state_->IsCompletedSuccessfully();
121 }
122 
State() const123 std::string SecAggClient::State() const {
124   absl::ReaderMutexLock _(&mu_);
125   return state_->StateName();
126 }
127 
128 }  // namespace secagg
129 }  // namespace fcp
130