Skip to content

Refactoring to allow adding multiple Assets (dynamic tiff sequences) #718

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 18, 2024
2 changes: 1 addition & 1 deletion docs/source/explanations/caching.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ from cachetools import Cache
from tiled.adapters.resource_cache import set_resource_cache

cache = Cache(maxsize=1)
set_resouurce_cache(cache)
set_resource_cache(cache)
```

Any object satisfying the `cachetools.Cache` interface is acceptable.
18 changes: 11 additions & 7 deletions tiled/_tests/test_tiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def client(tmpdir_module):
sequence_directory.mkdir()
filepaths = []
for i in range(3):
data = numpy.random.random((5, 7))
data = numpy.random.random((5, 7, 4))
filepath = sequence_directory / f"temp{i:05}.tif"
tf.imwrite(filepath, data)
filepaths.append(filepath)
Expand All @@ -46,19 +46,23 @@ def client(tmpdir_module):
@pytest.mark.parametrize(
"slice_input, correct_shape",
[
(None, (3, 5, 7)),
(0, (5, 7)),
(slice(0, 3, 2), (2, 5, 7)),
((1, slice(0, 3), slice(0, 3)), (3, 3)),
((slice(0, 3), slice(0, 3), slice(0, 3)), (3, 3, 3)),
(None, (3, 5, 7, 4)),
(0, (5, 7, 4)),
(slice(0, 3, 2), (2, 5, 7, 4)),
((1, slice(0, 3), slice(0, 3)), (3, 3, 4)),
((slice(0, 3), slice(0, 3), slice(0, 3)), (3, 3, 3, 4)),
((..., 0, 0, 0), (3,)),
((0, slice(0, 1), slice(0, 2), ...), (1, 2, 4)),
((0, ..., slice(0, 2)), (5, 7, 2)),
((..., slice(0, 1)), (3, 5, 7, 1)),
],
)
def test_tiff_sequence(client, slice_input, correct_shape):
arr = client["sequence"].read(slice=slice_input)
assert arr.shape == correct_shape


@pytest.mark.parametrize("block_input, correct_shape", [((0, 0, 0), (1, 5, 7))])
@pytest.mark.parametrize("block_input, correct_shape", [((0, 0, 0, 0), (1, 5, 7, 4))])
def test_tiff_sequence_block(client, block_input, correct_shape):
arr = client["sequence"].read_block(block_input)
assert arr.shape == correct_shape
Expand Down
66 changes: 23 additions & 43 deletions tiled/adapters/tiff.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import builtins

import numpy as np
import tifffile

from ..structures.array import ArrayStructure, BuiltinDtype
Expand Down Expand Up @@ -123,7 +124,7 @@ def __init__(
structure = ArrayStructure(
shape=shape,
# one chunks per underlying TIFF file
chunks=((1,) * shape[0], (shape[1],), (shape[2],)),
chunks=((1,) * shape[0], *[(i,) for i in shape[1:]]),
# Assume all files have the same data type
data_type=BuiltinDtype.from_numpy_dtype(self.read(slice=0).dtype),
)
Expand All @@ -133,66 +134,45 @@ def metadata(self):
# TODO How to deal with the many headers?
return self._provided_metadata

def read(self, slice=None):
def read(self, slice=Ellipsis):
"""Return a numpy array

Receives a sequence of values to select from a collection of tiff files that were saved in a folder
The input order is defined as files --> X slice --> Y slice
The input order is defined as: files --> vertical slice --> horizontal slice --> color slice --> ...
read() can receive one value or one slice to select all the data from one file or a sequence of files;
or it can receive a tuple of up to three values (int or slice) to select a more specific sequence of pixels
of a group of images
or it can receive a tuple (int or slice) to select a more specific sequence of pixels of a group of images.
"""

if slice is None:
if slice is Ellipsis:
return self._seq.asarray()
if isinstance(slice, int):
# e.g. read(slice=0)
# e.g. read(slice=0) -- return an entire image
return tifffile.TiffFile(self._seq.files[slice]).asarray()
# e.g. read(slice=(...))
if isinstance(slice, builtins.slice):
# e.g. read(slice=(...)) -- return a slice along the image axis
return tifffile.TiffSequence(self._seq.files[slice]).asarray()
if isinstance(slice, tuple):
if len(slice) == 0:
return self._seq.asarray()
if len(slice) == 1:
return self.read(slice=slice[0])
image_axis, *the_rest = slice
# Could be int or slice
# (0, slice(...)) or (0,....) are converted to a list
# Could be int or slice (0, slice(...)) or (0,....); the_rest is converted to a list
if isinstance(image_axis, int):
# e.g. read(slice=(0, ....))
return tifffile.TiffFile(self._seq.files[image_axis]).asarray()
if isinstance(image_axis, builtins.slice):
if image_axis.start is None:
slice_start = 0
else:
slice_start = image_axis.start
if image_axis.step is None:
slice_step = 1
else:
slice_step = image_axis.step

arr = tifffile.TiffSequence(
self._seq.files[
slice_start : image_axis.stop : slice_step # noqa: E203
]
).asarray()
arr = arr[tuple(the_rest)]
return arr
if isinstance(slice, builtins.slice):
# Check for start and step which can be optional
if slice.start is None:
slice_start = 0
else:
slice_start = slice.start
if slice.step is None:
slice_step = 1
else:
slice_step = slice.step

arr = tifffile.TiffSequence(
self._seq.files[slice_start : slice.stop : slice_step] # noqa: E203
).asarray()
arr = tifffile.TiffFile(self._seq.files[image_axis]).asarray()
elif image_axis is Ellipsis:
# Return all images
arr = tifffile.TiffSequence(self._seq.files).asarray()
the_rest.insert(0, Ellipsis) # Include any leading dimensions
elif isinstance(image_axis, builtins.slice):
arr = self.read(slice=image_axis)
arr = np.atleast_1d(arr[tuple(the_rest)])
return arr

def read_block(self, block, slice=None):
if block[1:] != (0, 0):
if any(block[1:]):
# e.g. block[1:] != [0,0, ..., 0]
raise IndexError(block)
arr = self.read(builtins.slice(block[0], block[0] + 1))
if slice is not None:
Expand Down
81 changes: 49 additions & 32 deletions tiled/catalog/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,14 +561,8 @@ async def get_distinct(self, metadata, structure_families, specs, counts):

return data

async def create_node(
self,
structure_family,
metadata,
key=None,
specs=None,
data_sources=None,
):
@property
def insert(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have factored out this bit into its own method because it's reused a couple of times throughout the class. Maybe call _insert?

# The only way to do "insert if does not exist" i.e. ON CONFLICT
# is to invoke dialect-specific insert.
if self.context.engine.dialect.name == "sqlite":
Expand All @@ -578,6 +572,16 @@ async def create_node(
else:
assert False # future-proofing

return insert

async def create_node(
self,
structure_family,
metadata,
key=None,
specs=None,
data_sources=None,
):
key = key or self.context.key_maker()
data_sources = data_sources or []

Expand Down Expand Up @@ -636,6 +640,7 @@ async def create_node(
"is not one that the Tiled server knows how to read."
),
)

if data_source.structure is None:
structure_id = None
else:
Expand All @@ -646,7 +651,7 @@ async def create_node(
)
structure_id = compute_structure_id(structure)
statement = (
insert(orm.Structure).values(
self.insert(orm.Structure).values(
id=structure_id,
structure=structure,
)
Expand All @@ -663,20 +668,7 @@ async def create_node(
node.data_sources.append(data_source_orm)
await db.flush() # Get data_source_orm.id.
for asset in data_source.assets:
# Find an asset_id if it exists, otherwise create a new one
statement = select(orm.Asset.id).where(
orm.Asset.data_uri == asset.data_uri
)
result = await db.execute(statement)
if row := result.fetchone():
(asset_id,) = row
else:
statement = insert(orm.Asset).values(
data_uri=asset.data_uri,
is_directory=asset.is_directory,
)
result = await db.execute(statement)
(asset_id,) = result.inserted_primary_key
asset_id = await self._put_asset(db, asset)
assoc_orm = orm.DataSourceAssetAssociation(
asset_id=asset_id,
data_source_id=data_source_orm.id,
Expand All @@ -701,23 +693,31 @@ async def create_node(
self.context, refreshed_node, access_policy=self.access_policy
)

async def _put_asset(self, db, asset):
# Find an asset_id if it exists, otherwise create a new one
statement = select(orm.Asset.id).where(orm.Asset.data_uri == asset.data_uri)
result = await db.execute(statement)
if row := result.fetchone():
(asset_id,) = row
else:
statement = self.insert(orm.Asset).values(
data_uri=asset.data_uri,
is_directory=asset.is_directory,
)
result = await db.execute(statement)
(asset_id,) = result.inserted_primary_key

return asset_id

async def put_data_source(self, data_source):
# Obtain and hash the canonical (RFC 8785) representation of
# the JSON structure.
structure = _prepare_structure(
data_source.structure_family, data_source.structure
)
structure_id = compute_structure_id(structure)
# The only way to do "insert if does not exist" i.e. ON CONFLICT
# is to invoke dialect-specific insert.
if self.context.engine.dialect.name == "sqlite":
from sqlalchemy.dialects.sqlite import insert
elif self.context.engine.dialect.name == "postgresql":
from sqlalchemy.dialects.postgresql import insert
else:
assert False # future-proofing
statement = (
insert(orm.Structure).values(
self.insert(orm.Structure).values(
id=structure_id,
structure=structure,
)
Expand All @@ -741,6 +741,23 @@ async def put_data_source(self, data_source):
status_code=404,
detail=f"No data_source {data_source.id} on this node.",
)
# Add assets and associate them with the data_source
for asset in data_source.assets:
asset_id = await self._put_asset(db, asset)
statement = select(orm.DataSourceAssetAssociation).where(
(orm.DataSourceAssetAssociation.data_source_id == data_source.id)
& (orm.DataSourceAssetAssociation.asset_id == asset_id)
)
result = await db.execute(statement)
if not result.fetchone():
assoc_orm = orm.DataSourceAssetAssociation(
asset_id=asset_id,
data_source_id=data_source.id,
parameter=asset.parameter,
num=asset.num,
)
db.add(assoc_orm)

await db.commit()

# async def patch_node(datasources=None):
Expand Down
Loading