Skip to content

Commit

Permalink
position based strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
victorcaquilpan committed Jul 19, 2024
1 parent adaad81 commit 55bfb29
Show file tree
Hide file tree
Showing 5 changed files with 405 additions and 139 deletions.
5 changes: 2 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,14 @@
* **Refinement MaskGAN version**: An extended model, which refined images using a simple, yet effective multi-stage, multi-plane approach is develop to improve the volumetric definition of synthetic images.
* **Model enhancements**: We include selection strategies to choose similar MRI/CT matches based on the position of slices.


## MaskGAN Framework

A novel unsupervised MR-to-CT synthesis method that preserves the anatomy under the explicit supervision of coarse masks without using costly manual annotations. MaskGAN bypasses the need for precise annotations, replacing them with standard (unsupervised) image processing techniques, which can produce coarse anatomical masks.
A novel unsupervised MR-to-CT synthesis method that preserves the anatomy under the explicit supervision of coarse masks without using costly manual annotations. MaskGAN bypasses the need for precise annotations, replacing them with standard (unsupervised) image processing techniques, which can produce coarse anatomical masks.
Such masks, although imperfect, provide sufficient cues for MaskGAN to capture anatomical outlines and produce structurally consistent images.

![Framework](./imgs/maskgan_v2.svg)

## Comparsion with State-of-the-Art Methods on Paediatric MR-CT Synthesis
## Comparison with State-of-the-Art Methods on Paediatric MR-CT Synthesis
![Result](./imgs/results.jpg)


Expand Down
20 changes: 19 additions & 1 deletion data/unaligned_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,14 @@ def __init__(self, opt):
# self.transform_maskA = get_transform(self.opt, grayscale=(self.input_nc == 1), mask=True)
# self.transform_maskB = get_transform(self.opt, grayscale=(self.output_nc == 1), mask=True)

# Save relative position of each img Input images should be given in the format xxxx_RELATIVEPOSITION.jpg
self.relative_pos_A = [int(img.split(".")[-2].split("_")[-1]) for img in self.A_paths]
self.relative_pos_B = [int(img.split(".")[-2].split("_")[-1]) for img in self.B_paths]
# Define range of adjacent slices to consider
if opt.phase == 'train':
self.position_based_range = opt.position_based_range*10


def __getitem__(self, index):
"""Return a data point and its metadata information.
Expand All @@ -66,7 +74,17 @@ def __getitem__(self, index):
if self.opt.serial_batches: # make sure index is within then range
index_B = index % self.B_size
else: # randomize the index for domain B to avoid fixed pairs.
index_B = random.randint(0, self.B_size - 1)
# Check the relative position of the image (Position based selection PBS)
A_path_spplited = A_path.split(".")
A_relative_position = A_path_spplited[-2].split("_")[-1]
# Convert to a number
A_relative_position = float(A_relative_position)
# Obtain the images in a similar range (Position based selection)
potential_indexes = [index for index, value in enumerate(self.relative_pos_B) if (A_relative_position-self.position_based_range) <= value <= (A_relative_position + self.position_based_range)]
# Define position of B image
potential_indexes = list(set(potential_indexes) & set(potential_indexes))
index_position = random.randint(0, len(potential_indexes) - 1)
index_B = potential_indexes[index_position]
B_path = self.B_paths[index_B]
maskA_path = self.maskA_paths[index_A]
maskB_path = self.maskB_paths[index_B]
Expand Down
1 change: 1 addition & 0 deletions options/train_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def initialize(self, parser):
parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images')
parser.add_argument('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]')
parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations')
parser.add_argument('--position_based_range', type = int, default=3, help='Define the range for the position-based selection strategy (PBS). In percentage')

self.isTrain = True
return parser
47 changes: 31 additions & 16 deletions preprocess/README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Preprocess MR-CT data and generate masks
- For simplicity, we assume the dataset have all pairs MRI-CT.
- The simplified code only has a single for-loop to partition 80/20/20 train/val/test.
- If you have an unpaired training set, i.e., source and target modalities do not match. You can simply separate the training data preprocessing and val/test data preprocessing by copy-paste the for-loop.
- For simplicity, we assume the dataset have MRI and CT scans.
- This code assumes that you have the raw data in three folders: **train**, **val** and **test**.
- train set contains only **unpaired images**, whereas val and set contain **paired images**.

## Environment installation
Setup using `pip install -r requirements.txt`
Expand All @@ -10,19 +10,30 @@ Setup using `pip install -r requirements.txt`
- Refer to your root folder as `root`. Assume your data structure is as follows
```bash
├── root/
│ ├── MRI/
│ │ ├── filename001.nii
│ │ ├── filename002.nii
│ │ └── ...
│ └── CT/
│ ├── filename001.nii
│ ├── filename002.nii
│ └── ...
```
- If your data structure is different, please modify the pattern matching expression at lines 138-139 in `preprocess/main.py`:
```python
root_a = f'{data_dir}/MRI/*.nii'
root_b = f'{data_dir}/CT/*.nii'
├── train/
| ├── MRI/
│ │ ├── filename001.nii
│ │ ├── filename002.nii
│ │ └── ...
| └── CT/
| ├── filename001.nii
| ├── filename002.nii
| └── ...
├── val/
| ├── MRI/
| | ├── filename003.nii
| | └── filename004.nii
| └── CT/
| ├── filename003.nii
| └── filename004.nii
└── test/
├── MRI/
| ├── filename005.nii
| └── filename006.nii
└── CT/
├── filename005.nii
└── filename006.nii

```

## Preprocess
Expand All @@ -32,3 +43,7 @@ root_b = f'{data_dir}/CT/*.nii'
- `--resample`: resample the resolution of the medical scans, default is [1.0, 1.0, 1.0] mm^3.

Since our paediatric scans have irregular sizes, we need to crop the depth and height dimensions in function `crop_scan()` at Ln 47. When running, the preprocessed 2D slice visualizations are saved under `vis` for your inspection. Use them to modify data augmentation `crop_scan()` as needed.

## Update!

Now, to use the based position selection strategy, our preprocessing stage generate files as: filename_XXX.jpg, where XXX corresponds to the relative position of the slice respect to the entire volumetric image. In that way, our model can choose slices of similar position.
Loading

0 comments on commit 55bfb29

Please sign in to comment.