Skip to content

Commit

Permalink
Merge pull request #67 from HelmholtzAI-Consultants-Munich/random_forest
Browse files Browse the repository at this point in the history
Random forest
  • Loading branch information
christinab12 authored Jan 26, 2024
2 parents 683e0e4 + ee70715 commit 6999988
Show file tree
Hide file tree
Showing 10 changed files with 387 additions and 85 deletions.
6 changes: 4 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ jobs:
strategy:
matrix:
platform: [ubuntu-latest, windows-latest, macos-latest]
python-version: [3.8, 3.9, "3.10"]
python-version: [3.9, "3.10"]

steps:
- name: Checkout Repository
Expand All @@ -91,8 +91,10 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install setuptools
python -m pip install --upgrade setuptools
pip install numpy
pip install pytest
pip install wheel
pip install coverage
pip install -e ".[testing]"
working-directory: src/server
Expand Down
1 change: 1 addition & 0 deletions src/server/dcp_server/config.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"model_type": "cyto"
},
"classifier":{
"model_class": "RandomForest",
"in_channels": 1,
"num_classes": 3,
"features":[64,128,256,512],
Expand Down
128 changes: 102 additions & 26 deletions src/server/dcp_server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,22 @@
from torch import nn
from torch.optim import Adam
from torch.utils.data import TensorDataset, DataLoader
from torchmetrics import F1Score
from copy import deepcopy
from tqdm import tqdm
import numpy as np
from scipy.ndimage import label

from cellpose.metrics import aggregated_jaccard_index
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import f1_score, log_loss
from sklearn.exceptions import NotFittedError

from cellpose.metrics import aggregated_jaccard_index
from cellpose.dynamics import labels_to_flows
#from segment_anything import SamPredictor, sam_model_registry
#from segment_anything.automatic_mask_generator import SamAutomaticMaskGenerator

from dcp_server.utils import get_centered_patches, find_max_patch_size, create_patch_dataset
from dcp_server.utils import get_centered_patches, find_max_patch_size, create_patch_dataset, create_dataset_for_rf

class CustomCellposeModel(models.CellposeModel, nn.Module):
"""Custom cellpose model inheriting the attributes and functions from the original CellposeModel and implementing
Expand All @@ -39,6 +44,7 @@ def __init__(self, model_config, train_config, eval_config, model_name):
self.mkldnn = False # otherwise we get error with saving model
self.train_config = train_config
self.eval_config = eval_config
self.loss = 1e6
self.model_name = model_name

def update_configs(self, train_config, eval_config):
Expand Down Expand Up @@ -71,12 +77,27 @@ def train(self, imgs, masks):

if masks[0].shape[0] == 2:
masks = list(masks[:,0,...])

super().train(train_data=deepcopy(imgs), train_labels=masks, **self.train_config["segmentor"])

# compute loss and metric
true_bin_masks = [mask>0 for mask in masks] # get binary masks
true_flows = labels_to_flows(masks) # get cellpose flows
# get predicted flows and cell probability
pred_masks = []
pred_flows = []
true_lbl = []
for idx, img in enumerate(imgs):
mask, flows, _ = super().eval(x=img, **self.eval_config["segmentor"])
pred_masks.append(mask)
pred_flows.append(np.stack([flows[1][0], flows[1][1], flows[2]])) # stack cell probability map, horizontal and vertical flow
true_lbl.append(np.stack([true_bin_masks[idx], true_flows[idx][2], true_flows[idx][3]]))

pred_masks = [self.eval(img) for img in masks]
self.metric = np.mean(aggregated_jaccard_index(masks, pred_masks)) # TODO move metric computation
# self.loss = self.loss_fn(masks, pred_masks)
true_lbl = np.stack(true_lbl)
pred_flows=np.stack(pred_flows)
pred_flows = torch.from_numpy(pred_flows).float().to('cpu')
# compute loss, combination of mse for flows and bce for cell probability
self.loss = self.loss_fn(true_lbl, pred_flows)
self.metric = np.mean(aggregated_jaccard_index(masks, pred_masks))

def masks_to_outlines(self, mask):
""" get outlines of masks as a 0-1 array
Expand Down Expand Up @@ -105,8 +126,8 @@ class CellClassifierFCNN(nn.Module):
def __init__(self, model_config, train_config, eval_config):
super().__init__()

self.in_channels = model_config["classifier"]["in_channels"]
self.num_classes = model_config["classifier"]["num_classes"]
self.in_channels = model_config["classifier"].get("in_channels",1)
self.num_classes = model_config["classifier"].get("num_classes",3)

self.train_config = train_config["classifier"]
self.eval_config = eval_config["classifier"]
Expand Down Expand Up @@ -134,6 +155,8 @@ def __init__(self, model_config, train_config, eval_config):
self.final_conv = nn.Conv2d(128, self.num_classes, 1)
self.pooling = nn.AdaptiveMaxPool2d(1)

self.metric_fn = F1Score(num_classes=self.num_classes, task="multiclass")

def update_configs(self, train_config, eval_config):
self.train_config = train_config
self.eval_config = eval_config
Expand Down Expand Up @@ -180,7 +203,7 @@ def train (self, imgs, labels):
# TODO check if we should replace self.parameters with super.parameters()

for _ in tqdm(range(epochs), desc="Running CellClassifierFCNN training"):
self.loss = 0
self.loss, self.metric = 0, 0
for data in train_dataloader:
imgs, labels = data

Expand All @@ -192,7 +215,10 @@ def train (self, imgs, labels):
optimizer.step()
self.loss += l.item()

self.metric += self.metric_fn(preds, labels)

self.loss /= len(train_dataloader)
self.metric /= len(train_dataloader)

def eval(self, img):
"""
Expand Down Expand Up @@ -224,15 +250,24 @@ def __init__(self, model_config, train_config, eval_config, model_name):
self.eval_config = eval_config
self.model_name = model_name

self.classifier_class = self.model_config.get("classifier").get("model_class", "CellClassifierFCNN")

# Initialize the cellpose model and the classifier
self.segmentor = CustomCellposeModel(self.model_config,
self.train_config,
self.eval_config,
"Cellpose")
self.classifier = CellClassifierFCNN(self.model_config,
self.train_config,
self.eval_config)


if self.classifier_class == "FCNN":
self.classifier = CellClassifierFCNN(self.model_config,
self.train_config,
self.eval_config)

elif self.classifier_class == "RandomForest":
self.classifier = CellClassifierShallowModel(self.model_config,
self.train_config,
self.eval_config)

def update_configs(self, train_config, eval_config):
self.train_config = train_config
self.eval_config = eval_config
Expand All @@ -249,19 +284,25 @@ def train(self, imgs, masks):
# train cellpose
masks = np.array(masks)
masks_instances = list(masks[:,0,...]) #[mask.sum(-1) for mask in masks] if masks[0].ndim == 3 else masks
self.segmentor.train(imgs, masks_instances)
self.segmentor.train(deepcopy(imgs), masks_instances)
# create patch dataset to train classifier
masks_classes = list(masks[:,1,...]) #[((mask > 0) * np.arange(1, 4)).sum(-1) for mask in masks]
patches, labels = create_patch_dataset(imgs,
masks_classes,
masks_instances,
noise_intensity = self.train_config["classifier"]["train_data"]["noise_intensity"],
max_patch_size = self.train_config["classifier"]["train_data"]["patch_size"])
patches, patch_masks, labels = create_patch_dataset(imgs,
masks_classes,
masks_instances,
noise_intensity = self.train_config["classifier"]["train_data"]["noise_intensity"],
max_patch_size = self.train_config["classifier"]["train_data"]["patch_size"])
x = patches
if self.classifier_class == "RandomForest":
x = create_dataset_for_rf(patches, patch_masks)
# train classifier
self.classifier.train(patches, labels)
self.classifier.train(x, labels)
# and compute metric and loss
self.metric = (self.segmentor.metric + self.classifier.metric) / 2
self.loss = (self.segmentor.loss + self.classifier.loss)/2

def eval(self, img):
# TBD we assume image is either 2D [H, W] (see fsimage storage)
# TBD we assume image is 2D [H, W] (see fsimage storage)
# The final mask which is returned should have
# first channel the output of cellpose and the rest are the class channels
with torch.no_grad():
Expand All @@ -275,22 +316,57 @@ def eval(self, img):
noise_intensity = self.eval_config["classifier"]["data"]["noise_intensity"]

# get patches centered around detected objects
patches, instance_labels, _ = get_centered_patches(img,
patches, patch_masks, instance_labels, _ = get_centered_patches(img,
instance_mask,
max_patch_size,
noise_intensity=noise_intensity)
x = patches
if self.classifier_class == "RandomForest":
x = create_dataset_for_rf(patches, patch_masks)
# loop over patches and create classification mask
for idx, patch in enumerate(patches):
patch_class = self.classifier.eval(patch) # patch size should be HxWxC, e.g. 64,64,3
for idx in range(len(x)):
patch_class = self.classifier.eval(x[idx])
# Assign predicted class to corresponding location in final_mask
class_mask[instance_mask==instance_labels[idx]] = patch_class.item() + 1
patch_class = patch_class.item() if isinstance(patch_class, torch.Tensor) else patch_class
class_mask[instance_mask==instance_labels[idx]] = patch_class + 1
# Apply mask to final_mask, retaining only regions where cellpose_mask is greater than 0
#class_mask = class_mask * (instance_mask > 0)#.long())
final_mask = np.stack((instance_mask, class_mask), axis=self.eval_config['mask_channel_axis']).astype(np.uint16) # size 2xHxW

return final_mask

class CellClassifierShallowModel:

def __init__(self, model_config, train_config, eval_config):

self.model_config = model_config
self.train_config = train_config
self.eval_config = eval_config

self.model = RandomForestClassifier() # TODO chnage config so RandomForestClassifier accepts input params


def train(self, X_train, y_train):

self.model.fit(X_train,y_train)

y_hat = self.model.predict(X_train)
y_hat_proba = self.model.predict_proba(X_train)

self.metric = f1_score(y_train, y_hat, average='micro')
# Binary Cross Entrop Loss
self.loss = log_loss(y_train, y_hat_proba)


def eval(self, X_test):

X_test = X_test.reshape(1,-1)

try:
y_hat = self.model.predict(X_test)
except NotFittedError as e:
y_hat = np.zeros(X_test.shape[0])

return y_hat

class UNet(nn.Module):

Expand Down
16 changes: 10 additions & 6 deletions src/server/dcp_server/serviceclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def evaluate(self, img: np.ndarray) -> np.ndarray:
def check_and_load_model(self):
bento_model_list = [model.tag.name for model in bentoml.models.list()]
if self.save_model_path in bento_model_list:
loaded_model = bentoml.pytorch.load_model(self.save_model_path+":latest")
loaded_model = bentoml.picklable_model.load_model(self.save_model_path+":latest")
assert loaded_model.__class__.__name__ == self.model.__class__.__name__, 'Check your config, loaded model and model to use not the same!'
self.model = loaded_model

Expand All @@ -65,11 +65,15 @@ def train(self, imgs: List[np.ndarray], masks: List[np.ndarray]) -> str:
"""
self.model.train(imgs, masks)
# Save the bentoml model
#bentoml.picklable_model.save_model(self.save_model_path, self.model)
bentoml.pytorch.save_model(self.save_model_path, # Model name in the local Model Store
self.model, # Model instance being saved
external_modules=[DCPModels]
)
bentoml.picklable_model.save_model(
self.save_model_path,
self.model,
external_modules=[DCPModels],
)
# bentoml.pytorch.save_model(self.save_model_path, # Model name in the local Model Store
# self.model, # Model instance being saved
# external_modules=[DCPModels]
# )

return self.save_model_path

Expand Down
Loading

0 comments on commit 6999988

Please sign in to comment.