xref: /aosp_15_r20/external/private-join-and-compute/private_join_and_compute/crypto/ec_group.cc (revision a6aa18fbfbf9cb5cd47356a9d1b057768998488c)
1 /*
2  * Copyright 2019 Google LLC.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     https://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "private_join_and_compute/crypto/ec_group.h"
17 
18 #include <utility>
19 
20 #include "absl/strings/str_cat.h"
21 #include "absl/strings/string_view.h"
22 #include "private_join_and_compute/crypto/ec_point.h"
23 #include "private_join_and_compute/crypto/openssl.inc"
24 #include "private_join_and_compute/util/status.inc"
25 
26 namespace private_join_and_compute {
27 
28 namespace {
29 
30 // Returns a group using the predefined underlying operations suggested by
31 // OpenSSL.
CreateGroup(int curve_id)32 StatusOr<ECGroup::ECGroupPtr> CreateGroup(int curve_id) {
33   auto ec_group_ptr = EC_GROUP_new_by_curve_name(curve_id);
34   // If this fails, this is usually due to an invalid curve id.
35   if (ec_group_ptr == nullptr) {
36     return InvalidArgumentError(
37         absl::StrCat("ECGroup::CreateGroup() - Could not create group. ",
38                      OpenSSLErrorString()));
39   }
40   return ECGroup::ECGroupPtr(ec_group_ptr);
41 }
42 
43 // Returns the order of the group. For more information, see
44 // https://en.wikipedia.org/wiki/Elliptic-curve_cryptography#Domain_parameters.
CreateOrder(const EC_GROUP * group,Context * context)45 StatusOr<BigNum> CreateOrder(const EC_GROUP* group, Context* context) {
46   BIGNUM* bn = BN_new();
47   if (bn == nullptr) {
48     return InternalError(
49         absl::StrCat("ECGroup::CreateOrder - Could not create BIGNUM. ",
50                      OpenSSLErrorString()));
51   }
52   BigNum::BignumPtr order = BigNum::BignumPtr(bn);
53   if (EC_GROUP_get_order(group, order.get(), context->GetBnCtx()) != 1) {
54     return InternalError(absl::StrCat(
55         "ECGroup::CreateOrder - Could not get order. ", OpenSSLErrorString()));
56   }
57   return context->CreateBigNum(std::move(order));
58 }
59 
60 // Returns the cofactor of the group.
CreateCofactor(const EC_GROUP * group,Context * context)61 StatusOr<BigNum> CreateCofactor(const EC_GROUP* group, Context* context) {
62   BIGNUM* bn = BN_new();
63   if (bn == nullptr) {
64     return InternalError(
65         absl::StrCat("ECGroup::CreateCofactor - Could not create BIGNUM. ",
66                      OpenSSLErrorString()));
67   }
68   BigNum::BignumPtr cofactor = BigNum::BignumPtr(bn);
69   if (EC_GROUP_get_cofactor(group, cofactor.get(), context->GetBnCtx()) != 1) {
70     return InternalError(
71         absl::StrCat("ECGroup::CreateCofactor - Could not get cofactor. ",
72                      OpenSSLErrorString()));
73   }
74   return context->CreateBigNum(std::move(cofactor));
75 }
76 
77 // Returns the parameters that define the curve. For more information, see
78 // https://en.wikipedia.org/wiki/Elliptic-curve_cryptography#Domain_parameters.
CreateCurveParams(const EC_GROUP * group,Context * context)79 StatusOr<ECGroup::CurveParams> CreateCurveParams(const EC_GROUP* group,
80                                                  Context* context) {
81   BIGNUM* bn1 = BN_new();
82   BIGNUM* bn2 = BN_new();
83   BIGNUM* bn3 = BN_new();
84   if (bn1 == nullptr || bn2 == nullptr || bn3 == nullptr) {
85     return InternalError(
86         absl::StrCat("ECGroup::CreateCurveParams - Could not create BIGNUM. ",
87                      OpenSSLErrorString()));
88   }
89   BigNum::BignumPtr p = BigNum::BignumPtr(bn1);
90   BigNum::BignumPtr a = BigNum::BignumPtr(bn2);
91   BigNum::BignumPtr b = BigNum::BignumPtr(bn3);
92   if (EC_GROUP_get_curve_GFp(group, p.get(), a.get(), b.get(),
93                              context->GetBnCtx()) != 1) {
94     return InternalError(
95         absl::StrCat("ECGroup::CreateCurveParams - Could not get params. ",
96                      OpenSSLErrorString()));
97   }
98   BigNum p_bn = context->CreateBigNum(std::move(p));
99   if (!p_bn.IsPrime()) {
100     return InternalError(absl::StrCat(
101         "ECGroup::CreateCurveParams - p is not prime. ", OpenSSLErrorString()));
102   }
103   return ECGroup::CurveParams{std::move(p_bn),
104                               context->CreateBigNum(std::move(a)),
105                               context->CreateBigNum(std::move(b))};
106 }
107 
108 // Returns (p - 1) / 2 where p is a curve-defining parameter.
GetPMinusOneOverTwo(const ECGroup::CurveParams & curve_params,Context * context)109 BigNum GetPMinusOneOverTwo(const ECGroup::CurveParams& curve_params,
110                            Context* context) {
111   return (curve_params.p - context->One()) / context->Two();
112 }
113 
114 }  // namespace
115 
ECGroup(Context * context,ECGroupPtr group,BigNum order,BigNum cofactor,CurveParams curve_params,BigNum p_minus_one_over_two)116 ECGroup::ECGroup(Context* context, ECGroupPtr group, BigNum order,
117                  BigNum cofactor, CurveParams curve_params,
118                  BigNum p_minus_one_over_two)
119     : context_(context),
120       group_(std::move(group)),
121       order_(std::move(order)),
122       cofactor_(std::move(cofactor)),
123       curve_params_(std::move(curve_params)),
124       p_minus_one_over_two_(std::move(p_minus_one_over_two)) {}
125 
Create(int curve_id,Context * context)126 StatusOr<ECGroup> ECGroup::Create(int curve_id, Context* context) {
127   ASSIGN_OR_RETURN(ECGroupPtr g, CreateGroup(curve_id));
128   ASSIGN_OR_RETURN(BigNum order, CreateOrder(g.get(), context));
129   ASSIGN_OR_RETURN(BigNum cofactor, CreateCofactor(g.get(), context));
130   ASSIGN_OR_RETURN(CurveParams params, CreateCurveParams(g.get(), context));
131   BigNum p_minus_one_over_two = GetPMinusOneOverTwo(params, context);
132   return ECGroup(context, std::move(g), std::move(order), std::move(cofactor),
133                  std::move(params), std::move(p_minus_one_over_two));
134 }
135 
GeneratePrivateKey() const136 BigNum ECGroup::GeneratePrivateKey() const {
137   return context_->GenerateRandBetween(context_->One(), order_);
138 }
139 
CheckPrivateKey(const BigNum & priv_key) const140 Status ECGroup::CheckPrivateKey(const BigNum& priv_key) const {
141   if (context_->Zero() >= priv_key || priv_key >= order_) {
142     return InvalidArgumentError(
143         "The given key is out of bounds, needs to be in [1, order) instead.");
144   }
145   return OkStatus();
146 }
147 
GetPointByHashingToCurveInternal(const BigNum & x) const148 StatusOr<ECPoint> ECGroup::GetPointByHashingToCurveInternal(
149     const BigNum& x) const {
150   BigNum mod_x = x.Mod(curve_params_.p);
151   BigNum y2 = ComputeYSquare(mod_x);
152   if (IsSquare(y2)) {
153     BigNum sqrt = y2.ModSqrt(curve_params_.p);
154     if (sqrt.IsBitSet(0)) {
155       return CreateECPoint(mod_x, sqrt.ModNegate(curve_params_.p));
156     }
157     return CreateECPoint(mod_x, sqrt);
158   }
159   return InternalError("Could not hash x to the curve.");
160 }
161 
GetPointByHashingToCurveSha256(absl::string_view m) const162 StatusOr<ECPoint> ECGroup::GetPointByHashingToCurveSha256(
163     absl::string_view m) const {
164   BigNum x = context_->RandomOracleSha256(m, curve_params_.p);
165   while (true) {
166     auto status_or_point = GetPointByHashingToCurveInternal(x);
167     if (status_or_point.ok()) {
168       return status_or_point;
169     }
170     x = context_->RandomOracleSha256(x.ToBytes(), curve_params_.p);
171   }
172 }
173 
GetPointByHashingToCurveSha384(absl::string_view m) const174 StatusOr<ECPoint> ECGroup::GetPointByHashingToCurveSha384(
175     absl::string_view m) const {
176   BigNum x = context_->RandomOracleSha384(m, curve_params_.p);
177   while (true) {
178     auto status_or_point = GetPointByHashingToCurveInternal(x);
179     if (status_or_point.ok()) {
180       return status_or_point;
181     }
182     x = context_->RandomOracleSha384(x.ToBytes(), curve_params_.p);
183   }
184 }
185 
GetPointByHashingToCurveSha512(absl::string_view m) const186 StatusOr<ECPoint> ECGroup::GetPointByHashingToCurveSha512(
187     absl::string_view m) const {
188   BigNum x = context_->RandomOracleSha512(m, curve_params_.p);
189   while (true) {
190     auto status_or_point = GetPointByHashingToCurveInternal(x);
191     if (status_or_point.ok()) {
192       return status_or_point;
193     }
194     x = context_->RandomOracleSha512(x.ToBytes(), curve_params_.p);
195   }
196 }
197 
GetPointByHashingToCurveSswuRo(absl::string_view m,absl::string_view dst) const198 StatusOr<ECPoint> ECGroup::GetPointByHashingToCurveSswuRo(
199     absl::string_view m, absl::string_view dst) const {
200   ASSIGN_OR_RETURN(ECPoint out, GetPointAtInfinity());
201   int curve_id = GetCurveId();
202   if (curve_id == NID_X9_62_prime256v1) {
203     if (EC_hash_to_curve_p256_xmd_sha256_sswu(
204             group_.get(), out.point_.get(),
205             reinterpret_cast<const uint8_t*>(dst.data()), dst.length(),
206             reinterpret_cast<const uint8_t*>(m.data()), m.length()) != 1) {
207       return InternalError(OpenSSLErrorString());
208     }
209   } else if (curve_id == NID_secp384r1) {
210     if (EC_hash_to_curve_p384_xmd_sha384_sswu(
211             group_.get(), out.point_.get(),
212             reinterpret_cast<const uint8_t*>(dst.data()), dst.length(),
213             reinterpret_cast<const uint8_t*>(m.data()), m.length()) != 1) {
214       return InternalError(OpenSSLErrorString());
215     }
216   } else {
217     return InvalidArgumentError("Curve does not support HashToCurveSswuRo.");
218   }
219   return out;
220 }
221 
ComputeYSquare(const BigNum & x) const222 BigNum ECGroup::ComputeYSquare(const BigNum& x) const {
223   return (x.Exp(context_->Three()) + curve_params_.a * x + curve_params_.b)
224       .Mod(curve_params_.p);
225 }
226 
IsValid(const ECPoint & point) const227 bool ECGroup::IsValid(const ECPoint& point) const {
228   if (!IsOnCurve(point) || IsAtInfinity(point)) {
229     return false;
230   }
231   return true;
232 }
233 
IsOnCurve(const ECPoint & point) const234 bool ECGroup::IsOnCurve(const ECPoint& point) const {
235   return 1 == EC_POINT_is_on_curve(group_.get(), point.point_.get(),
236                                    context_->GetBnCtx());
237 }
238 
IsAtInfinity(const ECPoint & point) const239 bool ECGroup::IsAtInfinity(const ECPoint& point) const {
240   return 1 == EC_POINT_is_at_infinity(group_.get(), point.point_.get());
241 }
242 
IsSquare(const BigNum & q) const243 bool ECGroup::IsSquare(const BigNum& q) const {
244   return q.ModExp(p_minus_one_over_two_, curve_params_.p).IsOne();
245 }
246 
GetFixedGenerator() const247 StatusOr<ECPoint> ECGroup::GetFixedGenerator() const {
248   const EC_POINT* ssl_generator = EC_GROUP_get0_generator(group_.get());
249   EC_POINT* dup_ssl_generator = EC_POINT_dup(ssl_generator, group_.get());
250   if (dup_ssl_generator == nullptr) {
251     return InternalError(OpenSSLErrorString());
252   }
253   return ECPoint(group_.get(), context_->GetBnCtx(),
254                  ECPoint::ECPointPtr(dup_ssl_generator));
255 }
256 
GetRandomGenerator() const257 StatusOr<ECPoint> ECGroup::GetRandomGenerator() const {
258   ASSIGN_OR_RETURN(ECPoint generator, GetFixedGenerator());
259   return generator.Mul(context_->GenerateRandBetween(context_->One(), order_));
260 }
261 
CreateECPoint(const BigNum & x,const BigNum & y) const262 StatusOr<ECPoint> ECGroup::CreateECPoint(const BigNum& x,
263                                          const BigNum& y) const {
264   ECPoint point = ECPoint(group_.get(), context_->GetBnCtx(), x, y);
265   if (!IsValid(point)) {
266     return InvalidArgumentError(
267         "ECGroup::CreateECPoint(x,y) - The point is not valid.");
268   }
269   return std::move(point);
270 }
271 
CreateECPoint(absl::string_view bytes) const272 StatusOr<ECPoint> ECGroup::CreateECPoint(absl::string_view bytes) const {
273   auto raw_ec_point_ptr = EC_POINT_new(group_.get());
274   if (raw_ec_point_ptr == nullptr) {
275     return InternalError("ECGroup::CreateECPoint: Failed to create point.");
276   }
277   ECPoint::ECPointPtr point(raw_ec_point_ptr);
278   if (EC_POINT_oct2point(group_.get(), point.get(),
279                          reinterpret_cast<const unsigned char*>(bytes.data()),
280                          bytes.size(), context_->GetBnCtx()) != 1) {
281     return InvalidArgumentError(
282         absl::StrCat("ECGroup::CreateECPoint(string) - Could not decode point.",
283                      "\n", OpenSSLErrorString()));
284   }
285 
286   ECPoint ec_point(group_.get(), context_->GetBnCtx(), std::move(point));
287   if (!IsValid(ec_point)) {
288     return InvalidArgumentError(
289         "ECGroup::CreateECPoint(string) - Decoded point is not valid.");
290   }
291   return std::move(ec_point);
292 }
293 
GetPointAtInfinity() const294 StatusOr<ECPoint> ECGroup::GetPointAtInfinity() const {
295   EC_POINT* new_point = EC_POINT_new(group_.get());
296   if (new_point == nullptr) {
297     return InternalError(
298         "ECGroup::GetPointAtInfinity() - Could not create new point.");
299   }
300   ECPoint::ECPointPtr point(new_point);
301   if (EC_POINT_set_to_infinity(group_.get(), point.get()) != 1) {
302     return InternalError(
303         "ECGroup::GetPointAtInfinity() - Could not get point at infinity.");
304   }
305   ECPoint ec_point(group_.get(), context_->GetBnCtx(), std::move(point));
306   return std::move(ec_point);
307 }
308 
309 }  // namespace private_join_and_compute
310