webdataset icon indicating copy to clipboard operation
webdataset copied to clipboard

Pytorch Lightning integration

Open hal-314 opened this issue 4 years ago • 4 comments

Hi

Currently, webdataset dataset using the default Pytorch Dataloader or WebLoader doesn't work with Pytorch Lightning. You need to set a length attribute (see lightining example) . However, when using an infinite dataloader as recomended in wds-distribute repo, it doesn't work.

An infinite dataloader would be supported if webdataset behavies as torchtext.data.Iterator or raise a TypeError (the same error one an object doesn't implement len method). It raise a NotImplementedError instead of ValueError. See this. I tested with Pytorch Lightning 1.2.8 and 1.3.1 and raising TypeError or NotImplementedError works.

So, it would be perfect if it's easier to user webdataset with Pytorch.Lightning. Currently, you need to write your own wds.WebLoader with a custom length function:

# This function isn't defined as a local function to be able to pickle _LightningWebLoader.
def not_implemented_length(*args, **kwargs):
    # Lightning expects TypeError or NotImplementedError exception when a dataloader declares __len__ method but it's an iterable
    # dataloader. Lighting assumes this to be compatible with torchtext.
    raise TypeError # NotImplementedError


def LightningWebLoader(*args, **kw):
    """Like wds.WebLoader" but raise NotImplementedError like torchtext so Pytorch Lightning
     detects it as IterableDataset without __len__ attribute"""
    return wds.Processor(DataLoader(*args, **kw), wds.utils.identity, length=not_implemented_length)

It would be easier if:

  1. Expose in wds.WebLoader a way to pass arguments to wds.Processor. For example, def WebLoader(*args, kw_proc={}, **kw): wds.Processor(DataLoader(*args, **kw), wds.utils.identity, **kw_proc). Then, passing _not_implemented_length would be enough.
  2. Like 1 but wds.Processor raise TypeError when length=False. So, the call would be wds.WebLoader(...., kw_proc={'length'=False}).
  3. wds.Processor.__len__ raise TypeError instead of ValueError.

The second and third options are backward incompatible. However, for me, option 2 may make sense as passing length=False I would expect a TypeError.

What do you think?

hal-314 avatar May 12 '21 12:05 hal-314

I've taken another stab at integration with tmbdev/webdataset-lightning; it works, even multinode.

Note that in WebDataset and WebLoader, you can set the length attribute to a function, so you can already raise whatever exception you want from it.

I will have a look at your suggestion. I think it may be good to make "DataStream" an explicit class in the webdataset library, that can take care of all these details.

tmbdev avatar May 16 '21 16:05 tmbdev

Thank you!

Another option is document it. Once you know how to do it, it isn't a great deal. However, it took me some hours to figure out. First, discover webdataset-lightning from several github issues, a couple of trials with dataloader and dataset length (I was using torch.Dataloader and length assignment doesn't work, it must go to the dataset), then trying with length=False and finally realizing that I could pass a custom length function to behave like torcktext or TypeError.

hal-314 avatar May 17 '21 07:05 hal-314

Yes, I agree, the methods need a lot more documentation. I'll try to add a lot more over the next couple of weeks.

tmbdev avatar May 18 '21 17:05 tmbdev

I’m willing to help out if needed. I have been working on getting a central webdataset / AIStore working for my universities lab and have also been tinkering with PyTorch lighting.

codestar12 avatar Dec 23 '21 15:12 codestar12