diff --git a/examples/pytorch_surrogate_model/run_ml_surrogate_15_stage.py b/examples/pytorch_surrogate_model/run_ml_surrogate_15_stage.py index 11f73f42c..033dcec37 100644 --- a/examples/pytorch_surrogate_model/run_ml_surrogate_15_stage.py +++ b/examples/pytorch_surrogate_model/run_ml_surrogate_15_stage.py @@ -16,7 +16,6 @@ cupy_available = True except ImportError: cupy_available = False -import sys import numpy as np import scipy.optimize as opt @@ -36,7 +35,10 @@ import torch except ImportError: print("Warning: Cannot import PyTorch. Skipping test.") + import sys + sys.exit(0) + import zipfile from urllib import request @@ -91,6 +93,10 @@ def download_and_unzip(url, data_dir): zip_dataset.extractall() +print( + "Downloading trained models from Zenodo.org - this might take a minute...", + flush=True, +) data_url = "https://zenodo.org/records/10810754/files/models.zip?download=1" download_and_unzip(data_url, "models.zip") @@ -299,8 +305,7 @@ def __init__(self, sim, stage_i, lattice_index, x_or_y): self.x_or_y = x_or_y self.push = self.set_lens - def set_lens(self, step): - pc = self.sim.particle_container() + def set_lens(self, pc, step): # get envelope parameters rbc = pc.reduced_beam_characteristics() alpha = rbc[f"alpha_{self.x_or_y}"] diff --git a/examples/pytorch_surrogate_model/surrogate_model_definitions.py b/examples/pytorch_surrogate_model/surrogate_model_definitions.py index c3ac44ce9..4819c9d49 100644 --- a/examples/pytorch_surrogate_model/surrogate_model_definitions.py +++ b/examples/pytorch_surrogate_model/surrogate_model_definitions.py @@ -8,8 +8,14 @@ from enum import Enum -import torch -from torch import nn +try: + import torch + from torch import nn +except ImportError: + print("Warning: Cannot import PyTorch. Skipping test.") + import sys + + sys.exit(0) class Activation(Enum):