diff --git a/src/galax/dynamics/__init__.py b/src/galax/dynamics/__init__.py index 64ee8788..e5e74abe 100644 --- a/src/galax/dynamics/__init__.py +++ b/src/galax/dynamics/__init__.py @@ -6,6 +6,7 @@ "integrate", "mockstream", "plot", + "cluster", # integrate "evaluate_orbit", "Orbit", @@ -29,7 +30,7 @@ from galax.setup_package import RUNTIME_TYPECHECKER with install_import_hook("galax.dynamics", RUNTIME_TYPECHECKER): - from . import fields, integrate, mockstream, plot + from . import cluster, fields, integrate, mockstream, plot from ._src.cluster.funcs import lagrange_points, tidal_radius from ._src.funcs import specific_angular_momentum from ._src.orbit import Orbit diff --git a/src/galax/dynamics/_src/cluster/__init__.py b/src/galax/dynamics/_src/cluster/__init__.py index 7f466a23..1a33955c 100644 --- a/src/galax/dynamics/_src/cluster/__init__.py +++ b/src/galax/dynamics/_src/cluster/__init__.py @@ -1,9 +1,25 @@ -"""Cluster mass evolution.""" +"""Cluster evolution.""" __all__ = [ + "MassSolver", + # Fields + "MassVectorField", + "AbstractMassField", + "UserMassField", + "ConstantMassField", + # Events + "MassBelowThreshold", + # Functions "lagrange_points", "tidal_radius", ] - +from .events import MassBelowThreshold +from .fields import ( + AbstractMassField, + ConstantMassField, + MassVectorField, + UserMassField, +) from .funcs import lagrange_points, tidal_radius +from .solver import MassSolver diff --git a/src/galax/dynamics/_src/cluster/events.py b/src/galax/dynamics/_src/cluster/events.py new file mode 100644 index 00000000..f6c1a31a --- /dev/null +++ b/src/galax/dynamics/_src/cluster/events.py @@ -0,0 +1,67 @@ +"""``galax`` dynamics.""" + +__all__ = ["MassBelowThreshold"] + +from typing import Any + +import equinox as eqx +from jaxtyping import Array +from plum import dispatch + +import unxt as u +from unxt.quantity import AbstractQuantity + + +class MassBelowThreshold(eqx.Module): # type: ignore[misc] + """Event to stop integration when the mass falls below a threshold. + + Instances can be used as the ``cond_fn`` argument of `diffrax.Event`. Since + this returns a scalar (not a `bool`) the solve the solve will terminate on + the step when the mass is below the threshold. + + With `diffrax.Event` this can be combined with a root-finder to find the + exact time when the mass is below the threshold, rather than the step + after. + + Example + ------- + >>> import unxt as u + >>> from galax.dynamics.cluster import MassBelowThreshold + + >>> cond_fn = MassBelowThreshold(u.Quantity(0.0, "Msun")) + >>> args = {"units": u.unitsystems.galactic} + + >>> cond_fn(0.0, u.Quantity(1.0, "Msun"), args) + Array(1., dtype=float64, weak_type=True) + + >>> cond_fn(0.0, u.Quantity(0.0, "Msun"), args) + Array(0., dtype=float64, weak_type=True) + + TODO: example using it as a with `diffrax.Event`. + + """ + + #: Threshold mass at which to stop integration. + threshold: AbstractQuantity + + @dispatch + def __call__( + self: "MassBelowThreshold", + t: Any, # noqa: ARG002 + y: Array, + args: dict[str, Any], + /, + **kwargs: Any, # noqa: ARG002 + ) -> Array: + return y - self.threshold.ustrip(args["units"]) + + @dispatch + def __call__( + self: "MassBelowThreshold", + t: Any, # noqa: ARG002 + y: AbstractQuantity, + args: dict[str, Any], + /, + **kwargs: Any, # noqa: ARG002 + ) -> Array: + return u.ustrip(args["units"], y - self.threshold) diff --git a/src/galax/dynamics/_src/cluster/fields.py b/src/galax/dynamics/_src/cluster/fields.py new file mode 100644 index 00000000..51df95f0 --- /dev/null +++ b/src/galax/dynamics/_src/cluster/fields.py @@ -0,0 +1,122 @@ +"""Fields for mass evolution.""" + +__all__ = [ + "MassVectorField", + "AbstractMassField", + "UserMassField", + "ConstantMassField", +] + +from abc import abstractmethod +from dataclasses import KW_ONLY +from typing import Any, Protocol, TypeAlias, TypedDict, runtime_checkable + +import diffrax as dfx +import equinox as eqx +import jax.numpy as jnp +from jaxtyping import Array, PyTree + +from galax.dynamics._src.fields import AbstractField + +Time: TypeAlias = Any +ClusterMass: TypeAlias = Any + + +class Args(TypedDict, total=False): + units: Any + # Add other optional keys here if needed + + +@runtime_checkable +class MassVectorField(Protocol): + """Protocol for mass vector field. + + This is a function that returns the derivative of the mass vector with + respect to time. + + Examples + -------- + >>> from galax.dynamics.cluster import MassVectorField + + >>> def mass_deriv(t, Mc, args, **kwargs): pass + + >>> isinstance(mass_deriv, MassVectorField) + True + + """ + + def __call__( + self, t: Time, Mc: ClusterMass, args: Args, /, **kwargs: Any + ) -> Array: ... + + +class AbstractMassField(AbstractField): + """ABC for mass fields. + + Methods + ------- + __call__ : `galax.dynamics.cluster.MassVectorField` + Compute the mass field. + terms : the `diffrax.AbstractTerm` `jaxtyping.PyTree` for integration. + + """ + + @abstractmethod + def __call__(self, t: Time, Mc: ClusterMass, args: Args, /, **kwargs: Any) -> Array: # type: ignore[override] + raise NotImplementedError # pragma: no cover + + @AbstractField.terms.dispatch # type: ignore[misc] + def terms( + self: "AbstractMassField", _: dfx.AbstractSolver, / + ) -> PyTree[dfx.AbstractTerm]: + """Return diffeq terms for integration. + + Examples + -------- + >>> import diffrax as dfx + >>> import galax.dynamics as gd + + >>> field = gd.cluster.ConstantMassField() + >>> field.terms(dfx.Dopri8()) + ODETerm( + vector_field=_JitWrapper( fn='ConstantMassField.__call__', ... ) ) + + """ + return dfx.ODETerm(eqx.filter_jit(self.__call__)) + + +##################################################### + + +class UserMassField(AbstractMassField): + """User-defined mass field. + + This takes a user-defined function of type + `galax.dynamics.cluster.MassVectorField`. + + """ + + #: User-defined mass derivative function of type + #: `galax.dynamics.cluster.MassVectorField` + mass_deriv: MassVectorField + + _: KW_ONLY + + def __call__(self, t: Time, Mc: ClusterMass, args: Args, /, **kwargs: Any) -> Array: # type: ignore[override] + return self.mass_deriv(t, Mc, args, **kwargs) + + +##################################################### + + +class ConstantMassField(AbstractMassField): + """Constant mass field. + + This is a constant mass field. + + + + """ + + def __call__(self, t: Time, Mc: ClusterMass, args: Args, /, **kwargs: Any) -> Array: # type: ignore[override] # noqa: ARG002 + return jnp.zeros_like(Mc) diff --git a/src/galax/dynamics/_src/cluster/solver.py b/src/galax/dynamics/_src/cluster/solver.py new file mode 100644 index 00000000..f4be8929 --- /dev/null +++ b/src/galax/dynamics/_src/cluster/solver.py @@ -0,0 +1,127 @@ +"""``galax`` dynamics.""" + +__all__ = ["MassSolver"] + +from dataclasses import KW_ONLY +from typing import Any + +import diffrax as dfx +import equinox as eqx +import optimistix as optx +from plum import dispatch + +import unxt as u +from unxt.quantity import AbstractQuantity + +from .events import MassBelowThreshold +from .fields import AbstractMassField +from galax.dynamics._src.diffeq import DiffEqSolver +from galax.dynamics._src.solver import AbstractSolver +from galax.dynamics._src.utils import parse_saveat + + +class MassSolver(AbstractSolver, strict=True): # type: ignore[call-arg] + """Solver for mass history.""" + + diffeqsolver: DiffEqSolver = eqx.field( + default=DiffEqSolver( + solver=dfx.Dopri8(), + stepsize_controller=dfx.PIDController(rtol=1e-8, atol=1e-8), + ), + converter=DiffEqSolver.from_, + ) + # TODO: should this be incorporated into in `DiffEqSolver`? + event: dfx.Event = eqx.field( + default=dfx.Event( + cond_fn=MassBelowThreshold(u.Quantity(0.0, "Msun")), + root_finder=optx.Newton(1e-5, 1e-5, optx.rms_norm), + ) + ) + + _: KW_ONLY + + units: u.AbstractUnitSystem = eqx.field( + default=u.unitsystems.galactic, converter=u.unitsystem, static=True + ) + + @dispatch.abstract + def init( + self: "MassSolver", terms: Any, t0: Any, t1: Any, y0: Any, args: Any + ) -> Any: + # See dispatches below + raise NotImplementedError # pragma: no cover + + @dispatch.abstract + def step( + self: "MassSolver", + terms: Any, + t0: Any, + t1: Any, + y0: Any, + args: Any, + **step_kwargs: Any, # e.g. solver_state, made_jump + ) -> Any: + """Step the state.""" + # See dispatches below + raise NotImplementedError # pragma: no cover + + # TODO: dispatch where the state from `init` is accepted + @dispatch.abstract + def solve( + self: "MassSolver", + field: Any, + state: Any, + t0: Any, + t1: Any, + /, + args: Any = (), + **solver_kw: Any, # TODO: TypedDict + ) -> dfx.Solution: + """Call `diffrax.diffeqsolve`.""" + raise NotImplementedError # pragma: no cover + + +# =================================================================== +# Solve Dispatches + + +default_saveat = dfx.SaveAt(t1=True) + + +@MassSolver.solve.dispatch # type: ignore[misc] +@eqx.filter_jit # type: ignore[misc] +def solve( + self: MassSolver, + field: AbstractMassField, + state: Any, + t0: AbstractQuantity, + t1: AbstractQuantity, + /, + saveat: Any = default_saveat, + **solver_kw: Any, +) -> dfx.Solution: + # Units + units = self.units + utime = units["time"] + + # Initial conditions + y0 = state.ustrip(units["mass"]) # Mc + + # Extra info + args = {"units": units} + + # Solve the differential equation + solver_kw.setdefault("dt0", None) + saveat = parse_saveat(units, saveat, dense=solver_kw.pop("dense", None)) + soln = self.diffeqsolver( + field.terms(self.diffeqsolver), + t0=t0.ustrip(utime), + t1=t1.ustrip(utime), + y0=y0, + event=self.event, + args=args, + saveat=saveat, + **solver_kw, + ) + + return soln # noqa: RET504 diff --git a/src/galax/dynamics/_src/fields.py b/src/galax/dynamics/_src/fields.py index 82475f06..b27a9c8a 100644 --- a/src/galax/dynamics/_src/fields.py +++ b/src/galax/dynamics/_src/fields.py @@ -10,17 +10,12 @@ from jaxtyping import PyTree from plum import dispatch -import unxt as u - from galax.dynamics._src.diffeq import DiffEqSolver class AbstractField(eqx.Module, strict=True): # type: ignore[misc,call-arg] """Abstract base class for fields.""" - #: unit system of the field. - units: eqx.AbstractVar[u.AbstractUnitSystem] - @abstractmethod def __call__(self, t: Any, *args: Any, **kwargs: Any) -> Any: """Evaluate the field at time `t`.""" diff --git a/src/galax/dynamics/cluster.py b/src/galax/dynamics/cluster.py new file mode 100644 index 00000000..7b287f00 --- /dev/null +++ b/src/galax/dynamics/cluster.py @@ -0,0 +1,34 @@ +""":mod:`galax.dynamics.cluster`.""" + +__all__ = [ + "MassSolver", + # Fields + "MassVectorField", + "AbstractMassField", + "UserMassField", + "ConstantMassField", + # Events + "MassBelowThreshold", + # Functions + "lagrange_points", + "tidal_radius", +] + +from jaxtyping import install_import_hook + +from galax.setup_package import RUNTIME_TYPECHECKER + +with install_import_hook("galax.dynamics.fields", RUNTIME_TYPECHECKER): + from ._src.cluster import ( + AbstractMassField, + ConstantMassField, + MassBelowThreshold, + MassSolver, + MassVectorField, + UserMassField, + lagrange_points, + tidal_radius, + ) + +# Cleanup +del install_import_hook, RUNTIME_TYPECHECKER diff --git a/tests/smoke/dynamics/test_package.py b/tests/smoke/dynamics/test_package.py index 231292e9..f0783c73 100644 --- a/tests/smoke/dynamics/test_package.py +++ b/tests/smoke/dynamics/test_package.py @@ -9,6 +9,7 @@ def test_all() -> None: # modules "fields", "integrate", + "cluster", "mockstream", "plot", # core