IterTools.jl icon indicating copy to clipboard operation
IterTools.jl copied to clipboard

partition function type stability for constant arguments

Open gabrielgellner opened this issue 5 years ago • 8 comments

I frequently like to use partition for a sliding window over a fixed array, like partition(arr, 2, 1) to get a set of pairs. The issue is that this code is type unstable with the current version of the code. kristoffer.carlsson gave a solution that inlines pieces, but mentions it might not be optimal in general (https://discourse.julialang.org/t/itertools-jl-partition-type-stability/13847)

If there is a way to fix this common use case would be super awesome!

gabrielgellner avatar Aug 22 '18 17:08 gabrielgellner

I think his suggestion is good!

iamed2 avatar Aug 24 '18 14:08 iamed2

Okay I will try to put a PR together over the weekend.

gabrielgellner avatar Aug 24 '18 15:08 gabrielgellner

Thanks! The key part will be adding @inferred tests.

iamed2 avatar Aug 24 '18 18:08 iamed2

I am not sure if this makes a difference anymore. Adding the @inline doesn't seem to change the inference anymore. Did something change that fixes this somewhere else?

gabrielgellner avatar Sep 13 '18 19:09 gabrielgellner

This one might be much more performant (only implemented for increments of one):



struct SlidingWindows{K,Itr}
    itr::Itr
end
slidingwindow(itr::T, k) where {T} = SlidingWindows{k, T}(itr)
function Base.length(I::SlidingWindows{K}) where {K} 
    l = length(I.itr) + 1 - K
    l < 1 && throw(
        ArgumentError("Argument with less than $K iterates."))
    l
end
Base.eltype(::SlidingWindows{K,Itr}) where {K,Itr} = NTuple{K,eltype(Itr)}
Base.IteratorEltype(::SlidingWindows{Itr}) where {Itr} = Base.IteratorEltype(Itr)

function Base.iterate(I::SlidingWindows{K}) where K
    ϕ = iterate(I.itr)
    ϕ === nothing && throw(
        ArgumentError("Argument with less than $K iterates."))
    el, state = ϕ
    els_ = [el,]   

    for k in 2:K
        ϕ = iterate(I.itr, state)
        ϕ === nothing && throw(
            ArgumentError("Argument with less than two iterates."))
        el, state = ϕ
        push!(els_, el)
    end
    els = NTuple{K}(els_)
    els, (Base.tail(els), state)
end

function Base.iterate(I::SlidingWindows, els_state) 
    els, state = els_state
    ϕ = iterate(I.itr, state)
    ϕ === nothing && return nothing
    el, state = ϕ
    elsnew = (els..., el)
    elsnew, (Base.tail(elsnew), state)
end
f(x, y) = x .+ y
g(x::T) where {T} = reduce(f, SlidingWindows{3,T}(x))
g2(x) = reduce(f, partition(x, 3, 1))
julia> @btime g(1:100000)
  63.931 μs (23 allocations: 848 bytes)
(4999850001, 4999949999, 5000049997)

julia> @btime g2(1:100000)
  45.940 ms (898453 allocations: 39.65 MiB)
(4999850001, 4999949999, 5000049997)

mschauer avatar Sep 26 '19 17:09 mschauer

I believe this is closed by https://github.com/JuliaCollections/IterTools.jl/pull/77 .

goretkin avatar Aug 25 '20 18:08 goretkin

A static implementation still worthwhile to have somewhere, right?

mschauer avatar Dec 18 '20 19:12 mschauer

A static implementation still worthwhile to have somewhere, right?

Do you mean something that doesn't rely on constant propagation for the step field?

If the ::Int is relaxed, then it might be possible to use e.g. https://github.com/perrutquist/StaticNumbers.jl/blob/master/src/StaticNumbers.jl to get the static behavior without requiring a new type and corresponding methods.

goretkin avatar Dec 19 '20 02:12 goretkin

The transducers equivalent of this seems to use Val in a wrapper constructor that supposedly gets inlined.

https://github.com/JuliaFolds/Transducers.jl/blob/2a51f8dbd6b4408063e52108c3e36006f63b6b0c/src/consecutive.jl#L41

I currently have a custom PairIterator that I would gladly replace with this if it offered comparable performance. But indeed:

julia> f(x, y) = x .+ y
f (generic function with 1 method)

julia> g2(x) = reduce(f, partition(x, 2, 1))
g2 (generic function with 1 method)

julia> g(x) = reduce(f, PairIterator(x))
g (generic function with 1 method)

julia> @btime g(1:100000)
  1.603 ns (0 allocations: 0 bytes)
(4999950000, 5000049999)

julia> @btime g2(1:100000)
  10.712 ms (598974 allocations: 24.40 MiB)
(4999950000, 5000049999)
struct PairIterator{A}
    ax::A
end
Base.IteratorSize(::Type{PairIterator{A}}) where {A} = Base.IteratorSize(A)
Base.IteratorEltype(::Type{PairIterator{A}}) where {A} = Base.IteratorEltype(A)
Base.eltype(itr::PairIterator) = Tuple{eltype(itr.ax), eltype(itr.ax)}
Base.length(itr::PairIterator) = length(itr.ax) - 1
Base.size(itr::PairIterator) = let s = size(itr.ax); (s[1]-1, s[2:end]...) end

@inline function Base.iterate(it::PairIterator)
    x1 = iterate(it.ax)
    x1 === nothing && return nothing
    v1, s1 = x1
    x2 = iterate(it.ax, s1)
    x2 === nothing && return nothing
    v2, s2 = x2
    return (v1, v2), (s2, v2)
end

@inline function Base.iterate(it::PairIterator, (s,v))
    x1 = iterate(it.ax, s)
    x1 === nothing && return nothing
    v2, s2 = x1
    return (v, v2), (s2, v2)
end

pepijndevos avatar Jun 05 '23 09:06 pepijndevos

It seems you could just wrap SlidingWindow in takenth to cover the full partition power at several magnitudes better performance.

pepijndevos avatar Jun 05 '23 10:06 pepijndevos