fairseq
fairseq copied to clipboard
TypeError: 'EpochBatchIterator' object is not iterable
🐛 Bug
When I walks through your docs, I find that the following code raises TypeError: 'EpochBatchIterator' object is not iterable
# setup the task (e.g., load dictionaries)
task = fairseq.tasks.setup_task(args)
# build model and criterion
model = task.build_model(args)
criterion = task.build_criterion(args)
# load datasets
task.load_dataset('train')
task.load_dataset('valid')
# iterate over mini-batches of data
batch_itr = task.get_batch_iterator(
task.dataset('train'), max_tokens=4096,
)
for batch in batch_itr:
# compute the loss
loss, sample_size, logging_output = task.get_loss(
model, criterion, batch,
)
loss.backward()
To Reproduce
Steps to reproduce the behavior (always include the command you ran):
- Run cmd '....'
- See error
Code sample
Expected behavior
Environment
- fairseq Version (e.g., 1.0 or main):
- PyTorch Version (e.g., 1.0)
- OS (e.g., Linux):
- How you installed fairseq (
pip, source): - Build command you used (if compiling from source):
- Python version:
- CUDA/cuDNN version:
- GPU models and configuration:
- Any other relevant information:
Additional context
I meet the same problem, and I check the source code finally find that there's one more step to do
# iterate over mini-batches of data
batch_itr = task.get_batch_iterator(
task.dataset('train'), max_tokens=4096,
)
batch_itr = batch_itr.next_epoch_itr()
You can check the function next_epoch_itr() for more message