Skip to content

Commit

Permalink
all: remove legacy code and add some minor improvements on data proce…
Browse files Browse the repository at this point in the history
…ssing
  • Loading branch information
Equim-chan committed Aug 16, 2024
1 parent f3be98d commit edb448b
Show file tree
Hide file tree
Showing 14 changed files with 110 additions and 67 deletions.
9 changes: 5 additions & 4 deletions libriichi/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,23 +66,24 @@ where
#[cfg(test)]
mod test {
use super::*;
use ndarray::arr2;

#[test]
fn mutate() {
let mut arr = Simple2DArray::<2, i32>::new(4);
arr.fill(1, 3);
assert_eq!(&arr.arr, &[0, 0, 3, 3, 0, 0, 0, 0]);
assert_eq!(arr.build(), arr2(&[[0, 0], [3, 3], [0, 0], [0, 0]]));

let mut arr = Simple2DArray::<2, i32>::new(4);
arr.fill_rows(1, 2, 3);
assert_eq!(&arr.arr, &[0, 0, 3, 3, 3, 3, 0, 0]);
assert_eq!(arr.build(), arr2(&[[0, 0], [3, 3], [3, 3], [0, 0]]));

let mut arr = Simple2DArray::<2, i32>::new(4);
arr.assign(1, 1, 3);
assert_eq!(&arr.arr, &[0, 0, 0, 3, 0, 0, 0, 0]);
assert_eq!(arr.build(), arr2(&[[0, 0], [0, 3], [0, 0], [0, 0]]));

let mut arr = Simple2DArray::<2, i32>::new(4);
arr.assign_rows(1, 1, 2, 3);
assert_eq!(&arr.arr, &[0, 0, 0, 3, 0, 3, 0, 0]);
assert_eq!(arr.build(), arr2(&[[0, 0], [0, 3], [0, 3], [0, 0]]));
}
}
6 changes: 3 additions & 3 deletions libriichi/src/dataset/invisible.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,9 @@ impl Invisible {
encode_tile(idx, tile);
idx += 2;
}
// In real life case `self.yama[yama_idx..]` is at most 69 (`yama_idx`
// is always >= 1), because the dealer always unconditionally deals the
// first tile from yama. Therefore we do the minus one here.
// In real life case `self.yama[yama_idx..].len()` is at most 69 since
// `yama_idx` >= 1 always holds, as the dealer always unconditionally
// deals the first tile from yama. Therefore we do the minus one here.
idx += (yama_idx - 1) * 2;

for &tile in &self.rinshan[rinshan_idx..] {
Expand Down
3 changes: 1 addition & 2 deletions mortal/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ def main():
num_blocks = config['resnet']['num_blocks']
conv_channels = config['resnet']['conv_channels']

oracle = None
mortal = Brain(version=version, num_blocks=num_blocks, conv_channels=conv_channels).to(device).eval()
dqn = DQN(version=version).to(device)
if config['online']['enable_compile']:
Expand Down Expand Up @@ -51,7 +50,7 @@ def main():
dqn.load_state_dict(rsp['dqn'])
logging.info('param has been updated')

rankings, file_list = train_player.train_play(oracle, mortal, dqn, device)
rankings, file_list = train_player.train_play(mortal, dqn, device)
avg_rank = rankings @ np.arange(1, 5) / rankings.sum()
avg_pt = rankings @ pts / rankings.sum()

Expand Down
5 changes: 2 additions & 3 deletions mortal/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,12 @@ def drain():
continue
return msg['drain_dir']

def submit_param(oracle, mortal, dqn, is_idle=False):
def submit_param(mortal, dqn, is_idle=False):
remote = (config['online']['remote']['host'], config['online']['remote']['port'])
with socket.socket() as conn:
conn.connect(remote)
send_msg(conn, {
'type': 'submit_param',
'oracle': None if oracle is None else oracle.state_dict(),
'mortal': mortal.state_dict(),
'dqn': dqn.state_dict(),
'is_idle': is_idle,
Expand All @@ -65,7 +64,7 @@ def recv_msg(conn: socket.socket, map_location=torch.device('cpu')):
rx = recv_binary(conn, 8)
(size,) = struct.unpack('<Q', rx)
rx = recv_binary(conn, size)
return torch.load(BytesIO(rx), weights_only=True, map_location=map_location)
return torch.load(BytesIO(rx), weights_only=False, map_location=map_location) # TODO: weights_only=True

def recv_binary(conn: socket.socket, size):
assert size > 0
Expand Down
2 changes: 1 addition & 1 deletion mortal/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,10 @@ def populate_buffer(self, file_list):
data = self.loader.load_gz_log_files(file_list)
for file in data:
for game in file:
# per move
obs = game.take_obs()
if self.oracle:
invisible_obs = game.take_invisible_obs()
# per move
actions = game.take_actions()
masks = game.take_masks()
at_kyoku = game.take_at_kyoku()
Expand Down
17 changes: 10 additions & 7 deletions mortal/engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import traceback
import torch
import numpy as np
from torch.distributions import Normal, Categorical
Expand Down Expand Up @@ -40,17 +41,19 @@ def __init__(
self.top_p = top_p

def react_batch(self, obs, masks, invisible_obs):
with (
torch.autocast(self.device.type, enabled=self.enable_amp),
torch.no_grad(),
):
return self._react_batch(obs, masks, invisible_obs)
try:
with (
torch.autocast(self.device.type, enabled=self.enable_amp),
torch.inference_mode(),
):
return self._react_batch(obs, masks, invisible_obs)
except Exception as ex:
raise Exception(f'{ex}\n{traceback.format_exc()}')

def _react_batch(self, obs, masks, invisible_obs):
obs = torch.as_tensor(np.stack(obs, axis=0), device=self.device)
masks = torch.as_tensor(np.stack(masks, axis=0), device=self.device)
invisible_obs = None
if self.is_oracle:
if invisible_obs is not None:
invisible_obs = torch.as_tensor(np.stack(invisible_obs, axis=0), device=self.device)
batch_size = obs.shape[0]

Expand Down
10 changes: 5 additions & 5 deletions mortal/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(self, channels, ratio=16, actv_builder=nn.ReLU, bias=True):
if isinstance(mod, nn.Linear):
nn.init.constant_(mod.bias, 0)

def forward(self, x):
def forward(self, x: Tensor):
avg_out = self.shared_mlp(x.mean(-1))
max_out = self.shared_mlp(x.amax(-1))
weight = (avg_out + max_out).sigmoid()
Expand Down Expand Up @@ -149,7 +149,7 @@ def __init__(self, *, conv_channels, num_blocks, is_oracle=False, version=1):
# always use EMA or CMA when True
self._freeze_bn = False

def forward(self, obs, invisible_obs: Optional[Tensor] = None) -> Union[Tuple[Tensor, Tensor], Tensor]:
def forward(self, obs: Tensor, invisible_obs: Optional[Tensor] = None) -> Union[Tuple[Tensor, Tensor], Tensor]:
if self.is_oracle:
assert invisible_obs is not None
obs = torch.cat((obs, invisible_obs), dim=1)
Expand Down Expand Up @@ -252,7 +252,7 @@ def __init__(self, hidden_size=64, num_layers=2):
# grand_kyoku: E1 = 0, S4 = 7, W4 = 11
# s is 2.5 at E1
# s[0] is score of player id 0
def forward(self, inputs):
def forward(self, inputs: List[Tensor]):
lengths = torch.tensor([t.shape[0] for t in inputs], dtype=torch.int64)
inputs = pad_sequence(inputs, batch_first=True)
packed_inputs = pack_padded_sequence(inputs, lengths, batch_first=True, enforce_sorted=False)
Expand All @@ -265,7 +265,7 @@ def forward_packed(self, packed_inputs):
return logits

# (N, 24) -> (N, player, rank_prob)
def calc_matrix(self, logits):
def calc_matrix(self, logits: Tensor):
batch_size = logits.shape[0]
probs = logits.softmax(-1)
matrix = torch.zeros(batch_size, 4, 4, dtype=probs.dtype)
Expand All @@ -276,7 +276,7 @@ def calc_matrix(self, logits):
return matrix

# (N, 4) -> (N)
def get_label(self, rank_by_player):
def get_label(self, rank_by_player: Tensor):
batch_size = rank_by_player.shape[0]
perms = self.perms.expand(batch_size, -1, -1).transpose(0, 1)
mappings = (perms == rank_by_player).all(-1).nonzero()
Expand Down
2 changes: 1 addition & 1 deletion mortal/mortal.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def main():
range(len(feature)),
))

with torch.no_grad():
with torch.inference_mode():
logits = grp(seq)
matrix = grp.calc_matrix(logits)
extra_data = {
Expand Down
2 changes: 0 additions & 2 deletions mortal/one_vs_three.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def main():
dqn,
is_oracle = False,
version = version,
stochastic_latent = cfg['champion'].get('stochastic_latent', False),
device = torch.device(cfg['champion']['device']),
enable_amp = cfg['champion']['enable_amp'],
enable_rule_based_agari_guard = cfg['champion']['enable_rule_based_agari_guard'],
Expand All @@ -65,7 +64,6 @@ def main():
dqn,
is_oracle = False,
version = version,
stochastic_latent = cfg['challenger'].get('stochastic_latent', False),
device = torch.device(cfg['challenger']['device']),
enable_amp = cfg['challenger']['enable_amp'],
enable_rule_based_agari_guard = cfg['challenger']['enable_rule_based_agari_guard'],
Expand Down
4 changes: 1 addition & 3 deletions mortal/player.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,19 +113,17 @@ def __init__(self):
self.boltzmann_epsilon = cfg['boltzmann_epsilon']
self.boltzmann_temp = cfg['boltzmann_temp']
self.top_p = cfg['top_p']
self.stochastic_latent = cfg.get('stochastic_latent', True)

self.repeats = cfg['repeats']
self.repeat_counter = 0

def train_play(self, oracle, mortal, dqn, device):
def train_play(self, mortal, dqn, device):
torch.backends.cudnn.benchmark = False
engine_chal = MortalEngine(
mortal,
dqn,
is_oracle = False,
version = self.chal_version,
stochastic_latent = self.stochastic_latent,
boltzmann_epsilon = self.boltzmann_epsilon,
boltzmann_temp = self.boltzmann_temp,
top_p = self.top_p,
Expand Down
2 changes: 1 addition & 1 deletion mortal/reward_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def calc_grp(self, grp_feature):
range(len(grp_feature)),
))

with torch.no_grad():
with torch.inference_mode():
logits = self.grp(seq)
matrix = self.grp.calc_matrix(logits)
return matrix
Expand Down
41 changes: 20 additions & 21 deletions mortal/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,18 @@
class State:
buffer_dir: str
drain_dir: str
capacity: int
force_sequential: bool
dir_lock: Lock
param_lock: Lock
# fields below are protected by dir_lock
buffer_size: int
submission_id: int
oracle_param: Optional[OrderedDict]
# fields below are protected by param_lock
mortal_param: Optional[OrderedDict]
dqn_param: Optional[OrderedDict]
param_version: int
idle_param_version: int
capacity: int
force_sequential: bool
S = None

class Handler(BaseRequestHandler):
Expand All @@ -38,16 +39,16 @@ def handle(self):
match msg['type']:
# called by workers
case 'get_param':
self.get_param(msg)
self.handle_get_param(msg)
case 'submit_replay':
self.submit_replay(msg)
self.handle_submit_replay(msg)
# called by trainer
case 'submit_param':
self.submit_param(msg)
self.handle_submit_param(msg)
case 'drain':
self.drain()
self.handle_drain()

def get_param(self, msg):
def handle_get_param(self, msg):
with S.dir_lock:
overflow = S.buffer_size >= S.capacity
with S.param_lock:
Expand All @@ -74,7 +75,7 @@ def get_param(self, msg):
torch.save(res, buf)
self.send_msg(buf.getbuffer(), packed=True)

def submit_replay(self, msg):
def handle_submit_replay(self, msg):
with S.dir_lock:
for filename, content in msg['logs'].items():
filepath = path.join(S.buffer_dir, f'{S.submission_id}_{filename}')
Expand All @@ -84,16 +85,15 @@ def submit_replay(self, msg):
S.submission_id += 1
logging.info(f'total buffer size: {S.buffer_size}')

def submit_param(self, msg):
def handle_submit_param(self, msg):
with S.param_lock:
S.oracle_param = msg['oracle']
S.mortal_param = msg['mortal']
S.dqn_param = msg['dqn']
S.param_version += 1
if msg['is_idle']:
S.idle_param_version = S.param_version

def drain(self):
def handle_drain(self):
drained_size = 0
with S.dir_lock:
buffer_list = os.listdir(S.buffer_dir)
Expand Down Expand Up @@ -136,17 +136,16 @@ def main():
S = State(
buffer_dir = path.abspath(cfg['buffer_dir']),
drain_dir = path.abspath(cfg['drain_dir']),
dir_lock = Lock(),
param_lock = Lock(),
buffer_size = 0, # protected by dir_lock
submission_id = 0, # protected by dir_lock
oracle_param = None, # protected by param_lock
mortal_param = None, # protected by param_lock
dqn_param = None, # protected by param_lock
param_version = 0, # protected by param_lock
idle_param_version = 0, # protected by param_lock
capacity = cfg['capacity'],
force_sequential = cfg['force_sequential'],
dir_lock = Lock(),
param_lock = Lock(),
buffer_size = 0,
submission_id = 0,
mortal_param = None,
dqn_param = None,
param_version = 0,
idle_param_version = 0,
)

bind_addr = (config['online']['remote']['host'], config['online']['remote']['port'])
Expand Down
Loading

0 comments on commit edb448b

Please sign in to comment.