-
Notifications
You must be signed in to change notification settings - Fork 0
/
inference_vae.py
63 lines (54 loc) · 2.19 KB
/
inference_vae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import argparse
import json
from datetime import datetime
from pathlib import Path
from data.dataset import ShapeNetDataModule2
from models.autoencoder import AutoencoderKL
from dotmap import DotMap
from pytorch_lightning import seed_everything
import numpy as np
from omegaconf import OmegaConf
import os
import yaml
def get_current_time():
now = datetime.now().strftime("%m-%d-%H%M%S")
return now
def main(args):
"""config"""
config = DotMap()
config.update(vars(args))
config.device = f"cuda:{args.gpu}"
vae_config = OmegaConf.load(args.config)
save_dir = './output/vae_reconstruction/' + args.exp_name
os.makedirs(save_dir, exist_ok=True)
OmegaConf.save(vae_config, f"./{save_dir}/{args.config.split('/')[-1]}")
assert args.ckpt is not None, "Please provide the path to the checkpoint."
ds_module = ShapeNetDataModule2(
"./data",
target_categories=config.target_categories,
batch_size=vae_config.data.batch_size,
num_workers=vae_config.data.num_workers,
)
test_dl = ds_module.test_dataloader()
val_dl = ds_module.val_dataloader()
autoencoder = AutoencoderKL(ddconfig=vae_config.model.params.ddconfig,
disc_config=vae_config.model.params.disc_config,
kl_weight=vae_config.model.params.kl_weight,
embed_dim=vae_config.model.params.embed_dim,
learning_rate=vae_config.model.learning_rate,
ckpt_path=args.ckpt,
ignore_keys=['discriminator'])
autoencoder.to(config.device)
autoencoder.eval()
rec_data = autoencoder.inference(test_dl, val_dl)
np.save(f"./{save_dir}/rec_data.npy", rec_data)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--gpu", type=int, default=0)
parser.add_argument("--target_categories", type=str, default=None)
parser.add_argument("--config", type=str, required=True)
parser.add_argument("--ckpt", type=str, required=True)
parser.add_argument("--exp_name", type=str, required=True)
args = parser.parse_args()
seed_everything(0)
main(args)