xref: /aosp_15_r20/external/federated-compute/fcp/secagg/shared/shamir_secret_sharing.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2018 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fcp/secagg/shared/shamir_secret_sharing.h"
18 
19 #include <cstdint>
20 #include <string>
21 #include <vector>
22 
23 #include "absl/status/statusor.h"
24 #include "absl/strings/numbers.h"
25 #include "fcp/secagg/shared/math.h"
26 #include "openssl/rand.h"
27 
28 namespace fcp {
29 namespace secagg {
30 
31 const uint64_t ShamirSecretSharing::kPrime;
32 constexpr size_t kSubsecretSize = sizeof(uint32_t);
33 
ShamirSecretSharing()34 ShamirSecretSharing::ShamirSecretSharing() {}
35 
Share(int threshold,int num_shares,const std::string & to_share)36 std::vector<ShamirShare> ShamirSecretSharing::Share(
37     int threshold, int num_shares, const std::string& to_share) {
38   FCP_CHECK(!to_share.empty()) << "to_share must not be empty";
39   FCP_CHECK(num_shares > 1) << "num_shares must be greater than 1";
40   FCP_CHECK(2 <= threshold && threshold <= num_shares)
41       << "threshold must be at least 2 and at most num_shares";
42 
43   std::vector<uint32_t> subsecrets = DivideIntoSubsecrets(to_share);
44 
45   // Each ShamirShare is specified as a string of length 4 * subsecrets.size().
46   // The first four characters of the ShamirShare are the share of the first
47   // subsecret stored in big-endian order, and so on.
48   std::vector<ShamirShare> shares(num_shares);
49   for (auto& share : shares) {
50     share.data.reserve(kSubsecretSize * subsecrets.size());
51   }
52 
53   for (uint32_t subsecret : subsecrets) {
54     std::vector<uint32_t> coefficients;
55     coefficients.reserve(threshold);
56     coefficients.push_back(subsecret);
57 
58     for (int i = 1; i < threshold; ++i) {
59       coefficients.push_back(RandomFieldElement());
60     }
61 
62     for (int i = 0; i < num_shares; ++i) {
63       // The client with id x gets the share of the polynomial evaluated at x+1.
64       uint32_t subshare = EvaluatePolynomial(coefficients, i + 1);
65       // Big-endian encoding
66       shares[i].data += IntToByteString(subshare);
67     }
68   }
69   return shares;
70 }
71 
Reconstruct(int threshold,const std::vector<ShamirShare> & shares,int secret_length)72 StatusOr<std::string> ShamirSecretSharing::Reconstruct(
73     int threshold, const std::vector<ShamirShare>& shares, int secret_length) {
74   FCP_CHECK(threshold > 1) << "threshold must be at least 2";
75   FCP_CHECK(secret_length > 0) << "secret_length must be positive";
76   FCP_CHECK(static_cast<int>(shares.size()) >= threshold)
77       << "A vector of size " << shares.size()
78       << " was provided, but threshold was specified as " << threshold;
79 
80   // The max possible number of subsecrets is based on the secret_length.
81   int max_num_subsecrets =
82       ((8 * secret_length) + kBitsPerSubsecret - 1) / kBitsPerSubsecret;
83   // The number of subsecrets may be different due to compatibility with the
84   // legacy Java implementation and may be smaller than max_num_subsecrets.
85   // The actual number is determined below.
86   int num_subsecrets = 0;
87 
88   // The X values of the participating clients' shares. The i-th share will be
89   // given an X value of i+1, to account for the fact that shares are 0-indexed.
90   // We want exactly threshold participating clients.
91   std::vector<int> x_values;
92 
93   for (int i = 0; i < static_cast<int>(shares.size()) &&
94                   static_cast<int>(x_values.size()) < threshold;
95        ++i) {
96     if (shares[i].data.empty()) {
97       continue;
98     }
99 
100     FCP_CHECK(shares[i].data.size() % kSubsecretSize == 0)
101         << "Share with index " << i << " is invalid: a share of size "
102         << shares[i].data.size() << " was provided but a multiple of "
103         << kSubsecretSize << " is expected";
104     if (num_subsecrets == 0) {
105       num_subsecrets = static_cast<int>(shares[i].data.size() / kSubsecretSize);
106       FCP_CHECK(num_subsecrets > 0 && num_subsecrets <= max_num_subsecrets)
107           << "Share with index " << i << " is invalid: "
108           << "the number of subsecrets is " << num_subsecrets
109           << " but between 1 and " << max_num_subsecrets << " is expected";
110     } else {
111       FCP_CHECK(shares[i].data.size() == num_subsecrets * kSubsecretSize)
112           << "Share with index " << i << " is invalid: "
113           << "all shares must match sizes: "
114           << "shares[i].data.size() = " << shares[i].data.size()
115           << ", num_subsecrets = " << num_subsecrets;
116     }
117     x_values.push_back(i + 1);
118   }
119   if (static_cast<int>(x_values.size()) < threshold) {
120     return FCP_STATUS(FAILED_PRECONDITION)
121            << "Only " << x_values.size()
122            << " valid shares were provided, but threshold was specified as "
123            << threshold;
124   }
125 
126   // Recover the sharing polynomials using Lagrange polynomial interpolation.
127   std::vector<uint32_t> coefficients = LagrangeCoefficients(x_values);
128   std::vector<uint32_t> subsecrets;
129   for (int i = 0; i < num_subsecrets; ++i) {
130     subsecrets.push_back(0);
131     for (int j = 0; j < static_cast<int>(x_values.size()); ++j) {
132       int share_index = x_values[j] - 1;
133       uint32_t subshare = 0;
134       // Big-endian decoding
135       for (int k = 0; k < kSubsecretSize; ++k) {
136         subshare <<= 8;
137         subshare += static_cast<uint8_t>(
138             shares[share_index].data[kSubsecretSize * i + k]);
139       }
140       subsecrets[i] += MultiplyMod(subshare, coefficients[j], kPrime);
141       subsecrets[i] %= kPrime;
142     }
143   }
144 
145   return RebuildFromSubsecrets(subsecrets, secret_length);
146 }
147 
148 // Helper function for ModInverse.
ModPow(uint32_t x,uint32_t y)149 static uint32_t ModPow(uint32_t x, uint32_t y) {
150   if (y == 0) {
151     return 1;
152   }
153   uint32_t p = ModPow(x, y / 2) % ShamirSecretSharing::kPrime;
154   uint32_t q = MultiplyMod(p, p, ShamirSecretSharing::kPrime);
155   return ((y & 0x01) == 0) ? q : MultiplyMod(x, q, ShamirSecretSharing::kPrime);
156 }
157 
ModInverse(uint32_t n)158 uint32_t ShamirSecretSharing::ModInverse(uint32_t n) {
159   FCP_CHECK(n > 0 && n < kPrime) << "Invalid value " << n << " for ModInverse";
160   while (inverses_.size() < n) {
161     // Fermat's Little Theorem guarantees n^-1 = n^(P-2) mod P.
162     inverses_.push_back(ModPow(inverses_.size() + 1, kPrime - 2));
163   }
164   return inverses_[n - 1];
165 }
166 
LagrangeCoefficients(const std::vector<int> & x_values)167 std::vector<uint32_t> ShamirSecretSharing::LagrangeCoefficients(
168     const std::vector<int>& x_values) {
169   FCP_CHECK(x_values.size() > 1) << "Must have at least 2 x_values";
170   for (int x : x_values) {
171     FCP_CHECK(x > 0) << "x_values must all be positive, but got a value of "
172                      << x;
173   }
174 
175   if (x_values == last_lc_input_) {
176     return last_lc_output_;
177   }
178   last_lc_input_ = x_values;
179   last_lc_output_.clear();
180 
181   for (int i = 0; i < static_cast<int>(x_values.size()); ++i) {
182     last_lc_output_.push_back(1);
183     for (int j = 0; j < static_cast<int>(x_values.size()); ++j) {
184       if (i == j) {
185         continue;
186       }
187       last_lc_output_[i] = MultiplyMod(last_lc_output_[i], x_values[j], kPrime);
188       if (x_values[j] > x_values[i]) {
189         last_lc_output_[i] = MultiplyMod(
190             last_lc_output_[i], ModInverse(x_values[j] - x_values[i]), kPrime);
191       } else {
192         // Factor out -1 (mod kPrime)
193         last_lc_output_[i] =
194             MultiplyMod(last_lc_output_[i], kPrime - 1, kPrime);
195         last_lc_output_[i] = MultiplyMod(
196             last_lc_output_[i], ModInverse(x_values[i] - x_values[j]), kPrime);
197       }
198     }
199   }
200 
201   return last_lc_output_;
202 }
203 
DivideIntoSubsecrets(const std::string & to_share)204 std::vector<uint32_t> ShamirSecretSharing::DivideIntoSubsecrets(
205     const std::string& to_share) {
206   std::vector<uint32_t> secret_parts(DivideRoundUp(
207       static_cast<uint32_t>(to_share.size()) * 8, kBitsPerSubsecret));
208 
209   int bits_done = 0;
210   auto current_subsecret = secret_parts.rbegin();
211 
212   // This is a packing of the bits in to_share into the bits in secret_parts.
213   // The last 31 bits in to_share are kept in the same order and placed into
214   // the last element of secret_parts, the second-to-last 31 bits are placed in
215   // the second-to-last element, and so on. The high-order bit of every element
216   // of secret_parts is 0. And the first element of secret_parts will contain
217   // the remaining bits at the front of to_share.
218   for (int i = to_share.size() - 1; i >= 0; --i) {
219     // Ensure high-order characters are treated consistently
220     uint8_t current_byte = static_cast<uint8_t>(to_share[i]);
221     if (kBitsPerSubsecret - bits_done > 8) {
222       *current_subsecret |= static_cast<uint32_t>(current_byte) << bits_done;
223       bits_done += 8;
224     } else {
225       uint8_t current_byte_right =
226           current_byte & (0xFF >> (8 - (kBitsPerSubsecret - bits_done)));
227       *current_subsecret |= static_cast<uint32_t>(current_byte_right)
228                             << bits_done;
229       // Make sure we're not in the edge case where we're exactly done.
230       if (!(i == 0 && bits_done + 8 == kBitsPerSubsecret)) {
231         bits_done = (bits_done + 8) % kBitsPerSubsecret;
232         ++current_subsecret;
233         *current_subsecret |= current_byte >> (8 - bits_done);
234       }
235     }
236   }
237   // We should have been working on the 0th element of the vector.
238   FCP_CHECK(current_subsecret + 1 == secret_parts.rend());
239   return secret_parts;
240 }
241 
RebuildFromSubsecrets(const std::vector<uint32_t> & secret_parts,int secret_length)242 std::string ShamirSecretSharing::RebuildFromSubsecrets(
243     const std::vector<uint32_t>& secret_parts, int secret_length) {
244   std::string secret(secret_length, 0);
245   int bits_done = 0;
246   auto subsecret = secret_parts.crbegin();
247   // Exactly reverse the process in DivideIntoSubsecrets.
248   for (int i = static_cast<int>(secret.size()) - 1;
249        i >= 0 && subsecret != secret_parts.crend(); --i) {
250     if (kBitsPerSubsecret - bits_done > 8) {
251       secret[i] = static_cast<uint8_t>((*subsecret >> bits_done) & 0xFF);
252       bits_done += 8;
253     } else {
254       uint8_t next_low_bits = static_cast<uint8_t>(*subsecret >> bits_done);
255       ++subsecret;
256       if (subsecret != secret_parts.crend()) {
257         secret[i] = static_cast<uint8_t>(
258             *subsecret & (0xFF >> (kBitsPerSubsecret - bits_done)));
259       }
260       bits_done = (bits_done + 8) % kBitsPerSubsecret;
261       secret[i] <<= 8 - bits_done;
262       secret[i] |= next_low_bits;
263     }
264   }
265 
266   return secret;
267 }
268 
EvaluatePolynomial(const std::vector<uint32_t> & polynomial,uint32_t x) const269 uint32_t ShamirSecretSharing::EvaluatePolynomial(
270     const std::vector<uint32_t>& polynomial, uint32_t x) const {
271   uint64_t sum = 0;
272 
273   for (int i = polynomial.size() - 1; i > 0; --i) {
274     sum += polynomial[i];
275     sum *= x;
276     sum %= kPrime;
277   }
278 
279   sum += polynomial[0];
280   sum %= kPrime;
281 
282   return static_cast<uint32_t>(sum);
283 }
284 
RandomFieldElement()285 uint32_t ShamirSecretSharing::RandomFieldElement() {
286   uint32_t rand = 0;
287   do {
288     rand = 0;
289     RAND_bytes(reinterpret_cast<uint8_t*>(&rand), sizeof(uint32_t));
290   } while (rand >= kPrime);
291   return rand;
292 }
293 
294 }  // namespace secagg
295 }  // namespace fcp
296