Skip to content

Commit

Permalink
documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
lfarv committed Jan 12, 2024
1 parent 11bff7e commit caf60df
Showing 1 changed file with 44 additions and 29 deletions.
73 changes: 44 additions & 29 deletions pyat/at/latticetools/observables.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@


class _ModFun(object):
"""General and pickleable evaluation function"""

def __init__(self, fun, statfun):
self.fun = fun
self.statfun = statfun
Expand All @@ -110,6 +112,8 @@ def __call__(self, ring, *a):


class _ArrayAccess(object):
"""Access to selected items in an array"""

def __init__(self, index):
self.index = _all_rows(index)

Expand All @@ -119,7 +123,9 @@ def __call__(self, ring, data):


class _RecordAccess(object):
def __init__(self, fieldname, index):
"""Access to selected items in a record array"""

def __init__(self, fieldname: str, index):
self.index = index
self.fieldname = fieldname

Expand All @@ -129,7 +135,8 @@ def __call__(self, ring, data):
return data if index is None else data[self.index]


def _all_rows(index: RefIndex):
def _all_rows(index: Optional[RefIndex]):
"""Prepends "all rows" (":") to an index tuple"""
if index is None:
return None
if isinstance(index, tuple):
Expand All @@ -139,7 +146,9 @@ def _all_rows(index: RefIndex):


class _Tune(object):
def __init__(self, idx):
"""Get integer tune from the phase advance"""

def __init__(self, idx: RefIndex):
self.fun = _RecordAccess("mu", _all_rows(idx))

def __call__(self, ring, data):
Expand All @@ -148,6 +157,8 @@ def __call__(self, ring, data):


class _Ring(object):
"""Get an attribute of a lattice element"""

def __init__(self, attrname, index, refpts):
self.get_val = _RecordAccess(attrname, index)
self.refpts = refpts
Expand All @@ -157,17 +168,6 @@ def __call__(self, ring):
return np.array(vals)


def _flatten(vals, order="F"):
def check_none(vs):
for v in vs:
if v is None:
raise AtError("Evaluation failed")
else:
yield v

return np.concatenate([np.reshape(v, -1, order=order) for v in check_none(vals)])


class Need(Enum):
"""Defines the computation requirements for an :py:class:`Observable`."""

Expand Down Expand Up @@ -327,7 +327,7 @@ def _all_lines(self):
return "\n".join((self.name, values))

def _setup(self, ring: Lattice):
"""Setup function called wen the observable is added to a list"""
"""Setup function called when the observable is added to a list"""
pass

def evaluate(self, ring: Lattice, *data, initial: bool = False):
Expand Down Expand Up @@ -403,6 +403,7 @@ def residual(self):
"""residual, computed as
:pycode:`residual = ((value-target)/weight)**2`"""
dev = self.deviation
# Return a large value if the evaluation failed
return 1.0e6 if dev is None else (dev / self.w) ** 2

@staticmethod
Expand Down Expand Up @@ -991,7 +992,7 @@ def _scan(self, obs) -> Observable:
# The "needs" and reflists will be restored later
if hasattr(self, "ring"):
if not isinstance(obs, Observable):
raise TypeError("{} is not an Observable".format(obs))
raise TypeError(f"{obs!r} is not an Observable")
# noinspection PyProtectedMember
obs._setup(self.ring)
self.needs |= obs.needs
Expand All @@ -1000,10 +1001,9 @@ def _scan(self, obs) -> Observable:

def __iadd__(self, other: "ObservableList"):
if not isinstance(other, ObservableList):
mess = "Cannot add a {} to an Observable"
raise TypeError(mess.format(type(other)))
raise TypeError(f"Cannot add a {type(other)} to an ObservableList")
if other.ring is not self.ring:
raise TypeError("Observables must be based on the same lattice")
raise TypeError("ObservableLists must be based on the same lattice")
self.extend(other)
return self

Expand Down Expand Up @@ -1065,7 +1065,6 @@ def obseval(obs):
data.append(emdata)
if Need.GEOMETRY in obsneeds:
data.append(geodata[obsrefs])
print(obs.name)
obs.evaluate(ring, *data, initial=initial)

@frequency_control
Expand All @@ -1077,12 +1076,12 @@ def ringeval(
r_in: Orbit = None,
):
"""Optics computations"""

trajs = orbits = rgdata = eldata = emdata = mxdata = geodata = None
o0 = None
needs = self.needs

if Need.TRAJECTORY in needs:
# Trajectory computation
if r_in is None:
r_in = np.zeros(6)
r_out = internal_lpass(ring, r_in.copy(), 1, refpts=self.passrefs)
Expand All @@ -1095,6 +1094,7 @@ def ringeval(
Need.GLOBALOPTICS,
Need.EMITTANCE,
}):
# Closed orbit computation
orbit0 = self.kwargs.get("orbit", None)
try:
o0, orbits = ring.find_orbit(
Expand All @@ -1104,6 +1104,7 @@ def ringeval(
pass

if Need.MATRIX in needs and o0 is not None:
# Transfer matrix computation
if ring.is_6d:
# noinspection PyUnboundLocalVariable
_, mxdata = ring.find_m66(
Expand All @@ -1129,6 +1130,7 @@ def ringeval(
not needs.isdisjoint({Need.LOCALOPTICS, Need.GLOBALOPTICS})
and o0 is not None
):
# Linear optics computation
get_chrom = Need.CHROMATICITY in needs
twiss_in = self.kwargs.get("twiss_in", None)
method = self.kwargs.get("method", linopt6)
Expand Down Expand Up @@ -1156,12 +1158,14 @@ def ringeval(
eldata["mu"] = eldata["mu"] % (2.0 * np.pi)

if Need.EMITTANCE in needs and o0 is not None:
# Emittance computation
try:
emdata = ring.envelope_parameters(orbit=o0, keep_lattice=True)
except AtError:
pass

if Need.GEOMETRY in needs:
# Geometry computation
geodata, _ = ring.get_geometry()

return trajs, orbits, rgdata, eldata, emdata, mxdata, geodata
Expand All @@ -1188,6 +1192,17 @@ def exclude(self, obsname: str, excluded: Refpts):
for obs in self:
self._update_reflists(obs)

def _flatten(self, attrname, order='F'):
def check_none():
for obs in self:
v = getattr(obs, attrname)
if v is None:
raise AtError(f"Evaluation of {obs.name} failed")
else:
yield v

return np.concatenate([np.reshape(v, -1, order=order) for v in check_none()])

@property
def shapes(self) -> list:
"""Shapes of all values"""
Expand All @@ -1212,7 +1227,7 @@ def get_flat_values(self, order: str = "F") -> np.ndarray:
Args:
order: Ordering for reshaping. See :py:func:`~numpy.reshape`
"""
return _flatten((obs.value for obs in self), order=order)
return self._flatten("value", order=order)

@property
def weighted_values(self) -> list:
Expand All @@ -1225,7 +1240,7 @@ def get_flat_weighted_values(self, order: str = "F") -> np.ndarray:
Args:
order: Ordering for reshaping. See :py:func:`~numpy.reshape`
"""
return _flatten((obs.weighted_value for obs in self), order=order)
return self._flatten("weighted_values", order=order)

@property
def deviations(self) -> list:
Expand All @@ -1238,7 +1253,7 @@ def get_flat_deviations(self, order: str = "F") -> np.ndarray:
Args:
order: Ordering for reshaping. See :py:func:`~numpy.reshape`
"""
return _flatten((obs.deviation for obs in self), order=order)
return self._flatten("deviation", order=order)

@property
def weighted_deviations(self) -> list:
Expand All @@ -1251,7 +1266,7 @@ def get_flat_weighted_deviations(self, order: str = "F") -> np.ndarray:
Args:
order: Ordering for reshaping. See :py:func:`~numpy.reshape`
"""
return _flatten((obs.weighted_deviation for obs in self), order=order)
return self._flatten("weighted_deviation", order=order)

@property
def weights(self) -> list:
Expand All @@ -1264,15 +1279,16 @@ def get_flat_weights(self, order: str = "F") -> np.ndarray:
Args:
order: Ordering for reshaping. See :py:func:`~numpy.reshape`
"""
return _flatten((obs.weight for obs in self), order=order)
return self._flatten("weight", order=order)

@property
def residuals(self) -> list:
"""Residuals of all observable"""
return [obs.residual for obs in self]

def get_sum_residuals(self) -> float:
"""Return the sum of all residual values"""
@property
def sum_residuals(self) -> float:
"""Sum of all residual values"""
residuals = (obs.residual for obs in self)
return sum(np.sum(res) for res in residuals)

Expand All @@ -1288,7 +1304,6 @@ def get_sum_residuals(self) -> float:
doc="1-D array of weighted deviations from target values",
)
flat_weights = property(get_flat_weights, doc="1-D array of Observable weights")
sum_residuals = property(get_sum_residuals, doc="Sum of all residual values")


# noinspection PyPep8Naming
Expand Down

0 comments on commit caf60df

Please sign in to comment.