diff --git a/src/qusi/internal/infer_session.py b/src/qusi/internal/infer_session.py index 2f92223..b628658 100644 --- a/src/qusi/internal/infer_session.py +++ b/src/qusi/internal/infer_session.py @@ -29,7 +29,11 @@ def infer_session( :param device: The device to run the model on. :return: A list of arrays with each element being the array predicted for each light curve in the dataset. """ - torch.multiprocessing.set_start_method("spawn") + logger.info(f'Creating dataloader workers...') + try: + torch.multiprocessing.set_start_method("spawn") + except RuntimeError: + pass if workers_per_dataloader == 0: prefetch_factor = None else: @@ -41,6 +45,7 @@ def infer_session( infer_dataloaders.append(infer_dataloader) model.eval() results = [] + logger.info(f'Entering infer loop...') for infer_dataloader in infer_dataloaders: result = infer_phase(infer_dataloader, model, device=device) results.append(result) @@ -61,6 +66,6 @@ def infer_phase(dataloader, model: Module, device: Device): batches_of_predicted_targets.append(batch_predicted_targets_array) batch_count += 1 processed_count += batch_predicted_targets_array.shape[0] - logger.info(f'Processed: {processed_count}') + logger.info(f'Processed: {processed_count}.') predicted_targets = np.concatenate(batches_of_predicted_targets, axis=0) return predicted_targets diff --git a/src/qusi/internal/train_session.py b/src/qusi/internal/train_session.py index 0b61c2e..8ca81d0 100644 --- a/src/qusi/internal/train_session.py +++ b/src/qusi/internal/train_session.py @@ -86,7 +86,10 @@ def train_session( dir=sessions_directory, ) train_dataset = InterleavedDataset.new(*train_datasets) - torch.multiprocessing.set_start_method("spawn") + try: + torch.multiprocessing.set_start_method("spawn") + except RuntimeError: + pass workers_per_dataloader = system_configuration.preprocessing_processes_per_train_process if workers_per_dataloader == 0: prefetch_factor = None