Skip to content

Commit

Permalink
Addition of U-Net Model (deepchem#3919)
Browse files Browse the repository at this point in the history
* Adding U-Net Model

* Linting fixes

* Added tests

* Linting fix

* Modified test input size

* Modified test input size to base 2

* Modified test input size for restore

* Added sigmoid activation, fixed tests

* Model restore test fix

* Added docs and usage examples

* Added notes to docs, made test image smaller

* Added overfit test

* Doc fixes, added to model rst

* Added type annotations, added model to cheatsheet

* Fixed typo
  • Loading branch information
aaronrockmenezes authored Apr 19, 2024
1 parent 352f8ef commit 13b7099
Show file tree
Hide file tree
Showing 5 changed files with 298 additions and 0 deletions.
1 change: 1 addition & 0 deletions deepchem/models/torch_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from deepchem.models.torch_models.acnn import AtomConvModel
from deepchem.models.torch_models.progressive_multitask import ProgressiveMultitask, ProgressiveMultitaskModel
from deepchem.models.torch_models.text_cnn import TextCNNModel
from deepchem.models.torch_models.unet import UNet, UNetModel
try:
from deepchem.models.torch_models.dmpnn import DMPNN, DMPNNModel
from deepchem.models.torch_models.gnn import GNN, GNNHead, GNNModular
Expand Down
80 changes: 80 additions & 0 deletions deepchem/models/torch_models/tests/test_unet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import pytest
import numpy as np
import deepchem as dc
import tempfile


@pytest.mark.torch
def test_unet_forward():
from deepchem.models.torch_models import UNetModel

# 5 RGB 16x16 pixel input images and 5 grey scale 16x16 pixel output segmentation masks
input_samples = np.random.randn(5, 3, 16, 16).astype(np.float32)
output_samples = np.random.rand(5, 1, 16, 16).astype(np.float32)

# Model works with ImageDataset as well as NumpyDataset.
# Using NumpyDataset for testing
np_dataset = dc.data.NumpyDataset(input_samples, output_samples)

unet_model = UNetModel(in_channels=3, out_channels=1)

unet_model.fit(np_dataset, nb_epoch=1)
pred = unet_model.predict(np_dataset)

assert pred.shape == output_samples.shape


@pytest.mark.torch
def test_unet_restore():
from deepchem.models.torch_models import UNetModel

# 5 RGB 16x16 pixel input images and 5 grey scale 16x16 pixel output segmentation masks
input_samples = np.random.randn(5, 3, 16, 16).astype(np.float32)
output_samples = np.random.rand(5, 1, 16, 16).astype(np.float32)

# Using ImageDataset for testing
np_dataset = dc.data.ImageDataset(input_samples, output_samples)

model_dir = tempfile.mkdtemp()
unet_model = UNetModel(in_channels=3, out_channels=1, model_dir=model_dir)

unet_model.fit(np_dataset, nb_epoch=1)
pred = unet_model.predict(np_dataset)

reloaded_model = UNetModel(in_channels=3,
out_channels=1,
model_dir=model_dir)

reloaded_model.restore()

pred = unet_model.predict(np_dataset)
reloaded_pred = reloaded_model.predict(np_dataset)

assert len(pred) == len(reloaded_pred)
assert np.allclose(pred, reloaded_pred, atol=1e-04)


@pytest.mark.torch
def test_unet_overfit():
from deepchem.models.torch_models import UNetModel

# 5 RGB 16x16 pixel input images and 5 grey scale 16x16 pixel output segmentation masks
input_samples = np.random.randn(5, 3, 16, 16).astype(np.float32)
output_samples = np.random.rand(5, 1, 16, 16).astype(np.float32)

# Using ImageDataset for testing
np_dataset = dc.data.NumpyDataset(input_samples, output_samples)

regression_metric = dc.metrics.Metric(dc.metrics.mean_squared_error,
mode='regression')

model_dir = tempfile.mkdtemp()
unet_model = UNetModel(in_channels=3, out_channels=1, model_dir=model_dir)

unet_model.fit(np_dataset, nb_epoch=100)
pred = unet_model.predict(np_dataset)

scores = regression_metric.compute_metric(np_dataset.y.reshape(5, -1),
pred.reshape(5, -1))

assert scores < 0.05, "Failed to overfit"
210 changes: 210 additions & 0 deletions deepchem/models/torch_models/unet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
from deepchem.models.torch_models.modular import TorchModel
from deepchem.models.losses import BinaryCrossEntropy
import torch
import torch.nn as nn
import torch.nn.functional as F


class UNet(nn.Module):
"""
UNet model for image segmentation.
UNet is a convolutional neural network architecture for fast and precise segmentation of images
based on the works of Ronneberger et al. [1]. The architecture consists of an encoder, a bottleneck,
and a decoder. The encoder downsamples the input image to capture the context of the image. The
bottleneck captures the most important features of the image. The decoder upsamples the image to
generate the segmentation mask. The encoder and decoder are connected by skip connections to preserve
spatial information.
Examples
--------
Importing necessary modules
>>> import numpy as np
>>> import deepchem as dc
>>> from deepchem.models.torch_models import UNet
Creating a random dataset of 5 32x32 pixel RGB input images and 5 32x32 pixel grey scale output images
>>> x = np.random.randn(5, 3, 32, 32).astype(np.float32)
>>> y = np.random.rand(5, 1, 32, 32).astype(np.float32)
>>> dataset = dc.data.NumpyDataset(x, y)
We will create a UNet model with 3 input channels and 1 output channel. We will then fit the model on the dataset for 5 epochs and predict the output images.
>>> model = UNetModel(in_channels=3, out_channels=1)
>>> model.fit(dataset, nb_epoch=5)
>>> predictions = model.predict(dataset)
Notes
-----
1. This implementation of the UNet model makes some changes to the padding of the inputs to the convolutional layers.
The padding is set to 'same' to ensure that the output size of the convolutional layers is the same as the input size.
This is done to preserve the spatial information of the input image and to keep the output size of the encoder and decoder the same.
2. The input image size must be divisible by 2^4 = 16 to ensure that the output size of the encoder and decoder is the same.
References
----------
.. [1] Ronneberger, O., Fischer, P., & Brox, T. (2015, May 18). U-NET: Convolutional Networks for Biomedical Image Segmentation. arXiv.org. https://arxiv.org/abs/1505.04597
"""

def __init__(self, in_channels: int = 3, out_channels: int = 1):
"""
Parameters
----------
in_channels: int (default 3)
Number of input channels.
out_channels: int (default 1)
Number of output channels.
"""
super(UNet, self).__init__()

# Encoder
self.encoder1 = self.conv_block(in_channels, 64)
self.encoder2 = self.conv_block(64, 128)
self.encoder3 = self.conv_block(128, 256)
self.encoder4 = self.conv_block(256, 512)

# Bottleneck
self.bottleneck = self.conv_block(512, 1024)

# Decoder
self.decoder4 = self.conv_block(1024 + 512, 512)
self.decoder3 = self.conv_block(512 + 256, 256)
self.decoder2 = self.conv_block(256 + 128, 128)
self.decoder1 = self.conv_block(128 + 64, 64)

# Maxpooling
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

# Upsampling
self.upsample = nn.Upsample(scale_factor=2,
mode='bilinear',
align_corners=True)

# Output
self.output = nn.Conv2d(64, out_channels, kernel_size=1)

def conv_block(self, in_channels: int, out_channels: int):
"""
Parameters
----------
in_channels: int
Number of input channels.
out_channels: int
Number of output channels.
"""
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding='same'),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3,
padding='same'), nn.ReLU(inplace=True))

def forward(self, x: torch.Tensor):
"""
Parameters
----------
x: Tensor
Input tensor.
Returns
-------
x: Tensor
Output tensor.
"""

# Encoder
x1 = self.encoder1(x)
x = self.maxpool(x1)
x2 = self.encoder2(x)
x = self.maxpool(x2)
x3 = self.encoder3(x)
x = self.maxpool(x3)
x4 = self.encoder4(x)
x = self.maxpool(x4)

# Bottleneck
x = self.bottleneck(x)

# Decoder
x = self.upsample(x)
x = self.decoder4(torch.cat([x, x4], 1))
x = self.upsample(x)
x = self.decoder3(torch.cat([x, x3], 1))
x = self.upsample(x)
x = self.decoder2(torch.cat([x, x2], 1))
x = self.upsample(x)
x = self.decoder1(torch.cat([x, x1], 1))

# Output
x = self.output(x)
x = F.sigmoid(x)
return x


class UNetModel(TorchModel):
"""
UNet model for image segmentation.
UNet is a convolutional neural network architecture for fast and precise segmentation of images
based on the works of Ronneberger et al. [1]. The architecture consists of an encoder, a bottleneck,
and a decoder. The encoder downsamples the input image to capture the context of the image. The
bottleneck captures the most important features of the image. The decoder upsamples the image to
generate the segmentation mask. The encoder and decoder are connected by skip connections to preserve
spatial information.
Examples
--------
Importing necessary modules
>>> import numpy as np
>>> import deepchem as dc
>>> from deepchem.models.torch_models import UNet
Creating a random dataset of 5 32x32 pixel RGB input images and 5 32x32 pixel grey scale output images
>>> x = np.random.randn(5, 3, 32, 32).astype(np.float32)
>>> y = np.random.rand(5, 1, 32, 32).astype(np.float32)
>>> dataset = dc.data.NumpyDataset(x, y)
We will create a UNet model with 3 input channels and 1 output channel. We will then fit the model on the dataset for 5 epochs and predict the output images.
>>> model = UNetModel(in_channels=3, out_channels=1)
>>> model.fit(dataset, nb_epoch=5)
>>> predictions = model.predict(dataset)
Notes
-----
1. This implementation of the UNet model makes some changes to the padding of the inputs to the convolutional layers.
The padding is set to 'same' to ensure that the output size of the convolutional layers is the same as the input size.
This is done to preserve the spatial information of the input image and to keep the output size of the encoder and decoder the same.
2. The input image size must be divisible by 2^4 = 16 to ensure that the output size of the encoder and decoder is the same.
References
----------
.. [1] Ronneberger, O., Fischer, P., & Brox, T. (2015, May 18). U-NET: Convolutional Networks for Biomedical Image Segmentation. arXiv.org. https://arxiv.org/abs/1505.04597
"""

def __init__(self, in_channels: int = 3, out_channels: int = 1, **kwargs):
"""
Parameters
----------
input_channels: int (default 3)
Number of input channels.
output_channels: int (default 1)
Number of output channels.
"""
if in_channels <= 0:
raise ValueError("input_channels must be greater than 0")

if out_channels <= 0:
raise ValueError("output_channels must be greater than 0")

model = UNet(in_channels=in_channels, out_channels=out_channels)

if 'loss' not in kwargs:
kwargs['loss'] = BinaryCrossEntropy()

super(UNetModel, self).__init__(model, **kwargs)
1 change: 1 addition & 0 deletions docs/source/api_reference/general_purpose_models.csv
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ RobustMultitaskClassifier,`ref <https://pubs.acs.org/doi/abs/10.1021/acs.jcim.7b
RobustMultitaskRegressor,`ref <https://pubs.acs.org/doi/abs/10.1021/acs.jcim.7b00146>`_,Regressor,CircularFingerprint RDKitDescriptors CoulombMatrixEig RdkitGridFeaturizer BindingPocketFeaturizer ElementPropertyFingerprint,Keras,
SeqToSeq,`ref <https://arxiv.org/abs/1409.3215>`_,,,PyTorch,fit method: :code:`fit_sequences`
WGAN,`ref <https://arxiv.org/abs/1701.07875>`_,Adversarial,,Keras,fit method: :code:`fit_gan`
UNet, `ref <https://arxiv.org/abs/1505.04597v1>`_,Classifier/ Regressor,,PyTorch,
6 changes: 6 additions & 0 deletions docs/source/api_reference/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,12 @@ TextCNNModel
.. autoclass:: deepchem.models.torch_models.TextCNNModel
:members:

UNetModel
------------

.. autoclass:: deepchem.models.torch_models.UNetModel
:members:

PyTorch Lightning Models
========================

Expand Down

0 comments on commit 13b7099

Please sign in to comment.