Skip to content

Commit

Permalink
First functional roundtrip for nifti images
Browse files Browse the repository at this point in the history
  • Loading branch information
ktsitsi committed Oct 7, 2024
1 parent 2e6ffc6 commit 6258617
Showing 1 changed file with 36 additions and 15 deletions.
51 changes: 36 additions & 15 deletions tiledb/bioimg/converters/nifti.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import nibabel as nib
import numpy as np
from black.trans import defaultdict
from nibabel import Nifti1Image
from numpy._typing import NDArray

Expand Down Expand Up @@ -51,7 +50,12 @@ def __init__(
self._vfs = VFS(config=self._source_cfg, ctx=self._source_ctx)
self._vfs_fh = self._vfs.open(input_path, mode="rb")
self._nib_image = Nifti1Image.from_stream(self._vfs_fh)
self._metadata: Dict[str, Any] = self._serialize_header(self.nifti1_hdr_2_dict())
self._metadata: Dict[str, Any] = self._serialize_header(
self.nifti1_hdr_2_dict()
)
self._binary_header = base64.b64encode(
self._nib_image.header.binaryblock
).decode("utf-8")
self._mode = "".join(self._nib_image.dataobj.dtype.names)

def __enter__(self) -> NiftiReader:
Expand All @@ -74,7 +78,7 @@ def logger(self) -> Optional[logging.Logger]:

@property
def group_metadata(self) -> Dict[str, Any]:
writer_kwargs = dict(metadata=self._metadata)
writer_kwargs = dict(metadata=self._metadata, binaryblock=self._binary_header)
self._logger.debug(f"Group metadata: {writer_kwargs}")
return {"json_write_kwargs": json.dumps(writer_kwargs)}

Expand Down Expand Up @@ -176,6 +180,7 @@ def level_image(
) -> np.ndarray:

unscaled_img = self._nib_image.dataobj.get_unscaled()
self._metadata["original_mode"] = self._mode
raw_data_contiguous = np.ascontiguousarray(unscaled_img)
numerical_data = np.frombuffer(raw_data_contiguous, dtype=self.level_dtype())
numerical_data = numerical_data.reshape(self.level_shape())
Expand All @@ -202,6 +207,7 @@ def optimal_reader(
# raise ValueError("chunk_size must be set for chunked reading.")
#
# array = self._nib_image.get_fdata()
# array = self._nib_image.get_fdata()
# total_slices = array.shape[-1]
# for i in range(0, total_slices, self.chunk_size):
# chunk = array[..., i : i + self.chunk_size]
Expand All @@ -216,15 +222,15 @@ def nifti1_hdr_2_dict(self) -> Dict[str, Any]:
}

@staticmethod
def _serialize_header(header_dict: [Dict, Any]) -> Dict[str, Any]:
serialized_header = defaultdict(dict)
for k,v in header_dict.items():
if isinstance(v, np.ndarray):
serialized_header[k] = v.tolist()
if isinstance(serialized_header[k], bytes):
serialized_header[k] = base64.b64encode(serialized_header[k]).decode('utf-8')
else:
serialized_header[k] = v
def _serialize_header(header_dict: Mapping[str, Any]) -> Dict[str, Any]:
serialized_header = {
k: (
base64.b64encode(v.tolist()).decode("utf-8")
if isinstance(v, np.ndarray) and isinstance(v.tolist(), bytes)
else v.tolist() if isinstance(v, np.ndarray) else v
)
for k, v in header_dict.items()
}
return serialized_header


Expand All @@ -233,6 +239,8 @@ def __init__(self, output_path: str, logger: logging.Logger):
self._logger = logger
self._output_path = output_path
self._group_metadata: Dict[str, Any] = {}
self._nifti1header = partial(nib.Nifti1Header)
self._original_mode = None
self._writer = partial(nib.Nifti1Image)

def __enter__(self) -> NiftiWriter:
Expand All @@ -249,20 +257,33 @@ def compute_level_metadata(
) -> Mapping[str, Any]:

writer_metadata: Dict[str, Any] = {}
original_mode = group_metadata.get("original_mode", "RGB")
writer_metadata["mode"] = original_mode
self._original_mode = group_metadata.get("original_mode", "RGB")
writer_metadata["mode"] = self._original_mode
self._logger.debug(f"Writer metadata: {writer_metadata}")
return writer_metadata

def write_group_metadata(self, metadata: Mapping[str, Any]) -> None:
self._group_metadata = json.loads(metadata["json_write_kwargs"])

def _structured_dtype(self) -> np.dtype:
if self._original_mode == "RGB":
return np.dtype([("R", "u1"), ("G", "u1"), ("B", "u1")])

def write_level_image(
self,
image: np.ndarray,
metadata: Mapping[str, Any],
) -> None:
nib_image = self._writer(image, metadata["affine"])
header = self._nifti1header(
binaryblock=base64.b64decode(self._group_metadata["binaryblock"])
)
contiguous_image = np.ascontiguousarray(image)
structured_arr = contiguous_image.view(dtype=self._structured_dtype()).reshape(
*image.shape[:-1]
)
nib_image = self._writer(
structured_arr, header=header, affine=header.get_best_affine()
)
nib.save(nib_image, self._output_path)

def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
Expand Down

0 comments on commit 6258617

Please sign in to comment.