Skip to content

Commit a3547e2

Browse files
kodiakhq[bot]jngrad
authored andcommitted
Rewrite script interface object containers serialization (#4724)
Fixes #4280 Description of changes: - checkpoint restrictions on the number of MPI ranks have been lifted
1 parent fd56cd2 commit a3547e2

15 files changed

+176
-347
lines changed

doc/sphinx/io.rst

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -111,11 +111,6 @@ Be aware of the following limitations:
111111
for a specific combination of features, please share your findings
112112
with the |es| community.
113113

114-
* Checkpointing only supports recursion on the head node. It is therefore
115-
impossible to checkpoint a :class:`espressomd.system.System` instance that
116-
contains LB boundaries, constraints or auto-update accumulators when the
117-
simulation is running with 2 or more MPI nodes.
118-
119114
* The active actors, i.e., the content of ``system.actors``, are checkpointed.
120115
For lattice-Boltzmann fluids, this only includes the parameters such as the
121116
lattice constant (``agrid``). The actual flow field has to be saved

src/python/espressomd/interactions.pyx

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2888,8 +2888,6 @@ class BondedInteractions(ScriptObjectMap):
28882888
return bond_id
28892889

28902890
def __getitem__(self, bond_id):
2891-
self._assert_key_type(bond_id)
2892-
28932891
if self.call_method('has_bond', bond_id=bond_id):
28942892
bond_obj = self.call_method('get_bond', bond_id=bond_id)
28952893
bond_obj._bond_id = bond_id
@@ -2932,7 +2930,6 @@ class BondedInteractions(ScriptObjectMap):
29322930
bond_id = self.call_method("insert", object=bond_obj)
29332931
else:
29342932
# Throw error if attempting to overwrite a bond of different type
2935-
self._assert_key_type(bond_id)
29362933
if self.call_method("contains", key=bond_id):
29372934
old_type = bonded_interaction_classes[
29382935
get_bonded_interaction_type_from_es_core(bond_id)]
@@ -2969,3 +2966,14 @@ class BondedInteractions(ScriptObjectMap):
29692966
for bond_id, (bond_params, bond_type) in params.items():
29702967
self[bond_id] = bonded_interaction_classes[bond_type](
29712968
**bond_params)
2969+
2970+
def __reduce__(self):
2971+
so_callback, (so_name, so_bytestring) = super().__reduce__()
2972+
return (BondedInteractions._restore_object,
2973+
(so_callback, (so_name, so_bytestring), self.__getstate__()))
2974+
2975+
@classmethod
2976+
def _restore_object(cls, so_callback, so_callback_args, state):
2977+
so = so_callback(*so_callback_args)
2978+
so.__setstate__(state)
2979+
return so

src/python/espressomd/script_interface.pyx

Lines changed: 0 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -406,15 +406,6 @@ class ScriptObjectList(ScriptInterfaceHelper):
406406
407407
"""
408408

409-
def __init__(self, *args, **kwargs):
410-
if args:
411-
params, (_unpickle_so_class, (_so_name, bytestring)) = args
412-
assert _so_name == self._so_name
413-
self = _unpickle_so_class(_so_name, bytestring)
414-
self.__setstate__(params)
415-
else:
416-
super().__init__(**kwargs)
417-
418409
def __getitem__(self, key):
419410
return self.call_method("get_elements")[key]
420411

@@ -426,24 +417,6 @@ class ScriptObjectList(ScriptInterfaceHelper):
426417
def __len__(self):
427418
return self.call_method("size")
428419

429-
@classmethod
430-
def _restore_object(cls, so_callback, so_callback_args, state):
431-
so = so_callback(*so_callback_args)
432-
so.__setstate__(state)
433-
return so
434-
435-
def __reduce__(self):
436-
so_callback, (so_name, so_bytestring) = super().__reduce__()
437-
return (ScriptObjectList._restore_object,
438-
(so_callback, (so_name, so_bytestring), self.__getstate__()))
439-
440-
def __getstate__(self):
441-
return self.call_method("get_elements")
442-
443-
def __setstate__(self, object_list):
444-
for item in object_list:
445-
self.add(item)
446-
447420

448421
class ScriptObjectMap(ScriptInterfaceHelper):
449422
"""
@@ -456,17 +429,6 @@ class ScriptObjectMap(ScriptInterfaceHelper):
456429
457430
"""
458431

459-
_key_type = int
460-
461-
def __init__(self, *args, **kwargs):
462-
if args:
463-
params, (_unpickle_so_class, (_so_name, bytestring)) = args
464-
assert _so_name == self._so_name
465-
self = _unpickle_so_class(_so_name, bytestring)
466-
self.__setstate__(params)
467-
else:
468-
super().__init__(**kwargs)
469-
470432
def remove(self, key):
471433
"""
472434
Remove the element with the given key.
@@ -485,15 +447,12 @@ class ScriptObjectMap(ScriptInterfaceHelper):
485447
return self.call_method("size")
486448

487449
def __getitem__(self, key):
488-
self._assert_key_type(key)
489450
return self.call_method("get", key=key)
490451

491452
def __setitem__(self, key, value):
492-
self._assert_key_type(key)
493453
self.call_method("insert", key=key, object=value)
494454

495455
def __delitem__(self, key):
496-
self._assert_key_type(key)
497456
self.call_method("erase", key=key)
498457

499458
def keys(self):
@@ -505,28 +464,6 @@ class ScriptObjectMap(ScriptInterfaceHelper):
505464
def items(self):
506465
for k in self.keys(): yield k, self[k]
507466

508-
def _assert_key_type(self, key):
509-
if not utils.is_valid_type(key, self._key_type):
510-
raise TypeError(f"Key has to be of type {self._key_type.__name__}")
511-
512-
@classmethod
513-
def _restore_object(cls, so_callback, so_callback_args, state):
514-
so = so_callback(*so_callback_args)
515-
so.__setstate__(state)
516-
return so
517-
518-
def __reduce__(self):
519-
so_callback, (so_name, so_bytestring) = super().__reduce__()
520-
return (ScriptObjectMap._restore_object,
521-
(so_callback, (so_name, so_bytestring), self.__getstate__()))
522-
523-
def __getstate__(self):
524-
return dict(self.items())
525-
526-
def __setstate__(self, params):
527-
for key, val in params.items():
528-
self[key] = val
529-
530467

531468
# Map from script object names to their corresponding python classes
532469
_python_class_by_so_name = {}

src/script_interface/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919

2020
add_library(
2121
Espresso_script_interface SHARED
22-
initialize.cpp ObjectHandle.cpp object_container_mpi_guard.cpp
23-
GlobalContext.cpp ContextManager.cpp ParallelExceptionHandler.cpp)
22+
initialize.cpp ObjectHandle.cpp GlobalContext.cpp ContextManager.cpp
23+
ParallelExceptionHandler.cpp)
2424
add_library(Espresso::script_interface ALIAS Espresso_script_interface)
2525

2626
add_subdirectory(accumulators)
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/*
2+
* Copyright (C) 2023 The ESPResSo project
3+
*
4+
* This file is part of ESPResSo.
5+
*
6+
* ESPResSo is free software: you can redistribute it and/or modify
7+
* it under the terms of the GNU General Public License as published by
8+
* the Free Software Foundation, either version 3 of the License, or
9+
* (at your option) any later version.
10+
*
11+
* ESPResSo is distributed in the hope that it will be useful,
12+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14+
* GNU General Public License for more details.
15+
*
16+
* You should have received a copy of the GNU General Public License
17+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
18+
*/
19+
#ifndef SCRIPT_INTERFACE_OBJECT_CONTAINER_HPP
20+
#define SCRIPT_INTERFACE_OBJECT_CONTAINER_HPP
21+
22+
#include "script_interface/auto_parameters/AutoParameters.hpp"
23+
24+
#include <type_traits>
25+
26+
namespace ScriptInterface {
27+
28+
/**
29+
* @brief Base class for containers whose @c BaseType might be a full
30+
* specialization of @ref AutoParameters.
31+
*/
32+
template <
33+
template <typename...> class Container, typename ManagedType,
34+
class BaseType,
35+
class = std::enable_if_t<std::is_base_of<ObjectHandle, ManagedType>::value>>
36+
using ObjectContainer = std::conditional_t<
37+
std::is_same<BaseType, ObjectHandle>::value,
38+
AutoParameters<Container<ManagedType, BaseType>, BaseType>, BaseType>;
39+
40+
} // namespace ScriptInterface
41+
42+
#endif

src/script_interface/ObjectList.hpp

Lines changed: 27 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@
2222
#ifndef SCRIPT_INTERFACE_OBJECT_LIST_HPP
2323
#define SCRIPT_INTERFACE_OBJECT_LIST_HPP
2424

25+
#include "script_interface/ObjectContainer.hpp"
2526
#include "script_interface/ScriptInterface.hpp"
2627
#include "script_interface/get_value.hpp"
27-
#include "script_interface/object_container_mpi_guard.hpp"
2828

2929
#include <utils/serialization/pack.hpp>
3030

@@ -35,20 +35,39 @@
3535
#include <vector>
3636

3737
namespace ScriptInterface {
38+
3839
/**
3940
* @brief Owning list of ObjectHandles
4041
* @tparam ManagedType Type of the managed objects, needs to be
41-
* derived from ObjectHandle
42+
* derived from @ref ObjectHandle
4243
*/
43-
template <
44-
typename ManagedType, class BaseType = ObjectHandle,
45-
class = std::enable_if_t<std::is_base_of<ObjectHandle, ManagedType>::value>>
46-
class ObjectList : public BaseType {
44+
template <typename ManagedType, class BaseType = ObjectHandle>
45+
class ObjectList : public ObjectContainer<ObjectList, ManagedType, BaseType> {
46+
public:
47+
using Base = ObjectContainer<ObjectList, ManagedType, BaseType>;
48+
using Base::add_parameters;
49+
4750
private:
51+
std::vector<std::shared_ptr<ManagedType>> m_elements;
52+
4853
virtual void add_in_core(const std::shared_ptr<ManagedType> &obj_ptr) = 0;
4954
virtual void remove_in_core(const std::shared_ptr<ManagedType> &obj_ptr) = 0;
5055

5156
public:
57+
ObjectList() {
58+
add_parameters({
59+
{"_objects", AutoParameter::read_only,
60+
[this]() { return make_vector_of_variants(m_elements); }},
61+
});
62+
}
63+
64+
void do_construct(VariantMap const &params) override {
65+
m_elements = get_value_or<decltype(m_elements)>(params, "_objects", {});
66+
for (auto const &object : m_elements) {
67+
add_in_core(object);
68+
}
69+
}
70+
5271
/**
5372
* @brief Add an element to the list.
5473
*
@@ -107,12 +126,7 @@ class ObjectList : public BaseType {
107126
}
108127

109128
if (method == "get_elements") {
110-
std::vector<Variant> ret;
111-
ret.reserve(m_elements.size());
112-
for (auto const &e : m_elements)
113-
ret.emplace_back(e);
114-
115-
return ret;
129+
return make_vector_of_variants(m_elements);
116130
}
117131

118132
if (method == "clear") {
@@ -128,32 +142,8 @@ class ObjectList : public BaseType {
128142
return m_elements.empty();
129143
}
130144

131-
return BaseType::do_call_method(method, parameters);
145+
return Base::do_call_method(method, parameters);
132146
}
133-
134-
private:
135-
std::string get_internal_state() const override {
136-
object_container_mpi_guard(BaseType::name(), m_elements.size());
137-
138-
std::vector<std::string> object_states(m_elements.size());
139-
140-
boost::transform(m_elements, object_states.begin(),
141-
[](auto const &e) { return e->serialize(); });
142-
143-
return Utils::pack(object_states);
144-
}
145-
146-
void set_internal_state(std::string const &state) override {
147-
auto const object_states = Utils::unpack<std::vector<std::string>>(state);
148-
149-
for (auto const &packed_object : object_states) {
150-
auto o = std::dynamic_pointer_cast<ManagedType>(
151-
BaseType::deserialize(packed_object, *BaseType::context()));
152-
add(std::move(o));
153-
}
154-
}
155-
156-
std::vector<std::shared_ptr<ManagedType>> m_elements;
157147
};
158148
} // Namespace ScriptInterface
159149
#endif

0 commit comments

Comments
 (0)