Skip to content

Commit

Permalink
Fix: PyTorch Example (#620)
Browse files Browse the repository at this point in the history
Smaller syntax errors that prevented run.
Skip gracefully if PyTorch is not found.

Add more detail infos about slow startup (download).
  • Loading branch information
ax3l authored May 23, 2024
1 parent d3a795a commit 5d6e9ea
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 5 deletions.
11 changes: 8 additions & 3 deletions examples/pytorch_surrogate_model/run_ml_surrogate_15_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
cupy_available = True
except ImportError:
cupy_available = False
import sys

import numpy as np
import scipy.optimize as opt
Expand All @@ -36,7 +35,10 @@
import torch
except ImportError:
print("Warning: Cannot import PyTorch. Skipping test.")
import sys

sys.exit(0)

import zipfile
from urllib import request

Expand Down Expand Up @@ -91,6 +93,10 @@ def download_and_unzip(url, data_dir):
zip_dataset.extractall()


print(
"Downloading trained models from Zenodo.org - this might take a minute...",
flush=True,
)
data_url = "https://zenodo.org/records/10810754/files/models.zip?download=1"
download_and_unzip(data_url, "models.zip")

Expand Down Expand Up @@ -299,8 +305,7 @@ def __init__(self, sim, stage_i, lattice_index, x_or_y):
self.x_or_y = x_or_y
self.push = self.set_lens

def set_lens(self, step):
pc = self.sim.particle_container()
def set_lens(self, pc, step):
# get envelope parameters
rbc = pc.reduced_beam_characteristics()
alpha = rbc[f"alpha_{self.x_or_y}"]
Expand Down
10 changes: 8 additions & 2 deletions examples/pytorch_surrogate_model/surrogate_model_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,14 @@

from enum import Enum

import torch
from torch import nn
try:
import torch
from torch import nn
except ImportError:
print("Warning: Cannot import PyTorch. Skipping test.")
import sys

sys.exit(0)


class Activation(Enum):
Expand Down

0 comments on commit 5d6e9ea

Please sign in to comment.