arrow icon indicating copy to clipboard operation
arrow copied to clipboard

Support union arrays in `concat_tables`

Open rudolfbyker opened this issue 1 year ago • 1 comments

Describe the enhancement requested

It would be nice to have another step up from promote_options="permissive", e.g., promote_options="union" which uses dense unions when columns are heterogeneous across schemas. For example:

from pyarrow import table, concat_tables

t1 = table({"a": [1, 2, 3]})
t2 = table({"a": ["a", "b", "c"]})

concat_tables(tables=[t1, t2], promote_options="permissive")  # Currently raises `ArrowTypeError`.
concat_tables(tables=[t1, t2], promote_options="union")  # Does not exist at the moment.

The latter should use a dense union for column "a".

I've implemented this myself, but it's hard to do, because there is no is_mergeable function which exposes the logic used by concat_tables(tables=…, promote_options="permissive") for me to use, causing me to have to re-implement that, either using guesswork, or using lots of try-excepts. Here is a rough attempt, which works for some cases, but not all. It also does not preserve metadata, nor support missing columns:

from itertools import chain
from logging import getLogger
from typing import Sequence

from pyarrow import (
    Table,
    concat_tables,
    ArrowTypeError,
    table,
    chunked_array,
    ArrowInvalid,
    UnionArray,
    array,
    int8,
    int32,
)

logger = getLogger(__name__)


def concat_tables_heterogeneous(tables: Sequence[Table]) -> Table:
    """
    Concatenate multiple tables vertically.
    This is similar to `pyarrow.concat_tables`, but it allows for heterogeneous schemas by using dense unions.
    """
    try:
        return concat_tables(tables=tables, promote_options="permissive")
    except ArrowTypeError:
        logger.warning(
            "Heterogeneous table schemas detected. "
            "Some columns will be represented as dense unions, which are slower."
        )

    # TODO: Ask the `pyarrow` maintainers to give us a `is_mergeable` function that we can use the check which columns
    #   are mergeable without using dense unions, instead of maintaining our own heuristics here.
    it = iter(tables)
    column_names = next(it).column_names
    for t in it:
        if t.column_names != column_names:
            raise NotImplementedError(
                "The tables don't all have the same column names."
            )

    result = {}
    for column_name in column_names:
        try:
            result[column_name] = chunked_array([t[column_name] for t in tables])
        except ArrowInvalid:
            # These can't be concatenated into a normal `ChunkedArray`. Use a dense union.
            result[column_name] = UnionArray.from_dense(
                array(
                    list(chain(*([i] * t.num_rows for i, t in enumerate(tables)))),
                    type=int8(),
                ),
                array(
                    list(chain(*(range(t.num_rows) for t in tables))),
                    type=int32(),
                ),
                [array(t[column_name]) for t in tables],
            )

    return table(data=result)

Component(s)

Python

rudolfbyker avatar Oct 14 '24 14:10 rudolfbyker

One problem with my implementation above (besides the fact that it's done in Python rather than on a lower level in the Arrow engine) is that chunked_array fill fail on columns where concat_tables would have succeeded (e.g., when merging int64, float64, and null). So instead of chunked_array we could use:

        try:
            result[column_name] = concat_tables(
                tables=[t.select([column_name]) for t in tables],
                promote_options="permissive",
            )[column_name]
        except ArrowTypeError:

But that still feels hacky...

rudolfbyker avatar Oct 14 '24 14:10 rudolfbyker

The code above only works for a small number of tables (n < 128). Here is an improved version that merges chunks, and therefore also works on a large number of tables:

from itertools import chain
from logging import getLogger
from typing import Sequence, List, Dict

from pyarrow import (
    Table,
    concat_tables,
    ArrowTypeError,
    table,
    UnionArray,
    array,
    int8,
    int32,
    Array,
    concat_arrays,
    DataType,
)

logger = getLogger(__name__)


def concat_tables_heterogeneous(
    *,
    tables: Sequence[Table],
    warn_about_heterogeneous_schemas: bool = True,
) -> Table:
    """
    Concatenate multiple tables vertically.
    This is similar to `pyarrow.concat_tables`, but it allows for heterogeneous schemas by using dense unions.
    See the feature request at https://github.com/apache/arrow/issues/44397

    Args:
        tables: The tables to concatenate.
        warn_about_heterogeneous_schemas:
            Whether to log a warning if the tables don't all have the same schema. This is True by default, because
            creating dense unions is a lot slower than creating arrow arrays with homogeneous types.
    """
    try:
        return concat_tables(tables=tables, promote_options="permissive")
    except ArrowTypeError:
        if warn_about_heterogeneous_schemas:
            logger.warning(
                "Heterogeneous table schemas detected. "
                "Some columns will be represented as dense unions, which are slower."
            )

    it = iter(tables)
    column_names = next(it).column_names
    for t in it:
        if t.column_names != column_names:
            raise NotImplementedError(
                "The tables don't all have the same column names."
            )

    result = {}
    for column_name in column_names:
        # TODO: Ask the `pyarrow` maintainers to give us a `is_mergeable` function that we can use the check which
        #  columns are mergeable without using dense unions, instead of using try-except here. This would also allow
        #  us to include the names of the offending columns in the warning message above.
        try:
            result[column_name] = concat_tables(
                tables=[t.select([column_name]) for t in tables],
                promote_options="permissive",
            )[column_name]
        except ArrowTypeError:
            # These can't be concatenated into a normal `ChunkedArray`. Use a dense union.
            result[column_name] = create_dense_union_from_chunks(
                chunks=list(chain(*[t[column_name].chunks for t in tables]))
            )

    return table(data=result)


def create_dense_union_from_chunks(*, chunks: Sequence[Array]) -> Array:
    """
    Given a sequence of `Array`s, create a dense union, where all chunks of the same type are merged.

    The types and order of values are preserved.
    Values are NOT converted to and fro between Python and Arrow.
    """
    if len(chunks) == 0:
        return array([])

    if len(chunks) == 1:
        return chunks[0]

    chunks_by_type: Dict[DataType, List[Array]] = {}
    for chunk in chunks:
        chunks_by_type.setdefault(chunk.type, []).append(chunk)

    chunk_type_ordinals = {t: i for i, t in enumerate(chunks_by_type)}

    # A dense union should have one child array per type.
    children = [
        concat_arrays(chunks_of_same_type)
        for chunks_of_same_type in chunks_by_type.values()
    ]

    if len(children) == 1:
        return children[0]

    # The `types` array specifies in which child array each value is.
    types: List[int] = []

    # The `offsets` array specifies the index of each value in the child array.
    offsets: List[int] = []

    next_offset_per_type = {t: 0 for t in chunks_by_type}
    for chunk in chunks:
        type_index = chunk_type_ordinals[chunk.type]
        for _ in chunk:
            types.append(type_index)
            offsets.append(next_offset_per_type[chunk.type])
            next_offset_per_type[chunk.type] += 1

    return UnionArray.from_dense(
        array(types, type=int8()),
        array(offsets, type=int32()),
        children,
    )

rudolfbyker avatar Mar 18 '25 10:03 rudolfbyker