Skip to content

Commit

Permalink
feat: Timestep now supports indexing
Browse files Browse the repository at this point in the history
  • Loading branch information
epignatelli committed Jan 13, 2024
1 parent 86b7279 commit 15137bd
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 1 deletion.
2 changes: 1 addition & 1 deletion helx/_version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# file generated by setuptools_scm
# don't change, don't track in version control
__version__ = "1.1.4"
__version__ = "1.1.5"
__version_info__ = tuple(int(i) for i in __version__.split(".") if i.isdigit())
7 changes: 7 additions & 0 deletions helx/base/mdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from jax import Array
import jax.numpy as jnp
import jax.tree_util as jtu
from flax import struct


Expand All @@ -42,3 +43,9 @@ class Timestep(struct.PyTreeNode):
"""The true state of the MDP, $s_t$ before taking action `action`"""
info: Dict[str, Any] = struct.field(default_factory=dict)
"""Additional information about the environment. Useful for accumulations (e.g. returns)"""

def __getitem__(self, key: Any) -> Timestep:
return jtu.tree_map(lambda x: x[key], self)

def __setitem__(self, key: Any, value: Any) -> Timestep:
return jtu.tree_map(lambda x: x.at[key].set(value), self)

0 comments on commit 15137bd

Please sign in to comment.