This repository is the official implementation of the NeurIPS 2024 paper "Point-PRC: A Prompt Learning Based Regulation Framework for Generalizable Point Cloud Analysis".
Motivation of our research: to promote the performances on downstream 3D tasks while maintaining good generalization of large 3D models. The experiments are conducted on ShapeNetCoreV2. ULIP2 can reach 71.22% zero-shot recognition accuracy on this dataset. Recent works built on ULIP-2 introduce lightweight prompt tuning (PT) to further boost target tasks (75.80% accuracy). However, we observe the improvements come at the expenses of a severe drop in 3D domain generalization (e.g., 57.07% accuracy on new classes, much behind 71.22%), and develop a systematic regulation constraint (RC) framework to address this challenge.
This paper investigates the 3D domain generalization (3DDG) ability of large 3D models based on prevalent prompt learning. Recent works demonstrate the performances of 3D point cloud recognition can be boosted remarkably by parameter-efficient prompt tuning. However, we observe that the improvement on downstream tasks comes at the expense of a severe drop in 3D domain generalization. To resolve this challenge, we present a comprehensive regulation framework that allows the learnable prompts to actively interact with the well-learned general knowledge in large 3D models to maintain good generalization. Specifically, the proposed framework imposes multiple explicit constraints on the prompt learning trajectory by maximizing the mutual agreement between task-specific predictions and task-agnostic knowledge. We design the regulation framework as a plug-and-play module to embed into existing representative large 3D models. Surprisingly, our method not only realizes consistently increasing generalization ability but also enhances task-specific 3D recognition performances across various 3DDG benchmarks by a clear margin. Considering the lack of study and evaluation on 3DDG, we also create three new benchmarks, namely base-to-new, cross-dataset and few-shot generalization benchmarks, to enrich the field and inspire future research.
- dassl
- Ubuntu 23.10
- Python 3.8.16
- PyTorch 1.12.0
- CUDA 11.6
- torchvision 0.13.0
- timm 0.9.16
- pueue & pueued 2.0.4
# NOTE The option 1 is recommended. A complete package list is provided in `env.yaml`
# option 1: create conda virtual env by your own
conda create -n pointprc python=3.8.16
codna activate pointprc
# install torch
pip install torch==1.12.0+cu116 torchvision==0.13.0+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
# install dassl
git clone https://github.com/auniquesun/dassl
cd dassl/
python setup.py develop # (no need to re-build if the source code is modified)
# option 2: create conda virtual env according to the provided env.yaml
conda env create -f env.yaml
codna activate pointprc
pueue
is a shell command management software, we use it for scheduling the model training & evaluation tasks, please refer to the official page for installation and basic usage. We recommend this tool because under its help you can run the experiments at scale thus save your time.
NOTE: We provide a complete package list of our virtual environment in env.yaml
. Feel free to check whether you need a specific package. If it the case, run the following command to install it, e.g.
pip install h5py==3.10.0 plyfile==1.0.3
-
In the experiments, we use the following models as the baselines. The pre-trained weights of these models can be found in their public GitHub repositories.
- NOTE: ULIP-2 uses same text encoder as ULIP
-
Make a folder called
weights
under this project and save the pre-trained weights into this folder.
-
We conduct experiments on three new 3D domain generalization (3DDG) benchmarks proposed by us, as introduced in the next section.
- base-to-new class generalization (base2new)
- cross-dataset generalization (xset)
- few-shot generalization (fewshot)
-
The structure of these benchmarks should be organized as follows.
/path/to/Point-PRC
|----data # placed in the same level of `trainers`, `weights`, etc.
|----base2new
|----modelnet40
|----scanobjectnn
|----shapenetcorev2
|----xset
|----corruption
|----dg
|----sim2real
|----pointda
|----fewshot
|----modelnet40
|----scanobjectnn
|----shapenetcorev2
- You can find the usage instructions and download links of these new 3DDG benchmarks in the following section.
-
The datasets used in this benchmark can be downloaded according to the following links.
-
The following table shows the statistics of this benchmark.
-
The datasets used in this benchmark can be downloaded according to the following links.
- OOD Generalization
- OmniObject3d (Omin3D)
- Data Corruption
- ModelNet-C (7 types of corruptions)
- add global outliers, add local outliers, dropout global structure, dropout local region, rotation, scaling, jittering
- ModelNet-C (7 types of corruptions)
- Sim-to-Real
- PointDA
- OOD Generalization
-
The following table shows the statistics of this benchmark.
-
Although this benchmark contains same datasets as the Base-to-new Class, it investigates the model generalization under extremely low-data regime (1, 2, 4, 8, and 16 shots), which is quite different from the evaluation setting in Base-to-new Class Generalization.
-
The following table shows the statistics of this benchmark.
We describe the experiment settings and implementation details in the paper, referring to Section 3.3, Section 4.1, Appendix A.1 and A.2.
-
This part corresponds to the experiments in Section 4.2 (Table 1) and Appendix (Table 9).
-
To evaluate the performances on this benchmark, you will use
scripts/pointprc/base2new_train.sh
andscripts/pointprc/base2new_test.sh
. The former trains a model on base classes while the latter evaluates the trained model on new classes. Both scripts have 18 input arguments, as commented in the script file. -
Taking S-PB_T50_RS (the hardest split of
ScanObjectNN
) as an example, we train the model on base classes and then evaluate the performance on new classes.
# prompt learning on base classes
./scripts/pointprc/base2new_train.sh 0 data/base2new/scanobjectnn scanobjectnn custom_ulip manual64 9 2 2 16 20 hardest full task_perform False False False ulip2 l1_dist
# test on novel classes
./scripts/pointprc/base2new_test.sh 0 data/base2new/scanobjectnn scanobjectnn custom_ulip manual64 9 2 2 16 20 hardest full task_perform False False False ulip2 l1_dist
-
This part corresponds to the experiments in Section 4.4 (Figure 4).
-
To evaluate the performances on this benchmark, you will use
scripts/pointprc/fewshot_train_eval.sh
. This script also has 18 command-line arguments, as explained in the file. -
Taking ShapeNetCoreV2 as an example, we train the model using 1, 2, 4, 8 and 16 shots per class and then evaluate the performance on the whole test set of the dataset.
# train using (1/2/4/8/16)-shot and evaluate on the whole test set
./scripts/pointprc/fewshot_train_eval.sh 0 data/fewshot/shapenetcorev2 shapenetcorev2 custom_ulip 9 2 2 1 50 obj_only full task_perform False False False manual64 ulip2 l1_dist
-
This part corresponds to the experiments in Section 4.3 (Table 2 and 3) and Appendix (Table 10 and 11).
-
This benchmark has four types of evaluation settings: OOD Generalization, Data Corruption, Sim-to-Real, and PointDA, referring to Table 5 in Appendix of the paper.
-
To evaluate the performances on OOD Generalization of this benchmark, you will use
scripts/pointprc/xset_test_dg.sh
. This scripts accepts 17 command-line arguments, as explained in the file.- Note that in this setting, ShapeNetCorev2 serves as the source domain and other four datasets (ModelNet40, S-PB_T50_RS, S-OBJ_BG, S-OBJ_ONLY, Omni3D) act as the target domains.
- Since the model has been trained on ShapeNetCoreV2 in the Few-shot Generalization benchmark, when evaluating the performance on a target domain (e.g., Omni3D), we directly load the weights of 16-shot ShapeNetCorev2.
# train using (1/2/4/8/16)-shot and evaluate on the whole test set
./scripts/pointprc/xset_test_dg.sh 0 data/xset/omniobject3d omniobject3d shapenetcorev2 12 4 4 50 obj_only full task_perform False False False manual64 ulip2 l1_dist
- To evaluate the performances on other three settings, you can check the following scripts for details.
scripts/pointprc/xset_corrupt.sh
for Data Corruptionscripts/pointprc/xset_train_sim2real.sh
andscripts/pointprc/xset_test_sim2real.sh
for Sim-to-Realscripts/pointprc/xset_train_pointda.sh
andscripts/pointprc/xset_test_pointda.sh
for PointDA
@inproceedings{sun24point-prc,
title={Point-PRC: A Prompt Learning Based Regulation Framework for Generalizable Point Cloud Analysis},
author={Sun, Hongyu and Ke, Qiuhong and Wang, Yongcai and Chen, Wang and Yang, Kang and Li, Deying and Cai, Jianfei},
booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems (NeurIPS)},
year={2024},
url={https://openreview.net/forum?id=g7lYP11Erv}
}
Our implementation is partially inspired by the following projects, thanks to their great work
If you have any question about our work, please search related issues or create a new one in this repository.