Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
Dynamo13 authored Sep 1, 2023
1 parent 89c9409 commit 3bfda05
Show file tree
Hide file tree
Showing 5 changed files with 305 additions and 0 deletions.
84 changes: 84 additions & 0 deletions SE_UResNet_tf.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "c0744c9f",
"metadata": {},
"outputs": [],
"source": [
"from dataloader import *\n",
"from model import *\n",
"from main import *\n",
"from utils import *\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "04766fc7",
"metadata": {},
"outputs": [],
"source": [
"image_height=256\n",
"image_width=256\n",
"image_channel=3"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dbd05bca",
"metadata": {},
"outputs": [],
"source": [
"img_dir='E:/Data_test/images'\n",
"mask_dir='E:/Data_test/mask'\n",
"weight_dir='E:/Data_test/weights'\n",
"\n",
"main(img_dir,mask_dir,weight_dir)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f8393f0f",
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"path = r'E:\\Data_test\\images\\CHNCXR_0001_0.png'\n",
"model=SE_UResNet((image_height,image_width,image_channel),num_classes, dropout_rate=0.0, batch_norm=True)\n",
"model.compile(optimizer='adam',loss='binary_crossentropy',metrics=['accuracy'])\n",
"model.load_weights(str(weight_dir)+'/weights.h5')\n",
"\n",
"test_image = cv2.resize(cv2.imread(path),(image_height,image_width))\n",
"predicted_image=model.predict(test_image.reshape(1,image_height,image_width,image_channel))\n",
"plt.imshow(predicted_image[0]>0.5,cmap='gray')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.17"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
33 changes: 33 additions & 0 deletions dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import keras
import numpy as np
import cv2

class Loader(keras.utils.Sequence):
"""Helper to iterate over the data (as Numpy arrays)."""

def __init__(self, batch_size, img_size, input_img_paths, mask_img_paths,image_channel,num_classes):
self.batch_size = batch_size
self.img_size = img_size
self.input_img_paths = input_img_paths
self.mask_img_paths = mask_img_paths
self.num_classes = num_classes
self.image_channel=image_channel


def __len__(self):
return len(self.mask_img_paths) // self.batch_size

def __getitem__(self, idx):
"""Returns tuple (input, target) correspond to batch #idx."""
i = idx * self.batch_size
batch_input_img_paths = self.input_img_paths[i : i + self.batch_size]
batch_mask_img_paths = self.mask_img_paths[i : i + self.batch_size]
x = np.zeros((self.batch_size,) + self.img_size +(self.image_channel,) , dtype="uint8")
for j, path in enumerate(batch_input_img_paths):
img = cv2.resize(cv2.imread(path),self.img_size)
x[j]=img
y = np.zeros((self.batch_size,) + self.img_size + (self.num_classes,), dtype="uint8")
for j, path in enumerate(batch_mask_img_paths):
msk = cv2.resize(cv2.imread(path),self.img_size)
y[j]=np.expand_dims((msk[:,:,0]/255).astype('uint8'),axis=-1)
return x, y
65 changes: 65 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from dataloader import*
from model import*
import random
import os
def main(input_dir, mask_dir,weight_dir,
image_height=256,
image_width=256,
image_channel=3,
img_size = (256,256),
num_classes = 1,
batch_size = 8,
epochs=100,
val_samples = 40,):

img_size = (image_height,image_width)
input_img_paths = sorted(
[
os.path.join(input_dir, fname)
for fname in os.listdir(input_dir)
]
)
mask_img_paths = sorted(
[
os.path.join(mask_dir, fname)
for fname in os.listdir(mask_dir)
]
)
random.Random(1337).shuffle(input_img_paths)
random.Random(1337).shuffle(mask_img_paths)
train_input_img_paths = input_img_paths[:-val_samples]
train_mask_img_paths = mask_img_paths[:-val_samples]
val_input_img_paths = input_img_paths[-val_samples:]
val_mask_img_paths = mask_img_paths[-val_samples:]

# Instantiate data Sequences for each split
train_gen = Loader(
batch_size, img_size, train_input_img_paths, train_mask_img_paths,image_channel,num_classes
)
val_gen = Loader(batch_size, img_size, val_input_img_paths, val_mask_img_paths,image_channel,num_classes)

model=SE_UResNet((image_height,image_width,image_channel),num_classes, dropout_rate=0.0, batch_norm=True)

model.compile(optimizer='adam',loss='binary_crossentropy',metrics=['accuracy'])
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=weight_dir+"\\weights.h5",
save_weights_only=True,
monitor='val_loss',
mode='min',
save_best_only=True
)
rlp =tf.keras.callbacks.ReduceLROnPlateau(
monitor='val_loss',
factor=0.001,
patience=10,
verbose=1,
mode='auto',
min_delta=0.00005)
#es=tf.keras.callbacks.EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=15)
history = model.fit(train_gen,validation_data=val_gen,epochs=epochs,callbacks=[checkpoint_callback,rlp])

if __name__ == "__main__":
img_dir = sys.argv[1]
mask_dir = sys.argv[2]
weight_dir = sys.argv[3]
main(img_dir,mask_dir,weight_dir)
68 changes: 68 additions & 0 deletions model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from utils import *
import tensorflow as tf
def SE_UResNet(input_shape, NUM_CLASSES=1, dropout_rate=0.0, batch_norm=True):
'''
Attention UNet,
'''
# network structure
FILTER_NUM = 32 # number of basic filters for the first layer
FILTER_SIZE = 3 # size of the convolutional filter
UP_SAMP_SIZE = 2 # size of upsampling filters

inputs = layers.Input(input_shape, dtype=tf.float32)

# Downsampling layers
# DownRes 1, convolution + pooling
conv_128 = conv_block(inputs, FILTER_SIZE, FILTER_NUM, 0.2, 1, batch_norm)
pool_64 = layers.MaxPooling2D(pool_size=(2,2))(conv_128)
# DownRes 2
conv_64 = conv_block(pool_64, FILTER_SIZE, 2*FILTER_NUM, 0.2, 2, batch_norm)
pool_32 = layers.MaxPooling2D(pool_size=(2,2))(conv_64)
# DownRes 3
conv_32 = conv_block(pool_32, FILTER_SIZE, 4*FILTER_NUM,0.2, 3, batch_norm)
pool_16 = layers.MaxPooling2D(pool_size=(2,2))(conv_32)
# DownRes 4
conv_16 = conv_block(pool_16, FILTER_SIZE, 8*FILTER_NUM,0.2,4, batch_norm)
pool_8 = layers.MaxPooling2D(pool_size=(2,2))(conv_16)
# DownRes 5, convolution only
conv_8 = conv_block(pool_8, FILTER_SIZE, 16*FILTER_NUM, 0.2,5, batch_norm)

# W-net layers
attw_16 = se_block(conv_16, 8*FILTER_NUM)
upw_16 = layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE), data_format="channels_last")(conv_8)
upw_16 = layers.concatenate([upw_16, attw_16], axis=3)
up_convw_16 = resb(upw_16, FILTER_SIZE, 8*FILTER_NUM, 0.2,6, batch_norm)

poolw_8 = layers.MaxPooling2D(pool_size=(2,2))(up_convw_16)
convw_16 = conv_block(poolw_8, FILTER_SIZE, 16*FILTER_NUM, 0.2,7, batch_norm)

# UpRes 6, attention gated concatenation + upsampling + double residual convolution
att_16 = se_block(up_convw_16, 8*FILTER_NUM)
up_16 = layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE), data_format="channels_last")(convw_16)
up_16 = layers.concatenate([up_16, att_16], axis=3)
up_conv_16 = conv_block(up_16, FILTER_SIZE, 8*FILTER_NUM, 0.2,8, batch_norm)
# UpRes 7
att_32 = se_block(conv_32, 4*FILTER_NUM)
up_32 = layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE), data_format="channels_last")(up_conv_16)
up_32 = layers.concatenate([up_32, att_32], axis=3)
up_conv_32 =conv_block(up_32, FILTER_SIZE, 4*FILTER_NUM, dropout_rate,9, batch_norm)
# UpRes 8
att_64 = se_block(conv_64, 2*FILTER_NUM)
up_64 = layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE), data_format="channels_last")(up_conv_32)
up_64 = layers.concatenate([up_64, att_64], axis=3)
up_conv_64 = conv_block(up_64, FILTER_SIZE, 2*FILTER_NUM, 0.2,10, batch_norm)
# UpRes 9
att_128 = se_block(conv_128, FILTER_NUM)
up_128 = layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE), data_format="channels_last")(up_conv_64)
up_128 = layers.concatenate([up_128, att_128], axis=3)
up_conv_128 = conv_block(up_128, FILTER_SIZE, FILTER_NUM, 0.2,11, batch_norm)

# 1*1 convolutional layers
conv_final = layers.Conv2D(NUM_CLASSES, kernel_size=(1,1))(up_conv_128)
conv_final = layers.BatchNormalization(axis=3)(conv_final)
conv_final = layers.Activation('sigmoid')(conv_final) #Change to softmax for multichannel

# Model integration
model = models.Model(inputs, conv_final, name="Attention_UWNet")
return model
55 changes: 55 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from tensorflow.keras import models, layers, regularizers
from tensorflow.keras import backend as K
import tensorflow as tf

def conv_block(x, filter_size, size, dropout,num, batch_norm=False):

conv = layers.Conv2D(size, (filter_size, filter_size), padding="same")(x)
if batch_norm is True:
conv = layers.BatchNormalization(axis=3)(conv)
conv = layers.Activation("relu")(conv)

conv = layers.Conv2D(size, (filter_size, filter_size), padding="same",name="conv"+str(num))(conv)
if batch_norm is True:
conv = layers.BatchNormalization(axis=3)(conv)
conv = layers.Activation("relu")(conv)

if dropout > 0:
conv = layers.Dropout(dropout)(conv)

return conv

def se_block(x,r):
copy=x
gap=layers.GlobalAveragePooling2D()(x)
flat=layers.Flatten()(gap)
dense=layers.Dense(flat.shape[-1]//r, activation = 'relu')(gap)
dense=layers.Dense(flat.shape[-1], activation = 'sigmoid')(dense)
m =layers.multiply([dense,copy])
return m

def resb(x, filter_size, size, dropout,num, batch_norm=False):
# copy tensor to variable called x_skip
x_skip = x
x_skip=layers.Conv2D(1, (1, 1), padding="same")(x_skip)
print(x_skip.shape)
# Layer 1
x = layers.Conv2D(size, (filter_size, filter_size), padding="same")(x)
if batch_norm is True:
x = layers.BatchNormalization(axis=3)(x)
x = layers.Activation("relu")(x)
if dropout > 0:
x = layers.Dropout(dropout)(x)
print(x.shape)
# Layer 2
x = layers.Conv2D(size, (filter_size, filter_size), padding="same")(x)
if batch_norm is True:
x = layers.BatchNormalization(axis=3)(x)
x = layers.Activation("relu")(x)
if dropout > 0:
x = layers.Dropout(dropout)(x)
print(x.shape)
# Add Residue
x = tf.keras.layers.Add()([x, x_skip])
x = tf.keras.layers.Activation('relu')(x)
return x

0 comments on commit 3bfda05

Please sign in to comment.