nerfstudio icon indicating copy to clipboard operation
nerfstudio copied to clipboard

Dataloading Revamp

Open AntonioMacaronio opened this issue 8 months ago • 1 comments

Problems and Background

  • With a sufficiently large enough dataset, the current parallel_datamanager.py will try to cache the entire dataset into RAM, which will lead to an out-of-memory (OOM) error
  • parallel_datamanager.py only uses one worker to generate ray bundles. Since various subprocesses such as unprojecting during ray generation, or pixel sampling within a custom mask can be a CPU-intensive task, it may be better suited to parallelize this. While parallel_datamanager.py does support multiple workers, each worker caches the entire dataset to RAM and it does not support massive datasets, leading to duplicate copies of the dataset in computer memory. It also implements parallelism from scratch and is not friendly to build off.
  • Additionally, both VanillaDataManager and ParallelDataManager rely on CacheDataloader, which subclasses torch.utils.data.DataLoader, which is a strange coding practice, and actually serves no particular use in the current nerfstudio implementation.
  • Similarly for full_images_datamanager.py: As we can not fit the entire dataset in RAM, the current implementation loads in entire dataset into the FullImageDataloader's cached_train attribute. To do this efficiently, we need multiprocess parallelization to load in images, undistort them, and do this quickly to keep up with GPU's forward and backward passes of the model.

Overview of Changes

  • Replacing CacheDataloader with RayBatchStream, which subclasses torch.utils.data.IterableDataset. The goal of this class is to generate ray bundles directly without caching all images to RAM. This is done by collating a sampled batch of images to sample from. A new ParallelDatamanager class is written which is available side-by-side but can completely replace the original VanillaDatamanager
  • Adding an ImageBatchStream to create a parallel, OOM-resistant version of FullImageDataManager. This can be configured to load from the disk by setting cache_images config variable to disk.
  • A new pil_to_numpy() function is added. This function reads a PIL.Image's data buffer and fills an empty numpy array while reading, hastening the conversion process and removing an extra memory allocation. It is the fastest way to get from a PIL Image to a Pytorch tensor averaging ~2.5ms for a 1080x1920 image (~40% faster)
  • A new flag called cache_compressed_imgs now caches your images to RAM in their compressed form (for example, caching) and relies on parallelized CPU dataloading to efficiently decode them into pytorch tensors to be used in training.

Impact

  • Checkout these comparisons! The left was trained on 200 images of a 4k video, while the right was trained on 2000 images of the same 4k video.

AntonioMacaronio avatar Jun 12 '24 11:06 AntonioMacaronio