diff --git a/configs/bungeenerf/bungeenerf_multiscale_google.py b/configs/bungeenerf/bungeenerf_multiscale_google.py new file mode 100644 index 0000000..98cf59d --- /dev/null +++ b/configs/bungeenerf/bungeenerf_multiscale_google.py @@ -0,0 +1,198 @@ +_base_ = [ + # '../_base_/models/nerf.py', + # '../_base_/schedules/adam_20w_iter.py', + # '../_base_/default_runtime.py' +] + +import os +from datetime import datetime + +method = 'bungeenerf' # [nerf, kilo_nerf, mip_nerf, bungeenerf] + +# optimizer +optimizer = dict(type='Adam', lr=5e-4, betas=(0.9, 0.999)) +optimizer_config = dict(grad_clip=None) + +max_iters = 200000 +lr_config = dict(policy='step', step=500 * 1000, gamma=0.1, by_epoch=False) +checkpoint_config = dict(interval=500, by_epoch=False) +log_level = 'INFO' +log_config = dict(interval=5, + by_epoch=False, + hooks=[dict(type='TextLoggerHook')]) +workflow = [('train', 500), ('val', 1)] + +# hooks +# 'params' are numeric type value, 'variables' are variables in local environment +train_hooks = [ + dict(type='SetValPipelineHook', + params=dict(), + variables=dict(valset='valset')), + dict(type='ValidateHook', + params=dict(save_folder='visualizations/validation')), + dict(type='SaveSpiralHook', + params=dict(save_folder='visualizations/spiral')), + dict(type='PassIterHook', params=dict()), # 将当前iter数告诉dataset + dict(type='OccupationHook', + params=dict()), # no need for open-source vision +] + +test_hooks = [ + dict(type='SetValPipelineHook', + params=dict(), + variables=dict(valset='testset')), + dict(type='TestHook', params=dict()), +] + +# runner +train_runner = dict(type='BungeeNerfTrainRunner') +test_runner = dict(type='BungeeNerfTestRunner') + +# runtime settings +num_gpus = 1 +distributed = (num_gpus > 1) # 是否多卡,mmcv对dp多卡支持不好,故而要么单卡要么ddp多卡 +stage = 0 # current stage for training +work_dir = './work_dirs/bungeenerf/#DATANAME#/stage_%d/' % stage +timestamp = datetime.now().strftime('%d-%b-%H-%M') + +# shared params by model and data and ... +dataset_type = 'mutiscale_google' +no_batching = True # only take random rays from 1 image at a time + +white_bkgd = False # set to render synthetic data on a white bkgd (always use for dvoxels) +is_perturb = False # set to 0. for no jitter, 1. for jitter +use_viewdirs = True # use full 5D input instead of 3D +N_rand_per_sampler = 1024 * 2 # how many N_rand in get_item() function +lindisp = False # sampling linearly in disparity rather than depth +N_samples = 65 # number of coarse samples per ray + +# resume_from = os.path.join(work_dir, 'latest.pth') +load_from = os.path.join(work_dir, 'latest.pth') + +model = dict( + type='BungeeNerfNetwork', + cfg=dict( + phase='train', # 'train' or 'test' + ray_shape='cone', # The shape of cast rays ('cone' or 'cylinder'). + resample_padding=0.01, # Dirichlet/alpha "padding" on the histogram. + N_importance=65, # number of additional fine samples per ray + is_perturb=is_perturb, + chunk=1024 * 32, # mainly work for val + bs_data= + 'rays_o', # the data's shape indicates the real batch-size, this's also the num of rays + ), + + mlp=dict( # coarse model + type='BungeeNerfMLP', + cur_stage=stage, # resblock nums + netwidth=256, # channels per layer + netchunk=1024 * 64, # number of pts sent through network in parallel; + embedder=dict( + type='BungeeEmbedder', + i_embed=0, # set 0 for default positional encoding, -1 for none + multires= + 10, # log2 of max freq for positional encoding (3D location) + multires_dirs= + 4, # this is 'multires_views' in origin codes, log2 of max freq for positional encoding (2D direction) + ), + ), + + render=dict( # render model + type='BungeeNerfRender', + white_bkgd= + white_bkgd, # set to render synthetic data on a white bkgd (always use for dvoxels) + raw_noise_std= + 0, # std dev of noise added to regularize sigma_a output, 1e0 recommended + ), +) + +basedata_cfg = dict( + dataset_type=dataset_type, + datadir='data/multiscale_google/#DATANAME#', + white_bkgd=white_bkgd, + factor=3, + N_rand_per_sampler=N_rand_per_sampler, + mode='train', + cur_stage=stage, + holdout=16, + is_batching=True, # True for blender, False for llff +) + +traindata_cfg = basedata_cfg.copy() +valdata_cfg = basedata_cfg.copy() +testdata_cfg = basedata_cfg.copy() + +traindata_cfg.update(dict()) +valdata_cfg.update(dict(mode='val')) +testdata_cfg.update(dict(mode='test', testskip=0)) + +train_pipeline = [ + dict( + type='BungeeBatchSample', + enable=True, + N_rand=N_rand_per_sampler, + ), + dict(type='DeleteUseless', keys=['rays_rgb', 'idx']), + dict( + type='ToTensor', + enable=True, + keys=['rays_o', 'rays_d', 'target_s', 'scale_code'], + ), + dict( + type='GetViewdirs', + enable=use_viewdirs, + ), + dict(type='BungeeGetBounds', enable=True), + dict(type='BungeeGetZvals', enable=True, lindisp=lindisp, + N_samples=N_samples), # N_samples: number of coarse samples per ray + dict(type='PerturbZvals', enable=is_perturb), + dict(type='DeleteUseless', enable=True, + keys=['pose', 'iter_n']), # 删除pose 其实求完ray就不再需要了 +] + +test_pipeline = [ + dict( + type='ToTensor', + enable=True, + keys=['pose'], + ), + dict( + type='GetRays', + include_radius=True, + enable=True, + ), + dict(type='FlattenRays', + include_radius=True, + enable=True), # 原来是(H, W, ..) 变成(H*W, ...) 记录下原来的尺寸 + dict( + type='GetViewdirs', + enable=use_viewdirs, + ), + dict(type='BungeeGetBounds', enable=True), + dict(type='BungeeGetZvals', enable=True, lindisp=lindisp, + N_samples=N_samples), # 同上train_pipeline + dict(type='PerturbZvals', enable=False), # 测试集不扰动 + dict(type='DeleteUseless', enable=True, + keys=['pose']), # 删除pose 其实求完ray就不再需要了 +] + +data = dict( + train_loader=dict(batch_size=1, num_workers=4), + train=dict( + type='BungeeDataset', + cfg=traindata_cfg, + pipeline=train_pipeline, + ), + val_loader=dict(batch_size=1, num_workers=0), + val=dict( + type='BungeeDataset', + cfg=valdata_cfg, + pipeline=test_pipeline, + ), + test_loader=dict(batch_size=1, num_workers=0), + test=dict( + type='BungeeDataset', + cfg=testdata_cfg, + pipeline=test_pipeline, # same pipeline as validation + ), +) diff --git a/xrnerf/core/runner/__init__.py b/xrnerf/core/runner/__init__.py index 0ad49cd..85e6687 100644 --- a/xrnerf/core/runner/__init__.py +++ b/xrnerf/core/runner/__init__.py @@ -1,6 +1,7 @@ from .base import NerfTestRunner, NerfTrainRunner from .kilonerf_runner import (KiloNerfDistillTrainRunner, KiloNerfTestRunner, KiloNerfTrainRunner) +from .bungeenerf_runner import BungeeNerfTrainRunner, BungeeNerfTestRunner __all__ = [ 'NerfTrainRunner', @@ -8,4 +9,6 @@ 'KiloNerfDistillTrainRunner', 'KiloNerfTrainRunner', 'KiloNerfTestRunner', + 'BungeeNerfTrainRunner', + 'BungeeNerfTestRunner', ] diff --git a/xrnerf/core/runner/bungeenerf_runner.py b/xrnerf/core/runner/bungeenerf_runner.py new file mode 100644 index 0000000..329c81e --- /dev/null +++ b/xrnerf/core/runner/bungeenerf_runner.py @@ -0,0 +1,40 @@ + +import time +import warnings + +import mmcv +import torch +from mmcv.runner import EpochBasedRunner, IterBasedRunner +from mmcv.runner.utils import get_host_info + + +class BungeeNerfTrainRunner(IterBasedRunner): + def train(self, data_loader, **kwargs): + self.model.train() + self.mode = 'train' + self.data_loader = data_loader + self._epoch = data_loader.epoch + data_batch = next(data_loader) + self.data_batch = data_batch + scale_code = data_batch['scale_code'] + for stage in range(int(torch.max(scale_code)+1)): + kwargs['stage'] = stage + self.call_hook('before_train_iter') + outputs = self.model.train_step(data_batch, self.optimizer, **kwargs) + if not isinstance(outputs, dict): + raise TypeError('model.train_step() must return a dict') + if 'log_vars' in outputs: + if outputs['log_vars']['loss']==0.: + continue + self.log_buffer.update(outputs['log_vars'], outputs['num_samples']) + self.log_buffer.output['stage'] = stage + self.outputs = outputs + self.call_hook('after_train_iter') + del self.data_batch + self._inner_iter += 1 + self._iter += 1 + +class BungeeNerfTestRunner(EpochBasedRunner): + """BungeeNerfTestRunner""" + pass + diff --git a/xrnerf/datasets/__init__.py b/xrnerf/datasets/__init__.py index 943a6d9..5781914 100644 --- a/xrnerf/datasets/__init__.py +++ b/xrnerf/datasets/__init__.py @@ -8,10 +8,11 @@ from .samplers import DistributedSampler from .scene_dataset import SceneBaseDataset from .genebody_dataset import GeneBodyDataset +from .bungee_dataset import BungeeDataset __all__ = [ 'SceneBaseDataset', 'DATASETS', 'build_dataset', 'DistributedSampler', 'MipMultiScaleDataset', 'KiloNerfDataset', 'KiloNerfNodeDataset', - 'NeuralBodyDataset', 'AniNeRFDataset', 'HashNerfDataset', 'GeneBodyDataset' - + 'NeuralBodyDataset', 'AniNeRFDataset', 'HashNerfDataset', 'GeneBodyDataset', + 'BungeeDataset' ] diff --git a/xrnerf/datasets/bungee_dataset.py b/xrnerf/datasets/bungee_dataset.py new file mode 100644 index 0000000..0bb7f36 --- /dev/null +++ b/xrnerf/datasets/bungee_dataset.py @@ -0,0 +1,75 @@ +# # Copyright (c) OpenMMLab. All rights reserved. + +import numpy as np +import torch + +from .builder import DATASETS +from .scene_dataset import SceneBaseDataset +from .load_data import load_data, load_rays_bungee + + +@DATASETS.register_module() +class BungeeDataset(SceneBaseDataset): + def __init__(self, cfg, pipeline): + self.cur_stage = cfg.cur_stage + super().__init__(cfg, pipeline) + + def _init_load(self): # load dataset when init + self.images, self.poses, self.render_poses, self.hwf, self.K, self.scene_scaling_factor, self.scene_origin, self.scale_split, self.i_train, self.i_val, self.i_test, self.n_images = load_data(self.cfg) + + + if self.is_batching and self.mode == 'train': + # for batching dataset, rays must be computed when init() + self.N_rand = self.cfg.N_rand_per_sampler + self.rays_rgb, self.radii, self.scale_codes = load_rays_bungee(self.hwf[0], self.hwf[1], self.hwf[2], + self.poses, self.images, self.i_train, self.n_images, self.scale_split, self.cur_stage) + + + def _fetch_train_data(self, idx): + if self.is_batching: # for batching dataset, rays are randomly selected from all images + data = {'rays_rgb': self.rays_rgb, + 'radii': self.radii, + 'scale_code': self.scale_codes, + 'idx': idx} + else: # for batching dataset, rays are selected from one images + data = { + 'poses': self.poses, + 'images': self.images, + 'n_images': self.n_images, + 'i_data': self.i_train, + 'idx': idx + } + data['iter_n'] = self.iter_n + return data + + def _fetch_val_data(self, idx): # for val mode, fetch all data in one time + data = { + 'spiral_poses':self.render_poses, + 'poses':self.poses[self.i_test], + 'images':self.images[self.i_test], + } + return data + + def _fetch_test_data(self, idx): # different from val: test return one image once + data = { + 'pose':self.poses[self.i_test][idx], + 'image':self.images[self.i_test][idx], + 'idx':idx + } + return data + + def get_info(self): + res = { + 'H': self.hwf[0], + 'W': self.hwf[1], + 'focal': self.hwf[2], + 'K': self.K, + 'render_poses': self.render_poses, + 'hwf': self.hwf, + 'cur_stage': self.cur_stage, + 'scene_origin': self.scene_origin, + 'scene_scaling_factor': self.scene_scaling_factor, + 'scale_split': self.scale_split, + } + return res + diff --git a/xrnerf/datasets/load_data/__init__.py b/xrnerf/datasets/load_data/__init__.py index 2598dcb..54ac225 100644 --- a/xrnerf/datasets/load_data/__init__.py +++ b/xrnerf/datasets/load_data/__init__.py @@ -1,5 +1,5 @@ from .get_rays import (get_rays_np, load_rays, load_rays_hash, - load_rays_multiscale) + load_rays_multiscale, load_rays_bungee) from .load import load_data __all__ = [ @@ -8,4 +8,5 @@ 'load_rays', 'load_rays_hash', 'load_rays_multiscale', + 'load_rays_bungee' ] diff --git a/xrnerf/datasets/load_data/get_rays.py b/xrnerf/datasets/load_data/get_rays.py index d333ecf..37689fd 100644 --- a/xrnerf/datasets/load_data/get_rays.py +++ b/xrnerf/datasets/load_data/get_rays.py @@ -1,5 +1,5 @@ import numpy as np - +import torch def get_rays_np(H, W, K, c2w): i, j = np.meshgrid(np.arange(W, dtype=np.float32), @@ -150,3 +150,48 @@ def broadcast_scalar_attribute(x): near=near, far=far) return rays + + +def get_rays_np_bungee(H, W, focal, c2w): + i, j = np.meshgrid(np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32), indexing='xy') + dirs = np.stack([(i-W*.5)/focal, -(j-H*.5)/focal, -np.ones_like(i)], -1) + dirs = dirs/np.linalg.norm(dirs, axis=-1)[..., None] + rays_d = np.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) + rays_o = np.broadcast_to(c2w[:3,-1], np.shape(rays_d)) + return rays_o, rays_d + +def load_rays_bungee(H, W, focal, poses, images, i_data, n_images, scale_split, cur_stage): + # get scale codes + scale_codes = [] + prev_spl = n_images + cur_scale = 0 + for spl in scale_split[:cur_stage+1]: + scale_codes.append(np.tile(np.ones(((prev_spl-spl),1,1,1))*cur_scale, (1,H,W,1))) + prev_spl = spl + cur_scale += 1 + scale_codes = np.concatenate(scale_codes, 0) + scale_codes = scale_codes.astype(np.int64) + # [N, ro+rd, H, W, 3] + rays = np.stack([get_rays_np_bungee(H, W, focal, p) for p in poses], 0) + directions = rays[:,1,:,:,:] + dx = np.sqrt( + np.sum((directions[:, :-1, :, :] - directions[:, 1:, :, :])**2, -1)) + dx = np.concatenate([dx, dx[:, -2:-1, :]], 1) + radii = dx[..., None] * 2 / np.sqrt(12) + + # [N, ro+rd+rgb, H, W, 3] + rays_rgb = np.concatenate([rays, images[:,None]], 1) + rays_rgb = np.transpose(rays_rgb, [0,2,3,1,4]) + rays_rgb = np.stack([rays_rgb[i] for i in i_data], 0) + radii = np.stack([radii[i] for i in i_data], 0) + scale_codes = np.stack([scale_codes[i] for i in i_data], 0) + + rays_rgb = np.reshape(rays_rgb, [-1,3,3]) + radii = np.reshape(radii, [-1, 1]) + scale_codes = np.reshape(scale_codes, [-1, 1]) + + rand_idx = torch.randperm(rays_rgb.shape[0]) + rays_rgb = rays_rgb[rand_idx.cpu().data.numpy()] + radii = radii[rand_idx.cpu().data.numpy()] + scale_codes = scale_codes[rand_idx.cpu().data.numpy()] + return rays_rgb, radii, scale_codes diff --git a/xrnerf/datasets/load_data/load.py b/xrnerf/datasets/load_data/load.py index c113f77..08c5d94 100644 --- a/xrnerf/datasets/load_data/load.py +++ b/xrnerf/datasets/load_data/load.py @@ -6,6 +6,9 @@ from .load_llff import load_llff_data from .load_multiscale import load_multiscale_data from .load_nsvf_dataset import load_nsvf_dataset +from .load_multiscale_google import load_google_data + + def load_data(args): @@ -141,6 +144,34 @@ def load_data(args): if render_subset != 'custom_path': render_poses = np.array(poses[i_render]) + + elif args.dataset_type == 'mutiscale_google': + images, poses, scene_scale, scene_origin, scale_split = load_google_data(args.datadir, args.factor) + n_images = len(images) + print('Load Multiscale Google', n_images) + if args.white_bkgd: + images = images[...,:3]*images[...,-1:] + (1.-images[...,-1:]) + else: + images = images[...,:3] + images = images[scale_split[args.cur_stage]:] + poses = poses[scale_split[args.cur_stage]:] + + + if args.holdout > 0: + i_test = np.arange(images.shape[0])[::args.holdout] + i_val = i_test + i_train = np.array([i for i in np.arange(int(images.shape[0])) if + (i not in i_test)]) + + hwf = poses[0, :3, -1] + poses = poses[:,:3,:4] + H, W, focal = hwf + H, W = int(H), int(W) + hwf = [H, W, focal] + K = np.array([[focal, 0, 0.5 * W], [0, focal, 0.5 * H], [0, 0, 1]]) + render_poses = np.array(poses[i_test]) + return images, poses, render_poses, hwf, K, scene_scale, scene_origin, scale_split, i_train, i_val, i_test, n_images + else: print('Unknown dataset type', args.dataset_type, 'exiting') return @@ -158,3 +189,4 @@ def load_data(args): # exit(0) return images, poses, render_poses, hwf, K, near, far, i_train, i_val, i_test + diff --git a/xrnerf/datasets/load_data/load_multiscale_google.py b/xrnerf/datasets/load_data/load_multiscale_google.py new file mode 100644 index 0000000..33f7722 --- /dev/null +++ b/xrnerf/datasets/load_data/load_multiscale_google.py @@ -0,0 +1,38 @@ +import numpy as np +import os +import json +import cv2 + + + +def load_google_data(datadir, factor=None): + imgdir = os.path.join(datadir, 'images') + imgfiles = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir)) if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png') or f.endswith('jpeg')] + imgs = [ + f for f in imgfiles + if any([f.endswith(ex) for ex in ['JPG', 'jpg', 'png', 'jpeg', 'PNG']]) + ] + + sh = np.array(cv2.imread(imgfiles[0]).shape) + imgs = [] + for f in imgfiles: + im = cv2.imread(f, cv2.IMREAD_UNCHANGED) + if im.shape[-1] == 3: + im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) + else: + im = cv2.cvtColor(im, cv2.COLOR_BGRA2RGBA) + im = cv2.resize(im, (sh[1]//factor, sh[0]//factor), interpolation=cv2.INTER_AREA) + im = im.astype(np.float32) / 255 + imgs.append(im) + imgs = np.stack(imgs, -1) + imgs = np.moveaxis(imgs, -1, 0).astype(np.float32) + + data = json.load(open(os.path.join(datadir, 'poses_enu.json'))) + poses = np.array(data['poses'])[:, :-2].reshape([-1, 3, 5]) + poses[:, :2, 4] = np.array(sh[:2]//factor).reshape([1, 2]) + poses[:, 2, 4] = poses[:,2, 4] * 1./factor + + scene_scale = data['scene_scale'] + scene_origin = np.array(data['scene_origin']) + scale_split = data['scale_split'] + return imgs, poses, scene_scale, scene_origin, scale_split diff --git a/xrnerf/datasets/pipelines/augment.py b/xrnerf/datasets/pipelines/augment.py index a63d8f0..4c656ce 100644 --- a/xrnerf/datasets/pipelines/augment.py +++ b/xrnerf/datasets/pipelines/augment.py @@ -7,7 +7,6 @@ from ..builder import PIPELINES - @PIPELINES.register_module() class SelectRays: """random select rays when training diff --git a/xrnerf/datasets/pipelines/compose.py b/xrnerf/datasets/pipelines/compose.py index aeb18e7..042c340 100644 --- a/xrnerf/datasets/pipelines/compose.py +++ b/xrnerf/datasets/pipelines/compose.py @@ -62,6 +62,7 @@ def to_tensor(data): if isinstance(data, torch.Tensor): return data if isinstance(data, np.ndarray): + # Need to be restored return torch.from_numpy(data) if isinstance(data, Sequence) and not mmcv.is_str(data): return torch.tensor(data) diff --git a/xrnerf/datasets/pipelines/create.py b/xrnerf/datasets/pipelines/create.py index a9d838e..a6f3d1f 100644 --- a/xrnerf/datasets/pipelines/create.py +++ b/xrnerf/datasets/pipelines/create.py @@ -37,13 +37,13 @@ def __call__(self, results): img_i = results['i_data'][idx] results['pose'] = results['poses'][img_i, :3, :4] results['target_s'] = results['images'][img_i] + return results def __repr__(self): return '{}:slice a batch of rays from all rays'.format( self.__class__.__name__) - @PIPELINES.register_module() class MipMultiScaleSample: """sample from dataset @@ -89,6 +89,7 @@ def __init__(self, enable=True, N_rand=1024, **kwargs): self.N_rand = N_rand # slice how many rays one time self.kwargs = kwargs + def __call__(self, results): """BatchSlice Args: @@ -109,6 +110,43 @@ def __call__(self, results): def __repr__(self): return '{}:sample a batch of rays from all rays'.format( self.__class__.__name__) + +@PIPELINES.register_module() +class BungeeBatchSample: + """get slice rays from all rays in batching dataset + Args: + keys (Sequence[str]): Required keys to be converted. + """ + def __init__(self, enable=True, N_rand=1024, **kwargs): + self.enable = enable + self.N_rand = N_rand # slice how many rays one time + self.kwargs = kwargs + + def __call__(self, results): + """BatchSlice + Args: + results (dict): The resulting dict to be modified and passed + to the next transform in pipeline. + """ + if self.enable: + start_i = self.N_rand * results['idx'] + batch_rays = results['rays_rgb'][start_i:start_i + + self.N_rand] # [B, 2+1, 3*?] + results['rays_o'], results['rays_d'], results[ + 'target_s'] = batch_rays[:, + 0, :], batch_rays[:, + 1, :], batch_rays[:, + 2, :] + results['radii'] = results['radii'][start_i:start_i + + self.N_rand] + results['scale_code'] = results['scale_code'][start_i:start_i + + self.N_rand] + + return results + + def __repr__(self): + return '{}:sample a batch of rays from all rays'.format( + self.__class__.__name__) @PIPELINES.register_module() @@ -185,6 +223,7 @@ def __call__(self, results): dirs = torch.stack([(i - K[0][2]) / K[0][0], -(j - K[1][2]) / K[1][1], -torch.ones_like(i)], -1).to(device) + # Rotate ray directions from camera frame to the world frame rays_d = torch.sum( dirs[..., np.newaxis, :] * c2w[:3, :3], @@ -496,6 +535,43 @@ def __repr__(self): self.__class__.__name__) + +@PIPELINES.register_module() +class BungeeGetZvals: + """get intervals between samples + Args: + keys (Sequence[str]): Required keys to be converted. + """ + def __init__(self, + enable=True, + N_samples=64, + **kwargs): + self.enable = enable + self.N_samples = N_samples + + def __call__(self, results): + """get intervals between samples + Args: + results (dict): The resulting dict to be modified and passed + to the next transform in pipeline. + """ + if self.enable: + device = results['rays_o'].device + N_rays = results['rays_o'].shape[0] + t_vals = torch.linspace(0., 1., steps=self.N_samples).to(device) + z_vals_lindisp = 1./(1./results['near'] * (1.-t_vals) + 1./results['far'] * (t_vals)) + z_vals_lindisp_half = z_vals_lindisp[:,:int(self.N_samples*2/3)] + linear_start = z_vals_lindisp_half[:,-1:] + t_vals_linear = torch.linspace(0., 1., steps=self.N_samples-int(self.N_samples*2/3)+1).to(device) + z_vals_linear_half = linear_start * (1-t_vals_linear) + results['far'] * t_vals_linear + z_vals = torch.cat((z_vals_lindisp_half, z_vals_linear_half[:,1:]), -1) + z_vals, _ = torch.sort(z_vals, -1) + z_vals = z_vals.expand([N_rays, self.N_samples]) + results['z_vals'] = z_vals + return results + + + @PIPELINES.register_module() class GetPts: """get pts @@ -749,3 +825,63 @@ def __call__(self, results): def __repr__(self): return '{}:load the Camera and SMPL parameters'.format( self.__class__.__name__) + + + +@PIPELINES.register_module() +class BungeeGetBounds: + """get near and far + Args: + keys (Sequence[str]): Required keys to be converted. + """ + def __init__(self, enable=True, ray_nearfar='sphere', **kwargs): + self.enable = enable + # kwargs来自于dataset读取完毕后,记录的datainfo信息 + self.ray_nearfar = ray_nearfar + self.kwargs = kwargs + + def __call__(self, results): + """get bound(near and far) + Args: + results (dict): The resulting dict to be modified and passed + to the next transform in pipeline. + """ + if self.enable: + scene_origin = self.kwargs['scene_origin'] + scene_scaling_factor = self.kwargs['scene_scaling_factor'] + device = results['rays_o'].device + if self.ray_nearfar == 'sphere': + globe_center = torch.tensor(np.array(scene_origin) * scene_scaling_factor).float().to(device) + # 6371011 is earth radius, 250 is the assumed height limitation of buildings in the scene + earth_radius = 6371011 * scene_scaling_factor + earth_radius_plus_bldg = (6371011+250) * scene_scaling_factor + # intersect with building upper limit sphere + delta = (2*torch.sum((results['rays_o']-globe_center) * results['viewdirs'], dim=-1))**2 - 4*torch.norm(results['viewdirs'], dim=-1)**2 * (torch.norm((results['rays_o']-globe_center), dim=-1)**2 - (earth_radius_plus_bldg)**2) + d_near = (-2*torch.sum((results['rays_o']-globe_center) * results['viewdirs'], dim=-1) - delta**0.5) / (2*torch.norm(results['viewdirs'], dim=-1)**2) + rays_start = results['rays_o'] + (d_near[...,None]*results['viewdirs']) + # intersect with earth + delta = (2*torch.sum((results['rays_o']-globe_center) * results['viewdirs'], dim=-1))**2 - 4*torch.norm(results['viewdirs'], dim=-1)**2 * (torch.norm((results['rays_o']-globe_center), dim=-1)**2 - (earth_radius)**2) + d_far = (-2*torch.sum((results['rays_o']-globe_center) * results['viewdirs'], dim=-1) - delta**0.5) / (2*torch.norm(results['viewdirs'], dim=-1)**2) + rays_end = results['rays_o'] + (d_far[...,None]*results['viewdirs']) + # compute near and far for each ray + new_near = torch.norm(results['rays_o'] - rays_start, dim=-1, keepdim=True) + near = new_near * 0.9 + new_far = torch.norm(results['rays_o'] - rays_end, dim=-1, keepdim=True) + far = new_far * 1.1 + elif self.ray_nearfar == 'flat': + normal = torch.tensor([0, 0, 1]).to(results['rays_o']) * scene_scaling_factor + p0_far = torch.tensor([0, 0, 0]).to(results['rays_o']) * scene_scaling_factor + p0_near = torch.tensor([0, 0, 250]).to(results['rays_o']) * scene_scaling_factor + + near = (p0_near - results['rays_o'] * normal).sum(-1) / (results['viewdirs'] * normal).sum(-1) + far = (p0_far - results['rays_o'] * normal).sum(-1) / (results['viewdirs'] * normal).sum(-1) + near = near.clamp(min=1e-6) + near, far = near.unsqueeze(-1), far.unsqueeze(-1) + results['far'] = far + results['near'] = near + return results + + def __repr__(self): + return '{}:get bounds(near and far)'.format(self.__class__.__name__) + + diff --git a/xrnerf/models/embedders/__init__.py b/xrnerf/models/embedders/__init__.py index ea17d5b..4411e5d 100644 --- a/xrnerf/models/embedders/__init__.py +++ b/xrnerf/models/embedders/__init__.py @@ -4,8 +4,10 @@ from .mipnerf_embedder import MipNerfEmbedder from .neuralbody_embedder import SmplEmbedder from .gnr_embedder import SRFilters, HourGlass, HGFilter, PositionalEncoding, SphericalHarmonics +from .bungee_embedder import BungeeEmbedder __all__ = [ 'BaseEmbedder', 'MipNerfEmbedder', 'KiloNerfFourierEmbedder', - 'SmplEmbedder', 'SRFilters', 'HourGlass', 'HGFilter', 'PositionalEncoding', 'SphericalHarmonics' + 'SmplEmbedder', 'SRFilters', 'HourGlass', 'HGFilter', 'PositionalEncoding', 'SphericalHarmonics', + 'BungeeEmbedder' ] diff --git a/xrnerf/models/embedders/bungee_embedder.py b/xrnerf/models/embedders/bungee_embedder.py new file mode 100644 index 0000000..e8058f4 --- /dev/null +++ b/xrnerf/models/embedders/bungee_embedder.py @@ -0,0 +1,105 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch import nn + +from ..builder import EMBEDDERS + + +@EMBEDDERS.register_module() +class BungeeEmbedder(nn.Module): + def __init__(self, + i_embed=0, + multires=10, + multires_dirs=4, + input_ch=3, + **kwargs): + super().__init__() # 对于集成了nn.Module的类型,如果有可学习参数,必须加上这个 + if i_embed == -1: + self.embed_fns, self.embed_ch = [nn.Identity()], input_ch + self.embed_fns_dirs, self.embed_ch_dirs = [nn.Identity()], input_ch + else: + self.embed_fns, self.embed_ch = self.create_mip_embedding_fn( + multires, input_ch=input_ch) + self.embed_fns_dirs, self.embed_ch_dirs = self.create_embedding_fn( + multires_dirs, input_ch=input_ch) + + def create_mip_embedding_fn(self, + multires, + input_ch=3, + cat_input=True, + log_sampling=True, + periodic_fns=[torch.sin, torch.cos]): + num_freqs = multires + max_freq_log2 = multires - 1 + embed_fns = [] + out_dim = 0 + d = input_ch + if cat_input: + embed_fns.append(lambda x: x[:,:d]) + out_dim += d + N_freqs = num_freqs + max_freq = max_freq_log2 + + if log_sampling: + freq_bands_y = 2.**torch.linspace(0., max_freq, steps=N_freqs) + freq_bands_w = 4.**torch.linspace(0., max_freq, steps=N_freqs) + else: + freq_bands_y = torch.linspace(2.**0, 2.**max_freq, steps=N_freqs) + freq_bands_w = torch.linspace(4.**0, 4.**max_freq, steps=N_freqs) + for freq_y, freq_w in zip(freq_bands_y, freq_bands_w): + for p_fn in periodic_fns: + embed_fns.append(lambda inputs, p_fn=p_fn, freq_y=freq_y, freq_w=freq_w : p_fn(inputs[:,:d] * freq_y) * torch.exp((-0.5) * freq_w * inputs[:,d:])) + out_dim += d + return embed_fns, out_dim + + def create_embedding_fn(self, + multires, + input_ch=3, + cat_input=True, + log_sampling=True, + periodic_fns=[torch.sin, torch.cos]): + num_freqs = multires + max_freq_log2 = multires - 1 + embed_fns = [] + out_dim = 0 + d = input_ch + if cat_input: + embed_fns.append(lambda x: x) + out_dim += d + N_freqs = num_freqs + max_freq = max_freq_log2 + + if log_sampling: + freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs) + else: + freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs) + for freq in freq_bands: + for p_fn in periodic_fns: + embed_fns.append( + lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq)) + out_dim += d + return embed_fns, out_dim + + def get_embed_ch(self): + return self.embed_ch, self.embed_ch_dirs + + def forward(self, data): + means, cov_diags = data['samples'] + means_flat = torch.reshape(means, [-1, means.shape[-1]]) + cov_diags_flat = torch.reshape(cov_diags, [-1, cov_diags.shape[-1]]) + inputs_flat = torch.cat((means_flat, cov_diags_flat), -1) + embedded = self.run_embed(inputs_flat, self.embed_fns) + + viewdirs = data['viewdirs'] + input_dirs = viewdirs[:,None].expand(means.shape) + input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]]) + embedded_dirs = self.run_embed(input_dirs_flat, self.embed_fns_dirs) + + + embedded = torch.cat([embedded, embedded_dirs], -1) + data['unflatten_shape'] = data['samples'][0].shape[:-1] + data['embedded'] = embedded + return data + + def run_embed(self, x, embed_fns): + return torch.cat([fn(x) for fn in embed_fns], -1) diff --git a/xrnerf/models/mlps/__init__.py b/xrnerf/models/mlps/__init__.py index 1504959..4a38365 100644 --- a/xrnerf/models/mlps/__init__.py +++ b/xrnerf/models/mlps/__init__.py @@ -6,6 +6,7 @@ from .nb_mlp import NB_NeRFMLP from .nerf_mlp import NerfMLP from .gnr_mlp import GNRMLP +from .bungeenerf_mlp import BungeeNerfMLP __all__ = [ 'NerfMLP', @@ -15,5 +16,6 @@ 'DeformField', 'NB_NeRFMLP', 'HashNerfMLP', - 'GNRMLP' + 'GNRMLP', + 'BungeeNerfMLP' ] diff --git a/xrnerf/models/mlps/bungeenerf_mlp.py b/xrnerf/models/mlps/bungeenerf_mlp.py new file mode 100644 index 0000000..bbd5799 --- /dev/null +++ b/xrnerf/models/mlps/bungeenerf_mlp.py @@ -0,0 +1,115 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn.functional as F +from torch import nn + +from .. import builder +from ..builder import MLPS +from .base import BaseMLP + + +class BungeeNerfBaseBlock(nn.Module): + def __init__(self, netwidth=256, input_ch=3, input_ch_views=3): + super(BungeeNerfBaseBlock, self).__init__() + self.pts_linears = nn.ModuleList([nn.Linear(input_ch, netwidth)] + [nn.Linear(netwidth, netwidth) for _ in range(3)]) + self.views_linear = nn.Linear(input_ch_views + netwidth, netwidth//2) + self.feature_linear = nn.Linear(netwidth, netwidth) + self.alpha_linear = nn.Linear(netwidth, 1) + self.rgb_linear = nn.Linear(netwidth//2, 3) + + def forward(self, input_pts, input_views): + h = input_pts.float() + for i, l in enumerate(self.pts_linears): + h = self.pts_linears[i](h) + h = F.relu(h) + alpha = self.alpha_linear(h) + feature0 = self.feature_linear(h) + h0 = torch.cat([feature0, input_views], -1) + h0 = self.views_linear(h0) + h0 = F.relu(h0) + rgb = self.rgb_linear(h0) + return rgb, alpha, h + + +class BungeeNerfResBlock(nn.Module): + def __init__(self, netwidth=256, input_ch=3, input_ch_views=3): + super(BungeeNerfResBlock, self).__init__() + self.pts_linears = nn.ModuleList([nn.Linear(input_ch+netwidth, netwidth), nn.Linear(netwidth, netwidth)]) + self.views_linear = nn.Linear(input_ch_views + netwidth, netwidth//2) + self.feature_linear = nn.Linear(netwidth, netwidth) + self.alpha_linear = nn.Linear(netwidth, 1) + self.rgb_linear = nn.Linear(netwidth//2, 3) + + def forward(self, input_pts, input_views, h): + h = torch.cat([input_pts, h], -1) + for i, l in enumerate(self.pts_linears): + h = self.pts_linears[i](h) + h = F.relu(h) + alpha = self.alpha_linear(h) + feature0 = self.feature_linear(h) + h0 = torch.cat([feature0, input_views], -1) + h0 = self.views_linear(h0) + h0 = F.relu(h0) + rgb = self.rgb_linear(h0) + return rgb, alpha, h + +@MLPS.register_module() +class BungeeNerfMLP(BaseMLP): + + def __init__(self, + cur_stage=0, + netwidth=256, + netchunk=1024 * 32, + embedder=None, + **kwarg): + super().__init__() # 对于集成了nn.Module的类型,如果有可学习参数,必须加上这个 + self.chunk = netchunk + self.embedder = builder.build_embedder(embedder) + self.num_resblocks = cur_stage + self.init_mlp(netwidth) + + def init_mlp(self, netwidth): + W = netwidth + self.input_ch, self.input_ch_dirs = self.embedder.get_embed_ch() + self.baseblock = BungeeNerfBaseBlock(netwidth=W, input_ch=self.input_ch, input_ch_views=self.input_ch_dirs) + self.resblocks = nn.ModuleList([ + BungeeNerfResBlock(netwidth=W, input_ch=self.input_ch, input_ch_views=self.input_ch_dirs) for _ in range(self.num_resblocks) + ]) + return + + def forward(self, data): + + data = self.embedder(data) + data['embedded'] = data['embedded'].float() + outputs_flat = self.batchify_run_mlp(data['embedded']) + data['raw'] = torch.reshape( + outputs_flat, + list(data['unflatten_shape']) + list(outputs_flat.shape[1:])) + del data['unflatten_shape'] + return data + + def batchify_run_mlp(self, x): + if self.chunk is None: + return self.run_mlp(x) + else: + outputs = torch.cat([ + self.run_mlp(x[i:i + self.chunk]) + for i in range(0, x.shape[0], self.chunk) + ], 0) + return outputs + + def run_mlp(self, x): + input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_dirs], dim=-1) + alphas = [] + rgbs = [] + base_rgb, base_alpha, h = self.baseblock(input_pts, input_views) + alphas.append(base_alpha) + rgbs.append(base_rgb) + for i in range(self.num_resblocks): + res_rgb, res_alpha, h = self.resblocks[i](input_pts, input_views, h) + alphas.append(res_alpha) + rgbs.append(res_rgb) + + outputs = torch.cat([torch.stack(rgbs,1),torch.stack(alphas,1)],-1) + + return outputs diff --git a/xrnerf/models/networks/__init__.py b/xrnerf/models/networks/__init__.py index dfc77b3..45b280b 100644 --- a/xrnerf/models/networks/__init__.py +++ b/xrnerf/models/networks/__init__.py @@ -7,8 +7,9 @@ from .neuralbody import NeuralBodyNetwork from .student_nerf import StudentNerfNetwork from .gnr import GnrNetwork +from .bungeenerf import BungeeNerfNetwork __all__ = [ 'NerfNetwork', 'MipNerfNetwork', 'KiloNerfNetwork', 'StudentNerfNetwork', - 'NeuralBodyNetwork', 'AniNeRFNetwork', 'GnrNetwork' + 'NeuralBodyNetwork', 'AniNeRFNetwork', 'GnrNetwork', 'BungeeNerfNetwork' ] diff --git a/xrnerf/models/networks/bungeenerf.py b/xrnerf/models/networks/bungeenerf.py new file mode 100644 index 0000000..1e1e51e --- /dev/null +++ b/xrnerf/models/networks/bungeenerf.py @@ -0,0 +1,177 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import time + +import torch +from mmcv.runner import get_dist_info +from torch import nn +from tqdm import tqdm + +from .. import builder +from ..builder import NETWORKS +from .base import BaseNerfNetwork +from .utils import * + +@NETWORKS.register_module() +class BungeeNerfNetwork(BaseNerfNetwork): + """There are 3 kinds of forward mode for Network: + + 1. 'train': phase=='train' and use 'train_step()' to forward, input a batch of rays + 2. 'val': phase=='train' and 'val_step()' to forward, input all testset's poses&images in one 'val_step()' + 3. 'test': phase=='test' and 'test_step()' to forward, input all testset one by one + """ + def __init__(self, cfg, mlp=None, render=None): + super().__init__() + + self.phase = cfg.get('phase', 'train') + if 'chunk' in cfg: self.chunk = cfg.chunk + if 'bs_data' in cfg: self.bs_data = cfg.bs_data + if 'is_perturb' in cfg: self.is_perturb = cfg.is_perturb + if 'N_importance' in cfg: self.N_importance = cfg.N_importance + self.resample_padding = cfg.resample_padding + self.ray_shape = cfg.ray_shape + if mlp is not None: + self.mlp = builder.build_mlp(mlp) + if render is not None: + self.render = builder.build_render(render) + + def forward(self, data, is_test=False): + randomized = not is_test + data = sample_along_rays(data, self.ray_shape) + data, ret = self.render(self.mlp(data), is_test) + if self.N_importance > 0: + data = resample_along_rays(data, randomized, self.ray_shape, + self.resample_padding) + _, ret2 = self.render(self.mlp(data), is_test) + + ret = merge_ret(ret, ret2) # add fine-net's returns to ret + + return ret + + def batchify_forward(self, data, is_test=False): + """forward in smaller minibatches to avoid OOM.""" + # self.bs_data's shape[0] indicates the real batch-size, this's also the num of rays + N = data[self.bs_data].shape[0] + all_ret = {} + for i in range(0, N, self.chunk): + data_chunk = {} + for k in data: + if data[k].shape[0] == N: + data_chunk[k] = data[k][i:i + self.chunk] + else: + data_chunk[k] = data[k] + + ret = self.forward(data_chunk, is_test) + + for k in ret: + if k not in all_ret: all_ret[k] = [] + all_ret[k].append(ret[k]) + all_ret = {k: torch.cat(all_ret[k], 0) for k in all_ret} + return all_ret + + def train_step(self, data, optimizer, **kwargs): + for k in data: + data[k] = unfold_batching(data[k]) + stage = kwargs['stage'] + self.render.stage = stage + ret = self.forward(data, is_test=False) + + img_loss = img2mse(ret['rgb']*(data['scale_code']<=stage), data['target_s']*(data['scale_code']<=stage)) + psnr = mse2psnr(img_loss) + loss = img_loss + + if 'coarse_rgb' in ret: + coarse_img_loss = img2mse(ret['coarse_rgb']*(data['scale_code']<=stage), data['target_s']*(data['scale_code']<=stage)) + loss = loss + coarse_img_loss + + log_vars = {'loss': loss.item(), 'psnr': psnr.item()} + outputs = { + 'loss': loss, + 'log_vars': log_vars, + 'num_samples': ret['rgb'].shape[0] + } + return outputs + + def val_step(self, data, optimizer=None, **kwargs): + if self.phase == 'test': + return self.test_step(data, **kwargs) + + rank, world_size = get_dist_info() + if rank == 0: + for k in data: + data[k] = unfold_batching(data[k]) + poses = data['poses'] + images = data['images'] + spiral_poses = data['spiral_poses'] + + rgbs, disps, gt_imgs = [], [], [] + elapsed_time_list = [] + for i in tqdm(range(poses.shape[0])): + start = time.time() + data = self.val_pipeline({'pose': poses[i]}) + ret = self.batchify_forward( + data, is_test=True) # 测试时 raw_noise_std=False + end = time.time() + # elapsed_time includes pipeline time and forward time + elapsed_time = end - start + rgb = recover_shape(ret['rgb'], data['src_shape']) + disp = recover_shape(ret['disp'], data['src_shape']) + rgbs.append(rgb.cpu().numpy()) + disps.append(disp.cpu().numpy()) + gt_imgs.append(images[i].cpu().numpy()) + elapsed_time_list.append(elapsed_time) + + spiral_rgbs, spiral_disps = [], [] + for i in tqdm(range(spiral_poses.shape[0])): + data = self.val_pipeline({'pose': spiral_poses[i]}) + ret = self.batchify_forward(data, is_test=True) + rgb = recover_shape(ret['rgb'], data['src_shape']) + disp = recover_shape(ret['disp'], data['src_shape']) + spiral_rgbs.append(rgb.cpu().numpy()) + spiral_disps.append(disp.cpu().numpy()) + + outputs = { + 'spiral_rgbs': spiral_rgbs, + 'spiral_disps': spiral_disps, + 'rgbs': rgbs, + 'disps': disps, + 'gt_imgs': gt_imgs, + 'elapsed_time': elapsed_time_list + } + else: + outputs = {} + return outputs + + def test_step(self, data, **kwargs): + """in mmcv's runner, there is only train_step and val_step so use. + + [val_step() + phase=='test'] to represent test. + """ + rank, world_size = get_dist_info() + if rank == 0: + for k in data: + data[k] = unfold_batching(data[k]) + + image = data['image'] + idx = data['idx'].item() + + data = self.val_pipeline({'pose': data['pose']}) + + ret = self.batchify_forward(data, is_test=True) + rgb = recover_shape(ret['rgb'], data['src_shape']) + + rgb = rgb.cpu().numpy() + image = image.cpu().numpy() + + outputs = {'rgb': rgb, 'gt_img': image, 'idx': idx} + + else: + outputs = {} + return outputs + + def set_val_pipeline(self, func): + self.val_pipeline = func + return + + + + diff --git a/xrnerf/models/networks/utils/__init__.py b/xrnerf/models/networks/utils/__init__.py index cf4d58a..2abd63e 100644 --- a/xrnerf/models/networks/utils/__init__.py +++ b/xrnerf/models/networks/utils/__init__.py @@ -35,5 +35,5 @@ 'LPIPS', 'SSIM', 'psnr', - 'init_weights' + 'init_weights', ] diff --git a/xrnerf/models/renders/__init__.py b/xrnerf/models/renders/__init__.py index 4ae4ee2..40255a3 100644 --- a/xrnerf/models/renders/__init__.py +++ b/xrnerf/models/renders/__init__.py @@ -4,6 +4,7 @@ from .mipnerf_render import MipNerfRender from .nerf_render import NerfRender from .gnr_render import GnrRenderer +from .bungeenerf_render import BungeeNerfRender __all__ = [ 'NerfRender', @@ -11,4 +12,5 @@ 'KiloNerfSimpleRender', 'HashNerfRender', 'GnrRenderer' + 'BungeeNerfRender' ] diff --git a/xrnerf/models/renders/bungeenerf_render.py b/xrnerf/models/renders/bungeenerf_render.py new file mode 100644 index 0000000..5a3b22b --- /dev/null +++ b/xrnerf/models/renders/bungeenerf_render.py @@ -0,0 +1,108 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn.functional as F +from torch import nn + +from ..builder import RENDERS +from .base import BaseRender + + +@RENDERS.register_module() +class BungeeNerfRender(BaseRender): + def __init__(self, + stage=0, + white_bkgd=False, + raw_noise_std=0, + rgb_padding=0, + density_bias=-1, + density_activation='softplus', + **kwarg): + super().__init__() # 对于集成了nn.Module的类型,如果有可学习参数,必须加上这个 + self.white_bkgd = white_bkgd + self.raw_noise_std = raw_noise_std + self.rgb_padding = rgb_padding + self.density_bias = density_bias + self.stage = stage + + if density_activation == 'softplus': # Density activation. + self.density_activation = F.softplus + elif density_activation == 'relu': + self.density_activation = F.relu + else: + raise NotImplementedError + + + def get_disp_map(self, weights, z_vals): + depth_map = torch.sum(weights * z_vals, -1) + disp_map = 1. / torch.max(1e-10 * torch.ones_like(depth_map), + depth_map / torch.sum(weights, -1)) + return disp_map + + + def get_weights(self, density_delta): + alpha = 1 - torch.exp(density_delta) + weights = alpha * torch.cumprod( + torch.cat([ + torch.ones( + (alpha.shape[0], 1)).to(alpha.device), 1. - alpha + 1e-10 + ], -1), -1)[:, :-1] + return weights + + def forward(self, data, is_test=False): + """Transforms model's predictions to semantically meaningful values. + + Args: + data: inputs + is_test: is_test + Returns: + rgb_map: [num_rays, 3]. Estimated RGB color of a ray. + disp_map: [num_rays]. Disparity map. Inverse of depth map. + acc_map: [num_rays]. Sum of weights along each ray. + weights: [num_rays, num_samples]. Weights assigned to each sampled color. + depth_map: [num_rays]. Estimated distance to object. + ret: return values + """ + raw = data['raw'] + z_vals = data['z_vals'] + # z_vals: [N_rays, N_samples] for nerf or [N_rays, N_samples+1] for mip + viewdirs = data['viewdirs'] + raw_noise_std = 0 if is_test else self.raw_noise_std + device = raw.device + z_vals = .5 * (z_vals[...,1:] + z_vals[...,:-1]) + dists = z_vals[..., 1:] - z_vals[..., :-1] + if dists.shape[1] != raw.shape[1]: # if z_val: [N_rays, N_samples] + dists = torch.cat([ + dists, + torch.Tensor([1e10]).to(device).expand(dists[..., :1].shape) + ], -1) # [N_rays, N_samples] + dists = dists * torch.norm(viewdirs[..., None, :], dim=-1) + + acc_rgb = torch.sum(raw[...,:self.stage+1,:3], dim=2) + + rgb = (1 + 2 * self.rgb_padding) / (1 + torch.exp(-acc_rgb)) - self.rgb_padding + + acc_alpha = torch.sum(raw[...,:self.stage+1,3], dim=2) + + noise = 0. + if raw_noise_std > 0.: + noise = torch.randn(acc_alpha.shape) * raw_noise_std + noise = noise.to(device) + + density_delta = -self.density_activation(acc_alpha + noise + + self.density_bias) * dists + + weights = self.get_weights(density_delta) + + rgb_map = torch.sum(weights[..., None] * rgb, -2) # [N_rays, 3] + disp_map = self.get_disp_map(weights, z_vals) + acc_map = torch.sum(weights, -1) + + if self.white_bkgd: + rgb_map = rgb_map + (1. - acc_map[..., None]) + + ret = {'rgb': rgb_map, 'disp': disp_map, 'acc': acc_map} + data['weights'] = weights # 放在data里面,给sample函数用 + + return data, ret + +