-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
69 lines (52 loc) · 2.23 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import numpy as np
from preprocessing import DataSanitizer, CellGraph
from config import CFG
from utils import Visualize, Runner
from pathlib import Path
from scipy.stats import pearsonr
import torch
import torch.nn as nn
import torch.optim as optim
from model import Net
from rich import print
import gzip
# system inits
torch.manual_seed(CFG.seed)
torch.cuda.manual_seed_all(CFG.seed)
np.random.seed(CFG.seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
processed_file = 'processed.pt.gz'
if Path(processed_file).exists():
print(f"{processed_file} does exist!")
with gzip.open(processed_file, 'rb') as f:
tload = torch.load(f)
data, gd = tload['data'], tload['gd']
else:
data = DataSanitizer('matrix.mtx', data_path='./data')
gd = CellGraph(data.get(), device=device)[0]
with gzip.open(processed_file, 'wb') as f:
torch.save({'data': data, 'gd': gd}, f)
x, edge_index, n_features = gd.x, gd.edge_index, gd.num_features
print(gd)
data = torch.tensor(data.get(), dtype=torch.float).to(device)
for method in ('masked', 'full', 'zeros'):
print(method.title().center(80, '='))
n_features = gd.num_features
model = Net([n_features, n_features//64]).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=CFG.lr, weight_decay=CFG.wd)
# base_loss = criterion(data[gd.train_mask], gd.x[gd.train_mask])
runner = Runner(model, criterion, optimizer, x=x,
edge_index=edge_index, gd=gd, data=data)
# CFG.n_epochs = 5001
logs = runner.train_full(epochs=CFG.n_epochs, method=method)
output, target, loss_test = runner.evaluate(method='zeros')
print(f'Testing Loss: {loss_test:.4f} ')
# output, target = run.predict()
output, target = output.cpu().numpy().reshape(-1), target.cpu().numpy().reshape(-1)
print(pearsonr(output, target))
ext = f'(S{CFG.seed}E{CFG.n_epochs-1})'
Visualize(logs).plot(
title=f'Train Validation Loss #{method}', xlabel='#Epochs', ylabel='Loss Values', savefigname=f'train_val_{method}_{ext}.png')
Visualize(output, target).regplot(
title=f'neuron1K scGCN #{method}', xlabel='True', ylabel='Predicted', savefigname=f'neuron1k_correlation_{method}.png')