Skip to content

Commit

Permalink
Added pre-computation for BaseEnvironment.get_available_options.
Browse files Browse the repository at this point in the history
Previously, every time BaseEnvironment.get_available_options was called, it would loop through every single option and test whether a given state was in its initiation set. This was unbelievably slow - it ended taking up the majority of the runtime - and scaled linearly in its horribleness with the number of available options.

Now, whenever the option set is mutated using BaseEnvironment.set_options, a list of options available in each state is now pre-computed. Then, whenever BaseEnvironment.get_available_options is called, a shallow copy of the given state's list is returned. Much simpler and faster!

In a gridworld with ~500 states and ~300 options, this reduced the time taken to run OptionsAgent for 75,000 time-steps from ~120secs to ~9secs. Crazy!
  • Loading branch information
Ueva committed Jul 11, 2024
1 parent beb785b commit 5ecd52a
Showing 1 changed file with 21 additions and 5 deletions.
26 changes: 21 additions & 5 deletions simpleoptions/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ def __init__(self):
"""
self._options = set()
self.exploration_options = set()
self._option_availability_maps = {}
self._exploration_option_availability_maps = {}
self.current_state = None

@abstractmethod
Expand Down Expand Up @@ -126,7 +128,7 @@ def get_option_space(self) -> Set["BaseOption"]:
Returns a set containing all of the options available in this environment.
Returns:
Set[BaseOption]: All possible options in this environment.
Set[BaseOption]: All possible options available in this environment.
"""
return self._options

Expand All @@ -150,11 +152,9 @@ def get_available_options(self, state: Hashable, exploration=False) -> List["Bas
return []
# Otherwise, options whose initiation set contains the given state are returned.
else:
# Lists all options (including options corresponding to primitive actions) which have the given state in their initiation sets.
available_options = [option for option in self._options if option.initiation(state)]

available_options = copy.copy(self._option_availability_maps.get(state, list()))
if exploration:
available_options.extend([option for option in self.exploration_options if option.initiation(state)])
available_options.extend(copy.copy(self._exploration_option_availability_maps.get(state, list())))

return available_options

Expand All @@ -173,6 +173,13 @@ def set_options(self, new_options: List["BaseOption"], append: bool = False) ->
else:
self._options.update(copy.copy(new_options))

self._option_availability_maps = {}
for state in self.get_state_space():
for option in self._options:
if option.initiation(state):
self._option_availability_maps[state] = self._option_availability_maps.get(state, list())
self._option_availability_maps[state].append(option)

def set_exploration_options(self, new_options: List["BaseOption"], append: bool = False) -> None:
"""
Sets the set of options available solely for exploration in this environment.
Expand All @@ -187,6 +194,15 @@ def set_exploration_options(self, new_options: List["BaseOption"], append: bool
else:
self.exploration_options.update(copy.copy(new_options))

self._exploration_option_availability_maps = {}
for state in self.get_state_space():
for exploration_option in self.exploration_options:
if exploration_option.initiation(state):
self._exploration_option_availability_maps[state] = self._exploration_option_availability_maps.get(
state, list()
)
self._exploration_option_availability_maps[state].append(exploration_option)

@abstractmethod
def is_state_terminal(self, state: Hashable = None) -> bool:
"""
Expand Down

0 comments on commit 5ecd52a

Please sign in to comment.