Skip to content

Commit

Permalink
Added pytest for paradigm_utils.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ArashAkbarinia committed Dec 15, 2023
1 parent de73728 commit 7b673cd
Show file tree
Hide file tree
Showing 3 changed files with 171 additions and 65 deletions.
140 changes: 76 additions & 64 deletions osculari/paradigms/paradigm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,21 @@
]


def accuracy_preds(output: torch.Tensor, target: torch.Tensor,
topk: Optional[Sequence] = (1,)) -> (List, List):
def _accuracy_preds(output: torch.Tensor, target: torch.Tensor,
topk: Optional[Sequence] = (1,)) -> (List[float], List[torch.Tensor]):
"""
Computes the accuracy over the k top predictions.
Compute accuracy and correct predictions for the top-k thresholds.
Args:
output: The model's output tensor containing predictions for each input sample.
target: The ground-truth labels for each input sample.
topk: An optional list of top-k accuracy thresholds to be computed (e.g., (1, 5)).
Parameters:
output (torch.Tensor): Model predictions.
target (torch.Tensor): Ground truth labels.
topk (Optional[Sequence]): Top-k thresholds for accuracy computation. Default is (1,).
Returns:
A tuple containing the computed accuracies and correct predictions for each top-k threshold.
Tuple[List[float], List[torch.Tensor]]: List of accuracies for each top-k threshold,
list of correct predictions for each top-k
threshold.
"""

with torch.inference_mode(): # Ensure that the model is in inference mode
maxk = max(topk) # Extract the maximum top-k value
batch_size = target.size(0) # Get the batch size
Expand All @@ -59,38 +60,47 @@ def accuracy_preds(output: torch.Tensor, target: torch.Tensor,

def accuracy(output: torch.Tensor, target: torch.Tensor) -> float:
"""
This function computes the accuracy of a model's prediction on a given set of data.
Compute the accuracy of model predictions.
Args:
output: The model's predicted output (torch.Tensor).
target: The ground truth labels (torch.Tensor).
Parameters:
output (torch.Tensor): Model predictions.
target (torch.Tensor): Ground truth labels.
Returns:
The accuracy of the model's predictions (float).
float: Accuracy of the model predictions.
"""
# Ensure the output has two dimensions (Linear layer output is two-dimensional)
assert len(output.shape) == 2
# Ensure output and target have the same number of elements
assert len(output) == len(target)

# Check if the model is performing binary classification
if output.shape[1] == 1:
# Check if the model produces one-dimensional predictions
pred = torch.equal(torch.gt(output, 0), target.float())
return pred.float().mean(0, keepdim=True)[0]
# Convert to binary predictions (greater than 0)
output_class = torch.gt(output, 0).flatten()
# Compute accuracy for binary classification
pred = torch.eq(output_class, target)
return pred.float().mean().item()

# Otherwise, the model produces multidimensional predictions
acc, _ = accuracy_preds(output, target, topk=[1])
acc, _ = _accuracy_preds(output, target, topk=[1])
return acc[0].item() # Extract the top-1 accuracy


def circular_mean(a: float, b: float) -> float:
"""
Computes the circular mean of two values.
Compute the circular mean of two angles in radians.
Args:
a: The first value (float).
b: The second value (float).
Parameters:
a (float): First angle in radians.
b (float): Second angle in radians.
Returns:
The circular mean of the two values (float).
"""
Returns:
float: Circular mean of the two angles.
"""
# Calculate the circular mean using a conditional expression
mu = (a + b + 1) / 2 if abs(a - b) > 0.5 else (a + b) / 2
# Adjust the result to be in the range [0, 1)
return mu if mu >= 1 else mu - 1


Expand All @@ -115,24 +125,23 @@ def midpoint(
Union[float, npt.NDArray, None]
):
"""
Finds the midpoint of the stimulus range based on the current accuracy and the target accuracy
threshold.
Args:
acc: The current accuracy value (float).
low: The lower bound of the stimulus range (float or NumPy array).
mid: The current midpoint of the stimulus range (float or NumPy array).
high: The upper bound of the stimulus range (float or NumPy array).
th: The target accuracy threshold (float).
ep: The convergence tolerance (float; optional).
circular_channels: The list of circular channels for applying circular arithmetic when
computing the average (list; optional).
Compute new midpoints for a given accuracy in a binary search.
Parameters:
acc (float): Current accuracy.
low (Union[float, npt.NDArray]): Low value in the search space.
mid (Union[float, npt.NDArray]): Midpoint in the search space.
high (Union[float, npt.NDArray]): High value in the search space.
th (float): Target accuracy.
ep (Optional[float]): Acceptable range around the target accuracy. Default is 1e-4.
circular_channels (Optional[List]): List of circular channels. Default is None.
Returns:
The new low, mid, and high values of the stimulus range based on the current accuracy and
the target accuracy threshold.
(Union[float, npt.NDArray, None], Union[float, npt.NDArray, None], Union[float, npt.NDArray, None]):
Tuple containing the updated low, mid, and high values.
If the accuracy is within the acceptable range of the target accuracy, returns
(None, None, None).
"""

# Calculate the difference between the current accuracy and the target accuracy
diff_acc = acc - th

Expand All @@ -157,32 +166,35 @@ def midpoint(
return mid, new_mid, high


def train_linear_probe(model: ProbeNet, dataset: Union[TorchDataset, TorchDataLoader],
epoch_loop: Callable[[nn.Module, TorchDataLoader, Any, torch.device], Dict],
out_dir: str, device: Optional[torch.device] = None,
epochs: Optional[int] = 10,
optimiser: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[lr_scheduler.LRScheduler] = None) -> Dict:
def train_linear_probe(
model: ProbeNet,
dataset: Union[TorchDataset, TorchDataLoader],
epoch_loop: Callable[[nn.Module, TorchDataLoader, Any, torch.device], Dict],
out_dir: str,
device: Optional[torch.device] = None,
epochs: Optional[int] = 10,
optimiser: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[lr_scheduler.LRScheduler] = None
) -> Dict:
"""
Trains the linear probe network on the specified dataset.
Args:
model: The linear probe network to train.
dataset: The dataset or dataloader for training.
epoch_loop: A function to perform an epoch of training or testing. This function
must accept for positional arguments (i.e., model, train_loader, optimiser, device).
This function should return a dictionary.
out_dir: The output directory for saving checkpoints.
device: The device to use for training (Optional).
epochs: The number of epochs to train for (Optional).
optimiser: The optimiser to use for training (default: SGD) (Optional).
scheduler: The learning rate scheduler to use
(default: MultiStepLR at 50 and 80% of epochs) (Optional).
Returns:
A dictionary containing training logs.
"""
Train a linear probe on top of a frozen backbone model.
Parameters:
model (ProbeNet): Linear probe model.
dataset (Union[TorchDataset, TorchDataLoader]): Training dataset or data loader.
epoch_loop (Callable): Function defining the training loop for one epoch. This function
must accept for positional arguments (i.e., model, train_loader, optimiser, device).
This function should return a dictionary.
out_dir (str): Output directory to save checkpoints.
device (Optional[torch.device]): Device on which to perform training.
epochs (Optional[int]): Number of training epochs. Default is 10.
optimiser (Optional[torch.optim.Optimizer]): Optimization algorithm. Default is SGD.
scheduler (Optional[lr_scheduler.LRScheduler]): Learning rate scheduler. Default is
MultiStepLR at 50 and 80% of epochs
Returns:
Dict: Training logs containing statistics.
"""
# Data loading
if isinstance(dataset, TorchDataLoader):
train_loader = dataset
Expand Down
48 changes: 47 additions & 1 deletion tests/datasets/imutils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
Unit tests for imutils_test.py
"""

import numpy as np
import pytest
import numpy as np

from osculari.datasets import imutils

Expand Down Expand Up @@ -33,7 +33,53 @@ def test_michelson_contrast_valid_input(sample_image):
np.testing.assert_almost_equal(result, expected_result)


def test_michelson_contrast_contrast_one(sample_image):
contrast_factor = 1.0
result = imutils.michelson_contrast(sample_image, contrast_factor)

# Ensure that the output has the same shape as the input
assert result.shape == sample_image.shape

# Ensure that the output is a NumPy array
assert isinstance(result, np.ndarray)

# Ensure that the output is identical to input
expected_result = sample_image
np.testing.assert_equal(result, expected_result)


def test_michelson_contrast_invalid_contrast():
with pytest.raises(AssertionError):
contrast_factor = 1.5 # Invalid contrast value
imutils.michelson_contrast(np.array([[1, 2], [3, 4]]), contrast_factor)


def test_gamma_correction_valid_input(sample_image):
gamma_factor = 0.5
result = imutils.gamma_correction(sample_image, gamma_factor)

# Ensure that the output has the same shape as the input
assert result.shape == sample_image.shape

# Ensure that the output is a NumPy array
assert isinstance(result, np.ndarray)

# Ensure that gamma correction is applied correctly
expected_result = np.array([[169, 201, 223],
[187, 213, 232],
[201, 223, 239]], dtype='uint8')
np.testing.assert_almost_equal(result, expected_result)


def test_gamma_correction_gamma_one(sample_image):
gamma_factor = 1.0
result = imutils.gamma_correction(sample_image, gamma_factor)

# Ensure that when gamma is 1, the output is the same as the input
np.testing.assert_almost_equal(result, sample_image)


def test_gamma_correction_zero_gamma():
with pytest.raises(AssertionError):
gamma_factor = 0.0 # Invalid gamma value
imutils.gamma_correction(np.array([[1, 2], [3, 4]]), gamma_factor)
48 changes: 48 additions & 0 deletions tests/paradigms/paradigm_utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""
Unit tests for paradigm_utils.py
"""

import pytest
import torch

from osculari.paradigms import paradigm_utils


def test_accuracy_binary_classification():
# Test accuracy for binary classification predictions
output = torch.tensor([0.2, -0.1, 0.8, -0.4]).view(4, 1)
target = torch.tensor([1, 0, 1, 0])
acc = paradigm_utils.accuracy(output, target)
assert acc == 1.0


def test_accuracy_multi_classification():
# Test accuracy for multi-class predictions
output = torch.tensor([[0.2, -0.1, 0.8, -0.4], [0.1, 0.3, -0.2, 0.5]])
target = torch.tensor([2, 0])
acc = paradigm_utils.accuracy(output, target)
assert acc == 0.5


def test_accuracy_invalid_input():
# Test with invalid input (different shapes)
output = torch.tensor([[0.2, -0.1, 0.8, -0.4], [0.1, 0.3, -0.2, 0.5]])
target = torch.tensor([2, 0, 1]) # Invalid target shape
with pytest.raises(AssertionError):
paradigm_utils.accuracy(output, target)


def test_accuracy_zero_dimensional():
# Test with zero-dimensional input (should raise an error)
output = torch.tensor(0.5)
target = torch.tensor(1)
with pytest.raises(AssertionError):
paradigm_utils.accuracy(output, target)


def test_accuracy_one_dimensional_equal():
# Test accuracy for one-dimensional predictions where output and target are equal
output = torch.tensor([0.2, -0.1, 0.8, -0.4]).view(4, 1)
target = torch.tensor([0, 0, 1, 0])
acc = paradigm_utils.accuracy(output, target)
assert acc == 0.75

0 comments on commit 7b673cd

Please sign in to comment.