xref: /aosp_15_r20/external/private-join-and-compute/private_join_and_compute/crypto/mont_mul.cc (revision a6aa18fbfbf9cb5cd47356a9d1b057768998488c)
1 /*
2  * Copyright 2019 Google LLC.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     https://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "private_join_and_compute/crypto/mont_mul.h"
17 
18 #include <algorithm>
19 #include <string>
20 #include <utility>
21 #include <vector>
22 
23 #include "absl/log/check.h"
24 #include "private_join_and_compute/crypto/openssl.inc"
25 
26 namespace private_join_and_compute {
27 
MontBigNum(const MontBigNum & other)28 MontBigNum::MontBigNum(const MontBigNum& other)
29     : ctx_(other.ctx_),
30       mont_ctx_(other.mont_ctx_),
31       bn_(BigNum::BignumPtr(BN_dup(other.bn_.get()))) {}
32 
operator =(const MontBigNum & other)33 MontBigNum& MontBigNum::operator=(const MontBigNum& other) {
34   ctx_ = other.ctx_;
35   mont_ctx_ = other.mont_ctx_;
36   bn_ = BigNum::BignumPtr(BN_dup(other.bn_.get()));
37   return *this;
38 }
39 
MontBigNum(MontBigNum && other)40 MontBigNum::MontBigNum(MontBigNum&& other)
41     : ctx_(other.ctx_), mont_ctx_(other.mont_ctx_), bn_(std::move(other.bn_)) {}
42 
operator =(MontBigNum && other)43 MontBigNum& MontBigNum::operator=(MontBigNum&& other) {
44   ctx_ = other.ctx_;
45   mont_ctx_ = other.mont_ctx_;
46   bn_ = std::move(other.bn_);
47   return *this;
48 }
49 
50 // The reinterpret_cast is necessary to accept a string_view.
MontBigNum(Context * ctx,BN_MONT_CTX * mont_ctx,absl::string_view bytes)51 MontBigNum::MontBigNum(Context* ctx, BN_MONT_CTX* mont_ctx,
52                        absl::string_view bytes)
53     : MontBigNum(ctx, mont_ctx, BigNum::BignumPtr(BN_new())) {
54   CRYPTO_CHECK(nullptr !=
55                BN_bin2bn(reinterpret_cast<const unsigned char*>(bytes.data()),
56                          bytes.size(), bn_.get()));
57 }
58 
Mul(const MontBigNum & mont_big_num) const59 MontBigNum MontBigNum::Mul(const MontBigNum& mont_big_num) const {
60   MontBigNum r = *this;
61   r.MulInPlace(mont_big_num);
62   return r;
63 }
64 
MulInPlace(const MontBigNum & mont_big_num)65 MontBigNum& MontBigNum::MulInPlace(const MontBigNum& mont_big_num) {
66   CHECK_EQ(mont_big_num.mont_ctx_, mont_ctx_);
67   CRYPTO_CHECK(1 == BN_mod_mul_montgomery(bn_.get(), bn_.get(),
68                                           mont_big_num.bn_.get(), mont_ctx_,
69                                           ctx_->GetBnCtx()));
70   return *this;
71 }
72 
operator ==(const MontBigNum & other) const73 bool MontBigNum::operator==(const MontBigNum& other) const {
74   CHECK_EQ(other.mont_ctx_, mont_ctx_);
75   return BN_cmp(bn_.get(), other.bn_.get()) == 0;
76 }
77 
PowTo2To(int64_t exponent) const78 MontBigNum MontBigNum::PowTo2To(int64_t exponent) const {
79   CHECK(exponent >= 0) << "MontBigNum::PowTo2To: exponent must be nonnegative";
80   MontBigNum r = *this;
81   for (int64_t i = 0; i < exponent; i++) {
82     CRYPTO_CHECK(1 == BN_mod_mul_montgomery(r.bn_.get(), r.bn_.get(),
83                                             r.bn_.get(), mont_ctx_,
84                                             ctx_->GetBnCtx()));
85   }
86   return r;
87 }
88 
89 // The reinterpret_cast is necessary to return a string.
ToBytes() const90 std::string MontBigNum::ToBytes() const {
91   int length = BN_num_bytes(bn_.get());
92   std::vector<unsigned char> bytes(length);
93   BN_bn2bin(bn_.get(), bytes.data());
94   return std::string(reinterpret_cast<char*>(bytes.data()), bytes.size());
95 }
96 
ToBigNum() const97 BigNum MontBigNum::ToBigNum() const {
98   BIGNUM* temp = BN_new();
99   CHECK_NE(temp, nullptr);
100   auto bn_ptr = BigNum::BignumPtr(temp);
101   CRYPTO_CHECK(1 == BN_from_montgomery(bn_ptr.get(), bn_.get(), mont_ctx_,
102                                        ctx_->GetBnCtx()));
103   return ctx_->CreateBigNum(std::move(bn_ptr));
104 }
105 
MontBigNum(Context * ctx,BN_MONT_CTX * mont_ctx,BigNum::BignumPtr bn)106 MontBigNum::MontBigNum(Context* ctx, BN_MONT_CTX* mont_ctx,
107                        BigNum::BignumPtr bn)
108     : ctx_(ctx), mont_ctx_(mont_ctx), bn_(std::move(bn)) {}
109 
CreateMontBigNum(const BigNum & big_num)110 MontBigNum MontContext::CreateMontBigNum(const BigNum& big_num) {
111   CHECK(big_num < modulus_);
112   BIGNUM* bn = BN_dup(big_num.GetConstBignumPtr());
113   CHECK_NE(bn, nullptr);
114   CRYPTO_CHECK(1 ==
115                BN_to_montgomery(bn, bn, mont_ctx_.get(), ctx_->GetBnCtx()));
116   return MontBigNum(ctx_, mont_ctx_.get(), BigNum::BignumPtr(bn));
117 }
118 
CreateMontBigNum(absl::string_view bytes)119 MontBigNum MontContext::CreateMontBigNum(absl::string_view bytes) {
120   return MontBigNum(ctx_, mont_ctx_.get(), bytes);
121 }
122 
MontContext(Context * ctx,const BigNum & modulus)123 MontContext::MontContext(Context* ctx, const BigNum& modulus)
124     : modulus_(modulus), ctx_(ctx), mont_ctx_(MontCtxPtr(BN_MONT_CTX_new())) {
125   CRYPTO_CHECK(1 == BN_MONT_CTX_set(mont_ctx_.get(),
126                                     modulus.GetConstBignumPtr(),
127                                     ctx_->GetBnCtx()));
128 }
129 
130 }  // namespace private_join_and_compute
131