skorch
skorch copied to clipboard
torch IterableDataset are not fully supported
With PyTorch 1.2.0 came IterableDataset
which only implements __iter__
but no __len__
and certainly no __getitem__
. This is definitely a problem since we are using Subset
to split the input dataset and wraps the original dataset, introduces __getitem__
and delegates the call to the wrapped dataset - which doesn't implement that method since it is iterable.
The simplest solution to this is to not split IterableDataset
in any way. What do you think?
Didn't know about this one. I always wondered whether Datasets really needed __getitem__
, this answers the question :)
Splitting the way skorch (or rather, sklearn) does it, can't be easily supported with IterableDataset
. The __len__
part would be okay, since our Dataset
supports passing the length explicitly. For train/valid, a user would need to predefine two datasets at the moment.
We could think about a wrapper class that allows to split IterableDataset
by using every n-th element for validation, but e.g. stratified splits, group-based splits, or predefined splits wouldn't work.
Didn't know about this one. I always wondered whether Datasets really needed
__getitem__
, this answers the question :)
Did you mean __len__
? :)
We could think about a wrapper class that allows to split IterableDataset by using every n-th element for validation, but e.g. stratified splits, group-based splits, or predefined splits wouldn't work.
If I was a user of IterableDataset
, I do not think there is a sensible default for splitting the data.
The simplest solution to this is to not split IterableDataset in any way. What do you think?
I agree.
Currently, is there an issue with passing an IterableDataset
directly into fit
? Something like this works: (a little hacky?)
class MyDataset(torch.utils.data.IterableDataset):
def __init__(self, X, y):
super().__init__()
self.X = X
self.y = y
self._i = -1
def _generator(self):
if self._i == len(X):
raise StopIteration()
self._i = self._i + 1
yield self.X[self._i], self.y[self._i]
def __iter__(self):
return self._generator()
X, y = make_classification(1000, 20, n_informative=10, random_state=0)
X = X.astype(np.float32)
dataset = MyDataset(X, y)
net = NeuralNetClassifier(ClassifierModule, train_split=None)
net.fit(dataset, y=None)
Moving forward, we can raise an error when train_split is not None and IterableDataset
is passed in?
Did you mean
__len__
? :)
I meant __getitem__
but my sentence was not very clear. What I wanted to express is I wondered why torch Dataset
was not implemented as an iterable instead of relying on __getitem__
to access its members. I just concluded that there probably is a technical reason for it, but the existence of IterableDataset
shows that __getitem__
is actually not strictly necessary (though still helpful in some situations).
If I was a user of
IterableDataset
, I do not think there is a sensible default for splitting the data.The simplest solution to this is to not split IterableDataset in any way. What do you think?
I agree.
I agree to both.
Something like this works: (a little hacky?)
I don't think it's too hacky. Maybe this could be added to helper.py
?
Moving forward, we can raise an error when train_split is not None and
IterableDataset
is passed in?
I agree to this too.