Skip to content

Commit 2756188

Browse files
committed
handle ndim>2 in maker utils
1 parent 6f464a7 commit 2756188

File tree

3 files changed

+19
-7
lines changed

3 files changed

+19
-7
lines changed

src/roman_datamodels/maker_utils/_datamodels.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -450,8 +450,7 @@ def mk_source_catalog(*, filepath=None, **kwargs):
450450
source_catalog = stnode.SourceCatalog()
451451

452452
source_catalog["source_catalog"] = kwargs.get("source_catalog", Table([range(3), range(3)], names=["a", "b"]))
453-
source_catalog["meta"] = mk_common_meta()
454-
source_catalog["meta"].update(kwargs.get("meta", dict(segmentation_map='')))
453+
source_catalog["meta"] = mk_common_meta(**kwargs.get("meta", {}))
455454

456455
return save_node(source_catalog, filepath=filepath)
457456

@@ -466,15 +465,24 @@ def mk_segmentation_map(*, filepath=None, shape=(4096, 4096), **kwargs):
466465
filepath
467466
(optional, keyword-only) File name and path to write model to.
468467
468+
shape
469+
(optional, keyword-only) Shape of arrays in the model.
470+
469471
Returns
470472
-------
471473
roman_datamodels.stnode.SegmentationMap
472474
"""
473-
segmentation_map = stnode.SegmentationMap()
475+
if len(shape) > 2:
476+
shape = shape[1:3]
474477

478+
warnings.warn(
479+
f"{MESSAGE} assuming the first entry is n_groups followed by y, x. The remaining is thrown out!", UserWarning
480+
)
481+
482+
segmentation_map = stnode.SegmentationMap()
475483
segmentation_map["data"] = kwargs.get("data", np.zeros(shape, dtype=np.uint32))
476484
segmentation_map["meta"] = mk_common_meta()
477-
segmentation_map["meta"].update(kwargs.get("meta", dict(filename='')))
485+
segmentation_map["meta"].update(kwargs.get("meta", {}))
478486

479487
return save_node(segmentation_map, filepath=filepath)
480488

tests/test_models.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -715,6 +715,13 @@ def test_make_source_catalog():
715715
assert isinstance(source_catalog_model.source_catalog, Table)
716716

717717

718+
def test_make_segmentation_map():
719+
segmentation_map = utils.mk_segmentation_map()
720+
segmentation_map_model = datamodels.SegmentationMapModel(segmentation_map)
721+
722+
assert isinstance(segmentation_map_model.data, np.ndarray)
723+
724+
718725
def test_datamodel_info_search(capsys):
719726
wfi_science_raw = utils.mk_level1_science_raw(shape=(2, 8, 8))
720727
af = asdf.AsdfFile()

tests/test_open.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -226,9 +226,6 @@ def test_node_round_trip(tmp_path, node_class):
226226
@pytest.mark.filterwarnings("ignore:This function assumes shape is 2D")
227227
@pytest.mark.filterwarnings("ignore:Input shape must be 5D")
228228
def test_opening_model(tmp_path, node_class):
229-
if node_class == stnode.SourceCatalog:
230-
pytest.xfail("SourceCatalog does not have a meta attribute yet")
231-
232229
file_path = tmp_path / "test.asdf"
233230

234231
# Create a node and write it to disk

0 commit comments

Comments
 (0)