diff --git a/lib/Crypto/Math/Numbers.py b/lib/Crypto/Math/Numbers.py index c2c4483d..d9218f3d 100644 --- a/lib/Crypto/Math/Numbers.py +++ b/lib/Crypto/Math/Numbers.py @@ -30,7 +30,12 @@ __all__ = ["Integer"] +import os + try: + if os.getenv("PYCRYPTODOME_DISABLE_GMP"): + raise ImportError() + from Crypto.Math._IntegerGMP import IntegerGMP as Integer from Crypto.Math._IntegerGMP import implementation as _implementation except (ImportError, OSError, AttributeError): diff --git a/lib/Crypto/Math/_IntegerGMP.py b/lib/Crypto/Math/_IntegerGMP.py index e1b6d66c..21a092a3 100644 --- a/lib/Crypto/Math/_IntegerGMP.py +++ b/lib/Crypto/Math/_IntegerGMP.py @@ -29,12 +29,11 @@ # =================================================================== import sys +import struct -from Crypto.Util.py3compat import tobytes, is_native_int +from Crypto.Util.py3compat import is_native_int from Crypto.Util._raw_api import (backend, load_lib, - get_raw_buffer, get_c_string, - null_pointer, create_string_buffer, c_ulong, c_size_t, c_uint8_ptr) from ._IntegerBase import IntegerBase @@ -92,6 +91,9 @@ int __gmpz_invert (mpz_t rop, const mpz_t op1, const mpz_t op2); int __gmpz_divisible_p (const mpz_t n, const mpz_t d); int __gmpz_divisible_ui_p (const mpz_t n, UNIX_ULONG d); + + size_t __gmpz_size (const mpz_t op); + UNIX_ULONG __gmpz_getlimbn (const mpz_t op, size_t n); """ if sys.platform == "win32": @@ -103,6 +105,25 @@ if hasattr(lib, "__mpir_version"): raise ImportError("MPIR library detected") + +# Lazy creation of GMP methods +class _GMP(object): + + def __getattr__(self, name): + if name.startswith("mpz_"): + func_name = "__gmpz_" + name[4:] + elif name.startswith("gmp_"): + func_name = "__gmp_" + name[4:] + else: + raise AttributeError("Attribute %s is invalid" % name) + func = getattr(lib, func_name) + setattr(self, name, func) + return func + + +_gmp = _GMP() + + # In order to create a function that returns a pointer to # a new MPZ structure, we need to break the abstraction # and know exactly what ffi backend we have @@ -117,6 +138,8 @@ class _MPZ(Structure): def new_mpz(): return byref(_MPZ()) + _gmp.mpz_getlimbn.restype = c_ulong + else: # We are using CFFI from Crypto.Util._raw_api import ffi @@ -125,22 +148,8 @@ def new_mpz(): return ffi.new("MPZ*") -# Lazy creation of GMP methods -class _GMP(object): - - def __getattr__(self, name): - if name.startswith("mpz_"): - func_name = "__gmpz_" + name[4:] - elif name.startswith("gmp_"): - func_name = "__gmp_" + name[4:] - else: - raise AttributeError("Attribute %s is invalid" % name) - func = getattr(lib, func_name) - setattr(self, name, func) - return func - - -_gmp = _GMP() +# Size of a native word +_sys_bits = 8 * struct.calcsize("N") class IntegerGMP(IntegerBase): @@ -190,7 +199,6 @@ def __init__(self, value): else: raise NotImplementedError - # Conversions def __int__(self): tmp = new_mpz() @@ -248,31 +256,50 @@ def to_bytes(self, block_size=0, byteorder='big'): if self < 0: raise ValueError("Conversion only valid for non-negative numbers") - buf_len = (_gmp.mpz_sizeinbase(self._mpz_p, 2) + 7) // 8 - if buf_len > block_size > 0: - raise ValueError("Number is too big to convert to byte string" - " of prescribed length") - buf = create_string_buffer(buf_len) - + num_limbs = _gmp.mpz_size(self._mpz_p) + if num_limbs == 0: + limbs = [0] + else: + limbs = [_gmp.mpz_getlimbn(self._mpz_p, i) for i in range(num_limbs)] - _gmp.mpz_export( - buf, - null_pointer, # Ignore countp - 1, # Big endian - c_size_t(1), # Each word is 1 byte long - 0, # Endianess within a word - not relevant - c_size_t(0), # No nails - self._mpz_p) + if _sys_bits == 32: + spchar = "L" + elif _sys_bits == 64: + spchar = "Q" + else: + raise ValueError("Unknown limb size") - result = b'\x00' * max(0, block_size - buf_len) + get_raw_buffer(buf) if byteorder == 'big': - pass + + result = struct.pack(">" + spchar * num_limbs, *limbs[::-1]) + cutoff_len = len(result) - block_size + if block_size == 0: + result = result.lstrip(b'\x00') + elif cutoff_len > 0: + if result[:cutoff_len] != b'\x00' * (cutoff_len): + raise ValueError("Number is too big to convert to " + "byte string of prescribed length") + result = result[cutoff_len:] + elif cutoff_len < 0: + result = b'\x00' * (-cutoff_len) + result + elif byteorder == 'little': - result = bytearray(result) - result.reverse() - result = bytes(result) + + result = struct.pack("<" + spchar * num_limbs, *limbs) + cutoff_len = len(result) - block_size + if block_size == 0: + result = result.rstrip(b'\x00') + elif cutoff_len > 0: + if result[-cutoff_len:] != b'\x00' * cutoff_len: + raise ValueError("Number is too big to convert to " + "byte string of prescribed length") + result = result[:-cutoff_len] + elif cutoff_len < 0: + result = result + b'\x00' * (-cutoff_len) + else: raise ValueError("Incorrect byteorder") + return result @staticmethod diff --git a/lib/Crypto/SelfTest/Math/test_Numbers.py b/lib/Crypto/SelfTest/Math/test_Numbers.py index 7609485a..7628880a 100644 --- a/lib/Crypto/SelfTest/Math/test_Numbers.py +++ b/lib/Crypto/SelfTest/Math/test_Numbers.py @@ -112,15 +112,15 @@ def test_conversion_to_bytes(self): Integer = self.Integer v1 = Integer(0x17) - self.assertEqual(b("\x17"), v1.to_bytes()) + self.assertEqual(b"\x17", v1.to_bytes()) v2 = Integer(0xFFFE) - self.assertEqual(b("\xFF\xFE"), v2.to_bytes()) - self.assertEqual(b("\x00\xFF\xFE"), v2.to_bytes(3)) + self.assertEqual(b"\xFF\xFE", v2.to_bytes()) + self.assertEqual(b"\x00\xFF\xFE", v2.to_bytes(3)) self.assertRaises(ValueError, v2.to_bytes, 1) - self.assertEqual(b("\xFE\xFF"), v2.to_bytes(byteorder='little')) - self.assertEqual(b("\xFE\xFF\x00"), v2.to_bytes(3, byteorder='little')) + self.assertEqual(b"\xFE\xFF", v2.to_bytes(byteorder='little')) + self.assertEqual(b"\xFE\xFF\x00", v2.to_bytes(3, byteorder='little')) v3 = Integer(-90) self.assertRaises(ValueError, v3.to_bytes) diff --git a/lib/Crypto/__init__.py b/lib/Crypto/__init__.py index c33481e2..127f6b7b 100644 --- a/lib/Crypto/__init__.py +++ b/lib/Crypto/__init__.py @@ -1,6 +1,6 @@ __all__ = ['Cipher', 'Hash', 'Protocol', 'PublicKey', 'Util', 'Signature', 'IO', 'Math'] -version_info = (3, 20, '0') +version_info = (3, 21, '0b0') __version__ = ".".join([str(x) for x in version_info])