1# Copyright 2019 Google LLC.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15
16"""Make available access to openssl library and bn functions."""
17
18import ctypes.util
19from functools import total_ordering
20import hashlib
21import math
22from typing import Union
23
24from absl import logging
25from private_join_and_compute.py.crypto_util import converters
26from private_join_and_compute.py.crypto_util.supported_hashes import HashType
27import six
28
29ssl = None
30
31try:
32  ssl_libpath = ctypes.util.find_library('crypto')
33  ssl = ctypes.cdll.LoadLibrary(ssl_libpath)
34except (OSError, IOError) as e:
35  logging.fatal('Could not load the ssl library.\n%s', e)
36
37ssl.ERR_error_string_n.restype = ctypes.c_void_p
38ssl.ERR_error_string_n.argtypes = [
39    ctypes.c_long,
40    ctypes.c_char_p,
41    ctypes.c_size_t,
42]
43ssl.ERR_get_error.restype = ctypes.c_long
44ssl.ERR_get_error.argtypes = []
45
46ssl.BN_new.restype = ctypes.c_void_p
47ssl.BN_new.argtypes = []
48ssl.BN_free.argtypes = [ctypes.c_void_p]
49ssl.BN_num_bits.restype = ctypes.c_int
50ssl.BN_num_bits.argtypes = [ctypes.c_void_p]
51ssl.BN_bin2bn.restype = ctypes.c_void_p
52ssl.BN_bin2bn.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.c_void_p]
53ssl.BN_bn2bin.restype = ctypes.c_int
54ssl.BN_bn2bin.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
55ssl.BN_CTX_new.restype = ctypes.c_void_p
56ssl.BN_CTX_new.argtypes = []
57ssl.BN_CTX_free.restype = ctypes.c_int
58ssl.BN_CTX_free.argtypes = [ctypes.c_void_p]
59ssl.BN_mod_exp.restype = ctypes.c_int
60ssl.BN_mod_exp.argtypes = [
61    ctypes.c_void_p,
62    ctypes.c_void_p,
63    ctypes.c_void_p,
64    ctypes.c_void_p,
65    ctypes.c_void_p,
66]
67ssl.BN_mod_mul.restype = ctypes.c_int
68ssl.BN_mod_mul.argtypes = [
69    ctypes.c_void_p,
70    ctypes.c_void_p,
71    ctypes.c_void_p,
72    ctypes.c_void_p,
73    ctypes.c_void_p,
74]
75ssl.BN_CTX_new.argtypes = []
76ssl.BN_CTX_free.argtypes = [ctypes.c_void_p]
77ssl.BN_generate_prime_ex.restype = ctypes.c_int
78ssl.BN_generate_prime_ex.argtypes = [
79    ctypes.c_void_p,
80    ctypes.c_int,
81    ctypes.c_int,
82    ctypes.c_void_p,
83    ctypes.c_void_p,
84    ctypes.c_void_p,
85]
86ssl.BN_is_prime_ex.restype = ctypes.c_int
87ssl.BN_is_prime_ex.argtypes = [
88    ctypes.c_void_p,
89    ctypes.c_int,
90    ctypes.c_void_p,
91    ctypes.c_void_p,
92]
93ssl.BN_mul.restype = ctypes.c_int
94ssl.BN_mul.argtypes = [
95    ctypes.c_void_p,
96    ctypes.c_void_p,
97    ctypes.c_void_p,
98    ctypes.c_void_p,
99]
100ssl.BN_div.restype = ctypes.c_int
101ssl.BN_div.argtypes = [
102    ctypes.c_void_p,
103    ctypes.c_void_p,
104    ctypes.c_void_p,
105    ctypes.c_void_p,
106    ctypes.c_void_p,
107]
108ssl.BN_exp.restype = ctypes.c_int
109ssl.BN_exp.argtypes = [
110    ctypes.c_void_p,
111    ctypes.c_void_p,
112    ctypes.c_void_p,
113    ctypes.c_void_p,
114]
115ssl.RAND_seed.restype = ctypes.c_int
116ssl.RAND_seed.argtypes = [ctypes.c_void_p, ctypes.c_int]
117ssl.BN_gcd.restype = ctypes.c_int
118ssl.BN_gcd.argtypes = [
119    ctypes.c_void_p,
120    ctypes.c_void_p,
121    ctypes.c_void_p,
122    ctypes.c_void_p,
123]
124ssl.BN_mod_inverse.restype = ctypes.c_void_p
125ssl.BN_mod_inverse.argtypes = [
126    ctypes.c_void_p,
127    ctypes.c_void_p,
128    ctypes.c_void_p,
129    ctypes.c_void_p,
130]
131ssl.BN_mod_sqrt.restype = ctypes.c_void_p
132ssl.BN_mod_sqrt.argtypes = [
133    ctypes.c_void_p,
134    ctypes.c_void_p,
135    ctypes.c_void_p,
136    ctypes.c_void_p,
137]
138ssl.BN_add.restype = ctypes.c_int
139ssl.BN_add.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p]
140ssl.BN_sub.restype = ctypes.c_int
141ssl.BN_sub.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p]
142ssl.BN_nnmod.restype = ctypes.c_int
143ssl.BN_nnmod.argtypes = [
144    ctypes.c_void_p,
145    ctypes.c_void_p,
146    ctypes.c_void_p,
147    ctypes.c_void_p,
148]
149ssl.BN_rand_range.restype = ctypes.c_int
150ssl.BN_rand_range.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
151ssl.BN_lshift.restype = ctypes.c_int
152ssl.BN_lshift.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int]
153ssl.BN_rshift.restype = ctypes.c_int
154ssl.BN_rshift.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int]
155ssl.BN_cmp.restype = ctypes.c_int
156ssl.BN_cmp.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
157ssl.BN_is_bit_set.restype = ctypes.c_int
158ssl.BN_is_bit_set.argtypes = [ctypes.c_void_p, ctypes.c_int]
159
160ssl.EVP_PKEY_new.argtypes = []
161ssl.EVP_PKEY_new.restype = ctypes.c_void_p
162
163ssl.EC_KEY_new.restype = ctypes.c_void_p
164ssl.EC_KEY_new.argtypes = []
165ssl.EC_KEY_free.argtypes = [ctypes.c_void_p]
166ssl.EC_KEY_new_by_curve_name.restype = ctypes.c_void_p
167ssl.EC_KEY_new_by_curve_name.argtypes = [ctypes.c_int]
168ssl.EC_KEY_generate_key.restype = ctypes.c_int
169ssl.EC_KEY_generate_key.argtypes = [ctypes.c_void_p]
170ssl.EC_KEY_set_asn1_flag.restype = None
171ssl.EC_KEY_set_asn1_flag.argtypes = [ctypes.c_void_p, ctypes.c_int]
172
173ssl.EC_KEY_get0_public_key.restype = ctypes.c_void_p
174ssl.EC_KEY_get0_public_key.argtypes = [ctypes.c_void_p]
175
176ssl.EC_KEY_set_public_key.restype = ctypes.c_int
177ssl.EC_KEY_set_public_key.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
178
179ssl.EC_KEY_get0_private_key.restype = ctypes.c_void_p
180ssl.EC_KEY_get0_private_key.argtypes = [ctypes.c_void_p]
181
182ssl.EC_KEY_set_private_key.restype = ctypes.c_int
183ssl.EC_KEY_set_private_key.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
184
185ssl.EC_KEY_check_key.restype = ctypes.c_int
186ssl.EC_KEY_check_key.argtypes = [ctypes.c_void_p]
187
188ssl.EVP_PKEY_free.argtypes = [ctypes.c_void_p]
189ssl.EVP_PKEY_free.restype = None
190
191ssl.EVP_PKEY_get1_EC_KEY.restype = ctypes.c_void_p
192ssl.EVP_PKEY_get1_EC_KEY.argtypes = [ctypes.c_void_p]
193
194ssl.EC_GROUP_free.argtypes = [ctypes.c_void_p]
195ssl.EC_GROUP_get_order.restype = ctypes.c_int
196ssl.EC_GROUP_get_order.argtypes = [
197    ctypes.c_void_p,
198    ctypes.c_void_p,
199    ctypes.c_void_p,
200]
201ssl.EC_GROUP_new_by_curve_name.restype = ctypes.c_void_p
202ssl.EC_GROUP_new_by_curve_name.argtypes = [ctypes.c_int]
203ssl.EC_GROUP_get0_generator.restype = ctypes.c_void_p
204ssl.EC_GROUP_get0_generator.argtypes = [ctypes.c_void_p]
205
206ssl.EC_POINT_new.argtypes = [ctypes.c_void_p]
207ssl.EC_POINT_new.restype = ctypes.c_void_p
208ssl.EC_POINT_dup.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
209ssl.EC_POINT_dup.restype = ctypes.c_void_p
210
211ssl.EC_POINT_free.argtypes = [ctypes.c_void_p]
212
213ssl.EC_POINT_mul.argtypes = [
214    ctypes.c_void_p,
215    ctypes.c_void_p,
216    ctypes.c_void_p,
217    ctypes.c_void_p,
218    ctypes.c_void_p,
219    ctypes.c_void_p,
220]
221ssl.EC_POINT_mul.restype = ctypes.c_int
222
223ssl.EC_POINT_add.argtypes = [
224    ctypes.c_void_p,
225    ctypes.c_void_p,
226    ctypes.c_void_p,
227    ctypes.c_void_p,
228    ctypes.c_void_p,
229]
230ssl.EC_POINT_add.restype = ctypes.c_int
231
232ssl.EC_POINT_point2oct.restype = ctypes.c_int
233ssl.EC_POINT_point2oct.argtypes = [
234    ctypes.c_void_p,
235    ctypes.c_void_p,
236    ctypes.c_int,
237    ctypes.c_void_p,
238    ctypes.c_int,
239    ctypes.c_void_p,
240]
241ssl.EC_POINT_oct2point.restype = ctypes.c_int
242ssl.EC_POINT_oct2point.argtypes = [
243    ctypes.c_void_p,
244    ctypes.c_void_p,
245    ctypes.c_void_p,
246    ctypes.c_int,
247    ctypes.c_void_p,
248]
249
250ssl.EC_POINT_is_on_curve.restype = ctypes.c_int
251ssl.EC_POINT_is_on_curve.argtypes = [
252    ctypes.c_void_p,
253    ctypes.c_void_p,
254    ctypes.c_void_p,
255]
256ssl.EC_POINT_is_at_infinity.restype = ctypes.c_int
257ssl.EC_POINT_is_at_infinity.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
258ssl.EC_POINT_set_to_infinity.restype = ctypes.c_int
259ssl.EC_POINT_set_to_infinity.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
260
261ssl.EC_POINT_cmp.restype = ctypes.c_int
262ssl.EC_POINT_cmp.argtypes = [
263    ctypes.c_void_p,
264    ctypes.c_void_p,
265    ctypes.c_void_p,
266    ctypes.c_void_p,
267]
268
269ssl.PEM_write_PUBKEY.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
270ssl.PEM_write_PUBKEY.restypes = ctypes.c_int
271
272ssl.PEM_write_PrivateKey.restype = ctypes.c_int
273ssl.PEM_write_PrivateKey.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
274
275ssl.PEM_read_PrivateKey.restype = ctypes.c_void_p
276ssl.PEM_read_PrivateKey.argtypes = [
277    ctypes.c_void_p,
278    ctypes.c_void_p,
279    ctypes.c_void_p,
280    ctypes.c_void_p,
281]
282
283ssl.EVP_PKEY_set1_EC_KEY.restype = ctypes.c_int
284ssl.EVP_PKEY_set1_EC_KEY.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
285
286ssl.EC_GROUP_get_curve_GFp.restype = ctypes.c_int
287ssl.EC_GROUP_get_curve_GFp.argtypes = [
288    ctypes.c_void_p,
289    ctypes.c_void_p,
290    ctypes.c_void_p,
291    ctypes.c_void_p,
292    ctypes.c_void_p,
293]
294
295ssl.EC_POINT_set_affine_coordinates_GFp.restype = ctypes.c_int
296ssl.EC_POINT_set_affine_coordinates_GFp.argtypes = [
297    ctypes.c_void_p,
298    ctypes.c_void_p,
299    ctypes.c_void_p,
300    ctypes.c_void_p,
301    ctypes.c_void_p,
302]
303
304ssl.BN_MONT_CTX_new.restype = ctypes.c_void_p
305ssl.BN_MONT_CTX_new.argtypes = []
306ssl.BN_MONT_CTX_set.restype = ctypes.c_int
307ssl.BN_MONT_CTX_set.argtypes = [
308    ctypes.c_void_p,
309    ctypes.c_void_p,
310    ctypes.c_void_p,
311]
312ssl.BN_MONT_CTX_free.argtypes = [ctypes.c_void_p]
313ssl.BN_mod_mul_montgomery.restype = ctypes.c_int
314ssl.BN_mod_mul_montgomery.argtypes = [
315    ctypes.c_void_p,
316    ctypes.c_void_p,
317    ctypes.c_void_p,
318    ctypes.c_void_p,
319    ctypes.c_void_p,
320]
321ssl.BN_to_montgomery.restype = ctypes.c_int
322ssl.BN_to_montgomery.argtypes = [
323    ctypes.c_void_p,
324    ctypes.c_void_p,
325    ctypes.c_void_p,
326    ctypes.c_void_p,
327]
328ssl.BN_from_montgomery.restype = ctypes.c_int
329ssl.BN_from_montgomery.argtypes = [
330    ctypes.c_void_p,
331    ctypes.c_void_p,
332    ctypes.c_void_p,
333    ctypes.c_void_p,
334]
335ssl.BN_copy.restype = ctypes.c_void_p
336ssl.BN_copy.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
337ssl.BN_dup.restype = ctypes.c_void_p
338ssl.BN_dup.argtypes = [ctypes.c_void_p]
339
340pointer = ctypes.pointer
341cast = ctypes.cast
342
343
344class SSLProxy(object):
345  """Wrapper (a pass-through with error checking) for the loaded ssl library.
346
347  This class checks the ssl methods returning pointers does not return None and
348  also checks methods returning 0 on failure. In case of a failure, it prints
349  OpenSSL error messages.
350  """
351
352  def __init__(self, ssl_lib):
353    self._ssl = ssl_lib
354    self._cache = {}
355    # Functions without a return value or having a return value that is already
356    # explicitly checked in the code.
357    self._funcs_to_skip = {
358        'BN_free',
359        'BN_CTX_free',
360        'BN_cmp',
361        'BN_num_bits',
362        'BN_bn2bin',
363        'EC_POINT_is_at_infinity',
364        'EC_POINT_cmp',
365        'EC_POINT_free',
366        'EC_KEY_free',
367        'BN_MONT_CTX_free',
368        'BN_is_bit_set',
369        'EC_GROUP_free',
370        'BN_is_prime_ex',
371        'EC_POINT_point2oct',
372    }
373
374  def _DebugInfo(self):
375    """Returns the last error message from the OpenSSL library."""
376    err = ctypes.create_string_buffer(256)
377    self._ssl.ERR_error_string_n(self._ssl.ERR_get_error(), err, 256)
378    return '\nOpenSSL Error: {}'.format(err.value)
379
380  def __getattr__(self, name):
381    if name in self._funcs_to_skip:
382      return getattr(self._ssl, name)
383    if name not in self._cache:
384
385      def WrapperFunc(*args):
386        func = getattr(self._ssl, name)
387        ret = func(*args)
388        if func.restype is ctypes.c_void_p:
389          assert ret is not None, 'ret is None{}'.format(self._DebugInfo())
390        elif func.restype is ctypes.c_int:
391          assert 1 == ret, 'ret is not 1, ret: {}{}'.format(
392              ret, self._DebugInfo()
393          )
394        return ret
395
396      self._cache[name] = WrapperFunc
397    return self._cache[name]
398
399
400ssl = SSLProxy(ssl)
401
402
403def LongtoBn(bn_r: int, a: int) -> int:
404  """Converts a to BigNum and stores in preallocated bn_r."""
405  bytes_a = converters.LongToBytes(a)
406  return ssl.BN_bin2bn(bytes_a, len(bytes_a), bn_r)
407
408
409def BnToLong(bn_a: int) -> int:
410  """Converts BigNum to long."""
411  num_bits_in_a = ssl.BN_num_bits(bn_a)
412  num_bytes_in_a = int(math.ceil(num_bits_in_a / 8.0))
413  bytes_a = ctypes.create_string_buffer(num_bytes_in_a)
414  ssl.BN_bn2bin(bn_a, bytes_a)
415  return converters.BytesToLong(bytes_a.raw)
416
417
418def BnToBytes(bn_a: int) -> bytes:
419  """Converts BigNum to long."""
420  num_bits_in_a = ssl.BN_num_bits(bn_a)
421  num_bytes_in_a = int(math.ceil(num_bits_in_a / 8.0))
422  bytes_a = ctypes.create_string_buffer(num_bytes_in_a)
423  ssl.BN_bn2bin(bn_a, bytes_a)
424  return bytes_a.raw
425
426
427def BytesToBn(bytes_a: bytes) -> int:
428  """Converts BigNum to long."""
429  bn_r = ssl.BN_new()
430  ssl.BN_bin2bn(bytes_a, len(bytes_a), bn_r)
431  return bn_r
432
433
434def GetRandomInRange(long_start: int, long_end: int) -> int:
435  """ "Returns a random in the range [long_start, long_end)."""
436  with TempBNs(rand=None, interval=(long_end - long_start)) as bn:
437    ssl.BN_rand_range(bn.rand, bn.interval)
438    return BnToLong(bn.rand) + long_start
439
440
441def ModExp(g: int, x: int, n: int) -> int:
442  """Computes g^x mod n."""
443  with TempBNs(r=None, g=g, x=x, n=n) as bn:
444    ssl.BN_mod_exp(bn.r, bn.g, bn.x, bn.n, OpenSSLHelper().ctx)
445    return BnToLong(bn.r)
446
447
448def ModInverse(x: int, n: int) -> int:
449  """Computes 1/x mod n."""
450  with TempBNs(r=None, x=x, n=n) as bn:
451    ssl.BN_mod_inverse(bn.r, bn.x, bn.n, OpenSSLHelper().ctx)
452    return BnToLong(bn.r)
453
454
455class TempBNs(object):
456  """Class for creating temporary openssl bignums by using 'with' clause."""
457
458  # Disable pytype attribute checking.
459  _HAS_DYNAMIC_ATTRIBUTES = True
460
461  def __init__(self, **kwargs):
462    r"""Initializes and assigns all temporary bignums.
463
464    Usage:
465      with TempBNs(x=5, y=[10,11]) as bn:
466        # bn.x is the temporary bignum holding the value 5 within this scope.
467        # bn.y is the temporary list of bignum holding the value 10 and 11
468        # within this scope.
469
470    or it can be used for assigning temporary results into bignums as follows:
471      with TempBNs(result=None, x=5) as bn:
472        # bn.result is an empty temporary bignum within this scope.
473        # bn.x is the same as before.
474
475    or bytes can be given as well as longs:
476      with TempBNs(x=5, y=['\001', '\002']) as bn:
477        # bn.x is the temporary bignum holding the value 5 within this scope.
478        # bn.y is the temporary list of bignum holding the value 1 and 2 within
479        # this scope.
480
481    Args:
482      **kwargs: key (variable), value (int or long) pairs.
483    """
484    self._args = []
485    for key, value in kwargs.items():
486      assert not hasattr(self, key), '{} already exists.'.format(key)
487      if isinstance(value, list):
488        assert value, 'Cannot declare empty list in TempBNs.'
489        for v in value:
490          self._args.append(ssl.BN_new())
491          self._BytesOrLongToBn(self._args[-1], v)
492        setattr(self, key, self._args[-len(value) :])
493      else:
494        self._args.append(ssl.BN_new())
495        setattr(self, key, self._args[-1])
496        if value:
497          self._BytesOrLongToBn(self._args[-1], value)
498
499  @classmethod
500  def _BytesOrLongToBn(cls, bn, val) -> int:
501    if isinstance(val, int):
502      LongtoBn(bn, val)
503    if isinstance(val, str):
504      ssl.BN_bin2bn(val, len(val), bn)
505
506  def __enter__(self, *args):
507    return self
508
509  def __exit__(self, some_type, value, traceback):
510    for bn in self._args:
511      ssl.BN_free(bn)
512
513
514def RandomOracle(
515    x: Union[int, bytes],
516    max_value: int,
517    hash_type: Union[type(None), HashType] = None,
518) -> int:
519  """A random oracle function mapping x deterministically into a large domain.
520
521  The random oracle is similar to the example given in the last paragraph of
522  Chapter 6 of [1] where the output is expanded by successively hashing the
523  concatenation of the input with a fixed sized counter starting from 1.
524
525  [1] Bellare, Mihir, and Phillip Rogaway. "Random oracles are practical:
526  A paradigm for designing efficient protocols." Proceedings of the 1st ACM
527  conference on Computer and communications security. ACM, 1993.
528
529  Args:
530    x: long or string input
531    max_value: the max value of the output domain.
532    hash_type: the hash function to use, as a HashType. If 'None' is provided
533      this defaults to HashType.SHA512.
534
535  Returns:
536    a long value from the set [0, max_value).
537
538  Raises:
539    ValueError: if bit length of max_value is greater than
540      hash_type.bit_length * 254. Since the counter used for expanding the
541      output is expanded to 8 bit length (hard-coded), any counter value that is
542      greater than 256 would cause variable length inputs passed to the
543      underlying hash calls and might make this random oracle's output not
544      uniform across the output domain. The output length is increased by a
545      security value of hash_type.bit_length which reduces the bias of selecting
546      certain values more often than others when max_value is not a multiple of
547      2.
548  """
549  if hash_type is None:
550    hash_type = HashType.SHA512
551  output_bit_length = max_value.bit_length() + hash_type.bit_length
552  iter_count = int(math.ceil(float(output_bit_length) / hash_type.bit_length))
553  if iter_count > 255:
554    raise ValueError(
555        'The domain bit length must not be greater than H * 254. '
556        'Given bit length: {}'.format(output_bit_length)
557    )
558  excess_bit_count = (iter_count * hash_type.bit_length) - output_bit_length
559  hash_output = 0
560  bytes_x = x if isinstance(x, bytes) else converters.LongToBytes(x)
561  for i in range(1, iter_count + 1):
562    hash_output <<= hash_type.bit_length
563    hash_output |= hash_type.hash(
564        six.ensure_binary(converters.LongToBytes(i) + bytes_x)
565    )
566  return (hash_output >> excess_bit_count) % max_value
567
568
569class PRNG(object):
570  """Hash based counter mode pseudorandom number generator.
571
572  The technique used in this class is same as the one used in RandomOracle
573  function.
574  """
575
576  def __init__(self, seed, counter_byte_len=4):
577    """Creates the PRNG with the given seed.
578
579    Args:
580      seed: at least 32 byte number or string.
581      counter_byte_len: the max number of counter bytes to use. After exceeding
582        the counter, this PRNG should not be used.
583
584    Raises:
585      ValueError: when the seed is not at least 32 bytes.
586    """
587    self.seed = (
588        seed if isinstance(seed, bytes) else converters.LongToBytes(seed)
589    )
590    if len(self.seed) < 32:
591      raise ValueError(
592          'seed needs to be at least 32 bytes, the given bytes: {}'.format(
593              self.seed
594          )
595      )
596    self.cur_pad = 0
597    self.cur_bytes = b''
598    self.cur_byte_len = counter_byte_len
599    self.limit = 1 << (self.cur_byte_len * 8)
600
601  def _GetMore(self):
602    assert self.cur_pad < self.limit, 'Limit has been reached.'
603    hash_output = six.ensure_binary(
604        hashlib.sha512(
605            six.ensure_binary(self._PaddedCountBytes() + self.seed)
606        ).digest()
607    )
608    self.cur_pad += 1
609    self.cur_bytes += hash_output
610
611  def _PaddedCountBytes(self):
612    counter_bytes = converters.LongToBytes(self.cur_pad)
613    # Although we could use {:\x004}.format, Python seems to print space when
614    # doing this way for the null character.
615    return b'\000' * (self.cur_byte_len - len(counter_bytes)) + counter_bytes
616
617  def _GetNBitRand(self, n):
618    """Gets a random number in [0, 2^n) interval."""
619    byte_len = (n + 7) >> 3
620    excess_len = (8 - (n % 8)) % 8
621    while len(self.cur_bytes) < byte_len:
622      self._GetMore()
623    self.cur_bytes, rand = (
624        self.cur_bytes[byte_len:],
625        self.cur_bytes[:byte_len],
626    )
627    rand_num = converters.BytesToLong(rand) >> excess_len
628    return rand_num
629
630  def GetRand(self, upper_limit):
631    """Gets a random number in [0, upper_limit) interval."""
632    bit_len = (upper_limit - 1).bit_length()
633    rand_num = self._GetNBitRand(bit_len)
634    while rand_num >= upper_limit:
635      rand_num = self._GetNBitRand(bit_len)
636    return rand_num
637
638
639class OpenSSLHelper(object):
640  """A singleton wrapper class for openssl ctx and seeding its rand.
641
642  Context is used for caching already allocated big nums. Each openssl operation
643  requires a context to be passed to Get temporary big nums avoiding allocating
644  new big nums for these temporary nums thus making big num operations use
645  memory resources more efficiently. Usage in openssl library:
646
647  BN_CTX_start(ctx)
648  ....
649  temp = BN_CTX_get(ctx)
650  ....
651  BN_CTX_end(ctx)
652  Please note that BN_CTX_start and BN_CTX_end is not implemented here as this
653  is only passed to various openssl big num operations.
654  """
655
656  _instance = None
657
658  def __new__(cls, *args, **kwargs):
659    if not cls._instance:
660      cls._instance = super(OpenSSLHelper, cls).__new__(cls, *args, **kwargs)
661    return cls._instance
662
663  def __init__(self):
664    self._ctx = ssl.BN_CTX_new()
665    # So that garbage collection doesn't collect ssl before this object.
666    self.ssl = ssl
667
668  def __del__(self):
669    # clean up
670    self.ssl.BN_CTX_free(self._ctx)
671
672  @property
673  def ctx(self):
674    return self._ctx
675
676
677@total_ordering
678class BigNum(object):
679  """A wrapper class for openssl bn numbers.
680
681  Used for arithmetic operations on long numbers.
682  """
683
684  _ZERO = None
685  _ONE = None
686  _TWO = None
687
688  def __init__(self, bn_num):
689    self._bn_num = bn_num
690    self._helper = OpenSSLHelper()
691    self.immutable = True
692    # So that garbage collection doesn't collect ssl before this object.
693    self.ssl = ssl
694
695  @classmethod
696  def Zero(cls):
697    if not cls._ZERO:
698      cls._ZERO = cls.FromLongNumber(0)
699    return cls._ZERO
700
701  @classmethod
702  def One(cls):
703    if not cls._ONE:
704      cls._ONE = cls.FromLongNumber(1)
705    return cls._ONE
706
707  @classmethod
708  def Two(cls):
709    if not cls._TWO:
710      cls._TWO = cls.FromLongNumber(2)
711    return cls._TWO
712
713  @classmethod
714  def FromLongNumber(cls, long_number: int) -> 'BigNum':
715    """Returns a BigNum constructed from the given long number."""
716    bytes_num = converters.LongToBytes(long_number)
717    return cls.FromBytes(bytes_num)
718
719  @classmethod
720  def FromBytes(cls, number_in_bytes):
721    """Returns a BigNum constructed from the given long number."""
722    bn_num = ssl.BN_new()
723    ssl.BN_bin2bn(number_in_bytes, len(number_in_bytes), bn_num)
724    return BigNum(bn_num)
725
726  @classmethod
727  def GenerateSafePrime(cls, prime_length):
728    """Returns a safe prime BigNum with the given bit-length."""
729    bn_prime_num = ssl.BN_new()
730    ssl.BN_generate_prime_ex(bn_prime_num, prime_length, 1, None, None, None)
731    return BigNum(bn_prime_num)
732
733  @classmethod
734  def GeneratePrime(cls, prime_length: int) -> 'BigNum':
735    """Returns a prime BigNum with the given bit-length."""
736    bn_prime_num = ssl.BN_new()
737    ssl.BN_generate_prime_ex(bn_prime_num, prime_length, 0, None, None, None)
738    return BigNum(bn_prime_num)
739
740  def GeneratePrimeForSubGroup(self, prime_length: int) -> 'BigNum':
741    """Returns a prime BigNum, p, satisfying p = (self * k) + 1 for some k.
742
743    Args:
744      prime_length: the bit length of the returned prime.
745
746    Returns:
747      a prime BigNum, p = (self * k) + 1 for some k.
748    """
749    bn_prime_num = ssl.BN_new()
750    ssl.BN_generate_prime_ex(
751        bn_prime_num, prime_length, 0, self._bn_num, None, None
752    )
753    return BigNum(bn_prime_num)
754
755  def IsPrime(self, error_probability=1e-6):
756    """Returns True if this big num is prime, False otherwise."""
757    rounds = int(math.ceil(-math.log(error_probability) / math.log(4)))
758    return ssl.BN_is_prime_ex(self._bn_num, rounds, self._helper.ctx, None) != 0
759
760  def IsSafePrime(self, error_probability=1e-6):
761    """Returns True if this big num is a safe prime, False otherwise."""
762    return self.IsPrime(error_probability) and (
763        (self - self.One()) / self.Two()
764    ).IsPrime(error_probability)
765
766  def IsBitSet(self, n):
767    """Returns True if the n-th bit is set, False otherwise."""
768    return ssl.BN_is_bit_set(self._bn_num, n)
769
770  def BitLength(self):
771    return ssl.BN_num_bits(self._bn_num)
772
773  def Clone(self):
774    """Clones this big num."""
775    return BigNum(ssl.BN_dup(self._bn_num))
776
777  def Mutable(self):
778    """Sets this BigNum to mutable so that it can be changed."""
779    self.immutable = False
780    return self
781
782  def __hash__(self):
783    return hash((self._bn_num, self.immutable))
784
785  def __del__(self):
786    self.ssl.BN_free(self._bn_num)
787
788  def __add__(self, other):
789    return self._ComputeResult(ssl.BN_add, None, other)
790
791  def __iadd__(self, other):
792    return self._ComputeResultInPlace(ssl.BN_add, None, other)
793
794  def __sub__(self, other):
795    return self._ComputeResult(ssl.BN_sub, None, other)
796
797  def __isub__(self, other):
798    return self._ComputeResultInPlace(ssl.BN_sub, None, other)
799
800  def __mul__(self, other):
801    return self._ComputeResult(ssl.BN_mul, self._helper.ctx, other)
802
803  def __imul__(self, other):
804    return self._ComputeResultInPlace(ssl.BN_mul, self._helper.ctx, other)
805
806  def __mod__(self, modulus):
807    return self._ComputeResult(ssl.BN_nnmod, self._helper.ctx, modulus)
808
809  def __imod__(self, modulus):
810    return self._ComputeResultInPlace(ssl.BN_nnmod, self._helper.ctx, modulus)
811
812  def __pow__(self, other):
813    return self._ComputeResult(ssl.BN_exp, self._helper.ctx, other)
814
815  def __ipow__(self, other):
816    return self._ComputeResultInPlace(ssl.BN_exp, self._helper.ctx, other)
817
818  def __rshift__(self, n):
819    bn_num = ssl.BN_new()
820    ssl.BN_rshift(bn_num, self._bn_num, n)
821    return BigNum(bn_num)
822
823  def __irshift__(self, n):
824    ssl.BN_rshift(self._bn_num, self._bn_num, n)
825    return self
826
827  def __lshift__(self, n):
828    bn_num = ssl.BN_new()
829    ssl.BN_lshift(bn_num, self._bn_num, n)
830    return BigNum(bn_num)
831
832  def __ilshift__(self, n):
833    ssl.BN_lshift(self._bn_num, self._bn_num, n)
834    return self
835
836  def __div__(self, other):
837    return self._Div(BigNum(ssl.BN_new()), self, other)
838
839  def __truediv__(self, other):
840    return self._Div(BigNum(ssl.BN_new()), self, other)
841
842  def __idiv__(self, other):
843    return self._Div(self, self, other)
844
845  def _Div(self, big_result, big_num, other_big_num):
846    """Divides two bignums.
847
848    Args:
849      big_result: The bignum where the result is stored.
850      big_num:  The numerator.
851      other_big_num:  The denominator.
852
853    Returns:
854      big_result.
855
856    Raises:
857      ValueError:  If the remainder is non-zero.
858    """
859    if isinstance(other_big_num, self.__class__):
860      bn_remainder = ssl.BN_new()
861      ssl.BN_div(
862          big_result._bn_num,
863          bn_remainder,
864          big_num._bn_num,
865          other_big_num._bn_num,
866          self._helper.ctx,
867      )
868      try:
869        if ssl.BN_cmp(bn_remainder, self.Zero()._bn_num) != 0:
870          raise ValueError(
871              'There is a remainder in division of {} and {}'.format(
872                  big_num.GetAsLong(), other_big_num.GetAsLong()
873              )
874          )
875        return big_result
876      finally:
877        ssl.BN_free(bn_remainder)
878
879  def ModMul(self, other, modulus):
880    """Modular multiplies this with other based on the modulus.
881
882    For efficiency, please use Montgomery multiplication module if this is done
883    multiple times with the same modulus.
884
885    Args:
886      other: the other BigNum
887      modulus: the modulus of the operation
888
889    Returns:
890      a new BigNum holding the result.
891    """
892    return self._ComputeResult(ssl.BN_mod_mul, self._helper.ctx, other, modulus)
893
894  def IModMul(self, other, modulus):
895    """Modular multiplies this with other based on the modulus.
896
897    Stores the result in this BigNum.
898    For efficiency, please use Montgomery multiplication module if this is done
899    multiple times with the same modulus.
900
901    Args:
902      other: the other BigNum
903      modulus: the modulus of the operation
904
905    Returns:
906      self
907    """
908    return self._ComputeResultInPlace(
909        ssl.BN_mod_mul, self._helper.ctx, other, modulus
910    )
911
912  def ModExp(self, other, modulus):
913    """Modular exponentiates this with other based on the modulus.
914
915    Args:
916      other: the other BigNum
917      modulus: the modulus of the operation
918
919    Returns:
920      a new BigNum holding the result.
921    """
922    return self._ComputeResult(ssl.BN_mod_exp, self._helper.ctx, other, modulus)
923
924  def IModExp(self, other, modulus):
925    """Modular exponentiates this with other based on the modulus.
926
927    Args:
928      other: the other BigNum
929      modulus: the modulus of the operation
930
931    Returns:
932      self
933    """
934    return self._ComputeResultInPlace(
935        ssl.BN_mod_exp, self._helper.ctx, other, modulus
936    )
937
938  def GCD(self, other):
939    """Gets gcd as a BigNum."""
940    return self._ComputeResult(ssl.BN_gcd, self._helper.ctx, other)
941
942  def ModInverse(self, modulus):
943    """Gets the inverse of this BigNum in mod modulus."""
944    try:
945      return self._ComputeResult(ssl.BN_mod_inverse, self._helper.ctx, modulus)
946    except AssertionError as a:
947      raise ValueError(
948          'This big num {} and modulus {} are not relatively '
949          'primes.\nThe Assertion Error: {}'.format(
950              self.GetAsLong(), modulus.GetAsLong(), a
951          )
952      )
953
954  def ModSqrt(self, modulus):
955    """Gets the sqrt of this BigNum in mod modulus.
956
957    Args:
958      modulus: the modulus of the operation
959
960    Returns:
961      a new BigNum holding the result.
962    """
963    big_num_result = self._ComputeResult(
964        ssl.BN_mod_sqrt, self._helper.ctx, modulus
965    )
966    return big_num_result
967
968  def IModSqrt(self, modulus):
969    """Gets the sqrt of this BigNum in mod modulus.
970
971    Args:
972      modulus: the modulus of the operation
973
974    Returns:
975      self
976    """
977    return self._ComputeResultInPlace(
978        ssl.BN_mod_sqrt, self._helper.ctx, modulus
979    )
980
981  def GenerateRand(self):
982    """Generates a cryptographically strong pseudo-random between 0 & self.
983
984    Returns:
985      A BigNum in [0, self._big_num) range.
986    """
987    bn_rand = ssl.BN_new()
988    ssl.BN_rand_range(bn_rand, self._bn_num)
989    return BigNum(bn_rand)
990
991  def GenerateRandWithStart(self, start_big_num):
992    """Generates a cryptographically strong pseudo-random between start & self.
993
994    Args:
995      start_big_num: start BigNum value of the interval.
996
997    Returns:
998      A BigNum in [start, self._big_num) range.
999    """
1000    return (self - start_big_num).GenerateRand() + start_big_num
1001
1002  def ModNegate(self, modulus):
1003    return modulus - (self % modulus)
1004
1005  def AddOne(self):
1006    return self + self.One()
1007
1008  def SubtractOne(self):
1009    return self - self.One()
1010
1011  def __str__(self):
1012    return str(self.GetAsLong())
1013
1014  def __eq__(self, other):
1015    # pylint: disable=protected-access
1016    if isinstance(other, self.__class__):
1017      return ssl.BN_cmp(self._bn_num, other._bn_num) == 0
1018    raise ValueError('Cannot compare BigNum with type {}'.format(type(other)))
1019    # pylint: enable=protected-access
1020
1021  def __ne__(self, other):
1022    return not self == other
1023
1024  def __lt__(self, other):
1025    # pylint: disable=protected-access
1026    if isinstance(other, self.__class__):
1027      return ssl.BN_cmp(self._bn_num, other._bn_num) == -1
1028    raise ValueError('Cannot compare BigNum with type {}'.format(type(other)))
1029    # pylint: enable=protected-access
1030
1031  def _ComputeResult(self, func, ctx, *args):
1032    return self._ComputeResultIntoBigNum(
1033        BigNum(ssl.BN_new()), func, ctx, self, *args
1034    )
1035
1036  def _ComputeResultInPlace(self, func, ctx, *args):
1037    if self.immutable:
1038      raise ValueError(
1039          'This operation will change this immutable BigNum. Call '
1040          'Mutable method to change it.'
1041      )
1042    return self._ComputeResultIntoBigNum(self, func, ctx, self, *args)
1043
1044  @classmethod
1045  def _ComputeResultIntoBigNum(cls, big_num_result, func, ctx, *args):
1046    # pylint: disable=protected-access
1047    if all(isinstance(big_num, cls) for big_num in args):
1048      args = [big_num._bn_num for big_num in args]
1049      if ctx:
1050        args.append(ctx)
1051      func(big_num_result._bn_num, *args)
1052      return big_num_result
1053    return NotImplemented
1054    # pylint: enable=protected-access
1055
1056  def GetAsLong(self):
1057    """Gets the long number in this BigNum."""
1058    return converters.BytesToLong(self.GetAsBytes())
1059
1060  def GetAsBytes(self):
1061    """Gets the long number as bytes in this BigNum."""
1062    num_bits = ssl.BN_num_bits(self._bn_num)
1063    num_bytes = int(math.ceil(num_bits / 8.0))
1064    bytes_num = ctypes.create_string_buffer(num_bytes)
1065    ssl.BN_bn2bin(self._bn_num, bytes_num)
1066    return bytes_num.raw
1067
1068
1069class BigNumCache(object):
1070  """A singleton cache holding BigNum representations of small numbers."""
1071
1072  _instance = None
1073
1074  def __new__(cls, *args, **kwargs):  # pylint: disable=unused-argument
1075    if not cls._instance:
1076      cls._instance = super(BigNumCache, cls).__new__(cls)
1077    return cls._instance
1078
1079  def __init__(self, max_count: int):
1080    self._cache = {}
1081    self._max_count = max_count
1082
1083  def Get(self, num: int) -> BigNum:
1084    """Gets the BigNum from the cache or creates a new BigNum.
1085
1086    If max_count is reached, a new BigNum is created and returned without
1087    storing in the cache.
1088    Args:
1089      num: the long or integer to convert to BigNum.
1090
1091    Returns:
1092      a BigNum for the given num.
1093    """
1094    if num not in self._cache:
1095      if len(self._cache) >= self._max_count:
1096        return BigNum.FromLongNumber(num)
1097      self._cache[num] = BigNum.FromLongNumber(num)
1098    return self._cache[num]
1099