Skip to content

Commit

Permalink
✨ feat(dynamics): cluster mass solver
Browse files Browse the repository at this point in the history
Signed-off-by: Nathaniel Starkman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Jan 28, 2025
1 parent c949e71 commit c1616ac
Show file tree
Hide file tree
Showing 8 changed files with 371 additions and 8 deletions.
3 changes: 2 additions & 1 deletion src/galax/dynamics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"integrate",
"mockstream",
"plot",
"cluster",
# integrate
"evaluate_orbit",
"Orbit",
Expand All @@ -29,7 +30,7 @@
from galax.setup_package import RUNTIME_TYPECHECKER

with install_import_hook("galax.dynamics", RUNTIME_TYPECHECKER):
from . import fields, integrate, mockstream, plot
from . import cluster, fields, integrate, mockstream, plot
from ._src.cluster.funcs import lagrange_points, tidal_radius
from ._src.funcs import specific_angular_momentum
from ._src.orbit import Orbit
Expand Down
20 changes: 18 additions & 2 deletions src/galax/dynamics/_src/cluster/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,25 @@
"""Cluster mass evolution."""
"""Cluster evolution."""

__all__ = [
"MassSolver",
# Fields
"MassVectorField",
"AbstractMassField",
"UserMassField",
"ConstantMassField",
# Events
"MassBelowThreshold",
# Functions
"lagrange_points",
"tidal_radius",
]


from .events import MassBelowThreshold
from .fields import (
AbstractMassField,
ConstantMassField,
MassVectorField,
UserMassField,
)
from .funcs import lagrange_points, tidal_radius
from .solver import MassSolver
67 changes: 67 additions & 0 deletions src/galax/dynamics/_src/cluster/events.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""``galax`` dynamics."""

__all__ = ["MassBelowThreshold"]

from typing import Any

import equinox as eqx
from jaxtyping import Array
from plum import dispatch

import unxt as u
from unxt.quantity import AbstractQuantity


class MassBelowThreshold(eqx.Module): # type: ignore[misc]
"""Event to stop integration when the mass falls below a threshold.
Instances can be used as the ``cond_fn`` argument of `diffrax.Event`. Since
this returns a scalar (not a `bool`) the solve the solve will terminate on
the step when the mass is below the threshold.
With `diffrax.Event` this can be combined with a root-finder to find the
exact time when the mass is below the threshold, rather than the step
after.
Example
-------
>>> import unxt as u
>>> from galax.dynamics.cluster import MassBelowThreshold
>>> cond_fn = MassBelowThreshold(u.Quantity(0.0, "Msun"))
>>> args = {"units": u.unitsystems.galactic}
>>> cond_fn(0.0, u.Quantity(1.0, "Msun"), args)
Array(1., dtype=float64, weak_type=True)
>>> cond_fn(0.0, u.Quantity(0.0, "Msun"), args)
Array(0., dtype=float64, weak_type=True)
TODO: example using it as a with `diffrax.Event`.
"""

#: Threshold mass at which to stop integration.
threshold: AbstractQuantity

@dispatch
def __call__(
self: "MassBelowThreshold",
t: Any, # noqa: ARG002
y: Array,
args: dict[str, Any],
/,
**kwargs: Any, # noqa: ARG002
) -> Array:
return y - self.threshold.ustrip(args["units"])

Check warning on line 56 in src/galax/dynamics/_src/cluster/events.py

View check run for this annotation

Codecov / codecov/patch

src/galax/dynamics/_src/cluster/events.py#L56

Added line #L56 was not covered by tests

@dispatch
def __call__(
self: "MassBelowThreshold",
t: Any, # noqa: ARG002
y: AbstractQuantity,
args: dict[str, Any],
/,
**kwargs: Any, # noqa: ARG002
) -> Array:
return u.ustrip(args["units"], y - self.threshold)
122 changes: 122 additions & 0 deletions src/galax/dynamics/_src/cluster/fields.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
"""Fields for mass evolution."""

__all__ = [
"MassVectorField",
"AbstractMassField",
"UserMassField",
"ConstantMassField",
]

from abc import abstractmethod
from dataclasses import KW_ONLY
from typing import Any, Protocol, TypeAlias, TypedDict, runtime_checkable

import diffrax as dfx
import equinox as eqx
import jax.numpy as jnp
from jaxtyping import Array, PyTree

from galax.dynamics._src.fields import AbstractField

Time: TypeAlias = Any
ClusterMass: TypeAlias = Any


class Args(TypedDict, total=False):
units: Any
# Add other optional keys here if needed


@runtime_checkable
class MassVectorField(Protocol):
"""Protocol for mass vector field.
This is a function that returns the derivative of the mass vector with
respect to time.
Examples
--------
>>> from galax.dynamics.cluster import MassVectorField
>>> def mass_deriv(t, Mc, args, **kwargs): pass
>>> isinstance(mass_deriv, MassVectorField)
True
"""

def __call__(
self, t: Time, Mc: ClusterMass, args: Args, /, **kwargs: Any
) -> Array: ...


class AbstractMassField(AbstractField):
"""ABC for mass fields.
Methods
-------
__call__ : `galax.dynamics.cluster.MassVectorField`
Compute the mass field.
terms : the `diffrax.AbstractTerm` `jaxtyping.PyTree` for integration.
"""

@abstractmethod
def __call__(self, t: Time, Mc: ClusterMass, args: Args, /, **kwargs: Any) -> Array: # type: ignore[override]
raise NotImplementedError # pragma: no cover

@AbstractField.terms.dispatch # type: ignore[misc]
def terms(
self: "AbstractMassField", _: dfx.AbstractSolver, /
) -> PyTree[dfx.AbstractTerm]:
"""Return diffeq terms for integration.
Examples
--------
>>> import diffrax as dfx
>>> import galax.dynamics as gd
>>> field = gd.cluster.ConstantMassField()
>>> field.terms(dfx.Dopri8())
ODETerm(
vector_field=_JitWrapper( fn='ConstantMassField.__call__', ... ) )
"""
return dfx.ODETerm(eqx.filter_jit(self.__call__))


#####################################################


class UserMassField(AbstractMassField):
"""User-defined mass field.
This takes a user-defined function of type
`galax.dynamics.cluster.MassVectorField`.
"""

#: User-defined mass derivative function of type
#: `galax.dynamics.cluster.MassVectorField`
mass_deriv: MassVectorField

_: KW_ONLY

def __call__(self, t: Time, Mc: ClusterMass, args: Args, /, **kwargs: Any) -> Array: # type: ignore[override]
return self.mass_deriv(t, Mc, args, **kwargs)

Check warning on line 106 in src/galax/dynamics/_src/cluster/fields.py

View check run for this annotation

Codecov / codecov/patch

src/galax/dynamics/_src/cluster/fields.py#L106

Added line #L106 was not covered by tests


#####################################################


class ConstantMassField(AbstractMassField):
"""Constant mass field.
This is a constant mass field.
"""

def __call__(self, t: Time, Mc: ClusterMass, args: Args, /, **kwargs: Any) -> Array: # type: ignore[override] # noqa: ARG002
return jnp.zeros_like(Mc)

Check warning on line 122 in src/galax/dynamics/_src/cluster/fields.py

View check run for this annotation

Codecov / codecov/patch

src/galax/dynamics/_src/cluster/fields.py#L122

Added line #L122 was not covered by tests
127 changes: 127 additions & 0 deletions src/galax/dynamics/_src/cluster/solver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
"""``galax`` dynamics."""

__all__ = ["MassSolver"]

from dataclasses import KW_ONLY
from typing import Any

import diffrax as dfx
import equinox as eqx
import optimistix as optx
from plum import dispatch

import unxt as u
from unxt.quantity import AbstractQuantity

from .events import MassBelowThreshold
from .fields import AbstractMassField
from galax.dynamics._src.diffeq import DiffEqSolver
from galax.dynamics._src.solver import AbstractSolver
from galax.dynamics._src.utils import parse_saveat


class MassSolver(AbstractSolver, strict=True): # type: ignore[call-arg]
"""Solver for mass history."""

diffeqsolver: DiffEqSolver = eqx.field(
default=DiffEqSolver(
solver=dfx.Dopri8(),
stepsize_controller=dfx.PIDController(rtol=1e-8, atol=1e-8),
),
converter=DiffEqSolver.from_,
)
# TODO: should this be incorporated into in `DiffEqSolver`?
event: dfx.Event = eqx.field(
default=dfx.Event(
cond_fn=MassBelowThreshold(u.Quantity(0.0, "Msun")),
root_finder=optx.Newton(1e-5, 1e-5, optx.rms_norm),
)
)

_: KW_ONLY

units: u.AbstractUnitSystem = eqx.field(
default=u.unitsystems.galactic, converter=u.unitsystem, static=True
)

@dispatch.abstract
def init(
self: "MassSolver", terms: Any, t0: Any, t1: Any, y0: Any, args: Any
) -> Any:
# See dispatches below
raise NotImplementedError # pragma: no cover

@dispatch.abstract
def step(
self: "MassSolver",
terms: Any,
t0: Any,
t1: Any,
y0: Any,
args: Any,
**step_kwargs: Any, # e.g. solver_state, made_jump
) -> Any:
"""Step the state."""
# See dispatches below
raise NotImplementedError # pragma: no cover

# TODO: dispatch where the state from `init` is accepted
@dispatch.abstract
def solve(
self: "MassSolver",
field: Any,
state: Any,
t0: Any,
t1: Any,
/,
args: Any = (),
**solver_kw: Any, # TODO: TypedDict
) -> dfx.Solution:
"""Call `diffrax.diffeqsolve`."""
raise NotImplementedError # pragma: no cover


# ===================================================================
# Solve Dispatches


default_saveat = dfx.SaveAt(t1=True)


@MassSolver.solve.dispatch # type: ignore[misc]
@eqx.filter_jit # type: ignore[misc]
def solve(
self: MassSolver,
field: AbstractMassField,
state: Any,
t0: AbstractQuantity,
t1: AbstractQuantity,
/,
saveat: Any = default_saveat,
**solver_kw: Any,
) -> dfx.Solution:
# Units
units = self.units
utime = units["time"]

Check warning on line 105 in src/galax/dynamics/_src/cluster/solver.py

View check run for this annotation

Codecov / codecov/patch

src/galax/dynamics/_src/cluster/solver.py#L104-L105

Added lines #L104 - L105 were not covered by tests

# Initial conditions
y0 = state.ustrip(units["mass"]) # Mc

Check warning on line 108 in src/galax/dynamics/_src/cluster/solver.py

View check run for this annotation

Codecov / codecov/patch

src/galax/dynamics/_src/cluster/solver.py#L108

Added line #L108 was not covered by tests

# Extra info
args = {"units": units}

Check warning on line 111 in src/galax/dynamics/_src/cluster/solver.py

View check run for this annotation

Codecov / codecov/patch

src/galax/dynamics/_src/cluster/solver.py#L111

Added line #L111 was not covered by tests

# Solve the differential equation
solver_kw.setdefault("dt0", None)
saveat = parse_saveat(units, saveat, dense=solver_kw.pop("dense", None))
soln = self.diffeqsolver(

Check warning on line 116 in src/galax/dynamics/_src/cluster/solver.py

View check run for this annotation

Codecov / codecov/patch

src/galax/dynamics/_src/cluster/solver.py#L114-L116

Added lines #L114 - L116 were not covered by tests
field.terms(self.diffeqsolver),
t0=t0.ustrip(utime),
t1=t1.ustrip(utime),
y0=y0,
event=self.event,
args=args,
saveat=saveat,
**solver_kw,
)

return soln # noqa: RET504

Check warning on line 127 in src/galax/dynamics/_src/cluster/solver.py

View check run for this annotation

Codecov / codecov/patch

src/galax/dynamics/_src/cluster/solver.py#L127

Added line #L127 was not covered by tests
5 changes: 0 additions & 5 deletions src/galax/dynamics/_src/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,12 @@
from jaxtyping import PyTree
from plum import dispatch

import unxt as u

from galax.dynamics._src.diffeq import DiffEqSolver


class AbstractField(eqx.Module, strict=True): # type: ignore[misc,call-arg]
"""Abstract base class for fields."""

#: unit system of the field.
units: eqx.AbstractVar[u.AbstractUnitSystem]

@abstractmethod
def __call__(self, t: Any, *args: Any, **kwargs: Any) -> Any:
"""Evaluate the field at time `t`."""
Expand Down
Loading

0 comments on commit c1616ac

Please sign in to comment.