streaming icon indicating copy to clipboard operation
streaming copied to clipboard

Best strategy to do joint image/video training

Open universome opened this issue 1 year ago • 1 comments

Hi, I have a video dataset and have a video generation model which can take both images and videos as training inputs. It was shown in prior work that it's beneficial for a video generator to be trained on both videos and images. What is the best strategy to do so with StreamingDataset? I currently see 3 variants:

  1. Two dataloaders and two datasets (does not work)
dataset_videos = VideoStreamingDataset(**dataset_kwargs)
dataset_images = VideoStreamingDataset(return_random_frames=True, **dataset_kwargs)
dataloader_videos = infinite_iterator(StreamingDataLoader(dataset_videos, batch_size=batch_size_videos))
dataloader_images = infinite_iterator(StreamingDataLoader(dataset_images, batch_size=batch_size_images))
while True:
  batch_videos = next(iterator_videos)
  batch_images = next(iterator_images)
  ...

It does not work, because local points to the same directory for both dataset_videos and dataset_images, leading to the "Reused directory error". We cannot set local=None for dataset_images, because it iterates with a different pace (batch_size_images > batch_size_videos) and will get stuck when the shards are absent (only works when local fully exists, i.e. remote=None).

  1. Two dataloaders and a single dataset (works, but slow)
dataset = VideoStreamingDataset(**dataset_kwargs)
iterator_videos = infinite_iterator(StreamingDataLoader(dataset, batch_size=batch_size_videos))
iterator_images = infinite_iterator(StreamingDataLoader(dataset, batch_size=batch_size_images))
while True:
  batch_videos = next(iterator_videos)
  batch_images = sample_random_frame_and_return_batch(next(iterator_images))
  ...

It is slow because we decode all the frames for each video, and then use just a single frame from each video in the dataloader_images. Also, for 2 dataloaders we have to use twice as many workers as we should have for a single dataloader.

  1. Single dataloader and single dataset (does not work, but would be an ideal solution I guess)
dataset = VideoStreamingDataset(**dataset_kwargs)
dataloader = StreamingDataLoader(dataset)
magic_joint_iterator = some_magic_joint_iterator(dataloader)
while True:
  batch_videos = next(magic_joint_iterator)
  batch_images = next(magic_joint_iterator, return_images_instead_of_videos=True)
  ...

This does not work, since it's not clear how to pass the parameter into the get_item method of the dataset.

Could you please tell what would be the optimal strategy to implement joint image/video training with StreamingDataset (note that image/video batch sizes are different)?

universome avatar Oct 09 '23 15:10 universome

Thanks for trying Streaming.

We haven't done a whole lot with the video modality, and have not seen this particular use case before, caveat emptor:

Two dataloaders and a single dataset (works, but slow)

It is slow because we decode all the frames for each video, and then use just a single frame from each video in the dataloader_images.

This slowness is the problem right? Particularly if longer videos/more batch_images/better networking, I would benchmark time to iterate with decoding videos into their frames ahead of time and then serializing samples to MDS not as videos but as array of JPEG with image quality cranked way down, figuring that network is faster than video decoding. People hate using disk inefficiently, but compared to time, CPU, etc. it's your most elastic resource imho.

This would require doing your own sample ser/deser, which would conceivably look like:

  • num frames: u32
  • byte offset of each frame: array of u32
  • each frame's JPEG data

This has the added benefit that you could jump directly to single frames when decoding for batch_images.

This does not work, since it's not clear how to pass the parameter into the get_item method of the dataset.

StreamingDataset is a regular python object, so you can do hacks like:

dataset.ret_modality = ReturnModality.VIDEOS
batch_videos = next(it)  # Calls dataset.get_item(), whose behavior depends on self.ret_modality

dataset.ret_modality = ReturnModality.IMAGES
images_batch = next(it)  # Calls dataset.get_item(), whose behavior depends on self.ret_modality

dataset.ret_modality = None

knighton avatar Oct 15 '23 02:10 knighton