Skip to content

Commit

Permalink
add autodiff settings
Browse files Browse the repository at this point in the history
  • Loading branch information
stevehenke committed Jan 28, 2025
1 parent b898173 commit 4f47504
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 58 deletions.
10 changes: 7 additions & 3 deletions src/ptychodus/model/ptychi/autodiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,18 @@
from ptychodus.api.scan import Scan

from .helper import PtyChiOptionsHelper
from .settings import PtyChiAutodiffSettings

logger = logging.getLogger(__name__)


class AutodiffReconstructor(Reconstructor):
def __init__(self, options_helper: PtyChiOptionsHelper) -> None:
def __init__(
self, options_helper: PtyChiOptionsHelper, settings: PtyChiAutodiffSettings
) -> None:
super().__init__()
self._options_helper = options_helper
self._settings = settings

@property
def name(self) -> str:
Expand All @@ -39,7 +43,7 @@ def _create_reconstructor_options(self) -> AutodiffPtychographyReconstructorOpti

####

loss_function_str = self._autodiffSettings.lossFunction.getValue()
loss_function_str = self._settings.lossFunction.getValue()

try:
loss_function = LossFunctions[loss_function_str.upper()]
Expand All @@ -49,7 +53,7 @@ def _create_reconstructor_options(self) -> AutodiffPtychographyReconstructorOpti

####

forward_model_class_str = self._autodiffSettings.forwardModelClass.getValue()
forward_model_class_str = self._settings.forwardModelClass.getValue()

try:
forward_model_class = ForwardModels[forward_model_class_str.upper()]
Expand Down
16 changes: 10 additions & 6 deletions src/ptychodus/model/ptychi/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .device import PtyChiDeviceRepository
from .enums import PtyChiEnumerators
from .settings import (
PtyChiAutodiffSettings,
PtyChiDMSettings,
PtyChiLSQMLSettings,
PtyChiOPRSettings,
Expand All @@ -31,14 +32,15 @@ def __init__(
self, settingsRegistry: SettingsRegistry, detector: Detector, isDeveloperModeEnabled: bool
) -> None:
super().__init__()
self.reconstructorSettings = PtyChiReconstructorSettings(settingsRegistry)
self.autodiffSettings = PtyChiAutodiffSettings(settingsRegistry)
self.dmSettings = PtyChiDMSettings(settingsRegistry)
self.lsqmlSettings = PtyChiLSQMLSettings(settingsRegistry)
self.objectSettings = PtyChiObjectSettings(settingsRegistry)
self.probeSettings = PtyChiProbeSettings(settingsRegistry)
self.probePositionSettings = PtyChiProbePositionSettings(settingsRegistry)
self.oprSettings = PtyChiOPRSettings(settingsRegistry)
self.dmSettings = PtyChiDMSettings(settingsRegistry)
self.pieSettings = PtyChiPIESettings(settingsRegistry)
self.lsqmlSettings = PtyChiLSQMLSettings(settingsRegistry)
self.probePositionSettings = PtyChiProbePositionSettings(settingsRegistry)
self.probeSettings = PtyChiProbeSettings(settingsRegistry)
self.reconstructorSettings = PtyChiReconstructorSettings(settingsRegistry)

self.enumerators = PtyChiEnumerators()
self.deviceRepository = PtyChiDeviceRepository(
Expand Down Expand Up @@ -77,7 +79,9 @@ def __init__(
self.reconstructor_list.append(EPIEReconstructor(optionsHelper, self.pieSettings))
self.reconstructor_list.append(RPIEReconstructor(optionsHelper, self.pieSettings))
self.reconstructor_list.append(LSQMLReconstructor(optionsHelper, self.lsqmlSettings))
self.reconstructor_list.append(AutodiffReconstructor(optionsHelper))
self.reconstructor_list.append(
AutodiffReconstructor(optionsHelper, self.autodiffSettings)
)

@property
def name(self) -> str:
Expand Down
14 changes: 5 additions & 9 deletions src/ptychodus/model/ptychi/dm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,10 @@


class DMReconstructor(Reconstructor):
def __init__(
self,
options_helper: PtyChiOptionsHelper,
dmSettings: PtyChiDMSettings,
) -> None:
def __init__(self, options_helper: PtyChiOptionsHelper, settings: PtyChiDMSettings) -> None:
super().__init__()
self._options_helper = options_helper
self._dmSettings = dmSettings
self._settings = settings

@property
def name(self) -> str:
Expand All @@ -51,8 +47,8 @@ def _create_reconstructor_options(self) -> DMReconstructorOptions:
random_seed=helper.random_seed,
displayed_loss_function=helper.displayed_loss_function,
use_low_memory_forward_model=helper.use_low_memory_forward_model,
exit_wave_update_relaxation=self._dmSettings.exitWaveUpdateRelaxation.getValue(),
chunk_length=self._dmSettings.chunkLength.getValue(),
exit_wave_update_relaxation=self._settings.exitWaveUpdateRelaxation.getValue(),
chunk_length=self._settings.chunkLength.getValue(),
)

def _create_object_options(self, object_: Object) -> DMObjectOptions:
Expand All @@ -72,7 +68,7 @@ def _create_object_options(self, object_: Object) -> DMObjectOptions:
remove_grid_artifacts=helper.remove_grid_artifacts,
multislice_regularization=helper.multislice_regularization,
patch_interpolation_method=helper.patch_interpolation_method,
amplitude_clamp_limit=self._dmSettings.objectAmplitudeClampLimit.getValue(),
amplitude_clamp_limit=self._settings.objectAmplitudeClampLimit.getValue(),
)

def _create_probe_options(self, probe: Probe, metadata: ProductMetadata) -> DMProbeOptions:
Expand Down
12 changes: 4 additions & 8 deletions src/ptychodus/model/ptychi/epie.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,10 @@


class EPIEReconstructor(Reconstructor):
def __init__(
self,
options_helper: PtyChiOptionsHelper,
pieSettings: PtyChiPIESettings,
) -> None:
def __init__(self, options_helper: PtyChiOptionsHelper, settings: PtyChiPIESettings) -> None:
super().__init__()
self._options_helper = options_helper
self._pieSettings = pieSettings
self._settings = settings

@property
def name(self) -> str:
Expand Down Expand Up @@ -70,7 +66,7 @@ def _create_object_options(self, object_: Object) -> PIEObjectOptions:
remove_grid_artifacts=helper.remove_grid_artifacts,
multislice_regularization=helper.multislice_regularization,
patch_interpolation_method=helper.patch_interpolation_method,
alpha=self._pieSettings.objectAlpha.getValue(),
alpha=self._settings.objectAlpha.getValue(),
)

def _create_probe_options(self, probe: Probe, metadata: ProductMetadata) -> PIEProbeOptions:
Expand All @@ -88,7 +84,7 @@ def _create_probe_options(self, probe: Probe, metadata: ProductMetadata) -> PIEP
support_constraint=helper.support_constraint,
center_constraint=helper.center_constraint,
eigenmode_update_relaxation=helper.eigenmode_update_relaxation,
alpha=self._pieSettings.probeAlpha.getValue(),
alpha=self._settings.probeAlpha.getValue(),
)

def _create_probe_position_options(
Expand Down
28 changes: 12 additions & 16 deletions src/ptychodus/model/ptychi/lsqml.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,10 @@


class LSQMLReconstructor(Reconstructor):
def __init__(
self,
options_helper: PtyChiOptionsHelper,
lsqmlSettings: PtyChiLSQMLSettings,
) -> None:
def __init__(self, options_helper: PtyChiOptionsHelper, settings: PtyChiLSQMLSettings) -> None:
super().__init__()
self._options_helper = options_helper
self._lsqmlSettings = lsqmlSettings
self._settings = settings

@property
def name(self) -> str:
Expand All @@ -44,7 +40,7 @@ def _create_reconstructor_options(self) -> LSQMLReconstructorOptions:

####

noise_model_str = self._lsqmlSettings.noiseModel.getValue()
noise_model_str = self._settings.noiseModel.getValue()

try:
noise_model = NoiseModels[noise_model_str.upper()]
Expand All @@ -56,9 +52,9 @@ def _create_reconstructor_options(self) -> LSQMLReconstructorOptions:

momentum_acceleration_gradient_mixing_factor: float | None = None

if self._lsqmlSettings.useMomentumAccelerationGradientMixingFactor.getValue():
if self._settings.useMomentumAccelerationGradientMixingFactor.getValue():
momentum_acceleration_gradient_mixing_factor = (
self._lsqmlSettings.momentumAccelerationGradientMixingFactor.getValue()
self._settings.momentumAccelerationGradientMixingFactor.getValue()
)

####
Expand All @@ -75,10 +71,10 @@ def _create_reconstructor_options(self) -> LSQMLReconstructorOptions:
displayed_loss_function=helper.displayed_loss_function,
use_low_memory_forward_model=helper.use_low_memory_forward_model,
noise_model=noise_model,
gaussian_noise_std=self._lsqmlSettings.gaussianNoiseDeviation.getValue(),
solve_obj_prb_step_size_jointly_for_first_slice_in_multislice=self._lsqmlSettings.solveObjectProbeStepSizeJointlyForFirstSliceInMultislice.getValue(),
solve_step_sizes_only_using_first_probe_mode=self._lsqmlSettings.solveStepSizesOnlyUsingFirstProbeMode.getValue(),
momentum_acceleration_gain=self._lsqmlSettings.momentumAccelerationGain.getValue(),
gaussian_noise_std=self._settings.gaussianNoiseDeviation.getValue(),
solve_obj_prb_step_size_jointly_for_first_slice_in_multislice=self._settings.solveObjectProbeStepSizeJointlyForFirstSliceInMultislice.getValue(),
solve_step_sizes_only_using_first_probe_mode=self._settings.solveStepSizesOnlyUsingFirstProbeMode.getValue(),
momentum_acceleration_gain=self._settings.momentumAccelerationGain.getValue(),
momentum_acceleration_gradient_mixing_factor=momentum_acceleration_gradient_mixing_factor,
)

Expand All @@ -99,8 +95,8 @@ def _create_object_options(self, object_: Object) -> LSQMLObjectOptions:
remove_grid_artifacts=helper.remove_grid_artifacts,
multislice_regularization=helper.multislice_regularization,
patch_interpolation_method=helper.patch_interpolation_method,
optimal_step_size_scaler=self._lsqmlSettings.objectOptimalStepSizeScaler.getValue(),
multimodal_update=self._lsqmlSettings.objectMultimodalUpdate.getValue(),
optimal_step_size_scaler=self._settings.objectOptimalStepSizeScaler.getValue(),
multimodal_update=self._settings.objectMultimodalUpdate.getValue(),
)

def _create_probe_options(self, probe: Probe, metadata: ProductMetadata) -> LSQMLProbeOptions:
Expand All @@ -118,7 +114,7 @@ def _create_probe_options(self, probe: Probe, metadata: ProductMetadata) -> LSQM
support_constraint=helper.support_constraint,
center_constraint=helper.center_constraint,
eigenmode_update_relaxation=helper.eigenmode_update_relaxation,
optimal_step_size_scaler=self._lsqmlSettings.probeOptimalStepSizeScaler.getValue(),
optimal_step_size_scaler=self._settings.probeOptimalStepSizeScaler.getValue(),
)

def _create_probe_position_options(
Expand Down
12 changes: 4 additions & 8 deletions src/ptychodus/model/ptychi/pie.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,10 @@


class PIEReconstructor(Reconstructor):
def __init__(
self,
options_helper: PtyChiOptionsHelper,
pieSettings: PtyChiPIESettings,
) -> None:
def __init__(self, options_helper: PtyChiOptionsHelper, settings: PtyChiPIESettings) -> None:
super().__init__()
self._options_helper = options_helper
self._pieSettings = pieSettings
self._settings = settings

@property
def name(self) -> str:
Expand Down Expand Up @@ -70,7 +66,7 @@ def _create_object_options(self, object_: Object) -> PIEObjectOptions:
remove_grid_artifacts=helper.remove_grid_artifacts,
multislice_regularization=helper.multislice_regularization,
patch_interpolation_method=helper.patch_interpolation_method,
alpha=self._pieSettings.objectAlpha.getValue(),
alpha=self._settings.objectAlpha.getValue(),
)

def _create_probe_options(self, probe: Probe, metadata: ProductMetadata) -> PIEProbeOptions:
Expand All @@ -88,7 +84,7 @@ def _create_probe_options(self, probe: Probe, metadata: ProductMetadata) -> PIEP
support_constraint=helper.support_constraint,
center_constraint=helper.center_constraint,
eigenmode_update_relaxation=helper.eigenmode_update_relaxation,
alpha=self._pieSettings.probeAlpha.getValue(),
alpha=self._settings.probeAlpha.getValue(),
)

def _create_probe_position_options(
Expand Down
12 changes: 4 additions & 8 deletions src/ptychodus/model/ptychi/rpie.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,10 @@


class RPIEReconstructor(Reconstructor):
def __init__(
self,
options_helper: PtyChiOptionsHelper,
pieSettings: PtyChiPIESettings,
) -> None:
def __init__(self, options_helper: PtyChiOptionsHelper, settings: PtyChiPIESettings) -> None:
super().__init__()
self._options_helper = options_helper
self._pieSettings = pieSettings
self._settings = settings

@property
def name(self) -> str:
Expand Down Expand Up @@ -70,7 +66,7 @@ def _create_object_options(self, object_: Object) -> PIEObjectOptions:
remove_grid_artifacts=helper.remove_grid_artifacts,
multislice_regularization=helper.multislice_regularization,
patch_interpolation_method=helper.patch_interpolation_method,
alpha=self._pieSettings.objectAlpha.getValue(),
alpha=self._settings.objectAlpha.getValue(),
)

def _create_probe_options(self, probe: Probe, metadata: ProductMetadata) -> PIEProbeOptions:
Expand All @@ -88,7 +84,7 @@ def _create_probe_options(self, probe: Probe, metadata: ProductMetadata) -> PIEP
support_constraint=helper.support_constraint,
center_constraint=helper.center_constraint,
eigenmode_update_relaxation=helper.eigenmode_update_relaxation,
alpha=self._pieSettings.probeAlpha.getValue(),
alpha=self._settings.probeAlpha.getValue(),
)

def _create_probe_position_options(
Expand Down

0 comments on commit 4f47504

Please sign in to comment.