streaming
streaming copied to clipboard
Make `epoch_sample_ids` cachable
🚀 Feature Request
It would be awesome to enable caching of epoch_sample_ids.
Motivation
Caching would remove a lot of redundant work that currently is re-executed at each run. It takes 20 minutes for my dataset's sample IDs to be created. This wastes a lot of budget for large-scale runs.
In my case, I'll specifically focus on the implementation in https://github.com/mosaicml/streaming/blob/2e9db78db6dd4108b697cfde92a95cd0de80539c/streaming/base/batching/random.py. Specifically, what takes long is dataset.resample_streams (with sampling_method="balanced") and get_shuffle (with shuffle_algo="py1e" in my case).
[Optional] Implementation
I've looked into this a bit, but get_shuffle's indirect dependence through get_partitions on sample_in_epoch (drop_first in the called functions) seems to make this very difficult. Maybe someone with more knowledge of the codebase can chime in on this, though. I would personally be happy with a simple hacky solution for now. :)
For now I've implemented a stupid NumPy file hash for dataset.resample_streams, which already saves around 40–50% of the time.
Hey, @janEbert this seems sensible! We have chosen not to cache the epoch sample id tensor mainly because persistent storage may not be available in many training setups. So reading from a cached file is not always possible.
However, this could be an optional feature for users that do have this set up. Dumping the numpy tensor to a file honestly is a good start -- we'd be happy to help review an implementation, and always appreciate community PRs!
I see, that makes sense. It also seemed like the indices are re-calculated upon each validation run, so there is really only a time save when you start a run or load from a checkpoint.
Regarding the implementation, I'll be happy to put what I cooked up into a PR once I find some free time. Considering the re-calculation I mentioned above (if I interpreted my logs correctly), maybe the additional complexity is not really worth to add to the code base, though. :)