fms-fsdp
fms-fsdp copied to clipboard
Do not scan the same files multiple times in ScalableShardDataset
It seems that ScalableShardDataset can run os.walk of the same folders and then get length of each file potentially hundreds of times depending on the number of logical shards in one rank, because each StreamingDocDataset runs full os.walk() in its setup: https://github.com/foundation-model-stack/fms-fsdp/blob/23a9e39fd7f83b46bf27cfbf5e09143514f8c7eb/fms_fsdp/utils/dataset_utils.py#L1226 https://github.com/foundation-model-stack/fms-fsdp/blob/23a9e39fd7f83b46bf27cfbf5e09143514f8c7eb/fms_fsdp/utils/dataset_utils.py#L925
This can be especially inefficient in case of many files. Should os.walk() run only once and its result should be reused?
Thanks @rualark for bringing this up! Yes the repeated os.walk is a pain point with respect to setup times. Unfortunately I don't think it's feasible to execute the walk only once, because pytorch dataloader workers are asynchronous and not really designed to pass objects among themselves. What we might be able to do though is run the walk once per device, by changing the way datapaths are handled during construction, and that way the logical shard count won't have as much of an impact on setup time.
Per-device os.walk and get_length deduplication sounds like a great start.
What about extracting this reading logic to rank0, then NCCL tree-broadcasting file paths and lengths to all other ranks and then copying to all StreamingDocDatasets:
- First rank0 walks and calls get_length in multiple threads on rank0 and sends results to a queue.
- rank0 gets batches from the queue and starts tree-broadcasting to other ranks.
- When all ranks get the whole list of files with lengths, they replicate StreamingDocDatasets in memory.
Further improvements:
A. For extremely large datasets on highly distributed storages, if rank0 NIC becomes a bottleneck during the first step, folders can be NCCL-scattered between all nodes and then after reading folders on different nodes, file paths and lengths can be all-gathered to all ranks.
B. If StreamingDocDatasets initialization becomes a bottleneck, it can start initializing with the first batches of file paths and lengths arriving.