SparseConnectivityTracer.jl icon indicating copy to clipboard operation
SparseConnectivityTracer.jl copied to clipboard

Support Lux.jl

Open aml5600 opened this issue 1 year ago • 4 comments

When tracing through a (function) that calls a Lux model, a variety of warnings are thrown about using fallback matmul functions because of eltypes that don't reduce to (I guess) common BLAS routines (i.e. Float64 and the tracer types).

Overloading either Lux evaluation functions in an extension, or even matmul-esque functions would be greatly appreciated!

aml5600 avatar Nov 13 '24 16:11 aml5600

Thanks for the report Andrew! Could you provide us with a stack trace and ideally also a minimal example that reproduces the issue? E.g. a Lux model stripped down to the relevant layer(s) evaluated on random inputs?

matmul on Arrays of our tracers generally works, so there must be something else going on here as well.

adrhill avatar Nov 13 '24 18:11 adrhill

yes, I will try. I have a large backlog of MWEs I need to make but should be getting some time to address them soon!

aml5600 avatar Nov 14 '24 02:11 aml5600

Found this after seeing the same thing. I don't think this affects my actual performance since the sparsity pattern is the same either way and it's just the initial trace that might be slower if I'm understanding it right. Either way, this is the warning I get from the example below.

Warning:

┌ Warning: Mixed-Precision `matmul_cpu_fallback!` detected and Octavian.jl cannot be used for this set of inputs (C [Matrix{SparseConnectivityTracer.GradientTracer{SparseConnectivityTracer.IndexSetGradientPattern{Int64, BitSet}}}]: A [Matrix{Float32}] x B [Matrix{SparseConnectivityTracer.GradientTracer{SparseConnectivityTracer.IndexSetGradientPattern{Int64, BitSet}}}]). Falling back to generic implementation. This may be slow.
└ @ LuxLib.Impl C:\Users\johnb\.julia\packages\LuxLib\EnsP3\src\impl\matmul.jl:190

Example:

using Lux, SparseConnectivityTracer, Random

#create and instantiate MLP
rng = Random.default_rng();
mlpmodel = Chain(Dense(2,2, relu));
ps, st = Lux.setup(rng, mlpmodel);

#setup and run sparsity detection
x0 = [1.0, 2.0];
detector=TracerSparsityDetector();
sparsity_pattern = jacobian_sparsity(x->LuxCore.stateless_apply(mlpmodel, x, ps), x0, detector)

jbiffl avatar Apr 15 '25 17:04 jbiffl

Looks like we could overload LuxLib.jl's matmul* implementations: https://github.com/LuxDL/Lux.jl/blob/main/lib/LuxLib/src/impl/matmul.jl

I unfortunately won't have time to look at it before mid-May, but Lux support will be of high priority for SCT.

adrhill avatar Apr 15 '25 21:04 adrhill