fix: Allow dtype cast for UDF
Fixes #25732.
There is currently no way to coerce the resulting AnyValue of a UDF into the specific dtype prior to the UDF being applied, so if a user supplies a compatible return_dtype we should allow the result to be collected/cast.
OP example:
import polars as pl
df = pl.DataFrame({"a": [1, 3, 4.0]}, schema={"a": pl.Float32})
result = df.with_columns(pl.col("a").map_elements(lambda x: x, return_dtype=pl.Float32))
print(result)
# shape: (3, 1)
# ┌─────┐
# │ a │
# │ --- │
# │ f32 │
# ╞═════╡
# │ 1.0 │
# │ 3.0 │
# │ 4.0 │
# └─────┘
Not sure why I'm getting a bunch of u32/u64 failures, it looks like it thinks it's on the 32-bit runtime but it's actually 64? Edit: Orson fixed it.
Codecov Report
:x: Patch coverage is 70.00000% with 3 lines in your changes missing coverage. Please review.
:white_check_mark: Project coverage is 79.64%. Comparing base (fe85677) to head (3c1dd61).
| Files with missing lines | Patch % | Lines |
|---|---|---|
| crates/polars-python/src/map/series.rs | 70.00% | 3 Missing :warning: |
Additional details and impacted files
@@ Coverage Diff @@
## main #25735 +/- ##
==========================================
- Coverage 80.56% 79.64% -0.93%
==========================================
Files 1764 1764
Lines 242683 242683
Branches 3041 3041
==========================================
- Hits 195528 193281 -2247
- Misses 46372 48619 +2247
Partials 783 783
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
:rocket: New features to boost your workflow:
- :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
I'm not sure if we want false strictness. Won't this also allow other silly things such as returning strings when the return type is set to pl.Int64?
It does. I wasn't sure if that mattered, as UDFs are (or should be) a last resort, so it feels right to make it as least strict as possible.
If this isn't desired I'll look into a more appropriate fix here, this was kind of an easy sledgehammer fix.
@orlp the main issue here is that the strict parameter in from_any_values_and_dtype is much more strict than our normal casting. pl.Series([1, 2, 3]).cast(pl.UInt8, strict=True) works fine, but for from_any_values_and_dtype, strict means that the dtype must match exactly or we get an error. And since any floats are built as f64, there is no way to get a return dtype of f32.
Marking as ready but don't know if this implementation is the way to go, pending decision on how UDF return_dtype should be handled (is it directive or description of return dtype?).
This has always been a thing in the back of my mind where it's undefined how this works. I'd assume that however it is, it should match the behaviour of pl.DataFrame(...) construction (see https://github.com/pola-rs/polars/pull/25823/files).
I'd propose the following, which feels aligned with the super messed-up way Python (and type checkers) think that ints and floats are "close enough to the same thing". For example, assert 1 == 1.0 is True, Pyright doesn't warn about -> float-typed function returning an int, etc.
Proposal
- A python int can be implicitly cast to any type of return_dtype int (lossless operation). Int overflows should raise an Exception.
- A python float can be implicitly cast to any type of return_dtype float (lossless operation).
- A python int can be implicitly cast to any type of return_dtype float (lossless operation).
- A python float CANNOT be implicitly cast to any type of return_dtype int (lossy operation).