Skip to content

Commit fb39a31

Browse files
committed
✨ feat(dynamics): add solve namespace
Also, moving some stuff out of the integrate namespace
1 parent ce4d030 commit fb39a31

File tree

10 files changed

+52
-31
lines changed

10 files changed

+52
-31
lines changed

src/galax/_interop/galax_interop_astropy/dynamics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def evaluate_orbit(
2525
t: APYQuantity,
2626
/,
2727
*,
28-
solver: gd.integrate.DynamicsSolver | None = None,
28+
solver: gd.DynamicsSolver | None = None,
2929
dense: Literal[True, False] = False,
3030
) -> gd.Orbit:
3131
"""Compute an orbit in a potential.

src/galax/dynamics/__init__.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,16 @@
33
__all__ = [
44
# Modules
55
"fields",
6-
"integrate",
6+
"solve",
7+
"integrate", # TODO: deprecate
78
"mockstream",
89
"plot",
910
"cluster",
10-
# integrate
11+
# solve
1112
"evaluate_orbit",
1213
"Orbit",
14+
"AbstractSolver",
15+
"DynamicsSolver",
1316
# mockstream
1417
"MockStreamArm",
1518
"MockStream",
@@ -35,7 +38,6 @@
3538
from ._src.api import omega, specific_angular_momentum
3639
from ._src.cluster.funcs import lagrange_points, tidal_radius
3740
from ._src.orbit import Orbit
38-
from .integrate import evaluate_orbit
3941
from .mockstream import (
4042
AbstractStreamDF,
4143
ChenStreamDF,
@@ -44,6 +46,7 @@
4446
MockStreamArm,
4547
MockStreamGenerator,
4648
)
49+
from .solve import AbstractSolver, DynamicsSolver, evaluate_orbit
4750

4851
#
4952
# isort: split

src/galax/dynamics/_src/dynamics/field_hamiltonian.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ class HamiltonianField(AbstractDynamicsField, strict=True): # type: ignore[call
7777
Just to continue the example, we can use this field to integrate the
7878
equations of motion:
7979
80-
>>> solver = gd.integrate.DynamicsSolver() # defaults to Dopri8
80+
>>> solver = gd.DynamicsSolver() # defaults to Dopri8
8181
>>> w0 = gc.PhaseSpacePosition(
8282
... q=u.Quantity([[8, 0, 9], [9, 0, 3]], "kpc"),
8383
... p=u.Quantity([0, 220, 0], "km/s"),
@@ -268,7 +268,7 @@ def terms(
268268
269269
For completeness we'll integrate the EoM.
270270
271-
>>> dynamics_solver = gd.integrate.DynamicsSolver(solver)
271+
>>> dynamics_solver = gd.DynamicsSolver(solver)
272272
>>> w0 = gc.PhaseSpacePosition(
273273
... q=u.Quantity([8., 0, 0], "kpc"),
274274
... p=u.Quantity([0, 220, 0], "km/s"),

src/galax/dynamics/_src/dynamics/field_nbody.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class NBodyField(AbstractDynamicsField, strict=True): # type: ignore[call-arg]
4040
>>> q = u.Quantity([[-1, 0, 0], [1, 0, 0]], "AU") / 2
4141
>>> p = u.Quantity([[0, -1, 0], [0, 1, 0]], "km/s") * 25
4242
43-
>>> solver = gd.integrate.DynamicsSolver()
43+
>>> solver = gd.DynamicsSolver()
4444
4545
>>> field = gd.fields.NBodyField(
4646
... masses=u.Quantity([1, 1], "Msun"),

src/galax/dynamics/_src/dynamics/register_gc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def from_(
3838
>>> pot = gp.HernquistPotential(m_tot=u.Quantity(1e12, "Msun"),
3939
... r_s=u.Quantity(5, "kpc"), units="galactic")
4040
>>> field = gd.fields.HamiltonianField(pot)
41-
>>> solver = gd.integrate.DynamicsSolver() # defaults to Dopri8
41+
>>> solver = gd.DynamicsSolver() # defaults to Dopri8
4242
>>> w0 = gc.PhaseSpacePosition(
4343
... q=u.Quantity([[8, 0, 9], [9, 0, 3]], "kpc"),
4444
... p=u.Quantity([0, 220, 0], "km/s"),

src/galax/dynamics/_src/dynamics/solver.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ class DynamicsSolver(AbstractSolver, strict=True): # type: ignore[call-arg]
5050
>>> import galax.potential as gp
5151
>>> import galax.dynamics as gd
5252
53-
>>> solver = gd.integrate.DynamicsSolver() # defaults to Dopri8
53+
>>> solver = gd.DynamicsSolver() # defaults to Dopri8
5454
5555
Define the vector field. In this example it's to solve Hamilton's EoM in a
5656
gravitational potential.
@@ -92,9 +92,9 @@ class DynamicsSolver(AbstractSolver, strict=True): # type: ignore[call-arg]
9292
setting the `diffrax.AbstractSolver`,
9393
`diffrax.AbstractStepSizeController`, etc.
9494
95-
>>> diffeqsolver = gd.integrate.DiffEqSolver(dfx.Dopri8(),
95+
>>> diffeqsolver = gd.solve.DiffEqSolver(dfx.Dopri8(),
9696
... stepsize_controller=dfx.PIDController(rtol=1e-5, atol=1e-5))
97-
>>> solver = gd.integrate.DynamicsSolver(diffeqsolver)
97+
>>> solver = gd.DynamicsSolver(diffeqsolver)
9898
>>> solver
9999
DynamicsSolver(
100100
diffeqsolver=DiffEqSolver(
@@ -107,7 +107,7 @@ class DynamicsSolver(AbstractSolver, strict=True): # type: ignore[call-arg]
107107
2. A `dict` of keyword arguments that are passed to
108108
`galax.dynamics.integrate.DiffEqSolver`.
109109
110-
>>> solver = gd.integrate.DynamicsSolver({
110+
>>> solver = gd.DynamicsSolver({
111111
... "solver": dfx.Dopri8(), "stepsize_controller": dfx.ConstantStepSize()})
112112
>>> solver
113113
DynamicsSolver(
@@ -187,7 +187,7 @@ def solve(
187187
>>> import galax.potential as gp
188188
>>> import galax.dynamics as gd
189189
190-
>>> solver = gd.integrate.DynamicsSolver()
190+
>>> solver = gd.DynamicsSolver()
191191
192192
Specify the vector field.
193193
@@ -658,7 +658,7 @@ def terms(
658658
>>> import galax.potential as gp
659659
>>> import galax.dynamics as gd
660660
661-
>>> solver = gd.integrate.DynamicsSolver(dfx.Dopri8())
661+
>>> solver = gd.DynamicsSolver(dfx.Dopri8())
662662
663663
>>> pot = gp.KeplerPotential(m_tot=u.Quantity(1e12, "Msun"), units="galactic")
664664
>>> field = gd.fields.HamiltonianField(pot)

src/galax/dynamics/_src/fields.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def terms(
4343
>>> import galax.potential as gp
4444
>>> import galax.dynamics as gd
4545
46-
>>> solver = gd.integrate.DiffEqSolver(dfx.Dopri8())
46+
>>> solver = gd.solve.DiffEqSolver(dfx.Dopri8())
4747
4848
>>> pot = gp.KeplerPotential(m_tot=u.Quantity(1e12, "Msun"), units="galactic")
4949
>>> field = gd.fields.HamiltonianField(pot)

src/galax/dynamics/_src/solver.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,9 @@ def from_(cls: type[AbstractSolver], solver: AbstractSolver) -> AbstractSolver:
6262
6363
Examples
6464
--------
65-
>>> from galax.dynamics.integrate import DynamicsSolver
66-
>>> solver = DynamicsSolver()
67-
>>> new_solver = DynamicsSolver.from_(solver)
65+
>>> import galax.dynamics as gd
66+
>>> solver = gd.DynamicsSolver()
67+
>>> new_solver = gd.DynamicsSolver.from_(solver)
6868
>>> new_solver is solver
6969
True
7070
@@ -91,7 +91,7 @@ def from_(cls: type[AbstractSolver], obj: Any) -> AbstractSolver:
9191
Examples
9292
--------
9393
>>> import diffrax as dfx
94-
>>> from galax.dynamics.integrate import DynamicsSolver, DiffEqSolver
94+
>>> from galax.dynamics.solve import DynamicsSolver, DiffEqSolver
9595
9696
>>> DynamicsSolver.from_( DiffEqSolver(dfx.Dopri5()))
9797
DynamicsSolver(
@@ -118,16 +118,16 @@ def from_(cls: type[AbstractSolver], obj: Mapping[str, Any]) -> AbstractSolver:
118118
Examples
119119
--------
120120
>>> import diffrax as dfx
121-
>>> from galax.dynamics.integrate import DynamicsSolver
121+
>>> import galax.dynamics as gd
122122
123-
>>> DynamicsSolver.from_({})
123+
>>> gd.DynamicsSolver.from_({})
124124
DynamicsSolver(
125125
diffeqsolver=DiffEqSolver(
126126
solver=Dopri8(scan_kind=None),
127127
stepsize_controller=PIDController( ... ),
128128
adjoint=RecursiveCheckpointAdjoint(checkpoints=None) ) )
129129
130-
>>> DynamicsSolver.from_({"diffeqsolver": dfx.Dopri5()})
130+
>>> gd.DynamicsSolver.from_({"diffeqsolver": dfx.Dopri5()})
131131
DynamicsSolver(
132132
diffeqsolver=DiffEqSolver(
133133
solver=Dopri5(scan_kind=None),

src/galax/dynamics/integrate.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,29 +5,21 @@
55
"Integrator",
66
"Interpolant",
77
"parse_time_specification",
8-
"AbstractSolver",
9-
"DynamicsSolver",
108
"InterpolatedPhaseSpacePosition",
11-
# Diffraxtra external library
12-
"DiffEqSolver",
13-
"VectorizedDenseInterpolation",
149
]
1510

1611
from jaxtyping import install_import_hook
1712

1813
from galax.setup_package import RUNTIME_TYPECHECKER
1914

2015
with install_import_hook("galax.dynamics.integrate", RUNTIME_TYPECHECKER):
21-
from diffraxtra import DiffEqSolver, VectorizedDenseInterpolation
22-
23-
from ._src.dynamics import DynamicsSolver, parse_time_specification
16+
from ._src.dynamics import parse_time_specification
2417
from ._src.integrate import (
2518
Integrator,
2619
Interpolant,
2720
InterpolatedPhaseSpacePosition,
2821
evaluate_orbit,
2922
)
30-
from ._src.solver import AbstractSolver
3123

3224
# Cleanup
3325
del install_import_hook, RUNTIME_TYPECHECKER

src/galax/dynamics/solve.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
""":mod:`galax.dynamics.solve`."""
2+
3+
__all__ = [
4+
"evaluate_orbit",
5+
"parse_time_specification",
6+
"AbstractSolver",
7+
"DynamicsSolver",
8+
"DiffEqSolver",
9+
"VectorizedDenseInterpolation",
10+
"Orbit",
11+
]
12+
13+
from jaxtyping import install_import_hook
14+
15+
from galax.setup_package import RUNTIME_TYPECHECKER
16+
17+
with install_import_hook("galax.dynamics.integrate", RUNTIME_TYPECHECKER):
18+
from diffraxtra import DiffEqSolver, VectorizedDenseInterpolation
19+
20+
from ._src.dynamics import DynamicsSolver, parse_time_specification
21+
from ._src.integrate import evaluate_orbit
22+
from ._src.orbit import Orbit
23+
from ._src.solver import AbstractSolver
24+
25+
# Cleanup
26+
del install_import_hook, RUNTIME_TYPECHECKER

0 commit comments

Comments
 (0)