Skip to content

Commit

Permalink
Mark object stats as str-typed
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelosthege committed Jul 24, 2023
1 parent 65eb592 commit cd1d354
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 1 deletion.
6 changes: 5 additions & 1 deletion pymc/backends/mcbackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,9 +221,13 @@ def make_runmeta_and_point_fn(
(-1 if s is None else s)
for s in (shape or [])
]
dt = np.dtype(dtype).name
# Object types will be pickled by the ChainRecordAdapter!
if dt == "object":
dt = "str"
svar = mcb.Variable(
name=sname,
dtype=np.dtype(dtype).name,
dtype=dt,
shape=sshape,
undefined_ndim=shape is None,
)
Expand Down
13 changes: 13 additions & 0 deletions tests/backends/test_mcbackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,19 @@ def test_make_runmeta_and_point_fn(simple_model):
assert not vars["vector_interval__"].is_deterministic
assert vars["matrix"].is_deterministic
assert len(rmeta.sample_stats) == len(step.stats_dtypes[0])

with simple_model:
step = pm.NUTS()
rmeta, point_fn = make_runmeta_and_point_fn(
initial_point=simple_model.initial_point(),
step=step,
model=simple_model,
)
assert isinstance(rmeta, mcb.RunMeta)
svars = {s.name: s for s in rmeta.sample_stats}
# Unbeknownst to McBackend, object stats are pickled to str
assert "sampler_0__warning" in svars
assert svars["sampler_0__warning"].dtype == "str"
pass


Expand Down

0 comments on commit cd1d354

Please sign in to comment.