Skip to content

Commit

Permalink
feat: ArrayLikeObservationWrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
zyr17 committed Jul 7, 2024
1 parent 7575cdb commit 5482e76
Show file tree
Hide file tree
Showing 2 changed files with 815 additions and 10 deletions.
193 changes: 183 additions & 10 deletions src/lpsim/env/ts.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,25 @@
"""
Modifications for tianshou.
"""

import types
from torch import nn
import warnings
import gymnasium as gym
from gymnasium import Space
import numpy as np
from typing import Any, Callable, Dict, Optional, Union
from tianshou.data import Collector as TianshouCollector, Batch, ReplayBuffer
from tianshou.env import (
from tianshou.data import (
Collector as TianshouCollector,
AsyncCollector as TianshouAsyncCollector,
Batch,
ReplayBuffer,
)
from tianshou.env import ( # noqa
PettingZooEnv as TianshouPettingZooEnv,
DummyVectorEnv,
BaseVectorEnv,
SubprocVectorEnv, # noqa
)
from tianshou.policy import MultiAgentPolicyManager, BasePolicy
from tianshou.utils import MultipleLRSchedulers
Expand All @@ -19,13 +28,30 @@
from lpsim.agents.nothing_agent import NothingAgent
from lpsim.agents.random_agent import RandomAgent
from lpsim.env.env import LPSimBaseV0Env
from lpsim.env.wrappers import (
ArrayActionWrapper,
ArrayLikeObservationWrapper,
AutoDiceWrapper,
)
from lpsim.agents import Agents

# from pettingzoo.classic import rps_v2
from pettingzoo.utils.wrappers import BaseWrapper

from lpsim.server.deck import Deck
from lpsim.server.match import MatchConfig
from lpsim.server.interaction import (
ChooseCharacterResponse,
DeclareRoundEndResponse,
ElementalTuningResponse,
Requests,
RerollDiceResponse,
Responses,
SwitchCardResponse,
SwitchCharacterResponse,
UseCardResponse,
UseSkillResponse,
)
from lpsim.server.match import Match, MatchConfig


class LPSimAgent2TianshouPolicy(BasePolicy):
Expand Down Expand Up @@ -61,15 +87,132 @@ def forward(
"""
Extract action from agent.
"""
# print('forward batch', len(batch))
responses = []
for obs in batch.obs:
if isinstance(obs, Match):
resp = self.agent.generate_response(obs)
elif isinstance(obs, Batch) and "requests" in obs:
# requests are set
fake_match = Match()
fake_match.requests = [x for x in obs["requests"] if x is not None]
resp = self.agent.generate_response(fake_match)
else:
raise ValueError(f"Unknown obs type {type(obs)}")
assert resp is not None
responses.append(resp)
return Batch(act=np.array(responses))

def learn(self, batch: Batch, **kwargs: Any) -> dict[str, float]:
"""currently the agent learns nothing, it returns an empty dict."""
print("learn batch", len(batch))
return {}


class CommandActionPolicy(LPSimAgent2TianshouPolicy):
def __init__(self, *args, add_dice_in_output: bool = True, **kwargs):
super().__init__(*args, **kwargs)
self._add_dice_in_output = add_dice_in_output

def _request_idx_in_list(self, request: Requests, req_list: list[Requests]) -> int:
return req_list.index(request)

def _idx_list_to_number(self, idx_list: list[int]) -> int:
return sum([1 << i for i in idx_list])

def forward(
self,
batch: Batch,
state: Dict | Batch | np.ndarray | types.NoneType = None,
**kwargs: Any,
) -> Batch:
resp_batch = super().forward(batch, state, **kwargs)
requests: list[list[Requests]] = []
for obs in batch.obs:
if isinstance(obs, Match):
requests.append(obs.requests)
elif isinstance(obs, Batch) and "requests" in obs:
requests.append([x for x in obs["requests"] if x is not None])
results = []
for resp, reqs in zip(resp_batch.act, requests):
rr: Responses = resp
pidx = rr.player_idx
ridx = self._request_idx_in_list(rr.request, reqs)
target = 0
dice = 0
if isinstance(rr, SwitchCardResponse):
dice = self._idx_list_to_number(rr.card_idxs)
elif isinstance(rr, ChooseCharacterResponse):
target = rr.character_idx
elif isinstance(rr, RerollDiceResponse):
dice = self._idx_list_to_number(rr.reroll_dice_idxs)
elif isinstance(rr, SwitchCharacterResponse):
dice = self._idx_list_to_number(rr.dice_idxs)
elif isinstance(rr, ElementalTuningResponse):
target = rr.card_idx
dice = self._idx_list_to_number([rr.dice_idx])
elif isinstance(rr, DeclareRoundEndResponse):
pass
elif isinstance(rr, UseSkillResponse):
dice = self._idx_list_to_number(rr.dice_idxs)
elif isinstance(rr, UseCardResponse):
if rr.target is not None:
targets = rr.request.targets
target = targets.index(rr.target)
dice = self._idx_list_to_number(rr.dice_idxs)
else:
raise ValueError(f"Unknown response type {type(rr)}")
if self._add_dice_in_output:
results.append((pidx, ridx, target, dice))
else:
results.append((pidx, ridx, target))
return Batch(act=np.array(results))


class TableAttnFCNet(nn.Module):
"""
Use Attention on multiple instance infos (e.g. status, character, summons),
then cat same level data, until all table info is used.
TODO: not implemented yet.
"""

def __init__(
self,
observation_space: Space | None = None,
action_space: Space | None = None,
) -> None:
super().__init__()
self.observation_space = observation_space
self.action_space = action_space

def forward(
self,
batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any,
) -> Batch:
"""
Input batch, output embedding for current status
"""
# print('forward batch', len(batch))
responses = []
for obs in batch.obs:
resp = self.agent.generate_response(obs)
if isinstance(obs, Match):
resp = self.agent.generate_response(obs)
elif isinstance(obs, Batch) and "requests" in obs:
# requests are set
fake_match = Match()
fake_match.requests = [x for x in obs["requests"] if x is not None]
resp = self.agent.generate_response(fake_match)
else:
raise ValueError(f"Unknown obs type {type(obs)}")
assert resp is not None
responses.append(resp)
return Batch(act=np.array(responses))

def learn(self, batch: Batch, **kwargs: Any) -> dict[str, float]:
"""currently the agent learns nothing, it returns an empty dict."""
print("learn batch", len(batch))
return {}


Expand Down Expand Up @@ -169,6 +312,11 @@ def reset_env(self, gym_reset_kwargs: Dict[str, Any] | None = None) -> None:
return super().reset_env(gym_reset_kwargs)


class AsyncCollector(TianshouAsyncCollector, Collector):
def __init__(self, *argv, **kwargs):
Collector.__init__(self, *argv, **kwargs)


def get_env_args():
mo_ying_cao_4_5 = (
"2BPTnM7QxlTU+xjW2EMTuFrSzSQEy/PT1kTE/vbWznQDD4TTz2TUzvnT1nQj1JjU0PPD"
Expand All @@ -184,11 +332,30 @@ def get_env_args():
def get_lpsim_env():
match_config = MatchConfig(max_round_number=999)
original_env = LPSimBaseV0Env(match_config=match_config)
original_env = ArrayLikeObservationWrapper(original_env)
original_env = ArrayActionWrapper(original_env)
original_env = AutoDiceWrapper(original_env)
reset_args = get_env_args()
env = PettingZooEnv(original_env, gym_reset_kwargs=reset_args)
return env


def async_render(self, **kwargs: Any) -> list[Any]:
"""
Render all of the environments. Will replace VectorEnv default render.
It will not raise RuntimeError when env is still stepping, instead, it will output
None.
"""
self._assert_is_not_closed()
res = []
for idx, worker in enumerate(self.workers):
if idx in self.waiting_id:
res.append(None)
else:
res.append(worker.render(**kwargs))
return res


if __name__ == "__main__":
# env = rps_v2.env(render_mode="human")
# env = PettingZooEnv(env)
Expand All @@ -199,25 +366,31 @@ def get_lpsim_env():
# Step 3: Define policies for each agent
policies = MultiAgentPolicyManager(
policies=[
LPSimAgent2TianshouPolicy(RandomAgent(player_idx=0)),
LPSimAgent2TianshouPolicy(NothingAgent(player_idx=1)),
CommandActionPolicy(RandomAgent(player_idx=0)),
CommandActionPolicy(NothingAgent(player_idx=1)),
],
env=env,
)

# Step 4: Convert the env to vector format
env = DummyVectorEnv(
VectorEnv = DummyVectorEnv
VectorEnv = SubprocVectorEnv
VectorEnv = DummyVectorEnv
env = VectorEnv(
# [lambda: PettingZooEnv(original_env, gym_reset_kwargs=reset_args)]
[get_lpsim_env] * 2
[get_lpsim_env] * 4,
# wait_num=2,
# timeout=None,
)
env.render = types.MethodType(async_render, env)

reset_args = get_env_args()

# Step 5: Construct the Collector, which interfaces the policies with the
# vectorised environment
collector = Collector(policies, env, gym_reset_kwargs=reset_args)
collector = AsyncCollector(policies, env, gym_reset_kwargs=reset_args)

# Step 6: Execute the environment with the agents playing for 1 episode, and render
# a frame every 0.1 seconds
result = collector.collect(n_episode=3, render=0.1, gym_reset_kwargs=reset_args)
result = collector.collect(n_episode=100, render=0.1, gym_reset_kwargs=reset_args)
print("done")
Loading

0 comments on commit 5482e76

Please sign in to comment.