Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
zyr17 committed Jul 28, 2024
1 parent aa52f00 commit f1470f8
Showing 1 changed file with 82 additions and 2 deletions.
84 changes: 82 additions & 2 deletions src/lpsim/env/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Wrappers for gymnasium. If want to use pettingzoo, may need to modify them.
"""

from functools import lru_cache
import os # noqa
from typing import Type
import numpy as np
Expand All @@ -18,12 +19,12 @@
from lpsim.server.character.character_base import CharacterBase, TalentBase
from lpsim.server.consts import CostLabels, ObjectType, WeaponType
from lpsim.server.deck import Deck
from lpsim.server.interaction import Requests
from lpsim.server.interaction import DeclareRoundEndRequest, Requests
from lpsim.server.match import Match, MatchConfig
from lpsim.server.object_base import CardBase, ObjectBase
from lpsim.server.player_table import PlayerTable
from lpsim.server.status.team_status.base import TeamStatusBase
from lpsim.server.struct import Cost
from lpsim.server.struct import Cost, MultipleObjectPosition
from lpsim.server.summon.base import SummonBase
from lpsim.utils.desc_registry import get_desc_patch

Expand Down Expand Up @@ -79,10 +80,12 @@ def __init__(
env: AECEnv,
max_possible_action_number: int = 20,
max_possible_status_number: int = 10,
max_possible_target_number: int = 4,
):
super().__init__(env)
self.max_possilable_action_number: int = max_possible_action_number
self.max_possible_status_number = max_possible_status_number
self.max_possible_target_number = max_possible_target_number
inner_env: LPSimBaseV0Env = env.unwrapped # type: ignore
config = inner_env.match_config
if config is None:
Expand Down Expand Up @@ -505,6 +508,82 @@ def player_observation(
result_dict.update(self.character_status_observation(c))
return result_dict

def one_request_observation(
self,
request: Requests,
) -> np.ndarray:
"""
Convert request to ndarray observation. It split requests by their class,
then encode them in different ways.
"""
player_idx = request.player_idx
req_cost = self.cost_observation(Cost())
req_info = 0
req_targets: list[int] = []
if request.name == "ChooseCharacterRequest":
req_type = 1
req_info = (2 ** np.array(request.available_character_idxs)).sum().item()
elif request.name == "SwitchCardRequest":
req_type = 2
elif request.name == "UseCardRequest":
req_type = 3
req_cost = self.cost_observation(request.cost)
targets = request.targets
if targets is not None:
# TODO how to align with existing embeddings
for x in targets:
assert not isinstance(x, MultipleObjectPosition)
req_targets.append(self._target_to_id(x))
elif request.name == "UseSkillRequest":
req_type = 4
req_cost = self.cost_observation(request.cost)
req_targets = [self._target_to_id(request.skill_idx)]
elif request.name == "DeclareRoundEndRequest":
req_type = 5
elif request.name == "ElementalTuningRequest":
req_type = 6
req_info = (2 ** np.array(request.card_idxs)).sum().item()
elif request.name == "RerollDiceRequest":
req_type = 7
req_info = 0 # TODO
elif request.name == "SwitchCharacterRequest":
req_type = 8
req_cost = self.cost_observation(request.cost)
req_targets = [self._target_to_id(request.target_character_idx)]
else:
raise NotImplementedError(f"unknown request: {request.name}")
targets = np.array(targets)
expanded_targets = np.zeros(self.max_possible_target_number, dtype=int)
valid_target_mask = np.zeros(self.max_possible_target_number, dtype=int)
expanded_targets[: len(targets)] = targets
valid_target_mask[: len(targets)] = 1
final_targets = np.stack([valid_target_mask, expanded_targets], axis=1)
result = np.concatenate(
[[player_idx, req_type, req_info], req_cost, final_targets.flatten()]
)
return result

@lru_cache
def _request_observation_length(self):
return self.one_request_observation(DeclareRoundEndRequest(player_idx=0)).shape[
0
]

def request_observation(
self,
requests: list[Requests],
) -> dict[str, np.ndarray]:
"""
Gather and return all observation for requests. It will use
one_request_observation to convert each request.
"""
result = [self.one_request_observation(req) for req in requests]
assert len(result) <= self.max_possilable_action_number
reql = self._request_observation_length()
result += [np.zeros(reql)] * (self.max_possilable_action_number - len(result))
result = np.stack(result, dtype=float)
return {"requests": result}

def full_observation(self, match: Match) -> dict[str, np.ndarray]:
"""
Full observation. It contains two player_observation, and round.
Expand All @@ -516,6 +595,7 @@ def full_observation(self, match: Match) -> dict[str, np.ndarray]:
**self.player_observation(
1, match.player_tables[1], match.requests, match.current_player == 1
),
**self.request_observation(match.requests),
"round": np.array(match.round_number),
}

Expand Down

0 comments on commit f1470f8

Please sign in to comment.