xref: /aosp_15_r20/external/cronet/third_party/anonymous_tokens/src/anonymous_tokens/cpp/crypto/rsa_blinder.cc (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 // Copyright 2023 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //    https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "anonymous_tokens/cpp/crypto/rsa_blinder.h"
16 
17 #include <memory>
18 #include <string>
19 #include <utility>
20 #include <vector>
21 
22 #include "absl/status/status.h"
23 #include "absl/status/statusor.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/string_view.h"
26 #include "anonymous_tokens/cpp/crypto/constants.h"
27 #include "anonymous_tokens/cpp/crypto/crypto_utils.h"
28 #include "anonymous_tokens/cpp/shared/status_utils.h"
29 #include <openssl/digest.h>
30 #include <openssl/rsa.h>
31 
32 namespace anonymous_tokens {
33 
New(absl::string_view rsa_modulus,absl::string_view rsa_public_exponent,const EVP_MD * signature_hash_function,const EVP_MD * mgf1_hash_function,int salt_length,const bool use_rsa_public_exponent,std::optional<absl::string_view> public_metadata)34 absl::StatusOr<std::unique_ptr<RsaBlinder>> RsaBlinder::New(
35     absl::string_view rsa_modulus, absl::string_view rsa_public_exponent,
36     const EVP_MD* signature_hash_function, const EVP_MD* mgf1_hash_function,
37     int salt_length, const bool use_rsa_public_exponent,
38     std::optional<absl::string_view> public_metadata) {
39   bssl::UniquePtr<RSA> rsa_public_key;
40 
41   if (!public_metadata.has_value()) {
42     ANON_TOKENS_ASSIGN_OR_RETURN(
43         rsa_public_key, CreatePublicKeyRSA(rsa_modulus, rsa_public_exponent));
44   } else {
45     // If public metadata is passed, RsaBlinder will compute a new public
46     // exponent using the public metadata.
47     //
48     // Empty string is a valid public metadata value.
49     ANON_TOKENS_ASSIGN_OR_RETURN(
50         rsa_public_key, CreatePublicKeyRSAWithPublicMetadata(
51                             rsa_modulus, rsa_public_exponent, *public_metadata,
52                             use_rsa_public_exponent));
53   }
54 
55   ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr<BIGNUM> r, NewBigNum());
56   ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr<BIGNUM> r_inv_mont, NewBigNum());
57 
58   // Limit r between [2, n) so that an r of 1 never happens. An r of 1 doesn't
59   // blind.
60   if (BN_rand_range_ex(r.get(), 2, RSA_get0_n(rsa_public_key.get())) !=
61       kBsslSuccess) {
62     return absl::InternalError(
63         "BN_rand_range_ex failed when called from RsaBlinder::New.");
64   }
65 
66   bssl::UniquePtr<BN_CTX> bn_ctx(BN_CTX_new());
67   if (!bn_ctx) {
68     return absl::InternalError("BN_CTX_new failed.");
69   }
70 
71   bssl::UniquePtr<BN_MONT_CTX> bn_mont_ctx(BN_MONT_CTX_new_for_modulus(
72       RSA_get0_n(rsa_public_key.get()), bn_ctx.get()));
73   if (!bn_mont_ctx) {
74     return absl::InternalError("BN_MONT_CTX_new_for_modulus failed.");
75   }
76 
77   // We wish to compute r^-1 in the Montgomery domain, or r^-1 R mod n. This is
78   // can be done with BN_mod_inverse_blinded followed by BN_to_montgomery, but
79   // it is equivalent and slightly more efficient to first compute r R^-1 mod n
80   // with BN_from_montgomery, and then inverting that to give r^-1 R mod n.
81   int is_r_not_invertible = 0;
82   if (BN_from_montgomery(r_inv_mont.get(), r.get(), bn_mont_ctx.get(),
83                          bn_ctx.get()) != kBsslSuccess ||
84       BN_mod_inverse_blinded(r_inv_mont.get(), &is_r_not_invertible,
85                              r_inv_mont.get(), bn_mont_ctx.get(),
86                              bn_ctx.get()) != kBsslSuccess) {
87     return absl::InternalError(
88         absl::StrCat("BN_mod_inverse failed when called from RsaBlinder::New, "
89                      "is_r_not_invertible = ",
90                      is_r_not_invertible));
91   }
92 
93   return absl::WrapUnique(new RsaBlinder(
94       salt_length, public_metadata, signature_hash_function, mgf1_hash_function,
95       std::move(rsa_public_key), std::move(r), std::move(r_inv_mont),
96       std::move(bn_mont_ctx)));
97 }
98 
RsaBlinder(int salt_length,std::optional<absl::string_view> public_metadata,const EVP_MD * sig_hash,const EVP_MD * mgf1_hash,bssl::UniquePtr<RSA> rsa_public_key,bssl::UniquePtr<BIGNUM> r,bssl::UniquePtr<BIGNUM> r_inv_mont,bssl::UniquePtr<BN_MONT_CTX> mont_n)99 RsaBlinder::RsaBlinder(int salt_length,
100                        std::optional<absl::string_view> public_metadata,
101                        const EVP_MD* sig_hash, const EVP_MD* mgf1_hash,
102                        bssl::UniquePtr<RSA> rsa_public_key,
103                        bssl::UniquePtr<BIGNUM> r,
104                        bssl::UniquePtr<BIGNUM> r_inv_mont,
105                        bssl::UniquePtr<BN_MONT_CTX> mont_n)
106     : salt_length_(salt_length),
107       public_metadata_(public_metadata),
108       sig_hash_(sig_hash),
109       mgf1_hash_(mgf1_hash),
110       rsa_public_key_(std::move(rsa_public_key)),
111       r_(std::move(r)),
112       r_inv_mont_(std::move(r_inv_mont)),
113       mont_n_(std::move(mont_n)),
114       blinder_state_(RsaBlinder::BlinderState::kCreated) {}
115 
Blind(const absl::string_view message)116 absl::StatusOr<std::string> RsaBlinder::Blind(const absl::string_view message) {
117   // Check that the blinder state was kCreated
118   if (blinder_state_ != RsaBlinder::BlinderState::kCreated) {
119     return absl::FailedPreconditionError(
120         "RsaBlinder is in wrong state to blind message.");
121   }
122   std::string augmented_message(message);
123   if (public_metadata_.has_value()) {
124     augmented_message = EncodeMessagePublicMetadata(message, *public_metadata_);
125   }
126   ANON_TOKENS_ASSIGN_OR_RETURN(std::string digest_str,
127                                ComputeHash(augmented_message, *sig_hash_));
128   std::vector<uint8_t> digest(digest_str.begin(), digest_str.end());
129 
130   // Construct the PSS padded message, using the same workflow as BoringSSL's
131   // RSA_sign_pss_mgf1 for processing the message (but not signing the message):
132   // google3/third_party/openssl/boringssl/src/crypto/fipsmodule/rsa/rsa.c?l=557
133   if (digest.size() != EVP_MD_size(sig_hash_)) {
134     return absl::InternalError("Invalid input message length.");
135   }
136 
137   // Allocate for padded length
138   const int padded_len = BN_num_bytes(RSA_get0_n(rsa_public_key_.get()));
139   std::vector<uint8_t> padded(padded_len);
140 
141   // The |md| and |mgf1_md| arguments identify the hash used to calculate
142   // |digest| and the MGF1 hash, respectively. If |mgf1_md| is NULL, |md| is
143   // used. |salt_len| specifies the expected salt length in bytes. If |salt_len|
144   // is -1, then the salt length is the same as the hash length. If -2, then the
145   // salt length is maximal given the size of |rsa|. If unsure, use -1.
146   if (RSA_padding_add_PKCS1_PSS_mgf1(
147           /*rsa=*/rsa_public_key_.get(), /*EM=*/padded.data(),
148           /*mHash=*/digest.data(), /*Hash=*/sig_hash_, /*mgf1Hash=*/mgf1_hash_,
149           /*sLen=*/salt_length_) != kBsslSuccess) {
150     return absl::InternalError(
151         "RSA_padding_add_PKCS1_PSS_mgf1 failed when called from "
152         "RsaBlinder::Blind");
153   }
154 
155   bssl::UniquePtr<BN_CTX> bn_ctx(BN_CTX_new());
156   if (!bn_ctx) {
157     return absl::InternalError("BN_CTX_new failed.");
158   }
159 
160   std::string encoded_message(padded.begin(), padded.end());
161   ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr<BIGNUM> encoded_message_bn,
162                                StringToBignum(encoded_message));
163 
164   // Take `r^e mod n`. This is an equivalent operation to RSA_encrypt, without
165   // extra encode/decode trips.
166   ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr<BIGNUM> rE, NewBigNum());
167   if (BN_mod_exp_mont(rE.get(), r_.get(), RSA_get0_e(rsa_public_key_.get()),
168                       RSA_get0_n(rsa_public_key_.get()), bn_ctx.get(),
169                       mont_n_.get()) != kBsslSuccess) {
170     return absl::InternalError(
171         "BN_mod_exp_mont failed when called from RsaBlinder::Blind.");
172   }
173 
174   // Do `encoded_message*r^e mod n`.
175   //
176   // To avoid leaking side channels, we use Montgomery reduction. This would be
177   // FromMontgomery(ModMulMontgomery(ToMontgomery(m), ToMontgomery(r^e))).
178   // However, this is equivalent to ModMulMontgomery(m, ToMontgomery(r^e)).
179   // Each BN_mod_mul_montgomery removes a factor of R, so by having only one
180   // input in the Montgomery domain, we save a To/FromMontgomery pair.
181   //
182   // Internally, BN_mod_exp_mont actually computes r^e in the Montgomery domain
183   // and converts it out, but there is no public API for this, so we perform an
184   // extra conversion.
185   ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr<BIGNUM> multiplication_res,
186                                NewBigNum());
187   if (BN_to_montgomery(multiplication_res.get(), rE.get(), mont_n_.get(),
188                        bn_ctx.get()) != kBsslSuccess ||
189       BN_mod_mul_montgomery(multiplication_res.get(), encoded_message_bn.get(),
190                             multiplication_res.get(), mont_n_.get(),
191                             bn_ctx.get()) != kBsslSuccess) {
192     return absl::InternalError(
193         "BN_mod_mul failed when called from RsaBlinder::Blind.");
194   }
195 
196   absl::StatusOr<std::string> blinded_msg =
197       BignumToString(*multiplication_res, padded_len);
198 
199   // Update RsaBlinder state to kBlinded
200   blinder_state_ = RsaBlinder::BlinderState::kBlinded;
201 
202   return blinded_msg;
203 }
204 
205 // Unblinds `blind_signature`.
Unblind(const absl::string_view blind_signature)206 absl::StatusOr<std::string> RsaBlinder::Unblind(
207     const absl::string_view blind_signature) {
208   if (blinder_state_ != RsaBlinder::BlinderState::kBlinded) {
209     return absl::FailedPreconditionError(
210         "RsaBlinder is in wrong state to unblind signature.");
211   }
212   const unsigned int mod_size = BN_num_bytes(RSA_get0_n(rsa_public_key_.get()));
213   // Parse the signed_blinded_data as BIGNUM.
214   if (blind_signature.size() != mod_size) {
215     return absl::InternalError(absl::StrCat(
216         "Expected blind signature size = ", mod_size,
217         " actual blind signature size = ", blind_signature.size(), " bytes."));
218   }
219 
220   bssl::UniquePtr<BN_CTX> bn_ctx(BN_CTX_new());
221   if (!bn_ctx) {
222     return absl::InternalError("BN_CTX_new failed.");
223   }
224 
225   ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr<BIGNUM> signed_big_num,
226                                StringToBignum(blind_signature));
227   ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr<BIGNUM> unblinded_sig_big,
228                                NewBigNum());
229   // Do `signed_message*r^-1 mod n`.
230   //
231   // To avoid leaking side channels, we use Montgomery reduction. This would be
232   // FromMontgomery(ModMulMontgomery(ToMontgomery(m), ToMontgomery(r^-1))).
233   // However, this is equivalent to ModMulMontgomery(m, ToMontgomery(r^-1)).
234   // Each BN_mod_mul_montgomery removes a factor of R, so by having only one
235   // input in the Montgomery domain, we save a To/FromMontgomery pair.
236   if (BN_mod_mul_montgomery(unblinded_sig_big.get(), signed_big_num.get(),
237                             r_inv_mont_.get(), mont_n_.get(),
238                             bn_ctx.get()) != kBsslSuccess) {
239     return absl::InternalError(
240         "BN_mod_mul failed when called from RsaBlinder::Unblind.");
241   }
242   absl::StatusOr<std::string> unblinded_signed_message =
243       BignumToString(*unblinded_sig_big, /*output_len=*/mod_size);
244   blinder_state_ = RsaBlinder::BlinderState::kUnblinded;
245   return unblinded_signed_message;
246 }
247 
Verify(absl::string_view signature,absl::string_view message)248 absl::Status RsaBlinder::Verify(absl::string_view signature,
249                                 absl::string_view message) {
250   std::string augmented_message(message);
251   if (public_metadata_.has_value()) {
252     augmented_message = EncodeMessagePublicMetadata(message, *public_metadata_);
253   }
254   return RsaBlindSignatureVerify(salt_length_, sig_hash_, mgf1_hash_, signature,
255                                  augmented_message, rsa_public_key_.get());
256 }
257 
258 }  // namespace anonymous_tokens
259