1 /* 2 * Copyright 2018 Google LLC 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef FCP_SECAGG_SHARED_SECAGG_VECTOR_H_ 18 #define FCP_SECAGG_SHARED_SECAGG_VECTOR_H_ 19 20 #include <cstdint> 21 #include <memory> 22 #include <string> 23 #include <utility> 24 #include <vector> 25 26 #include "absl/base/attributes.h" 27 #include "absl/container/node_hash_map.h" 28 #include "absl/numeric/bits.h" 29 #include "absl/strings/string_view.h" 30 #include "absl/types/span.h" 31 #include "fcp/base/monitoring.h" 32 33 // Represents an immutable vector of nonnegative integers, where each entry has 34 // the same specified bit width. This is used in the SecAgg package both to 35 // provide input to SecAggClient and by SecAggServer to provide its output (more 36 // specifically, inputs and outputs are of type 37 // unordered_map<std::string, SecAggVector>, where the key denotes a name 38 // associated with the vector). 39 // 40 // This class is backed by a packed byte representation of a uint64_t vector, in 41 // little endian order, where each consecutive bit_width sequence of bits of 42 // the packed vector corresponds to an integer value between 43 // 0 and modulus. 44 45 namespace fcp { 46 namespace secagg { 47 48 class SecAggVector { 49 public: 50 static constexpr uint64_t kMaxModulus = 1ULL << 62; // max 62 bitwidth 51 52 // Creates a SecAggVector of the specified modulus, using the specified 53 // span of uint64s. The integers are converted into a packed byte 54 // representation and stored in that format. 55 // 56 // Each element of span must be in [0, modulus-1]. 57 // 58 // modulus itself must be > 1 and <= kMaxModulus. 59 SecAggVector(absl::Span<const uint64_t> span, uint64_t modulus, 60 bool branchless_codec = false); 61 62 // Creates a SecAggVector from the given little-endian packed byte 63 // representation. The packed representation should have num_elements longs, 64 // each of bit_width length (in bits). 65 // 66 // packed_bytes must be in the same format as the output of GetAsPackedBytes. 67 // 68 // modulus must be > 1 and <= kMaxModulus. 69 // 70 // For large strings, copying may be avoided by specifying an rvalue for 71 // packed bytes, e.g. std::move(large_caller_string), which should move the 72 // contents. 73 SecAggVector(std::string packed_bytes, uint64_t modulus, size_t num_elements, 74 bool branchless_codec = false); 75 76 // Disallow memory expensive copying of SecAggVector. 77 SecAggVector(const SecAggVector&) = delete; 78 SecAggVector& operator=(const SecAggVector&) = delete; 79 80 // Enable move semantics. SecAggVector(SecAggVector && other)81 SecAggVector(SecAggVector&& other) { other.MoveTo(this); } 82 83 SecAggVector& operator=(SecAggVector&& other) { 84 other.MoveTo(this); 85 return *this; 86 } 87 88 // Calculates bitwith for the specified modulus. GetBitWidth(uint64_t modulus)89 inline static int GetBitWidth(uint64_t modulus) { 90 return static_cast<int>(absl::bit_width(modulus - 1ULL)); 91 } 92 modulus()93 ABSL_MUST_USE_RESULT inline uint64_t modulus() const { return modulus_; } bit_width()94 ABSL_MUST_USE_RESULT inline size_t bit_width() const { return bit_width_; } num_elements()95 ABSL_MUST_USE_RESULT inline size_t num_elements() const { 96 return num_elements_; 97 } 98 99 // Produces and returns a representation of this SecAggVector as a vector of 100 // uint64_t. The returned vector is obtained by unpacking the stored packed 101 // representation of the vector. 102 ABSL_MUST_USE_RESULT std::vector<uint64_t> GetAsUint64Vector() const; 103 104 // Returns the stored, compressed representation of the SecAggVector. 105 // The bytes are stored in little-endian order, using only bit_width bits to 106 // represent each element of the vector. GetAsPackedBytes()107 ABSL_MUST_USE_RESULT inline const std::string& GetAsPackedBytes() const { 108 CheckHasValue(); 109 return packed_bytes_; 110 } 111 112 // Takes out the stored, compressed representation of the SecAggVector. 113 // This call "consumes" the SecAggVector instance, and after that it becomes 114 // invalid. 115 // The bytes are stored in little-endian order, using only bit_width bits to 116 // represent each element of the vector. TakePackedBytes()117 ABSL_MUST_USE_RESULT inline std::string TakePackedBytes() && { 118 CheckHasValue(); 119 modulus_ = 0; 120 bit_width_ = 0; 121 num_elements_ = 0; 122 return std::move(packed_bytes_); 123 } 124 125 inline friend bool operator==(const SecAggVector& lhs, 126 const SecAggVector& rhs) { 127 return lhs.packed_bytes_ == rhs.packed_bytes_; 128 } 129 130 // Decoder for unpacking SecAggVector values one by one. 131 class Decoder { 132 public: Decoder(const SecAggVector & v)133 explicit Decoder(const SecAggVector& v) 134 : Decoder(v.packed_bytes_, v.modulus_) {} 135 136 explicit Decoder(absl::string_view packed_bytes, uint64_t modulus); 137 138 // Unpacks and returns the next value. 139 // Result of this operation is undetermined when the decoder has already 140 // decoded all values. For performance reasons ReadValue doesn't validate 141 // the state. 142 uint64_t ReadValue(); 143 144 private: 145 inline void ReadData(); 146 147 const char* read_cursor_; 148 const char* const cursor_sentinel_; 149 uint64_t cursor_read_value_; 150 uint64_t scratch_; 151 int read_cursor_bit_; 152 uint8_t bit_width_; 153 const uint64_t mask_; 154 uint64_t modulus_; 155 }; 156 157 // Coder for packing SecAggVector values one by one. 158 class Coder { 159 public: 160 explicit Coder(uint64_t modulus, int bit_width, size_t num_elements); 161 162 // Pack and write value to packed buffer. 163 void WriteValue(uint64_t value); 164 165 // Consumes the coder and creates SecAggVector with the packed buffer. 166 SecAggVector Create() &&; 167 168 private: 169 std::string packed_bytes_; 170 int num_bytes_needed_; 171 uint64_t modulus_; 172 int bit_width_; 173 size_t num_elements_; 174 char* write_cursor_; 175 uint64_t target_cursor_value_; 176 uint8_t starting_bit_position_; 177 }; 178 179 private: 180 std::string packed_bytes_; 181 uint64_t modulus_; 182 int bit_width_; 183 size_t num_elements_; 184 bool branchless_codec_; 185 186 // Moves this object's value to the target one and resets this object's state. MoveTo(SecAggVector * target)187 inline void MoveTo(SecAggVector* target) { 188 target->modulus_ = modulus_; 189 target->bit_width_ = bit_width_; 190 target->num_elements_ = num_elements_; 191 target->branchless_codec_ = branchless_codec_; 192 target->packed_bytes_ = std::move(packed_bytes_); 193 modulus_ = 0; 194 bit_width_ = 0; 195 num_elements_ = 0; 196 branchless_codec_ = false; 197 } 198 199 // Verifies that this SecAggVector value can't be accessed after swapping it 200 // with another SecAggVector via std::move(). CheckHasValue()201 void CheckHasValue() const { 202 FCP_CHECK(modulus_ > 0) << "SecAggVector has no value"; 203 } 204 205 void PackUint64IntoByteStringAt(int index, uint64_t element); 206 // A version without expensive branches or multiplies. 207 void PackUint64IntoByteStringBranchless(absl::Span<const uint64_t> span); 208 209 static ABSL_MUST_USE_RESULT uint64_t UnpackUint64FromByteStringAt( 210 int index, int bit_width, const std::string& byte_string); 211 // A version without expensive branches or multiplies. 212 void UnpackByteStringToUint64VectorBranchless( 213 std::vector<uint64_t>* long_vector) const; 214 }; // class SecAggVector 215 216 // This is equivalent to 217 // using SecAggVectorMap = absl::node_hash_map<std::string, SecAggVector>; 218 // except copy construction and assignment are explicitly prohibited. 219 class SecAggVectorMap : public absl::node_hash_map<std::string, SecAggVector> { 220 public: 221 using Base = absl::node_hash_map<std::string, SecAggVector>; 222 using Base::Base; 223 using Base::operator=; 224 SecAggVectorMap(const SecAggVectorMap&) = delete; 225 SecAggVectorMap& operator=(const SecAggVectorMap&) = delete; 226 }; 227 228 // Unpacked vector is simply a pair vector<uint64_t> and the modulus used with 229 // each element. 230 class SecAggUnpackedVector : public std::vector<uint64_t> { 231 public: SecAggUnpackedVector(size_t size,uint64_t modulus)232 explicit SecAggUnpackedVector(size_t size, uint64_t modulus) 233 : vector(size), modulus_(modulus) {} 234 SecAggUnpackedVector(std::vector<uint64_t> elements,uint64_t modulus)235 explicit SecAggUnpackedVector(std::vector<uint64_t> elements, 236 uint64_t modulus) 237 : vector(std::move(elements)), modulus_(modulus) {} 238 modulus()239 ABSL_MUST_USE_RESULT inline uint64_t modulus() const { return modulus_; } num_elements()240 ABSL_MUST_USE_RESULT inline size_t num_elements() const { return size(); } 241 242 // Disallow memory expensive copying of SecAggVector. 243 SecAggUnpackedVector(const SecAggUnpackedVector&) = delete; 244 SecAggUnpackedVector& operator=(const SecAggUnpackedVector&) = delete; 245 SecAggUnpackedVector(const SecAggVector & other)246 explicit SecAggUnpackedVector(const SecAggVector& other) 247 : vector(other.num_elements()), modulus_(other.modulus()) { 248 SecAggVector::Decoder decoder(other); 249 for (auto& v : *this) { 250 v = decoder.ReadValue(); 251 } 252 } 253 254 // Enable move semantics. SecAggUnpackedVector(SecAggUnpackedVector && other)255 SecAggUnpackedVector(SecAggUnpackedVector&& other) 256 : vector(std::move(other)), modulus_(other.modulus_) { 257 other.modulus_ = 0; 258 } 259 260 SecAggUnpackedVector& operator=(SecAggUnpackedVector&& other) { 261 modulus_ = other.modulus_; 262 other.modulus_ = 0; 263 vector::operator=(std::move(other)); 264 return *this; 265 } 266 267 // Combines this vector with another (packed) vector by adding elements of 268 // this vector to corresponding elements of the other vector. 269 // It is assumed that both vectors have the same modulus. The modulus is 270 // applied to each sum. 271 void Add(const SecAggVector& other); 272 273 private: 274 uint64_t modulus_; 275 }; 276 277 // This is mostly equivalent to 278 // using SecAggUnpackedVectorMap = 279 // absl::node_hash_map<std::string, SecAggUnpackedVector>; 280 // except copy construction and assignment are explicitly prohibited and 281 // Add method is added. 282 class SecAggUnpackedVectorMap 283 : public absl::node_hash_map<std::string, SecAggUnpackedVector> { 284 public: 285 using Base = absl::node_hash_map<std::string, SecAggUnpackedVector>; 286 using Base::Base; 287 using Base::operator=; 288 SecAggUnpackedVectorMap(const SecAggUnpackedVectorMap&) = delete; 289 SecAggUnpackedVectorMap& operator=(const SecAggUnpackedVectorMap&) = delete; 290 SecAggUnpackedVectorMap(const SecAggVectorMap & other)291 explicit SecAggUnpackedVectorMap(const SecAggVectorMap& other) { 292 for (auto& [name, vector] : other) { 293 this->emplace(name, SecAggUnpackedVector(vector)); 294 } 295 } 296 297 // Combines this map with another (packed) map by adding all vectors in this 298 // map to corresponding vectors in the other map. 299 // It is assumed that names of vectors match in both maps. 300 void Add(const SecAggVectorMap& other); 301 302 // Analogous to the above, as a static method. Also assumes that names of 303 // vectors match in both maps. 304 static std::unique_ptr<SecAggUnpackedVectorMap> AddMaps( 305 const SecAggUnpackedVectorMap& a, const SecAggUnpackedVectorMap& b); 306 }; 307 308 } // namespace secagg 309 } // namespace fcp 310 311 #endif // FCP_SECAGG_SHARED_SECAGG_VECTOR_H_ 312