ChainRulesCore.ProjectTo creates sparse matrices of the wrong element type (drops Duals)
MWE:
using Zygote, SparseArrays, ForwardDiff
x, v = rand(Float32, 5), rand(Float32, 5)
A = sprand(Float32, 5, 5, 0.5)
loss(_x) = sum(tanh.(A * _x))
T = typeof(ForwardDiff.Tag(nothing, eltype(x)))
y = ForwardDiff.Dual{T, eltype(x), 1}.(x, ForwardDiff.Partials.(tuple.(reshape(v, size(x)))))
g = x -> first(Zygote.gradient(loss, x))
ForwardDiff.partials.(g(y), 1)
Note that this MWE doesn't run:
julia> y = _default_autoback_hesvec_cache(x, v)
ERROR: UndefVarError: `_default_autoback_hesvec_cache` not defined in `Main`
Suggestion: check for spelling errors or missing imports.
Stacktrace:
[1] top-level scope
@ REPL[199]:1
[2] top-level scope
@ ~/.julia/packages/Metal/qeZqc/src/initialization.jl:51
julia> ForwardDiff.partials.(g(y), 1)
ERROR: MethodError: no method matching Float32(::Dual{Nothing, Float32, 1})
Fixed, just delete that extra line.
A less complicated way to trigger this seems to be:
julia> x |> summary # from above
"5-element Vector{Float32}"
julia> Zygote.gradient(loss, x)
(Float32[0.711501, 0.9295027, 0.035282552, 0.9122769, 0.3412085],)
julia> Zygote.gradient(loss, x .+ Dual(0,1))
ERROR: MethodError: no method matching Float32(::Dual{Nothing, Float32, 1})
Closest candidates are:
(::Type{T})(::Real, ::RoundingMode) where T<:AbstractFloat
@ Base rounding.jl:266
(::Type{T})(::T) where T<:Number
@ Core boot.jl:894
Float32(::IrrationalConstants.Invsqrtπ)
@ IrrationalConstants ~/.julia/packages/IrrationalConstants/vp5v4/src/macro.jl:113
...
Stacktrace:
[1] convert(::Type{Float32}, x::Dual{Nothing, Float32, 1})
@ Base ./number.jl:7
[2] setindex!(A::Vector{Float32}, x::Dual{Nothing, Float32, 1}, i::Int64)
@ Base ./array.jl:969
[3] (::ChainRulesCore.ProjectTo{SparseMatrixCSC, @NamedTuple{…}})(dx::Matrix{Dual{…}})
@ ChainRulesCoreSparseArraysExt ~/.julia/packages/ChainRulesCore/7MWx2/ext/ChainRulesCoreSparseArraysExt.jl:79
[4] #1476
@ ~/.julia/packages/ChainRules/snrkz/src/rulesets/Base/arraymath.jl:36 [inlined]
[5] unthunk
@ ~/.julia/packages/ChainRulesCore/7MWx2/src/tangent_types/thunks.jl:204 [inlined]
[6] wrap_chainrules_output
@ ~/.julia/packages/Zygote/YYT6v/src/compiler/chainrules.jl:110 [inlined]
MWE which doesn't use Zygote:
julia> using ChainRulesCore, SparseArrays, ForwardDiff
julia> A = sprand(Float32, 5, 5, 0.5);
julia> ProjectTo(A)(ones(5, 5))
5×5 SparseMatrixCSC{Float32, Int64} with 14 stored entries:
⋅ 1.0 1.0 ⋅ ⋅
⋅ 1.0 ⋅ ⋅ ⋅
1.0 ⋅ 1.0 ⋅ 1.0
1.0 ⋅ 1.0 1.0 1.0
1.0 ⋅ 1.0 1.0 1.0
julia> ProjectTo(A)(ones(5, 5) .+ ForwardDiff.Dual(0,1))
ERROR: MethodError: no method matching Float32(::Dual{Nothing, Float64, 1})
Closest candidates are:
(::Type{T})(::Real, ::RoundingMode) where T<:AbstractFloat
@ Base rounding.jl:266
(::Type{T})(::T) where T<:Number
@ Core boot.jl:894
Float32(::IrrationalConstants.Invsqrtπ)
@ IrrationalConstants ~/.julia/packages/IrrationalConstants/vp5v4/src/macro.jl:113
...
Stacktrace:
[1] convert(::Type{Float32}, x::Dual{Nothing, Float64, 1})
@ Base ./number.jl:7
[2] setindex!(A::Vector{Float32}, x::Dual{Nothing, Float64, 1}, i::Int64)
@ Base ./array.jl:969
[3] (::ProjectTo{SparseMatrixCSC, @NamedTuple{…}})(dx::Matrix{Dual{…}})
@ ChainRulesCoreSparseArraysExt ~/.julia/packages/ChainRulesCore/7MWx2/ext/ChainRulesCoreSparseArraysExt.jl:79
[4] top-level scope
Dense matrices do allow eltypes like Dual here, e.g. ProjectTo(Matrix(A))(ones(5, 5) .+ ForwardDiff.Dual(0,1)). This is needed for forward-over-reverse things, like Zygote.hessian.
The bug is in this line:
https://github.com/JuliaDiff/ChainRulesCore.jl/blob/2c2d2bd6baf42e4ff754e8caa6a2ea41531daf16/ext/ChainRulesCoreSparseArraysExt.jl#L73
Note that the whole spares prediction story is basically placeholder code, rushed in to make 1.0 have the desired behaviour of preserving sparsity. It is quite slow, and could really use some care from someone who actually use it. (Maybe shipping 1.0 with deliberate errors would have been better.)