Skip to content

Commit

Permalink
Fixed interpolation workflow to keep float32 results when input is fl…
Browse files Browse the repository at this point in the history
…oat32
  • Loading branch information
doc78 committed Dec 15, 2023
1 parent c94a218 commit a8ff078
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 5 deletions.
8 changes: 8 additions & 0 deletions src/pyg2p/main/interpolation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,14 @@ def interpolate_scipy(self, latgrib, longrib, v, grid_id, grid_details=None):
self._log('Trying to interpolate without grib lat/lons. Probably a malformed grib!', 'ERROR')
raise ApplicationException.get_exc(5000)

# CR: to use float32 computations uncomment here:
# longrib=np.float32(longrib)
# latgrib=np.float32(latgrib)
# lonefas=np.float32(lonefas)
# latefas=np.float32(latefas)
# v=np.float32(v)
# self.mv_out=np.float32(self.mv_out)

self._log('\nInterpolating table not found\n Id: {}\nWill create file: {}'.format(intertable_id, intertable_name), 'WARN')
scipy_interpolation = ScipyInterpolation(longrib, latgrib, grid_details, v.ravel(), nnear, self.mv_out,
self._mv_grib, target_is_rotated=self._rotated_target_grid,
Expand Down
18 changes: 13 additions & 5 deletions src/pyg2p/main/interpolation/scipy_interpolation_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,9 +571,9 @@ def interpolate(self, lonefas, latefas):
subset_size = lonefas.shape[0]//self.num_of_splits

# Initialize empty arrays to store the results
weights = np.empty((lonefas.shape[0]*lonefas.shape[1],self.nnear))
weights = np.empty((lonefas.shape[0]*lonefas.shape[1],self.nnear),dtype=lonefas.dtype)
indexes = np.empty((lonefas.shape[0]*lonefas.shape[1],self.nnear),dtype=int)
result = np.empty((lonefas.shape[0]*lonefas.shape[1]))
result = np.empty((lonefas.shape[0]*lonefas.shape[1]),dtype=lonefas.dtype)

# Iterate over the subsets of the arrays
for i in range(0, lonefas.shape[0], subset_size):
Expand Down Expand Up @@ -606,6 +606,8 @@ def interpolate_split(self, target_lons, target_lats):
x, y, z = self.to_3d(target_lons, target_lats, to_regular=self.target_grid_is_rotated)
efas_locations = np.vstack((x.ravel(), y.ravel(), z.ravel())).T
distances, indexes = self.tree.query(efas_locations, k=self.nnear, n_jobs=self.njobs)
if efas_locations.dtype==np.dtype('float32'):
distances=np.float32(distances)
checktime = time.time()
stdout.write('KDtree time (sec): {}\n'.format(checktime - start))

Expand Down Expand Up @@ -667,6 +669,12 @@ def to_3d(self, lons, lats, rotate=False, to_regular=False):
x = ne.evaluate('r * {x}'.format(x=x_formula))
y = ne.evaluate('r * {y}'.format(y=y_formula))
z = ne.evaluate('r * {z}'.format(z=z_formula))

if lons.dtype==np.dtype('float32'):
x=np.float32(x)
y=np.float32(y)
z=np.float32(z)

return x, y, z

def _build_nn(self, distances, indexes):
Expand Down Expand Up @@ -756,8 +764,8 @@ def _build_weights_invdist(self, distances, indexes, nnear, adw_type = None, use
else:
n_debug=11805340
z = self.z
result = mask_it(np.empty((len(distances),) + np.shape(z[0])), self._mv_target, 1)
weights = empty((len(distances),) + (nnear,))
result = mask_it(np.empty((len(distances),) + np.shape(z[0]),dtype=z.dtype), self._mv_target, 1)
weights = empty((len(distances),) + (nnear,),dtype=z.dtype)
idxs = empty((len(indexes),) + (nnear,), fill_value=z.size, dtype=int)
num_cells = result.size
back_char, _ = progress_step_and_backchar(num_cells)
Expand Down Expand Up @@ -1019,7 +1027,7 @@ def _build_weights_invdist(self, distances, indexes, nnear, adw_type = None, use
dist_leq_1e_10 = distances[:, 0] <= 1e-10

# distances <= 1e-10 : take exactly the point, weight = 1
onlyfirst_array = np.zeros(nnear)
onlyfirst_array = np.zeros(nnear, dtype=weights.dtype)
onlyfirst_array[0] = 1
weights[dist_leq_1e_10] = onlyfirst_array
idxs[dist_leq_1e_10] = indexes[dist_leq_1e_10]
Expand Down

0 comments on commit a8ff078

Please sign in to comment.