Skip to content

Commit

Permalink
Puppeteers for Allelopathic Harvest and CleanUp (with three goals: ea…
Browse files Browse the repository at this point in the history
…t, clean, and sanction).

PiperOrigin-RevId: 646455367
Change-Id: Id77a1adb4b4132f209378645a99ca2c77891556b
  • Loading branch information
duenez authored and copybara-github committed Jun 25, 2024
1 parent c6ff8a5 commit 5c53180
Show file tree
Hide file tree
Showing 4 changed files with 562 additions and 0 deletions.
99 changes: 99 additions & 0 deletions meltingpot/utils/puppeteers/allelopathic_harvest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright 2022 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Puppeteers for allelopathic_harvest."""

import dataclasses
from typing import Sequence

import dm_env
from meltingpot.utils.puppeteers import puppeteer
import numpy as np


@dataclasses.dataclass(frozen=True)
class ConventionFollowerState:
"""Current state of the ConventionFollower.
Attributes:
step_count: number of timesteps previously seen in this episode.
current_goal: the goal last used by the puppeteer.
recent_frames: buffer of recent observation frames.
"""
step_count: int
current_goal: puppeteer.PuppetGoal
recent_frames: tuple[np.ndarray, ...]


class ConventionFollower(puppeteer.Puppeteer[ConventionFollowerState]):
"""Allelopathic Harvest puppeteer for a convention follower."""

def __init__(
self,
initial_goal: puppeteer.PuppetGoal,
preference_goals: Sequence[puppeteer.PuppetGoal],
color_threshold: int,
recency_window: int = 5) -> None:
"""Initializes the puppeteer.
Args:
initial_goal: the initial goal to pursue.
preference_goals: sequence of goals corresponding to the R, G, B, goals
for when that color becomes dominant (on average) over the last
`recency_window` frames.
color_threshold: threshold for a color to become dominant.
recency_window: number of frames to check for a dominant color.
"""
self._initial_goal = initial_goal
self._preference_goals = preference_goals
self._color_threshold = color_threshold
self._recency_window = recency_window

def initial_state(self) -> ConventionFollowerState:
return ConventionFollowerState(
step_count=0, current_goal=self._initial_goal, recent_frames=())

def step(
self,
timestep: dm_env.TimeStep,
prev_state: ConventionFollowerState
) -> tuple[dm_env.TimeStep, ConventionFollowerState]:
"""Puppeteer step.
Args:
timestep: the timestep.
prev_state: the state of the pupeeteer.
Returns:
Modified timestep and new state.
"""
if timestep.first():
prev_state = self.initial_state()

recent_frames = list(prev_state.recent_frames)
current_goal = prev_state.current_goal
if len(recent_frames) < self._recency_window:
recent_frames = tuple([timestep.observation['RGB']] + recent_frames)
else:
recent_frames = tuple([timestep.observation['RGB']] + recent_frames[:-1])

average_color = np.array(recent_frames).mean(axis=(0, 1, 2))
index = np.argmax(average_color)
if average_color[index] > self._color_threshold:
current_goal = self._preference_goals[index]

return puppeteer.puppet_timestep(timestep, current_goal), (
ConventionFollowerState(
step_count=prev_state.step_count + 1,
current_goal=current_goal,
recent_frames=recent_frames))
102 changes: 102 additions & 0 deletions meltingpot/utils/puppeteers/allelopathic_harvest_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright 2022 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for allelopathic_harvest puppeteers."""

from unittest import mock

from absl.testing import absltest
from absl.testing import parameterized
from meltingpot.testing import puppeteers
from meltingpot.utils.puppeteers import allelopathic_harvest

_CONSUME_ANY = mock.sentinel.consume
_PREFER_R = mock.sentinel.pref_red
_PREFER_G = mock.sentinel.pref_green
_PREFER_B = mock.sentinel.pref_blue
_PREFER = (_PREFER_R, _PREFER_G, _PREFER_B)

_RGB_KEY = 'RGB'
_NEUTRAL_COLOR = (((100, 100, 100),),)
_HIGH_R = (((255, 55, 55),),)
_HIGH_G = (((55, 255, 55),),)
_HIGH_B = (((55, 55, 255),),)


def _goals(puppeteer, rgbs, state=None):
observations = [{_RGB_KEY: rgb} for rgb in rgbs]
goals, state = puppeteers.goals_from_observations(
puppeteer, observations, state
)
return goals, state


class ConventionFollowerTest(parameterized.TestCase):

def test_starts_with_consume(self):
puppeteer = allelopathic_harvest.ConventionFollower(
initial_goal=_CONSUME_ANY,
preference_goals=_PREFER,
color_threshold=200,
recency_window=5,
)
expected = [_CONSUME_ANY]
actual, _ = _goals(puppeteer, [_NEUTRAL_COLOR])
self.assertEqual(actual, expected)

@parameterized.parameters(
(_HIGH_R, _PREFER_R),
(_HIGH_G, _PREFER_G),
(_HIGH_B, _PREFER_B),
)
def test_change_on_immediate_recency(self, image, preference):
puppeteer = allelopathic_harvest.ConventionFollower(
initial_goal=_CONSUME_ANY,
preference_goals=_PREFER,
color_threshold=200,
recency_window=1,
)
expected = [_CONSUME_ANY, preference]
actual, _ = _goals(puppeteer, [_NEUTRAL_COLOR, image])
self.assertEqual(actual, expected)

@parameterized.parameters(
(_HIGH_R,),
(_HIGH_G,),
(_HIGH_B,),
)
def test_no_change_with_long_recency(self, image):
puppeteer = allelopathic_harvest.ConventionFollower(
initial_goal=_CONSUME_ANY,
preference_goals=_PREFER,
color_threshold=250,
recency_window=5,
)
expected = [_CONSUME_ANY, _CONSUME_ANY]
actual, _ = _goals(puppeteer, [_NEUTRAL_COLOR, image])
self.assertEqual(actual, expected)

@parameterized.parameters(5, 10, 25)
def test_change_over_long_recency(self, recency):
puppeteer = allelopathic_harvest.ConventionFollower(
initial_goal=_CONSUME_ANY,
preference_goals=_PREFER,
color_threshold=254,
recency_window=recency,
)
expected = [_CONSUME_ANY] * recency + [_PREFER_R]
actual, _ = _goals(puppeteer, [_NEUTRAL_COLOR] + [_HIGH_R] * recency)
self.assertEqual(actual, expected)

if __name__ == '__main__':
absltest.main()
Loading

0 comments on commit 5c53180

Please sign in to comment.