Functors.jl icon indicating copy to clipboard operation
Functors.jl copied to clipboard

Improve type stability of cached walks

Open chengchingwen opened this issue 1 year ago • 5 comments

This PR adds a special cache type that allows the compiler to use the signature of the un-cached walk to generate corresponding type assertion to the untyped cache (IdDict{Any, Any}). This would improve the type stability of fmapand friends. It also looses the constraint of the cache type so functionality outside fmap remains the same.

chengchingwen avatar May 09 '24 01:05 chengchingwen

This adds some complexity to the code and some fragility as well, since it seems it could break with newer julia versions. Can you post some benchmarks showing performance improvements?

CarloLucibello avatar May 13 '24 05:05 CarloLucibello

Not a benchmark, but without this PR:

julia> @code_warntype gpu(Chain(Dense(3, 5), Dense(5, 2)))
MethodInstance for Flux.gpu(::Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}})
  from gpu(x) @ Flux ~/.julia/packages/Flux/Wz6D4/src/functor.jl:248
Arguments
  #self#::Core.Const(Flux.gpu)
  x::Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}
Body::Chain{T} where T<:Tuple{Any, Any}
1 ─ %1 = Flux.FluxCUDAAdaptor()::Core.Const(Flux.FluxCUDAAdaptor(nothing))
│   %2 = Flux.gpu(%1, x)::Chain{T} where T<:Tuple{Any, Any}
└──      return %2

v.s. with:

julia> @code_warntype gpu(Chain(Dense(3, 5), Dense(5, 2)))
MethodInstance for Flux.gpu(::Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity
), Matrix{Float32}, Vector{Float32}}}})
  from gpu(x) @ Flux ~/.julia/packages/Flux/Wz6D4/src/functor.jl:248
Arguments
  #self#::Core.Const(Flux.gpu)
  x::Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vecto
r{Float32}}}}
Body::Union{Chain{Tuple{Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.D
eviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuff
er}}}}, Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}
1 ─ %1 = Flux.FluxCUDAAdaptor()::Core.Const(Flux.FluxCUDAAdaptor(nothing))
│   %2 = Flux.gpu(%1, x)::Union{Chain{Tuple{Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Fl
oat32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1,
 CUDA.Mem.DeviceBuffer}}}}, Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}
└──      return %2

chengchingwen avatar May 13 '24 05:05 chengchingwen

@darsnack @ToucheSir what do you think? I'm unfamiliar with expression manipulations.

CarloLucibello avatar May 15 '24 06:05 CarloLucibello

I am also concerned about fragility. The implementation itself is sensible, but as written seems like it will need to get updated for internal changes often. The core idea is to use the return type of the walk to force the type when accessing the cache, right? That seems like a very straight-forward generated function to write with the call to return_type being the only brittle bit. Or is the rest of the current implementation necessary for performance reasons? Accessing the IdDict is the main reason Functors is type unstable, so fixing it is nice.

Pulling back, is there a use-case where we lack a function barrier between the call to gpu and the hot code path?

darsnack avatar May 15 '24 11:05 darsnack

The core idea is to use the return type of the walk to force the type when accessing the cache, right? That seems like a very straight-forward generated function to write with the call to return_type being the only brittle bit. Or is the rest of the current implementation necessary for performance reasons?

Yes, essentially the whole generated function is just to generate return cache.cache[x]::(return_type(cache.walk, typeof(args))). It also seems to be doable without generated function, but with the generated function we can get the precise world-age (though I'm not familiar enough with the world-age mechanism to know if the precise world-age is required in this use-case).

Pulling back, is there a use-case where we lack a function barrier between the call to gpu and the hot code path?

if you need to handle data movement during the forward/backward pass.

chengchingwen avatar May 15 '24 11:05 chengchingwen

given the concerns expressed in https://github.com/LuxDL/Lux.jl/issues/1017 I think we should do this.

CarloLucibello avatar Nov 04 '24 08:11 CarloLucibello

@CarloLucibello Since Julia v1.10 is the new LTS, do you think we could drop v1.6 support so that we can remove that @static if VERSION >= v"1.10.0-DEV.609" branch which makes the code look fragile?

chengchingwen avatar Nov 04 '24 08:11 chengchingwen

yes, we should do that.

CarloLucibello avatar Nov 04 '24 09:11 CarloLucibello