Skip to content

Commit

Permalink
Add PYCRYPTODOME_DISABLE_GMP flag and simplify to_bytes() for GMP
Browse files Browse the repository at this point in the history
  • Loading branch information
Legrandin committed Jan 12, 2024
1 parent b6ab946 commit 28e2ef4
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 45 deletions.
5 changes: 5 additions & 0 deletions lib/Crypto/Math/Numbers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,12 @@

__all__ = ["Integer"]

import os

try:
if os.getenv("PYCRYPTODOME_DISABLE_GMP"):
raise ImportError()

from Crypto.Math._IntegerGMP import IntegerGMP as Integer
from Crypto.Math._IntegerGMP import implementation as _implementation
except (ImportError, OSError, AttributeError):
Expand Down
105 changes: 66 additions & 39 deletions lib/Crypto/Math/_IntegerGMP.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,11 @@
# ===================================================================

import sys
import struct

from Crypto.Util.py3compat import tobytes, is_native_int
from Crypto.Util.py3compat import is_native_int

from Crypto.Util._raw_api import (backend, load_lib,
get_raw_buffer, get_c_string,
null_pointer, create_string_buffer,
c_ulong, c_size_t, c_uint8_ptr)

from ._IntegerBase import IntegerBase
Expand Down Expand Up @@ -92,6 +91,9 @@
int __gmpz_invert (mpz_t rop, const mpz_t op1, const mpz_t op2);
int __gmpz_divisible_p (const mpz_t n, const mpz_t d);
int __gmpz_divisible_ui_p (const mpz_t n, UNIX_ULONG d);
size_t __gmpz_size (const mpz_t op);
UNIX_ULONG __gmpz_getlimbn (const mpz_t op, size_t n);
"""

if sys.platform == "win32":
Expand All @@ -103,6 +105,25 @@
if hasattr(lib, "__mpir_version"):
raise ImportError("MPIR library detected")


# Lazy creation of GMP methods
class _GMP(object):

def __getattr__(self, name):
if name.startswith("mpz_"):
func_name = "__gmpz_" + name[4:]
elif name.startswith("gmp_"):
func_name = "__gmp_" + name[4:]
else:
raise AttributeError("Attribute %s is invalid" % name)
func = getattr(lib, func_name)
setattr(self, name, func)
return func


_gmp = _GMP()


# In order to create a function that returns a pointer to
# a new MPZ structure, we need to break the abstraction
# and know exactly what ffi backend we have
Expand All @@ -117,6 +138,8 @@ class _MPZ(Structure):
def new_mpz():
return byref(_MPZ())

_gmp.mpz_getlimbn.restype = c_ulong

else:
# We are using CFFI
from Crypto.Util._raw_api import ffi
Expand All @@ -125,22 +148,8 @@ def new_mpz():
return ffi.new("MPZ*")


# Lazy creation of GMP methods
class _GMP(object):

def __getattr__(self, name):
if name.startswith("mpz_"):
func_name = "__gmpz_" + name[4:]
elif name.startswith("gmp_"):
func_name = "__gmp_" + name[4:]
else:
raise AttributeError("Attribute %s is invalid" % name)
func = getattr(lib, func_name)
setattr(self, name, func)
return func


_gmp = _GMP()
# Size of a native word
_sys_bits = 8 * struct.calcsize("N")


class IntegerGMP(IntegerBase):
Expand Down Expand Up @@ -190,7 +199,6 @@ def __init__(self, value):
else:
raise NotImplementedError


# Conversions
def __int__(self):
tmp = new_mpz()
Expand Down Expand Up @@ -248,31 +256,50 @@ def to_bytes(self, block_size=0, byteorder='big'):
if self < 0:
raise ValueError("Conversion only valid for non-negative numbers")

buf_len = (_gmp.mpz_sizeinbase(self._mpz_p, 2) + 7) // 8
if buf_len > block_size > 0:
raise ValueError("Number is too big to convert to byte string"
" of prescribed length")
buf = create_string_buffer(buf_len)

num_limbs = _gmp.mpz_size(self._mpz_p)
if num_limbs == 0:
limbs = [0]
else:
limbs = [_gmp.mpz_getlimbn(self._mpz_p, i) for i in range(num_limbs)]

_gmp.mpz_export(
buf,
null_pointer, # Ignore countp
1, # Big endian
c_size_t(1), # Each word is 1 byte long
0, # Endianess within a word - not relevant
c_size_t(0), # No nails
self._mpz_p)
if _sys_bits == 32:
spchar = "L"
elif _sys_bits == 64:
spchar = "Q"
else:
raise ValueError("Unknown limb size")

result = b'\x00' * max(0, block_size - buf_len) + get_raw_buffer(buf)
if byteorder == 'big':
pass

result = struct.pack(">" + spchar * num_limbs, *limbs[::-1])
cutoff_len = len(result) - block_size
if block_size == 0:
result = result.lstrip(b'\x00')
elif cutoff_len > 0:
if result[:cutoff_len] != b'\x00' * (cutoff_len):
raise ValueError("Number is too big to convert to "
"byte string of prescribed length")
result = result[cutoff_len:]
elif cutoff_len < 0:
result = b'\x00' * (-cutoff_len) + result

elif byteorder == 'little':
result = bytearray(result)
result.reverse()
result = bytes(result)

result = struct.pack("<" + spchar * num_limbs, *limbs)
cutoff_len = len(result) - block_size
if block_size == 0:
result = result.rstrip(b'\x00')
elif cutoff_len > 0:
if result[-cutoff_len:] != b'\x00' * cutoff_len:
raise ValueError("Number is too big to convert to "
"byte string of prescribed length")
result = result[:-cutoff_len]
elif cutoff_len < 0:
result = result + b'\x00' * (-cutoff_len)

else:
raise ValueError("Incorrect byteorder")

return result

@staticmethod
Expand Down
10 changes: 5 additions & 5 deletions lib/Crypto/SelfTest/Math/test_Numbers.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,15 +112,15 @@ def test_conversion_to_bytes(self):
Integer = self.Integer

v1 = Integer(0x17)
self.assertEqual(b("\x17"), v1.to_bytes())
self.assertEqual(b"\x17", v1.to_bytes())

v2 = Integer(0xFFFE)
self.assertEqual(b("\xFF\xFE"), v2.to_bytes())
self.assertEqual(b("\x00\xFF\xFE"), v2.to_bytes(3))
self.assertEqual(b"\xFF\xFE", v2.to_bytes())
self.assertEqual(b"\x00\xFF\xFE", v2.to_bytes(3))
self.assertRaises(ValueError, v2.to_bytes, 1)

self.assertEqual(b("\xFE\xFF"), v2.to_bytes(byteorder='little'))
self.assertEqual(b("\xFE\xFF\x00"), v2.to_bytes(3, byteorder='little'))
self.assertEqual(b"\xFE\xFF", v2.to_bytes(byteorder='little'))
self.assertEqual(b"\xFE\xFF\x00", v2.to_bytes(3, byteorder='little'))

v3 = Integer(-90)
self.assertRaises(ValueError, v3.to_bytes)
Expand Down
2 changes: 1 addition & 1 deletion lib/Crypto/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
__all__ = ['Cipher', 'Hash', 'Protocol', 'PublicKey', 'Util', 'Signature',
'IO', 'Math']

version_info = (3, 20, '0')
version_info = (3, 21, '0b0')

__version__ = ".".join([str(x) for x in version_info])

0 comments on commit 28e2ef4

Please sign in to comment.