-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_Policy_IPS.py
45 lines (35 loc) · 1.36 KB
/
run_Policy_IPS.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import argparse
import sys
import traceback
from run_Policy_Main import get_args_all, main
# import pytest
sys.path.extend(["./src", "./src/DeepCTR-Torch", "./src/tianshou"])
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
from core.configs import get_common_args
import logzero
def get_args_ips_policy():
parser = argparse.ArgumentParser()
parser.add_argument("--user_model_name", type=str, default="DeepFM-IPS")
parser.add_argument('--epoch', default=200, type=int)
parser.add_argument("--model_name", type=str, default="IPS")
parser.add_argument("--read_message", type=str, default="DeepFM-IPS")
parser.add_argument("--entropy_window", type=int, nargs="*", default=[])
parser.add_argument('--lambda_variance', default=0, type=float)
parser.add_argument('--lambda_entropy', default=0, type=float)
parser.add_argument("--message", type=str, default="IPS")
args = parser.parse_known_args()[0]
return args
if __name__ == '__main__':
args_all = get_args_all()
args = get_common_args(args_all)
args_ips = get_args_ips_policy()
args_all.__dict__.update(args.__dict__)
args_all.__dict__.update(args_ips.__dict__)
args_all.lambda_variance = 0
args_all.lambda_entropy = 0
try:
main(args_all)
except Exception as e:
var = traceback.format_exc()
print(var)
logzero.logger.error(var)