Fix chunk casting and schema unification in dataset
Updated chunk handling to cast to expected schema when features are provided or to unify schemas when not. This ensures proper schema alignment for the yielded batches.
fixes #7872
This PR fixes a bug where IterableDataset created from a generator with explicit features parameter would fail during arrow operations (like .to_pandas()) when the data contains missing or null values.
Problem
When an IterableDataset is created with explicit features but the generator yields data with missing values (e.g., empty lists), PyArrow would infer different schemas for different batches based on the actual data rather than using the provided schema. This caused ArrowInvalid errors when trying to concatenate batches with mismatched schemas.
Example error:
pyarrow.lib.ArrowInvalid: Schema at index 1 was different:
a: int64
b: list
vs
a: int64
b: list>
Solution
Modified RebatchedArrowExamplesIterable._iter_arrow() to:
- Cast chunks to the expected schema when explicit features are provided
- Unify schemas across chunks when no explicit features are set
- Gracefully handle cast failures by falling back to the original chunk
This ensures that the user-provided schema is respected throughout the iteration process.
Testing
Verified the fix with the following test case:
import datasets
from datasets import features
def test_to_pandas_works_with_explicit_schema():
common_features = features.Features(
{
"a": features.Value("int64"),
"b": features.List({"c": features.Value("int64")}),
}
)
def row_generator():
data = [{"a": 1, "b": []}, {"a": 1, "b": [{"c": 1}]}]
for row in data:
yield row
d = datasets.IterableDataset.from_generator(row_generator, features=common_features)
print("Iterating…")
for _ in d.to_pandas():
pass
test_to_pandas_works_with_explicit_schema()
Before Patch -
@ArjunJagdale ➜ /workspaces/datasets (main) $ python test_arjun.py
Iterating…
Traceback (most recent call last):
File "/workspaces/datasets/test_arjun.py", line 24, in <module>
test_to_pandas_works_with_explicit_schema()
File "/workspaces/datasets/test_arjun.py", line 21, in test_to_pandas_works_with_explicit_schema
for _ in d.to_pandas():
File "/workspaces/datasets/src/datasets/iterable_dataset.py", line 3736, in to_pandas
table = pa.concat_tables(list(self.with_format("arrow").iter(batch_size=1000)))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspaces/datasets/src/datasets/iterable_dataset.py", line 2596, in iter
for key, pa_table in iterator:
File "/workspaces/datasets/src/datasets/iterable_dataset.py", line 2111, in _iter_arrow
for key, pa_table in self.ex_iterable._iter_arrow():
File "/workspaces/datasets/src/datasets/iterable_dataset.py", line 632, in _iter_arrow
yield new_key, pa.Table.from_batches(chunks_buffer)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "pyarrow/table.pxi", line 5039, in pyarrow.lib.Table.from_batches
File "pyarrow/error.pxi", line 155, in pyarrow.lib.pyarrow_internal_check_status
File "pyarrow/error.pxi", line 92, in pyarrow.lib.check_status
pyarrow.lib.ArrowInvalid: Schema at index 1 was different:
a: int64
b: list<item: null>
vs
a: int64
b: list<item: struct<c: int64>>
After Patch -
@ArjunJagdale ➜ /workspaces/datasets (main) $ python test_arjun.py
Iterating…
@ArjunJagdale ➜ /workspaces/datasets (main) $
@lhoestq would like to hear from you!