diff --git a/.github/workflows/parity.yml b/.github/workflows/parity.yml
new file mode 100644
index 0000000..238edfe
--- /dev/null
+++ b/.github/workflows/parity.yml
@@ -0,0 +1,27 @@
+name: Parity check
+
+on:
+ push:
+ pull_request:
+
+jobs:
+ parity:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v4
+ - uses: actions/setup-python@v5
+ with:
+ python-version: "3.11"
+
+ - name: Install deps (CPU only)
+ run: |
+ python -m pip install --upgrade pip
+ pip install --index-url https://download.pytorch.org/whl/cpu torch
+ pip install numpy "jax[cpu]" dm-haiku
+ pip install git+https://github.com/google-deepmind/tracr.git
+
+ - name: Compile & export tracr
+ run: python scripts/compile_export.py
+
+ - name: Verify parity (fail if mismatch)
+ run: python scripts/parity_check.py
diff --git a/__pycache__/tracr_transformer_pt.cpython-313.pyc b/__pycache__/tracr_transformer_pt.cpython-313.pyc
index 1f6d124..5a5b73c 100644
Binary files a/__pycache__/tracr_transformer_pt.cpython-313.pyc and b/__pycache__/tracr_transformer_pt.cpython-313.pyc differ
diff --git a/artifacts/input_tokens.json b/artifacts/input_tokens.json
new file mode 100644
index 0000000..3674a71
--- /dev/null
+++ b/artifacts/input_tokens.json
@@ -0,0 +1 @@
+["BOS", 1, 0, 1, 1, 0]
\ No newline at end of file
diff --git a/artifacts/token_to_id.json b/artifacts/token_to_id.json
new file mode 100644
index 0000000..6966287
--- /dev/null
+++ b/artifacts/token_to_id.json
@@ -0,0 +1,6 @@
+{
+ "BOS": 2,
+ "0": 0,
+ "1": 1,
+ "PAD": 3
+}
\ No newline at end of file
diff --git a/artifacts/tracr_majority_params.npz b/artifacts/tracr_majority_params.npz
new file mode 100644
index 0000000..ce7aa3a
Binary files /dev/null and b/artifacts/tracr_majority_params.npz differ
diff --git a/artifacts/tracr_output.npy b/artifacts/tracr_output.npy
new file mode 100644
index 0000000..a802065
Binary files /dev/null and b/artifacts/tracr_output.npy differ
diff --git a/artifacts/tracr_param_keys.json b/artifacts/tracr_param_keys.json
new file mode 100644
index 0000000..460b71c
--- /dev/null
+++ b/artifacts/tracr_param_keys.json
@@ -0,0 +1,40 @@
+[
+ "pos_embed__embeddings",
+ "token_embed__embeddings",
+ "transformer__layer_0__attn__key__b",
+ "transformer__layer_0__attn__key__w",
+ "transformer__layer_0__attn__linear__b",
+ "transformer__layer_0__attn__linear__w",
+ "transformer__layer_0__attn__query__b",
+ "transformer__layer_0__attn__query__w",
+ "transformer__layer_0__attn__value__b",
+ "transformer__layer_0__attn__value__w",
+ "transformer__layer_0__mlp__linear_1__b",
+ "transformer__layer_0__mlp__linear_1__w",
+ "transformer__layer_0__mlp__linear_2__b",
+ "transformer__layer_0__mlp__linear_2__w",
+ "transformer__layer_1__attn__key__b",
+ "transformer__layer_1__attn__key__w",
+ "transformer__layer_1__attn__linear__b",
+ "transformer__layer_1__attn__linear__w",
+ "transformer__layer_1__attn__query__b",
+ "transformer__layer_1__attn__query__w",
+ "transformer__layer_1__attn__value__b",
+ "transformer__layer_1__attn__value__w",
+ "transformer__layer_1__mlp__linear_1__b",
+ "transformer__layer_1__mlp__linear_1__w",
+ "transformer__layer_1__mlp__linear_2__b",
+ "transformer__layer_1__mlp__linear_2__w",
+ "transformer__layer_2__attn__key__b",
+ "transformer__layer_2__attn__key__w",
+ "transformer__layer_2__attn__linear__b",
+ "transformer__layer_2__attn__linear__w",
+ "transformer__layer_2__attn__query__b",
+ "transformer__layer_2__attn__query__w",
+ "transformer__layer_2__attn__value__b",
+ "transformer__layer_2__attn__value__w",
+ "transformer__layer_2__mlp__linear_1__b",
+ "transformer__layer_2__mlp__linear_1__w",
+ "transformer__layer_2__mlp__linear_2__b",
+ "transformer__layer_2__mlp__linear_2__w"
+]
\ No newline at end of file
diff --git a/export_tracr_params.py b/export_tracr_params.py
deleted file mode 100644
index 1b67dbc..0000000
--- a/export_tracr_params.py
+++ /dev/null
@@ -1,60 +0,0 @@
-# export_tracr_params.py
-import numpy as np
-import jax
-from tracr.compiler import compiling
-from tracr.rasp import rasp
-
-VOCAB = {0, 1}
-MAX_SEQ_LEN = 10
-COMPILER_BOS = "BOS"
-
-def majority_score_program():
- all_pos = rasp.Select(rasp.indices, rasp.indices, rasp.Comparison.TRUE)
- select_ones = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.EQ)
- num_ones = rasp.Aggregate(select_ones, rasp.tokens)
- seq_len = rasp.Aggregate(all_pos, rasp.tokens * 0 + 1)
- majority_score = num_ones - (seq_len - num_ones)
- return majority_score
-
-print("Compiling RASP → JAX/Haiku transformer…")
-compiled = compiling.compile_rasp_to_model(
- majority_score_program(),
- vocab=VOCAB,
- max_seq_len=MAX_SEQ_LEN,
- compiler_bos=COMPILER_BOS,
-)
-print("Done.")
-
-# --- Inspect param tree to learn names you must map ---
-# Depending on Tracr version, compiled.params or compiled.weights exists.
-# Print keys so we can map them into PyTorch:
-try:
- params = compiled.params
-except AttributeError:
- params = compiled.weights # fallback if older API
-
-flat, treedef = jax.tree_util.tree_flatten(params)
-leaves_with_paths = []
-
-def track_paths(path, node):
- if isinstance(node, (dict,)):
- for k,v in node.items():
- track_paths(path + (k,), v)
- else:
- leaves_with_paths.append(("/".join(path), node))
-
-track_paths((), params)
-
-print("\n=== JAX PARAM KEYS (preview) ===")
-for k, v in leaves_with_paths:
- print(f"{k}: shape={np.array(v).shape}")
-print("=== end ===\n")
-
-# --- Save to NPZ with slash->double_underscore to be filesystem-friendly ---
-npz_dict = {}
-for k, v in leaves_with_paths:
- safe = k.replace("/", "__")
- npz_dict[safe] = np.array(v)
-
-np.savez("tracr_majority_params.npz", **npz_dict)
-print("Exported => tracr_majority_params.npz")
diff --git a/graph.gv b/graph.gv
deleted file mode 100644
index 2470339..0000000
--- a/graph.gv
+++ /dev/null
@@ -1,366 +0,0 @@
-// Computational graph for the feedforward sweep
-digraph TracrTransformerPT {
- graph [label=<TracrTransformerPT
77 tensors total (36.7 KB)
4656 params total (22.7 KB)
> labeljust=left labelloc=t ordering=out rankdir=BT]
- node [ordering=out]
- input_1 [label=<input_1
1x6 (176 B)
@input.token_ids> color=black fillcolor="#98FB98" fontcolor=black ordering=out shape=oval style="filled,solid"]
- input_1 -> embedding_1_1 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- embedding_1_1 [label=<embedding_1_1
1x6x24 (720 B)
params: 4x24
@token_emb> color=black fillcolor="#E6E6E6" fontcolor=black ordering=out shape=box style="filled,solid"]
- embedding_1_1 -> add_1_6 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- arange_1_2 [label=<arange_1_2
x6 (160 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,dashed"]
- arange_1_2 -> unsqueeze_1_3 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=dashed]
- unsqueeze_1_3 [label=<unsqueeze_1_3
1x6 (176 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,dashed"]
- unsqueeze_1_3 -> expand_1_4 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=dashed]
- expand_1_4 [label=<expand_1_4
1x6 (176 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,dashed"]
- expand_1_4 -> embedding_2_5 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=dashed]
- embedding_2_5 [label=<embedding_2_5
1x6x24 (720 B)
params: 11x24
@pos_emb> color=black fillcolor="#E6E6E6" fontcolor=black ordering=out shape=box style="filled,dashed"]
- embedding_2_5 -> add_1_6 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=dashed]
- add_1_6 [label=<add_1_6
1x6x24 (720 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"]
- add_1_6 -> linear_1_7 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- add_1_6 -> linear_2_10 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- add_1_6 -> linear_3_13 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- add_1_6 -> add_2_25pass1 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- linear_1_7 [label=<linear_1_7
1x6x12 (432 B)
params: 12x24, x12
@layers.0.attn.W_q> color=black fillcolor="#E6E6E6" fontcolor=black ordering=out shape=box style="filled,solid"]
- view_1_8 [label=<view_1_8
1x6x3x4 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"]
- transpose_1_9 [label=<transpose_1_9
1x3x6x4 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"]
- linear_2_10 [label=<linear_2_10
1x6x12 (432 B)
params: 12x24, x12
@layers.0.attn.W_k> color=black fillcolor="#E6E6E6" fontcolor=black ordering=out shape=box style="filled,solid"]
- view_2_11 [label=<view_2_11
1x6x3x4 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"]
- transpose_2_12 [label=<transpose_2_12
1x3x6x4 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"]
- linear_3_13 [label=<linear_3_13
1x6x12 (432 B)
params: 12x24, x12
@layers.0.attn.W_v> color=black fillcolor="#E6E6E6" fontcolor=black ordering=out shape=box style="filled,solid"]
- view_3_14 [label=<view_3_14
1x6x3x4 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"]
- transpose_3_15 [label=<transpose_3_15
1x3x6x4 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"]
- transpose_4_16 [label=<transpose_4_16
1x3x4x6 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"]
- matmul_1_17 [label=<matmul_1_17
1x3x6x6 (592 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"]
- truediv_1_18 [label=<truediv_1_18
1x3x6x6 (592 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"]
- softmax_1_19 [label=<softmax_1_19
1x3x6x6 (592 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"]
- matmul_2_20 [label=<matmul_2_20
1x3x6x4 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"]
- transpose_5_21 [label=<transpose_5_21
1x6x3x4 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"]
- contiguous_1_22 [label=<contiguous_1_22
1x6x3x4 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"]
- view_4_23 [label=<view_4_23
1x6x12 (432 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"]
- linear_4_24 [label=<linear_4_24
1x6x24 (720 B)
params: 24x12, x24
@layers.0.attn.W_o> color=black fillcolor="#E6E6E6" fontcolor=black ordering=out shape=box style="filled,solid"]
- add_2_25pass1 [label=<add_2_25:1
1x6x24 (720 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"]
- linear_5_26 [label=<linear_5_26
1x6x4 (240 B)
params: 4x24, x4
@layers.0.mlp.0> color=black fillcolor="#E6E6E6" fontcolor=black ordering=out shape=box style="filled,solid"]
- gelu_1_27 [label=<gelu_1_27
1x6x4 (240 B)
@layers.0.mlp.1> color=black fillcolor=white fontcolor=black ordering=out shape=box style="filled,solid"]
- linear_6_28 [label=<linear_6_28
1x6x24 (720 B)
params: 24x4, x24
@layers.0.mlp.2> color=black fillcolor="#E6E6E6" fontcolor=black ordering=out shape=box style="filled,solid"]
- add_2_25pass2 [label=<add_2_25:2
1x6x24 (720 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"]
- add_2_25pass2 -> linear_7_29 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- add_2_25pass2 -> linear_8_32 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- add_2_25pass2 -> linear_9_35 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- add_2_25pass2 -> add_3_47pass1 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- linear_7_29 [label=<linear_7_29
1x6x12 (432 B)
params: 12x24, x12
@layers.1.attn.W_q> color=black fillcolor="#E6E6E6" fontcolor=black ordering=out shape=box style="filled,solid"]
- view_5_30 [label=<view_5_30
1x6x3x4 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"]
- transpose_6_31 [label=<transpose_6_31
1x3x6x4 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"]
- linear_8_32 [label=<linear_8_32
1x6x12 (432 B)
params: 12x24, x12
@layers.1.attn.W_k> color=black fillcolor="#E6E6E6" fontcolor=black ordering=out shape=box style="filled,solid"]
- view_6_33 [label=<view_6_33
1x6x3x4 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"]
- transpose_7_34 [label=<transpose_7_34
1x3x6x4 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"]
- linear_9_35 [label=<linear_9_35
1x6x12 (432 B)
params: 12x24, x12
@layers.1.attn.W_v> color=black fillcolor="#E6E6E6" fontcolor=black ordering=out shape=box style="filled,solid"]
- view_7_36 [label=<view_7_36
1x6x3x4 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"]
- transpose_8_37 [label=<transpose_8_37
1x3x6x4 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"]
- transpose_9_38 [label=<transpose_9_38
1x3x4x6 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"]
- matmul_3_39 [label=<matmul_3_39
1x3x6x6 (592 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"]
- truediv_2_40 [label=<truediv_2_40
1x3x6x6 (592 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"]
- softmax_2_41 [label=<softmax_2_41
1x3x6x6 (592 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"]
- matmul_4_42 [label=<matmul_4_42
1x3x6x4 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"]
- transpose_10_43 [label=<transpose_10_43
1x6x3x4 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"]
- contiguous_2_44 [label=<contiguous_2_44
1x6x3x4 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"]
- view_8_45 [label=<view_8_45
1x6x12 (432 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"]
- linear_10_46 [label=<linear_10_46
1x6x24 (720 B)
params: 24x12, x24
@layers.1.attn.W_o> color=black fillcolor="#E6E6E6" fontcolor=black ordering=out shape=box style="filled,solid"]
- add_3_47pass1 [label=<add_3_47:1
1x6x24 (720 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"]
- linear_11_48 [label=<linear_11_48
1x6x4 (240 B)
params: 4x24, x4
@layers.1.mlp.0> color=black fillcolor="#E6E6E6" fontcolor=black ordering=out shape=box style="filled,solid"]
- gelu_2_49 [label=<gelu_2_49
1x6x4 (240 B)
@layers.1.mlp.1> color=black fillcolor=white fontcolor=black ordering=out shape=box style="filled,solid"]
- linear_12_50 [label=<linear_12_50
1x6x24 (720 B)
params: 24x4, x24
@layers.1.mlp.2> color=black fillcolor="#E6E6E6" fontcolor=black ordering=out shape=box style="filled,solid"]
- add_3_47pass2 [label=<add_3_47:2
1x6x24 (720 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"]
- add_3_47pass2 -> linear_13_51 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- add_3_47pass2 -> linear_14_54 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- add_3_47pass2 -> linear_15_57 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- add_3_47pass2 -> add_4_69pass1 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- linear_13_51 [label=<linear_13_51
1x6x12 (432 B)
params: 12x24, x12
@layers.2.attn.W_q> color=black fillcolor="#E6E6E6" fontcolor=black ordering=out shape=box style="filled,solid"]
- view_9_52 [label=<view_9_52
1x6x3x4 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"]
- transpose_11_53 [label=<transpose_11_53
1x3x6x4 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"]
- linear_14_54 [label=<linear_14_54
1x6x12 (432 B)
params: 12x24, x12
@layers.2.attn.W_k> color=black fillcolor="#E6E6E6" fontcolor=black ordering=out shape=box style="filled,solid"]
- view_10_55 [label=<view_10_55
1x6x3x4 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"]
- transpose_12_56 [label=<transpose_12_56
1x3x6x4 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"]
- linear_15_57 [label=<linear_15_57
1x6x12 (432 B)
params: 12x24, x12
@layers.2.attn.W_v> color=black fillcolor="#E6E6E6" fontcolor=black ordering=out shape=box style="filled,solid"]
- view_11_58 [label=<view_11_58
1x6x3x4 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"]
- transpose_13_59 [label=<transpose_13_59
1x3x6x4 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"]
- transpose_14_60 [label=<transpose_14_60
1x3x4x6 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"]
- matmul_5_61 [label=<matmul_5_61
1x3x6x6 (592 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"]
- truediv_3_62 [label=<truediv_3_62
1x3x6x6 (592 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"]
- softmax_3_63 [label=<softmax_3_63
1x3x6x6 (592 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"]
- matmul_6_64 [label=<matmul_6_64
1x3x6x4 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"]
- transpose_15_65 [label=<transpose_15_65
1x6x3x4 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"]
- contiguous_3_66 [label=<contiguous_3_66
1x6x3x4 (448 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"]
- view_12_67 [label=<view_12_67
1x6x12 (432 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"]
- linear_16_68 [label=<linear_16_68
1x6x24 (720 B)
params: 24x12, x24
@layers.2.attn.W_o> color=black fillcolor="#E6E6E6" fontcolor=black ordering=out shape=box style="filled,solid"]
- add_4_69pass1 [label=<add_4_69:1
1x6x24 (720 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"]
- linear_17_70 [label=<linear_17_70
1x6x4 (240 B)
params: 4x24, x4
@layers.2.mlp.0> color=black fillcolor="#E6E6E6" fontcolor=black ordering=out shape=box style="filled,solid"]
- gelu_3_71 [label=<gelu_3_71
1x6x4 (240 B)
@layers.2.mlp.1> color=black fillcolor=white fontcolor=black ordering=out shape=box style="filled,solid"]
- linear_18_72 [label=<linear_18_72
1x6x24 (720 B)
params: 24x4, x24
@layers.2.mlp.2> color=black fillcolor="#E6E6E6" fontcolor=black ordering=out shape=box style="filled,solid"]
- add_4_69pass2 [label=<add_4_69:2
1x6x24 (720 B)> color=black fillcolor=white fontcolor=black ordering=out shape=oval style="filled,solid"]
- add_4_69pass2 -> output_1 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- output_1 [label=<output_1
1x6x24 (720 B)
@output> color=black fillcolor="#ff9999" fontcolor=black ordering=out shape=oval style="filled,solid"]
- {
- rank=sink
- output_1
- }
- subgraph cluster_token_emb_pass1 {
- fillcolor=white label=<@token_emb
(Embedding)
> labelloc=b penwidth=5.0 style="filled,dashed"
- }
- subgraph cluster_pos_emb_pass1 {
- fillcolor=white label=<@pos_emb
(Embedding)
> labelloc=b penwidth=5.0 style="filled,dashed"
- }
- subgraph "cluster_layers.0_pass1" {
- fillcolor=white label=<@layers.0
(EncoderBlock)
> labelloc=b penwidth=5.0 style="filled,solid"
- linear_4_24 -> add_2_25pass1 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- add_2_25pass1 -> linear_5_26 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- add_2_25pass1 -> add_2_25pass2 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- linear_6_28 -> add_2_25pass2 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- }
- subgraph "cluster_layers.1_pass1" {
- fillcolor=white label=<@layers.1
(EncoderBlock)
> labelloc=b penwidth=5.0 style="filled,solid"
- linear_10_46 -> add_3_47pass1 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- add_3_47pass1 -> linear_11_48 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- add_3_47pass1 -> add_3_47pass2 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- linear_12_50 -> add_3_47pass2 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- }
- subgraph "cluster_layers.2_pass1" {
- fillcolor=white label=<@layers.2
(EncoderBlock)
> labelloc=b penwidth=5.0 style="filled,solid"
- linear_16_68 -> add_4_69pass1 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- add_4_69pass1 -> linear_17_70 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- add_4_69pass1 -> add_4_69pass2 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- linear_18_72 -> add_4_69pass2 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- }
- subgraph "cluster_layers.0_pass1" {
- subgraph "cluster_layers.0.attn_pass1" {
- fillcolor=white label=<@layers.0.attn
(MultiheadSelfAttention)
> labelloc=b penwidth=3.5 style="filled,solid"
- linear_1_7 -> view_1_8 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- view_1_8 -> transpose_1_9 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- transpose_1_9 -> matmul_1_17 [label=<arg 0> arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- linear_2_10 -> view_2_11 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- view_2_11 -> transpose_2_12 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- transpose_2_12 -> transpose_4_16 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- linear_3_13 -> view_3_14 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- view_3_14 -> transpose_3_15 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- transpose_3_15 -> matmul_2_20 [label=<arg 1> arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- transpose_4_16 -> matmul_1_17 [label=<arg 1> arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- matmul_1_17 -> truediv_1_18 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- truediv_1_18 -> softmax_1_19 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- softmax_1_19 -> matmul_2_20 [label=<arg 0> arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- matmul_2_20 -> transpose_5_21 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- transpose_5_21 -> contiguous_1_22 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- contiguous_1_22 -> view_4_23 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- view_4_23 -> linear_4_24 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- }
- }
- subgraph "cluster_layers.0_pass1" {
- subgraph "cluster_layers.0.mlp_pass1" {
- fillcolor=white label=<@layers.0.mlp
(Sequential)
> labelloc=b penwidth=3.5 style="filled,solid"
- linear_5_26 -> gelu_1_27 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- gelu_1_27 -> linear_6_28 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- }
- }
- subgraph "cluster_layers.1_pass1" {
- subgraph "cluster_layers.1.attn_pass1" {
- fillcolor=white label=<@layers.1.attn
(MultiheadSelfAttention)
> labelloc=b penwidth=3.5 style="filled,solid"
- linear_7_29 -> view_5_30 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- view_5_30 -> transpose_6_31 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- transpose_6_31 -> matmul_3_39 [label=<arg 0> arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- linear_8_32 -> view_6_33 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- view_6_33 -> transpose_7_34 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- transpose_7_34 -> transpose_9_38 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- linear_9_35 -> view_7_36 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- view_7_36 -> transpose_8_37 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- transpose_8_37 -> matmul_4_42 [label=<arg 1> arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- transpose_9_38 -> matmul_3_39 [label=<arg 1> arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- matmul_3_39 -> truediv_2_40 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- truediv_2_40 -> softmax_2_41 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- softmax_2_41 -> matmul_4_42 [label=<arg 0> arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- matmul_4_42 -> transpose_10_43 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- transpose_10_43 -> contiguous_2_44 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- contiguous_2_44 -> view_8_45 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- view_8_45 -> linear_10_46 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- }
- }
- subgraph "cluster_layers.1_pass1" {
- subgraph "cluster_layers.1.mlp_pass1" {
- fillcolor=white label=<@layers.1.mlp
(Sequential)
> labelloc=b penwidth=3.5 style="filled,solid"
- linear_11_48 -> gelu_2_49 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- gelu_2_49 -> linear_12_50 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- }
- }
- subgraph "cluster_layers.2_pass1" {
- subgraph "cluster_layers.2.attn_pass1" {
- fillcolor=white label=<@layers.2.attn
(MultiheadSelfAttention)
> labelloc=b penwidth=3.5 style="filled,solid"
- linear_13_51 -> view_9_52 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- view_9_52 -> transpose_11_53 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- transpose_11_53 -> matmul_5_61 [label=<arg 0> arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- linear_14_54 -> view_10_55 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- view_10_55 -> transpose_12_56 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- transpose_12_56 -> transpose_14_60 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- linear_15_57 -> view_11_58 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- view_11_58 -> transpose_13_59 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- transpose_13_59 -> matmul_6_64 [label=<arg 1> arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- transpose_14_60 -> matmul_5_61 [label=<arg 1> arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- matmul_5_61 -> truediv_3_62 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- truediv_3_62 -> softmax_3_63 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- softmax_3_63 -> matmul_6_64 [label=<arg 0> arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- matmul_6_64 -> transpose_15_65 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- transpose_15_65 -> contiguous_3_66 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- contiguous_3_66 -> view_12_67 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- view_12_67 -> linear_16_68 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- }
- }
- subgraph "cluster_layers.2_pass1" {
- subgraph "cluster_layers.2.mlp_pass1" {
- fillcolor=white label=<@layers.2.mlp
(Sequential)
> labelloc=b penwidth=3.5 style="filled,solid"
- linear_17_70 -> gelu_3_71 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- gelu_3_71 -> linear_18_72 [arrowsize=.7 color=black fontcolor=black labelfontsize=8 style=solid]
- }
- }
- subgraph "cluster_layers.0_pass1" {
- subgraph "cluster_layers.0.attn_pass1" {
- subgraph "cluster_layers.0.attn.W_q_pass1" {
- fillcolor=white label=<@layers.0.attn.W_q
(Linear)
> labelloc=b penwidth=2.0 style="filled,dashed"
- }
- }
- }
- subgraph "cluster_layers.0_pass1" {
- subgraph "cluster_layers.0.attn_pass1" {
- subgraph "cluster_layers.0.attn.W_k_pass1" {
- fillcolor=white label=<@layers.0.attn.W_k
(Linear)
> labelloc=b penwidth=2.0 style="filled,dashed"
- }
- }
- }
- subgraph "cluster_layers.0_pass1" {
- subgraph "cluster_layers.0.attn_pass1" {
- subgraph "cluster_layers.0.attn.W_v_pass1" {
- fillcolor=white label=<@layers.0.attn.W_v
(Linear)
> labelloc=b penwidth=2.0 style="filled,dashed"
- }
- }
- }
- subgraph "cluster_layers.0_pass1" {
- subgraph "cluster_layers.0.attn_pass1" {
- subgraph "cluster_layers.0.attn.W_o_pass1" {
- fillcolor=white label=<@layers.0.attn.W_o
(Linear)
> labelloc=b penwidth=2.0 style="filled,dashed"
- }
- }
- }
- subgraph "cluster_layers.0_pass1" {
- subgraph "cluster_layers.0.mlp_pass1" {
- subgraph "cluster_layers.0.mlp.0_pass1" {
- fillcolor=white label=<@layers.0.mlp.0
(Linear)
> labelloc=b penwidth=2.0 style="filled,dashed"
- }
- }
- }
- subgraph "cluster_layers.0_pass1" {
- subgraph "cluster_layers.0.mlp_pass1" {
- subgraph "cluster_layers.0.mlp.1_pass1" {
- fillcolor=white label=<@layers.0.mlp.1
(GELU)
> labelloc=b penwidth=2.0 style="filled,dashed"
- }
- }
- }
- subgraph "cluster_layers.0_pass1" {
- subgraph "cluster_layers.0.mlp_pass1" {
- subgraph "cluster_layers.0.mlp.2_pass1" {
- fillcolor=white label=<@layers.0.mlp.2
(Linear)
> labelloc=b penwidth=2.0 style="filled,dashed"
- }
- }
- }
- subgraph "cluster_layers.1_pass1" {
- subgraph "cluster_layers.1.attn_pass1" {
- subgraph "cluster_layers.1.attn.W_q_pass1" {
- fillcolor=white label=<@layers.1.attn.W_q
(Linear)
> labelloc=b penwidth=2.0 style="filled,dashed"
- }
- }
- }
- subgraph "cluster_layers.1_pass1" {
- subgraph "cluster_layers.1.attn_pass1" {
- subgraph "cluster_layers.1.attn.W_k_pass1" {
- fillcolor=white label=<@layers.1.attn.W_k
(Linear)
> labelloc=b penwidth=2.0 style="filled,dashed"
- }
- }
- }
- subgraph "cluster_layers.1_pass1" {
- subgraph "cluster_layers.1.attn_pass1" {
- subgraph "cluster_layers.1.attn.W_v_pass1" {
- fillcolor=white label=<@layers.1.attn.W_v
(Linear)
> labelloc=b penwidth=2.0 style="filled,dashed"
- }
- }
- }
- subgraph "cluster_layers.1_pass1" {
- subgraph "cluster_layers.1.attn_pass1" {
- subgraph "cluster_layers.1.attn.W_o_pass1" {
- fillcolor=white label=<@layers.1.attn.W_o
(Linear)
> labelloc=b penwidth=2.0 style="filled,dashed"
- }
- }
- }
- subgraph "cluster_layers.1_pass1" {
- subgraph "cluster_layers.1.mlp_pass1" {
- subgraph "cluster_layers.1.mlp.0_pass1" {
- fillcolor=white label=<@layers.1.mlp.0
(Linear)
> labelloc=b penwidth=2.0 style="filled,dashed"
- }
- }
- }
- subgraph "cluster_layers.1_pass1" {
- subgraph "cluster_layers.1.mlp_pass1" {
- subgraph "cluster_layers.1.mlp.1_pass1" {
- fillcolor=white label=<@layers.1.mlp.1
(GELU)
> labelloc=b penwidth=2.0 style="filled,dashed"
- }
- }
- }
- subgraph "cluster_layers.1_pass1" {
- subgraph "cluster_layers.1.mlp_pass1" {
- subgraph "cluster_layers.1.mlp.2_pass1" {
- fillcolor=white label=<@layers.1.mlp.2
(Linear)
> labelloc=b penwidth=2.0 style="filled,dashed"
- }
- }
- }
- subgraph "cluster_layers.2_pass1" {
- subgraph "cluster_layers.2.attn_pass1" {
- subgraph "cluster_layers.2.attn.W_q_pass1" {
- fillcolor=white label=<@layers.2.attn.W_q
(Linear)
> labelloc=b penwidth=2.0 style="filled,dashed"
- }
- }
- }
- subgraph "cluster_layers.2_pass1" {
- subgraph "cluster_layers.2.attn_pass1" {
- subgraph "cluster_layers.2.attn.W_k_pass1" {
- fillcolor=white label=<@layers.2.attn.W_k
(Linear)
> labelloc=b penwidth=2.0 style="filled,dashed"
- }
- }
- }
- subgraph "cluster_layers.2_pass1" {
- subgraph "cluster_layers.2.attn_pass1" {
- subgraph "cluster_layers.2.attn.W_v_pass1" {
- fillcolor=white label=<@layers.2.attn.W_v
(Linear)
> labelloc=b penwidth=2.0 style="filled,dashed"
- }
- }
- }
- subgraph "cluster_layers.2_pass1" {
- subgraph "cluster_layers.2.attn_pass1" {
- subgraph "cluster_layers.2.attn.W_o_pass1" {
- fillcolor=white label=<@layers.2.attn.W_o
(Linear)
> labelloc=b penwidth=2.0 style="filled,dashed"
- }
- }
- }
- subgraph "cluster_layers.2_pass1" {
- subgraph "cluster_layers.2.mlp_pass1" {
- subgraph "cluster_layers.2.mlp.0_pass1" {
- fillcolor=white label=<@layers.2.mlp.0
(Linear)
> labelloc=b penwidth=2.0 style="filled,dashed"
- }
- }
- }
- subgraph "cluster_layers.2_pass1" {
- subgraph "cluster_layers.2.mlp_pass1" {
- subgraph "cluster_layers.2.mlp.1_pass1" {
- fillcolor=white label=<@layers.2.mlp.1
(GELU)
> labelloc=b penwidth=2.0 style="filled,dashed"
- }
- }
- }
- subgraph "cluster_layers.2_pass1" {
- subgraph "cluster_layers.2.mlp_pass1" {
- subgraph "cluster_layers.2.mlp.2_pass1" {
- fillcolor=white label=<@layers.2.mlp.2
(Linear)
> labelloc=b penwidth=2.0 style="filled,dashed"
- }
- }
- }
-}
diff --git a/graph.gv.pdf b/graph.gv.pdf
deleted file mode 100644
index 308a93f..0000000
Binary files a/graph.gv.pdf and /dev/null differ
diff --git a/input_tokens.json b/input_tokens.json
new file mode 100644
index 0000000..3674a71
--- /dev/null
+++ b/input_tokens.json
@@ -0,0 +1 @@
+["BOS", 1, 0, 1, 1, 0]
\ No newline at end of file
diff --git a/load_and_visualize_with_torchlens.py b/load_and_visualize_with_torchlens.py
deleted file mode 100644
index 4b0b404..0000000
--- a/load_and_visualize_with_torchlens.py
+++ /dev/null
@@ -1,70 +0,0 @@
-import os, sys
-sys.path.append(os.path.dirname(__file__)) # add current folder to module search path
-from tracr_transformer_pt import TracrTransformerPT
-# load_and_visualize_with_torchlens.py
-import numpy as np
-import torch
-import torchlens as tl
-from tracr_transformer_pt import TracrTransformerPT
-
-# Instantiate the mirror model with your inferred hyperparams
-model = TracrTransformerPT(
- vocab_size=4,
- max_seq_len=11,
- d_model=24,
- n_heads=3,
- head_dim=4, # ★ critical — this makes 3*4 = 12 projection dim
- n_layers=3,
- d_mlp=4
-)
-
-
-# Load NPZ exported from your JAX/Haiku model
-npz = np.load("tracr_majority_params.npz")
-
-# Helper: copy (with optional transpose) into a torch parameter tensor
-def copy_(pt_tensor, arr, transpose=False):
- t = torch.tensor(arr)
- if transpose:
- t = t.T
- assert tuple(t.shape) == tuple(pt_tensor.shape), f"Shape mismatch: src {t.shape} != dst {pt_tensor.shape}"
- with torch.no_grad():
- pt_tensor.copy_(t)
-
-sd = model.state_dict()
-
-# ---- Embeddings (same layout as PyTorch, no transpose) ----
-copy_(sd["token_emb.weight"], npz["token_embed__embeddings"])
-copy_(sd["pos_emb.weight"], npz["pos_embed__embeddings"])
-
-# ---- Per-layer mappings (notice: JAX Linear 'w' is (in, out); PyTorch is (out, in) -> transpose=True) ----
-for i in range(3): # layers 0..2
- # Attention projections
- copy_(sd[f"layers.{i}.attn.W_q.weight"], npz[f"transformer__layer_{i}__attn__query__w"], transpose=True) # (24,12)->(12,24)
- copy_(sd[f"layers.{i}.attn.W_q.bias"], npz[f"transformer__layer_{i}__attn__query__b"])
- copy_(sd[f"layers.{i}.attn.W_k.weight"], npz[f"transformer__layer_{i}__attn__key__w"], transpose=True)
- copy_(sd[f"layers.{i}.attn.W_k.bias"], npz[f"transformer__layer_{i}__attn__key__b"])
- copy_(sd[f"layers.{i}.attn.W_v.weight"], npz[f"transformer__layer_{i}__attn__value__w"], transpose=True)
- copy_(sd[f"layers.{i}.attn.W_v.bias"], npz[f"transformer__layer_{i}__attn__value__b"])
-
- # Attention output projection ("linear")
- copy_(sd[f"layers.{i}.attn.W_o.weight"], npz[f"transformer__layer_{i}__attn__linear__w"], transpose=True) # (12,24)->(24,12)
- copy_(sd[f"layers.{i}.attn.W_o.bias"], npz[f"transformer__layer_{i}__attn__linear__b"])
-
- # MLP 24->4->24
- copy_(sd[f"layers.{i}.mlp.0.weight"], npz[f"transformer__layer_{i}__mlp__linear_1__w"], transpose=True) # (24,4)->(4,24)
- copy_(sd[f"layers.{i}.mlp.0.bias"], npz[f"transformer__layer_{i}__mlp__linear_1__b"])
- copy_(sd[f"layers.{i}.mlp.2.weight"], npz[f"transformer__layer_{i}__mlp__linear_2__w"], transpose=True) # (4,24)->(24,4)
- copy_(sd[f"layers.{i}.mlp.2.bias"], npz[f"transformer__layer_{i}__mlp__linear_2__b"])
-
-# Commit weights to the module
-model.load_state_dict(sd)
-
-# ---- TorchLens over the SAME model ----
-# If you have your exact token-indexing, encode it here. For a quick diagram, dummy IDs in [0..3] work.
-x = torch.randint(low=0, high=4, size=(1, 6)) # (B=1, T=6) e.g., [BOS, 1,0,1,1,0] with the right indices if you prefer
-model.eval()
-
-# 1) Save full forward history AND 2) render the layered graph (unrolled by layer):
-log = tl.log_forward_pass(model, x, vis_opt="unrolled")
-tl.show_model_graph(model, (x,), vis_opt="unrolled", file_name="torchlens_majority_graph")
diff --git a/my_majority_program.py b/my_majority_program.py
deleted file mode 100644
index bc06cf4..0000000
--- a/my_majority_program.py
+++ /dev/null
@@ -1,183 +0,0 @@
-import os, sys
-tracr_path = os.path.join(os.path.dirname(__file__), "Tracr", "tracr")
-sys.path.insert(0, tracr_path)
-
-import numpy as np
-import jax
-import jax.random as random
-
-from tracr.compiler import compiling
-from tracr.rasp import rasp
-
-# --- The robust function to import show_model ---
-# --- The robust function to import show_model (returns None if not found) ---
-def _get_show_model():
- import importlib, pkgutil, tracr
- for modpath in ("tracr.compiler.visualization", "tracr.visualization"):
- try:
- mod = importlib.import_module(modpath)
- fn = getattr(mod, "show_model", None)
- if callable(fn):
- return fn
- except Exception:
- pass
- for _, name, _ in pkgutil.walk_packages(tracr.__path__, tracr.__name__ + "."):
- try:
- mod = importlib.import_module(name)
- fn = getattr(mod, "show_model", None)
- if callable(fn):
- return fn
- except Exception:
- continue
- return None # <- do NOT raise here
-# --- Fallback: render a clean Tracr-style diagram from compiled params ---
-def render_block_diagram_from_compiled(compiled_model, out_basename="tracr_majority_graph"):
- import re
- from graphviz import Digraph
- # get params from the compiled model
- try:
- params = compiled_model.params
- except AttributeError:
- params = compiled_model.weights
-
- # flatten nested dict into "path" -> ndarray
- flat = {}
- def walk(path, node):
- if isinstance(node, dict):
- for k, v in node.items():
- walk(path + (k,), v)
- else:
- flat["/".join(path)] = np.array(node)
- walk((), params)
-
- # read shapes
- tok_key = next(k for k in flat if k.endswith("token_embed/embeddings"))
- pos_key = next(k for k in flat if k.endswith("pos_embed/embeddings"))
- vocab_size, d_model = flat[tok_key].shape
- max_seq_len = flat[pos_key].shape[0]
-
- layer_nums = sorted({int(m.group(1))
- for k in flat
- for m in [re.search(r"transformer/layer_(\d+)/", k)]
- if m})
- # attn proj and mlp hidden from layer 0
- proj_dim = flat[f"transformer/layer_{layer_nums[0]}/attn/query/w"].shape[1]
- mlp_hidden = flat[f"transformer/layer_{layer_nums[0]}/mlp/linear_1/w"].shape[1]
-
- dot = Digraph("tracr_majority_transformer", format="pdf")
- dot.attr(rankdir="LR", fontsize="12", labelloc="t",
- label=f"Tracr-compiled Majority Transformer\n"
- f"vocab={vocab_size}, d_model={d_model}, layers={len(layer_nums)}, "
- f"proj_dim={proj_dim}, mlp_hidden={mlp_hidden}")
-
- with dot.subgraph(name="cluster_embed") as c:
- c.attr(label="Embeddings")
- c.node("tok_emb", f"Token Embedding\n({vocab_size}×{d_model})", shape="box")
- c.node("pos_emb", f"Positional Embedding\n({max_seq_len}×{d_model})", shape="box")
- c.node("sum", "Add", shape="circle")
- c.edges([("tok_emb", "sum"), ("pos_emb", "sum")])
-
- prev = "sum"
- for i in layer_nums:
- with dot.subgraph(name=f"cluster_layer_{i}") as c:
- c.attr(label=f"Encoder Block {i}")
- c.node(f"attn_{i}", f"MHA proj {proj_dim}", shape="box")
- c.node(f"add_attn_{i}", "Add", shape="circle")
- c.node(f"mlp_{i}", f"MLP {d_model}→{mlp_hidden}→{d_model}", shape="box")
- c.node(f"add_mlp_{i}", "Add", shape="circle")
- dot.edge(prev, f"attn_{i}")
- dot.edge(f"attn_{i}", f"add_attn_{i}")
- dot.edge(f"add_attn_{i}", f"mlp_{i}")
- dot.edge(f"mlp_{i}", f"add_mlp_{i}")
- prev = f"add_mlp_{i}"
-
- dot.node("out", f"Output\n(seq_len×{d_model})", shape="box")
- dot.edge(prev, "out")
- out_path = dot.render(out_basename, cleanup=True)
- print(f"Saved {out_path}")
-
-
-show_model = _get_show_model()
-
-
-
-VOCAB = {0, 1}
-MAX_SEQ_LEN = 10
-COMPILER_BOS = "BOS"
-
-# --- Majority Score Program ---
-def majority_score_program():
- """
- RASP program that outputs whether 1s or 0s are the majority.
- Positive = majority of 1s, Negative = majority of 0s, 0 = tie.
- """
- # Select all positions
- all_positions = rasp.Select(rasp.indices, rasp.indices, rasp.Comparison.TRUE)
-
- # Count number of 1s
- select_ones = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.EQ)
- num_ones = rasp.Aggregate(select_ones, rasp.tokens)
-
- # Count number of 0s (total length - num_ones)
- seq_length = rasp.Aggregate(all_positions, rasp.tokens * 0 + 1) # sum of 1s over all positions
- num_zeros = seq_length - num_ones
-
- # Majority score = (#1s - #0s)
- majority_score = num_ones - num_zeros
-
- return majority_score
-
-# --- Compile ---
-print("Compiling majority RASP program to transformer model...")
-compiled_model = compiling.compile_rasp_to_model(
- majority_score_program(),
- vocab=VOCAB,
- max_seq_len=MAX_SEQ_LEN,
- compiler_bos=COMPILER_BOS,
-)
-print("Compilation complete!\n")
-
-# --- Save transformer diagram ---
-print("Generating transformer visualization...")
-show_model = _get_show_model()
-if show_model is not None:
- graph = show_model(compiled_model, max_seq_len=MAX_SEQ_LEN, return_graph=True)
- graph.render("tracr_majority_graph", format="pdf", cleanup=True)
-else:
- print("Tracr show_model not found — using fallback renderer.")
- render_block_diagram_from_compiled(compiled_model, out_basename="tracr_majority_graph")
-print("Diagram saved as tracr_majority_graph.pdf ✅\n")
-
-
-# --- Example ---
-example_input_sequence = [1, 0, 1, 1, 0]
-
-print(f"Raw input sequence (no BOS): {example_input_sequence}")
-
-# Prepend BOS manually, since Tracr expects it
-input_with_bos = [COMPILER_BOS] + example_input_sequence
-print(f"Input sequence with BOS: {input_with_bos}")
-
-# Run model
-output_logits = compiled_model.apply(input_with_bos)
-
-# Interpret output
-vocab_list = sorted(list(VOCAB)) + [COMPILER_BOS]
-predicted_tokens = [vocab_list[np.argmax(logits)] for logits in output_logits]
-
-print("\n--- Model Output ---")
-print("Raw logits:\n", output_logits)
-print("Predicted tokens:", predicted_tokens)
-
-# --- Run RASP directly ---
-rasp_output = majority_score_program()(example_input_sequence)
-print("Raw RASP output:", rasp_output)
-print("\nRASP execution output:", rasp_output)
-
-majority_val = rasp_output[0]
-if majority_val > 0:
- print("Majority element: 1 ✅")
-elif majority_val < 0:
- print("Majority element: 0 ✅")
-else:
- print("Tie between 0 and 1 🤝")
diff --git a/my_majority_torchlens.py b/my_majority_torchlens.py
deleted file mode 100644
index 904ead7..0000000
--- a/my_majority_torchlens.py
+++ /dev/null
@@ -1,56 +0,0 @@
-import torch
-import torch.nn as nn
-import torchlens as tl # ★ import as a module; tl.show_model_graph used below
-
-class MyTransformer(nn.Module):
- def __init__(self, input_dim=128, hidden_dim=256, num_layers=2, num_heads=4):
- super().__init__()
- self.embedding = nn.Linear(input_dim, hidden_dim)
- encoder_layer = nn.TransformerEncoderLayer(
- d_model=hidden_dim, nhead=num_heads, batch_first=True
- )
- self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
- self.decoder = nn.Linear(hidden_dim, 10)
-
- def forward(self, x):
- x = self.embedding(x)
- x = self.encoder(x)
- return self.decoder(x)
-
-# Optional wrapper not required for TorchLens; it will trace submodules anyway.
-transformer = MyTransformer()
-
-# Example input (B, T, D_in)
-x = torch.randn(2, 5, 128)
-
-# Option A: one-liner that logs activations and ALSO renders the graph.
-log = tl.log_forward_pass(transformer, x, layers_to_save=None, vis_opt="unrolled")
-
-
-# --- Make a wrapper that exposes internals clearly ---
-class TransformerWrapper(nn.Module):
- def __init__(self, transformer):
- super().__init__()
- self.model = transformer
-
- def forward(self, x):
- # Instead of doing everything in one call, we explicitly call submodules
- x = self.model.embedding(x)
- # Go through each encoder layer explicitly so TorchLens can track them
- for i, layer in enumerate(self.model.encoder.layers):
- x = layer(x)
- x = self.model.decoder(x)
- return x
-
-# --- Instantiate and wrap the model ---
-transformer = MyTransformer()
-wrapped_model = TransformerWrapper(transformer)
-
-# --- Dummy input ---
-x = torch.randn(2, 5, 128)
-
-# --- Run TorchLens logging ---
-log = log_forward_pass(wrapped_model, x)
-
-# --- Visualize graph ---
-show_model_graph(log)
diff --git a/scripts/compile_export.py b/scripts/compile_export.py
new file mode 100644
index 0000000..f92caf2
--- /dev/null
+++ b/scripts/compile_export.py
@@ -0,0 +1,132 @@
+#!/usr/bin/env python3
+# scripts/compile_export.py
+import os, sys, json
+import numpy as np
+from pathlib import Path
+import sys
+
+REPO_ROOT = Path(__file__).resolve().parents[1]
+
+# Add local tracr paths if they exist; otherwise rely on pip-installed package
+for p in [
+ REPO_ROOT / "external" / "Tracr" / "tracr", # optional vendored location
+ REPO_ROOT / "Tracr" / "tracr", # older local layout
+ REPO_ROOT / "tracr", # just in case
+]:
+ if p.is_dir():
+ sys.path.insert(0, str(p))
+ break
+
+# Now import (works with either local path OR pip-installed package)
+from tracr.compiler import compiling
+from tracr.rasp import rasp
+
+
+# -------- Config --------
+VOCAB = {0, 1}
+MAX_SEQ_LEN = 10
+COMPILER_BOS = "BOS"
+COMPILER_PAD = "PAD"
+CAUSAL = True
+EXAMPLE = [1, 0, 1, 1, 0]
+
+ART = REPO_ROOT / "artifacts"
+ART.mkdir(exist_ok=True)
+
+def majority_program():
+ # majority = 2 * (#ones) - seq_len
+ all_pos = rasp.Select(rasp.indices, rasp.indices, rasp.Comparison.EQ)
+ ones_vec = rasp.Map(lambda t: t, rasp.tokens) # 0/1 numeric
+ num_ones = rasp.Aggregate(all_pos, ones_vec)
+ ones_const = rasp.Map(lambda _: 1, rasp.tokens)
+ seq_len = rasp.Aggregate(all_pos, ones_const)
+ return (2 * num_ones) - seq_len
+
+def compile_tracr(prog):
+ kw = dict(vocab=VOCAB, max_seq_len=MAX_SEQ_LEN, compiler_bos=COMPILER_BOS)
+ try:
+ return compiling.compile_rasp_to_model(prog, compiler_pad=COMPILER_PAD, causal=CAUSAL, **kw)
+ except TypeError:
+ pass
+ try:
+ return compiling.compile_rasp_to_model(prog, causal=CAUSAL, **kw)
+ except TypeError:
+ return compiling.compile_rasp_to_model(prog, **kw)
+
+def get_tok2id_or_fallback(model):
+ cands = [
+ getattr(model, "tokenizer", None),
+ getattr(model, "vocab", None),
+ getattr(model, "vocabulary", None),
+ getattr(getattr(model, "transformer", None), "tokenizer", None),
+ getattr(getattr(model, "transformer", None), "vocab", None),
+ getattr(getattr(model, "transformer", None), "vocabulary", None),
+ ]
+ for obj in cands:
+ if obj is None: continue
+ if isinstance(obj, dict) and obj: return obj
+ if hasattr(obj, "token_to_id") and isinstance(obj.token_to_id, dict):
+ return obj.token_to_id
+ if hasattr(obj, "id_to_token") and isinstance(obj.id_to_token, dict):
+ return {tok: int(i) for i, tok in obj.id_to_token.items()}
+ # fallback mapping (deterministic)
+ ordered = [COMPILER_BOS] + sorted(list(VOCAB), key=lambda x: repr(x)) + [COMPILER_PAD]
+ mapping = {tok: i for i, tok in enumerate(ordered)}
+ print("[WARN] Using fallback token mapping:", mapping)
+ return mapping
+
+def export_params_npz(compiled, out_path: Path, keys_path: Path):
+ try:
+ params = compiled.params
+ except AttributeError:
+ params = compiled.weights
+
+ leaves = []
+ def walk(path, node):
+ if isinstance(node, dict):
+ for k, v in node.items(): walk(path + (k,), v)
+ elif isinstance(node, (list, tuple)):
+ for i, v in enumerate(node): walk(path + (str(i),), v)
+ else:
+ leaves.append(("/".join(path), np.array(node)))
+
+ walk((), params)
+ npz_dict = {k.replace("/", "__"): v for k, v in leaves}
+ np.savez(out_path, **npz_dict)
+ keys_path.write_text(json.dumps(sorted(npz_dict.keys()), indent=2))
+ print(f"Exported params -> {out_path} (keys -> {keys_path})")
+
+def main():
+ print("Compiling RASP → Tracr transformer…")
+ compiled = compile_tracr(majority_program())
+ print("Done.\n")
+
+ # tokenizer mapping: only write fallback if no existing discovered mapping
+ tok_map_path = ART / "token_to_id.json"
+ tok2id = get_tok2id_or_fallback(compiled)
+ if not tok_map_path.exists():
+ tok_map_path.write_text(json.dumps(tok2id))
+ print("Saved token_to_id.json")
+ else:
+ print("token_to_id.json exists; keeping discovered mapping.")
+
+ # exact tokens used for the reference pass
+ tokens = [COMPILER_BOS] + EXAMPLE
+ (ART / "input_tokens.json").write_text(json.dumps(tokens))
+ print("Saved input_tokens.json")
+
+ # forward pass on assembled model
+ out = compiled.apply(tokens)
+ arr = np.array(getattr(out, "transformer_output", out), dtype=np.float32)
+ if arr.ndim == 2: arr = arr[None, ...]
+ np.save(ART / "tracr_output.npy", arr)
+ print(f"Saved tracr_output.npy with shape {arr.shape} (dtype={arr.dtype})")
+
+ # export params from THIS compiled model
+ export_params_npz(compiled, ART / "tracr_majority_params.npz", ART / "tracr_param_keys.json")
+
+ print("\nNow run:")
+ print(" python scripts/parity_check.py")
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/parity_check.py b/scripts/parity_check.py
new file mode 100644
index 0000000..f6383e1
--- /dev/null
+++ b/scripts/parity_check.py
@@ -0,0 +1,79 @@
+#!/usr/bin/env python3
+# scripts/parity_check.py
+import json, itertools, numpy as np, torch
+from pathlib import Path
+import sys
+
+ROOT = Path(__file__).resolve().parents[1]
+ART = ROOT / "artifacts"
+sys.path.append(str(ROOT / "src"))
+
+from tracr_pt_model import TracrTransformerPT
+
+# ---- Load NPZ & infer dims ----
+npz = np.load(ART / "tracr_majority_params.npz")
+def get(k): return npz.get(k, npz[k.replace("/", "__")])
+
+vocab_size, d_model = get("token_embed/embeddings").shape
+max_len = get("pos_embed/embeddings").shape[0]
+proj_dim = int(get("transformer/layer_0/attn/query/w").shape[1]) # JAX (in,out)
+d_mlp = int(get("transformer/layer_0/mlp/linear_1/w").shape[1])
+n_layers = sum((f"transformer/layer_{i}/attn/query/w".replace("/", "__") in npz) for i in range(64))
+
+# Your matched config
+n_heads, head_dim = 2, proj_dim // 2
+
+print(f"Inferred -> d_model={d_model}, vocab={vocab_size}, max_seq_len={max_len}, "
+ f"layers={n_layers}, proj_dim={proj_dim}, d_mlp={d_mlp}, heads={n_heads}, head_dim={head_dim}")
+
+# ---- Build PT model & load weights ----
+model = TracrTransformerPT(vocab_size, max_len, int(d_model), int(n_layers), int(d_mlp),
+ n_heads=int(n_heads), head_dim=int(head_dim)).eval()
+
+def load_linear(L, w_key, b_key):
+ w = torch.from_numpy(get(w_key)).float().t().contiguous() # (in,out) -> (out,in)
+ b = torch.from_numpy(get(b_key)).float()
+ with torch.no_grad(): L.weight.copy_(w); L.bias.copy_(b)
+
+with torch.no_grad():
+ model.token_emb.weight.copy_(torch.from_numpy(get("token_embed/embeddings")).float())
+ model.pos_emb.weight.copy_(torch.from_numpy(get("pos_embed/embeddings")).float())
+
+for i in range(n_layers):
+ P = f"transformer/layer_{i}"
+ blk = model.layers[i]
+ load_linear(blk.attn.W_q, f"{P}/attn/query/w", f"{P}/attn/query/b")
+ load_linear(blk.attn.W_k, f"{P}/attn/key/w", f"{P}/attn/key/b")
+ load_linear(blk.attn.W_v, f"{P}/attn/value/w", f"{P}/attn/value/b")
+ load_linear(blk.attn.W_o, f"{P}/attn/linear/w", f"{P}/attn/linear/b")
+ load_linear(blk.mlp.fc1, f"{P}/mlp/linear_1/w", f"{P}/mlp/linear_1/b")
+ load_linear(blk.mlp.fc2, f"{P}/mlp/linear_2/w", f"{P}/mlp/linear_2/b")
+
+# ---- Read tokens & Tracr reference ----
+tokens = json.loads((ART / "input_tokens.json").read_text()) # ["BOS", 1, 0, 1, 1, 0]
+ref = torch.from_numpy(np.load(ART / "tracr_output.npy")).float()
+
+# ---- Discover BOS/0/1/PAD mapping once ----
+TOKS = ["BOS", "0", "1", "PAD"]
+best = None
+for perm in itertools.permutations(range(vocab_size), vocab_size):
+ tok2id = {TOKS[i]: perm[i] for i in range(vocab_size)}
+ ids = [tok2id["BOS"]] + [tok2id[str(t)] for t in tokens[1:]]
+ ids = torch.tensor([ids], dtype=torch.long)
+ with torch.no_grad(): out = model(ids)
+ md = (out - ref).abs().max().item()
+ ok = torch.allclose(out, ref, atol=1e-5)
+ cand = (tok2id, md, ok)
+ if best is None or md < best[1]: best = cand
+ if ok: break
+
+tok2id, md, ok = best
+(ART / "token_to_id.json").write_text(json.dumps(tok2id, indent=2))
+
+print("\n--- Mapping search ---")
+print(f"tok2id={tok2id}, max_diff={md:.6g}, match={ok}")
+print("\n--- Sanity Check ---")
+print(f"Outputs match: {ok}")
+print(f"Max abs diff: {md:.6g}")
+
+sys.exit(0 if ok else 1)
\ No newline at end of file
diff --git a/src/__init__.py b/src/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/__pycache__/tracr_pt_model.cpython-313.pyc b/src/__pycache__/tracr_pt_model.cpython-313.pyc
new file mode 100644
index 0000000..d15c121
Binary files /dev/null and b/src/__pycache__/tracr_pt_model.cpython-313.pyc differ
diff --git a/src/tracr_pt_model.py b/src/tracr_pt_model.py
new file mode 100644
index 0000000..03870d3
--- /dev/null
+++ b/src/tracr_pt_model.py
@@ -0,0 +1,66 @@
+# src/tracr_pt_model.py
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+class MultiheadSelfAttention(nn.Module):
+ def __init__(self, d_model, n_heads, head_dim, causal=True):
+ super().__init__()
+ self.n_heads = n_heads
+ self.head_dim = head_dim
+ self.causal = causal
+ proj = n_heads * head_dim
+ self.W_q = nn.Linear(d_model, proj, bias=True)
+ self.W_k = nn.Linear(d_model, proj, bias=True)
+ self.W_v = nn.Linear(d_model, proj, bias=True)
+ self.W_o = nn.Linear(proj, d_model, bias=True)
+
+ def forward(self, x):
+ B, T, _ = x.shape
+ def split(L):
+ y = L(x).view(B, T, self.n_heads, self.head_dim)
+ return y.permute(0, 2, 1, 3) # (B, H, T, D)
+ q, k, v = split(self.W_q), split(self.W_k), split(self.W_v)
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
+ if self.causal:
+ mask = torch.triu(torch.ones(T, T, device=x.device), 1).bool()
+ scores = scores.masked_fill(mask, float("-inf"))
+ attn = torch.softmax(scores, dim=-1)
+ ctx = torch.matmul(attn, v).permute(0, 2, 1, 3).contiguous().view(B, T, self.n_heads * self.head_dim)
+ return self.W_o(ctx)
+
+class MLP(nn.Module):
+ def __init__(self, d_model, d_mlp):
+ super().__init__()
+ self.fc1 = nn.Linear(d_model, d_mlp, bias=True)
+ self.fc2 = nn.Linear(d_mlp, d_model, bias=True)
+ def forward(self, x):
+ return self.fc2(F.relu(self.fc1(x)))
+
+class EncoderBlock(nn.Module):
+ def __init__(self, d_model, n_heads, head_dim, d_mlp):
+ super().__init__()
+ self.attn = MultiheadSelfAttention(d_model, n_heads, head_dim, causal=True)
+ self.mlp = MLP(d_model, d_mlp)
+ def forward(self, x):
+ x = x + self.attn(x) # Attn → MLP (sequential residuals)
+ x = x + self.mlp(x)
+ return x
+
+class TracrTransformerPT(nn.Module):
+ def __init__(self, vocab_size, max_seq_len, d_model, n_layers, d_mlp, n_heads=2, head_dim=12):
+ super().__init__()
+ self.token_emb = nn.Embedding(vocab_size, d_model)
+ self.pos_emb = nn.Embedding(max_seq_len, d_model)
+ self.layers = nn.ModuleList([
+ EncoderBlock(d_model, n_heads, head_dim, d_mlp) for _ in range(n_layers)
+ ])
+
+ def forward(self, token_ids):
+ B, T = token_ids.shape
+ pos = torch.arange(T, device=token_ids.device) # positions start at 0
+ x = self.token_emb(token_ids) + self.pos_emb(pos)[None, :, :]
+ for blk in self.layers:
+ x = blk(x)
+ return x
diff --git a/token_to_id.json b/token_to_id.json
new file mode 100644
index 0000000..6966287
--- /dev/null
+++ b/token_to_id.json
@@ -0,0 +1,6 @@
+{
+ "BOS": 2,
+ "0": 0,
+ "1": 1,
+ "PAD": 3
+}
\ No newline at end of file
diff --git a/tracr_majority_graph.pdf b/tracr_majority_graph.pdf
deleted file mode 100644
index 88f2c1a..0000000
Binary files a/tracr_majority_graph.pdf and /dev/null differ
diff --git a/tracr_majority_params.npz b/tracr_majority_params.npz
index 437d52d..ce7aa3a 100644
Binary files a/tracr_majority_params.npz and b/tracr_majority_params.npz differ
diff --git a/tracr_output.npy b/tracr_output.npy
new file mode 100644
index 0000000..a802065
Binary files /dev/null and b/tracr_output.npy differ
diff --git a/tracr_param_keys.json b/tracr_param_keys.json
new file mode 100644
index 0000000..460b71c
--- /dev/null
+++ b/tracr_param_keys.json
@@ -0,0 +1,40 @@
+[
+ "pos_embed__embeddings",
+ "token_embed__embeddings",
+ "transformer__layer_0__attn__key__b",
+ "transformer__layer_0__attn__key__w",
+ "transformer__layer_0__attn__linear__b",
+ "transformer__layer_0__attn__linear__w",
+ "transformer__layer_0__attn__query__b",
+ "transformer__layer_0__attn__query__w",
+ "transformer__layer_0__attn__value__b",
+ "transformer__layer_0__attn__value__w",
+ "transformer__layer_0__mlp__linear_1__b",
+ "transformer__layer_0__mlp__linear_1__w",
+ "transformer__layer_0__mlp__linear_2__b",
+ "transformer__layer_0__mlp__linear_2__w",
+ "transformer__layer_1__attn__key__b",
+ "transformer__layer_1__attn__key__w",
+ "transformer__layer_1__attn__linear__b",
+ "transformer__layer_1__attn__linear__w",
+ "transformer__layer_1__attn__query__b",
+ "transformer__layer_1__attn__query__w",
+ "transformer__layer_1__attn__value__b",
+ "transformer__layer_1__attn__value__w",
+ "transformer__layer_1__mlp__linear_1__b",
+ "transformer__layer_1__mlp__linear_1__w",
+ "transformer__layer_1__mlp__linear_2__b",
+ "transformer__layer_1__mlp__linear_2__w",
+ "transformer__layer_2__attn__key__b",
+ "transformer__layer_2__attn__key__w",
+ "transformer__layer_2__attn__linear__b",
+ "transformer__layer_2__attn__linear__w",
+ "transformer__layer_2__attn__query__b",
+ "transformer__layer_2__attn__query__w",
+ "transformer__layer_2__attn__value__b",
+ "transformer__layer_2__attn__value__w",
+ "transformer__layer_2__mlp__linear_1__b",
+ "transformer__layer_2__mlp__linear_1__w",
+ "transformer__layer_2__mlp__linear_2__b",
+ "transformer__layer_2__mlp__linear_2__w"
+]
\ No newline at end of file
diff --git a/tracr_transformer_pt.py b/tracr_transformer_pt.py
deleted file mode 100644
index 2a9a656..0000000
--- a/tracr_transformer_pt.py
+++ /dev/null
@@ -1,60 +0,0 @@
-# tracr_transformer_pt.py (replace your attention + block signatures with this)
-
-import math
-import torch
-import torch.nn as nn
-
-class MultiheadSelfAttention(nn.Module):
- def __init__(self, d_model=24, n_heads=3, head_dim=4, bias=True):
- super().__init__()
- self.d_model = d_model
- self.n_heads = n_heads
- self.head_dim = head_dim
- self.proj_dim = n_heads * head_dim # = 12 (matches your JAX dump)
-
- self.W_q = nn.Linear(d_model, n_heads * head_dim, bias=bias) # 24 -> 12
- self.W_k = nn.Linear(d_model, n_heads * head_dim, bias=bias) # 24 -> 12
- self.W_v = nn.Linear(d_model, n_heads * head_dim, bias=bias) # 24 -> 12
- self.W_o = nn.Linear(n_heads * head_dim, d_model, bias=bias) # 12 -> 24
-
-
- def forward(self, x):
- B, T, C = x.shape
- q = self.W_q(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) # (B,H,T,4)
- k = self.W_k(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
- v = self.W_v(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
-
- att = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) # (B,H,T,T)
- att = torch.softmax(att, dim=-1)
- y = torch.matmul(att, v) # (B,H,T,4)
- y = y.transpose(1, 2).contiguous().view(B, T, self.proj_dim) # (B,T,12)
- return self.W_o(y) # (B,T,24)
-
-class EncoderBlock(nn.Module):
- def __init__(self, d_model=24, n_heads=3, head_dim=4, d_mlp=4):
- super().__init__()
- self.attn = MultiheadSelfAttention(d_model=d_model, n_heads=n_heads, head_dim=head_dim, bias=True)
- self.mlp = nn.Sequential(
- nn.Linear(d_model, d_mlp, bias=True), # 24 -> 4
- nn.GELU(),
- nn.Linear(d_mlp, d_model, bias=True), # 4 -> 24
- )
-
- def forward(self, x):
- x = x + self.attn(x)
- x = x + self.mlp(x)
- return x
-
-class TracrTransformerPT(nn.Module):
- def __init__(self, vocab_size=4, max_seq_len=11, d_model=24, n_heads=3, head_dim=4, n_layers=3, d_mlp=4):
- super().__init__()
- self.token_emb = nn.Embedding(vocab_size, d_model)
- self.pos_emb = nn.Embedding(max_seq_len, d_model)
- self.layers = nn.ModuleList([EncoderBlock(d_model, n_heads, head_dim, d_mlp) for _ in range(n_layers)])
-
- def forward(self, token_ids):
- B, T = token_ids.shape
- x = self.token_emb(token_ids) + self.pos_emb(torch.arange(T, device=token_ids.device).unsqueeze(0).expand(B, T))
- for blk in self.layers:
- x = blk(x)
- return x # (B, T, 24)