diff --git a/tests/cases/csv_points_source.py b/tests/cases/csv_points_source.py index 1fec0833..66ba613a 100644 --- a/tests/cases/csv_points_source.py +++ b/tests/cases/csv_points_source.py @@ -2,6 +2,7 @@ import numpy as np import pytest +import csv from gunpowder import ( BatchRequest, @@ -22,7 +23,7 @@ def seeds(): @pytest.fixture -def test_points(tmpdir): +def test_points_2d(tmpdir): random.seed(1234) np.random.seed(1234) @@ -41,8 +42,31 @@ def test_points(tmpdir): yield fake_points_file, fake_points -def test_pipeline3(test_points): - fake_points_file, fake_points = test_points +@pytest.fixture +def test_points_3d(tmpdir): + random.seed(1234) + np.random.seed(1234) + + fake_points_file = tmpdir / "shift_test.csv" + fake_points = np.random.randint(0, 100, size=(3, 3)).astype(float) + with open(fake_points_file, "w") as f: + writer = csv.DictWriter(f, fieldnames=["x", "y", "z", "id"]) + writer.writeheader() + for i, point in enumerate(fake_points): + pointdict = {"x": point[0], "y": point[1], "z": point[2], "id": i} + writer.writerow(pointdict) + + # This fixture will run after seeds since it is set + # with autouse=True. So make sure to reset the seeds properly at the end + # of this fixture + random.seed(12345) + np.random.seed(12345) + + yield fake_points_file, fake_points + + +def test_pipeline_2d(test_points_2d): + fake_points_file, fake_points = test_points_2d points_key = GraphKey("TEST_POINTS") @@ -67,3 +91,35 @@ def test_pipeline3(test_points): result_locs = [list(point.location) for point in result_points] assert sorted(result_locs) == sorted(target_locs) + + +def test_pipeline_3d(test_points_3d): + fake_points_file, fake_points = test_points_3d + + points_key = GraphKey("TEST_POINTS") + scale = 2 + csv_source = CsvPointsSource( + fake_points_file, + points_key, + spatial_cols=[0, 2, 1], + delimiter=",", + id_col=3, + points_spec=GraphSpec(roi=Roi(shape=Coordinate((100, 100)), offset=(0, 0))), + scale=scale, + ) + + request = BatchRequest() + shape = Coordinate((100, 100, 100)) + request.add(points_key, shape) + + pipeline = csv_source + with build(pipeline) as b: + request = b.request_batch(request) + + result_points = list(request[points_key].nodes) + for node in result_points: + orig_loc = fake_points[int(node.id)] + reordered_loc = orig_loc.copy() + reordered_loc[1] = orig_loc[2] + reordered_loc[2] = orig_loc[1] + assert list(node.location) == list(reordered_loc * scale)