Skip to content

Commit

Permalink
Merge pull request #1400 from OceanParcels/fix_casting_warnings
Browse files Browse the repository at this point in the history
Fix casting warnings
  • Loading branch information
erikvansebille authored Jul 19, 2023
2 parents 9573623 + 501d5cd commit bf98c97
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 7 deletions.
1 change: 1 addition & 0 deletions parcels/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -1449,6 +1449,7 @@ def computeTimeChunk(self, data, tindex):
interp_method=self.interp_method,
data_full_zdim=self.data_full_zdim,
chunksize=self.chunksize,
cast_data_dtype=self.cast_data_dtype,
rechunk_callback_fields=rechunk_callback_fields,
chunkdims_name_map=self.netcdf_chunkdims_name_map,
netcdf_decodewarning=self.netcdf_decodewarning)
Expand Down
5 changes: 3 additions & 2 deletions parcels/fieldfilebuffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(self, filename, dimensions, indices, timestamp=None,
self.indices = indices
self.dataset = None
self.timestamp = timestamp
self.cast_data_dtype = kwargs.pop('cast_data_dtype', np.float32)
self.ti = None
self.interp_method = interp_method
self.data_full_zdim = data_full_zdim
Expand Down Expand Up @@ -207,7 +208,7 @@ def data_access(self):
data = self.dataset[self.name]
ti = range(data.shape[0]) if self.ti is None else self.ti
data = self._apply_indices(data, ti)
return np.array(data)
return np.array(data, dtype=self.cast_data_dtype)

@property
def time(self):
Expand Down Expand Up @@ -740,7 +741,7 @@ def data_access(self):
self.rechunk_callback_fields()
self.chunking_finalized = True

return data
return data.astype(self.cast_data_dtype)


class DeferredDaskFileBuffer(DaskFileBuffer):
Expand Down
10 changes: 5 additions & 5 deletions parcels/particlefile/baseparticlefile.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,16 +201,16 @@ def write_once(self, var):

def _extend_zarr_dims(self, Z, store, dtype, axis):
if axis == 1:
a = np.full((Z.shape[0], self.chunks[1]), np.nan, dtype=dtype)
a = np.full((Z.shape[0], self.chunks[1]), self.fill_value_map[dtype], dtype=dtype)
obs = zarr.group(store=store, overwrite=False)["obs"]
if len(obs) == Z.shape[1]:
obs.append(np.arange(self.chunks[1])+obs[-1]+1)
else:
extra_trajs = max(self.maxids - Z.shape[0], self.chunks[0])
if len(Z.shape) == 2:
a = np.full((extra_trajs, Z.shape[1]), np.nan, dtype=dtype)
a = np.full((extra_trajs, Z.shape[1]), self.fill_value_map[dtype], dtype=dtype)
else:
a = np.full((extra_trajs,), np.nan, dtype=dtype)
a = np.full((extra_trajs,), self.fill_value_map[dtype], dtype=dtype)
Z.append(a, axis=axis)
zarr.consolidate_metadata(store)

Expand Down Expand Up @@ -281,11 +281,11 @@ def write(self, pset, time, deleted_only=False):
varout = self._convert_varout_name(var)
if varout not in ['trajectory']: # because 'trajectory' is written as coordinate
if self.write_once(var):
data = np.full((arrsize[0],), np.nan, dtype=self.vars_to_write[var])
data = np.full((arrsize[0],), self.fill_value_map[self.vars_to_write[var]], dtype=self.vars_to_write[var])
data[ids_once] = pset.collection.getvardata(var, indices_to_write_once)
dims = ["trajectory"]
else:
data = np.full(arrsize, np.nan, dtype=self.vars_to_write[var])
data = np.full(arrsize, self.fill_value_map[self.vars_to_write[var]], dtype=self.vars_to_write[var])
data[ids, 0] = pset.collection.getvardata(var, indices_to_write)
dims = ["trajectory", "obs"]
ds[varout] = xr.DataArray(data=data, dims=dims, attrs=attrs[varout])
Expand Down

0 comments on commit bf98c97

Please sign in to comment.