Skip to content

Commit

Permalink
sticky actions can have random duration within a range
Browse files Browse the repository at this point in the history
  • Loading branch information
sparisi committed Nov 8, 2024
1 parent d219ee3 commit dbd47a3
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 15 deletions.
55 changes: 43 additions & 12 deletions gymnasium/wrappers/stateful_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

from __future__ import annotations

import numpy as np
from typing import Any

import gymnasium as gym
from gymnasium.core import ActType, ObsType
from gymnasium.error import InvalidProbability
from gymnasium.error import InvalidProbability, InvalidBound


__all__ = ["StickyAction"]
Expand Down Expand Up @@ -46,21 +47,43 @@ def __init__(
self,
env: gym.Env[ObsType, ActType],
repeat_action_probability: float,
repeat_action_duration: int = 1,
repeat_action_duration: int | tuple[int, int] = 1,
):
"""Initialize StickyAction wrapper.
Args:
env (Env): the wrapped environment,
repeat_action_probability (int | float): a probability of repeating the old action,
repeat_action_duration (int): the number of steps the action is repeated.
repeat_action_duration (int | tuple[int, int]): the number of steps
the action is repeated. It can be either an int (for deterministic
repeats) or a tuple[int, int] for a range of stochastic number of repeats.
"""
if not 0 <= repeat_action_probability < 1:
raise InvalidProbability(
f"repeat_action_probability should be in the interval [0,1). Received {repeat_action_probability}"
)

if repeat_action_duration < 1:
if isinstance(repeat_action_duration, int):
repeat_action_duration = [repeat_action_duration, repeat_action_duration]
elif not (
isinstance(repeat_action_duration, tuple)
or isinstance(repeat_action_duration, list)
):
raise ValueError(
f"repeat_action_duration should be either an integer, a tuple, or a list. Received {repeat_action_duration}"
)

if len(repeat_action_duration) != 2:
raise ValueError(
f"repeat_action_duration should be a tuple or a list of two integers. Received {repeat_action_duration}"
)

if repeat_action_duration[1] < repeat_action_duration[0]:
raise InvalidBound(
f"repeat_action_duration is not a valid bound. Received {repeat_action_duration}"
)

if np.any(np.array(repeat_action_duration) < 1):
raise ValueError(
f"repeat_action_duration should be larger or equal than 1. Received {repeat_action_duration}"
)
Expand All @@ -71,18 +94,20 @@ def __init__(
gym.ActionWrapper.__init__(self, env)

self.repeat_action_probability = repeat_action_probability
self.repeat_action_duration = repeat_action_duration
self.repeat_action_duration_range = repeat_action_duration
self.last_action: ActType | None = None
self.last_action_repeats = 0
self.is_repeating = False
self.last_action_repeats = None # number of steps last action will be repeated
self.is_repeating = False # if the agent is currently "stuck" into repeats
self.is_repeating_since = 0 # number of steps last action has been repeated

def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[ObsType, dict[str, Any]]:
"""Reset the environment."""
self.last_action = None
self.last_action_repeats = 0
self.last_action_repeats = None
self.is_repeating = False
self.is_repeating_since = 0

return super().reset(seed=seed, options=options)

Expand All @@ -92,14 +117,20 @@ def action(self, action: ActType) -> ActType:
self.is_repeating
or self.last_action is not None
and self.np_random.uniform() < self.repeat_action_probability
):
): # either the agent was already "stuck" into repeats, or a new series of repeats is triggered
if self.last_action_repeats is None: # if a new series starts, randomly sample its duration
self.last_action_repeats = self.np_random.integers(
self.repeat_action_duration_range[0],
self.repeat_action_duration_range[1] + 1,
)
action = self.last_action
self.is_repeating = True
self.last_action_repeats += 1
self.is_repeating_since += 1

if self.last_action_repeats == self.repeat_action_duration:
if self.last_action_repeats == self.is_repeating_since: # repeats are done, reset "stuck" status
self.last_action_repeats = None
self.is_repeating = False
self.last_action_repeats = 0
self.is_repeating_since = 0

self.last_action = action
return action
6 changes: 3 additions & 3 deletions tests/wrappers/test_sticky_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import pytest

from gymnasium.error import InvalidProbability
from gymnasium.error import InvalidProbability, InvalidBound
from gymnasium.wrappers import StickyAction
from tests.testing_env import GenericTestEnv
from tests.wrappers.utils import NUM_STEPS, record_action_as_obs_step
Expand Down Expand Up @@ -51,10 +51,10 @@ def test_sticky_action_raise_probability(repeat_action_probability):
)


@pytest.mark.parametrize("repeat_action_duration", [-4, 0])
@pytest.mark.parametrize("repeat_action_duration", [-4, 0, (0, 0), (4, 2), [1,]])
def test_sticky_action_raise_duration(repeat_action_duration):
"""Tests the stick action wrapper with durations that should raise an error."""
with pytest.raises(ValueError):
with pytest.raises((ValueError, InvalidBound)):
StickyAction(
GenericTestEnv(), 0.5, repeat_action_duration=repeat_action_duration
)

0 comments on commit dbd47a3

Please sign in to comment.