polars icon indicating copy to clipboard operation
polars copied to clipboard

Multiple Condition Mapping like `numpy.select`

Open Julian-J-S opened this issue 2 years ago • 9 comments

Problem description

#5822 would be really great for mapping column values with a descrete/fixed set of possible values. But additionally there are many cases where it is impossible to know the set of possible values in advance so you need to use a condition.

In the pandas/numpy world this is done with the select function where you specify a list of conditions and a list of values to return for each condition.

In polars there is a when, then, otherwise chain which is really great and ergonomic for a single condition but becomes a bit messy when you have multiple conditions.

Examples:


df = pl.DataFrame({
    "base_price": [100, 200, 300, 400, 500],
    "discount_level": [0, -1, -5, 2, 1],
})

# polars (when, then, otherwise)
df.with_column(
    pl.when(pl.col("discount_level") == 0)
    .then(pl.col("base_price"))
    .otherwise(
        pl.when(pl.col("discount_level") == 1)
        .then(pl.col("base_price") * 0.9)
        .otherwise(
            pl.when(pl.col("discount_level") > 1)
            .then(pl.col("base_price") * 0.8)
            .otherwise(
                pl.when(pl.col("discount_level") == -1)
                .then(2)
                .otherwise(-3)
            )
        )
    ).alias("discounted_price")
)


# pandas/numpy (select)
condlist = [
    df["discount_level"] == 0,
    df["discount_level"] == 1,
    df["discount_level"] > 1,
    df["discount_level"] == -1,
]

choicelist = [
    df["base_price"],
    df["base_price"] * 0.9,
    df["base_price"] * 0.8,
    2
]

df.assign(
    discounted_price=np.select(condlist, choicelist, default=-3)
)

Do you think it would be possible or a worthy addition to add something like this to polars? =)

Julian-J-S avatar Dec 15 '22 20:12 Julian-J-S

If you're not aware - you can chain when/then/otherwise expressions - you could write it similar to your pandas/numpy (select) example:

df = pl.DataFrame({
   "base_price": [100, 200, 300, 400, 500],
   "discount_level": [0, -1, -5, 2, 1],
})
         
base_price = pl.col("base_price")
discount_level = pl.col("discount_level")
 
condlist = [
   discount_level == 0,
   discount_level == 1,
   discount_level > 1,
   discount_level == -1
]
 
choicelist = [
   base_price,
   base_price * 0.9,
   base_price * 0.8,
   2
]
 
default = -3

mapping = pl.when(False).then(None)
for cond, choice in zip(condlist, choicelist):
   mapping = mapping.when(cond).then(choice)

mapping = mapping.otherwise(default)
 
>>> df.with_column(mapping.alias("discounted_price"))
shape: (5, 3)
┌────────────┬────────────────┬──────────────────┐
│ base_price | discount_level | discounted_price │
│ ---        | ---            | ---              │
│ i64        | i64            | f64              │
╞════════════╪════════════════╪══════════════════╡
│ 100        | 0              | 100.0            │
├────────────┼────────────────┼──────────────────┤
│ 200        | -1             | 2.0              │
├────────────┼────────────────┼──────────────────┤
│ 300        | -5             | -3.0             │
├────────────┼────────────────┼──────────────────┤
│ 400        | 2              | 320.0            │
├────────────┼────────────────┼──────────────────┤
│ 500        | 1              | 450.0            │
└─//─────────┴─//─────────────┴─//───────────────┘

cmdlineluser avatar Dec 15 '22 22:12 cmdlineluser

@JulianCologne you don't have to use otherwise every time. There is a pl.Expr.when. Here's your code cleaned up quite a bit:

import polars as pl
from polars import col, when

df = pl.DataFrame({
    "base_price": [100, 200, 300, 400, 500],
    "discount_level": [0, -1, -5, 2, 1],
})

# polars (when, then, otherwise)
discount_level, base_price = col("discount_level"), col("base_price)
df.with_column(
    when(discount_level == 0).then(base_price )
    .when(discount_level == 1).then(base_price ) * 0.9)
    .when(discount_level > 1).then(base_price * 0.8)
    .when(discount_level == -1).then(2)
    .otherwise(-3)
    .alias("discounted_price")
)

mcrumiller avatar Dec 15 '22 22:12 mcrumiller

thanks a lot for your feedback and your ideas!

@cmdlineluser: nice solution and very similar to the numpy/pandas way. What I don't like about it is this strange False/None initialization of the WhenThen but still cool!

@mcrumiller: very nice! I did not know you could chain the when statements like this. I think this is a great and expressive solution =)

I am not so much about performance but just playing around with this and comparing it to pandas/numpy I found that the when chain is considerably slower than the select approach (around 50% slower). I have no idea why this is and I am by no means a performance expert but in (almost) all of my previous experience with polars it was always faster than pandas/numpy.

Example:

data = {
    "r": np.random.rand(100_000_000),
}

df_pl = pl.DataFrame(data)
df_pd = pd.DataFrame(data)

%%timeit
df_pl.with_column(
    when(col("r") < 0.5).then(0)
    .when(col("r") < 0.75).then(col("r") * 2)
    .when(col("r") < 0.9).then(col("r") * 3)
    .when(col("r") < 0.95).then(col("r") * 4)
    .otherwise(1)
    .alias("calc")
);
# >>> 5.5s


%%timeit
condlist = [
    df_pd["r"] < 0.5,
    df_pd["r"] < 0.75,
    df_pd["r"] < 0.9,
    df_pd["r"] < 0.95,
]
choicelist = [
    0,
    df_pd["r"] * 2,
    df_pd["r"] * 3,
    df_pd["r"] * 4,
]
#df_pd.assign(calc=np.select(condlist, choicelist, default=1));
np.select(condlist, choicelist, default=1)
# >>> 3.3 (numpy+pandas assign)
# >>> 2.5s (numpy only)

Julian-J-S avatar Dec 16 '22 08:12 Julian-J-S

I am not so much about performance but just playing around with this and comparing it to pandas/numpy I found that the when chain is considerably slower than the select approach (around 50% slower). I have no idea why this is and I am by no means a performance expert but in (almost) all of my previous experience with polars it was always faster than pandas/numpy.

This is because (at least for now), Polars will calculate each then condition for the whole "r" column and not just on the subsection alone.

ghuls avatar Dec 16 '22 13:12 ghuls

I am not so much about performance but just playing around with this and comparing it to pandas/numpy I found that the when chain is considerably slower than the select approach (around 50% slower)

If I am not mistaken polars evaluates (in-parallel) all the branches in when/then/otherwise constructs and then throws away the ones that are false. This might be the reason why multiple nested conditionals are slower. Curious to see, if a single when/then/otherwise is slower than numpy?

slonik-az avatar Dec 16 '22 13:12 slonik-az

@ghuls can someone explain the logic of doing 1) compute all, 2) filter? Wouldn't simply reversing those two operations significantly speed up the entire process?

mcrumiller avatar Dec 16 '22 14:12 mcrumiller

@mcrumiller As all those operations run in parallel, the compute all + filter afterwards operation can be faster in wall time and the calculations can be vectorized operations.

But there are also downsides:

  • if one of the branches is a heavy function, you pay that price for each value in the original series and not just for the selected elements in that branch
  • even worse, if your condition selects data to avoid calling the function with bad values (e.g. divide by 0). This is currently not possible with this syntax.

@ritchie46 Couldn't you limit the parallelisation to the when condition:

col("r") < 0.5
col("r") < 0.75
col("r") < 0.9
col("r") < 0.95
other

Then make the masks for each branch in paralell and fill the other values with None?

ghuls avatar Dec 16 '22 15:12 ghuls

just throwing in some performance comparisons

polars: when/then/otherwise chain VS when/then/otherwise nested

%%timeit
df_pl.with_column(
    when(col("r") < 0.5).then(0)
    .when(col("r") < 0.75).then(col("r") * 2)
    .when(col("r") < 0.9).then(col("r") * 3)
    .when(col("r") < 0.95).then(col("r") * 4)
    .otherwise(1)
);
# >>> 5.5s

%%timeit
df_pl.with_column(
    when(col("r") < 0.5).then(0)
    .otherwise(
        when(col("r") < 0.75).then(col("r") * 2)
        .otherwise(
            when(col("r") < 0.9).then(col("r") * 3)
            .otherwise(
                when(col("r") < 0.95).then(col("r") * 4)
                .otherwise(1)
            )
        ));
# >>> 5.5

chaining vs nesting makes almost no difference

(single condition) polars when/then/otherwise VS numpy select VS numpy where

%%timeit
df_pl.with_column(
    when(col("r") < 0.5).then(0)
    .otherwise(1)
);
# >>> 2.16s

%%timeit
df_pd.assign(
    calc=np.select(
        [df_pd["r"] < 0.5],
        [0],
        default=1
    )
)
# >>> 1.75s

%%timeit
df_pd.assign(
    calc=np.where(df_pd["r"] < 0.5, 0, 1)
);
# >>> 1.34s

numpy select is faster than polars when/then/otherwise and numpy where is even faster

Julian-J-S avatar Dec 16 '22 15:12 Julian-J-S

chaining vs nesting makes almost no difference

It should make zero difference: one uses elseif and the other uses else(if( which should be identical in nearly every programming language.

mcrumiller avatar Dec 16 '22 15:12 mcrumiller