From 4f4750473ff8b75edd572153dc16f801f3da4ca6 Mon Sep 17 00:00:00 2001 From: Steven Henke Date: Tue, 28 Jan 2025 17:06:23 -0600 Subject: [PATCH] add autodiff settings --- src/ptychodus/model/ptychi/autodiff.py | 10 ++++++--- src/ptychodus/model/ptychi/core.py | 16 +++++++++------ src/ptychodus/model/ptychi/dm.py | 14 +++++-------- src/ptychodus/model/ptychi/epie.py | 12 ++++------- src/ptychodus/model/ptychi/lsqml.py | 28 +++++++++++--------------- src/ptychodus/model/ptychi/pie.py | 12 ++++------- src/ptychodus/model/ptychi/rpie.py | 12 ++++------- 7 files changed, 46 insertions(+), 58 deletions(-) diff --git a/src/ptychodus/model/ptychi/autodiff.py b/src/ptychodus/model/ptychi/autodiff.py index a5a32542..ac31f8e4 100644 --- a/src/ptychodus/model/ptychi/autodiff.py +++ b/src/ptychodus/model/ptychi/autodiff.py @@ -21,14 +21,18 @@ from ptychodus.api.scan import Scan from .helper import PtyChiOptionsHelper +from .settings import PtyChiAutodiffSettings logger = logging.getLogger(__name__) class AutodiffReconstructor(Reconstructor): - def __init__(self, options_helper: PtyChiOptionsHelper) -> None: + def __init__( + self, options_helper: PtyChiOptionsHelper, settings: PtyChiAutodiffSettings + ) -> None: super().__init__() self._options_helper = options_helper + self._settings = settings @property def name(self) -> str: @@ -39,7 +43,7 @@ def _create_reconstructor_options(self) -> AutodiffPtychographyReconstructorOpti #### - loss_function_str = self._autodiffSettings.lossFunction.getValue() + loss_function_str = self._settings.lossFunction.getValue() try: loss_function = LossFunctions[loss_function_str.upper()] @@ -49,7 +53,7 @@ def _create_reconstructor_options(self) -> AutodiffPtychographyReconstructorOpti #### - forward_model_class_str = self._autodiffSettings.forwardModelClass.getValue() + forward_model_class_str = self._settings.forwardModelClass.getValue() try: forward_model_class = ForwardModels[forward_model_class_str.upper()] diff --git a/src/ptychodus/model/ptychi/core.py b/src/ptychodus/model/ptychi/core.py index 122f4c4b..9645b5b3 100644 --- a/src/ptychodus/model/ptychi/core.py +++ b/src/ptychodus/model/ptychi/core.py @@ -13,6 +13,7 @@ from .device import PtyChiDeviceRepository from .enums import PtyChiEnumerators from .settings import ( + PtyChiAutodiffSettings, PtyChiDMSettings, PtyChiLSQMLSettings, PtyChiOPRSettings, @@ -31,14 +32,15 @@ def __init__( self, settingsRegistry: SettingsRegistry, detector: Detector, isDeveloperModeEnabled: bool ) -> None: super().__init__() - self.reconstructorSettings = PtyChiReconstructorSettings(settingsRegistry) + self.autodiffSettings = PtyChiAutodiffSettings(settingsRegistry) + self.dmSettings = PtyChiDMSettings(settingsRegistry) + self.lsqmlSettings = PtyChiLSQMLSettings(settingsRegistry) self.objectSettings = PtyChiObjectSettings(settingsRegistry) - self.probeSettings = PtyChiProbeSettings(settingsRegistry) - self.probePositionSettings = PtyChiProbePositionSettings(settingsRegistry) self.oprSettings = PtyChiOPRSettings(settingsRegistry) - self.dmSettings = PtyChiDMSettings(settingsRegistry) self.pieSettings = PtyChiPIESettings(settingsRegistry) - self.lsqmlSettings = PtyChiLSQMLSettings(settingsRegistry) + self.probePositionSettings = PtyChiProbePositionSettings(settingsRegistry) + self.probeSettings = PtyChiProbeSettings(settingsRegistry) + self.reconstructorSettings = PtyChiReconstructorSettings(settingsRegistry) self.enumerators = PtyChiEnumerators() self.deviceRepository = PtyChiDeviceRepository( @@ -77,7 +79,9 @@ def __init__( self.reconstructor_list.append(EPIEReconstructor(optionsHelper, self.pieSettings)) self.reconstructor_list.append(RPIEReconstructor(optionsHelper, self.pieSettings)) self.reconstructor_list.append(LSQMLReconstructor(optionsHelper, self.lsqmlSettings)) - self.reconstructor_list.append(AutodiffReconstructor(optionsHelper)) + self.reconstructor_list.append( + AutodiffReconstructor(optionsHelper, self.autodiffSettings) + ) @property def name(self) -> str: diff --git a/src/ptychodus/model/ptychi/dm.py b/src/ptychodus/model/ptychi/dm.py index 0e48a5e7..2872a827 100644 --- a/src/ptychodus/model/ptychi/dm.py +++ b/src/ptychodus/model/ptychi/dm.py @@ -25,14 +25,10 @@ class DMReconstructor(Reconstructor): - def __init__( - self, - options_helper: PtyChiOptionsHelper, - dmSettings: PtyChiDMSettings, - ) -> None: + def __init__(self, options_helper: PtyChiOptionsHelper, settings: PtyChiDMSettings) -> None: super().__init__() self._options_helper = options_helper - self._dmSettings = dmSettings + self._settings = settings @property def name(self) -> str: @@ -51,8 +47,8 @@ def _create_reconstructor_options(self) -> DMReconstructorOptions: random_seed=helper.random_seed, displayed_loss_function=helper.displayed_loss_function, use_low_memory_forward_model=helper.use_low_memory_forward_model, - exit_wave_update_relaxation=self._dmSettings.exitWaveUpdateRelaxation.getValue(), - chunk_length=self._dmSettings.chunkLength.getValue(), + exit_wave_update_relaxation=self._settings.exitWaveUpdateRelaxation.getValue(), + chunk_length=self._settings.chunkLength.getValue(), ) def _create_object_options(self, object_: Object) -> DMObjectOptions: @@ -72,7 +68,7 @@ def _create_object_options(self, object_: Object) -> DMObjectOptions: remove_grid_artifacts=helper.remove_grid_artifacts, multislice_regularization=helper.multislice_regularization, patch_interpolation_method=helper.patch_interpolation_method, - amplitude_clamp_limit=self._dmSettings.objectAmplitudeClampLimit.getValue(), + amplitude_clamp_limit=self._settings.objectAmplitudeClampLimit.getValue(), ) def _create_probe_options(self, probe: Probe, metadata: ProductMetadata) -> DMProbeOptions: diff --git a/src/ptychodus/model/ptychi/epie.py b/src/ptychodus/model/ptychi/epie.py index 69a90fdb..da1cfbda 100644 --- a/src/ptychodus/model/ptychi/epie.py +++ b/src/ptychodus/model/ptychi/epie.py @@ -25,14 +25,10 @@ class EPIEReconstructor(Reconstructor): - def __init__( - self, - options_helper: PtyChiOptionsHelper, - pieSettings: PtyChiPIESettings, - ) -> None: + def __init__(self, options_helper: PtyChiOptionsHelper, settings: PtyChiPIESettings) -> None: super().__init__() self._options_helper = options_helper - self._pieSettings = pieSettings + self._settings = settings @property def name(self) -> str: @@ -70,7 +66,7 @@ def _create_object_options(self, object_: Object) -> PIEObjectOptions: remove_grid_artifacts=helper.remove_grid_artifacts, multislice_regularization=helper.multislice_regularization, patch_interpolation_method=helper.patch_interpolation_method, - alpha=self._pieSettings.objectAlpha.getValue(), + alpha=self._settings.objectAlpha.getValue(), ) def _create_probe_options(self, probe: Probe, metadata: ProductMetadata) -> PIEProbeOptions: @@ -88,7 +84,7 @@ def _create_probe_options(self, probe: Probe, metadata: ProductMetadata) -> PIEP support_constraint=helper.support_constraint, center_constraint=helper.center_constraint, eigenmode_update_relaxation=helper.eigenmode_update_relaxation, - alpha=self._pieSettings.probeAlpha.getValue(), + alpha=self._settings.probeAlpha.getValue(), ) def _create_probe_position_options( diff --git a/src/ptychodus/model/ptychi/lsqml.py b/src/ptychodus/model/ptychi/lsqml.py index acabbb51..dabf24da 100644 --- a/src/ptychodus/model/ptychi/lsqml.py +++ b/src/ptychodus/model/ptychi/lsqml.py @@ -26,14 +26,10 @@ class LSQMLReconstructor(Reconstructor): - def __init__( - self, - options_helper: PtyChiOptionsHelper, - lsqmlSettings: PtyChiLSQMLSettings, - ) -> None: + def __init__(self, options_helper: PtyChiOptionsHelper, settings: PtyChiLSQMLSettings) -> None: super().__init__() self._options_helper = options_helper - self._lsqmlSettings = lsqmlSettings + self._settings = settings @property def name(self) -> str: @@ -44,7 +40,7 @@ def _create_reconstructor_options(self) -> LSQMLReconstructorOptions: #### - noise_model_str = self._lsqmlSettings.noiseModel.getValue() + noise_model_str = self._settings.noiseModel.getValue() try: noise_model = NoiseModels[noise_model_str.upper()] @@ -56,9 +52,9 @@ def _create_reconstructor_options(self) -> LSQMLReconstructorOptions: momentum_acceleration_gradient_mixing_factor: float | None = None - if self._lsqmlSettings.useMomentumAccelerationGradientMixingFactor.getValue(): + if self._settings.useMomentumAccelerationGradientMixingFactor.getValue(): momentum_acceleration_gradient_mixing_factor = ( - self._lsqmlSettings.momentumAccelerationGradientMixingFactor.getValue() + self._settings.momentumAccelerationGradientMixingFactor.getValue() ) #### @@ -75,10 +71,10 @@ def _create_reconstructor_options(self) -> LSQMLReconstructorOptions: displayed_loss_function=helper.displayed_loss_function, use_low_memory_forward_model=helper.use_low_memory_forward_model, noise_model=noise_model, - gaussian_noise_std=self._lsqmlSettings.gaussianNoiseDeviation.getValue(), - solve_obj_prb_step_size_jointly_for_first_slice_in_multislice=self._lsqmlSettings.solveObjectProbeStepSizeJointlyForFirstSliceInMultislice.getValue(), - solve_step_sizes_only_using_first_probe_mode=self._lsqmlSettings.solveStepSizesOnlyUsingFirstProbeMode.getValue(), - momentum_acceleration_gain=self._lsqmlSettings.momentumAccelerationGain.getValue(), + gaussian_noise_std=self._settings.gaussianNoiseDeviation.getValue(), + solve_obj_prb_step_size_jointly_for_first_slice_in_multislice=self._settings.solveObjectProbeStepSizeJointlyForFirstSliceInMultislice.getValue(), + solve_step_sizes_only_using_first_probe_mode=self._settings.solveStepSizesOnlyUsingFirstProbeMode.getValue(), + momentum_acceleration_gain=self._settings.momentumAccelerationGain.getValue(), momentum_acceleration_gradient_mixing_factor=momentum_acceleration_gradient_mixing_factor, ) @@ -99,8 +95,8 @@ def _create_object_options(self, object_: Object) -> LSQMLObjectOptions: remove_grid_artifacts=helper.remove_grid_artifacts, multislice_regularization=helper.multislice_regularization, patch_interpolation_method=helper.patch_interpolation_method, - optimal_step_size_scaler=self._lsqmlSettings.objectOptimalStepSizeScaler.getValue(), - multimodal_update=self._lsqmlSettings.objectMultimodalUpdate.getValue(), + optimal_step_size_scaler=self._settings.objectOptimalStepSizeScaler.getValue(), + multimodal_update=self._settings.objectMultimodalUpdate.getValue(), ) def _create_probe_options(self, probe: Probe, metadata: ProductMetadata) -> LSQMLProbeOptions: @@ -118,7 +114,7 @@ def _create_probe_options(self, probe: Probe, metadata: ProductMetadata) -> LSQM support_constraint=helper.support_constraint, center_constraint=helper.center_constraint, eigenmode_update_relaxation=helper.eigenmode_update_relaxation, - optimal_step_size_scaler=self._lsqmlSettings.probeOptimalStepSizeScaler.getValue(), + optimal_step_size_scaler=self._settings.probeOptimalStepSizeScaler.getValue(), ) def _create_probe_position_options( diff --git a/src/ptychodus/model/ptychi/pie.py b/src/ptychodus/model/ptychi/pie.py index c377cbec..21b9f40f 100644 --- a/src/ptychodus/model/ptychi/pie.py +++ b/src/ptychodus/model/ptychi/pie.py @@ -25,14 +25,10 @@ class PIEReconstructor(Reconstructor): - def __init__( - self, - options_helper: PtyChiOptionsHelper, - pieSettings: PtyChiPIESettings, - ) -> None: + def __init__(self, options_helper: PtyChiOptionsHelper, settings: PtyChiPIESettings) -> None: super().__init__() self._options_helper = options_helper - self._pieSettings = pieSettings + self._settings = settings @property def name(self) -> str: @@ -70,7 +66,7 @@ def _create_object_options(self, object_: Object) -> PIEObjectOptions: remove_grid_artifacts=helper.remove_grid_artifacts, multislice_regularization=helper.multislice_regularization, patch_interpolation_method=helper.patch_interpolation_method, - alpha=self._pieSettings.objectAlpha.getValue(), + alpha=self._settings.objectAlpha.getValue(), ) def _create_probe_options(self, probe: Probe, metadata: ProductMetadata) -> PIEProbeOptions: @@ -88,7 +84,7 @@ def _create_probe_options(self, probe: Probe, metadata: ProductMetadata) -> PIEP support_constraint=helper.support_constraint, center_constraint=helper.center_constraint, eigenmode_update_relaxation=helper.eigenmode_update_relaxation, - alpha=self._pieSettings.probeAlpha.getValue(), + alpha=self._settings.probeAlpha.getValue(), ) def _create_probe_position_options( diff --git a/src/ptychodus/model/ptychi/rpie.py b/src/ptychodus/model/ptychi/rpie.py index 6090a64b..22970f4c 100644 --- a/src/ptychodus/model/ptychi/rpie.py +++ b/src/ptychodus/model/ptychi/rpie.py @@ -25,14 +25,10 @@ class RPIEReconstructor(Reconstructor): - def __init__( - self, - options_helper: PtyChiOptionsHelper, - pieSettings: PtyChiPIESettings, - ) -> None: + def __init__(self, options_helper: PtyChiOptionsHelper, settings: PtyChiPIESettings) -> None: super().__init__() self._options_helper = options_helper - self._pieSettings = pieSettings + self._settings = settings @property def name(self) -> str: @@ -70,7 +66,7 @@ def _create_object_options(self, object_: Object) -> PIEObjectOptions: remove_grid_artifacts=helper.remove_grid_artifacts, multislice_regularization=helper.multislice_regularization, patch_interpolation_method=helper.patch_interpolation_method, - alpha=self._pieSettings.objectAlpha.getValue(), + alpha=self._settings.objectAlpha.getValue(), ) def _create_probe_options(self, probe: Probe, metadata: ProductMetadata) -> PIEProbeOptions: @@ -88,7 +84,7 @@ def _create_probe_options(self, probe: Probe, metadata: ProductMetadata) -> PIEP support_constraint=helper.support_constraint, center_constraint=helper.center_constraint, eigenmode_update_relaxation=helper.eigenmode_update_relaxation, - alpha=self._pieSettings.probeAlpha.getValue(), + alpha=self._settings.probeAlpha.getValue(), ) def _create_probe_position_options(