Skip to content

Commit

Permalink
Add torchmetrics as dependency to tests and fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
fcogidi committed Nov 8, 2023
1 parent 500cab3 commit b59a98e
Show file tree
Hide file tree
Showing 8 changed files with 17 additions and 56 deletions.
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ mypy = "^1.0.0"
ruff = "^0.1.0"
nbqa = { version = "^1.7.0", extras = ["toolchain"] }
cycquery = "^0.1.0" # used for integration test
torchmetrics = {version = "^1.2.0", extras = ["classification"]}

[tool.poetry.group.docs]
optional = true
Expand Down
1 change: 1 addition & 0 deletions tests/cyclops/evaluate/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Tests for the `cyclops.evaluate` package."""
2 changes: 1 addition & 1 deletion tests/cyclops/evaluate/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
"""Evaluate metrics testing package."""
"""Tests for `cyclops.evaluate.metrics` package."""
2 changes: 1 addition & 1 deletion tests/cyclops/evaluate/metrics/experimental/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
"""Test array-API-compatible implementations of evalution metrics."""
"""Test the `cyclops.evaluate.metrics.experimental` package."""
11 changes: 4 additions & 7 deletions tests/cyclops/evaluate/metrics/experimental/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import numpy.array_api as anp
import pytest
import torch

from cyclops.evaluate.metrics.experimental.utils.ops import (
dim_zero_cat,
Expand All @@ -11,11 +12,7 @@
dim_zero_min,
dim_zero_sum,
)
from cyclops.utils.optional import import_optional_module
from metrics.experimental.testers import DummyListStateMetric, DummyMetric


torch = import_optional_module("torch", "ignore")
from evaluate.metrics.experimental.testers import DummyListStateMetric, DummyMetric


class TestMetricBaseClass:
Expand Down Expand Up @@ -328,8 +325,8 @@ def test_dist_backend_kwarg(self):
DummyMetric(dist_backend=42)

@pytest.mark.skipif(
torch is None or not torch.cuda.is_available(),
reason="Test requires torch and cuda.",
not torch.cuda.is_available(),
reason="CUDA is not available.",
)
def test_to_device_torch(self):
"""Test that `to_device` method works as expected."""
Expand Down
42 changes: 3 additions & 39 deletions tests/cyclops/evaluate/metrics/experimental/utils/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import numpy.array_api as anp
import pytest
import torch

from cyclops.evaluate.metrics.experimental.utils.ops import (
apply_to_array_collection,
Expand All @@ -26,7 +27,6 @@


cp = import_optional_module("cupy", error="ignore")
torch = import_optional_module("torch", error="ignore")


def multiply_by_two(x):
Expand Down Expand Up @@ -148,7 +148,6 @@ def test_handle_defaultdict_input(self):

def test_apply_to_nested_collections(self):
"""Test applying a function to nested collections of arrays."""
# Given
data = {
"a": anp.asarray(
[
Expand All @@ -168,10 +167,8 @@ def test_apply_to_nested_collections(self):
},
}

# When
result = apply_to_array_collection(data, multiply_by_two)

# Then
expected_result = {
"a": anp.asarray(
[
Expand Down Expand Up @@ -211,95 +208,73 @@ class TestBincount:

def test_non_negative_integers(self):
"""Test using non-negative integers as input."""
# Arrange
input_array = anp.asarray([0, 1, 1, 2, 2, 2])
expected_output = anp.asarray([1, 2, 3])

# Act
result = bincount(input_array)

# Assert
assert anp.all(result == expected_output)

def test_empty_array(self):
"""Test using an empty array as input."""
# Arrange
input_array = anp.asarray([], dtype=anp.int32)
expected_output = anp.asarray([], dtype=anp.int32)

# Act
result = bincount(input_array, minlength=5)

# Assert
assert anp.all(result == expected_output)

def test_single_unique_value(self):
"""Test using an array with a single unique value as input."""
# Arrange
input_array = anp.asarray([3, 3, 3, 3])
expected_output = anp.asarray([0, 0, 0, 4])

# Act
result = bincount(input_array)

# Assert
assert anp.all(result == expected_output)

def test_no_repeated_values(self):
"""Test using an array with no repeated values as input."""
# Arrange
input_array = anp.asarray([0, 1, 2, 3, 4, 5])
expected_output = anp.ones_like(input_array)

# Act
result = bincount(input_array)

# Assert
assert anp.all(result == expected_output)

def test_negative_integers(self):
"""Test using an array with negative integers as input."""
# Arrange
input_array = anp.asarray([-1, 0, 1, 2])

# Act and Assert
with pytest.raises(ValueError):
bincount(input_array)

def test_negative_minlength(self):
"""Test using a negative minlength as input."""
# Arrange
input_array = anp.asarray([1, 2, 3])

# Act and Assert
with pytest.raises(ValueError):
bincount(input_array, minlength=-5)

def test_different_shapes(self):
"""Test using arrays and weights with different shapes as input."""
# Arrange
input_array = anp.asarray([1, 2, 3])
weights = anp.asarray([0.5, 0.5])

# Act and Assert
with pytest.raises(ValueError):
bincount(input_array, weights=weights)

def test_not_one_dimensional(self):
"""Test using a multi-dimensional array as input."""
# Arrange
input_array = anp.asarray([[1, 2], [3, 4]])

# Act and Assert
with pytest.raises(ValueError):
bincount(input_array)

def test_not_integer_type(self):
"""Test using a non-integer array as input."""
# Arrange
input_array = anp.asarray([1.5, 2.5, 3.5])

# Act and Assert
with pytest.raises(ValueError):
bincount(input_array)

Expand All @@ -309,10 +284,8 @@ class TestClone:

def test_clone_numpy_array(self):
"""Test if the clone function creates a new copy of a numpy array."""
# Create a numpy array
x = np.array([1, 2, 3])

# Clone the array
y = clone(x)

# Check if y is a new copy of x
Expand All @@ -328,23 +301,18 @@ def test_clone_cupy_array(self):
except cp.cuda.runtime.CUDARuntimeError: # type: ignore
pytest.skip("CUDA is not available.")

# Create a cupy array
x = cp.asarray([1, 2, 3]) # type: ignore

# Clone the array
y = clone(x)

# Check if y is a new copy of x
assert y is not x
assert cp.array_equal(y, x) # type: ignore

@pytest.mark.skipif(torch is None, reason="PyTorch is not installed.")
def test_clone_torch_tensor(self):
"""Test if the clone function properly clones a torch tensor."""
# Create a torch tensor
x = torch.tensor([1, 2, 3]) # type: ignore

# Clone the tensor
y = clone(x)

# Check if y is a new copy of x
Expand All @@ -353,17 +321,13 @@ def test_clone_torch_tensor(self):

def test_clone_empty_array(self):
"""Test if the clone function creates a new copy of an empty array."""
import numpy.array_api as np

# Create an empty array
x = np.asarray([])
x = anp.asarray([])

# Clone the array
y = clone(x)

# Check if y is a new copy of x
assert y is not x
assert np.all(y == x)
assert anp.all(y == x)


class TestDimZeroCat:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""Test utility functions for validating input arrays."""
import numpy as np
import numpy.array_api as anp
import torch

from cyclops.evaluate.metrics.experimental.utils.validation import (
is_floating_point,
is_numeric,
)
from cyclops.utils.optional import import_optional_module


def test_is_floating_point():
Expand All @@ -17,13 +17,11 @@ def test_is_floating_point():
x = anp.asarray([1, 2, 3], dtype=anp.float64)
assert is_floating_point(x)

torch = import_optional_module("torch")
if torch is not None:
x = torch.tensor([1, 2, 3], dtype=torch.float16)
assert is_floating_point(x)
x = torch.tensor([1, 2, 3], dtype=torch.float16)
assert is_floating_point(x)

x = torch.tensor([1, 2, 3], dtype=torch.bfloat16)
assert is_floating_point(x)
x = torch.tensor([1, 2, 3], dtype=torch.bfloat16)
assert is_floating_point(x)

x = anp.asarray([1, 2, 3], dtype=anp.int32)
assert not is_floating_point(x)
Expand Down

0 comments on commit b59a98e

Please sign in to comment.