-
Notifications
You must be signed in to change notification settings - Fork 0
/
simulate.py
41 lines (32 loc) · 1.29 KB
/
simulate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import os
import argparse
from pathlib import Path
from loguru import logger
import pandas as pd
from ranking_constraints.config import create_parser, parse_args, parse_dataset_args
from ranking_constraints.simulator import Simulator
from ranking_constraints import controller as ctrl
if __name__ == "__main__":
configs = create_parser()
configs = parse_dataset_args(configs)
configs = parse_args(configs)
logger.level("DEBUG")
logger.info(f"T={configs.T}, N={configs.N}, delta={configs.delta}")
output_dir = f"{os.getcwd()}/results/{configs.dataset}"
if not os.path.exists(output_dir):
os.makedirs(output_dir)
ctrl_class = getattr(ctrl, configs.controller)
controller = ctrl_class(configs)
simulator = Simulator(configs, controller, test=True, tqdm_disable=False)
state, utility, obs = simulator.simulate(configs.R)
intermediate_metrics = simulator.intermediate_metrics
df = pd.DataFrame(intermediate_metrics)
if hasattr(configs, 'metrics_file_name'):
df.to_pickle(f'{output_dir}/{configs.metrics_file_name}.pkl')
metrics = simulator.get_metrics(configs.delta)
columns = list(metrics.keys())
values = [str(metrics[c]) for c in columns]
header = ",".join(columns)
print(header)
row = ",".join(values)
print(row)