fairseq icon indicating copy to clipboard operation
fairseq copied to clipboard

TypeError: 'EpochBatchIterator' object is not iterable

Open jordane95 opened this issue 3 years ago • 1 comments

🐛 Bug

When I walks through your docs, I find that the following code raises TypeError: 'EpochBatchIterator' object is not iterable

# setup the task (e.g., load dictionaries)
task = fairseq.tasks.setup_task(args)

# build model and criterion
model = task.build_model(args)
criterion = task.build_criterion(args)

# load datasets
task.load_dataset('train')
task.load_dataset('valid')

# iterate over mini-batches of data
batch_itr = task.get_batch_iterator(
    task.dataset('train'), max_tokens=4096,
)
for batch in batch_itr:
    # compute the loss
    loss, sample_size, logging_output = task.get_loss(
        model, criterion, batch,
    )
    loss.backward()

To Reproduce

Steps to reproduce the behavior (always include the command you ran):

  1. Run cmd '....'
  2. See error

Code sample

Expected behavior

Environment

  • fairseq Version (e.g., 1.0 or main):
  • PyTorch Version (e.g., 1.0)
  • OS (e.g., Linux):
  • How you installed fairseq (pip, source):
  • Build command you used (if compiling from source):
  • Python version:
  • CUDA/cuDNN version:
  • GPU models and configuration:
  • Any other relevant information:

Additional context

jordane95 avatar Feb 07 '22 07:02 jordane95

I meet the same problem, and I check the source code finally find that there's one more step to do

# iterate over mini-batches of data
batch_itr = task.get_batch_iterator(
task.dataset('train'), max_tokens=4096,
)
batch_itr = batch_itr.next_epoch_itr()

You can check the function next_epoch_itr() for more message

slatter666 avatar Oct 09 '22 12:10 slatter666