Skip to content

Commit eb600a1

Browse files
Fix for model type
1 parent 9393be8 commit eb600a1

File tree

4 files changed

+33
-8
lines changed

4 files changed

+33
-8
lines changed

CHANGES.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
- Allow assignment to or creation of node attributes using dot notation of object instances
55
with validation. [#284]
66

7+
- Bugfix for ``meta.model_type`` not being set to match the model writing the file. [#296]
8+
79
0.18.0 (2023-11-06)
810
===================
911

src/roman_datamodels/datamodels/_datamodels.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,29 +25,42 @@ class _DataModel(DataModel):
2525
def __init_subclass__(cls, **kwargs):
2626
"""Register each subclass in the __all__ for this module"""
2727
super().__init_subclass__(**kwargs)
28+
29+
# Don't register private classes
30+
if cls.__name__.startswith("_"):
31+
return
32+
2833
if cls.__name__ in __all__:
2934
raise ValueError(f"Duplicate model type {cls.__name__}")
3035

3136
__all__.append(cls.__name__)
3237

3338

34-
class MosaicModel(_DataModel):
39+
class _RomanDataModel(_DataModel):
40+
def __init__(self, init=None, **kwargs):
41+
super().__init__(init, **kwargs)
42+
43+
if init is not None:
44+
self.meta.model_type = self.__class__.__name__
45+
46+
47+
class MosaicModel(_RomanDataModel):
3548
_node_type = stnode.WfiMosaic
3649

3750

38-
class ImageModel(_DataModel):
51+
class ImageModel(_RomanDataModel):
3952
_node_type = stnode.WfiImage
4053

4154

42-
class ScienceRawModel(_DataModel):
55+
class ScienceRawModel(_RomanDataModel):
4356
_node_type = stnode.WfiScienceRaw
4457

4558

46-
class MsosStackModel(_DataModel):
59+
class MsosStackModel(_RomanDataModel):
4760
_node_type = stnode.MsosStack
4861

4962

50-
class RampModel(_DataModel):
63+
class RampModel(_RomanDataModel):
5164
_node_type = stnode.Ramp
5265

5366
@classmethod
@@ -86,7 +99,7 @@ def from_science_raw(cls, model):
8699
raise ValueError("Input model must be a ScienceRawModel or RampModel")
87100

88101

89-
class RampFitOutputModel(_DataModel):
102+
class RampFitOutputModel(_RomanDataModel):
90103
_node_type = stnode.RampFitOutput
91104

92105

@@ -107,7 +120,7 @@ def is_association(cls, asn_data):
107120
return isinstance(asn_data, dict) and "asn_id" in asn_data and "asn_pool" in asn_data
108121

109122

110-
class GuidewindowModel(_DataModel):
123+
class GuidewindowModel(_RomanDataModel):
111124
_node_type = stnode.Guidewindow
112125

113126

tests/test_maker_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from astropy.time import Time
88

99
from roman_datamodels import datamodels, maker_utils, stnode
10+
from roman_datamodels.datamodels._datamodels import _RomanDataModel
1011
from roman_datamodels.maker_utils import _ref_files as ref_files
1112
from roman_datamodels.testing import assert_node_equal
1213

@@ -109,7 +110,8 @@ def test_datamodel_maker(model_class):
109110
assert isinstance(model, model_class)
110111
model.validate()
111112

112-
assert model.meta.model_type == model_class.__name__
113+
if issubclass(model_class, _RomanDataModel):
114+
assert model.meta.model_type == model_class.__name__
113115

114116

115117
@pytest.mark.parametrize("node_class", [node for node in datamodels.MODEL_REGISTRY])

tests/test_models.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -772,6 +772,14 @@ def test_ramp_from_science_raw():
772772
if isinstance(ramp_value, np.ndarray):
773773
assert_array_equal(ramp_value, raw_value.astype(ramp_value.dtype))
774774

775+
elif key == "meta":
776+
for meta_key in ramp_value:
777+
if meta_key == "model_type":
778+
ramp_value[meta_key] = ramp.__class__.__name__
779+
raw_value[meta_key] = raw.__class__.__name__
780+
continue
781+
assert_node_equal(ramp_value[meta_key], raw_value[meta_key])
782+
775783
elif isinstance(ramp_value, stnode.DNode):
776784
assert_node_equal(ramp_value, raw_value)
777785

0 commit comments

Comments
 (0)