From 7d5f8936e22add2cfa431ebdbf1d87ed45d2a605 Mon Sep 17 00:00:00 2001 From: William Baker Date: Wed, 18 Oct 2023 10:32:32 +0100 Subject: [PATCH 1/2] consider value matrix shapes for Jax conversion For some programs, the value matrix could have a dimension larger than the largest key-query matrix resulting in a compliation error. By considering both ov and qk matrices when padding the Jax model we can resolve this --- tracr/compiler/assemble.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tracr/compiler/assemble.py b/tracr/compiler/assemble.py index 693a87f..a952bcd 100644 --- a/tracr/compiler/assemble.py +++ b/tracr/compiler/assemble.py @@ -130,7 +130,8 @@ def _get_model_config_and_module_names( if multi_attn_heads: num_heads = max(len(heads) for heads in multi_attn_heads) - key_size = max(max(head.w_qk.matrix.shape) for head in heads) + key_size = max([max(head.w_qk.matrix.shape) for head in heads] + + [max(head.w_ov.matrix.shape) for head in heads]) else: num_heads, key_size = 1, 1 From 98beaa846149abf724b92604297a4047b60d8201 Mon Sep 17 00:00:00 2001 From: William Baker Date: Mon, 22 Jan 2024 17:27:55 +0000 Subject: [PATCH 2/2] added ov test cases --- tracr/compiler/test_cases.py | 40 ++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/tracr/compiler/test_cases.py b/tracr/compiler/test_cases.py index 9e3ac28..c7ef11e 100644 --- a/tracr/compiler/test_cases.py +++ b/tracr/compiler/test_cases.py @@ -345,6 +345,43 @@ # make_nary_sequencemap(f, *sops) +def make_ov_test_case_1(): + so2 = rasp.Map(lambda x: x - 2, rasp.indices) + so4 = rasp.SequenceMap(lambda x,y: x or y, so2, rasp.indices) + so1 = rasp.Map(lambda x: x > 1, rasp.tokens) + so3 = rasp.Map(lambda x: x < False, so1) + se1 = rasp.Select(so3, so3, rasp.Comparison.LEQ) + so6 = rasp.Aggregate(se1, so4) + return so6 + +def make_ov_test_case_2(): + se1 = rasp.Select(rasp.indices, rasp.indices, lambda x, y: x == y) + so2 = rasp.Aggregate(se1, rasp.indices) + se2 = rasp.Select(rasp.tokens, rasp.tokens, lambda x, y: x < y) + so3 = rasp.SelectorWidth(se2) + se3 = rasp.Select(so3, so2, lambda x, y: x!=y) + so6 = rasp.SequenceMap(lambda x,y: x-y, so3, so3) + so7 = rasp.Aggregate(se3, so6) + return so7 + +TEST_CASES += [ + dict( + testcase_name="ov_test_case_1", + program=make_ov_test_case_1(), + vocab={0, 1, 2, 3, 4, 5, 6, 7, 8}, + test_input=[0], + expected_output=[1, 2, 3, 4], + max_seq_len=6), + dict( + testcase_name="ov_test_case_2", + program=make_ov_test_case_2(), + vocab={0, 1, 2, 3, 4, 5, 6, 7, 8}, + test_input=[1,3], + expected_output=[0,0], + max_seq_len=6), +] + + CAUSAL_TEST_CASES = UNIVERSAL_TEST_CASES + [ dict( testcase_name="selector_width", @@ -355,3 +392,6 @@ expected_output=[1, 2, 3, 4], max_seq_len=5), ] + + +