skorch icon indicating copy to clipboard operation
skorch copied to clipboard

torch IterableDataset are not fully supported

Open ottonemo opened this issue 5 years ago • 3 comments

With PyTorch 1.2.0 came IterableDataset which only implements __iter__ but no __len__ and certainly no __getitem__. This is definitely a problem since we are using Subset to split the input dataset and wraps the original dataset, introduces __getitem__ and delegates the call to the wrapped dataset - which doesn't implement that method since it is iterable.

The simplest solution to this is to not split IterableDataset in any way. What do you think?

ottonemo avatar Feb 20 '20 00:02 ottonemo

Didn't know about this one. I always wondered whether Datasets really needed __getitem__, this answers the question :)

Splitting the way skorch (or rather, sklearn) does it, can't be easily supported with IterableDataset. The __len__ part would be okay, since our Dataset supports passing the length explicitly. For train/valid, a user would need to predefine two datasets at the moment.

We could think about a wrapper class that allows to split IterableDataset by using every n-th element for validation, but e.g. stratified splits, group-based splits, or predefined splits wouldn't work.

BenjaminBossan avatar Feb 20 '20 23:02 BenjaminBossan

Didn't know about this one. I always wondered whether Datasets really needed __getitem__, this answers the question :)

Did you mean __len__? :)

We could think about a wrapper class that allows to split IterableDataset by using every n-th element for validation, but e.g. stratified splits, group-based splits, or predefined splits wouldn't work.

If I was a user of IterableDataset, I do not think there is a sensible default for splitting the data.

The simplest solution to this is to not split IterableDataset in any way. What do you think?

I agree.

Currently, is there an issue with passing an IterableDataset directly into fit? Something like this works: (a little hacky?)

class MyDataset(torch.utils.data.IterableDataset):
    def __init__(self, X, y):
        super().__init__()
        self.X = X
        self.y = y
        self._i = -1

    def _generator(self):
        if self._i == len(X):
           raise StopIteration()
        self._i = self._i + 1
        yield self.X[self._i], self.y[self._i]

    def __iter__(self):
        return self._generator()

X, y = make_classification(1000, 20, n_informative=10, random_state=0)
X = X.astype(np.float32)
dataset = MyDataset(X, y)

net = NeuralNetClassifier(ClassifierModule, train_split=None)
net.fit(dataset, y=None)

Moving forward, we can raise an error when train_split is not None and IterableDataset is passed in?

thomasjpfan avatar Feb 21 '20 00:02 thomasjpfan

Did you mean __len__? :)

I meant __getitem__ but my sentence was not very clear. What I wanted to express is I wondered why torch Dataset was not implemented as an iterable instead of relying on __getitem__ to access its members. I just concluded that there probably is a technical reason for it, but the existence of IterableDataset shows that __getitem__ is actually not strictly necessary (though still helpful in some situations).

If I was a user of IterableDataset, I do not think there is a sensible default for splitting the data.

The simplest solution to this is to not split IterableDataset in any way. What do you think?

I agree.

I agree to both.

Something like this works: (a little hacky?)

I don't think it's too hacky. Maybe this could be added to helper.py?

Moving forward, we can raise an error when train_split is not None and IterableDataset is passed in?

I agree to this too.

BenjaminBossan avatar Feb 21 '20 23:02 BenjaminBossan