webdataset
webdataset copied to clipboard
Pytorch Lightning integration
Hi
Currently, webdataset dataset using the default Pytorch Dataloader or WebLoader doesn't work with Pytorch Lightning. You need to set a length attribute (see lightining example) . However, when using an infinite dataloader as recomended in wds-distribute repo, it doesn't work.
An infinite dataloader would be supported if webdataset behavies as torchtext.data.Iterator or raise a TypeError (the same error one an object doesn't implement len method). It raise a NotImplementedError instead of ValueError. See this. I tested with Pytorch Lightning 1.2.8 and 1.3.1 and raising TypeError or NotImplementedError works.
So, it would be perfect if it's easier to user webdataset with Pytorch.Lightning. Currently, you need to write your own wds.WebLoader with a custom length function:
# This function isn't defined as a local function to be able to pickle _LightningWebLoader.
def not_implemented_length(*args, **kwargs):
# Lightning expects TypeError or NotImplementedError exception when a dataloader declares __len__ method but it's an iterable
# dataloader. Lighting assumes this to be compatible with torchtext.
raise TypeError # NotImplementedError
def LightningWebLoader(*args, **kw):
"""Like wds.WebLoader" but raise NotImplementedError like torchtext so Pytorch Lightning
detects it as IterableDataset without __len__ attribute"""
return wds.Processor(DataLoader(*args, **kw), wds.utils.identity, length=not_implemented_length)
It would be easier if:
- Expose in wds.WebLoader a way to pass arguments to
wds.Processor. For example,def WebLoader(*args, kw_proc={}, **kw): wds.Processor(DataLoader(*args, **kw), wds.utils.identity, **kw_proc). Then, passing_not_implemented_lengthwould be enough. - Like 1 but
wds.ProcessorraiseTypeErrorwhenlength=False. So, the call would bewds.WebLoader(...., kw_proc={'length'=False}). wds.Processor.__len__raiseTypeErrorinstead of ValueError.
The second and third options are backward incompatible. However, for me, option 2 may make sense as passing length=False I would expect a TypeError.
What do you think?
I've taken another stab at integration with tmbdev/webdataset-lightning; it works, even multinode.
Note that in WebDataset and WebLoader, you can set the length attribute to a function, so you can already raise whatever exception you want from it.
I will have a look at your suggestion. I think it may be good to make "DataStream" an explicit class in the webdataset library, that can take care of all these details.
Thank you!
Another option is document it. Once you know how to do it, it isn't a great deal. However, it took me some hours to figure out. First, discover webdataset-lightning from several github issues, a couple of trials with dataloader and dataset length (I was using torch.Dataloader and length assignment doesn't work, it must go to the dataset), then trying with length=False and finally realizing that I could pass a custom length function to behave like torcktext or TypeError.
Yes, I agree, the methods need a lot more documentation. I'll try to add a lot more over the next couple of weeks.
I’m willing to help out if needed. I have been working on getting a central webdataset / AIStore working for my universities lab and have also been tinkering with PyTorch lighting.