1# Copyright 2019 Google LLC. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15"""Module for elliptic curve related classes.""" 16 17import ctypes 18from typing import Optional, Union 19 20from private_join_and_compute.py.crypto_util import converters 21from private_join_and_compute.py.crypto_util import ssl_util 22from private_join_and_compute.py.crypto_util.ssl_util import BigNum 23from private_join_and_compute.py.crypto_util.ssl_util import OpenSSLHelper 24from private_join_and_compute.py.crypto_util.ssl_util import TempBNs 25from private_join_and_compute.py.crypto_util.supported_curves import SupportedCurve 26from private_join_and_compute.py.crypto_util.supported_hashes import HashType 27import six 28 29POINT_CONVERSION_COMPRESSED = 2 30 31 32class ECPoint(object): 33 """The ECPoint class.""" 34 35 def __init__(self, group, ec_point_bn): 36 self._group = group 37 self._point = ec_point_bn 38 self.ctx = OpenSSLHelper().ctx 39 # So that garbage collection doesn't collect ssl before this object. 40 self.ssl = ssl_util.ssl 41 42 @classmethod 43 def FromPoint(cls, group: int, x: int, y: int): 44 """Creates an EC_POINT object with the given x, y affine coordinates. 45 46 Args: 47 group: the EC_GROUP for the given point's elliptic curve 48 x: the x coordinate of the point as long value 49 y: the y coordinate of the point as long value 50 51 Returns: 52 <x, y> ECPoint on the elliptic curve defined by group 53 54 Raises: 55 TypeError: If the x, y coordinates are not of type long. 56 """ 57 ec_point = cls._EmptyPoint(group) 58 with TempBNs(x=x, y=y) as bn: 59 # pylint: disable=protected-access 60 ssl_util.ssl.EC_POINT_set_affine_coordinates_GFp( 61 group, ec_point._point, bn.x, bn.y, None 62 ) 63 # pylint: enable=protected-access 64 ec_point.CheckValidity() 65 return ec_point 66 67 @classmethod 68 def FromLongOrBytes(cls, group: int, point_long_or_bytes: Union[int, bytes]): 69 """Creates an EC_POINT object from its serialized bytes representation. 70 71 Args: 72 group: the EC_GROUP for the point's elliptic curve. 73 point_long_or_bytes: the serialized bytes representations of the point. 74 75 Returns: 76 The point encoded by point_long_or_bytes 77 78 Raises: 79 ValueError: if point_long_or_bytes is not a valid encoding of a point 80 from the EC group. 81 """ 82 ec_point = cls._EmptyPoint(group) 83 if isinstance(point_long_or_bytes, int): 84 point_long_or_bytes = converters.LongToBytes(point_long_or_bytes) 85 # pylint: disable=protected-access 86 ssl_util.ssl.EC_POINT_oct2point( 87 group, 88 ec_point._point, 89 point_long_or_bytes, 90 len(point_long_or_bytes), 91 None, 92 ) 93 # pylint: enable=protected-access 94 ec_point.CheckValidity() 95 return ec_point 96 97 @classmethod 98 def GetPointAtInfinity(cls, group): 99 p = ssl_util.ssl.EC_POINT_new(group) 100 ssl_util.ssl.EC_POINT_set_to_infinity(group, p) 101 return ECPoint(group, p) 102 103 @classmethod 104 def _EmptyPoint(cls, group): 105 return ECPoint(group, ssl_util.ssl.EC_POINT_new(group)) 106 107 def __del__(self): 108 self.ssl.EC_POINT_free(self._point) 109 110 def CheckValidity(self): 111 """Checks if this point is valid and can be multiplied with the key. 112 113 If the point is corrupted as a result of a faulty computation, this might 114 leak data about the key. 115 116 Raises: 117 ValueError: If the point is not on the curve or if the point is the 118 neutral element. 119 """ 120 if not self.IsOnCurve(): 121 raise ValueError('The point is not on the curve.') 122 123 if self.IsAtInfinity(): 124 raise ValueError('The point is the neutral element.') 125 126 def __mul__(self, scalar): 127 new_ec_point = self._EmptyPoint(self._group) 128 # pylint: disable=protected-access 129 if isinstance(scalar, BigNum): 130 ssl_util.ssl.EC_POINT_mul( 131 self._group, 132 new_ec_point._point, 133 None, 134 self._point, 135 scalar._bn_num, 136 self.ctx, 137 ) 138 else: 139 ssl_util.ssl.EC_POINT_mul( 140 self._group, new_ec_point._point, None, self._point, scalar, self.ctx 141 ) 142 # pylint: enable=protected-access 143 return new_ec_point 144 145 def __imul__(self, scalar): 146 if isinstance(scalar, BigNum): 147 # pylint: disable=protected-access 148 ssl_util.ssl.EC_POINT_mul( 149 self._group, self._point, None, self._point, scalar._bn_num, self.ctx 150 ) 151 # pylint: enable=protected-access 152 else: 153 ssl_util.ssl.EC_POINT_mul( 154 self._group, self._point, None, self._point, scalar, self.ctx 155 ) 156 return self 157 158 def __add__(self, ec_point): 159 new_ec_point = self._EmptyPoint(self._group) 160 # pylint: disable=protected-access 161 ssl_util.ssl.EC_POINT_add( 162 self._group, new_ec_point._point, self._point, ec_point._point, self.ctx 163 ) 164 # pylint: enable=protected-access 165 return new_ec_point 166 167 def __iadd__(self, ec_point): 168 # pylint: disable=protected-access 169 ssl_util.ssl.EC_POINT_add( 170 self._group, self._point, self._point, ec_point._point, self.ctx 171 ) 172 # pylint: enable=protected-access 173 return self 174 175 def IsOnCurve(self) -> bool: 176 return 1 == ssl_util.ssl.EC_POINT_is_on_curve( 177 self._group, self._point, None 178 ) 179 180 def IsAtInfinity(self) -> bool: 181 return 1 == ssl_util.ssl.EC_POINT_is_at_infinity(self._group, self._point) 182 183 def GetAsLong(self) -> int: 184 return converters.BytesToLong(self.GetAsBytes()) 185 186 def GetAsBytes(self) -> bytes: 187 buf_len = ssl_util.ssl.EC_POINT_point2oct( 188 self._group, self._point, POINT_CONVERSION_COMPRESSED, None, 0, None 189 ) 190 buf = ctypes.create_string_buffer(buf_len) 191 ssl_util.ssl.EC_POINT_point2oct( 192 self._group, 193 self._point, 194 POINT_CONVERSION_COMPRESSED, 195 buf, 196 buf_len, 197 None, 198 ) 199 return six.ensure_binary(buf.raw) 200 201 def __eq__(self, other: 'ECPoint'): 202 # pylint: disable=protected-access 203 if isinstance(other, self.__class__): 204 return 0 == ssl_util.ssl.EC_POINT_cmp( 205 self._group, self._point, other._point, self.ctx 206 ) 207 raise ValueError('Cannot compare ECPoint with type {}'.format(type(other))) 208 # pylint: enable=protected-access 209 210 def __ne__(self, other: 'ECPoint'): 211 return not self.__eq__(other) 212 213 def __str__(self): 214 return str(self.GetAsLong()) 215 216 217class EllipticCurve(object): 218 """Class for representing the elliptic curve.""" 219 220 def __init__( 221 self, 222 curve_id: Union[int, SupportedCurve], 223 hash_type: Optional[HashType] = None, 224 ): 225 if isinstance(curve_id, SupportedCurve): 226 curve_id = curve_id.id 227 if hash_type is None: 228 hash_type = HashType.SHA512 229 self._hash_type = hash_type 230 self._group = ssl_util.ssl.EC_GROUP_new_by_curve_name(curve_id) 231 with TempBNs(p=None, a=None, b=None, order=None) as bn: 232 ssl_util.ssl.EC_GROUP_get_curve_GFp(self._group, bn.p, bn.a, bn.b, None) 233 ssl_util.ssl.EC_GROUP_get_order( 234 self._group, bn.order, OpenSSLHelper().ctx 235 ) 236 self._order = ssl_util.BnToLong(bn.order) 237 self._p = ssl_util.BnToLong(bn.p) 238 self._p_bn = BigNum.FromLongNumber(self._p) 239 if not self._p_bn.IsPrime(): 240 raise ValueError( 241 'Wrong curve parameters: p must be a prime. p: {}'.format(self._p) 242 ) 243 self._a = ssl_util.BnToLong(bn.a) 244 self._b = ssl_util.BnToLong(bn.b) 245 self._p_sub_one_div_by_two = (self._p - 1) >> 1 246 # So that garbage collection doesn't collect ssl before this object. 247 self.ssl = ssl_util.ssl 248 249 def __del__(self): 250 self.ssl.EC_GROUP_free(self._group) 251 252 def GetPointByHashingToCurve(self, m: Union[int, bytes]) -> ECPoint: 253 """Hashes m into the elliptic curve.""" 254 return ECPoint.FromPoint(self.group, *self.HashToCurve(m)) 255 256 def GetPointFromLong(self, m_long: int) -> ECPoint: 257 """Converts the given compressed point (m_long) into ECPoint.""" 258 return ECPoint.FromLongOrBytes(self.group, m_long) 259 260 def GetPointFromBytes(self, m_bytes: bytes) -> ECPoint: 261 """Converts the given compressed point (m_bytes) into ECPoint.""" 262 return ECPoint.FromLongOrBytes(self.group, m_bytes) 263 264 def GetPointAtInfinity(self) -> ECPoint: 265 """Gets a point at the infinity.""" 266 return ECPoint.GetPointAtInfinity(self.group) 267 268 def GetRandomGenerator(self): 269 ssl_point = ssl_util.ssl.EC_GROUP_get0_generator(self.group) 270 generator = ECPoint( 271 self.group, ssl_util.ssl.EC_POINT_dup(ssl_point, self.group) 272 ) 273 generator *= BigNum.FromLongNumber(self.order).GenerateRandWithStart( 274 BigNum.One() 275 ) 276 return generator 277 278 def ComputeYSquare(self, x: int): 279 """Returns y^2 calculated with x^3 + ax + b.""" 280 return (x**3 + self._a * x + self._b) % self._p 281 282 def HashToCurve(self, m: Union[int, bytes]): 283 """ "Hash m to a point on the elliptic curve y^2 = x^3 + ax + b. 284 285 To hash m to a point on the curve, the algorithm first computes an integer 286 hash value x = h(m) and determines whether x is the abscissa of a point on 287 the elliptic curve y^2 = x^3 + ax + b. If not, set x = h(x) and try again. 288 289 Security: 290 The number of operations required to hash a message m depends on m, which 291 could lead to a timing attack. 292 293 Args: 294 m: long, int or str input 295 296 Returns: 297 A point (x, y) on this elliptic curve. 298 """ 299 x = ssl_util.RandomOracle(m, self._p, hash_type=self._hash_type) 300 y2 = self.ComputeYSquare(x) 301 302 # y2 is a quadratic residue if y2^(p-1)/2 = 1 303 if 1 == ssl_util.ModExp(y2, self._p_sub_one_div_by_two, self._p): 304 y2_bn = ssl_util.BigNum.FromLongNumber(y2).Mutable() 305 y2_bn.IModSqrt(self._p_bn) 306 if y2_bn.IsBitSet(0): 307 return (x, y2_bn.ModNegate(self._p_bn).GetAsLong()) 308 return (x, y2_bn.GetAsLong()) 309 else: 310 return self.HashToCurve(x) 311 312 def __eq__(self, other): 313 # pylint: disable=protected-access 314 if isinstance(other, self.__class__): 315 return self._p == other._p and self._a == other._a and self._b == other._b 316 raise ValueError( 317 'Cannot compare EllipticCurve with type {}'.format(type(other)) 318 ) 319 # pylint: enable=protected-access 320 321 @property 322 def group(self): 323 return self._group 324 325 @property 326 def order(self): 327 return self._order 328 329 330class ECKey(object): 331 """Class representing the elliptic curve key.""" 332 333 def __init__( 334 self, 335 curve_id: Union[int, SupportedCurve], 336 priv_key_bytes: Optional[bytes] = None, 337 hash_type: Optional[HashType] = None, 338 ): 339 if isinstance(curve_id, SupportedCurve): 340 curve_id = curve_id.id 341 self._curve_id = curve_id 342 self._key = ssl_util.ssl.EC_KEY_new_by_curve_name(curve_id) 343 if priv_key_bytes: 344 ssl_util.ssl.EC_KEY_set_private_key( 345 self._key, ssl_util.BytesToBn(priv_key_bytes) 346 ) 347 else: 348 if 1 != ssl_util.ssl.EC_KEY_generate_key(self._key): 349 raise Exception('EC key generation failed.') 350 self._Check() 351 self._priv_key_bn = ssl_util.ssl.EC_KEY_get0_private_key(self._key) 352 self._priv_key_bytes = ssl_util.BnToBytes(self._priv_key_bn) 353 self._priv_key_bignum = BigNum.FromBytes(self._priv_key_bytes) 354 self._ec = EllipticCurve(curve_id, hash_type=hash_type) 355 self._decrypt_key = self._priv_key_bignum.ModInverse( 356 BigNum.FromLongNumber(self._ec.order) 357 ) 358 # So that garbage collection doesn't collect ssl before this object. 359 self.ssl = ssl_util.ssl 360 361 def __del__(self): 362 self.ssl.EC_KEY_free(self._key) 363 364 def _Check(self): 365 if 0 == ssl_util.ssl.EC_KEY_check_key(self._key): 366 raise ValueError('The ECKey checks has failed.') 367 368 @property 369 def priv_key_bytes(self): 370 return self._priv_key_bytes 371 372 @property 373 def priv_key_bn(self): 374 return self._priv_key_bn 375 376 @property 377 def priv_key_bignum(self): 378 return self._priv_key_bignum 379 380 @property 381 def decrypt_key_bignum(self): 382 return self._decrypt_key 383 384 @property 385 def elliptic_curve(self): 386 return self._ec 387 388 @property 389 def curve_id(self): 390 return self._curve_id 391