Skip to content

Commit

Permalink
Broadcasting Fixes (#4)
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobdparker authored Jul 29, 2024
1 parent 1a58915 commit 7e4a472
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 55 deletions.
53 changes: 29 additions & 24 deletions regridding/_conservative_ramshaw.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
# cache=True,
)
def _conservative_ramshaw(
# values_input: np.ndarray,
# values_output: np.ndarray,
grid_input: tuple[np.ndarray, np.ndarray],
grid_output: tuple[np.ndarray, np.ndarray],
epsilon: float = 1e-10,
Expand All @@ -27,14 +25,9 @@ def _conservative_ramshaw(
weights.append((0., 0., 0.))

input_x, input_y = grid_input
# output_x, output_y = grid_output

shape_input = input_x.shape
# shape_output = np.broadcast_shapes(output_x.shape, output_y.shape)

grids_sweep = grid_input, grid_output
grids_static = grid_output, grid_input
grids_input = "sweep", "static"
axes = 0, 1

# k = slice(None, -1)
Expand All @@ -55,23 +48,35 @@ def _conservative_ramshaw(

# values_input = values_input / area_input

for grid_sweep, grid_static, grid_input in zip(grids_sweep, grids_static, grids_input):
grid_static_x, grid_static_y = grid_static
grid_sweep_x, grid_sweep_y = grid_sweep
for axis in axes:
_sweep_axis(
# values_input=values_input,
# values_output=values_output,
area_input=area_input,
grid_sweep_x=grid_sweep_x,
grid_sweep_y=grid_sweep_y,
grid_static_x=grid_static_x,
grid_static_y=grid_static_y,
axis=axis,
grid_input=grid_input,
epsilon=epsilon,
weights=weights,
)
grid_static_x, grid_static_y = grid_output
grid_sweep_x, grid_sweep_y = grid_input
for axis in axes:
_sweep_axis(
area_input=area_input,
grid_sweep_x=grid_sweep_x,
grid_sweep_y=grid_sweep_y,
grid_static_x=grid_static_x,
grid_static_y=grid_static_y,
axis=axis,
grid_input="sweep",
epsilon=epsilon,
weights=weights,
)

grid_static_x, grid_static_y = grid_input
grid_sweep_x, grid_sweep_y = grid_output
for axis in axes:
_sweep_axis(
area_input=area_input,
grid_sweep_x=grid_sweep_x,
grid_sweep_y=grid_sweep_y,
grid_static_x=grid_static_x,
grid_static_y=grid_static_y,
axis=axis,
grid_input="static",
epsilon=epsilon,
weights=weights,
)

# return values_output
return weights
Expand Down
17 changes: 12 additions & 5 deletions regridding/_regrid/_regrid_from_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,17 @@ def regrid_from_weights(
:func:`regridding.regrid_from_weights`
"""

shape_input = np.broadcast_shapes(shape_input, values_input.shape)
values_input = np.broadcast_to(values_input, shape=shape_input, subok=True)
shape_input = np.broadcast_shapes(values_input.shape, shape_input)

ndim_input = len(shape_input)
axis_input = _util._normalize_axis(axis_input, ndim=ndim_input)

shape_orthogonal = (
1 if i in axis_input else shape_input[i] for i in range(-len(shape_input), 0)
)
weights = np.broadcast_to(np.array(weights), shape_orthogonal, subok=True)
values_input = np.broadcast_to(values_input, shape_input, subok=True)

if values_output is None:
shape_output = np.broadcast_shapes(
shape_output,
Expand Down Expand Up @@ -85,10 +91,11 @@ def regrid_from_weights(

shape_output_tmp = values_output.shape

weights = numba.typed.List(weights.reshape(-1))
values_input = values_input.reshape(-1, *shape_input_numba)
values_output = values_output.reshape(-1, *shape_output_numba)

weights = numba.typed.List(weights.reshape(-1))

values_input = np.ascontiguousarray(values_input)
values_output = np.ascontiguousarray(values_output)

Expand All @@ -105,17 +112,17 @@ def regrid_from_weights(
return values_output


@numba.njit()
@numba.njit(parallel=True)
def _regrid_from_weights(
weights: numba.typed.List,
values_input: np.ndarray,
values_output: np.ndarray,
) -> None:

for d in numba.prange(len(weights)):
weights_d = weights[d]
values_input_d = values_input[d].reshape(-1)
values_output_d = values_output[d].reshape(-1)

for w in range(len(weights_d)):
i_input, i_output, weight = weights_d[w]
values_output_d[int(i_output)] += weight * values_input_d[int(i_input)]
95 changes: 76 additions & 19 deletions regridding/_regrid/_tests/test_regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,29 @@
import numpy as np
import regridding

x = np.linspace(-1, 1, num=10)
y = np.linspace(-1, 1, num=11)
x_broadcasted, y_broadcasted = np.meshgrid(
x,
y,
indexing="ij",
)

new_y = np.linspace(-1, 1, num=5)
new_x = np.linspace(-1, 1, num=6)

new_x_broadcasted, new_y_broadcasted = np.meshgrid(
x,
new_y,
indexing="ij",
)

new_x_broadcasted_2, new_y_broadcasted_2 = np.meshgrid(
new_x,
y,
indexing="ij",
)


@pytest.mark.parametrize(
argnames="coordinates_input,coordinates_output,values_input,values_output,axis_input,axis_output,result_expected",
Expand All @@ -24,6 +47,33 @@
None,
np.square(np.linspace(-1, 1, num=11)),
),
(
(y,),
(new_y,),
x_broadcasted + y_broadcasted,
None,
(~0,),
(~0,),
new_x_broadcasted + new_y_broadcasted,
),
(
(x[..., np.newaxis],),
(new_x[..., np.newaxis],),
x_broadcasted + y_broadcasted,
None,
(0,),
(0,),
new_x_broadcasted_2 + new_y_broadcasted_2,
),
(
(x[..., np.newaxis],),
(0.1 * new_x[..., np.newaxis] + 0.001 * new_y,),
x[..., np.newaxis],
None,
(0,),
(0,),
0.1 * new_x[..., np.newaxis] + 0.001 * new_y,
),
],
)
def test_regrid_multilinear_1d(
Expand All @@ -46,35 +96,34 @@ def test_regrid_multilinear_1d(
)
assert isinstance(result, np.ndarray)
assert np.issubdtype(result.dtype, float)
assert np.all(result == result_expected)
assert np.allclose(result, result_expected)


@pytest.mark.parametrize(
argnames="coordinates_input, values_input, axis_input",
argnames="coordinates_input, values_input, axis_input, coordinates_output, values_output, axis_output",
argvalues=[
(
np.meshgrid(
np.linspace(-1, 1, num=10),
np.linspace(-1, 1, num=11),
indexing="ij",
),
(x_broadcasted, y_broadcasted),
np.random.normal(size=(10 - 1, 11 - 1)),
None,
(1.1 * x_broadcasted + 0.01, 1.2 * y_broadcasted + 0.01),
None,
None,
),
],
)
@pytest.mark.parametrize(
argnames="coordinates_output, values_output, axis_output",
argvalues=[
(
np.meshgrid(
1.1 * np.linspace(-1, 1, num=10) + 0.001,
1.2 * np.linspace(-1, 1, num=11) + 0.001,
indexing="ij",
(
x_broadcasted[..., np.newaxis] + np.array([0, 0.001]),
y_broadcasted[..., np.newaxis] + np.array([0, 0.001]),
),
np.random.normal(size=(x.shape[0] - 1, y.shape[0] - 1, 2)),
(0, 1),
(
1.1 * (x_broadcasted[..., np.newaxis] + np.array([0, 0.001])) + 0.01,
1.2 * (y_broadcasted[..., np.newaxis] + np.array([0, 0.01])) + 0.001,
),
None,
None,
)
(0, 1),
),
],
)
def test_regrid_conservative_2d(
Expand All @@ -95,6 +144,14 @@ def test_regrid_conservative_2d(
method="conservative",
)

result_shape = np.array(np.broadcast(*coordinates_output).shape)

if axis_input is None:
result_shape = result_shape - 1
else:
for ax in axis_input:
result_shape[ax] = result_shape[ax] - 1

assert np.issubdtype(result.dtype, float)
assert result.shape == tuple(np.array(np.broadcast(*coordinates_output).shape) - 1)
assert result.shape == tuple(result_shape)
assert np.isclose(result.sum(), values_input.sum())
14 changes: 7 additions & 7 deletions regridding/_weights/_weights_conservative.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,23 +39,23 @@ def _weights_conservative(
weights = np.empty(shape_orthogonal, dtype=numba.typed.List)

for index in np.ndindex(*shape_orthogonal):
index_vertices_input = list(index)
index_vertices_input = list(reversed(index))

for ax in axis_input:
index_vertices_input.insert(ax, slice(None))
index_vertices_input = tuple(index_vertices_input)
index_vertices_input.insert(~ax, slice(None))
index_vertices_input = tuple(reversed(index_vertices_input))

index_vertices_output = list(index)
index_vertices_output = list(reversed(index))
for ax in axis_output:
index_vertices_output.insert(ax, slice(None))
index_vertices_output = tuple(index_vertices_output)
index_vertices_output.insert(~ax, slice(None))
index_vertices_output = tuple(reversed(index_vertices_output))

if len(axis_input) == 1:
raise NotImplementedError("1D regridding not supported")

elif len(axis_input) == 2:
coordinates_input_x, coordinates_input_y = coordinates_input
coordinates_output_x, coordinates_output_y = coordinates_output

weights[index] = _conservative_ramshaw(
grid_input=(
coordinates_input_x[index_vertices_input],
Expand Down

0 comments on commit 7e4a472

Please sign in to comment.