Skip to content

Commit

Permalink
Citynerf (#26)
Browse files Browse the repository at this point in the history
* citynerf

* remove .DS_Store

* remove print stage

* Update bungeenerf_multiscale_google.py

* remove useless codes

* fix stage bugs

* fix stage bugs

* fix a data directory bug
  • Loading branch information
LOTEAT authored Jan 10, 2023
1 parent c42049b commit a6f2f6f
Show file tree
Hide file tree
Showing 21 changed files with 1,091 additions and 10 deletions.
198 changes: 198 additions & 0 deletions configs/bungeenerf/bungeenerf_multiscale_google.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
_base_ = [
# '../_base_/models/nerf.py',
# '../_base_/schedules/adam_20w_iter.py',
# '../_base_/default_runtime.py'
]

import os
from datetime import datetime

method = 'bungeenerf' # [nerf, kilo_nerf, mip_nerf, bungeenerf]

# optimizer
optimizer = dict(type='Adam', lr=5e-4, betas=(0.9, 0.999))
optimizer_config = dict(grad_clip=None)

max_iters = 200000
lr_config = dict(policy='step', step=500 * 1000, gamma=0.1, by_epoch=False)
checkpoint_config = dict(interval=500, by_epoch=False)
log_level = 'INFO'
log_config = dict(interval=5,
by_epoch=False,
hooks=[dict(type='TextLoggerHook')])
workflow = [('train', 500), ('val', 1)]

# hooks
# 'params' are numeric type value, 'variables' are variables in local environment
train_hooks = [
dict(type='SetValPipelineHook',
params=dict(),
variables=dict(valset='valset')),
dict(type='ValidateHook',
params=dict(save_folder='visualizations/validation')),
dict(type='SaveSpiralHook',
params=dict(save_folder='visualizations/spiral')),
dict(type='PassIterHook', params=dict()), # 将当前iter数告诉dataset
dict(type='OccupationHook',
params=dict()), # no need for open-source vision
]

test_hooks = [
dict(type='SetValPipelineHook',
params=dict(),
variables=dict(valset='testset')),
dict(type='TestHook', params=dict()),
]

# runner
train_runner = dict(type='BungeeNerfTrainRunner')
test_runner = dict(type='BungeeNerfTestRunner')

# runtime settings
num_gpus = 1
distributed = (num_gpus > 1) # 是否多卡,mmcv对dp多卡支持不好,故而要么单卡要么ddp多卡
stage = 0 # current stage for training
work_dir = './work_dirs/bungeenerf/#DATANAME#/stage_%d/' % stage
timestamp = datetime.now().strftime('%d-%b-%H-%M')

# shared params by model and data and ...
dataset_type = 'mutiscale_google'
no_batching = True # only take random rays from 1 image at a time

white_bkgd = False # set to render synthetic data on a white bkgd (always use for dvoxels)
is_perturb = False # set to 0. for no jitter, 1. for jitter
use_viewdirs = True # use full 5D input instead of 3D
N_rand_per_sampler = 1024 * 2 # how many N_rand in get_item() function
lindisp = False # sampling linearly in disparity rather than depth
N_samples = 65 # number of coarse samples per ray

# resume_from = os.path.join(work_dir, 'latest.pth')
load_from = os.path.join(work_dir, 'latest.pth')

model = dict(
type='BungeeNerfNetwork',
cfg=dict(
phase='train', # 'train' or 'test'
ray_shape='cone', # The shape of cast rays ('cone' or 'cylinder').
resample_padding=0.01, # Dirichlet/alpha "padding" on the histogram.
N_importance=65, # number of additional fine samples per ray
is_perturb=is_perturb,
chunk=1024 * 32, # mainly work for val
bs_data=
'rays_o', # the data's shape indicates the real batch-size, this's also the num of rays
),

mlp=dict( # coarse model
type='BungeeNerfMLP',
cur_stage=stage, # resblock nums
netwidth=256, # channels per layer
netchunk=1024 * 64, # number of pts sent through network in parallel;
embedder=dict(
type='BungeeEmbedder',
i_embed=0, # set 0 for default positional encoding, -1 for none
multires=
10, # log2 of max freq for positional encoding (3D location)
multires_dirs=
4, # this is 'multires_views' in origin codes, log2 of max freq for positional encoding (2D direction)
),
),

render=dict( # render model
type='BungeeNerfRender',
white_bkgd=
white_bkgd, # set to render synthetic data on a white bkgd (always use for dvoxels)
raw_noise_std=
0, # std dev of noise added to regularize sigma_a output, 1e0 recommended
),
)

basedata_cfg = dict(
dataset_type=dataset_type,
datadir='data/multiscale_google/#DATANAME#',
white_bkgd=white_bkgd,
factor=3,
N_rand_per_sampler=N_rand_per_sampler,
mode='train',
cur_stage=stage,
holdout=16,
is_batching=True, # True for blender, False for llff
)

traindata_cfg = basedata_cfg.copy()
valdata_cfg = basedata_cfg.copy()
testdata_cfg = basedata_cfg.copy()

traindata_cfg.update(dict())
valdata_cfg.update(dict(mode='val'))
testdata_cfg.update(dict(mode='test', testskip=0))

train_pipeline = [
dict(
type='BungeeBatchSample',
enable=True,
N_rand=N_rand_per_sampler,
),
dict(type='DeleteUseless', keys=['rays_rgb', 'idx']),
dict(
type='ToTensor',
enable=True,
keys=['rays_o', 'rays_d', 'target_s', 'scale_code'],
),
dict(
type='GetViewdirs',
enable=use_viewdirs,
),
dict(type='BungeeGetBounds', enable=True),
dict(type='BungeeGetZvals', enable=True, lindisp=lindisp,
N_samples=N_samples), # N_samples: number of coarse samples per ray
dict(type='PerturbZvals', enable=is_perturb),
dict(type='DeleteUseless', enable=True,
keys=['pose', 'iter_n']), # 删除pose 其实求完ray就不再需要了
]

test_pipeline = [
dict(
type='ToTensor',
enable=True,
keys=['pose'],
),
dict(
type='GetRays',
include_radius=True,
enable=True,
),
dict(type='FlattenRays',
include_radius=True,
enable=True), # 原来是(H, W, ..) 变成(H*W, ...) 记录下原来的尺寸
dict(
type='GetViewdirs',
enable=use_viewdirs,
),
dict(type='BungeeGetBounds', enable=True),
dict(type='BungeeGetZvals', enable=True, lindisp=lindisp,
N_samples=N_samples), # 同上train_pipeline
dict(type='PerturbZvals', enable=False), # 测试集不扰动
dict(type='DeleteUseless', enable=True,
keys=['pose']), # 删除pose 其实求完ray就不再需要了
]

data = dict(
train_loader=dict(batch_size=1, num_workers=4),
train=dict(
type='BungeeDataset',
cfg=traindata_cfg,
pipeline=train_pipeline,
),
val_loader=dict(batch_size=1, num_workers=0),
val=dict(
type='BungeeDataset',
cfg=valdata_cfg,
pipeline=test_pipeline,
),
test_loader=dict(batch_size=1, num_workers=0),
test=dict(
type='BungeeDataset',
cfg=testdata_cfg,
pipeline=test_pipeline, # same pipeline as validation
),
)
3 changes: 3 additions & 0 deletions xrnerf/core/runner/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from .base import NerfTestRunner, NerfTrainRunner
from .kilonerf_runner import (KiloNerfDistillTrainRunner, KiloNerfTestRunner,
KiloNerfTrainRunner)
from .bungeenerf_runner import BungeeNerfTrainRunner, BungeeNerfTestRunner

__all__ = [
'NerfTrainRunner',
'NerfTestRunner',
'KiloNerfDistillTrainRunner',
'KiloNerfTrainRunner',
'KiloNerfTestRunner',
'BungeeNerfTrainRunner',
'BungeeNerfTestRunner',
]
40 changes: 40 additions & 0 deletions xrnerf/core/runner/bungeenerf_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@

import time
import warnings

import mmcv
import torch
from mmcv.runner import EpochBasedRunner, IterBasedRunner
from mmcv.runner.utils import get_host_info


class BungeeNerfTrainRunner(IterBasedRunner):
def train(self, data_loader, **kwargs):
self.model.train()
self.mode = 'train'
self.data_loader = data_loader
self._epoch = data_loader.epoch
data_batch = next(data_loader)
self.data_batch = data_batch
scale_code = data_batch['scale_code']
for stage in range(int(torch.max(scale_code)+1)):
kwargs['stage'] = stage
self.call_hook('before_train_iter')
outputs = self.model.train_step(data_batch, self.optimizer, **kwargs)
if not isinstance(outputs, dict):
raise TypeError('model.train_step() must return a dict')
if 'log_vars' in outputs:
if outputs['log_vars']['loss']==0.:
continue
self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
self.log_buffer.output['stage'] = stage
self.outputs = outputs
self.call_hook('after_train_iter')
del self.data_batch
self._inner_iter += 1
self._iter += 1

class BungeeNerfTestRunner(EpochBasedRunner):
"""BungeeNerfTestRunner"""
pass

5 changes: 3 additions & 2 deletions xrnerf/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
from .samplers import DistributedSampler
from .scene_dataset import SceneBaseDataset
from .genebody_dataset import GeneBodyDataset
from .bungee_dataset import BungeeDataset

__all__ = [
'SceneBaseDataset', 'DATASETS', 'build_dataset', 'DistributedSampler',
'MipMultiScaleDataset', 'KiloNerfDataset', 'KiloNerfNodeDataset',
'NeuralBodyDataset', 'AniNeRFDataset', 'HashNerfDataset', 'GeneBodyDataset'

'NeuralBodyDataset', 'AniNeRFDataset', 'HashNerfDataset', 'GeneBodyDataset',
'BungeeDataset'
]
75 changes: 75 additions & 0 deletions xrnerf/datasets/bungee_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# # Copyright (c) OpenMMLab. All rights reserved.

import numpy as np
import torch

from .builder import DATASETS
from .scene_dataset import SceneBaseDataset
from .load_data import load_data, load_rays_bungee


@DATASETS.register_module()
class BungeeDataset(SceneBaseDataset):
def __init__(self, cfg, pipeline):
self.cur_stage = cfg.cur_stage
super().__init__(cfg, pipeline)

def _init_load(self): # load dataset when init
self.images, self.poses, self.render_poses, self.hwf, self.K, self.scene_scaling_factor, self.scene_origin, self.scale_split, self.i_train, self.i_val, self.i_test, self.n_images = load_data(self.cfg)


if self.is_batching and self.mode == 'train':
# for batching dataset, rays must be computed when init()
self.N_rand = self.cfg.N_rand_per_sampler
self.rays_rgb, self.radii, self.scale_codes = load_rays_bungee(self.hwf[0], self.hwf[1], self.hwf[2],
self.poses, self.images, self.i_train, self.n_images, self.scale_split, self.cur_stage)


def _fetch_train_data(self, idx):
if self.is_batching: # for batching dataset, rays are randomly selected from all images
data = {'rays_rgb': self.rays_rgb,
'radii': self.radii,
'scale_code': self.scale_codes,
'idx': idx}
else: # for batching dataset, rays are selected from one images
data = {
'poses': self.poses,
'images': self.images,
'n_images': self.n_images,
'i_data': self.i_train,
'idx': idx
}
data['iter_n'] = self.iter_n
return data

def _fetch_val_data(self, idx): # for val mode, fetch all data in one time
data = {
'spiral_poses':self.render_poses,
'poses':self.poses[self.i_test],
'images':self.images[self.i_test],
}
return data

def _fetch_test_data(self, idx): # different from val: test return one image once
data = {
'pose':self.poses[self.i_test][idx],
'image':self.images[self.i_test][idx],
'idx':idx
}
return data

def get_info(self):
res = {
'H': self.hwf[0],
'W': self.hwf[1],
'focal': self.hwf[2],
'K': self.K,
'render_poses': self.render_poses,
'hwf': self.hwf,
'cur_stage': self.cur_stage,
'scene_origin': self.scene_origin,
'scene_scaling_factor': self.scene_scaling_factor,
'scale_split': self.scale_split,
}
return res

3 changes: 2 additions & 1 deletion xrnerf/datasets/load_data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .get_rays import (get_rays_np, load_rays, load_rays_hash,
load_rays_multiscale)
load_rays_multiscale, load_rays_bungee)
from .load import load_data

__all__ = [
Expand All @@ -8,4 +8,5 @@
'load_rays',
'load_rays_hash',
'load_rays_multiscale',
'load_rays_bungee'
]
Loading

0 comments on commit a6f2f6f

Please sign in to comment.