ChainRulesCore.jl icon indicating copy to clipboard operation
ChainRulesCore.jl copied to clipboard

ChainRulesCore.ProjectTo creates sparse matrices of the wrong element type (drops Duals)

Open ChrisRackauckas opened this issue 1 year ago • 3 comments

MWE:

using Zygote, SparseArrays, ForwardDiff

x, v = rand(Float32, 5), rand(Float32, 5)
A = sprand(Float32, 5, 5, 0.5)
loss(_x) = sum(tanh.(A * _x))

T = typeof(ForwardDiff.Tag(nothing, eltype(x)))
y = ForwardDiff.Dual{T, eltype(x), 1}.(x, ForwardDiff.Partials.(tuple.(reshape(v, size(x)))))
g = x -> first(Zygote.gradient(loss, x))
ForwardDiff.partials.(g(y), 1)

ChrisRackauckas avatar Dec 30 '23 11:12 ChrisRackauckas

Note that this MWE doesn't run:

julia> y = _default_autoback_hesvec_cache(x, v)
ERROR: UndefVarError: `_default_autoback_hesvec_cache` not defined in `Main`
Suggestion: check for spelling errors or missing imports.
Stacktrace:
 [1] top-level scope
   @ REPL[199]:1
 [2] top-level scope
   @ ~/.julia/packages/Metal/qeZqc/src/initialization.jl:51

julia> ForwardDiff.partials.(g(y), 1)
ERROR: MethodError: no method matching Float32(::Dual{Nothing, Float32, 1})

mcabbott avatar Dec 31 '23 17:12 mcabbott

Fixed, just delete that extra line.

ChrisRackauckas avatar Dec 31 '23 17:12 ChrisRackauckas

A less complicated way to trigger this seems to be:

julia> x |> summary  # from above
"5-element Vector{Float32}"

julia> Zygote.gradient(loss, x)
(Float32[0.711501, 0.9295027, 0.035282552, 0.9122769, 0.3412085],)

julia> Zygote.gradient(loss, x .+ Dual(0,1))
ERROR: MethodError: no method matching Float32(::Dual{Nothing, Float32, 1})

Closest candidates are:
  (::Type{T})(::Real, ::RoundingMode) where T<:AbstractFloat
   @ Base rounding.jl:266
  (::Type{T})(::T) where T<:Number
   @ Core boot.jl:894
  Float32(::IrrationalConstants.Invsqrtπ)
   @ IrrationalConstants ~/.julia/packages/IrrationalConstants/vp5v4/src/macro.jl:113
  ...

Stacktrace:
  [1] convert(::Type{Float32}, x::Dual{Nothing, Float32, 1})
    @ Base ./number.jl:7
  [2] setindex!(A::Vector{Float32}, x::Dual{Nothing, Float32, 1}, i::Int64)
    @ Base ./array.jl:969
  [3] (::ChainRulesCore.ProjectTo{SparseMatrixCSC, @NamedTuple{…}})(dx::Matrix{Dual{…}})
    @ ChainRulesCoreSparseArraysExt ~/.julia/packages/ChainRulesCore/7MWx2/ext/ChainRulesCoreSparseArraysExt.jl:79
  [4] #1476
    @ ~/.julia/packages/ChainRules/snrkz/src/rulesets/Base/arraymath.jl:36 [inlined]
  [5] unthunk
    @ ~/.julia/packages/ChainRulesCore/7MWx2/src/tangent_types/thunks.jl:204 [inlined]
  [6] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/YYT6v/src/compiler/chainrules.jl:110 [inlined]

MWE which doesn't use Zygote:

julia> using ChainRulesCore, SparseArrays, ForwardDiff

julia> A = sprand(Float32, 5, 5, 0.5);

julia> ProjectTo(A)(ones(5, 5))
5×5 SparseMatrixCSC{Float32, Int64} with 14 stored entries:
  ⋅   1.0  1.0   ⋅    ⋅ 
  ⋅   1.0   ⋅    ⋅    ⋅ 
 1.0   ⋅   1.0   ⋅   1.0
 1.0   ⋅   1.0  1.0  1.0
 1.0   ⋅   1.0  1.0  1.0

julia> ProjectTo(A)(ones(5, 5) .+ ForwardDiff.Dual(0,1))
ERROR: MethodError: no method matching Float32(::Dual{Nothing, Float64, 1})

Closest candidates are:
  (::Type{T})(::Real, ::RoundingMode) where T<:AbstractFloat
   @ Base rounding.jl:266
  (::Type{T})(::T) where T<:Number
   @ Core boot.jl:894
  Float32(::IrrationalConstants.Invsqrtπ)
   @ IrrationalConstants ~/.julia/packages/IrrationalConstants/vp5v4/src/macro.jl:113
  ...

Stacktrace:
 [1] convert(::Type{Float32}, x::Dual{Nothing, Float64, 1})
   @ Base ./number.jl:7
 [2] setindex!(A::Vector{Float32}, x::Dual{Nothing, Float64, 1}, i::Int64)
   @ Base ./array.jl:969
 [3] (::ProjectTo{SparseMatrixCSC, @NamedTuple{…}})(dx::Matrix{Dual{…}})
   @ ChainRulesCoreSparseArraysExt ~/.julia/packages/ChainRulesCore/7MWx2/ext/ChainRulesCoreSparseArraysExt.jl:79
 [4] top-level scope

Dense matrices do allow eltypes like Dual here, e.g. ProjectTo(Matrix(A))(ones(5, 5) .+ ForwardDiff.Dual(0,1)). This is needed for forward-over-reverse things, like Zygote.hessian.

The bug is in this line:

https://github.com/JuliaDiff/ChainRulesCore.jl/blob/2c2d2bd6baf42e4ff754e8caa6a2ea41531daf16/ext/ChainRulesCoreSparseArraysExt.jl#L73

Note that the whole spares prediction story is basically placeholder code, rushed in to make 1.0 have the desired behaviour of preserving sparsity. It is quite slow, and could really use some care from someone who actually use it. (Maybe shipping 1.0 with deliberate errors would have been better.)

mcabbott avatar Dec 31 '23 17:12 mcabbott