arrow icon indicating copy to clipboard operation
arrow copied to clipboard

`join`ing tables with ExtensionArrays

Open NellyWhads opened this issue 1 year ago • 0 comments

Describe the enhancement requested

I'm looking for documentation on how to implement an ExtensionArray which supports join functionality.

Particularly, I'd like to join a table which includes a FixedShapeTensorArray column with another table.

Here's a simple example which does not work.

import numpy as np
import pyarrow as pa

# First dim is the batch dim
tensors = np.arange(3 * 10 * 10).reshape((3, 10, 10)).astype(np.uint8)
tensor_array = pa.FixedShapeTensorArray.from_numpy_ndarray(tensors)
ids = pa.array([1,2,3], type=pa.uint8())
table = pa.Table.from_arrays([ids, tensor_array], names=["id", "tensor"])
print(table.schema)

classes = pa.array(["one", "two", "three"], type=pa.string())
table_2 = pa.Table.from_arrays([ids, classes], names=["id", "name"])
print(table_2.schema)

table.join(table_2, keys=["id"], join_type="full outer")

This raises the error

---------------------------------------------------------------------------
ArrowInvalid                              Traceback (most recent call last)
Cell In[42], [line 1](vscode-notebook-cell:?execution_count=42&line=1)
----> [1](vscode-notebook-cell:?execution_count=42&line=1) table.join(table_2, keys=["id"], join_type="full outer")

File ~/.pyenv/versions/next_gen_data_38/lib/python3.8/site-packages/pyarrow/table.pxi:5570, in pyarrow.lib.Table.join()

File ~/.pyenv/versions/next_gen_data_38/lib/python3.8/site-packages/pyarrow/acero.py:247, in _perform_join(join_type, left_operand, left_keys, right_operand, right_keys, left_suffix, right_suffix, use_threads, coalesce_keys, output_type)
    [242](https://file+.vscode-resource.vscode-cdn.net/Users/nw/workspaces/main/next_gen_data/next_gen_data/~/.pyenv/versions/next_gen_data_38/lib/python3.8/site-packages/pyarrow/acero.py:242)     projection = Declaration(
    [243](https://file+.vscode-resource.vscode-cdn.net/Users/nw/workspaces/main/next_gen_data/next_gen_data/~/.pyenv/versions/next_gen_data_38/lib/python3.8/site-packages/pyarrow/acero.py:243)         "project", ProjectNodeOptions(projections, projected_col_names)
    [244](https://file+.vscode-resource.vscode-cdn.net/Users/nw/workspaces/main/next_gen_data/next_gen_data/~/.pyenv/versions/next_gen_data_38/lib/python3.8/site-packages/pyarrow/acero.py:244)     )
    [245](https://file+.vscode-resource.vscode-cdn.net/Users/nw/workspaces/main/next_gen_data/next_gen_data/~/.pyenv/versions/next_gen_data_38/lib/python3.8/site-packages/pyarrow/acero.py:245)     decl = Declaration.from_sequence([decl, projection])
--> [247](https://file+.vscode-resource.vscode-cdn.net/Users/nw/workspaces/main/next_gen_data/next_gen_data/~/.pyenv/versions/next_gen_data_38/lib/python3.8/site-packages/pyarrow/acero.py:247) result_table = decl.to_table(use_threads=use_threads)
    [249](https://file+.vscode-resource.vscode-cdn.net/Users/nw/workspaces/main/next_gen_data/next_gen_data/~/.pyenv/versions/next_gen_data_38/lib/python3.8/site-packages/pyarrow/acero.py:249) if output_type == Table:
    [250](https://file+.vscode-resource.vscode-cdn.net/Users/nw/workspaces/main/pytorc/projects/next_gen_data/next_gen_data/~/.pyenv/versions/next_gen_data_38/lib/python3.8/site-packages/pyarrow/acero.py:250)     return result_table

File ~/.pyenv/versions/next_gen_data_38/lib/python3.8/site-packages/pyarrow/_acero.pyx:590, in pyarrow._acero.Declaration.to_table()

File ~/.pyenv/versions/next_gen_data_38/lib/python3.8/site-packages/pyarrow/error.pxi:155, in pyarrow.lib.pyarrow_internal_check_status()

File ~/.pyenv/versions/next_gen_data_38/lib/python3.8/site-packages/pyarrow/error.pxi:92, in pyarrow.lib.check_status()

ArrowInvalid: Data type extension<arrow.fixed_shape_tensor[value_type=uint8, shape=[10,10], permutation=[0,1]]> is not supported in join non-key field tensor

How can I make this work? The individual tensors I want to store are rather small (single-digit-dimensions), but the join may lead to list aggregation of a few hundred rows.

I've tagged this as a python question because I don't know what level of API needs to be adjusted to add this functionality.

Component(s)

Python

NellyWhads avatar Oct 18 '24 16:10 NellyWhads