Skip to content

Commit

Permalink
fix: fix parse_prefix_free with allow_incomplete=True
Browse files Browse the repository at this point in the history
  • Loading branch information
rwnobrega committed Dec 31, 2024
1 parent da3668e commit 8495bea
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 18 deletions.
18 changes: 7 additions & 11 deletions src/komm/_lossless_coding/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,13 @@ def parse_prefix_free(
allow_incomplete: bool,
) -> npt.NDArray[np.integer]:
output: list[int] = []
i = 0
while i < len(input):
j = 1
while i + j <= len(input):
try:
key = tuple(input[i : i + j])
output.extend(dictionary[key])
break
except KeyError:
j += 1
i += j
i, j = 0, 0
while j < len(input):
j += 1
key = tuple(input[i:j])
if key in dictionary:
output.extend(dictionary[key])
i = j

if i == len(input):
return np.asarray(output)
Expand Down
16 changes: 10 additions & 6 deletions tests/lossless_coding/test_huffman_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,14 @@ def test_huffman_code_invalid_call():
komm.HuffmanCode([0.5, 0.5], policy="unknown") # type: ignore


@pytest.mark.parametrize("n", range(2, 16))
def test_huffman_code_encode_decode(n):
integers = np.random.randint(0, 100, n)
@pytest.mark.parametrize("source_cardinality", [2, 3, 4, 5, 6])
@pytest.mark.parametrize("source_block_size", [1, 2])
@pytest.mark.parametrize("policy", ["high", "low"])
def test_huffman_code_encode_decode(source_cardinality, source_block_size, policy):
integers = np.random.randint(0, 100, source_cardinality)
pmf = integers / integers.sum()
code = komm.HuffmanCode(pmf)
x = np.random.randint(0, n - 1, 1000)
assert np.array_equal(code.decode(code.encode(x)), x)
dms = komm.DiscreteMemorylessSource(pmf)
code = komm.HuffmanCode(pmf, source_block_size=source_block_size, policy=policy)
x = dms(1000 * source_block_size)
x_hat = code.decode(code.encode(x))
assert np.array_equal(x_hat, x)
16 changes: 15 additions & 1 deletion tests/lossless_coding/test_tunstall_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,25 @@ def test_tunstall_code_invalid_init():
@pytest.mark.parametrize("source_cardinality", range(2, 10))
@pytest.mark.parametrize("target_block_size", range(1, 7))
def test_random_tunstall_code(source_cardinality, target_block_size):
if 2**target_block_size < source_cardinality: # Target block size too low.
if 2**target_block_size < source_cardinality: # target block size too low
return
for _ in range(10):
pmf = np.random.rand(source_cardinality)
pmf /= pmf.sum()
code = komm.TunstallCode(pmf, target_block_size)
assert code.is_prefix_free()
assert code.is_fully_covering()


@pytest.mark.parametrize("source_cardinality", range(2, 10))
@pytest.mark.parametrize("target_block_size", range(1, 7))
def test_tunstall_code_encode_decode(source_cardinality, target_block_size):
if 2**target_block_size < source_cardinality: # target block size too low
return
integers = np.random.randint(0, 100, source_cardinality)
pmf = integers / integers.sum()
dms = komm.DiscreteMemorylessSource(pmf)
code = komm.TunstallCode(pmf, target_block_size=target_block_size)
x = dms(1000)
x_hat = code.decode(code.encode(x))[: len(x)]
assert np.array_equal(x_hat, x)

0 comments on commit 8495bea

Please sign in to comment.