ChainRules.jl icon indicating copy to clipboard operation
ChainRules.jl copied to clipboard

use ProjectTo in Array addition

Open oxinabox opened this issue 3 years ago • 4 comments

the reshape was a primative version of ProjectTo i think?

oxinabox avatar Jan 21 '22 19:01 oxinabox

I see it has test_rrule(+, randn(3), randn(3,1), randn(3,1,1)) for this reshape, but maybe test Diagonal + Matrix or something?

mcabbott avatar Jan 21 '22 19:01 mcabbott

I added

        test_rrule(+, randn(3,3), Diagonal(randn(3)), randn(3,3,1))
        test_rrule(+, randn(3,3), Diagonal(randn(3)), Symmetric(randn(3,3)))

and interestingly both fail. So some debugging to do there

oxinabox avatar Jan 24 '22 11:01 oxinabox

The bug is that this doesn't work:

julia> ProjectTo(Diagonal([1,2,3]))(randn(3,3,1))
ERROR: MethodError: no method matching (::ProjectTo{Diagonal, ... }}}}}}}})(::Array{Float64, 3})

julia> Diagonal(randn(3)) + randn(3,3,1)  # but this does
3×3×1 Array{Float64, 3}:

Fixed in https://github.com/JuliaDiff/ChainRulesCore.jl/pull/446/commits/a3278387300de171b2ec2deb89f967daf5b3ee58

mcabbott avatar Feb 12 '22 03:02 mcabbott

The fix is part of https://github.com/JuliaDiff/ChainRulesCore.jl/pull/446, BTW. Which I think should be merged.

mcabbott avatar Feb 20 '22 14:02 mcabbott

Closing infavor of #783 because I can't be bothered rebasing

oxinabox avatar Feb 12 '24 06:02 oxinabox