Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
BLIP_MRI/project/hf_results
BLIP_MRI/project/hf_logs
BLIP_MRI/logs
BLIP_MRI/project/wandb
BLIP_MRI/project/dataset/__pychache__
BLIP_MRI/project/model/__pychache__
BLIP_MRI/project/utils/__pychache__
BLIP_MRI/project/*.json
BLIP_MRI/project/data/*.json
BLIP_MRI/project/check_metadata.py
BLIP_MRI/project/data/abcd
BLIP_MRI/project/data/gard
BLIP_MRI/project/data/ukb
__pycache__/
*.pyc
127 changes: 127 additions & 0 deletions BLIP_MRI/README.md
Original file line number Diff line number Diff line change
@@ -1 +1,128 @@
# LLaVA-NeXT-Interleave for Brain MRI Comparison Tasks

Multi-image comparison framework using LLaVA-NeXT-Interleave for brain MRI analysis.

---

## Overview

**Architecture:** LLaVA-NeXT-Interleave (Qwen-0.5b)
**Task:** Reference-augmented comparison
**Format:** Multi-turn conversation with interleaved images

**Example:**
```
Turn 1 (Reference):
User: "Here is a brain scan from a male participant. <image>"
Assistant: "Understood. I've analyzed the reference scan."

Turn 2 (Query):
User: "Compare this scan with the reference. <image> What is the sex?"
Assistant: "Based on the comparison, the subject is male."
```

---

## 1. Data Preparation

### Generate Comparison JSON

#### Sex(or any categorical attributes) Comparison (Categorical)

```bash
python generate_json_comparison_split_general.py \
--study_sample ABCD \
--meta_path /path/to/phenotype.csv \
--img_dir /path/to/images \
--target_col sex # (can be other categorical attributes)\
--task_type categorical \
--output_dir ./ \
--num_pairs 3 \
--seed 1234
```

**Output:**
- `data/ABCD_sex_comparison_tasks_train.json`
- `data/ABCD_sex_comparison_tasks_val.json`
- `data/ABCD_sex_comparison_tasks_test.json`

#### BMI_sds(or any numerical attributes) Regression (Numerical)

```bash
python generate_json_comparison_split_general.py \
--study_sample ABCD \
--meta_path /path/to/phenotype.csv \
--img_dir /path/to/images \
--target_col BMI_sds # (can be other numerical attributes) \
--task_type numerical \
--output_dir ./ \
--num_pairs 3 \
--seed 1234
```

**Output:**
- `data/ABCD_BMI_sds_comparison_tasks_train.json`
- `data/ABCD_BMI_sds_comparison_tasks_val.json`
- `data/ABCD_BMI_sds_comparison_tasks_test.json`

**Key Parameters:**
- `--num_pairs`: Number of references per query subject

---

## 2. Data Split Logic

**Complete Separation:**
- **Inter-split:** Train/Val/Test subjects do NOT overlap
- **Intra-split:** Query and Reference pools do NOT overlap within each split

**Example (1000 subjects, 70/15/15 split):**
```
Train: 700 subjects
├─ Query: 350 subjects
└─ Reference: 350 subjects (different from query!)

Val: 150 subjects
├─ Query: 75 subjects
└─ Reference: 75 subjects

Test: 150 subjects
├─ Query: 75 subjects
└─ Reference: 75 subjects
```

**Why?** Test subjects NEVER appear in training (even as references) for true generalization test

---

## 3. Training

### Configure

Edit `config/Brain_LLaVa_train_Deepspeed_joint_multiturn_comparison.yaml`:

```yaml
dataset:
target_col: "sex" # or "age", "diagnosis", etc.

train_json:
- "./data/ABCD_sex_comparison_tasks_train.json"
val_json:
- "./data/ABCD_sex_comparison_tasks_val.json"
test_json:
- "./data/ABCD_sex_comparison_tasks_test.json"

```

## 4. Troubleshooting

**Error: File not found**
- Check image paths in JSON match actual files

**Error: Image token mismatch**
- Ensure using updated `Trainer_LLaVA_Next_interleave.py`

**Low metrics**
- Check data split: Are train/val/test balanced?
- Review generation logs for prediction quality

6 changes: 3 additions & 3 deletions BLIP_MRI/environment_llava.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: /pscratch/sd/h/heehaw/anaconda/BLIP_MRI_llava
name: BLIP_MRI_llava
channels:
- conda-forge
dependencies:
Expand Down Expand Up @@ -125,7 +125,7 @@ dependencies:
- sympy==1.14.0
- threadpoolctl==3.6.0
- timm==0.4.12
- tokenizers==0.13.3
- tokenizers>=0.20
- torch==2.8.0
- torchvision==0.23.0
- tqdm==4.67.1
Expand All @@ -138,4 +138,4 @@ dependencies:
- wandb==0.17.0
- xxhash==3.5.0
- yarl==1.20.1
prefix: /pscratch/sd/h/heehaw/anaconda/BLIP_MRI_llava
prefix: /YOUR_DIRECTORY
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
wandb:
API_KEY: "YOUR_API_KEY"

seed: 1234

dataset:
# Target column
# Single-task:
target_col: "sex" # Options: 'sex', 'age', 'diagnosis', 'bmi', etc.

# Multi-task (future support):
# target_col: ["sex", "age"] # List for multiple targets

train_json:
- "./data/ABCD_sex_comparison_tasks_train.json"
val_json:
- "./data/ABCD_sex_comparison_tasks_val.json"
test_json:
- "./data/ABCD_sex_comparison_tasks_test.json"

# Image size
img_size: [120, 120, 120]


model:
hf_name: "llava-hf/llava-interleave-qwen-0.5b-hf"
patch_size: [10, 10, 10]


trainer:
max_epochs: 50
learning_rate: 0.00005
warmup_steps: 500
weight_decay: 0.01
per_device_batch_size: 1 # Multi-image (reference + query) requires more memory
gradient_accumulation_steps: 4
gradient_checkpointing: True
logging_steps: 1
ckpt_dir: "./hf_results/{}/last.ckpt"
resume_training: False
Loading