ChainRulesTestUtils.jl
ChainRulesTestUtils.jl copied to clipboard
test_rrule bug?
trafficstars
From
https://github.com/JuliaDiff/ChainRules.jl/blob/24318b0321ccd48f16cbbd59dba6ae8bb9e90860/test/rulesets/LinearAlgebra/structured.jl#L120
using ChainRules
using ChainRulesCore
using ChainRulesTestUtils
using LinearAlgebra
f = adjoint
T = Float64
n = 5
m = 3
A = randn(T, n, m)
Y = f(A)
Ȳ_mat = randn(T, m, n)
Ȳ_composite = Composite{typeof(Y)}(parent=collect(f(Ȳ_mat)))
test_rrule(f, A; output_tangent=Ȳ_mat) # works
test_rrule(f, A; output_tangent=Ȳ_composite) # breaks
test_rrule(f, A) # breaks
Could you provide a stack trace please? :)
Sure, sorry:
Test Failed at /Users/mzgubic/JuliaEnvs/ChainRules.jl/dev/ChainRulesTestUtils/src/check_result.jl:19
Expression: isapprox(actual, expected; kwargs...)
Evaluated: isapprox([-0.9813351316701018 -1.4155573915526176 -0.08802023134709079; 0.5821187406162006 0.25518397597220116 0.6440863483747852; … ; -0.8464238028962722 0.8760142296068198 -1.4627259246978286; 1.3112413776523713 -1.8132103745503358 -1.8489217478396944], [-0.9813351316700425 0.5821187406162589 0.8906072162724507; -0.846423802896213 1.311241377652429 -1.4155573915525579; … ; -1.8132103745503456 -0.08802023134703205 0.6440863483748435; -0.004548957389037324 -1.4627259246978253 -1.8489217478396347]; rtol = 1.0e-9, atol = 1.0e-9)
I looked into it yesterday and it seems that there is an extra adjoint somewhere which messes up the test (the rrule itself looks fine, gives the same result for both output_tangents)
Oh weird, yeah, that sounds like a bug.