ChainRules.jl
ChainRules.jl copied to clipboard
reuse factorization in rule for \
Closes JuliaLang/julia#250 Its pretty cute I am not 100% sure this is correct, but CI will tell me
Also stops thunking ∂B since we want to use that for ∂A
which means if we need ∂B then we need to compute it, and if we need ∂A we needed to compute it,
so either way we need it.
julia> @btime $(x)\$y
1.032 μs (3 allocations: 1.19 KiB)
10-element Vector{Float64}:
0.49018802032025616
-0.33578706090563665
-0.4615411572036624
0.24664997800385907
-0.5310612794797791
0.3537748064584685
1.470168551590106
-0.4108242996218167
0.11217305102151084
-0.4276420520809933
julia> @btime $(factorize(x))\$y
252.892 ns (1 allocation: 160 bytes)
10-element Vector{Float64}:
0.49018802032025616
-0.33578706090563665
-0.4615411572036624
0.24664997800385907
-0.5310612794797791
0.3537748064584685
1.470168551590106
-0.4108242996218167
0.11217305102151084
-0.4276420520809933
This fails for some inputs because https://github.com/JuliaLang/julia/issues/38293
@willtebbutt gave me a workaround:
julia> @btime $A'\$B;
31.750 ms (5 allocations: 15.27 MiB)
julia> @btime $Af'\$B; # best but not always permitted
19.356 ms (2 allocations: 7.63 MiB)
julia> @btime ($B'/ $Af)';
20.160 ms (2 allocations: 7.63 MiB)
bother: : no method matching rdiv!(::Array{Float64,2}, ::QRPivoted{Float64,Array{Float64,2}})
This should be unblocked by https://github.com/JuliaLang/julia/pull/40899 and it needs to be tested against that branch of julia
testing this against JuliaLang/julia#40899 (cc @andreasnoack for interest) I think it has resolved the problems. Still errors but the erros are now:
- inference failing (probably need to look into how the rules are written there)
MethodError: no method matching factorize(::Vector{Float64})(probably should define a_factorize_for_ldivthat returns the input if it is a vector, or ifVERSION<v"1.7") -DimensionMismatch("overdetermined systems are not supported")(need to look into that one. It was introduced by JuliaLang/julia#40899 but it might be we shouldn't be usingfactorizefor this case? but/does support over-determined systems)
So I think probably JuliaLang/julia#40899 works, not 100% sure re the last point
The error suggests that you end up doing something like factorize(A)'\b for a wide A. For wide A, factorize gives a QR and least squares isn't defined for Adjoint{QR}. Is this a real problem or are you just trying to exercise all possible combinations?
Worth asking @DhairyaLGandhi who originally openned this issue in https://github.com/FluxML/Zygote.jl/issues/773 / https://github.com/JuliaDiff/ChainRules.jl/issues/250
I never want to actually AD \ at all.
but I think the use of factorizations here is a very neat application of https://juliadiff.org/ChainRulesCore.jl/dev/design/changing_the_primal.html
In general ChainRules exercises all paths because it does.