polars icon indicating copy to clipboard operation
polars copied to clipboard

Incorrectly Sized Series of Struct returned from Plugin

Open J-Meyers opened this issue 4 months ago • 3 comments

Checks

  • [X] I have checked that this issue has not already been reported.
  • [X] I have confirmed this bug exists on the latest version of Polars.

Reproducible example

fn sample_struct(_input_fields: &[Field]) -> PolarsResult<Field> {
    let maxout = Field::new("out1".into(), DataType::UInt64);
    let data = Field::new("out2".into(), DataType::UInt64);

    let struct_type = DataType::Struct(vec![maxout, data]);
    Ok(Field::new("out_combined".into(), struct_type))
}

// split_chunked is directly copied from the example link in the py03 repo

pub fn sample_struct_ca(data: &ChunkedArray<UInt64Type>) -> (ChunkedArray<UInt64Type>, ChunkedArray<UInt64Type>) {
    let mut output_builder =
        PrimitiveChunkedBuilder::<UInt64Type>::new("out1".into(), data.len());
    let mut other_builder = PrimitiveChunkedBuilder::<UInt64Type>::new("out2".into(), data.len());

    for s in data.iter() {
        match s {
            Some(s) => {
                output_builder.append_value(0);
                other_builder.append_value(0);
            }
            None => {
                output_builder.append_null();
                other_builder.append_null();
            }
        }
    }

    (output_builder.finish(), other_builder.finish())
}

#[polars_expr(output_type_func=sample_struct)]
fn sample_func(inputs: &[Series]) -> PolarsResult<Series> {
    let input = inputs[0].u64()?;

    let n_threads = 20;
    let splits = split_offsets(input.len(), n_threads);

    let chunks: Vec<(Vec<_>, Vec<_>)> = splits
        .into_iter()
        .map(|(offset, len)| -> (Vec<_>, Vec<_>) {
            let sliced = input.slice(offset as i64, len);
            let (out1, out2) = sample_struct_ca(&sliced);
            (out1.downcast_iter().cloned().collect::<Vec<_>>(), out2.downcast_iter().cloned().collect::<Vec<_>>())
        })
        .collect();

    let (unzipped1, unzipped2): (Vec<Vec<_>>, Vec<Vec<_>>) = chunks.into_iter().unzip();
    let all1 = UInt64Chunked::from_chunk_iter("out1".into(), unzipped1.into_iter().flatten());
    let all2 = UInt64Chunked::from_chunk_iter("out2".into(), unzipped2.into_iter().flatten());
    let s = StructChunked::from_series("out_combined".into(), &[all1.into_series(), all2.into_series()])?;
    Ok(s.into_series())
}
def get_sample(input_column: IntoExprColumn) -> pl.Expr:
    return register_plugin(
        args=[input_column],
        symbol="sample_func",
        is_elementwise=True,
        lib=lib,
    )

sample_df = pl.DataFrame({
    "A": list(range(2000)),
})

print(sample_df.select(get_sample(pl.col("A").cast(pl.UInt64))).select(pl.len()))

Log output

No response

Issue description

This was done with polars-u64-idx

When trying to follow the pattern demonstrated here: https://github.com/pola-rs/pyo3-polars/blob/d426148ae27410aa4fb10a4a9dc67647a058244f/example/derive_expression/expression_lib/src/expressions.rs

I found that directly returning a series of just a primitive type like in the example worked correctly, but when trying to return a series of struct like in the above it is incorrectly sized (in the above example 100 -- 2000 / 20), and is instead the size of the single last chunk that was to be processed. - So if instead of using range(2000) I used range(1999) it would be of length 99

Printing out s.len() within the plugin returns the correct length, but once it comes back into dataframe form it is no longer the correct length. If instead of going through the chunking process with the downcasting and the rest and instead directly returning the construted struct works propery.

Expected behavior

It should return the full length series regardless of wrapping them in structs

Installed versions

["dtype-struct", "dtype-i8", "dtype-array", "dtype-i16", "bigidx"]

--------Version info--------- Polars: 1.9.0 Index type: UInt64 Platform: Linux-6.8.0-45-generic-x86_64-with-glibc2.31 Python: 3.11.9 (main, Oct 8 2024, 17:30:19) [GCC 13.1.0]

----Optional dependencies---- adbc_driver_manager altair cloudpickle 3.0.0 connectorx deltalake fastexcel fsspec 2023.9.2 gevent great_tables matplotlib 3.9.2 nest_asyncio 1.6.0 numpy 1.26.4 openpyxl pandas 2.2.3 pyarrow 17.0.0 pydantic 2.9.2 pyiceberg sqlalchemy torch 2.4.1+cu124 xlsx2csv xlsxwriter

J-Meyers avatar Oct 17 '24 13:10 J-Meyers