transformers
transformers copied to clipboard
PipelineChunkIterator does not provide the correct length
System Info
transformersversion: 4.28.1- PyTorch version (GPU?): 2.0.0+cu117 (True)
Who can help?
@Narsil
Information
- [ ] The official example scripts
- [ ] My own modified scripts
Tasks
- [ ] An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - [ ] My own task or dataset (give details below)
Reproduction
Relatively minor in the scheme of things, but I looked into it a bit to make sure it wasn't an issue with batching.
from transformers import pipeline
pipe = pipeline("token-classification")
pipe(["New York " * 600] * 2, stride=0)
Leads to noisy warnings from torch:
/tmp/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py:646: UserWarning: Length of IterableDataset <transformers.pipelines.pt_utils.PipelineChunkIterator object at 0x7f084a9bce50> was reported to be 2 (when accessing len(dataloader)), but 3 samples have been fetched.
warnings.warn(warn_msg)
/tmp/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py:646: UserWarning: Length of IterableDataset <transformers.pipelines.pt_utils.PipelineChunkIterator object at 0x7f084a9bce50> was reported to be 2 (when accessing len(dataloader)), but 4 samples have been fetched.
warnings.warn(warn_msg)
/tmp/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py:646: UserWarning: Length of IterableDataset <transformers.pipelines.pt_utils.PipelineChunkIterator object at 0x7f084a9bce50> was reported to be 2 (when accessing len(dataloader)), but 5 samples have been fetched.
warnings.warn(warn_msg)
/tmp/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py:646: UserWarning: Length of IterableDataset <transformers.pipelines.pt_utils.PipelineChunkIterator object at 0x7f084a9bce50> was reported to be 2 (when accessing len(dataloader)), but 6 samples have been fetched.
warnings.warn(warn_msg)
Expected behavior
PipelineChunkIterator provides the intended length, no noisy warnings.