ChainRulesTestUtils.jl
ChainRulesTestUtils.jl copied to clipboard
Support testing chunked forward mode
Consider this (From #36)
simo(x) = (x, 2x)
function ChainRulesCore.frule((_, ẋ), simo, x)
y = simo(x)
return y, Composite{typeof(y)}(ẋ, 2ẋ)
end
I believe that that the following should work:
frule_test(simo, (randn(), randn(4))) # chunked mode, scalar/vector
frule_test(simo, (randn(3), randn(2, 3))) # chunked mode, vector/matrix
as I believe the following is the correct chunked mode behavour. @YingboMa am i right?
at least for the scalar primal and vector differential
julia> frule((Zero(), [1, 2, 3]), simo, π)
((π, 6.283185307179586), ([1, 2, 3], [2, 4, 6]))
julia> frule((Zero(), [1 1; 0 1; 1 0]), simo, [1, 1, 1])
(([1, 1, 1], [2, 2, 2]), Composite{Tuple{Array{Int64,1},Array{Int64,1}}}([1 1; 0 1; 1 0], [2 2; 0 2; 2 0]))
Yes, if you define the partials for the i-th primal is partials[:, i].
I am still thinking about which dimension should partials belong. Should it be the leading dimension or the tilling dimension? It does feel right that partials should be in the leading dimension for fast access in scalar code.