PointConT
PointConT copied to clipboard
Official implementation of the paper "Point Cloud Classification Using Content-based Transformer via Clustering in Feature Space"
PointConT
This repository is an official implementation of the following paper:
Point Cloud Classification Using Content-based Transformer via Clustering in Feature Space
Yahui Liu, Bin Tian, Yisheng Lv, Lingxi Li, Feiyue Wang
Accepted for publication in the IEEE/CAA Journal of Automatica Sinica
Get Started
Installation
# clone this repo
git clone https://github.com/yahuiliu99/PointConT.git
cd PointConT
# create a conda env
conda create -n pointcont -y python=3.7 numpy=1.20 numba
conda activate pointcont
# install PyTorch and libs (refer to requirements.txt)
# please install compatible PyTorch and CUDA versions
conda install -y pytorch=1.10.1 torchvision cudatoolkit=11.1 -c pytorch -c nvidia
pip install hydra-core==1.1 h5py scikit-learn einops tqdm warmup-scheduler deepspeed tensorboard
# install the pointnet++ library cuda extensions
pip install pointnet2_ops_lib/.
Data Preparation
When you first run the command for training, the datasets will be automatically downloaded and saved in data/.
- ModelNet40 -->
data/modelnet40_ply_hdf5_2048/ - ScanObjectNN -->
data/h5_files/
Alternatively, you can manually download the official data (ModelNet40 | ScanObjectNN) in any path, and create a symbolic link to your datasets folder.
mkdir data
ln -s /path/to/your/data/folder data/
Training
Step 1: Check config file
You can modify settings in config/cls.yaml.
Make sure the eval is set to False.
We support wandb for collecting results online. Just set
wandb.use_wandb=Trueif use wandb. Please check the official wandb doc for more details.
Step 2: Train PointConT
-
Classification on ModelNet40
python main_cls.py db=modelnet40 -
Classification on ScanObjectNN
python main_cls.py db=scanobjectnn
config/cls.yaml will be automatically loaded when you run the command.
Evaluation
To evaluate a trained-model, please set eval=True in config/cls.yaml and run python main_cls.py db=${dataset}
Or you can override values in the loaded config from the command line:
python main_cls.py db=${dataset} eval=True
Visualization
Dependency
Please refer to the following github repository for point cloud rendering code: PointFlowRenderer

Results (pretrained model)
| Dataset | mAcc | OA | Download |
|---|---|---|---|
| ModelNet40 | 90.5 | 93.5 | ckpt | log |
| ScanObjectNN | 86.0 | 88.0 | ckpt | log |
| ScanObjectNN * | 88.5 | 90.3 | config | log |
* denotes method evaluated with voting strategy
Citation
If you find our work useful in your research, please consider citing:
@article{Liu2023PointConT,
author = {Liu, Yahui and Tian, Bin and Lv, Yisheng and Li, Lingxi and Wang, Fei-Yue},
title = {Point Cloud Classification Using Content-based Transformer via Clustering in Feature Space},
journal = {IEEE/CAA Journal of Automatica Sinica},
year={2023},
volume={10},
number={8},
pages={1-9},
doi={10.1109/JAS.2023.123432}
}
Acknowledgement
Our code is mainly based on the following open-source projects. Many thanks to the authors for their wonderful works.
PointNet2, Point-Transformers, DGCNN, CurveNet, PointMLP, PAConv, PointNeXt