xref: /aosp_15_r20/external/private-join-and-compute/private_join_and_compute/crypto/context.cc (revision a6aa18fbfbf9cb5cd47356a9d1b057768998488c)
1 /*
2  * Copyright 2019 Google LLC.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     https://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "private_join_and_compute/crypto/context.h"
17 
18 #include <cmath>
19 #include <memory>
20 #include <string>
21 
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/string_view.h"
24 #include "private_join_and_compute/crypto/openssl_init.h"
25 
26 namespace private_join_and_compute {
27 
OpenSSLErrorString()28 std::string OpenSSLErrorString() {
29   char buf[256];
30   ERR_error_string_n(ERR_get_error(), buf, sizeof(buf));
31   return buf;
32 }
33 
Context()34 Context::Context()
35     : bn_ctx_(BN_CTX_new()),
36       evp_md_ctx_(EVP_MD_CTX_create()),
37       zero_bn_(CreateBigNum(0)),
38       one_bn_(CreateBigNum(1)),
39       two_bn_(CreateBigNum(2)),
40       three_bn_(CreateBigNum(3)) {
41   OpenSSLInit();
42   CHECK(RAND_status()) << "OpenSSL PRNG is not properly seeded.";
43   HMAC_CTX_init(&hmac_ctx_);
44 }
45 
~Context()46 Context::~Context() { HMAC_CTX_cleanup(&hmac_ctx_); }
47 
GetBnCtx()48 BN_CTX* Context::GetBnCtx() { return bn_ctx_.get(); }
49 
CreateBigNum(absl::string_view bytes)50 BigNum Context::CreateBigNum(absl::string_view bytes) {
51   return BigNum(bn_ctx_.get(), bytes);
52 }
53 
CreateBigNum(uint64_t number)54 BigNum Context::CreateBigNum(uint64_t number) {
55   return BigNum(bn_ctx_.get(), number);
56 }
57 
CreateBigNum(BigNum::BignumPtr bn)58 BigNum Context::CreateBigNum(BigNum::BignumPtr bn) {
59   return BigNum(bn_ctx_.get(), std::move(bn));
60 }
61 
Sha256String(absl::string_view bytes)62 std::string Context::Sha256String(absl::string_view bytes) {
63   unsigned char hash[EVP_MAX_MD_SIZE];
64   CRYPTO_CHECK(1 ==
65                EVP_DigestInit_ex(evp_md_ctx_.get(), EVP_sha256(), nullptr));
66   CRYPTO_CHECK(
67       1 == EVP_DigestUpdate(evp_md_ctx_.get(), bytes.data(), bytes.length()));
68   unsigned int md_len;
69   CRYPTO_CHECK(1 == EVP_DigestFinal_ex(evp_md_ctx_.get(), hash, &md_len));
70   return std::string(reinterpret_cast<char*>(hash), md_len);
71 }
72 
Sha384String(absl::string_view bytes)73 std::string Context::Sha384String(absl::string_view bytes) {
74   unsigned char hash[EVP_MAX_MD_SIZE];
75   CRYPTO_CHECK(1 ==
76                EVP_DigestInit_ex(evp_md_ctx_.get(), EVP_sha384(), nullptr));
77   CRYPTO_CHECK(
78       1 == EVP_DigestUpdate(evp_md_ctx_.get(), bytes.data(), bytes.length()));
79   unsigned int md_len;
80   CRYPTO_CHECK(1 == EVP_DigestFinal_ex(evp_md_ctx_.get(), hash, &md_len));
81   return std::string(reinterpret_cast<char*>(hash), md_len);
82 }
83 
Sha512String(absl::string_view bytes)84 std::string Context::Sha512String(absl::string_view bytes) {
85   unsigned char hash[EVP_MAX_MD_SIZE];
86   CRYPTO_CHECK(1 ==
87                EVP_DigestInit_ex(evp_md_ctx_.get(), EVP_sha512(), nullptr));
88   CRYPTO_CHECK(
89       1 == EVP_DigestUpdate(evp_md_ctx_.get(), bytes.data(), bytes.length()));
90   unsigned int md_len;
91   CRYPTO_CHECK(1 == EVP_DigestFinal_ex(evp_md_ctx_.get(), hash, &md_len));
92   return std::string(reinterpret_cast<char*>(hash), md_len);
93 }
94 
RandomOracle(absl::string_view x,const BigNum & max_value,RandomOracleHashType hash_type)95 BigNum Context::RandomOracle(absl::string_view x, const BigNum& max_value,
96                              RandomOracleHashType hash_type) {
97   int hash_output_length = 256;
98   if (hash_type == SHA512) {
99     hash_output_length = 512;
100   } else if (hash_type == SHA384) {
101     hash_output_length = 384;
102   }
103   int output_bit_length = max_value.BitLength() + hash_output_length;
104   int iter_count =
105       std::ceil(static_cast<float>(output_bit_length) / hash_output_length);
106   CHECK(iter_count * hash_output_length < 130048)
107       << "The domain bit length must not be greater than "
108          "130048. Desired bit length: "
109       << output_bit_length;
110   int excess_bit_count = (iter_count * hash_output_length) - output_bit_length;
111   BigNum hash_output = CreateBigNum(0);
112   for (int i = 1; i < iter_count + 1; i++) {
113     hash_output = hash_output.Lshift(hash_output_length);
114     std::string bignum_bytes = absl::StrCat(CreateBigNum(i).ToBytes(), x);
115     std::string hashed_string;
116     if (hash_type == SHA512) {
117       hashed_string = Sha512String(bignum_bytes);
118     } else if (hash_type == SHA384) {
119       hashed_string = Sha384String(bignum_bytes);
120     } else {
121       hashed_string = Sha256String(bignum_bytes);
122     }
123     hash_output = hash_output + CreateBigNum(hashed_string);
124   }
125   return hash_output.Rshift(excess_bit_count).Mod(max_value);
126 }
127 
RandomOracleSha512(absl::string_view x,const BigNum & max_value)128 BigNum Context::RandomOracleSha512(absl::string_view x,
129                                    const BigNum& max_value) {
130   return RandomOracle(x, max_value, SHA512);
131 }
132 
RandomOracleSha384(absl::string_view x,const BigNum & max_value)133 BigNum Context::RandomOracleSha384(absl::string_view x,
134                                    const BigNum& max_value) {
135   return RandomOracle(x, max_value, SHA384);
136 }
137 
RandomOracleSha256(absl::string_view x,const BigNum & max_value)138 BigNum Context::RandomOracleSha256(absl::string_view x,
139                                    const BigNum& max_value) {
140   return RandomOracle(x, max_value, SHA256);
141 }
142 
PRF(absl::string_view key,absl::string_view data,const BigNum & max_value)143 BigNum Context::PRF(absl::string_view key, absl::string_view data,
144                     const BigNum& max_value) {
145   CHECK_GE(key.size() * 8, 80);
146   CHECK_LE(max_value.BitLength(), 512)
147       << "The requested output length is not supported. The maximum "
148          "supported output length is 512. The requested output length is "
149       << max_value.BitLength();
150   CRYPTO_CHECK(1 == HMAC_Init_ex(&hmac_ctx_, key.data(), key.size(),
151                                  EVP_sha512(), nullptr));
152   CRYPTO_CHECK(1 ==
153                HMAC_Update(&hmac_ctx_,
154                            reinterpret_cast<const unsigned char*>(data.data()),
155                            data.size()));
156   unsigned int md_len;
157   unsigned char hash[EVP_MAX_MD_SIZE];
158   CRYPTO_CHECK(1 == HMAC_Final(&hmac_ctx_, hash, &md_len));
159   BigNum hash_bn(bn_ctx_.get(), hash, md_len);
160   BigNum hash_bn_reduced = hash_bn.GetLastNBits(max_value.BitLength());
161   if (hash_bn_reduced < max_value) {
162     return hash_bn_reduced;
163   } else {
164     return Context::PRF(key, hash_bn.ToBytes(), max_value);
165   }
166 }
167 
GenerateSafePrime(int prime_length)168 BigNum Context::GenerateSafePrime(int prime_length) {
169   BigNum r(bn_ctx_.get());
170   CRYPTO_CHECK(1 == BN_generate_prime_ex(r.bn_.get(), prime_length, 1, nullptr,
171                                          nullptr, nullptr));
172   return r;
173 }
174 
GeneratePrime(int prime_length)175 BigNum Context::GeneratePrime(int prime_length) {
176   BigNum r(bn_ctx_.get());
177   CRYPTO_CHECK(1 == BN_generate_prime_ex(r.bn_.get(), prime_length, 0, nullptr,
178                                          nullptr, nullptr));
179   return r;
180 }
181 
GenerateRandLessThan(const BigNum & max_value)182 BigNum Context::GenerateRandLessThan(const BigNum& max_value) {
183   BigNum r(bn_ctx_.get());
184   CRYPTO_CHECK(1 == BN_rand_range(r.bn_.get(), max_value.bn_.get()));
185   return r;
186 }
187 
GenerateRandBetween(const BigNum & start,const BigNum & end)188 BigNum Context::GenerateRandBetween(const BigNum& start, const BigNum& end) {
189   CHECK(start < end);
190   return GenerateRandLessThan(end - start) + start;
191 }
192 
GenerateRandomBytes(int num_bytes)193 std::string Context::GenerateRandomBytes(int num_bytes) {
194   CHECK_GE(num_bytes, 0) << "num_bytes must be nonnegative, provided value was "
195                          << num_bytes << ".";
196   std::unique_ptr<unsigned char[]> bytes(new unsigned char[num_bytes]);
197   CRYPTO_CHECK(1 == RAND_bytes(bytes.get(), num_bytes));
198   return std::string(reinterpret_cast<char*>(bytes.get()), num_bytes);
199 }
200 
RelativelyPrimeRandomLessThan(const BigNum & num)201 BigNum Context::RelativelyPrimeRandomLessThan(const BigNum& num) {
202   BigNum rand_num = GenerateRandLessThan(num);
203   while (rand_num.Gcd(num) > One()) {
204     rand_num = GenerateRandLessThan(num);
205   }
206   return rand_num;
207 }
208 
209 }  // namespace private_join_and_compute
210