From 4dc9e1f6baef83fc2e2e427a9c2cef27d401612a Mon Sep 17 00:00:00 2001 From: golmschenk Date: Mon, 6 May 2024 18:17:53 -0400 Subject: [PATCH] Add the infinite_datasets_test_session --- .../infinite_datasets_test_session.py | 51 +++++++++++++++++++ src/qusi/session.py | 2 + 2 files changed, 53 insertions(+) create mode 100644 src/qusi/internal/infinite_datasets_test_session.py diff --git a/src/qusi/internal/infinite_datasets_test_session.py b/src/qusi/internal/infinite_datasets_test_session.py new file mode 100644 index 00000000..5abddb72 --- /dev/null +++ b/src/qusi/internal/infinite_datasets_test_session.py @@ -0,0 +1,51 @@ +from torch.nn import Module +from torch.types import Device +from torch.utils.data import DataLoader +from wandb.wandb_torch import torch + +from qusi.internal.light_curve_dataset import LightCurveDataset + + +def infinite_datasets_test_session(test_datasets: list[LightCurveDataset], model: Module, + metric_functions: list[Module], *, batch_size: int, device: Device, steps: int): + """ + Runs a test session on finite datasets. + + :param test_datasets: A list of datasets to run the test session on. + :param model: A model to perform the inference. + :param metric_functions: A metrics to test. + :param batch_size: A batch size to use during testing. + :param device: A device to run the model on. + :param steps: The number of steps to run on the infinite datasets. + :return: A list of arrays, with one array for each test dataset, with each array containing an element for each + metric that was tested. + """ + test_dataloaders: list[DataLoader] = [] + for test_dataset in test_datasets: + test_dataloaders.append(DataLoader(test_dataset, batch_size=batch_size, pin_memory=True)) + model.eval() + results = [] + for test_dataloader in test_dataloaders: + result = infinite_dataset_test_phase(test_dataloader, model, metric_functions, device=device, steps=steps) + results.append(result) + return results + + +def infinite_dataset_test_phase(dataloader, model: Module, metric_functions: list[Module], device: Device, steps: int): + batch_count = 0 + metric_totals = torch.zeros(size=[len(metric_functions)]) + model.eval() + with torch.no_grad(): + for input_features, targets in dataloader: + input_features = input_features.to(device, non_blocking=True) + targets = targets.to(device, non_blocking=True) + predicted_targets = model(input_features) + for metric_function_index, metric_function in enumerate(metric_functions): + batch_metric_value = metric_function(predicted_targets.to(device, non_blocking=True), + targets) + metric_totals[metric_function_index] += batch_metric_value.to('cpu', non_blocking=True) + batch_count += 1 + if batch_count >= steps: + break + cycle_metric_values = metric_totals / batch_count + return cycle_metric_values diff --git a/src/qusi/session.py b/src/qusi/session.py index 2954a142..8d883245 100644 --- a/src/qusi/session.py +++ b/src/qusi/session.py @@ -4,6 +4,7 @@ from qusi.internal.device import get_device from qusi.internal.finite_test_session import finite_datasets_test_session from qusi.internal.infer_session import infer_session +from qusi.internal.infinite_datasets_test_session import infinite_datasets_test_session from qusi.internal.train_hyperparameter_configuration import TrainHyperparameterConfiguration from qusi.internal.train_logging_configuration import TrainLoggingConfiguration from qusi.internal.train_system_configuration import TrainSystemConfiguration @@ -13,6 +14,7 @@ 'finite_datasets_test_session', 'get_device', 'infer_session', + 'infinite_datasets_test_session', 'TrainHyperparameterConfiguration', 'TrainLoggingConfiguration', 'TrainSystemConfiguration',