Skip to content

Commit 24bacb1

Browse files
committed
✨ feat(dynamics): cluster mass solver
Signed-off-by: Nathaniel Starkman <nstarman@users.noreply.github.com>
1 parent d805ac8 commit 24bacb1

File tree

9 files changed

+372
-9
lines changed

9 files changed

+372
-9
lines changed

src/galax/dynamics/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
"integrate",
77
"mockstream",
88
"plot",
9+
"cluster",
910
# integrate
1011
"evaluate_orbit",
1112
"Orbit",
@@ -29,7 +30,7 @@
2930
from galax.setup_package import RUNTIME_TYPECHECKER
3031

3132
with install_import_hook("galax.dynamics", RUNTIME_TYPECHECKER):
32-
from . import fields, integrate, mockstream, plot
33+
from . import cluster, fields, integrate, mockstream, plot
3334
from ._src.cluster.funcs import lagrange_points, tidal_radius
3435
from ._src.funcs import specific_angular_momentum
3536
from ._src.orbit import Orbit
Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,25 @@
1-
"""Cluster mass evolution."""
1+
"""Cluster evolution."""
22

33
__all__ = [
4+
"MassSolver",
5+
# Fields
6+
"MassVectorField",
7+
"AbstractMassField",
8+
"UserMassField",
9+
"ConstantMassField",
10+
# Events
11+
"MassBelowThreshold",
12+
# Functions
413
"lagrange_points",
514
"tidal_radius",
615
]
716

8-
17+
from .events import MassBelowThreshold
18+
from .fields import (
19+
AbstractMassField,
20+
ConstantMassField,
21+
MassVectorField,
22+
UserMassField,
23+
)
924
from .funcs import lagrange_points, tidal_radius
25+
from .solver import MassSolver
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
"""``galax`` dynamics."""
2+
3+
__all__ = ["MassBelowThreshold"]
4+
5+
from typing import Any
6+
7+
import equinox as eqx
8+
from jaxtyping import Array
9+
from plum import dispatch
10+
11+
import unxt as u
12+
from unxt.quantity import AbstractQuantity
13+
14+
15+
class MassBelowThreshold(eqx.Module): # type: ignore[misc]
16+
"""Event to stop integration when the mass falls below a threshold.
17+
18+
Instances can be used as the ``cond_fn`` argument of `diffrax.Event`. Since
19+
this returns a scalar (not a `bool`) the solve the solve will terminate on
20+
the step when the mass is below the threshold.
21+
22+
With `diffrax.Event` this can be combined with a root-finder to find the
23+
exact time when the mass is below the threshold, rather than the step
24+
after.
25+
26+
Example
27+
-------
28+
>>> import unxt as u
29+
>>> from galax.dynamics.cluster import MassBelowThreshold
30+
31+
>>> cond_fn = MassBelowThreshold(u.Quantity(0.0, "Msun"))
32+
>>> args = {"units": u.unitsystems.galactic}
33+
34+
>>> cond_fn(0.0, u.Quantity(1.0, "Msun"), args)
35+
Array(1., dtype=float64, weak_type=True)
36+
37+
>>> cond_fn(0.0, u.Quantity(0.0, "Msun"), args)
38+
Array(0., dtype=float64, weak_type=True)
39+
40+
TODO: example using it as a with `diffrax.Event`.
41+
42+
"""
43+
44+
#: Threshold mass at which to stop integration.
45+
threshold: AbstractQuantity
46+
47+
@dispatch
48+
def __call__(
49+
self: "MassBelowThreshold",
50+
t: Any, # noqa: ARG002
51+
y: Array,
52+
args: dict[str, Any],
53+
/,
54+
**kwargs: Any, # noqa: ARG002
55+
) -> Array:
56+
return y - self.threshold.ustrip(args["units"])
57+
58+
@dispatch
59+
def __call__(
60+
self: "MassBelowThreshold",
61+
t: Any, # noqa: ARG002
62+
y: AbstractQuantity,
63+
args: dict[str, Any],
64+
/,
65+
**kwargs: Any, # noqa: ARG002
66+
) -> Array:
67+
return u.ustrip(args["units"], y - self.threshold)
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
"""Fields for mass evolution."""
2+
3+
__all__ = [
4+
"MassVectorField",
5+
"AbstractMassField",
6+
"UserMassField",
7+
"ConstantMassField",
8+
]
9+
10+
from abc import abstractmethod
11+
from dataclasses import KW_ONLY
12+
from typing import Any, Protocol, TypeAlias, TypedDict, runtime_checkable
13+
14+
import diffrax as dfx
15+
import equinox as eqx
16+
import jax.numpy as jnp
17+
from jaxtyping import Array, PyTree
18+
19+
from galax.dynamics._src.fields import AbstractField
20+
21+
Time: TypeAlias = Any
22+
ClusterMass: TypeAlias = Any
23+
24+
25+
class Args(TypedDict, total=False):
26+
units: Any
27+
# Add other optional keys here if needed
28+
29+
30+
@runtime_checkable
31+
class MassVectorField(Protocol):
32+
"""Protocol for mass vector field.
33+
34+
This is a function that returns the derivative of the mass vector with
35+
respect to time.
36+
37+
Examples
38+
--------
39+
>>> from galax.dynamics.cluster import MassVectorField
40+
41+
>>> def mass_deriv(t, Mc, args, **kwargs): pass
42+
43+
>>> isinstance(mass_deriv, MassVectorField)
44+
True
45+
46+
"""
47+
48+
def __call__(
49+
self, t: Time, Mc: ClusterMass, args: Args, /, **kwargs: Any
50+
) -> Array: ...
51+
52+
53+
class AbstractMassField(AbstractField):
54+
"""ABC for mass fields.
55+
56+
Methods
57+
-------
58+
__call__ : `galax.dynamics.cluster.MassVectorField`
59+
Compute the mass field.
60+
terms : the `diffrax.AbstractTerm` `jaxtyping.PyTree` for integration.
61+
62+
"""
63+
64+
@abstractmethod
65+
def __call__(self, t: Time, Mc: ClusterMass, args: Args, /, **kwargs: Any) -> Array: # type: ignore[override]
66+
raise NotImplementedError # pragma: no cover
67+
68+
@AbstractField.terms.dispatch # type: ignore[misc]
69+
def terms(
70+
self: "AbstractMassField", _: dfx.AbstractSolver, /
71+
) -> PyTree[dfx.AbstractTerm]:
72+
"""Return diffeq terms for integration.
73+
74+
Examples
75+
--------
76+
>>> import diffrax as dfx
77+
>>> import galax.dynamics as gd
78+
79+
>>> field = gd.cluster.ConstantMassField()
80+
>>> field.terms(dfx.Dopri8())
81+
ODETerm(
82+
vector_field=_JitWrapper( fn='ConstantMassField.__call__', ... ) )
83+
84+
"""
85+
return dfx.ODETerm(eqx.filter_jit(self.__call__))
86+
87+
88+
#####################################################
89+
90+
91+
class UserMassField(AbstractMassField):
92+
"""User-defined mass field.
93+
94+
This takes a user-defined function of type
95+
`galax.dynamics.cluster.MassVectorField`.
96+
97+
"""
98+
99+
#: User-defined mass derivative function of type
100+
#: `galax.dynamics.cluster.MassVectorField`
101+
mass_deriv: MassVectorField
102+
103+
_: KW_ONLY
104+
105+
def __call__(self, t: Time, Mc: ClusterMass, args: Args, /, **kwargs: Any) -> Array: # type: ignore[override]
106+
return self.mass_deriv(t, Mc, args, **kwargs)
107+
108+
109+
#####################################################
110+
111+
112+
class ConstantMassField(AbstractMassField):
113+
"""Constant mass field.
114+
115+
This is a constant mass field.
116+
117+
118+
119+
"""
120+
121+
def __call__(self, t: Time, Mc: ClusterMass, args: Args, /, **kwargs: Any) -> Array: # type: ignore[override] # noqa: ARG002
122+
return jnp.zeros_like(Mc)
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
"""``galax`` dynamics."""
2+
3+
__all__ = ["MassSolver"]
4+
5+
from dataclasses import KW_ONLY
6+
from typing import Any
7+
8+
import diffrax as dfx
9+
import equinox as eqx
10+
import optimistix as optx
11+
from plum import dispatch
12+
13+
import unxt as u
14+
from unxt.quantity import AbstractQuantity
15+
16+
from .events import MassBelowThreshold
17+
from .fields import AbstractMassField
18+
from galax.dynamics._src.diffeq import DiffEqSolver
19+
from galax.dynamics._src.solver import AbstractSolver
20+
from galax.dynamics._src.utils import parse_saveat
21+
22+
23+
class MassSolver(AbstractSolver, strict=True): # type: ignore[call-arg]
24+
"""Solver for mass history."""
25+
26+
diffeqsolver: DiffEqSolver = eqx.field(
27+
default=DiffEqSolver(
28+
solver=dfx.Dopri8(),
29+
stepsize_controller=dfx.PIDController(rtol=1e-8, atol=1e-8),
30+
),
31+
converter=DiffEqSolver.from_,
32+
)
33+
# TODO: should this be incorporated into in `DiffEqSolver`?
34+
event: dfx.Event = eqx.field(
35+
default=dfx.Event(
36+
cond_fn=MassBelowThreshold(u.Quantity(0.0, "Msun")),
37+
root_finder=optx.Newton(1e-5, 1e-5, optx.rms_norm),
38+
)
39+
)
40+
41+
_: KW_ONLY
42+
43+
units: u.AbstractUnitSystem = eqx.field(
44+
default=u.unitsystems.galactic, converter=u.unitsystem, static=True
45+
)
46+
47+
@dispatch.abstract
48+
def init(
49+
self: "MassSolver", terms: Any, t0: Any, t1: Any, y0: Any, args: Any
50+
) -> Any:
51+
# See dispatches below
52+
raise NotImplementedError # pragma: no cover
53+
54+
@dispatch.abstract
55+
def step(
56+
self: "MassSolver",
57+
terms: Any,
58+
t0: Any,
59+
t1: Any,
60+
y0: Any,
61+
args: Any,
62+
**step_kwargs: Any, # e.g. solver_state, made_jump
63+
) -> Any:
64+
"""Step the state."""
65+
# See dispatches below
66+
raise NotImplementedError # pragma: no cover
67+
68+
# TODO: dispatch where the state from `init` is accepted
69+
@dispatch.abstract
70+
def solve(
71+
self: "MassSolver",
72+
field: Any,
73+
state: Any,
74+
t0: Any,
75+
t1: Any,
76+
/,
77+
args: Any = (),
78+
**solver_kw: Any, # TODO: TypedDict
79+
) -> dfx.Solution:
80+
"""Call `diffrax.diffeqsolve`."""
81+
raise NotImplementedError # pragma: no cover
82+
83+
84+
# ===================================================================
85+
# Solve Dispatches
86+
87+
88+
default_saveat = dfx.SaveAt(t1=True)
89+
90+
91+
@MassSolver.solve.dispatch # type: ignore[misc]
92+
@eqx.filter_jit # type: ignore[misc]
93+
def solve(
94+
self: MassSolver,
95+
field: AbstractMassField,
96+
state: Any,
97+
t0: AbstractQuantity,
98+
t1: AbstractQuantity,
99+
/,
100+
saveat: Any = default_saveat,
101+
**solver_kw: Any,
102+
) -> dfx.Solution:
103+
# Units
104+
units = self.units
105+
utime = units["time"]
106+
107+
# Initial conditions
108+
y0 = state.ustrip(units["mass"]) # Mc
109+
110+
# Extra info
111+
args = {"units": units}
112+
113+
# Solve the differential equation
114+
solver_kw.setdefault("dt0", None)
115+
saveat = parse_saveat(units, saveat, dense=solver_kw.pop("dense", None))
116+
soln = self.diffeqsolver(
117+
field.terms(self.diffeqsolver),
118+
t0=t0.ustrip(utime),
119+
t1=t1.ustrip(utime),
120+
y0=y0,
121+
event=self.event,
122+
args=args,
123+
saveat=saveat,
124+
**solver_kw,
125+
)
126+
127+
return soln # noqa: RET504

src/galax/dynamics/_src/fields.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,12 @@
1010
from jaxtyping import PyTree
1111
from plum import dispatch
1212

13-
import unxt as u
14-
1513
from galax.dynamics._src.diffeq import DiffEqSolver
1614

1715

1816
class AbstractField(eqx.Module, strict=True): # type: ignore[misc,call-arg]
1917
"""Abstract base class for fields."""
2018

21-
#: unit system of the field.
22-
units: eqx.AbstractVar[u.AbstractUnitSystem]
23-
2419
@abstractmethod
2520
def __call__(self, t: Any, *args: Any, **kwargs: Any) -> Any:
2621
"""Evaluate the field at time `t`."""

0 commit comments

Comments
 (0)