Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PR]: Improving regrid2 performance #533

Merged
merged 27 commits into from
Mar 8, 2024
Merged
Changes from 1 commit
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
5b1f3e5
Fixes regrid2 operating on numpy arrays
jasonb5 Aug 16, 2023
6ed8990
Merge branch 'main' into regrid2_performance
jasonb5 Oct 11, 2023
f50cf24
Removes regrid output mask
jasonb5 Oct 26, 2023
2e58cd7
Fixes extracting input data variable
jasonb5 Oct 26, 2023
df2d23a
Updates regrid2
jasonb5 Nov 7, 2023
3d5d0d9
Overrides dtype to match input
jasonb5 Nov 8, 2023
5f5e1be
Fixes wrapping when using np.take
jasonb5 Nov 8, 2023
dd671b4
Fixes copying variable attributes
jasonb5 Nov 8, 2023
c6e65b7
Fixes typing errors
jasonb5 Nov 8, 2023
a68bf9f
Fixes correcting dtype before mapping/regridding
jasonb5 Dec 7, 2023
53989b7
Fixes tests
jasonb5 Dec 7, 2023
e6fe7c9
Fixes latitude that was wrapping
jasonb5 Dec 7, 2023
934e028
Fixes copying coordinates when name is missmatched
jasonb5 Dec 7, 2023
62fe454
Fixes shifting longitude
jasonb5 Dec 7, 2023
d5d7f49
Merge branch 'main' into regrid2_performance
jasonb5 Dec 7, 2023
a2eb927
Fixes missmatched coordinate names
jasonb5 Dec 7, 2023
ea19fb0
Fixes reshape/ordering output data
jasonb5 Dec 9, 2023
b67fe3a
Adds more optimizations
jasonb5 Dec 9, 2023
bcdc652
Fixes masking
jasonb5 Jan 19, 2024
3312be3
Fixes failing tests
jasonb5 Jan 19, 2024
9377a88
Merge branch 'main' into regrid2_performance
jasonb5 Jan 19, 2024
9d49dbe
Fixes tests
jasonb5 Jan 20, 2024
ad61067
Merge branch 'main' into regrid2_performance
jasonb5 Feb 12, 2024
96eaae2
Fixes docstrings and adds comments
jasonb5 Feb 23, 2024
7695e2b
Adds comments
jasonb5 Feb 28, 2024
a5bf57c
Merge branch 'main' into regrid2_performance
jasonb5 Feb 28, 2024
e2125f9
Replaces cf.axes with get_dim_keys
jasonb5 Mar 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fixes typing errors
  • Loading branch information
jasonb5 committed Nov 8, 2023

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
commit c6e65b717064322b71230efc49a56a5dad5414d9
23 changes: 12 additions & 11 deletions xcdat/regridder/regrid2.py
Original file line number Diff line number Diff line change
@@ -90,10 +90,10 @@ def horizontal(self, data_var: str, ds: xr.Dataset) -> xr.Dataset:

def _regrid(
input_data_var: xr.DataArray,
src_lat_bnds: list,
src_lon_bnds: list,
dst_lat_bnds: list,
dst_lon_bnds: list,
src_lat_bnds: np.ndarray,
src_lon_bnds: np.ndarray,
dst_lat_bnds: np.ndarray,
dst_lon_bnds: np.ndarray,
) -> np.ndarray:
lat_mapping, lat_weights = _map_latitude(src_lat_bnds, dst_lat_bnds)
lon_mapping, lon_weights = _map_longitude(src_lon_bnds, dst_lon_bnds)
@@ -124,7 +124,7 @@ def _regrid(
y_index = dims.index(y_name)
x_index = dims.index(x_name)

output_data = []
output_points = []

target_dtype = input_data.dtype

@@ -145,16 +145,16 @@ def _regrid(
dtype=target_dtype,
) / np.sum(cell_weight, dtype=target_dtype)

output_data.append(cell_value)
output_points.append(cell_value)

output_data = np.asarray(output_data, dtype=target_dtype)
output_data = np.array(output_points, dtype=target_dtype)
output_data = output_data.reshape(tuple(data_shape.values()))

return output_data


def _build_dataset(
ds: xr.DataArray,
ds: xr.Dataset,
data_var: str,
output_data: np.ndarray,
dst_lat_bnds,
@@ -164,8 +164,8 @@ def _build_dataset(
) -> xr.Dataset:
input_data_var = ds[data_var]

output_coords = {}
output_data_vars = {}
output_coords: dict[str, xr.DataArray] = {}
output_data_vars: dict[str, xr.DataArray] = {}
output_bnds = {
"Y": dst_lat_bnds,
"X": dst_lon_bnds,
@@ -197,9 +197,10 @@ def _build_dataset(
dims=input_data_var.dims,
coords=output_coords,
attrs=ds[data_var].attrs.copy(),
name=data_var,
)

output_data_vars[input_data_var.name] = output_da
output_data_vars[data_var] = output_da

output_ds = xr.Dataset(
output_data_vars,