EasyVolcap: Accelerating Neural Volumetric Video Research
Paper | arXiv | Example Dataset | Pretrained Model | 4K4D
News:
- 24.08.06 Multiple feature updates. See changelog.md.
- 24.02.27 4K4D has been accepted to CVPR 2024.
- 23.12.13 EasyVolcap will be presented at SIGGRAPH Asia 2023, Sydney.
- 23.10.17 4K4D, a real-time 4D view synthesis algorithm developed using EasyVolcap, has been made public.
EasyVolcap is a PyTorch library for accelerating neural volumetric video research, particularly in areas of volumetric video capturing, reconstruction, and rendering.
EasyVolcap.Fastforward.mp4
Install only the core dependencies for running the viewer locally:
# Editable install, with dependencies from requirements.txt
pip install -v -e .
On Windows or older versions of Linux, you might end up with a CPU-only PyTorch installation by only running the above command since only the CPU version for Windows is available on PyPI (more info).
To install a cuda-enabled PyTorch, append the above command with an extra search link:
pip install -v -e . -f https://download.pytorch.org/whl/torch_stable.html
# also try directly installing PyTorch with pip install torch --index-url https://download.pytorch.org/whl/cu118 or https://download.pytorch.org/whl/cu121 if you're still unable to install a CUDA enabled version of PyTorch, as per suggested on https://pytorch.org/get-started/locally/
Or install all dependencies for development (this requires you to have a valid CUDA building environment with PyTorch already installed).
Note that you can just run these two in tandem to take care of PyTorch
and EasyVolcap
before compiling CUDA extensions:
# Editable install, with dependencies from requirements-devel.txt
pip install -v -e . # will install from requirements.txt
pip install -v -r requirements-devel.txt
Aside from running git pull
, you might also need to reregister the command lines and code path by running pip install -e . --no-build-isolation --no-deps
again.
A notable example is when updating to [4K4D](https://github.com/zju3dv/4K4D], you're required to rerun the editable install command to use that repository instead of this one.
In the following sections, we'll show examples of how to run EasyVolcap on a small multi-view video dataset with several of our implemented algorithms, including Instant-NGP+T, 3DGS+T, and ENeRFi (ENeRF Improved).
In the documentation static.md
, we also provide a complete example of how to prepare the dataset using COLMAP and run the above-mentioned three models using EasyVolcap.
The example dataset for this section can be downloaded from this Google Drive link. After downloading the example dataset, place the unzipped files inside data/enerf_outdoor
such that you can see files like:
data/enerf_outdoor/actor1_4_subseq/images
data/enerf_outdoor/actor1_4_subseq/intri.yml
data/enerf_outdoor/actor1_4_subseq/extri.yml
This dataset is a small subset of the ENeRF-Outdoor dataset released by our team. For downloading the full dataset, please follow the guide in the link.
Before running the models, let's first prepare some shell variables for easy access.
expname=actor1_4_subseq
data_root=data/enerf_outdoor/actor1_4_subseq
enerfi_actor1_4_subseq_render.mp4
enerfi_actor1_4_subseq_demo.mp4
The pre-trained model for ENeRFi on the DTU dataset can be downloaded from this Google Drive link. After downloading, rename the model to latest.npz
and place it in data/trained_model/enerfi_dtu
.
# Render ENeRFi with pretrained model
evc-test -c configs/exps/enerfi/enerfi_${expname}.yaml,configs/specs/spiral.yaml,configs/specs/ibr.yaml runner_cfg.visualizer_cfg.save_tag=${expname} exp_name=enerfi_dtu
# Render ENeRFi with GUI
evc-gui -c configs/exps/enerfi/enerfi_${expname}.yaml exp_name=enerfi_dtu val_dataloader_cfg.dataset_cfg.ratio=0.5 # 2.5 FPS on 3060
If more performance is desired:
# Slightly worst quality, faster rendering
evc-gui -c configs/exps/enerfi/enerfi_${expname}.yaml exp_name=enerfi_dtu val_dataloader_cfg.dataset_cfg.ratio=0.5 model_cfg.sampler_cfg.n_planes=32,8 model_cfg.sampler_cfg.n_samples=4,1 # 3.6 FPS on 3060
# Fine quality, faster rendering
evc-gui -c configs/exps/enerfi/enerfi_${expname}.yaml,configs/specs/fp16.yaml exp_name=enerfi_dtu val_dataloader_cfg.dataset_cfg.ratio=0.5 # 3.6 FPS on 3060
# Worst quality, fastest rendering
evc-gui -c configs/exps/enerfi/enerfi_${expname}.yaml,configs/specs/fp16.yaml exp_name=enerfi_dtu val_dataloader_cfg.dataset_cfg.ratio=0.5 model_cfg.sampler_cfg.n_planes=32,8 model_cfg.sampler_cfg.n_samples=4,1 # 5.0 FPS on 3060
Note that EasyVolcap supports WebSocket-based server-side rendering. More info.
To use the WebSocket-based rendering, append the config server.yaml
to any of the native rendering commands beginning with evc-gui
.
# Run the rendering server, append `configs/specs/server.yaml` to the config file list
evc-gui -c configs/exps/enerfi/enerfi_${expname}.yaml,configs/specs/server.yaml exp_name=enerfi_dtu val_dataloader_cfg.dataset_cfg.ratio=0.5 model_cfg.sampler_cfg.n_planes=32,8 model_cfg.sampler_cfg.n_samples=4,1
Then run the viewer in your desired viewing client, tested on Windows, MacOS and Linux.
# Separate WebSocket Client parameter and evc parameter with --, for now, the viewer can be configured with evc
# Replace 10.76.5.252 with your server IP
# Replace -c configs/datasets/enerf_outdoor/enerf_outdoor.yaml with what ever other config to use
# The enerf_outdoor.yaml here provides a basic camera setup for the viewer
evc-ws --host 10.76.5.252 --port 1024 -- -c configs/datasets/enerf_outdoor/enerf_outdoor.yaml viewer_cfg.window_size="768,1366"
Note that this example requires you to have tiny-cuda-nn
installed. Guide.
We extend Instant-NGP to be time-aware, as a baseline method. With the data preparation completed, we've got an images
folder and a pair of intri.yml
and extri.yml
files, and we can run the l3mhet model.
Note that this model is not built for dynamic scenes, we train it here mainly for extracting initialization point clouds and computing a tighter bounding box.
Similar procedures can be applied to other datasets if such initialization is required.
We need to write a config file for this model
- Write the data-folder-related stuff inside configs/datasets. Just copy and paste
configs/datasets/enerf_outdoor/actor1_4_subseq.yaml
and modify thedata_root
andbounds
(bounding box), or maybe add a camera near-far threshold. - Write the experiment config inside configs/exps. Just copy and paste
configs/exps/l3mhet/l3mhet_actor1_4_subseq.yaml
and modify thedataset
-related line inconfigs
.
# With your config files ready, you can run the following command to train the model
evc-train -c configs/exps/l3mhet/l3mhet_${expname}.yaml
# Now run the following command to render some output
evc-test -c configs/exps/l3mhet/l3mhet_${expname}.yaml,configs/specs/spiral.yaml
# And maybe render the model with GUI in lower resolution
evc-gui -c configs/exps/l3mhet/l3mhet_${expname}.yaml viewer_cfg.render_ratio=0.15
configs/specs/spiral.yaml
: please check this file for more details, it's a collection of configs to tell the data loader and visualizer to generate a spiral path by interpolating the given cameras
Note that this example requires you to have diff_gauss
or diff_gaussian_rasterization
installed. Guide.
gaussiant_actor1_4_subseq_demo.mp4
The original 3DGS uses the sparse reconstruction result of COLMAP for initialization.
However, we found that the sparse reconstruction result often contains a lot of floating points, which is hard to prune for 3DGS and could easily make the model fail to converge.
Thus, we opted to use the "dense" reconstruction result of our Instant-NGP+T implementation by computing the RGBD image for input views and concatenating them as the input of 3DGS. The script volume_fusion.py
controls this process and it should work similarly on all models that support depth output.
The following script block provides an example of how to prepare an initialization for our 3DGS+T implementation.
# Extract geometry (point cloud) for initialization from the l3mhet model
# Tune image sample rate and resizing ratio for a denser or sparser estimation
python scripts/fusion/volume_fusion.py -- -c configs/exps/l3mhet/l3mhet_${expname}.yaml val_dataloader_cfg.dataset_cfg.ratio=0.15
# Move the rendering results to the dataset folder
source_folder="data/geometry/l3mhet_${expname}/POINT"
destination_folder="${data_root}/vhulls"
# Create the destination directory if it doesn't exist
mkdir -p ${destination_folder}
# Loop through all .ply files in the source directory
for file in ${source_folder}/*.ply; do
number=$(echo $(basename ${file}) | sed -e 's/frame\([0-9]*\).ply/\1/')
formatted_number=$(printf "%06d" ${number})
destination_file="${destination_folder}/${formatted_number}.ply"
cp ${file} ${destination_file}
done
Our conventions for storing initialization point clouds:
- Raw point clouds extracted using Instant-NGP or Space Carving are placed inside the
vhulls
folder. These files might be large. It's OK to directly optimize 3DGS+T on these. - We might perform some cleanup of the point clouds and store them in the
surfs
folder.- For 3DGS+T, the cleaned-up point clouds might be easier to optimize since 3DGS is good at growing details but not so good at dealing with floaters (removing or splitting).
- For other representations, the cleaned-up point clouds work better than the visual hull (from Space Carving) but might not work so well as the raw point clouds of Instant-NGP.
Then, prepare an experiment config like configs/exps/gaussiant/gaussiant_actor1_4_subseq.yaml
.
The colmap.yaml
provides some heuristics for large-scale static scenes. Remove these if you're not planning on using COLMAP's parameters directly.
# Train a 3DGS model on the ${expname} dataset
evc-train -c configs/exps/gaussiant/gaussiant_${expname}.yaml # might run out of VRAM, try reducing densify until iter
# Perform rendering on the trained ${expname} dataset
evc-test -c configs/exps/gaussiant/gaussiant_${expname}.yaml,configs/specs/superm.yaml,configs/specs/spiral.yaml
# Perform rendering with GUI, do this on a machine with monitor, tested on Windows and Ubuntu
evc-gui -c configs/exps/gaussiant/gaussiant_${expname}.yaml
The superm.yaml
skips the loading of input images and other initializations for network-only rendering since all the information we need is contained inside the trained model.
Most of the time when we want to build a new set of algorithms on top of the framework, we only have to worry about the actual network itself. Before writing your new volumetric video algorithm, we need a basic understanding of the network's input and output:
We use Python dictionaries for passing in and out network input and output.
- The
batch
variable stores the network input you sampled from the dataset (e.g. camera parameters). - The
output
key of thebatch
variable should contain the network output. For each network module's output definition, please refer to the design documents for them (camera
,sampler
,network
,renderer
) or just see the definitions involumetric_video_model.py
(therender_rays
function).
We support purely customized network construction & usage and also a unified NeRF-like pipeline.
- If your new network model's structure is similar to NeRF-based ones (i.e. with the separation of
sampler
,network
andrenderer
), you can simply swap out parts of thevolumetric_video_network.py
by writing a new config to swap thetype
parameter of the***_cfg
dictionaries. - If you'd like to build a completely new network model: to save you some hassle, we grant the
sampler
classes the ability to directly output the core network output (rgb_map
stored inbatch.output
). Define your rendering function and network structure however you like and reuse other parts of the codebase. An example:gaussiant_sampler.py
.
- TODO: Replace the custom sampler with a custom network, an example:
TemporalForestGaussianSplatting
A miminal custom module using all other EasyVolcap components should look something like this:
from easyvolcap.engine import SAMPLERS
from easyvolcap.utils.net_utils import VolumetricVideoModule
from easyvolcap.utils.console_utils import *
@SAMPLERS.register_module() # make the custom module callable by class name
class CustomVolumetricVideoModule(VolumetricVideoModule):
def __init__(self,
network, # ignore noop_network
... # configurable parameters
):
# Initialize custom network parameters
...
def forward(self, batch: dotdict):
# Perform network forwarding
...
# Store output for further processing
batch.output.rgb_map = ... # store rendered image for loss (B, N, 3)
In the respective config, select this module with:
model_cfg:
sampler_cfg:
type: CustomVolumetricVideoModule
data/dataset/sequence # data_root & data_root
├── intri.yml # required: intrinsics
├── extri.yml # required: extrinsics
└── images # required: source images
├── 000000 # camera / frame
│ ├── 000000.jpg # image
│ ├── 000001.jpg # for dynamic dataset, more images can be placed here
│ ...
│ ├── 000298.jpg # for dynamic dataset, more images can be placed here
│ └── 000299.jpg # for dynamic dataset, more images can be placed here
├── 000001
├── 000002
...
├── 000058
└── 000059
EasyVolcap is designed to work on the simplest data form: images
and no more. The key data preprocessing are done in the dataloader
and dataset
modules. These steps are done in the data loader's initialization
- We might correct the camera pose with their center of attention and world-up vector (
dataloader_cfg.dataset_cfg.use_aligned_cameras=True
). - We undistort read images from the disk using the intrinsic poses and store them as jpeg bytes in memory.
EasyVolcap now supports direct import from other locations and code bases.
After installing, you can not only directly use utility modules and functions from easyvolcap.utils
, but also import and build upon our core modules and classes.
# Import the logging and debugging functions
from easyvolcap.utils.console_utils import * # log, tqdm, @catch_throw
from easyvolcap.utils.timer_utils import timer # timer.record
from easyvolcap.utils.data_utils import export_pts, export_mesh, export_npz
...
# Import the OpenGL-based viewer and build upon it
from easyvolcap.runners.volumetric_video_viewer import VolumetricVideoViewer
class CustomViewer(VolumetricVideoViewer):
...
The import will work when actually running the code, but it might fail since some of the autocompletion modules is not fully compatible with the newest editable install.
If you see warnings when importing EasyVolcap in your editor like VSCode, you might want to add the path of your EasyVolcap codebase to the python.autoComplete.extraPaths
and python.analysis.extraPaths
like this:
{
"python.autoComplete.extraPaths": ["/home/zju3dv/code/easyvolcap"],
"python.analysis.extraPaths": ["/home/zju3dv/code/easyvolcap"]
}
Another solution is to replace the installation command of EasyVolcap with a compatible one using a compatible editable install:
pip install -e . --no-build-isolation --no-deps --config-settings editable_mode=compat
Note that this is marked deprecated in the PEP specification. Thus our recommendation is to change the setting of your editor instead.
If you're interested in developing or researching with EasyVolcap, the recommended way is to fork the repository and modify or append to our source code directly instead of using EasyVolcap as a module.
After cloning and forking, add https://github.com/zju3dv/EasyVolcap as an upstream
if you want to receive updates from our side. Use git fetch upstream
to pull and merge our updates to EasyVolcap to your new project if needed. The following code block provides an example of this development process.
Our recent project 4K4D is developed in this fashion.
# Prepare the name and GitHub repo of your new project
project=4K4D
repo=https://github.com/zju3dv/${project}
# Clone EasyVolcap and add our repo as an upstream
git clone https://github.com/zju3dv/EasyVolcap ${project}
# Setup the remote of your new project
git set-url origin ${repo}
# Add EasyVolcap as an upstream
git remote add upstream https://github.com/zju3dv/EasyVolcap
# If EasyVolcap updates, fetch the updates and maybe merge with it
git fetch upstream
git merge upstream/main
Nevertheless, we still encourage you to read on and possibly follow the tutorials in the Examples section and maybe read our design documents in the Design Docs section to grasp an understanding of how EasyVolcap works as a project.
- Documentations are still WIP. We'll gradually add more guides and examples, especially regarding the usage of EasyVolcap's various systems.
The documentation contained in the docs/design
directory contains explanations of design choices and various best practices when developing with EasyVolcap.
docs/design/main.md
: Gives an overview of the structure of the EasyVolcap codebase.
docs/design/config.md
: Thoroughly explains the commandline and configuration API of EasyVolcap.
docs/design/logging.md
: Describes the functionalities of the logging system of EasyVolcap.
We would like to acknowledge the following inspiring prior work:
- EasyMocap: Make Human Motion Capture Easier
- XRNeRF: OpenXRLab Neural Radiance Field (NeRF) Toolbox and Benchmark
- Nerfstudio: A Modular Framework for Neural Radiance Field Development
- Dear ImGui: Bloat-Free Graphical User Interface for C++ With Minimal Dependencies
- Neural Body: Implicit Neural Representations with Structured Latent Codes
- ENeRF: Efficient Neural Radiance Fields for Interactive Free-Viewpoint Video
- Instant Neural Graphics Primitives with a Multiresolution Hash Encoding
- 3D Gaussian Splatting for Real-Time Radiance Field Rendering
EasyVolcap's license can be found here.
Note that the license of the algorithms or other components implemented in EasyVolcap might be different from the license of EasyVolcap itself. You will have to install their respective modules to use them in EasyVolcap following the guide in the installation section.
Please refer to their respective licensing terms if you're planning on using them. For example, EasyVolcap's own implementation of the hash embedding module is under the same license as EasyVolcap. However, the imported tiny-cuda-nn
package in the importing implementation is under the same license as the original Instant-NGP paper.
All imported modules (as specified in the requirement files or in the source code itself) are under their respective licenses. Please refer to their respective repositories for more information.
If you find this code useful for your research, please cite us using the following BibTeX entry. If you used our implementation of other methods, please also cite them separately.
@article{xu2023easyvolcap,
title={EasyVolcap: Accelerating Neural Volumetric Video Research},
author={Xu, Zhen and Xie, Tao and Peng, Sida and Lin, Haotong and Shuai, Qing and Yu, Zhiyuan and He, Guangzhao and Sun, Jiaming and Bao, Hujun and Zhou, Xiaowei},
booktitle={SIGGRAPH Asia 2023 Technical Communications},
year={2023}
}
@article{xu20234k4d,
title={4K4D: Real-Time 4D View Synthesis at 4K Resolution},
author={Xu, Zhen and Peng, Sida and Lin, Haotong and He, Guangzhao and Sun, Jiaming and Shen, Yujun and Bao, Hujun and Zhou, Xiaowei},
booktitle={arXiv preprint arXiv:2310.11448},
year={2023}
}