diff --git a/.github/workflows/integration_tests.yml b/.github/workflows/integration_tests.yml index 4516f24b1..91b7264de 100644 --- a/.github/workflows/integration_tests.yml +++ b/.github/workflows/integration_tests.yml @@ -38,6 +38,11 @@ jobs: integration-tests: runs-on: [self-hosted, gpu, db] steps: + - name: Export OpenMPI PATH + run: | + export PATH=/opt/openmpi-4.1.5/bin:$PATH + export LD_LIBRARY_PATH=/opt/openmpi-4.1.5/lib:$LD_LIBRARY_PATH + ompi_info - uses: actions/checkout@v3 - name: Install poetry run: pip install poetry @@ -49,6 +54,8 @@ jobs: poetry env use '3.10' source $(poetry env info --path)/bin/activate poetry install --with test + pip install cupy-cuda12x + env MPICC=/opt/openmpi-4.1.5/bin/mpicc python -m pip install git+https://github.com/mpi4py/mpi4py coverage run -m pytest -m integration_test && coverage xml && coverage report -m - name: Upload coverage to Codecov uses: Wandalen/wretry.action@v1.0.36 diff --git a/cyclops/evaluate/metrics/experimental/__init__.py b/cyclops/evaluate/metrics/experimental/__init__.py new file mode 100644 index 000000000..401e491f8 --- /dev/null +++ b/cyclops/evaluate/metrics/experimental/__init__.py @@ -0,0 +1,6 @@ +"""Array-API-compatible metrics.""" +from cyclops.evaluate.metrics.experimental.confusion_matrix import ( + BinaryConfusionMatrix, + MulticlassConfusionMatrix, + MultilabelConfusionMatrix, +) diff --git a/cyclops/evaluate/metrics/experimental/confusion_matrix.py b/cyclops/evaluate/metrics/experimental/confusion_matrix.py new file mode 100644 index 000000000..eddca4dc3 --- /dev/null +++ b/cyclops/evaluate/metrics/experimental/confusion_matrix.py @@ -0,0 +1,376 @@ +"""Confusion matrix.""" +from typing import Any, List, Optional, Tuple, Union + +from cyclops.evaluate.metrics.experimental.functional.confusion_matrix import ( + _binary_confusion_matrix_compute, + _binary_confusion_matrix_format_inputs, + _binary_confusion_matrix_update_state, + _binary_confusion_matrix_validate_args, + _binary_confusion_matrix_validate_arrays, + _multiclass_confusion_matrix_compute, + _multiclass_confusion_matrix_format_inputs, + _multiclass_confusion_matrix_update_state, + _multiclass_confusion_matrix_validate_args, + _multiclass_confusion_matrix_validate_arrays, + _multilabel_confusion_matrix_compute, + _multilabel_confusion_matrix_format_inputs, + _multilabel_confusion_matrix_update_state, + _multilabel_confusion_matrix_validate_args, + _multilabel_confusion_matrix_validate_arrays, +) +from cyclops.evaluate.metrics.experimental.metric import Metric +from cyclops.evaluate.metrics.experimental.utils.ops import dim_zero_cat +from cyclops.evaluate.metrics.experimental.utils.typing import Array + + +class _AbstractConfusionMatrix(Metric): + """Base class defining the common interface for confusion matrix classes.""" + + tp: Union[Array, List[Array]] + fp: Union[Array, List[Array]] + tn: Union[Array, List[Array]] + fn: Union[Array, List[Array]] + + def _create_state(self, size: int = 1) -> None: + """Create the state variables. + + Parameters + ---------- + size : int + The size of the default Array to create for the state variables + + Raises + ------ + RuntimeError + If ``size`` is not greater than 0. + + """ + if size <= 0: + raise RuntimeError( + f"Expected `size` to be greater than 0, got {size}.", + ) + dist_reduce_fn = "sum" + + def default(xp: Any) -> Array: + return xp.zeros(shape=size, dtype=xp.int64) + + self.add_state_default_factory("tp", default, dist_reduce_fn=dist_reduce_fn) # type: ignore + self.add_state_default_factory("fp", default, dist_reduce_fn=dist_reduce_fn) # type: ignore + self.add_state_default_factory("tn", default, dist_reduce_fn=dist_reduce_fn) # type: ignore + self.add_state_default_factory("fn", default, dist_reduce_fn=dist_reduce_fn) # type: ignore + + def _update_stat_scores(self, tp: Array, fp: Array, tn: Array, fn: Array) -> None: + """Update the stat scores.""" + self.tp += tp + self.fp += fp + self.tn += tn + self.fn += fn + + def _final_state(self) -> Tuple[Array, Array, Array, Array]: + """Return the final state variables.""" + tp = dim_zero_cat(self.tp) + fp = dim_zero_cat(self.fp) + tn = dim_zero_cat(self.tn) + fn = dim_zero_cat(self.fn) + return tp, fp, tn, fn + + +class BinaryConfusionMatrix( + _AbstractConfusionMatrix, + registry_key="binary_confusion_matrix", +): + """Confusion matrix for binary classification tasks. + + Parameters + ---------- + threshold : float, default=0.5 + The threshold value to use when binarizing the inputs. + normalize : {'true', 'pred', 'all', 'none' None}, optional, default=None + Normalizes confusion matrix over the true (rows), predicted (columns) + samples or all samples. If `None` or `'none'`, confusion matrix will + not be normalized. + ignore_index : int, optional, default=None + Specifies a target value that is ignored and does not contribute to + the confusion matrix. If `None`, all values are used. + **kwargs : Any + Additional keyword arguments common to all metrics. + + Examples + -------- + >>> import numpy.array_api as np + >>> from cyclops.evaluate.metrics.experimental import BinaryConfusionMatrix + >>> target = np.asarray([0, 1, 0, 1, 0, 1]) + >>> preds = np.asarray([0, 0, 1, 1, 0, 1]) + >>> metric = BinaryConfusionMatrix() + >>> metric(target, preds) + Array([[2, 1], + [1, 2]], dtype=int64) + >>> target = np.asarray([0, 1, 0, 1, 0, 1]) + >>> preds = np.asarray([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) + >>> metric = BinaryConfusionMatrix() + >>> metric(target, preds) + Array([[2, 1], + [1, 2]], dtype=int32) + >>> target = np.asarray([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) + >>> preds = np.asarray([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], + ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]]) + + """ + + def __init__( + self, + threshold: float = 0.5, + normalize: Optional[str] = None, + ignore_index: Optional[int] = None, + **kwargs: Any, + ) -> None: + """Initialize the class.""" + super().__init__(**kwargs) + + _binary_confusion_matrix_validate_args( + threshold=threshold, + normalize=normalize, + ignore_index=ignore_index, + ) + + self.threshold = threshold + self.normalize = normalize + self.ignore_index = ignore_index + + self._create_state(size=1) + + def _update_state(self, target: Array, preds: Array) -> None: + """Update the state variables.""" + _binary_confusion_matrix_validate_arrays( + target, + preds, + ignore_index=self.ignore_index, + ) + target, preds = _binary_confusion_matrix_format_inputs( + target, + preds, + threshold=self.threshold, + ignore_index=self.ignore_index, + ) + + tn, fp, fn, tp = _binary_confusion_matrix_update_state(target, preds) + self._update_stat_scores(tp, fp, tn, fn) + + def _compute_metric(self) -> Array: + """Compute the confusion matrix.""" + tp, fp, tn, fn = self._final_state() + return _binary_confusion_matrix_compute( + tp=tp, + fp=fp, + tn=tn, + fn=fn, + normalize=self.normalize, + ) + + +class MulticlassConfusionMatrix(Metric, registry_key="multiclass_confusion_matrix"): + """Confusion matrix for multiclass classification tasks. + + Parameters + ---------- + num_classes : int + The number of classes. + normalize : {'true', 'pred', 'all', 'none' None}, optional, default=None + Normalizes confusion matrix over the true (rows), predicted (columns) + samples or all samples. If `None` or `'none'`, confusion matrix will + not be normalized. + ignore_index : int, optional, default=None + Specifies a target value that is ignored and does not contribute to + the confusion matrix. If `None`, all values are used. + **kwargs : Any + Additional keyword arguments common to all metrics. + + Examples + -------- + >>> import numpy.array_api as np + >>> from cyclops.evaluate.metrics.experimental import MulticlassConfusionMatrix + >>> target = np.asarray([2, 1, 0, 0]) + >>> preds = np.asarray([2, 1, 0, 1]) + >>> metric = MulticlassConfusionMatrix(num_classes=3) + >>> metric(target, preds) + Array([[1, 1, 0], + [0, 1, 0], + [0, 0, 1]], dtype=int64) + >>> target = np.asarray([2, 1, 0, 0]) + >>> preds = np.asarray([[0.16, 0.26, 0.58], + ... [0.22, 0.61, 0.17], + ... [0.71, 0.09, 0.20], + ... [0.05, 0.82, 0.13]]) + >>> metric = MulticlassConfusionMatrix(num_classes=3) + >>> metric(target, preds) + Array([[1, 1, 0], + [0, 1, 0], + [0, 0, 1]], dtype=int64) + + """ + + confmat: Array + + def __init__( + self, + num_classes: int, + normalize: Optional[str] = None, + ignore_index: Optional[Union[int, Tuple[int]]] = None, + **kwargs: Any, + ) -> None: + """Initialize the class.""" + super().__init__(**kwargs) + + _multiclass_confusion_matrix_validate_args( + num_classes, + normalize=normalize, + ignore_index=ignore_index, + ) + + self.num_classes = num_classes + self.normalize = normalize + self.ignore_index = ignore_index + + dist_reduce_fn = "sum" + + def default(xp: Any) -> Array: + return xp.zeros((num_classes,) * 2, dtype=xp.int64) + + self.add_state_default_factory("confmat", default, dist_reduce_fn=dist_reduce_fn) # type: ignore + + def _update_state(self, target: Array, preds: Array) -> None: + """Update the state variable.""" + _multiclass_confusion_matrix_validate_arrays( + target, + preds, + self.num_classes, + ignore_index=self.ignore_index, + ) + target, preds = _multiclass_confusion_matrix_format_inputs( + target, + preds, + ignore_index=self.ignore_index, + ) + confmat = _multiclass_confusion_matrix_update_state( + target, + preds, + self.num_classes, + ) + + self.confmat += confmat + + def _compute_metric(self) -> Array: + """Compute the confusion matrix.""" + confmat = self.confmat + return _multiclass_confusion_matrix_compute( + confmat, + normalize=self.normalize, + ) + + +class MultilabelConfusionMatrix( + _AbstractConfusionMatrix, + registry_key="multilabel_confusion_matrix", +): + """Confusion matrix for multilabel classification tasks. + + Parameters + ---------- + num_labels : int + The number of labels. + threshold : float, default=0.5 + The threshold value to use when binarizing the inputs. + normalize : {'true', 'pred', 'all', 'none' None}, optional, default=None + Normalizes confusion matrix over the true (rows), predicted (columns) + samples or all samples. If `None` or `'none'`, confusion matrix will + not be normalized. + ignore_index : int, optional, default=None + Specifies a target value that is ignored and does not contribute to + the confusion matrix. If `None`, all values are used. + **kwargs : Any + Additional keyword arguments common to all metrics. + + Examples + -------- + >>> import numpy.array_api as np + >>> from cyclops.evaluate.metrics.experimental import MultilabelConfusionMatrix + >>> target = np.asarray([[0, 1, 0], [1, 0, 1]]) + >>> preds = np.asarray([[0, 0, 1], [1, 0, 1]]) + >>> metric = MultilabelConfusionMatrix(num_labels=3) + >>> metric(target, preds) + Array([[[1, 0], + [0, 1]], + + [[1, 0], + [1, 0]], + + [[0, 1], + [0, 1]]], dtype=int64) + >>> target = np.asarray([[0, 1, 0], [1, 0, 1]]) + >>> preds = np.asarray([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) + >>> metric = MultilabelConfusionMatrix(num_labels=3) + >>> metric(target, preds) + Array([[[1, 0], + [0, 1]], + + [[1, 0], + [1, 0]], + + [[0, 1], + [0, 1]]], dtype=int64) + + """ + + def __init__( + self, + num_labels: int, + threshold: float = 0.5, + normalize: Optional[str] = None, + ignore_index: Optional[int] = None, + **kwargs: Any, + ) -> None: + """Initialize the class.""" + super().__init__(**kwargs) + + _multilabel_confusion_matrix_validate_args( + num_labels, + threshold=threshold, + normalize=normalize, + ignore_index=ignore_index, + ) + + self.num_labels = num_labels + self.threshold = threshold + self.normalize = normalize + self.ignore_index = ignore_index + + self._create_state(size=num_labels) + + def _update_state(self, target: Array, preds: Array) -> None: + """Update the state variables.""" + _multilabel_confusion_matrix_validate_arrays( + target, + preds, + self.num_labels, + ignore_index=self.ignore_index, + ) + target, preds = _multilabel_confusion_matrix_format_inputs( + target, + preds, + threshold=self.threshold, + ignore_index=self.ignore_index, + ) + tn, fp, fn, tp = _multilabel_confusion_matrix_update_state(target, preds) + self._update_stat_scores(tp, fp, tn, fn) + + def _compute_metric(self) -> Array: + """Compute the confusion matrix.""" + tp, fp, tn, fn = self._final_state() + return _multilabel_confusion_matrix_compute( + tp=tp, + fp=fp, + tn=tn, + fn=fn, + num_labels=self.num_labels, + normalize=self.normalize, + ) diff --git a/cyclops/evaluate/metrics/experimental/distributed_backends/__init__.py b/cyclops/evaluate/metrics/experimental/distributed_backends/__init__.py new file mode 100644 index 000000000..00da50a1f --- /dev/null +++ b/cyclops/evaluate/metrics/experimental/distributed_backends/__init__.py @@ -0,0 +1,48 @@ +"""Distributed backends for distributed metric computation.""" +from cyclops.evaluate.metrics.experimental.distributed_backends.base import ( + _DISTRIBUTED_BACKEND_REGISTRY, + DistributedBackend, +) +from cyclops.evaluate.metrics.experimental.distributed_backends.mpi4py import ( + MPI4Py, +) +from cyclops.evaluate.metrics.experimental.distributed_backends.non_distributed import ( + NonDistributed, +) +from cyclops.evaluate.metrics.experimental.distributed_backends.torch_distributed import ( + TorchDistributed, +) + + +def get_backend(name: str) -> DistributedBackend: + """Return a registered distributed backend by name. + + Parameters + ---------- + name : str + Name of the distributed backend. + + Returns + ------- + backend : DistributedBackend + An instance of the distributed backend. + + Raises + ------ + ValueError + If the backend is not found in the registry. + + """ + if not isinstance(name, str): + raise TypeError( + f"Expected `name` to be a str, but got {type(name)}.", + ) + + name = name.lower() + backend = _DISTRIBUTED_BACKEND_REGISTRY.get(name) + if backend is None: + raise ValueError( + f"Backend `{name}` is not found. " + f"It should be one of {list(_DISTRIBUTED_BACKEND_REGISTRY.keys())}.", + ) + return backend() diff --git a/cyclops/evaluate/metrics/experimental/distributed_backends/base.py b/cyclops/evaluate/metrics/experimental/distributed_backends/base.py new file mode 100644 index 000000000..acb292ab5 --- /dev/null +++ b/cyclops/evaluate/metrics/experimental/distributed_backends/base.py @@ -0,0 +1,85 @@ +"""Distributed backend interface.""" +import logging +from abc import ABC, abstractmethod, abstractproperty +from typing import Any, List, Optional + +from cyclops.evaluate.metrics.experimental.utils.typing import Array +from cyclops.utils.log import setup_logging + + +_DISTRIBUTED_BACKEND_REGISTRY = {} +LOGGER = logging.getLogger(__name__) +setup_logging(print_level="WARN", logger=LOGGER) + + +class DistributedBackend(ABC): + """Abstract base class for implementing distributed communication backends. + + Parameters + ---------- + registry_key : str, optional + The key used to register the distributed backend. If not given, the class + name will be used as the key. + + """ + + def __init_subclass__( + cls, + registry_key: Optional[str] = None, + **kwargs: Any, + ): + """Register the distributed backend.""" + super().__init_subclass__(**kwargs) + + if ( # subclass has not implemented abstract methods/properties + cls.all_gather is not DistributedBackend.all_gather + and cls.is_initialized is not DistributedBackend.is_initialized + and cls.rank is not DistributedBackend.rank + and cls.world_size is not DistributedBackend.world_size + ) and cls.__name__ != "DistributedBackend": + if registry_key is None: + registry_key = cls.__name__ + elif registry_key in _DISTRIBUTED_BACKEND_REGISTRY: + LOGGER.warning( + "The given distributed backend %s has already been registered. " + "It will be overwritten by %s.", + registry_key, + cls.__name__, + ) + _DISTRIBUTED_BACKEND_REGISTRY[registry_key] = cls + else: + LOGGER.warning( + "The distributed backend %s is not registered because it does not " + "implement any abstract methods/properties.", + cls.__name__, + ) + + @abstractproperty + def is_initialized(self) -> bool: + """Return `True` if the distributed environment has been initialized.""" + + @abstractproperty + def rank(self) -> int: + """Return the rank of the current process.""" + + @abstractproperty + def world_size(self) -> int: + """Return the total number of processes.""" + + @abstractmethod + def all_gather(self, arr: Array) -> List[Array]: + """Gather Array object from all processes and return as a list. + + NOTE: This method must handle uneven array shapes. + + Parameters + ---------- + arr : Array + The array to be gathered. + + Returns + ------- + List[Array] + A list of data gathered from all processes. + + """ diff --git a/cyclops/evaluate/metrics/experimental/distributed_backends/mpi4py.py b/cyclops/evaluate/metrics/experimental/distributed_backends/mpi4py.py new file mode 100644 index 000000000..3ae9997d4 --- /dev/null +++ b/cyclops/evaluate/metrics/experimental/distributed_backends/mpi4py.py @@ -0,0 +1,107 @@ +"""mpi4py backend for synchronizing array-API-compatible objects.""" +import os +from typing import TYPE_CHECKING, List + +import array_api_compat as apc + +from cyclops.evaluate.metrics.experimental.distributed_backends.base import ( + DistributedBackend, +) +from cyclops.evaluate.metrics.experimental.utils.ops import flatten +from cyclops.evaluate.metrics.experimental.utils.typing import Array +from cyclops.utils.optional import import_optional_module + + +if TYPE_CHECKING: + from mpi4py import MPI +else: + MPI = import_optional_module("mpi4py.MPI", error="ignore") +# mypy: disable-error-code="no-any-return" + + +class MPI4Py(DistributedBackend, registry_key="mpi4py"): + """A distributed communication backend for mpi4py.""" + + def __init__(self) -> None: + """Initialize the MPI4Py backend.""" + super().__init__() + if MPI is None: + raise ImportError( + f"For availability of {self.__class__.__name__}," + " please install mpi4py first.", + ) + + @property + def is_initialized(self) -> bool: + """Return `True` if the distributed environment has been initialized.""" + return "OMPI_COMM_WORLD_SIZE" in os.environ + + @property + def rank(self) -> int: + """Return the rank of the current process.""" + comm = MPI.COMM_WORLD + return comm.Get_rank() + + @property + def world_size(self) -> int: + """Return the total number of processes.""" + comm = MPI.COMM_WORLD + return comm.Get_size() + + def all_gather(self, arr: Array) -> List[Array]: + """Gather Arrays from current proccess and return as a list. + + Parameters + ---------- + arr : Array + Any array-API-compatible object. + + Returns + ------- + List[Array] + A list of the gathered array-API-compatible objects. + """ + try: + xp = apc.array_namespace(arr) + except TypeError as e: + raise TypeError( + "The given array is not array-API-compatible. " + "Please use array-API-compatible objects.", + ) from e + + comm = MPI.COMM_WORLD + + # gather the shape and size of each array + local_shape = arr.shape + local_size = apc.size(arr) + all_shapes = comm.allgather(local_shape) + all_sizes = comm.allgather(local_size) + + # prepare displacements for `Allgatherv`` + displacements = [0] + for shape in all_shapes[:-1]: + shape_arr = xp.asarray(shape, dtype=xp.int32) + displacements.append(displacements[-1] + int(xp.prod(shape_arr))) + + # allocate memory for gathered data based on total size + total_size = sum( + [int(xp.prod(xp.asarray(shape, dtype=xp.int32))) for shape in all_shapes], + ) + gathered_data = xp.empty(total_size, dtype=arr.dtype) + + # gather data from all processes to all processes, accounting for uneven shapes + comm.Allgatherv( + flatten(arr), + [gathered_data, (all_sizes, displacements)], + ) + + # reshape gathered data back to original shape + reshaped_data = [] + for shape in all_shapes: + shape_arr = xp.asarray(shape, dtype=xp.int32) + reshaped_data.append( + xp.reshape(gathered_data[: xp.prod(shape_arr)], shape=shape), + ) + gathered_data = gathered_data[xp.prod(shape_arr) :] + + return reshaped_data diff --git a/cyclops/evaluate/metrics/experimental/distributed_backends/non_distributed.py b/cyclops/evaluate/metrics/experimental/distributed_backends/non_distributed.py new file mode 100644 index 000000000..fba82d187 --- /dev/null +++ b/cyclops/evaluate/metrics/experimental/distributed_backends/non_distributed.py @@ -0,0 +1,38 @@ +"""A dummy object for non-distributed environments.""" +from typing import Any, List + +from cyclops.evaluate.metrics.experimental.distributed_backends.base import ( + DistributedBackend, +) + + +class NonDistributed(DistributedBackend, registry_key="non_distributed"): + """A dummy distributed communication backend for non-distributed environments.""" + + @property + def is_initialized(self) -> bool: + """Return `True` if the distributed environment has been initialized. + + For a non-distributed environment, it is always `False`. + """ + return False + + @property + def rank(self) -> int: + """Return the rank of the current process. + + For a non-distributed environment, it is always 0 + """ + return 0 + + @property + def world_size(self) -> int: + """Return the total number of processes. + + For a non-distributed environment, it is always 1. + """ + return 1 + + def all_gather(self, arr: Any) -> List[Any]: + """Return the input as a list.""" + return [arr] diff --git a/cyclops/evaluate/metrics/experimental/distributed_backends/torch_distributed.py b/cyclops/evaluate/metrics/experimental/distributed_backends/torch_distributed.py new file mode 100644 index 000000000..b4f29eb26 --- /dev/null +++ b/cyclops/evaluate/metrics/experimental/distributed_backends/torch_distributed.py @@ -0,0 +1,99 @@ +"""`torch.distributed` backend for synchronizing `torch.Tensor` objects.""" +from typing import TYPE_CHECKING, List, TypeVar + +from cyclops.evaluate.metrics.experimental.distributed_backends.base import ( + DistributedBackend, +) +from cyclops.utils.optional import import_optional_module + + +if TYPE_CHECKING: + import torch + import torch.distributed as torch_dist +else: + torch = import_optional_module("torch", error="ignore") + torch_dist = import_optional_module("torch.distributed", error="ignore") + +Tensor = TypeVar("Tensor", bound="torch.Tensor") + + +class TorchDistributed(DistributedBackend, registry_key="torch_distributed"): + """A distributed communication backend for torch.distributed.""" + + def __init__(self) -> None: + """Initialize the object.""" + super().__init__() + if torch is None: + raise ImportError( + f"For availability of {self.__class__.__name__}," + " please install pytorch first.", + ) + if not torch_dist.is_available(): + raise RuntimeError( + f"For availability of {self.__class__.__name__}," + " make sure torch.distributed is available.", + ) + + @property + def is_initialized(self) -> bool: + """Return `True` if the distributed environment has been initialized.""" + return torch_dist.is_initialized() + + @property + def rank(self) -> int: + """Return the rank of the current process group.""" + return torch_dist.get_rank() + + @property + def world_size(self) -> int: + """Return the world size of the current process group.""" + return torch_dist.get_world_size() + + def _simple_all_gather(self, data: torch.Tensor) -> List[torch.Tensor]: + """Gather tensors of the same shape from all processes.""" + gathered_data = [torch.zeros_like(data) for _ in range(self.world_size)] + torch_dist.all_gather(gathered_data, data) # type: ignore[no-untyped-call] + return gathered_data + + def all_gather(self, data: torch.Tensor) -> List[torch.Tensor]: + """Gather Arrays from current proccess and return as a list. + + Parameters + ---------- + arr : torch.Tensor + `torch.Tensor` object to be gathered. + + Returns + ------- + List[Array] + A list of the gathered `torch.Tensor` objects. + """ + data = data.contiguous() + + if data.ndim == 0: + return self._simple_all_gather(data) + + # gather sizes of all tensors + local_size = torch.tensor(data.shape, device=data.device) + local_sizes = [torch.zeros_like(local_size) for _ in range(self.world_size)] + torch_dist.all_gather(local_sizes, local_size) # type: ignore[no-untyped-call] + max_size = torch.stack(local_sizes).max(dim=0).values + all_sizes_equal = all(all(ls == max_size) for ls in local_sizes) + + # if shapes are all the same, do a simple gather + if all_sizes_equal: + return self._simple_all_gather(data) + + # if not, pad each local tensor to maximum size, gather and then truncate + pad_dims = [] + pad_by = (max_size - local_size).detach().cpu() + for val in reversed(pad_by): + pad_dims.append(0) + pad_dims.append(val.item()) + data_padded = torch.nn.functional.pad(data, pad_dims) + gathered_data = [torch.zeros_like(data_padded) for _ in range(self.world_size)] + torch_dist.all_gather(gathered_data, data_padded) # type: ignore[no-untyped-call] + for idx, item_size in enumerate(local_sizes): + slice_param = [slice(dim_size) for dim_size in item_size] + gathered_data[idx] = gathered_data[idx][slice_param] + return gathered_data diff --git a/cyclops/evaluate/metrics/experimental/functional/__init__.py b/cyclops/evaluate/metrics/experimental/functional/__init__.py new file mode 100644 index 000000000..5179ca8ec --- /dev/null +++ b/cyclops/evaluate/metrics/experimental/functional/__init__.py @@ -0,0 +1,6 @@ +"""Functional metrics for evaluating model performance.""" +from cyclops.evaluate.metrics.experimental.functional.confusion_matrix import ( + binary_confusion_matrix, + multiclass_confusion_matrix, + multilabel_confusion_matrix, +) diff --git a/cyclops/evaluate/metrics/experimental/functional/confusion_matrix.py b/cyclops/evaluate/metrics/experimental/functional/confusion_matrix.py new file mode 100644 index 000000000..f2f3066d1 --- /dev/null +++ b/cyclops/evaluate/metrics/experimental/functional/confusion_matrix.py @@ -0,0 +1,757 @@ +"""Functions for computing the confusion matrix for classification tasks.""" +from typing import Any, Literal, Optional, Tuple, Union + +import array_api_compat as apc + +from cyclops.evaluate.metrics.experimental.utils.ops import ( + bincount, + clone, + flatten, + remove_ignore_index, + safe_divide, + sigmoid, + squeeze_all, + to_int, +) +from cyclops.evaluate.metrics.experimental.utils.typing import Array +from cyclops.evaluate.metrics.experimental.utils.validation import ( + _basic_input_array_checks, + _check_same_shape, + is_floating_point, +) + + +def _common_confusion_matrix_args_validate( + normalize: Optional[str] = None, + ignore_index: Optional[Union[int, Tuple[int]]] = None, +) -> None: + """Validate the arguments of the confusion matrix functions.""" + allowed_normalize = ("true", "pred", "all", "none", None) + if normalize not in allowed_normalize: + raise ValueError( + f"Expected argument `normalize` to be one of {allowed_normalize}, " + f"but got {normalize}", + ) + if ignore_index is not None and not ( + isinstance(ignore_index, int) + or ( + isinstance(ignore_index, tuple) + and all(isinstance(i, int) for i in ignore_index) + ) + ): + raise ValueError( + "Expected argument `ignore_index` to either be `None`, an integer, " + "or a tuple of integers (for multiclass and multilabel inputs) but " + f"got {ignore_index}", + ) + if isinstance(ignore_index, tuple) and min(ignore_index) < 0: + raise ValueError( + "Expected argument `ignore_index` to be a tuple of non-negative " + f"integers but got {ignore_index}", + ) + + +def _normalize_confusion_matrix( + confmat: Array, + normalize: Optional[str] = None, + *, + xp: Any, +) -> Array: + """Normalize the confusion matrix.""" + if normalize in ["true", "pred", "all"]: + confmat = xp.astype(confmat, xp.float32) + + if normalize == "pred": + return safe_divide(confmat, xp.sum(confmat, axis=-2, keepdims=True)) + if normalize == "true": + return safe_divide(confmat, xp.sum(confmat, axis=-1, keepdims=True)) + if normalize == "all": + return safe_divide(confmat, xp.sum(confmat, axis=(-1, -2), keepdims=True)) + + return confmat + + +def _binary_confusion_matrix_validate_args( + threshold: float = 0.5, + normalize: Optional[str] = None, + ignore_index: Optional[int] = None, +) -> None: + """Validate the arguments of the `binary_confusion_matrix` method.""" + if not (isinstance(threshold, float) and (0.0 <= threshold <= 1.0)): + raise ValueError( + "Expected argument `threshold` to be a float in the [0,1] range, " + f"but got {threshold}.", + ) + _common_confusion_matrix_args_validate( + normalize=normalize, + ignore_index=ignore_index, + ) + + +def _binary_confusion_matrix_validate_arrays( + target: Array, + preds: Array, + ignore_index: Optional[int] = None, +) -> None: + """Validate the inputs of the `binary_confusion_matrix` method.""" + _basic_input_array_checks(target, preds) + _check_same_shape(target, preds) + + xp = apc.array_namespace(target, preds) + + unique_values = xp.unique_values(target) + if ignore_index is None: + check = xp.any((unique_values != 0) & (unique_values != 1)) + else: + check = xp.any( + (unique_values != 0) + & (unique_values != 1) + & (unique_values != ignore_index), + ) + if check: + raise RuntimeError( + "Expected only the following values " + f"{[0, 1] if ignore_index is None else [ignore_index]} in `target`. " + f"But found the following values: {unique_values}", + ) + + if not is_floating_point(preds): + unique_values = xp.unique_values(preds) + if xp.any((unique_values != 0) & (unique_values != 1)): + raise RuntimeError( + "Expected only the following values " + f"{[0, 1] if ignore_index is None else [ignore_index]} in `preds`. " + f"But found the following values: {unique_values}", + ) + + +def _binary_confusion_matrix_format_inputs( + target: Array, + preds: Array, + threshold: float, + ignore_index: Optional[int], +) -> Tuple[Array, Array]: + """Format the input arrays of the `binary_confusion_matrix` method.""" + xp = apc.array_namespace(target, preds) + + preds = flatten(preds) + target = flatten(target) + + if ignore_index is not None: + target, preds = remove_ignore_index(target, preds, ignore_index=ignore_index) + + if is_floating_point(preds): + # NOTE: in the 2021.12 version of the the array API standard the `__mul__` + # operator is only defined for numeric arrays (including float and int scalars) + # so we convert the boolean array to an integer array first. + if not xp.all(to_int((preds >= 0)) * to_int((preds <= 1))): # preds are logits + preds = sigmoid(preds) # convert to probabilities with sigmoid + preds = to_int(preds > threshold) + + return target, preds + + +def _binary_confusion_matrix_update_state( + target: Array, + preds: Array, +) -> Tuple[Array, Array, Array, Array]: + """Compute stat scores for the given `target` and `preds` arrays.""" + xp = apc.array_namespace(target, preds) + + # NOTE: in the 2021.12 version of the array API standard, the `sum` method + # only supports numeric types, so we have to cast the boolean arrays to integers. + # Also, the `squeeze` method in the array API standard does not support `axis=None` + # so we define a custom method `squeeze_all` to squeeze all singleton dimensions. + tp = squeeze_all(xp.sum(to_int((target == preds) & (target == 1)))) + fn = squeeze_all(xp.sum(to_int((target != preds) & (target == 1)))) + fp = squeeze_all(xp.sum(to_int((target != preds) & (target == 0)))) + tn = squeeze_all(xp.sum(to_int((target == preds) & (target == 0)))) + + return tn, fp, fn, tp + + +def _binary_confusion_matrix_compute( + tn: Array, + fp: Array, + fn: Array, + tp: Array, + normalize: Optional[str] = None, +) -> Array: + """Compute the confusion matrix from the given stat scores.""" + xp = apc.array_namespace(tn, fp, fn, tp) + + confmat = squeeze_all( + xp.reshape(xp.stack([tn, fp, fn, tp], axis=0), shape=(-1, 2, 2)), + ) + return _normalize_confusion_matrix(confmat, normalize=normalize, xp=xp) + + +def binary_confusion_matrix( + target: Array, + preds: Array, + threshold: float = 0.5, + normalize: Optional[Literal["pred", "true", "all", "none"]] = None, + ignore_index: Optional[int] = None, +) -> Array: + """Compute the confusion matrix for binary classification tasks. + + Parameters + ---------- + target : Array + The target array with shape `(N, ...)`, where `N` is the number of samples. + preds : Array + The prediction array with shape `(N, ...)`, where `N` is the number of samples. + threshold : float, default=0.5 + The threshold to use for binarizing the predictions. + normalize : str, optional, default=None + Normalization mode. + If `None` or `'none'`, return the number of correctly classified samples + for each class. + If `'true'`, return the fraction of correctly classified samples for each + class over the number of samples with the same true class. + If `'pred'`, return the fraction of samples of each class that were correctly + classified over the number of samples with the same predicted class. + If `'all'`, return the fraction of correctly classified samples over all + samples. + ignore_index : int, optional, default=None + Specifies a target value that is ignored and does not contribute to the + confusion matrix. If `None`, ignore nothing. + + Returns + ------- + Array + The confusion matrix with shape `(2, 2)`. + + Raises + ------ + ValueError + If `target` and `preds` have different shapes. + ValueError + If `target` and `preds` are not array-API-compatible. + ValueError + If `target` or `preds` are empty. + ValueError + If `target` or `preds` are not numeric arrays. + ValueError + If `threshold` is not a float in the [0,1] range. + ValueError + If `normalize` is not one of `'pred'`, `'true'`, `'all'`, `'none'`, or `None`. + ValueError + If `ignore_index` is not `None` or an integer. + + Examples + -------- + >>> import numpy.array_api as np + >>> from cyclops.evaluate.metrics.experimental.functional import binary_confusion_matrix + >>> target = np.asarray([0, 1, 0, 1, 0, 1]) + >>> preds = np.asarray([0, 0, 1, 1, 0, 1]) + >>> binary_confusion_matrix(target, preds) + Array([[2, 1], + [1, 2]], dtype=int64) + >>> target = np.asarray([0, 1, 0, 1, 0, 1]) + >>> preds = np.asarray([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) + >>> binary_confusion_matrix(target, preds) + Array([[2, 1], + [1, 2]], dtype=int32) + + """ # noqa: W505 + _binary_confusion_matrix_validate_args( + threshold=threshold, + normalize=normalize, + ignore_index=ignore_index, + ) + _binary_confusion_matrix_validate_arrays(target, preds, ignore_index) + + target, preds = _binary_confusion_matrix_format_inputs( + target, + preds, + threshold, + ignore_index, + ) + tn, fp, fn, tp = _binary_confusion_matrix_update_state(target, preds) + + return _binary_confusion_matrix_compute(tn, fp, fn, tp, normalize=normalize) + + +def _multiclass_confusion_matrix_validate_args( + num_classes: int, + normalize: Optional[str] = None, + ignore_index: Optional[Union[int, Tuple[int]]] = None, +) -> None: + """Validate the arguments of the `multiclass_confusion_matrix` method.""" + if not isinstance(num_classes, int) or num_classes < 2: + raise ValueError( + "Expected argument `num_classes` to be an integer larger than 1, " + f"but got {num_classes}.", + ) + _common_confusion_matrix_args_validate( + normalize=normalize, + ignore_index=ignore_index, + ) + + +def _multiclass_confusion_matrix_validate_arrays( + target: Array, + preds: Array, + num_classes: int, + ignore_index: Optional[Union[int, Tuple[int]]] = None, +) -> None: + """Validate the inputs of the `multiclass_confusion_matrix` method.""" + _basic_input_array_checks(target, preds) + + xp = apc.array_namespace(target, preds) + + if preds.ndim == target.ndim + 1: + if not is_floating_point(preds): + raise ValueError( + "If `preds` have one dimension more than `target`, `preds` should " + "contain floating point values.", + ) + + if target.ndim == 0 and preds.shape[0] != num_classes: + raise ValueError( + "If `target` is a scalar and `preds` has one dimension more than " + "`target`, the first dimension of `preds` should be equal to number " + "of classes.", + ) + if target.ndim >= 1 and preds.shape[1] != num_classes: + raise ValueError( + "If `preds` have one dimension more than `target`, the second " + "dimension of `preds` should be equal to number of classes.", + ) + if preds.shape[2:] != target.shape[1:]: + raise ValueError( + "If `preds` have one dimension more than `target`, the shape of " + "`preds` should be (N, C, ...), and the shape of `target` should " + "be (N, ...).", + ) + elif preds.ndim == target.ndim: + _check_same_shape(target, preds) + else: + raise ValueError( + "Either `preds` and `target` both should have the (same) shape (N, ...), " + "or the shape of `target` should be (N, ...) and the shape of `preds` " + "should be (N, C, ...).", + ) + + num_unique_values = apc.size(xp.unique_values(target)) + check = num_unique_values is None or ( + num_unique_values > num_classes + if ignore_index is None + else num_unique_values > num_classes + 1 + ) + if check: + raise RuntimeError( + f"Expected only {num_classes if ignore_index is None else num_classes + 1} " + f"values in `target` but found {num_unique_values} values.", + ) + + if not is_floating_point(preds): + unique_values = xp.unique_values(preds) + num_unique_values = apc.size(unique_values) + if num_unique_values is None or num_unique_values > num_classes: + raise RuntimeError( + f"Expected only {num_classes} values in `preds` but found " + f"{num_unique_values} values.", + ) + + +def _multiclass_confusion_matrix_format_inputs( + target: Array, + preds: Array, + ignore_index: Optional[Union[int, Tuple[int]]] = None, +) -> Tuple[Array, Array]: + """Format the input arrays of the `multiclass_confusion_matrix` method.""" + xp = apc.array_namespace(target, preds) + if preds.ndim == target.ndim + 1: + axis = 1 if preds.ndim > 1 else 0 + preds = xp.argmax(preds, axis=axis) + + target, preds = flatten(target), flatten(preds) + + if ignore_index is not None: + target, preds = remove_ignore_index(target, preds, ignore_index=ignore_index) + + return target, preds + + +def _multiclass_confusion_matrix_update_state( + target: Array, + preds: Array, + num_classes: int, +) -> Array: + """Compute the confusion matrix for the given `target` and `preds` arrays.""" + xp = apc.array_namespace(target, preds) + + unique_mapping = to_int(target) * num_classes + to_int(preds) + bins = bincount(unique_mapping, minlength=num_classes**2) + + return squeeze_all(xp.reshape(bins, shape=(-1, num_classes, num_classes))) + + +def _multiclass_confusion_matrix_compute( + confmat: Array, + normalize: Optional[str] = None, +) -> Array: + """Normalize the confusion matrix.""" + xp = apc.array_namespace(confmat) + return _normalize_confusion_matrix(confmat, normalize=normalize, xp=xp) + + +def multiclass_confusion_matrix( + target: Array, + preds: Array, + num_classes: int, + normalize: Optional[Literal["pred", "true", "all", "none"]] = None, + ignore_index: Optional[Union[int, Tuple[int]]] = None, +) -> Array: + """Compute the confusion matrix for multiclass classification tasks. + + Parameters + ---------- + target : Array + The target array of shape `(N, ...)`, where `N` is the number of samples. + preds : Array + The prediction array with shape `(N, ...)`, for integer inputs, or + `(N, C, ...)`, for float inputs, where `N` is the number of samples and + `C` is the number of classes. + num_classes : int + The number of classes. + normalize : str, optional, default=None + Normalization mode. + If `None` or `'none'`, return the number of correctly classified samples + for each class. + If `'true'`, return the fraction of correctly classified samples for each + class over the number of samples with the same true class. + If `'pred'`, return the fraction of samples of each class that were correctly + classified over the number of samples with the same predicted class. + If `'all'`, return the fraction of correctly classified samples over all + samples. + ignore_index : int, Tuple[int], optional, default=None + Specifies a target value(s) that is ignored and does not contribute to the + confusion matrix. If `None`, ignore nothing. + + Returns + ------- + Array + The confusion matrix with shape `(C, C)`, where `C` is the number of classes. + + Raises + ------ + ValueError + If `target` and `preds` are not array-API-compatible. + ValueError + If `target` or `preds` are empty. + ValueError + If `target` or `preds` are not numeric arrays. + ValueError + If `num_classes` is not an integer larger than 1. + ValueError + If `normalize` is not one of `'pred'`, `'true'`, `'all'`, `'none'`, or `None`. + ValueError + If `ignore_index` is not `None`, an integer or a tuple of integers.\ + ValueError + If `preds` contains floats but `target` does not have one dimension less than + `preds`. + ValueError + If the second dimension of `preds` is not equal to `num_classes`. + ValueError + If when `target` has one dimension less than `preds`, the shape of `preds` is + not `(N, C, ...)` while the shape of `target` is `(N, ...)`. + ValueError + If when `target` and `preds` have the same number of dimensions, they + do not have the same shape. + RuntimeError + If `target` contains values that are not in the range [0, `num_classes`). + + Examples + -------- + >>> import numpy.array_api as np + >>> from cyclops.evaluate.metrics.experimental.functional import multiclass_confusion_matrix + >>> target = np.asarray([2, 1, 0, 0]) + >>> preds = np.asarray([2, 1, 0, 1]) + >>> multiclass_confusion_matrix(target, preds, num_classes=3) + Array([[1, 1, 0], + [0, 1, 0], + [0, 0, 1]], dtype=int64) + >>> target = np.asarray([2, 1, 0, 0]) + >>> preds = np.asarray([[0.16, 0.26, 0.58], + ... [0.22, 0.61, 0.17], + ... [0.71, 0.09, 0.20], + ... [0.05, 0.82, 0.13]]) + >>> multiclass_confusion_matrix(target, preds, num_classes=3) + Array([[1, 1, 0], + [0, 1, 0], + [0, 0, 1]], dtype=int64) + + """ # noqa: W505 + _multiclass_confusion_matrix_validate_args( + num_classes, + normalize=normalize, + ignore_index=ignore_index, + ) + _multiclass_confusion_matrix_validate_arrays( + target, + preds, + num_classes, + ignore_index=ignore_index, + ) + + target, preds = _multiclass_confusion_matrix_format_inputs( + target, + preds, + ignore_index=ignore_index, + ) + confmat = _multiclass_confusion_matrix_update_state(target, preds, num_classes) + + return _multiclass_confusion_matrix_compute(confmat, normalize) + + +def _multilabel_confusion_matrix_validate_args( + num_labels: int, + threshold: float = 0.5, + normalize: Optional[str] = None, + ignore_index: Optional[int] = None, +) -> None: + """Validate the arguments of the `multilabel_confusion_matrix` method.""" + if not isinstance(num_labels, int) or num_labels < 2: + raise ValueError( + "Expected argument `num_labels` to be an integer larger than 1, " + f"but got {num_labels}.", + ) + _binary_confusion_matrix_validate_args( + threshold=threshold, + normalize=normalize, + ignore_index=ignore_index, + ) + + +def _multilabel_confusion_matrix_validate_arrays( + target: Array, + preds: Array, + num_labels: int, + ignore_index: Optional[int] = None, +) -> None: + """Validate the input arrays of the `multilabel_confusion_matrix` method.""" + _basic_input_array_checks(target, preds) + _check_same_shape(target, preds) + + xp = apc.array_namespace(target, preds) + + if preds.shape[1] != num_labels: + raise ValueError( + "Expected the second dimension of `preds` and `target` to be equal " + f"to `num_labels`={num_labels}, but got {preds.shape[1]}.", + ) + + # Check that target only contains [0,1] values or value in ignore_index + unique_values = xp.unique_values(target) + if ignore_index is None: + check = xp.any((unique_values != 0) & (unique_values != 1)) + else: + check = xp.any( + (unique_values != 0) + & (unique_values != 1) + & (unique_values != ignore_index), + ) + if check: + raise RuntimeError( + "Expected only the following values " + f"{[0, 1] if ignore_index is None else [ignore_index]} in `target`. " + f"But found the following values: {unique_values}", + ) + + if not is_floating_point(preds): + unique_values = xp.unique_values(preds) + if xp.any((unique_values != 0) & (unique_values != 1)): + raise RuntimeError( + "Expected only 0s and 1s in `preds`, but found the following values: " + f"{unique_values}", + ) + + +def _multilabel_confusion_matrix_format_inputs( + target: Array, + preds: Array, + threshold: float = 0.5, + ignore_index: Optional[int] = None, +) -> Tuple[Array, Array]: + """Format the input arrays of the `multilabel_confusion_matrix` method.""" + xp = apc.array_namespace(target, preds) + + if is_floating_point(preds): + # NOTE: in the array API standard the `__mul__` operator is only defined + # for numeric arrays (including float and int scalars) so we convert the + # boolean array to an integer array first. + if not xp.all(to_int((preds >= 0)) * to_int((preds <= 1))): + preds = sigmoid(preds) # convert logits to probabilities + preds = to_int(preds > threshold) + + preds = xp.reshape(preds, shape=(*preds.shape[:2], -1)) + target = xp.reshape(target, shape=(*target.shape[:2], -1)) + + if ignore_index is not None: + idx = target == ignore_index + target = clone(target) + target[idx] = -1 + + return target, preds + + +def _multilabel_confusion_matrix_update_state( + target: Array, + preds: Array, +) -> Tuple[Array, Array, Array, Array]: + """Compute the statistics for the given `target` and `preds` arrays.""" + xp = apc.array_namespace(target, preds) + + sum_axis = (0, -1) + tp = squeeze_all(xp.sum(to_int((target == preds) & (target == 1)), axis=sum_axis)) + fn = squeeze_all(xp.sum(to_int((target != preds) & (target == 1)), axis=sum_axis)) + fp = squeeze_all(xp.sum(to_int((target != preds) & (target == 0)), axis=sum_axis)) + tn = squeeze_all(xp.sum(to_int((target == preds) & (target == 0)), axis=sum_axis)) + + return tn, fp, fn, tp + + +def _multilabel_confusion_matrix_compute( + tn: Array, + fp: Array, + fn: Array, + tp: Array, + num_labels: int, + normalize: Optional[str] = None, +) -> Array: + """Compute the confusion matrix from the given stat scores.""" + xp = apc.array_namespace(tn, fp, fn, tp) + + confmat = squeeze_all( + xp.reshape(xp.stack([tn, fp, fn, tp], axis=-1), shape=(-1, num_labels, 2, 2)), + ) + + return _normalize_confusion_matrix(confmat, normalize=normalize, xp=xp) + + +def multilabel_confusion_matrix( + target: Array, + preds: Array, + num_labels: int, + threshold: float = 0.5, + normalize: Optional[str] = None, + ignore_index: Optional[int] = None, +) -> Array: + """Compute the confusion matrix for multilabel classification tasks. + + Parameters + ---------- + target : Array + The target array of shape `(N, L, ...)`, where `N` is the number of samples + and `L` is the number of labels. + preds : Array + The prediction array of shape `(N, L, ...)`, where `N` is the number of + samples and `L` is the number of labels. If `preds` contains floats that + are not in the range [0,1], they will be converted to probabilities using + the sigmoid function. + num_labels : int + The number of labels. + threshold : float, default=0.5 + The threshold to use for binarizing the predictions. + normalize : str, optional, default=None + Normalization mode. + If `None` or `'none'`, return the number of correctly classified samples + for each class. + If `'true'`, return the fraction of correctly classified samples for each + class over the number of true samples for each class. + If `'pred'`, return the fraction of samples of each class that were correctly + classified over the number of samples predicted for each class. + If `'all'`, return the fraction of correctly classified samples over all + samples. + ignore_index : int, optional, default=None + Specifies a target value that is ignored and does not contribute to the + confusion matrix. If `None`, ignore nothing. + + Returns + ------- + Array + The confusion matrix with shape `(L, 2, 2)`, where `L` is the number of labels. + + Raises + ------ + ValueError + If `target` and `preds` are not array-API-compatible. + ValueError + If `target` or `preds` are empty. + ValueError + If `target` or `preds` are not numeric arrays. + ValueError + If `threshold` is not a float in the [0,1] range. + ValueError + If `normalize` is not one of `'pred'`, `'true'`, `'all'`, `'none'`, or `None`. + ValueError + If `ignore_index` is not `None` or a non-negative integer. + ValueError + If `num_labels` is not an integer larger than 1. + ValueError + If `target` and `preds` do not have the same shape. + ValueError + If the second dimension of `preds` is not equal to `num_labels`. + RuntimeError + If `target` contains values that are not in the range [0, 1]. + + Examples + -------- + >>> import numpy.array_api as np + >>> from cyclops.evaluate.metrics.experimental.functional import multilabel_confusion_matrix + >>> target = np.asarray([[0, 1, 0], [1, 0, 1]]) + >>> preds = np.asarray([[0, 0, 1], [1, 0, 1]]) + >>> multilabel_confusion_matrix(target, preds, num_labels=3) + Array([[[1, 0], + [0, 1]], + + [[1, 0], + [1, 0]], + + [[0, 1], + [0, 1]]], dtype=int64) + >>> target = np.asarray([[0, 1, 0], [1, 0, 1]]) + >>> preds = np.asarray([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) + >>> multilabel_confusion_matrix(target, preds, num_labels=3) + Array([[[1, 0], + [0, 1]], + + [[1, 0], + [1, 0]], + + [[0, 1], + [0, 1]]], dtype=int64) + + """ # noqa: W505 + _multilabel_confusion_matrix_validate_args( + num_labels, + threshold=threshold, + normalize=normalize, + ignore_index=ignore_index, + ) + _multilabel_confusion_matrix_validate_arrays( + target, + preds, + num_labels, + ignore_index=ignore_index, + ) + + target, preds = _multilabel_confusion_matrix_format_inputs( + target, + preds, + threshold=threshold, + ignore_index=ignore_index, + ) + tn, fp, fn, tp = _multilabel_confusion_matrix_update_state(target, preds) + + return _multilabel_confusion_matrix_compute( + tn, + fp, + fn, + tp, + num_labels, + normalize=normalize, + ) diff --git a/cyclops/evaluate/metrics/experimental/metric.py b/cyclops/evaluate/metrics/experimental/metric.py new file mode 100644 index 000000000..14355c30a --- /dev/null +++ b/cyclops/evaluate/metrics/experimental/metric.py @@ -0,0 +1,676 @@ +"""Base class for all metrics.""" +import inspect +import logging +import warnings +from abc import ABC, abstractmethod +from copy import deepcopy +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Protocol, + Union, + runtime_checkable, +) + +import array_api_compat as apc + +from cyclops.evaluate.metrics.experimental.distributed_backends import get_backend +from cyclops.evaluate.metrics.experimental.utils.ops import ( + apply_to_array_collection, + clone, + dim_zero_cat, + dim_zero_max, + dim_zero_mean, + dim_zero_min, + dim_zero_sum, + flatten_seq, +) +from cyclops.evaluate.metrics.experimental.utils.typing import Array +from cyclops.utils.log import setup_logging + + +LOGGER = logging.getLogger(__name__) +setup_logging(print_level="WARN", logger=LOGGER) + +_METRIC_REGISTRY = {} + +TState = Union[Array, List[Array]] + + +@runtime_checkable +class StateFactory(Protocol): + """Protocol for a function that creates a metric state.""" + + def __call__(self, xp: Optional[Any] = None) -> TState: + """Create a metric state.""" + ... + + +class Metric(ABC): + """Abstract base class for all metrics.""" + + def __init__(self, **kwargs: Any) -> None: + dist_backend = kwargs.get("dist_backend", "non_distributed") + self.dist_backend = get_backend(dist_backend) + + self._device = "cpu" + self._update_count: int = 0 + self._computed: Any = None + self._default_factories: Dict[str, StateFactory] = {} + self._defaults: Dict[str, TState] = {} + self._reductions: Dict[str, Union[str, Callable[..., Any], None]] = {} + + self._is_synced = False + self._cache: Optional[Dict[str, TState]] = None + + def __init_subclass__( + cls: Any, + registry_key: Optional[str] = None, + force_register: bool = False, + ) -> None: + """Add subclass to the metric registry.""" + super().__init_subclass__() + + if registry_key is None and not force_register: + warnings.warn( + "A registry key must be provided when `force_register` is True. " + "The registration will be skipped.", + category=UserWarning, + stacklevel=2, + ) + return + + if registry_key is not None and not isinstance(registry_key, str): + raise TypeError( + f"Expected `registry_key` to be a string, but got {type(registry_key)}.", + ) + + is_abstract_cls = inspect.isabstract(cls) + excluded_classes = ("OperatorMetric", "MetricCollection") + if force_register or ( + not (is_abstract_cls or cls.__name__ in excluded_classes) + and registry_key is not None + ): + _METRIC_REGISTRY[registry_key] = cls + + @property + def device(self) -> Union[str, Any]: + """Return the device on which the metric states are stored.""" + return self._device + + @property + def state_vars(self) -> Dict[str, TState]: + """Return the state variables of the metric as a dictionary.""" + return {attr: getattr(self, attr) for attr in self._defaults} + + @abstractmethod + def _update_state(self, *args: Any, **kwargs: Any) -> None: + """Update the state of the metric.""" + + @abstractmethod + def _compute_metric(self) -> Any: + """Compute the final value of the metric from the state variables.""" + + def _add_states(self, xp: Any) -> None: + """Add the state variables as attributes using the default factory functions.""" + # raise error if no default factories have been added + if not self._default_factories: + warnings.warn( + f"The metric `{self.__class__.__name__}` has no state variables, " + "which may lead to unexpected behavior. This is likely because the " + "`update` method was called before the `add_state_default_factory` " + "method was called.", + category=UserWarning, + stacklevel=2, + ) + + for name, factory in self._default_factories.items(): + params = inspect.signature(factory).parameters + if len(params) == 1 and list(params.keys())[0] == "xp": + value = factory(xp=xp) + else: + value = factory() + + _validate_state_variable_type(name, value) + + setattr(self, name, value) + self._defaults[name] = ( + clone(value) if apc.is_array_api_obj(value) else deepcopy(value) + ) + + def add_state_default_factory( + self, + name: str, + default_factory: StateFactory, + dist_reduce_fn: Optional[Union[str, Callable[..., Any]]] = None, + ) -> None: + """Add a function for creating default values for state variables. + + Parameters + ---------- + name : str + The name of the state. + default_factory : Callable[..., Union[Array, List[Array]]] + A function that creates the state. The function can take + no arguments or exactly one argument named `xp` (the array API namespace) + and must return an array-API-compatible object or a list of + array-API-compatible objects. + dist_reduce_fn : str or Callable[..., Any], optional + The function to use to reduce the state across all processes. + If `None`, no reduction will be performed. If a string, the string + must be one of ['mean', 'sum', 'cat', 'min', 'max']. If a callable, + the callable must take a single argument (the state) and + return a reduced version of the state. + + """ + if not name.isidentifier(): + raise ValueError( + f"Argument `name` must be a valid python identifier. Got `{name}`.", + ) + if not callable(default_factory): + raise TypeError( + "Expected `default_factory` to be a callable, but got " + f"{type(default_factory)}.", + ) + + params = inspect.signature(default_factory).parameters + check = ( + isinstance(default_factory, StateFactory) + and default_factory.__name__ == "list" # type: ignore + or ( + len(params) == 0 + or (len(params) == 1 and list(params.keys())[0] == "xp") + ) + ) + if not check: + raise TypeError( + "Expected `default_factory` to be a function that takes at most " + "one argument named 'xp' (the array API namespace), but got " + f"{inspect.signature(default_factory)}.", + ) + + if dist_reduce_fn == "sum": + dist_reduce_fn = dim_zero_sum + elif dist_reduce_fn == "mean": + dist_reduce_fn = dim_zero_mean + elif dist_reduce_fn == "max": + dist_reduce_fn = dim_zero_max + elif dist_reduce_fn == "min": + dist_reduce_fn = dim_zero_min + elif dist_reduce_fn == "cat": + dist_reduce_fn = dim_zero_cat + elif dist_reduce_fn is not None and not callable(dist_reduce_fn): + raise ValueError( + "`dist_reduce_fn` must be callable or one of " + "['mean', 'sum', 'cat', 'min', 'max', None]", + ) + + self._default_factories[name] = default_factory + self._reductions[name] = dist_reduce_fn + + def to_device( + self, + device: str, + stream: Optional[Union[int, Any]] = None, + ) -> "Metric": + """Move the state variables of the metric to the given device. + + Parameters + ---------- + device : str + The device to move the state variables to. + stream : int or Any, optional + The stream to use when moving the state variables to the device. + """ + for name in self._defaults: + value = getattr(self, name) + _validate_state_variable_type(name, value) + + if apc.is_array_api_obj(value): + setattr(self, name, apc.to_device(value, device, stream=stream)) + elif isinstance(value, list): + setattr( + self, + name, + [apc.to_device(array, device, stream=stream) for array in value], + ) + + self._device = device + return self + + def update(self, *args: Any, **kwargs: Any) -> None: + """Update the state of the metric. + + This method calls the `_update_state` method, which should be implemented + by the subclass. The `_update_state` method should update the state variables + of the metric using the array API-compatible objects passed to this method. + This method enusres that the state variables are created using the factory + functions added via the `add_state_default_factory` before the first call to + `_update_state`. It also ensures that the state variables are moved to the + device of the first array-API-compatible object passed to it. The method + tracks the number of times `update` is called and resets the cached result + of `compute` whenever `update` is called. + + Notes + ----- + - This method should be called before the `compute` method is called + for the first time to ensure that the state variables are initialized. + """ + if ( + not bool(self._defaults) and bool(self._default_factories) + ) or self._default_factories.keys() != self._defaults.keys(): + arrays = [obj for obj in args if apc.is_array_api_obj(obj)] + arrays.extend( + (obj for obj in kwargs.values() if apc.is_array_api_obj(obj)), + ) + if len(arrays) == 0: + raise ValueError( + f"The `update` method of metric {self.__class__.__name__} " + "was called without any array API-compatible objects. " + "This may lead to errors as metric state variables may " + "not yet be defined.", + ) + xp = apc.get_namespace(*arrays) + self._add_states(xp) + + # move state variables to device of first array + device = apc.device(arrays[0]) + self.to_device(device) + + self._computed = None + self._update_count += 1 + + self._update_state(*args, **kwargs) + + def sync(self) -> None: + """Synchronzie the metric states across all processes. + + This method is a no-op if the distributed backend is not initialized or + if using a non-distributed backend. + """ + if self._is_synced: + raise RuntimeError("The Metric has already been synced.") + + if not self.dist_backend.is_initialized: + if self.dist_backend.world_size == 1: + self._is_synced = True + self._cache = {attr: getattr(self, attr) for attr in self._defaults} + return + + self._cache = {attr: getattr(self, attr) for attr in self._defaults} + + input_dict = {attr: getattr(self, attr) for attr in self._reductions} + for attr, reduction_fn in self._reductions.items(): + # pre-concatenate metric states that are lists to reduce number of + # all_gather operations + if ( + reduction_fn == dim_zero_cat + and isinstance(input_dict[attr], list) + and len(input_dict[attr]) > 1 + ): + input_dict[attr] = [dim_zero_cat(input_dict[attr])] + + output_dict = apply_to_array_collection( + input_dict, + self.dist_backend.all_gather, + ) + + for attr, reduction_fn in self._reductions.items(): + if isinstance(output_dict[attr], list) and len(output_dict[attr]) == 0: + setattr(self, attr, []) + continue + + # stack or flatten inputs before reduction + first_elem = output_dict[attr][0] + if apc.is_array_api_obj(first_elem): + xp = apc.array_namespace(first_elem) + output_dict[attr] = xp.stack(output_dict[attr]) + elif isinstance(first_elem, list): + output_dict[attr] = flatten_seq(output_dict[attr]) + + if not (callable(reduction_fn) or reduction_fn is None): + raise TypeError("`reduction_fn` must be callable or None") + + reduced = ( + reduction_fn(output_dict[attr]) + if reduction_fn is not None + else output_dict[attr] + ) + setattr(self, attr, reduced) + + self._is_synced = True + + def unsync(self) -> None: + """Restore cached local metric state.""" + if not self._is_synced: + raise RuntimeError( + "The Metric has already been un-synced. " + "This may be because the distributed backend is not initialized.", + ) + + if self._cache is None: + raise RuntimeError( + "The internal cache should exist to unsync the Metric. " + "This is likely because the distributed backend is not initialized.", + ) + + # if we synced, restore to cache so that we can continue to accumulate + # un-synced state + for attr, val in self._cache.items(): + setattr(self, attr, val) + + self._is_synced = False + self._cache = None + + def compute(self, *args: Any, **kwargs: Any) -> Any: + """Compute the final value of the metric from the state variables. + + Prior to calling the `_compute_metric` method, which should be implemented + by the subclass, this method ensures that the metric states are synced + across all processes and guards against potentially calling the `compute` + method before the state variables have been initialized. This method + also caches the result of the metric computation so that it can be returned + without recomputing the metric. + """ + if self._update_count == 0: + raise RuntimeError( + f"The `compute` method of {self.__class__.__name__} was called " + "before the `update` method. This will lead to errors, " + "as the state variables have not yet been initialized.", + ) + + if self._computed is not None: + return self._computed # return cached result + + self.sync() + value = self._compute_metric(*args, **kwargs) + self.unsync() + + self._computed = value + + return value + + def reset(self) -> None: + """Reset the metric state to default values.""" + for state_name, default_value in self._defaults.items(): + if apc.is_array_api_obj(default_value): + setattr( + self, + state_name, + apc.to_device(clone(default_value), self.device), + ) + elif isinstance(default_value, list): + setattr( + self, + state_name, + [ + apc.to_device(clone(array), self.device) + for array in default_value + ], + ) + else: + raise TypeError( + f"Expected the value of state `{state_name}` to be an array API " + "object or a list of array API objects. But got " + f"`{type(default_value)} instead.", + ) + + self._update_count = 0 + self._computed = None + self._cache = None + self._is_synced = False + + def clone(self) -> "Metric": + """Return a deep copy of the metric.""" + return deepcopy(self) + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + """Update the global metric state and compute the metric value for a batch.""" + # global accumulation + self.update(*args, **kwargs) + update_count = self._update_count + cache = {attr: getattr(self, attr) for attr in self._defaults} + + # batch computation + self.reset() + self.update(*args, **kwargs) + batch_result = self.compute() + + # restore global state + for attr, value in cache.items(): + setattr(self, attr, value) + self._update_count = update_count + self._computed = None + self._is_synced = False + + return batch_result + + def __deepcopy__(self, memo: Optional[Dict[int, Any]] = None) -> "Metric": + """Deepcopy the metric. + + This is needed because the metric may contain array API objects that don + not allow Array objects to be instantiated directly using the `__new__` + method. An example of this is the `Array` object in the `numpy.array_api` + namespace. + """ + cls = self.__class__ + obj_copy = cls.__new__(cls) + + if memo is None: + memo = {} + memo[id(self)] = obj_copy + + for k, v in self.__dict__.items(): + if k == "_cache" and v is not None: + _cache_ = apply_to_array_collection( + v, + lambda x: apc.to_device(clone(x), self.device), + ) + setattr(obj_copy, k, _cache_) + elif k == "_defaults" and v is not None: + _defaults_ = apply_to_array_collection( + v, + lambda x: apc.to_device(clone(x), self.device), + ) + setattr(obj_copy, k, _defaults_) + elif apc.is_array_api_obj(v): + setattr(obj_copy, k, apc.to_device(clone(v), self.device)) + else: + setattr(obj_copy, k, deepcopy(v, memo)) + return obj_copy + + def __repr__(self) -> str: + """Return a string representation of the metric.""" + return f"{self.__class__.__name__}" + + def __abs__(self) -> "Metric": + """Return the absolute value of the metric.""" + return OperatorMetric("__abs__", self, None) + + def __add__(self, other: Union[int, float, "Metric", Array]) -> "Metric": + """Add two metrics, a metric and a scalar or a metric and an array.""" + return OperatorMetric("__add__", self, other) + + def __and__(self, other: Union[bool, int, "Metric", Array]) -> "Metric": + """Compute the bitwise AND of a metric and another object.""" + return OperatorMetric("__and__", self, other) + + def __eq__(self, other: Union[bool, float, int, "Metric", Array]) -> Array: + """Compare the metric to another object for equality.""" + return OperatorMetric("__eq__", self, other) + + def __floordiv__(self, other: Union[int, float, "Metric", Array]) -> "Metric": + """Compute the floor division of a metric and another object.""" + return OperatorMetric("__floordiv__", self, other) + + def __ge__(self, other: Union[int, float, "Metric", Array]) -> "Metric": + """Compute the truth value of `self` >= `other`.""" + return OperatorMetric("__ge__", self, other) + + def __gt__(self, other: Union[int, float, "Metric", Array]) -> "Metric": + """Compute the truth value of `self` > `other`.""" + return OperatorMetric("__gt__", self, other) + + def __invert__(self) -> "Metric": + """Compute the bitwise NOT of the metric.""" + return OperatorMetric("__invert__", self, None) + + def __le__(self, other: Union[int, float, "Metric", Array]) -> "Metric": + """Compute the truth value of `self` <= `other`.""" + return OperatorMetric("__le__", self, other) + + def __lt__(self, other: Union[int, float, "Metric", Array]) -> "Metric": + """Compute the truth value of `self` < `other`.""" + return OperatorMetric("__lt__", self, other) + + def __matmul__(self, other: Union["Metric", Array]) -> "Metric": + """Matrix multiply two metrics or a metric and an array.""" + return OperatorMetric("__matmul__", self, other) + + def __mod__(self, other: Union[int, float, "Metric", Array]) -> "Metric": + """Compute the remainder when a metric is divided by another object.""" + return OperatorMetric("__mod__", self, other) + + def __mul__(self, other: Union[int, float, "Metric", Array]) -> "Metric": + """Multiply two metrics, a metric and a scalar or a metric and an array.""" + return OperatorMetric("__mul__", self, other) + + def __ne__(self, other: Union[bool, float, int, "Metric", Array]) -> Array: + """Compute the truth value of `self` != `other`.""" + return OperatorMetric("__ne__", self, other) + + def __neg__(self) -> "Metric": + """Negate every element of the metric result.""" + return OperatorMetric("__neg__", self, None) + + def __or__(self, other: Union[bool, int, "Metric", Array]) -> "Metric": + """Evaluate `self` | `other`.""" + return OperatorMetric("__or__", self, other) + + def __pos__(self) -> "Metric": + """Evaluate `+self` for every element of the metric result.""" + return OperatorMetric("__abs__", self, None) + + def __pow__(self, other: Union[int, float, "Metric", Array]) -> "Metric": + """Raise the metric to the power of another object.""" + return OperatorMetric("__pow__", self, other) + + def __sub__(self, other: Union[int, float, "Metric", Array]) -> "Metric": + """Subtract two metrics, a metric and a scalar or a metric and an array.""" + return OperatorMetric("__sub__", self, other) + + def __truediv__(self, other: Union[int, float, "Metric", Array]) -> "Metric": + """Divide two metrics, a metric and a scalar or a metric and an array.""" + return OperatorMetric("__truediv__", self, other) + + def __xor__(self, other: Union[bool, int, "Metric", Array]) -> "Metric": + """Evaluate `self` ^ `other`.""" + return OperatorMetric("__xor__", self, other) + + +def _validate_state_variable_type(name: str, value: Any) -> None: + if not apc.is_array_api_obj(value) and not ( + isinstance(value, list) and all(apc.is_array_api_obj(x) for x in value) + ): + raise TypeError( + f"Expected the value of state `{name}` to be an array API object or a " + f"list of array API-compatible objects. But got {type(value)} instead.", + ) + + +class OperatorMetric(Metric): + """A metric used to apply an operator to one or two metrics. + + Parameters + ---------- + operator : str + The operator to apply. + metric_a : bool, int, float, Metric, Array + The first metric to apply the operator to. + metric_b : bool, int, float, Metric, Array, optional + The second metric to apply the operator to. For unary operators, this + should be None. + + """ + + def __init__( + self, + operator: str, + metric_a: Union[bool, int, float, Metric, Array], + metric_b: Optional[Union[bool, int, float, Metric, Array]], + ) -> None: + """Initialize the metric.""" + super().__init__() + + self._op = operator + self.metric_a = metric_a.clone() if isinstance(metric_a, Metric) else metric_a + self.metric_b = metric_b.clone() if isinstance(metric_b, Metric) else metric_b + + def _update_state(self, *args: Any, **kwargs: Any) -> None: + """Update the state of each metric.""" + if isinstance(self.metric_a, Metric): + self.metric_a.update(*args, **kwargs) + + if isinstance(self.metric_b, Metric): + self.metric_b.update(*args, **kwargs) + + def _compute_metric(self) -> None: + """Not implemented and not required. + + The `compute` is overridden to call the `compute` method of `metric_a` + and/or `metric_b` and then apply the operator. + + """ + + def compute(self) -> Any: + """Compute the value of each metric, then apply the operator.""" + result_a = ( + self.metric_a.compute() + if isinstance(self.metric_a, Metric) + else self.metric_a + ) + + result_b = ( + self.metric_b.compute() + if isinstance(self.metric_b, Metric) + else self.metric_b + ) + + if self.metric_b is None: + return getattr(result_a, self._op)() + + return getattr(result_a, self._op)(result_b) + + def reset(self) -> None: + """Reset the state of each metric.""" + if isinstance(self.metric_a, Metric): + self.metric_a.reset() + + if isinstance(self.metric_b, Metric): + self.metric_b.reset() + + def to_device( + self, + device: str, + stream: Optional[Union[int, Any]] = None, + ) -> Metric: + """Move the state variables of the metric to the given device.""" + if isinstance(self.metric_a, Metric): + self.metric_a.to_device(device, stream=stream) + elif apc.is_array_api_obj(self.metric_a): + apc.to_device(self.metric_a, device, stream=stream) + + if isinstance(self.metric_b, Metric): + self.metric_b.to_device(device, stream=stream) + elif apc.is_array_api_obj(self.metric_b): + apc.to_device(self.metric_b, device, stream=stream) + + return self + + def __repr__(self) -> str: + """Return a string representation of the object.""" + _op_metrics = f"(\n {self._op}(\n {self.metric_a!r},\n {self.metric_b!r}\n )\n)" # noqa: E501 + return self.__class__.__name__ + _op_metrics diff --git a/cyclops/evaluate/metrics/experimental/utils/__init__.py b/cyclops/evaluate/metrics/experimental/utils/__init__.py new file mode 100644 index 000000000..a95369f15 --- /dev/null +++ b/cyclops/evaluate/metrics/experimental/utils/__init__.py @@ -0,0 +1,24 @@ +"""Utilities for the metrics module.""" +from cyclops.evaluate.metrics.experimental.utils.ops import ( + apply_to_array_collection, + bincount, + clone, + dim_zero_cat, + dim_zero_max, + dim_zero_mean, + dim_zero_min, + dim_zero_sum, + flatten, + flatten_seq, + moveaxis, + safe_divide, + sigmoid, + softmax, + squeeze_all, + to_int, +) +from cyclops.evaluate.metrics.experimental.utils.typing import Array +from cyclops.evaluate.metrics.experimental.utils.validation import ( + is_floating_point, + is_numeric, +) diff --git a/cyclops/evaluate/metrics/experimental/utils/ops.py b/cyclops/evaluate/metrics/experimental/utils/ops.py new file mode 100644 index 000000000..e1dd09195 --- /dev/null +++ b/cyclops/evaluate/metrics/experimental/utils/ops.py @@ -0,0 +1,809 @@ +"""Utility functions for performing operations on array-API-compatible objects.""" +import warnings +from collections import OrderedDict, defaultdict +from typing import ( + Any, + Callable, + List, + Mapping, + Optional, + Sequence, + Tuple, + Union, +) + +import array_api_compat as apc +from array_api_compat.common._helpers import _is_numpy_array, _is_torch_array +from numpy.core.multiarray import normalize_axis_index # type: ignore + +from cyclops.evaluate.metrics.experimental.utils.typing import Array +from cyclops.evaluate.metrics.experimental.utils.validation import ( + _get_int_dtypes, + is_floating_point, +) + + +def apply_to_array_collection( # noqa: PLR0911 + data: Any, + func: Callable[..., Any], + *args: Any, + **kwargs: Any, +) -> Any: + """Apply a function to an array or collection of arrays. + + Parameters + ---------- + data : Any + An array or collection of arrays. + func : Callable[..., Any] + A function to be applied to `data`. + *args : Any + Positional arguments to be passed to the function. + **kwargs : Any + Keyword arguments to be passed to the function. + + Returns + ------- + Any + The result of applying the function to the input data. + + """ + is_namedtuple = ( + isinstance(data, tuple) + and hasattr(data, "_asdict") + and hasattr(data, "_fields") + ) + if apc.is_array_api_obj(data): + return func(data, *args, **kwargs) + if isinstance(data, list) and all(apc.is_array_api_obj(x) for x in data): + return [func(x, *args, **kwargs) for x in data] + if (isinstance(data, tuple) and not is_namedtuple) and all( + apc.is_array_api_obj(x) for x in data + ): + return tuple(func(x, *args, **kwargs) for x in data) + if isinstance(data, dict) and all(apc.is_array_api_obj(x) for x in data.values()): + return {k: func(v, *args, **kwargs) for k, v in data.items()} + + elem_type = type(data) + + if isinstance(data, Mapping): + out = [] + for k, v in data.items(): + out.append((k, apply_to_array_collection(v, func, *args, **kwargs))) + if isinstance(data, defaultdict): + return elem_type(data.default_factory, OrderedDict(out)) + return elem_type(OrderedDict(out)) + + is_sequence = isinstance(data, Sequence) and not isinstance(data, str) + if is_namedtuple or is_sequence: + out = [] + for d in data: + out.append(apply_to_array_collection(d, func, *args, **kwargs)) + return elem_type(*out) if is_namedtuple else elem_type(out) + return data + + +def bincount( + array: Array, + weights: Optional[Array] = None, + minlength: int = 0, +) -> Array: + """Count the number of occurrences of each value in an array of non-negative ints. + + Parameters + ---------- + array : Array + The input array. + weights : Array, optional, default=None + An array of weights, of the same shape as `array`. Each value in `array` + only contributes its associated weight towards the bin count (instead of 1). + If `weights` is None, all values in `array` are counted equally. + minlength : int, optional, default=0 + A minimum number of bins for the output array. If `minlength` is greater + than the largest value in `array`, then the output array will have + `minlength` bins. + + Returns + ------- + Array + The result of binning the input array. + + Raises + ------ + ValueError + If `array` is not a 1D array of non-negative integers, `weights` is not None + and `weights` and `array` do not have the same shape, or `minlength` is + negative. + + Examples + -------- + >>> import numpy.array_api as np + >>> from cyclops.evaluate.metrics.experimental.utils.ops import bincount + >>> x = np.asarray([0, 1, 1, 2, 2, 2]) + >>> bincount(x) + Array([1, 2, 3], dtype=int64) + >>> bincount(x, weights=np.asarray([0.5, 0.5, 0.5, 0.5, 0.5, 0.5])) + Array([0.5, 1. , 1.5], dtype=float64) + >>> bincount(x, minlength=5) + Array([1, 2, 3, 0, 0], dtype=int32) + + """ + xp = apc.array_namespace(array) + + if not (isinstance(minlength, int) and minlength >= 0): + raise ValueError( + "Expected `min_length` to be a non-negative integer. " + f"Got minlength={minlength}.", + ) + + if apc.size(array) == 0: + return xp.zeros(shape=(minlength,), dtype=xp.int64, device=apc.device(array)) + + if array.ndim != 1: + raise ValueError(f"Expected `array` to be a 1D array. Got {array.ndim}D array.") + + if array.dtype not in _get_int_dtypes(namespace=xp): + raise ValueError( + f"Expected `array` to be an integer array. Got {array.dtype} type.", + ) + + if xp.any(array < 0): + raise ValueError("`array` must contain only non-negative integers.") + + if weights is not None and array.shape != weights.shape: + raise ValueError( + "Expected `array` and `weights` to have the same shape. " + f"Got array.shape={array.shape} and weights.shape={weights.shape}.", + ) + + size = int(xp.max(array)) + 1 + size = max(size, int(minlength)) + device = apc.device(array) + + bincount = xp.astype( + array == xp.arange(size, device=device)[:, None], + weights.dtype if weights is not None else xp.int32, + ) + return xp.sum(bincount * (weights if weights is not None else 1), axis=1) + + +def clone(array: Array) -> Array: + """Create a copy of an array. + + Parameters + ---------- + array : Array + The input array. + + Returns + ------- + Array + A copy of the input array. + + Notes + ----- + This method is a temporary workaround for the lack of support for the `copy` + or `clone` method in the array API standard. The 2023 version of the standard + may include a copy method. See: https://github.com/data-apis/array-api/issues/495 + + Examples + -------- + >>> import numpy.array_api as np + >>> from cyclops.evaluate.metrics.experimental.utils.ops import clone + >>> x = np.zeros((1, 2, 3)) + >>> y = x + >>> y is x + True + >>> y = clone(x) + >>> y is x + False + + """ + xp = apc.array_namespace(array) + return xp.asarray(array, device=apc.device(array), copy=True) + + +def dim_zero_cat(x: Union[Array, List[Array], Tuple[Array]]) -> Array: + """Concatenation along the zero dimension.""" + if apc.is_array_api_obj(x) or not x: # covers empty list/tuple + return x + + if not isinstance(x, (list, tuple)): + raise TypeError( + "Expected `x` to be an Array or a list/tuple of Arrays. " + f"Got {type(x)} instead.", + ) + + xp = apc.array_namespace(x[0]) + x_ = [] + for el in x: + if not apc.is_array_api_obj(el): + raise TypeError( + "Expected `x` to be an Array or a list/tuple of Arrays. " + f"Got a list/tuple containing a {type(el)} instead.", + ) + if apc.size(el) == 1 and el.ndim == 0: + x_.append(xp.expand_dims(el, axis=0)) + else: + x_.append(el) + + if not x_: # empty list + raise ValueError("No samples to concatenate") + return xp.concat(x_, axis=0) + + +def dim_zero_max(x: Array) -> Array: + """Max along the zero dimension.""" + xp = apc.array_namespace(x) + return xp.max(x, axis=0) + + +def dim_zero_mean(x: Array) -> Array: + """Average along the zero dimension.""" + xp = apc.array_namespace(x) + x = x if is_floating_point(x) else xp.astype(x, xp.float32) + return xp.mean(x, axis=0) + + +def dim_zero_min(x: Array) -> Array: + """Min along the zero dimension.""" + xp = apc.array_namespace(x) + return xp.min(x, axis=0) + + +def dim_zero_sum(x: Array) -> Array: + """Summation along the zero dimension.""" + xp = apc.array_namespace(x) + return xp.sum(x, axis=0) + + +def flatten(array: Array) -> Array: + """Flatten an array. + + Parameters + ---------- + array : Array + The input array. + + Returns + ------- + Array + The flattened array. + + Notes + ----- + This method is a temporary workaround for the lack of support for the `flatten` + method in the array API standard. + + Examples + -------- + >>> import numpy.array_api as np + >>> from cyclops.evaluate.metrics.experimental.utils.ops import flatten + >>> x = np.zeros((1, 2, 3)) + >>> x.shape + (1, 2, 3) + >>> flatten(x).shape + (6,) + """ + xp = apc.array_namespace(array) + return xp.asarray( + xp.reshape(array, shape=(-1,)), + device=apc.device(array), + copy=True, + ) + + +def flatten_seq(inp: Sequence[Any]) -> List[Any]: + """Flatten a nested sequence into a single list. + + Parameters + ---------- + inp : Sequence + The input sequence. + + Returns + ------- + List[Any] + The flattened sequence. + + Examples + -------- + >>> from cyclops.evaluate.metrics.experimental.utils.ops import flatten_seq + >>> x = [[1, 2, 3], [4, 5, 6]] + >>> flatten_seq(x) + [1, 2, 3, 4, 5, 6] + + """ + if not isinstance(inp, Sequence): + raise TypeError("Input must be a Sequence") + + if len(inp) == 0: + return [] + + if isinstance(inp, str) and len(inp) == 1: + return [inp] + + result = [] + for sublist in inp: + if isinstance(sublist, Sequence): + result.extend(flatten_seq(sublist)) + else: + result.append(sublist) + return result + + +def moveaxis( + array: Array, + source: Union[int, Tuple[int]], + destination: Union[int, Tuple[int]], +) -> Array: + """Move given array axes to new positions. + + Parameters + ---------- + array : Array + The input array. + source : int or Tuple[int] + Original positions of the axes to move. These must be unique. + destination : int or Tuple[int] + Destination positions for each of the original axes. These must also be + unique. + + Returns + ------- + Array + Array with moved axes. This array is a view of the input array. + + Raises + ------ + ValueError + If the source and destination axes are not unique or if the number of + elements in `source` and `destination` are not equal. + + Notes + ----- + A similar method has been added to the array API standard in v2022.12. See: + https://data-apis.org/array-api/draft/API_specification/generated/array_api.moveaxis.html + The `array_api_compat` library does not yet support that version of the standard, + so we define this method here for now. + + Examples + -------- + >>> import numpy.array_api as np + >>> from cyclops.evaluate.metrics.experimental.utils.ops import moveaxis + >>> x = np.zeros((1, 2, 3)) + >>> moveaxis(x, 0, 1).shape + (2, 1, 3) + + """ + if isinstance(source, int): + source = (source,) + if isinstance(destination, int): + destination = (destination,) + + if (isinstance(source, tuple) and not isinstance(destination, tuple)) or ( + isinstance(destination, tuple) and not isinstance(source, tuple) + ): + raise ValueError( + "`source` and `destination` must both be tuples or both be integers", + ) + + if len(set(source)) != len(source) or len(set(destination)) != len(destination): + raise ValueError("`source` and `destination` must not contain duplicate values") + + if len(source) != len(destination): + raise ValueError( + "`source` and `destination` must have the same number of elements", + ) + + xp = apc.array_namespace(array) + num_dims = array.ndim + if ( + max(source) >= num_dims + or max(destination) >= num_dims + or abs(min(source)) > num_dims + or abs(min(destination)) > num_dims + ): + raise ValueError( + "Values in `source` and `destination` are out of bounds for `array` " + f"with {num_dims} dimensions", + ) + + # normalize negative indices + src_ = tuple([src % num_dims for src in source]) + dest_ = tuple([dest % num_dims for dest in destination]) + + order = [n for n in range(num_dims) if n not in src_] + + for src, dest in sorted(zip(dest_, src_)): + order.insert(src, dest) + + return xp.permute_dims(array, order) + + +def remove_ignore_index( + target: Array, + preds: Array, + ignore_index: Optional[Union[Tuple[int, ...], int]], +) -> Tuple[Array, Array]: + """Remove the samples at the indices where target values match `ignore_index`. + + Parameters + ---------- + target : Array + The target array. + preds : Array + The predictions array. + ignore_index : int or Tuple[int], optional, default=None + The index or indices to ignore. If None, no indices will be ignored. + + Returns + ------- + Tuple[Array, Array] + The `target` and `preds` arrays with the samples at the indices where target + values match `ignore_index` removed. + """ + if ignore_index is None: + return target, preds + + if not ( + isinstance(ignore_index, int) + or ( + isinstance(ignore_index, tuple) + and all(isinstance(x, int) for x in ignore_index) + ) + ): + raise TypeError( + "Expected `ignore_index` to be an integer or a tuple of integers. " + f"Got {type(ignore_index)} instead.", + ) + + xp = apc.array_namespace(target, preds) + + if isinstance(ignore_index, int): + mask = target == ignore_index + else: + mask = xp.zeros_like(target, dtype=xp.bool, device=apc.device(target)) + for index in ignore_index: + mask = xp.logical_or(mask, target == index) + + return clone(target[~mask]), clone(preds[~mask]) + + +def safe_divide(numerator: Array, denominator: Array) -> Array: + """Divide two arrays and return zero if denominator is zero. + + Parameters + ---------- + numerator : Array + The numerator array. + denominator : Array + The denominator array. + + Returns + ------- + quotient : Array + The quotient of the two arrays. + + Raises + ------ + ValueError + If `numerator` and `denominator` do not have the same shape. + + Examples + -------- + >>> import numpy.array_api as np + >>> from cyclops.evaluate.metrics.experimental.utils.ops import safe_divide + >>> x = np.asarray([1.1, 2.0, 3.0]) + >>> y = np.asarray([1.1, 0.0, 3.0]) + >>> safe_divide(x, y) + Array([1., 0., 1.], dtype=float64) + """ + xp = apc.array_namespace(numerator, denominator) + + numerator = ( + numerator if is_floating_point(numerator) else xp.astype(numerator, xp.float32) + ) + denominator = ( + denominator + if is_floating_point(denominator) + else xp.astype(denominator, xp.float32) + ) + + return xp.where( + denominator == 0, + xp.asarray(0, dtype=xp.float32, device=apc.device(numerator)), + numerator / denominator, + ) + + +def sigmoid(array: Array) -> Array: + """Compute the sigmoid of an array. + + Parameters + ---------- + array : Array + The input array. + + Returns + ------- + Array + The sigmoid of the input array. + + Examples + -------- + >>> import numpy.array_api as np + >>> from cyclops.evaluate.metrics.experimental.utils.ops import sigmoid + >>> x = np.asarray([1.1, 2.0, 3.0]) + >>> sigmoid(x) + Array([0.75026011, 0.88079708, 0.95257413], dtype=float64) + + """ + xp = apc.array_namespace(array) + if apc.size(array) == 0: + return xp.asarray([], dtype=xp.float32, device=apc.device(array)) + + array = array if is_floating_point(array) else xp.astype(array, xp.float32) + + exp_array = xp.exp(array) + return xp.where( + array >= 0, + 1 / (1 + xp.exp(-array)), + exp_array / (1 + exp_array), + ) + + +def softmax(array: Array, axis: Optional[int] = None) -> Array: + """Compute the softmax of an array. + + Parameters + ---------- + array : Array + The input array. + axis : int, optional, default=None + The axis along which to compute the softmax. If None, the softmax will be + computed over all elements in the array. + + Returns + ------- + Array + The softmax of the input array. + + Examples + -------- + >>> import numpy.array_api as np + >>> from cyclops.evaluate.metrics.experimental.utils.ops import softmax + >>> x = np.asarray([[1.1, 2.0, 3.0], [2.0, 1.0, 0.5]]) + >>> softmax(x, axis=1) + Array([[0.09856589, 0.24243297, 0.65900114], + [0.62853172, 0.2312239 , 0.14024438]], dtype=float64) + + """ + xp = apc.array_namespace(array) + + x_max = xp.max(array, axis=axis, keepdims=True) + exp_x_shifted = xp.exp(array - x_max) + + return safe_divide(exp_x_shifted, xp.sum(exp_x_shifted, axis=axis, keepdims=True)) + + +def squeeze_all(array: Array) -> Array: + """Remove all singleton dimensions from an array. + + Parameters + ---------- + array : Array + An array to squeeze. + + Returns + ------- + Array + The squeezed array. + + Examples + -------- + >>> import numpy.array_api as np + >>> from cyclops.evaluate.metrics.experimental.utils.ops import squeeze_all + >>> x = np.zeros((1, 2, 1, 3, 1, 4)) + >>> x.shape + (1, 2, 1, 3, 1, 4) + >>> squeeze_all(x).shape + (2, 3, 4) + """ + xp = apc.array_namespace(array) + singleton_axes = tuple(i for i in range(array.ndim) if array.shape[i] == 1) + if len(singleton_axes) == 0: + return array + + return xp.squeeze(array, axis=singleton_axes) + + +def to_int(array: Array) -> Array: + """Convert the data type of an array to a 64-bit integer type. + + Parameters + ---------- + array : Array + The input array. + + Returns + ------- + Array + The input array converted to an integer array. + + Examples + -------- + >>> import numpy.array_api as np + >>> from cyclops.evaluate.metrics.experimental.utils.ops import to_int + >>> x = np.asarray([1.1, 2.0, 3.0]) + >>> to_int(x) + Array([1, 2, 3], dtype=int32) + + """ + xp = apc.array_namespace(array) + return xp.astype(array, xp.int64, copy=False) + + +def _select_topk( # noqa: PLR0912 + scores: Array, + top_k: int = 1, + axis: int = -1, +) -> Array: + """Compute a one-hot array indicating the top-k scores along an axis. + + Parameters + ---------- + scores : Array + An array of scores of shape `[..., C, ...]` where `C` is in the axis `axis`. + top_k : int, optional, default=1 + The number of top scores to select. + axis : int, optional, default=-1 + The axis along which to select the top-k scores. + + Returns + ------- + Array + A one-hot array indicating the top-k scores along an axis. + + Raises + ------ + ValueError + If `top_k` is not positive, `axis` is greater than or equal to the number of + dimensions in `scores`, or `top_k` is greater than the size of `scores` along + `axis`. + + Warnings + -------- + This method may be slow or memory-intensive for some array namespaces. See + `Notes` for more details. + + Notes + ----- + This method can be slow or memory-intensive for some array namespaces due to + several factors: + 1. The use of `argsort` to fully sort the array as opposed to a partial sort. + However, an upcoming version of the array API will include a `topk` method + that will be more efficient. See https://github.com/data-apis/array-api/issues/629 + 2. The lack of support for advanced indexing in the array API standard. + 3. The lack of support for methods that set elements along an axis, like + `np.put_along_axis` or `torch.scatter`. + + Examples + -------- + >>> import numpy.array_api as np + >>> from cyclops.evaluate.metrics.experimental.utils.ops import _select_topk + >>> x = np.asarray([[1.1, 2.0, 3.0], [2.0, 1.0, 0.5]]) + >>> _select_topk(x, top_k=2, axis=1) + Array([[0, 1, 1], + [1, 1, 0]], dtype=int32) + """ + xp = apc.array_namespace(scores) + if top_k <= 0: + raise ValueError("top_k must be positive") + if axis >= scores.ndim: + raise ValueError("axis must be less than scores.ndim") + if scores.ndim == 0 and top_k != 1: + raise ValueError("top_k must be 1 for 0-dim scores") + if top_k > scores.shape[axis]: + raise ValueError("top_k must be less than or equal to scores.shape[axis]") + + if top_k == 1: # more efficient than argsort for top_k=1 + topk_indices = xp.argmax(scores, axis=axis, keepdims=True) + else: + topk_indices = xp.argsort(scores, axis=axis, descending=True, stable=False) + slice_indices = [slice(None)] * scores.ndim + slice_indices[axis] = slice(None, top_k) + topk_indices = topk_indices[tuple(slice_indices)] + + zeros = xp.zeros_like(scores, dtype=xp.int32, device=apc.device(scores)) + + if _is_torch_array(scores): + return zeros.scatter(axis, topk_indices, 1) + if _is_numpy_array(scores): + return xp.put_along_axis(zeros, topk_indices, 1, axis) + + axis = normalize_axis_index(axis, scores.ndim) + + # --- begin code copied from numpy --- + # from https://github.com/numpy/numpy/blob/v1.26.0/numpy/lib/shape_base.py#L27 + shape_ones = (1,) * topk_indices.ndim + dest_dims = list(range(axis)) + [None] + list(range(axis + 1, topk_indices.ndim)) + + # build a fancy index, consisting of orthogonal aranges, with the + # requested index inserted at the right location + fancy_index = [] + for dim, n in zip(dest_dims, scores.shape): + if dim is None: + fancy_index.append(topk_indices) + else: + ind_shape = shape_ones[:dim] + (-1,) + shape_ones[dim + 1 :] + fancy_index.append(xp.reshape(xp.arange(n), shape=(ind_shape))) + # --- end of code copied from numpy --- + + indices = xp.broadcast_arrays(*fancy_index) + indices = xp.stack(indices, axis=-1) + indices = xp.reshape(indices, shape=(-1, indices.shape[-1])) + + try: # advanced indexing + zeros[tuple(indices.T)] = 1 + except IndexError: + warnings.warn( + "The `select_topk` method is slow and memory-intensive for the array " + f"namespace '{xp.__name__}' and will be deprecated in a future release." + "Consider writing a custom implementation for your array namespace " + "using operations that are more efficient for your array namespace.", + category=UserWarning, + stacklevel=1, + ) + for idx in range(indices.shape[0]): + zeros[tuple(indices[idx, ...])] = 1 + + return zeros + + +def _to_one_hot(array: Array, num_classes: Optional[int] = None) -> Array: + """Convert an array of integer labels to a one-hot encoded array. + + Parameters + ---------- + array : Array + An array of integer labels. + num_classes : int, optional, default=None + The number of classes. If not provided, the number of classes will be inferred + from the array. + + Returns + ------- + Array + A one-hot encoded representation of `array`. + + Warnings + -------- + This method can be slow or memory-intensive for some array namespaces due to + the lack of support for advanced indexing in the array API standard. + + """ + xp = apc.array_namespace(array) + if array.dtype not in _get_int_dtypes(namespace=xp): + array = to_int(array) + input_shape = array.shape + array = flatten(array) + + if num_classes is None: + unique_values = xp.unique_values(array) + num_classes = int(apc.size(unique_values)) + + device = apc.device(array) + + try: # advanced indexing + return xp.eye(num_classes, dtype=xp.int64, device=device)[array] + except IndexError: + n = array.shape[0] + categorical = xp.zeros((n, num_classes), dtype=xp.int64, device=device) + + indices = xp.stack((xp.arange(n, device=device), array), axis=-1) + for idx in range(indices.shape[0]): + categorical[tuple(indices[idx, ...])] = 1 + output_shape = input_shape + (num_classes,) + + return xp.reshape(categorical, output_shape) diff --git a/cyclops/evaluate/metrics/experimental/utils/typing.py b/cyclops/evaluate/metrics/experimental/utils/typing.py new file mode 100644 index 000000000..90e501ef0 --- /dev/null +++ b/cyclops/evaluate/metrics/experimental/utils/typing.py @@ -0,0 +1,27 @@ +"""Utilities for array-API compatibility.""" +from typing import TYPE_CHECKING, Any, Optional, Protocol, Union + +import numpy.typing as npt +import torch + +from cyclops.utils.optional import import_optional_module + + +class _ArrayAPICompliantObject(Protocol): + """Protocol for objects that have a __array_namespace__ attribute.""" + + def __array_namespace__(self, api_version: Optional[str] = None) -> Any: + """Return an array-API-compatible namespace.""" + ... + + +_supported_array_types = (npt.NDArray, torch.Tensor, _ArrayAPICompliantObject) + +cp = import_optional_module("cupy", error="ignore") +if cp is not None: + _supported_array_types += (cp.ndarray,) # type: ignore[assignment] + +if TYPE_CHECKING: # noqa: SIM108 + Array = Any +else: + Array = Union[_supported_array_types] # type: ignore diff --git a/cyclops/evaluate/metrics/experimental/utils/validation.py b/cyclops/evaluate/metrics/experimental/utils/validation.py new file mode 100644 index 000000000..e81b6c312 --- /dev/null +++ b/cyclops/evaluate/metrics/experimental/utils/validation.py @@ -0,0 +1,155 @@ +"""Utility functions for performing common input validations.""" +from typing import Any, List, Literal + +import array_api_compat as apc + +from cyclops.evaluate.metrics.experimental.utils.typing import Array + + +def is_floating_point(array: Array) -> bool: + """Return `True` if the array has a floating-point datatype. + + Floating-point datatypes include: + - `float32` + - `float64` + - `float16` + - `bfloat16` + + """ + xp = apc.array_namespace(array) + float_dtypes = _get_float_dtypes(xp) + + return array.dtype in float_dtypes + + +def is_numeric(*arrays: Array) -> bool: + """Check if given arrays have numeric datatype. + + Numeric datatypes include: + - `float32` + - `float64` + - `float16` + - `bfloat16` + - `int8` + - `int16` + - `int32` + - `int64` + - `uint8` + - `uint16` + - `uint32` + - `uint64` + + Parameters + ---------- + arrays : Array + The arrays to check. + + Returns + ------- + bool + `True` if all of the arrays have a numeric datatype. `False` otherwise. + + """ + xp = apc.array_namespace(*arrays) + numeric_dtypes = _get_int_dtypes(xp) + _get_float_dtypes(xp) + + return all(array.dtype in numeric_dtypes for array in arrays) + + +def _basic_input_array_checks(target: Array, preds: Array) -> None: + """Perform basic validation of `target` and `preds`.""" + if not apc.is_array_api_obj(target): + raise ValueError( + "Expected `target` to be an array-API-compatible object, but got " + f"{type(target)}.", + ) + + if not apc.is_array_api_obj(preds): + raise ValueError( + "Expected `preds` to be an array-API-compatible object, but got " + f"{type(preds)}.", + ) + + if _is_empty(target) or _is_empty(preds): + raise ValueError("Expected `target` and `preds` to be non-empty arrays.") + + if not is_numeric(target, preds): + raise ValueError( + "Expected `target` and `preds` to be numeric arrays, but got " + f"{target.dtype} and {preds.dtype}, respectively.", + ) + + +def _check_average_arg(average: Literal["micro", "macro", "weighted", None]) -> None: + """Validate the `average` argument.""" + if average not in ["micro", "macro", "weighted", None]: + raise ValueError( + f"Argument average has to be one of 'micro', 'macro', 'weighted', " + f"or None, got {average}.", + ) + + +def _check_same_shape(target: Array, preds: Array) -> None: + """Check if `target` and `preds` have the same shape.""" + if target.shape != preds.shape: + raise ValueError( + "Expected `target` and `preds` to have the same shape, but got `target` " + f"with shape={target.shape} and `preds` with shape={preds.shape}.", + ) + + +def _get_float_dtypes(namespace: Any) -> List[Any]: + """Return a list of floating-point dtypes. + + Notes + ----- + The integer types `float16` and `bfloat16` are not defined in the API, but + are included here as they are increasingly common in deep learning frameworks. + + """ + float_dtypes = [namespace.float32, namespace.float64] + if hasattr(namespace, "float16"): + float_dtypes.append(namespace.float16) + if hasattr(namespace, "bfloat16"): + float_dtypes.append(namespace.bfloat16) + + return float_dtypes + + +def _get_int_dtypes(namespace: Any) -> List[Any]: + """Return a list of integer dtypes. + + Notes + ----- + The integer types `uint16`, `uint32` and `uint64` are defined in the API + standard but not in PyTorch. Although, PyTorch supports quantized integer + types like `qint8` and `quint8`, but we omit them here because they are not + part of the array API standard. + The 2022.12 version of the array API standard includes a `isdtype` method + that will eliminate the need for this function. The `array_api_compat` + package currently (Nov. 2023) supports only the 2021.12 version of the + standard, so we need to define this function for now. + + """ + int_dtypes = [ + namespace.int8, + namespace.int16, + namespace.int32, + namespace.int64, + namespace.uint8, + ] + + if hasattr(namespace, "uint16"): + int_dtypes.append(namespace.uint16) + if hasattr(namespace, "uint32"): + int_dtypes.append(namespace.uint32) + if hasattr(namespace, "uint64"): + int_dtypes.append(namespace.uint64) + + return int_dtypes + + +def _is_empty(array: Array) -> bool: + """Return `True` if the array is empty.""" + numel = apc.size(array) + return numel is not None and numel == 0 diff --git a/cyclops/utils/optional.py b/cyclops/utils/optional.py new file mode 100644 index 000000000..a52ffbb7b --- /dev/null +++ b/cyclops/utils/optional.py @@ -0,0 +1,56 @@ +"""Utilities for handling optional dependencies.""" +import importlib +import importlib.util +import logging +import warnings +from types import ModuleType +from typing import Literal, Optional + +from cyclops.utils.log import setup_logging + + +LOGGER = logging.getLogger(__name__) +setup_logging(print_level="WARN", logger=LOGGER) + + +def import_optional_module( + name: str, + error: Literal["raise", "warn", "ignore"] = "raise", +) -> Optional[ModuleType]: + """Import an optional module. + + Parameters + ---------- + name : str + The name of the module to import. + error : ErrorOption, optional + How to handle errors. One of: + - "raise": raise an error if the module cannot be imported. + - "warn": raise a warning if the module cannot be imported. + - "ignore": ignore the missing module and return `None`. + + Returns + ------- + Optional[ModuleType] + The imported module, if it exists. Otherwise, `None`. + + """ + if error not in ("raise", "warn", "ignore"): + raise ValueError( + "Expected `error` to be one of 'raise', 'warn, or 'ignore', " + f"but got {error}.", + ) + + try: + return importlib.import_module(name) + except ModuleNotFoundError as exc: + msg = ( + f"Missing optional dependency '{name}'. " + f"Use pip or conda to install {name}." + ) + if error == "raise": + raise type(exc)(msg) from None + if error == "warn": + warnings.warn(msg, category=ImportWarning, stacklevel=2) + + return None diff --git a/poetry.lock b/poetry.lock index 4c6bdbd1c..1888780cd 100644 --- a/poetry.lock +++ b/poetry.lock @@ -344,6 +344,21 @@ cffi = ">=1.0.1" dev = ["cogapp", "pre-commit", "pytest", "wheel"] tests = ["pytest"] +[[package]] +name = "array-api-compat" +version = "1.4" +description = "A wrapper around NumPy and other array libraries to make them compatible with the Array API standard" +optional = false +python-versions = ">=3.8" +files = [ + {file = "array_api_compat-1.4-py3-none-any.whl", hash = "sha256:326383f5423716922724199988084250ca41074c7938bed6f9cea95d8d70f833"}, + {file = "array_api_compat-1.4.tar.gz", hash = "sha256:d49f00eb66b436cf3a6026d6f43c115d3e058a3a9936536b0bac33dd470e8b4d"}, +] + +[package.extras] +cupy = ["cupy"] +numpy = ["numpy"] + [[package]] name = "arrow" version = "1.3.0" @@ -916,74 +931,66 @@ srsly = ">=2.4.0,<3.0.0" [[package]] name = "contourpy" -version = "1.1.1" +version = "1.2.0" description = "Python library for calculating contours of 2D quadrilateral grids" optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" files = [ - {file = "contourpy-1.1.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:46e24f5412c948d81736509377e255f6040e94216bf1a9b5ea1eaa9d29f6ec1b"}, - {file = "contourpy-1.1.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0e48694d6a9c5a26ee85b10130c77a011a4fedf50a7279fa0bdaf44bafb4299d"}, - {file = "contourpy-1.1.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a66045af6cf00e19d02191ab578a50cb93b2028c3eefed999793698e9ea768ae"}, - {file = "contourpy-1.1.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4ebf42695f75ee1a952f98ce9775c873e4971732a87334b099dde90b6af6a916"}, - {file = "contourpy-1.1.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f6aec19457617ef468ff091669cca01fa7ea557b12b59a7908b9474bb9674cf0"}, - {file = "contourpy-1.1.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:462c59914dc6d81e0b11f37e560b8a7c2dbab6aca4f38be31519d442d6cde1a1"}, - {file = "contourpy-1.1.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:6d0a8efc258659edc5299f9ef32d8d81de8b53b45d67bf4bfa3067f31366764d"}, - {file = "contourpy-1.1.1-cp310-cp310-win32.whl", hash = "sha256:d6ab42f223e58b7dac1bb0af32194a7b9311065583cc75ff59dcf301afd8a431"}, - {file = "contourpy-1.1.1-cp310-cp310-win_amd64.whl", hash = "sha256:549174b0713d49871c6dee90a4b499d3f12f5e5f69641cd23c50a4542e2ca1eb"}, - {file = "contourpy-1.1.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:407d864db716a067cc696d61fa1ef6637fedf03606e8417fe2aeed20a061e6b2"}, - {file = "contourpy-1.1.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:dfe80c017973e6a4c367e037cb31601044dd55e6bfacd57370674867d15a899b"}, - {file = "contourpy-1.1.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e30aaf2b8a2bac57eb7e1650df1b3a4130e8d0c66fc2f861039d507a11760e1b"}, - {file = "contourpy-1.1.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3de23ca4f381c3770dee6d10ead6fff524d540c0f662e763ad1530bde5112532"}, - {file = "contourpy-1.1.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:566f0e41df06dfef2431defcfaa155f0acfa1ca4acbf8fd80895b1e7e2ada40e"}, - {file = "contourpy-1.1.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b04c2f0adaf255bf756cf08ebef1be132d3c7a06fe6f9877d55640c5e60c72c5"}, - {file = "contourpy-1.1.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d0c188ae66b772d9d61d43c6030500344c13e3f73a00d1dc241da896f379bb62"}, - {file = "contourpy-1.1.1-cp311-cp311-win32.whl", hash = "sha256:0683e1ae20dc038075d92e0e0148f09ffcefab120e57f6b4c9c0f477ec171f33"}, - {file = "contourpy-1.1.1-cp311-cp311-win_amd64.whl", hash = "sha256:8636cd2fc5da0fb102a2504fa2c4bea3cbc149533b345d72cdf0e7a924decc45"}, - {file = "contourpy-1.1.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:560f1d68a33e89c62da5da4077ba98137a5e4d3a271b29f2f195d0fba2adcb6a"}, - {file = "contourpy-1.1.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:24216552104ae8f3b34120ef84825400b16eb6133af2e27a190fdc13529f023e"}, - {file = "contourpy-1.1.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:56de98a2fb23025882a18b60c7f0ea2d2d70bbbcfcf878f9067234b1c4818442"}, - {file = "contourpy-1.1.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:07d6f11dfaf80a84c97f1a5ba50d129d9303c5b4206f776e94037332e298dda8"}, - {file = "contourpy-1.1.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f1eaac5257a8f8a047248d60e8f9315c6cff58f7803971170d952555ef6344a7"}, - {file = "contourpy-1.1.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:19557fa407e70f20bfaba7d55b4d97b14f9480856c4fb65812e8a05fe1c6f9bf"}, - {file = "contourpy-1.1.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:081f3c0880712e40effc5f4c3b08feca6d064cb8cfbb372ca548105b86fd6c3d"}, - {file = "contourpy-1.1.1-cp312-cp312-win32.whl", hash = "sha256:059c3d2a94b930f4dafe8105bcdc1b21de99b30b51b5bce74c753686de858cb6"}, - {file = "contourpy-1.1.1-cp312-cp312-win_amd64.whl", hash = "sha256:f44d78b61740e4e8c71db1cf1fd56d9050a4747681c59ec1094750a658ceb970"}, - {file = "contourpy-1.1.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:70e5a10f8093d228bb2b552beeb318b8928b8a94763ef03b858ef3612b29395d"}, - {file = "contourpy-1.1.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:8394e652925a18ef0091115e3cc191fef350ab6dc3cc417f06da66bf98071ae9"}, - {file = "contourpy-1.1.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c5bd5680f844c3ff0008523a71949a3ff5e4953eb7701b28760805bc9bcff217"}, - {file = "contourpy-1.1.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:66544f853bfa85c0d07a68f6c648b2ec81dafd30f272565c37ab47a33b220684"}, - {file = "contourpy-1.1.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e0c02b75acfea5cab07585d25069207e478d12309557f90a61b5a3b4f77f46ce"}, - {file = "contourpy-1.1.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:41339b24471c58dc1499e56783fedc1afa4bb018bcd035cfb0ee2ad2a7501ef8"}, - {file = "contourpy-1.1.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:f29fb0b3f1217dfe9362ec55440d0743fe868497359f2cf93293f4b2701b8251"}, - {file = "contourpy-1.1.1-cp38-cp38-win32.whl", hash = "sha256:f9dc7f933975367251c1b34da882c4f0e0b2e24bb35dc906d2f598a40b72bfc7"}, - {file = "contourpy-1.1.1-cp38-cp38-win_amd64.whl", hash = "sha256:498e53573e8b94b1caeb9e62d7c2d053c263ebb6aa259c81050766beb50ff8d9"}, - {file = "contourpy-1.1.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:ba42e3810999a0ddd0439e6e5dbf6d034055cdc72b7c5c839f37a7c274cb4eba"}, - {file = "contourpy-1.1.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:6c06e4c6e234fcc65435223c7b2a90f286b7f1b2733058bdf1345d218cc59e34"}, - {file = "contourpy-1.1.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ca6fab080484e419528e98624fb5c4282148b847e3602dc8dbe0cb0669469887"}, - {file = "contourpy-1.1.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:93df44ab351119d14cd1e6b52a5063d3336f0754b72736cc63db59307dabb718"}, - {file = "contourpy-1.1.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:eafbef886566dc1047d7b3d4b14db0d5b7deb99638d8e1be4e23a7c7ac59ff0f"}, - {file = "contourpy-1.1.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:efe0fab26d598e1ec07d72cf03eaeeba8e42b4ecf6b9ccb5a356fde60ff08b85"}, - {file = "contourpy-1.1.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:f08e469821a5e4751c97fcd34bcb586bc243c39c2e39321822060ba902eac49e"}, - {file = "contourpy-1.1.1-cp39-cp39-win32.whl", hash = "sha256:bfc8a5e9238232a45ebc5cb3bfee71f1167064c8d382cadd6076f0d51cff1da0"}, - {file = "contourpy-1.1.1-cp39-cp39-win_amd64.whl", hash = "sha256:c84fdf3da00c2827d634de4fcf17e3e067490c4aea82833625c4c8e6cdea0887"}, - {file = "contourpy-1.1.1-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:229a25f68046c5cf8067d6d6351c8b99e40da11b04d8416bf8d2b1d75922521e"}, - {file = "contourpy-1.1.1-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a10dab5ea1bd4401c9483450b5b0ba5416be799bbd50fc7a6cc5e2a15e03e8a3"}, - {file = "contourpy-1.1.1-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:4f9147051cb8fdb29a51dc2482d792b3b23e50f8f57e3720ca2e3d438b7adf23"}, - {file = "contourpy-1.1.1-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:a75cc163a5f4531a256f2c523bd80db509a49fc23721b36dd1ef2f60ff41c3cb"}, - {file = "contourpy-1.1.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3b53d5769aa1f2d4ea407c65f2d1d08002952fac1d9e9d307aa2e1023554a163"}, - {file = "contourpy-1.1.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:11b836b7dbfb74e049c302bbf74b4b8f6cb9d0b6ca1bf86cfa8ba144aedadd9c"}, - {file = "contourpy-1.1.1.tar.gz", hash = "sha256:96ba37c2e24b7212a77da85004c38e7c4d155d3e72a45eeaf22c1f03f607e8ab"}, -] - -[package.dependencies] -numpy = {version = ">=1.16,<2.0", markers = "python_version <= \"3.11\""} + {file = "contourpy-1.2.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0274c1cb63625972c0c007ab14dd9ba9e199c36ae1a231ce45d725cbcbfd10a8"}, + {file = "contourpy-1.2.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ab459a1cbbf18e8698399c595a01f6dcc5c138220ca3ea9e7e6126232d102bb4"}, + {file = "contourpy-1.2.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6fdd887f17c2f4572ce548461e4f96396681212d858cae7bd52ba3310bc6f00f"}, + {file = "contourpy-1.2.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5d16edfc3fc09968e09ddffada434b3bf989bf4911535e04eada58469873e28e"}, + {file = "contourpy-1.2.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1c203f617abc0dde5792beb586f827021069fb6d403d7f4d5c2b543d87edceb9"}, + {file = "contourpy-1.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b69303ceb2e4d4f146bf82fda78891ef7bcd80c41bf16bfca3d0d7eb545448aa"}, + {file = "contourpy-1.2.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:884c3f9d42d7218304bc74a8a7693d172685c84bd7ab2bab1ee567b769696df9"}, + {file = "contourpy-1.2.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:4a1b1208102be6e851f20066bf0e7a96b7d48a07c9b0cfe6d0d4545c2f6cadab"}, + {file = "contourpy-1.2.0-cp310-cp310-win32.whl", hash = "sha256:34b9071c040d6fe45d9826cbbe3727d20d83f1b6110d219b83eb0e2a01d79488"}, + {file = "contourpy-1.2.0-cp310-cp310-win_amd64.whl", hash = "sha256:bd2f1ae63998da104f16a8b788f685e55d65760cd1929518fd94cd682bf03e41"}, + {file = "contourpy-1.2.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:dd10c26b4eadae44783c45ad6655220426f971c61d9b239e6f7b16d5cdaaa727"}, + {file = "contourpy-1.2.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:5c6b28956b7b232ae801406e529ad7b350d3f09a4fde958dfdf3c0520cdde0dd"}, + {file = "contourpy-1.2.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ebeac59e9e1eb4b84940d076d9f9a6cec0064e241818bcb6e32124cc5c3e377a"}, + {file = "contourpy-1.2.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:139d8d2e1c1dd52d78682f505e980f592ba53c9f73bd6be102233e358b401063"}, + {file = "contourpy-1.2.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1e9dc350fb4c58adc64df3e0703ab076f60aac06e67d48b3848c23647ae4310e"}, + {file = "contourpy-1.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:18fc2b4ed8e4a8fe849d18dce4bd3c7ea637758c6343a1f2bae1e9bd4c9f4686"}, + {file = "contourpy-1.2.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:16a7380e943a6d52472096cb7ad5264ecee36ed60888e2a3d3814991a0107286"}, + {file = "contourpy-1.2.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:8d8faf05be5ec8e02a4d86f616fc2a0322ff4a4ce26c0f09d9f7fb5330a35c95"}, + {file = "contourpy-1.2.0-cp311-cp311-win32.whl", hash = "sha256:67b7f17679fa62ec82b7e3e611c43a016b887bd64fb933b3ae8638583006c6d6"}, + {file = "contourpy-1.2.0-cp311-cp311-win_amd64.whl", hash = "sha256:99ad97258985328b4f207a5e777c1b44a83bfe7cf1f87b99f9c11d4ee477c4de"}, + {file = "contourpy-1.2.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:575bcaf957a25d1194903a10bc9f316c136c19f24e0985a2b9b5608bdf5dbfe0"}, + {file = "contourpy-1.2.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:9e6c93b5b2dbcedad20a2f18ec22cae47da0d705d454308063421a3b290d9ea4"}, + {file = "contourpy-1.2.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:464b423bc2a009088f19bdf1f232299e8b6917963e2b7e1d277da5041f33a779"}, + {file = "contourpy-1.2.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:68ce4788b7d93e47f84edd3f1f95acdcd142ae60bc0e5493bfd120683d2d4316"}, + {file = "contourpy-1.2.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3d7d1f8871998cdff5d2ff6a087e5e1780139abe2838e85b0b46b7ae6cc25399"}, + {file = "contourpy-1.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6e739530c662a8d6d42c37c2ed52a6f0932c2d4a3e8c1f90692ad0ce1274abe0"}, + {file = "contourpy-1.2.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:247b9d16535acaa766d03037d8e8fb20866d054d3c7fbf6fd1f993f11fc60ca0"}, + {file = "contourpy-1.2.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:461e3ae84cd90b30f8d533f07d87c00379644205b1d33a5ea03381edc4b69431"}, + {file = "contourpy-1.2.0-cp312-cp312-win32.whl", hash = "sha256:1c2559d6cffc94890b0529ea7eeecc20d6fadc1539273aa27faf503eb4656d8f"}, + {file = "contourpy-1.2.0-cp312-cp312-win_amd64.whl", hash = "sha256:491b1917afdd8638a05b611a56d46587d5a632cabead889a5440f7c638bc6ed9"}, + {file = "contourpy-1.2.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5fd1810973a375ca0e097dee059c407913ba35723b111df75671a1976efa04bc"}, + {file = "contourpy-1.2.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:999c71939aad2780f003979b25ac5b8f2df651dac7b38fb8ce6c46ba5abe6ae9"}, + {file = "contourpy-1.2.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b7caf9b241464c404613512d5594a6e2ff0cc9cb5615c9475cc1d9b514218ae8"}, + {file = "contourpy-1.2.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:266270c6f6608340f6c9836a0fb9b367be61dde0c9a9a18d5ece97774105ff3e"}, + {file = "contourpy-1.2.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dbd50d0a0539ae2e96e537553aff6d02c10ed165ef40c65b0e27e744a0f10af8"}, + {file = "contourpy-1.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:11f8d2554e52f459918f7b8e6aa20ec2a3bce35ce95c1f0ef4ba36fbda306df5"}, + {file = "contourpy-1.2.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:ce96dd400486e80ac7d195b2d800b03e3e6a787e2a522bfb83755938465a819e"}, + {file = "contourpy-1.2.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:6d3364b999c62f539cd403f8123ae426da946e142312a514162adb2addd8d808"}, + {file = "contourpy-1.2.0-cp39-cp39-win32.whl", hash = "sha256:1c88dfb9e0c77612febebb6ac69d44a8d81e3dc60f993215425b62c1161353f4"}, + {file = "contourpy-1.2.0-cp39-cp39-win_amd64.whl", hash = "sha256:78e6ad33cf2e2e80c5dfaaa0beec3d61face0fb650557100ee36db808bfa6843"}, + {file = "contourpy-1.2.0-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:be16975d94c320432657ad2402f6760990cb640c161ae6da1363051805fa8108"}, + {file = "contourpy-1.2.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b95a225d4948b26a28c08307a60ac00fb8671b14f2047fc5476613252a129776"}, + {file = "contourpy-1.2.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:0d7e03c0f9a4f90dc18d4e77e9ef4ec7b7bbb437f7f675be8e530d65ae6ef956"}, + {file = "contourpy-1.2.0.tar.gz", hash = "sha256:171f311cb758de7da13fc53af221ae47a5877be5a0843a9fe150818c51ed276a"}, +] + +[package.dependencies] +numpy = ">=1.20,<2.0" [package.extras] bokeh = ["bokeh", "selenium"] docs = ["furo", "sphinx (>=7.2)", "sphinx-copybutton"] -mypy = ["contourpy[bokeh,docs]", "docutils-stubs", "mypy (==1.4.1)", "types-Pillow"] +mypy = ["contourpy[bokeh,docs]", "docutils-stubs", "mypy (==1.6.1)", "types-Pillow"] test = ["Pillow", "contourpy[test-no-images]", "matplotlib"] -test-no-images = ["pytest", "pytest-cov", "wurlitzer"] +test-no-images = ["pytest", "pytest-cov", "pytest-xdist", "wurlitzer"] [[package]] name = "coverage" @@ -1052,6 +1059,25 @@ tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.1 [package.extras] toml = ["tomli"] +[[package]] +name = "cupy" +version = "12.2.0" +description = "CuPy: NumPy & SciPy for GPU" +optional = false +python-versions = ">=3.8" +files = [ + {file = "cupy-12.2.0.tar.gz", hash = "sha256:f95ffd0afeacb617b048fe028ede07b97dc9e95aca1610a022b1f3d20a9a027e"}, +] + +[package.dependencies] +fastrlock = ">=0.5" +numpy = ">=1.20,<1.27" + +[package.extras] +all = ["Cython (>=0.29.22,<3)", "optuna (>=2.0)", "scipy (>=1.6,<1.13)"] +stylecheck = ["autopep8 (==1.5.5)", "flake8 (==3.8.4)", "mypy (==1.4.1)", "pbr (==5.5.1)", "pycodestyle (==2.6.0)", "types-setuptools (==57.4.14)"] +test = ["hypothesis (>=6.37.2,<6.55.0)", "pytest (>=7.2)"] + [[package]] name = "cycler" version = "0.12.1" @@ -1306,6 +1332,90 @@ files = [ [package.extras] devel = ["colorama", "json-spec", "jsonschema", "pylint", "pytest", "pytest-benchmark", "pytest-cache", "validictory"] +[[package]] +name = "fastrlock" +version = "0.8.2" +description = "Fast, re-entrant optimistic lock implemented in Cython" +optional = false +python-versions = "*" +files = [ + {file = "fastrlock-0.8.2-cp27-cp27m-macosx_10_15_x86_64.whl", hash = "sha256:94e348c72a1fd1f8191f25ea056448e4f5a87b8fbf005b39d290dcb0581a48cd"}, + {file = "fastrlock-0.8.2-cp27-cp27m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2d5595903444c854b99c42122b87edfe8a37cd698a4eae32f4fd1d2a7b6c115d"}, + {file = "fastrlock-0.8.2-cp27-cp27m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:e4bbde174a0aff5f6eeba75cf8c4c5d2a316316bc21f03a0bddca0fc3659a6f3"}, + {file = "fastrlock-0.8.2-cp27-cp27mu-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:7a2ccaf88ac0db153e84305d1ef0aa138cea82c6a88309066f6eaa3bc98636cd"}, + {file = "fastrlock-0.8.2-cp27-cp27mu-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:31a27a2edf482df72b91fe6c6438314d2c65290aa7becc55589d156c9b91f0da"}, + {file = "fastrlock-0.8.2-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:e9904b5b37c3e5bb4a245c56bc4b7e497da57ffb8528f4fc39af9dcb168ee2e1"}, + {file = "fastrlock-0.8.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:43a241655e83e4603a152192cf022d5ca348c2f4e56dfb02e5c9c4c1a32f9cdb"}, + {file = "fastrlock-0.8.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9121a894d74e65557e47e777060a495ab85f4b903e80dd73a3c940ba042920d7"}, + {file = "fastrlock-0.8.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_24_i686.whl", hash = "sha256:11bbbbc526363955aeddb9eec4cee2a0012322b7b2f15b54f44454fcf4fd398a"}, + {file = "fastrlock-0.8.2-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:27786c62a400e282756ae1b090bcd7cfa35f28270cff65a9e7b27a5327a32561"}, + {file = "fastrlock-0.8.2-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:08315bde19d0c2e6b06593d5a418be3dc8f9b1ee721afa96867b9853fceb45cf"}, + {file = "fastrlock-0.8.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:e8b49b5743ede51e0bcf6805741f39f5e0e0fd6a172ba460cb39e3097ba803bb"}, + {file = "fastrlock-0.8.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b443e73a4dfc7b6e0800ea4c13567b9694358e86f53bb2612a51c9e727cac67b"}, + {file = "fastrlock-0.8.2-cp310-cp310-win_amd64.whl", hash = "sha256:b3853ed4ce522598dc886160a7bab432a093051af85891fa2f5577c1dcac8ed6"}, + {file = "fastrlock-0.8.2-cp311-cp311-macosx_10_15_universal2.whl", hash = "sha256:790fc19bccbd39426060047e53629f171a44745613bf360a045e9f9c8c4a2cea"}, + {file = "fastrlock-0.8.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:dbdce852e6bb66e1b8c36679d482971d69d93acf1785657522e51b7de30c3356"}, + {file = "fastrlock-0.8.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d47713ffe6d4a627fbf078be9836a95ac106b4a0543e3841572c91e292a5d885"}, + {file = "fastrlock-0.8.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_24_i686.whl", hash = "sha256:ea96503b918fceaf40443182742b8964d47b65c5ebdea532893cb9479620000c"}, + {file = "fastrlock-0.8.2-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:c6bffa978793bea5e1b00e677062e53a62255439339591b70e209fa1552d5ee0"}, + {file = "fastrlock-0.8.2-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:75c07726c8b1a52147fd7987d6baaa318c5dced1416c3f25593e40f56e10755b"}, + {file = "fastrlock-0.8.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:88f079335e9da631efa64486c8207564a7bcd0c00526bb9e842e9d5b7e50a6cc"}, + {file = "fastrlock-0.8.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:4fb2e77ff04bc4beb71d63c8e064f052ce5a6ea1e001d528d4d7f4b37d736f2e"}, + {file = "fastrlock-0.8.2-cp311-cp311-win_amd64.whl", hash = "sha256:b4c9083ea89ab236b06e9ef2263971db3b4b507195fc7d5eecab95828dcae325"}, + {file = "fastrlock-0.8.2-cp312-cp312-macosx_10_15_universal2.whl", hash = "sha256:98195866d3a9949915935d40a88e4f1c166e82e378f622c88025f2938624a90a"}, + {file = "fastrlock-0.8.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b22ea9bf5f9fad2b0077e944a7813f91593a4f61adf8faf734a70aed3f2b3a40"}, + {file = "fastrlock-0.8.2-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:dcc1bf0ac8a194313cf6e645e300a8a379674ceed8e0b1e910a2de3e3c28989e"}, + {file = "fastrlock-0.8.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:a3dcc876050b8f5cbc0ee84ef1e7f0c1dfe7c148f10098828bc4403683c33f10"}, + {file = "fastrlock-0.8.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:685e656048b59d8dfde8c601f188ad53a4d719eb97080cafc8696cda6d75865e"}, + {file = "fastrlock-0.8.2-cp312-cp312-win_amd64.whl", hash = "sha256:fb5363cf0fddd9b50525ddbf64a1e1b28ec4c6dfb28670a940cb1cf988a6786b"}, + {file = "fastrlock-0.8.2-cp35-cp35m-macosx_10_15_x86_64.whl", hash = "sha256:a74f5a92fa6e51c4f3c69b29c4662088b97be12f40652a21109605a175c81824"}, + {file = "fastrlock-0.8.2-cp35-cp35m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ccf39ad5702e33e4d335b48ef9d56e21619b529b7f7471b5211419f380329b62"}, + {file = "fastrlock-0.8.2-cp35-cp35m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:66f2662c640bb71a1016a031eea6eef9d25c2bcdf7ffd1d1ddc5a58f9a1ced04"}, + {file = "fastrlock-0.8.2-cp36-cp36m-macosx_10_15_x86_64.whl", hash = "sha256:17734e2e5af4c07ddb0fb10bd484e062c22de3be6b67940b9cc6ec2f18fa61ba"}, + {file = "fastrlock-0.8.2-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:ab91b0c36e95d42e1041a4907e3eefd06c482d53af3c7a77be7e214cc7cd4a63"}, + {file = "fastrlock-0.8.2-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b32fdf874868326351a75b1e4c02f97e802147119ae44c52d3d9da193ec34f5b"}, + {file = "fastrlock-0.8.2-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_24_i686.whl", hash = "sha256:2074548a335fcf7d19ebb18d9208da9e33b06f745754466a7e001d2b1c58dd19"}, + {file = "fastrlock-0.8.2-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:4fb04442b6d1e2b36c774919c6bcbe3339c61b337261d4bd57e27932589095af"}, + {file = "fastrlock-0.8.2-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:1fed2f4797ad68e9982038423018cf08bec5f4ce9fed63a94a790773ed6a795c"}, + {file = "fastrlock-0.8.2-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e380ec4e6d8b26e389713995a43cb7fe56baea2d25fe073d4998c4821a026211"}, + {file = "fastrlock-0.8.2-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:25945f962c7bd808415cfde3da624d4399d4ea71ed8918538375f16bceb79e1c"}, + {file = "fastrlock-0.8.2-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:2c1719ddc8218b01e82fb2e82e8451bd65076cb96d7bef4477194bbb4305a968"}, + {file = "fastrlock-0.8.2-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:5460c5ee6ced6d61ec8cd2324ebbe793a4960c4ffa2131ffff480e3b61c99ec5"}, + {file = "fastrlock-0.8.2-cp36-cp36m-win_amd64.whl", hash = "sha256:33145acbad8317584cd64588131c7e1e286beef6280c0009b4544c91fce171d2"}, + {file = "fastrlock-0.8.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:59344c1d46b7dec97d3f22f1cc930fafe8980b3c5bc9c9765c56738a5f1559e4"}, + {file = "fastrlock-0.8.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b2a1c354f13f22b737621d914f3b4a8434ae69d3027a775e94b3e671756112f9"}, + {file = "fastrlock-0.8.2-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_24_i686.whl", hash = "sha256:cf81e0278b645004388873e0a1f9e3bc4c9ab8c18e377b14ed1a544be4b18c9a"}, + {file = "fastrlock-0.8.2-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:1b15430b93d7eb3d56f6ff690d2ebecb79ed0e58248427717eba150a508d1cd7"}, + {file = "fastrlock-0.8.2-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:067edb0a0805bf61e17a251d5046af59f6e9d2b8ad01222e0ef7a0b7937d5548"}, + {file = "fastrlock-0.8.2-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:eb31fe390f03f7ae886dcc374f1099ec88526631a4cb891d399b68181f154ff0"}, + {file = "fastrlock-0.8.2-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:643e1e65b4f5b284427e61a894d876d10459820e93aa1e724dfb415117be24e0"}, + {file = "fastrlock-0.8.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:5dfb78dd600a12f23fc0c3ec58f81336229fdc74501ecf378d1ce5b3f2f313ea"}, + {file = "fastrlock-0.8.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:b8ca0fe21458457077e4cb2d81e1ebdb146a00b3e9e2db6180a773f7ea905032"}, + {file = "fastrlock-0.8.2-cp37-cp37m-win_amd64.whl", hash = "sha256:d918dfe473291e8bfd8e13223ea5cb9b317bd9f50c280923776c377f7c64b428"}, + {file = "fastrlock-0.8.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:c393af77c659a38bffbca215c0bcc8629ba4299568308dd7e4ff65d62cabed39"}, + {file = "fastrlock-0.8.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:73426f5eb2ecc10626c67cf86bd0af9e00d53e80e5c67d5ce8e18376d6abfa09"}, + {file = "fastrlock-0.8.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_24_i686.whl", hash = "sha256:320fd55bafee3eb069cfb5d6491f811a912758387ef2193840e2663e80e16f48"}, + {file = "fastrlock-0.8.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:8c1c91a68926421f5ccbc82c85f83bd3ba593b121a46a1b9a554b3f0dd67a4bf"}, + {file = "fastrlock-0.8.2-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:ad1bc61c7f6b0e58106aaab034916b6cb041757f708b07fbcdd9d6e1ac629225"}, + {file = "fastrlock-0.8.2-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:87f4e01b042c84e6090dbc4fbe3415ddd69f6bc0130382323f9d3f1b8dd71b46"}, + {file = "fastrlock-0.8.2-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:d34546ad2e4a480b94b6797bcc5a322b3c705c4c74c3e4e545c4a3841c1b2d59"}, + {file = "fastrlock-0.8.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:ebb32d776b61acd49f859a1d16b9e3d84e7b46d0d92aebd58acd54dc38e96664"}, + {file = "fastrlock-0.8.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:30bdbe4662992348132d03996700e1cf910d141d629179b967b146a22942264e"}, + {file = "fastrlock-0.8.2-cp38-cp38-win_amd64.whl", hash = "sha256:07ed3c7b3867c05a3d6be4ced200c7767000f3431b9be6da66972822dd86e8be"}, + {file = "fastrlock-0.8.2-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:ddf5d247f686aec853ddcc9a1234bfcc6f57b0a0670d2ad82fc25d8ae7e6a15f"}, + {file = "fastrlock-0.8.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:7269bb3fc15587b0c191eecd95831d771a7d80f0c48929e560806b038ff3066c"}, + {file = "fastrlock-0.8.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:adcb9e77aa132cc6c9de2ffe7cf880a20aa8cdba21d367d1da1a412f57bddd5d"}, + {file = "fastrlock-0.8.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_24_i686.whl", hash = "sha256:a3b8b5d2935403f1b4b25ae324560e94b59593a38c0d2e7b6c9872126a9622ed"}, + {file = "fastrlock-0.8.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2587cedbb36c7988e707d83f0f1175c1f882f362b5ebbee25d70218ea33d220d"}, + {file = "fastrlock-0.8.2-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:9af691a9861027181d4de07ed74f0aee12a9650ac60d0a07f4320bff84b5d95f"}, + {file = "fastrlock-0.8.2-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:99dd6652bd6f730beadf74ef769d38c6bbd8ee6d1c15c8d138ea680b0594387f"}, + {file = "fastrlock-0.8.2-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:4d63b6596368dab9e0cc66bf047e7182a56f33b34db141816a4f21f5bf958228"}, + {file = "fastrlock-0.8.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:ff75c90663d6e8996610d435e71487daa853871ad1770dd83dc0f2fc4997241e"}, + {file = "fastrlock-0.8.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:e27c3cd27fbd25e5223c5c992b300cd4ee8f0a75c6f222ce65838138d853712c"}, + {file = "fastrlock-0.8.2-cp39-cp39-win_amd64.whl", hash = "sha256:dd961a32a7182c3891cdebca417fda67496d5d5de6ae636962254d22723bdf52"}, + {file = "fastrlock-0.8.2.tar.gz", hash = "sha256:644ec9215cf9c4df8028d8511379a15d9c1af3e16d80e47f1b6fdc6ba118356a"}, +] + [[package]] name = "filelock" version = "3.13.1" @@ -1340,57 +1450,57 @@ pyflakes = ">=3.1.0,<3.2.0" [[package]] name = "fonttools" -version = "4.43.1" +version = "4.44.0" description = "Tools to manipulate font files" optional = false python-versions = ">=3.8" files = [ - {file = "fonttools-4.43.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:bf11e2cca121df35e295bd34b309046c29476ee739753bc6bc9d5050de319273"}, - {file = "fonttools-4.43.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:10b3922875ffcba636674f406f9ab9a559564fdbaa253d66222019d569db869c"}, - {file = "fonttools-4.43.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9f727c3e3d08fd25352ed76cc3cb61486f8ed3f46109edf39e5a60fc9fecf6ca"}, - {file = "fonttools-4.43.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ad0b3f6342cfa14be996971ea2b28b125ad681c6277c4cd0fbdb50340220dfb6"}, - {file = "fonttools-4.43.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:3b7ad05b2beeebafb86aa01982e9768d61c2232f16470f9d0d8e385798e37184"}, - {file = "fonttools-4.43.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:4c54466f642d2116686268c3e5f35ebb10e49b0d48d41a847f0e171c785f7ac7"}, - {file = "fonttools-4.43.1-cp310-cp310-win32.whl", hash = "sha256:1e09da7e8519e336239fbd375156488a4c4945f11c4c5792ee086dd84f784d02"}, - {file = "fonttools-4.43.1-cp310-cp310-win_amd64.whl", hash = "sha256:1cf9e974f63b1080b1d2686180fc1fbfd3bfcfa3e1128695b5de337eb9075cef"}, - {file = "fonttools-4.43.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:5db46659cfe4e321158de74c6f71617e65dc92e54980086823a207f1c1c0e24b"}, - {file = "fonttools-4.43.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:1952c89a45caceedf2ab2506d9a95756e12b235c7182a7a0fff4f5e52227204f"}, - {file = "fonttools-4.43.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9c36da88422e0270fbc7fd959dc9749d31a958506c1d000e16703c2fce43e3d0"}, - {file = "fonttools-4.43.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7bbbf8174501285049e64d174e29f9578495e1b3b16c07c31910d55ad57683d8"}, - {file = "fonttools-4.43.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:d4071bd1c183b8d0b368cc9ed3c07a0f6eb1bdfc4941c4c024c49a35429ac7cd"}, - {file = "fonttools-4.43.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d21099b411e2006d3c3e1f9aaf339e12037dbf7bf9337faf0e93ec915991f43b"}, - {file = "fonttools-4.43.1-cp311-cp311-win32.whl", hash = "sha256:b84a1c00f832feb9d0585ca8432fba104c819e42ff685fcce83537e2e7e91204"}, - {file = "fonttools-4.43.1-cp311-cp311-win_amd64.whl", hash = "sha256:9a2f0aa6ca7c9bc1058a9d0b35483d4216e0c1bbe3962bc62ce112749954c7b8"}, - {file = "fonttools-4.43.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:4d9740e3783c748521e77d3c397dc0662062c88fd93600a3c2087d3d627cd5e5"}, - {file = "fonttools-4.43.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:884ef38a5a2fd47b0c1291647b15f4e88b9de5338ffa24ee52c77d52b4dfd09c"}, - {file = "fonttools-4.43.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9648518ef687ba818db3fcc5d9aae27a369253ac09a81ed25c3867e8657a0680"}, - {file = "fonttools-4.43.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95e974d70238fc2be5f444fa91f6347191d0e914d5d8ae002c9aa189572cc215"}, - {file = "fonttools-4.43.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:34f713dad41aa21c637b4e04fe507c36b986a40f7179dcc86402237e2d39dcd3"}, - {file = "fonttools-4.43.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:360201d46165fc0753229afe785900bc9596ee6974833124f4e5e9f98d0f592b"}, - {file = "fonttools-4.43.1-cp312-cp312-win32.whl", hash = "sha256:bb6d2f8ef81ea076877d76acfb6f9534a9c5f31dc94ba70ad001267ac3a8e56f"}, - {file = "fonttools-4.43.1-cp312-cp312-win_amd64.whl", hash = "sha256:25d3da8a01442cbc1106490eddb6d31d7dffb38c1edbfabbcc8db371b3386d72"}, - {file = "fonttools-4.43.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:8da417431bfc9885a505e86ba706f03f598c85f5a9c54f67d63e84b9948ce590"}, - {file = "fonttools-4.43.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:51669b60ee2a4ad6c7fc17539a43ffffc8ef69fd5dbed186a38a79c0ac1f5db7"}, - {file = "fonttools-4.43.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:748015d6f28f704e7d95cd3c808b483c5fb87fd3eefe172a9da54746ad56bfb6"}, - {file = "fonttools-4.43.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7a58eb5e736d7cf198eee94844b81c9573102ae5989ebcaa1d1a37acd04b33d"}, - {file = "fonttools-4.43.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:6bb5ea9076e0e39defa2c325fc086593ae582088e91c0746bee7a5a197be3da0"}, - {file = "fonttools-4.43.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:5f37e31291bf99a63328668bb83b0669f2688f329c4c0d80643acee6e63cd933"}, - {file = "fonttools-4.43.1-cp38-cp38-win32.whl", hash = "sha256:9c60ecfa62839f7184f741d0509b5c039d391c3aff71dc5bc57b87cc305cff3b"}, - {file = "fonttools-4.43.1-cp38-cp38-win_amd64.whl", hash = "sha256:fe9b1ec799b6086460a7480e0f55c447b1aca0a4eecc53e444f639e967348896"}, - {file = "fonttools-4.43.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:13a9a185259ed144def3682f74fdcf6596f2294e56fe62dfd2be736674500dba"}, - {file = "fonttools-4.43.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b2adca1b46d69dce4a37eecc096fe01a65d81a2f5c13b25ad54d5430ae430b13"}, - {file = "fonttools-4.43.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:18eefac1b247049a3a44bcd6e8c8fd8b97f3cad6f728173b5d81dced12d6c477"}, - {file = "fonttools-4.43.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2062542a7565091cea4cc14dd99feff473268b5b8afdee564f7067dd9fff5860"}, - {file = "fonttools-4.43.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:18a2477c62a728f4d6e88c45ee9ee0229405e7267d7d79ce1f5ce0f3e9f8ab86"}, - {file = "fonttools-4.43.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a7a06f8d95b7496e53af80d974d63516ffb263a468e614978f3899a6df52d4b3"}, - {file = "fonttools-4.43.1-cp39-cp39-win32.whl", hash = "sha256:10003ebd81fec0192c889e63a9c8c63f88c7d72ae0460b7ba0cd2a1db246e5ad"}, - {file = "fonttools-4.43.1-cp39-cp39-win_amd64.whl", hash = "sha256:e117a92b07407a061cde48158c03587ab97e74e7d73cb65e6aadb17af191162a"}, - {file = "fonttools-4.43.1-py3-none-any.whl", hash = "sha256:4f88cae635bfe4bbbdc29d479a297bb525a94889184bb69fa9560c2d4834ddb9"}, - {file = "fonttools-4.43.1.tar.gz", hash = "sha256:17dbc2eeafb38d5d0e865dcce16e313c58265a6d2d20081c435f84dc5a9d8212"}, -] - -[package.extras] -all = ["brotli (>=1.0.1)", "brotlicffi (>=0.8.0)", "fs (>=2.2.0,<3)", "lxml (>=4.0,<5)", "lz4 (>=1.7.4.2)", "matplotlib", "munkres", "scipy", "skia-pathops (>=0.5.0)", "sympy", "uharfbuzz (>=0.23.0)", "unicodedata2 (>=15.0.0)", "xattr", "zopfli (>=0.1.4)"] + {file = "fonttools-4.44.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:e1cd1c6bb097e774d68402499ff66185190baaa2629ae2f18515a2c50b93db0c"}, + {file = "fonttools-4.44.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b9eab7f9837fdaa2a10a524fbcc2ec24bf60637c044b6e4a59c3f835b90f0fae"}, + {file = "fonttools-4.44.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0f412954275e594f7a51c16f3b3edd850acb0d842fefc33856b63a17e18499a5"}, + {file = "fonttools-4.44.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:50d25893885e80a5955186791eed5579f1e75921751539cc1dc3ffd1160b48cf"}, + {file = "fonttools-4.44.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:22ea8aa7b3712450b42b044702bd3a64fd118006bad09a6f94bd1b227088492e"}, + {file = "fonttools-4.44.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:df40daa6c03b98652ffe8110ae014fe695437f6e1cb5a07e16ea37f40e73ac86"}, + {file = "fonttools-4.44.0-cp310-cp310-win32.whl", hash = "sha256:bca49da868e8bde569ef36f0cc1b6de21d56bf9c3be185c503b629c19a185287"}, + {file = "fonttools-4.44.0-cp310-cp310-win_amd64.whl", hash = "sha256:dbac86d83d96099890e731cc2af97976ff2c98f4ba432fccde657c5653a32f1c"}, + {file = "fonttools-4.44.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:e8ff7d19a6804bfd561cfcec9b4200dd1788e28f7de4be70189801530c47c1b3"}, + {file = "fonttools-4.44.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a8a1fa9a718de0bc026979c93e1e9b55c5efde60d76f91561fd713387573817d"}, + {file = "fonttools-4.44.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c05064f95aacdfc06f21e55096c964b2228d942b8675fa26995a2551f6329d2d"}, + {file = "fonttools-4.44.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:31b38528f25bc662401e6ffae14b3eb7f1e820892fd80369a37155e3b636a2f4"}, + {file = "fonttools-4.44.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:05d7c4d2c95b9490e669f3cb83918799bf1c838619ac6d3bad9ea017cfc63f2e"}, + {file = "fonttools-4.44.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:6999e80a125b0cd8e068d0210b63323f17338038c2ecd2e11b9209ec430fe7f2"}, + {file = "fonttools-4.44.0-cp311-cp311-win32.whl", hash = "sha256:a7aec7f5d14dfcd71fb3ebc299b3f000c21fdc4043079101777ed2042ba5b7c5"}, + {file = "fonttools-4.44.0-cp311-cp311-win_amd64.whl", hash = "sha256:518a945dbfe337744bfff31423c1430303b8813c5275dffb0f2577f0734a1189"}, + {file = "fonttools-4.44.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:59b6ad83cce067d10f4790c037a5904424f45bebb5e7be2eb2db90402f288267"}, + {file = "fonttools-4.44.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:c2de1fb18198acd400c45ffe2aef5420c8d55fde903e91cba705596099550f3b"}, + {file = "fonttools-4.44.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:84f308b7a8d28208d54315d11d35f9888d6d607673dd4d42d60b463682ee0400"}, + {file = "fonttools-4.44.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66bc6efd829382f7a7e6cf33c2fb32b13edc8a239eb15f32acbf197dce7a0165"}, + {file = "fonttools-4.44.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:a8b99713d3a0d0e876b6aecfaada5e7dc9fe979fcd90ef9fa0ba1d9b9aed03f2"}, + {file = "fonttools-4.44.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:b63da598d9cbc52e2381f922da0e94d60c0429f92207bd3fb04d112fc82ea7cb"}, + {file = "fonttools-4.44.0-cp312-cp312-win32.whl", hash = "sha256:f611c97678604e302b725f71626edea113a5745a7fb557c958b39edb6add87d5"}, + {file = "fonttools-4.44.0-cp312-cp312-win_amd64.whl", hash = "sha256:58af428746fa73a2edcbf26aff33ac4ef3c11c8d75bb200eaea2f7e888d2de4e"}, + {file = "fonttools-4.44.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:9ee8692e23028564c13d924004495f284df8ac016a19f17a87251210e1f1f928"}, + {file = "fonttools-4.44.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:dab3d00d27b1a79ae4d4a240e8ceea8af0ff049fd45f05adb4f860d93744110d"}, + {file = "fonttools-4.44.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f53526668beccdb3409c6055a4ffe50987a7f05af6436fa55d61f5e7bd450219"}, + {file = "fonttools-4.44.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a3da036b016c975c2d8c69005bdc4d5d16266f948a7fab950244e0f58301996a"}, + {file = "fonttools-4.44.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:b99fe8ef4093f672d00841569d2d05691e50334d79f4d9c15c1265d76d5580d2"}, + {file = "fonttools-4.44.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:6d16d9634ff1e5cea2cf4a8cbda9026f766e4b5f30b48f8180f0e99133d3abfc"}, + {file = "fonttools-4.44.0-cp38-cp38-win32.whl", hash = "sha256:3d29509f6e05e8d725db59c2d8c076223d793e4e35773040be6632a0349f2f97"}, + {file = "fonttools-4.44.0-cp38-cp38-win_amd64.whl", hash = "sha256:d4fa4f4bc8fd86579b8cdbe5e948f35d82c0eda0091c399d009b2a5a6b61c040"}, + {file = "fonttools-4.44.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:c794de4086f06ae609b71ac944ec7deb09f34ecf73316fddc041087dd24bba39"}, + {file = "fonttools-4.44.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:2db63941fee3122e31a21dd0f5b2138ce9906b661a85b63622421d3654a74ae2"}, + {file = "fonttools-4.44.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eb01c49c8aa035d5346f46630209923d4927ed15c2493db38d31da9f811eb70d"}, + {file = "fonttools-4.44.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:46c79af80a835410874683b5779b6c1ec1d5a285e11c45b5193e79dd691eb111"}, + {file = "fonttools-4.44.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b6e6aa2d066f8dafd06d8d0799b4944b5d5a1f015dd52ac01bdf2895ebe169a0"}, + {file = "fonttools-4.44.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:63a3112f753baef8c6ac2f5f574bb9ac8001b86c8c0c0380039db47a7f512d20"}, + {file = "fonttools-4.44.0-cp39-cp39-win32.whl", hash = "sha256:54efed22b2799a85475e6840e907c402ba49892c614565dc770aa97a53621b2b"}, + {file = "fonttools-4.44.0-cp39-cp39-win_amd64.whl", hash = "sha256:2e91e19b583961979e2e5a701269d3cfc07418963bee717f8160b0a24332826b"}, + {file = "fonttools-4.44.0-py3-none-any.whl", hash = "sha256:b9beb0fa6ff3ea808ad4a6962d68ac0f140ddab080957b20d9e268e4d67fb335"}, + {file = "fonttools-4.44.0.tar.gz", hash = "sha256:4e90dd81b6e0d97ebfe52c0d12a17a9ef7f305d6bfbb93081265057d6092f252"}, +] + +[package.extras] +all = ["brotli (>=1.0.1)", "brotlicffi (>=0.8.0)", "fs (>=2.2.0,<3)", "lxml (>=4.0,<5)", "lz4 (>=1.7.4.2)", "matplotlib", "munkres", "scipy", "skia-pathops (>=0.5.0)", "sympy", "uharfbuzz (>=0.23.0)", "unicodedata2 (>=15.1.0)", "xattr", "zopfli (>=0.1.4)"] graphite = ["lz4 (>=1.7.4.2)"] interpolatable = ["munkres", "scipy"] lxml = ["lxml (>=4.0,<5)"] @@ -1400,7 +1510,7 @@ repacker = ["uharfbuzz (>=0.23.0)"] symfont = ["sympy"] type1 = ["xattr"] ufo = ["fs (>=2.2.0,<3)"] -unicode = ["unicodedata2 (>=15.0.0)"] +unicode = ["unicodedata2 (>=15.1.0)"] woff = ["brotli (>=1.0.1)", "brotlicffi (>=0.8.0)", "zopfli (>=0.1.4)"] [[package]] @@ -2283,7 +2393,6 @@ optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*" files = [ {file = "jsonpointer-2.4-py2.py3-none-any.whl", hash = "sha256:15d51bba20eea3165644553647711d150376234112651b4f1811022aecad7d7a"}, - {file = "jsonpointer-2.4.tar.gz", hash = "sha256:585cee82b70211fa9e6043b7bb89db6e1aa49524340dde8ad6b63206ea689d88"}, ] [[package]] @@ -2861,6 +2970,87 @@ files = [ docs = ["Sphinx (==5.1.0)", "doc8 (>=0.8.1)", "sphinx-rtd-theme (>=0.5.0)", "sphinxcontrib-apidoc (>=0.3.0)"] testing = ["black", "isort", "pytest (>=6,!=7.0.0)", "pytest-xdist (>=2)", "twine"] +[[package]] +name = "lightning" +version = "2.1.0" +description = "The Deep Learning framework to train, deploy, and ship AI products Lightning fast." +optional = false +python-versions = ">=3.8" +files = [ + {file = "lightning-2.1.0-py3-none-any.whl", hash = "sha256:c12bd10bd28b9e29a8e877be039350a585f248c10b76360faa2aa2497f980de6"}, + {file = "lightning-2.1.0.tar.gz", hash = "sha256:1f78f5995ae7dcffa1edf34320db136902b73a0d1b304404c48ec8be165b3a93"}, +] + +[package.dependencies] +fsspec = {version = ">2021.06.0,<2025.0", extras = ["http"]} +lightning-utilities = ">=0.8.0,<2.0" +numpy = ">=1.17.2,<3.0" +packaging = ">=20.0,<25.0" +pytorch-lightning = "*" +PyYAML = ">=5.4,<8.0" +torch = ">=1.12.0,<4.0" +torchmetrics = ">=0.7.0,<3.0" +tqdm = ">=4.57.0,<6.0" +typing-extensions = ">=4.0.0,<6.0" + +[package.extras] +all = ["Jinja2 (<4.0)", "Pillow (>=9.5.0)", "PyYAML (<7.0)", "aiohttp (>=3.8.0,<4.0)", "arrow (>=1.2.0,<2.0)", "backoff (>=2.2.1,<3.0)", "beautifulsoup4 (>=4.8.0,<5.0)", "click (<9.0)", "croniter (>=1.3.0,<1.5.0)", "dateutils (<1.0)", "deepdiff (>=5.7.0,<7.0)", "deepspeed (>=0.8.2,<=0.9.3)", "docker (>=5.0.0,<7.0)", "fastapi (>=0.92.0,<1.0)", "fsspec (>=2022.5.0,<2024.0)", "fsspec[http] (>2021.06.0,<2024.0)", "gym[classic-control] (>=0.17.0,<1.0)", "hydra-core (>=1.0.5,<2.0)", "inquirer (>=2.10.0,<4.0)", "ipython[all] (<9.0)", "jsonargparse[signatures] (>=4.18.0,<5.0)", "lightning-api-access (>=0.0.3)", "lightning-cloud (==0.5.39)", "lightning-fabric (>=1.9.0)", "lightning-utilities (>=0.8.0,<1.0)", "matplotlib (>3.1,<4.0)", "omegaconf (>=2.0.5,<3.0)", "packaging", "panel (>=1.0.0,<2.0)", "psutil (<6.0)", "pydantic (>=1.7.4)", "python-multipart (>=0.0.5,<1.0)", "pytorch-lightning (>=1.9.0)", "redis (>=4.0.1,<6.0)", "requests (<3.0)", "rich (>=12.3.0,<14.0)", "s3fs (>=2022.5.0,<2024.0)", "starlette", "starsessions (>=1.2.1,<2.0)", "streamlit (>=1.13.0,<2.0)", "tensorboardX (>=2.2,<3.0)", "torch (>0.14.0,<3.0)", "torchdata (>0.5.9,<1.0)", "torchmetrics (>=0.10.0,<2.0)", "torchvision (>=0.13.0,<1.0)", "torchvision (>=0.15.2,<1.0)", "traitlets (>=5.3.0,<6.0)", "typing-extensions (>=4.0.0,<5.0)", "urllib3 (<3.0)", "uvicorn (<1.0)", "websocket-client (<2.0)", "websockets (<12.0)"] +app = ["Jinja2 (<4.0)", "PyYAML (<7.0)", "arrow (>=1.2.0,<2.0)", "backoff (>=2.2.1,<3.0)", "beautifulsoup4 (>=4.8.0,<5.0)", "click (<9.0)", "croniter (>=1.3.0,<1.5.0)", "dateutils (<1.0)", "deepdiff (>=5.7.0,<7.0)", "fastapi (>=0.92.0,<1.0)", "fsspec (>=2022.5.0,<2024.0)", "inquirer (>=2.10.0,<4.0)", "lightning-cloud (==0.5.39)", "lightning-utilities (>=0.8.0,<1.0)", "packaging", "psutil (<6.0)", "pydantic (>=1.7.4)", "python-multipart (>=0.0.5,<1.0)", "requests (<3.0)", "rich (>=12.3.0,<14.0)", "starlette", "starsessions (>=1.2.1,<2.0)", "traitlets (>=5.3.0,<6.0)", "typing-extensions (>=4.0.0,<5.0)", "urllib3 (<3.0)", "uvicorn (<1.0)", "websocket-client (<2.0)", "websockets (<12.0)"] +app-all = ["Jinja2 (<4.0)", "PyYAML (<7.0)", "aiohttp (>=3.8.0,<4.0)", "arrow (>=1.2.0,<2.0)", "backoff (>=2.2.1,<3.0)", "beautifulsoup4 (>=4.8.0,<5.0)", "click (<9.0)", "croniter (>=1.3.0,<1.5.0)", "dateutils (<1.0)", "deepdiff (>=5.7.0,<7.0)", "docker (>=5.0.0,<7.0)", "fastapi (>=0.92.0,<1.0)", "fsspec (>=2022.5.0,<2024.0)", "inquirer (>=2.10.0,<4.0)", "lightning-api-access (>=0.0.3)", "lightning-cloud (==0.5.39)", "lightning-fabric (>=1.9.0)", "lightning-utilities (>=0.8.0,<1.0)", "packaging", "panel (>=1.0.0,<2.0)", "psutil (<6.0)", "pydantic (>=1.7.4)", "python-multipart (>=0.0.5,<1.0)", "pytorch-lightning (>=1.9.0)", "redis (>=4.0.1,<6.0)", "requests (<3.0)", "rich (>=12.3.0,<14.0)", "s3fs (>=2022.5.0,<2024.0)", "starlette", "starsessions (>=1.2.1,<2.0)", "streamlit (>=1.13.0,<2.0)", "traitlets (>=5.3.0,<6.0)", "typing-extensions (>=4.0.0,<5.0)", "urllib3 (<3.0)", "uvicorn (<1.0)", "websocket-client (<2.0)", "websockets (<12.0)"] +app-cloud = ["docker (>=5.0.0,<7.0)", "redis (>=4.0.1,<6.0)", "s3fs (>=2022.5.0,<2024.0)"] +app-components = ["aiohttp (>=3.8.0,<4.0)", "lightning-api-access (>=0.0.3)", "lightning-fabric (>=1.9.0)", "pytorch-lightning (>=1.9.0)"] +app-dev = ["Jinja2 (<4.0)", "PyYAML (<7.0)", "aiohttp (>=3.8.0,<4.0)", "arrow (>=1.2.0,<2.0)", "backoff (>=2.2.1,<3.0)", "beautifulsoup4 (>=4.8.0,<5.0)", "click (<9.0)", "coverage (==7.3.1)", "croniter (>=1.3.0,<1.5.0)", "dateutils (<1.0)", "deepdiff (>=5.7.0,<7.0)", "docker (>=5.0.0,<7.0)", "fastapi (>=0.92.0,<1.0)", "fsspec (>=2022.5.0,<2024.0)", "httpx (==0.25.0)", "inquirer (>=2.10.0,<4.0)", "lightning-api-access (>=0.0.3)", "lightning-cloud (==0.5.39)", "lightning-fabric (>=1.9.0)", "lightning-utilities (>=0.8.0,<1.0)", "packaging", "panel (>=1.0.0,<2.0)", "playwright (==1.38.0)", "psutil (<6.0)", "pydantic (>=1.7.4)", "pympler", "pytest (==7.4.0)", "pytest-asyncio (==0.21.1)", "pytest-cov (==4.1.0)", "pytest-doctestplus (==0.9.0)", "pytest-rerunfailures (==12.0)", "pytest-timeout (==2.1.0)", "pytest-xdist (==3.3.1)", "python-multipart (>=0.0.5,<1.0)", "pytorch-lightning (>=1.9.0)", "redis (>=4.0.1,<6.0)", "requests (<3.0)", "requests-mock (==1.11.0)", "rich (>=12.3.0,<14.0)", "s3fs (>=2022.5.0,<2024.0)", "setuptools (<69.0)", "starlette", "starsessions (>=1.2.1,<2.0)", "streamlit (>=1.13.0,<2.0)", "traitlets (>=5.3.0,<6.0)", "trio (<0.22.0)", "typing-extensions (>=4.0.0,<5.0)", "urllib3 (<3.0)", "uvicorn (<1.0)", "websocket-client (<2.0)", "websockets (<12.0)"] +app-extra = ["Jinja2 (<4.0)", "PyYAML (<7.0)", "aiohttp (>=3.8.0,<4.0)", "arrow (>=1.2.0,<2.0)", "backoff (>=2.2.1,<3.0)", "beautifulsoup4 (>=4.8.0,<5.0)", "click (<9.0)", "croniter (>=1.3.0,<1.5.0)", "dateutils (<1.0)", "deepdiff (>=5.7.0,<7.0)", "docker (>=5.0.0,<7.0)", "fastapi (>=0.92.0,<1.0)", "fsspec (>=2022.5.0,<2024.0)", "inquirer (>=2.10.0,<4.0)", "lightning-api-access (>=0.0.3)", "lightning-cloud (==0.5.39)", "lightning-fabric (>=1.9.0)", "lightning-utilities (>=0.8.0,<1.0)", "packaging", "panel (>=1.0.0,<2.0)", "psutil (<6.0)", "pydantic (>=1.7.4)", "python-multipart (>=0.0.5,<1.0)", "pytorch-lightning (>=1.9.0)", "redis (>=4.0.1,<6.0)", "requests (<3.0)", "rich (>=12.3.0,<14.0)", "s3fs (>=2022.5.0,<2024.0)", "starlette", "starsessions (>=1.2.1,<2.0)", "streamlit (>=1.13.0,<2.0)", "traitlets (>=5.3.0,<6.0)", "typing-extensions (>=4.0.0,<5.0)", "urllib3 (<3.0)", "uvicorn (<1.0)", "websocket-client (<2.0)", "websockets (<12.0)"] +app-test = ["coverage (==7.3.1)", "httpx (==0.25.0)", "playwright (==1.38.0)", "psutil (<6.0)", "pympler", "pytest (==7.4.0)", "pytest-asyncio (==0.21.1)", "pytest-cov (==4.1.0)", "pytest-doctestplus (==0.9.0)", "pytest-rerunfailures (==12.0)", "pytest-timeout (==2.1.0)", "pytest-xdist (==3.3.1)", "requests-mock (==1.11.0)", "setuptools (<69.0)", "trio (<0.22.0)"] +app-ui = ["panel (>=1.0.0,<2.0)", "streamlit (>=1.13.0,<2.0)"] +cloud = ["docker (>=5.0.0,<7.0)", "fsspec[http] (>2021.06.0,<2024.0)", "redis (>=4.0.1,<6.0)", "s3fs (>=2022.5.0,<2024.0)"] +components = ["aiohttp (>=3.8.0,<4.0)", "lightning-api-access (>=0.0.3)", "lightning-fabric (>=1.9.0)", "pytorch-lightning (>=1.9.0)"] +data = ["Jinja2 (<4.0)", "PyYAML (<7.0)", "arrow (>=1.2.0,<2.0)", "backoff (>=2.2.1,<3.0)", "beautifulsoup4 (>=4.8.0,<5.0)", "click (<9.0)", "croniter (>=1.3.0,<1.5.0)", "dateutils (<1.0)", "deepdiff (>=5.7.0,<7.0)", "fastapi (>=0.92.0,<1.0)", "fsspec (>=2022.5.0,<2024.0)", "inquirer (>=2.10.0,<4.0)", "lightning-cloud (==0.5.39)", "lightning-utilities (>=0.8.0,<1.0)", "packaging", "psutil (<6.0)", "pydantic (>=1.7.4)", "python-multipart (>=0.0.5,<1.0)", "requests (<3.0)", "rich (>=12.3.0,<14.0)", "starlette", "starsessions (>=1.2.1,<2.0)", "torch (>0.14.0,<3.0)", "torchdata (>0.5.9,<1.0)", "traitlets (>=5.3.0,<6.0)", "typing-extensions (>=4.0.0,<5.0)", "urllib3 (<3.0)", "uvicorn (<1.0)", "websocket-client (<2.0)", "websockets (<12.0)"] +data-all = ["Jinja2 (<4.0)", "Pillow (>=9.5.0)", "PyYAML (<7.0)", "arrow (>=1.2.0,<2.0)", "backoff (>=2.2.1,<3.0)", "beautifulsoup4 (>=4.8.0,<5.0)", "click (<9.0)", "croniter (>=1.3.0,<1.5.0)", "dateutils (<1.0)", "deepdiff (>=5.7.0,<7.0)", "fastapi (>=0.92.0,<1.0)", "fsspec (>=2022.5.0,<2024.0)", "fsspec[http] (>2021.06.0,<2024.0)", "inquirer (>=2.10.0,<4.0)", "lightning-cloud (==0.5.39)", "lightning-utilities (>=0.8.0,<1.0)", "packaging", "psutil (<6.0)", "pydantic (>=1.7.4)", "python-multipart (>=0.0.5,<1.0)", "requests (<3.0)", "rich (>=12.3.0,<14.0)", "s3fs (>=2022.5.0,<2024.0)", "starlette", "starsessions (>=1.2.1,<2.0)", "torch (>0.14.0,<3.0)", "torchdata (>0.5.9,<1.0)", "torchvision (>=0.15.2,<1.0)", "traitlets (>=5.3.0,<6.0)", "typing-extensions (>=4.0.0,<5.0)", "urllib3 (<3.0)", "uvicorn (<1.0)", "websocket-client (<2.0)", "websockets (<12.0)"] +data-cloud = ["fsspec[http] (>2021.06.0,<2024.0)", "s3fs (>=2022.5.0,<2024.0)"] +data-dev = ["Jinja2 (<4.0)", "Pillow (>=9.5.0)", "PyYAML (<7.0)", "arrow (>=1.2.0,<2.0)", "backoff (>=2.2.1,<3.0)", "beautifulsoup4 (>=4.8.0,<5.0)", "click (<9.0)", "coverage (==7.3.1)", "croniter (>=1.3.0,<1.5.0)", "dateutils (<1.0)", "deepdiff (>=5.7.0,<7.0)", "fastapi (>=0.92.0,<1.0)", "fsspec (>=2022.5.0,<2024.0)", "fsspec[http] (>2021.06.0,<2024.0)", "inquirer (>=2.10.0,<4.0)", "lightning-cloud (==0.5.39)", "lightning-utilities (>=0.8.0,<1.0)", "packaging", "psutil (<6.0)", "pydantic (>=1.7.4)", "pytest (==7.4.0)", "pytest-cov (==4.1.0)", "pytest-random-order (==1.1.0)", "pytest-rerunfailures (==12.0)", "pytest-timeout (==2.1.0)", "python-multipart (>=0.0.5,<1.0)", "requests (<3.0)", "rich (>=12.3.0,<14.0)", "s3fs (>=2022.5.0,<2024.0)", "starlette", "starsessions (>=1.2.1,<2.0)", "torch (>0.14.0,<3.0)", "torchdata (>0.5.9,<1.0)", "torchvision (>=0.15.2,<1.0)", "traitlets (>=5.3.0,<6.0)", "typing-extensions (>=4.0.0,<5.0)", "urllib3 (<3.0)", "uvicorn (<1.0)", "websocket-client (<2.0)", "websockets (<12.0)"] +data-examples = ["Pillow (>=9.5.0)", "torchvision (>=0.15.2,<1.0)"] +data-test = ["coverage (==7.3.1)", "pytest (==7.4.0)", "pytest-cov (==4.1.0)", "pytest-random-order (==1.1.0)", "pytest-rerunfailures (==12.0)", "pytest-timeout (==2.1.0)"] +dev = ["Jinja2 (<4.0)", "Pillow (>=9.5.0)", "PyYAML (<7.0)", "aiohttp (>=3.8.0,<4.0)", "arrow (>=1.2.0,<2.0)", "backoff (>=2.2.1,<3.0)", "beautifulsoup4 (>=4.8.0,<5.0)", "click (<9.0)", "click (==8.1.7)", "cloudpickle (>=1.3,<3.0)", "coverage (==7.3.1)", "croniter (>=1.3.0,<1.5.0)", "dateutils (<1.0)", "deepdiff (>=5.7.0,<7.0)", "deepspeed (>=0.8.2,<=0.9.3)", "docker (>=5.0.0,<7.0)", "fastapi", "fastapi (>=0.92.0,<1.0)", "fsspec (>=2022.5.0,<2024.0)", "fsspec[http] (>2021.06.0,<2024.0)", "gym[classic-control] (>=0.17.0,<1.0)", "httpx (==0.25.0)", "hydra-core (>=1.0.5,<2.0)", "inquirer (>=2.10.0,<4.0)", "ipython[all] (<9.0)", "jsonargparse[signatures] (>=4.18.0,<5.0)", "lightning-api-access (>=0.0.3)", "lightning-cloud (==0.5.39)", "lightning-fabric (>=1.9.0)", "lightning-utilities (>=0.8.0,<1.0)", "matplotlib (>3.1,<4.0)", "omegaconf (>=2.0.5,<3.0)", "onnx (>=0.14.0,<2.0)", "onnxruntime (>=0.15.0,<2.0)", "packaging", "pandas (>1.0,<3.0)", "panel (>=1.0.0,<2.0)", "playwright (==1.38.0)", "psutil (<6.0)", "pydantic (>=1.7.4)", "pympler", "pytest (==7.4.0)", "pytest-asyncio (==0.21.1)", "pytest-cov (==4.1.0)", "pytest-doctestplus (==0.9.0)", "pytest-random-order (==1.1.0)", "pytest-rerunfailures (==12.0)", "pytest-timeout (==2.1.0)", "pytest-xdist (==3.3.1)", "python-multipart (>=0.0.5,<1.0)", "pytorch-lightning (>=1.9.0)", "redis (>=4.0.1,<6.0)", "requests (<3.0)", "requests-mock (==1.11.0)", "rich (>=12.3.0,<14.0)", "s3fs (>=2022.5.0,<2024.0)", "scikit-learn (>0.22.1,<2.0)", "setuptools (<69.0)", "starlette", "starsessions (>=1.2.1,<2.0)", "streamlit (>=1.13.0,<2.0)", "tensorboard (>=2.9.1,<3.0)", "tensorboardX (>=2.2,<3.0)", "torch (>0.14.0,<3.0)", "torchdata (>0.5.9,<1.0)", "torchmetrics (>=0.10.0,<2.0)", "torchmetrics (>=0.7.0,<2.0)", "torchvision (>=0.13.0,<1.0)", "torchvision (>=0.15.2,<1.0)", "traitlets (>=5.3.0,<6.0)", "trio (<0.22.0)", "typing-extensions (>=4.0.0,<5.0)", "urllib3 (<3.0)", "uvicorn", "uvicorn (<1.0)", "websocket-client (<2.0)", "websockets (<12.0)"] +examples = ["Pillow (>=9.5.0)", "gym[classic-control] (>=0.17.0,<1.0)", "ipython[all] (<9.0)", "lightning-utilities (>=0.8.0,<1.0)", "torchmetrics (>=0.10.0,<2.0)", "torchvision (>=0.13.0,<1.0)", "torchvision (>=0.15.2,<1.0)"] +extra = ["Jinja2 (<4.0)", "PyYAML (<7.0)", "aiohttp (>=3.8.0,<4.0)", "arrow (>=1.2.0,<2.0)", "backoff (>=2.2.1,<3.0)", "beautifulsoup4 (>=4.8.0,<5.0)", "click (<9.0)", "croniter (>=1.3.0,<1.5.0)", "dateutils (<1.0)", "deepdiff (>=5.7.0,<7.0)", "docker (>=5.0.0,<7.0)", "fastapi (>=0.92.0,<1.0)", "fsspec (>=2022.5.0,<2024.0)", "hydra-core (>=1.0.5,<2.0)", "inquirer (>=2.10.0,<4.0)", "jsonargparse[signatures] (>=4.18.0,<5.0)", "lightning-api-access (>=0.0.3)", "lightning-cloud (==0.5.39)", "lightning-fabric (>=1.9.0)", "lightning-utilities (>=0.8.0,<1.0)", "matplotlib (>3.1,<4.0)", "omegaconf (>=2.0.5,<3.0)", "packaging", "panel (>=1.0.0,<2.0)", "psutil (<6.0)", "pydantic (>=1.7.4)", "python-multipart (>=0.0.5,<1.0)", "pytorch-lightning (>=1.9.0)", "redis (>=4.0.1,<6.0)", "requests (<3.0)", "rich (>=12.3.0,<14.0)", "s3fs (>=2022.5.0,<2024.0)", "starlette", "starsessions (>=1.2.1,<2.0)", "streamlit (>=1.13.0,<2.0)", "tensorboardX (>=2.2,<3.0)", "traitlets (>=5.3.0,<6.0)", "typing-extensions (>=4.0.0,<5.0)", "urllib3 (<3.0)", "uvicorn (<1.0)", "websocket-client (<2.0)", "websockets (<12.0)"] +fabric-all = ["deepspeed (>=0.8.2,<=0.9.3)", "lightning-utilities (>=0.8.0,<1.0)", "torchmetrics (>=0.10.0,<2.0)", "torchvision (>=0.13.0,<1.0)"] +fabric-dev = ["click (==8.1.7)", "coverage (==7.3.1)", "deepspeed (>=0.8.2,<=0.9.3)", "lightning-utilities (>=0.8.0,<1.0)", "pytest (==7.4.0)", "pytest-cov (==4.1.0)", "pytest-random-order (==1.1.0)", "pytest-rerunfailures (==12.0)", "pytest-timeout (==2.1.0)", "tensorboardX (>=2.2,<3.0)", "torchmetrics (>=0.10.0,<2.0)", "torchmetrics (>=0.7.0,<2.0)", "torchvision (>=0.13.0,<1.0)"] +fabric-examples = ["lightning-utilities (>=0.8.0,<1.0)", "torchmetrics (>=0.10.0,<2.0)", "torchvision (>=0.13.0,<1.0)"] +fabric-strategies = ["deepspeed (>=0.8.2,<=0.9.3)"] +fabric-test = ["click (==8.1.7)", "coverage (==7.3.1)", "pytest (==7.4.0)", "pytest-cov (==4.1.0)", "pytest-random-order (==1.1.0)", "pytest-rerunfailures (==12.0)", "pytest-timeout (==2.1.0)", "tensorboardX (>=2.2,<3.0)", "torchmetrics (>=0.7.0,<2.0)"] +pytorch-all = ["deepspeed (>=0.8.2,<=0.9.3)", "gym[classic-control] (>=0.17.0,<1.0)", "hydra-core (>=1.0.5,<2.0)", "ipython[all] (<9.0)", "jsonargparse[signatures] (>=4.18.0,<5.0)", "lightning-utilities (>=0.8.0,<1.0)", "matplotlib (>3.1,<4.0)", "omegaconf (>=2.0.5,<3.0)", "rich (>=12.3.0,<14.0)", "tensorboardX (>=2.2,<3.0)", "torchmetrics (>=0.10.0,<2.0)", "torchvision (>=0.13.0,<1.0)"] +pytorch-dev = ["cloudpickle (>=1.3,<3.0)", "coverage (==7.3.1)", "deepspeed (>=0.8.2,<=0.9.3)", "fastapi", "gym[classic-control] (>=0.17.0,<1.0)", "hydra-core (>=1.0.5,<2.0)", "ipython[all] (<9.0)", "jsonargparse[signatures] (>=4.18.0,<5.0)", "lightning-utilities (>=0.8.0,<1.0)", "matplotlib (>3.1,<4.0)", "omegaconf (>=2.0.5,<3.0)", "onnx (>=0.14.0,<2.0)", "onnxruntime (>=0.15.0,<2.0)", "pandas (>1.0,<3.0)", "psutil (<6.0)", "pytest (==7.4.0)", "pytest-cov (==4.1.0)", "pytest-random-order (==1.1.0)", "pytest-rerunfailures (==12.0)", "pytest-timeout (==2.1.0)", "rich (>=12.3.0,<14.0)", "scikit-learn (>0.22.1,<2.0)", "tensorboard (>=2.9.1,<3.0)", "tensorboardX (>=2.2,<3.0)", "torchmetrics (>=0.10.0,<2.0)", "torchvision (>=0.13.0,<1.0)", "uvicorn"] +pytorch-examples = ["gym[classic-control] (>=0.17.0,<1.0)", "ipython[all] (<9.0)", "lightning-utilities (>=0.8.0,<1.0)", "torchmetrics (>=0.10.0,<2.0)", "torchvision (>=0.13.0,<1.0)"] +pytorch-extra = ["hydra-core (>=1.0.5,<2.0)", "jsonargparse[signatures] (>=4.18.0,<5.0)", "matplotlib (>3.1,<4.0)", "omegaconf (>=2.0.5,<3.0)", "rich (>=12.3.0,<14.0)", "tensorboardX (>=2.2,<3.0)"] +pytorch-strategies = ["deepspeed (>=0.8.2,<=0.9.3)"] +pytorch-test = ["cloudpickle (>=1.3,<3.0)", "coverage (==7.3.1)", "fastapi", "onnx (>=0.14.0,<2.0)", "onnxruntime (>=0.15.0,<2.0)", "pandas (>1.0,<3.0)", "psutil (<6.0)", "pytest (==7.4.0)", "pytest-cov (==4.1.0)", "pytest-random-order (==1.1.0)", "pytest-rerunfailures (==12.0)", "pytest-timeout (==2.1.0)", "scikit-learn (>0.22.1,<2.0)", "tensorboard (>=2.9.1,<3.0)", "uvicorn"] +store = ["Jinja2 (<4.0)", "PyYAML (<7.0)", "arrow (>=1.2.0,<2.0)", "backoff (>=2.2.1,<3.0)", "beautifulsoup4 (>=4.8.0,<5.0)", "click (<9.0)", "croniter (>=1.3.0,<1.5.0)", "dateutils (<1.0)", "deepdiff (>=5.7.0,<7.0)", "fastapi (>=0.92.0,<1.0)", "fsspec (>=2022.5.0,<2024.0)", "inquirer (>=2.10.0,<4.0)", "lightning-cloud (==0.5.39)", "lightning-utilities (>=0.8.0,<1.0)", "packaging", "psutil (<6.0)", "pydantic (>=1.7.4)", "python-multipart (>=0.0.5,<1.0)", "requests (<3.0)", "rich (>=12.3.0,<14.0)", "starlette", "starsessions (>=1.2.1,<2.0)", "traitlets (>=5.3.0,<6.0)", "typing-extensions (>=4.0.0,<5.0)", "urllib3 (<3.0)", "uvicorn (<1.0)", "websocket-client (<2.0)", "websockets (<12.0)"] +store-test = ["coverage (==7.3.1)", "pytest (==7.4.0)", "pytest-cov (==4.1.0)", "pytest-random-order (==1.1.0)", "pytest-rerunfailures (==12.0)", "pytest-timeout (==2.1.0)"] +strategies = ["deepspeed (>=0.8.2,<=0.9.3)"] +test = ["click (==8.1.7)", "cloudpickle (>=1.3,<3.0)", "coverage (==7.3.1)", "fastapi", "httpx (==0.25.0)", "onnx (>=0.14.0,<2.0)", "onnxruntime (>=0.15.0,<2.0)", "pandas (>1.0,<3.0)", "playwright (==1.38.0)", "psutil (<6.0)", "pympler", "pytest (==7.4.0)", "pytest-asyncio (==0.21.1)", "pytest-cov (==4.1.0)", "pytest-doctestplus (==0.9.0)", "pytest-random-order (==1.1.0)", "pytest-rerunfailures (==12.0)", "pytest-timeout (==2.1.0)", "pytest-xdist (==3.3.1)", "requests-mock (==1.11.0)", "scikit-learn (>0.22.1,<2.0)", "setuptools (<69.0)", "tensorboard (>=2.9.1,<3.0)", "tensorboardX (>=2.2,<3.0)", "torchmetrics (>=0.7.0,<2.0)", "trio (<0.22.0)", "uvicorn"] +ui = ["panel (>=1.0.0,<2.0)", "streamlit (>=1.13.0,<2.0)"] + +[[package]] +name = "lightning-utilities" +version = "0.9.0" +description = "PyTorch Lightning Sample project." +optional = false +python-versions = ">=3.7" +files = [ + {file = "lightning-utilities-0.9.0.tar.gz", hash = "sha256:efbf2c488c257f942abdfd06cf646fb84ca215a9663b60081811e22a15ee033b"}, + {file = "lightning_utilities-0.9.0-py3-none-any.whl", hash = "sha256:918dd90c775719e3855631db6282ad75c14da4c5727c4cebdd1589d865fad03d"}, +] + +[package.dependencies] +packaging = ">=17.1" +typing-extensions = "*" + +[package.extras] +cli = ["fire"] +docs = ["requests (>=2.0.0)"] +typing = ["mypy (>=1.0.0)"] + [[package]] name = "llvmlite" version = "0.40.1" @@ -2945,16 +3135,6 @@ files = [ {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win32.whl", hash = "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-win32.whl", hash = "sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-win_amd64.whl", hash = "sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707"}, @@ -3156,6 +3336,21 @@ tqdm = ["tqdm (>=4.47.0)"] transformers = ["transformers (<4.22)"] zarr = ["zarr"] +[[package]] +name = "mpi4py" +version = "4.0.0.dev0" +description = "Python bindings for MPI" +optional = false +python-versions = ">=3.6" +files = [] +develop = false + +[package.source] +type = "git" +url = "https://github.com/mpi4py/mpi4py" +reference = "HEAD" +resolved_reference = "5c57df3ec3892cd2f961e6027a02cb4056ded4df" + [[package]] name = "multidict" version = "6.0.4" @@ -4596,6 +4791,37 @@ text-unidecode = ">=1.3" [package.extras] unidecode = ["Unidecode (>=1.1.1)"] +[[package]] +name = "pytorch-lightning" +version = "2.1.0" +description = "PyTorch Lightning is the lightweight PyTorch wrapper for ML researchers. Scale your models. Write less boilerplate." +optional = false +python-versions = ">=3.8" +files = [ + {file = "pytorch-lightning-2.1.0.tar.gz", hash = "sha256:bf9e26b293e1ccda5f8e146fe58716eecfd77e9639ef3ec2210b0dcba51c4593"}, + {file = "pytorch_lightning-2.1.0-py3-none-any.whl", hash = "sha256:2802d683ef513235dfc211f6bc45d7086e8982feaac1625aafd2886c5e5b96f8"}, +] + +[package.dependencies] +fsspec = {version = ">2021.06.0", extras = ["http"]} +lightning-utilities = ">=0.8.0" +numpy = ">=1.17.2" +packaging = ">=20.0" +PyYAML = ">=5.4" +torch = ">=1.12.0" +torchmetrics = ">=0.7.0" +tqdm = ">=4.57.0" +typing-extensions = ">=4.0.0" + +[package.extras] +all = ["deepspeed (>=0.8.2,<=0.9.3)", "gym[classic-control] (>=0.17.0)", "hydra-core (>=1.0.5)", "ipython[all] (<8.15.0)", "jsonargparse[signatures] (>=4.18.0)", "lightning-utilities (>=0.8.0)", "matplotlib (>3.1)", "omegaconf (>=2.0.5)", "rich (>=12.3.0)", "tensorboardX (>=2.2)", "torchmetrics (>=0.10.0)", "torchvision (>=0.13.0)"] +deepspeed = ["deepspeed (>=0.8.2,<=0.9.3)"] +dev = ["cloudpickle (>=1.3)", "coverage (==7.3.1)", "deepspeed (>=0.8.2,<=0.9.3)", "fastapi", "gym[classic-control] (>=0.17.0)", "hydra-core (>=1.0.5)", "ipython[all] (<8.15.0)", "jsonargparse[signatures] (>=4.18.0)", "lightning-utilities (>=0.8.0)", "matplotlib (>3.1)", "omegaconf (>=2.0.5)", "onnx (>=0.14.0)", "onnxruntime (>=0.15.0)", "pandas (>1.0)", "psutil (<5.9.6)", "pytest (==7.4.0)", "pytest-cov (==4.1.0)", "pytest-random-order (==1.1.0)", "pytest-rerunfailures (==12.0)", "pytest-timeout (==2.1.0)", "rich (>=12.3.0)", "scikit-learn (>0.22.1)", "tensorboard (>=2.9.1)", "tensorboardX (>=2.2)", "torchmetrics (>=0.10.0)", "torchvision (>=0.13.0)", "uvicorn"] +examples = ["gym[classic-control] (>=0.17.0)", "ipython[all] (<8.15.0)", "lightning-utilities (>=0.8.0)", "torchmetrics (>=0.10.0)", "torchvision (>=0.13.0)"] +extra = ["hydra-core (>=1.0.5)", "jsonargparse[signatures] (>=4.18.0)", "matplotlib (>3.1)", "omegaconf (>=2.0.5)", "rich (>=12.3.0)", "tensorboardX (>=2.2)"] +strategies = ["deepspeed (>=0.8.2,<=0.9.3)"] +test = ["cloudpickle (>=1.3)", "coverage (==7.3.1)", "fastapi", "onnx (>=0.14.0)", "onnxruntime (>=0.15.0)", "pandas (>1.0)", "psutil (<5.9.6)", "pytest (==7.4.0)", "pytest-cov (==4.1.0)", "pytest-random-order (==1.1.0)", "pytest-rerunfailures (==12.0)", "pytest-timeout (==2.1.0)", "scikit-learn (>0.22.1)", "tensorboard (>=2.9.1)", "uvicorn"] + [[package]] name = "pytz" version = "2023.3.post1" @@ -4708,7 +4934,6 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, - {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -4716,15 +4941,8 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, - {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, - {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, - {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, - {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -4741,7 +4959,6 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, - {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -4749,7 +4966,6 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, - {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, @@ -6569,6 +6785,34 @@ typing-extensions = "*" [package.extras] opt-einsum = ["opt-einsum (>=3.3)"] +[[package]] +name = "torchmetrics" +version = "1.2.0" +description = "PyTorch native Metrics" +optional = false +python-versions = ">=3.8" +files = [ + {file = "torchmetrics-1.2.0-py3-none-any.whl", hash = "sha256:da2cb18822b285786d082c40efb9e1d861aac425f58230234fe6ce233cf002f8"}, + {file = "torchmetrics-1.2.0.tar.gz", hash = "sha256:7eb28340bde45e13187a9ad54a4a7010a50417815d8181a5df6131f116ffe1b7"}, +] + +[package.dependencies] +lightning-utilities = ">=0.8.0" +numpy = ">1.20.0" +torch = ">=1.8.1" + +[package.extras] +all = ["SciencePlots (>=2.0.0)", "lpips (<=0.1.4)", "matplotlib (>=3.2.0)", "mypy (==1.5.1)", "nltk (>=3.6)", "piq (<=0.8.0)", "pycocotools (>2.0.0)", "pystoi (>=0.3.0)", "regex (>=2021.9.24)", "scipy (>1.0.0)", "torch-fidelity (<=0.4.0)", "torchaudio (>=0.10.0)", "torchvision (>=0.8)", "tqdm (>=4.41.0)", "transformers (>4.4.0)", "transformers (>=4.10.0)", "types-PyYAML", "types-emoji", "types-protobuf", "types-requests", "types-setuptools", "types-six", "types-tabulate"] +audio = ["pystoi (>=0.3.0)", "torchaudio (>=0.10.0)"] +detection = ["pycocotools (>2.0.0)", "torchvision (>=0.8)"] +dev = ["SciencePlots (>=2.0.0)", "bert-score (==0.3.13)", "cloudpickle (>1.3)", "coverage (==7.3.1)", "dython (<=0.7.4)", "fairlearn", "fast-bss-eval (>=0.1.0)", "faster-coco-eval (>=1.3.3)", "fire (<=0.5.0)", "huggingface-hub (<0.18)", "jiwer (>=2.3.0)", "kornia (>=0.6.7)", "lpips (<=0.1.4)", "matplotlib (>=3.2.0)", "mir-eval (>=0.6)", "mypy (==1.5.1)", "netcal (>1.0.0)", "nltk (>=3.6)", "numpy (<1.25.0)", "pandas (>1.0.0)", "pandas (>=1.4.0)", "phmdoctest (==1.4.0)", "piq (<=0.8.0)", "psutil (<=5.9.5)", "pycocotools (>2.0.0)", "pystoi (>=0.3.0)", "pytest (==7.4.2)", "pytest-cov (==4.1.0)", "pytest-doctestplus (==1.0.0)", "pytest-rerunfailures (==12.0)", "pytest-timeout (==2.1.0)", "pytorch-msssim (==1.0.0)", "regex (>=2021.9.24)", "requests (<=2.31.0)", "rouge-score (>0.1.0)", "sacrebleu (>=2.0.0)", "scikit-image (>=0.19.0)", "scikit-learn (>=1.1.1)", "scipy (>1.0.0)", "sewar (>=0.4.4)", "statsmodels (>0.13.5)", "torch-complex (<=0.4.3)", "torch-fidelity (<=0.4.0)", "torchaudio (>=0.10.0)", "torchvision (>=0.8)", "tqdm (>=4.41.0)", "transformers (>4.4.0)", "transformers (>=4.10.0)", "types-PyYAML", "types-emoji", "types-protobuf", "types-requests", "types-setuptools", "types-six", "types-tabulate"] +image = ["lpips (<=0.1.4)", "scipy (>1.0.0)", "torch-fidelity (<=0.4.0)", "torchvision (>=0.8)"] +multimodal = ["piq (<=0.8.0)", "transformers (>=4.10.0)"] +test = ["bert-score (==0.3.13)", "cloudpickle (>1.3)", "coverage (==7.3.1)", "dython (<=0.7.4)", "fairlearn", "fast-bss-eval (>=0.1.0)", "faster-coco-eval (>=1.3.3)", "fire (<=0.5.0)", "huggingface-hub (<0.18)", "jiwer (>=2.3.0)", "kornia (>=0.6.7)", "mir-eval (>=0.6)", "netcal (>1.0.0)", "numpy (<1.25.0)", "pandas (>1.0.0)", "pandas (>=1.4.0)", "phmdoctest (==1.4.0)", "psutil (<=5.9.5)", "pytest (==7.4.2)", "pytest-cov (==4.1.0)", "pytest-doctestplus (==1.0.0)", "pytest-rerunfailures (==12.0)", "pytest-timeout (==2.1.0)", "pytorch-msssim (==1.0.0)", "requests (<=2.31.0)", "rouge-score (>0.1.0)", "sacrebleu (>=2.0.0)", "scikit-image (>=0.19.0)", "scikit-learn (>=1.1.1)", "scipy (>1.0.0)", "sewar (>=0.4.4)", "statsmodels (>0.13.5)", "torch-complex (<=0.4.3)"] +text = ["nltk (>=3.6)", "regex (>=2021.9.24)", "tqdm (>=4.41.0)", "transformers (>4.4.0)"] +typing = ["mypy (==1.5.1)", "types-PyYAML", "types-emoji", "types-protobuf", "types-requests", "types-setuptools", "types-six", "types-tabulate"] +visual = ["SciencePlots (>=2.0.0)", "matplotlib (>=3.2.0)"] + [[package]] name = "torchvision" version = "0.14.1" @@ -7342,10 +7586,10 @@ docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.link testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy (>=0.9.1)", "pytest-ruff"] [extras] -models = ["alibi", "alibi-detect", "hydra-core", "llvmlite", "scikit-learn", "torch", "torchxrayvision", "xgboost"] +models = ["alibi", "alibi-detect", "array-api-compat", "hydra-core", "llvmlite", "scikit-learn", "torch", "torchxrayvision", "xgboost"] report = ["kaleido", "pillow", "plotly", "pybtex", "pydantic", "scour", "spdx-tools"] [metadata] lock-version = "2.0" python-versions = ">=3.9, <3.11" -content-hash = "9f7fca5ba08c424ed004fbf8ef9df2947e8c7721726d12743d5cbac1496bbdd4" +content-hash = "922e7f891d06f8f640c0a89e8d5eabd49cba3cd2bb49cb6fac4043358db1ef95" diff --git a/pyproject.toml b/pyproject.toml index 9a77e4b08..77fa0fb55 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ readme = "README.md" [tool.poetry.dependencies] python = ">=3.9, <3.11" pandas = "^2.0" -numpy = "^1.23.0" +numpy = "^1.24.4" datasets = "^2.10.1" psutil = "^5.9.4" pyarrow = "^11.0.0" @@ -35,6 +35,7 @@ kaleido = { version = "0.2.1", optional = true } scour = { version = "^0.38.2", optional = true } plotly = { version = "^5.7.0", optional = true } pillow = { version = "^9.5.0", optional = true } +array-api-compat = {version = "^1.4", optional = true} [tool.poetry.group.models.dependencies] hydra-core = "^1.2.0" @@ -45,6 +46,7 @@ xgboost = "^1.5.2" alibi = { version = "^0.9.4", extras = ["shap"] } alibi-detect = { version = "^0.11.0", extras = ["torch"] } llvmlite = "^0.40.0" +array-api-compat = {version = "^1.4"} [tool.poetry.group.report.dependencies] pydantic = "^1.10.11" @@ -69,6 +71,7 @@ mypy = "^1.0.0" ruff = "^0.1.0" nbqa = { version = "^1.7.0", extras = ["toolchain"] } cycquery = "^0.1.0" # used for integration test +torchmetrics = {version = "^1.2.0", extras = ["classification"]} [tool.poetry.group.docs] optional = true @@ -98,9 +101,13 @@ jupyter = "^1.0.0" jupyterlab = "^3.4.2" ipympl = "^0.9.3" ipywidgets = "^8.0.6" +torchmetrics = {version = "^1.2.0", extras = ["classification"]} +cupy = "^12.2.0" +mpi4py = {git = "https://github.com/mpi4py/mpi4py"} +lightning = "^2.1.0" [tool.poetry.extras] -models = ["hydra-core", "scikit-learn", "torch", "torchxrayvision", "xgboost", "alibi", "alibi-detect", "llvmlite"] +models = ["array-api-compat", "hydra-core", "scikit-learn", "torch", "torchxrayvision", "xgboost", "alibi", "alibi-detect", "llvmlite"] report = ["pydantic", "spdx-tools", "pybtex", "kaleido", "scour", "plotly", "pillow"] [tool.mypy] diff --git a/tests/cyclops/evaluate/__init__.py b/tests/cyclops/evaluate/__init__.py new file mode 100644 index 000000000..c314707bc --- /dev/null +++ b/tests/cyclops/evaluate/__init__.py @@ -0,0 +1 @@ +"""Tests for the `cyclops.evaluate` package.""" diff --git a/tests/cyclops/evaluate/metrics/__init__.py b/tests/cyclops/evaluate/metrics/__init__.py index 430dd576d..7922cd278 100644 --- a/tests/cyclops/evaluate/metrics/__init__.py +++ b/tests/cyclops/evaluate/metrics/__init__.py @@ -1 +1 @@ -"""Evaluate metrics testing package.""" +"""Tests for `cyclops.evaluate.metrics` package.""" diff --git a/tests/cyclops/evaluate/metrics/conftest.py b/tests/cyclops/evaluate/metrics/conftest.py new file mode 100644 index 000000000..f3241828e --- /dev/null +++ b/tests/cyclops/evaluate/metrics/conftest.py @@ -0,0 +1,76 @@ +"""pytest plugins and constants for tests/cyclops/evaluate/metrics/.""" +import contextlib +import os +import socket +import sys +from functools import partial + +import pytest +import torch.distributed +from torch.multiprocessing import Pool as TorchPool +from torch.multiprocessing import set_sharing_strategy, set_start_method + +from cyclops.utils.optional import import_optional_module + + +MPI_pool = import_optional_module("mpi4py.util.pool", error="ignore") + +with contextlib.suppress(RuntimeError): + set_start_method("spawn") + set_sharing_strategy("file_system") + + +NUM_PROCESSES = 2 +BATCH_SIZE = 16 * NUM_PROCESSES +NUM_BATCHES = 8 +NUM_CLASSES = 10 +NUM_LABELS = 10 +EXTRA_DIM = 4 +THRESHOLD = 0.6 + + +def get_open_port(): + """Get an open port. + + Reference + --------- + 1. https://stackoverflow.com/questions/66348957/pytorch-ddp-get-stuck-in-getting-free-port + + """ + with contextlib.closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: + s.bind(("", 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + return s.getsockname()[1] + + +def setup_ddp(rank, world_size, port): + """Initialize distributed environment.""" + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(port) + + if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): + torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) + + +def pytest_configure(): + """Inject attributes into the pytest namespace.""" + torch_pool = TorchPool(processes=NUM_PROCESSES) + torch_pool.starmap( + partial(setup_ddp, port=get_open_port()), + [(rank, NUM_PROCESSES) for rank in range(NUM_PROCESSES)], + ) + pytest.torch_pool = torch_pool # type: ignore + + if MPI_pool is not None: + mpi_pool = MPI_pool.Pool(processes=NUM_PROCESSES, path=sys.path) + pytest.mpi_pool = mpi_pool # type: ignore + + +def pytest_sessionfinish(): + """Close the global multiprocessing pools after all tests are done.""" + pytest.torch_pool.close() # type: ignore + pytest.torch_pool.join() # type: ignore + + if MPI_pool is not None: + pytest.mpi_pool.close() # type: ignore + pytest.mpi_pool.join() # type: ignore diff --git a/tests/cyclops/evaluate/metrics/experimental/__init__.py b/tests/cyclops/evaluate/metrics/experimental/__init__.py new file mode 100644 index 000000000..cc2cc0b56 --- /dev/null +++ b/tests/cyclops/evaluate/metrics/experimental/__init__.py @@ -0,0 +1 @@ +"""Test the `cyclops.evaluate.metrics.experimental` package.""" diff --git a/tests/cyclops/evaluate/metrics/experimental/distributed_backends/__init__.py b/tests/cyclops/evaluate/metrics/experimental/distributed_backends/__init__.py new file mode 100644 index 000000000..9864c6b05 --- /dev/null +++ b/tests/cyclops/evaluate/metrics/experimental/distributed_backends/__init__.py @@ -0,0 +1 @@ +"""Tests for distributed backends.""" diff --git a/tests/cyclops/evaluate/metrics/experimental/distributed_backends/test_mpi4py.py b/tests/cyclops/evaluate/metrics/experimental/distributed_backends/test_mpi4py.py new file mode 100644 index 000000000..d2909676d --- /dev/null +++ b/tests/cyclops/evaluate/metrics/experimental/distributed_backends/test_mpi4py.py @@ -0,0 +1,175 @@ +"""Test mpi4py backend.""" + +from functools import partial + +import numpy as np +import numpy.array_api as anp +import pytest + +from cyclops.evaluate.metrics.experimental.distributed_backends.mpi4py import ( + MPI4Py, +) +from cyclops.utils.optional import import_optional_module + +from ...conftest import NUM_PROCESSES +from ..testers import DummyListStateMetric, DummyMetric + + +MPI = import_optional_module("mpi4py.MPI", error="ignore") + + +def _test_mpi4py_class_init(rank: int, worldsize: int = 2): + """Run test.""" + if MPI is None: + with pytest.raises( + ImportError, + match="For availability of MPI4Py please install mpi4py first.", + ): + backend = MPI4Py() + assert not backend.is_initialized + pytest.skip("`mpi4py` is not installed.") + + backend = MPI4Py() + assert backend.is_initialized + assert backend.rank == rank + assert backend.world_size == worldsize + + +@pytest.mark.integration_test() +def test_mpi4py_backend_class_init(): + """Test `TorchDistributed` class.""" + pytest.mpi_pool.starmap(_test_mpi4py_class_init, [(rank, 2) for rank in range(2)]) # type: ignore + + +def _test_all_gather_simple(rank: int, worldsize: int = 2): + """Run test.""" + backend = MPI4Py() + + array = anp.ones(5) + result = backend.all_gather(array) + assert len(result) == worldsize + for idx in range(worldsize): + val = result[idx] + assert anp.all(val == anp.ones_like(val)) + + +def _test_all_gather_uneven_arrays(rank: int, worldsize: int = 2): + """Run test.""" + backend = MPI4Py() + + array = anp.ones(rank) + result = backend.all_gather(array) + assert len(result) == worldsize + for idx in range(worldsize): + val = result[idx] + assert anp.all(val == anp.ones_like(val)) + + +def _test_all_gather_uneven_multidim_arrays(rank: int, worldsize: int = 2): + """Run test.""" + backend = MPI4Py() + + array = anp.ones((rank + 1, 2 - rank, 2)) + result = backend.all_gather(array) + assert len(result) == worldsize + for idx in range(worldsize): + val = result[idx] + assert anp.all(val == anp.ones_like(val)) + + +@pytest.mark.integration_test() +@pytest.mark.skipif(MPI is None, reason="`mpi4py` is not installed.") +@pytest.mark.parametrize( + "case_fn", + [ + _test_all_gather_simple, + _test_all_gather_uneven_arrays, + _test_all_gather_uneven_multidim_arrays, + ], +) +def test_mpi4py_all_gather(case_fn): + """Test `all_gather` method.""" + pytest.mpi_pool.starmap(case_fn, [(rank, 2) for rank in range(NUM_PROCESSES)]) # type: ignore + + +def _test_dist_sum(rank: int, worldsize: int = NUM_PROCESSES) -> None: + dummy = DummyMetric(dist_backend="mpi4py") + dummy._reductions = {"foo": anp.sum} + dummy.foo = anp.asarray(1) + dummy.sync() + + assert anp.all(dummy.foo == anp.asarray(worldsize)) + + +def _test_dist_cat(rank: int, worldsize: int = NUM_PROCESSES) -> None: + dummy = DummyMetric(dist_backend="mpi4py") + dummy._reductions = {"foo": anp.concat} + dummy.foo = [anp.asarray([1])] + dummy.sync() + + assert anp.all(anp.equal(dummy.foo, anp.asarray([1, 1]))) + + +def _test_dist_sum_cat(rank: int, worldsize: int = NUM_PROCESSES) -> None: + dummy = DummyMetric(dist_backend="mpi4py") + dummy._reductions = {"foo": anp.concat, "bar": anp.sum} + dummy.foo = [anp.asarray([1])] + dummy.bar = anp.asarray(1) + dummy.sync() + + assert anp.all(anp.equal(dummy.foo, anp.asarray([1, 1]))) + assert dummy.bar == worldsize + + +def _test_dist_compositional_array(rank: int, worldsize: int = NUM_PROCESSES) -> None: + dummy = DummyMetric(dist_backend="mpi4py") + dummy = dummy.clone() + dummy.clone() + dummy.update(anp.asarray(1, dtype=anp.float32)) + val = dummy.compute() + print(val) + assert val == 2 * worldsize + + +@pytest.mark.integration_test() +@pytest.mark.skipif(MPI is None, reason="`mpi4py` is not available") +@pytest.mark.parametrize( + "process", + [ + _test_dist_cat, + _test_dist_sum, + _test_dist_sum_cat, + _test_dist_compositional_array, + ], +) +def test_ddp(process): + """Test ddp functions.""" + pytest.mpi_pool.map(process, range(NUM_PROCESSES)) # type: ignore + + +def _test_sync_on_compute_array_state(rank): + dummy = DummyMetric(dist_backend="mpi4py") + dummy.update(anp.asarray(rank + 1, dtype=anp.float32)) + val = dummy.compute() + + assert anp.all(val == 3) + + +def _test_sync_on_compute_list_state(rank): + dummy = DummyListStateMetric(dist_backend="mpi4py") + dummy.update(anp.asarray(rank + 1, dtype=anp.float32)) + val = dummy.compute() + assert anp.all(anp.sum(val) == 3) + assert np.allclose(val, anp.asarray([1, 2])) or np.allclose( + val, + anp.asarray([2, 1]), + ) + + +@pytest.mark.integration_test() +@pytest.mark.parametrize( + "test_func", + [_test_sync_on_compute_list_state, _test_sync_on_compute_array_state], +) +def test_sync_on_compute(test_func): + """Test that synchronization of states can be enabled and disabled for compute.""" + pytest.mpi_pool.map(partial(test_func), range(NUM_PROCESSES)) # type: ignore diff --git a/tests/cyclops/evaluate/metrics/experimental/distributed_backends/test_torch_distributed.py b/tests/cyclops/evaluate/metrics/experimental/distributed_backends/test_torch_distributed.py new file mode 100644 index 000000000..17077a205 --- /dev/null +++ b/tests/cyclops/evaluate/metrics/experimental/distributed_backends/test_torch_distributed.py @@ -0,0 +1,153 @@ +"""Test the torch distributed backend.""" +import sys +from functools import partial + +import pytest + +from cyclops.evaluate.metrics.experimental.distributed_backends.torch_distributed import ( + TorchDistributed, +) +from cyclops.utils.optional import import_optional_module + +from ...conftest import NUM_PROCESSES +from ..testers import DummyListStateMetric, DummyMetric + + +torch = import_optional_module("torch", error="ignore") + + +def _test_torch_distributed_class(rank: int, worldsize: int = NUM_PROCESSES): + """Run test.""" + torch_dist = import_optional_module("torch.distributed", error="ignore") + if torch is None: + with pytest.raises( + ImportError, + match="For availability of TorchDistributed please install .*", + ): + backend = TorchDistributed() + + if torch_dist is None: + with pytest.raises(RuntimeError): + backend = TorchDistributed() + + # skip if torch distributed is not available + if torch_dist is None: + pytest.skip("torch.distributed is not available") + + backend = TorchDistributed() + + assert backend.is_initialized == torch_dist.is_initialized() + assert backend.rank == rank + assert backend.world_size == worldsize + + # test all simple all gather (tensors of the same size) + tensor = torch.ones(2) # type: ignore + result = backend._simple_all_gather(tensor) + assert len(result) == worldsize + for idx in range(worldsize): + val = result[idx] + assert (val == torch.ones_like(val)).all() # type: ignore + + # test all gather uneven tensors + tensor = torch.ones(rank) # type: ignore + result = backend.all_gather(tensor) + assert len(result) == worldsize + for idx in range(worldsize): + val = result[idx] + assert (val == torch.ones_like(val)).all() # type: ignore + + # test all gather multidimensional uneven tensors + tensor = torch.ones(rank + 1, 2 - rank) # type: ignore + result = backend.all_gather(tensor) + assert len(result) == worldsize + for idx in range(worldsize): + val = result[idx] + assert (val == torch.ones_like(val)).all() # type: ignore + + +@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") +def test_torch_distributed_backend_class(): + """Test `TorchDistributed` class.""" + pytest.torch_pool.map(_test_torch_distributed_class, range(NUM_PROCESSES)) # type: ignore + + +def _test_dist_sum(rank: int, worldsize: int = NUM_PROCESSES) -> None: + dummy = DummyMetric(dist_backend="torch_distributed") + dummy._reductions = {"foo": torch.sum} + dummy.foo = torch.tensor(1) + dummy.sync() + + assert dummy.foo == worldsize + + +def _test_dist_cat(rank: int, worldsize: int = NUM_PROCESSES) -> None: + dummy = DummyMetric(dist_backend="torch_distributed") + dummy._reductions = {"foo": torch.cat} + dummy.foo = [torch.tensor([1])] + dummy.sync() + + assert torch.all(torch.eq(dummy.foo, torch.tensor([1, 1]))) + + +def _test_dist_sum_cat(rank: int, worldsize: int = NUM_PROCESSES) -> None: + dummy = DummyMetric(dist_backend="torch_distributed") + dummy._reductions = {"foo": torch.cat, "bar": torch.sum} + dummy.foo = [torch.tensor([1])] + dummy.bar = torch.tensor(1) + dummy.sync() + + assert torch.all(torch.eq(dummy.foo, torch.tensor([1, 1]))) + assert dummy.bar == worldsize + + +def _test_dist_compositional_tensor(rank: int, worldsize: int = NUM_PROCESSES) -> None: + dummy = DummyMetric(dist_backend="torch_distributed") + dummy = dummy.clone() + dummy.clone() + dummy.update(torch.tensor(1)) + val = dummy.compute() + assert val == 2 * worldsize + + +@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") +@pytest.mark.skipif(torch is None, reason="torch is not available") +@pytest.mark.parametrize( + "process", + [ + _test_dist_cat, + _test_dist_sum, + _test_dist_sum_cat, + _test_dist_compositional_tensor, + ], +) +def test_ddp(process): + """Test ddp functions.""" + pytest.torch_pool.map(process, range(NUM_PROCESSES)) # type: ignore + + +def _test_sync_on_compute_tensor_state(rank): + dummy = DummyMetric(dist_backend="torch_distributed") + dummy.update(torch.tensor(rank + 1)) + val = dummy.compute() + + assert val == 3 + + +def _test_sync_on_compute_list_state(rank): + dummy = DummyListStateMetric(dist_backend="torch_distributed") + dummy.update(torch.tensor(rank + 1)) + val = dummy.compute() + assert val.sum() == 3 + assert torch.allclose(val, torch.tensor([1, 2])) or torch.allclose( + val, + torch.tensor([2, 1]), + ) + + +@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") +@pytest.mark.parametrize( + "test_func", + [_test_sync_on_compute_list_state, _test_sync_on_compute_tensor_state], +) +def test_sync_on_compute(test_func): + """Test that synchronization of states can be enabled and disabled for compute.""" + pytest.torch_pool.map(partial(test_func), range(NUM_PROCESSES)) diff --git a/tests/cyclops/evaluate/metrics/experimental/inputs.py b/tests/cyclops/evaluate/metrics/experimental/inputs.py new file mode 100644 index 000000000..66f7fb191 --- /dev/null +++ b/tests/cyclops/evaluate/metrics/experimental/inputs.py @@ -0,0 +1,331 @@ +"""Input data for tests of metrics in cyclops/evaluate/metrics/experimental.""" +import random +from collections import namedtuple +from typing import Any + +import array_api_compat as apc +import numpy as np +import pytest +import torch +from scipy.special import log_softmax + +from cyclops.evaluate.metrics.experimental.utils.typing import Array + +from ..conftest import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, NUM_CLASSES, NUM_LABELS + + +InputSpec = namedtuple("InputSpec", ["target", "preds"]) + + +def set_random_seed(seed: int) -> None: + """Set random seed.""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + +def _inv_sigmoid(arr: Array) -> Array: + """Inverse sigmoid function.""" + xp = apc.array_namespace(arr) + return xp.log(arr / (1 - arr)) + + +set_random_seed(1) + +# binary +# NOTE: the test will loop over the first dimension of the input +_binary_labels_0d = np.random.randint(0, 2, size=(NUM_BATCHES)) +_binary_preds_0d = np.random.randint(0, 2, size=(NUM_BATCHES)) +_binary_probs_0d = np.random.rand(NUM_BATCHES) +_binary_labels_1d = np.random.randint(0, 2, size=(NUM_BATCHES, BATCH_SIZE)) +_binary_preds_1d = np.random.randint(0, 2, size=(NUM_BATCHES, BATCH_SIZE)) +_binary_probs_1d = np.random.rand(NUM_BATCHES, BATCH_SIZE) +_binary_labels_multidim = np.random.randint( + 0, + 2, + size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM), +) +_binary_preds_multidim = np.random.randint( + 0, + 2, + size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM), +) +_binary_probs_multidim = np.random.rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM) + + +def _binary_cases(*, xp: Any): + """Return binary input cases for the given array namespace.""" + return ( + pytest.param( + InputSpec( + target=xp.asarray(_binary_labels_0d), + preds=xp.asarray(_binary_preds_0d), + ), + id="input[0d-labels]", + ), + pytest.param( + InputSpec( + target=xp.asarray(_binary_labels_0d), + preds=xp.asarray(_binary_probs_0d), + ), + id="input[0d-probs]", + ), + pytest.param( + InputSpec( + target=xp.asarray(_binary_labels_0d), + preds=xp.asarray(_inv_sigmoid(_binary_probs_0d)), + ), + id="input[0d-logits]", + ), + pytest.param( + InputSpec( + target=xp.asarray(_binary_labels_1d), + preds=xp.asarray(_binary_preds_1d), + ), + id="input[1d-labels]", + ), + pytest.param( + InputSpec( + target=xp.asarray(_binary_labels_1d), + preds=xp.asarray(_binary_probs_1d), + ), + id="input[1d-probs]", + ), + pytest.param( + InputSpec( + target=xp.asarray(_binary_labels_1d), + preds=xp.asarray(_inv_sigmoid(_binary_probs_1d)), + ), + id="input[1d-probs]", + ), + pytest.param( + InputSpec( + target=xp.asarray(_binary_labels_multidim), + preds=xp.asarray(_binary_preds_multidim), + ), + id="input[multidim-labels]", + ), + pytest.param( + InputSpec( + target=xp.asarray(_binary_labels_multidim), + preds=xp.asarray(_binary_probs_multidim), + ), + id="input[multidim-probs]", + ), + pytest.param( + InputSpec( + target=xp.asarray(_binary_labels_multidim), + preds=xp.asarray(_inv_sigmoid(_binary_probs_multidim)), + ), + id="input[multidim-probs]", + ), + ) + + +def _multiclass_with_missing_class( + *shape: Any, + num_classes: int = NUM_CLASSES, + xp: Any, +) -> Array: + """Generate multiclass input where a class is missing. + + Args: + shape: shape of the tensor + num_classes: number of classes + + Returns + ------- + tensor with missing classes + + """ + x = np.random.randint(0, num_classes, shape) + x[x == 0] = 2 + return xp.asarray(x) + + +# multiclass +_multiclass_labels_0d = np.random.randint(0, NUM_CLASSES, size=(NUM_BATCHES)) +_multiclass_preds_0d = np.random.randint(0, NUM_CLASSES, size=(NUM_BATCHES)) +_multiclass_probs_0d = np.random.rand(NUM_BATCHES, NUM_CLASSES) +_multiclass_labels_1d = np.random.randint( + 0, + NUM_CLASSES, + size=(NUM_BATCHES, BATCH_SIZE), +) +_multiclass_preds_1d = np.random.randint(0, NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)) +_multiclass_probs_1d = np.random.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES) +_multiclass_labels_multidim = np.random.randint( + 0, + NUM_CLASSES, + size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM), +) +_multiclass_preds_multidim = np.random.randint( + 0, + NUM_CLASSES, + size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM), +) +_multiclass_probs_multidim = np.random.rand( + NUM_BATCHES, + BATCH_SIZE, + NUM_CLASSES, + EXTRA_DIM, +) + + +def _multiclass_cases(*, xp: Any): + """Return multiclass input cases for the given array namespace.""" + return ( + pytest.param( + InputSpec( + target=xp.asarray(_multiclass_labels_0d), + preds=xp.asarray(_multiclass_preds_0d), + ), + id="input[0d-labels]", + ), + pytest.param( + InputSpec( + target=xp.asarray(_multiclass_labels_0d), + preds=xp.asarray(_multiclass_probs_0d), + ), + id="input[0d-probs]", + ), + pytest.param( + InputSpec( + target=xp.asarray(_multiclass_labels_0d), + preds=xp.asarray(log_softmax(_multiclass_probs_0d, axis=-1)), + ), + id="input[0d-logits]", + ), + pytest.param( + InputSpec( + target=xp.asarray(_multiclass_labels_1d), + preds=xp.asarray(_multiclass_preds_1d), + ), + id="input[1d-labels]", + ), + pytest.param( + InputSpec( + target=xp.asarray(_multiclass_labels_1d), + preds=xp.asarray(_multiclass_probs_1d), + ), + id="input[1d-probs]", + ), + pytest.param( + InputSpec( + target=xp.asarray(_multiclass_labels_1d), + preds=xp.asarray(log_softmax(_multiclass_probs_1d, axis=-1)), + ), + id="input[1d-logits]", + ), + pytest.param( + InputSpec( + preds=_multiclass_with_missing_class( + NUM_BATCHES, + BATCH_SIZE, + num_classes=NUM_CLASSES, + xp=xp, + ), + target=_multiclass_with_missing_class( + NUM_BATCHES, + BATCH_SIZE, + num_classes=NUM_CLASSES, + xp=xp, + ), + ), + id="input[1d-labels-missing_class]", + ), + pytest.param( + InputSpec( + target=xp.asarray(_multiclass_labels_multidim), + preds=xp.asarray(_multiclass_preds_multidim), + ), + id="input[multidim-labels]", + ), + pytest.param( + InputSpec( + target=xp.asarray(_multiclass_labels_multidim), + preds=xp.asarray(_multiclass_probs_multidim), + ), + id="input[multidim-probs]", + ), + pytest.param( + InputSpec( + target=xp.asarray(_multiclass_labels_multidim), + preds=xp.asarray(log_softmax(_multiclass_probs_multidim, axis=-1)), + ), + id="input[multidim-logits]", + ), + ) + + +# multilabel +_multilabel_labels = np.random.randint(0, 2, size=(NUM_BATCHES, BATCH_SIZE, NUM_LABELS)) +_multilabel_preds = np.random.randint( + 0, + 2, + size=(NUM_BATCHES, BATCH_SIZE, NUM_LABELS), +) +_multilabel_probs = np.random.rand(NUM_BATCHES, BATCH_SIZE, NUM_LABELS) +_multilabel_labels_multidim = np.random.randint( + 0, + 2, + size=(NUM_BATCHES, BATCH_SIZE, NUM_LABELS, EXTRA_DIM), +) +_multilabel_preds_multidim = np.random.randint( + 0, + 2, + size=(NUM_BATCHES, BATCH_SIZE, NUM_LABELS, EXTRA_DIM), +) +_multilabel_probs_multidim = np.random.rand( + NUM_BATCHES, + BATCH_SIZE, + NUM_LABELS, + EXTRA_DIM, +) + + +def _multilabel_cases(*, xp: Any): + return ( + pytest.param( + InputSpec( + target=xp.asarray(_multilabel_labels), + preds=xp.asarray(_multilabel_preds), + ), + id="input[2d-labels]", + ), + pytest.param( + InputSpec( + target=xp.asarray(_multilabel_labels), + preds=xp.asarray(_multilabel_probs), + ), + id="input[2d-probs]", + ), + pytest.param( + InputSpec( + target=xp.asarray(_multilabel_labels), + preds=xp.asarray(_inv_sigmoid(_multilabel_probs)), + ), + id="input[2d-logits]", + ), + pytest.param( + InputSpec( + target=xp.asarray(_multilabel_labels_multidim), + preds=xp.asarray(_multilabel_preds_multidim), + ), + id="input[multidim-labels]", + ), + pytest.param( + InputSpec( + target=xp.asarray(_multilabel_labels_multidim), + preds=xp.asarray(_multilabel_probs_multidim), + ), + id="input[multidim-probs]", + ), + pytest.param( + InputSpec( + target=xp.asarray(_multilabel_labels_multidim), + preds=xp.asarray(_inv_sigmoid(_multilabel_probs_multidim)), + ), + id="input[multidim-logits]", + ), + ) diff --git a/tests/cyclops/evaluate/metrics/experimental/test_confusion_matrix.py b/tests/cyclops/evaluate/metrics/experimental/test_confusion_matrix.py new file mode 100644 index 000000000..73b378c31 --- /dev/null +++ b/tests/cyclops/evaluate/metrics/experimental/test_confusion_matrix.py @@ -0,0 +1,425 @@ +"""Test confusion matrix metrics.""" +from functools import partial + +import array_api_compat as apc +import array_api_compat.torch +import numpy.array_api as anp +import pytest +import torch.utils.dlpack +from torchmetrics.functional.classification import ( + binary_confusion_matrix as tm_binary_confusion_matrix, +) +from torchmetrics.functional.classification import ( + multiclass_confusion_matrix as tm_multiclass_confusion_matrix, +) +from torchmetrics.functional.classification import ( + multilabel_confusion_matrix as tm_multilabel_confusion_matrix, +) + +from cyclops.evaluate.metrics.experimental.confusion_matrix import ( + BinaryConfusionMatrix, + MulticlassConfusionMatrix, + MultilabelConfusionMatrix, +) +from cyclops.evaluate.metrics.experimental.functional.confusion_matrix import ( + binary_confusion_matrix, + multiclass_confusion_matrix, + multilabel_confusion_matrix, +) +from cyclops.evaluate.metrics.experimental.utils.ops import to_int +from cyclops.evaluate.metrics.experimental.utils.validation import is_floating_point + +from ..conftest import NUM_CLASSES, NUM_LABELS, THRESHOLD +from .inputs import _binary_cases, _multiclass_cases, _multilabel_cases +from .testers import MetricTester, _inject_ignore_index + + +def _binary_confusion_matrix_reference( + target, + preds, + threshold, + normalize, + ignore_index, +) -> torch.Tensor: + """Return the reference binary confusion matrix.""" + return tm_binary_confusion_matrix( + torch.utils.dlpack.from_dlpack(preds), + torch.utils.dlpack.from_dlpack(target), + threshold=threshold, + normalize=normalize, + ignore_index=ignore_index, + ) + + +class TestBinaryConfusionMatrix(MetricTester): + """Test binary confusion matrix function and class.""" + + @pytest.mark.parametrize("inputs", _binary_cases(xp=anp)) + @pytest.mark.parametrize("normalize", [None, "true", "pred", "all"]) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + def test_binary_confusion_matrix_function_with_numpy_array_api_arrays( + self, + inputs, + normalize, + ignore_index, + ) -> None: + """Test function for binary confusion matrix using numpy.array_api arrays.""" + target, preds = inputs + + if ignore_index is not None: + target = _inject_ignore_index(target, ignore_index) + + self.run_metric_function_implementation_test( + target, + preds, + metric_function=binary_confusion_matrix, + metric_args={ + "threshold": THRESHOLD, + "normalize": normalize, + "ignore_index": ignore_index, + }, + reference_metric=partial( + _binary_confusion_matrix_reference, + threshold=THRESHOLD, + normalize=normalize, + ignore_index=ignore_index, + ), + ) + + @pytest.mark.parametrize("inputs", _binary_cases(xp=anp)) + @pytest.mark.parametrize("normalize", [None, "true", "pred", "all"]) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + def test_binary_confusion_matrix_class_with_numpy_array_api_arrays( + self, + inputs, + normalize, + ignore_index, + ) -> None: + """Test class for binary confusion matrix.""" + target, preds = inputs + + if ( + preds.ndim == 1 + and is_floating_point(preds) + and not anp.all(to_int((preds >= 0)) * to_int((preds <= 1))) + ): + pytest.skip( + "When using 0-D logits, batch result will be different from local " + "result because the `sigmoid` operation may not be applied to each " + "batch (some values may be in [0, 1] and some may not).", + ) + + if ignore_index is not None: + target = _inject_ignore_index(target, ignore_index) + + self.run_metric_class_implementation_test( + target, + preds, + metric_class=BinaryConfusionMatrix, + metric_args={ + "threshold": THRESHOLD, + "normalize": normalize, + "ignore_index": ignore_index, + }, + reference_metric=partial( + _binary_confusion_matrix_reference, + threshold=THRESHOLD, + normalize=normalize, + ignore_index=ignore_index, + ), + ) + + @pytest.mark.integration_test() # machine for integration tests has GPU + @pytest.mark.parametrize("inputs", _binary_cases(xp=array_api_compat.torch)) + @pytest.mark.parametrize("normalize", [None, "true", "pred", "all"]) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + def test_binary_confusion_matrix_class_with_torch_tensors( + self, + inputs, + normalize, + ignore_index, + ) -> None: + """Test binary confusion matrix class with torch tensors.""" + target, preds = inputs + + if ( + preds.ndim == 1 + and is_floating_point(preds) + and not torch.all(to_int((preds >= 0)) * to_int((preds <= 1))) + ): + pytest.skip( + "When using 0-D logits, batch result will be different from local " + "result because the `sigmoid` operation may not be applied to each " + "batch (some values may be in [0, 1] and some may not).", + ) + + if ignore_index is not None: + target = _inject_ignore_index(target, ignore_index) + + device = "cuda" if torch.cuda.is_available() else "cpu" + + self.run_metric_class_implementation_test( + target, + preds, + metric_class=BinaryConfusionMatrix, + metric_args={ + "threshold": THRESHOLD, + "normalize": normalize, + "ignore_index": ignore_index, + }, + reference_metric=partial( + _binary_confusion_matrix_reference, + threshold=THRESHOLD, + normalize=normalize, + ignore_index=ignore_index, + ), + device=device, + use_device_for_ref=True, + ) + + +def _multiclass_confusion_matrix_reference( + target, + preds, + num_classes=NUM_CLASSES, + normalize=None, + ignore_index=None, +) -> torch.Tensor: + """Return the reference multiclass confusion matrix.""" + print(preds, target) + if preds.ndim == 1 and is_floating_point(preds): + xp = apc.array_namespace(preds) + preds = xp.argmax(preds, axis=0) + + return tm_multiclass_confusion_matrix( + torch.utils.dlpack.from_dlpack(preds), + torch.utils.dlpack.from_dlpack(target), + num_classes, + normalize=normalize, + ignore_index=ignore_index, + ) + + +class TestMulticlassConfusionMatrix(MetricTester): + """Test multiclass confusion matrix function and class.""" + + @pytest.mark.parametrize("inputs", _multiclass_cases(xp=anp)) + @pytest.mark.parametrize("normalize", [None, "true", "pred", "all"]) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + def test_multiclass_confusion_matrix_with_numpy_array_api_arrays( + self, + inputs, + normalize, + ignore_index, + ) -> None: + """Test function for multiclass confusion matrix.""" + target, preds = inputs + + if ignore_index is not None: + target = _inject_ignore_index(target, ignore_index) + + self.run_metric_function_implementation_test( + target, + preds, + metric_function=multiclass_confusion_matrix, + metric_args={ + "num_classes": NUM_CLASSES, + "normalize": normalize, + "ignore_index": ignore_index, + }, + reference_metric=partial( + _multiclass_confusion_matrix_reference, + normalize=normalize, + ignore_index=ignore_index, + ), + ) + + @pytest.mark.parametrize("inputs", _multiclass_cases(xp=anp)) + @pytest.mark.parametrize("normalize", [None, "true", "pred", "all"]) + @pytest.mark.parametrize("ignore_index", [None, 1, -1]) + def test_multiclass_confusion_matrix_class_with_numpy_array_api_arrays( + self, + inputs, + normalize, + ignore_index, + ) -> None: + """Test class for multiclass confusion matrix.""" + target, preds = inputs + + if ignore_index is not None: + target = _inject_ignore_index(target, ignore_index) + + self.run_metric_class_implementation_test( + target, + preds, + metric_class=MulticlassConfusionMatrix, + reference_metric=partial( + _multiclass_confusion_matrix_reference, + normalize=normalize, + ignore_index=ignore_index, + ), + metric_args={ + "num_classes": NUM_CLASSES, + "normalize": normalize, + "ignore_index": ignore_index, + }, + ) + + @pytest.mark.integration_test() # machine for integration tests has GPU + @pytest.mark.parametrize("inputs", _multiclass_cases(xp=array_api_compat.torch)) + @pytest.mark.parametrize("normalize", [None, "true", "pred", "all"]) + @pytest.mark.parametrize("ignore_index", [None, 1, -1]) + def test_multiclass_confusion_matrix_class_with_torch_tensors( + self, + inputs, + normalize, + ignore_index, + ) -> None: + """Test class for multiclass confusion matrix.""" + target, preds = inputs + + if ignore_index is not None: + target = _inject_ignore_index(target, ignore_index) + + device = "cuda" if torch.cuda.is_available() else "cpu" + + self.run_metric_class_implementation_test( + target, + preds, + metric_class=MulticlassConfusionMatrix, + reference_metric=partial( + _multiclass_confusion_matrix_reference, + normalize=normalize, + ignore_index=ignore_index, + ), + metric_args={ + "num_classes": NUM_CLASSES, + "normalize": normalize, + "ignore_index": ignore_index, + }, + device=device, + use_device_for_ref=True, + ) + + +def _multilabel_confusion_matrix_reference( + preds, + target, + threshold, + num_labels=NUM_LABELS, + normalize=None, + ignore_index=None, +) -> torch.Tensor: + """Return the reference multilabel confusion matrix.""" + return tm_multilabel_confusion_matrix( + torch.utils.dlpack.from_dlpack(preds), + torch.utils.dlpack.from_dlpack(target), + num_labels, + threshold=threshold, + normalize=normalize, + ignore_index=ignore_index, + ) + + +class TestMultilabelConfusionMatrix(MetricTester): + """Test multilabel confusion matrix function and class.""" + + @pytest.mark.parametrize("inputs", _multilabel_cases(xp=anp)) + @pytest.mark.parametrize("normalize", [None, "true", "pred", "all"]) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + def test_multilabel_confusion_matrix_with_numpy_array_api_arrays( + self, + inputs, + normalize, + ignore_index, + ) -> None: + """Test function for multilabel confusion matrix.""" + target, preds = inputs + + if ignore_index is not None: + target = _inject_ignore_index(target, ignore_index) + + self.run_metric_function_implementation_test( + target, + preds, + metric_function=multilabel_confusion_matrix, + reference_metric=partial( + _multilabel_confusion_matrix_reference, + threshold=THRESHOLD, + normalize=normalize, + ignore_index=ignore_index, + ), + metric_args={ + "threshold": THRESHOLD, + "num_labels": NUM_LABELS, + "normalize": normalize, + "ignore_index": ignore_index, + }, + ) + + @pytest.mark.parametrize("inputs", _multilabel_cases(xp=anp)) + @pytest.mark.parametrize("normalize", [None, "true", "pred", "all"]) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + def test_multilabel_confusion_matrix_class_with_numpy_array_api_arrays( + self, + inputs, + normalize, + ignore_index, + ) -> None: + """Test class for multilabel confusion matrix.""" + target, preds = inputs + + if ignore_index is not None: + target = _inject_ignore_index(target, ignore_index) + + self.run_metric_class_implementation_test( + target, + preds, + metric_class=MultilabelConfusionMatrix, + reference_metric=partial( + _multilabel_confusion_matrix_reference, + threshold=THRESHOLD, + normalize=normalize, + ignore_index=ignore_index, + ), + metric_args={ + "threshold": THRESHOLD, + "num_labels": NUM_LABELS, + "normalize": normalize, + "ignore_index": ignore_index, + }, + ) + + @pytest.mark.integration_test() # machine for integration tests has GPU + @pytest.mark.parametrize("inputs", _multilabel_cases(xp=anp)) + @pytest.mark.parametrize("normalize", [None, "true", "pred", "all"]) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + def test_multilabel_confusion_matrix_class_with_torch_tensors( + self, + inputs, + normalize, + ignore_index, + ) -> None: + """Test class for multilabel confusion matrix.""" + target, preds = inputs + + if ignore_index is not None: + target = _inject_ignore_index(target, ignore_index) + + self.run_metric_class_implementation_test( + target, + preds, + metric_class=MultilabelConfusionMatrix, + reference_metric=partial( + _multilabel_confusion_matrix_reference, + threshold=THRESHOLD, + normalize=normalize, + ignore_index=ignore_index, + ), + metric_args={ + "threshold": THRESHOLD, + "num_labels": NUM_LABELS, + "normalize": normalize, + "ignore_index": ignore_index, + }, + ) diff --git a/tests/cyclops/evaluate/metrics/experimental/test_metric.py b/tests/cyclops/evaluate/metrics/experimental/test_metric.py new file mode 100644 index 000000000..39c80fb75 --- /dev/null +++ b/tests/cyclops/evaluate/metrics/experimental/test_metric.py @@ -0,0 +1,419 @@ +"""Tests for the base class of metrics.""" +import array_api_compat as apc +import numpy as np +import numpy.array_api as anp +import pytest +import torch + +from cyclops.evaluate.metrics.experimental.utils.ops import ( + dim_zero_cat, + dim_zero_max, + dim_zero_mean, + dim_zero_min, + dim_zero_sum, +) + +from .testers import DummyListStateMetric, DummyMetric + + +def test_inherit(): + """Test that metric that inherits can be instantiated.""" + DummyMetric() + + +def test_dist_backend_kwarg(): + """Test different options for `dist_backend` kwarg.""" + with pytest.raises( + ValueError, + match="Backend `nonexistent` is not found.*", + ): + DummyMetric(dist_backend="nonexistent") + + with pytest.raises( + TypeError, + match="Expected `name` to be a str, but got .", + ): + DummyMetric(dist_backend=42) + + +def test_add_state_factory(): + """Test that `add_state_default_factory` method works as expected.""" + metric = DummyMetric() + + # happy path + # default_factory is callable with single argument (xp) + metric.add_state_default_factory("a", lambda xp: xp.asarray(0), None) # type: ignore + reduce_fn = metric._reductions["a"] + assert reduce_fn is None, "Saved reduction function is not None." + assert ( + metric._default_factories.get("a") is not None + ), "Default factory was not correctly created." + + # default_factory is 'list' + metric.add_state_default_factory("b", list) # type: ignore + assert ( + metric._default_factories.get("b") == list + ), "Default factory should be 'list'." + + # dist_reduce_fn is "sum" + metric.add_state_default_factory("c", lambda xp: xp.asarray(0), "sum") # type: ignore + reduce_fn = metric._reductions["c"] + assert callable(reduce_fn), "Saved reduction function is not callable." + assert reduce_fn is dim_zero_sum, ( + "Saved reduction function is not the same as the one used to " + "create the state." + ) + assert reduce_fn(anp.asarray([1, 1])) == anp.asarray( + 2, + ), "Saved reduction function does not work as expected." + + # dist_reduce_fn is "mean" + metric.add_state_default_factory("d", lambda xp: xp.asarray(0), "mean") # type: ignore + reduce_fn = metric._reductions["d"] + assert callable(reduce_fn), "Saved reduction function is not callable." + assert reduce_fn is dim_zero_mean, ( + "Saved reduction function is not the same as the one used to " + "create the state." + ) + assert np.allclose( + reduce_fn(anp.asarray([1.0, 2.0])), + 1.5, + ), "Saved reduction function does not work as expected." + + # dist_reduce_fn is "cat" + metric.add_state_default_factory("e", lambda xp: xp.asarray(0), "cat") # type: ignore + reduce_fn = metric._reductions["e"] + assert callable(reduce_fn), "Saved reduction function is not callable." + assert reduce_fn is dim_zero_cat, ( + "Saved reduction function is not the same as the one used to " + "create the state." + ) + np.testing.assert_array_equal( + reduce_fn([anp.asarray([1]), anp.asarray([1])]), + anp.asarray([1, 1]), + err_msg="Saved reduction function does not work as expected.", + ) + + # dist_reduce_fn is "max" + metric.add_state_default_factory("f", lambda xp: xp.asarray(0), "max") # type: ignore + reduce_fn = metric._reductions["f"] + assert callable(reduce_fn), "Saved reduction function is not callable." + assert reduce_fn is dim_zero_max, ( + "Saved reduction function is not the same as the one used to " + "create the state." + ) + np.testing.assert_array_equal( + reduce_fn(anp.asarray([1, 2])), + anp.asarray(2), + err_msg="Saved reduction function does not work as expected.", + ) + + # dist_reduce_fn is "min" + metric.add_state_default_factory("g", lambda xp: xp.asarray(0), "min") # type: ignore + metric._add_states(anp) + reduce_fn = metric._reductions["g"] + assert callable(reduce_fn), "Saved reduction function is not callable." + assert reduce_fn is dim_zero_min, ( + "Saved reduction function is not the same as the one used to " + "create the state." + ) + np.testing.assert_array_equal( + reduce_fn(anp.asarray([1, 2])), + anp.asarray(1), + err_msg="Saved reduction function does not work as expected.", + ) + + # custom reduction function + def custom_fn(_): + return anp.asarray(-1) + + metric.add_state_default_factory("h", lambda xp: xp.asarray(0), custom_fn) # type: ignore + assert metric._reductions["h"](anp.asarray([1, 1])) == anp.asarray(-1) # type: ignore + + # test that default values are set correctly + metric._add_states(anp) + for name in "abcdefgh": + default = metric._defaults.get(name, None) + assert default is not None, f"Default value for {name} is None." + if apc.is_array_api_obj(default): + np.testing.assert_array_equal( + default, + anp.asarray(0), + err_msg=f"Default value for {name} is not 0.", + ) + else: + assert default == [] + + assert hasattr(metric, name), f"Metric does not have attribute {name}." + attr_val = getattr(metric, name) + if apc.is_array_api_obj(default): + np.testing.assert_array_equal( + attr_val, + anp.asarray(0), + err_msg=f"Attribute {name} is not 0.", + ) + else: + assert attr_val == [] + + +def test_add_state_default_factory_invalid_input(): + """Test that `add_state_default_factory` method raises errors as expected.""" + metric = DummyMetric() + with pytest.raises( + ValueError, + match="`dist_reduce_fn` must be callable or one of .*", + ): + metric.add_state_default_factory("h1", lambda xp: xp.asarray(0), "xyz") # type: ignore + + with pytest.raises( + ValueError, + match="`dist_reduce_fn` must be callable or one of .*", + ): + metric.add_state_default_factory("h2", lambda xp: xp.asarray(0), 42) # type: ignore + + with pytest.raises( + TypeError, + match="Expected `default_factory` to be a callable, but got .", + ): + metric.add_state_default_factory("h3", [lambda xp: xp.asarray(0)], "sum") # type: ignore + + with pytest.raises( + TypeError, + match="Expected `default_factory` to be a callable, but got .", + ): + metric.add_state_default_factory("h4", 42, "sum") # type: ignore + + def custom_fn(xp, _): + return xp.asarray(-1) + + with pytest.raises( + TypeError, + match="Expected `default_factory` to be a function that takes at most .*", + ): + metric.add_state_default_factory("h5", custom_fn) # type: ignore + + with pytest.raises( + ValueError, + match="Argument `name` must be a valid python identifier. Got `h6!`.", + ): + metric.add_state_default_factory("h6!", list) # type: ignore + + +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA is not available.", +) +def test_to_device(): + """Test that `to_device` method works as expected.""" + metric = DummyMetric() + assert metric.device == "cpu" + + metric = metric.to_device("cuda") + assert metric.device == "cuda" + metric.update(torch.tensor(42, device="cuda")) # type: ignore + assert metric.x.device.type == "cuda" # type: ignore + + metric = metric.to_device("cpu") + assert metric.device == "cpu" + + metric = metric.to_device("cuda") + assert metric.device == "cuda" + metric.reset() + assert metric.x.device.type == "cuda" # type: ignore + + metric = DummyListStateMetric() + assert metric.device == "cpu" + + metric = metric.to_device("cuda") + metric.update(torch.tensor(1.0).to("cuda")) # type: ignore + metric.compute() + torch.testing.assert_close(metric.x, [torch.tensor(1.0, device="cuda")]) # type: ignore + + metric = metric.to_device("cpu") + torch.testing.assert_close(metric.x, [torch.tensor(1.0, device="cpu")]) # type: ignore + metric.to_device("cuda") + torch.testing.assert_close(metric.x, [torch.tensor(1.0, device="cuda")]) # type: ignore + + +def test_update(): + """Test that `update` method works as expected.""" + metric = DummyMetric() + metric._add_states(anp) + assert metric.state_vars == {"x": anp.asarray(0, dtype=anp.float32)} + assert metric._computed is None + metric.update(anp.asarray(1, dtype=anp.float32)) + assert metric._computed is None + assert metric.state_vars == {"x": anp.asarray(1, dtype=anp.float32)} + metric.update(anp.asarray(2, dtype=anp.float32)) + assert metric.state_vars == {"x": anp.asarray(3, dtype=anp.float32)} + assert metric._computed is None + + metric = DummyListStateMetric() + metric._add_states(anp) + assert metric.state_vars == {"x": []} + assert metric._computed is None + metric.update(anp.asarray(1)) + assert metric._computed is None + assert metric.state_vars == {"x": [anp.asarray(1)]} + metric.update(anp.asarray(2)) + assert metric.state_vars == {"x": [anp.asarray(1), anp.asarray(2)]} + assert metric._computed is None + + +def test_compute(): + """Test that `compute` method works as expected.""" + metric = DummyMetric() + + with pytest.raises( + RuntimeError, + match="The `compute` method of DummyMetric was called before the `update`.*", + ): + metric.compute() + + metric.update(anp.asarray(1, dtype=anp.float32)) + expected_value = anp.asarray(1, dtype=anp.float32) + assert metric._computed is None + np.testing.assert_array_equal(metric.compute(), expected_value) + np.testing.assert_array_equal(metric._computed, expected_value) + assert metric.state_vars == {"x": expected_value} + + metric.update(anp.asarray(2, dtype=anp.float32)) + expected_value = anp.asarray(3, dtype=anp.float32) + assert metric._computed is None + np.testing.assert_array_equal(metric.compute(), expected_value) + np.testing.assert_array_equal(metric._computed, expected_value) + assert metric.state_vars == {"x": expected_value} + + # called without update, should return cached value + metric._computed = anp.asarray(42, dtype=anp.float32) + np.testing.assert_array_equal( + metric.compute(), + anp.asarray(42, dtype=anp.float32), + ) + assert metric.state_vars == {"x": anp.asarray(3, dtype=anp.float32)} + + +def test_reset(): + """Test that reset method works as expected.""" + + class A(DummyMetric): + pass + + class B(DummyListStateMetric): + pass + + metric = A() + metric._add_states(anp) + assert metric.x == anp.asarray(0, dtype=anp.float32) # type: ignore + metric.x = anp.asarray(42) # type: ignore + metric.reset() + assert metric.x == anp.asarray(0, dtype=anp.float32) # type: ignore + + metric = B() + metric._add_states(anp) + assert isinstance(metric.x, list) # type: ignore + assert len(metric.x) == 0 # type: ignore + metric.x = anp.asarray(42) # type: ignore + metric.reset() + assert isinstance(metric.x, list) # type: ignore + assert len(metric.x) == 0 # type: ignore + + +def test_reset_compute(): + """Test that `reset`+`compute` methods works as expected.""" + metric = DummyMetric() + + metric.update(anp.asarray(42, dtype=anp.float32)) + assert metric.state_vars == {"x": anp.asarray(42, dtype=anp.float32)} + np.testing.assert_array_equal( + metric.compute(), + anp.asarray(42, dtype=anp.float32), + ) + metric.reset() + assert metric.state_vars == {"x": anp.asarray(0, dtype=anp.float32)} + + +def test_error_on_compute_before_update(): + """Test that `compute` method raises error when called before `update`.""" + metric = DummyMetric() + + with pytest.raises( + RuntimeError, + match="The `compute` method of DummyMetric was called before the `update`.*", + ): + metric.compute() + + # after update, should work + metric.update(anp.asarray(42, dtype=anp.float32)) + result = metric.compute() + np.testing.assert_array_equal(result, anp.asarray(42, dtype=anp.float32)) + + +def test_clone(): + """Test the `clone` method.""" + metric = DummyMetric() + metric_clone = metric.clone() + assert metric is not metric_clone + assert metric.state_vars is not metric_clone.state_vars + assert metric._default_factories is not metric_clone._default_factories + assert metric._reductions is not metric_clone._reductions + + metric.update(anp.asarray(42, dtype=anp.float32)) + assert metric.state_vars == { + "x": anp.asarray(42, dtype=anp.float32), + } and not hasattr(metric_clone, "x") + assert metric._update_count == 1 and metric_clone._update_count == 0 + + metric_clone = metric.clone() + assert metric is not metric_clone + assert metric.state_vars == metric_clone.state_vars + assert metric._update_count == metric_clone._update_count + assert metric._computed == metric_clone._computed + + metric.compute() + assert ( + anp.all(metric._computed == anp.asarray(42, dtype=anp.float32)) + and metric_clone._computed is None + ) + metric_clone = metric.clone() + assert metric is not metric_clone + assert metric.state_vars == metric_clone.state_vars + assert anp.all(metric._computed == metric_clone._computed) + + +def test_call(): + """Test that the `__call__` method works as expected.""" + metric = DummyMetric() + assert metric.state_vars == {} + assert metric._computed is None + + metric(anp.asarray(42, dtype=anp.float32)) + assert metric.state_vars == {"x": anp.asarray(42, dtype=anp.float32)} + assert metric._computed is None + + metric(anp.asarray(1, dtype=anp.float32)) + assert metric.state_vars == {"x": anp.asarray(43, dtype=anp.float32)} + assert metric._computed is None + + metric.reset() + assert metric.state_vars == {"x": anp.asarray(0, dtype=anp.float32)} + assert metric._computed is None + + +@pytest.mark.parametrize("method", ["call", "update"]) +@pytest.mark.parametrize("metric", [DummyMetric, DummyListStateMetric]) +def test_update_count_torch(metric, method): + """Test that `_update_count` attribute is correctly updated.""" + m = metric() + x = torch.randn( + 1, + ).squeeze() + for i in range(10): + if method == "update": + m.update(x) + if method == "call": + _ = m(x) + assert m._update_count == i + 1 + + m.reset() + assert m._update_count == 0 diff --git a/tests/cyclops/evaluate/metrics/experimental/test_operator_metric.py b/tests/cyclops/evaluate/metrics/experimental/test_operator_metric.py new file mode 100644 index 000000000..c015596fb --- /dev/null +++ b/tests/cyclops/evaluate/metrics/experimental/test_operator_metric.py @@ -0,0 +1,519 @@ +"""Test the OperatorMetric class.""" +import numpy as np +import numpy.array_api as anp +import pytest + +from cyclops.evaluate.metrics.experimental.metric import Metric, OperatorMetric +from cyclops.evaluate.metrics.experimental.utils.typing import Array + + +class DummyMetric(Metric): + """DummyMetric class for testing operator metrics.""" + + def __init__(self, val_to_return: Array) -> None: + super().__init__() + self.add_state_default_factory( + "_num_updates", + lambda xp: xp.asarray(0, device=self._device), # type: ignore + dist_reduce_fn="sum", + ) + self._val_to_return = val_to_return + + def _update_state(self, unused_arg: Array) -> None: + """Compute state.""" + self._num_updates += 1 # type: ignore + + def _compute_metric(self): + """Compute result.""" + return anp.asarray(self._val_to_return) + + +def test_metrics_abs(): + """Test that `abs` operator works and returns an operator metric.""" + metric = DummyMetric(anp.asarray(-2, dtype=anp.float32)) + abs_metric = abs(metric) + assert isinstance(abs_metric, OperatorMetric) + abs_metric.update(anp.asarray(0)) # dummy value to get array namespace + assert np.allclose(anp.asarray(2, dtype=anp.float32), abs_metric.compute()) + + +@pytest.mark.parametrize( + ("second_operand", "expected_result"), + [ + (2, anp.asarray(4)), + (2.0, anp.asarray(4.0)), + (DummyMetric(anp.asarray(2)), anp.asarray(4)), + (anp.asarray(2), anp.asarray(4)), + ], +) +def test_metrics_add(second_operand, expected_result): + """Test that `add` operator works and returns an operator metric.""" + first_metric = DummyMetric( + anp.asarray( + 2, + dtype=anp.float32 if isinstance(second_operand, float) else None, + ), + ) + + final_add = first_metric + second_operand + assert isinstance(final_add, OperatorMetric) + final_add.update(anp.asarray(0)) # dummy value to get array namespace + assert np.allclose(expected_result, final_add.compute()) + + if not isinstance(second_operand, DummyMetric): + with pytest.raises(TypeError, match="unsupported operand type.*"): + final_radd = second_operand + first_metric # type: ignore + else: + final_radd = second_operand + first_metric + assert isinstance(final_radd, OperatorMetric) + final_radd.update(anp.asarray(0)) # dummy value to get array namespace + assert np.allclose(expected_result, final_radd.compute()) + + +@pytest.mark.parametrize( + ("second_operand", "expected_result"), + [ + (False, anp.asarray(False)), + (2, anp.asarray(2)), + (DummyMetric(anp.asarray(42)), anp.asarray(42)), + (anp.asarray(2), anp.asarray(2)), + ], +) +def test_metrics_and(second_operand, expected_result): + """Test that `and` operator works and returns an operator metric.""" + first_metric = DummyMetric( + anp.asarray(True) if isinstance(second_operand, bool) else anp.asarray(42), + ) + + final_and = first_metric & second_operand + assert isinstance(final_and, OperatorMetric) + final_and.update(anp.asarray(0)) # dummy value to get array namespace + assert np.allclose(expected_result, final_and.compute()) + + if not isinstance(second_operand, DummyMetric): + with pytest.raises(TypeError, match="unsupported operand type.*"): + final_rand = second_operand & first_metric # type: ignore + else: + final_rand = second_operand & first_metric + assert isinstance(final_rand, OperatorMetric) + final_rand.update(anp.asarray(0)) # dummy value to get array namespace + assert np.allclose(expected_result, final_rand.compute()) + + +@pytest.mark.parametrize( + ("second_operand", "expected_result"), + [ + (DummyMetric(anp.asarray(2)), anp.asarray(True)), + (2, anp.asarray(True)), + (2.0, anp.asarray(True)), + (anp.asarray(2), anp.asarray(True)), + ], +) +def test_metrics_eq(second_operand, expected_result): + """Test that `eq` operator works and returns an operator metric.""" + first_metric = DummyMetric( + anp.asarray( + 2, + dtype=anp.float32 if isinstance(second_operand, float) else None, + ), + ) + + final_eq = first_metric == second_operand + assert isinstance(final_eq, OperatorMetric) + final_eq.update(anp.asarray(0)) # dummy value to get array namespace + assert anp.all(expected_result == final_eq.compute()) + + +@pytest.mark.parametrize( + ("second_operand", "expected_result"), + [ + (DummyMetric(anp.asarray(2)), anp.asarray(2)), + (2, anp.asarray(2)), + (2.0, anp.asarray(2.0)), + (anp.asarray(2), anp.asarray(2)), + ], +) +def test_metrics_floordiv(second_operand, expected_result): + """Test that `floordiv` operator works and returns an operator metric.""" + first_metric = DummyMetric( + anp.asarray( + 5, # python scalars can only be promoted with floating-point arrays + dtype=anp.float32 if isinstance(second_operand, float) else None, + ), + ) + + final_floordiv = first_metric // second_operand + assert isinstance(final_floordiv, OperatorMetric) + final_floordiv.update(anp.asarray(0)) # dummy value to get array namespace + + assert np.allclose(expected_result, final_floordiv.compute()) + + +@pytest.mark.parametrize( + ("second_operand", "expected_result"), + [ + (DummyMetric(anp.asarray(2)), anp.asarray(True)), + (2, anp.asarray(True)), + (2.0, anp.asarray(True)), + (anp.asarray(2), anp.asarray(True)), + ], +) +def test_metrics_ge(second_operand, expected_result): + """Test that `ge` operator works and returns an operator metric.""" + first_metric = DummyMetric( + anp.asarray( + 5, # python scalars can only be promoted with floating-point arrays + dtype=anp.float32 if isinstance(second_operand, float) else None, + ), + ) + + final_ge = first_metric >= second_operand + assert isinstance(final_ge, OperatorMetric) + final_ge.update(anp.asarray(0)) # dummy value to get array namespace + + assert anp.all(expected_result == final_ge.compute()) + + +@pytest.mark.parametrize( + ("second_operand", "expected_result"), + [ + (DummyMetric(2), anp.asarray(True)), + (2, anp.asarray(True)), + (2.0, anp.asarray(True)), + (anp.asarray(2), anp.asarray(True)), + ], +) +def test_metrics_gt(second_operand, expected_result): + """Test that `gt` operator works and returns an operator metric.""" + first_metric = DummyMetric( + anp.asarray( + 5, # python scalars can only be promoted with floating-point arrays + dtype=anp.float32 if isinstance(second_operand, float) else None, + ), + ) + + final_gt = first_metric > second_operand + assert isinstance(final_gt, OperatorMetric) + final_gt.update(anp.asarray(0)) # dummy value to get array namespace + + assert anp.all(expected_result == final_gt.compute()) + + +def test_metrics_invert(): + """Test that `invert` operator works and returns an operator metric.""" + first_metric = DummyMetric(anp.asarray(1)) + + final_inverse = ~first_metric + assert isinstance(final_inverse, OperatorMetric) + final_inverse.update(anp.asarray(0)) # dummy value to get array namespace + assert np.allclose(anp.asarray(-2), final_inverse.compute()) + + +@pytest.mark.parametrize( + ("second_operand", "expected_result"), + [ + (DummyMetric(2), anp.asarray(False)), + (2, anp.asarray(False)), + (2.0, anp.asarray(False)), + (anp.asarray(2), anp.asarray(False)), + ], +) +def test_metrics_le(second_operand, expected_result): + """Test that `le` operator works and returns an operator metric.""" + first_metric = DummyMetric( + anp.asarray( + 5, # python scalars can only be promoted with floating-point arrays + dtype=anp.float32 if isinstance(second_operand, float) else None, + ), + ) + + final_le = first_metric <= second_operand + assert isinstance(final_le, OperatorMetric) + final_le.update(anp.asarray(0)) # dummy value to get array namespace + + assert anp.all(expected_result == final_le.compute()) + + +@pytest.mark.parametrize( + ("second_operand", "expected_result"), + [ + (DummyMetric(2), anp.asarray(False)), + (2, anp.asarray(False)), + (2.0, anp.asarray(False)), + (anp.asarray(2), anp.asarray(False)), + ], +) +def test_metrics_lt(second_operand, expected_result): + """Test that `lt` operator works and returns an operator metric.""" + first_metric = DummyMetric( + anp.asarray( + 5, # python scalars can only be promoted with floating-point arrays + dtype=anp.float32 if isinstance(second_operand, float) else None, + ), + ) + + final_lt = first_metric < second_operand + assert isinstance(final_lt, OperatorMetric) + final_lt.update(anp.asarray(0)) # dummy value to get array namespace + + assert anp.all(expected_result == final_lt.compute()) + + +@pytest.mark.parametrize( + ("second_operand", "expected_result"), + [ + (DummyMetric(anp.asarray([2, 2, 2])), anp.asarray(12)), + (anp.asarray([2, 2, 2]), anp.asarray(12)), + ], +) +def test_metrics_matmul(second_operand, expected_result): + """Test that `matmul` operator works and returns an operator metric.""" + first_metric = DummyMetric(anp.asarray([2, 2, 2])) + + final_matmul = first_metric @ second_operand + assert isinstance(final_matmul, OperatorMetric) + final_matmul.update(anp.asarray(0)) # dummy value to get array namespace + + assert np.allclose(expected_result, final_matmul.compute()) + + +@pytest.mark.parametrize( + ("second_operand", "expected_result"), + [ + (DummyMetric(2), anp.asarray(1)), + (2, anp.asarray(1)), + (2.0, anp.asarray(1)), + (anp.asarray(2), anp.asarray(1)), + ], +) +def test_metrics_mod(second_operand, expected_result): + """Test that `mod` operator works and returns an operator metric.""" + first_metric = DummyMetric( + anp.asarray( + 5, # python scalars can only be promoted with floating-point arrays + dtype=anp.float32 if isinstance(second_operand, float) else None, + ), + ) + + final_mod = first_metric % second_operand + assert isinstance(final_mod, OperatorMetric) + final_mod.update(anp.asarray(0)) # dummy value to get array namespace + + assert np.allclose(expected_result, final_mod.compute()) + + +@pytest.mark.parametrize( + ("second_operand", "expected_result"), + [ + (DummyMetric(2), anp.asarray(4)), + (2, anp.asarray(4)), + (2.0, anp.asarray(4.0)), + pytest.param(anp.asarray(2), anp.asarray(4)), + ], +) +def test_metrics_mul(second_operand, expected_result): + """Test that `mul` operator works and returns an operator metric.""" + first_metric = DummyMetric( + anp.asarray( + 2, # python scalars can only be promoted with floating-point arrays + dtype=anp.float32 if isinstance(second_operand, float) else None, + ), + ) + + final_mul = first_metric * second_operand + assert isinstance(final_mul, OperatorMetric) + final_mul.update(anp.asarray(0)) # dummy value to get array namespace + assert np.allclose(expected_result, final_mul.compute()) + + if not isinstance(second_operand, DummyMetric): + with pytest.raises(TypeError, match="unsupported operand type.*"): + final_rmul = second_operand * first_metric # type: ignore + else: + final_rmul = second_operand * first_metric + assert isinstance(final_rmul, OperatorMetric) + final_rmul.update(anp.asarray(0)) # dummy value to get array namespace + assert np.allclose(expected_result, final_rmul.compute()) + + +@pytest.mark.parametrize( + ("second_operand", "expected_result"), + [ + (DummyMetric(anp.asarray(2)), anp.asarray(False)), + (2, anp.asarray(False)), + (2.0, anp.asarray(False)), + (anp.asarray(2), anp.asarray(False)), + ], +) +def test_metrics_ne(second_operand, expected_result): + """Test that `!=` operator works and returns an operator metric.""" + first_metric = DummyMetric( + anp.asarray( + 2, # python scalars can only be promoted with floating-point arrays + dtype=anp.float32 if isinstance(second_operand, float) else None, + ), + ) + + final_ne = first_metric != second_operand + assert isinstance(final_ne, OperatorMetric) + final_ne.update(anp.asarray(0)) # dummy value to get array namespace + + assert anp.all(expected_result == final_ne.compute()) + + +def test_metrics_neg(): + """Test that `neg` operator works and returns an operator metric.""" + first_metric = DummyMetric(anp.asarray(1)) + + final_neg = -first_metric + assert isinstance(final_neg, OperatorMetric) + final_neg.update(anp.asarray(0)) # dummy value to get array namespace + assert np.allclose(anp.asarray(-1), final_neg.compute()) + + +@pytest.mark.parametrize( + ("second_operand", "expected_result"), + [ + (DummyMetric(anp.asarray([1, 0, 3])), anp.asarray([-1, -2, 3])), + (anp.asarray([1, 0, 3]), anp.asarray([-1, -2, 3])), + ], +) +def test_metrics_or(second_operand, expected_result): + """Test that `or` operator works and returns an operator metric.""" + first_metric = DummyMetric(anp.asarray([-1, -2, 3])) + + final_or = first_metric | second_operand + assert isinstance(final_or, OperatorMetric) + final_or.update(anp.asarray(0)) # dummy value to get array namespace + assert np.allclose(expected_result, final_or.compute()) + + if not isinstance(second_operand, DummyMetric): + with pytest.raises(TypeError, match="unsupported operand type.*"): + final_ror = second_operand | first_metric # type: ignore + else: + final_ror = second_operand | first_metric + assert isinstance(final_ror, OperatorMetric) + final_ror.update(anp.asarray(0)) # dummy value to get array namespace + assert np.allclose(expected_result, final_ror.compute()) + + +def test_metrics_pos(): + """Test that `pos` operator works and returns an operator metric.""" + first_metric = DummyMetric(anp.asarray(-1)) + + final_pos = +first_metric + assert isinstance(final_pos, OperatorMetric) + final_pos.update(np.asanyarray(0)) # dummy value to get array namespace + assert np.allclose(anp.asarray(1), final_pos.compute()) + + +@pytest.mark.parametrize( + ("second_operand", "expected_result"), + [ + (DummyMetric(2), anp.asarray(4)), + (2, anp.asarray(4)), + (2.0, anp.asarray(4.0)), + (anp.asarray(2), anp.asarray(4)), + ], +) +def test_metrics_pow(second_operand, expected_result): + """Test that `pow` operator works and returns an operator metric.""" + first_metric = DummyMetric( + anp.asarray( + 2, + dtype=anp.float32 if isinstance(second_operand, float) else None, + ), + ) + + final_pow = first_metric**second_operand + + assert isinstance(final_pow, OperatorMetric) + + final_pow.update(np.asarray(0)) # dummy value to get array namespace + assert np.allclose(expected_result, final_pow.compute()) + + +@pytest.mark.parametrize( + ("second_operand", "expected_result"), + [ + (DummyMetric(2), anp.asarray(1)), + (2, anp.asarray(1)), + (2.0, anp.asarray(1.0)), + (anp.asarray(2), anp.asarray(1)), + ], +) +def test_metrics_sub(second_operand, expected_result): + """Test that `sub` operator works and returns an operator metric.""" + first_metric = DummyMetric( + anp.asarray( + 3, + dtype=anp.float32 if isinstance(second_operand, float) else None, + ), + ) + + final_sub = first_metric - second_operand + + assert isinstance(final_sub, OperatorMetric) + final_sub.update(anp.asarray(0)) # dummy value to get array namespace + assert np.allclose(expected_result, final_sub.compute()) + + +@pytest.mark.parametrize( + ("second_operand", "expected_result"), + [ + (DummyMetric(anp.asarray(3.0)), anp.asarray(2.0)), + (3, anp.asarray(2.0)), + (3.0, anp.asarray(2.0)), + (anp.asarray(3.0), anp.asarray(2.0)), + ], +) +def test_metrics_truediv(second_operand, expected_result): + """Test that `truediv` operator works and returns an operator metric.""" + first_metric = DummyMetric(anp.asarray(6.0)) # only floating-point arrays + + final_truediv = first_metric / second_operand + + assert isinstance(final_truediv, OperatorMetric) + final_truediv.update(anp.asarray(0)) # dummy value to get array namespace + assert np.allclose(expected_result, final_truediv.compute()) + + +@pytest.mark.parametrize( + ("second_operand", "expected_result"), + [ + (DummyMetric([1, 0, 3]), anp.asarray([-2, -2, 0])), + (anp.asarray([1, 0, 3]), anp.asarray([-2, -2, 0])), + ], +) +def test_metrics_xor(second_operand, expected_result): + """Test that `xor` operator works and returns an operator metric.""" + first_metric = DummyMetric([-1, -2, 3]) + + final_xor = first_metric ^ second_operand + assert isinstance(final_xor, OperatorMetric) + final_xor.update(anp.asarray(0)) # dummy value to get array namespace + assert np.allclose(expected_result, final_xor.compute()) + + if not isinstance(second_operand, DummyMetric): + with pytest.raises(TypeError, match="unsupported operand type.*"): + final_rxor = second_operand ^ first_metric # type: ignore + else: + final_rxor = second_operand ^ first_metric + assert isinstance(final_rxor, OperatorMetric) + final_rxor.update(anp.asarray(0)) # dummy value to get array namespace + assert np.allclose(expected_result, final_rxor.compute()) + + +def test_operator_metrics_update(): + """Test update method for operator metrics.""" + compos = DummyMetric(anp.asarray(5)) + DummyMetric(anp.asarray(4)) + + assert isinstance(compos, OperatorMetric) + compos.update(anp.asarray(0)) # dummy value to get array namespace + compos.update(anp.asarray(0)) # dummy value to get array namespace + compos.update(anp.asarray(0)) # dummy value to get array namespace + + assert isinstance(compos.metric_a, DummyMetric) + assert isinstance(compos.metric_b, DummyMetric) + + assert compos.metric_a._num_updates == 3 # type: ignore + assert compos.metric_b._num_updates == 3 # type: ignore diff --git a/tests/cyclops/evaluate/metrics/experimental/testers.py b/tests/cyclops/evaluate/metrics/experimental/testers.py new file mode 100644 index 000000000..b36f93d53 --- /dev/null +++ b/tests/cyclops/evaluate/metrics/experimental/testers.py @@ -0,0 +1,357 @@ +"""Testers for metrics.""" +from functools import partial +from typing import Any, Callable, Dict, Optional, Sequence, Type + +import array_api_compat as apc +import numpy as np + +from cyclops.evaluate.metrics.experimental.metric import Metric +from cyclops.evaluate.metrics.experimental.utils.ops import clone, flatten +from cyclops.evaluate.metrics.experimental.utils.typing import Array + + +def _assert_allclose( + cyclops_result: Any, + ref_result: Any, + atol: float = 1e-8, + key: Optional[str] = None, +) -> None: + """Recursively assert that two results are within a certain tolerance.""" + if apc.is_array_api_obj(cyclops_result) and apc.is_array_api_obj(ref_result): + # move to cpu and convert to numpy + cyclops_result = np.from_dlpack(apc.to_device(cyclops_result, "cpu")) + ref_result = np.from_dlpack(apc.to_device(ref_result, "cpu")) + + np.testing.assert_allclose( + cyclops_result, + ref_result, + atol=atol, + equal_nan=True, + ) + + # multi output comparison + elif isinstance(cyclops_result, Sequence): + for cyc_res, ref_res in zip(cyclops_result, ref_result): + _assert_allclose(cyc_res, ref_res, atol=atol) + elif isinstance(cyclops_result, dict): + if key is None: + raise KeyError("Provide Key for Dict based metric results.") + _assert_allclose(cyclops_result[key], ref_result, atol=atol) + else: + raise ValueError("Unknown format for comparison") + + +def _assert_array(cyclops_result: Any, key: Optional[str] = None) -> None: + """Recursively check that some input only consists of Arrays.""" + if isinstance(cyclops_result, Sequence): + for res in cyclops_result: + _assert_array(res) + elif isinstance(cyclops_result, Dict): + if key is None: + raise KeyError("Provide Key for Dict based metric results.") + assert apc.is_array_api_obj(cyclops_result[key]) + else: + assert apc.is_array_api_obj(cyclops_result) + + +def _class_impl_test( # noqa: PLR0912 + target: Array, + preds: Array, + metric_class: Type[Metric], + reference_metric: Callable[..., Any], + metric_args: Optional[Dict[str, Any]] = None, + atol: float = 1e-8, + device: str = "cpu", + use_device_for_ref: bool = False, +): + """Test output of metric class against a reference metric.""" + assert apc.is_array_api_obj(target) and apc.is_array_api_obj(preds), ( + f"`target` and `preds` must be Array API compatible objects, " + f"got {type(target)} and {type(preds)}." + ) + + t_size = target.shape[0] + p_size = preds.shape[0] + assert ( + p_size == t_size + ), f"`preds` and `target` have different number of samples: {p_size} and {t_size}." + num_batches = p_size + + # instantiate metric + metric_args = metric_args or {} + metric = metric_class(**metric_args) + + # check that the metric can be cloned + metric_clone = metric.clone() + assert metric_clone is not metric, "Metric clone should not be the same object." + assert type(metric_clone) is type(metric), "Metric clone should be the same type." + + # move to device + metric = metric.to_device(device) + preds = apc.to_device(preds, device) + target = apc.to_device(target, device) + + for i in range(num_batches): # type: ignore + # compute batch result and aggregate for global result + cyc_batch_result = metric(target[i, ...], preds[i, ...]) + + ref_batch_result = reference_metric( + target=apc.to_device( + target[i, ...], + device if use_device_for_ref else "cpu", + ), + preds=apc.to_device(preds[i, ...], device if use_device_for_ref else "cpu"), + ) + if isinstance(cyc_batch_result, dict): + for key in cyc_batch_result: + _assert_allclose( + cyc_batch_result, + ref_batch_result[key], + atol=atol, + key=key, + ) + else: + _assert_allclose(cyc_batch_result, ref_batch_result, atol=atol) + + # check on all batches on all ranks + cyc_result = metric.compute() + if isinstance(cyc_result, dict): + for key in cyc_result: + _assert_array(cyc_result, key=key) + else: + _assert_array(cyc_result) + + xp = apc.array_namespace(target, preds) + if preds.ndim == 1 or (preds.ndim == 2 and target.ndim == 1): + # 0-D binary and multiclass cases + total_preds = preds + else: + total_preds = xp.concat([preds[i, ...] for i in range(num_batches)]) # type: ignore + + if target.ndim > 1: + total_target = xp.concat([target[i, ...] for i in range(num_batches)]) # type: ignore + else: + total_target = target + ref_result = reference_metric( + target=apc.to_device(total_target, device if use_device_for_ref else "cpu"), + preds=apc.to_device(total_preds, device if use_device_for_ref else "cpu"), + ) + + # assert after aggregation + if isinstance(ref_result, dict): + for key in ref_result: + _assert_allclose(cyc_result, ref_result[key], atol=atol, key=key) + else: + _assert_allclose(cyc_result, ref_result, atol=atol) + + +def _function_impl_test( + target: Array, + preds: Array, + metric_function: Callable[..., Any], + reference_metric: Callable[..., Any], + metric_args: Optional[Dict[str, Any]] = None, + atol: float = 1e-8, + device: str = "cpu", + use_device_for_ref: bool = False, +): + """Test output of a metric function against a reference metric.""" + assert apc.is_array_api_obj(target) and apc.is_array_api_obj(preds), ( + f"`target` and `preds` must be Array API compatible objects, " + f"got {type(target)} and {type(preds)}." + ) + + t_size = target.shape[0] + p_size = preds.shape[0] + assert ( + p_size == t_size + ), f"`preds` and `target` have different number of samples: {p_size} and {t_size}." + + metric_args = metric_args or {} + metric = partial(metric_function, **metric_args) + + preds = apc.to_device(preds, device) + target = apc.to_device(target, device) + + num_batches = p_size + for i in range(num_batches): + cyclops_result = metric(target[i, ...], preds[i, ...]) + + # always compare to reference metric on CPU + ref_result = reference_metric( + target=apc.to_device( + target[i, ...], + device if use_device_for_ref else "cpu", + ), + preds=apc.to_device(preds[i, ...], device if use_device_for_ref else "cpu"), + ) + + _assert_allclose(cyclops_result, ref_result, atol=atol) + + +class MetricTester: + """Test class for all metrics.""" + + atol: float = 1e-8 + + def run_metric_function_implementation_test( + self, + target: Array, + preds: Array, + metric_function: Callable[..., Any], + reference_metric: Callable[..., Any], + metric_args: Optional[Dict[str, Any]] = None, + device: str = "cpu", + use_device_for_ref: bool = False, + ): + """Test output of a metric function against a reference metric. + + Parameters + ---------- + target : Array + The target array. Any Array API compatible object is accepted. + preds : Array + The predictions array. Any Array API compatible object is accepted. + metric_function : Callable[..., Any] + The metric function to test. + reference_metric : Callable[..., Any] + The reference metric function. + metric_args : Dict[str, Any], optional + The arguments to pass to the metric function. + device : str, optional, default="cpu" + The device to compute the metric on. + use_device_for_ref : bool, optional, default=False + Whether to compute the reference metric on the same device as `device`. + + """ + return _function_impl_test( + target=target, + preds=preds, + metric_function=metric_function, + reference_metric=reference_metric, + metric_args=metric_args, + atol=self.atol, + device=device, + use_device_for_ref=use_device_for_ref, + ) + + def run_metric_class_implementation_test( + self, + target: Array, + preds: Array, + metric_class: Type[Metric], + reference_metric: Callable[..., Any], + metric_args: Optional[dict] = None, + device: str = "cpu", + use_device_for_ref: bool = False, + ): + """Test output of a metric class against a reference metric. + + Parameters + ---------- + target : Array + The target array. Any Array API compatible object is accepted. + preds : Array + The predictions array. Any Array API compatible object is accepted. + metric_class : Metric + The metric class to test. + reference_metric : Callable[..., Any] + The reference metric function. + metric_args : Optional[dict], optional + The arguments to pass to the metric function. + device : str, optional, default="cpu" + The device to compute the metric on. + use_device_for_ref : bool, optional, default=False + Whether to compute the reference metric on the same device as `device`. + """ + return _class_impl_test( + target=target, + preds=preds, + metric_class=metric_class, + reference_metric=reference_metric, + metric_args=metric_args, + atol=self.atol, + device=device, + use_device_for_ref=use_device_for_ref, + ) + + +class DummyMetric(Metric): + """Dummy metric for testing core components.""" + + name = "Dummy" + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.add_state_default_factory( + "x", + lambda xp: xp.asarray(0.0, dtype=xp.float32, device=self.device), # type: ignore + dist_reduce_fn="sum", + ) + + def _update_state(self, x: Array) -> None: + """Update state.""" + self.x += x # type: ignore + + def _compute_metric(self) -> Array: + """Compute value.""" + return self.x # type: ignore + + +class DummyListStateMetric(Metric): + """Dummy metric with list state for testing core components.""" + + name = "DummyListState" + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.add_state_default_factory("x", list, dist_reduce_fn="cat") # type: ignore + + def _update_state(self, x: Array): + """Update state.""" + self.x.append(apc.to_device(x, self.device)) # type: ignore + + def _compute_metric(self): + """Compute value.""" + return self.x # type: ignore + + +def _inject_ignore_index(array, ignore_index): + """Inject ignore index into array.""" + if ignore_index is None: + return array + + if isinstance(ignore_index, int): + ignore_index = (ignore_index,) + + if any(any(flatten(array) == idx) for idx in ignore_index): + return array + + xp = apc.array_namespace(array) + classes = xp.unique_values(array) + + # select random indices (same size as ignore_index) and set them to ignore_index + indices = np.random.randint(0, apc.size(array), size=len(ignore_index)) # type: ignore + array = clone(array) + + # use loop + basic indexing to set ignore_index + for idx, ignore_idx in zip(indices, ignore_index): # type: ignore + xp.reshape(array, (-1,))[idx] = ignore_idx + + # if all classes are removed, add one back + batch_size = array.shape[0] if array.ndim > 1 else 1 + for i in range(batch_size): + batch = array[i, ...] if array.ndim > 1 else array + new_classes = xp.unique_values(batch) + class_not_in = [c not in new_classes for c in classes] + + if any(class_not_in): + missing_class = int(np.where(class_not_in)[0][0]) + mask = xp.zeros_like(batch, dtype=xp.bool, device=apc.device(batch)) + for idx in ignore_index: + mask = xp.logical_or(mask, batch == idx) + ignored_idx = np.where(mask)[0] + if len(ignored_idx) > 0: + batch[int(ignored_idx[0])] = missing_class + + return array diff --git a/tests/cyclops/evaluate/metrics/experimental/utils/test_ops.py b/tests/cyclops/evaluate/metrics/experimental/utils/test_ops.py new file mode 100644 index 000000000..cf2d2c0a7 --- /dev/null +++ b/tests/cyclops/evaluate/metrics/experimental/utils/test_ops.py @@ -0,0 +1,1012 @@ +"""Test utility functions for performing operations on Arrays.""" +from collections import defaultdict, namedtuple + +import numpy as np +import numpy.array_api as anp +import pytest +import torch + +from cyclops.evaluate.metrics.experimental.utils.ops import ( + apply_to_array_collection, + bincount, + clone, + dim_zero_cat, + dim_zero_max, + dim_zero_mean, + dim_zero_min, + dim_zero_sum, + flatten, + flatten_seq, + moveaxis, + remove_ignore_index, + safe_divide, + sigmoid, + squeeze_all, +) +from cyclops.utils.optional import import_optional_module + + +cp = import_optional_module("cupy", error="ignore") + + +def multiply_by_two(x): + """Multiply the input by two.""" + return x * 2 + + +class TestApplyToArrayCollection: + """Test the `apply_to_array_collection` utility function.""" + + def test_apply_to_single_array(self): + """Test applying a function to a single array.""" + data = anp.asarray([1, 2, 3, 4, 5]) + + result = apply_to_array_collection(data, multiply_by_two) + + assert anp.all(result == anp.asarray([2, 4, 6, 8, 10])) + + def test_apply_to_list_of_arrays(self): + """Test applying a function to a list of arrays.""" + data = [anp.asarray([1, 2, 3]), anp.asarray([4, 5, 6])] + + result = apply_to_array_collection(data, multiply_by_two) + expected_result = [anp.asarray([2, 4, 6]), anp.asarray([8, 10, 12])] + + assert all(anp.all(a == b) for a, b in zip(result, expected_result)) + + def test_apply_to_tuple_of_arrays(self): + """Test applying a function to a tuple of arrays.""" + data = (anp.asarray([1, 2, 3]), anp.asarray([4, 5, 6])) + + result = apply_to_array_collection(data, multiply_by_two) + expected_result = (anp.asarray([2, 4, 6]), anp.asarray([8, 10, 12])) + + assert all(anp.all(a == b) for a, b in zip(result, expected_result)) + + def test_apply_to_dictionary_of_arrays(self): + """Test applying a function to a dictionary of arrays.""" + data = {"a": anp.asarray([1, 2, 3]), "b": anp.asarray([4, 5, 6])} + + result = apply_to_array_collection(data, multiply_by_two) + expected_result = {"a": anp.asarray([2, 4, 6]), "b": anp.asarray([8, 10, 12])} + + assert all( + anp.all(a == b) for a, b in zip(result.values(), expected_result.values()) + ) + assert all(k in result for k in expected_result) + + def test_apply_to_namedtuple_of_arrays(self): + """Test applying a function to a namedtuple of arrays.""" + Data = namedtuple("Data", ["a", "b"]) + data = Data(anp.asarray([1, 2, 3]), anp.asarray([4, 5, 6])) + + result = apply_to_array_collection(data, multiply_by_two) + expected_result = Data(anp.asarray([2, 4, 6]), anp.asarray([8, 10, 12])) + + assert all(anp.all(a == b) for a, b in zip(result, expected_result)) + assert all(k in result._fields for k in expected_result._fields) + + def test_return_input_data_if_not_array(self): + """Test returning the input data for non-array inputs.""" + data = 10 + result = apply_to_array_collection(data, multiply_by_two) + assert result == 10 + + def test_return_input_data_if_not_array_collection(self): + """Test returning the input data for non-array collections.""" + data = [ + [1, 2, 3], + [4, 5, 6], + [7, 8, 9], + ] + result = apply_to_array_collection(data, multiply_by_two) + assert result == data + + def test_handle_empty_list_input(self): + """Test handling an empty list input.""" + result = apply_to_array_collection([], multiply_by_two) + assert result == [] + + def test_handle_empty_tuple_input(self): + """Test handling an empty tuple input.""" + result = apply_to_array_collection((), multiply_by_two) + assert result == () + + def test_handle_empty_dictionary_input(self): + """Test handling an empty dictionary input.""" + result = apply_to_array_collection({}, multiply_by_two) + assert result == {} + + def test_handle_dictionary_with_non_string_keys(self): + """Test handling a dictionary with non-string keys.""" + data = {1: anp.asarray([1, 2, 3]), 2: anp.asarray([4, 5, 6])} + + result = apply_to_array_collection(data, multiply_by_two) + expected_result = {1: anp.asarray([2, 4, 6]), 2: anp.asarray([8, 10, 12])} + assert all( + anp.all(a == b) for a, b in zip(result.values(), expected_result.values()) + ) + assert all(k in result for k in expected_result) + + def test_handle_defaultdict_input(self): + """Test handling a defaultdict input.""" + data = defaultdict( + list, + {"a": anp.asarray([1, 2, 3]), "b": anp.asarray([4, 5, 6])}, + ) + + result = apply_to_array_collection(data, multiply_by_two) + expected_result = defaultdict( + list, + {"a": anp.asarray([2, 4, 6]), "b": anp.asarray([8, 10, 12])}, + ) + + assert all( + anp.all(a == b) for a, b in zip(result.values(), expected_result.values()) + ) + assert all(k in result for k in expected_result) + + def test_apply_to_nested_collections(self): + """Test applying a function to nested collections of arrays.""" + data = { + "a": anp.asarray( + [ + [1, 2, 3], + [4, 5, 6], + [7, 8, 9], + ], + ), + "b": [anp.asarray([10, 11, 12]), anp.asarray([13, 14, 15])], + "c": ( + anp.asarray([16, 17, 18]), + anp.asarray([19, 20, 21]), + ), + "d": { + "e": anp.asarray([22, 23, 24]), + "f": anp.asarray([25, 26, 27]), + }, + } + + result = apply_to_array_collection(data, multiply_by_two) + + expected_result = { + "a": anp.asarray( + [ + [2, 4, 6], + [8, 10, 12], + [14, 16, 18], + ], + ), + "b": [anp.asarray([20, 22, 24]), anp.asarray([26, 28, 30])], + "c": ( + anp.asarray([32, 34, 36]), + anp.asarray([38, 40, 42]), + ), + "d": { + "e": anp.asarray([44, 46, 48]), + "f": anp.asarray([50, 52, 54]), + }, + } + + for k in expected_result: + assert k in result + + if isinstance(expected_result[k], dict): + for kk in expected_result[k]: + assert kk in result[k] + assert anp.all(expected_result[k][kk] == result[k][kk]) + elif isinstance(expected_result[k], (tuple, list)): + assert all( + anp.all(a == b) for a, b in zip(result[k], expected_result[k]) + ) + else: + assert anp.all(expected_result[k] == result[k]) + + +class TestBincount: + """Test the `bincount` utility function.""" + + def test_non_negative_integers(self): + """Test using non-negative integers as input.""" + input_array = anp.asarray([0, 1, 1, 2, 2, 2]) + expected_output = anp.asarray([1, 2, 3]) + + result = bincount(input_array) + + assert anp.all(result == expected_output) + + def test_empty_array(self): + """Test using an empty array as input.""" + input_array = anp.asarray([], dtype=anp.int32) + expected_output = anp.asarray([], dtype=anp.int64) + + result = bincount(input_array, minlength=0) + + assert anp.all(result == expected_output) + + result = bincount(input_array, minlength=10) + expected_output = anp.zeros(10, dtype=anp.int64) + + assert anp.all(result == expected_output) + + def test_single_unique_value(self): + """Test using an array with a single unique value as input.""" + input_array = anp.asarray([3, 3, 3, 3]) + expected_output = anp.asarray([0, 0, 0, 4]) + + result = bincount(input_array) + + assert anp.all(result == expected_output) + + def test_no_repeated_values(self): + """Test using an array with no repeated values as input.""" + input_array = anp.asarray([0, 1, 2, 3, 4, 5]) + expected_output = anp.ones_like(input_array) + + result = bincount(input_array) + + assert anp.all(result == expected_output) + + def test_negative_integers(self): + """Test using an array with negative integers as input.""" + input_array = anp.asarray([-1, 0, 1, 2]) + + with pytest.raises(ValueError): + bincount(input_array) + + def test_negative_minlength(self): + """Test using a negative minlength as input.""" + input_array = anp.asarray([1, 2, 3]) + + with pytest.raises(ValueError): + bincount(input_array, minlength=-5) + + def test_different_shapes(self): + """Test using arrays and weights with different shapes as input.""" + input_array = anp.asarray([1, 2, 3]) + weights = anp.asarray([0.5, 0.5]) + + with pytest.raises(ValueError): + bincount(input_array, weights=weights) + + def test_not_one_dimensional(self): + """Test using a multi-dimensional array as input.""" + input_array = anp.asarray([[1, 2], [3, 4]]) + + with pytest.raises(ValueError): + bincount(input_array) + + def test_not_integer_type(self): + """Test using a non-integer array as input.""" + input_array = anp.asarray([1.5, 2.5, 3.5]) + + with pytest.raises(ValueError): + bincount(input_array) + + +class TestClone: + """Test the `clone` utility function.""" + + def test_clone_numpy_array(self): + """Test if the clone function creates a new copy of a numpy array.""" + x = np.array([1, 2, 3]) + + y = clone(x) + + # Check if y is a new copy of x + assert y is not x + assert np.array_equal(y, x) + + @pytest.mark.skipif(cp is None, reason="Cupy is not installed.") + @pytest.mark.integration_test() # machine for integration test has GPU + def test_clone_cupy_array(self): + """Test if the clone function creates a new copy of a cupy array.""" + try: + if not cp.cuda.is_available(): # type: ignore + pytest.skip("CUDA is not available.") + except cp.cuda.runtime.CUDARuntimeError: # type: ignore + pytest.skip("CUDA is not available.") + + x = cp.asarray([1, 2, 3]) # type: ignore + + y = clone(x) + + # Check if y is a new copy of x + assert y is not x + assert cp.array_equal(y, x) # type: ignore + + def test_clone_torch_tensor(self): + """Test if the clone function properly clones a torch tensor.""" + x = torch.tensor([1, 2, 3]) # type: ignore + + y = clone(x) + + # Check if y is a new copy of x + assert y is not x + assert torch.equal(y, x) # type: ignore + + def test_clone_empty_array(self): + """Test if the clone function creates a new copy of an empty array.""" + x = anp.asarray([]) + + y = clone(x) + + # Check if y is a new copy of x + assert y is not x + assert anp.all(y == x) + + +class TestDimZeroCat: + """Test the `dim_zero_cat` utility function.""" + + def test_returns_input_if_array_or_empty_list_tuple(self): + """Test if the input is an array or empty list/tuple.""" + array1 = anp.asarray([1, 2, 3]) + array2 = anp.asarray([4, 5, 6]) + empty_list = [] + + result1 = dim_zero_cat(array1) + result2 = dim_zero_cat([array2]) + result3 = dim_zero_cat([]) + result4 = dim_zero_cat(empty_list) + + np.testing.assert_array_equal(result1, array1) + np.testing.assert_array_equal(result2, array2) + assert result3 == [] + assert result4 == [] + + def test_concatenates_arrays_along_zero_dimension(self): + """Test concatenation along the zero dimension.""" + array1 = anp.asarray([1, 2, 3]) + array2 = anp.asarray([4, 5, 6]) + + result = dim_zero_cat([array1, array2]) + + expected_result = anp.asarray([1, 2, 3, 4, 5, 6]) + np.testing.assert_array_equal(result, expected_result) + + def test_arrays_with_different_shapes(self): + """Test handling of arrays with different shapes.""" + array1 = anp.asarray([1, 2, 3]) + array2 = anp.asarray([[4, 5, 6], [7, 8, 9]]) + + with pytest.raises(ValueError): + dim_zero_cat([array1, array2]) + + def test_raises_type_error_if_input_not_array_or_list_tuple_of_arrays(self): + """Test raising TypeError if input is not an array or a list/tuple of arrays.""" + array1 = anp.asarray([1, 2, 3]) + array2 = anp.asarray([4, 5, 6]) + + with pytest.raises(TypeError): + dim_zero_cat([array1, array2, 7]) + + with pytest.raises(TypeError): + dim_zero_cat([array1, array2, None, "hello"]) + + with pytest.raises(TypeError): + dim_zero_cat(123) + + def test_raises_value_error_if_input_list_empty(self): + """Test raising ValueError if input list is empty.""" + result = dim_zero_cat([]) + assert result == [] + + +def test_dim_zero_max(): + """Test the `dim_zero_max` utility function.""" + # happy path + array1 = anp.asarray([1, 2, 3]) # 1d + array2 = anp.asarray([[4, 5, 6], [7, 8, 9]]) + + result1 = dim_zero_max(array1) + result2 = dim_zero_max(array2) + + expected_result1 = anp.asarray(3) + np.testing.assert_array_equal(result1, expected_result1) + expected_result2 = anp.asarray([7, 8, 9]) + np.testing.assert_array_equal(result2, expected_result2) + + # edge cases + with pytest.raises(ValueError): + dim_zero_max(anp.asarray([])) + + with pytest.raises(AttributeError): + dim_zero_max([array1, array2]) + + with pytest.raises(TypeError): + dim_zero_max([1, 2, 3]) + + # 1x1x1 array + array3 = anp.asarray([[[1]]]) + result4 = dim_zero_max(array3) + expected_result4 = anp.asarray([[1]]) + np.testing.assert_array_equal(result4, expected_result4) + + +def test_dim_zero_mean(): + """Test the `dim_zero_mean` utility function.""" + # happy path + array1 = anp.asarray([1, 2, 3]) + array2 = anp.asarray([[4, 5, 6], [7, 8, 9]]) + array3 = anp.asarray([[[10, 11, 12], [13, 14, 15]], [[16, 17, 18], [19, 20, 21]]]) + + result1 = dim_zero_mean(array1) + expected_result1 = anp.asarray(2) + np.testing.assert_array_equal(result1, expected_result1) + + result2 = dim_zero_mean(array2) + expected_result2 = anp.asarray([5.5, 6.5, 7.5]) + np.testing.assert_allclose(result2, expected_result2) + + result3 = dim_zero_mean(array3) + expected_result3 = anp.asarray([[13, 14, 15], [16, 17, 18]], dtype=anp.float32) + np.testing.assert_allclose(result3, expected_result3) + + # edge cases + result4 = dim_zero_mean(anp.asarray([])) + np.testing.assert_array_equal(result4, anp.asarray(anp.nan)) + + with pytest.raises(AttributeError): + dim_zero_mean([array1, array2]) + + with pytest.raises(TypeError): + dim_zero_mean([1, 2, 3]) + + +def test_dim_zero_min(): + """Test the `dim_zero_min` utility function.""" + # expected behavior + array1 = anp.asarray([1, 2, 3]) + array2 = anp.asarray([[4, 5, 6], [7, 8, 9]]) + array3 = anp.asarray([[[10, 11, 12], [13, 14, 15]], [[-16, 17, 18], [19, 20, 21]]]) + + result1 = dim_zero_min(array1) + expected_result1 = anp.asarray(1) + np.testing.assert_array_equal(result1, expected_result1) + + result2 = dim_zero_min(array2) + expected_result2 = anp.asarray([4, 5, 6]) + np.testing.assert_array_equal(result2, expected_result2) + + result3 = dim_zero_min(array3) + expected_result3 = anp.asarray([[-16, 11, 12], [13, 14, 15]]) + np.testing.assert_array_equal(result3, expected_result3) + + # edge cases + with pytest.raises(ValueError): + dim_zero_min(anp.asarray([])) + + with pytest.raises(AttributeError): + dim_zero_min([array1, array2]) + + with pytest.raises(TypeError): + dim_zero_min([1, 2, 3]) + + +def test_dim_zero_sum(): + """Test the `dim_zero_sum` utility function.""" + array1 = anp.asarray([1, 2, 3]) + array2 = anp.asarray([[4, 5, 6], [7, 8, 9]]) + array3 = anp.asarray([[[10, 11, 12], [13, 14, 15]], [[16, 17, 18], [19, 20, 21]]]) + + result1 = dim_zero_sum(array1) + expected_result1 = anp.asarray(6) + np.testing.assert_array_equal(result1, expected_result1) + + result2 = dim_zero_sum(array2) + expected_result2 = anp.asarray([11, 13, 15]) + np.testing.assert_array_equal(result2, expected_result2) + + result3 = dim_zero_sum(array3) + expected_result3 = anp.asarray([[26, 28, 30], [32, 34, 36]]) + np.testing.assert_array_equal(result3, expected_result3) + + # edge cases + result4 = dim_zero_sum(anp.asarray([])) + np.testing.assert_array_equal(result4, anp.asarray(0)) + + with pytest.raises(AttributeError): + dim_zero_sum([array1, array2]) + + with pytest.raises(TypeError): + dim_zero_sum([1, 2, 3]) + + +def test_flatten(): + """Test the `flatten` utility function.""" + x = anp.asarray([1, 2, 3]) + result = flatten(x) + assert anp.all(result == x) + assert not np.shares_memory(result, x) + + x = anp.asarray([[1, 2, 3], [4, 5, 6]]) + result = flatten(x) + expected_result = anp.asarray([1, 2, 3, 4, 5, 6]) + assert anp.all(result == expected_result) + assert not np.shares_memory(result, x) + + x = anp.asarray([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + result = flatten(x) + expected_result = anp.asarray([1, 2, 3, 4, 5, 6, 7, 8]) + assert anp.all(result == expected_result) + assert not np.shares_memory(result, x) + + x = anp.asarray([]) + result = flatten(x) + assert anp.all(result == x) + assert not np.shares_memory(result, x) + + +def test_flatten_seq(): + """Test the `flatten_seq` utility function.""" + # happy path + x = [] + assert flatten_seq(x) == [] + + x = [1, 2, 3] + assert flatten_seq(x) == x + + x = [[1, 2, 3], [4, 5, 6]] + assert flatten_seq(x) == [1, 2, 3, 4, 5, 6] + + x = [[1, [2, [3]]], [4, [5, [6]]]] + assert flatten_seq(x) == [1, 2, 3, 4, 5, 6] + + x = [[1, 2, 3], "abc", [4, [5, 6]], 7, anp.asarray(1)] + assert flatten_seq(x) == [ + 1, + 2, + 3, + "a", + "b", + "c", + 4, + 5, + 6, + 7, + anp.asarray(1), + ] + + # edge cases + x = 123 + with pytest.raises(TypeError): + flatten_seq(x) # type: ignore + + x = [1, None, [2, None, 3]] + assert flatten_seq(x) == [1, None, 2, None, 3] + + x = [[], [1, 2], [], [3, 4], []] + assert flatten_seq(x) == [1, 2, 3, 4] + + x = [[None, None], [None, None]] + assert flatten_seq(x) == [None, None, None, None] + + x = [[], [], []] + assert flatten_seq(x) == [] + + +class TestMoveAxis: + """Test the `moveaxis` utility function.""" + + def test_move_single_axis(self): + """Test moving a single axis.""" + array = anp.zeros((2, 3, 4)) + + result = moveaxis(array, 0, 2) + expected_result = np.moveaxis(array, 0, 2) # type: ignore + + assert result.shape == expected_result.shape + assert np.shares_memory(result, array) + assert np.all(result == expected_result) + + def test_move_negative_indices(self): + """Test moving an axis with negative indices.""" + array = anp.zeros((2, 3, 4)) + + result = moveaxis(array, -1, -3) + expected_result = np.moveaxis(array, -1, -3) # type: ignore + + assert result.shape == expected_result.shape + assert np.shares_memory(result, array) + + assert np.all(result == expected_result) + + def test_move_multiple_axes(self): + """Test moving multiple axes.""" + array = anp.zeros((2, 3, 4)) + + result = moveaxis(array, (0, 1), (1, 0)) # type: ignore + expected_result = np.moveaxis(array, (0, 1), (1, 0)) # type: ignore + + assert result.shape == expected_result.shape + assert np.shares_memory(result, array) + assert np.all(result == expected_result) + + def test_move_same_position(self): + """Test moving an axis to the same position.""" + array = anp.zeros((2, 3, 4)) + + result = moveaxis(array, 0, 0) + + assert result.shape == (2, 3, 4) + assert np.shares_memory(result, array) + assert np.all(result == array) + + def test_move_outside_shape(self): + """Test moving an axis outside the shape of the array.""" + array = anp.zeros((2, 3, 4)) + + with pytest.raises(ValueError): + moveaxis(array, 0, 5) + + with pytest.raises(ValueError): + moveaxis(array, 0, -5) + + def test_raise_value_error_if_duplicate_values(self): + """Test passing duplicate values for source or destination.""" + array = anp.zeros((2, 3, 4)) + + with pytest.raises(ValueError): + moveaxis(array, (0, 0), (1, 2)) # type: ignore + + with pytest.raises(ValueError): + moveaxis(array, (0, 1), (1, 1)) # type: ignore + + def test_raise_value_error_if_source_and_destination_not_same_length(self): + """Test passing source and destination with different lengths.""" + array = anp.zeros((2, 3, 4)) + + with pytest.raises(ValueError): + moveaxis(array, (0, 1), (1, 0, 2)) # type: ignore + + def test_raise_value_error_if_source_or_destination_not_integer_or_tuple(self): + """Test passing source or destination as a non-integer or non-tuple.""" + array = anp.zeros((2, 3, 4)) + + # Test with source as a float + with pytest.raises(ValueError): + moveaxis(array, 0.5, 2) # type: ignore + + # Test with destination as a string + with pytest.raises(ValueError): + moveaxis(array, 0, "2") # type: ignore + + # Test with source as a list + with pytest.raises(ValueError): + moveaxis(array, [0], 2) # type: ignore + + # Test with destination as a dictionary + with pytest.raises(ValueError): + moveaxis(array, 0, {"2": 2}) # type: ignore + + +class TestRemoveIgnoreIndex: + """Test the `remove_ignore_index` utility function.""" + + def test_return_same_input_arrays_if_ignore_index_is_none(self): + """Test case when `ignore_index` is None.""" + input_arrays = (anp.asarray([1, 2, 3]), anp.asarray([4, 5, 6])) + ignore_index = None + + result = remove_ignore_index(*input_arrays, ignore_index=ignore_index) + + assert result == input_arrays + + def test_remove_samples_equal_to_ignore_index_from_input_arrays(self): + """Test removing samples that are equal to `ignore_index`.""" + target = anp.asarray([1, 2, 3]) + preds = anp.asarray([4, 5, 6]) + ignore_index = 2 + expected_target = anp.asarray([1, 3]) + expected_preds = anp.asarray([4, 6]) + + out_target, out_preds = remove_ignore_index( + target, + preds, + ignore_index=ignore_index, + ) + + assert anp.all(out_target == expected_target) + assert anp.all(out_preds == expected_preds) + + target = anp.asarray([[1, 2, 3], [4, 5, 6]]) + preds = anp.asarray([[7, 8, 9], [10, 11, 12]]) + ignore_index = 1 + expected_target = anp.asarray([2, 3, 4, 5, 6]) + expected_preds = anp.asarray([8, 9, 10, 11, 12]) + + out_target, out_preds = remove_ignore_index( + target, + preds, + ignore_index=ignore_index, + ) + + assert anp.all(out_target == expected_target) + assert anp.all(out_preds == expected_preds) + + def test_return_same_output_arrays_if_ignore_index_not_in_input_arrays(self): + """Test returning the same arrays if `ignore_index` is not in array.""" + input_arrays = (anp.asarray([1, 2, 3]), anp.asarray([4, 5, 6])) + ignore_index = 7 + + result = remove_ignore_index(*input_arrays, ignore_index=ignore_index) + + assert all(anp.all(a == b) for a, b in zip(result, input_arrays)) + + def test_raise_type_error_if_ignore_index_not_integer_or_tuple_of_integers(self): + """Test raising TypeError on invalid `ignore_index` type.""" + input_arrays = (anp.asarray([1, 2, 3]), anp.asarray([4, 5, 6])) + ignore_index = "ignore" + + with pytest.raises(TypeError): + remove_ignore_index(*input_arrays, ignore_index=ignore_index) # type: ignore + + def test_raise_type_error_if_input_arrays_not_array_objects(self): + """Test raising TypeError on invalid input array type.""" + input_arrays = ([1, 2, 3], [4, 5, 6]) + ignore_index = 2 + + with pytest.raises(TypeError): + remove_ignore_index(*input_arrays, ignore_index=ignore_index) + + def test_return_empty_tuple_if_all_input_arrays_empty(self): + """Test with all input arrays empty.""" + input_arrays = (anp.asarray([]), anp.asarray([])) + ignore_index = 2 + + result = remove_ignore_index(*input_arrays, ignore_index=ignore_index) + + assert all(anp.all(a == b) for a, b in zip(result, input_arrays)) + + def test_return_empty_tuple_if_all_samples_are_equal_to_ignore_index(self): + """Test ignoring all samples in input arrays.""" + input_arrays = (anp.asarray([1, 1, 1]), anp.asarray([1, 1, 1])) + ignore_index = 1 + expected_result = ( + anp.asarray([], dtype=anp.int64), + anp.asarray([], dtype=anp.int64), + ) + + result = remove_ignore_index(*input_arrays, ignore_index=ignore_index) + + assert all(anp.all(a == b) for a, b in zip(result, expected_result)) + + def test_remove_samples_with_tuple_ignore_index(self): + """Test with tuple of ignore_index values.""" + input_arrays = (anp.asarray([1, 2, 3]), anp.asarray([4, 5, 6])) + ignore_index = (2, 3) + + result = remove_ignore_index(*input_arrays, ignore_index=ignore_index) + + expected_result = (anp.asarray([1]), anp.asarray([4])) + assert all(anp.all(a == b) for a, b in zip(result, expected_result)) + + input_arrays = ( + anp.asarray([[1, 2, 3], [4, 5, 6]]), + anp.asarray([[7, 8, 9], [10, 11, 12]]), + ) + ignore_index = (2, 6) + + result = remove_ignore_index(*input_arrays, ignore_index=ignore_index) + + expected_result = (anp.asarray([1, 3, 4, 5]), anp.asarray([7, 9, 10, 11])) + assert all(anp.all(a == b) for a, b in zip(result, expected_result)) + + +class TestSafeDivide: + """Test the `safe_divide` utility function.""" + + def test_divide_non_zero_denominators(self): + """Test dividing two arrays with non-zero denominators.""" + numerator = anp.asarray([1.0, 2.0, 3.0]) + denominator = anp.asarray([2.0, 3.0, 4.0]) + expected_result = anp.asarray([0.5, 0.66666667, 0.75]) + + result = safe_divide(numerator, denominator) + + assert np.allclose(result, expected_result) + + def test_divide_zero_denominators(self): + """Test dividing two arrays with zero denominators, return array of zeros.""" + numerator = anp.asarray([1.0, 2.0, 3.0]) + denominator = anp.asarray([0.0, 0.0, 0.0]) + expected_result = anp.asarray([0.0, 0.0, 0.0]) + + result = safe_divide(numerator, denominator) + + assert anp.all(result == expected_result) + + def test_divide_one_zero_denominator(self): + """Test dividing two arrays with one zero denominator.""" + numerator = anp.asarray([1.0, 2.0, 3.0]) + denominator = anp.asarray([2.0, 0.0, 4.0]) + expected_result = anp.asarray([0.5, 0.0, 0.75]) + + result = safe_divide(numerator, denominator) + + assert anp.all(result == expected_result) + + def test_divide_empty_arrays(self): + """Test dividing two empty arrays.""" + import numpy.array_api as anp + + numerator = anp.asarray([]) + denominator = anp.asarray([]) + expected_result = anp.asarray([]) + + result = safe_divide(numerator, denominator) + + assert anp.all(result == expected_result) + + def test_divide_different_datatypes(self): + """Test dividing two arrays with different datatypes.""" + numerator = anp.asarray([1.0, 2.0, 3.0], dtype=anp.float32) + denominator = anp.asarray([2.0, 3.0, 4.0], dtype=anp.float64) + expected_result = anp.asarray([0.5, 0.66666667, 0.75], dtype=anp.float64) + + result = safe_divide(numerator, denominator) + + assert np.allclose(result, expected_result) + + def test_divide_mixed_values(self): + """Test dividing two arrays with mixed positive and negative values.""" + numerator = anp.asarray([1.0, -2.0, 0.0, 3.0]) + denominator = anp.asarray([-2.0, 3.0, 0.0, -4.0]) + expected_result = anp.asarray([-0.5, -0.66666667, 0.0, -0.75]) + + result = safe_divide(numerator, denominator) + + assert np.allclose(result, expected_result) + + def test_divide_large_values(self): + """Test dividing two arrays with large values.""" + numerator = anp.asarray([1e20, 2e20, 3e20], dtype=anp.float64) + denominator = anp.asarray([1e20, 1e-20, 3e20], dtype=anp.float32) + expected_result = anp.asarray([1.0, 2e40, 1.0]) + + result = safe_divide(numerator, denominator) + + assert np.allclose(result, expected_result) + + def test_divide_different_shapes(self): + """Test dividing two arrays with different shapes.""" + numerator = anp.asarray([1.0, 2.0, 3.0]) + denominator = anp.asarray([1.0, 2.0]) + + with pytest.raises(ValueError): + safe_divide(numerator, denominator) + + def test_divide_inf_values(self): + """Test dividing two arrays with Inf values.""" + numerator = np.asarray([1.0, 2.0, np.inf]) + denominator = np.asarray([2.0, np.inf, 4.0]) + expected_result = np.asarray([0.5, 0.0, np.inf]) + + result = safe_divide(numerator, denominator) + print(result) + + assert np.all(result == expected_result) + + def test_divide_with_nan_values(self): + """Test dividing two arrays with NaN values.""" + numerator = np.asarray([1.0, 2.0, np.nan]) + denominator = np.asarray([2.0, np.nan, 4.0]) + expected_result = np.asarray([0.5, np.nan, np.nan]) + + result = safe_divide(numerator, denominator) + + assert np.all(np.isnan(result) == np.isnan(expected_result)) + + def test_returns_array_with_same_shape(self): + """Test that the shape of the output is the same as the input arrays.""" + numerator = anp.asarray([1.0, 2.0, 3.0]) + denominator = anp.asarray([2.0, 3.0, 4.0]) + + result = safe_divide(numerator, denominator) + + assert result.shape == numerator.shape + assert result.shape == denominator.shape + + +class TestSigmoid: + """Test the `sigmoid` utility function.""" + + def test_sigmoid_positive_values(self): + """Test sigmoid function with positive values.""" + x = anp.asarray([1.1, 2.0, 3.0]) + result = sigmoid(x) + expected = anp.asarray([0.75026011, 0.88079708, 0.95257413], dtype=anp.float64) + np.testing.assert_allclose(result, expected) + + def test_sigmoid_negative_values(self): + """Test sigmoid function with negative values.""" + x = anp.asarray([-1.1, -2.0, -3.0]) + result = sigmoid(x) + expected = anp.asarray([0.24973989, 0.11920292, 0.04742587], dtype=anp.float64) + np.testing.assert_allclose(result, expected) + + def test_sigmoid_zeros(self): + """Test sigmoid function with zeros.""" + x = anp.asarray([0, 0, 0]) + result = sigmoid(x) + expected = anp.asarray([0.5, 0.5, 0.5], dtype=anp.float64) + np.testing.assert_allclose(result, expected) + + def test_sigmoid_large_values(self): + """Test sigmoid function with large values.""" + x = anp.asarray([100, 1000, 10000]) + result = sigmoid(x) + expected = anp.asarray([1.0, 1.0, 1.0], dtype=anp.float64) + np.testing.assert_allclose(result, expected) + + def test_sigmoid_small_values(self): + """Test sigmoid function with small values.""" + x = anp.asarray([-100, -1000, -10000], dtype=anp.float32) + result = sigmoid(x) + expected = anp.asarray([0.0, 0.0, 0.0], dtype=anp.float32) + np.testing.assert_allclose(result, expected, atol=4e-44) + + def test_sigmoid_empty_array(self): + """Test sigmoid function with empty array.""" + x = anp.asarray([]) + result = sigmoid(x) + expected = anp.asarray([], dtype=anp.float64) + assert all(result == expected) + + def test_sigmoid_nan_values(self): + """Test sigmoid function with NaN values.""" + x = anp.asarray([anp.nan, anp.nan, anp.nan]) + result = sigmoid(x) + expected = anp.asarray([anp.nan, anp.nan, anp.nan], dtype=anp.float64) + np.testing.assert_allclose(result, expected, equal_nan=True) + + def test_sigmoid_infinity_values(self): + """Test sigmoid function with infinity values.""" + x = anp.asarray([anp.inf, anp.inf, anp.inf]) + result = sigmoid(x) + expected = anp.asarray([1.0, 1.0, 1.0], dtype=anp.float64) + np.testing.assert_allclose(result, expected) + + def test_sigmoid_negative_infinity_values(self): + """Test sigmoid function with negative infinity values.""" + x = anp.asarray([-anp.inf, -anp.inf, -anp.inf]) + result = sigmoid(x) + expected = anp.asarray([0.0, 0.0, 0.0], dtype=anp.float64) + np.testing.assert_allclose(result, expected) + + def test_sigmoid_large_number_of_elements(self): + """Test sigmoid function with a large number of elements.""" + x = anp.ones(10**6) + result = sigmoid(x) + expected = anp.ones(10**6, dtype=anp.float64) * 0.73105858 + np.testing.assert_allclose(result, expected) + + +def test_squeeze_all(): + """Test the `squeeze_all` utility function.""" + # happy path + x = anp.asarray([[1, 2, 3], [4, 5, 6]]) + result = squeeze_all(x) + np.testing.assert_array_equal(result, x) + + x = anp.asarray([[[1, 2, 3]]]) + result = squeeze_all(x) + excepted_result = np.squeeze(x) + np.testing.assert_array_equal(result, excepted_result) + + x = anp.asarray([[[1, 2, 3]], [[4, 5, 6]]]) + result = squeeze_all(x) + excepted_result = np.squeeze(x) + np.testing.assert_array_equal(result, excepted_result) + + x = anp.asarray([[[0], [1], [2]]]) + result = squeeze_all(x) + excepted_result = np.squeeze(x) + np.testing.assert_array_equal(result, excepted_result) + + # edge cases + x = anp.asarray([]) + result = squeeze_all(x) + excepted_result = np.squeeze(x) + np.testing.assert_array_equal(result, excepted_result) diff --git a/tests/cyclops/evaluate/metrics/experimental/utils/test_validation.py b/tests/cyclops/evaluate/metrics/experimental/utils/test_validation.py new file mode 100644 index 000000000..d504baddb --- /dev/null +++ b/tests/cyclops/evaluate/metrics/experimental/utils/test_validation.py @@ -0,0 +1,50 @@ +"""Test utility functions for validating input arrays.""" +import numpy as np +import numpy.array_api as anp +import torch + +from cyclops.evaluate.metrics.experimental.utils.validation import ( + is_floating_point, + is_numeric, +) + + +def test_is_floating_point(): + """Test `is_floating_point`.""" + x = anp.asarray([1, 2, 3], dtype=anp.float32) + assert is_floating_point(x) + + x = anp.asarray([1, 2, 3], dtype=anp.float64) + assert is_floating_point(x) + + x = torch.tensor([1, 2, 3], dtype=torch.float16) + assert is_floating_point(x) + + x = torch.tensor([1, 2, 3], dtype=torch.bfloat16) + assert is_floating_point(x) + + x = anp.asarray([1, 2, 3], dtype=anp.int32) + assert not is_floating_point(x) + + x = np.zeros((3, 3), dtype=np.bool_) + assert not is_floating_point(x) + + +def test_is_numeric(): + """Test `is_numeric`.""" + numeric_dtypes = [ + anp.int8, + anp.int16, + anp.int32, + anp.int64, + anp.uint8, + anp.uint16, + anp.uint32, + anp.uint64, + anp.float32, + anp.float64, + ] + + for dtype in numeric_dtypes: + x = anp.asarray([1, 2, 3], dtype=dtype) + assert is_numeric(x) diff --git a/tests/cyclops/evaluate/metrics/inputs.py b/tests/cyclops/evaluate/metrics/inputs.py index 32231bcd0..0306a82f3 100644 --- a/tests/cyclops/evaluate/metrics/inputs.py +++ b/tests/cyclops/evaluate/metrics/inputs.py @@ -8,12 +8,8 @@ import scipy as sp from numpy.typing import ArrayLike +from .conftest import BATCH_SIZE, NUM_BATCHES, NUM_CLASSES, NUM_LABELS -BATCH_SIZE = 16 -NUM_BATCHES = 8 -NUM_CLASSES = 10 -NUM_LABELS = 5 -THRESHOLD = random.random() Input = namedtuple("Input", ["target", "preds"]) diff --git a/tests/cyclops/evaluate/metrics/test_accuracy.py b/tests/cyclops/evaluate/metrics/test_accuracy.py index ee16df1d6..9fe90cdce 100644 --- a/tests/cyclops/evaluate/metrics/test_accuracy.py +++ b/tests/cyclops/evaluate/metrics/test_accuracy.py @@ -11,20 +11,15 @@ from cyclops.evaluate.metrics.accuracy import Accuracy from cyclops.evaluate.metrics.functional.accuracy import accuracy from cyclops.evaluate.metrics.utils import sigmoid -from metrics.helpers import MetricTester -from metrics.inputs import ( - NUM_CLASSES, - NUM_LABELS, - THRESHOLD, - _binary_cases, - _multiclass_cases, - _multilabel_cases, -) -from metrics.test_stat_scores import ( +from evaluate.metrics.test_stat_scores import ( _sk_stat_scores_multiclass, _sk_stat_scores_multilabel, ) +from .conftest import NUM_CLASSES, NUM_LABELS, THRESHOLD +from .helpers import MetricTester +from .inputs import _binary_cases, _multiclass_cases, _multilabel_cases + np.seterr(divide="ignore", invalid="ignore") # ignore divide by zero or nan diff --git a/tests/cyclops/evaluate/metrics/test_auroc.py b/tests/cyclops/evaluate/metrics/test_auroc.py index 17a58cace..667aec167 100644 --- a/tests/cyclops/evaluate/metrics/test_auroc.py +++ b/tests/cyclops/evaluate/metrics/test_auroc.py @@ -11,10 +11,10 @@ from cyclops.evaluate.metrics.auroc import AUROC from cyclops.evaluate.metrics.functional import auroc as cyclops_auroc from cyclops.evaluate.metrics.utils import sigmoid -from metrics.helpers import MetricTester -from metrics.inputs import ( - NUM_CLASSES, - NUM_LABELS, + +from .conftest import NUM_CLASSES, NUM_LABELS +from .helpers import MetricTester +from .inputs import ( _binary_cases, _multiclass_cases, _multilabel_cases, diff --git a/tests/cyclops/evaluate/metrics/test_fbeta.py b/tests/cyclops/evaluate/metrics/test_fbeta.py index e4b32e754..ce0b1aa7e 100644 --- a/tests/cyclops/evaluate/metrics/test_fbeta.py +++ b/tests/cyclops/evaluate/metrics/test_fbeta.py @@ -10,15 +10,10 @@ from cyclops.evaluate.metrics.f_beta import F1Score, FbetaScore from cyclops.evaluate.metrics.functional.f_beta import f1_score, fbeta_score from cyclops.evaluate.metrics.utils import sigmoid -from metrics.helpers import MetricTester -from metrics.inputs import ( - NUM_CLASSES, - NUM_LABELS, - THRESHOLD, - _binary_cases, - _multiclass_cases, - _multilabel_cases, -) + +from .conftest import NUM_CLASSES, NUM_LABELS, THRESHOLD +from .helpers import MetricTester +from .inputs import _binary_cases, _multiclass_cases, _multilabel_cases def _sk_binary_fbeta_score( diff --git a/tests/cyclops/evaluate/metrics/test_metric_collection.py b/tests/cyclops/evaluate/metrics/test_metric_collection.py index 6ee568633..31c36a5b1 100644 --- a/tests/cyclops/evaluate/metrics/test_metric_collection.py +++ b/tests/cyclops/evaluate/metrics/test_metric_collection.py @@ -5,14 +5,10 @@ from cyclops.evaluate.metrics import MetricCollection from cyclops.evaluate.metrics.metric import _METRIC_REGISTRY -from metrics.helpers import _assert_allclose -from metrics.inputs import ( - NUM_CLASSES, - NUM_LABELS, - _binary_cases, - _multiclass_cases, - _multilabel_cases, -) + +from .conftest import NUM_CLASSES, NUM_LABELS +from .helpers import _assert_allclose +from .inputs import _binary_cases, _multiclass_cases, _multilabel_cases @pytest.fixture(name="binary_metrics") diff --git a/tests/cyclops/evaluate/metrics/test_precision_recall.py b/tests/cyclops/evaluate/metrics/test_precision_recall.py index 3e6822d48..69a8ba698 100644 --- a/tests/cyclops/evaluate/metrics/test_precision_recall.py +++ b/tests/cyclops/evaluate/metrics/test_precision_recall.py @@ -16,15 +16,10 @@ ) from cyclops.evaluate.metrics.precision_recall import Precision, Recall from cyclops.evaluate.metrics.utils import sigmoid -from metrics.helpers import MetricTester -from metrics.inputs import ( - NUM_CLASSES, - NUM_LABELS, - THRESHOLD, - _binary_cases, - _multiclass_cases, - _multilabel_cases, -) + +from .conftest import NUM_CLASSES, NUM_LABELS, THRESHOLD +from .helpers import MetricTester +from .inputs import _binary_cases, _multiclass_cases, _multilabel_cases def _sk_binary_precision_recall( diff --git a/tests/cyclops/evaluate/metrics/test_precision_recall_curve.py b/tests/cyclops/evaluate/metrics/test_precision_recall_curve.py index e6bc0cc5c..2140dabe8 100644 --- a/tests/cyclops/evaluate/metrics/test_precision_recall_curve.py +++ b/tests/cyclops/evaluate/metrics/test_precision_recall_curve.py @@ -12,14 +12,10 @@ ) from cyclops.evaluate.metrics.precision_recall_curve import PrecisionRecallCurve from cyclops.evaluate.metrics.utils import sigmoid -from metrics.helpers import MetricTester -from metrics.inputs import ( - NUM_CLASSES, - NUM_LABELS, - _binary_cases, - _multiclass_cases, - _multilabel_cases, -) + +from .conftest import NUM_CLASSES, NUM_LABELS +from .helpers import MetricTester +from .inputs import _binary_cases, _multiclass_cases, _multilabel_cases def _sk_binary_precision_recall_curve( diff --git a/tests/cyclops/evaluate/metrics/test_roc.py b/tests/cyclops/evaluate/metrics/test_roc.py index 7c4fcecc2..02cbbdfbe 100644 --- a/tests/cyclops/evaluate/metrics/test_roc.py +++ b/tests/cyclops/evaluate/metrics/test_roc.py @@ -10,14 +10,10 @@ from cyclops.evaluate.metrics.functional import roc_curve as cyclops_roc_curve from cyclops.evaluate.metrics.roc import ROCCurve from cyclops.evaluate.metrics.utils import sigmoid -from metrics.helpers import MetricTester -from metrics.inputs import ( - NUM_CLASSES, - NUM_LABELS, - _binary_cases, - _multiclass_cases, - _multilabel_cases, -) + +from .conftest import NUM_CLASSES, NUM_LABELS +from .helpers import MetricTester +from .inputs import _binary_cases, _multiclass_cases, _multilabel_cases def _sk_binary_roc_curve( diff --git a/tests/cyclops/evaluate/metrics/test_specificity.py b/tests/cyclops/evaluate/metrics/test_specificity.py index f39ba6c0c..e0b11197c 100644 --- a/tests/cyclops/evaluate/metrics/test_specificity.py +++ b/tests/cyclops/evaluate/metrics/test_specificity.py @@ -8,16 +8,11 @@ from cyclops.evaluate.metrics.functional.specificity import specificity from cyclops.evaluate.metrics.specificity import Specificity -from metrics.helpers import MetricTester -from metrics.inputs import ( - NUM_CLASSES, - NUM_LABELS, - THRESHOLD, - _binary_cases, - _multiclass_cases, - _multilabel_cases, -) -from metrics.test_stat_scores import ( + +from .conftest import NUM_CLASSES, NUM_LABELS, THRESHOLD +from .helpers import MetricTester +from .inputs import _binary_cases, _multiclass_cases, _multilabel_cases +from .test_stat_scores import ( _sk_stat_scores_binary, _sk_stat_scores_multiclass, _sk_stat_scores_multilabel, diff --git a/tests/cyclops/evaluate/metrics/test_stat_scores.py b/tests/cyclops/evaluate/metrics/test_stat_scores.py index 464248cb5..d26d50ac9 100644 --- a/tests/cyclops/evaluate/metrics/test_stat_scores.py +++ b/tests/cyclops/evaluate/metrics/test_stat_scores.py @@ -13,15 +13,10 @@ from cyclops.evaluate.metrics.functional.stat_scores import stat_scores from cyclops.evaluate.metrics.stat_scores import StatScores from cyclops.evaluate.metrics.utils import sigmoid -from metrics.helpers import MetricTester -from metrics.inputs import ( - NUM_CLASSES, - NUM_LABELS, - THRESHOLD, - _binary_cases, - _multiclass_cases, - _multilabel_cases, -) + +from .conftest import NUM_CLASSES, NUM_LABELS, THRESHOLD +from .helpers import MetricTester +from .inputs import _binary_cases, _multiclass_cases, _multilabel_cases def _sk_stat_scores_binary( diff --git a/tests/cyclops/utils/test_optional.py b/tests/cyclops/utils/test_optional.py new file mode 100644 index 000000000..2ccb3854d --- /dev/null +++ b/tests/cyclops/utils/test_optional.py @@ -0,0 +1,37 @@ +"""Test optional import utilities.""" + +import pytest + +from cyclops.utils.optional import import_optional_module + + +class TestImportOptionalModule: + """Test importing optional modules.""" + + def test_import_valid_module(self): + """Test importing a valid module.""" + module = import_optional_module("math") + assert module is not None + import math + + assert module == math + + def test_import_nonexistent_module_ignore(self): + """Test importing a non-existent module with `error='ignore'`.""" + module = import_optional_module("nonexistent_module", error="ignore") + assert module is None + + def test_import_nonexistent_module_warn(self): + """Test importing a non-existent module with `error='warn'`.""" + with pytest.warns(ImportWarning): + import_optional_module("nonexistent_module", error="warn") + + def test_import_nonexistent_module_raise(self): + """Test importing a non-existent module with `error='raise'`.""" + with pytest.raises(ModuleNotFoundError): + import_optional_module("nonexistent_module", error="raise") + + def test_invalid_error_option(self): + """Test importing a valid module with an invalid error option.""" + with pytest.raises(ValueError): + import_optional_module("math", error="invalid_option") # type: ignore