A Pytorch implement of retina vessel segementation
DRIVE dataset is provided in ./data directory
I define 3 types of unet model in the ./models directory
- original_unet: implemented as the essay describe
- net2: add padding for convolution/deconvolution kernels in order to maintain input shape
- net_improve: add batch_normalization for each layer to converge faster
Data Prefetcher is used to boost loading speed. Here I set patch_per_img = 19000, so loading with cpu can be very slow.
Get training dataset:
- python generate_train_dataset.py
HDF5 files will be saved under folder hdf5
set epoch number, device,... in config.py
- save_pth: the name of folder saving checkpoints
- batch_size
- device: in pytorch device string format (cpu, cuda, cuda:0, etc.)
Then run:
- python main.py