streaming
streaming copied to clipboard
Modify StreamingDataset to support passing process_group as construct…
…or arg
Description of changes:
Modify StreamingDataset to support passing process_group as a constructor argument. Currently, StreamingDataset assumes it should use the default process group; however, for certain use cases (e.g., Pipeline Distributed Parallel training), users may want to specify a different process group. This change accommodates this flexibility.
Issue #, if available:
N/A
Merge Checklist:
Put an x
without space in the boxes that apply. If you are unsure about any checklist, please don't hesitate to ask. We are here to help! This is simply a reminder of what we are going to look for before merging your pull request.
General
- [X] I have read the contributor guidelines
- [ ] This is a documentation change or typo fix. If so, skip the rest of this checklist.
- [X] I certify that the changes I am introducing will be backward compatible, and I have discussed concerns about this, if any, with the MosaicML team.
- [ ] I have updated any necessary documentation, including README and API docs (if appropriate).
Tests
- [X] I ran
pre-commit
on my change. (check out thepre-commit
section of prerequisites) - [ ] I have added tests that prove my fix is effective or that my feature works (if appropriate).
- [X] I ran the tests locally to make sure it pass. (check out testing)
- [ ] I have added unit and/or integration tests as appropriate to ensure backward compatibility of the changes.
Hey @jasonkrone, thanks for submitting this PR! So two things:
- we're trying to reduce our dependency on torch.distributed because it can, at times, cause some messy issues with more complex parallelism schemes at scale
- we are definitely looking to support other parallelism schemes, including pipeline parallelism
We recently added the
replication
arg, which duplicates samples across groups of consecutive GPUs. For your pipelining case, does that work for you? If it doesn't, is there a way to do this which doesn't involve sending in a ProcessGroup object? For example, you may be able to pass in aWorld
object or add a function to theWorld
object to enable pipelining. Mind elaborating on your training setup and what data specific groups of GPUs need to see?
Ah I see @knighton already discussed much of this with you on the community slack.
@jasonkrone is this PR still needed? Or was there a resolution from that community slack thread?
We can cancel this one! I didn't get a resolution, but i think that this is not likely to be a common issue for the community and therefore doesn't warrant the change.