xref: /aosp_15_r20/external/federated-compute/fcp/secagg/server/secagg_server_prng_running_state.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2018 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fcp/secagg/server/secagg_server_prng_running_state.h"
18 
19 #include <functional>
20 #include <memory>
21 #include <optional>
22 #include <string>
23 #include <utility>
24 
25 #include "absl/synchronization/mutex.h"
26 #include "fcp/base/monitoring.h"
27 #include "fcp/secagg/server/secagg_server_completed_state.h"
28 #include "fcp/tracing/tracing_span.h"
29 
30 namespace fcp {
31 namespace secagg {
32 
SecAggServerPrngRunningState(std::unique_ptr<SecAggServerProtocolImpl> impl,int number_of_clients_failed_after_sending_masked_input,int number_of_clients_failed_before_sending_masked_input,int number_of_clients_terminated_without_unmasking)33 SecAggServerPrngRunningState::SecAggServerPrngRunningState(
34     std::unique_ptr<SecAggServerProtocolImpl> impl,
35     int number_of_clients_failed_after_sending_masked_input,
36     int number_of_clients_failed_before_sending_masked_input,
37     int number_of_clients_terminated_without_unmasking)
38     : SecAggServerState(number_of_clients_failed_after_sending_masked_input,
39                         number_of_clients_failed_before_sending_masked_input,
40                         number_of_clients_terminated_without_unmasking,
41                         SecAggServerStateKind::PRNG_RUNNING, std::move(impl)),
42       completion_status_(std::nullopt) {}
43 
~SecAggServerPrngRunningState()44 SecAggServerPrngRunningState::~SecAggServerPrngRunningState() {}
45 
HandleMessage(uint32_t client_id,const ClientToServerWrapperMessage & message)46 Status SecAggServerPrngRunningState::HandleMessage(
47     uint32_t client_id, const ClientToServerWrapperMessage& message) {
48   MessageReceived(message, false);  // Messages are always unexpected here.
49   if (message.has_abort()) {
50     AbortClient(client_id, "", ClientDropReason::SENT_ABORT_MESSAGE,
51                 /*notify=*/false);
52   } else {
53     AbortClient(client_id, "Non-abort message sent during PrngUnmasking step.",
54                 ClientDropReason::UNEXPECTED_MESSAGE_TYPE);
55   }
56   return FCP_STATUS(OK);
57 }
58 
HandleAbort()59 void SecAggServerPrngRunningState::HandleAbort() {
60   if (cancellation_token_) {
61     cancellation_token_->Cancel();
62   }
63 }
64 
65 StatusOr<SecAggServerProtocolImpl::PrngWorkItems>
Initialize()66 SecAggServerPrngRunningState::Initialize() {
67   // Shamir reconstruction part of PRNG
68   absl::Time reconstruction_start = absl::Now();
69   FCP_ASSIGN_OR_RETURN(auto shamir_reconstruction_result,
70                        impl()->HandleShamirReconstruction());
71   auto elapsed_millis =
72       absl::ToInt64Milliseconds(absl::Now() - reconstruction_start);
73   if (metrics()) {
74     metrics()->ShamirReconstructionTimes(elapsed_millis);
75   }
76   Trace<ShamirReconstruction>(elapsed_millis);
77 
78   // Generating workitems for PRNG computation.
79   return impl()->InitializePrng(std::move(shamir_reconstruction_result));
80 }
81 
EnterState()82 void SecAggServerPrngRunningState::EnterState() {
83   auto initialize_result = Initialize();
84 
85   if (!initialize_result.ok()) {
86     absl::MutexLock lock(&mutex_);
87     completion_status_ = initialize_result.status();
88     return;
89   }
90 
91   auto work_items = std::move(initialize_result).value();
92 
93   // Scheduling workitems to run.
94   prng_started_time_ = absl::Now();
95 
96   cancellation_token_ = impl()->StartPrng(
97       work_items, [this](Status status) { this->PrngRunnerFinished(status); });
98 }
99 
SetAsyncCallback(std::function<void ()> async_callback)100 bool SecAggServerPrngRunningState::SetAsyncCallback(
101     std::function<void()> async_callback) {
102   absl::MutexLock lock(&mutex_);
103   FCP_CHECK(async_callback != nullptr) << "async_callback is expected";
104 
105   if (completion_status_.has_value()) {
106     // PRNG computation has already finished.
107     impl()->scheduler()->ScheduleCallback(async_callback);
108   } else {
109     prng_done_callback_ = async_callback;
110   }
111   return true;
112 }
113 
PrngRunnerFinished(Status final_status)114 void SecAggServerPrngRunningState::PrngRunnerFinished(Status final_status) {
115   auto elapsed_millis =
116       absl::ToInt64Milliseconds(absl::Now() - prng_started_time_);
117   if (metrics()) {
118     metrics()->PrngExpansionTimes(elapsed_millis);
119   }
120   Trace<PrngExpansion>(elapsed_millis);
121 
122   std::function<void()> prng_done_callback;
123   {
124     absl::MutexLock lock(&mutex_);
125     completion_status_ = final_status;
126     prng_done_callback = prng_done_callback_;
127   }
128 
129   if (prng_done_callback) {
130     prng_done_callback();
131   }
132 }
133 
HandleAbortClient(uint32_t client_id,ClientDropReason reason_code)134 void SecAggServerPrngRunningState::HandleAbortClient(
135     uint32_t client_id, ClientDropReason reason_code) {
136   set_client_status(client_id,
137                     ClientStatus::DEAD_AFTER_UNMASKING_RESPONSE_RECEIVED);
138 }
139 
140 StatusOr<std::unique_ptr<SecAggServerState>>
ProceedToNextRound()141 SecAggServerPrngRunningState::ProceedToNextRound() {
142   // Block if StartPrng is still being called. That done to ensure that
143   // StartPrng doesn't use *this* object after it has been destroyed by
144   // the code that called ProceedToNextRound.
145   absl::MutexLock lock(&mutex_);
146 
147   if (!completion_status_.has_value()) {
148     return FCP_STATUS(UNAVAILABLE);
149   }
150 
151   // Don't send any messages; every client either got an "early success"
152   // notification at the end of Round 3, marked itself completed after sending
153   // its Round 3 message, or was already aborted.
154   if (completion_status_.value().ok()) {
155     return std::make_unique<SecAggServerCompletedState>(
156         ExitState(StateTransition::kSuccess),
157         number_of_clients_failed_after_sending_masked_input_,
158         number_of_clients_failed_before_sending_masked_input_,
159         number_of_clients_terminated_without_unmasking_);
160   } else {
161     return AbortState(std::string(completion_status_.value().message()),
162                       SecAggServerOutcome::UNHANDLED_ERROR);
163   }
164 }
165 
ReadyForNextRound() const166 bool SecAggServerPrngRunningState::ReadyForNextRound() const {
167   absl::MutexLock lock(&mutex_);
168   return completion_status_.has_value();
169 }
170 
NumberOfIncludedInputs() const171 int SecAggServerPrngRunningState::NumberOfIncludedInputs() const {
172   return total_number_of_clients() -
173          number_of_clients_failed_before_sending_masked_input_;
174 }
175 
IsNumberOfIncludedInputsCommitted() const176 bool SecAggServerPrngRunningState::IsNumberOfIncludedInputsCommitted() const {
177   return true;
178 }
179 
180 }  // namespace secagg
181 }  // namespace fcp
182