Skip to content

Commit

Permalink
Adding zarr chunk funcs for skeleton downloading
Browse files Browse the repository at this point in the history
  • Loading branch information
brianreicher committed Oct 11, 2023
1 parent 33f7865 commit 9e10124
Showing 1 changed file with 91 additions and 0 deletions.
91 changes: 91 additions & 0 deletions src/raygun/webknossos_utils/wkw_seg_to_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,15 @@
import tempfile
from glob import glob
import os
import logging
from funlib.persistence import open_ds, prepare_ds
from funlib.geometry import Roi, Coordinate
import numpy as np
from skimage.draw import line_nd


logger = logging.getLogger(__name__)

def download_wk_skeleton(
annotation_ID,
save_path,
Expand Down Expand Up @@ -52,6 +59,90 @@ def download_wk_skeleton(
return zip_path


def parse_skeleton(zip_path) -> dict:
fin = zip_path
if not fin.endswith(".zip"):
try:
fin = get_updated_skeleton(zip_path)
assert fin.endswith(".zip"), "Skeleton zip file not found."
except:
assert False, "CATMAID NOT IMPLEMENTED"

wk_skels = wk.skeleton.Skeleton.load(fin)
# return wk_skels
skel_coor = {}
for tree in wk_skels.trees:
skel_coor[tree.id] = []
for start, end in tree.edges.keys():
start_pos = start.position.to_np()
end_pos = end.position.to_np()
skel_coor[tree.id].append([start_pos, end_pos])

return skel_coor


def get_updated_skeleton(zip_path) -> str:
if not os.path.exists(zip_path):
path = os.path.dirname(os.path.realpath(zip_path))
search_path = os.path.join(path, "skeletons/*")
files = glob(search_path)
if len(files) == 0:
skel_file = download_wk_skeleton()
else:
skel_file = max(files, key=os.path.getctime)
skel_file = os.path.abspath(skel_file)

return skel_file

def rasterize_skeleton(zip_path:str,
raw_file:str,
raw_ds:str) -> np.ndarray:

logger.info(f"Rasterizing skeleton...")

skel_coor = parse_skeleton(zip_path)

# Initialize rasterized skeleton image
raw = open_ds(raw_file, raw_ds)

dataset_shape = raw.data.shape
print(dataset_shape)
voxel_size = raw.voxel_size
offset = raw.roi.begin # unhardcode for nonzero offset
image = np.zeros(dataset_shape, dtype=np.uint8)

def adjust(coor):
ds_under = [x-1 for x in dataset_shape]
return np.min([coor - offset, ds_under], 0)

print("adjusting . . .")
for id, tree in skel_coor.items():
# iterates through ever node and assigns id to {image}
for start, end in tree:
line = line_nd(adjust(start), adjust(end))
image[line] = id


# Save GT rasterization #TODO: implement daisy blockwise option
total_roi = Roi(
Coordinate(offset) * Coordinate(voxel_size),
Coordinate(dataset_shape) * Coordinate(voxel_size),
)

print("saving . . .")
out_ds = prepare_ds(
raw_file,
"volumes/training_rasters",
total_roi,
voxel_size,
image.dtype,
delete=True,
)
out_ds[out_ds.roi] = image

return image


def get_wk_mask(
annotation_ID,
save_path, # TODO: Add mkdtemp() as default
Expand Down

0 comments on commit 9e10124

Please sign in to comment.