Skip to content

Commit

Permalink
Merge pull request #16 from ctr26/spline_contour_code
Browse files Browse the repository at this point in the history
spline contour code
  • Loading branch information
ctr26 authored Jan 8, 2024
2 parents e369195 + 9786013 commit f54cdc1
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 61 deletions.
55 changes: 55 additions & 0 deletions bioimage_embed/shapes/contours.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import numpy as np
from scipy.interpolate import interp1d, splprep, splev


def cart2pol(x, y):
return (np.sqrt(x**2 + y**2), np.arctan2(y, x))


def pol2cart(rho, phi):
return (rho * np.cos(phi), rho * np.sin(phi))


def cubic_polar_resample_contour(contour: np.array, size: int) -> np.array:
"""Star convex resampling of a contour using cubic interpolation
Args:
contour (np.Array): scikit image contour
size (int): control points to interpolate to
Returns:
np.Array: new contour
"""
contour_y, contour_x = contour[0][:, 0], contour[0][:, 1]
rho, phi = cart2pol(contour_x, contour_y)

rho_interp = interp1d(np.linspace(0, 1, len(rho)), rho, kind="cubic")(
np.linspace(0, 1, size)
)
phi_interp = interp1d(np.linspace(0, 1, len(phi)), phi, kind="cubic")(
np.linspace(0, 1, size)
)

xii, yii = pol2cart(rho_interp, phi_interp)
return np.array([xii, yii])


def contour_to_xy(contour: np.array):
return contour[0][:, 0], contour[0][:, 1]


def uniform_spline_resample_contour(contour: np.array, size: int) -> np.array:
"""Resample a contour using a uniform spline
Author: @afoix
Args:
contour (np.array): scikit image contour
size (int): Control points to interpolate to
Returns:
np.Array: new contour
"""
contour_y, contour_x = contour_to_xy(contour)
tck, u = splprep([contour_x, contour_y], s=0)
u_new = np.linspace(u.min(), u.max(), size)
return np.array(splev(u_new, tck))
69 changes: 8 additions & 61 deletions bioimage_embed/shapes/transforms.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,19 @@
import sys
import numpy as np
from skimage.draw import polygon2mask
import matplotlib.pyplot as plt

from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO
import pyro.distributions as dist
import pyro
import pytorch_lightning as pl
from torch.utils.data import random_split, DataLoader
import glob

# Note - you must have torchvision installed for this example
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
from skimage.measure import regionprops
from torchvision.transforms.functional import crop
from scipy import ndimage
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
from pytorch_lightning import loggers as pl_loggers
import torchvision
from sklearn.manifold import MDS
from sklearn.metrics.pairwise import euclidean_distances
from scipy.ndimage import convolve, sobel
from skimage.measure import find_contours
from scipy.interpolate import interp1d
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.optim as optim

from . import contours


class cropCentroid(torch.nn.Module):
Expand Down Expand Up @@ -180,7 +157,6 @@ def __repr__(self):
return self.__class__.__name__ + f"(size={self.size})"

def get_distogram(self, coords, matrix_normalised=False):

xii, yii = coords
# distograms.append(euclidean_distances(np.array([xii,yii]).T))
distance_matrix = euclidean_distances(np.array([xii, yii]).T) / self.size**0.5
Expand All @@ -192,7 +168,6 @@ def get_distogram(self, coords, matrix_normalised=False):
return distance_matrix / norm
return distance_matrix

# import numpy as np

class ImageToCoords(torch.nn.Module):
def __init__(self, size):
Expand Down Expand Up @@ -224,39 +199,12 @@ def get_coords_C(
coords_list.append(self.get_coords(image, size))
return torch.tensor(np.array(coords_list))

def get_coords(self, image, size):
coords = []
np_image = np.array(image)
scaling = np.linalg.norm(np_image.shape)

# for i in range(np_image_full.shape[0]):
# np_image = np_image_full[i]
# im_height, im_width = np_image.shape

contour = find_contours(np_image)
contour_y, contour_x = contour[0][:, 0], contour[0][:, 1]
# plt.scatter(contour_x,contour_y)
# plt.show()
# %%
rho, phi = self.cart2pol(contour_x, contour_y)

rho_interp = interp1d(np.linspace(0, 1, len(rho)), rho, kind="cubic")(
np.linspace(0, 1, size)
)
phi_interp = interp1d(np.linspace(0, 1, len(phi)), phi, kind="cubic")(
np.linspace(0, 1, size)
)

xii, yii = np.divide(self.pol2cart(rho_interp, phi_interp), scaling)
xii, yii = self.pol2cart(rho_interp, phi_interp)
# distograms.append(euclidean_distances(np.array([xii,yii]).T))
return np.array([xii, yii])

def cart2pol(self, x, y):
return (np.sqrt(x**2 + y**2), np.arctan2(y, x))

def pol2cart(self, rho, phi):
return (rho * np.cos(phi), rho * np.sin(phi))
def get_coords(self, image, size, method="uniform_spline", contour_level=0.8):
contour = find_contours(np.array(image), contour_level)
if method == "uniform_spline":
return contours.uniform_spline_resample_contour(contour=contour, size=size)
if method == "cubic_polar":
return contours.cubic_polar_resample_contour(contour=contour, size=size)


class VerticesToMask(torch.nn.Module):
Expand Down Expand Up @@ -284,7 +232,6 @@ def vertices_to_mask(self, vertices, mask_shape=(128, 128)):
return torch.tensor(np.array(mask_list))

def vertices_to_mask_BC(self, vertices, mask_shape=(128, 128)):

flat = np.reshape(vertices, (-1, vertices.shape[-2], vertices.shape[-1]))
masks = np.stack([polygon2mask(mask_shape, arr) for arr in flat]).reshape(
*vertices.shape[-4:-2], *mask_shape
Expand Down

0 comments on commit f54cdc1

Please sign in to comment.