xref: /aosp_15_r20/external/federated-compute/fcp/secagg/shared/secagg_vector.h (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2018 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef FCP_SECAGG_SHARED_SECAGG_VECTOR_H_
18 #define FCP_SECAGG_SHARED_SECAGG_VECTOR_H_
19 
20 #include <cstdint>
21 #include <memory>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 
26 #include "absl/base/attributes.h"
27 #include "absl/container/node_hash_map.h"
28 #include "absl/numeric/bits.h"
29 #include "absl/strings/string_view.h"
30 #include "absl/types/span.h"
31 #include "fcp/base/monitoring.h"
32 
33 // Represents an immutable vector of nonnegative integers, where each entry has
34 // the same specified bit width. This is used in the SecAgg package both to
35 // provide input to SecAggClient and by SecAggServer to provide its output (more
36 // specifically, inputs and outputs are of type
37 // unordered_map<std::string, SecAggVector>, where the key denotes a name
38 // associated with the vector).
39 //
40 // This class is backed by a packed byte representation of a uint64_t vector, in
41 // little endian order, where each consecutive bit_width sequence of bits of
42 // the packed vector corresponds to an integer value between
43 // 0 and modulus.
44 
45 namespace fcp {
46 namespace secagg {
47 
48 class SecAggVector {
49  public:
50   static constexpr uint64_t kMaxModulus = 1ULL << 62;  // max 62 bitwidth
51 
52   // Creates a SecAggVector of the specified modulus, using the specified
53   // span of uint64s. The integers are converted into a packed byte
54   // representation and stored in that format.
55   //
56   // Each element of span must be in [0, modulus-1].
57   //
58   // modulus itself must be > 1 and <= kMaxModulus.
59   SecAggVector(absl::Span<const uint64_t> span, uint64_t modulus,
60                bool branchless_codec = false);
61 
62   // Creates a SecAggVector from the given little-endian packed byte
63   // representation. The packed representation should have num_elements longs,
64   // each of bit_width length (in bits).
65   //
66   // packed_bytes must be in the same format as the output of GetAsPackedBytes.
67   //
68   // modulus must be > 1 and <= kMaxModulus.
69   //
70   // For large strings, copying may be avoided by specifying an rvalue for
71   // packed bytes, e.g. std::move(large_caller_string), which should move the
72   // contents.
73   SecAggVector(std::string packed_bytes, uint64_t modulus, size_t num_elements,
74                bool branchless_codec = false);
75 
76   // Disallow memory expensive copying of SecAggVector.
77   SecAggVector(const SecAggVector&) = delete;
78   SecAggVector& operator=(const SecAggVector&) = delete;
79 
80   // Enable move semantics.
SecAggVector(SecAggVector && other)81   SecAggVector(SecAggVector&& other) { other.MoveTo(this); }
82 
83   SecAggVector& operator=(SecAggVector&& other) {
84     other.MoveTo(this);
85     return *this;
86   }
87 
88   // Calculates bitwith for the specified modulus.
GetBitWidth(uint64_t modulus)89   inline static int GetBitWidth(uint64_t modulus) {
90     return static_cast<int>(absl::bit_width(modulus - 1ULL));
91   }
92 
modulus()93   ABSL_MUST_USE_RESULT inline uint64_t modulus() const { return modulus_; }
bit_width()94   ABSL_MUST_USE_RESULT inline size_t bit_width() const { return bit_width_; }
num_elements()95   ABSL_MUST_USE_RESULT inline size_t num_elements() const {
96     return num_elements_;
97   }
98 
99   // Produces and returns a representation of this SecAggVector as a vector of
100   // uint64_t. The returned vector is obtained by unpacking the stored packed
101   // representation of the vector.
102   ABSL_MUST_USE_RESULT std::vector<uint64_t> GetAsUint64Vector() const;
103 
104   // Returns the stored, compressed representation of the SecAggVector.
105   // The bytes are stored in little-endian order, using only bit_width bits to
106   // represent each element of the vector.
GetAsPackedBytes()107   ABSL_MUST_USE_RESULT inline const std::string& GetAsPackedBytes() const {
108     CheckHasValue();
109     return packed_bytes_;
110   }
111 
112   // Takes out the stored, compressed representation of the SecAggVector.
113   // This call "consumes" the SecAggVector instance, and after that it becomes
114   // invalid.
115   // The bytes are stored in little-endian order, using only bit_width bits to
116   // represent each element of the vector.
TakePackedBytes()117   ABSL_MUST_USE_RESULT inline std::string TakePackedBytes() && {
118     CheckHasValue();
119     modulus_ = 0;
120     bit_width_ = 0;
121     num_elements_ = 0;
122     return std::move(packed_bytes_);
123   }
124 
125   inline friend bool operator==(const SecAggVector& lhs,
126                                 const SecAggVector& rhs) {
127     return lhs.packed_bytes_ == rhs.packed_bytes_;
128   }
129 
130   // Decoder for unpacking SecAggVector values one by one.
131   class Decoder {
132    public:
Decoder(const SecAggVector & v)133     explicit Decoder(const SecAggVector& v)
134         : Decoder(v.packed_bytes_, v.modulus_) {}
135 
136     explicit Decoder(absl::string_view packed_bytes, uint64_t modulus);
137 
138     // Unpacks and returns the next value.
139     // Result of this operation is undetermined when the decoder has already
140     // decoded all values. For performance reasons ReadValue doesn't validate
141     // the state.
142     uint64_t ReadValue();
143 
144    private:
145     inline void ReadData();
146 
147     const char* read_cursor_;
148     const char* const cursor_sentinel_;
149     uint64_t cursor_read_value_;
150     uint64_t scratch_;
151     int read_cursor_bit_;
152     uint8_t bit_width_;
153     const uint64_t mask_;
154     uint64_t modulus_;
155   };
156 
157   // Coder for packing SecAggVector values one by one.
158   class Coder {
159    public:
160     explicit Coder(uint64_t modulus, int bit_width, size_t num_elements);
161 
162     // Pack and write value to packed buffer.
163     void WriteValue(uint64_t value);
164 
165     // Consumes the coder and creates SecAggVector with the packed buffer.
166     SecAggVector Create() &&;
167 
168    private:
169     std::string packed_bytes_;
170     int num_bytes_needed_;
171     uint64_t modulus_;
172     int bit_width_;
173     size_t num_elements_;
174     char* write_cursor_;
175     uint64_t target_cursor_value_;
176     uint8_t starting_bit_position_;
177   };
178 
179  private:
180   std::string packed_bytes_;
181   uint64_t modulus_;
182   int bit_width_;
183   size_t num_elements_;
184   bool branchless_codec_;
185 
186   // Moves this object's value to the target one and resets this object's state.
MoveTo(SecAggVector * target)187   inline void MoveTo(SecAggVector* target) {
188     target->modulus_ = modulus_;
189     target->bit_width_ = bit_width_;
190     target->num_elements_ = num_elements_;
191     target->branchless_codec_ = branchless_codec_;
192     target->packed_bytes_ = std::move(packed_bytes_);
193     modulus_ = 0;
194     bit_width_ = 0;
195     num_elements_ = 0;
196     branchless_codec_ = false;
197   }
198 
199   // Verifies that this SecAggVector value can't be accessed after swapping it
200   // with another SecAggVector via std::move().
CheckHasValue()201   void CheckHasValue() const {
202     FCP_CHECK(modulus_ > 0) << "SecAggVector has no value";
203   }
204 
205   void PackUint64IntoByteStringAt(int index, uint64_t element);
206   // A version without expensive branches or multiplies.
207   void PackUint64IntoByteStringBranchless(absl::Span<const uint64_t> span);
208 
209   static ABSL_MUST_USE_RESULT uint64_t UnpackUint64FromByteStringAt(
210       int index, int bit_width, const std::string& byte_string);
211   // A version without expensive branches or multiplies.
212   void UnpackByteStringToUint64VectorBranchless(
213       std::vector<uint64_t>* long_vector) const;
214 };  // class SecAggVector
215 
216 // This is equivalent to
217 // using SecAggVectorMap = absl::node_hash_map<std::string, SecAggVector>;
218 // except copy construction and assignment are explicitly prohibited.
219 class SecAggVectorMap : public absl::node_hash_map<std::string, SecAggVector> {
220  public:
221   using Base = absl::node_hash_map<std::string, SecAggVector>;
222   using Base::Base;
223   using Base::operator=;
224   SecAggVectorMap(const SecAggVectorMap&) = delete;
225   SecAggVectorMap& operator=(const SecAggVectorMap&) = delete;
226 };
227 
228 // Unpacked vector is simply a pair vector<uint64_t> and the modulus used with
229 // each element.
230 class SecAggUnpackedVector : public std::vector<uint64_t> {
231  public:
SecAggUnpackedVector(size_t size,uint64_t modulus)232   explicit SecAggUnpackedVector(size_t size, uint64_t modulus)
233       : vector(size), modulus_(modulus) {}
234 
SecAggUnpackedVector(std::vector<uint64_t> elements,uint64_t modulus)235   explicit SecAggUnpackedVector(std::vector<uint64_t> elements,
236                                 uint64_t modulus)
237       : vector(std::move(elements)), modulus_(modulus) {}
238 
modulus()239   ABSL_MUST_USE_RESULT inline uint64_t modulus() const { return modulus_; }
num_elements()240   ABSL_MUST_USE_RESULT inline size_t num_elements() const { return size(); }
241 
242   // Disallow memory expensive copying of SecAggVector.
243   SecAggUnpackedVector(const SecAggUnpackedVector&) = delete;
244   SecAggUnpackedVector& operator=(const SecAggUnpackedVector&) = delete;
245 
SecAggUnpackedVector(const SecAggVector & other)246   explicit SecAggUnpackedVector(const SecAggVector& other)
247       : vector(other.num_elements()), modulus_(other.modulus()) {
248     SecAggVector::Decoder decoder(other);
249     for (auto& v : *this) {
250       v = decoder.ReadValue();
251     }
252   }
253 
254   // Enable move semantics.
SecAggUnpackedVector(SecAggUnpackedVector && other)255   SecAggUnpackedVector(SecAggUnpackedVector&& other)
256       : vector(std::move(other)), modulus_(other.modulus_) {
257     other.modulus_ = 0;
258   }
259 
260   SecAggUnpackedVector& operator=(SecAggUnpackedVector&& other) {
261     modulus_ = other.modulus_;
262     other.modulus_ = 0;
263     vector::operator=(std::move(other));
264     return *this;
265   }
266 
267   // Combines this vector with another (packed) vector by adding elements of
268   // this vector to corresponding elements of the other vector.
269   // It is assumed that both vectors have the same modulus. The modulus is
270   // applied to each sum.
271   void Add(const SecAggVector& other);
272 
273  private:
274   uint64_t modulus_;
275 };
276 
277 // This is mostly equivalent to
278 // using SecAggUnpackedVectorMap =
279 //     absl::node_hash_map<std::string, SecAggUnpackedVector>;
280 // except copy construction and assignment are explicitly prohibited and
281 // Add method is added.
282 class SecAggUnpackedVectorMap
283     : public absl::node_hash_map<std::string, SecAggUnpackedVector> {
284  public:
285   using Base = absl::node_hash_map<std::string, SecAggUnpackedVector>;
286   using Base::Base;
287   using Base::operator=;
288   SecAggUnpackedVectorMap(const SecAggUnpackedVectorMap&) = delete;
289   SecAggUnpackedVectorMap& operator=(const SecAggUnpackedVectorMap&) = delete;
290 
SecAggUnpackedVectorMap(const SecAggVectorMap & other)291   explicit SecAggUnpackedVectorMap(const SecAggVectorMap& other) {
292     for (auto& [name, vector] : other) {
293       this->emplace(name, SecAggUnpackedVector(vector));
294     }
295   }
296 
297   // Combines this map with another (packed) map by adding all vectors in this
298   // map to corresponding vectors in the other map.
299   // It is assumed that names of vectors match in both maps.
300   void Add(const SecAggVectorMap& other);
301 
302   // Analogous to the above, as a static method. Also assumes that names of
303   // vectors match in both maps.
304   static std::unique_ptr<SecAggUnpackedVectorMap> AddMaps(
305       const SecAggUnpackedVectorMap& a, const SecAggUnpackedVectorMap& b);
306 };
307 
308 }  // namespace secagg
309 }  // namespace fcp
310 
311 #endif  // FCP_SECAGG_SHARED_SECAGG_VECTOR_H_
312