Implementation of R-NCE (Residual Noise Contrastive Estimation) for learning disease-related biomarkers from brain MRI patches.
R-NCE combines two learning objectives:
- Metadata Prediction: Encoder representations (h) learn to predict known features (e.g., cortical thickness, subcortical volume)
- Residualized Contrastive Learning: Contrastive loss applied to residualized embeddings that are orthogonal to known features
The key innovation:
- Representations (h) from the encoder predict metadata features via residualization loss
- Residualized embeddings are obtained by removing metadata signal from
- Contrastive loss is applied to embeddings, encouraging augmentation-invariant features orthogonal to metadata
This results in representations containing both known biological features AND novel, augmentation-invariant information.
git clone https://github.com/yourusername/rnce.git
cd rnce
pip install -e .Requirements: Python >= 3.8, PyTorch >= 2.0.0, TorchIO >= 0.18.0
R-NCE expects pre-extracted brain patches as .npy or .pt files, plus a metadata CSV.
Data structure:
data/
├── patches/
│ ├── left_cingulate/
│ │ ├── subject_0001.npy # Shape: [1, 40, 122, 82]
│ │ ├── subject_0002.npy
│ │ └── ...
└── metadata.csv
Example Brain region patch sizes:
| Region | Left | Right |
|---|---|---|
| MTL | 57 x 107 x 64 | 57 x 107 x 64 |
| LTL | 70 x 115 x 90 | 70 x 115 x 90 |
| Occipital | 79 x 87 x 82 | 80 x 87 x 82 |
| Parietal | 86 x 108 x 94 | 87 x 108 x 94 |
| Subcortical | 63 x 90 x 75 | 63 x 90 x 75 |
| Cingulate | 40 x 122 x 82 | 41 x 122 x 82 |
| Medial Frontal | 50 x 138 x 128 | 50 x 138 x 128 |
| Posterior Frontal | 78 x 98 x 99 | 78 x 98 x 99 |
| Anterior Frontal | 74 x 84 x 104 | 74 x 84 x 104 |
metadata.csv format:
subject_id,patch_path,cortical_thickness,subcortical_volume,surface_area
0001,data/patches/left_cingulate/subject_0001.npy,2.45,7500.2,1850.3
0002,data/patches/left_cingulate/subject_0002.npy,2.38,7200.8,1820.1
...Metadata columns should contain FreeSurfer features (cortical thickness, subcortical volumes, surface area) that will be predicted by encoder representations.
See configs/left_cingulate.yaml:
model:
rep_dim: 64
proj_dim: 128
data:
metadata_csv: "data/metadata.csv"
metadata_columns:
- lh_caudalanteriorcingulate_thickness
- lh_isthmuscingulate_thickness
- lh_posteriorcingulate_thickness
- lh_rostralanteriorcingulate_thickness
train_val_split: 0.85
training:
batch_size: 128
num_epochs: 400
lr: 0.00003
optimizer: 'Adam'
temp: 0.5
scale_contrastive: 1.0
scale_residual: 1.0python scripts/train.py --config configs/left_cingulate.yaml --device cudapython scripts/extract_embeddings.py \
--checkpoint checkpoints/best_model.pt \
--metadata_csv data/metadata.csv \
--output embeddings.ptfrom rnce import RNCEModel, RNCETrainer, load_config
config = load_config('my_config.yaml')
model = RNCEModel(
rep_dim=config.model.rep_dim,
proj_dim=config.model.proj_dim,
metadata_train=train_metadata,
metadata_val=val_metadata
)
trainer = RNCETrainer(
model=model,
train_loader=train_loader,
val_loader=val_loader,
optimizer=optimizer,
config=config,
train_metadata=train_metadata,
val_metadata=val_metadata
)
history = trainer.train(num_epochs=100)from rnce.data import get_augmentation_transform
transform = get_augmentation_transform({
'affine': {'degrees': 5},
'blur': {'std_range': [0, 0.3]},
'noise': {'std_range': [0, 0.05]}
})import torch
from rnce import RNCEModel
checkpoint = torch.load('checkpoints/best_model.pt')
model = RNCEModel(rep_dim=64, proj_dim=128, metadata_train=metadata)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
with torch.no_grad():
out = model(patch, sample_indices)
h = out['h'] # Encoder representations (metadata info + novel features)3D CNN architecture:
- Input: [1, D, H, W] brain patch
- 2 conv layers (1→8→8)
- 3 blocks (8→16→32→64)
- 2 conv layers (64→128→rep_dim)
- MaxPool + FC
- Output: [rep_dim] (default 64)
Removes metadata signal from encoder representations:
- Precomputes (X_meta^T X_meta)^-1
- Applies: h_resid = (I - batch_ratio * X_meta @ inv @ X_meta^T) @ h
- Creates residualized representations orthogonal to metadata, used for contrastive learning
2-layer MLP: Linear(rep_dim → proj_dim) → ReLU → Linear(proj_dim → proj_dim)
InfoNCE loss applied to residualized embeddings z:
L_contrastive = -log(exp(sim(z_i, z_j) / τ) / Σ exp(sim(z_i, z_k) / τ))
Encourages augmentation-invariant features orthogonal to metadata.
Encourages encoder representations h to predict metadata:
L_resid = ||residual||_F
where residual = (I - H @ pinv(H)) @ X_meta
Minimized when h captures metadata information.
L_total = α * L_contrastive(z) + β * L_resid(h)
Default: α = 1.0, β = 1.0