Skip to content

Commit

Permalink
clean up steps
Browse files Browse the repository at this point in the history
  • Loading branch information
rhayes777 committed Nov 29, 2024
1 parent 4e9912c commit e8db9e2
Showing 1 changed file with 15 additions and 83 deletions.
98 changes: 15 additions & 83 deletions test_autolens/point/triangles/test_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,12 @@

import autolens as al
import autogalaxy as ag
from autoarray.structures.triangles.abstract import HEIGHT_FACTOR
from autoarray.structures.triangles.coordinate_array import CoordinateArrayTriangles
from autoarray.structures.triangles.jax_coordinate_array import (
CoordinateArrayTriangles as JAXTriangles,
)
from autoarray.structures.triangles.shape import Point
from autolens.mock import NullTracer
from autolens.point.solver import PointSolver
from autolens.point.visualise import visualise, plot_triangles_compare, plot_triangles


@pytest.fixture
Expand Down Expand Up @@ -86,96 +83,31 @@ def triangle_set(triangles):
}


def test_real_example(grid, tracer):
solver = PointSolver.for_grid(
grid=grid,
pixel_scale_precision=0.001,
)
jax_solver = PointSolver.for_grid(
grid=grid,
pixel_scale_precision=0.001,
array_triangles_cls=JAXTriangles,
)

point = Point(0.07, 0.07)

for step, jax_step in zip(
solver.steps(tracer=tracer, shape=point),
jax_solver.steps(tracer=tracer, shape=point),
):
initial_triangles = step.initial_triangles
jax_initial_triangles = jax_step.initial_triangles

initial_triangle_set = triangle_set(initial_triangles)
jax_initial_triangle_set = triangle_set(jax_initial_triangles)

print(
"difference in initial",
initial_triangle_set.difference(jax_initial_triangle_set),
)

print("Difference in vertices")
print(
{
tuple(map(float, np.round(v, 3))) for v in initial_triangles.vertices
}.difference(
{
tuple(map(float, np.round(v, 3)))
for v in jax_initial_triangles.vertices
if not np.isnan(v).any()
}
)
)

source_triangles = triangle_set(step.source_triangles)
jax_source_triangles = triangle_set(jax_step.source_triangles)

print(
"in source but not jax", source_triangles.difference(jax_source_triangles)
)
print(
"in jax but not source", jax_source_triangles.difference(source_triangles)
)

if step.number == 2:
break


def test_real_example_jax_only(grid, tracer):
def test_real_example_jax(grid, tracer):
jax_solver = PointSolver.for_grid(
grid=grid,
pixel_scale_precision=0.001,
array_triangles_cls=JAXTriangles,
)

for step in jax_solver.steps(
result = jax_solver.solve(
tracer=tracer,
shape=Point(
0.07,
0.07,
),
):
triangles = step.initial_triangles
source_plane_coordinate=(0.07, 0.07),
)

print(triangles)
visualise(step)
assert len(result) == 5


def test_broken_step(grid, tracer):
solver = PointSolver(
scale=0.5,
def test_real_example_normal(grid, tracer):
jax_solver = PointSolver.for_grid(
grid=grid,
pixel_scale_precision=0.001,
initial_triangles=JAXTriangles(
coordinates=np.array([[6.0, 3.0]]),
side_length=0.5,
flipped=True,
y_offset=-0.25 * HEIGHT_FACTOR,
),
array_triangles_cls=CoordinateArrayTriangles,
)
step = next(
solver.steps(
tracer=tracer,
shape=Point(0.07, 0.07),
)

result = jax_solver.solve(
tracer=tracer,
source_plane_coordinate=(0.07, 0.07),
)
visualise(step)

assert len(result) == 5

0 comments on commit e8db9e2

Please sign in to comment.