Skip to content

Commit

Permalink
Allow the multiprocessing start method to have already been set
Browse files Browse the repository at this point in the history
  • Loading branch information
golmschenk committed Jul 28, 2024
1 parent 7d173ad commit 17dba3a
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
9 changes: 7 additions & 2 deletions src/qusi/internal/infer_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@ def infer_session(
:param device: The device to run the model on.
:return: A list of arrays with each element being the array predicted for each light curve in the dataset.
"""
torch.multiprocessing.set_start_method("spawn")
logger.info(f'Creating dataloader workers...')
try:
torch.multiprocessing.set_start_method("spawn")
except RuntimeError:
pass
if workers_per_dataloader == 0:
prefetch_factor = None
else:
Expand All @@ -41,6 +45,7 @@ def infer_session(
infer_dataloaders.append(infer_dataloader)
model.eval()
results = []
logger.info(f'Entering infer loop...')
for infer_dataloader in infer_dataloaders:
result = infer_phase(infer_dataloader, model, device=device)
results.append(result)
Expand All @@ -61,6 +66,6 @@ def infer_phase(dataloader, model: Module, device: Device):
batches_of_predicted_targets.append(batch_predicted_targets_array)
batch_count += 1
processed_count += batch_predicted_targets_array.shape[0]
logger.info(f'Processed: {processed_count}')
logger.info(f'Processed: {processed_count}.')
predicted_targets = np.concatenate(batches_of_predicted_targets, axis=0)
return predicted_targets
5 changes: 4 additions & 1 deletion src/qusi/internal/train_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,10 @@ def train_session(
dir=sessions_directory,
)
train_dataset = InterleavedDataset.new(*train_datasets)
torch.multiprocessing.set_start_method("spawn")
try:
torch.multiprocessing.set_start_method("spawn")
except RuntimeError:
pass
workers_per_dataloader = system_configuration.preprocessing_processes_per_train_process
if workers_per_dataloader == 0:
prefetch_factor = None
Expand Down

0 comments on commit 17dba3a

Please sign in to comment.