Efficient-VDVAE icon indicating copy to clipboard operation
Efficient-VDVAE copied to clipboard

How to preprocess a new dataset?

Open turian opened this issue 2 years ago • 2 comments

I have a new dataset of 128x128 images. Can you provide README instructions on how to preprocess it?

turian avatar Aug 14 '22 00:08 turian

Hello @turian,

Thank you for bringing this to our attention. We will hopefully revise the README section in the future to include this, but for now here are the steps to load your own data:

  • Prerequisite: you need all your training images, in a format readable by Pillow, in one folder and your test/validation images in a separate one. We don't support reading images from zip files or other formats at this moment.
  1. Check the preprocessing pipeline existent in efficient_vdvae_torch (or) efficient_vdvae_jax /data/generic_data_loader.py and verify it's compatible with the requirements in your own dataset. The generic pipeline has 3 parts: A normalization depending on the number of bits you want your images to be, A normalization so that images are in [-1,1] and an optional horizontal flip for your training dataset (Which you can control through hparams)

  2. Go to efficient_vdvae_torch (or) efficient_vdvae_jax/hparams.cfg and set your data paths and all the other parameters like the size of your images 128x128 in data section. Make sure to set a new string as your dataset_source. (new_data in this example.)

  3. Go to efficient_vdvae_torch/train.py and add your new dataset_source string to the list of supported datasets that use the generic dataloader pipeline.

if hparams.data.dataset_source in ['ffhq', 'celebAHQ', 'celebA', 'new_data']:
        train_files, train_filenames = create_filenames_list(hparams.data.train_data_path)
        val_files, val_filenames = create_filenames_list(hparams.data.val_data_path)
        train_loader, val_loader = train_val_data_generic(train_files, train_filenames, val_files, val_filenames,
                                                          hparams.run.num_gpus, local_rank)

For JAX it's quite similar. Go to efficient_vdvae_jax/train.py and add the new dataset_source.

    # Load datasets
    if hparams.data.dataset_source in ('ffhq', 'celebAHQ', 'celebA', 'new_data'):
        train_data, val_data = create_generic_datasets()

Hopefully that answers your question. Let me know if there's something that's not clear :). Otherwise, please feel free to close this issue.

Thank you! Louay Hazami

Vanlogh avatar Aug 15 '22 15:08 Vanlogh

Hello @turian and thanks for showing interest in our work.

We have added custom dataset support in our latest commit.

Sufficient instructions on how to use are available in this section of the README. We also provide utility scripts to train/val split or resize your data if needed (as explained in the new section of the README).

Hope this helps, let us know if there are still any pending issues concerning this feature.

Best, Rayhane.

Rayhane-mamah avatar Aug 15 '22 17:08 Rayhane-mamah