Skip to content

Code Repository for R-NCE (Residual Noise Contrastive Estimation)

Notifications You must be signed in to change notification settings

batmanlab/R-NCE

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

R-NCE: Residual Noise Contrastive Estimation

Implementation of R-NCE (Residual Noise Contrastive Estimation) for learning disease-related biomarkers from brain MRI patches.

Overview

R-NCE combines two learning objectives:

  1. Metadata Prediction: Encoder representations (h) learn to predict known features (e.g., cortical thickness, subcortical volume)
  2. Residualized Contrastive Learning: Contrastive loss applied to residualized embeddings that are orthogonal to known features

The key innovation:

  • Representations (h) from the encoder predict metadata features via residualization loss
  • Residualized embeddings are obtained by removing metadata signal from
  • Contrastive loss is applied to embeddings, encouraging augmentation-invariant features orthogonal to metadata

This results in representations containing both known biological features AND novel, augmentation-invariant information.

Installation

git clone https://github.com/yourusername/rnce.git
cd rnce
pip install -e .

Requirements: Python >= 3.8, PyTorch >= 2.0.0, TorchIO >= 0.18.0

Quick Start

1. Prepare Data

R-NCE expects pre-extracted brain patches as .npy or .pt files, plus a metadata CSV.

Data structure:

data/
├── patches/
│   ├── left_cingulate/
│   │   ├── subject_0001.npy      # Shape: [1, 40, 122, 82]
│   │   ├── subject_0002.npy
│   │   └── ...
└── metadata.csv

Example Brain region patch sizes:

Region Left Right
MTL 57 x 107 x 64 57 x 107 x 64
LTL 70 x 115 x 90 70 x 115 x 90
Occipital 79 x 87 x 82 80 x 87 x 82
Parietal 86 x 108 x 94 87 x 108 x 94
Subcortical 63 x 90 x 75 63 x 90 x 75
Cingulate 40 x 122 x 82 41 x 122 x 82
Medial Frontal 50 x 138 x 128 50 x 138 x 128
Posterior Frontal 78 x 98 x 99 78 x 98 x 99
Anterior Frontal 74 x 84 x 104 74 x 84 x 104

metadata.csv format:

subject_id,patch_path,cortical_thickness,subcortical_volume,surface_area
0001,data/patches/left_cingulate/subject_0001.npy,2.45,7500.2,1850.3
0002,data/patches/left_cingulate/subject_0002.npy,2.38,7200.8,1820.1
...

Metadata columns should contain FreeSurfer features (cortical thickness, subcortical volumes, surface area) that will be predicted by encoder representations.

2. Configure Training

See configs/left_cingulate.yaml:

model:
  rep_dim: 64
  proj_dim: 128

data:
  metadata_csv: "data/metadata.csv"
  metadata_columns:
    - lh_caudalanteriorcingulate_thickness
    - lh_isthmuscingulate_thickness
    - lh_posteriorcingulate_thickness
    - lh_rostralanteriorcingulate_thickness
  train_val_split: 0.85

training:
  batch_size: 128
  num_epochs: 400
  lr: 0.00003
  optimizer: 'Adam'
  temp: 0.5
  scale_contrastive: 1.0
  scale_residual: 1.0

3. Train

python scripts/train.py --config configs/left_cingulate.yaml --device cuda

4. Extract Embeddings

python scripts/extract_embeddings.py \
    --checkpoint checkpoints/best_model.pt \
    --metadata_csv data/metadata.csv \
    --output embeddings.pt

Usage Examples

Training with Custom Config

from rnce import RNCEModel, RNCETrainer, load_config

config = load_config('my_config.yaml')

model = RNCEModel(
    rep_dim=config.model.rep_dim,
    proj_dim=config.model.proj_dim,
    metadata_train=train_metadata,
    metadata_val=val_metadata
)

trainer = RNCETrainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    config=config,
    train_metadata=train_metadata,
    val_metadata=val_metadata
)

history = trainer.train(num_epochs=100)

Custom Augmentation

from rnce.data import get_augmentation_transform

transform = get_augmentation_transform({
    'affine': {'degrees': 5},
    'blur': {'std_range': [0, 0.3]},
    'noise': {'std_range': [0, 0.05]}
})

Inference

import torch
from rnce import RNCEModel

checkpoint = torch.load('checkpoints/best_model.pt')
model = RNCEModel(rep_dim=64, proj_dim=128, metadata_train=metadata)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

with torch.no_grad():
    out = model(patch, sample_indices)
    h = out['h']  # Encoder representations (metadata info + novel features)

Architecture

PatchConvNet Encoder

3D CNN architecture:

  • Input: [1, D, H, W] brain patch
  • 2 conv layers (1→8→8)
  • 3 blocks (8→16→32→64)
  • 2 conv layers (64→128→rep_dim)
  • MaxPool + FC
  • Output: [rep_dim] (default 64)

Residualization Head

Removes metadata signal from encoder representations:

  • Precomputes (X_meta^T X_meta)^-1
  • Applies: h_resid = (I - batch_ratio * X_meta @ inv @ X_meta^T) @ h
  • Creates residualized representations orthogonal to metadata, used for contrastive learning

Projection Head

2-layer MLP: Linear(rep_dim → proj_dim) → ReLU → Linear(proj_dim → proj_dim)

Loss Functions

Contrastive Loss

InfoNCE loss applied to residualized embeddings z:

L_contrastive = -log(exp(sim(z_i, z_j) / τ) / Σ exp(sim(z_i, z_k) / τ))

Encourages augmentation-invariant features orthogonal to metadata.

Residualization Loss

Encourages encoder representations h to predict metadata:

L_resid = ||residual||_F
where residual = (I - H @ pinv(H)) @ X_meta

Minimized when h captures metadata information.

Combined Loss

L_total = α * L_contrastive(z) + β * L_resid(h)

Default: α = 1.0, β = 1.0

About

Code Repository for R-NCE (Residual Noise Contrastive Estimation)

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages