1 /*
2 * Copyright 2018 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "fcp/secagg/server/secagg_server_prng_running_state.h"
18
19 #include <functional>
20 #include <memory>
21 #include <optional>
22 #include <string>
23 #include <utility>
24
25 #include "absl/synchronization/mutex.h"
26 #include "fcp/base/monitoring.h"
27 #include "fcp/secagg/server/secagg_server_completed_state.h"
28 #include "fcp/tracing/tracing_span.h"
29
30 namespace fcp {
31 namespace secagg {
32
SecAggServerPrngRunningState(std::unique_ptr<SecAggServerProtocolImpl> impl,int number_of_clients_failed_after_sending_masked_input,int number_of_clients_failed_before_sending_masked_input,int number_of_clients_terminated_without_unmasking)33 SecAggServerPrngRunningState::SecAggServerPrngRunningState(
34 std::unique_ptr<SecAggServerProtocolImpl> impl,
35 int number_of_clients_failed_after_sending_masked_input,
36 int number_of_clients_failed_before_sending_masked_input,
37 int number_of_clients_terminated_without_unmasking)
38 : SecAggServerState(number_of_clients_failed_after_sending_masked_input,
39 number_of_clients_failed_before_sending_masked_input,
40 number_of_clients_terminated_without_unmasking,
41 SecAggServerStateKind::PRNG_RUNNING, std::move(impl)),
42 completion_status_(std::nullopt) {}
43
~SecAggServerPrngRunningState()44 SecAggServerPrngRunningState::~SecAggServerPrngRunningState() {}
45
HandleMessage(uint32_t client_id,const ClientToServerWrapperMessage & message)46 Status SecAggServerPrngRunningState::HandleMessage(
47 uint32_t client_id, const ClientToServerWrapperMessage& message) {
48 MessageReceived(message, false); // Messages are always unexpected here.
49 if (message.has_abort()) {
50 AbortClient(client_id, "", ClientDropReason::SENT_ABORT_MESSAGE,
51 /*notify=*/false);
52 } else {
53 AbortClient(client_id, "Non-abort message sent during PrngUnmasking step.",
54 ClientDropReason::UNEXPECTED_MESSAGE_TYPE);
55 }
56 return FCP_STATUS(OK);
57 }
58
HandleAbort()59 void SecAggServerPrngRunningState::HandleAbort() {
60 if (cancellation_token_) {
61 cancellation_token_->Cancel();
62 }
63 }
64
65 StatusOr<SecAggServerProtocolImpl::PrngWorkItems>
Initialize()66 SecAggServerPrngRunningState::Initialize() {
67 // Shamir reconstruction part of PRNG
68 absl::Time reconstruction_start = absl::Now();
69 FCP_ASSIGN_OR_RETURN(auto shamir_reconstruction_result,
70 impl()->HandleShamirReconstruction());
71 auto elapsed_millis =
72 absl::ToInt64Milliseconds(absl::Now() - reconstruction_start);
73 if (metrics()) {
74 metrics()->ShamirReconstructionTimes(elapsed_millis);
75 }
76 Trace<ShamirReconstruction>(elapsed_millis);
77
78 // Generating workitems for PRNG computation.
79 return impl()->InitializePrng(std::move(shamir_reconstruction_result));
80 }
81
EnterState()82 void SecAggServerPrngRunningState::EnterState() {
83 auto initialize_result = Initialize();
84
85 if (!initialize_result.ok()) {
86 absl::MutexLock lock(&mutex_);
87 completion_status_ = initialize_result.status();
88 return;
89 }
90
91 auto work_items = std::move(initialize_result).value();
92
93 // Scheduling workitems to run.
94 prng_started_time_ = absl::Now();
95
96 cancellation_token_ = impl()->StartPrng(
97 work_items, [this](Status status) { this->PrngRunnerFinished(status); });
98 }
99
SetAsyncCallback(std::function<void ()> async_callback)100 bool SecAggServerPrngRunningState::SetAsyncCallback(
101 std::function<void()> async_callback) {
102 absl::MutexLock lock(&mutex_);
103 FCP_CHECK(async_callback != nullptr) << "async_callback is expected";
104
105 if (completion_status_.has_value()) {
106 // PRNG computation has already finished.
107 impl()->scheduler()->ScheduleCallback(async_callback);
108 } else {
109 prng_done_callback_ = async_callback;
110 }
111 return true;
112 }
113
PrngRunnerFinished(Status final_status)114 void SecAggServerPrngRunningState::PrngRunnerFinished(Status final_status) {
115 auto elapsed_millis =
116 absl::ToInt64Milliseconds(absl::Now() - prng_started_time_);
117 if (metrics()) {
118 metrics()->PrngExpansionTimes(elapsed_millis);
119 }
120 Trace<PrngExpansion>(elapsed_millis);
121
122 std::function<void()> prng_done_callback;
123 {
124 absl::MutexLock lock(&mutex_);
125 completion_status_ = final_status;
126 prng_done_callback = prng_done_callback_;
127 }
128
129 if (prng_done_callback) {
130 prng_done_callback();
131 }
132 }
133
HandleAbortClient(uint32_t client_id,ClientDropReason reason_code)134 void SecAggServerPrngRunningState::HandleAbortClient(
135 uint32_t client_id, ClientDropReason reason_code) {
136 set_client_status(client_id,
137 ClientStatus::DEAD_AFTER_UNMASKING_RESPONSE_RECEIVED);
138 }
139
140 StatusOr<std::unique_ptr<SecAggServerState>>
ProceedToNextRound()141 SecAggServerPrngRunningState::ProceedToNextRound() {
142 // Block if StartPrng is still being called. That done to ensure that
143 // StartPrng doesn't use *this* object after it has been destroyed by
144 // the code that called ProceedToNextRound.
145 absl::MutexLock lock(&mutex_);
146
147 if (!completion_status_.has_value()) {
148 return FCP_STATUS(UNAVAILABLE);
149 }
150
151 // Don't send any messages; every client either got an "early success"
152 // notification at the end of Round 3, marked itself completed after sending
153 // its Round 3 message, or was already aborted.
154 if (completion_status_.value().ok()) {
155 return std::make_unique<SecAggServerCompletedState>(
156 ExitState(StateTransition::kSuccess),
157 number_of_clients_failed_after_sending_masked_input_,
158 number_of_clients_failed_before_sending_masked_input_,
159 number_of_clients_terminated_without_unmasking_);
160 } else {
161 return AbortState(std::string(completion_status_.value().message()),
162 SecAggServerOutcome::UNHANDLED_ERROR);
163 }
164 }
165
ReadyForNextRound() const166 bool SecAggServerPrngRunningState::ReadyForNextRound() const {
167 absl::MutexLock lock(&mutex_);
168 return completion_status_.has_value();
169 }
170
NumberOfIncludedInputs() const171 int SecAggServerPrngRunningState::NumberOfIncludedInputs() const {
172 return total_number_of_clients() -
173 number_of_clients_failed_before_sending_masked_input_;
174 }
175
IsNumberOfIncludedInputsCommitted() const176 bool SecAggServerPrngRunningState::IsNumberOfIncludedInputsCommitted() const {
177 return true;
178 }
179
180 } // namespace secagg
181 } // namespace fcp
182