diff --git a/parcels/field.py b/parcels/field.py index 91a4ae91b..1c3980eb8 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -1449,6 +1449,7 @@ def computeTimeChunk(self, data, tindex): interp_method=self.interp_method, data_full_zdim=self.data_full_zdim, chunksize=self.chunksize, + cast_data_dtype=self.cast_data_dtype, rechunk_callback_fields=rechunk_callback_fields, chunkdims_name_map=self.netcdf_chunkdims_name_map, netcdf_decodewarning=self.netcdf_decodewarning) diff --git a/parcels/fieldfilebuffer.py b/parcels/fieldfilebuffer.py index 85e8339d0..8fef570bd 100644 --- a/parcels/fieldfilebuffer.py +++ b/parcels/fieldfilebuffer.py @@ -22,6 +22,7 @@ def __init__(self, filename, dimensions, indices, timestamp=None, self.indices = indices self.dataset = None self.timestamp = timestamp + self.cast_data_dtype = kwargs.pop('cast_data_dtype', np.float32) self.ti = None self.interp_method = interp_method self.data_full_zdim = data_full_zdim @@ -207,7 +208,7 @@ def data_access(self): data = self.dataset[self.name] ti = range(data.shape[0]) if self.ti is None else self.ti data = self._apply_indices(data, ti) - return np.array(data) + return np.array(data, dtype=self.cast_data_dtype) @property def time(self): @@ -740,7 +741,7 @@ def data_access(self): self.rechunk_callback_fields() self.chunking_finalized = True - return data + return data.astype(self.cast_data_dtype) class DeferredDaskFileBuffer(DaskFileBuffer): diff --git a/parcels/particlefile/baseparticlefile.py b/parcels/particlefile/baseparticlefile.py index 755d2b1ce..db6f3594c 100644 --- a/parcels/particlefile/baseparticlefile.py +++ b/parcels/particlefile/baseparticlefile.py @@ -201,16 +201,16 @@ def write_once(self, var): def _extend_zarr_dims(self, Z, store, dtype, axis): if axis == 1: - a = np.full((Z.shape[0], self.chunks[1]), np.nan, dtype=dtype) + a = np.full((Z.shape[0], self.chunks[1]), self.fill_value_map[dtype], dtype=dtype) obs = zarr.group(store=store, overwrite=False)["obs"] if len(obs) == Z.shape[1]: obs.append(np.arange(self.chunks[1])+obs[-1]+1) else: extra_trajs = max(self.maxids - Z.shape[0], self.chunks[0]) if len(Z.shape) == 2: - a = np.full((extra_trajs, Z.shape[1]), np.nan, dtype=dtype) + a = np.full((extra_trajs, Z.shape[1]), self.fill_value_map[dtype], dtype=dtype) else: - a = np.full((extra_trajs,), np.nan, dtype=dtype) + a = np.full((extra_trajs,), self.fill_value_map[dtype], dtype=dtype) Z.append(a, axis=axis) zarr.consolidate_metadata(store) @@ -281,11 +281,11 @@ def write(self, pset, time, deleted_only=False): varout = self._convert_varout_name(var) if varout not in ['trajectory']: # because 'trajectory' is written as coordinate if self.write_once(var): - data = np.full((arrsize[0],), np.nan, dtype=self.vars_to_write[var]) + data = np.full((arrsize[0],), self.fill_value_map[self.vars_to_write[var]], dtype=self.vars_to_write[var]) data[ids_once] = pset.collection.getvardata(var, indices_to_write_once) dims = ["trajectory"] else: - data = np.full(arrsize, np.nan, dtype=self.vars_to_write[var]) + data = np.full(arrsize, self.fill_value_map[self.vars_to_write[var]], dtype=self.vars_to_write[var]) data[ids, 0] = pset.collection.getvardata(var, indices_to_write) dims = ["trajectory", "obs"] ds[varout] = xr.DataArray(data=data, dims=dims, attrs=attrs[varout])