Improve type stability of cached walks
This PR adds a special cache type that allows the compiler to use the signature of the un-cached walk to generate corresponding type assertion to the untyped cache (IdDict{Any, Any}). This would improve the type stability of fmapand friends. It also looses the constraint of the cache type so functionality outside fmap remains the same.
This adds some complexity to the code and some fragility as well, since it seems it could break with newer julia versions. Can you post some benchmarks showing performance improvements?
Not a benchmark, but without this PR:
julia> @code_warntype gpu(Chain(Dense(3, 5), Dense(5, 2)))
MethodInstance for Flux.gpu(::Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}})
from gpu(x) @ Flux ~/.julia/packages/Flux/Wz6D4/src/functor.jl:248
Arguments
#self#::Core.Const(Flux.gpu)
x::Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}
Body::Chain{T} where T<:Tuple{Any, Any}
1 ─ %1 = Flux.FluxCUDAAdaptor()::Core.Const(Flux.FluxCUDAAdaptor(nothing))
│ %2 = Flux.gpu(%1, x)::Chain{T} where T<:Tuple{Any, Any}
└── return %2
v.s. with:
julia> @code_warntype gpu(Chain(Dense(3, 5), Dense(5, 2)))
MethodInstance for Flux.gpu(::Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity
), Matrix{Float32}, Vector{Float32}}}})
from gpu(x) @ Flux ~/.julia/packages/Flux/Wz6D4/src/functor.jl:248
Arguments
#self#::Core.Const(Flux.gpu)
x::Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vecto
r{Float32}}}}
Body::Union{Chain{Tuple{Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.D
eviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuff
er}}}}, Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}
1 ─ %1 = Flux.FluxCUDAAdaptor()::Core.Const(Flux.FluxCUDAAdaptor(nothing))
│ %2 = Flux.gpu(%1, x)::Union{Chain{Tuple{Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Fl
oat32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1,
CUDA.Mem.DeviceBuffer}}}}, Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}
└── return %2
@darsnack @ToucheSir what do you think? I'm unfamiliar with expression manipulations.
I am also concerned about fragility. The implementation itself is sensible, but as written seems like it will need to get updated for internal changes often. The core idea is to use the return type of the walk to force the type when accessing the cache, right? That seems like a very straight-forward generated function to write with the call to return_type being the only brittle bit. Or is the rest of the current implementation necessary for performance reasons? Accessing the IdDict is the main reason Functors is type unstable, so fixing it is nice.
Pulling back, is there a use-case where we lack a function barrier between the call to gpu and the hot code path?
The core idea is to use the return type of the walk to force the type when accessing the cache, right? That seems like a very straight-forward generated function to write with the call to return_type being the only brittle bit. Or is the rest of the current implementation necessary for performance reasons?
Yes, essentially the whole generated function is just to generate return cache.cache[x]::(return_type(cache.walk, typeof(args))). It also seems to be doable without generated function, but with the generated function we can get the precise world-age (though I'm not familiar enough with the world-age mechanism to know if the precise world-age is required in this use-case).
Pulling back, is there a use-case where we lack a function barrier between the call to gpu and the hot code path?
if you need to handle data movement during the forward/backward pass.
given the concerns expressed in https://github.com/LuxDL/Lux.jl/issues/1017 I think we should do this.
@CarloLucibello Since Julia v1.10 is the new LTS, do you think we could drop v1.6 support so that we can remove that @static if VERSION >= v"1.10.0-DEV.609" branch which makes the code look fragile?
yes, we should do that.