Skip to content

Commit

Permalink
update algorithm-specific options in model
Browse files Browse the repository at this point in the history
  • Loading branch information
stevehenke committed Jan 28, 2025
1 parent 8664581 commit b898173
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 18 deletions.
28 changes: 28 additions & 0 deletions src/ptychodus/model/ptychi/autodiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
AutodiffPtychographyProbeOptions,
AutodiffPtychographyProbePositionOptions,
AutodiffPtychographyReconstructorOptions,
ForwardModels,
LossFunctions,
)
from ptychi.api.task import PtychographyTask

Expand All @@ -34,6 +36,29 @@ def name(self) -> str:

def _create_reconstructor_options(self) -> AutodiffPtychographyReconstructorOptions:
helper = self._options_helper.reconstructor_helper

####

loss_function_str = self._autodiffSettings.lossFunction.getValue()

try:
loss_function = LossFunctions[loss_function_str.upper()]
except KeyError:
logger.warning('Failed to parse loss function "{loss_function_str}"!')
loss_function = LossFunctions.MSE_SQRT

####

forward_model_class_str = self._autodiffSettings.forwardModelClass.getValue()

try:
forward_model_class = ForwardModels[forward_model_class_str.upper()]
except KeyError:
logger.warning('Failed to parse forward model class "{forward_model_class_str}"!')
forward_model_class = ForwardModels.PLANAR_PTYCHOGRAPHY

####

return AutodiffPtychographyReconstructorOptions(
num_epochs=helper.num_epochs,
batch_size=helper.batch_size,
Expand All @@ -45,6 +70,9 @@ def _create_reconstructor_options(self) -> AutodiffPtychographyReconstructorOpti
random_seed=helper.random_seed,
displayed_loss_function=helper.displayed_loss_function,
use_low_memory_forward_model=helper.use_low_memory_forward_model,
loss_function=loss_function,
forward_model_class=forward_model_class,
forward_model_params=None,
)

def _create_object_options(self, object_: Object) -> AutodiffPtychographyObjectOptions:
Expand Down
6 changes: 6 additions & 0 deletions src/ptychodus/model/ptychi/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ def __init__(self) -> None:
from ptychi.api import (
BatchingModes,
Directions,
ForwardModels,
ImageGradientMethods,
ImageIntegrationMethods,
LossFunctions,
Expand All @@ -19,6 +20,7 @@ def __init__(self) -> None:
except ModuleNotFoundError:
self._batchingModes: Sequence[str] = list()
self._directions: Sequence[str] = list()
self._forwardModels: Sequence[str] = list()
self._imageGradientMethods: Sequence[str] = list()
self._imageIntegrationMethods: Sequence[str] = list()
self._lossFunctions: Sequence[str] = list()
Expand All @@ -30,6 +32,7 @@ def __init__(self) -> None:
else:
self._batchingModes = [member.name for member in BatchingModes]
self._directions = [member.name for member in Directions]
self._forwardModels = [member.name for member in ForwardModels]
self._imageGradientMethods = [member.name for member in ImageGradientMethods]
self._imageIntegrationMethods = [member.name for member in ImageIntegrationMethods]
self._lossFunctions = [member.name for member in LossFunctions]
Expand All @@ -45,6 +48,9 @@ def batchingModes(self) -> Iterator[str]:
def directions(self) -> Iterator[str]:
return iter(self._directions)

def forwardModels(self) -> Iterator[str]:
return iter(self._forwardModels)

def imageGradientMethods(self) -> Iterator[str]:
return iter(self._imageGradientMethods)

Expand Down
48 changes: 30 additions & 18 deletions src/ptychodus/model/ptychi/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,36 +364,34 @@ def update(self, observable: Observable) -> None:
self.notifyObservers()


class PtyChiDMSettings(Observable, Observer): # FIXME to view
class PtyChiAutodiffSettings(Observable, Observer): # FIXME to view
def __init__(self, registry: SettingsRegistry) -> None:
super().__init__()
self._settingsGroup = registry.createGroup('PtyChiDM')
self._settingsGroup = registry.createGroup('PtyChiAutodiff')
self._settingsGroup.addObserver(self)

self.exitWaveUpdateRelaxation = self._settingsGroup.createRealParameter(
'ExitWaveUpdateRelaxation', 1.0, minimum=0.0, maximum=1.0
)
self.chunkLength = self._settingsGroup.createIntegerParameter('ChunkLength', 1, minimum=1)
self.objectAmplitudeClampLimit = self._settingsGroup.createRealParameter(
'ObjectAmplitudeClampLimit', 1000, minimum=0.0
self.lossFunction = self._settingsGroup.createStringParameter('LossFunction', 'MSE_SQRT')
self.forwardModelClass = self._settingsGroup.createStringParameter(
'ForwardModelClass', 'PLANAR_PTYCHOGRAPHY'
)

def update(self, observable: Observable) -> None:
if observable is self._settingsGroup:
self.notifyObservers()


class PtyChiPIESettings(Observable, Observer): # FIXME to view
class PtyChiDMSettings(Observable, Observer): # FIXME to view
def __init__(self, registry: SettingsRegistry) -> None:
super().__init__()
self._settingsGroup = registry.createGroup('PtyChiPIE')
self._settingsGroup = registry.createGroup('PtyChiDM')
self._settingsGroup.addObserver(self)

self.probeAlpha = self._settingsGroup.createRealParameter(
'ProbeAlpha', 0.1, minimum=0.0, maximum=1.0
self.exitWaveUpdateRelaxation = self._settingsGroup.createRealParameter(
'ExitWaveUpdateRelaxation', 1.0, minimum=0.0, maximum=1.0
)
self.objectAlpha = self._settingsGroup.createRealParameter(
'ObjectAlpha', 0.1, minimum=0.0, maximum=1.0
self.chunkLength = self._settingsGroup.createIntegerParameter('ChunkLength', 1, minimum=1)
self.objectAmplitudeClampLimit = self._settingsGroup.createRealParameter(
'ObjectAmplitudeClampLimit', 1000, minimum=0.0
)

def update(self, observable: Observable) -> None:
Expand Down Expand Up @@ -431,18 +429,32 @@ def __init__(self, registry: SettingsRegistry) -> None:
'MomentumAccelerationGradientMixingFactor', 1.0
)

self.probeOptimalStepSizeScaler = self._settingsGroup.createRealParameter(
'ProbeOptimalStepSizeScaler', 0.9
)
self.objectOptimalStepSizeScaler = self._settingsGroup.createRealParameter(
'ObjectOptimalStepSizeScaler', 0.9, minimum=0.0
)
self.objectMultimodalUpdate = self._settingsGroup.createBooleanParameter(
'ObjectMultimodalUpdate', True
)

self.probeEigenmodeUpdateRelaxation = self._settingsGroup.createRealParameter(
'ProbeEigenmodeUpdateRelaxation', 1.0
def update(self, observable: Observable) -> None:
if observable is self._settingsGroup:
self.notifyObservers()


class PtyChiPIESettings(Observable, Observer): # FIXME to view
def __init__(self, registry: SettingsRegistry) -> None:
super().__init__()
self._settingsGroup = registry.createGroup('PtyChiPIE')
self._settingsGroup.addObserver(self)

self.probeAlpha = self._settingsGroup.createRealParameter(
'ProbeAlpha', 0.1, minimum=0.0, maximum=1.0
)
self.probeOptimalStepSizeScaler = self._settingsGroup.createRealParameter(
'ProbeOptimalStepSizeScaler', 0.9
self.objectAlpha = self._settingsGroup.createRealParameter(
'ObjectAlpha', 0.1, minimum=0.0, maximum=1.0
)

def update(self, observable: Observable) -> None:
Expand Down

0 comments on commit b898173

Please sign in to comment.