From 7b673cdd4f2072a5d4194667ca6e33f10f93c748 Mon Sep 17 00:00:00 2001 From: ArashAkbarinia Date: Fri, 15 Dec 2023 10:46:01 +0100 Subject: [PATCH] Added pytest for paradigm_utils.py --- osculari/paradigms/paradigm_utils.py | 140 ++++++++++++++----------- tests/datasets/imutils_test.py | 48 ++++++++- tests/paradigms/paradigm_utils_test.py | 48 +++++++++ 3 files changed, 171 insertions(+), 65 deletions(-) create mode 100644 tests/paradigms/paradigm_utils_test.py diff --git a/osculari/paradigms/paradigm_utils.py b/osculari/paradigms/paradigm_utils.py index 954d6c4..483c0d4 100644 --- a/osculari/paradigms/paradigm_utils.py +++ b/osculari/paradigms/paradigm_utils.py @@ -20,20 +20,21 @@ ] -def accuracy_preds(output: torch.Tensor, target: torch.Tensor, - topk: Optional[Sequence] = (1,)) -> (List, List): +def _accuracy_preds(output: torch.Tensor, target: torch.Tensor, + topk: Optional[Sequence] = (1,)) -> (List[float], List[torch.Tensor]): """ - Computes the accuracy over the k top predictions. + Compute accuracy and correct predictions for the top-k thresholds. - Args: - output: The model's output tensor containing predictions for each input sample. - target: The ground-truth labels for each input sample. - topk: An optional list of top-k accuracy thresholds to be computed (e.g., (1, 5)). + Parameters: + output (torch.Tensor): Model predictions. + target (torch.Tensor): Ground truth labels. + topk (Optional[Sequence]): Top-k thresholds for accuracy computation. Default is (1,). Returns: - A tuple containing the computed accuracies and correct predictions for each top-k threshold. + Tuple[List[float], List[torch.Tensor]]: List of accuracies for each top-k threshold, + list of correct predictions for each top-k + threshold. """ - with torch.inference_mode(): # Ensure that the model is in inference mode maxk = max(topk) # Extract the maximum top-k value batch_size = target.size(0) # Get the batch size @@ -59,38 +60,47 @@ def accuracy_preds(output: torch.Tensor, target: torch.Tensor, def accuracy(output: torch.Tensor, target: torch.Tensor) -> float: """ - This function computes the accuracy of a model's prediction on a given set of data. + Compute the accuracy of model predictions. - Args: - output: The model's predicted output (torch.Tensor). - target: The ground truth labels (torch.Tensor). + Parameters: + output (torch.Tensor): Model predictions. + target (torch.Tensor): Ground truth labels. Returns: - The accuracy of the model's predictions (float). + float: Accuracy of the model predictions. """ + # Ensure the output has two dimensions (Linear layer output is two-dimensional) + assert len(output.shape) == 2 + # Ensure output and target have the same number of elements + assert len(output) == len(target) + # Check if the model is performing binary classification if output.shape[1] == 1: - # Check if the model produces one-dimensional predictions - pred = torch.equal(torch.gt(output, 0), target.float()) - return pred.float().mean(0, keepdim=True)[0] + # Convert to binary predictions (greater than 0) + output_class = torch.gt(output, 0).flatten() + # Compute accuracy for binary classification + pred = torch.eq(output_class, target) + return pred.float().mean().item() # Otherwise, the model produces multidimensional predictions - acc, _ = accuracy_preds(output, target, topk=[1]) + acc, _ = _accuracy_preds(output, target, topk=[1]) return acc[0].item() # Extract the top-1 accuracy def circular_mean(a: float, b: float) -> float: """ - Computes the circular mean of two values. + Compute the circular mean of two angles in radians. - Args: - a: The first value (float). - b: The second value (float). + Parameters: + a (float): First angle in radians. + b (float): Second angle in radians. - Returns: - The circular mean of the two values (float). - """ + Returns: + float: Circular mean of the two angles. + """ + # Calculate the circular mean using a conditional expression mu = (a + b + 1) / 2 if abs(a - b) > 0.5 else (a + b) / 2 + # Adjust the result to be in the range [0, 1) return mu if mu >= 1 else mu - 1 @@ -115,24 +125,23 @@ def midpoint( Union[float, npt.NDArray, None] ): """ - Finds the midpoint of the stimulus range based on the current accuracy and the target accuracy - threshold. - - Args: - acc: The current accuracy value (float). - low: The lower bound of the stimulus range (float or NumPy array). - mid: The current midpoint of the stimulus range (float or NumPy array). - high: The upper bound of the stimulus range (float or NumPy array). - th: The target accuracy threshold (float). - ep: The convergence tolerance (float; optional). - circular_channels: The list of circular channels for applying circular arithmetic when - computing the average (list; optional). + Compute new midpoints for a given accuracy in a binary search. + + Parameters: + acc (float): Current accuracy. + low (Union[float, npt.NDArray]): Low value in the search space. + mid (Union[float, npt.NDArray]): Midpoint in the search space. + high (Union[float, npt.NDArray]): High value in the search space. + th (float): Target accuracy. + ep (Optional[float]): Acceptable range around the target accuracy. Default is 1e-4. + circular_channels (Optional[List]): List of circular channels. Default is None. Returns: - The new low, mid, and high values of the stimulus range based on the current accuracy and - the target accuracy threshold. + (Union[float, npt.NDArray, None], Union[float, npt.NDArray, None], Union[float, npt.NDArray, None]): + Tuple containing the updated low, mid, and high values. + If the accuracy is within the acceptable range of the target accuracy, returns + (None, None, None). """ - # Calculate the difference between the current accuracy and the target accuracy diff_acc = acc - th @@ -157,32 +166,35 @@ def midpoint( return mid, new_mid, high -def train_linear_probe(model: ProbeNet, dataset: Union[TorchDataset, TorchDataLoader], - epoch_loop: Callable[[nn.Module, TorchDataLoader, Any, torch.device], Dict], - out_dir: str, device: Optional[torch.device] = None, - epochs: Optional[int] = 10, - optimiser: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[lr_scheduler.LRScheduler] = None) -> Dict: +def train_linear_probe( + model: ProbeNet, + dataset: Union[TorchDataset, TorchDataLoader], + epoch_loop: Callable[[nn.Module, TorchDataLoader, Any, torch.device], Dict], + out_dir: str, + device: Optional[torch.device] = None, + epochs: Optional[int] = 10, + optimiser: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[lr_scheduler.LRScheduler] = None +) -> Dict: """ - Trains the linear probe network on the specified dataset. - - Args: - model: The linear probe network to train. - dataset: The dataset or dataloader for training. - epoch_loop: A function to perform an epoch of training or testing. This function - must accept for positional arguments (i.e., model, train_loader, optimiser, device). - This function should return a dictionary. - out_dir: The output directory for saving checkpoints. - device: The device to use for training (Optional). - epochs: The number of epochs to train for (Optional). - optimiser: The optimiser to use for training (default: SGD) (Optional). - scheduler: The learning rate scheduler to use - (default: MultiStepLR at 50 and 80% of epochs) (Optional). - - Returns: - A dictionary containing training logs. - """ + Train a linear probe on top of a frozen backbone model. + + Parameters: + model (ProbeNet): Linear probe model. + dataset (Union[TorchDataset, TorchDataLoader]): Training dataset or data loader. + epoch_loop (Callable): Function defining the training loop for one epoch. This function + must accept for positional arguments (i.e., model, train_loader, optimiser, device). + This function should return a dictionary. + out_dir (str): Output directory to save checkpoints. + device (Optional[torch.device]): Device on which to perform training. + epochs (Optional[int]): Number of training epochs. Default is 10. + optimiser (Optional[torch.optim.Optimizer]): Optimization algorithm. Default is SGD. + scheduler (Optional[lr_scheduler.LRScheduler]): Learning rate scheduler. Default is + MultiStepLR at 50 and 80% of epochs + Returns: + Dict: Training logs containing statistics. + """ # Data loading if isinstance(dataset, TorchDataLoader): train_loader = dataset diff --git a/tests/datasets/imutils_test.py b/tests/datasets/imutils_test.py index f487d8f..ed1c835 100644 --- a/tests/datasets/imutils_test.py +++ b/tests/datasets/imutils_test.py @@ -2,8 +2,8 @@ Unit tests for imutils_test.py """ -import numpy as np import pytest +import numpy as np from osculari.datasets import imutils @@ -33,7 +33,53 @@ def test_michelson_contrast_valid_input(sample_image): np.testing.assert_almost_equal(result, expected_result) +def test_michelson_contrast_contrast_one(sample_image): + contrast_factor = 1.0 + result = imutils.michelson_contrast(sample_image, contrast_factor) + + # Ensure that the output has the same shape as the input + assert result.shape == sample_image.shape + + # Ensure that the output is a NumPy array + assert isinstance(result, np.ndarray) + + # Ensure that the output is identical to input + expected_result = sample_image + np.testing.assert_equal(result, expected_result) + + def test_michelson_contrast_invalid_contrast(): with pytest.raises(AssertionError): contrast_factor = 1.5 # Invalid contrast value imutils.michelson_contrast(np.array([[1, 2], [3, 4]]), contrast_factor) + + +def test_gamma_correction_valid_input(sample_image): + gamma_factor = 0.5 + result = imutils.gamma_correction(sample_image, gamma_factor) + + # Ensure that the output has the same shape as the input + assert result.shape == sample_image.shape + + # Ensure that the output is a NumPy array + assert isinstance(result, np.ndarray) + + # Ensure that gamma correction is applied correctly + expected_result = np.array([[169, 201, 223], + [187, 213, 232], + [201, 223, 239]], dtype='uint8') + np.testing.assert_almost_equal(result, expected_result) + + +def test_gamma_correction_gamma_one(sample_image): + gamma_factor = 1.0 + result = imutils.gamma_correction(sample_image, gamma_factor) + + # Ensure that when gamma is 1, the output is the same as the input + np.testing.assert_almost_equal(result, sample_image) + + +def test_gamma_correction_zero_gamma(): + with pytest.raises(AssertionError): + gamma_factor = 0.0 # Invalid gamma value + imutils.gamma_correction(np.array([[1, 2], [3, 4]]), gamma_factor) diff --git a/tests/paradigms/paradigm_utils_test.py b/tests/paradigms/paradigm_utils_test.py new file mode 100644 index 0000000..4ebaff2 --- /dev/null +++ b/tests/paradigms/paradigm_utils_test.py @@ -0,0 +1,48 @@ +""" +Unit tests for paradigm_utils.py +""" + +import pytest +import torch + +from osculari.paradigms import paradigm_utils + + +def test_accuracy_binary_classification(): + # Test accuracy for binary classification predictions + output = torch.tensor([0.2, -0.1, 0.8, -0.4]).view(4, 1) + target = torch.tensor([1, 0, 1, 0]) + acc = paradigm_utils.accuracy(output, target) + assert acc == 1.0 + + +def test_accuracy_multi_classification(): + # Test accuracy for multi-class predictions + output = torch.tensor([[0.2, -0.1, 0.8, -0.4], [0.1, 0.3, -0.2, 0.5]]) + target = torch.tensor([2, 0]) + acc = paradigm_utils.accuracy(output, target) + assert acc == 0.5 + + +def test_accuracy_invalid_input(): + # Test with invalid input (different shapes) + output = torch.tensor([[0.2, -0.1, 0.8, -0.4], [0.1, 0.3, -0.2, 0.5]]) + target = torch.tensor([2, 0, 1]) # Invalid target shape + with pytest.raises(AssertionError): + paradigm_utils.accuracy(output, target) + + +def test_accuracy_zero_dimensional(): + # Test with zero-dimensional input (should raise an error) + output = torch.tensor(0.5) + target = torch.tensor(1) + with pytest.raises(AssertionError): + paradigm_utils.accuracy(output, target) + + +def test_accuracy_one_dimensional_equal(): + # Test accuracy for one-dimensional predictions where output and target are equal + output = torch.tensor([0.2, -0.1, 0.8, -0.4]).view(4, 1) + target = torch.tensor([0, 0, 1, 0]) + acc = paradigm_utils.accuracy(output, target) + assert acc == 0.75