From a7027c6825fd0661074a67cdf9b0406b8982600d Mon Sep 17 00:00:00 2001 From: William Patton Date: Tue, 19 Dec 2023 10:37:48 -0700 Subject: [PATCH] fix the test case --- tests/cases/pad.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/tests/cases/pad.py b/tests/cases/pad.py index 3b9939e7..2ee8968d 100644 --- a/tests/cases/pad.py +++ b/tests/cases/pad.py @@ -17,6 +17,8 @@ import pytest import numpy as np +from itertools import product + @pytest.mark.parametrize("mode", ["constant", "reflect"]) def test_output(mode): @@ -50,7 +52,20 @@ def test_output(mode): assert pipeline.spec[graph_key].roi == Roi((190, 10, 10), (1820, 200, 200)) batch = pipeline.request_batch( - BatchRequest({array_key: ArraySpec(Roi((180, 0, 0), (20, 20, 20)))}) + BatchRequest({array_key: ArraySpec(Roi((180, 0, 0), (40, 40, 40)))}) ) - assert np.sum(batch.arrays[array_key].data) == 1 * 10 * 10 + data = batch.arrays[array_key].data + if mode == "constant": + octants = [ + (1 * 10 * 10) if zi + yi + xi < 3 else 100 * 1 * 5 * 10 + for zi, yi, xi in product(range(2), range(2), range(2)) + ] + assert np.sum(data) == np.sum(octants), ( + np.sum(data), + np.sum(octants), + np.unique(data), + ) + elif mode == "reflect": + octants = [100 * 1 * 5 * 10 for _ in range(8)] + assert np.sum(data) == np.sum(octants), data.shape