Skip to content

Commit

Permalink
Refactor battle stream API
Browse files Browse the repository at this point in the history
Fix #366.

Some other things fixed while debugging:

Fix recharge being treated as a move causing agent requests to time out.

Make BattleEnv.step() loop condition more robust.

Set ZeroMQ high water mark on the JS worker end to ensure no dropped
messages.

Fix tqdm leaving artifacts in the terminal.

Add simulator scripts for debugging.

Expose timeout configs.

Make simulateBattle() exceptions louder in the JS worker.

Fix error logging/swallowing behavior in simulateBattle().

Fix wrapTimeout() call stack.

Remove unused BattleParser type params TArgs, TResult.

Fix game truncation handling.
  • Loading branch information
taylorhansen committed Jan 16, 2024
1 parent 73dd6e6 commit ec3939b
Show file tree
Hide file tree
Showing 49 changed files with 2,211 additions and 2,942 deletions.
4 changes: 4 additions & 0 deletions config/train_example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ rollout:
workers: 1
per_worker: 1
battles_per_log: 1000
worker_timeout_ms: 60_000 # 1m
sim_timeout_ms: 300_000 # 5m
state_type: numpy
opponents:
- name: previous
Expand All @@ -86,6 +88,8 @@ eval:
workers: 4
per_worker: 2
battles_per_log: 100
worker_timeout_ms: 60_000 # 1m
sim_timeout_ms: 300_000 # 5m
state_type: tensor
opponents:
- name: previous
Expand Down
52 changes: 52 additions & 0 deletions scripts/sim-randbat.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/** @file Simulates a random battle used for training. */
import {randomAgent} from "../src/ts/battle/agent/random";
import {gen4Parser} from "../src/ts/battle/parser/gen4";
import {ExperienceBattleParser} from "../src/ts/battle/worker/ExperienceBattleParser";
import {PlayerOptions, simulateBattle} from "../src/ts/battle/worker/battle";
import {wrapTimeout} from "../src/ts/utils/timeout";
import {Mutable} from "../src/ts/utils/types";

Error.stackTraceLimit = Infinity;

const timeoutMs = 5000; // 5s
const battleTimeoutMs = 1000; // 1s
const maxTurns = 50;
const p1Exp = true;
const p2Exp = true;

void (async function () {
const p1: Mutable<PlayerOptions> = {
name: "p1",
agent: randomAgent,
parser: gen4Parser,
};
const p2: Mutable<PlayerOptions> = {
name: "p2",
agent: randomAgent,
parser: gen4Parser,
};
if (p1Exp) {
const expParser = new ExperienceBattleParser(p1.parser, "p1");
p1.parser = async (ctx, event) => await expParser.parse(ctx, event);
p1.agent = async (state, choices) => await randomAgent(state, choices);
}
if (p2Exp) {
const expParser = new ExperienceBattleParser(p2.parser, "p2");
p2.parser = async (ctx, event) => await expParser.parse(ctx, event);
p2.agent = async (state, choices) => await randomAgent(state, choices);
}

const result = await wrapTimeout(
async () =>
await simulateBattle({
players: {p1, p2},
maxTurns,
timeoutMs: battleTimeoutMs,
}),
timeoutMs,
);
console.log(`winner: ${result.winner}`);
console.log(`truncated: ${!!result.truncated}`);
console.log(`log path: ${result.logPath}`);
console.log(`err: ${result.err?.stack ?? result.err}`);
})().catch(err => console.log("sim-randbat failed:", err));
118 changes: 118 additions & 0 deletions scripts/sim_randbat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""Simulates a random battle used for training."""

import asyncio
import os
import sys
from contextlib import closing
from itertools import chain
from typing import Optional, Union

import numpy as np
import tensorflow as tf

# So that we can `python -m scripts.sim_randbat` from project root.
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

# pylint: disable=wrong-import-position, import-error
from src.py.agents.agent import Agent
from src.py.agents.utils.epsilon_greedy import EpsilonGreedy
from src.py.environments.battle_env import (
AgentDict,
BattleEnv,
BattleEnvConfig,
EvalOpponentConfig,
InfoDict,
)
from src.py.environments.utils.battle_pool import BattlePoolConfig
from src.py.models.utils.greedy import decode_action_rankings


class RandomAgent(Agent):
"""Agent that acts randomly."""

def __init__(self, rng: Optional[tf.random.Generator] = None):
self._epsilon_greedy = EpsilonGreedy(exploration=1.0, rng=rng)

def select_action(
self,
state: AgentDict[Union[np.ndarray, tf.Tensor]],
info: AgentDict[InfoDict],
) -> AgentDict[list[str]]:
"""Selects a random action."""
_ = info
return dict(
zip(
state.keys(),
decode_action_rankings(
self._epsilon_greedy.rand_actions(len(state))
),
)
)

def update_model(
self, state, reward, next_state, terminated, truncated, info
):
"""Not implemented."""
raise NotImplementedError


async def sim_randbat():
"""Starts the simulator."""

rng = tf.random.get_global_generator()

agent = RandomAgent(rng)

env = BattleEnv(
config=BattleEnvConfig(
max_turns=50,
batch_limit=4,
pool=BattlePoolConfig(
workers=2,
per_worker=1,
battles_per_log=1,
worker_timeout_ms=1000, # 1s
sim_timeout_ms=60_000, # 1m
),
state_type="tensor",
),
rng=rng,
)
await env.ready()

with closing(env):
state, info = env.reset(
rollout_battles=10,
eval_opponents=(
EvalOpponentConfig(
name="eval_self", battles=10, type="model", model="model/p2"
),
),
)
done = False
while not done:
action = agent.select_action(state, info)
(next_state, _, terminated, truncated, info, done) = await env.step(
action
)
state = next_state
for key, ended in chain(terminated.items(), truncated.items()):
if ended:
state.pop(key)
info.pop(key)
for key, env_info in info.items():
if key.player != "__env__":
continue
battle_result = env_info.get("battle_result", None)
if battle_result is None:
continue
print(battle_result)


def main():
"""Main entry point."""
asyncio.run(sim_randbat())


if __name__ == "__main__":
main()
19 changes: 14 additions & 5 deletions src/py/environments/battle_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,16 +338,25 @@ async def step(
num_pending = 0
all_reqs: AgentDict[Union[AgentRequest, AgentFinalRequest]] = {}
while (
self.config.batch_limit <= 0
or num_pending < self.config.batch_limit
) and await self.battle_pool.agent_poll(
timeout=0 if len(all_reqs) > 0 else None
(
self.config.batch_limit <= 0
or num_pending < self.config.batch_limit
)
and (
len(self.active_battles) > 0
or (self.queue_task is not None and not self.queue_task.done())
)
and await self.battle_pool.agent_poll(
timeout=0
if len(all_reqs) > 0
else self.config.pool.worker_timeout_ms
)
):
key, req, state = await self.battle_pool.agent_recv(
flags=zmq.DONTWAIT
)
assert all_reqs.get(key, None) is None, (
f"Received duplicate agent request for {key}: "
f"Received too many agent requests for {key}: "
f"{(req)}, previous {all_reqs[key]}"
)
all_reqs[key] = req
Expand Down
28 changes: 19 additions & 9 deletions src/py/environments/utils/battle_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,18 @@ class BattlePoolConfig:
regardless of this value. Omit to not store logs except on error.
"""

worker_timeout_ms: Optional[int] = None
"""
Worker communication timeout in milliseconds for both starting battles and
managing battle agents. Used for catching rare async bugs.
"""

sim_timeout_ms: Optional[int] = None
"""
Simulator timeout in milliseconds for processing battle-related actions and
events. Used for catching rare async bugs.
"""


class BattleKey(NamedTuple):
"""Key type used to identify individual battles when using many workers."""
Expand Down Expand Up @@ -119,21 +131,16 @@ def __init__(
self.ctx.setsockopt(zmq.LINGER, 0)
# Prevent messages from getting dropped.
self.ctx.setsockopt(zmq.ROUTER_MANDATORY, 1)
self.ctx.setsockopt(zmq.SNDHWM, 0)
self.ctx.setsockopt(zmq.RCVHWM, 0)
self.ctx.setsockopt(zmq.SNDHWM, 0)
if config.worker_timeout_ms is not None:
self.ctx.setsockopt(zmq.RCVTIMEO, config.worker_timeout_ms)
self.ctx.setsockopt(zmq.SNDTIMEO, config.worker_timeout_ms)

self.battle_sock = self.ctx.socket(zmq.ROUTER)
# Prevent indefinite blocking.
self.battle_sock.setsockopt(zmq.SNDTIMEO, 10_000) # 10s
self.battle_sock.setsockopt(zmq.RCVTIMEO, 10_000)
self.battle_sock.bind(f"ipc:///tmp/psai-battle-socket-{self.sock_id}")

self.agent_sock = self.ctx.socket(zmq.ROUTER)
# The JS simulator is very fast compared to the ML code so it shouldn't
# take long at all to send predictions to or receive requests from any
# of the connected workers.
self.agent_sock.setsockopt(zmq.SNDTIMEO, 10_000) # 10s
self.agent_sock.setsockopt(zmq.RCVTIMEO, 10_000)
self.agent_sock.bind(f"ipc:///tmp/psai-agent-socket-{self.sock_id}")

self.agent_poller = zmq.asyncio.Poller()
Expand Down Expand Up @@ -260,6 +267,9 @@ async def queue_battle(
"onlyLogOnError": self.config.battles_per_log is None
or self.battle_count % self.config.battles_per_log != 0,
"seed": prng_seeds[0],
"timeoutMs": self.config.sim_timeout_ms
if self.config.sim_timeout_ms is not None
else None,
}
await self.battle_sock.send_multipart(
[worker_id, json.dumps(req).encode()]
Expand Down
11 changes: 10 additions & 1 deletion src/py/environments/utils/protocol.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Describes the JSON protocol for the BattlePool.
Corresponds to src/ts/battle/worker/protocol.ts.
MUST keep this in sync with src/ts/battle/worker/protocol.ts.
"""
from typing import Optional, TypedDict

Expand Down Expand Up @@ -75,6 +75,12 @@ class BattleRequest(TypedDict):
seed: Optional[PRNGSeed]
"""Seed for battle engine."""

timeoutMs: Optional[int]
"""
Simulator timeout in milliseconds for processing battle-related actions and
events. Used for catching rare async bugs.
"""


class BattleReply(TypedDict):
"""Result of finished battle."""
Expand All @@ -94,6 +100,9 @@ class BattleReply(TypedDict):
truncated: Optional[bool]
"""Whether the battle was truncated due to max turn limit or error."""

logPath: Optional[str]
"""Resolved path to the log file."""

err: Optional[str]
"""Captured exception with stack trace if it was thrown during the game."""

Expand Down
4 changes: 2 additions & 2 deletions src/py/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ async def run_eval(
unit="battles",
unit_scale=True,
dynamic_ncols=True,
position=0,
position=1,
) as pbar:
state, info = env.reset(eval_opponents=opponents)
done = False
Expand Down Expand Up @@ -250,7 +250,7 @@ async def train(config: TrainConfig):
dynamic_ncols=True,
smoothing=0.1,
initial=min(int(episode), config.rollout.num_episodes),
position=1,
position=0,
) as pbar:
if config.rollout.eps_per_eval > 0 and not restored and episode == 0:
# Pre-evaluation for comparison against the later trained model.
Expand Down
Loading

0 comments on commit ec3939b

Please sign in to comment.