1 /*
2 * Copyright 2021 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "fcp/secagg/server/aes/aes_secagg_server_protocol_impl.h"
18
19 #include <algorithm>
20 #include <cstddef>
21 #include <functional>
22 #include <iterator>
23 #include <memory>
24 #include <string>
25 #include <utility>
26 #include <vector>
27
28 #include "absl/container/node_hash_map.h"
29 #include "absl/status/status.h"
30 #include "fcp/base/monitoring.h"
31 #include "fcp/secagg/server/experiments_names.h"
32 #include "fcp/secagg/server/secagg_scheduler.h"
33 #include "fcp/secagg/shared/map_of_masks.h"
34 #include "fcp/secagg/shared/math.h"
35 #include "fcp/secagg/shared/secagg_vector.h"
36
37 namespace {
38
AddReduce(std::vector<std::unique_ptr<fcp::secagg::SecAggVectorMap>> vector_of_maps)39 std::unique_ptr<fcp::secagg::SecAggUnpackedVectorMap> AddReduce(
40 std::vector<std::unique_ptr<fcp::secagg::SecAggVectorMap>> vector_of_maps) {
41 FCP_CHECK(!vector_of_maps.empty());
42 // Initialize result
43 auto result = std::make_unique<fcp::secagg::SecAggUnpackedVectorMap>(
44 *vector_of_maps[0]);
45 // Reduce vector of maps
46 for (int i = 1; i < vector_of_maps.size(); ++i) {
47 result->Add(*vector_of_maps[i]);
48 }
49 return result;
50 }
51
52 // Initializes a SecAggUnpackedVectorMap object according to a provided input
53 // vector specification
InitializeVectorMap(const std::vector<fcp::secagg::InputVectorSpecification> & input_vector_specs)54 std::unique_ptr<fcp::secagg::SecAggUnpackedVectorMap> InitializeVectorMap(
55 const std::vector<fcp::secagg::InputVectorSpecification>&
56 input_vector_specs) {
57 auto vector_map = std::make_unique<fcp::secagg::SecAggUnpackedVectorMap>();
58 for (const fcp::secagg::InputVectorSpecification& vector_spec :
59 input_vector_specs) {
60 vector_map->emplace(vector_spec.name(),
61 fcp::secagg::SecAggUnpackedVector(
62 vector_spec.length(), vector_spec.modulus()));
63 }
64 return vector_map;
65 }
66
67 } // namespace
68
69 namespace fcp {
70 namespace secagg {
71
72 // The number of keys included in a single PRNG job.
73 static constexpr int kPrngBatchSize = 32;
74
75 std::shared_ptr<Accumulator<SecAggUnpackedVectorMap>>
SetupMaskedInputCollection()76 AesSecAggServerProtocolImpl::SetupMaskedInputCollection() {
77 if (!experiments()->IsEnabled(kSecAggAsyncRound2Experiment)) {
78 // Prepare the sum of masked input vectors with all zeroes.
79 masked_input_ = InitializeVectorMap(input_vector_specs());
80 } else {
81 auto initial_value = InitializeVectorMap(input_vector_specs());
82 masked_input_accumulator_ =
83 scheduler()->CreateAccumulator<SecAggUnpackedVectorMap>(
84 std::move(initial_value), SecAggUnpackedVectorMap::AddMaps);
85 }
86 return masked_input_accumulator_;
87 }
88
89 std::vector<std::unique_ptr<SecAggVectorMap>>
TakeMaskedInputQueue()90 AesSecAggServerProtocolImpl::TakeMaskedInputQueue() {
91 absl::MutexLock lock(&mutex_);
92 return std::move(masked_input_queue_);
93 }
94
HandleMaskedInputCollectionResponse(std::unique_ptr<MaskedInputCollectionResponse> masked_input_response)95 Status AesSecAggServerProtocolImpl::HandleMaskedInputCollectionResponse(
96 std::unique_ptr<MaskedInputCollectionResponse> masked_input_response) {
97 FCP_CHECK(masked_input_response);
98 // Make sure the received vectors match the specification.
99 if (masked_input_response->vectors().size() != input_vector_specs().size()) {
100 return ::absl::InvalidArgumentError(
101 "Masked input does not match input vector specification - "
102 "wrong number of vectors.");
103 }
104 auto& input_vectors = *masked_input_response->mutable_vectors();
105 auto checked_masked_vectors = std::make_unique<SecAggVectorMap>();
106 for (const InputVectorSpecification& vector_spec : input_vector_specs()) {
107 auto masked_vector = input_vectors.find(vector_spec.name());
108 if (masked_vector == input_vectors.end()) {
109 return ::absl::InvalidArgumentError(
110 "Masked input does not match input vector specification - wrong "
111 "vector names.");
112 }
113 // TODO(team): This does not appear to be properly covered by unit
114 // tests.
115 int bit_width = SecAggVector::GetBitWidth(vector_spec.modulus());
116 if (masked_vector->second.encoded_vector().size() !=
117 DivideRoundUp(vector_spec.length() * bit_width, 8)) {
118 return ::absl::InvalidArgumentError(
119 "Masked input does not match input vector specification - vector is "
120 "wrong size.");
121 }
122 checked_masked_vectors->emplace(
123 vector_spec.name(),
124 SecAggVector(std::move(*masked_vector->second.mutable_encoded_vector()),
125 vector_spec.modulus(), vector_spec.length()));
126 }
127
128 if (experiments()->IsEnabled(kSecAggAsyncRound2Experiment)) {
129 // If async processing is enabled we queue the client message. Moreover, if
130 // the queue we found was empty this means that it has been taken by an
131 // asynchronous aggregation task. In that case, we schedule an aggregation
132 // task to process the queue that we just initiated, which will happen
133 // eventually.
134 size_t is_queue_empty;
135 {
136 absl::MutexLock lock(&mutex_);
137 is_queue_empty = masked_input_queue_.empty();
138 masked_input_queue_.emplace_back(std::move(checked_masked_vectors));
139 }
140 if (is_queue_empty) {
141 // TODO(team): Abort should handle the situation where `this` has
142 // been destructed while the schedule task is still not running, and
143 // message_queue_ can't be moved.
144 Trace<Round2AsyncWorkScheduled>();
145 masked_input_accumulator_->Schedule([&] {
146 auto queue = TakeMaskedInputQueue();
147 Trace<Round2MessageQueueTaken>(queue.size());
148 return AddReduce(std::move(queue));
149 });
150 }
151 } else {
152 // Sequential processing
153 FCP_CHECK(masked_input_);
154 masked_input_->Add(*checked_masked_vectors);
155 }
156
157 return ::absl::OkStatus();
158 }
159
FinalizeMaskedInputCollection()160 void AesSecAggServerProtocolImpl::FinalizeMaskedInputCollection() {
161 if (experiments()->IsEnabled(kSecAggAsyncRound2Experiment)) {
162 FCP_CHECK(masked_input_accumulator_->IsIdle());
163 masked_input_ = masked_input_accumulator_->GetResultAndCancel();
164 }
165 }
166
StartPrng(const PrngWorkItems & work_items,std::function<void (Status)> done_callback)167 CancellationToken AesSecAggServerProtocolImpl::StartPrng(
168 const PrngWorkItems& work_items,
169 std::function<void(Status)> done_callback) {
170 FCP_CHECK(done_callback);
171 FCP_CHECK(masked_input_);
172 auto generators =
173 std::vector<std::function<std::unique_ptr<SecAggUnpackedVectorMap>()>>();
174
175 // Break the keys to add or subtract into vectors of size kPrngBatchSize (or
176 // less for the last one) and schedule them as tasks.
177 for (auto it = work_items.prng_keys_to_add.begin();
178 it < work_items.prng_keys_to_add.end(); it += kPrngBatchSize) {
179 std::vector<AesKey> batch_prng_keys_to_add;
180 std::copy(it,
181 std::min(it + kPrngBatchSize, work_items.prng_keys_to_add.end()),
182 std::back_inserter(batch_prng_keys_to_add));
183 generators.emplace_back([=]() {
184 return UnpackedMapOfMasks(batch_prng_keys_to_add, std::vector<AesKey>(),
185 input_vector_specs(), session_id(),
186 *prng_factory());
187 });
188 }
189
190 for (auto it = work_items.prng_keys_to_subtract.begin();
191 it < work_items.prng_keys_to_subtract.end(); it += kPrngBatchSize) {
192 std::vector<AesKey> batch_prng_keys_to_subtract;
193 std::copy(
194 it,
195 std::min(it + kPrngBatchSize, work_items.prng_keys_to_subtract.end()),
196 std::back_inserter(batch_prng_keys_to_subtract));
197 generators.emplace_back([=]() {
198 return UnpackedMapOfMasks(
199 std::vector<AesKey>(), batch_prng_keys_to_subtract,
200 input_vector_specs(), session_id(), *prng_factory());
201 });
202 }
203
204 auto accumulator = scheduler()->CreateAccumulator<SecAggUnpackedVectorMap>(
205 std::move(masked_input_), SecAggUnpackedVectorMap::AddMaps);
206 for (const auto& generator : generators) {
207 accumulator->Schedule(generator);
208 }
209 accumulator->SetAsyncObserver([=, accumulator = accumulator.get()]() {
210 auto unpacked_map = accumulator->GetResultAndCancel();
211 auto packed_map = std::make_unique<SecAggVectorMap>();
212 for (auto& entry : *unpacked_map) {
213 uint64_t modulus = entry.second.modulus();
214 packed_map->emplace(entry.first,
215 SecAggVector(std::move(entry.second), modulus));
216 }
217 SetResult(std::move(packed_map));
218 done_callback(absl::OkStatus());
219 });
220 return accumulator;
221 }
222 } // namespace secagg
223 } // namespace fcp
224