ChainRulesTestUtils.jl icon indicating copy to clipboard operation
ChainRulesTestUtils.jl copied to clipboard

Support testing chunked forward mode

Open oxinabox opened this issue 5 years ago • 2 comments

Consider this (From #36)

simo(x) = (x, 2x)
function ChainRulesCore.frule((_, ẋ), simo, x)
    y = simo(x)
    return y, Composite{typeof(y)}(ẋ, 2ẋ)
end

I believe that that the following should work:

frule_test(simo, (randn(), randn(4)))  # chunked mode, scalar/vector
frule_test(simo, (randn(3), randn(2, 3)))  # chunked mode, vector/matrix

as I believe the following is the correct chunked mode behavour. @YingboMa am i right?

at least for the scalar primal and vector differential

julia> frule((Zero(), [1, 2, 3]), simo, π)
((π, 6.283185307179586), ([1, 2, 3], [2, 4, 6]))

julia> frule((Zero(), [1 1; 0 1; 1 0]), simo, [1, 1, 1])
(([1, 1, 1], [2, 2, 2]), Composite{Tuple{Array{Int64,1},Array{Int64,1}}}([1 1; 0 1; 1 0], [2 2; 0 2; 2 0]))

oxinabox avatar May 18 '20 17:05 oxinabox

Yes, if you define the partials for the i-th primal is partials[:, i].

YingboMa avatar May 27 '20 01:05 YingboMa

I am still thinking about which dimension should partials belong. Should it be the leading dimension or the tilling dimension? It does feel right that partials should be in the leading dimension for fast access in scalar code.

YingboMa avatar May 27 '20 02:05 YingboMa