1# Copyright 2019 Google LLC. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15 16"""Make available access to openssl library and bn functions.""" 17 18import ctypes.util 19from functools import total_ordering 20import hashlib 21import math 22from typing import Union 23 24from absl import logging 25from private_join_and_compute.py.crypto_util import converters 26from private_join_and_compute.py.crypto_util.supported_hashes import HashType 27import six 28 29ssl = None 30 31try: 32 ssl_libpath = ctypes.util.find_library('crypto') 33 ssl = ctypes.cdll.LoadLibrary(ssl_libpath) 34except (OSError, IOError) as e: 35 logging.fatal('Could not load the ssl library.\n%s', e) 36 37ssl.ERR_error_string_n.restype = ctypes.c_void_p 38ssl.ERR_error_string_n.argtypes = [ 39 ctypes.c_long, 40 ctypes.c_char_p, 41 ctypes.c_size_t, 42] 43ssl.ERR_get_error.restype = ctypes.c_long 44ssl.ERR_get_error.argtypes = [] 45 46ssl.BN_new.restype = ctypes.c_void_p 47ssl.BN_new.argtypes = [] 48ssl.BN_free.argtypes = [ctypes.c_void_p] 49ssl.BN_num_bits.restype = ctypes.c_int 50ssl.BN_num_bits.argtypes = [ctypes.c_void_p] 51ssl.BN_bin2bn.restype = ctypes.c_void_p 52ssl.BN_bin2bn.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.c_void_p] 53ssl.BN_bn2bin.restype = ctypes.c_int 54ssl.BN_bn2bin.argtypes = [ctypes.c_void_p, ctypes.c_void_p] 55ssl.BN_CTX_new.restype = ctypes.c_void_p 56ssl.BN_CTX_new.argtypes = [] 57ssl.BN_CTX_free.restype = ctypes.c_int 58ssl.BN_CTX_free.argtypes = [ctypes.c_void_p] 59ssl.BN_mod_exp.restype = ctypes.c_int 60ssl.BN_mod_exp.argtypes = [ 61 ctypes.c_void_p, 62 ctypes.c_void_p, 63 ctypes.c_void_p, 64 ctypes.c_void_p, 65 ctypes.c_void_p, 66] 67ssl.BN_mod_mul.restype = ctypes.c_int 68ssl.BN_mod_mul.argtypes = [ 69 ctypes.c_void_p, 70 ctypes.c_void_p, 71 ctypes.c_void_p, 72 ctypes.c_void_p, 73 ctypes.c_void_p, 74] 75ssl.BN_CTX_new.argtypes = [] 76ssl.BN_CTX_free.argtypes = [ctypes.c_void_p] 77ssl.BN_generate_prime_ex.restype = ctypes.c_int 78ssl.BN_generate_prime_ex.argtypes = [ 79 ctypes.c_void_p, 80 ctypes.c_int, 81 ctypes.c_int, 82 ctypes.c_void_p, 83 ctypes.c_void_p, 84 ctypes.c_void_p, 85] 86ssl.BN_is_prime_ex.restype = ctypes.c_int 87ssl.BN_is_prime_ex.argtypes = [ 88 ctypes.c_void_p, 89 ctypes.c_int, 90 ctypes.c_void_p, 91 ctypes.c_void_p, 92] 93ssl.BN_mul.restype = ctypes.c_int 94ssl.BN_mul.argtypes = [ 95 ctypes.c_void_p, 96 ctypes.c_void_p, 97 ctypes.c_void_p, 98 ctypes.c_void_p, 99] 100ssl.BN_div.restype = ctypes.c_int 101ssl.BN_div.argtypes = [ 102 ctypes.c_void_p, 103 ctypes.c_void_p, 104 ctypes.c_void_p, 105 ctypes.c_void_p, 106 ctypes.c_void_p, 107] 108ssl.BN_exp.restype = ctypes.c_int 109ssl.BN_exp.argtypes = [ 110 ctypes.c_void_p, 111 ctypes.c_void_p, 112 ctypes.c_void_p, 113 ctypes.c_void_p, 114] 115ssl.RAND_seed.restype = ctypes.c_int 116ssl.RAND_seed.argtypes = [ctypes.c_void_p, ctypes.c_int] 117ssl.BN_gcd.restype = ctypes.c_int 118ssl.BN_gcd.argtypes = [ 119 ctypes.c_void_p, 120 ctypes.c_void_p, 121 ctypes.c_void_p, 122 ctypes.c_void_p, 123] 124ssl.BN_mod_inverse.restype = ctypes.c_void_p 125ssl.BN_mod_inverse.argtypes = [ 126 ctypes.c_void_p, 127 ctypes.c_void_p, 128 ctypes.c_void_p, 129 ctypes.c_void_p, 130] 131ssl.BN_mod_sqrt.restype = ctypes.c_void_p 132ssl.BN_mod_sqrt.argtypes = [ 133 ctypes.c_void_p, 134 ctypes.c_void_p, 135 ctypes.c_void_p, 136 ctypes.c_void_p, 137] 138ssl.BN_add.restype = ctypes.c_int 139ssl.BN_add.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p] 140ssl.BN_sub.restype = ctypes.c_int 141ssl.BN_sub.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p] 142ssl.BN_nnmod.restype = ctypes.c_int 143ssl.BN_nnmod.argtypes = [ 144 ctypes.c_void_p, 145 ctypes.c_void_p, 146 ctypes.c_void_p, 147 ctypes.c_void_p, 148] 149ssl.BN_rand_range.restype = ctypes.c_int 150ssl.BN_rand_range.argtypes = [ctypes.c_void_p, ctypes.c_void_p] 151ssl.BN_lshift.restype = ctypes.c_int 152ssl.BN_lshift.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int] 153ssl.BN_rshift.restype = ctypes.c_int 154ssl.BN_rshift.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int] 155ssl.BN_cmp.restype = ctypes.c_int 156ssl.BN_cmp.argtypes = [ctypes.c_void_p, ctypes.c_void_p] 157ssl.BN_is_bit_set.restype = ctypes.c_int 158ssl.BN_is_bit_set.argtypes = [ctypes.c_void_p, ctypes.c_int] 159 160ssl.EVP_PKEY_new.argtypes = [] 161ssl.EVP_PKEY_new.restype = ctypes.c_void_p 162 163ssl.EC_KEY_new.restype = ctypes.c_void_p 164ssl.EC_KEY_new.argtypes = [] 165ssl.EC_KEY_free.argtypes = [ctypes.c_void_p] 166ssl.EC_KEY_new_by_curve_name.restype = ctypes.c_void_p 167ssl.EC_KEY_new_by_curve_name.argtypes = [ctypes.c_int] 168ssl.EC_KEY_generate_key.restype = ctypes.c_int 169ssl.EC_KEY_generate_key.argtypes = [ctypes.c_void_p] 170ssl.EC_KEY_set_asn1_flag.restype = None 171ssl.EC_KEY_set_asn1_flag.argtypes = [ctypes.c_void_p, ctypes.c_int] 172 173ssl.EC_KEY_get0_public_key.restype = ctypes.c_void_p 174ssl.EC_KEY_get0_public_key.argtypes = [ctypes.c_void_p] 175 176ssl.EC_KEY_set_public_key.restype = ctypes.c_int 177ssl.EC_KEY_set_public_key.argtypes = [ctypes.c_void_p, ctypes.c_void_p] 178 179ssl.EC_KEY_get0_private_key.restype = ctypes.c_void_p 180ssl.EC_KEY_get0_private_key.argtypes = [ctypes.c_void_p] 181 182ssl.EC_KEY_set_private_key.restype = ctypes.c_int 183ssl.EC_KEY_set_private_key.argtypes = [ctypes.c_void_p, ctypes.c_void_p] 184 185ssl.EC_KEY_check_key.restype = ctypes.c_int 186ssl.EC_KEY_check_key.argtypes = [ctypes.c_void_p] 187 188ssl.EVP_PKEY_free.argtypes = [ctypes.c_void_p] 189ssl.EVP_PKEY_free.restype = None 190 191ssl.EVP_PKEY_get1_EC_KEY.restype = ctypes.c_void_p 192ssl.EVP_PKEY_get1_EC_KEY.argtypes = [ctypes.c_void_p] 193 194ssl.EC_GROUP_free.argtypes = [ctypes.c_void_p] 195ssl.EC_GROUP_get_order.restype = ctypes.c_int 196ssl.EC_GROUP_get_order.argtypes = [ 197 ctypes.c_void_p, 198 ctypes.c_void_p, 199 ctypes.c_void_p, 200] 201ssl.EC_GROUP_new_by_curve_name.restype = ctypes.c_void_p 202ssl.EC_GROUP_new_by_curve_name.argtypes = [ctypes.c_int] 203ssl.EC_GROUP_get0_generator.restype = ctypes.c_void_p 204ssl.EC_GROUP_get0_generator.argtypes = [ctypes.c_void_p] 205 206ssl.EC_POINT_new.argtypes = [ctypes.c_void_p] 207ssl.EC_POINT_new.restype = ctypes.c_void_p 208ssl.EC_POINT_dup.argtypes = [ctypes.c_void_p, ctypes.c_void_p] 209ssl.EC_POINT_dup.restype = ctypes.c_void_p 210 211ssl.EC_POINT_free.argtypes = [ctypes.c_void_p] 212 213ssl.EC_POINT_mul.argtypes = [ 214 ctypes.c_void_p, 215 ctypes.c_void_p, 216 ctypes.c_void_p, 217 ctypes.c_void_p, 218 ctypes.c_void_p, 219 ctypes.c_void_p, 220] 221ssl.EC_POINT_mul.restype = ctypes.c_int 222 223ssl.EC_POINT_add.argtypes = [ 224 ctypes.c_void_p, 225 ctypes.c_void_p, 226 ctypes.c_void_p, 227 ctypes.c_void_p, 228 ctypes.c_void_p, 229] 230ssl.EC_POINT_add.restype = ctypes.c_int 231 232ssl.EC_POINT_point2oct.restype = ctypes.c_int 233ssl.EC_POINT_point2oct.argtypes = [ 234 ctypes.c_void_p, 235 ctypes.c_void_p, 236 ctypes.c_int, 237 ctypes.c_void_p, 238 ctypes.c_int, 239 ctypes.c_void_p, 240] 241ssl.EC_POINT_oct2point.restype = ctypes.c_int 242ssl.EC_POINT_oct2point.argtypes = [ 243 ctypes.c_void_p, 244 ctypes.c_void_p, 245 ctypes.c_void_p, 246 ctypes.c_int, 247 ctypes.c_void_p, 248] 249 250ssl.EC_POINT_is_on_curve.restype = ctypes.c_int 251ssl.EC_POINT_is_on_curve.argtypes = [ 252 ctypes.c_void_p, 253 ctypes.c_void_p, 254 ctypes.c_void_p, 255] 256ssl.EC_POINT_is_at_infinity.restype = ctypes.c_int 257ssl.EC_POINT_is_at_infinity.argtypes = [ctypes.c_void_p, ctypes.c_void_p] 258ssl.EC_POINT_set_to_infinity.restype = ctypes.c_int 259ssl.EC_POINT_set_to_infinity.argtypes = [ctypes.c_void_p, ctypes.c_void_p] 260 261ssl.EC_POINT_cmp.restype = ctypes.c_int 262ssl.EC_POINT_cmp.argtypes = [ 263 ctypes.c_void_p, 264 ctypes.c_void_p, 265 ctypes.c_void_p, 266 ctypes.c_void_p, 267] 268 269ssl.PEM_write_PUBKEY.argtypes = [ctypes.c_void_p, ctypes.c_void_p] 270ssl.PEM_write_PUBKEY.restypes = ctypes.c_int 271 272ssl.PEM_write_PrivateKey.restype = ctypes.c_int 273ssl.PEM_write_PrivateKey.argtypes = [ctypes.c_void_p, ctypes.c_void_p] 274 275ssl.PEM_read_PrivateKey.restype = ctypes.c_void_p 276ssl.PEM_read_PrivateKey.argtypes = [ 277 ctypes.c_void_p, 278 ctypes.c_void_p, 279 ctypes.c_void_p, 280 ctypes.c_void_p, 281] 282 283ssl.EVP_PKEY_set1_EC_KEY.restype = ctypes.c_int 284ssl.EVP_PKEY_set1_EC_KEY.argtypes = [ctypes.c_void_p, ctypes.c_void_p] 285 286ssl.EC_GROUP_get_curve_GFp.restype = ctypes.c_int 287ssl.EC_GROUP_get_curve_GFp.argtypes = [ 288 ctypes.c_void_p, 289 ctypes.c_void_p, 290 ctypes.c_void_p, 291 ctypes.c_void_p, 292 ctypes.c_void_p, 293] 294 295ssl.EC_POINT_set_affine_coordinates_GFp.restype = ctypes.c_int 296ssl.EC_POINT_set_affine_coordinates_GFp.argtypes = [ 297 ctypes.c_void_p, 298 ctypes.c_void_p, 299 ctypes.c_void_p, 300 ctypes.c_void_p, 301 ctypes.c_void_p, 302] 303 304ssl.BN_MONT_CTX_new.restype = ctypes.c_void_p 305ssl.BN_MONT_CTX_new.argtypes = [] 306ssl.BN_MONT_CTX_set.restype = ctypes.c_int 307ssl.BN_MONT_CTX_set.argtypes = [ 308 ctypes.c_void_p, 309 ctypes.c_void_p, 310 ctypes.c_void_p, 311] 312ssl.BN_MONT_CTX_free.argtypes = [ctypes.c_void_p] 313ssl.BN_mod_mul_montgomery.restype = ctypes.c_int 314ssl.BN_mod_mul_montgomery.argtypes = [ 315 ctypes.c_void_p, 316 ctypes.c_void_p, 317 ctypes.c_void_p, 318 ctypes.c_void_p, 319 ctypes.c_void_p, 320] 321ssl.BN_to_montgomery.restype = ctypes.c_int 322ssl.BN_to_montgomery.argtypes = [ 323 ctypes.c_void_p, 324 ctypes.c_void_p, 325 ctypes.c_void_p, 326 ctypes.c_void_p, 327] 328ssl.BN_from_montgomery.restype = ctypes.c_int 329ssl.BN_from_montgomery.argtypes = [ 330 ctypes.c_void_p, 331 ctypes.c_void_p, 332 ctypes.c_void_p, 333 ctypes.c_void_p, 334] 335ssl.BN_copy.restype = ctypes.c_void_p 336ssl.BN_copy.argtypes = [ctypes.c_void_p, ctypes.c_void_p] 337ssl.BN_dup.restype = ctypes.c_void_p 338ssl.BN_dup.argtypes = [ctypes.c_void_p] 339 340pointer = ctypes.pointer 341cast = ctypes.cast 342 343 344class SSLProxy(object): 345 """Wrapper (a pass-through with error checking) for the loaded ssl library. 346 347 This class checks the ssl methods returning pointers does not return None and 348 also checks methods returning 0 on failure. In case of a failure, it prints 349 OpenSSL error messages. 350 """ 351 352 def __init__(self, ssl_lib): 353 self._ssl = ssl_lib 354 self._cache = {} 355 # Functions without a return value or having a return value that is already 356 # explicitly checked in the code. 357 self._funcs_to_skip = { 358 'BN_free', 359 'BN_CTX_free', 360 'BN_cmp', 361 'BN_num_bits', 362 'BN_bn2bin', 363 'EC_POINT_is_at_infinity', 364 'EC_POINT_cmp', 365 'EC_POINT_free', 366 'EC_KEY_free', 367 'BN_MONT_CTX_free', 368 'BN_is_bit_set', 369 'EC_GROUP_free', 370 'BN_is_prime_ex', 371 'EC_POINT_point2oct', 372 } 373 374 def _DebugInfo(self): 375 """Returns the last error message from the OpenSSL library.""" 376 err = ctypes.create_string_buffer(256) 377 self._ssl.ERR_error_string_n(self._ssl.ERR_get_error(), err, 256) 378 return '\nOpenSSL Error: {}'.format(err.value) 379 380 def __getattr__(self, name): 381 if name in self._funcs_to_skip: 382 return getattr(self._ssl, name) 383 if name not in self._cache: 384 385 def WrapperFunc(*args): 386 func = getattr(self._ssl, name) 387 ret = func(*args) 388 if func.restype is ctypes.c_void_p: 389 assert ret is not None, 'ret is None{}'.format(self._DebugInfo()) 390 elif func.restype is ctypes.c_int: 391 assert 1 == ret, 'ret is not 1, ret: {}{}'.format( 392 ret, self._DebugInfo() 393 ) 394 return ret 395 396 self._cache[name] = WrapperFunc 397 return self._cache[name] 398 399 400ssl = SSLProxy(ssl) 401 402 403def LongtoBn(bn_r: int, a: int) -> int: 404 """Converts a to BigNum and stores in preallocated bn_r.""" 405 bytes_a = converters.LongToBytes(a) 406 return ssl.BN_bin2bn(bytes_a, len(bytes_a), bn_r) 407 408 409def BnToLong(bn_a: int) -> int: 410 """Converts BigNum to long.""" 411 num_bits_in_a = ssl.BN_num_bits(bn_a) 412 num_bytes_in_a = int(math.ceil(num_bits_in_a / 8.0)) 413 bytes_a = ctypes.create_string_buffer(num_bytes_in_a) 414 ssl.BN_bn2bin(bn_a, bytes_a) 415 return converters.BytesToLong(bytes_a.raw) 416 417 418def BnToBytes(bn_a: int) -> bytes: 419 """Converts BigNum to long.""" 420 num_bits_in_a = ssl.BN_num_bits(bn_a) 421 num_bytes_in_a = int(math.ceil(num_bits_in_a / 8.0)) 422 bytes_a = ctypes.create_string_buffer(num_bytes_in_a) 423 ssl.BN_bn2bin(bn_a, bytes_a) 424 return bytes_a.raw 425 426 427def BytesToBn(bytes_a: bytes) -> int: 428 """Converts BigNum to long.""" 429 bn_r = ssl.BN_new() 430 ssl.BN_bin2bn(bytes_a, len(bytes_a), bn_r) 431 return bn_r 432 433 434def GetRandomInRange(long_start: int, long_end: int) -> int: 435 """ "Returns a random in the range [long_start, long_end).""" 436 with TempBNs(rand=None, interval=(long_end - long_start)) as bn: 437 ssl.BN_rand_range(bn.rand, bn.interval) 438 return BnToLong(bn.rand) + long_start 439 440 441def ModExp(g: int, x: int, n: int) -> int: 442 """Computes g^x mod n.""" 443 with TempBNs(r=None, g=g, x=x, n=n) as bn: 444 ssl.BN_mod_exp(bn.r, bn.g, bn.x, bn.n, OpenSSLHelper().ctx) 445 return BnToLong(bn.r) 446 447 448def ModInverse(x: int, n: int) -> int: 449 """Computes 1/x mod n.""" 450 with TempBNs(r=None, x=x, n=n) as bn: 451 ssl.BN_mod_inverse(bn.r, bn.x, bn.n, OpenSSLHelper().ctx) 452 return BnToLong(bn.r) 453 454 455class TempBNs(object): 456 """Class for creating temporary openssl bignums by using 'with' clause.""" 457 458 # Disable pytype attribute checking. 459 _HAS_DYNAMIC_ATTRIBUTES = True 460 461 def __init__(self, **kwargs): 462 r"""Initializes and assigns all temporary bignums. 463 464 Usage: 465 with TempBNs(x=5, y=[10,11]) as bn: 466 # bn.x is the temporary bignum holding the value 5 within this scope. 467 # bn.y is the temporary list of bignum holding the value 10 and 11 468 # within this scope. 469 470 or it can be used for assigning temporary results into bignums as follows: 471 with TempBNs(result=None, x=5) as bn: 472 # bn.result is an empty temporary bignum within this scope. 473 # bn.x is the same as before. 474 475 or bytes can be given as well as longs: 476 with TempBNs(x=5, y=['\001', '\002']) as bn: 477 # bn.x is the temporary bignum holding the value 5 within this scope. 478 # bn.y is the temporary list of bignum holding the value 1 and 2 within 479 # this scope. 480 481 Args: 482 **kwargs: key (variable), value (int or long) pairs. 483 """ 484 self._args = [] 485 for key, value in kwargs.items(): 486 assert not hasattr(self, key), '{} already exists.'.format(key) 487 if isinstance(value, list): 488 assert value, 'Cannot declare empty list in TempBNs.' 489 for v in value: 490 self._args.append(ssl.BN_new()) 491 self._BytesOrLongToBn(self._args[-1], v) 492 setattr(self, key, self._args[-len(value) :]) 493 else: 494 self._args.append(ssl.BN_new()) 495 setattr(self, key, self._args[-1]) 496 if value: 497 self._BytesOrLongToBn(self._args[-1], value) 498 499 @classmethod 500 def _BytesOrLongToBn(cls, bn, val) -> int: 501 if isinstance(val, int): 502 LongtoBn(bn, val) 503 if isinstance(val, str): 504 ssl.BN_bin2bn(val, len(val), bn) 505 506 def __enter__(self, *args): 507 return self 508 509 def __exit__(self, some_type, value, traceback): 510 for bn in self._args: 511 ssl.BN_free(bn) 512 513 514def RandomOracle( 515 x: Union[int, bytes], 516 max_value: int, 517 hash_type: Union[type(None), HashType] = None, 518) -> int: 519 """A random oracle function mapping x deterministically into a large domain. 520 521 The random oracle is similar to the example given in the last paragraph of 522 Chapter 6 of [1] where the output is expanded by successively hashing the 523 concatenation of the input with a fixed sized counter starting from 1. 524 525 [1] Bellare, Mihir, and Phillip Rogaway. "Random oracles are practical: 526 A paradigm for designing efficient protocols." Proceedings of the 1st ACM 527 conference on Computer and communications security. ACM, 1993. 528 529 Args: 530 x: long or string input 531 max_value: the max value of the output domain. 532 hash_type: the hash function to use, as a HashType. If 'None' is provided 533 this defaults to HashType.SHA512. 534 535 Returns: 536 a long value from the set [0, max_value). 537 538 Raises: 539 ValueError: if bit length of max_value is greater than 540 hash_type.bit_length * 254. Since the counter used for expanding the 541 output is expanded to 8 bit length (hard-coded), any counter value that is 542 greater than 256 would cause variable length inputs passed to the 543 underlying hash calls and might make this random oracle's output not 544 uniform across the output domain. The output length is increased by a 545 security value of hash_type.bit_length which reduces the bias of selecting 546 certain values more often than others when max_value is not a multiple of 547 2. 548 """ 549 if hash_type is None: 550 hash_type = HashType.SHA512 551 output_bit_length = max_value.bit_length() + hash_type.bit_length 552 iter_count = int(math.ceil(float(output_bit_length) / hash_type.bit_length)) 553 if iter_count > 255: 554 raise ValueError( 555 'The domain bit length must not be greater than H * 254. ' 556 'Given bit length: {}'.format(output_bit_length) 557 ) 558 excess_bit_count = (iter_count * hash_type.bit_length) - output_bit_length 559 hash_output = 0 560 bytes_x = x if isinstance(x, bytes) else converters.LongToBytes(x) 561 for i in range(1, iter_count + 1): 562 hash_output <<= hash_type.bit_length 563 hash_output |= hash_type.hash( 564 six.ensure_binary(converters.LongToBytes(i) + bytes_x) 565 ) 566 return (hash_output >> excess_bit_count) % max_value 567 568 569class PRNG(object): 570 """Hash based counter mode pseudorandom number generator. 571 572 The technique used in this class is same as the one used in RandomOracle 573 function. 574 """ 575 576 def __init__(self, seed, counter_byte_len=4): 577 """Creates the PRNG with the given seed. 578 579 Args: 580 seed: at least 32 byte number or string. 581 counter_byte_len: the max number of counter bytes to use. After exceeding 582 the counter, this PRNG should not be used. 583 584 Raises: 585 ValueError: when the seed is not at least 32 bytes. 586 """ 587 self.seed = ( 588 seed if isinstance(seed, bytes) else converters.LongToBytes(seed) 589 ) 590 if len(self.seed) < 32: 591 raise ValueError( 592 'seed needs to be at least 32 bytes, the given bytes: {}'.format( 593 self.seed 594 ) 595 ) 596 self.cur_pad = 0 597 self.cur_bytes = b'' 598 self.cur_byte_len = counter_byte_len 599 self.limit = 1 << (self.cur_byte_len * 8) 600 601 def _GetMore(self): 602 assert self.cur_pad < self.limit, 'Limit has been reached.' 603 hash_output = six.ensure_binary( 604 hashlib.sha512( 605 six.ensure_binary(self._PaddedCountBytes() + self.seed) 606 ).digest() 607 ) 608 self.cur_pad += 1 609 self.cur_bytes += hash_output 610 611 def _PaddedCountBytes(self): 612 counter_bytes = converters.LongToBytes(self.cur_pad) 613 # Although we could use {:\x004}.format, Python seems to print space when 614 # doing this way for the null character. 615 return b'\000' * (self.cur_byte_len - len(counter_bytes)) + counter_bytes 616 617 def _GetNBitRand(self, n): 618 """Gets a random number in [0, 2^n) interval.""" 619 byte_len = (n + 7) >> 3 620 excess_len = (8 - (n % 8)) % 8 621 while len(self.cur_bytes) < byte_len: 622 self._GetMore() 623 self.cur_bytes, rand = ( 624 self.cur_bytes[byte_len:], 625 self.cur_bytes[:byte_len], 626 ) 627 rand_num = converters.BytesToLong(rand) >> excess_len 628 return rand_num 629 630 def GetRand(self, upper_limit): 631 """Gets a random number in [0, upper_limit) interval.""" 632 bit_len = (upper_limit - 1).bit_length() 633 rand_num = self._GetNBitRand(bit_len) 634 while rand_num >= upper_limit: 635 rand_num = self._GetNBitRand(bit_len) 636 return rand_num 637 638 639class OpenSSLHelper(object): 640 """A singleton wrapper class for openssl ctx and seeding its rand. 641 642 Context is used for caching already allocated big nums. Each openssl operation 643 requires a context to be passed to Get temporary big nums avoiding allocating 644 new big nums for these temporary nums thus making big num operations use 645 memory resources more efficiently. Usage in openssl library: 646 647 BN_CTX_start(ctx) 648 .... 649 temp = BN_CTX_get(ctx) 650 .... 651 BN_CTX_end(ctx) 652 Please note that BN_CTX_start and BN_CTX_end is not implemented here as this 653 is only passed to various openssl big num operations. 654 """ 655 656 _instance = None 657 658 def __new__(cls, *args, **kwargs): 659 if not cls._instance: 660 cls._instance = super(OpenSSLHelper, cls).__new__(cls, *args, **kwargs) 661 return cls._instance 662 663 def __init__(self): 664 self._ctx = ssl.BN_CTX_new() 665 # So that garbage collection doesn't collect ssl before this object. 666 self.ssl = ssl 667 668 def __del__(self): 669 # clean up 670 self.ssl.BN_CTX_free(self._ctx) 671 672 @property 673 def ctx(self): 674 return self._ctx 675 676 677@total_ordering 678class BigNum(object): 679 """A wrapper class for openssl bn numbers. 680 681 Used for arithmetic operations on long numbers. 682 """ 683 684 _ZERO = None 685 _ONE = None 686 _TWO = None 687 688 def __init__(self, bn_num): 689 self._bn_num = bn_num 690 self._helper = OpenSSLHelper() 691 self.immutable = True 692 # So that garbage collection doesn't collect ssl before this object. 693 self.ssl = ssl 694 695 @classmethod 696 def Zero(cls): 697 if not cls._ZERO: 698 cls._ZERO = cls.FromLongNumber(0) 699 return cls._ZERO 700 701 @classmethod 702 def One(cls): 703 if not cls._ONE: 704 cls._ONE = cls.FromLongNumber(1) 705 return cls._ONE 706 707 @classmethod 708 def Two(cls): 709 if not cls._TWO: 710 cls._TWO = cls.FromLongNumber(2) 711 return cls._TWO 712 713 @classmethod 714 def FromLongNumber(cls, long_number: int) -> 'BigNum': 715 """Returns a BigNum constructed from the given long number.""" 716 bytes_num = converters.LongToBytes(long_number) 717 return cls.FromBytes(bytes_num) 718 719 @classmethod 720 def FromBytes(cls, number_in_bytes): 721 """Returns a BigNum constructed from the given long number.""" 722 bn_num = ssl.BN_new() 723 ssl.BN_bin2bn(number_in_bytes, len(number_in_bytes), bn_num) 724 return BigNum(bn_num) 725 726 @classmethod 727 def GenerateSafePrime(cls, prime_length): 728 """Returns a safe prime BigNum with the given bit-length.""" 729 bn_prime_num = ssl.BN_new() 730 ssl.BN_generate_prime_ex(bn_prime_num, prime_length, 1, None, None, None) 731 return BigNum(bn_prime_num) 732 733 @classmethod 734 def GeneratePrime(cls, prime_length: int) -> 'BigNum': 735 """Returns a prime BigNum with the given bit-length.""" 736 bn_prime_num = ssl.BN_new() 737 ssl.BN_generate_prime_ex(bn_prime_num, prime_length, 0, None, None, None) 738 return BigNum(bn_prime_num) 739 740 def GeneratePrimeForSubGroup(self, prime_length: int) -> 'BigNum': 741 """Returns a prime BigNum, p, satisfying p = (self * k) + 1 for some k. 742 743 Args: 744 prime_length: the bit length of the returned prime. 745 746 Returns: 747 a prime BigNum, p = (self * k) + 1 for some k. 748 """ 749 bn_prime_num = ssl.BN_new() 750 ssl.BN_generate_prime_ex( 751 bn_prime_num, prime_length, 0, self._bn_num, None, None 752 ) 753 return BigNum(bn_prime_num) 754 755 def IsPrime(self, error_probability=1e-6): 756 """Returns True if this big num is prime, False otherwise.""" 757 rounds = int(math.ceil(-math.log(error_probability) / math.log(4))) 758 return ssl.BN_is_prime_ex(self._bn_num, rounds, self._helper.ctx, None) != 0 759 760 def IsSafePrime(self, error_probability=1e-6): 761 """Returns True if this big num is a safe prime, False otherwise.""" 762 return self.IsPrime(error_probability) and ( 763 (self - self.One()) / self.Two() 764 ).IsPrime(error_probability) 765 766 def IsBitSet(self, n): 767 """Returns True if the n-th bit is set, False otherwise.""" 768 return ssl.BN_is_bit_set(self._bn_num, n) 769 770 def BitLength(self): 771 return ssl.BN_num_bits(self._bn_num) 772 773 def Clone(self): 774 """Clones this big num.""" 775 return BigNum(ssl.BN_dup(self._bn_num)) 776 777 def Mutable(self): 778 """Sets this BigNum to mutable so that it can be changed.""" 779 self.immutable = False 780 return self 781 782 def __hash__(self): 783 return hash((self._bn_num, self.immutable)) 784 785 def __del__(self): 786 self.ssl.BN_free(self._bn_num) 787 788 def __add__(self, other): 789 return self._ComputeResult(ssl.BN_add, None, other) 790 791 def __iadd__(self, other): 792 return self._ComputeResultInPlace(ssl.BN_add, None, other) 793 794 def __sub__(self, other): 795 return self._ComputeResult(ssl.BN_sub, None, other) 796 797 def __isub__(self, other): 798 return self._ComputeResultInPlace(ssl.BN_sub, None, other) 799 800 def __mul__(self, other): 801 return self._ComputeResult(ssl.BN_mul, self._helper.ctx, other) 802 803 def __imul__(self, other): 804 return self._ComputeResultInPlace(ssl.BN_mul, self._helper.ctx, other) 805 806 def __mod__(self, modulus): 807 return self._ComputeResult(ssl.BN_nnmod, self._helper.ctx, modulus) 808 809 def __imod__(self, modulus): 810 return self._ComputeResultInPlace(ssl.BN_nnmod, self._helper.ctx, modulus) 811 812 def __pow__(self, other): 813 return self._ComputeResult(ssl.BN_exp, self._helper.ctx, other) 814 815 def __ipow__(self, other): 816 return self._ComputeResultInPlace(ssl.BN_exp, self._helper.ctx, other) 817 818 def __rshift__(self, n): 819 bn_num = ssl.BN_new() 820 ssl.BN_rshift(bn_num, self._bn_num, n) 821 return BigNum(bn_num) 822 823 def __irshift__(self, n): 824 ssl.BN_rshift(self._bn_num, self._bn_num, n) 825 return self 826 827 def __lshift__(self, n): 828 bn_num = ssl.BN_new() 829 ssl.BN_lshift(bn_num, self._bn_num, n) 830 return BigNum(bn_num) 831 832 def __ilshift__(self, n): 833 ssl.BN_lshift(self._bn_num, self._bn_num, n) 834 return self 835 836 def __div__(self, other): 837 return self._Div(BigNum(ssl.BN_new()), self, other) 838 839 def __truediv__(self, other): 840 return self._Div(BigNum(ssl.BN_new()), self, other) 841 842 def __idiv__(self, other): 843 return self._Div(self, self, other) 844 845 def _Div(self, big_result, big_num, other_big_num): 846 """Divides two bignums. 847 848 Args: 849 big_result: The bignum where the result is stored. 850 big_num: The numerator. 851 other_big_num: The denominator. 852 853 Returns: 854 big_result. 855 856 Raises: 857 ValueError: If the remainder is non-zero. 858 """ 859 if isinstance(other_big_num, self.__class__): 860 bn_remainder = ssl.BN_new() 861 ssl.BN_div( 862 big_result._bn_num, 863 bn_remainder, 864 big_num._bn_num, 865 other_big_num._bn_num, 866 self._helper.ctx, 867 ) 868 try: 869 if ssl.BN_cmp(bn_remainder, self.Zero()._bn_num) != 0: 870 raise ValueError( 871 'There is a remainder in division of {} and {}'.format( 872 big_num.GetAsLong(), other_big_num.GetAsLong() 873 ) 874 ) 875 return big_result 876 finally: 877 ssl.BN_free(bn_remainder) 878 879 def ModMul(self, other, modulus): 880 """Modular multiplies this with other based on the modulus. 881 882 For efficiency, please use Montgomery multiplication module if this is done 883 multiple times with the same modulus. 884 885 Args: 886 other: the other BigNum 887 modulus: the modulus of the operation 888 889 Returns: 890 a new BigNum holding the result. 891 """ 892 return self._ComputeResult(ssl.BN_mod_mul, self._helper.ctx, other, modulus) 893 894 def IModMul(self, other, modulus): 895 """Modular multiplies this with other based on the modulus. 896 897 Stores the result in this BigNum. 898 For efficiency, please use Montgomery multiplication module if this is done 899 multiple times with the same modulus. 900 901 Args: 902 other: the other BigNum 903 modulus: the modulus of the operation 904 905 Returns: 906 self 907 """ 908 return self._ComputeResultInPlace( 909 ssl.BN_mod_mul, self._helper.ctx, other, modulus 910 ) 911 912 def ModExp(self, other, modulus): 913 """Modular exponentiates this with other based on the modulus. 914 915 Args: 916 other: the other BigNum 917 modulus: the modulus of the operation 918 919 Returns: 920 a new BigNum holding the result. 921 """ 922 return self._ComputeResult(ssl.BN_mod_exp, self._helper.ctx, other, modulus) 923 924 def IModExp(self, other, modulus): 925 """Modular exponentiates this with other based on the modulus. 926 927 Args: 928 other: the other BigNum 929 modulus: the modulus of the operation 930 931 Returns: 932 self 933 """ 934 return self._ComputeResultInPlace( 935 ssl.BN_mod_exp, self._helper.ctx, other, modulus 936 ) 937 938 def GCD(self, other): 939 """Gets gcd as a BigNum.""" 940 return self._ComputeResult(ssl.BN_gcd, self._helper.ctx, other) 941 942 def ModInverse(self, modulus): 943 """Gets the inverse of this BigNum in mod modulus.""" 944 try: 945 return self._ComputeResult(ssl.BN_mod_inverse, self._helper.ctx, modulus) 946 except AssertionError as a: 947 raise ValueError( 948 'This big num {} and modulus {} are not relatively ' 949 'primes.\nThe Assertion Error: {}'.format( 950 self.GetAsLong(), modulus.GetAsLong(), a 951 ) 952 ) 953 954 def ModSqrt(self, modulus): 955 """Gets the sqrt of this BigNum in mod modulus. 956 957 Args: 958 modulus: the modulus of the operation 959 960 Returns: 961 a new BigNum holding the result. 962 """ 963 big_num_result = self._ComputeResult( 964 ssl.BN_mod_sqrt, self._helper.ctx, modulus 965 ) 966 return big_num_result 967 968 def IModSqrt(self, modulus): 969 """Gets the sqrt of this BigNum in mod modulus. 970 971 Args: 972 modulus: the modulus of the operation 973 974 Returns: 975 self 976 """ 977 return self._ComputeResultInPlace( 978 ssl.BN_mod_sqrt, self._helper.ctx, modulus 979 ) 980 981 def GenerateRand(self): 982 """Generates a cryptographically strong pseudo-random between 0 & self. 983 984 Returns: 985 A BigNum in [0, self._big_num) range. 986 """ 987 bn_rand = ssl.BN_new() 988 ssl.BN_rand_range(bn_rand, self._bn_num) 989 return BigNum(bn_rand) 990 991 def GenerateRandWithStart(self, start_big_num): 992 """Generates a cryptographically strong pseudo-random between start & self. 993 994 Args: 995 start_big_num: start BigNum value of the interval. 996 997 Returns: 998 A BigNum in [start, self._big_num) range. 999 """ 1000 return (self - start_big_num).GenerateRand() + start_big_num 1001 1002 def ModNegate(self, modulus): 1003 return modulus - (self % modulus) 1004 1005 def AddOne(self): 1006 return self + self.One() 1007 1008 def SubtractOne(self): 1009 return self - self.One() 1010 1011 def __str__(self): 1012 return str(self.GetAsLong()) 1013 1014 def __eq__(self, other): 1015 # pylint: disable=protected-access 1016 if isinstance(other, self.__class__): 1017 return ssl.BN_cmp(self._bn_num, other._bn_num) == 0 1018 raise ValueError('Cannot compare BigNum with type {}'.format(type(other))) 1019 # pylint: enable=protected-access 1020 1021 def __ne__(self, other): 1022 return not self == other 1023 1024 def __lt__(self, other): 1025 # pylint: disable=protected-access 1026 if isinstance(other, self.__class__): 1027 return ssl.BN_cmp(self._bn_num, other._bn_num) == -1 1028 raise ValueError('Cannot compare BigNum with type {}'.format(type(other))) 1029 # pylint: enable=protected-access 1030 1031 def _ComputeResult(self, func, ctx, *args): 1032 return self._ComputeResultIntoBigNum( 1033 BigNum(ssl.BN_new()), func, ctx, self, *args 1034 ) 1035 1036 def _ComputeResultInPlace(self, func, ctx, *args): 1037 if self.immutable: 1038 raise ValueError( 1039 'This operation will change this immutable BigNum. Call ' 1040 'Mutable method to change it.' 1041 ) 1042 return self._ComputeResultIntoBigNum(self, func, ctx, self, *args) 1043 1044 @classmethod 1045 def _ComputeResultIntoBigNum(cls, big_num_result, func, ctx, *args): 1046 # pylint: disable=protected-access 1047 if all(isinstance(big_num, cls) for big_num in args): 1048 args = [big_num._bn_num for big_num in args] 1049 if ctx: 1050 args.append(ctx) 1051 func(big_num_result._bn_num, *args) 1052 return big_num_result 1053 return NotImplemented 1054 # pylint: enable=protected-access 1055 1056 def GetAsLong(self): 1057 """Gets the long number in this BigNum.""" 1058 return converters.BytesToLong(self.GetAsBytes()) 1059 1060 def GetAsBytes(self): 1061 """Gets the long number as bytes in this BigNum.""" 1062 num_bits = ssl.BN_num_bits(self._bn_num) 1063 num_bytes = int(math.ceil(num_bits / 8.0)) 1064 bytes_num = ctypes.create_string_buffer(num_bytes) 1065 ssl.BN_bn2bin(self._bn_num, bytes_num) 1066 return bytes_num.raw 1067 1068 1069class BigNumCache(object): 1070 """A singleton cache holding BigNum representations of small numbers.""" 1071 1072 _instance = None 1073 1074 def __new__(cls, *args, **kwargs): # pylint: disable=unused-argument 1075 if not cls._instance: 1076 cls._instance = super(BigNumCache, cls).__new__(cls) 1077 return cls._instance 1078 1079 def __init__(self, max_count: int): 1080 self._cache = {} 1081 self._max_count = max_count 1082 1083 def Get(self, num: int) -> BigNum: 1084 """Gets the BigNum from the cache or creates a new BigNum. 1085 1086 If max_count is reached, a new BigNum is created and returned without 1087 storing in the cache. 1088 Args: 1089 num: the long or integer to convert to BigNum. 1090 1091 Returns: 1092 a BigNum for the given num. 1093 """ 1094 if num not in self._cache: 1095 if len(self._cache) >= self._max_count: 1096 return BigNum.FromLongNumber(num) 1097 self._cache[num] = BigNum.FromLongNumber(num) 1098 return self._cache[num] 1099