image2sphere
image2sphere copied to clipboard
Method to predict distributions over 3D object rotation from 2D images
Image to Sphere: Learning Equivariant Features for Efficient Pose Prediction
This repository implements a hybrid equivariant model for SO(3) reasoning from 2D images for object pose estimation.
The underlying SO(3) symmetry of the pose estimation task is not accessible in an image, which can only be transformed
by in-plane rotations. Our model, I2S, projects features from the image plane onto the sphere, which is SO(3) transformable. Thus,
the model is able to leverage SO(3)-equivariant group convolutions which improve sample efficiency. Conveniently,
the output of the group convolution are coefficients over the Fourier basis of SO(3), which form a concise yet expressive
representation for distributions over SO(3). Our model can capture complex pose distributions that arise from occlusions,
ambiguity or object symmetries.
Table of Contents
-
Colab Demos
- Visualize Predictions
- Model Walkthrough
- Intro to Spherical Convolution
- Installation
- Dataset Preparation
-
Train I2S
- ModelNet10-SO(3)
- SYMSOL
- PASCAL3D+
- Pretrained Models
- Citation
- Acknowledgements
Colab Demos
Visualize Predictions
This Colab notebook loads pretrained I2S models on PASCAL3D+ and ModelNet10-SO(3) and visualizes output distributions generated for images from the test set. You can also upload your own images and see what the model predicts.
Model Walkthrough
This Colab notebook goes step-by-step through the construction of I2S, and illustrates how you can modify different components for a custom application.
Intro to Spherical Convolution
This Colab notebook helps you understand spherical harmonics and spherical convolution with some visualizations.
Installation
This code was tested with python 3.8. You can install all necessary requirements with pip:
pip install -r requirements.txt
You may get lots of warnings from e3nn about deprecated functions. If so, run commands as python -W ignore -m src.train ...
Dataset preparation
Follow instruction in datasets/README.md
. Make sure to run commands from
within the datasets
folder.
Train I2S
ModelNet10-SO(3)
python -m src.train --dataset_name=modelnet10 --encoder=resnet50_pretrained --seed=0
Rotation error (in radians) on the test set will be stored in results/pascal3d-warp-synth_resnet101-pretrained_seed0/eval.npy
To train on the limited training set (20 views per instance), run:
python -m src.train --dataset_name=modelnet10-limited --encoder=resnet50_pretrained --seed=0
SYMSOL
Here is an example for training on SYMSOL I with 50k views per instance
python -m src.train --dataset_name=symsolI-50000 --encoder=resnet50_pretrained --seed=0
Average log likelihood on the test set will be stored in results/symsolI-50000_resnet50-pretrained_seed0/eval_log_likelihood.npy
You can adjust the number of views (--dataset_name=symsolI-10000
will use 10k views per instance) or
train on SYMSOL II objects (--dataset_name=symsolII-50000
will train on sphX; --dataset_name=symsolIII-50000
will train on cylO; --dataset_name=symsolIIII-50000
will train on tetX). We train a single model on all of SYMSOL I, but separate models for each object from SYMSOL II.
Train on PASCAL3D+
python -m src.train --dataset_name=pascal3d-warp-synth --encoder=resnet101_pretrained --seed=0
Rotation error (in radians) on the test set will be stored in results/pascal3d-warp-synth_resnet101-pretrained_seed0/eval.npy
Pretrained Models
Pascal3D+
Download the checkpoint here. It achieves median rotation error of 9.6 degrees averaged over all twelve classes. It was trained using the procedure outlined in the paper.
Use the following code to load the weights:
from src.predictor import I2S
model = I2S(num_classes=12, encoder='resnet101')
checkpoint = torch.load('pascal3d_checkpoint.pt')['model_state_dict']
model.load_state_dict(checkpoint)
model.eval()
ModelNet10-SO(3)
Download the checkpoint here. It achieves an Acc@15 of 0.721, averaged over all 10 classes. It was trained using the procedure outlined in the paper (100 training views).
Use the following code to load the weights:
from src.predictor import I2S
model = I2S(num_classes=10, encoder='resnet50')
checkpoint = torch.load('modelnet10so3_checkpoint.pt')['model_state_dict']
model.load_state_dict(checkpoint)
model.eval()
Citation
To cite this work, please use the following bibtex:
@inproceedings{
klee2023image2sphere,
title={Image to Sphere: Learning Equivariant Features for Efficient Pose Prediction},
author={David M. Klee and Ondrej Biza and Robert Platt and Robin Walters},
booktitle={International Conference on Learning Representations},
year={2023},
url={https://openreview.net/forum?id=_2bDpAtr7PI}
}
Acknowledgements
The code for loading and warping PASCAL3D+ images is taken from this repo. The code for generating healpy grids and visualizing distributions over SO(3) is taken from this repo.