Skip to content

Commit 302b769

Browse files
committed
use xoflib
1 parent b4b2c01 commit 302b769

File tree

8 files changed

+82
-80
lines changed

8 files changed

+82
-80
lines changed

README.md

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,24 @@ deterministic CSRNG. The reference implementation uses
7777
AES256 CTR DRBG. I have implemented this in [`ase256_ctr_drbg.py`](src/dilithium_py/drbg/ase256_ctr_drbg.py).
7878
However, I have not implemented AES itself, instead I import this from `pycryptodome`.
7979

80-
To install dependencies, run `pip -r install requirements`.
80+
To install dependencies, run `pip install -r requirements.txt`.
8181

8282
If you're happy to use system randomness (`os.urandom`) then you don't need
8383
this dependency.
8484

85+
#### `xoflib`
86+
87+
There is an additional optional dependency of
88+
[`xoflib`](https://github.com/GiacomoPope/xoflib) which is a python package with
89+
bindings to many Rust implementations of eXtendable-Output Functions (XOFx). The
90+
creation of this package was inspired by this repository as Dilithium needs a streaming API from the shake XOFs which `hashlib` doesn't support.
91+
92+
`xoflib` can be installed by running `pip install xoflib` or by installing from requirements as above.
93+
94+
If you do not wish to install this dependency, then we include a small
95+
[`shake_wrapper`](src/dilithium_py/shake/shake_wrapper.py) to mimic `xoflib` but
96+
with a much higher memory consumption due to the limitations of `hashlib`.
97+
8598
## Using dilithium-py
8699

87100
### ML DSA
@@ -126,12 +139,12 @@ The above example would also work with the other NIST levels
126139

127140
Some very rough benchmarks to give an idea about performance:
128141

129-
| 500 Iterations | `ML_DSA_44` | `ML_DSA_65` | `ML_DSA_87` |
142+
| 1000 Iterations | `ML_DSA_44` | `ML_DSA_65` | `ML_DSA_87` |
130143
|--------------------------|--------------|--------------|--------------|
131-
| `KeyGen()` Median Time | 6 ms | 10 ms | 16 ms |
132-
| `Sign()` Median Time | 29 ms | 52 ms | 61 ms |
133-
| `Sign()` Average Time | 36 ms | 64 ms | 75 ms |
134-
| `Verify()` Median Time | 8 ms | 12 ms | 18 ms |
144+
| `KeyGen()` Median Time | 6 ms | 10 ms | 14 ms |
145+
| `Sign()` Median Time | 29 ms | 49 ms | 59 ms |
146+
| `Sign()` Average Time | 36 ms | 62 ms | 75 ms |
147+
| `Verify()` Median Time | 8 ms | 11 ms | 17 ms |
135148

136149
All times recorded using a Intel Core i7-9750H CPU averaged over 1000 calls.
137150

@@ -177,12 +190,12 @@ The above example would also work with the other NIST levels
177190

178191
Some very rough benchmarks to give an idea about performance:
179192

180-
| 500 Iterations | `Dilithium2` | `Dilithium3` | `Dilithium5` |
193+
| 1000 Iterations | `Dilithium2` | `Dilithium3` | `Dilithium5` |
181194
|--------------------------|---------------|--------------|--------------|
182-
| `KeyGen()` Median Time | 6 ms | 10 ms | 16 ms |
195+
| `KeyGen()` Median Time | 6 ms | 9 ms | 15 ms |
183196
| `Sign()` Median Time | 27 ms | 46 ms | 58 ms |
184197
| `Sign()` Average Time | 35 ms | 58 ms | 72 ms |
185-
| `Verify()` Median Time | 8 ms | 12 ms | 18 ms |
198+
| `Verify()` Median Time | 7 ms | 11 ms | 18 ms |
186199

187200
All times recorded using a Intel Core i7-9750H CPU averaged over 1000 calls.
188201

benchmarks/benchmark_dilithium.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ def benchmark_dilithium(Dilithium, name, count):
7070
# I used 1000 calls for the README, but you might want to
7171
# shrink this down if you're playing
7272
count = 1000
73-
benchmark_dilithium(Dilithium2, "Dilithium2", count)
74-
benchmark_dilithium(Dilithium3, "Dilithium3", count)
75-
benchmark_dilithium(Dilithium5, "Dilithium5", count)
73+
# benchmark_dilithium(Dilithium2, "Dilithium2", count)
74+
# benchmark_dilithium(Dilithium3, "Dilithium3", count)
75+
# benchmark_dilithium(Dilithium5, "Dilithium5", count)
7676

77-
# profile_dilithium(Dilithium2)
77+
profile_dilithium(Dilithium2)

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
pycryptodome == 3.14.1
1+
pycryptodome == 3.14.1
2+
xoflib

src/dilithium_py/dilithium/dilithium.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
import os
22
from ..modules.modules import ModuleDilithium
3-
from ..shake.shake_wrapper import Shake256
3+
4+
try:
5+
from xoflib import shake256
6+
except ImportError:
7+
from ..shake.shake_wrapper import shake256
48

59

610
class Dilithium:
@@ -56,7 +60,7 @@ def _h(input_bytes, length):
5660
"""
5761
H: B^* -> B^*
5862
"""
59-
return Shake256.digest(input_bytes, length)
63+
return shake256(input_bytes).read(length)
6064

6165
def _expand_matrix_from_seed(self, rho):
6266
"""

src/dilithium_py/ml_dsa/ml_dsa.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
import os
22
from ..modules.modules import ModuleDilithium
3-
from ..shake.shake_wrapper import Shake256
3+
4+
try:
5+
from xoflib import shake256
6+
except ImportError:
7+
from ..shake.shake_wrapper import shake256
48

59

610
class ML_DSA:
@@ -57,7 +61,7 @@ def _h(input_bytes, length):
5761
"""
5862
H: B^* -> B^*
5963
"""
60-
return Shake256.digest(input_bytes, length)
64+
return shake256(input_bytes).read(length)
6165

6266
def _expand_matrix_from_seed(self, rho):
6367
"""

src/dilithium_py/polynomials/polynomials.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,13 @@
66
decompose,
77
check_norm_bound,
88
)
9-
from ..shake.shake_wrapper import Shake128, Shake256
109
from ..utilities.utils import make_hint, make_hint_optimised, use_hint
1110

11+
try:
12+
from xoflib import shake128, shake256
13+
except ImportError:
14+
from ..shake.shake_wrapper import shake128, shake256
15+
1216

1317
class PolynomialRingDilithium(PolynomialRing):
1418
def __init__(self):
@@ -53,19 +57,19 @@ def rejection_sample(i, xof):
5357
return j
5458

5559
# Initialise the XOF
56-
Shake256.absorb(seed)
60+
xof = shake256(seed)
5761

5862
# Set the first 8 bytes for the sign, and leave the rest for
5963
# sampling.
60-
sign_bytes = Shake256.read(8)
64+
sign_bytes = xof.read(8)
6165
sign_int = int.from_bytes(sign_bytes, "little")
6266

6367
# Set the list of coeffs to be 0
6468
coeffs = [0 for _ in range(256)]
6569

6670
# Now set tau values of coeffs to be ±1
6771
for i in range(256 - tau, 256):
68-
j = rejection_sample(i, Shake256)
72+
j = rejection_sample(i, xof)
6973
coeffs[i] = coeffs[j]
7074
coeffs[j] = 1 - 2 * (sign_int & 1)
7175
sign_int >>= 1
@@ -93,8 +97,8 @@ def rejection_sample(xof):
9397

9498
# Initialise the XOF
9599
seed = rho + bytes([j, i])
96-
Shake128.absorb(seed)
97-
coeffs = [rejection_sample(Shake128) for _ in range(256)]
100+
xof = shake128(seed)
101+
coeffs = [rejection_sample(xof) for _ in range(256)]
98102
return self(coeffs, is_ntt=True)
99103

100104
def rejection_bounded_poly(self, rho_prime, i, eta):
@@ -116,14 +120,14 @@ def coefficient_from_half_byte(j, eta):
116120

117121
# Initialise the XOF
118122
seed = rho_prime + int.to_bytes(i, 2, "little")
119-
Shake256.absorb(seed)
123+
xof = shake256(seed)
120124

121125
# Sample bytes for all n coeffs
122126
i = 0
123127
coeffs = [0 for _ in range(256)]
124128
while i < 256:
125129
# Consider two values for each byte (top and bottom four bits)
126-
j = Shake256.read(1)[0]
130+
j = xof.read(1)[0]
127131

128132
c0 = coefficient_from_half_byte(j % 16, eta)
129133
if c0 is not False:
@@ -151,7 +155,7 @@ def sample_mask_polynomial(self, rho_prime, i, kappa, gamma_1):
151155

152156
# Initialise the XOF
153157
seed = rho_prime + int.to_bytes(kappa + i, 2, "little")
154-
xof_bytes = Shake256.digest(seed, total_bytes)
158+
xof_bytes = shake256(seed).read(total_bytes)
155159
r = int.from_bytes(xof_bytes, "little")
156160
mask = (1 << bit_count) - 1
157161
coeffs = [gamma_1 - ((r >> bit_count * i) & mask) for i in range(self.n)]

src/dilithium_py/shake/shake_wrapper.py

Lines changed: 22 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -17,69 +17,45 @@ class Shake:
1717
def __init__(self, algorithm, block_length):
1818
self.algorithm = algorithm
1919
self.block_length = block_length
20-
self.index = 0
21-
self.read_blocks = 0
22-
self.bytes_left = 0
23-
self.read_data = b""
20+
self.buf = b""
21+
self.len_buf = 0
2422

2523
def absorb(self, input_bytes):
2624
"""
27-
Initialise the XOF with the seed
28-
and reset other init.
25+
Initialise the XOF with the seed and reset other init.
2926
"""
30-
self.read_data = b""
31-
self.read_blocks = 0
32-
self.bytes_left = 0
27+
# Initalize the buffer
3328
self.index = 0
34-
self.xof = self.algorithm(input_bytes)
3529

36-
def digest(self, input_bytes, length):
37-
"""
38-
Sometimes we just want n bytes, so rather than read
39-
them slowly, we can just pull them straight out.
40-
"""
41-
return self.algorithm(input_bytes).digest(length)
30+
# Set the reading method from hashlib digest
31+
self.xof_read = self.algorithm(input_bytes).digest
4232

43-
def get_n_blocks(self, n):
44-
"""
45-
Requests n blocks from Shake and stores them
46-
Ignores any bytes previously read
47-
"""
48-
# Because of hashlib we need to request ALL bytes even
49-
# if we only want 5 more blocks
50-
byte_count = self.block_length * (self.read_blocks + n)
51-
xof_data = self.xof.digest(byte_count)
52-
53-
# include the extra blocks and remove the read ones
54-
self.read_data = (
55-
self.read_data[self.index :] + xof_data[-self.block_length * n :]
56-
)
57-
self.read_blocks += n
58-
self.bytes_left += self.block_length * n
59-
self.index = 0
33+
# Start by requesting 5 blocks from the XOF
34+
self.buf = self.xof_read(5 * self.block_length)
35+
self.len_buf = 5 * self.block_length
6036

6137
def read(self, n):
6238
"""
63-
Rad n bytes from the XOF
39+
Read n bytes from the XOF
6440
"""
6541
# Make sure there are enough bytes to read
66-
if n > self.bytes_left:
67-
# If we don't need many bytes, just get 5 blocks
68-
if (n - self.bytes_left) < 5 * self.block_length:
69-
self.get_n_blocks(5)
70-
# Otherwise get as many as we need
71-
else:
72-
self.get_n_blocks(n // self.block_length + 1)
42+
while self.index + n > self.len_buf:
43+
# double the size of the buffer
44+
self.len_buf *= 2
45+
self.buf = self.xof_read(self.len_buf)
7346

7447
# Read from the buffer data the bytes requested
75-
send = self.read_data[self.index : self.index + n]
48+
send = self.buf[self.index : self.index + n]
7649

77-
# Store that we've read the bytes and shift the index
78-
self.bytes_left -= n
50+
# Shift the index along the buffer
7951
self.index += n
8052

8153
return send
8254

55+
def __call__(self, input_bytes):
56+
self.absorb(input_bytes)
57+
return self
58+
8359

84-
Shake128 = Shake(shake_128, 168)
85-
Shake256 = Shake(shake_256, 136)
60+
shake128 = Shake(shake_128, 168)
61+
shake256 = Shake(shake_256, 136)

tests/test_shake.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from hashlib import shake_128, shake_256
2-
from dilithium_py.shake.shake_wrapper import Shake128, Shake256
2+
from dilithium_py.shake.shake_wrapper import shake128, shake256
33
from Crypto.Hash.SHAKE128 import SHAKE128_XOF
44
from Crypto.Hash.SHAKE256 import SHAKE256_XOF
55

@@ -22,12 +22,12 @@ def hashlib_test_many_calls(self, Shake, shake_hashlib):
2222
self.assertEqual(shake_hashlib(absorb_bytes).digest(l), output)
2323

2424
def test_hashlib_shake128(self):
25-
self.hashlib_test_long_calls(Shake128, shake_128)
26-
self.hashlib_test_many_calls(Shake128, shake_128)
25+
self.hashlib_test_long_calls(shake128, shake_128)
26+
self.hashlib_test_many_calls(shake128, shake_128)
2727

2828
def test_hashlib_shake256(self):
29-
self.hashlib_test_long_calls(Shake256, shake_256)
30-
self.hashlib_test_many_calls(Shake256, shake_256)
29+
self.hashlib_test_long_calls(shake256, shake_256)
30+
self.hashlib_test_many_calls(shake256, shake_256)
3131

3232

3333
class TestShakeCrypto(unittest.TestCase):
@@ -40,5 +40,5 @@ def pycryptodome_test_read_chunks(self, Shake, ShakeCrypto):
4040
self.assertEqual(Shake.read(chunk), ShakeCrypto.read(chunk))
4141

4242
def test_pycryptodome_shake(self):
43-
self.pycryptodome_test_read_chunks(Shake128, SHAKE128_XOF())
44-
self.pycryptodome_test_read_chunks(Shake256, SHAKE256_XOF())
43+
self.pycryptodome_test_read_chunks(shake128, SHAKE128_XOF())
44+
self.pycryptodome_test_read_chunks(shake256, SHAKE256_XOF())

0 commit comments

Comments
 (0)