-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix #366. Some other things fixed while debugging: Fix recharge being treated as a move causing agent requests to time out. Make BattleEnv.step() loop condition more robust. Set ZeroMQ high water mark on the JS worker end to ensure no dropped messages. Fix tqdm leaving artifacts in the terminal. Add simulator scripts for debugging. Expose timeout configs. Make simulateBattle() exceptions louder in the JS worker. Fix error logging/swallowing behavior in simulateBattle(). Fix wrapTimeout() call stack. Remove unused BattleParser type params TArgs, TResult. Fix game truncation handling.
- Loading branch information
1 parent
73dd6e6
commit ec3939b
Showing
49 changed files
with
2,211 additions
and
2,942 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
/** @file Simulates a random battle used for training. */ | ||
import {randomAgent} from "../src/ts/battle/agent/random"; | ||
import {gen4Parser} from "../src/ts/battle/parser/gen4"; | ||
import {ExperienceBattleParser} from "../src/ts/battle/worker/ExperienceBattleParser"; | ||
import {PlayerOptions, simulateBattle} from "../src/ts/battle/worker/battle"; | ||
import {wrapTimeout} from "../src/ts/utils/timeout"; | ||
import {Mutable} from "../src/ts/utils/types"; | ||
|
||
Error.stackTraceLimit = Infinity; | ||
|
||
const timeoutMs = 5000; // 5s | ||
const battleTimeoutMs = 1000; // 1s | ||
const maxTurns = 50; | ||
const p1Exp = true; | ||
const p2Exp = true; | ||
|
||
void (async function () { | ||
const p1: Mutable<PlayerOptions> = { | ||
name: "p1", | ||
agent: randomAgent, | ||
parser: gen4Parser, | ||
}; | ||
const p2: Mutable<PlayerOptions> = { | ||
name: "p2", | ||
agent: randomAgent, | ||
parser: gen4Parser, | ||
}; | ||
if (p1Exp) { | ||
const expParser = new ExperienceBattleParser(p1.parser, "p1"); | ||
p1.parser = async (ctx, event) => await expParser.parse(ctx, event); | ||
p1.agent = async (state, choices) => await randomAgent(state, choices); | ||
} | ||
if (p2Exp) { | ||
const expParser = new ExperienceBattleParser(p2.parser, "p2"); | ||
p2.parser = async (ctx, event) => await expParser.parse(ctx, event); | ||
p2.agent = async (state, choices) => await randomAgent(state, choices); | ||
} | ||
|
||
const result = await wrapTimeout( | ||
async () => | ||
await simulateBattle({ | ||
players: {p1, p2}, | ||
maxTurns, | ||
timeoutMs: battleTimeoutMs, | ||
}), | ||
timeoutMs, | ||
); | ||
console.log(`winner: ${result.winner}`); | ||
console.log(`truncated: ${!!result.truncated}`); | ||
console.log(`log path: ${result.logPath}`); | ||
console.log(`err: ${result.err?.stack ?? result.err}`); | ||
})().catch(err => console.log("sim-randbat failed:", err)); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
"""Simulates a random battle used for training.""" | ||
|
||
import asyncio | ||
import os | ||
import sys | ||
from contextlib import closing | ||
from itertools import chain | ||
from typing import Optional, Union | ||
|
||
import numpy as np | ||
import tensorflow as tf | ||
|
||
# So that we can `python -m scripts.sim_randbat` from project root. | ||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) | ||
|
||
# pylint: disable=wrong-import-position, import-error | ||
from src.py.agents.agent import Agent | ||
from src.py.agents.utils.epsilon_greedy import EpsilonGreedy | ||
from src.py.environments.battle_env import ( | ||
AgentDict, | ||
BattleEnv, | ||
BattleEnvConfig, | ||
EvalOpponentConfig, | ||
InfoDict, | ||
) | ||
from src.py.environments.utils.battle_pool import BattlePoolConfig | ||
from src.py.models.utils.greedy import decode_action_rankings | ||
|
||
|
||
class RandomAgent(Agent): | ||
"""Agent that acts randomly.""" | ||
|
||
def __init__(self, rng: Optional[tf.random.Generator] = None): | ||
self._epsilon_greedy = EpsilonGreedy(exploration=1.0, rng=rng) | ||
|
||
def select_action( | ||
self, | ||
state: AgentDict[Union[np.ndarray, tf.Tensor]], | ||
info: AgentDict[InfoDict], | ||
) -> AgentDict[list[str]]: | ||
"""Selects a random action.""" | ||
_ = info | ||
return dict( | ||
zip( | ||
state.keys(), | ||
decode_action_rankings( | ||
self._epsilon_greedy.rand_actions(len(state)) | ||
), | ||
) | ||
) | ||
|
||
def update_model( | ||
self, state, reward, next_state, terminated, truncated, info | ||
): | ||
"""Not implemented.""" | ||
raise NotImplementedError | ||
|
||
|
||
async def sim_randbat(): | ||
"""Starts the simulator.""" | ||
|
||
rng = tf.random.get_global_generator() | ||
|
||
agent = RandomAgent(rng) | ||
|
||
env = BattleEnv( | ||
config=BattleEnvConfig( | ||
max_turns=50, | ||
batch_limit=4, | ||
pool=BattlePoolConfig( | ||
workers=2, | ||
per_worker=1, | ||
battles_per_log=1, | ||
worker_timeout_ms=1000, # 1s | ||
sim_timeout_ms=60_000, # 1m | ||
), | ||
state_type="tensor", | ||
), | ||
rng=rng, | ||
) | ||
await env.ready() | ||
|
||
with closing(env): | ||
state, info = env.reset( | ||
rollout_battles=10, | ||
eval_opponents=( | ||
EvalOpponentConfig( | ||
name="eval_self", battles=10, type="model", model="model/p2" | ||
), | ||
), | ||
) | ||
done = False | ||
while not done: | ||
action = agent.select_action(state, info) | ||
(next_state, _, terminated, truncated, info, done) = await env.step( | ||
action | ||
) | ||
state = next_state | ||
for key, ended in chain(terminated.items(), truncated.items()): | ||
if ended: | ||
state.pop(key) | ||
info.pop(key) | ||
for key, env_info in info.items(): | ||
if key.player != "__env__": | ||
continue | ||
battle_result = env_info.get("battle_result", None) | ||
if battle_result is None: | ||
continue | ||
print(battle_result) | ||
|
||
|
||
def main(): | ||
"""Main entry point.""" | ||
asyncio.run(sim_randbat()) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.