diff --git a/src/komm/_lossless_coding/util.py b/src/komm/_lossless_coding/util.py index 73748510..ff9461ce 100644 --- a/src/komm/_lossless_coding/util.py +++ b/src/komm/_lossless_coding/util.py @@ -63,17 +63,13 @@ def parse_prefix_free( allow_incomplete: bool, ) -> npt.NDArray[np.integer]: output: list[int] = [] - i = 0 - while i < len(input): - j = 1 - while i + j <= len(input): - try: - key = tuple(input[i : i + j]) - output.extend(dictionary[key]) - break - except KeyError: - j += 1 - i += j + i, j = 0, 0 + while j < len(input): + j += 1 + key = tuple(input[i:j]) + if key in dictionary: + output.extend(dictionary[key]) + i = j if i == len(input): return np.asarray(output) diff --git a/tests/lossless_coding/test_huffman_code.py b/tests/lossless_coding/test_huffman_code.py index d0758fd7..ce9b4002 100644 --- a/tests/lossless_coding/test_huffman_code.py +++ b/tests/lossless_coding/test_huffman_code.py @@ -81,10 +81,14 @@ def test_huffman_code_invalid_call(): komm.HuffmanCode([0.5, 0.5], policy="unknown") # type: ignore -@pytest.mark.parametrize("n", range(2, 16)) -def test_huffman_code_encode_decode(n): - integers = np.random.randint(0, 100, n) +@pytest.mark.parametrize("source_cardinality", [2, 3, 4, 5, 6]) +@pytest.mark.parametrize("source_block_size", [1, 2]) +@pytest.mark.parametrize("policy", ["high", "low"]) +def test_huffman_code_encode_decode(source_cardinality, source_block_size, policy): + integers = np.random.randint(0, 100, source_cardinality) pmf = integers / integers.sum() - code = komm.HuffmanCode(pmf) - x = np.random.randint(0, n - 1, 1000) - assert np.array_equal(code.decode(code.encode(x)), x) + dms = komm.DiscreteMemorylessSource(pmf) + code = komm.HuffmanCode(pmf, source_block_size=source_block_size, policy=policy) + x = dms(1000 * source_block_size) + x_hat = code.decode(code.encode(x)) + assert np.array_equal(x_hat, x) diff --git a/tests/lossless_coding/test_tunstall_code.py b/tests/lossless_coding/test_tunstall_code.py index bf09fe11..e62a3cbf 100644 --- a/tests/lossless_coding/test_tunstall_code.py +++ b/tests/lossless_coding/test_tunstall_code.py @@ -31,7 +31,7 @@ def test_tunstall_code_invalid_init(): @pytest.mark.parametrize("source_cardinality", range(2, 10)) @pytest.mark.parametrize("target_block_size", range(1, 7)) def test_random_tunstall_code(source_cardinality, target_block_size): - if 2**target_block_size < source_cardinality: # Target block size too low. + if 2**target_block_size < source_cardinality: # target block size too low return for _ in range(10): pmf = np.random.rand(source_cardinality) @@ -39,3 +39,17 @@ def test_random_tunstall_code(source_cardinality, target_block_size): code = komm.TunstallCode(pmf, target_block_size) assert code.is_prefix_free() assert code.is_fully_covering() + + +@pytest.mark.parametrize("source_cardinality", range(2, 10)) +@pytest.mark.parametrize("target_block_size", range(1, 7)) +def test_tunstall_code_encode_decode(source_cardinality, target_block_size): + if 2**target_block_size < source_cardinality: # target block size too low + return + integers = np.random.randint(0, 100, source_cardinality) + pmf = integers / integers.sum() + dms = komm.DiscreteMemorylessSource(pmf) + code = komm.TunstallCode(pmf, target_block_size=target_block_size) + x = dms(1000) + x_hat = code.decode(code.encode(x))[: len(x)] + assert np.array_equal(x_hat, x)