Skip to content

Commit

Permalink
make MBItem._determine_name() an abstract method to children can bett…
Browse files Browse the repository at this point in the history
…er handle their naming
  • Loading branch information
rocco8773 committed Jan 15, 2025
1 parent 9279489 commit 24aa199
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 30 deletions.
3 changes: 3 additions & 0 deletions bapsf_motion/motion_builder/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,9 @@ def _build_initial_ds(self):

return ds

def _determine_name(self):
return self.base_name

def add_layer(self, ly_type: str, **settings):
"""
Add a "point" layer to the motion builder.
Expand Down
36 changes: 23 additions & 13 deletions bapsf_motion/motion_builder/exclusions/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Module that defines the `BaseExclusion` abstract class."""
__all__ = ["BaseExclusion", "GovernExclusion"]

import ast
import numpy as np
import re
import xarray as xr
Expand All @@ -11,7 +12,7 @@
from bapsf_motion.motion_builder.item import MBItem


class BaseExclusion(ABC, MBItem):
class BaseExclusion(MBItem):
"""
Abstract base class for :term:`motion exclusion` classes.
Expand Down Expand Up @@ -133,18 +134,6 @@ def inputs(self) -> Dict[str, Any]:
"""
return self._inputs

@MBItem.name.setter
def name(self, name: str):
if not self.skip_ds_add:
# The exclusion name is a part of the Dataset management,
# so we can NOT/ should NOT rename it
return
elif not isinstance(name, str):
return

self._name = name
self._name_pattern = re.compile(rf"{name}(?P<number>[0-9]+)")

@abstractmethod
def _generate_exclusion(self) -> Union[np.ndarray, xr.DataArray]:
"""
Expand All @@ -161,6 +150,27 @@ def _validate_inputs(self) -> None:
"""
...

def _determine_name(self):
try:
return self.name
except AttributeError:
# self._name has not been defined yet
pass

names = set(self._ds.data_vars.keys())
ids = []
for name in names:
_match = self.name_pattern.fullmatch(name)
if _match is not None:
ids.append(
ast.literal_eval(_match.group("number"))
)

ids = list(set(ids))
_id = 0 if not ids else ids[-1] + 1

return f"{self.base_name}{_id:d}"

def is_excluded(self, point):
"""
Check if ``point`` resides in an excluded region defined by
Expand Down
20 changes: 5 additions & 15 deletions bapsf_motion/motion_builder/item.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import re
import xarray as xr

from abc import ABC, abstractmethod
from typing import Hashable, Tuple

try:
Expand All @@ -16,7 +17,7 @@
ErrorOptions = str


class MBItem:
class MBItem(ABC):
r"""
A base class for any :term:`motion builder` class that will interact
with the `xarray` `~xarray.Dataset` containing the
Expand Down Expand Up @@ -155,7 +156,8 @@ def _validate_ds(ds: xr.Dataset) -> xr.Dataset:

return ds

def _determine_name(self):
@abstractmethod
def _determine_name(self) -> str:
"""
Determine the name for the motion builder item that will be used
in the `~xarray.Dataset`. This is generally the name of the
Expand All @@ -165,19 +167,7 @@ def _determine_name(self):
:attr:`name_pattern` and generate a unique :attr:`name` for
the motion builder item.
"""
try:
return self.name
except AttributeError:
# self._name has not been defined yet
pass

names = set(self._ds.data_vars.keys())
n_existing = 0
for name in names:
if self.name_pattern.fullmatch(name) is not None:
n_existing += 1

return f"{self.base_name}{n_existing + 1:d}"
...

def drop_vars(self, names: str, *, errors: ErrorOptions = "raise"):
new_ds = self._ds.drop_vars(names, errors=errors)
Expand Down
26 changes: 24 additions & 2 deletions bapsf_motion/motion_builder/layers/base.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
"""Module that defines the `BaseLayer` abstract class."""
__all__ = ["BaseLayer"]

import ast
import re
import numpy as np
import xarray as xr

from abc import ABC, abstractmethod
from abc import abstractmethod
from typing import Any, Dict, List, Union

from bapsf_motion.motion_builder.item import MBItem


class BaseLayer(ABC, MBItem):
class BaseLayer(MBItem):
"""
Abstract base class for :term:`motion layer` classes.
Expand Down Expand Up @@ -142,6 +143,27 @@ def _validate_inputs(self) -> None:
"""
...

def _determine_name(self):
try:
return self.name
except AttributeError:
# self._name has not been defined yet
pass

names = set(self._ds.data_vars.keys())
ids = []
for name in names:
_match = self.name_pattern.fullmatch(name)
if _match is not None:
ids.append(
ast.literal_eval(_match.group("number"))
)

ids = list(set(ids))
_id = 0 if not ids else ids[-1] + 1

return f"{self.base_name}{_id:d}"

def _generate_point_matrix_da(self):
"""
Generate the :term:`motion layer` array/matrix and add it to
Expand Down

0 comments on commit 24aa199

Please sign in to comment.