HybridDepth
HybridDepth copied to clipboard
Official implementation for HybridDepth Model
Hybrid Depth: Robust Depth Fusion for Mobile AR By Leveraging Depth from Focus and Single-Image Priors
Ashkan Ganj1 · Hang Su2 · Tian Guo1
1Worcester Polytechnic Institute 2Nvidia Research
This work presents HybridDepth. HybridDepth is a practical depth estimation solution based on focal stack images captured from a camera. This approach outperforms state-of-the-art models across several well-known datasets, including NYU V2, DDFF12, and ARKitScenes.
News
- 2024-07-25: We released the pre-trained models.
- 2024-07-23: Model and Github repository is online.
TODOs
- [ ] (WIP) Release a new model and pre-trained weights with improved performance.
- [ ] Add Hugging Face model.
- [ ] Release Android Mobile Client for HybridDepth.
Pre-trained Models
We provide three models trained on different datasets. You can download them from the links below:
Model | Checkpoint |
---|---|
Hybrid-Depth-NYU-5 | Download |
Hybrid-Depth-NYU-10 | Download |
Hybrid-Depth-DDFF12-5 | Download |
Hybrid-Depth-ARKitScenes-5 | Download |
Usage
Prepraration
- Clone the repository and install the dependencies:
git clone https://github.com/cake-lab/HybridDepth.git
cd HybridDepth
conda env create -f environment.yml
conda activate hybriddepth
-
Download Necessary Files:
- Download the necessary file here and place it in the
checkpoints
directory. - Download the checkpoints listed here and put them under the
checkpoints
directory.
- Download the necessary file here and place it in the
- Install Synthesizing cuda package
python utils/synthetic/gauss_psf/setup.py install
This will install the Python package for synthesizing images.
Dataset Preparation
-
NYU: Download dataset as per instructions given here.
-
DDFF12: Download dataset as per instructions given here.
-
ARKitScenes: Download dataset as per instructions given here.
Using HybridDepth model for prediction
For inference you can run the provided notebook test.ipynb
or use the following command:
# Load the model checkpoint
model_path = 'checkpoints/NYUBestScaleInv10Full.ckpt'
model = DepthNetModule.load_from_checkpoint(model_path)
model.eval()
model = model.to('cuda')
After loading the model, you can use the following code to process the input images and get the depth map:
from utils.io import prepare_input_image
data_dir = 'focal stack images directory'
# Load the focal stack images
focal_stack, rgb_img, focus_dist = prepare_input_image(data_dir)
# inference
with torch.no_grad():
out = model(rgb_img, focal_stack, focus_dist)
metric_depth = out[0].squeeze().cpu().numpy() # The metric depth
Evaluation
First setup the configuration file config.yaml
in the configs
directory. We already provide the configuration files for the three datasets in the configs
directory. In the configuration file, you can specify the path to the dataloader, the path to the model, and other hyperparameters. Here is an example of the configuration file:
data:
class_path: dataloader.dataset.NYUDataModule # Path to your dataloader Module in dataset.py
init_args:
nyuv2_data_root: 'root to the NYUv2 dataset or other datasets' # path to the specific dataset
img_size: [480, 640] # Adjust if your DataModule expects a tuple for img_size
remove_white_border: True
num_workers: 0 # if you are using synthetic data, you don't need multiple workers
use_labels: True
model:
invert_depth: True # If the model outputs inverted depth
ckpt_path: checkpoints/checkpoint.ckpt
Then specify the configuration file in the test.sh
script.
python cli_run.py test --config configs/config_file_name.yaml
Finally, run the following command:
cd scripts
sh evaluate.sh
Training
First setup the configuration file config.yaml
in the configs
directory. You only need to specify the path to the dataset and the batch size. The rest of the hyperparameters are already set.
For example, you can use the following configuration file for training on the NYUv2 dataset:
...
model:
invert_depth: True
# learning rate
lr: 3e-4 # you can adjust this value
# weight decay
wd: 0.001 # you can adjust this value
data:
class_path: dataloader.dataset.NYUDataModule # Path to your dataloader Module in dataset.py
init_args:
nyuv2_data_root: 'root to the NYUv2 dataset or other datasets' # path to the specific dataset
img_size: [480, 640] # Adjust if your NYUDataModule expects a tuple for img_size
remove_white_border: True
batch_size: 24 # Adjust the batch size
num_workers: 0 # if you are using synthetic data, you don't need multiple workers
use_labels: True
ckpt_path: null
Then specify the configuration file in the train.sh
script.
python cli_run.py train --config configs/config_file_name.yaml
Finally, run the following command:
cd scripts
sh train.sh
Citation
If our work assists you in your research, please cite it as follows:
@misc{ganj2024hybriddepthrobustdepthfusion,
title={HybridDepth: Robust Depth Fusion for Mobile AR by Leveraging Depth from Focus and Single-Image Priors},
author={Ashkan Ganj and Hang Su and Tian Guo},
year={2024},
eprint={2407.18443},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2407.18443},
}
@inproceedings{10.1145/3638550.3641122,
author = {Ganj, Ashkan and Zhao, Yiqin and Su, Hang and Guo, Tian},
title = {Mobile AR Depth Estimation: Challenges \& Prospects},
year = {2024},
isbn = {9798400704970},
publisher = {Association for Computing Machinery},
address = {New York, NY, USA},
url = {https://doi.org/10.1145/3638550.3641122},
doi = {10.1145/3638550.3641122},
abstract = {Accurate metric depth can help achieve more realistic user interactions such as object placement and occlusion detection in mobile augmented reality (AR). However, it can be challenging to obtain metricly accurate depth estimation in practice. We tested four different state-of-the-art (SOTA) monocular depth estimation models on a newly introduced dataset (ARKitScenes) and observed obvious performance gaps on this real-world mobile dataset. We categorize the challenges to hardware, data, and model-related challenges and propose promising future directions, including (i) using more hardware-related information from the mobile device's camera and other available sensors, (ii) capturing high-quality data to reflect real-world AR scenarios, and (iii) designing a model architecture to utilize the new information.},
booktitle = {Proceedings of the 25th International Workshop on Mobile Computing Systems and Applications},
pages = {21–26},
numpages = {6},
location = {<conf-loc>, <city>San Diego</city>, <state>CA</state>, <country>USA</country>, </conf-loc>},
series = {HOTMOBILE '24}
}