GradNCP
GradNCP copied to clipboard
Learning Large-scale Neural Fields via Context Pruned Meta-Learning (NeurIPS 2023)
GradNCP
Official PyTorch implementation of "Learning Large-scale Neural Fields via Context Pruned Meta-Learning" (NeurIPS 2023) by Jihoon Tack, Subin Kim, Sihyun Yu, Jaeho Lee, Jinwoo Shin, Jonathan Richard Schwarz.
TL;DR: We propose an efficient meta-learning framework for scalable neural fields learning that involves online data pruning of the context set.
1. Dependencies
conda create -n gradncp python=3.8 -y
conda activate gradncp
pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
pip install einops pyyaml tensorboardX tensorboard natsort pyspng av pytorch_msssim lpips
2. Dataset
- Dataset path
/data, one can change the path indata.dataset.py(e.g.,DATA_PATH = './PATH_TO_DATA') - Download CelebA, CelebA-HQ, AFHQ, Imagenette-320, ImageNet, Text, UCF-101, Librispeech, ERA5
3. How to run?
Train
# Learnit
CUDA_VISIBLE_DEVICES=0 python main.py --configs ./configs/main/maml_celeba.yaml
# Ours
CUDA_VISIBLE_DEVICES=0 python main.py --configs ./configs/main/ours_celeba.yaml
Evaluation
- Example of
<PATH TO CHECKPOINT>:./logs/maml_celeba/best.pth
# Learnit
CUDA_VISIBLE_DEVICES=0 python eval.py --configs ./configs/evaluation/eval_celeba.yaml --load_path ./logs/xxxx/best.model
# Ours (CelebaA) Example
CUDA_VISIBLE_DEVICES=0 python eval.py --configs ./configs/evaluation/eval_celeba_ours.yaml --load_path ./logs/xxxx/best.model
Reference
This code is mainly built upon JAX Learnit, JAX Functa, PyTorch Siren, PyTorch MetaSDF, PyTorch Meta-SparseINR, and PyTorch COIN++ repositories.
Citation
@inproceedings{tack2023learning,
title={Learning Large-scale Neural Fields via Context Pruned Meta-Learning},
author={Tack, Jihoon and Kim, Subin and Yu, Sihyun and Lee, Jaeho and Shin, Jinwoo and Schwarz, Jonathan Richard},
booktitle={Advances in Neural Information Processing Systems},
year={2023}
}