Skip to content

Commit 2ce3bde

Browse files
authored
Merge pull request #404 from skim0119/wip/401
Use OperatorGroup for constrain and callback features
2 parents 61c53e9 + a7c1de0 commit 2ce3bde

26 files changed

+464
-355
lines changed

Makefile

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ flake8:
5050
.PHONY: autoflake-check
5151
autoflake-check:
5252
poetry run autoflake --version
53-
poetry run autoflake $(AUTOFLAKE_ARGS) elastica tests examples
5453
poetry run autoflake --check $(AUTOFLAKE_ARGS) elastica tests examples
5554

5655
.PHONY: autoflake-format

elastica/__init__.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,9 @@
7575
from elastica.utils import isqrt
7676
from elastica.timestepper import (
7777
integrate,
78-
PositionVerlet,
79-
PEFRL,
80-
RungeKutta4,
81-
EulerForward,
8278
extend_stepper_interface,
8379
)
80+
from elastica.timestepper.symplectic_steppers import PositionVerlet, PEFRL
8481
from elastica.memory_block.memory_block_rigid_body import MemoryBlockRigidBody
8582
from elastica.memory_block.memory_block_rod import MemoryBlockCosseratRod
8683
from elastica.restart import save_state, load_state

elastica/timestepper/explicit_steppers.py renamed to elastica/experimental/timestepper/explicit_steppers.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,15 @@
88
from elastica.typing import (
99
SystemType,
1010
SystemCollectionType,
11-
OperatorType,
11+
StepType,
1212
SteppersOperatorsType,
1313
StateType,
1414
)
15-
from elastica.systems.protocol import ExplicitSystemProtocol
16-
from .protocol import ExplicitStepperProtocol, MemoryProtocol
15+
from elastica.experimental.timestepper.protocol import (
16+
ExplicitSystemProtocol,
17+
ExplicitStepperProtocol,
18+
MemoryProtocol,
19+
)
1720

1821

1922
"""
@@ -166,10 +169,10 @@ class EulerForward(ExplicitStepperMixin):
166169
Classical Euler Forward stepper. Stateless, coordinates operations only.
167170
"""
168171

169-
def get_stages(self) -> list[OperatorType]:
172+
def get_stages(self) -> list[StepType]:
170173
return [self._first_stage]
171174

172-
def get_updates(self) -> list[OperatorType]:
175+
def get_updates(self) -> list[StepType]:
173176
return [self._first_update]
174177

175178
def _first_stage(
@@ -198,15 +201,15 @@ class RungeKutta4(ExplicitStepperMixin):
198201
to be externally managed and allocated.
199202
"""
200203

201-
def get_stages(self) -> list[OperatorType]:
204+
def get_stages(self) -> list[StepType]:
202205
return [
203206
self._first_stage,
204207
self._second_stage,
205208
self._third_stage,
206209
self._fourth_stage,
207210
]
208211

209-
def get_updates(self) -> list[OperatorType]:
212+
def get_updates(self) -> list[StepType]:
210213
return [
211214
self._first_update,
212215
self._second_update,

elastica/systems/memory.py renamed to elastica/experimental/timestepper/memory.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
from typing import Iterator, TypeVar, Generic, Type
2-
from elastica.timestepper.protocol import ExplicitStepperProtocol
32
from elastica.typing import SystemCollectionType
3+
from elastica.experimental.timestepper.explicit_steppers import (
4+
RungeKutta4,
5+
EulerForward,
6+
)
7+
from elastica.experimental.timestepper.protocol import ExplicitStepperProtocol
48

59
from copy import copy
610

@@ -12,11 +16,6 @@ def make_memory_for_explicit_stepper(
1216
) -> "MemoryCollection":
1317
# TODO Automated logic (class creation, memory management logic) agnostic of stepper details (RK, AB etc.)
1418

15-
from elastica.timestepper.explicit_steppers import (
16-
RungeKutta4,
17-
EulerForward,
18-
)
19-
2019
# is_this_system_a_collection = is_system_a_collection(system)
2120

2221
memory_cls: Type
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
from typing import Protocol
2+
3+
from elastica.typing import StepType, StateType
4+
from elastica.systems.protocol import SystemProtocol, SlenderBodyGeometryProtocol
5+
from elastica.timestepper.protocol import StepperProtocol
6+
7+
import numpy as np
8+
9+
10+
class ExplicitSystemProtocol(SystemProtocol, SlenderBodyGeometryProtocol, Protocol):
11+
# TODO: Temporarily made to handle explicit stepper.
12+
# Need to be refactored as the explicit stepper is further developed.
13+
def __call__(self, time: np.float64, dt: np.float64) -> np.float64: ...
14+
@property
15+
def state(self) -> StateType: ...
16+
@state.setter
17+
def state(self, state: StateType) -> None: ...
18+
@property
19+
def n_elems(self) -> int: ...
20+
21+
22+
class MemoryProtocol(Protocol):
23+
@property
24+
def initial_state(self) -> bool: ...
25+
26+
27+
class ExplicitStepperProtocol(StepperProtocol, Protocol):
28+
"""symplectic stepper protocol."""
29+
30+
def get_stages(self) -> list[StepType]: ...
31+
32+
def get_updates(self) -> list[StepType]: ...
33+
34+
35+
# class _LinearExponentialIntegratorMixin:
36+
# """
37+
# Linear Exponential integrator mixin wrapper.
38+
# """
39+
#
40+
# def __init__(self):
41+
# pass
42+
#
43+
# def _do_stage(self, System, Memory, time, dt):
44+
# # TODO : Make more general, system should not be calculating what the state
45+
# # transition matrix directly is, but rather it should just give
46+
# Memory.linear_operator = System.get_linear_state_transition_operator(time, dt)
47+
#
48+
# def _do_update(self, System, Memory, time, dt):
49+
# # FIXME What's the right formula when doing update?
50+
# # System.linearly_evolving_state = _batch_matmul(
51+
# # System.linearly_evolving_state,
52+
# # Memory.linear_operator
53+
# # )
54+
# System.linearly_evolving_state = np.einsum(
55+
# "ijk,ljk->ilk", System.linearly_evolving_state, Memory.linear_operator
56+
# )
57+
# return time + dt
58+
#
59+
# def _first_prefactor(self, dt):
60+
# """Prefactor call to satisfy interface of SymplecticStepper. Should never
61+
# be used in actual code.
62+
#
63+
# Parameters
64+
# ----------
65+
# dt : the time step of simulation
66+
#
67+
# Raises
68+
# ------
69+
# RuntimeError
70+
# """
71+
# raise RuntimeError(
72+
# "Symplectic prefactor of LinearExponentialIntegrator should not be called!"
73+
# )
74+
#
75+
# # Code repeat!
76+
# # Easy to avoid, but keep for performance.
77+
# def _do_one_step(self, System, time, prefac):
78+
# System.linearly_evolving_state = np.einsum(
79+
# "ijk,ljk->ilk",
80+
# System.linearly_evolving_state,
81+
# System.get_linear_state_transition_operator(time, prefac),
82+
# )
83+
# return (
84+
# time # TODO fix hack that treats time separately here. Shuold be time + dt
85+
# )
86+
# # return time + dt

elastica/modules/base_system.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
Basic coordinating for multiple, smaller systems that have an independently integrable
66
interface (i.e. works with symplectic or explicit routines `timestepper.py`.)
77
"""
8-
from typing import Type, Generator, Iterable, Any, overload
8+
from typing import Type, Generator, Any, overload
99
from typing import final
1010
from elastica.typing import (
1111
SystemType,
@@ -27,6 +27,7 @@
2727

2828
from .memory_block import construct_memory_block_structures
2929
from .operator_group import OperatorGroupFIFO
30+
from .protocol import ModuleProtocol
3031

3132

3233
class BaseSystemCollection(MutableSequence):
@@ -55,10 +56,18 @@ def __init__(self) -> None:
5556
# Collection of functions. Each group is executed as a collection at the different steps.
5657
# Each component (Forcing, Connection, etc.) registers the executable (callable) function
5758
# in the group that that needs to be executed. These should be initialized before mixin.
58-
self._feature_group_synchronize: Iterable[OperatorType] = OperatorGroupFIFO()
59-
self._feature_group_constrain_values: list[OperatorType] = []
60-
self._feature_group_constrain_rates: list[OperatorType] = []
61-
self._feature_group_callback: list[OperatorCallbackType] = []
59+
self._feature_group_synchronize: OperatorGroupFIFO[
60+
OperatorType, ModuleProtocol
61+
] = OperatorGroupFIFO()
62+
self._feature_group_constrain_values: OperatorGroupFIFO[
63+
OperatorType, ModuleProtocol
64+
] = OperatorGroupFIFO()
65+
self._feature_group_constrain_rates: OperatorGroupFIFO[
66+
OperatorType, ModuleProtocol
67+
] = OperatorGroupFIFO()
68+
self._feature_group_callback: OperatorGroupFIFO[
69+
OperatorCallbackType, ModuleProtocol
70+
] = OperatorGroupFIFO()
6271
self._feature_group_finalize: list[OperatorFinalizeType] = []
6372
# We need to initialize our mixin classes
6473
super().__init__()

elastica/modules/callbacks.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from elastica.typing import SystemType, SystemIdxType, OperatorFinalizeType
1010
from .protocol import ModuleProtocol
1111

12+
import functools
13+
1214
import numpy as np
1315

1416
from elastica.callback_functions import CallBackBaseClass
@@ -29,9 +31,7 @@ class CallBacks:
2931

3032
def __init__(self: SystemCollectionProtocol) -> None:
3133
self._callback_list: list[ModuleProtocol] = []
32-
self._callback_operators: list[tuple[int, CallBackBaseClass]] = []
3334
super(CallBacks, self).__init__()
34-
self._feature_group_callback.append(self._callback_execution)
3535
self._feature_group_finalize.append(self._finalize_callback)
3636

3737
def collect_diagnostics(
@@ -54,30 +54,28 @@ def collect_diagnostics(
5454
sys_idx: SystemIdxType = self.get_system_index(system)
5555

5656
# Create _Constraint object, cache it and return to user
57-
_callbacks: ModuleProtocol = _CallBack(sys_idx)
58-
self._callback_list.append(_callbacks)
57+
_callback: ModuleProtocol = _CallBack(sys_idx)
58+
self._callback_list.append(_callback)
59+
self._feature_group_callback.append_id(_callback)
5960

60-
return _callbacks
61+
return _callback
6162

6263
def _finalize_callback(self: SystemCollectionProtocol) -> None:
6364
# dev : the first index stores the rod index to collect data.
64-
self._callback_operators = [
65-
(callback.id(), callback.instantiate()) for callback in self._callback_list
66-
]
65+
for callback in self._callback_list:
66+
sys_id = callback.id()
67+
callback_instance = callback.instantiate()
68+
69+
callback_operator = functools.partial(
70+
callback_instance.make_callback, system=self[sys_id]
71+
)
72+
self._feature_group_callback.add_operators(callback, [callback_operator])
73+
6774
self._callback_list.clear()
6875
del self._callback_list
6976

7077
# First callback execution
71-
time = np.float64(0.0)
72-
self._callback_execution(time=time, current_step=0)
73-
74-
def _callback_execution(
75-
self: SystemCollectionProtocol,
76-
time: np.float64,
77-
current_step: int,
78-
) -> None:
79-
for sys_id, callback in self._callback_operators:
80-
callback.make_callback(self[sys_id], time, current_step)
78+
self.apply_callbacks(time=np.float64(0.0), current_step=0)
8179

8280

8381
class _CallBack:

elastica/modules/constraints.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from typing import Any, Type, cast
88
from typing_extensions import Self
99

10+
import functools
11+
1012
import numpy as np
1113

1214
from elastica.boundary_conditions import ConstraintBase
@@ -36,8 +38,6 @@ class Constraints:
3638
def __init__(self: SystemCollectionProtocol) -> None:
3739
self._constraints_list: list[ModuleProtocol] = []
3840
super(Constraints, self).__init__()
39-
self._feature_group_constrain_values.append(self._constrain_values)
40-
self._feature_group_constrain_rates.append(self._constrain_rates)
4141
self._feature_group_finalize.append(self._finalize_constraints)
4242

4343
def constrain(
@@ -62,6 +62,8 @@ def constrain(
6262
# Create _Constraint object, cache it and return to user
6363
_constraint: ModuleProtocol = _Constraint(sys_idx)
6464
self._constraints_list.append(_constraint)
65+
self._feature_group_constrain_values.append_id(_constraint)
66+
self._feature_group_constrain_rates.append_id(_constraint)
6567

6668
return _constraint
6769

@@ -71,11 +73,14 @@ def _finalize_constraints(self: SystemCollectionProtocol) -> None:
7173
periodic boundaries, a new constrain for memory block rod added called as _ConstrainPeriodicBoundaries. This
7274
constrain will synchronize the only periodic boundaries of position, director, velocity and omega variables.
7375
"""
74-
from elastica._synchronize_periodic_boundary import _ConstrainPeriodicBoundaries
7576

7677
for block in self.block_systems():
7778
# append the memory block to the simulation as a system. Memory block is the final system in the simulation.
7879
if hasattr(block, "ring_rod_flag"):
80+
from elastica._synchronize_periodic_boundary import (
81+
_ConstrainPeriodicBoundaries,
82+
)
83+
7984
# Apply the constrain to synchronize the periodic boundaries of the memory rod. Find the memory block
8085
# sys idx among other systems added and then apply boundary conditions.
8186
memory_block_idx = self.get_system_index(block)
@@ -89,31 +94,38 @@ def _finalize_constraints(self: SystemCollectionProtocol) -> None:
8994

9095
# dev : the first index stores the rod index to apply the boundary condition
9196
# to.
92-
self._constraints_operators = [
93-
(constraint.id(), constraint.instantiate(self[constraint.id()]))
94-
for constraint in self._constraints_list
95-
]
96-
9797
# Sort from lowest id to highest id for potentially better memory access
9898
# _constraints contains list of tuples. First element of tuple is rod number and
9999
# following elements are the type of boundary condition such as
100100
# [(0, ConstraintBase, OneEndFixedBC), (1, HelicalBucklingBC), ... ]
101101
# Thus using lambda we iterate over the list of tuples and use rod number (x[0])
102102
# to sort constraints.
103-
self._constraints_operators.sort(key=lambda x: x[0])
103+
self._constraints_list.sort(key=lambda x: x.id())
104+
for constraint in self._constraints_list:
105+
sys_id = constraint.id()
106+
constraint_instance = constraint.instantiate(self[sys_id])
107+
108+
constrain_values = functools.partial(
109+
constraint_instance.constrain_values, system=self[sys_id]
110+
)
111+
constrain_rates = functools.partial(
112+
constraint_instance.constrain_rates, system=self[sys_id]
113+
)
114+
115+
self._feature_group_constrain_values.add_operators(
116+
constraint, [constrain_values]
117+
)
118+
self._feature_group_constrain_rates.add_operators(
119+
constraint, [constrain_rates]
120+
)
104121

105122
# At t=0.0, constrain all the boundary conditions (for compatability with
106123
# initial conditions)
107-
self._constrain_values(time=np.float64(0.0))
108-
self._constrain_rates(time=np.float64(0.0))
109-
110-
def _constrain_values(self: SystemCollectionProtocol, time: np.float64) -> None:
111-
for sys_id, constraint in self._constraints_operators:
112-
constraint.constrain_values(self[sys_id], time)
124+
self.constrain_values(time=np.float64(0.0))
125+
self.constrain_rates(time=np.float64(0.0))
113126

114-
def _constrain_rates(self: SystemCollectionProtocol, time: np.float64) -> None:
115-
for sys_id, constraint in self._constraints_operators:
116-
constraint.constrain_rates(self[sys_id], time)
127+
self._constraints_list = []
128+
del self._constraints_list
117129

118130

119131
class _Constraint:

0 commit comments

Comments
 (0)