feat: "carefully" allow for dask Expr that modify index
What type of PR is this? (check all applicable)
- [ ] 💾 Refactor
- [x] ✨ Feature
- [ ] 🐛 Bug Fix
- [ ] 🔧 Optimization
- [ ] 📝 Documentation
- [ ] ✅ Test
- [ ] 🐳 Other
Checklist
- [ ] Code follows style guide (ruff)
- [ ] Tests added
- [ ] Documented the changes
If you have comments or can explain your changes, please do so below.
Pretty dangerous stuff to workaround the dask index.
To assess that the implementation is working as expected, I implemented both sort (different index but same length) and drop_nulls (different index due to different length)
thanks for trying this - i'll test it out and see if there's a perf impact
we've got the notebooks in tpch/notebooks, the first two support Dask - fancy running them with this branch and seeing if there's any perf impact?
Hey @MarcoGorelli, I am giving another thought on this feature (which I would still love to see), here is a simple idea to have partial support without loss of performance:
sortis the only method that changes the index and result in an output with the same length. Instead of changing the index to each series, we can do that specifically insort, namely by assigning the original index.def sort(self: Self, *, descending: bool = False, nulls_last: bool = False) -> Self: na_position = "last" if nulls_last else "first" def func(_input: Any, ascending: bool, na_position: bool) -> Any: # noqa: FBT001 name = _input.name result =_input.to_frame(name=name).sort_values( by=name, ascending=ascending, na_position=na_position )[name] return de._expr.AssignIndex(result, _input.index) return self._from_call( func, "sort", not descending, na_position, returns_scalar=False, )- All the other methods that change the index, do so by reducing the length of the series. In my working experience and in TPCH queries they are mostly used before a reduction or in isolation, therefore we should not worry of changing their index. Example:
df.select( head_sum=pl.col("a").head().sum(), tail_mean=pl.col("a").tail().mean(), ) - What is left and unsupported you may ask? Multiple ~reductions~ operations ending up with the same length, different from the original, won't be possible. Example:
df.select( head=pl.col("a").head(), tail=pl.col("a").tail(), )
What do you think?
@MarcoGorelli I am tagging this as ready for review as I re-worked it a bit more.
The TL;DR is:
sortis kind of special, as it modifies the index but returns a Series of the same length of the original one, therefore in such specific case I am manually re-assigning the index- for all other methods, I added a boolean flag to
DaskExprcalledmodifies_indexand:- that is not allowed in
with_columns - in
selectit should be allowed only if there are no other exprs or there is a reduction following ~(I need to address both these cases actually)~.
- that is not allowed in
Yet before developing further, I would like some feedback on how likable this approach is and if we want to move forward with it 🙏🏼
thanks @FBruzzesi !
to be honest I don't know about using such private methods, it makes me feel slightly uneasy - @phofl do you have time/interest in taking a look? specifically the de._collection.Series(de._expr.AssignIndex(result, _input.index)) part in narwhals/_dask/expr.py
I think that for sql engines (like duckdb, which hopefully we can get to eventually) operations like df.select(nw.col('a').sort(), nw.col('b')) would be problematic anyway, so I don't think it'd be an issue to leave them out of the Narwhals area of support
Hi @phofl, apologies to call you in the mix once more.
I have a few questions in order to make this work and guarantee that we don't end up with a
fundamentally a bad idea in Dask, it will shoot you in the foot all over the place.
- How can we test for when Dask will shoot us in the foot if we do something bad?
- The latest approach TL;DR is that if a method changes the index, then it either has to be followed by a reduction or be a single selection. Examples:
- Reductions:
which would translate to something likedf.select( head_sum=pl.col("a").head().sum(), tail_mean=pl.col("a").tail().mean(), )dd.concat([df["a"].head().sum(), df["a"].tail().mean()]) - Single selection:
df.select( head=pl.col("a").head(), )
- Reductions:
thanks for having explored this!
tbh i don't think we should do it - it adds complexity which we must then maintain, and then we'd need to do it for duckdb / pyspark / ... . I'd prefer to keep the lazy-only layer simpler
closing then, but thanks again for your investigation here! 🙏
Thanks @MarcoGorelli! As you mentioned in the past, "no" might be temporary, we can always come back to this if needed