Skip to content

Commit

Permalink
Adds support for ISystemReset in test fixture (#2647)
Browse files Browse the repository at this point in the history
* Adds support for Reset in test fixture

This PR adds support for the Reset API to the test fixture. As `TestFixture`
is one of the main ways one can get access to the ECM in python
when trying to write some scripts for Deep Reinforcement Learning I
realized that without `Reset` supported in the `TestFixture` API, end
users would have a very hard time using our python APIs (which are
actually quite nice). For reference I'm hacking a demo template here:

#2667
---------

Signed-off-by: Arjo Chakravarty <arjoc@intrinsic.ai>
  • Loading branch information
arjo129 authored Dec 20, 2024
1 parent f05cda2 commit e63e8d8
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 1 deletion.
6 changes: 6 additions & 0 deletions include/gz/sim/TestFixture.hh
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,12 @@ class GZ_SIM_VISIBLE TestFixture
public: TestFixture &OnPostUpdate(std::function<void(
const UpdateInfo &, const EntityComponentManager &)> _cb);

/// \brief Wrapper around a system's update callback
/// \param[in] _cb Function to be called every reset
/// \return Reference to self.
public: TestFixture &OnReset(std::function<void(
const UpdateInfo &, EntityComponentManager &)> _cb);

/// \brief Finalize all the functions and add fixture to server.
/// Finalize must be called before running the server, otherwise none of the
/// `On*` functions will be called.
Expand Down
11 changes: 11 additions & 0 deletions python/src/gz/sim/TestFixture.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,17 @@ defineSimTestFixture(pybind11::object module)
),
pybind11::return_value_policy::reference,
"Wrapper around a system's post-update callback"
)
.def(
"on_reset", WrapCallbacks(
[](TestFixture* self, std::function<void(
const UpdateInfo &, EntityComponentManager &)> _cb)
{
self->OnReset(_cb);
}
),
pybind11::return_value_policy::reference,
"Wrapper around a system's reset callback"
);
// TODO(ahcorde): This method is not compiling for the following reason:
// The EventManager class has an unordered_map which holds a unique_ptr
Expand Down
28 changes: 27 additions & 1 deletion src/TestFixture.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ class HelperSystem :
public ISystemConfigure,
public ISystemPreUpdate,
public ISystemUpdate,
public ISystemPostUpdate
public ISystemPostUpdate,
public ISystemReset
{
// Documentation inherited
public: void Configure(
Expand All @@ -50,6 +51,10 @@ class HelperSystem :
public: void PostUpdate(const UpdateInfo &_info,
const EntityComponentManager &_ecm) override;

// Documentation inherited
public: void Reset(const UpdateInfo &_info,
EntityComponentManager &_ecm) override;

/// \brief Function to call every time we configure a world
public: std::function<void(const Entity &_entity,
const std::shared_ptr<const sdf::Element> &_sdf,
Expand All @@ -68,6 +73,10 @@ class HelperSystem :
/// \brief Function to call every post-update
public: std::function<void(const UpdateInfo &,
const EntityComponentManager &)> postUpdateCallback;

/// \brief Reset callback
public: std::function<void(const UpdateInfo &,
EntityComponentManager &)> resetCallback;
};

/////////////////////////////////////////////////
Expand Down Expand Up @@ -105,6 +114,14 @@ void HelperSystem::PostUpdate(const UpdateInfo &_info,
this->postUpdateCallback(_info, _ecm);
}

/////////////////////////////////////////////////
void HelperSystem::Reset(const UpdateInfo &_info,
EntityComponentManager &_ecm)
{
if (this->resetCallback)
this->resetCallback(_info, _ecm);
}

//////////////////////////////////////////////////
class gz::sim::TestFixture::Implementation
{
Expand Down Expand Up @@ -200,6 +217,15 @@ TestFixture &TestFixture::OnPostUpdate(std::function<void(
return *this;
}

//////////////////////////////////////////////////
TestFixture &TestFixture::OnReset(std::function<void(
const UpdateInfo &, EntityComponentManager &)> _cb)
{
if (nullptr != this->dataPtr->helperSystem)
this->dataPtr->helperSystem->resetCallback = std::move(_cb);
return *this;
}

//////////////////////////////////////////////////
std::shared_ptr<Server> TestFixture::Server() const
{
Expand Down

0 comments on commit e63e8d8

Please sign in to comment.