nerfstudio
nerfstudio copied to clipboard
Dataloading Revamp
Problems and Background
- With a sufficiently large enough dataset, the current
parallel_datamanager.py
will try to cache the entire dataset into RAM, which will lead to an out-of-memory (OOM) error -
parallel_datamanager.py
only uses one worker to generate ray bundles. Since various subprocesses such as unprojecting during ray generation, or pixel sampling within a custom mask can be a CPU-intensive task, it may be better suited to parallelize this. Whileparallel_datamanager.py
does support multiple workers, each worker caches the entire dataset to RAM and it does not support massive datasets, leading to duplicate copies of the dataset in computer memory. It also implements parallelism from scratch and is not friendly to build off. - Additionally, both
VanillaDataManager
andParallelDataManager
rely on CacheDataloader, which subclassestorch.utils.data.DataLoader
, which is a strange coding practice, and actually serves no particular use in the current nerfstudio implementation. - Similarly for
full_images_datamanager.py
: As we can not fit the entire dataset in RAM, the current implementation loads in entire dataset into theFullImageDataloader
'scached_train
attribute. To do this efficiently, we need multiprocess parallelization to load in images, undistort them, and do this quickly to keep up with GPU's forward and backward passes of the model.
Overview of Changes
- Replacing
CacheDataloader
withRayBatchStream
, which subclassestorch.utils.data.IterableDataset
. The goal of this class is to generate ray bundles directly without caching all images to RAM. This is done by collating a sampled batch of images to sample from. A newParallelDatamanager
class is written which is available side-by-side but can completely replace the originalVanillaDatamanager
- Adding an
ImageBatchStream
to create a parallel, OOM-resistant version ofFullImageDataManager
. This can be configured to load from the disk by settingcache_images
config variable to disk. - A new
pil_to_numpy()
function is added. This function reads a PIL.Image's data buffer and fills an empty numpy array while reading, hastening the conversion process and removing an extra memory allocation. It is the fastest way to get from a PIL Image to a Pytorch tensor averaging ~2.5ms for a 1080x1920 image (~40% faster) - A new flag called
cache_compressed_imgs
now caches your images to RAM in their compressed form (for example, caching) and relies on parallelized CPU dataloading to efficiently decode them into pytorch tensors to be used in training.
Impact
- Checkout these comparisons! The left was trained on 200 images of a 4k video, while the right was trained on 2000 images of the same 4k video.