polars icon indicating copy to clipboard operation
polars copied to clipboard

Allow `pl.Expr.get` handle null index values in Aggregation context so that `pl.Expr.arg_min/arg_max` is always valid input

Open dalejung opened this issue 1 year ago • 8 comments

Current pl.Expr.get will error with OutOfBoundsError: index out of bounds if it receives a null. This is an issue when doing aggregations where filtering the group might leave us with an empty group that returns null for arg_min.

Take a simple example of a dataframe with date, int range, and the day of month as columns.

    dates = pl.date_range(
        pl.datetime(2023, 1, 1),
        pl.datetime(2023, 12, 31),
        eager=True,
    )
    df = pl.DataFrame({
        'date': dates,
    }).with_columns(
        num=pl.int_range(pl.len()),
        day=pl.col('date').dt.day(),
    )

If I wanted to get the first day after the 28th of each month, We run into an issue for Feb since it only has 28 days.

    after_28 = pl.col('day') > 28

    report = df.group_by_dynamic('date', every='1mo').agg(
        pl.col('num').filter(after_28).min(),

        # THIS ERRORS WITH OutOfBoundsError 
        # pl.col('date').filter(after_28).get(
        #     pl.col('num').filter(after_28).arg_min()
        # ).alias('after_28_date'),
    )

Incidentally, I can get the expected output by grabbing adding a row index and grabbing the values from the original df after the group_by.

    report = df.with_row_index().group_by_dynamic('date', every='1mo').agg(
        pl.col('num').filter(after_28).min(),
        # store the filtered arg_min
        pl.col('num').filter(after_28).arg_min().alias('arg_min'),
        # store the filtered index values
        pl.col('index').filter(after_28).alias('filtered_index'),
    )

    # We grab from filtered_index, since arg_min was also from filtered
    b_after_28_index = report.select(pl.col('filtered_index').list.get(pl.col('arg_min')))['filtered_index']

    # Grab date values from original df
    report = report.with_columns(
        after_28_date=df.select(
            pl.col('date').gather(b_after_28_index)
        )['date']
    ).drop('filtered_index')

                   date   num  arg_min        after_28_date
0   2023-01-01 00:00:00    28        0  2023-01-29 00:00:00
1   2023-02-01 00:00:00  <NA>     <NA>                 <NA>
2   2023-03-01 00:00:00    87        0  2023-03-29 00:00:00
3   2023-04-01 00:00:00   118        0  2023-04-29 00:00:00
4   2023-05-01 00:00:00   148        0  2023-05-29 00:00:00
5   2023-06-01 00:00:00   179        0  2023-06-29 00:00:00
6   2023-07-01 00:00:00   209        0  2023-07-29 00:00:00
7   2023-08-01 00:00:00   240        0  2023-08-29 00:00:00
8   2023-09-01 00:00:00   271        0  2023-09-29 00:00:00
9   2023-10-01 00:00:00   301        0  2023-10-29 00:00:00
10  2023-11-01 00:00:00   332        0  2023-11-29 00:00:00
11  2023-12-01 00:00:00   362        0  2023-12-29 00:00:00

This might just be a usage issue? My toy example doesn't just filter the original df for dates after the 18th because the aggregation might have multiple filters i.e. odd dates, dates ending 3, etc.

dalejung avatar Mar 19 '24 13:03 dalejung

The issue in the title ("Allow pl.Expr.get handle null index values") is fixed by https://github.com/pola-rs/polars/pull/15239, I've moved the more general issue of allowing out-of-bounds indices in get/gather to this dedicated issue: https://github.com/pola-rs/polars/issues/15240.

orlp avatar Mar 22 '24 15:03 orlp

@orlp sounds good. Yeah I originally opened this ticket thinking my particular issue was with OOB but realized it was the nulls.

dalejung avatar Mar 22 '24 15:03 dalejung

@orlp

I still get out of bounds error for

    report = df.group_by_dynamic('date', every='1mo').agg(
        pl.col('num').filter(after_28).arg_min(),
        # THIS ERRORS WITH OutOfBoundsError 
        pl.col('date').filter(after_28).get(
            pl.col('num').filter(after_28).arg_min()
        ).alias('after_28_date'),
    )

I just realized that this happens in group_by and not select

    empty_filter = pl.col('day') > 100

	# runs fine
    report = df.select(
        pl.col('num').filter(empty_filter).arg_min(),
        pl.col('date').filter(empty_filter).get(
            pl.col('num').filter(empty_filter).arg_min()
        ).alias('after_28_date'),
    )
shape: (1, 2)
┌──────┬───────────────┐
│ num  ┆ after_28_date │
│ ---  ┆ ---           │
│ u32  ┆ datetime[μs]  │
╞══════╪═══════════════╡
│ null ┆ null          │
└──────┴───────────────┘

I assumed this was an null issue because that is what triggers the error in aggregation, but it looks like .get worked in select context. I checked '0.20.16'.

dalejung avatar Mar 22 '24 22:03 dalejung

It seems that it goes wrong in crates/polars-lazy/src/physical_plan/expressions/take.rs, where it considers nulls as out-of-bounds. @ritchie46 is that intended?

orlp avatar Mar 23 '24 09:03 orlp

@orlp Looking at take.rs I see that we'd run into the same issue with slice. It's interesting, a slice(0, pl.col('col').arg_max()) would fail but slice(0, 0) would work fine. Furthermore a slice(0, 0).max() would return NA which makes sense.

I can see this being a bit thorny. I'd expect any expression that takes in an index to handle null because arg_min can be null. But afaik this only occurs when you filter/subset within an aggregation, because the top level group_by won't create empty groups.

Unsure if the default behavior should be to handle NA gracefully or have a flag the triggers grace.

dalejung avatar Mar 26 '24 21:03 dalejung

@ritchie46 @orlp

Can I get clarification on whether this is a bug or usage issue?

Is the fix adding null_on_oob to Expr.get? Or should nulls not trigger OutOfBoundsError.

dalejung avatar Aug 11 '24 00:08 dalejung

https://github.com/pola-rs/polars/issues/16842 looks to be the fix to add pl.Expr.get

dalejung avatar Aug 17 '24 18:08 dalejung

@orlp

Here is a super barebones example.

import polars as pl

df = pl.DataFrame({
    'num': range(10),
})

first_below_5 = (pl.col('num') < 5).arg_true().first()

report = df.group_by('num').agg(
    first_below_5=pl.col('num').get(first_below_5)
) # errors

Note that using implode().list.get works as expected.

# Works
report = df.group_by('num').agg(
    first_below_5=(
        pl.col('num').implode()
        .list.get(
            first_below_5
        )
        .get(0)
    )
)

Also this bug doesnt appear in normal select columns.

df.select(
    pl.col('num').get(
        (pl.col('num') < 0).arg_true().first()
    )
)

derekjlv avatar Oct 17 '24 16:10 derekjlv