A PyTorch implementation of MergeDNA, a hierarchical autoencoder for DNA sequences that learns adaptive tokenization through token merging.
A video explanation of MergeDNA can be found here.
Input DNA (N nucleotides)
|
v
[Local Encoder] --> Windowed attention + merge (N -> L)
|
v
[Latent Encoder] --> Full attention + merge (L -> K)
|
v
[Unmerge K -> L]
|
v
[Latent Decoder] --> Refine at L level
|
v
[Unmerge L -> N]
|
v
[Local Decoder] --> Refine at N level
|
v
Output Logits (N x 4)
import torch
from merge_dna import MergeDNA, MergeDNALoss
# Create model
model = MergeDNA(d_model=1024)
# Input: batch of DNA sequences as token IDs (0=A, 1=C, 2=G, 3=T)
x_ids = torch.randint(0, 4, (batch_size, seq_len))
# Forward pass
out = model(x_ids)
logits = out["logits"] # [batch, seq_len, 4]
# Training with all 3 losses (MTR + Latent MTR + AMTM)
loss_module = MergeDNALoss(model)
total_loss, loss_dict = loss_module(x_ids, targets=x_ids)