Skip to content

Commit

Permalink
refactor: upd read tensorstore routines
Browse files Browse the repository at this point in the history
  • Loading branch information
anna-grim committed Feb 5, 2024
1 parent da83bda commit 75d6650
Showing 1 changed file with 17 additions and 19 deletions.
36 changes: 17 additions & 19 deletions src/deep_neurographs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import zarr

ANISOTROPY = np.array([0.748, 0.748, 1.0])
SUPPORTED_DRIVERS = ["neuroglancer_precomputed", "zarr"]
SUPPORTED_DRIVERS = ["neuroglancer_precomputed", "n5", "zarr"]


# --- dictionary utils ---
Expand Down Expand Up @@ -304,35 +304,33 @@ def open_tensorstore(path, driver):
ts_arr = ts_arr[0, 0, :, :, :]
ts_arr = ts_arr[ts.d[0].transpose[2]]
ts_arr = ts_arr[ts.d[0].transpose[1]]
return ts_arr
return ts_arr


def read_img_chunk(img, xyz, shape):
start, end = get_start_end(xyz, shape)
start, end = get_start_end(xyz, shape, from_center=from_center)
return img[
start[2]: end[2], start[1]: end[1], start[0]: end[0]
].transpose(2, 1, 0)


def get_chunk(arr, xyz, shape):
start, end = get_start_end(xyz, shape)
return deepcopy(
arr[start[0]: end[0], start[1]: end[1], start[2]: end[2]]
)
def get_chunk(arr, xyz, shape, from_center=True):
start, end = get_start_end(xyz, shape, from_center=from_center)
return deepcopy(arr[start[0]: end[0], start[1]: end[1], start[2]: end[2]])


def read_tensorstore(ts_arr, xyz, shape):
start, end = get_start_end(xyz, shape)
return (
ts_arr[start[0]: end[0], start[1]: end[1], start[2]: end[2]]
.read()
.result()
)
def read_tensorstore(arr, xyz, shape, from_center=True):
chunk = get_chunk(arr, xyz, shape, from_center=from_center)
return np.swapaxes(chunk.read().result(), 0, 2)


def get_start_end(xyz, shape):
start = [xyz[i] - shape[i] // 2 for i in range(3)]
end = [xyz[i] + shape[i] // 2 for i in range(3)]
def get_start_end(xyz, shape, from_center=True):
if from_center:
start = [xyz[i] - shape[i] // 2 for i in range(3)]
end = [xyz[i] + shape[i] // 2 for i in range(3)]
else:
start = xyz
end = [xyz[i] + shape[i] for i in range(3)]
return start, end


Expand All @@ -341,7 +339,7 @@ def get_superchunks(img_path, label_path, xyz, shape, from_center=True):
img_job = executor.submit(
get_superchunk,
img_path,
"zarr",
"n5" if ".n5" in img_path else "zarr",
xyz,
shape,
from_center=from_center,
Expand Down

0 comments on commit 75d6650

Please sign in to comment.