-
Notifications
You must be signed in to change notification settings - Fork 122
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Puppeteers for Allelopathic Harvest and CleanUp (with three goals: ea…
…t, clean, and sanction). PiperOrigin-RevId: 646455367 Change-Id: Id77a1adb4b4132f209378645a99ca2c77891556b
- Loading branch information
1 parent
c6ff8a5
commit 5c53180
Showing
4 changed files
with
562 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
# Copyright 2022 DeepMind Technologies Limited. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Puppeteers for allelopathic_harvest.""" | ||
|
||
import dataclasses | ||
from typing import Sequence | ||
|
||
import dm_env | ||
from meltingpot.utils.puppeteers import puppeteer | ||
import numpy as np | ||
|
||
|
||
@dataclasses.dataclass(frozen=True) | ||
class ConventionFollowerState: | ||
"""Current state of the ConventionFollower. | ||
Attributes: | ||
step_count: number of timesteps previously seen in this episode. | ||
current_goal: the goal last used by the puppeteer. | ||
recent_frames: buffer of recent observation frames. | ||
""" | ||
step_count: int | ||
current_goal: puppeteer.PuppetGoal | ||
recent_frames: tuple[np.ndarray, ...] | ||
|
||
|
||
class ConventionFollower(puppeteer.Puppeteer[ConventionFollowerState]): | ||
"""Allelopathic Harvest puppeteer for a convention follower.""" | ||
|
||
def __init__( | ||
self, | ||
initial_goal: puppeteer.PuppetGoal, | ||
preference_goals: Sequence[puppeteer.PuppetGoal], | ||
color_threshold: int, | ||
recency_window: int = 5) -> None: | ||
"""Initializes the puppeteer. | ||
Args: | ||
initial_goal: the initial goal to pursue. | ||
preference_goals: sequence of goals corresponding to the R, G, B, goals | ||
for when that color becomes dominant (on average) over the last | ||
`recency_window` frames. | ||
color_threshold: threshold for a color to become dominant. | ||
recency_window: number of frames to check for a dominant color. | ||
""" | ||
self._initial_goal = initial_goal | ||
self._preference_goals = preference_goals | ||
self._color_threshold = color_threshold | ||
self._recency_window = recency_window | ||
|
||
def initial_state(self) -> ConventionFollowerState: | ||
return ConventionFollowerState( | ||
step_count=0, current_goal=self._initial_goal, recent_frames=()) | ||
|
||
def step( | ||
self, | ||
timestep: dm_env.TimeStep, | ||
prev_state: ConventionFollowerState | ||
) -> tuple[dm_env.TimeStep, ConventionFollowerState]: | ||
"""Puppeteer step. | ||
Args: | ||
timestep: the timestep. | ||
prev_state: the state of the pupeeteer. | ||
Returns: | ||
Modified timestep and new state. | ||
""" | ||
if timestep.first(): | ||
prev_state = self.initial_state() | ||
|
||
recent_frames = list(prev_state.recent_frames) | ||
current_goal = prev_state.current_goal | ||
if len(recent_frames) < self._recency_window: | ||
recent_frames = tuple([timestep.observation['RGB']] + recent_frames) | ||
else: | ||
recent_frames = tuple([timestep.observation['RGB']] + recent_frames[:-1]) | ||
|
||
average_color = np.array(recent_frames).mean(axis=(0, 1, 2)) | ||
index = np.argmax(average_color) | ||
if average_color[index] > self._color_threshold: | ||
current_goal = self._preference_goals[index] | ||
|
||
return puppeteer.puppet_timestep(timestep, current_goal), ( | ||
ConventionFollowerState( | ||
step_count=prev_state.step_count + 1, | ||
current_goal=current_goal, | ||
recent_frames=recent_frames)) |
102 changes: 102 additions & 0 deletions
102
meltingpot/utils/puppeteers/allelopathic_harvest_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
# Copyright 2022 DeepMind Technologies Limited. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Tests for allelopathic_harvest puppeteers.""" | ||
|
||
from unittest import mock | ||
|
||
from absl.testing import absltest | ||
from absl.testing import parameterized | ||
from meltingpot.testing import puppeteers | ||
from meltingpot.utils.puppeteers import allelopathic_harvest | ||
|
||
_CONSUME_ANY = mock.sentinel.consume | ||
_PREFER_R = mock.sentinel.pref_red | ||
_PREFER_G = mock.sentinel.pref_green | ||
_PREFER_B = mock.sentinel.pref_blue | ||
_PREFER = (_PREFER_R, _PREFER_G, _PREFER_B) | ||
|
||
_RGB_KEY = 'RGB' | ||
_NEUTRAL_COLOR = (((100, 100, 100),),) | ||
_HIGH_R = (((255, 55, 55),),) | ||
_HIGH_G = (((55, 255, 55),),) | ||
_HIGH_B = (((55, 55, 255),),) | ||
|
||
|
||
def _goals(puppeteer, rgbs, state=None): | ||
observations = [{_RGB_KEY: rgb} for rgb in rgbs] | ||
goals, state = puppeteers.goals_from_observations( | ||
puppeteer, observations, state | ||
) | ||
return goals, state | ||
|
||
|
||
class ConventionFollowerTest(parameterized.TestCase): | ||
|
||
def test_starts_with_consume(self): | ||
puppeteer = allelopathic_harvest.ConventionFollower( | ||
initial_goal=_CONSUME_ANY, | ||
preference_goals=_PREFER, | ||
color_threshold=200, | ||
recency_window=5, | ||
) | ||
expected = [_CONSUME_ANY] | ||
actual, _ = _goals(puppeteer, [_NEUTRAL_COLOR]) | ||
self.assertEqual(actual, expected) | ||
|
||
@parameterized.parameters( | ||
(_HIGH_R, _PREFER_R), | ||
(_HIGH_G, _PREFER_G), | ||
(_HIGH_B, _PREFER_B), | ||
) | ||
def test_change_on_immediate_recency(self, image, preference): | ||
puppeteer = allelopathic_harvest.ConventionFollower( | ||
initial_goal=_CONSUME_ANY, | ||
preference_goals=_PREFER, | ||
color_threshold=200, | ||
recency_window=1, | ||
) | ||
expected = [_CONSUME_ANY, preference] | ||
actual, _ = _goals(puppeteer, [_NEUTRAL_COLOR, image]) | ||
self.assertEqual(actual, expected) | ||
|
||
@parameterized.parameters( | ||
(_HIGH_R,), | ||
(_HIGH_G,), | ||
(_HIGH_B,), | ||
) | ||
def test_no_change_with_long_recency(self, image): | ||
puppeteer = allelopathic_harvest.ConventionFollower( | ||
initial_goal=_CONSUME_ANY, | ||
preference_goals=_PREFER, | ||
color_threshold=250, | ||
recency_window=5, | ||
) | ||
expected = [_CONSUME_ANY, _CONSUME_ANY] | ||
actual, _ = _goals(puppeteer, [_NEUTRAL_COLOR, image]) | ||
self.assertEqual(actual, expected) | ||
|
||
@parameterized.parameters(5, 10, 25) | ||
def test_change_over_long_recency(self, recency): | ||
puppeteer = allelopathic_harvest.ConventionFollower( | ||
initial_goal=_CONSUME_ANY, | ||
preference_goals=_PREFER, | ||
color_threshold=254, | ||
recency_window=recency, | ||
) | ||
expected = [_CONSUME_ANY] * recency + [_PREFER_R] | ||
actual, _ = _goals(puppeteer, [_NEUTRAL_COLOR] + [_HIGH_R] * recency) | ||
self.assertEqual(actual, expected) | ||
|
||
if __name__ == '__main__': | ||
absltest.main() |
Oops, something went wrong.