-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
70 lines (60 loc) · 2.11 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import torch
from torch import nn
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
from PIL import Image
from render import render_rays
def train(
model,
optimizer,
dataloader,
device="cpu",
hn=0,
hf=1,
nb_epochs=10,
nb_bins=192,
H=400,
W=400,
):
model.train()
progress_bar = tqdm(enumerate(dataloader), total=len(dataloader))
for epoch in range(nb_epochs):
progress_bar.set_description(f"training epoch: {epoch}")
for idx, batch in progress_bar:
ray_origins = batch[:, :3].to(device)
ray_directions = batch[:, 3:6].to(device)
gt_px_values = batch[:, 6:].to(device)
pred_px_values = render_rays(
model, ray_origins, ray_directions, hn, hf, nb_bins
)
loss = ((gt_px_values - pred_px_values) ** 2).mean()
progress_bar.set_postfix({"loss": loss.item()})
optimizer.zero_grad()
loss.backward()
optimizer.step()
torch.save(model.cpu(), "models/ngp_model")
model.to(device)
@torch.no_grad()
def test(
model, device, hn, hf, dataset, img_index, chunk_size=20, nb_bins=192, H=400, W=400
):
with torch.inference_mode():
model.eval()
ray_origins = dataset[img_index * H * W : (img_index + 1) * H * W, :3]
ray_directions = dataset[img_index * H * W : (img_index + 1) * H * W, :6]
px_values = [] # image
for i in range(int(np.ceil(H / chunk_size))): # iter chunks
ray_origins_ = ray_origins[
i * W * chunk_size : (i + 1) * W * chunk_size
].to(device)
ray_directions_ = ray_directions[
i * W * chunk_size : (i + 1) * W * chunk_size
].to(device)
px_values.append(
render_rays(model, ray_origins_, ray_directions_, hn, hf, nb_bins)
)
img = torch.cat(px_values).data.cpu().numpy().reshape(H, W, 3)
img = (img.clip(0, 1) * 255.0).astype(np.uint8)
img = Image.fromarray(img)
img.save(f"novel_views/img_{img_index}.png")