ChainRules.jl
ChainRules.jl copied to clipboard
use ProjectTo in Array addition
the reshape was a primative version of ProjectTo i think?
I see it has test_rrule(+, randn(3), randn(3,1), randn(3,1,1)) for this reshape, but maybe test Diagonal + Matrix or something?
I added
test_rrule(+, randn(3,3), Diagonal(randn(3)), randn(3,3,1))
test_rrule(+, randn(3,3), Diagonal(randn(3)), Symmetric(randn(3,3)))
and interestingly both fail. So some debugging to do there
The bug is that this doesn't work:
julia> ProjectTo(Diagonal([1,2,3]))(randn(3,3,1))
ERROR: MethodError: no method matching (::ProjectTo{Diagonal, ... }}}}}}}})(::Array{Float64, 3})
julia> Diagonal(randn(3)) + randn(3,3,1) # but this does
3×3×1 Array{Float64, 3}:
Fixed in https://github.com/JuliaDiff/ChainRulesCore.jl/pull/446/commits/a3278387300de171b2ec2deb89f967daf5b3ee58
The fix is part of https://github.com/JuliaDiff/ChainRulesCore.jl/pull/446, BTW. Which I think should be merged.
Closing infavor of #783 because I can't be bothered rebasing