text icon indicating copy to clipboard operation
text copied to clipboard

Problem with StopIteration on dataset when creating vocabulary

Open 00krishna opened this issue 4 years ago • 2 comments

❓ Questions and Help

Description

Hey folks, I was hoping someone could tell me a better way to deal with this issue. I am getting a StopIteration error on the dataset, and I am not clear on how to get around it. Here is a minimal example below which creates the error. I am using Torchtext 0.10.0.

In the real code, I am pulling the AG_NEWS dataset into the train_iter variable, building a vocabulary based on that train_iter dataset, and then trying to process batches for that same dataset using a Dataloader with collate function.

The problem seems to be that I iterate through train_iter one time, in order to build the vocabulary with the yield_tokens function. But when I try and then do next(iter(train_iter)), the iterator has already reached its end. Is there a way to copy the train_iter so that I can build the vocabulary based on the copy. I can probably write some hacky code to workaround this, but just wanted to see if there is a better or more appropriate way.

from torchtext.datasets import AG_NEWS
from torchtext.data.utils import get_tokenizer

from typing import Optional, Tuple

import torchtext
import torch
from torchtext.vocab import Vocab, build_vocab_from_iterator
import numpy as np


def yield_tokens(data_iter):
    for _, text in data_iter:
        yield tokenizer(text)

tokenizer = get_tokenizer('basic_english')
train_iter, test_iter = AG_NEWS()

vocab = build_vocab_from_iterator(yield_tokens(train_iter), 
                                            specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])

print(next(iter(train_iter)))

The error message generated is:

Exception has occurred: StopIteration
exception: no description

    print(next(iter(train_iter)))

00krishna avatar Nov 19 '21 01:11 00krishna

Yes, this is a know issue/limitation of current datasets. We are working on improving this user-experience in the next release.

For now the only way around is to get new iterator every time we need it (for instance inside epochs). Getting iterator is expensive only for the 1st time around, due to download and extraction (if any) of data. But next time, it is almost immediately returned by the function.

So for eg you may do something like this:


from torch.utils.data import DataLoader
from torchtext.datasets import AG_NEWS
collate_fn = None #replace with actual collate function
for _ in num_epochs:
    train_iter = AG_NEWS(split='train')
    data_loader = DataLoader(train_iter, batch_size = batch_size, collate_fn = collate_fn)
    for labels, input in data_loader:
       # do something

parmeet avatar Nov 19 '21 02:11 parmeet

@parmeet Oh I see. That is interesting. Yeah, I thanks for explaining the issue. This helps me to understand how to run training now. I will have to copy this code and keep it in mind.

00krishna avatar Nov 20 '21 02:11 00krishna