diff --git a/src/raygun/webknossos_utils/wkw_seg_to_zarr.py b/src/raygun/webknossos_utils/wkw_seg_to_zarr.py index f6a1244..57d335e 100644 --- a/src/raygun/webknossos_utils/wkw_seg_to_zarr.py +++ b/src/raygun/webknossos_utils/wkw_seg_to_zarr.py @@ -8,8 +8,15 @@ import tempfile from glob import glob import os +import logging +from funlib.persistence import open_ds, prepare_ds +from funlib.geometry import Roi, Coordinate +import numpy as np +from skimage.draw import line_nd +logger = logging.getLogger(__name__) + def download_wk_skeleton( annotation_ID, save_path, @@ -52,6 +59,90 @@ def download_wk_skeleton( return zip_path +def parse_skeleton(zip_path) -> dict: + fin = zip_path + if not fin.endswith(".zip"): + try: + fin = get_updated_skeleton(zip_path) + assert fin.endswith(".zip"), "Skeleton zip file not found." + except: + assert False, "CATMAID NOT IMPLEMENTED" + + wk_skels = wk.skeleton.Skeleton.load(fin) + # return wk_skels + skel_coor = {} + for tree in wk_skels.trees: + skel_coor[tree.id] = [] + for start, end in tree.edges.keys(): + start_pos = start.position.to_np() + end_pos = end.position.to_np() + skel_coor[tree.id].append([start_pos, end_pos]) + + return skel_coor + + +def get_updated_skeleton(zip_path) -> str: + if not os.path.exists(zip_path): + path = os.path.dirname(os.path.realpath(zip_path)) + search_path = os.path.join(path, "skeletons/*") + files = glob(search_path) + if len(files) == 0: + skel_file = download_wk_skeleton() + else: + skel_file = max(files, key=os.path.getctime) + skel_file = os.path.abspath(skel_file) + + return skel_file + +def rasterize_skeleton(zip_path:str, + raw_file:str, + raw_ds:str) -> np.ndarray: + + logger.info(f"Rasterizing skeleton...") + + skel_coor = parse_skeleton(zip_path) + + # Initialize rasterized skeleton image + raw = open_ds(raw_file, raw_ds) + + dataset_shape = raw.data.shape + print(dataset_shape) + voxel_size = raw.voxel_size + offset = raw.roi.begin # unhardcode for nonzero offset + image = np.zeros(dataset_shape, dtype=np.uint8) + + def adjust(coor): + ds_under = [x-1 for x in dataset_shape] + return np.min([coor - offset, ds_under], 0) + + print("adjusting . . .") + for id, tree in skel_coor.items(): + # iterates through ever node and assigns id to {image} + for start, end in tree: + line = line_nd(adjust(start), adjust(end)) + image[line] = id + + + # Save GT rasterization #TODO: implement daisy blockwise option + total_roi = Roi( + Coordinate(offset) * Coordinate(voxel_size), + Coordinate(dataset_shape) * Coordinate(voxel_size), + ) + + print("saving . . .") + out_ds = prepare_ds( + raw_file, + "volumes/training_rasters", + total_roi, + voxel_size, + image.dtype, + delete=True, + ) + out_ds[out_ds.roi] = image + + return image + + def get_wk_mask( annotation_ID, save_path, # TODO: Add mkdtemp() as default