-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathconfig.py
28 lines (22 loc) · 922 Bytes
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# dataset name
dataset = 'frappe'
assert dataset in ['ml-tag', 'frappe']
# model name
model = 'NFM'
assert model in ['FM', 'NFM']
# as the log_loss is not implemented in Xiangnan's repo, I thus remove this loss type
loss_type = 'square_loss'
assert loss_type in ['square_loss']
# important settings (normally default is the paper choice)
optimizer = 'Adagrad'
activation_function = 'relu'
assert optimizer in ['Adagrad', 'Adam', 'SGD', 'Momentum']
assert activation_function in ['relu', 'sigmoid', 'tanh', 'identity']
# paths
main_path = '/home/share/guoyangyang/recommendation/NFM-Data/{}/'.format(dataset)
train_libfm = main_path + '{}.train.libfm'.format(dataset)
valid_libfm = main_path + '{}.validation.libfm'.format(dataset)
test_libfm = main_path + '{}.test.libfm'.format(dataset)
model_path = './models/'
FM_model_path = model_path + 'FM.pth' # useful when pre-train
NFM_model_path = model_path + 'NFM.pth'