-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathrun_clustering.py
121 lines (96 loc) · 3.83 KB
/
run_clustering.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import argparse
import itertools
import time
import torch
from model import DCP
from utils.get_mask import get_mask
from utils.util import cal_std
from utils.logger_ import get_logger
from utils.datasets import *
from configure.configure_clustering import get_default_config
import collections
import warnings
warnings.simplefilter("ignore")
dataset = {
0: "Caltech101-20",
1: "Scene_15",
2: "NoisyMNIST",
3: "LandUse_21",
}
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=int, default='0', help='dataset id')
parser.add_argument('--devices', type=str, default='0', help='gpu device ids')
parser.add_argument('--print_num', type=int, default='50', help='gap of print evaluations')
parser.add_argument('--test_time', type=int, default='5', help='number of test times')
parser.add_argument('--missing_rate', type=float, default='0', help='missing rate')
args = parser.parse_args()
dataset = dataset[args.dataset]
def main():
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.devices)
use_cuda = torch.cuda.is_available()
device = torch.device('cuda:0' if use_cuda else 'cpu')
# Configure
config = get_default_config(dataset)
config['missing_rate'] = args.missing_rate
config['print_num'] = args.print_num
config['dataset'] = dataset
logger, plt_name = get_logger(config)
logger.info('Dataset:' + str(dataset))
for (k, v) in config.items():
if isinstance(v, dict):
logger.info("%s={" % (k))
for (g, z) in v.items():
logger.info(" %s = %s" % (g, z))
else:
logger.info("%s = %s" % (k, v))
# Load data
X_list, Y_list = load_data(config)
x1_train_raw = X_list[0]
x2_train_raw = X_list[1]
fold_acc, fold_nmi, fold_ari = [], [], []
for data_seed in range(1, args.test_time + 1):
start = time.time()
np.random.seed(data_seed)
# Get Mask
mask = get_mask(2, x1_train_raw.shape[0], config['missing_rate'])
x1_train = x1_train_raw * mask[:, 0][:, np.newaxis]
x2_train = x2_train_raw * mask[:, 1][:, np.newaxis]
x1_train = torch.from_numpy(x1_train).float().to(device)
x2_train = torch.from_numpy(x2_train).float().to(device)
mask = torch.from_numpy(mask).long().to(device)
# Accumulated metrics
accumulated_metrics = collections.defaultdict(list)
# Set random seeds
if config['missing_rate'] == 0:
seed = data_seed
else:
seed = config['seed']
np.random.seed(seed)
random.seed(seed + 1)
torch.manual_seed(seed + 2)
torch.cuda.manual_seed(seed + 3)
torch.backends.cudnn.deterministic = True
# Build model
DCP_model = DCP(config)
optimizer = torch.optim.Adam(
itertools.chain(DCP_model.autoencoder1.parameters(), DCP_model.autoencoder2.parameters(),
DCP_model.img2txt.parameters(), DCP_model.txt2img.parameters()),
lr=config['training']['lr'])
# Print the models
logger.info(DCP_model.autoencoder1)
logger.info(DCP_model.img2txt)
logger.info(optimizer)
DCP_model.autoencoder1.to(device), DCP_model.autoencoder2.to(device)
DCP_model.img2txt.to(device), DCP_model.txt2img.to(device)
# Training
acc, nmi, ari = DCP_model.train_clustering(config, logger, accumulated_metrics, x1_train, x2_train, Y_list, mask,
optimizer, device)
fold_acc.append(acc)
fold_nmi.append(nmi)
fold_ari.append(ari)
print(time.time() - start)
logger.info('--------------------Training over--------------------')
acc, nmi, ari = cal_std(logger, fold_acc, fold_nmi, fold_ari)
if __name__ == '__main__':
main()