polars icon indicating copy to clipboard operation
polars copied to clipboard

Clarify documentation for the `agg_list` argument in `Expr.map_batches`

Open Wainberg opened this issue 2 years ago • 7 comments

Description

The documentation for Expr.map_batches pithily describes the function of the agg_list argument as "Aggregate list". What does this argument do? It would be good to update the documentation.

Wainberg avatar Jan 10 '24 20:01 Wainberg

It seems to be for map_elements. I don't think it has a practical usage outside of map_elements using it but I haven't messed with it much

deanm0000 avatar Jan 10 '24 22:01 deanm0000

map_elements doesn't have an agg_list argument, though.

Wainberg avatar Jan 10 '24 22:01 Wainberg

Sorry, to clarify, map_elements calls map_batches and in doing so it sets that parameter in different conditions that I don't remember off hand.

deanm0000 avatar Jan 10 '24 22:01 deanm0000

Looks like it controls ApplyOptions::ApplyList

https://github.com/pola-rs/polars/blob/a8bdc76000c059afdac1f215e3b95654a0057712/crates/polars-plan/src/dsl/python_udf.rs#L209-L210

Which is defined here:

https://github.com/pola-rs/polars/blob/a8bdc76000c059afdac1f215e3b95654a0057712/crates/polars-plan/src/logical_plan/options.rs#L162-L172

cmdlineluser avatar Jan 10 '24 22:01 cmdlineluser

Does it actually do anything? I haven't been able to find an example where it changes the result.

Wainberg avatar Jan 11 '24 03:01 Wainberg

They have a clear distinction mainly in the agg context. If agg_list is False, the UDF is called per group. In contrast, the UDF is invoked only once on a list of groups.

Let's use an example to illustrate this further:

df = pl.DataFrame(
         {
            "a": [0,1,0,1],
            "b": [1,2,3,4],
        }
    )

def f(x):
        print(x)
        return x
  1. Disable agg_list:
df.group_by("a").agg(pl.col("b").map_batches(f, agg_list=False))

# first output
Series: '' [i64]
[
	2
	4
]

# second output
Series: '' [i64]
[
	1
	3
]
  1. Enable agg_list:
df.group_by("a").agg(pl.col("b").map_batches(f, agg_list=True))

# output
Series: 'b' [list[i64]]
[
	[2, 4]
	[1, 3]
]

reswqa avatar Jan 11 '24 08:01 reswqa

Maybe I can update this document to make it easier to understand.

reswqa avatar Jan 11 '24 08:01 reswqa