Support union arrays in `concat_tables`
Describe the enhancement requested
It would be nice to have another step up from promote_options="permissive", e.g., promote_options="union" which uses dense unions when columns are heterogeneous across schemas. For example:
from pyarrow import table, concat_tables
t1 = table({"a": [1, 2, 3]})
t2 = table({"a": ["a", "b", "c"]})
concat_tables(tables=[t1, t2], promote_options="permissive") # Currently raises `ArrowTypeError`.
concat_tables(tables=[t1, t2], promote_options="union") # Does not exist at the moment.
The latter should use a dense union for column "a".
I've implemented this myself, but it's hard to do, because there is no is_mergeable function which exposes the logic used by concat_tables(tables=…, promote_options="permissive") for me to use, causing me to have to re-implement that, either using guesswork, or using lots of try-excepts. Here is a rough attempt, which works for some cases, but not all. It also does not preserve metadata, nor support missing columns:
from itertools import chain
from logging import getLogger
from typing import Sequence
from pyarrow import (
Table,
concat_tables,
ArrowTypeError,
table,
chunked_array,
ArrowInvalid,
UnionArray,
array,
int8,
int32,
)
logger = getLogger(__name__)
def concat_tables_heterogeneous(tables: Sequence[Table]) -> Table:
"""
Concatenate multiple tables vertically.
This is similar to `pyarrow.concat_tables`, but it allows for heterogeneous schemas by using dense unions.
"""
try:
return concat_tables(tables=tables, promote_options="permissive")
except ArrowTypeError:
logger.warning(
"Heterogeneous table schemas detected. "
"Some columns will be represented as dense unions, which are slower."
)
# TODO: Ask the `pyarrow` maintainers to give us a `is_mergeable` function that we can use the check which columns
# are mergeable without using dense unions, instead of maintaining our own heuristics here.
it = iter(tables)
column_names = next(it).column_names
for t in it:
if t.column_names != column_names:
raise NotImplementedError(
"The tables don't all have the same column names."
)
result = {}
for column_name in column_names:
try:
result[column_name] = chunked_array([t[column_name] for t in tables])
except ArrowInvalid:
# These can't be concatenated into a normal `ChunkedArray`. Use a dense union.
result[column_name] = UnionArray.from_dense(
array(
list(chain(*([i] * t.num_rows for i, t in enumerate(tables)))),
type=int8(),
),
array(
list(chain(*(range(t.num_rows) for t in tables))),
type=int32(),
),
[array(t[column_name]) for t in tables],
)
return table(data=result)
Component(s)
Python
One problem with my implementation above (besides the fact that it's done in Python rather than on a lower level in the Arrow engine) is that chunked_array fill fail on columns where concat_tables would have succeeded (e.g., when merging int64, float64, and null). So instead of chunked_array we could use:
try:
result[column_name] = concat_tables(
tables=[t.select([column_name]) for t in tables],
promote_options="permissive",
)[column_name]
except ArrowTypeError:
But that still feels hacky...
The code above only works for a small number of tables (n < 128). Here is an improved version that merges chunks, and therefore also works on a large number of tables:
from itertools import chain
from logging import getLogger
from typing import Sequence, List, Dict
from pyarrow import (
Table,
concat_tables,
ArrowTypeError,
table,
UnionArray,
array,
int8,
int32,
Array,
concat_arrays,
DataType,
)
logger = getLogger(__name__)
def concat_tables_heterogeneous(
*,
tables: Sequence[Table],
warn_about_heterogeneous_schemas: bool = True,
) -> Table:
"""
Concatenate multiple tables vertically.
This is similar to `pyarrow.concat_tables`, but it allows for heterogeneous schemas by using dense unions.
See the feature request at https://github.com/apache/arrow/issues/44397
Args:
tables: The tables to concatenate.
warn_about_heterogeneous_schemas:
Whether to log a warning if the tables don't all have the same schema. This is True by default, because
creating dense unions is a lot slower than creating arrow arrays with homogeneous types.
"""
try:
return concat_tables(tables=tables, promote_options="permissive")
except ArrowTypeError:
if warn_about_heterogeneous_schemas:
logger.warning(
"Heterogeneous table schemas detected. "
"Some columns will be represented as dense unions, which are slower."
)
it = iter(tables)
column_names = next(it).column_names
for t in it:
if t.column_names != column_names:
raise NotImplementedError(
"The tables don't all have the same column names."
)
result = {}
for column_name in column_names:
# TODO: Ask the `pyarrow` maintainers to give us a `is_mergeable` function that we can use the check which
# columns are mergeable without using dense unions, instead of using try-except here. This would also allow
# us to include the names of the offending columns in the warning message above.
try:
result[column_name] = concat_tables(
tables=[t.select([column_name]) for t in tables],
promote_options="permissive",
)[column_name]
except ArrowTypeError:
# These can't be concatenated into a normal `ChunkedArray`. Use a dense union.
result[column_name] = create_dense_union_from_chunks(
chunks=list(chain(*[t[column_name].chunks for t in tables]))
)
return table(data=result)
def create_dense_union_from_chunks(*, chunks: Sequence[Array]) -> Array:
"""
Given a sequence of `Array`s, create a dense union, where all chunks of the same type are merged.
The types and order of values are preserved.
Values are NOT converted to and fro between Python and Arrow.
"""
if len(chunks) == 0:
return array([])
if len(chunks) == 1:
return chunks[0]
chunks_by_type: Dict[DataType, List[Array]] = {}
for chunk in chunks:
chunks_by_type.setdefault(chunk.type, []).append(chunk)
chunk_type_ordinals = {t: i for i, t in enumerate(chunks_by_type)}
# A dense union should have one child array per type.
children = [
concat_arrays(chunks_of_same_type)
for chunks_of_same_type in chunks_by_type.values()
]
if len(children) == 1:
return children[0]
# The `types` array specifies in which child array each value is.
types: List[int] = []
# The `offsets` array specifies the index of each value in the child array.
offsets: List[int] = []
next_offset_per_type = {t: 0 for t in chunks_by_type}
for chunk in chunks:
type_index = chunk_type_ordinals[chunk.type]
for _ in chunk:
types.append(type_index)
offsets.append(next_offset_per_type[chunk.type])
next_offset_per_type[chunk.type] += 1
return UnionArray.from_dense(
array(types, type=int8()),
array(offsets, type=int32()),
children,
)