1# Copyright 2019 Google LLC.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Module for elliptic curve related classes."""
16
17import ctypes
18from typing import Optional, Union
19
20from private_join_and_compute.py.crypto_util import converters
21from private_join_and_compute.py.crypto_util import ssl_util
22from private_join_and_compute.py.crypto_util.ssl_util import BigNum
23from private_join_and_compute.py.crypto_util.ssl_util import OpenSSLHelper
24from private_join_and_compute.py.crypto_util.ssl_util import TempBNs
25from private_join_and_compute.py.crypto_util.supported_curves import SupportedCurve
26from private_join_and_compute.py.crypto_util.supported_hashes import HashType
27import six
28
29POINT_CONVERSION_COMPRESSED = 2
30
31
32class ECPoint(object):
33  """The ECPoint class."""
34
35  def __init__(self, group, ec_point_bn):
36    self._group = group
37    self._point = ec_point_bn
38    self.ctx = OpenSSLHelper().ctx
39    # So that garbage collection doesn't collect ssl before this object.
40    self.ssl = ssl_util.ssl
41
42  @classmethod
43  def FromPoint(cls, group: int, x: int, y: int):
44    """Creates an EC_POINT object with the given x, y affine coordinates.
45
46    Args:
47      group: the EC_GROUP for the given point's elliptic curve
48      x: the x coordinate of the point as long value
49      y: the y coordinate of the point as long value
50
51    Returns:
52      <x, y> ECPoint on the elliptic curve defined by group
53
54    Raises:
55      TypeError: If the x, y coordinates are not of type long.
56    """
57    ec_point = cls._EmptyPoint(group)
58    with TempBNs(x=x, y=y) as bn:
59      # pylint: disable=protected-access
60      ssl_util.ssl.EC_POINT_set_affine_coordinates_GFp(
61          group, ec_point._point, bn.x, bn.y, None
62      )
63      # pylint: enable=protected-access
64    ec_point.CheckValidity()
65    return ec_point
66
67  @classmethod
68  def FromLongOrBytes(cls, group: int, point_long_or_bytes: Union[int, bytes]):
69    """Creates an EC_POINT object from its serialized bytes representation.
70
71    Args:
72      group: the EC_GROUP for the point's elliptic curve.
73      point_long_or_bytes: the serialized bytes representations of the point.
74
75    Returns:
76      The point encoded by point_long_or_bytes
77
78    Raises:
79      ValueError: if point_long_or_bytes is not a valid encoding of a point
80      from the EC group.
81    """
82    ec_point = cls._EmptyPoint(group)
83    if isinstance(point_long_or_bytes, int):
84      point_long_or_bytes = converters.LongToBytes(point_long_or_bytes)
85    # pylint: disable=protected-access
86    ssl_util.ssl.EC_POINT_oct2point(
87        group,
88        ec_point._point,
89        point_long_or_bytes,
90        len(point_long_or_bytes),
91        None,
92    )
93    # pylint: enable=protected-access
94    ec_point.CheckValidity()
95    return ec_point
96
97  @classmethod
98  def GetPointAtInfinity(cls, group):
99    p = ssl_util.ssl.EC_POINT_new(group)
100    ssl_util.ssl.EC_POINT_set_to_infinity(group, p)
101    return ECPoint(group, p)
102
103  @classmethod
104  def _EmptyPoint(cls, group):
105    return ECPoint(group, ssl_util.ssl.EC_POINT_new(group))
106
107  def __del__(self):
108    self.ssl.EC_POINT_free(self._point)
109
110  def CheckValidity(self):
111    """Checks if this point is valid and can be multiplied with the key.
112
113    If the point is corrupted as a result of a faulty computation, this might
114    leak data about the key.
115
116    Raises:
117      ValueError: If the point is not on the curve or if the point is the
118      neutral element.
119    """
120    if not self.IsOnCurve():
121      raise ValueError('The point is not on the curve.')
122
123    if self.IsAtInfinity():
124      raise ValueError('The point is the neutral element.')
125
126  def __mul__(self, scalar):
127    new_ec_point = self._EmptyPoint(self._group)
128    # pylint: disable=protected-access
129    if isinstance(scalar, BigNum):
130      ssl_util.ssl.EC_POINT_mul(
131          self._group,
132          new_ec_point._point,
133          None,
134          self._point,
135          scalar._bn_num,
136          self.ctx,
137      )
138    else:
139      ssl_util.ssl.EC_POINT_mul(
140          self._group, new_ec_point._point, None, self._point, scalar, self.ctx
141      )
142    # pylint: enable=protected-access
143    return new_ec_point
144
145  def __imul__(self, scalar):
146    if isinstance(scalar, BigNum):
147      # pylint: disable=protected-access
148      ssl_util.ssl.EC_POINT_mul(
149          self._group, self._point, None, self._point, scalar._bn_num, self.ctx
150      )
151      # pylint: enable=protected-access
152    else:
153      ssl_util.ssl.EC_POINT_mul(
154          self._group, self._point, None, self._point, scalar, self.ctx
155      )
156    return self
157
158  def __add__(self, ec_point):
159    new_ec_point = self._EmptyPoint(self._group)
160    # pylint: disable=protected-access
161    ssl_util.ssl.EC_POINT_add(
162        self._group, new_ec_point._point, self._point, ec_point._point, self.ctx
163    )
164    # pylint: enable=protected-access
165    return new_ec_point
166
167  def __iadd__(self, ec_point):
168    # pylint: disable=protected-access
169    ssl_util.ssl.EC_POINT_add(
170        self._group, self._point, self._point, ec_point._point, self.ctx
171    )
172    # pylint: enable=protected-access
173    return self
174
175  def IsOnCurve(self) -> bool:
176    return 1 == ssl_util.ssl.EC_POINT_is_on_curve(
177        self._group, self._point, None
178    )
179
180  def IsAtInfinity(self) -> bool:
181    return 1 == ssl_util.ssl.EC_POINT_is_at_infinity(self._group, self._point)
182
183  def GetAsLong(self) -> int:
184    return converters.BytesToLong(self.GetAsBytes())
185
186  def GetAsBytes(self) -> bytes:
187    buf_len = ssl_util.ssl.EC_POINT_point2oct(
188        self._group, self._point, POINT_CONVERSION_COMPRESSED, None, 0, None
189    )
190    buf = ctypes.create_string_buffer(buf_len)
191    ssl_util.ssl.EC_POINT_point2oct(
192        self._group,
193        self._point,
194        POINT_CONVERSION_COMPRESSED,
195        buf,
196        buf_len,
197        None,
198    )
199    return six.ensure_binary(buf.raw)
200
201  def __eq__(self, other: 'ECPoint'):
202    # pylint: disable=protected-access
203    if isinstance(other, self.__class__):
204      return 0 == ssl_util.ssl.EC_POINT_cmp(
205          self._group, self._point, other._point, self.ctx
206      )
207    raise ValueError('Cannot compare ECPoint with type {}'.format(type(other)))
208    # pylint: enable=protected-access
209
210  def __ne__(self, other: 'ECPoint'):
211    return not self.__eq__(other)
212
213  def __str__(self):
214    return str(self.GetAsLong())
215
216
217class EllipticCurve(object):
218  """Class for representing the elliptic curve."""
219
220  def __init__(
221      self,
222      curve_id: Union[int, SupportedCurve],
223      hash_type: Optional[HashType] = None,
224  ):
225    if isinstance(curve_id, SupportedCurve):
226      curve_id = curve_id.id
227    if hash_type is None:
228      hash_type = HashType.SHA512
229    self._hash_type = hash_type
230    self._group = ssl_util.ssl.EC_GROUP_new_by_curve_name(curve_id)
231    with TempBNs(p=None, a=None, b=None, order=None) as bn:
232      ssl_util.ssl.EC_GROUP_get_curve_GFp(self._group, bn.p, bn.a, bn.b, None)
233      ssl_util.ssl.EC_GROUP_get_order(
234          self._group, bn.order, OpenSSLHelper().ctx
235      )
236      self._order = ssl_util.BnToLong(bn.order)
237      self._p = ssl_util.BnToLong(bn.p)
238      self._p_bn = BigNum.FromLongNumber(self._p)
239      if not self._p_bn.IsPrime():
240        raise ValueError(
241            'Wrong curve parameters: p must be a prime. p: {}'.format(self._p)
242        )
243      self._a = ssl_util.BnToLong(bn.a)
244      self._b = ssl_util.BnToLong(bn.b)
245      self._p_sub_one_div_by_two = (self._p - 1) >> 1
246    # So that garbage collection doesn't collect ssl before this object.
247    self.ssl = ssl_util.ssl
248
249  def __del__(self):
250    self.ssl.EC_GROUP_free(self._group)
251
252  def GetPointByHashingToCurve(self, m: Union[int, bytes]) -> ECPoint:
253    """Hashes m into the elliptic curve."""
254    return ECPoint.FromPoint(self.group, *self.HashToCurve(m))
255
256  def GetPointFromLong(self, m_long: int) -> ECPoint:
257    """Converts the given compressed point (m_long) into ECPoint."""
258    return ECPoint.FromLongOrBytes(self.group, m_long)
259
260  def GetPointFromBytes(self, m_bytes: bytes) -> ECPoint:
261    """Converts the given compressed point (m_bytes) into ECPoint."""
262    return ECPoint.FromLongOrBytes(self.group, m_bytes)
263
264  def GetPointAtInfinity(self) -> ECPoint:
265    """Gets a point at the infinity."""
266    return ECPoint.GetPointAtInfinity(self.group)
267
268  def GetRandomGenerator(self):
269    ssl_point = ssl_util.ssl.EC_GROUP_get0_generator(self.group)
270    generator = ECPoint(
271        self.group, ssl_util.ssl.EC_POINT_dup(ssl_point, self.group)
272    )
273    generator *= BigNum.FromLongNumber(self.order).GenerateRandWithStart(
274        BigNum.One()
275    )
276    return generator
277
278  def ComputeYSquare(self, x: int):
279    """Returns y^2 calculated with x^3 + ax + b."""
280    return (x**3 + self._a * x + self._b) % self._p
281
282  def HashToCurve(self, m: Union[int, bytes]):
283    """ "Hash m to a point on the elliptic curve y^2 = x^3 + ax + b.
284
285    To hash m to a point on the curve, the algorithm first computes an integer
286    hash value x = h(m) and determines whether x is the abscissa of a point on
287    the elliptic curve y^2 = x^3 + ax + b. If not, set x = h(x) and try again.
288
289    Security:
290    The number of operations required to hash a message m depends on m, which
291    could lead to a timing attack.
292
293    Args:
294      m: long, int or str input
295
296    Returns:
297      A point (x, y) on this elliptic curve.
298    """
299    x = ssl_util.RandomOracle(m, self._p, hash_type=self._hash_type)
300    y2 = self.ComputeYSquare(x)
301
302    # y2 is a quadratic residue if y2^(p-1)/2 = 1
303    if 1 == ssl_util.ModExp(y2, self._p_sub_one_div_by_two, self._p):
304      y2_bn = ssl_util.BigNum.FromLongNumber(y2).Mutable()
305      y2_bn.IModSqrt(self._p_bn)
306      if y2_bn.IsBitSet(0):
307        return (x, y2_bn.ModNegate(self._p_bn).GetAsLong())
308      return (x, y2_bn.GetAsLong())
309    else:
310      return self.HashToCurve(x)
311
312  def __eq__(self, other):
313    # pylint: disable=protected-access
314    if isinstance(other, self.__class__):
315      return self._p == other._p and self._a == other._a and self._b == other._b
316    raise ValueError(
317        'Cannot compare EllipticCurve with type {}'.format(type(other))
318    )
319    # pylint: enable=protected-access
320
321  @property
322  def group(self):
323    return self._group
324
325  @property
326  def order(self):
327    return self._order
328
329
330class ECKey(object):
331  """Class representing the elliptic curve key."""
332
333  def __init__(
334      self,
335      curve_id: Union[int, SupportedCurve],
336      priv_key_bytes: Optional[bytes] = None,
337      hash_type: Optional[HashType] = None,
338  ):
339    if isinstance(curve_id, SupportedCurve):
340      curve_id = curve_id.id
341    self._curve_id = curve_id
342    self._key = ssl_util.ssl.EC_KEY_new_by_curve_name(curve_id)
343    if priv_key_bytes:
344      ssl_util.ssl.EC_KEY_set_private_key(
345          self._key, ssl_util.BytesToBn(priv_key_bytes)
346      )
347    else:
348      if 1 != ssl_util.ssl.EC_KEY_generate_key(self._key):
349        raise Exception('EC key generation failed.')
350      self._Check()
351    self._priv_key_bn = ssl_util.ssl.EC_KEY_get0_private_key(self._key)
352    self._priv_key_bytes = ssl_util.BnToBytes(self._priv_key_bn)
353    self._priv_key_bignum = BigNum.FromBytes(self._priv_key_bytes)
354    self._ec = EllipticCurve(curve_id, hash_type=hash_type)
355    self._decrypt_key = self._priv_key_bignum.ModInverse(
356        BigNum.FromLongNumber(self._ec.order)
357    )
358    # So that garbage collection doesn't collect ssl before this object.
359    self.ssl = ssl_util.ssl
360
361  def __del__(self):
362    self.ssl.EC_KEY_free(self._key)
363
364  def _Check(self):
365    if 0 == ssl_util.ssl.EC_KEY_check_key(self._key):
366      raise ValueError('The ECKey checks has failed.')
367
368  @property
369  def priv_key_bytes(self):
370    return self._priv_key_bytes
371
372  @property
373  def priv_key_bn(self):
374    return self._priv_key_bn
375
376  @property
377  def priv_key_bignum(self):
378    return self._priv_key_bignum
379
380  @property
381  def decrypt_key_bignum(self):
382    return self._decrypt_key
383
384  @property
385  def elliptic_curve(self):
386    return self._ec
387
388  @property
389  def curve_id(self):
390    return self._curve_id
391