Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gp/feat/aman arbitrary path #1057

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
11 changes: 11 additions & 0 deletions docs/axisman.rst
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,17 @@ The output of the ``wrap`` cal should be::
Note the boresight entry is marked with a ``*``, indicating that it's
an AxisManager rather than a numpy array.

To access data in an AxisManager, use a path-like syntax where
iparask marked this conversation as resolved.
Show resolved Hide resolved
attribute names are separated by dots::

>>> n, ofs = 1000, 0
>>> dets = ["det0", "det1", "det2"]
>>> aman = core.AxisManager(core.LabelAxis("dets", dets), core.OffsetAxis("samps", n, ofs))
>>> child = core.AxisManager(core.LabelAxis("dets", dets + ["det3"]),core.OffsetAxis("samps", n, ofs - n // 2),)
>>> aman.wrap("child", child)
>>> print(aman["child.dets"])
LabelAxis(3:'det0','det1','det2')

To slice this object, use the restrict() method. First, let's
restrict in the 'dets' axis. Since it's an Axis of type LabelAxis,
the restriction selector must be a list of strings::
Expand Down
69 changes: 52 additions & 17 deletions sotodlib/core/axisman.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,29 +349,60 @@ def move(self, name, new_name):
self._fields[new_name] = self._fields.pop(name)
self._assignments[new_name] = self._assignments.pop(name)
return self

def add_axis(self, a):
assert isinstance( a, AxisInterface)
self._axes[a.name] = a.copy()

def __contains__(self, name):
return name in self._fields or name in self._axes
attrs = name.split(".")
tmp_item = self
while attrs:
attr_name = attrs.pop(0)
if attr_name in tmp_item._fields:
tmp_item = tmp_item._fields[attr_name]
elif attr_name in tmp_item._axes:
tmp_item = tmp_item._axes[attr_name]
else:
return False
return True

def __getitem__(self, name):
if name in self._fields:
return self._fields[name]
if name in self._axes:
return self._axes[name]
raise KeyError(name)

# We want to support options like:
# aman.focal_plane.xi . aman['focal_plane.xi']
# We will safely assume that a getitem will always have '.' as the separator
attrs = name.split(".")
tmp_item = self
while attrs:
attr_name = attrs.pop(0)
if attr_name in tmp_item._fields:
tmp_item = tmp_item._fields[attr_name]
elif attr_name in tmp_item._axes:
tmp_item = tmp_item._axes[attr_name]
else:
raise KeyError(attr_name)
return tmp_item

def __setitem__(self, name, val):
if name in self._fields:
self._fields[name] = val
else:
raise KeyError(name)

last_pos = name.rfind(".")
val_key = name
tmp_item = self
if last_pos > -1:
val_key = name[last_pos + 1:]
attrs = name[:last_pos]
tmp_item = self[attrs]

if isinstance(val, AxisManager) and isinstance(tmp_item, AxisManager):
raise ValueError("Cannot assign AxisManager to AxisManager. Please use wrap method.")

tmp_item.__setattr__(val_key, val)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a slight regression -- consider:

a = AxisManager()
a.x = 1
a['y'] = 1

The previous behavior was that a.x = 1 would work, but a['y'] = 1 would raise a KeyError. Please restore that behavior.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, I think


def __setattr__(self, name, value):
# Assignment to members update those members
# We will assume that a path exists until the last member.
# If any member prior to that does not exist a keyerror is raised.
if "_fields" in self.__dict__ and name in self._fields.keys():
self._fields[name] = value
else:
Expand All @@ -381,7 +412,11 @@ def __setattr__(self, name, value):
def __getattr__(self, name):
# Prevent members from override special class members.
if name.startswith("__"): raise AttributeError(name)
return self[name]
try:
val = self[name]
except KeyError as ex:
raise AttributeError(name) from ex
return val

def __dir__(self):
return sorted(tuple(self.__dict__.keys()) + tuple(self.keys()))
Expand Down Expand Up @@ -514,27 +549,27 @@ def concatenate(items, axis=0, other_fields='exact'):
output.wrap(k, new_data[k], axis_map)
else:
if other_fields == "exact":
## if every item named k is a scalar
## if every item named k is a scalar
err_msg = (f"The field '{k}' does not share axis '{axis}'; "
f"{k} is not identical across all items "
f"pass other_fields='drop' or 'first' or else "
f"remove this field from the targets.")

if np.any([np.isscalar(i[k]) for i in items]):
if not np.all([np.isscalar(i[k]) for i in items]):
raise ValueError(err_msg)
if not np.all([np.array_equal(i[k], items[0][k], equal_nan=True) for i in items]):
raise ValueError(err_msg)
output.wrap(k, items[0][k], axis_map)
continue

elif not np.all([i[k].shape==items[0][k].shape for i in items]):
raise ValueError(err_msg)
elif not np.all([np.array_equal(i[k], items[0][k], equal_nan=True) for i in items]):
raise ValueError(err_msg)

output.wrap(k, items[0][k].copy(), axis_map)

elif other_fields == 'fail':
raise ValueError(
f"The field '{k}' does not share axis '{axis}'; "
Expand Down
2 changes: 1 addition & 1 deletion sotodlib/core/axisman_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def expand_RangesMatrix(flat_rm):
if shape[0] == 0:
return so3g.proj.RangesMatrix([], child_shape=shape[1:])
# Otherwise non-trivial
count = np.product(shape[:-1])
count = np.prod(shape[:-1])
start, stride = 0, count // shape[0]
for i in range(0, len(ends), stride):
_e = ends[i:i+stride] - start
Expand Down
61 changes: 56 additions & 5 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import shutil

from networkx import selfloop_edges
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused import?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

import numpy as np
import astropy.units as u
from sotodlib import core
Expand Down Expand Up @@ -66,7 +67,7 @@ def test_130_not_inplace(self):

# This should return a separate thing.
rman = aman.restrict('samps', (10, 30), in_place=False)
#self.assertNotEqual(aman.a1[0], 0.)
# self.assertNotEqual(aman.a1[0], 0.)
self.assertEqual(len(aman.a1), 100)
self.assertEqual(len(rman.a1), 20)
self.assertNotEqual(aman.a1[10], 0.)
Expand Down Expand Up @@ -190,23 +191,23 @@ def test_170_concat(self):

# ... other_fields="exact"
aman = core.AxisManager.concatenate([amanA, amanB], axis='dets')

## add scalars
amanA.wrap("ans", 42)
amanB.wrap("ans", 42)
aman = core.AxisManager.concatenate([amanA, amanB], axis='dets')

# ... other_fields="exact"
amanB.azimuth[:] = 2.
with self.assertRaises(ValueError):
aman = core.AxisManager.concatenate([amanA, amanB], axis='dets')

# ... other_fields="exact" and arrays of different shapes
amanB.move("azimuth", None)
amanB.wrap("azimuth", np.array([43,5,2,3]))
with self.assertRaises(ValueError):
aman = core.AxisManager.concatenate([amanA, amanB], axis='dets')

# ... other_fields="fail"
amanB.move("azimuth",None)
amanB.wrap_new('azimuth', shape=('samps',))[:] = 2.
Expand Down Expand Up @@ -298,6 +299,56 @@ def test_300_restrict(self):
self.assertNotEqual(aman.a1[0, 0, 0, 1], 0.)

# wrap of AxisManager, merge.
def test_get_set(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The naming convention here is test_NNN_description. I propose you call this test_190_get_set and move it upwards, before the # Multi-dimensional restrictions. section.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

dets = ["det0", "det1", "det2"]
n, ofs = 1000, 0
aman = core.AxisManager(
core.LabelAxis("dets", dets), core.OffsetAxis("samps", n, ofs)
)
child = core.AxisManager(
core.LabelAxis("dets", dets + ["det3"]),
core.OffsetAxis("samps", n, ofs - n // 2),
)

child2 = core.AxisManager(
core.LabelAxis("dets2", ["det4", "det5"]),
core.OffsetAxis("samps", n, ofs - n // 2),
)
child2.wrap("tod", np.zeros((2,1000)))
aman.wrap("child", child)
aman["child"].wrap("child2", child2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mhasself do we want support in wrap for these paths as well? ie aman.wrap('child.child2', ...)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am ok with that but let's leave it out for now.

self.assertEqual(aman["child.child2.dets2"].count, 2)
self.assertEqual(aman["child.dets"].name, "dets")
np.testing.assert_array_equal(aman["child.child2.dets2"].vals, np.array(["det4", "det5"]))
self.assertEqual(aman["child.child2.samps"].count, n // 2)
self.assertEqual(aman["child.child2.samps"].offset, 0)
self.assertEqual(aman["child.child2.samps"].count, aman.child.child2.samps.count)
self.assertEqual(aman["child.child2.samps"].offset, aman.child.child2.samps.offset)

np.testing.assert_array_equal(aman["child.child2.tod"], np.zeros((2,1000)))

with self.assertRaises(KeyError):
aman["child2"]

with self.assertRaises(AttributeError):
aman["child.dets.an_extra_layer"]

self.assertIn("child.dets", aman)
self.assertIn("child.dets2", aman) # I am not sure why this is true
iparask marked this conversation as resolved.
Show resolved Hide resolved
self.assertNotIn("child.child2.someentry", aman)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add test
self.assertNotIn("child.child2.someentry.someotherentry", aman)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


with self.assertRaises(ValueError):
aman["child"] = child2

new_tods = np.ones((2, 500))
aman.child.child2.tod = new_tods
np.testing.assert_array_equal(aman["child.child2.tod"], np.ones((2, 500)))
np.testing.assert_array_equal(aman.child.child2.tod, np.ones((2, 500)))

new_tods = np.ones((2, 1500))
aman["child.child2.tod"] = new_tods
np.testing.assert_array_equal(aman["child.child2.tod"], np.ones((2, 1500)))
np.testing.assert_array_equal(aman.child.child2.tod, np.ones((2, 1500)))

def test_400_child(self):
dets = ['det0', 'det1', 'det2']
Expand Down
Loading