-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
✨ feat(dynamics): cluster mass solver
Signed-off-by: Nathaniel Starkman <nstarman@users.noreply.github.com>
- Loading branch information
Showing
8 changed files
with
371 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,25 @@ | ||
"""Cluster mass evolution.""" | ||
"""Cluster evolution.""" | ||
|
||
__all__ = [ | ||
"MassSolver", | ||
# Fields | ||
"MassVectorField", | ||
"AbstractMassField", | ||
"UserMassField", | ||
"ConstantMassField", | ||
# Events | ||
"MassBelowThreshold", | ||
# Functions | ||
"lagrange_points", | ||
"tidal_radius", | ||
] | ||
|
||
|
||
from .events import MassBelowThreshold | ||
from .fields import ( | ||
AbstractMassField, | ||
ConstantMassField, | ||
MassVectorField, | ||
UserMassField, | ||
) | ||
from .funcs import lagrange_points, tidal_radius | ||
from .solver import MassSolver |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
"""``galax`` dynamics.""" | ||
|
||
__all__ = ["MassBelowThreshold"] | ||
|
||
from typing import Any | ||
|
||
import equinox as eqx | ||
from jaxtyping import Array | ||
from plum import dispatch | ||
|
||
import unxt as u | ||
from unxt.quantity import AbstractQuantity | ||
|
||
|
||
class MassBelowThreshold(eqx.Module): # type: ignore[misc] | ||
"""Event to stop integration when the mass falls below a threshold. | ||
Instances can be used as the ``cond_fn`` argument of `diffrax.Event`. Since | ||
this returns a scalar (not a `bool`) the solve the solve will terminate on | ||
the step when the mass is below the threshold. | ||
With `diffrax.Event` this can be combined with a root-finder to find the | ||
exact time when the mass is below the threshold, rather than the step | ||
after. | ||
Example | ||
------- | ||
>>> import unxt as u | ||
>>> from galax.dynamics.cluster import MassBelowThreshold | ||
>>> cond_fn = MassBelowThreshold(u.Quantity(0.0, "Msun")) | ||
>>> args = {"units": u.unitsystems.galactic} | ||
>>> cond_fn(0.0, u.Quantity(1.0, "Msun"), args) | ||
Array(1., dtype=float64, weak_type=True) | ||
>>> cond_fn(0.0, u.Quantity(0.0, "Msun"), args) | ||
Array(0., dtype=float64, weak_type=True) | ||
TODO: example using it as a with `diffrax.Event`. | ||
""" | ||
|
||
#: Threshold mass at which to stop integration. | ||
threshold: AbstractQuantity | ||
|
||
@dispatch | ||
def __call__( | ||
self: "MassBelowThreshold", | ||
t: Any, # noqa: ARG002 | ||
y: Array, | ||
args: dict[str, Any], | ||
/, | ||
**kwargs: Any, # noqa: ARG002 | ||
) -> Array: | ||
return y - self.threshold.ustrip(args["units"]) | ||
|
||
@dispatch | ||
def __call__( | ||
self: "MassBelowThreshold", | ||
t: Any, # noqa: ARG002 | ||
y: AbstractQuantity, | ||
args: dict[str, Any], | ||
/, | ||
**kwargs: Any, # noqa: ARG002 | ||
) -> Array: | ||
return u.ustrip(args["units"], y - self.threshold) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
"""Fields for mass evolution.""" | ||
|
||
__all__ = [ | ||
"MassVectorField", | ||
"AbstractMassField", | ||
"UserMassField", | ||
"ConstantMassField", | ||
] | ||
|
||
from abc import abstractmethod | ||
from dataclasses import KW_ONLY | ||
from typing import Any, Protocol, TypeAlias, TypedDict, runtime_checkable | ||
|
||
import diffrax as dfx | ||
import equinox as eqx | ||
import jax.numpy as jnp | ||
from jaxtyping import Array, PyTree | ||
|
||
from galax.dynamics._src.fields import AbstractField | ||
|
||
Time: TypeAlias = Any | ||
ClusterMass: TypeAlias = Any | ||
|
||
|
||
class Args(TypedDict, total=False): | ||
units: Any | ||
# Add other optional keys here if needed | ||
|
||
|
||
@runtime_checkable | ||
class MassVectorField(Protocol): | ||
"""Protocol for mass vector field. | ||
This is a function that returns the derivative of the mass vector with | ||
respect to time. | ||
Examples | ||
-------- | ||
>>> from galax.dynamics.cluster import MassVectorField | ||
>>> def mass_deriv(t, Mc, args, **kwargs): pass | ||
>>> isinstance(mass_deriv, MassVectorField) | ||
True | ||
""" | ||
|
||
def __call__( | ||
self, t: Time, Mc: ClusterMass, args: Args, /, **kwargs: Any | ||
) -> Array: ... | ||
|
||
|
||
class AbstractMassField(AbstractField): | ||
"""ABC for mass fields. | ||
Methods | ||
------- | ||
__call__ : `galax.dynamics.cluster.MassVectorField` | ||
Compute the mass field. | ||
terms : the `diffrax.AbstractTerm` `jaxtyping.PyTree` for integration. | ||
""" | ||
|
||
@abstractmethod | ||
def __call__(self, t: Time, Mc: ClusterMass, args: Args, /, **kwargs: Any) -> Array: # type: ignore[override] | ||
raise NotImplementedError # pragma: no cover | ||
|
||
@AbstractField.terms.dispatch # type: ignore[misc] | ||
def terms( | ||
self: "AbstractMassField", _: dfx.AbstractSolver, / | ||
) -> PyTree[dfx.AbstractTerm]: | ||
"""Return diffeq terms for integration. | ||
Examples | ||
-------- | ||
>>> import diffrax as dfx | ||
>>> import galax.dynamics as gd | ||
>>> field = gd.cluster.ConstantMassField() | ||
>>> field.terms(dfx.Dopri8()) | ||
ODETerm( | ||
vector_field=_JitWrapper( fn='ConstantMassField.__call__', ... ) ) | ||
""" | ||
return dfx.ODETerm(eqx.filter_jit(self.__call__)) | ||
|
||
|
||
##################################################### | ||
|
||
|
||
class UserMassField(AbstractMassField): | ||
"""User-defined mass field. | ||
This takes a user-defined function of type | ||
`galax.dynamics.cluster.MassVectorField`. | ||
""" | ||
|
||
#: User-defined mass derivative function of type | ||
#: `galax.dynamics.cluster.MassVectorField` | ||
mass_deriv: MassVectorField | ||
|
||
_: KW_ONLY | ||
|
||
def __call__(self, t: Time, Mc: ClusterMass, args: Args, /, **kwargs: Any) -> Array: # type: ignore[override] | ||
return self.mass_deriv(t, Mc, args, **kwargs) | ||
|
||
|
||
##################################################### | ||
|
||
|
||
class ConstantMassField(AbstractMassField): | ||
"""Constant mass field. | ||
This is a constant mass field. | ||
""" | ||
|
||
def __call__(self, t: Time, Mc: ClusterMass, args: Args, /, **kwargs: Any) -> Array: # type: ignore[override] # noqa: ARG002 | ||
return jnp.zeros_like(Mc) | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
"""``galax`` dynamics.""" | ||
|
||
__all__ = ["MassSolver"] | ||
|
||
from dataclasses import KW_ONLY | ||
from typing import Any | ||
|
||
import diffrax as dfx | ||
import equinox as eqx | ||
import optimistix as optx | ||
from plum import dispatch | ||
|
||
import unxt as u | ||
from unxt.quantity import AbstractQuantity | ||
|
||
from .events import MassBelowThreshold | ||
from .fields import AbstractMassField | ||
from galax.dynamics._src.diffeq import DiffEqSolver | ||
from galax.dynamics._src.solver import AbstractSolver | ||
from galax.dynamics._src.utils import parse_saveat | ||
|
||
|
||
class MassSolver(AbstractSolver, strict=True): # type: ignore[call-arg] | ||
"""Solver for mass history.""" | ||
|
||
diffeqsolver: DiffEqSolver = eqx.field( | ||
default=DiffEqSolver( | ||
solver=dfx.Dopri8(), | ||
stepsize_controller=dfx.PIDController(rtol=1e-8, atol=1e-8), | ||
), | ||
converter=DiffEqSolver.from_, | ||
) | ||
# TODO: should this be incorporated into in `DiffEqSolver`? | ||
event: dfx.Event = eqx.field( | ||
default=dfx.Event( | ||
cond_fn=MassBelowThreshold(u.Quantity(0.0, "Msun")), | ||
root_finder=optx.Newton(1e-5, 1e-5, optx.rms_norm), | ||
) | ||
) | ||
|
||
_: KW_ONLY | ||
|
||
units: u.AbstractUnitSystem = eqx.field( | ||
default=u.unitsystems.galactic, converter=u.unitsystem, static=True | ||
) | ||
|
||
@dispatch.abstract | ||
def init( | ||
self: "MassSolver", terms: Any, t0: Any, t1: Any, y0: Any, args: Any | ||
) -> Any: | ||
# See dispatches below | ||
raise NotImplementedError # pragma: no cover | ||
|
||
@dispatch.abstract | ||
def step( | ||
self: "MassSolver", | ||
terms: Any, | ||
t0: Any, | ||
t1: Any, | ||
y0: Any, | ||
args: Any, | ||
**step_kwargs: Any, # e.g. solver_state, made_jump | ||
) -> Any: | ||
"""Step the state.""" | ||
# See dispatches below | ||
raise NotImplementedError # pragma: no cover | ||
|
||
# TODO: dispatch where the state from `init` is accepted | ||
@dispatch.abstract | ||
def solve( | ||
self: "MassSolver", | ||
field: Any, | ||
state: Any, | ||
t0: Any, | ||
t1: Any, | ||
/, | ||
args: Any = (), | ||
**solver_kw: Any, # TODO: TypedDict | ||
) -> dfx.Solution: | ||
"""Call `diffrax.diffeqsolve`.""" | ||
raise NotImplementedError # pragma: no cover | ||
|
||
|
||
# =================================================================== | ||
# Solve Dispatches | ||
|
||
|
||
default_saveat = dfx.SaveAt(t1=True) | ||
|
||
|
||
@MassSolver.solve.dispatch # type: ignore[misc] | ||
@eqx.filter_jit # type: ignore[misc] | ||
def solve( | ||
self: MassSolver, | ||
field: AbstractMassField, | ||
state: Any, | ||
t0: AbstractQuantity, | ||
t1: AbstractQuantity, | ||
/, | ||
saveat: Any = default_saveat, | ||
**solver_kw: Any, | ||
) -> dfx.Solution: | ||
# Units | ||
units = self.units | ||
utime = units["time"] | ||
|
||
# Initial conditions | ||
y0 = state.ustrip(units["mass"]) # Mc | ||
|
||
# Extra info | ||
args = {"units": units} | ||
|
||
# Solve the differential equation | ||
solver_kw.setdefault("dt0", None) | ||
saveat = parse_saveat(units, saveat, dense=solver_kw.pop("dense", None)) | ||
soln = self.diffeqsolver( | ||
field.terms(self.diffeqsolver), | ||
t0=t0.ustrip(utime), | ||
t1=t1.ustrip(utime), | ||
y0=y0, | ||
event=self.event, | ||
args=args, | ||
saveat=saveat, | ||
**solver_kw, | ||
) | ||
|
||
return soln # noqa: RET504 | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.