Paper | Project Page | Setup | Usage | Citation
PonderTTT applies selective TTT updates based on input difficulty using the reconstruction loss as a training-free gating signal. A single scalar threshold, calibrated on unlabeled data and adapted during inference, governs update frequency. Testing on GPT-2 models (124M to 1.5B parameters) shows 82–89% Oracle Recovery while being fully training-free.
| Model | SKIP | Oracle | Ours | Recovery |
|---|---|---|---|---|
| Small (124M) | 2.324 | 1.935 | 1.977 | 89.2% |
| Medium (355M) | 1.909 | 1.653 | 1.697 | 82.8% |
| Large (774M) | 2.005 | 1.580 | 1.656 | 82.1% |
| XL (1.5B) | 1.875 | 1.518 | 1.576 | 83.8% |
This codebase is implemented in JAX and has been tested on both GPUs and Cloud TPU VMs.
# Install uv
curl -LsSf https://astral.sh/uv/install.sh | sh
# Install the project
uv pip install -e . # CPU
uv pip install -e . --group gpu # CUDA 13
uv pip install -e . --group tpu # TPUoutput = model(input_ids, use_ttt=True)
recon_loss = output["ttt_stats"]["ttt_loss_step_0"]
if recon_loss > threshold:
# UPDATE: re-forward with updated weights
pass
else:
# SKIP: use current weights
pass./scripts/run_all_experiments.sh # All models
./scripts/run_all_experiments.sh --small # Small (124M)
./scripts/run_all_experiments.sh --xl # XL (1.5B)@article{sim2025ponderttt,
title={When to Ponder: Adaptive Compute Allocation for Code Generation via Test-Time Training},
author={Sim, Gihyeon},
journal={arXiv preprint arXiv:2601.00894},
year={2025}
}This project is licensed under the MIT License.