diff --git a/netcdf_to_gltf_converter/netcdf/netcdf_data.py b/netcdf_to_gltf_converter/netcdf/netcdf_data.py index f773e68..92a5d5c 100644 --- a/netcdf_to_gltf_converter/netcdf/netcdf_data.py +++ b/netcdf_to_gltf_converter/netcdf/netcdf_data.py @@ -20,13 +20,13 @@ def get_coordinate_variables(data, standard_names: tuple) -> List[xr.DataArray]: -class VariableBase(ABC): +class DataVariable(): """Class that serves as a wrapper object for an xarray.DataArray. The wrapper allows for easier retrieval of relevant data. """ def __init__(self, data: xr.DataArray) -> None: - """Initialize a VariableBase with the specified data. + """Initialize a DataVariable with the specified data. Args: data (xr.DataArray): The variable data. @@ -130,20 +130,20 @@ def get_array(self, variable_name: str) -> xr.DataArray: self._raise_if_not_in_dataset(variable_name) return self._dataset[variable_name] - @abstractmethod - def get_variable(self, variable_name: str) -> VariableBase: + def get_variable(self, variable_name: str) -> DataVariable: """Get the variable with the specified name from the data set. Args: variable_name (str): The variable name. Returns: - VariableBase: The wrapper object for the variable. + DataVariable: The wrapper object for the variable. Raises: ValueError: When the dataset does not contain a variable with the name. """ - pass + data = self.get_array(variable_name) + return DataVariable(data) def _raise_if_not_in_dataset(self, name: str): if name not in self._dataset: diff --git a/netcdf_to_gltf_converter/netcdf/parser.py b/netcdf_to_gltf_converter/netcdf/parser.py index 9732994..0a43aa8 100644 --- a/netcdf_to_gltf_converter/netcdf/parser.py +++ b/netcdf_to_gltf_converter/netcdf/parser.py @@ -9,7 +9,7 @@ from netcdf_to_gltf_converter.netcdf.xbeach.xbeach_data import XBeachDataset from netcdf_to_gltf_converter.netcdf.netcdf_data import ( DatasetBase, - VariableBase, + DataVariable, ) from netcdf_to_gltf_converter.preprocessing.interpolation import ( NearestPointInterpolator, @@ -111,7 +111,7 @@ def _transform_grid(config: Config, dataset: DatasetBase): variables = [var.name for var in config.variables] dataset.scale_coordinates(config.scale_horizontal, config.scale_vertical, variables) - def _interpolate(self, data: VariableBase, time_index: int, dataset: DatasetBase): + def _interpolate(self, data: DataVariable, time_index: int, dataset: DatasetBase): return self._interpolator.interpolate( data.coordinates, data.get_data_at_time(time_index), dataset ) diff --git a/netcdf_to_gltf_converter/netcdf/ugrid/ugrid_data.py b/netcdf_to_gltf_converter/netcdf/ugrid/ugrid_data.py index a959bc6..09095c2 100644 --- a/netcdf_to_gltf_converter/netcdf/ugrid/ugrid_data.py +++ b/netcdf_to_gltf_converter/netcdf/ugrid/ugrid_data.py @@ -4,16 +4,7 @@ import xarray as xr import xugrid as xu -from netcdf_to_gltf_converter.netcdf.netcdf_data import ( - DatasetBase, - VariableBase, -) - -class UgridVariable(VariableBase): - """Class that serves as a wrapper object for an xarray.DataArray with UGrid conventions. - The wrapper allows for easier retrieval of relevant data. - """ - +from netcdf_to_gltf_converter.netcdf.netcdf_data import DatasetBase class UgridDataset(DatasetBase): """Class that serves as a wrapper object for an xarray.Dataset with UGrid conventions. @@ -50,21 +41,6 @@ def min_y(self) -> float: _, min_y, _,_ = self._grid.bounds return min_y - def get_variable(self, variable_name: str) -> UgridVariable: - """Get the variable with the specified name from the data set. - - Args: - variable_name (str): The variable name. - - Returns: - UgridVariable: A UgridVariable. - - Raises: - ValueError: When the dataset does not contain a variable with the name. - """ - data = self.get_array(variable_name) - return UgridVariable(data) - def _get_ugrid2d(self) -> xu.Ugrid2d: for grid in self._ugrid_data_set.grids: if isinstance(grid, xu.Ugrid2d): diff --git a/netcdf_to_gltf_converter/netcdf/xbeach/xbeach_data.py b/netcdf_to_gltf_converter/netcdf/xbeach/xbeach_data.py index 24701ef..a8f2a38 100644 --- a/netcdf_to_gltf_converter/netcdf/xbeach/xbeach_data.py +++ b/netcdf_to_gltf_converter/netcdf/xbeach/xbeach_data.py @@ -2,9 +2,8 @@ import numpy as np import xarray as xr -from netcdf_to_gltf_converter.netcdf.netcdf_data import DatasetBase, VariableBase, get_coordinate_variables +from netcdf_to_gltf_converter.netcdf.netcdf_data import DatasetBase, get_coordinate_variables from netcdf_to_gltf_converter.netcdf.xbeach import connectivity -from netcdf_to_gltf_converter.utils.arrays import uint32_array from xugrid.ugrid.conventions import X_STANDARD_NAMES, Y_STANDARD_NAMES @@ -36,12 +35,6 @@ def node_coordinates(self) -> np.ndarray: np.ndarray: An ndarray of floats with shape (n, 2). Each row represents one node and contains the x- and y-coordinate. """ return np.column_stack([self.node_x, self.node_y]) - -class XBeachVariable(VariableBase): - """Class that serves as a wrapper object for an xarray.DataArray for XBEACH output. - The wrapper allows for easier retrieval of relevant data. - """ - class XBeachDataset(DatasetBase): """Class that serves as a wrapper object for an xarray.Dataset with UGrid conventions. @@ -75,30 +68,18 @@ def min_y(self) -> float: float: A floating value with the smallest y-coordinate. """ return self._y_coord_vars[0].values.min() - - def get_variable(self, variable_name: str) -> XBeachVariable: - """Get the variable with the specified name from the data set. - - Args: - variable_name (str): The variable name. - - Returns: - UgridVariable: A UgridVariable. - - Raises: - ValueError: When the dataset does not contain a variable with the name. - """ - data = self.get_array(variable_name) - return XBeachVariable(data) def transform_coordinate_system(self, source_crs: int, target_crs: int): """Transform the coordinates to another coordinate system. + Args: source_crs (int): EPSG from the source coordinate system. target_crs (int): EPSG from the target coordinate system. + Raises: + NotImplementedError: Thrown because coordinate system transformation is not yet suupport for regular grids. """ - pass + raise NotImplementedError("Coordinate system transformation is not yet suupport for regular grids.") @property def face_node_connectivity(self) -> np.ndarray: