diff --git a/.github/workflows/parity.yml b/.github/workflows/parity.yml new file mode 100644 index 0000000..238edfe --- /dev/null +++ b/.github/workflows/parity.yml @@ -0,0 +1,27 @@ +name: Parity check + +on: + push: + pull_request: + +jobs: + parity: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install deps (CPU only) + run: | + python -m pip install --upgrade pip + pip install --index-url https://download.pytorch.org/whl/cpu torch + pip install numpy "jax[cpu]" dm-haiku + pip install git+https://github.com/google-deepmind/tracr.git + + - name: Compile & export tracr + run: python scripts/compile_export.py + + - name: Verify parity (fail if mismatch) + run: python scripts/parity_check.py diff --git a/__pycache__/tracr_transformer_pt.cpython-313.pyc b/__pycache__/tracr_transformer_pt.cpython-313.pyc index 1f6d124..5a5b73c 100644 Binary files a/__pycache__/tracr_transformer_pt.cpython-313.pyc and b/__pycache__/tracr_transformer_pt.cpython-313.pyc differ diff --git a/artifacts/input_tokens.json b/artifacts/input_tokens.json new file mode 100644 index 0000000..3674a71 --- /dev/null +++ b/artifacts/input_tokens.json @@ -0,0 +1 @@ +["BOS", 1, 0, 1, 1, 0] \ No newline at end of file diff --git a/artifacts/token_to_id.json b/artifacts/token_to_id.json new file mode 100644 index 0000000..6966287 --- /dev/null +++ b/artifacts/token_to_id.json @@ -0,0 +1,6 @@ +{ + "BOS": 2, + "0": 0, + "1": 1, + "PAD": 3 +} \ No newline at end of file diff --git a/artifacts/tracr_majority_params.npz b/artifacts/tracr_majority_params.npz new file mode 100644 index 0000000..ce7aa3a Binary files /dev/null and b/artifacts/tracr_majority_params.npz differ diff --git a/artifacts/tracr_output.npy b/artifacts/tracr_output.npy new file mode 100644 index 0000000..a802065 Binary files /dev/null and b/artifacts/tracr_output.npy differ diff --git a/artifacts/tracr_param_keys.json b/artifacts/tracr_param_keys.json new file mode 100644 index 0000000..460b71c --- /dev/null +++ b/artifacts/tracr_param_keys.json @@ -0,0 +1,40 @@ +[ + "pos_embed__embeddings", + "token_embed__embeddings", + "transformer__layer_0__attn__key__b", + "transformer__layer_0__attn__key__w", + "transformer__layer_0__attn__linear__b", + "transformer__layer_0__attn__linear__w", + "transformer__layer_0__attn__query__b", + "transformer__layer_0__attn__query__w", + "transformer__layer_0__attn__value__b", + "transformer__layer_0__attn__value__w", + "transformer__layer_0__mlp__linear_1__b", + "transformer__layer_0__mlp__linear_1__w", + "transformer__layer_0__mlp__linear_2__b", + "transformer__layer_0__mlp__linear_2__w", + "transformer__layer_1__attn__key__b", + "transformer__layer_1__attn__key__w", + "transformer__layer_1__attn__linear__b", + "transformer__layer_1__attn__linear__w", + "transformer__layer_1__attn__query__b", + "transformer__layer_1__attn__query__w", + "transformer__layer_1__attn__value__b", + "transformer__layer_1__attn__value__w", + "transformer__layer_1__mlp__linear_1__b", + "transformer__layer_1__mlp__linear_1__w", + "transformer__layer_1__mlp__linear_2__b", + "transformer__layer_1__mlp__linear_2__w", + "transformer__layer_2__attn__key__b", + "transformer__layer_2__attn__key__w", + "transformer__layer_2__attn__linear__b", + "transformer__layer_2__attn__linear__w", + "transformer__layer_2__attn__query__b", + "transformer__layer_2__attn__query__w", + "transformer__layer_2__attn__value__b", + "transformer__layer_2__attn__value__w", + "transformer__layer_2__mlp__linear_1__b", + "transformer__layer_2__mlp__linear_1__w", + "transformer__layer_2__mlp__linear_2__b", + "transformer__layer_2__mlp__linear_2__w" +] \ No newline at end of file diff --git a/export_tracr_params.py b/export_tracr_params.py deleted file mode 100644 index 1b67dbc..0000000 --- a/export_tracr_params.py +++ /dev/null @@ -1,60 +0,0 @@ -# export_tracr_params.py -import numpy as np -import jax -from tracr.compiler import compiling -from tracr.rasp import rasp - -VOCAB = {0, 1} -MAX_SEQ_LEN = 10 -COMPILER_BOS = "BOS" - -def majority_score_program(): - all_pos = rasp.Select(rasp.indices, rasp.indices, rasp.Comparison.TRUE) - select_ones = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.EQ) - num_ones = rasp.Aggregate(select_ones, rasp.tokens) - seq_len = rasp.Aggregate(all_pos, rasp.tokens * 0 + 1) - majority_score = num_ones - (seq_len - num_ones) - return majority_score - -print("Compiling RASP → JAX/Haiku transformer…") -compiled = compiling.compile_rasp_to_model( - majority_score_program(), - vocab=VOCAB, - max_seq_len=MAX_SEQ_LEN, - compiler_bos=COMPILER_BOS, -) -print("Done.") - -# --- Inspect param tree to learn names you must map --- -# Depending on Tracr version, compiled.params or compiled.weights exists. -# Print keys so we can map them into PyTorch: -try: - params = compiled.params -except AttributeError: - params = compiled.weights # fallback if older API - -flat, treedef = jax.tree_util.tree_flatten(params) -leaves_with_paths = [] - -def track_paths(path, node): - if isinstance(node, (dict,)): - for k,v in node.items(): - track_paths(path + (k,), v) - else: - leaves_with_paths.append(("/".join(path), node)) - -track_paths((), params) - -print("\n=== JAX PARAM KEYS (preview) ===") -for k, v in leaves_with_paths: - print(f"{k}: shape={np.array(v).shape}") -print("=== end ===\n") - -# --- Save to NPZ with slash->double_underscore to be filesystem-friendly --- -npz_dict = {} -for k, v in leaves_with_paths: - safe = k.replace("/", "__") - npz_dict[safe] = np.array(v) - -np.savez("tracr_majority_params.npz", **npz_dict) -print("Exported => tracr_majority_params.npz") diff --git a/graph.gv b/graph.gv deleted file mode 100644 index 2470339..0000000 --- a/graph.gv +++ /dev/null @@ -1,366 +0,0 @@ -// Computational graph for the feedforward sweep -digraph TracrTransformerPT { - graph [label=<TracrTransformerPT
77 tensors total (36.7 KB)
4656 params total (22.7 KB)
> labeljust=left labelloc=t ordering=out rankdir=BT] - node [ordering=out] - input_1 [label=<input_1
1x6 (176 B)
@input.token_ids> color=black fillcolor="#98FB98" fontcolor=black ordering=out shape=oval style="filled,solid"] - input_1 -> embedding_1_1 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - embedding_1_1 [label=<embedding_1_1
1x6x24 (720 B)
params: 4x24
@token_emb> color=black fillcolor="#E6E6E6" fontcolor=black ordering=out shape=box style="filled,solid"] - embedding_1_1 -> add_1_6 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - arange_1_2 [label=<arange_1_2
x6 (160 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,dashed"] - arange_1_2 -> unsqueeze_1_3 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=dashed] - unsqueeze_1_3 [label=<unsqueeze_1_3
1x6 (176 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,dashed"] - unsqueeze_1_3 -> expand_1_4 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=dashed] - expand_1_4 [label=<expand_1_4
1x6 (176 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,dashed"] - expand_1_4 -> embedding_2_5 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=dashed] - embedding_2_5 [label=<embedding_2_5
1x6x24 (720 B)
params: 11x24
@pos_emb> color=black fillcolor="#E6E6E6" fontcolor=black ordering=out shape=box style="filled,dashed"] - embedding_2_5 -> add_1_6 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=dashed] - add_1_6 [label=<add_1_6
1x6x24 (720 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"] - add_1_6 -> linear_1_7 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - add_1_6 -> linear_2_10 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - add_1_6 -> linear_3_13 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - add_1_6 -> add_2_25pass1 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - linear_1_7 [label=<linear_1_7
1x6x12 (432 B)
params: 12x24, x12
@layers.0.attn.W_q> color=black fillcolor="#E6E6E6" fontcolor=black ordering=out shape=box style="filled,solid"] - view_1_8 [label=<view_1_8
1x6x3x4 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"] - transpose_1_9 [label=<transpose_1_9
1x3x6x4 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"] - linear_2_10 [label=<linear_2_10
1x6x12 (432 B)
params: 12x24, x12
@layers.0.attn.W_k> color=black fillcolor="#E6E6E6" fontcolor=black ordering=out shape=box style="filled,solid"] - view_2_11 [label=<view_2_11
1x6x3x4 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"] - transpose_2_12 [label=<transpose_2_12
1x3x6x4 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"] - linear_3_13 [label=<linear_3_13
1x6x12 (432 B)
params: 12x24, x12
@layers.0.attn.W_v> color=black fillcolor="#E6E6E6" fontcolor=black ordering=out shape=box style="filled,solid"] - view_3_14 [label=<view_3_14
1x6x3x4 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"] - transpose_3_15 [label=<transpose_3_15
1x3x6x4 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"] - transpose_4_16 [label=<transpose_4_16
1x3x4x6 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"] - matmul_1_17 [label=<matmul_1_17
1x3x6x6 (592 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"] - truediv_1_18 [label=<truediv_1_18
1x3x6x6 (592 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"] - softmax_1_19 [label=<softmax_1_19
1x3x6x6 (592 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"] - matmul_2_20 [label=<matmul_2_20
1x3x6x4 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"] - transpose_5_21 [label=<transpose_5_21
1x6x3x4 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"] - contiguous_1_22 [label=<contiguous_1_22
1x6x3x4 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"] - view_4_23 [label=<view_4_23
1x6x12 (432 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"] - linear_4_24 [label=<linear_4_24
1x6x24 (720 B)
params: 24x12, x24
@layers.0.attn.W_o> color=black fillcolor="#E6E6E6" fontcolor=black ordering=out shape=box style="filled,solid"] - add_2_25pass1 [label=<add_2_25:1
1x6x24 (720 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"] - linear_5_26 [label=<linear_5_26
1x6x4 (240 B)
params: 4x24, x4
@layers.0.mlp.0> color=black fillcolor="#E6E6E6" fontcolor=black ordering=out shape=box style="filled,solid"] - gelu_1_27 [label=<gelu_1_27
1x6x4 (240 B)
@layers.0.mlp.1> color=black fillcolor=white fontcolor=black ordering=out shape=box style="filled,solid"] - linear_6_28 [label=<linear_6_28
1x6x24 (720 B)
params: 24x4, x24
@layers.0.mlp.2> color=black fillcolor="#E6E6E6" fontcolor=black ordering=out shape=box style="filled,solid"] - add_2_25pass2 [label=<add_2_25:2
1x6x24 (720 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"] - add_2_25pass2 -> linear_7_29 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - add_2_25pass2 -> linear_8_32 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - add_2_25pass2 -> linear_9_35 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - add_2_25pass2 -> add_3_47pass1 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - linear_7_29 [label=<linear_7_29
1x6x12 (432 B)
params: 12x24, x12
@layers.1.attn.W_q> color=black fillcolor="#E6E6E6" fontcolor=black ordering=out shape=box style="filled,solid"] - view_5_30 [label=<view_5_30
1x6x3x4 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"] - transpose_6_31 [label=<transpose_6_31
1x3x6x4 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"] - linear_8_32 [label=<linear_8_32
1x6x12 (432 B)
params: 12x24, x12
@layers.1.attn.W_k> color=black fillcolor="#E6E6E6" fontcolor=black ordering=out shape=box style="filled,solid"] - view_6_33 [label=<view_6_33
1x6x3x4 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"] - transpose_7_34 [label=<transpose_7_34
1x3x6x4 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"] - linear_9_35 [label=<linear_9_35
1x6x12 (432 B)
params: 12x24, x12
@layers.1.attn.W_v> color=black fillcolor="#E6E6E6" fontcolor=black ordering=out shape=box style="filled,solid"] - view_7_36 [label=<view_7_36
1x6x3x4 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"] - transpose_8_37 [label=<transpose_8_37
1x3x6x4 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"] - transpose_9_38 [label=<transpose_9_38
1x3x4x6 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"] - matmul_3_39 [label=<matmul_3_39
1x3x6x6 (592 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"] - truediv_2_40 [label=<truediv_2_40
1x3x6x6 (592 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"] - softmax_2_41 [label=<softmax_2_41
1x3x6x6 (592 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"] - matmul_4_42 [label=<matmul_4_42
1x3x6x4 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"] - transpose_10_43 [label=<transpose_10_43
1x6x3x4 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"] - contiguous_2_44 [label=<contiguous_2_44
1x6x3x4 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"] - view_8_45 [label=<view_8_45
1x6x12 (432 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"] - linear_10_46 [label=<linear_10_46
1x6x24 (720 B)
params: 24x12, x24
@layers.1.attn.W_o> color=black fillcolor="#E6E6E6" fontcolor=black ordering=out shape=box style="filled,solid"] - add_3_47pass1 [label=<add_3_47:1
1x6x24 (720 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"] - linear_11_48 [label=<linear_11_48
1x6x4 (240 B)
params: 4x24, x4
@layers.1.mlp.0> color=black fillcolor="#E6E6E6" fontcolor=black ordering=out shape=box style="filled,solid"] - gelu_2_49 [label=<gelu_2_49
1x6x4 (240 B)
@layers.1.mlp.1> color=black fillcolor=white fontcolor=black ordering=out shape=box style="filled,solid"] - linear_12_50 [label=<linear_12_50
1x6x24 (720 B)
params: 24x4, x24
@layers.1.mlp.2> color=black fillcolor="#E6E6E6" fontcolor=black ordering=out shape=box style="filled,solid"] - add_3_47pass2 [label=<add_3_47:2
1x6x24 (720 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"] - add_3_47pass2 -> linear_13_51 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - add_3_47pass2 -> linear_14_54 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - add_3_47pass2 -> linear_15_57 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - add_3_47pass2 -> add_4_69pass1 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - linear_13_51 [label=<linear_13_51
1x6x12 (432 B)
params: 12x24, x12
@layers.2.attn.W_q> color=black fillcolor="#E6E6E6" fontcolor=black ordering=out shape=box style="filled,solid"] - view_9_52 [label=<view_9_52
1x6x3x4 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"] - transpose_11_53 [label=<transpose_11_53
1x3x6x4 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"] - linear_14_54 [label=<linear_14_54
1x6x12 (432 B)
params: 12x24, x12
@layers.2.attn.W_k> color=black fillcolor="#E6E6E6" fontcolor=black ordering=out shape=box style="filled,solid"] - view_10_55 [label=<view_10_55
1x6x3x4 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"] - transpose_12_56 [label=<transpose_12_56
1x3x6x4 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"] - linear_15_57 [label=<linear_15_57
1x6x12 (432 B)
params: 12x24, x12
@layers.2.attn.W_v> color=black fillcolor="#E6E6E6" fontcolor=black ordering=out shape=box style="filled,solid"] - view_11_58 [label=<view_11_58
1x6x3x4 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"] - transpose_13_59 [label=<transpose_13_59
1x3x6x4 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"] - transpose_14_60 [label=<transpose_14_60
1x3x4x6 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"] - matmul_5_61 [label=<matmul_5_61
1x3x6x6 (592 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"] - truediv_3_62 [label=<truediv_3_62
1x3x6x6 (592 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"] - softmax_3_63 [label=<softmax_3_63
1x3x6x6 (592 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"] - matmul_6_64 [label=<matmul_6_64
1x3x6x4 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"] - transpose_15_65 [label=<transpose_15_65
1x6x3x4 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"] - contiguous_3_66 [label=<contiguous_3_66
1x6x3x4 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"] - view_12_67 [label=<view_12_67
1x6x12 (432 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"] - linear_16_68 [label=<linear_16_68
1x6x24 (720 B)
params: 24x12, x24
@layers.2.attn.W_o> color=black fillcolor="#E6E6E6" fontcolor=black ordering=out shape=box style="filled,solid"] - add_4_69pass1 [label=<add_4_69:1
1x6x24 (720 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"] - linear_17_70 [label=<linear_17_70
1x6x4 (240 B)
params: 4x24, x4
@layers.2.mlp.0> color=black fillcolor="#E6E6E6" fontcolor=black ordering=out shape=box style="filled,solid"] - gelu_3_71 [label=<gelu_3_71
1x6x4 (240 B)
@layers.2.mlp.1> color=black fillcolor=white fontcolor=black ordering=out shape=box style="filled,solid"] - linear_18_72 [label=<linear_18_72
1x6x24 (720 B)
params: 24x4, x24
@layers.2.mlp.2> color=black fillcolor="#E6E6E6" fontcolor=black ordering=out shape=box style="filled,solid"] - add_4_69pass2 [label=<add_4_69:2
1x6x24 (720 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"] - add_4_69pass2 -> output_1 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - output_1 [label=<output_1
1x6x24 (720 B)
@output> color=black fillcolor="#ff9999" fontcolor=black ordering=out shape=oval style="filled,solid"] - { - rank=sink - output_1 - } - subgraph cluster_token_emb_pass1 { - fillcolor=white label=<@token_emb
(Embedding)
> labelloc=b penwidth=5.0 style="filled,dashed" - } - subgraph cluster_pos_emb_pass1 { - fillcolor=white label=<@pos_emb
(Embedding)
> labelloc=b penwidth=5.0 style="filled,dashed" - } - subgraph "cluster_layers.0_pass1" { - fillcolor=white label=<@layers.0
(EncoderBlock)
> labelloc=b penwidth=5.0 style="filled,solid" - linear_4_24 -> add_2_25pass1 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - add_2_25pass1 -> linear_5_26 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - add_2_25pass1 -> add_2_25pass2 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - linear_6_28 -> add_2_25pass2 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - } - subgraph "cluster_layers.1_pass1" { - fillcolor=white label=<@layers.1
(EncoderBlock)
> labelloc=b penwidth=5.0 style="filled,solid" - linear_10_46 -> add_3_47pass1 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - add_3_47pass1 -> linear_11_48 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - add_3_47pass1 -> add_3_47pass2 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - linear_12_50 -> add_3_47pass2 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - } - subgraph "cluster_layers.2_pass1" { - fillcolor=white label=<@layers.2
(EncoderBlock)
> labelloc=b penwidth=5.0 style="filled,solid" - linear_16_68 -> add_4_69pass1 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - add_4_69pass1 -> linear_17_70 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - add_4_69pass1 -> add_4_69pass2 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - linear_18_72 -> add_4_69pass2 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - } - subgraph "cluster_layers.0_pass1" { - subgraph "cluster_layers.0.attn_pass1" { - fillcolor=white label=<@layers.0.attn
(MultiheadSelfAttention)
> labelloc=b penwidth=3.5 style="filled,solid" - linear_1_7 -> view_1_8 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - view_1_8 -> transpose_1_9 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - transpose_1_9 -> matmul_1_17 [label=<arg 0> arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - linear_2_10 -> view_2_11 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - view_2_11 -> transpose_2_12 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - transpose_2_12 -> transpose_4_16 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - linear_3_13 -> view_3_14 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - view_3_14 -> transpose_3_15 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - transpose_3_15 -> matmul_2_20 [label=<arg 1> arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - transpose_4_16 -> matmul_1_17 [label=<arg 1> arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - matmul_1_17 -> truediv_1_18 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - truediv_1_18 -> softmax_1_19 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - softmax_1_19 -> matmul_2_20 [label=<arg 0> arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - matmul_2_20 -> transpose_5_21 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - transpose_5_21 -> contiguous_1_22 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - contiguous_1_22 -> view_4_23 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - view_4_23 -> linear_4_24 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - } - } - subgraph "cluster_layers.0_pass1" { - subgraph "cluster_layers.0.mlp_pass1" { - fillcolor=white label=<@layers.0.mlp
(Sequential)
> labelloc=b penwidth=3.5 style="filled,solid" - linear_5_26 -> gelu_1_27 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - gelu_1_27 -> linear_6_28 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - } - } - subgraph "cluster_layers.1_pass1" { - subgraph "cluster_layers.1.attn_pass1" { - fillcolor=white label=<@layers.1.attn
(MultiheadSelfAttention)
> labelloc=b penwidth=3.5 style="filled,solid" - linear_7_29 -> view_5_30 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - view_5_30 -> transpose_6_31 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - transpose_6_31 -> matmul_3_39 [label=<arg 0> arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - linear_8_32 -> view_6_33 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - view_6_33 -> transpose_7_34 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - transpose_7_34 -> transpose_9_38 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - linear_9_35 -> view_7_36 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - view_7_36 -> transpose_8_37 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - transpose_8_37 -> matmul_4_42 [label=<arg 1> arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - transpose_9_38 -> matmul_3_39 [label=<arg 1> arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - matmul_3_39 -> truediv_2_40 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - truediv_2_40 -> softmax_2_41 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - softmax_2_41 -> matmul_4_42 [label=<arg 0> arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - matmul_4_42 -> transpose_10_43 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - transpose_10_43 -> contiguous_2_44 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - contiguous_2_44 -> view_8_45 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - view_8_45 -> linear_10_46 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - } - } - subgraph "cluster_layers.1_pass1" { - subgraph "cluster_layers.1.mlp_pass1" { - fillcolor=white label=<@layers.1.mlp
(Sequential)
> labelloc=b penwidth=3.5 style="filled,solid" - linear_11_48 -> gelu_2_49 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - gelu_2_49 -> linear_12_50 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - } - } - subgraph "cluster_layers.2_pass1" { - subgraph "cluster_layers.2.attn_pass1" { - fillcolor=white label=<@layers.2.attn
(MultiheadSelfAttention)
> labelloc=b penwidth=3.5 style="filled,solid" - linear_13_51 -> view_9_52 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - view_9_52 -> transpose_11_53 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - transpose_11_53 -> matmul_5_61 [label=<arg 0> arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - linear_14_54 -> view_10_55 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - view_10_55 -> transpose_12_56 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - transpose_12_56 -> transpose_14_60 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - linear_15_57 -> view_11_58 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - view_11_58 -> transpose_13_59 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - transpose_13_59 -> matmul_6_64 [label=<arg 1> arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - transpose_14_60 -> matmul_5_61 [label=<arg 1> arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - matmul_5_61 -> truediv_3_62 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - truediv_3_62 -> softmax_3_63 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - softmax_3_63 -> matmul_6_64 [label=<arg 0> arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - matmul_6_64 -> transpose_15_65 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - transpose_15_65 -> contiguous_3_66 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - contiguous_3_66 -> view_12_67 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - view_12_67 -> linear_16_68 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - } - } - subgraph "cluster_layers.2_pass1" { - subgraph "cluster_layers.2.mlp_pass1" { - fillcolor=white label=<@layers.2.mlp
(Sequential)
> labelloc=b penwidth=3.5 style="filled,solid" - linear_17_70 -> gelu_3_71 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - gelu_3_71 -> linear_18_72 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid] - } - } - subgraph "cluster_layers.0_pass1" { - subgraph "cluster_layers.0.attn_pass1" { - subgraph "cluster_layers.0.attn.W_q_pass1" { - fillcolor=white label=<@layers.0.attn.W_q
(Linear)
> labelloc=b penwidth=2.0 style="filled,dashed" - } - } - } - subgraph "cluster_layers.0_pass1" { - subgraph "cluster_layers.0.attn_pass1" { - subgraph "cluster_layers.0.attn.W_k_pass1" { - fillcolor=white label=<@layers.0.attn.W_k
(Linear)
> labelloc=b penwidth=2.0 style="filled,dashed" - } - } - } - subgraph "cluster_layers.0_pass1" { - subgraph "cluster_layers.0.attn_pass1" { - subgraph "cluster_layers.0.attn.W_v_pass1" { - fillcolor=white label=<@layers.0.attn.W_v
(Linear)
> labelloc=b penwidth=2.0 style="filled,dashed" - } - } - } - subgraph "cluster_layers.0_pass1" { - subgraph "cluster_layers.0.attn_pass1" { - subgraph "cluster_layers.0.attn.W_o_pass1" { - fillcolor=white label=<@layers.0.attn.W_o
(Linear)
> labelloc=b penwidth=2.0 style="filled,dashed" - } - } - } - subgraph "cluster_layers.0_pass1" { - subgraph "cluster_layers.0.mlp_pass1" { - subgraph "cluster_layers.0.mlp.0_pass1" { - fillcolor=white label=<@layers.0.mlp.0
(Linear)
> labelloc=b penwidth=2.0 style="filled,dashed" - } - } - } - subgraph "cluster_layers.0_pass1" { - subgraph "cluster_layers.0.mlp_pass1" { - subgraph "cluster_layers.0.mlp.1_pass1" { - fillcolor=white label=<@layers.0.mlp.1
(GELU)
> labelloc=b penwidth=2.0 style="filled,dashed" - } - } - } - subgraph "cluster_layers.0_pass1" { - subgraph "cluster_layers.0.mlp_pass1" { - subgraph "cluster_layers.0.mlp.2_pass1" { - fillcolor=white label=<@layers.0.mlp.2
(Linear)
> labelloc=b penwidth=2.0 style="filled,dashed" - } - } - } - subgraph "cluster_layers.1_pass1" { - subgraph "cluster_layers.1.attn_pass1" { - subgraph "cluster_layers.1.attn.W_q_pass1" { - fillcolor=white label=<@layers.1.attn.W_q
(Linear)
> labelloc=b penwidth=2.0 style="filled,dashed" - } - } - } - subgraph "cluster_layers.1_pass1" { - subgraph "cluster_layers.1.attn_pass1" { - subgraph "cluster_layers.1.attn.W_k_pass1" { - fillcolor=white label=<@layers.1.attn.W_k
(Linear)
> labelloc=b penwidth=2.0 style="filled,dashed" - } - } - } - subgraph "cluster_layers.1_pass1" { - subgraph "cluster_layers.1.attn_pass1" { - subgraph "cluster_layers.1.attn.W_v_pass1" { - fillcolor=white label=<@layers.1.attn.W_v
(Linear)
> labelloc=b penwidth=2.0 style="filled,dashed" - } - } - } - subgraph "cluster_layers.1_pass1" { - subgraph "cluster_layers.1.attn_pass1" { - subgraph "cluster_layers.1.attn.W_o_pass1" { - fillcolor=white label=<@layers.1.attn.W_o
(Linear)
> labelloc=b penwidth=2.0 style="filled,dashed" - } - } - } - subgraph "cluster_layers.1_pass1" { - subgraph "cluster_layers.1.mlp_pass1" { - subgraph "cluster_layers.1.mlp.0_pass1" { - fillcolor=white label=<@layers.1.mlp.0
(Linear)
> labelloc=b penwidth=2.0 style="filled,dashed" - } - } - } - subgraph "cluster_layers.1_pass1" { - subgraph "cluster_layers.1.mlp_pass1" { - subgraph "cluster_layers.1.mlp.1_pass1" { - fillcolor=white label=<@layers.1.mlp.1
(GELU)
> labelloc=b penwidth=2.0 style="filled,dashed" - } - } - } - subgraph "cluster_layers.1_pass1" { - subgraph "cluster_layers.1.mlp_pass1" { - subgraph "cluster_layers.1.mlp.2_pass1" { - fillcolor=white label=<@layers.1.mlp.2
(Linear)
> labelloc=b penwidth=2.0 style="filled,dashed" - } - } - } - subgraph "cluster_layers.2_pass1" { - subgraph "cluster_layers.2.attn_pass1" { - subgraph "cluster_layers.2.attn.W_q_pass1" { - fillcolor=white label=<@layers.2.attn.W_q
(Linear)
> labelloc=b penwidth=2.0 style="filled,dashed" - } - } - } - subgraph "cluster_layers.2_pass1" { - subgraph "cluster_layers.2.attn_pass1" { - subgraph "cluster_layers.2.attn.W_k_pass1" { - fillcolor=white label=<@layers.2.attn.W_k
(Linear)
> labelloc=b penwidth=2.0 style="filled,dashed" - } - } - } - subgraph "cluster_layers.2_pass1" { - subgraph "cluster_layers.2.attn_pass1" { - subgraph "cluster_layers.2.attn.W_v_pass1" { - fillcolor=white label=<@layers.2.attn.W_v
(Linear)
> labelloc=b penwidth=2.0 style="filled,dashed" - } - } - } - subgraph "cluster_layers.2_pass1" { - subgraph "cluster_layers.2.attn_pass1" { - subgraph "cluster_layers.2.attn.W_o_pass1" { - fillcolor=white label=<@layers.2.attn.W_o
(Linear)
> labelloc=b penwidth=2.0 style="filled,dashed" - } - } - } - subgraph "cluster_layers.2_pass1" { - subgraph "cluster_layers.2.mlp_pass1" { - subgraph "cluster_layers.2.mlp.0_pass1" { - fillcolor=white label=<@layers.2.mlp.0
(Linear)
> labelloc=b penwidth=2.0 style="filled,dashed" - } - } - } - subgraph "cluster_layers.2_pass1" { - subgraph "cluster_layers.2.mlp_pass1" { - subgraph "cluster_layers.2.mlp.1_pass1" { - fillcolor=white label=<@layers.2.mlp.1
(GELU)
> labelloc=b penwidth=2.0 style="filled,dashed" - } - } - } - subgraph "cluster_layers.2_pass1" { - subgraph "cluster_layers.2.mlp_pass1" { - subgraph "cluster_layers.2.mlp.2_pass1" { - fillcolor=white label=<@layers.2.mlp.2
(Linear)
> labelloc=b penwidth=2.0 style="filled,dashed" - } - } - } -} diff --git a/graph.gv.pdf b/graph.gv.pdf deleted file mode 100644 index 308a93f..0000000 Binary files a/graph.gv.pdf and /dev/null differ diff --git a/input_tokens.json b/input_tokens.json new file mode 100644 index 0000000..3674a71 --- /dev/null +++ b/input_tokens.json @@ -0,0 +1 @@ +["BOS", 1, 0, 1, 1, 0] \ No newline at end of file diff --git a/load_and_visualize_with_torchlens.py b/load_and_visualize_with_torchlens.py deleted file mode 100644 index 4b0b404..0000000 --- a/load_and_visualize_with_torchlens.py +++ /dev/null @@ -1,70 +0,0 @@ -import os, sys -sys.path.append(os.path.dirname(__file__)) # add current folder to module search path -from tracr_transformer_pt import TracrTransformerPT -# load_and_visualize_with_torchlens.py -import numpy as np -import torch -import torchlens as tl -from tracr_transformer_pt import TracrTransformerPT - -# Instantiate the mirror model with your inferred hyperparams -model = TracrTransformerPT( - vocab_size=4, - max_seq_len=11, - d_model=24, - n_heads=3, - head_dim=4, # ★ critical — this makes 3*4 = 12 projection dim - n_layers=3, - d_mlp=4 -) - - -# Load NPZ exported from your JAX/Haiku model -npz = np.load("tracr_majority_params.npz") - -# Helper: copy (with optional transpose) into a torch parameter tensor -def copy_(pt_tensor, arr, transpose=False): - t = torch.tensor(arr) - if transpose: - t = t.T - assert tuple(t.shape) == tuple(pt_tensor.shape), f"Shape mismatch: src {t.shape} != dst {pt_tensor.shape}" - with torch.no_grad(): - pt_tensor.copy_(t) - -sd = model.state_dict() - -# ---- Embeddings (same layout as PyTorch, no transpose) ---- -copy_(sd["token_emb.weight"], npz["token_embed__embeddings"]) -copy_(sd["pos_emb.weight"], npz["pos_embed__embeddings"]) - -# ---- Per-layer mappings (notice: JAX Linear 'w' is (in, out); PyTorch is (out, in) -> transpose=True) ---- -for i in range(3): # layers 0..2 - # Attention projections - copy_(sd[f"layers.{i}.attn.W_q.weight"], npz[f"transformer__layer_{i}__attn__query__w"], transpose=True) # (24,12)->(12,24) - copy_(sd[f"layers.{i}.attn.W_q.bias"], npz[f"transformer__layer_{i}__attn__query__b"]) - copy_(sd[f"layers.{i}.attn.W_k.weight"], npz[f"transformer__layer_{i}__attn__key__w"], transpose=True) - copy_(sd[f"layers.{i}.attn.W_k.bias"], npz[f"transformer__layer_{i}__attn__key__b"]) - copy_(sd[f"layers.{i}.attn.W_v.weight"], npz[f"transformer__layer_{i}__attn__value__w"], transpose=True) - copy_(sd[f"layers.{i}.attn.W_v.bias"], npz[f"transformer__layer_{i}__attn__value__b"]) - - # Attention output projection ("linear") - copy_(sd[f"layers.{i}.attn.W_o.weight"], npz[f"transformer__layer_{i}__attn__linear__w"], transpose=True) # (12,24)->(24,12) - copy_(sd[f"layers.{i}.attn.W_o.bias"], npz[f"transformer__layer_{i}__attn__linear__b"]) - - # MLP 24->4->24 - copy_(sd[f"layers.{i}.mlp.0.weight"], npz[f"transformer__layer_{i}__mlp__linear_1__w"], transpose=True) # (24,4)->(4,24) - copy_(sd[f"layers.{i}.mlp.0.bias"], npz[f"transformer__layer_{i}__mlp__linear_1__b"]) - copy_(sd[f"layers.{i}.mlp.2.weight"], npz[f"transformer__layer_{i}__mlp__linear_2__w"], transpose=True) # (4,24)->(24,4) - copy_(sd[f"layers.{i}.mlp.2.bias"], npz[f"transformer__layer_{i}__mlp__linear_2__b"]) - -# Commit weights to the module -model.load_state_dict(sd) - -# ---- TorchLens over the SAME model ---- -# If you have your exact token-indexing, encode it here. For a quick diagram, dummy IDs in [0..3] work. -x = torch.randint(low=0, high=4, size=(1, 6)) # (B=1, T=6) e.g., [BOS, 1,0,1,1,0] with the right indices if you prefer -model.eval() - -# 1) Save full forward history AND 2) render the layered graph (unrolled by layer): -log = tl.log_forward_pass(model, x, vis_opt="unrolled") -tl.show_model_graph(model, (x,), vis_opt="unrolled", file_name="torchlens_majority_graph") diff --git a/my_majority_program.py b/my_majority_program.py deleted file mode 100644 index bc06cf4..0000000 --- a/my_majority_program.py +++ /dev/null @@ -1,183 +0,0 @@ -import os, sys -tracr_path = os.path.join(os.path.dirname(__file__), "Tracr", "tracr") -sys.path.insert(0, tracr_path) - -import numpy as np -import jax -import jax.random as random - -from tracr.compiler import compiling -from tracr.rasp import rasp - -# --- The robust function to import show_model --- -# --- The robust function to import show_model (returns None if not found) --- -def _get_show_model(): - import importlib, pkgutil, tracr - for modpath in ("tracr.compiler.visualization", "tracr.visualization"): - try: - mod = importlib.import_module(modpath) - fn = getattr(mod, "show_model", None) - if callable(fn): - return fn - except Exception: - pass - for _, name, _ in pkgutil.walk_packages(tracr.__path__, tracr.__name__ + "."): - try: - mod = importlib.import_module(name) - fn = getattr(mod, "show_model", None) - if callable(fn): - return fn - except Exception: - continue - return None # <- do NOT raise here -# --- Fallback: render a clean Tracr-style diagram from compiled params --- -def render_block_diagram_from_compiled(compiled_model, out_basename="tracr_majority_graph"): - import re - from graphviz import Digraph - # get params from the compiled model - try: - params = compiled_model.params - except AttributeError: - params = compiled_model.weights - - # flatten nested dict into "path" -> ndarray - flat = {} - def walk(path, node): - if isinstance(node, dict): - for k, v in node.items(): - walk(path + (k,), v) - else: - flat["/".join(path)] = np.array(node) - walk((), params) - - # read shapes - tok_key = next(k for k in flat if k.endswith("token_embed/embeddings")) - pos_key = next(k for k in flat if k.endswith("pos_embed/embeddings")) - vocab_size, d_model = flat[tok_key].shape - max_seq_len = flat[pos_key].shape[0] - - layer_nums = sorted({int(m.group(1)) - for k in flat - for m in [re.search(r"transformer/layer_(\d+)/", k)] - if m}) - # attn proj and mlp hidden from layer 0 - proj_dim = flat[f"transformer/layer_{layer_nums[0]}/attn/query/w"].shape[1] - mlp_hidden = flat[f"transformer/layer_{layer_nums[0]}/mlp/linear_1/w"].shape[1] - - dot = Digraph("tracr_majority_transformer", format="pdf") - dot.attr(rankdir="LR", fontsize="12", labelloc="t", - label=f"Tracr-compiled Majority Transformer\n" - f"vocab={vocab_size}, d_model={d_model}, layers={len(layer_nums)}, " - f"proj_dim={proj_dim}, mlp_hidden={mlp_hidden}") - - with dot.subgraph(name="cluster_embed") as c: - c.attr(label="Embeddings") - c.node("tok_emb", f"Token Embedding\n({vocab_size}×{d_model})", shape="box") - c.node("pos_emb", f"Positional Embedding\n({max_seq_len}×{d_model})", shape="box") - c.node("sum", "Add", shape="circle") - c.edges([("tok_emb", "sum"), ("pos_emb", "sum")]) - - prev = "sum" - for i in layer_nums: - with dot.subgraph(name=f"cluster_layer_{i}") as c: - c.attr(label=f"Encoder Block {i}") - c.node(f"attn_{i}", f"MHA proj {proj_dim}", shape="box") - c.node(f"add_attn_{i}", "Add", shape="circle") - c.node(f"mlp_{i}", f"MLP {d_model}→{mlp_hidden}→{d_model}", shape="box") - c.node(f"add_mlp_{i}", "Add", shape="circle") - dot.edge(prev, f"attn_{i}") - dot.edge(f"attn_{i}", f"add_attn_{i}") - dot.edge(f"add_attn_{i}", f"mlp_{i}") - dot.edge(f"mlp_{i}", f"add_mlp_{i}") - prev = f"add_mlp_{i}" - - dot.node("out", f"Output\n(seq_len×{d_model})", shape="box") - dot.edge(prev, "out") - out_path = dot.render(out_basename, cleanup=True) - print(f"Saved {out_path}") - - -show_model = _get_show_model() - - - -VOCAB = {0, 1} -MAX_SEQ_LEN = 10 -COMPILER_BOS = "BOS" - -# --- Majority Score Program --- -def majority_score_program(): - """ - RASP program that outputs whether 1s or 0s are the majority. - Positive = majority of 1s, Negative = majority of 0s, 0 = tie. - """ - # Select all positions - all_positions = rasp.Select(rasp.indices, rasp.indices, rasp.Comparison.TRUE) - - # Count number of 1s - select_ones = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.EQ) - num_ones = rasp.Aggregate(select_ones, rasp.tokens) - - # Count number of 0s (total length - num_ones) - seq_length = rasp.Aggregate(all_positions, rasp.tokens * 0 + 1) # sum of 1s over all positions - num_zeros = seq_length - num_ones - - # Majority score = (#1s - #0s) - majority_score = num_ones - num_zeros - - return majority_score - -# --- Compile --- -print("Compiling majority RASP program to transformer model...") -compiled_model = compiling.compile_rasp_to_model( - majority_score_program(), - vocab=VOCAB, - max_seq_len=MAX_SEQ_LEN, - compiler_bos=COMPILER_BOS, -) -print("Compilation complete!\n") - -# --- Save transformer diagram --- -print("Generating transformer visualization...") -show_model = _get_show_model() -if show_model is not None: - graph = show_model(compiled_model, max_seq_len=MAX_SEQ_LEN, return_graph=True) - graph.render("tracr_majority_graph", format="pdf", cleanup=True) -else: - print("Tracr show_model not found — using fallback renderer.") - render_block_diagram_from_compiled(compiled_model, out_basename="tracr_majority_graph") -print("Diagram saved as tracr_majority_graph.pdf ✅\n") - - -# --- Example --- -example_input_sequence = [1, 0, 1, 1, 0] - -print(f"Raw input sequence (no BOS): {example_input_sequence}") - -# Prepend BOS manually, since Tracr expects it -input_with_bos = [COMPILER_BOS] + example_input_sequence -print(f"Input sequence with BOS: {input_with_bos}") - -# Run model -output_logits = compiled_model.apply(input_with_bos) - -# Interpret output -vocab_list = sorted(list(VOCAB)) + [COMPILER_BOS] -predicted_tokens = [vocab_list[np.argmax(logits)] for logits in output_logits] - -print("\n--- Model Output ---") -print("Raw logits:\n", output_logits) -print("Predicted tokens:", predicted_tokens) - -# --- Run RASP directly --- -rasp_output = majority_score_program()(example_input_sequence) -print("Raw RASP output:", rasp_output) -print("\nRASP execution output:", rasp_output) - -majority_val = rasp_output[0] -if majority_val > 0: - print("Majority element: 1 ✅") -elif majority_val < 0: - print("Majority element: 0 ✅") -else: - print("Tie between 0 and 1 🤝") diff --git a/my_majority_torchlens.py b/my_majority_torchlens.py deleted file mode 100644 index 904ead7..0000000 --- a/my_majority_torchlens.py +++ /dev/null @@ -1,56 +0,0 @@ -import torch -import torch.nn as nn -import torchlens as tl # ★ import as a module; tl.show_model_graph used below - -class MyTransformer(nn.Module): - def __init__(self, input_dim=128, hidden_dim=256, num_layers=2, num_heads=4): - super().__init__() - self.embedding = nn.Linear(input_dim, hidden_dim) - encoder_layer = nn.TransformerEncoderLayer( - d_model=hidden_dim, nhead=num_heads, batch_first=True - ) - self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) - self.decoder = nn.Linear(hidden_dim, 10) - - def forward(self, x): - x = self.embedding(x) - x = self.encoder(x) - return self.decoder(x) - -# Optional wrapper not required for TorchLens; it will trace submodules anyway. -transformer = MyTransformer() - -# Example input (B, T, D_in) -x = torch.randn(2, 5, 128) - -# Option A: one-liner that logs activations and ALSO renders the graph. -log = tl.log_forward_pass(transformer, x, layers_to_save=None, vis_opt="unrolled") - - -# --- Make a wrapper that exposes internals clearly --- -class TransformerWrapper(nn.Module): - def __init__(self, transformer): - super().__init__() - self.model = transformer - - def forward(self, x): - # Instead of doing everything in one call, we explicitly call submodules - x = self.model.embedding(x) - # Go through each encoder layer explicitly so TorchLens can track them - for i, layer in enumerate(self.model.encoder.layers): - x = layer(x) - x = self.model.decoder(x) - return x - -# --- Instantiate and wrap the model --- -transformer = MyTransformer() -wrapped_model = TransformerWrapper(transformer) - -# --- Dummy input --- -x = torch.randn(2, 5, 128) - -# --- Run TorchLens logging --- -log = log_forward_pass(wrapped_model, x) - -# --- Visualize graph --- -show_model_graph(log) diff --git a/scripts/compile_export.py b/scripts/compile_export.py new file mode 100644 index 0000000..f92caf2 --- /dev/null +++ b/scripts/compile_export.py @@ -0,0 +1,132 @@ +#!/usr/bin/env python3 +# scripts/compile_export.py +import os, sys, json +import numpy as np +from pathlib import Path +import sys + +REPO_ROOT = Path(__file__).resolve().parents[1] + +# Add local tracr paths if they exist; otherwise rely on pip-installed package +for p in [ + REPO_ROOT / "external" / "Tracr" / "tracr", # optional vendored location + REPO_ROOT / "Tracr" / "tracr", # older local layout + REPO_ROOT / "tracr", # just in case +]: + if p.is_dir(): + sys.path.insert(0, str(p)) + break + +# Now import (works with either local path OR pip-installed package) +from tracr.compiler import compiling +from tracr.rasp import rasp + + +# -------- Config -------- +VOCAB = {0, 1} +MAX_SEQ_LEN = 10 +COMPILER_BOS = "BOS" +COMPILER_PAD = "PAD" +CAUSAL = True +EXAMPLE = [1, 0, 1, 1, 0] + +ART = REPO_ROOT / "artifacts" +ART.mkdir(exist_ok=True) + +def majority_program(): + # majority = 2 * (#ones) - seq_len + all_pos = rasp.Select(rasp.indices, rasp.indices, rasp.Comparison.EQ) + ones_vec = rasp.Map(lambda t: t, rasp.tokens) # 0/1 numeric + num_ones = rasp.Aggregate(all_pos, ones_vec) + ones_const = rasp.Map(lambda _: 1, rasp.tokens) + seq_len = rasp.Aggregate(all_pos, ones_const) + return (2 * num_ones) - seq_len + +def compile_tracr(prog): + kw = dict(vocab=VOCAB, max_seq_len=MAX_SEQ_LEN, compiler_bos=COMPILER_BOS) + try: + return compiling.compile_rasp_to_model(prog, compiler_pad=COMPILER_PAD, causal=CAUSAL, **kw) + except TypeError: + pass + try: + return compiling.compile_rasp_to_model(prog, causal=CAUSAL, **kw) + except TypeError: + return compiling.compile_rasp_to_model(prog, **kw) + +def get_tok2id_or_fallback(model): + cands = [ + getattr(model, "tokenizer", None), + getattr(model, "vocab", None), + getattr(model, "vocabulary", None), + getattr(getattr(model, "transformer", None), "tokenizer", None), + getattr(getattr(model, "transformer", None), "vocab", None), + getattr(getattr(model, "transformer", None), "vocabulary", None), + ] + for obj in cands: + if obj is None: continue + if isinstance(obj, dict) and obj: return obj + if hasattr(obj, "token_to_id") and isinstance(obj.token_to_id, dict): + return obj.token_to_id + if hasattr(obj, "id_to_token") and isinstance(obj.id_to_token, dict): + return {tok: int(i) for i, tok in obj.id_to_token.items()} + # fallback mapping (deterministic) + ordered = [COMPILER_BOS] + sorted(list(VOCAB), key=lambda x: repr(x)) + [COMPILER_PAD] + mapping = {tok: i for i, tok in enumerate(ordered)} + print("[WARN] Using fallback token mapping:", mapping) + return mapping + +def export_params_npz(compiled, out_path: Path, keys_path: Path): + try: + params = compiled.params + except AttributeError: + params = compiled.weights + + leaves = [] + def walk(path, node): + if isinstance(node, dict): + for k, v in node.items(): walk(path + (k,), v) + elif isinstance(node, (list, tuple)): + for i, v in enumerate(node): walk(path + (str(i),), v) + else: + leaves.append(("/".join(path), np.array(node))) + + walk((), params) + npz_dict = {k.replace("/", "__"): v for k, v in leaves} + np.savez(out_path, **npz_dict) + keys_path.write_text(json.dumps(sorted(npz_dict.keys()), indent=2)) + print(f"Exported params -> {out_path} (keys -> {keys_path})") + +def main(): + print("Compiling RASP → Tracr transformer…") + compiled = compile_tracr(majority_program()) + print("Done.\n") + + # tokenizer mapping: only write fallback if no existing discovered mapping + tok_map_path = ART / "token_to_id.json" + tok2id = get_tok2id_or_fallback(compiled) + if not tok_map_path.exists(): + tok_map_path.write_text(json.dumps(tok2id)) + print("Saved token_to_id.json") + else: + print("token_to_id.json exists; keeping discovered mapping.") + + # exact tokens used for the reference pass + tokens = [COMPILER_BOS] + EXAMPLE + (ART / "input_tokens.json").write_text(json.dumps(tokens)) + print("Saved input_tokens.json") + + # forward pass on assembled model + out = compiled.apply(tokens) + arr = np.array(getattr(out, "transformer_output", out), dtype=np.float32) + if arr.ndim == 2: arr = arr[None, ...] + np.save(ART / "tracr_output.npy", arr) + print(f"Saved tracr_output.npy with shape {arr.shape} (dtype={arr.dtype})") + + # export params from THIS compiled model + export_params_npz(compiled, ART / "tracr_majority_params.npz", ART / "tracr_param_keys.json") + + print("\nNow run:") + print(" python scripts/parity_check.py") + +if __name__ == "__main__": + main() diff --git a/scripts/parity_check.py b/scripts/parity_check.py new file mode 100644 index 0000000..f6383e1 --- /dev/null +++ b/scripts/parity_check.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python3 +# scripts/parity_check.py +import json, itertools, numpy as np, torch +from pathlib import Path +import sys + +ROOT = Path(__file__).resolve().parents[1] +ART = ROOT / "artifacts" +sys.path.append(str(ROOT / "src")) + +from tracr_pt_model import TracrTransformerPT + +# ---- Load NPZ & infer dims ---- +npz = np.load(ART / "tracr_majority_params.npz") +def get(k): return npz.get(k, npz[k.replace("/", "__")]) + +vocab_size, d_model = get("token_embed/embeddings").shape +max_len = get("pos_embed/embeddings").shape[0] +proj_dim = int(get("transformer/layer_0/attn/query/w").shape[1]) # JAX (in,out) +d_mlp = int(get("transformer/layer_0/mlp/linear_1/w").shape[1]) +n_layers = sum((f"transformer/layer_{i}/attn/query/w".replace("/", "__") in npz) for i in range(64)) + +# Your matched config +n_heads, head_dim = 2, proj_dim // 2 + +print(f"Inferred -> d_model={d_model}, vocab={vocab_size}, max_seq_len={max_len}, " + f"layers={n_layers}, proj_dim={proj_dim}, d_mlp={d_mlp}, heads={n_heads}, head_dim={head_dim}") + +# ---- Build PT model & load weights ---- +model = TracrTransformerPT(vocab_size, max_len, int(d_model), int(n_layers), int(d_mlp), + n_heads=int(n_heads), head_dim=int(head_dim)).eval() + +def load_linear(L, w_key, b_key): + w = torch.from_numpy(get(w_key)).float().t().contiguous() # (in,out) -> (out,in) + b = torch.from_numpy(get(b_key)).float() + with torch.no_grad(): L.weight.copy_(w); L.bias.copy_(b) + +with torch.no_grad(): + model.token_emb.weight.copy_(torch.from_numpy(get("token_embed/embeddings")).float()) + model.pos_emb.weight.copy_(torch.from_numpy(get("pos_embed/embeddings")).float()) + +for i in range(n_layers): + P = f"transformer/layer_{i}" + blk = model.layers[i] + load_linear(blk.attn.W_q, f"{P}/attn/query/w", f"{P}/attn/query/b") + load_linear(blk.attn.W_k, f"{P}/attn/key/w", f"{P}/attn/key/b") + load_linear(blk.attn.W_v, f"{P}/attn/value/w", f"{P}/attn/value/b") + load_linear(blk.attn.W_o, f"{P}/attn/linear/w", f"{P}/attn/linear/b") + load_linear(blk.mlp.fc1, f"{P}/mlp/linear_1/w", f"{P}/mlp/linear_1/b") + load_linear(blk.mlp.fc2, f"{P}/mlp/linear_2/w", f"{P}/mlp/linear_2/b") + +# ---- Read tokens & Tracr reference ---- +tokens = json.loads((ART / "input_tokens.json").read_text()) # ["BOS", 1, 0, 1, 1, 0] +ref = torch.from_numpy(np.load(ART / "tracr_output.npy")).float() + +# ---- Discover BOS/0/1/PAD mapping once ---- +TOKS = ["BOS", "0", "1", "PAD"] +best = None +for perm in itertools.permutations(range(vocab_size), vocab_size): + tok2id = {TOKS[i]: perm[i] for i in range(vocab_size)} + ids = [tok2id["BOS"]] + [tok2id[str(t)] for t in tokens[1:]] + ids = torch.tensor([ids], dtype=torch.long) + with torch.no_grad(): out = model(ids) + md = (out - ref).abs().max().item() + ok = torch.allclose(out, ref, atol=1e-5) + cand = (tok2id, md, ok) + if best is None or md < best[1]: best = cand + if ok: break + +tok2id, md, ok = best +(ART / "token_to_id.json").write_text(json.dumps(tok2id, indent=2)) + +print("\n--- Mapping search ---") +print(f"tok2id={tok2id}, max_diff={md:.6g}, match={ok}") +print("\n--- Sanity Check ---") +print(f"Outputs match: {ok}") +print(f"Max abs diff: {md:.6g}") + +sys.exit(0 if ok else 1) \ No newline at end of file diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/__pycache__/tracr_pt_model.cpython-313.pyc b/src/__pycache__/tracr_pt_model.cpython-313.pyc new file mode 100644 index 0000000..d15c121 Binary files /dev/null and b/src/__pycache__/tracr_pt_model.cpython-313.pyc differ diff --git a/src/tracr_pt_model.py b/src/tracr_pt_model.py new file mode 100644 index 0000000..03870d3 --- /dev/null +++ b/src/tracr_pt_model.py @@ -0,0 +1,66 @@ +# src/tracr_pt_model.py +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +class MultiheadSelfAttention(nn.Module): + def __init__(self, d_model, n_heads, head_dim, causal=True): + super().__init__() + self.n_heads = n_heads + self.head_dim = head_dim + self.causal = causal + proj = n_heads * head_dim + self.W_q = nn.Linear(d_model, proj, bias=True) + self.W_k = nn.Linear(d_model, proj, bias=True) + self.W_v = nn.Linear(d_model, proj, bias=True) + self.W_o = nn.Linear(proj, d_model, bias=True) + + def forward(self, x): + B, T, _ = x.shape + def split(L): + y = L(x).view(B, T, self.n_heads, self.head_dim) + return y.permute(0, 2, 1, 3) # (B, H, T, D) + q, k, v = split(self.W_q), split(self.W_k), split(self.W_v) + scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) + if self.causal: + mask = torch.triu(torch.ones(T, T, device=x.device), 1).bool() + scores = scores.masked_fill(mask, float("-inf")) + attn = torch.softmax(scores, dim=-1) + ctx = torch.matmul(attn, v).permute(0, 2, 1, 3).contiguous().view(B, T, self.n_heads * self.head_dim) + return self.W_o(ctx) + +class MLP(nn.Module): + def __init__(self, d_model, d_mlp): + super().__init__() + self.fc1 = nn.Linear(d_model, d_mlp, bias=True) + self.fc2 = nn.Linear(d_mlp, d_model, bias=True) + def forward(self, x): + return self.fc2(F.relu(self.fc1(x))) + +class EncoderBlock(nn.Module): + def __init__(self, d_model, n_heads, head_dim, d_mlp): + super().__init__() + self.attn = MultiheadSelfAttention(d_model, n_heads, head_dim, causal=True) + self.mlp = MLP(d_model, d_mlp) + def forward(self, x): + x = x + self.attn(x) # Attn → MLP (sequential residuals) + x = x + self.mlp(x) + return x + +class TracrTransformerPT(nn.Module): + def __init__(self, vocab_size, max_seq_len, d_model, n_layers, d_mlp, n_heads=2, head_dim=12): + super().__init__() + self.token_emb = nn.Embedding(vocab_size, d_model) + self.pos_emb = nn.Embedding(max_seq_len, d_model) + self.layers = nn.ModuleList([ + EncoderBlock(d_model, n_heads, head_dim, d_mlp) for _ in range(n_layers) + ]) + + def forward(self, token_ids): + B, T = token_ids.shape + pos = torch.arange(T, device=token_ids.device) # positions start at 0 + x = self.token_emb(token_ids) + self.pos_emb(pos)[None, :, :] + for blk in self.layers: + x = blk(x) + return x diff --git a/token_to_id.json b/token_to_id.json new file mode 100644 index 0000000..6966287 --- /dev/null +++ b/token_to_id.json @@ -0,0 +1,6 @@ +{ + "BOS": 2, + "0": 0, + "1": 1, + "PAD": 3 +} \ No newline at end of file diff --git a/tracr_majority_graph.pdf b/tracr_majority_graph.pdf deleted file mode 100644 index 88f2c1a..0000000 Binary files a/tracr_majority_graph.pdf and /dev/null differ diff --git a/tracr_majority_params.npz b/tracr_majority_params.npz index 437d52d..ce7aa3a 100644 Binary files a/tracr_majority_params.npz and b/tracr_majority_params.npz differ diff --git a/tracr_output.npy b/tracr_output.npy new file mode 100644 index 0000000..a802065 Binary files /dev/null and b/tracr_output.npy differ diff --git a/tracr_param_keys.json b/tracr_param_keys.json new file mode 100644 index 0000000..460b71c --- /dev/null +++ b/tracr_param_keys.json @@ -0,0 +1,40 @@ +[ + "pos_embed__embeddings", + "token_embed__embeddings", + "transformer__layer_0__attn__key__b", + "transformer__layer_0__attn__key__w", + "transformer__layer_0__attn__linear__b", + "transformer__layer_0__attn__linear__w", + "transformer__layer_0__attn__query__b", + "transformer__layer_0__attn__query__w", + "transformer__layer_0__attn__value__b", + "transformer__layer_0__attn__value__w", + "transformer__layer_0__mlp__linear_1__b", + "transformer__layer_0__mlp__linear_1__w", + "transformer__layer_0__mlp__linear_2__b", + "transformer__layer_0__mlp__linear_2__w", + "transformer__layer_1__attn__key__b", + "transformer__layer_1__attn__key__w", + "transformer__layer_1__attn__linear__b", + "transformer__layer_1__attn__linear__w", + "transformer__layer_1__attn__query__b", + "transformer__layer_1__attn__query__w", + "transformer__layer_1__attn__value__b", + "transformer__layer_1__attn__value__w", + "transformer__layer_1__mlp__linear_1__b", + "transformer__layer_1__mlp__linear_1__w", + "transformer__layer_1__mlp__linear_2__b", + "transformer__layer_1__mlp__linear_2__w", + "transformer__layer_2__attn__key__b", + "transformer__layer_2__attn__key__w", + "transformer__layer_2__attn__linear__b", + "transformer__layer_2__attn__linear__w", + "transformer__layer_2__attn__query__b", + "transformer__layer_2__attn__query__w", + "transformer__layer_2__attn__value__b", + "transformer__layer_2__attn__value__w", + "transformer__layer_2__mlp__linear_1__b", + "transformer__layer_2__mlp__linear_1__w", + "transformer__layer_2__mlp__linear_2__b", + "transformer__layer_2__mlp__linear_2__w" +] \ No newline at end of file diff --git a/tracr_transformer_pt.py b/tracr_transformer_pt.py deleted file mode 100644 index 2a9a656..0000000 --- a/tracr_transformer_pt.py +++ /dev/null @@ -1,60 +0,0 @@ -# tracr_transformer_pt.py (replace your attention + block signatures with this) - -import math -import torch -import torch.nn as nn - -class MultiheadSelfAttention(nn.Module): - def __init__(self, d_model=24, n_heads=3, head_dim=4, bias=True): - super().__init__() - self.d_model = d_model - self.n_heads = n_heads - self.head_dim = head_dim - self.proj_dim = n_heads * head_dim # = 12 (matches your JAX dump) - - self.W_q = nn.Linear(d_model, n_heads * head_dim, bias=bias) # 24 -> 12 - self.W_k = nn.Linear(d_model, n_heads * head_dim, bias=bias) # 24 -> 12 - self.W_v = nn.Linear(d_model, n_heads * head_dim, bias=bias) # 24 -> 12 - self.W_o = nn.Linear(n_heads * head_dim, d_model, bias=bias) # 12 -> 24 - - - def forward(self, x): - B, T, C = x.shape - q = self.W_q(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) # (B,H,T,4) - k = self.W_k(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) - v = self.W_v(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) - - att = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) # (B,H,T,T) - att = torch.softmax(att, dim=-1) - y = torch.matmul(att, v) # (B,H,T,4) - y = y.transpose(1, 2).contiguous().view(B, T, self.proj_dim) # (B,T,12) - return self.W_o(y) # (B,T,24) - -class EncoderBlock(nn.Module): - def __init__(self, d_model=24, n_heads=3, head_dim=4, d_mlp=4): - super().__init__() - self.attn = MultiheadSelfAttention(d_model=d_model, n_heads=n_heads, head_dim=head_dim, bias=True) - self.mlp = nn.Sequential( - nn.Linear(d_model, d_mlp, bias=True), # 24 -> 4 - nn.GELU(), - nn.Linear(d_mlp, d_model, bias=True), # 4 -> 24 - ) - - def forward(self, x): - x = x + self.attn(x) - x = x + self.mlp(x) - return x - -class TracrTransformerPT(nn.Module): - def __init__(self, vocab_size=4, max_seq_len=11, d_model=24, n_heads=3, head_dim=4, n_layers=3, d_mlp=4): - super().__init__() - self.token_emb = nn.Embedding(vocab_size, d_model) - self.pos_emb = nn.Embedding(max_seq_len, d_model) - self.layers = nn.ModuleList([EncoderBlock(d_model, n_heads, head_dim, d_mlp) for _ in range(n_layers)]) - - def forward(self, token_ids): - B, T = token_ids.shape - x = self.token_emb(token_ids) + self.pos_emb(torch.arange(T, device=token_ids.device).unsqueeze(0).expand(B, T)) - for blk in self.layers: - x = blk(x) - return x # (B, T, 24)