Skip to content

Commit e070242

Browse files
committed
♻️ refactor(dynamics): evaluate_orbit uses DynamicsSolver
1 parent ce4d030 commit e070242

File tree

10 files changed

+199
-144
lines changed

10 files changed

+199
-144
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
"coordinax>=0.19.0",
3333
"dataclassish>=0.7.1",
3434
"diffrax>=0.6",
35-
"diffraxtra>=1.0.2",
35+
"diffraxtra>=1.1.0",
3636
"equinox>=0.11.8",
3737
"is-annotated>=1.0",
3838
"jax>=0.4.35",

src/galax/dynamics/_src/integrate/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
"Interpolant",
1313
]
1414

15-
from .funcs import evaluate_orbit
1615
from .integrator import Integrator
1716
from .interp import Interpolant
1817
from .interp_psp import InterpolatedPhaseSpacePosition

src/galax/dynamics/_src/mockstream/mockstream_generator.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,14 @@
2020
from .core import MockStream, MockStreamArm
2121
from .df import AbstractStreamDF, ProgenitorMassCallable
2222
from .utils import cond_reverse
23-
from galax.dynamics._src.integrate.funcs import _default_integrator, evaluate_orbit
24-
from galax.dynamics._src.integrate.integrator import Integrator
25-
from galax.dynamics._src.orbit import Orbit
23+
from galax.dynamics._src.dynamics import DynamicsSolver
24+
from galax.dynamics._src.orbit import Orbit, evaluate_orbit
2625
from galax.potential import AbstractPotential
2726

2827
Carry: TypeAlias = tuple[gt.IntSz0, gt.SzN, gt.SzN]
2928

29+
_default_solver: DynamicsSolver = DynamicsSolver()
30+
3031

3132
@final
3233
class MockStreamGenerator(eqx.Module): # type: ignore[misc]
@@ -42,13 +43,11 @@ class MockStreamGenerator(eqx.Module): # type: ignore[misc]
4243
"""Potential in which the progenitor orbits and creates a stream."""
4344

4445
_: KW_ONLY
45-
progenitor_integrator: Integrator = eqx.field(
46-
default=_default_integrator, static=True
47-
)
48-
"""Integrator for the progenitor orbit."""
46+
progenitor_solver: DynamicsSolver = eqx.field(default=_default_solver, static=True)
47+
"""Solver for the progenitor orbit."""
4948

50-
stream_integrator: Integrator = eqx.field(default=_default_integrator, static=True)
51-
"""Integrator for the stream."""
49+
stream_solver: DynamicsSolver = eqx.field(default=_default_solver, static=True)
50+
"""Solver for the stream."""
5251

5352
@property
5453
def units(self) -> u.AbstractUnitSystem:
@@ -63,9 +62,7 @@ def _progenitor_trajectory(
6362
"""Integrate the progenitor orbit."""
6463
return cast(
6564
Orbit,
66-
evaluate_orbit(
67-
self.potential, w0, ts, integrator=self.progenitor_integrator
68-
),
65+
evaluate_orbit(self.potential, w0, ts, integrator=self.progenitor_solver),
6966
)
7067

7168
# ==========================================================================
@@ -103,7 +100,7 @@ def one_pt_intg(
103100
def integ_ics(ics: gt.Sz6) -> gt.SzN:
104101
# TODO: only return the final state
105102
return evaluate_orbit(
106-
self.potential, ics, tstep, integrator=self.stream_integrator
103+
self.potential, ics, tstep, integrator=self.stream_solver
107104
).w(units=self.units)[-1]
108105

109106
# vmap integration over leading and trailing arm
@@ -138,10 +135,10 @@ def one_pt_intg(
138135
) -> tuple[gt.Sz6, gt.Sz6]:
139136
tstep = jnp.asarray([ts[i], t_f])
140137
w_lead = evaluate_orbit(
141-
self.potential, w0_l_i, tstep, integrator=self.stream_integrator
138+
self.potential, w0_l_i, tstep, integrator=self.stream_solver
142139
).w(units=self.potential.units)[-1]
143140
w_trail = evaluate_orbit(
144-
self.potential, w0_t_i, tstep, integrator=self.stream_integrator
141+
self.potential, w0_t_i, tstep, integrator=self.stream_solver
145142
).w(units=self.potential.units)[-1]
146143
return w_lead, w_trail
147144

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Orbits. Private module."""
22

3-
__all__ = ["Orbit", "plot_components"]
3+
__all__ = ["Orbit", "evaluate_orbit", "plot_components"]
44

5+
from .funcs import evaluate_orbit
56
from .orbit import Orbit
67
from .plot_helper import plot_components

src/galax/dynamics/_src/integrate/funcs.py renamed to src/galax/dynamics/_src/orbit/funcs.py

Lines changed: 39 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,8 @@
22

33
__all__ = ["evaluate_orbit"]
44

5-
from collections.abc import Callable
65
from typing import Any, Literal
76

8-
import jax
9-
from jaxtyping import Array
107
from plum import dispatch
118

129
import quaxed.numpy as jnp
@@ -16,17 +13,12 @@
1613
import galax.dynamics._src.custom_types as gdt
1714
import galax.potential as gp
1815
import galax.typing as gt
19-
from .integrator import Integrator
20-
from galax.dynamics._src.dynamics import HamiltonianField
21-
from galax.dynamics._src.orbit import Orbit
16+
from .interp import PhaseSpaceInterpolation
17+
from .orbit import Orbit
18+
from galax.dynamics._src.dynamics import DynamicsSolver, HamiltonianField
2219

2320
# TODO: enable setting the default integrator
24-
_default_integrator: Integrator = Integrator()
25-
26-
27-
_select_w0: Callable[[Array, Array, Array], Array] = jax.numpy.vectorize(
28-
jax.lax.select, signature="(),(6),(6)->(6)"
29-
)
21+
_default_solver: DynamicsSolver = DynamicsSolver()
3022

3123

3224
@dispatch
@@ -36,7 +28,7 @@ def evaluate_orbit(
3628
t: Any,
3729
/,
3830
*,
39-
integrator: Integrator | None = None,
31+
solver: DynamicsSolver | None = None,
4032
dense: Literal[True, False] = False,
4133
) -> Orbit:
4234
"""Compute an orbit in a potential.
@@ -119,14 +111,22 @@ def evaluate_orbit(
119111
... t=u.Quantity(-100, "Myr"))
120112
>>> ts = u.Quantity(jnp.linspace(0, 1, 4), "Gyr")
121113
>>> orbit = gd.evaluate_orbit(potential, w0, ts)
122-
>>> orbit
114+
>>> print(orbit)
123115
Orbit(
124-
q=CartesianPos3D(...), p=CartesianVel3D(...),
125-
t=Quantity['time'](Array(..., dtype=float64), unit='Myr'),
126-
frame=SimulationFrame(),
127-
potential=KeplerPotential(...),
128-
interpolant=None
129-
)
116+
q=<CartesianPos3D (x[kpc], y[kpc], z[kpc])
117+
[[ 8.953 -1.324 0. ]
118+
[ 9.035 1.276 0. ]
119+
[ 2.644 -2.021 0. ]
120+
[ 9.827 -0.562 0. ]]>,
121+
p=<CartesianVel3D (x[kpc / Myr], y[kpc / Myr], z[kpc / Myr])
122+
[[ 0.322 0.181 0. ]
123+
[-0.308 0.183 0. ]
124+
[ 1.336 -0.247 0. ]
125+
[ 0.126 0.201 0. ]]>,
126+
t=Quantity['time'](Array([...], dtype=float64), unit='Myr'),
127+
frame=SimulationFrame(),
128+
potential=KeplerPotential( ... ),
129+
interpolant=None)
130130
131131
Note how there are 4 points in the orbit, corresponding to the 4 requested
132132
return times. These are the times at which the orbit is evaluated, not the
@@ -166,7 +166,7 @@ def evaluate_orbit(
166166
>>> orbit
167167
Orbit(
168168
q=CartesianPos3D(
169-
x=Quantity[PhysicalType('length')](value=f64[2,10], unit=Unit("kpc")),
169+
x=Quantity[PhysicalType('length')](value=f64[10,2], unit=Unit("kpc")),
170170
...
171171
),
172172
p=CartesianVel3D(...),
@@ -199,7 +199,7 @@ def evaluate_orbit(
199199
"""
200200
# Setup
201201
units = pot.units
202-
integrator = _default_integrator if integrator is None else integrator
202+
solver = _default_solver if solver is None else solver
203203
t = jnp.atleast_1d(FastQ.from_(t, units["time"])) # ensure t units
204204

205205
field = HamiltonianField(pot)
@@ -208,23 +208,29 @@ def evaluate_orbit(
208208
tw0 = w0.t if (isinstance(w0, gc.PhaseSpacePosition) and w0.t is not None) else t[0]
209209

210210
# Initial integration `w0.t` to `t[0]`.
211-
# TODO: get diffrax's `solver_state` to speed the second integration.
212-
# TODO: get diffrax's `controller_state` to speed the second integration.
211+
# TODO: diffrax's `solver_state`, `controller_state`
213212
# TODO: `max_steps` as kwarg.
214-
qp0 = integrator(
215-
field,
216-
w0,
217-
tw0,
218-
jnp.full_like(tw0, fill_value=t[0]),
219-
dense=False,
220-
)
213+
fullt0 = jnp.full_like(tw0, fill_value=t[0])
214+
soln0 = solver.solve(field, w0, tw0, fullt0, dense=False, unbatch_time=True)
221215

222216
# Orbit integration `t[0]` to `t[-1]`
223217
# TODO: `max_steps` as kwarg.
224-
ws = integrator(field, qp0, t[0], t[-1], saveat=t, dense=dense)
218+
ys = (FastQ(soln0.ys[0], units["length"]), FastQ(soln0.ys[1], units["speed"]))
219+
soln = solver.solve(
220+
field, ys, t[0], t[-1], saveat=t, dense=dense, vectorize_interpolation=True
221+
)
225222

226223
# Return the orbit object
227-
return Orbit._from_psp(ws, t, pot) # noqa: SLF001
224+
return Orbit(
225+
q=FastQ(soln.ys[0], units["length"]),
226+
p=FastQ(soln.ys[1], units["speed"]),
227+
t=t,
228+
frame=getattr(w0, "frame", gc.frames.SimulationFrame()),
229+
potential=pot,
230+
interpolant=(
231+
PhaseSpaceInterpolation(soln.interpolation, units=units) if dense else None
232+
),
233+
)
228234

229235

230236
@dispatch
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
"""galax: Galactic Dynamix in Jax."""
2+
3+
__all__ = ["PhaseSpaceInterpolation"]
4+
5+
from typing import Any, cast
6+
from typing_extensions import override
7+
8+
import diffrax as dfx
9+
import equinox as eqx
10+
11+
import diffraxtra as dfxtra
12+
import unxt as u
13+
14+
import galax.coordinates as gc
15+
import galax.typing as gt
16+
17+
18+
class PhaseSpaceInterpolation(dfxtra.AbstractVectorizedDenseInterpolation): # type: ignore[misc]
19+
"""Phase-space interpolation for orbit evaluation."""
20+
21+
#: The vectorized interpolation object.
22+
interp: dfxtra.VectorizedDenseInterpolation = eqx.field(
23+
converter=dfxtra.VectorizedDenseInterpolation.from_
24+
)
25+
26+
#: The unit system for the interpolation.
27+
units: u.AbstractUnitSystem = eqx.field(static=True, converter=u.unitsystem)
28+
29+
@override
30+
@eqx.filter_jit # type: ignore[misc]
31+
def evaluate(self, ts: u.Quantity) -> gc.PhaseSpacePosition:
32+
# Parse the time
33+
t = u.Quantity.from_(ts, self.units["time"])
34+
35+
# Evaluate the interpolation
36+
ys = self.interp.evaluate(t.ustrip(self.units["time"]))
37+
38+
# Return as a phase-space position
39+
return gc.PhaseSpacePosition(
40+
q=u.Quantity(ys[0], self.units["length"]),
41+
p=u.Quantity(ys[1], self.units["speed"]),
42+
t=t,
43+
)
44+
45+
def __call__(self, *args: Any, **kwds: Any) -> gc.PhaseSpacePosition:
46+
return cast(gc.PhaseSpacePosition, self.evaluate(*args, **kwds))
47+
48+
@property
49+
def scalar_interpolation(self) -> dfx.DenseInterpolation:
50+
"""Return the scalar interpolation for the phase-space position."""
51+
return cast(dfx.DenseInterpolation, self.interp.scalar_interpolation)
52+
53+
@property
54+
def batch_shape(self) -> gt.Shape:
55+
"""Return the batch shape of the interpolation."""
56+
return cast(gt.Shape, self.interp.batch_shape)
57+
58+
@property
59+
def y0_shape(self) -> gt.Shape:
60+
"""Return the shape of the initial value."""
61+
return cast(gt.Shape, self.interp.y0_shape)

0 commit comments

Comments
 (0)