From a14357da6cada433d28bf11a45c7bcaace76c06e Mon Sep 17 00:00:00 2001 From: Krishna Murthy Date: Thu, 16 Apr 2020 22:38:01 -0400 Subject: [PATCH] Lint Signed-off-by: Krishna Murthy --- eval_nerf.py | 17 +++++++++++++---- nerf/models.py | 4 ++-- nerf/nerf_helpers.py | 16 +++++++++++----- nerf/train_utils.py | 7 +++++-- tiny_nerf.py | 3 +-- train_nerf.py | 14 +++++++------- 6 files changed, 39 insertions(+), 22 deletions(-) diff --git a/eval_nerf.py b/eval_nerf.py index cffc398..87c46fa 100644 --- a/eval_nerf.py +++ b/eval_nerf.py @@ -9,8 +9,15 @@ import yaml from tqdm import tqdm -from nerf import (CfgNode, get_ray_bundle, load_blender_data, load_llff_data, - models, get_embedding_function, run_one_iter_of_nerf) +from nerf import ( + CfgNode, + get_ray_bundle, + load_blender_data, + load_llff_data, + models, + get_embedding_function, + run_one_iter_of_nerf, +) def cast_to_image(tensor, dataset_type): @@ -85,7 +92,7 @@ def main(): encode_position_fn = get_embedding_function( num_encoding_functions=cfg.models.coarse.num_encoding_fn_xyz, include_input=cfg.models.coarse.include_input_xyz, - log_sampling=cfg.models.coarse.log_sampling_xyz + log_sampling=cfg.models.coarse.log_sampling_xyz, ) encode_direction_fn = None @@ -174,7 +181,9 @@ def main(): times_per_image.append(time.time() - start) if configargs.savedir: savefile = os.path.join(configargs.savedir, f"{i:04d}.png") - imageio.imwrite(savefile, cast_to_image(rgb[..., :3], cfg.dataset.type.lower())) + imageio.imwrite( + savefile, cast_to_image(rgb[..., :3], cfg.dataset.type.lower()) + ) if configargs.save_disparity_image: savefile = os.path.join(configargs.savedir, "disparity", f"{i:04d}.png") imageio.imwrite(savefile, cast_to_disparity_image(disp)) diff --git a/nerf/models.py b/nerf/models.py index 84bf9ac..c866c0e 100644 --- a/nerf/models.py +++ b/nerf/models.py @@ -136,7 +136,7 @@ def __init__( use_viewdirs=True, ): super(PaperNeRFModel, self).__init__() - + include_input_xyz = 3 if include_input_xyz else 0 include_input_dir = 3 if include_input_dir else 0 self.dim_xyz = include_input_xyz + 2 * 3 * num_encoding_fn_xyz @@ -161,7 +161,7 @@ def __init__( self.relu = torch.nn.functional.relu def forward(self, x): - xyz, dirs = x[..., :self.dim_xyz], x[..., self.dim_xyz:] + xyz, dirs = x[..., : self.dim_xyz], x[..., self.dim_xyz :] for i in range(8): if i == 4: x = self.layers_xyz[i](torch.cat((xyz, x), -1)) diff --git a/nerf/nerf_helpers.py b/nerf/nerf_helpers.py index f6e5fd1..4fcb372 100644 --- a/nerf/nerf_helpers.py +++ b/nerf/nerf_helpers.py @@ -130,14 +130,20 @@ def positional_encoding( encoding = [tensor] if include_input else [] frequency_bands = None if log_sampling: - frequency_bands = 2. ** torch.linspace( - 0., num_encoding_functions - 1, num_encoding_functions, - dtype=tensor.dtype, device=tensor.device, + frequency_bands = 2.0 ** torch.linspace( + 0.0, + num_encoding_functions - 1, + num_encoding_functions, + dtype=tensor.dtype, + device=tensor.device, ) else: frequency_bands = torch.linspace( - 2. ** 0., 2. ** (num_encoding_functions - 1), num_encoding_functions, - dtype=tensor.dtype, device=tensor.device + 2.0 ** 0.0, + 2.0 ** (num_encoding_functions - 1), + num_encoding_functions, + dtype=tensor.dtype, + device=tensor.device, ) for freq in frequency_bands: diff --git a/nerf/train_utils.py b/nerf/train_utils.py index 1ebb03b..c9fcfd9 100644 --- a/nerf/train_utils.py +++ b/nerf/train_utils.py @@ -180,13 +180,16 @@ def run_one_iter_of_nerf( for batch in batches ] synthesized_images = list(zip(*pred)) - synthesized_images = [torch.cat(image, dim=0) if image[0] is not None else (None) for image in synthesized_images] + synthesized_images = [ + torch.cat(image, dim=0) if image[0] is not None else (None) + for image in synthesized_images + ] if mode == "validation": synthesized_images = [ image.view(shape) if image is not None else None for (image, shape) in zip(synthesized_images, restore_shapes) ] - + # Returns rgb_coarse, disp_coarse, acc_coarse, rgb_fine, disp_fine, acc_fine # (assuming both the coarse and fine networks are used). if model_fine: diff --git a/tiny_nerf.py b/tiny_nerf.py index c088ec6..48b92e7 100644 --- a/tiny_nerf.py +++ b/tiny_nerf.py @@ -6,8 +6,7 @@ import torch from tqdm import tqdm, trange -from nerf import (cumprod_exclusive, get_minibatches, get_ray_bundle, - positional_encoding) +from nerf import cumprod_exclusive, get_minibatches, get_ray_bundle, positional_encoding def compute_query_points_from_rays( diff --git a/train_nerf.py b/train_nerf.py index 8ebb62d..7b20219 100644 --- a/train_nerf.py +++ b/train_nerf.py @@ -10,9 +10,9 @@ from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm, trange -from nerf import (CfgNode, get_ray_bundle, img2mse, load_blender_data, - load_llff_data, meshgrid_xy, models, mse2psnr, - get_embedding_function, run_one_iter_of_nerf) +from nerf import (CfgNode, get_embedding_function, get_ray_bundle, img2mse, + load_blender_data, load_llff_data, meshgrid_xy, models, + mse2psnr, run_one_iter_of_nerf) def main(): @@ -63,7 +63,7 @@ def main(): H, W = int(H), int(W) hwf = [H, W, focal] if cfg.nerf.train.white_background: - images = images[..., :3] * images[..., -1:] + (1. - images[..., -1:]) + images = images[..., :3] * images[..., -1:] + (1.0 - images[..., -1:]) elif cfg.dataset.type.lower() == "llff": images, poses, bds, render_poses, i_test = load_llff_data( cfg.dataset.basedir, factor=cfg.dataset.downsample_factor @@ -104,7 +104,7 @@ def main(): include_input=cfg.models.coarse.include_input_xyz, log_sampling=cfg.models.coarse.log_sampling_xyz, ) - + encode_direction_fn = None if cfg.models.coarse.use_viewdirs: encode_direction_fn = get_embedding_function( @@ -250,7 +250,7 @@ def main(): rgb_fine[..., :3], target_ray_values[..., :3] ) # loss = torch.nn.functional.mse_loss(rgb_pred[..., :3], target_s[..., :3]) - loss = 0. + loss = 0.0 # if fine_loss is not None: # loss = fine_loss # else: @@ -337,7 +337,7 @@ def main(): ) target_ray_values = img_target coarse_loss = img2mse(rgb_coarse[..., :3], target_ray_values[..., :3]) - loss, fine_loss = 0., 0. + loss, fine_loss = 0.0, 0.0 if rgb_fine is not None: fine_loss = img2mse(rgb_fine[..., :3], target_ray_values[..., :3]) loss = fine_loss